kernels.quantize

kernels.quantize

Dequantization utilities for bitsandbytes and FP8 integration.

Functions

Name Description
dequantize Fast NF4 dequantization using bitsandbytes CUDA kernels.
dequantize_fp8 Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.
dequantize_weight Unified dequantization for both torchao and bnb quantized weights.
is_quant_tensor_subclass True for torchao quant tensor subclasses (NF4Tensor,

dequantize

kernels.quantize.dequantize(W, quant_state=None, out=None)

Fast NF4 dequantization using bitsandbytes CUDA kernels.

Performs efficient dequantization of weights from NF4 format using bitsandbytes’ optimized CUDA implementations. Supports both legacy list and new QuantState formats.

Parameters

Name Type Description Default
W torch.Tensor Quantized weight tensor to dequantize required
quant_state QuantState | list | torch.Tensor | None Quantization state containing metadata needed for dequantization. Can be either a QuantState object or legacy list format. If None, returns W unchanged. None
out torch.Tensor | None Optional output tensor for storing dequantized results. Must match expected shape and dtype if provided. None

Returns

Name Type Description
torch.Tensor Dequantized tensor in the specified dtype (fp16 or bf16). Will be transposed if
torch.Tensor input W was transposed.

Raises

Name Type Description
AssertionError If provided output tensor doesn’t match expected shape / dtype.

Note

Uses CUDA streams for better performance when available in newer bitsandbytes versions (>0.43.3).

dequantize_fp8

kernels.quantize.dequantize_fp8(W, scale_inv, dtype=torch.bfloat16)

Dequantize FP8 block-quantized weights: W_dequant = W_fp8 * scale_inv.

Parameters

Name Type Description Default
W torch.Tensor FP8 weight tensor [out_features, in_features] in float8_e4m3fn. required
scale_inv torch.Tensor Per-block inverse scale [ceil(out/block), ceil(in/block)] or per-tensor scalar. required
dtype torch.dtype Output dtype (default bf16). torch.bfloat16

Returns

Name Type Description
torch.Tensor Dequantized tensor in the specified dtype.

dequantize_weight

kernels.quantize.dequantize_weight(W, quant_state=None, transpose=False)

Unified dequantization for both torchao and bnb quantized weights.

For torchao tensor subclasses (AffineQuantizedTensor, NF4Tensor), dequantizes using the appropriate instance method. For bnb Params4bit, delegates to the optimized CUDA kernel in dequantize.

Parameters

Name Type Description Default
W torch.Tensor Quantized weight tensor [out_features, in_features]. required
quant_state QuantState | list | None bnb QuantState (None for torchao / unquantized). None
transpose bool If True, return [in_features, out_features]. False

Returns

Name Type Description
torch.Tensor Dequantized float tensor, optionally transposed.

is_quant_tensor_subclass

kernels.quantize.is_quant_tensor_subclass(W)

True for torchao quant tensor subclasses (NF4Tensor, AffineQuantizedTensor, etc.) — anything that is not plain torch.Tensor or torch.nn.Parameter. type(W) is not torch.Tensor alone is unsafe: Parameter is a subclass of Tensor, not the same type, so the bare check misclassifies every unquantized PEFT base weight.