mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 14:31:35 +08:00
feat: enhance AI SDK integration with middleware support and improve message handling
This commit is contained in:
parent
ce25001590
commit
356e828422
@ -36,7 +36,7 @@ import type {
|
||||
Usage
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { loggerService } from '@logger'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
import type { FinishReason, LanguageModelUsage, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
|
||||
|
||||
@ -56,6 +56,7 @@ interface AdapterState {
|
||||
model: string
|
||||
inputTokens: number
|
||||
outputTokens: number
|
||||
cacheInputTokens: number
|
||||
currentBlockIndex: number
|
||||
blocks: Map<number, ContentBlockState>
|
||||
textBlockIndex: number | null
|
||||
@ -67,10 +68,6 @@ interface AdapterState {
|
||||
hasEmittedMessageStart: boolean
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Adapter Class
|
||||
// ============================================================================
|
||||
|
||||
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
|
||||
|
||||
export interface AiSdkToAnthropicSSEOptions {
|
||||
@ -94,6 +91,7 @@ export class AiSdkToAnthropicSSE {
|
||||
model: options.model,
|
||||
inputTokens: options.inputTokens || 0,
|
||||
outputTokens: 0,
|
||||
cacheInputTokens: 0,
|
||||
currentBlockIndex: 0,
|
||||
blocks: new Map(),
|
||||
textBlockIndex: null,
|
||||
@ -153,19 +151,19 @@ export class AiSdkToAnthropicSSE {
|
||||
|
||||
// === Reasoning/Thinking Events ===
|
||||
case 'reasoning-start': {
|
||||
const reasoningId = (chunk as { id?: string }).id || `reasoning_${Date.now()}`
|
||||
const reasoningId = chunk.id
|
||||
this.startThinkingBlock(reasoningId)
|
||||
break
|
||||
}
|
||||
|
||||
case 'reasoning-delta': {
|
||||
const reasoningId = (chunk as { id?: string }).id
|
||||
const reasoningId = chunk.id
|
||||
this.emitThinkingDelta(chunk.text || '', reasoningId)
|
||||
break
|
||||
}
|
||||
|
||||
case 'reasoning-end': {
|
||||
const reasoningId = (chunk as { id?: string }).id
|
||||
const reasoningId = chunk.id
|
||||
this.stopThinkingBlock(reasoningId)
|
||||
break
|
||||
}
|
||||
@ -176,14 +174,18 @@ export class AiSdkToAnthropicSSE {
|
||||
type: 'tool-call',
|
||||
toolCallId: chunk.toolCallId,
|
||||
toolName: chunk.toolName,
|
||||
// AI SDK uses 'args' in some versions and 'input' in others
|
||||
args: 'args' in chunk ? chunk.args : (chunk as any).input
|
||||
args: chunk.input
|
||||
})
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// Tool results are handled separately in Anthropic API
|
||||
// They come from user messages, not assistant stream
|
||||
// this.handleToolResult({
|
||||
// type: 'tool-result',
|
||||
// toolCallId: chunk.toolCallId,
|
||||
// toolName: chunk.toolName,
|
||||
// args: chunk.input,
|
||||
// result: chunk.output
|
||||
// })
|
||||
break
|
||||
|
||||
// === Completion Events ===
|
||||
@ -465,34 +467,29 @@ export class AiSdkToAnthropicSSE {
|
||||
this.state.stopReason = 'tool_use'
|
||||
}
|
||||
|
||||
private handleFinish(chunk: {
|
||||
type: 'finish'
|
||||
finishReason?: string
|
||||
totalUsage?: {
|
||||
inputTokens?: number
|
||||
outputTokens?: number
|
||||
}
|
||||
}): void {
|
||||
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
|
||||
// Update usage
|
||||
if (chunk.totalUsage) {
|
||||
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
|
||||
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
|
||||
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
|
||||
}
|
||||
|
||||
// Determine finish reason
|
||||
if (!this.state.stopReason) {
|
||||
switch (chunk.finishReason) {
|
||||
case 'stop':
|
||||
case 'end_turn':
|
||||
this.state.stopReason = 'end_turn'
|
||||
break
|
||||
case 'length':
|
||||
case 'max_tokens':
|
||||
this.state.stopReason = 'max_tokens'
|
||||
break
|
||||
case 'tool-calls':
|
||||
this.state.stopReason = 'tool_use'
|
||||
break
|
||||
case 'content-filter':
|
||||
this.state.stopReason = 'refusal'
|
||||
break
|
||||
default:
|
||||
this.state.stopReason = 'end_turn'
|
||||
}
|
||||
@ -539,8 +536,8 @@ export class AiSdkToAnthropicSSE {
|
||||
// Emit message_delta with final stop reason and usage
|
||||
const usage: MessageDeltaUsage = {
|
||||
output_tokens: this.state.outputTokens,
|
||||
input_tokens: null,
|
||||
cache_creation_input_tokens: null,
|
||||
input_tokens: this.state.inputTokens,
|
||||
cache_creation_input_tokens: this.state.cacheInputTokens,
|
||||
cache_read_input_tokens: null,
|
||||
server_tool_use: null
|
||||
}
|
||||
|
||||
@ -50,7 +50,7 @@ export interface SharedMiddlewareConfig {
|
||||
export function isGemini3ModelId(modelId?: string): boolean {
|
||||
if (!modelId) return false
|
||||
const lowerModelId = modelId.toLowerCase()
|
||||
return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3')
|
||||
return lowerModelId.includes('gemini-3')
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||
import { loggerService } from '@logger'
|
||||
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware'
|
||||
import { getAiSdkProviderId } from '@shared/provider'
|
||||
import type { Provider } from '@types'
|
||||
import type { Request, Response } from 'express'
|
||||
import express from 'express'
|
||||
@ -206,12 +208,26 @@ async function handleUnifiedProcessing({
|
||||
return
|
||||
}
|
||||
|
||||
const middlewareConfig: SharedMiddlewareConfig = {
|
||||
modelId: actualModelId,
|
||||
providerId: provider.id,
|
||||
aiSdkProviderId: getAiSdkProviderId(provider)
|
||||
}
|
||||
const middlewares = buildSharedMiddlewares(middlewareConfig)
|
||||
|
||||
logger.debug('Built middlewares for unified processing', {
|
||||
middlewareCount: middlewares.length,
|
||||
modelId: actualModelId,
|
||||
providerId: provider.id
|
||||
})
|
||||
|
||||
if (request.stream) {
|
||||
await streamUnifiedMessages({
|
||||
response: res,
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares,
|
||||
onError: (error) => {
|
||||
logger.error('Stream error', error as Error)
|
||||
},
|
||||
@ -220,7 +236,12 @@ async function handleUnifiedProcessing({
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const response = await generateUnifiedMessage(provider, actualModelId, request)
|
||||
const response = await generateUnifiedMessage({
|
||||
provider,
|
||||
modelId: actualModelId,
|
||||
params: request,
|
||||
middlewares
|
||||
})
|
||||
res.json(response)
|
||||
}
|
||||
} catch (error: any) {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
||||
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||
import type {
|
||||
ImageBlockParam,
|
||||
MessageCreateParams,
|
||||
@ -7,9 +7,11 @@ import type {
|
||||
Tool as AnthropicTool
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters'
|
||||
import { isGemini3ModelId } from '@shared/middleware'
|
||||
import {
|
||||
type AiSdkConfig,
|
||||
type AiSdkConfigContext,
|
||||
@ -21,13 +23,15 @@ import {
|
||||
} from '@shared/provider'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import type { Provider } from '@types'
|
||||
import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai'
|
||||
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai'
|
||||
import type { ImagePart, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
|
||||
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel } from 'ai'
|
||||
import { net } from 'electron'
|
||||
import type { Response } from 'express'
|
||||
|
||||
const logger = loggerService.withContext('UnifiedMessagesService')
|
||||
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
|
||||
initializeSharedProviders({
|
||||
warn: (message) => logger.warn(message),
|
||||
error: (message, error) => logger.error(message, error)
|
||||
@ -64,10 +68,6 @@ export interface GenerateUnifiedMessageConfig {
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Internal Utilities
|
||||
// ============================================================================
|
||||
|
||||
function getMainProcessFormatContext(): ProviderFormatContext {
|
||||
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
|
||||
return {
|
||||
@ -154,6 +154,19 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
}
|
||||
}
|
||||
|
||||
// Build a map of tool_use_id -> toolName from all messages first
|
||||
// This is needed because tool_result references tool_use from previous assistant messages
|
||||
const toolCallIdToName = new Map<string, string>()
|
||||
for (const msg of params.messages) {
|
||||
if (Array.isArray(msg.content)) {
|
||||
for (const block of msg.content) {
|
||||
if (block.type === 'tool_use') {
|
||||
toolCallIdToName.set(block.id, block.name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// User/assistant messages
|
||||
for (const msg of params.messages) {
|
||||
if (typeof msg.content === 'string') {
|
||||
@ -190,10 +203,12 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
input: block.input
|
||||
})
|
||||
} else if (block.type === 'tool_result') {
|
||||
// Look up toolName from the pre-built map (covers cross-message references)
|
||||
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
|
||||
toolResultParts.push({
|
||||
type: 'tool-result',
|
||||
toolCallId: block.tool_use_id,
|
||||
toolName: toolCallParts.find((t) => t.toolCallId === block.tool_use_id)?.toolName || 'unknown',
|
||||
toolName,
|
||||
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
|
||||
})
|
||||
}
|
||||
@ -211,7 +226,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
|
||||
} else {
|
||||
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
|
||||
if (assistantContent.length > 0) {
|
||||
messages.push({ role: 'assistant', content: assistantContent })
|
||||
let providerOptions: ProviderOptions | undefined = undefined
|
||||
if (isGemini3ModelId(params.model)) {
|
||||
providerOptions = {
|
||||
google: {
|
||||
thoughtSignature: MAGIC_STRING
|
||||
},
|
||||
openrouter: {
|
||||
reasoning_details: []
|
||||
}
|
||||
}
|
||||
}
|
||||
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -229,6 +255,32 @@ interface ExecuteStreamConfig {
|
||||
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Create AI SDK provider instance from config
|
||||
* Similar to renderer's createAiSdkProvider
|
||||
*/
|
||||
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
|
||||
let providerId = config.providerId
|
||||
|
||||
// Handle special provider modes (same as renderer)
|
||||
if (providerId === 'openai' && config.options?.mode === 'chat') {
|
||||
providerId = 'openai-chat'
|
||||
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
|
||||
providerId = 'azure-responses'
|
||||
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
|
||||
providerId = 'cherryin-chat'
|
||||
}
|
||||
|
||||
const provider = await createProviderCore(providerId, config.options)
|
||||
|
||||
logger.debug('AI SDK provider created', {
|
||||
providerId,
|
||||
hasOptions: !!config.options
|
||||
})
|
||||
|
||||
return provider
|
||||
}
|
||||
|
||||
/**
|
||||
* Core stream execution function - single source of truth for AI SDK calls
|
||||
*/
|
||||
@ -240,9 +292,20 @@ async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthro
|
||||
|
||||
logger.debug('Created AI SDK config', {
|
||||
providerId: sdkConfig.providerId,
|
||||
hasOptions: !!sdkConfig.options
|
||||
hasOptions: !!sdkConfig.options,
|
||||
message: params.messages
|
||||
})
|
||||
|
||||
// Create provider instance and get language model
|
||||
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
|
||||
const baseModel = aiSdkProvider.languageModel(modelId)
|
||||
|
||||
// Apply middlewares if present
|
||||
const model =
|
||||
middlewares.length > 0 && typeof baseModel === 'object'
|
||||
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
|
||||
: baseModel
|
||||
|
||||
// Create executor with plugins
|
||||
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
|
||||
|
||||
@ -250,36 +313,25 @@ async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthro
|
||||
const coreMessages = convertAnthropicToAiMessages(params)
|
||||
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||
|
||||
logger.debug('Converted messages', {
|
||||
originalCount: params.messages.length,
|
||||
convertedCount: coreMessages.length,
|
||||
hasSystem: !!params.system,
|
||||
hasTools: !!tools,
|
||||
toolCount: tools ? Object.keys(tools).length : 0
|
||||
})
|
||||
|
||||
// Create the adapter
|
||||
const adapter = new AiSdkToAnthropicSSE({
|
||||
model: `${provider.id}:${modelId}`,
|
||||
onEvent: onEvent || (() => {})
|
||||
})
|
||||
|
||||
// Execute stream
|
||||
const result = await executor.streamText(
|
||||
{
|
||||
model: modelId,
|
||||
messages: coreMessages,
|
||||
maxOutputTokens: params.max_tokens,
|
||||
temperature: params.temperature,
|
||||
topP: params.top_p,
|
||||
stopSequences: params.stop_sequences,
|
||||
stopWhen: stepCountIs(100),
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
providerOptions: {}
|
||||
},
|
||||
{ middlewares }
|
||||
)
|
||||
// Execute stream - pass model object instead of string
|
||||
const result = await executor.streamText({
|
||||
model, // Now passing LanguageModel object, not string
|
||||
messages: coreMessages,
|
||||
maxOutputTokens: params.max_tokens,
|
||||
temperature: params.temperature,
|
||||
topP: params.top_p,
|
||||
stopSequences: params.stop_sequences,
|
||||
stopWhen: stepCountIs(100),
|
||||
headers: defaultAppHeaders(),
|
||||
tools,
|
||||
providerOptions: {}
|
||||
})
|
||||
|
||||
// Process the stream through the adapter
|
||||
await adapter.processStream(result.fullStream)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user