refactor: migrate to v5 patch-2

This commit is contained in:
suyao 2025-07-07 03:58:10 +08:00
parent f20d964be3
commit 448b5b5c9e
No known key found for this signature in database
16 changed files with 109 additions and 94 deletions

View File

@ -8,7 +8,7 @@ export type { NamedMiddleware } from './middleware'
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware' export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
// 创建管理 // 创建管理
export type { ModelCreationRequest, ResolvedConfig } from './models' export type { ModelConfig } from './models'
export { export {
createBaseModel, createBaseModel,
createImageModel, createImageModel,
@ -19,6 +19,6 @@ export {
} from './models' } from './models'
// 执行管理 // 执行管理
export type { MCPRequestContext } from './plugins/built-in/mcpPromptPlugin' export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
export type { ExecutionOptions, ExecutorConfig } from './runtime' export type { ExecutionOptions, ExecutorConfig } from './runtime'
export { createExecutor, createOpenAICompatibleExecutor } from './runtime' export { createExecutor, createOpenAICompatibleExecutor } from './runtime'

View File

@ -1,37 +1,37 @@
/** // /**
* // * 配置管理器
* optionspluginsmiddlewares等配置 // * 整合options、plugins、middlewares等配置
*/ // */
import { LanguageModelV2Middleware } from '@ai-sdk/provider' // import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import { ProviderId, ProviderSettingsMap } from '../../types' // import { ProviderId, ProviderSettingsMap } from '../../types'
import { createMiddlewares } from '../middleware/manager' // import { createMiddlewares } from '../middleware/manager'
import { AiPlugin } from '../plugins' // import { AiPlugin } from '../plugins'
import { ResolvedConfig } from './types' // import { ResolvedConfig } from './types'
/** // /**
* // * 解析配置
* provider配置provider选项等 // * 整合provider配置、插件、中间件、provider选项等
*/ // */
export function resolveConfig( // export function resolveConfig(
providerId: ProviderId, // providerId: ProviderId,
modelId: string, // modelId: string,
providerSettings: ProviderSettingsMap[ProviderId], // providerSettings: ProviderSettingsMap[ProviderId],
plugins: AiPlugin[] = [], // plugins: AiPlugin[] = [],
middlewares: LanguageModelV2Middleware[] = [] // middlewares: LanguageModelV2Middleware[] = []
): ResolvedConfig { // ): ResolvedConfig {
// 使用独立的中间件管理器处理中间件 // // 使用独立的中间件管理器处理中间件
const resolvedMiddlewares = createMiddlewares(middlewares) // const resolvedMiddlewares = createMiddlewares(middlewares)
return { // return {
provider: { // provider: {
id: providerId, // id: providerId,
options: providerSettings // options: providerSettings
}, // },
model: { // model: {
id: modelId // id: modelId
}, // },
plugins, // plugins,
middlewares: resolvedMiddlewares // middlewares: resolvedMiddlewares
} // }
} // }

View File

@ -2,8 +2,7 @@
* Provider * Provider
* AI SDK providers * AI SDK providers
*/ */
import { ImageModelV2, type LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider' import { ImageModelV2, type LanguageModelV2 } from '@ai-sdk/provider'
import { wrapLanguageModel } from 'ai'
import { type ProviderId, type ProviderSettingsMap } from '../../types' import { type ProviderId, type ProviderSettingsMap } from '../../types'
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model' import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'

View File

@ -2,7 +2,7 @@
* *
* *
*/ */
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider' import { LanguageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai' import { LanguageModel } from 'ai'
import { wrapModelWithMiddlewares } from '../middleware' import { wrapModelWithMiddlewares } from '../middleware'
@ -16,7 +16,11 @@ export async function createModel(config: ModelConfig): Promise<LanguageModelV2>
validateModelConfig(config) validateModelConfig(config)
// 1. 创建基础模型 // 1. 创建基础模型
const baseModel = await createBaseModel(config) const baseModel = await createBaseModel({
providerId: config.providerId,
modelId: config.modelId,
providerSettings: config.options
})
// 2. 应用中间件(如果有) // 2. 应用中间件(如果有)
return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel return config.middlewares?.length ? wrapModelWithMiddlewares(baseModel, config.middlewares) : baseModel
@ -39,7 +43,7 @@ function validateModelConfig(config: ModelConfig): void {
if (!config.modelId) { if (!config.modelId) {
throw new Error('ModelConfig: modelId is required') throw new Error('ModelConfig: modelId is required')
} }
if (!config.providerSettings) { if (!config.options) {
throw new Error('ModelConfig: providerSettings is required') throw new Error('ModelConfig: providerSettings is required')
} }
} }

View File

@ -5,6 +5,6 @@
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:' export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
export { createLoggingPlugin } from './logging' export { createLoggingPlugin } from './logging'
export type { MCPPromptConfig, ToolUseResult } from './mcpPromptPlugin' export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
export { createMCPPromptPlugin } from './mcpPromptPlugin' export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
export { type WebSearchConfig, webSearchPlugin } from './webSearchPlugin' export { type WebSearchConfig, webSearchPlugin } from './webSearchPlugin'

View File

@ -5,8 +5,9 @@
*/ */
import type { ModelMessage, TextStreamPart, ToolErrorUnion, ToolSet } from 'ai' import type { ModelMessage, TextStreamPart, ToolErrorUnion, ToolSet } from 'ai'
import { definePlugin } from '../index' import { definePlugin } from '../../index'
import type { AiRequestContext } from '../types' import type { AiRequestContext } from '../../types'
import { PromptToolUseConfig, ToolUseResult } from './type'
/** /**
* 使 AI SDK Tool * 使 AI SDK Tool
@ -25,37 +26,6 @@ import type { AiRequestContext } from '../types'
// } // }
// } // }
/**
*
* AI响应中解析出的工具使用意图
*/
export interface ToolUseResult {
id: string
toolName: string
arguments: any
status: 'pending' | 'invoking' | 'done' | 'error'
}
/**
* MCP Prompt
*/
export interface MCPPromptConfig {
// 是否启用(用于运行时开关)
enabled?: boolean
// 自定义系统提示符构建函数(可选,有默认实现)
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
}
/**
* AI MCP
*/
export interface MCPRequestContext extends AiRequestContext {
mcpTools: ToolSet
}
/** /**
* Cherry Studio * Cherry Studio
*/ */
@ -282,14 +252,11 @@ function defaultParseToolUse(content: string, tools: ToolSet): ToolUseResult[] {
return results return results
} }
/** export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
* MCP Prompt
*/
export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
return definePlugin({ return definePlugin({
name: 'built-in:mcp-prompt', name: 'built-in:prompt-tool-use',
transformParams: (params: any, context: AiRequestContext) => { transformParams: (params: any, context: AiRequestContext) => {
if (!enabled || !params.tools || typeof params.tools !== 'object') { if (!enabled || !params.tools || typeof params.tools !== 'object') {
return params return params
@ -315,6 +282,7 @@ export const createMCPPromptPlugin = (config: MCPPromptConfig = {}) => {
...(systemMessage ? { system: systemMessage } : {}), ...(systemMessage ? { system: systemMessage } : {}),
tools: undefined tools: undefined
} }
context.originalParams = transformedParams
console.log('transformedParams', transformedParams) console.log('transformedParams', transformedParams)
return transformedParams return transformedParams
}, },

View File

@ -0,0 +1,33 @@
import { ToolSet } from 'ai'
import { AiRequestContext } from '../..'
/**
*
* AI响应中解析出的工具使用意图
*/
export interface ToolUseResult {
id: string
toolName: string
arguments: any
status: 'pending' | 'invoking' | 'done' | 'error'
}
export interface BaseToolUsePluginConfig {
enabled?: boolean
}
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
// 自定义系统提示符构建函数(可选,有默认实现)
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => ToolUseResult[]
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
}
/**
* AI MCP
*/
export interface ToolUseRequestContext extends AiRequestContext {
mcpTools: ToolSet
}

View File

@ -14,7 +14,7 @@ import { adaptWebSearchForProvider, type WebSearchConfig } from './helper'
* options.ts assistant.enableWebSearch * options.ts assistant.enableWebSearch
* providerOptions webSearch: { enabled: true } * providerOptions webSearch: { enabled: true }
*/ */
export const webSearchPlugin = (config) => export const webSearchPlugin = () =>
definePlugin({ definePlugin({
name: 'webSearch', name: 'webSearch',
enforce: 'pre', enforce: 'pre',

View File

@ -38,6 +38,7 @@ export interface AiPlugin {
// 【Sequential】串行钩子 - 链式执行,支持数据转换 // 【Sequential】串行钩子 - 链式执行,支持数据转换
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any> transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any> transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
configureModel?: (model: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用 // 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void> onRequestStart?: (context: AiRequestContext) => void | Promise<void>

View File

@ -218,7 +218,7 @@ export class RuntimeExecutor<T extends ProviderId = ProviderId> {
return await createModel({ return await createModel({
providerId: this.config.providerId, providerId: this.config.providerId,
modelId: modelOrId, modelId: modelOrId,
providerSettings: this.config.providerSettings, options: this.config.providerSettings,
middlewares middlewares
}) })
} else { } else {

View File

@ -75,7 +75,13 @@ export type {
ToolSet, ToolSet,
UserModelMessage UserModelMessage
} from 'ai' } from 'ai'
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware, smoothStream } from 'ai' export {
defaultSettingsMiddleware,
extractReasoningMiddleware,
simulateStreamingMiddleware,
smoothStream,
stepCountIs
} from 'ai'
// 重新导出所有 Provider Settings 类型 // 重新导出所有 Provider Settings 类型
export type { export type {

View File

@ -17,7 +17,7 @@ import {
type ProviderSettingsMap, type ProviderSettingsMap,
StreamTextParams StreamTextParams
} from '@cherrystudio/ai-core' } from '@cherrystudio/ai-core'
import { createMCPPromptPlugin } from '@cherrystudio/ai-core/core/plugins/built-in' import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/core/plugins/built-in'
import { isDedicatedImageGenerationModel } from '@renderer/config/models' import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import type { GenerateImageParams, Model, Provider } from '@renderer/types' import type { GenerateImageParams, Model, Provider } from '@renderer/types'
@ -136,7 +136,7 @@ export default class ModernAiProvider {
// 3. 启用Prompt工具调用时添加工具插件 // 3. 启用Prompt工具调用时添加工具插件
if (middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) { if (middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
plugins.push( plugins.push(
createMCPPromptPlugin({ createPromptToolUsePlugin({
enabled: true, enabled: true,
createSystemMessage: (systemPrompt, params, context) => { createSystemMessage: (systemPrompt, params, context) => {
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) { if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
@ -157,6 +157,10 @@ export default class ModernAiProvider {
}) })
) )
} }
// if (!middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
// plugins.push(createNativeToolUsePlugin())
// }
console.log( console.log(
'最终插件列表:', '最终插件列表:',
plugins.map((p) => p.name) plugins.map((p) => p.name)

View File

@ -3,7 +3,7 @@
* apiClient * apiClient
*/ */
import { type ModelMessage, type StreamTextParams } from '@cherrystudio/ai-core' import { type ModelMessage, stepCountIs, type StreamTextParams } from '@cherrystudio/ai-core'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { import {
isGenerateImageModel, isGenerateImageModel,
@ -245,7 +245,8 @@ export async function buildStreamTextParams(
abortSignal: options.requestOptions?.signal, abortSignal: options.requestOptions?.signal,
headers: options.requestOptions?.headers, headers: options.requestOptions?.headers,
providerOptions, providerOptions,
tools tools,
stopWhen: stepCountIs(10)
} }
return { params, modelId: model.id, capabilities: { enableReasoning, enableWebSearch, enableGenerateImage } } return { params, modelId: model.id, capabilities: { enableReasoning, enableWebSearch, enableGenerateImage } }

View File

@ -42,7 +42,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
if (isSupportedThinkingTokenGeminiModel(model)) { if (isSupportedThinkingTokenGeminiModel(model)) {
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return { reasoning_effort: 'none' } return { reasoningEffort: 'none' }
} }
return {} return {}
} }
@ -85,7 +85,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
return { reasoning: { max_tokens: 0, exclude: true } } return { reasoning: { max_tokens: 0, exclude: true } }
} }
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) { if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return { reasoning_effort: 'none' } return { reasoningEffort: 'none' }
} }
return {} return {}
} }
@ -123,14 +123,14 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
// Grok models // Grok models
if (isSupportedReasoningEffortGrokModel(model)) { if (isSupportedReasoningEffortGrokModel(model)) {
return { return {
reasoning_effort: reasoningEffort reasoningEffort: reasoningEffort
} }
} }
// OpenAI models // OpenAI models
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) { if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
return { return {
reasoning_effort: reasoningEffort reasoningEffort: reasoningEffort
} }
} }

View File

@ -172,7 +172,6 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
const uploadedFiles = await FileManager.uploadFiles(files) const uploadedFiles = await FileManager.uploadFiles(files)
const baseUserMessage: MessageInputBaseParams = { assistant, topic, content: text } const baseUserMessage: MessageInputBaseParams = { assistant, topic, content: text }
Logger.log('baseUserMessage', baseUserMessage)
// getUserMessage() // getUserMessage()
if (uploadedFiles) { if (uploadedFiles) {

View File

@ -49,7 +49,7 @@ type OpenAIParamsWithoutReasoningEffort = Omit<OpenAI.Chat.Completions.ChatCompl
export type ReasoningEffortOptionalParams = { export type ReasoningEffortOptionalParams = {
thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number } thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number }
reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string; enabled?: boolean } | OpenAI.Reasoning reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string; enabled?: boolean } | OpenAI.Reasoning
reasoning_effort?: OpenAI.Chat.Completions.ChatCompletionCreateParams['reasoning_effort'] | 'none' | 'auto' reasoningEffort?: OpenAI.Chat.Completions.ChatCompletionCreateParams['reasoning_effort'] | 'none' | 'auto'
enable_thinking?: boolean enable_thinking?: boolean
thinking_budget?: number thinking_budget?: number
enable_reasoning?: boolean enable_reasoning?: boolean