|
1 | 1 | import { type SetupServerApi, setupServer } from "msw/node"; |
2 | 2 | import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; |
| 3 | +import { classifyAgentError } from "../adapters/claude/conversion/sdk-to-acp"; |
3 | 4 | import type { PostHogAPIClient } from "../posthog-api"; |
4 | 5 | import { createTestRepo, type TestRepo } from "../test/fixtures/api"; |
5 | 6 | import { createPostHogHandlers } from "../test/mocks/msw-handlers"; |
@@ -49,7 +50,42 @@ const QUESTION_META = { |
49 | 50 | ], |
50 | 51 | }; |
51 | 52 |
|
| 53 | +function createTransientPromptError(): Error & { |
| 54 | + data: { classification: string; result: string }; |
| 55 | +} { |
| 56 | + const error = new Error("API Error: terminated") as Error & { |
| 57 | + data: { classification: string; result: string }; |
| 58 | + }; |
| 59 | + error.data = { |
| 60 | + classification: "upstream_stream_terminated", |
| 61 | + result: "API Error: terminated", |
| 62 | + }; |
| 63 | + return error; |
| 64 | +} |
| 65 | + |
| 66 | +function createTransientConnectionError(): Error & { |
| 67 | + data: { classification: string; result: string }; |
| 68 | +} { |
| 69 | + const error = new Error("fetch failed") as Error & { |
| 70 | + data: { classification: string; result: string }; |
| 71 | + }; |
| 72 | + error.data = { |
| 73 | + classification: "upstream_connection_error", |
| 74 | + result: "fetch failed", |
| 75 | + }; |
| 76 | + return error; |
| 77 | +} |
| 78 | + |
52 | 79 | describe("Question relay", () => { |
| 80 | + it.each([ |
| 81 | + ["API Error: terminated", "upstream_stream_terminated"], |
| 82 | + ["API Error: Connection error", "upstream_connection_error"], |
| 83 | + ["something else", "agent_error"], |
| 84 | + [undefined, "agent_error"], |
| 85 | + ])("classifies %p as %s", (message, expected) => { |
| 86 | + expect(classifyAgentError(message)).toBe(expected); |
| 87 | + }); |
| 88 | + |
53 | 89 | let repo: TestRepo; |
54 | 90 | let server: TestableAgentServer; |
55 | 91 | let mswServer: SetupServerApi; |
@@ -514,5 +550,93 @@ describe("Question relay", () => { |
514 | 550 | prompt: [{ type: "text", text: "original task description" }], |
515 | 551 | }); |
516 | 552 | }); |
| 553 | + |
| 554 | + it("does not replay a transient upstream termination before any session activity", async () => { |
| 555 | + vi.spyOn(server.posthogAPI, "getTask").mockResolvedValue({ |
| 556 | + id: "test-task-id", |
| 557 | + title: "t", |
| 558 | + description: "original task description", |
| 559 | + } as unknown as Task); |
| 560 | + vi.spyOn(server.posthogAPI, "getTaskRun").mockResolvedValue({ |
| 561 | + id: "test-run-id", |
| 562 | + task: "test-task-id", |
| 563 | + state: {}, |
| 564 | + } as unknown as TaskRun); |
| 565 | + |
| 566 | + const promptSpy = vi |
| 567 | + .fn() |
| 568 | + .mockRejectedValueOnce(createTransientPromptError()); |
| 569 | + const updateTaskRunSpy = vi |
| 570 | + .spyOn(server.posthogAPI, "updateTaskRun") |
| 571 | + .mockResolvedValue({} as TaskRun); |
| 572 | + server.session = { |
| 573 | + payload: TEST_PAYLOAD, |
| 574 | + acpSessionId: "acp-session", |
| 575 | + clientConnection: { prompt: promptSpy }, |
| 576 | + logWriter: { |
| 577 | + flushAll: vi.fn().mockResolvedValue(undefined), |
| 578 | + getFullAgentResponse: vi.fn().mockReturnValue(null), |
| 579 | + resetTurnMessages: vi.fn(), |
| 580 | + flush: vi.fn().mockResolvedValue(undefined), |
| 581 | + isRegistered: vi.fn().mockReturnValue(true), |
| 582 | + }, |
| 583 | + }; |
| 584 | + |
| 585 | + await server.sendInitialTaskMessage(TEST_PAYLOAD); |
| 586 | + |
| 587 | + expect(promptSpy).toHaveBeenCalledTimes(1); |
| 588 | + expect(updateTaskRunSpy).toHaveBeenCalledWith( |
| 589 | + "test-task-id", |
| 590 | + "test-run-id", |
| 591 | + { |
| 592 | + status: "failed", |
| 593 | + error_message: "Upstream LLM stream terminated", |
| 594 | + }, |
| 595 | + ); |
| 596 | + }); |
| 597 | + |
| 598 | + it("surfaces upstream connection errors with the connection-specific message", async () => { |
| 599 | + vi.spyOn(server.posthogAPI, "getTask").mockResolvedValue({ |
| 600 | + id: "test-task-id", |
| 601 | + title: "t", |
| 602 | + description: "original task description", |
| 603 | + } as unknown as Task); |
| 604 | + vi.spyOn(server.posthogAPI, "getTaskRun").mockResolvedValue({ |
| 605 | + id: "test-run-id", |
| 606 | + task: "test-task-id", |
| 607 | + state: {}, |
| 608 | + } as unknown as TaskRun); |
| 609 | + |
| 610 | + const promptSpy = vi.fn().mockImplementationOnce(async () => { |
| 611 | + throw createTransientConnectionError(); |
| 612 | + }); |
| 613 | + const updateTaskRunSpy = vi |
| 614 | + .spyOn(server.posthogAPI, "updateTaskRun") |
| 615 | + .mockResolvedValue({} as TaskRun); |
| 616 | + server.session = { |
| 617 | + payload: TEST_PAYLOAD, |
| 618 | + acpSessionId: "acp-session", |
| 619 | + clientConnection: { prompt: promptSpy }, |
| 620 | + logWriter: { |
| 621 | + flushAll: vi.fn().mockResolvedValue(undefined), |
| 622 | + getFullAgentResponse: vi.fn().mockReturnValue(null), |
| 623 | + resetTurnMessages: vi.fn(), |
| 624 | + flush: vi.fn().mockResolvedValue(undefined), |
| 625 | + isRegistered: vi.fn().mockReturnValue(true), |
| 626 | + }, |
| 627 | + }; |
| 628 | + |
| 629 | + await server.sendInitialTaskMessage(TEST_PAYLOAD); |
| 630 | + |
| 631 | + expect(promptSpy).toHaveBeenCalledTimes(1); |
| 632 | + expect(updateTaskRunSpy).toHaveBeenCalledWith( |
| 633 | + "test-task-id", |
| 634 | + "test-run-id", |
| 635 | + { |
| 636 | + status: "failed", |
| 637 | + error_message: "Upstream LLM connection error", |
| 638 | + }, |
| 639 | + ); |
| 640 | + }); |
517 | 641 | }); |
518 | 642 | }); |
0 commit comments