mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-24 10:40:07 +08:00
feat: enhance web search functionality with abort signal support
- Updated WebSearchTool to accept an abort signal in the execute method. - Modified various WebSearchProvider classes to include httpOptions for search methods, allowing for abort signal handling. - Improved WebSearchService to prioritize external abort signals for better request management. - Enhanced MessageTool to reflect tool status with appropriate UI feedback.
This commit is contained in:
parent
654f19eaa9
commit
ff378ca567
@ -40,7 +40,7 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
||||
.describe('Optional additional context, keywords, or specific focus to enhance the search')
|
||||
}),
|
||||
|
||||
execute: async ({ additionalContext }) => {
|
||||
execute: async ({ additionalContext }, { abortSignal }) => {
|
||||
let finalQueries = [...extractedKeywords.question]
|
||||
|
||||
if (additionalContext?.trim()) {
|
||||
@ -67,7 +67,15 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
||||
links: extractedKeywords.links
|
||||
}
|
||||
}
|
||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
||||
// abortSignal?.addEventListener('abort', () => {
|
||||
// console.log('tool_call_abortSignal', abortSignal?.aborted)
|
||||
// })
|
||||
searchResults = await WebSearchService.processWebsearch(
|
||||
webSearchProvider!,
|
||||
extractResults,
|
||||
requestId,
|
||||
abortSignal
|
||||
)
|
||||
|
||||
return searchResults
|
||||
},
|
||||
|
||||
@ -1,5 +1,8 @@
|
||||
import { NormalToolResponse } from '@renderer/types'
|
||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
||||
import { MessageBlockStatus, ToolMessageBlock } from '@renderer/types/newMessage'
|
||||
import { TFunction } from 'i18next'
|
||||
import { Pause } from 'lucide-react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { MessageAgentTools } from './MessageAgentTools'
|
||||
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
||||
@ -35,11 +38,23 @@ const isAgentTool = (toolName: string) => {
|
||||
return false
|
||||
}
|
||||
|
||||
const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null => {
|
||||
const ChooseTool = (
|
||||
toolResponse: NormalToolResponse,
|
||||
status: MessageBlockStatus,
|
||||
t: TFunction
|
||||
): React.ReactNode | null => {
|
||||
let toolName = toolResponse.tool.name
|
||||
const toolType = toolResponse.tool.type
|
||||
if (toolName.startsWith(prefix)) {
|
||||
toolName = toolName.slice(prefix.length)
|
||||
if (status === MessageBlockStatus.PAUSED) {
|
||||
return (
|
||||
<div className="flex items-center gap-1">
|
||||
<Pause className="h-4 w-4" />
|
||||
<span>{t('message.tools.aborted')}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
switch (toolName) {
|
||||
case 'web_search':
|
||||
case 'web_search_preview':
|
||||
@ -58,12 +73,13 @@ const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null =>
|
||||
}
|
||||
|
||||
export default function MessageTool({ block }: Props) {
|
||||
const { t } = useTranslation()
|
||||
// FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
|
||||
const toolResponse = block.metadata?.rawMcpToolResponse as NormalToolResponse
|
||||
|
||||
if (!toolResponse) return null
|
||||
|
||||
const toolRenderer = ChooseTool(toolResponse as NormalToolResponse)
|
||||
const toolRenderer = ChooseTool(toolResponse as NormalToolResponse, block.status, t)
|
||||
|
||||
if (!toolRenderer) return null
|
||||
|
||||
|
||||
@ -18,7 +18,11 @@ export default class BochaProvider extends BaseWebSearchProvider {
|
||||
}
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
public async search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
try {
|
||||
if (!query.trim()) {
|
||||
throw new Error('Search query cannot be empty')
|
||||
@ -44,7 +48,8 @@ export default class BochaProvider extends BaseWebSearchProvider {
|
||||
headers: {
|
||||
...this.defaultHeaders(),
|
||||
...headers
|
||||
}
|
||||
},
|
||||
signal: httpOptions?.signal
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
import { WebSearchState } from '@renderer/store/websearch'
|
||||
import { WebSearchProviderResponse } from '@renderer/types'
|
||||
|
||||
import BaseWebSearchProvider from './BaseWebSearchProvider'
|
||||
|
||||
export default class DefaultProvider extends BaseWebSearchProvider {
|
||||
search(): Promise<WebSearchProviderResponse> {
|
||||
search(_query: string, _websearch: WebSearchState, _httpOptions?: RequestInit): Promise<WebSearchProviderResponse> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,13 +20,18 @@ export default class ExaProvider extends BaseWebSearchProvider {
|
||||
this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
public async search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
try {
|
||||
if (!query.trim()) {
|
||||
throw new Error('Search query cannot be empty')
|
||||
}
|
||||
|
||||
const response = await this.exa.search({
|
||||
// 使用 Promise.race 来支持 abort signal
|
||||
const searchPromise = this.exa.search({
|
||||
query,
|
||||
numResults: Math.max(1, websearch.maxResults),
|
||||
contents: {
|
||||
@ -34,6 +39,20 @@ export default class ExaProvider extends BaseWebSearchProvider {
|
||||
}
|
||||
})
|
||||
|
||||
let response: Awaited<typeof searchPromise>
|
||||
if (httpOptions?.signal) {
|
||||
response = await Promise.race([
|
||||
searchPromise,
|
||||
new Promise<never>((_, reject) => {
|
||||
httpOptions.signal?.addEventListener('abort', () => {
|
||||
reject(new DOMException('The operation was aborted.', 'AbortError'))
|
||||
})
|
||||
})
|
||||
])
|
||||
} else {
|
||||
response = await searchPromise
|
||||
}
|
||||
|
||||
return {
|
||||
query: response.autopromptString,
|
||||
results: response.results.slice(0, websearch.maxResults).map((result) => {
|
||||
|
||||
@ -95,7 +95,11 @@ export default class SearxngProvider extends BaseWebSearchProvider {
|
||||
}
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
public async search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
try {
|
||||
if (!query) {
|
||||
throw new Error('Search query cannot be empty')
|
||||
@ -124,7 +128,7 @@ export default class SearxngProvider extends BaseWebSearchProvider {
|
||||
// Fetch content for each URL concurrently
|
||||
const fetchPromises = validItems.map(async (item) => {
|
||||
// Logger.log(`Fetching content for ${item.url}...`)
|
||||
return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser)
|
||||
return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions)
|
||||
})
|
||||
|
||||
// Wait for all fetches to complete
|
||||
|
||||
@ -20,23 +20,42 @@ export default class TavilyProvider extends BaseWebSearchProvider {
|
||||
this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
public async search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
try {
|
||||
if (!query.trim()) {
|
||||
throw new Error('Search query cannot be empty')
|
||||
}
|
||||
|
||||
const result = await this.tvly.search({
|
||||
// 使用 Promise.race 来支持 abort signal
|
||||
const searchPromise = this.tvly.search({
|
||||
query,
|
||||
max_results: Math.max(1, websearch.maxResults)
|
||||
})
|
||||
|
||||
let result: Awaited<typeof searchPromise>
|
||||
if (httpOptions?.signal) {
|
||||
result = await Promise.race([
|
||||
searchPromise,
|
||||
new Promise<never>((_, reject) => {
|
||||
httpOptions.signal?.addEventListener('abort', () => {
|
||||
reject(new DOMException('The operation was aborted.', 'AbortError'))
|
||||
})
|
||||
})
|
||||
])
|
||||
} else {
|
||||
result = await searchPromise
|
||||
}
|
||||
return {
|
||||
query: result.query,
|
||||
results: result.results.slice(0, websearch.maxResults).map((result) => {
|
||||
results: result.results.slice(0, websearch.maxResults).map((item) => {
|
||||
return {
|
||||
title: result.title || 'No title',
|
||||
content: result.content || '',
|
||||
url: result.url || ''
|
||||
title: item.title || 'No title',
|
||||
content: item.content || '',
|
||||
url: item.url || ''
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@ -43,7 +43,11 @@ export default class ZhipuProvider extends BaseWebSearchProvider {
|
||||
}
|
||||
}
|
||||
|
||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
||||
public async search(
|
||||
query: string,
|
||||
websearch: WebSearchState,
|
||||
httpOptions?: RequestInit
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
try {
|
||||
if (!query.trim()) {
|
||||
throw new Error('Search query cannot be empty')
|
||||
@ -62,7 +66,8 @@ export default class ZhipuProvider extends BaseWebSearchProvider {
|
||||
'Content-Type': 'application/json',
|
||||
...this.defaultHeaders()
|
||||
},
|
||||
body: JSON.stringify(requestBody)
|
||||
body: JSON.stringify(requestBody),
|
||||
signal: httpOptions?.signal
|
||||
})
|
||||
|
||||
if (!response.ok) {
|
||||
|
||||
@ -430,7 +430,8 @@ class WebSearchService {
|
||||
public async processWebsearch(
|
||||
webSearchProvider: WebSearchProvider,
|
||||
extractResults: ExtractResults,
|
||||
requestId: string
|
||||
requestId: string,
|
||||
externalSignal?: AbortSignal
|
||||
): Promise<WebSearchProviderResponse> {
|
||||
// 重置状态
|
||||
await this.setWebSearchStatus(requestId, { phase: 'default' })
|
||||
@ -441,8 +442,8 @@ class WebSearchService {
|
||||
return { results: [] }
|
||||
}
|
||||
|
||||
// 使用请求特定的signal,如果没有则回退到全局signal
|
||||
const signal = this.getRequestState(requestId).signal || this.signal
|
||||
// 优先使用外部传入的signal,其次是请求特定的signal,最后回退到全局signal
|
||||
const signal = externalSignal || this.getRequestState(requestId).signal || this.signal
|
||||
|
||||
const span = webSearchProvider.topicId
|
||||
? addSpan({
|
||||
|
||||
Loading…
Reference in New Issue
Block a user