@@ -56,7 +56,7 @@ def choose_qparams_stretched_affine(
5656
5757 scale = max_val / quant_max
5858 scale = scale .to (dtype = scale_dtype , device = input_float .device )
59- zero_point = torch .full_like (scale , 0.5 , dtype = zero_point_dtype )
59+ zero_point = torch .full_like (scale , - 0.5 , dtype = zero_point_dtype )
6060 return scale , zero_point
6161
6262
@@ -95,34 +95,12 @@ def quantize_stretched_affine(
9595 max_val = scale .mul (quant_max )
9696 input_float = input_float .clamp (min = - max_val , max = max_val )
9797 with torch .no_grad ():
98- quant = torch .round (input_float / scale - zero_point )
98+ # difference from quantize_affine: add zero_point before rounding
99+ quant = torch .round (input_float / scale + zero_point )
99100 quant = quant .to (dtype = target_dtype ).view (original_shape )
100101 return quant
101102
102103
103- def dequantize_stretched_affine (
104- data : torch .Tensor ,
105- block_size : Tuple [int , ...],
106- scale : torch .Tensor ,
107- zero_point : torch .Tensor ,
108- data_dtype : torch .dtype ,
109- quant_min : Optional [int ] = None ,
110- quant_max : Optional [int ] = None ,
111- output_dtype : torch .dtype = torch .float32 ,
112- ) -> torch .Tensor :
113- # allow float data_dtype instead of restricting to _SUB_BYTE_UINT_BOUNDS
114- return dequantize_affine (
115- data ,
116- block_size ,
117- scale ,
118- - zero_point ,
119- data_dtype ,
120- quant_min = quant_min ,
121- quant_max = quant_max ,
122- output_dtype = output_dtype ,
123- )
124-
125-
126104class StretchedAffineQuantizedTensor (AffineQuantizedTensor ):
127105 @classmethod
128106 def from_hp_to_intx (
@@ -184,7 +162,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
184162 )
185163
186164 data , scale , zero_point = self .tensor_impl .get_plain ()
187- dq = dequantize_stretched_affine (
165+ dq = dequantize_affine (
188166 data ,
189167 self .block_size ,
190168 scale ,
0 commit comments