Skip to content

Conversation

@jianan-gu
Copy link
Contributor

@jianan-gu jianan-gu commented Jul 22, 2025

Motivation

Fixes the Kimi-K2-Instruct (FP8) TP=6 failure on CPU.

ValueError: Weight output_partition_size = 2112 is not divisible by weight quantization block_n = 128.

Modifications

Considering weight_block_size when padding TP, to make self.num_heads * self.qk_head_dim / tp_size divisible by weight_block_size in the below ColumnParallelLinear.

  self.q_b_proj = ColumnParallelLinear(
      q_lora_rank,
      self.num_heads * self.qk_head_dim,
      bias=False,
      quant_config=quant_config,
      prefix=add_prefix("q_b_proj", prefix),
      tp_rank=attn_tp_rank,
      tp_size=attn_tp_size,
  )

In our case, self.num_heads = 64, self.qk_head_dim =192, tp_size = 6, block_size = 128

Before this PR, self.num_heads is padded to 66, self.num_heads * self.qk_head_dim / tp_size = 2112, which is not divisible by 128.
After this PR, self.num_heads is padded to 72, self.num_heads * self.qk_head_dim / tp_size = 2304, which is divisible by 128.

MISC.
This PR also covers a minor fix for unquant MoE module with apply_router_weight_on_input config, sine CPU AMX path supports now.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @jianan-gu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a tensor parallelism (TP) padding issue on CPU, specifically when weight_block_size is a factor. It refines the padding logic for attention heads by introducing head_dim as a critical parameter, ensuring that padding accounts for cases where head dimensions are not perfectly aligned with weight block sizes. This change improves the robustness of CPU-based tensor parallelism configurations.

Highlights

  • Refined Padding Logic: The get_num_heads_padding_size function now incorporates head_dim into its padding calculation. Previously, padding was applied if tp_size was odd and weight_block_size was present. The updated logic adds an additional condition: padding is also applied if head_dim is not perfectly divisible by the first element of weight_block_size (i.e., weight_block_size[0]), ensuring better alignment for tensor parallelism.
  • Dynamic Head Dimension Calculation: The adjust_config_with_unaligned_cpu_tp function has been enhanced to dynamically determine the appropriate head_dim to pass to the padding function. It now checks for the presence of qk_nope_head_dim and qk_rope_head_dim to compute a combined qk_head_dim, which is then prioritized as the head_dim for padding calculations. This ensures that models with more complex head dimension configurations (e.g., for different types of attention) are correctly handled.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a padding issue for Tensor Parallelism on CPUs when using block-wise quantized weights. The changes introduce a new condition for padding based on the model's head_dim and the weight_block_size. The logic correctly identifies the head_dim for various model architectures and passes it to an updated padding calculation function.

@jianan-gu jianan-gu marked this pull request as ready for review July 22, 2025 06:59
@mingfeima mingfeima added cpu cpu backend performance optimization intel labels Jul 22, 2025
@mingfeima
Copy link
Collaborator

I reckon that we still need to move fast for DP-Attn and EPMoE to skip padding for head dimension... not very decent to have a TP=6 for any scenario

@jianan-gu
Copy link
Contributor Author

I reckon that we still need to move fast for DP-Attn and EPMoE to skip padding for head dimension... not very decent to have a TP=6 for any scenario

Yes, agree and noted.

@Alcanderian Alcanderian added the ready-to-merge The PR is ready to merge after the CI is green. label Aug 7, 2025
@Alcanderian Alcanderian removed the ready-to-merge The PR is ready to merge after the CI is green. label Aug 7, 2025
@Alcanderian Alcanderian added the ready-to-merge The PR is ready to merge after the CI is green. label Aug 7, 2025
@hnyls2002 hnyls2002 merged commit 6e6009f into sgl-project:main Nov 6, 2025
61 of 72 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cpu cpu backend performance optimization intel ready-to-merge The PR is ready to merge after the CI is green. run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants