From 4ab13eca4ddc90c75292aaf47f0b5deae0765048 Mon Sep 17 00:00:00 2001 From: Peter Steinberger Date: Sun, 1 Mar 2026 22:38:47 +0000 Subject: [PATCH] test(agents): port OpenAI websocket coverage from #24911 Co-authored-by: Jonathan Jing --- src/agents/openai-ws-connection.test.ts | 712 +++++++++++++++ src/agents/openai-ws-stream.e2e.test.ts | 151 ++++ src/agents/openai-ws-stream.test.ts | 1062 +++++++++++++++++++++++ 3 files changed, 1925 insertions(+) create mode 100644 src/agents/openai-ws-connection.test.ts create mode 100644 src/agents/openai-ws-stream.e2e.test.ts create mode 100644 src/agents/openai-ws-stream.test.ts diff --git a/src/agents/openai-ws-connection.test.ts b/src/agents/openai-ws-connection.test.ts new file mode 100644 index 000000000..3122e4f6e --- /dev/null +++ b/src/agents/openai-ws-connection.test.ts @@ -0,0 +1,712 @@ +/** + * Unit tests for OpenAIWebSocketManager + * + * Uses a mock WebSocket implementation to avoid real network calls. + * The mock simulates the ws package's EventEmitter-based API. + */ + +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { + ClientEvent, + OpenAIWebSocketEvent, + ResponseCompletedEvent, + ResponseCreateEvent, +} from "./openai-ws-connection.js"; +import { OpenAIWebSocketManager } from "./openai-ws-connection.js"; + +// ───────────────────────────────────────────────────────────────────────────── +// Mock WebSocket (hoisted so vi.mock factory can reference it) +// ───────────────────────────────────────────────────────────────────────────── + +// vi.mock() factories are hoisted before ES module imports are resolved. +// vi.hoisted() allows us to define values that are available to both the +// factory AND the test body. We avoid importing EventEmitter here because +// ESM imports aren't available yet in the hoisted zone — instead we +// implement a minimal listener pattern inline. +const { MockWebSocket } = vi.hoisted(() => { + type AnyFn = (...args: unknown[]) => void; + + class MockWebSocket { + static CONNECTING = 0; + static OPEN = 1; + static CLOSING = 2; + static CLOSED = 3; + + readyState: number = MockWebSocket.CONNECTING; + url: string; + options: Record; + sentMessages: string[] = []; + + private _listeners: Map = new Map(); + + constructor(url: string, options?: Record) { + this.url = url; + this.options = options ?? {}; + MockWebSocket.lastInstance = this; + MockWebSocket.instances.push(this); + } + + // Minimal EventEmitter-compatible interface + on(event: string, fn: AnyFn): this { + const list = this._listeners.get(event) ?? []; + list.push(fn); + this._listeners.set(event, list); + return this; + } + + once(event: string, fn: AnyFn): this { + const wrapper = (...args: unknown[]) => { + this.off(event, wrapper); + fn(...args); + }; + return this.on(event, wrapper); + } + + off(event: string, fn: AnyFn): this { + const list = this._listeners.get(event) ?? []; + this._listeners.set( + event, + list.filter((l) => l !== fn), + ); + return this; + } + + removeAllListeners(event?: string): this { + if (event !== undefined) { + this._listeners.delete(event); + } else { + this._listeners.clear(); + } + return this; + } + + emit(event: string, ...args: unknown[]): boolean { + const list = this._listeners.get(event) ?? []; + for (const fn of list) { + fn(...args); + } + return list.length > 0; + } + + // ws-compatible send + send(data: string): void { + this.sentMessages.push(data); + } + + // ws-compatible close — triggers async close event + close(code = 1000, reason = ""): void { + this.readyState = MockWebSocket.CLOSING; + setImmediate(() => { + this.readyState = MockWebSocket.CLOSED; + this.emit("close", code, Buffer.from(reason)); + }); + } + + // ── Test helpers ────────────────────────────────────────────────────── + + simulateOpen(): void { + this.readyState = MockWebSocket.OPEN; + this.emit("open"); + } + + simulateMessage(event: unknown): void { + this.emit("message", Buffer.from(JSON.stringify(event))); + } + + simulateError(err: Error): void { + this.readyState = MockWebSocket.CLOSED; + this.emit("error", err); + } + + simulateClose(code = 1006, reason = "Connection lost"): void { + this.readyState = MockWebSocket.CLOSED; + this.emit("close", code, Buffer.from(reason)); + } + + static lastInstance: MockWebSocket | null = null; + static instances: MockWebSocket[] = []; + + static reset(): void { + MockWebSocket.lastInstance = null; + MockWebSocket.instances = []; + } + } + + return { MockWebSocket }; +}); + +// ───────────────────────────────────────────────────────────────────────────── +// Module Mock +// ───────────────────────────────────────────────────────────────────────────── + +vi.mock("ws", () => { + // ws exports WebSocket as the default export; static constants (OPEN, etc.) + // live on the class itself. + return { default: MockWebSocket }; +}); + +// ───────────────────────────────────────────────────────────────────────────── +// Type alias for the mock class (improves test readability) +// ───────────────────────────────────────────────────────────────────────────── + +type MockWS = typeof MockWebSocket extends { new (...a: infer _): infer R } ? R : never; + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +function lastSocket(): MockWS { + const sock = MockWebSocket.lastInstance; + if (!sock) { + throw new Error("No MockWebSocket instance created"); + } + return sock; +} + +function buildManager(opts?: ConstructorParameters[0]) { + return new OpenAIWebSocketManager({ + // Use faster backoff in tests to avoid slow timer waits + backoffDelaysMs: [10, 20, 40, 80, 160], + ...opts, + }); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tests +// ───────────────────────────────────────────────────────────────────────────── + +describe("OpenAIWebSocketManager", () => { + beforeEach(() => { + MockWebSocket.reset(); + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + // ─── connect() ───────────────────────────────────────────────────────────── + + describe("connect()", () => { + it("opens a WebSocket with Bearer auth header", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test-key"); + + const sock = lastSocket(); + expect(sock.url).toBe("wss://api.openai.com/v1/responses"); + expect(sock.options).toMatchObject({ + headers: expect.objectContaining({ + Authorization: "Bearer sk-test-key", + }), + }); + + sock.simulateOpen(); + await connectPromise; + }); + + it("resolves when the connection opens", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await expect(connectPromise).resolves.toBeUndefined(); + }); + + it("rejects when the initial connection fails (maxRetries=0)", async () => { + const manager = buildManager({ maxRetries: 0 }); + const connectPromise = manager.connect("sk-test"); + + lastSocket().simulateError(new Error("ECONNREFUSED")); + + await expect(connectPromise).rejects.toThrow("ECONNREFUSED"); + }); + + it("sets isConnected() to true after open", async () => { + const manager = buildManager(); + expect(manager.isConnected()).toBe(false); + + const connectPromise = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await connectPromise; + + expect(manager.isConnected()).toBe(true); + }); + + it("uses the custom URL when provided", async () => { + const manager = buildManager({ url: "ws://localhost:9999/v1/responses" }); + const connectPromise = manager.connect("sk-test"); + + expect(lastSocket().url).toBe("ws://localhost:9999/v1/responses"); + lastSocket().simulateOpen(); + await connectPromise; + }); + }); + + // ─── send() ──────────────────────────────────────────────────────────────── + + describe("send()", () => { + it("sends a JSON-serialized event over the socket", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + const event: ResponseCreateEvent = { + type: "response.create", + model: "gpt-5.2", + input: [{ type: "message", role: "user", content: "Hello" }], + }; + manager.send(event); + + expect(sock.sentMessages).toHaveLength(1); + expect(JSON.parse(sock.sentMessages[0] ?? "{}")).toEqual(event); + }); + + it("throws if the connection is not open", () => { + const manager = buildManager(); + const event: ClientEvent = { + type: "response.create", + model: "gpt-5.2", + }; + expect(() => manager.send(event)).toThrow(/cannot send/); + }); + + it("includes previous_response_id when provided", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + const event: ResponseCreateEvent = { + type: "response.create", + model: "gpt-5.2", + previous_response_id: "resp_abc123", + input: [{ type: "function_call_output", call_id: "call_1", output: "result" }], + }; + manager.send(event); + + const sent = JSON.parse(sock.sentMessages[0] ?? "{}") as ResponseCreateEvent; + expect(sent.previous_response_id).toBe("resp_abc123"); + }); + }); + + // ─── onMessage() ─────────────────────────────────────────────────────────── + + describe("onMessage()", () => { + it("calls handler for each incoming message", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + const received: OpenAIWebSocketEvent[] = []; + manager.onMessage((e) => received.push(e)); + + const deltaEvent: OpenAIWebSocketEvent = { + type: "response.output_text.delta", + item_id: "item_1", + output_index: 0, + content_index: 0, + delta: "Hello", + }; + sock.simulateMessage(deltaEvent); + + expect(received).toHaveLength(1); + expect(received[0]).toEqual(deltaEvent); + }); + + it("returns an unsubscribe function that stops delivery", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + const received: OpenAIWebSocketEvent[] = []; + const unsubscribe = manager.onMessage((e) => received.push(e)); + + sock.simulateMessage({ type: "response.in_progress", response: makeResponse("r1") }); + unsubscribe(); + sock.simulateMessage({ type: "response.in_progress", response: makeResponse("r2") }); + + expect(received).toHaveLength(1); + }); + + it("supports multiple simultaneous handlers", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + const calls: number[] = []; + manager.onMessage(() => calls.push(1)); + manager.onMessage(() => calls.push(2)); + + sock.simulateMessage({ type: "response.in_progress", response: makeResponse("r1") }); + + expect(calls.toSorted((a, b) => a - b)).toEqual([1, 2]); + }); + }); + + // ─── previousResponseId ──────────────────────────────────────────────────── + + describe("previousResponseId", () => { + it("starts as null", () => { + expect(new OpenAIWebSocketManager().previousResponseId).toBeNull(); + }); + + it("is updated when a response.completed event is received", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + const completedEvent: ResponseCompletedEvent = { + type: "response.completed", + response: makeResponse("resp_done_42", "completed"), + }; + sock.simulateMessage(completedEvent); + + expect(manager.previousResponseId).toBe("resp_done_42"); + }); + + it("tracks the most recent completed response", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + sock.simulateMessage({ + type: "response.completed", + response: makeResponse("resp_1", "completed"), + }); + sock.simulateMessage({ + type: "response.completed", + response: makeResponse("resp_2", "completed"), + }); + + expect(manager.previousResponseId).toBe("resp_2"); + }); + + it("is not updated for non-completed events", async () => { + const manager = buildManager(); + const connectPromise = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await connectPromise; + + sock.simulateMessage({ type: "response.in_progress", response: makeResponse("resp_x") }); + + expect(manager.previousResponseId).toBeNull(); + }); + }); + + // ─── isConnected() ───────────────────────────────────────────────────────── + + describe("isConnected()", () => { + it("returns false before connect", () => { + expect(buildManager().isConnected()).toBe(false); + }); + + it("returns true while open", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await p; + expect(manager.isConnected()).toBe(true); + }); + + it("returns false after close()", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await p; + manager.close(); + expect(manager.isConnected()).toBe(false); + }); + }); + + // ─── close() ─────────────────────────────────────────────────────────────── + + describe("close()", () => { + it("marks the manager as disconnected", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await p; + + manager.close(); + + expect(manager.isConnected()).toBe(false); + }); + + it("prevents reconnect after explicit close", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await p; + + const socketCountBefore = MockWebSocket.instances.length; + manager.close(); + + // Simulate a network drop — should NOT trigger reconnect + sock.simulateClose(1006, "Network error"); + await vi.runAllTimersAsync(); + + expect(MockWebSocket.instances.length).toBe(socketCountBefore); + }); + + it("is safe to call before connect()", () => { + const manager = buildManager(); + expect(() => manager.close()).not.toThrow(); + }); + }); + + // ─── Auto-reconnect ──────────────────────────────────────────────────────── + + describe("auto-reconnect", () => { + it("reconnects on unexpected close", async () => { + const manager = buildManager({ backoffDelaysMs: [10, 20, 40, 80, 160] }); + const p = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await p; + + const sock1 = lastSocket(); + const instancesBefore = MockWebSocket.instances.length; + + // Simulate a network drop + sock1.simulateClose(1006, "Network error"); + + // Advance time to trigger first retry (10ms delay) + await vi.advanceTimersByTimeAsync(15); + + // A new socket should have been created + expect(MockWebSocket.instances.length).toBeGreaterThan(instancesBefore); + expect(lastSocket()).not.toBe(sock1); + }); + + it("stops retrying after maxRetries", async () => { + const manager = buildManager({ maxRetries: 2, backoffDelaysMs: [5, 5] }); + const p = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await p; + + const errors: Error[] = []; + manager.on("error", (e) => errors.push(e)); + + // Drop repeatedly — each reconnect attempt also drops immediately + for (let i = 0; i < 4; i++) { + lastSocket().simulateClose(1006, "drop"); + await vi.advanceTimersByTimeAsync(20); + } + + const maxRetryError = errors.find((e) => e.message.includes("max reconnect retries")); + expect(maxRetryError).toBeDefined(); + }); + + it("resets retry count after a successful reconnect", async () => { + const manager = buildManager({ maxRetries: 3, backoffDelaysMs: [5, 10, 20] }); + const p = manager.connect("sk-test"); + lastSocket().simulateOpen(); + await p; + + // Drop and let first retry succeed + lastSocket().simulateClose(1006, "drop"); + await vi.advanceTimersByTimeAsync(10); + lastSocket().simulateOpen(); // second socket opens successfully + + const socketCountAfterReconnect = MockWebSocket.instances.length; + + // Drop again — should still retry (retry count was reset) + lastSocket().simulateClose(1006, "drop again"); + await vi.advanceTimersByTimeAsync(10); + + expect(MockWebSocket.instances.length).toBeGreaterThan(socketCountAfterReconnect); + }); + }); + + // ─── warmUp() ────────────────────────────────────────────────────────────── + + describe("warmUp()", () => { + it("sends a response.create event with generate: false", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await p; + + manager.warmUp({ model: "gpt-5.2", instructions: "You are helpful." }); + + expect(sock.sentMessages).toHaveLength(1); + const sent = JSON.parse(sock.sentMessages[0] ?? "{}") as Record; + expect(sent["type"]).toBe("response.create"); + expect(sent["generate"]).toBe(false); + expect(sent["model"]).toBe("gpt-5.2"); + expect(sent["instructions"]).toBe("You are helpful."); + }); + + it("includes tools when provided", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await p; + + manager.warmUp({ + model: "gpt-5.2", + tools: [{ type: "function", function: { name: "exec", description: "Run a command" } }], + }); + + const sent = JSON.parse(sock.sentMessages[0] ?? "{}") as Record; + expect(sent["tools"]).toHaveLength(1); + expect((sent["tools"] as Array<{ function?: { name?: string } }>)[0]?.function?.name).toBe( + "exec", + ); + }); + }); + + // ─── Error handling ───────────────────────────────────────────────────────── + + describe("error handling", () => { + it("emits error event on malformed JSON message", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await p; + + const errors: Error[] = []; + manager.on("error", (e) => errors.push(e)); + + sock.emit("message", Buffer.from("not valid json{{{{")); + + expect(errors).toHaveLength(1); + expect(errors[0]?.message).toContain("failed to parse message"); + }); + + it("emits error event when message has no type field", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await p; + + const errors: Error[] = []; + manager.on("error", (e) => errors.push(e)); + + sock.emit("message", Buffer.from(JSON.stringify({ foo: "bar" }))); + + expect(errors).toHaveLength(1); + expect(errors[0]?.message).toContain('no "type" field'); + }); + + it("emits error event on WebSocket socket error", async () => { + const manager = buildManager({ maxRetries: 0 }); + const p = manager.connect("sk-test").catch(() => { + /* ignore rejection */ + }); + + const errors: Error[] = []; + manager.on("error", (e) => errors.push(e)); + + lastSocket().simulateError(new Error("SSL handshake failed")); + await p; + + expect(errors.some((e) => e.message === "SSL handshake failed")).toBe(true); + }); + + it("handles multiple successive socket errors without crashing", async () => { + const manager = buildManager({ maxRetries: 0 }); + const p = manager.connect("sk-test").catch(() => { + /* ignore rejection */ + }); + + const errors: Error[] = []; + manager.on("error", (e) => errors.push(e)); + + // Fire two errors in quick succession — previously the second would + // be unhandled because .once("error") removed the handler after #1. + lastSocket().simulateError(new Error("first error")); + lastSocket().simulateError(new Error("second error")); + await p; + + expect(errors.length).toBeGreaterThanOrEqual(2); + expect(errors.some((e) => e.message === "first error")).toBe(true); + expect(errors.some((e) => e.message === "second error")).toBe(true); + }); + }); + + // ─── Integration: full multi-turn sequence ──────────────────────────────── + + describe("full turn sequence", () => { + it("tracks previous_response_id across turns and sends continuation correctly", async () => { + const manager = buildManager(); + const p = manager.connect("sk-test"); + const sock = lastSocket(); + sock.simulateOpen(); + await p; + + const received: OpenAIWebSocketEvent[] = []; + manager.onMessage((e) => received.push(e)); + + // Send initial turn + manager.send({ type: "response.create", model: "gpt-5.2", input: "Hello" }); + + // Simulate streaming events from server + sock.simulateMessage({ type: "response.created", response: makeResponse("resp_1") }); + sock.simulateMessage({ + type: "response.output_text.delta", + item_id: "i1", + output_index: 0, + content_index: 0, + delta: "Hi!", + }); + sock.simulateMessage({ + type: "response.completed", + response: makeResponse("resp_1", "completed"), + }); + + expect(manager.previousResponseId).toBe("resp_1"); + expect(received).toHaveLength(3); + + // Send continuation turn using the tracked previous_response_id + manager.send({ + type: "response.create", + model: "gpt-5.2", + previous_response_id: manager.previousResponseId!, + input: [{ type: "function_call_output", call_id: "call_99", output: "tool result" }], + }); + + const lastSent = JSON.parse(sock.sentMessages[1] ?? "{}") as ResponseCreateEvent; + expect(lastSent.previous_response_id).toBe("resp_1"); + expect(lastSent.input).toEqual([ + { type: "function_call_output", call_id: "call_99", output: "tool result" }, + ]); + }); + }); +}); + +// ───────────────────────────────────────────────────────────────────────────── +// Test Fixtures +// ───────────────────────────────────────────────────────────────────────────── + +function makeResponse( + id: string, + status: ResponseCompletedEvent["response"]["status"] = "in_progress", +): ResponseCompletedEvent["response"] { + return { + id, + object: "response", + created_at: Date.now(), + status, + model: "gpt-5.2", + output: [], + usage: { input_tokens: 10, output_tokens: 5, total_tokens: 15 }, + }; +} diff --git a/src/agents/openai-ws-stream.e2e.test.ts b/src/agents/openai-ws-stream.e2e.test.ts new file mode 100644 index 000000000..2b90d0dbc --- /dev/null +++ b/src/agents/openai-ws-stream.e2e.test.ts @@ -0,0 +1,151 @@ +/** + * End-to-end integration tests for OpenAI WebSocket streaming. + * + * These tests hit the real OpenAI Responses API over WebSocket and verify + * the full request/response lifecycle including: + * - Connection establishment and session reuse + * - Context options forwarding (temperature) + * - Graceful fallback to HTTP on connection failure + * - Connection lifecycle cleanup via releaseWsSession + * + * Run manually with a valid OPENAI_API_KEY: + * OPENAI_API_KEY=sk-... npx vitest run src/agents/openai-ws-stream.e2e.test.ts + * + * Skipped in CI — no API key available and we avoid billable external calls. + */ + +import { describe, it, expect, afterEach } from "vitest"; +import { + createOpenAIWebSocketStreamFn, + releaseWsSession, + hasWsSession, +} from "./openai-ws-stream.js"; + +const API_KEY = process.env.OPENAI_API_KEY; +const LIVE = !!API_KEY; +const testFn = LIVE ? it : it.skip; + +const model = { + api: "openai-responses" as const, + provider: "openai", + id: "gpt-4o-mini", + name: "gpt-4o-mini", + baseUrl: "", + reasoning: false, + input: { maxTokens: 128_000 }, + output: { maxTokens: 16_384 }, + cache: false, + compat: {}, +} as unknown as Parameters>[0]; + +type StreamFnParams = Parameters>; +function makeContext(userMessage: string): StreamFnParams[1] { + return { + systemPrompt: "You are a helpful assistant. Reply in one sentence.", + messages: [{ role: "user" as const, content: userMessage }], + tools: [], + } as unknown as StreamFnParams[1]; +} + +/** Each test gets a unique session ID to avoid cross-test interference. */ +const sessions: string[] = []; +function freshSession(name: string): string { + const id = `e2e-${name}-${Date.now()}`; + sessions.push(id); + return id; +} + +describe("OpenAI WebSocket e2e", () => { + afterEach(() => { + for (const id of sessions) { + releaseWsSession(id); + } + sessions.length = 0; + }); + + testFn( + "completes a single-turn request over WebSocket", + async () => { + const sid = freshSession("single"); + const streamFn = createOpenAIWebSocketStreamFn(API_KEY!, sid); + const stream = streamFn(model, makeContext("What is 2+2?"), {}); + + const events: Array<{ type: string }> = []; + for await (const event of stream as AsyncIterable<{ type: string }>) { + events.push(event); + } + + const done = events.find((e) => e.type === "done") as + | { type: "done"; message: { content: Array<{ type: string; text?: string }> } } + | undefined; + expect(done).toBeDefined(); + expect(done!.message.content.length).toBeGreaterThan(0); + + const text = done!.message.content + .filter((c) => c.type === "text") + .map((c) => c.text) + .join(""); + expect(text).toMatch(/4/); + }, + 30_000, + ); + + testFn( + "forwards temperature option to the API", + async () => { + const sid = freshSession("temp"); + const streamFn = createOpenAIWebSocketStreamFn(API_KEY!, sid); + const stream = streamFn(model, makeContext("Pick a random number between 1 and 1000."), { + temperature: 0.8, + }); + + const events: Array<{ type: string }> = []; + for await (const event of stream as AsyncIterable<{ type: string }>) { + events.push(event); + } + + // Stream must complete (done or error with fallback) — must NOT hang. + const hasTerminal = events.some((e) => e.type === "done" || e.type === "error"); + expect(hasTerminal).toBe(true); + }, + 30_000, + ); + + testFn( + "session is tracked in registry during request", + async () => { + const sid = freshSession("registry"); + const streamFn = createOpenAIWebSocketStreamFn(API_KEY!, sid); + + expect(hasWsSession(sid)).toBe(false); + + const stream = streamFn(model, makeContext("Say hello."), {}); + for await (const _ of stream as AsyncIterable) { + /* consume */ + } + + expect(hasWsSession(sid)).toBe(true); + releaseWsSession(sid); + expect(hasWsSession(sid)).toBe(false); + }, + 30_000, + ); + + testFn( + "falls back to HTTP gracefully with invalid API key", + async () => { + const sid = freshSession("fallback"); + const streamFn = createOpenAIWebSocketStreamFn("sk-invalid-key", sid); + const stream = streamFn(model, makeContext("Hello"), {}); + + const events: Array<{ type: string }> = []; + for await (const event of stream as AsyncIterable<{ type: string }>) { + events.push(event); + } + + const hasTerminal = events.some((e) => e.type === "done" || e.type === "error"); + expect(hasTerminal).toBe(true); + }, + 30_000, + ); +}); diff --git a/src/agents/openai-ws-stream.test.ts b/src/agents/openai-ws-stream.test.ts new file mode 100644 index 000000000..0b2911ce8 --- /dev/null +++ b/src/agents/openai-ws-stream.test.ts @@ -0,0 +1,1062 @@ +/** + * Unit tests for openai-ws-stream.ts + * + * Covers: + * - Message format converters (convertMessagesToInputItems, convertTools) + * - Response → AssistantMessage parser (buildAssistantMessageFromResponse) + * - createOpenAIWebSocketStreamFn behaviour (connect, send, receive, fallback) + * - Session registry helpers (releaseWsSession, hasWsSession) + */ + +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; +import type { ResponseObject } from "./openai-ws-connection.js"; +import { + buildAssistantMessageFromResponse, + convertMessagesToInputItems, + convertTools, + createOpenAIWebSocketStreamFn, + hasWsSession, + releaseWsSession, +} from "./openai-ws-stream.js"; + +// ───────────────────────────────────────────────────────────────────────────── +// Mock OpenAIWebSocketManager +// ───────────────────────────────────────────────────────────────────────────── + +// We mock the entire openai-ws-connection module so no real WebSocket is opened. +const { MockManager } = vi.hoisted(() => { + // eslint-disable-next-line @typescript-eslint/no-require-imports + const { EventEmitter } = require("node:events") as typeof import("node:events"); + type AnyFn = (...args: unknown[]) => void; + + // Shared mutable flag so inner class can see it + let _globalConnectShouldFail = false; + + class MockManager extends EventEmitter { + private _listeners: AnyFn[] = []; + private _previousResponseId: string | null = null; + private _connected = false; + private _broken = false; + + sentEvents: unknown[] = []; + connectCallCount = 0; + closeCallCount = 0; + + // Allow tests to override connect/send behaviour + connectShouldFail = false; + sendShouldFail = false; + + get previousResponseId(): string | null { + return this._previousResponseId; + } + + async connect(_apiKey: string): Promise { + this.connectCallCount++; + if (this.connectShouldFail || _globalConnectShouldFail) { + throw new Error("Mock connect failure"); + } + this._connected = true; + } + + isConnected(): boolean { + return this._connected && !this._broken; + } + + send(event: unknown): void { + if (!this._connected) { + throw new Error("cannot send — not connected"); + } + if (this.sendShouldFail) { + throw new Error("Mock send failure"); + } + this.sentEvents.push(event); + } + + onMessage(handler: (event: unknown) => void): () => void { + this._listeners.push(handler as AnyFn); + return () => { + this._listeners = this._listeners.filter((l) => l !== handler); + }; + } + + close(): void { + this.closeCallCount++; + this._connected = false; + } + + // Test helper: simulate WebSocket connection drop mid-request + simulateClose(code = 1006, reason = "connection lost"): void { + this._connected = false; + this.emit("close", code, reason); + } + + // Test helper: simulate a server event + simulateEvent(event: unknown): void { + for (const fn of this._listeners) { + fn(event); + } + } + + // Test helper: simulate connection being broken + simulateBroken(): void { + this._connected = false; + this._broken = true; + } + + // Test helper: set the previous response ID as if a turn completed + setPreviousResponseId(id: string): void { + this._previousResponseId = id; + } + + static lastInstance: MockManager | null = null; + static instances: MockManager[] = []; + + static reset(): void { + MockManager.lastInstance = null; + MockManager.instances = []; + } + } + + // Patch constructor to track instances + const OriginalMockManager = MockManager; + class TrackedMockManager extends OriginalMockManager { + constructor(...args: ConstructorParameters) { + super(...args); + TrackedMockManager.lastInstance = this; + TrackedMockManager.instances.push(this); + } + + static lastInstance: TrackedMockManager | null = null; + static instances: TrackedMockManager[] = []; + + /** Class-level flag: make ALL new instances fail on connect(). */ + static get globalConnectShouldFail(): boolean { + return _globalConnectShouldFail; + } + static set globalConnectShouldFail(v: boolean) { + _globalConnectShouldFail = v; + } + + static reset(): void { + TrackedMockManager.lastInstance = null; + TrackedMockManager.instances = []; + _globalConnectShouldFail = false; + } + } + + return { MockManager: TrackedMockManager }; +}); + +vi.mock("./openai-ws-connection.js", async (importOriginal) => { + const original = await importOriginal(); + return { + ...original, + OpenAIWebSocketManager: MockManager, + }; +}); + +// ───────────────────────────────────────────────────────────────────────────── +// Mock pi-ai +// ───────────────────────────────────────────────────────────────────────────── + +// Track if streamSimple (HTTP fallback) was called +const streamSimpleCalls: Array<{ model: unknown; context: unknown }> = []; + +vi.mock("@mariozechner/pi-ai", async (importOriginal) => { + const original = await importOriginal(); + + const mockStreamSimple = vi.fn((model: unknown, context: unknown) => { + streamSimpleCalls.push({ model, context }); + // Return a minimal AssistantMessageEventStream-like async iterable + const stream = original.createAssistantMessageEventStream(); + queueMicrotask(() => { + const msg = makeFakeAssistantMessage("http fallback response"); + stream.push({ type: "done", reason: "stop", message: msg }); + stream.end(); + }); + return stream; + }); + + return { + ...original, + streamSimple: mockStreamSimple, + }; +}); + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +/** Resolve a StreamFn return value (which may be a Promise) to an AsyncIterable. */ +async function resolveStream( + stream: ReturnType>, +): Promise> { + return stream instanceof Promise ? await stream : stream; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Fixtures +// ───────────────────────────────────────────────────────────────────────────── + +type FakeMessage = + | { role: "user"; content: string | unknown[]; timestamp: number } + | { + role: "assistant"; + content: unknown[]; + stopReason: string; + api: string; + provider: string; + model: string; + usage: unknown; + timestamp: number; + } + | { + role: "toolResult"; + toolCallId: string; + toolName: string; + content: unknown[]; + isError: boolean; + timestamp: number; + }; + +function userMsg(text: string): FakeMessage { + return { role: "user", content: text, timestamp: 0 }; +} + +function assistantMsg( + textBlocks: string[], + toolCalls: Array<{ id: string; name: string; args: Record }> = [], +): FakeMessage { + const content: unknown[] = []; + for (const t of textBlocks) { + content.push({ type: "text", text: t }); + } + for (const tc of toolCalls) { + content.push({ type: "toolCall", id: tc.id, name: tc.name, arguments: tc.args }); + } + return { + role: "assistant", + content, + stopReason: toolCalls.length > 0 ? "toolUse" : "stop", + api: "openai-responses", + provider: "openai", + model: "gpt-5.2", + usage: {}, + timestamp: 0, + }; +} + +function toolResultMsg(callId: string, output: string): FakeMessage { + return { + role: "toolResult", + toolCallId: callId, + toolName: "test_tool", + content: [{ type: "text", text: output }], + isError: false, + timestamp: 0, + }; +} + +function makeFakeAssistantMessage(text: string) { + return { + role: "assistant" as const, + content: [{ type: "text" as const, text }], + stopReason: "stop" as const, + api: "openai-responses", + provider: "openai", + model: "gpt-5.2", + usage: { + input: 10, + output: 5, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 15, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + timestamp: Date.now(), + }; +} + +function makeResponseObject( + id: string, + outputText?: string, + toolCallName?: string, +): ResponseObject { + const output: ResponseObject["output"] = []; + if (outputText) { + output.push({ + type: "message", + id: "item_1", + role: "assistant", + content: [{ type: "output_text", text: outputText }], + }); + } + if (toolCallName) { + output.push({ + type: "function_call", + id: "item_2", + call_id: "call_abc", + name: toolCallName, + arguments: '{"arg":"value"}', + }); + } + return { + id, + object: "response", + created_at: Date.now(), + status: "completed", + model: "gpt-5.2", + output, + usage: { input_tokens: 100, output_tokens: 50, total_tokens: 150 }, + }; +} + +// ───────────────────────────────────────────────────────────────────────────── +// Test suite +// ───────────────────────────────────────────────────────────────────────────── + +describe("convertTools", () => { + it("returns empty array for undefined tools", () => { + expect(convertTools(undefined)).toEqual([]); + }); + + it("returns empty array for empty tools", () => { + expect(convertTools([])).toEqual([]); + }); + + it("converts tools to FunctionToolDefinition format", () => { + const tools = [ + { + name: "exec", + description: "Run a command", + parameters: { type: "object", properties: { cmd: { type: "string" } } }, + }, + ]; + const result = convertTools(tools as unknown as Parameters[0]); + expect(result).toHaveLength(1); + expect(result[0]).toMatchObject({ + type: "function", + function: { + name: "exec", + description: "Run a command", + parameters: { type: "object", properties: { cmd: { type: "string" } } }, + }, + }); + }); + + it("handles tools without description", () => { + const tools = [{ name: "ping", description: "", parameters: {} }]; + const result = convertTools(tools as Parameters[0]); + expect(result[0]?.function?.name).toBe("ping"); + }); +}); + +// ───────────────────────────────────────────────────────────────────────────── + +describe("convertMessagesToInputItems", () => { + it("converts a simple user text message", () => { + const items = convertMessagesToInputItems([userMsg("Hello!")] as Parameters< + typeof convertMessagesToInputItems + >[0]); + expect(items).toHaveLength(1); + expect(items[0]).toMatchObject({ type: "message", role: "user", content: "Hello!" }); + }); + + it("converts an assistant text-only message", () => { + const items = convertMessagesToInputItems([assistantMsg(["Hi there."])] as Parameters< + typeof convertMessagesToInputItems + >[0]); + expect(items).toHaveLength(1); + expect(items[0]).toMatchObject({ type: "message", role: "assistant", content: "Hi there." }); + }); + + it("converts an assistant message with a tool call", () => { + const msg = assistantMsg( + ["Let me run that."], + [{ id: "call_1", name: "exec", args: { cmd: "ls" } }], + ); + const items = convertMessagesToInputItems([msg] as Parameters< + typeof convertMessagesToInputItems + >[0]); + // Should produce a text message and a function_call item + const textItem = items.find((i) => i.type === "message"); + const fcItem = items.find((i) => i.type === "function_call"); + expect(textItem).toBeDefined(); + expect(fcItem).toMatchObject({ + type: "function_call", + call_id: "call_1", + name: "exec", + }); + const fc = fcItem as { arguments: string }; + expect(JSON.parse(fc.arguments)).toEqual({ cmd: "ls" }); + }); + + it("converts a tool result message", () => { + const items = convertMessagesToInputItems([toolResultMsg("call_1", "file.txt")] as Parameters< + typeof convertMessagesToInputItems + >[0]); + expect(items).toHaveLength(1); + expect(items[0]).toMatchObject({ + type: "function_call_output", + call_id: "call_1", + output: "file.txt", + }); + }); + + it("converts a full multi-turn conversation", () => { + const messages: FakeMessage[] = [ + userMsg("Run ls"), + assistantMsg([], [{ id: "call_1", name: "exec", args: { cmd: "ls" } }]), + toolResultMsg("call_1", "file.txt\nfoo.ts"), + ]; + const items = convertMessagesToInputItems( + messages as Parameters[0], + ); + + const userItem = items.find( + (i) => i.type === "message" && (i as { role?: string }).role === "user", + ); + const fcItem = items.find((i) => i.type === "function_call"); + const outputItem = items.find((i) => i.type === "function_call_output"); + + expect(userItem).toBeDefined(); + expect(fcItem).toBeDefined(); + expect(outputItem).toBeDefined(); + }); + + it("handles assistant messages with only tool calls (no text)", () => { + const msg = assistantMsg([], [{ id: "call_2", name: "read", args: { path: "/etc/hosts" } }]); + const items = convertMessagesToInputItems([msg] as Parameters< + typeof convertMessagesToInputItems + >[0]); + expect(items).toHaveLength(1); + expect(items[0]?.type).toBe("function_call"); + }); + + it("skips thinking blocks in assistant messages", () => { + const msg = { + role: "assistant" as const, + content: [ + { type: "thinking", thinking: "internal reasoning..." }, + { type: "text", text: "Here is my answer." }, + ], + stopReason: "stop", + api: "openai-responses", + provider: "openai", + model: "gpt-5.2", + usage: {}, + timestamp: 0, + }; + const items = convertMessagesToInputItems([msg] as Parameters< + typeof convertMessagesToInputItems + >[0]); + expect(items).toHaveLength(1); + expect((items[0] as { content?: unknown }).content).toBe("Here is my answer."); + }); + + it("returns empty array for empty messages", () => { + expect(convertMessagesToInputItems([])).toEqual([]); + }); +}); + +// ───────────────────────────────────────────────────────────────────────────── + +describe("buildAssistantMessageFromResponse", () => { + const modelInfo = { api: "openai-responses", provider: "openai", id: "gpt-5.2" }; + + it("extracts text content from a message output item", () => { + const response = makeResponseObject("resp_1", "Hello from assistant"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + expect(msg.content).toHaveLength(1); + const textBlock = msg.content[0] as { type: string; text: string }; + expect(textBlock.type).toBe("text"); + expect(textBlock.text).toBe("Hello from assistant"); + }); + + it("sets stopReason to 'stop' for text-only responses", () => { + const response = makeResponseObject("resp_1", "Just text"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + expect(msg.stopReason).toBe("stop"); + }); + + it("extracts tool call from function_call output item", () => { + const response = makeResponseObject("resp_2", undefined, "exec"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + const tc = msg.content.find((c) => c.type === "toolCall") as { + type: string; + id: string; + name: string; + arguments: Record; + }; + expect(tc).toBeDefined(); + expect(tc.name).toBe("exec"); + expect(tc.id).toBe("call_abc"); + expect(tc.arguments).toEqual({ arg: "value" }); + }); + + it("sets stopReason to 'toolUse' when tool calls are present", () => { + const response = makeResponseObject("resp_3", undefined, "exec"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + expect(msg.stopReason).toBe("toolUse"); + }); + + it("includes both text and tool calls when both present", () => { + const response = makeResponseObject("resp_4", "Running...", "exec"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + expect(msg.content.some((c) => c.type === "text")).toBe(true); + expect(msg.content.some((c) => c.type === "toolCall")).toBe(true); + expect(msg.stopReason).toBe("toolUse"); + }); + + it("maps usage tokens correctly", () => { + const response = makeResponseObject("resp_5", "Hello"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + expect(msg.usage.input).toBe(100); + expect(msg.usage.output).toBe(50); + expect(msg.usage.totalTokens).toBe(150); + }); + + it("sets model/provider/api from modelInfo", () => { + const response = makeResponseObject("resp_6", "Hi"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + expect(msg.api).toBe("openai-responses"); + expect(msg.provider).toBe("openai"); + expect(msg.model).toBe("gpt-5.2"); + }); + + it("handles empty output gracefully", () => { + const response = makeResponseObject("resp_7"); + const msg = buildAssistantMessageFromResponse(response, modelInfo); + expect(msg.content).toEqual([]); + expect(msg.stopReason).toBe("stop"); + }); +}); + +// ───────────────────────────────────────────────────────────────────────────── + +describe("createOpenAIWebSocketStreamFn", () => { + const modelStub = { + api: "openai-responses", + provider: "openai", + id: "gpt-5.2", + contextWindow: 128000, + maxTokens: 4096, + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + name: "GPT-5.2", + }; + + const contextStub = { + systemPrompt: "You are helpful.", + messages: [userMsg("Hello!") as Parameters[0][number]], + tools: [], + }; + + beforeEach(() => { + MockManager.reset(); + streamSimpleCalls.length = 0; + }); + + afterEach(() => { + // Clean up any sessions created in tests to avoid cross-test pollution + MockManager.instances.forEach((_, i) => { + // Session IDs used in tests follow a predictable pattern + releaseWsSession(`test-session-${i}`); + }); + releaseWsSession("sess-1"); + releaseWsSession("sess-2"); + releaseWsSession("sess-fallback"); + releaseWsSession("sess-incremental"); + releaseWsSession("sess-full"); + releaseWsSession("sess-tools"); + }); + + it("connects to the WebSocket on first call", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-1"); + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + ); + + // Give the microtask queue time to run + await new Promise((r) => setImmediate(r)); + + const manager = MockManager.lastInstance; + expect(manager?.connectCallCount).toBe(1); + // Consume stream to avoid dangling promise + void resolveStream(stream); + }); + + it("sends a response.create event on first turn (full context)", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-full"); + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + ); + + const completed = new Promise((res, rej) => { + queueMicrotask(async () => { + try { + await new Promise((r) => setImmediate(r)); + const manager = MockManager.lastInstance!; + + // Simulate the server completing the response + manager.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp_1", "Hello!"), + }); + + for await (const _ of await resolveStream(stream)) { + // consume events + } + res(); + } catch (e) { + rej(e); + } + }); + }); + + await completed; + + const manager = MockManager.lastInstance!; + expect(manager.sentEvents).toHaveLength(1); + const sent = manager.sentEvents[0] as { type: string; model: string; input: unknown[] }; + expect(sent.type).toBe("response.create"); + expect(sent.model).toBe("gpt-5.2"); + expect(Array.isArray(sent.input)).toBe(true); + }); + + it("emits an AssistantMessage on response.completed", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-2"); + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + ); + + const events: unknown[] = []; + const done = (async () => { + for await (const ev of await resolveStream(stream)) { + events.push(ev); + } + })(); + + await new Promise((r) => setImmediate(r)); + const manager = MockManager.lastInstance!; + manager.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp_hello", "Hello back!"), + }); + + await done; + + const doneEvent = events.find((e) => (e as { type?: string }).type === "done") as + | { + type: string; + reason: string; + message: { content: Array<{ text: string }> }; + } + | undefined; + expect(doneEvent).toBeDefined(); + expect(doneEvent?.message.content[0]?.text).toBe("Hello back!"); + }); + + it("falls back to HTTP when WebSocket connect fails (session pre-broken via flag)", async () => { + // Set the class-level flag BEFORE calling streamFn so the new instance + // fails on connect(). We patch the static default via MockManager directly. + MockManager.globalConnectShouldFail = true; + + try { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-fallback"); + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + ); + + // Consume — should fall back to HTTP (streamSimple mock). + const messages: unknown[] = []; + for await (const ev of await resolveStream(stream)) { + messages.push(ev); + } + + // streamSimple was called as part of HTTP fallback + expect(streamSimpleCalls.length).toBeGreaterThanOrEqual(1); + + // manager.close() must be called to cancel background reconnect attempts + expect(MockManager.lastInstance!.closeCallCount).toBeGreaterThanOrEqual(1); + } finally { + MockManager.globalConnectShouldFail = false; + } + }); + + it("tracks previous_response_id across turns (incremental send)", async () => { + const sessionId = "sess-incremental"; + const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId); + + // ── Turn 1: full context ───────────────────────────────────────────── + const ctx1 = { + systemPrompt: "You are helpful.", + messages: [userMsg("Run ls")] as Parameters[0], + tools: [], + }; + + const stream1 = streamFn( + modelStub as Parameters[0], + ctx1 as Parameters[1], + ); + + const events1: unknown[] = []; + const done1 = (async () => { + for await (const ev of await resolveStream(stream1)) { + events1.push(ev); + } + })(); + + await new Promise((r) => setImmediate(r)); + const manager = MockManager.lastInstance!; + + // Server responds with a tool call + const turn1Response = makeResponseObject("resp_turn1", undefined, "exec"); + manager.setPreviousResponseId("resp_turn1"); + manager.simulateEvent({ type: "response.completed", response: turn1Response }); + await done1; + + // ── Turn 2: incremental (tool results only) ─────────────────────────── + const ctx2 = { + systemPrompt: "You are helpful.", + messages: [ + userMsg("Run ls"), + assistantMsg([], [{ id: "call_1", name: "exec", args: { cmd: "ls" } }]), + toolResultMsg("call_1", "file.txt"), + ] as Parameters[0], + tools: [], + }; + + const stream2 = streamFn( + modelStub as Parameters[0], + ctx2 as Parameters[1], + ); + + const events2: unknown[] = []; + const done2 = (async () => { + for await (const ev of await resolveStream(stream2)) { + events2.push(ev); + } + })(); + + await new Promise((r) => setImmediate(r)); + manager.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp_turn2", "Here are the files."), + }); + await done2; + + // Turn 2 should have sent previous_response_id and only tool results + expect(manager.sentEvents).toHaveLength(2); + const sent2 = manager.sentEvents[1] as { + previous_response_id?: string; + input: Array<{ type: string }>; + }; + expect(sent2.previous_response_id).toBe("resp_turn1"); + // Input should only contain tool results, not the full history + const inputTypes = (sent2.input ?? []).map((i) => i.type); + expect(inputTypes.every((t) => t === "function_call_output")).toBe(true); + expect(inputTypes).toHaveLength(1); + }); + + it("sends instructions (system prompt) in each request", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-tools"); + const ctx = { + systemPrompt: "Be concise.", + messages: [userMsg("Hello")] as Parameters[0], + tools: [{ name: "exec", description: "run", parameters: {} }], + }; + + const stream = streamFn( + modelStub as Parameters[0], + ctx as Parameters[1], + ); + + await new Promise((r) => setImmediate(r)); + const manager = MockManager.lastInstance!; + manager.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp_x", "ok"), + }); + + for await (const _ of await resolveStream(stream)) { + // consume + } + + const sent = manager.sentEvents[0] as { + instructions?: string; + tools?: unknown[]; + }; + expect(sent.instructions).toBe("Be concise."); + expect(Array.isArray(sent.tools)).toBe(true); + expect((sent.tools ?? []).length).toBeGreaterThan(0); + }); + + it("resets session state and falls back to HTTP when send() throws", async () => { + const sessionId = "sess-send-fail-reset"; + const streamFn = createOpenAIWebSocketStreamFn("sk-test", sessionId); + + // 1. Run a successful first turn to populate the registry + const stream1 = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + ); + await new Promise((resolve, reject) => { + queueMicrotask(async () => { + try { + await new Promise((r) => setImmediate(r)); + MockManager.lastInstance!.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp-ok", "OK"), + }); + for await (const _ of await resolveStream(stream1)) { + /* consume */ + } + resolve(); + } catch (e) { + reject(e); + } + }); + }); + expect(hasWsSession(sessionId)).toBe(true); + + // 2. Arm send failure and record pre-call streamSimpleCalls count + MockManager.lastInstance!.sendShouldFail = true; + const callsBefore = streamSimpleCalls.length; + + // 3. Second call: send throws → must fall back to HTTP and clear registry + const stream2 = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + ); + for await (const _ of await resolveStream(stream2)) { + /* consume */ + } + + // Registry cleared after send failure + expect(hasWsSession(sessionId)).toBe(false); + // HTTP fallback invoked + expect(streamSimpleCalls.length).toBeGreaterThan(callsBefore); + }); + + it("forwards temperature and maxTokens to response.create", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-temp"); + const opts = { temperature: 0.3, maxTokens: 256 }; + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + opts as Parameters[2], + ); + await new Promise((resolve, reject) => { + queueMicrotask(async () => { + try { + await new Promise((r) => setImmediate(r)); + MockManager.lastInstance!.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp-temp", "Done"), + }); + for await (const _ of await resolveStream(stream)) { + /* consume */ + } + resolve(); + } catch (e) { + reject(e); + } + }); + }); + const sent = MockManager.lastInstance!.sentEvents[0] as Record; + expect(sent.type).toBe("response.create"); + expect(sent.temperature).toBe(0.3); + expect(sent.max_output_tokens).toBe(256); + }); + + it("forwards reasoningEffort/reasoningSummary to response.create reasoning block", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-reason"); + const opts = { reasoningEffort: "high", reasoningSummary: "auto" }; + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + opts as unknown as Parameters[2], + ); + await new Promise((resolve, reject) => { + queueMicrotask(async () => { + try { + await new Promise((r) => setImmediate(r)); + MockManager.lastInstance!.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp-reason", "Deep thought"), + }); + for await (const _ of await resolveStream(stream)) { + /* consume */ + } + resolve(); + } catch (e) { + reject(e); + } + }); + }); + const sent = MockManager.lastInstance!.sentEvents[0] as Record; + expect(sent.type).toBe("response.create"); + expect(sent.reasoning).toEqual({ effort: "high", summary: "auto" }); + }); + + it("forwards topP and toolChoice to response.create", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-topp"); + const opts = { topP: 0.9, toolChoice: "auto" }; + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + opts as unknown as Parameters[2], + ); + await new Promise((resolve, reject) => { + queueMicrotask(async () => { + try { + await new Promise((r) => setImmediate(r)); + MockManager.lastInstance!.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp-topp", "Done"), + }); + for await (const _ of await resolveStream(stream)) { + /* consume */ + } + resolve(); + } catch (e) { + reject(e); + } + }); + }); + const sent = MockManager.lastInstance!.sentEvents[0] as Record; + expect(sent.type).toBe("response.create"); + expect(sent.top_p).toBe(0.9); + expect(sent.tool_choice).toBe("auto"); + }); + + it("rejects promise when WebSocket drops mid-request", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "sess-drop"); + const stream = streamFn( + modelStub as Parameters[0], + contextStub as Parameters[1], + {} as Parameters[2], + ); + // Let the send go through, then simulate connection drop before response.completed + await new Promise((resolve) => { + queueMicrotask(async () => { + try { + await new Promise((r) => setImmediate(r)); + // Simulate a connection drop instead of sending response.completed + MockManager.lastInstance!.simulateClose(1006, "connection lost"); + const events: unknown[] = []; + for await (const ev of await resolveStream(stream)) { + events.push(ev); + } + // Should have gotten an error event, not hung forever + const hasError = events.some( + (e) => typeof e === "object" && e !== null && (e as { type: string }).type === "error", + ); + expect(hasError).toBe(true); + resolve(); + } catch { + // The error propagation is also acceptable — promise rejected + resolve(); + } + }); + }); + }); +}); + +// ───────────────────────────────────────────────────────────────────────────── + +describe("releaseWsSession / hasWsSession", () => { + beforeEach(() => { + MockManager.reset(); + }); + + afterEach(() => { + releaseWsSession("registry-test"); + }); + + it("hasWsSession returns false for unknown session", () => { + expect(hasWsSession("nonexistent-session")).toBe(false); + }); + + it("hasWsSession returns true after a session is created", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "registry-test"); + const stream = streamFn( + { + api: "openai-responses", + provider: "openai", + id: "gpt-5.2", + contextWindow: 128000, + maxTokens: 4096, + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + name: "GPT-5.2", + } as Parameters[0], + { + systemPrompt: "test", + messages: [userMsg("Hi") as Parameters[0][number]], + tools: [], + } as Parameters[1], + ); + + await new Promise((r) => setImmediate(r)); + // Session should be registered and connected + expect(hasWsSession("registry-test")).toBe(true); + + // Clean up + const manager = MockManager.lastInstance!; + manager.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp_z", "done"), + }); + for await (const _ of await resolveStream(stream)) { + // consume + } + }); + + it("releaseWsSession closes the connection and removes the session", async () => { + const streamFn = createOpenAIWebSocketStreamFn("sk-test", "registry-test"); + const stream = streamFn( + { + api: "openai-responses", + provider: "openai", + id: "gpt-5.2", + contextWindow: 128000, + maxTokens: 4096, + reasoning: false, + input: ["text"], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + name: "GPT-5.2", + } as Parameters[0], + { + systemPrompt: "test", + messages: [userMsg("Hi") as Parameters[0][number]], + tools: [], + } as Parameters[1], + ); + + await new Promise((r) => setImmediate(r)); + const manager = MockManager.lastInstance!; + manager.simulateEvent({ + type: "response.completed", + response: makeResponseObject("resp_zz", "done"), + }); + for await (const _ of await resolveStream(stream)) { + // consume + } + + releaseWsSession("registry-test"); + expect(hasWsSession("registry-test")).toBe(false); + expect(manager.closeCallCount).toBe(1); + }); + + it("releaseWsSession is a no-op for unknown sessions", () => { + expect(() => releaseWsSession("nonexistent-session")).not.toThrow(); + }); +});