mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 10:40:07 +08:00
fix: enhance image block handling in message processing (#5971)
This commit is contained in:
parent
729752f96a
commit
5b49f77965
@ -2233,6 +2233,9 @@ export function isOpenAILLMModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
if (model.id.includes('gpt-4o-image')) {
|
||||
return false
|
||||
}
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -73,7 +73,7 @@ import {
|
||||
} from 'openai/resources'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
import { BaseOpenAiProvider } from './OpenAIResponseProvider'
|
||||
import { BaseOpenAIProvider } from './OpenAIResponseProvider'
|
||||
|
||||
// 1. 定义联合类型
|
||||
export type OpenAIStreamChunk =
|
||||
@ -81,7 +81,7 @@ export type OpenAIStreamChunk =
|
||||
| { type: 'tool-calls'; delta: any }
|
||||
| { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any }
|
||||
|
||||
export default class OpenAIProvider extends BaseOpenAiProvider {
|
||||
export default class OpenAIProvider extends BaseOpenAIProvider {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import {
|
||||
getOpenAIWebSearchParams,
|
||||
isOpenAILLMModel,
|
||||
isOpenAIReasoningModel,
|
||||
isOpenAIWebSearch,
|
||||
@ -53,8 +52,9 @@ import { FileLike, toFile } from 'openai/uploads'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
import BaseProvider from './BaseProvider'
|
||||
import OpenAIProvider from './OpenAIProvider'
|
||||
|
||||
export abstract class BaseOpenAiProvider extends BaseProvider {
|
||||
export abstract class BaseOpenAIProvider extends BaseProvider {
|
||||
protected sdk: OpenAI
|
||||
|
||||
constructor(provider: Provider) {
|
||||
@ -311,112 +311,7 @@ export abstract class BaseOpenAiProvider extends BaseProvider {
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
|
||||
// 退回到 OpenAI 兼容模式
|
||||
if (isOpenAIWebSearch(model)) {
|
||||
const systemMessage = { role: 'system', content: assistant.prompt || '' }
|
||||
const userMessages: ChatCompletionMessageParam[] = []
|
||||
const _messages = filterUserRoleStartMessages(
|
||||
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
||||
)
|
||||
onFilterMessages(_messages)
|
||||
|
||||
for (const message of _messages) {
|
||||
userMessages.push(await this.getMessageParam(message, model))
|
||||
}
|
||||
//当 systemMessage 内容为空时不发送 systemMessage
|
||||
let reqMessages: ChatCompletionMessageParam[]
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessages]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[]
|
||||
}
|
||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||
const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true)
|
||||
const { signal } = abortController
|
||||
const start_time_millsec = new Date().getTime()
|
||||
const response = await this.sdk.chat.completions
|
||||
// @ts-ignore key is not typed
|
||||
.create(
|
||||
{
|
||||
model: model.id,
|
||||
messages: reqMessages,
|
||||
stream: true,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_tokens: maxTokens,
|
||||
...getOpenAIWebSearchParams(assistant, model),
|
||||
...this.getCustomParameters(assistant)
|
||||
},
|
||||
{
|
||||
signal
|
||||
}
|
||||
)
|
||||
const processStream = async (stream: any) => {
|
||||
let content = ''
|
||||
let isFirstChunk = true
|
||||
const finalUsage: Usage = {
|
||||
completion_tokens: 0,
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
|
||||
const finalMetrics: Metrics = {
|
||||
completion_tokens: 0,
|
||||
time_completion_millsec: 0,
|
||||
time_first_token_millsec: 0
|
||||
}
|
||||
for await (const chunk of stream as any) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||
break
|
||||
}
|
||||
const delta = chunk.choices[0]?.delta
|
||||
const finishReason = chunk.choices[0]?.finish_reason
|
||||
if (delta?.content) {
|
||||
if (isOpenAIWebSearch(model)) {
|
||||
delta.content = convertLinks(delta.content || '', isFirstChunk)
|
||||
}
|
||||
if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
finalMetrics.time_first_token_millsec = new Date().getTime() - start_time_millsec
|
||||
}
|
||||
content += delta.content
|
||||
onChunk({ type: ChunkType.TEXT_DELTA, text: delta.content })
|
||||
}
|
||||
if (!isEmpty(finishReason) || chunk?.annotations) {
|
||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
||||
finalMetrics.time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
if (chunk.usage) {
|
||||
const usage = chunk.usage as OpenAI.Completions.CompletionUsage
|
||||
finalUsage.completion_tokens = usage.completion_tokens
|
||||
finalUsage.prompt_tokens = usage.prompt_tokens
|
||||
finalUsage.total_tokens = usage.total_tokens
|
||||
}
|
||||
finalMetrics.completion_tokens = finalUsage.completion_tokens
|
||||
}
|
||||
if (delta?.annotations) {
|
||||
onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: delta.annotations,
|
||||
source: WebSearchSource.OPENAI
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: finalUsage,
|
||||
metrics: finalMetrics
|
||||
}
|
||||
})
|
||||
}
|
||||
await processStream(response).finally(cleanup)
|
||||
await signalPromise?.promise?.catch((error) => {
|
||||
throw error
|
||||
})
|
||||
return
|
||||
}
|
||||
let tools: OpenAI.Responses.Tool[] = []
|
||||
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
||||
type: 'web_search_preview'
|
||||
@ -1164,6 +1059,11 @@ export abstract class BaseOpenAiProvider extends BaseProvider {
|
||||
)
|
||||
images = images.concat(assistantImages.filter(Boolean) as FileLike[])
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_CREATED
|
||||
})
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
@ -1242,9 +1142,30 @@ export abstract class BaseOpenAiProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
|
||||
export default class OpenAIResponseProvider extends BaseOpenAiProvider {
|
||||
export default class OpenAIResponseProvider extends BaseOpenAIProvider {
|
||||
private providers: Map<string, BaseOpenAIProvider> = new Map()
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.providers.set('openai-compatible', new OpenAIProvider(provider))
|
||||
}
|
||||
|
||||
private getProvider(model: Model): BaseOpenAIProvider {
|
||||
if (isOpenAIWebSearch(model)) {
|
||||
return this.providers.get('openai-compatible')!
|
||||
} else {
|
||||
return this
|
||||
}
|
||||
}
|
||||
|
||||
public completions(params: CompletionsParams): Promise<void> {
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
const provider = this.getProvider(model)
|
||||
return provider === this ? super.completions(params) : provider.completions(params)
|
||||
}
|
||||
|
||||
public convertMcpTools<T>(mcpTools: MCPTool[]) {
|
||||
|
||||
@ -536,10 +536,22 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
}
|
||||
},
|
||||
onImageCreated: () => {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
if (lastBlockId) {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.IMAGE,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
lastBlockType = MessageBlockType.IMAGE
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
}
|
||||
}
|
||||
},
|
||||
onImageGenerated: (imageData) => {
|
||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user