@@ -43,7 +43,7 @@ def int4_row_quantize(
4343 scales = scales .view (x .shape [0 ], - 1 ).t ().contiguous ()
4444 zeros = zeros .view (x .shape [0 ], - 1 ).t ().contiguous ()
4545
46- return out , scales , zeros
46+ return out , scales . to ( x . dtype ) , zeros . to ( x . dtype )
4747
4848
4949def pack_int4 (x : torch .Tensor ) -> torch .Tensor :
@@ -68,6 +68,7 @@ def __new__(cls, packed_weight, scale, zero_point, group_size):
6868 shape = packed_weight .shape
6969 kwargs = {}
7070 kwargs ["device" ] = packed_weight .device
71+ kwargs ["dtype" ] = scale .dtype
7172 kwargs ["requires_grad" ] = False
7273 return torch .Tensor ._make_wrapper_subclass (cls , shape , ** kwargs ) # type: ignore[attr-defined]
7374
@@ -98,7 +99,10 @@ def _apply_fn_to_data(self, fn):
9899 )
99100
100101 def __repr__ (self ):
101- raise NotImplementedError ("Subclasses must implement __repr__" )
102+ return (
103+ f"{ self .__class__ .__name__ } (weight={ self .packed_weight } , group_size={ self .group_size } , "
104+ f"shape={ self .shape } , device={ self .device } , dtype={ self .dtype } , requires_grad={ self .requires_grad } )"
105+ )
102106
103107 @classmethod
104108 def from_float (cls , w : torch .Tensor , group_size : int = 128 ):
@@ -136,6 +140,9 @@ def _(func, types, args, kwargs):
136140 f"{ func } is not implemented for non floating point input"
137141 )
138142
143+ orig_act_size = input_tensor .size ()
144+ orig_out_features = weight_tensor .shape [- 2 ]
145+
139146 res = torch .ops .fbgemm .bf16i4bf16_rowwise (
140147 input_tensor ,
141148 weight_tensor .packed_weight ,
@@ -144,7 +151,7 @@ def _(func, types, args, kwargs):
144151 )
145152 if bias is not None :
146153 res = res + bias
147- return res
154+ return res . reshape ( * orig_act_size [: - 1 ], orig_out_features )
148155
149156
150157@implements ([aten .detach .default , aten .alias .default ])
0 commit comments