From 356e82842299df02701d55cc59f989c6f3ca095f Mon Sep 17 00:00:00 2001 From: suyao Date: Fri, 28 Nov 2025 04:12:18 +0800 Subject: [PATCH] feat: enhance AI SDK integration with middleware support and improve message handling --- .../shared/adapters/AiSdkToAnthropicSSE.ts | 45 +++---- packages/shared/middleware/middlewares.ts | 2 +- src/main/apiServer/routes/messages.ts | 23 +++- .../apiServer/services/unified-messages.ts | 120 +++++++++++++----- 4 files changed, 130 insertions(+), 60 deletions(-) diff --git a/packages/shared/adapters/AiSdkToAnthropicSSE.ts b/packages/shared/adapters/AiSdkToAnthropicSSE.ts index 1674609236..f1d6b0c022 100644 --- a/packages/shared/adapters/AiSdkToAnthropicSSE.ts +++ b/packages/shared/adapters/AiSdkToAnthropicSSE.ts @@ -36,7 +36,7 @@ import type { Usage } from '@anthropic-ai/sdk/resources/messages' import { loggerService } from '@logger' -import type { TextStreamPart, ToolSet } from 'ai' +import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai' const logger = loggerService.withContext('AiSdkToAnthropicSSE') @@ -56,6 +56,7 @@ interface AdapterState { model: string inputTokens: number outputTokens: number + cacheInputTokens: number currentBlockIndex: number blocks: Map textBlockIndex: number | null @@ -67,10 +68,6 @@ interface AdapterState { hasEmittedMessageStart: boolean } -// ============================================================================ -// Adapter Class -// ============================================================================ - export type SSEEventCallback = (event: RawMessageStreamEvent) => void export interface AiSdkToAnthropicSSEOptions { @@ -94,6 +91,7 @@ export class AiSdkToAnthropicSSE { model: options.model, inputTokens: options.inputTokens || 0, outputTokens: 0, + cacheInputTokens: 0, currentBlockIndex: 0, blocks: new Map(), textBlockIndex: null, @@ -153,19 +151,19 @@ export class AiSdkToAnthropicSSE { // === Reasoning/Thinking Events === case 'reasoning-start': { - const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}` + const reasoningId = chunk.id this.startThinkingBlock(reasoningId) break } case 'reasoning-delta': { - const reasoningId = (chunk as { id?: string }).id + const reasoningId = chunk.id this.emitThinkingDelta(chunk.text || '', reasoningId) break } case 'reasoning-end': { - const reasoningId = (chunk as { id?: string }).id + const reasoningId = chunk.id this.stopThinkingBlock(reasoningId) break } @@ -176,14 +174,18 @@ export class AiSdkToAnthropicSSE { type: 'tool-call', toolCallId: chunk.toolCallId, toolName: chunk.toolName, - // AI SDK uses 'args' in some versions and 'input' in others - args: 'args' in chunk ? chunk.args : (chunk as any).input + args: chunk.input }) break case 'tool-result': - // Tool results are handled separately in Anthropic API - // They come from user messages, not assistant stream + // this.handleToolResult({ + // type: 'tool-result', + // toolCallId: chunk.toolCallId, + // toolName: chunk.toolName, + // args: chunk.input, + // result: chunk.output + // }) break // === Completion Events === @@ -465,34 +467,29 @@ export class AiSdkToAnthropicSSE { this.state.stopReason = 'tool_use' } - private handleFinish(chunk: { - type: 'finish' - finishReason?: string - totalUsage?: { - inputTokens?: number - outputTokens?: number - } - }): void { + private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void { // Update usage if (chunk.totalUsage) { this.state.inputTokens = chunk.totalUsage.inputTokens || 0 this.state.outputTokens = chunk.totalUsage.outputTokens || 0 + this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0 } // Determine finish reason if (!this.state.stopReason) { switch (chunk.finishReason) { case 'stop': - case 'end_turn': this.state.stopReason = 'end_turn' break case 'length': - case 'max_tokens': this.state.stopReason = 'max_tokens' break case 'tool-calls': this.state.stopReason = 'tool_use' break + case 'content-filter': + this.state.stopReason = 'refusal' + break default: this.state.stopReason = 'end_turn' } @@ -539,8 +536,8 @@ export class AiSdkToAnthropicSSE { // Emit message_delta with final stop reason and usage const usage: MessageDeltaUsage = { output_tokens: this.state.outputTokens, - input_tokens: null, - cache_creation_input_tokens: null, + input_tokens: this.state.inputTokens, + cache_creation_input_tokens: this.state.cacheInputTokens, cache_read_input_tokens: null, server_tool_use: null } diff --git a/packages/shared/middleware/middlewares.ts b/packages/shared/middleware/middlewares.ts index d9725101c2..de857699f7 100644 --- a/packages/shared/middleware/middlewares.ts +++ b/packages/shared/middleware/middlewares.ts @@ -50,7 +50,7 @@ export interface SharedMiddlewareConfig { export function isGemini3ModelId(modelId?: string): boolean { if (!modelId) return false const lowerModelId = modelId.toLowerCase() - return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3') + return lowerModelId.includes('gemini-3') } /** diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts index 907b498273..018e7d60ad 100644 --- a/src/main/apiServer/routes/messages.ts +++ b/src/main/apiServer/routes/messages.ts @@ -1,5 +1,7 @@ import type { MessageCreateParams } from '@anthropic-ai/sdk/resources' import { loggerService } from '@logger' +import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware' +import { getAiSdkProviderId } from '@shared/provider' import type { Provider } from '@types' import type { Request, Response } from 'express' import express from 'express' @@ -206,12 +208,26 @@ async function handleUnifiedProcessing({ return } + const middlewareConfig: SharedMiddlewareConfig = { + modelId: actualModelId, + providerId: provider.id, + aiSdkProviderId: getAiSdkProviderId(provider) + } + const middlewares = buildSharedMiddlewares(middlewareConfig) + + logger.debug('Built middlewares for unified processing', { + middlewareCount: middlewares.length, + modelId: actualModelId, + providerId: provider.id + }) + if (request.stream) { await streamUnifiedMessages({ response: res, provider, modelId: actualModelId, params: request, + middlewares, onError: (error) => { logger.error('Stream error', error as Error) }, @@ -220,7 +236,12 @@ async function handleUnifiedProcessing({ } }) } else { - const response = await generateUnifiedMessage(provider, actualModelId, request) + const response = await generateUnifiedMessage({ + provider, + modelId: actualModelId, + params: request, + middlewares + }) res.json(response) } } catch (error: any) { diff --git a/src/main/apiServer/services/unified-messages.ts b/src/main/apiServer/services/unified-messages.ts index be8b05aeac..4370a429d0 100644 --- a/src/main/apiServer/services/unified-messages.ts +++ b/src/main/apiServer/services/unified-messages.ts @@ -1,5 +1,5 @@ import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' -import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' +import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' import type { ImageBlockParam, MessageCreateParams, @@ -7,9 +7,11 @@ import type { Tool as AnthropicTool } from '@anthropic-ai/sdk/resources/messages' import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' +import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { loggerService } from '@logger' import { reduxService } from '@main/services/ReduxService' import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' +import { isGemini3ModelId } from '@shared/middleware' import { type AiSdkConfig, type AiSdkConfigContext, @@ -21,13 +23,15 @@ import { } from '@shared/provider' import { defaultAppHeaders } from '@shared/utils' import type { Provider } from '@types' -import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai' -import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai' +import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai' +import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai' import { net } from 'electron' import type { Response } from 'express' const logger = loggerService.withContext('UnifiedMessagesService') +const MAGIC_STRING = 'skip_thought_signature_validator' + initializeSharedProviders({ warn: (message) => logger.warn(message), error: (message, error) => logger.error(message, error) @@ -64,10 +68,6 @@ export interface GenerateUnifiedMessageConfig { plugins?: AiPlugin[] } -// ============================================================================ -// Internal Utilities -// ============================================================================ - function getMainProcessFormatContext(): ProviderFormatContext { const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') return { @@ -154,6 +154,19 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } } + // Build a map of tool_use_id -> toolName from all messages first + // This is needed because tool_result references tool_use from previous assistant messages + const toolCallIdToName = new Map() + for (const msg of params.messages) { + if (Array.isArray(msg.content)) { + for (const block of msg.content) { + if (block.type === 'tool_use') { + toolCallIdToName.set(block.id, block.name) + } + } + } + } + // User/assistant messages for (const msg of params.messages) { if (typeof msg.content === 'string') { @@ -190,10 +203,12 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage input: block.input }) } else if (block.type === 'tool_result') { + // Look up toolName from the pre-built map (covers cross-message references) + const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown' toolResultParts.push({ type: 'tool-result', toolCallId: block.tool_use_id, - toolName: toolCallParts.find((t) => t.toolCallId === block.tool_use_id)?.toolName || 'unknown', + toolName, output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' } }) } @@ -211,7 +226,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage } else { const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] if (assistantContent.length > 0) { - messages.push({ role: 'assistant', content: assistantContent }) + let providerOptions: ProviderOptions | undefined = undefined + if (isGemini3ModelId(params.model)) { + providerOptions = { + google: { + thoughtSignature: MAGIC_STRING + }, + openrouter: { + reasoning_details: [] + } + } + } + messages.push({ role: 'assistant', content: assistantContent, providerOptions }) } } } @@ -229,6 +255,32 @@ interface ExecuteStreamConfig { onEvent?: (event: Parameters[0]) => void } +/** + * Create AI SDK provider instance from config + * Similar to renderer's createAiSdkProvider + */ +async function createAiSdkProvider(config: AiSdkConfig): Promise { + let providerId = config.providerId + + // Handle special provider modes (same as renderer) + if (providerId === 'openai' && config.options?.mode === 'chat') { + providerId = 'openai-chat' + } else if (providerId === 'azure' && config.options?.mode === 'responses') { + providerId = 'azure-responses' + } else if (providerId === 'cherryin' && config.options?.mode === 'chat') { + providerId = 'cherryin-chat' + } + + const provider = await createProviderCore(providerId, config.options) + + logger.debug('AI SDK provider created', { + providerId, + hasOptions: !!config.options + }) + + return provider +} + /** * Core stream execution function - single source of truth for AI SDK calls */ @@ -240,9 +292,20 @@ async function executeStream(config: ExecuteStreamConfig): Promise 0 && typeof baseModel === 'object' + ? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel) + : baseModel + // Create executor with plugins const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) @@ -250,36 +313,25 @@ async function executeStream(config: ExecuteStreamConfig): Promise {}) }) - // Execute stream - const result = await executor.streamText( - { - model: modelId, - messages: coreMessages, - maxOutputTokens: params.max_tokens, - temperature: params.temperature, - topP: params.top_p, - stopSequences: params.stop_sequences, - stopWhen: stepCountIs(100), - headers: defaultAppHeaders(), - tools, - providerOptions: {} - }, - { middlewares } - ) + // Execute stream - pass model object instead of string + const result = await executor.streamText({ + model, // Now passing LanguageModel object, not string + messages: coreMessages, + maxOutputTokens: params.max_tokens, + temperature: params.temperature, + topP: params.top_p, + stopSequences: params.stop_sequences, + stopWhen: stepCountIs(100), + headers: defaultAppHeaders(), + tools, + providerOptions: {} + }) // Process the stream through the adapter await adapter.processStream(result.fullStream)