mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 06:30:10 +08:00
fix: azure gpt-image-1 and openrouter gemini-image (#10797)
* fix: azure gpt-image-1 and openroute gemini-image * feat: update encoding format handling for embeddings based on model type * fix: normalize model ID check for Azure OpenAI GPT-Image-1-Mini * feat: enhance regex for gemini-2.5-flash-image in image enhancement models * feat: 支持处理 base64 格式的图片 URL 在消息转换中 * feat: 更新消息转换函数以支持图像增强模型的特殊处理 * fix: update model handling in AzureOpenAI and Embeddings classes * fix: update OpenAI package version to 6.5.0 * fix: remove outdated OpenAI package patch for version 6.4.0 * fix: remove outdated OpenAI package entry from yarn.lock
This commit is contained in:
parent
6795a044fa
commit
ac4aa33e79
@ -148,7 +148,7 @@
|
||||
"@modelcontextprotocol/sdk": "^1.17.5",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||
"@openrouter/ai-sdk-provider": "^1.2.0",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "2.0.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||
@ -391,7 +391,8 @@
|
||||
"@img/sharp-linux-arm": "0.34.3",
|
||||
"@img/sharp-linux-arm64": "0.34.3",
|
||||
"@img/sharp-linux-x64": "0.34.3",
|
||||
"@img/sharp-win32-x64": "0.34.3"
|
||||
"@img/sharp-win32-x64": "0.34.3",
|
||||
"openai@npm:5.12.2": "npm:@cherrystudio/openai@6.5.0"
|
||||
},
|
||||
"packageManager": "yarn@4.9.1",
|
||||
"lint-staged": {
|
||||
|
||||
@ -342,29 +342,28 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
}
|
||||
}
|
||||
switch (message.type) {
|
||||
case 'function_call_output':
|
||||
{
|
||||
let str = ''
|
||||
if (typeof message.output === 'string') {
|
||||
str = message.output
|
||||
} else {
|
||||
for (const part of message.output) {
|
||||
switch (part.type) {
|
||||
case 'input_text':
|
||||
str += part.text
|
||||
break
|
||||
case 'input_image':
|
||||
str += part.image_url || ''
|
||||
break
|
||||
case 'input_file':
|
||||
str += part.file_data || ''
|
||||
break
|
||||
}
|
||||
case 'function_call_output': {
|
||||
let str = ''
|
||||
if (typeof message.output === 'string') {
|
||||
str = message.output
|
||||
} else {
|
||||
for (const part of message.output) {
|
||||
switch (part.type) {
|
||||
case 'input_text':
|
||||
str += part.text
|
||||
break
|
||||
case 'input_image':
|
||||
str += part.image_url || ''
|
||||
break
|
||||
case 'input_file':
|
||||
str += part.file_data || ''
|
||||
break
|
||||
}
|
||||
}
|
||||
sum += estimateTextTokens(str)
|
||||
}
|
||||
sum += estimateTextTokens(str)
|
||||
break
|
||||
}
|
||||
case 'function_call':
|
||||
sum += estimateTextTokens(message.arguments)
|
||||
break
|
||||
|
||||
@ -78,6 +78,12 @@ export const ImageGenerationMiddleware: CompletionsMiddleware =
|
||||
const options = { signal, timeout: defaultTimeout }
|
||||
|
||||
if (imageFiles.length > 0) {
|
||||
const model = assistant.model
|
||||
const provider = context.apiClientInstance.provider
|
||||
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/dall-e?tabs=gpt-image-1#call-the-image-edit-api
|
||||
if (model.id.toLowerCase().includes('gpt-image-1-mini') && provider.type === 'azure-openai') {
|
||||
throw new Error('Azure OpenAI GPT-Image-1-Mini model does not support image editing.')
|
||||
}
|
||||
response = await sdk.images.edit(
|
||||
{
|
||||
model: assistant.model.id,
|
||||
|
||||
@ -1,10 +1,12 @@
|
||||
import { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { loggerService } from '@logger'
|
||||
import type { MCPTool, Message, Model, Provider } from '@renderer/types'
|
||||
import { type MCPTool, type Message, type Model, type Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
|
||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||
import { noThinkMiddleware } from './noThinkMiddleware'
|
||||
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
||||
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||
@ -213,15 +215,16 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
|
||||
/**
|
||||
* 添加模型特定的中间件
|
||||
*/
|
||||
function addModelSpecificMiddlewares(_: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
|
||||
if (!config.model) return
|
||||
function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
|
||||
if (!config.model || !config.provider) return
|
||||
|
||||
// 可以根据模型ID或特性添加特定中间件
|
||||
// 例如:图像生成模型、多模态模型等
|
||||
|
||||
// 示例:某些模型需要特殊处理
|
||||
if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) {
|
||||
// 图像生成相关中间件
|
||||
if (isOpenRouterGeminiGenerateImageModel(config.model, config.provider)) {
|
||||
builder.add({
|
||||
name: 'openrouter-gemini-image-generation',
|
||||
middleware: openrouterGenerateImageMiddleware()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,33 @@
|
||||
import { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* Returns a LanguageModelMiddleware that ensures the OpenRouter provider is configured to support both
|
||||
* image and text modalities.
|
||||
* https://openrouter.ai/docs/features/multimodal/image-generation
|
||||
*
|
||||
* Remarks:
|
||||
* - The middleware declares middlewareVersion as 'v2'.
|
||||
* - transformParams asynchronously clones the incoming params and sets
|
||||
* providerOptions.openrouter.modalities = ['image', 'text'], preserving other providerOptions and
|
||||
* openrouter fields when present.
|
||||
* - Intended to ensure the provider can handle image and text generation without altering other
|
||||
* parameter values.
|
||||
*
|
||||
* @returns LanguageModelMiddleware - a middleware that augments providerOptions for OpenRouter to include image and text modalities.
|
||||
*/
|
||||
export function openrouterGenerateImageMiddleware(): LanguageModelMiddleware {
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
transformedParams.providerOptions = {
|
||||
...transformedParams.providerOptions,
|
||||
openrouter: { ...transformedParams.providerOptions?.openrouter, modalities: ['image', 'text'] }
|
||||
}
|
||||
transformedParams
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4,7 +4,7 @@
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models'
|
||||
import type { Message, Model } from '@renderer/types'
|
||||
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
|
||||
import {
|
||||
@ -47,6 +47,41 @@ export async function convertMessageToSdkParam(
|
||||
}
|
||||
}
|
||||
|
||||
async function convertImageBlockToImagePart(imageBlocks: ImageMessageBlock[]): Promise<Array<ImagePart>> {
|
||||
const parts: Array<ImagePart> = []
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
try {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
image: image.base64,
|
||||
mediaType: image.mime
|
||||
})
|
||||
} catch (error) {
|
||||
logger.warn('Failed to load image:', error as Error)
|
||||
}
|
||||
} else if (imageBlock.url) {
|
||||
const isBase64 = imageBlock.url.startsWith('data:')
|
||||
if (isBase64) {
|
||||
const base64 = imageBlock.url.match(/^data:[^;]*;base64,(.+)$/)![1]
|
||||
const mimeMatch = imageBlock.url.match(/^data:([^;]+)/)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
image: base64,
|
||||
mediaType: mimeMatch ? mimeMatch[1] : 'image/png'
|
||||
})
|
||||
} else {
|
||||
parts.push({
|
||||
type: 'image',
|
||||
image: imageBlock.url
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return parts
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换为用户模型消息
|
||||
*/
|
||||
@ -64,25 +99,7 @@ async function convertMessageToUserModelMessage(
|
||||
|
||||
// 处理图片(仅在支持视觉的模型中)
|
||||
if (isVisionModel) {
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
try {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
image: image.base64,
|
||||
mediaType: image.mime
|
||||
})
|
||||
} catch (error) {
|
||||
logger.warn('Failed to load image:', error as Error)
|
||||
}
|
||||
} else if (imageBlock.url) {
|
||||
parts.push({
|
||||
type: 'image',
|
||||
image: imageBlock.url
|
||||
})
|
||||
}
|
||||
}
|
||||
parts.push(...(await convertImageBlockToImagePart(imageBlocks)))
|
||||
}
|
||||
// 处理文件
|
||||
for (const fileBlock of fileBlocks) {
|
||||
@ -172,7 +189,27 @@ async function convertMessageToAssistantModelMessage(
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 Cherry Studio 消息数组为 AI SDK 消息数组
|
||||
* Converts an array of messages to SDK-compatible model messages.
|
||||
*
|
||||
* This function processes messages and transforms them into the format required by the SDK.
|
||||
* It handles special cases for vision models and image enhancement models.
|
||||
*
|
||||
* @param messages - Array of messages to convert. Must contain at least 2 messages when using image enhancement models.
|
||||
* @param model - The model configuration that determines conversion behavior
|
||||
*
|
||||
* @returns A promise that resolves to an array of SDK-compatible model messages
|
||||
*
|
||||
* @remarks
|
||||
* For image enhancement models with 2+ messages:
|
||||
* - Expects the second-to-last message (index length-2) to be an assistant message containing image blocks
|
||||
* - Expects the last message (index length-1) to be a user message
|
||||
* - Extracts images from the assistant message and appends them to the user message content
|
||||
* - Returns only the last two processed messages [assistantSdkMessage, userSdkMessage]
|
||||
*
|
||||
* For other models:
|
||||
* - Returns all converted messages in order
|
||||
*
|
||||
* The function automatically detects vision model capabilities and adjusts conversion accordingly.
|
||||
*/
|
||||
export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise<ModelMessage[]> {
|
||||
const sdkMessages: ModelMessage[] = []
|
||||
@ -182,6 +219,31 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M
|
||||
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
|
||||
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
|
||||
}
|
||||
// Special handling for image enhancement models
|
||||
// Only keep the last two messages and merge images into the user message
|
||||
// [system?, user, assistant, user]
|
||||
if (isImageEnhancementModel(model) && messages.length >= 3) {
|
||||
const needUpdatedMessages = messages.slice(-2)
|
||||
const needUpdatedSdkMessages = sdkMessages.slice(-2)
|
||||
const assistantMessage = needUpdatedMessages.filter((m) => m.role === 'assistant')[0]
|
||||
const assistantSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'assistant')[0]
|
||||
const userSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'user')[0]
|
||||
const systemSdkMessages = sdkMessages.filter((m) => m.role === 'system')
|
||||
const imageBlocks = findImageBlocks(assistantMessage)
|
||||
const imageParts = await convertImageBlockToImagePart(imageBlocks)
|
||||
const parts: Array<TextPart | ImagePart | FilePart> = []
|
||||
if (typeof userSdkMessage.content === 'string') {
|
||||
parts.push({ type: 'text', text: userSdkMessage.content })
|
||||
parts.push(...imageParts)
|
||||
userSdkMessage.content = parts
|
||||
} else {
|
||||
userSdkMessage.content.push(...imageParts)
|
||||
}
|
||||
if (systemSdkMessages.length > 0) {
|
||||
return [systemSdkMessages[0], assistantSdkMessage, userSdkMessage]
|
||||
}
|
||||
return [assistantSdkMessage, userSdkMessage]
|
||||
}
|
||||
|
||||
return sdkMessages
|
||||
}
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
import { isSystemProvider, Model, Provider, SystemProviderIds } from '@renderer/types'
|
||||
|
||||
export function buildGeminiGenerateImageParams(): Record<string, any> {
|
||||
return {
|
||||
responseModalities: ['TEXT', 'IMAGE']
|
||||
}
|
||||
}
|
||||
|
||||
export function isOpenRouterGeminiGenerateImageModel(model: Model, provider: Provider): boolean {
|
||||
return (
|
||||
model.id.includes('gemini-2.5-flash-image') &&
|
||||
isSystemProvider(provider) &&
|
||||
provider.id === SystemProviderIds.openrouter
|
||||
)
|
||||
}
|
||||
|
||||
@ -83,7 +83,7 @@ export const IMAGE_ENHANCEMENT_MODELS = [
|
||||
'grok-2-image(?:-[\\w-]+)?',
|
||||
'qwen-image-edit',
|
||||
'gpt-image-1',
|
||||
'gemini-2.5-flash-image',
|
||||
'gemini-2.5-flash-image(?:-[\\w-]+)?',
|
||||
'gemini-2.0-flash-preview-image-generation'
|
||||
]
|
||||
|
||||
|
||||
27
yarn.lock
27
yarn.lock
@ -7116,13 +7116,13 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@openrouter/ai-sdk-provider@npm:^1.1.2":
|
||||
version: 1.1.2
|
||||
resolution: "@openrouter/ai-sdk-provider@npm:1.1.2"
|
||||
"@openrouter/ai-sdk-provider@npm:^1.2.0":
|
||||
version: 1.2.0
|
||||
resolution: "@openrouter/ai-sdk-provider@npm:1.2.0"
|
||||
peerDependencies:
|
||||
ai: ^5.0.0
|
||||
zod: ^3.24.1 || ^v4
|
||||
checksum: 10c0/1ad50804189910d52c2c10e479bec40dfbd2109820e43135d001f4f8706be6ace532d4769a8c30111f5870afdfa97b815c7334b2e4d8d36ca68b1578ce5d9a41
|
||||
checksum: 10c0/4ca7c471ec46bdd48eea9c56d94778a06ca4b74b6ef2ab892ab7eadbd409e3530ac0c5791cd80e88cafc44a49a76585e59707104792e3e3124237fed767104ef
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@ -13902,7 +13902,7 @@ __metadata:
|
||||
"@mozilla/readability": "npm:^0.6.0"
|
||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch"
|
||||
"@notionhq/client": "npm:^2.2.15"
|
||||
"@openrouter/ai-sdk-provider": "npm:^1.1.2"
|
||||
"@openrouter/ai-sdk-provider": "npm:^1.2.0"
|
||||
"@opentelemetry/api": "npm:^1.9.0"
|
||||
"@opentelemetry/core": "npm:2.0.0"
|
||||
"@opentelemetry/exporter-trace-otlp-http": "npm:^0.200.0"
|
||||
@ -23907,23 +23907,6 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"openai@npm:5.12.2":
|
||||
version: 5.12.2
|
||||
resolution: "openai@npm:5.12.2"
|
||||
peerDependencies:
|
||||
ws: ^8.18.0
|
||||
zod: ^3.23.8
|
||||
peerDependenciesMeta:
|
||||
ws:
|
||||
optional: true
|
||||
zod:
|
||||
optional: true
|
||||
bin:
|
||||
openai: bin/cli
|
||||
checksum: 10c0/7737b9b24edc81fcf9e6dcfb18a196cc0f8e29b6e839adf06a2538558c03908e3aa4cd94901b1a7f4a9dd62676fe9e34d6202281b2395090d998618ea1614c0c
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"openapi-types@npm:^12.1.3":
|
||||
version: 12.1.3
|
||||
resolution: "openapi-types@npm:12.1.3"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user