mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-11 08:19:01 +08:00
refactor: enhance image generation handling and tool integration
- Updated image generation logic to support new model types and improved size handling. - Refactored middleware configuration to better manage tool usage and reasoning capabilities. - Introduced new utility functions for checking model compatibility with image generation. - Enhanced the integration of plugins for improved functionality during image generation processes. - Removed deprecated knowledge search tool to streamline the codebase.
This commit is contained in:
parent
ecc08bd3f7
commit
71959f577d
@ -253,16 +253,16 @@ export class AiSdkToChunkAdapter {
|
|||||||
// })
|
// })
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
// case 'file':
|
case 'file':
|
||||||
// // 文件相关事件,可能是图片生成
|
// 文件相关事件,可能是图片生成
|
||||||
// this.onChunk({
|
this.onChunk({
|
||||||
// type: ChunkType.IMAGE_COMPLETE,
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
// image: {
|
image: {
|
||||||
// type: 'base64',
|
type: 'base64',
|
||||||
// images: [chunk.base64]
|
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
|
||||||
// }
|
}
|
||||||
// })
|
})
|
||||||
// break
|
break
|
||||||
case 'error':
|
case 'error':
|
||||||
this.onChunk({
|
this.onChunk({
|
||||||
type: ChunkType.ERROR,
|
type: ChunkType.ERROR,
|
||||||
|
|||||||
@ -64,7 +64,7 @@ import {
|
|||||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import {
|
import {
|
||||||
anthropicToolUseToMcpTool,
|
anthropicToolUseToMcpTool,
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToAnthropicMessage,
|
mcpToolCallResponseToAnthropicMessage,
|
||||||
mcpToolsToAnthropicTools
|
mcpToolsToAnthropicTools
|
||||||
} from '@renderer/utils/mcp-tools'
|
} from '@renderer/utils/mcp-tools'
|
||||||
@ -454,7 +454,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
const { tools } = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (this.useSystemPromptForTools) {
|
if (this.useSystemPromptForTools) {
|
||||||
|
|||||||
@ -53,7 +53,7 @@ import {
|
|||||||
} from '@renderer/types/sdk'
|
} from '@renderer/types/sdk'
|
||||||
import {
|
import {
|
||||||
geminiFunctionCallToMcpTool,
|
geminiFunctionCallToMcpTool,
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToGeminiMessage,
|
mcpToolCallResponseToGeminiMessage,
|
||||||
mcpToolsToGeminiTools
|
mcpToolsToGeminiTools
|
||||||
} from '@renderer/utils/mcp-tools'
|
} from '@renderer/utils/mcp-tools'
|
||||||
@ -451,7 +451,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
const { tools } = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools,
|
mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (this.useSystemPromptForTools) {
|
if (this.useSystemPromptForTools) {
|
||||||
|
|||||||
@ -44,7 +44,7 @@ import {
|
|||||||
} from '@renderer/types/sdk'
|
} from '@renderer/types/sdk'
|
||||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import {
|
import {
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||||
mcpToolsToOpenAIChatTools,
|
mcpToolsToOpenAIChatTools,
|
||||||
openAIToolsToMcpTool
|
openAIToolsToMcpTool
|
||||||
@ -479,7 +479,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
|||||||
const { tools } = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (this.useSystemPromptForTools) {
|
if (this.useSystemPromptForTools) {
|
||||||
|
|||||||
@ -30,7 +30,7 @@ import {
|
|||||||
} from '@renderer/types/sdk'
|
} from '@renderer/types/sdk'
|
||||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import {
|
import {
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToOpenAIMessage,
|
mcpToolCallResponseToOpenAIMessage,
|
||||||
mcpToolsToOpenAIResponseTools,
|
mcpToolsToOpenAIResponseTools,
|
||||||
openAIToolsToMcpTool
|
openAIToolsToMcpTool
|
||||||
@ -362,7 +362,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
|||||||
const { tools: extraTools } = this.setupToolsConfig({
|
const { tools: extraTools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (this.useSystemPromptForTools) {
|
if (this.useSystemPromptForTools) {
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
|
||||||
|
|
||||||
import { OpenAIAPIClient } from './clients'
|
import { OpenAIAPIClient } from './clients'
|
||||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||||
@ -93,7 +93,7 @@ export default class AiProvider {
|
|||||||
builder.remove(ToolUseExtractionMiddlewareName)
|
builder.remove(ToolUseExtractionMiddlewareName)
|
||||||
builder.remove(McpToolChunkMiddlewareName)
|
builder.remove(McpToolChunkMiddlewareName)
|
||||||
}
|
}
|
||||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
if (!isPromptToolUse(params.assistant)) {
|
||||||
builder.remove(ToolUseExtractionMiddlewareName)
|
builder.remove(ToolUseExtractionMiddlewareName)
|
||||||
}
|
}
|
||||||
if (params.callType !== 'chat') {
|
if (params.callType !== 'chat') {
|
||||||
|
|||||||
@ -19,7 +19,7 @@ import {
|
|||||||
StreamTextParams
|
StreamTextParams
|
||||||
} from '@cherrystudio/ai-core'
|
} from '@cherrystudio/ai-core'
|
||||||
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
import { createPromptToolUsePlugin, webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
import { isDedicatedImageGenerationModel, isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
@ -69,10 +69,10 @@ function providerToAiSdkConfig(actualProvider: Provider): {
|
|||||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||||
// console.log('aiSdkProviderId', aiSdkProviderId)
|
// console.log('aiSdkProviderId', aiSdkProviderId)
|
||||||
// 如果provider是openai,则使用strict模式并且默认responses api
|
// 如果provider是openai,则使用strict模式并且默认responses api
|
||||||
const actualProviderId = actualProvider.id
|
const actualProviderId = actualProvider.type
|
||||||
const openaiResponseOptions =
|
const openaiResponseOptions =
|
||||||
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
|
// 对于实际是openai的需要走responses,aiCore内部会判断model是否可用responses
|
||||||
actualProviderId === 'openai'
|
actualProviderId === 'openai-response'
|
||||||
? {
|
? {
|
||||||
mode: 'responses'
|
mode: 'responses'
|
||||||
}
|
}
|
||||||
@ -167,15 +167,18 @@ export default class ModernAiProvider {
|
|||||||
// 内置了默认搜索参数,如果改的话可以传config进去
|
// 内置了默认搜索参数,如果改的话可以传config进去
|
||||||
plugins.push(webSearchPlugin())
|
plugins.push(webSearchPlugin())
|
||||||
}
|
}
|
||||||
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
// 2. 支持工具调用时添加搜索插件
|
||||||
|
if (middlewareConfig.isSupportedToolUse) {
|
||||||
|
plugins.push(searchOrchestrationPlugin(middlewareConfig.assistant))
|
||||||
|
}
|
||||||
|
|
||||||
// 2. 推理模型时添加推理插件
|
// 3. 推理模型时添加推理插件
|
||||||
if (middlewareConfig.enableReasoning) {
|
if (middlewareConfig.enableReasoning) {
|
||||||
plugins.push(reasoningTimePlugin)
|
plugins.push(reasoningTimePlugin)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 启用Prompt工具调用时添加工具插件
|
// 4. 启用Prompt工具调用时添加工具插件
|
||||||
if (middlewareConfig.enableTool && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
if (middlewareConfig.isPromptToolUse && middlewareConfig.mcpTools && middlewareConfig.mcpTools.length > 0) {
|
||||||
plugins.push(
|
plugins.push(
|
||||||
createPromptToolUsePlugin({
|
createPromptToolUsePlugin({
|
||||||
enabled: true,
|
enabled: true,
|
||||||
@ -216,8 +219,7 @@ export default class ModernAiProvider {
|
|||||||
): Promise<CompletionsResult> {
|
): Promise<CompletionsResult> {
|
||||||
console.log('completions', modelId, params, middlewareConfig)
|
console.log('completions', modelId, params, middlewareConfig)
|
||||||
|
|
||||||
// 检查是否为图像生成模型
|
if (middlewareConfig.isImageGenerationEndpoint) {
|
||||||
if (middlewareConfig.model && isDedicatedImageGenerationModel(middlewareConfig.model)) {
|
|
||||||
return await this.modernImageGeneration(modelId, params, middlewareConfig)
|
return await this.modernImageGeneration(modelId, params, middlewareConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -313,17 +315,17 @@ export default class ModernAiProvider {
|
|||||||
throw new Error('No prompt found in user message.')
|
throw new Error('No prompt found in user message.')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const startTime = Date.now()
|
||||||
|
|
||||||
// 发送图像生成开始事件
|
// 发送图像生成开始事件
|
||||||
if (onChunk) {
|
if (onChunk) {
|
||||||
onChunk({ type: ChunkType.IMAGE_CREATED })
|
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||||
}
|
}
|
||||||
|
|
||||||
const startTime = Date.now()
|
|
||||||
|
|
||||||
// 构建图像生成参数
|
// 构建图像生成参数
|
||||||
const imageParams = {
|
const imageParams = {
|
||||||
prompt,
|
prompt,
|
||||||
size: '1024x1024' as `${number}x${number}`, // 默认尺寸,使用正确的类型
|
size: isNotSupportedImageSizeModel(middlewareConfig.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
|
||||||
n: 1,
|
n: 1,
|
||||||
...(params.abortSignal && { abortSignal: params.abortSignal })
|
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||||
}
|
}
|
||||||
@ -338,7 +340,7 @@ export default class ModernAiProvider {
|
|||||||
if (result.images) {
|
if (result.images) {
|
||||||
for (const image of result.images) {
|
for (const image of result.images) {
|
||||||
if ('base64' in image && image.base64) {
|
if ('base64' in image && image.base64) {
|
||||||
images.push(`data:image/png;base64,${image.base64}`)
|
images.push(`data:${image.mediaType};base64,${image.base64}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -351,14 +353,27 @@ export default class ModernAiProvider {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// 发送响应完成事件
|
// 发送块完成事件(类似于 modernCompletions 的处理)
|
||||||
if (onChunk) {
|
if (onChunk) {
|
||||||
const usage = {
|
const usage = {
|
||||||
prompt_tokens: 0,
|
prompt_tokens: prompt.length, // 估算的 token 数量
|
||||||
completion_tokens: 0,
|
completion_tokens: 0, // 图像生成没有 completion tokens
|
||||||
total_tokens: 0
|
total_tokens: prompt.length
|
||||||
}
|
}
|
||||||
|
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage,
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: usage.completion_tokens,
|
||||||
|
time_first_token_millsec: 0,
|
||||||
|
time_completion_millsec: Date.now() - startTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送 LLM 响应完成事件
|
||||||
onChunk({
|
onChunk({
|
||||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
response: {
|
response: {
|
||||||
|
|||||||
@ -10,14 +10,19 @@ import type { Chunk } from '@renderer/types/chunk'
|
|||||||
* AI SDK 中间件配置项
|
* AI SDK 中间件配置项
|
||||||
*/
|
*/
|
||||||
export interface AiSdkMiddlewareConfig {
|
export interface AiSdkMiddlewareConfig {
|
||||||
streamOutput?: boolean
|
streamOutput: boolean
|
||||||
onChunk?: (chunk: Chunk) => void
|
onChunk?: (chunk: Chunk) => void
|
||||||
model?: Model
|
model?: Model
|
||||||
provider?: Provider
|
provider?: Provider
|
||||||
enableReasoning?: boolean
|
enableReasoning: boolean
|
||||||
// 是否开启提示词工具调用
|
// 是否开启提示词工具调用
|
||||||
enableTool?: boolean
|
isPromptToolUse: boolean
|
||||||
enableWebSearch?: boolean
|
// 是否支持工具调用
|
||||||
|
isSupportedToolUse: boolean
|
||||||
|
// image generation endpoint
|
||||||
|
isImageGenerationEndpoint: boolean
|
||||||
|
enableWebSearch: boolean
|
||||||
|
enableGenerateImage: boolean
|
||||||
mcpTools?: BaseTool[]
|
mcpTools?: BaseTool[]
|
||||||
// TODO assistant
|
// TODO assistant
|
||||||
assistant: Assistant
|
assistant: Assistant
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { ProviderId } from '@cherrystudio/ai-core/types'
|
import { ProviderId } from '@cherrystudio/ai-core/types'
|
||||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
import { isOpenAIModel } from '@renderer/config/models'
|
||||||
import { Model, Provider } from '@renderer/types'
|
import { Model, Provider } from '@renderer/types'
|
||||||
|
|
||||||
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
|
export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'openai-compatible' {
|
||||||
@ -16,7 +16,7 @@ export function getAiSdkProviderIdForAihubmix(model: Model): ProviderId | 'opena
|
|||||||
return 'google'
|
return 'google'
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isOpenAILLMModel(model)) {
|
if (isOpenAIModel(model)) {
|
||||||
return 'openai'
|
return 'openai'
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ export function createAihubmixProvider(model: Model, provider: Provider): Provid
|
|||||||
if (providerId === 'openai') {
|
if (providerId === 'openai') {
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
type: 'openai'
|
type: 'openai-response'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -255,8 +255,6 @@ export async function buildStreamTextParams(
|
|||||||
provider: Provider,
|
provider: Provider,
|
||||||
options: {
|
options: {
|
||||||
mcpTools?: MCPTool[]
|
mcpTools?: MCPTool[]
|
||||||
enableTools?: boolean
|
|
||||||
enableWebSearch?: boolean
|
|
||||||
webSearchProviderId?: string
|
webSearchProviderId?: string
|
||||||
requestOptions?: {
|
requestOptions?: {
|
||||||
signal?: AbortSignal
|
signal?: AbortSignal
|
||||||
@ -267,9 +265,9 @@ export async function buildStreamTextParams(
|
|||||||
): Promise<{
|
): Promise<{
|
||||||
params: StreamTextParams
|
params: StreamTextParams
|
||||||
modelId: string
|
modelId: string
|
||||||
capabilities: { enableReasoning?: boolean; enableWebSearch?: boolean; enableGenerateImage?: boolean }
|
capabilities: { enableReasoning: boolean; enableWebSearch: boolean; enableGenerateImage: boolean }
|
||||||
}> {
|
}> {
|
||||||
const { mcpTools, enableTools, webSearchProviderId } = options
|
const { mcpTools } = options
|
||||||
|
|
||||||
const model = assistant.model || getDefaultModel()
|
const model = assistant.model || getDefaultModel()
|
||||||
|
|
||||||
@ -291,12 +289,7 @@ export async function buildStreamTextParams(
|
|||||||
isGenerateImageModel(model) &&
|
isGenerateImageModel(model) &&
|
||||||
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
||||||
|
|
||||||
// 构建系统提示
|
const tools = setupToolsConfig(mcpTools)
|
||||||
const { tools } = setupToolsConfig({
|
|
||||||
mcpTools,
|
|
||||||
model,
|
|
||||||
enableToolUse: enableTools
|
|
||||||
})
|
|
||||||
|
|
||||||
// if (webSearchProviderId) {
|
// if (webSearchProviderId) {
|
||||||
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
// tools['builtin_web_search'] = webSearchTool(webSearchProviderId)
|
||||||
|
|||||||
5
src/renderer/src/aiCore/utils/image.ts
Normal file
5
src/renderer/src/aiCore/utils/image.ts
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
export function buildGeminiGenerateImageParams(): Record<string, any> {
|
||||||
|
return {
|
||||||
|
responseModalities: ['TEXT', 'IMAGE']
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -1,36 +1,21 @@
|
|||||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||||
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
|
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||||
import { isFunctionCallingModel } from '@renderer/config/models'
|
|
||||||
import { MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
|
||||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||||
import { tool } from 'ai'
|
import { tool } from 'ai'
|
||||||
import { JSONSchema7 } from 'json-schema'
|
import { JSONSchema7 } from 'json-schema'
|
||||||
|
|
||||||
// Setup tools configuration based on provided parameters
|
// Setup tools configuration based on provided parameters
|
||||||
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
export function setupToolsConfig(mcpTools?: MCPTool[]): Record<string, Tool> | undefined {
|
||||||
tools: Record<string, Tool>
|
|
||||||
useSystemPromptForTools?: boolean
|
|
||||||
} {
|
|
||||||
const { mcpTools, model, enableToolUse } = params
|
|
||||||
|
|
||||||
let tools: Record<string, Tool> = {}
|
let tools: Record<string, Tool> = {}
|
||||||
|
|
||||||
if (!mcpTools?.length) {
|
if (!mcpTools?.length) {
|
||||||
return { tools }
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
tools = convertMcpToolsToAiSdkTools(mcpTools)
|
tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
if (mcpTools.length > SYSTEM_PROMPT_THRESHOLD) {
|
return tools
|
||||||
return { tools, useSystemPromptForTools: true }
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
|
||||||
return { tools, useSystemPromptForTools: false }
|
|
||||||
}
|
|
||||||
|
|
||||||
return { tools, useSystemPromptForTools: true }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import { Assistant, Model, Provider } from '@renderer/types'
|
import { Assistant, Model, Provider } from '@renderer/types'
|
||||||
|
|
||||||
import { getAiSdkProviderId } from '../provider/factory'
|
import { getAiSdkProviderId } from '../provider/factory'
|
||||||
|
import { buildGeminiGenerateImageParams } from './image'
|
||||||
import {
|
import {
|
||||||
getAnthropicReasoningParams,
|
getAnthropicReasoningParams,
|
||||||
getCustomParameters,
|
getCustomParameters,
|
||||||
@ -140,7 +141,7 @@ function buildGeminiProviderOptions(
|
|||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): Record<string, any> {
|
||||||
const { enableReasoning } = capabilities
|
const { enableReasoning, enableGenerateImage } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: Record<string, any> = {}
|
||||||
|
|
||||||
// Gemini 推理参数
|
// Gemini 推理参数
|
||||||
@ -152,6 +153,13 @@ function buildGeminiProviderOptions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (enableGenerateImage) {
|
||||||
|
providerOptions = {
|
||||||
|
...providerOptions,
|
||||||
|
...buildGeminiGenerateImageParams()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return providerOptions
|
return providerOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -204,9 +204,10 @@ export const VISION_REGEX = new RegExp(
|
|||||||
)
|
)
|
||||||
|
|
||||||
// For middleware to identify models that must use the dedicated Image API
|
// For middleware to identify models that must use the dedicated Image API
|
||||||
export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1']
|
export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1', 'imagen(?:-[\\w-]+)?']
|
||||||
export const isDedicatedImageGenerationModel = (model: Model): boolean =>
|
const DEDICATED_IMAGE_MODELS_REGEX = new RegExp(DEDICATED_IMAGE_MODELS.join('|'), 'i')
|
||||||
DEDICATED_IMAGE_MODELS.filter((m) => model.id.includes(m)).length > 0
|
|
||||||
|
export const isDedicatedImageGenerationModel = (model: Model): boolean => DEDICATED_IMAGE_MODELS_REGEX.test(model.id)
|
||||||
|
|
||||||
// Text to image models
|
// Text to image models
|
||||||
export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|image|gpt-image/i
|
export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|image|gpt-image/i
|
||||||
@ -272,7 +273,7 @@ export function isFunctionCallingModel(model?: Model): boolean {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isEmbeddingModel(model)) {
|
if (isEmbeddingModel(model) || isGenerateImageModel(model)) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2775,6 +2776,16 @@ export function isGenerateImageModel(model: Model): boolean {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isNotSupportedImageSizeModel(model?: Model): boolean {
|
||||||
|
if (!model) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseName = getLowerBaseModelName(model.id, '/')
|
||||||
|
|
||||||
|
return baseName.includes('grok-2-image')
|
||||||
|
}
|
||||||
|
|
||||||
export function isSupportedDisableGenerationModel(model: Model): boolean {
|
export function isSupportedDisableGenerationModel(model: Model): boolean {
|
||||||
if (!model) {
|
if (!model) {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@ -37,6 +37,7 @@ import type { MessageInputBaseParams } from '@renderer/types/newMessage'
|
|||||||
import { classNames, delay, formatFileSize, getFileExtension } from '@renderer/utils'
|
import { classNames, delay, formatFileSize, getFileExtension } from '@renderer/utils'
|
||||||
import { formatQuotedText } from '@renderer/utils/formats'
|
import { formatQuotedText } from '@renderer/utils/formats'
|
||||||
import { getFilesFromDropEvent, getSendMessageShortcutLabel, isSendMessageKeyPressed } from '@renderer/utils/input'
|
import { getFilesFromDropEvent, getSendMessageShortcutLabel, isSendMessageKeyPressed } from '@renderer/utils/input'
|
||||||
|
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||||
import { getLanguageByLangcode } from '@renderer/utils/translate'
|
import { getLanguageByLangcode } from '@renderer/utils/translate'
|
||||||
import { documentExts, imageExts, textExts } from '@shared/config/constant'
|
import { documentExts, imageExts, textExts } from '@shared/config/constant'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
@ -801,6 +802,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
|||||||
|
|
||||||
const isExpended = expended || !!textareaHeight
|
const isExpended = expended || !!textareaHeight
|
||||||
const showThinkingButton = isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)
|
const showThinkingButton = isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)
|
||||||
|
const showMcpTools = isSupportedToolUse(assistant)
|
||||||
|
|
||||||
if (isMultiSelectMode) {
|
if (isMultiSelectMode) {
|
||||||
return null
|
return null
|
||||||
@ -879,6 +881,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
|||||||
setFiles={setFiles}
|
setFiles={setFiles}
|
||||||
showThinkingButton={showThinkingButton}
|
showThinkingButton={showThinkingButton}
|
||||||
showKnowledgeIcon={showKnowledgeIcon}
|
showKnowledgeIcon={showKnowledgeIcon}
|
||||||
|
showMcpTools={showMcpTools}
|
||||||
selectedKnowledgeBases={selectedKnowledgeBases}
|
selectedKnowledgeBases={selectedKnowledgeBases}
|
||||||
handleKnowledgeBaseSelect={handleKnowledgeBaseSelect}
|
handleKnowledgeBaseSelect={handleKnowledgeBaseSelect}
|
||||||
setText={setText}
|
setText={setText}
|
||||||
|
|||||||
@ -61,6 +61,7 @@ export interface InputbarToolsProps {
|
|||||||
extensions: string[]
|
extensions: string[]
|
||||||
showThinkingButton: boolean
|
showThinkingButton: boolean
|
||||||
showKnowledgeIcon: boolean
|
showKnowledgeIcon: boolean
|
||||||
|
showMcpTools: boolean
|
||||||
selectedKnowledgeBases: KnowledgeBase[]
|
selectedKnowledgeBases: KnowledgeBase[]
|
||||||
handleKnowledgeBaseSelect: (bases?: KnowledgeBase[]) => void
|
handleKnowledgeBaseSelect: (bases?: KnowledgeBase[]) => void
|
||||||
setText: Dispatch<SetStateAction<string>>
|
setText: Dispatch<SetStateAction<string>>
|
||||||
@ -102,6 +103,7 @@ const InputbarTools = ({
|
|||||||
setFiles,
|
setFiles,
|
||||||
showThinkingButton,
|
showThinkingButton,
|
||||||
showKnowledgeIcon,
|
showKnowledgeIcon,
|
||||||
|
showMcpTools,
|
||||||
selectedKnowledgeBases,
|
selectedKnowledgeBases,
|
||||||
handleKnowledgeBaseSelect,
|
handleKnowledgeBaseSelect,
|
||||||
setText,
|
setText,
|
||||||
@ -371,7 +373,8 @@ const InputbarTools = ({
|
|||||||
setInputValue={setText}
|
setInputValue={setText}
|
||||||
resizeTextArea={resizeTextArea}
|
resizeTextArea={resizeTextArea}
|
||||||
/>
|
/>
|
||||||
)
|
),
|
||||||
|
condition: showMcpTools
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
key: 'generate_image',
|
key: 'generate_image',
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import { PlusOutlined, RedoOutlined } from '@ant-design/icons'
|
import { PlusOutlined, RedoOutlined } from '@ant-design/icons'
|
||||||
import AiProvider from '@renderer/aiCore'
|
import AiProviderNew from '@renderer/aiCore/index_new'
|
||||||
import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg'
|
import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg'
|
||||||
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
|
import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar'
|
||||||
import { HStack } from '@renderer/components/Layout'
|
import { HStack } from '@renderer/components/Layout'
|
||||||
@ -179,12 +179,17 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
|||||||
try {
|
try {
|
||||||
if (mode === 'generate') {
|
if (mode === 'generate') {
|
||||||
if (painting.model.startsWith('imagen-')) {
|
if (painting.model.startsWith('imagen-')) {
|
||||||
const AI = new AiProvider(aihubmixProvider)
|
const AI = new AiProviderNew({
|
||||||
|
id: painting.model,
|
||||||
|
provider: 'aihubmix',
|
||||||
|
name: painting.model,
|
||||||
|
group: 'imagen'
|
||||||
|
})
|
||||||
const base64s = await AI.generateImage({
|
const base64s = await AI.generateImage({
|
||||||
prompt,
|
prompt,
|
||||||
model: painting.model,
|
model: painting.model,
|
||||||
imageSize: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':') || '1:1',
|
imageSize: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':') || '1:1',
|
||||||
batchSize: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages || 1,
|
batchSize: painting.model.startsWith('imagen-4.0-ultra') ? 1 : painting.numberOfImages || 1,
|
||||||
personGeneration: painting.personGeneration
|
personGeneration: painting.personGeneration
|
||||||
})
|
})
|
||||||
if (base64s?.length > 0) {
|
if (base64s?.length > 0) {
|
||||||
|
|||||||
@ -72,8 +72,8 @@ export const createModeConfigs = (): Record<AihubmixMode, ConfigItem[]> => {
|
|||||||
label: 'Gemini',
|
label: 'Gemini',
|
||||||
title: 'Gemini',
|
title: 'Gemini',
|
||||||
options: [
|
options: [
|
||||||
{ label: 'imagen-4.0-preview', value: 'imagen-4.0-generate-preview-05-20' },
|
{ label: 'imagen-4.0-preview', value: 'imagen-4.0-generate-preview-06-06' },
|
||||||
{ label: 'imagen-4.0-ultra-exp', value: 'imagen-4.0-ultra-generate-exp-05-20' },
|
{ label: 'imagen-4.0-ultra-preview', value: 'imagen-4.0-ultra-generate-preview-06-06' },
|
||||||
{ label: 'imagen-3.0', value: 'imagen-3.0-generate-001' }
|
{ label: 'imagen-3.0', value: 'imagen-3.0-generate-001' }
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMi
|
|||||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||||
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
|
||||||
import {
|
import {
|
||||||
|
isDedicatedImageGenerationModel,
|
||||||
isEmbeddingModel,
|
isEmbeddingModel,
|
||||||
isReasoningModel,
|
isReasoningModel,
|
||||||
isSupportedReasoningEffortModel,
|
isSupportedReasoningEffortModel,
|
||||||
@ -20,7 +21,7 @@ import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
|||||||
import { Message } from '@renderer/types/newMessage'
|
import { Message } from '@renderer/types/newMessage'
|
||||||
import { SdkModel } from '@renderer/types/sdk'
|
import { SdkModel } from '@renderer/types/sdk'
|
||||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
import { isEmpty, takeRight } from 'lodash'
|
import { isEmpty, takeRight } from 'lodash'
|
||||||
|
|
||||||
@ -391,7 +392,11 @@ export async function fetchChatCompletion({
|
|||||||
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
||||||
const provider = AI.getActualProvider()
|
const provider = AI.getActualProvider()
|
||||||
|
|
||||||
const mcpTools = await fetchMcpTools(assistant)
|
const mcpTools: MCPTool[] = []
|
||||||
|
|
||||||
|
if (isSupportedToolUse(assistant)) {
|
||||||
|
mcpTools.push(...(await fetchMcpTools(assistant)))
|
||||||
|
}
|
||||||
|
|
||||||
// 使用 transformParameters 模块构建参数
|
// 使用 transformParameters 模块构建参数
|
||||||
const {
|
const {
|
||||||
@ -400,7 +405,6 @@ export async function fetchChatCompletion({
|
|||||||
capabilities
|
capabilities
|
||||||
} = await buildStreamTextParams(messages, assistant, provider, {
|
} = await buildStreamTextParams(messages, assistant, provider, {
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
enableTools: isEnabledToolUse(assistant),
|
|
||||||
webSearchProviderId: assistant.webSearchProviderId,
|
webSearchProviderId: assistant.webSearchProviderId,
|
||||||
requestOptions: options
|
requestOptions: options
|
||||||
})
|
})
|
||||||
@ -430,8 +434,11 @@ export async function fetchChatCompletion({
|
|||||||
model: assistant.model,
|
model: assistant.model,
|
||||||
provider: provider,
|
provider: provider,
|
||||||
enableReasoning: capabilities.enableReasoning,
|
enableReasoning: capabilities.enableReasoning,
|
||||||
enableTool: assistant.settings?.toolUseMode === 'prompt',
|
isPromptToolUse: isPromptToolUse(assistant),
|
||||||
|
isSupportedToolUse: isSupportedToolUse(assistant),
|
||||||
|
isImageGenerationEndpoint: isDedicatedImageGenerationModel(assistant.model || getDefaultModel()),
|
||||||
enableWebSearch: capabilities.enableWebSearch,
|
enableWebSearch: capabilities.enableWebSearch,
|
||||||
|
enableGenerateImage: capabilities.enableGenerateImage,
|
||||||
mcpTools,
|
mcpTools,
|
||||||
assistant
|
assistant
|
||||||
}
|
}
|
||||||
|
|||||||
@ -44,7 +44,7 @@ export const createImageCallbacks = (deps: ImageCallbacksDependencies) => {
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|
||||||
onImageGenerated: (imageData: any) => {
|
onImageGenerated: async (imageData: any) => {
|
||||||
if (imageBlockId) {
|
if (imageBlockId) {
|
||||||
if (!imageData) {
|
if (!imageData) {
|
||||||
const changes: Partial<ImageMessageBlock> = {
|
const changes: Partial<ImageMessageBlock> = {
|
||||||
@ -62,7 +62,16 @@ export const createImageCallbacks = (deps: ImageCallbacksDependencies) => {
|
|||||||
}
|
}
|
||||||
imageBlockId = null
|
imageBlockId = null
|
||||||
} else {
|
} else {
|
||||||
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
if (imageData) {
|
||||||
|
const imageBlock = createImageBlock(assistantMsgId, {
|
||||||
|
status: MessageBlockStatus.SUCCESS,
|
||||||
|
url: imageData.images?.[0] || 'placeholder_image_url',
|
||||||
|
metadata: { generateImageResponse: imageData }
|
||||||
|
})
|
||||||
|
await blockManager.handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||||
|
} else {
|
||||||
|
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1006,12 +1006,30 @@ export function mcpToolCallResponseToGeminiMessage(
|
|||||||
return message
|
return message
|
||||||
}
|
}
|
||||||
|
|
||||||
export function isEnabledToolUse(assistant: Assistant) {
|
/**
|
||||||
|
* 是否启用工具使用
|
||||||
|
* 1. 如果模型支持函数调用,则启用工具使用
|
||||||
|
* 2. 如果工具使用模式为 prompt,则启用工具使用
|
||||||
|
* @param assistant
|
||||||
|
* @returns 是否启用工具使用
|
||||||
|
*/
|
||||||
|
export function isSupportedToolUse(assistant: Assistant) {
|
||||||
if (assistant.model) {
|
if (assistant.model) {
|
||||||
if (isFunctionCallingModel(assistant.model)) {
|
if (isFunctionCallingModel(assistant.model)) {
|
||||||
return assistant.settings?.toolUseMode === 'function'
|
return true
|
||||||
|
} else {
|
||||||
|
return assistant.settings?.toolUseMode === 'prompt'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 是否使用提示词工具使用
|
||||||
|
* @param assistant
|
||||||
|
* @returns 是否使用提示词工具使用
|
||||||
|
*/
|
||||||
|
export function isPromptToolUse(assistant: Assistant) {
|
||||||
|
return assistant.settings?.toolUseMode === 'prompt'
|
||||||
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user