Skip to content

Commit ff97f1a

Browse files
committed
feat: implement SEP-1577 sampling with tools support
1 parent 8d09f88 commit ff97f1a

File tree

12 files changed

+1756
-191
lines changed

12 files changed

+1756
-191
lines changed

crates/rmcp/src/model.rs

Lines changed: 265 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,18 +1209,257 @@ pub enum Role {
12091209
Assistant,
12101210
}
12111211

1212-
/// A message in a sampling conversation, containing a role and content.
1213-
///
1214-
/// This represents a single message in a conversation flow, used primarily
1215-
/// in LLM sampling requests where the conversation history is important
1216-
/// for generating appropriate responses.
1212+
/// Tool selection mode (SEP-1577).
1213+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1214+
#[serde(rename_all = "lowercase")]
1215+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1216+
pub enum ToolChoiceMode {
1217+
/// Model decides whether to use tools
1218+
Auto,
1219+
/// Model must use at least one tool
1220+
Required,
1221+
/// Model must not use tools
1222+
None,
1223+
}
1224+
1225+
impl Default for ToolChoiceMode {
1226+
fn default() -> Self {
1227+
Self::Auto
1228+
}
1229+
}
1230+
1231+
/// Tool choice configuration (SEP-1577).
1232+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
1233+
#[serde(rename_all = "camelCase")]
1234+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1235+
pub struct ToolChoice {
1236+
#[serde(skip_serializing_if = "Option::is_none")]
1237+
pub mode: Option<ToolChoiceMode>,
1238+
}
1239+
1240+
impl ToolChoice {
1241+
pub fn auto() -> Self {
1242+
Self {
1243+
mode: Some(ToolChoiceMode::Auto),
1244+
}
1245+
}
1246+
1247+
pub fn required() -> Self {
1248+
Self {
1249+
mode: Some(ToolChoiceMode::Required),
1250+
}
1251+
}
1252+
1253+
pub fn none() -> Self {
1254+
Self {
1255+
mode: Some(ToolChoiceMode::None),
1256+
}
1257+
}
1258+
}
1259+
1260+
/// Single or array content wrapper (SEP-1577).
1261+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1262+
#[serde(untagged)]
1263+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1264+
pub enum SamplingContent<T> {
1265+
Single(T),
1266+
Multiple(Vec<T>),
1267+
}
1268+
1269+
impl<T> SamplingContent<T> {
1270+
/// Convert to a Vec regardless of whether it's single or multiple
1271+
pub fn into_vec(self) -> Vec<T> {
1272+
match self {
1273+
SamplingContent::Single(item) => vec![item],
1274+
SamplingContent::Multiple(items) => items,
1275+
}
1276+
}
1277+
1278+
/// Check if the content is empty
1279+
pub fn is_empty(&self) -> bool {
1280+
match self {
1281+
SamplingContent::Single(_) => false,
1282+
SamplingContent::Multiple(items) => items.is_empty(),
1283+
}
1284+
}
1285+
1286+
/// Get the number of content items
1287+
pub fn len(&self) -> usize {
1288+
match self {
1289+
SamplingContent::Single(_) => 1,
1290+
SamplingContent::Multiple(items) => items.len(),
1291+
}
1292+
}
1293+
}
1294+
1295+
impl<T> Default for SamplingContent<T> {
1296+
fn default() -> Self {
1297+
SamplingContent::Multiple(Vec::new())
1298+
}
1299+
}
1300+
1301+
impl<T> SamplingContent<T> {
1302+
/// Get the first item if present
1303+
pub fn first(&self) -> Option<&T> {
1304+
match self {
1305+
SamplingContent::Single(item) => Some(item),
1306+
SamplingContent::Multiple(items) => items.first(),
1307+
}
1308+
}
1309+
1310+
/// Iterate over all content items
1311+
pub fn iter(&self) -> impl Iterator<Item = &T> {
1312+
let items: Vec<&T> = match self {
1313+
SamplingContent::Single(item) => vec![item],
1314+
SamplingContent::Multiple(items) => items.iter().collect(),
1315+
};
1316+
items.into_iter()
1317+
}
1318+
}
1319+
1320+
impl SamplingMessageContent {
1321+
/// Get the text content if this is a Text variant
1322+
pub fn as_text(&self) -> Option<&RawTextContent> {
1323+
match self {
1324+
SamplingMessageContent::Text(text) => Some(text),
1325+
_ => None,
1326+
}
1327+
}
1328+
1329+
/// Get the tool use content if this is a ToolUse variant
1330+
pub fn as_tool_use(&self) -> Option<&ToolUseContent> {
1331+
match self {
1332+
SamplingMessageContent::ToolUse(tool_use) => Some(tool_use),
1333+
_ => None,
1334+
}
1335+
}
1336+
1337+
/// Get the tool result content if this is a ToolResult variant
1338+
pub fn as_tool_result(&self) -> Option<&ToolResultContent> {
1339+
match self {
1340+
SamplingMessageContent::ToolResult(tool_result) => Some(tool_result),
1341+
_ => None,
1342+
}
1343+
}
1344+
}
1345+
1346+
impl<T> From<T> for SamplingContent<T> {
1347+
fn from(item: T) -> Self {
1348+
SamplingContent::Single(item)
1349+
}
1350+
}
1351+
1352+
impl<T> From<Vec<T>> for SamplingContent<T> {
1353+
fn from(items: Vec<T>) -> Self {
1354+
SamplingContent::Multiple(items)
1355+
}
1356+
}
1357+
1358+
/// A message in a sampling conversation (SEP-1577).
12171359
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
12181360
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
12191361
pub struct SamplingMessage {
1220-
/// The role of the message sender (User or Assistant)
12211362
pub role: Role,
1222-
/// The actual content of the message (text, image, etc.)
1223-
pub content: Content,
1363+
/// Single content or array of contents
1364+
pub content: SamplingContent<SamplingMessageContent>,
1365+
#[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
1366+
pub meta: Option<Meta>,
1367+
}
1368+
1369+
/// Content types for sampling messages (SEP-1577).
1370+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
1371+
#[serde(tag = "type", rename_all = "snake_case")]
1372+
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
1373+
pub enum SamplingMessageContent {
1374+
Text(RawTextContent),
1375+
Image(RawImageContent),
1376+
Audio(RawAudioContent),
1377+
/// Assistant only
1378+
ToolUse(ToolUseContent),
1379+
/// User only
1380+
ToolResult(ToolResultContent),
1381+
}
1382+
1383+
impl SamplingMessageContent {
1384+
/// Create a text content
1385+
pub fn text(text: impl Into<String>) -> Self {
1386+
Self::Text(RawTextContent {
1387+
text: text.into(),
1388+
meta: None,
1389+
})
1390+
}
1391+
1392+
pub fn tool_use(id: impl Into<String>, name: impl Into<String>, input: JsonObject) -> Self {
1393+
Self::ToolUse(ToolUseContent::new(id, name, input))
1394+
}
1395+
1396+
pub fn tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1397+
Self::ToolResult(ToolResultContent::new(tool_use_id, content))
1398+
}
1399+
}
1400+
1401+
impl SamplingMessage {
1402+
pub fn new(role: Role, content: impl Into<SamplingMessageContent>) -> Self {
1403+
Self {
1404+
role,
1405+
content: SamplingContent::Single(content.into()),
1406+
meta: None,
1407+
}
1408+
}
1409+
1410+
pub fn new_multiple(role: Role, contents: Vec<SamplingMessageContent>) -> Self {
1411+
Self {
1412+
role,
1413+
content: SamplingContent::Multiple(contents),
1414+
meta: None,
1415+
}
1416+
}
1417+
1418+
pub fn user_text(text: impl Into<String>) -> Self {
1419+
Self::new(Role::User, SamplingMessageContent::text(text))
1420+
}
1421+
1422+
pub fn assistant_text(text: impl Into<String>) -> Self {
1423+
Self::new(Role::Assistant, SamplingMessageContent::text(text))
1424+
}
1425+
1426+
pub fn user_tool_result(tool_use_id: impl Into<String>, content: Vec<Content>) -> Self {
1427+
Self::new(
1428+
Role::User,
1429+
SamplingMessageContent::tool_result(tool_use_id, content),
1430+
)
1431+
}
1432+
1433+
pub fn assistant_tool_use(
1434+
id: impl Into<String>,
1435+
name: impl Into<String>,
1436+
input: JsonObject,
1437+
) -> Self {
1438+
Self::new(
1439+
Role::Assistant,
1440+
SamplingMessageContent::tool_use(id, name, input),
1441+
)
1442+
}
1443+
}
1444+
1445+
// Conversion from RawTextContent to SamplingMessageContent
1446+
impl From<RawTextContent> for SamplingMessageContent {
1447+
fn from(text: RawTextContent) -> Self {
1448+
SamplingMessageContent::Text(text)
1449+
}
1450+
}
1451+
1452+
// Conversion from String to SamplingMessageContent (as text)
1453+
impl From<String> for SamplingMessageContent {
1454+
fn from(text: String) -> Self {
1455+
SamplingMessageContent::text(text)
1456+
}
1457+
}
1458+
1459+
impl From<&str> for SamplingMessageContent {
1460+
fn from(text: &str) -> Self {
1461+
SamplingMessageContent::text(text)
1462+
}
12241463
}
12251464

12261465
/// Specifies how much context should be included in sampling requests.
@@ -1267,20 +1506,26 @@ pub struct CreateMessageRequestParams {
12671506
/// System prompt to guide the model's behavior
12681507
#[serde(skip_serializing_if = "Option::is_none")]
12691508
pub system_prompt: Option<String>,
1270-
/// How much context to include from MCP servers
1509+
/// Context inclusion (soft-deprecated)
12711510
#[serde(skip_serializing_if = "Option::is_none")]
12721511
pub include_context: Option<ContextInclusion>,
1273-
/// Temperature for controlling randomness (0.0 to 1.0)
1512+
/// Temperature (0.0 to 1.0)
12741513
#[serde(skip_serializing_if = "Option::is_none")]
12751514
pub temperature: Option<f32>,
1276-
/// Maximum number of tokens to generate
1515+
/// Max tokens to generate
12771516
pub max_tokens: u32,
1278-
/// Sequences that should stop generation
1517+
/// Stop sequences
12791518
#[serde(skip_serializing_if = "Option::is_none")]
12801519
pub stop_sequences: Option<Vec<String>>,
1281-
/// Additional metadata for the request
1520+
/// Request metadata
12821521
#[serde(skip_serializing_if = "Option::is_none")]
12831522
pub metadata: Option<Value>,
1523+
/// Tools for the model (SEP-1577)
1524+
#[serde(skip_serializing_if = "Option::is_none")]
1525+
pub tools: Option<Vec<Tool>>,
1526+
/// Tool selection config (SEP-1577)
1527+
#[serde(skip_serializing_if = "Option::is_none")]
1528+
pub tool_choice: Option<ToolChoice>,
12841529
}
12851530

12861531
impl RequestParamsMeta for CreateMessageRequestParams {
@@ -1926,20 +2171,15 @@ pub type CallToolRequestParam = CallToolRequestParams;
19262171
/// Request to call a specific tool
19272172
pub type CallToolRequest = Request<CallToolRequestMethod, CallToolRequestParams>;
19282173

1929-
/// The result of a sampling/createMessage request containing the generated response.
1930-
///
1931-
/// This structure contains the generated message along with metadata about
1932-
/// how the generation was performed and why it stopped.
2174+
/// Result of sampling/createMessage (SEP-1577).
19332175
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
19342176
#[serde(rename_all = "camelCase")]
19352177
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
19362178
pub struct CreateMessageResult {
1937-
/// The identifier of the model that generated the response
19382179
pub model: String,
1939-
/// The reason why generation stopped (e.g., "endTurn", "maxTokens")
2180+
/// Why generation stopped: "endTurn", "stopSequence", "maxTokens", "toolUse"
19402181
#[serde(skip_serializing_if = "Option::is_none")]
19412182
pub stop_reason: Option<String>,
1942-
/// The generated message with role and content
19432183
#[serde(flatten)]
19442184
pub message: SamplingMessage,
19452185
}
@@ -1948,6 +2188,7 @@ impl CreateMessageResult {
19482188
pub const STOP_REASON_END_TURN: &str = "endTurn";
19492189
pub const STOP_REASON_END_SEQUENCE: &str = "stopSequence";
19502190
pub const STOP_REASON_END_MAX_TOKEN: &str = "maxTokens";
2191+
pub const STOP_REASON_TOOL_USE: &str = "toolUse";
19512192
}
19522193

19532194
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -2476,7 +2717,10 @@ mod tests {
24762717
..
24772718
}) => {
24782719
assert_eq!(capabilities.roots.unwrap().list_changed, Some(true));
2479-
assert_eq!(capabilities.sampling.unwrap().len(), 0);
2720+
// Empty sampling capability (no tools or context sub-capabilities)
2721+
let sampling = capabilities.sampling.unwrap();
2722+
assert_eq!(sampling.tools, None);
2723+
assert_eq!(sampling.context, None);
24802724
assert_eq!(client_info.name, "ExampleClient");
24812725
assert_eq!(client_info.version, "1.0.0");
24822726
}

0 commit comments

Comments
 (0)