diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 071b6102e9..9991dffd1f 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -9,18 +9,16 @@ import { createExecutor } from '@cherrystudio/ai-core' import { loggerService } from '@logger' -import { isNotSupportedImageSizeModel } from '@renderer/config/models' import { getEnableDeveloperMode } from '@renderer/hooks/useSettings' import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes' -import { ChunkType } from '@renderer/types/chunk' import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai' import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' import LegacyAiProvider from './legacy/index' -import { CompletionsResult } from './legacy/middleware/schemas' +import { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas' import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' import { buildPlugins } from './plugins/PluginBuilder' import { createAiSdkProvider } from './provider/factory' @@ -140,7 +138,24 @@ export default class ModernAiProvider { config: ModernAiProviderConfig ): Promise { if (config.isImageGenerationEndpoint) { - return await this.modernImageGeneration(model as ImageModel, params, config) + // 使用 legacy 实现处理图像生成(支持图片编辑等高级功能) + if (!config.uiMessages) { + throw new Error('uiMessages is required for image generation endpoint') + } + + const legacyParams: CompletionsParams = { + callType: 'chat', + messages: config.uiMessages, // 使用原始的 UI 消息格式 + assistant: config.assistant, + streamOutput: config.streamOutput ?? true, + onChunk: config.onChunk, + topicId: config.topicId, + mcpTools: config.mcpTools, + enableWebSearch: config.enableWebSearch + } + + // 调用 legacy 的 completions,会自动使用 ImageGenerationMiddleware + return await this.legacyProvider.completions(legacyParams) } return await this.modernCompletions(model as LanguageModel, params, config) @@ -290,7 +305,9 @@ export default class ModernAiProvider { /** * 使用现代化 AI SDK 的图像生成实现,支持流式输出 + * @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能 */ + /* private async modernImageGeneration( model: ImageModel, params: StreamTextParams, @@ -407,6 +424,7 @@ export default class ModernAiProvider { throw error } } + */ // 代理其他方法到原有实现 public async models() { diff --git a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts index faef00051f..f331d36a7e 100644 --- a/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts +++ b/src/renderer/src/aiCore/middleware/AiSdkMiddlewareBuilder.ts @@ -1,5 +1,5 @@ import { loggerService } from '@logger' -import type { MCPTool, Model, Provider } from '@renderer/types' +import type { MCPTool, Message, Model, Provider } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { extractReasoningMiddleware, LanguageModelMiddleware, simulateStreamingMiddleware } from 'ai' @@ -23,6 +23,7 @@ export interface AiSdkMiddlewareConfig { enableWebSearch: boolean enableGenerateImage: boolean mcpTools?: MCPTool[] + uiMessages?: Message[] } /** diff --git a/src/renderer/src/aiCore/prepareParams/messageConverter.ts b/src/renderer/src/aiCore/prepareParams/messageConverter.ts index a869b19d8f..d11f25fc2c 100644 --- a/src/renderer/src/aiCore/prepareParams/messageConverter.ts +++ b/src/renderer/src/aiCore/prepareParams/messageConverter.ts @@ -6,7 +6,6 @@ import { loggerService } from '@logger' import { isVisionModel } from '@renderer/config/models' import type { Message, Model } from '@renderer/types' -import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' import { findFileBlocks, @@ -154,11 +153,8 @@ async function convertMessageToAssistantModelMessage( /** * 转换 Cherry Studio 消息数组为 AI SDK 消息数组 */ -export async function convertMessagesToSdkMessages( - messages: Message[], - model: Model -): Promise { - const sdkMessages: StreamTextParams['messages'] = [] +export async function convertMessagesToSdkMessages(messages: Message[], model: Model): Promise { + const sdkMessages: ModelMessage[] = [] const isVision = isVisionModel(model) for (const message of messages) { diff --git a/src/renderer/src/pages/paintings/AihubmixPage.tsx b/src/renderer/src/pages/paintings/AihubmixPage.tsx index 78fd923b9d..e9bbac452a 100644 --- a/src/renderer/src/pages/paintings/AihubmixPage.tsx +++ b/src/renderer/src/pages/paintings/AihubmixPage.tsx @@ -1,6 +1,6 @@ import { PlusOutlined, RedoOutlined } from '@ant-design/icons' import { loggerService } from '@logger' -import AiProviderNew from '@renderer/aiCore/index_new' +import AiProvider from '@renderer/aiCore' import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg' import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar' import { HStack } from '@renderer/components/Layout' @@ -203,12 +203,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => { try { if (mode === 'aihubmix_image_generate') { if (painting.model.startsWith('imagen-')) { - const AI = new AiProviderNew({ - id: painting.model, - provider: 'aihubmix', - name: painting.model, - group: 'imagen' - }) + const AI = new AiProvider(aihubmixProvider) const base64s = await AI.generateImage({ prompt, model: painting.model, diff --git a/src/renderer/src/services/ApiService.ts b/src/renderer/src/services/ApiService.ts index 3e09050ea3..2f69eea8e0 100644 --- a/src/renderer/src/services/ApiService.ts +++ b/src/renderer/src/services/ApiService.ts @@ -83,7 +83,8 @@ export async function fetchChatCompletion({ assistant, options, onChunkReceived, - topicId + topicId, + uiMessages }: FetchChatCompletionParams) { logger.info('fetchChatCompletion called with detailed context', { messageCount: messages?.length || 0, @@ -132,7 +133,8 @@ export async function fetchChatCompletion({ isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model || getDefaultModel()), enableWebSearch: capabilities.enableWebSearch, enableGenerateImage: capabilities.enableGenerateImage, - mcpTools + mcpTools, + uiMessages } // --- Call AI Completions --- @@ -141,7 +143,8 @@ export async function fetchChatCompletion({ ...middlewareConfig, assistant, topicId, - callType: 'chat' + callType: 'chat', + uiMessages }) } diff --git a/src/renderer/src/services/ConversationService.ts b/src/renderer/src/services/ConversationService.ts index 269424cc0f..a7f3fab13c 100644 --- a/src/renderer/src/services/ConversationService.ts +++ b/src/renderer/src/services/ConversationService.ts @@ -1,7 +1,7 @@ import { convertMessagesToSdkMessages } from '@renderer/aiCore/prepareParams' import { Assistant, Message } from '@renderer/types' -import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import { filterAdjacentUserMessaegs, filterLastAssistantMessage } from '@renderer/utils/messageUtils/filters' +import { ModelMessage } from 'ai' import { findLast, isEmpty, takeRight } from 'lodash' import { getAssistantSettings, getDefaultModel } from './AssistantService' @@ -16,13 +16,16 @@ export class ConversationService { static async prepareMessagesForModel( messages: Message[], assistant: Assistant - ): Promise { + ): Promise<{ modelMessages: ModelMessage[]; uiMessages: Message[] }> { const { contextCount } = getAssistantSettings(assistant) // This logic is extracted from the original ApiService.fetchChatCompletion // const contextMessages = filterContextMessages(messages) const lastUserMessage = findLast(messages, (m) => m.role === 'user') if (!lastUserMessage) { - return + return { + modelMessages: [], + uiMessages: [] + } } const filteredMessages1 = filterAfterContextClearMessages(messages) @@ -33,16 +36,19 @@ export class ConversationService { const filteredMessages4 = filterAdjacentUserMessaegs(filteredMessages3) - let _messages = filterUserRoleStartMessages( + let uiMessages = filterUserRoleStartMessages( filterEmptyMessages(filterAfterContextClearMessages(takeRight(filteredMessages4, contextCount + 2))) // 取原来几个provider的最大值 ) // Fallback: ensure at least the last user message is present to avoid empty payloads - if ((!_messages || _messages.length === 0) && lastUserMessage) { - _messages = [lastUserMessage] + if ((!uiMessages || uiMessages.length === 0) && lastUserMessage) { + uiMessages = [lastUserMessage] } - return await convertMessagesToSdkMessages(_messages, assistant.model || getDefaultModel()) + return { + modelMessages: await convertMessagesToSdkMessages(uiMessages, assistant.model || getDefaultModel()), + uiMessages + } } static needsWebSearch(assistant: Assistant): boolean { diff --git a/src/renderer/src/services/OrchestrateService.ts b/src/renderer/src/services/OrchestrateService.ts index 996ee3f923..eef206d14e 100644 --- a/src/renderer/src/services/OrchestrateService.ts +++ b/src/renderer/src/services/OrchestrateService.ts @@ -42,14 +42,15 @@ export class OrchestrationService { const { messages, assistant } = request try { - const llmMessages = await ConversationService.prepareMessagesForModel(messages, assistant) + const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant) await fetchChatCompletion({ - messages: llmMessages, + messages: modelMessages, assistant: assistant, options: request.options, onChunkReceived, - topicId: request.topicId + topicId: request.topicId, + uiMessages: uiMessages }) } catch (error: any) { onChunkReceived({ type: ChunkType.ERROR, error }) @@ -70,17 +71,18 @@ export async function transformMessagesAndFetch( const { messages, assistant } = request try { - const llmMessages = await ConversationService.prepareMessagesForModel(messages, assistant) + const { modelMessages, uiMessages } = await ConversationService.prepareMessagesForModel(messages, assistant) // replace prompt variables assistant.prompt = await replacePromptVariables(assistant.prompt, assistant.model?.name) await fetchChatCompletion({ - messages: llmMessages, + messages: modelMessages, assistant: assistant, options: request.options, onChunkReceived, - topicId: request.topicId + topicId: request.topicId, + uiMessages }) } catch (error: any) { onChunkReceived({ type: ChunkType.ERROR, error }) diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index d38a739676..e8a60c412d 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -1307,6 +1307,7 @@ type BaseParams = { options?: FetchChatCompletionOptions onChunkReceived: (chunk: Chunk) => void topicId?: string // 添加 topicId 参数 + uiMessages: Message[] } type MessagesParams = BaseParams & {