mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-27 21:01:32 +08:00
feat(AI Provider): Update image generation handling to use legacy implementation for enhanced features
- Refactor image generation logic to utilize legacy completions for better support of image editing. - Introduce uiMessages as a required parameter for image generation endpoints. - Update related types and middleware configurations to accommodate new message structures. - Adjust ConversationService and OrchestrateService to handle model and UI messages separately.
This commit is contained in:
parent
20311af8a8
commit
b8cefb8e85
@ -9,18 +9,16 @@
|
||||
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import { CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { createAiSdkProvider } from './provider/factory'
|
||||
@ -140,7 +138,24 @@ export default class ModernAiProvider {
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
return await this.modernImageGeneration(model as ImageModel, params, config)
|
||||
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
||||
if (!config.uiMessages) {
|
||||
throw new Error('uiMessages is required for image generation endpoint')
|
||||
}
|
||||
|
||||
const legacyParams: CompletionsParams = {
|
||||
callType: 'chat',
|
||||
messages: config.uiMessages, // 使用原始的 UI 消息格式
|
||||
assistant: config.assistant,
|
||||
streamOutput: config.streamOutput ?? true,
|
||||
onChunk: config.onChunk,
|
||||
topicId: config.topicId,
|
||||
mcpTools: config.mcpTools,
|
||||
enableWebSearch: config.enableWebSearch
|
||||
}
|
||||
|
||||
// 调用 legacy 的 completions,会自动使用 ImageGenerationMiddleware
|
||||
return await this.legacyProvider.completions(legacyParams)
|
||||
}
|
||||
|
||||
return await this.modernCompletions(model as LanguageModel, params, config)
|
||||
@ -290,7 +305,9 @@ export default class ModernAiProvider {
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||
* @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能
|
||||
*/
|
||||
/*
|
||||
private async modernImageGeneration(
|
||||
model: ImageModel,
|
||||
params: StreamTextParams,
|
||||
@ -407,6 +424,7 @@ export default class ModernAiProvider {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { MCPTool, Model, Provider } from '@renderer/types'
|
||||
import type { MCPTool, Message, Model, Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
|
||||
@ -23,6 +23,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
uiMessages?: Message[]
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -6,7 +6,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import type { Message, Model } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
|
||||
import {
|
||||
findFileBlocks,
|
||||
@ -154,11 +153,8 @@ async function convertMessageToAssistantModelMessage(
|
||||
/**
|
||||
* 转换 Cherry Studio 消息数组为 AI SDK 消息数组
|
||||
*/
|
||||
export async function convertMessagesToSdkMessages(
|
||||
messages: Message[],
|
||||
model: Model
|
||||
): Promise<StreamTextParams['messages']> {
|
||||
const sdkMessages: StreamTextParams['messages'] = []
|
||||
export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise<ModelMessage[]> {
|
||||
const sdkMessages: ModelMessage[] = []
|
||||
const isVision = isVisionModel(model)
|
||||
|
||||
for (const message of messages) {
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import { PlusOutlined, RedoOutlined } from '@ant-design/icons'
|
||||
import { loggerService } from '@logger'
|
||||
import AiProviderNew from '@renderer/aiCore/index_new'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg'
|
||||
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
@ -203,12 +203,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
||||
try {
|
||||
if (mode === 'aihubmix_image_generate') {
|
||||
if (painting.model.startsWith('imagen-')) {
|
||||
const AI = new AiProviderNew({
|
||||
id: painting.model,
|
||||
provider: 'aihubmix',
|
||||
name: painting.model,
|
||||
group: 'imagen'
|
||||
})
|
||||
const AI = new AiProvider(aihubmixProvider)
|
||||
const base64s = await AI.generateImage({
|
||||
prompt,
|
||||
model: painting.model,
|
||||
|
||||
@ -83,7 +83,8 @@ export async function fetchChatCompletion({
|
||||
assistant,
|
||||
options,
|
||||
onChunkReceived,
|
||||
topicId
|
||||
topicId,
|
||||
uiMessages
|
||||
}: FetchChatCompletionParams) {
|
||||
logger.info('fetchChatCompletion called with detailed context', {
|
||||
messageCount: messages?.length || 0,
|
||||
@ -132,7 +133,8 @@ export async function fetchChatCompletion({
|
||||
isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model || getDefaultModel()),
|
||||
enableWebSearch: capabilities.enableWebSearch,
|
||||
enableGenerateImage: capabilities.enableGenerateImage,
|
||||
mcpTools
|
||||
mcpTools,
|
||||
uiMessages
|
||||
}
|
||||
|
||||
// --- Call AI Completions ---
|
||||
@ -141,7 +143,8 @@ export async function fetchChatCompletion({
|
||||
...middlewareConfig,
|
||||
assistant,
|
||||
topicId,
|
||||
callType: 'chat'
|
||||
callType: 'chat',
|
||||
uiMessages
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { convertMessagesToSdkMessages } from '@renderer/aiCore/prepareParams'
|
||||
import { Assistant, Message } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters'
|
||||
import { ModelMessage } from 'ai'
|
||||
import { findLast, isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import { getAssistantSettings, getDefaultModel } from './AssistantService'
|
||||
@ -16,13 +16,16 @@ export class ConversationService {
|
||||
static async prepareMessagesForModel(
|
||||
messages: Message[],
|
||||
assistant: Assistant
|
||||
): Promise<StreamTextParams['messages']> {
|
||||
): Promise<{ modelMessages: ModelMessage[]; uiMessages: Message[] }> {
|
||||
const { contextCount } = getAssistantSettings(assistant)
|
||||
// This logic is extracted from the original ApiService.fetchChatCompletion
|
||||
// const contextMessages = filterContextMessages(messages)
|
||||
const lastUserMessage = findLast(messages, (m) => m.role === 'user')
|
||||
if (!lastUserMessage) {
|
||||
return
|
||||
return {
|
||||
modelMessages: [],
|
||||
uiMessages: []
|
||||
}
|
||||
}
|
||||
|
||||
const filteredMessages1 = filterAfterContextClearMessages(messages)
|
||||
@ -33,16 +36,19 @@ export class ConversationService {
|
||||
|
||||
const filteredMessages4 = filterAdjacentUserMessaegs(filteredMessages3)
|
||||
|
||||
let _messages = filterUserRoleStartMessages(
|
||||
let uiMessages = filterUserRoleStartMessages(
|
||||
filterEmptyMessages(filterAfterContextClearMessages(takeRight(filteredMessages4, contextCount + 2))) // 取原来几个provider的最大值
|
||||
)
|
||||
|
||||
// Fallback: ensure at least the last user message is present to avoid empty payloads
|
||||
if ((!_messages || _messages.length === 0) && lastUserMessage) {
|
||||
_messages = [lastUserMessage]
|
||||
if ((!uiMessages || uiMessages.length === 0) && lastUserMessage) {
|
||||
uiMessages = [lastUserMessage]
|
||||
}
|
||||
|
||||
return await convertMessagesToSdkMessages(_messages, assistant.model || getDefaultModel())
|
||||
return {
|
||||
modelMessages: await convertMessagesToSdkMessages(uiMessages, assistant.model || getDefaultModel()),
|
||||
uiMessages
|
||||
}
|
||||
}
|
||||
|
||||
static needsWebSearch(assistant: Assistant): boolean {
|
||||
|
||||
@ -42,14 +42,15 @@ export class OrchestrationService {
|
||||
const { messages, assistant } = request
|
||||
|
||||
try {
|
||||
const llmMessages = await ConversationService.prepareMessagesForModel(messages, assistant)
|
||||
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
|
||||
|
||||
await fetchChatCompletion({
|
||||
messages: llmMessages,
|
||||
messages: modelMessages,
|
||||
assistant: assistant,
|
||||
options: request.options,
|
||||
onChunkReceived,
|
||||
topicId: request.topicId
|
||||
topicId: request.topicId,
|
||||
uiMessages: uiMessages
|
||||
})
|
||||
} catch (error: any) {
|
||||
onChunkReceived({ type: ChunkType.ERROR, error })
|
||||
@ -70,17 +71,18 @@ export async function transformMessagesAndFetch(
|
||||
const { messages, assistant } = request
|
||||
|
||||
try {
|
||||
const llmMessages = await ConversationService.prepareMessagesForModel(messages, assistant)
|
||||
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
|
||||
|
||||
// replace prompt variables
|
||||
assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name)
|
||||
|
||||
await fetchChatCompletion({
|
||||
messages: llmMessages,
|
||||
messages: modelMessages,
|
||||
assistant: assistant,
|
||||
options: request.options,
|
||||
onChunkReceived,
|
||||
topicId: request.topicId
|
||||
topicId: request.topicId,
|
||||
uiMessages
|
||||
})
|
||||
} catch (error: any) {
|
||||
onChunkReceived({ type: ChunkType.ERROR, error })
|
||||
|
||||
@ -1307,6 +1307,7 @@ type BaseParams = {
|
||||
options?: FetchChatCompletionOptions
|
||||
onChunkReceived: (chunk: Chunk) => void
|
||||
topicId?: string // 添加 topicId 参数
|
||||
uiMessages: Message[]
|
||||
}
|
||||
|
||||
type MessagesParams = BaseParams & {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user