refactor: update type exports and enhance web search functionality

- Added `ReasoningPart`, `FilePart`, and `ImagePart` to type exports in `index.ts`.
- Refactored `transformParameters.ts` to include `enableWebSearch` option and integrate web search tools.
- Introduced new utility `getWebSearchTools` in `websearch.ts` to manage web search tool configurations based on model type.
- Commented out deprecated code in `smoothReasoningPlugin.ts` and `textPlugin.ts` for potential removal.
This commit is contained in:
MyPrototypeWhat 2025-07-07 19:34:04 +08:00
parent 342c5ab82c
commit bb520910bc
5 changed files with 189 additions and 144 deletions

View File

@ -49,29 +49,30 @@ export * as aiSdk from 'ai'
// 直接导出 AI SDK 的常用类型,方便使用
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
export type { ToolCall } from '@ai-sdk/provider-utils'
export type { ReasoningPart } from '@ai-sdk/provider-utils'
export type {
AssistantModelMessage,
FilePart,
// 通用类型
FinishReason,
GenerateObjectResult,
// 生成相关类型
GenerateTextResult,
ImagePart,
InvalidToolInputError,
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
// 消息相关类型
ModelMessage,
TextPart,
FilePart,
ImagePart,
ToolCallPart,
// 错误类型
NoSuchToolError,
StreamTextResult,
SystemModelMessage,
TextPart,
// 流相关类型
TextStreamPart,
// 工具相关类型
Tool,
ToolCallPart,
ToolCallUnion,
ToolModelMessage,
ToolResultPart,
@ -79,7 +80,6 @@ export type {
ToolSet,
UserModelMessage
} from 'ai'
export type { ReasoningPart } from '@ai-sdk/provider-utils'
export {
defaultSettingsMiddleware,
extractReasoningMiddleware,

View File

@ -1,152 +1,152 @@
// 可能会废弃在流上做delay还是有问题
// // 可能会废弃在流上做delay还是有问题
import { definePlugin } from '@cherrystudio/ai-core'
// import { definePlugin } from '@cherrystudio/ai-core'
const chunkingRegex = /([\u4E00-\u9FFF])|\S+\s+/
const delayInMs = 50
// const chunkingRegex = /([\u4E00-\u9FFF])|\S+\s+/
// const delayInMs = 50
export default definePlugin({
name: 'reasoningPlugin',
// export default definePlugin({
// name: 'reasoningPlugin',
transformStream: () => () => {
// === smoothing 状态 ===
let buffer = ''
// transformStream: () => () => {
// // === smoothing 状态 ===
// let buffer = ''
// === 时间跟踪状态 ===
let thinkingStartTime = performance.now()
let hasStartedThinking = false
let accumulatedThinkingContent = ''
// // === 时间跟踪状态 ===
// let thinkingStartTime = performance.now()
// let hasStartedThinking = false
// let accumulatedThinkingContent = ''
// === 日志计数器 ===
let chunkCount = 0
let delayCount = 0
// // === 日志计数器 ===
// let chunkCount = 0
// let delayCount = 0
const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))
// const delay = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms))
// 收集所有当前可匹配的chunks
const collectMatches = (inputBuffer: string) => {
const matches: string[] = []
let tempBuffer = inputBuffer
let match
// // 收集所有当前可匹配的chunks
// const collectMatches = (inputBuffer: string) => {
// const matches: string[] = []
// let tempBuffer = inputBuffer
// let match
// 重置regex状态
chunkingRegex.lastIndex = 0
// // 重置regex状态
// chunkingRegex.lastIndex = 0
while ((match = chunkingRegex.exec(tempBuffer)) !== null) {
matches.push(match[0])
tempBuffer = tempBuffer.slice(match.index + match[0].length)
// 重置regex以从头开始匹配剩余内容
chunkingRegex.lastIndex = 0
}
// while ((match = chunkingRegex.exec(tempBuffer)) !== null) {
// matches.push(match[0])
// tempBuffer = tempBuffer.slice(match.index + match[0].length)
// // 重置regex以从头开始匹配剩余内容
// chunkingRegex.lastIndex = 0
// }
return {
matches,
remaining: tempBuffer
}
}
// return {
// matches,
// remaining: tempBuffer
// }
// }
return new TransformStream({
async transform(chunk, controller) {
if (chunk.type !== 'reasoning') {
// === 处理 reasoning 结束 ===
if (hasStartedThinking && accumulatedThinkingContent) {
console.log(
`[ReasoningPlugin] Ending reasoning. Final stats: chunks=${chunkCount}, delays=${delayCount}, efficiency=${(chunkCount / Math.max(delayCount, 1)).toFixed(2)}x`
)
// return new TransformStream({
// async transform(chunk, controller) {
// if (chunk.type !== 'reasoning') {
// // === 处理 reasoning 结束 ===
// if (hasStartedThinking && accumulatedThinkingContent) {
// console.log(
// `[ReasoningPlugin] Ending reasoning. Final stats: chunks=${chunkCount}, delays=${delayCount}, efficiency=${(chunkCount / Math.max(delayCount, 1)).toFixed(2)}x`
// )
// 先输出剩余的 buffer
if (buffer.length > 0) {
console.log(`[ReasoningPlugin] Flushing remaining buffer: "${buffer}"`)
controller.enqueue({
type: 'reasoning',
textDelta: buffer,
thinking_millsec: performance.now() - thinkingStartTime
})
buffer = ''
}
// // 先输出剩余的 buffer
// if (buffer.length > 0) {
// console.log(`[ReasoningPlugin] Flushing remaining buffer: "${buffer}"`)
// controller.enqueue({
// type: 'reasoning',
// textDelta: buffer,
// thinking_millsec: performance.now() - thinkingStartTime
// })
// buffer = ''
// }
// 生成 reasoning-signature
controller.enqueue({
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: performance.now() - thinkingStartTime
})
// // 生成 reasoning-signature
// controller.enqueue({
// type: 'reasoning-signature',
// text: accumulatedThinkingContent,
// thinking_millsec: performance.now() - thinkingStartTime
// })
// 重置状态
accumulatedThinkingContent = ''
hasStartedThinking = false
thinkingStartTime = 0
chunkCount = 0
delayCount = 0
}
// // 重置状态
// accumulatedThinkingContent = ''
// hasStartedThinking = false
// thinkingStartTime = 0
// chunkCount = 0
// delayCount = 0
// }
controller.enqueue(chunk)
return
}
// controller.enqueue(chunk)
// return
// }
// === 处理 reasoning 类型 ===
// // === 处理 reasoning 类型 ===
// 1. 时间跟踪逻辑
if (!hasStartedThinking) {
hasStartedThinking = true
thinkingStartTime = performance.now()
console.log(`[ReasoningPlugin] Starting reasoning session`)
}
accumulatedThinkingContent += chunk.textDelta
// // 1. 时间跟踪逻辑
// if (!hasStartedThinking) {
// hasStartedThinking = true
// thinkingStartTime = performance.now()
// console.log(`[ReasoningPlugin] Starting reasoning session`)
// }
// accumulatedThinkingContent += chunk.textDelta
// 2. 动态Smooth处理逻辑
const beforeBuffer = buffer
buffer += chunk.textDelta
// // 2. 动态Smooth处理逻辑
// const beforeBuffer = buffer
// buffer += chunk.textDelta
console.log(`[ReasoningPlugin] Received chunk: "${chunk.textDelta}", buffer: "${beforeBuffer}" → "${buffer}"`)
// console.log(`[ReasoningPlugin] Received chunk: "${chunk.textDelta}", buffer: "${beforeBuffer}" → "${buffer}"`)
// 收集所有当前可以匹配的chunks
const { matches, remaining } = collectMatches(buffer)
// // 收集所有当前可以匹配的chunks
// const { matches, remaining } = collectMatches(buffer)
if (matches.length > 0) {
console.log(
`[ReasoningPlugin] Collected ${matches.length} matches: [${matches.map((m) => `"${m}"`).join(', ')}], remaining: "${remaining}"`
)
// if (matches.length > 0) {
// console.log(
// `[ReasoningPlugin] Collected ${matches.length} matches: [${matches.map((m) => `"${m}"`).join(', ')}], remaining: "${remaining}"`
// )
// 批量输出所有匹配的chunks
for (const matchText of matches) {
controller.enqueue({
type: 'reasoning',
textDelta: matchText,
thinking_millsec: performance.now() - thinkingStartTime
})
chunkCount++
}
// // 批量输出所有匹配的chunks
// for (const matchText of matches) {
// controller.enqueue({
// type: 'reasoning',
// textDelta: matchText,
// thinking_millsec: performance.now() - thinkingStartTime
// })
// chunkCount++
// }
// 更新buffer为剩余内容
buffer = remaining
// // 更新buffer为剩余内容
// buffer = remaining
// 只等待一次而不是每个chunk都等待
delayCount++
console.log(
`[ReasoningPlugin] Delaying ${delayInMs}ms (delay #${delayCount}, efficiency: ${(chunkCount / delayCount).toFixed(2)} chunks/delay)`
)
const delayStart = performance.now()
await delay(delayInMs)
const actualDelay = performance.now() - delayStart
console.log(`[ReasoningPlugin] Delay completed: expected=${delayInMs}ms, actual=${actualDelay.toFixed(1)}ms`)
} else {
console.log(`[ReasoningPlugin] No matches found, keeping in buffer: "${buffer}"`)
}
// 如果没有匹配保留在buffer中等待下次数据
},
// // 只等待一次而不是每个chunk都等待
// delayCount++
// console.log(
// `[ReasoningPlugin] Delaying ${delayInMs}ms (delay #${delayCount}, efficiency: ${(chunkCount / delayCount).toFixed(2)} chunks/delay)`
// )
// const delayStart = performance.now()
// await delay(delayInMs)
// const actualDelay = performance.now() - delayStart
// console.log(`[ReasoningPlugin] Delay completed: expected=${delayInMs}ms, actual=${actualDelay.toFixed(1)}ms`)
// } else {
// console.log(`[ReasoningPlugin] No matches found, keeping in buffer: "${buffer}"`)
// }
// // 如果没有匹配保留在buffer中等待下次数据
// },
// === flush 处理剩余 buffer ===
flush(controller) {
if (buffer.length > 0) {
console.log(`[ReasoningPlugin] Final flush: "${buffer}"`)
controller.enqueue({
type: 'reasoning',
textDelta: buffer,
thinking_millsec: hasStartedThinking ? performance.now() - thinkingStartTime : 0
})
}
}
})
}
})
// // === flush 处理剩余 buffer ===
// flush(controller) {
// if (buffer.length > 0) {
// console.log(`[ReasoningPlugin] Final flush: "${buffer}"`)
// controller.enqueue({
// type: 'reasoning',
// textDelta: buffer,
// thinking_millsec: hasStartedThinking ? performance.now() - thinkingStartTime : 0
// })
// }
// }
// })
// }
// })

View File

@ -1,13 +1,13 @@
// 可能会废弃在流上做delay还是有问题
// // 可能会废弃在流上做delay还是有问题
import { definePlugin, smoothStream } from '@cherrystudio/ai-core'
// import { definePlugin, smoothStream } from '@cherrystudio/ai-core'
export default definePlugin({
name: 'textPlugin',
transformStream: () =>
smoothStream({
delayInMs: 50,
// 中文3个字符一个chunk,英文一个单词一个chunk
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
})
})
// export default definePlugin({
// name: 'textPlugin',
// transformStream: () =>
// smoothStream({
// delayInMs: 50,
// // 中文3个字符一个chunk,英文一个单词一个chunk
// chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
// })
// })

View File

@ -29,6 +29,7 @@ import {
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
import {
findFileBlocks,
findImageBlocks,
@ -41,7 +42,7 @@ import { defaultTimeout } from '@shared/config/constant'
// import { jsonSchemaToZod } from 'json-schema-to-zod'
import { setupToolsConfig } from './utils/mcp'
import { buildProviderOptions } from './utils/options'
import { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage'
import { getWebSearchTools } from './utils/websearch'
/**
*
@ -243,6 +244,7 @@ export async function buildStreamTextParams(
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
enableWebSearch?: boolean
requestOptions?: {
signal?: AbortSignal
timeout?: number
@ -277,12 +279,18 @@ export async function buildStreamTextParams(
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
// 构建系统提示
const { tools } = setupToolsConfig({
let { tools } = setupToolsConfig({
mcpTools,
model,
enableToolUse: enableTools
})
// Add web search tools if enabled
if (enableWebSearch) {
const webSearchTools = getWebSearchTools(model)
tools = { ...tools, ...webSearchTools }
}
// 构建真正的 providerOptions
const providerOptions = buildProviderOptions(assistant, model, {
enableReasoning,

View File

@ -0,0 +1,37 @@
import { isWebSearchModel } from '@renderer/config/models'
import { Model } from '@renderer/types'
// import {} from '@cherrystudio/ai-core'
// The tool name for Gemini search can be arbitrary, but let's use a descriptive one.
const GEMINI_SEARCH_TOOL_NAME = 'google_search'
export function getWebSearchTools(model: Model): Record<string, any> {
if (!isWebSearchModel(model)) {
return {}
}
// Use provider from model if available, otherwise fallback to parsing model id.
const provider = model.provider || model.id.split('/')[0]
switch (provider) {
case 'anthropic':
return {
web_search: {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5
}
}
case 'google':
case 'gemini':
return {
[GEMINI_SEARCH_TOOL_NAME]: {
googleSearch: {}
}
}
default:
// For OpenAI and others, web search is often a parameter, not a tool.
// The logic is handled in `buildProviderOptions`.
return {}
}
}