diff --git a/src/server/index.test.ts b/src/server/index.test.ts index 46205d726..664ed4520 100644 --- a/src/server/index.test.ts +++ b/src/server/index.test.ts @@ -15,7 +15,8 @@ import { ListResourcesRequestSchema, ListToolsRequestSchema, SetLevelRequestSchema, - ErrorCode + ErrorCode, + LoggingMessageNotification } from "../types.js"; import { Transport } from "../shared/transport.js"; import { InMemoryTransport } from "../inMemory.js"; @@ -569,7 +570,7 @@ test("should allow elicitation reject and cancel without validation", async () = action: "decline", }); - // Test cancel - should not validate + // Test cancel - should not validate await expect( server.elicitInput({ message: "Please provide your name", @@ -861,3 +862,154 @@ test("should handle request timeout", async () => { code: ErrorCode.RequestTimeout, }); }); + +/* + Test automatic log level handling for transports with and without sessionId + */ +test("should respect log level for transport without sessionId", async () => { + + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + enforceStrictCapabilities: true, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + expect(clientTransport.sessionId).toEqual(undefined); + + // Client sets logging level to warning + await client.setLoggingLevel("warning"); + + // This one will make it through + const warningParams: LoggingMessageNotification["params"] = { + level: "warning", + logger: "test server", + data: "Warning message", + }; + + // This one will not + const debugParams: LoggingMessageNotification["params"] = { + level: "debug", + logger: "test server", + data: "Debug message", + }; + + // Test the one that makes it through + clientTransport.onmessage = jest.fn().mockImplementation((message) => { + expect(message).toEqual({ + jsonrpc: "2.0", + method: "notifications/message", + params: warningParams + }); + }); + + // This one will not make it through + await server.sendLoggingMessage(debugParams); + expect(clientTransport.onmessage).not.toHaveBeenCalled(); + + // This one will, triggering the above test in clientTransport.onmessage + await server.sendLoggingMessage(warningParams); + expect(clientTransport.onmessage).toHaveBeenCalled(); + +}); + +test("should respect log level for transport with sessionId", async () => { + + const server = new Server( + { + name: "test server", + version: "1.0", + }, + { + capabilities: { + prompts: {}, + resources: {}, + tools: {}, + logging: {}, + }, + enforceStrictCapabilities: true, + }, + ); + + const client = new Client( + { + name: "test client", + version: "1.0", + }, + ); + + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); + + // Add a session id to the transports + const SESSION_ID = "test-session-id"; + clientTransport.sessionId = SESSION_ID; + serverTransport.sessionId = SESSION_ID; + + expect(clientTransport.sessionId).toBeDefined(); + expect(serverTransport.sessionId).toBeDefined(); + + await Promise.all([ + client.connect(clientTransport), + server.connect(serverTransport), + ]); + + + // Client sets logging level to warning + await client.setLoggingLevel("warning"); + + // This one will make it through + const warningParams: LoggingMessageNotification["params"] = { + level: "warning", + logger: "test server", + data: "Warning message", + }; + + // This one will not + const debugParams: LoggingMessageNotification["params"] = { + level: "debug", + logger: "test server", + data: "Debug message", + }; + + // Test the one that makes it through + clientTransport.onmessage = jest.fn().mockImplementation((message) => { + expect(message).toEqual({ + jsonrpc: "2.0", + method: "notifications/message", + params: warningParams + }); + }); + + // This one will not make it through + await server.sendLoggingMessage(debugParams, SESSION_ID); + expect(clientTransport.onmessage).not.toHaveBeenCalled(); + + // This one will, triggering the above test in clientTransport.onmessage + await server.sendLoggingMessage(warningParams, SESSION_ID); + expect(clientTransport.onmessage).toHaveBeenCalled(); + +}); + diff --git a/src/server/index.ts b/src/server/index.ts index b1f71ea28..970657358 100644 --- a/src/server/index.ts +++ b/src/server/index.ts @@ -117,7 +117,7 @@ export class Server< const transportSessionId: string | undefined = extra.sessionId || extra.requestInfo?.headers['mcp-session-id'] as string || undefined; const { level } = request.params; const parseResult = LoggingLevelSchema.safeParse(level); - if (transportSessionId && parseResult.success) { + if (parseResult.success) { this._loggingLevels.set(transportSessionId, parseResult.data); } return {}; @@ -126,7 +126,7 @@ export class Server< } // Map log levels by session id - private _loggingLevels = new Map(); + private _loggingLevels = new Map(); // Map LogLevelSchema to severity index private readonly LOG_LEVEL_SEVERITY = new Map( @@ -134,7 +134,7 @@ export class Server< ); // Is a message with the given level ignored in the log level set for the given session id? - private isMessageIgnored = (level: LoggingLevel, sessionId: string): boolean => { + private isMessageIgnored = (level: LoggingLevel, sessionId?: string): boolean => { const currentLevel = this._loggingLevels.get(sessionId); return (currentLevel) ? this.LOG_LEVEL_SEVERITY.get(level)! < this.LOG_LEVEL_SEVERITY.get(currentLevel)! @@ -398,7 +398,7 @@ export class Server< */ async sendLoggingMessage(params: LoggingMessageNotification["params"], sessionId?: string) { if (this._capabilities.logging) { - if (!sessionId || !this.isMessageIgnored(params.level, sessionId)) { + if (!this.isMessageIgnored(params.level, sessionId)) { return this.notification({method: "notifications/message", params}) } }