mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 05:11:24 +08:00
refactor: update type exports and enhance web search functionality
- Added `ReasoningPart`, `FilePart`, and `ImagePart` to type exports in `index.ts`. - Refactored `transformParameters.ts` to include `enableWebSearch` option and integrate web search tools. - Introduced new utility `getWebSearchTools` in `websearch.ts` to manage web search tool configurations based on model type. - Commented out deprecated code in `smoothReasoningPlugin.ts` and `textPlugin.ts` for potential removal.
This commit is contained in:
parent
342c5ab82c
commit
bb520910bc
@ -49,29 +49,30 @@ export * as aiSdk from 'ai'
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
export type { ToolCall } from '@ai-sdk/provider-utils'
|
||||
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
export type {
|
||||
AssistantModelMessage,
|
||||
FilePart,
|
||||
// 通用类型
|
||||
FinishReason,
|
||||
GenerateObjectResult,
|
||||
// 生成相关类型
|
||||
GenerateTextResult,
|
||||
ImagePart,
|
||||
InvalidToolInputError,
|
||||
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
|
||||
// 消息相关类型
|
||||
ModelMessage,
|
||||
TextPart,
|
||||
FilePart,
|
||||
ImagePart,
|
||||
ToolCallPart,
|
||||
// 错误类型
|
||||
NoSuchToolError,
|
||||
StreamTextResult,
|
||||
SystemModelMessage,
|
||||
TextPart,
|
||||
// 流相关类型
|
||||
TextStreamPart,
|
||||
// 工具相关类型
|
||||
Tool,
|
||||
ToolCallPart,
|
||||
ToolCallUnion,
|
||||
ToolModelMessage,
|
||||
ToolResultPart,
|
||||
@ -79,7 +80,6 @@ export type {
|
||||
ToolSet,
|
||||
UserModelMessage
|
||||
} from 'ai'
|
||||
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
export {
|
||||
defaultSettingsMiddleware,
|
||||
extractReasoningMiddleware,
|
||||
|
||||
@ -1,152 +1,152 @@
|
||||
// 可能会废弃,在流上做delay还是有问题
|
||||
// // 可能会废弃,在流上做delay还是有问题
|
||||
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
// import { definePlugin } from '@cherrystudio/ai-core'
|
||||
|
||||
const chunkingRegex = /([\u4E00-\u9FFF])|\S+\s+/
|
||||
const delayInMs = 50
|
||||
// const chunkingRegex = /([\u4E00-\u9FFF])|\S+\s+/
|
||||
// const delayInMs = 50
|
||||
|
||||
export default definePlugin({
|
||||
name: 'reasoningPlugin',
|
||||
// export default definePlugin({
|
||||
// name: 'reasoningPlugin',
|
||||
|
||||
transformStream: () => () => {
|
||||
// === smoothing 状态 ===
|
||||
let buffer = ''
|
||||
// transformStream: () => () => {
|
||||
// // === smoothing 状态 ===
|
||||
// let buffer = ''
|
||||
|
||||
// === 时间跟踪状态 ===
|
||||
let thinkingStartTime = performance.now()
|
||||
let hasStartedThinking = false
|
||||
let accumulatedThinkingContent = ''
|
||||
// // === 时间跟踪状态 ===
|
||||
// let thinkingStartTime = performance.now()
|
||||
// let hasStartedThinking = false
|
||||
// let accumulatedThinkingContent = ''
|
||||
|
||||
// === 日志计数器 ===
|
||||
let chunkCount = 0
|
||||
let delayCount = 0
|
||||
// // === 日志计数器 ===
|
||||
// let chunkCount = 0
|
||||
// let delayCount = 0
|
||||
|
||||
const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))
|
||||
// const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))
|
||||
|
||||
// 收集所有当前可匹配的chunks
|
||||
const collectMatches = (inputBuffer: string) => {
|
||||
const matches: string[] = []
|
||||
let tempBuffer = inputBuffer
|
||||
let match
|
||||
// // 收集所有当前可匹配的chunks
|
||||
// const collectMatches = (inputBuffer: string) => {
|
||||
// const matches: string[] = []
|
||||
// let tempBuffer = inputBuffer
|
||||
// let match
|
||||
|
||||
// 重置regex状态
|
||||
chunkingRegex.lastIndex = 0
|
||||
// // 重置regex状态
|
||||
// chunkingRegex.lastIndex = 0
|
||||
|
||||
while ((match = chunkingRegex.exec(tempBuffer)) !== null) {
|
||||
matches.push(match[0])
|
||||
tempBuffer = tempBuffer.slice(match.index + match[0].length)
|
||||
// 重置regex以从头开始匹配剩余内容
|
||||
chunkingRegex.lastIndex = 0
|
||||
}
|
||||
// while ((match = chunkingRegex.exec(tempBuffer)) !== null) {
|
||||
// matches.push(match[0])
|
||||
// tempBuffer = tempBuffer.slice(match.index + match[0].length)
|
||||
// // 重置regex以从头开始匹配剩余内容
|
||||
// chunkingRegex.lastIndex = 0
|
||||
// }
|
||||
|
||||
return {
|
||||
matches,
|
||||
remaining: tempBuffer
|
||||
}
|
||||
}
|
||||
// return {
|
||||
// matches,
|
||||
// remaining: tempBuffer
|
||||
// }
|
||||
// }
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk, controller) {
|
||||
if (chunk.type !== 'reasoning') {
|
||||
// === 处理 reasoning 结束 ===
|
||||
if (hasStartedThinking && accumulatedThinkingContent) {
|
||||
console.log(
|
||||
`[ReasoningPlugin] Ending reasoning. Final stats: chunks=${chunkCount}, delays=${delayCount}, efficiency=${(chunkCount / Math.max(delayCount, 1)).toFixed(2)}x`
|
||||
)
|
||||
// return new TransformStream({
|
||||
// async transform(chunk, controller) {
|
||||
// if (chunk.type !== 'reasoning') {
|
||||
// // === 处理 reasoning 结束 ===
|
||||
// if (hasStartedThinking && accumulatedThinkingContent) {
|
||||
// console.log(
|
||||
// `[ReasoningPlugin] Ending reasoning. Final stats: chunks=${chunkCount}, delays=${delayCount}, efficiency=${(chunkCount / Math.max(delayCount, 1)).toFixed(2)}x`
|
||||
// )
|
||||
|
||||
// 先输出剩余的 buffer
|
||||
if (buffer.length > 0) {
|
||||
console.log(`[ReasoningPlugin] Flushing remaining buffer: "${buffer}"`)
|
||||
controller.enqueue({
|
||||
type: 'reasoning',
|
||||
textDelta: buffer,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
})
|
||||
buffer = ''
|
||||
}
|
||||
// // 先输出剩余的 buffer
|
||||
// if (buffer.length > 0) {
|
||||
// console.log(`[ReasoningPlugin] Flushing remaining buffer: "${buffer}"`)
|
||||
// controller.enqueue({
|
||||
// type: 'reasoning',
|
||||
// textDelta: buffer,
|
||||
// thinking_millsec: performance.now() - thinkingStartTime
|
||||
// })
|
||||
// buffer = ''
|
||||
// }
|
||||
|
||||
// 生成 reasoning-signature
|
||||
controller.enqueue({
|
||||
type: 'reasoning-signature',
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
})
|
||||
// // 生成 reasoning-signature
|
||||
// controller.enqueue({
|
||||
// type: 'reasoning-signature',
|
||||
// text: accumulatedThinkingContent,
|
||||
// thinking_millsec: performance.now() - thinkingStartTime
|
||||
// })
|
||||
|
||||
// 重置状态
|
||||
accumulatedThinkingContent = ''
|
||||
hasStartedThinking = false
|
||||
thinkingStartTime = 0
|
||||
chunkCount = 0
|
||||
delayCount = 0
|
||||
}
|
||||
// // 重置状态
|
||||
// accumulatedThinkingContent = ''
|
||||
// hasStartedThinking = false
|
||||
// thinkingStartTime = 0
|
||||
// chunkCount = 0
|
||||
// delayCount = 0
|
||||
// }
|
||||
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
// controller.enqueue(chunk)
|
||||
// return
|
||||
// }
|
||||
|
||||
// === 处理 reasoning 类型 ===
|
||||
// // === 处理 reasoning 类型 ===
|
||||
|
||||
// 1. 时间跟踪逻辑
|
||||
if (!hasStartedThinking) {
|
||||
hasStartedThinking = true
|
||||
thinkingStartTime = performance.now()
|
||||
console.log(`[ReasoningPlugin] Starting reasoning session`)
|
||||
}
|
||||
accumulatedThinkingContent += chunk.textDelta
|
||||
// // 1. 时间跟踪逻辑
|
||||
// if (!hasStartedThinking) {
|
||||
// hasStartedThinking = true
|
||||
// thinkingStartTime = performance.now()
|
||||
// console.log(`[ReasoningPlugin] Starting reasoning session`)
|
||||
// }
|
||||
// accumulatedThinkingContent += chunk.textDelta
|
||||
|
||||
// 2. 动态Smooth处理逻辑
|
||||
const beforeBuffer = buffer
|
||||
buffer += chunk.textDelta
|
||||
// // 2. 动态Smooth处理逻辑
|
||||
// const beforeBuffer = buffer
|
||||
// buffer += chunk.textDelta
|
||||
|
||||
console.log(`[ReasoningPlugin] Received chunk: "${chunk.textDelta}", buffer: "${beforeBuffer}" → "${buffer}"`)
|
||||
// console.log(`[ReasoningPlugin] Received chunk: "${chunk.textDelta}", buffer: "${beforeBuffer}" → "${buffer}"`)
|
||||
|
||||
// 收集所有当前可以匹配的chunks
|
||||
const { matches, remaining } = collectMatches(buffer)
|
||||
// // 收集所有当前可以匹配的chunks
|
||||
// const { matches, remaining } = collectMatches(buffer)
|
||||
|
||||
if (matches.length > 0) {
|
||||
console.log(
|
||||
`[ReasoningPlugin] Collected ${matches.length} matches: [${matches.map((m) => `"${m}"`).join(', ')}], remaining: "${remaining}"`
|
||||
)
|
||||
// if (matches.length > 0) {
|
||||
// console.log(
|
||||
// `[ReasoningPlugin] Collected ${matches.length} matches: [${matches.map((m) => `"${m}"`).join(', ')}], remaining: "${remaining}"`
|
||||
// )
|
||||
|
||||
// 批量输出所有匹配的chunks
|
||||
for (const matchText of matches) {
|
||||
controller.enqueue({
|
||||
type: 'reasoning',
|
||||
textDelta: matchText,
|
||||
thinking_millsec: performance.now() - thinkingStartTime
|
||||
})
|
||||
chunkCount++
|
||||
}
|
||||
// // 批量输出所有匹配的chunks
|
||||
// for (const matchText of matches) {
|
||||
// controller.enqueue({
|
||||
// type: 'reasoning',
|
||||
// textDelta: matchText,
|
||||
// thinking_millsec: performance.now() - thinkingStartTime
|
||||
// })
|
||||
// chunkCount++
|
||||
// }
|
||||
|
||||
// 更新buffer为剩余内容
|
||||
buffer = remaining
|
||||
// // 更新buffer为剩余内容
|
||||
// buffer = remaining
|
||||
|
||||
// 只等待一次,而不是每个chunk都等待
|
||||
delayCount++
|
||||
console.log(
|
||||
`[ReasoningPlugin] Delaying ${delayInMs}ms (delay #${delayCount}, efficiency: ${(chunkCount / delayCount).toFixed(2)} chunks/delay)`
|
||||
)
|
||||
const delayStart = performance.now()
|
||||
await delay(delayInMs)
|
||||
const actualDelay = performance.now() - delayStart
|
||||
console.log(`[ReasoningPlugin] Delay completed: expected=${delayInMs}ms, actual=${actualDelay.toFixed(1)}ms`)
|
||||
} else {
|
||||
console.log(`[ReasoningPlugin] No matches found, keeping in buffer: "${buffer}"`)
|
||||
}
|
||||
// 如果没有匹配,保留在buffer中等待下次数据
|
||||
},
|
||||
// // 只等待一次,而不是每个chunk都等待
|
||||
// delayCount++
|
||||
// console.log(
|
||||
// `[ReasoningPlugin] Delaying ${delayInMs}ms (delay #${delayCount}, efficiency: ${(chunkCount / delayCount).toFixed(2)} chunks/delay)`
|
||||
// )
|
||||
// const delayStart = performance.now()
|
||||
// await delay(delayInMs)
|
||||
// const actualDelay = performance.now() - delayStart
|
||||
// console.log(`[ReasoningPlugin] Delay completed: expected=${delayInMs}ms, actual=${actualDelay.toFixed(1)}ms`)
|
||||
// } else {
|
||||
// console.log(`[ReasoningPlugin] No matches found, keeping in buffer: "${buffer}"`)
|
||||
// }
|
||||
// // 如果没有匹配,保留在buffer中等待下次数据
|
||||
// },
|
||||
|
||||
// === flush 处理剩余 buffer ===
|
||||
flush(controller) {
|
||||
if (buffer.length > 0) {
|
||||
console.log(`[ReasoningPlugin] Final flush: "${buffer}"`)
|
||||
controller.enqueue({
|
||||
type: 'reasoning',
|
||||
textDelta: buffer,
|
||||
thinking_millsec: hasStartedThinking ? performance.now() - thinkingStartTime : 0
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
// // === flush 处理剩余 buffer ===
|
||||
// flush(controller) {
|
||||
// if (buffer.length > 0) {
|
||||
// console.log(`[ReasoningPlugin] Final flush: "${buffer}"`)
|
||||
// controller.enqueue({
|
||||
// type: 'reasoning',
|
||||
// textDelta: buffer,
|
||||
// thinking_millsec: hasStartedThinking ? performance.now() - thinkingStartTime : 0
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// })
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
// 可能会废弃,在流上做delay还是有问题
|
||||
// // 可能会废弃,在流上做delay还是有问题
|
||||
|
||||
import { definePlugin, smoothStream } from '@cherrystudio/ai-core'
|
||||
// import { definePlugin, smoothStream } from '@cherrystudio/ai-core'
|
||||
|
||||
export default definePlugin({
|
||||
name: 'textPlugin',
|
||||
transformStream: () =>
|
||||
smoothStream({
|
||||
delayInMs: 50,
|
||||
// 中文3个字符一个chunk,英文一个单词一个chunk
|
||||
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
||||
})
|
||||
})
|
||||
// export default definePlugin({
|
||||
// name: 'textPlugin',
|
||||
// transformStream: () =>
|
||||
// smoothStream({
|
||||
// delayInMs: 50,
|
||||
// // 中文3个字符一个chunk,英文一个单词一个chunk
|
||||
// chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
|
||||
// })
|
||||
// })
|
||||
|
||||
@ -29,6 +29,7 @@ import {
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
|
||||
import {
|
||||
findFileBlocks,
|
||||
findImageBlocks,
|
||||
@ -41,7 +42,7 @@ import { defaultTimeout } from '@shared/config/constant'
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
import { buildProviderOptions } from './utils/options'
|
||||
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
|
||||
import { getWebSearchTools } from './utils/websearch'
|
||||
|
||||
/**
|
||||
* 获取温度参数
|
||||
@ -243,6 +244,7 @@ export async function buildStreamTextParams(
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
enableTools?: boolean
|
||||
enableWebSearch?: boolean
|
||||
requestOptions?: {
|
||||
signal?: AbortSignal
|
||||
timeout?: number
|
||||
@ -277,12 +279,18 @@ export async function buildStreamTextParams(
|
||||
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
||||
|
||||
// 构建系统提示
|
||||
const { tools } = setupToolsConfig({
|
||||
let { tools } = setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: enableTools
|
||||
})
|
||||
|
||||
// Add web search tools if enabled
|
||||
if (enableWebSearch) {
|
||||
const webSearchTools = getWebSearchTools(model)
|
||||
tools = { ...tools, ...webSearchTools }
|
||||
}
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, {
|
||||
enableReasoning,
|
||||
|
||||
37
src/renderer/src/aiCore/utils/websearch.ts
Normal file
37
src/renderer/src/aiCore/utils/websearch.ts
Normal file
@ -0,0 +1,37 @@
|
||||
import { isWebSearchModel } from '@renderer/config/models'
|
||||
import { Model } from '@renderer/types'
|
||||
// import {} from '@cherrystudio/ai-core'
|
||||
|
||||
// The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
|
||||
const GEMINI_SEARCH_TOOL_NAME = 'google_search'
|
||||
|
||||
export function getWebSearchTools(model: Model): Record<string, any> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
// Use provider from model if available, otherwise fallback to parsing model id.
|
||||
const provider = model.provider || model.id.split('/')[0]
|
||||
|
||||
switch (provider) {
|
||||
case 'anthropic':
|
||||
return {
|
||||
web_search: {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
}
|
||||
}
|
||||
case 'google':
|
||||
case 'gemini':
|
||||
return {
|
||||
[GEMINI_SEARCH_TOOL_NAME]: {
|
||||
googleSearch: {}
|
||||
}
|
||||
}
|
||||
default:
|
||||
// For OpenAI and others, web search is often a parameter, not a tool.
|
||||
// The logic is handled in `buildProviderOptions`.
|
||||
return {}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user