Merge branch 'main' into fix/next-release-bugs

This commit is contained in:
kangfenmao 2025-05-10 20:18:44 +08:00
commit f4963888e8
44 changed files with 1009 additions and 518 deletions

View File

@ -19,7 +19,7 @@ Cherry Studio is a desktop client that supports for multiple LLM providers, avai
# 📖 Guide # 📖 Guide
https://docs.cherry-ai.com <https://docs.cherry-ai.com>
# 🌠 Screenshot # 🌠 Screenshot
@ -82,17 +82,18 @@ https://docs.cherry-ai.com
# 🌈 Theme # 🌈 Theme
- Theme Gallery: https://cherrycss.com - Theme Gallery: <https://cherrycss.com>
- Aero Theme: https://github.com/hakadao/CherryStudio-Aero - Aero Theme: <https://github.com/hakadao/CherryStudio-Aero>
- PaperMaterial Theme: https://github.com/rainoffallingstar/CherryStudio-PaperMaterial - PaperMaterial Theme: <https://github.com/rainoffallingstar/CherryStudio-PaperMaterial>
- Claude dynamic-style: https://github.com/bjl101501/CherryStudio-Claudestyle-dynamic - Claude dynamic-style: <https://github.com/bjl101501/CherryStudio-Claudestyle-dynamic>
- Maple Neon Theme: https://github.com/BoningtonChen/CherryStudio_themes - Maple Neon Theme: <https://github.com/BoningtonChen/CherryStudio_themes>
Welcome PR for more themes Welcome PR for more themes
# 🖥️ Develop # 🖥️ Develop
Refer to the [development documentation](docs/dev.md) Refer to the [development documentation](docs/dev.md)
Refer to the [Architecture overview documentation](https://deepwiki.com/CherryHQ/cherry-studio)
# 🤝 Contributing # 🤝 Contributing
@ -144,7 +145,7 @@ Thank you for your support and contributions!
# ✉️ Contact # ✉️ Contact
yinsenho@cherry-ai.com <yinsenho@cherry-ai.com>
# ⭐️ Star History # ⭐️ Star History

View File

@ -17,6 +17,10 @@ export default abstract class BaseReranker {
* Get Rerank Request Url * Get Rerank Request Url
*/ */
protected getRerankUrl() { protected getRerankUrl() {
if (this.base.rerankModelProvider === 'dashscope') {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
let baseURL = this.base?.rerankBaseURL?.endsWith('/') let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1) ? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL : this.base.rerankBaseURL
@ -28,6 +32,56 @@ export default abstract class BaseReranker {
return `${baseURL}/rerank` return `${baseURL}/rerank`
} }
/**
* Get Rerank Request Body
*/
protected getRerankRequestBody(query: string, searchResults: ExtractChunkData[]) {
const provider = this.base.rerankModelProvider
const documents = searchResults.map((doc) => doc.pageContent)
const topN = this.base.topN || 5
if (provider === 'voyageai') {
return {
model: this.base.rerankModel,
query,
documents,
top_k: topN
}
} else if (provider === 'dashscope') {
return {
model: this.base.rerankModel,
input: {
query,
documents
},
parameters: {
top_n: topN
}
}
} else {
return {
model: this.base.rerankModel,
query,
documents,
top_n: topN
}
}
}
/**
* Extract Rerank Result
*/
protected extractRerankResult(data: any) {
const provider = this.base.rerankModelProvider
if (provider === 'dashscope') {
return data.output.results
} else if (provider === 'voyageai') {
return data.data
} else {
return data.results
}
}
/** /**
* Get Rerank Result * Get Rerank Result
* @param searchResults * @param searchResults

View File

@ -1,58 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
interface DashscopeRerankResultItem {
document: {
text: string
}
index: number
relevance_score: number
}
interface DashscopeRerankResponse {
output: {
results: DashscopeRerankResultItem[]
}
usage: {
total_tokens: number
}
request_id: string
}
export default class DashscopeReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
const requestBody = {
model: this.base.rerankModel,
input: {
query,
documents: searchResults.map((doc) => doc.pageContent)
},
parameters: {
return_documents: true, // Recommended to be true to get document details if needed, though scores are primary
top_n: this.base.topN || 5 // Default to 5 if topN is not specified, as per API example
}
}
try {
const { data } = await axiosProxy.axios.post<DashscopeRerankResponse>(url, requestBody, {
headers: this.defaultHeaders()
})
const rerankResults = data.output.results
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Dashscope Reranker API 错误:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -1,14 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class DefaultReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
async rerank(): Promise<ExtractChunkData[]> {
throw new Error('Method not implemented.')
}
}

View File

@ -4,7 +4,7 @@ import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker' import BaseReranker from './BaseReranker'
export default class JinaReranker extends BaseReranker { export default class GeneralReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) { constructor(base: KnowledgeBaseParams) {
super(base) super(base)
} }
@ -12,21 +12,15 @@ export default class JinaReranker extends BaseReranker {
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => { public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl() const url = this.getRerankUrl()
const requestBody = { const requestBody = this.getRerankRequestBody(query, searchResults)
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN
}
try { try {
const { data } = await AxiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() }) const { data } = await AxiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results const rerankResults = this.extractRerankResult(data)
return this.getRerankResult(searchResults, rerankResults) return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) { } catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody) const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Jina Reranker API Error:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`) throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
} }
} }

View File

@ -1,13 +1,12 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces' import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { KnowledgeBaseParams } from '@types' import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker' import GeneralReranker from './GeneralReranker'
import RerankerFactory from './RerankerFactory'
export default class Reranker { export default class Reranker {
private sdk: BaseReranker private sdk: GeneralReranker
constructor(base: KnowledgeBaseParams) { constructor(base: KnowledgeBaseParams) {
this.sdk = RerankerFactory.create(base) this.sdk = new GeneralReranker(base)
} }
public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> { public async rerank(query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> {
return this.sdk.rerank(query, searchResults) return this.sdk.rerank(query, searchResults)

View File

@ -1,23 +0,0 @@
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
import DashscopeReranker from './DashscopeReranker'
import DefaultReranker from './DefaultReranker'
import JinaReranker from './JinaReranker'
import SiliconFlowReranker from './SiliconFlowReranker'
import VoyageReranker from './VoyageReranker'
export default class RerankerFactory {
static create(base: KnowledgeBaseParams): BaseReranker {
if (base.rerankModelProvider === 'silicon') {
return new SiliconFlowReranker(base)
} else if (base.rerankModelProvider === 'jina') {
return new JinaReranker(base)
} else if (base.rerankModelProvider === 'voyageai') {
return new VoyageReranker(base)
} else if (base.rerankModelProvider === 'dashscope') {
return new DashscopeReranker(base)
}
return new DefaultReranker(base)
}
}

View File

@ -1,36 +0,0 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class SiliconFlowReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_n: this.base.topN,
max_chunks_per_doc: this.base.chunkSize,
overlap_tokens: this.base.chunkOverlap
}
try {
const { data } = await axiosProxy.axios.post(url, requestBody, { headers: this.defaultHeaders() })
const rerankResults = data.results
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('SiliconFlow Reranker API 错误:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

View File

@ -1,40 +0,0 @@
import { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import axiosProxy from '@main/services/AxiosProxy'
import { KnowledgeBaseParams } from '@types'
import BaseReranker from './BaseReranker'
export default class VoyageReranker extends BaseReranker {
constructor(base: KnowledgeBaseParams) {
super(base)
}
public rerank = async (query: string, searchResults: ExtractChunkData[]): Promise<ExtractChunkData[]> => {
const url = this.getRerankUrl()
const requestBody = {
model: this.base.rerankModel,
query,
documents: searchResults.map((doc) => doc.pageContent),
top_k: this.base.topN,
return_documents: false,
truncation: true
}
try {
const { data } = await axiosProxy.axios.post(url, requestBody, {
headers: {
...this.defaultHeaders()
}
})
const rerankResults = data.data
return this.getRerankResult(searchResults, rerankResults)
} catch (error: any) {
const errorDetails = this.formatErrorMessage(url, error, requestBody)
console.error('Voyage Reranker API Error:', errorDetails)
throw new Error(`重排序请求失败: ${error.message}\n请求详情: ${errorDetails}`)
}
}
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.7 KiB

View File

@ -0,0 +1,41 @@
import { useMemo, useReducer } from 'react'
import { initialScrollState, scrollReducer } from './reducer'
import { FlatListItem, ScrollTrigger } from './types'
/**
* hook
*/
export function useScrollState() {
const [state, dispatch] = useReducer(scrollReducer, initialScrollState)
const actions = useMemo(
() => ({
setFocusedItemKey: (key: string) => dispatch({ type: 'SET_FOCUSED_ITEM_KEY', payload: key }),
setScrollTrigger: (trigger: ScrollTrigger) => dispatch({ type: 'SET_SCROLL_TRIGGER', payload: trigger }),
setLastScrollOffset: (offset: number) => dispatch({ type: 'SET_LAST_SCROLL_OFFSET', payload: offset }),
setStickyGroup: (group: FlatListItem | null) => dispatch({ type: 'SET_STICKY_GROUP', payload: group }),
setIsMouseOver: (isMouseOver: boolean) => dispatch({ type: 'SET_IS_MOUSE_OVER', payload: isMouseOver }),
focusNextItem: (modelItems: FlatListItem[], step: number) =>
dispatch({ type: 'FOCUS_NEXT_ITEM', payload: { modelItems, step } }),
focusPage: (modelItems: FlatListItem[], currentIndex: number, step: number) =>
dispatch({ type: 'FOCUS_PAGE', payload: { modelItems, currentIndex, step } }),
searchChanged: (searchText: string) => dispatch({ type: 'SEARCH_CHANGED', payload: { searchText } }),
updateOnListChange: (modelItems: FlatListItem[]) =>
dispatch({ type: 'UPDATE_ON_LIST_CHANGE', payload: { modelItems } }),
initScroll: () => dispatch({ type: 'INIT_SCROLL' })
}),
[]
)
return {
// 状态
focusedItemKey: state.focusedItemKey,
scrollTrigger: state.scrollTrigger,
lastScrollOffset: state.lastScrollOffset,
stickyGroup: state.stickyGroup,
isMouseOver: state.isMouseOver,
// 操作
...actions
}
}

View File

@ -0,0 +1,3 @@
import { SelectModelPopup } from './popup'
export default SelectModelPopup

View File

@ -1,7 +1,9 @@
import { PushpinOutlined } from '@ant-design/icons' import { PushpinOutlined } from '@ant-design/icons'
import { HStack } from '@renderer/components/Layout'
import ModelTagsWithLabel from '@renderer/components/ModelTagsWithLabel'
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { getModelLogo, isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { getModelLogo, isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import db from '@renderer/databases' import { usePinnedModels } from '@renderer/hooks/usePinnedModels'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService' import { getModelUniqId } from '@renderer/services/ModelService'
import { Model } from '@renderer/types' import { Model } from '@renderer/types'
@ -15,101 +17,61 @@ import { useTranslation } from 'react-i18next'
import { FixedSizeList } from 'react-window' import { FixedSizeList } from 'react-window'
import styled from 'styled-components' import styled from 'styled-components'
import { HStack } from '../Layout' import { useScrollState } from './hook'
import ModelTagsWithLabel from '../ModelTagsWithLabel' import { FlatListItem } from './types'
const PAGE_SIZE = 9 const PAGE_SIZE = 9
const ITEM_HEIGHT = 36 const ITEM_HEIGHT = 36
// 列表项类型,组名也作为列表项 interface PopupParams {
type ListItemType = 'group' | 'model'
// 滚动触发来源类型
type ScrollTrigger = 'initial' | 'search' | 'keyboard' | 'none'
// 扁平化列表项接口
interface FlatListItem {
key: string
type: ListItemType
icon?: React.ReactNode
name: React.ReactNode
tags?: React.ReactNode
model?: Model
isPinned?: boolean
isSelected?: boolean
}
interface Props {
model?: Model model?: Model
} }
interface PopupContainerProps extends Props { interface Props extends PopupParams {
resolve: (value: Model | undefined) => void resolve: (value: Model | undefined) => void
} }
const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => { const PopupContainer: React.FC<Props> = ({ model, resolve }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { providers } = useProviders() const { providers } = useProviders()
const { pinnedModels, togglePinnedModel, loading: loadingPinnedModels } = usePinnedModels()
const [open, setOpen] = useState(true) const [open, setOpen] = useState(true)
const inputRef = useRef<InputRef>(null) const inputRef = useRef<InputRef>(null)
const listRef = useRef<FixedSizeList>(null) const listRef = useRef<FixedSizeList>(null)
const [_searchText, setSearchText] = useState('') const [_searchText, setSearchText] = useState('')
const searchText = useDeferredValue(_searchText) const searchText = useDeferredValue(_searchText)
const [isMouseOver, setIsMouseOver] = useState(false)
const [pinnedModels, setPinnedModels] = useState<string[]>([])
const [_focusedItemKey, setFocusedItemKey] = useState<string>('')
const focusedItemKey = useDeferredValue(_focusedItemKey)
const [_stickyGroup, setStickyGroup] = useState<FlatListItem | null>(null)
const stickyGroup = useDeferredValue(_stickyGroup)
const firstGroupRef = useRef<FlatListItem | null>(null)
const scrollTriggerRef = useRef<ScrollTrigger>('initial')
const lastScrollOffsetRef = useRef(0)
// 当前选中的模型ID // 当前选中的模型ID
const currentModelId = model ? getModelUniqId(model) : '' const currentModelId = model ? getModelUniqId(model) : ''
// 加载置顶模型列表 // 管理滚动和焦点状态
useEffect(() => { const {
const loadPinnedModels = async () => { focusedItemKey,
const setting = await db.settings.get('pinned:models') scrollTrigger,
const savedPinnedModels = setting?.value || [] lastScrollOffset,
stickyGroup: _stickyGroup,
isMouseOver,
setFocusedItemKey,
setScrollTrigger,
setLastScrollOffset,
setStickyGroup,
setIsMouseOver,
focusNextItem,
focusPage,
searchChanged,
updateOnListChange,
initScroll
} = useScrollState()
// Filter out invalid pinned models const stickyGroup = useDeferredValue(_stickyGroup)
const allModelIds = providers.flatMap((p) => p.models || []).map((m) => getModelUniqId(m)) const firstGroupRef = useRef<FlatListItem | null>(null)
const validPinnedModels = savedPinnedModels.filter((id) => allModelIds.includes(id))
// Update storage if there were invalid models
if (validPinnedModels.length !== savedPinnedModels.length) {
await db.settings.put({ id: 'pinned:models', value: validPinnedModels })
}
setPinnedModels(sortBy(validPinnedModels))
}
try {
loadPinnedModels()
} catch (error) {
console.error('Failed to load pinned models', error)
setPinnedModels([])
}
}, [providers])
const togglePin = useCallback( const togglePin = useCallback(
async (modelId: string) => { async (modelId: string) => {
const newPinnedModels = pinnedModels.includes(modelId) await togglePinnedModel(modelId)
? pinnedModels.filter((id) => id !== modelId) setScrollTrigger('none') // pin操作不触发滚动
: [...pinnedModels, modelId]
try {
await db.settings.put({ id: 'pinned:models', value: newPinnedModels })
setPinnedModels(sortBy(newPinnedModels))
// Pin操作不触发滚动
scrollTriggerRef.current = 'none'
} catch (error) {
console.error('Failed to update pinned models', error)
}
}, },
[pinnedModels] [togglePinnedModel, setScrollTrigger]
) )
// 根据输入的文本筛选模型 // 根据输入的文本筛选模型
@ -222,6 +184,16 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
return items return items
}, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem]) }, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem])
// 获取可选择的模型项(过滤掉分组标题)
const modelItems = useMemo(() => {
return listItems.filter((item) => item.type === 'model')
}, [listItems])
// 当搜索文本变化时更新滚动触发器
useEffect(() => {
searchChanged(searchText)
}, [searchText, searchChanged])
// 基于滚动位置更新sticky分组标题 // 基于滚动位置更新sticky分组标题
const updateStickyGroup = useCallback( const updateStickyGroup = useCallback(
(scrollOffset?: number) => { (scrollOffset?: number) => {
@ -231,7 +203,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
} }
// 基于滚动位置计算当前可见的第一个项的索引 // 基于滚动位置计算当前可见的第一个项的索引
const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffsetRef.current) / ITEM_HEIGHT) const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffset) / ITEM_HEIGHT)
// 从该索引向前查找最近的分组标题 // 从该索引向前查找最近的分组标题
for (let i = estimatedIndex - 1; i >= 0; i--) { for (let i = estimatedIndex - 1; i >= 0; i--) {
@ -242,9 +214,9 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
} }
// 找不到则使用第一个分组标题 // 找不到则使用第一个分组标题
setStickyGroup(firstGroupRef.current ?? null) setStickyGroup(firstGroupRef.current)
}, },
[listItems] [listItems, lastScrollOffset, setStickyGroup]
) )
// 在listItems变化时更新sticky group // 在listItems变化时更新sticky group
@ -255,67 +227,46 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
// 处理列表滚动事件更新lastScrollOffset并更新sticky分组 // 处理列表滚动事件更新lastScrollOffset并更新sticky分组
const handleScroll = useCallback( const handleScroll = useCallback(
({ scrollOffset }) => { ({ scrollOffset }) => {
lastScrollOffsetRef.current = scrollOffset setLastScrollOffset(scrollOffset)
updateStickyGroup(scrollOffset) updateStickyGroup(scrollOffset)
}, },
[updateStickyGroup] [updateStickyGroup, setLastScrollOffset]
) )
// 获取可选择的模型项(过滤掉分组标题) // 在列表项更新时,更新焦点项
const modelItems = useMemo(() => {
return listItems.filter((item) => item.type === 'model')
}, [listItems])
// 搜索文本变化时设置滚动来源
useEffect(() => { useEffect(() => {
if (searchText.trim() !== '') { updateOnListChange(modelItems)
scrollTriggerRef.current = 'search' }, [modelItems, updateOnListChange])
setFocusedItemKey('')
}
}, [searchText])
// 设置初始聚焦项以触发滚动
useEffect(() => {
if (scrollTriggerRef.current === 'initial' || scrollTriggerRef.current === 'search') {
const selectedItem = modelItems.find((item) => item.isSelected)
if (selectedItem) {
setFocusedItemKey(selectedItem.key)
} else if (scrollTriggerRef.current === 'initial' && modelItems.length > 0) {
setFocusedItemKey(modelItems[0].key)
}
// 其余情况不设置focusedItemKey
}
}, [modelItems])
// 滚动到聚焦项 // 滚动到聚焦项
useEffect(() => { useEffect(() => {
if (scrollTriggerRef.current === 'none' || !focusedItemKey) return if (scrollTrigger === 'none' || !focusedItemKey) return
const index = listItems.findIndex((item) => item.key === focusedItemKey) const index = listItems.findIndex((item) => item.key === focusedItemKey)
if (index < 0) return if (index < 0) return
// 根据触发源决定滚动对齐方式 // 根据触发源决定滚动对齐方式
const alignment = scrollTriggerRef.current === 'keyboard' ? 'auto' : 'center' const alignment = scrollTrigger === 'keyboard' ? 'auto' : 'center'
listRef.current?.scrollToItem(index, alignment) listRef.current?.scrollToItem(index, alignment)
// 滚动后重置触发器 // 滚动后重置触发器
scrollTriggerRef.current = 'none' setScrollTrigger('none')
}, [focusedItemKey, listItems]) }, [focusedItemKey, scrollTrigger, listItems, setScrollTrigger])
const handleItemClick = useCallback( const handleItemClick = useCallback(
(item: FlatListItem) => { (item: FlatListItem) => {
if (item.type === 'model') { if (item.type === 'model') {
scrollTriggerRef.current = 'none' setScrollTrigger('initial')
resolve(item.model) resolve(item.model)
setOpen(false) setOpen(false)
} }
}, },
[resolve] [resolve, setScrollTrigger]
) )
// 处理键盘导航 // 处理键盘导航
useEffect(() => { const handleKeyDown = useCallback(
const handleKeyDown = (e: KeyboardEvent) => { (e: KeyboardEvent) => {
if (!open) return if (!open) return
if (modelItems.length === 0) { if (modelItems.length === 0) {
@ -329,43 +280,21 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
setIsMouseOver(false) setIsMouseOver(false)
} }
const getCurrentIndex = (currentKey: string) => { const currentIndex = modelItems.findIndex((item) => item.key === focusedItemKey)
const currentIndex = modelItems.findIndex((item) => item.key === currentKey) const normalizedIndex = currentIndex < 0 ? 0 : currentIndex
return currentIndex < 0 ? 0 : currentIndex
}
switch (e.key) { switch (e.key) {
case 'ArrowUp': case 'ArrowUp':
scrollTriggerRef.current = 'keyboard' focusNextItem(modelItems, -1)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = (currentIndex - 1 + modelItems.length) % modelItems.length
return modelItems[nextIndex].key
})
break break
case 'ArrowDown': case 'ArrowDown':
scrollTriggerRef.current = 'keyboard' focusNextItem(modelItems, 1)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = (currentIndex + 1) % modelItems.length
return modelItems[nextIndex].key
})
break break
case 'PageUp': case 'PageUp':
scrollTriggerRef.current = 'keyboard' focusPage(modelItems, normalizedIndex, -PAGE_SIZE)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = Math.max(currentIndex - PAGE_SIZE, 0)
return modelItems[nextIndex].key
})
break break
case 'PageDown': case 'PageDown':
scrollTriggerRef.current = 'keyboard' focusPage(modelItems, normalizedIndex, PAGE_SIZE)
setFocusedItemKey((prev) => {
const currentIndex = getCurrentIndex(prev)
const nextIndex = Math.min(currentIndex + PAGE_SIZE, modelItems.length - 1)
return modelItems[nextIndex].key
})
break break
case 'Enter': case 'Enter':
if (focusedItemKey) { if (focusedItemKey) {
@ -377,34 +306,47 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
break break
case 'Escape': case 'Escape':
e.preventDefault() e.preventDefault()
scrollTriggerRef.current = 'none' setScrollTrigger('none')
setOpen(false) setOpen(false)
resolve(undefined) resolve(undefined)
break break
} }
} },
[
window.addEventListener('keydown', handleKeyDown) focusedItemKey,
return () => window.removeEventListener('keydown', handleKeyDown) modelItems,
}, [focusedItemKey, modelItems, handleItemClick, open, resolve]) handleItemClick,
open,
const onCancel = useCallback(() => { resolve,
scrollTriggerRef.current = 'none' setIsMouseOver,
setOpen(false) focusNextItem,
}, []) focusPage,
setScrollTrigger
const onClose = useCallback(async () => { ]
scrollTriggerRef.current = 'none' )
resolve(undefined)
SelectModelPopup.hide()
}, [resolve])
useEffect(() => { useEffect(() => {
if (!open) return window.addEventListener('keydown', handleKeyDown)
return () => window.removeEventListener('keydown', handleKeyDown)
}, [handleKeyDown])
const onCancel = useCallback(() => {
setScrollTrigger('initial')
setOpen(false)
}, [setScrollTrigger])
const onClose = useCallback(async () => {
setScrollTrigger('initial')
resolve(undefined)
SelectModelPopup.hide()
}, [resolve, setScrollTrigger])
// 初始化焦点和滚动位置
useEffect(() => {
if (!open || loadingPinnedModels) return
setTimeout(() => inputRef.current?.focus(), 0) setTimeout(() => inputRef.current?.focus(), 0)
scrollTriggerRef.current = 'initial' initScroll()
lastScrollOffsetRef.current = 0 }, [open, initScroll, loadingPinnedModels])
}, [open])
const RowData = useMemo( const RowData = useMemo(
(): VirtualizedRowData => ({ (): VirtualizedRowData => ({
@ -415,7 +357,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
handleItemClick, handleItemClick,
togglePin togglePin
}), }),
[stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin] [stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin, setFocusedItemKey]
) )
const listHeight = useMemo(() => { const listHeight = useMemo(() => {
@ -470,7 +412,7 @@ const PopupContainer: React.FC<PopupContainerProps> = ({ model, resolve }) => {
<Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} /> <Divider style={{ margin: 0, marginTop: 4, borderBlockStartWidth: 0.5 }} />
{listItems.length > 0 ? ( {listItems.length > 0 ? (
<ListContainer onMouseMove={() => setIsMouseOver(true)}> <ListContainer onMouseMove={() => !isMouseOver && setIsMouseOver(true)}>
{/* Sticky Group Banner它会替换第一个分组名称 */} {/* Sticky Group Banner它会替换第一个分组名称 */}
<StickyGroupBanner>{stickyGroup?.name}</StickyGroupBanner> <StickyGroupBanner>{stickyGroup?.name}</StickyGroupBanner>
<FixedSizeList <FixedSizeList
@ -685,14 +627,26 @@ const PinIconWrapper = styled.div.attrs({ className: 'pin-icon' })<{ $isPinned?:
} }
` `
export default class SelectModelPopup { const TopViewKey = 'SelectModelPopup'
export class SelectModelPopup {
static topviewId = 0
static hide() { static hide() {
TopView.hide('SelectModelPopup') TopView.hide(TopViewKey)
} }
static show(params: Props) { static show(params: PopupParams) {
return new Promise<Model | undefined>((resolve) => { return new Promise<Model | undefined>((resolve) => {
TopView.show(<PopupContainer {...params} resolve={resolve} />, 'SelectModelPopup') TopView.show(
<PopupContainer
{...params}
resolve={(v) => {
resolve(v)
TopView.hide(TopViewKey)
}}
/>,
TopViewKey
)
}) })
} }
} }

View File

@ -0,0 +1,109 @@
import { ScrollAction, ScrollState } from './types'
/**
*
*/
export const initialScrollState: ScrollState = {
focusedItemKey: '',
scrollTrigger: 'initial',
lastScrollOffset: 0,
stickyGroup: null,
isMouseOver: false
}
/**
* reducer
* @param state
* @param action
* @returns
*/
export const scrollReducer = (state: ScrollState, action: ScrollAction): ScrollState => {
switch (action.type) {
case 'SET_FOCUSED_ITEM_KEY':
return { ...state, focusedItemKey: action.payload }
case 'SET_SCROLL_TRIGGER':
return { ...state, scrollTrigger: action.payload }
case 'SET_LAST_SCROLL_OFFSET':
return { ...state, lastScrollOffset: action.payload }
case 'SET_STICKY_GROUP':
return { ...state, stickyGroup: action.payload }
case 'SET_IS_MOUSE_OVER':
return { ...state, isMouseOver: action.payload }
case 'FOCUS_NEXT_ITEM': {
const { modelItems, step } = action.payload
if (modelItems.length === 0) {
return {
...state,
focusedItemKey: '',
scrollTrigger: 'keyboard'
}
}
const currentIndex = modelItems.findIndex((item) => item.key === state.focusedItemKey)
const nextIndex = (currentIndex < 0 ? 0 : currentIndex + step + modelItems.length) % modelItems.length
return {
...state,
focusedItemKey: modelItems[nextIndex].key,
scrollTrigger: 'keyboard'
}
}
case 'FOCUS_PAGE': {
const { modelItems, currentIndex, step } = action.payload
const nextIndex = Math.max(0, Math.min(currentIndex + step, modelItems.length - 1))
return {
...state,
focusedItemKey: modelItems.length > 0 ? modelItems[nextIndex].key : '',
scrollTrigger: 'keyboard'
}
}
case 'SEARCH_CHANGED':
return {
...state,
scrollTrigger: action.payload.searchText ? 'search' : 'initial'
}
case 'UPDATE_ON_LIST_CHANGE': {
const { modelItems } = action.payload
// 在列表变化时尝试聚焦一个模型:
// - 如果是 initial 状态,先尝试聚焦当前选中的模型
// - 如果是 search 状态,尝试聚焦第一个模型
let newFocusedKey = ''
if (state.scrollTrigger === 'initial' || state.scrollTrigger === 'search') {
const selectedItem = modelItems.find((item) => item.isSelected)
if (selectedItem && state.scrollTrigger === 'initial') {
newFocusedKey = selectedItem.key
} else if (modelItems.length > 0) {
newFocusedKey = modelItems[0].key
}
} else {
newFocusedKey = state.focusedItemKey
}
return {
...state,
focusedItemKey: newFocusedKey
}
}
case 'INIT_SCROLL':
return {
...state,
scrollTrigger: 'initial',
lastScrollOffset: 0
}
default:
return state
}
}

View File

@ -0,0 +1,42 @@
import { Model } from '@renderer/types'
import { ReactNode } from 'react'
// 列表项类型,组名也作为列表项
export type ListItemType = 'group' | 'model'
// 滚动触发来源类型
export type ScrollTrigger = 'initial' | 'search' | 'keyboard' | 'none'
// 扁平化列表项接口
export interface FlatListItem {
key: string
type: ListItemType
icon?: ReactNode
name: ReactNode
tags?: ReactNode
model?: Model
isPinned?: boolean
isSelected?: boolean
}
// 滚动和焦点相关的状态类型
export interface ScrollState {
focusedItemKey: string
scrollTrigger: ScrollTrigger
lastScrollOffset: number
stickyGroup: FlatListItem | null
isMouseOver: boolean
}
// 滚动和焦点相关的 action 类型
export type ScrollAction =
| { type: 'SET_FOCUSED_ITEM_KEY'; payload: string }
| { type: 'SET_SCROLL_TRIGGER'; payload: ScrollTrigger }
| { type: 'SET_LAST_SCROLL_OFFSET'; payload: number }
| { type: 'SET_STICKY_GROUP'; payload: FlatListItem | null }
| { type: 'SET_IS_MOUSE_OVER'; payload: boolean }
| { type: 'FOCUS_NEXT_ITEM'; payload: { modelItems: FlatListItem[]; step: number } }
| { type: 'FOCUS_PAGE'; payload: { modelItems: FlatListItem[]; currentIndex: number; step: number } }
| { type: 'SEARCH_CHANGED'; payload: { searchText: string } }
| { type: 'UPDATE_ON_LIST_CHANGE'; payload: { modelItems: FlatListItem[] } }
| { type: 'INIT_SCROLL'; payload?: void }

View File

@ -237,10 +237,10 @@ export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
) )
export function isFunctionCallingModel(model: Model): boolean { export function isFunctionCallingModel(model: Model): boolean {
if (model.type?.includes('function_calling')) { if (!model) return false
return true if (model.type) {
} return model.type.includes('function_calling')
} else {
if (isEmbeddingModel(model)) { if (isEmbeddingModel(model)) {
return false return false
} }
@ -255,6 +255,7 @@ export function isFunctionCallingModel(model: Model): boolean {
return FUNCTION_CALLING_REGEX.test(model.id) return FUNCTION_CALLING_REGEX.test(model.id)
} }
}
export function getModelLogo(modelId: string) { export function getModelLogo(modelId: string) {
const isLight = true const isLight = true
@ -2188,7 +2189,9 @@ export function isEmbeddingModel(model: Model): boolean {
if (!model) { if (!model) {
return false return false
} }
if (model.type) {
return model.type.includes('embedding')
} else {
if (['anthropic'].includes(model?.provider)) { if (['anthropic'].includes(model?.provider)) {
return false return false
} }
@ -2201,7 +2204,8 @@ export function isEmbeddingModel(model: Model): boolean {
return false return false
} }
return EMBEDDING_REGEX.test(model.id) || model.type?.includes('embedding') || false return EMBEDDING_REGEX.test(model.id)
}
} }
export function isRerankModel(model: Model): boolean { export function isRerankModel(model: Model): boolean {
@ -2212,16 +2216,20 @@ export function isVisionModel(model: Model): boolean {
if (!model) { if (!model) {
return false return false
} }
if (model.type) {
return model.type.includes('vision')
} else {
// 新添字段 copilot-vision-request 后可使用 vision // 新添字段 copilot-vision-request 后可使用 vision
// if (model.provider === 'copilot') { // if (model.provider === 'copilot') {
// return false // return false
// } // }
if (model.provider === 'doubao') { if (model.provider === 'doubao') {
return VISION_REGEX.test(model.name) || model.type?.includes('vision') || false return VISION_REGEX.test(model.name)
} }
return VISION_REGEX.test(model.id) || model.type?.includes('vision') || false return VISION_REGEX.test(model.id)
}
} }
export function isOpenAIReasoningModel(model: Model): boolean { export function isOpenAIReasoningModel(model: Model): boolean {
@ -2355,9 +2363,11 @@ export function isReasoningModel(model?: Model): boolean {
if (!model) { if (!model) {
return false return false
} }
if (model.type) {
return model.type.includes('reasoning')
} else {
if (model.provider === 'doubao') { if (model.provider === 'doubao') {
return REASONING_REGEX.test(model.name) || model.type?.includes('reasoning') || false return REASONING_REGEX.test(model.name)
} }
if ( if (
@ -2371,7 +2381,8 @@ export function isReasoningModel(model?: Model): boolean {
return true return true
} }
return REASONING_REGEX.test(model.id) || model.type?.includes('reasoning') || false return REASONING_REGEX.test(model.id)
}
} }
export function isSupportedModel(model: OpenAI.Models.Model): boolean { export function isSupportedModel(model: OpenAI.Models.Model): boolean {
@ -2386,13 +2397,9 @@ export function isWebSearchModel(model: Model): boolean {
if (!model) { if (!model) {
return false return false
} }
if (model.type) { if (model.type) {
if (model.type.includes('web_search')) { return model.type.includes('web_search')
return true } else {
}
}
const provider = getProviderByModel(model) const provider = getProviderByModel(model)
if (!provider) { if (!provider) {
@ -2470,6 +2477,7 @@ export function isWebSearchModel(model: Model): boolean {
return false return false
} }
}
export function isGenerateImageModel(model: Model): boolean { export function isGenerateImageModel(model: Model): boolean {
if (!model) { if (!model) {

View File

@ -95,7 +95,8 @@ export function getProviderLogo(providerId: string) {
return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP] return PROVIDER_LOGO_MAP[providerId as keyof typeof PROVIDER_LOGO_MAP]
} }
export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix'] // export const SUPPORTED_REANK_PROVIDERS = ['silicon', 'jina', 'voyageai', 'dashscope', 'aihubmix']
export const NOT_SUPPORTED_REANK_PROVIDERS = ['ollama']
export const PROVIDER_CONFIG = { export const PROVIDER_CONFIG = {
openai: { openai: {

View File

@ -1,6 +1,8 @@
import BochaLogo from '@renderer/assets/images/search/bocha.webp'
import ExaLogo from '@renderer/assets/images/search/exa.png' import ExaLogo from '@renderer/assets/images/search/exa.png'
import SearxngLogo from '@renderer/assets/images/search/searxng.svg' import SearxngLogo from '@renderer/assets/images/search/searxng.svg'
import TavilyLogo from '@renderer/assets/images/search/tavily.png' import TavilyLogo from '@renderer/assets/images/search/tavily.png'
export function getWebSearchProviderLogo(providerId: string) { export function getWebSearchProviderLogo(providerId: string) {
switch (providerId) { switch (providerId) {
case 'tavily': case 'tavily':
@ -9,6 +11,8 @@ export function getWebSearchProviderLogo(providerId: string) {
return SearxngLogo return SearxngLogo
case 'exa': case 'exa':
return ExaLogo return ExaLogo
case 'bocha':
return BochaLogo
default: default:
return undefined return undefined
} }
@ -32,6 +36,12 @@ export const WEB_SEARCH_PROVIDER_CONFIG = {
apiKey: 'https://dashboard.exa.ai/api-keys' apiKey: 'https://dashboard.exa.ai/api-keys'
} }
}, },
bocha: {
websites: {
official: 'https://bochaai.com',
apiKey: 'https://open.bochaai.com/overview'
}
},
'local-google': { 'local-google': {
websites: { websites: {
official: 'https://www.google.com' official: 'https://www.google.com'

View File

@ -0,0 +1,63 @@
import db from '@renderer/databases'
import { getModelUniqId } from '@renderer/services/ModelService'
import { sortBy } from 'lodash'
import { useCallback, useEffect, useState } from 'react'
import { useProviders } from './useProvider'
export const usePinnedModels = () => {
const [pinnedModels, setPinnedModels] = useState<string[]>([])
const [loading, setLoading] = useState(true)
const { providers } = useProviders()
useEffect(() => {
const loadPinnedModels = async () => {
setLoading(true)
const setting = await db.settings.get('pinned:models')
const savedPinnedModels = setting?.value || []
// Filter out invalid pinned models
const allModelIds = providers.flatMap((p) => p.models || []).map((m) => getModelUniqId(m))
const validPinnedModels = savedPinnedModels.filter((id) => allModelIds.includes(id))
// Update storage if there were invalid models
if (validPinnedModels.length !== savedPinnedModels.length) {
await db.settings.put({ id: 'pinned:models', value: validPinnedModels })
}
setPinnedModels(sortBy(validPinnedModels))
setLoading(false)
}
loadPinnedModels().catch((error) => {
console.error('Failed to load pinned models', error)
setPinnedModels([])
setLoading(false)
})
}, [providers])
const updatePinnedModels = useCallback(async (models: string[]) => {
await db.settings.put({ id: 'pinned:models', value: models })
setPinnedModels(sortBy(models))
}, [])
/**
* Toggle a single pinned model
* @param modelId - The ID string of the model to toggle
*/
const togglePinnedModel = useCallback(
async (modelId: string) => {
try {
const newPinnedModels = pinnedModels.includes(modelId)
? pinnedModels.filter((id) => id !== modelId)
: [...pinnedModels, modelId]
await updatePinnedModels(newPinnedModels)
} catch (error) {
console.error('Failed to toggle pinned model', error)
}
},
[pinnedModels, updatePinnedModels]
)
return { pinnedModels, updatePinnedModels, togglePinnedModel, loading }
}

View File

@ -705,6 +705,7 @@
"pinned": "Pinned", "pinned": "Pinned",
"rerank_model": "Reordering Model", "rerank_model": "Reordering Model",
"rerank_model_support_provider": "Currently, the reordering model only supports some providers ({{provider}})", "rerank_model_support_provider": "Currently, the reordering model only supports some providers ({{provider}})",
"rerank_model_not_support_provider": "Currently, the reordering model does not support this provider ({{provider}})",
"rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.", "rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.",
"search": "Search models...", "search": "Search models...",
"stream_output": "Stream output", "stream_output": "Stream output",
@ -1343,6 +1344,7 @@
"providerPlaceholder": "Provider name", "providerPlaceholder": "Provider name",
"advancedSettings": "Advanced Settings" "advancedSettings": "Advanced Settings"
}, },
"messages.prompt": "Show prompt",
"messages.divider": "Show divider between messages", "messages.divider": "Show divider between messages",
"messages.grid_columns": "Message grid display columns", "messages.grid_columns": "Message grid display columns",
"messages.grid_popover_trigger": "Grid detail trigger", "messages.grid_popover_trigger": "Grid detail trigger",

View File

@ -719,7 +719,8 @@
"text": "テキスト", "text": "テキスト",
"vision": "画像", "vision": "画像",
"websearch": "ウェブ検索" "websearch": "ウェブ検索"
} },
"rerank_model_not_support_provider": "現在、並べ替えモデルはこのプロバイダー ({{provider}}) をサポートしていません。"
}, },
"navbar": { "navbar": {
"expand": "ダイアログを展開", "expand": "ダイアログを展開",
@ -1341,6 +1342,7 @@
"providerPlaceholder": "プロバイダー名", "providerPlaceholder": "プロバイダー名",
"advancedSettings": "詳細設定" "advancedSettings": "詳細設定"
}, },
"messages.prompt": "プロンプト表示",
"messages.divider": "メッセージ間に区切り線を表示", "messages.divider": "メッセージ間に区切り線を表示",
"messages.grid_columns": "メッセージグリッドの表示列数", "messages.grid_columns": "メッセージグリッドの表示列数",
"messages.grid_popover_trigger": "グリッド詳細トリガー", "messages.grid_popover_trigger": "グリッド詳細トリガー",

View File

@ -719,7 +719,8 @@
"text": "Текст", "text": "Текст",
"vision": "Визуальные", "vision": "Визуальные",
"websearch": "Веб-поисковые" "websearch": "Веб-поисковые"
} },
"rerank_model_not_support_provider": "В настоящее время модель переупорядочивания не поддерживает этого провайдера ({{provider}})"
}, },
"navbar": { "navbar": {
"expand": "Развернуть диалоговое окно", "expand": "Развернуть диалоговое окно",
@ -1341,6 +1342,7 @@
"providerPlaceholder": "Имя провайдера", "providerPlaceholder": "Имя провайдера",
"advancedSettings": "Расширенные настройки" "advancedSettings": "Расширенные настройки"
}, },
"messages.prompt": "Показывать подсказки",
"messages.divider": "Показывать разделитель между сообщениями", "messages.divider": "Показывать разделитель между сообщениями",
"messages.grid_columns": "Количество столбцов сетки сообщений", "messages.grid_columns": "Количество столбцов сетки сообщений",
"messages.grid_popover_trigger": "Триггер для отображения подробной информации в сетке", "messages.grid_popover_trigger": "Триггер для отображения подробной информации в сетке",

View File

@ -705,6 +705,7 @@
"pinned": "已固定", "pinned": "已固定",
"rerank_model": "重排模型", "rerank_model": "重排模型",
"rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})", "rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})",
"rerank_model_not_support_provider": "目前重排序模型不支持该服务商 ({{provider}})",
"rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加", "rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加",
"search": "搜索模型...", "search": "搜索模型...",
"stream_output": "流式输出", "stream_output": "流式输出",
@ -1343,6 +1344,7 @@
"providerPlaceholder": "提供者名称", "providerPlaceholder": "提供者名称",
"advancedSettings": "高级设置" "advancedSettings": "高级设置"
}, },
"messages.prompt": "提示词显示",
"messages.divider": "消息分割线", "messages.divider": "消息分割线",
"messages.grid_columns": "消息网格展示列数", "messages.grid_columns": "消息网格展示列数",
"messages.grid_popover_trigger": "网格详情触发", "messages.grid_popover_trigger": "网格详情触发",

View File

@ -719,7 +719,8 @@
"text": "文字", "text": "文字",
"vision": "視覺", "vision": "視覺",
"websearch": "網路搜尋" "websearch": "網路搜尋"
} },
"rerank_model_not_support_provider": "目前,重新排序模型不支援此提供者({{provider}}"
}, },
"navbar": { "navbar": {
"expand": "伸縮對話框", "expand": "伸縮對話框",
@ -1342,6 +1343,7 @@
"providerPlaceholder": "提供者名稱", "providerPlaceholder": "提供者名稱",
"advancedSettings": "高級設定" "advancedSettings": "高級設定"
}, },
"messages.prompt": "提示詞顯示",
"messages.divider": "訊息間顯示分隔線", "messages.divider": "訊息間顯示分隔線",
"messages.grid_columns": "訊息網格展示列數", "messages.grid_columns": "訊息網格展示列數",
"messages.grid_popover_trigger": "網格詳細資訊觸發", "messages.grid_popover_trigger": "網格詳細資訊觸發",

View File

@ -42,7 +42,7 @@ interface MessagesProps {
const Messages: React.FC<MessagesProps> = ({ assistant, topic, setActiveTopic }) => { const Messages: React.FC<MessagesProps> = ({ assistant, topic, setActiveTopic }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { showTopics, topicPosition, showAssistants, messageNavigation } = useSettings() const { showPrompt, showTopics, topicPosition, showAssistants, messageNavigation } = useSettings()
const { updateTopic, addTopic } = useAssistant(assistant.id) const { updateTopic, addTopic } = useAssistant(assistant.id)
const dispatch = useAppDispatch() const dispatch = useAppDispatch()
const containerRef = useRef<HTMLDivElement>(null) const containerRef = useRef<HTMLDivElement>(null)
@ -254,7 +254,7 @@ const Messages: React.FC<MessagesProps> = ({ assistant, topic, setActiveTopic })
)} )}
</ScrollContainer> </ScrollContainer>
</InfiniteScroll> </InfiniteScroll>
<Prompt assistant={assistant} key={assistant.prompt} topic={topic} /> {showPrompt && <Prompt assistant={assistant} key={assistant.prompt} topic={topic} />}
</NarrowLayout> </NarrowLayout>
{messageNavigation === 'anchor' && <MessageAnchorLine messages={displayMessages} />} {messageNavigation === 'anchor' && <MessageAnchorLine messages={displayMessages} />}
{messageNavigation === 'buttons' && <ChatNavigation containerId="messages" />} {messageNavigation === 'buttons' && <ChatNavigation containerId="messages" />}

View File

@ -37,6 +37,7 @@ import {
setPasteLongTextThreshold, setPasteLongTextThreshold,
setRenderInputMessageAsMarkdown, setRenderInputMessageAsMarkdown,
setShowInputEstimatedTokens, setShowInputEstimatedTokens,
setShowPrompt,
setShowMessageDivider, setShowMessageDivider,
setShowTranslateConfirm, setShowTranslateConfirm,
setThoughtAutoCollapse setThoughtAutoCollapse
@ -76,6 +77,7 @@ const SettingsTab: FC<Props> = (props) => {
const dispatch = useAppDispatch() const dispatch = useAppDispatch()
const { const {
showPrompt,
showMessageDivider, showMessageDivider,
messageFont, messageFont,
showInputEstimatedTokens, showInputEstimatedTokens,
@ -282,6 +284,11 @@ const SettingsTab: FC<Props> = (props) => {
<SettingGroup> <SettingGroup>
<SettingSubtitle style={{ marginTop: 0 }}>{t('settings.messages.title')}</SettingSubtitle> <SettingSubtitle style={{ marginTop: 0 }}>{t('settings.messages.title')}</SettingSubtitle>
<SettingDivider /> <SettingDivider />
<SettingRow>
<SettingRowTitleSmall>{t('settings.messages.prompt')}</SettingRowTitleSmall>
<Switch size="small" checked={showPrompt} onChange={(checked) => dispatch(setShowPrompt(checked))} />
</SettingRow>
<SettingDivider />
<SettingRow> <SettingRow>
<SettingRowTitleSmall>{t('settings.messages.divider')}</SettingRowTitleSmall> <SettingRowTitleSmall>{t('settings.messages.divider')}</SettingRowTitleSmall>
<Switch <Switch

View File

@ -1,7 +1,8 @@
import { TopView } from '@renderer/components/TopView' import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers' import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge' import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import { SettingHelpText } from '@renderer/pages/settings' import { SettingHelpText } from '@renderer/pages/settings'
@ -67,7 +68,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
const rerankSelectOptions = providers const rerankSelectOptions = providers
.filter((p) => p.models.length > 0) .filter((p) => p.models.length > 0)
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id)) .filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({ .map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name, label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name, title: p.name,
@ -176,8 +177,8 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} /> <Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
</Form.Item> </Form.Item>
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}> <SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
{t('models.rerank_model_support_provider', { {t('models.rerank_model_not_support_provider', {
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`)) provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
})} })}
</SettingHelpText> </SettingHelpText>
<Form.Item <Form.Item

View File

@ -3,7 +3,8 @@ import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings' import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers' import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledge } from '@renderer/hooks/useKnowledge' import { useKnowledge } from '@renderer/hooks/useKnowledge'
import { useProviders } from '@renderer/hooks/useProvider' import { useProviders } from '@renderer/hooks/useProvider'
import { SettingHelpText } from '@renderer/pages/settings' import { SettingHelpText } from '@renderer/pages/settings'
@ -68,7 +69,7 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
const rerankSelectOptions = providers const rerankSelectOptions = providers
.filter((p) => p.models.length > 0) .filter((p) => p.models.length > 0)
.filter((p) => SUPPORTED_REANK_PROVIDERS.includes(p.id)) .filter((p) => !NOT_SUPPORTED_REANK_PROVIDERS.includes(p.id))
.map((p) => ({ .map((p) => ({
label: p.isSystem ? t(`provider.${p.id}`) : p.name, label: p.isSystem ? t(`provider.${p.id}`) : p.name,
title: p.name, title: p.name,
@ -157,8 +158,8 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
/> />
</Form.Item> </Form.Item>
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}> <SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
{t('models.rerank_model_support_provider', { {t('models.rerank_model_not_support_provider', {
provider: SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`)) provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
})} })}
</SettingHelpText> </SettingHelpText>

View File

@ -132,7 +132,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
] as ModelType[] ] as ModelType[]
// 合并现有选择和默认类型 // 合并现有选择和默认类型
const selectedTypes = [...new Set([...(model.type || []), ...defaultTypes])] const selectedTypes = model.type ? model.type : defaultTypes
const showTypeConfirmModal = (type: string) => { const showTypeConfirmModal = (type: string) => {
window.modal.confirm({ window.modal.confirm({
@ -165,28 +165,23 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
options={[ options={[
{ {
label: t('models.type.vision'), label: t('models.type.vision'),
value: 'vision', value: 'vision'
disabled: isVisionModel(model) && !selectedTypes.includes('vision')
}, },
{ {
label: t('models.type.websearch'), label: t('models.type.websearch'),
value: 'web_search', value: 'web_search'
disabled: isWebSearchModel(model) && !selectedTypes.includes('web_search')
}, },
{ {
label: t('models.type.embedding'), label: t('models.type.embedding'),
value: 'embedding', value: 'embedding'
disabled: isEmbeddingModel(model) && !selectedTypes.includes('embedding')
}, },
{ {
label: t('models.type.reasoning'), label: t('models.type.reasoning'),
value: 'reasoning', value: 'reasoning'
disabled: isReasoningModel(model) && !selectedTypes.includes('reasoning')
}, },
{ {
label: t('models.type.function_calling'), label: t('models.type.function_calling'),
value: 'function_calling', value: 'function_calling'
disabled: isFunctionCallingModel(model) && !selectedTypes.includes('function_calling')
} }
]} ]}
/> />

View File

@ -192,14 +192,11 @@ const WebSearchProviderSetting: FC<Props> = ({ provider: _provider }) => {
onChange={(e) => setApiHost(e.target.value)} onChange={(e) => setApiHost(e.target.value)}
onBlur={onUpdateApiHost} onBlur={onUpdateApiHost}
/> />
<Button
ghost={apiValid}
type={apiValid ? 'primary' : 'default'}
onClick={checkSearch}
disabled={apiChecking}>
{apiChecking ? <LoadingOutlined spin /> : apiValid ? <CheckOutlined /> : t('settings.websearch.check')}
</Button>
</Flex> </Flex>
</>
)}
{hasObjectKey(provider, 'basicAuthUsername') && (
<>
<SettingDivider style={{ marginTop: 12, marginBottom: 12 }} /> <SettingDivider style={{ marginTop: 12, marginBottom: 12 }} />
<SettingSubtitle style={{ marginTop: 5, marginBottom: 10 }}> <SettingSubtitle style={{ marginTop: 5, marginBottom: 10 }}>
{t('settings.provider.basic_auth')} {t('settings.provider.basic_auth')}

View File

@ -402,7 +402,7 @@ export default class OpenAICompatibleProvider extends BaseOpenAiProvider {
await this.checkIsCopilot() await this.checkIsCopilot()
const lastUserMsg = userMessages.findLast((m) => m.role === 'user') const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
if (lastUserMsg) { if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
const postsuffix = '/no_think' const postsuffix = '/no_think'
// qwenThinkMode === true 表示思考模式啓用,此時不應添加 /no_think如果存在則移除 // qwenThinkMode === true 表示思考模式啓用,此時不應添加 /no_think如果存在則移除
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true

View File

@ -4,18 +4,32 @@ import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
export default abstract class BaseWebSearchProvider { export default abstract class BaseWebSearchProvider {
// @ts-ignore this // @ts-ignore this
protected provider: WebSearchProvider protected provider: WebSearchProvider
protected apiHost?: string
protected apiKey: string protected apiKey: string
constructor(provider: WebSearchProvider) { constructor(provider: WebSearchProvider) {
this.provider = provider this.provider = provider
this.apiHost = this.getApiHost()
this.apiKey = this.getApiKey() this.apiKey = this.getApiKey()
} }
abstract search( abstract search(
query: string, query: string,
websearch: WebSearchState, websearch: WebSearchState,
httpOptions?: RequestInit httpOptions?: RequestInit
): Promise<WebSearchProviderResponse> ): Promise<WebSearchProviderResponse>
public getApiHost() {
return this.provider.apiHost
}
public defaultHeaders() {
return {
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio'
}
}
public getApiKey() { public getApiKey() {
const keys = this.provider.apiKey?.split(',').map((key) => key.trim()) || [] const keys = this.provider.apiKey?.split(',').map((key) => key.trim()) || []
const keyName = `web-search-provider:${this.provider.id}:last_used_key` const keyName = `web-search-provider:${this.provider.id}:last_used_key`

View File

@ -0,0 +1,70 @@
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { BochaSearchParams, BochaSearchResponse } from '@renderer/types/bocha'
import BaseWebSearchProvider from './BaseWebSearchProvider'
export default class BochaProvider extends BaseWebSearchProvider {
constructor(provider: WebSearchProvider) {
super(provider)
if (!this.apiKey) {
throw new Error('API key is required for Bocha provider')
}
if (!this.apiHost) {
throw new Error('API host is required for Bocha provider')
}
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
try {
if (!query.trim()) {
throw new Error('Search query cannot be empty')
}
const headers = {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`
}
const contentLimit = websearch.contentLimit
const params: BochaSearchParams = {
query,
count: websearch.maxResults,
exclude: websearch.excludeDomains.join(','),
freshness: websearch.searchWithTime ? 'oneDay' : 'noLimit',
summary: false,
page: contentLimit ? Math.ceil(contentLimit / websearch.maxResults) : 1
}
const response = await fetch(`${this.apiHost}/v1/web-search`, {
method: 'POST',
body: JSON.stringify(params),
headers: {
...this.defaultHeaders(),
...headers
}
})
if (!response.ok) {
throw new Error(`Bocha search failed: ${response.status} ${response.statusText}`)
}
const resp: BochaSearchResponse = await response.json()
if (resp.code !== 200) {
throw new Error(`Bocha search failed: ${resp.msg}`)
}
return {
query: resp.data.queryContext.originalQuery,
results: resp.data.webPages.value.map((result) => ({
title: result.name,
content: result.snippet,
url: result.url
}))
}
} catch (error) {
console.error('Bocha search failed:', error)
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
}
}
}

View File

@ -12,7 +12,10 @@ export default class ExaProvider extends BaseWebSearchProvider {
if (!this.apiKey) { if (!this.apiKey) {
throw new Error('API key is required for Exa provider') throw new Error('API key is required for Exa provider')
} }
this.exa = new ExaClient({ apiKey: this.apiKey }) if (!this.apiHost) {
throw new Error('API host is required for Exa provider')
}
this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
} }
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> { public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {

View File

@ -10,7 +10,6 @@ import BaseWebSearchProvider from './BaseWebSearchProvider'
export default class SearxngProvider extends BaseWebSearchProvider { export default class SearxngProvider extends BaseWebSearchProvider {
private searxng: SearxngClient private searxng: SearxngClient
private engines: string[] = [] private engines: string[] = []
private readonly apiHost: string
private readonly basicAuthUsername?: string private readonly basicAuthUsername?: string
private readonly basicAuthPassword?: string private readonly basicAuthPassword?: string
private isInitialized = false private isInitialized = false

View File

@ -12,7 +12,10 @@ export default class TavilyProvider extends BaseWebSearchProvider {
if (!this.apiKey) { if (!this.apiKey) {
throw new Error('API key is required for Tavily provider') throw new Error('API key is required for Tavily provider')
} }
this.tvly = new TavilyClient({ apiKey: this.apiKey }) if (!this.apiHost) {
throw new Error('API host is required for Tavily provider')
}
this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
} }
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> { public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {

View File

@ -1,6 +1,7 @@
import { WebSearchProvider } from '@renderer/types' import { WebSearchProvider } from '@renderer/types'
import BaseWebSearchProvider from './BaseWebSearchProvider' import BaseWebSearchProvider from './BaseWebSearchProvider'
import BochaProvider from './BochaProvider'
import DefaultProvider from './DefaultProvider' import DefaultProvider from './DefaultProvider'
import ExaProvider from './ExaProvider' import ExaProvider from './ExaProvider'
import LocalBaiduProvider from './LocalBaiduProvider' import LocalBaiduProvider from './LocalBaiduProvider'
@ -14,6 +15,8 @@ export default class WebSearchProviderFactory {
switch (provider.id) { switch (provider.id) {
case 'tavily': case 'tavily':
return new TavilyProvider(provider) return new TavilyProvider(provider)
case 'bocha':
return new BochaProvider(provider)
case 'searxng': case 'searxng':
return new SearxngProvider(provider) return new SearxngProvider(provider)
case 'exa': case 'exa':

View File

@ -125,6 +125,8 @@ async function fetchExternalTool(
return return
} }
if (extractResults.websearch.question[0] === 'not_needed') return
// Add check for assistant.model before using it // Add check for assistant.model before using it
if (!assistant.model) { if (!assistant.model) {
console.warn('searchTheWeb called without assistant.model') console.warn('searchTheWeb called without assistant.model')

View File

@ -106,7 +106,7 @@ class WebSearchService {
const webSearchEngine = new WebSearchEngineProvider(provider) const webSearchEngine = new WebSearchEngineProvider(provider)
let formattedQuery = query let formattedQuery = query
// 有待商榷,效果一般 // FIXME: 有待商榷,效果一般
if (websearch.searchWithTime) { if (websearch.searchWithTime) {
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}` formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
} }

View File

@ -46,7 +46,7 @@ const persistedReducer = persistReducer(
{ {
key: 'cherry-studio', key: 'cherry-studio',
storage, storage,
version: 98, version: 99,
blacklist: ['runtime', 'messages', 'messageBlocks'], blacklist: ['runtime', 'messages', 'messageBlocks'],
migrate migrate
}, },

View File

@ -5,7 +5,7 @@ import { SYSTEM_MODELS } from '@renderer/config/models'
import { TRANSLATE_PROMPT } from '@renderer/config/prompts' import { TRANSLATE_PROMPT } from '@renderer/config/prompts'
import db from '@renderer/databases' import db from '@renderer/databases'
import i18n from '@renderer/i18n' import i18n from '@renderer/i18n'
import { Assistant } from '@renderer/types' import { Assistant, WebSearchProvider } from '@renderer/types'
import { getDefaultGroupName, getLeadingEmoji, runAsyncFunction, uuid } from '@renderer/utils' import { getDefaultGroupName, getLeadingEmoji, runAsyncFunction, uuid } from '@renderer/utils'
import { isEmpty } from 'lodash' import { isEmpty } from 'lodash'
import { createMigrate } from 'redux-persist' import { createMigrate } from 'redux-persist'
@ -64,6 +64,18 @@ function addWebSearchProvider(state: RootState, id: string) {
} }
} }
function updateWebSearchProvider(state: RootState, provider: Partial<WebSearchProvider>) {
if (state.websearch && state.websearch.providers) {
const index = state.websearch.providers.findIndex((p) => p.id === provider.id)
if (index !== -1) {
state.websearch.providers[index] = {
...state.websearch.providers[index],
...provider
}
}
}
}
const migrateConfig = { const migrateConfig = {
'2': (state: RootState) => { '2': (state: RootState) => {
try { try {
@ -1252,6 +1264,38 @@ const migrateConfig = {
} catch (error) { } catch (error) {
return state return state
} }
},
'99': (state: RootState) => {
try {
state.settings.showPrompt = true
addWebSearchProvider(state, 'bocha')
updateWebSearchProvider(state, {
id: 'exa',
apiHost: 'https://api.exa.ai'
})
updateWebSearchProvider(state, {
id: 'tavily',
apiHost: 'https://api.tavily.com'
})
// Remove basic auth fields from exa and tavily
if (state.websearch?.providers) {
state.websearch.providers = state.websearch.providers.map((provider) => {
if (provider.id === 'exa' || provider.id === 'tavily') {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { basicAuthUsername, basicAuthPassword, ...rest } = provider
return rest
}
return provider
})
}
return state
} catch (error) {
return state
}
} }
} }

View File

@ -31,6 +31,7 @@ export interface SettingsState {
proxyMode: 'system' | 'custom' | 'none' proxyMode: 'system' | 'custom' | 'none'
proxyUrl?: string proxyUrl?: string
userName: string userName: string
showPrompt: boolean
showMessageDivider: boolean showMessageDivider: boolean
messageFont: 'system' | 'serif' messageFont: 'system' | 'serif'
showInputEstimatedTokens: boolean showInputEstimatedTokens: boolean
@ -143,6 +144,7 @@ export const initialState: SettingsState = {
proxyMode: 'system', proxyMode: 'system',
proxyUrl: undefined, proxyUrl: undefined,
userName: '', userName: '',
showPrompt: true,
showMessageDivider: true, showMessageDivider: true,
messageFont: 'system', messageFont: 'system',
showInputEstimatedTokens: false, showInputEstimatedTokens: false,
@ -272,6 +274,9 @@ const settingsSlice = createSlice({
setUserName: (state, action: PayloadAction<string>) => { setUserName: (state, action: PayloadAction<string>) => {
state.userName = action.payload state.userName = action.payload
}, },
setShowPrompt: (state, action: PayloadAction<boolean>) => {
state.showPrompt = action.payload
},
setShowMessageDivider: (state, action: PayloadAction<boolean>) => { setShowMessageDivider: (state, action: PayloadAction<boolean>) => {
state.showMessageDivider = action.payload state.showMessageDivider = action.payload
}, },
@ -528,6 +533,7 @@ export const {
setProxyMode, setProxyMode,
setProxyUrl, setProxyUrl,
setUserName, setUserName,
setShowPrompt,
setShowMessageDivider, setShowMessageDivider,
setMessageFont, setMessageFont,
setShowInputEstimatedTokens, setShowInputEstimatedTokens,

View File

@ -25,6 +25,8 @@ export interface WebSearchState {
/** @deprecated 支持在快捷菜单中自选搜索供应商,所以这个不再适用 */ /** @deprecated 支持在快捷菜单中自选搜索供应商,所以这个不再适用 */
overwrite: boolean overwrite: boolean
contentLimit?: number contentLimit?: number
// 具体供应商的配置
providerConfig: Record<string, any>
} }
const initialState: WebSearchState = { const initialState: WebSearchState = {
@ -33,16 +35,26 @@ const initialState: WebSearchState = {
{ {
id: 'tavily', id: 'tavily',
name: 'Tavily', name: 'Tavily',
apiHost: 'https://api.tavily.com',
apiKey: '' apiKey: ''
}, },
{ {
id: 'searxng', id: 'searxng',
name: 'Searxng', name: 'Searxng',
apiHost: '' apiHost: '',
basicAuthUsername: '',
basicAuthPassword: ''
}, },
{ {
id: 'exa', id: 'exa',
name: 'Exa', name: 'Exa',
apiHost: 'https://api.exa.ai',
apiKey: ''
},
{
id: 'bocha',
name: 'Bocha',
apiHost: 'https://api.bochaai.com',
apiKey: '' apiKey: ''
}, },
{ {
@ -65,7 +77,8 @@ const initialState: WebSearchState = {
maxResults: 5, maxResults: 5,
excludeDomains: [], excludeDomains: [],
subscribeSources: [], subscribeSources: [],
overwrite: false overwrite: false,
providerConfig: {}
} }
export const defaultWebSearchProviders = initialState.providers export const defaultWebSearchProviders = initialState.providers
@ -139,6 +152,12 @@ const websearchSlice = createSlice({
}, },
setContentLimit: (state, action: PayloadAction<number | undefined>) => { setContentLimit: (state, action: PayloadAction<number | undefined>) => {
state.contentLimit = action.payload state.contentLimit = action.payload
},
setProviderConfig: (state, action: PayloadAction<Record<string, any>>) => {
state.providerConfig = action.payload
},
updateProviderConfig: (state, action: PayloadAction<Record<string, any>>) => {
state.providerConfig = { ...state.providerConfig, ...action.payload }
} }
} }
}) })
@ -157,7 +176,9 @@ export const {
setSubscribeSources, setSubscribeSources,
setOverwrite, setOverwrite,
addWebSearchProvider, addWebSearchProvider,
setContentLimit setContentLimit,
setProviderConfig,
updateProviderConfig
} = websearchSlice.actions } = websearchSlice.actions
export default websearchSlice.reducer export default websearchSlice.reducer

View File

@ -0,0 +1,207 @@
import { z } from 'zod'
export const freshnessOptions = ['oneDay', 'oneWeek', 'oneMonth', 'oneYear', 'noLimit'] as const
const isValidDate = (dateStr: string): boolean => {
// First check basic format
if (!/^\d{4}-\d{2}-\d{2}$/.test(dateStr)) {
return false
}
const [year, month, day] = dateStr.split('-').map(Number)
if (year < 1900 || year > 2100) {
return false
}
// Check month range
if (month < 1 || month > 12) {
return false
}
// Get last day of the month
const lastDay = new Date(year, month, 0).getDate()
// Check day range
if (day < 1 || day > lastDay) {
return false
}
return true
}
const isValidDateRange = (dateRangeStr: string): boolean => {
// Check if it's a single date
if (/^\d{4}-\d{2}-\d{2}$/.test(dateRangeStr)) {
return isValidDate(dateRangeStr)
}
// Check if it's a date range
if (!/^\d{4}-\d{2}-\d{2}\.\.\d{4}-\d{2}-\d{2}$/.test(dateRangeStr)) {
return false
}
const [startDate, endDate] = dateRangeStr.split('..')
// Validate both dates
if (!isValidDate(startDate) || !isValidDate(endDate)) {
return false
}
// Check if start date is before or equal to end date
const start = new Date(startDate)
const end = new Date(endDate)
return start <= end
}
const isValidExcludeDomains = (excludeStr: string): boolean => {
if (!excludeStr) return true
// Split by either | or ,
const domains = excludeStr
.split(/[|,]/)
.map((d) => d.trim())
.filter(Boolean)
// Check number of domains
if (domains.length > 20) {
return false
}
// Domain name regex (supports both root domains and subdomains)
const domainRegex = /^(?:[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]{2,}$/
// Check each domain
return domains.every((domain) => domainRegex.test(domain))
}
const BochaSearchParamsSchema = z.object({
query: z.string(),
freshness: z
.union([
z.enum(freshnessOptions),
z
.string()
.regex(
/^(\d{4}-\d{2}-\d{2})(\.\.\d{4}-\d{2}-\d{2})?$/,
'Date must be in YYYY-MM-DD or YYYY-MM-DD..YYYY-MM-DD format'
)
.refine(isValidDateRange, {
message: 'Invalid date range - please provide valid dates in YYYY-MM-DD or YYYY-MM-DD..YYYY-MM-DD format'
})
])
.optional()
.default('noLimit'),
summary: z.boolean().optional().default(false),
exclude: z
.string()
.optional()
.refine((val) => !val || isValidExcludeDomains(val), {
message:
'Invalid exclude format. Please provide valid domain names separated by | or ,. Maximum 20 domains allowed.'
}),
page: z.number().optional().default(1),
count: z.number().optional().default(10)
})
const BochaSearchResponseDataSchema = z.object({
type: z.string(),
queryContext: z.object({
originalQuery: z.string()
}),
webPages: z.object({
webSearchUrl: z.string(),
totalEstimatedMatches: z.number(),
value: z.array(
z.object({
id: z.string(),
name: z.string(),
url: z.string(),
displayUrl: z.string(),
snippet: z.string(),
summary: z.string().optional(),
siteName: z.string(),
siteIcon: z.string(),
datePublished: z.string(),
dateLastCrawled: z.string(),
cachedPageUrl: z.string(),
language: z.string(),
isFamilyFriendly: z.boolean(),
isNavigational: z.boolean()
})
),
someResultsRemoved: z.boolean()
}),
images: z.object({
id: z.string(),
readLink: z.string(),
webSearchUrl: z.string(),
name: z.string(),
value: z.array(
z.object({
webSearchUrl: z.string(),
name: z.string(),
thumbnailUrl: z.string(),
datePublished: z.string(),
contentUrl: z.string(),
hostPageUrl: z.string(),
contentSize: z.string(),
encodingFormat: z.string(),
hostPageDisplayUrl: z.string(),
width: z.number(),
height: z.number(),
thumbnail: z.object({
width: z.number(),
height: z.number()
})
})
)
}),
videos: z.object({
id: z.string(),
readLink: z.string(),
webSearchUrl: z.string(),
isFamilyFriendly: z.boolean(),
scenario: z.string(),
name: z.string(),
value: z.array(
z.object({
webSearchUrl: z.string(),
name: z.string(),
description: z.string(),
thumbnailUrl: z.string(),
publisher: z.string(),
creator: z.string(),
contentUrl: z.string(),
hostPageUrl: z.string(),
encodingFormat: z.string(),
hostPageDisplayUrl: z.string(),
width: z.number(),
height: z.number(),
duration: z.number(),
motionThumbnailUrl: z.string(),
embedHtml: z.string(),
allowHttpsEmbed: z.boolean(),
viewCount: z.number(),
thumbnail: z.object({
width: z.number(),
height: z.number()
}),
allowMobileEmbed: z.boolean(),
isSuperfresh: z.boolean(),
datePublished: z.string()
})
)
})
})
const BochaSearchResponseSchema = z.object({
code: z.number(),
logId: z.string(),
data: BochaSearchResponseDataSchema,
msg: z.string().optional()
})
export type BochaSearchParams = z.infer<typeof BochaSearchParamsSchema>
export type BochaSearchResponse = z.infer<typeof BochaSearchResponseSchema>
export { BochaSearchParamsSchema, BochaSearchResponseSchema }