mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-28 13:31:32 +08:00
feat(tests): add unit tests for utility functions in utils.test.ts
- Implemented tests for `createErrorChunk`, `capitalize`, and `isAsyncIterable` functions. - Ensured comprehensive coverage for various input scenarios, including error handling and edge cases.
This commit is contained in:
parent
bf02afa841
commit
ff7ad52ad5
@ -332,9 +332,18 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
|
||||
if (hasKnowledgeBase) {
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// off 模式:直接添加知识库搜索工具,跳过意图识别
|
||||
// off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词
|
||||
const userMessage = userMessages[context.requestId]
|
||||
const fallbackKeywords = {
|
||||
question: [getMessageContent(userMessage) || 'search'],
|
||||
rewrite: getMessageContent(userMessage) || 'search'
|
||||
}
|
||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool (force mode)')
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||
assistant,
|
||||
fallbackKeywords,
|
||||
getMessageContent(userMessage)
|
||||
)
|
||||
params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
|
||||
} else {
|
||||
// on 模式:根据意图识别结果决定是否添加工具
|
||||
@ -343,9 +352,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
||||
analysisResult.knowledge.question &&
|
||||
analysisResult.knowledge.question[0] !== 'not_needed'
|
||||
|
||||
if (needsKnowledgeSearch) {
|
||||
if (needsKnowledgeSearch && analysisResult.knowledge) {
|
||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool (intent-based)')
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
|
||||
const userMessage = userMessages[context.requestId]
|
||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||
assistant,
|
||||
analysisResult.knowledge,
|
||||
getMessageContent(userMessage)
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,30 +1,37 @@
|
||||
import { processKnowledgeSearch } from '@renderer/services/KnowledgeService'
|
||||
import type { Assistant, KnowledgeReference } from '@renderer/types'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
import { ExtractResults, KnowledgeExtractResults } from '@renderer/utils/extract'
|
||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
import { z } from 'zod'
|
||||
|
||||
// Schema definitions - 添加 userMessage 字段来获取用户消息
|
||||
const KnowledgeSearchInputSchema = z.object({
|
||||
query: z.string().describe('The search query for knowledge base'),
|
||||
rewrite: z.string().optional().describe('Optional rewritten query with alternative phrasing'),
|
||||
userMessage: z.string().describe('The original user message content for direct search mode')
|
||||
})
|
||||
|
||||
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
|
||||
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
|
||||
|
||||
/**
|
||||
* 知识库搜索工具
|
||||
* 基于 ApiService.ts 中的 searchKnowledgeBase 逻辑实现
|
||||
* 使用预提取关键词,直接使用插件阶段分析的搜索意图,避免重复分析
|
||||
*/
|
||||
export const knowledgeSearchTool = (assistant: Assistant) => {
|
||||
export const knowledgeSearchTool = (
|
||||
assistant: Assistant,
|
||||
extractedKeywords: KnowledgeExtractResults,
|
||||
userMessage?: string
|
||||
) => {
|
||||
return tool({
|
||||
name: 'builtin_knowledge_search',
|
||||
description: 'Search the knowledge base for relevant information',
|
||||
inputSchema: KnowledgeSearchInputSchema,
|
||||
execute: async ({ query, rewrite, userMessage }) => {
|
||||
description: `Search the knowledge base for relevant information using pre-analyzed search intent.
|
||||
|
||||
Pre-extracted search queries: "${extractedKeywords.question.join(', ')}"
|
||||
Rewritten query: "${extractedKeywords.rewrite}"
|
||||
|
||||
This tool searches your knowledge base for relevant documents and returns results for easy reference.
|
||||
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||
|
||||
inputSchema: z.object({
|
||||
additionalContext: z
|
||||
.string()
|
||||
.optional()
|
||||
.describe('Optional additional context or specific focus to enhance the knowledge search')
|
||||
}),
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
try {
|
||||
// 获取助手的知识库配置
|
||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||
@ -36,35 +43,51 @@ export const knowledgeSearchTool = (assistant: Assistant) => {
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建搜索条件 - 复制原逻辑
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
let finalRewrite = extractedKeywords.rewrite
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||
console.log(`🔍 AI enhanced knowledge search with: ${additionalContext}`)
|
||||
const cleanContext = additionalContext.trim()
|
||||
if (cleanContext) {
|
||||
finalQueries = [cleanContext]
|
||||
finalRewrite = cleanContext
|
||||
console.log(`➕ Added additional context: ${cleanContext}`)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (finalQueries[0] === 'not_needed') {
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建搜索条件
|
||||
let searchCriteria: { question: string[]; rewrite: string }
|
||||
|
||||
if (knowledgeRecognition === 'off') {
|
||||
// 直接模式:使用用户消息内容 (类似原逻辑的 getMainTextContent(lastUserMessage))
|
||||
const directContent = userMessage || query || 'search'
|
||||
// 直接模式:使用用户消息内容
|
||||
const directContent = userMessage || finalQueries[0] || 'search'
|
||||
searchCriteria = {
|
||||
question: [directContent],
|
||||
rewrite: directContent
|
||||
}
|
||||
} else {
|
||||
// 自动模式:使用意图识别的结果 (类似原逻辑的 extractResults.knowledge)
|
||||
// 自动模式:使用意图识别的结果
|
||||
searchCriteria = {
|
||||
question: [query],
|
||||
rewrite: rewrite || query
|
||||
question: finalQueries,
|
||||
rewrite: finalRewrite
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要搜索
|
||||
if (searchCriteria.question[0] === 'not_needed') {
|
||||
return []
|
||||
}
|
||||
|
||||
// 构建 ExtractResults 对象 - 与原逻辑一致
|
||||
// 构建 ExtractResults 对象
|
||||
const extractResults: ExtractResults = {
|
||||
websearch: undefined,
|
||||
knowledge: searchCriteria
|
||||
}
|
||||
|
||||
console.log('Knowledge search extractResults:', extractResults)
|
||||
|
||||
// 执行知识库搜索
|
||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds)
|
||||
|
||||
@ -86,4 +109,7 @@ export const knowledgeSearchTool = (assistant: Assistant) => {
|
||||
})
|
||||
}
|
||||
|
||||
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
|
||||
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
|
||||
|
||||
export default knowledgeSearchTool
|
||||
|
||||
@ -5,42 +5,6 @@ import { ExtractResults } from '@renderer/utils/extract'
|
||||
import { InferToolInput, InferToolOutput, tool } from 'ai'
|
||||
import { z } from 'zod'
|
||||
|
||||
// import { AiSdkTool, ToolCallResult } from './types'
|
||||
|
||||
// const WebSearchResult = z.array(
|
||||
// z.object({
|
||||
// query: z.string().optional(),
|
||||
// results: z.array(
|
||||
// z.object({
|
||||
// title: z.string(),
|
||||
// content: z.string(),
|
||||
// url: z.string()
|
||||
// })
|
||||
// )
|
||||
// })
|
||||
// )
|
||||
// const webSearchToolInputSchema = z.object({
|
||||
// query: z.string().describe('The query to search for')
|
||||
// })
|
||||
|
||||
// export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
||||
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
// return tool({
|
||||
// name: 'builtin_web_search',
|
||||
// description: 'Search the web for information',
|
||||
// inputSchema: webSearchToolInputSchema,
|
||||
// outputSchema: WebSearchProviderResult,
|
||||
// execute: async ({ query }) => {
|
||||
// console.log('webSearchTool', query)
|
||||
// const response = await webSearchService.search(query)
|
||||
// console.log('webSearchTool response', response)
|
||||
// return response
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
// export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchTool>>
|
||||
// export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
||||
|
||||
/**
|
||||
* 使用预提取关键词的网络搜索工具
|
||||
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
||||
@ -53,7 +17,7 @@ export const webSearchToolWithPreExtractedKeywords = (
|
||||
},
|
||||
requestId: string
|
||||
) => {
|
||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
||||
const webSearchProvider = WebSearchService.getWebSearchProvider(webSearchProviderId)
|
||||
|
||||
return tool({
|
||||
name: 'builtin_web_search',
|
||||
@ -112,7 +76,8 @@ Call this tool to execute the search. You can optionally provide additional cont
|
||||
}
|
||||
}
|
||||
console.log('extractResults', extractResults)
|
||||
const response = await webSearchService.processWebsearch(extractResults, requestId)
|
||||
console.log('webSearchProvider', webSearchProvider)
|
||||
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
searchResults.push(response)
|
||||
} catch (error) {
|
||||
console.error(`Web search failed for query "${finalQueries}":`, error)
|
||||
|
||||
@ -3,7 +3,7 @@ import { loggerService } from '@logger'
|
||||
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||
import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web'
|
||||
import { Context, context, Span, SpanStatusCode, trace } from '@opentelemetry/api'
|
||||
import { isAsyncIterable } from '@renderer/aiCore/middleware/utils'
|
||||
import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils'
|
||||
import { db } from '@renderer/databases'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||
|
||||
@ -40,21 +40,7 @@ interface RequestState {
|
||||
/**
|
||||
* 提供网络搜索相关功能的服务类
|
||||
*/
|
||||
export default class WebSearchService {
|
||||
private static instance: WebSearchService
|
||||
private webSearchProviderId: WebSearchProvider['id']
|
||||
|
||||
private constructor(webSearchProviderId: WebSearchProvider['id']) {
|
||||
this.webSearchProviderId = webSearchProviderId
|
||||
}
|
||||
|
||||
public static getInstance(webSearchProviderId: WebSearchProvider['id']): WebSearchService {
|
||||
if (!WebSearchService.instance) {
|
||||
WebSearchService.instance = new WebSearchService(webSearchProviderId)
|
||||
}
|
||||
return WebSearchService.instance
|
||||
}
|
||||
|
||||
class WebSearchService {
|
||||
/**
|
||||
* 是否暂停
|
||||
*/
|
||||
@ -113,7 +99,7 @@ export default class WebSearchService {
|
||||
* @private
|
||||
* @returns 网络搜索状态
|
||||
*/
|
||||
private static getWebSearchState(): WebSearchState {
|
||||
private getWebSearchState(): WebSearchState {
|
||||
return store.getState().websearch
|
||||
}
|
||||
|
||||
@ -122,8 +108,8 @@ export default class WebSearchService {
|
||||
* @public
|
||||
* @returns 如果默认搜索提供商已启用则返回true,否则返回false
|
||||
*/
|
||||
public static isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
|
||||
const { providers } = WebSearchService.getWebSearchState()
|
||||
public isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
|
||||
const { providers } = this.getWebSearchState()
|
||||
const provider = providers.find((provider) => provider.id === providerId)
|
||||
|
||||
if (!provider) {
|
||||
@ -153,7 +139,7 @@ export default class WebSearchService {
|
||||
* @returns 如果启用覆盖搜索则返回true,否则返回false
|
||||
*/
|
||||
public isOverwriteEnabled(): boolean {
|
||||
const { overwrite } = WebSearchService.getWebSearchState()
|
||||
const { overwrite } = this.getWebSearchState()
|
||||
return overwrite
|
||||
}
|
||||
|
||||
@ -163,7 +149,8 @@ export default class WebSearchService {
|
||||
* @returns 网络搜索提供商
|
||||
*/
|
||||
public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined {
|
||||
const { providers } = WebSearchService.getWebSearchState()
|
||||
const { providers } = this.getWebSearchState()
|
||||
console.log('providers', providers)
|
||||
const provider = providers.find((provider) => provider.id === providerId)
|
||||
|
||||
return provider
|
||||
@ -172,16 +159,18 @@ export default class WebSearchService {
|
||||
/**
|
||||
* 使用指定的提供商执行网络搜索
|
||||
* @public
|
||||
* @param provider 搜索提供商
|
||||
* @param query 搜索查询
|
||||
* @returns 搜索响应
|
||||
*/
|
||||
public async search(query: string, httpOptions?: RequestInit, spanId?: string): Promise<WebSearchProviderResponse> {
|
||||
const websearch = WebSearchService.getWebSearchState()
|
||||
const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId)
|
||||
if (!webSearchProvider) {
|
||||
throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`)
|
||||
}
|
||||
const webSearchEngine = new WebSearchEngineProvider(webSearchProvider, spanId)
|
||||
public async search(
|
||||
provider: WebSearchProvider,
|
||||
query: string,
|
||||
httpOptions?: RequestInit,
|
||||
spanId?: string
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
const websearch = this.getWebSearchState()
|
||||
const webSearchEngine = new WebSearchEngineProvider(provider, spanId)
|
||||
|
||||
let formattedQuery = query
|
||||
// FIXME: 有待商榷,效果一般
|
||||
@ -203,9 +192,9 @@ export default class WebSearchService {
|
||||
* @param provider 要检查的搜索提供商
|
||||
* @returns 如果提供商可用返回true,否则返回false
|
||||
*/
|
||||
public async checkSearch(): Promise<{ valid: boolean; error?: any }> {
|
||||
public async checkSearch(provider: WebSearchProvider): Promise<{ valid: boolean; error?: any }> {
|
||||
try {
|
||||
const response = await this.search('test query')
|
||||
const response = await this.search(provider, 'test query')
|
||||
logger.debug('Search response:', response)
|
||||
// 优化的判断条件:检查结果是否有效且没有错误
|
||||
return { valid: response.results !== undefined, error: undefined }
|
||||
@ -437,7 +426,11 @@ export default class WebSearchService {
|
||||
*
|
||||
* @returns 包含搜索结果的响应对象
|
||||
*/
|
||||
public async processWebsearch(extractResults: ExtractResults, requestId: string): Promise<WebSearchProviderResponse> {
|
||||
public async processWebsearch(
|
||||
webSearchProvider: WebSearchProvider,
|
||||
extractResults: ExtractResults,
|
||||
requestId: string
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
// 重置状态
|
||||
await this.setWebSearchStatus(requestId, { phase: 'default' })
|
||||
|
||||
@ -479,7 +472,9 @@ export default class WebSearchService {
|
||||
return { query: 'summaries', results: contents }
|
||||
}
|
||||
|
||||
const searchPromises = questions.map((q) => this.search(q, { signal }, span?.spanContext().spanId))
|
||||
const searchPromises = questions.map((q) =>
|
||||
this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId)
|
||||
)
|
||||
const searchResults = await Promise.allSettled(searchPromises)
|
||||
|
||||
// 统计成功完成的搜索数量
|
||||
@ -524,7 +519,7 @@ export default class WebSearchService {
|
||||
}
|
||||
}
|
||||
|
||||
const { compressionConfig } = WebSearchService.getWebSearchState()
|
||||
const { compressionConfig } = this.getWebSearchState()
|
||||
|
||||
// RAG压缩处理
|
||||
if (compressionConfig?.method === 'rag' && requestId) {
|
||||
@ -578,3 +573,5 @@ export default class WebSearchService {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default new WebSearchService()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user