feat: enhance AI SDK integration with middleware support and improve message handling

This commit is contained in:
suyao 2025-11-28 04:12:18 +08:00
parent ce25001590
commit 356e828422
No known key found for this signature in database
4 changed files with 130 additions and 60 deletions

View File

@ -36,7 +36,7 @@ import type {
Usage
} from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger'
import type { TextStreamPart, ToolSet } from 'ai'
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
@ -56,6 +56,7 @@ interface AdapterState {
model: string
inputTokens: number
outputTokens: number
cacheInputTokens: number
currentBlockIndex: number
blocks: Map<number, ContentBlockState>
textBlockIndex: number | null
@ -67,10 +68,6 @@ interface AdapterState {
hasEmittedMessageStart: boolean
}
// ============================================================================
// Adapter Class
// ============================================================================
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
export interface AiSdkToAnthropicSSEOptions {
@ -94,6 +91,7 @@ export class AiSdkToAnthropicSSE {
model: options.model,
inputTokens: options.inputTokens || 0,
outputTokens: 0,
cacheInputTokens: 0,
currentBlockIndex: 0,
blocks: new Map(),
textBlockIndex: null,
@ -153,19 +151,19 @@ export class AiSdkToAnthropicSSE {
// === Reasoning/Thinking Events ===
case 'reasoning-start': {
const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}`
const reasoningId = chunk.id
this.startThinkingBlock(reasoningId)
break
}
case 'reasoning-delta': {
const reasoningId = (chunk as { id?: string }).id
const reasoningId = chunk.id
this.emitThinkingDelta(chunk.text || '', reasoningId)
break
}
case 'reasoning-end': {
const reasoningId = (chunk as { id?: string }).id
const reasoningId = chunk.id
this.stopThinkingBlock(reasoningId)
break
}
@ -176,14 +174,18 @@ export class AiSdkToAnthropicSSE {
type: 'tool-call',
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
// AI SDK uses 'args' in some versions and 'input' in others
args: 'args' in chunk ? chunk.args : (chunk as any).input
args: chunk.input
})
break
case 'tool-result':
// Tool results are handled separately in Anthropic API
// They come from user messages, not assistant stream
// this.handleToolResult({
// type: 'tool-result',
// toolCallId: chunk.toolCallId,
// toolName: chunk.toolName,
// args: chunk.input,
// result: chunk.output
// })
break
// === Completion Events ===
@ -465,34 +467,29 @@ export class AiSdkToAnthropicSSE {
this.state.stopReason = 'tool_use'
}
private handleFinish(chunk: {
type: 'finish'
finishReason?: string
totalUsage?: {
inputTokens?: number
outputTokens?: number
}
}): void {
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
// Update usage
if (chunk.totalUsage) {
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
}
// Determine finish reason
if (!this.state.stopReason) {
switch (chunk.finishReason) {
case 'stop':
case 'end_turn':
this.state.stopReason = 'end_turn'
break
case 'length':
case 'max_tokens':
this.state.stopReason = 'max_tokens'
break
case 'tool-calls':
this.state.stopReason = 'tool_use'
break
case 'content-filter':
this.state.stopReason = 'refusal'
break
default:
this.state.stopReason = 'end_turn'
}
@ -539,8 +536,8 @@ export class AiSdkToAnthropicSSE {
// Emit message_delta with final stop reason and usage
const usage: MessageDeltaUsage = {
output_tokens: this.state.outputTokens,
input_tokens: null,
cache_creation_input_tokens: null,
input_tokens: this.state.inputTokens,
cache_creation_input_tokens: this.state.cacheInputTokens,
cache_read_input_tokens: null,
server_tool_use: null
}

View File

@ -50,7 +50,7 @@ export interface SharedMiddlewareConfig {
export function isGemini3ModelId(modelId?: string): boolean {
if (!modelId) return false
const lowerModelId = modelId.toLowerCase()
return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3')
return lowerModelId.includes('gemini-3')
}
/**

View File

@ -1,5 +1,7 @@
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger'
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware'
import { getAiSdkProviderId } from '@shared/provider'
import type { Provider } from '@types'
import type { Request, Response } from 'express'
import express from 'express'
@ -206,12 +208,26 @@ async function handleUnifiedProcessing({
return
}
const middlewareConfig: SharedMiddlewareConfig = {
modelId: actualModelId,
providerId: provider.id,
aiSdkProviderId: getAiSdkProviderId(provider)
}
const middlewares = buildSharedMiddlewares(middlewareConfig)
logger.debug('Built middlewares for unified processing', {
middlewareCount: middlewares.length,
modelId: actualModelId,
providerId: provider.id
})
if (request.stream) {
await streamUnifiedMessages({
response: res,
provider,
modelId: actualModelId,
params: request,
middlewares,
onError: (error) => {
logger.error('Stream error', error as Error)
},
@ -220,7 +236,12 @@ async function handleUnifiedProcessing({
}
})
} else {
const response = await generateUnifiedMessage(provider, actualModelId, request)
const response = await generateUnifiedMessage({
provider,
modelId: actualModelId,
params: request,
middlewares
})
res.json(response)
}
} catch (error: any) {

View File

@ -1,5 +1,5 @@
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type {
ImageBlockParam,
MessageCreateParams,
@ -7,9 +7,11 @@ import type {
Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages'
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { reduxService } from '@main/services/ReduxService'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters'
import { isGemini3ModelId } from '@shared/middleware'
import {
type AiSdkConfig,
type AiSdkConfigContext,
@ -21,13 +23,15 @@ import {
} from '@shared/provider'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai'
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai'
import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
const logger = loggerService.withContext('UnifiedMessagesService')
const MAGIC_STRING = 'skip_thought_signature_validator'
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
@ -64,10 +68,6 @@ export interface GenerateUnifiedMessageConfig {
plugins?: AiPlugin[]
}
// ============================================================================
// Internal Utilities
// ============================================================================
function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return {
@ -154,6 +154,19 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
}
}
// Build a map of tool_use_id -> toolName from all messages first
// This is needed because tool_result references tool_use from previous assistant messages
const toolCallIdToName = new Map<string, string>()
for (const msg of params.messages) {
if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'tool_use') {
toolCallIdToName.set(block.id, block.name)
}
}
}
}
// User/assistant messages
for (const msg of params.messages) {
if (typeof msg.content === 'string') {
@ -190,10 +203,12 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
input: block.input
})
} else if (block.type === 'tool_result') {
// Look up toolName from the pre-built map (covers cross-message references)
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
toolResultParts.push({
type: 'tool-result',
toolCallId: block.tool_use_id,
toolName: toolCallParts.find((t) => t.toolCallId === block.tool_use_id)?.toolName || 'unknown',
toolName,
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
})
}
@ -211,7 +226,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
} else {
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) {
messages.push({ role: 'assistant', content: assistantContent })
let providerOptions: ProviderOptions | undefined = undefined
if (isGemini3ModelId(params.model)) {
providerOptions = {
google: {
thoughtSignature: MAGIC_STRING
},
openrouter: {
reasoning_details: []
}
}
}
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
}
}
}
@ -229,6 +255,32 @@ interface ExecuteStreamConfig {
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
}
/**
* Create AI SDK provider instance from config
* Similar to renderer's createAiSdkProvider
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
let providerId = config.providerId
// Handle special provider modes (same as renderer)
if (providerId === 'openai' && config.options?.mode === 'chat') {
providerId = 'openai-chat'
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
providerId = 'azure-responses'
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
providerId = 'cherryin-chat'
}
const provider = await createProviderCore(providerId, config.options)
logger.debug('AI SDK provider created', {
providerId,
hasOptions: !!config.options
})
return provider
}
/**
* Core stream execution function - single source of truth for AI SDK calls
*/
@ -240,9 +292,20 @@ async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthro
logger.debug('Created AI SDK config', {
providerId: sdkConfig.providerId,
hasOptions: !!sdkConfig.options
hasOptions: !!sdkConfig.options,
message: params.messages
})
// Create provider instance and get language model
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
const baseModel = aiSdkProvider.languageModel(modelId)
// Apply middlewares if present
const model =
middlewares.length > 0 && typeof baseModel === 'object'
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
: baseModel
// Create executor with plugins
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
@ -250,36 +313,25 @@ async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthro
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
logger.debug('Converted messages', {
originalCount: params.messages.length,
convertedCount: coreMessages.length,
hasSystem: !!params.system,
hasTools: !!tools,
toolCount: tools ? Object.keys(tools).length : 0
})
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: onEvent || (() => {})
})
// Execute stream
const result = await executor.streamText(
{
model: modelId,
messages: coreMessages,
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
tools,
providerOptions: {}
},
{ middlewares }
)
// Execute stream - pass model object instead of string
const result = await executor.streamText({
model, // Now passing LanguageModel object, not string
messages: coreMessages,
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
tools,
providerOptions: {}
})
// Process the stream through the adapter
await adapter.processStream(result.fullStream)