feat: add shared AI SDK middlewares and refactor middleware handling

This commit is contained in:
suyao 2025-11-28 01:27:20 +08:00
parent 77c1b77113
commit ce25001590
No known key found for this signature in database
7 changed files with 401 additions and 373 deletions

View File

@ -0,0 +1,15 @@
/**
* Shared AI SDK Middlewares
*
* Environment-agnostic middlewares that can be used in both
* renderer process and main process (API server).
*/
export {
buildSharedMiddlewares,
getReasoningTagName,
isGemini3ModelId,
openrouterReasoningMiddleware,
type SharedMiddlewareConfig,
skipGeminiThoughtSignatureMiddleware
} from './middlewares'

View File

@ -0,0 +1,205 @@
/**
* Shared AI SDK Middlewares
*
* These middlewares are environment-agnostic and can be used in both
* renderer process and main process (API server).
*/
import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
import { extractReasoningMiddleware } from 'ai'
/**
* Configuration for building shared middlewares
*/
export interface SharedMiddlewareConfig {
/**
* Whether to enable reasoning extraction
*/
enableReasoning?: boolean
/**
* Tag name for reasoning extraction
* Defaults based on model ID
*/
reasoningTagName?: string
/**
* Model ID - used to determine default reasoning tag and model detection
*/
modelId?: string
/**
* Provider ID (Cherry Studio provider ID)
* Used for provider-specific middlewares like OpenRouter
*/
providerId?: string
/**
* AI SDK Provider ID
* Used for Gemini thought signature middleware
* e.g., 'google', 'google-vertex'
*/
aiSdkProviderId?: string
}
/**
* Check if model ID represents a Gemini 3 (2.5) model
* that requires thought signature handling
*
* @param modelId - The model ID string (not Model object)
*/
export function isGemini3ModelId(modelId?: string): boolean {
if (!modelId) return false
const lowerModelId = modelId.toLowerCase()
return lowerModelId.includes('gemini-2.5') || lowerModelId.includes('gemini-exp') || lowerModelId.includes('gemini-3')
}
/**
* Get the default reasoning tag name based on model ID
*
* Different models use different tags for reasoning content:
* - Most models: 'think'
* - GPT-OSS models: 'reasoning'
* - Gemini models: 'thought'
* - Seed models: 'seed:think'
*/
export function getReasoningTagName(modelId?: string): string {
if (!modelId) return 'think'
const lowerModelId = modelId.toLowerCase()
if (lowerModelId.includes('gpt-oss')) return 'reasoning'
if (lowerModelId.includes('gemini')) return 'thought'
if (lowerModelId.includes('seed-oss-36b')) return 'seed:think'
return 'think'
}
/**
* Skip Gemini Thought Signature Middleware
*
* Due to the complexity of multi-model client requests (which can switch
* to other models mid-process), this middleware skips all Gemini 3
* thinking signatures validation.
*
* @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex')
* @returns LanguageModelV2Middleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}
/**
* OpenRouter Reasoning Middleware
*
* Filters out [REDACTED] blocks from OpenRouter reasoning responses.
* OpenRouter may include [REDACTED] markers in reasoning content that
* should be removed for cleaner output.
*
* @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens
* @returns LanguageModelV2Middleware
*/
export function openrouterReasoningMiddleware(): LanguageModelV2Middleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}
/**
* Build shared middlewares based on configuration
*
* This function builds a set of middlewares that are commonly needed
* across different environments (renderer, API server).
*
* @param config - Configuration for middleware building
* @returns Array of AI SDK middlewares
*
* @example
* ```typescript
* import { buildSharedMiddlewares } from '@shared/middleware'
*
* const middlewares = buildSharedMiddlewares({
* enableReasoning: true,
* modelId: 'gemini-2.5-pro',
* providerId: 'openrouter',
* aiSdkProviderId: 'google'
* })
* ```
*/
export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] {
const middlewares: LanguageModelV2Middleware[] = []
// 1. Reasoning extraction middleware
if (config.enableReasoning) {
const tagName = config.reasoningTagName || getReasoningTagName(config.modelId)
middlewares.push(extractReasoningMiddleware({ tagName }))
}
// 2. OpenRouter-specific: filter [REDACTED] blocks
if (config.providerId === 'openrouter' && config.enableReasoning) {
middlewares.push(openrouterReasoningMiddleware())
}
// 3. Gemini 3 (2.5) specific: skip thought signature validation
if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) {
middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId))
}
return middlewares
}

View File

@ -1,4 +1,4 @@
import type { LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type {
ImageBlockParam,
@ -6,7 +6,7 @@ import type {
TextBlockParam,
Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { reduxService } from '@main/services/ReduxService'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@shared/adapters'
@ -21,8 +21,8 @@ import {
} from '@shared/provider'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, LanguageModel, ModelMessage, Provider as AiSdkProvider, TextPart, Tool } from 'ai'
import { jsonSchema, stepCountIs, streamText, tool } from 'ai'
import type { ImagePart, ModelMessage, TextPart, Tool } from 'ai'
import { jsonSchema, simulateStreamingMiddleware, stepCountIs, tool } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
@ -33,6 +33,9 @@ initializeSharedProviders({
error: (message, error) => logger.error(message, error)
})
/**
* Configuration for unified message streaming
*/
export interface UnifiedStreamConfig {
response: Response
provider: Provider
@ -40,12 +43,31 @@ export interface UnifiedStreamConfig {
params: MessageCreateParams
onError?: (error: unknown) => void
onComplete?: () => void
/**
* Optional AI SDK middlewares to apply
*/
middlewares?: LanguageModelV2Middleware[]
/**
* Optional AI Core plugins to use with the executor
*/
plugins?: AiPlugin[]
}
/**
* Main process format context for formatProviderApiHost
* Unlike renderer, main process doesn't have direct access to store getters, so use reduxService cache
* Configuration for non-streaming message generation
*/
export interface GenerateUnifiedMessageConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
}
// ============================================================================
// Internal Utilities
// ============================================================================
function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return {
@ -56,12 +78,7 @@ function getMainProcessFormatContext(): ProviderFormatContext {
}
}
/**
* Main process context for providerToAiSdkConfig
* Main process doesn't have access to browser APIs like window.keyv
*/
const mainProcessSdkContext: AiSdkConfigContext = {
// Simple key rotation - just return first key (no persistent rotation in main process)
getRotatedApiKey: (provider) => {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
@ -69,199 +86,82 @@ const mainProcessSdkContext: AiSdkConfigContext = {
fetch: net.fetch as typeof globalThis.fetch
}
/**
* Get actual provider configuration for a model
*
* For aggregated providers (new-api, aihubmix, vertexai, azure-openai),
* this resolves the actual provider type based on the model's characteristics.
*/
function getActualProvider(provider: Provider, modelId: string): Provider {
// Find the model in provider's models list
const model = provider.models?.find((m) => m.id === modelId)
if (!model) {
// If model not found, return provider as-is
return provider
}
// Resolve actual provider based on model
if (!model) return provider
return resolveActualProvider(provider, model)
}
/**
* Convert Cherry Studio Provider to AI SDK config
* Uses shared implementation with main process context
*/
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
// First resolve actual provider for aggregated providers
const actualProvider = getActualProvider(provider, modelId)
// Format the provider's apiHost for AI SDK
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
// Use shared implementation
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
}
/**
* Create an AI SDK provider from Cherry Studio provider configuration
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
try {
const provider = await createProviderCore(config.providerId, config.options)
logger.debug('AI SDK provider created', {
providerId: config.providerId,
hasOptions: !!config.options
})
return provider
} catch (error) {
logger.error('Failed to create AI SDK provider', error as Error, {
providerId: config.providerId
})
throw error
}
}
/**
* Create an AI SDK language model from a Cherry Studio provider configuration
* Uses shared provider utilities for consistent behavior with renderer
*/
async function createLanguageModel(provider: Provider, modelId: string): Promise<LanguageModel> {
logger.debug('Creating language model', {
providerId: provider.id,
providerType: provider.type,
modelId,
apiHost: provider.apiHost
})
// Convert provider config to AI SDK config
const config = providerToAiSdkConfig(provider, modelId)
// Create the AI SDK provider
const aiSdkProvider = await createAiSdkProvider(config)
if (!aiSdkProvider) {
throw new Error(`Failed to create AI SDK provider for ${provider.id}`)
}
// Get the language model
return aiSdkProvider.languageModel(modelId)
}
function convertAnthropicToolResultToAiSdk(
content: string | Array<TextBlockParam | ImageBlockParam>
): LanguageModelV2ToolResultOutput {
if (typeof content === 'string') {
return {
type: 'text',
value: content
}
} else {
const values: Array<
| { type: 'text'; text: string }
| {
type: 'media'
/**
Base-64 encoded media data.
*/
data: string
/**
IANA media type.
@see https://www.iana.org/assignments/media-types/media-types.xhtml
*/
mediaType: string
}
> = []
for (const block of content) {
if (block.type === 'text') {
values.push({
type: 'text',
text: block.text
})
} else if (block.type === 'image') {
values.push({
type: 'media',
data: block.source.type === 'base64' ? block.source.data : block.source.url,
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
})
}
}
return {
type: 'content',
value: values
return { type: 'text', value: content }
}
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
for (const block of content) {
if (block.type === 'text') {
values.push({ type: 'text', text: block.text })
} else if (block.type === 'image') {
values.push({
type: 'media',
data: block.source.type === 'base64' ? block.source.data : block.source.url,
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
})
}
}
return { type: 'content', value: values }
}
/**
* Convert Anthropic tools format to AI SDK tools format
*/
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, Tool> | undefined {
if (!tools || tools.length === 0) {
return undefined
}
if (!tools || tools.length === 0) return undefined
const aiSdkTools: Record<string, Tool> = {}
for (const anthropicTool of tools) {
// Handle different tool types
if (anthropicTool.type === 'bash_20250124') {
// Skip computer use and bash tools - these are Anthropic-specific
continue
}
// Regular tool (type === 'custom' or no type)
if (anthropicTool.type === 'bash_20250124') continue
const toolDef = anthropicTool as AnthropicTool
const parameters = toolDef.input_schema as Parameters<typeof jsonSchema>[0]
aiSdkTools[toolDef.name] = tool({
description: toolDef.description || '',
inputSchema: jsonSchema(parameters),
execute: async (input: Record<string, unknown>) => {
return input
}
execute: async (input: Record<string, unknown>) => input
})
}
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
}
/**
* Convert Anthropic MessageCreateParams to AI SDK message format
*/
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
const messages: ModelMessage[] = []
// Add system message if present
// System message
if (params.system) {
if (typeof params.system === 'string') {
messages.push({
role: 'system',
content: params.system
})
messages.push({ role: 'system', content: params.system })
} else if (Array.isArray(params.system)) {
// Handle TextBlockParam array
const systemText = params.system
.filter((block) => block.type === 'text')
.map((block) => block.text)
.join('\n')
if (systemText) {
messages.push({
role: 'system',
content: systemText
})
messages.push({ role: 'system', content: systemText })
}
}
}
// Convert user/assistant messages
// User/assistant messages
for (const msg of params.messages) {
if (typeof msg.content === 'string') {
if (msg.role === 'user') {
messages.push({ role: 'user', content: msg.content })
} else {
messages.push({ role: 'assistant', content: msg.content })
}
messages.push({
role: msg.role === 'user' ? 'user' : 'assistant',
content: msg.content
})
} else if (Array.isArray(msg.content)) {
// Handle content blocks
const textParts: TextPart[] = []
const imageParts: ImagePart[] = []
const reasoningParts: ReasoningPart[] = []
@ -278,15 +178,9 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
} else if (block.type === 'image') {
const source = block.source
if (source.type === 'base64') {
imageParts.push({
type: 'image',
image: `data:${source.media_type};base64,${source.data}`
})
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
} else if (source.type === 'url') {
imageParts.push({
type: 'image',
image: source.url
})
imageParts.push({ type: 'image', image: source.url })
}
} else if (block.type === 'tool_use') {
toolCallParts.push({
@ -306,30 +200,18 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
}
if (toolResultParts.length > 0) {
messages.push({
role: 'tool',
content: [...toolResultParts]
})
messages.push({ role: 'tool', content: [...toolResultParts] })
}
// Build the message based on role
// Only push user/assistant message if there's actual content (avoid empty messages)
if (msg.role === 'user') {
const userContent = [...textParts, ...imageParts]
if (userContent.length > 0) {
messages.push({
role: 'user',
content: userContent
})
messages.push({ role: 'user', content: userContent })
}
} else {
// Assistant messages contain tool calls, not tool results
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) {
messages.push({
role: 'assistant',
content: assistantContent
})
messages.push({ role: 'assistant', content: assistantContent })
}
}
}
@ -338,67 +220,54 @@ function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage
return messages
}
/**
* Stream a message request using AI SDK and convert to Anthropic SSE format
*/
// TODO: 使用ai-core executor集成中间件和transformstream进来
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
const { response, provider, modelId, params, onError, onComplete } = config
interface ExecuteStreamConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
}
logger.info('Starting unified message stream', {
providerId: provider.id,
providerType: provider.type,
modelId,
stream: params.stream
/**
* Core stream execution function - single source of truth for AI SDK calls
*/
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
// Convert provider config to AI SDK config
const sdkConfig = providerToAiSdkConfig(provider, modelId)
logger.debug('Created AI SDK config', {
providerId: sdkConfig.providerId,
hasOptions: !!sdkConfig.options
})
try {
response.setHeader('Content-Type', 'text/event-stream')
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
// Create executor with plugins
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
const model = await createLanguageModel(provider, modelId)
// Convert messages and tools
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
const coreMessages = convertAnthropicToAiMessages(params)
logger.debug('Converted messages', {
originalCount: params.messages.length,
convertedCount: coreMessages.length,
hasSystem: !!params.system,
hasTools: !!tools,
toolCount: tools ? Object.keys(tools).length : 0
})
// Convert tools if present
const tools = convertAnthropicToolsToAiSdk(params.tools)
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: onEvent || (() => {})
})
logger.debug('Converted messages', {
originalCount: params.messages.length,
convertedCount: coreMessages.length,
hasSystem: !!params.system,
hasTools: !!tools,
toolCount: tools ? Object.keys(tools).length : 0,
toolNames: tools ? Object.keys(tools).slice(0, 10) : [],
paramsToolCount: params.tools?.length || 0
})
// Debug: Log message structure to understand tool_result handling
logger.silly('Message structure for debugging', {
messages: coreMessages.map((m) => ({
role: m.role,
contentTypes: Array.isArray(m.content)
? m.content.map((c: { type: string }) => c.type)
: typeof m.content === 'string'
? ['string']
: ['unknown']
}))
})
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: (event) => {
const sseData = formatSSEEvent(event)
response.write(sseData)
}
})
// Start streaming
const result = streamText({
model,
// Execute stream
const result = await executor.streamText(
{
model: modelId,
messages: coreMessages,
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
@ -408,38 +277,65 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
headers: defaultAppHeaders(),
tools,
providerOptions: {}
})
},
{ middlewares }
)
// Process the stream through the adapter
await adapter.processStream(result.fullStream)
// Process the stream through the adapter
await adapter.processStream(result.fullStream)
return adapter
}
/**
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
*/
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
logger.info('Starting unified message stream', {
providerId: provider.id,
providerType: provider.type,
modelId,
stream: params.stream,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
response.setHeader('Content-Type', 'text/event-stream')
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
await executeStream({
provider,
modelId,
params,
middlewares,
plugins,
onEvent: (event) => {
const sseData = formatSSEEvent(event)
response.write(sseData)
}
})
// Send done marker
response.write(formatSSEDone())
response.end()
logger.info('Unified message stream completed', {
providerId: provider.id,
modelId
})
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
onComplete?.()
} catch (error) {
logger.error('Error in unified message stream', error as Error, {
providerId: provider.id,
modelId
})
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
// Try to send error event if response is still writable
if (!response.writableEnded) {
try {
const errorMessage = error instanceof Error ? error.message : 'Unknown error'
response.write(
`event: error\ndata: ${JSON.stringify({
type: 'error',
error: {
type: 'api_error',
message: errorMessage
}
error: { type: 'api_error', message: errorMessage }
})}\n\n`
)
response.end()
@ -455,64 +351,61 @@ export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promis
/**
* Generate a non-streaming message response
*
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
* similar to renderer's ModernAiProvider pattern.
*/
export async function generateUnifiedMessage(
provider: Provider,
modelId: string,
params: MessageCreateParams
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
modelId?: string,
params?: MessageCreateParams
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
// Support both old signature and new config-based signature
let config: GenerateUnifiedMessageConfig
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
config = providerOrConfig
} else {
config = {
provider: providerOrConfig as Provider,
modelId: modelId!,
params: params!
}
}
const { provider, middlewares = [], plugins = [] } = config
logger.info('Starting unified message generation', {
providerId: provider.id,
providerType: provider.type,
modelId
modelId: config.modelId,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
// Create language model (async - uses @cherrystudio/ai-core)
const model = await createLanguageModel(provider, modelId)
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
// Convert messages and tools
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
// Create adapter to collect the response
let finalResponse: ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse> | null = null
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: () => {
// We don't need to emit events for non-streaming
}
const adapter = await executeStream({
provider,
modelId: config.modelId,
params: config.params,
middlewares: allMiddlewares,
plugins
})
// Generate text
const result = streamText({
model,
messages: coreMessages,
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
stopSequences: params.stop_sequences,
headers: defaultAppHeaders(),
tools,
stopWhen: stepCountIs(100)
})
// Process the stream to build the response
await adapter.processStream(result.fullStream)
// Get the final response
finalResponse = adapter.buildNonStreamingResponse()
const finalResponse = adapter.buildNonStreamingResponse()
logger.info('Unified message generation completed', {
providerId: provider.id,
modelId
modelId: config.modelId
})
return finalResponse
} catch (error) {
logger.error('Error in unified message generation', error as Error, {
providerId: provider.id,
modelId
modelId: config.modelId
})
throw error
}

View File

@ -5,6 +5,7 @@ import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
@ -13,9 +14,7 @@ import { getAiSdkProviderId } from '../provider/factory'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')

View File

@ -1,50 +0,0 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@ -1,36 +0,0 @@
import type { LanguageModelMiddleware } from 'ai'
/**
* skip Gemini Thought Signature Middleware
* Gemini3
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
* @param aiSdkId AI SDK Provider ID
* @returns LanguageModelMiddleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}

View File

@ -7,9 +7,11 @@
"src/main/env.d.ts",
"src/renderer/src/types/*",
"packages/shared/**/*",
"packages/aiCore/src/**/*",
"scripts",
"packages/mcp-trace/**/*",
"src/renderer/src/services/traceApi.ts"
"src/renderer/src/services/traceApi.ts",
"packages/ai-sdk-provider/**/*"
],
"compilerOptions": {
"composite": true,