diff --git a/embed-pool.js b/embed-pool.js new file mode 100644 index 0000000..51c5e7b --- /dev/null +++ b/embed-pool.js @@ -0,0 +1,146 @@ +/** + * Encapsulates embed-worker lifecycle management. + * + * @param {() => import("node:events").EventEmitter} workerFactory Creates a worker-like object. + * @param {object} [opts] + * @param {number} [opts.embedTimeout=60000] Per-embed timeout in ms. + * @param {number} [opts.restartDelay=2000] Delay before restarting a crashed worker. + * @param {number} [opts.workerReadyTimeout=30000] Max time to wait for a restarting worker. + * @param {number} [opts.maxRestartDelay=60000] Backoff cap for repeated restart failures. + */ +export function createEmbedPool(workerFactory, opts = {}) { + const EMBED_TIMEOUT_MS = opts.embedTimeout ?? 60_000; + const RESTART_DELAY_MS = opts.restartDelay ?? 2000; + const WORKER_READY_TIMEOUT_MS = opts.workerReadyTimeout ?? 30_000; + const MAX_RESTART_DELAY_MS = opts.maxRestartDelay ?? 60_000; + + let worker = null; + let workerAlive = false; + let shuttingDown = false; + let embedIdCounter = 0; + let currentRestartDelay = RESTART_DELAY_MS; + const pendingEmbeds = new Map(); + + let workerReadyResolve = null; + let workerReadyPromise = null; + + function rejectAllPending(reason) { + for (const [, { reject, timer }] of pendingEmbeds) { + clearTimeout(timer); + reject(new Error(reason)); + } + pendingEmbeds.clear(); + } + + let restartTimer = null; + + function scheduleRestart(code) { + if (shuttingDown) return; + const delay = currentRestartDelay; + process.stderr.write(`[vector-memory] Worker exited (code ${code}) — restarting in ${delay}ms\n`); + workerReadyPromise = new Promise(resolve => { workerReadyResolve = resolve; }); + restartTimer = setTimeout(() => { + restartTimer = null; + if (shuttingDown) return; + try { + initWorker(); + currentRestartDelay = RESTART_DELAY_MS; + } catch (err) { + process.stderr.write(`[vector-memory] Worker restart failed: ${err.message}\n`); + currentRestartDelay = Math.min(currentRestartDelay * 2, MAX_RESTART_DELAY_MS); + scheduleRestart(code); + return; + } + if (workerReadyResolve) { + workerReadyResolve(); + workerReadyResolve = null; + } + }, delay); + } + + function initWorker() { + worker = workerFactory(); + workerAlive = true; + workerReadyPromise = null; + + worker.on("message", (msg) => { + if (msg.type === "ready") return; + if (msg.type === "error") { + process.stderr.write(`[vector-memory] Embedding model error: ${msg.message}\n`); + return; + } + const pending = pendingEmbeds.get(msg.id); + if (pending) { + clearTimeout(pending.timer); + pendingEmbeds.delete(msg.id); + pending.resolve(msg.embedding); + } + }); + + worker.on("error", (err) => { + process.stderr.write(`[vector-memory] Worker crashed: ${err.message}\n`); + workerAlive = false; + rejectAllPending("Worker crashed: " + err.message); + }); + + worker.on("exit", (code) => { + workerAlive = false; + rejectAllPending("Worker exited with code " + code); + scheduleRestart(code); + }); + } + + async function embed(text) { + if (!workerAlive && workerReadyPromise) { + let timeoutId; + const timeout = new Promise((_, reject) => { + timeoutId = setTimeout( + () => reject(new Error("Embed worker restart timed out")), + WORKER_READY_TIMEOUT_MS + ); + }); + try { + await Promise.race([workerReadyPromise, timeout]); + } finally { + clearTimeout(timeoutId); + } + } + + if (!workerAlive) { + throw new Error("Embed worker is not running"); + } + + return new Promise((resolve, reject) => { + const id = embedIdCounter++; + const timer = setTimeout(() => { + pendingEmbeds.delete(id); + reject(new Error("Embedding timed out after " + EMBED_TIMEOUT_MS + "ms")); + }, EMBED_TIMEOUT_MS); + pendingEmbeds.set(id, { resolve, reject, timer }); + try { + worker.postMessage({ id, text }); + } catch (err) { + clearTimeout(timer); + pendingEmbeds.delete(id); + reject(err); + } + }); + } + + function shutdown() { + shuttingDown = true; + clearTimeout(restartTimer); + rejectAllPending("Pool shutting down"); + if (worker) { + worker.terminate(); + worker = null; + workerAlive = false; + } + if (workerReadyResolve) { + workerReadyResolve(); + workerReadyResolve = null; + } + } + + return { embed, initWorker, shutdown, isAlive: () => workerAlive }; +} diff --git a/index.js b/index.js index 21e3c36..ce34f36 100644 --- a/index.js +++ b/index.js @@ -11,7 +11,7 @@ import { request } from "http"; import { userInfo } from "os"; const __dirname = dirname(fileURLToPath(import.meta.url)); -const COPILOT_DIR = join(homedir(), ".copilot"); +const COPILOT_DIR = process.env.VECTOR_MEMORY_DATA_DIR || join(homedir(), ".copilot"); const EXPECTED_USER = userInfo().username; const PKG = JSON.parse(readFileSync(join(__dirname, "package.json"), "utf-8")); @@ -27,7 +27,7 @@ const PID_FILE = join(COPILOT_DIR, "vector-memory.pid"); const SERVER_DEPS_DIR = join(__dirname, ".server"); const SERVER_DEPS_JSON = join(__dirname, "server-deps.json"); -const SERVER_FILES = ["vector-memory-server.js", "embed-worker.js", "lib.js"]; +const SERVER_FILES = ["vector-memory-server.js", "embed-worker.js", "embed-pool.js", "lib.js"]; /** Check if server deps are available (either in node_modules or .server/) */ function serverDepsInstalled() { diff --git a/package-lock.json b/package-lock.json index 2243087..9c08298 100644 --- a/package-lock.json +++ b/package-lock.json @@ -18,6 +18,9 @@ "devDependencies": { "@eslint/js": "^10.0.1", "eslint": "^10.0.2" + }, + "engines": { + "node": ">=18" } }, "node_modules/@eslint-community/eslint-utils": { diff --git a/package.json b/package.json index e3f0dba..2884b58 100644 --- a/package.json +++ b/package.json @@ -10,6 +10,7 @@ "index.js", "vector-memory-server.js", "embed-worker.js", + "embed-pool.js", "lib.js", "server-deps.json", "LICENSE", @@ -50,9 +51,10 @@ "offline" ], "scripts": { - "lint": "eslint index.js vector-memory-server.js embed-worker.js lib.js", + "lint": "eslint index.js vector-memory-server.js embed-worker.js embed-pool.js lib.js", "test": "node --test test.js", - "test:coverage": "node --test --experimental-test-coverage --test-coverage-lines=100 --test-coverage-branches=100 --test-coverage-functions=100 --test-coverage-exclude=test.js test.js", + "test:integration": "node --test test-integration.js", + "test:coverage": "node --test --experimental-test-coverage --test-coverage-lines=100 --test-coverage-branches=100 --test-coverage-functions=100 --test-coverage-exclude=test.js --test-coverage-exclude=test-integration.js test.js", "check": "npm run lint && npm run test:coverage" }, "dependencies": { diff --git a/test-integration.js b/test-integration.js new file mode 100644 index 0000000..2c07071 --- /dev/null +++ b/test-integration.js @@ -0,0 +1,261 @@ +/** + * Integration tests for the MCP STDIO proxy + HTTP server pipeline. + * + * Spawns index.js (the proxy), which spawns vector-memory-server.js (the + * HTTP server) on first tool call. All data goes to a temp directory via + * VECTOR_MEMORY_DATA_DIR, leaving the real ~/.copilot/ untouched. + * + * Refs #5 + */ + +import { describe, it, before, after } from "node:test"; +import assert from "node:assert/strict"; +import { spawn } from "node:child_process"; +import { mkdtempSync, rmSync, existsSync } from "node:fs"; +import { join } from "node:path"; +import { tmpdir } from "node:os"; +import { fileURLToPath } from "node:url"; +import { dirname } from "node:path"; + +const __dirname = dirname(fileURLToPath(import.meta.url)); +const INDEX_JS = join(__dirname, "index.js"); + +// --- MCP JSON-RPC helpers --- + +let msgId = 0; + +function jsonrpc(method, params = {}) { + return JSON.stringify({ jsonrpc: "2.0", id: ++msgId, method, params }); +} + +function notification(method, params = {}) { + return JSON.stringify({ jsonrpc: "2.0", method, params }); +} + +/** + * Spawns the MCP proxy, performs the initialize handshake, and returns + * a helper object for sending tool calls and reading responses. + */ +function createMcpClient(env = {}) { + const child = spawn(process.execPath, [INDEX_JS], { + stdio: ["pipe", "pipe", "pipe"], + env: { ...process.env, ...env }, + windowsHide: true, + }); + + let buffer = ""; + const pending = new Map(); + + child.stdout.on("data", (chunk) => { + buffer += chunk.toString(); + // MCP messages are newline-delimited JSON + let nl; + while ((nl = buffer.indexOf("\n")) !== -1) { + const line = buffer.slice(0, nl).trim(); + buffer = buffer.slice(nl + 1); + if (!line) continue; + try { + const msg = JSON.parse(line); + if (msg.id != null && pending.has(msg.id)) { + pending.get(msg.id)(msg); + pending.delete(msg.id); + } + } catch { /* ignore non-JSON lines */ } + } + }); + + function send(text) { + child.stdin.write(text + "\n"); + } + + function request(method, params = {}, timeoutMs = 120_000) { + const id = ++msgId; + return new Promise((resolve, reject) => { + const timer = setTimeout(() => { + pending.delete(id); + reject(new Error(`MCP request "${method}" (id=${id}) timed out after ${timeoutMs}ms`)); + }, timeoutMs); + + pending.set(id, (msg) => { + clearTimeout(timer); + resolve(msg); + }); + + send(JSON.stringify({ jsonrpc: "2.0", id, method, params })); + }); + } + + async function initialize() { + const resp = await request("initialize", { + protocolVersion: "2024-11-05", + capabilities: {}, + clientInfo: { name: "integration-test", version: "0.1" }, + }); + // Send initialized notification (required by MCP spec) + send(notification("notifications/initialized")); + return resp; + } + + async function callTool(name, args = {}, timeoutMs = 120_000) { + return request("tools/call", { name, arguments: args }, timeoutMs); + } + + async function listTools() { + return request("tools/list"); + } + + function kill() { + try { child.stdin.end(); } catch {} + try { child.kill(); } catch {} + } + + return { initialize, callTool, listTools, request, kill, child }; +} + +// --- Integration tests --- + +describe("MCP Integration (end-to-end)", { timeout: 180_000 }, () => { + let tmpDir; + let client; + let testPort; + + before(async () => { + // Create isolated temp data directory + tmpDir = mkdtempSync(join(tmpdir(), "vector-memory-test-")); + + // Use a random high port to avoid conflicting with a running server + testPort = 40000 + Math.floor(Math.random() * 20000); + + client = createMcpClient({ + VECTOR_MEMORY_DATA_DIR: tmpDir, + VECTOR_MEMORY_PORT: String(testPort), + VECTOR_MEMORY_IDLE_TIMEOUT: "0", // disable idle shutdown + }); + + // Perform MCP handshake + const initResp = await client.initialize(); + assert.ok(initResp.result, "initialize should return a result"); + assert.ok(initResp.result.protocolVersion, "should include protocol version"); + }); + + after(async () => { + if (client) client.kill(); + + // Give child processes a moment to exit + await new Promise(r => setTimeout(r, 1000)); + + // Kill any server on our test port + try { + const { execSync } = await import("node:child_process"); + const out = execSync("netstat -ano", { encoding: "utf-8", windowsHide: true }); + for (const line of out.split("\n")) { + if (line.includes(`:${testPort}`) && line.includes("LISTENING")) { + const pid = parseInt(line.trim().split(/\s+/).pop()); + if (!isNaN(pid) && pid > 0) { + try { process.kill(pid); } catch {} + } + } + } + } catch {} + + // Clean up temp directory + if (tmpDir && existsSync(tmpDir)) { + rmSync(tmpDir, { recursive: true, force: true }); + } + }); + + it("tools/list returns vector_search and vector_reindex", async () => { + const resp = await client.listTools(); + assert.ok(resp.result, "should have result"); + const tools = resp.result.tools; + assert.ok(Array.isArray(tools), "tools should be an array"); + + const names = tools.map(t => t.name).sort(); + assert.deepEqual(names, ["vector_reindex", "vector_search"]); + + // vector_search should have query and limit params + const searchTool = tools.find(t => t.name === "vector_search"); + assert.ok(searchTool.inputSchema.properties.query, "vector_search should have query param"); + assert.ok(searchTool.inputSchema.properties.limit, "vector_search should have limit param"); + + // vector_reindex should have no required params + const reindexTool = tools.find(t => t.name === "vector_reindex"); + assert.ok(reindexTool, "vector_reindex should exist"); + }); + + it("vector_search with valid query returns results (empty DB = no results)", async () => { + const resp = await client.callTool("vector_search", { query: "test query", limit: 5 }); + assert.ok(resp.result, "should have result"); + assert.ok(Array.isArray(resp.result.content), "should have content array"); + assert.equal(resp.result.content[0].type, "text"); + const text = resp.result.content[0].text; + // Empty temp DB → "No results found." or worker error (acceptable on first run) + assert.ok( + text.includes("No results") || + text.includes("score:") || + text.includes("unavailable") || + text.includes("Error"), + `Unexpected response: ${text.slice(0, 300)}` + ); + }); + + it("vector_search with missing query returns validation error", async () => { + const resp = await client.callTool("vector_search", {}); + // MCP SDK returns validation errors as result with isError: true + assert.ok(resp.result || resp.error, "should have result or error"); + if (resp.result) { + assert.ok(resp.result.isError, "should be flagged as error"); + assert.ok(resp.result.content[0].text.includes("invalid") || + resp.result.content[0].text.includes("Invalid") || + resp.result.content[0].text.includes("required"), + `Expected validation error, got: ${resp.result.content[0].text.slice(0, 200)}`); + } + }); + + it("vector_search with invalid limit type returns validation error", async () => { + const resp = await client.callTool("vector_search", { query: "test", limit: "not a number" }); + assert.ok(resp.result || resp.error, "should have result or error"); + if (resp.result) { + assert.ok(resp.result.isError, "should be flagged as error"); + assert.ok(resp.result.content[0].text.includes("invalid") || + resp.result.content[0].text.includes("Invalid") || + resp.result.content[0].text.includes("number"), + `Expected type validation error, got: ${resp.result.content[0].text.slice(0, 200)}`); + } + }); + + it("vector_reindex returns count or session store message", async () => { + const resp = await client.callTool("vector_reindex", {}); + assert.ok(resp.result, "should have result"); + assert.ok(Array.isArray(resp.result.content), "should have content array"); + const text = resp.result.content[0].text; + assert.ok( + text.includes("Reindexed") || + text.includes("Session store") || + text.includes("not found") || + text.includes("unavailable") || + text.includes("Error"), + `Expected reindex result or error message, got: ${text.slice(0, 300)}` + ); + }); + + it("calling unknown tool returns error", async () => { + const resp = await client.callTool("nonexistent_tool", {}); + // MCP SDK returns unknown tool as either error or result with isError + assert.ok( + resp.error || (resp.result && resp.result.isError), + `Expected error for unknown tool, got: ${JSON.stringify(resp.result || resp.error).slice(0, 200)}` + ); + }); + + it("PID file is created in temp data dir", async () => { + // The server should have written its PID file to our temp dir + const pidFile = join(tmpDir, "vector-memory.pid"); + // Give the server a moment if it's still starting + for (let i = 0; i < 10; i++) { + if (existsSync(pidFile)) break; + await new Promise(r => setTimeout(r, 500)); + } + assert.ok(existsSync(pidFile), "PID file should exist in temp data dir (not ~/.copilot/)"); + }); +}); diff --git a/test.js b/test.js index a234b8c..f1b4dfa 100644 --- a/test.js +++ b/test.js @@ -1,7 +1,9 @@ import { describe, it } from "node:test"; import assert from "node:assert/strict"; import { Readable } from "node:stream"; +import { EventEmitter } from "node:events"; import { filterUnindexed, dedup, postProcessResults, isOurServer, isIndexable, userPort, BASE_PORT, MIN_SCORE, createHandler } from "./lib.js"; +import { createEmbedPool } from "./embed-pool.js"; // --- Mock helpers for handler tests --- @@ -446,3 +448,498 @@ describe("handleRequest - /reindex", () => { assert.equal(res.statusCode, 500); }); }); + +// --- MockWorker for embed pool tests --- + +class MockWorker extends EventEmitter { + constructor() { + super(); + this.messages = []; + this.terminated = false; + } + postMessage(msg) { + this.messages.push(msg); + } + terminate() { + this.terminated = true; + } +} + +function mockWorkerFactory() { + const workers = []; + const factory = () => { + const w = new MockWorker(); + workers.push(w); + return w; + }; + factory.workers = workers; + return factory; +} + +// --- Embed pool tests --- + +describe("createEmbedPool", () => { + it("embeds text through the worker (happy path)", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory); + pool.initWorker(); + + const p = pool.embed("hello"); + const msg = factory.workers[0].messages[0]; + factory.workers[0].emit("message", { id: msg.id, embedding: Buffer.from([1, 2, 3]) }); + + const result = await p; + assert.deepEqual(result, Buffer.from([1, 2, 3])); + pool.shutdown(); + }); + + it("rejects embed when worker was never started", async () => { + const pool = createEmbedPool(() => new MockWorker()); + await assert.rejects(() => pool.embed("hello"), /Embed worker is not running/); + }); + + it("rejects pending embeds on worker error", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory); + pool.initWorker(); + + const p = pool.embed("hello"); + factory.workers[0].emit("error", new Error("segfault")); + + await assert.rejects(() => p, /Worker crashed/); + pool.shutdown(); + }); + + it("restarts worker on non-zero exit", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 50 }); + pool.initWorker(); + + factory.workers[0].emit("exit", 1); + await new Promise(r => setTimeout(r, 100)); + + assert.equal(factory.workers.length, 2, "worker should have been recreated"); + assert.equal(pool.isAlive(), true); + pool.shutdown(); + }); + + // === BUG TESTS: these demonstrate the issues we're fixing === + + it("restarts worker on code-0 exit (bug: currently does not)", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 50 }); + pool.initWorker(); + + factory.workers[0].emit("exit", 0); + await new Promise(r => setTimeout(r, 100)); + + assert.equal(factory.workers.length, 2, "worker should restart even on clean exit"); + assert.equal(pool.isAlive(), true, "pool should be alive after code-0 restart"); + pool.shutdown(); + }); + + it("waits for worker restart instead of rejecting immediately (bug: currently rejects)", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 50, workerReadyTimeout: 5000 }); + pool.initWorker(); + + // Crash the worker + factory.workers[0].emit("exit", 1); + + // Immediately try to embed — should wait for restart, not reject + let rejected = false; + let error = null; + const embedPromise = pool.embed("test text") + .catch(e => { rejected = true; error = e; }); + + // Give microtasks a chance to settle (synchronous rejection would be caught here) + await new Promise(r => setTimeout(r, 10)); + + assert.equal(rejected, false, + `embed() rejected immediately with "${error?.message}" instead of waiting for worker restart`); + + // Let restart happen + await new Promise(r => setTimeout(r, 100)); + + // Respond from new worker + assert.equal(factory.workers.length, 2, "worker should have restarted"); + const msg = factory.workers[1].messages[0]; + factory.workers[1].emit("message", { id: msg.id, embedding: Buffer.from([4, 5, 6]) }); + + await embedPromise; + assert.equal(rejected, false); + pool.shutdown(); + }); + + it("does not restart after explicit shutdown", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 50 }); + pool.initWorker(); + pool.shutdown(); + + factory.workers[0].emit("exit", 1); + await new Promise(r => setTimeout(r, 100)); + + assert.equal(factory.workers.length, 1, "should not restart after shutdown"); + }); + + it("times out if worker restart takes too long", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 10000, workerReadyTimeout: 50 }); + pool.initWorker(); + + factory.workers[0].emit("exit", 1); + + await assert.rejects(() => pool.embed("test"), /restart timed out/i); + pool.shutdown(); + }); + + it("handles worker 'error' message type (model error)", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory); + pool.initWorker(); + + // Send an error-type message (model failed to load, etc.) + factory.workers[0].emit("message", { type: "error", message: "ONNX load failed" }); + + // Pool should still be alive — this is a non-fatal model error, not a crash + assert.equal(pool.isAlive(), true); + pool.shutdown(); + }); + + it("times out embed when worker never responds", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { embedTimeout: 50 }); + pool.initWorker(); + + // Send an embed but never respond from the mock worker + await assert.rejects(() => pool.embed("hello"), /timed out after 50ms/); + pool.shutdown(); + }); + + it("rejects embed when postMessage throws", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory); + pool.initWorker(); + + // Make postMessage throw (simulates worker in bad state) + factory.workers[0].postMessage = () => { throw new Error("DataCloneError"); }; + + await assert.rejects(() => pool.embed("hello"), /DataCloneError/); + pool.shutdown(); + }); + + it("ignores 'ready' message type without affecting pending embeds", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory); + pool.initWorker(); + + const p = pool.embed("hello"); + // Send a ready message — should be ignored, embed still pending + factory.workers[0].emit("message", { type: "ready" }); + + // Now send the real response + const msg = factory.workers[0].messages[0]; + factory.workers[0].emit("message", { id: msg.id, embedding: Buffer.from([9]) }); + + const result = await p; + assert.deepEqual(result, Buffer.from([9])); + pool.shutdown(); + }); + + it("ignores response for unknown embed id", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory); + pool.initWorker(); + + // Send a response for an ID that was never requested — should not throw + factory.workers[0].emit("message", { id: 99999, embedding: Buffer.from([1]) }); + pool.shutdown(); + }); + + it("shutdown is idempotent when no worker was started", () => { + const pool = createEmbedPool(() => new MockWorker()); + // Should not throw + pool.shutdown(); + pool.shutdown(); + }); + + it("BUG: workerFactory() throwing in scheduleRestart retries with backoff", async () => { + let callCount = 0; + const workers = []; + const factory = () => { + callCount++; + if (callCount === 2) throw new Error("Worker constructor exploded"); + const w = new MockWorker(); + workers.push(w); + return w; + }; + const pool = createEmbedPool(factory, { restartDelay: 50, workerReadyTimeout: 2000, maxRestartDelay: 200 }); + pool.initWorker(); + + // Worker exits cleanly — first restart throws, second should succeed via backoff + workers[0].emit("exit", 0); + + // Wait for retry + await new Promise(r => setTimeout(r, 500)); + + assert.ok(callCount >= 3, `Expected at least 3 factory calls, got ${callCount}`); + assert.equal(pool.isAlive(), true, "pool should recover via backoff retry"); + pool.shutdown(); + }); + + it("BUG: shutdown during restart delay still spawns zombie worker", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 50 }); + pool.initWorker(); + + // Worker exits — restart is scheduled with 50ms delay + factory.workers[0].emit("exit", 1); + + // Shutdown immediately (before the 50ms restart fires) + pool.shutdown(); + + // Wait for the restart timer to fire (guard should catch it) + await new Promise(r => setTimeout(r, 150)); + + // Should NOT have created a second worker — guard prevents it + assert.equal(factory.workers.length, 1, + `Expected 1 worker but got ${factory.workers.length} — zombie worker spawned after shutdown`); + }); + + it("shutdown resolves pending restart waiters", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 10000, workerReadyTimeout: 5000 }); + pool.initWorker(); + + // Trigger exit — restart is scheduled but slow + factory.workers[0].emit("exit", 1); + + // Start an embed — it will wait for the restart + const embedPromise = pool.embed("hello").catch(e => e); + + // Shutdown while it's waiting + pool.shutdown(); + + const result = await embedPromise; + assert.ok(result instanceof Error); + assert.match(result.message, /not running|shutting down/i); + }); + + // === BUG: code-0 exit orphans in-flight embeds === + + it("BUG: code-0 exit rejects in-flight embeds (not orphaned for 60s)", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { embedTimeout: 5000, restartDelay: 50 }); + pool.initWorker(); + + // Start an embed (worker won't respond) + const embedPromise = pool.embed("hello").catch(e => e); + await new Promise(r => setTimeout(r, 10)); + + // Worker exits cleanly — in-flight embed should be rejected promptly + const start = Date.now(); + factory.workers[0].emit("exit", 0); + + const result = await embedPromise; + const elapsed = Date.now() - start; + + assert.ok(result instanceof Error, "in-flight embed should have been rejected"); + assert.ok(elapsed < 1000, + `embed took ${elapsed}ms to reject — orphaned until embedTimeout instead of rejected on exit`); + pool.shutdown(); + }); + + // === BUG: failed restart = permanent death (no retry) === + + it("BUG: failed restart retries with backoff instead of dying permanently", async () => { + let callCount = 0; + const workers = []; + const factory = () => { + callCount++; + // Fail on attempts 2 and 3, succeed on attempt 4 + if (callCount >= 2 && callCount <= 3) throw new Error("ONNX load failed"); + const w = new MockWorker(); + workers.push(w); + return w; + }; + const pool = createEmbedPool(factory, { + restartDelay: 30, + workerReadyTimeout: 5000, + maxRestartDelay: 200, + }); + pool.initWorker(); + + // Worker exits — first restart attempt will fail, second will fail, third should succeed + workers[0].emit("exit", 0); + + // Wait for backoff retries to play out + await new Promise(r => setTimeout(r, 1500)); + + assert.ok(callCount >= 4, + `Expected at least 4 factory calls (1 init + 2 failures + 1 success), got ${callCount}`); + assert.equal(pool.isAlive(), true, "pool should have recovered after transient failures"); + pool.shutdown(); + }); + + it("POSITIVE: backoff delay increases on consecutive failures", async () => { + let callCount = 0; + const timestamps = []; + const workers = []; + const factory = () => { + callCount++; + timestamps.push(Date.now()); + if (callCount >= 2) throw new Error("still broken"); + const w = new MockWorker(); + workers.push(w); + return w; + }; + const pool = createEmbedPool(factory, { + restartDelay: 50, + workerReadyTimeout: 5000, + maxRestartDelay: 400, + }); + pool.initWorker(); + + workers[0].emit("exit", 0); + // Wait for several backoff attempts + await new Promise(r => setTimeout(r, 2000)); + + // Should have multiple attempts with increasing gaps + assert.ok(callCount >= 4, `Expected at least 4 attempts, got ${callCount}`); + + // Verify delays are increasing (backoff) + for (let i = 2; i < timestamps.length - 1; i++) { + const gap1 = timestamps[i] - timestamps[i - 1]; + const gap2 = timestamps[i + 1] - timestamps[i]; + assert.ok(gap2 >= gap1 * 0.8, // allow 20% timing jitter + `Expected increasing delays but gap ${i}: ${gap1}ms, gap ${i+1}: ${gap2}ms`); + } + pool.shutdown(); + }); + + it("POSITIVE: backoff resets after a successful restart", async () => { + let callCount = 0; + const workers = []; + const factory = () => { + callCount++; + // Fail on second call, succeed on all others + if (callCount === 2) throw new Error("transient failure"); + const w = new MockWorker(); + workers.push(w); + return w; + }; + const pool = createEmbedPool(factory, { + restartDelay: 30, + workerReadyTimeout: 5000, + maxRestartDelay: 500, + }); + pool.initWorker(); + + // First exit → restart fails → retries → succeeds + workers[0].emit("exit", 0); + await new Promise(r => setTimeout(r, 500)); + + assert.equal(pool.isAlive(), true, "pool should have recovered"); + const secondWorkerIdx = workers.length - 1; + + // Second exit → should restart with initial delay (not backoff from previous failure) + workers[secondWorkerIdx].emit("exit", 0); + await new Promise(r => setTimeout(r, 200)); + + assert.equal(pool.isAlive(), true, "pool should restart at base delay after prior success"); + pool.shutdown(); + }); + + it("POSITIVE: backoff caps at maxRestartDelay", async () => { + let callCount = 0; + const timestamps = []; + const workers = []; + const factory = () => { + callCount++; + timestamps.push(Date.now()); + if (callCount >= 2) throw new Error("permanently broken"); + const w = new MockWorker(); + workers.push(w); + return w; + }; + const pool = createEmbedPool(factory, { + restartDelay: 50, + maxRestartDelay: 150, + }); + pool.initWorker(); + workers[0].emit("exit", 0); + + await new Promise(r => setTimeout(r, 2000)); + + // All gaps after the first few should be capped at ~150ms + const gaps = []; + for (let i = 1; i < timestamps.length; i++) { + gaps.push(timestamps[i] - timestamps[i - 1]); + } + // The last few gaps should all be ≤ maxRestartDelay + jitter + const lastGaps = gaps.slice(-3); + for (const gap of lastGaps) { + assert.ok(gap <= 250, + `Gap ${gap}ms exceeds maxRestartDelay (150ms) + reasonable jitter`); + } + pool.shutdown(); + }); + + it("shutdown cancels pending restart timer", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 50 }); + pool.initWorker(); + + // Trigger exit — restart timer starts + factory.workers[0].emit("exit", 1); + + // Shutdown immediately — should cancel the restart + pool.shutdown(); + + // Wait for timer that would have fired + await new Promise(r => setTimeout(r, 100)); + + // Should NOT have created a second worker + assert.equal(factory.workers.length, 1, "shutdown should cancel pending restart"); + }); + + it("belt-and-suspenders: shuttingDown guard catches callback when clearTimeout fails", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory, { restartDelay: 50 }); + pool.initWorker(); + + // Worker exits — restart timer is scheduled + factory.workers[0].emit("exit", 1); + + // Monkeypatch clearTimeout to a no-op, simulating a race where + // the timer callback is already queued when shutdown tries to cancel it + const realClearTimeout = globalThis.clearTimeout; + globalThis.clearTimeout = () => {}; + + try { + pool.shutdown(); + } finally { + globalThis.clearTimeout = realClearTimeout; + } + + // Timer fires because clearTimeout was neutered — the shuttingDown guard catches it + await new Promise(r => setTimeout(r, 150)); + + // No zombie worker spawned — guard did its job + assert.equal(factory.workers.length, 1, + "shuttingDown guard should prevent zombie worker when clearTimeout fails"); + }); + + it("uses default options when none provided", async () => { + const factory = mockWorkerFactory(); + const pool = createEmbedPool(factory); + pool.initWorker(); + + const p = pool.embed("test"); + const msg = factory.workers[0].messages[0]; + factory.workers[0].emit("message", { id: msg.id, embedding: Buffer.from([1]) }); + await p; + pool.shutdown(); + }); +}); diff --git a/vector-memory-server.js b/vector-memory-server.js index 4d894a2..372c223 100644 --- a/vector-memory-server.js +++ b/vector-memory-server.js @@ -9,11 +9,12 @@ import { createServer, request as httpReq } from "http"; import { execSync } from "child_process"; import { userInfo } from "os"; import { filterUnindexed, dedup, postProcessResults, isOurServer, isIndexable, DIMS, createHandler, userPort } from "./lib.js"; +import { createEmbedPool } from "./embed-pool.js"; const __dirname = dirname(fileURLToPath(import.meta.url)); const PKG = JSON.parse(readFileSync(join(__dirname, "package.json"), "utf-8")); const SERVER_USER = userInfo().username; -const COPILOT_DIR = join(homedir(), ".copilot"); +const COPILOT_DIR = process.env.VECTOR_MEMORY_DATA_DIR || join(homedir(), ".copilot"); const SESSION_STORE_PATH = join(COPILOT_DIR, "session-store.db"); const VECTOR_INDEX_PATH = join(COPILOT_DIR, "vector-index.db"); const INDEX_INTERVAL_MS = 15 * 60 * 1000; @@ -24,77 +25,7 @@ const IDLE_CHECK_MS = 60_000; // check every 60s let isIndexing = false; -// --- Embedding via Worker Thread --- - -// Worker is started lazily after we win the singleton race. -// Auto-restarts on crash to prevent zombie server state. -let worker; -let workerAlive = false; -let embedIdCounter = 0; -const pendingEmbeds = new Map(); // id → { resolve, reject } -const EMBED_TIMEOUT_MS = 60_000; - -function rejectAllPending(reason) { - for (const [id, { reject }] of pendingEmbeds) { - reject(new Error(reason)); - } - pendingEmbeds.clear(); -} - -function initWorker() { - worker = new Worker(join(__dirname, "embed-worker.js")); - workerAlive = true; - - worker.on("message", (msg) => { - if (msg.type === "ready") return; - if (msg.type === "error") { - process.stderr.write(`[vector-memory] Embedding model error: ${msg.message}\n`); - return; - } - const pending = pendingEmbeds.get(msg.id); - if (pending) { - clearTimeout(pending.timer); - pendingEmbeds.delete(msg.id); - pending.resolve(msg.embedding); - } - }); - - worker.on("error", (err) => { - process.stderr.write(`[vector-memory] Worker crashed: ${err.message}\n`); - workerAlive = false; - rejectAllPending("Worker crashed: " + err.message); - }); - - worker.on("exit", (code) => { - workerAlive = false; - if (code !== 0) { - process.stderr.write(`[vector-memory] Worker exited with code ${code} — restarting in 2s\n`); - rejectAllPending("Worker exited with code " + code); - setTimeout(() => initWorker(), 2000); - } - }); -} - -function embed(text) { - return new Promise((resolve, reject) => { - if (!workerAlive) { - return reject(new Error("Embed worker is not running")); - } - const id = embedIdCounter++; - const timer = setTimeout(() => { - pendingEmbeds.delete(id); - reject(new Error("Embedding timed out after " + EMBED_TIMEOUT_MS + "ms")); - }, EMBED_TIMEOUT_MS); - pendingEmbeds.set(id, { resolve, reject, timer }); - try { - worker.postMessage({ id, text }); - } catch (err) { - clearTimeout(timer); - pendingEmbeds.delete(id); - reject(err); - } - }); -} +const pool = createEmbedPool(() => new Worker(join(__dirname, "embed-worker.js"))); function openVectorDb() { const db = new Database(VECTOR_INDEX_PATH); @@ -154,7 +85,7 @@ async function indexContent(vecDb, items) { for (const item of items) { if (!isIndexable(item)) continue; - const embedding = await embed(item.content); + const embedding = await pool.embed(item.content); const result = insertMeta.run(item.session_id, item.source_type, item.content, item.source_id ?? null); if (result.changes > 0) { insertVec.run(BigInt(result.lastInsertRowid), embedding); @@ -189,7 +120,7 @@ async function backgroundIndex() { } async function search(vecDb, query, limit = 10) { - const queryEmbedding = await embed(query); + const queryEmbedding = await pool.embed(query); const results = vecDb .prepare( @@ -366,7 +297,7 @@ try { } // --- We won the singleton race — now do the heavy init --- -initWorker(); +pool.initWorker(); { const vecDb = openVectorDb(); @@ -393,7 +324,7 @@ function cleanup() { const pidFile = join(COPILOT_DIR, "vector-memory.pid"); if (existsSync(pidFile)) unlinkSync(pidFile); } catch {} - if (worker) worker.terminate(); + pool.shutdown(); process.exit(0); } process.on("SIGTERM", cleanup);