-
Notifications
You must be signed in to change notification settings - Fork 475
Description
Bug report
I converted the Qwen3 30B A3B MoE model as instructed in the example
hf_transfer download Qwen/Qwen3-30B-A3B-Base --local-dir /tmp/qwen3_hf_checkpoint
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_qwen3_moe\
--base_model_path=/tmp/qwen3_hf_checkpoint \
--maxtext_model_path=gs://${my_bucket} \
--model_size=qwen3-30b-a3bThe following error occurs when I try to (continuously) pre-train the model converted from HF. The same error occurs when using the all 4 dropless method.
For the dense training, I used the following command (executed with the multihost runner):
python3 -m maxtext.trainers.pre_train.train src/MaxText/configs/base.yml \
base_output_directory=gs://${my_bucket} \
tokenizer_path=src/MaxText/assets/tokenizers/qwen3-tokenizer \
tokenizer_type=huggingface\
load_parameters_path=gs://${my_bucket} \
run_name=some_training \
max_target_length=4096 \
async_checkpointing=true \
model_name='qwen3-30b-a3b' \
dataset_type=grain \
grain_file_type=arrayrecord \
grain_train_files=${my_data}/*.array_record \
sparse_matmul=False \
capacity_factor=-1 \
load_balance_loss_weight=0.02 \
routed_bias=True \
routed_scaling_factor=1.0 \
float32_weight_sum=True "Logs/Output
Traceback (most recent call last):
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 146, in _python_pjit_helper
out_flat, compiled, profiler, const_args = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/pjit.py", line 1642, in _pjit_call_impl_python
return (compiled.unsafe_call(*computation.const_args, *args),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 359, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1366, in call
input_bufs = self.in_handler(args)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1249, in call
return self.handler(input_buffers)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/profiler.py", line 359, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 134, in shard_args
arg = dtypes.canonicalize_value(arg)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jed351/run5/maxtext_venv/lib/python3.12/site-packages/jax/_src/dtypes.py", line 388, in canonicalize_value
raise InvalidInputException(
jax._src.dtypes.InvalidInputException: Argument 'ShapeDtypeStruct(shape=(128, 48), dtype=float32, sharding=NamedSharding(mesh=Mesh('diloco': 1, 'data': 1, 'stage': 1, 'fsdp': 64, 'fsdp_transpose': 1, 'sequence': 1, 'context': 1, 'context_autoregressive': 1, 'tensor': 1, 'tensor_transpose': 1, 'tensor_sequence': 1, 'expert': 1, 'autoregressive': 1, axis_types=(Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto, Auto)), spec=PartitionSpec(None, 'stage'), memory_kind=device))' of type <class 'jax.ShapeDtypeStruct'> is not a valid JAX type.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/jed351/run5/src/MaxText/train.py", line 33, in
_new_module.main(sys.argv)
File "/home/jed351/run5/src/maxtext/trainers/pre_train/train.py", line 598, in main
run(config, recorder, diagnostic_config)
File "/home/jed351/run5/src/maxtext/trainers/pre_train/train.py", line 592, in run
train_loop(config, recorder)
File "/home/jed351/run5/src/maxtext/trainers/pre_train/train.py", line 472, in train_loop
state, metrics = p_train_step(state, example_batch, nextrng)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: Argument 'state.params['params']['decoder']['layers']['moe_block']['gate']['bias']' of shape float32[128,48] of type <class 'jax.ShapeDtypeStruct'> is not a valid JAX type.
Environment Information
TPU v6e-64
Software version: v2-alpha-tpuv6e
MaxText installed from source with uv
Additional Context
No response