mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-09 23:10:20 +08:00
feat(tests): add unit tests for utility functions in utils.test.ts
- Implemented tests for `createErrorChunk`, `capitalize`, and `isAsyncIterable` functions. - Ensured comprehensive coverage for various input scenarios, including error handling and edge cases.
This commit is contained in:
parent
bf02afa841
commit
ff7ad52ad5
@ -332,9 +332,18 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
|||||||
|
|
||||||
if (hasKnowledgeBase) {
|
if (hasKnowledgeBase) {
|
||||||
if (knowledgeRecognition === 'off') {
|
if (knowledgeRecognition === 'off') {
|
||||||
// off 模式:直接添加知识库搜索工具,跳过意图识别
|
// off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词
|
||||||
|
const userMessage = userMessages[context.requestId]
|
||||||
|
const fallbackKeywords = {
|
||||||
|
question: [getMessageContent(userMessage) || 'search'],
|
||||||
|
rewrite: getMessageContent(userMessage) || 'search'
|
||||||
|
}
|
||||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool (force mode)')
|
console.log('📚 [SearchOrchestration] Adding knowledge search tool (force mode)')
|
||||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
|
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||||
|
assistant,
|
||||||
|
fallbackKeywords,
|
||||||
|
getMessageContent(userMessage)
|
||||||
|
)
|
||||||
params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
|
params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
|
||||||
} else {
|
} else {
|
||||||
// on 模式:根据意图识别结果决定是否添加工具
|
// on 模式:根据意图识别结果决定是否添加工具
|
||||||
@ -343,9 +352,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
|
|||||||
analysisResult.knowledge.question &&
|
analysisResult.knowledge.question &&
|
||||||
analysisResult.knowledge.question[0] !== 'not_needed'
|
analysisResult.knowledge.question[0] !== 'not_needed'
|
||||||
|
|
||||||
if (needsKnowledgeSearch) {
|
if (needsKnowledgeSearch && analysisResult.knowledge) {
|
||||||
console.log('📚 [SearchOrchestration] Adding knowledge search tool (intent-based)')
|
console.log('📚 [SearchOrchestration] Adding knowledge search tool (intent-based)')
|
||||||
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
|
const userMessage = userMessages[context.requestId]
|
||||||
|
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
|
||||||
|
assistant,
|
||||||
|
analysisResult.knowledge,
|
||||||
|
getMessageContent(userMessage)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,30 +1,37 @@
|
|||||||
import { processKnowledgeSearch } from '@renderer/services/KnowledgeService'
|
import { processKnowledgeSearch } from '@renderer/services/KnowledgeService'
|
||||||
import type { Assistant, KnowledgeReference } from '@renderer/types'
|
import type { Assistant, KnowledgeReference } from '@renderer/types'
|
||||||
import { ExtractResults } from '@renderer/utils/extract'
|
import { ExtractResults, KnowledgeExtractResults } from '@renderer/utils/extract'
|
||||||
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
// Schema definitions - 添加 userMessage 字段来获取用户消息
|
|
||||||
const KnowledgeSearchInputSchema = z.object({
|
|
||||||
query: z.string().describe('The search query for knowledge base'),
|
|
||||||
rewrite: z.string().optional().describe('Optional rewritten query with alternative phrasing'),
|
|
||||||
userMessage: z.string().describe('The original user message content for direct search mode')
|
|
||||||
})
|
|
||||||
|
|
||||||
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
|
|
||||||
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 知识库搜索工具
|
* 知识库搜索工具
|
||||||
* 基于 ApiService.ts 中的 searchKnowledgeBase 逻辑实现
|
* 使用预提取关键词,直接使用插件阶段分析的搜索意图,避免重复分析
|
||||||
*/
|
*/
|
||||||
export const knowledgeSearchTool = (assistant: Assistant) => {
|
export const knowledgeSearchTool = (
|
||||||
|
assistant: Assistant,
|
||||||
|
extractedKeywords: KnowledgeExtractResults,
|
||||||
|
userMessage?: string
|
||||||
|
) => {
|
||||||
return tool({
|
return tool({
|
||||||
name: 'builtin_knowledge_search',
|
name: 'builtin_knowledge_search',
|
||||||
description: 'Search the knowledge base for relevant information',
|
description: `Search the knowledge base for relevant information using pre-analyzed search intent.
|
||||||
inputSchema: KnowledgeSearchInputSchema,
|
|
||||||
execute: async ({ query, rewrite, userMessage }) => {
|
Pre-extracted search queries: "${extractedKeywords.question.join(', ')}"
|
||||||
|
Rewritten query: "${extractedKeywords.rewrite}"
|
||||||
|
|
||||||
|
This tool searches your knowledge base for relevant documents and returns results for easy reference.
|
||||||
|
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
|
||||||
|
|
||||||
|
inputSchema: z.object({
|
||||||
|
additionalContext: z
|
||||||
|
.string()
|
||||||
|
.optional()
|
||||||
|
.describe('Optional additional context or specific focus to enhance the knowledge search')
|
||||||
|
}),
|
||||||
|
|
||||||
|
execute: async ({ additionalContext }) => {
|
||||||
try {
|
try {
|
||||||
// 获取助手的知识库配置
|
// 获取助手的知识库配置
|
||||||
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
|
||||||
@ -36,35 +43,51 @@ export const knowledgeSearchTool = (assistant: Assistant) => {
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建搜索条件 - 复制原逻辑
|
let finalQueries = [...extractedKeywords.question]
|
||||||
|
let finalRewrite = extractedKeywords.rewrite
|
||||||
|
|
||||||
|
if (additionalContext?.trim()) {
|
||||||
|
// 如果大模型提供了额外上下文,使用更具体的描述
|
||||||
|
console.log(`🔍 AI enhanced knowledge search with: ${additionalContext}`)
|
||||||
|
const cleanContext = additionalContext.trim()
|
||||||
|
if (cleanContext) {
|
||||||
|
finalQueries = [cleanContext]
|
||||||
|
finalRewrite = cleanContext
|
||||||
|
console.log(`➕ Added additional context: ${cleanContext}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否需要搜索
|
||||||
|
if (finalQueries[0] === 'not_needed') {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建搜索条件
|
||||||
let searchCriteria: { question: string[]; rewrite: string }
|
let searchCriteria: { question: string[]; rewrite: string }
|
||||||
|
|
||||||
if (knowledgeRecognition === 'off') {
|
if (knowledgeRecognition === 'off') {
|
||||||
// 直接模式:使用用户消息内容 (类似原逻辑的 getMainTextContent(lastUserMessage))
|
// 直接模式:使用用户消息内容
|
||||||
const directContent = userMessage || query || 'search'
|
const directContent = userMessage || finalQueries[0] || 'search'
|
||||||
searchCriteria = {
|
searchCriteria = {
|
||||||
question: [directContent],
|
question: [directContent],
|
||||||
rewrite: directContent
|
rewrite: directContent
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 自动模式:使用意图识别的结果 (类似原逻辑的 extractResults.knowledge)
|
// 自动模式:使用意图识别的结果
|
||||||
searchCriteria = {
|
searchCriteria = {
|
||||||
question: [query],
|
question: finalQueries,
|
||||||
rewrite: rewrite || query
|
rewrite: finalRewrite
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否需要搜索
|
// 构建 ExtractResults 对象
|
||||||
if (searchCriteria.question[0] === 'not_needed') {
|
|
||||||
return []
|
|
||||||
}
|
|
||||||
|
|
||||||
// 构建 ExtractResults 对象 - 与原逻辑一致
|
|
||||||
const extractResults: ExtractResults = {
|
const extractResults: ExtractResults = {
|
||||||
websearch: undefined,
|
websearch: undefined,
|
||||||
knowledge: searchCriteria
|
knowledge: searchCriteria
|
||||||
}
|
}
|
||||||
|
|
||||||
|
console.log('Knowledge search extractResults:', extractResults)
|
||||||
|
|
||||||
// 执行知识库搜索
|
// 执行知识库搜索
|
||||||
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds)
|
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds)
|
||||||
|
|
||||||
@ -86,4 +109,7 @@ export const knowledgeSearchTool = (assistant: Assistant) => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
|
||||||
|
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
|
||||||
|
|
||||||
export default knowledgeSearchTool
|
export default knowledgeSearchTool
|
||||||
|
|||||||
@ -5,42 +5,6 @@ import { ExtractResults } from '@renderer/utils/extract'
|
|||||||
import { InferToolInput, InferToolOutput, tool } from 'ai'
|
import { InferToolInput, InferToolOutput, tool } from 'ai'
|
||||||
import { z } from 'zod'
|
import { z } from 'zod'
|
||||||
|
|
||||||
// import { AiSdkTool, ToolCallResult } from './types'
|
|
||||||
|
|
||||||
// const WebSearchResult = z.array(
|
|
||||||
// z.object({
|
|
||||||
// query: z.string().optional(),
|
|
||||||
// results: z.array(
|
|
||||||
// z.object({
|
|
||||||
// title: z.string(),
|
|
||||||
// content: z.string(),
|
|
||||||
// url: z.string()
|
|
||||||
// })
|
|
||||||
// )
|
|
||||||
// })
|
|
||||||
// )
|
|
||||||
// const webSearchToolInputSchema = z.object({
|
|
||||||
// query: z.string().describe('The query to search for')
|
|
||||||
// })
|
|
||||||
|
|
||||||
// export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
|
|
||||||
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
|
||||||
// return tool({
|
|
||||||
// name: 'builtin_web_search',
|
|
||||||
// description: 'Search the web for information',
|
|
||||||
// inputSchema: webSearchToolInputSchema,
|
|
||||||
// outputSchema: WebSearchProviderResult,
|
|
||||||
// execute: async ({ query }) => {
|
|
||||||
// console.log('webSearchTool', query)
|
|
||||||
// const response = await webSearchService.search(query)
|
|
||||||
// console.log('webSearchTool response', response)
|
|
||||||
// return response
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
// export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchTool>>
|
|
||||||
// export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 使用预提取关键词的网络搜索工具
|
* 使用预提取关键词的网络搜索工具
|
||||||
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
* 这个工具直接使用插件阶段分析的搜索意图,避免重复分析
|
||||||
@ -53,7 +17,7 @@ export const webSearchToolWithPreExtractedKeywords = (
|
|||||||
},
|
},
|
||||||
requestId: string
|
requestId: string
|
||||||
) => {
|
) => {
|
||||||
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
|
const webSearchProvider = WebSearchService.getWebSearchProvider(webSearchProviderId)
|
||||||
|
|
||||||
return tool({
|
return tool({
|
||||||
name: 'builtin_web_search',
|
name: 'builtin_web_search',
|
||||||
@ -112,7 +76,8 @@ Call this tool to execute the search. You can optionally provide additional cont
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
console.log('extractResults', extractResults)
|
console.log('extractResults', extractResults)
|
||||||
const response = await webSearchService.processWebsearch(extractResults, requestId)
|
console.log('webSearchProvider', webSearchProvider)
|
||||||
|
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||||
searchResults.push(response)
|
searchResults.push(response)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(`Web search failed for query "${finalQueries}":`, error)
|
console.error(`Web search failed for query "${finalQueries}":`, error)
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import { loggerService } from '@logger'
|
|||||||
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||||
import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web'
|
import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web'
|
||||||
import { Context, context, Span, SpanStatusCode, trace } from '@opentelemetry/api'
|
import { Context, context, Span, SpanStatusCode, trace } from '@opentelemetry/api'
|
||||||
import { isAsyncIterable } from '@renderer/aiCore/middleware/utils'
|
import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils'
|
||||||
import { db } from '@renderer/databases'
|
import { db } from '@renderer/databases'
|
||||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
|
||||||
|
|||||||
@ -40,21 +40,7 @@ interface RequestState {
|
|||||||
/**
|
/**
|
||||||
* 提供网络搜索相关功能的服务类
|
* 提供网络搜索相关功能的服务类
|
||||||
*/
|
*/
|
||||||
export default class WebSearchService {
|
class WebSearchService {
|
||||||
private static instance: WebSearchService
|
|
||||||
private webSearchProviderId: WebSearchProvider['id']
|
|
||||||
|
|
||||||
private constructor(webSearchProviderId: WebSearchProvider['id']) {
|
|
||||||
this.webSearchProviderId = webSearchProviderId
|
|
||||||
}
|
|
||||||
|
|
||||||
public static getInstance(webSearchProviderId: WebSearchProvider['id']): WebSearchService {
|
|
||||||
if (!WebSearchService.instance) {
|
|
||||||
WebSearchService.instance = new WebSearchService(webSearchProviderId)
|
|
||||||
}
|
|
||||||
return WebSearchService.instance
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 是否暂停
|
* 是否暂停
|
||||||
*/
|
*/
|
||||||
@ -113,7 +99,7 @@ export default class WebSearchService {
|
|||||||
* @private
|
* @private
|
||||||
* @returns 网络搜索状态
|
* @returns 网络搜索状态
|
||||||
*/
|
*/
|
||||||
private static getWebSearchState(): WebSearchState {
|
private getWebSearchState(): WebSearchState {
|
||||||
return store.getState().websearch
|
return store.getState().websearch
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -122,8 +108,8 @@ export default class WebSearchService {
|
|||||||
* @public
|
* @public
|
||||||
* @returns 如果默认搜索提供商已启用则返回true,否则返回false
|
* @returns 如果默认搜索提供商已启用则返回true,否则返回false
|
||||||
*/
|
*/
|
||||||
public static isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
|
public isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
|
||||||
const { providers } = WebSearchService.getWebSearchState()
|
const { providers } = this.getWebSearchState()
|
||||||
const provider = providers.find((provider) => provider.id === providerId)
|
const provider = providers.find((provider) => provider.id === providerId)
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
@ -153,7 +139,7 @@ export default class WebSearchService {
|
|||||||
* @returns 如果启用覆盖搜索则返回true,否则返回false
|
* @returns 如果启用覆盖搜索则返回true,否则返回false
|
||||||
*/
|
*/
|
||||||
public isOverwriteEnabled(): boolean {
|
public isOverwriteEnabled(): boolean {
|
||||||
const { overwrite } = WebSearchService.getWebSearchState()
|
const { overwrite } = this.getWebSearchState()
|
||||||
return overwrite
|
return overwrite
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -163,7 +149,8 @@ export default class WebSearchService {
|
|||||||
* @returns 网络搜索提供商
|
* @returns 网络搜索提供商
|
||||||
*/
|
*/
|
||||||
public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined {
|
public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined {
|
||||||
const { providers } = WebSearchService.getWebSearchState()
|
const { providers } = this.getWebSearchState()
|
||||||
|
console.log('providers', providers)
|
||||||
const provider = providers.find((provider) => provider.id === providerId)
|
const provider = providers.find((provider) => provider.id === providerId)
|
||||||
|
|
||||||
return provider
|
return provider
|
||||||
@ -172,16 +159,18 @@ export default class WebSearchService {
|
|||||||
/**
|
/**
|
||||||
* 使用指定的提供商执行网络搜索
|
* 使用指定的提供商执行网络搜索
|
||||||
* @public
|
* @public
|
||||||
|
* @param provider 搜索提供商
|
||||||
* @param query 搜索查询
|
* @param query 搜索查询
|
||||||
* @returns 搜索响应
|
* @returns 搜索响应
|
||||||
*/
|
*/
|
||||||
public async search(query: string, httpOptions?: RequestInit, spanId?: string): Promise<WebSearchProviderResponse> {
|
public async search(
|
||||||
const websearch = WebSearchService.getWebSearchState()
|
provider: WebSearchProvider,
|
||||||
const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId)
|
query: string,
|
||||||
if (!webSearchProvider) {
|
httpOptions?: RequestInit,
|
||||||
throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`)
|
spanId?: string
|
||||||
}
|
): Promise<WebSearchProviderResponse> {
|
||||||
const webSearchEngine = new WebSearchEngineProvider(webSearchProvider, spanId)
|
const websearch = this.getWebSearchState()
|
||||||
|
const webSearchEngine = new WebSearchEngineProvider(provider, spanId)
|
||||||
|
|
||||||
let formattedQuery = query
|
let formattedQuery = query
|
||||||
// FIXME: 有待商榷,效果一般
|
// FIXME: 有待商榷,效果一般
|
||||||
@ -203,9 +192,9 @@ export default class WebSearchService {
|
|||||||
* @param provider 要检查的搜索提供商
|
* @param provider 要检查的搜索提供商
|
||||||
* @returns 如果提供商可用返回true,否则返回false
|
* @returns 如果提供商可用返回true,否则返回false
|
||||||
*/
|
*/
|
||||||
public async checkSearch(): Promise<{ valid: boolean; error?: any }> {
|
public async checkSearch(provider: WebSearchProvider): Promise<{ valid: boolean; error?: any }> {
|
||||||
try {
|
try {
|
||||||
const response = await this.search('test query')
|
const response = await this.search(provider, 'test query')
|
||||||
logger.debug('Search response:', response)
|
logger.debug('Search response:', response)
|
||||||
// 优化的判断条件:检查结果是否有效且没有错误
|
// 优化的判断条件:检查结果是否有效且没有错误
|
||||||
return { valid: response.results !== undefined, error: undefined }
|
return { valid: response.results !== undefined, error: undefined }
|
||||||
@ -437,7 +426,11 @@ export default class WebSearchService {
|
|||||||
*
|
*
|
||||||
* @returns 包含搜索结果的响应对象
|
* @returns 包含搜索结果的响应对象
|
||||||
*/
|
*/
|
||||||
public async processWebsearch(extractResults: ExtractResults, requestId: string): Promise<WebSearchProviderResponse> {
|
public async processWebsearch(
|
||||||
|
webSearchProvider: WebSearchProvider,
|
||||||
|
extractResults: ExtractResults,
|
||||||
|
requestId: string
|
||||||
|
): Promise<WebSearchProviderResponse> {
|
||||||
// 重置状态
|
// 重置状态
|
||||||
await this.setWebSearchStatus(requestId, { phase: 'default' })
|
await this.setWebSearchStatus(requestId, { phase: 'default' })
|
||||||
|
|
||||||
@ -479,7 +472,9 @@ export default class WebSearchService {
|
|||||||
return { query: 'summaries', results: contents }
|
return { query: 'summaries', results: contents }
|
||||||
}
|
}
|
||||||
|
|
||||||
const searchPromises = questions.map((q) => this.search(q, { signal }, span?.spanContext().spanId))
|
const searchPromises = questions.map((q) =>
|
||||||
|
this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId)
|
||||||
|
)
|
||||||
const searchResults = await Promise.allSettled(searchPromises)
|
const searchResults = await Promise.allSettled(searchPromises)
|
||||||
|
|
||||||
// 统计成功完成的搜索数量
|
// 统计成功完成的搜索数量
|
||||||
@ -524,7 +519,7 @@ export default class WebSearchService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const { compressionConfig } = WebSearchService.getWebSearchState()
|
const { compressionConfig } = this.getWebSearchState()
|
||||||
|
|
||||||
// RAG压缩处理
|
// RAG压缩处理
|
||||||
if (compressionConfig?.method === 'rag' && requestId) {
|
if (compressionConfig?.method === 'rag' && requestId) {
|
||||||
@ -578,3 +573,5 @@ export default class WebSearchService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export default new WebSearchService()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user