feat(tests): add unit tests for utility functions in utils.test.ts

- Implemented tests for `createErrorChunk`, `capitalize`, and `isAsyncIterable` functions.
- Ensured comprehensive coverage for various input scenarios, including error handling and edge cases.
This commit is contained in:
MyPrototypeWhat 2025-08-08 15:20:02 +08:00
parent bf02afa841
commit ff7ad52ad5
6 changed files with 105 additions and 103 deletions

View File

@ -332,9 +332,18 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
if (hasKnowledgeBase) {
if (knowledgeRecognition === 'off') {
// off 模式:直接添加知识库搜索工具,跳过意图识别
// off 模式:直接添加知识库搜索工具,使用用户消息作为搜索关键词
const userMessage = userMessages[context.requestId]
const fallbackKeywords = {
question: [getMessageContent(userMessage) || 'search'],
rewrite: getMessageContent(userMessage) || 'search'
}
console.log('📚 [SearchOrchestration] Adding knowledge search tool (force mode)')
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
assistant,
fallbackKeywords,
getMessageContent(userMessage)
)
params.toolChoice = { type: 'tool', toolName: 'builtin_knowledge_search' }
} else {
// on 模式:根据意图识别结果决定是否添加工具
@ -343,9 +352,14 @@ export const searchOrchestrationPlugin = (assistant: Assistant) => {
analysisResult.knowledge.question &&
analysisResult.knowledge.question[0] !== 'not_needed'
if (needsKnowledgeSearch) {
if (needsKnowledgeSearch && analysisResult.knowledge) {
console.log('📚 [SearchOrchestration] Adding knowledge search tool (intent-based)')
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(assistant)
const userMessage = userMessages[context.requestId]
params.tools['builtin_knowledge_search'] = knowledgeSearchTool(
assistant,
analysisResult.knowledge,
getMessageContent(userMessage)
)
}
}
}

View File

@ -1,30 +1,37 @@
import { processKnowledgeSearch } from '@renderer/services/KnowledgeService'
import type { Assistant, KnowledgeReference } from '@renderer/types'
import { ExtractResults } from '@renderer/utils/extract'
import { ExtractResults, KnowledgeExtractResults } from '@renderer/utils/extract'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import { isEmpty } from 'lodash'
import { z } from 'zod'
// Schema definitions - 添加 userMessage 字段来获取用户消息
const KnowledgeSearchInputSchema = z.object({
query: z.string().describe('The search query for knowledge base'),
rewrite: z.string().optional().describe('Optional rewritten query with alternative phrasing'),
userMessage: z.string().describe('The original user message content for direct search mode')
})
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
/**
*
* ApiService.ts searchKnowledgeBase
* 使使
*/
export const knowledgeSearchTool = (assistant: Assistant) => {
export const knowledgeSearchTool = (
assistant: Assistant,
extractedKeywords: KnowledgeExtractResults,
userMessage?: string
) => {
return tool({
name: 'builtin_knowledge_search',
description: 'Search the knowledge base for relevant information',
inputSchema: KnowledgeSearchInputSchema,
execute: async ({ query, rewrite, userMessage }) => {
description: `Search the knowledge base for relevant information using pre-analyzed search intent.
Pre-extracted search queries: "${extractedKeywords.question.join(', ')}"
Rewritten query: "${extractedKeywords.rewrite}"
This tool searches your knowledge base for relevant documents and returns results for easy reference.
Call this tool to execute the search. You can optionally provide additional context to refine the search.`,
inputSchema: z.object({
additionalContext: z
.string()
.optional()
.describe('Optional additional context or specific focus to enhance the knowledge search')
}),
execute: async ({ additionalContext }) => {
try {
// 获取助手的知识库配置
const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
@ -36,35 +43,51 @@ export const knowledgeSearchTool = (assistant: Assistant) => {
return []
}
// 构建搜索条件 - 复制原逻辑
let finalQueries = [...extractedKeywords.question]
let finalRewrite = extractedKeywords.rewrite
if (additionalContext?.trim()) {
// 如果大模型提供了额外上下文,使用更具体的描述
console.log(`🔍 AI enhanced knowledge search with: ${additionalContext}`)
const cleanContext = additionalContext.trim()
if (cleanContext) {
finalQueries = [cleanContext]
finalRewrite = cleanContext
console.log(` Added additional context: ${cleanContext}`)
}
}
// 检查是否需要搜索
if (finalQueries[0] === 'not_needed') {
return []
}
// 构建搜索条件
let searchCriteria: { question: string[]; rewrite: string }
if (knowledgeRecognition === 'off') {
// 直接模式:使用用户消息内容 (类似原逻辑的 getMainTextContent(lastUserMessage))
const directContent = userMessage || query || 'search'
// 直接模式:使用用户消息内容
const directContent = userMessage || finalQueries[0] || 'search'
searchCriteria = {
question: [directContent],
rewrite: directContent
}
} else {
// 自动模式:使用意图识别的结果 (类似原逻辑的 extractResults.knowledge)
// 自动模式:使用意图识别的结果
searchCriteria = {
question: [query],
rewrite: rewrite || query
question: finalQueries,
rewrite: finalRewrite
}
}
// 检查是否需要搜索
if (searchCriteria.question[0] === 'not_needed') {
return []
}
// 构建 ExtractResults 对象 - 与原逻辑一致
// 构建 ExtractResults 对象
const extractResults: ExtractResults = {
websearch: undefined,
knowledge: searchCriteria
}
console.log('Knowledge search extractResults:', extractResults)
// 执行知识库搜索
const knowledgeReferences = await processKnowledgeSearch(extractResults, knowledgeBaseIds)
@ -86,4 +109,7 @@ export const knowledgeSearchTool = (assistant: Assistant) => {
})
}
export type KnowledgeSearchToolInput = InferToolInput<ReturnType<typeof knowledgeSearchTool>>
export type KnowledgeSearchToolOutput = InferToolOutput<ReturnType<typeof knowledgeSearchTool>>
export default knowledgeSearchTool

View File

@ -5,42 +5,6 @@ import { ExtractResults } from '@renderer/utils/extract'
import { InferToolInput, InferToolOutput, tool } from 'ai'
import { z } from 'zod'
// import { AiSdkTool, ToolCallResult } from './types'
// const WebSearchResult = z.array(
// z.object({
// query: z.string().optional(),
// results: z.array(
// z.object({
// title: z.string(),
// content: z.string(),
// url: z.string()
// })
// )
// })
// )
// const webSearchToolInputSchema = z.object({
// query: z.string().describe('The query to search for')
// })
// export const webSearchTool = (webSearchProviderId: WebSearchProvider['id']) => {
// const webSearchService = WebSearchService.getInstance(webSearchProviderId)
// return tool({
// name: 'builtin_web_search',
// description: 'Search the web for information',
// inputSchema: webSearchToolInputSchema,
// outputSchema: WebSearchProviderResult,
// execute: async ({ query }) => {
// console.log('webSearchTool', query)
// const response = await webSearchService.search(query)
// console.log('webSearchTool response', response)
// return response
// }
// })
// }
// export type WebSearchToolInput = InferToolInput<ReturnType<typeof webSearchTool>>
// export type WebSearchToolOutput = InferToolOutput<ReturnType<typeof webSearchTool>>
/**
* 使
* 使
@ -53,7 +17,7 @@ export const webSearchToolWithPreExtractedKeywords = (
},
requestId: string
) => {
const webSearchService = WebSearchService.getInstance(webSearchProviderId)
const webSearchProvider = WebSearchService.getWebSearchProvider(webSearchProviderId)
return tool({
name: 'builtin_web_search',
@ -112,7 +76,8 @@ Call this tool to execute the search. You can optionally provide additional cont
}
}
console.log('extractResults', extractResults)
const response = await webSearchService.processWebsearch(extractResults, requestId)
console.log('webSearchProvider', webSearchProvider)
const response = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
searchResults.push(response)
} catch (error) {
console.error(`Web search failed for query "${finalQueries}":`, error)

View File

@ -3,7 +3,7 @@ import { loggerService } from '@logger'
import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
import { cleanContext, endContext, getContext, startContext } from '@mcp-trace/trace-web'
import { Context, context, Span, SpanStatusCode, trace } from '@opentelemetry/api'
import { isAsyncIterable } from '@renderer/aiCore/middleware/utils'
import { isAsyncIterable } from '@renderer/aiCore/legacy/middleware/utils'
import { db } from '@renderer/databases'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'

View File

@ -40,21 +40,7 @@ interface RequestState {
/**
*
*/
export default class WebSearchService {
private static instance: WebSearchService
private webSearchProviderId: WebSearchProvider['id']
private constructor(webSearchProviderId: WebSearchProvider['id']) {
this.webSearchProviderId = webSearchProviderId
}
public static getInstance(webSearchProviderId: WebSearchProvider['id']): WebSearchService {
if (!WebSearchService.instance) {
WebSearchService.instance = new WebSearchService(webSearchProviderId)
}
return WebSearchService.instance
}
class WebSearchService {
/**
*
*/
@ -113,7 +99,7 @@ export default class WebSearchService {
* @private
* @returns
*/
private static getWebSearchState(): WebSearchState {
private getWebSearchState(): WebSearchState {
return store.getState().websearch
}
@ -122,8 +108,8 @@ export default class WebSearchService {
* @public
* @returns truefalse
*/
public static isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
const { providers } = WebSearchService.getWebSearchState()
public isWebSearchEnabled(providerId?: WebSearchProvider['id']): boolean {
const { providers } = this.getWebSearchState()
const provider = providers.find((provider) => provider.id === providerId)
if (!provider) {
@ -153,7 +139,7 @@ export default class WebSearchService {
* @returns truefalse
*/
public isOverwriteEnabled(): boolean {
const { overwrite } = WebSearchService.getWebSearchState()
const { overwrite } = this.getWebSearchState()
return overwrite
}
@ -163,7 +149,8 @@ export default class WebSearchService {
* @returns
*/
public getWebSearchProvider(providerId?: string): WebSearchProvider | undefined {
const { providers } = WebSearchService.getWebSearchState()
const { providers } = this.getWebSearchState()
console.log('providers', providers)
const provider = providers.find((provider) => provider.id === providerId)
return provider
@ -172,16 +159,18 @@ export default class WebSearchService {
/**
* 使
* @public
* @param provider
* @param query
* @returns
*/
public async search(query: string, httpOptions?: RequestInit, spanId?: string): Promise<WebSearchProviderResponse> {
const websearch = WebSearchService.getWebSearchState()
const webSearchProvider = this.getWebSearchProvider(this.webSearchProviderId)
if (!webSearchProvider) {
throw new Error(`WebSearchProvider ${this.webSearchProviderId} not found`)
}
const webSearchEngine = new WebSearchEngineProvider(webSearchProvider, spanId)
public async search(
provider: WebSearchProvider,
query: string,
httpOptions?: RequestInit,
spanId?: string
): Promise<WebSearchProviderResponse> {
const websearch = this.getWebSearchState()
const webSearchEngine = new WebSearchEngineProvider(provider, spanId)
let formattedQuery = query
// FIXME: 有待商榷,效果一般
@ -203,9 +192,9 @@ export default class WebSearchService {
* @param provider
* @returns truefalse
*/
public async checkSearch(): Promise<{ valid: boolean; error?: any }> {
public async checkSearch(provider: WebSearchProvider): Promise<{ valid: boolean; error?: any }> {
try {
const response = await this.search('test query')
const response = await this.search(provider, 'test query')
logger.debug('Search response:', response)
// 优化的判断条件:检查结果是否有效且没有错误
return { valid: response.results !== undefined, error: undefined }
@ -437,7 +426,11 @@ export default class WebSearchService {
*
* @returns
*/
public async processWebsearch(extractResults: ExtractResults, requestId: string): Promise<WebSearchProviderResponse> {
public async processWebsearch(
webSearchProvider: WebSearchProvider,
extractResults: ExtractResults,
requestId: string
): Promise<WebSearchProviderResponse> {
// 重置状态
await this.setWebSearchStatus(requestId, { phase: 'default' })
@ -479,7 +472,9 @@ export default class WebSearchService {
return { query: 'summaries', results: contents }
}
const searchPromises = questions.map((q) => this.search(q, { signal }, span?.spanContext().spanId))
const searchPromises = questions.map((q) =>
this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId)
)
const searchResults = await Promise.allSettled(searchPromises)
// 统计成功完成的搜索数量
@ -524,7 +519,7 @@ export default class WebSearchService {
}
}
const { compressionConfig } = WebSearchService.getWebSearchState()
const { compressionConfig } = this.getWebSearchState()
// RAG压缩处理
if (compressionConfig?.method === 'rag' && requestId) {
@ -578,3 +573,5 @@ export default class WebSearchService {
}
}
}
export default new WebSearchService()