diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp index 6c4d600e6f..9c0832c5b7 100644 --- a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_aten.cpp @@ -66,6 +66,7 @@ TORCH_LIBRARY(torchao, m) { DEFINE_OP(3); DEFINE_OP(4); DEFINE_OP(5); + DEFINE_OP(6); } TORCH_LIBRARY_IMPL(torchao, CPU, m) { @@ -74,6 +75,7 @@ TORCH_LIBRARY_IMPL(torchao, CPU, m) { DEFINE_CPU_IMPL(3); DEFINE_CPU_IMPL(4); DEFINE_CPU_IMPL(5); + DEFINE_CPU_IMPL(6); } TORCH_LIBRARY_IMPL(torchao, Meta, m) { @@ -82,4 +84,5 @@ TORCH_LIBRARY_IMPL(torchao, Meta, m) { DEFINE_META_IMPL(3); DEFINE_META_IMPL(4); DEFINE_META_IMPL(5); + DEFINE_META_IMPL(6); } diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w6s.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w6s.cpp new file mode 100644 index 0000000000..4667442de1 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w6s.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_6bit0zp_weight.out", _op_out); diff --git a/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w6sz.cpp b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w6sz.cpp new file mode 100644 index 0000000000..199bbcc932 --- /dev/null +++ b/torchao/experimental/ops/linear_8bit_act_xbit_weight/op_linear_8bit_act_xbit_weight_executorch/w6sz.cpp @@ -0,0 +1,29 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the license found in the +// LICENSE file in the root directory of this source tree. + +// Unlike ATen, ExecuTorch op registration appears to only allow on +// EXECUTORCH_LIBRARY per cpp file due to a name redefinition error, so a new +// file is needed for each variant + +#include + +namespace { +Tensor _op_out( + RuntimeContext& ctx, + const Tensor& activations, + const Tensor& packed_weights, + const Tensor& group_size_tensor, + const Tensor& n_tensor, + const Tensor& k_tensor, + Tensor& out) { + (void)ctx; + linear_out_cpu( + activations, packed_weights, group_size_tensor, n_tensor, k_tensor, out); + return out; +} +} // namespace + +EXECUTORCH_LIBRARY(torchao, "_linear_8bit_act_6bit_weight.out", _op_out); diff --git a/torchao/experimental/quant_api.py b/torchao/experimental/quant_api.py index ac21c75221..7666f8b78d 100644 --- a/torchao/experimental/quant_api.py +++ b/torchao/experimental/quant_api.py @@ -115,15 +115,7 @@ def forward(self, x): lead_shape = x.shape[0:-2] m, k = x.shape[-2], x.shape[-1] n = self._n.shape[1] - x = x.reshape(-1, m, k) - - res = [ - self._linear_op( - x[i, :, :], self.packed_weights, self._group_size, self._n, self._k - ) - for i in range(x.shape[0]) - ] - res = torch.stack(res) + res = self._linear_op(x.reshape(-1, k), self.packed_weights, self._group_size, self._n, self._k) res = res.reshape(*lead_shape, m, n) return res @@ -206,7 +198,7 @@ def forward(self, x): def _maybe_get_quantized_linear_native(nbit, has_weight_zeros): try: - if nbit in [1, 2, 3, 4, 5]: + if nbit in [1, 2, 3, 4, 5, 6]: wzp_suffix = "" if has_weight_zeros else "0zp" return _Int8DynActIntxWeightQuantizedLinearNative( pack_weight_op=getattr( diff --git a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py index 1966fd1589..45dce490ac 100644 --- a/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py +++ b/torchao/experimental/tests/test_linear_8bit_act_xbit_weight_quantizer.py @@ -36,7 +36,7 @@ def test_accuracy(self): m = 1 n = 1071 k = 4096 - activations = torch.randn(m, k, dtype=torch.float32) + activations = torch.randn(2, 3, m, k, dtype=torch.float32) model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) for nbit in [1, 2, 3, 4, 5, 6, 7]: @@ -84,7 +84,7 @@ def test_export_compile_aoti(self): layers = [torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, k2, bias=False), torch.nn.Linear(k2, k3, bias=False)] model = torch.nn.Sequential(*layers) - activations = torch.randn(2, 1, m, k0, dtype=torch.float32) + activations = torch.randn(m, k0, dtype=torch.float32) print("Quantizing model") quantizer = Int8DynActIntxWeightQuantizer(