diff --git a/package.json b/package.json index 9c4398cdc2..d846f69d3a 100644 --- a/package.json +++ b/package.json @@ -148,7 +148,7 @@ "@modelcontextprotocol/sdk": "^1.17.5", "@mozilla/readability": "^0.6.0", "@notionhq/client": "^2.2.15", - "@openrouter/ai-sdk-provider": "^1.1.2", + "@openrouter/ai-sdk-provider": "^1.2.0", "@opentelemetry/api": "^1.9.0", "@opentelemetry/core": "2.0.0", "@opentelemetry/exporter-trace-otlp-http": "^0.200.0", @@ -391,7 +391,8 @@ "@img/sharp-linux-arm": "0.34.3", "@img/sharp-linux-arm64": "0.34.3", "@img/sharp-linux-x64": "0.34.3", - "@img/sharp-win32-x64": "0.34.3" + "@img/sharp-win32-x64": "0.34.3", + "openai@npm:5.12.2": "npm:@cherrystudio/openai@6.5.0" }, "packageManager": "yarn@4.9.1", "lint-staged": { diff --git a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts index 6c3d12bb41..5d13d6ff70 100644 --- a/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/openai/OpenAIResponseAPIClient.ts @@ -342,29 +342,28 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient< } } switch (message.type) { - case 'function_call_output': - { - let str = '' - if (typeof message.output === 'string') { - str = message.output - } else { - for (const part of message.output) { - switch (part.type) { - case 'input_text': - str += part.text - break - case 'input_image': - str += part.image_url || '' - break - case 'input_file': - str += part.file_data || '' - break - } + case 'function_call_output': { + let str = '' + if (typeof message.output === 'string') { + str = message.output + } else { + for (const part of message.output) { + switch (part.type) { + case 'input_text': + str += part.text + break + case 'input_image': + str += part.image_url || '' + break + case 'input_file': + str += part.file_data || '' + break } } - sum += estimateTextTokens(str) } + sum += estimateTextTokens(str) break + } case 'function_call': sum += estimateTextTokens(message.arguments) break diff --git a/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts b/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts index de8034d514..40ab43c561 100644 --- a/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts +++ b/src/renderer/src/aiCore/legacy/middleware/feat/ImageGenerationMiddleware.ts @@ -78,6 +78,12 @@ export const ImageGenerationMiddleware: CompletionsMiddleware = const options = { signal, timeout: defaultTimeout } if (imageFiles.length > 0) { + const model = assistant.model + const provider = context.apiClientInstance.provider + // https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/dall-e?tabs=gpt-image-1#call-the-image-edit-api + if (model.id.toLowerCase().includes('gpt-image-1-mini') && provider.type === 'azure-openai') { + throw new Error('Azure OpenAI GPT-Image-1-Mini model does not support image editing.') + } response = await sdk.images.edit( { model: assistant.model.id, diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index 46cb0af6f0..924cc5f47e 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -1,10 +1,12 @@ import { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins' import { loggerService } from '@logger' -import type { MCPTool, Message, Model, Provider } from '@renderer/types' +import { type MCPTool, type Message, type Model, type Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai' +import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { noThinkMiddleware } from './noThinkMiddleware' +import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' import { toolChoiceMiddleware } from './toolChoiceMiddleware' const logger = loggerService.withContext('AiSdkMiddlewareBuilder') @@ -213,15 +215,16 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: /** * 添加模型特定的中间件 */ -function addModelSpecificMiddlewares(_: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { - if (!config.model) return +function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void { + if (!config.model || !config.provider) return // 可以根据模型ID或特性添加特定中间件 // 例如:图像生成模型、多模态模型等 - - // 示例:某些模型需要特殊处理 - if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) { - // 图像生成相关中间件 + if (isOpenRouterGeminiGenerateImageModel(config.model, config.provider)) { + builder.add({ + name: 'openrouter-gemini-image-generation', + middleware: openrouterGenerateImageMiddleware() + }) } } diff --git a/src/renderer/src/aiCore/middleware/openrouterGenerateImageMiddleware.ts b/src/renderer/src/aiCore/middleware/openrouterGenerateImageMiddleware.ts new file mode 100644 index 0000000000..0110d9a4f0 --- /dev/null +++ b/src/renderer/src/aiCore/middleware/openrouterGenerateImageMiddleware.ts @@ -0,0 +1,33 @@ +import { LanguageModelMiddleware } from 'ai' + +/** + * Returns a LanguageModelMiddleware that ensures the OpenRouter provider is configured to support both + * image and text modalities. + * https://openrouter.ai/docs/features/multimodal/image-generation + * + * Remarks: + * - The middleware declares middlewareVersion as 'v2'. + * - transformParams asynchronously clones the incoming params and sets + * providerOptions.openrouter.modalities = ['image', 'text'], preserving other providerOptions and + * openrouter fields when present. + * - Intended to ensure the provider can handle image and text generation without altering other + * parameter values. + * + * @returns LanguageModelMiddleware - a middleware that augments providerOptions for OpenRouter to include image and text modalities. + */ +export function openrouterGenerateImageMiddleware(): LanguageModelMiddleware { + return { + middlewareVersion: 'v2', + + transformParams: async ({ params }) => { + const transformedParams = { ...params } + transformedParams.providerOptions = { + ...transformedParams.providerOptions, + openrouter: { ...transformedParams.providerOptions?.openrouter, modalities: ['image', 'text'] } + } + transformedParams + + return transformedParams + } + } +} diff --git a/src/renderer/src/aiCore/prepareParams/messageConverter.ts b/src/renderer/src/aiCore/prepareParams/messageConverter.ts index 4c2d5baba6..46cacb5b74 100644 --- a/src/renderer/src/aiCore/prepareParams/messageConverter.ts +++ b/src/renderer/src/aiCore/prepareParams/messageConverter.ts @@ -4,7 +4,7 @@ */ import { loggerService } from '@logger' -import { isVisionModel } from '@renderer/config/models' +import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models' import type { Message, Model } from '@renderer/types' import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' import { @@ -47,6 +47,41 @@ export async function convertMessageToSdkParam( } } +async function convertImageBlockToImagePart(imageBlocks: ImageMessageBlock[]): Promise> { + const parts: Array = [] + for (const imageBlock of imageBlocks) { + if (imageBlock.file) { + try { + const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) + parts.push({ + type: 'image', + image: image.base64, + mediaType: image.mime + }) + } catch (error) { + logger.warn('Failed to load image:', error as Error) + } + } else if (imageBlock.url) { + const isBase64 = imageBlock.url.startsWith('data:') + if (isBase64) { + const base64 = imageBlock.url.match(/^data:[^;]*;base64,(.+)$/)![1] + const mimeMatch = imageBlock.url.match(/^data:([^;]+)/) + parts.push({ + type: 'image', + image: base64, + mediaType: mimeMatch ? mimeMatch[1] : 'image/png' + }) + } else { + parts.push({ + type: 'image', + image: imageBlock.url + }) + } + } + } + return parts +} + /** * 转换为用户模型消息 */ @@ -64,25 +99,7 @@ async function convertMessageToUserModelMessage( // 处理图片(仅在支持视觉的模型中) if (isVisionModel) { - for (const imageBlock of imageBlocks) { - if (imageBlock.file) { - try { - const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) - parts.push({ - type: 'image', - image: image.base64, - mediaType: image.mime - }) - } catch (error) { - logger.warn('Failed to load image:', error as Error) - } - } else if (imageBlock.url) { - parts.push({ - type: 'image', - image: imageBlock.url - }) - } - } + parts.push(...(await convertImageBlockToImagePart(imageBlocks))) } // 处理文件 for (const fileBlock of fileBlocks) { @@ -172,7 +189,27 @@ async function convertMessageToAssistantModelMessage( } /** - * 转换 Cherry Studio 消息数组为 AI SDK 消息数组 + * Converts an array of messages to SDK-compatible model messages. + * + * This function processes messages and transforms them into the format required by the SDK. + * It handles special cases for vision models and image enhancement models. + * + * @param messages - Array of messages to convert. Must contain at least 2 messages when using image enhancement models. + * @param model - The model configuration that determines conversion behavior + * + * @returns A promise that resolves to an array of SDK-compatible model messages + * + * @remarks + * For image enhancement models with 2+ messages: + * - Expects the second-to-last message (index length-2) to be an assistant message containing image blocks + * - Expects the last message (index length-1) to be a user message + * - Extracts images from the assistant message and appends them to the user message content + * - Returns only the last two processed messages [assistantSdkMessage, userSdkMessage] + * + * For other models: + * - Returns all converted messages in order + * + * The function automatically detects vision model capabilities and adjusts conversion accordingly. */ export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise { const sdkMessages: ModelMessage[] = [] @@ -182,6 +219,31 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M const sdkMessage = await convertMessageToSdkParam(message, isVision, model) sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage])) } + // Special handling for image enhancement models + // Only keep the last two messages and merge images into the user message + // [system?, user, assistant, user] + if (isImageEnhancementModel(model) && messages.length >= 3) { + const needUpdatedMessages = messages.slice(-2) + const needUpdatedSdkMessages = sdkMessages.slice(-2) + const assistantMessage = needUpdatedMessages.filter((m) => m.role === 'assistant')[0] + const assistantSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'assistant')[0] + const userSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'user')[0] + const systemSdkMessages = sdkMessages.filter((m) => m.role === 'system') + const imageBlocks = findImageBlocks(assistantMessage) + const imageParts = await convertImageBlockToImagePart(imageBlocks) + const parts: Array = [] + if (typeof userSdkMessage.content === 'string') { + parts.push({ type: 'text', text: userSdkMessage.content }) + parts.push(...imageParts) + userSdkMessage.content = parts + } else { + userSdkMessage.content.push(...imageParts) + } + if (systemSdkMessages.length > 0) { + return [systemSdkMessages[0], assistantSdkMessage, userSdkMessage] + } + return [assistantSdkMessage, userSdkMessage] + } return sdkMessages } diff --git a/src/renderer/src/aiCore/utils/image.ts b/src/renderer/src/aiCore/utils/image.ts index 7691f9d4b1..43d916640a 100644 --- a/src/renderer/src/aiCore/utils/image.ts +++ b/src/renderer/src/aiCore/utils/image.ts @@ -1,5 +1,15 @@ +import { isSystemProvider, Model, Provider, SystemProviderIds } from '@renderer/types' + export function buildGeminiGenerateImageParams(): Record { return { responseModalities: ['TEXT', 'IMAGE'] } } + +export function isOpenRouterGeminiGenerateImageModel(model: Model, provider: Provider): boolean { + return ( + model.id.includes('gemini-2.5-flash-image') && + isSystemProvider(provider) && + provider.id === SystemProviderIds.openrouter + ) +} diff --git a/src/renderer/src/config/models/vision.ts b/src/renderer/src/config/models/vision.ts index ceff0a10c3..19e6ce6047 100644 --- a/src/renderer/src/config/models/vision.ts +++ b/src/renderer/src/config/models/vision.ts @@ -83,7 +83,7 @@ export const IMAGE_ENHANCEMENT_MODELS = [ 'grok-2-image(?:-[\\w-]+)?', 'qwen-image-edit', 'gpt-image-1', - 'gemini-2.5-flash-image', + 'gemini-2.5-flash-image(?:-[\\w-]+)?', 'gemini-2.0-flash-preview-image-generation' ] diff --git a/yarn.lock b/yarn.lock index 672c4af35f..79909f1781 100644 --- a/yarn.lock +++ b/yarn.lock @@ -7116,13 +7116,13 @@ __metadata: languageName: node linkType: hard -"@openrouter/ai-sdk-provider@npm:^1.1.2": - version: 1.1.2 - resolution: "@openrouter/ai-sdk-provider@npm:1.1.2" +"@openrouter/ai-sdk-provider@npm:^1.2.0": + version: 1.2.0 + resolution: "@openrouter/ai-sdk-provider@npm:1.2.0" peerDependencies: ai: ^5.0.0 zod: ^3.24.1 || ^v4 - checksum: 10c0/1ad50804189910d52c2c10e479bec40dfbd2109820e43135d001f4f8706be6ace532d4769a8c30111f5870afdfa97b815c7334b2e4d8d36ca68b1578ce5d9a41 + checksum: 10c0/4ca7c471ec46bdd48eea9c56d94778a06ca4b74b6ef2ab892ab7eadbd409e3530ac0c5791cd80e88cafc44a49a76585e59707104792e3e3124237fed767104ef languageName: node linkType: hard @@ -13902,7 +13902,7 @@ __metadata: "@mozilla/readability": "npm:^0.6.0" "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch" "@notionhq/client": "npm:^2.2.15" - "@openrouter/ai-sdk-provider": "npm:^1.1.2" + "@openrouter/ai-sdk-provider": "npm:^1.2.0" "@opentelemetry/api": "npm:^1.9.0" "@opentelemetry/core": "npm:2.0.0" "@opentelemetry/exporter-trace-otlp-http": "npm:^0.200.0" @@ -23907,23 +23907,6 @@ __metadata: languageName: node linkType: hard -"openai@npm:5.12.2": - version: 5.12.2 - resolution: "openai@npm:5.12.2" - peerDependencies: - ws: ^8.18.0 - zod: ^3.23.8 - peerDependenciesMeta: - ws: - optional: true - zod: - optional: true - bin: - openai: bin/cli - checksum: 10c0/7737b9b24edc81fcf9e6dcfb18a196cc0f8e29b6e839adf06a2538558c03908e3aa4cd94901b1a7f4a9dd62676fe9e34d6202281b2395090d998618ea1614c0c - languageName: node - linkType: hard - "openapi-types@npm:^12.1.3": version: 12.1.3 resolution: "openapi-types@npm:12.1.3"