mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-05 04:19:02 +08:00
fix: enhance image block handling in message processing (#5971)
This commit is contained in:
parent
462a63d36c
commit
2e504a92dc
@ -2233,6 +2233,9 @@ export function isOpenAILLMModel(model: Model): boolean {
|
|||||||
if (!model) {
|
if (!model) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if (model.id.includes('gpt-4o-image')) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
if (isOpenAIReasoningModel(model)) {
|
if (isOpenAIReasoningModel(model)) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@ -73,7 +73,7 @@ import {
|
|||||||
} from 'openai/resources'
|
} from 'openai/resources'
|
||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import { BaseOpenAiProvider } from './OpenAIResponseProvider'
|
import { BaseOpenAIProvider } from './OpenAIResponseProvider'
|
||||||
|
|
||||||
// 1. 定义联合类型
|
// 1. 定义联合类型
|
||||||
export type OpenAIStreamChunk =
|
export type OpenAIStreamChunk =
|
||||||
@ -81,7 +81,7 @@ export type OpenAIStreamChunk =
|
|||||||
| { type: 'tool-calls'; delta: any }
|
| { type: 'tool-calls'; delta: any }
|
||||||
| { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any }
|
| { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any }
|
||||||
|
|
||||||
export default class OpenAIProvider extends BaseOpenAiProvider {
|
export default class OpenAIProvider extends BaseOpenAIProvider {
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import {
|
import {
|
||||||
getOpenAIWebSearchParams,
|
|
||||||
isOpenAILLMModel,
|
isOpenAILLMModel,
|
||||||
isOpenAIReasoningModel,
|
isOpenAIReasoningModel,
|
||||||
isOpenAIWebSearch,
|
isOpenAIWebSearch,
|
||||||
@ -53,8 +52,9 @@ import { FileLike, toFile } from 'openai/uploads'
|
|||||||
|
|
||||||
import { CompletionsParams } from '.'
|
import { CompletionsParams } from '.'
|
||||||
import BaseProvider from './BaseProvider'
|
import BaseProvider from './BaseProvider'
|
||||||
|
import OpenAIProvider from './OpenAIProvider'
|
||||||
|
|
||||||
export abstract class BaseOpenAiProvider extends BaseProvider {
|
export abstract class BaseOpenAIProvider extends BaseProvider {
|
||||||
protected sdk: OpenAI
|
protected sdk: OpenAI
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
@ -311,112 +311,7 @@ export abstract class BaseOpenAiProvider extends BaseProvider {
|
|||||||
const model = assistant.model || defaultModel
|
const model = assistant.model || defaultModel
|
||||||
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant)
|
||||||
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
|
const isEnabledBuiltinWebSearch = assistant.enableWebSearch
|
||||||
// 退回到 OpenAI 兼容模式
|
|
||||||
if (isOpenAIWebSearch(model)) {
|
|
||||||
const systemMessage = { role: 'system', content: assistant.prompt || '' }
|
|
||||||
const userMessages: ChatCompletionMessageParam[] = []
|
|
||||||
const _messages = filterUserRoleStartMessages(
|
|
||||||
filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
|
||||||
)
|
|
||||||
onFilterMessages(_messages)
|
|
||||||
|
|
||||||
for (const message of _messages) {
|
|
||||||
userMessages.push(await this.getMessageParam(message, model))
|
|
||||||
}
|
|
||||||
//当 systemMessage 内容为空时不发送 systemMessage
|
|
||||||
let reqMessages: ChatCompletionMessageParam[]
|
|
||||||
if (!systemMessage.content) {
|
|
||||||
reqMessages = [...userMessages]
|
|
||||||
} else {
|
|
||||||
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[]
|
|
||||||
}
|
|
||||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
|
||||||
const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true)
|
|
||||||
const { signal } = abortController
|
|
||||||
const start_time_millsec = new Date().getTime()
|
|
||||||
const response = await this.sdk.chat.completions
|
|
||||||
// @ts-ignore key is not typed
|
|
||||||
.create(
|
|
||||||
{
|
|
||||||
model: model.id,
|
|
||||||
messages: reqMessages,
|
|
||||||
stream: true,
|
|
||||||
temperature: this.getTemperature(assistant, model),
|
|
||||||
top_p: this.getTopP(assistant, model),
|
|
||||||
max_tokens: maxTokens,
|
|
||||||
...getOpenAIWebSearchParams(assistant, model),
|
|
||||||
...this.getCustomParameters(assistant)
|
|
||||||
},
|
|
||||||
{
|
|
||||||
signal
|
|
||||||
}
|
|
||||||
)
|
|
||||||
const processStream = async (stream: any) => {
|
|
||||||
let content = ''
|
|
||||||
let isFirstChunk = true
|
|
||||||
const finalUsage: Usage = {
|
|
||||||
completion_tokens: 0,
|
|
||||||
prompt_tokens: 0,
|
|
||||||
total_tokens: 0
|
|
||||||
}
|
|
||||||
|
|
||||||
const finalMetrics: Metrics = {
|
|
||||||
completion_tokens: 0,
|
|
||||||
time_completion_millsec: 0,
|
|
||||||
time_first_token_millsec: 0
|
|
||||||
}
|
|
||||||
for await (const chunk of stream as any) {
|
|
||||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
const delta = chunk.choices[0]?.delta
|
|
||||||
const finishReason = chunk.choices[0]?.finish_reason
|
|
||||||
if (delta?.content) {
|
|
||||||
if (isOpenAIWebSearch(model)) {
|
|
||||||
delta.content = convertLinks(delta.content || '', isFirstChunk)
|
|
||||||
}
|
|
||||||
if (isFirstChunk) {
|
|
||||||
isFirstChunk = false
|
|
||||||
finalMetrics.time_first_token_millsec = new Date().getTime() - start_time_millsec
|
|
||||||
}
|
|
||||||
content += delta.content
|
|
||||||
onChunk({ type: ChunkType.TEXT_DELTA, text: delta.content })
|
|
||||||
}
|
|
||||||
if (!isEmpty(finishReason) || chunk?.annotations) {
|
|
||||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content })
|
|
||||||
finalMetrics.time_completion_millsec = new Date().getTime() - start_time_millsec
|
|
||||||
if (chunk.usage) {
|
|
||||||
const usage = chunk.usage as OpenAI.Completions.CompletionUsage
|
|
||||||
finalUsage.completion_tokens = usage.completion_tokens
|
|
||||||
finalUsage.prompt_tokens = usage.prompt_tokens
|
|
||||||
finalUsage.total_tokens = usage.total_tokens
|
|
||||||
}
|
|
||||||
finalMetrics.completion_tokens = finalUsage.completion_tokens
|
|
||||||
}
|
|
||||||
if (delta?.annotations) {
|
|
||||||
onChunk({
|
|
||||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
|
||||||
llm_web_search: {
|
|
||||||
results: delta.annotations,
|
|
||||||
source: WebSearchSource.OPENAI
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
onChunk({
|
|
||||||
type: ChunkType.BLOCK_COMPLETE,
|
|
||||||
response: {
|
|
||||||
usage: finalUsage,
|
|
||||||
metrics: finalMetrics
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
await processStream(response).finally(cleanup)
|
|
||||||
await signalPromise?.promise?.catch((error) => {
|
|
||||||
throw error
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
let tools: OpenAI.Responses.Tool[] = []
|
let tools: OpenAI.Responses.Tool[] = []
|
||||||
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
||||||
type: 'web_search_preview'
|
type: 'web_search_preview'
|
||||||
@ -1162,6 +1057,11 @@ export abstract class BaseOpenAiProvider extends BaseProvider {
|
|||||||
)
|
)
|
||||||
images = images.concat(assistantImages.filter(Boolean) as FileLike[])
|
images = images.concat(assistantImages.filter(Boolean) as FileLike[])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.LLM_RESPONSE_CREATED
|
||||||
|
})
|
||||||
|
|
||||||
onChunk({
|
onChunk({
|
||||||
type: ChunkType.IMAGE_CREATED
|
type: ChunkType.IMAGE_CREATED
|
||||||
})
|
})
|
||||||
@ -1240,9 +1140,30 @@ export abstract class BaseOpenAiProvider extends BaseProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export default class OpenAIResponseProvider extends BaseOpenAiProvider {
|
export default class OpenAIResponseProvider extends BaseOpenAIProvider {
|
||||||
|
private providers: Map<string, BaseOpenAIProvider> = new Map()
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
|
this.providers.set('openai-compatible', new OpenAIProvider(provider))
|
||||||
|
}
|
||||||
|
|
||||||
|
private getProvider(model: Model): BaseOpenAIProvider {
|
||||||
|
if (isOpenAIWebSearch(model)) {
|
||||||
|
return this.providers.get('openai-compatible')!
|
||||||
|
} else {
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public completions(params: CompletionsParams): Promise<void> {
|
||||||
|
const model = params.assistant.model
|
||||||
|
if (!model) {
|
||||||
|
return Promise.reject(new Error('Model is required'))
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = this.getProvider(model)
|
||||||
|
return provider === this ? super.completions(params) : provider.completions(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
public convertMcpTools<T>(mcpTools: MCPTool[]) {
|
public convertMcpTools<T>(mcpTools: MCPTool[]) {
|
||||||
|
|||||||
@ -536,10 +536,22 @@ const fetchAndProcessAssistantResponseImpl = async (
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
onImageCreated: () => {
|
onImageCreated: () => {
|
||||||
const imageBlock = createImageBlock(assistantMsgId, {
|
if (lastBlockId) {
|
||||||
status: MessageBlockStatus.PROCESSING
|
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||||
})
|
const initialChanges: Partial<MessageBlock> = {
|
||||||
handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
type: MessageBlockType.IMAGE,
|
||||||
|
status: MessageBlockStatus.STREAMING
|
||||||
|
}
|
||||||
|
lastBlockType = MessageBlockType.IMAGE
|
||||||
|
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||||
|
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||||
|
} else {
|
||||||
|
const imageBlock = createImageBlock(assistantMsgId, {
|
||||||
|
status: MessageBlockStatus.PROCESSING
|
||||||
|
})
|
||||||
|
handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
onImageGenerated: (imageData) => {
|
onImageGenerated: (imageData) => {
|
||||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user