diff --git a/README.md b/README.md index adf4f1a1f9..e9cd2f5424 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@

## 📢 News +- **[2025-12]** AgentScope supports [TTS(Text-to-Speech)](https://doc.agentscope.io/tutorial/task_tts.html) now! Check our [example]() and [tutorial](https://doc.agentscope.io/tutorial/task_tts.html) for more details. - **[2025-11]** AgentScope supports [Anthropic Agent Skill](https://claude.com/blog/skills) now! Check our [example](https://github.com/agentscope-ai/agentscope/tree/main/examples/functionality/agent_skill) and [tutorial](https://doc.agentscope.io/tutorial/task_agent_skill.html) for more details. - **[2025-11]** AgentScope open-sources [Alias-Agent](https://github.com/agentscope-ai/agentscope-samples/tree/main/alias) for diverse real-world tasks and [Data-Juicer Agent](https://github.com/agentscope-ai/agentscope-samples/tree/main/data_juicer_agent) for data processing. - **[2025-11]** AgentScope supports [Agentic RL](https://github.com/agentscope-ai/agentscope/tree/main/examples/training/react_agent) via integrating [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) library. diff --git a/README_zh.md b/README_zh.md index bd5be58d4c..3ccb661226 100644 --- a/README_zh.md +++ b/README_zh.md @@ -54,6 +54,7 @@

## 📢 新闻 +- **[2025-12]** AgentScope 已支持 [TTS(Text-to-Speech) 模型](https://doc.agentscope.io/zh_CN/tutorial/task_tts.html) !欢迎查看 [样例]() 和 [教程](https://doc.agentscope.io/zh_CN/tutorial/task_tts.html) 了解更多详情。 - **[2025-11]** AgentScope 已支持 [Anthropic Agent Skill](https://claude.com/blog/skills) !欢迎查看 [样例](https://github.com/agentscope-ai/agentscope/tree/main/examples/functionality/agent_skill) 和 [教程](https://doc.agentscope.io/zh_CN/tutorial/task_agent_skill.html) 了解更多详情。 - **[2025-11]** AgentScope 开源 [Alias-Agent](https://github.com/agentscope-ai/agentscope-samples/tree/main/alias) 用于处理多样化的真实任务,以及 [Data-Juicer Agent](https://github.com/agentscope-ai/agentscope-samples/tree/main/data_juicer_agent) 用于自然语言驱动的数据处理。 - **[2025-11]** AgentScope 通过集成 [Trinity-RFT](https://github.com/modelscope/Trinity-RFT) 实现对 [Agentic RL](https://github.com/agentscope-ai/agentscope/tree/main/examples/training/react_agent) 的支持。 diff --git a/docs/tutorial/en/index.rst b/docs/tutorial/en/index.rst index a9f53b16c1..e5ecb2684a 100644 --- a/docs/tutorial/en/index.rst +++ b/docs/tutorial/en/index.rst @@ -33,26 +33,42 @@ Welcome to AgentScope's documentation! .. toctree:: :maxdepth: 1 - :caption: Task Guides + :caption: Model and Context tutorial/task_model tutorial/task_prompt - tutorial/task_tool + tutorial/task_token tutorial/task_memory tutorial/task_long_term_memory + +.. toctree:: + :maxdepth: 1 + :caption: Tool + + tutorial/task_tool + tutorial/task_mcp + tutorial/task_agent_skill + +.. toctree:: + :maxdepth: 1 + :caption: Agent + tutorial/task_agent + tutorial/task_state + tutorial/task_hook + +.. toctree:: + :maxdepth: 1 + :caption: Features + tutorial/task_pipeline tutorial/task_plan tutorial/task_rag - tutorial/task_state - tutorial/task_hook - tutorial/task_mcp - tutorial/task_agent_skill tutorial/task_studio tutorial/task_tracing tutorial/task_eval tutorial/task_embedding - tutorial/task_token + tutorial/task_tts .. toctree:: :maxdepth: 1 @@ -76,3 +92,4 @@ Welcome to AgentScope's documentation! api/agentscope.tracing api/agentscope.session api/agentscope.exception + api/agentscope.tts diff --git a/docs/tutorial/en/src/task_tts.py b/docs/tutorial/en/src/task_tts.py new file mode 100644 index 0000000000..f692237d78 --- /dev/null +++ b/docs/tutorial/en/src/task_tts.py @@ -0,0 +1,243 @@ +# -*- coding: utf-8 -*- +""" +.. _tts: + +TTS +==================== + +AgentScope provides a unified interface for Text-to-Speech (TTS) models across multiple API providers. +This tutorial demonstrates how to use TTS models in AgentScope. + +AgentScope supports the following TTS APIs: + +.. list-table:: Built-in TTS Models + :header-rows: 1 + + * - API + - Class + - Streaming Input + - Non-Streaming Input + - Streaming Output + - Non-Streaming Output + * - DashScope Realtime API + - ``DashScopeRealtimeTTSModel`` + - ✅ + - ✅ + - ✅ + - ✅ + * - DashScope API + - ``DashScopeTTSModel`` + - ❌ + - ✅ + - ✅ + - ✅ + * - OpenAI API + - ``OpenAITTSModel`` + - ❌ + - ✅ + - ✅ + - ✅ + * - Gemini API + - ``GeminiTTSModel`` + - ❌ + - ✅ + - ✅ + - ✅ + +.. note:: The streaming input and output in AgentScope TTS models are all accumulative. + +**Choosing the Right Model:** + +- **Use Non-Realtime TTS** when you have complete text ready (e.g., pre-written + responses, complete LLM outputs) +- **Use Realtime TTS** when text is generated progressively (e.g., streaming + LLM responses) for lower latency + +""" + +import asyncio +import os + +from agentscope.agent import ReActAgent, UserAgent +from agentscope.formatter import DashScopeChatFormatter +from agentscope.message import Msg +from agentscope.model import DashScopeChatModel +from agentscope.tts import ( + DashScopeRealtimeTTSModel, + DashScopeTTSModel, +) + +# %% +# Non-Realtime TTS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Non-realtime TTS models process complete text inputs and are the simplest +# to use. You can directly call their ``synthesize()`` method. +# +# Taking DashScope TTS model as an example: + + +async def example_non_realtime_tts() -> None: + """A basic example of using non-realtime TTS models.""" + # Example with DashScope TTS + tts_model = DashScopeTTSModel( + api_key=os.environ.get("DASHSCOPE_API_KEY", ""), + model_name="qwen3-tts-flash", + voice="Cherry", + stream=False, # Non-streaming output + ) + + msg = Msg( + name="assistant", + content="Hello, this is DashScope TTS.", + role="assistant", + ) + + # Directly synthesize without connecting + tts_response = await tts_model.synthesize(msg) + + # tts_response.content contains an audio block with base64-encoded audio data + print( + "The length of audio data:", + len(tts_response.content[0]["source"]["data"]), + ) + + +asyncio.run(example_non_realtime_tts()) + +# %% +# **Streaming Output for Lower Latency:** +# +# When ``stream=True``, the model returns audio chunks progressively, allowing +# you to start playback before synthesis completes. This reduces perceived latency. +# + + +async def example_non_realtime_tts_streaming() -> None: + """An example of using non-realtime TTS models with streaming output.""" + # Example with DashScope TTS with streaming output + tts_model = DashScopeTTSModel( + api_key=os.environ.get("DASHSCOPE_API_KEY", ""), + model_name="qwen3-tts-flash", + voice="Cherry", + stream=True, # Enable streaming output + ) + + msg = Msg( + name="assistant", + content="Hello, this is DashScope TTS with streaming output.", + role="assistant", + ) + + # Synthesize and receive an async generator for streaming output + async for tts_response in await tts_model.synthesize(msg): + # Process each audio chunk as it arrives + print( + "Received audio chunk of length:", + len(tts_response.content[0]["source"]["data"]), + ) + + +asyncio.run(example_non_realtime_tts_streaming()) + + +# %% +# Realtime TTS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Realtime TTS models are designed for scenarios where text is generated +# incrementally, such as streaming LLM responses. This enables the lowest +# possible latency by starting audio synthesis before the complete text is ready. +# +# **Key Concepts:** +# +# - **Stateful Processing**: Realtime TTS maintains state for a single streaming +# session, identified by ``msg.id``. Only one streaming session can be active +# at a time. +# - **Two Methods**: +# +# - ``push(msg)``: Non-blocking method that submits text chunks and returns +# immediately. May return partial audio if available. +# - ``synthesize(msg)``: Blocking method that finalizes the session and returns +# all remaining audio. When ``stream=True``, it returns an async generator. +# +# .. code-block:: python +# +# async def example_realtime_tts_streaming(): +# tts_model = DashScopeRealtimeTTSModel( +# api_key=os.environ.get("DASHSCOPE_API_KEY", ""), +# model_name="qwen3-tts-flash-realtime", +# voice="Cherry", +# stream=False, +# ) +# +# # realtime tts model received accumulative text chunks +# res = await tts_model.push(msg_chunk_1) # non-blocking +# res = await tts_model.push(msg_chunk_2) # non-blocking +# ... +# res = await tts_model.synthesize(final_msg) # blocking, get all remaining audio +# +# When setting ``stream=True`` during initialization, the ``synthesize()`` method returns an async generator of ``TTSResponse`` objects, allowing you to process audio chunks as they arrive. +# +# +# Integrating with ReActAgent +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# AgentScope agents can automatically synthesize their responses to speech +# when provided with a TTS model. This works seamlessly with both realtime +# and non-realtime TTS models. +# +# **How It Works:** +# +# 1. The agent generates a text response (potentially streamed from an LLM) +# 2. The TTS model synthesizes the text to audio automatically +# 3. The synthesized audio is attached to the ``speech`` field of the ``Msg`` object +# 4. The audio is played during the agent's ``self.print()`` method +# + + +async def example_agent_with_tts() -> None: + """An example of using TTS with ReActAgent.""" + # Create an agent with TTS enabled + agent = ReActAgent( + name="Assistant", + sys_prompt="You are a helpful assistant.", + model=DashScopeChatModel( + api_key=os.environ.get("DASHSCOPE_API_KEY", ""), + model_name="qwen-max", + stream=True, + ), + formatter=DashScopeChatFormatter(), + # Enable TTS + tts_model=DashScopeRealtimeTTSModel( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model_name="qwen3-tts-flash-realtime", + voice="Cherry", + ), + ) + user = UserAgent("User") + + # Build a conversation just like normal + msg = None + while True: + msg = await agent(msg) + msg = await user(msg) + if msg.get_text_content() == "exit": + break + + +# %% +# Customizing TTS Model +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# You can create custom TTS implementations by inheriting from ``TTSModelBase``. +# The base class provides a flexible interface for both realtime and non-realtime +# TTS models. +# We use an attribute ``supports_streaming_input`` to indicate if the TTS model is realtime or not. +# +# For realtime TTS models, you need to implement the ``connect``, ``close``, ``push`` and ``synthesize`` methods to handle the lifecycle and streaming input. +# +# While for non-realtime TTS models, you only need to implement the ``synthesize`` method. +# +# Further Reading +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# - :ref:`agent` - Learn more about agents in AgentScope +# - :ref:`message` - Understand message format in AgentScope +# - API Reference: :class:`agentscope.tts.TTSModelBase` +# diff --git a/docs/tutorial/zh_CN/index.rst b/docs/tutorial/zh_CN/index.rst index a9f53b16c1..704bcf4b67 100644 --- a/docs/tutorial/zh_CN/index.rst +++ b/docs/tutorial/zh_CN/index.rst @@ -31,28 +31,45 @@ Welcome to AgentScope's documentation! tutorial/faq + .. toctree:: :maxdepth: 1 - :caption: Task Guides + :caption: Model and Context tutorial/task_model tutorial/task_prompt - tutorial/task_tool + tutorial/task_token tutorial/task_memory tutorial/task_long_term_memory + +.. toctree:: + :maxdepth: 1 + :caption: Tool + + tutorial/task_tool + tutorial/task_mcp + tutorial/task_agent_skill + +.. toctree:: + :maxdepth: 1 + :caption: Agent + tutorial/task_agent + tutorial/task_state + tutorial/task_hook + +.. toctree:: + :maxdepth: 1 + :caption: Features + tutorial/task_pipeline tutorial/task_plan tutorial/task_rag - tutorial/task_state - tutorial/task_hook - tutorial/task_mcp - tutorial/task_agent_skill tutorial/task_studio tutorial/task_tracing tutorial/task_eval tutorial/task_embedding - tutorial/task_token + tutorial/task_tts .. toctree:: :maxdepth: 1 @@ -76,3 +93,4 @@ Welcome to AgentScope's documentation! api/agentscope.tracing api/agentscope.session api/agentscope.exception + api/agentscope.tts diff --git a/docs/tutorial/zh_CN/src/task_tts.py b/docs/tutorial/zh_CN/src/task_tts.py new file mode 100644 index 0000000000..8dd4f2513d --- /dev/null +++ b/docs/tutorial/zh_CN/src/task_tts.py @@ -0,0 +1,230 @@ +# -*- coding: utf-8 -*- +""" +.. _tts: + +TTS +==================== + +AgentScope 为多个 API 提供商的文本转语音(TTS)模型提供了统一接口。 +本章节演示如何在 AgentScope 中使用 TTS 模型。 + +AgentScope 支持以下 TTS API: + +.. list-table:: 内置 TTS 模型 + :header-rows: 1 + + * - API + - 类 + - 流式输入 + - 非流式输入 + - 流式输出 + - 非流式输出 + * - DashScope 实时 API + - ``DashScopeRealtimeTTSModel`` + - ✅ + - ✅ + - ✅ + - ✅ + * - DashScope API + - ``DashScopeTTSModel`` + - ❌ + - ✅ + - ✅ + - ✅ + * - OpenAI API + - ``OpenAITTSModel`` + - ❌ + - ✅ + - ✅ + - ✅ + * - Gemini API + - ``GeminiTTSModel`` + - ❌ + - ✅ + - ✅ + - ✅ + +.. note:: AgentScope TTS 模型中的流式输入和输出都是累积式的。 + +**选择合适的模型:** + +- **使用非实时 TTS**:当已有完整文本时(例如预先编写的响应、完整的 LLM 输出) +- **使用实时 TTS**:当文本是逐步生成时(例如 LLM 的流式返回),以获得更低的延迟 + +""" + +import asyncio +import os + +from agentscope.agent import ReActAgent, UserAgent +from agentscope.formatter import DashScopeChatFormatter +from agentscope.message import Msg +from agentscope.model import DashScopeChatModel +from agentscope.tts import ( + DashScopeRealtimeTTSModel, + DashScopeTTSModel, +) + +# %% +# 非实时 TTS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 非实时 TTS 模型处理完整的文本输入,使用起来最简单,可以直接调用它们的 ``synthesize()`` 方法。 +# +# 以 DashScope TTS 模型为例: + + +async def example_non_realtime_tts() -> None: + """使用非实时 TTS 模型的基本示例。""" + # DashScope TTS 示例 + tts_model = DashScopeTTSModel( + api_key=os.environ.get("DASHSCOPE_API_KEY", ""), + model_name="qwen3-tts-flash", + voice="Cherry", + stream=False, # 非流式输出 + ) + + msg = Msg( + name="assistant", + content="你好,这是 DashScope TTS。", + role="assistant", + ) + + tts_response = await tts_model.synthesize(msg) + + # tts_response.content 包含一个带有 base64 编码音频数据的音频块 + print("音频数据长度:", len(tts_response.content[0]["source"]["data"])) + + +asyncio.run(example_non_realtime_tts()) + +# %% +# **流式输出以降低延迟:** +# +# 当 ``stream=True`` 时,模型会逐步返回音频块,允许 +# 您在合成完成前开始播放。这减少了感知延迟。 +# + + +async def example_non_realtime_tts_streaming() -> None: + """使用带流式输出的非实时 TTS 模型的示例。""" + # 使用流式输出的 DashScope TTS 示例 + tts_model = DashScopeTTSModel( + api_key=os.environ.get("DASHSCOPE_API_KEY", ""), + model_name="qwen3-tts-flash", + voice="Cherry", + stream=True, # 启用流式输出 + ) + + msg = Msg( + name="assistant", + content="你好,这是带流式输出的 DashScope TTS。", + role="assistant", + ) + + # 合成并接收用于流式输出的异步生成器 + async for tts_response in await tts_model.synthesize(msg): + # 处理到达的每个音频块 + print("接收到的音频块长度:", len(tts_response.content[0]["source"]["data"])) + + +asyncio.run(example_non_realtime_tts_streaming()) + + +# %% +# 实时 TTS +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 实时 TTS 模型专为文本增量生成的场景设计, +# 例如流式 LLM 响应。这通过在完整文本准备好之前 +# 开始音频合成,实现尽可能低的延迟。 +# +# **核心概念:** +# +# - **有状态处理**:实时 TTS 为单个流式会话维护状态, +# 由 ``msg.id`` 标识。一次只能有一个流式会话处于活动状态。 +# - **两种方法**: +# +# - ``push(msg)``:非阻塞方法,提交文本块并立即返回。 +# 如果有可用的部分音频,可能会返回。 +# - ``synthesize(msg)``:阻塞方法,完成会话并返回 +# 所有剩余的音频。当 ``stream=True`` 时,返回异步生成器。 +# +# .. code-block:: python +# +# async def example_realtime_tts_streaming(): +# tts_model = DashScopeRealtimeTTSModel( +# api_key=os.environ.get("DASHSCOPE_API_KEY", ""), +# model_name="qwen3-tts-flash-realtime", +# voice="Cherry", +# stream=False, +# ) +# +# # 实时 tts 模型接收累积的文本块 +# res = await tts_model.push(msg_chunk_1) # 非阻塞 +# res = await tts_model.push(msg_chunk_2) # 非阻塞 +# ... +# res = await tts_model.synthesize(final_msg) # 阻塞,获取所有剩余音频 +# +# 在初始化时设置 ``stream=True`` 时,``synthesize()`` 方法返回 ``TTSResponse`` 对象的异步生成器,允许您在音频块到达时处理它们。 +# +# +# 与 ReActAgent 集成 +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# AgentScope 智能体在提供 TTS 模型时,可以自动将其响应合成为语音。 +# 这与实时和非实时 TTS 模型都能无缝协作。 +# +# **工作原理:** +# +# 1. 智能体生成文本响应(可能从 LLM 流式传输) +# 2. TTS 模型自动将文本合成为音频 +# 3. 合成的音频附加到 ``Msg`` 对象的 ``speech`` 字段 +# 4. 音频在智能体的 ``self.print()`` 方法期间播放 +# + + +async def example_agent_with_tts() -> None: + """使用带 TTS 的 ReActAgent 的示例。""" + # 创建启用了 TTS 的智能体 + agent = ReActAgent( + name="Assistant", + sys_prompt="你是一个有用的助手。", + model=DashScopeChatModel( + api_key=os.environ["DASHSCOPE_API_KEY"], + model_name="qwen-max", + stream=True, + ), + formatter=DashScopeChatFormatter(), + # 启用 TTS + tts_model=DashScopeRealtimeTTSModel( + api_key=os.getenv("DASHSCOPE_API_KEY"), + model_name="qwen3-tts-flash-realtime", + voice="Cherry", + ), + ) + user = UserAgent("User") + + # 像正常情况一样构建对话 + msg = None + while True: + msg = await agent(msg) + msg = await user(msg) + if msg.get_text_content() == "exit": + break + + +# %% +# 自定义 TTS 模型 +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 可以通过继承 ``TTSModelBase`` 来创建自定义 TTS 实现。 +# 基类为实时和非实时 TTS 模型提供了灵活的接口。 +# 我们使用属性 ``supports_streaming_input`` 来指示 TTS 模型是否为实时模型。 +# +# 对于实时 TTS 模型,需要实现 ``connect``、``close``、``push`` 和 ``synthesize`` 方法来处理 API 的生命周期和流式输入。 +# +# 而对于非实时 TTS 模型,只需实现 ``synthesize`` 方法。 +# +# 进一步阅读 +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# - :ref:`agent` - 了解更多关于 AgentScope 中的智能体 +# - :ref:`message` - 理解 AgentScope 中的消息格式 +# - API 参考::class:`agentscope.tts.TTSModelBase` +# diff --git a/examples/functionality/tts/README.md b/examples/functionality/tts/README.md new file mode 100644 index 0000000000..74edfa963f --- /dev/null +++ b/examples/functionality/tts/README.md @@ -0,0 +1,13 @@ +# TTS (Text-to-Speech) in AgentScope + +This example demonstrates how to integrate DashScope Realtime TTS model with `ReActAgent` to enable audio output. +The agent can speak its responses in real-time. + +This example uses DashScope's Realtime TTS model, you can also change to other TTS models supported by AgentScope, e.g. +OpenAI, Gemini, etc. + +To run the example, execute: + +```bash +python main.py +``` diff --git a/examples/functionality/tts/main.py b/examples/functionality/tts/main.py new file mode 100644 index 0000000000..a9967ae970 --- /dev/null +++ b/examples/functionality/tts/main.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- +"""The main entry point of the ReAct agent example.""" +import asyncio +import os + +from agentscope.agent import ReActAgent, UserAgent +from agentscope.formatter import DashScopeChatFormatter +from agentscope.memory import InMemoryMemory +from agentscope.model import DashScopeChatModel +from agentscope.tool import ( + Toolkit, + execute_shell_command, + execute_python_code, + view_text_file, +) +from agentscope.tts import DashScopeRealtimeTTSModel + + +async def main() -> None: + """The main entry point for the ReAct agent example.""" + toolkit = Toolkit() + toolkit.register_tool_function(execute_shell_command) + toolkit.register_tool_function(execute_python_code) + toolkit.register_tool_function(view_text_file) + + agent = ReActAgent( + name="Friday", + sys_prompt="You are a helpful assistant named Friday.", + model=DashScopeChatModel( + api_key=os.environ.get("DASHSCOPE_API_KEY"), + model_name="qwen3-max", + enable_thinking=False, + stream=True, + ), + formatter=DashScopeChatFormatter(), + toolkit=toolkit, + memory=InMemoryMemory(), + # Specify the TTS model for real-time speech synthesis + tts_model=DashScopeRealtimeTTSModel( + model_name="qwen3-tts-flash-realtime", + api_key=os.environ.get("DASHSCOPE_API_KEY"), + voice="Cherry", + stream=False, + ), + ) + user = UserAgent("User") + + msg = None + while True: + msg = await user(msg) + if msg.get_text_content() == "exit": + break + msg = await agent(msg) + + +asyncio.run(main()) diff --git a/src/agentscope/agent/_agent_base.py b/src/agentscope/agent/_agent_base.py index 0779a17815..e2023cd5da 100644 --- a/src/agentscope/agent/_agent_base.py +++ b/src/agentscope/agent/_agent_base.py @@ -202,7 +202,12 @@ async def reply(self, *args: Any, **kwargs: Any) -> Msg: f"{self.__class__.__name__} class.", ) - async def print(self, msg: Msg, last: bool = True) -> None: + async def print( + self, + msg: Msg, + last: bool = True, + speech: AudioBlock | list[AudioBlock] | None = None, + ) -> None: """The function to display the message. Args: @@ -211,9 +216,12 @@ async def print(self, msg: Msg, last: bool = True) -> None: last (`bool`, defaults to `True`): Whether this is the last one in streaming messages. For non-streaming message, this should always be `True`. + speech (`AudioBlock | list[AudioBlock] | None`, optional): + The audio content block(s) to be played along with the + message. """ if not self._disable_msg_queue: - await self.msg_queue.put((deepcopy(msg), last)) + await self.msg_queue.put((deepcopy(msg), last, speech)) if self._disable_console_output: return @@ -223,10 +231,7 @@ async def print(self, msg: Msg, last: bool = True) -> None: thinking_and_text_to_print = [] for block in msg.get_content_blocks(): - if block["type"] == "audio": - self._process_audio_block(msg.id, block) - - elif block["type"] == "text": + if block["type"] == "text": self._print_text_block( msg.id, name_prefix=msg.name, @@ -245,6 +250,13 @@ async def print(self, msg: Msg, last: bool = True) -> None: elif last: self._print_last_block(block, msg) + # Play audio block if exists + if isinstance(speech, list): + for audio_block in speech: + self._process_audio_block(msg.id, audio_block) + elif isinstance(speech, dict): + self._process_audio_block(msg.id, speech) + # Clean up resources if this is the last message in streaming if last and msg.id in self._stream_prefix: if "audio" in self._stream_prefix[msg.id]: @@ -392,18 +404,28 @@ def _print_text_block( def _print_last_block( self, - block: ToolUseBlock | ToolResultBlock | ImageBlock | VideoBlock, + block: ToolUseBlock + | ToolResultBlock + | ImageBlock + | VideoBlock + | AudioBlock, msg: Msg, ) -> None: """Process and print the last content block, and the block type - is not audio, text, or thinking. + is not text, or thinking. Args: - block (`ToolUseBlock | ToolResultBlock | ImageBlock | VideoBlock`): + block (`ToolUseBlock | ToolResultBlock | ImageBlock | VideoBlock \ + | AudioBlock`): The content block to be printed msg (`Msg`): The message object """ + # TODO: We should consider how to handle the multimodal blocks in the + # terminal, since the base64 data may be too long to display. + if block.get("type") in ["image", "video", "audio"]: + return + text_prefix = self._stream_prefix.get(msg.id, {}).get("text", "") if text_prefix: diff --git a/src/agentscope/agent/_react_agent.py b/src/agentscope/agent/_react_agent.py index e52b2aec41..d80a58f0c1 100644 --- a/src/agentscope/agent/_react_agent.py +++ b/src/agentscope/agent/_react_agent.py @@ -7,16 +7,23 @@ from pydantic import BaseModel, ValidationError, Field +from ._utils import _AsyncNullContext from ._react_agent_base import ReActAgentBase from .._logging import logger from ..formatter import FormatterBase from ..memory import MemoryBase, LongTermMemoryBase, InMemoryMemory -from ..message import Msg, ToolUseBlock, ToolResultBlock, TextBlock +from ..message import ( + Msg, + ToolUseBlock, + ToolResultBlock, + TextBlock, +) from ..model import ChatModelBase from ..rag import KnowledgeBase, Document from ..plan import PlanNotebook from ..tool import Toolkit, ToolResponse from ..tracing import trace_reply +from ..tts import TTSModelBase class _QueryRewriteModel(BaseModel): @@ -63,6 +70,7 @@ def __init__( plan_notebook: PlanNotebook | None = None, print_hint_msg: bool = False, max_iters: int = 10, + tts_model: TTSModelBase | None = None, ) -> None: """Initialize the ReAct agent @@ -120,6 +128,8 @@ def __init__( the long-term memory and knowledge base(s). max_iters (`int`, defaults to `10`): The maximum number of iterations of the reasoning-acting loops. + tts_model (`TTSModelBase | None` optional): + The TTS model used by the agent. """ super().__init__() @@ -135,6 +145,7 @@ def __init__( self.max_iters = max_iters self.model = model self.formatter = formatter + self.tts_model = tts_model # -------------- Memory management -------------- # Record the dialogue history in the memory @@ -390,11 +401,13 @@ async def reply( # pylint: disable=too-many-branches await self.memory.add(reply_msg) return reply_msg + # pylint: disable=too-many-branches async def _reasoning( self, tool_choice: Literal["auto", "none", "required"] | None = None, ) -> Msg: """Perform the reasoning process.""" + if self.plan_notebook: # Insert the reasoning hint from the plan notebook hint_msg = await self.plan_notebook.get_current_hint() @@ -423,22 +436,52 @@ async def _reasoning( # handle output from the model interrupted_by_user = False msg = None + + # TTS model context manager + tts_context = self.tts_model or _AsyncNullContext() + speech = None + try: - if self.model.stream: - msg = Msg(self.name, [], "assistant") - async for content_chunk in res: - msg.content = content_chunk.content - await self.print(msg, False) - await self.print(msg, True) + async with tts_context: + msg = Msg(name=self.name, content=[], role="assistant") + if self.model.stream: + async for content_chunk in res: + msg.content = content_chunk.content + + # The speech generated from multimodal (audio) models + # e.g. Qwen-Omni and GPT-AUDIO + speech = msg.get_content_blocks("audio") or None + + # Push to TTS model if available + if ( + self.tts_model + and self.tts_model.supports_streaming_input + ): + tts_res = await self.tts_model.push(msg) + speech = tts_res.content + + await self.print(msg, False, speech=speech) + + else: + msg.content = list(res.content) + + if self.tts_model: + # Push to TTS model and block to receive the full speech + # synthesis result + tts_res = await self.tts_model.synthesize(msg) + if self.tts_model.stream: + async for tts_chunk in tts_res: + speech = tts_chunk.content + await self.print(msg, False, speech=speech) + else: + speech = tts_res.content + + await self.print(msg, True, speech=speech) # Add a tiny sleep to yield the last message object in the # message queue await asyncio.sleep(0.001) - else: - msg = Msg(self.name, list(res.content), "assistant") - await self.print(msg, True) - except asyncio.CancelledError as e: interrupted_by_user = True raise e from None @@ -542,6 +585,7 @@ async def observe(self, msg: Msg | list[Msg] | None) -> None: async def _summarizing(self) -> Msg: """Generate a response when the agent fails to solve the problem in the maximum iterations.""" + hint_msg = Msg( "user", "You have failed to generate response within the maximum " @@ -562,18 +606,47 @@ async def _summarizing(self) -> Msg: # finish_function here res = await self.model(prompt) - res_msg = Msg(self.name, [], "assistant") - if isinstance(res, AsyncGenerator): - async for chunk in res: - res_msg.content = chunk.content - await self.print(res_msg, False) - await self.print(res_msg, True) + # TTS model context manager + tts_context = self.tts_model or _AsyncNullContext() + speech = None - else: - res_msg.content = res.content - await self.print(res_msg, True) + async with tts_context: + res_msg = Msg(self.name, [], "assistant") + if isinstance(res, AsyncGenerator): + async for chunk in res: + res_msg.content = chunk.content + + # The speech generated from multimodal (audio) models + # e.g. Qwen-Omni and GPT-AUDIO + speech = res_msg.get_content_blocks("audio") or None + + # Push to TTS model if available + if ( + self.tts_model + and self.tts_model.supports_streaming_input + ): + tts_res = await self.tts_model.push(res_msg) + speech = tts_res.content - return res_msg + await self.print(res_msg, False, speech=speech) + + else: + res_msg.content = res.content + + if self.tts_model: + # Push to TTS model and block to receive the full speech + # synthesis result + tts_res = await self.tts_model.synthesize(res_msg) + if self.tts_model.stream: + async for tts_chunk in tts_res: + speech = tts_chunk.content + await self.print(res_msg, False, speech=speech) + else: + speech = tts_res.content + + await self.print(res_msg, True, speech=speech) + + return res_msg async def handle_interrupt( self, diff --git a/src/agentscope/agent/_utils.py b/src/agentscope/agent/_utils.py new file mode 100644 index 0000000000..ae6a2cce08 --- /dev/null +++ b/src/agentscope/agent/_utils.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +"""Utils for agents in agentscope.""" +from typing import Any + + +class _AsyncNullContext: + """An async null context manager.""" + + async def __aenter__(self) -> None: + return None + + async def __aexit__( + self, + exc_type: Any, + exc_val: Any, + exc_tb: Any, + ) -> None: + return None diff --git a/src/agentscope/hooks/_studio_hooks.py b/src/agentscope/hooks/_studio_hooks.py index 7d17ae1b8c..111777e26c 100644 --- a/src/agentscope/hooks/_studio_hooks.py +++ b/src/agentscope/hooks/_studio_hooks.py @@ -32,8 +32,8 @@ def as_studio_forward_message_pre_print_hook( json={ "runId": run_id, "replyId": reply_id, - "name": getattr(self, "name", msg.name), - "role": "user" + "replyName": getattr(self, "name", msg.name), + "replyRole": "user" if isinstance(self, UserAgent) else "assistant", "msg": message_data, diff --git a/src/agentscope/message/_message_base.py b/src/agentscope/message/_message_base.py index 51798ebc98..481d17639b 100644 --- a/src/agentscope/message/_message_base.py +++ b/src/agentscope/message/_message_base.py @@ -120,19 +120,31 @@ def has_content_blocks( """ return len(self.get_content_blocks(block_type)) > 0 - def get_text_content(self) -> str | None: - """Get the pure text blocks from the message content.""" + def get_text_content(self, separator: str = "\n") -> str | None: + """Get the pure text blocks from the message content. + + Args: + separator (`str`, defaults to `\n`): + The separator to use when concatenating multiple text blocks. + Defaults to newline character. + + Returns: + `str | None`: + The concatenated text content, or `None` if there is no text + content. + """ if isinstance(self.content, str): return self.content - gathered_text = None + gathered_text = [] for block in self.content: if block.get("type") == "text": - if gathered_text is None: - gathered_text = str(block.get("text")) - else: - gathered_text += block.get("text") - return gathered_text + gathered_text.append(block["text"]) + + if gathered_text: + return separator.join(gathered_text) + + return None @overload def get_content_blocks( diff --git a/src/agentscope/pipeline/_functional.py b/src/agentscope/pipeline/_functional.py index 8770cd544c..efee84378b 100644 --- a/src/agentscope/pipeline/_functional.py +++ b/src/agentscope/pipeline/_functional.py @@ -4,7 +4,7 @@ from copy import deepcopy from typing import Any, AsyncGenerator, Tuple, Coroutine from ..agent import AgentBase -from ..message import Msg +from ..message import Msg, AudioBlock async def sequential_pipeline( @@ -108,7 +108,11 @@ async def stream_printing_messages( agents: list[AgentBase], coroutine_task: Coroutine, end_signal: str = "[END]", -) -> AsyncGenerator[Tuple[Msg, bool], None]: + yield_speech: bool = False, +) -> AsyncGenerator[ + Tuple[Msg, bool] | Tuple[Msg, bool, AudioBlock | list[AudioBlock] | None], + None, +]: """This pipeline will gather the printing messages from agents when execute the given coroutine task, and yield them one by one. Only the messages that are printed by `await self.print(msg)` in the agent @@ -134,12 +138,19 @@ async def stream_printing_messages( A special signal to indicate the end of message streaming. When this signal is received from the message queue, the generator will stop yielding messages and exit the loop. - - Returns: - `AsyncGenerator[Tuple[Msg, bool], None]`: - An async generator that yields tuples of (message, is_last_chunk). - The `is_last_chunk` boolean indicates whether the message is the - last chunk in a streaming message. + yield_speech (`bool`, defaults to `False`): + Whether to yield speech associated with the messages, if any. + If `True` and a speech is attached when calling `await + self.print()` in the agent, the yielded tuple will include the + speech as the third element. If `False`, only the message and + the boolean flag will be yielded. + + Yields: + `Tuple[Msg, bool] | Tuple[Msg, bool, AudioBlock | list[AudioBlock] | \ + None]`: + A tuple containing the message, a boolean indicating whether + it's the last chunk in a streaming message, and optionally + the associated speech (if `yield_speech` is `True`). """ # Enable the message queue to get the intermediate messages @@ -156,6 +167,7 @@ async def stream_printing_messages( else: task.add_done_callback(lambda _: queue.put_nowait(end_signal)) + # Receive the messages from the agent's message queue while True: # The message obj, and a boolean indicating whether it's the last chunk # in a streaming message @@ -165,7 +177,11 @@ async def stream_printing_messages( if isinstance(printing_msg, str) and printing_msg == end_signal: break - yield printing_msg + if yield_speech: + yield printing_msg + else: + msg, last, _ = printing_msg + yield msg, last # Check exception after processing all messages exception = task.exception() diff --git a/src/agentscope/tts/__init__.py b/src/agentscope/tts/__init__.py new file mode 100644 index 0000000000..b54a1f3ed8 --- /dev/null +++ b/src/agentscope/tts/__init__.py @@ -0,0 +1,19 @@ +# -*- coding: utf-8 -*- +"""The TTS (Text-to-Speech) module.""" + +from ._tts_base import TTSModelBase +from ._tts_response import TTSResponse, TTSUsage +from ._dashscope_tts_model import DashScopeTTSModel +from ._dashscope_realtime_tts_model import DashScopeRealtimeTTSModel +from ._gemini_tts_model import GeminiTTSModel +from ._openai_tts_model import OpenAITTSModel + +__all__ = [ + "TTSModelBase", + "TTSResponse", + "TTSUsage", + "DashScopeTTSModel", + "DashScopeRealtimeTTSModel", + "GeminiTTSModel", + "OpenAITTSModel", +] diff --git a/src/agentscope/tts/_dashscope_realtime_tts_model.py b/src/agentscope/tts/_dashscope_realtime_tts_model.py new file mode 100644 index 0000000000..df02bdd2b0 --- /dev/null +++ b/src/agentscope/tts/_dashscope_realtime_tts_model.py @@ -0,0 +1,445 @@ +# -*- coding: utf-8 -*- +"""DashScope Realtime TTS model implementation.""" + +import threading +from typing import Any, Literal, TYPE_CHECKING, AsyncGenerator + +from ._tts_base import TTSModelBase +from ._tts_response import TTSResponse +from ..message import Msg, AudioBlock, Base64Source +from ..types import JSONSerializableObject + +if TYPE_CHECKING: + from dashscope.audio.qwen_tts_realtime import ( + QwenTtsRealtime, + QwenTtsRealtimeCallback, + ) +else: + QwenTtsRealtime = "dashscope.audio.qwen_tts_realtime.QwenTtsRealtime" + QwenTtsRealtimeCallback = ( + "dashscope.audio.qwen_tts_realtime.QwenTtsRealtimeCallback" + ) + + +def _get_qwen_tts_realtime_callback_class() -> type["QwenTtsRealtimeCallback"]: + from dashscope.audio.qwen_tts_realtime import QwenTtsRealtimeCallback + + class _DashScopeRealtimeTTSCallback(QwenTtsRealtimeCallback): + """DashScope Realtime TTS callback.""" + + def __init__(self) -> None: + """Initialize the DashScope Realtime TTS callback.""" + super().__init__() + + # The event that will be set when a new audio chunk is received + self.chunk_event = threading.Event() + # The event that will be set when the TTS synthesis is finished + self.finish_event = threading.Event() + # Cache the audio data + self._audio_data: str = "" + + def on_event(self, response: dict[str, Any]) -> None: + """Called when a TTS event is received (DashScope SDK callback). + + Args: + response (`dict[str, Any]`): + The event response dictionary. + """ + try: + event_type = response.get("type") + + if event_type == "session.created": + self._audio_data = "" + self.finish_event.clear() + + elif event_type == "response.audio.delta": + audio_data = response.get("delta") + if audio_data: + # Process audio data in thread callback + if isinstance(audio_data, bytes): + import base64 + + audio_data = base64.b64encode(audio_data).decode() + + # Accumulate audio data + self._audio_data += audio_data + + # Signal that a new audio chunk is available + if not self.chunk_event.is_set(): + self.chunk_event.set() + + elif event_type == "response.done": + # Response completed, can be used for metrics + pass + + elif event_type == "session.finished": + self.chunk_event.set() + self.finish_event.set() + + except Exception: + import traceback + + traceback.print_exc() + self.finish_event.set() + + async def get_audio_data(self, block: bool) -> TTSResponse: + """Get the current accumulated audio data as base64 string so far. + + Returns: + `str`: + The base64-encoded audio data. + """ + # Block until synthesis is finished + if block: + self.finish_event.wait() + + # Return the accumulated audio data + if self._audio_data: + return TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self._audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + ) + + # Reset for next tts request + await self._reset() + + # Return empty response if no audio data + return TTSResponse(content=None) + + async def get_audio_chunk(self) -> AsyncGenerator[TTSResponse, None]: + """Get the audio data chunk as an async generator of `TTSResponse` + objects. + + Returns: + `AsyncGenerator[TTSResponse, None]`: + The async generator yielding TTSResponse with audio chunks. + """ + while True: + if self.finish_event.is_set(): + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self._audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + is_last=True, + ) + + # Reset for next tts request + await self._reset() + + break + + if self.chunk_event.is_set(): + # Clear the event for next chunk + self.chunk_event.clear() + else: + # Wait for the next chunk + self.chunk_event.wait() + + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self._audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + is_last=False, + ) + + async def _reset(self) -> None: + """Reset the callback state for a new TTS request.""" + self.finish_event.clear() + self.chunk_event.clear() + self._audio_data = "" + + return _DashScopeRealtimeTTSCallback + + +class DashScopeRealtimeTTSModel(TTSModelBase): + """TTS implementation for DashScope Qwen Realtime TTS API, which supports + streaming input. The supported models include "qwen-3-tts-flash-realtime", + "qwen-tts-realtime", etc. + + For more details, please see the `official document + `_. + + .. note:: The DashScopeRealtimeTTSModel can only handle one streaming + input request at a time, and cannot process multiple streaming input + requests concurrently. For example, it cannot handle input sequences like + `[msg_1_chunk0, msg_1_chunk1, msg_2_chunk0]`, where the prefixes "msg_x" + indicate different streaming input requests. + """ + + supports_streaming_input: bool = True + """Whether the model supports streaming input.""" + + def __init__( + self, + api_key: str, + model_name: str = "qwen3-tts-flash-realtime", + voice: Literal["Cherry", "Nofish", "Ethan", "Jennifer"] + | str = "Cherry", + stream: bool = True, + cold_start_length: int | None = None, + cold_start_words: int | None = None, + client_kwargs: dict[str, JSONSerializableObject] | None = None, + generate_kwargs: dict[str, JSONSerializableObject] | None = None, + ) -> None: + """Initialize the DashScope TTS model by specifying the model, voice, + and other parameters. + + .. note:: More details about the parameters, such as `model_name`, + `voice`, and `mode` can be found in the `official document + `_. + + .. note:: You can use `cold_start_length` and `cold_start_words` + simultaneously to set both character and word thresholds for the first + TTS request. For Chinese text, word segmentation (based on spaces) may + not be effective. + + Args: + api_key (`str`): + The DashScope API key. + model_name (`str`, defaults to "qwen-tts-realtime"): + The TTS model name, e.g. "qwen3-tts-flash-realtime", + "qwen-tts-realtime", etc. + voice (`Literal["Cherry", "Serena", "Ethan", "Chelsie"] | str`, \ + defaults to "Cherry".): + The voice to use for synthesis. Refer to `official document + `_ + for the supported voices for each model. + stream (`bool`, defaults to `True`): + Whether to use streaming synthesis. + cold_start_length (`int | None`, optional): + The minimum length send threshold for the first TTS request, + ensuring there is no pause in the synthesized speech for too + short input text. The length is measured in number of + characters. + cold_start_words (`int | None`, optional): + The minimum words send threshold for the first TTS request, + ensuring there is no pause in the synthesized speech for too + short input text. The words are identified by spaces in the + text. + client_kwargs (`dict[str, JSONSerializableObject] | None`, \ + optional): + The extra keyword arguments to initialize the DashScope + realtime tts client. + generate_kwargs (`dict[str, JSONSerializableObject] | None`, \ + optional): + The extra keyword arguments used in DashScope realtime tts API + generation. + """ + super().__init__(model_name=model_name, stream=stream) + + import dashscope + from dashscope.audio.qwen_tts_realtime import QwenTtsRealtime + + dashscope.api_key = api_key + + # Store configuration + self.voice = voice + self.mode = "server_commit" + self.cold_start_length = cold_start_length + self.cold_start_words = cold_start_words + self.client_kwargs = client_kwargs or {} + self.generate_kwargs = generate_kwargs or {} + + # Initialize TTS client + # Save callback reference (for DashScope SDK) + self._dashscope_callback = _get_qwen_tts_realtime_callback_class()() + self._tts_client: QwenTtsRealtime = QwenTtsRealtime( + model=self.model_name, + callback=self._dashscope_callback, + **self.client_kwargs, + ) + + self._connected = False + + # The variables for tracking streaming input messages + # If we have sent text for the current message + self._first_send: bool = True + # The current message ID being processed + self._current_msg_id: str | None = None + # The current prefix text already sent + self._current_prefix: str = "" + + async def connect(self) -> None: + """Initialize the DashScope TTS model and establish connection.""" + if self._connected: + return + + self._tts_client.connect() + + # Update session with voice and format settings + self._tts_client.update_session( + voice=self.voice, + mode=self.mode, + **self.generate_kwargs, + ) + + self._connected = True + + async def close(self) -> None: + """Close the TTS model and clean up resources.""" + if not self._connected: + return + + self._connected = False + + self._tts_client.finish() + self._tts_client.close() + + async def push( + self, + msg: Msg, + **kwargs: Any, + ) -> TTSResponse: + """Append text to be synthesized and return the received TTS response. + Note this method is non-blocking, and maybe return an empty response + if no audio is received yet. + + To receive all the synthesized speech, call the `synthesize` method + after pushing all the text chunks. + + Args: + msg (`Msg`): + The message to be synthesized. The `msg.id` identifies the + streaming input request. + **kwargs (`Any`): + Additional keyword arguments to pass to the TTS API call. + + Returns: + `TTSResponse`: + The TTSResponse containing audio blocks. + """ + if not self._connected: + raise RuntimeError( + "TTS model is not connected. Call `connect()` first.", + ) + + if self._current_msg_id is not None and self._current_msg_id != msg.id: + raise RuntimeError( + "DashScopeRealtimeTTSModel can only handle one streaming " + "input request at a time. Please ensure that all chunks " + "belong to the same message ID.", + ) + + # Record current message ID + self._current_msg_id = msg.id + + text = msg.get_text_content() + + # Determine if we should send text based on cold start settings only + # for the first input chunk and not the last chunk + if text: + if self._first_send: + # If we have cold start settings + if self.cold_start_length: + if len(text) < self.cold_start_length: + delta_to_send = "" + else: + delta_to_send = text + else: + delta_to_send = text + + if delta_to_send and self.cold_start_words: + if len(delta_to_send.split()) < self.cold_start_words: + delta_to_send = "" + else: + # Remove the already sent prefix if not the first send + delta_to_send = text.removeprefix(self._current_prefix) + + if delta_to_send: + self._tts_client.append_text(delta_to_send) + + # Record sent prefix + self._current_prefix += delta_to_send + self._first_send = False + + # Wait for the audio data to be available + res = await self._dashscope_callback.get_audio_data(block=False) + + return res + + # Return empty response if no text to send + return TTSResponse(content=None) + + async def synthesize( + self, + msg: Msg | None = None, + **kwargs: Any, + ) -> TTSResponse | AsyncGenerator[TTSResponse, None]: + """Append text to be synthesized and return TTS response. + + Args: + msg (`Msg | None`, optional): + The message to be synthesized. + **kwargs (`Any`): + Additional keyword arguments to pass to the TTS API call. + + Returns: + `TTSResponse | AsyncGenerator[TTSResponse, None]`: + The TTSResponse object in non-streaming mode, or an async + generator yielding TTSResponse objects in streaming mode. + """ + if not self._connected: + raise RuntimeError( + "TTS model is not connected. Call `connect()` first.", + ) + + if self._current_msg_id is not None and self._current_msg_id != msg.id: + raise RuntimeError( + "DashScopeRealtimeTTSModel can only handle one streaming " + "input request at a time. Please ensure that all chunks " + "belong to the same message ID.", + ) + + if msg is None: + delta_to_send = "" + + else: + # Record current message ID + self._current_msg_id = msg.id + delta_to_send = (msg.get_text_content() or "").removeprefix( + self._current_prefix, + ) + + # Determine if we should send text based on cold start settings only + # for the first input chunk and not the last chunk + if delta_to_send: + self._tts_client.append_text(delta_to_send) + + # To keep correct prefix tracking + self._current_prefix += delta_to_send + self._first_send = False + + # We need to block until synthesis is complete to get all audio + self._tts_client.commit() + self._tts_client.finish() + + if self.stream: + # Return an async generator for audio chunks + res = self._dashscope_callback.get_audio_chunk() + + else: + # Block and wait for all audio data to be available + res = await self._dashscope_callback.get_audio_data(block=True) + + # Update state for next message + self._current_msg_id = None + self._first_send = True + self._current_prefix = "" + + return res diff --git a/src/agentscope/tts/_dashscope_tts_model.py b/src/agentscope/tts/_dashscope_tts_model.py new file mode 100644 index 0000000000..7fa9b7ebb1 --- /dev/null +++ b/src/agentscope/tts/_dashscope_tts_model.py @@ -0,0 +1,177 @@ +# -*- coding: utf-8 -*- +"""DashScope SDK TTS model implementation using MultiModalConversation API.""" +from typing import ( + Any, + Literal, + AsyncGenerator, + Generator, + TYPE_CHECKING, +) + +from ._tts_base import TTSModelBase +from ._tts_response import TTSResponse +from ..message import Msg, AudioBlock, Base64Source +from ..types import JSONSerializableObject + +if TYPE_CHECKING: + from dashscope.api_entities.dashscope_response import ( + MultiModalConversationResponse, + ) + +else: + MultiModalConversationResponse = ( + "dashscope.api_entities.dashscope_response." + "MultiModalConversationResponse" + ) + + +class DashScopeTTSModel(TTSModelBase): + """DashScope TTS model implementation using MultiModalConversation API. + For more details, please see the `official document + `_. + """ + + supports_streaming_input: bool = False + """Whether the model supports streaming input.""" + + def __init__( + self, + api_key: str, + model_name: str = "qwen3-tts-flash", + voice: Literal["Cherry", "Serena", "Ethan", "Chelsie"] + | str = "Cherry", + language_type: str = "Auto", + stream: bool = True, + generate_kwargs: dict[str, JSONSerializableObject] | None = None, + ) -> None: + """Initialize the DashScope SDK TTS model. + + .. note:: More details about the parameters, such as `model_name`, + `voice`, and language_type can be found in the `official document + `_. + + Args: + api_key (`str`): + The DashScope API key. Required. + model_name (`str`, defaults to "qwen3-tts-flash"): + The TTS model name. Supported models are qwen3-tts-flash, + qwen-tts, etc. + voice (`Literal["Cherry", "Serena", "Ethan", "Chelsie"] | str`, \ + defaults to "Cherry"): + The voice to use. Supported voices are "Cherry", "Serena", + "Ethan", "Chelsie", etc. + language_type (`str`, default to "Auto"): + The language type. Should match the text language for + correct pronunciation and natural intonation. + generate_kwargs (`dict[str, JSONSerializableObject] | None`, \ + optional): + The extra keyword arguments used in Dashscope TTS API + generation, e.g. `temperature`, `seed`. + """ + super().__init__(model_name=model_name, stream=stream) + + self.api_key = api_key + self.voice = voice + self.language_type = language_type + self.generate_kwargs = generate_kwargs or {} + + async def synthesize( + self, + msg: Msg | None = None, + **kwargs: Any, + ) -> TTSResponse | AsyncGenerator[TTSResponse, None]: + """Call the DashScope TTS API to synthesize speech from text. + + Args: + msg (`Msg | None`, optional): + The message to be synthesized. + **kwargs (`Any`): + Additional keyword arguments to pass to the TTS API call. + + Returns: + `TTSResponse | AsyncGenerator[TTSResponse, None]`: + The TTS response or an async generator yielding TTSResponse + objects in streaming mode. + """ + + if msg is None: + return TTSResponse(content=None) + + text = msg.get_text_content() + + import dashscope + + # Call DashScope TTS API with streaming mode + response = dashscope.MultiModalConversation.call( + model=self.model_name, + api_key=self.api_key, + text=text, + voice=self.voice, + language_type=self.language_type, + stream=True, + **self.generate_kwargs, + **kwargs, + ) + + if self.stream: + return self._parse_into_async_generator(response) + + audio_data = "" + for chunk in response: + if chunk.output is not None: + audio_data += chunk.output.audio.data + + res = TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + ) + return res + + @staticmethod + async def _parse_into_async_generator( + response: Generator[MultiModalConversationResponse, None, None], + ) -> AsyncGenerator[TTSResponse, None]: + """Parse the TTS response into an async generator. + + Args: + response (`Generator[MultiModalConversationResponse, None, None]`): + The streaming response from DashScope TTS API. + + Returns: + `AsyncGenerator[TTSResponse, None]`: + An async generator yielding TTSResponse objects. + """ + audio_data = "" + for chunk in response: + if chunk.output is not None: + audio = chunk.output.audio + if audio and audio.data: + audio_data += audio.data + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + is_last=False, + ) + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + is_last=True, + ) diff --git a/src/agentscope/tts/_gemini_tts_model.py b/src/agentscope/tts/_gemini_tts_model.py new file mode 100644 index 0000000000..eaef5a1af7 --- /dev/null +++ b/src/agentscope/tts/_gemini_tts_model.py @@ -0,0 +1,210 @@ +# -*- coding: utf-8 -*- +"""Gemini TTS model implementation.""" +import base64 +from typing import TYPE_CHECKING, Any, Literal, AsyncGenerator, Iterator + +from ._tts_base import TTSModelBase +from ._tts_response import TTSResponse +from ..message import Msg, AudioBlock, Base64Source +from ..types import JSONSerializableObject + +if TYPE_CHECKING: + from google.genai import Client + from google.genai.types import GenerateContentResponse +else: + Client = "google.genai.Client" + GenerateContentResponse = "google.genai.types.GenerateContentResponse" + + +class GeminiTTSModel(TTSModelBase): + """Gemini TTS model implementation. + For more details, please see the `official document + `_. + """ + + supports_streaming_input: bool = False + """Whether the model supports streaming input.""" + + def __init__( + self, + api_key: str, + model_name: str = "gemini-2.5-flash-preview-tts", + voice: Literal["Zephyr", "Kore", "Orus", "Autonoe"] | str = "Kore", + stream: bool = True, + client_kwargs: dict[str, JSONSerializableObject] | None = None, + generate_kwargs: dict[str, JSONSerializableObject] | None = None, + ) -> None: + """Initialize the Gemini TTS model. + + .. note:: + More details about the parameters, such as `model_name` and + `voice` can be found in the `official document + `_. + + Args: + api_key (`str`): + The Gemini API key. + model_name (`str`, defaults to "gemini-2.5-flash-preview-tts"): + The TTS model name. Supported models are + "gemini-2.5-flash-preview-tts", + "gemini-2.5-pro-preview-tts", etc. + voice (`Literal["Zephyr", "Kore", "Orus", "Autonoe"] | str`, \ + defaults to "Kore"): + The voice name to use. Supported voices are "Zephyr", + "Kore", "Orus", "Autonoe", etc. + stream (`bool`, defaults to `True`): + Whether to use streaming synthesis if supported by the model. + client_kwargs (`dict[str, JSONSerializableObject] | None`, \ + optional): + The extra keyword arguments to initialize the Gemini client. + generate_kwargs (`dict[str, JSONSerializableObject] | None`, \ + optional): + The extra keyword arguments used in Gemini API generation, + e.g. `temperature`, `seed`. + """ + super().__init__(model_name=model_name, stream=stream) + + self.api_key = api_key + self.voice = voice + + from google import genai + + self._client = genai.Client( + api_key=self.api_key, + **(client_kwargs or {}), + ) + + self.generate_kwargs = generate_kwargs or {} + + async def synthesize( + self, + msg: Msg | None = None, + **kwargs: Any, + ) -> TTSResponse | AsyncGenerator[TTSResponse, None]: + """Append text to be synthesized and return TTS response. + + Args: + msg (`Msg | None`, optional): + The message to be synthesized. + **kwargs (`Any`): + Additional keyword arguments to pass to the TTS API call. + + Returns: + `TTSResponse | AsyncGenerator[TTSResponse, None]`: + The TTSResponse object in non-streaming mode, or an async + generator yielding TTSResponse objects in streaming mode. + """ + if msg is None: + return TTSResponse(content=None) + + from google.genai import types + + # Only call API for synthesis when last=True + text = msg.get_text_content() + + # Prepare config + config = types.GenerateContentConfig( + response_modalities=["AUDIO"], + speech_config=types.SpeechConfig( + voice_config=types.VoiceConfig( + prebuilt_voice_config=types.PrebuiltVoiceConfig( + voice_name=self.voice, + ), + ), + ), + **self.generate_kwargs, + **kwargs, + ) + + # Prepare API kwargs + api_kwargs: dict[str, JSONSerializableObject] = { + "model": self.model_name, + "contents": text, + "config": config, + } + + if self.stream: + response = self._client.models.generate_content_stream( + **api_kwargs, + ) + return self._parse_into_async_generator(response) + + # Call Gemini TTS API + response = self._client.models.generate_content(**api_kwargs) + + # Extract audio data + if ( + response.candidates + and response.candidates[0].content + and response.candidates[0].content.parts + and response.candidates[0].content.parts[0].inline_data + ): + audio_data = ( + response.candidates[0].content.parts[0].inline_data.data + ) + mime_type = ( + response.candidates[0].content.parts[0].inline_data.mime_type + ) + # Convert PCM data to base64 + audio_base64 = base64.b64encode(audio_data).decode("utf-8") + + audio_block = AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_base64, + media_type=mime_type, + ), + ) + return TTSResponse(content=audio_block) + + else: + # Not the last chunk, return empty AudioBlock + return TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data="", + media_type="audio/pcm;rate=24000", + ), + ), + ) + + @staticmethod + async def _parse_into_async_generator( + response: Iterator[GenerateContentResponse], + ) -> AsyncGenerator[TTSResponse, None]: + """Parse the TTS response into an async generator. + + Args: + response (`Iterator[GenerateContentResponse]`): + The streaming response from Gemini TTS API. + + Returns: + `AsyncGenerator[TTSResponse, None]`: + An async generator yielding TTSResponse objects. + """ + audio_data = "" + for chunk in response: + chunk_audio_data = ( + chunk.candidates[0].content.parts[0].inline_data.data + ) + mime_type = ( + chunk.candidates[0].content.parts[0].inline_data.mime_type + ) + chunk_audio_base64 = base64.b64encode(chunk_audio_data).decode( + "utf-8", + ) + audio_data += chunk_audio_base64 + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_data, + media_type=mime_type, + ), + ), + ) + yield TTSResponse(content=None) diff --git a/src/agentscope/tts/_openai_tts_model.py b/src/agentscope/tts/_openai_tts_model.py new file mode 100644 index 0000000000..46978a48c2 --- /dev/null +++ b/src/agentscope/tts/_openai_tts_model.py @@ -0,0 +1,184 @@ +# -*- coding: utf-8 -*- +"""OpenAI TTS model implementation.""" +import base64 +from typing import TYPE_CHECKING, Any, Literal, AsyncGenerator + +from ._tts_base import TTSModelBase +from ._tts_response import TTSResponse +from ..message import Msg, AudioBlock, Base64Source +from ..types import JSONSerializableObject + +if TYPE_CHECKING: + from openai import HttpxBinaryResponseContent +else: + HttpxBinaryResponseContent = "openai.HttpxBinaryResponseContent" + + +class OpenAITTSModel(TTSModelBase): + """OpenAI TTS model implementation. + For more details, please see the `official document + `_. + """ + + # This model does not support streaming input (requires complete text) + supports_streaming_input: bool = False + + def __init__( + self, + api_key: str, + model_name: str = "gpt-4o-mini-tts", + voice: Literal["alloy", "ash", "ballad", "coral"] | str = "alloy", + stream: bool = True, + client_kwargs: dict | None = None, + generate_kwargs: dict[str, JSONSerializableObject] | None = None, + ) -> None: + """Initialize the OpenAI TTS model. + + .. note:: + More details about the parameters, such as `model_name` and + `voice` can be found in the `official document + `_. + + Args: + api_key (`str`): + The OpenAI API key. + model_name (`str`, defaults to "gpt-4o-mini-tts"): + The TTS model name. Supported models are "gpt-4o-mini-tts", + "tts-1", etc. + voice (`Literal["alloy", "ash", "ballad", "coral"] | str `, + defaults to "alloy"): + The voice to use. Supported voices are "alloy", "ash", + "ballad", "coral", etc. + client_kwargs (`dict | None`, default `None`): + The extra keyword arguments to initialize the OpenAI client. + generate_kwargs (`dict[str, JSONSerializableObject] | None`, \ + optional): + The extra keyword arguments used in OpenAI API generation, + e.g. `temperature`, `seed`. + """ + super().__init__(model_name=model_name, stream=stream) + + self.api_key = api_key + self.voice = voice + self.stream = stream + + import openai + + self._client = openai.AsyncOpenAI( + api_key=self.api_key, + **client_kwargs or {}, + ) + + # Text buffer for each message to accumulate text before synthesis + # Key is msg.id, value is the accumulated text + self.generate_kwargs = generate_kwargs or {} + + async def synthesize( + self, + msg: Msg | None = None, + **kwargs: Any, + ) -> TTSResponse | AsyncGenerator[TTSResponse, None]: + """Append text to be synthesized and return TTS response. + + Args: + msg (`Msg | None`, optional): + The message to be synthesized. + **kwargs (`Any`): + Additional keyword arguments to pass to the TTS API call. + + Returns: + `TTSResponse | AsyncGenerator[TTSResponse, None]`: + The TTSResponse object in non-streaming mode, or an async + generator yielding TTSResponse objects in streaming mode. + """ + if msg is None: + return TTSResponse(content=None) + + text = msg.get_text_content() + + if text: + if self.stream: + response = ( + self._client.audio.speech.with_streaming_response.create( + model=self.model_name, + voice=self.voice, + input=text, + response_format="mp3", + **self.generate_kwargs, + **kwargs, + ) + ) + return self._parse_into_async_generator(response) + + response = await self._client.audio.speech.create( + model=self.model_name, + voice=self.voice, + input=text, + response_format="pcm", + **self.generate_kwargs, + **kwargs, + ) + + audio_base64 = base64.b64encode(response.content).decode( + "utf-8", + ) + return TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_base64, + media_type="audio/pcm", + ), + ), + ) + + return TTSResponse(content=None) + + @staticmethod + async def _parse_into_async_generator( + response: HttpxBinaryResponseContent, + ) -> AsyncGenerator[TTSResponse, None]: + """Parse the streaming response into an async generator of TTSResponse. + + Args: + response (`HttpxBinaryResponseContent`): + The streaming response from OpenAI TTS API. + + Yields: + `TTSResponse`: + The TTSResponse object containing audio blocks. + """ + # Iterate through the streaming response chunks + async with response as stream: + audio_base64 = "" + async for chunk in stream.iter_bytes(): + if chunk: + # Encode chunk to base64 + audio_base64 = base64.b64encode(chunk).decode("utf-8") + + # Create TTSResponse for this chunk + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_base64, + media_type="audio/pcm", + ), + ), + is_last=False, # Not the last chunk yet + ) + + # Yield final response with is_last=True to indicate end of stream + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=audio_base64, + media_type="audio/pcm", + ), + ), + is_last=True, + ) diff --git a/src/agentscope/tts/_tts_base.py b/src/agentscope/tts/_tts_base.py new file mode 100644 index 0000000000..df0eea3793 --- /dev/null +++ b/src/agentscope/tts/_tts_base.py @@ -0,0 +1,143 @@ +# -*- coding: utf-8 -*- +"""The TTS model base class.""" + +from abc import ABC, abstractmethod +from typing import Any, AsyncGenerator + +from agentscope.message import Msg + +from ._tts_response import TTSResponse + + +class TTSModelBase(ABC): + """Base class for TTS models in AgentScope. + + This base class provides general abstraction for both realtime and + non-realtime TTS models (depending on whether streaming input is + supported). + + For non-realtime TTS models, the `synthesize` method is used to + synthesize speech from the input text. You only need to implement the + `_call_api` method to handle the TTS API calls. + + For realtime TTS models, its lifecycle is managed via the async context + manager or calling `connect` and `close` methods. The `push` method will + append text chunks and return the received TTS response, while the + `synthesize` method will block until the full speech is synthesized. + You need to implement the `connect`, `close`, and `_call_api` methods + to handle the TTS API calls and resource management. + """ + + supports_streaming_input: bool = False + """If the TTS model class supports streaming input.""" + + model_name: str + """The name of the TTS model.""" + + stream: bool + """Whether to use streaming synthesis if supported by the model.""" + + def __init__(self, model_name: str, stream: bool) -> None: + """Initialize the TTS model base class. + + Args: + model_name (`str`): + The name of the TTS model + stream (`bool`): + Whether to use streaming synthesis if supported by the model. + """ + self.model_name = model_name + self.stream = stream + + async def __aenter__(self) -> "TTSModelBase": + """Enter the async context manager and initialize resources if + needed.""" + if self.supports_streaming_input: + await self.connect() + + return self + + async def __aexit__( + self, + exc_type: Any, + exc_value: Any, + traceback: Any, + ) -> None: + """Exit the async context manager and clean up resources if needed.""" + if self.supports_streaming_input: + await self.close() + + async def connect(self) -> None: + """Connect to the TTS model and initialize resources. For non-realtime + TTS models, leave this method empty. + + .. note:: Only needs to be implemented for realtime TTS models. + + """ + raise NotImplementedError( + f"The connect method is not implemented for " + f"{self.__class__.__name__} class.", + ) + + async def close(self) -> None: + """Close the connection to the TTS model and clean up resources. For + non-realtime TTS models, leave this method empty. + + .. note:: Only needs to be implemented for realtime TTS models. + + """ + raise NotImplementedError( + "The close method is not implemented for " + f"{self.__class__.__name__} class.", + ) + + async def push( + self, + msg: Msg, + **kwargs: Any, + ) -> TTSResponse: + """Append text to be synthesized and return the received TTS response. + Note this method is non-blocking, and maybe return an empty response + if no audio is received yet. + + To receive all the synthesized speech, call the `synthesize` method + after pushing all the text chunks. + + .. note:: Only needs to be implemented for realtime TTS models. + + Args: + msg (`Msg`): + The message to be synthesized. The `msg.id` identifies the + streaming input request. + **kwargs (`Any`): + Additional keyword arguments to pass to the TTS API call. + + Returns: + `TTSResponse`: + The TTSResponse containing audio block. + """ + raise NotImplementedError( + "The push method is not implemented for " + f"{self.__class__.__name__} class.", + ) + + @abstractmethod + async def synthesize( + self, + msg: Msg | None = None, + **kwargs: Any, + ) -> TTSResponse | AsyncGenerator[TTSResponse, None]: + """Synthesize speech from the appended text. Different from the `push` + method, this method will block until the full speech is synthesized. + + Args: + msg (`Msg | None`, defaults to `None`): + The message to be synthesized. If `None`, this method will + wait for all previously pushed text to be synthesized, and + return the last synthesized TTSResponse. + + Returns: + `TTSResponse | AsyncGenerator[TTSResponse, None]`: + The TTSResponse containing audio blocks, or an async generator + yielding TTSResponse objects in streaming mode. + """ diff --git a/src/agentscope/tts/_tts_response.py b/src/agentscope/tts/_tts_response.py new file mode 100644 index 0000000000..7d4c4fcac2 --- /dev/null +++ b/src/agentscope/tts/_tts_response.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +"""The TTS response module.""" + +from dataclasses import dataclass, field +from typing import Literal + +from .._utils._common import _get_timestamp +from .._utils._mixin import DictMixin +from ..message import AudioBlock +from ..types import JSONSerializableObject + + +@dataclass +class TTSUsage(DictMixin): + """The usage of a TTS model API invocation.""" + + input_tokens: int + """The number of input tokens.""" + + output_tokens: int + """The number of output tokens.""" + + time: float + """The time used in seconds.""" + + type: Literal["tts"] = field(default_factory=lambda: "tts") + """The type of the usage, must be `tts`.""" + + +@dataclass +class TTSResponse(DictMixin): + """The response of TTS models.""" + + content: AudioBlock | None + """The content of the TTS response, which contains audio block""" + + id: str = field(default_factory=lambda: _get_timestamp(True)) + """The unique identifier of the response.""" + + created_at: str = field(default_factory=_get_timestamp) + """When the response was created.""" + + type: Literal["tts"] = field(default_factory=lambda: "tts") + """The type of the response, which is always 'tts'.""" + + usage: TTSUsage | None = field(default_factory=lambda: None) + """The usage information of the TTS response, if available.""" + + metadata: dict[str, JSONSerializableObject] | None = field( + default_factory=lambda: None, + ) + """The metadata of the TTS response.""" + + is_last: bool = True + """Whether this is the last response in a stream of TTS responses.""" diff --git a/tests/tts_dashscope_test.py b/tests/tts_dashscope_test.py new file mode 100644 index 0000000000..8a37f58bd7 --- /dev/null +++ b/tests/tts_dashscope_test.py @@ -0,0 +1,321 @@ +# -*- coding: utf-8 -*- +# pylint: disable=protected-access +"""The unittests for DashScope TTS models.""" +import base64 +from typing import AsyncGenerator +from unittest import IsolatedAsyncioTestCase +from unittest.mock import Mock, patch, AsyncMock, MagicMock + +from agentscope.message import Msg, AudioBlock, Base64Source +from agentscope.tts import ( + DashScopeRealtimeTTSModel, + DashScopeTTSModel, + TTSResponse, +) + + +class DashScopeRealtimeTTSModelTest(IsolatedAsyncioTestCase): + """The unittests for DashScope Realtime TTS model.""" + + def setUp(self) -> None: + """Set up the test case.""" + self.api_key = "test_api_key" + self.mock_audio_data = base64.b64encode(b"fake_audio_data").decode( + "utf-8", + ) + + def _create_mock_tts_client(self) -> Mock: + """Create a mock QwenTtsRealtime client.""" + mock_client = Mock() + mock_client.connect = Mock() + mock_client.close = Mock() + mock_client.finish = Mock() + mock_client.update_session = Mock() + mock_client.append_text = Mock() + return mock_client + + def _create_mock_dashscope_modules(self) -> dict: + """Create mock dashscope modules for patching.""" + mock_qwen_tts_realtime = MagicMock() + mock_qwen_tts_realtime.QwenTtsRealtime = Mock + mock_qwen_tts_realtime.QwenTtsRealtimeCallback = Mock + + mock_audio = MagicMock() + mock_audio.qwen_tts_realtime = mock_qwen_tts_realtime + + mock_dashscope = MagicMock() + mock_dashscope.api_key = None + mock_dashscope.audio = mock_audio + + return { + "dashscope": mock_dashscope, + "dashscope.audio": mock_audio, + "dashscope.audio.qwen_tts_realtime": mock_qwen_tts_realtime, + } + + def test_init(self) -> None: + """Test initialization of DashScopeRealtimeTTSModel.""" + mock_modules = self._create_mock_dashscope_modules() + mock_tts_client = self._create_mock_tts_client() + mock_tts_class = Mock(return_value=mock_tts_client) + mock_modules[ + "dashscope.audio.qwen_tts_realtime" + ].QwenTtsRealtime = mock_tts_class + + with patch.dict("sys.modules", mock_modules): + model = DashScopeRealtimeTTSModel( + api_key=self.api_key, + stream=False, + ) + self.assertEqual(model.model_name, "qwen3-tts-flash-realtime") + self.assertFalse(model.stream) + self.assertFalse(model._connected) + + async def test_push_incremental_text(self) -> None: + """Test push method with incremental text chunks.""" + mock_modules = self._create_mock_dashscope_modules() + mock_client = self._create_mock_tts_client() + mock_tts_class = Mock(return_value=mock_client) + mock_modules[ + "dashscope.audio.qwen_tts_realtime" + ].QwenTtsRealtime = mock_tts_class + + with patch.dict("sys.modules", mock_modules): + async with DashScopeRealtimeTTSModel( + api_key=self.api_key, + stream=False, + ) as model: + # Mock callback to return audio data + model._dashscope_callback.get_audio_data = AsyncMock( + return_value=TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self.mock_audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + ), + ) + + msg_id = "test_msg_001" + text_chunks = ["Hello there!\n\n", "This is a test message."] + + accumulated_text = "" + for chunk in text_chunks: + accumulated_text += chunk + msg = Msg( + name="user", + content=accumulated_text, + role="user", + ) + msg.id = msg_id + + response = await model.push(msg) + self.assertIsInstance(response, TTSResponse) + + # Verify append_text was called + self.assertGreater(mock_client.append_text.call_count, 0) + + async def test_synthesize_non_streaming(self) -> None: + """Test synthesize method in non-streaming mode.""" + mock_modules = self._create_mock_dashscope_modules() + mock_client = self._create_mock_tts_client() + mock_tts_class = Mock(return_value=mock_client) + mock_modules[ + "dashscope.audio.qwen_tts_realtime" + ].QwenTtsRealtime = mock_tts_class + + with patch.dict("sys.modules", mock_modules): + async with DashScopeRealtimeTTSModel( + api_key=self.api_key, + stream=False, + ) as model: + model._dashscope_callback.get_audio_data = AsyncMock( + return_value=TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self.mock_audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + ), + ) + + msg = Msg( + name="user", + content="Hello! Test message.", + role="user", + ) + response = await model.synthesize(msg) + + self.assertIsInstance(response, TTSResponse) + self.assertEqual(response.content["type"], "audio") + + async def test_synthesize_streaming(self) -> None: + """Test synthesize method in streaming mode.""" + mock_modules = self._create_mock_dashscope_modules() + mock_client = self._create_mock_tts_client() + mock_tts_class = Mock(return_value=mock_client) + mock_modules[ + "dashscope.audio.qwen_tts_realtime" + ].QwenTtsRealtime = mock_tts_class + + with patch.dict("sys.modules", mock_modules): + async with DashScopeRealtimeTTSModel( + api_key=self.api_key, + stream=True, + ) as model: + + async def mock_generator() -> AsyncGenerator[ + TTSResponse, + None, + ]: + yield TTSResponse( + content=AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self.mock_audio_data, + media_type="audio/pcm;rate=24000", + ), + ), + ) + yield TTSResponse(content=None) + + model._dashscope_callback.get_audio_chunk = mock_generator + + msg = Msg(name="user", content="Test streaming.", role="user") + response = await model.synthesize(msg) + + self.assertIsInstance(response, AsyncGenerator) + chunk_count = 0 + async for chunk in response: + self.assertIsInstance(chunk, TTSResponse) + chunk_count += 1 + self.assertGreater(chunk_count, 0) + + +class DashScopeTTSModelTest(IsolatedAsyncioTestCase): + """The unittests for DashScope TTS model (non-realtime).""" + + def setUp(self) -> None: + """Set up the test case.""" + self.api_key = "test_api_key" + self.mock_audio_data = "ZmFrZV9hdWRpb19kYXRh" # base64 encoded + + def _create_mock_response_chunk(self, audio_data: str) -> Mock: + """Create a mock response chunk.""" + mock_chunk = Mock() + mock_chunk.output = Mock() + mock_chunk.output.audio = Mock() + mock_chunk.output.audio.data = audio_data + return mock_chunk + + def test_init(self) -> None: + """Test initialization of DashScopeTTSModel.""" + model = DashScopeTTSModel( + api_key=self.api_key, + model_name="qwen3-tts-flash", + voice="Cherry", + stream=False, + ) + self.assertEqual(model.model_name, "qwen3-tts-flash") + self.assertEqual(model.voice, "Cherry") + self.assertFalse(model.stream) + self.assertFalse(model.supports_streaming_input) + + async def test_synthesize_non_streaming(self) -> None: + """Test synthesize method in non-streaming mode.""" + model = DashScopeTTSModel( + api_key=self.api_key, + stream=False, + ) + + mock_chunks = [ + self._create_mock_response_chunk("audio1"), + self._create_mock_response_chunk("audio2"), + ] + + with patch("dashscope.MultiModalConversation.call") as mock_call: + mock_call.return_value = iter(mock_chunks) + + msg = Msg(name="user", content="Hello! Test message.", role="user") + response = await model.synthesize(msg) + + expected_content = AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data="audio1audio2", + media_type="audio/pcm;rate=24000", + ), + ) + self.assertEqual(response.content, expected_content) + + async def test_synthesize_streaming(self) -> None: + """Test synthesize method in streaming mode.""" + model = DashScopeTTSModel( + api_key=self.api_key, + stream=True, + ) + + mock_chunks = [ + self._create_mock_response_chunk("audio1"), + self._create_mock_response_chunk("audio2"), + ] + + with patch("dashscope.MultiModalConversation.call") as mock_call: + mock_call.return_value = iter(mock_chunks) + + msg = Msg(name="user", content="Test streaming.", role="user") + response = await model.synthesize(msg) + + self.assertIsInstance(response, AsyncGenerator) + chunks = [chunk async for chunk in response] + + # Should have 3 chunks: 2 from audio data + 1 final + self.assertEqual(len(chunks), 3) + + # Chunk 1: accumulated "audio1" + self.assertEqual( + chunks[0].content, + AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data="audio1", + media_type="audio/pcm;rate=24000", + ), + ), + ) + + # Chunk 2: accumulated "audio1audio2" + self.assertEqual( + chunks[1].content, + AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data="audio1audio2", + media_type="audio/pcm;rate=24000", + ), + ), + ) + + # Final chunk: complete audio data + self.assertEqual( + chunks[2].content, + AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data="audio1audio2", + media_type="audio/pcm;rate=24000", + ), + ), + ) + self.assertTrue(chunks[2].is_last) diff --git a/tests/tts_gemini_test.py b/tests/tts_gemini_test.py new file mode 100644 index 0000000000..7df86df777 --- /dev/null +++ b/tests/tts_gemini_test.py @@ -0,0 +1,167 @@ +# -*- coding: utf-8 -*- +# pylint: disable=protected-access +"""The unittests for Gemini TTS model.""" +import base64 +import sys +from typing import AsyncGenerator +from unittest import IsolatedAsyncioTestCase +from unittest.mock import Mock, patch, MagicMock + +from agentscope.message import Msg, AudioBlock, Base64Source +from agentscope.tts import GeminiTTSModel + + +# Create mock google.genai modules (required for import-time patching) +mock_types = MagicMock() +mock_types.GenerateContentConfig = Mock(return_value=Mock()) +mock_types.SpeechConfig = Mock(return_value=Mock()) +mock_types.VoiceConfig = Mock(return_value=Mock()) +mock_types.PrebuiltVoiceConfig = Mock(return_value=Mock()) + +mock_genai = MagicMock() +mock_genai.Client = Mock(return_value=MagicMock()) +mock_genai.types = mock_types + +mock_google = MagicMock() +mock_google.genai = mock_genai + + +@patch.dict( + sys.modules, + { + "google": mock_google, + "google.genai": mock_genai, + "google.genai.types": mock_types, + }, +) +class GeminiTTSModelTest(IsolatedAsyncioTestCase): + """The unittests for Gemini TTS model.""" + + def setUp(self) -> None: + """Set up the test case.""" + self.api_key = "test_api_key" + self.mock_audio_bytes = b"fake_audio_data" + self.mock_audio_base64 = base64.b64encode( + self.mock_audio_bytes, + ).decode( + "utf-8", + ) + self.mock_mime_type = "audio/pcm;rate=24000" + + def _create_mock_response( + self, + audio_data: bytes, + mime_type: str, + ) -> MagicMock: + """Create a mock Gemini response.""" + mock = MagicMock() + mock.candidates[0].content.parts[0].inline_data.data = audio_data + mock.candidates[0].content.parts[0].inline_data.mime_type = mime_type + return mock + + def test_init(self) -> None: + """Test initialization of GeminiTTSModel.""" + model = GeminiTTSModel( + api_key=self.api_key, + model_name="gemini-2.5-flash-preview-tts", + voice="Kore", + stream=False, + ) + self.assertEqual(model.model_name, "gemini-2.5-flash-preview-tts") + self.assertEqual(model.voice, "Kore") + self.assertFalse(model.stream) + self.assertFalse(model.supports_streaming_input) + + async def test_synthesize_non_streaming(self) -> None: + """Test synthesize method in non-streaming mode.""" + model = GeminiTTSModel( + api_key=self.api_key, + stream=False, + ) + + # Mock the generate_content response + mock_response = self._create_mock_response( + self.mock_audio_bytes, + self.mock_mime_type, + ) + model._client.models.generate_content = Mock( + return_value=mock_response, + ) + + msg = Msg(name="user", content="Hello! Test message.", role="user") + response = await model.synthesize(msg) + + expected_content = AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self.mock_audio_base64, + media_type=self.mock_mime_type, + ), + ) + self.assertEqual(response.content, expected_content) + + async def test_synthesize_streaming(self) -> None: + """Test synthesize method in streaming mode.""" + model = GeminiTTSModel( + api_key=self.api_key, + stream=True, + ) + + # Create mock streaming response chunks + chunk1_data = b"audio_chunk_1" + chunk2_data = b"audio_chunk_2" + mock_chunk1 = self._create_mock_response( + chunk1_data, + self.mock_mime_type, + ) + mock_chunk2 = self._create_mock_response( + chunk2_data, + self.mock_mime_type, + ) + + # Mock streaming response + model._client.models.generate_content_stream = Mock( + return_value=iter([mock_chunk1, mock_chunk2]), + ) + + msg = Msg(name="user", content="Test streaming.", role="user") + response = await model.synthesize(msg) + + self.assertIsInstance(response, AsyncGenerator) + chunks = [chunk async for chunk in response] + + # Should have 3 chunks: 2 from audio data + 1 final empty + self.assertEqual(len(chunks), 3) + + chunk1_base64 = base64.b64encode(chunk1_data).decode("utf-8") + chunk2_base64 = base64.b64encode(chunk2_data).decode("utf-8") + + # Chunk 1: accumulated chunk1 + self.assertEqual( + chunks[0].content, + AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=chunk1_base64, + media_type=self.mock_mime_type, + ), + ), + ) + + # Chunk 2: accumulated chunk1 + chunk2 + self.assertEqual( + chunks[1].content, + AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=chunk1_base64 + chunk2_base64, + media_type=self.mock_mime_type, + ), + ), + ) + + # Final chunk: empty + self.assertIsNone(chunks[2].content) diff --git a/tests/tts_openai_test.py b/tests/tts_openai_test.py new file mode 100644 index 0000000000..86a5093c35 --- /dev/null +++ b/tests/tts_openai_test.py @@ -0,0 +1,135 @@ +# -*- coding: utf-8 -*- +# pylint: disable=protected-access +"""The unittests for OpenAI TTS model.""" +import base64 +import sys +from typing import AsyncGenerator +from unittest import IsolatedAsyncioTestCase +from unittest.mock import Mock, patch, AsyncMock, MagicMock + +from agentscope.message import Msg, AudioBlock, Base64Source +from agentscope.tts import OpenAITTSModel + + +# Create mock openai module (required for import-time patching) +mock_openai = MagicMock() +mock_openai.AsyncOpenAI = Mock(return_value=MagicMock()) + + +@patch.dict(sys.modules, {"openai": mock_openai}) +class OpenAITTSModelTest(IsolatedAsyncioTestCase): + """The unittests for OpenAI TTS model.""" + + def setUp(self) -> None: + """Set up the test case.""" + self.api_key = "test_api_key" + self.mock_audio_bytes = b"fake_audio_data" + self.mock_audio_base64 = base64.b64encode( + self.mock_audio_bytes, + ).decode( + "utf-8", + ) + + def test_init(self) -> None: + """Test initialization of OpenAITTSModel.""" + model = OpenAITTSModel( + api_key=self.api_key, + model_name="gpt-4o-mini-tts", + voice="alloy", + stream=False, + ) + self.assertEqual(model.model_name, "gpt-4o-mini-tts") + self.assertEqual(model.voice, "alloy") + self.assertFalse(model.stream) + self.assertFalse(model.supports_streaming_input) + + async def test_synthesize_non_streaming(self) -> None: + """Test synthesize method in non-streaming mode.""" + model = OpenAITTSModel( + api_key=self.api_key, + stream=False, + ) + + # Mock the speech.create response + mock_response = Mock() + mock_response.content = self.mock_audio_bytes + model._client.audio.speech.create = AsyncMock( + return_value=mock_response, + ) + + msg = Msg(name="user", content="Hello! Test message.", role="user") + response = await model.synthesize(msg) + + expected_content = AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=self.mock_audio_base64, + media_type="audio/pcm", + ), + ) + self.assertEqual(response.content, expected_content) + model._client.audio.speech.create.assert_called_once() + + async def test_synthesize_streaming(self) -> None: + """Test synthesize method in streaming mode.""" + model = OpenAITTSModel( + api_key=self.api_key, + stream=True, + ) + + chunk1 = b"audio_chunk_1" + chunk2 = b"audio_chunk_2" + + # Create mock streaming response inline + mock_stream = MagicMock() + mock_stream.__aenter__ = AsyncMock(return_value=mock_stream) + mock_stream.__aexit__ = AsyncMock(return_value=None) + + async def mock_iter_bytes() -> AsyncGenerator[bytes, None]: + yield chunk1 + yield chunk2 + + mock_stream.iter_bytes = mock_iter_bytes + + model._client.audio.speech.with_streaming_response.create = Mock( + return_value=mock_stream, + ) + + msg = Msg(name="user", content="Test streaming.", role="user") + response = await model.synthesize(msg) + + self.assertIsInstance(response, AsyncGenerator) + chunks = [chunk async for chunk in response] + + # Should have 3 chunks: 2 from audio data + 1 final + self.assertEqual(len(chunks), 3) + + # Chunk 1 + self.assertEqual( + chunks[0].content, + AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=base64.b64encode(chunk1).decode("utf-8"), + media_type="audio/pcm", + ), + ), + ) + + # Chunk 2 + self.assertEqual( + chunks[1].content, + AudioBlock( + type="audio", + source=Base64Source( + type="base64", + data=base64.b64encode(chunk2).decode("utf-8"), + media_type="audio/pcm", + ), + ), + ) + + # Final chunk + self.assertTrue(chunks[2].is_last)