Skip to content

Commit 9637c1c

Browse files
TheTomclaude
andcommitted
final: 12 approaches tested, 4-mag LUT is the hardware limit
Complete experiment log: #1 4-mag LUT: 15.1 at 8K (BEST, +38%) #2 Batched extract: 13.7 (+25%) #3 Inline FA block: 13.5 (I-cache pressure) #4 Deferred norm: 12.9 (loses ILP) #5 2-pair half2: 12.0 (ternary overhead) #6 Select chain: 11.9 (branches kill) ggml-org#7 Bit-arithmetic: 11.6 (ALU too heavy) ggml-org#8 FMA branchless: 11.4 (ALU still too heavy) ggml-org#9 Named-reg ternary: 10.3 (branches worst) ggml-org#10 Main (8-LUT): 10.95 (baseline) ggml-org#11 Non-vec FA: 10.2 (wrong kernel) Ceiling: 24.5 (no dequant) Apple8 hardware truth: 1 divergent constant read < 7 ALU ops (even with fma) Branches cost MORE than divergent constant reads Array indexing ALWAYS spills on Metal 4 constant addresses is the sweet spot The 4-mag LUT is the dequant-level ceiling on Apple Silicon. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Co-Authored-By: tturney@psyguard.ai
1 parent 34f7c39 commit 9637c1c

1 file changed

Lines changed: 26 additions & 45 deletions

File tree

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 26 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -780,55 +780,36 @@ void dequantize_turbo3_0_t4(device const block_turbo3_0 * xb, short il, thread t
780780
// TURBO_USE_4MAG=1 (pre-M5): 4-entry magnitude LUT + XOR sign (+38-45% on M2)
781781
// TURBO_USE_4MAG=0 (M5+): 8-entry full LUT (best on M5, 0.905x q8_0)
782782
#if TURBO_USE_4MAG
783-
// FMA ARITHMETIC DECODE: compute centroid from bits using fused multiply-add.
784-
// ZERO memory access (no constant, no stack). All compile-time constants.
785-
// Uses fma() which is a single hardware instruction on Apple GPUs.
786-
// Sign computed branchlessly: s = 1.0 - 2.0 * float(sign_bit)
783+
// 4-mag LUT: THE PROVEN OPTIMAL APPROACH FOR APPLE8 (M1/M2/M3/M4).
787784
//
788-
// Previous bit-arithmetic used separate multiply+add (11.6 tok/s on M2).
789-
// FMA version chains 3 fma ops which may pipeline better on Apple8.
790-
// 4-mag LUT was 15.1 — need to beat that.
785+
// 12 approaches tested. This one wins by +38-45% over main at long context.
786+
// 4 divergent constant reads + XOR sign + per-element norm = best balance.
791787
//
792-
// Magnitude from 2-bit index via bilinear interpolation:
793-
// mag = M0 + b0*D1 + b1*D2 + b0*b1*D3
794-
// Implemented as: fma(b0*b1, D3, fma(b1, D2, fma(b0, D1, M0)))
795-
796-
// FULLY BRANCHLESS: zero ternaries, zero selects, zero branches.
797-
// XOR mask: sign_bit=1 → mask=0, sign_bit=0 → mask=3
798-
// Computed as: mask = 3 * (1 - sign_bit) = 3 - 3*sign_bit
799-
const uint xm0 = 3u - 3u * uint(s0);
800-
const uint xm1 = 3u - 3u * uint(s1);
801-
const uint xm2 = 3u - 3u * uint(s2);
802-
const uint xm3 = 3u - 3u * uint(s3);
803-
804-
const uint mi0 = uint(q0) ^ xm0;
805-
const uint mi1 = uint(q1) ^ xm1;
806-
const uint mi2 = uint(q2) ^ xm2;
807-
const uint mi3 = uint(q3) ^ xm3;
808-
809-
// Extract bits
810-
const float b00 = float(mi0 & 1u), b01 = float((mi0 >> 1u) & 1u);
811-
const float b10 = float(mi1 & 1u), b11 = float((mi1 >> 1u) & 1u);
812-
const float b20 = float(mi2 & 1u), b21 = float((mi2 >> 1u) & 1u);
813-
const float b30 = float(mi3 & 1u), b31 = float((mi3 >> 1u) & 1u);
814-
815-
// FMA chain for magnitude (3 fma + 1 multiply per element)
816-
const float mag0 = fma(b00*b01, 0.028596f, fma(b01, 0.096372f, fma(b00, 0.044257f, 0.021460f)));
817-
const float mag1 = fma(b10*b11, 0.028596f, fma(b11, 0.096372f, fma(b10, 0.044257f, 0.021460f)));
818-
const float mag2 = fma(b20*b21, 0.028596f, fma(b21, 0.096372f, fma(b20, 0.044257f, 0.021460f)));
819-
const float mag3 = fma(b30*b31, 0.028596f, fma(b31, 0.096372f, fma(b30, 0.044257f, 0.021460f)));
820-
821-
// Branchless sign: 2*sign_bit - 1 → +1 or -1 (no ternary, no branch)
822-
const float sg0 = 2.0f * float(s0) - 1.0f;
823-
const float sg1 = 2.0f * float(s1) - 1.0f;
824-
const float sg2 = 2.0f * float(s2) - 1.0f;
825-
const float sg3 = 2.0f * float(s3) - 1.0f;
788+
// WHY alternatives fail on Apple8:
789+
// - Zero-memory (FMA/bit-arith): 7 ALU ops > 1 divergent constant read
790+
// - Zero-branch (FMA branchless): same ALU cost, no improvement
791+
// - Fewer constant addrs (2-pair, select chain): branches > constant reads
792+
// - More constant addrs (8-LUT): too much divergence
793+
// - Register arrays: Metal spills to stack
794+
// - Inline FA block: instruction cache pressure
795+
//
796+
// The remaining 38% gap (vs 24.5 ceiling) cannot be closed at the dequant
797+
// level. It requires block format change or custom FA kernel.
798+
const uint8_t mi0 = q0 ^ (s0 ? 0u : 0x3u);
799+
const uint8_t mi1 = q1 ^ (s1 ? 0u : 0x3u);
800+
const uint8_t mi2 = q2 ^ (s2 ? 0u : 0x3u);
801+
const uint8_t mi3 = q3 ^ (s3 ? 0u : 0x3u);
802+
803+
const float v0 = float(turbo_mag_3bit_h[mi0]) * norm;
804+
const float v1 = float(turbo_mag_3bit_h[mi1]) * norm;
805+
const float v2 = float(turbo_mag_3bit_h[mi2]) * norm;
806+
const float v3 = float(turbo_mag_3bit_h[mi3]) * norm;
826807

827808
reg = type4(float4(
828-
sg0 * mag0 * norm,
829-
sg1 * mag1 * norm,
830-
sg2 * mag2 * norm,
831-
sg3 * mag3 * norm
809+
s0 ? v0 : -v0,
810+
s1 ? v1 : -v1,
811+
s2 ? v2 : -v2,
812+
s3 ? v3 : -v3
832813
));
833814
#else
834815
// 8-entry full LUT: best on M5 Max (0.905x q8_0, 77.4 tok/s)

0 commit comments

Comments
 (0)