Merge branch 'main' into v2

This commit is contained in:
fullex 2025-11-24 16:22:32 +08:00
commit 0045bf6c9c
142 changed files with 11167 additions and 2536 deletions

View File

@ -77,7 +77,7 @@ jobs:
with: with:
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
commit-message: "feat(bot): Weekly automated script run" commit-message: "feat(bot): Weekly automated script run"
title: "🤖 Weekly Automated Update: ${{ env.CURRENT_DATE }}" title: "🤖 Weekly Auto I18N Sync: ${{ env.CURRENT_DATE }}"
body: | body: |
This PR includes changes generated by the weekly auto i18n. This PR includes changes generated by the weekly auto i18n.
Review the changes before merging. Review the changes before merging.

View File

@ -1,152 +0,0 @@
diff --git a/dist/index.js b/dist/index.js
index c2ef089c42e13a8ee4a833899a415564130e5d79..75efa7baafb0f019fb44dd50dec1641eee8879e7 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
index d75c0cc13c41192408c1f3f2d29d76a7bffa6268..ada730b8cb97d9b7d4cb32883a1d1ff416404d9b 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/internal/index.js b/dist/internal/index.js
index 277cac8dc734bea2fb4f3e9a225986b402b24f48..bb704cd79e602eb8b0cee1889e42497d59ccdb7a 100644
--- a/dist/internal/index.js
+++ b/dist/internal/index.js
@@ -432,7 +432,15 @@ function prepareTools({
var _a;
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
const toolWarnings = [];
- const isGemini2 = modelId.includes("gemini-2");
+ // These changes could be safely removed when @ai-sdk/google v3 released.
+ const isLatest = (
+ [
+ 'gemini-flash-latest',
+ 'gemini-flash-lite-latest',
+ 'gemini-pro-latest',
+ ]
+ ).some(id => id === modelId);
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
const supportsFileSearch = modelId.includes("gemini-2.5");
if (tools == null) {
@@ -458,7 +466,7 @@ function prepareTools({
providerDefinedTools.forEach((tool) => {
switch (tool.id) {
case "google.google_search":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ googleSearch: {} });
} else if (supportsDynamicRetrieval) {
googleTools2.push({
@@ -474,7 +482,7 @@ function prepareTools({
}
break;
case "google.url_context":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ urlContext: {} });
} else {
toolWarnings.push({
@@ -485,7 +493,7 @@ function prepareTools({
}
break;
case "google.code_execution":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ codeExecution: {} });
} else {
toolWarnings.push({
@@ -507,7 +515,7 @@ function prepareTools({
}
break;
case "google.vertex_rag_store":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({
retrieval: {
vertex_rag_store: {
diff --git a/dist/internal/index.mjs b/dist/internal/index.mjs
index 03b7cc591be9b58bcc2e775a96740d9f98862a10..347d2c12e1cee79f0f8bb258f3844fb0522a6485 100644
--- a/dist/internal/index.mjs
+++ b/dist/internal/index.mjs
@@ -424,7 +424,15 @@ function prepareTools({
var _a;
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
const toolWarnings = [];
- const isGemini2 = modelId.includes("gemini-2");
+ // These changes could be safely removed when @ai-sdk/google v3 released.
+ const isLatest = (
+ [
+ 'gemini-flash-latest',
+ 'gemini-flash-lite-latest',
+ 'gemini-pro-latest',
+ ]
+ ).some(id => id === modelId);
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
const supportsFileSearch = modelId.includes("gemini-2.5");
if (tools == null) {
@@ -450,7 +458,7 @@ function prepareTools({
providerDefinedTools.forEach((tool) => {
switch (tool.id) {
case "google.google_search":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ googleSearch: {} });
} else if (supportsDynamicRetrieval) {
googleTools2.push({
@@ -466,7 +474,7 @@ function prepareTools({
}
break;
case "google.url_context":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ urlContext: {} });
} else {
toolWarnings.push({
@@ -477,7 +485,7 @@ function prepareTools({
}
break;
case "google.code_execution":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ codeExecution: {} });
} else {
toolWarnings.push({
@@ -499,7 +507,7 @@ function prepareTools({
}
break;
case "google.vertex_rag_store":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({
retrieval: {
vertex_rag_store: {
@@ -1434,9 +1442,7 @@ var googleTools = {
vertexRagStore
};
export {
- GoogleGenerativeAILanguageModel,
getGroundingMetadataSchema,
- getUrlContextMetadataSchema,
- googleTools
+ getUrlContextMetadataSchema, GoogleGenerativeAILanguageModel, googleTools
};
//# sourceMappingURL=index.mjs.map
\ No newline at end of file

View File

@ -0,0 +1,26 @@
diff --git a/dist/index.js b/dist/index.js
index dc7b74ba55337c491cdf1ab3e39ca68cc4187884..ace8c90591288e42c2957e93c9bf7984f1b22444 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -472,7 +472,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
index 8390439c38cb7eaeb52080862cd6f4c58509e67c..a7647f2e11700dff7e1c8d4ae8f99d3637010733 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -478,7 +478,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts

View File

@ -1,131 +0,0 @@
diff --git a/dist/index.mjs b/dist/index.mjs
index b3f018730a93639aad7c203f15fb1aeb766c73f4..ade2a43d66e9184799d072153df61ef7be4ea110 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -296,7 +296,14 @@ var HuggingFaceResponsesLanguageModel = class {
metadata: huggingfaceOptions == null ? void 0 : huggingfaceOptions.metadata,
instructions: huggingfaceOptions == null ? void 0 : huggingfaceOptions.instructions,
...preparedTools && { tools: preparedTools },
- ...preparedToolChoice && { tool_choice: preparedToolChoice }
+ ...preparedToolChoice && { tool_choice: preparedToolChoice },
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ reasoning: {
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ effort: huggingfaceOptions.reasoningEffort,
+ }),
+ },
+ }),
};
return { args: baseArgs, warnings };
}
@@ -365,6 +372,20 @@ var HuggingFaceResponsesLanguageModel = class {
}
break;
}
+ case 'reasoning': {
+ for (const contentPart of part.content) {
+ content.push({
+ type: 'reasoning',
+ text: contentPart.text,
+ providerMetadata: {
+ huggingface: {
+ itemId: part.id,
+ },
+ },
+ });
+ }
+ break;
+ }
case "mcp_call": {
content.push({
type: "tool-call",
@@ -519,6 +540,11 @@ var HuggingFaceResponsesLanguageModel = class {
id: value.item.call_id,
toolName: value.item.name
});
+ } else if (value.item.type === 'reasoning') {
+ controller.enqueue({
+ type: 'reasoning-start',
+ id: value.item.id,
+ });
}
return;
}
@@ -570,6 +596,22 @@ var HuggingFaceResponsesLanguageModel = class {
});
return;
}
+ if (isReasoningDeltaChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-delta',
+ id: value.item_id,
+ delta: value.delta,
+ });
+ return;
+ }
+
+ if (isReasoningEndChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-end',
+ id: value.item_id,
+ });
+ return;
+ }
},
flush(controller) {
controller.enqueue({
@@ -593,7 +635,8 @@ var HuggingFaceResponsesLanguageModel = class {
var huggingfaceResponsesProviderOptionsSchema = z2.object({
metadata: z2.record(z2.string(), z2.string()).optional(),
instructions: z2.string().optional(),
- strictJsonSchema: z2.boolean().optional()
+ strictJsonSchema: z2.boolean().optional(),
+ reasoningEffort: z2.string().optional(),
});
var huggingfaceResponsesResponseSchema = z2.object({
id: z2.string(),
@@ -727,12 +770,31 @@ var responseCreatedChunkSchema = z2.object({
model: z2.string()
})
});
+var reasoningTextDeltaChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.delta'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ delta: z2.string(),
+ sequence_number: z2.number(),
+});
+
+var reasoningTextEndChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.done'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ text: z2.string(),
+ sequence_number: z2.number(),
+});
var huggingfaceResponsesChunkSchema = z2.union([
responseOutputItemAddedSchema,
responseOutputItemDoneSchema,
textDeltaChunkSchema,
responseCompletedChunkSchema,
responseCreatedChunkSchema,
+ reasoningTextDeltaChunkSchema,
+ reasoningTextEndChunkSchema,
z2.object({ type: z2.string() }).loose()
// fallback for unknown chunks
]);
@@ -751,6 +813,12 @@ function isResponseCompletedChunk(chunk) {
function isResponseCreatedChunk(chunk) {
return chunk.type === "response.created";
}
+function isReasoningDeltaChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.delta';
+}
+function isReasoningEndChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.done';
+}
// src/huggingface-provider.ts
function createHuggingFace(options = {}) {

View File

@ -0,0 +1,140 @@
diff --git a/dist/index.js b/dist/index.js
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
arguments: import_v43.z.string()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}),
finish_reason: import_v43.z.string().nullish()
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
arguments: import_v43.z.string().nullish()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: import_v43.z.string().nullish()
diff --git a/dist/index.mjs b/dist/index.mjs
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
arguments: z3.string()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}),
finish_reason: z3.string().nullish()
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
arguments: z3.string().nullish()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: z3.string().nullish()

View File

@ -1,5 +1,5 @@
diff --git a/dist/index.js b/dist/index.js diff --git a/dist/index.js b/dist/index.js
index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa96b52ac0d 100644 index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c75674e59637 100644
--- a/dist/index.js --- a/dist/index.js
+++ b/dist/index.js +++ b/dist/index.js
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)( @@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
@ -18,7 +18,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
tool_calls: import_v42.z.array( tool_calls: import_v42.z.array(
import_v42.z.object({ import_v42.z.object({
index: import_v42.z.number(), index: import_v42.z.number(),
@@ -785,6 +787,13 @@ var OpenAIChatLanguageModel = class { @@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
if (text != null && text.length > 0) { if (text != null && text.length > 0) {
content.push({ type: "text", text }); content.push({ type: "text", text });
} }
@ -32,7 +32,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) { for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
content.push({ content.push({
type: "tool-call", type: "tool-call",
@@ -866,6 +875,7 @@ var OpenAIChatLanguageModel = class { @@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
}; };
let metadataExtracted = false; let metadataExtracted = false;
let isActiveText = false; let isActiveText = false;
@ -40,7 +40,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
const providerMetadata = { openai: {} }; const providerMetadata = { openai: {} };
return { return {
stream: response.pipeThrough( stream: response.pipeThrough(
@@ -923,6 +933,21 @@ var OpenAIChatLanguageModel = class { @@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
return; return;
} }
const delta = choice.delta; const delta = choice.delta;
@ -62,7 +62,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
if (delta.content != null) { if (delta.content != null) {
if (!isActiveText) { if (!isActiveText) {
controller.enqueue({ type: "text-start", id: "0" }); controller.enqueue({ type: "text-start", id: "0" });
@@ -1035,6 +1060,9 @@ var OpenAIChatLanguageModel = class { @@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
} }
}, },
flush(controller) { flush(controller) {

View File

@ -14,7 +14,7 @@
} }
}, },
"enabled": true, "enabled": true,
"includes": ["**/*.json", "!*.json", "!**/package.json"] "includes": ["**/*.json", "!*.json", "!**/package.json", "!coverage/**"]
}, },
"css": { "css": {
"formatter": { "formatter": {
@ -23,7 +23,7 @@
}, },
"files": { "files": {
"ignoreUnknown": false, "ignoreUnknown": false,
"includes": ["**", "!**/.claude/**"], "includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
"maxSize": 2097152 "maxSize": 2097152
}, },
"formatter": { "formatter": {

View File

@ -135,58 +135,66 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo: releaseInfo:
releaseNotes: | releaseNotes: |
<!--LANG:en--> <!--LANG:en-->
What's New in v1.7.0-rc.1 What's New in v1.7.0-rc.2
🎉 MAJOR NEW FEATURE: AI Agents
- Create and manage custom AI agents with specialized tools and permissions
- Dedicated agent sessions with persistent SQLite storage, separate from regular chats
- Real-time tool approval system - review and approve agent actions dynamically
- MCP (Model Context Protocol) integration for connecting external tools
- Slash commands support for quick agent interactions
- OpenAI-compatible REST API for agent access
✨ New Features: ✨ New Features:
- AI Providers: Added support for Hugging Face, Mistral, Perplexity, and SophNet - AI Models: Added support for Gemini 3, Gemini 3 Pro with image preview, and GPT-5.1
- Knowledge Base: OpenMinerU document preprocessor, full-text search in notes, enhanced tool selection - Import: ChatGPT conversation import feature
- Image & OCR: Intel OVMS painting provider and Intel OpenVINO (NPU) OCR support - Agent: Git Bash detection and requirement check for Windows agents
- MCP Management: Redesigned interface with dual-column layout for easier management - Search: Native language emoji search with CLDR data format
- Languages: Added German language support - Provider: Endpoint type support for cherryin provider
- Debug: Local crash mini dump file for better diagnostics
⚡ Improvements:
- Upgraded to Electron 38.7.0
- Enhanced system shutdown handling and automatic update checks
- Improved proxy bypass rules
🐛 Important Bug Fixes: 🐛 Important Bug Fixes:
- Fixed streaming response issues across multiple AI providers - Error Handling: Improved error display in AiSdkToChunkAdapter
- Fixed session list scrolling problems - Database: Optimized DatabaseManager and fixed libsql crash issues
- Fixed knowledge base deletion errors - Memory: Fixed EventEmitter memory leak in useApiServer hook
- Messages: Fixed adjacent user messages appearing when assistant message contains error only
- Tools: Fixed missing execution state for approved tool permissions
- File Processing: Fixed "no such file" error for non-English filenames in open-mineru
- PDF: Fixed mineru PDF validation and 403 errors
- Images: Fixed base64 image save issues
- Search: Fixed URL context and web search capability
- Models: Added verbosity parameter support for GPT-5 models
- UI: Improved todo tool status icon visibility and colors
- Providers: Fixed api-host for vercel ai-gateway and gitcode update config
⚡ Improvements:
- SDK: Updated Google and OpenAI SDKs with new features
- UI: Simplified knowledge base creation modal and agent creation form
- Tools: Replaced renderToolContent function with ToolContent component
- Architecture: Namespace tool call IDs with session ID to prevent conflicts
- Config: AI SDK configuration refactoring
<!--LANG:zh-CN--> <!--LANG:zh-CN-->
v1.7.0-rc.1 新特性 v1.7.0-rc.2 新特性
🎉 重大更新AI Agent 智能体系统
- 创建和管理专属 AI Agent配置专用工具和权限
- 独立的 Agent 会话,使用 SQLite 持久化存储,与普通聊天分离
- 实时工具审批系统 - 动态审查和批准 Agent 操作
- MCP模型上下文协议集成连接外部工具
- 支持斜杠命令快速交互
- 兼容 OpenAI 的 REST API 访问
✨ 新功能: ✨ 新功能:
- AI 提供商:新增 Hugging Face、Mistral、Perplexity 和 SophNet 支持 - AI 模型:新增 Gemini 3、Gemini 3 Pro 图像预览支持,以及 GPT-5.1
- 知识库OpenMinerU 文档预处理器、笔记全文搜索、增强的工具选择 - 导入ChatGPT 对话导入功能
- 图像与 OCRIntel OVMS 绘图提供商和 Intel OpenVINO (NPU) OCR 支持 - AgentWindows Agent 的 Git Bash 检测和要求检查
- MCP 管理:重构管理界面,采用双列布局,更加方便管理 - 搜索:支持本地语言 emoji 搜索CLDR 数据格式)
- 语言:新增德语支持 - 提供商cherryin provider 的端点类型支持
- 调试:启用本地崩溃 mini dump 文件,方便诊断
⚡ 改进:
- 升级到 Electron 38.7.0
- 增强的系统关机处理和自动更新检查
- 改进的代理绕过规则
🐛 重要修复: 🐛 重要修复:
- 修复多个 AI 提供商的流式响应问题 - 错误处理:改进 AiSdkToChunkAdapter 的错误显示
- 修复会话列表滚动问题 - 数据库:优化 DatabaseManager 并修复 libsql 崩溃问题
- 修复知识库删除错误 - 内存:修复 useApiServer hook 中的 EventEmitter 内存泄漏
- 消息:修复当助手消息仅包含错误时相邻用户消息出现的问题
- 工具:修复批准工具权限缺少执行状态的问题
- 文件处理:修复 open-mineru 处理非英文文件名时的"无此文件"错误
- PDF修复 mineru PDF 验证和 403 错误
- 图片:修复 base64 图片保存问题
- 搜索:修复 URL 上下文和网络搜索功能
- 模型:为 GPT-5 模型添加 verbosity 参数支持
- UI改进 todo 工具状态图标可见性和颜色
- 提供商:修复 vercel ai-gateway 和 gitcode 更新配置的 api-host
⚡ 改进:
- SDK更新 Google 和 OpenAI SDK新增功能和修复
- UI简化知识库创建模态框和 agent 创建表单
- 工具:用 ToolContent 组件替换 renderToolContent 函数,提升可读性
- 架构:用会话 ID 命名工具调用 ID 以防止冲突
- 配置AI SDK 配置重构
<!--LANG:END--> <!--LANG:END-->

View File

@ -113,16 +113,17 @@
"@agentic/exa": "^7.3.3", "@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3", "@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3", "@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^3.0.53", "@ai-sdk/amazon-bedrock": "^3.0.56",
"@ai-sdk/anthropic": "^2.0.44", "@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/cerebras": "^1.0.31", "@ai-sdk/cerebras": "^1.0.31",
"@ai-sdk/gateway": "^2.0.9", "@ai-sdk/gateway": "^2.0.13",
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch", "@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/google-vertex": "^3.0.68", "@ai-sdk/google-vertex": "^3.0.72",
"@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.8#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch", "@ai-sdk/huggingface": "^0.0.10",
"@ai-sdk/mistral": "^2.0.23", "@ai-sdk/mistral": "^2.0.24",
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch", "@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/perplexity": "^2.0.17", "@ai-sdk/perplexity": "^2.0.20",
"@ai-sdk/test-server": "^0.0.1",
"@ant-design/v5-patch-for-react-19": "^1.0.3", "@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0", "@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch", "@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
@ -167,7 +168,7 @@
"@modelcontextprotocol/sdk": "^1.17.5", "@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0", "@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15", "@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.2.0", "@openrouter/ai-sdk-provider": "^1.2.5",
"@opentelemetry/api": "^1.9.0", "@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0", "@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0", "@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
@ -244,7 +245,7 @@
"@viz-js/lang-dot": "^1.0.5", "@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0", "@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4", "@xyflow/react": "^12.4.4",
"ai": "^5.0.90", "ai": "^5.0.98",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch", "antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1", "archiver": "^7.0.1",
"async-mutex": "^0.5.0", "async-mutex": "^0.5.0",
@ -417,8 +418,11 @@
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch", "@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch", "@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch", "@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch", "@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/google@npm:2.0.36": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch" "@ai-sdk/google@npm:2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/openai@npm:2.0.71": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
}, },
"packageManager": "yarn@4.9.1", "packageManager": "yarn@4.9.1",
"lint-staged": { "lint-staged": {

View File

@ -42,7 +42,7 @@
}, },
"dependencies": { "dependencies": {
"@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.12" "@ai-sdk/provider-utils": "^3.0.17"
}, },
"devDependencies": { "devDependencies": {
"tsdown": "^0.13.3", "tsdown": "^0.13.3",

View File

@ -39,13 +39,13 @@
"ai": "^5.0.26" "ai": "^5.0.26"
}, },
"dependencies": { "dependencies": {
"@ai-sdk/anthropic": "^2.0.43", "@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/azure": "^2.0.66", "@ai-sdk/azure": "^2.0.73",
"@ai-sdk/deepseek": "^1.0.27", "@ai-sdk/deepseek": "^1.0.29",
"@ai-sdk/openai-compatible": "^1.0.26", "@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/provider": "^2.0.0", "@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.16", "@ai-sdk/provider-utils": "^3.0.17",
"@ai-sdk/xai": "^2.0.31", "@ai-sdk/xai": "^2.0.34",
"zod": "^4.1.5" "zod": "^4.1.5"
}, },
"devDependencies": { "devDependencies": {

View File

@ -0,0 +1,180 @@
/**
* Mock Provider Instances
* Provides mock implementations for all supported AI providers
*/
import type { ImageModelV2, LanguageModelV2 } from '@ai-sdk/provider'
import { vi } from 'vitest'
/**
* Creates a mock language model with customizable behavior
*/
export function createMockLanguageModel(overrides?: Partial<LanguageModelV2>): LanguageModelV2 {
return {
specificationVersion: 'v1',
provider: 'mock-provider',
modelId: 'mock-model',
defaultObjectGenerationMode: 'tool',
doGenerate: vi.fn().mockResolvedValue({
text: 'Mock response text',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 20,
totalTokens: 30
},
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
yield {
type: 'text-delta',
textDelta: 'Mock '
}
yield {
type: 'text-delta',
textDelta: 'streaming '
}
yield {
type: 'text-delta',
textDelta: 'response'
}
yield {
type: 'finish',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 15,
totalTokens: 25
}
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
...overrides
} as LanguageModelV2
}
/**
* Creates a mock image model with customizable behavior
*/
export function createMockImageModel(overrides?: Partial<ImageModelV2>): ImageModelV2 {
return {
specificationVersion: 'v2',
provider: 'mock-provider',
modelId: 'mock-image-model',
doGenerate: vi.fn().mockResolvedValue({
images: [
{
base64: 'mock-base64-image-data',
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
mimeType: 'image/png'
}
],
warnings: []
}),
...overrides
} as ImageModelV2
}
/**
* Mock provider configurations for testing
*/
export const mockProviderConfigs = {
openai: {
apiKey: 'sk-test-openai-key-123456789',
baseURL: 'https://api.openai.com/v1',
organization: 'test-org'
},
anthropic: {
apiKey: 'sk-ant-test-key-123456789',
baseURL: 'https://api.anthropic.com'
},
google: {
apiKey: 'test-google-api-key-123456789',
baseURL: 'https://generativelanguage.googleapis.com/v1'
},
xai: {
apiKey: 'xai-test-key-123456789',
baseURL: 'https://api.x.ai/v1'
},
azure: {
apiKey: 'test-azure-key-123456789',
resourceName: 'test-resource',
deployment: 'test-deployment'
},
deepseek: {
apiKey: 'sk-test-deepseek-key-123456789',
baseURL: 'https://api.deepseek.com/v1'
},
openrouter: {
apiKey: 'sk-or-test-key-123456789',
baseURL: 'https://openrouter.ai/api/v1'
},
huggingface: {
apiKey: 'hf_test_key_123456789',
baseURL: 'https://api-inference.huggingface.co'
},
'openai-compatible': {
apiKey: 'test-compatible-key-123456789',
baseURL: 'https://api.example.com/v1',
name: 'test-provider'
},
'openai-chat': {
apiKey: 'sk-test-chat-key-123456789',
baseURL: 'https://api.openai.com/v1'
}
} as const
/**
* Mock provider instances for testing
*/
export const mockProviderInstances = {
openai: {
name: 'openai-mock',
languageModel: createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }),
imageModel: createMockImageModel({ provider: 'openai', modelId: 'dall-e-3' })
},
anthropic: {
name: 'anthropic-mock',
languageModel: createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' })
},
google: {
name: 'google-mock',
languageModel: createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }),
imageModel: createMockImageModel({ provider: 'google', modelId: 'imagen-3.0-generate-001' })
},
xai: {
name: 'xai-mock',
languageModel: createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }),
imageModel: createMockImageModel({ provider: 'xai', modelId: 'grok-2-image-latest' })
},
deepseek: {
name: 'deepseek-mock',
languageModel: createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' })
}
}
export type ProviderId = keyof typeof mockProviderConfigs

View File

@ -0,0 +1,331 @@
/**
* Mock Responses
* Provides realistic mock responses for all provider types
*/
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
/**
* Standard test messages for all scenarios
*/
export const testMessages = {
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
conversation: [
{ role: 'user' as const, content: 'What is the capital of France?' },
{ role: 'assistant' as const, content: 'The capital of France is Paris.' },
{ role: 'user' as const, content: 'What is its population?' }
],
withSystem: [
{ role: 'system' as const, content: 'You are a helpful assistant that provides concise answers.' },
{ role: 'user' as const, content: 'Explain quantum computing in one sentence.' }
],
withImages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
''
}
]
}
],
toolUse: [{ role: 'user' as const, content: 'What is the weather in San Francisco?' }],
multiTurn: [
{ role: 'user' as const, content: 'Can you help me with a math problem?' },
{ role: 'assistant' as const, content: 'Of course! What math problem would you like help with?' },
{ role: 'user' as const, content: 'What is 15 * 23?' },
{ role: 'assistant' as const, content: '15 * 23 = 345' },
{ role: 'user' as const, content: 'Now divide that by 5' }
]
} satisfies Record<string, ModelMessage[]>
/**
* Standard test tools for tool calling scenarios
*/
export const testTools: Record<string, Tool> = {
getWeather: {
description: 'Get the current weather in a given location',
inputSchema: jsonSchema({
type: 'object',
properties: {
location: {
type: 'string',
description: 'The city and state, e.g. San Francisco, CA'
},
unit: {
type: 'string',
enum: ['celsius', 'fahrenheit'],
description: 'The temperature unit to use'
}
},
required: ['location']
}),
execute: async ({ location, unit = 'fahrenheit' }) => {
return {
location,
temperature: unit === 'celsius' ? 22 : 72,
unit,
condition: 'sunny'
}
}
},
calculate: {
description: 'Perform a mathematical calculation',
inputSchema: jsonSchema({
type: 'object',
properties: {
operation: {
type: 'string',
enum: ['add', 'subtract', 'multiply', 'divide'],
description: 'The operation to perform'
},
a: {
type: 'number',
description: 'The first number'
},
b: {
type: 'number',
description: 'The second number'
}
},
required: ['operation', 'a', 'b']
}),
execute: async ({ operation, a, b }) => {
const operations = {
add: (x: number, y: number) => x + y,
subtract: (x: number, y: number) => x - y,
multiply: (x: number, y: number) => x * y,
divide: (x: number, y: number) => x / y
}
return { result: operations[operation as keyof typeof operations](a, b) }
}
},
searchDatabase: {
description: 'Search for information in a database',
inputSchema: jsonSchema({
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query'
},
limit: {
type: 'number',
description: 'Maximum number of results to return',
default: 10
}
},
required: ['query']
}),
execute: async ({ query, limit = 10 }) => {
return {
results: [
{ id: 1, title: `Result 1 for ${query}`, relevance: 0.95 },
{ id: 2, title: `Result 2 for ${query}`, relevance: 0.87 }
].slice(0, limit)
}
}
}
}
/**
* Mock streaming chunks for different providers
*/
export const mockStreamingChunks = {
text: [
{ type: 'text-delta' as const, textDelta: 'Hello' },
{ type: 'text-delta' as const, textDelta: ', ' },
{ type: 'text-delta' as const, textDelta: 'this ' },
{ type: 'text-delta' as const, textDelta: 'is ' },
{ type: 'text-delta' as const, textDelta: 'a ' },
{ type: 'text-delta' as const, textDelta: 'test.' }
],
withToolCall: [
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: '{"location":'
},
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: ' "San Francisco, CA"}'
},
{
type: 'tool-call' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
args: { location: 'San Francisco, CA' }
}
],
withFinish: [
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
{
type: 'finish' as const,
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
}
}
]
}
/**
* Mock complete responses for non-streaming scenarios
*/
export const mockCompleteResponses = {
simple: {
text: 'This is a simple response.',
finishReason: 'stop' as const,
usage: {
promptTokens: 15,
completionTokens: 8,
totalTokens: 23
}
},
withToolCalls: {
text: 'I will check the weather for you.',
toolCalls: [
{
toolCallId: 'call_456',
toolName: 'getWeather',
args: { location: 'New York, NY', unit: 'celsius' }
}
],
finishReason: 'tool-calls' as const,
usage: {
promptTokens: 25,
completionTokens: 12,
totalTokens: 37
}
},
withWarnings: {
text: 'Response with warnings.',
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
},
warnings: [
{
type: 'unsupported-setting' as const,
message: 'Temperature parameter not supported for this model'
}
]
}
}
/**
* Mock image generation responses
*/
export const mockImageResponses = {
single: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82]),
mimeType: 'image/png' as const
},
warnings: []
},
multiple: {
images: [
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAEklEQVR42mNk+M9QzwAEjDAGACCKAgdZ9zImAAAAAElFTkSuQmCC',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
}
],
warnings: []
},
withProviderMetadata: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
providerMetadata: {
openai: {
images: [
{
revisedPrompt: 'A detailed and enhanced version of the original prompt'
}
]
}
},
warnings: []
}
}
/**
* Mock error responses
*/
export const mockErrors = {
invalidApiKey: {
name: 'APIError',
message: 'Invalid API key provided',
statusCode: 401
},
rateLimitExceeded: {
name: 'RateLimitError',
message: 'Rate limit exceeded. Please try again later.',
statusCode: 429,
headers: {
'retry-after': '60'
}
},
modelNotFound: {
name: 'ModelNotFoundError',
message: 'The requested model was not found',
statusCode: 404
},
contextLengthExceeded: {
name: 'ContextLengthError',
message: "This model's maximum context length is 4096 tokens",
statusCode: 400
},
timeout: {
name: 'TimeoutError',
message: 'Request timed out after 30000ms',
code: 'ETIMEDOUT'
},
networkError: {
name: 'NetworkError',
message: 'Network connection failed',
code: 'ECONNREFUSED'
}
}

View File

@ -0,0 +1,329 @@
/**
* Provider-Specific Test Utilities
* Helper functions for testing individual providers with all their parameters
*/
import type { Tool } from 'ai'
import { expect } from 'vitest'
/**
* Provider parameter configurations for comprehensive testing
*/
export const providerParameterMatrix = {
openai: {
models: ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-4o'],
parameters: {
temperature: [0, 0.5, 0.7, 1.0, 1.5, 2.0],
maxTokens: [100, 500, 1000, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
frequencyPenalty: [-2.0, -1.0, 0, 1.0, 2.0],
presencePenalty: [-2.0, -1.0, 0, 1.0, 2.0],
stop: [undefined, ['stop'], ['STOP', 'END']],
seed: [undefined, 12345, 67890],
responseFormat: [undefined, { type: 'json_object' as const }],
user: [undefined, 'test-user-123']
},
toolChoice: ['auto', 'required', 'none', { type: 'function' as const, name: 'getWeather' }],
parallelToolCalls: [true, false]
},
anthropic: {
models: ['claude-3-5-sonnet-20241022', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000, 8000],
topP: [0.1, 0.5, 0.9, 1.0],
topK: [undefined, 1, 5, 10, 40],
stop: [undefined, ['Human:', 'Assistant:']],
metadata: [undefined, { userId: 'test-123' }]
},
toolChoice: ['auto', 'any', { type: 'tool' as const, name: 'getWeather' }]
},
google: {
models: ['gemini-2.0-flash-exp', 'gemini-1.5-pro', 'gemini-1.5-flash'],
parameters: {
temperature: [0, 0.5, 0.9, 1.0],
maxTokens: [100, 1000, 2000, 8000],
topP: [0.1, 0.5, 0.95, 1.0],
topK: [undefined, 1, 16, 40],
stopSequences: [undefined, ['END'], ['STOP', 'TERMINATE']]
},
safetySettings: [
undefined,
[
{ category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_MEDIUM_AND_ABOVE' },
{ category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_ONLY_HIGH' }
]
]
},
xai: {
models: ['grok-2-latest', 'grok-2-1212'],
parameters: {
temperature: [0, 0.5, 1.0, 1.5],
maxTokens: [100, 500, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
stop: [undefined, ['STOP'], ['END', 'TERMINATE']],
seed: [undefined, 12345]
}
},
deepseek: {
models: ['deepseek-chat', 'deepseek-coder'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 0.5, 1.0],
presencePenalty: [0, 0.5, 1.0],
stop: [undefined, ['```'], ['END']]
}
},
azure: {
deployments: ['gpt-4-deployment', 'gpt-35-turbo-deployment'],
parameters: {
temperature: [0, 0.7, 1.0],
maxTokens: [100, 1000, 2000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 1.0],
presencePenalty: [0, 1.0],
stop: [undefined, ['STOP']]
}
}
} as const
/**
* Creates test cases for all parameter combinations
*/
export function generateParameterTestCases<T extends Record<string, any[]>>(
params: T,
maxCombinations = 50
): Array<Partial<{ [K in keyof T]: T[K][number] }>> {
const keys = Object.keys(params) as Array<keyof T>
const testCases: Array<Partial<{ [K in keyof T]: T[K][number] }>> = []
// Generate combinations using sampling strategy for large parameter spaces
const totalCombinations = keys.reduce((acc, key) => acc * params[key].length, 1)
if (totalCombinations <= maxCombinations) {
// Generate all combinations if total is small
generateAllCombinations(params, keys, 0, {}, testCases)
} else {
// Sample diverse combinations if total is large
generateSampledCombinations(params, keys, maxCombinations, testCases)
}
return testCases
}
function generateAllCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
index: number,
current: Partial<{ [K in keyof T]: T[K][number] }>,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
if (index === keys.length) {
results.push({ ...current })
return
}
const key = keys[index]
for (const value of params[key]) {
generateAllCombinations(params, keys, index + 1, { ...current, [key]: value }, results)
}
}
function generateSampledCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
count: number,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
// Generate edge cases first (min/max values)
const edgeCase1: any = {}
const edgeCase2: any = {}
for (const key of keys) {
edgeCase1[key] = params[key][0]
edgeCase2[key] = params[key][params[key].length - 1]
}
results.push(edgeCase1, edgeCase2)
// Generate random combinations for the rest
for (let i = results.length; i < count; i++) {
const combination: any = {}
for (const key of keys) {
const values = params[key]
combination[key] = values[Math.floor(Math.random() * values.length)]
}
results.push(combination)
}
}
/**
* Validates that all provider-specific parameters are correctly passed through
*/
export function validateProviderParams(providerId: string, actualParams: any, expectedParams: any): void {
const requiredFields: Record<string, string[]> = {
openai: ['model', 'messages'],
anthropic: ['model', 'messages'],
google: ['model', 'contents'],
xai: ['model', 'messages'],
deepseek: ['model', 'messages'],
azure: ['messages']
}
const fields = requiredFields[providerId] || ['model', 'messages']
for (const field of fields) {
expect(actualParams).toHaveProperty(field)
}
// Validate optional parameters if they were provided
const optionalParams = ['temperature', 'max_tokens', 'top_p', 'stop', 'tools']
for (const param of optionalParams) {
if (expectedParams[param] !== undefined) {
expect(actualParams[param]).toEqual(expectedParams[param])
}
}
}
/**
* Creates a comprehensive test suite for a provider
*/
// oxlint-disable-next-line no-unused-vars
export function createProviderTestSuite(_providerId: string) {
return {
testBasicCompletion: async (executor: any, model: string) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
expect(result).toBeDefined()
expect(result.text).toBeDefined()
expect(typeof result.text).toBe('string')
},
testStreaming: async (executor: any, model: string) => {
const chunks: any[] = []
const result = await executor.streamText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
for await (const chunk of result.textStream) {
chunks.push(chunk)
}
expect(chunks.length).toBeGreaterThan(0)
},
testTemperature: async (executor: any, model: string, temperatures: number[]) => {
for (const temperature of temperatures) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
temperature
})
expect(result).toBeDefined()
}
},
testMaxTokens: async (executor: any, model: string, maxTokensValues: number[]) => {
for (const maxTokens of maxTokensValues) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
maxTokens
})
expect(result).toBeDefined()
if (result.usage?.completionTokens) {
expect(result.usage.completionTokens).toBeLessThanOrEqual(maxTokens)
}
}
},
testToolCalling: async (executor: any, model: string, tools: Record<string, Tool>) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'What is the weather in SF?' }],
tools
})
expect(result).toBeDefined()
},
testStopSequences: async (executor: any, model: string, stopSequences: string[][]) => {
for (const stop of stopSequences) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Count to 10' }],
stop
})
expect(result).toBeDefined()
}
}
}
}
/**
* Generates test data for vision/multimodal testing
*/
export function createVisionTestData() {
return {
imageUrl: 'https://example.com/test-image.jpg',
base64Image:
'',
messages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
''
}
]
}
]
}
}
/**
* Creates mock responses for different finish reasons
*/
export function createFinishReasonMocks() {
return {
stop: {
text: 'Complete response.',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 }
},
length: {
text: 'Incomplete response due to',
finishReason: 'length' as const,
usage: { promptTokens: 10, completionTokens: 100, totalTokens: 110 }
},
'tool-calls': {
text: 'Calling tools',
finishReason: 'tool-calls' as const,
toolCalls: [{ toolCallId: 'call_1', toolName: 'getWeather', args: { location: 'SF' } }],
usage: { promptTokens: 10, completionTokens: 8, totalTokens: 18 }
},
'content-filter': {
text: '',
finishReason: 'content-filter' as const,
usage: { promptTokens: 10, completionTokens: 0, totalTokens: 10 }
}
}
}

View File

@ -0,0 +1,291 @@
/**
* Test Utilities
* Helper functions for testing AI Core functionality
*/
import { expect, vi } from 'vitest'
import type { ProviderId } from '../fixtures/mock-providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
/**
* Creates a test provider with streaming support
*/
export function createTestStreamingProvider(chunks: any[]) {
return createMockLanguageModel({
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
for (const chunk of chunks) {
yield chunk
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
})
})
}
/**
* Creates a test provider that throws errors
*/
export function createErrorProvider(error: Error) {
return createMockLanguageModel({
doGenerate: vi.fn().mockRejectedValue(error),
doStream: vi.fn().mockImplementation(() => {
throw error
})
})
}
/**
* Collects all chunks from a stream
*/
export async function collectStreamChunks<T>(stream: AsyncIterable<T>): Promise<T[]> {
const chunks: T[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
return chunks
}
/**
* Waits for a specific number of milliseconds
*/
export function wait(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms))
}
/**
* Creates a mock abort controller that aborts after a delay
*/
export function createDelayedAbortController(delayMs: number): AbortController {
const controller = new AbortController()
setTimeout(() => controller.abort(), delayMs)
return controller
}
/**
* Asserts that a function throws an error with a specific message
*/
export async function expectError(fn: () => Promise<any>, expectedMessage?: string | RegExp): Promise<Error> {
try {
await fn()
throw new Error('Expected function to throw an error, but it did not')
} catch (error) {
if (expectedMessage) {
const message = (error as Error).message
if (typeof expectedMessage === 'string') {
if (!message.includes(expectedMessage)) {
throw new Error(`Expected error message to include "${expectedMessage}", but got "${message}"`)
}
} else {
if (!expectedMessage.test(message)) {
throw new Error(`Expected error message to match ${expectedMessage}, but got "${message}"`)
}
}
}
return error as Error
}
}
/**
* Creates a spy function that tracks calls and arguments
*/
export function createSpy<T extends (...args: any[]) => any>() {
const calls: Array<{ args: Parameters<T>; result?: ReturnType<T>; error?: Error }> = []
const spy = vi.fn((...args: Parameters<T>) => {
try {
const result = undefined as ReturnType<T>
calls.push({ args, result })
return result
} catch (error) {
calls.push({ args, error: error as Error })
throw error
}
})
return {
fn: spy,
calls,
getCalls: () => calls,
getCallCount: () => calls.length,
getLastCall: () => calls[calls.length - 1],
reset: () => {
calls.length = 0
spy.mockClear()
}
}
}
/**
* Validates provider configuration
*/
export function validateProviderConfig(providerId: ProviderId) {
const config = mockProviderConfigs[providerId]
if (!config) {
throw new Error(`No mock configuration found for provider: ${providerId}`)
}
if (!config.apiKey) {
throw new Error(`Provider ${providerId} is missing apiKey in mock config`)
}
return config
}
/**
* Creates a test context with common setup
*/
export function createTestContext() {
const mocks = {
languageModel: createMockLanguageModel(),
imageModel: createMockImageModel(),
providers: new Map<string, any>()
}
const cleanup = () => {
mocks.providers.clear()
vi.clearAllMocks()
}
return {
mocks,
cleanup
}
}
/**
* Measures execution time of an async function
*/
export async function measureTime<T>(fn: () => Promise<T>): Promise<{ result: T; duration: number }> {
const start = Date.now()
const result = await fn()
const duration = Date.now() - start
return { result, duration }
}
/**
* Retries a function until it succeeds or max attempts reached
*/
export async function retryUntilSuccess<T>(fn: () => Promise<T>, maxAttempts = 3, delayMs = 100): Promise<T> {
let lastError: Error | undefined
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return await fn()
} catch (error) {
lastError = error as Error
if (attempt < maxAttempts) {
await wait(delayMs)
}
}
}
throw lastError || new Error('All retry attempts failed')
}
/**
* Creates a mock streaming response that emits chunks at intervals
*/
export function createTimedStream<T>(chunks: T[], intervalMs = 10) {
return {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
await wait(intervalMs)
yield chunk
}
}
}
}
/**
* Asserts that two objects are deeply equal, ignoring specified keys
*/
export function assertDeepEqualIgnoring<T extends Record<string, any>>(
actual: T,
expected: T,
ignoreKeys: string[] = []
): void {
const filterKeys = (obj: T): Partial<T> => {
const filtered = { ...obj }
for (const key of ignoreKeys) {
delete filtered[key]
}
return filtered
}
const filteredActual = filterKeys(actual)
const filteredExpected = filterKeys(expected)
expect(filteredActual).toEqual(filteredExpected)
}
/**
* Creates a provider mock that simulates rate limiting
*/
export function createRateLimitedProvider(limitPerSecond: number) {
const calls: number[] = []
return createMockLanguageModel({
doGenerate: vi.fn().mockImplementation(async () => {
const now = Date.now()
calls.push(now)
// Remove calls older than 1 second
const recentCalls = calls.filter((time) => now - time < 1000)
if (recentCalls.length > limitPerSecond) {
throw new Error('Rate limit exceeded')
}
return {
text: 'Rate limited response',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 },
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}
})
})
}
/**
* Validates streaming response structure
*/
export function validateStreamChunk(chunk: any): void {
expect(chunk).toBeDefined()
expect(chunk).toHaveProperty('type')
if (chunk.type === 'text-delta') {
expect(chunk).toHaveProperty('textDelta')
expect(typeof chunk.textDelta).toBe('string')
} else if (chunk.type === 'finish') {
expect(chunk).toHaveProperty('finishReason')
expect(chunk).toHaveProperty('usage')
} else if (chunk.type === 'tool-call') {
expect(chunk).toHaveProperty('toolCallId')
expect(chunk).toHaveProperty('toolName')
expect(chunk).toHaveProperty('args')
}
}
/**
* Creates a test logger that captures log messages
*/
export function createTestLogger() {
const logs: Array<{ level: string; message: string; meta?: any }> = []
return {
info: (message: string, meta?: any) => logs.push({ level: 'info', message, meta }),
warn: (message: string, meta?: any) => logs.push({ level: 'warn', message, meta }),
error: (message: string, meta?: any) => logs.push({ level: 'error', message, meta }),
debug: (message: string, meta?: any) => logs.push({ level: 'debug', message, meta }),
getLogs: () => logs,
clear: () => {
logs.length = 0
}
}
}

View File

@ -0,0 +1,12 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Fixtures
export * from './fixtures/mock-providers'
export * from './fixtures/mock-responses'
// Helpers
export * from './helpers/provider-test-utils'
export * from './helpers/test-utils'

View File

@ -0,0 +1,499 @@
/**
* RuntimeExecutor.generateText Comprehensive Tests
* Tests non-streaming text generation across all providers with various parameters
*/
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockLanguageModel,
mockCompleteResponses,
mockProviderConfigs,
testMessages,
testTools
} from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
generateText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
})
describe('Basic Functionality', () => {
it('should generate text with minimal parameters', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
expect(result.text).toBe('This is a simple response.')
expect(result.finishReason).toBe('stop')
expect(result.usage).toBeDefined()
})
it('should generate with system messages', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should generate with conversation history', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.conversation
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.conversation
})
)
})
})
describe('All Parameter Combinations', () => {
it('should support all parameters together', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
)
})
it('should support partial parameters', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.5,
maxOutputTokens: 100
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxOutputTokens: 100
})
)
})
})
describe('Tool Calling', () => {
beforeEach(() => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withToolCalls as any)
})
it('should support tool calling', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
tools: testTools
})
)
expect(result.toolCalls).toBeDefined()
expect(result.toolCalls).toHaveLength(1)
})
it('should support toolChoice auto', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'auto'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'auto'
})
)
})
it('should support toolChoice required', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'required'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'required'
})
)
})
it('should support toolChoice none', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
tools: testTools,
toolChoice: 'none'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'none'
})
)
})
it('should support specific tool selection', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
)
})
})
describe('Multiple Providers', () => {
it('should work with Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
const anthropicModel = createMockLanguageModel({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet-20241022',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
})
it('should work with Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
const googleModel = createMockLanguageModel({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
await googleExecutor.generateText({
model: 'gemini-2.0-flash-exp',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
})
it('should work with xAI provider', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
const xaiModel = createMockLanguageModel({
provider: 'xai',
modelId: 'grok-2-latest'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
await xaiExecutor.generateText({
model: 'grok-2-latest',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
})
it('should work with DeepSeek provider', async () => {
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
const deepseekModel = createMockLanguageModel({
provider: 'deepseek',
modelId: 'deepseek-chat'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
await deepseekExecutor.generateText({
model: 'deepseek-chat',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
})
})
describe('Plugin Integration', () => {
it('should execute all plugin hooks', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.8 }
}),
transformResult: vi.fn(async (result) => {
pluginCalls.push('transformResult')
return { ...result, text: result.text + ' [modified]' }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const result = await executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginCalls).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
// Verify transformed parameters
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.8
})
)
// Verify transformed result
expect(result.text).toContain('[modified]')
})
it('should handle multiple plugins in order', async () => {
const pluginOrder: string[] = []
const plugin1: AiPlugin = {
name: 'plugin-1',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-1')
return { ...params, temperature: 0.5 }
})
}
const plugin2: AiPlugin = {
name: 'plugin-2',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-2')
return { ...params, maxTokens: 200 }
})
}
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
await executorWithPlugins.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginOrder).toEqual(['plugin-1', 'plugin-2'])
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxTokens: 200
})
)
})
})
describe('Error Handling', () => {
it('should handle API errors', async () => {
const error = new Error('API request failed')
vi.mocked(generateText).mockRejectedValue(error)
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('API request failed')
})
it('should execute onError plugin hook', async () => {
const error = new Error('Generation failed')
vi.mocked(generateText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Generation failed')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
it('should handle model not found error', async () => {
const error = new Error('Model not found: invalid-model')
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
throw error
})
await expect(
executor.generateText({
model: 'invalid-model',
messages: testMessages.simple
})
).rejects.toThrow('Model not found')
})
})
describe('Usage and Metadata', () => {
it('should return usage information', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(result.usage).toBeDefined()
expect(result.usage.inputTokens).toBe(15)
expect(result.usage.outputTokens).toBe(8)
expect(result.usage.totalTokens).toBe(23)
})
it('should handle warnings', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withWarnings as any)
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 2.5 // Unsupported value
})
expect(result.warnings).toBeDefined()
expect(result.warnings).toHaveLength(1)
expect(result.warnings![0].type).toBe('unsupported-setting')
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle aborted request', async () => {
const abortError = new Error('Request aborted')
abortError.name = 'AbortError'
vi.mocked(generateText).mockRejectedValue(abortError)
const abortController = new AbortController()
abortController.abort()
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
).rejects.toThrow('Request aborted')
})
})
})

View File

@ -0,0 +1,525 @@
/**
* RuntimeExecutor.streamText Comprehensive Tests
* Tests streaming text generation across all providers with various parameters
*/
import { streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
streamText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.streamText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
})
describe('Basic Functionality', () => {
it('should stream text with minimal parameters', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Hello'
yield ' '
yield 'World'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Hello' }
yield { type: 'text-delta', textDelta: ' ' }
yield { type: 'text-delta', textDelta: 'World' }
})(),
usage: Promise.resolve({ promptTokens: 5, completionTokens: 3, totalTokens: 8 })
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
const chunks = await collectStreamChunks(result.textStream)
expect(chunks).toEqual(['Hello', ' ', 'World'])
})
it('should stream with system messages', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should stream multi-turn conversations', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Multi-turn response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Multi-turn response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.multiTurn
})
expect(streamText).toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.multiTurn
})
)
})
})
describe('Temperature Parameter', () => {
const temperatures = [0, 0.3, 0.5, 0.7, 0.9, 1.0, 1.5, 2.0]
it.each(temperatures)('should support temperature=%s', async (temperature) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
temperature
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature
})
)
})
})
describe('Max Tokens Parameter', () => {
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
maxOutputTokens: maxTokens
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
maxTokens
})
)
})
})
describe('Top P Parameter', () => {
const topPValues = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0]
it.each(topPValues)('should support topP=%s', async (topP) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
topP
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
topP
})
)
})
})
describe('Frequency and Presence Penalty', () => {
it('should support frequency penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const frequencyPenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty
})
)
}
})
it('should support presence penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const presencePenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
presencePenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
presencePenalty
})
)
}
})
it('should support both penalties together', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
)
})
})
describe('Seed Parameter', () => {
it('should support seed for deterministic output', async () => {
const seeds = [0, 12345, 67890, 999999]
for (const seed of seeds) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
seed
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
seed
})
)
}
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle abort during streaming', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Start'
// Simulate abort
abortController.abort()
throw new Error('Aborted')
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Start' }
throw new Error('Aborted')
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
await expect(async () => {
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream should be interrupted
}
}).rejects.toThrow('Aborted')
})
})
describe('Plugin Integration', () => {
it('should execute plugins during streaming', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.5 }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
// Consume stream
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream chunks
}
expect(pluginCalls).toContain('onRequestStart')
expect(pluginCalls).toContain('transformParams')
// Verify transformed parameters were used
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5
})
)
})
})
describe('Full Stream with Finish Reason', () => {
it('should provide finish reason in full stream', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
yield {
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
}
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
const fullChunks = await collectStreamChunks(result.fullStream)
expect(fullChunks).toHaveLength(2)
expect(fullChunks[0]).toEqual({ type: 'text-delta', textDelta: 'Response' })
expect(fullChunks[1]).toEqual({
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
})
})
})
describe('Error Handling', () => {
it('should handle streaming errors', async () => {
const error = new Error('Streaming failed')
vi.mocked(streamText).mockRejectedValue(error)
await expect(
executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Streaming failed')
})
it('should execute onError plugin hook on failure', async () => {
const error = new Error('Stream error')
vi.mocked(streamText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Stream error')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
})
})

View File

@ -4,3 +4,34 @@ export const defaultAppHeaders = () => {
'X-Title': 'Cherry Studio' 'X-Title': 'Cherry Studio'
} }
} }
// Following two function are not being used for now.
// I may use them in the future, so just keep them commented. - by eurfelux
/**
* Converts an `undefined` value to `null`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `null` if the input is `undefined`; otherwise the input value
*/
// export function toNullIfUndefined<T>(value: T | undefined): T | null {
// if (value === undefined) {
// return null
// } else {
// return value
// }
// }
/**
* Converts a `null` value to `undefined`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `undefined` if the input is `null`; otherwise the input value
*/
// export function toUndefinedIfNull<T>(value: T | null): T | undefined {
// if (value === null) {
// return undefined
// } else {
// return value
// }
// }

View File

@ -3,7 +3,6 @@ import { createServer } from 'node:http'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { IpcChannel } from '@shared/IpcChannel' import { IpcChannel } from '@shared/IpcChannel'
import { agentService } from '../services/agents'
import { windowService } from '../services/WindowService' import { windowService } from '../services/WindowService'
import { app } from './app' import { app } from './app'
import { config } from './config' import { config } from './config'
@ -32,11 +31,6 @@ export class ApiServer {
// Load config // Load config
const { port, host } = await config.load() const { port, host } = await config.load()
// Initialize AgentService
logger.info('Initializing AgentService')
await agentService.initialize()
logger.info('AgentService initialized')
// Create server with Express app // Create server with Express app
this.server = createServer(app) this.server = createServer(app)
this.applyServerTimeouts(this.server) this.applyServerTimeouts(this.server)

View File

@ -44,6 +44,7 @@ import {
import { dataApiService } from '@data/DataApiService' import { dataApiService } from '@data/DataApiService'
import { cacheService } from '@data/CacheService' import { cacheService } from '@data/CacheService'
import { initWebviewHotkeys } from './services/WebviewService' import { initWebviewHotkeys } from './services/WebviewService'
import { runAsyncFunction } from './utils'
const logger = loggerService.withContext('MainEntry') const logger = loggerService.withContext('MainEntry')
@ -256,39 +257,33 @@ if (!app.requestSingleInstanceLock()) {
//start selection assistant service //start selection assistant service
initSelectionService() initSelectionService()
// Initialize Agent Service runAsyncFunction(async () => {
try { // Start API server if enabled or if agents exist
await agentService.initialize() try {
logger.info('Agent service initialized successfully') const config = await apiServerService.getCurrentConfig()
} catch (error: any) { logger.info('API server config:', config)
logger.error('Failed to initialize Agent service:', error)
}
// Start API server if enabled or if agents exist // Check if there are any agents
try { let shouldStart = config.enabled
const config = await apiServerService.getCurrentConfig() if (!shouldStart) {
logger.info('API server config:', config) try {
const { total } = await agentService.listAgents({ limit: 1 })
// Check if there are any agents if (total > 0) {
let shouldStart = config.enabled shouldStart = true
if (!shouldStart) { logger.info(`Detected ${total} agent(s), auto-starting API server`)
try { }
const { total } = await agentService.listAgents({ limit: 1 }) } catch (error: any) {
if (total > 0) { logger.warn('Failed to check agent count:', error)
shouldStart = true
logger.info(`Detected ${total} agent(s), auto-starting API server`)
} }
} catch (error: any) {
logger.warn('Failed to check agent count:', error)
} }
}
if (shouldStart) { if (shouldStart) {
await apiServerService.start() await apiServerService.start()
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
} }
} catch (error: any) { })
logger.error('Failed to check/start API server:', error)
}
}) })
registerProtocolClient(app) registerProtocolClient(app)

View File

@ -1,16 +1,12 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { mcpApiService } from '@main/apiServer/services/mcp' import { mcpApiService } from '@main/apiServer/services/mcp'
import { type ModelValidationError, validateModelId } from '@main/apiServer/utils' import { type ModelValidationError, validateModelId } from '@main/apiServer/utils'
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types' import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
import { objectKeys } from '@types' import { objectKeys } from '@types'
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
import fs from 'fs' import fs from 'fs'
import path from 'path' import path from 'path'
import { MigrationService } from './database/MigrationService' import { DatabaseManager } from './database/DatabaseManager'
import * as schema from './database/schema'
import { dbPath } from './drizzle.config'
import { type AgentModelField, AgentModelValidationError } from './errors' import { type AgentModelField, AgentModelValidationError } from './errors'
import { builtinSlashCommands } from './services/claudecode/commands' import { builtinSlashCommands } from './services/claudecode/commands'
import { builtinTools } from './services/claudecode/tools' import { builtinTools } from './services/claudecode/tools'
@ -18,22 +14,16 @@ import { builtinTools } from './services/claudecode/tools'
const logger = loggerService.withContext('BaseService') const logger = loggerService.withContext('BaseService')
/** /**
* Base service class providing shared database connection and utilities * Base service class providing shared utilities for all agent-related services.
* for all agent-related services.
* *
* Features: * Features:
* - Programmatic schema management (no CLI dependencies) * - Database access through DatabaseManager singleton
* - Automatic table creation and migration * - JSON field serialization/deserialization
* - Schema version tracking and compatibility checks * - Path validation and creation
* - Transaction-based operations for safety * - Model validation
* - Development vs production mode handling * - MCP tools and slash commands listing
* - Connection retry logic with exponential backoff
*/ */
export abstract class BaseService { export abstract class BaseService {
protected static client: Client | null = null
protected static db: LibSQLDatabase<typeof schema> | null = null
protected static isInitialized = false
protected static initializationPromise: Promise<void> | null = null
protected jsonFields: string[] = [ protected jsonFields: string[] = [
'tools', 'tools',
'mcps', 'mcps',
@ -43,23 +33,6 @@ export abstract class BaseService {
'slash_commands' 'slash_commands'
] ]
/**
* Initialize database with retry logic and proper error handling
*/
protected static async initialize(): Promise<void> {
// Return existing initialization if in progress
if (BaseService.initializationPromise) {
return BaseService.initializationPromise
}
if (BaseService.isInitialized) {
return
}
BaseService.initializationPromise = BaseService.performInitialization()
return BaseService.initializationPromise
}
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> { public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
const tools: Tool[] = [] const tools: Tool[] = []
if (agentType === 'claude-code') { if (agentType === 'claude-code') {
@ -99,78 +72,13 @@ export abstract class BaseService {
return [] return []
} }
private static async performInitialization(): Promise<void> { /**
const maxRetries = 3 * Get database instance
let lastError: Error * Automatically waits for initialization to complete
*/
for (let attempt = 1; attempt <= maxRetries; attempt++) { protected async getDatabase() {
try { const dbManager = await DatabaseManager.getInstance()
logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`) return dbManager.getDatabase()
// Ensure the database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
BaseService.client = createClient({
url: `file:${dbPath}`
})
BaseService.db = drizzle(BaseService.client, { schema })
// Run database migrations
const migrationService = new MigrationService(BaseService.db, BaseService.client)
await migrationService.runMigrations()
BaseService.isInitialized = true
logger.info('Agent database initialized successfully')
return
} catch (error) {
lastError = error as Error
logger.warn(`Database initialization attempt ${attempt} failed:`, lastError)
// Clean up on failure
if (BaseService.client) {
try {
BaseService.client.close()
} catch (closeError) {
logger.warn('Failed to close client during cleanup:', closeError as Error)
}
}
BaseService.client = null
BaseService.db = null
// Wait before retrying (exponential backoff)
if (attempt < maxRetries) {
const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s
logger.info(`Retrying in ${delay}ms...`)
await new Promise((resolve) => setTimeout(resolve, delay))
}
}
}
// All retries failed
BaseService.initializationPromise = null
logger.error('Failed to initialize Agent database after all retries:', lastError!)
throw lastError!
}
protected ensureInitialized(): void {
if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) {
throw new Error('Database not initialized. Call initialize() first.')
}
}
protected get database(): LibSQLDatabase<typeof schema> {
this.ensureInitialized()
return BaseService.db!
}
protected get rawClient(): Client {
this.ensureInitialized()
return BaseService.client!
} }
protected serializeJsonFields(data: any): any { protected serializeJsonFields(data: any): any {
@ -282,7 +190,7 @@ export abstract class BaseService {
} }
/** /**
* Force re-initialization (for development/testing) * Validate agent model configuration
*/ */
protected async validateAgentModels( protected async validateAgentModels(
agentType: AgentType, agentType: AgentType,
@ -323,22 +231,4 @@ export abstract class BaseService {
} }
} }
} }
static async reinitialize(): Promise<void> {
BaseService.isInitialized = false
BaseService.initializationPromise = null
if (BaseService.client) {
try {
BaseService.client.close()
} catch (error) {
logger.warn('Failed to close client during reinitialize:', error as Error)
}
}
BaseService.client = null
BaseService.db = null
await BaseService.initialize()
}
} }

View File

@ -0,0 +1,156 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import type { LibSQLDatabase } from 'drizzle-orm/libsql'
import { drizzle } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
import { dbPath } from '../drizzle.config'
import { MigrationService } from './MigrationService'
import * as schema from './schema'
const logger = loggerService.withContext('DatabaseManager')
/**
* Database initialization state
*/
enum InitState {
INITIALIZING = 'initializing',
INITIALIZED = 'initialized',
FAILED = 'failed'
}
/**
* DatabaseManager - Singleton class for managing libsql database connections
*
* Responsibilities:
* - Single source of truth for database connection
* - Thread-safe initialization with state management
* - Automatic migration handling
* - Safe connection cleanup
* - Error recovery and retry logic
* - Windows platform compatibility fixes
*/
export class DatabaseManager {
private static instance: DatabaseManager | null = null
private client: Client | null = null
private db: LibSQLDatabase<typeof schema> | null = null
private state: InitState = InitState.INITIALIZING
/**
* Get the singleton instance (database initialization starts automatically)
*/
public static async getInstance(): Promise<DatabaseManager> {
if (DatabaseManager.instance) {
return DatabaseManager.instance
}
const instance = new DatabaseManager()
await instance.initialize()
DatabaseManager.instance = instance
return instance
}
/**
* Perform the actual initialization
*/
public async initialize(): Promise<void> {
if (this.state === InitState.INITIALIZED) {
return
}
try {
logger.info(`Initializing database at: ${dbPath}`)
// Ensure database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
// Check if database file is corrupted (Windows specific check)
if (fs.existsSync(dbPath)) {
const stats = fs.statSync(dbPath)
if (stats.size === 0) {
logger.warn('Database file is empty, removing corrupted file')
fs.unlinkSync(dbPath)
}
}
// Create client with platform-specific options
this.client = createClient({
url: `file:${dbPath}`,
// intMode: 'number' helps avoid some Windows compatibility issues
intMode: 'number'
})
// Create drizzle instance
this.db = drizzle(this.client, { schema })
// Run migrations
const migrationService = new MigrationService(this.db, this.client)
await migrationService.runMigrations()
this.state = InitState.INITIALIZED
logger.info('Database initialized successfully')
} catch (error) {
const err = error as Error
logger.error('Database initialization failed:', {
error: err.message,
stack: err.stack
})
// Clean up failed initialization
this.cleanupFailedInit()
// Set failed state
this.state = InitState.FAILED
throw new Error(`Database initialization failed: ${err.message || 'Unknown error'}`)
}
}
/**
* Clean up after failed initialization
*/
private cleanupFailedInit(): void {
if (this.client) {
try {
// On Windows, closing a partially initialized client can crash
// Wrap in try-catch and ignore errors during cleanup
this.client.close()
} catch (error) {
logger.warn('Failed to close client during cleanup:', error as Error)
}
}
this.client = null
this.db = null
}
/**
* Get the database instance
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public getDatabase(): LibSQLDatabase<typeof schema> {
return this.db!
}
/**
* Get the raw client (for advanced operations)
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public async getClient(): Promise<Client> {
return this.client!
}
/**
* Check if database is initialized
*/
public isInitialized(): boolean {
return this.state === InitState.INITIALIZED
}
}

View File

@ -7,8 +7,14 @@
* Schema evolution is handled by Drizzle Kit migrations. * Schema evolution is handled by Drizzle Kit migrations.
*/ */
// Database Manager (Singleton)
export * from './DatabaseManager'
// Drizzle ORM schemas // Drizzle ORM schemas
export * from './schema' export * from './schema'
// Repository helpers // Repository helpers
export * from './sessionMessageRepository' export * from './sessionMessageRepository'
// Migration Service
export * from './MigrationService'

View File

@ -15,26 +15,16 @@ import { sessionMessagesTable } from './schema'
const logger = loggerService.withContext('AgentMessageRepository') const logger = loggerService.withContext('AgentMessageRepository')
type TxClient = any
export type PersistUserMessageParams = AgentMessageUserPersistPayload & { export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
sessionId: string sessionId: string
agentSessionId?: string agentSessionId?: string
tx?: TxClient
} }
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & { export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
sessionId: string sessionId: string
agentSessionId: string agentSessionId: string
tx?: TxClient
} }
type PersistExchangeParams = AgentMessagePersistExchangePayload & {
tx?: TxClient
}
type PersistExchangeResult = AgentMessagePersistExchangeResult
class AgentMessageRepository extends BaseService { class AgentMessageRepository extends BaseService {
private static instance: AgentMessageRepository | null = null private static instance: AgentMessageRepository | null = null
@ -87,17 +77,13 @@ class AgentMessageRepository extends BaseService {
return deserialized return deserialized
} }
private getWriter(tx?: TxClient): TxClient {
return tx ?? this.database
}
private async findExistingMessageRow( private async findExistingMessageRow(
writer: TxClient,
sessionId: string, sessionId: string,
role: string, role: string,
messageId: string messageId: string
): Promise<SessionMessageRow | null> { ): Promise<SessionMessageRow | null> {
const candidateRows: SessionMessageRow[] = await writer const database = await this.getDatabase()
const candidateRows: SessionMessageRow[] = await database
.select() .select()
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role))) .where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
@ -122,10 +108,7 @@ class AgentMessageRepository extends BaseService {
private async upsertMessage( private async upsertMessage(
params: PersistUserMessageParams | PersistAssistantMessageParams params: PersistUserMessageParams | PersistAssistantMessageParams
): Promise<AgentSessionMessageEntity> { ): Promise<AgentSessionMessageEntity> {
await AgentMessageRepository.initialize() const { sessionId, agentSessionId = '', payload, metadata, createdAt } = params
this.ensureInitialized()
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
if (!payload?.message?.role) { if (!payload?.message?.role) {
throw new Error('Message payload missing role') throw new Error('Message payload missing role')
@ -135,18 +118,18 @@ class AgentMessageRepository extends BaseService {
throw new Error('Message payload missing id') throw new Error('Message payload missing id')
} }
const writer = this.getWriter(tx) const database = await this.getDatabase()
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString() const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
const serializedPayload = this.serializeMessage(payload) const serializedPayload = this.serializeMessage(payload)
const serializedMetadata = this.serializeMetadata(metadata) const serializedMetadata = this.serializeMetadata(metadata)
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id) const existingRow = await this.findExistingMessageRow(sessionId, payload.message.role, payload.message.id)
if (existingRow) { if (existingRow) {
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || '' const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
await writer await database
.update(sessionMessagesTable) .update(sessionMessagesTable)
.set({ .set({
content: serializedPayload, content: serializedPayload,
@ -175,7 +158,7 @@ class AgentMessageRepository extends BaseService {
updated_at: now updated_at: now
} }
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning() const [saved] = await database.insert(sessionMessagesTable).values(insertData).returning()
return this.deserialize(saved) return this.deserialize(saved)
} }
@ -188,49 +171,38 @@ class AgentMessageRepository extends BaseService {
return this.upsertMessage(params) return this.upsertMessage(params)
} }
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> { async persistExchange(params: AgentMessagePersistExchangePayload): Promise<AgentMessagePersistExchangeResult> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
const { sessionId, agentSessionId, user, assistant } = params const { sessionId, agentSessionId, user, assistant } = params
const result = await this.database.transaction(async (tx) => { const exchangeResult: AgentMessagePersistExchangeResult = {}
const exchangeResult: PersistExchangeResult = {}
if (user?.payload) { if (user?.payload) {
exchangeResult.userMessage = await this.persistUserMessage({ exchangeResult.userMessage = await this.persistUserMessage({
sessionId, sessionId,
agentSessionId, agentSessionId,
payload: user.payload, payload: user.payload,
metadata: user.metadata, metadata: user.metadata,
createdAt: user.createdAt, createdAt: user.createdAt
tx })
}) }
}
if (assistant?.payload) { if (assistant?.payload) {
exchangeResult.assistantMessage = await this.persistAssistantMessage({ exchangeResult.assistantMessage = await this.persistAssistantMessage({
sessionId, sessionId,
agentSessionId, agentSessionId,
payload: assistant.payload, payload: assistant.payload,
metadata: assistant.metadata, metadata: assistant.metadata,
createdAt: assistant.createdAt, createdAt: assistant.createdAt
tx })
}) }
}
return exchangeResult return exchangeResult
})
return result
} }
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> { async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
try { try {
const rows = await this.database const database = await this.getDatabase()
const rows = await database
.select() .select()
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId)) .where(eq(sessionMessagesTable.session_id, sessionId))

View File

@ -32,14 +32,8 @@ export class AgentService extends BaseService {
return AgentService.instance return AgentService.instance
} }
async initialize(): Promise<void> {
await BaseService.initialize()
}
// Agent Methods // Agent Methods
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> { async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
this.ensureInitialized()
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}` const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
const now = new Date().toISOString() const now = new Date().toISOString()
@ -75,8 +69,9 @@ export class AgentService extends BaseService {
updated_at: now updated_at: now
} }
await this.database.insert(agentsTable).values(insertData) const database = await this.getDatabase()
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1) await database.insert(agentsTable).values(insertData)
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) { if (!result[0]) {
throw new Error('Failed to create agent') throw new Error('Failed to create agent')
} }
@ -86,9 +81,8 @@ export class AgentService extends BaseService {
} }
async getAgent(id: string): Promise<GetAgentResponse | null> { async getAgent(id: string): Promise<GetAgentResponse | null> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) { if (!result[0]) {
return null return null
@ -118,9 +112,9 @@ export class AgentService extends BaseService {
} }
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> { async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
this.ensureInitialized() // Build query with pagination // Build query with pagination
const database = await this.getDatabase()
const totalResult = await this.database.select({ count: count() }).from(agentsTable) const totalResult = await database.select({ count: count() }).from(agentsTable)
const sortBy = options.sortBy || 'created_at' const sortBy = options.sortBy || 'created_at'
const orderBy = options.orderBy || 'desc' const orderBy = options.orderBy || 'desc'
@ -128,7 +122,7 @@ export class AgentService extends BaseService {
const sortField = agentsTable[sortBy] const sortField = agentsTable[sortBy]
const orderFn = orderBy === 'asc' ? asc : desc const orderFn = orderBy === 'asc' ? asc : desc
const baseQuery = this.database.select().from(agentsTable).orderBy(orderFn(sortField)) const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
const result = const result =
options.limit !== undefined options.limit !== undefined
@ -151,8 +145,6 @@ export class AgentService extends BaseService {
updates: UpdateAgentRequest, updates: UpdateAgentRequest,
options: { replace?: boolean } = {} options: { replace?: boolean } = {}
): Promise<UpdateAgentResponse | null> { ): Promise<UpdateAgentResponse | null> {
this.ensureInitialized()
// Check if agent exists // Check if agent exists
const existing = await this.getAgent(id) const existing = await this.getAgent(id)
if (!existing) { if (!existing) {
@ -195,22 +187,21 @@ export class AgentService extends BaseService {
} }
} }
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id)) const database = await this.getDatabase()
await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
return await this.getAgent(id) return await this.getAgent(id)
} }
async deleteAgent(id: string): Promise<boolean> { async deleteAgent(id: string): Promise<boolean> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id))
return result.rowsAffected > 0 return result.rowsAffected > 0
} }
async agentExists(id: string): Promise<boolean> { async agentExists(id: string): Promise<boolean> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database
const result = await this.database
.select({ id: agentsTable.id }) .select({ id: agentsTable.id })
.from(agentsTable) .from(agentsTable)
.where(eq(agentsTable.id, id)) .where(eq(agentsTable.id, id))

View File

@ -104,14 +104,9 @@ export class SessionMessageService extends BaseService {
return SessionMessageService.instance return SessionMessageService.instance
} }
async initialize(): Promise<void> {
await BaseService.initialize()
}
async sessionMessageExists(id: number): Promise<boolean> { async sessionMessageExists(id: number): Promise<boolean> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database
const result = await this.database
.select({ id: sessionMessagesTable.id }) .select({ id: sessionMessagesTable.id })
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(eq(sessionMessagesTable.id, id)) .where(eq(sessionMessagesTable.id, id))
@ -124,10 +119,9 @@ export class SessionMessageService extends BaseService {
sessionId: string, sessionId: string,
options: ListOptions = {} options: ListOptions = {}
): Promise<{ messages: AgentSessionMessageEntity[] }> { ): Promise<{ messages: AgentSessionMessageEntity[] }> {
this.ensureInitialized()
// Get messages with pagination // Get messages with pagination
const baseQuery = this.database const database = await this.getDatabase()
const baseQuery = database
.select() .select()
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId)) .where(eq(sessionMessagesTable.session_id, sessionId))
@ -146,9 +140,8 @@ export class SessionMessageService extends BaseService {
} }
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> { async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database
const result = await this.database
.delete(sessionMessagesTable) .delete(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId))) .where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
@ -160,8 +153,6 @@ export class SessionMessageService extends BaseService {
messageData: CreateSessionMessageRequest, messageData: CreateSessionMessageRequest,
abortController: AbortController abortController: AbortController
): Promise<SessionStreamResult> { ): Promise<SessionStreamResult> {
this.ensureInitialized()
return await this.startSessionMessageStream(session, messageData, abortController) return await this.startSessionMessageStream(session, messageData, abortController)
} }
@ -270,10 +261,9 @@ export class SessionMessageService extends BaseService {
} }
private async getLastAgentSessionId(sessionId: string): Promise<string> { private async getLastAgentSessionId(sessionId: string): Promise<string> {
this.ensureInitialized()
try { try {
const result = await this.database const database = await this.getDatabase()
const result = await database
.select({ agent_session_id: sessionMessagesTable.agent_session_id }) .select({ agent_session_id: sessionMessagesTable.agent_session_id })
.from(sessionMessagesTable) .from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, '')))) .where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))

View File

@ -31,10 +31,6 @@ export class SessionService extends BaseService {
return SessionService.instance return SessionService.instance
} }
async initialize(): Promise<void> {
await BaseService.initialize()
}
/** /**
* Override BaseService.listSlashCommands to merge builtin and plugin commands * Override BaseService.listSlashCommands to merge builtin and plugin commands
*/ */
@ -85,13 +81,12 @@ export class SessionService extends BaseService {
agentId: string, agentId: string,
req: Partial<CreateSessionRequest> = {} req: Partial<CreateSessionRequest> = {}
): Promise<GetAgentSessionResponse | null> { ): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized()
// Validate agent exists - we'll need to import AgentService for this check // Validate agent exists - we'll need to import AgentService for this check
// For now, we'll skip this validation to avoid circular dependencies // For now, we'll skip this validation to avoid circular dependencies
// The database foreign key constraint will handle this // The database foreign key constraint will handle this
const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1) const database = await this.getDatabase()
const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
if (!agents[0]) { if (!agents[0]) {
throw new Error('Agent not found') throw new Error('Agent not found')
} }
@ -136,9 +131,10 @@ export class SessionService extends BaseService {
updated_at: now updated_at: now
} }
await this.database.insert(sessionsTable).values(insertData) const db = await this.getDatabase()
await db.insert(sessionsTable).values(insertData)
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1) const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
if (!result[0]) { if (!result[0]) {
throw new Error('Failed to create session') throw new Error('Failed to create session')
@ -149,9 +145,8 @@ export class SessionService extends BaseService {
} }
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> { async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database
const result = await this.database
.select() .select()
.from(sessionsTable) .from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId))) .where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@ -177,8 +172,6 @@ export class SessionService extends BaseService {
agentId?: string, agentId?: string,
options: ListOptions = {} options: ListOptions = {}
): Promise<{ sessions: AgentSessionEntity[]; total: number }> { ): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
this.ensureInitialized()
// Build where conditions // Build where conditions
const whereConditions: SQL[] = [] const whereConditions: SQL[] = []
if (agentId) { if (agentId) {
@ -193,16 +186,13 @@ export class SessionService extends BaseService {
: undefined : undefined
// Get total count // Get total count
const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause) const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
const total = totalResult[0].count const total = totalResult[0].count
// Build list query with pagination - sort by updated_at descending (latest first) // Build list query with pagination - sort by updated_at descending (latest first)
const baseQuery = this.database const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
.select()
.from(sessionsTable)
.where(whereClause)
.orderBy(desc(sessionsTable.updated_at))
const result = const result =
options.limit !== undefined options.limit !== undefined
@ -221,8 +211,6 @@ export class SessionService extends BaseService {
id: string, id: string,
updates: UpdateSessionRequest updates: UpdateSessionRequest
): Promise<UpdateSessionResponse | null> { ): Promise<UpdateSessionResponse | null> {
this.ensureInitialized()
// Check if session exists // Check if session exists
const existing = await this.getSession(agentId, id) const existing = await this.getSession(agentId, id)
if (!existing) { if (!existing) {
@ -263,15 +251,15 @@ export class SessionService extends BaseService {
} }
} }
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id)) const database = await this.getDatabase()
await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
return await this.getSession(agentId, id) return await this.getSession(agentId, id)
} }
async deleteSession(agentId: string, id: string): Promise<boolean> { async deleteSession(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database
const result = await this.database
.delete(sessionsTable) .delete(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId))) .where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@ -279,9 +267,8 @@ export class SessionService extends BaseService {
} }
async sessionExists(agentId: string, id: string): Promise<boolean> { async sessionExists(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized() const database = await this.getDatabase()
const result = await database
const result = await this.database
.select({ id: sessionsTable.id }) .select({ id: sessionsTable.id })
.from(sessionsTable) .from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId))) .where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))

View File

@ -2,7 +2,14 @@
import { EventEmitter } from 'node:events' import { EventEmitter } from 'node:events'
import { createRequire } from 'node:module' import { createRequire } from 'node:module'
import type { CanUseTool, McpHttpServerConfig, Options, SDKMessage } from '@anthropic-ai/claude-agent-sdk' import type {
CanUseTool,
HookCallback,
McpHttpServerConfig,
Options,
PreToolUseHookInput,
SDKMessage
} from '@anthropic-ai/claude-agent-sdk'
import { query } from '@anthropic-ai/claude-agent-sdk' import { query } from '@anthropic-ai/claude-agent-sdk'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { config as apiConfigService } from '@main/apiServer/config' import { config as apiConfigService } from '@main/apiServer/config'
@ -157,6 +164,63 @@ class ClaudeCodeService implements AgentServiceInterface {
}) })
} }
const preToolUseHook: HookCallback = async (input, toolUseID, options) => {
// Type guard to ensure we're handling PreToolUse event
if (input.hook_event_name !== 'PreToolUse') {
return {}
}
const hookInput = input as PreToolUseHookInput
const toolName = hookInput.tool_name
logger.debug('PreToolUse hook triggered', {
session_id: hookInput.session_id,
tool_name: hookInput.tool_name,
tool_use_id: toolUseID,
tool_input: hookInput.tool_input,
cwd: hookInput.cwd,
permission_mode: hookInput.permission_mode,
autoAllowTools: autoAllowTools
})
if (options?.signal?.aborted) {
logger.debug('PreToolUse hook signal already aborted; skipping tool use', {
tool_name: hookInput.tool_name
})
return {}
}
// handle auto approved tools since it never triggers canUseTool
const normalizedToolName = normalizeToolName(toolName)
if (toolUseID) {
const bypassAll = input.permission_mode === 'bypassPermissions'
const autoAllowed = autoAllowTools.has(toolName) || autoAllowTools.has(normalizedToolName)
if (bypassAll || autoAllowed) {
const namespacedToolCallId = buildNamespacedToolCallId(session.id, toolUseID)
logger.debug('handling auto approved tools', {
toolName,
normalizedToolName,
namespacedToolCallId,
permission_mode: input.permission_mode,
autoAllowTools
})
const isRecord = (v: unknown): v is Record<string, unknown> => {
return !!v && typeof v === 'object' && !Array.isArray(v)
}
const toolInput = isRecord(input.tool_input) ? input.tool_input : {}
await promptForToolApproval(toolName, toolInput, {
...options,
toolCallId: namespacedToolCallId,
autoApprove: true
})
}
}
// Return to proceed without modification
return {}
}
// Build SDK options from parameters // Build SDK options from parameters
const options: Options = { const options: Options = {
abortController, abortController,
@ -180,7 +244,14 @@ class ClaudeCodeService implements AgentServiceInterface {
permissionMode: session.configuration?.permission_mode, permissionMode: session.configuration?.permission_mode,
maxTurns: session.configuration?.max_turns, maxTurns: session.configuration?.max_turns,
allowedTools: session.allowed_tools, allowedTools: session.allowed_tools,
canUseTool canUseTool,
hooks: {
PreToolUse: [
{
hooks: [preToolUseHook]
}
]
}
} }
if (session.accessible_paths.length > 1) { if (session.accessible_paths.length > 1) {

View File

@ -31,6 +31,7 @@ type PendingPermissionRequest = {
abortListener?: () => void abortListener?: () => void
originalInput: Record<string, unknown> originalInput: Record<string, unknown>
toolName: string toolName: string
toolCallId?: string
} }
type RendererPermissionRequestPayload = { type RendererPermissionRequestPayload = {
@ -45,6 +46,7 @@ type RendererPermissionRequestPayload = {
createdAt: number createdAt: number
expiresAt: number expiresAt: number
suggestions: PermissionUpdate[] suggestions: PermissionUpdate[]
autoApprove?: boolean
} }
type RendererPermissionResultPayload = { type RendererPermissionResultPayload = {
@ -52,6 +54,7 @@ type RendererPermissionResultPayload = {
behavior: ToolPermissionBehavior behavior: ToolPermissionBehavior
message?: string message?: string
reason: 'response' | 'timeout' | 'aborted' | 'no-window' reason: 'response' | 'timeout' | 'aborted' | 'no-window'
toolCallId?: string
} }
const pendingRequests = new Map<string, PendingPermissionRequest>() const pendingRequests = new Map<string, PendingPermissionRequest>()
@ -145,7 +148,8 @@ const finalizeRequest = (
requestId, requestId,
behavior: update.behavior, behavior: update.behavior,
message: update.behavior === 'deny' ? update.message : undefined, message: update.behavior === 'deny' ? update.message : undefined,
reason reason,
toolCallId: pending.toolCallId
} }
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload) const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
@ -210,6 +214,7 @@ const ensureIpcHandlersRegistered = () => {
type PromptForToolApprovalOptions = { type PromptForToolApprovalOptions = {
signal: AbortSignal signal: AbortSignal
suggestions?: PermissionUpdate[] suggestions?: PermissionUpdate[]
autoApprove?: boolean
// NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID. // NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID.
// Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0` // Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0`
@ -270,7 +275,8 @@ export async function promptForToolApproval(
inputPreview, inputPreview,
createdAt, createdAt,
expiresAt, expiresAt,
suggestions: sanitizedSuggestions suggestions: sanitizedSuggestions,
autoApprove: options.autoApprove
} }
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' } const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
@ -299,7 +305,8 @@ export async function promptForToolApproval(
timeout, timeout,
originalInput: sanitizedInput, originalInput: sanitizedInput,
toolName, toolName,
signal: options?.signal signal: options?.signal,
toolCallId: options.toolCallId
} }
if (options?.signal) { if (options?.signal) {

View File

@ -385,14 +385,13 @@ export class AiSdkToChunkAdapter {
case 'error': case 'error':
this.onChunk({ this.onChunk({
type: ChunkType.ERROR, type: ChunkType.ERROR,
error: error: AISDKError.isInstance(chunk.error)
chunk.error instanceof AISDKError ? chunk.error
? chunk.error : new ProviderSpecificError({
: new ProviderSpecificError({ message: formatErrorMessage(chunk.error),
message: formatErrorMessage(chunk.error), provider: 'unknown',
provider: 'unknown', cause: chunk.error
cause: chunk.error })
})
}) })
break break

View File

@ -32,6 +32,7 @@ import {
prepareSpecialProviderConfig, prepareSpecialProviderConfig,
providerToAiSdkConfig providerToAiSdkConfig
} from './provider/providerConfig' } from './provider/providerConfig'
import type { AiSdkConfig } from './types'
const logger = loggerService.withContext('ModernAiProvider') const logger = loggerService.withContext('ModernAiProvider')
@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider { export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider private legacyProvider: LegacyAiProvider
private config?: ReturnType<typeof providerToAiSdkConfig> private config?: AiSdkConfig
private actualProvider: Provider private actualProvider: Provider
private model?: Model private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null private localProvider: Awaited<AiSdkProvider> | null = null
@ -89,6 +90,11 @@ export default class ModernAiProvider {
// 每次请求时重新生成配置以确保API key轮换生效 // 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model) this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config) logger.debug('Generated provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) { if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true providerConfig.isImageGenerationEndpoint = true
} }
@ -149,7 +155,8 @@ export default class ModernAiProvider {
params: StreamTextParams, params: StreamTextParams,
config: ModernAiProviderConfig config: ModernAiProviderConfig
): Promise<CompletionsResult> { ): Promise<CompletionsResult> {
if (config.isImageGenerationEndpoint) { // ai-gateway不是image/generation 端点所以就先不走legacy了
if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能) // 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) { if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint') throw new Error('uiMessages is required for image generation endpoint')
@ -463,8 +470,13 @@ export default class ModernAiProvider {
// 如果支持新的 AI SDK使用现代化实现 // 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) { if (isModernSdkSupported(this.actualProvider)) {
try { try {
// 确保 config 已定义
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with generateImage')
}
// 确保本地provider已创建 // 确保本地provider已创建
if (!this.localProvider) { if (!this.localProvider && this.config) {
this.localProvider = await createAiSdkProvider(this.config) this.localProvider = await createAiSdkProvider(this.config)
if (!this.localProvider) { if (!this.localProvider) {
throw new Error('Local provider not created') throw new Error('Local provider not created')

View File

@ -1,6 +1,6 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { isNewApiProvider } from '@renderer/config/providers'
import type { Provider } from '@renderer/types' import type { Provider } from '@renderer/types'
import { isNewApiProvider } from '@renderer/utils/provider'
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient' import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient' import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'

View File

@ -7,7 +7,6 @@ import {
isOpenAIModel, isOpenAIModel,
isSupportFlexServiceTierModel isSupportFlexServiceTierModel
} from '@renderer/config/models' } from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getAssistantSettings } from '@renderer/services/AssistantService' import { getAssistantSettings } from '@renderer/services/AssistantService'
import type { import type {
@ -19,7 +18,6 @@ import type {
MCPToolResponse, MCPToolResponse,
MemoryItem, MemoryItem,
Model, Model,
OpenAIVerbosity,
Provider, Provider,
ToolCallResponse, ToolCallResponse,
WebSearchProviderResponse, WebSearchProviderResponse,
@ -33,6 +31,7 @@ import {
OpenAIServiceTiers, OpenAIServiceTiers,
SystemProviderIds SystemProviderIds
} from '@renderer/types' } from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import type { Message } from '@renderer/types/newMessage' import type { Message } from '@renderer/types/newMessage'
import type { import type {
RequestOptions, RequestOptions,
@ -48,6 +47,7 @@ import type {
import { isJSON, parseJSON } from '@renderer/utils' import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController' import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import { defaultTimeout } from '@shared/config/constant' import { defaultTimeout } from '@shared/config/constant'
import { REFERENCE_PROMPT } from '@shared/config/prompts' import { REFERENCE_PROMPT } from '@shared/config/prompts'
import { defaultAppHeaders } from '@shared/utils' import { defaultAppHeaders } from '@shared/utils'

View File

@ -58,10 +58,27 @@ vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({})) AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
})) }))
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
// Mock the models config to prevent circular dependency issues // Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({ vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(), findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(), isReasoningModel: vi.fn(),
isOpenAILLMModel: vi.fn(),
SYSTEM_MODELS: { SYSTEM_MODELS: {
silicon: [], silicon: [],
defaultModel: [] defaultModel: []

View File

@ -1,7 +1,8 @@
import { GoogleGenAI } from '@google/genai' import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import type { Model, Provider, VertexProvider } from '@renderer/types' import type { Model, Provider, VertexProvider } from '@renderer/types'
import { isVertexProvider } from '@renderer/utils/provider'
import { isEmpty } from 'lodash' import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient' import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'

View File

@ -10,7 +10,6 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { import {
findTokenLimit, findTokenLimit,
GEMINI_FLASH_MODEL_REGEX, GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
getThinkModelType, getThinkModelType,
isClaudeReasoningModel, isClaudeReasoningModel,
isDeepSeekHybridInferenceModel, isDeepSeekHybridInferenceModel,
@ -40,12 +39,6 @@ import {
MODEL_SUPPORTED_REASONING_EFFORT, MODEL_SUPPORTED_REASONING_EFFORT,
ZHIPU_RESULT_TOKENS ZHIPU_RESULT_TOKENS
} from '@renderer/config/models' } from '@renderer/config/models'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate' import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService' import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
import { estimateTextTokens } from '@renderer/services/TokenService' import { estimateTextTokens } from '@renderer/services/TokenService'
@ -89,6 +82,12 @@ import {
openAIToolsToMcpTool openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools' } from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/utils/provider'
import { t } from 'i18next' import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas' import type { GenericChunk } from '../../middleware/schemas'
@ -743,7 +742,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
: {}), : {}),
...this.getProviderSpecificParameters(assistant, model), ...this.getProviderSpecificParameters(assistant, model),
...reasoningEffort, ...reasoningEffort,
...getOpenAIWebSearchParams(model, enableWebSearch), // ...getOpenAIWebSearchParams(model, enableWebSearch),
// OpenRouter usage tracking // OpenRouter usage tracking
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}), ...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
...extra_body, ...extra_body,

View File

@ -12,7 +12,6 @@ import {
isSupportVerbosityModel, isSupportVerbosityModel,
isVisionModel isVisionModel
} from '@renderer/config/models' } from '@renderer/config/models'
import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
import { estimateTextTokens } from '@renderer/services/TokenService' import { estimateTextTokens } from '@renderer/services/TokenService'
import type { import type {
FileMetadata, FileMetadata,
@ -43,6 +42,7 @@ import {
openAIToolsToMcpTool openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools' } from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
import { MB } from '@shared/config/constant' import { MB } from '@shared/config/constant'
import { t } from 'i18next' import { t } from 'i18next'
import { isEmpty } from 'lodash' import { isEmpty } from 'lodash'

View File

@ -1,6 +1,7 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { isZhipuModel } from '@renderer/config/models' import { isZhipuModel } from '@renderer/config/models'
import { getStoreProviders } from '@renderer/hooks/useStore' import { getStoreProviders } from '@renderer/hooks/useStore'
import { getDefaultModel } from '@renderer/services/AssistantService'
import type { Chunk } from '@renderer/types/chunk' import type { Chunk } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult } from '../schemas' import type { CompletionsParams, CompletionsResult } from '../schemas'
@ -66,7 +67,7 @@ export const ErrorHandlerMiddleware =
} }
function handleError(error: any, params: CompletionsParams): any { function handleError(error: any, params: CompletionsParams): any {
if (isZhipuModel(params.assistant.model) && error.status && !params.enableGenerateImage) { if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
return handleZhipuError(error) return handleZhipuError(error)
} }

View File

@ -1,10 +1,10 @@
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models' import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import type { MCPTool } from '@renderer/types' import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk' import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import type { LanguageModelMiddleware } from 'ai' import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash' import { isEmpty } from 'lodash'
@ -12,6 +12,7 @@ import { isEmpty } from 'lodash'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware' import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware' import { toolChoiceMiddleware } from './toolChoiceMiddleware'
@ -217,6 +218,14 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
middleware: noThinkMiddleware() middleware: noThinkMiddleware()
}) })
} }
if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
builder.add({
name: 'openrouter-reasoning-redaction',
middleware: openrouterReasoningMiddleware()
})
logger.debug('Added OpenRouter reasoning redaction middleware')
}
} }
/** /**

View File

@ -0,0 +1,50 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@ -0,0 +1,234 @@
import type { Message, Model } from '@renderer/types'
import type { FileMetadata } from '@renderer/types/file'
import { FileTypes } from '@renderer/types/file'
import {
AssistantMessageStatus,
type FileMessageBlock,
type ImageMessageBlock,
MessageBlockStatus,
MessageBlockType,
type ThinkingMessageBlock,
UserMessageStatus
} from '@renderer/types/newMessage'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
convertFileBlockToFilePartMock: vi.fn(),
convertFileBlockToTextPartMock: vi.fn()
}))
vi.mock('../fileProcessor', () => ({
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
convertFileBlockToTextPart: convertFileBlockToTextPartMock
}))
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
vi.mock('@renderer/config/models', () => ({
isVisionModel: (model: Model) => visionModelIds.has(model.id),
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
}))
type MockableMessage = Message & {
__mockContent?: string
__mockFileBlocks?: FileMessageBlock[]
__mockImageBlocks?: ImageMessageBlock[]
__mockThinkingBlocks?: ThinkingMessageBlock[]
}
vi.mock('@renderer/utils/messageUtils/find', () => ({
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? []
}))
import { convertMessagesToSdkMessages, convertMessageToSdkParam } from '../messageConverter'
let messageCounter = 0
let blockCounter = 0
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o-mini',
name: 'GPT-4o mini',
provider: 'openai',
group: 'openai',
...overrides
})
const createMessage = (role: Message['role']): MockableMessage =>
({
id: `message-${++messageCounter}`,
role,
assistantId: 'assistant-1',
topicId: 'topic-1',
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
blocks: []
}) as MockableMessage
const createFileBlock = (
messageId: string,
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
): FileMessageBlock => {
const { file, ...blockOverrides } = overrides
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
return {
id: blockOverrides.id ?? `file-block-${blockCounter}`,
messageId,
type: MessageBlockType.FILE,
createdAt: blockOverrides.createdAt ?? timestamp,
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
file: {
id: file?.id ?? `file-${blockCounter}`,
name: file?.name ?? 'document.txt',
origin_name: file?.origin_name ?? 'document.txt',
path: file?.path ?? '/tmp/document.txt',
size: file?.size ?? 1024,
ext: file?.ext ?? '.txt',
type: file?.type ?? FileTypes.TEXT,
created_at: file?.created_at ?? timestamp,
count: file?.count ?? 1,
...file
},
...blockOverrides
}
}
const createImageBlock = (
messageId: string,
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
): ImageMessageBlock => ({
id: overrides.id ?? `image-block-${++blockCounter}`,
messageId,
type: MessageBlockType.IMAGE,
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
status: overrides.status ?? MessageBlockStatus.SUCCESS,
url: overrides.url ?? 'https://example.com/image.png',
...overrides
})
describe('messageConverter', () => {
beforeEach(() => {
convertFileBlockToFilePartMock.mockReset()
convertFileBlockToTextPartMock.mockReset()
convertFileBlockToFilePartMock.mockResolvedValue(null)
convertFileBlockToTextPartMock.mockResolvedValue(null)
messageCounter = 0
blockCounter = 0
})
describe('convertMessageToSdkParam', () => {
it('includes text and image parts for user messages on vision models', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Describe this picture'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Describe this picture' },
{ type: 'image', image: 'https://example.com/cat.png' }
]
})
})
it('returns file instructions as a system message when native uploads succeed', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Summarize the PDF'
message.__mockFileBlocks = [createFileBlock(message.id)]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'document.pdf',
mediaType: 'application/pdf',
data: 'fileid://remote-file'
})
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual([
{
role: 'system',
content: 'fileid://remote-file'
},
{
role: 'user',
content: [{ type: 'text', text: 'Summarize the PDF' }]
}
])
})
})
describe('convertMessagesToSdkMessages', () => {
it('appends assistant images to the final user message for image enhancement models', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const initialUser = createMessage('user')
initialUser.__mockContent = 'Start editing'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Here is the current preview'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Increase the brightness'
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
expect(result).toEqual([
{
role: 'assistant',
content: [{ type: 'text', text: 'Here is the current preview' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Increase the brightness' },
{ type: 'image', image: 'https://example.com/preview.png' }
]
}
])
})
it('preserves preceding system instructions when building enhancement payloads', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const fileUser = createMessage('user')
fileUser.__mockContent = 'Use this document as inspiration'
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FileTypes.DOCUMENT } })]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'reference.pdf',
mediaType: 'application/pdf',
data: 'fileid://reference'
})
const assistant = createMessage('assistant')
assistant.__mockContent = 'Generated previews ready'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Apply the edits'
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
expect(result).toEqual([
{ role: 'system', content: 'fileid://reference' },
{
role: 'assistant',
content: [{ type: 'text', text: 'Generated previews ready' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Apply the edits' },
{ type: 'image', image: 'https://example.com/reference.png' }
]
}
])
})
})
})

View File

@ -0,0 +1,218 @@
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
import { TopicType } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
import { describe, expect, it, vi } from 'vitest'
import { getTemperature, getTimeout, getTopP } from '../modelParameters'
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
contextCount: assistant.settings?.contextCount ?? 4096,
temperature: assistant.settings?.temperature ?? 0.7,
enableTemperature: assistant.settings?.enableTemperature ?? true,
topP: assistant.settings?.topP ?? 1,
enableTopP: assistant.settings?.enableTopP ?? false,
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
maxTokens: assistant.settings?.maxTokens,
streamOutput: assistant.settings?.streamOutput ?? true,
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
defaultModel: assistant.defaultModel,
customParameters: assistant.settings?.customParameters ?? [],
reasoning_effort: assistant.settings?.reasoning_effort,
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
qwenThinkMode: assistant.settings?.qwenThinkMode
})
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(),
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
}))
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/assistants', () => ({
default: (state = { assistants: [] }) => state
}))
const createTopic = (assistantId: string): Topic => ({
id: `topic-${assistantId}`,
assistantId,
name: 'topic',
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messages: [],
type: TopicType.Chat
})
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
const assistantId = 'assistant-1'
return {
id: assistantId,
name: 'Test Assistant',
prompt: 'prompt',
topics: [createTopic(assistantId)],
type: 'assistant',
settings
}
}
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
provider: 'openai',
name: 'GPT-4o',
group: 'openai',
...overrides
})
describe('modelParameters', () => {
describe('getTemperature', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'medium' })
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for models without temperature/topP support', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
const model = createModel({
id: 'claude-sonnet-4.5',
name: 'Claude Sonnet 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns configured temperature when enabled', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(0.42)
})
it('returns undefined when temperature is disabled', () => {
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('clamps temperature to max 1.0 for Zhipu models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Anthropic models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
const model = createModel({
id: 'claude-sonnet-3.5',
name: 'Claude 3.5 Sonnet',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Moonshot models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({
id: 'moonshot-v1-8k',
name: 'Moonshot v1 8k',
provider: 'moonshot',
group: 'moonshot'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('does not clamp temperature for OpenAI models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(2.0)
})
it('does not clamp temperature when it is already within limits', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(0.8)
})
})
describe('getTopP', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'high' })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for models without TopP support', () => {
const assistant = createAssistant({ enableTopP: true })
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({
id: 'claude-opus-4.5',
name: 'Claude Opus 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns configured TopP when enabled', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBe(0.73)
})
it('returns undefined when TopP is disabled', () => {
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBeUndefined()
})
})
describe('getTimeout', () => {
it('uses an extended timeout for flex service tier models', () => {
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(15 * 1000 * 60)
})
it('falls back to the default timeout otherwise', () => {
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(defaultTimeout)
})
})
})

View File

@ -1,13 +1,31 @@
import { isClaude45ReasoningModel } from '@renderer/config/models' import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types' import type { Assistant, Model } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant' import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14' const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] { export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
const anthropicHeaders: string[] = [] const anthropicHeaders: string[] = []
if (isClaude45ReasoningModel(model) && isToolUseModeFunction(assistant)) { const provider = getProviderByModel(model)
if (
isClaude45ReasoningModel(model) &&
isToolUseModeFunction(assistant) &&
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
) {
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER) anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
} }
if (isClaude4SeriesModel(model)) {
if (isVertexProvider(provider) && assistant.enableWebSearch) {
anthropicHeaders.push(WEBSEARCH_HEADER)
}
anthropicHeaders.push(CONTEXT_100M_HEADER)
}
return anthropicHeaders return anthropicHeaders
} }

View File

@ -85,19 +85,6 @@ export function supportsLargeFileUpload(model: Model): boolean {
}) })
} }
/**
* TopP
*/
export function supportsTopP(model: Model): boolean {
const provider = getProviderByModel(model)
if (provider?.type === 'anthropic' || model?.endpoint_type === 'anthropic') {
return false
}
return true
}
/** /**
* *
*/ */

View File

@ -3,17 +3,27 @@
* TopP * TopP
*/ */
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { import {
isClaude45ReasoningModel, isClaude45ReasoningModel,
isClaudeReasoningModel, isClaudeReasoningModel,
isMaxTemperatureOneModel,
isNotSupportTemperatureAndTopP, isNotSupportTemperatureAndTopP,
isSupportedFlexServiceTier isSupportedFlexServiceTier,
isSupportedThinkingTokenClaudeModel
} from '@renderer/config/models' } from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService' import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types' import type { Assistant, Model } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant' import { defaultTimeout } from '@shared/config/constant'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
/** /**
* Claude 4.5 :
* - temperature 使 temperature
* - top_p 使 top_p
* - temperature ,top_p
* - 使
* *
*/ */
export function getTemperature(assistant: Assistant, model: Model): number | undefined { export function getTemperature(assistant: Assistant, model: Model): number | undefined {
@ -27,7 +37,11 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
return undefined return undefined
} }
const assistantSettings = getAssistantSettings(assistant) const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined let temperature = assistantSettings?.temperature
if (temperature && isMaxTemperatureOneModel(model)) {
temperature = Math.min(1, temperature)
}
return assistantSettings?.enableTemperature ? temperature : undefined
} }
/** /**
@ -56,3 +70,18 @@ export function getTimeout(model: Model): number {
} }
return defaultTimeout return defaultTimeout
} }
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
// NOTE: ai-sdk会把maxToken和budgetToken加起来
let { maxTokens = DEFAULT_MAX_TOKENS } = getAssistantSettings(assistant)
const provider = getProviderByModel(model)
if (isSupportedThinkingTokenClaudeModel(model) && ['anthropic', 'aws-bedrock'].includes(provider.type)) {
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
const budget = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
if (budget) {
maxTokens -= budget
}
}
return maxTokens
}

View File

@ -4,11 +4,12 @@
*/ */
import { anthropic } from '@ai-sdk/anthropic' import { anthropic } from '@ai-sdk/anthropic'
import { azure } from '@ai-sdk/azure'
import { google } from '@ai-sdk/google' import { google } from '@ai-sdk/google'
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge' import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
import { vertex } from '@ai-sdk/google-vertex/edge' import { vertex } from '@ai-sdk/google-vertex/edge'
import { combineHeaders } from '@ai-sdk/provider-utils' import { combineHeaders } from '@ai-sdk/provider-utils'
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas' import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { import {
@ -17,13 +18,10 @@ import {
isOpenRouterBuiltInWebSearchModel, isOpenRouterBuiltInWebSearchModel,
isReasoningModel, isReasoningModel,
isSupportedReasoningEffortModel, isSupportedReasoningEffortModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenModel, isSupportedThinkingTokenModel,
isWebSearchModel isWebSearchModel
} from '@renderer/config/models' } from '@renderer/config/models'
import { isAwsBedrockProvider } from '@renderer/config/providers' import { getDefaultModel } from '@renderer/services/AssistantService'
import { isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store' import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch' import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import { type Assistant, type MCPTool, type Provider } from '@renderer/types' import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
@ -36,11 +34,9 @@ import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from '../provider/factory' import { getAiSdkProviderId } from '../provider/factory'
import { setupToolsConfig } from '../utils/mcp' import { setupToolsConfig } from '../utils/mcp'
import { buildProviderOptions } from '../utils/options' import { buildProviderOptions } from '../utils/options'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch' import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
import { addAnthropicHeaders } from './header' import { addAnthropicHeaders } from './header'
import { supportsTopP } from './modelCapabilities' import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
import { getTemperature, getTopP } from './modelParameters'
const logger = loggerService.withContext('parameterBuilder') const logger = loggerService.withContext('parameterBuilder')
@ -63,7 +59,7 @@ export async function buildStreamTextParams(
timeout?: number timeout?: number
headers?: Record<string, string> headers?: Record<string, string>
} }
} = {} }
): Promise<{ ): Promise<{
params: StreamTextParams params: StreamTextParams
modelId: string modelId: string
@ -80,8 +76,6 @@ export async function buildStreamTextParams(
const model = assistant.model || getDefaultModel() const model = assistant.model || getDefaultModel()
const aiSdkProviderId = getAiSdkProviderId(provider) const aiSdkProviderId = getAiSdkProviderId(provider)
let { maxTokens } = getAssistantSettings(assistant)
// 这三个变量透传出来,交给下面启用插件/中间件 // 这三个变量透传出来,交给下面启用插件/中间件
// 也可以在外部构建好再传入buildStreamTextParams // 也可以在外部构建好再传入buildStreamTextParams
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true // FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
@ -118,16 +112,6 @@ export async function buildStreamTextParams(
enableGenerateImage enableGenerateImage
}) })
// NOTE: ai-sdk会把maxToken和budgetToken加起来
if (
enableReasoning &&
maxTokens !== undefined &&
isSupportedThinkingTokenClaudeModel(model) &&
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
) {
maxTokens -= getAnthropicThinkingBudget(assistant, model)
}
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
if (enableWebSearch) { if (enableWebSearch) {
if (isBaseProvider(aiSdkProviderId)) { if (isBaseProvider(aiSdkProviderId)) {
@ -144,6 +128,17 @@ export async function buildStreamTextParams(
maxUses: webSearchConfig.maxResults, maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}) as ProviderDefinedTool }) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-responses') {
tools.web_search_preview = azure.tools.webSearchPreview({
searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
}) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-anthropic') {
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
const anthropicSearchOptions: AnthropicSearchConfig = {
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}
tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
} }
} }
@ -161,9 +156,10 @@ export async function buildStreamTextParams(
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
break break
case 'anthropic': case 'anthropic':
case 'azure-anthropic':
case 'google-vertex-anthropic': case 'google-vertex-anthropic':
tools.web_fetch = ( tools.web_fetch = (
aiSdkProviderId === 'anthropic' ['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)
? anthropic.tools.webFetch_20250910({ ? anthropic.tools.webFetch_20250910({
maxUses: webSearchConfig.maxResults, maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
@ -179,8 +175,7 @@ export async function buildStreamTextParams(
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {} let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking if (isAnthropicModel(model)) {
if (!isVertexProvider(provider) && !isAwsBedrockProvider(provider) && isAnthropicModel(model)) {
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') } const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
headers = combineHeaders(headers, newBetaHeaders) headers = combineHeaders(headers, newBetaHeaders)
} }
@ -188,8 +183,9 @@ export async function buildStreamTextParams(
// 构建基础参数 // 构建基础参数
const params: StreamTextParams = { const params: StreamTextParams = {
messages: sdkMessages, messages: sdkMessages,
maxOutputTokens: maxTokens, maxOutputTokens: getMaxTokens(assistant, model),
temperature: getTemperature(assistant, model), temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
abortSignal: options.requestOptions?.signal, abortSignal: options.requestOptions?.signal,
headers, headers,
providerOptions, providerOptions,
@ -197,10 +193,6 @@ export async function buildStreamTextParams(
maxRetries: 0 maxRetries: 0
} }
if (supportsTopP(model)) {
params.topP = getTopP(assistant, model)
}
if (tools) { if (tools) {
params.tools = tools params.tools = tools
} }

View File

@ -23,6 +23,26 @@ vi.mock('@cherrystudio/ai-core', () => ({
} }
})) }))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
// Mock the provider configs // Mock the provider configs
vi.mock('../providerConfigs', () => ({ vi.mock('../providerConfigs', () => ({
initializeNewProviders: vi.fn() initializeNewProviders: vi.fn()

View File

@ -12,7 +12,14 @@ vi.mock('@renderer/services/LoggerService', () => ({
})) }))
vi.mock('@renderer/services/AssistantService', () => ({ vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn() getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
})) }))
vi.mock('@renderer/store', () => ({ vi.mock('@renderer/store', () => ({
@ -34,7 +41,7 @@ vi.mock('@renderer/utils/api', () => ({
})) }))
})) }))
vi.mock('@renderer/config/providers', async (importOriginal) => { vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any const actual = (await importOriginal()) as any
return { return {
...actual, ...actual,
@ -53,10 +60,21 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
createVertexProvider: vi.fn() createVertexProvider: vi.fn()
})) }))
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/config/providers' vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types' import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api' import { formatApiHost } from '@renderer/utils/api'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'

View File

@ -0,0 +1,22 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@ -2,8 +2,10 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import type { Provider } from '@renderer/types' import type { Provider } from '@renderer/types'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai' import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
import { initializeNewProviders } from './providerInitialization' import { initializeNewProviders } from './providerInitialization'
const logger = loggerService.withContext('ProviderFactory') const logger = loggerService.withContext('ProviderFactory')
@ -55,9 +57,12 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
* AI SDK Provider ID * AI SDK Provider ID
* *
*/ */
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' { export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id // 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id) const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider) && isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
if (resolvedFromId) { if (resolvedFromId) {
return resolvedFromId return resolvedFromId
} }
@ -73,11 +78,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
if (provider.apiHost.includes('api.openai.com')) { if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat' return 'openai-chat'
} }
// 3. 最后的fallback通常会成为openai-compatible // 3. 最后的fallback使用provider本身的id
return provider.id as ProviderId return provider.id
} }
export async function createAiSdkProvider(config) { export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
let localProvider: Awaited<AiSdkProvider> | null = null let localProvider: Awaited<AiSdkProvider> | null = null
try { try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') { if (config.providerId === 'openai' && config.options?.mode === 'chat') {

View File

@ -1,20 +1,6 @@
import { import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
formatPrivateKey,
hasProviderConfig,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { cacheService } from '@data/CacheService' import { cacheService } from '@data/CacheService'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider
} from '@renderer/config/providers'
import { import {
getAwsBedrockAccessKeyId, getAwsBedrockAccessKeyId,
getAwsBedrockApiKey, getAwsBedrockApiKey,
@ -22,14 +8,25 @@ import {
getAwsBedrockRegion, getAwsBedrockRegion,
getAwsBedrockSecretAccessKey getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock' } from '@renderer/hooks/useAwsBedrock'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store' import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types' import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api' import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash' import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config' import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants' import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
@ -75,6 +72,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
return vertexAnthropicProviderCreator(model, provider) return vertexAnthropicProviderCreator(model, provider)
} }
} }
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider return provider
} }
@ -132,13 +132,7 @@ export function getActualProvider(model: Model): Provider {
* Provider AI SDK * Provider AI SDK
* *
*/ */
export function providerToAiSdkConfig( export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
actualProvider: Provider,
model: Model
): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
const aiSdkProviderId = getAiSdkProviderId(actualProvider) const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置 // 构建基础配置
@ -192,13 +186,10 @@ export function providerToAiSdkConfig(
// azure // azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') { if (aiSdkProviderId === 'azure-responses') {
// extraOptions.apiVersion = actualProvider.apiVersion === 'preview' ? 'v1' : actualProvider.apiVersion 默认使用v1不使用azure endpoint extraOptions.mode = 'responses'
if (actualProvider.apiVersion === 'preview' || actualProvider.apiVersion === 'v1') { } else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'responses' extraOptions.mode = 'chat'
} else {
extraOptions.mode = 'chat'
}
} }
// bedrock // bedrock
@ -238,7 +229,7 @@ export function providerToAiSdkConfig(
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') { if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions) const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return { return {
providerId: aiSdkProviderId as ProviderId, providerId: aiSdkProviderId,
options options
} }
} }

View File

@ -32,6 +32,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
supportsImageGeneration: true, supportsImageGeneration: true,
aliases: ['vertexai-anthropic'] aliases: ['vertexai-anthropic']
}, },
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{ {
id: 'github-copilot-openai-compatible', id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible', name: 'GitHub Copilot OpenAI Compatible',

View File

@ -133,7 +133,7 @@ export class AiSdkSpanAdapter {
// 详细记录转换过程 // 详细记录转换过程
const operationId = attributes['ai.operationId'] const operationId = attributes['ai.operationId']
logger.info('Converting AI SDK span to SpanEntity', { logger.debug('Converting AI SDK span to SpanEntity', {
spanName: spanName, spanName: spanName,
operationId, operationId,
spanTag, spanTag,
@ -149,7 +149,7 @@ export class AiSdkSpanAdapter {
}) })
if (tokenUsage) { if (tokenUsage) {
logger.info('Token usage data found', { logger.debug('Token usage data found', {
spanName: spanName, spanName: spanName,
operationId, operationId,
usage: tokenUsage, usage: tokenUsage,
@ -158,7 +158,7 @@ export class AiSdkSpanAdapter {
} }
if (inputs || outputs) { if (inputs || outputs) {
logger.info('Input/Output data extracted', { logger.debug('Input/Output data extracted', {
spanName: spanName, spanName: spanName,
operationId, operationId,
hasInputs: !!inputs, hasInputs: !!inputs,
@ -170,7 +170,7 @@ export class AiSdkSpanAdapter {
} }
if (Object.keys(typeSpecificData).length > 0) { if (Object.keys(typeSpecificData).length > 0) {
logger.info('Type-specific data extracted', { logger.debug('Type-specific data extracted', {
spanName: spanName, spanName: spanName,
operationId, operationId,
typeSpecificKeys: Object.keys(typeSpecificData), typeSpecificKeys: Object.keys(typeSpecificData),
@ -204,7 +204,7 @@ export class AiSdkSpanAdapter {
modelName: modelName || this.extractModelFromAttributes(attributes) modelName: modelName || this.extractModelFromAttributes(attributes)
} }
logger.info('AI SDK span successfully converted to SpanEntity', { logger.debug('AI SDK span successfully converted to SpanEntity', {
spanName: spanName, spanName: spanName,
operationId, operationId,
spanId: spanContext.spanId, spanId: spanContext.spanId,

View File

@ -0,0 +1,15 @@
/**
* This type definition file is only for renderer.
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
* (ai-core package is set as browser-enviroment-only)
*
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
export type AiSdkConfig = {
providerId: string
options: ProviderSettingsMap[keyof ProviderSettingsMap]
}

View File

@ -0,0 +1,121 @@
/**
* image.ts Unit Tests
* Tests for Gemini image generation utilities
*/
import type { Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { describe, expect, it } from 'vitest'
import { buildGeminiGenerateImageParams, isOpenRouterGeminiGenerateImageModel } from '../image'
describe('image utils', () => {
describe('buildGeminiGenerateImageParams', () => {
it('should return correct response modalities', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toEqual({
responseModalities: ['TEXT', 'IMAGE']
})
})
it('should return an object with responseModalities property', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toHaveProperty('responseModalities')
expect(Array.isArray(result.responseModalities)).toBe(true)
expect(result.responseModalities).toHaveLength(2)
})
})
describe('isOpenRouterGeminiGenerateImageModel', () => {
const mockOpenRouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const mockOtherProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should return true for OpenRouter Gemini 2.5 Flash Image model', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for non-Gemini model on OpenRouter', () => {
const model: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model on non-OpenRouter provider', () => {
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.gemini
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOtherProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model without image suffix', () => {
const model: Model = {
id: 'google/gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should handle model ID with partial match', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-generation',
name: 'Gemini Image Gen',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for custom provider', () => {
const customProvider: Provider = {
id: 'custom-provider-123',
name: 'Custom Provider',
apiKey: 'test-key',
apiHost: 'https://custom.com'
} as Provider
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: 'custom-provider-123'
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, customProvider)
expect(result).toBe(false)
})
})
})

View File

@ -0,0 +1,435 @@
/**
* mcp.ts Unit Tests
* Tests for MCP tools configuration and conversion utilities
*/
import type { MCPTool } from '@renderer/types'
import type { Tool } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { convertMcpToolsToAiSdkTools, setupToolsConfig } from '../mcp'
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/utils/mcp-tools', () => ({
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
isToolAutoApproved: vi.fn(() => false),
callMCPTool: vi.fn(async () => ({
content: [{ type: 'text', text: 'Tool executed successfully' }],
isError: false
}))
}))
vi.mock('@renderer/utils/userConfirmation', () => ({
requestToolConfirmation: vi.fn(async () => true)
}))
describe('mcp utils', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('setupToolsConfig', () => {
it('should return undefined when no MCP tools provided', () => {
const result = setupToolsConfig()
expect(result).toBeUndefined()
})
it('should return undefined when empty MCP tools array provided', () => {
const result = setupToolsConfig([])
expect(result).toBeUndefined()
})
it('should convert MCP tools to AI SDK tools format', () => {
const mcpTools: MCPTool[] = [
{
id: 'test-tool-1',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-tool',
description: 'A test tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string' }
}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toEqual(['test-tool'])
expect(result!['test-tool']).toHaveProperty('description')
expect(result!['test-tool']).toHaveProperty('inputSchema')
expect(result!['test-tool']).toHaveProperty('execute')
})
it('should handle multiple MCP tools', () => {
const mcpTools: MCPTool[] = [
{
id: 'tool1-id',
serverId: 'server1',
serverName: 'server1',
name: 'tool1',
description: 'First tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
},
{
id: 'tool2-id',
serverId: 'server2',
serverName: 'server2',
name: 'tool2',
description: 'Second tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toHaveLength(2)
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
})
})
describe('convertMcpToolsToAiSdkTools', () => {
it('should convert single MCP tool to AI SDK tool', () => {
const mcpTools: MCPTool[] = [
{
id: 'get-weather-id',
serverId: 'weather-server',
serverName: 'weather-server',
name: 'get-weather',
description: 'Get weather information',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
location: { type: 'string' }
},
required: ['location']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['get-weather'])
const tool = result['get-weather'] as Tool
expect(tool.description).toBe('Get weather information')
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should handle tool without description', () => {
const mcpTools: MCPTool[] = [
{
id: 'no-desc-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'no-desc-tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['no-desc-tool'])
const tool = result['no-desc-tool'] as Tool
expect(tool.description).toBe('Tool from test-server')
})
it('should convert empty tools array', () => {
const result = convertMcpToolsToAiSdkTools([])
expect(result).toEqual({})
})
it('should handle complex input schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'complex-tool-id',
serverId: 'server',
serverName: 'server',
name: 'complex-tool',
description: 'Tool with complex schema',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
tags: {
type: 'array',
items: { type: 'string' }
},
metadata: {
type: 'object',
properties: {
key: { type: 'string' }
}
}
},
required: ['name']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['complex-tool'])
const tool = result['complex-tool'] as Tool
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should preserve tool names with special characters', () => {
const mcpTools: MCPTool[] = [
{
id: 'special-tool-id',
serverId: 'server',
serverName: 'server',
name: 'tool_with-special.chars',
description: 'Special chars tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
})
it('should handle multiple tools with different schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'string-tool-id',
serverId: 'server1',
serverName: 'server1',
name: 'string-tool',
description: 'String tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' }
}
}
},
{
id: 'number-tool-id',
serverId: 'server2',
serverName: 'server2',
name: 'number-tool',
description: 'Number tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
count: { type: 'number' }
}
}
},
{
id: 'boolean-tool-id',
serverId: 'server3',
serverName: 'server3',
name: 'boolean-tool',
description: 'Boolean tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
enabled: { type: 'boolean' }
}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
expect(result['string-tool']).toBeDefined()
expect(result['number-tool']).toBeDefined()
expect(result['boolean-tool']).toBeDefined()
})
})
describe('tool execution', () => {
it('should execute tool with user confirmation', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'test-exec-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-exec-tool',
description: 'Test execution tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['test-exec-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
})
it('should handle user cancellation', async () => {
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
const mcpTools: MCPTool[] = [
{
id: 'cancelled-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'cancelled-tool',
description: 'Tool to cancel',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['cancelled-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).not.toHaveBeenCalled()
expect(result).toEqual({
content: [
{
type: 'text',
text: 'User declined to execute tool "cancelled-tool".'
}
],
isError: false
})
})
it('should handle tool execution error', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
const mcpTools: MCPTool[] = [
{
id: 'error-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'error-tool',
description: 'Tool that errors',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['error-tool'] as Tool
await expect(
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
).rejects.toEqual({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
})
it('should auto-approve when enabled', async () => {
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(isToolAutoApproved).mockReturnValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'auto-approve-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'auto-approve-tool',
description: 'Auto-approved tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['auto-approve-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
expect(requestToolConfirmation).not.toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
})
})
})

View File

@ -0,0 +1,545 @@
/**
* options.ts Unit Tests
* Tests for building provider-specific options
*/
import type { Assistant, Model, Provider } from '@renderer/types'
import { OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { buildProviderOptions } from '../options'
// Mock dependencies
vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
const actual = (await importOriginal()) as object
return {
...actual,
baseProviderIdSchema: {
safeParse: vi.fn((id) => {
const baseProviders = [
'openai',
'openai-chat',
'azure',
'azure-responses',
'huggingface',
'anthropic',
'google',
'xai',
'deepseek',
'openrouter',
'openai-compatible'
]
if (baseProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false }
})
},
customProviderIdSchema: {
safeParse: vi.fn((id) => {
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
if (customProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false, error: new Error('Invalid provider') }
})
}
}
})
vi.mock('../provider/factory', () => ({
getAiSdkProviderId: vi.fn((provider) => {
// Simulate the provider ID mapping
const mapping: Record<string, string> = {
[SystemProviderIds.gemini]: 'google',
[SystemProviderIds.openai]: 'openai',
[SystemProviderIds.anthropic]: 'anthropic',
[SystemProviderIds.grok]: 'xai',
[SystemProviderIds.deepseek]: 'deepseek',
[SystemProviderIds.openrouter]: 'openrouter'
}
return mapping[provider.id] || provider.id
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => ({
...(await importOriginal()),
isOpenAIModel: vi.fn((model) => model.id.includes('gpt') || model.id.includes('o1')),
isQwenMTModel: vi.fn(() => false),
isSupportFlexServiceTierModel: vi.fn(() => true),
isOpenAILLMModel: vi.fn(() => true),
SYSTEM_MODELS: {
defaultModel: [
{ id: 'default-1', name: 'Default 1' },
{ id: 'default-2', name: 'Default 2' },
{ id: 'default-3', name: 'Default 3' }
]
}
}))
vi.mock(import('@renderer/utils/provider'), async (importOriginal) => {
return {
...(await importOriginal()),
isSupportServiceTierProvider: vi.fn((provider) => {
return [SystemProviderIds.openai, SystemProviderIds.groq].includes(provider.id)
})
}
})
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn((key) => {
if (key === 'openAI') {
return { summaryText: 'off', verbosity: 'medium' } as any
}
return {}
})
}))
vi.mock('@renderer/services/AssistantService', () => ({
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
})),
getAssistantSettings: vi.fn(() => ({
reasoning_effort: 'medium',
maxTokens: 4096
})),
getProviderByModel: vi.fn((model: Model) => ({
id: model.provider,
name: 'Mock Provider'
}))
}))
vi.mock('../reasoning', () => ({
getOpenAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'medium' })),
getAnthropicReasoningParams: vi.fn(() => ({
thinking: { type: 'enabled', budgetTokens: 5000 }
})),
getGeminiReasoningParams: vi.fn(() => ({
thinkingConfig: { include_thoughts: true }
})),
getXAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'high' })),
getBedrockReasoningParams: vi.fn(() => ({
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
})),
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
getCustomParameters: vi.fn(() => ({}))
}))
vi.mock('../image', () => ({
buildGeminiGenerateImageParams: vi.fn(() => ({
responseModalities: ['TEXT', 'IMAGE']
}))
}))
vi.mock('../websearch', () => ({
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('options utils', () => {
const mockAssistant: Assistant = {
id: 'test-assistant',
name: 'Test Assistant',
settings: {}
} as Assistant
const mockModel: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
beforeEach(() => {
vi.clearAllMocks()
})
describe('buildProviderOptions', () => {
describe('OpenAI provider', () => {
const openaiProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai-response',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should build basic OpenAI options', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openai')
expect(result.openai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('reasoningEffort')
expect(result.openai.reasoningEffort).toBe('medium')
})
it('should include service tier when supported', () => {
const providerWithServiceTier: Provider = {
...openaiProvider,
serviceTier: OpenAIServiceTiers.auto
}
const result = buildProviderOptions(mockAssistant, mockModel, providerWithServiceTier, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('serviceTier')
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
})
})
describe('Anthropic provider', () => {
const anthropicProvider: Provider = {
id: SystemProviderIds.anthropic,
name: 'Anthropic',
type: 'anthropic',
apiKey: 'test-key',
apiHost: 'https://api.anthropic.com',
isSystem: true
} as Provider
const anthropicModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
it('should build basic Anthropic options', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
expect(result.anthropic).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.anthropic).toHaveProperty('thinking')
expect(result.anthropic.thinking).toEqual({
type: 'enabled',
budgetTokens: 5000
})
})
})
describe('Google provider', () => {
const googleProvider: Provider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [{ id: 'gemini-2.0-flash-exp' }] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should build basic Google options', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
expect(result.google).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google.thinkingConfig).toEqual({
include_thoughts: true
})
})
it('should include image generation parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('responseModalities')
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
})
})
describe('xAI provider', () => {
const xaiProvider = {
id: SystemProviderIds.grok,
name: 'xAI',
type: 'new-api',
apiKey: 'test-key',
apiHost: 'https://api.x.ai/v1',
isSystem: true,
models: [] as Model[]
} as Provider
const xaiModel: Model = {
id: 'grok-2-latest',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
it('should build basic xAI options', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('xai')
expect(result.xai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.xai).toHaveProperty('reasoningEffort')
expect(result.xai.reasoningEffort).toBe('high')
})
})
describe('DeepSeek provider', () => {
const deepseekProvider: Provider = {
id: SystemProviderIds.deepseek,
name: 'DeepSeek',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.deepseek.com',
isSystem: true
} as Provider
const deepseekModel: Model = {
id: 'deepseek-chat',
name: 'DeepSeek Chat',
provider: SystemProviderIds.deepseek
} as Model
it('should build basic DeepSeek options', () => {
const result = buildProviderOptions(mockAssistant, deepseekModel, deepseekProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('deepseek')
expect(result.deepseek).toBeDefined()
})
})
describe('OpenRouter provider', () => {
const openrouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const openrouterModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
it('should build basic OpenRouter options', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openrouter')
expect(result.openrouter).toBeDefined()
})
it('should include web search parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: true,
enableGenerateImage: false
})
expect(result.openrouter).toHaveProperty('enable_search')
})
})
describe('Custom parameters', () => {
it('should merge custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
custom_param: 'custom_value',
another_param: 123
})
const result = buildProviderOptions(
mockAssistant,
mockModel,
{
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1'
} as Provider,
{
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
}
)
expect(result.openai).toHaveProperty('custom_param')
expect(result.openai.custom_param).toBe('custom_value')
expect(result.openai).toHaveProperty('another_param')
expect(result.openai.another_param).toBe(123)
})
})
describe('Multiple capabilities', () => {
const googleProvider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should combine reasoning and image generation', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google).toHaveProperty('responseModalities')
})
it('should handle all capabilities enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: true,
enableGenerateImage: true
})
expect(result.google).toBeDefined()
expect(Object.keys(result.google).length).toBeGreaterThan(0)
})
})
describe('Vertex AI providers', () => {
it('should map google-vertex to google', () => {
const vertexProvider = {
id: 'google-vertex',
name: 'Vertex AI',
type: 'vertexai',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: 'google-vertex'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
})
it('should map google-vertex-anthropic to anthropic', () => {
const vertexAnthropicProvider = {
id: 'google-vertex-anthropic',
name: 'Vertex AI Anthropic',
type: 'vertex-anthropic',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: 'google-vertex-anthropic'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexAnthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
})
})
})
})

View File

@ -0,0 +1,967 @@
/**
* reasoning.ts Unit Tests
* Tests for reasoning parameter generation utilities
*/
import { getStoreSetting } from '@renderer/hooks/useSettings'
import type { SettingsState } from '@renderer/store/settings'
import type { Assistant, Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
getAnthropicReasoningParams,
getBedrockReasoningParams,
getCustomParameters,
getGeminiReasoningParams,
getOpenAIReasoningParams,
getReasoningEffort,
getXAIReasoningParams
} from '../reasoning'
function defaultGetStoreSetting<K extends keyof SettingsState>(key: K): SettingsState[K] {
if (key === 'openAI') {
return {
summaryText: 'auto',
verbosity: 'medium'
} as SettingsState[K]
}
return undefined as SettingsState[K]
}
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/llm', () => ({
initialState: {},
default: (state = { llm: {} }) => state
}))
vi.mock('@renderer/config/constant', () => ({
DEFAULT_MAX_TOKENS: 4096,
isMac: false,
isWin: false,
TOKENFLUX_HOST: 'mock-host'
}))
vi.mock('@renderer/utils/provider', () => ({
isSupportEnableThinkingProvider: vi.fn((provider) => {
return [SystemProviderIds.dashscope, SystemProviderIds.silicon].includes(provider.id)
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => {
const actual: any = await importOriginal()
return {
...actual,
isReasoningModel: vi.fn(() => false),
isOpenAIDeepResearchModel: vi.fn(() => false),
isOpenAIModel: vi.fn(() => false),
isSupportedReasoningEffortOpenAIModel: vi.fn(() => false),
isSupportedThinkingTokenQwenModel: vi.fn(() => false),
isQwenReasoningModel: vi.fn(() => false),
isSupportedThinkingTokenClaudeModel: vi.fn(() => false),
isSupportedThinkingTokenGeminiModel: vi.fn(() => false),
isSupportedThinkingTokenDoubaoModel: vi.fn(() => false),
isSupportedThinkingTokenZhipuModel: vi.fn(() => false),
isSupportedReasoningEffortModel: vi.fn(() => false),
isDeepSeekHybridInferenceModel: vi.fn(() => false),
isSupportedReasoningEffortGrokModel: vi.fn(() => false),
getThinkModelType: vi.fn(() => 'default'),
isDoubaoSeedAfter251015: vi.fn(() => false),
isDoubaoThinkingAutoModel: vi.fn(() => false),
isGrok4FastReasoningModel: vi.fn(() => false),
isGrokReasoningModel: vi.fn(() => false),
isOpenAIReasoningModel: vi.fn(() => false),
isQwenAlwaysThinkModel: vi.fn(() => false),
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
isSupportedThinkingTokenModel: vi.fn(() => false),
isGPT51SeriesModel: vi.fn(() => false)
}
})
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(defaultGetStoreSetting)
}))
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: vi.fn((assistant) => ({
maxTokens: assistant?.settings?.maxTokens || 4096,
reasoning_effort: assistant?.settings?.reasoning_effort
})),
getProviderByModel: vi.fn((model) => ({
id: model.provider,
name: 'Test Provider'
})),
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
}))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('reasoning utils', () => {
beforeEach(() => {
vi.resetAllMocks()
})
describe('getReasoningEffort', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'anthropic/claude-sonnet-4',
name: 'Claude Sonnet 4',
provider: SystemProviderIds.openrouter
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
})
it('should handle Qwen models with enable_thinking', async () => {
const { isReasoningModel, isSupportedThinkingTokenQwenModel, isQwenReasoningModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'qwen-plus',
name: 'Qwen Plus',
provider: SystemProviderIds.dashscope
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('enable_thinking')
})
it('should handle Claude models with thinking config', async () => {
const {
isSupportedThinkingTokenClaudeModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isSupportedReasoningEffortModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high',
maxTokens: 4096
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budget_tokens: expect.any(Number)
}
})
})
it('should handle Gemini Flash models with thinking budget 0', async () => {
const {
isSupportedThinkingTokenGeminiModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isOpenAIDeepResearchModel,
isSupportedThinkingTokenQwenModel,
isSupportedThinkingTokenHunyuanModel,
isDeepSeekHybridInferenceModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenHunyuanModel).mockReturnValue(false)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
google: {
thinking_config: {
thinking_budget: 0
}
}
}
})
})
it('should handle GPT-5.1 reasoning model with effort levels', async () => {
const {
isReasoningModel,
isOpenAIDeepResearchModel,
isSupportedReasoningEffortModel,
isGPT51SeriesModel,
getThinkModelType
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(true)
vi.mocked(getThinkModelType).mockReturnValue('gpt5_1')
vi.mocked(isGPT51SeriesModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT-5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'none'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
reasoningEffort: 'none'
})
})
it('should handle DeepSeek hybrid inference models', async () => {
const { isReasoningModel, isDeepSeekHybridInferenceModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(true)
const model: Model = {
id: 'deepseek-v3.1',
name: 'DeepSeek V3.1',
provider: SystemProviderIds.silicon
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
enable_thinking: true
})
})
it('should return medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
const model: Model = {
id: 'o3-deep-research',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning_effort: 'medium' })
})
it('should return empty for groq provider', async () => {
const { getProviderByModel } = await import('@renderer/services/AssistantService')
vi.mocked(getProviderByModel).mockReturnValue({
id: 'groq',
name: 'Groq'
} as Provider)
const model: Model = {
id: 'groq-model',
name: 'Groq Model',
provider: 'groq'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
})
describe('getOpenAIReasoningParams', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort set', async () => {
const model: Model = {
id: 'o1-preview',
name: 'O1 Preview',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for OpenAI models', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT 5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: 'auto'
})
})
it('should include reasoning summary when not o1-pro', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'auto'
})
})
it('should not include reasoning summary for o1-pro', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: undefined
})
})
it('should force medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } =
await import('@renderer/config/models')
const { getStoreSetting } = await import('@renderer/hooks/useSettings')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o3-deep-research',
name: 'O3 Mini',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'off'
})
})
})
describe('getAnthropicReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-5-sonnet',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return disabled thinking when no reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'disabled'
}
})
})
it('should return enabled thinking with budget for Claude models', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getGeminiReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.0-flash',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should disable thinking for Flash models without reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: false,
thinkingBudget: 0
}
})
})
it('should enable thinking with budget for reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
thinkingBudget: 16448,
includeThoughts: true
}
})
})
it('should enable thinking without budget for auto effort ratio > 1', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'auto'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: true
}
})
})
})
describe('getXAIReasoningParams', () => {
it('should return empty for non-Grok model', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(false)
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-2',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for Grok models', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-3',
name: 'Grok 3',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toHaveProperty('reasoningEffort')
expect(result.reasoningEffort).toBe('high')
})
})
describe('getBedrockReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning config for Claude models on Bedrock', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({
reasoningConfig: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getCustomParameters', () => {
it('should return empty object when no custom parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({})
})
it('should return custom parameters as key-value pairs', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: 'param1', value: 'value1', type: 'string' },
{ name: 'param2', value: 123, type: 'number' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
param1: 'value1',
param2: 123
})
})
it('should parse JSON type parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'config', value: '{"key": "value"}', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
config: { key: 'value' }
})
})
it('should handle invalid JSON gracefully', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'invalid', value: '{invalid json', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
invalid: '{invalid json'
})
})
it('should handle undefined JSON value', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'undef', value: 'undefined', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
undef: undefined
})
})
it('should skip parameters with empty names', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: '', value: 'value1', type: 'string' },
{ name: ' ', value: 'value2', type: 'string' },
{ name: 'valid', value: 'value3', type: 'string' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
valid: 'value3'
})
})
})
})

View File

@ -0,0 +1,384 @@
/**
* websearch.ts Unit Tests
* Tests for web search parameters generation utilities
*/
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
// Mock dependencies
vi.mock('@renderer/config/models', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
}))
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
}))
describe('websearch utils', () => {
describe('getWebSearchParams', () => {
it('should return enhancement params for hunyuan provider', () => {
const model: Model = {
id: 'hunyuan-model',
name: 'Hunyuan Model',
provider: 'hunyuan'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_enhancement: true,
citation: true,
search_info: true
})
})
it('should return search params for dashscope provider', () => {
const model: Model = {
id: 'qwen-model',
name: 'Qwen Model',
provider: 'dashscope'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_search: true,
search_options: {
forced_search: true
}
})
})
it('should return web_search_options for OpenAI web search models', () => {
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
web_search_options: {}
})
})
it('should return empty object for other providers', () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
it('should return empty object for custom provider', () => {
const model: Model = {
id: 'custom-model',
name: 'Custom Model',
provider: 'custom-provider'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
})
describe('buildProviderBuiltinWebSearchConfig', () => {
const defaultWebSearchConfig: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
describe('openai provider', () => {
it('should return low search context size for low maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 20,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'low'
}
})
})
it('should return medium search context size for medium maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
it('should return high search context size for high maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 80,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'high'
}
})
})
it('should use medium for deep research models regardless of maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
})
describe('openai-chat provider', () => {
it('should return correct search context size', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
it('should handle deep research models', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
})
describe('anthropic provider', () => {
it('should return anthropic search options with maxUses', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
it('should include blockedDomains when excludeDomains provided', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 30,
excludeDomains: ['example.com', 'test.com']
}
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
expect(result).toEqual({
anthropic: {
maxUses: 30,
blockedDomains: ['example.com', 'test.com']
}
})
})
it('should not include blockedDomains when empty', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
})
describe('xai provider', () => {
it('should return xai search options', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result).toEqual({
xai: {
maxSearchResults: 50,
returnCitations: true,
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
mode: 'on'
}
})
})
it('should limit excluded websites to 5', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 40,
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
}
const result = buildProviderBuiltinWebSearchConfig('xai', config)
expect(result?.xai?.sources).toBeDefined()
const webSource = result?.xai?.sources?.[0]
if (webSource && webSource.type === 'web') {
expect(webSource.excludedWebsites).toHaveLength(5)
}
})
it('should include all sources types', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result?.xai?.sources).toHaveLength(3)
expect(result?.xai?.sources?.[0].type).toBe('web')
expect(result?.xai?.sources?.[1].type).toBe('news')
expect(result?.xai?.sources?.[2].type).toBe('x')
})
})
describe('openrouter provider', () => {
it('should return openrouter plugins config', () => {
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 50
}
]
}
})
})
it('should respect custom maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 75,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 75
}
]
}
})
})
})
describe('unsupported provider', () => {
it('should return empty object for unsupported provider', () => {
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
expect(result).toEqual({})
})
it('should return empty object for google provider', () => {
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
expect(result).toEqual({})
})
})
describe('edge cases', () => {
it('should handle maxResults at boundary values', () => {
// Test boundary at 33 (low/medium)
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
expect(result33?.openai?.searchContextSize).toBe('low')
// Test boundary at 34 (medium)
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
expect(result34?.openai?.searchContextSize).toBe('medium')
// Test boundary at 66 (medium)
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
expect(result66?.openai?.searchContextSize).toBe('medium')
// Test boundary at 67 (high)
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
expect(result67?.openai?.searchContextSize).toBe('high')
})
it('should handle zero maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('low')
})
it('should handle very large maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('high')
})
})
})
})

View File

@ -1,3 +1,8 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider' import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { import {
@ -7,17 +12,27 @@ import {
isSupportFlexServiceTierModel, isSupportFlexServiceTierModel,
isSupportVerbosityModel isSupportVerbosityModel
} from '@renderer/config/models' } from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate' import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import type { Assistant, Model, Provider } from '@renderer/types' import { getStoreSetting } from '@renderer/hooks/useSettings'
import { import {
type Assistant,
type GroqServiceTier,
GroqServiceTiers, GroqServiceTiers,
type GroqSystemProvider,
isGroqServiceTier, isGroqServiceTier,
isGroqSystemProvider,
isOpenAIServiceTier, isOpenAIServiceTier,
isTranslateAssistant, isTranslateAssistant,
type Model,
type NotGroqProvider,
type OpenAIServiceTier,
OpenAIServiceTiers, OpenAIServiceTiers,
SystemProviderIds type Provider,
type ServiceTier
} from '@renderer/types' } from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import type { JSONValue } from 'ai'
import { t } from 'i18next' import { t } from 'i18next'
import { getAiSdkProviderId } from '../provider/factory' import { getAiSdkProviderId } from '../provider/factory'
@ -35,8 +50,31 @@ import { getWebSearchParams } from './websearch'
const logger = loggerService.withContext('aiCore.utils.options') const logger = loggerService.withContext('aiCore.utils.options')
// copy from BaseApiClient.ts function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
const getServiceTier = (model: Model, provider: Provider) => { if (
!isOpenAIServiceTier(serviceTier) ||
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
if (
!isGroqServiceTier(serviceTier) ||
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
const serviceTierSetting = provider.serviceTier const serviceTierSetting = provider.serviceTier
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) { if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
@ -44,24 +82,17 @@ const getServiceTier = (model: Model, provider: Provider) => {
} }
// 处理不同供应商需要 fallback 到默认值的情况 // 处理不同供应商需要 fallback 到默认值的情况
if (provider.id === SystemProviderIds.groq) { if (isGroqSystemProvider(provider)) {
if ( return toGroqServiceTier(model, serviceTierSetting)
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
} else { } else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同 // 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if ( return toOpenAIServiceTier(model, serviceTierSetting)
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
} }
}
return serviceTierSetting function getVerbosity(): OpenAIVerbosity {
const openAI = getStoreSetting('openAI')
return openAI.verbosity
} }
/** /**
@ -78,13 +109,13 @@ export function buildProviderOptions(
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): Record<string, Record<string, JSONValue>> {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities }) logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider) const rawProviderId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项 // 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {} let providerSpecificOptions: Record<string, any> = {}
const serviceTierSetting = getServiceTier(model, actualProvider) const serviceTier = getServiceTier(model, actualProvider)
providerSpecificOptions.serviceTier = serviceTierSetting const textVerbosity = getVerbosity()
// 根据 provider 类型分离构建逻辑 // 根据 provider 类型分离构建逻辑
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId) const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
if (success) { if (success) {
@ -94,9 +125,14 @@ export function buildProviderOptions(
case 'openai-chat': case 'openai-chat':
case 'azure': case 'azure':
case 'azure-responses': case 'azure-responses':
providerSpecificOptions = { {
...buildOpenAIProviderOptions(assistant, model, capabilities), const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions(
serviceTier: serviceTierSetting assistant,
model,
capabilities,
serviceTier
)
providerSpecificOptions = options
} }
break break
case 'anthropic': case 'anthropic':
@ -116,12 +152,19 @@ export function buildProviderOptions(
// 对于其他 provider使用通用的构建逻辑 // 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = { providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities), ...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting serviceTier,
textVerbosity
} }
break break
} }
case 'cherryin': case 'cherryin':
providerSpecificOptions = buildCherryInProviderOptions(assistant, model, capabilities, actualProvider) providerSpecificOptions = buildCherryInProviderOptions(
assistant,
model,
capabilities,
actualProvider,
serviceTier
)
break break
default: default:
throw new Error(`Unsupported base provider ${baseProviderId}`) throw new Error(`Unsupported base provider ${baseProviderId}`)
@ -135,6 +178,7 @@ export function buildProviderOptions(
case 'google-vertex': case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break break
case 'azure-anthropic':
case 'google-vertex-anthropic': case 'google-vertex-anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break break
@ -142,13 +186,14 @@ export function buildProviderOptions(
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
break break
case 'huggingface': case 'huggingface':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities) providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
break break
default: default:
// 对于其他 provider使用通用的构建逻辑 // 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = { providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities), ...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting serviceTier,
textVerbosity
} }
} }
} else { } else {
@ -166,6 +211,7 @@ export function buildProviderOptions(
{ {
'google-vertex': 'google', 'google-vertex': 'google',
'google-vertex-anthropic': 'anthropic', 'google-vertex-anthropic': 'anthropic',
'azure-anthropic': 'anthropic',
'ai-gateway': 'gateway' 'ai-gateway': 'gateway'
}[rawProviderId] || rawProviderId }[rawProviderId] || rawProviderId
@ -189,10 +235,11 @@ function buildOpenAIProviderOptions(
enableReasoning: boolean enableReasoning: boolean
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
} },
): Record<string, any> { serviceTier: OpenAIServiceTier
): OpenAIResponsesProviderOptions {
const { enableReasoning } = capabilities const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: OpenAIResponsesProviderOptions = {}
// OpenAI 推理参数 // OpenAI 推理参数
if (enableReasoning) { if (enableReasoning) {
const reasoningParams = getOpenAIReasoningParams(assistant, model) const reasoningParams = getOpenAIReasoningParams(assistant, model)
@ -203,8 +250,8 @@ function buildOpenAIProviderOptions(
} }
if (isSupportVerbosityModel(model)) { if (isSupportVerbosityModel(model)) {
const state = window.store?.getState() const openAI = getStoreSetting<'openAI'>('openAI')
const userVerbosity = state?.settings?.openAI?.verbosity const userVerbosity = openAI?.verbosity
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) { if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
const supportedVerbosity = getModelSupportedVerbosity(model) const supportedVerbosity = getModelSupportedVerbosity(model)
@ -218,6 +265,11 @@ function buildOpenAIProviderOptions(
} }
} }
providerOptions = {
...providerOptions,
serviceTier
}
return providerOptions return providerOptions
} }
@ -232,9 +284,9 @@ function buildAnthropicProviderOptions(
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): AnthropicProviderOptions {
const { enableReasoning } = capabilities const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: AnthropicProviderOptions = {}
// Anthropic 推理参数 // Anthropic 推理参数
if (enableReasoning) { if (enableReasoning) {
@ -259,9 +311,9 @@ function buildGeminiProviderOptions(
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): GoogleGenerativeAIProviderOptions {
const { enableReasoning, enableGenerateImage } = capabilities const { enableReasoning, enableGenerateImage } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: GoogleGenerativeAIProviderOptions = {}
// Gemini 推理参数 // Gemini 推理参数
if (enableReasoning) { if (enableReasoning) {
@ -290,7 +342,7 @@ function buildXAIProviderOptions(
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): XaiProviderOptions {
const { enableReasoning } = capabilities const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: Record<string, any> = {}
@ -313,16 +365,12 @@ function buildCherryInProviderOptions(
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
}, },
actualProvider: Provider actualProvider: Provider,
): Record<string, any> { serviceTier: OpenAIServiceTier
const serviceTierSetting = getServiceTier(model, actualProvider) ): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
switch (actualProvider.type) { switch (actualProvider.type) {
case 'openai': case 'openai':
return { return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
}
case 'anthropic': case 'anthropic':
return buildAnthropicProviderOptions(assistant, model, capabilities) return buildAnthropicProviderOptions(assistant, model, capabilities)
@ -344,9 +392,9 @@ function buildBedrockProviderOptions(
enableWebSearch: boolean enableWebSearch: boolean
enableGenerateImage: boolean enableGenerateImage: boolean
} }
): Record<string, any> { ): BedrockProviderOptions {
const { enableReasoning } = capabilities const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {} let providerOptions: BedrockProviderOptions = {}
if (enableReasoning) { if (enableReasoning) {
const reasoningParams = getBedrockReasoningParams(assistant, model) const reasoningParams = getBedrockReasoningParams(assistant, model)

View File

@ -1,6 +1,7 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock' import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic' import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google' import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai' import type { XaiProviderOptions } from '@ai-sdk/xai'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
@ -11,6 +12,7 @@ import {
isDeepSeekHybridInferenceModel, isDeepSeekHybridInferenceModel,
isDoubaoSeedAfter251015, isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel, isDoubaoThinkingAutoModel,
isGemini3Model,
isGPT51SeriesModel, isGPT51SeriesModel,
isGrok4FastReasoningModel, isGrok4FastReasoningModel,
isGrokReasoningModel, isGrokReasoningModel,
@ -32,13 +34,13 @@ import {
isSupportedThinkingTokenZhipuModel, isSupportedThinkingTokenZhipuModel,
MODEL_SUPPORTED_REASONING_EFFORT MODEL_SUPPORTED_REASONING_EFFORT
} from '@renderer/config/models' } from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import { getStoreSetting } from '@renderer/hooks/useSettings' import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService' import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { SettingsState } from '@renderer/store/settings' import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
import type { Assistant, Model } from '@renderer/types'
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types' import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes'
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk' import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { toInteger } from 'lodash' import { toInteger } from 'lodash'
const logger = loggerService.withContext('reasoning') const logger = loggerService.withContext('reasoning')
@ -130,7 +132,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
} }
// Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider // Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') { if (isGPT51SeriesModel(model)) {
return { return {
reasoningEffort: 'none' reasoningEffort: 'none'
} }
@ -278,6 +280,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
// gemini series, openai compatible api // gemini series, openai compatible api
if (isSupportedThinkingTokenGeminiModel(model)) { if (isSupportedThinkingTokenGeminiModel(model)) {
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
if (isGemini3Model(model)) {
return {
reasoning_effort: reasoningEffort
}
}
if (reasoningEffort === 'auto') { if (reasoningEffort === 'auto') {
return { return {
extra_body: { extra_body: {
@ -341,10 +349,14 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
} }
/** /**
* OpenAI * Get OpenAI reasoning parameters
* OpenAIResponseAPIClient OpenAIAPIClient * Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
* For official OpenAI provider only
*/ */
export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> { export function getOpenAIReasoningParams(
assistant: Assistant,
model: Model
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
if (!isReasoningModel(model)) { if (!isReasoningModel(model)) {
return {} return {}
} }
@ -355,6 +367,10 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {} return {}
} }
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
reasoningEffort = 'medium'
}
// 非OpenAI模型但是Provider类型是responses/azure openai的情况 // 非OpenAI模型但是Provider类型是responses/azure openai的情况
if (!isOpenAIModel(model)) { if (!isOpenAIModel(model)) {
return { return {
@ -362,21 +378,17 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
} }
} }
const openAI = getStoreSetting('openAI') as SettingsState['openAI'] const openAI = getStoreSetting('openAI')
const summaryText = openAI?.summaryText || 'off' const summaryText = openAI.summaryText
let reasoningSummary: string | undefined = undefined let reasoningSummary: OpenAISummaryText = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) { if (model.id.includes('o1-pro')) {
reasoningSummary = undefined reasoningSummary = undefined
} else { } else {
reasoningSummary = summaryText reasoningSummary = summaryText
} }
if (isOpenAIDeepResearchModel(model)) {
reasoningEffort = 'medium'
}
// OpenAI 推理参数 // OpenAI 推理参数
if (isSupportedReasoningEffortOpenAIModel(model)) { if (isSupportedReasoningEffortOpenAIModel(model)) {
return { return {
@ -388,19 +400,26 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {} return {}
} }
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number { export function getAnthropicThinkingBudget(
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant) maxTokens: number | undefined,
reasoningEffort: string | undefined,
modelId: string
): number | undefined {
if (reasoningEffort === undefined || reasoningEffort === 'none') { if (reasoningEffort === undefined || reasoningEffort === 'none') {
return 0 return undefined
} }
const effortRatio = EFFORT_RATIO[reasoningEffort] const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(modelId)
if (!tokenLimit) {
return undefined
}
const budgetTokens = Math.max( const budgetTokens = Math.max(
1024, 1024,
Math.floor( Math.floor(
Math.min( Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + (tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min,
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
) )
) )
@ -432,7 +451,8 @@ export function getAnthropicReasoningParams(
// Claude 推理参数 // Claude 推理参数
if (isSupportedThinkingTokenClaudeModel(model)) { if (isSupportedThinkingTokenClaudeModel(model)) {
const budgetTokens = getAnthropicThinkingBudget(assistant, model) const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return { return {
thinking: { thinking: {
@ -445,6 +465,21 @@ export function getAnthropicReasoningParams(
return {} return {}
} }
type GoogelThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
switch (reasoningEffort) {
case 'low':
return 'low'
case 'medium':
return 'medium'
case 'high':
return 'high'
default:
return 'medium'
}
}
/** /**
* Gemini * Gemini
* GeminiAPIClient * GeminiAPIClient
@ -472,6 +507,15 @@ export function getGeminiReasoningParams(
} }
} }
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
if (isGemini3Model(model)) {
return {
thinkingConfig: {
thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort] const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) { if (effortRatio > 1) {
@ -555,7 +599,8 @@ export function getBedrockReasoningParams(
return {} return {}
} }
const budgetTokens = getAnthropicThinkingBudget(assistant, model) const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return { return {
reasoningConfig: { reasoningConfig: {
type: 'enabled', type: 'enabled',

View File

@ -47,6 +47,7 @@ export function buildProviderBuiltinWebSearchConfig(
model?: Model model?: Model
): WebSearchPluginConfig | undefined { ): WebSearchPluginConfig | undefined {
switch (providerId) { switch (providerId) {
case 'azure-responses':
case 'openai': { case 'openai': {
const searchContextSize = isOpenAIDeepResearchModel(model) const searchContextSize = isOpenAIDeepResearchModel(model)
? 'medium' ? 'medium'

View File

@ -6,7 +6,7 @@ import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import styled, { css } from 'styled-components' import styled, { css } from 'styled-components'
interface SelectorOption<V = string | number> { interface SelectorOption<V = string | number | undefined | null> {
label: string | ReactNode label: string | ReactNode
value: V value: V
type?: 'group' type?: 'group'
@ -14,7 +14,7 @@ interface SelectorOption<V = string | number> {
disabled?: boolean disabled?: boolean
} }
interface BaseSelectorProps<V = string | number> { interface BaseSelectorProps<V = string | number | undefined | null> {
options: SelectorOption<V>[] options: SelectorOption<V>[]
placeholder?: string placeholder?: string
placement?: 'topLeft' | 'topCenter' | 'topRight' | 'bottomLeft' | 'bottomCenter' | 'bottomRight' | 'top' | 'bottom' placement?: 'topLeft' | 'topCenter' | 'topRight' | 'bottomLeft' | 'bottomCenter' | 'bottomRight' | 'top' | 'bottom'
@ -39,7 +39,7 @@ interface MultipleSelectorProps<V> extends BaseSelectorProps<V> {
export type SelectorProps<V> = SingleSelectorProps<V> | MultipleSelectorProps<V> export type SelectorProps<V> = SingleSelectorProps<V> | MultipleSelectorProps<V>
const Selector = <V extends string | number>({ const Selector = <V extends string | number | undefined | null>({
options, options,
value, value,
onChange = () => {}, onChange = () => {},

View File

@ -1,520 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import {
isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel,
isGeminiReasoningModel,
isLingReasoningModel,
isSupportedThinkingTokenGeminiModel
} from '../models/reasoning'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
describe('Doubao Models', () => {
describe('isDoubaoThinkingAutoModel', () => {
it('should return false for invalid models', () => {
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-251015',
name: 'doubao-seed-1-6-251015',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-lite-251015',
name: 'doubao-seed-1-6-lite-251015',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-thinking-250715',
name: 'doubao-seed-1-6-thinking-250715',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-flash',
name: 'doubao-seed-1-6-flash',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-thinking',
name: 'doubao-seed-1-6-thinking',
provider: '',
group: ''
})
).toBe(false)
})
it('should return true for valid models', () => {
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-250615',
name: 'doubao-seed-1-6-250615',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'Doubao-Seed-1.6',
name: 'Doubao-Seed-1.6',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-1-5-thinking-pro-m',
name: 'doubao-1-5-thinking-pro-m',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1.6-lite',
name: 'doubao-seed-1.6-lite',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-1-5-thinking-pro-m-12345',
name: 'doubao-1-5-thinking-pro-m-12345',
provider: '',
group: ''
})
).toBe(true)
})
})
describe('isDoubaoSeedAfter251015', () => {
it('should return true for models matching the pattern', () => {
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-251015',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-lite-251015',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for models not matching the pattern', () => {
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-250615',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'Doubao-Seed-1.6',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-1-5-thinking-pro-m',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-lite-251016',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})
describe('Ling Models', () => {
describe('isLingReasoningModel', () => {
it('should return false for ling variants', () => {
expect(
isLingReasoningModel({
id: 'ling-1t',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isLingReasoningModel({
id: 'ling-flash-2.0',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isLingReasoningModel({
id: 'ling-mini-2.0',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return true for ring variants', () => {
expect(
isLingReasoningModel({
id: 'ring-1t',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isLingReasoningModel({
id: 'ring-flash-2.0',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isLingReasoningModel({
id: 'ring-mini-2.0',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
})
})
describe('Gemini Models', () => {
describe('isSupportedThinkingTokenGeminiModel', () => {
it('should return true for gemini 2.5 models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for image and tts models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-image',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-preview-tts',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return false for older gemini models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
describe('isGeminiReasoningModel', () => {
it('should return true for gemini thinking models', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-2.0-flash-thinking',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-thinking-exp',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for supported thinking token gemini models', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini-3 models', () => {
// Preview versions
expect(
isGeminiReasoningModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isGeminiReasoningModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for older gemini models without thinking', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isGeminiReasoningModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return false for undefined model', () => {
expect(isGeminiReasoningModel(undefined)).toBe(false)
})
})
})

View File

@ -1,167 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import { isVisionModel } from '../models/vision'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
},
getProviderByModel: () => null
}))
describe('isVisionModel', () => {
describe('Gemini Models', () => {
it('should return true for gemini 1.5 models', () => {
expect(
isVisionModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 2.x models', () => {
expect(
isVisionModel({
id: 'gemini-2.0-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isVisionModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isVisionModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isVisionModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini exp models', () => {
expect(
isVisionModel({
id: 'gemini-exp-1206',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for gemini 1.0 models', () => {
expect(
isVisionModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})

View File

@ -1,64 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import { GEMINI_SEARCH_REGEX } from '../models/websearch'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
},
getProviderByModel: () => null
}))
describe('Gemini Search Models', () => {
describe('GEMINI_SEARCH_REGEX', () => {
it('should match gemini 2.x models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
})
it('should match gemini latest models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
})
it('should match gemini 3 models', () => {
// Preview versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
// Future stable versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
})
it('should not match older gemini models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
})
})
})

View File

@ -0,0 +1,101 @@
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
import { isEmbeddingModel, isRerankModel } from '../embedding'
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'test-model',
name: 'Test Model',
provider: 'openai',
group: 'Test',
...overrides
})
describe('isEmbeddingModel', () => {
it('returns true for ids that match the embedding regex', () => {
expect(isEmbeddingModel(createModel({ id: 'Text-Embedding-3-Small' }))).toBe(true)
})
it('returns false for rerank models even if they match embedding patterns', () => {
const model = createModel({ id: 'rerank-qa', name: 'rerank-qa' })
expect(isRerankModel(model)).toBe(true)
expect(isEmbeddingModel(model)).toBe(false)
})
it('honors user overrides for embedding capability', () => {
const model = createModel({
id: 'text-embedding-3-small',
capabilities: [{ type: 'embedding', isUserSelected: false }]
})
expect(isEmbeddingModel(model)).toBe(false)
})
it('uses the model name when provider is doubao', () => {
const model = createModel({
id: 'custom-id',
name: 'BGE-Large-zh-v1.5',
provider: 'doubao'
})
expect(isEmbeddingModel(model)).toBe(true)
})
it('returns false for anthropic provider models', () => {
const model = createModel({
id: 'text-embedding-ada-002',
provider: 'anthropic'
})
expect(isEmbeddingModel(model)).toBe(false)
})
})
describe('isRerankModel', () => {
it('identifies ids that match rerank regex', () => {
expect(isRerankModel(createModel({ id: 'jina-rerank-v2-base' }))).toBe(true)
})
it('honors user overrides for rerank capability', () => {
const model = createModel({
id: 'jina-rerank-v2-base',
capabilities: [{ type: 'rerank', isUserSelected: false }]
})
expect(isRerankModel(model)).toBe(false)
})
})

View File

@ -1,33 +1,55 @@
import { import {
isImageEnhancementModel, isImageEnhancementModel,
isPureGenerateImageModel,
isQwenReasoningModel, isQwenReasoningModel,
isSupportedThinkingTokenQwenModel, isSupportedThinkingTokenQwenModel,
isVisionModel, isVisionModel
isWebSearchModel
} from '@renderer/config/models' } from '@renderer/config/models'
import type { Model } from '@renderer/types' import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, test, vi } from 'vitest' import { beforeEach, describe, expect, test, vi } from 'vitest'
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
const getProviderByModelMock = vi.fn()
const isEmbeddingModelMock = vi.fn()
const isRerankModelMock = vi.fn()
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: (...args: any[]) => getProviderByModelMock(...args),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/config/models/embedding', () => ({
isEmbeddingModel: (...args: any[]) => isEmbeddingModelMock(...args),
isRerankModel: (...args: any[]) => isRerankModelMock(...args)
}))
beforeEach(() => {
vi.clearAllMocks()
getProviderByModelMock.mockReturnValue({ type: 'openai-response' } as any)
isEmbeddingModelMock.mockReturnValue(false)
isRerankModelMock.mockReturnValue(false)
})
// Suggested test cases // Suggested test cases
describe('Qwen Model Detection', () => { describe('Qwen Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
})
test('isQwenReasoningModel', () => { test('isQwenReasoningModel', () => {
expect(isQwenReasoningModel({ id: 'qwen3-thinking' } as Model)).toBe(true) expect(isQwenReasoningModel({ id: 'qwen3-thinking' } as Model)).toBe(true)
expect(isQwenReasoningModel({ id: 'qwen3-instruct' } as Model)).toBe(false) expect(isQwenReasoningModel({ id: 'qwen3-instruct' } as Model)).toBe(false)
@ -56,14 +78,6 @@ describe('Qwen Model Detection', () => {
}) })
describe('Vision Model Detection', () => { describe('Vision Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
})
test('isVisionModel', () => { test('isVisionModel', () => {
expect(isVisionModel({ id: 'qwen-vl-max' } as Model)).toBe(true) expect(isVisionModel({ id: 'qwen-vl-max' } as Model)).toBe(true)
expect(isVisionModel({ id: 'qwen-omni-turbo' } as Model)).toBe(true) expect(isVisionModel({ id: 'qwen-omni-turbo' } as Model)).toBe(true)
@ -75,25 +89,4 @@ describe('Vision Model Detection', () => {
expect(isImageEnhancementModel({ id: 'qwen-image-edit' } as Model)).toBe(true) expect(isImageEnhancementModel({ id: 'qwen-image-edit' } as Model)).toBe(true)
expect(isImageEnhancementModel({ id: 'grok-2-image-latest' } as Model)).toBe(true) expect(isImageEnhancementModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
}) })
test('isPureGenerateImageModel', () => {
expect(isPureGenerateImageModel({ id: 'gpt-image-1' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gemini-2.5-flash-image-preview' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gemini-2.0-flash-preview-image-generation' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gpt-4o' } as Model)).toBe(false)
})
})
describe('Web Search Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
})
test('isWebSearchModel', () => {
expect(isWebSearchModel({ id: 'grok-2-image-latest' } as Model)).toBe(false)
})
}) })

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,137 @@
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { isEmbeddingModel, isRerankModel } from '../embedding'
import { isDeepSeekHybridInferenceModel } from '../reasoning'
import { isFunctionCallingModel } from '../tooluse'
import { isPureGenerateImageModel, isTextToImageModel } from '../vision'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('../embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
vi.mock('../vision', () => ({
isPureGenerateImageModel: vi.fn(),
isTextToImageModel: vi.fn()
}))
vi.mock('../reasoning', () => ({
isDeepSeekHybridInferenceModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
const pureImageMock = vi.mocked(isPureGenerateImageModel)
const textToImageMock = vi.mocked(isTextToImageModel)
const deepSeekHybridMock = vi.mocked(isDeepSeekHybridInferenceModel)
describe('isFunctionCallingModel', () => {
beforeEach(() => {
vi.clearAllMocks()
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
pureImageMock.mockReturnValue(false)
textToImageMock.mockReturnValue(false)
deepSeekHybridMock.mockReturnValue(false)
})
it('returns false when the model is undefined', () => {
expect(isFunctionCallingModel(undefined as unknown as Model)).toBe(false)
})
it('returns false when model is classified as embedding/rerank/image', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel())).toBe(false)
})
it('respect manual user overrides', () => {
const model = createModel({
capabilities: [{ type: 'function_calling', isUserSelected: false }]
})
expect(isFunctionCallingModel(model)).toBe(false)
const enabled = createModel({
capabilities: [{ type: 'function_calling', isUserSelected: true }]
})
expect(isFunctionCallingModel(enabled)).toBe(true)
})
it('matches doubao models by name when regex applies', () => {
const doubao = createModel({
id: 'custom-model',
name: 'Doubao-Seed-1.6-251015',
provider: 'doubao'
})
expect(isFunctionCallingModel(doubao)).toBe(true)
})
it('returns true for regex matches on standard providers', () => {
expect(isFunctionCallingModel(createModel({ id: 'gpt-5' }))).toBe(true)
})
it('excludes explicitly blocked ids', () => {
expect(isFunctionCallingModel(createModel({ id: 'gemini-1.5-flash' }))).toBe(false)
})
it('forces support for trusted providers', () => {
for (const provider of ['deepseek', 'anthropic', 'kimi', 'moonshot']) {
expect(isFunctionCallingModel(createModel({ provider }))).toBe(true)
}
})
it('returns true when identified as deepseek hybrid inference model', () => {
deepSeekHybridMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'custom' }))).toBe(true)
})
it('returns false for deepseek hybrid models behind restricted system providers', () => {
deepSeekHybridMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'dashscope' }))).toBe(false)
})
})

View File

@ -0,0 +1,280 @@
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
isGPT5ProModel,
isGPT5SeriesModel,
isGPT5SeriesReasoningModel,
isGPT51SeriesModel,
isOpenAIChatCompletionOnlyModel,
isOpenAILLMModel,
isOpenAIModel,
isOpenAIOpenWeightModel,
isOpenAIReasoningModel,
isSupportVerbosityModel
} from '../openai'
import { isQwenMTModel } from '../qwen'
import {
agentModelFilter,
getModelSupportedVerbosity,
groupQwenModels,
isAnthropicModel,
isGeminiModel,
isGemmaModel,
isGenerateImageModels,
isMaxTemperatureOneModel,
isNotSupportedTextDelta,
isNotSupportSystemMessageModel,
isNotSupportTemperatureAndTopP,
isSupportedFlexServiceTier,
isSupportedModel,
isSupportFlexServiceTierModel,
isVisionModels,
isZhipuModel
} from '../utils'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from '../vision'
import { isOpenAIWebSearchChatCompletionOnlyModel } from '../websearch'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('@renderer/config/models/embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
vi.mock('../vision', () => ({
isGenerateImageModel: vi.fn(),
isTextToImageModel: vi.fn(),
isVisionModel: vi.fn()
}))
vi.mock(import('../openai'), async (importOriginal) => {
const actual = await importOriginal()
return {
...actual,
isOpenAIReasoningModel: vi.fn()
}
})
vi.mock('../websearch', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
const visionMock = vi.mocked(isVisionModel)
const textToImageMock = vi.mocked(isTextToImageModel)
const generateImageMock = vi.mocked(isGenerateImageModel)
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
describe('model utils', () => {
beforeEach(() => {
vi.clearAllMocks()
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
visionMock.mockReturnValue(true)
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(true)
reasoningMock.mockReturnValue(false)
openAIWebSearchOnlyMock.mockReturnValue(false)
})
it('detects OpenAI LLM models through reasoning and GPT prefix', () => {
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
})
it('detects OpenAI models via GPT prefix or reasoning support', () => {
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
})
it('evaluates support for flex service tier and alias helper', () => {
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('detects verbosity support for GPT-5+ families', () => {
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
})
it('limits verbosity controls for GPT-5 Pro models', () => {
const proModel = createModel({ id: 'gpt-5-pro' })
const previewModel = createModel({ id: 'gpt-5-preview' })
expect(getModelSupportedVerbosity(proModel)).toEqual([undefined, 'high'])
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
expect(isGPT5ProModel(proModel)).toBe(true)
expect(isGPT5ProModel(previewModel)).toBe(false)
})
it('identifies OpenAI chat-completion-only models', () => {
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('filters unsupported OpenAI catalog entries', () => {
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
})
it('calculates temperature/top-p support correctly', () => {
const model = createModel({ id: 'o1' })
reasoningMock.mockReturnValue(true)
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
const openWeight = createModel({ id: 'gpt-oss-debug' })
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
const chatOnly = createModel({ id: 'o1-preview' })
reasoningMock.mockReturnValue(false)
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
})
it('handles gemma and gemini detections plus zhipu tagging', () => {
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
})
it('groups qwen models by prefix', () => {
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
const grouped = groupQwenModels([qwen, qwenOmni, other])
expect(Object.keys(grouped)).toContain('qwen-7b')
expect(Object.keys(grouped)).toContain('qwen2.5')
expect(grouped.DeepSeek).toContain(other)
})
it('aggregates boolean helpers based on regex rules', () => {
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
expect(isQwenMTModel(createModel({ id: 'qwen-mt-large' }))).toBe(true)
expect(isNotSupportedTextDelta(createModel({ id: 'qwen-mt-large' }))).toBe(true)
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
})
it('evaluates GPT-5 family helpers', () => {
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-prompt' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
})
it('wraps generate/vision helpers that operate on arrays', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
expect(isVisionModels(models)).toBe(true)
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isVisionModels(models)).toBe(false)
expect(isGenerateImageModels(models)).toBe(true)
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isGenerateImageModels(models)).toBe(false)
})
it('filters models for agent usage', () => {
expect(agentModelFilter(createModel())).toBe(true)
embeddingMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
it('identifies models with maximum temperature of 1.0', () => {
// Zhipu models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
// Anthropic models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
// Moonshot models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
// Other models should return false
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
})
})

View File

@ -0,0 +1,311 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { isEmbeddingModel, isRerankModel } from '../embedding'
import {
isAutoEnableImageGenerationModel,
isDedicatedImageGenerationModel,
isGenerateImageModel,
isImageEnhancementModel,
isPureGenerateImageModel,
isTextToImageModel,
isVisionModel
} from '../vision'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn()
}))
vi.mock('../embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const providerMock = vi.mocked(getProviderByModel)
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
describe('vision helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
providerMock.mockReturnValue({ type: 'openai-response' } as any)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
})
describe('isGenerateImageModel', () => {
it('returns false for embedding/rerank models or missing providers', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
rerankMock.mockReturnValue(false)
providerMock.mockReturnValueOnce(undefined as any)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
it('detects OpenAI and third-party generative image models', () => {
expect(isGenerateImageModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
providerMock.mockReturnValue({ type: 'custom' } as any)
expect(isGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image' }))).toBe(true)
})
it('returns false when openai-response model is not on allow list', () => {
expect(isGenerateImageModel(createModel({ id: 'gpt-4.2-experimental' }))).toBe(false)
})
})
describe('isPureGenerateImageModel', () => {
it('requires both generate and text-to-image support', () => {
expect(isPureGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(true)
expect(isPureGenerateImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isPureGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image-preview' }))).toBe(true)
})
})
describe('text-to-image helpers', () => {
it('matches predefined keywords', () => {
expect(isTextToImageModel(createModel({ id: 'midjourney-v6' }))).toBe(true)
expect(isTextToImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('detects models with restricted image size support and enhancement', () => {
expect(isImageEnhancementModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
expect(isImageEnhancementModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('identifies dedicated and auto-enabled image generation models', () => {
expect(isDedicatedImageGenerationModel(createModel({ id: 'grok-2-image-1212' }))).toBe(true)
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gemini-2.5-flash-image-ultra' }))).toBe(true)
})
it('returns false when models are not in dedicated or auto-enable sets', () => {
expect(isDedicatedImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
})
})
describe('isVisionModel', () => {
it('returns false for embedding/rerank models and honors overrides', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isVisionModel(createModel({ id: 'gpt-4o' }))).toBe(false)
embeddingMock.mockReturnValue(false)
const disabled = createModel({
id: 'gpt-4o',
capabilities: [{ type: 'vision', isUserSelected: false }]
})
expect(isVisionModel(disabled)).toBe(false)
const forced = createModel({
id: 'gpt-4o',
capabilities: [{ type: 'vision', isUserSelected: true }]
})
expect(isVisionModel(forced)).toBe(true)
})
it('matches doubao models by name and general regexes by id', () => {
const doubao = createModel({
id: 'custom-id',
provider: 'doubao',
name: 'Doubao-Seed-1-6-Lite-251015'
})
expect(isVisionModel(doubao)).toBe(true)
expect(isVisionModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
})
it('leverages image enhancement regex when standard vision regex does not match', () => {
expect(isVisionModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
})
it('returns false for doubao models that fail regex checks', () => {
const doubao = createModel({ id: 'doubao-standard', provider: 'doubao', name: 'basic' })
expect(isVisionModel(doubao)).toBe(false)
})
describe('Gemini Models', () => {
it('should return true for gemini 1.5 models', () => {
expect(
isVisionModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 2.x models', () => {
expect(
isVisionModel({
id: 'gemini-2.0-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isVisionModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isVisionModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isVisionModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini exp models', () => {
expect(
isVisionModel({
id: 'gemini-exp-1206',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for gemini 1.0 models', () => {
expect(
isVisionModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})

View File

@ -0,0 +1,397 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
const providerMock = vi.mocked(getProviderByModel)
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
const isEmbeddingModel = vi.hoisted(() => vi.fn())
const isRerankModel = vi.hoisted(() => vi.fn())
vi.mock('../embedding', () => ({
isEmbeddingModel: (...args: any[]) => isEmbeddingModel(...args),
isRerankModel: (...args: any[]) => isRerankModel(...args)
}))
const isPureGenerateImageModel = vi.hoisted(() => vi.fn())
const isTextToImageModel = vi.hoisted(() => vi.fn())
const isGenerateImageModel = vi.hoisted(() => vi.fn())
vi.mock('../vision', () => ({
isPureGenerateImageModel: (...args: any[]) => isPureGenerateImageModel(...args),
isTextToImageModel: (...args: any[]) => isTextToImageModel(...args),
isGenerateImageModel: (...args: any[]) => isGenerateImageModel(...args),
isModernGenerateImageModel: vi.fn()
}))
const providerMocks = vi.hoisted(() => ({
isGeminiProvider: vi.fn(),
isNewApiProvider: vi.fn(),
isOpenAICompatibleProvider: vi.fn(),
isOpenAIProvider: vi.fn(),
isVertexProvider: vi.fn(),
isAwsBedrockProvider: vi.fn(),
isAzureOpenAIProvider: vi.fn()
}))
vi.mock('@renderer/utils/provider', () => providerMocks)
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { isOpenAIDeepResearchModel } from '../openai'
import {
GEMINI_SEARCH_REGEX,
isHunyuanSearchModel,
isMandatoryWebSearchModel,
isOpenAIWebSearchChatCompletionOnlyModel,
isOpenAIWebSearchModel,
isOpenRouterBuiltInWebSearchModel,
isWebSearchModel
} from '../websearch'
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
id: 'openai',
type: 'openai',
name: 'OpenAI',
apiKey: '',
apiHost: '',
models: [],
...overrides
})
const resetMocks = () => {
providerMock.mockReturnValue(createProvider())
isEmbeddingModel.mockReturnValue(false)
isRerankModel.mockReturnValue(false)
isPureGenerateImageModel.mockReturnValue(false)
isTextToImageModel.mockReturnValue(false)
providerMocks.isGeminiProvider.mockReturnValue(false)
providerMocks.isNewApiProvider.mockReturnValue(false)
providerMocks.isOpenAICompatibleProvider.mockReturnValue(false)
providerMocks.isOpenAIProvider.mockReturnValue(false)
}
describe('websearch helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
resetMocks()
})
describe('isOpenAIDeepResearchModel', () => {
it('detects deep research ids for OpenAI only', () => {
expect(isOpenAIDeepResearchModel(createModel({ id: 'openai/deep-research-preview' }))).toBe(true)
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openai', id: 'gpt-4o' }))).toBe(false)
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openrouter', id: 'deep-research' }))).toBe(false)
})
})
describe('isWebSearchModel', () => {
it('returns false for embedding/rerank/image models', () => {
isEmbeddingModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
resetMocks()
isRerankModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
resetMocks()
isTextToImageModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
})
it('honors user overrides', () => {
const enabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: true }] })
expect(isWebSearchModel(enabled)).toBe(true)
const disabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: false }] })
expect(isWebSearchModel(disabled)).toBe(false)
})
it('returns false when provider lookup fails', () => {
providerMock.mockReturnValueOnce(undefined as any)
expect(isWebSearchModel(createModel())).toBe(false)
})
it('handles Anthropic providers on unsupported platforms', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds['aws-bedrock'] }))
const model = createModel({ id: 'claude-2-sonnet' })
expect(isWebSearchModel(model)).toBe(false)
})
it('returns true for first-party Anthropic provider', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'anthropic' }))
const model = createModel({ id: 'claude-3.5-sonnet-latest', provider: 'anthropic' })
expect(isWebSearchModel(model)).toBe(true)
})
it('detects OpenAI preview search models only when supported', () => {
providerMocks.isOpenAIProvider.mockReturnValue(true)
const model = createModel({ id: 'gpt-4o-search-preview' })
expect(isWebSearchModel(model)).toBe(true)
const nonSearch = createModel({ id: 'gpt-4o-image' })
expect(isWebSearchModel(nonSearch)).toBe(false)
})
it('supports Perplexity sonar families including mandatory variants', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
expect(isWebSearchModel(createModel({ id: 'sonar-deep-research' }))).toBe(true)
})
it('handles AIHubMix Gemini and OpenAI search models', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
expect(isWebSearchModel(createModel({ id: 'gemini-2.5-pro-preview' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
const openaiSearch = createModel({ id: 'gpt-4o-search-preview' })
expect(isWebSearchModel(openaiSearch)).toBe(true)
})
it('supports OpenAI-compatible or new API providers for Gemini/OpenAI models', () => {
const model = createModel({ id: 'gemini-2.5-flash-lite-latest' })
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
providerMocks.isOpenAICompatibleProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(model)).toBe(true)
resetMocks()
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
providerMocks.isNewApiProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
})
it('falls back to Gemini/Vertex provider regex matching', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.vertexai }))
providerMocks.isGeminiProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel({ id: 'gemini-2.0-flash-latest' }))).toBe(true)
})
it('evaluates hunyuan/zhipu/dashscope/openrouter/grok providers', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'hunyuan' }))
expect(isWebSearchModel(createModel({ id: 'hunyuan-pro' }))).toBe(true)
expect(isWebSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
providerMock.mockReturnValueOnce(createProvider({ id: 'zhipu' }))
expect(isWebSearchModel(createModel({ id: 'glm-4-air' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id: 'qwen-max-latest' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isWebSearchModel(createModel())).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'grok' }))
expect(isWebSearchModel(createModel({ id: 'grok-2' }))).toBe(true)
})
})
describe('isMandatoryWebSearchModel', () => {
it('requires sonar ids for perplexity/openrouter providers', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.openrouter }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openai' }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(false)
})
it.each([
['perplexity', 'non-sonar'],
['openrouter', 'gpt-4o-search-preview']
])('returns false for %s provider when id is %s', (providerId, modelId) => {
providerMock.mockReturnValueOnce(createProvider({ id: providerId }))
expect(isMandatoryWebSearchModel(createModel({ id: modelId }))).toBe(false)
})
})
describe('isOpenRouterBuiltInWebSearchModel', () => {
it('checks for sonar ids or OpenAI chat-completion-only variants', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(false)
})
})
describe('OpenAI web search helpers', () => {
it('detects chat completion only variants and openai search ids', () => {
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-mini-search-preview' }))).toBe(true)
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4.1-turbo' }))).toBe(true)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-5.1-chat' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'o3-mini' }))).toBe(true)
})
it.each(['gpt-4.1-preview', 'gpt-4o-2024-05-13', 'o4-mini', 'gpt-5-explorer'])(
'treats %s as an OpenAI web search model',
(id) => {
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each(['gpt-4o-image-preview', 'gpt-4.1-nano', 'gpt-5.1-chat', 'gpt-image-1'])(
'excludes %s from OpenAI web search',
(id) => {
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(false)
}
)
it.each(['gpt-4o-search-preview', 'gpt-4o-mini-search-preview'])('flags %s as chat-completion-only', (id) => {
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id }))).toBe(true)
})
})
describe('isHunyuanSearchModel', () => {
it('identifies hunyuan models except lite', () => {
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-pro', provider: 'hunyuan' }))).toBe(true)
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
expect(isHunyuanSearchModel(createModel())).toBe(false)
})
it.each(['hunyuan-standard', 'hunyuan-advanced'])('accepts %s', (suffix) => {
expect(isHunyuanSearchModel(createModel({ id: suffix, provider: 'hunyuan' }))).toBe(true)
})
})
describe('provider-specific regex coverage', () => {
it.each(['qwen-turbo', 'qwen-max-0919', 'qwen3-max', 'qwen-plus-2024', 'qwq-32b'])(
'dashscope treats %s as searchable',
(id) => {
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each(['qwen-1.5-chat', 'custom-model'])('dashscope ignores %s', (id) => {
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id }))).toBe(false)
})
it.each(['sonar', 'sonar-pro', 'sonar-reasoning-pro', 'sonar-deep-research'])(
'perplexity provider supports %s',
(id) => {
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.perplexity }))
expect(isWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each([
'gemini-2.0-flash-latest',
'gemini-2.5-flash-lite-latest',
'gemini-flash-lite-latest',
'gemini-pro-latest'
])('Gemini provider supports %s', (id) => {
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.vertexai }))
providerMocks.isGeminiProvider.mockReturnValue(true)
expect(isWebSearchModel(createModel({ id }))).toBe(true)
})
})
describe('Gemini Search Models', () => {
describe('GEMINI_SEARCH_REGEX', () => {
it('should match gemini 2.x models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
})
it('should match gemini latest models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
})
it('should match gemini 3 models', () => {
// Preview versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-image-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-image-preview')).toBe(true)
// Future stable versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
// Version with decimals
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-flash-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-pro-image-preview')).toBe(true)
})
it('should not match gemini 2.x image-preview models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-image-preview')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro-image-preview')).toBe(false)
})
it('should not match older gemini models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
})
})
})
})

View File

@ -1,6 +1,8 @@
export * from './default' export * from './default'
export * from './embedding' export * from './embedding'
export * from './logo' export * from './logo'
export * from './openai'
export * from './qwen'
export * from './reasoning' export * from './reasoning'
export * from './tooluse' export * from './tooluse'
export * from './utils' export * from './utils'

View File

@ -0,0 +1,107 @@
import type { Model } from '@renderer/types'
import { getLowerBaseModelName } from '@renderer/utils'
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
export function isOpenAILLMModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
if (modelId.includes('gpt-4o-image')) {
return false
}
if (isOpenAIReasoningModel(model)) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
export function isOpenAIModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
}
export const isGPT5ProModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5-pro')
}
export const isOpenAIOpenWeightModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-oss')
}
export const isGPT5SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
}
export const isGPT5SeriesReasoningModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return isGPT5SeriesModel(model) && !modelId.includes('chat')
}
export const isGPT51SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5.1')
}
export function isSupportVerbosityModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
}
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return (
modelId.includes('gpt-4o-search-preview') ||
modelId.includes('gpt-4o-mini-search-preview') ||
modelId.includes('o1-mini') ||
modelId.includes('o1-preview')
)
}
export function isOpenAIReasoningModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
}
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
modelId.includes('o3') ||
modelId.includes('o4') ||
modelId.includes('gpt-oss') ||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
)
}
const OPENAI_DEEP_RESEARCH_MODEL_REGEX = /deep[-_]?research/
export function isOpenAIDeepResearchModel(model?: Model): boolean {
if (!model) {
return false
}
const providerId = model.provider
if (providerId !== 'openai' && providerId !== 'openai-chat') {
return false
}
const modelId = getLowerBaseModelName(model.id, '/')
return OPENAI_DEEP_RESEARCH_MODEL_REGEX.test(modelId)
}

View File

@ -0,0 +1,7 @@
import type { Model } from '@renderer/types'
import { getLowerBaseModelName } from '@renderer/utils'
export const isQwenMTModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('qwen-mt')
}

View File

@ -8,9 +8,16 @@ import type {
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding' import { isEmbeddingModel, isRerankModel } from './embedding'
import { isGPT5ProModel, isGPT5SeriesModel, isGPT51SeriesModel } from './utils' import {
isGPT5ProModel,
isGPT5SeriesModel,
isGPT51SeriesModel,
isOpenAIDeepResearchModel,
isOpenAIReasoningModel,
isSupportedReasoningEffortOpenAIModel
} from './openai'
import { GEMINI_FLASH_MODEL_REGEX, isGemini3Model } from './utils'
import { isTextToImageModel } from './vision' import { isTextToImageModel } from './vision'
import { GEMINI_FLASH_MODEL_REGEX, isOpenAIDeepResearchModel } from './websearch'
// Reasoning models // Reasoning models
export const REASONING_REGEX = export const REASONING_REGEX =
@ -30,6 +37,7 @@ export const MODEL_SUPPORTED_REASONING_EFFORT: ReasoningEffortConfig = {
grok: ['low', 'high'] as const, grok: ['low', 'high'] as const,
grok4_fast: ['auto'] as const, grok4_fast: ['auto'] as const,
gemini: ['low', 'medium', 'high', 'auto'] as const, gemini: ['low', 'medium', 'high', 'auto'] as const,
gemini3: ['low', 'medium', 'high'] as const,
gemini_pro: ['low', 'medium', 'high', 'auto'] as const, gemini_pro: ['low', 'medium', 'high', 'auto'] as const,
qwen: ['low', 'medium', 'high'] as const, qwen: ['low', 'medium', 'high'] as const,
qwen_thinking: ['low', 'medium', 'high'] as const, qwen_thinking: ['low', 'medium', 'high'] as const,
@ -56,6 +64,7 @@ export const MODEL_SUPPORTED_OPTIONS: ThinkingOptionConfig = {
grok4_fast: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.grok4_fast] as const, grok4_fast: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.grok4_fast] as const,
gemini: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.gemini] as const, gemini: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.gemini] as const,
gemini_pro: MODEL_SUPPORTED_REASONING_EFFORT.gemini_pro, gemini_pro: MODEL_SUPPORTED_REASONING_EFFORT.gemini_pro,
gemini3: MODEL_SUPPORTED_REASONING_EFFORT.gemini3,
qwen: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.qwen] as const, qwen: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.qwen] as const,
qwen_thinking: MODEL_SUPPORTED_REASONING_EFFORT.qwen_thinking, qwen_thinking: MODEL_SUPPORTED_REASONING_EFFORT.qwen_thinking,
doubao: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.doubao] as const, doubao: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.doubao] as const,
@ -106,6 +115,9 @@ const _getThinkModelType = (model: Model): ThinkingModelType => {
} else { } else {
thinkingModelType = 'gemini_pro' thinkingModelType = 'gemini_pro'
} }
if (isGemini3Model(model)) {
thinkingModelType = 'gemini3'
}
} else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok' } else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok'
else if (isSupportedThinkingTokenQwenModel(model)) { else if (isSupportedThinkingTokenQwenModel(model)) {
if (isQwenAlwaysThinkModel(model)) { if (isQwenAlwaysThinkModel(model)) {
@ -254,11 +266,19 @@ export function isGeminiReasoningModel(model?: Model): boolean {
// Gemini 支持思考模式的模型正则 // Gemini 支持思考模式的模型正则
export const GEMINI_THINKING_MODEL_REGEX = export const GEMINI_THINKING_MODEL_REGEX =
/gemini-(?:2\.5.*(?:-latest)?|3-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i /gemini-(?:2\.5.*(?:-latest)?|3(?:\.\d+)?-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i
export const isSupportedThinkingTokenGeminiModel = (model: Model): boolean => { export const isSupportedThinkingTokenGeminiModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id, '/') const modelId = getLowerBaseModelName(model.id, '/')
if (GEMINI_THINKING_MODEL_REGEX.test(modelId)) { if (GEMINI_THINKING_MODEL_REGEX.test(modelId)) {
// gemini-3.x 的 image 模型支持思考模式
if (isGemini3Model(model)) {
if (modelId.includes('tts')) {
return false
}
return true
}
// gemini-2.x 的 image/tts 模型不支持
if (modelId.includes('image') || modelId.includes('tts')) { if (modelId.includes('image') || modelId.includes('tts')) {
return false return false
} }
@ -382,6 +402,12 @@ export function isClaude45ReasoningModel(model: Model): boolean {
return regex.test(modelId) return regex.test(modelId)
} }
export function isClaude4SeriesModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:-[\w-]+)?$/i
return regex.test(modelId)
}
export function isClaudeReasoningModel(model?: Model): boolean { export function isClaudeReasoningModel(model?: Model): boolean {
if (!model) { if (!model) {
return false return false
@ -529,22 +555,6 @@ export function isReasoningModel(model?: Model): boolean {
return REASONING_REGEX.test(modelId) || false return REASONING_REGEX.test(modelId) || false
} }
export function isOpenAIReasoningModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
}
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
modelId.includes('o3') ||
modelId.includes('o4') ||
modelId.includes('gpt-oss') ||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
)
}
export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> = { export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> = {
// Gemini models // Gemini models
'gemini-2\\.5-flash-lite.*$': { min: 512, max: 24576 }, 'gemini-2\\.5-flash-lite.*$': { min: 512, max: 24576 },

View File

@ -4,7 +4,7 @@ import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding' import { isEmbeddingModel, isRerankModel } from './embedding'
import { isDeepSeekHybridInferenceModel } from './reasoning' import { isDeepSeekHybridInferenceModel } from './reasoning'
import { isPureGenerateImageModel, isTextToImageModel } from './vision' import { isTextToImageModel } from './vision'
// Tool calling models // Tool calling models
export const FUNCTION_CALLING_MODELS = [ export const FUNCTION_CALLING_MODELS = [
@ -41,7 +41,9 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [
'gemini-1(?:\\.[\\w-]+)?', 'gemini-1(?:\\.[\\w-]+)?',
'qwen-mt(?:-[\\w-]+)?', 'qwen-mt(?:-[\\w-]+)?',
'gpt-5-chat(?:-[\\w-]+)?', 'gpt-5-chat(?:-[\\w-]+)?',
'glm-4\\.5v' 'glm-4\\.5v',
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation'
] ]
export const FUNCTION_CALLING_REGEX = new RegExp( export const FUNCTION_CALLING_REGEX = new RegExp(
@ -50,13 +52,7 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
) )
export function isFunctionCallingModel(model?: Model): boolean { export function isFunctionCallingModel(model?: Model): boolean {
if ( if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
!model ||
isEmbeddingModel(model) ||
isRerankModel(model) ||
isTextToImageModel(model) ||
isPureGenerateImageModel(model)
) {
return false return false
} }
@ -66,10 +62,6 @@ export function isFunctionCallingModel(model?: Model): boolean {
return isUserSelectedModelType(model, 'function_calling')! return isUserSelectedModelType(model, 'function_calling')!
} }
if (model.provider === 'qiniu') {
return ['deepseek-v3-tool', 'deepseek-v3-0324', 'qwq-32b', 'qwen2.5-72b-instruct'].includes(modelId)
}
if (model.provider === 'doubao' || modelId.includes('doubao')) { if (model.provider === 'doubao' || modelId.includes('doubao')) {
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name) return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
} }

View File

@ -1,43 +1,14 @@
import type OpenAI from '@cherrystudio/openai' import type OpenAI from '@cherrystudio/openai'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import type { Model } from '@renderer/types' import { type Model, SystemProviderIds } from '@renderer/types'
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils' import { getLowerBaseModelName } from '@renderer/utils'
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '@shared/config/prompts'
import { getWebSearchTools } from '../tools' import { isOpenAIChatCompletionOnlyModel, isOpenAIOpenWeightModel, isOpenAIReasoningModel } from './openai'
import { isOpenAIReasoningModel } from './reasoning' import { isQwenMTModel } from './qwen'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision' import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
import { isOpenAIWebSearchChatCompletionOnlyModel } from './websearch'
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
export function isOpenAILLMModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
if (modelId.includes('gpt-4o-image')) {
return false
}
if (isOpenAIReasoningModel(model)) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
export function isOpenAIModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
}
export function isSupportFlexServiceTierModel(model: Model): boolean { export function isSupportFlexServiceTierModel(model: Model): boolean {
if (!model) { if (!model) {
@ -52,33 +23,6 @@ export function isSupportedFlexServiceTier(model: Model): boolean {
return isSupportFlexServiceTierModel(model) return isSupportFlexServiceTierModel(model)
} }
export function isSupportVerbosityModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
}
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return (
modelId.includes('gpt-4o-search-preview') ||
modelId.includes('gpt-4o-mini-search-preview') ||
modelId.includes('o1-mini') ||
modelId.includes('o1-preview')
)
}
export function isGrokModel(model?: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('grok')
}
export function isSupportedModel(model: OpenAI.Models.Model): boolean { export function isSupportedModel(model: OpenAI.Models.Model): boolean {
if (!model) { if (!model) {
return false return false
@ -105,53 +49,6 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
return false return false
} }
export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record<string, any> {
if (!isEnableWebSearch) {
return {}
}
const webSearchTools = getWebSearchTools(model)
if (model.provider === 'grok') {
return {
search_parameters: {
mode: 'auto',
return_citations: true,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
}
}
}
if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
if (model.provider === 'dashscope') {
return {
enable_search: true,
search_options: {
forced_search: true
}
}
}
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
return {
web_search_options: {}
}
}
if (model.provider === 'openrouter') {
return {
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
}
}
return {
tools: webSearchTools
}
}
export function isGemmaModel(model?: Model): boolean { export function isGemmaModel(model?: Model): boolean {
if (!model) { if (!model) {
return false return false
@ -161,12 +58,14 @@ export function isGemmaModel(model?: Model): boolean {
return modelId.includes('gemma-') || model.group === 'Gemma' return modelId.includes('gemma-') || model.group === 'Gemma'
} }
export function isZhipuModel(model?: Model): boolean { export function isZhipuModel(model: Model): boolean {
if (!model) { const modelId = getLowerBaseModelName(model.id)
return false return modelId.includes('glm') || model.provider === SystemProviderIds.zhipu
} }
return model.provider === 'zhipu' export function isMoonshotModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return ['moonshot', 'kimi'].some((m) => modelId.includes(m))
} }
/** /**
@ -212,11 +111,6 @@ export const isAnthropicModel = (model?: Model): boolean => {
return modelId.startsWith('claude') return modelId.startsWith('claude')
} }
export const isQwenMTModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('qwen-mt')
}
export const isNotSupportedTextDelta = (model: Model): boolean => { export const isNotSupportedTextDelta = (model: Model): boolean => {
return isQwenMTModel(model) return isQwenMTModel(model)
} }
@ -225,34 +119,22 @@ export const isNotSupportSystemMessageModel = (model: Model): boolean => {
return isQwenMTModel(model) || isGemmaModel(model) return isQwenMTModel(model) || isGemmaModel(model)
} }
export const isGPT5SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
}
export const isGPT5SeriesReasoningModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return isGPT5SeriesModel(model) && !modelId.includes('chat')
}
export const isGPT51SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5.1')
}
// GPT-5 verbosity configuration // GPT-5 verbosity configuration
// gpt-5-pro only supports 'high', other GPT-5 models support all levels // gpt-5-pro only supports 'high', other GPT-5 models support all levels
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ('low' | 'medium' | 'high')[]> = { export const MODEL_SUPPORTED_VERBOSITY: Record<string, ValidOpenAIVerbosity[]> = {
'gpt-5-pro': ['high'], 'gpt-5-pro': ['high'],
default: ['low', 'medium', 'high'] default: ['low', 'medium', 'high']
} } as const
export const getModelSupportedVerbosity = (model: Model): ('low' | 'medium' | 'high')[] => { export const getModelSupportedVerbosity = (model: Model): OpenAIVerbosity[] => {
const modelId = getLowerBaseModelName(model.id) const modelId = getLowerBaseModelName(model.id)
let supportedValues: ValidOpenAIVerbosity[]
if (modelId.includes('gpt-5-pro')) { if (modelId.includes('gpt-5-pro')) {
return MODEL_SUPPORTED_VERBOSITY['gpt-5-pro'] supportedValues = MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
} else {
supportedValues = MODEL_SUPPORTED_VERBOSITY.default
} }
return MODEL_SUPPORTED_VERBOSITY.default return [undefined, ...supportedValues]
} }
export const isGeminiModel = (model: Model) => { export const isGeminiModel = (model: Model) => {
@ -260,11 +142,6 @@ export const isGeminiModel = (model: Model) => {
return modelId.includes('gemini') return modelId.includes('gemini')
} }
export const isOpenAIOpenWeightModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-oss')
}
// zhipu 视觉推理模型用这组 special token 标记推理结果 // zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
@ -272,7 +149,14 @@ export const agentModelFilter = (model: Model): boolean => {
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model) return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
} }
export const isGPT5ProModel = (model: Model) => { export const isMaxTemperatureOneModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id) if (isZhipuModel(model) || isAnthropicModel(model) || isMoonshotModel(model)) {
return modelId.includes('gpt-5-pro') return true
}
return false
}
export const isGemini3Model = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gemini-3')
} }

View File

@ -3,6 +3,7 @@ import type { Model } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding' import { isEmbeddingModel, isRerankModel } from './embedding'
import { isFunctionCallingModel } from './tooluse'
// Vision models // Vision models
const visionAllowedModels = [ const visionAllowedModels = [
@ -72,12 +73,10 @@ const VISION_REGEX = new RegExp(
// For middleware to identify models that must use the dedicated Image API // For middleware to identify models that must use the dedicated Image API
const DEDICATED_IMAGE_MODELS = [ const DEDICATED_IMAGE_MODELS = [
'grok-2-image', 'grok-2-image(?:-[\\w-]+)?',
'grok-2-image-1212', 'dall-e(?:-[\\w-]+)?',
'grok-2-image-latest', 'gpt-image-1(?:-[\\w-]+)?',
'dall-e-3', 'imagen(?:-[\\w-]+)?'
'dall-e-2',
'gpt-image-1'
] ]
const IMAGE_ENHANCEMENT_MODELS = [ const IMAGE_ENHANCEMENT_MODELS = [
@ -85,13 +84,22 @@ const IMAGE_ENHANCEMENT_MODELS = [
'qwen-image-edit', 'qwen-image-edit',
'gpt-image-1', 'gpt-image-1',
'gemini-2.5-flash-image(?:-[\\w-]+)?', 'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation' 'gemini-2.0-flash-preview-image-generation',
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?'
] ]
const IMAGE_ENHANCEMENT_MODELS_REGEX = new RegExp(IMAGE_ENHANCEMENT_MODELS.join('|'), 'i') const IMAGE_ENHANCEMENT_MODELS_REGEX = new RegExp(IMAGE_ENHANCEMENT_MODELS.join('|'), 'i')
const DEDICATED_IMAGE_MODELS_REGEX = new RegExp(DEDICATED_IMAGE_MODELS.join('|'), 'i')
// Models that should auto-enable image generation button when selected // Models that should auto-enable image generation button when selected
const AUTO_ENABLE_IMAGE_MODELS = ['gemini-2.5-flash-image', ...DEDICATED_IMAGE_MODELS] const AUTO_ENABLE_IMAGE_MODELS = [
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?',
...DEDICATED_IMAGE_MODELS
]
const AUTO_ENABLE_IMAGE_MODELS_REGEX = new RegExp(AUTO_ENABLE_IMAGE_MODELS.join('|'), 'i')
const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [ const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
'o3', 'o3',
@ -105,26 +113,34 @@ const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
const OPENAI_IMAGE_GENERATION_MODELS = [...OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS, 'gpt-image-1'] const OPENAI_IMAGE_GENERATION_MODELS = [...OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS, 'gpt-image-1']
const MODERN_IMAGE_MODELS = ['gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?']
const GENERATE_IMAGE_MODELS = [ const GENERATE_IMAGE_MODELS = [
'gemini-2.0-flash-exp', 'gemini-2.0-flash-exp(?:-[\\w-]+)?',
'gemini-2.0-flash-exp-image-generation', 'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation', 'gemini-2.0-flash-preview-image-generation',
'gemini-2.5-flash-image', ...MODERN_IMAGE_MODELS,
...DEDICATED_IMAGE_MODELS ...DEDICATED_IMAGE_MODELS
] ]
const OPENAI_IMAGE_GENERATION_MODELS_REGEX = new RegExp(OPENAI_IMAGE_GENERATION_MODELS.join('|'), 'i')
const GENERATE_IMAGE_MODELS_REGEX = new RegExp(GENERATE_IMAGE_MODELS.join('|'), 'i')
const MODERN_GENERATE_IMAGE_MODELS_REGEX = new RegExp(MODERN_IMAGE_MODELS.join('|'), 'i')
export const isDedicatedImageGenerationModel = (model: Model): boolean => { export const isDedicatedImageGenerationModel = (model: Model): boolean => {
if (!model) return false if (!model) return false
const modelId = getLowerBaseModelName(model.id) const modelId = getLowerBaseModelName(model.id)
return DEDICATED_IMAGE_MODELS.some((m) => modelId.includes(m)) return DEDICATED_IMAGE_MODELS_REGEX.test(modelId)
} }
export const isAutoEnableImageGenerationModel = (model: Model): boolean => { export const isAutoEnableImageGenerationModel = (model: Model): boolean => {
if (!model) return false if (!model) return false
const modelId = getLowerBaseModelName(model.id) const modelId = getLowerBaseModelName(model.id)
return AUTO_ENABLE_IMAGE_MODELS.some((m) => modelId.includes(m)) return AUTO_ENABLE_IMAGE_MODELS_REGEX.test(modelId)
} }
/** /**
@ -146,48 +162,44 @@ export function isGenerateImageModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/') const modelId = getLowerBaseModelName(model.id, '/')
if (provider.type === 'openai-response') { if (provider.type === 'openai-response') {
return ( return OPENAI_IMAGE_GENERATION_MODELS_REGEX.test(modelId) || GENERATE_IMAGE_MODELS_REGEX.test(modelId)
OPENAI_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel)) ||
GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel))
)
} }
return GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel)) return GENERATE_IMAGE_MODELS_REGEX.test(modelId)
} }
// TODO: refine the regex
/** /**
* *
* @param model * @param model
* @returns * @returns
*/ */
export function isPureGenerateImageModel(model: Model): boolean { export function isPureGenerateImageModel(model: Model): boolean {
if (!isGenerateImageModel(model) || !isTextToImageModel(model)) { if (!isGenerateImageModel(model) && !isTextToImageModel(model)) {
return false
}
if (isFunctionCallingModel(model)) {
return false return false
} }
const modelId = getLowerBaseModelName(model.id) const modelId = getLowerBaseModelName(model.id)
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel)) if (GENERATE_IMAGE_MODELS_REGEX.test(modelId) && !MODERN_GENERATE_IMAGE_MODELS_REGEX.test(modelId)) {
return true
}
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((m) => modelId.includes(m))
} }
// TODO: refine the regex
// Text to image models // Text to image models
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|image|gpt-image/i const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|imagen|gpt-image/i
export function isTextToImageModel(model: Model): boolean { export function isTextToImageModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id) const modelId = getLowerBaseModelName(model.id)
return TEXT_TO_IMAGE_REGEX.test(modelId) return TEXT_TO_IMAGE_REGEX.test(modelId)
} }
// It's not used now
// export function isNotSupportedImageSizeModel(model?: Model): boolean {
// if (!model) {
// return false
// }
// const baseName = getLowerBaseModelName(model.id, '/')
// return baseName.includes('grok-2-image')
// }
/** /**
* *
* @param model * @param model

View File

@ -2,27 +2,29 @@ import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types' import type { Model } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types' import { SystemProviderIds } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { import {
isAzureOpenAIProvider,
isGeminiProvider, isGeminiProvider,
isNewApiProvider, isNewApiProvider,
isOpenAICompatibleProvider, isOpenAICompatibleProvider,
isOpenAIProvider, isOpenAIProvider,
isVertexAiProvider isVertexProvider
} from '../providers' } from '@renderer/utils/provider'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isAnthropicModel } from './utils'
import { isPureGenerateImageModel, isTextToImageModel } from './vision'
export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp( export { GEMINI_FLASH_MODEL_REGEX } from './utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isClaude4SeriesModel } from './reasoning'
import { isAnthropicModel } from './utils'
import { isTextToImageModel } from './vision'
const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
`\\b(?:claude-3(-|\\.)(7|5)-sonnet(?:-[\\w-]+)|claude-3(-|\\.)5-haiku(?:-[\\w-]+)|claude-(haiku|sonnet|opus)-4(?:-[\\w-]+)?)\\b`, `\\b(?:claude-3(-|\\.)(7|5)-sonnet(?:-[\\w-]+)|claude-3(-|\\.)5-haiku(?:-[\\w-]+)|claude-(haiku|sonnet|opus)-4(?:-[\\w-]+)?)\\b`,
'i' 'i'
) )
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$')
export const GEMINI_SEARCH_REGEX = new RegExp( export const GEMINI_SEARCH_REGEX = new RegExp(
'gemini-(?:2.*(?:-latest)?|3-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\\w-]+)*$', 'gemini-(?:2(?!.*-image-preview).*(?:-latest)?|3(?:\\.\\d+)?-(?:flash|pro)(?:-(?:image-)?preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\\w-]+)*$',
'i' 'i'
) )
@ -34,30 +36,8 @@ export const PERPLEXITY_SEARCH_MODELS = [
'sonar-deep-research' 'sonar-deep-research'
] ]
const OPENAI_DEEP_RESEARCH_MODEL_REGEX = /deep[-_]?research/
export function isOpenAIDeepResearchModel(model?: Model): boolean {
if (!model) {
return false
}
const providerId = model.provider
if (providerId !== 'openai' && providerId !== 'openai-chat') {
return false
}
const modelId = getLowerBaseModelName(model.id, '/')
return OPENAI_DEEP_RESEARCH_MODEL_REGEX.test(modelId)
}
export function isWebSearchModel(model: Model): boolean { export function isWebSearchModel(model: Model): boolean {
if ( if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
!model ||
isEmbeddingModel(model) ||
isRerankModel(model) ||
isTextToImageModel(model) ||
isPureGenerateImageModel(model)
) {
return false return false
} }
@ -73,16 +53,17 @@ export function isWebSearchModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/') const modelId = getLowerBaseModelName(model.id, '/')
// bedrock和vertex不支持 // bedrock不支持, azure支持
if ( if (isAnthropicModel(model) && !(provider.id === SystemProviderIds['aws-bedrock'])) {
isAnthropicModel(model) && if (isVertexProvider(provider)) {
!(provider.id === SystemProviderIds['aws-bedrock'] || provider.id === SystemProviderIds.vertexai) return isClaude4SeriesModel(model)
) { }
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(modelId) return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(modelId)
} }
// TODO: 当其他供应商采用Response端点时这个地方逻辑需要改进 // TODO: 当其他供应商采用Response端点时这个地方逻辑需要改进
if (isOpenAIProvider(provider)) { // azure现在也支持了websearch
if (isOpenAIProvider(provider) || isAzureOpenAIProvider(provider)) {
if (isOpenAIWebSearchModel(model)) { if (isOpenAIWebSearchModel(model)) {
return true return true
} }
@ -113,7 +94,7 @@ export function isWebSearchModel(model: Model): boolean {
} }
} }
if (isGeminiProvider(provider) || isVertexAiProvider(provider)) { if (isGeminiProvider(provider) || isVertexProvider(provider)) {
return GEMINI_SEARCH_REGEX.test(modelId) return GEMINI_SEARCH_REGEX.test(modelId)
} }

View File

@ -59,15 +59,8 @@ import VoyageAIProviderLogo from '@renderer/assets/images/providers/voyageai.png
import XirangProviderLogo from '@renderer/assets/images/providers/xirang.png' import XirangProviderLogo from '@renderer/assets/images/providers/xirang.png'
import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png' import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png'
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png' import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
import type { import type { AtLeast, SystemProvider, SystemProviderId } from '@renderer/types'
AtLeast, import { OpenAIServiceTiers } from '@renderer/types'
AzureOpenAIProvider,
Provider,
ProviderType,
SystemProvider,
SystemProviderId
} from '@renderer/types'
import { isSystemProvider, OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
import { TOKENFLUX_HOST } from './constant' import { TOKENFLUX_HOST } from './constant'
import { glm45FlashModel, qwen38bModel, SYSTEM_MODELS } from './models' import { glm45FlashModel, qwen38bModel, SYSTEM_MODELS } from './models'
@ -1441,153 +1434,3 @@ export const PROVIDER_URLS: Record<SystemProviderId, ProviderUrls> = {
} }
} }
} }
const NOT_SUPPORT_ARRAY_CONTENT_PROVIDERS = [
'deepseek',
'baichuan',
'minimax',
'xirang',
'poe',
'cephalon'
] as const satisfies SystemProviderId[]
/**
* message content Only for OpenAI Chat Completions API.
*/
export const isSupportArrayContentProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportArrayContent !== true &&
!NOT_SUPPORT_ARRAY_CONTENT_PROVIDERS.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_DEVELOPER_ROLE_PROVIDERS = ['poe', 'qiniu'] as const satisfies SystemProviderId[]
/**
* developer message role Only for OpenAI API.
*/
export const isSupportDeveloperRoleProvider = (provider: Provider) => {
return (
provider.apiOptions?.isSupportDeveloperRole === true ||
(isSystemProvider(provider) && !NOT_SUPPORT_DEVELOPER_ROLE_PROVIDERS.some((pid) => pid === provider.id))
)
}
const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const satisfies SystemProviderId[]
/**
* stream_options Only for OpenAI API.
*/
export const isSupportStreamOptionsProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportStreamOptions !== true &&
!NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_QWEN3_ENABLE_THINKING_PROVIDER = [
'ollama',
'lmstudio',
'nvidia'
] as const satisfies SystemProviderId[]
/**
* 使 enable_thinking Qwen3 Only for OpenAI Chat Completions API.
*/
export const isSupportEnableThinkingProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportEnableThinking !== true &&
!NOT_SUPPORT_QWEN3_ENABLE_THINKING_PROVIDER.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_SERVICE_TIER_PROVIDERS = ['github', 'copilot', 'cerebras'] as const satisfies SystemProviderId[]
/**
* service_tier Only for OpenAI API.
*/
export const isSupportServiceTierProvider = (provider: Provider) => {
return (
provider.apiOptions?.isSupportServiceTier === true ||
(isSystemProvider(provider) && !NOT_SUPPORT_SERVICE_TIER_PROVIDERS.some((pid) => pid === provider.id))
)
}
const SUPPORT_URL_CONTEXT_PROVIDER_TYPES = [
'gemini',
'vertexai',
'anthropic',
'new-api'
] as const satisfies ProviderType[]
export const isSupportUrlContextProvider = (provider: Provider) => {
return (
SUPPORT_URL_CONTEXT_PROVIDER_TYPES.some((type) => type === provider.type) ||
provider.id === SystemProviderIds.cherryin
)
}
const SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS = ['gemini', 'vertexai'] as const satisfies SystemProviderId[]
/** 判断是否是使用 Gemini 原生搜索工具的 provider. 目前假设只有官方 API 使用原生工具 */
export const isGeminiWebSearchProvider = (provider: Provider) => {
return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id)
}
export const isNewApiProvider = (provider: Provider) => {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
}
export function isCherryAIProvider(provider: Provider): boolean {
return provider.id === 'cherryai'
}
export function isPerplexityProvider(provider: Provider): boolean {
return provider.id === 'perplexity'
}
/**
* OpenAI
* @param {Provider} provider
* @returns {boolean} OpenAI
*/
export function isOpenAICompatibleProvider(provider: Provider): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider {
return provider.type === 'azure-openai'
}
export function isOpenAIProvider(provider: Provider): boolean {
return provider.type === 'openai-response'
}
export function isAnthropicProvider(provider: Provider): boolean {
return provider.type === 'anthropic'
}
export function isGeminiProvider(provider: Provider): boolean {
return provider.type === 'gemini'
}
export function isVertexAiProvider(provider: Provider): boolean {
return provider.type === 'vertexai'
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'ai-gateway'
}
export function isAwsBedrockProvider(provider: Provider): boolean {
return provider.type === 'aws-bedrock'
}
const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[]
export const isSupportAPIVersionProvider = (provider: Provider) => {
if (isSystemProvider(provider)) {
return !NOT_SUPPORT_API_VERSION_PROVIDERS.some((pid) => pid === provider.id)
}
return provider.apiOptions?.isNotSupportAPIVersion !== false
}

View File

@ -1,55 +0,0 @@
import type { ChatCompletionTool } from '@cherrystudio/openai/resources'
import type { Model } from '@renderer/types'
import { WEB_SEARCH_PROMPT_FOR_ZHIPU } from '@shared/config/prompts'
export function getWebSearchTools(model: Model): ChatCompletionTool[] {
if (model?.provider === 'zhipu') {
if (model.id === 'glm-4-alltools') {
return [
{
type: 'web_browser',
web_browser: {
browser: 'auto'
}
} as unknown as ChatCompletionTool
]
}
return [
{
type: 'web_search',
web_search: {
enable: true,
search_result: true,
search_prompt: WEB_SEARCH_PROMPT_FOR_ZHIPU
}
} as unknown as ChatCompletionTool
]
}
if (model?.id.includes('gemini')) {
return [
{
type: 'function',
function: {
name: 'googleSearch'
}
}
]
}
return []
}
export function getUrlContextTools(model: Model): ChatCompletionTool[] {
if (model.id.includes('gemini')) {
return [
{
type: 'function',
function: {
name: 'urlContext'
}
}
]
}
return []
}

View File

@ -175,14 +175,46 @@ export function useAppInit() {
useEffect(() => { useEffect(() => {
if (!window.electron?.ipcRenderer) return if (!window.electron?.ipcRenderer) return
const requestListener = (_event: Electron.IpcRendererEvent, payload: ToolPermissionRequestPayload) => { const requestListener = async (_event: Electron.IpcRendererEvent, payload: ToolPermissionRequestPayload) => {
logger.debug('Renderer received tool permission request', { logger.debug('Renderer received tool permission request', {
requestId: payload.requestId, requestId: payload.requestId,
toolName: payload.toolName, toolName: payload.toolName,
expiresAt: payload.expiresAt, expiresAt: payload.expiresAt,
suggestionCount: payload.suggestions.length suggestionCount: payload.suggestions.length,
autoApprove: payload.autoApprove
}) })
dispatch(toolPermissionsActions.requestReceived(payload)) dispatch(toolPermissionsActions.requestReceived(payload))
// Auto-approve if requested
if (payload.autoApprove) {
logger.debug('Auto-approving tool permission request', {
requestId: payload.requestId,
toolName: payload.toolName
})
dispatch(toolPermissionsActions.submissionSent({ requestId: payload.requestId, behavior: 'allow' }))
try {
const response = await window.api.agentTools.respondToPermission({
requestId: payload.requestId,
behavior: 'allow',
updatedInput: payload.input,
updatedPermissions: payload.suggestions
})
if (!response?.success) {
throw new Error('Auto-approval response rejected by main process')
}
logger.debug('Auto-approval acknowledged by main process', {
requestId: payload.requestId,
toolName: payload.toolName
})
} catch (error) {
logger.error('Failed to send auto-approval response', error as Error)
dispatch(toolPermissionsActions.submissionFailed({ requestId: payload.requestId }))
}
}
} }
const resultListener = (_event: Electron.IpcRendererEvent, payload: ToolPermissionResultPayload) => { const resultListener = (_event: Electron.IpcRendererEvent, payload: ToolPermissionResultPayload) => {

View File

@ -38,13 +38,6 @@ export function getVertexAIServiceAccount() {
return store.getState().llm.settings.vertexai.serviceAccount return store.getState().llm.settings.vertexai.serviceAccount
} }
/**
* Provider VertexProvider
*/
export function isVertexProvider(provider: Provider): provider is VertexProvider {
return provider.type === 'vertexai'
}
/** /**
* VertexProvider * VertexProvider
* @param baseProvider provider * @param baseProvider provider

View File

@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "Failed to send your decision. Please try again." "sendFailed": "Failed to send your decision. Please try again."
}, },
"executing": "Executing...",
"expired": "Expired", "expired": "Expired",
"inputPreview": "Tool input preview", "inputPreview": "Tool input preview",
"pending": "Pending ({{seconds}}s)", "pending": "Pending ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Name", "name": "Name",
"no_results": "No results", "no_results": "No results",
"none": "None", "none": "None",
"off": "Off",
"open": "Open", "open": "Open",
"paste": "Paste", "paste": "Paste",
"placeholders": { "placeholders": {
@ -4259,7 +4261,6 @@
"default": "default", "default": "default",
"flex": "flex", "flex": "flex",
"on_demand": "on demand", "on_demand": "on demand",
"performance": "performance",
"priority": "priority", "priority": "priority",
"tip": "Specifies the latency tier to use for processing the request", "tip": "Specifies the latency tier to use for processing the request",
"title": "Service Tier" "title": "Service Tier"
@ -4278,7 +4279,7 @@
"low": "Low", "low": "Low",
"medium": "Medium", "medium": "Medium",
"tip": "Control the level of detail in the model's output", "tip": "Control the level of detail in the model's output",
"title": "Level of detail" "title": "Verbosity"
} }
}, },
"privacy": { "privacy": {

View File

@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "发送您的决定失败,请重试。" "sendFailed": "发送您的决定失败,请重试。"
}, },
"executing": "正在执行...",
"expired": "已过期", "expired": "已过期",
"inputPreview": "工具输入预览", "inputPreview": "工具输入预览",
"pending": "等待中 ({{seconds}}秒)", "pending": "等待中 ({{seconds}}秒)",
@ -1158,6 +1159,7 @@
"name": "名称", "name": "名称",
"no_results": "无结果", "no_results": "无结果",
"none": "无", "none": "无",
"off": "关闭",
"open": "打开", "open": "打开",
"paste": "粘贴", "paste": "粘贴",
"placeholders": { "placeholders": {
@ -4259,7 +4261,6 @@
"default": "默认", "default": "默认",
"flex": "灵活", "flex": "灵活",
"on_demand": "按需", "on_demand": "按需",
"performance": "性能",
"priority": "优先", "priority": "优先",
"tip": "指定用于处理请求的延迟层级", "tip": "指定用于处理请求的延迟层级",
"title": "服务层级" "title": "服务层级"

View File

@ -37,7 +37,7 @@
"success": "成功偵測到 Git Bash" "success": "成功偵測到 Git Bash"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "在這裡輸入您的訊息,使用 {{key}} 傳送 - @ 選擇路徑,/ 選擇命令"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "傳送您的決定失敗,請重試。" "sendFailed": "傳送您的決定失敗,請重試。"
}, },
"executing": "執行中...",
"expired": "已過期", "expired": "已過期",
"inputPreview": "工具輸入預覽", "inputPreview": "工具輸入預覽",
"pending": "等待中 ({{seconds}}秒)", "pending": "等待中 ({{seconds}}秒)",
@ -1158,6 +1159,7 @@
"name": "名稱", "name": "名稱",
"no_results": "沒有結果", "no_results": "沒有結果",
"none": "無", "none": "無",
"off": "關閉",
"open": "開啟", "open": "開啟",
"paste": "貼上", "paste": "貼上",
"placeholders": { "placeholders": {
@ -4259,7 +4261,6 @@
"default": "預設", "default": "預設",
"flex": "彈性", "flex": "彈性",
"on_demand": "按需", "on_demand": "按需",
"performance": "效能",
"priority": "優先", "priority": "優先",
"tip": "指定用於處理請求的延遲層級", "tip": "指定用於處理請求的延遲層級",
"title": "服務層級" "title": "服務層級"

View File

@ -29,15 +29,15 @@
}, },
"gitBash": { "gitBash": {
"error": { "error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from", "description": "Git Bash ist erforderlich, um Agents unter Windows auszuführen. Der Agent kann ohne es nicht funktionieren. Bitte installieren Sie Git für Windows von",
"recheck": "[to be translated]:Recheck Git Bash Installation", "recheck": "Überprüfe die Git Bash-Installation erneut",
"title": "[to be translated]:Git Bash Required" "title": "Git Bash erforderlich"
}, },
"notFound": "[to be translated]:Git Bash not found. Please install it first.", "notFound": "Git Bash nicht gefunden. Bitte installieren Sie es zuerst.",
"success": "[to be translated]:Git Bash detected successfully!" "success": "Git Bash erfolgreich erkannt!"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "Gib hier deine Nachricht ein, senden mit {{key}} @ Pfad auswählen, / Befehl auswählen"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "Ihre Entscheidung konnte nicht gesendet werden. Bitte versuchen Sie es erneut." "sendFailed": "Ihre Entscheidung konnte nicht gesendet werden. Bitte versuchen Sie es erneut."
}, },
"executing": "Ausführen...",
"expired": "Abgelaufen", "expired": "Abgelaufen",
"inputPreview": "Vorschau der Werkzeugeingabe", "inputPreview": "Vorschau der Werkzeugeingabe",
"pending": "Ausstehend ({{seconds}}s)", "pending": "Ausstehend ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Name", "name": "Name",
"no_results": "Keine Ergebnisse", "no_results": "Keine Ergebnisse",
"none": "Keine", "none": "Keine",
"off": "Aus",
"open": "Öffnen", "open": "Öffnen",
"paste": "Einfügen", "paste": "Einfügen",
"placeholders": { "placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Vorschau", "preview": "Vorschau",
"split": "Geteilte Ansicht" "split": "Geteilte Ansicht"
}, },
"import": {
"chatgpt": {
"assistant_name": "ChatGPT-Import",
"button": "Datei auswählen",
"description": "Importiert nur Gesprächstexte, keine Bilder und Anhänge",
"error": {
"invalid_json": "Ungültiges JSON-Dateiformat",
"no_conversations": "Keine Gespräche in der Datei gefunden",
"no_valid_conversations": "Keine gültigen Konversationen zum Importieren",
"unknown": "Import fehlgeschlagen, bitte überprüfen Sie das Dateiformat"
},
"help": {
"step1": "1. Melden Sie sich bei ChatGPT an, gehen Sie zu Einstellungen > Dateneinstellungen > Daten exportieren",
"step2": "2. Warten Sie auf die Exportdatei per E-Mail",
"step3": "3. Extrahiere die heruntergeladene Datei und finde conversations.json",
"title": "Wie exportiere ich ChatGPT-Gespräche?"
},
"importing": "Konversationen werden importiert...",
"selecting": "Datei wird ausgewählt...",
"success": "Erfolgreich {{topics}} Konversationen mit {{messages}} Nachrichten importiert",
"title": "ChatGPT-Unterhaltungen importieren",
"untitled_conversation": "Unbenannte Unterhaltung"
},
"confirm": {
"button": "Importdatei auswählen",
"label": "Sind Sie sicher, dass Sie externe Daten importieren möchten?"
},
"content": "Wählen Sie die zu importierende Gesprächsdatei einer externen Anwendung aus; derzeit werden nur ChatGPT-JSON-Formatdateien unterstützt.",
"title": "Externe Gespräche importieren"
},
"knowledge": { "knowledge": {
"add": { "add": {
"title": "Wissensdatenbank hinzufügen" "title": "Wissensdatenbank hinzufügen"
@ -3094,6 +3126,7 @@
"basic": "Grundlegende Dateneinstellungen", "basic": "Grundlegende Dateneinstellungen",
"cloud_storage": "Cloud-Backup-Einstellungen", "cloud_storage": "Cloud-Backup-Einstellungen",
"export_settings": "Export-Einstellungen", "export_settings": "Export-Einstellungen",
"import_settings": "Importeinstellungen",
"third_party": "Drittanbieter-Verbindungen" "third_party": "Drittanbieter-Verbindungen"
}, },
"export_menu": { "export_menu": {
@ -3152,6 +3185,11 @@
}, },
"hour_interval_one": "{{count}} Stunde", "hour_interval_one": "{{count}} Stunde",
"hour_interval_other": "{{count}} Stunden", "hour_interval_other": "{{count}} Stunden",
"import_settings": {
"button": "JSON-Datei importieren",
"chatgpt": "Import aus ChatGPT",
"title": "Importiere Daten von externen Anwendungen"
},
"joplin": { "joplin": {
"check": { "check": {
"button": "Erkennen", "button": "Erkennen",
@ -4223,7 +4261,6 @@
"default": "Standard", "default": "Standard",
"flex": "Flexibel", "flex": "Flexibel",
"on_demand": "Auf Anfrage", "on_demand": "Auf Anfrage",
"performance": "Leistung",
"priority": "Priorität", "priority": "Priorität",
"tip": "Latenz-Ebene für Anfrageverarbeitung festlegen", "tip": "Latenz-Ebene für Anfrageverarbeitung festlegen",
"title": "Service-Tier" "title": "Service-Tier"

View File

@ -29,15 +29,15 @@
}, },
"gitBash": { "gitBash": {
"error": { "error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from", "description": "Το Git Bash απαιτείται για την εκτέλεση πρακτόρων στα Windows. Ο πράκτορας δεν μπορεί να λειτουργήσει χωρίς αυτό. Παρακαλούμε εγκαταστήστε το Git για Windows από",
"recheck": "[to be translated]:Recheck Git Bash Installation", "recheck": "Επανέλεγχος Εγκατάστασης του Git Bash",
"title": "[to be translated]:Git Bash Required" "title": "Απαιτείται Git Bash"
}, },
"notFound": "[to be translated]:Git Bash not found. Please install it first.", "notFound": "Το Git Bash δεν βρέθηκε. Παρακαλώ εγκαταστήστε το πρώτα.",
"success": "[to be translated]:Git Bash detected successfully!" "success": "Το Git Bash εντοπίστηκε με επιτυχία!"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "Εισάγετε το μήνυμά σας εδώ, στείλτε με {{key}} - @ επιλέξτε διαδρομή, / επιλέξτε εντολή"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "Αποτυχία αποστολής της απόφασής σας. Προσπαθήστε ξανά." "sendFailed": "Αποτυχία αποστολής της απόφασής σας. Προσπαθήστε ξανά."
}, },
"executing": "Εκτέλεση...",
"expired": "Ληγμένο", "expired": "Ληγμένο",
"inputPreview": "Προεπισκόπηση εισόδου εργαλείου", "inputPreview": "Προεπισκόπηση εισόδου εργαλείου",
"pending": "Εκκρεμεί ({{seconds}}δ)", "pending": "Εκκρεμεί ({{seconds}}δ)",
@ -1158,6 +1159,7 @@
"name": "Όνομα", "name": "Όνομα",
"no_results": "Δεν βρέθηκαν αποτελέσματα", "no_results": "Δεν βρέθηκαν αποτελέσματα",
"none": "Χωρίς", "none": "Χωρίς",
"off": "Κλειστό",
"open": "Άνοιγμα", "open": "Άνοιγμα",
"paste": "Επικόλληση", "paste": "Επικόλληση",
"placeholders": { "placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Προεπισκόπηση", "preview": "Προεπισκόπηση",
"split": "Διαχωρισμός" "split": "Διαχωρισμός"
}, },
"import": {
"chatgpt": {
"assistant_name": "Εισαγωγή ChatGPT",
"button": "Επιλέξτε Αρχείο",
"description": "Εισάγει μόνο κείμενο συνομιλίας, δεν περιλαμβάνει εικόνες και συνημμένα",
"error": {
"invalid_json": "Μη έγκυρη μορφή αρχείου JSON",
"no_conversations": "Δεν βρέθηκαν συνομιλίες στο αρχείο",
"no_valid_conversations": "Δεν υπάρχουν έγκυρες συνομιλίες προς εισαγωγή",
"unknown": "Η εισαγωγή απέτυχε, παρακαλώ ελέγξτε τη μορφή του αρχείου"
},
"help": {
"step1": "1. Συνδεθείτε στο ChatGPT, πηγαίνετε στις Ρυθμίσεις > Έλεγχοι δεδομένων > Εξαγωγή δεδομένων",
"step2": "2. Περιμένετε το αρχείο εξαγωγής μέσω email",
"step3": "3. Εξαγάγετε το ληφθέν αρχείο και βρείτε το conversations.json",
"title": "Πώς να εξάγετε συνομιλίες του ChatGPT;"
},
"importing": "Εισαγωγή συνομιλιών...",
"selecting": "Επιλογή αρχείου...",
"success": "Επιτυχής εισαγωγή {{topics}} συνομιλιών με {{messages}} μηνύματα",
"title": "Εισαγωγή Συνομιλιών του ChatGPT",
"untitled_conversation": "Συνομιλία χωρίς τίτλο"
},
"confirm": {
"button": "Επιλέξτε Εισαγωγή Αρχείου",
"label": "Είστε σίγουροι ότι θέλετε να εισάγετε εξωτερικά δεδομένα;"
},
"content": "Επιλέξτε εξωτερικό αρχείο συνομιλίας εφαρμογής για εισαγωγή, προς το παρόν υποστηρίζονται μόνο αρχεία μορφής JSON του ChatGPT",
"title": "Εισαγωγή Εξωτερικών Συνομιλιών"
},
"knowledge": { "knowledge": {
"add": { "add": {
"title": "Προσθήκη βιβλιοθήκης γνώσεων" "title": "Προσθήκη βιβλιοθήκης γνώσεων"
@ -3094,6 +3126,7 @@
"basic": "Ρυθμίσεις βασικών δεδομένων", "basic": "Ρυθμίσεις βασικών δεδομένων",
"cloud_storage": "Ρυθμίσεις αποθήκευσης στο νέφος", "cloud_storage": "Ρυθμίσεις αποθήκευσης στο νέφος",
"export_settings": "Ρυθμίσεις εξαγωγής", "export_settings": "Ρυθμίσεις εξαγωγής",
"import_settings": "Εισαγωγή Ρυθμίσεων",
"third_party": "Σύνδεση τρίτων" "third_party": "Σύνδεση τρίτων"
}, },
"export_menu": { "export_menu": {
@ -3152,6 +3185,11 @@
}, },
"hour_interval_one": "{{count}} ώρα", "hour_interval_one": "{{count}} ώρα",
"hour_interval_other": "{{count}} ώρες", "hour_interval_other": "{{count}} ώρες",
"import_settings": {
"button": "Εισαγωγή αρχείου Json",
"chatgpt": "Εισαγωγή από το ChatGPT",
"title": "Εισαγωγή Δεδομένων Εξωτερικής Εφαρμογής"
},
"joplin": { "joplin": {
"check": { "check": {
"button": "Έλεγχος", "button": "Έλεγχος",
@ -4223,7 +4261,6 @@
"default": "Προεπιλογή", "default": "Προεπιλογή",
"flex": "Εύκαμπτο", "flex": "Εύκαμπτο",
"on_demand": "κατά παραγγελία", "on_demand": "κατά παραγγελία",
"performance": "Απόδοση",
"priority": "προτεραιότητα", "priority": "προτεραιότητα",
"tip": "Καθορίστε το επίπεδο καθυστέρησης που χρησιμοποιείται για την επεξεργασία των αιτημάτων", "tip": "Καθορίστε το επίπεδο καθυστέρησης που χρησιμοποιείται για την επεξεργασία των αιτημάτων",
"title": "Επίπεδο υπηρεσίας" "title": "Επίπεδο υπηρεσίας"

View File

@ -29,15 +29,15 @@
}, },
"gitBash": { "gitBash": {
"error": { "error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from", "description": "Se requiere Git Bash para ejecutar agentes en Windows. El agente no puede funcionar sin él. Instale Git para Windows desde",
"recheck": "[to be translated]:Recheck Git Bash Installation", "recheck": "Volver a verificar la instalación de Git Bash",
"title": "[to be translated]:Git Bash Required" "title": "Git Bash Requerido"
}, },
"notFound": "[to be translated]:Git Bash not found. Please install it first.", "notFound": "Git Bash no encontrado. Por favor, instálalo primero.",
"success": "[to be translated]:Git Bash detected successfully!" "success": "¡Git Bash detectado con éxito!"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "Introduce tu mensaje aquí, envía con {{key}} - @ seleccionar ruta, / seleccionar comando"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "No se pudo enviar tu decisión. Por favor, inténtalo de nuevo." "sendFailed": "No se pudo enviar tu decisión. Por favor, inténtalo de nuevo."
}, },
"executing": "Ejecutando...",
"expired": "Caducado", "expired": "Caducado",
"inputPreview": "Vista previa de entrada de herramienta", "inputPreview": "Vista previa de entrada de herramienta",
"pending": "Pendiente ({{seconds}}s)", "pending": "Pendiente ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Nombre", "name": "Nombre",
"no_results": "Sin resultados", "no_results": "Sin resultados",
"none": "无", "none": "无",
"off": "Apagado",
"open": "Abrir", "open": "Abrir",
"paste": "Pegar", "paste": "Pegar",
"placeholders": { "placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Vista previa", "preview": "Vista previa",
"split": "Dividir" "split": "Dividir"
}, },
"import": {
"chatgpt": {
"assistant_name": "Importación de ChatGPT",
"button": "Seleccionar archivo",
"description": "Solo importa el texto de la conversación, no incluye imágenes ni archivos adjuntos",
"error": {
"invalid_json": "Formato de archivo JSON inválido",
"no_conversations": "No se encontraron conversaciones en el archivo",
"no_valid_conversations": "No hay conversaciones válidas para importar",
"unknown": "Error de importación, por favor verifica el formato del archivo"
},
"help": {
"step1": "1. Inicia sesión en ChatGPT, ve a Configuración > Controles de datos > Exportar datos",
"step2": "2. Espera el archivo de exportación por correo electrónico",
"step3": "3. Extrae el archivo descargado y busca conversations.json",
"title": "Cómo exportar conversaciones de ChatGPT"
},
"importing": "Importando conversaciones...",
"selecting": "Seleccionando archivo...",
"success": "Importadas con éxito {{topics}} conversaciones con {{messages}} mensajes",
"title": "Importar conversaciones de ChatGPT",
"untitled_conversation": "Conversación Sin Título"
},
"confirm": {
"button": "Seleccionar Archivo de Importación",
"label": "¿Estás seguro de que quieres importar datos externos?"
},
"content": "Selecciona el archivo de conversación de la aplicación externa para importar; actualmente solo admite archivos en formato JSON de ChatGPT",
"title": "Importar Conversaciones Externas"
},
"knowledge": { "knowledge": {
"add": { "add": {
"title": "Agregar base de conocimientos" "title": "Agregar base de conocimientos"
@ -3094,6 +3126,7 @@
"basic": "Configuración básica", "basic": "Configuración básica",
"cloud_storage": "Configuración de almacenamiento en la nube", "cloud_storage": "Configuración de almacenamiento en la nube",
"export_settings": "Configuración de exportación", "export_settings": "Configuración de exportación",
"import_settings": "Importar configuración",
"third_party": "Conexiones de terceros" "third_party": "Conexiones de terceros"
}, },
"export_menu": { "export_menu": {
@ -3152,6 +3185,11 @@
}, },
"hour_interval_one": "{{count}} hora", "hour_interval_one": "{{count}} hora",
"hour_interval_other": "{{count}} horas", "hour_interval_other": "{{count}} horas",
"import_settings": {
"button": "Importar archivo Json",
"chatgpt": "Importar desde ChatGPT",
"title": "Importar datos de aplicaciones externas"
},
"joplin": { "joplin": {
"check": { "check": {
"button": "Revisar", "button": "Revisar",
@ -4223,7 +4261,6 @@
"default": "Predeterminado", "default": "Predeterminado",
"flex": "Flexible", "flex": "Flexible",
"on_demand": "según demanda", "on_demand": "según demanda",
"performance": "rendimiento",
"priority": "prioridad", "priority": "prioridad",
"tip": "Especifica el nivel de latencia utilizado para procesar la solicitud", "tip": "Especifica el nivel de latencia utilizado para procesar la solicitud",
"title": "Nivel de servicio" "title": "Nivel de servicio"

View File

@ -29,15 +29,15 @@
}, },
"gitBash": { "gitBash": {
"error": { "error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from", "description": "Git Bash est requis pour exécuter des agents sur Windows. L'agent ne peut pas fonctionner sans. Veuillez installer Git pour Windows depuis",
"recheck": "[to be translated]:Recheck Git Bash Installation", "recheck": "Revérifier l'installation de Git Bash",
"title": "[to be translated]:Git Bash Required" "title": "Git Bash requis"
}, },
"notFound": "[to be translated]:Git Bash not found. Please install it first.", "notFound": "Git Bash introuvable. Veuillez linstaller dabord.",
"success": "[to be translated]:Git Bash detected successfully!" "success": "Git Bash détecté avec succès !"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "Entrez votre message ici, envoyez avec {{key}} - @ sélectionner le chemin, / sélectionner la commande"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "Échec de l'envoi de votre décision. Veuillez réessayer." "sendFailed": "Échec de l'envoi de votre décision. Veuillez réessayer."
}, },
"executing": "Exécution en cours...",
"expired": "Expiré", "expired": "Expiré",
"inputPreview": "Aperçu de l'entrée de l'outil", "inputPreview": "Aperçu de l'entrée de l'outil",
"pending": "En attente ({{seconds}}s)", "pending": "En attente ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Nom", "name": "Nom",
"no_results": "Aucun résultat", "no_results": "Aucun résultat",
"none": "Aucun", "none": "Aucun",
"off": "Désactivé",
"open": "Ouvrir", "open": "Ouvrir",
"paste": "Coller", "paste": "Coller",
"placeholders": { "placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Aperçu", "preview": "Aperçu",
"split": "Diviser" "split": "Diviser"
}, },
"import": {
"chatgpt": {
"assistant_name": "Importation de ChatGPT",
"button": "Sélectionner le fichier",
"description": "Importe uniquement le texte de la conversation, n'inclut pas les images et les pièces jointes",
"error": {
"invalid_json": "Format de fichier JSON invalide",
"no_conversations": "Aucune conversation trouvée dans le fichier",
"no_valid_conversations": "Aucune conversation valide à importer",
"unknown": "L'importation a échoué, veuillez vérifier le format du fichier"
},
"help": {
"step1": "1. Connectez-vous à ChatGPT, allez dans Paramètres > Contrôles des données > Exporter les données",
"step2": "2. Attendez le fichier dexportation par e-mail",
"step3": "3. Extrayez le fichier téléchargé et recherchez conversations.json",
"title": "Comment exporter les conversations de ChatGPT ?"
},
"importing": "Importation des conversations...",
"selecting": "Sélection du fichier...",
"success": "Importation réussie de {{topics}} conversations avec {{messages}} messages",
"title": "Importer les conversations de ChatGPT",
"untitled_conversation": "Conversation sans titre"
},
"confirm": {
"button": "Sélectionner le fichier à importer",
"label": "Êtes-vous sûr de vouloir importer des données externes ?"
},
"content": "Sélectionnez le fichier de conversation de l'application externe à importer, actuellement uniquement les fichiers au format JSON de ChatGPT sont pris en charge",
"title": "Importer des conversations externes"
},
"knowledge": { "knowledge": {
"add": { "add": {
"title": "Ajouter une base de connaissances" "title": "Ajouter une base de connaissances"
@ -3094,6 +3126,7 @@
"basic": "Paramètres de base", "basic": "Paramètres de base",
"cloud_storage": "Paramètres de sauvegarde cloud", "cloud_storage": "Paramètres de sauvegarde cloud",
"export_settings": "Paramètres d'exportation", "export_settings": "Paramètres d'exportation",
"import_settings": "Importer les paramètres",
"third_party": "Connexion tierce" "third_party": "Connexion tierce"
}, },
"export_menu": { "export_menu": {
@ -3152,6 +3185,11 @@
}, },
"hour_interval_one": "{{count}} heure", "hour_interval_one": "{{count}} heure",
"hour_interval_other": "{{count}} heures", "hour_interval_other": "{{count}} heures",
"import_settings": {
"button": "Importer le fichier JSON",
"chatgpt": "Importer depuis ChatGPT",
"title": "Importer des données d'applications externes"
},
"joplin": { "joplin": {
"check": { "check": {
"button": "Vérifier", "button": "Vérifier",
@ -4223,7 +4261,6 @@
"default": "Par défaut", "default": "Par défaut",
"flex": "Flexible", "flex": "Flexible",
"on_demand": "à la demande", "on_demand": "à la demande",
"performance": "performance",
"priority": "priorité", "priority": "priorité",
"tip": "Spécifie le niveau de latence utilisé pour traiter la demande", "tip": "Spécifie le niveau de latence utilisé pour traiter la demande",
"title": "Niveau de service" "title": "Niveau de service"

View File

@ -29,15 +29,15 @@
}, },
"gitBash": { "gitBash": {
"error": { "error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from", "description": "Windowsでエージェントを実行するにはGit Bashが必要です。これがないとエージェントは動作しません。以下からGit for Windowsをインストールしてください。",
"recheck": "[to be translated]:Recheck Git Bash Installation", "recheck": "Git Bashのインストールを再確認してください",
"title": "[to be translated]:Git Bash Required" "title": "Git Bashが必要です"
}, },
"notFound": "[to be translated]:Git Bash not found. Please install it first.", "notFound": "Git Bash が見つかりません。先にインストールしてください。",
"success": "[to be translated]:Git Bash detected successfully!" "success": "Git Bashが正常に検出されました"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "メッセージをここに入力し、{{key}}で送信 - @でパスを選択、/でコマンドを選択"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "決定の送信に失敗しました。もう一度お試しください。" "sendFailed": "決定の送信に失敗しました。もう一度お試しください。"
}, },
"executing": "実行中...",
"expired": "期限切れ", "expired": "期限切れ",
"inputPreview": "ツール入力プレビュー", "inputPreview": "ツール入力プレビュー",
"pending": "保留中({{seconds}}秒)", "pending": "保留中({{seconds}}秒)",
@ -1158,6 +1159,7 @@
"name": "名前", "name": "名前",
"no_results": "検索結果なし", "no_results": "検索結果なし",
"none": "無", "none": "無",
"off": "オフ",
"open": "開く", "open": "開く",
"paste": "貼り付け", "paste": "貼り付け",
"placeholders": { "placeholders": {
@ -1385,6 +1387,36 @@
"preview": "プレビュー", "preview": "プレビュー",
"split": "分割" "split": "分割"
}, },
"import": {
"chatgpt": {
"assistant_name": "ChatGPTインポート",
"button": "ファイルを選択",
"description": "会話のテキストのみをインポートし、画像や添付ファイルは含まれません",
"error": {
"invalid_json": "無効なJSONファイル形式",
"no_conversations": "ファイルに会話が見つかりません",
"no_valid_conversations": "インポートする有効な会話がありません",
"unknown": "インポートに失敗しました。ファイル形式を確認してください。"
},
"help": {
"step1": "1. ChatGPTにログインし、設定 > データ管理 > データをエクスポートへ進みます",
"step2": "2. エクスポートファイルが届くまでメールでお待ちください",
"step3": "3. ダウンロードしたファイルを展開し、conversations.jsonを探してください",
"title": "ChatGPTの会話をエクスポートする方法は"
},
"importing": "会話をインポートしています...",
"selecting": "ファイルを選択中...",
"success": "{{topics}}件の会話と{{messages}}件のメッセージを正常にインポートしました",
"title": "ChatGPTの会話をインポート",
"untitled_conversation": "無題の会話"
},
"confirm": {
"button": "ファイルのインポートを選択",
"label": "外部データをインポートしてもよろしいですか?"
},
"content": "外部アプリケーションの会話ファイルを選択してインポートします。現在、ChatGPT JSON形式ファイルのみサポートしています。",
"title": "外部会話をインポート"
},
"knowledge": { "knowledge": {
"add": { "add": {
"title": "ナレッジベースを追加" "title": "ナレッジベースを追加"
@ -3094,6 +3126,7 @@
"basic": "基本データ設定", "basic": "基本データ設定",
"cloud_storage": "クラウドバックアップ設定", "cloud_storage": "クラウドバックアップ設定",
"export_settings": "エクスポート設定", "export_settings": "エクスポート設定",
"import_settings": "設定のインポート",
"third_party": "サードパーティー連携" "third_party": "サードパーティー連携"
}, },
"export_menu": { "export_menu": {
@ -3152,6 +3185,11 @@
}, },
"hour_interval_one": "{{count}} 時間", "hour_interval_one": "{{count}} 時間",
"hour_interval_other": "{{count}} 時間", "hour_interval_other": "{{count}} 時間",
"import_settings": {
"button": "JSONファイルをインポート",
"chatgpt": "ChatGPTからインポート",
"title": "外部アプリケーションデータをインポート"
},
"joplin": { "joplin": {
"check": { "check": {
"button": "確認", "button": "確認",
@ -4223,7 +4261,6 @@
"default": "デフォルト", "default": "デフォルト",
"flex": "フレックス", "flex": "フレックス",
"on_demand": "オンデマンド", "on_demand": "オンデマンド",
"performance": "性能",
"priority": "優先", "priority": "優先",
"tip": "リクエスト処理に使用するレイテンシティアを指定します", "tip": "リクエスト処理に使用するレイテンシティアを指定します",
"title": "サービスティア" "title": "サービスティア"

View File

@ -29,15 +29,15 @@
}, },
"gitBash": { "gitBash": {
"error": { "error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from", "description": "O Git Bash é necessário para executar agentes no Windows. O agente não pode funcionar sem ele. Por favor, instale o Git para Windows a partir de",
"recheck": "[to be translated]:Recheck Git Bash Installation", "recheck": "Reverificar a Instalação do Git Bash",
"title": "[to be translated]:Git Bash Required" "title": "Git Bash Necessário"
}, },
"notFound": "[to be translated]:Git Bash not found. Please install it first.", "notFound": "Git Bash não encontrado. Por favor, instale-o primeiro.",
"success": "[to be translated]:Git Bash detected successfully!" "success": "Git Bash detectado com sucesso!"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "Digite sua mensagem aqui, envie com {{key}} - @ selecionar caminho, / selecionar comando"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "Falha ao enviar sua decisão. Por favor, tente novamente." "sendFailed": "Falha ao enviar sua decisão. Por favor, tente novamente."
}, },
"executing": "Executando...",
"expired": "Expirado", "expired": "Expirado",
"inputPreview": "Pré-visualização da entrada da ferramenta", "inputPreview": "Pré-visualização da entrada da ferramenta",
"pending": "Pendente ({{seconds}}s)", "pending": "Pendente ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Nome", "name": "Nome",
"no_results": "Nenhum resultado", "no_results": "Nenhum resultado",
"none": "Nenhum", "none": "Nenhum",
"off": "Desligado",
"open": "Abrir", "open": "Abrir",
"paste": "Colar", "paste": "Colar",
"placeholders": { "placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Visualizar", "preview": "Visualizar",
"split": "Dividir" "split": "Dividir"
}, },
"import": {
"chatgpt": {
"assistant_name": "Importação do ChatGPT",
"button": "Selecionar Arquivo",
"description": "Importa apenas o texto da conversa, não inclui imagens e anexos",
"error": {
"invalid_json": "Formato de arquivo JSON inválido",
"no_conversations": "Nenhuma conversa encontrada no arquivo",
"no_valid_conversations": "Nenhuma conversa válida para importar",
"unknown": "Falha na importação, verifique o formato do arquivo"
},
"help": {
"step1": "1. Faça login no ChatGPT, vá para Configurações > Controles de dados > Exportar dados",
"step2": "2. Aguarde o arquivo de exportação por e-mail",
"step3": "3. Extraia o arquivo baixado e localize conversations.json",
"title": "Como exportar conversas do ChatGPT?"
},
"importing": "Importando conversas...",
"selecting": "Selecionando arquivo...",
"success": "Importadas com sucesso {{topics}} conversas com {{messages}} mensagens",
"title": "Importar Conversas do ChatGPT",
"untitled_conversation": "Conversa Sem Título"
},
"confirm": {
"button": "Selecionar Arquivo de Importação",
"label": "Tem certeza de que deseja importar dados externos?"
},
"content": "Selecione o arquivo de conversa do aplicativo externo para importar; atualmente, apenas arquivos no formato JSON do ChatGPT são suportados.",
"title": "Importar Conversas Externas"
},
"knowledge": { "knowledge": {
"add": { "add": {
"title": "Adicionar Base de Conhecimento" "title": "Adicionar Base de Conhecimento"
@ -3094,6 +3126,7 @@
"basic": "Configurações Básicas", "basic": "Configurações Básicas",
"cloud_storage": "Configurações de Armazenamento em Nuvem", "cloud_storage": "Configurações de Armazenamento em Nuvem",
"export_settings": "Configurações de Exportação", "export_settings": "Configurações de Exportação",
"import_settings": "Importar Configurações",
"third_party": "Conexões de Terceiros" "third_party": "Conexões de Terceiros"
}, },
"export_menu": { "export_menu": {
@ -3152,6 +3185,11 @@
}, },
"hour_interval_one": "{{count}} hora", "hour_interval_one": "{{count}} hora",
"hour_interval_other": "{{count}} horas", "hour_interval_other": "{{count}} horas",
"import_settings": {
"button": "Importar Arquivo Json",
"chatgpt": "Importar do ChatGPT",
"title": "Importar Dados de Aplicações Externas"
},
"joplin": { "joplin": {
"check": { "check": {
"button": "Verificar", "button": "Verificar",
@ -4223,7 +4261,6 @@
"default": "Padrão", "default": "Padrão",
"flex": "Flexível", "flex": "Flexível",
"on_demand": "sob demanda", "on_demand": "sob demanda",
"performance": "desempenho",
"priority": "prioridade", "priority": "prioridade",
"tip": "Especifique o nível de latência usado para processar a solicitação", "tip": "Especifique o nível de latência usado para processar a solicitação",
"title": "Nível de Serviço" "title": "Nível de Serviço"

View File

@ -29,15 +29,15 @@
}, },
"gitBash": { "gitBash": {
"error": { "error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from", "description": "Для запуска агентов в Windows требуется Git Bash. Без него агент не может работать. Пожалуйста, установите Git для Windows с",
"recheck": "[to be translated]:Recheck Git Bash Installation", "recheck": "Повторная проверка установки Git Bash",
"title": "[to be translated]:Git Bash Required" "title": "Требуется Git Bash"
}, },
"notFound": "[to be translated]:Git Bash not found. Please install it first.", "notFound": "Git Bash не найден. Пожалуйста, сначала установите его.",
"success": "[to be translated]:Git Bash detected successfully!" "success": "Git Bash успешно обнаружен!"
}, },
"input": { "input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command" "placeholder": "Введите ваше сообщение здесь, отправьте с помощью {{key}} — @ выбрать путь, / выбрать команду"
}, },
"list": { "list": {
"error": { "error": {
@ -266,6 +266,7 @@
"error": { "error": {
"sendFailed": "Не удалось отправить ваше решение. Попробуйте ещё раз." "sendFailed": "Не удалось отправить ваше решение. Попробуйте ещё раз."
}, },
"executing": "Выполнение...",
"expired": "Истёк", "expired": "Истёк",
"inputPreview": "Предварительный просмотр ввода инструмента", "inputPreview": "Предварительный просмотр ввода инструмента",
"pending": "Ожидание ({{seconds}}с)", "pending": "Ожидание ({{seconds}}с)",
@ -1158,6 +1159,7 @@
"name": "Имя", "name": "Имя",
"no_results": "Результатов не найдено", "no_results": "Результатов не найдено",
"none": "без", "none": "без",
"off": "Выкл",
"open": "Открыть", "open": "Открыть",
"paste": "Вставить", "paste": "Вставить",
"placeholders": { "placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Предпросмотр", "preview": "Предпросмотр",
"split": "Разделить" "split": "Разделить"
}, },
"import": {
"chatgpt": {
"assistant_name": "Импорт ChatGPT",
"button": "Выбрать файл",
"description": "Импортирует только текст переписки, не включает изображения и вложения",
"error": {
"invalid_json": "Неверный формат файла JSON",
"no_conversations": "В файле не найдено ни одной беседы",
"no_valid_conversations": "Нет допустимых бесед для импорта",
"unknown": "Импорт не удался, проверьте формат файла"
},
"help": {
"step1": "1. Войдите в ChatGPT, перейдите в Настройки > Управление данными > Экспортировать данные",
"step2": "2. Дождитесь письма с файлом экспорта",
"step3": "3. Распакуйте загруженный файл и найдите conversations.json",
"title": "Как экспортировать диалоги с ChatGPT?"
},
"importing": "Импортирование разговоров...",
"selecting": "Выбор файла...",
"success": "Успешно импортировано {{topics}} бесед с {{messages}} сообщениями",
"title": "Импортировать диалоги ChatGPT",
"untitled_conversation": "Безымянный разговор"
},
"confirm": {
"button": "Выберите файл для импорта",
"label": "Вы уверены, что хотите импортировать внешние данные?"
},
"content": "Выберите внешний файл с перепиской для импорта; в настоящее время поддерживаются только файлы в формате JSON ChatGPT",
"title": "Импорт внешних бесед"
},
"knowledge": { "knowledge": {
"add": { "add": {
"title": "Добавить базу знаний" "title": "Добавить базу знаний"
@ -3094,6 +3126,7 @@
"basic": "Основные настройки данных", "basic": "Основные настройки данных",
"cloud_storage": "Настройки облачного резервирования", "cloud_storage": "Настройки облачного резервирования",
"export_settings": "Настройки экспорта", "export_settings": "Настройки экспорта",
"import_settings": "Импорт настроек",
"third_party": "Сторонние подключения" "third_party": "Сторонние подключения"
}, },
"export_menu": { "export_menu": {
@ -3152,6 +3185,11 @@
}, },
"hour_interval_one": "{{count}} час", "hour_interval_one": "{{count}} час",
"hour_interval_other": "{{count}} часов", "hour_interval_other": "{{count}} часов",
"import_settings": {
"button": "Импортировать файл JSON",
"chatgpt": "Импорт из ChatGPT",
"title": "Импорт внешних данных приложения"
},
"joplin": { "joplin": {
"check": { "check": {
"button": "Проверить", "button": "Проверить",
@ -4223,7 +4261,6 @@
"default": "По умолчанию", "default": "По умолчанию",
"flex": "Гибкий", "flex": "Гибкий",
"on_demand": "по требованию", "on_demand": "по требованию",
"performance": "производительность",
"priority": "приоритет", "priority": "приоритет",
"tip": "Указывает уровень задержки, который следует использовать для обработки запроса", "tip": "Указывает уровень задержки, который следует использовать для обработки запроса",
"title": "Уровень сервиса" "title": "Уровень сервиса"

View File

@ -3,7 +3,6 @@ import { ActionIconButton } from '@renderer/components/Buttons'
import type { QuickPanelListItem } from '@renderer/components/QuickPanel' import type { QuickPanelListItem } from '@renderer/components/QuickPanel'
import { QuickPanelReservedSymbol, useQuickPanel } from '@renderer/components/QuickPanel' import { QuickPanelReservedSymbol, useQuickPanel } from '@renderer/components/QuickPanel'
import { isGeminiModel } from '@renderer/config/models' import { isGeminiModel } from '@renderer/config/models'
import { isGeminiWebSearchProvider, isSupportUrlContextProvider } from '@renderer/config/providers'
import { useAssistant } from '@renderer/hooks/useAssistant' import { useAssistant } from '@renderer/hooks/useAssistant'
import { useMCPServers } from '@renderer/hooks/useMCPServers' import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { useTimer } from '@renderer/hooks/useTimer' import { useTimer } from '@renderer/hooks/useTimer'
@ -12,6 +11,7 @@ import { getProviderByModel } from '@renderer/services/AssistantService'
import { EventEmitter } from '@renderer/services/EventService' import { EventEmitter } from '@renderer/services/EventService'
import type { MCPPrompt, MCPResource, MCPServer } from '@renderer/types' import type { MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant' import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isGeminiWebSearchProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
import { Form, Input } from 'antd' import { Form, Input } from 'antd'
import { CircleX, Hammer, Plus } from 'lucide-react' import { CircleX, Hammer, Plus } from 'lucide-react'
import type { FC } from 'react' import type { FC } from 'react'

View File

@ -9,7 +9,6 @@ import {
isOpenAIWebSearchModel, isOpenAIWebSearchModel,
isWebSearchModel isWebSearchModel
} from '@renderer/config/models' } from '@renderer/config/models'
import { isGeminiWebSearchProvider } from '@renderer/config/providers'
import { useAssistant } from '@renderer/hooks/useAssistant' import { useAssistant } from '@renderer/hooks/useAssistant'
import { useTimer } from '@renderer/hooks/useTimer' import { useTimer } from '@renderer/hooks/useTimer'
import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders' import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders'
@ -19,6 +18,7 @@ import WebSearchService from '@renderer/services/WebSearchService'
import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types' import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types'
import { hasObjectKey } from '@renderer/utils' import { hasObjectKey } from '@renderer/utils'
import { isToolUseModeFunction } from '@renderer/utils/assistant' import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isGeminiWebSearchProvider } from '@renderer/utils/provider'
import { Globe } from 'lucide-react' import { Globe } from 'lucide-react'
import { useCallback, useEffect, useMemo } from 'react' import { useCallback, useEffect, useMemo } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'

View File

@ -1,7 +1,7 @@
import { isAnthropicModel, isGeminiModel } from '@renderer/config/models' import { isAnthropicModel, isGeminiModel, isPureGenerateImageModel } from '@renderer/config/models'
import { isSupportUrlContextProvider } from '@renderer/config/providers'
import { defineTool, registerTool, TopicType } from '@renderer/pages/home/Inputbar/types' import { defineTool, registerTool, TopicType } from '@renderer/pages/home/Inputbar/types'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import { isSupportUrlContextProvider } from '@renderer/utils/provider'
import UrlContextButton from './components/UrlContextbutton' import UrlContextButton from './components/UrlContextbutton'
@ -11,7 +11,12 @@ const urlContextTool = defineTool({
visibleInScopes: [TopicType.Chat], visibleInScopes: [TopicType.Chat],
condition: ({ model }) => { condition: ({ model }) => {
const provider = getProviderByModel(model) const provider = getProviderByModel(model)
return !!provider && isSupportUrlContextProvider(provider) && (isGeminiModel(model) || isAnthropicModel(model)) return (
!!provider &&
isSupportUrlContextProvider(provider) &&
!isPureGenerateImageModel(model) &&
(isGeminiModel(model) || isAnthropicModel(model))
)
}, },
render: ({ assistant }) => <UrlContextButton assistantId={assistant.id} /> render: ({ assistant }) => <UrlContextButton assistantId={assistant.id} />
}) })

View File

@ -1,4 +1,4 @@
import { isMandatoryWebSearchModel } from '@renderer/config/models' import { isMandatoryWebSearchModel, isWebSearchModel } from '@renderer/config/models'
import { defineTool, registerTool, TopicType } from '@renderer/pages/home/Inputbar/types' import { defineTool, registerTool, TopicType } from '@renderer/pages/home/Inputbar/types'
import WebSearchButton from './components/WebSearchButton' import WebSearchButton from './components/WebSearchButton'
@ -15,7 +15,7 @@ const webSearchTool = defineTool({
label: (t) => t('chat.input.web_search.label'), label: (t) => t('chat.input.web_search.label'),
visibleInScopes: [TopicType.Chat], visibleInScopes: [TopicType.Chat],
condition: ({ model }) => !isMandatoryWebSearchModel(model), condition: ({ model }) => isWebSearchModel(model) && !isMandatoryWebSearchModel(model),
render: function WebSearchToolRender(context) { render: function WebSearchToolRender(context) {
const { assistant, quickPanelController } = context const { assistant, quickPanelController } = context

Some files were not shown because too many files have changed in this diff Show More