feat: add sessions_yield tool for cooperative turn-ending (#36537)
Merged via squash. Prepared head SHA: 75d9204c863792226389a4d33eeb40c4e842528d Co-authored-by: jriff <50276+jriff@users.noreply.github.com> Co-authored-by: jalehman <550978+jalehman@users.noreply.github.com> Reviewed-by: @jalehman
This commit is contained in:
@@ -24,6 +24,7 @@ Docs: https://docs.openclaw.ai
|
||||
### Changes
|
||||
|
||||
- Docs/Kubernetes: Add a starter K8s install path with raw manifests, Kind setup, and deployment docs. Thanks @sallyom @dzianisv @egkristi
|
||||
- Agents/subagents: add `sessions_yield` so orchestrators can end the current turn immediately, skip queued tool work, and carry a hidden follow-up payload into the next session turn. (#36537) thanks @jriff
|
||||
|
||||
### Fixes
|
||||
|
||||
|
||||
@@ -32,13 +32,13 @@ INPUT_PATHS=(
|
||||
)
|
||||
|
||||
compute_hash() {
|
||||
ROOT_DIR="$ROOT_DIR" node --input-type=module - "${INPUT_PATHS[@]}" <<'NODE'
|
||||
ROOT_DIR="$ROOT_DIR" node --input-type=module --eval '
|
||||
import { createHash } from "node:crypto";
|
||||
import { promises as fs } from "node:fs";
|
||||
import path from "node:path";
|
||||
|
||||
const rootDir = process.env.ROOT_DIR ?? process.cwd();
|
||||
const inputs = process.argv.slice(2);
|
||||
const inputs = process.argv.slice(1);
|
||||
const files = [];
|
||||
|
||||
async function walk(entryPath) {
|
||||
@@ -73,7 +73,7 @@ for (const filePath of files) {
|
||||
}
|
||||
|
||||
process.stdout.write(hash.digest("hex"));
|
||||
NODE
|
||||
' "${INPUT_PATHS[@]}"
|
||||
}
|
||||
|
||||
current_hash="$(compute_hash)"
|
||||
|
||||
@@ -21,6 +21,7 @@ import { createSessionsHistoryTool } from "./tools/sessions-history-tool.js";
|
||||
import { createSessionsListTool } from "./tools/sessions-list-tool.js";
|
||||
import { createSessionsSendTool } from "./tools/sessions-send-tool.js";
|
||||
import { createSessionsSpawnTool } from "./tools/sessions-spawn-tool.js";
|
||||
import { createSessionsYieldTool } from "./tools/sessions-yield-tool.js";
|
||||
import { createSubagentsTool } from "./tools/subagents-tool.js";
|
||||
import { createTtsTool } from "./tools/tts-tool.js";
|
||||
import { createWebFetchTool, createWebSearchTool } from "./tools/web-tools.js";
|
||||
@@ -77,6 +78,8 @@ export function createOpenClawTools(
|
||||
* subagents inherit the real workspace path instead of the sandbox copy.
|
||||
*/
|
||||
spawnWorkspaceDir?: string;
|
||||
/** Callback invoked when sessions_yield tool is called. */
|
||||
onYield?: (message: string) => Promise<void> | void;
|
||||
} & SpawnedToolContext,
|
||||
): AnyAgentTool[] {
|
||||
const workspaceDir = resolveWorkspaceRoot(options?.workspaceDir);
|
||||
@@ -181,6 +184,10 @@ export function createOpenClawTools(
|
||||
agentChannel: options?.agentChannel,
|
||||
sandboxed: options?.sandboxed,
|
||||
}),
|
||||
createSessionsYieldTool({
|
||||
sessionId: options?.sessionId,
|
||||
onYield: options?.onYield,
|
||||
}),
|
||||
createSessionsSpawnTool({
|
||||
agentSessionKey: options?.agentSessionKey,
|
||||
agentChannel: options?.agentChannel,
|
||||
|
||||
@@ -276,7 +276,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { model: "deepseek/deepseek-r1" };
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -308,7 +308,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = {};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -332,7 +332,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { reasoning_effort: "high" };
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -357,7 +357,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { reasoning: { max_tokens: 256 } };
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -381,7 +381,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { reasoning_effort: "medium" };
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -588,7 +588,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { thinking: "off" };
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -619,7 +619,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { thinking: "off" };
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -650,7 +650,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = {};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -674,7 +674,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = { tool_choice: "required" };
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -699,7 +699,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
const payloads: Record<string, unknown>[] = [];
|
||||
const baseStreamFn: StreamFn = (_model, _context, options) => {
|
||||
const payload: Record<string, unknown> = {};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -749,7 +749,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
],
|
||||
tool_choice: { type: "tool", name: "read" },
|
||||
};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -793,7 +793,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
],
|
||||
};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -832,7 +832,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
],
|
||||
};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -896,7 +896,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
},
|
||||
};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
@@ -943,7 +943,7 @@ describe("applyExtraParamsToAgent", () => {
|
||||
},
|
||||
},
|
||||
};
|
||||
options?.onPayload?.(payload, model);
|
||||
options?.onPayload?.(payload, _model);
|
||||
payloads.push(payload);
|
||||
return {} as ReturnType<StreamFn>;
|
||||
};
|
||||
|
||||
370
src/agents/pi-embedded-runner.sessions-yield.e2e.test.ts
Normal file
370
src/agents/pi-embedded-runner.sessions-yield.e2e.test.ts
Normal file
@@ -0,0 +1,370 @@
|
||||
/**
|
||||
* End-to-end test proving that when sessions_yield is called:
|
||||
* 1. The attempt completes with yieldDetected
|
||||
* 2. The run exits with stopReason "end_turn" and no pendingToolCalls
|
||||
* 3. The parent session is idle (clearActiveEmbeddedRun has run)
|
||||
*
|
||||
* This exercises the full path: mock LLM → agent loop → tool execution → callback → attempt result → run result.
|
||||
* Follows the same pattern as pi-embedded-runner.e2e.test.ts.
|
||||
*/
|
||||
import fs from "node:fs/promises";
|
||||
import os from "node:os";
|
||||
import path from "node:path";
|
||||
import "./test-helpers/fast-coding-tools.js";
|
||||
import { afterAll, beforeAll, describe, expect, it, vi } from "vitest";
|
||||
import type { OpenClawConfig } from "../config/config.js";
|
||||
import { isEmbeddedPiRunActive, queueEmbeddedPiMessage } from "./pi-embedded-runner/runs.js";
|
||||
|
||||
function createMockUsage(input: number, output: number) {
|
||||
return {
|
||||
input,
|
||||
output,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: input + output,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
let streamCallCount = 0;
|
||||
let multiToolMode = false;
|
||||
let responsePlan: Array<"toolUse" | "stop"> = [];
|
||||
let observedContexts: Array<Array<{ role?: string; content?: unknown }>> = [];
|
||||
|
||||
vi.mock("@mariozechner/pi-coding-agent", async () => {
|
||||
return await vi.importActual<typeof import("@mariozechner/pi-coding-agent")>(
|
||||
"@mariozechner/pi-coding-agent",
|
||||
);
|
||||
});
|
||||
|
||||
vi.mock("@mariozechner/pi-ai", async () => {
|
||||
const actual = await vi.importActual<typeof import("@mariozechner/pi-ai")>("@mariozechner/pi-ai");
|
||||
|
||||
const buildToolUseMessage = (model: { api: string; provider: string; id: string }) => {
|
||||
const toolCalls: Array<{
|
||||
type: "toolCall";
|
||||
id: string;
|
||||
name: string;
|
||||
arguments: Record<string, unknown>;
|
||||
}> = [
|
||||
{
|
||||
type: "toolCall" as const,
|
||||
id: "tc-yield-e2e-1",
|
||||
name: "sessions_yield",
|
||||
arguments: { message: "Yielding turn." },
|
||||
},
|
||||
];
|
||||
if (multiToolMode) {
|
||||
toolCalls.push({
|
||||
type: "toolCall" as const,
|
||||
id: "tc-post-yield-2",
|
||||
name: "read",
|
||||
arguments: { file_path: "/etc/hostname" },
|
||||
});
|
||||
}
|
||||
return {
|
||||
role: "assistant" as const,
|
||||
content: toolCalls,
|
||||
stopReason: "toolUse" as const,
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: createMockUsage(1, 1),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
};
|
||||
|
||||
const buildStopMessage = (model: { api: string; provider: string; id: string }) => ({
|
||||
role: "assistant" as const,
|
||||
content: [{ type: "text" as const, text: "Acknowledged." }],
|
||||
stopReason: "stop" as const,
|
||||
api: model.api,
|
||||
provider: model.provider,
|
||||
model: model.id,
|
||||
usage: createMockUsage(1, 1),
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
|
||||
return {
|
||||
...actual,
|
||||
complete: async (model: { api: string; provider: string; id: string }) => {
|
||||
streamCallCount++;
|
||||
const next = responsePlan.shift() ?? "stop";
|
||||
return next === "toolUse" ? buildToolUseMessage(model) : buildStopMessage(model);
|
||||
},
|
||||
completeSimple: async (model: { api: string; provider: string; id: string }) => {
|
||||
streamCallCount++;
|
||||
const next = responsePlan.shift() ?? "stop";
|
||||
return next === "toolUse" ? buildToolUseMessage(model) : buildStopMessage(model);
|
||||
},
|
||||
streamSimple: (
|
||||
model: { api: string; provider: string; id: string },
|
||||
context: { messages?: Array<{ role?: string; content?: unknown }> },
|
||||
) => {
|
||||
streamCallCount++;
|
||||
observedContexts.push((context.messages ?? []).map((message) => ({ ...message })));
|
||||
const next = responsePlan.shift() ?? "stop";
|
||||
const message = next === "toolUse" ? buildToolUseMessage(model) : buildStopMessage(model);
|
||||
const stream = actual.createAssistantMessageEventStream();
|
||||
queueMicrotask(() => {
|
||||
stream.push({
|
||||
type: "done",
|
||||
reason: next === "toolUse" ? "toolUse" : "stop",
|
||||
message,
|
||||
});
|
||||
stream.end();
|
||||
});
|
||||
return stream;
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
let runEmbeddedPiAgent: typeof import("./pi-embedded-runner/run.js").runEmbeddedPiAgent;
|
||||
let tempRoot: string | undefined;
|
||||
let agentDir: string;
|
||||
let workspaceDir: string;
|
||||
|
||||
beforeAll(async () => {
|
||||
vi.useRealTimers();
|
||||
streamCallCount = 0;
|
||||
responsePlan = [];
|
||||
observedContexts = [];
|
||||
({ runEmbeddedPiAgent } = await import("./pi-embedded-runner/run.js"));
|
||||
tempRoot = await fs.mkdtemp(path.join(os.tmpdir(), "openclaw-yield-e2e-"));
|
||||
agentDir = path.join(tempRoot, "agent");
|
||||
workspaceDir = path.join(tempRoot, "workspace");
|
||||
await fs.mkdir(agentDir, { recursive: true });
|
||||
await fs.mkdir(workspaceDir, { recursive: true });
|
||||
}, 180_000);
|
||||
|
||||
afterAll(async () => {
|
||||
if (!tempRoot) {
|
||||
return;
|
||||
}
|
||||
await fs.rm(tempRoot, { recursive: true, force: true });
|
||||
tempRoot = undefined;
|
||||
});
|
||||
|
||||
const makeConfig = (modelIds: string[]) =>
|
||||
({
|
||||
models: {
|
||||
providers: {
|
||||
openai: {
|
||||
api: "openai-responses",
|
||||
apiKey: "sk-test",
|
||||
baseUrl: "https://example.com",
|
||||
models: modelIds.map((id) => ({
|
||||
id,
|
||||
name: `Mock ${id}`,
|
||||
reasoning: false,
|
||||
input: ["text"],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 16_000,
|
||||
maxTokens: 2048,
|
||||
})),
|
||||
},
|
||||
},
|
||||
},
|
||||
}) satisfies OpenClawConfig;
|
||||
|
||||
const immediateEnqueue = async <T>(task: () => Promise<T>) => task();
|
||||
|
||||
const readSessionMessages = async (sessionFile: string) => {
|
||||
const raw = await fs.readFile(sessionFile, "utf-8");
|
||||
return raw
|
||||
.split(/\r?\n/)
|
||||
.filter(Boolean)
|
||||
.map(
|
||||
(line) =>
|
||||
JSON.parse(line) as { type?: string; message?: { role?: string; content?: unknown } },
|
||||
)
|
||||
.filter((entry) => entry.type === "message")
|
||||
.map((entry) => entry.message) as Array<{ role?: string; content?: unknown }>;
|
||||
};
|
||||
|
||||
const readSessionEntries = async (sessionFile: string) =>
|
||||
(await fs.readFile(sessionFile, "utf-8"))
|
||||
.split(/\r?\n/)
|
||||
.filter(Boolean)
|
||||
.map((line) => JSON.parse(line) as Record<string, unknown>);
|
||||
|
||||
describe("sessions_yield e2e", () => {
|
||||
it(
|
||||
"parent session is idle after yield and preserves the follow-up payload",
|
||||
{ timeout: 15_000 },
|
||||
async () => {
|
||||
streamCallCount = 0;
|
||||
responsePlan = ["toolUse"];
|
||||
observedContexts = [];
|
||||
|
||||
const sessionId = "yield-e2e-parent";
|
||||
const sessionFile = path.join(workspaceDir, "session-yield-e2e.jsonl");
|
||||
const cfg = makeConfig(["mock-yield"]);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId,
|
||||
sessionKey: "agent:test:yield-e2e",
|
||||
sessionFile,
|
||||
workspaceDir,
|
||||
config: cfg,
|
||||
prompt: "Spawn subagent and yield.",
|
||||
provider: "openai",
|
||||
model: "mock-yield",
|
||||
timeoutMs: 10_000,
|
||||
agentDir,
|
||||
runId: "run-yield-e2e-1",
|
||||
enqueue: immediateEnqueue,
|
||||
});
|
||||
|
||||
// 1. Run completed with end_turn (yield causes clean exit)
|
||||
expect(result.meta.stopReason).toBe("end_turn");
|
||||
|
||||
// 2. No pending tool calls (yield is NOT a client tool call)
|
||||
expect(result.meta.pendingToolCalls).toBeUndefined();
|
||||
|
||||
// 3. Parent session is IDLE — clearActiveEmbeddedRun ran in finally block
|
||||
expect(isEmbeddedPiRunActive(sessionId)).toBe(false);
|
||||
|
||||
// 4. Steer would fail — session not in ACTIVE_EMBEDDED_RUNS
|
||||
expect(queueEmbeddedPiMessage(sessionId, "subagent result")).toBe(false);
|
||||
|
||||
// 5. The yield stops at tool time — there is no second provider call.
|
||||
expect(streamCallCount).toBe(1);
|
||||
|
||||
// 6. Session transcript contains only the original assistant tool call.
|
||||
const messages = await readSessionMessages(sessionFile);
|
||||
const roles = messages.map((m) => m?.role);
|
||||
expect(roles).toContain("user");
|
||||
expect(roles.filter((r) => r === "assistant")).toHaveLength(1);
|
||||
|
||||
const firstAssistant = messages.find((m) => m?.role === "assistant");
|
||||
const content = firstAssistant?.content;
|
||||
expect(Array.isArray(content)).toBe(true);
|
||||
const toolCall = (content as Array<{ type?: string; name?: string }>).find(
|
||||
(c) => c.type === "toolCall" && c.name === "sessions_yield",
|
||||
);
|
||||
expect(toolCall).toBeDefined();
|
||||
|
||||
const entries = await readSessionEntries(sessionFile);
|
||||
const yieldContext = entries.find(
|
||||
(entry) =>
|
||||
entry.type === "custom_message" && entry.customType === "openclaw.sessions_yield",
|
||||
);
|
||||
expect(yieldContext).toMatchObject({
|
||||
content: expect.stringContaining("Yielding turn."),
|
||||
});
|
||||
|
||||
streamCallCount = 0;
|
||||
responsePlan = ["stop"];
|
||||
observedContexts = [];
|
||||
await runEmbeddedPiAgent({
|
||||
sessionId,
|
||||
sessionKey: "agent:test:yield-e2e",
|
||||
sessionFile,
|
||||
workspaceDir,
|
||||
config: cfg,
|
||||
prompt: "Subagent finished with the requested result.",
|
||||
provider: "openai",
|
||||
model: "mock-yield",
|
||||
timeoutMs: 10_000,
|
||||
agentDir,
|
||||
runId: "run-yield-e2e-2",
|
||||
enqueue: immediateEnqueue,
|
||||
});
|
||||
|
||||
const resumeContext = observedContexts[0] ?? [];
|
||||
const resumeTexts = resumeContext.flatMap((message) =>
|
||||
Array.isArray(message.content)
|
||||
? (message.content as Array<{ type?: string; text?: string }>)
|
||||
.filter((part) => part.type === "text" && typeof part.text === "string")
|
||||
.map((part) => part.text ?? "")
|
||||
: [],
|
||||
);
|
||||
expect(resumeTexts.some((text) => text.includes("Yielding turn."))).toBe(true);
|
||||
expect(
|
||||
resumeTexts.some((text) => text.includes("Subagent finished with the requested result.")),
|
||||
).toBe(true);
|
||||
},
|
||||
);
|
||||
|
||||
it(
|
||||
"abort prevents subsequent tool calls from executing after yield",
|
||||
{ timeout: 15_000 },
|
||||
async () => {
|
||||
streamCallCount = 0;
|
||||
multiToolMode = true;
|
||||
responsePlan = ["toolUse"];
|
||||
observedContexts = [];
|
||||
|
||||
const sessionId = "yield-e2e-abort";
|
||||
const sessionFile = path.join(workspaceDir, "session-yield-abort.jsonl");
|
||||
const cfg = makeConfig(["mock-yield-abort"]);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
sessionId,
|
||||
sessionKey: "agent:test:yield-abort",
|
||||
sessionFile,
|
||||
workspaceDir,
|
||||
config: cfg,
|
||||
prompt: "Yield and then read a file.",
|
||||
provider: "openai",
|
||||
model: "mock-yield-abort",
|
||||
timeoutMs: 10_000,
|
||||
agentDir,
|
||||
runId: "run-yield-abort-1",
|
||||
enqueue: immediateEnqueue,
|
||||
});
|
||||
|
||||
// Reset for other tests
|
||||
multiToolMode = false;
|
||||
|
||||
// 1. Run completed with end_turn despite the extra queued tool call
|
||||
expect(result.meta.stopReason).toBe("end_turn");
|
||||
|
||||
// 2. Session is idle
|
||||
expect(isEmbeddedPiRunActive(sessionId)).toBe(false);
|
||||
|
||||
// 3. The yield prevented a post-tool provider call.
|
||||
expect(streamCallCount).toBe(1);
|
||||
|
||||
// 4. Transcript should contain sessions_yield but NOT a successful read result
|
||||
const messages = await readSessionMessages(sessionFile);
|
||||
const allContent = messages.flatMap((m) =>
|
||||
Array.isArray(m?.content) ? (m.content as Array<{ type?: string; name?: string }>) : [],
|
||||
);
|
||||
const yieldCall = allContent.find(
|
||||
(c) => c.type === "toolCall" && c.name === "sessions_yield",
|
||||
);
|
||||
expect(yieldCall).toBeDefined();
|
||||
|
||||
// The read tool call should be in the assistant message (LLM requested it),
|
||||
// but its result should NOT show a successful file read.
|
||||
const readCall = allContent.find((c) => c.type === "toolCall" && c.name === "read");
|
||||
expect(readCall).toBeDefined(); // LLM asked for it...
|
||||
|
||||
// ...but the file was never actually read (no tool result with file contents)
|
||||
const toolResults = messages.filter((m) => m?.role === "toolResult");
|
||||
const readResult = toolResults.find((tr) => {
|
||||
const content = tr?.content;
|
||||
if (typeof content === "string") {
|
||||
return content.includes("/etc/hostname");
|
||||
}
|
||||
if (Array.isArray(content)) {
|
||||
return (content as Array<{ text?: string }>).some((c) =>
|
||||
c.text?.includes("/etc/hostname"),
|
||||
);
|
||||
}
|
||||
return false;
|
||||
});
|
||||
// If the read tool ran, its result would reference the file path.
|
||||
// The abort should have prevented it from executing.
|
||||
expect(readResult).toBeUndefined();
|
||||
},
|
||||
);
|
||||
});
|
||||
@@ -1574,6 +1574,8 @@ export async function runEmbeddedPiAgent(
|
||||
// ACP bridge) can distinguish end_turn from max_tokens.
|
||||
stopReason: attempt.clientToolCall
|
||||
? "tool_calls"
|
||||
: attempt.yieldDetected
|
||||
? "end_turn"
|
||||
: (lastAssistant?.stopReason as string | undefined),
|
||||
pendingToolCalls: attempt.clientToolCall
|
||||
? [
|
||||
|
||||
@@ -148,6 +148,186 @@ type PromptBuildHookRunner = {
|
||||
) => Promise<PluginHookBeforeAgentStartResult | undefined>;
|
||||
};
|
||||
|
||||
const SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE = "openclaw.sessions_yield_interrupt";
|
||||
const SESSIONS_YIELD_CONTEXT_CUSTOM_TYPE = "openclaw.sessions_yield";
|
||||
|
||||
// Persist a hidden context reminder so the next turn knows why the runner stopped.
|
||||
function buildSessionsYieldContextMessage(message: string): string {
|
||||
return `${message}\n\n[Context: The previous turn ended intentionally via sessions_yield while waiting for a follow-up event.]`;
|
||||
}
|
||||
|
||||
// Return a synthetic aborted response so pi-agent-core unwinds without a real provider call.
|
||||
function createYieldAbortedResponse(model: { api?: string; provider?: string; id?: string }): {
|
||||
[Symbol.asyncIterator]: () => AsyncGenerator<never, void, unknown>;
|
||||
result: () => Promise<{
|
||||
role: "assistant";
|
||||
content: Array<{ type: "text"; text: string }>;
|
||||
stopReason: "aborted";
|
||||
api: string;
|
||||
provider: string;
|
||||
model: string;
|
||||
usage: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
totalTokens: number;
|
||||
cost: {
|
||||
input: number;
|
||||
output: number;
|
||||
cacheRead: number;
|
||||
cacheWrite: number;
|
||||
total: number;
|
||||
};
|
||||
};
|
||||
timestamp: number;
|
||||
}>;
|
||||
} {
|
||||
const message = {
|
||||
role: "assistant" as const,
|
||||
content: [{ type: "text" as const, text: "" }],
|
||||
stopReason: "aborted" as const,
|
||||
api: model.api ?? "",
|
||||
provider: model.provider ?? "",
|
||||
model: model.id ?? "",
|
||||
usage: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
total: 0,
|
||||
},
|
||||
},
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {},
|
||||
result: async () => message,
|
||||
};
|
||||
}
|
||||
|
||||
// Queue a hidden steering message so pi-agent-core skips any remaining tool calls.
|
||||
function queueSessionsYieldInterruptMessage(activeSession: {
|
||||
agent: { steer: (message: AgentMessage) => void };
|
||||
}) {
|
||||
activeSession.agent.steer({
|
||||
role: "custom",
|
||||
customType: SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE,
|
||||
content: "[sessions_yield interrupt]",
|
||||
display: false,
|
||||
details: { source: "sessions_yield" },
|
||||
timestamp: Date.now(),
|
||||
});
|
||||
}
|
||||
|
||||
// Append the caller-provided yield payload as a hidden session message once the run is idle.
|
||||
async function persistSessionsYieldContextMessage(
|
||||
activeSession: {
|
||||
sendCustomMessage: (
|
||||
message: {
|
||||
customType: string;
|
||||
content: string;
|
||||
display: boolean;
|
||||
details?: Record<string, unknown>;
|
||||
},
|
||||
options?: { triggerTurn?: boolean },
|
||||
) => Promise<void>;
|
||||
},
|
||||
message: string,
|
||||
) {
|
||||
await activeSession.sendCustomMessage(
|
||||
{
|
||||
customType: SESSIONS_YIELD_CONTEXT_CUSTOM_TYPE,
|
||||
content: buildSessionsYieldContextMessage(message),
|
||||
display: false,
|
||||
details: { source: "sessions_yield", message },
|
||||
},
|
||||
{ triggerTurn: false },
|
||||
);
|
||||
}
|
||||
|
||||
// Remove the synthetic yield interrupt + aborted assistant entry from the live transcript.
|
||||
function stripSessionsYieldArtifacts(activeSession: {
|
||||
messages: AgentMessage[];
|
||||
agent: { replaceMessages: (messages: AgentMessage[]) => void };
|
||||
sessionManager?: unknown;
|
||||
}) {
|
||||
const strippedMessages = activeSession.messages.slice();
|
||||
while (strippedMessages.length > 0) {
|
||||
const last = strippedMessages.at(-1) as
|
||||
| AgentMessage
|
||||
| { role?: string; customType?: string; stopReason?: string };
|
||||
if (last?.role === "assistant" && "stopReason" in last && last.stopReason === "aborted") {
|
||||
strippedMessages.pop();
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
last?.role === "custom" &&
|
||||
"customType" in last &&
|
||||
last.customType === SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE
|
||||
) {
|
||||
strippedMessages.pop();
|
||||
continue;
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (strippedMessages.length !== activeSession.messages.length) {
|
||||
activeSession.agent.replaceMessages(strippedMessages);
|
||||
}
|
||||
|
||||
const sessionManager = activeSession.sessionManager as
|
||||
| {
|
||||
fileEntries?: Array<{
|
||||
type?: string;
|
||||
id?: string;
|
||||
parentId?: string | null;
|
||||
message?: { role?: string; stopReason?: string };
|
||||
customType?: string;
|
||||
}>;
|
||||
byId?: Map<string, { id: string }>;
|
||||
leafId?: string | null;
|
||||
_rewriteFile?: () => void;
|
||||
}
|
||||
| undefined;
|
||||
const fileEntries = sessionManager?.fileEntries;
|
||||
const byId = sessionManager?.byId;
|
||||
if (!fileEntries || !byId) {
|
||||
return;
|
||||
}
|
||||
|
||||
let changed = false;
|
||||
while (fileEntries.length > 1) {
|
||||
const last = fileEntries.at(-1);
|
||||
if (!last || last.type === "session") {
|
||||
break;
|
||||
}
|
||||
const isYieldAbortAssistant =
|
||||
last.type === "message" &&
|
||||
last.message?.role === "assistant" &&
|
||||
last.message?.stopReason === "aborted";
|
||||
const isYieldInterruptMessage =
|
||||
last.type === "custom_message" && last.customType === SESSIONS_YIELD_INTERRUPT_CUSTOM_TYPE;
|
||||
if (!isYieldAbortAssistant && !isYieldInterruptMessage) {
|
||||
break;
|
||||
}
|
||||
fileEntries.pop();
|
||||
if (last.id) {
|
||||
byId.delete(last.id);
|
||||
}
|
||||
sessionManager.leafId = last.parentId ?? null;
|
||||
changed = true;
|
||||
}
|
||||
if (changed) {
|
||||
sessionManager._rewriteFile?.();
|
||||
}
|
||||
}
|
||||
|
||||
export function isOllamaCompatProvider(model: {
|
||||
provider?: string;
|
||||
baseUrl?: string;
|
||||
@@ -1121,6 +1301,13 @@ export async function runEmbeddedAttempt(
|
||||
config: params.config,
|
||||
sessionAgentId,
|
||||
});
|
||||
// Track sessions_yield tool invocation (callback pattern, like clientToolCallDetected)
|
||||
let yieldDetected = false;
|
||||
let yieldMessage: string | null = null;
|
||||
// Late-binding reference so onYield can abort the session (declared after tool creation)
|
||||
let abortSessionForYield: (() => void) | null = null;
|
||||
let queueYieldInterruptForSession: (() => void) | null = null;
|
||||
let yieldAbortSettled: Promise<void> | null = null;
|
||||
// Check if the model supports native image input
|
||||
const modelHasVision = params.model.input?.includes("image") ?? false;
|
||||
const toolsRaw = params.disableTools
|
||||
@@ -1165,6 +1352,13 @@ export async function runEmbeddedAttempt(
|
||||
requireExplicitMessageTarget:
|
||||
params.requireExplicitMessageTarget ?? isSubagentSessionKey(params.sessionKey),
|
||||
disableMessageTool: params.disableMessageTool,
|
||||
onYield: (message) => {
|
||||
yieldDetected = true;
|
||||
yieldMessage = message;
|
||||
queueYieldInterruptForSession?.();
|
||||
runAbortController.abort("sessions_yield");
|
||||
abortSessionForYield?.();
|
||||
},
|
||||
});
|
||||
const toolsEnabled = supportsModelTools(params.model);
|
||||
const tools = sanitizeToolsForGoogle({
|
||||
@@ -1475,6 +1669,12 @@ export async function runEmbeddedAttempt(
|
||||
throw new Error("Embedded agent session missing");
|
||||
}
|
||||
const activeSession = session;
|
||||
abortSessionForYield = () => {
|
||||
yieldAbortSettled = Promise.resolve(activeSession.abort());
|
||||
};
|
||||
queueYieldInterruptForSession = () => {
|
||||
queueSessionsYieldInterruptMessage(activeSession);
|
||||
};
|
||||
removeToolResultContextGuard = installToolResultContextGuard({
|
||||
agent: activeSession.agent,
|
||||
contextWindowTokens: Math.max(
|
||||
@@ -1646,6 +1846,17 @@ export async function runEmbeddedAttempt(
|
||||
};
|
||||
}
|
||||
|
||||
const innerStreamFn = activeSession.agent.streamFn;
|
||||
activeSession.agent.streamFn = (model, context, options) => {
|
||||
const signal = runAbortController.signal as AbortSignal & { reason?: unknown };
|
||||
if (yieldDetected && signal.aborted && signal.reason === "sessions_yield") {
|
||||
return createYieldAbortedResponse(model) as unknown as Awaited<
|
||||
ReturnType<typeof innerStreamFn>
|
||||
>;
|
||||
}
|
||||
return innerStreamFn(model, context, options);
|
||||
};
|
||||
|
||||
// Some models emit tool names with surrounding whitespace (e.g. " read ").
|
||||
// pi-agent-core dispatches tool calls with exact string matching, so normalize
|
||||
// names on the live response stream before tool execution.
|
||||
@@ -1746,6 +1957,7 @@ export async function runEmbeddedAttempt(
|
||||
}
|
||||
|
||||
let aborted = Boolean(params.abortSignal?.aborted);
|
||||
let yieldAborted = false;
|
||||
let timedOut = false;
|
||||
let timedOutDuringCompaction = false;
|
||||
const getAbortReason = (signal: AbortSignal): unknown =>
|
||||
@@ -2075,8 +2287,29 @@ export async function runEmbeddedAttempt(
|
||||
await abortable(activeSession.prompt(effectivePrompt));
|
||||
}
|
||||
} catch (err) {
|
||||
// Yield-triggered abort is intentional — treat as clean stop, not error.
|
||||
// Check the abort reason to distinguish from external aborts (timeout, user cancel)
|
||||
// that may race after yieldDetected is set.
|
||||
yieldAborted =
|
||||
yieldDetected &&
|
||||
isRunnerAbortError(err) &&
|
||||
err instanceof Error &&
|
||||
err.cause === "sessions_yield";
|
||||
if (yieldAborted) {
|
||||
aborted = false;
|
||||
// Ensure the session abort has fully settled before proceeding.
|
||||
if (yieldAbortSettled) {
|
||||
// eslint-disable-next-line @typescript-eslint/await-thenable -- abort() returns Promise<void> per AgentSession.d.ts
|
||||
await yieldAbortSettled;
|
||||
}
|
||||
stripSessionsYieldArtifacts(activeSession);
|
||||
if (yieldMessage) {
|
||||
await persistSessionsYieldContextMessage(activeSession, yieldMessage);
|
||||
}
|
||||
} else {
|
||||
promptError = err;
|
||||
promptErrorSource = "prompt";
|
||||
}
|
||||
} finally {
|
||||
log.debug(
|
||||
`embedded run prompt end: runId=${params.runId} sessionId=${params.sessionId} durationMs=${Date.now() - promptStartedAt}`,
|
||||
@@ -2103,7 +2336,11 @@ export async function runEmbeddedAttempt(
|
||||
await params.onBlockReplyFlush();
|
||||
}
|
||||
|
||||
const compactionRetryWait = await waitForCompactionRetryWithAggregateTimeout({
|
||||
// Skip compaction wait when yield aborted the run — the signal is
|
||||
// already tripped and abortable() would immediately reject.
|
||||
const compactionRetryWait = yieldAborted
|
||||
? { timedOut: false }
|
||||
: await waitForCompactionRetryWithAggregateTimeout({
|
||||
waitForCompactionRetry,
|
||||
abortable,
|
||||
aggregateTimeoutMs: COMPACTION_RETRY_AGGREGATE_TIMEOUT_MS,
|
||||
@@ -2365,6 +2602,7 @@ export async function runEmbeddedAttempt(
|
||||
compactionCount: getCompactionCount(),
|
||||
// Client tool call detected (OpenResponses hosted tools)
|
||||
clientToolCall: clientToolCallDetected ?? undefined,
|
||||
yieldDetected: yieldDetected || undefined,
|
||||
};
|
||||
} finally {
|
||||
// Always tear down the session (and release the lock) before we leave this attempt.
|
||||
|
||||
@@ -64,4 +64,6 @@ export type EmbeddedRunAttemptResult = {
|
||||
compactionCount?: number;
|
||||
/** Client tool call detected (OpenResponses hosted tools). */
|
||||
clientToolCall?: { name: string; params: Record<string, unknown> };
|
||||
/** True when sessions_yield tool was called during this attempt. */
|
||||
yieldDetected?: boolean;
|
||||
};
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
/**
|
||||
* Integration test proving that sessions_yield produces a clean end_turn exit
|
||||
* with no pending tool calls, so the parent session is idle when subagent
|
||||
* results arrive.
|
||||
*/
|
||||
import "./run.overflow-compaction.mocks.shared.js";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import { runEmbeddedPiAgent } from "./run.js";
|
||||
import { makeAttemptResult } from "./run.overflow-compaction.fixture.js";
|
||||
import { mockedGlobalHookRunner } from "./run.overflow-compaction.mocks.shared.js";
|
||||
import {
|
||||
mockedRunEmbeddedAttempt,
|
||||
overflowBaseRunParams,
|
||||
} from "./run.overflow-compaction.shared-test.js";
|
||||
import { isEmbeddedPiRunActive, queueEmbeddedPiMessage } from "./runs.js";
|
||||
|
||||
describe("sessions_yield orchestration", () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks();
|
||||
mockedGlobalHookRunner.hasHooks.mockImplementation(() => false);
|
||||
});
|
||||
|
||||
it("parent session is idle after yield — end_turn, no pendingToolCalls", async () => {
|
||||
const sessionId = "yield-parent-session";
|
||||
|
||||
// Simulate an attempt where sessions_yield was called
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce(
|
||||
makeAttemptResult({
|
||||
promptError: null,
|
||||
sessionIdUsed: sessionId,
|
||||
yieldDetected: true,
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
...overflowBaseRunParams,
|
||||
sessionId,
|
||||
runId: "run-yield-orchestration",
|
||||
});
|
||||
|
||||
// 1. Run completed with end_turn (yield causes clean exit)
|
||||
expect(result.meta.stopReason).toBe("end_turn");
|
||||
|
||||
// 2. No pending tool calls (yield is NOT a client tool call)
|
||||
expect(result.meta.pendingToolCalls).toBeUndefined();
|
||||
|
||||
// 3. Parent session is IDLE (not in ACTIVE_EMBEDDED_RUNS)
|
||||
expect(isEmbeddedPiRunActive(sessionId)).toBe(false);
|
||||
|
||||
// 4. Steer would fail (message delivery must take direct path, not steer)
|
||||
expect(queueEmbeddedPiMessage(sessionId, "subagent result")).toBe(false);
|
||||
});
|
||||
|
||||
it("clientToolCall takes precedence over yieldDetected", async () => {
|
||||
// Edge case: both flags set (shouldn't happen, but clientToolCall wins)
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce(
|
||||
makeAttemptResult({
|
||||
promptError: null,
|
||||
yieldDetected: true,
|
||||
clientToolCall: { name: "hosted_tool", params: { arg: "value" } },
|
||||
}),
|
||||
);
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
...overflowBaseRunParams,
|
||||
runId: "run-yield-vs-client-tool",
|
||||
});
|
||||
|
||||
// clientToolCall wins — tool_calls stopReason, pendingToolCalls populated
|
||||
expect(result.meta.stopReason).toBe("tool_calls");
|
||||
expect(result.meta.pendingToolCalls).toHaveLength(1);
|
||||
expect(result.meta.pendingToolCalls![0].name).toBe("hosted_tool");
|
||||
});
|
||||
|
||||
it("normal attempt without yield has no stopReason override", async () => {
|
||||
mockedRunEmbeddedAttempt.mockResolvedValueOnce(makeAttemptResult({ promptError: null }));
|
||||
|
||||
const result = await runEmbeddedPiAgent({
|
||||
...overflowBaseRunParams,
|
||||
runId: "run-no-yield",
|
||||
});
|
||||
|
||||
// Neither clientToolCall nor yieldDetected → stopReason is undefined
|
||||
expect(result.meta.stopReason).toBeUndefined();
|
||||
expect(result.meta.pendingToolCalls).toBeUndefined();
|
||||
});
|
||||
});
|
||||
@@ -267,6 +267,8 @@ export function createOpenClawCodingTools(options?: {
|
||||
disableMessageTool?: boolean;
|
||||
/** Whether the sender is an owner (required for owner-only tools). */
|
||||
senderIsOwner?: boolean;
|
||||
/** Callback invoked when sessions_yield tool is called. */
|
||||
onYield?: (message: string) => Promise<void> | void;
|
||||
}): AnyAgentTool[] {
|
||||
const execToolName = "exec";
|
||||
const sandbox = options?.sandbox?.enabled ? options.sandbox : undefined;
|
||||
@@ -530,6 +532,7 @@ export function createOpenClawCodingTools(options?: {
|
||||
requesterSenderId: options?.senderId,
|
||||
senderIsOwner: options?.senderIsOwner,
|
||||
sessionId: options?.sessionId,
|
||||
onYield: options?.onYield,
|
||||
}),
|
||||
];
|
||||
const toolsForMemoryFlush =
|
||||
|
||||
@@ -7,11 +7,14 @@ vi.mock("@mariozechner/pi-ai", async (importOriginal) => {
|
||||
const original = await importOriginal<typeof import("@mariozechner/pi-ai")>();
|
||||
return {
|
||||
...original,
|
||||
getOAuthApiKey: () => undefined,
|
||||
getOAuthProviders: () => [],
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("@mariozechner/pi-ai/oauth", () => ({
|
||||
getOAuthApiKey: () => undefined,
|
||||
getOAuthProviders: () => [],
|
||||
}));
|
||||
|
||||
import { createOpenClawCodingTools } from "./pi-tools.js";
|
||||
|
||||
describe("FS tools with workspaceOnly=false", () => {
|
||||
|
||||
@@ -22,6 +22,7 @@ export const DEFAULT_TOOL_ALLOW = [
|
||||
"sessions_history",
|
||||
"sessions_send",
|
||||
"sessions_spawn",
|
||||
"sessions_yield",
|
||||
"subagents",
|
||||
"session_status",
|
||||
] as const;
|
||||
|
||||
@@ -145,6 +145,14 @@ const CORE_TOOL_DEFINITIONS: CoreToolDefinition[] = [
|
||||
profiles: ["coding"],
|
||||
includeInOpenClawGroup: true,
|
||||
},
|
||||
{
|
||||
id: "sessions_yield",
|
||||
label: "sessions_yield",
|
||||
description: "End turn to receive sub-agent results",
|
||||
sectionId: "sessions",
|
||||
profiles: ["coding"],
|
||||
includeInOpenClawGroup: true,
|
||||
},
|
||||
{
|
||||
id: "subagents",
|
||||
label: "subagents",
|
||||
|
||||
45
src/agents/tools/sessions-yield-tool.test.ts
Normal file
45
src/agents/tools/sessions-yield-tool.test.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { createSessionsYieldTool } from "./sessions-yield-tool.js";
|
||||
|
||||
describe("sessions_yield tool", () => {
|
||||
it("returns error when no sessionId is provided", async () => {
|
||||
const onYield = vi.fn();
|
||||
const tool = createSessionsYieldTool({ onYield });
|
||||
const result = await tool.execute("call-1", {});
|
||||
expect(result.details).toMatchObject({
|
||||
status: "error",
|
||||
error: "No session context",
|
||||
});
|
||||
expect(onYield).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
it("invokes onYield callback with default message", async () => {
|
||||
const onYield = vi.fn();
|
||||
const tool = createSessionsYieldTool({ sessionId: "test-session", onYield });
|
||||
const result = await tool.execute("call-1", {});
|
||||
expect(result.details).toMatchObject({ status: "yielded", message: "Turn yielded." });
|
||||
expect(onYield).toHaveBeenCalledOnce();
|
||||
expect(onYield).toHaveBeenCalledWith("Turn yielded.");
|
||||
});
|
||||
|
||||
it("passes the custom message through the yield callback", async () => {
|
||||
const onYield = vi.fn();
|
||||
const tool = createSessionsYieldTool({ sessionId: "test-session", onYield });
|
||||
const result = await tool.execute("call-1", { message: "Waiting for fact-checker" });
|
||||
expect(result.details).toMatchObject({
|
||||
status: "yielded",
|
||||
message: "Waiting for fact-checker",
|
||||
});
|
||||
expect(onYield).toHaveBeenCalledOnce();
|
||||
expect(onYield).toHaveBeenCalledWith("Waiting for fact-checker");
|
||||
});
|
||||
|
||||
it("returns error without onYield callback", async () => {
|
||||
const tool = createSessionsYieldTool({ sessionId: "test-session" });
|
||||
const result = await tool.execute("call-1", {});
|
||||
expect(result.details).toMatchObject({
|
||||
status: "error",
|
||||
error: "Yield not supported in this context",
|
||||
});
|
||||
});
|
||||
});
|
||||
32
src/agents/tools/sessions-yield-tool.ts
Normal file
32
src/agents/tools/sessions-yield-tool.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import { Type } from "@sinclair/typebox";
|
||||
import type { AnyAgentTool } from "./common.js";
|
||||
import { jsonResult, readStringParam } from "./common.js";
|
||||
|
||||
const SessionsYieldToolSchema = Type.Object({
|
||||
message: Type.Optional(Type.String()),
|
||||
});
|
||||
|
||||
export function createSessionsYieldTool(opts?: {
|
||||
sessionId?: string;
|
||||
onYield?: (message: string) => Promise<void> | void;
|
||||
}): AnyAgentTool {
|
||||
return {
|
||||
label: "Yield",
|
||||
name: "sessions_yield",
|
||||
description:
|
||||
"End your current turn. Use after spawning subagents to receive their results as the next message.",
|
||||
parameters: SessionsYieldToolSchema,
|
||||
execute: async (_toolCallId, args) => {
|
||||
const params = args as Record<string, unknown>;
|
||||
const message = readStringParam(params, "message") || "Turn yielded.";
|
||||
if (!opts?.sessionId) {
|
||||
return jsonResult({ status: "error", error: "No session context" });
|
||||
}
|
||||
if (!opts?.onYield) {
|
||||
return jsonResult({ status: "error", error: "Yield not supported in this context" });
|
||||
}
|
||||
await opts.onYield(message);
|
||||
return jsonResult({ status: "yielded", message });
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -69,7 +69,12 @@ export type TelegramBotOptions = {
|
||||
|
||||
export { getTelegramSequentialKey };
|
||||
|
||||
function readRequestUrl(input: RequestInfo | URL): string | null {
|
||||
type TelegramFetchInput = Parameters<NonNullable<ApiClientOptions["fetch"]>>[0];
|
||||
type TelegramFetchInit = Parameters<NonNullable<ApiClientOptions["fetch"]>>[1];
|
||||
type GlobalFetchInput = Parameters<typeof globalThis.fetch>[0];
|
||||
type GlobalFetchInit = Parameters<typeof globalThis.fetch>[1];
|
||||
|
||||
function readRequestUrl(input: TelegramFetchInput): string | null {
|
||||
if (typeof input === "string") {
|
||||
return input;
|
||||
}
|
||||
@@ -83,7 +88,7 @@ function readRequestUrl(input: RequestInfo | URL): string | null {
|
||||
return null;
|
||||
}
|
||||
|
||||
function extractTelegramApiMethod(input: RequestInfo | URL): string | null {
|
||||
function extractTelegramApiMethod(input: TelegramFetchInput): string | null {
|
||||
const url = readRequestUrl(input);
|
||||
if (!url) {
|
||||
return null;
|
||||
@@ -150,7 +155,7 @@ export function createTelegramBot(opts: TelegramBotOptions) {
|
||||
// Use manual event forwarding instead of AbortSignal.any() to avoid the cross-realm
|
||||
// AbortSignal issue in Node.js (grammY's signal may come from a different module context,
|
||||
// causing "signals[0] must be an instance of AbortSignal" errors).
|
||||
finalFetch = ((input: RequestInfo | URL, init?: RequestInit) => {
|
||||
finalFetch = ((input: TelegramFetchInput, init?: TelegramFetchInit) => {
|
||||
const controller = new AbortController();
|
||||
const abortWith = (signal: AbortSignal) => controller.abort(signal.reason);
|
||||
const onShutdown = () => abortWith(shutdownSignal);
|
||||
@@ -162,13 +167,16 @@ export function createTelegramBot(opts: TelegramBotOptions) {
|
||||
}
|
||||
if (init?.signal) {
|
||||
if (init.signal.aborted) {
|
||||
abortWith(init.signal);
|
||||
abortWith(init.signal as unknown as AbortSignal);
|
||||
} else {
|
||||
onRequestAbort = () => abortWith(init.signal as AbortSignal);
|
||||
init.signal.addEventListener("abort", onRequestAbort, { once: true });
|
||||
init.signal.addEventListener("abort", onRequestAbort);
|
||||
}
|
||||
}
|
||||
return callFetch(input, { ...init, signal: controller.signal }).finally(() => {
|
||||
return callFetch(input as GlobalFetchInput, {
|
||||
...(init as GlobalFetchInit),
|
||||
signal: controller.signal,
|
||||
}).finally(() => {
|
||||
shutdownSignal.removeEventListener("abort", onShutdown);
|
||||
if (init?.signal && onRequestAbort) {
|
||||
init.signal.removeEventListener("abort", onRequestAbort);
|
||||
@@ -178,7 +186,7 @@ export function createTelegramBot(opts: TelegramBotOptions) {
|
||||
}
|
||||
if (finalFetch) {
|
||||
const baseFetch = finalFetch;
|
||||
finalFetch = ((input: RequestInfo | URL, init?: RequestInit) => {
|
||||
finalFetch = ((input: TelegramFetchInput, init?: TelegramFetchInit) => {
|
||||
return Promise.resolve(baseFetch(input, init)).catch((err: unknown) => {
|
||||
try {
|
||||
tagTelegramNetworkError(err, {
|
||||
|
||||
@@ -12,12 +12,14 @@ vi.mock("@mariozechner/pi-ai", async (importOriginal) => {
|
||||
return {
|
||||
...original,
|
||||
completeSimple: vi.fn(),
|
||||
// Some auth helpers import oauth provider metadata at module load time.
|
||||
getOAuthProviders: () => [],
|
||||
getOAuthApiKey: vi.fn(async () => null),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("@mariozechner/pi-ai/oauth", () => ({
|
||||
getOAuthProviders: () => [],
|
||||
getOAuthApiKey: vi.fn(async () => null),
|
||||
}));
|
||||
|
||||
vi.mock("../agents/pi-embedded-runner/model.js", () => ({
|
||||
resolveModel: vi.fn((provider: string, modelId: string) => ({
|
||||
model: {
|
||||
|
||||
Reference in New Issue
Block a user