Skip to content

Commit 6bdd3f6

Browse files
committed
Remove dequantize_stretched_affine
1 parent 3cbb705 commit 6bdd3f6

File tree

2 files changed

+4
-28
lines changed

2 files changed

+4
-28
lines changed

torchao/prototype/parq/quant/quant_api.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def choose_qparams_stretched_affine(
5656

5757
scale = max_val / quant_max
5858
scale = scale.to(dtype=scale_dtype, device=input_float.device)
59-
zero_point = torch.full_like(scale, 0.5, dtype=zero_point_dtype)
59+
zero_point = torch.full_like(scale, -0.5, dtype=zero_point_dtype)
6060
return scale, zero_point
6161

6262

@@ -95,34 +95,12 @@ def quantize_stretched_affine(
9595
max_val = scale.mul(quant_max)
9696
input_float = input_float.clamp(min=-max_val, max=max_val)
9797
with torch.no_grad():
98-
quant = torch.round(input_float / scale - zero_point)
98+
# difference from quantize_affine: add zero_point before rounding
99+
quant = torch.round(input_float / scale + zero_point)
99100
quant = quant.to(dtype=target_dtype).view(original_shape)
100101
return quant
101102

102103

103-
def dequantize_stretched_affine(
104-
data: torch.Tensor,
105-
block_size: Tuple[int, ...],
106-
scale: torch.Tensor,
107-
zero_point: torch.Tensor,
108-
data_dtype: torch.dtype,
109-
quant_min: Optional[int] = None,
110-
quant_max: Optional[int] = None,
111-
output_dtype: torch.dtype = torch.float32,
112-
) -> torch.Tensor:
113-
# allow float data_dtype instead of restricting to _SUB_BYTE_UINT_BOUNDS
114-
return dequantize_affine(
115-
data,
116-
block_size,
117-
scale,
118-
-zero_point,
119-
data_dtype,
120-
quant_min=quant_min,
121-
quant_max=quant_max,
122-
output_dtype=output_dtype,
123-
)
124-
125-
126104
class StretchedAffineQuantizedTensor(AffineQuantizedTensor):
127105
@classmethod
128106
def from_hp_to_intx(
@@ -184,7 +162,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
184162
)
185163

186164
data, scale, zero_point = self.tensor_impl.get_plain()
187-
dq = dequantize_stretched_affine(
165+
dq = dequantize_affine(
188166
data,
189167
self.block_size,
190168
scale,

torchao/prototype/parq/quant/uniform_torchao.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929

3030
from .quant_api import (
3131
choose_qparams_stretched_affine,
32-
dequantize_stretched_affine,
3332
quantize_stretched_affine,
3433
)
3534
from .quantizer import Quantizer
@@ -157,7 +156,6 @@ def __init__(self, b: int, int_shift: float = 0.5) -> None:
157156

158157
self._choose_qparams = partial(choose_qparams_stretched_affine, b=b)
159158
self._quantize = quantize_stretched_affine
160-
self._dequantize = dequantize_stretched_affine
161159

162160
def get_quant_size(self, b: int) -> int:
163161
return math.floor(2**b - 2 * self.int_shift) + 1

0 commit comments

Comments
 (0)