Skip to content

Commit 9df9b49

Browse files
committed
fix dtype
1 parent 22e0eba commit 9df9b49

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def callback(x):
11741174
help=(
11751175
"Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, "
11761176
+ "autoquant-int4, autoquant-gemlite-int4, autoquant-float8, autoquant-sparse, autoquant-all, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, "
1177-
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits>"
1177+
+ "embed-int8wo, marlin_qqq, gemlite-<pack_bitwidth>-<nbits>-<groupsize>, float8dq, int4dq-<nbits>, fbgemm-int4-<group_size>"
11781178
),
11791179
)
11801180
parser.add_argument(

torchao/dtypes/fbgemm_int4_tensor.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def int4_row_quantize(
4343
scales = scales.view(x.shape[0], -1).t().contiguous()
4444
zeros = zeros.view(x.shape[0], -1).t().contiguous()
4545

46-
return out, scales, zeros
46+
return out, scales.to(x.dtype), zeros.to(x.dtype)
4747

4848

4949
def pack_int4(x: torch.Tensor) -> torch.Tensor:
@@ -68,6 +68,7 @@ def __new__(cls, packed_weight, scale, zero_point, group_size):
6868
shape = packed_weight.shape
6969
kwargs = {}
7070
kwargs["device"] = packed_weight.device
71+
kwargs["dtype"] = scale.dtype
7172
kwargs["requires_grad"] = False
7273
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
7374

@@ -98,7 +99,10 @@ def _apply_fn_to_data(self, fn):
9899
)
99100

100101
def __repr__(self):
101-
raise NotImplementedError("Subclasses must implement __repr__")
102+
return (
103+
f"{self.__class__.__name__}(weight={self.packed_weight}, group_size={self.group_size}, "
104+
f"shape={self.shape}, device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
105+
)
102106

103107
@classmethod
104108
def from_float(cls, w: torch.Tensor, group_size: int = 128):
@@ -136,6 +140,9 @@ def _(func, types, args, kwargs):
136140
f"{func} is not implemented for non floating point input"
137141
)
138142

143+
orig_act_size = input_tensor.size()
144+
orig_out_features = weight_tensor.shape[-2]
145+
139146
res = torch.ops.fbgemm.bf16i4bf16_rowwise(
140147
input_tensor,
141148
weight_tensor.packed_weight,
@@ -144,7 +151,7 @@ def _(func, types, args, kwargs):
144151
)
145152
if bias is not None:
146153
res = res + bias
147-
return res
154+
return res.reshape(*orig_act_size[:-1], orig_out_features)
148155

149156

150157
@implements([aten.detach.default, aten.alias.default])

0 commit comments

Comments
 (0)