Skip to content

[Deepseek R1][v0] Porting deepseek r1 to habana_main#1161

Merged
xuechendi merged 11 commits intohabana_mainfrom
dev/vllmfork_deepseek_r1
May 8, 2025
Merged

[Deepseek R1][v0] Porting deepseek r1 to habana_main#1161
xuechendi merged 11 commits intohabana_mainfrom
dev/vllmfork_deepseek_r1

Conversation

@xuechendi
Copy link
Copy Markdown

@xuechendi xuechendi commented Apr 24, 2025

JIRA: https://jira.habana-labs.com/browse/SW-227174

cherry-pick #1030 and fixed conflicts after rebase
Dependency: HabanaAI/vllm-hpu-extension#161
HabanaAI/vllm-hpu-extension#170

Verified with below 3 methods:

  1. test with deepseek-v2 BF16 weight. => Passed
  2. evaluate acc on deepseek-r1 with out of box block fp8 weight => Passed
  3. evaluate acc on deepseek-r1 with out of box block fp8 weight + INC calibrated per-channel scale => Passed acc check, performance reach goal(number is in jira ticket)

== Details ==

  1. test with deepseek-v2 BF16 weight:
PT_HPU_LAZY_MODE=1 python run_example_tp.py --model DeepSeek-V2-Lite --tokenizer DeepSeek-V2-Lite --osl 32 
(VllmWorkerProcess pid=1039) WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up!
(VllmWorkerProcess pid=1038) WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up!
(VllmWorkerProcess pid=1041) WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up!
WARNING 04-25 03:01:53 [hpu_model_runner.py:1039] Configuration: ('decode', 4, 128) was not warmed-up!
Processed prompts: 100%|████████████████████████████████████████████████████████████████████████████| 4/4 [00:02<00:00,  1.57it/s, est. speed input: 12.59 toks/s, output: 50.37 toks/s]
e2e took 2.5509743690199684 seconds
====================================
Prompt: 'Hello, my name is'
Generated text: '\nI am a 20 year old student from the UK. I am currently studying for a degree in English Literature and Creative Writing at the University of East'
Ground truth: None
====================================
====================================
Prompt: '0.999 compares to 0.9 is '
Generated text: '100%\n0.9999999999999999999999999'
Ground truth: None
====================================
====================================
Prompt: 'The capital of France is'
Generated text: ' Paris, which is also the largest city in the country. The city is located on the Seine River and is known for its beautiful architecture, museums, and art'
Ground truth: None
====================================
====================================
Prompt: 'The future of AI is'
Generated text: ' in the hands of the people\nThe future of AI is in the hands of the people\nThe future of AI is in the hands of the people\nThe'
Ground truth: None
====================================
  1. evaluate acc on deepseek-r1 with out of box block fp8 weight - limit 256
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9648 ± 0.0115
strict-match 5 exact_match 0.9648 ± 0.0115
  1. evaluate acc on deepseek-r1 with out of box block fp8 weight + INC calibrated per-channel scale

run calibration

{
    "method": "HOOKS",
    "mode": "MEASURE",
    "observer": "maxabs",
    "whitelist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": ["lm_head", "mlp\\.gate\\b"]
    },
    "quantize_weight": false,
    "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
}
OFFICIAL_FP8_MODEL=DeepSeek-R1

PT_HPU_LAZY_MODE=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
QUANT_CONFIG=inc_measure_with_fp8kv_config.json \
python run_example_tp.py --model ${OFFICIAL_FP8_MODEL} --tokenizer ${OFFICIAL_FP8_MODEL} --osl 32 --max_num_seqs 1 --nprompts 512 --dataset pile

run test

QUANT_CONFIG=inc_quant_with_fp8kv_config.json \
PT_HPU_LAZY_MODE=1 \
VLLM_DISABLE_MARK_SCALES_AS_CONST=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
lm_eval --model vllm \
  --model_args "pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc,enable_expert_parallel=True,max_num_seqs=256" \
  --tasks gsm8k --num_fewshot "5" \
  --batch_size "256" --limit 16 --log_samples --output_path fp8_gsm8k.json
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9688 ± 0.0109
strict-match 5 exact_match 0.9688 ± 0.0109

@xuechendi
Copy link
Copy Markdown
Author

@michalkuligowski
Copy link
Copy Markdown

/run-gaudi-tests

@xuechendi
Copy link
Copy Markdown
Author

xuechendi commented Apr 25, 2025

I verified failed CI locally, and they are working OK
And I realized that since we haven't get HabanaAI/vllm-hpu-extension#161 merged firstly, so it is lacking vllm-hpu-extension functions in this PR as well.

Please help to review HabanaAI/vllm-hpu-extension#161 firstly

@michalkuligowski
Copy link
Copy Markdown

@xuechendi you can update requiremets/hpu.txt for the related extension change to be utilized. We cant merge it before it is tested in this PR

@xuechendi
Copy link
Copy Markdown
Author

/run-gaudi-tests

@xuechendi xuechendi enabled auto-merge (squash) May 5, 2025 22:17
@xuechendi
Copy link
Copy Markdown
Author

CI failed on getting resource:
image

Comment thread vllm/worker/hpu_model_runner.py Outdated
Comment thread vllm/worker/hpu_model_runner.py Outdated
@michalkuligowski
Copy link
Copy Markdown

/skip-gaudi-tests - two multimodal jobs failing is false negative, the undrlying tests are passing

Copy link
Copy Markdown

@kwisniewski98 kwisniewski98 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of changes that I'm not sure if are upstreamable, we'd have to ask maintainers.
Also, I see workarounds that were introduced by me, about which I'm also not sure if we can introduce to main, i. e. MLA not supporting V1 and issues with torch compile

Comment thread vllm/attention/backends/hpu_attn.py
Comment thread vllm/model_executor/layers/fused_moe/fused_moe.py Outdated
@xuechendi
Copy link
Copy Markdown
Author

xuechendi commented May 6, 2025

There is a lot of changes that I'm not sure if are upstreamable, we'd have to ask maintainers. Also, I see workarounds that were introduced by me, about which I'm also not sure if we can introduce to main, i. e. MLA not supporting V1 and issues with torch compile

This PR is aming to cherry-pick 1.21 deepseek support to habana_main. I don't want to push more changes to this PR to make the impl being very different to the one we merged for 1.21.
Let's aim V1 support and torch compile in next PR.
I hope to get this PR merged, so llama4 and Qwen3 can share the same MOE changes.

@xuechendi
Copy link
Copy Markdown
Author

/run-gaudi-tests

xuechendi and others added 6 commits May 7, 2025 17:43
migrated from a PR to habana_main:
#1014

For Best performance, this PR is recommended to run with INC:
[[SW-223553] [VLLM] Merge deepseek changes into habana_main - Habana
Labs](https://jira.habana-labs.com/browse/SW-223553)

**test acc of G3**:
```bash
huggingface-cli download Yi30/inc-woq-default-pile-one-cache-408  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./inc-woq-default-pile-one-cache-408-for-fp8-mla/inc_measure_output"
}

QUANT_CONFIG=inc_quant_with_fp8kv_config.json \
PT_HPU_LAZY_MODE=1 \
VLLM_SKIP_WARMUP=true \
PT_HPU_ENABLE_LAZY_COLLECTIVES=true \
PT_HPU_WEIGHT_SHARING=0 \
VLLM_MLA_DISABLE_REQUANTIZATION=1 \
lm_eval --model vllm \
  --model_args "pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc" \
  --tasks gsm8k --num_fewshot "5" --limit "256" \
  --batch_size "8"
```

**test acc of G2**:
**convert original DeepSeek-R1** using
[convert_for_g2.py](https://github.com/yangulei/vllm-fork/blob/deepseek_r1_g2/scripts/convert_for_g2.py)
(this step will be removed as INC updates.)

```bash

huggingface-cli download Yi30/inc-woq-default-pile-one-cache-412-g2  --local-dir ./scripts/nc_workspace_measure_kvache

cat inc_quant_with_fp8kv_config.json
{
    "mode": "QUANTIZE",
    "observer": "maxabs",
    "scale_method": "maxabs_hw",
    "scale_format": "const",
    "allowlist": {
        "types": [],
        "names": []
    },
    "blocklist": {
        "types": [],
        "names": [
            "lm_head",
            "mlp\\.gate\\b",
            "block2batch_matmul"
        ]
    },
    "dump_stats_path": "./nc_workspace_measure_kvache/inc_measure_output"
}
```

vllm
(pretrained=/mnt/weka/data/pytorch/DeepSeek-R1/,tensor_parallel_size=8,distributed_executor_backend=mp,trust_remote_code=true,max_model_len=4096,use_v2_block_manager=True,dtype=bfloat16,kv_cache_dtype=fp8_inc),
gen_kwargs: (None), limit: 256.0, num_fewshot: 5, batch_size: 128
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|

|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.9492|± |0.0137|
| | |strict-match | 5|exact_match|↑ |0.9453|± |0.0142|

----------
Need to use vllm-hpu-extension:
https://github.com/HabanaAI/vllm-hpu-extension/tree/dev/chendi/deepseek_r1

Status:

runnable with Deepseek-R1.
Accuracy check: for block fp8 weight => garbage output
accuracy check for BF16 weight => looks good.

test scripts:
```
from vllm import LLM, SamplingParams
import os

os.environ['VLLM_SKIP_WARMUP'] = 'true'
os.environ['PT_HPU_LAZY_MODE'] = '1'
os.environ['PT_HPU_ENABLE_LAZY_COLLECTIVES']='true'
os.environ['PT_HPU_WEIGHT_SHARING']='0'
os.environ['VLLM_MLA_DISABLE_REQUANTIZATION']='1'

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

if __name__ == "__main__":
    # Create a sampling params object.
    sampling_params = SamplingParams(temperature=0.0, max_tokens=16, ignore_eos=True)

    # Create an LLM.
    model_path = "/data/models/DeepSeek-R1"

    llm = LLM(model=model_path,
            trust_remote_code=True,
            enforce_eager=True,
            dtype="bfloat16",
            use_v2_block_manager=True,
            max_model_len=1024,
            max_num_seqs=1,
            tensor_parallel_size=8,
            distributed_executor_backend='mp',
            gpu_memory_utilization=0.8,
            #kv_cache_dtype="fp8_inc",
            seed=2024)

    # Generate texts from the prompts. The output is a list of RequestOutput objects
    # that contain the prompt, generated text, and other information.
    outputs = llm.generate(prompts, sampling_params)
    # Print the outputs.
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0].text
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
    if os.environ.get("QUANT_CONFIG", None) is not None:
        llm.llm_engine.model_executor.shutdown()
```

---------

Signed-off-by: Chendi.Xue <chendi.xue@intel.com>
Signed-off-by: kwisniewski98 <kwisniewski@habana.ai>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Co-authored-by: kwisniewski98 <kwisniewski@habana.ai>
Use fp32 `gating_output` instead of adding `mark_step()` to fix the
accuracy issues in
117555d
.
This will reduce the graph replay duration from ~41ms to ~32ms for
decoding phase of 16k/1k bs=16 benchmark on Gaudi2.
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
… multiple cards (#1100)

- Add `VLLM_DISABLE_MARK_SCALES_AS_CONST=true` for speed up the warmup
stage.
- Fix the `dist.barrier` issue for single card

cc @xuechendi @thuang6

---------

Signed-off-by: Yi Liu <yiliu4@habana.ai>
Co-authored-by: Yi Liu <yiliu4@habana.ai>
Previously it was only checking if it is using quant_config and choosing
VllmMixtureOfExpertsOpFP8 as OP, which only difference is that when
measuring scales it is assuming block quant. This will only happen when
we are using Fp8MoEMethod as quant_method.
Kwargs in moe_op call had to be disabled, beacuse of different apis of
FP8 and unquantized

---------

Signed-off-by: kwisniewski98 <kwisniewski@habana.ai>

Signed-off-by: Chendi Xue <chendi.xue@intel.com>
xuechendi added 2 commits May 7, 2025 17:44
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
@xuechendi xuechendi force-pushed the dev/vllmfork_deepseek_r1 branch from a808b51 to 27e267c Compare May 7, 2025 14:45
@xuechendi xuechendi requested a review from mswiniarsk as a code owner May 7, 2025 14:45
@xuechendi
Copy link
Copy Markdown
Author

/run-gaudi-tests

Copy link
Copy Markdown

@jikunshang jikunshang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q_proj removed. need change.

xuechendi added 3 commits May 7, 2025 21:52
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
@xuechendi
Copy link
Copy Markdown
Author

/run-gaudi-tests

@xuechendi xuechendi merged commit ae79743 into habana_main May 8, 2025
43 checks passed
@xuechendi xuechendi deleted the dev/vllmfork_deepseek_r1 branch May 8, 2025 00:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants