From ce2500159041bc31ac9dee8e5d27faf4a80d578e Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 01:27:20 +0800 Subject: [PATCH] feat: add shared AI SDK middlewares and refactor middleware handling --- packages/shared/middleware/index.ts | 15 + packages/shared/middleware/middlewares.ts | 205 ++++++++ .../apiServer/services/unified-messages.ts | 461 +++++++----------- .../middleware/AiSdkMiddlewareBuilder.ts | 3 +- .../openrouterReasoningMiddleware.ts | 50 -- .../skipGeminiThoughtSignatureMiddleware.ts | 36 -- tsconfig.node.json | 4 +- 7 files changed, 401 insertions(+), 373 deletions(-) create mode 100644 packages/shared/middleware/index.ts create mode 100644 packages/shared/middleware/middlewares.ts delete mode 100644 src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts delete mode 100644 src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts diff --git a/packages/shared/middleware/index.ts b/packages/shared/middleware/index.ts new file mode 100644 index 0000000000..a4db5ad2dd --- /dev/null +++ b/packages/shared/middleware/index.ts @@ -0,0 +1,15 @@ +/** + * Shared AI SDK Middlewares + * + * Environment-agnostic middlewares that can be used in both + * renderer process and main process (API server). + */ + +export { + buildSharedMiddlewares, + getReasoningTagName, + isGemini3ModelId, + openrouterReasoningMiddleware, + type SharedMiddlewareConfig, + skipGeminiThoughtSignatureMiddleware +} from './middlewares' diff --git a/packages/shared/middleware/middlewares.ts b/packages/shared/middleware/middlewares.ts new file mode 100644 index 0000000000..d9725101c2 --- /dev/null +++ b/packages/shared/middleware/middlewares.ts @@ -0,0 +1,205 @@ +/** + * Shared AI SDK Middlewares + * + * These middlewares are environment-agnostic and can be used in both + * renderer process and main process (API server). + */ +import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider' +import { extractReasoningMiddleware } from 'ai' + +/** + * Configuration for building shared middlewares + */ +export interface SharedMiddlewareConfig { + /** + * Whether to enable reasoning extraction + */ + enableReasoning?: boolean + + /** + * Tag name for reasoning extraction + * Defaults based on model ID + */ + reasoningTagName?: string + + /** + * Model ID - used to determine default reasoning tag and model detection + */ + modelId?: string + + /** + * Provider ID (Cherry Studio provider ID) + * Used for provider-specific middlewares like OpenRouter + */ + providerId?: string + + /** + * AI SDK Provider ID + * Used for Gemini thought signature middleware + * e.g., 'google', 'google-vertex' + */ + aiSdkProviderId?: string +} + +/** + * Check if model ID represents a Gemini 3 (2.5) model + * that requires thought signature handling + * + * @param modelId - The model ID string (not Model object) + */ +export function isGemini3ModelId(modelId?: string): boolean { + if (!modelId) return false + const lowerModelId = modelId.toLowerCase() + return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3') +} + +/** + * Get the default reasoning tag name based on model ID + * + * Different models use different tags for reasoning content: + * - Most models: 'think' + * - GPT-OSS models: 'reasoning' + * - Gemini models: 'thought' + * - Seed models: 'seed:think' + */ +export function getReasoningTagName(modelId?: string): string { + if (!modelId) return 'think' + const lowerModelId = modelId.toLowerCase() + if (lowerModelId.includes('gpt-oss')) return 'reasoning' + if (lowerModelId.includes('gemini')) return 'thought' + if (lowerModelId.includes('seed-oss-36b')) return 'seed:think' + return 'think' +} + +/** + * Skip Gemini Thought Signature Middleware + * + * Due to the complexity of multi-model client requests (which can switch + * to other models mid-process), this middleware skips all Gemini 3 + * thinking signatures validation. + * + * @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex') + * @returns LanguageModelV2Middleware + */ +export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware { + const MAGIC_STRING = 'skip_thought_signature_validator' + return { + middlewareVersion: 'v2', + + transformParams: async ({ params }) => { + const transformedParams = { ...params } + // Process messages in prompt + if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) { + transformedParams.prompt = transformedParams.prompt.map((message) => { + if (typeof message.content !== 'string') { + for (const part of message.content) { + const googleOptions = part?.providerOptions?.[aiSdkId] + if (googleOptions?.thoughtSignature) { + googleOptions.thoughtSignature = MAGIC_STRING + } + } + } + return message + }) + } + + return transformedParams + } + } +} + +/** + * OpenRouter Reasoning Middleware + * + * Filters out [REDACTED] blocks from OpenRouter reasoning responses. + * OpenRouter may include [REDACTED] markers in reasoning content that + * should be removed for cleaner output. + * + * @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens + * @returns LanguageModelV2Middleware + */ +export function openrouterReasoningMiddleware(): LanguageModelV2Middleware { + const REDACTED_BLOCK = '[REDACTED]' + return { + middlewareVersion: 'v2', + wrapGenerate: async ({ doGenerate }) => { + const { content, ...rest } = await doGenerate() + const modifiedContent = content.map((part) => { + if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) { + return { + ...part, + text: part.text.replace(REDACTED_BLOCK, '') + } + } + return part + }) + return { content: modifiedContent, ...rest } + }, + wrapStream: async ({ doStream }) => { + const { stream, ...rest } = await doStream() + return { + stream: stream.pipeThrough( + new TransformStream({ + transform( + chunk: LanguageModelV2StreamPart, + controller: TransformStreamDefaultController + ) { + if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) { + controller.enqueue({ + ...chunk, + delta: chunk.delta.replace(REDACTED_BLOCK, '') + }) + } else { + controller.enqueue(chunk) + } + } + }) + ), + ...rest + } + } + } +} + +/** + * Build shared middlewares based on configuration + * + * This function builds a set of middlewares that are commonly needed + * across different environments (renderer, API server). + * + * @param config - Configuration for middleware building + * @returns Array of AI SDK middlewares + * + * @example + * ```typescript + * import { buildSharedMiddlewares } from '@shared/middleware' + * + * const middlewares = buildSharedMiddlewares({ + * enableReasoning: true, + * modelId: 'gemini-2.5-pro', + * providerId: 'openrouter', + * aiSdkProviderId: 'google' + * }) + * ``` + */ +export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] { + const middlewares: LanguageModelV2Middleware[] = [] + + // 1. Reasoning extraction middleware + if (config.enableReasoning) { + const tagName = config.reasoningTagName || getReasoningTagName(config.modelId) + middlewares.push(extractReasoningMiddleware({ tagName })) + } + + // 2. OpenRouter-specific: filter [REDACTED] blocks + if (config.providerId === 'openrouter' && config.enableReasoning) { + middlewares.push(openrouterReasoningMiddleware()) + } + + // 3. Gemini 3 (2.5) specific: skip thought signature validation + if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) { + middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId)) + } + + return middlewares +} diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index ddb6d59b37..be8b05aeac 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -1,4 +1,4 @@ -import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' +import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' import type { ImageBlockParam, @@ -6,7 +6,7 @@ import type { TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources/messages' -import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' +import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' import { loggerService } from '@logger' import { reduxService } from '@main/services/ReduxService' import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' @@ -21,8 +21,8 @@ import { } from '@shared/provider' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' -import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' -import { jsonSchema, stepCountIs, streamText, tool } from 'ai' +import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai' +import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai' import { net } from 'electron' import type { Response } from 'express' @@ -33,6 +33,9 @@ initializeSharedProviders({ error: (message, error) => logger.error(message, error) }) +/** + * Configuration for unified message streaming + */ export interface UnifiedStreamConfig { response: Response provider: Provider @@ -40,12 +43,31 @@ export interface UnifiedStreamConfig { params: MessageCreateParams onError?: (error: unknown) => void onComplete?: () => void + /** + * Optional AI SDK middlewares to apply + */ + middlewares?: LanguageModelV2Middleware[] + /** + * Optional AI Core plugins to use with the executor + */ + plugins?: AiPlugin[] } /** - * Main process format context for formatProviderApiHost - * Unlike renderer, main process doesn't have direct access to store getters, so use reduxService cache + * Configuration for non-streaming message generation */ +export interface GenerateUnifiedMessageConfig { + provider: Provider + modelId: string + params: MessageCreateParams + middlewares?: LanguageModelV2Middleware[] + plugins?: AiPlugin[] +} + +// ============================================================================ +// Internal Utilities +// ============================================================================ + function getMainProcessFormatContext(): ProviderFormatContext { const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') return { @@ -56,12 +78,7 @@ function getMainProcessFormatContext(): ProviderFormatContext { } } -/** - * Main process context for providerToAiSdkConfig - * Main process doesn't have access to browser APIs like window.keyv - */ const mainProcessSdkContext: AiSdkConfigContext = { - // Simple key rotation - just return first key (no persistent rotation in main process) getRotatedApiKey: (provider) => { const keys = provider.apiKey.split(',').map((k) => k.trim()) return keys[0] || provider.apiKey @@ -69,199 +86,82 @@ const mainProcessSdkContext: AiSdkConfigContext = { fetch: net.fetch as typeof globalThis.fetch } -/** - * Get actual provider configuration for a model - * - * For aggregated providers (new-api, aihubmix, vertexai, azure-openai), - * this resolves the actual provider type based on the model's characteristics. - */ function getActualProvider(provider: Provider, modelId: string): Provider { - // Find the model in provider's models list const model = provider.models?.find((m) => m.id === modelId) - if (!model) { - // If model not found, return provider as-is - return provider - } - - // Resolve actual provider based on model + if (!model) return provider return resolveActualProvider(provider, model) } -/** - * Convert Cherry Studio Provider to AI SDK config - * Uses shared implementation with main process context - */ function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig { - // First resolve actual provider for aggregated providers const actualProvider = getActualProvider(provider, modelId) - - // Format the provider's apiHost for AI SDK const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext()) - - // Use shared implementation return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext) } -/** - * Create an AI SDK provider from Cherry Studio provider configuration - */ -async function createAiSdkProvider(config: AiSdkConfig): Promise { - try { - const provider = await createProviderCore(config.providerId, config.options) - logger.debug('AI SDK provider created', { - providerId: config.providerId, - hasOptions: !!config.options - }) - return provider - } catch (error) { - logger.error('Failed to create AI SDK provider', error as Error, { - providerId: config.providerId - }) - throw error - } -} - -/** - * Create an AI SDK language model from a Cherry Studio provider configuration - * Uses shared provider utilities for consistent behavior with renderer - */ -async function createLanguageModel(provider: Provider, modelId: string): Promise { - logger.debug('Creating language model', { - providerId: provider.id, - providerType: provider.type, - modelId, - apiHost: provider.apiHost - }) - - // Convert provider config to AI SDK config - const config = providerToAiSdkConfig(provider, modelId) - - // Create the AI SDK provider - const aiSdkProvider = await createAiSdkProvider(config) - if (!aiSdkProvider) { - throw new Error(`Failed to create AI SDK provider for ${provider.id}`) - } - - // Get the language model - return aiSdkProvider.languageModel(modelId) -} - function convertAnthropicToolResultToAiSdk( content: string | Array ): LanguageModelV2ToolResultOutput { if (typeof content === 'string') { - return { - type: 'text', - value: content - } - } else { - const values: Array< - | { type: 'text'; text: string } - | { - type: 'media' - /** -Base-64 encoded media data. -*/ - data: string - /** -IANA media type. -@see https://www.iana.org/assignments/media-types/media-types.xhtml -*/ - mediaType: string - } - > = [] - for (const block of content) { - if (block.type === 'text') { - values.push({ - type: 'text', - text: block.text - }) - } else if (block.type === 'image') { - values.push({ - type: 'media', - data: block.source.type === 'base64' ? block.source.data : block.source.url, - mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' - }) - } - } - return { - type: 'content', - value: values + return { type: 'text', value: content } + } + const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = [] + for (const block of content) { + if (block.type === 'text') { + values.push({ type: 'text', text: block.text }) + } else if (block.type === 'image') { + values.push({ + type: 'media', + data: block.source.type === 'base64' ? block.source.data : block.source.url, + mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png' + }) } } + return { type: 'content', value: values } } -/** - * Convert Anthropic tools format to AI SDK tools format - */ function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record | undefined { - if (!tools || tools.length === 0) { - return undefined - } + if (!tools || tools.length === 0) return undefined const aiSdkTools: Record = {} - for (const anthropicTool of tools) { - // Handle different tool types - if (anthropicTool.type === 'bash_20250124') { - // Skip computer use and bash tools - these are Anthropic-specific - continue - } - - // Regular tool (type === 'custom' or no type) + if (anthropicTool.type === 'bash_20250124') continue const toolDef = anthropicTool as AnthropicTool const parameters = toolDef.input_schema as Parameters[0] - aiSdkTools[toolDef.name] = tool({ description: toolDef.description || '', inputSchema: jsonSchema(parameters), - execute: async (input: Record) => { - return input - } + execute: async (input: Record) => input }) } - return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined } -/** - * Convert Anthropic MessageCreateParams to AI SDK message format - */ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] { const messages: ModelMessage[] = [] - // Add system message if present + // System message if (params.system) { if (typeof params.system === 'string') { - messages.push({ - role: 'system', - content: params.system - }) + messages.push({ role: 'system', content: params.system }) } else if (Array.isArray(params.system)) { - // Handle TextBlockParam array const systemText = params.system .filter((block) => block.type === 'text') .map((block) => block.text) .join('\n') if (systemText) { - messages.push({ - role: 'system', - content: systemText - }) + messages.push({ role: 'system', content: systemText }) } } } - // Convert user/assistant messages + // User/assistant messages for (const msg of params.messages) { if (typeof msg.content === 'string') { - if (msg.role === 'user') { - messages.push({ role: 'user', content: msg.content }) - } else { - messages.push({ role: 'assistant', content: msg.content }) - } + messages.push({ + role: msg.role === 'user' ? 'user' : 'assistant', + content: msg.content + }) } else if (Array.isArray(msg.content)) { - // Handle content blocks const textParts: TextPart[] = [] const imageParts: ImagePart[] = [] const reasoningParts: ReasoningPart[] = [] @@ -278,15 +178,9 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } else if (block.type === 'image') { const source = block.source if (source.type === 'base64') { - imageParts.push({ - type: 'image', - image: `data:${source.media_type};base64,${source.data}` - }) + imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` }) } else if (source.type === 'url') { - imageParts.push({ - type: 'image', - image: source.url - }) + imageParts.push({ type: 'image', image: source.url }) } } else if (block.type === 'tool_use') { toolCallParts.push({ @@ -306,30 +200,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } if (toolResultParts.length > 0) { - messages.push({ - role: 'tool', - content: [...toolResultParts] - }) + messages.push({ role: 'tool', content: [...toolResultParts] }) } - // Build the message based on role - // Only push user/assistant message if there's actual content (avoid empty messages) if (msg.role === 'user') { const userContent = [...textParts, ...imageParts] if (userContent.length > 0) { - messages.push({ - role: 'user', - content: userContent - }) + messages.push({ role: 'user', content: userContent }) } } else { - // Assistant messages contain tool calls, not tool results const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] if (assistantContent.length > 0) { - messages.push({ - role: 'assistant', - content: assistantContent - }) + messages.push({ role: 'assistant', content: assistantContent }) } } } @@ -338,67 +220,54 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage return messages } -/** - * Stream a message request using AI SDK and convert to Anthropic SSE format - */ -// TODO: 使用ai-core executor集成中间件和transformstream进来 -export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { - const { response, provider, modelId, params, onError, onComplete } = config +interface ExecuteStreamConfig { + provider: Provider + modelId: string + params: MessageCreateParams + middlewares?: LanguageModelV2Middleware[] + plugins?: AiPlugin[] + onEvent?: (event: Parameters[0]) => void +} - logger.info('Starting unified message stream', { - providerId: provider.id, - providerType: provider.type, - modelId, - stream: params.stream +/** + * Core stream execution function - single source of truth for AI SDK calls + */ +async function executeStream(config: ExecuteStreamConfig): Promise { + const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config + + // Convert provider config to AI SDK config + const sdkConfig = providerToAiSdkConfig(provider, modelId) + + logger.debug('Created AI SDK config', { + providerId: sdkConfig.providerId, + hasOptions: !!sdkConfig.options }) - try { - response.setHeader('Content-Type', 'text/event-stream') - response.setHeader('Cache-Control', 'no-cache') - response.setHeader('Connection', 'keep-alive') - response.setHeader('X-Accel-Buffering', 'no') + // Create executor with plugins + const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) - const model = await createLanguageModel(provider, modelId) + // Convert messages and tools + const coreMessages = convertAnthropicToAiMessages(params) + const tools = convertAnthropicToolsToAiSdk(params.tools) - const coreMessages = convertAnthropicToAiMessages(params) + logger.debug('Converted messages', { + originalCount: params.messages.length, + convertedCount: coreMessages.length, + hasSystem: !!params.system, + hasTools: !!tools, + toolCount: tools ? Object.keys(tools).length : 0 + }) - // Convert tools if present - const tools = convertAnthropicToolsToAiSdk(params.tools) + // Create the adapter + const adapter = new AiSdkToAnthropicSSE({ + model: `${provider.id}:${modelId}`, + onEvent: onEvent || (() => {}) + }) - logger.debug('Converted messages', { - originalCount: params.messages.length, - convertedCount: coreMessages.length, - hasSystem: !!params.system, - hasTools: !!tools, - toolCount: tools ? Object.keys(tools).length : 0, - toolNames: tools ? Object.keys(tools).slice(0, 10) : [], - paramsToolCount: params.tools?.length || 0 - }) - - // Debug: Log message structure to understand tool_result handling - logger.silly('Message structure for debugging', { - messages: coreMessages.map((m) => ({ - role: m.role, - contentTypes: Array.isArray(m.content) - ? m.content.map((c: { type: string }) => c.type) - : typeof m.content === 'string' - ? ['string'] - : ['unknown'] - })) - }) - - // Create the adapter - const adapter = new AiSdkToAnthropicSSE({ - model: `${provider.id}:${modelId}`, - onEvent: (event) => { - const sseData = formatSSEEvent(event) - response.write(sseData) - } - }) - - // Start streaming - const result = streamText({ - model, + // Execute stream + const result = await executor.streamText( + { + model: modelId, messages: coreMessages, maxOutputTokens: params.max_tokens, temperature: params.temperature, @@ -408,38 +277,65 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis headers: defaultAppHeaders(), tools, providerOptions: {} - }) + }, + { middlewares } + ) - // Process the stream through the adapter - await adapter.processStream(result.fullStream) + // Process the stream through the adapter + await adapter.processStream(result.fullStream) + + return adapter +} + +/** + * Stream a message request using AI SDK executor and convert to Anthropic SSE format + */ +export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise { + const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config + + logger.info('Starting unified message stream', { + providerId: provider.id, + providerType: provider.type, + modelId, + stream: params.stream, + middlewareCount: middlewares.length, + pluginCount: plugins.length + }) + + try { + response.setHeader('Content-Type', 'text/event-stream') + response.setHeader('Cache-Control', 'no-cache') + response.setHeader('Connection', 'keep-alive') + response.setHeader('X-Accel-Buffering', 'no') + + await executeStream({ + provider, + modelId, + params, + middlewares, + plugins, + onEvent: (event) => { + const sseData = formatSSEEvent(event) + response.write(sseData) + } + }) // Send done marker response.write(formatSSEDone()) response.end() - logger.info('Unified message stream completed', { - providerId: provider.id, - modelId - }) - + logger.info('Unified message stream completed', { providerId: provider.id, modelId }) onComplete?.() } catch (error) { - logger.error('Error in unified message stream', error as Error, { - providerId: provider.id, - modelId - }) + logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId }) - // Try to send error event if response is still writable if (!response.writableEnded) { try { const errorMessage = error instanceof Error ? error.message : 'Unknown error' response.write( `event: error\ndata: ${JSON.stringify({ type: 'error', - error: { - type: 'api_error', - message: errorMessage - } + error: { type: 'api_error', message: errorMessage } })}\n\n` ) response.end() @@ -455,64 +351,61 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis /** * Generate a non-streaming message response + * + * Uses simulateStreamingMiddleware to reuse the same streaming logic, + * similar to renderer's ModernAiProvider pattern. */ export async function generateUnifiedMessage( - provider: Provider, - modelId: string, - params: MessageCreateParams + providerOrConfig: Provider | GenerateUnifiedMessageConfig, + modelId?: string, + params?: MessageCreateParams ): Promise> { + // Support both old signature and new config-based signature + let config: GenerateUnifiedMessageConfig + if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) { + config = providerOrConfig + } else { + config = { + provider: providerOrConfig as Provider, + modelId: modelId!, + params: params! + } + } + + const { provider, middlewares = [], plugins = [] } = config + logger.info('Starting unified message generation', { providerId: provider.id, providerType: provider.type, - modelId + modelId: config.modelId, + middlewareCount: middlewares.length, + pluginCount: plugins.length }) try { - // Create language model (async - uses @cherrystudio/ai-core) - const model = await createLanguageModel(provider, modelId) + // Add simulateStreamingMiddleware to reuse streaming logic for non-streaming + const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares] - // Convert messages and tools - const coreMessages = convertAnthropicToAiMessages(params) - const tools = convertAnthropicToolsToAiSdk(params.tools) - - // Create adapter to collect the response - let finalResponse: ReturnType | null = null - const adapter = new AiSdkToAnthropicSSE({ - model: `${provider.id}:${modelId}`, - onEvent: () => { - // We don't need to emit events for non-streaming - } + const adapter = await executeStream({ + provider, + modelId: config.modelId, + params: config.params, + middlewares: allMiddlewares, + plugins }) - // Generate text - const result = streamText({ - model, - messages: coreMessages, - maxOutputTokens: params.max_tokens, - temperature: params.temperature, - topP: params.top_p, - stopSequences: params.stop_sequences, - headers: defaultAppHeaders(), - tools, - stopWhen: stepCountIs(100) - }) - - // Process the stream to build the response - await adapter.processStream(result.fullStream) - - // Get the final response - finalResponse = adapter.buildNonStreamingResponse() + const finalResponse = adapter.buildNonStreamingResponse() logger.info('Unified message generation completed', { providerId: provider.id, - modelId + modelId: config.modelId }) return finalResponse } catch (error) { logger.error('Error in unified message generation', error as Error, { providerId: provider.id, - modelId + modelId: config.modelId }) throw error } diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index b314ddd737..82e1c32465 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -5,6 +5,7 @@ import type { MCPTool } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { isSupportEnableThinkingProvider } from '@renderer/utils/provider' +import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware' import type { LanguageModelMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' import { isEmpty } from 'lodash' @@ -13,9 +14,7 @@ import { getAiSdkProviderId } from '../provider/factory' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { noThinkMiddleware } from './noThinkMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' -import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' -import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware' import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') diff --git a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts b/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts deleted file mode 100644 index 9ef3df61e9..0000000000 --- a/src/renderer/src/aiCore/middleware/openrouterReasoningMiddleware.ts +++ /dev/null @@ -1,50 +0,0 @@ -import type { LanguageModelV2StreamPart } from '@ai-sdk/provider' -import type { LanguageModelMiddleware } from 'ai' - -/** - * https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude - * - * @returns LanguageModelMiddleware - a middleware filter redacted block - */ -export function openrouterReasoningMiddleware(): LanguageModelMiddleware { - const REDACTED_BLOCK = '[REDACTED]' - return { - middlewareVersion: 'v2', - wrapGenerate: async ({ doGenerate }) => { - const { content, ...rest } = await doGenerate() - const modifiedContent = content.map((part) => { - if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) { - return { - ...part, - text: part.text.replace(REDACTED_BLOCK, '') - } - } - return part - }) - return { content: modifiedContent, ...rest } - }, - wrapStream: async ({ doStream }) => { - const { stream, ...rest } = await doStream() - return { - stream: stream.pipeThrough( - new TransformStream({ - transform( - chunk: LanguageModelV2StreamPart, - controller: TransformStreamDefaultController - ) { - if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) { - controller.enqueue({ - ...chunk, - delta: chunk.delta.replace(REDACTED_BLOCK, '') - }) - } else { - controller.enqueue(chunk) - } - } - }) - ), - ...rest - } - } - } -} diff --git a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts b/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts deleted file mode 100644 index da318ea60d..0000000000 --- a/src/renderer/src/aiCore/middleware/skipGeminiThoughtSignatureMiddleware.ts +++ /dev/null @@ -1,36 +0,0 @@ -import type { LanguageModelMiddleware } from 'ai' - -/** - * skip Gemini Thought Signature Middleware - * 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名 - * Due to the complexity of multi-model client requests (which can switch to other models mid-process), - * it was decided to add a skip for all Gemini3 thinking signatures via middleware. - * @param aiSdkId AI SDK Provider ID - * @returns LanguageModelMiddleware - */ -export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware { - const MAGIC_STRING = 'skip_thought_signature_validator' - return { - middlewareVersion: 'v2', - - transformParams: async ({ params }) => { - const transformedParams = { ...params } - // Process messages in prompt - if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) { - transformedParams.prompt = transformedParams.prompt.map((message) => { - if (typeof message.content !== 'string') { - for (const part of message.content) { - const googleOptions = part?.providerOptions?.[aiSdkId] - if (googleOptions?.thoughtSignature) { - googleOptions.thoughtSignature = MAGIC_STRING - } - } - } - return message - }) - } - - return transformedParams - } - } -} diff --git a/tsconfig.node.json b/tsconfig.node.json index 9871e604f2..4f9e797146 100644 --- a/tsconfig.node.json +++ b/tsconfig.node.json @@ -7,9 +7,11 @@ "src/main/env.d.ts", "src/renderer/src/types/*", "packages/shared/**/*", + "packages/aiCore/src/**/*", "scripts", "packages/mcp-trace/**/*", - "src/renderer/src/services/traceApi.ts" + "src/renderer/src/services/traceApi.ts", + "packages/ai-sdk-provider/**/*" ], "compilerOptions": { "composite": true,