From 2d4d44d7b9dd250702d4df60c58a0d9d114a79b7 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 28 Aug 2024 14:06:45 +0000 Subject: [PATCH 1/7] Merge branch 'main' of https://github.com/microsoft/BitBLAS into main --- bitblas/gpu/matmul_mma_dequantize.py | 22 ++++++++++------------ bitblas/module/__init__.py | 4 ++-- bitblas/ops/general_matmul/__init__.py | 21 +++++++++++---------- 3 files changed, 23 insertions(+), 24 deletions(-) diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index b5c327190..6bc0e39bc 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -992,18 +992,16 @@ def get_idx(): sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) dequantize_block_local = block_shared_local - if ("zeros_mode" in weight_decode_info and - weight_decode_info["zeros_mode"] == "quantized"): - if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): - block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") - sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) - # pop the scale block - auto_inline_producers(sch, block_local_scales) - - if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): - block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") - sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + # pop the scale block + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) for producer in weight_producers: with suppress(Exception): diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index 242589c7b..c1cf316ff 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -312,12 +312,12 @@ def load_and_transform_weight( if bias is not None: self.bias = bias - def repack_from_gptq(self, gptq_module): + def repack_from_gptq(self, gptq_module, device="cuda"): # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) intweight = unpack_qweight(qweight, self.bits).contiguous() if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda() + qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).to(device) self.qweight = qweight # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index dfd22e6e8..39ab6ee2b 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -603,7 +603,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): weight = weight.contiguous() if self.W_dtype == self.A_dtype: if self.weight_transform is not None: - return self.weight_transform(weight.cpu()).cuda().contiguous() + return self.weight_transform(weight.cpu()).to(weight.device).contiguous() return weight source_format, bit = self.source_format, self.bit @@ -624,7 +624,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): # Apply an optional weight transformation if specified if self.weight_transform is not None: - weight = self.weight_transform(weight.cpu()).cuda().contiguous() + weight = self.weight_transform(weight.cpu()).to(weight.device).contiguous() # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias result = [weight] @@ -666,18 +666,19 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if bias is not None: args.append(bias) args.append(output) - - if self.dynamic_range is not None: - m = reduce(operator.mul, A.shape[:-1], 1) - args.append(m) - - stream = torch.cuda.current_stream() - + + # self._forward_from_torch_func(*args) if self.lib is None: self._forward_from_torch_func(*args) else: + if self.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) + print(self.get_source()) + stream = torch.cuda.current_stream(device=A.device) + torch.cuda.set_device(A.device) self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) - + return output def __call__(self, *args: Any, **kwds: Any) -> Any: From 390ad18588ab291f0f6c541f2899fb7fa5f6597e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 28 Aug 2024 14:16:53 +0000 Subject: [PATCH 2/7] remove debug print --- bitblas/ops/general_matmul/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 39ab6ee2b..04b88999a 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -674,7 +674,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if self.dynamic_range is not None: m = reduce(operator.mul, A.shape[:-1], 1) args.append(m) - print(self.get_source()) + stream = torch.cuda.current_stream(device=A.device) torch.cuda.set_device(A.device) self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) From dcf3a2e70e1f62b238cbd1bcace9acf537643a03 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 28 Aug 2024 14:35:44 +0000 Subject: [PATCH 3/7] Refactor Matmul class for improved readability and maintainability --- bitblas/ops/general_matmul/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 04b88999a..ba516e9ef 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -666,7 +666,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if bias is not None: args.append(bias) args.append(output) - + # self._forward_from_torch_func(*args) if self.lib is None: self._forward_from_torch_func(*args) @@ -678,7 +678,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: stream = torch.cuda.current_stream(device=A.device) torch.cuda.set_device(A.device) self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) - + return output def __call__(self, *args: Any, **kwds: Any) -> Any: From 42b4213bed5083e9c62267f3664ccc50bdb6f2f3 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 28 Aug 2024 14:36:58 +0000 Subject: [PATCH 4/7] Refactor Matmul class for improved readability and maintainability --- bitblas/ops/general_matmul/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index ba516e9ef..72a721429 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -667,7 +667,6 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(bias) args.append(output) - # self._forward_from_torch_func(*args) if self.lib is None: self._forward_from_torch_func(*args) else: From d3674ec0885cfdad6b9cde0e8e4e7561d6b568a1 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 28 Aug 2024 15:05:28 +0000 Subject: [PATCH 5/7] revert set device --- bitblas/ops/general_matmul/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 72a721429..3d361db43 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -666,7 +666,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if bias is not None: args.append(bias) args.append(output) - + if self.lib is None: self._forward_from_torch_func(*args) else: @@ -675,7 +675,6 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(m) stream = torch.cuda.current_stream(device=A.device) - torch.cuda.set_device(A.device) self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) return output From 02176b2de4c3f85ed09acf34fc18965650b29d9c Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 28 Aug 2024 15:15:22 +0000 Subject: [PATCH 6/7] lint fix --- bitblas/ops/general_matmul/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 3d361db43..c26b9c7a9 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -666,7 +666,7 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: if bias is not None: args.append(bias) args.append(output) - + if self.lib is None: self._forward_from_torch_func(*args) else: From fae145c22dae353e10bd37f4233234b1f4055a34 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Fri, 30 Aug 2024 03:13:57 +0000 Subject: [PATCH 7/7] register fp8 for dynamic --- 3rdparty/tvm | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 3c6317a1e..c5d987715 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3c6317a1ea614b7277ffe0b4ede18b4652afad1c +Subproject commit c5d9877154f67dc0d3651032b15521e09dfda882