feat: enhance AI SDK integration with middleware support and improve message handling

This commit is contained in:
suyao 2025-11-28 04:12:18 +08:00
parent ce25001590
commit 356e828422
No known key found for this signature in database
4 changed files with 130 additions and 60 deletions

View File

@ -36,7 +36,7 @@ import type {
Usage Usage
} from '@anthropic-ai/sdk/resources/messages' } from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import type { TextStreamPart, ToolSet } from 'ai' import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
const logger = loggerService.withContext('AiSdkToAnthropicSSE') const logger = loggerService.withContext('AiSdkToAnthropicSSE')
@ -56,6 +56,7 @@ interface AdapterState {
model: string model: string
inputTokens: number inputTokens: number
outputTokens: number outputTokens: number
cacheInputTokens: number
currentBlockIndex: number currentBlockIndex: number
blocks: Map<number, ContentBlockState> blocks: Map<number, ContentBlockState>
textBlockIndex: number | null textBlockIndex: number | null
@ -67,10 +68,6 @@ interface AdapterState {
hasEmittedMessageStart: boolean hasEmittedMessageStart: boolean
} }
// ============================================================================
// Adapter Class
// ============================================================================
export type SSEEventCallback = (event: RawMessageStreamEvent) => void export type SSEEventCallback = (event: RawMessageStreamEvent) => void
export interface AiSdkToAnthropicSSEOptions { export interface AiSdkToAnthropicSSEOptions {
@ -94,6 +91,7 @@ export class AiSdkToAnthropicSSE {
model: options.model, model: options.model,
inputTokens: options.inputTokens || 0, inputTokens: options.inputTokens || 0,
outputTokens: 0, outputTokens: 0,
cacheInputTokens: 0,
currentBlockIndex: 0, currentBlockIndex: 0,
blocks: new Map(), blocks: new Map(),
textBlockIndex: null, textBlockIndex: null,
@ -153,19 +151,19 @@ export class AiSdkToAnthropicSSE {
// === Reasoning/Thinking Events === // === Reasoning/Thinking Events ===
case 'reasoning-start': { case 'reasoning-start': {
const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}` const reasoningId = chunk.id
this.startThinkingBlock(reasoningId) this.startThinkingBlock(reasoningId)
break break
} }
case 'reasoning-delta': { case 'reasoning-delta': {
const reasoningId = (chunk as { id?: string }).id const reasoningId = chunk.id
this.emitThinkingDelta(chunk.text || '', reasoningId) this.emitThinkingDelta(chunk.text || '', reasoningId)
break break
} }
case 'reasoning-end': { case 'reasoning-end': {
const reasoningId = (chunk as { id?: string }).id const reasoningId = chunk.id
this.stopThinkingBlock(reasoningId) this.stopThinkingBlock(reasoningId)
break break
} }
@ -176,14 +174,18 @@ export class AiSdkToAnthropicSSE {
type: 'tool-call', type: 'tool-call',
toolCallId: chunk.toolCallId, toolCallId: chunk.toolCallId,
toolName: chunk.toolName, toolName: chunk.toolName,
// AI SDK uses 'args' in some versions and 'input' in others args: chunk.input
args: 'args' in chunk ? chunk.args : (chunk as any).input
}) })
break break
case 'tool-result': case 'tool-result':
// Tool results are handled separately in Anthropic API // this.handleToolResult({
// They come from user messages, not assistant stream // type: 'tool-result',
// toolCallId: chunk.toolCallId,
// toolName: chunk.toolName,
// args: chunk.input,
// result: chunk.output
// })
break break
// === Completion Events === // === Completion Events ===
@ -465,34 +467,29 @@ export class AiSdkToAnthropicSSE {
this.state.stopReason = 'tool_use' this.state.stopReason = 'tool_use'
} }
private handleFinish(chunk: { private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
type: 'finish'
finishReason?: string
totalUsage?: {
inputTokens?: number
outputTokens?: number
}
}): void {
// Update usage // Update usage
if (chunk.totalUsage) { if (chunk.totalUsage) {
this.state.inputTokens = chunk.totalUsage.inputTokens || 0 this.state.inputTokens = chunk.totalUsage.inputTokens || 0
this.state.outputTokens = chunk.totalUsage.outputTokens || 0 this.state.outputTokens = chunk.totalUsage.outputTokens || 0
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
} }
// Determine finish reason // Determine finish reason
if (!this.state.stopReason) { if (!this.state.stopReason) {
switch (chunk.finishReason) { switch (chunk.finishReason) {
case 'stop': case 'stop':
case 'end_turn':
this.state.stopReason = 'end_turn' this.state.stopReason = 'end_turn'
break break
case 'length': case 'length':
case 'max_tokens':
this.state.stopReason = 'max_tokens' this.state.stopReason = 'max_tokens'
break break
case 'tool-calls': case 'tool-calls':
this.state.stopReason = 'tool_use' this.state.stopReason = 'tool_use'
break break
case 'content-filter':
this.state.stopReason = 'refusal'
break
default: default:
this.state.stopReason = 'end_turn' this.state.stopReason = 'end_turn'
} }
@ -539,8 +536,8 @@ export class AiSdkToAnthropicSSE {
// Emit message_delta with final stop reason and usage // Emit message_delta with final stop reason and usage
const usage: MessageDeltaUsage = { const usage: MessageDeltaUsage = {
output_tokens: this.state.outputTokens, output_tokens: this.state.outputTokens,
input_tokens: null, input_tokens: this.state.inputTokens,
cache_creation_input_tokens: null, cache_creation_input_tokens: this.state.cacheInputTokens,
cache_read_input_tokens: null, cache_read_input_tokens: null,
server_tool_use: null server_tool_use: null
} }

View File

@ -50,7 +50,7 @@ export interface SharedMiddlewareConfig {
export function isGemini3ModelId(modelId?: string): boolean { export function isGemini3ModelId(modelId?: string): boolean {
if (!modelId) return false if (!modelId) return false
const lowerModelId = modelId.toLowerCase() const lowerModelId = modelId.toLowerCase()
return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3') return lowerModelId.includes('gemini-3')
} }
/** /**

View File

@ -1,5 +1,7 @@
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources' import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware'
import { getAiSdkProviderId } from '@shared/provider'
import type { Provider } from '@types' import type { Provider } from '@types'
import type { Request, Response } from 'express' import type { Request, Response } from 'express'
import express from 'express' import express from 'express'
@ -206,12 +208,26 @@ async function handleUnifiedProcessing({
return return
} }
const middlewareConfig: SharedMiddlewareConfig = {
modelId: actualModelId,
providerId: provider.id,
aiSdkProviderId: getAiSdkProviderId(provider)
}
const middlewares = buildSharedMiddlewares(middlewareConfig)
logger.debug('Built middlewares for unified processing', {
middlewareCount: middlewares.length,
modelId: actualModelId,
providerId: provider.id
})
if (request.stream) { if (request.stream) {
await streamUnifiedMessages({ await streamUnifiedMessages({
response: res, response: res,
provider, provider,
modelId: actualModelId, modelId: actualModelId,
params: request, params: request,
middlewares,
onError: (error) => { onError: (error) => {
logger.error('Stream error', error as Error) logger.error('Stream error', error as Error)
}, },
@ -220,7 +236,12 @@ async function handleUnifiedProcessing({
} }
}) })
} else { } else {
const response = await generateUnifiedMessage(provider, actualModelId, request) const response = await generateUnifiedMessage({
provider,
modelId: actualModelId,
params: request,
middlewares
})
res.json(response) res.json(response)
} }
} catch (error: any) { } catch (error: any) {

View File

@ -1,5 +1,5 @@
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider' import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils' import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type { import type {
ImageBlockParam, ImageBlockParam,
MessageCreateParams, MessageCreateParams,
@ -7,9 +7,11 @@ import type {
Tool as AnthropicTool Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages' } from '@anthropic-ai/sdk/resources/messages'
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core' import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { reduxService } from '@main/services/ReduxService' import { reduxService } from '@main/services/ReduxService'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters' import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters'
import { isGemini3ModelId } from '@shared/middleware'
import { import {
type AiSdkConfig, type AiSdkConfig,
type AiSdkConfigContext, type AiSdkConfigContext,
@ -21,13 +23,15 @@ import {
} from '@shared/provider' } from '@shared/provider'
import { defaultAppHeaders } from '@shared/utils' import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types' import type { Provider } from '@types'
import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai' import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai' import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai'
import { net } from 'electron' import { net } from 'electron'
import type { Response } from 'express' import type { Response } from 'express'
const logger = loggerService.withContext('UnifiedMessagesService') const logger = loggerService.withContext('UnifiedMessagesService')
const MAGIC_STRING = 'skip_thought_signature_validator'
initializeSharedProviders({ initializeSharedProviders({
warn: (message) => logger.warn(message), warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error) error: (message, error) => logger.error(message, error)
@ -64,10 +68,6 @@ export interface GenerateUnifiedMessageConfig {
plugins?: AiPlugin[] plugins?: AiPlugin[]
} }
// ============================================================================
// Internal Utilities
// ============================================================================
function getMainProcessFormatContext(): ProviderFormatContext { function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai') const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return { return {
@ -154,6 +154,19 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
} }
} }
// Build a map of tool_use_id -> toolName from all messages first
// This is needed because tool_result references tool_use from previous assistant messages
const toolCallIdToName = new Map<string, string>()
for (const msg of params.messages) {
if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'tool_use') {
toolCallIdToName.set(block.id, block.name)
}
}
}
}
// User/assistant messages // User/assistant messages
for (const msg of params.messages) { for (const msg of params.messages) {
if (typeof msg.content === 'string') { if (typeof msg.content === 'string') {
@ -190,10 +203,12 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
input: block.input input: block.input
}) })
} else if (block.type === 'tool_result') { } else if (block.type === 'tool_result') {
// Look up toolName from the pre-built map (covers cross-message references)
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
toolResultParts.push({ toolResultParts.push({
type: 'tool-result', type: 'tool-result',
toolCallId: block.tool_use_id, toolCallId: block.tool_use_id,
toolName: toolCallParts.find((t) => t.toolCallId === block.tool_use_id)?.toolName || 'unknown', toolName,
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' } output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
}) })
} }
@ -211,7 +226,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
} else { } else {
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts] const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) { if (assistantContent.length > 0) {
messages.push({ role: 'assistant', content: assistantContent }) let providerOptions: ProviderOptions | undefined = undefined
if (isGemini3ModelId(params.model)) {
providerOptions = {
google: {
thoughtSignature: MAGIC_STRING
},
openrouter: {
reasoning_details: []
}
}
}
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
} }
} }
} }
@ -229,6 +255,32 @@ interface ExecuteStreamConfig {
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
} }
/**
* Create AI SDK provider instance from config
* Similar to renderer's createAiSdkProvider
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
let providerId = config.providerId
// Handle special provider modes (same as renderer)
if (providerId === 'openai' && config.options?.mode === 'chat') {
providerId = 'openai-chat'
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
providerId = 'azure-responses'
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
providerId = 'cherryin-chat'
}
const provider = await createProviderCore(providerId, config.options)
logger.debug('AI SDK provider created', {
providerId,
hasOptions: !!config.options
})
return provider
}
/** /**
* Core stream execution function - single source of truth for AI SDK calls * Core stream execution function - single source of truth for AI SDK calls
*/ */
@ -240,9 +292,20 @@ async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthro
logger.debug('Created AI SDK config', { logger.debug('Created AI SDK config', {
providerId: sdkConfig.providerId, providerId: sdkConfig.providerId,
hasOptions: !!sdkConfig.options hasOptions: !!sdkConfig.options,
message: params.messages
}) })
// Create provider instance and get language model
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
const baseModel = aiSdkProvider.languageModel(modelId)
// Apply middlewares if present
const model =
middlewares.length > 0 && typeof baseModel === 'object'
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
: baseModel
// Create executor with plugins // Create executor with plugins
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins) const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
@ -250,36 +313,25 @@ async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthro
const coreMessages = convertAnthropicToAiMessages(params) const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools) const tools = convertAnthropicToolsToAiSdk(params.tools)
logger.debug('Converted messages', {
originalCount: params.messages.length,
convertedCount: coreMessages.length,
hasSystem: !!params.system,
hasTools: !!tools,
toolCount: tools ? Object.keys(tools).length : 0
})
// Create the adapter // Create the adapter
const adapter = new AiSdkToAnthropicSSE({ const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`, model: `${provider.id}:${modelId}`,
onEvent: onEvent || (() => {}) onEvent: onEvent || (() => {})
}) })
// Execute stream // Execute stream - pass model object instead of string
const result = await executor.streamText( const result = await executor.streamText({
{ model, // Now passing LanguageModel object, not string
model: modelId, messages: coreMessages,
messages: coreMessages, maxOutputTokens: params.max_tokens,
maxOutputTokens: params.max_tokens, temperature: params.temperature,
temperature: params.temperature, topP: params.top_p,
topP: params.top_p, stopSequences: params.stop_sequences,
stopSequences: params.stop_sequences, stopWhen: stepCountIs(100),
stopWhen: stepCountIs(100), headers: defaultAppHeaders(),
headers: defaultAppHeaders(), tools,
tools, providerOptions: {}
providerOptions: {} })
},
{ middlewares }
)
// Process the stream through the adapter // Process the stream through the adapter
await adapter.processStream(result.fullStream) await adapter.processStream(result.fullStream)