diff --git a/packages/core/src/mcp/stored-oauth-provider.test.ts b/packages/core/src/mcp/stored-oauth-provider.test.ts new file mode 100644 index 00000000000..b31e2b6c88e --- /dev/null +++ b/packages/core/src/mcp/stored-oauth-provider.test.ts @@ -0,0 +1,164 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import { beforeEach, describe, expect, it, vi } from 'vitest'; +import type { MCPOAuthConfig } from './oauth-provider.js'; +import { StoredOAuthMcpProvider } from './stored-oauth-provider.js'; +import type { MCPOAuthTokenStorage } from './oauth-token-storage.js'; +import type { OAuthCredentials } from './token-storage/types.js'; + +vi.mock('../utils/events.js', () => ({ + coreEvents: { + emitFeedback: vi.fn(), + }, +})); + +vi.mock('../mcp/token-storage/hybrid-token-storage.js', () => ({ + HybridTokenStorage: vi.fn(), +})); + +describe('StoredOAuthMcpProvider', () => { + const oauthConfig: MCPOAuthConfig = { + tokenUrl: 'https://auth.example.com/token', + }; + + const storedCredentials: OAuthCredentials = { + serverName: 'test-server', + token: { + accessToken: 'stored-access-token', + refreshToken: 'stored-refresh-token', + tokenType: 'Bearer', + scope: 'scope-a scope-b', + expiresAt: Date.now() + 3600_000, + }, + clientId: 'stored-client-id', + tokenUrl: 'https://auth.example.com/token', + mcpServerUrl: 'https://example.com/mcp', + updatedAt: Date.now(), + }; + + let tokenStorage: MCPOAuthTokenStorage; + + beforeEach(() => { + tokenStorage = { + getCredentials: vi.fn().mockResolvedValue(storedCredentials), + saveToken: vi.fn().mockResolvedValue(undefined), + deleteCredentials: vi.fn().mockResolvedValue(undefined), + } as unknown as MCPOAuthTokenStorage; + }); + + it('returns stored client information and tokens in SDK shape', async () => { + const provider = new StoredOAuthMcpProvider( + 'test-server', + oauthConfig, + tokenStorage, + ); + + await expect(provider.clientInformation()).resolves.toEqual({ + client_id: 'stored-client-id', + client_secret: undefined, + token_endpoint_auth_method: 'none', + }); + + const tokens = await provider.tokens(); + expect(tokens?.access_token).toBe('stored-access-token'); + expect(tokens?.refresh_token).toBe('stored-refresh-token'); + expect(tokens?.token_type).toBe('Bearer'); + expect(tokens?.scope).toBe('scope-a scope-b'); + expect(tokens?.expires_in).toBeGreaterThan(0); + }); + + it('saves refreshed tokens and preserves the previous refresh token when omitted', async () => { + vi.mocked(tokenStorage.getCredentials) + .mockResolvedValueOnce(storedCredentials) + .mockResolvedValueOnce({ + ...storedCredentials, + token: { + ...storedCredentials.token, + accessToken: 'refreshed-access-token', + }, + }); + + const provider = new StoredOAuthMcpProvider( + 'test-server', + oauthConfig, + tokenStorage, + ); + + await provider.saveTokens({ + access_token: 'refreshed-access-token', + token_type: 'Bearer', + expires_in: 1800, + }); + + expect(tokenStorage.saveToken).toHaveBeenCalledWith( + 'test-server', + expect.objectContaining({ + accessToken: 'refreshed-access-token', + refreshToken: 'stored-refresh-token', + tokenType: 'Bearer', + }), + 'stored-client-id', + 'https://auth.example.com/token', + 'https://example.com/mcp', + ); + }); + + it('does not preserve a stale expiresAt when refreshed tokens omit expires_in', async () => { + const expiredCredentials: OAuthCredentials = { + ...storedCredentials, + token: { + ...storedCredentials.token, + expiresAt: Date.now() - 60_000, + }, + }; + + vi.mocked(tokenStorage.getCredentials) + .mockResolvedValueOnce(expiredCredentials) + .mockResolvedValueOnce({ + ...expiredCredentials, + token: { + accessToken: 'refreshed-access-token', + refreshToken: 'stored-refresh-token', + tokenType: 'Bearer', + scope: 'scope-a scope-b', + }, + }); + + const provider = new StoredOAuthMcpProvider( + 'test-server', + oauthConfig, + tokenStorage, + ); + + await provider.saveTokens({ + access_token: 'refreshed-access-token', + token_type: 'Bearer', + }); + + expect(tokenStorage.saveToken).toHaveBeenCalledWith( + 'test-server', + expect.not.objectContaining({ + expiresAt: expect.any(Number), + }), + 'stored-client-id', + 'https://auth.example.com/token', + 'https://example.com/mcp', + ); + }); + + it('invalidates stored credentials', async () => { + const provider = new StoredOAuthMcpProvider( + 'test-server', + oauthConfig, + tokenStorage, + ); + + await provider.invalidateCredentials('tokens'); + + expect(tokenStorage.deleteCredentials).toHaveBeenCalledWith('test-server'); + }); +}); diff --git a/packages/core/src/mcp/stored-oauth-provider.ts b/packages/core/src/mcp/stored-oauth-provider.ts new file mode 100644 index 00000000000..0fc8bf02485 --- /dev/null +++ b/packages/core/src/mcp/stored-oauth-provider.ts @@ -0,0 +1,189 @@ +/** + * @license + * Copyright 2026 Google LLC + * SPDX-License-Identifier: Apache-2.0 + */ + +import type { + OAuthClientInformationMixed, + OAuthClientMetadata, + OAuthTokens, +} from '@modelcontextprotocol/sdk/shared/auth.js'; +import { REDIRECT_PATH } from '../utils/oauth-flow.js'; +import { coreEvents } from '../utils/events.js'; +import { debugLogger } from '../utils/debugLogger.js'; +import type { MCPOAuthConfig } from './oauth-provider.js'; +import { MCPOAuthTokenStorage } from './oauth-token-storage.js'; +import type { OAuthCredentials, OAuthToken } from './token-storage/types.js'; +import type { McpAuthProvider } from './auth-provider.js'; + +const DEFAULT_REDIRECT_URL = `http://localhost${REDIRECT_PATH}`; + +function toOAuthTokens(token: OAuthToken): OAuthTokens { + const tokens: OAuthTokens = { + access_token: token.accessToken, + token_type: token.tokenType, + }; + + if (token.refreshToken) { + tokens.refresh_token = token.refreshToken; + } + if (token.scope) { + tokens.scope = token.scope; + } + if (token.expiresAt) { + tokens.expires_in = Math.max( + 0, + Math.floor((token.expiresAt - Date.now()) / 1000), + ); + } + + return tokens; +} + +function toStoredToken( + tokens: OAuthTokens, + previousToken?: OAuthToken, +): OAuthToken { + const storedToken: OAuthToken = { + accessToken: tokens.access_token, + tokenType: tokens.token_type || previousToken?.tokenType || 'Bearer', + refreshToken: tokens.refresh_token || previousToken?.refreshToken, + scope: tokens.scope || previousToken?.scope, + }; + + if (tokens.expires_in !== undefined) { + storedToken.expiresAt = Date.now() + tokens.expires_in * 1000; + } + + return storedToken; +} + +export class StoredOAuthMcpProvider implements McpAuthProvider { + private cachedCredentials?: OAuthCredentials | null; + private cachedClientInformation?: OAuthClientInformationMixed; + private cachedCodeVerifier?: string; + + constructor( + private readonly serverName: string, + private readonly oauthConfig: MCPOAuthConfig = {}, + private readonly tokenStorage: MCPOAuthTokenStorage = new MCPOAuthTokenStorage(), + ) {} + + get redirectUrl(): string { + return this.oauthConfig.redirectUri || DEFAULT_REDIRECT_URL; + } + + get clientMetadata(): OAuthClientMetadata { + return { + client_name: 'Gemini CLI MCP Client', + redirect_uris: [this.redirectUrl], + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: this.oauthConfig.clientSecret + ? 'client_secret_post' + : 'none', + scope: this.oauthConfig.scopes?.join(' ') || undefined, + }; + } + + private async getCredentials(): Promise { + if (this.cachedCredentials !== undefined) { + return this.cachedCredentials; + } + this.cachedCredentials = await this.tokenStorage.getCredentials( + this.serverName, + ); + return this.cachedCredentials; + } + + async clientInformation(): Promise { + if (this.cachedClientInformation) { + return this.cachedClientInformation; + } + + const credentials = await this.getCredentials(); + const clientId = this.oauthConfig.clientId || credentials?.clientId; + if (!clientId) { + return undefined; + } + + this.cachedClientInformation = { + client_id: clientId, + client_secret: this.oauthConfig.clientSecret, + token_endpoint_auth_method: this.oauthConfig.clientSecret + ? 'client_secret_post' + : 'none', + }; + return this.cachedClientInformation; + } + + saveClientInformation(clientInformation: OAuthClientInformationMixed): void { + this.cachedClientInformation = clientInformation; + } + + async tokens(): Promise { + const credentials = await this.getCredentials(); + if (!credentials) { + return undefined; + } + return toOAuthTokens(credentials.token); + } + + async saveTokens(tokens: OAuthTokens): Promise { + const credentials = await this.getCredentials(); + const clientId = + this.oauthConfig.clientId || + credentials?.clientId || + this.cachedClientInformation?.client_id; + + await this.tokenStorage.saveToken( + this.serverName, + toStoredToken(tokens, credentials?.token), + clientId, + this.oauthConfig.tokenUrl || credentials?.tokenUrl, + credentials?.mcpServerUrl, + ); + + this.cachedCredentials = await this.tokenStorage.getCredentials( + this.serverName, + ); + } + + async redirectToAuthorization(authorizationUrl: URL): Promise { + debugLogger.log( + `Stored OAuth provider for '${this.serverName}' needs re-authentication at ${authorizationUrl.toString()}`, + ); + coreEvents.emitFeedback( + 'info', + `MCP server '${this.serverName}' requires re-authentication using: /mcp auth ${this.serverName}`, + ); + } + + saveCodeVerifier(codeVerifier: string): void { + this.cachedCodeVerifier = codeVerifier; + } + + codeVerifier(): string { + if (!this.cachedCodeVerifier) { + throw new Error('No code verifier saved'); + } + return this.cachedCodeVerifier; + } + + async invalidateCredentials( + scope: 'all' | 'client' | 'tokens' | 'verifier', + ): Promise { + if (scope === 'all' || scope === 'client' || scope === 'tokens') { + await this.tokenStorage.deleteCredentials(this.serverName); + this.cachedCredentials = null; + if (scope === 'all' || scope === 'client') { + this.cachedClientInformation = undefined; + } + } + + if (scope === 'all' || scope === 'verifier') { + this.cachedCodeVerifier = undefined; + } + } +} diff --git a/packages/core/src/tools/mcp-client.test.ts b/packages/core/src/tools/mcp-client.test.ts index 4a14b671a0d..90a1c82b085 100644 --- a/packages/core/src/tools/mcp-client.test.ts +++ b/packages/core/src/tools/mcp-client.test.ts @@ -18,6 +18,7 @@ import { GoogleCredentialProvider } from '../mcp/google-auth-provider.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; +import type { McpAuthProvider } from '../mcp/auth-provider.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; import { PromptListChangedNotificationSchema, @@ -33,6 +34,7 @@ import { createTransport, hasNetworkTransport, isEnabled, + MCPServerStatus, McpClient, populateMcpServerCommand, type McpContext, @@ -46,7 +48,7 @@ import { coreEvents } from '../utils/events.js'; import type { EnvironmentSanitizationConfig } from '../services/environmentSanitization.js'; interface TestableTransport { - _authProvider?: GoogleCredentialProvider; + _authProvider?: McpAuthProvider; _requestInit?: { headers?: Record; }; @@ -438,6 +440,269 @@ describe('mcp-client', () => { expect(mockedToolRegistry.registerTool).toHaveBeenCalledOnce(); }); + it('surfaces recoverable auth transport errors when no auth recovery is in progress', async () => { + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + close: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ tools: [] }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + onerror: undefined, + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + ( + mockedClient as { + onerror?: (error: unknown) => void; + } + ).onerror?.( + new StreamableHTTPError( + 401, + 'Server returned 401 after successful authentication', + ), + ); + + expect(client.getStatus()).toBe(MCPServerStatus.DISCONNECTED); + expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith( + 'error', + 'MCP ERROR (test-server)', + expect.anything(), + 'test-server', + ); + }); + + it('ignores transient Streamable HTTP background SSE disconnects', async () => { + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + close: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ tools: [] }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + onerror: undefined, + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const client = new McpClient( + 'test-server', + { url: 'http://test-server', type: 'http' }, + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + ( + mockedClient as { + onerror?: (error: unknown) => void; + } + ).onerror?.(new Error('SSE stream disconnected: TypeError: terminated')); + + expect(client.getStatus()).toBe(MCPServerStatus.CONNECTED); + expect(MOCK_CONTEXT.emitMcpDiagnostic).not.toHaveBeenCalledWith( + 'error', + 'MCP ERROR (test-server)', + expect.anything(), + 'test-server', + ); + }); + + it('surfaces terminal Streamable HTTP reconnect exhaustion errors', async () => { + const mockedClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + close: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ tools: [] }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + onerror: undefined, + }; + vi.mocked(ClientLib.Client).mockReturnValue( + mockedClient as unknown as ClientLib.Client, + ); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const client = new McpClient( + 'test-server', + { url: 'http://test-server', type: 'http' }, + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + ( + mockedClient as { + onerror?: (error: unknown) => void; + } + ).onerror?.(new Error('Maximum reconnection attempts (2) exceeded.')); + + expect(client.getStatus()).toBe(MCPServerStatus.DISCONNECTED); + expect(MOCK_CONTEXT.emitMcpDiagnostic).toHaveBeenCalledWith( + 'error', + 'MCP ERROR (test-server)', + expect.anything(), + 'test-server', + ); + }); + + it('reconnects and retries a tool call once after a recoverable auth error', async () => { + const firstClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + close: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'testTool', + description: 'A test tool', + inputSchema: { type: 'object', properties: {} }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + callTool: vi + .fn() + .mockRejectedValue( + new StreamableHTTPError( + 401, + 'Server returned 401 after successful authentication', + ), + ), + }; + const secondClient = { + connect: vi.fn(), + discover: vi.fn(), + disconnect: vi.fn(), + close: vi.fn(), + getStatus: vi.fn(), + registerCapabilities: vi.fn(), + setRequestHandler: vi.fn(), + setNotificationHandler: vi.fn(), + getServerCapabilities: vi.fn().mockReturnValue({ tools: {} }), + listTools: vi.fn().mockResolvedValue({ + tools: [ + { + name: 'testTool', + description: 'A test tool', + inputSchema: { type: 'object', properties: {} }, + }, + ], + }), + listPrompts: vi.fn().mockResolvedValue({ prompts: [] }), + request: vi.fn().mockResolvedValue({}), + callTool: vi.fn().mockResolvedValue({ + content: [{ type: 'text', text: 'ok' }], + }), + }; + vi.mocked(ClientLib.Client) + .mockReturnValueOnce(firstClient as unknown as ClientLib.Client) + .mockReturnValueOnce(secondClient as unknown as ClientLib.Client); + vi.spyOn(SdkClientStdioLib, 'StdioClientTransport').mockReturnValue( + {} as SdkClientStdioLib.StdioClientTransport, + ); + + const mockedToolRegistry = { + registerTool: vi.fn(), + sortTools: vi.fn(), + getToolsByServer: vi.fn().mockReturnValue([]), + getMessageBus: vi.fn().mockReturnValue(undefined), + } as unknown as ToolRegistry; + const promptRegistry = { + registerPrompt: vi.fn(), + getPromptsByServer: vi.fn().mockReturnValue([]), + removePromptsByServer: vi.fn(), + } as unknown as PromptRegistry; + const resourceRegistry = { + getResourcesByServer: vi.fn().mockReturnValue([]), + setResourcesForServer: vi.fn(), + removeResourcesByServer: vi.fn(), + } as unknown as ResourceRegistry; + + const client = new McpClient( + 'test-server', + { command: 'test-command' }, + workspaceContext, + MOCK_CONTEXT, + false, + '0.0.1', + ); + + await client.connect(); + await client.discoverInto(MOCK_CONTEXT, { + toolRegistry: mockedToolRegistry, + promptRegistry, + resourceRegistry, + }); + + const discoveredTool = vi.mocked(mockedToolRegistry.registerTool).mock + .calls[0][0] as DiscoveredMCPTool; + const invocation = discoveredTool.build({}); + const result = await invocation.execute(new AbortController().signal); + + expect(result.error).toBeUndefined(); + expect(firstClient.callTool).toHaveBeenCalledTimes(1); + expect(firstClient.close).toHaveBeenCalledTimes(1); + expect(secondClient.callTool).toHaveBeenCalledTimes(1); + expect(client.getStatus()).toBe(MCPServerStatus.CONNECTED); + expect(MOCK_CONTEXT.emitMcpDiagnostic).not.toHaveBeenCalledWith( + 'error', + 'MCP ERROR (test-server)', + expect.anything(), + 'test-server', + ); + }); + it('should register tool with readOnlyHint and preserve annotations', async () => { const mockedClient = { connect: vi.fn(), @@ -1883,6 +2148,41 @@ describe('mcp-client', () => { }); }); + it('uses a refresh-capable auth provider for stored OAuth credentials on HTTP transport', async () => { + const mockStoredCredentials = { + serverName: 'test-server', + token: { + accessToken: 'stored-access-token', + refreshToken: 'stored-refresh-token', + tokenType: 'Bearer', + expiresAt: Date.now() - 1000, + }, + clientId: 'stored-client-id', + tokenUrl: 'https://auth.example.com/token', + updatedAt: Date.now(), + }; + vi.mocked(MCPOAuthTokenStorage).mockReturnValue({ + getCredentials: vi.fn().mockResolvedValue(mockStoredCredentials), + } as unknown as MCPOAuthTokenStorage); + + const transport = await createTransport( + 'test-server', + { + url: 'http://test-server', + }, + false, + MOCK_CONTEXT, + ); + + expect(transport).toBeInstanceOf(StreamableHTTPClientTransport); + const testableTransport = transport as unknown as TestableTransport; + expect(testableTransport._authProvider).toBeDefined(); + expect( + testableTransport._requestInit?.headers?.['Authorization'], + ).toBeUndefined(); + expect(vi.mocked(MCPOAuthProvider)).not.toHaveBeenCalled(); + }); + it('with type="http" and headers applies headers correctly', async () => { const transport = await createTransport( 'test-server', diff --git a/packages/core/src/tools/mcp-client.ts b/packages/core/src/tools/mcp-client.ts index 58b7b6c8e22..fc39fced834 100644 --- a/packages/core/src/tools/mcp-client.ts +++ b/packages/core/src/tools/mcp-client.ts @@ -55,6 +55,7 @@ import type { McpAuthProvider } from '../mcp/auth-provider.js'; import { MCPOAuthProvider } from '../mcp/oauth-provider.js'; import { MCPOAuthTokenStorage } from '../mcp/oauth-token-storage.js'; import { OAuthUtils } from '../mcp/oauth-utils.js'; +import { StoredOAuthMcpProvider } from '../mcp/stored-oauth-provider.js'; import type { PromptRegistry } from '../prompts/prompt-registry.js'; import { getErrorMessage, @@ -80,6 +81,7 @@ import { type EnvironmentSanitizationConfig, } from '../services/environmentSanitization.js'; import { expandEnvVars } from '../utils/envExpansion.js'; +import { isRecord } from '../utils/markdownUtils.js'; import { GEMINI_CLI_IDENTIFICATION_ENV_VAR, GEMINI_CLI_IDENTIFICATION_ENV_VAR_VALUE, @@ -146,6 +148,7 @@ export class McpClient implements McpProgressReporter { private client: Client | undefined; private transport: Transport | undefined; private status: MCPServerStatus = MCPServerStatus.DISCONNECTED; + private authRecoveryInProgress = false; private isRefreshingTools: boolean = false; private pendingToolRefresh: boolean = false; private isRefreshingResources: boolean = false; @@ -175,6 +178,147 @@ export class McpClient implements McpProgressReporter { return this.serverName; } + private isRecoverableTransportError(error: unknown): boolean { + const message = getErrorMessage(error); + + if (isAuthenticationError(error)) { + return true; + } + + return ( + message.includes('401') || + message.includes('Unauthorized') || + message.includes('No auth provider') || + message.includes('after successful authentication') + ); + } + + private isTransientStreamableHttpBackgroundError(error: unknown): boolean { + if (!usesHttpTransport(this.serverConfig)) { + return false; + } + + const message = getErrorMessage(error); + return ( + message.includes('SSE stream disconnected') || + message.includes('Failed to reconnect SSE stream') || + message.includes('Failed to open SSE stream') + ); + } + + private wireClientErrorHandling(client: Client): void { + const originalOnError = client.onerror; + client.onerror = (error) => { + if (this.status !== MCPServerStatus.CONNECTED) { + return; + } + + if ( + this.authRecoveryInProgress && + this.isRecoverableTransportError(error) + ) { + debugLogger.log( + `Ignoring MCP transport error during auth recovery for '${this.serverName}': ${getErrorMessage(error)}`, + ); + return; + } + + if (this.isTransientStreamableHttpBackgroundError(error)) { + debugLogger.log( + `Ignoring transient Streamable HTTP background transport error for '${this.serverName}': ${getErrorMessage(error)}`, + ); + return; + } + + if (originalOnError) originalOnError(error); + this.cliConfig.emitMcpDiagnostic( + 'error', + `MCP ERROR (${this.serverName})`, + error, + this.serverName, + ); + this.updateStatus(MCPServerStatus.DISCONNECTED); + }; + } + + private async establishConnection(): Promise { + this.client = await connectToMcpServer( + this.clientVersion, + this.serverName, + this.serverConfig, + this.debugMode, + this.workspaceContext, + this.cliConfig, + ); + + this.registerNotificationHandlers(); + this.wireClientErrorHandling(this.client); + this.updateStatus(MCPServerStatus.CONNECTED); + } + + private async reconnect(): Promise { + if (this.client) { + try { + await this.client.close(); + } catch (error) { + debugLogger.debug( + `Ignoring MCP client close error during reconnect for '${this.serverName}': ${getErrorMessage(error)}`, + ); + } + } + + this.client = undefined; + this.updateStatus(MCPServerStatus.CONNECTING); + await this.establishConnection(); + } + + private async runWithAuthRecovery( + operation: () => Promise, + ): Promise { + this.authRecoveryInProgress = true; + try { + return await operation(); + } finally { + this.authRecoveryInProgress = false; + } + } + + private async callToolWithReconnect( + call: FunctionCall, + progressToken: string, + ): Promise { + this.assertConnected(); + + const request = { + name: call.name!, + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + arguments: call.args as Record, + _meta: { progressToken }, + }; + + try { + return await this.client!.callTool(request, undefined, { + timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + } catch (error) { + if (!this.isRecoverableTransportError(error)) { + throw error; + } + + debugLogger.log( + `Recoverable MCP tool call error for '${this.serverName}', reconnecting and retrying once: ${getErrorMessage(error)}`, + ); + + return this.runWithAuthRecovery(async () => { + await this.reconnect(); + + return this.client!.callTool(request, undefined, { + timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, + }); + }); + } + } + /** * Connects to the MCP server. */ @@ -186,32 +330,7 @@ export class McpClient implements McpProgressReporter { } this.updateStatus(MCPServerStatus.CONNECTING); try { - this.client = await connectToMcpServer( - this.clientVersion, - this.serverName, - this.serverConfig, - this.debugMode, - this.workspaceContext, - this.cliConfig, - ); - - this.registerNotificationHandlers(); - - const originalOnError = this.client.onerror; - this.client.onerror = (error) => { - if (this.status !== MCPServerStatus.CONNECTED) { - return; - } - if (originalOnError) originalOnError(error); - this.cliConfig.emitMcpDiagnostic( - 'error', - `MCP ERROR (${this.serverName})`, - error, - this.serverName, - ); - this.updateStatus(MCPServerStatus.DISCONNECTED); - }; - this.updateStatus(MCPServerStatus.CONNECTED); + await this.establishConnection(); } catch (error) { this.updateStatus(MCPServerStatus.DISCONNECTED); throw error; @@ -335,6 +454,8 @@ export class McpClient implements McpProgressReporter { timeout: this.serverConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, }), progressReporter: this, + toolCallRunner: (call, progressToken) => + this.callToolWithReconnect(call, progressToken), }, ); } @@ -1027,6 +1148,39 @@ function createAuthProvider( return undefined; } +function usesHttpTransport(mcpServerConfig: MCPServerConfig): boolean { + if (mcpServerConfig.httpUrl) { + return true; + } + + if (mcpServerConfig.type === 'sse') { + return false; + } + + return !!mcpServerConfig.url; +} + +async function createStoredOAuthAuthProvider( + mcpServerName: string, + mcpServerConfig: MCPServerConfig, +): Promise { + if (!usesHttpTransport(mcpServerConfig)) { + return undefined; + } + + const tokenStorage = new MCPOAuthTokenStorage(); + const credentials = await tokenStorage.getCredentials(mcpServerName); + if (!credentials) { + return undefined; + } + + return new StoredOAuthMcpProvider( + mcpServerName, + mcpServerConfig.oauth ?? {}, + tokenStorage, + ); +} + /** * Create a transport with OAuth token for the given server configuration. * @@ -1276,6 +1430,10 @@ export async function discoverTools( timeout?: number; signal?: AbortSignal; progressReporter?: McpProgressReporter; + toolCallRunner?: ( + call: FunctionCall, + progressToken: string, + ) => Promise; }, ): Promise { try { @@ -1295,6 +1453,7 @@ export async function discoverTools( toolDef, mcpServerConfig.timeout ?? MCP_DEFAULT_TIMEOUT_MSEC, options?.progressReporter, + options?.toolCallRunner, ); // Extract annotations from the tool definition @@ -1355,6 +1514,10 @@ class McpCallableTool implements CallableTool { private readonly toolDef: McpTool, private readonly timeout: number, private readonly progressReporter?: McpProgressReporter, + private readonly toolCallRunner?: ( + call: FunctionCall, + progressToken: string, + ) => Promise, ) {} async tool(): Promise { @@ -1386,22 +1549,29 @@ class McpCallableTool implements CallableTool { } try { - const result = await this.client.callTool( - { - name: call.name!, - // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion - arguments: call.args as Record, - _meta: { progressToken }, - }, - undefined, - { timeout: this.timeout }, - ); + const result = this.toolCallRunner + ? await this.toolCallRunner(call, progressToken) + : await this.client.callTool( + { + name: call.name!, + // eslint-disable-next-line @typescript-eslint/no-unsafe-type-assertion + arguments: call.args as Record, + _meta: { progressToken }, + }, + undefined, + { timeout: this.timeout }, + ); return [ { functionResponse: { name: call.name, - response: result, + response: + result === undefined + ? undefined + : isRecord(result) + ? result + : { result }, }, }, ]; @@ -2186,7 +2356,11 @@ export async function createTransport( } } if (mcpServerConfig.httpUrl || mcpServerConfig.url) { - const authProvider = createAuthProvider(mcpServerConfig); + let authProvider = createAuthProvider(mcpServerConfig); + authProvider ??= await createStoredOAuthAuthProvider( + mcpServerName, + mcpServerConfig, + ); const headers: Record = (await authProvider?.getRequestHeaders?.()) ?? {};