diff --git a/3rdparty/tvm b/3rdparty/tvm index 3c6317a1e..c5d987715 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 3c6317a1ea614b7277ffe0b4ede18b4652afad1c +Subproject commit c5d9877154f67dc0d3651032b15521e09dfda882 diff --git a/bitblas/gpu/matmul_mma_dequantize.py b/bitblas/gpu/matmul_mma_dequantize.py index b5c327190..6bc0e39bc 100644 --- a/bitblas/gpu/matmul_mma_dequantize.py +++ b/bitblas/gpu/matmul_mma_dequantize.py @@ -992,18 +992,16 @@ def get_idx(): sch.compute_at(block_shared_local_local, B_shared_vi, preserve_unit_loops=True) dequantize_block_local = block_shared_local - if ("zeros_mode" in weight_decode_info and - weight_decode_info["zeros_mode"] == "quantized"): - if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): - block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") - sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) - # pop the scale block - auto_inline_producers(sch, block_local_scales) - - if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): - block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") - sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) - auto_inline_producers(sch, block_local_zeros) + if ("with_scaling" in weight_decode_info and weight_decode_info["with_scaling"]): + block_local_scales = sch.cache_read(dequantize_block_local, b_idx + 1, "local") + sch.compute_at(block_local_scales, B_shared_vi, preserve_unit_loops=True) + # pop the scale block + auto_inline_producers(sch, block_local_scales) + + if ("with_zeros" in weight_decode_info and weight_decode_info["with_zeros"]): + block_local_zeros = sch.cache_read(dequantize_block_local, b_idx + 2, "local") + sch.compute_at(block_local_zeros, B_shared_vi, preserve_unit_loops=True) + auto_inline_producers(sch, block_local_zeros) for producer in weight_producers: with suppress(Exception): diff --git a/bitblas/module/__init__.py b/bitblas/module/__init__.py index 242589c7b..c1cf316ff 100644 --- a/bitblas/module/__init__.py +++ b/bitblas/module/__init__.py @@ -312,12 +312,12 @@ def load_and_transform_weight( if bias is not None: self.bias = bias - def repack_from_gptq(self, gptq_module): + def repack_from_gptq(self, gptq_module, device="cuda"): # qweight in gptq old quant linear stored with (out_features, in_features), should be transposed. qweight = gptq_module.qweight.T.contiguous().view(self.TORCH_STORAGE_DTYPE) intweight = unpack_qweight(qweight, self.bits).contiguous() if self.bitblas_matmul.weight_transform is not None: - qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).cuda() + qweight = self.bitblas_matmul.weight_transform(intweight.cpu()).to(device) self.qweight = qweight # scales in gptq old quant linear stored with (in_features // group_size, out_features), should be transposed. scales = gptq_module.scales.T.contiguous().view(self.torch_dtype) diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index dfd22e6e8..c26b9c7a9 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -603,7 +603,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): weight = weight.contiguous() if self.W_dtype == self.A_dtype: if self.weight_transform is not None: - return self.weight_transform(weight.cpu()).cuda().contiguous() + return self.weight_transform(weight.cpu()).to(weight.device).contiguous() return weight source_format, bit = self.source_format, self.bit @@ -624,7 +624,7 @@ def transform_weight(self, weight, scale=None, zeros=None, bias=None): # Apply an optional weight transformation if specified if self.weight_transform is not None: - weight = self.weight_transform(weight.cpu()).cuda().contiguous() + weight = self.weight_transform(weight.cpu()).to(weight.device).contiguous() # Prepare the return list with the transformed weight and optionally include scale, zeros, and bias result = [weight] @@ -667,15 +667,14 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(bias) args.append(output) - if self.dynamic_range is not None: - m = reduce(operator.mul, A.shape[:-1], 1) - args.append(m) - - stream = torch.cuda.current_stream() - if self.lib is None: self._forward_from_torch_func(*args) else: + if self.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) + + stream = torch.cuda.current_stream(device=A.device) self._forward_from_prebuild_lib(*args, stream=stream.cuda_stream) return output