@@ -244,6 +244,13 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
244244
245245 if not (shape_supported and dtype_supported ):
246246 # fall back to triton
247+ # If weight_scale is in UE8M0 packed format (int32), convert back to float32
248+ # UE8M0 format has shape (N, K//block_k//4) with dtype int32
249+ # Triton expects shape (N//block_n, K//block_k) with dtype float32
250+ if weight_scale .dtype == torch .int32 :
251+ weight_scale = _unpack_ue8m0_scale_for_triton (
252+ weight_scale , weight .shape , block_size
253+ )
247254 return triton_w8a8_block_fp8_linear (
248255 input , weight , block_size , weight_scale , input_scale , bias
249256 )
@@ -267,6 +274,67 @@ def deepgemm_w8a8_block_fp8_linear_with_fallback(
267274 return output .to (dtype = output_dtype ).view (* output_shape )
268275
269276
277+ def _unpack_ue8m0_scale_for_triton (
278+ sf_packed : torch .Tensor ,
279+ weight_shape : Tuple [int , int ],
280+ block_size : List [int ],
281+ ) -> torch .Tensor :
282+ """
283+ Unpack UE8M0 packed scale tensor back to float32 format for triton kernel.
284+
285+ The UE8M0 format packs scales as:
286+ - Shape: (N, K//block_k//4) with dtype int32
287+ - Each int32 contains 4 uint8 scale values
288+
289+ Triton expects:
290+ - Shape: (N//block_n, K//block_k) with dtype float32
291+
292+ Args:
293+ sf_packed: Packed scale tensor with shape (N, packed_k_groups) and dtype int32
294+ weight_shape: (N, K) shape of the weight tensor
295+ block_size: [block_n, block_k] quantization block size
296+
297+ Returns:
298+ Unpacked scale tensor with shape (n_groups, k_groups) and dtype float32
299+ """
300+ assert sf_packed .dtype == torch .int32
301+ assert len (sf_packed .shape ) == 2
302+
303+ N , K = weight_shape
304+ block_n , block_k = block_size
305+ n_groups = ceil_div (N , block_n )
306+ k_groups = ceil_div (K , block_k )
307+
308+ mn_repeat , k_div_4 = sf_packed .shape
309+ k_packed = k_div_4 * 4
310+
311+ # Unpack int32 -> 4x uint8 -> float32
312+ # Each uint8 represents an exponent in UE8M0 format
313+ sf_u8 = sf_packed .contiguous ().view (torch .uint8 ).view (mn_repeat , k_packed )
314+ sf_fp32 = (sf_u8 .to (torch .int32 ) << 23 ).view (torch .float32 )
315+
316+ # Handle row dimension - may have 128x replication or direct mapping
317+ if mn_repeat == N :
318+ # Rows are replicated 128 times, take every 128th row
319+ # sf_fp32 shape: (N, k_packed) -> (n_groups, k_packed)
320+ # Select representative rows at indices 0, 128, 256, ...
321+ indices = torch .arange (0 , N , block_n , device = sf_packed .device )
322+ sf_fp32 = sf_fp32 .index_select (0 , indices )
323+ elif mn_repeat == n_groups :
324+ # Already in the correct n_groups format
325+ pass
326+ else :
327+ raise ValueError (
328+ f"Unexpected scale shape: sf_packed.shape={ sf_packed .shape } , "
329+ f"weight_shape={ weight_shape } , block_size={ block_size } "
330+ )
331+
332+ # Crop k dimension to expected size (remove padding if any)
333+ sf_fp32 = sf_fp32 [:, :k_groups ].contiguous ()
334+
335+ return sf_fp32
336+
337+
270338def aiter_w8a8_block_fp8_linear (
271339 input : torch .Tensor ,
272340 weight : torch .Tensor ,
0 commit comments