Skip to content

[BUG] DeepSpeedEngine __del__/destroy throws AttributeError after initialize() failure (missing _deepcompile_active) + BrokenPipeError noise #7812

@griffinstalha

Description

@griffinstalha

Bug Description

When deepspeed.initialize() fails early (in this reproducer, a config validation error related to loss scaling), DeepSpeed later emits a destructor-time traceback:

  • DeepSpeedEngine.del calls self.destroy()
  • destroy() calls is_deepcompile_active()
  • is_deepcompile_active() unconditionally reads self._deepcompile_active
  • is_deepcompile_active() unconditionally reads self._deepcompile_active
  • But on this early-failure path, _deepcompile_active was never set, leading to:

AttributeError: 'DeepSpeedEngine' object has no attribute '_deepcompile_active'

Additionally, after the process exits, I also observe BrokenPipeError: [Errno 32] Broken pipe (“Exception ignored in: <_io.BufferedWriter …>”), which adds more post-run noise.
This is not just cosmetic: any CI/harness that treats stderr tracebacks as failures will flag the run even though the expected init failure was handled.

To Reproduce

Steps to reproduce the behavior:

  1. Create a Python venv and install compatible PyTorch + DeepSpeed.
  2. Save the script below as testcases/deepspeed_testcase.py.
  3. Run:
source ~/.venvs/dl_testing/bin/activate
export TF_CPP_MIN_LOG_LEVEL=2
export PYTHONNOUSERSITE=1
deepspeed --num_gpus=1 testcases/deepspeed_testcase.py

  1. Observe: script prints Test Passed ✅ (meaning it caught the expected init failure), but then DeepSpeed prints the destructor traceback and BrokenPipeError.

Expected behavior
If deepspeed.initialize() fails and the exception is handled by the caller, DeepSpeed should not emit any additional destructor-time traceback. Cleanup should be safe for partially-initialized DeepSpeedEngine objects.

Actual behavior
After the expected init failure is handled, DeepSpeed prints:

  • Exception ignored in: <function DeepSpeedEngine.del ...>
  • AttributeError: 'DeepSpeedEngine' object has no attribute '_deepcompile_active'
  • BrokenPipeError: [Errno 32] Broken pipe

Reproducer script

import sys
import os
import json
import tempfile
import random

EXPECTED_SUBSTR = "requires dynamic loss scaling"

_ENGINE = None
_CFG_PATH = None


def _skip(reason: str):
    print(f"SKIP_ENV: {reason}")
    sys.exit(0)


def _cleanup_best_effort():
    global _ENGINE, _CFG_PATH
    try:
        if _ENGINE is not None:
            try:
                _ENGINE.destroy()
            except Exception:
                pass
    except Exception:
        pass

    try:
        import torch.distributed as dist
        if dist.is_available() and dist.is_initialized():
            dist.destroy_process_group()
    except Exception:
        pass

    try:
        if _CFG_PATH and os.path.exists(_CFG_PATH):
            os.remove(_CFG_PATH)
    except Exception:
        pass


def _pass():
    _cleanup_best_effort()
    print("Test Passed ✅")
    sys.exit(0)


def _fail():
    _cleanup_best_effort()
    print("Test Failed ❌")
    sys.exit(0)


def _seed_all(seed: int):
    random.seed(seed)
    try:
        import numpy as np
        np.random.seed(seed)
    except Exception:
        pass
    try:
        import torch
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
    except Exception:
        pass


def _write_json_config(cfg: dict) -> str:
    fd, path = tempfile.mkstemp(prefix="ds_cfg_", suffix=".json")
    os.close(fd)
    with open(path, "w", encoding="utf-8") as f:
        json.dump(cfg, f, indent=2, sort_keys=True)
    return path


def main():
    global _ENGINE, _CFG_PATH

    try:
        try:
            import torch
            import torch.nn as nn
        except Exception as e:
            _skip(f"torch missing: {e}")

        try:
            import deepspeed
        except Exception as e:
            _skip(f"deepspeed missing: {e}")

        if not torch.cuda.is_available():
            _skip("CUDA not available")

        _seed_all(1337)

        class TinyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin = nn.Linear(8, 4)

            def forward(self, x):
                return self.lin(x)

        model = TinyModel().cuda()

        ds_config = {
            "train_batch_size": 1,
            "train_micro_batch_size_per_gpu": 1,
            "gradient_accumulation_steps": 1,
            "bf16": {"enabled": True},
            "fp16": {"enabled": False, "loss_scale": 128},
            "loss_scale": 128,
            "optimizer": {"type": "Lamb", "params": {"lr": 1e-3}},
            "zero_optimization": {"stage": 0},
        }

        _CFG_PATH = _write_json_config(ds_config)

        try:
            engine, _, _, _ = deepspeed.initialize(
                model=model,
                model_parameters=list(model.parameters()),
                config=_CFG_PATH,
            )
            _ENGINE = engine
        except BaseException as e:
            if EXPECTED_SUBSTR in str(e).lower():
                _pass()
            print(f"DEBUG_EXCEPTION_INIT: {e}")
            _fail()

        _fail()

    except SystemExit:
        raise
    except BaseException as e:
        _cleanup_best_effort()
        print(f"HARNESS_ERROR: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()

Triggering Command

source ~/.venvs/dl_testing/bin/activate
export TF_CPP_MIN_LOG_LEVEL=2
PYTHONNOUSERSITE=1 deepspeed --num_gpus=1 deepspeed_testcase.py

Output:

Test Passed ✅
Exception ignored in: <function DeepSpeedEngine.__del__>
Traceback (most recent call last):
  File ".../deepspeed/runtime/engine.py", line 565, in __del__
    self.destroy()
  File ".../deepspeed/runtime/engine.py", line 570, in destroy
    if self.is_deepcompile_active():
  File ".../deepspeed/runtime/engine.py", line 4358, in is_deepcompile_active
    return self._deepcompile_active
AttributeError: 'DeepSpeedEngine' object has no attribute '_deepcompile_active'

Expected behavior

If deepspeed.initialize() fails due to configuration validation (e.g., the expected “requires dynamic loss scaling” error), DeepSpeed should not emit an additional destructor traceback. Cleanup should be exception-safe for partially constructed DeepSpeedEngine instances.

DS Report

------------------------------------------------
DeepSpeed C++/CUDA extension op report
--------------------------------------------------
NOTE: Ops not installed will be just-in-time (JIT) compiled at
      runtime if needed. Op compatibility means that your system
      meet the required dependencies to JIT install the op.
--------------------------------------------------
JIT compiled ops requires ninja
ninja .................. [OKAY]
--------------------------------------------------
op name ................ installed .. compatible
--------------------------------------------------
 [WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
async_io ............... [NO] ....... [NO]
fused_adam ............. [NO] ....... [OKAY]
cpu_adam ............... [NO] ....... [OKAY]
cpu_adagrad ............ [NO] ....... [OKAY]
cpu_lion ............... [NO] ....... [OKAY]
dc ..................... [NO] ....... [OKAY]
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
evoformer_attn ......... [NO] ....... [NO]
fp_quantizer ........... [NO] ....... [OKAY]
fused_lamb ............. [NO] ....... [OKAY]
fused_lion ............. [NO] ....... [OKAY]
gds .................... [NO] ....... [NO]
transformer_inference .. [NO] ....... [OKAY]
inference_core_ops ..... [NO] ....... [OKAY]
cutlass_ops ............ [NO] ....... [OKAY]
quantizer .............. [NO] ....... [OKAY]
ragged_device_ops ...... [NO] ....... [OKAY]
ragged_ops ............. [NO] ....... [OKAY]
random_ltd ............. [NO] ....... [OKAY]
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.6
 [WARNING]  using untested triton version (3.2.0), only 1.0.0 is known to be compatible
sparse_attn ............ [NO] ....... [NO]
spatial_inference ...... [NO] ....... [OKAY]
transformer ............ [NO] ....... [OKAY]
stochastic_transformer . [NO] ....... [OKAY]
utils .................. [NO] ....... [OKAY]
--------------------------------------------------
DeepSpeed general environment info:
torch install path ............... ['/home/talha/.venvs/dl_testing/lib/python3.13/site-packages/torch']
torch version .................... 2.6.0+cu124
deepspeed install path ........... ['/home/talha/.venvs/dl_testing/lib/python3.13/site-packages/deepspeed']
deepspeed info ................... 0.18.4, unknown, unknown
torch cuda version ............... 12.4
torch hip version ................ None
nvcc version ..................... 12.0
deepspeed wheel compiled w. ...... torch 0.0, cuda 0.0
shared memory (/dev/shm) size .... 62.88 GB

Screenshots

Not applicable (terminal traceback included above).

System info

  • OS: Ubuntu 24.04.3 LTS (noble)
  • GPU count and types: 4 × NVIDIA GeForce RTX 3090 (repro ran with --num_gpus=1, CUDA_VISIBLE_DEVICES=0)
  • Python version: 3.13.5
  • PyTorch version: 2.6.0+cu124
  • DeepSpeed version: 0.18.4
  • NVIDIA driver: 550.78
  • CUDA (driver-reported): 12.4
  • torch cuda version: 12.4
  • nvcc version (ds_report): 12.0

Python version

Python: 3.13.5

Launcher context

Launched with the DeepSpeed launcher:
deepspeed --num_gpus=1 deepspeed_testcase.py

Docker context

Not using Docker (bare-metal / standard venv)

Additional context

This appears to be an unsafe cleanup path for partially-constructed DeepSpeedEngine objects. destroy() assumes _deepcompile_active exists. A defensive fix would be to initialize _deepcompile_active very early in engine construction (before any failure points), or guard its use:

  • getattr(self, "_deepcompile_active", False) in is_deepcompile_active(), and/or
  • ensure del never emits exceptions.

The BrokenPipeError stderr noise appears after the launcher reports successful exit; even if it’s a launcher pipe/stream shutdown issue, it’s still undesirable log pollution.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtraining

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions