mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat(openrouter): support openrouter to generate image (#9750)
* feat(openrouter): 支持OpenRouter的图像生成功能并处理模型名称 修改getLowerBaseModelName函数以处理OpenRouter的:free后缀 在OpenAIApiClient中添加enableGenerateImage参数支持图像生成 * refactor(openai): 重构OpenAI参数类型并优化翻译选项处理 重构OpenAIParamsWithoutReasoningEffort为OpenAIParamsPurified,新增OpenAIModalities和OpenAIExtraBody类型 优化翻译选项处理逻辑,提前验证目标语言有效性 将modalities参数从extra_body分离以提升类型安全性 * test(naming): 修复模型名称处理测试并添加新测试用例 修复getLowerBaseModelName测试中对GPT-4:free的预期结果 添加新测试用例验证去除:free后缀的功能 * test(naming): 移除对包含冒号的模型名称的测试
This commit is contained in:
parent
fd2d4c723c
commit
22ca77188b
@ -60,6 +60,8 @@ import {
|
||||
import { ChunkType, TextStartChunk, ThinkingStartChunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
OpenAIExtraBody,
|
||||
OpenAIModality,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
@ -564,7 +566,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
messages: OpenAISdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, enableWebSearch } = coreRequest
|
||||
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
let { streamOutput } = coreRequest
|
||||
|
||||
// Qwen3商业版(思考模式)、Qwen3开源版、QwQ、QVQ只支持流式输出。
|
||||
@ -572,18 +574,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
streamOutput = true
|
||||
}
|
||||
|
||||
const extra_body: Record<string, any> = {}
|
||||
const extra_body: OpenAIExtraBody = {}
|
||||
|
||||
if (isQwenMTModel(model)) {
|
||||
if (isTranslateAssistant(assistant)) {
|
||||
const targetLanguage = assistant.targetLanguage
|
||||
const targetLanguage = mapLanguageToQwenMTModel(assistant.targetLanguage)
|
||||
if (!targetLanguage) {
|
||||
throw new Error(t('translate.error.not_supported', { language: assistant.targetLanguage.value }))
|
||||
}
|
||||
const translationOptions = {
|
||||
source_lang: 'auto',
|
||||
target_lang: mapLanguageToQwenMTModel(targetLanguage)
|
||||
target_lang: targetLanguage
|
||||
} as const
|
||||
if (!translationOptions.target_lang) {
|
||||
throw new Error(t('translate.error.not_supported', { language: targetLanguage.value }))
|
||||
}
|
||||
extra_body.translation_options = translationOptions
|
||||
} else {
|
||||
throw new Error(t('translate.error.chat_qwen_mt'))
|
||||
@ -684,6 +686,15 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
reasoningEffort.reasoning_effort = 'low'
|
||||
}
|
||||
|
||||
const modalities: {
|
||||
modalities?: OpenAIModality[]
|
||||
} = {}
|
||||
// for openrouter generate image
|
||||
// https://openrouter.ai/docs/features/multimodal/image-generation
|
||||
if (enableGenerateImage && this.provider.id === SystemProviderIds.openrouter) {
|
||||
modalities.modalities = ['image', 'text']
|
||||
}
|
||||
|
||||
const commonParams: OpenAISdkParams = {
|
||||
model: model.id,
|
||||
messages:
|
||||
@ -696,6 +707,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
stream: streamOutput,
|
||||
...(shouldIncludeStreamOptions ? { stream_options: { include_usage: true } } : {}),
|
||||
...modalities,
|
||||
// groq 有不同的 service tier 配置,不符合 openai 接口类型
|
||||
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
|
||||
...this.getProviderSpecificParameters(assistant, model),
|
||||
@ -703,7 +715,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||
// OpenRouter usage tracking
|
||||
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
|
||||
...(isQwenMTModel(model) ? extra_body : {}),
|
||||
...extra_body,
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
// 注意:用户自定义参数总是应该覆盖其他参数
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
|
||||
@ -72,7 +72,7 @@ export type RequestOptions = Anthropic.RequestOptions | OpenAI.RequestOptions |
|
||||
* OpenAI
|
||||
*/
|
||||
|
||||
type OpenAIParamsWithoutReasoningEffort = Omit<OpenAI.Chat.Completions.ChatCompletionCreateParams, 'reasoning_effort'>
|
||||
type OpenAIParamsPurified = Omit<OpenAI.Chat.Completions.ChatCompletionCreateParams, 'reasoning_effort' | 'modalities'>
|
||||
|
||||
export type ReasoningEffortOptionalParams = {
|
||||
thinking?: { type: 'disabled' | 'enabled' | 'auto'; budget_tokens?: number }
|
||||
@ -97,7 +97,7 @@ export type ReasoningEffortOptionalParams = {
|
||||
// Add any other potential reasoning-related keys here if they exist
|
||||
}
|
||||
|
||||
export type OpenAISdkParams = OpenAIParamsWithoutReasoningEffort & ReasoningEffortOptionalParams
|
||||
export type OpenAISdkParams = OpenAIParamsPurified & ReasoningEffortOptionalParams & OpenAIModalities & OpenAIExtraBody
|
||||
|
||||
// OpenRouter may include additional fields like cost
|
||||
export type OpenAISdkRawChunk =
|
||||
@ -116,7 +116,18 @@ export type OpenAISdkRawContentSource =
|
||||
})
|
||||
|
||||
export type OpenAISdkMessageParam = OpenAI.Chat.Completions.ChatCompletionMessageParam
|
||||
|
||||
export type OpenAIExtraBody = {
|
||||
// for qwen mt
|
||||
translation_options?: {
|
||||
source_lang: 'auto'
|
||||
target_lang: string
|
||||
}
|
||||
}
|
||||
// image is for openrouter. audio is ignored for now
|
||||
export type OpenAIModality = OpenAI.ChatCompletionModality | 'image'
|
||||
export type OpenAIModalities = {
|
||||
modalities?: OpenAIModality[]
|
||||
}
|
||||
/**
|
||||
* OpenAI Response
|
||||
*/
|
||||
|
||||
@ -174,7 +174,6 @@ describe('naming', () => {
|
||||
|
||||
it('should return original id if no delimiter found', () => {
|
||||
expect(getBaseModelName('deepseek-r1')).toBe('deepseek-r1')
|
||||
expect(getBaseModelName('deepseek-r1:free')).toBe('deepseek-r1:free')
|
||||
})
|
||||
|
||||
it('should handle edge cases', () => {
|
||||
@ -209,7 +208,7 @@ describe('naming', () => {
|
||||
it('should return lowercase original id if no delimiter found', () => {
|
||||
// 验证没有分隔符时返回小写原始ID
|
||||
expect(getLowerBaseModelName('DeepSeek-R1')).toBe('deepseek-r1')
|
||||
expect(getLowerBaseModelName('GPT-4:Free')).toBe('gpt-4:free')
|
||||
expect(getLowerBaseModelName('GPT-4')).toBe('gpt-4')
|
||||
})
|
||||
|
||||
it('should handle edge cases', () => {
|
||||
@ -219,6 +218,10 @@ describe('naming', () => {
|
||||
expect(getLowerBaseModelName('/Model')).toBe('model')
|
||||
expect(getLowerBaseModelName('Model//Name')).toBe('name')
|
||||
})
|
||||
|
||||
it('should remove trailing :free', () => {
|
||||
expect(getLowerBaseModelName('gpt-4:free')).toBe('gpt-4')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getFirstCharacter', () => {
|
||||
|
||||
@ -73,7 +73,12 @@ export const getBaseModelName = (id: string, delimiter: string = '/'): string =>
|
||||
* @returns {string} 小写的基础名称
|
||||
*/
|
||||
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
|
||||
return getBaseModelName(id, delimiter).toLowerCase()
|
||||
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
|
||||
// for openrouter
|
||||
if (baseModelName.endsWith(':free')) {
|
||||
return baseModelName.replace(':free', '')
|
||||
}
|
||||
return baseModelName
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
Loading…
Reference in New Issue
Block a user