feat(AI Provider): Update image generation handling to use legacy implementation for enhanced features

- Refactor image generation logic to utilize legacy completions for better support of image editing.
- Introduce uiMessages as a required parameter for image generation endpoints.
- Update related types and middleware configurations to accommodate new message structures.
- Adjust ConversationService and OrchestrateService to handle model and UI messages separately.
This commit is contained in:
MyPrototypeWhat 2025-09-04 12:46:57 +08:00
parent 20311af8a8
commit b8cefb8e85
8 changed files with 56 additions and 34 deletions

View File

@ -9,18 +9,16 @@
import { createExecutor } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
import { ChunkType } from '@renderer/types/chunk'
import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './legacy/index'
import { CompletionsResult } from './legacy/middleware/schemas'
import { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import { createAiSdkProvider } from './provider/factory'
@ -140,7 +138,24 @@ export default class ModernAiProvider {
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
if (config.isImageGenerationEndpoint) {
return await this.modernImageGeneration(model as ImageModel, params, config)
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
}
const legacyParams: CompletionsParams = {
callType: 'chat',
messages: config.uiMessages, // 使用原始的 UI 消息格式
assistant: config.assistant,
streamOutput: config.streamOutput ?? true,
onChunk: config.onChunk,
topicId: config.topicId,
mcpTools: config.mcpTools,
enableWebSearch: config.enableWebSearch
}
// 调用 legacy 的 completions会自动使用 ImageGenerationMiddleware
return await this.legacyProvider.completions(legacyParams)
}
return await this.modernCompletions(model as LanguageModel, params, config)
@ -290,7 +305,9 @@ export default class ModernAiProvider {
/**
* 使 AI SDK
* @deprecated 使 legacy
*/
/*
private async modernImageGeneration(
model: ImageModel,
params: StreamTextParams,
@ -407,6 +424,7 @@ export default class ModernAiProvider {
throw error
}
}
*/
// 代理其他方法到原有实现
public async models() {

View File

@ -1,5 +1,5 @@
import { loggerService } from '@logger'
import type { MCPTool, Model, Provider } from '@renderer/types'
import type { MCPTool, Message, Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai'
@ -23,6 +23,7 @@ export interface AiSdkMiddlewareConfig {
enableWebSearch: boolean
enableGenerateImage: boolean
mcpTools?: MCPTool[]
uiMessages?: Message[]
}
/**

View File

@ -6,7 +6,6 @@
import { loggerService } from '@logger'
import { isVisionModel } from '@renderer/config/models'
import type { Message, Model } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
import {
findFileBlocks,
@ -154,11 +153,8 @@ async function convertMessageToAssistantModelMessage(
/**
* Cherry Studio AI SDK
*/
export async function convertMessagesToSdkMessages(
messages: Message[],
model: Model
): Promise<StreamTextParams['messages']> {
const sdkMessages: StreamTextParams['messages'] = []
export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise<ModelMessage[]> {
const sdkMessages: ModelMessage[] = []
const isVision = isVisionModel(model)
for (const message of messages) {

View File

@ -1,6 +1,6 @@
import { PlusOutlined, RedoOutlined } from '@ant-design/icons'
import { loggerService } from '@logger'
import AiProviderNew from '@renderer/aiCore/index_new'
import AiProvider from '@renderer/aiCore'
import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg'
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
import { HStack } from '@renderer/components/Layout'
@ -203,12 +203,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
try {
if (mode === 'aihubmix_image_generate') {
if (painting.model.startsWith('imagen-')) {
const AI = new AiProviderNew({
id: painting.model,
provider: 'aihubmix',
name: painting.model,
group: 'imagen'
})
const AI = new AiProvider(aihubmixProvider)
const base64s = await AI.generateImage({
prompt,
model: painting.model,

View File

@ -83,7 +83,8 @@ export async function fetchChatCompletion({
assistant,
options,
onChunkReceived,
topicId
topicId,
uiMessages
}: FetchChatCompletionParams) {
logger.info('fetchChatCompletion called with detailed context', {
messageCount: messages?.length || 0,
@ -132,7 +133,8 @@ export async function fetchChatCompletion({
isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model || getDefaultModel()),
enableWebSearch: capabilities.enableWebSearch,
enableGenerateImage: capabilities.enableGenerateImage,
mcpTools
mcpTools,
uiMessages
}
// --- Call AI Completions ---
@ -141,7 +143,8 @@ export async function fetchChatCompletion({
...middlewareConfig,
assistant,
topicId,
callType: 'chat'
callType: 'chat',
uiMessages
})
}

View File

@ -1,7 +1,7 @@
import { convertMessagesToSdkMessages } from '@renderer/aiCore/prepareParams'
import { Assistant, Message } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters'
import { ModelMessage } from 'ai'
import { findLast, isEmpty, takeRight } from 'lodash'
import { getAssistantSettings, getDefaultModel } from './AssistantService'
@ -16,13 +16,16 @@ export class ConversationService {
static async prepareMessagesForModel(
messages: Message[],
assistant: Assistant
): Promise<StreamTextParams['messages']> {
): Promise<{ modelMessages: ModelMessage[]; uiMessages: Message[] }> {
const { contextCount } = getAssistantSettings(assistant)
// This logic is extracted from the original ApiService.fetchChatCompletion
// const contextMessages = filterContextMessages(messages)
const lastUserMessage = findLast(messages, (m) => m.role === 'user')
if (!lastUserMessage) {
return
return {
modelMessages: [],
uiMessages: []
}
}
const filteredMessages1 = filterAfterContextClearMessages(messages)
@ -33,16 +36,19 @@ export class ConversationService {
const filteredMessages4 = filterAdjacentUserMessaegs(filteredMessages3)
let _messages = filterUserRoleStartMessages(
let uiMessages = filterUserRoleStartMessages(
filterEmptyMessages(filterAfterContextClearMessages(takeRight(filteredMessages4, contextCount + 2))) // 取原来几个provider的最大值
)
// Fallback: ensure at least the last user message is present to avoid empty payloads
if ((!_messages || _messages.length === 0) && lastUserMessage) {
_messages = [lastUserMessage]
if ((!uiMessages || uiMessages.length === 0) && lastUserMessage) {
uiMessages = [lastUserMessage]
}
return await convertMessagesToSdkMessages(_messages, assistant.model || getDefaultModel())
return {
modelMessages: await convertMessagesToSdkMessages(uiMessages, assistant.model || getDefaultModel()),
uiMessages
}
}
static needsWebSearch(assistant: Assistant): boolean {

View File

@ -42,14 +42,15 @@ export class OrchestrationService {
const { messages, assistant } = request
try {
const llmMessages = await ConversationService.prepareMessagesForModel(messages, assistant)
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
await fetchChatCompletion({
messages: llmMessages,
messages: modelMessages,
assistant: assistant,
options: request.options,
onChunkReceived,
topicId: request.topicId
topicId: request.topicId,
uiMessages: uiMessages
})
} catch (error: any) {
onChunkReceived({ type: ChunkType.ERROR, error })
@ -70,17 +71,18 @@ export async function transformMessagesAndFetch(
const { messages, assistant } = request
try {
const llmMessages = await ConversationService.prepareMessagesForModel(messages, assistant)
const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant)
// replace prompt variables
assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name)
await fetchChatCompletion({
messages: llmMessages,
messages: modelMessages,
assistant: assistant,
options: request.options,
onChunkReceived,
topicId: request.topicId
topicId: request.topicId,
uiMessages
})
} catch (error: any) {
onChunkReceived({ type: ChunkType.ERROR, error })

View File

@ -1307,6 +1307,7 @@ type BaseParams = {
options?: FetchChatCompletionOptions
onChunkReceived: (chunk: Chunk) => void
topicId?: string // 添加 topicId 参数
uiMessages: Message[]
}
type MessagesParams = BaseParams & {