Conversation
|
@jenriver @chapman20j please review this pr, thanks ! |
| model_ckpt_path = snapshot_download(model_name) | ||
|
|
||
| config = modeling.ModelConfig.qwen2_7b(use_sharding=False) | ||
| # mesh, batch_shd = None, None |
There was a problem hiding this comment.
Can you remove this commented out code?
| query = [ | ||
| "why sky is blue?", | ||
| ] | ||
|
|
||
| tokenizer = AutoTokenizer.from_pretrained( | ||
| model_name, | ||
| trust_remote_code=True, | ||
| ) |
There was a problem hiding this comment.
Please take advantage of the 120 character limit. Removing the trailing comma's on the last line will fix this with the auto-formatter.
| self.batch_size = 32 | ||
| self.num_input_tokens = 5 | ||
| self.cache_size, self.gen_steps = 128, 10 | ||
| self.relaxed_tol = 1e-3 |
There was a problem hiding this comment.
Can you remove this relaxed_tol? It is helpful to use the smallest numbers in the testing which cause the tests to pass. This helps us see how numerically precise each operation is. We typically go for using the same atol and rtol with one mantissa digit and one exponent digit (e.g. 2e-5).
| tl[lp:, :], | ||
| rtol=self.relaxed_tol, | ||
| atol=self.relaxed_tol, | ||
| check_dtype=False, |
There was a problem hiding this comment.
Please enable dtype checking on all tests.
| self.q_proj = shard( | ||
| nnx.Linear(cfg.emb_dim, cfg.num_heads * cfg.head_dim, use_bias=True, rngs=rngs), self.shd_cfg.q_weight_ndh | ||
| ) | ||
| self.k_proj = shard( | ||
| nnx.Linear(cfg.emb_dim, cfg.num_kv_heads * cfg.head_dim, use_bias=True, rngs=rngs), | ||
| self.shd_cfg.kv_weight_ndh, | ||
| ) | ||
| self.v_proj = shard( | ||
| nnx.Linear(cfg.emb_dim, cfg.num_kv_heads * cfg.head_dim, use_bias=True, rngs=rngs), | ||
| self.shd_cfg.kv_weight_ndh, | ||
| ) | ||
| self.o_proj = shard( | ||
| nnx.Linear(cfg.num_heads * cfg.head_dim, cfg.emb_dim, use_bias=False, rngs=rngs), self.shd_cfg.o_weight_nhd | ||
| ) | ||
|
|
There was a problem hiding this comment.
These shardings may not be optimal because the attention weights are stored differently in this implementation. Have you had a chance to test the performance here?
|
|
||
| TRANSFORM_LINEAR = Transform(permute=(1, 0)) | ||
| TRANSFORM_NONE = Transform() | ||
|
|
There was a problem hiding this comment.
Can you update this to match the Transform enum from other models?
| from bonsai.models.mimo_audio.mimo_audio_tokenizer_configuration import MiMoAudioTokenizerConfig | ||
| from bonsai.models.mimo_audio.mimo_audio_tokenizer_params import load_tokenizer_weights_from_safetensors |
There was a problem hiding this comment.
Can we move all these imports to the beginning of the file?
| decoded_audio = tokenizer_model.decode(audio_tokens_array) | ||
| os.makedirs(output_dir, exist_ok=True) | ||
|
|
||
| import soundfile as sf |
There was a problem hiding this comment.
Please add additional dependencies to the pyproject.toml file
| cls.num_input_tokens = 5 | ||
| cls.group_size = cls.bonsai_config.group_size | ||
| cls.audio_channels = cls.bonsai_config.audio_channels | ||
| cls.tol = 1e-3 |
There was a problem hiding this comment.
Please use atol and rtol based on each test. Refer to my comment in the qwen2 part of the PR for more information.
|
|
||
| # Since the transformer library does not yet support mimo-audio-tokenizer, | ||
| # during testing, the official implementation code of mimo-audio-tokenizer (https://github.com/XiaomiMiMo/MiMo-Audio) needs to be copied to the corresponding location. | ||
| from bonsai.models.mimo_audio.pytorch.src.mimo_audio_tokenizer.modeling_audio_tokenizer import ( |
There was a problem hiding this comment.
This path isn't part of the PR. This will cause the tests to fail when making the PR. Could you include an install for this in the pyproject.toml file under testing
There was a problem hiding this comment.
Actually it looks like the comment before this code may be outdated. Huggingface supports tokenizers for this model. Could you please use that tokenizer in the testing?
| return cls(**kwargs) | ||
|
|
||
| def create_qwen2_config(self) -> "Qwen2Config": | ||
| from bonsai.models.qwen2.modeling import ModelConfig as Qwen2Config |
There was a problem hiding this comment.
Can we move this import to the beginning of the file?
| from typing import TYPE_CHECKING | ||
| from bonsai.models.qwen3.modeling import ShardingCfg | ||
|
|
||
| if TYPE_CHECKING: |
There was a problem hiding this comment.
Lets make this just a regular import since it is used in the MiMoAudioConfig
| emb_dim=1024, | ||
| mlp_dim=4096, | ||
| num_heads=64, | ||
| head_dim=16, # 1024 // 64 = 16 |
There was a problem hiding this comment.
Can you remove this comment?
|
Hi, thanks for this addition. Could we split this into two PRs? Please merge Qwen2 first, then follow up with the MiMo components. On a high level:
For MiMo (in the follow-up PR):
|
| from jax.sharding import PartitionSpec as P | ||
|
|
||
| Array = jnp.ndarray | ||
| ShardingSpec = P |
There was a problem hiding this comment.
Can you update this to
from jax import P
from jaxtyping import Array
from jax.sharding import PartitionSpec| conv_weight: ShardingSpec # (in_channels, out_channels, kernel_size) | ||
| conv_bias: ShardingSpec # (out_channels,) | ||
|
|
||
| # Transformer weight sharding (shared by Encoder/Decoder/Vocoder) | ||
| attn_qkvo_weight: ShardingSpec # (d_model, d_model) | ||
| attn_qkv_bias: ShardingSpec # (d_model,) | ||
| attn_out_bias: ShardingSpec # (d_model,) | ||
|
|
||
| # FFN weight sharding | ||
| ffn_weight_in: ShardingSpec # (d_model, ffn_dim) | ||
| ffn_weight_out: ShardingSpec # (ffn_dim, d_model) | ||
| ffn_bias: ShardingSpec # (ffn_dim,) or (d_model,) | ||
|
|
||
| # LayerNorm/GroupNorm sharding | ||
| norm_scale: ShardingSpec # (dim,) | ||
| norm_bias: ShardingSpec # (dim,) | ||
|
|
||
| # Quantizer codebook sharding | ||
| codebook: ShardingSpec # (codebook_size, d_model) | ||
|
|
||
| # ConvTranspose1d weight sharding | ||
| conv_transpose_weight: ShardingSpec # (in_ch, out_ch, kernel) | ||
| conv_transpose_bias: ShardingSpec # (out_ch,) | ||
|
|
||
| # ISTFT related sharding | ||
| istft_linear_weight: ShardingSpec # (dim, n_fft+2) | ||
| istft_linear_bias: ShardingSpec # (n_fft+2,) | ||
| istft_window: ShardingSpec # (win_length,) | ||
|
|
||
| # Activation sharding | ||
| act_btd: ShardingSpec # [batch, time, d_model] | ||
| act_btnh: ShardingSpec # [batch, time, num_heads, head_dim] | ||
| act_btc: ShardingSpec # [batch, time, channels] |
There was a problem hiding this comment.
Can you update this dataclass to use PartitionSpec instead of ShardingSpec.
| @@ -0,0 +1,210 @@ | |||
| """Configuration classes for MiMo Audio Tokenizer.""" | |||
There was a problem hiding this comment.
This file can be combined with the mimo_audio_tokenizer.py file.
| @dataclass(frozen=True) | ||
| class Transform: | ||
| permute: tuple[int, ...] | None = None | ||
| reshape: tuple[int, ...] | None = None | ||
| reshape_first: bool = False | ||
|
|
||
|
|
||
| TRANSFORM_LINEAR = Transform(permute=(1, 0)) | ||
| TRANSFORM_CONV1D = Transform(permute=(2, 1, 0)) | ||
| TRANSFORM_NONE = Transform() |
There was a problem hiding this comment.
Can you use the Transform enum for consistency with other models?
| TRANSFORM_NONE = Transform() | ||
|
|
||
|
|
||
| def _get_key_mapping(config: model_lib.MiMoAudioTokenizerConfig) -> dict[str, tuple[str, Transform]]: |
There was a problem hiding this comment.
Using a regex can simplify this function.
| @@ -0,0 +1,110 @@ | |||
| from dataclasses import dataclass | |||
There was a problem hiding this comment.
These configs can be moved into the modeling.py file.
| input_local_layers: int = 6 | ||
| input_local_dim: int = 1024 | ||
| input_full_attention: bool = True | ||
|
|
There was a problem hiding this comment.
for consistency with other models, can we store the qwen2 config as a variable in this config. The other functions could then just assign to that variable and return this class.
| self.speech_vocab_sizes = [1025, 1025, 129, 129, 129, 129, 129, 129] | ||
| self.speech_empty_ids = [1024, 1024, 128, 128, 128, 128, 128, 128] | ||
| self.delay_pattern = [0, 1, 2, 3, 4, 5, 6, 7] |
There was a problem hiding this comment.
These should be part of the model config.
| output = text_embeds + speech_grouped_embeds | ||
| return shard(output, self.shd_cfg.act_btd) | ||
|
|
||
| def forward( |
There was a problem hiding this comment.
can we name this __call__ for consistency with other models?
| def _local_transformer_step_jit( | ||
| local_transformer: nnx.Module, | ||
| local_embeds: jnp.ndarray, | ||
| cache: Cache, | ||
| segment_ids: jnp.ndarray, | ||
| ) -> Tuple[jnp.ndarray, Cache]: | ||
| x = local_embeds | ||
| for i, layer in enumerate(local_transformer.layers): | ||
| x = layer(x, cache[i], segment_ids) | ||
| hidden_state = local_transformer.final_norm(x) | ||
| return hidden_state, cache |
There was a problem hiding this comment.
The logic for this function should be built into the local_forward function.
There was a problem hiding this comment.
Actually this looks pretty similar to the Qwen2 call method. Can you use the jitted-forward function from that implementation?
| @jax.jit | ||
| def local_forward_jit( | ||
| model: FlaxMiMoAudioForCausalLM, | ||
| local_embeds: jnp.ndarray, | ||
| key: jax.random.PRNGKey, | ||
| ) -> jnp.ndarray: | ||
| return model.local_forward(local_embeds, key, local_sampler=None) |
There was a problem hiding this comment.
This function is unused
| q_bias_flatten = Transform(reshape=(-1,)) | ||
| kv_bias_flatten = Transform(reshape=(-1,)) |
There was a problem hiding this comment.
Can you use the Transform enum to maintain consistency with other implementations?
|
@jenriver @chapman20j Thank you for the review and the detailed feedback.I ve gone through all the comments and will need some time to work through the requested changes. I plan to start addressing them shortly and will update the PR once I have meaningful progress. |
…-ml#130) * adding readme, modeling and params files for dinov3 * added run and testing * added readme and removed some basic imports * removed ignore statement for ide * Added colab link in markdown * changed config and made changes wrt it, added all outputs, removed cosine similarity in tests * adding more tests and some final changes * final changes to params * fixed ruff based text formatting * using random init torch model in test_outputs and removed run_model due to gated repo. * using hf_hub constants for path and pre-commit done for ruff * updating tests * rename test_outputs to be model specific * inital commit for qwen3vl * removed initial commit * vjepa2 base fm, classifier and params done * fixed params and testing for model porting * updating for additional testing and stability * more testing, ruff, classifier works well * fixed foundation model, reformatted modeling.py, fixed testing * adding readme * name in readme 😅 * removed unnecessary self.config = config * changes and fixes * Fixed resnet link in README (jax-ml#127) Co-authored-by: James Chapman <chapmanjames@google.com> * Vjepa2 format fixes (jax-ml#128) * vjepa2: Use opencv-python than torchcodec * Refactor test_outputs and forward in modeling.py * [CI] run_selective_tests: handle renamed paths (jax-ml#129) * remove dinov3 output class to facilitate jit compilation --------- Co-authored-by: Jen Ha <25069493+jenriver@users.noreply.github.com> Co-authored-by: vfdev <vfdev.5@gmail.com> Co-authored-by: James Chapman <chapmanjames@google.com>
Co-authored-by: Jen Ha <25069493+jenriver@users.noreply.github.com>
Change `- []` to `- [ ]` to ensure checkboxes are properly rendered by GitHub.
…l#131) * implement state space caching * refactor docstrings and comments * convert unicode string to ASCII-only * move create_empty_cache() to modeling.py --------- Co-authored-by: James Chapman <chapmanjames@google.com>
…at-support-mimo-audio
Resolves #91
Adds MiMo-Audio model implementation in JAX/Flax NNX.
Reference
Paper: {https://github.com/XiaomiMiMo/MiMo-Audio/blob/main/MiMo-Audio-Technical-Report.pdf}
Website: {https://github.com/XiaomiMiMo/MiMo-Audio}
Model code: {https://github.com/XiaomiMiMo/MiMo-Audio/tree/main/src}
Model weights: {https://huggingface.co/collections/XiaomiMiMo/mimo-audio}
Checklist
run_model.pyfor model usage,test_outputs.pyand/ormodel_validation_colab.ipynbfor quality).