feat(openrouter): support openrouter to generate image (#9750)

* feat(openrouter): 支持OpenRouter的图像生成功能并处理模型名称

修改getLowerBaseModelName函数以处理OpenRouter的:free后缀
在OpenAIApiClient中添加enableGenerateImage参数支持图像生成

* refactor(openai): 重构OpenAI参数类型并优化翻译选项处理

重构OpenAIParamsWithoutReasoningEffort为OpenAIParamsPurified,新增OpenAIModalities和OpenAIExtraBody类型
优化翻译选项处理逻辑,提前验证目标语言有效性
将modalities参数从extra_body分离以提升类型安全性

* test(naming): 修复模型名称处理测试并添加新测试用例

修复getLowerBaseModelName测试中对GPT-4:free的预期结果
添加新测试用例验证去除:free后缀的功能

* test(naming): 移除对包含冒号的模型名称的测试
This commit is contained in:
Phantom 2025-09-01 12:55:46 +08:00 committed by GitHub
parent fd2d4c723c
commit 22ca77188b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 45 additions and 14 deletions

View File

@ -60,6 +60,8 @@ import {
import { ChunkType, TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
OpenAIExtraBody,
OpenAIModality,
OpenAISdkMessageParam,
OpenAISdkParams,
OpenAISdkRawChunk,
@ -564,7 +566,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
messages: OpenAISdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, enableWebSearch } = coreRequest
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
let { streamOutput } = coreRequest
// Qwen3商业版思考模式、Qwen3开源版、QwQ、QVQ只支持流式输出。
@ -572,18 +574,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
streamOutput = true
}
const extra_body: Record<string, any> = {}
const extra_body: OpenAIExtraBody = {}
if (isQwenMTModel(model)) {
if (isTranslateAssistant(assistant)) {
const targetLanguage = assistant.targetLanguage
const targetLanguage = mapLanguageToQwenMTModel(assistant.targetLanguage)
if (!targetLanguage) {
throw new Error(t('translate.error.not_supported', { language: assistant.targetLanguage.value }))
}
const translationOptions = {
source_lang: 'auto',
target_lang: mapLanguageToQwenMTModel(targetLanguage)
target_lang: targetLanguage
} as const
if (!translationOptions.target_lang) {
throw new Error(t('translate.error.not_supported', { language: targetLanguage.value }))
}
extra_body.translation_options = translationOptions
} else {
throw new Error(t('translate.error.chat_qwen_mt'))
@ -684,6 +686,15 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
reasoningEffort.reasoning_effort = 'low'
}
const modalities: {
modalities?: OpenAIModality[]
} = {}
// for openrouter generate image
// https://openrouter.ai/docs/features/multimodal/image-generation
if (enableGenerateImage && this.provider.id === SystemProviderIds.openrouter) {
modalities.modalities = ['image', 'text']
}
const commonParams: OpenAISdkParams = {
model: model.id,
messages:
@ -696,6 +707,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
tools: tools.length > 0 ? tools : undefined,
stream: streamOutput,
...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}),
...modalities,
// groq 有不同的 service tier 配置,不符合 openai 接口类型
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
...this.getProviderSpecificParameters(assistant, model),
@ -703,7 +715,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
...getOpenAIWebSearchParams(model, enableWebSearch),
// OpenRouter usage tracking
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
...(isQwenMTModel(model) ? extra_body : {}),
...extra_body,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
// 注意:用户自定义参数总是应该覆盖其他参数
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})

View File

@ -72,7 +72,7 @@ export type RequestOptions = Anthropic.RequestOptions | OpenAI.RequestOptions |
* OpenAI
*/
type OpenAIParamsWithoutReasoningEffort = Omit<OpenAI.Chat.Completions.ChatCompletionCreateParams, 'reasoning_effort'>
type OpenAIParamsPurified = Omit<OpenAI.Chat.Completions.ChatCompletionCreateParams, 'reasoning_effort' | 'modalities'>
export type ReasoningEffortOptionalParams = {
thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number }
@ -97,7 +97,7 @@ export type ReasoningEffortOptionalParams = {
// Add any other potential reasoning-related keys here if they exist
}
export type OpenAISdkParams = OpenAIParamsWithoutReasoningEffort & ReasoningEffortOptionalParams
export type OpenAISdkParams = OpenAIParamsPurified & ReasoningEffortOptionalParams & OpenAIModalities & OpenAIExtraBody
// OpenRouter may include additional fields like cost
export type OpenAISdkRawChunk =
@ -116,7 +116,18 @@ export type OpenAISdkRawContentSource =
})
export type OpenAISdkMessageParam = OpenAI.Chat.Completions.ChatCompletionMessageParam
export type OpenAIExtraBody = {
// for qwen mt
translation_options?: {
source_lang: 'auto'
target_lang: string
}
}
// image is for openrouter. audio is ignored for now
export type OpenAIModality = OpenAI.ChatCompletionModality | 'image'
export type OpenAIModalities = {
modalities?: OpenAIModality[]
}
/**
* OpenAI Response
*/

View File

@ -174,7 +174,6 @@ describe('naming', () => {
it('should return original id if no delimiter found', () => {
expect(getBaseModelName('deepseek-r1')).toBe('deepseek-r1')
expect(getBaseModelName('deepseek-r1:free')).toBe('deepseek-r1:free')
})
it('should handle edge cases', () => {
@ -209,7 +208,7 @@ describe('naming', () => {
it('should return lowercase original id if no delimiter found', () => {
// 验证没有分隔符时返回小写原始ID
expect(getLowerBaseModelName('DeepSeek-R1')).toBe('deepseek-r1')
expect(getLowerBaseModelName('GPT-4:Free')).toBe('gpt-4:free')
expect(getLowerBaseModelName('GPT-4')).toBe('gpt-4')
})
it('should handle edge cases', () => {
@ -219,6 +218,10 @@ describe('naming', () => {
expect(getLowerBaseModelName('/Model')).toBe('model')
expect(getLowerBaseModelName('Model//Name')).toBe('name')
})
it('should remove trailing :free', () => {
expect(getLowerBaseModelName('gpt-4:free')).toBe('gpt-4')
})
})
describe('getFirstCharacter', () => {

View File

@ -73,7 +73,12 @@ export const getBaseModelName = (id: string, delimiter: string = '/'): string =>
* @returns {string}
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
return getBaseModelName(id, delimiter).toLowerCase()
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}
/**