Merge branch 'main' into v2

This commit is contained in:
fullex 2025-11-24 16:22:32 +08:00
commit 0045bf6c9c
142 changed files with 11167 additions and 2536 deletions

View File

@ -77,7 +77,7 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
commit-message: "feat(bot): Weekly automated script run"
title: "🤖 Weekly Automated Update: ${{ env.CURRENT_DATE }}"
title: "🤖 Weekly Auto I18N Sync: ${{ env.CURRENT_DATE }}"
body: |
This PR includes changes generated by the weekly auto i18n.
Review the changes before merging.

View File

@ -1,152 +0,0 @@
diff --git a/dist/index.js b/dist/index.js
index c2ef089c42e13a8ee4a833899a415564130e5d79..75efa7baafb0f019fb44dd50dec1641eee8879e7 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
index d75c0cc13c41192408c1f3f2d29d76a7bffa6268..ada730b8cb97d9b7d4cb32883a1d1ff416404d9b 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/internal/index.js b/dist/internal/index.js
index 277cac8dc734bea2fb4f3e9a225986b402b24f48..bb704cd79e602eb8b0cee1889e42497d59ccdb7a 100644
--- a/dist/internal/index.js
+++ b/dist/internal/index.js
@@ -432,7 +432,15 @@ function prepareTools({
var _a;
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
const toolWarnings = [];
- const isGemini2 = modelId.includes("gemini-2");
+ // These changes could be safely removed when @ai-sdk/google v3 released.
+ const isLatest = (
+ [
+ 'gemini-flash-latest',
+ 'gemini-flash-lite-latest',
+ 'gemini-pro-latest',
+ ]
+ ).some(id => id === modelId);
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
const supportsFileSearch = modelId.includes("gemini-2.5");
if (tools == null) {
@@ -458,7 +466,7 @@ function prepareTools({
providerDefinedTools.forEach((tool) => {
switch (tool.id) {
case "google.google_search":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ googleSearch: {} });
} else if (supportsDynamicRetrieval) {
googleTools2.push({
@@ -474,7 +482,7 @@ function prepareTools({
}
break;
case "google.url_context":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ urlContext: {} });
} else {
toolWarnings.push({
@@ -485,7 +493,7 @@ function prepareTools({
}
break;
case "google.code_execution":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ codeExecution: {} });
} else {
toolWarnings.push({
@@ -507,7 +515,7 @@ function prepareTools({
}
break;
case "google.vertex_rag_store":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({
retrieval: {
vertex_rag_store: {
diff --git a/dist/internal/index.mjs b/dist/internal/index.mjs
index 03b7cc591be9b58bcc2e775a96740d9f98862a10..347d2c12e1cee79f0f8bb258f3844fb0522a6485 100644
--- a/dist/internal/index.mjs
+++ b/dist/internal/index.mjs
@@ -424,7 +424,15 @@ function prepareTools({
var _a;
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
const toolWarnings = [];
- const isGemini2 = modelId.includes("gemini-2");
+ // These changes could be safely removed when @ai-sdk/google v3 released.
+ const isLatest = (
+ [
+ 'gemini-flash-latest',
+ 'gemini-flash-lite-latest',
+ 'gemini-pro-latest',
+ ]
+ ).some(id => id === modelId);
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
const supportsFileSearch = modelId.includes("gemini-2.5");
if (tools == null) {
@@ -450,7 +458,7 @@ function prepareTools({
providerDefinedTools.forEach((tool) => {
switch (tool.id) {
case "google.google_search":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ googleSearch: {} });
} else if (supportsDynamicRetrieval) {
googleTools2.push({
@@ -466,7 +474,7 @@ function prepareTools({
}
break;
case "google.url_context":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ urlContext: {} });
} else {
toolWarnings.push({
@@ -477,7 +485,7 @@ function prepareTools({
}
break;
case "google.code_execution":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ codeExecution: {} });
} else {
toolWarnings.push({
@@ -499,7 +507,7 @@ function prepareTools({
}
break;
case "google.vertex_rag_store":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({
retrieval: {
vertex_rag_store: {
@@ -1434,9 +1442,7 @@ var googleTools = {
vertexRagStore
};
export {
- GoogleGenerativeAILanguageModel,
getGroundingMetadataSchema,
- getUrlContextMetadataSchema,
- googleTools
+ getUrlContextMetadataSchema, GoogleGenerativeAILanguageModel, googleTools
};
//# sourceMappingURL=index.mjs.map
\ No newline at end of file

View File

@ -0,0 +1,26 @@
diff --git a/dist/index.js b/dist/index.js
index dc7b74ba55337c491cdf1ab3e39ca68cc4187884..ace8c90591288e42c2957e93c9bf7984f1b22444 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -472,7 +472,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
index 8390439c38cb7eaeb52080862cd6f4c58509e67c..a7647f2e11700dff7e1c8d4ae8f99d3637010733 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -478,7 +478,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts

View File

@ -1,131 +0,0 @@
diff --git a/dist/index.mjs b/dist/index.mjs
index b3f018730a93639aad7c203f15fb1aeb766c73f4..ade2a43d66e9184799d072153df61ef7be4ea110 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -296,7 +296,14 @@ var HuggingFaceResponsesLanguageModel = class {
metadata: huggingfaceOptions == null ? void 0 : huggingfaceOptions.metadata,
instructions: huggingfaceOptions == null ? void 0 : huggingfaceOptions.instructions,
...preparedTools && { tools: preparedTools },
- ...preparedToolChoice && { tool_choice: preparedToolChoice }
+ ...preparedToolChoice && { tool_choice: preparedToolChoice },
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ reasoning: {
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ effort: huggingfaceOptions.reasoningEffort,
+ }),
+ },
+ }),
};
return { args: baseArgs, warnings };
}
@@ -365,6 +372,20 @@ var HuggingFaceResponsesLanguageModel = class {
}
break;
}
+ case 'reasoning': {
+ for (const contentPart of part.content) {
+ content.push({
+ type: 'reasoning',
+ text: contentPart.text,
+ providerMetadata: {
+ huggingface: {
+ itemId: part.id,
+ },
+ },
+ });
+ }
+ break;
+ }
case "mcp_call": {
content.push({
type: "tool-call",
@@ -519,6 +540,11 @@ var HuggingFaceResponsesLanguageModel = class {
id: value.item.call_id,
toolName: value.item.name
});
+ } else if (value.item.type === 'reasoning') {
+ controller.enqueue({
+ type: 'reasoning-start',
+ id: value.item.id,
+ });
}
return;
}
@@ -570,6 +596,22 @@ var HuggingFaceResponsesLanguageModel = class {
});
return;
}
+ if (isReasoningDeltaChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-delta',
+ id: value.item_id,
+ delta: value.delta,
+ });
+ return;
+ }
+
+ if (isReasoningEndChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-end',
+ id: value.item_id,
+ });
+ return;
+ }
},
flush(controller) {
controller.enqueue({
@@ -593,7 +635,8 @@ var HuggingFaceResponsesLanguageModel = class {
var huggingfaceResponsesProviderOptionsSchema = z2.object({
metadata: z2.record(z2.string(), z2.string()).optional(),
instructions: z2.string().optional(),
- strictJsonSchema: z2.boolean().optional()
+ strictJsonSchema: z2.boolean().optional(),
+ reasoningEffort: z2.string().optional(),
});
var huggingfaceResponsesResponseSchema = z2.object({
id: z2.string(),
@@ -727,12 +770,31 @@ var responseCreatedChunkSchema = z2.object({
model: z2.string()
})
});
+var reasoningTextDeltaChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.delta'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ delta: z2.string(),
+ sequence_number: z2.number(),
+});
+
+var reasoningTextEndChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.done'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ text: z2.string(),
+ sequence_number: z2.number(),
+});
var huggingfaceResponsesChunkSchema = z2.union([
responseOutputItemAddedSchema,
responseOutputItemDoneSchema,
textDeltaChunkSchema,
responseCompletedChunkSchema,
responseCreatedChunkSchema,
+ reasoningTextDeltaChunkSchema,
+ reasoningTextEndChunkSchema,
z2.object({ type: z2.string() }).loose()
// fallback for unknown chunks
]);
@@ -751,6 +813,12 @@ function isResponseCompletedChunk(chunk) {
function isResponseCreatedChunk(chunk) {
return chunk.type === "response.created";
}
+function isReasoningDeltaChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.delta';
+}
+function isReasoningEndChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.done';
+}
// src/huggingface-provider.ts
function createHuggingFace(options = {}) {

View File

@ -0,0 +1,140 @@
diff --git a/dist/index.js b/dist/index.js
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
arguments: import_v43.z.string()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}),
finish_reason: import_v43.z.string().nullish()
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
arguments: import_v43.z.string().nullish()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: import_v43.z.string().nullish()
diff --git a/dist/index.mjs b/dist/index.mjs
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
arguments: z3.string()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}),
finish_reason: z3.string().nullish()
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
arguments: z3.string().nullish()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: z3.string().nullish()

View File

@ -1,5 +1,5 @@
diff --git a/dist/index.js b/dist/index.js
index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa96b52ac0d 100644
index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c75674e59637 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
@ -18,7 +18,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
tool_calls: import_v42.z.array(
import_v42.z.object({
index: import_v42.z.number(),
@@ -785,6 +787,13 @@ var OpenAIChatLanguageModel = class {
@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
if (text != null && text.length > 0) {
content.push({ type: "text", text });
}
@ -32,7 +32,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
content.push({
type: "tool-call",
@@ -866,6 +875,7 @@ var OpenAIChatLanguageModel = class {
@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
};
let metadataExtracted = false;
let isActiveText = false;
@ -40,7 +40,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
const providerMetadata = { openai: {} };
return {
stream: response.pipeThrough(
@@ -923,6 +933,21 @@ var OpenAIChatLanguageModel = class {
@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
return;
}
const delta = choice.delta;
@ -62,7 +62,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
if (delta.content != null) {
if (!isActiveText) {
controller.enqueue({ type: "text-start", id: "0" });
@@ -1035,6 +1060,9 @@ var OpenAIChatLanguageModel = class {
@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
}
},
flush(controller) {

View File

@ -14,7 +14,7 @@
}
},
"enabled": true,
"includes": ["**/*.json", "!*.json", "!**/package.json"]
"includes": ["**/*.json", "!*.json", "!**/package.json", "!coverage/**"]
},
"css": {
"formatter": {
@ -23,7 +23,7 @@
},
"files": {
"ignoreUnknown": false,
"includes": ["**", "!**/.claude/**"],
"includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
"maxSize": 2097152
},
"formatter": {

View File

@ -135,58 +135,66 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
<!--LANG:en-->
What's New in v1.7.0-rc.1
🎉 MAJOR NEW FEATURE: AI Agents
- Create and manage custom AI agents with specialized tools and permissions
- Dedicated agent sessions with persistent SQLite storage, separate from regular chats
- Real-time tool approval system - review and approve agent actions dynamically
- MCP (Model Context Protocol) integration for connecting external tools
- Slash commands support for quick agent interactions
- OpenAI-compatible REST API for agent access
What's New in v1.7.0-rc.2
✨ New Features:
- AI Providers: Added support for Hugging Face, Mistral, Perplexity, and SophNet
- Knowledge Base: OpenMinerU document preprocessor, full-text search in notes, enhanced tool selection
- Image & OCR: Intel OVMS painting provider and Intel OpenVINO (NPU) OCR support
- MCP Management: Redesigned interface with dual-column layout for easier management
- Languages: Added German language support
⚡ Improvements:
- Upgraded to Electron 38.7.0
- Enhanced system shutdown handling and automatic update checks
- Improved proxy bypass rules
- AI Models: Added support for Gemini 3, Gemini 3 Pro with image preview, and GPT-5.1
- Import: ChatGPT conversation import feature
- Agent: Git Bash detection and requirement check for Windows agents
- Search: Native language emoji search with CLDR data format
- Provider: Endpoint type support for cherryin provider
- Debug: Local crash mini dump file for better diagnostics
🐛 Important Bug Fixes:
- Fixed streaming response issues across multiple AI providers
- Fixed session list scrolling problems
- Fixed knowledge base deletion errors
- Error Handling: Improved error display in AiSdkToChunkAdapter
- Database: Optimized DatabaseManager and fixed libsql crash issues
- Memory: Fixed EventEmitter memory leak in useApiServer hook
- Messages: Fixed adjacent user messages appearing when assistant message contains error only
- Tools: Fixed missing execution state for approved tool permissions
- File Processing: Fixed "no such file" error for non-English filenames in open-mineru
- PDF: Fixed mineru PDF validation and 403 errors
- Images: Fixed base64 image save issues
- Search: Fixed URL context and web search capability
- Models: Added verbosity parameter support for GPT-5 models
- UI: Improved todo tool status icon visibility and colors
- Providers: Fixed api-host for vercel ai-gateway and gitcode update config
⚡ Improvements:
- SDK: Updated Google and OpenAI SDKs with new features
- UI: Simplified knowledge base creation modal and agent creation form
- Tools: Replaced renderToolContent function with ToolContent component
- Architecture: Namespace tool call IDs with session ID to prevent conflicts
- Config: AI SDK configuration refactoring
<!--LANG:zh-CN-->
v1.7.0-rc.1 新特性
🎉 重大更新AI Agent 智能体系统
- 创建和管理专属 AI Agent配置专用工具和权限
- 独立的 Agent 会话,使用 SQLite 持久化存储,与普通聊天分离
- 实时工具审批系统 - 动态审查和批准 Agent 操作
- MCP模型上下文协议集成连接外部工具
- 支持斜杠命令快速交互
- 兼容 OpenAI 的 REST API 访问
v1.7.0-rc.2 新特性
✨ 新功能:
- AI 提供商:新增 Hugging Face、Mistral、Perplexity 和 SophNet 支持
- 知识库OpenMinerU 文档预处理器、笔记全文搜索、增强的工具选择
- 图像与 OCRIntel OVMS 绘图提供商和 Intel OpenVINO (NPU) OCR 支持
- MCP 管理:重构管理界面,采用双列布局,更加方便管理
- 语言:新增德语支持
⚡ 改进:
- 升级到 Electron 38.7.0
- 增强的系统关机处理和自动更新检查
- 改进的代理绕过规则
- AI 模型:新增 Gemini 3、Gemini 3 Pro 图像预览支持,以及 GPT-5.1
- 导入ChatGPT 对话导入功能
- AgentWindows Agent 的 Git Bash 检测和要求检查
- 搜索:支持本地语言 emoji 搜索CLDR 数据格式)
- 提供商cherryin provider 的端点类型支持
- 调试:启用本地崩溃 mini dump 文件,方便诊断
🐛 重要修复:
- 修复多个 AI 提供商的流式响应问题
- 修复会话列表滚动问题
- 修复知识库删除错误
- 错误处理:改进 AiSdkToChunkAdapter 的错误显示
- 数据库:优化 DatabaseManager 并修复 libsql 崩溃问题
- 内存:修复 useApiServer hook 中的 EventEmitter 内存泄漏
- 消息:修复当助手消息仅包含错误时相邻用户消息出现的问题
- 工具:修复批准工具权限缺少执行状态的问题
- 文件处理:修复 open-mineru 处理非英文文件名时的"无此文件"错误
- PDF修复 mineru PDF 验证和 403 错误
- 图片:修复 base64 图片保存问题
- 搜索:修复 URL 上下文和网络搜索功能
- 模型:为 GPT-5 模型添加 verbosity 参数支持
- UI改进 todo 工具状态图标可见性和颜色
- 提供商:修复 vercel ai-gateway 和 gitcode 更新配置的 api-host
⚡ 改进:
- SDK更新 Google 和 OpenAI SDK新增功能和修复
- UI简化知识库创建模态框和 agent 创建表单
- 工具:用 ToolContent 组件替换 renderToolContent 函数,提升可读性
- 架构:用会话 ID 命名工具调用 ID 以防止冲突
- 配置AI SDK 配置重构
<!--LANG:END-->

View File

@ -113,16 +113,17 @@
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^3.0.53",
"@ai-sdk/anthropic": "^2.0.44",
"@ai-sdk/amazon-bedrock": "^3.0.56",
"@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/cerebras": "^1.0.31",
"@ai-sdk/gateway": "^2.0.9",
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch",
"@ai-sdk/google-vertex": "^3.0.68",
"@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.8#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch",
"@ai-sdk/mistral": "^2.0.23",
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/perplexity": "^2.0.17",
"@ai-sdk/gateway": "^2.0.13",
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/google-vertex": "^3.0.72",
"@ai-sdk/huggingface": "^0.0.10",
"@ai-sdk/mistral": "^2.0.24",
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/perplexity": "^2.0.20",
"@ai-sdk/test-server": "^0.0.1",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
@ -167,7 +168,7 @@
"@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.2.0",
"@openrouter/ai-sdk-provider": "^1.2.5",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
@ -244,7 +245,7 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.90",
"ai": "^5.0.98",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
@ -417,8 +418,11 @@
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/google@npm:2.0.36": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch"
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/google@npm:2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/openai@npm:2.0.71": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
},
"packageManager": "yarn@4.9.1",
"lint-staged": {

View File

@ -42,7 +42,7 @@
},
"dependencies": {
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.12"
"@ai-sdk/provider-utils": "^3.0.17"
},
"devDependencies": {
"tsdown": "^0.13.3",

View File

@ -39,13 +39,13 @@
"ai": "^5.0.26"
},
"dependencies": {
"@ai-sdk/anthropic": "^2.0.43",
"@ai-sdk/azure": "^2.0.66",
"@ai-sdk/deepseek": "^1.0.27",
"@ai-sdk/openai-compatible": "^1.0.26",
"@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/azure": "^2.0.73",
"@ai-sdk/deepseek": "^1.0.29",
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.16",
"@ai-sdk/xai": "^2.0.31",
"@ai-sdk/provider-utils": "^3.0.17",
"@ai-sdk/xai": "^2.0.34",
"zod": "^4.1.5"
},
"devDependencies": {

View File

@ -0,0 +1,180 @@
/**
* Mock Provider Instances
* Provides mock implementations for all supported AI providers
*/
import type { ImageModelV2, LanguageModelV2 } from '@ai-sdk/provider'
import { vi } from 'vitest'
/**
* Creates a mock language model with customizable behavior
*/
export function createMockLanguageModel(overrides?: Partial<LanguageModelV2>): LanguageModelV2 {
return {
specificationVersion: 'v1',
provider: 'mock-provider',
modelId: 'mock-model',
defaultObjectGenerationMode: 'tool',
doGenerate: vi.fn().mockResolvedValue({
text: 'Mock response text',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 20,
totalTokens: 30
},
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
yield {
type: 'text-delta',
textDelta: 'Mock '
}
yield {
type: 'text-delta',
textDelta: 'streaming '
}
yield {
type: 'text-delta',
textDelta: 'response'
}
yield {
type: 'finish',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 15,
totalTokens: 25
}
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
...overrides
} as LanguageModelV2
}
/**
* Creates a mock image model with customizable behavior
*/
export function createMockImageModel(overrides?: Partial<ImageModelV2>): ImageModelV2 {
return {
specificationVersion: 'v2',
provider: 'mock-provider',
modelId: 'mock-image-model',
doGenerate: vi.fn().mockResolvedValue({
images: [
{
base64: 'mock-base64-image-data',
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
mimeType: 'image/png'
}
],
warnings: []
}),
...overrides
} as ImageModelV2
}
/**
* Mock provider configurations for testing
*/
export const mockProviderConfigs = {
openai: {
apiKey: 'sk-test-openai-key-123456789',
baseURL: 'https://api.openai.com/v1',
organization: 'test-org'
},
anthropic: {
apiKey: 'sk-ant-test-key-123456789',
baseURL: 'https://api.anthropic.com'
},
google: {
apiKey: 'test-google-api-key-123456789',
baseURL: 'https://generativelanguage.googleapis.com/v1'
},
xai: {
apiKey: 'xai-test-key-123456789',
baseURL: 'https://api.x.ai/v1'
},
azure: {
apiKey: 'test-azure-key-123456789',
resourceName: 'test-resource',
deployment: 'test-deployment'
},
deepseek: {
apiKey: 'sk-test-deepseek-key-123456789',
baseURL: 'https://api.deepseek.com/v1'
},
openrouter: {
apiKey: 'sk-or-test-key-123456789',
baseURL: 'https://openrouter.ai/api/v1'
},
huggingface: {
apiKey: 'hf_test_key_123456789',
baseURL: 'https://api-inference.huggingface.co'
},
'openai-compatible': {
apiKey: 'test-compatible-key-123456789',
baseURL: 'https://api.example.com/v1',
name: 'test-provider'
},
'openai-chat': {
apiKey: 'sk-test-chat-key-123456789',
baseURL: 'https://api.openai.com/v1'
}
} as const
/**
* Mock provider instances for testing
*/
export const mockProviderInstances = {
openai: {
name: 'openai-mock',
languageModel: createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }),
imageModel: createMockImageModel({ provider: 'openai', modelId: 'dall-e-3' })
},
anthropic: {
name: 'anthropic-mock',
languageModel: createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' })
},
google: {
name: 'google-mock',
languageModel: createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }),
imageModel: createMockImageModel({ provider: 'google', modelId: 'imagen-3.0-generate-001' })
},
xai: {
name: 'xai-mock',
languageModel: createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }),
imageModel: createMockImageModel({ provider: 'xai', modelId: 'grok-2-image-latest' })
},
deepseek: {
name: 'deepseek-mock',
languageModel: createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' })
}
}
export type ProviderId = keyof typeof mockProviderConfigs

View File

@ -0,0 +1,331 @@
/**
* Mock Responses
* Provides realistic mock responses for all provider types
*/
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
/**
* Standard test messages for all scenarios
*/
export const testMessages = {
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
conversation: [
{ role: 'user' as const, content: 'What is the capital of France?' },
{ role: 'assistant' as const, content: 'The capital of France is Paris.' },
{ role: 'user' as const, content: 'What is its population?' }
],
withSystem: [
{ role: 'system' as const, content: 'You are a helpful assistant that provides concise answers.' },
{ role: 'user' as const, content: 'Explain quantum computing in one sentence.' }
],
withImages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='
}
]
}
],
toolUse: [{ role: 'user' as const, content: 'What is the weather in San Francisco?' }],
multiTurn: [
{ role: 'user' as const, content: 'Can you help me with a math problem?' },
{ role: 'assistant' as const, content: 'Of course! What math problem would you like help with?' },
{ role: 'user' as const, content: 'What is 15 * 23?' },
{ role: 'assistant' as const, content: '15 * 23 = 345' },
{ role: 'user' as const, content: 'Now divide that by 5' }
]
} satisfies Record<string, ModelMessage[]>
/**
* Standard test tools for tool calling scenarios
*/
export const testTools: Record<string, Tool> = {
getWeather: {
description: 'Get the current weather in a given location',
inputSchema: jsonSchema({
type: 'object',
properties: {
location: {
type: 'string',
description: 'The city and state, e.g. San Francisco, CA'
},
unit: {
type: 'string',
enum: ['celsius', 'fahrenheit'],
description: 'The temperature unit to use'
}
},
required: ['location']
}),
execute: async ({ location, unit = 'fahrenheit' }) => {
return {
location,
temperature: unit === 'celsius' ? 22 : 72,
unit,
condition: 'sunny'
}
}
},
calculate: {
description: 'Perform a mathematical calculation',
inputSchema: jsonSchema({
type: 'object',
properties: {
operation: {
type: 'string',
enum: ['add', 'subtract', 'multiply', 'divide'],
description: 'The operation to perform'
},
a: {
type: 'number',
description: 'The first number'
},
b: {
type: 'number',
description: 'The second number'
}
},
required: ['operation', 'a', 'b']
}),
execute: async ({ operation, a, b }) => {
const operations = {
add: (x: number, y: number) => x + y,
subtract: (x: number, y: number) => x - y,
multiply: (x: number, y: number) => x * y,
divide: (x: number, y: number) => x / y
}
return { result: operations[operation as keyof typeof operations](a, b) }
}
},
searchDatabase: {
description: 'Search for information in a database',
inputSchema: jsonSchema({
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query'
},
limit: {
type: 'number',
description: 'Maximum number of results to return',
default: 10
}
},
required: ['query']
}),
execute: async ({ query, limit = 10 }) => {
return {
results: [
{ id: 1, title: `Result 1 for ${query}`, relevance: 0.95 },
{ id: 2, title: `Result 2 for ${query}`, relevance: 0.87 }
].slice(0, limit)
}
}
}
}
/**
* Mock streaming chunks for different providers
*/
export const mockStreamingChunks = {
text: [
{ type: 'text-delta' as const, textDelta: 'Hello' },
{ type: 'text-delta' as const, textDelta: ', ' },
{ type: 'text-delta' as const, textDelta: 'this ' },
{ type: 'text-delta' as const, textDelta: 'is ' },
{ type: 'text-delta' as const, textDelta: 'a ' },
{ type: 'text-delta' as const, textDelta: 'test.' }
],
withToolCall: [
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: '{"location":'
},
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: ' "San Francisco, CA"}'
},
{
type: 'tool-call' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
args: { location: 'San Francisco, CA' }
}
],
withFinish: [
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
{
type: 'finish' as const,
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
}
}
]
}
/**
* Mock complete responses for non-streaming scenarios
*/
export const mockCompleteResponses = {
simple: {
text: 'This is a simple response.',
finishReason: 'stop' as const,
usage: {
promptTokens: 15,
completionTokens: 8,
totalTokens: 23
}
},
withToolCalls: {
text: 'I will check the weather for you.',
toolCalls: [
{
toolCallId: 'call_456',
toolName: 'getWeather',
args: { location: 'New York, NY', unit: 'celsius' }
}
],
finishReason: 'tool-calls' as const,
usage: {
promptTokens: 25,
completionTokens: 12,
totalTokens: 37
}
},
withWarnings: {
text: 'Response with warnings.',
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
},
warnings: [
{
type: 'unsupported-setting' as const,
message: 'Temperature parameter not supported for this model'
}
]
}
}
/**
* Mock image generation responses
*/
export const mockImageResponses = {
single: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82]),
mimeType: 'image/png' as const
},
warnings: []
},
multiple: {
images: [
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAEklEQVR42mNk+M9QzwAEjDAGACCKAgdZ9zImAAAAAElFTkSuQmCC',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
}
],
warnings: []
},
withProviderMetadata: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
providerMetadata: {
openai: {
images: [
{
revisedPrompt: 'A detailed and enhanced version of the original prompt'
}
]
}
},
warnings: []
}
}
/**
* Mock error responses
*/
export const mockErrors = {
invalidApiKey: {
name: 'APIError',
message: 'Invalid API key provided',
statusCode: 401
},
rateLimitExceeded: {
name: 'RateLimitError',
message: 'Rate limit exceeded. Please try again later.',
statusCode: 429,
headers: {
'retry-after': '60'
}
},
modelNotFound: {
name: 'ModelNotFoundError',
message: 'The requested model was not found',
statusCode: 404
},
contextLengthExceeded: {
name: 'ContextLengthError',
message: "This model's maximum context length is 4096 tokens",
statusCode: 400
},
timeout: {
name: 'TimeoutError',
message: 'Request timed out after 30000ms',
code: 'ETIMEDOUT'
},
networkError: {
name: 'NetworkError',
message: 'Network connection failed',
code: 'ECONNREFUSED'
}
}

View File

@ -0,0 +1,329 @@
/**
* Provider-Specific Test Utilities
* Helper functions for testing individual providers with all their parameters
*/
import type { Tool } from 'ai'
import { expect } from 'vitest'
/**
* Provider parameter configurations for comprehensive testing
*/
export const providerParameterMatrix = {
openai: {
models: ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-4o'],
parameters: {
temperature: [0, 0.5, 0.7, 1.0, 1.5, 2.0],
maxTokens: [100, 500, 1000, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
frequencyPenalty: [-2.0, -1.0, 0, 1.0, 2.0],
presencePenalty: [-2.0, -1.0, 0, 1.0, 2.0],
stop: [undefined, ['stop'], ['STOP', 'END']],
seed: [undefined, 12345, 67890],
responseFormat: [undefined, { type: 'json_object' as const }],
user: [undefined, 'test-user-123']
},
toolChoice: ['auto', 'required', 'none', { type: 'function' as const, name: 'getWeather' }],
parallelToolCalls: [true, false]
},
anthropic: {
models: ['claude-3-5-sonnet-20241022', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000, 8000],
topP: [0.1, 0.5, 0.9, 1.0],
topK: [undefined, 1, 5, 10, 40],
stop: [undefined, ['Human:', 'Assistant:']],
metadata: [undefined, { userId: 'test-123' }]
},
toolChoice: ['auto', 'any', { type: 'tool' as const, name: 'getWeather' }]
},
google: {
models: ['gemini-2.0-flash-exp', 'gemini-1.5-pro', 'gemini-1.5-flash'],
parameters: {
temperature: [0, 0.5, 0.9, 1.0],
maxTokens: [100, 1000, 2000, 8000],
topP: [0.1, 0.5, 0.95, 1.0],
topK: [undefined, 1, 16, 40],
stopSequences: [undefined, ['END'], ['STOP', 'TERMINATE']]
},
safetySettings: [
undefined,
[
{ category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_MEDIUM_AND_ABOVE' },
{ category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_ONLY_HIGH' }
]
]
},
xai: {
models: ['grok-2-latest', 'grok-2-1212'],
parameters: {
temperature: [0, 0.5, 1.0, 1.5],
maxTokens: [100, 500, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
stop: [undefined, ['STOP'], ['END', 'TERMINATE']],
seed: [undefined, 12345]
}
},
deepseek: {
models: ['deepseek-chat', 'deepseek-coder'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 0.5, 1.0],
presencePenalty: [0, 0.5, 1.0],
stop: [undefined, ['```'], ['END']]
}
},
azure: {
deployments: ['gpt-4-deployment', 'gpt-35-turbo-deployment'],
parameters: {
temperature: [0, 0.7, 1.0],
maxTokens: [100, 1000, 2000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 1.0],
presencePenalty: [0, 1.0],
stop: [undefined, ['STOP']]
}
}
} as const
/**
* Creates test cases for all parameter combinations
*/
export function generateParameterTestCases<T extends Record<string, any[]>>(
params: T,
maxCombinations = 50
): Array<Partial<{ [K in keyof T]: T[K][number] }>> {
const keys = Object.keys(params) as Array<keyof T>
const testCases: Array<Partial<{ [K in keyof T]: T[K][number] }>> = []
// Generate combinations using sampling strategy for large parameter spaces
const totalCombinations = keys.reduce((acc, key) => acc * params[key].length, 1)
if (totalCombinations <= maxCombinations) {
// Generate all combinations if total is small
generateAllCombinations(params, keys, 0, {}, testCases)
} else {
// Sample diverse combinations if total is large
generateSampledCombinations(params, keys, maxCombinations, testCases)
}
return testCases
}
function generateAllCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
index: number,
current: Partial<{ [K in keyof T]: T[K][number] }>,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
if (index === keys.length) {
results.push({ ...current })
return
}
const key = keys[index]
for (const value of params[key]) {
generateAllCombinations(params, keys, index + 1, { ...current, [key]: value }, results)
}
}
function generateSampledCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
count: number,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
// Generate edge cases first (min/max values)
const edgeCase1: any = {}
const edgeCase2: any = {}
for (const key of keys) {
edgeCase1[key] = params[key][0]
edgeCase2[key] = params[key][params[key].length - 1]
}
results.push(edgeCase1, edgeCase2)
// Generate random combinations for the rest
for (let i = results.length; i < count; i++) {
const combination: any = {}
for (const key of keys) {
const values = params[key]
combination[key] = values[Math.floor(Math.random() * values.length)]
}
results.push(combination)
}
}
/**
* Validates that all provider-specific parameters are correctly passed through
*/
export function validateProviderParams(providerId: string, actualParams: any, expectedParams: any): void {
const requiredFields: Record<string, string[]> = {
openai: ['model', 'messages'],
anthropic: ['model', 'messages'],
google: ['model', 'contents'],
xai: ['model', 'messages'],
deepseek: ['model', 'messages'],
azure: ['messages']
}
const fields = requiredFields[providerId] || ['model', 'messages']
for (const field of fields) {
expect(actualParams).toHaveProperty(field)
}
// Validate optional parameters if they were provided
const optionalParams = ['temperature', 'max_tokens', 'top_p', 'stop', 'tools']
for (const param of optionalParams) {
if (expectedParams[param] !== undefined) {
expect(actualParams[param]).toEqual(expectedParams[param])
}
}
}
/**
* Creates a comprehensive test suite for a provider
*/
// oxlint-disable-next-line no-unused-vars
export function createProviderTestSuite(_providerId: string) {
return {
testBasicCompletion: async (executor: any, model: string) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
expect(result).toBeDefined()
expect(result.text).toBeDefined()
expect(typeof result.text).toBe('string')
},
testStreaming: async (executor: any, model: string) => {
const chunks: any[] = []
const result = await executor.streamText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
for await (const chunk of result.textStream) {
chunks.push(chunk)
}
expect(chunks.length).toBeGreaterThan(0)
},
testTemperature: async (executor: any, model: string, temperatures: number[]) => {
for (const temperature of temperatures) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
temperature
})
expect(result).toBeDefined()
}
},
testMaxTokens: async (executor: any, model: string, maxTokensValues: number[]) => {
for (const maxTokens of maxTokensValues) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
maxTokens
})
expect(result).toBeDefined()
if (result.usage?.completionTokens) {
expect(result.usage.completionTokens).toBeLessThanOrEqual(maxTokens)
}
}
},
testToolCalling: async (executor: any, model: string, tools: Record<string, Tool>) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'What is the weather in SF?' }],
tools
})
expect(result).toBeDefined()
},
testStopSequences: async (executor: any, model: string, stopSequences: string[][]) => {
for (const stop of stopSequences) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Count to 10' }],
stop
})
expect(result).toBeDefined()
}
}
}
}
/**
* Generates test data for vision/multimodal testing
*/
export function createVisionTestData() {
return {
imageUrl: 'https://example.com/test-image.jpg',
base64Image:
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
messages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='
}
]
}
]
}
}
/**
* Creates mock responses for different finish reasons
*/
export function createFinishReasonMocks() {
return {
stop: {
text: 'Complete response.',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 }
},
length: {
text: 'Incomplete response due to',
finishReason: 'length' as const,
usage: { promptTokens: 10, completionTokens: 100, totalTokens: 110 }
},
'tool-calls': {
text: 'Calling tools',
finishReason: 'tool-calls' as const,
toolCalls: [{ toolCallId: 'call_1', toolName: 'getWeather', args: { location: 'SF' } }],
usage: { promptTokens: 10, completionTokens: 8, totalTokens: 18 }
},
'content-filter': {
text: '',
finishReason: 'content-filter' as const,
usage: { promptTokens: 10, completionTokens: 0, totalTokens: 10 }
}
}
}

View File

@ -0,0 +1,291 @@
/**
* Test Utilities
* Helper functions for testing AI Core functionality
*/
import { expect, vi } from 'vitest'
import type { ProviderId } from '../fixtures/mock-providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
/**
* Creates a test provider with streaming support
*/
export function createTestStreamingProvider(chunks: any[]) {
return createMockLanguageModel({
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
for (const chunk of chunks) {
yield chunk
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
})
})
}
/**
* Creates a test provider that throws errors
*/
export function createErrorProvider(error: Error) {
return createMockLanguageModel({
doGenerate: vi.fn().mockRejectedValue(error),
doStream: vi.fn().mockImplementation(() => {
throw error
})
})
}
/**
* Collects all chunks from a stream
*/
export async function collectStreamChunks<T>(stream: AsyncIterable<T>): Promise<T[]> {
const chunks: T[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
return chunks
}
/**
* Waits for a specific number of milliseconds
*/
export function wait(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms))
}
/**
* Creates a mock abort controller that aborts after a delay
*/
export function createDelayedAbortController(delayMs: number): AbortController {
const controller = new AbortController()
setTimeout(() => controller.abort(), delayMs)
return controller
}
/**
* Asserts that a function throws an error with a specific message
*/
export async function expectError(fn: () => Promise<any>, expectedMessage?: string | RegExp): Promise<Error> {
try {
await fn()
throw new Error('Expected function to throw an error, but it did not')
} catch (error) {
if (expectedMessage) {
const message = (error as Error).message
if (typeof expectedMessage === 'string') {
if (!message.includes(expectedMessage)) {
throw new Error(`Expected error message to include "${expectedMessage}", but got "${message}"`)
}
} else {
if (!expectedMessage.test(message)) {
throw new Error(`Expected error message to match ${expectedMessage}, but got "${message}"`)
}
}
}
return error as Error
}
}
/**
* Creates a spy function that tracks calls and arguments
*/
export function createSpy<T extends (...args: any[]) => any>() {
const calls: Array<{ args: Parameters<T>; result?: ReturnType<T>; error?: Error }> = []
const spy = vi.fn((...args: Parameters<T>) => {
try {
const result = undefined as ReturnType<T>
calls.push({ args, result })
return result
} catch (error) {
calls.push({ args, error: error as Error })
throw error
}
})
return {
fn: spy,
calls,
getCalls: () => calls,
getCallCount: () => calls.length,
getLastCall: () => calls[calls.length - 1],
reset: () => {
calls.length = 0
spy.mockClear()
}
}
}
/**
* Validates provider configuration
*/
export function validateProviderConfig(providerId: ProviderId) {
const config = mockProviderConfigs[providerId]
if (!config) {
throw new Error(`No mock configuration found for provider: ${providerId}`)
}
if (!config.apiKey) {
throw new Error(`Provider ${providerId} is missing apiKey in mock config`)
}
return config
}
/**
* Creates a test context with common setup
*/
export function createTestContext() {
const mocks = {
languageModel: createMockLanguageModel(),
imageModel: createMockImageModel(),
providers: new Map<string, any>()
}
const cleanup = () => {
mocks.providers.clear()
vi.clearAllMocks()
}
return {
mocks,
cleanup
}
}
/**
* Measures execution time of an async function
*/
export async function measureTime<T>(fn: () => Promise<T>): Promise<{ result: T; duration: number }> {
const start = Date.now()
const result = await fn()
const duration = Date.now() - start
return { result, duration }
}
/**
* Retries a function until it succeeds or max attempts reached
*/
export async function retryUntilSuccess<T>(fn: () => Promise<T>, maxAttempts = 3, delayMs = 100): Promise<T> {
let lastError: Error | undefined
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return await fn()
} catch (error) {
lastError = error as Error
if (attempt < maxAttempts) {
await wait(delayMs)
}
}
}
throw lastError || new Error('All retry attempts failed')
}
/**
* Creates a mock streaming response that emits chunks at intervals
*/
export function createTimedStream<T>(chunks: T[], intervalMs = 10) {
return {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
await wait(intervalMs)
yield chunk
}
}
}
}
/**
* Asserts that two objects are deeply equal, ignoring specified keys
*/
export function assertDeepEqualIgnoring<T extends Record<string, any>>(
actual: T,
expected: T,
ignoreKeys: string[] = []
): void {
const filterKeys = (obj: T): Partial<T> => {
const filtered = { ...obj }
for (const key of ignoreKeys) {
delete filtered[key]
}
return filtered
}
const filteredActual = filterKeys(actual)
const filteredExpected = filterKeys(expected)
expect(filteredActual).toEqual(filteredExpected)
}
/**
* Creates a provider mock that simulates rate limiting
*/
export function createRateLimitedProvider(limitPerSecond: number) {
const calls: number[] = []
return createMockLanguageModel({
doGenerate: vi.fn().mockImplementation(async () => {
const now = Date.now()
calls.push(now)
// Remove calls older than 1 second
const recentCalls = calls.filter((time) => now - time < 1000)
if (recentCalls.length > limitPerSecond) {
throw new Error('Rate limit exceeded')
}
return {
text: 'Rate limited response',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 },
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}
})
})
}
/**
* Validates streaming response structure
*/
export function validateStreamChunk(chunk: any): void {
expect(chunk).toBeDefined()
expect(chunk).toHaveProperty('type')
if (chunk.type === 'text-delta') {
expect(chunk).toHaveProperty('textDelta')
expect(typeof chunk.textDelta).toBe('string')
} else if (chunk.type === 'finish') {
expect(chunk).toHaveProperty('finishReason')
expect(chunk).toHaveProperty('usage')
} else if (chunk.type === 'tool-call') {
expect(chunk).toHaveProperty('toolCallId')
expect(chunk).toHaveProperty('toolName')
expect(chunk).toHaveProperty('args')
}
}
/**
* Creates a test logger that captures log messages
*/
export function createTestLogger() {
const logs: Array<{ level: string; message: string; meta?: any }> = []
return {
info: (message: string, meta?: any) => logs.push({ level: 'info', message, meta }),
warn: (message: string, meta?: any) => logs.push({ level: 'warn', message, meta }),
error: (message: string, meta?: any) => logs.push({ level: 'error', message, meta }),
debug: (message: string, meta?: any) => logs.push({ level: 'debug', message, meta }),
getLogs: () => logs,
clear: () => {
logs.length = 0
}
}
}

View File

@ -0,0 +1,12 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Fixtures
export * from './fixtures/mock-providers'
export * from './fixtures/mock-responses'
// Helpers
export * from './helpers/provider-test-utils'
export * from './helpers/test-utils'

View File

@ -0,0 +1,499 @@
/**
* RuntimeExecutor.generateText Comprehensive Tests
* Tests non-streaming text generation across all providers with various parameters
*/
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockLanguageModel,
mockCompleteResponses,
mockProviderConfigs,
testMessages,
testTools
} from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
generateText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
})
describe('Basic Functionality', () => {
it('should generate text with minimal parameters', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
expect(result.text).toBe('This is a simple response.')
expect(result.finishReason).toBe('stop')
expect(result.usage).toBeDefined()
})
it('should generate with system messages', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should generate with conversation history', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.conversation
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.conversation
})
)
})
})
describe('All Parameter Combinations', () => {
it('should support all parameters together', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
)
})
it('should support partial parameters', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.5,
maxOutputTokens: 100
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxOutputTokens: 100
})
)
})
})
describe('Tool Calling', () => {
beforeEach(() => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withToolCalls as any)
})
it('should support tool calling', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
tools: testTools
})
)
expect(result.toolCalls).toBeDefined()
expect(result.toolCalls).toHaveLength(1)
})
it('should support toolChoice auto', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'auto'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'auto'
})
)
})
it('should support toolChoice required', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'required'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'required'
})
)
})
it('should support toolChoice none', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
tools: testTools,
toolChoice: 'none'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'none'
})
)
})
it('should support specific tool selection', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
)
})
})
describe('Multiple Providers', () => {
it('should work with Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
const anthropicModel = createMockLanguageModel({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet-20241022',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
})
it('should work with Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
const googleModel = createMockLanguageModel({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
await googleExecutor.generateText({
model: 'gemini-2.0-flash-exp',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
})
it('should work with xAI provider', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
const xaiModel = createMockLanguageModel({
provider: 'xai',
modelId: 'grok-2-latest'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
await xaiExecutor.generateText({
model: 'grok-2-latest',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
})
it('should work with DeepSeek provider', async () => {
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
const deepseekModel = createMockLanguageModel({
provider: 'deepseek',
modelId: 'deepseek-chat'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
await deepseekExecutor.generateText({
model: 'deepseek-chat',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
})
})
describe('Plugin Integration', () => {
it('should execute all plugin hooks', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.8 }
}),
transformResult: vi.fn(async (result) => {
pluginCalls.push('transformResult')
return { ...result, text: result.text + ' [modified]' }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const result = await executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginCalls).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
// Verify transformed parameters
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.8
})
)
// Verify transformed result
expect(result.text).toContain('[modified]')
})
it('should handle multiple plugins in order', async () => {
const pluginOrder: string[] = []
const plugin1: AiPlugin = {
name: 'plugin-1',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-1')
return { ...params, temperature: 0.5 }
})
}
const plugin2: AiPlugin = {
name: 'plugin-2',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-2')
return { ...params, maxTokens: 200 }
})
}
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
await executorWithPlugins.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginOrder).toEqual(['plugin-1', 'plugin-2'])
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxTokens: 200
})
)
})
})
describe('Error Handling', () => {
it('should handle API errors', async () => {
const error = new Error('API request failed')
vi.mocked(generateText).mockRejectedValue(error)
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('API request failed')
})
it('should execute onError plugin hook', async () => {
const error = new Error('Generation failed')
vi.mocked(generateText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Generation failed')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
it('should handle model not found error', async () => {
const error = new Error('Model not found: invalid-model')
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
throw error
})
await expect(
executor.generateText({
model: 'invalid-model',
messages: testMessages.simple
})
).rejects.toThrow('Model not found')
})
})
describe('Usage and Metadata', () => {
it('should return usage information', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(result.usage).toBeDefined()
expect(result.usage.inputTokens).toBe(15)
expect(result.usage.outputTokens).toBe(8)
expect(result.usage.totalTokens).toBe(23)
})
it('should handle warnings', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withWarnings as any)
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 2.5 // Unsupported value
})
expect(result.warnings).toBeDefined()
expect(result.warnings).toHaveLength(1)
expect(result.warnings![0].type).toBe('unsupported-setting')
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle aborted request', async () => {
const abortError = new Error('Request aborted')
abortError.name = 'AbortError'
vi.mocked(generateText).mockRejectedValue(abortError)
const abortController = new AbortController()
abortController.abort()
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
).rejects.toThrow('Request aborted')
})
})
})

View File

@ -0,0 +1,525 @@
/**
* RuntimeExecutor.streamText Comprehensive Tests
* Tests streaming text generation across all providers with various parameters
*/
import { streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
streamText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.streamText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
})
describe('Basic Functionality', () => {
it('should stream text with minimal parameters', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Hello'
yield ' '
yield 'World'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Hello' }
yield { type: 'text-delta', textDelta: ' ' }
yield { type: 'text-delta', textDelta: 'World' }
})(),
usage: Promise.resolve({ promptTokens: 5, completionTokens: 3, totalTokens: 8 })
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
const chunks = await collectStreamChunks(result.textStream)
expect(chunks).toEqual(['Hello', ' ', 'World'])
})
it('should stream with system messages', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should stream multi-turn conversations', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Multi-turn response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Multi-turn response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.multiTurn
})
expect(streamText).toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.multiTurn
})
)
})
})
describe('Temperature Parameter', () => {
const temperatures = [0, 0.3, 0.5, 0.7, 0.9, 1.0, 1.5, 2.0]
it.each(temperatures)('should support temperature=%s', async (temperature) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
temperature
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature
})
)
})
})
describe('Max Tokens Parameter', () => {
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
maxOutputTokens: maxTokens
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
maxTokens
})
)
})
})
describe('Top P Parameter', () => {
const topPValues = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0]
it.each(topPValues)('should support topP=%s', async (topP) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
topP
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
topP
})
)
})
})
describe('Frequency and Presence Penalty', () => {
it('should support frequency penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const frequencyPenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty
})
)
}
})
it('should support presence penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const presencePenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
presencePenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
presencePenalty
})
)
}
})
it('should support both penalties together', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
)
})
})
describe('Seed Parameter', () => {
it('should support seed for deterministic output', async () => {
const seeds = [0, 12345, 67890, 999999]
for (const seed of seeds) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
seed
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
seed
})
)
}
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle abort during streaming', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Start'
// Simulate abort
abortController.abort()
throw new Error('Aborted')
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Start' }
throw new Error('Aborted')
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
await expect(async () => {
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream should be interrupted
}
}).rejects.toThrow('Aborted')
})
})
describe('Plugin Integration', () => {
it('should execute plugins during streaming', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.5 }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
// Consume stream
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream chunks
}
expect(pluginCalls).toContain('onRequestStart')
expect(pluginCalls).toContain('transformParams')
// Verify transformed parameters were used
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5
})
)
})
})
describe('Full Stream with Finish Reason', () => {
it('should provide finish reason in full stream', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
yield {
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
}
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
const fullChunks = await collectStreamChunks(result.fullStream)
expect(fullChunks).toHaveLength(2)
expect(fullChunks[0]).toEqual({ type: 'text-delta', textDelta: 'Response' })
expect(fullChunks[1]).toEqual({
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
})
})
})
describe('Error Handling', () => {
it('should handle streaming errors', async () => {
const error = new Error('Streaming failed')
vi.mocked(streamText).mockRejectedValue(error)
await expect(
executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Streaming failed')
})
it('should execute onError plugin hook on failure', async () => {
const error = new Error('Stream error')
vi.mocked(streamText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Stream error')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
})
})

View File

@ -4,3 +4,34 @@ export const defaultAppHeaders = () => {
'X-Title': 'Cherry Studio'
}
}
// Following two function are not being used for now.
// I may use them in the future, so just keep them commented. - by eurfelux
/**
* Converts an `undefined` value to `null`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `null` if the input is `undefined`; otherwise the input value
*/
// export function toNullIfUndefined<T>(value: T | undefined): T | null {
// if (value === undefined) {
// return null
// } else {
// return value
// }
// }
/**
* Converts a `null` value to `undefined`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `undefined` if the input is `null`; otherwise the input value
*/
// export function toUndefinedIfNull<T>(value: T | null): T | undefined {
// if (value === null) {
// return undefined
// } else {
// return value
// }
// }

View File

@ -3,7 +3,6 @@ import { createServer } from 'node:http'
import { loggerService } from '@logger'
import { IpcChannel } from '@shared/IpcChannel'
import { agentService } from '../services/agents'
import { windowService } from '../services/WindowService'
import { app } from './app'
import { config } from './config'
@ -32,11 +31,6 @@ export class ApiServer {
// Load config
const { port, host } = await config.load()
// Initialize AgentService
logger.info('Initializing AgentService')
await agentService.initialize()
logger.info('AgentService initialized')
// Create server with Express app
this.server = createServer(app)
this.applyServerTimeouts(this.server)

View File

@ -44,6 +44,7 @@ import {
import { dataApiService } from '@data/DataApiService'
import { cacheService } from '@data/CacheService'
import { initWebviewHotkeys } from './services/WebviewService'
import { runAsyncFunction } from './utils'
const logger = loggerService.withContext('MainEntry')
@ -256,39 +257,33 @@ if (!app.requestSingleInstanceLock()) {
//start selection assistant service
initSelectionService()
// Initialize Agent Service
try {
await agentService.initialize()
logger.info('Agent service initialized successfully')
} catch (error: any) {
logger.error('Failed to initialize Agent service:', error)
}
runAsyncFunction(async () => {
// Start API server if enabled or if agents exist
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
// Start API server if enabled or if agents exist
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
// Check if there are any agents
let shouldStart = config.enabled
if (!shouldStart) {
try {
const { total } = await agentService.listAgents({ limit: 1 })
if (total > 0) {
shouldStart = true
logger.info(`Detected ${total} agent(s), auto-starting API server`)
// Check if there are any agents
let shouldStart = config.enabled
if (!shouldStart) {
try {
const { total } = await agentService.listAgents({ limit: 1 })
if (total > 0) {
shouldStart = true
logger.info(`Detected ${total} agent(s), auto-starting API server`)
}
} catch (error: any) {
logger.warn('Failed to check agent count:', error)
}
} catch (error: any) {
logger.warn('Failed to check agent count:', error)
}
}
if (shouldStart) {
await apiServerService.start()
if (shouldStart) {
await apiServerService.start()
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
})
})
registerProtocolClient(app)

View File

@ -1,16 +1,12 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import { mcpApiService } from '@main/apiServer/services/mcp'
import { type ModelValidationError, validateModelId } from '@main/apiServer/utils'
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
import { objectKeys } from '@types'
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
import { MigrationService } from './database/MigrationService'
import * as schema from './database/schema'
import { dbPath } from './drizzle.config'
import { DatabaseManager } from './database/DatabaseManager'
import { type AgentModelField, AgentModelValidationError } from './errors'
import { builtinSlashCommands } from './services/claudecode/commands'
import { builtinTools } from './services/claudecode/tools'
@ -18,22 +14,16 @@ import { builtinTools } from './services/claudecode/tools'
const logger = loggerService.withContext('BaseService')
/**
* Base service class providing shared database connection and utilities
* for all agent-related services.
* Base service class providing shared utilities for all agent-related services.
*
* Features:
* - Programmatic schema management (no CLI dependencies)
* - Automatic table creation and migration
* - Schema version tracking and compatibility checks
* - Transaction-based operations for safety
* - Development vs production mode handling
* - Connection retry logic with exponential backoff
* - Database access through DatabaseManager singleton
* - JSON field serialization/deserialization
* - Path validation and creation
* - Model validation
* - MCP tools and slash commands listing
*/
export abstract class BaseService {
protected static client: Client | null = null
protected static db: LibSQLDatabase<typeof schema> | null = null
protected static isInitialized = false
protected static initializationPromise: Promise<void> | null = null
protected jsonFields: string[] = [
'tools',
'mcps',
@ -43,23 +33,6 @@ export abstract class BaseService {
'slash_commands'
]
/**
* Initialize database with retry logic and proper error handling
*/
protected static async initialize(): Promise<void> {
// Return existing initialization if in progress
if (BaseService.initializationPromise) {
return BaseService.initializationPromise
}
if (BaseService.isInitialized) {
return
}
BaseService.initializationPromise = BaseService.performInitialization()
return BaseService.initializationPromise
}
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
const tools: Tool[] = []
if (agentType === 'claude-code') {
@ -99,78 +72,13 @@ export abstract class BaseService {
return []
}
private static async performInitialization(): Promise<void> {
const maxRetries = 3
let lastError: Error
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`)
// Ensure the database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
BaseService.client = createClient({
url: `file:${dbPath}`
})
BaseService.db = drizzle(BaseService.client, { schema })
// Run database migrations
const migrationService = new MigrationService(BaseService.db, BaseService.client)
await migrationService.runMigrations()
BaseService.isInitialized = true
logger.info('Agent database initialized successfully')
return
} catch (error) {
lastError = error as Error
logger.warn(`Database initialization attempt ${attempt} failed:`, lastError)
// Clean up on failure
if (BaseService.client) {
try {
BaseService.client.close()
} catch (closeError) {
logger.warn('Failed to close client during cleanup:', closeError as Error)
}
}
BaseService.client = null
BaseService.db = null
// Wait before retrying (exponential backoff)
if (attempt < maxRetries) {
const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s
logger.info(`Retrying in ${delay}ms...`)
await new Promise((resolve) => setTimeout(resolve, delay))
}
}
}
// All retries failed
BaseService.initializationPromise = null
logger.error('Failed to initialize Agent database after all retries:', lastError!)
throw lastError!
}
protected ensureInitialized(): void {
if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) {
throw new Error('Database not initialized. Call initialize() first.')
}
}
protected get database(): LibSQLDatabase<typeof schema> {
this.ensureInitialized()
return BaseService.db!
}
protected get rawClient(): Client {
this.ensureInitialized()
return BaseService.client!
/**
* Get database instance
* Automatically waits for initialization to complete
*/
protected async getDatabase() {
const dbManager = await DatabaseManager.getInstance()
return dbManager.getDatabase()
}
protected serializeJsonFields(data: any): any {
@ -282,7 +190,7 @@ export abstract class BaseService {
}
/**
* Force re-initialization (for development/testing)
* Validate agent model configuration
*/
protected async validateAgentModels(
agentType: AgentType,
@ -323,22 +231,4 @@ export abstract class BaseService {
}
}
}
static async reinitialize(): Promise<void> {
BaseService.isInitialized = false
BaseService.initializationPromise = null
if (BaseService.client) {
try {
BaseService.client.close()
} catch (error) {
logger.warn('Failed to close client during reinitialize:', error as Error)
}
}
BaseService.client = null
BaseService.db = null
await BaseService.initialize()
}
}

View File

@ -0,0 +1,156 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import type { LibSQLDatabase } from 'drizzle-orm/libsql'
import { drizzle } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
import { dbPath } from '../drizzle.config'
import { MigrationService } from './MigrationService'
import * as schema from './schema'
const logger = loggerService.withContext('DatabaseManager')
/**
* Database initialization state
*/
enum InitState {
INITIALIZING = 'initializing',
INITIALIZED = 'initialized',
FAILED = 'failed'
}
/**
* DatabaseManager - Singleton class for managing libsql database connections
*
* Responsibilities:
* - Single source of truth for database connection
* - Thread-safe initialization with state management
* - Automatic migration handling
* - Safe connection cleanup
* - Error recovery and retry logic
* - Windows platform compatibility fixes
*/
export class DatabaseManager {
private static instance: DatabaseManager | null = null
private client: Client | null = null
private db: LibSQLDatabase<typeof schema> | null = null
private state: InitState = InitState.INITIALIZING
/**
* Get the singleton instance (database initialization starts automatically)
*/
public static async getInstance(): Promise<DatabaseManager> {
if (DatabaseManager.instance) {
return DatabaseManager.instance
}
const instance = new DatabaseManager()
await instance.initialize()
DatabaseManager.instance = instance
return instance
}
/**
* Perform the actual initialization
*/
public async initialize(): Promise<void> {
if (this.state === InitState.INITIALIZED) {
return
}
try {
logger.info(`Initializing database at: ${dbPath}`)
// Ensure database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
// Check if database file is corrupted (Windows specific check)
if (fs.existsSync(dbPath)) {
const stats = fs.statSync(dbPath)
if (stats.size === 0) {
logger.warn('Database file is empty, removing corrupted file')
fs.unlinkSync(dbPath)
}
}
// Create client with platform-specific options
this.client = createClient({
url: `file:${dbPath}`,
// intMode: 'number' helps avoid some Windows compatibility issues
intMode: 'number'
})
// Create drizzle instance
this.db = drizzle(this.client, { schema })
// Run migrations
const migrationService = new MigrationService(this.db, this.client)
await migrationService.runMigrations()
this.state = InitState.INITIALIZED
logger.info('Database initialized successfully')
} catch (error) {
const err = error as Error
logger.error('Database initialization failed:', {
error: err.message,
stack: err.stack
})
// Clean up failed initialization
this.cleanupFailedInit()
// Set failed state
this.state = InitState.FAILED
throw new Error(`Database initialization failed: ${err.message || 'Unknown error'}`)
}
}
/**
* Clean up after failed initialization
*/
private cleanupFailedInit(): void {
if (this.client) {
try {
// On Windows, closing a partially initialized client can crash
// Wrap in try-catch and ignore errors during cleanup
this.client.close()
} catch (error) {
logger.warn('Failed to close client during cleanup:', error as Error)
}
}
this.client = null
this.db = null
}
/**
* Get the database instance
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public getDatabase(): LibSQLDatabase<typeof schema> {
return this.db!
}
/**
* Get the raw client (for advanced operations)
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public async getClient(): Promise<Client> {
return this.client!
}
/**
* Check if database is initialized
*/
public isInitialized(): boolean {
return this.state === InitState.INITIALIZED
}
}

View File

@ -7,8 +7,14 @@
* Schema evolution is handled by Drizzle Kit migrations.
*/
// Database Manager (Singleton)
export * from './DatabaseManager'
// Drizzle ORM schemas
export * from './schema'
// Repository helpers
export * from './sessionMessageRepository'
// Migration Service
export * from './MigrationService'

View File

@ -15,26 +15,16 @@ import { sessionMessagesTable } from './schema'
const logger = loggerService.withContext('AgentMessageRepository')
type TxClient = any
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
sessionId: string
agentSessionId?: string
tx?: TxClient
}
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
sessionId: string
agentSessionId: string
tx?: TxClient
}
type PersistExchangeParams = AgentMessagePersistExchangePayload & {
tx?: TxClient
}
type PersistExchangeResult = AgentMessagePersistExchangeResult
class AgentMessageRepository extends BaseService {
private static instance: AgentMessageRepository | null = null
@ -87,17 +77,13 @@ class AgentMessageRepository extends BaseService {
return deserialized
}
private getWriter(tx?: TxClient): TxClient {
return tx ?? this.database
}
private async findExistingMessageRow(
writer: TxClient,
sessionId: string,
role: string,
messageId: string
): Promise<SessionMessageRow | null> {
const candidateRows: SessionMessageRow[] = await writer
const database = await this.getDatabase()
const candidateRows: SessionMessageRow[] = await database
.select()
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
@ -122,10 +108,7 @@ class AgentMessageRepository extends BaseService {
private async upsertMessage(
params: PersistUserMessageParams | PersistAssistantMessageParams
): Promise<AgentSessionMessageEntity> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
const { sessionId, agentSessionId = '', payload, metadata, createdAt } = params
if (!payload?.message?.role) {
throw new Error('Message payload missing role')
@ -135,18 +118,18 @@ class AgentMessageRepository extends BaseService {
throw new Error('Message payload missing id')
}
const writer = this.getWriter(tx)
const database = await this.getDatabase()
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
const serializedPayload = this.serializeMessage(payload)
const serializedMetadata = this.serializeMetadata(metadata)
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
const existingRow = await this.findExistingMessageRow(sessionId, payload.message.role, payload.message.id)
if (existingRow) {
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
await writer
await database
.update(sessionMessagesTable)
.set({
content: serializedPayload,
@ -175,7 +158,7 @@ class AgentMessageRepository extends BaseService {
updated_at: now
}
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
const [saved] = await database.insert(sessionMessagesTable).values(insertData).returning()
return this.deserialize(saved)
}
@ -188,49 +171,38 @@ class AgentMessageRepository extends BaseService {
return this.upsertMessage(params)
}
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
async persistExchange(params: AgentMessagePersistExchangePayload): Promise<AgentMessagePersistExchangeResult> {
const { sessionId, agentSessionId, user, assistant } = params
const result = await this.database.transaction(async (tx) => {
const exchangeResult: PersistExchangeResult = {}
const exchangeResult: AgentMessagePersistExchangeResult = {}
if (user?.payload) {
exchangeResult.userMessage = await this.persistUserMessage({
sessionId,
agentSessionId,
payload: user.payload,
metadata: user.metadata,
createdAt: user.createdAt,
tx
})
}
if (user?.payload) {
exchangeResult.userMessage = await this.persistUserMessage({
sessionId,
agentSessionId,
payload: user.payload,
metadata: user.metadata,
createdAt: user.createdAt
})
}
if (assistant?.payload) {
exchangeResult.assistantMessage = await this.persistAssistantMessage({
sessionId,
agentSessionId,
payload: assistant.payload,
metadata: assistant.metadata,
createdAt: assistant.createdAt,
tx
})
}
if (assistant?.payload) {
exchangeResult.assistantMessage = await this.persistAssistantMessage({
sessionId,
agentSessionId,
payload: assistant.payload,
metadata: assistant.metadata,
createdAt: assistant.createdAt
})
}
return exchangeResult
})
return result
return exchangeResult
}
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
try {
const rows = await this.database
const database = await this.getDatabase()
const rows = await database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))

View File

@ -32,14 +32,8 @@ export class AgentService extends BaseService {
return AgentService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
// Agent Methods
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
this.ensureInitialized()
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
const now = new Date().toISOString()
@ -75,8 +69,9 @@ export class AgentService extends BaseService {
updated_at: now
}
await this.database.insert(agentsTable).values(insertData)
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
const database = await this.getDatabase()
await database.insert(agentsTable).values(insertData)
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create agent')
}
@ -86,9 +81,8 @@ export class AgentService extends BaseService {
}
async getAgent(id: string): Promise<GetAgentResponse | null> {
this.ensureInitialized()
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
const database = await this.getDatabase()
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
return null
@ -118,9 +112,9 @@ export class AgentService extends BaseService {
}
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
this.ensureInitialized() // Build query with pagination
const totalResult = await this.database.select({ count: count() }).from(agentsTable)
// Build query with pagination
const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(agentsTable)
const sortBy = options.sortBy || 'created_at'
const orderBy = options.orderBy || 'desc'
@ -128,7 +122,7 @@ export class AgentService extends BaseService {
const sortField = agentsTable[sortBy]
const orderFn = orderBy === 'asc' ? asc : desc
const baseQuery = this.database.select().from(agentsTable).orderBy(orderFn(sortField))
const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
const result =
options.limit !== undefined
@ -151,8 +145,6 @@ export class AgentService extends BaseService {
updates: UpdateAgentRequest,
options: { replace?: boolean } = {}
): Promise<UpdateAgentResponse | null> {
this.ensureInitialized()
// Check if agent exists
const existing = await this.getAgent(id)
if (!existing) {
@ -195,22 +187,21 @@ export class AgentService extends BaseService {
}
}
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
const database = await this.getDatabase()
await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
return await this.getAgent(id)
}
async deleteAgent(id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id))
const database = await this.getDatabase()
const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
return result.rowsAffected > 0
}
async agentExists(id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: agentsTable.id })
.from(agentsTable)
.where(eq(agentsTable.id, id))

View File

@ -104,14 +104,9 @@ export class SessionMessageService extends BaseService {
return SessionMessageService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
async sessionMessageExists(id: number): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: sessionMessagesTable.id })
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.id, id))
@ -124,10 +119,9 @@ export class SessionMessageService extends BaseService {
sessionId: string,
options: ListOptions = {}
): Promise<{ messages: AgentSessionMessageEntity[] }> {
this.ensureInitialized()
// Get messages with pagination
const baseQuery = this.database
const database = await this.getDatabase()
const baseQuery = database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))
@ -146,9 +140,8 @@ export class SessionMessageService extends BaseService {
}
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.delete(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
@ -160,8 +153,6 @@ export class SessionMessageService extends BaseService {
messageData: CreateSessionMessageRequest,
abortController: AbortController
): Promise<SessionStreamResult> {
this.ensureInitialized()
return await this.startSessionMessageStream(session, messageData, abortController)
}
@ -270,10 +261,9 @@ export class SessionMessageService extends BaseService {
}
private async getLastAgentSessionId(sessionId: string): Promise<string> {
this.ensureInitialized()
try {
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))

View File

@ -31,10 +31,6 @@ export class SessionService extends BaseService {
return SessionService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
/**
* Override BaseService.listSlashCommands to merge builtin and plugin commands
*/
@ -85,13 +81,12 @@ export class SessionService extends BaseService {
agentId: string,
req: Partial<CreateSessionRequest> = {}
): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized()
// Validate agent exists - we'll need to import AgentService for this check
// For now, we'll skip this validation to avoid circular dependencies
// The database foreign key constraint will handle this
const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
const database = await this.getDatabase()
const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
if (!agents[0]) {
throw new Error('Agent not found')
}
@ -136,9 +131,10 @@ export class SessionService extends BaseService {
updated_at: now
}
await this.database.insert(sessionsTable).values(insertData)
const db = await this.getDatabase()
await db.insert(sessionsTable).values(insertData)
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create session')
@ -149,9 +145,8 @@ export class SessionService extends BaseService {
}
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select()
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@ -177,8 +172,6 @@ export class SessionService extends BaseService {
agentId?: string,
options: ListOptions = {}
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
this.ensureInitialized()
// Build where conditions
const whereConditions: SQL[] = []
if (agentId) {
@ -193,16 +186,13 @@ export class SessionService extends BaseService {
: undefined
// Get total count
const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause)
const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
const total = totalResult[0].count
// Build list query with pagination - sort by updated_at descending (latest first)
const baseQuery = this.database
.select()
.from(sessionsTable)
.where(whereClause)
.orderBy(desc(sessionsTable.updated_at))
const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
const result =
options.limit !== undefined
@ -221,8 +211,6 @@ export class SessionService extends BaseService {
id: string,
updates: UpdateSessionRequest
): Promise<UpdateSessionResponse | null> {
this.ensureInitialized()
// Check if session exists
const existing = await this.getSession(agentId, id)
if (!existing) {
@ -263,15 +251,15 @@ export class SessionService extends BaseService {
}
}
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
const database = await this.getDatabase()
await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
return await this.getSession(agentId, id)
}
async deleteSession(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.delete(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@ -279,9 +267,8 @@ export class SessionService extends BaseService {
}
async sessionExists(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: sessionsTable.id })
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))

View File

@ -2,7 +2,14 @@
import { EventEmitter } from 'node:events'
import { createRequire } from 'node:module'
import type { CanUseTool, McpHttpServerConfig, Options, SDKMessage } from '@anthropic-ai/claude-agent-sdk'
import type {
CanUseTool,
HookCallback,
McpHttpServerConfig,
Options,
PreToolUseHookInput,
SDKMessage
} from '@anthropic-ai/claude-agent-sdk'
import { query } from '@anthropic-ai/claude-agent-sdk'
import { loggerService } from '@logger'
import { config as apiConfigService } from '@main/apiServer/config'
@ -157,6 +164,63 @@ class ClaudeCodeService implements AgentServiceInterface {
})
}
const preToolUseHook: HookCallback = async (input, toolUseID, options) => {
// Type guard to ensure we're handling PreToolUse event
if (input.hook_event_name !== 'PreToolUse') {
return {}
}
const hookInput = input as PreToolUseHookInput
const toolName = hookInput.tool_name
logger.debug('PreToolUse hook triggered', {
session_id: hookInput.session_id,
tool_name: hookInput.tool_name,
tool_use_id: toolUseID,
tool_input: hookInput.tool_input,
cwd: hookInput.cwd,
permission_mode: hookInput.permission_mode,
autoAllowTools: autoAllowTools
})
if (options?.signal?.aborted) {
logger.debug('PreToolUse hook signal already aborted; skipping tool use', {
tool_name: hookInput.tool_name
})
return {}
}
// handle auto approved tools since it never triggers canUseTool
const normalizedToolName = normalizeToolName(toolName)
if (toolUseID) {
const bypassAll = input.permission_mode === 'bypassPermissions'
const autoAllowed = autoAllowTools.has(toolName) || autoAllowTools.has(normalizedToolName)
if (bypassAll || autoAllowed) {
const namespacedToolCallId = buildNamespacedToolCallId(session.id, toolUseID)
logger.debug('handling auto approved tools', {
toolName,
normalizedToolName,
namespacedToolCallId,
permission_mode: input.permission_mode,
autoAllowTools
})
const isRecord = (v: unknown): v is Record<string, unknown> => {
return !!v && typeof v === 'object' && !Array.isArray(v)
}
const toolInput = isRecord(input.tool_input) ? input.tool_input : {}
await promptForToolApproval(toolName, toolInput, {
...options,
toolCallId: namespacedToolCallId,
autoApprove: true
})
}
}
// Return to proceed without modification
return {}
}
// Build SDK options from parameters
const options: Options = {
abortController,
@ -180,7 +244,14 @@ class ClaudeCodeService implements AgentServiceInterface {
permissionMode: session.configuration?.permission_mode,
maxTurns: session.configuration?.max_turns,
allowedTools: session.allowed_tools,
canUseTool
canUseTool,
hooks: {
PreToolUse: [
{
hooks: [preToolUseHook]
}
]
}
}
if (session.accessible_paths.length > 1) {

View File

@ -31,6 +31,7 @@ type PendingPermissionRequest = {
abortListener?: () => void
originalInput: Record<string, unknown>
toolName: string
toolCallId?: string
}
type RendererPermissionRequestPayload = {
@ -45,6 +46,7 @@ type RendererPermissionRequestPayload = {
createdAt: number
expiresAt: number
suggestions: PermissionUpdate[]
autoApprove?: boolean
}
type RendererPermissionResultPayload = {
@ -52,6 +54,7 @@ type RendererPermissionResultPayload = {
behavior: ToolPermissionBehavior
message?: string
reason: 'response' | 'timeout' | 'aborted' | 'no-window'
toolCallId?: string
}
const pendingRequests = new Map<string, PendingPermissionRequest>()
@ -145,7 +148,8 @@ const finalizeRequest = (
requestId,
behavior: update.behavior,
message: update.behavior === 'deny' ? update.message : undefined,
reason
reason,
toolCallId: pending.toolCallId
}
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
@ -210,6 +214,7 @@ const ensureIpcHandlersRegistered = () => {
type PromptForToolApprovalOptions = {
signal: AbortSignal
suggestions?: PermissionUpdate[]
autoApprove?: boolean
// NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID.
// Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0`
@ -270,7 +275,8 @@ export async function promptForToolApproval(
inputPreview,
createdAt,
expiresAt,
suggestions: sanitizedSuggestions
suggestions: sanitizedSuggestions,
autoApprove: options.autoApprove
}
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
@ -299,7 +305,8 @@ export async function promptForToolApproval(
timeout,
originalInput: sanitizedInput,
toolName,
signal: options?.signal
signal: options?.signal,
toolCallId: options.toolCallId
}
if (options?.signal) {

View File

@ -385,14 +385,13 @@ export class AiSdkToChunkAdapter {
case 'error':
this.onChunk({
type: ChunkType.ERROR,
error:
chunk.error instanceof AISDKError
? chunk.error
: new ProviderSpecificError({
message: formatErrorMessage(chunk.error),
provider: 'unknown',
cause: chunk.error
})
error: AISDKError.isInstance(chunk.error)
? chunk.error
: new ProviderSpecificError({
message: formatErrorMessage(chunk.error),
provider: 'unknown',
cause: chunk.error
})
})
break

View File

@ -32,6 +32,7 @@ import {
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
import type { AiSdkConfig } from './types'
const logger = loggerService.withContext('ModernAiProvider')
@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config?: ReturnType<typeof providerToAiSdkConfig>
private config?: AiSdkConfig
private actualProvider: Provider
private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null
@ -89,6 +90,11 @@ export default class ModernAiProvider {
// 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true
}
@ -149,7 +155,8 @@ export default class ModernAiProvider {
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
if (config.isImageGenerationEndpoint) {
// ai-gateway不是image/generation 端点所以就先不走legacy了
if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
@ -463,8 +470,13 @@ export default class ModernAiProvider {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
// 确保 config 已定义
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with generateImage')
}
// 确保本地provider已创建
if (!this.localProvider) {
if (!this.localProvider && this.config) {
this.localProvider = await createAiSdkProvider(this.config)
if (!this.localProvider) {
throw new Error('Local provider not created')

View File

@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { isNewApiProvider } from '@renderer/config/providers'
import type { Provider } from '@renderer/types'
import { isNewApiProvider } from '@renderer/utils/provider'
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'

View File

@ -7,7 +7,6 @@ import {
isOpenAIModel,
isSupportFlexServiceTierModel
} from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import type {
@ -19,7 +18,6 @@ import type {
MCPToolResponse,
MemoryItem,
Model,
OpenAIVerbosity,
Provider,
ToolCallResponse,
WebSearchProviderResponse,
@ -33,6 +31,7 @@ import {
OpenAIServiceTiers,
SystemProviderIds
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import type { Message } from '@renderer/types/newMessage'
import type {
RequestOptions,
@ -48,6 +47,7 @@ import type {
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import { defaultTimeout } from '@shared/config/constant'
import { REFERENCE_PROMPT } from '@shared/config/prompts'
import { defaultAppHeaders } from '@shared/utils'

View File

@ -58,10 +58,27 @@ vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
isOpenAILLMModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []

View File

@ -1,7 +1,8 @@
import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import type { Model, Provider, VertexProvider } from '@renderer/types'
import { isVertexProvider } from '@renderer/utils/provider'
import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'

View File

@ -10,7 +10,6 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
getThinkModelType,
isClaudeReasoningModel,
isDeepSeekHybridInferenceModel,
@ -40,12 +39,6 @@ import {
MODEL_SUPPORTED_REASONING_EFFORT,
ZHIPU_RESULT_TOKENS
} from '@renderer/config/models'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
import { estimateTextTokens } from '@renderer/services/TokenService'
@ -89,6 +82,12 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/utils/provider'
import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas'
@ -743,7 +742,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
: {}),
...this.getProviderSpecificParameters(assistant, model),
...reasoningEffort,
...getOpenAIWebSearchParams(model, enableWebSearch),
// ...getOpenAIWebSearchParams(model, enableWebSearch),
// OpenRouter usage tracking
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
...extra_body,

View File

@ -12,7 +12,6 @@ import {
isSupportVerbosityModel,
isVisionModel
} from '@renderer/config/models'
import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
FileMetadata,
@ -43,6 +42,7 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
import { MB } from '@shared/config/constant'
import { t } from 'i18next'
import { isEmpty } from 'lodash'

View File

@ -1,6 +1,7 @@
import { loggerService } from '@logger'
import { isZhipuModel } from '@renderer/config/models'
import { getStoreProviders } from '@renderer/hooks/useStore'
import { getDefaultModel } from '@renderer/services/AssistantService'
import type { Chunk } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult } from '../schemas'
@ -66,7 +67,7 @@ export const ErrorHandlerMiddleware =
}
function handleError(error: any, params: CompletionsParams): any {
if (isZhipuModel(params.assistant.model) && error.status && !params.enableGenerateImage) {
if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
return handleZhipuError(error)
}

View File

@ -1,10 +1,10 @@
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { loggerService } from '@logger'
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
@ -12,6 +12,7 @@ import { isEmpty } from 'lodash'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
@ -217,6 +218,14 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
middleware: noThinkMiddleware()
})
}
if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
builder.add({
name: 'openrouter-reasoning-redaction',
middleware: openrouterReasoningMiddleware()
})
logger.debug('Added OpenRouter reasoning redaction middleware')
}
}
/**

View File

@ -0,0 +1,50 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@ -0,0 +1,234 @@
import type { Message, Model } from '@renderer/types'
import type { FileMetadata } from '@renderer/types/file'
import { FileTypes } from '@renderer/types/file'
import {
AssistantMessageStatus,
type FileMessageBlock,
type ImageMessageBlock,
MessageBlockStatus,
MessageBlockType,
type ThinkingMessageBlock,
UserMessageStatus
} from '@renderer/types/newMessage'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
convertFileBlockToFilePartMock: vi.fn(),
convertFileBlockToTextPartMock: vi.fn()
}))
vi.mock('../fileProcessor', () => ({
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
convertFileBlockToTextPart: convertFileBlockToTextPartMock
}))
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
vi.mock('@renderer/config/models', () => ({
isVisionModel: (model: Model) => visionModelIds.has(model.id),
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
}))
type MockableMessage = Message & {
__mockContent?: string
__mockFileBlocks?: FileMessageBlock[]
__mockImageBlocks?: ImageMessageBlock[]
__mockThinkingBlocks?: ThinkingMessageBlock[]
}
vi.mock('@renderer/utils/messageUtils/find', () => ({
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? []
}))
import { convertMessagesToSdkMessages, convertMessageToSdkParam } from '../messageConverter'
let messageCounter = 0
let blockCounter = 0
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o-mini',
name: 'GPT-4o mini',
provider: 'openai',
group: 'openai',
...overrides
})
const createMessage = (role: Message['role']): MockableMessage =>
({
id: `message-${++messageCounter}`,
role,
assistantId: 'assistant-1',
topicId: 'topic-1',
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
blocks: []
}) as MockableMessage
const createFileBlock = (
messageId: string,
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
): FileMessageBlock => {
const { file, ...blockOverrides } = overrides
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
return {
id: blockOverrides.id ?? `file-block-${blockCounter}`,
messageId,
type: MessageBlockType.FILE,
createdAt: blockOverrides.createdAt ?? timestamp,
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
file: {
id: file?.id ?? `file-${blockCounter}`,
name: file?.name ?? 'document.txt',
origin_name: file?.origin_name ?? 'document.txt',
path: file?.path ?? '/tmp/document.txt',
size: file?.size ?? 1024,
ext: file?.ext ?? '.txt',
type: file?.type ?? FileTypes.TEXT,
created_at: file?.created_at ?? timestamp,
count: file?.count ?? 1,
...file
},
...blockOverrides
}
}
const createImageBlock = (
messageId: string,
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
): ImageMessageBlock => ({
id: overrides.id ?? `image-block-${++blockCounter}`,
messageId,
type: MessageBlockType.IMAGE,
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
status: overrides.status ?? MessageBlockStatus.SUCCESS,
url: overrides.url ?? 'https://example.com/image.png',
...overrides
})
describe('messageConverter', () => {
beforeEach(() => {
convertFileBlockToFilePartMock.mockReset()
convertFileBlockToTextPartMock.mockReset()
convertFileBlockToFilePartMock.mockResolvedValue(null)
convertFileBlockToTextPartMock.mockResolvedValue(null)
messageCounter = 0
blockCounter = 0
})
describe('convertMessageToSdkParam', () => {
it('includes text and image parts for user messages on vision models', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Describe this picture'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Describe this picture' },
{ type: 'image', image: 'https://example.com/cat.png' }
]
})
})
it('returns file instructions as a system message when native uploads succeed', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Summarize the PDF'
message.__mockFileBlocks = [createFileBlock(message.id)]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'document.pdf',
mediaType: 'application/pdf',
data: 'fileid://remote-file'
})
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual([
{
role: 'system',
content: 'fileid://remote-file'
},
{
role: 'user',
content: [{ type: 'text', text: 'Summarize the PDF' }]
}
])
})
})
describe('convertMessagesToSdkMessages', () => {
it('appends assistant images to the final user message for image enhancement models', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const initialUser = createMessage('user')
initialUser.__mockContent = 'Start editing'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Here is the current preview'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Increase the brightness'
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
expect(result).toEqual([
{
role: 'assistant',
content: [{ type: 'text', text: 'Here is the current preview' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Increase the brightness' },
{ type: 'image', image: 'https://example.com/preview.png' }
]
}
])
})
it('preserves preceding system instructions when building enhancement payloads', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const fileUser = createMessage('user')
fileUser.__mockContent = 'Use this document as inspiration'
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FileTypes.DOCUMENT } })]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'reference.pdf',
mediaType: 'application/pdf',
data: 'fileid://reference'
})
const assistant = createMessage('assistant')
assistant.__mockContent = 'Generated previews ready'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Apply the edits'
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
expect(result).toEqual([
{ role: 'system', content: 'fileid://reference' },
{
role: 'assistant',
content: [{ type: 'text', text: 'Generated previews ready' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Apply the edits' },
{ type: 'image', image: 'https://example.com/reference.png' }
]
}
])
})
})
})

View File

@ -0,0 +1,218 @@
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
import { TopicType } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
import { describe, expect, it, vi } from 'vitest'
import { getTemperature, getTimeout, getTopP } from '../modelParameters'
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
contextCount: assistant.settings?.contextCount ?? 4096,
temperature: assistant.settings?.temperature ?? 0.7,
enableTemperature: assistant.settings?.enableTemperature ?? true,
topP: assistant.settings?.topP ?? 1,
enableTopP: assistant.settings?.enableTopP ?? false,
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
maxTokens: assistant.settings?.maxTokens,
streamOutput: assistant.settings?.streamOutput ?? true,
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
defaultModel: assistant.defaultModel,
customParameters: assistant.settings?.customParameters ?? [],
reasoning_effort: assistant.settings?.reasoning_effort,
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
qwenThinkMode: assistant.settings?.qwenThinkMode
})
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(),
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
}))
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/assistants', () => ({
default: (state = { assistants: [] }) => state
}))
const createTopic = (assistantId: string): Topic => ({
id: `topic-${assistantId}`,
assistantId,
name: 'topic',
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messages: [],
type: TopicType.Chat
})
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
const assistantId = 'assistant-1'
return {
id: assistantId,
name: 'Test Assistant',
prompt: 'prompt',
topics: [createTopic(assistantId)],
type: 'assistant',
settings
}
}
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
provider: 'openai',
name: 'GPT-4o',
group: 'openai',
...overrides
})
describe('modelParameters', () => {
describe('getTemperature', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'medium' })
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for models without temperature/topP support', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
const model = createModel({
id: 'claude-sonnet-4.5',
name: 'Claude Sonnet 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns configured temperature when enabled', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(0.42)
})
it('returns undefined when temperature is disabled', () => {
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('clamps temperature to max 1.0 for Zhipu models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Anthropic models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
const model = createModel({
id: 'claude-sonnet-3.5',
name: 'Claude 3.5 Sonnet',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Moonshot models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({
id: 'moonshot-v1-8k',
name: 'Moonshot v1 8k',
provider: 'moonshot',
group: 'moonshot'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('does not clamp temperature for OpenAI models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(2.0)
})
it('does not clamp temperature when it is already within limits', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(0.8)
})
})
describe('getTopP', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'high' })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for models without TopP support', () => {
const assistant = createAssistant({ enableTopP: true })
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({
id: 'claude-opus-4.5',
name: 'Claude Opus 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns configured TopP when enabled', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBe(0.73)
})
it('returns undefined when TopP is disabled', () => {
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBeUndefined()
})
})
describe('getTimeout', () => {
it('uses an extended timeout for flex service tier models', () => {
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(15 * 1000 * 60)
})
it('falls back to the default timeout otherwise', () => {
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(defaultTimeout)
})
})
})

View File

@ -1,13 +1,31 @@
import { isClaude45ReasoningModel } from '@renderer/config/models'
import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
const anthropicHeaders: string[] = []
if (isClaude45ReasoningModel(model) && isToolUseModeFunction(assistant)) {
const provider = getProviderByModel(model)
if (
isClaude45ReasoningModel(model) &&
isToolUseModeFunction(assistant) &&
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
) {
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
}
if (isClaude4SeriesModel(model)) {
if (isVertexProvider(provider) && assistant.enableWebSearch) {
anthropicHeaders.push(WEBSEARCH_HEADER)
}
anthropicHeaders.push(CONTEXT_100M_HEADER)
}
return anthropicHeaders
}

View File

@ -85,19 +85,6 @@ export function supportsLargeFileUpload(model: Model): boolean {
})
}
/**
* TopP
*/
export function supportsTopP(model: Model): boolean {
const provider = getProviderByModel(model)
if (provider?.type === 'anthropic' || model?.endpoint_type === 'anthropic') {
return false
}
return true
}
/**
*
*/

View File

@ -3,17 +3,27 @@
* TopP
*/
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
isClaude45ReasoningModel,
isClaudeReasoningModel,
isMaxTemperatureOneModel,
isNotSupportTemperatureAndTopP,
isSupportedFlexServiceTier
isSupportedFlexServiceTier,
isSupportedThinkingTokenClaudeModel
} from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
/**
* Claude 4.5 :
* - temperature 使 temperature
* - top_p 使 top_p
* - temperature ,top_p
* - 使
*
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
@ -27,7 +37,11 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
let temperature = assistantSettings?.temperature
if (temperature && isMaxTemperatureOneModel(model)) {
temperature = Math.min(1, temperature)
}
return assistantSettings?.enableTemperature ? temperature : undefined
}
/**
@ -56,3 +70,18 @@ export function getTimeout(model: Model): number {
}
return defaultTimeout
}
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
// NOTE: ai-sdk会把maxToken和budgetToken加起来
let { maxTokens = DEFAULT_MAX_TOKENS } = getAssistantSettings(assistant)
const provider = getProviderByModel(model)
if (isSupportedThinkingTokenClaudeModel(model) && ['anthropic', 'aws-bedrock'].includes(provider.type)) {
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
const budget = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
if (budget) {
maxTokens -= budget
}
}
return maxTokens
}

View File

@ -4,11 +4,12 @@
*/
import { anthropic } from '@ai-sdk/anthropic'
import { azure } from '@ai-sdk/azure'
import { google } from '@ai-sdk/google'
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
import { vertex } from '@ai-sdk/google-vertex/edge'
import { combineHeaders } from '@ai-sdk/provider-utils'
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
import { loggerService } from '@logger'
import {
@ -17,13 +18,10 @@ import {
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { isAwsBedrockProvider } from '@renderer/config/providers'
import { isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import { getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
@ -36,11 +34,9 @@ import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from '../provider/factory'
import { setupToolsConfig } from '../utils/mcp'
import { buildProviderOptions } from '../utils/options'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
import { addAnthropicHeaders } from './header'
import { supportsTopP } from './modelCapabilities'
import { getTemperature, getTopP } from './modelParameters'
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
const logger = loggerService.withContext('parameterBuilder')
@ -63,7 +59,7 @@ export async function buildStreamTextParams(
timeout?: number
headers?: Record<string, string>
}
} = {}
}
): Promise<{
params: StreamTextParams
modelId: string
@ -80,8 +76,6 @@ export async function buildStreamTextParams(
const model = assistant.model || getDefaultModel()
const aiSdkProviderId = getAiSdkProviderId(provider)
let { maxTokens } = getAssistantSettings(assistant)
// 这三个变量透传出来,交给下面启用插件/中间件
// 也可以在外部构建好再传入buildStreamTextParams
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
@ -118,16 +112,6 @@ export async function buildStreamTextParams(
enableGenerateImage
})
// NOTE: ai-sdk会把maxToken和budgetToken加起来
if (
enableReasoning &&
maxTokens !== undefined &&
isSupportedThinkingTokenClaudeModel(model) &&
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
) {
maxTokens -= getAnthropicThinkingBudget(assistant, model)
}
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
if (enableWebSearch) {
if (isBaseProvider(aiSdkProviderId)) {
@ -144,6 +128,17 @@ export async function buildStreamTextParams(
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-responses') {
tools.web_search_preview = azure.tools.webSearchPreview({
searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
}) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-anthropic') {
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
const anthropicSearchOptions: AnthropicSearchConfig = {
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}
tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
}
}
@ -161,9 +156,10 @@ export async function buildStreamTextParams(
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
break
case 'anthropic':
case 'azure-anthropic':
case 'google-vertex-anthropic':
tools.web_fetch = (
aiSdkProviderId === 'anthropic'
['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)
? anthropic.tools.webFetch_20250910({
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
@ -179,8 +175,7 @@ export async function buildStreamTextParams(
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
if (!isVertexProvider(provider) && !isAwsBedrockProvider(provider) && isAnthropicModel(model)) {
if (isAnthropicModel(model)) {
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
headers = combineHeaders(headers, newBetaHeaders)
}
@ -188,8 +183,9 @@ export async function buildStreamTextParams(
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
maxOutputTokens: maxTokens,
maxOutputTokens: getMaxTokens(assistant, model),
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
abortSignal: options.requestOptions?.signal,
headers,
providerOptions,
@ -197,10 +193,6 @@ export async function buildStreamTextParams(
maxRetries: 0
}
if (supportsTopP(model)) {
params.topP = getTopP(assistant, model)
}
if (tools) {
params.tools = tools
}

View File

@ -23,6 +23,26 @@ vi.mock('@cherrystudio/ai-core', () => ({
}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
// Mock the provider configs
vi.mock('../providerConfigs', () => ({
initializeNewProviders: vi.fn()

View File

@ -12,7 +12,14 @@ vi.mock('@renderer/services/LoggerService', () => ({
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn()
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store', () => ({
@ -34,7 +41,7 @@ vi.mock('@renderer/utils/api', () => ({
}))
}))
vi.mock('@renderer/config/providers', async (importOriginal) => {
vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
@ -53,10 +60,21 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
createVertexProvider: vi.fn()
}))
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/config/providers'
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'

View File

@ -0,0 +1,22 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@ -2,8 +2,10 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
import { initializeNewProviders } from './providerInitialization'
const logger = loggerService.withContext('ProviderFactory')
@ -55,9 +57,12 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
* AI SDK Provider ID
*
*/
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider) && isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
if (resolvedFromId) {
return resolvedFromId
}
@ -73,11 +78,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId
// 3. 最后的fallback使用provider本身的id
return provider.id
}
export async function createAiSdkProvider(config) {
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
let localProvider: Awaited<AiSdkProvider> | null = null
try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') {

View File

@ -1,20 +1,6 @@
import {
formatPrivateKey,
hasProviderConfig,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { cacheService } from '@data/CacheService'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider
} from '@renderer/config/providers'
import {
getAwsBedrockAccessKeyId,
getAwsBedrockApiKey,
@ -22,14 +8,25 @@ import {
getAwsBedrockRegion,
getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
@ -75,6 +72,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
return vertexAnthropicProviderCreator(model, provider)
}
}
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
}
@ -132,13 +132,7 @@ export function getActualProvider(model: Model): Provider {
* Provider AI SDK
*
*/
export function providerToAiSdkConfig(
actualProvider: Provider,
model: Model
): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
@ -192,13 +186,10 @@ export function providerToAiSdkConfig(
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
// extraOptions.apiVersion = actualProvider.apiVersion === 'preview' ? 'v1' : actualProvider.apiVersion 默认使用v1不使用azure endpoint
if (actualProvider.apiVersion === 'preview' || actualProvider.apiVersion === 'v1') {
extraOptions.mode = 'responses'
} else {
extraOptions.mode = 'chat'
}
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// bedrock
@ -238,7 +229,7 @@ export function providerToAiSdkConfig(
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId as ProviderId,
providerId: aiSdkProviderId,
options
}
}

View File

@ -32,6 +32,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',

View File

@ -133,7 +133,7 @@ export class AiSdkSpanAdapter {
// 详细记录转换过程
const operationId = attributes['ai.operationId']
logger.info('Converting AI SDK span to SpanEntity', {
logger.debug('Converting AI SDK span to SpanEntity', {
spanName: spanName,
operationId,
spanTag,
@ -149,7 +149,7 @@ export class AiSdkSpanAdapter {
})
if (tokenUsage) {
logger.info('Token usage data found', {
logger.debug('Token usage data found', {
spanName: spanName,
operationId,
usage: tokenUsage,
@ -158,7 +158,7 @@ export class AiSdkSpanAdapter {
}
if (inputs || outputs) {
logger.info('Input/Output data extracted', {
logger.debug('Input/Output data extracted', {
spanName: spanName,
operationId,
hasInputs: !!inputs,
@ -170,7 +170,7 @@ export class AiSdkSpanAdapter {
}
if (Object.keys(typeSpecificData).length > 0) {
logger.info('Type-specific data extracted', {
logger.debug('Type-specific data extracted', {
spanName: spanName,
operationId,
typeSpecificKeys: Object.keys(typeSpecificData),
@ -204,7 +204,7 @@ export class AiSdkSpanAdapter {
modelName: modelName || this.extractModelFromAttributes(attributes)
}
logger.info('AI SDK span successfully converted to SpanEntity', {
logger.debug('AI SDK span successfully converted to SpanEntity', {
spanName: spanName,
operationId,
spanId: spanContext.spanId,

View File

@ -0,0 +1,15 @@
/**
* This type definition file is only for renderer.
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
* (ai-core package is set as browser-enviroment-only)
*
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
export type AiSdkConfig = {
providerId: string
options: ProviderSettingsMap[keyof ProviderSettingsMap]
}

View File

@ -0,0 +1,121 @@
/**
* image.ts Unit Tests
* Tests for Gemini image generation utilities
*/
import type { Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { describe, expect, it } from 'vitest'
import { buildGeminiGenerateImageParams, isOpenRouterGeminiGenerateImageModel } from '../image'
describe('image utils', () => {
describe('buildGeminiGenerateImageParams', () => {
it('should return correct response modalities', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toEqual({
responseModalities: ['TEXT', 'IMAGE']
})
})
it('should return an object with responseModalities property', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toHaveProperty('responseModalities')
expect(Array.isArray(result.responseModalities)).toBe(true)
expect(result.responseModalities).toHaveLength(2)
})
})
describe('isOpenRouterGeminiGenerateImageModel', () => {
const mockOpenRouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const mockOtherProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should return true for OpenRouter Gemini 2.5 Flash Image model', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for non-Gemini model on OpenRouter', () => {
const model: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model on non-OpenRouter provider', () => {
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.gemini
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOtherProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model without image suffix', () => {
const model: Model = {
id: 'google/gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should handle model ID with partial match', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-generation',
name: 'Gemini Image Gen',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for custom provider', () => {
const customProvider: Provider = {
id: 'custom-provider-123',
name: 'Custom Provider',
apiKey: 'test-key',
apiHost: 'https://custom.com'
} as Provider
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: 'custom-provider-123'
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, customProvider)
expect(result).toBe(false)
})
})
})

View File

@ -0,0 +1,435 @@
/**
* mcp.ts Unit Tests
* Tests for MCP tools configuration and conversion utilities
*/
import type { MCPTool } from '@renderer/types'
import type { Tool } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { convertMcpToolsToAiSdkTools, setupToolsConfig } from '../mcp'
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/utils/mcp-tools', () => ({
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
isToolAutoApproved: vi.fn(() => false),
callMCPTool: vi.fn(async () => ({
content: [{ type: 'text', text: 'Tool executed successfully' }],
isError: false
}))
}))
vi.mock('@renderer/utils/userConfirmation', () => ({
requestToolConfirmation: vi.fn(async () => true)
}))
describe('mcp utils', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('setupToolsConfig', () => {
it('should return undefined when no MCP tools provided', () => {
const result = setupToolsConfig()
expect(result).toBeUndefined()
})
it('should return undefined when empty MCP tools array provided', () => {
const result = setupToolsConfig([])
expect(result).toBeUndefined()
})
it('should convert MCP tools to AI SDK tools format', () => {
const mcpTools: MCPTool[] = [
{
id: 'test-tool-1',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-tool',
description: 'A test tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string' }
}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toEqual(['test-tool'])
expect(result!['test-tool']).toHaveProperty('description')
expect(result!['test-tool']).toHaveProperty('inputSchema')
expect(result!['test-tool']).toHaveProperty('execute')
})
it('should handle multiple MCP tools', () => {
const mcpTools: MCPTool[] = [
{
id: 'tool1-id',
serverId: 'server1',
serverName: 'server1',
name: 'tool1',
description: 'First tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
},
{
id: 'tool2-id',
serverId: 'server2',
serverName: 'server2',
name: 'tool2',
description: 'Second tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toHaveLength(2)
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
})
})
describe('convertMcpToolsToAiSdkTools', () => {
it('should convert single MCP tool to AI SDK tool', () => {
const mcpTools: MCPTool[] = [
{
id: 'get-weather-id',
serverId: 'weather-server',
serverName: 'weather-server',
name: 'get-weather',
description: 'Get weather information',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
location: { type: 'string' }
},
required: ['location']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['get-weather'])
const tool = result['get-weather'] as Tool
expect(tool.description).toBe('Get weather information')
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should handle tool without description', () => {
const mcpTools: MCPTool[] = [
{
id: 'no-desc-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'no-desc-tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['no-desc-tool'])
const tool = result['no-desc-tool'] as Tool
expect(tool.description).toBe('Tool from test-server')
})
it('should convert empty tools array', () => {
const result = convertMcpToolsToAiSdkTools([])
expect(result).toEqual({})
})
it('should handle complex input schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'complex-tool-id',
serverId: 'server',
serverName: 'server',
name: 'complex-tool',
description: 'Tool with complex schema',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
tags: {
type: 'array',
items: { type: 'string' }
},
metadata: {
type: 'object',
properties: {
key: { type: 'string' }
}
}
},
required: ['name']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['complex-tool'])
const tool = result['complex-tool'] as Tool
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should preserve tool names with special characters', () => {
const mcpTools: MCPTool[] = [
{
id: 'special-tool-id',
serverId: 'server',
serverName: 'server',
name: 'tool_with-special.chars',
description: 'Special chars tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
})
it('should handle multiple tools with different schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'string-tool-id',
serverId: 'server1',
serverName: 'server1',
name: 'string-tool',
description: 'String tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' }
}
}
},
{
id: 'number-tool-id',
serverId: 'server2',
serverName: 'server2',
name: 'number-tool',
description: 'Number tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
count: { type: 'number' }
}
}
},
{
id: 'boolean-tool-id',
serverId: 'server3',
serverName: 'server3',
name: 'boolean-tool',
description: 'Boolean tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
enabled: { type: 'boolean' }
}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
expect(result['string-tool']).toBeDefined()
expect(result['number-tool']).toBeDefined()
expect(result['boolean-tool']).toBeDefined()
})
})
describe('tool execution', () => {
it('should execute tool with user confirmation', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'test-exec-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-exec-tool',
description: 'Test execution tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['test-exec-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
})
it('should handle user cancellation', async () => {
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
const mcpTools: MCPTool[] = [
{
id: 'cancelled-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'cancelled-tool',
description: 'Tool to cancel',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['cancelled-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).not.toHaveBeenCalled()
expect(result).toEqual({
content: [
{
type: 'text',
text: 'User declined to execute tool "cancelled-tool".'
}
],
isError: false
})
})
it('should handle tool execution error', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
const mcpTools: MCPTool[] = [
{
id: 'error-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'error-tool',
description: 'Tool that errors',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['error-tool'] as Tool
await expect(
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
).rejects.toEqual({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
})
it('should auto-approve when enabled', async () => {
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(isToolAutoApproved).mockReturnValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'auto-approve-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'auto-approve-tool',
description: 'Auto-approved tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['auto-approve-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
expect(requestToolConfirmation).not.toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
})
})
})

View File

@ -0,0 +1,545 @@
/**
* options.ts Unit Tests
* Tests for building provider-specific options
*/
import type { Assistant, Model, Provider } from '@renderer/types'
import { OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { buildProviderOptions } from '../options'
// Mock dependencies
vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
const actual = (await importOriginal()) as object
return {
...actual,
baseProviderIdSchema: {
safeParse: vi.fn((id) => {
const baseProviders = [
'openai',
'openai-chat',
'azure',
'azure-responses',
'huggingface',
'anthropic',
'google',
'xai',
'deepseek',
'openrouter',
'openai-compatible'
]
if (baseProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false }
})
},
customProviderIdSchema: {
safeParse: vi.fn((id) => {
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
if (customProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false, error: new Error('Invalid provider') }
})
}
}
})
vi.mock('../provider/factory', () => ({
getAiSdkProviderId: vi.fn((provider) => {
// Simulate the provider ID mapping
const mapping: Record<string, string> = {
[SystemProviderIds.gemini]: 'google',
[SystemProviderIds.openai]: 'openai',
[SystemProviderIds.anthropic]: 'anthropic',
[SystemProviderIds.grok]: 'xai',
[SystemProviderIds.deepseek]: 'deepseek',
[SystemProviderIds.openrouter]: 'openrouter'
}
return mapping[provider.id] || provider.id
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => ({
...(await importOriginal()),
isOpenAIModel: vi.fn((model) => model.id.includes('gpt') || model.id.includes('o1')),
isQwenMTModel: vi.fn(() => false),
isSupportFlexServiceTierModel: vi.fn(() => true),
isOpenAILLMModel: vi.fn(() => true),
SYSTEM_MODELS: {
defaultModel: [
{ id: 'default-1', name: 'Default 1' },
{ id: 'default-2', name: 'Default 2' },
{ id: 'default-3', name: 'Default 3' }
]
}
}))
vi.mock(import('@renderer/utils/provider'), async (importOriginal) => {
return {
...(await importOriginal()),
isSupportServiceTierProvider: vi.fn((provider) => {
return [SystemProviderIds.openai, SystemProviderIds.groq].includes(provider.id)
})
}
})
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn((key) => {
if (key === 'openAI') {
return { summaryText: 'off', verbosity: 'medium' } as any
}
return {}
})
}))
vi.mock('@renderer/services/AssistantService', () => ({
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
})),
getAssistantSettings: vi.fn(() => ({
reasoning_effort: 'medium',
maxTokens: 4096
})),
getProviderByModel: vi.fn((model: Model) => ({
id: model.provider,
name: 'Mock Provider'
}))
}))
vi.mock('../reasoning', () => ({
getOpenAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'medium' })),
getAnthropicReasoningParams: vi.fn(() => ({
thinking: { type: 'enabled', budgetTokens: 5000 }
})),
getGeminiReasoningParams: vi.fn(() => ({
thinkingConfig: { include_thoughts: true }
})),
getXAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'high' })),
getBedrockReasoningParams: vi.fn(() => ({
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
})),
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
getCustomParameters: vi.fn(() => ({}))
}))
vi.mock('../image', () => ({
buildGeminiGenerateImageParams: vi.fn(() => ({
responseModalities: ['TEXT', 'IMAGE']
}))
}))
vi.mock('../websearch', () => ({
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('options utils', () => {
const mockAssistant: Assistant = {
id: 'test-assistant',
name: 'Test Assistant',
settings: {}
} as Assistant
const mockModel: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
beforeEach(() => {
vi.clearAllMocks()
})
describe('buildProviderOptions', () => {
describe('OpenAI provider', () => {
const openaiProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai-response',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should build basic OpenAI options', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openai')
expect(result.openai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('reasoningEffort')
expect(result.openai.reasoningEffort).toBe('medium')
})
it('should include service tier when supported', () => {
const providerWithServiceTier: Provider = {
...openaiProvider,
serviceTier: OpenAIServiceTiers.auto
}
const result = buildProviderOptions(mockAssistant, mockModel, providerWithServiceTier, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('serviceTier')
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
})
})
describe('Anthropic provider', () => {
const anthropicProvider: Provider = {
id: SystemProviderIds.anthropic,
name: 'Anthropic',
type: 'anthropic',
apiKey: 'test-key',
apiHost: 'https://api.anthropic.com',
isSystem: true
} as Provider
const anthropicModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
it('should build basic Anthropic options', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
expect(result.anthropic).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.anthropic).toHaveProperty('thinking')
expect(result.anthropic.thinking).toEqual({
type: 'enabled',
budgetTokens: 5000
})
})
})
describe('Google provider', () => {
const googleProvider: Provider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [{ id: 'gemini-2.0-flash-exp' }] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should build basic Google options', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
expect(result.google).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google.thinkingConfig).toEqual({
include_thoughts: true
})
})
it('should include image generation parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('responseModalities')
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
})
})
describe('xAI provider', () => {
const xaiProvider = {
id: SystemProviderIds.grok,
name: 'xAI',
type: 'new-api',
apiKey: 'test-key',
apiHost: 'https://api.x.ai/v1',
isSystem: true,
models: [] as Model[]
} as Provider
const xaiModel: Model = {
id: 'grok-2-latest',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
it('should build basic xAI options', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('xai')
expect(result.xai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.xai).toHaveProperty('reasoningEffort')
expect(result.xai.reasoningEffort).toBe('high')
})
})
describe('DeepSeek provider', () => {
const deepseekProvider: Provider = {
id: SystemProviderIds.deepseek,
name: 'DeepSeek',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.deepseek.com',
isSystem: true
} as Provider
const deepseekModel: Model = {
id: 'deepseek-chat',
name: 'DeepSeek Chat',
provider: SystemProviderIds.deepseek
} as Model
it('should build basic DeepSeek options', () => {
const result = buildProviderOptions(mockAssistant, deepseekModel, deepseekProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('deepseek')
expect(result.deepseek).toBeDefined()
})
})
describe('OpenRouter provider', () => {
const openrouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const openrouterModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
it('should build basic OpenRouter options', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openrouter')
expect(result.openrouter).toBeDefined()
})
it('should include web search parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: true,
enableGenerateImage: false
})
expect(result.openrouter).toHaveProperty('enable_search')
})
})
describe('Custom parameters', () => {
it('should merge custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
custom_param: 'custom_value',
another_param: 123
})
const result = buildProviderOptions(
mockAssistant,
mockModel,
{
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1'
} as Provider,
{
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
}
)
expect(result.openai).toHaveProperty('custom_param')
expect(result.openai.custom_param).toBe('custom_value')
expect(result.openai).toHaveProperty('another_param')
expect(result.openai.another_param).toBe(123)
})
})
describe('Multiple capabilities', () => {
const googleProvider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should combine reasoning and image generation', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google).toHaveProperty('responseModalities')
})
it('should handle all capabilities enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: true,
enableGenerateImage: true
})
expect(result.google).toBeDefined()
expect(Object.keys(result.google).length).toBeGreaterThan(0)
})
})
describe('Vertex AI providers', () => {
it('should map google-vertex to google', () => {
const vertexProvider = {
id: 'google-vertex',
name: 'Vertex AI',
type: 'vertexai',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: 'google-vertex'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
})
it('should map google-vertex-anthropic to anthropic', () => {
const vertexAnthropicProvider = {
id: 'google-vertex-anthropic',
name: 'Vertex AI Anthropic',
type: 'vertex-anthropic',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: 'google-vertex-anthropic'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexAnthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
})
})
})
})

View File

@ -0,0 +1,967 @@
/**
* reasoning.ts Unit Tests
* Tests for reasoning parameter generation utilities
*/
import { getStoreSetting } from '@renderer/hooks/useSettings'
import type { SettingsState } from '@renderer/store/settings'
import type { Assistant, Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
getAnthropicReasoningParams,
getBedrockReasoningParams,
getCustomParameters,
getGeminiReasoningParams,
getOpenAIReasoningParams,
getReasoningEffort,
getXAIReasoningParams
} from '../reasoning'
function defaultGetStoreSetting<K extends keyof SettingsState>(key: K): SettingsState[K] {
if (key === 'openAI') {
return {
summaryText: 'auto',
verbosity: 'medium'
} as SettingsState[K]
}
return undefined as SettingsState[K]
}
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/llm', () => ({
initialState: {},
default: (state = { llm: {} }) => state
}))
vi.mock('@renderer/config/constant', () => ({
DEFAULT_MAX_TOKENS: 4096,
isMac: false,
isWin: false,
TOKENFLUX_HOST: 'mock-host'
}))
vi.mock('@renderer/utils/provider', () => ({
isSupportEnableThinkingProvider: vi.fn((provider) => {
return [SystemProviderIds.dashscope, SystemProviderIds.silicon].includes(provider.id)
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => {
const actual: any = await importOriginal()
return {
...actual,
isReasoningModel: vi.fn(() => false),
isOpenAIDeepResearchModel: vi.fn(() => false),
isOpenAIModel: vi.fn(() => false),
isSupportedReasoningEffortOpenAIModel: vi.fn(() => false),
isSupportedThinkingTokenQwenModel: vi.fn(() => false),
isQwenReasoningModel: vi.fn(() => false),
isSupportedThinkingTokenClaudeModel: vi.fn(() => false),
isSupportedThinkingTokenGeminiModel: vi.fn(() => false),
isSupportedThinkingTokenDoubaoModel: vi.fn(() => false),
isSupportedThinkingTokenZhipuModel: vi.fn(() => false),
isSupportedReasoningEffortModel: vi.fn(() => false),
isDeepSeekHybridInferenceModel: vi.fn(() => false),
isSupportedReasoningEffortGrokModel: vi.fn(() => false),
getThinkModelType: vi.fn(() => 'default'),
isDoubaoSeedAfter251015: vi.fn(() => false),
isDoubaoThinkingAutoModel: vi.fn(() => false),
isGrok4FastReasoningModel: vi.fn(() => false),
isGrokReasoningModel: vi.fn(() => false),
isOpenAIReasoningModel: vi.fn(() => false),
isQwenAlwaysThinkModel: vi.fn(() => false),
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
isSupportedThinkingTokenModel: vi.fn(() => false),
isGPT51SeriesModel: vi.fn(() => false)
}
})
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(defaultGetStoreSetting)
}))
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: vi.fn((assistant) => ({
maxTokens: assistant?.settings?.maxTokens || 4096,
reasoning_effort: assistant?.settings?.reasoning_effort
})),
getProviderByModel: vi.fn((model) => ({
id: model.provider,
name: 'Test Provider'
})),
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
}))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('reasoning utils', () => {
beforeEach(() => {
vi.resetAllMocks()
})
describe('getReasoningEffort', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'anthropic/claude-sonnet-4',
name: 'Claude Sonnet 4',
provider: SystemProviderIds.openrouter
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
})
it('should handle Qwen models with enable_thinking', async () => {
const { isReasoningModel, isSupportedThinkingTokenQwenModel, isQwenReasoningModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'qwen-plus',
name: 'Qwen Plus',
provider: SystemProviderIds.dashscope
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('enable_thinking')
})
it('should handle Claude models with thinking config', async () => {
const {
isSupportedThinkingTokenClaudeModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isSupportedReasoningEffortModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high',
maxTokens: 4096
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budget_tokens: expect.any(Number)
}
})
})
it('should handle Gemini Flash models with thinking budget 0', async () => {
const {
isSupportedThinkingTokenGeminiModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isOpenAIDeepResearchModel,
isSupportedThinkingTokenQwenModel,
isSupportedThinkingTokenHunyuanModel,
isDeepSeekHybridInferenceModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenHunyuanModel).mockReturnValue(false)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
google: {
thinking_config: {
thinking_budget: 0
}
}
}
})
})
it('should handle GPT-5.1 reasoning model with effort levels', async () => {
const {
isReasoningModel,
isOpenAIDeepResearchModel,
isSupportedReasoningEffortModel,
isGPT51SeriesModel,
getThinkModelType
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(true)
vi.mocked(getThinkModelType).mockReturnValue('gpt5_1')
vi.mocked(isGPT51SeriesModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT-5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'none'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
reasoningEffort: 'none'
})
})
it('should handle DeepSeek hybrid inference models', async () => {
const { isReasoningModel, isDeepSeekHybridInferenceModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(true)
const model: Model = {
id: 'deepseek-v3.1',
name: 'DeepSeek V3.1',
provider: SystemProviderIds.silicon
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
enable_thinking: true
})
})
it('should return medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
const model: Model = {
id: 'o3-deep-research',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning_effort: 'medium' })
})
it('should return empty for groq provider', async () => {
const { getProviderByModel } = await import('@renderer/services/AssistantService')
vi.mocked(getProviderByModel).mockReturnValue({
id: 'groq',
name: 'Groq'
} as Provider)
const model: Model = {
id: 'groq-model',
name: 'Groq Model',
provider: 'groq'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
})
describe('getOpenAIReasoningParams', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort set', async () => {
const model: Model = {
id: 'o1-preview',
name: 'O1 Preview',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for OpenAI models', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT 5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: 'auto'
})
})
it('should include reasoning summary when not o1-pro', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'auto'
})
})
it('should not include reasoning summary for o1-pro', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: undefined
})
})
it('should force medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } =
await import('@renderer/config/models')
const { getStoreSetting } = await import('@renderer/hooks/useSettings')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o3-deep-research',
name: 'O3 Mini',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'off'
})
})
})
describe('getAnthropicReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-5-sonnet',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return disabled thinking when no reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'disabled'
}
})
})
it('should return enabled thinking with budget for Claude models', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getGeminiReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.0-flash',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should disable thinking for Flash models without reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: false,
thinkingBudget: 0
}
})
})
it('should enable thinking with budget for reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
thinkingBudget: 16448,
includeThoughts: true
}
})
})
it('should enable thinking without budget for auto effort ratio > 1', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'auto'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: true
}
})
})
})
describe('getXAIReasoningParams', () => {
it('should return empty for non-Grok model', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(false)
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-2',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for Grok models', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-3',
name: 'Grok 3',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toHaveProperty('reasoningEffort')
expect(result.reasoningEffort).toBe('high')
})
})
describe('getBedrockReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning config for Claude models on Bedrock', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({
reasoningConfig: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getCustomParameters', () => {
it('should return empty object when no custom parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({})
})
it('should return custom parameters as key-value pairs', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: 'param1', value: 'value1', type: 'string' },
{ name: 'param2', value: 123, type: 'number' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
param1: 'value1',
param2: 123
})
})
it('should parse JSON type parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'config', value: '{"key": "value"}', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
config: { key: 'value' }
})
})
it('should handle invalid JSON gracefully', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'invalid', value: '{invalid json', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
invalid: '{invalid json'
})
})
it('should handle undefined JSON value', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'undef', value: 'undefined', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
undef: undefined
})
})
it('should skip parameters with empty names', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: '', value: 'value1', type: 'string' },
{ name: ' ', value: 'value2', type: 'string' },
{ name: 'valid', value: 'value3', type: 'string' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
valid: 'value3'
})
})
})
})

View File

@ -0,0 +1,384 @@
/**
* websearch.ts Unit Tests
* Tests for web search parameters generation utilities
*/
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
// Mock dependencies
vi.mock('@renderer/config/models', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
}))
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
}))
describe('websearch utils', () => {
describe('getWebSearchParams', () => {
it('should return enhancement params for hunyuan provider', () => {
const model: Model = {
id: 'hunyuan-model',
name: 'Hunyuan Model',
provider: 'hunyuan'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_enhancement: true,
citation: true,
search_info: true
})
})
it('should return search params for dashscope provider', () => {
const model: Model = {
id: 'qwen-model',
name: 'Qwen Model',
provider: 'dashscope'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_search: true,
search_options: {
forced_search: true
}
})
})
it('should return web_search_options for OpenAI web search models', () => {
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
web_search_options: {}
})
})
it('should return empty object for other providers', () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
it('should return empty object for custom provider', () => {
const model: Model = {
id: 'custom-model',
name: 'Custom Model',
provider: 'custom-provider'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
})
describe('buildProviderBuiltinWebSearchConfig', () => {
const defaultWebSearchConfig: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
describe('openai provider', () => {
it('should return low search context size for low maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 20,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'low'
}
})
})
it('should return medium search context size for medium maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
it('should return high search context size for high maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 80,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'high'
}
})
})
it('should use medium for deep research models regardless of maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
})
describe('openai-chat provider', () => {
it('should return correct search context size', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
it('should handle deep research models', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
})
describe('anthropic provider', () => {
it('should return anthropic search options with maxUses', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
it('should include blockedDomains when excludeDomains provided', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 30,
excludeDomains: ['example.com', 'test.com']
}
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
expect(result).toEqual({
anthropic: {
maxUses: 30,
blockedDomains: ['example.com', 'test.com']
}
})
})
it('should not include blockedDomains when empty', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
})
describe('xai provider', () => {
it('should return xai search options', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result).toEqual({
xai: {
maxSearchResults: 50,
returnCitations: true,
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
mode: 'on'
}
})
})
it('should limit excluded websites to 5', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 40,
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
}
const result = buildProviderBuiltinWebSearchConfig('xai', config)
expect(result?.xai?.sources).toBeDefined()
const webSource = result?.xai?.sources?.[0]
if (webSource && webSource.type === 'web') {
expect(webSource.excludedWebsites).toHaveLength(5)
}
})
it('should include all sources types', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result?.xai?.sources).toHaveLength(3)
expect(result?.xai?.sources?.[0].type).toBe('web')
expect(result?.xai?.sources?.[1].type).toBe('news')
expect(result?.xai?.sources?.[2].type).toBe('x')
})
})
describe('openrouter provider', () => {
it('should return openrouter plugins config', () => {
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 50
}
]
}
})
})
it('should respect custom maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 75,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 75
}
]
}
})
})
})
describe('unsupported provider', () => {
it('should return empty object for unsupported provider', () => {
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
expect(result).toEqual({})
})
it('should return empty object for google provider', () => {
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
expect(result).toEqual({})
})
})
describe('edge cases', () => {
it('should handle maxResults at boundary values', () => {
// Test boundary at 33 (low/medium)
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
expect(result33?.openai?.searchContextSize).toBe('low')
// Test boundary at 34 (medium)
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
expect(result34?.openai?.searchContextSize).toBe('medium')
// Test boundary at 66 (medium)
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
expect(result66?.openai?.searchContextSize).toBe('medium')
// Test boundary at 67 (high)
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
expect(result67?.openai?.searchContextSize).toBe('high')
})
it('should handle zero maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('low')
})
it('should handle very large maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('high')
})
})
})
})

View File

@ -1,3 +1,8 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import {
@ -7,17 +12,27 @@ import {
isSupportFlexServiceTierModel,
isSupportVerbosityModel
} from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import type { Assistant, Model, Provider } from '@renderer/types'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import {
type Assistant,
type GroqServiceTier,
GroqServiceTiers,
type GroqSystemProvider,
isGroqServiceTier,
isGroqSystemProvider,
isOpenAIServiceTier,
isTranslateAssistant,
type Model,
type NotGroqProvider,
type OpenAIServiceTier,
OpenAIServiceTiers,
SystemProviderIds
type Provider,
type ServiceTier
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import type { JSONValue } from 'ai'
import { t } from 'i18next'
import { getAiSdkProviderId } from '../provider/factory'
@ -35,8 +50,31 @@ import { getWebSearchParams } from './websearch'
const logger = loggerService.withContext('aiCore.utils.options')
// copy from BaseApiClient.ts
const getServiceTier = (model: Model, provider: Provider) => {
function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
if (
!isOpenAIServiceTier(serviceTier) ||
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
if (
!isGroqServiceTier(serviceTier) ||
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
const serviceTierSetting = provider.serviceTier
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
@ -44,24 +82,17 @@ const getServiceTier = (model: Model, provider: Provider) => {
}
// 处理不同供应商需要 fallback 到默认值的情况
if (provider.id === SystemProviderIds.groq) {
if (
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
if (isGroqSystemProvider(provider)) {
return toGroqServiceTier(model, serviceTierSetting)
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if (
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
return toOpenAIServiceTier(model, serviceTierSetting)
}
}
return serviceTierSetting
function getVerbosity(): OpenAIVerbosity {
const openAI = getStoreSetting('openAI')
return openAI.verbosity
}
/**
@ -78,13 +109,13 @@ export function buildProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): Record<string, Record<string, JSONValue>> {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
const serviceTierSetting = getServiceTier(model, actualProvider)
providerSpecificOptions.serviceTier = serviceTierSetting
const serviceTier = getServiceTier(model, actualProvider)
const textVerbosity = getVerbosity()
// 根据 provider 类型分离构建逻辑
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
if (success) {
@ -94,9 +125,14 @@ export function buildProviderOptions(
case 'openai-chat':
case 'azure':
case 'azure-responses':
providerSpecificOptions = {
...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
{
const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions(
assistant,
model,
capabilities,
serviceTier
)
providerSpecificOptions = options
}
break
case 'anthropic':
@ -116,12 +152,19 @@ export function buildProviderOptions(
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
serviceTier,
textVerbosity
}
break
}
case 'cherryin':
providerSpecificOptions = buildCherryInProviderOptions(assistant, model, capabilities, actualProvider)
providerSpecificOptions = buildCherryInProviderOptions(
assistant,
model,
capabilities,
actualProvider,
serviceTier
)
break
default:
throw new Error(`Unsupported base provider ${baseProviderId}`)
@ -135,6 +178,7 @@ export function buildProviderOptions(
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'azure-anthropic':
case 'google-vertex-anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
@ -142,13 +186,14 @@ export function buildProviderOptions(
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
break
case 'huggingface':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
serviceTier,
textVerbosity
}
}
} else {
@ -166,6 +211,7 @@ export function buildProviderOptions(
{
'google-vertex': 'google',
'google-vertex-anthropic': 'anthropic',
'azure-anthropic': 'anthropic',
'ai-gateway': 'gateway'
}[rawProviderId] || rawProviderId
@ -189,10 +235,11 @@ function buildOpenAIProviderOptions(
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
},
serviceTier: OpenAIServiceTier
): OpenAIResponsesProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: OpenAIResponsesProviderOptions = {}
// OpenAI 推理参数
if (enableReasoning) {
const reasoningParams = getOpenAIReasoningParams(assistant, model)
@ -203,8 +250,8 @@ function buildOpenAIProviderOptions(
}
if (isSupportVerbosityModel(model)) {
const state = window.store?.getState()
const userVerbosity = state?.settings?.openAI?.verbosity
const openAI = getStoreSetting<'openAI'>('openAI')
const userVerbosity = openAI?.verbosity
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
const supportedVerbosity = getModelSupportedVerbosity(model)
@ -218,6 +265,11 @@ function buildOpenAIProviderOptions(
}
}
providerOptions = {
...providerOptions,
serviceTier
}
return providerOptions
}
@ -232,9 +284,9 @@ function buildAnthropicProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): AnthropicProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: AnthropicProviderOptions = {}
// Anthropic 推理参数
if (enableReasoning) {
@ -259,9 +311,9 @@ function buildGeminiProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): GoogleGenerativeAIProviderOptions {
const { enableReasoning, enableGenerateImage } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: GoogleGenerativeAIProviderOptions = {}
// Gemini 推理参数
if (enableReasoning) {
@ -290,7 +342,7 @@ function buildXAIProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): XaiProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
@ -313,16 +365,12 @@ function buildCherryInProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
},
actualProvider: Provider
): Record<string, any> {
const serviceTierSetting = getServiceTier(model, actualProvider)
actualProvider: Provider,
serviceTier: OpenAIServiceTier
): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
switch (actualProvider.type) {
case 'openai':
return {
...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
}
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
case 'anthropic':
return buildAnthropicProviderOptions(assistant, model, capabilities)
@ -344,9 +392,9 @@ function buildBedrockProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): BedrockProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: BedrockProviderOptions = {}
if (enableReasoning) {
const reasoningParams = getBedrockReasoningParams(assistant, model)

View File

@ -1,6 +1,7 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
@ -11,6 +12,7 @@ import {
isDeepSeekHybridInferenceModel,
isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel,
isGemini3Model,
isGPT51SeriesModel,
isGrok4FastReasoningModel,
isGrokReasoningModel,
@ -32,13 +34,13 @@ import {
isSupportedThinkingTokenZhipuModel,
MODEL_SUPPORTED_REASONING_EFFORT
} from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { SettingsState } from '@renderer/store/settings'
import type { Assistant, Model } from '@renderer/types'
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes'
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { toInteger } from 'lodash'
const logger = loggerService.withContext('reasoning')
@ -130,7 +132,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
// Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
if (isGPT51SeriesModel(model)) {
return {
reasoningEffort: 'none'
}
@ -278,6 +280,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
// gemini series, openai compatible api
if (isSupportedThinkingTokenGeminiModel(model)) {
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
if (isGemini3Model(model)) {
return {
reasoning_effort: reasoningEffort
}
}
if (reasoningEffort === 'auto') {
return {
extra_body: {
@ -341,10 +349,14 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
/**
* OpenAI
* OpenAIResponseAPIClient OpenAIAPIClient
* Get OpenAI reasoning parameters
* Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
* For official OpenAI provider only
*/
export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
export function getOpenAIReasoningParams(
assistant: Assistant,
model: Model
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
if (!isReasoningModel(model)) {
return {}
}
@ -355,6 +367,10 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
reasoningEffort = 'medium'
}
// 非OpenAI模型但是Provider类型是responses/azure openai的情况
if (!isOpenAIModel(model)) {
return {
@ -362,21 +378,17 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
const openAI = getStoreSetting('openAI')
const summaryText = openAI.summaryText
let reasoningSummary: string | undefined = undefined
let reasoningSummary: OpenAISummaryText = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
if (model.id.includes('o1-pro')) {
reasoningSummary = undefined
} else {
reasoningSummary = summaryText
}
if (isOpenAIDeepResearchModel(model)) {
reasoningEffort = 'medium'
}
// OpenAI 推理参数
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
@ -388,19 +400,26 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number {
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
export function getAnthropicThinkingBudget(
maxTokens: number | undefined,
reasoningEffort: string | undefined,
modelId: string
): number | undefined {
if (reasoningEffort === undefined || reasoningEffort === 'none') {
return 0
return undefined
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(modelId)
if (!tokenLimit) {
return undefined
}
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
@ -432,7 +451,8 @@ export function getAnthropicReasoningParams(
// Claude 推理参数
if (isSupportedThinkingTokenClaudeModel(model)) {
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
thinking: {
@ -445,6 +465,21 @@ export function getAnthropicReasoningParams(
return {}
}
type GoogelThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
switch (reasoningEffort) {
case 'low':
return 'low'
case 'medium':
return 'medium'
case 'high':
return 'high'
default:
return 'medium'
}
}
/**
* Gemini
* GeminiAPIClient
@ -472,6 +507,15 @@ export function getGeminiReasoningParams(
}
}
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
if (isGemini3Model(model)) {
return {
thinkingConfig: {
thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
@ -555,7 +599,8 @@ export function getBedrockReasoningParams(
return {}
}
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
reasoningConfig: {
type: 'enabled',

View File

@ -47,6 +47,7 @@ export function buildProviderBuiltinWebSearchConfig(
model?: Model
): WebSearchPluginConfig | undefined {
switch (providerId) {
case 'azure-responses':
case 'openai': {
const searchContextSize = isOpenAIDeepResearchModel(model)
? 'medium'

View File

@ -6,7 +6,7 @@ import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled, { css } from 'styled-components'
interface SelectorOption<V = string | number> {
interface SelectorOption<V = string | number | undefined | null> {
label: string | ReactNode
value: V
type?: 'group'
@ -14,7 +14,7 @@ interface SelectorOption<V = string | number> {
disabled?: boolean
}
interface BaseSelectorProps<V = string | number> {
interface BaseSelectorProps<V = string | number | undefined | null> {
options: SelectorOption<V>[]
placeholder?: string
placement?: 'topLeft' | 'topCenter' | 'topRight' | 'bottomLeft' | 'bottomCenter' | 'bottomRight' | 'top' | 'bottom'
@ -39,7 +39,7 @@ interface MultipleSelectorProps<V> extends BaseSelectorProps<V> {
export type SelectorProps<V> = SingleSelectorProps<V> | MultipleSelectorProps<V>
const Selector = <V extends string | number>({
const Selector = <V extends string | number | undefined | null>({
options,
value,
onChange = () => {},

View File

@ -1,520 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import {
isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel,
isGeminiReasoningModel,
isLingReasoningModel,
isSupportedThinkingTokenGeminiModel
} from '../models/reasoning'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
describe('Doubao Models', () => {
describe('isDoubaoThinkingAutoModel', () => {
it('should return false for invalid models', () => {
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-251015',
name: 'doubao-seed-1-6-251015',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-lite-251015',
name: 'doubao-seed-1-6-lite-251015',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-thinking-250715',
name: 'doubao-seed-1-6-thinking-250715',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-flash',
name: 'doubao-seed-1-6-flash',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-thinking',
name: 'doubao-seed-1-6-thinking',
provider: '',
group: ''
})
).toBe(false)
})
it('should return true for valid models', () => {
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-250615',
name: 'doubao-seed-1-6-250615',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'Doubao-Seed-1.6',
name: 'Doubao-Seed-1.6',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-1-5-thinking-pro-m',
name: 'doubao-1-5-thinking-pro-m',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1.6-lite',
name: 'doubao-seed-1.6-lite',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-1-5-thinking-pro-m-12345',
name: 'doubao-1-5-thinking-pro-m-12345',
provider: '',
group: ''
})
).toBe(true)
})
})
describe('isDoubaoSeedAfter251015', () => {
it('should return true for models matching the pattern', () => {
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-251015',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-lite-251015',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for models not matching the pattern', () => {
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-250615',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'Doubao-Seed-1.6',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-1-5-thinking-pro-m',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-lite-251016',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})
describe('Ling Models', () => {
describe('isLingReasoningModel', () => {
it('should return false for ling variants', () => {
expect(
isLingReasoningModel({
id: 'ling-1t',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isLingReasoningModel({
id: 'ling-flash-2.0',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isLingReasoningModel({
id: 'ling-mini-2.0',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return true for ring variants', () => {
expect(
isLingReasoningModel({
id: 'ring-1t',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isLingReasoningModel({
id: 'ring-flash-2.0',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isLingReasoningModel({
id: 'ring-mini-2.0',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
})
})
describe('Gemini Models', () => {
describe('isSupportedThinkingTokenGeminiModel', () => {
it('should return true for gemini 2.5 models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for image and tts models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-image',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-preview-tts',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return false for older gemini models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
describe('isGeminiReasoningModel', () => {
it('should return true for gemini thinking models', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-2.0-flash-thinking',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-thinking-exp',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for supported thinking token gemini models', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini-3 models', () => {
// Preview versions
expect(
isGeminiReasoningModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isGeminiReasoningModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for older gemini models without thinking', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isGeminiReasoningModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return false for undefined model', () => {
expect(isGeminiReasoningModel(undefined)).toBe(false)
})
})
})

View File

@ -1,167 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import { isVisionModel } from '../models/vision'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
},
getProviderByModel: () => null
}))
describe('isVisionModel', () => {
describe('Gemini Models', () => {
it('should return true for gemini 1.5 models', () => {
expect(
isVisionModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 2.x models', () => {
expect(
isVisionModel({
id: 'gemini-2.0-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isVisionModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isVisionModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isVisionModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini exp models', () => {
expect(
isVisionModel({
id: 'gemini-exp-1206',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for gemini 1.0 models', () => {
expect(
isVisionModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})

View File

@ -1,64 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import { GEMINI_SEARCH_REGEX } from '../models/websearch'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
},
getProviderByModel: () => null
}))
describe('Gemini Search Models', () => {
describe('GEMINI_SEARCH_REGEX', () => {
it('should match gemini 2.x models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
})
it('should match gemini latest models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
})
it('should match gemini 3 models', () => {
// Preview versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
// Future stable versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
})
it('should not match older gemini models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
})
})
})

View File

@ -0,0 +1,101 @@
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
import { isEmbeddingModel, isRerankModel } from '../embedding'
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'test-model',
name: 'Test Model',
provider: 'openai',
group: 'Test',
...overrides
})
describe('isEmbeddingModel', () => {
it('returns true for ids that match the embedding regex', () => {
expect(isEmbeddingModel(createModel({ id: 'Text-Embedding-3-Small' }))).toBe(true)
})
it('returns false for rerank models even if they match embedding patterns', () => {
const model = createModel({ id: 'rerank-qa', name: 'rerank-qa' })
expect(isRerankModel(model)).toBe(true)
expect(isEmbeddingModel(model)).toBe(false)
})
it('honors user overrides for embedding capability', () => {
const model = createModel({
id: 'text-embedding-3-small',
capabilities: [{ type: 'embedding', isUserSelected: false }]
})
expect(isEmbeddingModel(model)).toBe(false)
})
it('uses the model name when provider is doubao', () => {
const model = createModel({
id: 'custom-id',
name: 'BGE-Large-zh-v1.5',
provider: 'doubao'
})
expect(isEmbeddingModel(model)).toBe(true)
})
it('returns false for anthropic provider models', () => {
const model = createModel({
id: 'text-embedding-ada-002',
provider: 'anthropic'
})
expect(isEmbeddingModel(model)).toBe(false)
})
})
describe('isRerankModel', () => {
it('identifies ids that match rerank regex', () => {
expect(isRerankModel(createModel({ id: 'jina-rerank-v2-base' }))).toBe(true)
})
it('honors user overrides for rerank capability', () => {
const model = createModel({
id: 'jina-rerank-v2-base',
capabilities: [{ type: 'rerank', isUserSelected: false }]
})
expect(isRerankModel(model)).toBe(false)
})
})

View File

@ -1,33 +1,55 @@
import {
isImageEnhancementModel,
isPureGenerateImageModel,
isQwenReasoningModel,
isSupportedThinkingTokenQwenModel,
isVisionModel,
isWebSearchModel
isVisionModel
} from '@renderer/config/models'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, test, vi } from 'vitest'
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
const getProviderByModelMock = vi.fn()
const isEmbeddingModelMock = vi.fn()
const isRerankModelMock = vi.fn()
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: (...args: any[]) => getProviderByModelMock(...args),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/config/models/embedding', () => ({
isEmbeddingModel: (...args: any[]) => isEmbeddingModelMock(...args),
isRerankModel: (...args: any[]) => isRerankModelMock(...args)
}))
beforeEach(() => {
vi.clearAllMocks()
getProviderByModelMock.mockReturnValue({ type: 'openai-response' } as any)
isEmbeddingModelMock.mockReturnValue(false)
isRerankModelMock.mockReturnValue(false)
})
// Suggested test cases
describe('Qwen Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
})
test('isQwenReasoningModel', () => {
expect(isQwenReasoningModel({ id: 'qwen3-thinking' } as Model)).toBe(true)
expect(isQwenReasoningModel({ id: 'qwen3-instruct' } as Model)).toBe(false)
@ -56,14 +78,6 @@ describe('Qwen Model Detection', () => {
})
describe('Vision Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
})
test('isVisionModel', () => {
expect(isVisionModel({ id: 'qwen-vl-max' } as Model)).toBe(true)
expect(isVisionModel({ id: 'qwen-omni-turbo' } as Model)).toBe(true)
@ -75,25 +89,4 @@ describe('Vision Model Detection', () => {
expect(isImageEnhancementModel({ id: 'qwen-image-edit' } as Model)).toBe(true)
expect(isImageEnhancementModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
})
test('isPureGenerateImageModel', () => {
expect(isPureGenerateImageModel({ id: 'gpt-image-1' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gemini-2.5-flash-image-preview' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gemini-2.0-flash-preview-image-generation' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gpt-4o' } as Model)).toBe(false)
})
})
describe('Web Search Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
})
test('isWebSearchModel', () => {
expect(isWebSearchModel({ id: 'grok-2-image-latest' } as Model)).toBe(false)
})
})

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,137 @@
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { isEmbeddingModel, isRerankModel } from '../embedding'
import { isDeepSeekHybridInferenceModel } from '../reasoning'
import { isFunctionCallingModel } from '../tooluse'
import { isPureGenerateImageModel, isTextToImageModel } from '../vision'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('../embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
vi.mock('../vision', () => ({
isPureGenerateImageModel: vi.fn(),
isTextToImageModel: vi.fn()
}))
vi.mock('../reasoning', () => ({
isDeepSeekHybridInferenceModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
const pureImageMock = vi.mocked(isPureGenerateImageModel)
const textToImageMock = vi.mocked(isTextToImageModel)
const deepSeekHybridMock = vi.mocked(isDeepSeekHybridInferenceModel)
describe('isFunctionCallingModel', () => {
beforeEach(() => {
vi.clearAllMocks()
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
pureImageMock.mockReturnValue(false)
textToImageMock.mockReturnValue(false)
deepSeekHybridMock.mockReturnValue(false)
})
it('returns false when the model is undefined', () => {
expect(isFunctionCallingModel(undefined as unknown as Model)).toBe(false)
})
it('returns false when model is classified as embedding/rerank/image', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel())).toBe(false)
})
it('respect manual user overrides', () => {
const model = createModel({
capabilities: [{ type: 'function_calling', isUserSelected: false }]
})
expect(isFunctionCallingModel(model)).toBe(false)
const enabled = createModel({
capabilities: [{ type: 'function_calling', isUserSelected: true }]
})
expect(isFunctionCallingModel(enabled)).toBe(true)
})
it('matches doubao models by name when regex applies', () => {
const doubao = createModel({
id: 'custom-model',
name: 'Doubao-Seed-1.6-251015',
provider: 'doubao'
})
expect(isFunctionCallingModel(doubao)).toBe(true)
})
it('returns true for regex matches on standard providers', () => {
expect(isFunctionCallingModel(createModel({ id: 'gpt-5' }))).toBe(true)
})
it('excludes explicitly blocked ids', () => {
expect(isFunctionCallingModel(createModel({ id: 'gemini-1.5-flash' }))).toBe(false)
})
it('forces support for trusted providers', () => {
for (const provider of ['deepseek', 'anthropic', 'kimi', 'moonshot']) {
expect(isFunctionCallingModel(createModel({ provider }))).toBe(true)
}
})
it('returns true when identified as deepseek hybrid inference model', () => {
deepSeekHybridMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'custom' }))).toBe(true)
})
it('returns false for deepseek hybrid models behind restricted system providers', () => {
deepSeekHybridMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'dashscope' }))).toBe(false)
})
})

View File

@ -0,0 +1,280 @@
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
isGPT5ProModel,
isGPT5SeriesModel,
isGPT5SeriesReasoningModel,
isGPT51SeriesModel,
isOpenAIChatCompletionOnlyModel,
isOpenAILLMModel,
isOpenAIModel,
isOpenAIOpenWeightModel,
isOpenAIReasoningModel,
isSupportVerbosityModel
} from '../openai'
import { isQwenMTModel } from '../qwen'
import {
agentModelFilter,
getModelSupportedVerbosity,
groupQwenModels,
isAnthropicModel,
isGeminiModel,
isGemmaModel,
isGenerateImageModels,
isMaxTemperatureOneModel,
isNotSupportedTextDelta,
isNotSupportSystemMessageModel,
isNotSupportTemperatureAndTopP,
isSupportedFlexServiceTier,
isSupportedModel,
isSupportFlexServiceTierModel,
isVisionModels,
isZhipuModel
} from '../utils'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from '../vision'
import { isOpenAIWebSearchChatCompletionOnlyModel } from '../websearch'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('@renderer/config/models/embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
vi.mock('../vision', () => ({
isGenerateImageModel: vi.fn(),
isTextToImageModel: vi.fn(),
isVisionModel: vi.fn()
}))
vi.mock(import('../openai'), async (importOriginal) => {
const actual = await importOriginal()
return {
...actual,
isOpenAIReasoningModel: vi.fn()
}
})
vi.mock('../websearch', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
const visionMock = vi.mocked(isVisionModel)
const textToImageMock = vi.mocked(isTextToImageModel)
const generateImageMock = vi.mocked(isGenerateImageModel)
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
describe('model utils', () => {
beforeEach(() => {
vi.clearAllMocks()
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
visionMock.mockReturnValue(true)
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(true)
reasoningMock.mockReturnValue(false)
openAIWebSearchOnlyMock.mockReturnValue(false)
})
it('detects OpenAI LLM models through reasoning and GPT prefix', () => {
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
})
it('detects OpenAI models via GPT prefix or reasoning support', () => {
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
})
it('evaluates support for flex service tier and alias helper', () => {
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('detects verbosity support for GPT-5+ families', () => {
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
})
it('limits verbosity controls for GPT-5 Pro models', () => {
const proModel = createModel({ id: 'gpt-5-pro' })
const previewModel = createModel({ id: 'gpt-5-preview' })
expect(getModelSupportedVerbosity(proModel)).toEqual([undefined, 'high'])
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
expect(isGPT5ProModel(proModel)).toBe(true)
expect(isGPT5ProModel(previewModel)).toBe(false)
})
it('identifies OpenAI chat-completion-only models', () => {
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('filters unsupported OpenAI catalog entries', () => {
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
})
it('calculates temperature/top-p support correctly', () => {
const model = createModel({ id: 'o1' })
reasoningMock.mockReturnValue(true)
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
const openWeight = createModel({ id: 'gpt-oss-debug' })
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
const chatOnly = createModel({ id: 'o1-preview' })
reasoningMock.mockReturnValue(false)
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
})
it('handles gemma and gemini detections plus zhipu tagging', () => {
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
})
it('groups qwen models by prefix', () => {
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
const grouped = groupQwenModels([qwen, qwenOmni, other])
expect(Object.keys(grouped)).toContain('qwen-7b')
expect(Object.keys(grouped)).toContain('qwen2.5')
expect(grouped.DeepSeek).toContain(other)
})
it('aggregates boolean helpers based on regex rules', () => {
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
expect(isQwenMTModel(createModel({ id: 'qwen-mt-large' }))).toBe(true)
expect(isNotSupportedTextDelta(createModel({ id: 'qwen-mt-large' }))).toBe(true)
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
})
it('evaluates GPT-5 family helpers', () => {
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-prompt' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
})
it('wraps generate/vision helpers that operate on arrays', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
expect(isVisionModels(models)).toBe(true)
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isVisionModels(models)).toBe(false)
expect(isGenerateImageModels(models)).toBe(true)
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isGenerateImageModels(models)).toBe(false)
})
it('filters models for agent usage', () => {
expect(agentModelFilter(createModel())).toBe(true)
embeddingMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
it('identifies models with maximum temperature of 1.0', () => {
// Zhipu models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
// Anthropic models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
// Moonshot models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
// Other models should return false
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
})
})

View File

@ -0,0 +1,311 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { isEmbeddingModel, isRerankModel } from '../embedding'
import {
isAutoEnableImageGenerationModel,
isDedicatedImageGenerationModel,
isGenerateImageModel,
isImageEnhancementModel,
isPureGenerateImageModel,
isTextToImageModel,
isVisionModel
} from '../vision'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn()
}))
vi.mock('../embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const providerMock = vi.mocked(getProviderByModel)
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
describe('vision helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
providerMock.mockReturnValue({ type: 'openai-response' } as any)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
})
describe('isGenerateImageModel', () => {
it('returns false for embedding/rerank models or missing providers', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
rerankMock.mockReturnValue(false)
providerMock.mockReturnValueOnce(undefined as any)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
it('detects OpenAI and third-party generative image models', () => {
expect(isGenerateImageModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
providerMock.mockReturnValue({ type: 'custom' } as any)
expect(isGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image' }))).toBe(true)
})
it('returns false when openai-response model is not on allow list', () => {
expect(isGenerateImageModel(createModel({ id: 'gpt-4.2-experimental' }))).toBe(false)
})
})
describe('isPureGenerateImageModel', () => {
it('requires both generate and text-to-image support', () => {
expect(isPureGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(true)
expect(isPureGenerateImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isPureGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image-preview' }))).toBe(true)
})
})
describe('text-to-image helpers', () => {
it('matches predefined keywords', () => {
expect(isTextToImageModel(createModel({ id: 'midjourney-v6' }))).toBe(true)
expect(isTextToImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('detects models with restricted image size support and enhancement', () => {
expect(isImageEnhancementModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
expect(isImageEnhancementModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('identifies dedicated and auto-enabled image generation models', () => {
expect(isDedicatedImageGenerationModel(createModel({ id: 'grok-2-image-1212' }))).toBe(true)
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gemini-2.5-flash-image-ultra' }))).toBe(true)
})
it('returns false when models are not in dedicated or auto-enable sets', () => {
expect(isDedicatedImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
})
})
describe('isVisionModel', () => {
it('returns false for embedding/rerank models and honors overrides', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isVisionModel(createModel({ id: 'gpt-4o' }))).toBe(false)
embeddingMock.mockReturnValue(false)
const disabled = createModel({
id: 'gpt-4o',
capabilities: [{ type: 'vision', isUserSelected: false }]
})
expect(isVisionModel(disabled)).toBe(false)
const forced = createModel({
id: 'gpt-4o',
capabilities: [{ type: 'vision', isUserSelected: true }]
})
expect(isVisionModel(forced)).toBe(true)
})
it('matches doubao models by name and general regexes by id', () => {
const doubao = createModel({
id: 'custom-id',
provider: 'doubao',
name: 'Doubao-Seed-1-6-Lite-251015'
})
expect(isVisionModel(doubao)).toBe(true)
expect(isVisionModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
})
it('leverages image enhancement regex when standard vision regex does not match', () => {
expect(isVisionModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
})
it('returns false for doubao models that fail regex checks', () => {
const doubao = createModel({ id: 'doubao-standard', provider: 'doubao', name: 'basic' })
expect(isVisionModel(doubao)).toBe(false)
})
describe('Gemini Models', () => {
it('should return true for gemini 1.5 models', () => {
expect(
isVisionModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 2.x models', () => {
expect(
isVisionModel({
id: 'gemini-2.0-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isVisionModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isVisionModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isVisionModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini exp models', () => {
expect(
isVisionModel({
id: 'gemini-exp-1206',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for gemini 1.0 models', () => {
expect(
isVisionModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})

View File

@ -0,0 +1,397 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
const providerMock = vi.mocked(getProviderByModel)
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
const isEmbeddingModel = vi.hoisted(() => vi.fn())
const isRerankModel = vi.hoisted(() => vi.fn())
vi.mock('../embedding', () => ({
isEmbeddingModel: (...args: any[]) => isEmbeddingModel(...args),
isRerankModel: (...args: any[]) => isRerankModel(...args)
}))
const isPureGenerateImageModel = vi.hoisted(() => vi.fn())
const isTextToImageModel = vi.hoisted(() => vi.fn())
const isGenerateImageModel = vi.hoisted(() => vi.fn())
vi.mock('../vision', () => ({
isPureGenerateImageModel: (...args: any[]) => isPureGenerateImageModel(...args),
isTextToImageModel: (...args: any[]) => isTextToImageModel(...args),
isGenerateImageModel: (...args: any[]) => isGenerateImageModel(...args),
isModernGenerateImageModel: vi.fn()
}))
const providerMocks = vi.hoisted(() => ({
isGeminiProvider: vi.fn(),
isNewApiProvider: vi.fn(),
isOpenAICompatibleProvider: vi.fn(),
isOpenAIProvider: vi.fn(),
isVertexProvider: vi.fn(),
isAwsBedrockProvider: vi.fn(),
isAzureOpenAIProvider: vi.fn()
}))
vi.mock('@renderer/utils/provider', () => providerMocks)
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { isOpenAIDeepResearchModel } from '../openai'
import {
GEMINI_SEARCH_REGEX,
isHunyuanSearchModel,
isMandatoryWebSearchModel,
isOpenAIWebSearchChatCompletionOnlyModel,
isOpenAIWebSearchModel,
isOpenRouterBuiltInWebSearchModel,
isWebSearchModel
} from '../websearch'
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
id: 'openai',
type: 'openai',
name: 'OpenAI',
apiKey: '',
apiHost: '',
models: [],
...overrides
})
const resetMocks = () => {
providerMock.mockReturnValue(createProvider())
isEmbeddingModel.mockReturnValue(false)
isRerankModel.mockReturnValue(false)
isPureGenerateImageModel.mockReturnValue(false)
isTextToImageModel.mockReturnValue(false)
providerMocks.isGeminiProvider.mockReturnValue(false)
providerMocks.isNewApiProvider.mockReturnValue(false)
providerMocks.isOpenAICompatibleProvider.mockReturnValue(false)
providerMocks.isOpenAIProvider.mockReturnValue(false)
}
describe('websearch helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
resetMocks()
})
describe('isOpenAIDeepResearchModel', () => {
it('detects deep research ids for OpenAI only', () => {
expect(isOpenAIDeepResearchModel(createModel({ id: 'openai/deep-research-preview' }))).toBe(true)
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openai', id: 'gpt-4o' }))).toBe(false)
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openrouter', id: 'deep-research' }))).toBe(false)
})
})
describe('isWebSearchModel', () => {
it('returns false for embedding/rerank/image models', () => {
isEmbeddingModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
resetMocks()
isRerankModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
resetMocks()
isTextToImageModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
})
it('honors user overrides', () => {
const enabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: true }] })
expect(isWebSearchModel(enabled)).toBe(true)
const disabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: false }] })
expect(isWebSearchModel(disabled)).toBe(false)
})
it('returns false when provider lookup fails', () => {
providerMock.mockReturnValueOnce(undefined as any)
expect(isWebSearchModel(createModel())).toBe(false)
})
it('handles Anthropic providers on unsupported platforms', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds['aws-bedrock'] }))
const model = createModel({ id: 'claude-2-sonnet' })
expect(isWebSearchModel(model)).toBe(false)
})
it('returns true for first-party Anthropic provider', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'anthropic' }))
const model = createModel({ id: 'claude-3.5-sonnet-latest', provider: 'anthropic' })
expect(isWebSearchModel(model)).toBe(true)
})
it('detects OpenAI preview search models only when supported', () => {
providerMocks.isOpenAIProvider.mockReturnValue(true)
const model = createModel({ id: 'gpt-4o-search-preview' })
expect(isWebSearchModel(model)).toBe(true)
const nonSearch = createModel({ id: 'gpt-4o-image' })
expect(isWebSearchModel(nonSearch)).toBe(false)
})
it('supports Perplexity sonar families including mandatory variants', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
expect(isWebSearchModel(createModel({ id: 'sonar-deep-research' }))).toBe(true)
})
it('handles AIHubMix Gemini and OpenAI search models', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
expect(isWebSearchModel(createModel({ id: 'gemini-2.5-pro-preview' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
const openaiSearch = createModel({ id: 'gpt-4o-search-preview' })
expect(isWebSearchModel(openaiSearch)).toBe(true)
})
it('supports OpenAI-compatible or new API providers for Gemini/OpenAI models', () => {
const model = createModel({ id: 'gemini-2.5-flash-lite-latest' })
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
providerMocks.isOpenAICompatibleProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(model)).toBe(true)
resetMocks()
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
providerMocks.isNewApiProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
})
it('falls back to Gemini/Vertex provider regex matching', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.vertexai }))
providerMocks.isGeminiProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel({ id: 'gemini-2.0-flash-latest' }))).toBe(true)
})
it('evaluates hunyuan/zhipu/dashscope/openrouter/grok providers', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'hunyuan' }))
expect(isWebSearchModel(createModel({ id: 'hunyuan-pro' }))).toBe(true)
expect(isWebSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
providerMock.mockReturnValueOnce(createProvider({ id: 'zhipu' }))
expect(isWebSearchModel(createModel({ id: 'glm-4-air' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id: 'qwen-max-latest' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isWebSearchModel(createModel())).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'grok' }))
expect(isWebSearchModel(createModel({ id: 'grok-2' }))).toBe(true)
})
})
describe('isMandatoryWebSearchModel', () => {
it('requires sonar ids for perplexity/openrouter providers', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.openrouter }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openai' }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(false)
})
it.each([
['perplexity', 'non-sonar'],
['openrouter', 'gpt-4o-search-preview']
])('returns false for %s provider when id is %s', (providerId, modelId) => {
providerMock.mockReturnValueOnce(createProvider({ id: providerId }))
expect(isMandatoryWebSearchModel(createModel({ id: modelId }))).toBe(false)
})
})
describe('isOpenRouterBuiltInWebSearchModel', () => {
it('checks for sonar ids or OpenAI chat-completion-only variants', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(false)
})
})
describe('OpenAI web search helpers', () => {
it('detects chat completion only variants and openai search ids', () => {
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-mini-search-preview' }))).toBe(true)
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4.1-turbo' }))).toBe(true)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-5.1-chat' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'o3-mini' }))).toBe(true)
})
it.each(['gpt-4.1-preview', 'gpt-4o-2024-05-13', 'o4-mini', 'gpt-5-explorer'])(
'treats %s as an OpenAI web search model',
(id) => {
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each(['gpt-4o-image-preview', 'gpt-4.1-nano', 'gpt-5.1-chat', 'gpt-image-1'])(
'excludes %s from OpenAI web search',
(id) => {
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(false)
}
)
it.each(['gpt-4o-search-preview', 'gpt-4o-mini-search-preview'])('flags %s as chat-completion-only', (id) => {
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id }))).toBe(true)
})
})
describe('isHunyuanSearchModel', () => {
it('identifies hunyuan models except lite', () => {
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-pro', provider: 'hunyuan' }))).toBe(true)
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
expect(isHunyuanSearchModel(createModel())).toBe(false)
})
it.each(['hunyuan-standard', 'hunyuan-advanced'])('accepts %s', (suffix) => {
expect(isHunyuanSearchModel(createModel({ id: suffix, provider: 'hunyuan' }))).toBe(true)
})
})
describe('provider-specific regex coverage', () => {
it.each(['qwen-turbo', 'qwen-max-0919', 'qwen3-max', 'qwen-plus-2024', 'qwq-32b'])(
'dashscope treats %s as searchable',
(id) => {
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each(['qwen-1.5-chat', 'custom-model'])('dashscope ignores %s', (id) => {
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id }))).toBe(false)
})
it.each(['sonar', 'sonar-pro', 'sonar-reasoning-pro', 'sonar-deep-research'])(
'perplexity provider supports %s',
(id) => {
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.perplexity }))
expect(isWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each([
'gemini-2.0-flash-latest',
'gemini-2.5-flash-lite-latest',
'gemini-flash-lite-latest',
'gemini-pro-latest'
])('Gemini provider supports %s', (id) => {
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.vertexai }))
providerMocks.isGeminiProvider.mockReturnValue(true)
expect(isWebSearchModel(createModel({ id }))).toBe(true)
})
})
describe('Gemini Search Models', () => {
describe('GEMINI_SEARCH_REGEX', () => {
it('should match gemini 2.x models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
})
it('should match gemini latest models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
})
it('should match gemini 3 models', () => {
// Preview versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-image-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-image-preview')).toBe(true)
// Future stable versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
// Version with decimals
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-flash-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-pro-image-preview')).toBe(true)
})
it('should not match gemini 2.x image-preview models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-image-preview')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro-image-preview')).toBe(false)
})
it('should not match older gemini models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
})
})
})
})

View File

@ -1,6 +1,8 @@
export * from './default'
export * from './embedding'
export * from './logo'
export * from './openai'
export * from './qwen'
export * from './reasoning'
export * from './tooluse'
export * from './utils'

View File

@ -0,0 +1,107 @@
import type { Model } from '@renderer/types'
import { getLowerBaseModelName } from '@renderer/utils'
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
export function isOpenAILLMModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
if (modelId.includes('gpt-4o-image')) {
return false
}
if (isOpenAIReasoningModel(model)) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
export function isOpenAIModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
}
export const isGPT5ProModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5-pro')
}
export const isOpenAIOpenWeightModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-oss')
}
export const isGPT5SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
}
export const isGPT5SeriesReasoningModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return isGPT5SeriesModel(model) && !modelId.includes('chat')
}
export const isGPT51SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5.1')
}
export function isSupportVerbosityModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
}
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return (
modelId.includes('gpt-4o-search-preview') ||
modelId.includes('gpt-4o-mini-search-preview') ||
modelId.includes('o1-mini') ||
modelId.includes('o1-preview')
)
}
export function isOpenAIReasoningModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
}
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
modelId.includes('o3') ||
modelId.includes('o4') ||
modelId.includes('gpt-oss') ||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
)
}
const OPENAI_DEEP_RESEARCH_MODEL_REGEX = /deep[-_]?research/
export function isOpenAIDeepResearchModel(model?: Model): boolean {
if (!model) {
return false
}
const providerId = model.provider
if (providerId !== 'openai' && providerId !== 'openai-chat') {
return false
}
const modelId = getLowerBaseModelName(model.id, '/')
return OPENAI_DEEP_RESEARCH_MODEL_REGEX.test(modelId)
}

View File

@ -0,0 +1,7 @@
import type { Model } from '@renderer/types'
import { getLowerBaseModelName } from '@renderer/utils'
export const isQwenMTModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('qwen-mt')
}

View File

@ -8,9 +8,16 @@ import type {
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isGPT5ProModel, isGPT5SeriesModel, isGPT51SeriesModel } from './utils'
import {
isGPT5ProModel,
isGPT5SeriesModel,
isGPT51SeriesModel,
isOpenAIDeepResearchModel,
isOpenAIReasoningModel,
isSupportedReasoningEffortOpenAIModel
} from './openai'
import { GEMINI_FLASH_MODEL_REGEX, isGemini3Model } from './utils'
import { isTextToImageModel } from './vision'
import { GEMINI_FLASH_MODEL_REGEX, isOpenAIDeepResearchModel } from './websearch'
// Reasoning models
export const REASONING_REGEX =
@ -30,6 +37,7 @@ export const MODEL_SUPPORTED_REASONING_EFFORT: ReasoningEffortConfig = {
grok: ['low', 'high'] as const,
grok4_fast: ['auto'] as const,
gemini: ['low', 'medium', 'high', 'auto'] as const,
gemini3: ['low', 'medium', 'high'] as const,
gemini_pro: ['low', 'medium', 'high', 'auto'] as const,
qwen: ['low', 'medium', 'high'] as const,
qwen_thinking: ['low', 'medium', 'high'] as const,
@ -56,6 +64,7 @@ export const MODEL_SUPPORTED_OPTIONS: ThinkingOptionConfig = {
grok4_fast: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.grok4_fast] as const,
gemini: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.gemini] as const,
gemini_pro: MODEL_SUPPORTED_REASONING_EFFORT.gemini_pro,
gemini3: MODEL_SUPPORTED_REASONING_EFFORT.gemini3,
qwen: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.qwen] as const,
qwen_thinking: MODEL_SUPPORTED_REASONING_EFFORT.qwen_thinking,
doubao: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.doubao] as const,
@ -106,6 +115,9 @@ const _getThinkModelType = (model: Model): ThinkingModelType => {
} else {
thinkingModelType = 'gemini_pro'
}
if (isGemini3Model(model)) {
thinkingModelType = 'gemini3'
}
} else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok'
else if (isSupportedThinkingTokenQwenModel(model)) {
if (isQwenAlwaysThinkModel(model)) {
@ -254,11 +266,19 @@ export function isGeminiReasoningModel(model?: Model): boolean {
// Gemini 支持思考模式的模型正则
export const GEMINI_THINKING_MODEL_REGEX =
/gemini-(?:2\.5.*(?:-latest)?|3-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i
/gemini-(?:2\.5.*(?:-latest)?|3(?:\.\d+)?-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i
export const isSupportedThinkingTokenGeminiModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id, '/')
if (GEMINI_THINKING_MODEL_REGEX.test(modelId)) {
// gemini-3.x 的 image 模型支持思考模式
if (isGemini3Model(model)) {
if (modelId.includes('tts')) {
return false
}
return true
}
// gemini-2.x 的 image/tts 模型不支持
if (modelId.includes('image') || modelId.includes('tts')) {
return false
}
@ -382,6 +402,12 @@ export function isClaude45ReasoningModel(model: Model): boolean {
return regex.test(modelId)
}
export function isClaude4SeriesModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:-[\w-]+)?$/i
return regex.test(modelId)
}
export function isClaudeReasoningModel(model?: Model): boolean {
if (!model) {
return false
@ -529,22 +555,6 @@ export function isReasoningModel(model?: Model): boolean {
return REASONING_REGEX.test(modelId) || false
}
export function isOpenAIReasoningModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
}
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
modelId.includes('o3') ||
modelId.includes('o4') ||
modelId.includes('gpt-oss') ||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
)
}
export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> = {
// Gemini models
'gemini-2\\.5-flash-lite.*$': { min: 512, max: 24576 },

View File

@ -4,7 +4,7 @@ import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isDeepSeekHybridInferenceModel } from './reasoning'
import { isPureGenerateImageModel, isTextToImageModel } from './vision'
import { isTextToImageModel } from './vision'
// Tool calling models
export const FUNCTION_CALLING_MODELS = [
@ -41,7 +41,9 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [
'gemini-1(?:\\.[\\w-]+)?',
'qwen-mt(?:-[\\w-]+)?',
'gpt-5-chat(?:-[\\w-]+)?',
'glm-4\\.5v'
'glm-4\\.5v',
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation'
]
export const FUNCTION_CALLING_REGEX = new RegExp(
@ -50,13 +52,7 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
)
export function isFunctionCallingModel(model?: Model): boolean {
if (
!model ||
isEmbeddingModel(model) ||
isRerankModel(model) ||
isTextToImageModel(model) ||
isPureGenerateImageModel(model)
) {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
}
@ -66,10 +62,6 @@ export function isFunctionCallingModel(model?: Model): boolean {
return isUserSelectedModelType(model, 'function_calling')!
}
if (model.provider === 'qiniu') {
return ['deepseek-v3-tool', 'deepseek-v3-0324', 'qwq-32b', 'qwen2.5-72b-instruct'].includes(modelId)
}
if (model.provider === 'doubao' || modelId.includes('doubao')) {
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
}

View File

@ -1,43 +1,14 @@
import type OpenAI from '@cherrystudio/openai'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import type { Model } from '@renderer/types'
import { type Model, SystemProviderIds } from '@renderer/types'
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils'
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '@shared/config/prompts'
import { getWebSearchTools } from '../tools'
import { isOpenAIReasoningModel } from './reasoning'
import { isOpenAIChatCompletionOnlyModel, isOpenAIOpenWeightModel, isOpenAIReasoningModel } from './openai'
import { isQwenMTModel } from './qwen'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
import { isOpenAIWebSearchChatCompletionOnlyModel } from './websearch'
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
export function isOpenAILLMModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
if (modelId.includes('gpt-4o-image')) {
return false
}
if (isOpenAIReasoningModel(model)) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
export function isOpenAIModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
}
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
export function isSupportFlexServiceTierModel(model: Model): boolean {
if (!model) {
@ -52,33 +23,6 @@ export function isSupportedFlexServiceTier(model: Model): boolean {
return isSupportFlexServiceTierModel(model)
}
export function isSupportVerbosityModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
}
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return (
modelId.includes('gpt-4o-search-preview') ||
modelId.includes('gpt-4o-mini-search-preview') ||
modelId.includes('o1-mini') ||
modelId.includes('o1-preview')
)
}
export function isGrokModel(model?: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('grok')
}
export function isSupportedModel(model: OpenAI.Models.Model): boolean {
if (!model) {
return false
@ -105,53 +49,6 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
return false
}
export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record<string, any> {
if (!isEnableWebSearch) {
return {}
}
const webSearchTools = getWebSearchTools(model)
if (model.provider === 'grok') {
return {
search_parameters: {
mode: 'auto',
return_citations: true,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
}
}
}
if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
if (model.provider === 'dashscope') {
return {
enable_search: true,
search_options: {
forced_search: true
}
}
}
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
return {
web_search_options: {}
}
}
if (model.provider === 'openrouter') {
return {
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
}
}
return {
tools: webSearchTools
}
}
export function isGemmaModel(model?: Model): boolean {
if (!model) {
return false
@ -161,12 +58,14 @@ export function isGemmaModel(model?: Model): boolean {
return modelId.includes('gemma-') || model.group === 'Gemma'
}
export function isZhipuModel(model?: Model): boolean {
if (!model) {
return false
}
export function isZhipuModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('glm') || model.provider === SystemProviderIds.zhipu
}
return model.provider === 'zhipu'
export function isMoonshotModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return ['moonshot', 'kimi'].some((m) => modelId.includes(m))
}
/**
@ -212,11 +111,6 @@ export const isAnthropicModel = (model?: Model): boolean => {
return modelId.startsWith('claude')
}
export const isQwenMTModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('qwen-mt')
}
export const isNotSupportedTextDelta = (model: Model): boolean => {
return isQwenMTModel(model)
}
@ -225,34 +119,22 @@ export const isNotSupportSystemMessageModel = (model: Model): boolean => {
return isQwenMTModel(model) || isGemmaModel(model)
}
export const isGPT5SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
}
export const isGPT5SeriesReasoningModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return isGPT5SeriesModel(model) && !modelId.includes('chat')
}
export const isGPT51SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5.1')
}
// GPT-5 verbosity configuration
// gpt-5-pro only supports 'high', other GPT-5 models support all levels
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ('low' | 'medium' | 'high')[]> = {
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ValidOpenAIVerbosity[]> = {
'gpt-5-pro': ['high'],
default: ['low', 'medium', 'high']
}
} as const
export const getModelSupportedVerbosity = (model: Model): ('low' | 'medium' | 'high')[] => {
export const getModelSupportedVerbosity = (model: Model): OpenAIVerbosity[] => {
const modelId = getLowerBaseModelName(model.id)
let supportedValues: ValidOpenAIVerbosity[]
if (modelId.includes('gpt-5-pro')) {
return MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
supportedValues = MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
} else {
supportedValues = MODEL_SUPPORTED_VERBOSITY.default
}
return MODEL_SUPPORTED_VERBOSITY.default
return [undefined, ...supportedValues]
}
export const isGeminiModel = (model: Model) => {
@ -260,11 +142,6 @@ export const isGeminiModel = (model: Model) => {
return modelId.includes('gemini')
}
export const isOpenAIOpenWeightModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-oss')
}
// zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
@ -272,7 +149,14 @@ export const agentModelFilter = (model: Model): boolean => {
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
}
export const isGPT5ProModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5-pro')
export const isMaxTemperatureOneModel = (model: Model): boolean => {
if (isZhipuModel(model) || isAnthropicModel(model) || isMoonshotModel(model)) {
return true
}
return false
}
export const isGemini3Model = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gemini-3')
}

View File

@ -3,6 +3,7 @@ import type { Model } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isFunctionCallingModel } from './tooluse'
// Vision models
const visionAllowedModels = [
@ -72,12 +73,10 @@ const VISION_REGEX = new RegExp(
// For middleware to identify models that must use the dedicated Image API
const DEDICATED_IMAGE_MODELS = [
'grok-2-image',
'grok-2-image-1212',
'grok-2-image-latest',
'dall-e-3',
'dall-e-2',
'gpt-image-1'
'grok-2-image(?:-[\\w-]+)?',
'dall-e(?:-[\\w-]+)?',
'gpt-image-1(?:-[\\w-]+)?',
'imagen(?:-[\\w-]+)?'
]
const IMAGE_ENHANCEMENT_MODELS = [
@ -85,13 +84,22 @@ const IMAGE_ENHANCEMENT_MODELS = [
'qwen-image-edit',
'gpt-image-1',
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation'
'gemini-2.0-flash-preview-image-generation',
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?'
]
const IMAGE_ENHANCEMENT_MODELS_REGEX = new RegExp(IMAGE_ENHANCEMENT_MODELS.join('|'), 'i')
const DEDICATED_IMAGE_MODELS_REGEX = new RegExp(DEDICATED_IMAGE_MODELS.join('|'), 'i')
// Models that should auto-enable image generation button when selected
const AUTO_ENABLE_IMAGE_MODELS = ['gemini-2.5-flash-image', ...DEDICATED_IMAGE_MODELS]
const AUTO_ENABLE_IMAGE_MODELS = [
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?',
...DEDICATED_IMAGE_MODELS
]
const AUTO_ENABLE_IMAGE_MODELS_REGEX = new RegExp(AUTO_ENABLE_IMAGE_MODELS.join('|'), 'i')
const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
'o3',
@ -105,26 +113,34 @@ const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
const OPENAI_IMAGE_GENERATION_MODELS = [...OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS, 'gpt-image-1']
const MODERN_IMAGE_MODELS = ['gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?']
const GENERATE_IMAGE_MODELS = [
'gemini-2.0-flash-exp',
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-exp(?:-[\\w-]+)?',
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation',
'gemini-2.5-flash-image',
...MODERN_IMAGE_MODELS,
...DEDICATED_IMAGE_MODELS
]
const OPENAI_IMAGE_GENERATION_MODELS_REGEX = new RegExp(OPENAI_IMAGE_GENERATION_MODELS.join('|'), 'i')
const GENERATE_IMAGE_MODELS_REGEX = new RegExp(GENERATE_IMAGE_MODELS.join('|'), 'i')
const MODERN_GENERATE_IMAGE_MODELS_REGEX = new RegExp(MODERN_IMAGE_MODELS.join('|'), 'i')
export const isDedicatedImageGenerationModel = (model: Model): boolean => {
if (!model) return false
const modelId = getLowerBaseModelName(model.id)
return DEDICATED_IMAGE_MODELS.some((m) => modelId.includes(m))
return DEDICATED_IMAGE_MODELS_REGEX.test(modelId)
}
export const isAutoEnableImageGenerationModel = (model: Model): boolean => {
if (!model) return false
const modelId = getLowerBaseModelName(model.id)
return AUTO_ENABLE_IMAGE_MODELS.some((m) => modelId.includes(m))
return AUTO_ENABLE_IMAGE_MODELS_REGEX.test(modelId)
}
/**
@ -146,48 +162,44 @@ export function isGenerateImageModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
if (provider.type === 'openai-response') {
return (
OPENAI_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel)) ||
GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel))
)
return OPENAI_IMAGE_GENERATION_MODELS_REGEX.test(modelId) || GENERATE_IMAGE_MODELS_REGEX.test(modelId)
}
return GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel))
return GENERATE_IMAGE_MODELS_REGEX.test(modelId)
}
// TODO: refine the regex
/**
*
* @param model
* @returns
*/
export function isPureGenerateImageModel(model: Model): boolean {
if (!isGenerateImageModel(model) || !isTextToImageModel(model)) {
if (!isGenerateImageModel(model) && !isTextToImageModel(model)) {
return false
}
if (isFunctionCallingModel(model)) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel))
if (GENERATE_IMAGE_MODELS_REGEX.test(modelId) && !MODERN_GENERATE_IMAGE_MODELS_REGEX.test(modelId)) {
return true
}
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((m) => modelId.includes(m))
}
// TODO: refine the regex
// Text to image models
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|image|gpt-image/i
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|imagen|gpt-image/i
export function isTextToImageModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return TEXT_TO_IMAGE_REGEX.test(modelId)
}
// It's not used now
// export function isNotSupportedImageSizeModel(model?: Model): boolean {
// if (!model) {
// return false
// }
// const baseName = getLowerBaseModelName(model.id, '/')
// return baseName.includes('grok-2-image')
// }
/**
*
* @param model

View File

@ -2,27 +2,29 @@ import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import {
isAzureOpenAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isVertexAiProvider
} from '../providers'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isAnthropicModel } from './utils'
import { isPureGenerateImageModel, isTextToImageModel } from './vision'
isVertexProvider
} from '@renderer/utils/provider'
export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
export { GEMINI_FLASH_MODEL_REGEX } from './utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isClaude4SeriesModel } from './reasoning'
import { isAnthropicModel } from './utils'
import { isTextToImageModel } from './vision'
const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
`\\b(?:claude-3(-|\\.)(7|5)-sonnet(?:-[\\w-]+)|claude-3(-|\\.)5-haiku(?:-[\\w-]+)|claude-(haiku|sonnet|opus)-4(?:-[\\w-]+)?)\\b`,
'i'
)
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$')
export const GEMINI_SEARCH_REGEX = new RegExp(
'gemini-(?:2.*(?:-latest)?|3-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\\w-]+)*$',
'gemini-(?:2(?!.*-image-preview).*(?:-latest)?|3(?:\\.\\d+)?-(?:flash|pro)(?:-(?:image-)?preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\\w-]+)*$',
'i'
)
@ -34,30 +36,8 @@ export const PERPLEXITY_SEARCH_MODELS = [
'sonar-deep-research'
]
const OPENAI_DEEP_RESEARCH_MODEL_REGEX = /deep[-_]?research/
export function isOpenAIDeepResearchModel(model?: Model): boolean {
if (!model) {
return false
}
const providerId = model.provider
if (providerId !== 'openai' && providerId !== 'openai-chat') {
return false
}
const modelId = getLowerBaseModelName(model.id, '/')
return OPENAI_DEEP_RESEARCH_MODEL_REGEX.test(modelId)
}
export function isWebSearchModel(model: Model): boolean {
if (
!model ||
isEmbeddingModel(model) ||
isRerankModel(model) ||
isTextToImageModel(model) ||
isPureGenerateImageModel(model)
) {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
}
@ -73,16 +53,17 @@ export function isWebSearchModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
// bedrock和vertex不支持
if (
isAnthropicModel(model) &&
!(provider.id === SystemProviderIds['aws-bedrock'] || provider.id === SystemProviderIds.vertexai)
) {
// bedrock不支持, azure支持
if (isAnthropicModel(model) && !(provider.id === SystemProviderIds['aws-bedrock'])) {
if (isVertexProvider(provider)) {
return isClaude4SeriesModel(model)
}
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(modelId)
}
// TODO: 当其他供应商采用Response端点时这个地方逻辑需要改进
if (isOpenAIProvider(provider)) {
// azure现在也支持了websearch
if (isOpenAIProvider(provider) || isAzureOpenAIProvider(provider)) {
if (isOpenAIWebSearchModel(model)) {
return true
}
@ -113,7 +94,7 @@ export function isWebSearchModel(model: Model): boolean {
}
}
if (isGeminiProvider(provider) || isVertexAiProvider(provider)) {
if (isGeminiProvider(provider) || isVertexProvider(provider)) {
return GEMINI_SEARCH_REGEX.test(modelId)
}

View File

@ -59,15 +59,8 @@ import VoyageAIProviderLogo from '@renderer/assets/images/providers/voyageai.png
import XirangProviderLogo from '@renderer/assets/images/providers/xirang.png'
import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png'
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
import type {
AtLeast,
AzureOpenAIProvider,
Provider,
ProviderType,
SystemProvider,
SystemProviderId
} from '@renderer/types'
import { isSystemProvider, OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
import type { AtLeast, SystemProvider, SystemProviderId } from '@renderer/types'
import { OpenAIServiceTiers } from '@renderer/types'
import { TOKENFLUX_HOST } from './constant'
import { glm45FlashModel, qwen38bModel, SYSTEM_MODELS } from './models'
@ -1441,153 +1434,3 @@ export const PROVIDER_URLS: Record<SystemProviderId, ProviderUrls> = {
}
}
}
const NOT_SUPPORT_ARRAY_CONTENT_PROVIDERS = [
'deepseek',
'baichuan',
'minimax',
'xirang',
'poe',
'cephalon'
] as const satisfies SystemProviderId[]
/**
* message content Only for OpenAI Chat Completions API.
*/
export const isSupportArrayContentProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportArrayContent !== true &&
!NOT_SUPPORT_ARRAY_CONTENT_PROVIDERS.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_DEVELOPER_ROLE_PROVIDERS = ['poe', 'qiniu'] as const satisfies SystemProviderId[]
/**
* developer message role Only for OpenAI API.
*/
export const isSupportDeveloperRoleProvider = (provider: Provider) => {
return (
provider.apiOptions?.isSupportDeveloperRole === true ||
(isSystemProvider(provider) && !NOT_SUPPORT_DEVELOPER_ROLE_PROVIDERS.some((pid) => pid === provider.id))
)
}
const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const satisfies SystemProviderId[]
/**
* stream_options Only for OpenAI API.
*/
export const isSupportStreamOptionsProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportStreamOptions !== true &&
!NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_QWEN3_ENABLE_THINKING_PROVIDER = [
'ollama',
'lmstudio',
'nvidia'
] as const satisfies SystemProviderId[]
/**
* 使 enable_thinking Qwen3 Only for OpenAI Chat Completions API.
*/
export const isSupportEnableThinkingProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportEnableThinking !== true &&
!NOT_SUPPORT_QWEN3_ENABLE_THINKING_PROVIDER.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_SERVICE_TIER_PROVIDERS = ['github', 'copilot', 'cerebras'] as const satisfies SystemProviderId[]
/**
* service_tier Only for OpenAI API.
*/
export const isSupportServiceTierProvider = (provider: Provider) => {
return (
provider.apiOptions?.isSupportServiceTier === true ||
(isSystemProvider(provider) && !NOT_SUPPORT_SERVICE_TIER_PROVIDERS.some((pid) => pid === provider.id))
)
}
const SUPPORT_URL_CONTEXT_PROVIDER_TYPES = [
'gemini',
'vertexai',
'anthropic',
'new-api'
] as const satisfies ProviderType[]
export const isSupportUrlContextProvider = (provider: Provider) => {
return (
SUPPORT_URL_CONTEXT_PROVIDER_TYPES.some((type) => type === provider.type) ||
provider.id === SystemProviderIds.cherryin
)
}
const SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS = ['gemini', 'vertexai'] as const satisfies SystemProviderId[]
/** 判断是否是使用 Gemini 原生搜索工具的 provider. 目前假设只有官方 API 使用原生工具 */
export const isGeminiWebSearchProvider = (provider: Provider) => {
return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id)
}
export const isNewApiProvider = (provider: Provider) => {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
}
export function isCherryAIProvider(provider: Provider): boolean {
return provider.id === 'cherryai'
}
export function isPerplexityProvider(provider: Provider): boolean {
return provider.id === 'perplexity'
}
/**
* OpenAI
* @param {Provider} provider
* @returns {boolean} OpenAI
*/
export function isOpenAICompatibleProvider(provider: Provider): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider {
return provider.type === 'azure-openai'
}
export function isOpenAIProvider(provider: Provider): boolean {
return provider.type === 'openai-response'
}
export function isAnthropicProvider(provider: Provider): boolean {
return provider.type === 'anthropic'
}
export function isGeminiProvider(provider: Provider): boolean {
return provider.type === 'gemini'
}
export function isVertexAiProvider(provider: Provider): boolean {
return provider.type === 'vertexai'
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'ai-gateway'
}
export function isAwsBedrockProvider(provider: Provider): boolean {
return provider.type === 'aws-bedrock'
}
const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[]
export const isSupportAPIVersionProvider = (provider: Provider) => {
if (isSystemProvider(provider)) {
return !NOT_SUPPORT_API_VERSION_PROVIDERS.some((pid) => pid === provider.id)
}
return provider.apiOptions?.isNotSupportAPIVersion !== false
}

View File

@ -1,55 +0,0 @@
import type { ChatCompletionTool } from '@cherrystudio/openai/resources'
import type { Model } from '@renderer/types'
import { WEB_SEARCH_PROMPT_FOR_ZHIPU } from '@shared/config/prompts'
export function getWebSearchTools(model: Model): ChatCompletionTool[] {
if (model?.provider === 'zhipu') {
if (model.id === 'glm-4-alltools') {
return [
{
type: 'web_browser',
web_browser: {
browser: 'auto'
}
} as unknown as ChatCompletionTool
]
}
return [
{
type: 'web_search',
web_search: {
enable: true,
search_result: true,
search_prompt: WEB_SEARCH_PROMPT_FOR_ZHIPU
}
} as unknown as ChatCompletionTool
]
}
if (model?.id.includes('gemini')) {
return [
{
type: 'function',
function: {
name: 'googleSearch'
}
}
]
}
return []
}
export function getUrlContextTools(model: Model): ChatCompletionTool[] {
if (model.id.includes('gemini')) {
return [
{
type: 'function',
function: {
name: 'urlContext'
}
}
]
}
return []
}

View File

@ -175,14 +175,46 @@ export function useAppInit() {
useEffect(() => {
if (!window.electron?.ipcRenderer) return
const requestListener = (_event: Electron.IpcRendererEvent, payload: ToolPermissionRequestPayload) => {
const requestListener = async (_event: Electron.IpcRendererEvent, payload: ToolPermissionRequestPayload) => {
logger.debug('Renderer received tool permission request', {
requestId: payload.requestId,
toolName: payload.toolName,
expiresAt: payload.expiresAt,
suggestionCount: payload.suggestions.length
suggestionCount: payload.suggestions.length,
autoApprove: payload.autoApprove
})
dispatch(toolPermissionsActions.requestReceived(payload))
// Auto-approve if requested
if (payload.autoApprove) {
logger.debug('Auto-approving tool permission request', {
requestId: payload.requestId,
toolName: payload.toolName
})
dispatch(toolPermissionsActions.submissionSent({ requestId: payload.requestId, behavior: 'allow' }))
try {
const response = await window.api.agentTools.respondToPermission({
requestId: payload.requestId,
behavior: 'allow',
updatedInput: payload.input,
updatedPermissions: payload.suggestions
})
if (!response?.success) {
throw new Error('Auto-approval response rejected by main process')
}
logger.debug('Auto-approval acknowledged by main process', {
requestId: payload.requestId,
toolName: payload.toolName
})
} catch (error) {
logger.error('Failed to send auto-approval response', error as Error)
dispatch(toolPermissionsActions.submissionFailed({ requestId: payload.requestId }))
}
}
}
const resultListener = (_event: Electron.IpcRendererEvent, payload: ToolPermissionResultPayload) => {

View File

@ -38,13 +38,6 @@ export function getVertexAIServiceAccount() {
return store.getState().llm.settings.vertexai.serviceAccount
}
/**
* Provider VertexProvider
*/
export function isVertexProvider(provider: Provider): provider is VertexProvider {
return provider.type === 'vertexai'
}
/**
* VertexProvider
* @param baseProvider provider

View File

@ -266,6 +266,7 @@
"error": {
"sendFailed": "Failed to send your decision. Please try again."
},
"executing": "Executing...",
"expired": "Expired",
"inputPreview": "Tool input preview",
"pending": "Pending ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Name",
"no_results": "No results",
"none": "None",
"off": "Off",
"open": "Open",
"paste": "Paste",
"placeholders": {
@ -4259,7 +4261,6 @@
"default": "default",
"flex": "flex",
"on_demand": "on demand",
"performance": "performance",
"priority": "priority",
"tip": "Specifies the latency tier to use for processing the request",
"title": "Service Tier"
@ -4278,7 +4279,7 @@
"low": "Low",
"medium": "Medium",
"tip": "Control the level of detail in the model's output",
"title": "Level of detail"
"title": "Verbosity"
}
},
"privacy": {

View File

@ -266,6 +266,7 @@
"error": {
"sendFailed": "发送您的决定失败,请重试。"
},
"executing": "正在执行...",
"expired": "已过期",
"inputPreview": "工具输入预览",
"pending": "等待中 ({{seconds}}秒)",
@ -1158,6 +1159,7 @@
"name": "名称",
"no_results": "无结果",
"none": "无",
"off": "关闭",
"open": "打开",
"paste": "粘贴",
"placeholders": {
@ -4259,7 +4261,6 @@
"default": "默认",
"flex": "灵活",
"on_demand": "按需",
"performance": "性能",
"priority": "优先",
"tip": "指定用于处理请求的延迟层级",
"title": "服务层级"

View File

@ -37,7 +37,7 @@
"success": "成功偵測到 Git Bash"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "在這裡輸入您的訊息,使用 {{key}} 傳送 - @ 選擇路徑,/ 選擇命令"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "傳送您的決定失敗,請重試。"
},
"executing": "執行中...",
"expired": "已過期",
"inputPreview": "工具輸入預覽",
"pending": "等待中 ({{seconds}}秒)",
@ -1158,6 +1159,7 @@
"name": "名稱",
"no_results": "沒有結果",
"none": "無",
"off": "關閉",
"open": "開啟",
"paste": "貼上",
"placeholders": {
@ -4259,7 +4261,6 @@
"default": "預設",
"flex": "彈性",
"on_demand": "按需",
"performance": "效能",
"priority": "優先",
"tip": "指定用於處理請求的延遲層級",
"title": "服務層級"

View File

@ -29,15 +29,15 @@
},
"gitBash": {
"error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from",
"recheck": "[to be translated]:Recheck Git Bash Installation",
"title": "[to be translated]:Git Bash Required"
"description": "Git Bash ist erforderlich, um Agents unter Windows auszuführen. Der Agent kann ohne es nicht funktionieren. Bitte installieren Sie Git für Windows von",
"recheck": "Überprüfe die Git Bash-Installation erneut",
"title": "Git Bash erforderlich"
},
"notFound": "[to be translated]:Git Bash not found. Please install it first.",
"success": "[to be translated]:Git Bash detected successfully!"
"notFound": "Git Bash nicht gefunden. Bitte installieren Sie es zuerst.",
"success": "Git Bash erfolgreich erkannt!"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "Gib hier deine Nachricht ein, senden mit {{key}} @ Pfad auswählen, / Befehl auswählen"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "Ihre Entscheidung konnte nicht gesendet werden. Bitte versuchen Sie es erneut."
},
"executing": "Ausführen...",
"expired": "Abgelaufen",
"inputPreview": "Vorschau der Werkzeugeingabe",
"pending": "Ausstehend ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Name",
"no_results": "Keine Ergebnisse",
"none": "Keine",
"off": "Aus",
"open": "Öffnen",
"paste": "Einfügen",
"placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Vorschau",
"split": "Geteilte Ansicht"
},
"import": {
"chatgpt": {
"assistant_name": "ChatGPT-Import",
"button": "Datei auswählen",
"description": "Importiert nur Gesprächstexte, keine Bilder und Anhänge",
"error": {
"invalid_json": "Ungültiges JSON-Dateiformat",
"no_conversations": "Keine Gespräche in der Datei gefunden",
"no_valid_conversations": "Keine gültigen Konversationen zum Importieren",
"unknown": "Import fehlgeschlagen, bitte überprüfen Sie das Dateiformat"
},
"help": {
"step1": "1. Melden Sie sich bei ChatGPT an, gehen Sie zu Einstellungen > Dateneinstellungen > Daten exportieren",
"step2": "2. Warten Sie auf die Exportdatei per E-Mail",
"step3": "3. Extrahiere die heruntergeladene Datei und finde conversations.json",
"title": "Wie exportiere ich ChatGPT-Gespräche?"
},
"importing": "Konversationen werden importiert...",
"selecting": "Datei wird ausgewählt...",
"success": "Erfolgreich {{topics}} Konversationen mit {{messages}} Nachrichten importiert",
"title": "ChatGPT-Unterhaltungen importieren",
"untitled_conversation": "Unbenannte Unterhaltung"
},
"confirm": {
"button": "Importdatei auswählen",
"label": "Sind Sie sicher, dass Sie externe Daten importieren möchten?"
},
"content": "Wählen Sie die zu importierende Gesprächsdatei einer externen Anwendung aus; derzeit werden nur ChatGPT-JSON-Formatdateien unterstützt.",
"title": "Externe Gespräche importieren"
},
"knowledge": {
"add": {
"title": "Wissensdatenbank hinzufügen"
@ -3094,6 +3126,7 @@
"basic": "Grundlegende Dateneinstellungen",
"cloud_storage": "Cloud-Backup-Einstellungen",
"export_settings": "Export-Einstellungen",
"import_settings": "Importeinstellungen",
"third_party": "Drittanbieter-Verbindungen"
},
"export_menu": {
@ -3152,6 +3185,11 @@
},
"hour_interval_one": "{{count}} Stunde",
"hour_interval_other": "{{count}} Stunden",
"import_settings": {
"button": "JSON-Datei importieren",
"chatgpt": "Import aus ChatGPT",
"title": "Importiere Daten von externen Anwendungen"
},
"joplin": {
"check": {
"button": "Erkennen",
@ -4223,7 +4261,6 @@
"default": "Standard",
"flex": "Flexibel",
"on_demand": "Auf Anfrage",
"performance": "Leistung",
"priority": "Priorität",
"tip": "Latenz-Ebene für Anfrageverarbeitung festlegen",
"title": "Service-Tier"

View File

@ -29,15 +29,15 @@
},
"gitBash": {
"error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from",
"recheck": "[to be translated]:Recheck Git Bash Installation",
"title": "[to be translated]:Git Bash Required"
"description": "Το Git Bash απαιτείται για την εκτέλεση πρακτόρων στα Windows. Ο πράκτορας δεν μπορεί να λειτουργήσει χωρίς αυτό. Παρακαλούμε εγκαταστήστε το Git για Windows από",
"recheck": "Επανέλεγχος Εγκατάστασης του Git Bash",
"title": "Απαιτείται Git Bash"
},
"notFound": "[to be translated]:Git Bash not found. Please install it first.",
"success": "[to be translated]:Git Bash detected successfully!"
"notFound": "Το Git Bash δεν βρέθηκε. Παρακαλώ εγκαταστήστε το πρώτα.",
"success": "Το Git Bash εντοπίστηκε με επιτυχία!"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "Εισάγετε το μήνυμά σας εδώ, στείλτε με {{key}} - @ επιλέξτε διαδρομή, / επιλέξτε εντολή"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "Αποτυχία αποστολής της απόφασής σας. Προσπαθήστε ξανά."
},
"executing": "Εκτέλεση...",
"expired": "Ληγμένο",
"inputPreview": "Προεπισκόπηση εισόδου εργαλείου",
"pending": "Εκκρεμεί ({{seconds}}δ)",
@ -1158,6 +1159,7 @@
"name": "Όνομα",
"no_results": "Δεν βρέθηκαν αποτελέσματα",
"none": "Χωρίς",
"off": "Κλειστό",
"open": "Άνοιγμα",
"paste": "Επικόλληση",
"placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Προεπισκόπηση",
"split": "Διαχωρισμός"
},
"import": {
"chatgpt": {
"assistant_name": "Εισαγωγή ChatGPT",
"button": "Επιλέξτε Αρχείο",
"description": "Εισάγει μόνο κείμενο συνομιλίας, δεν περιλαμβάνει εικόνες και συνημμένα",
"error": {
"invalid_json": "Μη έγκυρη μορφή αρχείου JSON",
"no_conversations": "Δεν βρέθηκαν συνομιλίες στο αρχείο",
"no_valid_conversations": "Δεν υπάρχουν έγκυρες συνομιλίες προς εισαγωγή",
"unknown": "Η εισαγωγή απέτυχε, παρακαλώ ελέγξτε τη μορφή του αρχείου"
},
"help": {
"step1": "1. Συνδεθείτε στο ChatGPT, πηγαίνετε στις Ρυθμίσεις > Έλεγχοι δεδομένων > Εξαγωγή δεδομένων",
"step2": "2. Περιμένετε το αρχείο εξαγωγής μέσω email",
"step3": "3. Εξαγάγετε το ληφθέν αρχείο και βρείτε το conversations.json",
"title": "Πώς να εξάγετε συνομιλίες του ChatGPT;"
},
"importing": "Εισαγωγή συνομιλιών...",
"selecting": "Επιλογή αρχείου...",
"success": "Επιτυχής εισαγωγή {{topics}} συνομιλιών με {{messages}} μηνύματα",
"title": "Εισαγωγή Συνομιλιών του ChatGPT",
"untitled_conversation": "Συνομιλία χωρίς τίτλο"
},
"confirm": {
"button": "Επιλέξτε Εισαγωγή Αρχείου",
"label": "Είστε σίγουροι ότι θέλετε να εισάγετε εξωτερικά δεδομένα;"
},
"content": "Επιλέξτε εξωτερικό αρχείο συνομιλίας εφαρμογής για εισαγωγή, προς το παρόν υποστηρίζονται μόνο αρχεία μορφής JSON του ChatGPT",
"title": "Εισαγωγή Εξωτερικών Συνομιλιών"
},
"knowledge": {
"add": {
"title": "Προσθήκη βιβλιοθήκης γνώσεων"
@ -3094,6 +3126,7 @@
"basic": "Ρυθμίσεις βασικών δεδομένων",
"cloud_storage": "Ρυθμίσεις αποθήκευσης στο νέφος",
"export_settings": "Ρυθμίσεις εξαγωγής",
"import_settings": "Εισαγωγή Ρυθμίσεων",
"third_party": "Σύνδεση τρίτων"
},
"export_menu": {
@ -3152,6 +3185,11 @@
},
"hour_interval_one": "{{count}} ώρα",
"hour_interval_other": "{{count}} ώρες",
"import_settings": {
"button": "Εισαγωγή αρχείου Json",
"chatgpt": "Εισαγωγή από το ChatGPT",
"title": "Εισαγωγή Δεδομένων Εξωτερικής Εφαρμογής"
},
"joplin": {
"check": {
"button": "Έλεγχος",
@ -4223,7 +4261,6 @@
"default": "Προεπιλογή",
"flex": "Εύκαμπτο",
"on_demand": "κατά παραγγελία",
"performance": "Απόδοση",
"priority": "προτεραιότητα",
"tip": "Καθορίστε το επίπεδο καθυστέρησης που χρησιμοποιείται για την επεξεργασία των αιτημάτων",
"title": "Επίπεδο υπηρεσίας"

View File

@ -29,15 +29,15 @@
},
"gitBash": {
"error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from",
"recheck": "[to be translated]:Recheck Git Bash Installation",
"title": "[to be translated]:Git Bash Required"
"description": "Se requiere Git Bash para ejecutar agentes en Windows. El agente no puede funcionar sin él. Instale Git para Windows desde",
"recheck": "Volver a verificar la instalación de Git Bash",
"title": "Git Bash Requerido"
},
"notFound": "[to be translated]:Git Bash not found. Please install it first.",
"success": "[to be translated]:Git Bash detected successfully!"
"notFound": "Git Bash no encontrado. Por favor, instálalo primero.",
"success": "¡Git Bash detectado con éxito!"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "Introduce tu mensaje aquí, envía con {{key}} - @ seleccionar ruta, / seleccionar comando"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "No se pudo enviar tu decisión. Por favor, inténtalo de nuevo."
},
"executing": "Ejecutando...",
"expired": "Caducado",
"inputPreview": "Vista previa de entrada de herramienta",
"pending": "Pendiente ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Nombre",
"no_results": "Sin resultados",
"none": "无",
"off": "Apagado",
"open": "Abrir",
"paste": "Pegar",
"placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Vista previa",
"split": "Dividir"
},
"import": {
"chatgpt": {
"assistant_name": "Importación de ChatGPT",
"button": "Seleccionar archivo",
"description": "Solo importa el texto de la conversación, no incluye imágenes ni archivos adjuntos",
"error": {
"invalid_json": "Formato de archivo JSON inválido",
"no_conversations": "No se encontraron conversaciones en el archivo",
"no_valid_conversations": "No hay conversaciones válidas para importar",
"unknown": "Error de importación, por favor verifica el formato del archivo"
},
"help": {
"step1": "1. Inicia sesión en ChatGPT, ve a Configuración > Controles de datos > Exportar datos",
"step2": "2. Espera el archivo de exportación por correo electrónico",
"step3": "3. Extrae el archivo descargado y busca conversations.json",
"title": "Cómo exportar conversaciones de ChatGPT"
},
"importing": "Importando conversaciones...",
"selecting": "Seleccionando archivo...",
"success": "Importadas con éxito {{topics}} conversaciones con {{messages}} mensajes",
"title": "Importar conversaciones de ChatGPT",
"untitled_conversation": "Conversación Sin Título"
},
"confirm": {
"button": "Seleccionar Archivo de Importación",
"label": "¿Estás seguro de que quieres importar datos externos?"
},
"content": "Selecciona el archivo de conversación de la aplicación externa para importar; actualmente solo admite archivos en formato JSON de ChatGPT",
"title": "Importar Conversaciones Externas"
},
"knowledge": {
"add": {
"title": "Agregar base de conocimientos"
@ -3094,6 +3126,7 @@
"basic": "Configuración básica",
"cloud_storage": "Configuración de almacenamiento en la nube",
"export_settings": "Configuración de exportación",
"import_settings": "Importar configuración",
"third_party": "Conexiones de terceros"
},
"export_menu": {
@ -3152,6 +3185,11 @@
},
"hour_interval_one": "{{count}} hora",
"hour_interval_other": "{{count}} horas",
"import_settings": {
"button": "Importar archivo Json",
"chatgpt": "Importar desde ChatGPT",
"title": "Importar datos de aplicaciones externas"
},
"joplin": {
"check": {
"button": "Revisar",
@ -4223,7 +4261,6 @@
"default": "Predeterminado",
"flex": "Flexible",
"on_demand": "según demanda",
"performance": "rendimiento",
"priority": "prioridad",
"tip": "Especifica el nivel de latencia utilizado para procesar la solicitud",
"title": "Nivel de servicio"

View File

@ -29,15 +29,15 @@
},
"gitBash": {
"error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from",
"recheck": "[to be translated]:Recheck Git Bash Installation",
"title": "[to be translated]:Git Bash Required"
"description": "Git Bash est requis pour exécuter des agents sur Windows. L'agent ne peut pas fonctionner sans. Veuillez installer Git pour Windows depuis",
"recheck": "Revérifier l'installation de Git Bash",
"title": "Git Bash requis"
},
"notFound": "[to be translated]:Git Bash not found. Please install it first.",
"success": "[to be translated]:Git Bash detected successfully!"
"notFound": "Git Bash introuvable. Veuillez linstaller dabord.",
"success": "Git Bash détecté avec succès !"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "Entrez votre message ici, envoyez avec {{key}} - @ sélectionner le chemin, / sélectionner la commande"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "Échec de l'envoi de votre décision. Veuillez réessayer."
},
"executing": "Exécution en cours...",
"expired": "Expiré",
"inputPreview": "Aperçu de l'entrée de l'outil",
"pending": "En attente ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Nom",
"no_results": "Aucun résultat",
"none": "Aucun",
"off": "Désactivé",
"open": "Ouvrir",
"paste": "Coller",
"placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Aperçu",
"split": "Diviser"
},
"import": {
"chatgpt": {
"assistant_name": "Importation de ChatGPT",
"button": "Sélectionner le fichier",
"description": "Importe uniquement le texte de la conversation, n'inclut pas les images et les pièces jointes",
"error": {
"invalid_json": "Format de fichier JSON invalide",
"no_conversations": "Aucune conversation trouvée dans le fichier",
"no_valid_conversations": "Aucune conversation valide à importer",
"unknown": "L'importation a échoué, veuillez vérifier le format du fichier"
},
"help": {
"step1": "1. Connectez-vous à ChatGPT, allez dans Paramètres > Contrôles des données > Exporter les données",
"step2": "2. Attendez le fichier dexportation par e-mail",
"step3": "3. Extrayez le fichier téléchargé et recherchez conversations.json",
"title": "Comment exporter les conversations de ChatGPT ?"
},
"importing": "Importation des conversations...",
"selecting": "Sélection du fichier...",
"success": "Importation réussie de {{topics}} conversations avec {{messages}} messages",
"title": "Importer les conversations de ChatGPT",
"untitled_conversation": "Conversation sans titre"
},
"confirm": {
"button": "Sélectionner le fichier à importer",
"label": "Êtes-vous sûr de vouloir importer des données externes ?"
},
"content": "Sélectionnez le fichier de conversation de l'application externe à importer, actuellement uniquement les fichiers au format JSON de ChatGPT sont pris en charge",
"title": "Importer des conversations externes"
},
"knowledge": {
"add": {
"title": "Ajouter une base de connaissances"
@ -3094,6 +3126,7 @@
"basic": "Paramètres de base",
"cloud_storage": "Paramètres de sauvegarde cloud",
"export_settings": "Paramètres d'exportation",
"import_settings": "Importer les paramètres",
"third_party": "Connexion tierce"
},
"export_menu": {
@ -3152,6 +3185,11 @@
},
"hour_interval_one": "{{count}} heure",
"hour_interval_other": "{{count}} heures",
"import_settings": {
"button": "Importer le fichier JSON",
"chatgpt": "Importer depuis ChatGPT",
"title": "Importer des données d'applications externes"
},
"joplin": {
"check": {
"button": "Vérifier",
@ -4223,7 +4261,6 @@
"default": "Par défaut",
"flex": "Flexible",
"on_demand": "à la demande",
"performance": "performance",
"priority": "priorité",
"tip": "Spécifie le niveau de latence utilisé pour traiter la demande",
"title": "Niveau de service"

View File

@ -29,15 +29,15 @@
},
"gitBash": {
"error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from",
"recheck": "[to be translated]:Recheck Git Bash Installation",
"title": "[to be translated]:Git Bash Required"
"description": "Windowsでエージェントを実行するにはGit Bashが必要です。これがないとエージェントは動作しません。以下からGit for Windowsをインストールしてください。",
"recheck": "Git Bashのインストールを再確認してください",
"title": "Git Bashが必要です"
},
"notFound": "[to be translated]:Git Bash not found. Please install it first.",
"success": "[to be translated]:Git Bash detected successfully!"
"notFound": "Git Bash が見つかりません。先にインストールしてください。",
"success": "Git Bashが正常に検出されました"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "メッセージをここに入力し、{{key}}で送信 - @でパスを選択、/でコマンドを選択"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "決定の送信に失敗しました。もう一度お試しください。"
},
"executing": "実行中...",
"expired": "期限切れ",
"inputPreview": "ツール入力プレビュー",
"pending": "保留中({{seconds}}秒)",
@ -1158,6 +1159,7 @@
"name": "名前",
"no_results": "検索結果なし",
"none": "無",
"off": "オフ",
"open": "開く",
"paste": "貼り付け",
"placeholders": {
@ -1385,6 +1387,36 @@
"preview": "プレビュー",
"split": "分割"
},
"import": {
"chatgpt": {
"assistant_name": "ChatGPTインポート",
"button": "ファイルを選択",
"description": "会話のテキストのみをインポートし、画像や添付ファイルは含まれません",
"error": {
"invalid_json": "無効なJSONファイル形式",
"no_conversations": "ファイルに会話が見つかりません",
"no_valid_conversations": "インポートする有効な会話がありません",
"unknown": "インポートに失敗しました。ファイル形式を確認してください。"
},
"help": {
"step1": "1. ChatGPTにログインし、設定 > データ管理 > データをエクスポートへ進みます",
"step2": "2. エクスポートファイルが届くまでメールでお待ちください",
"step3": "3. ダウンロードしたファイルを展開し、conversations.jsonを探してください",
"title": "ChatGPTの会話をエクスポートする方法は"
},
"importing": "会話をインポートしています...",
"selecting": "ファイルを選択中...",
"success": "{{topics}}件の会話と{{messages}}件のメッセージを正常にインポートしました",
"title": "ChatGPTの会話をインポート",
"untitled_conversation": "無題の会話"
},
"confirm": {
"button": "ファイルのインポートを選択",
"label": "外部データをインポートしてもよろしいですか?"
},
"content": "外部アプリケーションの会話ファイルを選択してインポートします。現在、ChatGPT JSON形式ファイルのみサポートしています。",
"title": "外部会話をインポート"
},
"knowledge": {
"add": {
"title": "ナレッジベースを追加"
@ -3094,6 +3126,7 @@
"basic": "基本データ設定",
"cloud_storage": "クラウドバックアップ設定",
"export_settings": "エクスポート設定",
"import_settings": "設定のインポート",
"third_party": "サードパーティー連携"
},
"export_menu": {
@ -3152,6 +3185,11 @@
},
"hour_interval_one": "{{count}} 時間",
"hour_interval_other": "{{count}} 時間",
"import_settings": {
"button": "JSONファイルをインポート",
"chatgpt": "ChatGPTからインポート",
"title": "外部アプリケーションデータをインポート"
},
"joplin": {
"check": {
"button": "確認",
@ -4223,7 +4261,6 @@
"default": "デフォルト",
"flex": "フレックス",
"on_demand": "オンデマンド",
"performance": "性能",
"priority": "優先",
"tip": "リクエスト処理に使用するレイテンシティアを指定します",
"title": "サービスティア"

View File

@ -29,15 +29,15 @@
},
"gitBash": {
"error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from",
"recheck": "[to be translated]:Recheck Git Bash Installation",
"title": "[to be translated]:Git Bash Required"
"description": "O Git Bash é necessário para executar agentes no Windows. O agente não pode funcionar sem ele. Por favor, instale o Git para Windows a partir de",
"recheck": "Reverificar a Instalação do Git Bash",
"title": "Git Bash Necessário"
},
"notFound": "[to be translated]:Git Bash not found. Please install it first.",
"success": "[to be translated]:Git Bash detected successfully!"
"notFound": "Git Bash não encontrado. Por favor, instale-o primeiro.",
"success": "Git Bash detectado com sucesso!"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "Digite sua mensagem aqui, envie com {{key}} - @ selecionar caminho, / selecionar comando"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "Falha ao enviar sua decisão. Por favor, tente novamente."
},
"executing": "Executando...",
"expired": "Expirado",
"inputPreview": "Pré-visualização da entrada da ferramenta",
"pending": "Pendente ({{seconds}}s)",
@ -1158,6 +1159,7 @@
"name": "Nome",
"no_results": "Nenhum resultado",
"none": "Nenhum",
"off": "Desligado",
"open": "Abrir",
"paste": "Colar",
"placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Visualizar",
"split": "Dividir"
},
"import": {
"chatgpt": {
"assistant_name": "Importação do ChatGPT",
"button": "Selecionar Arquivo",
"description": "Importa apenas o texto da conversa, não inclui imagens e anexos",
"error": {
"invalid_json": "Formato de arquivo JSON inválido",
"no_conversations": "Nenhuma conversa encontrada no arquivo",
"no_valid_conversations": "Nenhuma conversa válida para importar",
"unknown": "Falha na importação, verifique o formato do arquivo"
},
"help": {
"step1": "1. Faça login no ChatGPT, vá para Configurações > Controles de dados > Exportar dados",
"step2": "2. Aguarde o arquivo de exportação por e-mail",
"step3": "3. Extraia o arquivo baixado e localize conversations.json",
"title": "Como exportar conversas do ChatGPT?"
},
"importing": "Importando conversas...",
"selecting": "Selecionando arquivo...",
"success": "Importadas com sucesso {{topics}} conversas com {{messages}} mensagens",
"title": "Importar Conversas do ChatGPT",
"untitled_conversation": "Conversa Sem Título"
},
"confirm": {
"button": "Selecionar Arquivo de Importação",
"label": "Tem certeza de que deseja importar dados externos?"
},
"content": "Selecione o arquivo de conversa do aplicativo externo para importar; atualmente, apenas arquivos no formato JSON do ChatGPT são suportados.",
"title": "Importar Conversas Externas"
},
"knowledge": {
"add": {
"title": "Adicionar Base de Conhecimento"
@ -3094,6 +3126,7 @@
"basic": "Configurações Básicas",
"cloud_storage": "Configurações de Armazenamento em Nuvem",
"export_settings": "Configurações de Exportação",
"import_settings": "Importar Configurações",
"third_party": "Conexões de Terceiros"
},
"export_menu": {
@ -3152,6 +3185,11 @@
},
"hour_interval_one": "{{count}} hora",
"hour_interval_other": "{{count}} horas",
"import_settings": {
"button": "Importar Arquivo Json",
"chatgpt": "Importar do ChatGPT",
"title": "Importar Dados de Aplicações Externas"
},
"joplin": {
"check": {
"button": "Verificar",
@ -4223,7 +4261,6 @@
"default": "Padrão",
"flex": "Flexível",
"on_demand": "sob demanda",
"performance": "desempenho",
"priority": "prioridade",
"tip": "Especifique o nível de latência usado para processar a solicitação",
"title": "Nível de Serviço"

View File

@ -29,15 +29,15 @@
},
"gitBash": {
"error": {
"description": "[to be translated]:Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from",
"recheck": "[to be translated]:Recheck Git Bash Installation",
"title": "[to be translated]:Git Bash Required"
"description": "Для запуска агентов в Windows требуется Git Bash. Без него агент не может работать. Пожалуйста, установите Git для Windows с",
"recheck": "Повторная проверка установки Git Bash",
"title": "Требуется Git Bash"
},
"notFound": "[to be translated]:Git Bash not found. Please install it first.",
"success": "[to be translated]:Git Bash detected successfully!"
"notFound": "Git Bash не найден. Пожалуйста, сначала установите его.",
"success": "Git Bash успешно обнаружен!"
},
"input": {
"placeholder": "[to be translated]:Enter your message here, send with {{key}} - @ select path, / select command"
"placeholder": "Введите ваше сообщение здесь, отправьте с помощью {{key}} — @ выбрать путь, / выбрать команду"
},
"list": {
"error": {
@ -266,6 +266,7 @@
"error": {
"sendFailed": "Не удалось отправить ваше решение. Попробуйте ещё раз."
},
"executing": "Выполнение...",
"expired": "Истёк",
"inputPreview": "Предварительный просмотр ввода инструмента",
"pending": "Ожидание ({{seconds}}с)",
@ -1158,6 +1159,7 @@
"name": "Имя",
"no_results": "Результатов не найдено",
"none": "без",
"off": "Выкл",
"open": "Открыть",
"paste": "Вставить",
"placeholders": {
@ -1385,6 +1387,36 @@
"preview": "Предпросмотр",
"split": "Разделить"
},
"import": {
"chatgpt": {
"assistant_name": "Импорт ChatGPT",
"button": "Выбрать файл",
"description": "Импортирует только текст переписки, не включает изображения и вложения",
"error": {
"invalid_json": "Неверный формат файла JSON",
"no_conversations": "В файле не найдено ни одной беседы",
"no_valid_conversations": "Нет допустимых бесед для импорта",
"unknown": "Импорт не удался, проверьте формат файла"
},
"help": {
"step1": "1. Войдите в ChatGPT, перейдите в Настройки > Управление данными > Экспортировать данные",
"step2": "2. Дождитесь письма с файлом экспорта",
"step3": "3. Распакуйте загруженный файл и найдите conversations.json",
"title": "Как экспортировать диалоги с ChatGPT?"
},
"importing": "Импортирование разговоров...",
"selecting": "Выбор файла...",
"success": "Успешно импортировано {{topics}} бесед с {{messages}} сообщениями",
"title": "Импортировать диалоги ChatGPT",
"untitled_conversation": "Безымянный разговор"
},
"confirm": {
"button": "Выберите файл для импорта",
"label": "Вы уверены, что хотите импортировать внешние данные?"
},
"content": "Выберите внешний файл с перепиской для импорта; в настоящее время поддерживаются только файлы в формате JSON ChatGPT",
"title": "Импорт внешних бесед"
},
"knowledge": {
"add": {
"title": "Добавить базу знаний"
@ -3094,6 +3126,7 @@
"basic": "Основные настройки данных",
"cloud_storage": "Настройки облачного резервирования",
"export_settings": "Настройки экспорта",
"import_settings": "Импорт настроек",
"third_party": "Сторонние подключения"
},
"export_menu": {
@ -3152,6 +3185,11 @@
},
"hour_interval_one": "{{count}} час",
"hour_interval_other": "{{count}} часов",
"import_settings": {
"button": "Импортировать файл JSON",
"chatgpt": "Импорт из ChatGPT",
"title": "Импорт внешних данных приложения"
},
"joplin": {
"check": {
"button": "Проверить",
@ -4223,7 +4261,6 @@
"default": "По умолчанию",
"flex": "Гибкий",
"on_demand": "по требованию",
"performance": "производительность",
"priority": "приоритет",
"tip": "Указывает уровень задержки, который следует использовать для обработки запроса",
"title": "Уровень сервиса"

View File

@ -3,7 +3,6 @@ import { ActionIconButton } from '@renderer/components/Buttons'
import type { QuickPanelListItem } from '@renderer/components/QuickPanel'
import { QuickPanelReservedSymbol, useQuickPanel } from '@renderer/components/QuickPanel'
import { isGeminiModel } from '@renderer/config/models'
import { isGeminiWebSearchProvider, isSupportUrlContextProvider } from '@renderer/config/providers'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { useTimer } from '@renderer/hooks/useTimer'
@ -12,6 +11,7 @@ import { getProviderByModel } from '@renderer/services/AssistantService'
import { EventEmitter } from '@renderer/services/EventService'
import type { MCPPrompt, MCPResource, MCPServer } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isGeminiWebSearchProvider, isSupportUrlContextProvider } from '@renderer/utils/provider'
import { Form, Input } from 'antd'
import { CircleX, Hammer, Plus } from 'lucide-react'
import type { FC } from 'react'

View File

@ -9,7 +9,6 @@ import {
isOpenAIWebSearchModel,
isWebSearchModel
} from '@renderer/config/models'
import { isGeminiWebSearchProvider } from '@renderer/config/providers'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useTimer } from '@renderer/hooks/useTimer'
import { useWebSearchProviders } from '@renderer/hooks/useWebSearchProviders'
@ -19,6 +18,7 @@ import WebSearchService from '@renderer/services/WebSearchService'
import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types'
import { hasObjectKey } from '@renderer/utils'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isGeminiWebSearchProvider } from '@renderer/utils/provider'
import { Globe } from 'lucide-react'
import { useCallback, useEffect, useMemo } from 'react'
import { useTranslation } from 'react-i18next'

View File

@ -1,7 +1,7 @@
import { isAnthropicModel, isGeminiModel } from '@renderer/config/models'
import { isSupportUrlContextProvider } from '@renderer/config/providers'
import { isAnthropicModel, isGeminiModel, isPureGenerateImageModel } from '@renderer/config/models'
import { defineTool, registerTool, TopicType } from '@renderer/pages/home/Inputbar/types'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { isSupportUrlContextProvider } from '@renderer/utils/provider'
import UrlContextButton from './components/UrlContextbutton'
@ -11,7 +11,12 @@ const urlContextTool = defineTool({
visibleInScopes: [TopicType.Chat],
condition: ({ model }) => {
const provider = getProviderByModel(model)
return !!provider && isSupportUrlContextProvider(provider) && (isGeminiModel(model) || isAnthropicModel(model))
return (
!!provider &&
isSupportUrlContextProvider(provider) &&
!isPureGenerateImageModel(model) &&
(isGeminiModel(model) || isAnthropicModel(model))
)
},
render: ({ assistant }) => <UrlContextButton assistantId={assistant.id} />
})

View File

@ -1,4 +1,4 @@
import { isMandatoryWebSearchModel } from '@renderer/config/models'
import { isMandatoryWebSearchModel, isWebSearchModel } from '@renderer/config/models'
import { defineTool, registerTool, TopicType } from '@renderer/pages/home/Inputbar/types'
import WebSearchButton from './components/WebSearchButton'
@ -15,7 +15,7 @@ const webSearchTool = defineTool({
label: (t) => t('chat.input.web_search.label'),
visibleInScopes: [TopicType.Chat],
condition: ({ model }) => !isMandatoryWebSearchModel(model),
condition: ({ model }) => isWebSearchModel(model) && !isMandatoryWebSearchModel(model),
render: function WebSearchToolRender(context) {
const { assistant, quickPanelController } = context

Some files were not shown because too many files have changed in this diff Show More