Files
openclaw/src/agents/pi-embedded-runner/model.ts

104 lines
3.7 KiB
TypeScript
Raw Normal View History

import { join } from "node:path";
2026-01-14 01:08:15 +00:00
import type { Api, Model } from "@mariozechner/pi-ai";
import { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent";
2026-01-14 01:08:15 +00:00
2026-01-30 03:15:10 +01:00
import type { OpenClawConfig } from "../../config/config.js";
import type { ModelDefinitionConfig } from "../../config/types.js";
2026-01-30 03:15:10 +01:00
import { resolveOpenClawAgentDir } from "../agent-paths.js";
2026-01-14 01:08:15 +00:00
import { DEFAULT_CONTEXT_TOKENS } from "../defaults.js";
import { normalizeModelCompat } from "../model-compat.js";
import { normalizeProviderId } from "../model-selection.js";
type InlineModelEntry = ModelDefinitionConfig & { provider: string; baseUrl?: string };
type InlineProviderConfig = {
baseUrl?: string;
api?: ModelDefinitionConfig["api"];
models?: ModelDefinitionConfig[];
};
export function buildInlineProviderModels(
providers: Record<string, InlineProviderConfig>,
): InlineModelEntry[] {
return Object.entries(providers).flatMap(([providerId, entry]) => {
const trimmed = providerId.trim();
if (!trimmed) return [];
return (entry?.models ?? []).map((model) => ({
...model,
provider: trimmed,
baseUrl: entry?.baseUrl,
api: model.api ?? entry?.api,
}));
});
}
2026-01-14 01:08:15 +00:00
2026-01-30 03:15:10 +01:00
export function buildModelAliasLines(cfg?: OpenClawConfig) {
2026-01-14 01:08:15 +00:00
const models = cfg?.agents?.defaults?.models ?? {};
const entries: Array<{ alias: string; model: string }> = [];
for (const [keyRaw, entryRaw] of Object.entries(models)) {
const model = String(keyRaw ?? "").trim();
if (!model) continue;
const alias = String((entryRaw as { alias?: string } | undefined)?.alias ?? "").trim();
2026-01-14 01:08:15 +00:00
if (!alias) continue;
entries.push({ alias, model });
}
return entries
.sort((a, b) => a.alias.localeCompare(b.alias))
.map((entry) => `- ${entry.alias}: ${entry.model}`);
}
export function resolveModel(
provider: string,
modelId: string,
agentDir?: string,
2026-01-30 03:15:10 +01:00
cfg?: OpenClawConfig,
2026-01-14 01:08:15 +00:00
): {
model?: Model<Api>;
error?: string;
authStorage: AuthStorage;
modelRegistry: ModelRegistry;
2026-01-14 01:08:15 +00:00
} {
2026-01-30 03:15:10 +01:00
const resolvedAgentDir = agentDir ?? resolveOpenClawAgentDir();
const authStorage = new AuthStorage(join(resolvedAgentDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, join(resolvedAgentDir, "models.json"));
2026-01-14 01:08:15 +00:00
const model = modelRegistry.find(provider, modelId) as Model<Api> | null;
if (!model) {
const providers = cfg?.models?.providers ?? {};
const inlineModels = buildInlineProviderModels(providers);
const normalizedProvider = normalizeProviderId(provider);
const inlineMatch = inlineModels.find(
2026-01-20 13:52:59 +00:00
(entry) => normalizeProviderId(entry.provider) === normalizedProvider && entry.id === modelId,
);
2026-01-14 01:08:15 +00:00
if (inlineMatch) {
const normalized = normalizeModelCompat(inlineMatch as Model<Api>);
return {
model: normalized,
2026-01-14 01:08:15 +00:00
authStorage,
modelRegistry,
};
}
const providerCfg = providers[provider];
if (providerCfg || modelId.startsWith("mock-")) {
const fallbackModel: Model<Api> = normalizeModelCompat({
id: modelId,
name: modelId,
api: providerCfg?.api ?? "openai-responses",
provider,
baseUrl: providerCfg?.baseUrl,
2026-01-14 01:08:15 +00:00
reasoning: false,
input: ["text"],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: providerCfg?.models?.[0]?.contextWindow ?? DEFAULT_CONTEXT_TOKENS,
maxTokens: providerCfg?.models?.[0]?.maxTokens ?? DEFAULT_CONTEXT_TOKENS,
2026-01-14 01:08:15 +00:00
} as Model<Api>);
return { model: fallbackModel, authStorage, modelRegistry };
2026-01-14 01:08:15 +00:00
}
return {
error: `Unknown model: ${provider}/${modelId}`,
authStorage,
modelRegistry,
};
}
return { model: normalizeModelCompat(model), authStorage, modelRegistry };
2026-01-14 01:08:15 +00:00
}