Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions benchmarks/benchmark_marlin_qqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import pandas as pd
from torchao.utils import benchmark_torch_function_in_microseconds
from torchao.ops import marlin_qqq_gemm
from torchao.quantization.marlin_qqq import marlin_qqq_workspace, pack_to_marlin_qqq
from tqdm import tqdm


def get_problem(m, n, k, groupsize=-1):
if groupsize == -1:
groupsize = k
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((k, n), dtype=torch.half, device=dev)

A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
B = torch.randint(low=-(2**31), high=2**31, size=(k, n), device=dev)
s_tok = torch.ones((m, 1), dtype=torch.float, device=dev)
if groupsize == k:
s_group = torch.tensor([], dtype=torch.half, device=dev)
else:
s_group = torch.ones((k // groupsize, n), dtype=torch.half, device=dev)
s_channel = torch.ones((1, n), dtype=torch.float, device=dev)
B, s_group, s_channel = pack_to_marlin_qqq(
B, s_group, s_channel, num_bits=4, group_size=group_size
)
qqq_workspace = marlin_qqq_workspace(n)
return A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace


def benchmark(m: int, k: int, n: int, group_size: int):
A, B, A_ref, B_ref, s_tok, s_channel, s_group, qqq_workspace = get_problem(
m, n, k, group_size
)

fp16_time = benchmark_torch_function_in_microseconds(torch.matmul, A_ref, B_ref)
marlin_qqq_w4a8_time = benchmark_torch_function_in_microseconds(
marlin_qqq_gemm, A, B, s_tok, s_channel, s_group, qqq_workspace, m, n, k
)

return {
"m": m,
"k": k,
"n": n,
"group_size": group_size,
"fp16_latency (ms)": fp16_time,
"marlin_qqq_w4a8_latency (ms)": marlin_qqq_w4a8_time,
"speedup (d/s)": fp16_time / marlin_qqq_w4a8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)

results = []
for group_size in tqdm([-1, 128]):
for m in tqdm([1 << i for i in range(10)]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n, group_size))

df = pd.DataFrame(results)
df.to_csv("marlin_qqq_w4a8_llm_benchmark_results.csv", index=False)
print(df.to_markdown(index=False))
129 changes: 129 additions & 0 deletions test/quantization/test_marlin_qqq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import copy

import pytest
import torch
from torch import nn
from torch.testing._internal.common_utils import TestCase, run_tests

from torchao.dtypes import MarlinQQQLayout
from torchao.quantization.marlin_qqq import (
pack_to_marlin_qqq,
unpack_from_marlin_qqq,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_and_quantize_affine_qqq,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5


class MarlinQQQ(TestCase):
def setUp(self):
super().setUp()
torch.manual_seed(0)

self.input = torch.randn((64, 32, 8192), dtype=torch.float16, device="cuda")
self.model = (
nn.Sequential(
nn.Linear(8192, 21504),
nn.Linear(21504, 8192),
nn.ReLU(),
nn.Linear(8192, 21504),
nn.Linear(21504, 8192),
)
.half()
.cuda()
)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_marlin_qqq(self):
output_ref = self.model(self.input)
for group_size in [-1, 128]:
modelq = copy.deepcopy(self.model)
quantize_(
modelq,
int8_dynamic_activation_int4_weight(
group_size=group_size,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
output = modelq(self.input)

assert torch.allclose(
output, output_ref, atol=1e-1
), "Results are not close"

@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_5, reason="Needs PyTorch 2.5+")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_marlin_qqq_compile(self):
model_copy = copy.deepcopy(self.model)
model_copy.forward = torch.compile(model_copy.forward, fullgraph=True)
output_ref = model_copy(self.input)

for group_size in [-1, 128]:
modelq = copy.deepcopy(self.model)
quantize_(
modelq,
int8_dynamic_activation_int4_weight(
group_size=group_size,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
modelq.forward = torch.compile(modelq.forward, fullgraph=True)
output = modelq(self.input)

assert torch.allclose(
output, output_ref, atol=1e-1
), "Results are not close"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_pack_unpack_equivalence(self):
num_bits = 4
shape = (11008, 4096)

w = torch.rand(shape, dtype=torch.float16, device="cuda")

for group_size in [-1, 128]:
# Quantize weights
q_w, s_group, s_channel, _ = choose_qparams_and_quantize_affine_qqq(
w, num_bits, group_size
)

q_w = q_w.t()
s_group = s_group.t()
s_channel = s_channel.t()

# Test pack/unpack equivalence
q_w_comp, packed_s_group, packed_s_channel = pack_to_marlin_qqq(
q_w, s_group, s_channel, num_bits, group_size
)
unpacked_q_w, unpacked_s_group, unpacked_s_channel = unpack_from_marlin_qqq(
q_w_comp,
packed_s_group,
packed_s_channel,
q_w.shape,
num_bits,
group_size,
)

assert torch.equal(
q_w, unpacked_q_w
), "Unpacked weights do not match original weights"
assert torch.equal(
s_channel, unpacked_s_channel
), "Unpacked s_channel do not match original s_channel"
assert torch.equal(
s_group, unpacked_s_group
), "Unpacked s_group do not match original s_group"


if __name__ == "__main__":
run_tests()
109 changes: 109 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
from torchao.utils import is_fbcode, TORCH_VERSION_AT_LEAST_2_5, compute_max_diff
from torchao.dtypes.floatx import from_scaled_tc_floatx
from torchao.sparsity.marlin import marlin_24_workspace, pack_to_marlin_24, inject_24
from torchao.quantization.marlin_qqq import (
marlin_qqq_workspace,
pack_to_marlin_qqq,
)
from torchao.quantization.quant_primitives import choose_qparams_and_quantize_affine_qqq
import pytest

if is_fbcode():
Expand Down Expand Up @@ -426,5 +431,109 @@ def test_marlin_24(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_facto
)


MARLIN_QQQ_BATCH_SIZE = [1, 4, 8, 16, 32, 64]
MARLIN_QQQ_K_CHUNKS = [128]
MARLIN_QQQ_N_CHUNKS = [64, 128, 256]
MNK_FACTORS = [
(1, 1, 1),
(1, 4, 8),
(1, 7, 5),
(13, 17, 67),
(26, 37, 13),
(67, 13, 11),
]
MARLIN_QQQ_SUPPORTED_NUM_BITS = [4]
MARLIN_QQQ_SUPPORTED_GROUP_SIZES = [-1, 128]

MARLIN_TEST_PARAMS = list(
itertools.product(
MARLIN_QQQ_BATCH_SIZE,
MARLIN_QQQ_K_CHUNKS,
MARLIN_QQQ_N_CHUNKS,
MARLIN_QQQ_SUPPORTED_NUM_BITS,
MARLIN_QQQ_SUPPORTED_GROUP_SIZES,
MNK_FACTORS,
)
)


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize(
"batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors",
MARLIN_TEST_PARAMS,
ids=str,
)
def test_marlin_qqq(batch_size, k_chunk, n_chunk, num_bits, group_size, mnk_factors):
int8_traits = torch.iinfo(torch.int8)
m_factor, n_factor, k_factor = mnk_factors

size_m = m_factor
size_k = k_chunk * k_factor
size_n = n_chunk * n_factor

a_input = torch.randn(
(batch_size, size_m, size_k), dtype=torch.float16, device="cuda"
)
b_weight = torch.rand((size_n, size_k), dtype=torch.float16, device="cuda")

# Reshape input into 2D tensor
input_2d = a_input.view(-1, a_input.shape[-1])
a_input_in, a_input_out = input_2d.shape

# Quantize activations
s_a = (
input_2d.abs()
.max(dim=-1, keepdim=True)[0]
.div(int8_traits.max)
.to(torch.float32)
)
q_a = (
(input_2d / s_a).round().clamp(int8_traits.min, int8_traits.max).to(torch.int8)
)

# Quantize weights
q_w, s_group, s_channel, w_ref = choose_qparams_and_quantize_affine_qqq(
b_weight, num_bits, group_size
)
q_w = q_w.t()
s_group = s_group.t()
s_channel = s_channel.t()
w_ref = w_ref.t()
marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel = pack_to_marlin_qqq(
q_w, s_group, s_channel, num_bits, group_size
)

workspace = marlin_qqq_workspace(size_n)

# Obtains reference output
output_ref = torch.matmul(q_a.half() * s_a.half(), w_ref)
output_ref = output_ref.reshape(a_input.shape[:-1] + (size_n,))

fn_inputs = (
q_a,
marlin_qqq_q_w,
s_a,
marlin_qqq_s_channel,
marlin_qqq_s_group,
workspace,
a_input_in,
size_n,
a_input_out,
)
output = torchao.ops.marlin_qqq_gemm(*fn_inputs)
output = output.reshape(a_input.shape[:-1] + (size_n,))

max_diff = compute_max_diff(output, output_ref)
assert max_diff < 0.04

# Performs opcheck
test_utils = ["test_schema", "test_autograd_registration", "test_faketensor"]
opcheck(
torch.ops.torchao.marlin_qqq_gemm,
fn_inputs,
test_utils=test_utils,
)


if __name__ == "__main__":
run_tests()
20 changes: 17 additions & 3 deletions torchao/_models/llama/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch._dynamo.config
import torch._inductor.config
from torchao.utils import get_model_size_in_bytes
from torchao.quantization.quant_primitives import MappingType
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

def device_sync(device):
Expand Down Expand Up @@ -211,6 +212,7 @@ def main(
int8_weight_only,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_dynamic_activation_int4_weight,
fpx_weight_only,
uintx_weight_only,
autoquant,
Expand All @@ -235,8 +237,20 @@ def main(
assert group_size in [32,64,128,256], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
quantize_(model, int4_weight_only(group_size=group_size))
if "marlin" in quantization:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "qqq" in quantization:
from torchao.dtypes import MarlinQQQLayout
quantize_(
model,
int8_dynamic_activation_int4_weight(
group_size=128,
mapping_type=MappingType.SYMMETRIC,
act_mapping_type=MappingType.SYMMETRIC,
layout=MarlinQQQLayout(),
),
)
else:
from torchao.dtypes import MarlinSparseLayout
quantize_(model, int4_weight_only(layout=MarlinSparseLayout()))
if "fp6" in quantization:
quantize_(model, fpx_weight_only(3, 2))
if "embed-int8wo" in quantization:
Expand Down Expand Up @@ -474,7 +488,7 @@ def callback(x):
help=(
'Which quantization techniques to apply: int8dq, int8wo, fp6, int4wo-<groupsize>, int4wo-<groupsize>-hqq, autoquant, '
+'autoquant-int4, autoquant-float8, uintx-<nbits>-<groupsize>, uintx-<nbits>-<groupsize>-hqq, sparse-marlin, spinquant, '
+'embed-int8wo'
+'embed-int8wo, marlin_qqq'
)
)
parser.add_argument("--calibration_limit", type=int, default=10, help="Number of calibration examples")
Expand Down
Loading
Loading