@@ -1380,6 +1380,84 @@ def forward(
13801380 input_ids , hidden_states , self .lm_head , forward_batch
13811381 )
13821382
1383+ def post_load_weights (self ):
1384+
1385+ # Perform post-processing after loading weights
1386+
1387+ if not global_server_args_dict ["disable_mla" ]:
1388+ for layer_id in range (self .config .num_hidden_layers ):
1389+ self_attn = self .model .layers [layer_id ].self_attn
1390+ if hasattr (self_attn .kv_b_proj , "qweight" ):
1391+ # AWQ compatible
1392+ if _is_cuda :
1393+ w = awq_dequantize (
1394+ self_attn .kv_b_proj .qweight ,
1395+ self_attn .kv_b_proj .scales ,
1396+ self_attn .kv_b_proj .qzeros ,
1397+ ).T
1398+ else :
1399+ w = ops .awq_dequantize (
1400+ self_attn .kv_b_proj .qweight ,
1401+ self_attn .kv_b_proj .scales ,
1402+ self_attn .kv_b_proj .qzeros ,
1403+ 0 ,
1404+ 0 ,
1405+ 0 ,
1406+ ).T
1407+ else :
1408+ w = self_attn .kv_b_proj .weight
1409+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1410+ # This may affect the accuracy of fp8 model.
1411+ if hasattr (self .quant_config , "weight_block_size" ) and w .dtype in (
1412+ torch .float8_e4m3fn ,
1413+ torch .float8_e4m3fnuz ,
1414+ ):
1415+ weight_block_size = self .quant_config .weight_block_size
1416+ if weight_block_size is not None :
1417+ assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1418+ if _is_hip :
1419+ weight , weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
1420+ weight = w ,
1421+ weight_scale = self_attn .kv_b_proj .weight_scale_inv ,
1422+ input_scale = None ,
1423+ )
1424+ else :
1425+ weight = w
1426+ weight_scale = self_attn .kv_b_proj .weight_scale_inv
1427+
1428+ w , scale = block_quant_to_tensor_quant (
1429+ weight , weight_scale , weight_block_size
1430+ )
1431+ self_attn .w_scale = scale
1432+ if w .dtype == torch .int8 :
1433+ if hasattr (self .quant_config , "weight_block_size" ):
1434+ # block-wise int8 need it
1435+ weight_block_size = self .quant_config .weight_block_size
1436+ if weight_block_size is not None :
1437+ assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1438+ weight = w
1439+ weight_scale = self_attn .kv_b_proj .weight_scale_inv
1440+ w = int8_block_dequant (
1441+ weight , weight_scale , weight_block_size
1442+ ).to (torch .bfloat16 )
1443+ else :
1444+ # channel-wise int8 need it
1445+ w = w .to (torch .bfloat16 ) * self_attn .kv_b_proj .weight_scale .to (
1446+ torch .bfloat16
1447+ )
1448+ w_kc , w_vc = w .unflatten (
1449+ 0 , (- 1 , self_attn .qk_nope_head_dim + self_attn .v_head_dim )
1450+ ).split ([self_attn .qk_nope_head_dim , self_attn .v_head_dim ], dim = 1 )
1451+ self_attn .w_kc = w_kc .transpose (1 , 2 ).contiguous ().transpose (1 , 2 )
1452+ self_attn .w_vc = w_vc .contiguous ().transpose (1 , 2 )
1453+ if (
1454+ hasattr (self_attn .kv_b_proj , "weight_scale" )
1455+ and self_attn .w_scale is None
1456+ ):
1457+ self_attn .w_scale = self_attn .kv_b_proj .weight_scale
1458+ if _is_hip :
1459+ self_attn .w_scale *= 2.0
1460+
13831461 def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
13841462 stacked_params_mapping = [
13851463 # (param_name, shard_name, shard_id)
@@ -1504,79 +1582,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
15041582 )
15051583 weight_loader (param , loaded_weight )
15061584
1507- if not global_server_args_dict ["disable_mla" ]:
1508- for layer_id in range (self .config .num_hidden_layers ):
1509- self_attn = self .model .layers [layer_id ].self_attn
1510- if hasattr (self_attn .kv_b_proj , "qweight" ):
1511- # AWQ compatible
1512- if _is_cuda :
1513- w = awq_dequantize (
1514- self_attn .kv_b_proj .qweight ,
1515- self_attn .kv_b_proj .scales ,
1516- self_attn .kv_b_proj .qzeros ,
1517- ).T
1518- else :
1519- w = ops .awq_dequantize (
1520- self_attn .kv_b_proj .qweight ,
1521- self_attn .kv_b_proj .scales ,
1522- self_attn .kv_b_proj .qzeros ,
1523- 0 ,
1524- 0 ,
1525- 0 ,
1526- ).T
1527- else :
1528- w = self_attn .kv_b_proj .weight
1529- # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1530- # This may affect the accuracy of fp8 model.
1531- if hasattr (self .quant_config , "weight_block_size" ) and w .dtype in (
1532- torch .float8_e4m3fn ,
1533- torch .float8_e4m3fnuz ,
1534- ):
1535- weight_block_size = self .quant_config .weight_block_size
1536- if weight_block_size is not None :
1537- assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1538- if _is_hip :
1539- weight , weight_scale , _ = normalize_e4m3fn_to_e4m3fnuz (
1540- weight = w ,
1541- weight_scale = self_attn .kv_b_proj .weight_scale_inv ,
1542- input_scale = None ,
1543- )
1544- else :
1545- weight = w
1546- weight_scale = self_attn .kv_b_proj .weight_scale_inv
1547-
1548- w , scale = block_quant_to_tensor_quant (
1549- weight , weight_scale , weight_block_size
1550- )
1551- self_attn .w_scale = scale
1552- if w .dtype == torch .int8 :
1553- if hasattr (self .quant_config , "weight_block_size" ):
1554- # block-wise int8 need it
1555- weight_block_size = self .quant_config .weight_block_size
1556- if weight_block_size is not None :
1557- assert hasattr (self_attn .kv_b_proj , "weight_scale_inv" )
1558- weight = w
1559- weight_scale = self_attn .kv_b_proj .weight_scale_inv
1560- w = int8_block_dequant (
1561- weight , weight_scale , weight_block_size
1562- ).to (torch .bfloat16 )
1563- else :
1564- # channel-wise int8 need it
1565- w = w .to (torch .bfloat16 ) * self_attn .kv_b_proj .weight_scale .to (
1566- torch .bfloat16
1567- )
1568- w_kc , w_vc = w .unflatten (
1569- 0 , (- 1 , self_attn .qk_nope_head_dim + self_attn .v_head_dim )
1570- ).split ([self_attn .qk_nope_head_dim , self_attn .v_head_dim ], dim = 1 )
1571- self_attn .w_kc = w_kc .transpose (1 , 2 ).contiguous ().transpose (1 , 2 )
1572- self_attn .w_vc = w_vc .contiguous ().transpose (1 , 2 )
1573- if (
1574- hasattr (self_attn .kv_b_proj , "weight_scale" )
1575- and self_attn .w_scale is None
1576- ):
1577- self_attn .w_scale = self_attn .kv_b_proj .weight_scale
1578- if _is_hip :
1579- self_attn .w_scale *= 2.0
1585+ self .post_load_weights ()
15801586
15811587 def get_embed_and_head (self ):
15821588 return self .model .embed_tokens .weight , self .lm_head .weight
0 commit comments