Skip to content

feat: add VITA-Audio#146

Merged
weedge merged 16 commits intomainfrom
feat/voice
May 12, 2025
Merged

feat: add VITA-Audio#146
weedge merged 16 commits intomainfrom
feat/voice

Conversation

@weedge
Copy link
Collaborator

@weedge weedge commented May 9, 2025

image

VITA-Audio 由四个主要组件组成:音频编码器、音频解码器、大型语言模型主干和一组跨模态token预测 (MCTP) 模块。Boost/Balance 与 GLM-4-Voice: Towards Intelligent and Human-like End-To-End Spoken Chatbot 类似,但增加了mtp, 使用 CosyVoice: A Scalable Multilingual Zero-Shot Text-To-Speech Synthesizer Based on Supervised 作为音频编码器和解码器。音频信号首先通过音频编码器编码为一系列离散的音频token,然后输入到 LLM 进行处理。在每次前向传递过程中,LLM 交替生成文本和音频token。LLM 最后一层的隐藏状态以及预测token的嵌入作为 MCTP 模块的输入。历史输入token、LLM 和 MCTP 模块预测的token被连接起来,形成下一个 LLM 前向传递的输入。最后,LLM 和 MCTP 模块生成的音频token被聚合并传递给音频解码器以生成最终的音频输出。

训练: #146 (comment)
推理流程分析见: #146 (comment)


feat:

  • add VITA audio task cases test (s2s,asr,tts) run on the modal
# download model
modal run src/download_models.py --repo-ids "VITA-MLLM/VITA-Audio-Boost"
modal run src/download_models.py --repo-ids "VITA-MLLM/VITA-Audio-Balance" 
modal run src/download_models.py --repo-ids "VITA-MLLM/VITA-Audio-Plus-Vanilla"
modal run src/download_models.py --repo-ids "THUDM/glm-4-voice-tokenizer"
modal run src/download_models.py --repo-ids "THUDM/glm-4-voice-decoder" 
modal run src/download_models.py --repo-ids "FunAudioLLM/SenseVoiceSmall"

# optional download, just for test
modal run src/download_models.py --repo-ids "SparkAudio/Spark-TTS-0.5B"
modal run src/download_models.py --repo-ids "hubertsiuzdak/snac_24khz"
modal run src/download_models.py --repo-ids "FunAudioLLM/CosyVoice2-0.5B"

# tokenize
IMAGE_GPU=T4 modal run src/llm/transformers/vita_voice.py --task tokenize
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task dump_model

# LLM: sensevoice(no use ctc)_qwen2_mtp(no mtp)
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text 
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text_audio
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text_stream
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text_audio_stream

# LLM: qwen2_mtp
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text 
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text_audio
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text_stream
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text_audio_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text 
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task text_stream

# LLM(Plus-Vanilla): sensevoice(no use ctc)_qwen2_mtp(no mtp) + AudioTokenizer: sensevoice_glm4voice(sensevoice WavFrontend, decoder)
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task sts
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task sts_stream
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr_text_stream
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_stream
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_clone
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_static_stream
IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_dynamic_stream


# LLM(Plus-Boost): sensevoice(no use ctc)_qwen2_mtp(no mtp) + AudioTokenizer: sensevoice_glm4voice(sensevoice WavFrontend, decoder)
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L40s modal run src/llm/transformers/vita_voice.py --task sts
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L40s modal run src/llm/transformers/vita_voice.py --task sts_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr_text_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_clone
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_static_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Plus-Boost IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_dynamic_stream

# LLM(Boost/Balance): qwen2_mtp + AudioTokenizer: glm4voice(tokenizer, decoder)
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L40s modal run src/llm/transformers/vita_voice.py --task sts
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L40s modal run src/llm/transformers/vita_voice.py --task sts_stream
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr_text_stream
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_stream
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_clone
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_static_stream
AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_dynamic_stream

MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L40s modal run src/llm/transformers/vita_voice.py --task sts
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L40s modal run src/llm/transformers/vita_voice.py --task sts_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task asr_text_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_clone
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_static_stream
MTP_LLM_MODEL=VITA-MLLM/VITA-Audio-Boost AUDIO_TOKENIZER_TYPE=glm4voice IMAGE_GPU=L4 modal run src/llm/transformers/vita_voice.py --task tts_audio_chunk_dynamic_stream
  • add llm_transformers_manual_vita_text llm_transformers_manual_vita_audio_asr llm_transformers_manual_vita_tts llm_transformers_manual_vita_text_voice llm_transformers_manual_vita_voice
  • add vita_asr and unit test
LLM_MODEL_NAME_OR_PATH=./models/VITA-MLLM/VITA-Audio-Plus-Vanilla \
    SENSE_VOICE_MODEL_PATH=./models/FunAudioLLM/SenseVoiceSmall \
    AUDIO_TOKENIZER_TYPE=sensevoice_glm4voice \
    LLM_DEVICE=cuda LLM_TORCH_DTYPE=bfloat16 \
    python -m unittest test.modules.speech.asr.test_vita_asr.TestVITAASR.test_transcribe_stream
LLM_MODEL_NAME_OR_PATH=./models/VITA-MLLM/VITA-Audio-Plus-Vanilla \
    SENSE_VOICE_MODEL_PATH=./models/FunAudioLLM/SenseVoiceSmall \
    AUDIO_TOKENIZER_TYPE=sensevoice_glm4voice \
    LLM_DEVICE=cuda LLM_TORCH_DTYPE=bfloat16 LLM_ATTN_IMP=flash_attention_2 \
    python -m unittest test.modules.speech.asr.test_vita_asr.TestVITAASR.test_transcribe_stream
  • add tts_vita, unit and grpc client test
LLM_MODEL_NAME_OR_PATH=./models/VITA-MLLM/VITA-Audio-Plus-Vanilla \
    AUDIO_TOKENIZER_TYPE=sensevoice_glm4voice \
    FLOW_PATH=./models/THUDM/glm-4-voice-decoder \
    LLM_DEVICE=cuda LLM_TORCH_DTYPE=bfloat16 LLM_ATTN_IMP=flash_attention_2 \
    python -m unittest test.modules.speech.tts.test_vita.TestVITATTS.test_synthesize
TTS_TAG=tts_vita IS_SAVE=1 \
    LLM_MODEL_NAME_OR_PATH=./models/VITA-MLLM/VITA-Audio-Plus-Vanilla \
    AUDIO_TOKENIZER_TYPE=sensevoice_glm4voice \
    FLOW_PATH=./models/THUDM/glm-4-voice-decoder \
    LLM_DEVICE=cuda LLM_TORCH_DTYPE=bfloat16 LLM_ATTN_IMP=flash_attention_2 \
    python -m src.cmd.grpc.speaker.client
  • add livekit_asr_vita_voice_bot , livekit_vita_voice_bot and fastapi_webrtc_vita_voice_bot_serve run with modal
# run fastapi_webrtc_vita_voice_bot_serve
ACHATBOT_VERSION=0.0.11 IMAGE_CONCURRENT_CN=1 IMAGE_GPU=L40s modal serve src/fastapi_webrtc_vita_voice_bot_serve.py

speech -> text + speech | use Plus-Vanilla qwen2_mtp_sensevoice no mtp

# curl api to run chat room bot with webrtc (livekit room)
curl --location 'https://weedge-achatbot--fastapi-webrtc-vita-voice-bot-srv-app-dev.modal.run/bot_join/chat-room/LivekitVITAVoiceBot' \
--header 'Content-Type: application/json' \
--data '{
    "chat_bot_name": "LivekitVITAVoiceBot",
    "room_name": "chat-room",
    "room_url": "",
    "token": "",
    "room_manager": {
        "tag": "livekit_room",
        "args": {
            "bot_name": "LivekitVITAVoiceBot",
            "is_common_session": false
        }
    },
    "services": {
        "pipeline": "achatbot",
        "vad": "silero",
        "voice_llm": "llm_transformers_manual_vita_voice"
    },
    "config": {
        "vad": {
            "tag": "silero_vad_analyzer",
            "args": {
                "stop_secs": 0.7
            }
        },
        "voice_llm": {
            "tag": "llm_transformers_manual_vita_voice",
            "args": {
                "no_stream_sleep_time": 0.5,
                "lm_device": "cuda",
                "lm_torch_dtype": "bfloat16",
                "lm_attn_impl": "flash_attention_2",
                "warmup_steps": 1,
                "chat_history_size": 0,
                "audio_tokenizer_type": "sensevoice_glm4voice",
                "audio_tokenizer_model_path": null,
                "sense_voice_model_path": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall",
                "flow_path": "/root/.achatbot/models/THUDM/glm-4-voice-decoder",
                "audio_tokenizer_rank": 0,
                "chunk_size_list": [8, 16, 25, 50, 100, 150, 200],
                "lm_model_name_or_path": "/root/.achatbot/models/VITA-MLLM/VITA-Audio-Plus-Boost"
            }
        }
    },
    "config_list": []
}'

speech -> text + speech | use Balance/Boost qwen2_mtp LLM with mtp

{
    "chat_bot_name": "LivekitVitaVoiceBot",
    "room_name": "chat-room",
    "room_url": "",
    "token": "",
    "room_manager": {
        "tag": "livekit_room",
        "args": {
            "bot_name": "LivekitVitaVoiceBot",
            "is_common_session": false
        }
    },
    "services": {
        "pipeline": "achatbot",
        "vad": "silero",
        "voice_llm": "llm_transformers_manual_vita_voice"
    },
    "config": {
        "vad": {
            "tag": "silero_vad_analyzer",
            "args": {
                "stop_secs": 0.7
            }
        },
        "voice_llm": {
            "tag": "llm_transformers_manual_vita_voice",
            "args": {
                "no_stream_sleep_time": 0.5,
                "lm_device": "cuda",
                "lm_torch_dtype": "bfloat16",
                "lm_attn_impl": "flash_attention_2",
                "warmup_steps": 1,
                "chat_history_size": 0,
                "audio_tokenizer_type": "glm4voice",
                "audio_tokenizer_model_path": "/root/.achatbot/models/THUDM/glm-4-voice-tokenizer",
                "sense_voice_model_path": null,
                "flow_path": "/root/.achatbot/models/THUDM/glm-4-voice-decoder",
                "audio_tokenizer_rank": 0,
                "chunk_size_list": [
                    8, 16, 25, 50, 100, 150, 200
                ],
                "lm_model_name_or_path": "/root/.achatbot/models/VITA-MLLM/VITA-Audio-Balance"
            }
        }
    },
    "config_list": []
}

asr + text -> text+speech | use qwen2_mtp_sensevoice LLM no mtp

curl --location 'https://weedge--fastapi-webrtc-vita-voice-bot-srv-app-dev.modal.run/bot_join/chat-room/LivekitAsrVITAVoiceBot' \
--header 'Content-Type: application/json' \
--data '{
  "chat_bot_name": "LivekitAsrVITAVoiceBot",
  "room_name": "chat-room",
  "room_url": "",
  "token": "",
  "room_manager": {
    "tag": "livekit_room",
    "args": {
      "bot_name": "LivekitAsrVITAVoiceBot",
      "is_common_session": false
    }
  },
  "services": {
    "pipeline": "achatbot",
    "vad": "silero",
    "asr": "sense_voice_asr",
    "voice_llm": "llm_transformers_manual_vita_text_voice"
  },
  "config": {
    "vad": {
      "tag": "silero_vad_analyzer",
      "args": {
        "stop_secs": 0.7
      }
    },
    "asr": {
      "args": {
        "language": "zn",
        "model_name_or_path": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall"
      },
      "tag": "sense_voice_asr"
    },
    "voice_llm": {
      "tag": "llm_transformers_manual_vita_voice",
      "args": {
        "no_stream_sleep_time": 0.5,
        "lm_device": "cuda",
        "lm_torch_dtype": "bfloat16",
        "lm_attn_impl": "flash_attention_2",
        "warmup_steps": 1,
        "chat_history_size": 0,
        "audio_tokenizer_type": "sensevoice_glm4voice",
        "audio_tokenizer_model_path": null,
        "sense_voice_model_path": null,
        "flow_path": "/root/.achatbot/models/THUDM/glm-4-voice-decoder",
        "audio_tokenizer_rank": 0,
        "chunk_size_list": [8, 16, 25, 50, 100, 150, 200],
        "lm_model_name_or_path": "/root/.achatbot/models/VITA-MLLM/VITA-Audio-Plus-Boost"
      }
    }
  },
  "config_list": []
}
'

asr + text -> text+speech | use Balance/Boost qwen2_mtp LLM

{
    "chat_bot_name": "LivekitAsrVITAVoiceBot",
    "room_name": "chat-room",
    "room_url": "",
    "token": "",
    "room_manager": {
        "tag": "livekit_room",
        "args": {
            "bot_name": "LivekitAsrVITAVoiceBot",
            "is_common_session": false
        }
    },
    "services": {
        "pipeline": "achatbot",
        "vad": "silero",
        "asr": "sense_voice_asr",
        "voice_llm": "llm_transformers_manual_vita_text_voice"
    },
    "config": {
        "vad": {
            "tag": "silero_vad_analyzer",
            "args": {
                "stop_secs": 0.7
            }
        },
        "asr": {
            "args": {
                "language": "zn",
                "model_name_or_path": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall"
            },
            "tag": "sense_voice_asr"
        },
        "voice_llm": {
            "tag": "llm_transformers_manual_vita_voice",
            "args": {
                "no_stream_sleep_time": 0.5,
                "lm_device": "cuda",
                "lm_torch_dtype": "bfloat16",
                "lm_attn_impl": "flash_attention_2",
                "warmup_steps": 1,
                "chat_history_size": 0,
                "audio_tokenizer_type": "glm4voice",
                "audio_tokenizer_model_path": null,
                "sense_voice_model_path": null,
                "flow_path": "/root/.achatbot/models/THUDM/glm-4-voice-decoder",
                "audio_tokenizer_rank": 0,
                "chunk_size_list": [
                    8, 16, 25, 50, 100, 150, 200
                ],
                "lm_model_name_or_path": "/root/.achatbot/models/VITA-MLLM/VITA-Audio-Balance"
            }
        }
    },
    "config_list": []
}

paper ai podcast: https://podcast-997.pages.dev/podcast/81f84087d1884716b2d6acec1273ccb1


colab: https://github.com/weedge/doraemon-nb/blob/main/achatbot_vita_audio.ipynb


Note

  • tex->text+speech chat QA任务中,生成的文本和音频一般,有时候答非所问;(可进行tex->text+speech chat QA任务 SFT)
  • tts任务中,一次生成长文本音频,boost/balance 生成的语音质量低,;(可将文本进行切分批量并发生成)
  • tts clone 任务暂未支持,需要训练音频模块;(论文的关注点主要在多模态模型引入mtp)

Tip

目标: (公开的训练和推理源码,应该是开源了中间过程,还可以加入数据集继续训练, 采用不同的audio tokenizer(encode) 和 llm 训练mtp模型(对于参数量大的模型比如dsv3可以使用LoRA的方式训练mtp模型), 但整体框架差不多)

  • 可插拔audio tokenizer model (encode) + llm(支持多模态tokens) + audio vocoder (decode) 进行训练, 以及推理
  • 训练推理加速引入mtp (推理速度优先的场景, 引入mtp模型性能低于纯主干模型性能, mtp 在 scaling 模型的宽度和深度 到大参数量的模型中表现比较好,低参数模型不太好)

llm 分为: 稠密模型 和 稀疏模型(moe)

audio tokenizer 分为:

  • 分开的 audio tokenizer encode 模型和 decode 模型
  • 单独的 audio codec 模型(支持encode, decode 重构)

工程化:

  • 如果将模型用于在线服务,对于推理工程架构来说,想充分利用计算存储资源,合理部署,或者复利LLM的推理优化,可以将音频相关模型单独训练部署。
  • 如果把语音模块网络模型结构和text LLM模型结构集成训练推理,需要进一步融合优化语音模块相关的算子到已有LLM优化结构中;或者可以对整体模块进行权重拆分部署(相对于单个整体omni模型结构部署,拆分会增加系统延迟)。

reference

Signed-off-by: weedge <weege007@gmail.com>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @weedge, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

Summary of Changes

This pull request introduces VITA-Audio to the project, including necessary dependencies and a modal app for running audio-related tasks. It adds a new submodule for VITA-Audio and sets up the environment for tasks like speech-to-speech (s2s), automatic speech recognition (asr), and text-to-speech (tts). The modal app configuration includes GPU setup, model loading, and task execution functions.

Highlights

  • Submodule Addition: The .gitmodules file is updated to include VITA-Audio as a submodule, pointing to the weedge/VITA-Audio repository.
  • Modal App Implementation: A new modal app, vita_voice.py, is created to handle VITA-Audio tasks. It includes image configuration with CUDA, dependency installation, and task-specific functions.
  • Dependency Management: The modal image installs necessary dependencies such as git, ffmpeg, torch, torchaudio, and flash-attn to support audio processing and model execution.
  • Task Execution Functions: The vita_voice.py script defines functions for tokenizing, model dumping, and benchmarking various audio tasks (STS, ASR, TTS) with streaming and non-streaming options.
  • Inference Class: The S2SInference class is implemented to handle sequence-to-sequence inference, including loading models, tokenizers, and running different audio processing pipelines.

Changelog

  • .gitmodules
    • Added VITA-Audio as a submodule.
  • deploy/modal/src/llm/transformers/vita_voice.py
    • Created a new modal app for VITA-Audio tasks.
    • Configured the modal image with CUDA and necessary dependencies.
    • Implemented functions for tokenizing, model dumping, and benchmarking audio tasks.
    • Defined the S2SInference class for handling sequence-to-sequence inference.
  • deps/VITAAudio
    • Added a file that represents the VITA-Audio submodule commit.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.


A voice from the code,
Transforms sound into words,
LLM sings softly.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

gemini-code-assist[bot]

This comment was marked as resolved.

@weedge
Copy link
Collaborator Author

weedge commented May 9, 2025

text token和audio token 交替生成
image
生成配置:

inference with mtp

按照 mtp_inference_mode 配置生成 奇数位为主干网络(28 x Qwen2DecoderLayers)生成text token数,偶数位为mtp(单个Qwen2DecoderLayer)生成audio token数。
比如:

  • Boost mtp_inference_mode:[1, 10, 4, 10] 表示首个text token由主干网络(28 x Qwen2DecoderLayers)生成,后续10个audio token 分别对应 第 (28 + mtp_idx) Qwen2DecoderLayer 来生成,0<mtp_idx<=10 ; 然后接着4个text token由主干网络(28 x Qwen2DecoderLayers)生成, 后续10个audio token 分别对应 第 (28 + mtp_idx) Qwen2DecoderLayer 来生成;后续一直交替生成。
  • Balance mtp_inference_mode: [1, 4, 3, 8, 4, 10] 类似 Boost 类推

Note

Qwen2DecoderLayer 越往后的层,对应生成hidden stats 包含的语义信息越好。(对于训练一个宽度+深度,参数量大的网络模型,使用mtp的效果会更好些)

代码:

# Boost/Balance mtp layer:  第 38-10 +  mtp_idx = 28+mtp_idx Qwen2DecoderLayer
layer_idxs=[self.config.num_hidden_layers - self.config.num_nextn_predict_layers + mtp_idx]

# 主干 qwen2: 28 x Qwen2DecoderLayer
layer_idxs=list(range(self.config.num_hidden_layers - self.config.num_nextn_predict_layers))

Qwen2MTP_AR_LLM_Boost

training or inference forward step: https://huggingface.co/VITA-MLLM/VITA-Audio-Boost/blob/main/modeling_qwen2.py#L781
model config: https://huggingface.co/VITA-MLLM/VITA-Audio-Boost/blob/main/config.json#L22

Qwen2MTP_AR_LLM_Boost 10317.912576 M parameters (qwen2 + 10 mtp (Linear(FC) projs + embed_norms + mtp_hidden_norms)

Qwen2MTPForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(168072, 3584)
    (layers): ModuleList(
      (0-37): 38 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3584, out_features=168072, bias=False)
  (mtp_projs): ModuleList(
    (0-9): 10 x Linear(in_features=7168, out_features=3584, bias=False)
  )
  (mtp_embed_norms): ModuleList(
    (0-9): 10 x Qwen2RMSNorm((3584,), eps=1e-06)
  )
  (mtp_hidden_norms): ModuleList(
    (0-9): 10 x Qwen2RMSNorm((3584,), eps=1e-06)
  )
)

Qwen2MTP_AR_LLM_Balance

training or inference forward step: https://huggingface.co/VITA-MLLM/VITA-Audio-Balance/blob/main/modeling_qwen2.py#L781
model config: https://huggingface.co/VITA-MLLM/VITA-Audio-Balance/blob/main/config.json#L22

Qwen2MTP_AR_LLM_Balance 10317.912576 M parameters (qwen2 + 10 mtp (Linear(FC) projs + embed_norms + mtp_hidden_norms)

Qwen2MTPForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(168072, 3584)
    (layers): ModuleList(
      (0-37): 38 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbedding()
  )
  (lm_head): Linear(in_features=3584, out_features=168072, bias=False)
  (mtp_projs): ModuleList(
    (0-9): 10 x Linear(in_features=7168, out_features=3584, bias=False)
  )
  (mtp_embed_norms): ModuleList(
    (0-9): 10 x Qwen2RMSNorm((3584,), eps=1e-06)
  )
  (mtp_hidden_norms): ModuleList(
    (0-9): 10 x Qwen2RMSNorm((3584,), eps=1e-06)
  )
)

Qwen2MTP_AR_LLM_Vanilla

training or inference forward step: https://huggingface.co/VITA-MLLM/VITA-Audio-Plus-Vanilla/blob/main/modeling_qwen2.py#L834
model config: https://huggingface.co/VITA-MLLM/VITA-Audio-Plus-Vanilla/blob/main/config.json#L23 no mtp

Qwen2MTP_AR_LLM_Vanilla 7979.042111 M parameters (sensevoice small encoder(no use ctc header) + mlp(ResamplerProjector) + qwen2) no mtp

Qwen2MTPSenseVoiceForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(168072, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb): Qwen2RotaryEmbedding()
    (audio_model): AudioEncoder(
      (model): SenseVoiceSmall(
        (specaug): SpecAugLFR(
          (freq_mask): MaskAlongAxisLFR(mask_width_range=[0, 30], num_mask=1, axis=freq)
          (time_mask): MaskAlongAxisLFR(mask_width_range=[0, 12], num_mask=1, axis=time)
        )
        (encoder): SenseVoiceEncoderSmall(
          (embed): SinusoidalPositionEncoder()
          (encoders0): ModuleList(
            (0): EncoderLayerSANM(
              (self_attn): MultiHeadedAttentionSANM(
                (linear_out): Linear(in_features=512, out_features=512, bias=True)
                (linear_q_k_v): Linear(in_features=560, out_features=1536, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
                (fsmn_block): Conv1d(512, 512, kernel_size=(11,), stride=(1,), groups=512, bias=False)
                (pad_fn): ConstantPad1d(padding=(5, 5), value=0.0)
              )
              (feed_forward): PositionwiseFeedForward(
                (w_1): Linear(in_features=512, out_features=2048, bias=True)
                (w_2): Linear(in_features=2048, out_features=512, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
                (activation): ReLU()
              )
              (norm1): LayerNorm((560,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (encoders): ModuleList(
            (0-48): 49 x EncoderLayerSANM(
              (self_attn): MultiHeadedAttentionSANM(
                (linear_out): Linear(in_features=512, out_features=512, bias=True)
                (linear_q_k_v): Linear(in_features=512, out_features=1536, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
                (fsmn_block): Conv1d(512, 512, kernel_size=(11,), stride=(1,), groups=512, bias=False)
                (pad_fn): ConstantPad1d(padding=(5, 5), value=0.0)
              )
              (feed_forward): PositionwiseFeedForward(
                (w_1): Linear(in_features=512, out_features=2048, bias=True)
                (w_2): Linear(in_features=2048, out_features=512, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
                (activation): ReLU()
              )
              (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (tp_encoders): ModuleList(
            (0-19): 20 x EncoderLayerSANM(
              (self_attn): MultiHeadedAttentionSANM(
                (linear_out): Linear(in_features=512, out_features=512, bias=True)
                (linear_q_k_v): Linear(in_features=512, out_features=1536, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
                (fsmn_block): Conv1d(512, 512, kernel_size=(11,), stride=(1,), groups=512, bias=False)
                (pad_fn): ConstantPad1d(padding=(5, 5), value=0.0)
              )
              (feed_forward): PositionwiseFeedForward(
                (w_1): Linear(in_features=512, out_features=2048, bias=True)
                (w_2): Linear(in_features=2048, out_features=512, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
                (activation): ReLU()
              )
              (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
          (after_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (tp_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        )
        (ctc): CTC(
          (ctc_lo): Linear(in_features=512, out_features=25055, bias=True)
          (ctc_loss): CTCLoss()
        )
        (embed): Embedding(16, 560)
        (criterion_att): LabelSmoothingLoss(
          (criterion): KLDivLoss()
        )
      )
    )
    (audio_projection): ResamplerProjector(
      (pre_proj_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (mlp): Sequential(
        (0): Linear(in_features=512, out_features=3584, bias=False)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3584, out_features=3584, bias=False)
      )
    )
  )
  (lm_head): Linear(in_features=3584, out_features=168072, bias=False)
  (mtp_projs): ModuleList()
  (mtp_embed_norms): ModuleList()
  (mtp_hidden_norms): ModuleList()
)

@weedge
Copy link
Collaborator Author

weedge commented May 9, 2025

train:

image

training with mtp

使用 text-audio-interval-ratio配置训练 前向推理时使用的 mtp_inference_mode

loss 值计算,用于后向梯度更新

主干网推理后计算loss

        if labels is not None:
            loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
            # loss = ForCausalLMLoss(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
            if self.training and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
                with torch.no_grad():
                    logger.info(f"STP {loss=}")

mtp每层推理后计算loss

            for mtp_idx in range(self.config.num_nextn_predict_layers):

                # SFT with data packing
                if True:
                    mtp_mask = position_ids > mtp_idx
                    # input_ids = input_ids[mtp_mask].unsqueeze(0)
                    inputs_embeds = inputs_embeds[mtp_mask].unsqueeze(0)
                    if attention_mask is not None:
                        attention_mask = attention_mask[mtp_mask].unsqueeze(0)
                    if position_ids is not None:
                        position_ids = position_ids[mtp_mask].unsqueeze(0)
                    labels = labels[mtp_mask].unsqueeze(0)
                    kl_labels = kl_labels[mtp_mask].unsqueeze(0)

                    mtp_mask = torch.cat((mtp_mask[:, 1:], mtp_mask[:, :1]), dim=1)
                    hidden_states = hidden_states[mtp_mask].unsqueeze(0)

                    cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k = prepare_fa2_from_position_ids_for_mtp(position_ids, mtp_idx)
                    # kwargs["cu_seq_lens_q"] = cu_seq_lens_q
                    # kwargs["cu_seq_lens_k"] = cu_seq_lens_k
                    # kwargs["max_length_q"] = max_length_q
                    # kwargs["max_length_k"] = max_length_k

                    # print(f"{cu_seq_lens_q}")
                    # print(f"{cu_seq_lens_k}")
                    # print(f"{max_length_q}")
                    # print(f"{max_length_k}")

                mtp_outputs, _, mtp_loss = self.mtp_forward(
                    mtp_idx,
                    input_ids=None,
                    hidden_states=hidden_states,
                    attention_mask=attention_mask,
                    position_ids=position_ids,
                    past_key_values=past_key_values,
                    inputs_embeds=inputs_embeds,
                    labels=labels,
                    kl_labels=kl_labels,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    cache_position=cache_position,
                    num_logits_to_keep=num_logits_to_keep,
                    cu_seq_lens_q=cu_seq_lens_q,
                    cu_seq_lens_k=cu_seq_lens_k,
                    max_length_q=max_length_q,
                    max_length_k=max_length_k,
                    **kwargs,
                )

                loss += sum(mtp_loss) / self.config.num_nextn_predict_layers * self.config.mtp_loss_weight

                hidden_states = mtp_outputs.last_hidden_state

mtp loss:

        if labels is not None:
            loss = []
            # ce_loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
            ce_loss = ForCausalLMLoss(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)

            loss += [ce_loss]

            if False:
                kl_logits = logits.contiguous()
                kl_labels = kl_labels.contiguous()
                kl_loss = compute_kl_loss(kl_logits, kl_labels)

                kl_loss_weight = 1
                loss += [kl_loss_weight * kl_loss]

            if self.training and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0:
                with torch.no_grad():
                    logger.info(f"\tMTP {mtp_idx=} {loss=}")

@weedge
Copy link
Collaborator Author

weedge commented May 10, 2025

sensevoice(连续 特征fbank提取) + glm4voice decoder (flow + hifi-gan)

Note

SenseVoiceEncoderSmall 已经集成到Vanilla中, 使用SenseVoiceSmall对音频编码 仅使用 funasr WavFrontend 进行fbank特征提取(连续)

image

SenseVoiceSmall args:

{
    "encoder": "SenseVoiceEncoderSmall",
    "encoder_conf": {
        "output_size": 512,
        "attention_heads": 4,
        "linear_units": 2048,
        "num_blocks": 50,
        "tp_blocks": 20,
        "dropout_rate": 0.1,
        "positional_dropout_rate": 0.1,
        "attention_dropout_rate": 0.1,
        "input_layer": "pe",
        "pos_enc_class": "SinusoidalPositionEncoder",
        "normalize_before": True,
        "kernel_size": 11,
        "sanm_shfit": 0,
        "selfattention_layer_type": "sanm",
    },
    "model": "SenseVoiceSmall",
    "model_conf": {"length_normalized_loss": True, "sos": 1, "eos": 2, "ignore_id": -1},
    "tokenizer": SentencepiecesTokenizer(
        model="/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model"
    ),
    "tokenizer_conf": {
        "bpemodel": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model",
        "unk_symbol": "<unk>",
        "split_with_space": True,
    },
    "frontend": WavFrontend(),
    "frontend_conf": {
        "fs": 16000,
        "window": "hamming",
        "n_mels": 80,
        "frame_length": 25,
        "frame_shift": 10,
        "lfr_m": 7,
        "lfr_n": 6,
        "cmvn_file": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall/am.mvn",
    },
    "dataset": "SenseVoiceCTCDataset",
    "dataset_conf": {
        "index_ds": "IndexDSJsonl",
        "batch_sampler": "EspnetStyleBatchSampler",
        "data_split_num": 32,
        "batch_type": "token",
        "batch_size": 14000,
        "max_token_length": 2000,
        "min_token_length": 60,
        "max_source_length": 2000,
        "min_source_length": 60,
        "max_target_length": 200,
        "min_target_length": 0,
        "shuffle": True,
        "num_workers": 4,
        "sos": 1,
        "eos": 2,
        "IndexDSJsonl": "IndexDSJsonl",
        "retry": 20,
    },
    "train_conf": {
        "accum_grad": 1,
        "grad_clip": 5,
        "max_epoch": 20,
        "keep_nbest_models": 10,
        "avg_nbest_model": 10,
        "log_interval": 100,
        "resume": True,
        "validate_interval": 10000,
        "save_checkpoint_interval": 10000,
    },
    "optim": "adamw",
    "optim_conf": {"lr": 2e-05},
    "scheduler": "warmuplr",
    "scheduler_conf": {"warmup_steps": 25000},
    "specaug": "SpecAugLFR",
    "specaug_conf": {
        "apply_time_warp": False,
        "time_warp_window": 5,
        "time_warp_mode": "bicubic",
        "apply_freq_mask": True,
        "freq_mask_width_range": [0, 30],
        "lfr_rate": 6,
        "num_freq_mask": 1,
        "apply_time_mask": True,
        "time_mask_width_range": [0, 12],
        "num_time_mask": 1,
    },
    "init_param": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall/model.pt",
    "config": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall/config.yaml",
    "trust_remote_code": True,
    "device": "cuda:0",
    "model_path": "/root/.achatbot/models/FunAudioLLM/SenseVoiceSmall",
    "vocab_size": 25055,
    "token_list": None,
    "input_size": 560,
}

SenseVoiceSmall
audio_tokenizer.audio_encoder.sensevoice 233.999167 M parameters

SenseVoiceSmall(
  (specaug): SpecAugLFR(
    (freq_mask): MaskAlongAxisLFR(mask_width_range=[0, 30], num_mask=1, axis=freq)
    (time_mask): MaskAlongAxisLFR(mask_width_range=[0, 12], num_mask=1, axis=time)
  )
  (encoder): SenseVoiceEncoderSmall(
    (embed): SinusoidalPositionEncoder()
    (encoders0): ModuleList(
      (0): EncoderLayerSANM(
        (self_attn): MultiHeadedAttentionSANM(
          (linear_out): Linear(in_features=512, out_features=512, bias=True)
          (linear_q_k_v): Linear(in_features=560, out_features=1536, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (fsmn_block): Conv1d(512, 512, kernel_size=(11,), stride=(1,), groups=512, bias=False)
          (pad_fn): ConstantPad1d(padding=(5, 5), value=0.0)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048, bias=True)
          (w_2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation): ReLU()
        )
        (norm1): LayerNorm((560,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (encoders): ModuleList(
      (0-48): 49 x EncoderLayerSANM(
        (self_attn): MultiHeadedAttentionSANM(
          (linear_out): Linear(in_features=512, out_features=512, bias=True)
          (linear_q_k_v): Linear(in_features=512, out_features=1536, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (fsmn_block): Conv1d(512, 512, kernel_size=(11,), stride=(1,), groups=512, bias=False)
          (pad_fn): ConstantPad1d(padding=(5, 5), value=0.0)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048, bias=True)
          (w_2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation): ReLU()
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (tp_encoders): ModuleList(
      (0-19): 20 x EncoderLayerSANM(
        (self_attn): MultiHeadedAttentionSANM(
          (linear_out): Linear(in_features=512, out_features=512, bias=True)
          (linear_q_k_v): Linear(in_features=512, out_features=1536, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (fsmn_block): Conv1d(512, 512, kernel_size=(11,), stride=(1,), groups=512, bias=False)
          (pad_fn): ConstantPad1d(padding=(5, 5), value=0.0)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048, bias=True)
          (w_2): Linear(in_features=2048, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation): ReLU()
        )
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (after_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (tp_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (ctc): CTC(
    (ctc_lo): Linear(in_features=512, out_features=25055, bias=True)
    (ctc_loss): CTCLoss()
  )
  (embed): Embedding(16, 560)
  (criterion_att): LabelSmoothingLoss(
    (criterion): KLDivLoss()
  )
)

THUDM/glm-4-voice-decoder:
audio_tokenizer.audio_decoder.flow 111.166208 M parameters

MaskedDiffWithXvec(
  (input_embedding): Embedding(16384, 512)
  (spk_embed_affine_layer): Linear(in_features=192, out_features=80, bias=True)
  (encoder): BlockConformerEncoder(
    (embed): LinearNoSubsampling(
      (out): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): Dropout(p=0.1, inplace=False)
      )
      (pos_enc): EspnetRelPositionalEncoding(
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (after_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (encoders): ModuleList(
      (0-5): 6 x ConformerEncoderLayer(
        (self_attn): BlockRelPositionMultiHeadedAttention(
          (linear_q): Linear(in_features=512, out_features=512, bias=True)
          (linear_k): Linear(in_features=512, out_features=512, bias=True)
          (linear_v): Linear(in_features=512, out_features=512, bias=True)
          (linear_out): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear_pos): Linear(in_features=512, out_features=512, bias=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048, bias=True)
          (activation): SiLU()
          (dropout): Dropout(p=0.1, inplace=False)
          (w_2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (norm_ff): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm_mha): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (encoder_proj): Linear(in_features=512, out_features=80, bias=True)
  (decoder): ConditionalCFM(
    (estimator): ConditionalDecoder(
      (time_embeddings): SinusoidalPosEmb()
      (time_mlp): TimestepEmbedding(
        (linear_1): Linear(in_features=320, out_features=1024, bias=True)
        (act): SiLU()
        (linear_2): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (down_blocks): ModuleList(
        (0): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(320, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(320, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Downsample1D(
            (conv): Conv1d(256, 256, kernel_size=(3,), stride=(2,), padding=(1,))
          )
        )
        (1): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
      )
      (mid_blocks): ModuleList(
        (0-11): 12 x ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
        )
      )
      (up_blocks): ModuleList(
        (0): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Upsample1D(
            (conv): ConvTranspose1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          )
        )
        (1): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
      )
      (final_block): Block1D(
        (block): Sequential(
          (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (1): GroupNorm(8, 256, eps=1e-05, affine=True)
          (2): Mish()
        )
      )
      (final_proj): Conv1d(256, 80, kernel_size=(1,), stride=(1,))
    )
  )
  (length_regulator): InterpolateRegulator(
    (model): Sequential(
      (0): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 80, eps=1e-05, affine=True)
      (2): Mish()
      (3): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (4): GroupNorm(1, 80, eps=1e-05, affine=True)
      (5): Mish()
      (6): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (7): GroupNorm(1, 80, eps=1e-05, affine=True)
      (8): Mish()
      (9): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (10): GroupNorm(1, 80, eps=1e-05, affine=True)
      (11): Mish()
      (12): Conv1d(80, 80, kernel_size=(1,), stride=(1,))
    )
  )
)

THUDM/glm-4-voice-decoder:
audio_tokenizer.audio_decoder.hift 20.460591 M parameters

HiFTGenerator(
  (m_source): SourceModuleHnNSF(
    (l_sin_gen): SineGen()
    (l_linear): Linear(in_features=9, out_features=1, bias=True)
    (l_tanh): Tanh()
  )
  (f0_upsamp): Upsample(scale_factor=256.0, mode='nearest')
  (conv_pre): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))
  (ups): ModuleList(
    (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
    (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
  )
  (source_downs): ModuleList(
    (0): Conv1d(18, 256, kernel_size=(np.int64(16),), stride=(np.int64(8),), padding=(np.int64(4),))
    (1): Conv1d(18, 128, kernel_size=(1,), stride=(1,))
  )
  (source_resblocks): ModuleList(
    (0): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
        (1): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (1): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
        (1): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
  )
  (resblocks): ModuleList(
    (0): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (1): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
        (1): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (2): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(5,))
        (1): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(5,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (3): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (4): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
        (1): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (5): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
        (1): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
  )
  (conv_post): Conv1d(128, 18, kernel_size=(7,), stride=(1,), padding=(3,))
  (reflection_pad): ReflectionPad1d((1, 0))
  (f0_predictor): ConvRNNF0Predictor(
    (condnet): Sequential(
      (0): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): ELU(alpha=1.0)
      (2): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (3): ELU(alpha=1.0)
      (4): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (5): ELU(alpha=1.0)
      (6): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (7): ELU(alpha=1.0)
      (8): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (9): ELU(alpha=1.0)
    )
    (classifier): Linear(in_features=512, out_features=1, bias=True)
  )
)

@weedge
Copy link
Collaborator Author

weedge commented May 10, 2025

glm4voice tokenizer + glm4voice decoder

THUDM/glm-4-voice-tokenizer:
audio_tokenizer.audio_encoder.whisper-encoder+vq 343.59936 M parameters (vq 离散token)

WhisperVQEncoder(
  (conv1): CausalConv1d(128, 1280, kernel_size=(3,), stride=(1,))
  (conv2): CausalConv1d(1280, 1280, kernel_size=(3,), stride=(2,))
  (embed_positions): Embedding(1500, 1280)
  (layers): ModuleList(
    (0-15): 16 x WhisperVQEncoderLayer(
      (self_attn): WhisperSdpaAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (activation_fn): GELUActivation()
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (pooling_layer): AvgPool1d(kernel_size=(4,), stride=(4,), padding=(0,))
  (codebook): Embedding(16384, 1280)
  (embed_positions2): Embedding(375, 1280)
)

THUDM/glm-4-voice-decoder:
audio_tokenizer.audio_decoder.flow 111.166208 M parameters

MaskedDiffWithXvec(
  (input_embedding): Embedding(16384, 512)
  (spk_embed_affine_layer): Linear(in_features=192, out_features=80, bias=True)
  (encoder): BlockConformerEncoder(
    (embed): LinearNoSubsampling(
      (out): Sequential(
        (0): Linear(in_features=512, out_features=512, bias=True)
        (1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (2): Dropout(p=0.1, inplace=False)
      )
      (pos_enc): EspnetRelPositionalEncoding(
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (after_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (encoders): ModuleList(
      (0-5): 6 x ConformerEncoderLayer(
        (self_attn): BlockRelPositionMultiHeadedAttention(
          (linear_q): Linear(in_features=512, out_features=512, bias=True)
          (linear_k): Linear(in_features=512, out_features=512, bias=True)
          (linear_v): Linear(in_features=512, out_features=512, bias=True)
          (linear_out): Linear(in_features=512, out_features=512, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear_pos): Linear(in_features=512, out_features=512, bias=False)
        )
        (feed_forward): PositionwiseFeedForward(
          (w_1): Linear(in_features=512, out_features=2048, bias=True)
          (activation): SiLU()
          (dropout): Dropout(p=0.1, inplace=False)
          (w_2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (norm_ff): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm_mha): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (encoder_proj): Linear(in_features=512, out_features=80, bias=True)
  (decoder): ConditionalCFM(
    (estimator): ConditionalDecoder(
      (time_embeddings): SinusoidalPosEmb()
      (time_mlp): TimestepEmbedding(
        (linear_1): Linear(in_features=320, out_features=1024, bias=True)
        (act): SiLU()
        (linear_2): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (down_blocks): ModuleList(
        (0): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(320, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(320, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Downsample1D(
            (conv): Conv1d(256, 256, kernel_size=(3,), stride=(2,), padding=(1,))
          )
        )
        (1): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
      )
      (mid_blocks): ModuleList(
        (0-11): 12 x ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(256, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
        )
      )
      (up_blocks): ModuleList(
        (0): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Upsample1D(
            (conv): ConvTranspose1d(256, 256, kernel_size=(4,), stride=(2,), padding=(1,))
          )
        )
        (1): ModuleList(
          (0): ResnetBlock1D(
            (mlp): Sequential(
              (0): Mish()
              (1): Linear(in_features=1024, out_features=256, bias=True)
            )
            (block1): Block1D(
              (block): Sequential(
                (0): Conv1d(512, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (block2): Block1D(
              (block): Sequential(
                (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
                (1): GroupNorm(8, 256, eps=1e-05, affine=True)
                (2): Mish()
              )
            )
            (res_conv): Conv1d(512, 256, kernel_size=(1,), stride=(1,))
          )
          (1): ModuleList(
            (0-3): 4 x BasicTransformerBlock(
              (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (attn1): Attention(
                (to_q): Linear(in_features=256, out_features=512, bias=False)
                (to_k): Linear(in_features=256, out_features=512, bias=False)
                (to_v): Linear(in_features=256, out_features=512, bias=False)
                (to_out): ModuleList(
                  (0): Linear(in_features=512, out_features=256, bias=True)
                  (1): Dropout(p=0, inplace=False)
                )
              )
              (norm3): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
              (ff): FeedForward(
                (net): ModuleList(
                  (0): GELU(
                    (proj): Linear(in_features=256, out_features=1024, bias=True)
                  )
                  (1): Dropout(p=0, inplace=False)
                  (2): LoRACompatibleLinear(in_features=1024, out_features=256, bias=True)
                )
              )
            )
          )
          (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        )
      )
      (final_block): Block1D(
        (block): Sequential(
          (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          (1): GroupNorm(8, 256, eps=1e-05, affine=True)
          (2): Mish()
        )
      )
      (final_proj): Conv1d(256, 80, kernel_size=(1,), stride=(1,))
    )
  )
  (length_regulator): InterpolateRegulator(
    (model): Sequential(
      (0): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): GroupNorm(1, 80, eps=1e-05, affine=True)
      (2): Mish()
      (3): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (4): GroupNorm(1, 80, eps=1e-05, affine=True)
      (5): Mish()
      (6): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (7): GroupNorm(1, 80, eps=1e-05, affine=True)
      (8): Mish()
      (9): Conv1d(80, 80, kernel_size=(3,), stride=(1,), padding=(1,))
      (10): GroupNorm(1, 80, eps=1e-05, affine=True)
      (11): Mish()
      (12): Conv1d(80, 80, kernel_size=(1,), stride=(1,))
    )
  )
)

THUDM/glm-4-voice-decoder:
audio_tokenizer.audio_decoder.hift 20.460591 M parameters

HiFTGenerator(
  (m_source): SourceModuleHnNSF(
    (l_sin_gen): SineGen()
    (l_linear): Linear(in_features=9, out_features=1, bias=True)
    (l_tanh): Tanh()
  )
  (f0_upsamp): Upsample(scale_factor=256.0, mode='nearest')
  (conv_pre): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))
  (ups): ModuleList(
    (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
    (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
  )
  (source_downs): ModuleList(
    (0): Conv1d(18, 256, kernel_size=(np.int64(16),), stride=(np.int64(8),), padding=(np.int64(4),))
    (1): Conv1d(18, 128, kernel_size=(1,), stride=(1,))
  )
  (source_resblocks): ModuleList(
    (0): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
        (1): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (1): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
        (1): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
  )
  (resblocks): ModuleList(
    (0): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (1): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
        (1): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(7,), stride=(1,), padding=(3,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (2): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(5,))
        (1): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))
        (2): Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(256, 256, kernel_size=(11,), stride=(1,), padding=(5,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (3): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
        (1): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(3,), stride=(1,), padding=(1,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (4): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
        (1): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(9,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(15,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(7,), stride=(1,), padding=(3,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
    (5): ResBlock(
      (convs1): ModuleList(
        (0): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
        (1): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(15,), dilation=(3,))
        (2): Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(25,), dilation=(5,))
      )
      (convs2): ModuleList(
        (0-2): 3 x Conv1d(128, 128, kernel_size=(11,), stride=(1,), padding=(5,))
      )
      (activations1): ModuleList(
        (0-2): 3 x Snake()
      )
      (activations2): ModuleList(
        (0-2): 3 x Snake()
      )
    )
  )
  (conv_post): Conv1d(128, 18, kernel_size=(7,), stride=(1,), padding=(3,))
  (reflection_pad): ReflectionPad1d((1, 0))
  (f0_predictor): ConvRNNF0Predictor(
    (condnet): Sequential(
      (0): Conv1d(80, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (1): ELU(alpha=1.0)
      (2): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (3): ELU(alpha=1.0)
      (4): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (5): ELU(alpha=1.0)
      (6): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (7): ELU(alpha=1.0)
      (8): Conv1d(512, 512, kernel_size=(3,), stride=(1,), padding=(1,))
      (9): ELU(alpha=1.0)
    )
    (classifier): Linear(in_features=512, out_features=1, bias=True)
  )
)

weedge added 4 commits May 10, 2025 19:30
…evoice_glm4voice(sensevoice WavFrontend, decoder)

Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
…ita_audio_asr llm_transformers_manual_vita_tts llm_transformers_manual_vita_text_voice llm_transformers_manual_vita_voice

Signed-off-by: weedge <weege007@gmail.com>
@weedge weedge added TTS ASR AR Flow modal voice A1-T2A2 (speech)-to-(text and speech) vocoder T1-T2A2 (text)-to-(text and speech) labels May 11, 2025
weedge added 10 commits May 11, 2025 20:55
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
Signed-off-by: weedge <weege007@gmail.com>
@weedge weedge added the MTP label May 12, 2025
Signed-off-by: weedge <weege007@gmail.com>
@weedge
Copy link
Collaborator Author

weedge commented May 15, 2025

training script args


ModelArguments(
    model_name_or_path="/models/Qwen/Qwen2.5-7B-Instruct",
    model_type=None,
    config_overrides=None,
    config_name="/models/Qwen/Qwen2.5-7B-Instruct",
    tokenizer_name="/models/Qwen/Qwen2.5-7B-Instruct",
    cache_dir=None,
    use_fast_tokenizer=True,
    model_revision="main",
    token=None,
    trust_remote_code=False,
    torch_dtype="bfloat16",
    low_cpu_mem_usage=False,
    attn_implementation="flash_attention_2",
    audio_tokenizer_path="/models/THUDM/glm-4-voice-tokenizer",
    audio_tokenizer_type="glm4voice",
    text_audio_interval_ratio=[1, 10, 4, 10],
    audio_model_freeze=False,
    vision_model_name_or_path=None,
    vision_model_type=None,
    vision_model_freeze=False,
    language_model_freeze=False,
    vision_projector_type="mlp",
    vision_projector_pre_norm=False,
    vision_downsample_ratio=0.5,
    image_size=448,
    image_token_length=1025,
    max_num_frame=16,
    max_fps=1,
    min_patch_grid=1,
    max_patch_grid=12,
    vision_process_type="dynamic",
    vision_normalize_type="imagenet",
    model_max_length=8192,
    A=0,
    B=None,
    C=False,
)

DataTrainingArguments(
    dataset_name="/VITA-Audio/configs/sts_finetune_stage1_test.yaml",
    dataset_config_name=None,
    train_file=None,
    validation_file=None,
    max_train_samples=None,
    max_eval_samples=None,
    streaming=False,
    block_size=None,
    overwrite_cache=False,
    validation_split_percentage=5,
    preprocessing_num_workers=None,
    keep_linebreaks=True,
    create_attention_mask=False,
    create_attention_mask_2d=False,
    reset_position_ids=True,
    reset_attention_mask=True,
    cross_dataset_joint=False,
    dataset_joint=True,
    variable_length=False,
    D=0,
    E=None,
    F=False,
)

TrainingArguments(output_dir='/train_output/finetune_glm4voice_stage1',
    overwrite_output_dir=True,
    do_train=True,
    do_eval=False,
    do_predict=False,
    eval_strategy=<IntervalStrategy.NO: 'no'>,
    prediction_loss_only=False,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    per_gpu_train_batch_size=None,
    per_gpu_eval_batch_size=None,
    gradient_accumulation_steps=16,
    eval_accumulation_steps=None,
    eval_delay=0,
    torch_empty_cache_steps=None,
    learning_rate=6e-05,
    weight_decay=0.0,
    adam_beta1=0.9,
    adam_beta2=0.95,
    adam_epsilon=1e-08,
    max_grad_norm=1.0,
    num_train_epochs=1.0,
    max_steps=8000,
    lr_scheduler_type=<SchedulerType.COSINE: 'cosine'>,
    lr_scheduler_kwargs={},
    warmup_ratio=0.03,
    warmup_steps=0,
    log_level='info',
    log_level_replica='warning',
    log_on_each_node=True,
    logging_dir='/train_output/finetune_glm4voice_stage1/runs/May15_07-29-43_modal',
    logging_strategy=<IntervalStrategy.STEPS: 'steps'>,
    logging_first_step=False,
    logging_steps=1.0,
    logging_nan_inf_filter=True,
    save_strategy=<SaveStrategy.STEPS: 'steps'>,
    save_steps=0.1,
    save_total_limit=2,
    save_safetensors=True,
    save_on_each_node=False,
    save_only_model=False,
    restore_callback_states_from_checkpoint=False,
    no_cuda=False,
    use_cpu=False,
    use_mps_device=False,
    seed=42,
    data_seed=42,
    jit_mode_eval=False,
    use_ipex=False,
    bf16=True,
    fp16=False,
    fp16_opt_level='O1',
    half_precision_backend='auto',
    bf16_full_eval=False,
    fp16_full_eval=False,
    tf32=True,
    local_rank=0,
    ddp_backend='nccl',
    tpu_num_cores=None,
    tpu_metrics_debug=False,
    debug=[],
    dataloader_drop_last=False,
    eval_steps=None,
    dataloader_num_workers=8,
    dataloader_prefetch_factor=None,
    past_index=-1,
    run_name='/train_output/finetune_glm4voice_stage1',
    disable_tqdm=False,
    remove_unused_columns=True,
    label_names=None,
    load_best_model_at_end=False,
    metric_for_best_model=None,
    greater_is_better=None,
    ignore_data_skip=False,
    fsdp=[],
    fsdp_min_num_params=0,
    fsdp_config={'min_num_params': 0,
    'xla': False,
    'xla_fsdp_v2': False,
    'xla_fsdp_grad_ckpt': False},
    fsdp_transformer_layer_cls_to_wrap=None,
    accelerator_config=AcceleratorConfig(split_batches=False,
    dispatch_batches=None,
    even_batches=True,
    use_seedable_sampler=True,
    non_blocking=False,
    gradient_accumulation_kwargs=None,
    use_configured_state=False),
    deepspeed='/VITA-Audio/scripts/deepspeed/ds_config_zero2.json',
    label_smoothing_factor=0.0,
    optim=<OptimizerNames.ADAMW_TORCH: 'adamw_torch'>,
    optim_args=None,
    adafactor=False,
    group_by_length=False,
    length_column_name='length',
    report_to=['tensorboard'],
    ddp_find_unused_parameters=None,
    ddp_bucket_cap_mb=None,
    ddp_broadcast_buffers=None,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=False,
    skip_memory_metrics=True,
    use_legacy_prediction_loop=False,
    push_to_hub=False,
    resume_from_checkpoint=None,
    hub_model_id=None,
    hub_strategy=<HubStrategy.EVERY_SAVE: 'every_save'>,
    hub_token=None,
    hub_private_repo=None,
    hub_always_push=False,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={'use_reentrant': False},
    include_inputs_for_metrics=False,
    include_for_metrics=[],
    eval_do_concat_batches=True,
    fp16_backend='auto',
    evaluation_strategy=None,
    push_to_hub_model_id=None,
    push_to_hub_organization=None,
    push_to_hub_token=None,
    mp_parameters='',
    auto_find_batch_size=False,
    full_determinism=False,
    torchdynamo=None,
    ray_scope='last',
    ddp_timeout=7200,
    torch_compile=False,
    torch_compile_backend=None,
    torch_compile_mode=None,
    dispatch_batches=None,
    split_batches=None,
    include_tokens_per_second=False,
    include_num_input_tokens_seen=False,
    neftune_noise_alpha=None,
    optim_target_modules=None,
    batch_eval_metrics=False,
    eval_on_start=False,
    use_liger_kernel=False,
    eval_use_gather_object=False,
    average_tokens_across_devices=False,
    vision_model_lr_mult=1.0,
    vision_model_lr_decay_rate=1.0,
    mtp_model_lr_mult=1.0,
)

@weedge
Copy link
Collaborator Author

weedge commented May 15, 2025

deepspeed args:

{
    "fp16": {
        "enabled": false, 
        "loss_scale": 0, 
        "loss_scale_window": 1000, 
        "initial_scale_power": 16, 
        "hysteresis": 2, 
        "min_loss_scale": 1
    }, 
    "bf16": {
        "enabled": true
    }, 
    "optimizer": {
        "type": "AdamW", 
        "params": {
            "lr": 6e-05, 
            "betas": [0.9, 0.95], 
            "eps": 1e-08, 
            "weight_decay": 0.0
        }
    }, 
    "scheduler": {
        "type": "WarmupCosineLR", 
        "params": {
            "total_num_steps": 8.000000e+03, 
            "warmup_min_ratio": 0, 
            "warmup_num_steps": 240, 
            "cos_min_ratio": 0.1
        }
    }, 
    "zero_optimization": {
        "stage": 2, 
        "offload_optimizer": {
            "device": "none", 
            "pin_memory": true
        }, 
        "offload_param": {
            "device": "none", 
            "pin_memory": true
        }, 
        "allgather_partitions": true, 
        "allgather_bucket_size": 5.000000e+08, 
        "overlap_comm": true, 
        "reduce_scatter": true, 
        "reduce_bucket_size": 5.000000e+08, 
        "contiguous_gradients": true, 
        "round_robin_gradients": true, 
        "sub_group_size": 1.000000e+12
    }, 
    "gradient_accumulation_steps": 16, 
    "gradient_clipping": 1.0, 
    "steps_per_print": inf, 
    "train_batch_size": 16, 
    "train_micro_batch_size_per_gpu": 1, 
    "wall_clock_breakdown": false, 
    "dump_state": false
}

要进一步优化DeepSpeed Zero优化器的内存使用效率,可以考虑以下几个方面:

  1. 选择合适的ZeRO阶段:

ZeRO Stage 1:适用于中小模型,内存需求适中。
ZeRO Stage 2:适用于中等规模模型,平衡内存和通信。
ZeRO Stage 3:适用于超大规模模型,但需优化网络带宽。
2. 调整批次大小和梯度累积步数:

增大批次大小(通过 train_batch_size 或 gradient_accumulation_steps)可提高GPU利用率。
结合实际情况,调整 train_batch_size 和 gradient_accumulation_steps,以获得最佳性能。
3. 启用混合精度训练:

使用半精度(FP16)或混合精度(BF16)来表示模型参数和计算,从而减少内存占用和加速计算。
确保硬件支持,并使用自动损失缩放(Loss Scaling)来保持数值稳定性。
4. 优化通信:

确保集群网络带宽充足(推荐InfiniBand或高性能以太网)。
启用通信压缩("compress_communication": true)减少通信量。
使用高效的通信后端(如NCCL),确保高带宽互联(如NVLink、InfiniBand)。
5. 卸载到CPU/NVMe:

如果GPU内存不足,启用优化器和参数卸载("offload_optimizer" 和 "offload_param")。
确保CPU内存和NVMe速度足够,避免瓶颈。
6. 使用激活检查点:

对于深层模型,启用激活检查点("activation_checkpointing": true)可显著减少内存占用。
权衡:检查点会增加约20%-30%的计算开销。
7. 合理选择ZeRO优化器的阶段:

根据模型规模和硬件资源,选择合适的ZeRO阶段。
ZeRO-1: 仅划分优化器参数,每个GPU各有一份完整的模型参数与梯度。

ZeRO-2: 划分优化器参数与梯度,每个GPU各有一份完整的模型参数。

ZeRO-3: 划分优化器参数、梯度与模型参数。

简单来说:从 ZeRO-1 到 ZeRO-3,阶段数越高,显存需求越小,但是训练速度也依次变慢。此外,设置 offload_param=cpu 参数会大幅减小显存需求,但会极大地使训练速度减慢。因此,如果您有足够的显存, 应当使用 ZeRO-1,并且确保 offload_param=none。


在DeepSpeed中,ZeRO(Zero Redundancy Optimizer)通过三个阶段来优化内存使用,每个阶段都有其特定的目标和实现方式。以下是对ZeRO的三个阶段的详细说明:

ZeRO Stage 1: 优化器状态分区

  • 原理:将优化器状态(例如,对于Adam优化器,32位权重以及一阶和二阶矩估计)在各个进程之间进行分区,使得每个进程只更新其分区。
  • 优点:显著减少了每个GPU上的显存占用,适用于中小规模模型,内存需求适中。

ZeRO Stage 2: 优化器+梯度状态分区

  • 原理:在Stage 1的基础上,进一步将用于更新模型权重的减少的16位梯度也进行分割,使得每个进程仅保留与其优化器状态部分相对应的梯度。
  • 优点:较大幅度降低显存需求,同时通信开销适中,适用于大多数大规模模型训练,兼顾效率和资源节省。

ZeRO Stage 3: 优化器+梯度+参数分区

  • 原理:将16位模型参数在各个进程中进行划分。ZeRO-3会在前向和后向传递过程中自动收集和划分它们。此外,ZeRO-3还包括一个无限卸载引擎,形成ZeRO-Infinity,可以将所有模型状态卸载到CPU和NVMe内存中,从而实现巨大的内存节省。
  • 优点:最大限度地减少内存占用,适合超大规模模型(如GPT-3)训练,但通信开销较大,对网络带宽要求高。

具体见: https://weedge.github.io/post/llm/trainingparallelstrategy/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

A1-T2A2 (speech)-to-(text and speech) AR ASR Flow modal MTP T1-T2A2 (text)-to-(text and speech) TTS vocoder voice

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant