Skip to content

Comments

Add mimo-audio model#132

Open
Mozoltov821 wants to merge 23 commits intojax-ml:mainfrom
Mozoltov821:feat-support-mimo-audio
Open

Add mimo-audio model#132
Mozoltov821 wants to merge 23 commits intojax-ml:mainfrom
Mozoltov821:feat-support-mimo-audio

Conversation

@Mozoltov821
Copy link

Resolves #91

Adds MiMo-Audio model implementation in JAX/Flax NNX.

Reference

Paper: {https://github.com/XiaomiMiMo/MiMo-Audio/blob/main/MiMo-Audio-Technical-Report.pdf}
Website: {https://github.com/XiaomiMiMo/MiMo-Audio}
Model code: {https://github.com/XiaomiMiMo/MiMo-Audio/tree/main/src}
Model weights: {https://huggingface.co/collections/XiaomiMiMo/mimo-audio}

Checklist

  • I have read the Contribution Guidelines and used pre-commit hooks to format this commit.
  • I have added all the necessary unit tests for my change. (run_model.py for model usage, test_outputs.py and/or model_validation_colab.ipynb for quality).
  • (If using an LLM) I have carefully reviewed and removed all superfluous comments or unneeded, commented-out code. Only necessary and functional code remains.
  • I have signed the Contributor License Agreement (CLA).

@pengchengneo
Copy link

@jenriver @chapman20j please review this pr, thanks !

model_ckpt_path = snapshot_download(model_name)

config = modeling.ModelConfig.qwen2_7b(use_sharding=False)
# mesh, batch_shd = None, None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this commented out code?

Comment on lines +35 to +42
query = [
"why sky is blue?",
]

tokenizer = AutoTokenizer.from_pretrained(
model_name,
trust_remote_code=True,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take advantage of the 120 character limit. Removing the trailing comma's on the last line will fix this with the auto-formatter.

self.batch_size = 32
self.num_input_tokens = 5
self.cache_size, self.gen_steps = 128, 10
self.relaxed_tol = 1e-3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this relaxed_tol? It is helpful to use the smallest numbers in the testing which cause the tests to pass. This helps us see how numerically precise each operation is. We typically go for using the same atol and rtol with one mantissa digit and one exponent digit (e.g. 2e-5).

tl[lp:, :],
rtol=self.relaxed_tol,
atol=self.relaxed_tol,
check_dtype=False,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please enable dtype checking on all tests.

Comment on lines +132 to +146
self.q_proj = shard(
nnx.Linear(cfg.emb_dim, cfg.num_heads * cfg.head_dim, use_bias=True, rngs=rngs), self.shd_cfg.q_weight_ndh
)
self.k_proj = shard(
nnx.Linear(cfg.emb_dim, cfg.num_kv_heads * cfg.head_dim, use_bias=True, rngs=rngs),
self.shd_cfg.kv_weight_ndh,
)
self.v_proj = shard(
nnx.Linear(cfg.emb_dim, cfg.num_kv_heads * cfg.head_dim, use_bias=True, rngs=rngs),
self.shd_cfg.kv_weight_ndh,
)
self.o_proj = shard(
nnx.Linear(cfg.num_heads * cfg.head_dim, cfg.emb_dim, use_bias=False, rngs=rngs), self.shd_cfg.o_weight_nhd
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These shardings may not be optimal because the attention weights are stored differently in this implementation. Have you had a chance to test the performance here?


TRANSFORM_LINEAR = Transform(permute=(1, 0))
TRANSFORM_NONE = Transform()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update this to match the Transform enum from other models?

Comment on lines +13 to +14
from bonsai.models.mimo_audio.mimo_audio_tokenizer_configuration import MiMoAudioTokenizerConfig
from bonsai.models.mimo_audio.mimo_audio_tokenizer_params import load_tokenizer_weights_from_safetensors
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move all these imports to the beginning of the file?

decoded_audio = tokenizer_model.decode(audio_tokens_array)
os.makedirs(output_dir, exist_ok=True)

import soundfile as sf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add additional dependencies to the pyproject.toml file

cls.num_input_tokens = 5
cls.group_size = cls.bonsai_config.group_size
cls.audio_channels = cls.bonsai_config.audio_channels
cls.tol = 1e-3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use atol and rtol based on each test. Refer to my comment in the qwen2 part of the PR for more information.


# Since the transformer library does not yet support mimo-audio-tokenizer,
# during testing, the official implementation code of mimo-audio-tokenizer (https://github.com/XiaomiMiMo/MiMo-Audio) needs to be copied to the corresponding location.
from bonsai.models.mimo_audio.pytorch.src.mimo_audio_tokenizer.modeling_audio_tokenizer import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This path isn't part of the PR. This will cause the tests to fail when making the PR. Could you include an install for this in the pyproject.toml file under testing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it looks like the comment before this code may be outdated. Huggingface supports tokenizers for this model. Could you please use that tokenizer in the testing?

return cls(**kwargs)

def create_qwen2_config(self) -> "Qwen2Config":
from bonsai.models.qwen2.modeling import ModelConfig as Qwen2Config
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this import to the beginning of the file?

from typing import TYPE_CHECKING
from bonsai.models.qwen3.modeling import ShardingCfg

if TYPE_CHECKING:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets make this just a regular import since it is used in the MiMoAudioConfig

emb_dim=1024,
mlp_dim=4096,
num_heads=64,
head_dim=16, # 1024 // 64 = 16
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you remove this comment?

@jenriver
Copy link
Member

Hi, thanks for this addition. Could we split this into two PRs? Please merge Qwen2 first, then follow up with the MiMo components.

On a high level:
For Qwen2:

  • Formatting: In run_model.py, please remove the trailing comma in the tokenizer initialization so the auto-formatter works correctly.
  • Testing: Remove relaxed_tol. Please use strict tolerances (e.g., 1e-5) and enable check_dtype=True in assertions.
  • Params: Update the Transform class in params.py to match the Enum style used elsewhere in the repo.
  • Sharding: Please verify the sharding configuration. nnx.Linear uses 2D weights, but the config appears to expect N-dimensional axes.

For MiMo (in the follow-up PR):

  • CI Blocker: The test requiring manual file copying (ex: bonsai.models.mimo_audio.pytorch.src.mimo_audio.modeling_mimo_audio) is a blocker. Please mock the expected outputs or include a reference implementation within the PR.
  • Duplication: Remove the duplicate MiMoSampler class and import bonsai.utils.samplers instead.
  • Numerics: Ensure the manual JAX ISTFT implementation strictly matches PyTorch padding behavior, as this often diverges.

Comment on lines +7 to +10
from jax.sharding import PartitionSpec as P

Array = jnp.ndarray
ShardingSpec = P
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update this to

from jax import P
from jaxtyping import Array
from jax.sharding import PartitionSpec

Comment on lines +15 to +47
conv_weight: ShardingSpec # (in_channels, out_channels, kernel_size)
conv_bias: ShardingSpec # (out_channels,)

# Transformer weight sharding (shared by Encoder/Decoder/Vocoder)
attn_qkvo_weight: ShardingSpec # (d_model, d_model)
attn_qkv_bias: ShardingSpec # (d_model,)
attn_out_bias: ShardingSpec # (d_model,)

# FFN weight sharding
ffn_weight_in: ShardingSpec # (d_model, ffn_dim)
ffn_weight_out: ShardingSpec # (ffn_dim, d_model)
ffn_bias: ShardingSpec # (ffn_dim,) or (d_model,)

# LayerNorm/GroupNorm sharding
norm_scale: ShardingSpec # (dim,)
norm_bias: ShardingSpec # (dim,)

# Quantizer codebook sharding
codebook: ShardingSpec # (codebook_size, d_model)

# ConvTranspose1d weight sharding
conv_transpose_weight: ShardingSpec # (in_ch, out_ch, kernel)
conv_transpose_bias: ShardingSpec # (out_ch,)

# ISTFT related sharding
istft_linear_weight: ShardingSpec # (dim, n_fft+2)
istft_linear_bias: ShardingSpec # (n_fft+2,)
istft_window: ShardingSpec # (win_length,)

# Activation sharding
act_btd: ShardingSpec # [batch, time, d_model]
act_btnh: ShardingSpec # [batch, time, num_heads, head_dim]
act_btc: ShardingSpec # [batch, time, channels]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you update this dataclass to use PartitionSpec instead of ShardingSpec.

@@ -0,0 +1,210 @@
"""Configuration classes for MiMo Audio Tokenizer."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file can be combined with the mimo_audio_tokenizer.py file.

Comment on lines +16 to +25
@dataclass(frozen=True)
class Transform:
permute: tuple[int, ...] | None = None
reshape: tuple[int, ...] | None = None
reshape_first: bool = False


TRANSFORM_LINEAR = Transform(permute=(1, 0))
TRANSFORM_CONV1D = Transform(permute=(2, 1, 0))
TRANSFORM_NONE = Transform()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the Transform enum for consistency with other models?

TRANSFORM_NONE = Transform()


def _get_key_mapping(config: model_lib.MiMoAudioTokenizerConfig) -> dict[str, tuple[str, Transform]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a regex can simplify this function.

@@ -0,0 +1,110 @@
from dataclasses import dataclass
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These configs can be moved into the modeling.py file.

input_local_layers: int = 6
input_local_dim: int = 1024
input_full_attention: bool = True

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for consistency with other models, can we store the qwen2 config as a variable in this config. The other functions could then just assign to that variable and return this class.

Comment on lines +49 to +51
self.speech_vocab_sizes = [1025, 1025, 129, 129, 129, 129, 129, 129]
self.speech_empty_ids = [1024, 1024, 128, 128, 128, 128, 128, 128]
self.delay_pattern = [0, 1, 2, 3, 4, 5, 6, 7]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be part of the model config.

output = text_embeds + speech_grouped_embeds
return shard(output, self.shd_cfg.act_btd)

def forward(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we name this __call__ for consistency with other models?

Comment on lines +265 to +275
def _local_transformer_step_jit(
local_transformer: nnx.Module,
local_embeds: jnp.ndarray,
cache: Cache,
segment_ids: jnp.ndarray,
) -> Tuple[jnp.ndarray, Cache]:
x = local_embeds
for i, layer in enumerate(local_transformer.layers):
x = layer(x, cache[i], segment_ids)
hidden_state = local_transformer.final_norm(x)
return hidden_state, cache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for this function should be built into the local_forward function.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this looks pretty similar to the Qwen2 call method. Can you use the jitted-forward function from that implementation?

Comment on lines +289 to +295
@jax.jit
def local_forward_jit(
model: FlaxMiMoAudioForCausalLM,
local_embeds: jnp.ndarray,
key: jax.random.PRNGKey,
) -> jnp.ndarray:
return model.local_forward(local_embeds, key, local_sampler=None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is unused

Comment on lines +12 to +13
q_bias_flatten = Transform(reshape=(-1,))
kv_bias_flatten = Transform(reshape=(-1,))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use the Transform enum to maintain consistency with other implementations?

@Mozoltov821
Copy link
Author

@jenriver @chapman20j Thank you for the review and the detailed feedback.I ve gone through all the comments and will need some time to work through the requested changes. I plan to start addressing them shortly and will update the PR once I have meaningful progress.

coder0143 and others added 7 commits February 3, 2026 11:12
…-ml#130)

* adding readme, modeling and params files for dinov3

* added run and testing

* added readme and removed some basic imports

* removed ignore statement for ide

* Added colab link in markdown

* changed config and made changes wrt it, added all outputs, removed cosine similarity in tests

* adding more tests and some final changes

* final changes to params

* fixed ruff based text formatting

* using random init torch model in test_outputs and removed run_model due to gated repo.

* using hf_hub constants for path and pre-commit done for ruff

* updating tests

* rename test_outputs to be model specific

* inital commit for qwen3vl

* removed initial commit

* vjepa2 base fm, classifier and params done

* fixed params and testing for model porting

* updating for additional testing and stability

* more testing, ruff, classifier works well

* fixed foundation model, reformatted modeling.py, fixed testing

* adding readme

* name in readme 😅

* removed unnecessary self.config = config

* changes and fixes

* Fixed resnet link in README (jax-ml#127)

Co-authored-by: James Chapman <chapmanjames@google.com>

* Vjepa2 format fixes (jax-ml#128)

* vjepa2: Use opencv-python than torchcodec

* Refactor test_outputs and forward in modeling.py

* [CI] run_selective_tests: handle renamed paths (jax-ml#129)

* remove dinov3 output class to facilitate jit compilation

---------

Co-authored-by: Jen Ha <25069493+jenriver@users.noreply.github.com>
Co-authored-by: vfdev <vfdev.5@gmail.com>
Co-authored-by: James Chapman <chapmanjames@google.com>
Co-authored-by: Jen Ha <25069493+jenriver@users.noreply.github.com>
Change `- []` to `- [ ]` to ensure checkboxes are properly rendered by GitHub.
…l#131)

* implement state space caching

* refactor docstrings and comments

* convert unicode string to ASCII-only

* move create_empty_cache() to modeling.py

---------

Co-authored-by: James Chapman <chapmanjames@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Model Support : MiMo-Audio

8 participants