Skip to content

Commit 12ae856

Browse files
olaservoclaude
andauthored
[v1.x backport] Use correct schema for client sampling validation when tools are present (#1407)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
1 parent b392f02 commit 12ae856

File tree

2 files changed

+131
-2
lines changed

2 files changed

+131
-2
lines changed

src/client/index.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import {
4040
CreateTaskResultSchema,
4141
CreateMessageRequestSchema,
4242
CreateMessageResultSchema,
43+
CreateMessageResultWithToolsSchema,
4344
ToolListChangedNotificationSchema,
4445
PromptListChangedNotificationSchema,
4546
ResourceListChangedNotificationSchema,
@@ -452,8 +453,10 @@ export class Client<
452453
return taskValidationResult.data;
453454
}
454455

455-
// For non-task requests, validate against CreateMessageResultSchema
456-
const validationResult = safeParse(CreateMessageResultSchema, result);
456+
// For non-task requests, validate against appropriate schema based on tools presence
457+
const hasTools = params.tools || params.toolChoice;
458+
const resultSchema = hasTools ? CreateMessageResultWithToolsSchema : CreateMessageResultSchema;
459+
const validationResult = safeParse(resultSchema, result);
457460
if (!validationResult.success) {
458461
const errorMessage =
459462
validationResult.error instanceof Error ? validationResult.error.message : String(validationResult.error);

test/client/index.test.ts

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4137,3 +4137,129 @@ describe('getSupportedElicitationModes', () => {
41374137
expect(result.supportsUrlMode).toBe(false);
41384138
});
41394139
});
4140+
4141+
describe('Client sampling validation with tools', () => {
4142+
test('should validate array content with tool_use when request includes tools', async () => {
4143+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4144+
4145+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
4146+
4147+
// Handler returns array content with tool_use - should validate with CreateMessageResultWithToolsSchema
4148+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4149+
model: 'test-model',
4150+
role: 'assistant',
4151+
stopReason: 'toolUse',
4152+
content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: { arg: 'value' } }]
4153+
}));
4154+
4155+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4156+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4157+
4158+
const result = await server.createMessage({
4159+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4160+
maxTokens: 100,
4161+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
4162+
});
4163+
4164+
expect(result.stopReason).toBe('toolUse');
4165+
expect(Array.isArray(result.content)).toBe(true);
4166+
expect((result.content as Array<{ type: string }>)[0].type).toBe('tool_use');
4167+
});
4168+
4169+
test('should validate single content when request includes tools', async () => {
4170+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4171+
4172+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
4173+
4174+
// Handler returns single content (text) - should still validate with CreateMessageResultWithToolsSchema
4175+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4176+
model: 'test-model',
4177+
role: 'assistant',
4178+
content: { type: 'text', text: 'No tool needed' }
4179+
}));
4180+
4181+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4182+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4183+
4184+
const result = await server.createMessage({
4185+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4186+
maxTokens: 100,
4187+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }]
4188+
});
4189+
4190+
expect((result.content as { type: string }).type).toBe('text');
4191+
});
4192+
4193+
test('should validate single content when request has no tools', async () => {
4194+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4195+
4196+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } });
4197+
4198+
// Handler returns single content - should validate with CreateMessageResultSchema
4199+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4200+
model: 'test-model',
4201+
role: 'assistant',
4202+
content: { type: 'text', text: 'Response' }
4203+
}));
4204+
4205+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4206+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4207+
4208+
const result = await server.createMessage({
4209+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4210+
maxTokens: 100
4211+
});
4212+
4213+
expect((result.content as { type: string }).type).toBe('text');
4214+
});
4215+
4216+
test('should reject array content when request has no tools', async () => {
4217+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4218+
4219+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: {} } });
4220+
4221+
// Handler returns array content - should fail validation with CreateMessageResultSchema
4222+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4223+
model: 'test-model',
4224+
role: 'assistant',
4225+
content: [{ type: 'text', text: 'Array response' }]
4226+
}));
4227+
4228+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4229+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4230+
4231+
await expect(
4232+
server.createMessage({
4233+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4234+
maxTokens: 100
4235+
})
4236+
).rejects.toThrow('Invalid sampling result');
4237+
});
4238+
4239+
test('should validate array content when request includes toolChoice', async () => {
4240+
const server = new Server({ name: 'test server', version: '1.0' }, { capabilities: {} });
4241+
4242+
const client = new Client({ name: 'test client', version: '1.0' }, { capabilities: { sampling: { tools: {} } } });
4243+
4244+
// Handler returns array content with tool_use
4245+
client.setRequestHandler(CreateMessageRequestSchema, async () => ({
4246+
model: 'test-model',
4247+
role: 'assistant',
4248+
stopReason: 'toolUse',
4249+
content: [{ type: 'tool_use', id: 'call_1', name: 'test_tool', input: {} }]
4250+
}));
4251+
4252+
const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
4253+
await Promise.all([client.connect(clientTransport), server.connect(serverTransport)]);
4254+
4255+
const result = await server.createMessage({
4256+
messages: [{ role: 'user', content: { type: 'text', text: 'hello' } }],
4257+
maxTokens: 100,
4258+
tools: [{ name: 'test_tool', inputSchema: { type: 'object' } }],
4259+
toolChoice: { mode: 'auto' }
4260+
});
4261+
4262+
expect(result.stopReason).toBe('toolUse');
4263+
expect(Array.isArray(result.content)).toBe(true);
4264+
});
4265+
});

0 commit comments

Comments
 (0)