mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 13:31:32 +08:00
refactor: migrate to v5 patch-2
This commit is contained in:
parent
f20d964be3
commit
448b5b5c9e
@ -8,7 +8,7 @@ export type { NamedMiddleware } from './middleware'
|
||||
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
||||
|
||||
// 创建管理
|
||||
export type { ModelCreationRequest, ResolvedConfig } from './models'
|
||||
export type { ModelConfig } from './models'
|
||||
export {
|
||||
createBaseModel,
|
||||
createImageModel,
|
||||
@ -19,6 +19,6 @@ export {
|
||||
} from './models'
|
||||
|
||||
// 执行管理
|
||||
export type { MCPRequestContext } from './plugins/built-in/mcpPromptPlugin'
|
||||
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
||||
export type { ExecutionOptions, ExecutorConfig } from './runtime'
|
||||
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||
|
||||
@ -1,37 +1,37 @@
|
||||
/**
|
||||
* 配置管理器
|
||||
* 整合options、plugins、middlewares等配置
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
// /**
|
||||
// * 配置管理器
|
||||
// * 整合options、plugins、middlewares等配置
|
||||
// */
|
||||
// import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { ProviderId, ProviderSettingsMap } from '../../types'
|
||||
import { createMiddlewares } from '../middleware/manager'
|
||||
import { AiPlugin } from '../plugins'
|
||||
import { ResolvedConfig } from './types'
|
||||
// import { ProviderId, ProviderSettingsMap } from '../../types'
|
||||
// import { createMiddlewares } from '../middleware/manager'
|
||||
// import { AiPlugin } from '../plugins'
|
||||
// import { ResolvedConfig } from './types'
|
||||
|
||||
/**
|
||||
* 解析配置
|
||||
* 整合provider配置、插件、中间件、provider选项等
|
||||
*/
|
||||
export function resolveConfig(
|
||||
providerId: ProviderId,
|
||||
modelId: string,
|
||||
providerSettings: ProviderSettingsMap[ProviderId],
|
||||
plugins: AiPlugin[] = [],
|
||||
middlewares: LanguageModelV2Middleware[] = []
|
||||
): ResolvedConfig {
|
||||
// 使用独立的中间件管理器处理中间件
|
||||
const resolvedMiddlewares = createMiddlewares(middlewares)
|
||||
// /**
|
||||
// * 解析配置
|
||||
// * 整合provider配置、插件、中间件、provider选项等
|
||||
// */
|
||||
// export function resolveConfig(
|
||||
// providerId: ProviderId,
|
||||
// modelId: string,
|
||||
// providerSettings: ProviderSettingsMap[ProviderId],
|
||||
// plugins: AiPlugin[] = [],
|
||||
// middlewares: LanguageModelV2Middleware[] = []
|
||||
// ): ResolvedConfig {
|
||||
// // 使用独立的中间件管理器处理中间件
|
||||
// const resolvedMiddlewares = createMiddlewares(middlewares)
|
||||
|
||||
return {
|
||||
provider: {
|
||||
id: providerId,
|
||||
options: providerSettings
|
||||
},
|
||||
model: {
|
||||
id: modelId
|
||||
},
|
||||
plugins,
|
||||
middlewares: resolvedMiddlewares
|
||||
}
|
||||
}
|
||||
// return {
|
||||
// provider: {
|
||||
// id: providerId,
|
||||
// options: providerSettings
|
||||
// },
|
||||
// model: {
|
||||
// id: modelId
|
||||
// },
|
||||
// plugins,
|
||||
// middlewares: resolvedMiddlewares
|
||||
// }
|
||||
// }
|
||||
|
||||
@ -2,8 +2,7 @@
|
||||
* Provider 创建器
|
||||
* 负责动态导入 AI SDK providers 并创建基础模型实例
|
||||
*/
|
||||
import { ImageModelV2, type LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import { wrapLanguageModel } from 'ai'
|
||||
import { ImageModelV2, type LanguageModelV2 } from '@ai-sdk/provider'
|
||||
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../../types'
|
||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
* 模型工厂函数
|
||||
* 统一的模型创建和配置管理
|
||||
*/
|
||||
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import { LanguageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { wrapModelWithMiddlewares } from '../middleware'
|
||||
@ -16,7 +16,11 @@ export async function createModel(config: ModelConfig): Promise<LanguageModelV2>
|
||||
validateModelConfig(config)
|
||||
|
||||
// 1. 创建基础模型
|
||||
const baseModel = await createBaseModel(config)
|
||||
const baseModel = await createBaseModel({
|
||||
providerId: config.providerId,
|
||||
modelId: config.modelId,
|
||||
providerSettings: config.options
|
||||
})
|
||||
|
||||
// 2. 应用中间件(如果有)
|
||||
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
|
||||
@ -39,7 +43,7 @@ function validateModelConfig(config: ModelConfig): void {
|
||||
if (!config.modelId) {
|
||||
throw new Error('ModelConfig: modelId is required')
|
||||
}
|
||||
if (!config.providerSettings) {
|
||||
if (!config.options) {
|
||||
throw new Error('ModelConfig: providerSettings is required')
|
||||
}
|
||||
}
|
||||
|
||||
@ -5,6 +5,6 @@
|
||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||
|
||||
export { createLoggingPlugin } from './logging'
|
||||
export type { MCPPromptConfig, ToolUseResult } from './mcpPromptPlugin'
|
||||
export { createMCPPromptPlugin } from './mcpPromptPlugin'
|
||||
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
||||
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
|
||||
export { type WebSearchConfig, webSearchPlugin } from './webSearchPlugin'
|
||||
|
||||
@ -5,8 +5,9 @@
|
||||
*/
|
||||
import type { ModelMessage, TextStreamPart, ToolErrorUnion, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 使用 AI SDK 的 Tool 类型,更通用
|
||||
@ -25,37 +26,6 @@ import type { AiRequestContext } from '../types'
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* 解析结果类型
|
||||
* 表示从AI响应中解析出的工具使用意图
|
||||
*/
|
||||
export interface ToolUseResult {
|
||||
id: string
|
||||
toolName: string
|
||||
arguments: any
|
||||
status: 'pending' | 'invoking' | 'done' | 'error'
|
||||
}
|
||||
|
||||
/**
|
||||
* MCP Prompt 插件配置
|
||||
*/
|
||||
export interface MCPPromptConfig {
|
||||
// 是否启用(用于运行时开关)
|
||||
enabled?: boolean
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
export interface MCPRequestContext extends AiRequestContext {
|
||||
mcpTools: ToolSet
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||
*/
|
||||
@ -282,14 +252,11 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建 MCP Prompt 插件
|
||||
*/
|
||||
export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:mcp-prompt',
|
||||
name: 'built-in:prompt-tool-use',
|
||||
transformParams: (params: any, context: AiRequestContext) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
@ -315,6 +282,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
|
||||
...(systemMessage ? { system: systemMessage } : {}),
|
||||
tools: undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
console.log('transformedParams', transformedParams)
|
||||
return transformedParams
|
||||
},
|
||||
@ -0,0 +1,33 @@
|
||||
import { ToolSet } from 'ai'
|
||||
|
||||
import { AiRequestContext } from '../..'
|
||||
|
||||
/**
|
||||
* 解析结果类型
|
||||
* 表示从AI响应中解析出的工具使用意图
|
||||
*/
|
||||
export interface ToolUseResult {
|
||||
id: string
|
||||
toolName: string
|
||||
arguments: any
|
||||
status: 'pending' | 'invoking' | 'done' | 'error'
|
||||
}
|
||||
|
||||
export interface BaseToolUsePluginConfig {
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
export interface ToolUseRequestContext extends AiRequestContext {
|
||||
mcpTools: ToolSet
|
||||
}
|
||||
@ -14,7 +14,7 @@ import { adaptWebSearchForProvider, type WebSearchConfig } from './helper'
|
||||
* options.ts 文件负责将高层级的设置(如 assistant.enableWebSearch)
|
||||
* 转换为 providerOptions 中的 webSearch: { enabled: true } 配置。
|
||||
*/
|
||||
export const webSearchPlugin = (config) =>
|
||||
export const webSearchPlugin = () =>
|
||||
definePlugin({
|
||||
name: 'webSearch',
|
||||
enforce: 'pre',
|
||||
|
||||
@ -38,6 +38,7 @@ export interface AiPlugin {
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||
configureModel?: (model: any, context: AiRequestContext) => any | Promise<any>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
@ -218,7 +218,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
return await createModel({
|
||||
providerId: this.config.providerId,
|
||||
modelId: modelOrId,
|
||||
providerSettings: this.config.providerSettings,
|
||||
options: this.config.providerSettings,
|
||||
middlewares
|
||||
})
|
||||
} else {
|
||||
|
||||
@ -75,7 +75,13 @@ export type {
|
||||
ToolSet,
|
||||
UserModelMessage
|
||||
} from 'ai'
|
||||
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware, smoothStream } from 'ai'
|
||||
export {
|
||||
defaultSettingsMiddleware,
|
||||
extractReasoningMiddleware,
|
||||
simulateStreamingMiddleware,
|
||||
smoothStream,
|
||||
stepCountIs
|
||||
} from 'ai'
|
||||
|
||||
// 重新导出所有 Provider Settings 类型
|
||||
export type {
|
||||
|
||||
@ -17,7 +17,7 @@ import {
|
||||
type ProviderSettingsMap,
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { createMCPPromptPlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
|
||||
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
@ -136,7 +136,7 @@ export default class ModernAiProvider {
|
||||
// 3. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
plugins.push(
|
||||
createMCPPromptPlugin({
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
@ -157,6 +157,10 @@ export default class ModernAiProvider {
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
// plugins.push(createNativeToolUsePlugin())
|
||||
// }
|
||||
console.log(
|
||||
'最终插件列表:',
|
||||
plugins.map((p) => p.name)
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
* 统一管理从各个 apiClient 提取的参数处理和转换功能
|
||||
*/
|
||||
|
||||
import { type ModelMessage, type StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { type ModelMessage, stepCountIs, type StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
@ -245,7 +245,8 @@ export async function buildStreamTextParams(
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
providerOptions,
|
||||
tools
|
||||
tools,
|
||||
stopWhen: stepCountIs(10)
|
||||
}
|
||||
|
||||
return { params, modelId: model.id, capabilities: { enableReasoning, enableWebSearch, enableGenerateImage } }
|
||||
|
||||
@ -42,7 +42,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||
return { reasoning_effort: 'none' }
|
||||
return { reasoningEffort: 'none' }
|
||||
}
|
||||
return {}
|
||||
}
|
||||
@ -85,7 +85,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
return { reasoning: { max_tokens: 0, exclude: true } }
|
||||
}
|
||||
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||
return { reasoning_effort: 'none' }
|
||||
return { reasoningEffort: 'none' }
|
||||
}
|
||||
return {}
|
||||
}
|
||||
@ -123,14 +123,14 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
// Grok models
|
||||
if (isSupportedReasoningEffortGrokModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
reasoningEffort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI models
|
||||
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
reasoningEffort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -172,7 +172,6 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
||||
const uploadedFiles = await FileManager.uploadFiles(files)
|
||||
|
||||
const baseUserMessage: MessageInputBaseParams = { assistant, topic, content: text }
|
||||
Logger.log('baseUserMessage', baseUserMessage)
|
||||
|
||||
// getUserMessage()
|
||||
if (uploadedFiles) {
|
||||
|
||||
@ -49,7 +49,7 @@ type OpenAIParamsWithoutReasoningEffort = Omit<OpenAI.Chat.Completions.ChatCompl
|
||||
export type ReasoningEffortOptionalParams = {
|
||||
thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number }
|
||||
reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string; enabled?: boolean } | OpenAI.Reasoning
|
||||
reasoning_effort?: OpenAI.Chat.Completions.ChatCompletionCreateParams['reasoning_effort'] | 'none' | 'auto'
|
||||
reasoningEffort?: OpenAI.Chat.Completions.ChatCompletionCreateParams['reasoning_effort'] | 'none' | 'auto'
|
||||
enable_thinking?: boolean
|
||||
thinking_budget?: number
|
||||
enable_reasoning?: boolean
|
||||
|
||||
Loading…
Reference in New Issue
Block a user