diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index f00f2d1a88..5fe9b44525 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -2233,6 +2233,9 @@ export function isOpenAILLMModel(model: Model): boolean { if (!model) { return false } + if (model.id.includes('gpt-4o-image')) { + return false + } if (isOpenAIReasoningModel(model)) { return true } diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index 26a637b129..3c59f0e71b 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -73,7 +73,7 @@ import { } from 'openai/resources' import { CompletionsParams } from '.' -import { BaseOpenAiProvider } from './OpenAIResponseProvider' +import { BaseOpenAIProvider } from './OpenAIResponseProvider' // 1. 定义联合类型 export type OpenAIStreamChunk = @@ -81,7 +81,7 @@ export type OpenAIStreamChunk = | { type: 'tool-calls'; delta: any } | { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any } -export default class OpenAIProvider extends BaseOpenAiProvider { +export default class OpenAIProvider extends BaseOpenAIProvider { constructor(provider: Provider) { super(provider) diff --git a/src/renderer/src/providers/AiProvider/OpenAIResponseProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIResponseProvider.ts index c515621169..500d212a6b 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIResponseProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIResponseProvider.ts @@ -1,5 +1,4 @@ import { - getOpenAIWebSearchParams, isOpenAILLMModel, isOpenAIReasoningModel, isOpenAIWebSearch, @@ -53,8 +52,9 @@ import { FileLike, toFile } from 'openai/uploads' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' +import OpenAIProvider from './OpenAIProvider' -export abstract class BaseOpenAiProvider extends BaseProvider { +export abstract class BaseOpenAIProvider extends BaseProvider { protected sdk: OpenAI constructor(provider: Provider) { @@ -311,112 +311,7 @@ export abstract class BaseOpenAiProvider extends BaseProvider { const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant) const isEnabledBuiltinWebSearch = assistant.enableWebSearch - // 退回到 OpenAI 兼容模式 - if (isOpenAIWebSearch(model)) { - const systemMessage = { role: 'system', content: assistant.prompt || '' } - const userMessages: ChatCompletionMessageParam[] = [] - const _messages = filterUserRoleStartMessages( - filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1))) - ) - onFilterMessages(_messages) - for (const message of _messages) { - userMessages.push(await this.getMessageParam(message, model)) - } - //当 systemMessage 内容为空时不发送 systemMessage - let reqMessages: ChatCompletionMessageParam[] - if (!systemMessage.content) { - reqMessages = [...userMessages] - } else { - reqMessages = [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[] - } - const lastUserMessage = _messages.findLast((m) => m.role === 'user') - const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true) - const { signal } = abortController - const start_time_millsec = new Date().getTime() - const response = await this.sdk.chat.completions - // @ts-ignore key is not typed - .create( - { - model: model.id, - messages: reqMessages, - stream: true, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_tokens: maxTokens, - ...getOpenAIWebSearchParams(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal - } - ) - const processStream = async (stream: any) => { - let content = '' - let isFirstChunk = true - const finalUsage: Usage = { - completion_tokens: 0, - prompt_tokens: 0, - total_tokens: 0 - } - - const finalMetrics: Metrics = { - completion_tokens: 0, - time_completion_millsec: 0, - time_first_token_millsec: 0 - } - for await (const chunk of stream as any) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { - break - } - const delta = chunk.choices[0]?.delta - const finishReason = chunk.choices[0]?.finish_reason - if (delta?.content) { - if (isOpenAIWebSearch(model)) { - delta.content = convertLinks(delta.content || '', isFirstChunk) - } - if (isFirstChunk) { - isFirstChunk = false - finalMetrics.time_first_token_millsec = new Date().getTime() - start_time_millsec - } - content += delta.content - onChunk({ type: ChunkType.TEXT_DELTA, text: delta.content }) - } - if (!isEmpty(finishReason) || chunk?.annotations) { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) - finalMetrics.time_completion_millsec = new Date().getTime() - start_time_millsec - if (chunk.usage) { - const usage = chunk.usage as OpenAI.Completions.CompletionUsage - finalUsage.completion_tokens = usage.completion_tokens - finalUsage.prompt_tokens = usage.prompt_tokens - finalUsage.total_tokens = usage.total_tokens - } - finalMetrics.completion_tokens = finalUsage.completion_tokens - } - if (delta?.annotations) { - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: delta.annotations, - source: WebSearchSource.OPENAI - } - }) - } - } - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - usage: finalUsage, - metrics: finalMetrics - } - }) - } - await processStream(response).finally(cleanup) - await signalPromise?.promise?.catch((error) => { - throw error - }) - return - } let tools: OpenAI.Responses.Tool[] = [] const toolChoices: OpenAI.Responses.ToolChoiceTypes = { type: 'web_search_preview' @@ -1162,6 +1057,11 @@ export abstract class BaseOpenAiProvider extends BaseProvider { ) images = images.concat(assistantImages.filter(Boolean) as FileLike[]) } + + onChunk({ + type: ChunkType.LLM_RESPONSE_CREATED + }) + onChunk({ type: ChunkType.IMAGE_CREATED }) @@ -1240,9 +1140,30 @@ export abstract class BaseOpenAiProvider extends BaseProvider { } } -export default class OpenAIResponseProvider extends BaseOpenAiProvider { +export default class OpenAIResponseProvider extends BaseOpenAIProvider { + private providers: Map = new Map() + constructor(provider: Provider) { super(provider) + this.providers.set('openai-compatible', new OpenAIProvider(provider)) + } + + private getProvider(model: Model): BaseOpenAIProvider { + if (isOpenAIWebSearch(model)) { + return this.providers.get('openai-compatible')! + } else { + return this + } + } + + public completions(params: CompletionsParams): Promise { + const model = params.assistant.model + if (!model) { + return Promise.reject(new Error('Model is required')) + } + + const provider = this.getProvider(model) + return provider === this ? super.completions(params) : provider.completions(params) } public convertMcpTools(mcpTools: MCPTool[]) { diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index c4d2d7898d..a8a58dc3d1 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -536,10 +536,22 @@ const fetchAndProcessAssistantResponseImpl = async ( } }, onImageCreated: () => { - const imageBlock = createImageBlock(assistantMsgId, { - status: MessageBlockStatus.PROCESSING - }) - handleBlockTransition(imageBlock, MessageBlockType.IMAGE) + if (lastBlockId) { + if (lastBlockType === MessageBlockType.UNKNOWN) { + const initialChanges: Partial = { + type: MessageBlockType.IMAGE, + status: MessageBlockStatus.STREAMING + } + lastBlockType = MessageBlockType.IMAGE + dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges })) + saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) + } else { + const imageBlock = createImageBlock(assistantMsgId, { + status: MessageBlockStatus.PROCESSING + }) + handleBlockTransition(imageBlock, MessageBlockType.IMAGE) + } + } }, onImageGenerated: (imageData) => { const imageUrl = imageData.images?.[0] || 'placeholder_image_url'