mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-05 04:19:02 +08:00
feat: enhance web search functionality with abort signal support
- Updated WebSearchTool to accept an abort signal in the execute method. - Modified various WebSearchProvider classes to include httpOptions for search methods, allowing for abort signal handling. - Improved WebSearchService to prioritize external abort signals for better request management. - Enhanced MessageTool to reflect tool status with appropriate UI feedback.
This commit is contained in:
parent
654f19eaa9
commit
ff378ca567
@ -40,7 +40,7 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
|||||||
.describe('Optional additional context, keywords, or specific focus to enhance the search')
|
.describe('Optional additional context, keywords, or specific focus to enhance the search')
|
||||||
}),
|
}),
|
||||||
|
|
||||||
execute: async ({ additionalContext }) => {
|
execute: async ({ additionalContext }, { abortSignal }) => {
|
||||||
let finalQueries = [...extractedKeywords.question]
|
let finalQueries = [...extractedKeywords.question]
|
||||||
|
|
||||||
if (additionalContext?.trim()) {
|
if (additionalContext?.trim()) {
|
||||||
@ -67,7 +67,15 @@ You can use this tool as-is to search with the prepared queries, or provide addi
|
|||||||
links: extractedKeywords.links
|
links: extractedKeywords.links
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
|
// abortSignal?.addEventListener('abort', () => {
|
||||||
|
// console.log('tool_call_abortSignal', abortSignal?.aborted)
|
||||||
|
// })
|
||||||
|
searchResults = await WebSearchService.processWebsearch(
|
||||||
|
webSearchProvider!,
|
||||||
|
extractResults,
|
||||||
|
requestId,
|
||||||
|
abortSignal
|
||||||
|
)
|
||||||
|
|
||||||
return searchResults
|
return searchResults
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1,5 +1,8 @@
|
|||||||
import { NormalToolResponse } from '@renderer/types'
|
import { NormalToolResponse } from '@renderer/types'
|
||||||
import type { ToolMessageBlock } from '@renderer/types/newMessage'
|
import { MessageBlockStatus, ToolMessageBlock } from '@renderer/types/newMessage'
|
||||||
|
import { TFunction } from 'i18next'
|
||||||
|
import { Pause } from 'lucide-react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
import { MessageAgentTools } from './MessageAgentTools'
|
import { MessageAgentTools } from './MessageAgentTools'
|
||||||
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
|
||||||
@ -35,11 +38,23 @@ const isAgentTool = (toolName: string) => {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null => {
|
const ChooseTool = (
|
||||||
|
toolResponse: NormalToolResponse,
|
||||||
|
status: MessageBlockStatus,
|
||||||
|
t: TFunction
|
||||||
|
): React.ReactNode | null => {
|
||||||
let toolName = toolResponse.tool.name
|
let toolName = toolResponse.tool.name
|
||||||
const toolType = toolResponse.tool.type
|
const toolType = toolResponse.tool.type
|
||||||
if (toolName.startsWith(prefix)) {
|
if (toolName.startsWith(prefix)) {
|
||||||
toolName = toolName.slice(prefix.length)
|
toolName = toolName.slice(prefix.length)
|
||||||
|
if (status === MessageBlockStatus.PAUSED) {
|
||||||
|
return (
|
||||||
|
<div className="flex items-center gap-1">
|
||||||
|
<Pause className="h-4 w-4" />
|
||||||
|
<span>{t('message.tools.aborted')}</span>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
switch (toolName) {
|
switch (toolName) {
|
||||||
case 'web_search':
|
case 'web_search':
|
||||||
case 'web_search_preview':
|
case 'web_search_preview':
|
||||||
@ -58,12 +73,13 @@ const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null =>
|
|||||||
}
|
}
|
||||||
|
|
||||||
export default function MessageTool({ block }: Props) {
|
export default function MessageTool({ block }: Props) {
|
||||||
|
const { t } = useTranslation()
|
||||||
// FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
|
// FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
|
||||||
const toolResponse = block.metadata?.rawMcpToolResponse as NormalToolResponse
|
const toolResponse = block.metadata?.rawMcpToolResponse as NormalToolResponse
|
||||||
|
|
||||||
if (!toolResponse) return null
|
if (!toolResponse) return null
|
||||||
|
|
||||||
const toolRenderer = ChooseTool(toolResponse as NormalToolResponse)
|
const toolRenderer = ChooseTool(toolResponse as NormalToolResponse, block.status, t)
|
||||||
|
|
||||||
if (!toolRenderer) return null
|
if (!toolRenderer) return null
|
||||||
|
|
||||||
|
|||||||
@ -18,7 +18,11 @@ export default class BochaProvider extends BaseWebSearchProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
public async search(
|
||||||
|
query: string,
|
||||||
|
websearch: WebSearchState,
|
||||||
|
httpOptions?: RequestInit
|
||||||
|
): Promise<WebSearchProviderResponse> {
|
||||||
try {
|
try {
|
||||||
if (!query.trim()) {
|
if (!query.trim()) {
|
||||||
throw new Error('Search query cannot be empty')
|
throw new Error('Search query cannot be empty')
|
||||||
@ -44,7 +48,8 @@ export default class BochaProvider extends BaseWebSearchProvider {
|
|||||||
headers: {
|
headers: {
|
||||||
...this.defaultHeaders(),
|
...this.defaultHeaders(),
|
||||||
...headers
|
...headers
|
||||||
}
|
},
|
||||||
|
signal: httpOptions?.signal
|
||||||
})
|
})
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
|||||||
@ -1,9 +1,10 @@
|
|||||||
|
import { WebSearchState } from '@renderer/store/websearch'
|
||||||
import { WebSearchProviderResponse } from '@renderer/types'
|
import { WebSearchProviderResponse } from '@renderer/types'
|
||||||
|
|
||||||
import BaseWebSearchProvider from './BaseWebSearchProvider'
|
import BaseWebSearchProvider from './BaseWebSearchProvider'
|
||||||
|
|
||||||
export default class DefaultProvider extends BaseWebSearchProvider {
|
export default class DefaultProvider extends BaseWebSearchProvider {
|
||||||
search(): Promise<WebSearchProviderResponse> {
|
search(_query: string, _websearch: WebSearchState, _httpOptions?: RequestInit): Promise<WebSearchProviderResponse> {
|
||||||
throw new Error('Method not implemented.')
|
throw new Error('Method not implemented.')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -20,13 +20,18 @@ export default class ExaProvider extends BaseWebSearchProvider {
|
|||||||
this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
||||||
}
|
}
|
||||||
|
|
||||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
public async search(
|
||||||
|
query: string,
|
||||||
|
websearch: WebSearchState,
|
||||||
|
httpOptions?: RequestInit
|
||||||
|
): Promise<WebSearchProviderResponse> {
|
||||||
try {
|
try {
|
||||||
if (!query.trim()) {
|
if (!query.trim()) {
|
||||||
throw new Error('Search query cannot be empty')
|
throw new Error('Search query cannot be empty')
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await this.exa.search({
|
// 使用 Promise.race 来支持 abort signal
|
||||||
|
const searchPromise = this.exa.search({
|
||||||
query,
|
query,
|
||||||
numResults: Math.max(1, websearch.maxResults),
|
numResults: Math.max(1, websearch.maxResults),
|
||||||
contents: {
|
contents: {
|
||||||
@ -34,6 +39,20 @@ export default class ExaProvider extends BaseWebSearchProvider {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
let response: Awaited<typeof searchPromise>
|
||||||
|
if (httpOptions?.signal) {
|
||||||
|
response = await Promise.race([
|
||||||
|
searchPromise,
|
||||||
|
new Promise<never>((_, reject) => {
|
||||||
|
httpOptions.signal?.addEventListener('abort', () => {
|
||||||
|
reject(new DOMException('The operation was aborted.', 'AbortError'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
response = await searchPromise
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
query: response.autopromptString,
|
query: response.autopromptString,
|
||||||
results: response.results.slice(0, websearch.maxResults).map((result) => {
|
results: response.results.slice(0, websearch.maxResults).map((result) => {
|
||||||
|
|||||||
@ -95,7 +95,11 @@ export default class SearxngProvider extends BaseWebSearchProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
public async search(
|
||||||
|
query: string,
|
||||||
|
websearch: WebSearchState,
|
||||||
|
httpOptions?: RequestInit
|
||||||
|
): Promise<WebSearchProviderResponse> {
|
||||||
try {
|
try {
|
||||||
if (!query) {
|
if (!query) {
|
||||||
throw new Error('Search query cannot be empty')
|
throw new Error('Search query cannot be empty')
|
||||||
@ -124,7 +128,7 @@ export default class SearxngProvider extends BaseWebSearchProvider {
|
|||||||
// Fetch content for each URL concurrently
|
// Fetch content for each URL concurrently
|
||||||
const fetchPromises = validItems.map(async (item) => {
|
const fetchPromises = validItems.map(async (item) => {
|
||||||
// Logger.log(`Fetching content for ${item.url}...`)
|
// Logger.log(`Fetching content for ${item.url}...`)
|
||||||
return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser)
|
return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Wait for all fetches to complete
|
// Wait for all fetches to complete
|
||||||
|
|||||||
@ -20,23 +20,42 @@ export default class TavilyProvider extends BaseWebSearchProvider {
|
|||||||
this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
|
||||||
}
|
}
|
||||||
|
|
||||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
public async search(
|
||||||
|
query: string,
|
||||||
|
websearch: WebSearchState,
|
||||||
|
httpOptions?: RequestInit
|
||||||
|
): Promise<WebSearchProviderResponse> {
|
||||||
try {
|
try {
|
||||||
if (!query.trim()) {
|
if (!query.trim()) {
|
||||||
throw new Error('Search query cannot be empty')
|
throw new Error('Search query cannot be empty')
|
||||||
}
|
}
|
||||||
|
|
||||||
const result = await this.tvly.search({
|
// 使用 Promise.race 来支持 abort signal
|
||||||
|
const searchPromise = this.tvly.search({
|
||||||
query,
|
query,
|
||||||
max_results: Math.max(1, websearch.maxResults)
|
max_results: Math.max(1, websearch.maxResults)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
let result: Awaited<typeof searchPromise>
|
||||||
|
if (httpOptions?.signal) {
|
||||||
|
result = await Promise.race([
|
||||||
|
searchPromise,
|
||||||
|
new Promise<never>((_, reject) => {
|
||||||
|
httpOptions.signal?.addEventListener('abort', () => {
|
||||||
|
reject(new DOMException('The operation was aborted.', 'AbortError'))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
result = await searchPromise
|
||||||
|
}
|
||||||
return {
|
return {
|
||||||
query: result.query,
|
query: result.query,
|
||||||
results: result.results.slice(0, websearch.maxResults).map((result) => {
|
results: result.results.slice(0, websearch.maxResults).map((item) => {
|
||||||
return {
|
return {
|
||||||
title: result.title || 'No title',
|
title: item.title || 'No title',
|
||||||
content: result.content || '',
|
content: item.content || '',
|
||||||
url: result.url || ''
|
url: item.url || ''
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@ -43,7 +43,11 @@ export default class ZhipuProvider extends BaseWebSearchProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
|
public async search(
|
||||||
|
query: string,
|
||||||
|
websearch: WebSearchState,
|
||||||
|
httpOptions?: RequestInit
|
||||||
|
): Promise<WebSearchProviderResponse> {
|
||||||
try {
|
try {
|
||||||
if (!query.trim()) {
|
if (!query.trim()) {
|
||||||
throw new Error('Search query cannot be empty')
|
throw new Error('Search query cannot be empty')
|
||||||
@ -62,7 +66,8 @@ export default class ZhipuProvider extends BaseWebSearchProvider {
|
|||||||
'Content-Type': 'application/json',
|
'Content-Type': 'application/json',
|
||||||
...this.defaultHeaders()
|
...this.defaultHeaders()
|
||||||
},
|
},
|
||||||
body: JSON.stringify(requestBody)
|
body: JSON.stringify(requestBody),
|
||||||
|
signal: httpOptions?.signal
|
||||||
})
|
})
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
|||||||
@ -430,7 +430,8 @@ class WebSearchService {
|
|||||||
public async processWebsearch(
|
public async processWebsearch(
|
||||||
webSearchProvider: WebSearchProvider,
|
webSearchProvider: WebSearchProvider,
|
||||||
extractResults: ExtractResults,
|
extractResults: ExtractResults,
|
||||||
requestId: string
|
requestId: string,
|
||||||
|
externalSignal?: AbortSignal
|
||||||
): Promise<WebSearchProviderResponse> {
|
): Promise<WebSearchProviderResponse> {
|
||||||
// 重置状态
|
// 重置状态
|
||||||
await this.setWebSearchStatus(requestId, { phase: 'default' })
|
await this.setWebSearchStatus(requestId, { phase: 'default' })
|
||||||
@ -441,8 +442,8 @@ class WebSearchService {
|
|||||||
return { results: [] }
|
return { results: [] }
|
||||||
}
|
}
|
||||||
|
|
||||||
// 使用请求特定的signal,如果没有则回退到全局signal
|
// 优先使用外部传入的signal,其次是请求特定的signal,最后回退到全局signal
|
||||||
const signal = this.getRequestState(requestId).signal || this.signal
|
const signal = externalSignal || this.getRequestState(requestId).signal || this.signal
|
||||||
|
|
||||||
const span = webSearchProvider.topicId
|
const span = webSearchProvider.topicId
|
||||||
? addSpan({
|
? addSpan({
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user