kernels.quantize
kernels.quantize
Dequantization utilities for bitsandbytes and FP8 integration.
Functions
| Name | Description |
|---|---|
| dequantize | Fast NF4 dequantization using bitsandbytes CUDA kernels. |
| dequantize_fp8 | Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv. |
| dequantize_weight | Unified dequantization for both torchao and bnb quantized weights. |
| is_quant_tensor_subclass | True for torchao quant tensor subclasses (NF4Tensor, |
dequantize
kernels.quantize.dequantize(W, quant_state=None, out=None)Fast NF4 dequantization using bitsandbytes CUDA kernels.
Performs efficient dequantization of weights from NF4 format using bitsandbytes’
optimized CUDA implementations. Supports both legacy list and new QuantState
formats.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| W | torch.Tensor | Quantized weight tensor to dequantize | required |
| quant_state | QuantState | list | torch.Tensor | None | Quantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged. |
None |
| out | torch.Tensor | None | Optional output tensor for storing dequantized results. Must match expected shape and dtype if provided. | None |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if | |
| torch.Tensor | input W was transposed. |
Raises
| Name | Type | Description |
|---|---|---|
| AssertionError | If provided output tensor doesn’t match expected shape / dtype. |
Note
Uses CUDA streams for better performance when available in newer bitsandbytes
versions (>0.43.3).
dequantize_fp8
kernels.quantize.dequantize_fp8(W, scale_inv, dtype=torch.bfloat16)Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| W | torch.Tensor | FP8 weight tensor [out_features, in_features] in float8_e4m3fn. | required |
| scale_inv | torch.Tensor | Per-block inverse scale [ceil(out/block), ceil(in/block)] or per-tensor scalar. | required |
| dtype | torch.dtype | Output dtype (default bf16). | torch.bfloat16 |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Dequantized tensor in the specified dtype. |
dequantize_weight
kernels.quantize.dequantize_weight(W, quant_state=None, transpose=False)Unified dequantization for both torchao and bnb quantized weights.
For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes
using the appropriate instance method. For bnb Params4bit, delegates to the
optimized CUDA kernel in dequantize.
Parameters
| Name | Type | Description | Default |
|---|---|---|---|
| W | torch.Tensor | Quantized weight tensor [out_features, in_features]. |
required |
| quant_state | QuantState | list | None | bnb QuantState (None for torchao / unquantized). |
None |
| transpose | bool | If True, return [in_features, out_features]. |
False |
Returns
| Name | Type | Description |
|---|---|---|
| torch.Tensor | Dequantized float tensor, optionally transposed. |
is_quant_tensor_subclass
kernels.quantize.is_quant_tensor_subclass(W)True for torchao quant tensor subclasses (NF4Tensor,
AffineQuantizedTensor, etc.) — anything that is not plain
torch.Tensor or torch.nn.Parameter. type(W) is not torch.Tensor alone is unsafe: Parameter is a subclass of
Tensor, not the same type, so the bare check misclassifies every
unquantized PEFT base weight.