Skip to content

Qwen3 MoE Conversion / Pre-Training Error #3231

@jedcheng

Description

@jedcheng

Bug report

I converted the Qwen3 30B A3B MoE model as instructed in the example

hf_transfer download Qwen/Qwen3-30B-A3B-Base --local-dir /tmp/qwen3_hf_checkpoint

python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_qwen3_moe\
  --base_model_path=/tmp/qwen3_hf_checkpoint \
  --maxtext_model_path=gs://${my_bucket} \
  --model_size=qwen3-30b-a3b

The following error occurs when I try to (continuously) pre-train the model converted from HF. The same error occurs when using the all 4 dropless method.

For the dense training, I used the following command (executed with the multihost runner):

python3 -m maxtext.trainers.pre_train.train src/MaxText/configs/base.yml \
    base_output_directory=gs://${my_bucket} \
    tokenizer_path=src/MaxText/assets/tokenizers/qwen3-tokenizer \
    tokenizer_type=huggingface\
    load_parameters_path=gs://${my_bucket} \
    run_name=some_training \
    max_target_length=4096 \
    async_checkpointing=true \
    model_name='qwen3-30b-a3b' \
    dataset_type=grain \
    grain_file_type=arrayrecord \
    grain_train_files=${my_data}/*.array_record  \
    sparse_matmul=False \
    capacity_factor=-1 \
    load_balance_loss_weight=0.02 \
    routed_bias=True \
    routed_scaling_factor=1.0 \
    float32_weight_sum=True "

Logs/Output

Traceback (most recent call last):
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 146, in _python_pjit_helper
out_flat, compiled, profiler, const_args = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1642, in _pjit_call_impl_python
return (compiled.unsafe_call(*computation.const_args, *args),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 359, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1366, in call
input_bufs = self.in_handler(args)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1249, in call
return self.handler(input_buffers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 359, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 134, in shard_args
arg = dtypes.canonicalize_value(arg)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/dtypes.py", line 388, in canonicalize_value
raise InvalidInputException(
jax._src.dtypes.InvalidInputException: Argument 'ShapeDtypeStruct(shape=(128, 48), dtype=float32, sharding=NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 64, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(None, 'stage'), memory_kind=device))' of type <class 'jax.ShapeDtypeStruct'> is not a valid JAX type.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
File "/home/jed351/run5/src/MaxText/train.py", line 33, in
_new_module.main(sys.argv)
File "/home/jed351/run5/src/maxtext/trainers/pre_train/train.py", line 598, in main
run(config, recorder, diagnostic_config)
File "/home/jed351/run5/src/maxtext/trainers/pre_train/train.py", line 592, in run
train_loop(config, recorder)
File "/home/jed351/run5/src/maxtext/trainers/pre_train/train.py", line 472, in train_loop
state, metrics = p_train_step(state, example_batch, nextrng)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Argument 'state.params['params']['decoder']['layers']['moe_block']['gate']['bias']' of shape float32[128,48] of type <class 'jax.ShapeDtypeStruct'> is not a valid JAX type.

Environment Information

TPU v6e-64
Software version: v2-alpha-tpuv6e
MaxText installed from source with uv

Additional Context

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions