mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: enhance image generation handling and tool integration
- Updated image generation logic to support new model types and improved size handling. - Refactored middleware configuration to better manage tool usage and reasoning capabilities. - Introduced new utility functions for checking model compatibility with image generation. - Enhanced the integration of plugins for improved functionality during image generation processes. - Removed deprecated knowledge search tool to streamline the codebase.
This commit is contained in:
parent
ecc08bd3f7
commit
71959f577d
@ -253,16 +253,16 @@ export class AiSdkToChunkAdapter {
|
||||
// })
|
||||
}
|
||||
break
|
||||
// case 'file':
|
||||
// // 文件相关事件,可能是图片生成
|
||||
// this.onChunk({
|
||||
// type: ChunkType.IMAGE_COMPLETE,
|
||||
// image: {
|
||||
// type: 'base64',
|
||||
// images: [chunk.base64]
|
||||
// }
|
||||
// })
|
||||
// break
|
||||
case 'file':
|
||||
// 文件相关事件,可能是图片生成
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'error':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
|
||||
@ -64,7 +64,7 @@ import {
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
@ -454,7 +454,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
|
||||
@ -53,7 +53,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToGeminiMessage,
|
||||
mcpToolsToGeminiTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
@ -451,7 +451,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
|
||||
@ -44,7 +44,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||
mcpToolsToOpenAIChatTools,
|
||||
openAIToolsToMcpTool
|
||||
@ -479,7 +479,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
|
||||
@ -30,7 +30,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToOpenAIMessage,
|
||||
mcpToolsToOpenAIResponseTools,
|
||||
openAIToolsToMcpTool
|
||||
@ -362,7 +362,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
const { tools: extraTools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { OpenAIAPIClient } from './clients'
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
@ -93,7 +93,7 @@ export default class AiProvider {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
}
|
||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
if (!isPromptToolUse(params.assistant)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
}
|
||||
if (params.callType !== 'chat') {
|
||||
|
||||
@ -19,7 +19,7 @@ import {
|
||||
StreamTextParams
|
||||
} from '@cherrystudio/ai-core'
|
||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { isDedicatedImageGenerationModel, isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
@ -69,10 +69,10 @@ function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
// console.log('aiSdkProviderId', aiSdkProviderId)
|
||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||
const actualProviderId = actualProvider.id
|
||||
const actualProviderId = actualProvider.type
|
||||
const openaiResponseOptions =
|
||||
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
|
||||
actualProviderId === 'openai'
|
||||
actualProviderId === 'openai-response'
|
||||
? {
|
||||
mode: 'responses'
|
||||
}
|
||||
@ -167,15 +167,18 @@ export default class ModernAiProvider {
|
||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||
plugins.push(webSearchPlugin())
|
||||
}
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
||||
// 2. 支持工具调用时添加搜索插件
|
||||
if (middlewareConfig.isSupportedToolUse) {
|
||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
||||
}
|
||||
|
||||
// 2. 推理模型时添加推理插件
|
||||
// 3. 推理模型时添加推理插件
|
||||
if (middlewareConfig.enableReasoning) {
|
||||
plugins.push(reasoningTimePlugin)
|
||||
}
|
||||
|
||||
// 3. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
// 4. 启用Prompt工具调用时添加工具插件
|
||||
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||
plugins.push(
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
@ -216,8 +219,7 @@ export default class ModernAiProvider {
|
||||
): Promise<CompletionsResult> {
|
||||
console.log('completions', modelId, params, middlewareConfig)
|
||||
|
||||
// 检查是否为图像生成模型
|
||||
if (middlewareConfig.model && isDedicatedImageGenerationModel(middlewareConfig.model)) {
|
||||
if (middlewareConfig.isImageGenerationEndpoint) {
|
||||
return await this.modernImageGeneration(modelId, params, middlewareConfig)
|
||||
}
|
||||
|
||||
@ -313,17 +315,17 @@ export default class ModernAiProvider {
|
||||
throw new Error('No prompt found in user message.')
|
||||
}
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 发送图像生成开始事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||
}
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 构建图像生成参数
|
||||
const imageParams = {
|
||||
prompt,
|
||||
size: '1024x1024' as `${number}x${number}`, // 默认尺寸,使用正确的类型
|
||||
size: isNotSupportedImageSizeModel(middlewareConfig.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
|
||||
n: 1,
|
||||
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||
}
|
||||
@ -338,7 +340,7 @@ export default class ModernAiProvider {
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:image/png;base64,${image.base64}`)
|
||||
images.push(`data:${image.mediaType};base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -351,14 +353,27 @@ export default class ModernAiProvider {
|
||||
})
|
||||
}
|
||||
|
||||
// 发送响应完成事件
|
||||
// 发送块完成事件(类似于 modernCompletions 的处理)
|
||||
if (onChunk) {
|
||||
const usage = {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
prompt_tokens: prompt.length, // 估算的 token 数量
|
||||
completion_tokens: 0, // 图像生成没有 completion tokens
|
||||
total_tokens: prompt.length
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 发送 LLM 响应完成事件
|
||||
onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
|
||||
@ -10,14 +10,19 @@ import type { Chunk } from '@renderer/types/chunk'
|
||||
* AI SDK 中间件配置项
|
||||
*/
|
||||
export interface AiSdkMiddlewareConfig {
|
||||
streamOutput?: boolean
|
||||
streamOutput: boolean
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
model?: Model
|
||||
provider?: Provider
|
||||
enableReasoning?: boolean
|
||||
enableReasoning: boolean
|
||||
// 是否开启提示词工具调用
|
||||
enableTool?: boolean
|
||||
enableWebSearch?: boolean
|
||||
isPromptToolUse: boolean
|
||||
// 是否支持工具调用
|
||||
isSupportedToolUse: boolean
|
||||
// image generation endpoint
|
||||
isImageGenerationEndpoint: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
mcpTools?: BaseTool[]
|
||||
// TODO assistant
|
||||
assistant: Assistant
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { ProviderId } from '@cherrystudio/ai-core/types'
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import { isOpenAIModel } from '@renderer/config/models'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
|
||||
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
|
||||
@ -16,7 +16,7 @@ export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'opena
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (isOpenAILLMModel(model)) {
|
||||
if (isOpenAIModel(model)) {
|
||||
return 'openai'
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ export function createAihubmixProvider(model: Model, provider: Provider): Provid
|
||||
if (providerId === 'openai') {
|
||||
return {
|
||||
...provider,
|
||||
type: 'openai'
|
||||
type: 'openai-response'
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -255,8 +255,6 @@ export async function buildStreamTextParams(
|
||||
provider: Provider,
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
enableWebSearch?: boolean
|
||||
webSearchProviderId?: string
|
||||
requestOptions?: {
|
||||
signal?: AbortSignal
|
||||
@ -267,9 +265,9 @@ export async function buildStreamTextParams(
|
||||
): Promise<{
|
||||
params: StreamTextParams
|
||||
modelId: string
|
||||
capabilities: { enableReasoning?: boolean; enableWebSearch?: boolean; enableGenerateImage?: boolean }
|
||||
capabilities: { enableReasoning: boolean; enableWebSearch: boolean; enableGenerateImage: boolean }
|
||||
}> {
|
||||
const { mcpTools, enableTools, webSearchProviderId } = options
|
||||
const { mcpTools } = options
|
||||
|
||||
const model = assistant.model || getDefaultModel()
|
||||
|
||||
@ -291,12 +289,7 @@ export async function buildStreamTextParams(
|
||||
isGenerateImageModel(model) &&
|
||||
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
||||
|
||||
// 构建系统提示
|
||||
const { tools } = setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: enableTools
|
||||
})
|
||||
const tools = setupToolsConfig(mcpTools)
|
||||
|
||||
// if (webSearchProviderId) {
|
||||
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||
|
||||
5
src/renderer/src/aiCore/utils/image.ts
Normal file
5
src/renderer/src/aiCore/utils/image.ts
Normal file
@ -0,0 +1,5 @@
|
||||
export function buildGeminiGenerateImageParams(): Record<string, any> {
|
||||
return {
|
||||
responseModalities: ['TEXT', 'IMAGE']
|
||||
}
|
||||
}
|
||||
@ -1,36 +1,21 @@
|
||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
|
||||
import { isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { tool } from 'ai'
|
||||
import { JSONSchema7 } from 'json-schema'
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: Record<string, Tool>
|
||||
useSystemPromptForTools?: boolean
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
|
||||
export function setupToolsConfig(mcpTools?: MCPTool[]): Record<string, Tool> | undefined {
|
||||
let tools: Record<string, Tool> = {}
|
||||
|
||||
if (!mcpTools?.length) {
|
||||
return { tools }
|
||||
return undefined
|
||||
}
|
||||
|
||||
tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
if (mcpTools.length > SYSTEM_PROMPT_THRESHOLD) {
|
||||
return { tools, useSystemPromptForTools: true }
|
||||
}
|
||||
|
||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||
return { tools, useSystemPromptForTools: false }
|
||||
}
|
||||
|
||||
return { tools, useSystemPromptForTools: true }
|
||||
return tools
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import { Assistant, Model, Provider } from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { buildGeminiGenerateImageParams } from './image'
|
||||
import {
|
||||
getAnthropicReasoningParams,
|
||||
getCustomParameters,
|
||||
@ -140,7 +141,7 @@ function buildGeminiProviderOptions(
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
const { enableReasoning, enableGenerateImage } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// Gemini 推理参数
|
||||
@ -152,6 +153,13 @@ function buildGeminiProviderOptions(
|
||||
}
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...buildGeminiGenerateImageParams()
|
||||
}
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
|
||||
@ -204,9 +204,10 @@ export const VISION_REGEX = new RegExp(
|
||||
)
|
||||
|
||||
// For middleware to identify models that must use the dedicated Image API
|
||||
export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1']
|
||||
export const isDedicatedImageGenerationModel = (model: Model): boolean =>
|
||||
DEDICATED_IMAGE_MODELS.filter((m) => model.id.includes(m)).length > 0
|
||||
export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1', 'imagen(?:-[\\w-]+)?']
|
||||
const DEDICATED_IMAGE_MODELS_REGEX = new RegExp(DEDICATED_IMAGE_MODELS.join('|'), 'i')
|
||||
|
||||
export const isDedicatedImageGenerationModel = (model: Model): boolean => DEDICATED_IMAGE_MODELS_REGEX.test(model.id)
|
||||
|
||||
// Text to image models
|
||||
export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|image|gpt-image/i
|
||||
@ -272,7 +273,7 @@ export function isFunctionCallingModel(model?: Model): boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
if (isEmbeddingModel(model)) {
|
||||
if (isEmbeddingModel(model) || isGenerateImageModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
@ -2775,6 +2776,16 @@ export function isGenerateImageModel(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
export function isNotSupportedImageSizeModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
const baseName = getLowerBaseModelName(model.id, '/')
|
||||
|
||||
return baseName.includes('grok-2-image')
|
||||
}
|
||||
|
||||
export function isSupportedDisableGenerationModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
|
||||
@ -37,6 +37,7 @@ import type { MessageInputBaseParams } from '@renderer/types/newMessage'
|
||||
import { classNames, delay, formatFileSize, getFileExtension } from '@renderer/utils'
|
||||
import { formatQuotedText } from '@renderer/utils/formats'
|
||||
import { getFilesFromDropEvent, getSendMessageShortcutLabel, isSendMessageKeyPressed } from '@renderer/utils/input'
|
||||
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { getLanguageByLangcode } from '@renderer/utils/translate'
|
||||
import { documentExts, imageExts, textExts } from '@shared/config/constant'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
@ -801,6 +802,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
||||
|
||||
const isExpended = expended || !!textareaHeight
|
||||
const showThinkingButton = isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)
|
||||
const showMcpTools = isSupportedToolUse(assistant)
|
||||
|
||||
if (isMultiSelectMode) {
|
||||
return null
|
||||
@ -879,6 +881,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
||||
setFiles={setFiles}
|
||||
showThinkingButton={showThinkingButton}
|
||||
showKnowledgeIcon={showKnowledgeIcon}
|
||||
showMcpTools={showMcpTools}
|
||||
selectedKnowledgeBases={selectedKnowledgeBases}
|
||||
handleKnowledgeBaseSelect={handleKnowledgeBaseSelect}
|
||||
setText={setText}
|
||||
|
||||
@ -61,6 +61,7 @@ export interface InputbarToolsProps {
|
||||
extensions: string[]
|
||||
showThinkingButton: boolean
|
||||
showKnowledgeIcon: boolean
|
||||
showMcpTools: boolean
|
||||
selectedKnowledgeBases: KnowledgeBase[]
|
||||
handleKnowledgeBaseSelect: (bases?: KnowledgeBase[]) => void
|
||||
setText: Dispatch<SetStateAction<string>>
|
||||
@ -102,6 +103,7 @@ const InputbarTools = ({
|
||||
setFiles,
|
||||
showThinkingButton,
|
||||
showKnowledgeIcon,
|
||||
showMcpTools,
|
||||
selectedKnowledgeBases,
|
||||
handleKnowledgeBaseSelect,
|
||||
setText,
|
||||
@ -371,7 +373,8 @@ const InputbarTools = ({
|
||||
setInputValue={setText}
|
||||
resizeTextArea={resizeTextArea}
|
||||
/>
|
||||
)
|
||||
),
|
||||
condition: showMcpTools
|
||||
},
|
||||
{
|
||||
key: 'generate_image',
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import { PlusOutlined, RedoOutlined } from '@ant-design/icons'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import AiProviderNew from '@renderer/aiCore/index_new'
|
||||
import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg'
|
||||
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
@ -179,12 +179,17 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
||||
try {
|
||||
if (mode === 'generate') {
|
||||
if (painting.model.startsWith('imagen-')) {
|
||||
const AI = new AiProvider(aihubmixProvider)
|
||||
const AI = new AiProviderNew({
|
||||
id: painting.model,
|
||||
provider: 'aihubmix',
|
||||
name: painting.model,
|
||||
group: 'imagen'
|
||||
})
|
||||
const base64s = await AI.generateImage({
|
||||
prompt,
|
||||
model: painting.model,
|
||||
imageSize: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':') || '1:1',
|
||||
batchSize: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages || 1,
|
||||
batchSize: painting.model.startsWith('imagen-4.0-ultra') ? 1 : painting.numberOfImages || 1,
|
||||
personGeneration: painting.personGeneration
|
||||
})
|
||||
if (base64s?.length > 0) {
|
||||
|
||||
@ -72,8 +72,8 @@ export const createModeConfigs = (): Record<AihubmixMode, ConfigItem[]> => {
|
||||
label: 'Gemini',
|
||||
title: 'Gemini',
|
||||
options: [
|
||||
{ label: 'imagen-4.0-preview', value: 'imagen-4.0-generate-preview-05-20' },
|
||||
{ label: 'imagen-4.0-ultra-exp', value: 'imagen-4.0-ultra-generate-exp-05-20' },
|
||||
{ label: 'imagen-4.0-preview', value: 'imagen-4.0-generate-preview-06-06' },
|
||||
{ label: 'imagen-4.0-ultra-preview', value: 'imagen-4.0-ultra-generate-preview-06-06' },
|
||||
{ label: 'imagen-3.0', value: 'imagen-3.0-generate-001' }
|
||||
]
|
||||
},
|
||||
|
||||
@ -7,6 +7,7 @@ import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMi
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||
import {
|
||||
isDedicatedImageGenerationModel,
|
||||
isEmbeddingModel,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
@ -20,7 +21,7 @@ import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
@ -391,7 +392,11 @@ export async function fetchChatCompletion({
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
const mcpTools = await fetchMcpTools(assistant)
|
||||
const mcpTools: MCPTool[] = []
|
||||
|
||||
if (isSupportedToolUse(assistant)) {
|
||||
mcpTools.push(...(await fetchMcpTools(assistant)))
|
||||
}
|
||||
|
||||
// 使用 transformParameters 模块构建参数
|
||||
const {
|
||||
@ -400,7 +405,6 @@ export async function fetchChatCompletion({
|
||||
capabilities
|
||||
} = await buildStreamTextParams(messages, assistant, provider, {
|
||||
mcpTools: mcpTools,
|
||||
enableTools: isEnabledToolUse(assistant),
|
||||
webSearchProviderId: assistant.webSearchProviderId,
|
||||
requestOptions: options
|
||||
})
|
||||
@ -430,8 +434,11 @@ export async function fetchChatCompletion({
|
||||
model: assistant.model,
|
||||
provider: provider,
|
||||
enableReasoning: capabilities.enableReasoning,
|
||||
enableTool: assistant.settings?.toolUseMode === 'prompt',
|
||||
isPromptToolUse: isPromptToolUse(assistant),
|
||||
isSupportedToolUse: isSupportedToolUse(assistant),
|
||||
isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model || getDefaultModel()),
|
||||
enableWebSearch: capabilities.enableWebSearch,
|
||||
enableGenerateImage: capabilities.enableGenerateImage,
|
||||
mcpTools,
|
||||
assistant
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ export const createImageCallbacks = (deps: ImageCallbacksDependencies) => {
|
||||
}
|
||||
},
|
||||
|
||||
onImageGenerated: (imageData: any) => {
|
||||
onImageGenerated: async (imageData: any) => {
|
||||
if (imageBlockId) {
|
||||
if (!imageData) {
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
@ -62,7 +62,16 @@ export const createImageCallbacks = (deps: ImageCallbacksDependencies) => {
|
||||
}
|
||||
imageBlockId = null
|
||||
} else {
|
||||
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
||||
if (imageData) {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.SUCCESS,
|
||||
url: imageData.images?.[0] || 'placeholder_image_url',
|
||||
metadata: { generateImageResponse: imageData }
|
||||
})
|
||||
await blockManager.handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
} else {
|
||||
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1006,12 +1006,30 @@ export function mcpToolCallResponseToGeminiMessage(
|
||||
return message
|
||||
}
|
||||
|
||||
export function isEnabledToolUse(assistant: Assistant) {
|
||||
/**
|
||||
* 是否启用工具使用
|
||||
* 1. 如果模型支持函数调用,则启用工具使用
|
||||
* 2. 如果工具使用模式为 prompt,则启用工具使用
|
||||
* @param assistant
|
||||
* @returns 是否启用工具使用
|
||||
*/
|
||||
export function isSupportedToolUse(assistant: Assistant) {
|
||||
if (assistant.model) {
|
||||
if (isFunctionCallingModel(assistant.model)) {
|
||||
return assistant.settings?.toolUseMode === 'function'
|
||||
return true
|
||||
} else {
|
||||
return assistant.settings?.toolUseMode === 'prompt'
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* 是否使用提示词工具使用
|
||||
* @param assistant
|
||||
* @returns 是否使用提示词工具使用
|
||||
*/
|
||||
export function isPromptToolUse(assistant: Assistant) {
|
||||
return assistant.settings?.toolUseMode === 'prompt'
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user