Skip to content

Comments

Z1/2 init: flatten params on device#7828

Merged
stas00 merged 24 commits intodeepspeedai:masterfrom
ksugama:flatten-tensor-gpu
Feb 13, 2026
Merged

Z1/2 init: flatten params on device#7828
stas00 merged 24 commits intodeepspeedai:masterfrom
ksugama:flatten-tensor-gpu

Conversation

@ksugama
Copy link
Contributor

@ksugama ksugama commented Feb 3, 2026

This PR addresses #7677 by flattening parameter tensors on the accelerators instead of the CPU during zero stage 1 and 2 initialization. This should alleviate CPU contention, with the caveat that the optimization is only used when there is enough VRAM to allocate a full copy of the parameter buffers.

On 8 x H100s and a Intel Xeon Platinum 8480+, profiling the initialization of DeepSpeed on 32 layers of Qwen3-30B with Z2 gives the following:

Old = ~382s
New = ~130s


If necessary, this optimization can be extended to allowed a tiered system that trades off VRAM space with performance, which might look like the following:

if enough VRAM for 2x model_size:
    naive flatten
else if enough VRAM for model_size / N:
    distributed flatten across N devices
else:
    flatten on CPU

The distributed flatten would involve each device flattening a portion of the parameters and performing an all-gather to assemble the full flattened model. See #7677 for original discussion.

@ksugama ksugama force-pushed the flatten-tensor-gpu branch from a07a21b to 293fbab Compare February 3, 2026 17:19
@ksugama ksugama changed the title Z1/2 Flatten Parameters on device Z1/2 init: flatten params on device Feb 3, 2026
@sfc-gh-truwase
Copy link
Collaborator

@ksugama thanks for working on this. Looking forward to when it is ready for review.

tohtana and others added 20 commits February 9, 2026 16:24
Fix deepspeedai#7812: This PR makes DeepSpeedEngine cleanup safe for partial
initialization.

This prevents destructor-time tracebacks by guarding access to
unitialized attributes of DeepSpeed engine.

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Evoformer tests fail with this error. We ignore this in the full test
for now.

```
RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing, you must use the 'spawn' start method
```

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Two spelling errors in docs/_tutorials/accelerator-setup-guide.md:

Line 50: comma-seperated-dash-range → comma-separated-dash-range
Line 97: optimzied → optimized
Both typos are in the Intel Architecture CPU section of the accelerator
setup guide.

Signed-off-by: leejianwoo-collab <leejianwoo@gmail.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
…er_rank (deepspeedai#7817)

Fix deepspeedai#7811
Issue: ZeRO-3 crashes when `zero.GatheredParameters` is used with
`modifier_rank=None` and a parameter is modified in-place. The failure
surfaced as an internal invariant assert in free_param, which is not
actionable for users.

Solution: `GatheredParameters` now detects in-place mutation with
`modifier_rank=None` and raises a clear, user-facing error early. The
mutation check is synchronized across ranks to avoid divergence and
hangs.
This PR also raises a clearer error from free_param when parameters are
still active in submodules.”

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
…rld_size in Ulysses (deepspeedai#7809)

### Description
This PR addresses Issue deepspeedai#7672.

When sequence_parallel_size is smaller than world_size (e.g., sp_size=2
on 4 GPUs) with PyTorch < 2.3, using
torch.distributed.nn.functional.all_gather for loss aggregation triggers
an IndexError: tuple index out of range during the backward pass. This
is due to a known PyTorch issue where the backward hook accesses the
global rank instead of the group rank.

### Solution
1. Regression Test & Workaround: Updated the regression test
TestUlyssesLossBackward to implement a Weighted All-Reduce pattern.
- Before: all_gather -> manual sum (Vulnerable to rank indexing mismatch
on older PyTorch).
- After: all_reduce(weighted_loss) / all_reduce(total_weight) (Robust
and supports weighted averaging).
2. Runtime Warning: Added a version check (required_torch_version) in
DeepSpeedEngine. It now logs a warning if Sequence Parallelism is
enabled on PyTorch < 2.3, providing a link to the workaround test case.
3. Documentation: Updated ulysses-alst-sequence-parallelism.md with a
note regarding legacy PyTorch versions and the recommended workaround.

### Verification
Added and verified the regression test
tests/unit/sequence_parallelism/test_ulysses.py which now validates the
weighted averaging logic.

**1. Reproduction (Before Fix)**
Confirmed IndexError crash on Rank 2/3 with sp_size=2 on a 4-GPU setup.
<img width="1370" height="860" alt="Screenshot 2026-01-23 at 23 53 42"
src="https://github.com/user-attachments/assets/f4005c02-ff6c-46ea-a1a7-caac2093128b"
/>

**2. Verification (After Fix)**
Verified the fix using the regression test logic on 4x RTX A6000. The
backward pass now completes successfully on all ranks without error.
<img width="1192" height="605" alt="Screenshot 2026-01-23 at 23 52 54"
src="https://github.com/user-attachments/assets/c14cd093-67b7-42b0-ae15-65555c129082"
/>

---------

Signed-off-by: vensen <vensenmu@gmail.com>
Co-authored-by: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Update PyTorch to v2.9 for modal tests

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Fix deepspeedai#7824

For [leaf
modules](https://deepspeed.readthedocs.io/en/latest/training.html#configuring-zero-leaf-modules),
ZeRO3 manages all parameters within the module uniformly. When a module
returns multiple output tensors, PyTorch's autograd can trigger backward
hooks from multiple threads concurrently. This causes race conditions
when multiple threads simultaneously modify __inflight_param_registry
and parameter states.

This PR ensures that for leaf modules, only one thread performs the
actual parameter fetching work while other concurrent threads wait and
return early, preventing the race condition.

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
hi deepspeed!

[Make it possible to evaluate when using sequence parallel in HF Trainer
#43517](huggingface/transformers#43517)

I initially opened a PR in transformers to prevent errors that occur
when running eval with deepspeed sequence parallel applied, but there
was feedback that this should be handled in deepspeed rather than in
transformers or accelerate, so I'm opening this PR here.

If you have any questions, feel free to ask.

Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
This PR introduces a flexible, configuration-driven API for AutoTP
(Automatic Tensor Parallelism) that allows users to define custom layer
partitioning patterns for training.
@inkcherry @delock

## Motivation

Previously, AutoTP relied on hardcoded layer detection logic that was
difficult to customize for new model architectures. This PR enables:

1. **Custom models**: Users can define exact regex patterns to match
their model's parameter names
2. **Fused layers**: Support for fused QKV, gate_up_proj, and other
packed weight matrices with unequal sub-parameter sizes (e.g., GQA with
different Q/K/V dimensions)
3. **Extensibility**: Easy to add new model presets or customize
existing ones

Here is an example of a config including custom partitioning patterns:

```json
{
    "tensor_parallel": {
        "autotp_size": 4,
        "partition_config": {
            "use_default_specs": false,
            "layer_specs": [
                {
                    "patterns": [".*\\.o_proj\\.weight$", ".*\\.down_proj\\.weight$"],
                    "partition_type": "row"
                },
                {
                    "patterns": [".*\\.[qkv]_proj\\.weight$"],
                    "partition_type": "column"
                },
                {
                    "patterns": [".*\\.gate_up_proj\\.weight$"],
                    "partition_type": "column",
                    "shape": [2, -1],
                    "partition_dim": 0
                }
            ]
        }
    }
}
```

Refer to the
[document](https://github.com/tohtana/DeepSpeed/blob/tohtana/autotp_custom_patterns/docs/code-docs/source/training.rst)
for more details (including preset models and how to define partitioning
for fused models).
We also opened a new
[PR](deepspeedai/DeepSpeedExamples#998) to show
the usage.

## Simplified initialization step

AutoTP previously required calling ``set_autotp_mode(training=True)``
and ``deepspeed.tp_model_init`` before ``deepspeed.initialize``. Now we
can include all the necessary configurations in the DeepSpeed config.

We still support the traditional initialization path for backward
compatibility.
When you use both (i.e. calling ``set_autotp_mode(training=True)`` and
``deepspeed.tp_model_init`` and passing the config to
``deepspeed.initialize``), we will merge the settings at initialization.
When we have conflicting settings, we will error out.

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
Correctly handle `ds_grad_is_ready` in ZeRO2

---------

Signed-off-by: Olatunji Ruwase <tunji.ruwase@snowflake.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
The current code has the following issues:
- `use_default_specs: false` doesn't work
- Injection by the traditional pattern runs even when custom patterns
are set
- `mpu` needs to be passed to `deepspeed.initialize` (HF integration
doesn't pass mpu)

This PR fixes AutoTP setup to respect `use_default_specs: false` and
disable the traditional injection path when custom patterns are enabled.
Also, when `mpu` is not passed, we create a TP group in the
initialization process.

With these changes, the [related
tests](https://github.com/deepspeedai/DeepSpeed/tree/master/tests/unit/model_parallelism)
pass and [all AutoTP
examples](https://github.com/tohtana/DeepSpeedExamples/tree/tohtana/custom_auto_tp/training/tensor_parallel)
in DeepSpeedExamples work now
([PR](deepspeedai/DeepSpeedExamples#998)).

---------

Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
Signed-off-by: Kento Sugama <kentosugama@protonmail.ch>
@ksugama ksugama force-pushed the flatten-tensor-gpu branch from bf577e8 to 3610631 Compare February 9, 2026 21:24
@ksugama ksugama marked this pull request as ready for review February 9, 2026 21:25
empty_cache()
see_memory_usage(f"After moving param group {i} to CPU", force=False)
accelerator = get_accelerator()
available_vram = accelerator.available_memory() if accelerator.is_available() else 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 I believe available_memory() eventually calls into nvidia-smi. Is this the foot gun you were warning about?

If it is, maybe this should be fixed in a different PR since that problem touches more than is related to these changes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remind me what issue I flagged that you are referring to?

available_memory is here:

def available_memory(self, device_index=None):
if pynvml:
if device_index is None:
device_index = self.current_device()
handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index))
info = pynvml.nvmlDeviceGetMemoryInfo(handle)
return info.free
else:
return self.total_memory(device_index) - self.memory_allocated(device_index)

Copy link
Contributor Author

@ksugama ksugama Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mentioned in the discussion in #7677 that in order to check device memory, nvml should not be used and PyTorch API should be used instead

Copy link
Collaborator

@stas00 stas00 Feb 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thank you for the refresher, Kento. I see my comment wasn't very precise. I was trying to say that it doesn't work on all devices.

We do want to use nvml in the general case when we want to see the total stats - since any gpu could be shared by multiple processes and torch's counters can only see its own process.

BTW, torch.cuda.memory.mem_get_info is an alternative to pynvml stats, but I think it'll fail just the same on gb10.

So I think the solution can be this:

 def available_memory(self, device_index=None): 
     this_process_view_available_mem = self.total_memory(device_index) - self.memory_allocated(device_index)
     if pynvml: 
         if device_index is None: 
             device_index = self.current_device() 
         handle = pynvml.nvmlDeviceGetHandleByIndex(self._get_nvml_gpu_id(device_index)) 
         try:
             # gb10 will fail this call
             info = pynvml.nvmlDeviceGetMemoryInfo(handle) 
             return info.free
         except:
             return this_process_view_available_mem
     else: 
         return this_process_view_available_mem

cc: @sfc-gh-truwase

@ksugama
Copy link
Contributor Author

ksugama commented Feb 9, 2026

@stas00 Hoping this alleviates the issue you were running into

@ksugama
Copy link
Contributor Author

ksugama commented Feb 12, 2026

@sfc-gh-truwase Would you be able to approve the CI workflows whenever you get the chance? I ran them on a lambda instance and they pass on my end

@stas00
Copy link
Collaborator

stas00 commented Feb 13, 2026

@stas00 Hoping this alleviates the issue you were running into

@ksugama, if it solves the problem described in the OP of #7677 then it's fantastic.

I will try to find time to run this check myself.

@stas00
Copy link
Collaborator

stas00 commented Feb 13, 2026

edit: my repro isn't measuring the right thing, since I see it's taking forever for Transformers to load the shards when doing multi-gpu. your PR is definitely making things load much faster.

I think there is some other inefficiency in HF Transformers's DS handling which makes the checkpoint shards load much slower on 8 gpus as compared to 1 gpu.

Copy link
Collaborator

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing work, thank you, @ksugama.

So many people are going to have a much better experience with DS w/ huge models.

@stas00 stas00 merged commit 84af822 into deepspeedai:master Feb 13, 2026
9 checks passed
@ksugama
Copy link
Contributor Author

ksugama commented Feb 13, 2026

Incredible! Glad I can help. Thank you for reviewing

@ksugama ksugama deleted the flatten-tensor-gpu branch February 13, 2026 16:52
@stas00
Copy link
Collaborator

stas00 commented Feb 13, 2026

https://x.com/StasBekman/status/2022354880049082658

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants