diff --git a/src/renderer/src/assets/images/apps/n8n.ico b/src/renderer/src/assets/images/apps/n8n.ico deleted file mode 100644 index 4df30bfeda..0000000000 Binary files a/src/renderer/src/assets/images/apps/n8n.ico and /dev/null differ diff --git a/src/renderer/src/assets/images/apps/n8n.svg b/src/renderer/src/assets/images/apps/n8n.svg new file mode 100644 index 0000000000..82f0a6da2e --- /dev/null +++ b/src/renderer/src/assets/images/apps/n8n.svg @@ -0,0 +1 @@ +n8n \ No newline at end of file diff --git a/src/renderer/src/assets/images/apps/paratera.ico b/src/renderer/src/assets/images/apps/paratera.ico deleted file mode 100644 index ff3958618c..0000000000 Binary files a/src/renderer/src/assets/images/apps/paratera.ico and /dev/null differ diff --git a/src/renderer/src/components/Popups/SelectModelPopup.tsx b/src/renderer/src/components/Popups/SelectModelPopup.tsx index c883e02ffa..c19c5def9c 100644 --- a/src/renderer/src/components/Popups/SelectModelPopup.tsx +++ b/src/renderer/src/components/Popups/SelectModelPopup.tsx @@ -59,9 +59,11 @@ const PopupContainer: React.FC = ({ model, resolve }) => { const [pinnedModels, setPinnedModels] = useState([]) const [_focusedItemKey, setFocusedItemKey] = useState('') const focusedItemKey = useDeferredValue(_focusedItemKey) - const [currentStickyGroup, setCurrentStickyGroup] = useState(null) + const [_stickyGroup, setStickyGroup] = useState(null) + const stickyGroup = useDeferredValue(_stickyGroup) const firstGroupRef = useRef(null) const scrollTriggerRef = useRef('initial') + const lastScrollOffsetRef = useRef(0) // 当前选中的模型ID const currentModelId = model ? getModelUniqId(model) : '' @@ -220,6 +222,45 @@ const PopupContainer: React.FC = ({ model, resolve }) => { return items }, [providers, getFilteredModels, pinnedModels, searchText, t, createModelItem]) + // 基于滚动位置更新sticky分组标题 + const updateStickyGroup = useCallback( + (scrollOffset?: number) => { + if (listItems.length === 0) { + setStickyGroup(null) + return + } + + // 基于滚动位置计算当前可见的第一个项的索引 + const estimatedIndex = Math.floor((scrollOffset ?? lastScrollOffsetRef.current) / ITEM_HEIGHT) + + // 从该索引向前查找最近的分组标题 + for (let i = estimatedIndex - 1; i >= 0; i--) { + if (i < listItems.length && listItems[i]?.type === 'group') { + setStickyGroup(listItems[i]) + return + } + } + + // 找不到则使用第一个分组标题 + setStickyGroup(firstGroupRef.current ?? null) + }, + [listItems] + ) + + // 在listItems变化时更新sticky group + useEffect(() => { + updateStickyGroup() + }, [listItems, updateStickyGroup]) + + // 处理列表滚动事件,更新lastScrollOffset并更新sticky分组 + const handleScroll = useCallback( + ({ scrollOffset }) => { + lastScrollOffsetRef.current = scrollOffset + updateStickyGroup(scrollOffset) + }, + [updateStickyGroup] + ) + // 获取可选择的模型项(过滤掉分组标题) const modelItems = useMemo(() => { return listItems.filter((item) => item.type === 'model') @@ -257,9 +298,6 @@ const PopupContainer: React.FC = ({ model, resolve }) => { const alignment = scrollTriggerRef.current === 'keyboard' ? 'auto' : 'center' listRef.current?.scrollToItem(index, alignment) - console.log('focusedItemKey', focusedItemKey) - console.log('scrollToFocusedItem', index, alignment) - // 滚动后重置触发器 scrollTriggerRef.current = 'none' }, [focusedItemKey, listItems]) @@ -365,41 +403,19 @@ const PopupContainer: React.FC = ({ model, resolve }) => { if (!open) return setTimeout(() => inputRef.current?.focus(), 0) scrollTriggerRef.current = 'initial' + lastScrollOffsetRef.current = 0 }, [open]) - // 初始化sticky分组标题 - useEffect(() => { - if (firstGroupRef.current) { - setCurrentStickyGroup(firstGroupRef.current) - } - }, [listItems]) - - const handleItemsRendered = useCallback( - ({ visibleStartIndex }: { visibleStartIndex: number; visibleStopIndex: number }) => { - // 从可见区域的起始位置向前查找最近的分组标题 - for (let i = visibleStartIndex - 1; i >= 0; i--) { - if (listItems[i]?.type === 'group') { - setCurrentStickyGroup(listItems[i]) - return - } - } - - // 找不到则使用第一个分组标题 - setCurrentStickyGroup(firstGroupRef.current ?? null) - }, - [listItems] - ) - const RowData = useMemo( (): VirtualizedRowData => ({ listItems, focusedItemKey, setFocusedItemKey, - currentStickyGroup, + stickyGroup, handleItemClick, togglePin }), - [currentStickyGroup, focusedItemKey, handleItemClick, listItems, togglePin] + [stickyGroup, focusedItemKey, handleItemClick, listItems, togglePin] ) const listHeight = useMemo(() => { @@ -456,7 +472,7 @@ const PopupContainer: React.FC = ({ model, resolve }) => { {listItems.length > 0 ? ( setIsMouseOver(true)}> {/* Sticky Group Banner,它会替换第一个分组名称 */} - {currentStickyGroup?.name} + {stickyGroup?.name} = ({ model, resolve }) => { itemData={RowData} itemKey={(index, data) => data.listItems[index].key} overscanCount={4} - onItemsRendered={handleItemsRendered} + onScroll={handleScroll} style={{ pointerEvents: isMouseOver ? 'auto' : 'none' }}> {VirtualizedRow} @@ -484,7 +500,7 @@ interface VirtualizedRowData { listItems: FlatListItem[] focusedItemKey: string setFocusedItemKey: (key: string) => void - currentStickyGroup: FlatListItem | null + stickyGroup: FlatListItem | null handleItemClick: (item: FlatListItem) => void togglePin: (modelId: string) => void } @@ -494,7 +510,7 @@ interface VirtualizedRowData { */ const VirtualizedRow = React.memo( ({ data, index, style }: { data: VirtualizedRowData; index: number; style: React.CSSProperties }) => { - const { listItems, focusedItemKey, setFocusedItemKey, handleItemClick, togglePin, currentStickyGroup } = data + const { listItems, focusedItemKey, setFocusedItemKey, handleItemClick, togglePin, stickyGroup } = data const item = listItems[index] @@ -505,7 +521,7 @@ const VirtualizedRow = React.memo( return (
{item.type === 'group' ? ( - {item.name} + {item.name} ) : ( => { try { let content: string try { - content = await window.api.file.read('customMiniAPP') + content = await window.api.file.read('custom-minapps.json') } catch (error) { // 如果文件不存在,创建一个空的 JSON 数组 content = '[]' - await window.api.file.writeWithId('customMiniAPP', content) + await window.api.file.writeWithId('custom-minapps.json', content) } const customApps = JSON.parse(content) @@ -451,18 +450,15 @@ const ORIGIN_DEFAULT_MIN_APPS: MinAppType[] = [ padding: 10 } }, - { - id: 'paratera', - name: 'ParateraAI', - logo: ParateraLogo, - url: 'https://ai.paratera.com/' - }, { id: 'n8n', name: 'n8n', logo: n8nLogo, url: 'https://app.n8n.cloud/', - bodered: true + bodered: true, + style: { + padding: 5 + } } ] diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index 77d9efc615..65ed25a04f 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -2074,62 +2074,6 @@ export const SYSTEM_MODELS: Record = { name: 'Qwen2.5 72B Instruct', group: 'Qwen' } - ], - paratera: [ - { - id: 'GLM-Z1-Flash-P002', - provider: 'paratera', - name: 'GLM-Z1-Flash-P002', - group: 'GLM' - }, - { - id: 'GLM-Z1-AirX-P002', - provider: 'paratera', - name: 'GLM-Z1-AirX-P002', - group: 'GLM' - }, - { - id: 'DeepSeek-V3-250324-P001', - provider: 'paratera', - name: 'DeepSeek-V3-250324-P001', - group: 'DeepSeek' - }, - { - id: 'DeepSeek-R1', - provider: 'paratera', - name: 'DeepSeek-R1', - group: 'DeepSeek' - }, - { - id: 'QwQ-N011-32B', - provider: 'paratera', - name: 'QwQ-N011-32B', - group: 'Qwen' - }, - { - id: 'GLM-Embedding-2-P002', - provider: 'paratera', - name: 'GLM-Embedding-2-P002', - group: 'GLM' - }, - { - id: 'GLM-Embedding-3-P002', - provider: 'paratera', - name: 'GLM-Embedding-3-P002', - group: 'GLM' - }, - { - id: 'Doubao-Embedding-Text-P001', - provider: 'paratera', - name: 'Doubao-Embedding-Text-P001', - group: 'Doubao' - }, - { - id: 'Doubao-Embedding-Large-Text-P001', - provider: 'paratera', - name: 'Doubao-Embedding-Large-Text-P001', - group: 'Doubao' - } ] } diff --git a/src/renderer/src/config/providers.ts b/src/renderer/src/config/providers.ts index 48ca7c9cc7..81407a8ae2 100644 --- a/src/renderer/src/config/providers.ts +++ b/src/renderer/src/config/providers.ts @@ -42,7 +42,6 @@ import VoyageAIProviderLogo from '@renderer/assets/images/providers/voyageai.png import XirangProviderLogo from '@renderer/assets/images/providers/xirang.png' import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png' import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png' -import ParateraLogo from '@renderer/assets/images/apps/paratera.ico' const PROVIDER_LOGO_MAP = { openai: OpenAiProviderLogo, @@ -89,8 +88,7 @@ const PROVIDER_LOGO_MAP = { gpustack: GPUStackProviderLogo, alayanew: AlayaNewProviderLogo, voyageai: VoyageAIProviderLogo, - qiniu: QiniuProviderLogo, - paratera: ParateraLogo + qiniu: QiniuProviderLogo } as const export function getProviderLogo(providerId: string) { @@ -585,16 +583,5 @@ export const PROVIDER_CONFIG = { docs: 'https://developer.qiniu.com/aitokenapi', models: 'https://developer.qiniu.com/aitokenapi/12883/model-list' } - }, - paratera: { - api: { - url: 'https://llmapi.paratera.com' - }, - websites: { - official: 'https://ai.paratera.com/', - apiKey: 'https://ai.paratera.com/#/lms/api', - docs: 'https://ai.paratera.com/document/llm/quickStart/useApi', - models: 'https://ai.paratera.com/#/lms/model' - } } } diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index dae4ed6ca5..3d8ab74b96 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -705,6 +705,7 @@ "rerank_model_tooltip": "Click the Manage button in Settings -> Model Services to add.", "search": "Search models...", "stream_output": "Stream output", + "enable_tool_use": "Enable Tool Use", "type": { "embedding": "Embedding", "free": "Free", diff --git a/src/renderer/src/i18n/locales/ja-jp.json b/src/renderer/src/i18n/locales/ja-jp.json index 046d2ff589..f1b12358d8 100644 --- a/src/renderer/src/i18n/locales/ja-jp.json +++ b/src/renderer/src/i18n/locales/ja-jp.json @@ -705,6 +705,7 @@ "rerank_model_tooltip": "設定->モデルサービスに移動し、管理ボタンをクリックして追加します。", "search": "モデルを検索...", "stream_output": "ストリーム出力", + "enable_tool_use": "ツール呼び出し", "type": { "embedding": "埋め込み", "free": "無料", diff --git a/src/renderer/src/i18n/locales/ru-ru.json b/src/renderer/src/i18n/locales/ru-ru.json index 41242b74c3..2692e1c270 100644 --- a/src/renderer/src/i18n/locales/ru-ru.json +++ b/src/renderer/src/i18n/locales/ru-ru.json @@ -705,6 +705,7 @@ "rerank_model_tooltip": "В настройках -> Служба модели нажмите кнопку \"Управление\", чтобы добавить.", "search": "Поиск моделей...", "stream_output": "Потоковый вывод", + "enable_tool_use": "Вызов инструмента", "type": { "embedding": "Встраиваемые", "free": "Бесплатные", diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index a57c106b49..e68409a642 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -705,6 +705,7 @@ "rerank_model_tooltip": "在设置->模型服务中点击管理按钮添加", "search": "搜索模型...", "stream_output": "流式输出", + "enable_tool_use": "工具调用", "type": { "embedding": "嵌入", "free": "免费", diff --git a/src/renderer/src/i18n/locales/zh-tw.json b/src/renderer/src/i18n/locales/zh-tw.json index d91d8fb257..893a7b75a1 100644 --- a/src/renderer/src/i18n/locales/zh-tw.json +++ b/src/renderer/src/i18n/locales/zh-tw.json @@ -705,6 +705,7 @@ "rerank_model_tooltip": "在設定->模型服務中點擊管理按鈕添加", "search": "搜尋模型...", "stream_output": "串流輸出", + "enable_tool_use": "工具調用", "type": { "embedding": "嵌入", "free": "免費", diff --git a/src/renderer/src/pages/apps/App.tsx b/src/renderer/src/pages/apps/App.tsx index cbf6ec75c6..c787a6ab1c 100644 --- a/src/renderer/src/pages/apps/App.tsx +++ b/src/renderer/src/pages/apps/App.tsx @@ -40,7 +40,7 @@ const App: FC = ({ app, onClick, size = 60, isLast }) => { const handleAddCustomApp = async (values: any) => { try { - const content = await window.api.file.read('customMiniAPP') + const content = await window.api.file.read('custom-minapps.json') const customApps = JSON.parse(content) // Check for duplicate ID @@ -62,7 +62,7 @@ const App: FC = ({ app, onClick, size = 60, isLast }) => { addTime: new Date().toISOString() } customApps.push(newApp) - await window.api.file.writeWithId('customMiniAPP', JSON.stringify(customApps, null, 2)) + await window.api.file.writeWithId('custom-minapps.json', JSON.stringify(customApps, null, 2)) message.success(t('settings.miniapps.custom.save_success')) setIsModalVisible(false) form.resetFields() @@ -138,10 +138,10 @@ const App: FC = ({ app, onClick, size = 60, isLast }) => { danger: true, onClick: async () => { try { - const content = await window.api.file.read('customMiniAPP') + const content = await window.api.file.read('custom-minapps.json') const customApps = JSON.parse(content) const updatedApps = customApps.filter((customApp: MinAppType) => customApp.id !== app.id) - await window.api.file.writeWithId('customMiniAPP', JSON.stringify(updatedApps, null, 2)) + await window.api.file.writeWithId('custom-minapps.json', JSON.stringify(updatedApps, null, 2)) message.success(t('settings.miniapps.custom.remove_success')) const reloadedApps = [...ORIGIN_DEFAULT_MIN_APPS, ...(await loadCustomMiniApp())] updateDefaultMinApps(reloadedApps) diff --git a/src/renderer/src/pages/home/Messages/CitationsList.tsx b/src/renderer/src/pages/home/Messages/CitationsList.tsx index 14f0b2b36b..c2ac673cff 100644 --- a/src/renderer/src/pages/home/Messages/CitationsList.tsx +++ b/src/renderer/src/pages/home/Messages/CitationsList.tsx @@ -91,7 +91,7 @@ const CitationsList: React.FC = ({ citations }) => { onClose={() => setOpen(false)} open={open} width={680} - styles={{ header: { border: 'none' }, body: { paddingTop: 0, backgroundColor: 'var(--color-background)' } }} + styles={{ header: { border: 'none' }, body: { paddingTop: 0 } }} destroyOnClose={false}> {open && citations.map((citation) => ( @@ -127,12 +127,12 @@ const WebSearchCitation: React.FC<{ citation: Citation }> = ({ citation }) => { }) return ( - handleLinkClick(citation.url)}> + {citation.showFavicon && citation.url && ( )} - + handleLinkClick(citation.url, e)}> {citation.title || {citation.hostname}} @@ -146,10 +146,12 @@ const WebSearchCitation: React.FC<{ citation: Citation }> = ({ citation }) => { } const KnowledgeCitation: React.FC<{ citation: Citation }> = ({ citation }) => ( - handleLinkClick(citation.url)}> + {citation.showFavicon && } - {citation.title} + handleLinkClick(citation.url, e)}> + {citation.title} + {citation.content && truncateText(citation.content, 100)} @@ -189,11 +191,15 @@ const PreviewIcon = styled.div` } ` -const CitationLink = styled.div` +const CitationLink = styled.a` font-size: 14px; line-height: 1.6; color: var(--color-text-1); text-decoration: none; + + .hostname { + color: var(--color-link); + } ` const WebSearchCard = styled.div` @@ -204,11 +210,6 @@ const WebSearchCard = styled.div` border-radius: var(--list-item-border-radius); background-color: var(--color-background); transition: all 0.3s ease; - cursor: pointer; - - &:hover { - background-color: var(--color-background-soft); - } ` const WebSearchCardHeader = styled.div` @@ -217,7 +218,6 @@ const WebSearchCardHeader = styled.div` align-items: center; gap: 8px; margin-bottom: 6px; - font-weight: 500; ` const WebSearchCardContent = styled.div` diff --git a/src/renderer/src/pages/home/Messages/MessageTools.tsx b/src/renderer/src/pages/home/Messages/MessageTools.tsx index 495c56cf10..b281e40642 100644 --- a/src/renderer/src/pages/home/Messages/MessageTools.tsx +++ b/src/renderer/src/pages/home/Messages/MessageTools.tsx @@ -67,7 +67,7 @@ const MessageTools: FC = ({ blocks }) => { const isDone = status === 'done' const hasError = isDone && response?.isError === true const result = { - params: tool.inputSchema, + params: toolResponse.arguments, response: toolResponse.response } diff --git a/src/renderer/src/pages/home/Tabs/SettingsTab.tsx b/src/renderer/src/pages/home/Tabs/SettingsTab.tsx index 9a21b1c808..0c8972db6d 100644 --- a/src/renderer/src/pages/home/Tabs/SettingsTab.tsx +++ b/src/renderer/src/pages/home/Tabs/SettingsTab.tsx @@ -70,6 +70,7 @@ const SettingsTab: FC = (props) => { const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0) const [fontSizeValue, setFontSizeValue] = useState(fontSize) const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true) + const [enableToolUse, setEnableToolUse] = useState(assistant?.settings?.enableToolUse ?? false) const { t } = useTranslation() const dispatch = useAppDispatch() @@ -222,6 +223,18 @@ const SettingsTab: FC = (props) => { /> + + {t('models.enable_tool_use')} + { + setEnableToolUse(checked) + updateAssistantSettings({ enableToolUse: checked }) + }} + /> + + diff --git a/src/renderer/src/pages/paintings/AihubmixPage.tsx b/src/renderer/src/pages/paintings/AihubmixPage.tsx index d9a6c2b485..aa52b0d560 100644 --- a/src/renderer/src/pages/paintings/AihubmixPage.tsx +++ b/src/renderer/src/pages/paintings/AihubmixPage.tsx @@ -1,4 +1,4 @@ -import { InfoCircleFilled, PlusOutlined, RedoOutlined } from '@ant-design/icons' +import { PlusOutlined, RedoOutlined } from '@ant-design/icons' import IcImageUp from '@renderer/assets/images/paintings/ic_ImageUp.svg' import { Navbar, NavbarCenter, NavbarRight } from '@renderer/components/app/Navbar' import { HStack } from '@renderer/components/Layout' @@ -20,6 +20,7 @@ import type { PaintingAction, PaintingsState } from '@renderer/types' import { getErrorMessage, uuid } from '@renderer/utils' import { Avatar, Button, Input, InputNumber, Radio, Segmented, Select, Slider, Switch, Tooltip, Upload } from 'antd' import TextArea from 'antd/es/input/TextArea' +import { Info } from 'lucide-react' import type { FC } from 'react' import { useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -72,6 +73,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => { { label: t('paintings.mode.remix'), value: 'remix' }, { label: t('paintings.mode.upscale'), value: 'upscale' } ] + const getNewPainting = () => { return { ...DEFAULT_PAINTING, @@ -278,14 +280,6 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => { } removePainting(mode, paintingToDelete) - - if (filteredPaintings.length === 1) { - const defaultPainting = { - ...DEFAULT_PAINTING, - id: uuid() - } - setPainting(defaultPainting) - } } const translate = async () => { @@ -334,6 +328,7 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => { navigate('../' + providerId, { replace: true }) } } + // 处理模式切换 const handleModeChange = (value: string) => { setMode(value as keyof PaintingsState) @@ -494,8 +489,9 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => { useEffect(() => { if (filteredPaintings.length === 0) { - addPainting(mode, getNewPainting()) - setPainting(DEFAULT_PAINTING) + const newPainting = getNewPainting() + addPainting(mode, newPainting) + setPainting(newPainting) } }, [filteredPaintings, mode, addPainting, painting]) @@ -674,11 +670,17 @@ const ToolbarMenu = styled.div` gap: 6px; ` -const InfoIcon = styled(InfoCircleFilled)` +const InfoIcon = styled(Info)` margin-left: 5px; cursor: help; - color: #8d94a6; - font-size: 12px; + color: var(--color-text-2); + opacity: 0.6; + width: 14px; + height: 16px; + + &:hover { + opacity: 1; + } ` const SliderContainer = styled.div` diff --git a/src/renderer/src/pages/paintings/PaintingsPage.tsx b/src/renderer/src/pages/paintings/PaintingsPage.tsx index fecfc9cd69..093105e08f 100644 --- a/src/renderer/src/pages/paintings/PaintingsPage.tsx +++ b/src/renderer/src/pages/paintings/PaintingsPage.tsx @@ -1,4 +1,4 @@ -import { PlusOutlined, QuestionCircleOutlined, RedoOutlined } from '@ant-design/icons' +import { PlusOutlined, RedoOutlined } from '@ant-design/icons' import ImageSize1_1 from '@renderer/assets/images/paintings/image-size-1-1.svg' import ImageSize1_2 from '@renderer/assets/images/paintings/image-size-1-2.svg' import ImageSize3_2 from '@renderer/assets/images/paintings/image-size-3-2.svg' @@ -26,6 +26,7 @@ import type { FileType, Painting } from '@renderer/types' import { getErrorMessage, uuid } from '@renderer/utils' import { Button, Input, InputNumber, Radio, Select, Slider, Switch, Tooltip } from 'antd' import TextArea from 'antd/es/input/TextArea' +import { Info } from 'lucide-react' import type { FC } from 'react' import { useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -90,7 +91,7 @@ const DEFAULT_PAINTING: Painting = { const PaintingsPage: FC<{ Options: string[] }> = ({ Options }) => { const { t } = useTranslation() const { paintings, addPainting, removePainting, updatePainting } = usePaintings() - const [painting, setPainting] = useState(DEFAULT_PAINTING) + const [painting, setPainting] = useState(paintings[0] || DEFAULT_PAINTING) const { theme } = useTheme() const providers = useAllProviders() const providerOptions = Options.map((option) => { @@ -260,10 +261,6 @@ const PaintingsPage: FC<{ Options: string[] }> = ({ Options }) => { } removePainting('paintings', paintingToDelete) - - if (paintings.length === 1) { - setPainting(getNewPainting()) - } } const onSelectPainting = (newPainting: Painting) => { @@ -326,8 +323,11 @@ const PaintingsPage: FC<{ Options: string[] }> = ({ Options }) => { useEffect(() => { if (paintings.length === 0) { - addPainting('paintings', getNewPainting()) + const newPainting = getNewPainting() + addPainting('paintings', newPainting) + setPainting(newPainting) } + return () => { if (spaceClickTimer.current) { clearTimeout(spaceClickTimer.current) @@ -602,11 +602,13 @@ const RadioButton = styled(Radio.Button)` align-items: center; ` -const InfoIcon = styled(QuestionCircleOutlined)` +const InfoIcon = styled(Info)` margin-left: 5px; cursor: help; color: var(--color-text-2); opacity: 0.6; + width: 16px; + height: 16px; &:hover { opacity: 1; diff --git a/src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx b/src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx index 31075ebbb7..9ea4559b47 100644 --- a/src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx +++ b/src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx @@ -24,6 +24,7 @@ const AssistantModelSettings: FC = ({ assistant, updateAssistant, updateA const [enableMaxTokens, setEnableMaxTokens] = useState(assistant?.settings?.enableMaxTokens ?? false) const [maxTokens, setMaxTokens] = useState(assistant?.settings?.maxTokens ?? 0) const [streamOutput, setStreamOutput] = useState(assistant?.settings?.streamOutput ?? true) + const [enableToolUse, setEnableToolUse] = useState(assistant?.settings?.enableToolUse ?? false) const [defaultModel, setDefaultModel] = useState(assistant?.defaultModel) const [topP, setTopP] = useState(assistant?.settings?.topP ?? 1) const [customParameters, setCustomParameters] = useState( @@ -377,6 +378,18 @@ const AssistantModelSettings: FC = ({ assistant, updateAssistant, updateA /> + + + { + setEnableToolUse(checked) + updateAssistantSettings({ enableToolUse: checked }) + }} + /> + + - ) diff --git a/src/renderer/src/providers/AiProvider/AihubmixProvider.ts b/src/renderer/src/providers/AiProvider/AihubmixProvider.ts index 8961a050d4..83b5377b5c 100644 --- a/src/renderer/src/providers/AiProvider/AihubmixProvider.ts +++ b/src/renderer/src/providers/AiProvider/AihubmixProvider.ts @@ -1,6 +1,6 @@ import { isOpenAILLMModel } from '@renderer/config/models' import { getDefaultModel } from '@renderer/services/AssistantService' -import { Assistant, Model, Provider, Suggestion } from '@renderer/types' +import { Assistant, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types' import { Message } from '@renderer/types/newMessage' import OpenAI from 'openai' @@ -18,6 +18,7 @@ import OpenAIProvider from './OpenAIProvider' export default class AihubmixProvider extends BaseProvider { private providers: Map = new Map() private defaultProvider: BaseProvider + private currentProvider: BaseProvider constructor(provider: Provider) { super(provider) @@ -30,6 +31,7 @@ export default class AihubmixProvider extends BaseProvider { // 设置默认提供商 this.defaultProvider = this.providers.get('default')! + this.currentProvider = this.defaultProvider } /** @@ -70,7 +72,8 @@ export default class AihubmixProvider extends BaseProvider { public async completions(params: CompletionsParams): Promise { const model = params.assistant.model - return this.getProvider(model!).completions(params) + this.currentProvider = this.getProvider(model!) + return this.currentProvider.completions(params) } public async translate( @@ -100,4 +103,12 @@ export default class AihubmixProvider extends BaseProvider { public async getEmbeddingDimensions(model: Model): Promise { return this.getProvider(model).getEmbeddingDimensions(model) } + + public convertMcpTools(mcpTools: MCPTool[]) { + return this.currentProvider.convertMcpTools(mcpTools) as T[] + } + + public mcpToolCallResponseToMessage(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) { + return this.currentProvider.mcpToolCallResponseToMessage(mcpToolResponse, resp, model) + } } diff --git a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts index d113d7f532..1f9fdfbe7e 100644 --- a/src/renderer/src/providers/AiProvider/AnthropicProvider.ts +++ b/src/renderer/src/providers/AiProvider/AnthropicProvider.ts @@ -1,15 +1,19 @@ import Anthropic from '@anthropic-ai/sdk' import { + Base64ImageSource, + ImageBlockParam, MessageCreateParamsNonStreaming, MessageParam, TextBlockParam, + ToolResultBlockParam, ToolUnion, + ToolUseBlock, WebSearchResultBlock, WebSearchTool20250305, WebSearchToolResultError } from '@anthropic-ai/sdk/resources' import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' -import { isReasoningModel, isVisionModel, isWebSearchModel } from '@renderer/config/models' +import { isReasoningModel, isWebSearchModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' @@ -23,16 +27,24 @@ import { Assistant, EFFORT_RATIO, FileTypes, + MCPCallToolResponse, + MCPTool, MCPToolResponse, Model, Provider, Suggestion, + ToolCallResponse, WebSearchSource } from '@renderer/types' import { ChunkType } from '@renderer/types/chunk' import type { Message } from '@renderer/types/newMessage' import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { mcpToolCallResponseToAnthropicMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' +import { + anthropicToolUseToMcpTool, + mcpToolCallResponseToAnthropicMessage, + mcpToolsToAnthropicTools, + parseAndCallTools +} from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { buildSystemPrompt } from '@renderer/utils/prompt' import { first, flatten, sum, takeRight } from 'lodash' @@ -199,7 +211,7 @@ export default class AnthropicProvider extends BaseProvider { public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) + const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant) const userMessagesParams: MessageParam[] = [] @@ -215,10 +227,16 @@ export default class AnthropicProvider extends BaseProvider { const userMessages = flatten(userMessagesParams) const lastUserMessage = _messages.findLast((m) => m.role === 'user') - // const tools = mcpTools ? mcpToolsToAnthropicTools(mcpTools) : undefined let systemPrompt = assistant.prompt - if (mcpTools && mcpTools.length > 0) { + + const { tools } = this.setupToolsConfig({ + model, + mcpTools, + enableToolUse + }) + + if (this.useSystemPromptForTools && mcpTools && mcpTools.length) { systemPrompt = buildSystemPrompt(systemPrompt, mcpTools) } @@ -232,8 +250,6 @@ export default class AnthropicProvider extends BaseProvider { const isEnabledBuiltinWebSearch = assistant.enableWebSearch - const tools: ToolUnion[] = [] - if (isEnabledBuiltinWebSearch) { const webSearchTool = await this.getWebSearchParams(model) if (webSearchTool) { @@ -244,7 +260,6 @@ export default class AnthropicProvider extends BaseProvider { const body: MessageCreateParamsNonStreaming = { model: model.id, messages: userMessages, - // tools: isEmpty(tools) ? undefined : tools, max_tokens: maxTokens || DEFAULT_MAX_TOKENS, temperature: this.getTemperature(assistant, model), top_p: this.getTopP(assistant, model), @@ -303,7 +318,7 @@ export default class AnthropicProvider extends BaseProvider { const processStream = (body: MessageCreateParamsNonStreaming, idx: number) => { return new Promise((resolve, reject) => { // 等待接口返回流 - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) + const toolCalls: ToolUseBlock[] = [] let hasThinkingContent = false this.sdk.messages .stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 }) @@ -380,30 +395,70 @@ export default class AnthropicProvider extends BaseProvider { }) thinking_content += thinking }) + .on('contentBlock', (content) => { + if (content.type === 'tool_use') { + toolCalls.push(content) + } + }) .on('finalMessage', async (message) => { + const toolResults: Awaited> = [] + // tool call + if (toolCalls.length > 0) { + const mcpToolResponses = toolCalls + .map((toolCall) => { + const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall) + if (!mcpTool) { + return undefined + } + return { + id: toolCall.id, + toolCallId: toolCall.id, + tool: mcpTool, + arguments: toolCall.input as Record, + status: 'pending' + } as ToolCallResponse + }) + .filter((t) => typeof t !== 'undefined') + toolResults.push( + ...(await parseAndCallTools( + mcpToolResponses, + toolResponses, + onChunk, + this.mcpToolCallResponseToMessage, + model, + mcpTools + )) + ) + } + + // tool use const content = message.content[0] if (content && content.type === 'text') { onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text }) - const toolResults = await parseAndCallTools( - content.text, - toolResponses, - onChunk, - idx, - mcpToolCallResponseToAnthropicMessage, - mcpTools, - isVisionModel(model) + toolResults.push( + ...(await parseAndCallTools( + content.text, + toolResponses, + onChunk, + this.mcpToolCallResponseToMessage, + model, + mcpTools + )) ) - if (toolResults.length > 0) { - userMessages.push({ - role: message.role, - content: message.content - }) + } - toolResults.forEach((ts) => userMessages.push(ts as MessageParam)) - const newBody = body - newBody.messages = userMessages - await processStream(newBody, idx + 1) - } + userMessages.push({ + role: message.role, + content: message.content + }) + + if (toolResults.length > 0) { + toolResults.forEach((ts) => userMessages.push(ts as MessageParam)) + const newBody = body + newBody.messages = userMessages + + onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) + await processStream(newBody, idx + 1) } const time_completion_millsec = new Date().getTime() - start_time_millsec @@ -434,7 +489,7 @@ export default class AnthropicProvider extends BaseProvider { }) }) } - + onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) await processStream(body, 0).finally(cleanup) } @@ -683,4 +738,47 @@ export default class AnthropicProvider extends BaseProvider { public async getEmbeddingDimensions(): Promise { return 0 } + + public convertMcpTools(mcpTools: MCPTool[]): T[] { + return mcpToolsToAnthropicTools(mcpTools) as T[] + } + + public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model) + } else if ('toolCallId' in mcpToolResponse) { + return { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: mcpToolResponse.toolCallId!, + content: resp.content + .map((item) => { + if (item.type === 'text') { + return { + type: 'text', + text: item.text || '' + } satisfies TextBlockParam + } + if (item.type === 'image') { + return { + type: 'image', + source: { + data: item.data || '', + media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'], + type: 'base64' + } + } satisfies ImageBlockParam + } + return + }) + .filter((n) => typeof n !== 'undefined'), + is_error: resp.isError + } satisfies ToolResultBlockParam + ] + } + } + return + } } diff --git a/src/renderer/src/providers/AiProvider/BaseProvider.ts b/src/renderer/src/providers/AiProvider/BaseProvider.ts index f39f09b20c..5af351aa57 100644 --- a/src/renderer/src/providers/AiProvider/BaseProvider.ts +++ b/src/renderer/src/providers/AiProvider/BaseProvider.ts @@ -1,9 +1,13 @@ +import { isFunctionCallingModel } from '@renderer/config/models' import { REFERENCE_PROMPT } from '@renderer/config/prompts' import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio' import type { Assistant, GenerateImageParams, KnowledgeReference, + MCPCallToolResponse, + MCPTool, + MCPToolResponse, Model, Provider, Suggestion, @@ -22,10 +26,15 @@ import type OpenAI from 'openai' import type { CompletionsParams } from '.' export default abstract class BaseProvider { + // Threshold for determining whether to use system prompt for tools + private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128 + protected provider: Provider protected host: string protected apiKey: string + protected useSystemPromptForTools: boolean = true + constructor(provider: Provider) { this.provider = provider this.host = this.getBaseURL() @@ -47,6 +56,12 @@ export default abstract class BaseProvider { abstract generateImage(params: GenerateImageParams): Promise abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise abstract getEmbeddingDimensions(model: Model): Promise + public abstract convertMcpTools(mcpTools: MCPTool[]): T[] + public abstract mcpToolCallResponseToMessage( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): any public getBaseURL(): string { const host = this.provider.apiHost @@ -229,4 +244,31 @@ export default abstract class BaseProvider { cleanup } } + + // Setup tools configuration based on provided parameters + protected setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): { + tools: T[] + } { + const { mcpTools, model, enableToolUse } = params + let tools: T[] = [] + + // If there are no tools, return an empty array + if (!mcpTools?.length) { + return { tools } + } + + // If the number of tools exceeds the threshold, use the system prompt + if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) { + this.useSystemPromptForTools = true + return { tools } + } + + // If the model supports function calling and tool usage is enabled + if (isFunctionCallingModel(model) && enableToolUse) { + tools = this.convertMcpTools(mcpTools) + this.useSystemPromptForTools = false + } + + return { tools } + } } diff --git a/src/renderer/src/providers/AiProvider/GeminiProvider.ts b/src/renderer/src/providers/AiProvider/GeminiProvider.ts index 57a50cdca4..9ccef9b28d 100644 --- a/src/renderer/src/providers/AiProvider/GeminiProvider.ts +++ b/src/renderer/src/providers/AiProvider/GeminiProvider.ts @@ -1,6 +1,7 @@ import { Content, File, + FunctionCall, GenerateContentConfig, GenerateContentResponse, GoogleGenAI, @@ -11,8 +12,9 @@ import { PartUnion, SafetySetting, ThinkingConfig, - ToolListUnion + Tool } from '@google/genai' +import { nanoid } from '@reduxjs/toolkit' import { findTokenLimit, isGeminiReasoningModel, @@ -35,17 +37,25 @@ import { EFFORT_RATIO, FileType, FileTypes, + MCPCallToolResponse, + MCPTool, MCPToolResponse, Model, Provider, Suggestion, + ToolCallResponse, Usage, WebSearchSource } from '@renderer/types' import { BlockCompleteChunk, Chunk, ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk' import type { Message, Response } from '@renderer/types/newMessage' import { removeSpecialCharactersForTopicName } from '@renderer/utils' -import { mcpToolCallResponseToGeminiMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' +import { + geminiFunctionCallToMcpTool, + mcpToolCallResponseToGeminiMessage, + mcpToolsToGeminiTools, + parseAndCallTools +} from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { buildSystemPrompt } from '@renderer/utils/prompt' import { MB } from '@shared/config/constant' @@ -263,7 +273,7 @@ export default class GeminiProvider extends BaseProvider { }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) + const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant) const userMessages = filterUserRoleStartMessages( filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 2))) @@ -280,12 +290,16 @@ export default class GeminiProvider extends BaseProvider { let systemInstruction = assistant.prompt - if (mcpTools && mcpTools.length > 0) { + const { tools } = this.setupToolsConfig({ + mcpTools, + model, + enableToolUse + }) + + if (this.useSystemPromptForTools) { systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools) } - // const tools = mcpToolsToGeminiTools(mcpTools) - const tools: ToolListUnion = [] const toolResponses: MCPToolResponse[] = [] if (assistant.enableWebSearch && isWebSearchModel(model)) { @@ -351,6 +365,224 @@ export default class GeminiProvider extends BaseProvider { const { cleanup, abortController } = this.createAbortController(userLastMessage?.id, true) + const processToolResults = async (toolResults: Awaited>, idx: number) => { + if (toolResults.length === 0) return + const newChat = this.sdk.chats.create({ + model: model.id, + config: generateContentConfig, + history: history as Content[] + }) + + const newStream = await newChat.sendMessageStream({ + message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion, + config: { + ...generateContentConfig, + abortSignal: abortController.signal + } + }) + await processStream(newStream, idx + 1) + } + + const processToolCalls = async (toolCalls: FunctionCall[]) => { + const mcpToolResponses: ToolCallResponse[] = toolCalls + .map((toolCall) => { + const mcpTool = geminiFunctionCallToMcpTool(mcpTools, toolCall) + if (!mcpTool) return undefined + + const parsedArgs = (() => { + try { + return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args + } catch { + return toolCall.args + } + })() + + return { + id: toolCall.id || nanoid(), + toolCallId: toolCall.id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } as ToolCallResponse + }) + .filter((t): t is ToolCallResponse => typeof t !== 'undefined') + + return await parseAndCallTools( + mcpToolResponses, + toolResponses, + onChunk, + this.mcpToolCallResponseToMessage, + model, + mcpTools + ) + } + + const processToolUses = async (content: string) => { + return await parseAndCallTools( + content, + toolResponses, + onChunk, + this.mcpToolCallResponseToMessage, + model, + mcpTools + ) + } + + const processStream = async ( + stream: AsyncGenerator | GenerateContentResponse, + idx: number + ) => { + history.push(messageContents) + + let functionCalls: FunctionCall[] = [] + + if (stream instanceof GenerateContentResponse) { + let content = '' + const time_completion_millsec = new Date().getTime() - start_time_millsec + + const toolResults: Awaited> = [] + if (stream.text?.length) { + toolResults.push(...(await processToolUses(stream.text))) + } + stream.candidates?.forEach((candidate) => { + if (candidate.content) { + history.push(candidate.content) + + candidate.content.parts?.forEach((part) => { + if (part.functionCall) { + functionCalls.push(part.functionCall) + } + if (part.text) { + content += part.text + onChunk({ type: ChunkType.TEXT_DELTA, text: part.text }) + } + }) + } + }) + if (content.length) { + onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) + } + if (functionCalls.length) { + toolResults.push(...(await processToolCalls(functionCalls))) + } + if (stream.text?.length) { + toolResults.push(...(await processToolUses(stream.text))) + } + if (toolResults.length) { + await processToolResults(toolResults, idx) + } + onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + text: stream.text, + usage: { + prompt_tokens: stream.usageMetadata?.promptTokenCount || 0, + thoughts_tokens: stream.usageMetadata?.thoughtsTokenCount || 0, + completion_tokens: stream.usageMetadata?.candidatesTokenCount || 0, + total_tokens: stream.usageMetadata?.totalTokenCount || 0 + }, + metrics: { + completion_tokens: stream.usageMetadata?.candidatesTokenCount, + time_completion_millsec, + time_first_token_millsec: 0 + }, + webSearch: { + results: stream.candidates?.[0]?.groundingMetadata, + source: 'gemini' + } + } as Response + } as BlockCompleteChunk) + } else { + let content = '' + let final_time_completion_millsec = 0 + let lastUsage: Usage | undefined = undefined + for await (const chunk of stream) { + if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break + + // --- Calculate Metrics --- + if (time_first_token_millsec == 0 && chunk.text !== undefined) { + // Update based on text arrival + time_first_token_millsec = new Date().getTime() - start_time_millsec + } + + // 1. Text Content + if (chunk.text !== undefined) { + content += chunk.text + onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text }) + } + + // 2. Usage Data + if (chunk.usageMetadata) { + lastUsage = { + prompt_tokens: chunk.usageMetadata.promptTokenCount || 0, + completion_tokens: chunk.usageMetadata.candidatesTokenCount || 0, + total_tokens: chunk.usageMetadata.totalTokenCount || 0 + } + final_time_completion_millsec = new Date().getTime() - start_time_millsec + } + + // 4. Image Generation + const generateImage = this.processGeminiImageResponse(chunk, onChunk) + if (generateImage?.images?.length) { + onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage }) + } + + if (chunk.candidates?.[0]?.finishReason) { + if (chunk.text) { + onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) + } + if (chunk.candidates?.[0]?.groundingMetadata) { + // 3. Grounding/Search Metadata + const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata + onChunk({ + type: ChunkType.LLM_WEB_SEARCH_COMPLETE, + llm_web_search: { + results: groundingMetadata, + source: WebSearchSource.GEMINI + } + } as LLMWebSearchCompleteChunk) + } + if (chunk.functionCalls) { + chunk.candidates?.forEach((candidate) => { + if (candidate.content) { + history.push(candidate.content) + } + }) + functionCalls = functionCalls.concat(chunk.functionCalls) + } + + onChunk({ + type: ChunkType.BLOCK_COMPLETE, + response: { + metrics: { + completion_tokens: lastUsage?.completion_tokens, + time_completion_millsec: final_time_completion_millsec, + time_first_token_millsec + }, + usage: lastUsage + } + }) + } + + // --- End Incremental onChunk calls --- + + // Call processToolUses AFTER potentially processing text content in this chunk + // This assumes tools might be specified within the text stream + // Note: parseAndCallTools inside should handle its own onChunk for tool responses + let toolResults: Awaited> = [] + if (functionCalls.length) { + toolResults = await processToolCalls(functionCalls) + } + if (content.length) { + toolResults = toolResults.concat(await processToolUses(content)) + } + if (toolResults.length) { + await processToolResults(toolResults, idx) + } + } + } + } + if (!streamOutput) { const response = await chat.sendMessage({ message: messageContents as PartUnion, @@ -359,32 +591,10 @@ export default class GeminiProvider extends BaseProvider { abortSignal: abortController.signal } }) - const time_completion_millsec = new Date().getTime() - start_time_millsec - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - text: response.text, - usage: { - prompt_tokens: response.usageMetadata?.promptTokenCount || 0, - thoughts_tokens: response.usageMetadata?.thoughtsTokenCount || 0, - completion_tokens: response.usageMetadata?.candidatesTokenCount || 0, - total_tokens: response.usageMetadata?.totalTokenCount || 0 - }, - metrics: { - completion_tokens: response.usageMetadata?.candidatesTokenCount, - time_completion_millsec, - time_first_token_millsec: 0 - }, - webSearch: { - results: response.candidates?.[0]?.groundingMetadata, - source: 'gemini' - } - } as Response - } as BlockCompleteChunk) - return + onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) + return await processStream(response, 0).then(cleanup) } - // 等待接口返回流 onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) const userMessagesStream = await chat.sendMessageStream({ message: messageContents as PartUnion, @@ -394,105 +604,6 @@ export default class GeminiProvider extends BaseProvider { } }) - const processToolUses = async (content: string, idx: number) => { - const toolResults = await parseAndCallTools( - content, - toolResponses, - onChunk, - idx, - mcpToolCallResponseToGeminiMessage, - mcpTools, - isVisionModel(model) - ) - if (toolResults && toolResults.length > 0) { - history.push(messageContents) - const newChat = this.sdk.chats.create({ - model: model.id, - config: generateContentConfig, - history: history as Content[] - }) - const newStream = await newChat.sendMessageStream({ - message: flatten(toolResults.map((ts) => (ts as Content).parts)) as PartUnion, - config: { - ...generateContentConfig, - abortSignal: abortController.signal - } - }) - await processStream(newStream, idx + 1) - } - } - - const processStream = async (stream: AsyncGenerator, idx: number) => { - let content = '' - let final_time_completion_millsec = 0 - let lastUsage: Usage | undefined = undefined - for await (const chunk of stream) { - if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break - - // --- Calculate Metrics --- - if (time_first_token_millsec == 0 && chunk.text !== undefined) { - // Update based on text arrival - time_first_token_millsec = new Date().getTime() - start_time_millsec - } - - // 1. Text Content - if (chunk.text !== undefined) { - content += chunk.text - onChunk({ type: ChunkType.TEXT_DELTA, text: chunk.text }) - } - - // 2. Usage Data - if (chunk.usageMetadata) { - lastUsage = { - prompt_tokens: chunk.usageMetadata.promptTokenCount || 0, - completion_tokens: chunk.usageMetadata.candidatesTokenCount || 0, - total_tokens: chunk.usageMetadata.totalTokenCount || 0 - } - final_time_completion_millsec = new Date().getTime() - start_time_millsec - } - - // 4. Image Generation - const generateImage = this.processGeminiImageResponse(chunk, onChunk) - if (generateImage?.images?.length) { - onChunk({ type: ChunkType.IMAGE_COMPLETE, image: generateImage }) - } - - if (chunk.candidates?.[0]?.finishReason) { - if (chunk.text) { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) - } - if (chunk.candidates?.[0]?.groundingMetadata) { - // 3. Grounding/Search Metadata - const groundingMetadata = chunk.candidates?.[0]?.groundingMetadata - onChunk({ - type: ChunkType.LLM_WEB_SEARCH_COMPLETE, - llm_web_search: { - results: groundingMetadata, - source: WebSearchSource.GEMINI - } - } as LLMWebSearchCompleteChunk) - } - onChunk({ - type: ChunkType.BLOCK_COMPLETE, - response: { - metrics: { - completion_tokens: lastUsage?.completion_tokens, - time_completion_millsec: final_time_completion_millsec, - time_first_token_millsec - }, - usage: lastUsage - } - }) - } - // --- End Incremental onChunk calls --- - - // Call processToolUses AFTER potentially processing text content in this chunk - // This assumes tools might be specified within the text stream - // Note: parseAndCallTools inside should handle its own onChunk for tool responses - await processToolUses(content, idx) - } - } - await processStream(userMessagesStream, 0).finally(cleanup) const final_time_completion_millsec = new Date().getTime() - start_time_millsec @@ -841,4 +952,32 @@ export default class GeminiProvider extends BaseProvider { public generateImageByChat(): Promise { throw new Error('Method not implemented.') } + + public convertMcpTools(mcpTools: MCPTool[]): T[] { + return mcpToolsToGeminiTools(mcpTools) as T[] + } + + public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse) { + const toolCallOut = { + role: 'user', + parts: [ + { + functionResponse: { + id: mcpToolResponse.toolCallId, + name: mcpToolResponse.tool.id, + response: { + output: !resp.isError ? resp.content : undefined, + error: resp.isError ? resp.content : undefined + } + } + } + ] + } satisfies Content + return toolCallOut + } + return + } } diff --git a/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts b/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts index c21a32bec0..1bd9535234 100644 --- a/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAICompatibleProvider.ts @@ -31,10 +31,13 @@ import { Assistant, EFFORT_RATIO, FileTypes, + MCPCallToolResponse, + MCPTool, MCPToolResponse, Model, Provider, Suggestion, + ToolCallResponse, Usage, WebSearchSource } from '@renderer/types' @@ -48,7 +51,12 @@ import { convertLinksToOpenRouter, convertLinksToZhipu } from '@renderer/utils/linkConverter' -import { mcpToolCallResponseToOpenAICompatibleMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' +import { + mcpToolCallResponseToOpenAICompatibleMessage, + mcpToolsToOpenAIChatTools, + openAIToolsToMcpTool, + parseAndCallTools +} from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { buildSystemPrompt } from '@renderer/utils/prompt' import { asyncGeneratorToReadableStream, readableStreamAsyncIterable } from '@renderer/utils/stream' @@ -57,18 +65,22 @@ import OpenAI, { AzureOpenAI } from 'openai' import { ChatCompletionContentPart, ChatCompletionCreateParamsNonStreaming, - ChatCompletionMessageParam + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionTool, + ChatCompletionToolMessageParam } from 'openai/resources' import { CompletionsParams } from '.' -import OpenAIProvider from './OpenAIProvider' +import { BaseOpenAiProvider } from './OpenAIProvider' // 1. 定义联合类型 export type OpenAIStreamChunk = | { type: 'reasoning' | 'text-delta'; textDelta: string } + | { type: 'tool-calls'; delta: any } | { type: 'finish'; finishReason: any; usage: any; delta: any; chunk: any } -export default class OpenAICompatibleProvider extends OpenAIProvider { +export default class OpenAICompatibleProvider extends BaseOpenAiProvider { constructor(provider: Provider) { super(provider) @@ -313,6 +325,24 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { return {} } + public convertMcpTools(mcpTools: MCPTool[]): T[] { + return mcpToolsToOpenAIChatTools(mcpTools) as T[] + } + + public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { + const toolCallOut: ChatCompletionToolMessageParam = { + role: 'tool', + tool_call_id: mcpToolResponse.toolCallId, + content: JSON.stringify(resp.content) + } + return toolCallOut + } + return + } + /** * Generate completions for the assistant * @param messages - The messages @@ -330,7 +360,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) + const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant) const isEnabledBultinWebSearch = assistant.enableWebSearch messages = addImageFileToContents(messages) const enableReasoning = @@ -344,7 +374,9 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` } } - if (mcpTools && mcpTools.length > 0) { + const { tools } = this.setupToolsConfig({ mcpTools, model, enableToolUse }) + + if (this.useSystemPromptForTools) { systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools) } @@ -379,53 +411,86 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { const toolResponses: MCPToolResponse[] = [] - const processToolUses = async (content: string, idx: number) => { - const toolResults = await parseAndCallTools( + const processToolResults = async (toolResults: Awaited>, idx: number) => { + if (toolResults.length === 0) return + + toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam)) + + console.debug('[tool] reqMessages before processing', model.id, reqMessages) + reqMessages = processReqMessages(model, reqMessages) + console.debug('[tool] reqMessages', model.id, reqMessages) + + onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) + const newStream = await this.sdk.chat.completions + // @ts-ignore key is not typed + .create( + { + model: model.id, + messages: reqMessages, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_tokens: maxTokens, + keep_alive: this.keepAliveTime, + stream: isSupportStreamOutput(), + tools: !isEmpty(tools) ? tools : undefined, + ...getOpenAIWebSearchParams(assistant, model), + ...this.getReasoningEffort(assistant, model), + ...this.getProviderSpecificParameters(assistant, model), + ...this.getCustomParameters(assistant) + }, + { + signal + } + ) + await processStream(newStream, idx + 1) + } + + const processToolCalls = async (mcpTools, toolCalls: ChatCompletionMessageToolCall[]) => { + const mcpToolResponses = toolCalls + .map((toolCall) => { + const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as ChatCompletionMessageToolCall) + if (!mcpTool) return undefined + + const parsedArgs = (() => { + try { + return JSON.parse(toolCall.function.arguments) + } catch { + return toolCall.function.arguments + } + })() + + return { + id: toolCall.id, + toolCallId: toolCall.id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } as ToolCallResponse + }) + .filter((t): t is ToolCallResponse => typeof t !== 'undefined') + return await parseAndCallTools( + mcpToolResponses, + toolResponses, + onChunk, + this.mcpToolCallResponseToMessage, + model, + mcpTools + ) + } + + const processToolUses = async (content: string) => { + return await parseAndCallTools( content, toolResponses, onChunk, - idx, - mcpToolCallResponseToOpenAICompatibleMessage, - mcpTools, - isVisionModel(model) + this.mcpToolCallResponseToMessage, + model, + mcpTools ) - - if (toolResults.length > 0) { - reqMessages.push({ - role: 'assistant', - content: content - } as ChatCompletionMessageParam) - toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam)) - - reqMessages = processReqMessages(model, reqMessages) - const newStream = await this.sdk.chat.completions - // @ts-ignore key is not typed - .create( - { - model: model.id, - messages: reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_tokens: maxTokens, - keep_alive: this.keepAliveTime, - stream: isSupportStreamOutput(), - // tools: tools, - service_tier: this.getServiceTier(model), - ...getOpenAIWebSearchParams(assistant, model), - ...this.getReasoningEffort(assistant, model), - ...this.getProviderSpecificParameters(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal, - timeout: this.getTimeout(model) - } - ) - await processStream(newStream, idx + 1) - } } const processStream = async (stream: any, idx: number) => { + const toolCalls: ChatCompletionMessageToolCall[] = [] // Handle non-streaming case (already returns early, no change needed here) if (!isSupportStreamOutput()) { const time_completion_millsec = new Date().getTime() - start_time_millsec @@ -439,10 +504,59 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { // Create a synthetic usage object if stream.usage is undefined const finalUsage = stream.usage // Separate onChunk calls for text and usage/metrics - if (stream.choices[0].message?.content) { - onChunk({ type: ChunkType.TEXT_COMPLETE, text: stream.choices[0].message.content }) + let content = '' + stream.choices.forEach((choice) => { + // reasoning + if (choice.message.reasoning) { + onChunk({ type: ChunkType.THINKING_DELTA, text: choice.message.reasoning }) + onChunk({ + type: ChunkType.THINKING_COMPLETE, + text: choice.message.reasoning, + thinking_millsec: time_completion_millsec + }) + } + // text + if (choice.message.content) { + content += choice.message.content + onChunk({ type: ChunkType.TEXT_DELTA, text: choice.message.content }) + } + // tool call + if (choice.message.tool_calls && choice.message.tool_calls.length) { + choice.message.tool_calls.forEach((t) => toolCalls.push(t)) + } + + reqMessages.push({ + role: choice.message.role, + content: choice.message.content, + tool_calls: toolCalls.length + ? toolCalls.map((toolCall) => ({ + id: toolCall.id, + function: { + ...toolCall.function, + arguments: + typeof toolCall.function.arguments === 'string' + ? toolCall.function.arguments + : JSON.stringify(toolCall.function.arguments) + }, + type: 'function' + })) + : undefined + }) + }) + + if (content.length) { + onChunk({ type: ChunkType.TEXT_COMPLETE, text: content }) } + const toolResults: Awaited> = [] + if (toolCalls.length) { + toolResults.push(...(await processToolCalls(mcpTools, toolCalls))) + } + if (stream.choices[0].message?.content) { + toolResults.push(...(await processToolUses(stream.choices[0].message?.content))) + } + await processToolResults(toolResults, idx) + // Always send usage and metrics data onChunk({ type: ChunkType.BLOCK_COMPLETE, response: { usage: finalUsage, metrics: finalMetrics } }) return @@ -486,6 +600,9 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { if (delta?.content) { yield { type: 'text-delta', textDelta: delta.content } } + if (delta?.tool_calls) { + yield { type: 'tool-calls', delta: delta } + } const finishReason = chunk.choices[0]?.finish_reason if (!isEmpty(finishReason)) { @@ -563,6 +680,25 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { onChunk({ type: ChunkType.TEXT_DELTA, text: textDelta }) break } + case 'tool-calls': { + chunk.delta.tool_calls.forEach((toolCall) => { + const { id, index, type, function: fun } = toolCall + if (id && type === 'function' && fun) { + const { name, arguments: args } = fun + toolCalls.push({ + id, + function: { + name: name || '', + arguments: args || '' + }, + type: 'function' + }) + } else if (fun?.arguments) { + toolCalls[index].function.arguments += fun.arguments + } + }) + break + } case 'finish': { const finishReason = chunk.finishReason const usage = chunk.usage @@ -624,7 +760,33 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { } as LLMWebSearchCompleteChunk) } } - await processToolUses(content, idx) + reqMessages.push({ + role: 'assistant', + content: content, + tool_calls: toolCalls.length + ? toolCalls.map((toolCall) => ({ + id: toolCall.id, + function: { + ...toolCall.function, + arguments: + typeof toolCall.function.arguments === 'string' + ? toolCall.function.arguments + : JSON.stringify(toolCall.function.arguments) + }, + type: 'function' + })) + : undefined + }) + let toolResults: Awaited> = [] + if (toolCalls.length) { + toolResults = await processToolCalls(mcpTools, toolCalls) + } + if (content.length) { + toolResults = toolResults.concat(await processToolUses(content)) + } + if (toolResults.length) { + await processToolResults(toolResults, idx) + } onChunk({ type: ChunkType.BLOCK_COMPLETE, response: { @@ -657,7 +819,7 @@ export default class OpenAICompatibleProvider extends OpenAIProvider { max_tokens: maxTokens, keep_alive: this.keepAliveTime, stream: isSupportStreamOutput(), - // tools: tools, + tools: !isEmpty(tools) ? tools : undefined, service_tier: this.getServiceTier(model), ...getOpenAIWebSearchParams(assistant, model), ...this.getReasoningEffort(assistant, model), diff --git a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts index 8536ac96f4..7d70e05aa5 100644 --- a/src/renderer/src/providers/AiProvider/OpenAIProvider.ts +++ b/src/renderer/src/providers/AiProvider/OpenAIProvider.ts @@ -21,10 +21,13 @@ import { Assistant, FileTypes, GenerateImageParams, + MCPCallToolResponse, + MCPTool, MCPToolResponse, Model, Provider, Suggestion, + ToolCallResponse, Usage, WebSearchSource } from '@renderer/types' @@ -33,7 +36,12 @@ import { Message } from '@renderer/types/newMessage' import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { addImageFileToContents } from '@renderer/utils/formats' import { convertLinks } from '@renderer/utils/linkConverter' -import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' +import { + mcpToolCallResponseToOpenAIMessage, + mcpToolsToOpenAIResponseTools, + openAIToolsToMcpTool, + parseAndCallTools +} from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { buildSystemPrompt } from '@renderer/utils/prompt' import { isEmpty, takeRight } from 'lodash' @@ -45,7 +53,7 @@ import { FileLike, toFile } from 'openai/uploads' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' -export default class OpenAIProvider extends BaseProvider { +export abstract class BaseOpenAiProvider extends BaseProvider { protected sdk: OpenAI constructor(provider: Provider) { @@ -61,6 +69,14 @@ export default class OpenAIProvider extends BaseProvider { }) } + abstract convertMcpTools(mcpTools: MCPTool[]): T[] + + abstract mcpToolCallResponseToMessage: ( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ) => OpenAI.Responses.ResponseInputItem | ChatCompletionMessageParam | undefined + /** * Extract the file content from the message * @param message - The message @@ -91,16 +107,23 @@ export default class OpenAIProvider extends BaseProvider { return '' } - private async getReponseMessageParam(message: Message, model: Model): Promise { + private async getReponseMessageParam(message: Message, model: Model): Promise { const isVision = isVisionModel(model) const content = await this.getMessageContent(message) const fileBlocks = findFileBlocks(message) const imageBlocks = findImageBlocks(message) if (fileBlocks.length === 0 && imageBlocks.length === 0) { - return { - role: message.role === 'system' ? 'user' : message.role, - content: content ? [{ type: 'input_text', text: content }] : [] + if (message.role === 'assistant') { + return { + role: 'assistant', + content: content + } + } else { + return { + role: message.role === 'system' ? 'user' : message.role, + content: content ? [{ type: 'input_text', text: content }] : [] + } as OpenAI.Responses.EasyInputMessage } } @@ -285,10 +308,8 @@ export default class OpenAIProvider extends BaseProvider { } const defaultModel = getDefaultModel() const model = assistant.model || defaultModel - const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) - + const { contextCount, maxTokens, streamOutput, enableToolUse } = getAssistantSettings(assistant) const isEnabledBuiltinWebSearch = assistant.enableWebSearch - onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) // 退回到 OpenAI 兼容模式 if (isOpenAIWebSearch(model)) { const systemMessage = { role: 'system', content: assistant.prompt || '' } @@ -387,7 +408,7 @@ export default class OpenAIProvider extends BaseProvider { }) return } - const tools: OpenAI.Responses.Tool[] = [] + let tools: OpenAI.Responses.Tool[] = [] const toolChoices: OpenAI.Responses.ToolChoiceTypes = { type: 'web_search_preview' } @@ -411,7 +432,15 @@ export default class OpenAIProvider extends BaseProvider { systemMessage.role = 'developer' } - if (mcpTools && mcpTools.length > 0) { + const { tools: extraTools } = this.setupToolsConfig({ + mcpTools, + model, + enableToolUse + }) + + tools = tools.concat(extraTools) + + if (this.useSystemPromptForTools) { systemMessageInput.text = buildSystemPrompt(systemMessageInput.text || '', mcpTools) } systemMessageContent.push(systemMessageInput) @@ -421,7 +450,7 @@ export default class OpenAIProvider extends BaseProvider { ) onFilterMessages(_messages) - const userMessage: OpenAI.Responses.EasyInputMessage[] = [] + const userMessage: OpenAI.Responses.ResponseInputItem[] = [] for (const message of _messages) { userMessage.push(await this.getReponseMessageParam(message, model)) } @@ -434,7 +463,7 @@ export default class OpenAIProvider extends BaseProvider { const { signal } = abortController // 当 systemMessage 内容为空时不发送 systemMessage - let reqMessages: OpenAI.Responses.EasyInputMessage[] + let reqMessages: OpenAI.Responses.ResponseInput if (!systemMessage.content) { reqMessages = [...userMessage] } else { @@ -443,48 +472,84 @@ export default class OpenAIProvider extends BaseProvider { const toolResponses: MCPToolResponse[] = [] - const processToolUses = async (content: string, idx: number) => { - const toolResults = await parseAndCallTools( + const processToolResults = async (toolResults: Awaited>, idx: number) => { + if (toolResults.length === 0) return + + toolResults.forEach((ts) => reqMessages.push(ts as OpenAI.Responses.EasyInputMessage)) + + onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) + const stream = await this.sdk.responses.create( + { + model: model.id, + input: reqMessages, + temperature: this.getTemperature(assistant, model), + top_p: this.getTopP(assistant, model), + max_output_tokens: maxTokens, + stream: streamOutput, + tools: !isEmpty(tools) ? tools : undefined, + service_tier: this.getServiceTier(model), + ...this.getResponseReasoningEffort(assistant, model), + ...this.getCustomParameters(assistant) + }, + { + signal, + timeout: this.getTimeout(model) + } + ) + await processStream(stream, idx + 1) + } + + const processToolCalls = async (mcpTools, toolCalls: OpenAI.Responses.ResponseFunctionToolCall[]) => { + const mcpToolResponses = toolCalls + .map((toolCall) => { + const mcpTool = openAIToolsToMcpTool(mcpTools, toolCall as OpenAI.Responses.ResponseFunctionToolCall) + if (!mcpTool) return undefined + + const parsedArgs = (() => { + try { + return JSON.parse(toolCall.arguments) + } catch { + return toolCall.arguments + } + })() + + return { + id: toolCall.call_id, + toolCallId: toolCall.call_id, + tool: mcpTool, + arguments: parsedArgs, + status: 'pending' + } as ToolCallResponse + }) + .filter((t): t is ToolCallResponse => typeof t !== 'undefined') + + return await parseAndCallTools( + mcpToolResponses, + toolResponses, + onChunk, + this.mcpToolCallResponseToMessage, + model, + mcpTools + ) + } + + const processToolUses = async (content: string) => { + return await parseAndCallTools( content, toolResponses, onChunk, - idx, - mcpToolCallResponseToOpenAIMessage, - mcpTools, - isVisionModel(model) + this.mcpToolCallResponseToMessage, + model, + mcpTools ) - - if (toolResults.length > 0) { - reqMessages.push({ - role: 'assistant', - content: content - }) - toolResults.forEach((ts) => reqMessages.push(ts as OpenAI.Responses.EasyInputMessage)) - const newStream = await this.sdk.responses.create( - { - model: model.id, - input: reqMessages, - temperature: this.getTemperature(assistant, model), - top_p: this.getTopP(assistant, model), - max_output_tokens: maxTokens, - stream: true, - service_tier: this.getServiceTier(model), - ...this.getResponseReasoningEffort(assistant, model), - ...this.getCustomParameters(assistant) - }, - { - signal, - timeout: this.getTimeout(model) - } - ) - await processStream(newStream, idx + 1) - } } const processStream = async ( stream: Stream | OpenAI.Responses.Response, idx: number ) => { + const toolCalls: OpenAI.Responses.ResponseFunctionToolCall[] = [] + if (!streamOutput) { const nonStream = stream as OpenAI.Responses.Response const time_completion_millsec = new Date().getTime() - start_time_millsec @@ -502,11 +567,15 @@ export default class OpenAIProvider extends BaseProvider { prompt_tokens: nonStream.usage?.input_tokens || 0, total_tokens } + let content = '' + for (const output of nonStream.output) { switch (output.type) { case 'message': if (output.content[0].type === 'output_text') { + onChunk({ type: ChunkType.TEXT_DELTA, text: output.content[0].text }) onChunk({ type: ChunkType.TEXT_COMPLETE, text: output.content[0].text }) + content += output.content[0].text if (output.content[0].annotations && output.content[0].annotations.length > 0) { onChunk({ type: ChunkType.LLM_WEB_SEARCH_COMPLETE, @@ -525,8 +594,32 @@ export default class OpenAIProvider extends BaseProvider { thinking_millsec: new Date().getTime() - start_time_millsec }) break + case 'function_call': + toolCalls.push(output) } } + + if (content) { + reqMessages.push({ + role: 'assistant', + content: content + }) + } + if (toolCalls.length) { + toolCalls.forEach((toolCall) => { + reqMessages.push(toolCall) + }) + } + + const toolResults: Awaited> = [] + if (toolCalls.length) { + toolResults.push(...(await processToolCalls(mcpTools, toolCalls))) + } + if (content.length) { + toolResults.push(...(await processToolUses(content))) + } + await processToolResults(toolResults, idx) + onChunk({ type: ChunkType.BLOCK_COMPLETE, response: { @@ -537,6 +630,9 @@ export default class OpenAIProvider extends BaseProvider { return } let content = '' + + const outputItems: OpenAI.Responses.ResponseOutputItem[] = [] + let lastUsage: Usage | undefined = undefined let final_time_completion_millsec_delta = 0 for await (const chunk of stream as Stream) { @@ -547,6 +643,12 @@ export default class OpenAIProvider extends BaseProvider { case 'response.created': time_first_token_millsec = new Date().getTime() break + case 'response.output_item.added': + if (chunk.item.type === 'function_call') { + outputItems.push(chunk.item) + } + break + case 'response.reasoning_summary_text.delta': onChunk({ type: ChunkType.THINKING_DELTA, @@ -579,6 +681,21 @@ export default class OpenAIProvider extends BaseProvider { text: content }) break + case 'response.function_call_arguments.done': { + const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find( + (item) => item.id === chunk.item_id + ) + if (outputItem) { + if (outputItem.type === 'function_call') { + toolCalls.push({ + ...outputItem, + arguments: chunk.arguments + }) + } + } + + break + } case 'response.content_part.done': if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) { onChunk({ @@ -615,9 +732,31 @@ export default class OpenAIProvider extends BaseProvider { }) break } + + // --- End of Incremental onChunk calls --- + } // End of for await loop + if (content) { + reqMessages.push({ + role: 'assistant', + content: content + }) + } + if (toolCalls.length) { + toolCalls.forEach((toolCall) => { + reqMessages.push(toolCall) + }) } - await processToolUses(content, idx) + // Call processToolUses AFTER the loop finishes processing the main stream content + // Note: parseAndCallTools inside processToolUses should handle its own onChunk for tool responses + const toolResults: Awaited> = [] + if (toolCalls.length) { + toolResults.push(...(await processToolCalls(mcpTools, toolCalls))) + } + if (content) { + toolResults.push(...(await processToolUses(content))) + } + await processToolResults(toolResults, idx) onChunk({ type: ChunkType.BLOCK_COMPLETE, @@ -632,6 +771,7 @@ export default class OpenAIProvider extends BaseProvider { }) } + onChunk({ type: ChunkType.LLM_RESPONSE_CREATED }) const stream = await this.sdk.responses.create( { model: model.id, @@ -1081,3 +1221,31 @@ export default class OpenAIProvider extends BaseProvider { return data.data[0].embedding.length } } + +export default class OpenAIProvider extends BaseOpenAiProvider { + constructor(provider: Provider) { + super(provider) + } + + public convertMcpTools(mcpTools: MCPTool[]) { + return mcpToolsToOpenAIResponseTools(mcpTools) as T[] + } + + public mcpToolCallResponseToMessage = ( + mcpToolResponse: MCPToolResponse, + resp: MCPCallToolResponse, + model: Model + ): OpenAI.Responses.ResponseInputItem | undefined => { + if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) { + return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model)) + } else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) { + const toolCallOut: OpenAI.Responses.ResponseInputItem = { + type: 'function_call_output', + call_id: mcpToolResponse.toolCallId, + output: JSON.stringify(resp.content) + } + return toolCallOut + } + return + } +} diff --git a/src/renderer/src/services/AssistantService.ts b/src/renderer/src/services/AssistantService.ts index 4ad50c29ef..1bef2ec2b7 100644 --- a/src/renderer/src/services/AssistantService.ts +++ b/src/renderer/src/services/AssistantService.ts @@ -107,6 +107,7 @@ export const getAssistantSettings = (assistant: Assistant): AssistantSettings => enableMaxTokens: assistant?.settings?.enableMaxTokens ?? false, maxTokens: getAssistantMaxTokens(), streamOutput: assistant?.settings?.streamOutput ?? true, + enableToolUse: assistant?.settings?.enableToolUse ?? false, hideMessages: assistant?.settings?.hideMessages ?? false, defaultModel: assistant?.defaultModel ?? undefined, customParameters: assistant?.settings?.customParameters ?? [] diff --git a/src/renderer/src/store/index.ts b/src/renderer/src/store/index.ts index df7af60f69..e8f4c2ac32 100644 --- a/src/renderer/src/store/index.ts +++ b/src/renderer/src/store/index.ts @@ -46,7 +46,7 @@ const persistedReducer = persistReducer( { key: 'cherry-studio', storage, - version: 99, + version: 98, blacklist: ['runtime', 'messages', 'messageBlocks'], migrate }, diff --git a/src/renderer/src/store/llm.ts b/src/renderer/src/store/llm.ts index 293a8db50f..27a68b342b 100644 --- a/src/renderer/src/store/llm.ts +++ b/src/renderer/src/store/llm.ts @@ -476,16 +476,6 @@ export const INITIAL_PROVIDERS: Provider[] = [ models: SYSTEM_MODELS.voyageai, isSystem: true, enabled: false - }, - { - id: 'paratera', - name: 'Paratera AI', - type: 'openai-compatible', - apiKey: '', - apiHost: 'https://llmapi.paratera.com', - models: SYSTEM_MODELS.paratera, - isSystem: true, - enabled: false } ] diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index 69abd39eb9..f09ce39536 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -1248,15 +1248,6 @@ const migrateConfig = { provider.type = 'openai-compatible' } }) - return state - } catch (error) { - return state - } - }, - '99': (state: RootState) => { - try { - addProvider(state, 'paratera') - return state } catch (error) { return state diff --git a/src/renderer/src/store/thunk/messageThunk.ts b/src/renderer/src/store/thunk/messageThunk.ts index ac69002148..bfb68de0c0 100644 --- a/src/renderer/src/store/thunk/messageThunk.ts +++ b/src/renderer/src/store/thunk/messageThunk.ts @@ -427,7 +427,17 @@ const fetchAndProcessAssistantResponseImpl = async ( } }, onToolCallInProgress: (toolResponse: MCPToolResponse) => { - if (toolResponse.status === 'invoking') { + if (lastBlockType === MessageBlockType.UNKNOWN && lastBlockId) { + lastBlockType = MessageBlockType.TOOL + const changes = { + type: MessageBlockType.TOOL, + status: MessageBlockStatus.PROCESSING, + metadata: { rawMcpToolResponse: toolResponse } + } + dispatch(updateOneBlock({ id: lastBlockId, changes })) + saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState) + toolCallIdToBlockIdMap.set(toolResponse.id, lastBlockId) + } else if (toolResponse.status === 'invoking') { const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, { toolName: toolResponse.tool.name, status: MessageBlockStatus.PROCESSING, diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 3f05e727b9..95b3756980 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -55,6 +55,7 @@ export type AssistantSettings = { maxTokens: number | undefined enableMaxTokens: boolean streamOutput: boolean + enableToolUse: boolean hideMessages: boolean defaultModel?: Model customParameters?: AssistantSettingCustomParameters[] @@ -570,13 +571,25 @@ export interface MCPConfig { servers: MCPServer[] } -export interface MCPToolResponse { - id: string // tool call id, it should be unique - tool: MCPTool // tool info +interface BaseToolResponse { + id: string // unique id + tool: MCPTool + arguments: Record | undefined status: string // 'invoking' | 'done' response?: any } +export interface ToolUseResponse extends BaseToolResponse { + toolUseId: string +} + +export interface ToolCallResponse extends BaseToolResponse { + // gemini tool call id might be undefined + toolCallId?: string +} + +export type MCPToolResponse = ToolUseResponse | ToolCallResponse + export interface MCPToolResultContent { type: 'text' | 'image' | 'audio' | 'resource' text?: string @@ -586,6 +599,7 @@ export interface MCPToolResultContent { uri?: string text?: string mimeType?: string + blob?: string } } diff --git a/src/renderer/src/utils/mcp-tools.ts b/src/renderer/src/utils/mcp-tools.ts index e8abc2f4bc..b484d0cc18 100644 --- a/src/renderer/src/utils/mcp-tools.ts +++ b/src/renderer/src/utils/mcp-tools.ts @@ -1,18 +1,31 @@ -import { ContentBlockParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources' -import { MessageParam } from '@anthropic-ai/sdk/resources' -import { Content, FunctionCall, Part } from '@google/genai' +import { + ContentBlockParam, + MessageParam, + ToolResultBlockParam, + ToolUnion, + ToolUseBlock +} from '@anthropic-ai/sdk/resources' +import { Content, FunctionCall, Part, Tool, Type as GeminiSchemaType } from '@google/genai' +import { isVisionModel } from '@renderer/config/models' import store from '@renderer/store' import { addMCPServer } from '@renderer/store/mcp' -import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse } from '@renderer/types' +import { MCPCallToolResponse, MCPServer, MCPTool, MCPToolResponse, Model, ToolUseResponse } from '@renderer/types' import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk' import { ChunkType } from '@renderer/types/chunk' +import { isArray, isObject, pull, transform } from 'lodash' import { nanoid } from 'nanoid' import OpenAI from 'openai' -import { ChatCompletionContentPart, ChatCompletionMessageParam, ChatCompletionMessageToolCall } from 'openai/resources' +import { + ChatCompletionContentPart, + ChatCompletionMessageParam, + ChatCompletionMessageToolCall, + ChatCompletionTool +} from 'openai/resources' import { CompletionsParams } from '../providers/AiProvider' const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install' +const EXTRA_SCHEMA_KEYS = ['schema', 'headers'] // const ensureValidSchema = (obj: Record) => { // // Filter out unsupported keys for Gemini @@ -153,77 +166,116 @@ const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install' // return processedProperties // } -// export function mcpToolsToOpenAITools(mcpTools: MCPTool[]): Array { -// return mcpTools.map((tool) => ({ -// type: 'function', -// name: tool.name, -// function: { -// name: tool.id, -// description: tool.description, -// parameters: { -// type: 'object', -// properties: filterPropertieAttributes(tool) -// } -// } -// })) -// } - -export function openAIToolsToMcpTool( - mcpTools: MCPTool[] | undefined, - llmTool: ChatCompletionMessageToolCall -): MCPTool | undefined { - if (!mcpTools) { - return undefined +export function filterProperties( + properties: Record | string | number | boolean | Array | string | number | boolean>, + supportedKeys: string[] +) { + // If it is an array, recursively process each element + if (isArray(properties)) { + return properties.map((item) => filterProperties(item, supportedKeys)) } - const tool = mcpTools.find( - (mcptool) => mcptool.id === llmTool.function.name || mcptool.name === llmTool.function.name - ) + // If it is an object, recursively process each property + if (isObject(properties)) { + return transform( + properties, + (result, value, key) => { + if (key === 'properties') { + result[key] = transform(value, (acc, v, k) => { + acc[k] = filterProperties(v, supportedKeys) + }) - if (!tool) { - console.warn('No MCP Tool found for tool call:', llmTool) - return undefined + result['additionalProperties'] = false + result['required'] = pull(Object.keys(value), ...EXTRA_SCHEMA_KEYS) + } else if (key === 'oneOf') { + // openai only supports anyOf + result['anyOf'] = filterProperties(value, supportedKeys) + } else if (supportedKeys.includes(key)) { + result[key] = filterProperties(value, supportedKeys) + if (key === 'type' && value === 'object') { + result['additionalProperties'] = false + } + } + }, + {} + ) } - console.log( - `[MCP] OpenAI Tool to MCP Tool: ${tool.serverName} ${tool.name}`, - tool, - 'args', - llmTool.function.arguments - ) - // use this to parse the arguments and avoid parsing errors - let args: any = {} - try { - args = JSON.parse(llmTool.function.arguments) - } catch (e) { - console.error('Error parsing arguments', e) - } - - return { - id: tool.id, - serverId: tool.serverId, - serverName: tool.serverName, - name: tool.name, - description: tool.description, - inputSchema: args - } + // Return other types directly (e.g., string, number, etc.) + return properties } -export async function callMCPTool(tool: MCPTool): Promise { - console.log(`[MCP] Calling Tool: ${tool.serverName} ${tool.name}`, tool) +export function mcpToolsToOpenAIResponseTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] { + const schemaKeys = ['type', 'description', 'items', 'enum', 'additionalProperties', 'anyof'] + return mcpTools.map( + (tool) => + ({ + type: 'function', + name: tool.id, + parameters: { + type: 'object', + properties: filterProperties(tool.inputSchema, schemaKeys).properties, + required: pull(Object.keys(tool.inputSchema.properties), ...EXTRA_SCHEMA_KEYS), + additionalProperties: false + }, + strict: true + }) satisfies OpenAI.Responses.Tool + ) +} + +export function mcpToolsToOpenAIChatTools(mcpTools: MCPTool[]): Array { + return mcpTools.map( + (tool) => + ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: { + type: 'object', + properties: tool.inputSchema.properties, + required: tool.inputSchema.required + } + } + }) as ChatCompletionTool + ) +} + +export function openAIToolsToMcpTool( + mcpTools: MCPTool[], + toolCall: OpenAI.Responses.ResponseFunctionToolCall | ChatCompletionMessageToolCall +): MCPTool | undefined { + const tool = mcpTools.find((mcpTool) => { + if ('name' in toolCall) { + return mcpTool.id === toolCall.name || mcpTool.name === toolCall.name + } else { + return mcpTool.id === toolCall.function.name || mcpTool.name === toolCall.function.name + } + }) + + if (!tool) { + console.warn('No MCP Tool found for tool call:', toolCall) + return undefined + } + + return tool +} + +export async function callMCPTool(toolResponse: MCPToolResponse): Promise { + console.log(`[MCP] Calling Tool: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, toolResponse.tool) try { - const server = getMcpServerByTool(tool) + const server = getMcpServerByTool(toolResponse.tool) if (!server) { - throw new Error(`Server not found: ${tool.serverName}`) + throw new Error(`Server not found: ${toolResponse.tool.serverName}`) } const resp = await window.api.mcp.callTool({ server, - name: tool.name, - args: tool.inputSchema + name: toolResponse.tool.name, + args: toolResponse.arguments }) - if (tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) { + if (toolResponse.tool.serverName === MCP_AUTO_INSTALL_SERVER_NAME) { if (resp.data) { const mcpServer: MCPServer = { id: `f${nanoid()}`, @@ -241,16 +293,16 @@ export async function callMCPTool(tool: MCPTool): Promise { } } - console.log(`[MCP] Tool called: ${tool.serverName} ${tool.name}`, resp) + console.log(`[MCP] Tool called: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, resp) return resp } catch (e) { - console.error(`[MCP] Error calling Tool: ${tool.serverName} ${tool.name}`, e) + console.error(`[MCP] Error calling Tool: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, e) return Promise.resolve({ isError: true, content: [ { type: 'text', - text: `Error calling tool ${tool.name}: ${e instanceof Error ? e.stack || e.message || 'No error details available' : JSON.stringify(e)}` + text: `Error calling tool ${toolResponse.tool.name}: ${e instanceof Error ? e.stack || e.message || 'No error details available' : JSON.stringify(e)}` } ] }) @@ -262,7 +314,7 @@ export function mcpToolsToAnthropicTools(mcpTools: MCPTool[]): Array const t: ToolUnion = { name: tool.id, description: tool.description, - // @ts-ignore no check + // @ts-ignore ignore type as it it unknow input_schema: tool.inputSchema } return t @@ -275,53 +327,68 @@ export function anthropicToolUseToMcpTool(mcpTools: MCPTool[] | undefined, toolU if (!tool) { return undefined } - // @ts-ignore ignore type as it it unknow - tool.inputSchema = toolUse.input return tool } -// export function mcpToolsToGeminiTools(mcpTools: MCPTool[] | undefined): geminiTool[] { -// if (!mcpTools || mcpTools.length === 0) { -// // No tools available -// return [] -// } -// const functions: FunctionDeclaration[] = [] - -// for (const tool of mcpTools) { -// const properties = filterPropertieAttributes(tool, true) -// const functionDeclaration: FunctionDeclaration = { -// name: tool.id, -// description: tool.description, -// parameters: { -// type: SchemaType.OBJECT, -// properties: -// Object.keys(properties).length > 0 -// ? Object.fromEntries( -// Object.entries(properties).map(([key, value]) => [key, ensureValidSchema(value as Record)]) -// ) -// : { _empty: { type: SchemaType.STRING } as SimpleStringSchema } -// } as FunctionDeclarationSchema -// } -// functions.push(functionDeclaration) -// } -// const tool: geminiTool = { -// functionDeclarations: functions -// } -// return [tool] -// } +/** + * @param mcpTools + * @returns + */ +export function mcpToolsToGeminiTools(mcpTools: MCPTool[]): Tool[] { + /** + * @typedef {import('@google/genai').Schema} Schema + */ + const schemaKeys = [ + 'example', + 'pattern', + 'default', + 'maxLength', + 'minLength', + 'minProperties', + 'maxProperties', + 'anyOf', + 'description', + 'enum', + 'format', + 'items', + 'maxItems', + 'maximum', + 'minItems', + 'minimum', + 'nullable', + 'properties', + 'propertyOrdering', + 'required', + 'title', + 'type' + ] + return [ + { + functionDeclarations: mcpTools?.map((tool) => { + return { + name: tool.id, + description: tool.description, + parameters: { + type: GeminiSchemaType.OBJECT, + properties: filterProperties(tool.inputSchema, schemaKeys).properties, + required: tool.inputSchema.required + } + } + }) + } + ] +} export function geminiFunctionCallToMcpTool( mcpTools: MCPTool[] | undefined, - fcall: FunctionCall | undefined + toolCall: FunctionCall | undefined ): MCPTool | undefined { - if (!fcall) return undefined + if (!toolCall) return undefined if (!mcpTools) return undefined - const tool = mcpTools.find((tool) => tool.id === fcall.name) + const tool = mcpTools.find((tool) => tool.id === toolCall.name) if (!tool) { return undefined } - // @ts-ignore schema is not a valid property - tool.inputSchema = fcall.args return tool } @@ -368,13 +435,13 @@ export function getMcpServerByTool(tool: MCPTool) { return servers.find((s) => s.id === tool.serverId) } -export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolResponse[] { +export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseResponse[] { if (!content || !mcpTools || mcpTools.length === 0) { return [] } const toolUsePattern = /([\s\S]*?)([\s\S]*?)<\/name>([\s\S]*?)([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g - const tools: MCPToolResponse[] = [] + const tools: ToolUseResponse[] = [] let match let idx = 0 // Find all tool use blocks @@ -401,10 +468,9 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolRespo // Add to tools array tools.push({ id: `${toolName}-${idx++}`, // Unique ID for each tool use - tool: { - ...mcpTool, - inputSchema: parsedArgs - }, + toolUseId: mcpTool.id, + tool: mcpTool, + arguments: parsedArgs, status: 'pending' }) @@ -414,36 +480,69 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): MCPToolRespo return tools } -export async function parseAndCallTools( - content: string, - toolResponses: MCPToolResponse[], +export async function parseAndCallTools( + tools: MCPToolResponse[], + allToolResponses: MCPToolResponse[], onChunk: CompletionsParams['onChunk'], - idx: number, - convertToMessage: ( - toolCallId: string, - resp: MCPCallToolResponse, - isVisionModel: boolean - ) => ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage, - mcpTools?: MCPTool[], - isVisionModel: boolean = false -): Promise<(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage)[]> { - const toolResults: (ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.EasyInputMessage)[] = [] - // process tool use - const tools = parseToolUse(content, mcpTools || []) - if (!tools || tools.length === 0) { + convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, + model: Model, + mcpTools?: MCPTool[] +): Promise< + (ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[] +> + +export async function parseAndCallTools( + content: string, + allToolResponses: MCPToolResponse[], + onChunk: CompletionsParams['onChunk'], + convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, + model: Model, + mcpTools?: MCPTool[] +): Promise< + (ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[] +> + +export async function parseAndCallTools( + content: string | MCPToolResponse[], + allToolResponses: MCPToolResponse[], + onChunk: CompletionsParams['onChunk'], + convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined, + model: Model, + mcpTools?: MCPTool[] +): Promise { + const toolResults: R[] = [] + let curToolResponses: MCPToolResponse[] = [] + if (Array.isArray(content)) { + curToolResponses = content + } else { + // process tool use + curToolResponses = parseToolUse(content, mcpTools || []) + } + if (!curToolResponses || curToolResponses.length === 0) { return toolResults } - for (let i = 0; i < tools.length; i++) { - const tool = tools[i] - upsertMCPToolResponse(toolResponses, { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'invoking' }, onChunk) + for (let i = 0; i < curToolResponses.length; i++) { + const toolResponse = curToolResponses[i] + upsertMCPToolResponse( + allToolResponses, + { + ...toolResponse, + status: 'invoking' + }, + onChunk + ) } - const toolPromises = tools.map(async (tool, i) => { + const toolPromises = curToolResponses.map(async (toolResponse) => { const images: string[] = [] - const toolCallResponse = await callMCPTool(tool.tool) + const toolCallResponse = await callMCPTool(toolResponse) upsertMCPToolResponse( - toolResponses, - { id: `${tool.id}-${idx}-${i}`, tool: tool.tool, status: 'done', response: toolCallResponse }, + allToolResponses, + { + ...toolResponse, + status: 'done', + response: toolCallResponse + }, onChunk ) @@ -466,15 +565,15 @@ export async function parseAndCallTools( }) } - return convertToMessage(tool.tool.id, toolCallResponse, isVisionModel) + return convertToMessage(toolResponse, toolCallResponse, model) }) - toolResults.push(...(await Promise.all(toolPromises))) + toolResults.push(...(await Promise.all(toolPromises)).filter((t) => typeof t !== 'undefined')) return toolResults } export function mcpToolCallResponseToOpenAICompatibleMessage( - toolCallId: string, + mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, isVisionModel: boolean = false ): ChatCompletionMessageParam { @@ -488,7 +587,7 @@ export function mcpToolCallResponseToOpenAICompatibleMessage( const content: ChatCompletionContentPart[] = [ { type: 'text', - text: `Here is the result of tool call ${toolCallId}:` + text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:` } ] @@ -541,7 +640,7 @@ export function mcpToolCallResponseToOpenAICompatibleMessage( } export function mcpToolCallResponseToOpenAIMessage( - toolCallId: string, + mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, isVisionModel: boolean = false ): OpenAI.Responses.EasyInputMessage { @@ -555,7 +654,7 @@ export function mcpToolCallResponseToOpenAIMessage( const content: OpenAI.Responses.ResponseInputContent[] = [ { type: 'input_text', - text: `Here is the result of tool call ${toolCallId}:` + text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:` } ] @@ -597,9 +696,9 @@ export function mcpToolCallResponseToOpenAIMessage( } export function mcpToolCallResponseToAnthropicMessage( - toolCallId: string, + mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, - isVisionModel: boolean = false + model: Model ): MessageParam { const message = { role: 'user' @@ -610,10 +709,10 @@ export function mcpToolCallResponseToAnthropicMessage( const content: ContentBlockParam[] = [ { type: 'text', - text: `Here is the result of tool call ${toolCallId}:` + text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:` } ] - if (isVisionModel) { + if (isVisionModel(model)) { for (const item of resp.content) { switch (item.type) { case 'text': @@ -665,7 +764,7 @@ export function mcpToolCallResponseToAnthropicMessage( } export function mcpToolCallResponseToGeminiMessage( - toolCallId: string, + mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, isVisionModel: boolean = false ): Content { @@ -682,7 +781,7 @@ export function mcpToolCallResponseToGeminiMessage( } else { const parts: Part[] = [ { - text: `Here is the result of tool call ${toolCallId}:` + text: `Here is the result of mcp tool use \`${mcpToolResponse.tool.name}\`:` } ] if (isVisionModel) { diff --git a/src/renderer/src/utils/prompt.ts b/src/renderer/src/utils/prompt.ts index 698fdb1186..d014bd4ce0 100644 --- a/src/renderer/src/utils/prompt.ts +++ b/src/renderer/src/utils/prompt.ts @@ -147,7 +147,7 @@ ${availableTools} ` } -export const buildSystemPrompt = (userSystemPrompt: string, tools: MCPTool[]): string => { +export const buildSystemPrompt = (userSystemPrompt: string, tools?: MCPTool[]): string => { if (tools && tools.length > 0) { return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt) .replace('{{ TOOL_USE_EXAMPLES }}', ToolUseExamples)