mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 14:31:35 +08:00
Merge branch 'main' of https://github.com/CherryHQ/cherry-studio into wip/refactor/databases
This commit is contained in:
commit
4bb5ff8086
88
.github/dependabot.yml
vendored
88
.github/dependabot.yml
vendored
@ -4,38 +4,26 @@ updates:
|
||||
directory: "/"
|
||||
schedule:
|
||||
interval: "monthly"
|
||||
open-pull-requests-limit: 7
|
||||
open-pull-requests-limit: 5
|
||||
target-branch: "main"
|
||||
commit-message:
|
||||
prefix: "chore"
|
||||
include: "scope"
|
||||
ignore:
|
||||
- dependency-name: "*"
|
||||
update-types:
|
||||
- "version-update:semver-major"
|
||||
- dependency-name: "@google/genai"
|
||||
- dependency-name: "antd"
|
||||
- dependency-name: "epub"
|
||||
- dependency-name: "openai"
|
||||
groups:
|
||||
# 核心框架
|
||||
core-framework:
|
||||
# CherryStudio 自定义包
|
||||
cherrystudio-packages:
|
||||
patterns:
|
||||
- "react"
|
||||
- "react-dom"
|
||||
- "electron"
|
||||
- "typescript"
|
||||
- "@types/react*"
|
||||
- "@types/node"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# Electron 生态和构建工具
|
||||
electron-build:
|
||||
patterns:
|
||||
- "electron-*"
|
||||
- "@electron*"
|
||||
- "vite"
|
||||
- "@vitejs/*"
|
||||
- "dotenv-cli"
|
||||
- "rollup-plugin-*"
|
||||
- "@swc/*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
- "@cherrystudio/*"
|
||||
- "@kangfenmao/*"
|
||||
- "selection-hook"
|
||||
|
||||
# 测试工具
|
||||
testing-tools:
|
||||
@ -44,30 +32,40 @@ updates:
|
||||
- "@vitest/*"
|
||||
- "playwright"
|
||||
- "@playwright/*"
|
||||
- "eslint*"
|
||||
- "@eslint*"
|
||||
- "testing-library/*"
|
||||
- "jest-styled-components"
|
||||
|
||||
# Lint 工具
|
||||
lint-tools:
|
||||
patterns:
|
||||
- "eslint"
|
||||
- "eslint-plugin-*"
|
||||
- "@eslint/*"
|
||||
- "@eslint-react/*"
|
||||
- "@electron-toolkit/eslint-config-*"
|
||||
- "prettier"
|
||||
- "husky"
|
||||
- "lint-staged"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# CherryStudio 自定义包
|
||||
cherrystudio-packages:
|
||||
# Markdown
|
||||
markdown:
|
||||
patterns:
|
||||
- "@cherrystudio/*"
|
||||
update-types:
|
||||
- "minor"
|
||||
- "patch"
|
||||
|
||||
# 兜底其他 dependencies
|
||||
other-dependencies:
|
||||
dependency-type: "production"
|
||||
|
||||
# 兜底其他 devDependencies
|
||||
other-dev-dependencies:
|
||||
dependency-type: "development"
|
||||
- "react-markdown"
|
||||
- "rehype-katex"
|
||||
- "rehype-mathjax"
|
||||
- "rehype-raw"
|
||||
- "remark-cjk-friendly"
|
||||
- "remark-gfm"
|
||||
- "remark-math"
|
||||
- "remove-markdown"
|
||||
- "markdown-it"
|
||||
- "@shikijs/markdown-it"
|
||||
- "shiki"
|
||||
- "@uiw/codemirror-extensions-langs"
|
||||
- "@uiw/codemirror-themes-all"
|
||||
- "@uiw/react-codemirror"
|
||||
- "fast-diff"
|
||||
- "mermaid"
|
||||
|
||||
- package-ecosystem: "github-actions"
|
||||
directory: "/"
|
||||
|
||||
2
.github/workflows/nightly-build.yml
vendored
2
.github/workflows/nightly-build.yml
vendored
@ -53,7 +53,7 @@ jobs:
|
||||
- name: Check out Git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: develop
|
||||
ref: main
|
||||
|
||||
- name: Install Node.js
|
||||
uses: actions/setup-node@v4
|
||||
|
||||
2
.github/workflows/pr-ci.yml
vendored
2
.github/workflows/pr-ci.yml
vendored
@ -44,4 +44,4 @@ jobs:
|
||||
run: yarn build:check
|
||||
|
||||
- name: Lint Check
|
||||
run: yarn lint
|
||||
run: yarn test:lint
|
||||
|
||||
39
.github/workflows/release.yml
vendored
39
.github/workflows/release.yml
vendored
@ -27,7 +27,7 @@ jobs:
|
||||
- name: Check out Git repository
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: main
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Get release tag
|
||||
id: get-tag
|
||||
@ -113,5 +113,40 @@ jobs:
|
||||
allowUpdates: true
|
||||
makeLatest: false
|
||||
tag: ${{ steps.get-tag.outputs.tag }}
|
||||
artifacts: 'dist/*.exe,dist/*.zip,dist/*.dmg,dist/*.AppImage,dist/*.snap,dist/*.deb,dist/*.rpm,dist/*.tar.gz,dist/latest*.yml,dist/*.blockmap'
|
||||
artifacts: 'dist/*.exe,dist/*.zip,dist/*.dmg,dist/*.AppImage,dist/*.snap,dist/*.deb,dist/*.rpm,dist/*.tar.gz,dist/latest*.yml,dist/rc*.yml,dist/*.blockmap'
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
dispatch-docs-update:
|
||||
needs: release
|
||||
if: success() && github.repository == 'CherryHQ/cherry-studio' # 确保所有构建成功且在主仓库中运行
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Get release tag
|
||||
id: get-tag
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||
echo "tag=${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "tag=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Check if tag is pre-release
|
||||
id: check-tag
|
||||
shell: bash
|
||||
run: |
|
||||
TAG="${{ steps.get-tag.outputs.tag }}"
|
||||
if [[ "$TAG" == *"rc"* || "$TAG" == *"pre-release"* ]]; then
|
||||
echo "is_pre_release=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_pre_release=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Dispatch update-download-version workflow to cherry-studio-docs
|
||||
if: steps.check-tag.outputs.is_pre_release == 'false'
|
||||
uses: peter-evans/repository-dispatch@v3
|
||||
with:
|
||||
token: ${{ secrets.REPO_DISPATCH_TOKEN }}
|
||||
repository: CherryHQ/cherry-studio-docs
|
||||
event-type: update-download-version
|
||||
client-payload: '{"version": "${{ steps.get-tag.outputs.tag }}"}'
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@ -45,7 +45,7 @@ stats.html
|
||||
local
|
||||
.aider*
|
||||
.cursorrules
|
||||
.cursor/rules
|
||||
.cursor/*
|
||||
|
||||
# vitest
|
||||
coverage
|
||||
|
||||
1
.vscode/launch.json
vendored
1
.vscode/launch.json
vendored
@ -7,7 +7,6 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}",
|
||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite",
|
||||
"runtimeVersion": "20",
|
||||
"windows": {
|
||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd"
|
||||
},
|
||||
|
||||
3
.vscode/settings.json
vendored
3
.vscode/settings.json
vendored
@ -1,7 +1,8 @@
|
||||
{
|
||||
"editor.formatOnSave": true,
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.fixAll.eslint": "explicit"
|
||||
"source.fixAll.eslint": "explicit",
|
||||
"source.organizeImports": "never"
|
||||
},
|
||||
"search.exclude": {
|
||||
"**/dist/**": true,
|
||||
|
||||
6471
.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch
vendored
Normal file
6471
.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch
vendored
Normal file
File diff suppressed because one or more lines are too long
71
.yarn/patches/@langchain-core-npm-0.3.44-41d5c3cb0a.patch
vendored
Normal file
71
.yarn/patches/@langchain-core-npm-0.3.44-41d5c3cb0a.patch
vendored
Normal file
@ -0,0 +1,71 @@
|
||||
diff --git a/dist/utils/tiktoken.cjs b/dist/utils/tiktoken.cjs
|
||||
index 973b0d0e75aeaf8de579419af31b879b32975413..f23c7caa8b9dc8bd404132725346a4786f6b278b 100644
|
||||
--- a/dist/utils/tiktoken.cjs
|
||||
+++ b/dist/utils/tiktoken.cjs
|
||||
@@ -1,25 +1,14 @@
|
||||
"use strict";
|
||||
Object.defineProperty(exports, "__esModule", { value: true });
|
||||
exports.encodingForModel = exports.getEncoding = void 0;
|
||||
-const lite_1 = require("js-tiktoken/lite");
|
||||
const async_caller_js_1 = require("./async_caller.cjs");
|
||||
const cache = {};
|
||||
const caller = /* #__PURE__ */ new async_caller_js_1.AsyncCaller({});
|
||||
async function getEncoding(encoding) {
|
||||
- if (!(encoding in cache)) {
|
||||
- cache[encoding] = caller
|
||||
- .fetch(`https://tiktoken.pages.dev/js/${encoding}.json`)
|
||||
- .then((res) => res.json())
|
||||
- .then((data) => new lite_1.Tiktoken(data))
|
||||
- .catch((e) => {
|
||||
- delete cache[encoding];
|
||||
- throw e;
|
||||
- });
|
||||
- }
|
||||
- return await cache[encoding];
|
||||
+ throw new Error("TikToken Not implemented");
|
||||
}
|
||||
exports.getEncoding = getEncoding;
|
||||
async function encodingForModel(model) {
|
||||
- return getEncoding((0, lite_1.getEncodingNameForModel)(model));
|
||||
+ throw new Error("TikToken Not implemented");
|
||||
}
|
||||
exports.encodingForModel = encodingForModel;
|
||||
diff --git a/dist/utils/tiktoken.js b/dist/utils/tiktoken.js
|
||||
index 8e41ee6f00f2f9c7fa2c59fa2b2f4297634b97aa..aa5f314a6349ad0d1c5aea8631a56aad099176e0 100644
|
||||
--- a/dist/utils/tiktoken.js
|
||||
+++ b/dist/utils/tiktoken.js
|
||||
@@ -1,20 +1,9 @@
|
||||
-import { Tiktoken, getEncodingNameForModel, } from "js-tiktoken/lite";
|
||||
import { AsyncCaller } from "./async_caller.js";
|
||||
const cache = {};
|
||||
const caller = /* #__PURE__ */ new AsyncCaller({});
|
||||
export async function getEncoding(encoding) {
|
||||
- if (!(encoding in cache)) {
|
||||
- cache[encoding] = caller
|
||||
- .fetch(`https://tiktoken.pages.dev/js/${encoding}.json`)
|
||||
- .then((res) => res.json())
|
||||
- .then((data) => new Tiktoken(data))
|
||||
- .catch((e) => {
|
||||
- delete cache[encoding];
|
||||
- throw e;
|
||||
- });
|
||||
- }
|
||||
- return await cache[encoding];
|
||||
+ throw new Error("TikToken Not implemented");
|
||||
}
|
||||
export async function encodingForModel(model) {
|
||||
- return getEncoding(getEncodingNameForModel(model));
|
||||
+ throw new Error("TikToken Not implemented");
|
||||
}
|
||||
diff --git a/package.json b/package.json
|
||||
index 36072aecf700fca1bc49832a19be832eca726103..90b8922fba1c3d1b26f78477c891b07816d6238a 100644
|
||||
--- a/package.json
|
||||
+++ b/package.json
|
||||
@@ -37,7 +37,6 @@
|
||||
"ansi-styles": "^5.0.0",
|
||||
"camelcase": "6",
|
||||
"decamelize": "1.2.0",
|
||||
- "js-tiktoken": "^1.0.12",
|
||||
"langsmith": ">=0.2.8 <0.4.0",
|
||||
"mustache": "^4.2.0",
|
||||
"p-queue": "^6.6.2",
|
||||
69
.yarn/patches/antd-npm-5.24.7-356a553ae5.patch
vendored
Normal file
69
.yarn/patches/antd-npm-5.24.7-356a553ae5.patch
vendored
Normal file
@ -0,0 +1,69 @@
|
||||
diff --git a/es/dropdown/dropdown.js b/es/dropdown/dropdown.js
|
||||
index 986877a762b9ad0aca596a8552732cd12d2eaabb..1f18aa2ea745e68950e4cee16d4d655f5c835fd5 100644
|
||||
--- a/es/dropdown/dropdown.js
|
||||
+++ b/es/dropdown/dropdown.js
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import * as React from 'react';
|
||||
import LeftOutlined from "@ant-design/icons/es/icons/LeftOutlined";
|
||||
-import RightOutlined from "@ant-design/icons/es/icons/RightOutlined";
|
||||
+import { ChevronRight } from 'lucide-react';
|
||||
import classNames from 'classnames';
|
||||
import RcDropdown from 'rc-dropdown';
|
||||
import useEvent from "rc-util/es/hooks/useEvent";
|
||||
@@ -158,8 +158,10 @@ const Dropdown = props => {
|
||||
className: `${prefixCls}-menu-submenu-arrow`
|
||||
}, direction === 'rtl' ? (/*#__PURE__*/React.createElement(LeftOutlined, {
|
||||
className: `${prefixCls}-menu-submenu-arrow-icon`
|
||||
- })) : (/*#__PURE__*/React.createElement(RightOutlined, {
|
||||
- className: `${prefixCls}-menu-submenu-arrow-icon`
|
||||
+ })) : (/*#__PURE__*/React.createElement(ChevronRight, {
|
||||
+ size: 16,
|
||||
+ strokeWidth: 1.8,
|
||||
+ className: `${prefixCls}-menu-submenu-arrow-icon lucide-custom`
|
||||
}))),
|
||||
mode: "vertical",
|
||||
selectable: false,
|
||||
diff --git a/es/dropdown/style/index.js b/es/dropdown/style/index.js
|
||||
index 768c01783002c6901c85a73061ff6b3e776a60ce..39b1b95a56cdc9fb586a193c3adad5141f5cf213 100644
|
||||
--- a/es/dropdown/style/index.js
|
||||
+++ b/es/dropdown/style/index.js
|
||||
@@ -240,7 +240,8 @@ const genBaseStyle = token => {
|
||||
marginInlineEnd: '0 !important',
|
||||
color: token.colorTextDescription,
|
||||
fontSize: fontSizeIcon,
|
||||
- fontStyle: 'normal'
|
||||
+ fontStyle: 'normal',
|
||||
+ marginTop: 3,
|
||||
}
|
||||
}
|
||||
}),
|
||||
diff --git a/es/select/useIcons.js b/es/select/useIcons.js
|
||||
index 959115be936ef8901548af2658c5dcfdc5852723..c812edd52123eb0faf4638b1154fcfa1b05b513b 100644
|
||||
--- a/es/select/useIcons.js
|
||||
+++ b/es/select/useIcons.js
|
||||
@@ -4,10 +4,10 @@ import * as React from 'react';
|
||||
import CheckOutlined from "@ant-design/icons/es/icons/CheckOutlined";
|
||||
import CloseCircleFilled from "@ant-design/icons/es/icons/CloseCircleFilled";
|
||||
import CloseOutlined from "@ant-design/icons/es/icons/CloseOutlined";
|
||||
-import DownOutlined from "@ant-design/icons/es/icons/DownOutlined";
|
||||
import LoadingOutlined from "@ant-design/icons/es/icons/LoadingOutlined";
|
||||
import SearchOutlined from "@ant-design/icons/es/icons/SearchOutlined";
|
||||
import { devUseWarning } from '../_util/warning';
|
||||
+import { ChevronDown } from 'lucide-react';
|
||||
export default function useIcons(_ref) {
|
||||
let {
|
||||
suffixIcon,
|
||||
@@ -56,8 +56,10 @@ export default function useIcons(_ref) {
|
||||
className: iconCls
|
||||
}));
|
||||
}
|
||||
- return getSuffixIconNode(/*#__PURE__*/React.createElement(DownOutlined, {
|
||||
- className: iconCls
|
||||
+ return getSuffixIconNode(/*#__PURE__*/React.createElement(ChevronDown, {
|
||||
+ size: 16,
|
||||
+ strokeWidth: 1.8,
|
||||
+ className: `${iconCls} lucide-custom`
|
||||
}));
|
||||
};
|
||||
}
|
||||
@ -65,11 +65,44 @@ index e8bd7bb46c8a54b3f55cf3a853ef924195271e01..f956e9f3fe9eb903c78aef3502553b01
|
||||
await packager.info.emitArtifactBuildCompleted({
|
||||
file: installerPath,
|
||||
updateInfo,
|
||||
diff --git a/out/util/yarn.js b/out/util/yarn.js
|
||||
index 1ee20f8b252a8f28d0c7b103789cf0a9a427aec1..c2878ec54d57da50bf14225e0c70c9c88664eb8a 100644
|
||||
--- a/out/util/yarn.js
|
||||
+++ b/out/util/yarn.js
|
||||
@@ -140,6 +140,7 @@ async function rebuild(config, { appDir, projectDir }, options) {
|
||||
arch,
|
||||
platform,
|
||||
buildFromSource,
|
||||
+ ignoreModules: config.excludeReBuildModules || undefined,
|
||||
projectRootPath: projectDir,
|
||||
mode: config.nativeRebuilder || "sequential",
|
||||
disablePreGypCopy: true,
|
||||
diff --git a/scheme.json b/scheme.json
|
||||
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43ebd0fa8b61 100644
|
||||
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a74dda74c9 100644
|
||||
--- a/scheme.json
|
||||
+++ b/scheme.json
|
||||
@@ -1975,6 +1975,13 @@
|
||||
@@ -1825,6 +1825,20 @@
|
||||
"string"
|
||||
]
|
||||
},
|
||||
+ "excludeReBuildModules": {
|
||||
+ "anyOf": [
|
||||
+ {
|
||||
+ "items": {
|
||||
+ "type": "string"
|
||||
+ },
|
||||
+ "type": "array"
|
||||
+ },
|
||||
+ {
|
||||
+ "type": "null"
|
||||
+ }
|
||||
+ ],
|
||||
+ "description": "The modules to exclude from the rebuild."
|
||||
+ },
|
||||
"executableArgs": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -1975,6 +1989,13 @@
|
||||
],
|
||||
"description": "The mime types in addition to specified in the file associations. Use it if you don't want to register a new mime type, but reuse existing."
|
||||
},
|
||||
@ -83,7 +116,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
||||
"packageCategory": {
|
||||
"description": "backward compatibility + to allow specify fpm-only category for all possible fpm targets in one place",
|
||||
"type": [
|
||||
@@ -2327,6 +2334,13 @@
|
||||
@@ -2327,6 +2348,13 @@
|
||||
"MacConfiguration": {
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
@ -97,7 +130,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
||||
"additionalArguments": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -2737,7 +2751,7 @@
|
||||
@@ -2527,6 +2555,20 @@
|
||||
"string"
|
||||
]
|
||||
},
|
||||
+ "excludeReBuildModules": {
|
||||
+ "anyOf": [
|
||||
+ {
|
||||
+ "items": {
|
||||
+ "type": "string"
|
||||
+ },
|
||||
+ "type": "array"
|
||||
+ },
|
||||
+ {
|
||||
+ "type": "null"
|
||||
+ }
|
||||
+ ],
|
||||
+ "description": "The modules to exclude from the rebuild."
|
||||
+ },
|
||||
"executableName": {
|
||||
"description": "The executable name. Defaults to `productName`.",
|
||||
"type": [
|
||||
@@ -2737,7 +2779,7 @@
|
||||
"type": "boolean"
|
||||
},
|
||||
"minimumSystemVersion": {
|
||||
@ -106,7 +160,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
||||
"type": [
|
||||
"null",
|
||||
"string"
|
||||
@@ -2959,6 +2973,13 @@
|
||||
@@ -2959,6 +3001,13 @@
|
||||
"MasConfiguration": {
|
||||
"additionalProperties": false,
|
||||
"properties": {
|
||||
@ -120,7 +174,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
||||
"additionalArguments": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -3369,7 +3390,7 @@
|
||||
@@ -3159,6 +3208,20 @@
|
||||
"string"
|
||||
]
|
||||
},
|
||||
+ "excludeReBuildModules": {
|
||||
+ "anyOf": [
|
||||
+ {
|
||||
+ "items": {
|
||||
+ "type": "string"
|
||||
+ },
|
||||
+ "type": "array"
|
||||
+ },
|
||||
+ {
|
||||
+ "type": "null"
|
||||
+ }
|
||||
+ ],
|
||||
+ "description": "The modules to exclude from the rebuild."
|
||||
+ },
|
||||
"executableName": {
|
||||
"description": "The executable name. Defaults to `productName`.",
|
||||
"type": [
|
||||
@@ -3369,7 +3432,7 @@
|
||||
"type": "boolean"
|
||||
},
|
||||
"minimumSystemVersion": {
|
||||
@ -129,7 +204,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
||||
"type": [
|
||||
"null",
|
||||
"string"
|
||||
@@ -6507,6 +6528,13 @@
|
||||
@@ -6381,6 +6444,20 @@
|
||||
"string"
|
||||
]
|
||||
},
|
||||
+ "excludeReBuildModules": {
|
||||
+ "anyOf": [
|
||||
+ {
|
||||
+ "items": {
|
||||
+ "type": "string"
|
||||
+ },
|
||||
+ "type": "array"
|
||||
+ },
|
||||
+ {
|
||||
+ "type": "null"
|
||||
+ }
|
||||
+ ],
|
||||
+ "description": "The modules to exclude from the rebuild."
|
||||
+ },
|
||||
"executableName": {
|
||||
"description": "The executable name. Defaults to `productName`.",
|
||||
"type": [
|
||||
@@ -6507,6 +6584,13 @@
|
||||
"string"
|
||||
]
|
||||
},
|
||||
@ -143,7 +239,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
||||
"protocols": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -7376,6 +7404,13 @@
|
||||
@@ -7153,6 +7237,20 @@
|
||||
"string"
|
||||
]
|
||||
},
|
||||
+ "excludeReBuildModules": {
|
||||
+ "anyOf": [
|
||||
+ {
|
||||
+ "items": {
|
||||
+ "type": "string"
|
||||
+ },
|
||||
+ "type": "array"
|
||||
+ },
|
||||
+ {
|
||||
+ "type": "null"
|
||||
+ }
|
||||
+ ],
|
||||
+ "description": "The modules to exclude from the rebuild."
|
||||
+ },
|
||||
"executableName": {
|
||||
"description": "The executable name. Defaults to `productName`.",
|
||||
"type": [
|
||||
@@ -7376,6 +7474,13 @@
|
||||
],
|
||||
"description": "MAS (Mac Application Store) development options (`mas-dev` target)."
|
||||
},
|
||||
|
||||
85
.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch
vendored
85
.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch
vendored
@ -1,85 +0,0 @@
|
||||
diff --git a/core.js b/core.js
|
||||
index 862d66101f441fb4f47dfc8cff5e2d39e1f5a11e..6464bebbf696c39d35f0368f061ea4236225c162 100644
|
||||
--- a/core.js
|
||||
+++ b/core.js
|
||||
@@ -159,7 +159,7 @@ class APIClient {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': this.getUserAgent(),
|
||||
- ...getPlatformHeaders(),
|
||||
+ // ...getPlatformHeaders(),
|
||||
...this.authHeaders(opts),
|
||||
};
|
||||
}
|
||||
diff --git a/core.mjs b/core.mjs
|
||||
index 05dbc6cfde51589a2b100d4e4b5b3c1a33b32b89..789fbb4985eb952a0349b779fa83b1a068af6e7e 100644
|
||||
--- a/core.mjs
|
||||
+++ b/core.mjs
|
||||
@@ -152,7 +152,7 @@ export class APIClient {
|
||||
Accept: 'application/json',
|
||||
'Content-Type': 'application/json',
|
||||
'User-Agent': this.getUserAgent(),
|
||||
- ...getPlatformHeaders(),
|
||||
+ // ...getPlatformHeaders(),
|
||||
...this.authHeaders(opts),
|
||||
};
|
||||
}
|
||||
diff --git a/error.mjs b/error.mjs
|
||||
index 7d19f5578040afa004bc887aab1725e8703d2bac..59ec725b6142299a62798ac4bdedb63ba7d9932c 100644
|
||||
--- a/error.mjs
|
||||
+++ b/error.mjs
|
||||
@@ -36,7 +36,7 @@ export class APIError extends OpenAIError {
|
||||
if (!status || !headers) {
|
||||
return new APIConnectionError({ message, cause: castToError(errorResponse) });
|
||||
}
|
||||
- const error = errorResponse?.['error'];
|
||||
+ const error = errorResponse?.['error'] || errorResponse;
|
||||
if (status === 400) {
|
||||
return new BadRequestError(status, error, message, headers);
|
||||
}
|
||||
diff --git a/resources/embeddings.js b/resources/embeddings.js
|
||||
index aae578404cb2d09a39ac33fc416f1c215c45eecd..25c54b05bdae64d5c3b36fbb30dc7c8221b14034 100644
|
||||
--- a/resources/embeddings.js
|
||||
+++ b/resources/embeddings.js
|
||||
@@ -36,6 +36,9 @@ class Embeddings extends resource_1.APIResource {
|
||||
// No encoding_format specified, defaulting to base64 for performance reasons
|
||||
// See https://github.com/openai/openai-node/pull/1312
|
||||
let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
||||
+ if (body.model.includes('jina')) {
|
||||
+ encoding_format = undefined;
|
||||
+ }
|
||||
if (hasUserProvidedEncodingFormat) {
|
||||
Core.debug('Request', 'User defined encoding_format:', body.encoding_format);
|
||||
}
|
||||
@@ -47,7 +50,7 @@ class Embeddings extends resource_1.APIResource {
|
||||
...options,
|
||||
});
|
||||
// if the user specified an encoding_format, return the response as-is
|
||||
- if (hasUserProvidedEncodingFormat) {
|
||||
+ if (hasUserProvidedEncodingFormat || body.model.includes('jina')) {
|
||||
return response;
|
||||
}
|
||||
// in this stage, we are sure the user did not specify an encoding_format
|
||||
diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs
|
||||
index 0df3c6cc79a520e54acb4c2b5f77c43b774035ff..aa488b8a11b2c413c0a663d9a6059d286d7b5faf 100644
|
||||
--- a/resources/embeddings.mjs
|
||||
+++ b/resources/embeddings.mjs
|
||||
@@ -10,6 +10,9 @@ export class Embeddings extends APIResource {
|
||||
// No encoding_format specified, defaulting to base64 for performance reasons
|
||||
// See https://github.com/openai/openai-node/pull/1312
|
||||
let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
||||
+ if (body.model.includes('jina')) {
|
||||
+ encoding_format = undefined;
|
||||
+ }
|
||||
if (hasUserProvidedEncodingFormat) {
|
||||
Core.debug('Request', 'User defined encoding_format:', body.encoding_format);
|
||||
}
|
||||
@@ -21,7 +24,7 @@ export class Embeddings extends APIResource {
|
||||
...options,
|
||||
});
|
||||
// if the user specified an encoding_format, return the response as-is
|
||||
- if (hasUserProvidedEncodingFormat) {
|
||||
+ if (hasUserProvidedEncodingFormat || body.model.includes('jina')) {
|
||||
return response;
|
||||
}
|
||||
// in this stage, we are sure the user did not specify an encoding_format
|
||||
279
.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch
vendored
Normal file
279
.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch
vendored
Normal file
@ -0,0 +1,279 @@
|
||||
diff --git a/client.js b/client.js
|
||||
index 33b4ff6309d5f29187dab4e285d07dac20340bab..8f568637ee9e4677585931fb0284c8165a933f69 100644
|
||||
--- a/client.js
|
||||
+++ b/client.js
|
||||
@@ -433,7 +433,7 @@ class OpenAI {
|
||||
'User-Agent': this.getUserAgent(),
|
||||
'X-Stainless-Retry-Count': String(retryCount),
|
||||
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
|
||||
- ...(0, detect_platform_1.getPlatformHeaders)(),
|
||||
+ // ...(0, detect_platform_1.getPlatformHeaders)(),
|
||||
'OpenAI-Organization': this.organization,
|
||||
'OpenAI-Project': this.project,
|
||||
},
|
||||
diff --git a/client.mjs b/client.mjs
|
||||
index c34c18213073540ebb296ea540b1d1ad39527906..1ce1a98256d7e90e26ca963582f235b23e996e73 100644
|
||||
--- a/client.mjs
|
||||
+++ b/client.mjs
|
||||
@@ -430,7 +430,7 @@ export class OpenAI {
|
||||
'User-Agent': this.getUserAgent(),
|
||||
'X-Stainless-Retry-Count': String(retryCount),
|
||||
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
|
||||
- ...getPlatformHeaders(),
|
||||
+ // ...getPlatformHeaders(),
|
||||
'OpenAI-Organization': this.organization,
|
||||
'OpenAI-Project': this.project,
|
||||
},
|
||||
diff --git a/core/error.js b/core/error.js
|
||||
index a12d9d9ccd242050161adeb0f82e1b98d9e78e20..fe3a5462480558bc426deea147f864f12b36f9bd 100644
|
||||
--- a/core/error.js
|
||||
+++ b/core/error.js
|
||||
@@ -40,7 +40,7 @@ class APIError extends OpenAIError {
|
||||
if (!status || !headers) {
|
||||
return new APIConnectionError({ message, cause: (0, errors_1.castToError)(errorResponse) });
|
||||
}
|
||||
- const error = errorResponse?.['error'];
|
||||
+ const error = errorResponse?.['error'] || errorResponse;
|
||||
if (status === 400) {
|
||||
return new BadRequestError(status, error, message, headers);
|
||||
}
|
||||
diff --git a/core/error.mjs b/core/error.mjs
|
||||
index 83cefbaffeb8c657536347322d8de9516af479a2..63334b7972ec04882aa4a0800c1ead5982345045 100644
|
||||
--- a/core/error.mjs
|
||||
+++ b/core/error.mjs
|
||||
@@ -36,7 +36,7 @@ export class APIError extends OpenAIError {
|
||||
if (!status || !headers) {
|
||||
return new APIConnectionError({ message, cause: castToError(errorResponse) });
|
||||
}
|
||||
- const error = errorResponse?.['error'];
|
||||
+ const error = errorResponse?.['error'] || errorResponse;
|
||||
if (status === 400) {
|
||||
return new BadRequestError(status, error, message, headers);
|
||||
}
|
||||
diff --git a/resources/embeddings.js b/resources/embeddings.js
|
||||
index 2404264d4ba0204322548945ebb7eab3bea82173..8f1bc45cc45e0797d50989d96b51147b90ae6790 100644
|
||||
--- a/resources/embeddings.js
|
||||
+++ b/resources/embeddings.js
|
||||
@@ -5,52 +5,64 @@ exports.Embeddings = void 0;
|
||||
const resource_1 = require("../core/resource.js");
|
||||
const utils_1 = require("../internal/utils.js");
|
||||
class Embeddings extends resource_1.APIResource {
|
||||
- /**
|
||||
- * Creates an embedding vector representing the input text.
|
||||
- *
|
||||
- * @example
|
||||
- * ```ts
|
||||
- * const createEmbeddingResponse =
|
||||
- * await client.embeddings.create({
|
||||
- * input: 'The quick brown fox jumped over the lazy dog',
|
||||
- * model: 'text-embedding-3-small',
|
||||
- * });
|
||||
- * ```
|
||||
- */
|
||||
- create(body, options) {
|
||||
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||
- // No encoding_format specified, defaulting to base64 for performance reasons
|
||||
- // See https://github.com/openai/openai-node/pull/1312
|
||||
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
||||
- if (hasUserProvidedEncodingFormat) {
|
||||
- (0, utils_1.loggerFor)(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
|
||||
- }
|
||||
- const response = this._client.post('/embeddings', {
|
||||
- body: {
|
||||
- ...body,
|
||||
- encoding_format: encoding_format,
|
||||
- },
|
||||
- ...options,
|
||||
- });
|
||||
- // if the user specified an encoding_format, return the response as-is
|
||||
- if (hasUserProvidedEncodingFormat) {
|
||||
- return response;
|
||||
- }
|
||||
- // in this stage, we are sure the user did not specify an encoding_format
|
||||
- // and we defaulted to base64 for performance reasons
|
||||
- // we are sure then that the response is base64 encoded, let's decode it
|
||||
- // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||
- (0, utils_1.loggerFor)(this._client).debug('embeddings/decoding base64 embeddings from base64');
|
||||
- return response._thenUnwrap((response) => {
|
||||
- if (response && response.data) {
|
||||
- response.data.forEach((embeddingBase64Obj) => {
|
||||
- const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||
- embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(embeddingBase64Str);
|
||||
- });
|
||||
- }
|
||||
- return response;
|
||||
- });
|
||||
- }
|
||||
+ /**
|
||||
+ * Creates an embedding vector representing the input text.
|
||||
+ *
|
||||
+ * @example
|
||||
+ * ```ts
|
||||
+ * const createEmbeddingResponse =
|
||||
+ * await client.embeddings.create({
|
||||
+ * input: 'The quick brown fox jumped over the lazy dog',
|
||||
+ * model: 'text-embedding-3-small',
|
||||
+ * });
|
||||
+ * ```
|
||||
+ */
|
||||
+ create(body, options) {
|
||||
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||
+ // No encoding_format specified, defaulting to base64 for performance reasons
|
||||
+ // See https://github.com/openai/openai-node/pull/1312
|
||||
+ let encoding_format = hasUserProvidedEncodingFormat
|
||||
+ ? body.encoding_format
|
||||
+ : "base64";
|
||||
+ if (body.model.includes("jina")) {
|
||||
+ encoding_format = undefined;
|
||||
+ }
|
||||
+ if (hasUserProvidedEncodingFormat) {
|
||||
+ (0, utils_1.loggerFor)(this._client).debug(
|
||||
+ "embeddings/user defined encoding_format:",
|
||||
+ body.encoding_format
|
||||
+ );
|
||||
+ }
|
||||
+ const response = this._client.post("/embeddings", {
|
||||
+ body: {
|
||||
+ ...body,
|
||||
+ encoding_format: encoding_format,
|
||||
+ },
|
||||
+ ...options,
|
||||
+ });
|
||||
+ // if the user specified an encoding_format, return the response as-is
|
||||
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
|
||||
+ return response;
|
||||
+ }
|
||||
+ // in this stage, we are sure the user did not specify an encoding_format
|
||||
+ // and we defaulted to base64 for performance reasons
|
||||
+ // we are sure then that the response is base64 encoded, let's decode it
|
||||
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||
+ (0, utils_1.loggerFor)(this._client).debug(
|
||||
+ "embeddings/decoding base64 embeddings from base64"
|
||||
+ );
|
||||
+ return response._thenUnwrap((response) => {
|
||||
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
|
||||
+ response.data.forEach((embeddingBase64Obj) => {
|
||||
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||
+ embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(
|
||||
+ embeddingBase64Str
|
||||
+ );
|
||||
+ });
|
||||
+ }
|
||||
+ return response;
|
||||
+ });
|
||||
+ }
|
||||
}
|
||||
exports.Embeddings = Embeddings;
|
||||
//# sourceMappingURL=embeddings.js.map
|
||||
diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs
|
||||
index 19dcaef578c194a89759c4360073cfd4f7dd2cbf..0284e9cc615c900eff508eb595f7360a74bd9200 100644
|
||||
--- a/resources/embeddings.mjs
|
||||
+++ b/resources/embeddings.mjs
|
||||
@@ -2,51 +2,61 @@
|
||||
import { APIResource } from "../core/resource.mjs";
|
||||
import { loggerFor, toFloat32Array } from "../internal/utils.mjs";
|
||||
export class Embeddings extends APIResource {
|
||||
- /**
|
||||
- * Creates an embedding vector representing the input text.
|
||||
- *
|
||||
- * @example
|
||||
- * ```ts
|
||||
- * const createEmbeddingResponse =
|
||||
- * await client.embeddings.create({
|
||||
- * input: 'The quick brown fox jumped over the lazy dog',
|
||||
- * model: 'text-embedding-3-small',
|
||||
- * });
|
||||
- * ```
|
||||
- */
|
||||
- create(body, options) {
|
||||
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||
- // No encoding_format specified, defaulting to base64 for performance reasons
|
||||
- // See https://github.com/openai/openai-node/pull/1312
|
||||
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
||||
- if (hasUserProvidedEncodingFormat) {
|
||||
- loggerFor(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
|
||||
- }
|
||||
- const response = this._client.post('/embeddings', {
|
||||
- body: {
|
||||
- ...body,
|
||||
- encoding_format: encoding_format,
|
||||
- },
|
||||
- ...options,
|
||||
- });
|
||||
- // if the user specified an encoding_format, return the response as-is
|
||||
- if (hasUserProvidedEncodingFormat) {
|
||||
- return response;
|
||||
- }
|
||||
- // in this stage, we are sure the user did not specify an encoding_format
|
||||
- // and we defaulted to base64 for performance reasons
|
||||
- // we are sure then that the response is base64 encoded, let's decode it
|
||||
- // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||
- loggerFor(this._client).debug('embeddings/decoding base64 embeddings from base64');
|
||||
- return response._thenUnwrap((response) => {
|
||||
- if (response && response.data) {
|
||||
- response.data.forEach((embeddingBase64Obj) => {
|
||||
- const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||
- embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
|
||||
- });
|
||||
- }
|
||||
- return response;
|
||||
- });
|
||||
- }
|
||||
+ /**
|
||||
+ * Creates an embedding vector representing the input text.
|
||||
+ *
|
||||
+ * @example
|
||||
+ * ```ts
|
||||
+ * const createEmbeddingResponse =
|
||||
+ * await client.embeddings.create({
|
||||
+ * input: 'The quick brown fox jumped over the lazy dog',
|
||||
+ * model: 'text-embedding-3-small',
|
||||
+ * });
|
||||
+ * ```
|
||||
+ */
|
||||
+ create(body, options) {
|
||||
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||
+ // No encoding_format specified, defaulting to base64 for performance reasons
|
||||
+ // See https://github.com/openai/openai-node/pull/1312
|
||||
+ let encoding_format = hasUserProvidedEncodingFormat
|
||||
+ ? body.encoding_format
|
||||
+ : "base64";
|
||||
+ if (body.model.includes("jina")) {
|
||||
+ encoding_format = undefined;
|
||||
+ }
|
||||
+ if (hasUserProvidedEncodingFormat) {
|
||||
+ loggerFor(this._client).debug(
|
||||
+ "embeddings/user defined encoding_format:",
|
||||
+ body.encoding_format
|
||||
+ );
|
||||
+ }
|
||||
+ const response = this._client.post("/embeddings", {
|
||||
+ body: {
|
||||
+ ...body,
|
||||
+ encoding_format: encoding_format,
|
||||
+ },
|
||||
+ ...options,
|
||||
+ });
|
||||
+ // if the user specified an encoding_format, return the response as-is
|
||||
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
|
||||
+ return response;
|
||||
+ }
|
||||
+ // in this stage, we are sure the user did not specify an encoding_format
|
||||
+ // and we defaulted to base64 for performance reasons
|
||||
+ // we are sure then that the response is base64 encoded, let's decode it
|
||||
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||
+ loggerFor(this._client).debug(
|
||||
+ "embeddings/decoding base64 embeddings from base64"
|
||||
+ );
|
||||
+ return response._thenUnwrap((response) => {
|
||||
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
|
||||
+ response.data.forEach((embeddingBase64Obj) => {
|
||||
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||
+ embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
|
||||
+ });
|
||||
+ }
|
||||
+ return response;
|
||||
+ });
|
||||
+ }
|
||||
}
|
||||
//# sourceMappingURL=embeddings.mjs.map
|
||||
201
README.md
201
README.md
@ -1,12 +1,74 @@
|
||||
<div align="right" >
|
||||
<details>
|
||||
<summary >🌐 Language</summary>
|
||||
<div>
|
||||
<div align="right">
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=en">English</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=zh-CN">简体中文</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=zh-TW">繁體中文</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ja">日本語</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ko">한국어</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=hi">हिन्दी</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=th">ไทย</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=fr">Français</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=de">Deutsch</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=es">Español</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=it">Itapano</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ru">Русский</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=pt">Português</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=nl">Nederlands</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=pl">Polski</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ar">العربية</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=fa">فارسی</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=tr">Türkçe</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=vi">Tiếng Việt</a></p>
|
||||
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=id">Bahasa Indonesia</a></p>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
|
||||
<h1 align="center">
|
||||
<a href="https://github.com/CherryHQ/cherry-studio/releases">
|
||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
|
||||
</a>
|
||||
</h1>
|
||||
<p align="center">English | <a href="./docs/README.zh.md">中文</a> | <a href="./docs/README.ja.md">日本語</a><br></p>
|
||||
<p align="center">English | <a href="./docs/README.zh.md">中文</a> | <a href="./docs/README.ja.md">日本語</a> | <a href="https://cherry-ai.com">Official Site</a> | <a href="https://docs.cherry-ai.com/cherry-studio-wen-dang/en-us">Documents</a> | <a href="./docs/dev.md">Development</a> | <a href="https://github.com/CherryHQ/cherry-studio/issues">Feedback</a><br></p>
|
||||
|
||||
<!-- 题头徽章组合 -->
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][deepwiki-shield]][deepwiki-link]
|
||||
[![][twitter-shield]][twitter-link]
|
||||
[![][discord-shield]][discord-link]
|
||||
[![][telegram-shield]][telegram-link]
|
||||
|
||||
</div>
|
||||
|
||||
<!-- 项目统计徽章 -->
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][github-stars-shield]][github-stars-link]
|
||||
[![][github-forks-shield]][github-forks-link]
|
||||
[![][github-release-shield]][github-release-link]
|
||||
[![][github-contributors-shield]][github-contributors-link]
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][license-shield]][license-link]
|
||||
[![][commercial-shield]][commercial-link]
|
||||
[![][sponsor-shield]][sponsor-link]
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://hellogithub.com/repository/1605492e1e2a4df3be07abfa4578dd37" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=1605492e1e2a4df3be07abfa4578dd37" alt="Featured|HelloGitHub" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||
</div>
|
||||
|
||||
# 🍒 Cherry Studio
|
||||
@ -17,10 +79,6 @@ Cherry Studio is a desktop client that supports for multiple LLM providers, avai
|
||||
|
||||
❤️ Like Cherry Studio? Give it a star 🌟 or [Sponsor](docs/sponsor.md) to support the development!
|
||||
|
||||
# 📖 Guide
|
||||
|
||||
<https://docs.cherry-ai.com>
|
||||
|
||||
# 🌠 Screenshot
|
||||
|
||||

|
||||
@ -114,14 +172,6 @@ Want to influence our roadmap? Join our [GitHub Discussions](https://github.com/
|
||||
|
||||
Welcome PR for more themes
|
||||
|
||||
# 🖥️ Develop
|
||||
|
||||
Refer to the [development documentation](docs/dev.md)
|
||||
|
||||
Refer to the [Architecture overview documentation](https://deepwiki.com/CherryHQ/cherry-studio)
|
||||
|
||||
Refer to the [Branching Strategy](docs/branching-strategy-en.md) for contribution guidelines
|
||||
|
||||
# 🤝 Contributing
|
||||
|
||||
We welcome contributions to Cherry Studio! Here are some ways you can contribute:
|
||||
@ -134,6 +184,8 @@ We welcome contributions to Cherry Studio! Here are some ways you can contribute
|
||||
6. **Community Engagement**: Join discussions and help users.
|
||||
7. **Promote Usage**: Spread the word about Cherry Studio.
|
||||
|
||||
Refer to the [Branching Strategy](docs/branching-strategy-en.md) for contribution guidelines
|
||||
|
||||
## Getting Started
|
||||
|
||||
1. **Fork the Repository**: Fork and clone it to your local machine.
|
||||
@ -145,6 +197,78 @@ For more detailed guidelines, please refer to our [Contributing Guide](./CONTRIB
|
||||
|
||||
Thank you for your support and contributions!
|
||||
|
||||
# 🔧 Developer Co-creation Program
|
||||
|
||||
We are launching the Cherry Studio Developer Co-creation Program to foster a healthy and positive-feedback loop within the open-source ecosystem. We believe that great software is built collaboratively, and every merged pull request breathes new life into the project.
|
||||
|
||||
We sincerely invite you to join our ranks of contributors and shape the future of Cherry Studio with us.
|
||||
|
||||
## Contributor Rewards Program
|
||||
|
||||
To give back to our core contributors and create a virtuous cycle, we have established the following long-term incentive plan.
|
||||
|
||||
**The inaugural tracking period for this program will be Q3 2025 (July, August, September). Rewards for this cycle will be distributed on October 1st.**
|
||||
|
||||
Within any tracking period (e.g., July 1st to September 30th for the first cycle), any developer who contributes more than **30 meaningful commits** to any of Cherry Studio's open-source projects on GitHub is eligible for the following benefits:
|
||||
|
||||
- **Cursor Subscription Sponsorship**: Receive a **$70 USD** credit or reimbursement for your [Cursor](https://cursor.sh/) subscription, making AI your most efficient coding partner.
|
||||
- **Unlimited Model Access**: Get **unlimited** API calls for the **DeepSeek** and **Qwen** models.
|
||||
- **Cutting-Edge Tech Access**: Enjoy occasional perks, including API access to models like **Claude**, **Gemini**, and **OpenAI**, keeping you at the forefront of technology.
|
||||
|
||||
## Growing Together & Future Plans
|
||||
|
||||
A vibrant community is the driving force behind any sustainable open-source project. As Cherry Studio grows, so will our rewards program. We are committed to continuously aligning our benefits with the best-in-class tools and resources in the industry. This ensures our core contributors receive meaningful support, creating a positive cycle where developers, the community, and the project grow together.
|
||||
|
||||
**Moving forward, the project will also embrace an increasingly open stance to give back to the entire open-source community.**
|
||||
|
||||
## How to Get Started?
|
||||
|
||||
We look forward to your first Pull Request!
|
||||
|
||||
You can start by exploring our repositories, picking up a `good first issue`, or proposing your own enhancements. Every commit is a testament to the spirit of open source.
|
||||
|
||||
Thank you for your interest and contributions.
|
||||
|
||||
Let's build together.
|
||||
|
||||
# 🏢 Enterprise Edition
|
||||
|
||||
Building on the Community Edition, we are proud to introduce **Cherry Studio Enterprise Edition**—a privately deployable AI productivity and management platform designed for modern teams and enterprises.
|
||||
|
||||
The Enterprise Edition addresses core challenges in team collaboration by centralizing the management of AI resources, knowledge, and data. It empowers organizations to enhance efficiency, foster innovation, and ensure compliance, all while maintaining 100% control over their data in a secure environment.
|
||||
|
||||
## Core Advantages
|
||||
|
||||
- **Unified Model Management**: Centrally integrate and manage various cloud-based LLMs (e.g., OpenAI, Anthropic, Google Gemini) and locally deployed private models. Employees can use them out-of-the-box without individual configuration.
|
||||
- **Enterprise-Grade Knowledge Base**: Build, manage, and share team-wide knowledge bases. Ensure knowledge is retained and consistent, enabling team members to interact with AI based on unified and accurate information.
|
||||
- **Fine-Grained Access Control**: Easily manage employee accounts and assign role-based permissions for different models, knowledge bases, and features through a unified admin backend.
|
||||
- **Fully Private Deployment**: Deploy the entire backend service on your on-premises servers or private cloud, ensuring your data remains 100% private and under your control to meet the strictest security and compliance standards.
|
||||
- **Reliable Backend Services**: Provides stable API services, enterprise-grade data backup and recovery mechanisms to ensure business continuity.
|
||||
|
||||
## ✨ Online Demo
|
||||
|
||||
> 🚧 **Public Beta Notice**
|
||||
>
|
||||
> The Enterprise Edition is currently in its early public beta stage, and we are actively iterating and optimizing its features. We are aware that it may not be perfectly stable yet. If you encounter any issues or have valuable suggestions during your trial, we would be very grateful if you could contact us via email to provide feedback.
|
||||
|
||||
**🔗 [Cherry Studio Enterprise](https://www.cherry-ai.com/enterprise)**
|
||||
|
||||
## Version Comparison
|
||||
|
||||
| Feature | Community Edition | Enterprise Edition |
|
||||
| :---------------- | :----------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| **Open Source** | ✅ Yes | ⭕️ part. released to cust. |
|
||||
| **Cost** | Free for Personal Use / Commercial License | Buyout / Subscription Fee |
|
||||
| **Admin Backend** | — | ● Centralized **Model** Access<br>● **Employee** Management<br>● Shared **Knowledge Base**<br>● **Access** Control<br>● **Data** Backup |
|
||||
| **Server** | — | ✅ Dedicated Private Deployment |
|
||||
|
||||
## Get the Enterprise Edition
|
||||
|
||||
We believe the Enterprise Edition will become your team's AI productivity engine. If you are interested in Cherry Studio Enterprise Edition and would like to learn more, request a quote, or schedule a demo, please contact us.
|
||||
|
||||
- **For Business Inquiries & Purchasing**:
|
||||
**📧 [bd@cherry-ai.com](mailto:bd@cherry-ai.com)**
|
||||
|
||||
# 🔗 Related Projects
|
||||
|
||||
- [one-api](https://github.com/songquanpeng/one-api):LLM API management and distribution system, supporting mainstream models like OpenAI, Azure, and Anthropic. Features unified API interface, suitable for key management and secondary distribution.
|
||||
@ -158,22 +282,37 @@ Thank you for your support and contributions!
|
||||
</a>
|
||||
<br /><br />
|
||||
|
||||
# 🌐 Community
|
||||
|
||||
[Telegram](https://t.me/CherryStudioAI) | [Email](mailto:support@cherry-ai.com) | [Twitter](https://x.com/kangfenmao)
|
||||
|
||||
# ☕ Sponsor
|
||||
|
||||
[Buy Me a Coffee](docs/sponsor.md)
|
||||
|
||||
# 📃 License
|
||||
|
||||
[LICENSE](./LICENSE)
|
||||
|
||||
# ✉️ Contact
|
||||
|
||||
<yinsenho@cherry-ai.com>
|
||||
|
||||
# ⭐️ Star History
|
||||
|
||||
[](https://star-history.com/#kangfenmao/cherry-studio&Timeline)
|
||||
[](https://star-history.com/#CherryHQ/cherry-studio&Timeline)
|
||||
|
||||
<!-- Links & Images -->
|
||||
|
||||
[deepwiki-shield]: https://img.shields.io/badge/Deepwiki-CherryHQ-0088CC?style=plastic
|
||||
[deepwiki-link]: https://deepwiki.com/CherryHQ/cherry-studio
|
||||
[twitter-shield]: https://img.shields.io/badge/Twitter-CherryStudioApp-0088CC?style=plastic&logo=x
|
||||
[twitter-link]: https://twitter.com/CherryStudioHQ
|
||||
[discord-shield]: https://img.shields.io/badge/Discord-@CherryStudio-0088CC?style=plastic&logo=discord
|
||||
[discord-link]: https://discord.gg/wez8HtpxqQ
|
||||
[telegram-shield]: https://img.shields.io/badge/Telegram-@CherryStudioAI-0088CC?style=plastic&logo=telegram
|
||||
[telegram-link]: https://t.me/CherryStudioAI
|
||||
|
||||
<!-- Links & Images -->
|
||||
|
||||
[github-stars-shield]: https://img.shields.io/github/stars/CherryHQ/cherry-studio?style=social
|
||||
[github-stars-link]: https://github.com/CherryHQ/cherry-studio/stargazers
|
||||
[github-forks-shield]: https://img.shields.io/github/forks/CherryHQ/cherry-studio?style=social
|
||||
[github-forks-link]: https://github.com/CherryHQ/cherry-studio/network
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/CherryHQ/cherry-studio
|
||||
[github-release-link]: https://github.com/CherryHQ/cherry-studio/releases
|
||||
[github-contributors-shield]: https://img.shields.io/github/contributors/CherryHQ/cherry-studio
|
||||
[github-contributors-link]: https://github.com/CherryHQ/cherry-studio/graphs/contributors
|
||||
|
||||
<!-- Links & Images -->
|
||||
|
||||
[license-shield]: https://img.shields.io/badge/License-AGPLv3-important.svg?style=plastic&logo=gnu
|
||||
[license-link]: https://www.gnu.org/licenses/agpl-3.0
|
||||
[commercial-shield]: https://img.shields.io/badge/License-Contact-white.svg?style=plastic&logoColor=white&logo=telegram&color=blue
|
||||
[commercial-link]: mailto:license@cherry-ai.com?subject=Commercial%20License%20Inquiry
|
||||
[sponsor-shield]: https://img.shields.io/badge/Sponsor-FF6699.svg?style=plastic&logo=githubsponsors&logoColor=white
|
||||
[sponsor-link]: https://github.com/CherryHQ/cherry-studio/blob/main/docs/sponsor.md
|
||||
|
||||
@ -1,15 +1,46 @@
|
||||
<h1 align="center">
|
||||
<a href="https://github.com/CherryHQ/cherry-studio/releases">
|
||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" />
|
||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
|
||||
</a>
|
||||
</h1>
|
||||
<p align="center">
|
||||
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | <a href="./README.zh.md">中文</a> | 日本語 <br>
|
||||
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | <a href="./README.zh.md">中文</a> | 日本語 | <a href="https://cherry-ai.com">公式サイト</a> | <a href="https://docs.cherry-ai.com/cherry-studio-wen-dang/ja">ドキュメント</a> | <a href="./dev.md">開発</a> | <a href="https://github.com/CherryHQ/cherry-studio/issues">フィードバック</a><br>
|
||||
</p>
|
||||
|
||||
<!-- バッジコレクション -->
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][deepwiki-shield]][deepwiki-link]
|
||||
[![][twitter-shield]][twitter-link]
|
||||
[![][discord-shield]][discord-link]
|
||||
[![][telegram-shield]][telegram-link]
|
||||
|
||||
</div>
|
||||
|
||||
<!-- プロジェクト統計 -->
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][github-stars-shield]][github-stars-link]
|
||||
[![][github-forks-shield]][github-forks-link]
|
||||
[![][github-release-shield]][github-release-link]
|
||||
[![][github-contributors-shield]][github-contributors-link]
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][license-shield]][license-link]
|
||||
[![][commercial-shield]][commercial-link]
|
||||
[![][sponsor-shield]][sponsor-link]
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://hellogithub.com/repository/1605492e1e2a4df3be07abfa4578dd37" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=1605492e1e2a4df3be07abfa4578dd37" alt="Featured|HelloGitHub" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||
</div>
|
||||
|
||||
# 🍒 Cherry Studio
|
||||
@ -20,10 +51,6 @@ Cherry Studio は、複数の LLM プロバイダーをサポートするデス
|
||||
|
||||
❤️ Cherry Studio をお気に入りにしましたか?小さな星をつけてください 🌟 または [スポンサー](sponsor.md) をして開発をサポートしてください!
|
||||
|
||||
# 📖 ガイド
|
||||
|
||||
https://docs.cherry-ai.com
|
||||
|
||||
# 🌠 スクリーンショット
|
||||
|
||||

|
||||
@ -117,14 +144,6 @@ https://docs.cherry-ai.com
|
||||
|
||||
より多くのテーマの PR を歓迎します
|
||||
|
||||
# 🖥️ 開発
|
||||
|
||||
[開発ドキュメント](dev.md)を参照してください
|
||||
|
||||
[アーキテクチャ概要ドキュメント](https://deepwiki.com/CherryHQ/cherry-studio)を参照してください
|
||||
|
||||
[ブランチ戦略](branching-strategy-en.md)を参照して貢献ガイドラインを確認してください
|
||||
|
||||
# 🤝 貢献
|
||||
|
||||
Cherry Studio への貢献を歓迎します!以下の方法で貢献できます:
|
||||
@ -137,6 +156,8 @@ Cherry Studio への貢献を歓迎します!以下の方法で貢献できま
|
||||
6. **コミュニティの参加**:ディスカッションに参加し、ユーザーを支援します
|
||||
7. **使用の促進**:Cherry Studio を広めます
|
||||
|
||||
[ブランチ戦略](branching-strategy-en.md)を参照して貢献ガイドラインを確認してください
|
||||
|
||||
## 始め方
|
||||
|
||||
1. **リポジトリをフォーク**:フォークしてローカルマシンにクローンします
|
||||
@ -161,22 +182,34 @@ Cherry Studio への貢献を歓迎します!以下の方法で貢献できま
|
||||
</a>
|
||||
<br /><br />
|
||||
|
||||
# 🌐 コミュニティ
|
||||
|
||||
[Telegram](https://t.me/CherryStudioAI) | [Email](mailto:support@cherry-ai.com) | [Twitter](https://x.com/kangfenmao)
|
||||
|
||||
# ☕ スポンサー
|
||||
|
||||
[開発者を支援する](sponsor.md)
|
||||
|
||||
# 📃 ライセンス
|
||||
|
||||
[LICENSE](../LICENSE)
|
||||
|
||||
# ✉️ お問い合わせ
|
||||
|
||||
yinsenho@cherry-ai.com
|
||||
|
||||
# ⭐️ スター履歴
|
||||
|
||||
[](https://star-history.com/#kangfenmao/cherry-studio&Timeline)
|
||||
[](https://star-history.com/#CherryHQ/cherry-studio&Timeline)
|
||||
|
||||
<!-- リンクと画像 -->
|
||||
[deepwiki-shield]: https://img.shields.io/badge/Deepwiki-CherryHQ-0088CC?style=plastic
|
||||
[deepwiki-link]: https://deepwiki.com/CherryHQ/cherry-studio
|
||||
[twitter-shield]: https://img.shields.io/badge/Twitter-CherryStudioApp-0088CC?style=plastic&logo=x
|
||||
[twitter-link]: https://twitter.com/CherryStudioHQ
|
||||
[discord-shield]: https://img.shields.io/badge/Discord-@CherryStudio-0088CC?style=plastic&logo=discord
|
||||
[discord-link]: https://discord.gg/wez8HtpxqQ
|
||||
[telegram-shield]: https://img.shields.io/badge/Telegram-@CherryStudioAI-0088CC?style=plastic&logo=telegram
|
||||
[telegram-link]: https://t.me/CherryStudioAI
|
||||
|
||||
<!-- プロジェクト統計 -->
|
||||
[github-stars-shield]: https://img.shields.io/github/stars/CherryHQ/cherry-studio?style=social
|
||||
[github-stars-link]: https://github.com/CherryHQ/cherry-studio/stargazers
|
||||
[github-forks-shield]: https://img.shields.io/github/forks/CherryHQ/cherry-studio?style=social
|
||||
[github-forks-link]: https://github.com/CherryHQ/cherry-studio/network
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/CherryHQ/cherry-studio
|
||||
[github-release-link]: https://github.com/CherryHQ/cherry-studio/releases
|
||||
[github-contributors-shield]: https://img.shields.io/github/contributors/CherryHQ/cherry-studio
|
||||
[github-contributors-link]: https://github.com/CherryHQ/cherry-studio/graphs/contributors
|
||||
|
||||
<!-- ライセンスとスポンサー -->
|
||||
[license-shield]: https://img.shields.io/badge/License-AGPLv3-important.svg?style=plastic&logo=gnu
|
||||
[license-link]: https://www.gnu.org/licenses/agpl-3.0
|
||||
[commercial-shield]: https://img.shields.io/badge/商用ライセンス-お問い合わせ-white.svg?style=plastic&logoColor=white&logo=telegram&color=blue
|
||||
[commercial-link]: mailto:license@cherry-ai.com?subject=商業ライセンスについて
|
||||
[sponsor-shield]: https://img.shields.io/badge/スポンサー-FF6699.svg?style=plastic&logo=githubsponsors&logoColor=white
|
||||
[sponsor-link]: https://github.com/CherryHQ/cherry-studio/blob/main/docs/sponsor.md
|
||||
|
||||
@ -1,14 +1,46 @@
|
||||
<h1 align="center">
|
||||
<a href="https://github.com/CherryHQ/cherry-studio/releases">
|
||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" />
|
||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
|
||||
</a>
|
||||
</h1>
|
||||
<p align="center">
|
||||
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | 中文 | <a href="./README.ja.md">日本語</a><br>
|
||||
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | 中文 | <a href="./README.ja.md">日本語</a> | <a href="https://cherry-ai.com">官方网站</a> | <a href="https://docs.cherry-ai.com/cherry-studio-wen-dang/zh-cn">文档</a> | <a href="./dev.md">开发</a> | <a href="https://github.com/CherryHQ/cherry-studio/issues">反馈</a><br>
|
||||
</p>
|
||||
|
||||
<!-- 题头徽章组合 -->
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][deepwiki-shield]][deepwiki-link]
|
||||
[![][twitter-shield]][twitter-link]
|
||||
[![][discord-shield]][discord-link]
|
||||
[![][telegram-shield]][telegram-link]
|
||||
|
||||
</div>
|
||||
|
||||
<!-- 项目统计徽章 -->
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][github-stars-shield]][github-stars-link]
|
||||
[![][github-forks-shield]][github-forks-link]
|
||||
[![][github-release-shield]][github-release-link]
|
||||
[![][github-contributors-shield]][github-contributors-link]
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
|
||||
[![][license-shield]][license-link]
|
||||
[![][commercial-shield]][commercial-link]
|
||||
[![][sponsor-shield]][sponsor-link]
|
||||
|
||||
</div>
|
||||
|
||||
<div align="center">
|
||||
<a href="https://hellogithub.com/repository/1605492e1e2a4df3be07abfa4578dd37" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=1605492e1e2a4df3be07abfa4578dd37" alt="Featured|HelloGitHub" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||
</div>
|
||||
|
||||
# 🍒 Cherry Studio
|
||||
@ -124,14 +156,6 @@ https://docs.cherry-ai.com
|
||||
|
||||
欢迎 PR 更多主题
|
||||
|
||||
# 🖥️ 开发
|
||||
|
||||
参考[开发文档](dev.md)
|
||||
|
||||
参考[架构概览文档](https://deepwiki.com/CherryHQ/cherry-studio)
|
||||
|
||||
参考[分支策略](branching-strategy-zh.md)了解贡献指南
|
||||
|
||||
# 🤝 贡献
|
||||
|
||||
我们欢迎对 Cherry Studio 的贡献!您可以通过以下方式贡献:
|
||||
@ -144,6 +168,8 @@ https://docs.cherry-ai.com
|
||||
6. **社区参与**:加入讨论并帮助用户
|
||||
7. **推广使用**:宣传 Cherry Studio
|
||||
|
||||
参考[分支策略](branching-strategy-zh.md)了解贡献指南
|
||||
|
||||
## 入门
|
||||
|
||||
1. **Fork 仓库**:Fork 并克隆到您的本地机器
|
||||
@ -168,22 +194,34 @@ https://docs.cherry-ai.com
|
||||
</a>
|
||||
<br /><br />
|
||||
|
||||
# 🌐 社区
|
||||
|
||||
[Telegram](https://t.me/CherryStudioAI) | [Email](mailto:support@cherry-ai.com) | [Twitter](https://x.com/kangfenmao)
|
||||
|
||||
# ☕ 赞助
|
||||
|
||||
[赞助开发者](sponsor.md)
|
||||
|
||||
# 📃 许可证
|
||||
|
||||
[LICENSE](../LICENSE)
|
||||
|
||||
# ✉️ 联系我们
|
||||
|
||||
yinsenho@cherry-ai.com
|
||||
|
||||
# ⭐️ Star 记录
|
||||
|
||||
[](https://star-history.com/#kangfenmao/cherry-studio&Timeline)
|
||||
[](https://star-history.com/#CherryHQ/cherry-studio&Timeline)
|
||||
|
||||
<!-- Links & Images -->
|
||||
[deepwiki-shield]: https://img.shields.io/badge/Deepwiki-CherryHQ-0088CC?style=plastic
|
||||
[deepwiki-link]: https://deepwiki.com/CherryHQ/cherry-studio
|
||||
[twitter-shield]: https://img.shields.io/badge/Twitter-CherryStudioApp-0088CC?style=plastic&logo=x
|
||||
[twitter-link]: https://twitter.com/CherryStudioHQ
|
||||
[discord-shield]: https://img.shields.io/badge/Discord-@CherryStudio-0088CC?style=plastic&logo=discord
|
||||
[discord-link]: https://discord.gg/wez8HtpxqQ
|
||||
[telegram-shield]: https://img.shields.io/badge/Telegram-@CherryStudioAI-0088CC?style=plastic&logo=telegram
|
||||
[telegram-link]: https://t.me/CherryStudioAI
|
||||
|
||||
<!-- 项目统计徽章 -->
|
||||
[github-stars-shield]: https://img.shields.io/github/stars/CherryHQ/cherry-studio?style=social
|
||||
[github-stars-link]: https://github.com/CherryHQ/cherry-studio/stargazers
|
||||
[github-forks-shield]: https://img.shields.io/github/forks/CherryHQ/cherry-studio?style=social
|
||||
[github-forks-link]: https://github.com/CherryHQ/cherry-studio/network
|
||||
[github-release-shield]: https://img.shields.io/github/v/release/CherryHQ/cherry-studio
|
||||
[github-release-link]: https://github.com/CherryHQ/cherry-studio/releases
|
||||
[github-contributors-shield]: https://img.shields.io/github/contributors/CherryHQ/cherry-studio
|
||||
[github-contributors-link]: https://github.com/CherryHQ/cherry-studio/graphs/contributors
|
||||
|
||||
<!-- 许可和赞助徽章 -->
|
||||
[license-shield]: https://img.shields.io/badge/License-AGPLv3-important.svg?style=plastic&logo=gnu
|
||||
[license-link]: https://www.gnu.org/licenses/agpl-3.0
|
||||
[commercial-shield]: https://img.shields.io/badge/商用授权-联系-white.svg?style=plastic&logoColor=white&logo=telegram&color=blue
|
||||
[commercial-link]: mailto:license@cherry-ai.com?subject=商业授权咨询
|
||||
[sponsor-shield]: https://img.shields.io/badge/赞助支持-FF6699.svg?style=plastic&logo=githubsponsors&logoColor=white
|
||||
[sponsor-link]: https://github.com/CherryHQ/cherry-studio/blob/main/docs/sponsor.md
|
||||
|
||||
214
docs/technical/how-to-write-middlewares.md
Normal file
214
docs/technical/how-to-write-middlewares.md
Normal file
@ -0,0 +1,214 @@
|
||||
# 如何为 AI Provider 编写中间件
|
||||
|
||||
本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。
|
||||
|
||||
## 架构概览
|
||||
|
||||
我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。
|
||||
|
||||
- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。
|
||||
- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。
|
||||
- **上下文 (Context)**: 一个在中间件之间传递的对象,携带了关于当前调用的信息(如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。
|
||||
|
||||
## 中间件的类型
|
||||
|
||||
目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景:
|
||||
|
||||
1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。
|
||||
2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。
|
||||
|
||||
## 编写一个 `CompletionsMiddleware`
|
||||
|
||||
`CompletionsMiddleware` 的基本签名(TypeScript 类型)如下:
|
||||
|
||||
```typescript
|
||||
import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径
|
||||
|
||||
export type CompletionsMiddleware = (
|
||||
api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>
|
||||
) => (
|
||||
next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any> // next 返回 Promise<any> 代表原始SDK响应或下游中间件的结果
|
||||
) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<void> // 最内层函数通常返回 Promise<void>,因为结果通过 onChunk 或 context 副作用传递
|
||||
```
|
||||
|
||||
让我们分解这个三段式结构:
|
||||
|
||||
1. **第一层函数 `(api) => { ... }`**:
|
||||
|
||||
- 接收一个 `api` 对象。
|
||||
- `api` 对象提供了以下方法:
|
||||
- `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。
|
||||
- `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。
|
||||
- `api.getProviderId()`: 获取当前 Provider 的 ID。
|
||||
- `api.getProviderInstance()`: 获取原始的 Provider 实例。
|
||||
- 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。
|
||||
|
||||
2. **第二层函数 `(next) => { ... }`**:
|
||||
|
||||
- 接收一个 `next` 函数。
|
||||
- `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。
|
||||
- `next` 函数接收当前的 `context` 和 `params` (这些可能已被上游中间件修改)。
|
||||
- **重要的是**:`next` 的返回类型通常是 `Promise<any>`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK,它将返回原始的 SDK 响应(例如,OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。
|
||||
- 此函数返回第三层(也是最核心的)函数。
|
||||
|
||||
3. **第三层函数 `(context, params) => { ... }`**:
|
||||
- 这是执行中间件主要逻辑的地方。
|
||||
- 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。
|
||||
- 在此函数中,你可以:
|
||||
- **在调用 `next` 之前**:
|
||||
- 读取或修改 `params`。例如,添加默认参数、转换消息格式。
|
||||
- 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。
|
||||
- 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。
|
||||
- **调用 `await next(context, params)`**:
|
||||
- 这是将控制权传递给下游的关键步骤。
|
||||
- `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。
|
||||
- **在调用 `next` 之后**:
|
||||
- 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。
|
||||
- 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。
|
||||
- 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。
|
||||
|
||||
### 示例:一个简单的日志中间件
|
||||
|
||||
```typescript
|
||||
import {
|
||||
AiProviderMiddlewareCompletionsContext,
|
||||
CompletionsParams,
|
||||
MiddlewareAPI,
|
||||
OnChunkFunction // 假设 OnChunkFunction 类型被导出
|
||||
} from './AiProviderMiddlewareTypes' // 调整路径
|
||||
import { ChunkType } from '@renderer/types' // 调整路径
|
||||
|
||||
export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => {
|
||||
return (api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>) => {
|
||||
// console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`);
|
||||
|
||||
return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any>) => {
|
||||
return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise<void> => {
|
||||
const startTime = Date.now()
|
||||
// 从 context 中获取 onChunk (它最初来自 params.onChunk)
|
||||
const onChunk = context.onChunk
|
||||
|
||||
console.log(
|
||||
`[LoggingMiddleware] Request for ${context.methodName} with params:`,
|
||||
params.messages?.[params.messages.length - 1]?.content
|
||||
)
|
||||
|
||||
try {
|
||||
// 调用下一个中间件或核心逻辑
|
||||
// `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象)
|
||||
const rawSdkResponse = await next(context, params)
|
||||
|
||||
// 此处简单示例不处理 rawSdkResponse,假设下游中间件 (如 StreamingResponseHandler)
|
||||
// 会处理它并通过 onChunk 发送数据。
|
||||
// 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。
|
||||
// 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`)
|
||||
|
||||
// 假设下游已经通过 onChunk 发送了所有数据。
|
||||
// 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送,
|
||||
// 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error)
|
||||
|
||||
// 如果 onChunk 可用,可以尝试发送一个错误块
|
||||
if (onChunk) {
|
||||
onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack }
|
||||
})
|
||||
// 考虑是否还需要发送 BLOCK_COMPLETE 来结束流
|
||||
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} })
|
||||
}
|
||||
throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `AiProviderMiddlewareCompletionsContext` 的重要性
|
||||
|
||||
`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含:
|
||||
|
||||
- `methodName`: 当前调用的方法名 (总是 `'completions'`)。
|
||||
- `originalArgs`: 传递给 `completions` 的原始参数数组。
|
||||
- `providerId`: Provider 的 ID。
|
||||
- `_providerInstance`: Provider 实例。
|
||||
- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。**
|
||||
- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。
|
||||
- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`。
|
||||
|
||||
**关键**: 当你在中间件中修改 `params` 或 `context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。
|
||||
|
||||
### 中间件的顺序
|
||||
|
||||
中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。
|
||||
|
||||
- 请求首先通过第一个中间件,然后是第二个,依此类推。
|
||||
- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。
|
||||
|
||||
例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`:
|
||||
|
||||
1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。
|
||||
2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||
3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||
4. 核心SDK调用(或链的末端)。
|
||||
5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||
6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。
|
||||
7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||
|
||||
### 注册中间件
|
||||
|
||||
中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。
|
||||
|
||||
```typescript
|
||||
// register.ts
|
||||
import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes'
|
||||
import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件
|
||||
import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的
|
||||
|
||||
const middlewareConfig: AiProviderMiddlewareConfig = {
|
||||
completions: [
|
||||
createSimpleLoggingMiddleware(), // 你新加的中间件
|
||||
createCompletionsLoggingMiddleware() // 已有的日志中间件
|
||||
// ... 其他 completions 中间件
|
||||
],
|
||||
methods: {
|
||||
// translate: [createGenericLoggingMiddleware()],
|
||||
// ... 其他方法的中间件
|
||||
}
|
||||
}
|
||||
|
||||
export default middlewareConfig
|
||||
```
|
||||
|
||||
### 最佳实践
|
||||
|
||||
1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。
|
||||
2. **无副作用 (尽可能)**: 除了通过 `context` 或 `onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。
|
||||
3. **错误处理**:
|
||||
- 在中间件内部使用 `try...catch` 来处理可能发生的错误。
|
||||
- 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。
|
||||
- 如果重新抛出,确保错误对象包含足够的信息。
|
||||
4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作,确保它们是异步的。
|
||||
5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。
|
||||
6. **上下文管理**:
|
||||
- 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。
|
||||
- 明确你添加到 `context` 的字段的用途和生命周期。
|
||||
7. **`next` 的调用**:
|
||||
- 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。
|
||||
- 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。
|
||||
8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。
|
||||
9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。
|
||||
|
||||
### 调试技巧
|
||||
|
||||
- 在中间件的关键点使用 `console.log` 或调试器来检查 `params`、`context` 的状态以及 `next` 的返回值。
|
||||
- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。
|
||||
- 编写单元测试来独立验证每个中间件的行为。
|
||||
|
||||
通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。
|
||||
@ -11,13 +11,19 @@ electronLanguages:
|
||||
- en # for macOS
|
||||
directories:
|
||||
buildResources: build
|
||||
|
||||
protocols:
|
||||
- name: Cherry Studio
|
||||
schemes:
|
||||
- cherrystudio
|
||||
files:
|
||||
- '**/*'
|
||||
- '!{.vscode,.yarn,.yarn-lock,.github,.cursorrules,.prettierrc}'
|
||||
- '!electron.vite.config.{js,ts,mjs,cjs}'
|
||||
- '!{.eslintignore,.eslintrc.cjs,.prettierignore,.prettierrc.yaml,eslint.config.mjs,dev-app-update.yml,CHANGELOG.md,README.md}'
|
||||
- '!{.env,.env.*,.npmrc,pnpm-lock.yaml}'
|
||||
- '!{tsconfig.json,tsconfig.node.json,tsconfig.web.json}'
|
||||
- '!**/{.vscode,.yarn,.yarn-lock,.github,.cursorrules,.prettierrc}'
|
||||
- '!electron.vite.config.{js,ts,mjs,cjs}}'
|
||||
- '!**/{.eslintignore,.eslintrc.js,.eslintrc.json,.eslintcache,root.eslint.config.js,eslint.config.js,.eslintrc.cjs,.prettierignore,.prettierrc.yaml,eslint.config.mjs,dev-app-update.yml,CHANGELOG.md,README.md}'
|
||||
- '!**/{.env,.env.*,.npmrc,pnpm-lock.yaml}'
|
||||
- '!**/{tsconfig.json,tsconfig.tsbuildinfo,tsconfig.node.json,tsconfig.web.json}'
|
||||
- '!**/{.editorconfig,.jekyll-metadata}'
|
||||
- '!src'
|
||||
- '!scripts'
|
||||
- '!local'
|
||||
@ -36,8 +42,11 @@ files:
|
||||
- '!**/*.{spec,test}.{js,jsx,ts,tsx}'
|
||||
- '!**/*.min.*.map'
|
||||
- '!**/*.d.ts'
|
||||
- '!**/dist/es6/**'
|
||||
- '!**/dist/demo/**'
|
||||
- '!**/amd/**'
|
||||
- '!**/{.DS_Store,Thumbs.db,thumbs.db,__pycache__}'
|
||||
- '!**/{LICENSE,LICENSE.txt,LICENSE-MIT.txt,*.LICENSE.txt,NOTICE.txt,README.md,readme.md,CHANGELOG.md}'
|
||||
- '!**/{LICENSE,license,LICENSE.*,*.LICENSE.txt,NOTICE.txt,README.md,readme.md,CHANGELOG.md}'
|
||||
- '!node_modules/rollup-plugin-visualizer'
|
||||
- '!node_modules/js-tiktoken'
|
||||
- '!node_modules/@tavily/core/node_modules/js-tiktoken'
|
||||
@ -89,6 +98,7 @@ linux:
|
||||
artifactName: ${productName}-${version}-${arch}.${ext}
|
||||
target:
|
||||
- target: AppImage
|
||||
- target: deb
|
||||
maintainer: electronjs.org
|
||||
category: Utility
|
||||
desktop:
|
||||
@ -106,10 +116,10 @@ afterSign: scripts/notarize.js
|
||||
artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||
releaseInfo:
|
||||
releaseNotes: |
|
||||
⚠️ 注意:升级前请备份数据,否则将无法降级
|
||||
文生图新增服务商 DMXAPI(限时免费)
|
||||
输入框按钮支持拖拽排序
|
||||
修复知识库搜索结果 100% 问题
|
||||
修复拖拽多选消息相关问题
|
||||
修复翻译回复内容导致内存异常问题
|
||||
常规错误修复和优化
|
||||
界面优化:优化多处界面样式,气泡样式改版,自动调整代码预览边栏宽度
|
||||
知识库:修复知识库引用不显示问题,修复部分嵌入模型适配问题
|
||||
备份与恢复:修复超过 2GB 大文件无法恢复问题
|
||||
文件处理:添加 .doc 文件支持
|
||||
划词助手:支持自定义 CSS 样式
|
||||
MCP:基于 Pyodide 实现 Python MCP 服务
|
||||
其他错误修复和优化
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import react from '@vitejs/plugin-react-swc'
|
||||
import { CodeInspectorPlugin } from 'code-inspector-plugin'
|
||||
import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
|
||||
import { resolve } from 'path'
|
||||
import { visualizer } from 'rollup-plugin-visualizer'
|
||||
@ -9,25 +10,7 @@ const visualizerPlugin = (type: 'renderer' | 'main') => {
|
||||
|
||||
export default defineConfig({
|
||||
main: {
|
||||
plugins: [
|
||||
externalizeDepsPlugin({
|
||||
exclude: [
|
||||
'@cherrystudio/embedjs',
|
||||
'@cherrystudio/embedjs-openai',
|
||||
'@cherrystudio/embedjs-loader-web',
|
||||
'@cherrystudio/embedjs-loader-markdown',
|
||||
'@cherrystudio/embedjs-loader-msoffice',
|
||||
'@cherrystudio/embedjs-loader-xml',
|
||||
'@cherrystudio/embedjs-loader-pdf',
|
||||
'@cherrystudio/embedjs-loader-sitemap',
|
||||
'@cherrystudio/embedjs-libsql',
|
||||
'@cherrystudio/embedjs-loader-image',
|
||||
'p-queue',
|
||||
'webdav'
|
||||
]
|
||||
}),
|
||||
...visualizerPlugin('main')
|
||||
],
|
||||
plugins: [externalizeDepsPlugin(), ...visualizerPlugin('main')],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@main': resolve('src/main'),
|
||||
@ -38,7 +21,13 @@ export default defineConfig({
|
||||
},
|
||||
build: {
|
||||
rollupOptions: {
|
||||
external: ['@libsql/client', 'bufferutil', 'utf-8-validate']
|
||||
external: ['@libsql/client', 'bufferutil', 'utf-8-validate'],
|
||||
output: {
|
||||
// 彻底禁用代码分割 - 返回 null 强制单文件打包
|
||||
manualChunks: undefined,
|
||||
// 内联所有动态导入,这是关键配置
|
||||
inlineDynamicImports: true
|
||||
}
|
||||
},
|
||||
sourcemap: process.env.NODE_ENV === 'development'
|
||||
},
|
||||
@ -72,6 +61,14 @@ export default defineConfig({
|
||||
]
|
||||
]
|
||||
}),
|
||||
// 只在开发环境下启用 CodeInspectorPlugin
|
||||
...(process.env.NODE_ENV === 'development'
|
||||
? [
|
||||
CodeInspectorPlugin({
|
||||
bundler: 'vite'
|
||||
})
|
||||
]
|
||||
: []),
|
||||
...visualizerPlugin('renderer')
|
||||
],
|
||||
resolve: {
|
||||
@ -81,12 +78,16 @@ export default defineConfig({
|
||||
}
|
||||
},
|
||||
optimizeDeps: {
|
||||
exclude: ['pyodide']
|
||||
exclude: ['pyodide'],
|
||||
esbuildOptions: {
|
||||
target: 'esnext' // for dev
|
||||
}
|
||||
},
|
||||
worker: {
|
||||
format: 'es'
|
||||
},
|
||||
build: {
|
||||
target: 'esnext', // for build
|
||||
rollupOptions: {
|
||||
input: {
|
||||
index: resolve(__dirname, 'src/renderer/index.html'),
|
||||
|
||||
130
package.json
130
package.json
@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.3.12",
|
||||
"version": "1.4.7",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
@ -22,7 +22,7 @@
|
||||
"dev": "electron-vite dev",
|
||||
"debug": "electron-vite -- --inspect --sourcemap --remote-debugging-port=9222",
|
||||
"build": "npm run typecheck && electron-vite build",
|
||||
"build:check": "yarn test && yarn typecheck && yarn check:i18n",
|
||||
"build:check": "yarn typecheck && yarn check:i18n && yarn test",
|
||||
"build:unpack": "dotenv npm run build && electron-builder --dir",
|
||||
"build:win": "dotenv npm run build && electron-builder --win --x64 --arm64",
|
||||
"build:win:x64": "dotenv npm run build && electron-builder --win --x64",
|
||||
@ -38,7 +38,6 @@
|
||||
"publish": "yarn build:check && yarn release patch push",
|
||||
"pulish:artifacts": "cd packages/artifacts && npm publish && cd -",
|
||||
"generate:agents": "yarn workspace @cherry-studio/database agents",
|
||||
"generate:icons": "electron-icon-builder --input=./build/logo.png --output=build",
|
||||
"analyze:renderer": "VISUALIZER_RENDERER=true yarn build",
|
||||
"analyze:main": "VISUALIZER_MAIN=true yarn build",
|
||||
"typecheck": "npm run typecheck:node && npm run typecheck:web",
|
||||
@ -48,6 +47,7 @@
|
||||
"test": "vitest run --silent",
|
||||
"test:main": "vitest run --project main",
|
||||
"test:renderer": "vitest run --project renderer",
|
||||
"test:update": "yarn test:renderer --update",
|
||||
"test:coverage": "vitest run --coverage --silent",
|
||||
"test:ui": "vitest --ui",
|
||||
"test:watch": "vitest",
|
||||
@ -59,6 +59,23 @@
|
||||
"migrations:generate": "drizzle-kit generate --config ./migrations/sqlite-drizzle.config.ts"
|
||||
},
|
||||
"dependencies": {
|
||||
"@libsql/client": "0.14.0",
|
||||
"@libsql/win32-x64-msvc": "^0.4.7",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"jsdom": "26.1.0",
|
||||
"macos-release": "^3.4.0",
|
||||
"node-stream-zip": "^1.15.0",
|
||||
"notion-helper": "^1.3.22",
|
||||
"os-proxy-config": "^1.1.2",
|
||||
"selection-hook": "^0.9.23",
|
||||
"turndown": "7.2.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||
@ -69,61 +86,30 @@
|
||||
"@cherrystudio/embedjs-loader-sitemap": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-web": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-xml": "^0.1.31",
|
||||
"@cherrystudio/embedjs-ollama": "^0.1.31",
|
||||
"@cherrystudio/embedjs-openai": "^0.1.31",
|
||||
"@electron-toolkit/utils": "^3.0.0",
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@libsql/client": "^0.15.7",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"@tanstack/react-query": "^5.27.0",
|
||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"diff": "^7.0.0",
|
||||
"docx": "^9.0.2",
|
||||
"drizzle-orm": "^0.43.1",
|
||||
"electron-log": "^5.1.5",
|
||||
"electron-store": "^8.2.0",
|
||||
"electron-updater": "6.6.4",
|
||||
"electron-window-state": "^5.0.3",
|
||||
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
||||
"fast-xml-parser": "^5.2.0",
|
||||
"fs-extra": "^11.2.0",
|
||||
"jsdom": "^26.0.0",
|
||||
"markdown-it": "^14.1.0",
|
||||
"node-stream-zip": "^1.15.0",
|
||||
"officeparser": "^4.1.1",
|
||||
"os-proxy-config": "^1.1.2",
|
||||
"proxy-agent": "^6.5.0",
|
||||
"selection-hook": "^0.9.14",
|
||||
"tar": "^7.4.3",
|
||||
"turndown": "^7.2.0",
|
||||
"webdav": "^5.8.0",
|
||||
"zipread": "^1.3.3"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@electron-toolkit/eslint-config-prettier": "^3.0.0",
|
||||
"@electron-toolkit/eslint-config-ts": "^3.0.0",
|
||||
"@electron-toolkit/preload": "^3.0.0",
|
||||
"@electron-toolkit/tsconfig": "^1.0.1",
|
||||
"@electron-toolkit/utils": "^3.0.0",
|
||||
"@electron/notarize": "^2.5.0",
|
||||
"@emotion/is-prop-valid": "^1.3.1",
|
||||
"@eslint-react/eslint-plugin": "^1.36.1",
|
||||
"@eslint/js": "^9.22.0",
|
||||
"@google/genai": "^0.13.0",
|
||||
"@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
|
||||
"@hello-pangea/dnd": "^16.6.0",
|
||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||
"@langchain/community": "^0.3.36",
|
||||
"@langchain/ollama": "^0.2.1",
|
||||
"@modelcontextprotocol/sdk": "^1.11.4",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@playwright/test": "^1.52.0",
|
||||
"@reduxjs/toolkit": "^2.2.5",
|
||||
"@shikijs/markdown-it": "^3.4.2",
|
||||
"@shikijs/markdown-it": "^3.7.0",
|
||||
"@swc/plugin-styled-components": "^7.1.5",
|
||||
"@tanstack/react-query": "^5.27.0",
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
"@testing-library/jest-dom": "^6.6.3",
|
||||
"@testing-library/react": "^16.3.0",
|
||||
@ -140,37 +126,51 @@
|
||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||
"@types/react-window": "^1",
|
||||
"@types/tinycolor2": "^1",
|
||||
"@types/ws": "^8",
|
||||
"@uiw/codemirror-extensions-langs": "^4.23.12",
|
||||
"@uiw/codemirror-themes-all": "^4.23.12",
|
||||
"@uiw/react-codemirror": "^4.23.12",
|
||||
"@types/word-extractor": "^1",
|
||||
"@uiw/codemirror-extensions-langs": "^4.23.14",
|
||||
"@uiw/codemirror-themes-all": "^4.23.14",
|
||||
"@uiw/react-codemirror": "^4.23.14",
|
||||
"@vitejs/plugin-react-swc": "^3.9.0",
|
||||
"@vitest/browser": "^3.1.4",
|
||||
"@vitest/coverage-v8": "^3.1.4",
|
||||
"@vitest/ui": "^3.1.4",
|
||||
"@vitest/web-worker": "^3.1.4",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"antd": "^5.22.5",
|
||||
"antd": "patch:antd@npm%3A5.24.7#~/.yarn/patches/antd-npm-5.24.7-356a553ae5.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
"axios": "^1.7.3",
|
||||
"browser-image-compression": "^2.0.2",
|
||||
"code-inspector-plugin": "^0.20.14",
|
||||
"color": "^5.0.0",
|
||||
"country-flag-emoji-polyfill": "0.1.8",
|
||||
"dayjs": "^1.11.11",
|
||||
"dexie": "^4.0.8",
|
||||
"dexie-react-hooks": "^1.1.7",
|
||||
"diff": "^7.0.0",
|
||||
"docx": "^9.0.2",
|
||||
"dotenv-cli": "^7.4.2",
|
||||
"drizzle-kit": "^0.31.1",
|
||||
"electron": "35.4.0",
|
||||
"electron": "35.6.0",
|
||||
"electron-builder": "26.0.15",
|
||||
"electron-devtools-installer": "^3.2.0",
|
||||
"electron-icon-builder": "^2.0.1",
|
||||
"electron-log": "^5.1.5",
|
||||
"electron-store": "^8.2.0",
|
||||
"electron-updater": "6.6.4",
|
||||
"electron-vite": "^3.1.0",
|
||||
"electron-window-state": "^5.0.3",
|
||||
"emittery": "^1.0.3",
|
||||
"emoji-picker-element": "^1.22.1",
|
||||
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
||||
"eslint": "^9.22.0",
|
||||
"eslint-plugin-react-hooks": "^5.2.0",
|
||||
"eslint-plugin-simple-import-sort": "^12.1.1",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"fast-diff": "^1.3.0",
|
||||
"fast-xml-parser": "^5.2.0",
|
||||
"franc-min": "^6.2.0",
|
||||
"fs-extra": "^11.2.0",
|
||||
"google-auth-library": "^9.15.1",
|
||||
"html-to-image": "^1.11.13",
|
||||
"husky": "^9.1.7",
|
||||
"i18next": "^23.11.5",
|
||||
@ -179,21 +179,24 @@
|
||||
"lodash": "^4.17.21",
|
||||
"lru-cache": "^11.1.0",
|
||||
"lucide-react": "^0.487.0",
|
||||
"mermaid": "^11.6.0",
|
||||
"markdown-it": "^14.1.0",
|
||||
"mermaid": "^11.7.0",
|
||||
"mime": "^4.0.4",
|
||||
"motion": "^12.10.5",
|
||||
"npx-scope-finder": "^1.2.0",
|
||||
"openai": "patch:openai@npm%3A4.96.0#~/.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch",
|
||||
"officeparser": "^4.1.1",
|
||||
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
|
||||
"p-queue": "^8.1.0",
|
||||
"playwright": "^1.52.0",
|
||||
"prettier": "^3.5.3",
|
||||
"proxy-agent": "^6.5.0",
|
||||
"rc-virtual-list": "^3.18.6",
|
||||
"react": "^19.0.0",
|
||||
"react-dom": "^19.0.0",
|
||||
"react-hotkeys-hook": "^4.6.1",
|
||||
"react-i18next": "^14.1.2",
|
||||
"react-infinite-scroll-component": "^6.1.0",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-markdown": "^10.1.0",
|
||||
"react-redux": "^9.1.2",
|
||||
"react-router": "6",
|
||||
"react-router-dom": "6",
|
||||
@ -202,34 +205,39 @@
|
||||
"redux": "^5.0.1",
|
||||
"redux-persist": "^6.0.0",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-mathjax": "^7.0.0",
|
||||
"rehype-mathjax": "^7.1.0",
|
||||
"rehype-raw": "^7.0.0",
|
||||
"remark-cjk-friendly": "^1.1.0",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-cjk-friendly": "^1.2.0",
|
||||
"remark-gfm": "^4.0.1",
|
||||
"remark-math": "^6.0.0",
|
||||
"remove-markdown": "^0.6.2",
|
||||
"rollup-plugin-visualizer": "^5.12.0",
|
||||
"sass": "^1.88.0",
|
||||
"shiki": "^3.4.2",
|
||||
"shiki": "^3.7.0",
|
||||
"string-width": "^7.2.0",
|
||||
"styled-components": "^6.1.11",
|
||||
"tar": "^7.4.3",
|
||||
"tiny-pinyin": "^1.3.2",
|
||||
"tokenx": "^0.4.1",
|
||||
"tokenx": "^1.1.0",
|
||||
"typescript": "^5.6.2",
|
||||
"uuid": "^10.0.0",
|
||||
"vite": "6.2.6",
|
||||
"vitest": "^3.1.4"
|
||||
"vitest": "^3.1.4",
|
||||
"webdav": "^5.8.0",
|
||||
"word-extractor": "^1.0.4",
|
||||
"zipread": "^1.3.3"
|
||||
},
|
||||
"resolutions": {
|
||||
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
||||
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
|
||||
"@langchain/openai@npm:>=0.1.0 <0.4.0": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
|
||||
"node-gyp": "^9.1.0",
|
||||
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
|
||||
"openai@npm:^4.77.0": "patch:openai@npm%3A4.96.0#~/.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch",
|
||||
"openai@npm:^4.77.0": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
|
||||
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
|
||||
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
|
||||
"openai@npm:^4.87.3": "patch:openai@npm%3A4.96.0#~/.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch",
|
||||
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch"
|
||||
"openai@npm:^4.87.3": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
|
||||
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
|
||||
"@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A0.3.44#~/.yarn/patches/@langchain-core-npm-0.3.44-41d5c3cb0a.patch"
|
||||
},
|
||||
"packageManager": "yarn@4.9.1",
|
||||
"lint-staged": {
|
||||
|
||||
@ -3,6 +3,8 @@ export enum IpcChannel {
|
||||
App_ClearCache = 'app:clear-cache',
|
||||
App_SetLaunchOnBoot = 'app:set-launch-on-boot',
|
||||
App_SetLanguage = 'app:set-language',
|
||||
App_SetEnableSpellCheck = 'app:set-enable-spell-check',
|
||||
App_SetSpellCheckLanguages = 'app:set-spell-check-languages',
|
||||
App_ShowUpdateDialog = 'app:show-update-dialog',
|
||||
App_CheckForUpdate = 'app:check-for-update',
|
||||
App_Reload = 'app:reload',
|
||||
@ -11,20 +13,32 @@ export enum IpcChannel {
|
||||
App_SetLaunchToTray = 'app:set-launch-to-tray',
|
||||
App_SetTray = 'app:set-tray',
|
||||
App_SetTrayOnClose = 'app:set-tray-on-close',
|
||||
App_RestartTray = 'app:restart-tray',
|
||||
App_SetTheme = 'app:set-theme',
|
||||
App_SetAutoUpdate = 'app:set-auto-update',
|
||||
App_SetTestPlan = 'app:set-test-plan',
|
||||
App_SetTestChannel = 'app:set-test-channel',
|
||||
App_HandleZoomFactor = 'app:handle-zoom-factor',
|
||||
|
||||
App_Select = 'app:select',
|
||||
App_HasWritePermission = 'app:has-write-permission',
|
||||
App_Copy = 'app:copy',
|
||||
App_SetStopQuitApp = 'app:set-stop-quit-app',
|
||||
App_SetAppDataPath = 'app:set-app-data-path',
|
||||
App_GetDataPathFromArgs = 'app:get-data-path-from-args',
|
||||
App_FlushAppData = 'app:flush-app-data',
|
||||
App_IsNotEmptyDir = 'app:is-not-empty-dir',
|
||||
App_RelaunchApp = 'app:relaunch-app',
|
||||
App_IsBinaryExist = 'app:is-binary-exist',
|
||||
App_GetBinaryPath = 'app:get-binary-path',
|
||||
App_InstallUvBinary = 'app:install-uv-binary',
|
||||
App_InstallBunBinary = 'app:install-bun-binary',
|
||||
|
||||
App_QuoteToMain = 'app:quote-to-main',
|
||||
|
||||
Notification_Send = 'notification:send',
|
||||
Notification_OnClick = 'notification:on-click',
|
||||
|
||||
Webview_SetOpenLinkExternal = 'webview:set-open-link-external',
|
||||
Webview_SetSpellCheckEnabled = 'webview:set-spell-check-enabled',
|
||||
|
||||
// Open
|
||||
Open_Path = 'open:path',
|
||||
@ -57,6 +71,9 @@ export enum IpcChannel {
|
||||
Mcp_ServersUpdated = 'mcp:servers-updated',
|
||||
Mcp_CheckConnectivity = 'mcp:check-connectivity',
|
||||
|
||||
// Python
|
||||
Python_Execute = 'python:execute',
|
||||
|
||||
//copilot
|
||||
Copilot_GetAuthMessage = 'copilot:get-auth-message',
|
||||
Copilot_GetCopilotToken = 'copilot:get-copilot-token',
|
||||
@ -84,6 +101,10 @@ export enum IpcChannel {
|
||||
Gemini_ListFiles = 'gemini:list-files',
|
||||
Gemini_DeleteFile = 'gemini:delete-file',
|
||||
|
||||
// VertexAI
|
||||
VertexAI_GetAuthHeaders = 'vertexai:get-auth-headers',
|
||||
VertexAI_ClearAuthCache = 'vertexai:clear-auth-cache',
|
||||
|
||||
Windows_ResetMinimumSize = 'window:reset-minimum-size',
|
||||
Windows_SetMinimumSize = 'window:set-minimum-size',
|
||||
|
||||
@ -111,10 +132,12 @@ export enum IpcChannel {
|
||||
File_WriteWithId = 'file:writeWithId',
|
||||
File_SaveImage = 'file:saveImage',
|
||||
File_Base64Image = 'file:base64Image',
|
||||
File_SaveBase64Image = 'file:saveBase64Image',
|
||||
File_Download = 'file:download',
|
||||
File_Copy = 'file:copy',
|
||||
File_BinaryImage = 'file:binaryImage',
|
||||
File_Base64File = 'file:base64File',
|
||||
File_GetPdfInfo = 'file:getPdfInfo',
|
||||
Fs_Read = 'fs:read',
|
||||
|
||||
Export_Word = 'export:word',
|
||||
@ -144,7 +167,7 @@ export enum IpcChannel {
|
||||
|
||||
// events
|
||||
BackupProgress = 'backup-progress',
|
||||
ThemeChange = 'theme:change',
|
||||
ThemeUpdated = 'theme:updated',
|
||||
UpdateDownloadedCancelled = 'update-downloaded-cancelled',
|
||||
RestoreProgress = 'restore-progress',
|
||||
UpdateError = 'update-error',
|
||||
@ -186,7 +209,10 @@ export enum IpcChannel {
|
||||
Selection_WriteToClipboard = 'selection:write-to-clipboard',
|
||||
Selection_SetEnabled = 'selection:set-enabled',
|
||||
Selection_SetTriggerMode = 'selection:set-trigger-mode',
|
||||
Selection_SetFilterMode = 'selection:set-filter-mode',
|
||||
Selection_SetFilterList = 'selection:set-filter-list',
|
||||
Selection_SetFollowToolbar = 'selection:set-follow-toolbar',
|
||||
Selection_SetRemeberWinSize = 'selection:set-remeber-win-size',
|
||||
Selection_ActionWindowClose = 'selection:action-window-close',
|
||||
Selection_ActionWindowMinimize = 'selection:action-window-minimize',
|
||||
Selection_ActionWindowPin = 'selection:action-window-pin',
|
||||
|
||||
@ -1,138 +1,371 @@
|
||||
export const imageExts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']
|
||||
export const videoExts = ['.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv']
|
||||
export const audioExts = ['.mp3', '.wav', '.ogg', '.flac', '.aac']
|
||||
export const documentExts = ['.pdf', '.docx', '.pptx', '.xlsx', '.odt', '.odp', '.ods']
|
||||
export const documentExts = ['.pdf', '.doc', '.docx', '.pptx', '.xlsx', '.odt', '.odp', '.ods']
|
||||
export const thirdPartyApplicationExts = ['.draftsExport']
|
||||
export const bookExts = ['.epub']
|
||||
export const textExts = [
|
||||
'.txt', // 普通文本文件
|
||||
'.md', // Markdown 文件
|
||||
'.mdx', // Markdown 文件
|
||||
'.html', // HTML 文件
|
||||
'.htm', // HTML 文件的另一种扩展名
|
||||
'.xml', // XML 文件
|
||||
'.json', // JSON 文件
|
||||
'.yaml', // YAML 文件
|
||||
'.yml', // YAML 文件的另一种扩展名
|
||||
'.csv', // 逗号分隔值文件
|
||||
'.tsv', // 制表符分隔值文件
|
||||
'.ini', // 配置文件
|
||||
'.log', // 日志文件
|
||||
'.rtf', // 富文本格式文件
|
||||
'.org', // org-mode 文件
|
||||
'.wiki', // VimWiki 文件
|
||||
'.tex', // LaTeX 文件
|
||||
'.bib', // BibTeX 文件
|
||||
'.srt', // 字幕文件
|
||||
'.xhtml', // XHTML 文件
|
||||
'.nfo', // 信息文件(主要用于场景发布)
|
||||
'.conf', // 配置文件
|
||||
'.config', // 配置文件
|
||||
'.env', // 环境变量文件
|
||||
'.rst', // reStructuredText 文件
|
||||
'.php', // PHP 脚本文件,包含嵌入的 HTML
|
||||
'.js', // JavaScript 文件(部分是文本,部分可能包含代码)
|
||||
'.ts', // TypeScript 文件
|
||||
'.jsp', // JavaServer Pages 文件
|
||||
'.aspx', // ASP.NET 文件
|
||||
'.bat', // Windows 批处理文件
|
||||
'.sh', // Unix/Linux Shell 脚本文件
|
||||
'.py', // Python 脚本文件
|
||||
'.ipynb', // Jupyter 笔记本格式
|
||||
'.rb', // Ruby 脚本文件
|
||||
'.pl', // Perl 脚本文件
|
||||
'.sql', // SQL 脚本文件
|
||||
'.css', // Cascading Style Sheets 文件
|
||||
'.less', // Less CSS 预处理器文件
|
||||
'.scss', // Sass CSS 预处理器文件
|
||||
'.sass', // Sass 文件
|
||||
'.styl', // Stylus CSS 预处理器文件
|
||||
'.coffee', // CoffeeScript 文件
|
||||
'.ino', // Arduino 代码文件
|
||||
'.asm', // Assembly 语言文件
|
||||
'.go', // Go 语言文件
|
||||
'.scala', // Scala 语言文件
|
||||
'.swift', // Swift 语言文件
|
||||
'.kt', // Kotlin 语言文件
|
||||
'.rs', // Rust 语言文件
|
||||
'.lua', // Lua 语言文件
|
||||
'.groovy', // Groovy 语言文件
|
||||
'.dart', // Dart 语言文件
|
||||
'.hs', // Haskell 语言文件
|
||||
'.clj', // Clojure 语言文件
|
||||
'.cljs', // ClojureScript 语言文件
|
||||
'.elm', // Elm 语言文件
|
||||
'.erl', // Erlang 语言文件
|
||||
'.ex', // Elixir 语言文件
|
||||
'.exs', // Elixir 脚本文件
|
||||
'.pug', // Pug (formerly Jade) 模板文件
|
||||
'.haml', // Haml 模板文件
|
||||
'.slim', // Slim 模板文件
|
||||
'.tpl', // 模板文件(通用)
|
||||
'.ejs', // Embedded JavaScript 模板文件
|
||||
'.hbs', // Handlebars 模板文件
|
||||
'.mustache', // Mustache 模板文件
|
||||
'.jade', // Jade 模板文件 (已重命名为 Pug)
|
||||
'.twig', // Twig 模板文件
|
||||
'.blade', // Blade 模板文件 (Laravel)
|
||||
'.vue', // Vue.js 单文件组件
|
||||
'.jsx', // React JSX 文件
|
||||
'.tsx', // React TSX 文件
|
||||
'.graphql', // GraphQL 查询语言文件
|
||||
'.gql', // GraphQL 查询语言文件
|
||||
'.proto', // Protocol Buffers 文件
|
||||
'.thrift', // Thrift 文件
|
||||
'.toml', // TOML 配置文件
|
||||
'.edn', // Clojure 数据表示文件
|
||||
'.cake', // CakePHP 配置文件
|
||||
'.ctp', // CakePHP 视图文件
|
||||
'.cfm', // ColdFusion 标记语言文件
|
||||
'.cfc', // ColdFusion 组件文件
|
||||
'.m', // Objective-C 或 MATLAB 源文件
|
||||
'.mm', // Objective-C++ 源文件
|
||||
'.gradle', // Gradle 构建文件
|
||||
'.groovy', // Gradle 构建文件
|
||||
'.kts', // Kotlin Script 文件
|
||||
'.java', // Java 代码文件
|
||||
'.cs', // C# 代码文件
|
||||
'.cpp', // C++ 代码文件
|
||||
'.c', // C++ 代码文件
|
||||
'.h', // C++ 头文件
|
||||
'.hpp', // C++ 头文件
|
||||
'.cc', // C++ 源文件
|
||||
'.cxx', // C++ 源文件
|
||||
'.cppm', // C++20 模块接口文件
|
||||
'.ipp', // 模板实现文件
|
||||
'.ixx', // C++20 模块实现文件
|
||||
'.f90', // Fortran 90 源文件
|
||||
'.f', // Fortran 固定格式源代码文件
|
||||
'.f03', // Fortran 2003+ 源代码文件
|
||||
'.ahk', // AutoHotKey 语言文件
|
||||
'.tcl', // Tcl 脚本
|
||||
'.do', // Questa 或 Modelsim Tcl 脚本
|
||||
'.v', // Verilog 源文件
|
||||
'.sv', // SystemVerilog 源文件
|
||||
'.svh', // SystemVerilog 头文件
|
||||
'.vhd', // VHDL 源文件
|
||||
'.vhdl', // VHDL 源文件
|
||||
'.lef', // Library Exchange Format
|
||||
'.def', // Design Exchange Format
|
||||
'.edif', // Electronic Design Interchange Format
|
||||
'.sdf', // Standard Delay Format
|
||||
'.sdc', // Synopsys Design Constraints
|
||||
'.xdc', // Xilinx Design Constraints
|
||||
'.rpt', // 报告文件
|
||||
'.lisp', // Lisp 脚本
|
||||
'.il', // Cadence SKILL 脚本
|
||||
'.ils', // Cadence SKILL++ 脚本
|
||||
'.sp', // SPICE netlist 文件
|
||||
'.spi', // SPICE netlist 文件
|
||||
'.cir', // SPICE netlist 文件
|
||||
'.net', // SPICE netlist 文件
|
||||
'.scs', // Spectre netlist 文件
|
||||
'.asc', // LTspice netlist schematic 文件
|
||||
'.tf' // Technology File
|
||||
]
|
||||
const textExtsByCategory = new Map([
|
||||
[
|
||||
'language',
|
||||
[
|
||||
'.js',
|
||||
'.mjs',
|
||||
'.cjs',
|
||||
'.ts',
|
||||
'.jsx',
|
||||
'.tsx', // JavaScript/TypeScript
|
||||
'.py', // Python
|
||||
'.java', // Java
|
||||
'.cs', // C#
|
||||
'.cpp',
|
||||
'.c',
|
||||
'.h',
|
||||
'.hpp',
|
||||
'.cc',
|
||||
'.cxx',
|
||||
'.cppm',
|
||||
'.ipp',
|
||||
'.ixx', // C/C++
|
||||
'.php', // PHP
|
||||
'.rb', // Ruby
|
||||
'.pl', // Perl
|
||||
'.go', // Go
|
||||
'.rs', // Rust
|
||||
'.swift', // Swift
|
||||
'.kt',
|
||||
'.kts', // Kotlin
|
||||
'.scala', // Scala
|
||||
'.lua', // Lua
|
||||
'.groovy', // Groovy
|
||||
'.dart', // Dart
|
||||
'.hs', // Haskell
|
||||
'.clj',
|
||||
'.cljs', // Clojure
|
||||
'.elm', // Elm
|
||||
'.erl', // Erlang
|
||||
'.ex',
|
||||
'.exs', // Elixir
|
||||
'.ml',
|
||||
'.mli', // OCaml
|
||||
'.fs', // F#
|
||||
'.r',
|
||||
'.R', // R
|
||||
'.sol', // Solidity
|
||||
'.awk', // AWK
|
||||
'.cob', // COBOL
|
||||
'.asm',
|
||||
'.s', // Assembly
|
||||
'.lisp',
|
||||
'.lsp', // Lisp
|
||||
'.coffee', // CoffeeScript
|
||||
'.ino', // Arduino
|
||||
'.jl', // Julia
|
||||
'.nim', // Nim
|
||||
'.zig', // Zig
|
||||
'.d', // D语言
|
||||
'.pas', // Pascal
|
||||
'.vb', // Visual Basic
|
||||
'.rkt', // Racket
|
||||
'.scm', // Scheme
|
||||
'.hx', // Haxe
|
||||
'.as', // ActionScript
|
||||
'.pde', // Processing
|
||||
'.f90',
|
||||
'.f',
|
||||
'.f03',
|
||||
'.for',
|
||||
'.f95', // Fortran
|
||||
'.adb',
|
||||
'.ads', // Ada
|
||||
'.pro', // Prolog
|
||||
'.m',
|
||||
'.mm', // Objective-C/MATLAB
|
||||
'.rpy', // Ren'Py
|
||||
'.ets', // OpenHarmony,
|
||||
'.uniswap', // DeFi
|
||||
'.vy', // Vyper
|
||||
'.shader',
|
||||
'.glsl',
|
||||
'.frag',
|
||||
'.vert',
|
||||
'.gd' // Godot
|
||||
]
|
||||
],
|
||||
[
|
||||
'script',
|
||||
[
|
||||
'.sh', // Shell
|
||||
'.bat',
|
||||
'.cmd', // Windows批处理
|
||||
'.ps1', // PowerShell
|
||||
'.tcl',
|
||||
'.do', // Tcl
|
||||
'.ahk', // AutoHotkey
|
||||
'.zsh', // Zsh
|
||||
'.fish', // Fish shell
|
||||
'.csh', // C shell
|
||||
'.vbs', // VBScript
|
||||
'.applescript', // AppleScript
|
||||
'.au3', // AutoIt
|
||||
'.bash',
|
||||
'.nu'
|
||||
]
|
||||
],
|
||||
[
|
||||
'style',
|
||||
[
|
||||
'.css', // CSS
|
||||
'.less', // Less
|
||||
'.scss',
|
||||
'.sass', // Sass
|
||||
'.styl', // Stylus
|
||||
'.pcss', // PostCSS
|
||||
'.postcss' // PostCSS
|
||||
]
|
||||
],
|
||||
[
|
||||
'template',
|
||||
[
|
||||
'.vue', // Vue.js
|
||||
'.pug',
|
||||
'.jade', // Pug/Jade
|
||||
'.haml', // Haml
|
||||
'.slim', // Slim
|
||||
'.tpl', // 通用模板
|
||||
'.ejs', // EJS
|
||||
'.hbs', // Handlebars
|
||||
'.mustache', // Mustache
|
||||
'.twig', // Twig
|
||||
'.blade', // Blade (Laravel)
|
||||
'.liquid', // Liquid
|
||||
'.jinja',
|
||||
'.jinja2',
|
||||
'.j2', // Jinja
|
||||
'.erb', // ERB
|
||||
'.vm', // Velocity
|
||||
'.ftl', // FreeMarker
|
||||
'.svelte', // Svelte
|
||||
'.astro' // Astro
|
||||
]
|
||||
],
|
||||
[
|
||||
'config',
|
||||
[
|
||||
'.ini', // INI配置
|
||||
'.conf',
|
||||
'.config', // 通用配置
|
||||
'.env', // 环境变量
|
||||
'.toml', // TOML
|
||||
'.cfg', // 通用配置
|
||||
'.properties', // Java属性
|
||||
'.desktop', // Linux桌面文件
|
||||
'.service', // systemd服务
|
||||
'.rc',
|
||||
'.bashrc',
|
||||
'.zshrc', // Shell配置
|
||||
'.fishrc', // Fish shell配置
|
||||
'.vimrc', // Vim配置
|
||||
'.htaccess', // Apache配置
|
||||
'.robots', // robots.txt
|
||||
'.editorconfig', // EditorConfig
|
||||
'.eslintrc', // ESLint
|
||||
'.prettierrc', // Prettier
|
||||
'.babelrc', // Babel
|
||||
'.npmrc', // npm
|
||||
'.dockerignore', // Docker ignore
|
||||
'.npmignore',
|
||||
'.yarnrc',
|
||||
'.prettierignore',
|
||||
'.eslintignore',
|
||||
'.browserslistrc',
|
||||
'.json5',
|
||||
'.tfvars'
|
||||
]
|
||||
],
|
||||
[
|
||||
'document',
|
||||
[
|
||||
'.txt',
|
||||
'.text', // 纯文本
|
||||
'.md',
|
||||
'.mdx', // Markdown
|
||||
'.html',
|
||||
'.htm',
|
||||
'.xhtml', // HTML
|
||||
'.xml', // XML
|
||||
'.org', // Org-mode
|
||||
'.wiki', // Wiki
|
||||
'.tex',
|
||||
'.bib', // LaTeX
|
||||
'.rst', // reStructuredText
|
||||
'.rtf', // 富文本
|
||||
'.nfo', // 信息文件
|
||||
'.adoc',
|
||||
'.asciidoc', // AsciiDoc
|
||||
'.pod', // Perl文档
|
||||
'.1',
|
||||
'.2',
|
||||
'.3',
|
||||
'.4',
|
||||
'.5',
|
||||
'.6',
|
||||
'.7',
|
||||
'.8',
|
||||
'.9', // man页面
|
||||
'.man', // man页面
|
||||
'.texi',
|
||||
'.texinfo', // Texinfo
|
||||
'.readme',
|
||||
'.me', // README
|
||||
'.changelog', // 变更日志
|
||||
'.license', // 许可证
|
||||
'.authors', // 作者文件
|
||||
'.po',
|
||||
'.pot'
|
||||
]
|
||||
],
|
||||
[
|
||||
'data',
|
||||
[
|
||||
'.json', // JSON
|
||||
'.jsonc', // JSON with comments
|
||||
'.yaml',
|
||||
'.yml', // YAML
|
||||
'.csv',
|
||||
'.tsv', // 分隔值文件
|
||||
'.edn', // Clojure数据
|
||||
'.jsonl',
|
||||
'.ndjson', // 换行分隔JSON
|
||||
'.geojson', // GeoJSON
|
||||
'.gpx', // GPS Exchange
|
||||
'.kml', // Keyhole Markup
|
||||
'.rss',
|
||||
'.atom', // Feed格式
|
||||
'.vcf', // vCard
|
||||
'.ics', // iCalendar
|
||||
'.ldif', // LDAP数据交换
|
||||
'.pbtxt',
|
||||
'.map'
|
||||
]
|
||||
],
|
||||
[
|
||||
'build',
|
||||
[
|
||||
'.gradle', // Gradle
|
||||
'.make',
|
||||
'.mk', // Make
|
||||
'.cmake', // CMake
|
||||
'.sbt', // SBT
|
||||
'.rake', // Rake
|
||||
'.spec', // RPM spec
|
||||
'.pom',
|
||||
'.build', // Meson
|
||||
'.bazel' // Bazel
|
||||
]
|
||||
],
|
||||
[
|
||||
'database',
|
||||
[
|
||||
'.sql', // SQL
|
||||
'.ddl',
|
||||
'.dml', // DDL/DML
|
||||
'.plsql', // PL/SQL
|
||||
'.psql', // PostgreSQL
|
||||
'.cypher', // Cypher
|
||||
'.sparql' // SPARQL
|
||||
]
|
||||
],
|
||||
[
|
||||
'web',
|
||||
[
|
||||
'.graphql',
|
||||
'.gql', // GraphQL
|
||||
'.proto', // Protocol Buffers
|
||||
'.thrift', // Thrift
|
||||
'.wsdl', // WSDL
|
||||
'.raml', // RAML
|
||||
'.swagger',
|
||||
'.openapi' // API文档
|
||||
]
|
||||
],
|
||||
[
|
||||
'version',
|
||||
[
|
||||
'.gitignore', // Git ignore
|
||||
'.gitattributes', // Git attributes
|
||||
'.gitconfig', // Git config
|
||||
'.hgignore', // Mercurial ignore
|
||||
'.bzrignore', // Bazaar ignore
|
||||
'.svnignore', // SVN ignore
|
||||
'.githistory' // Git history
|
||||
]
|
||||
],
|
||||
[
|
||||
'subtitle',
|
||||
[
|
||||
'.srt',
|
||||
'.sub',
|
||||
'.ass' // 字幕格式
|
||||
]
|
||||
],
|
||||
[
|
||||
'log',
|
||||
[
|
||||
'.log',
|
||||
'.rpt' // 日志和报告 (移除了.out,因为通常是二进制可执行文件)
|
||||
]
|
||||
],
|
||||
[
|
||||
'eda',
|
||||
[
|
||||
'.v',
|
||||
'.sv',
|
||||
'.svh', // Verilog/SystemVerilog
|
||||
'.vhd',
|
||||
'.vhdl', // VHDL
|
||||
'.lef',
|
||||
'.def', // LEF/DEF
|
||||
'.edif', // EDIF
|
||||
'.sdf', // SDF
|
||||
'.sdc',
|
||||
'.xdc', // 约束文件
|
||||
'.sp',
|
||||
'.spi',
|
||||
'.cir',
|
||||
'.net', // SPICE
|
||||
'.scs', // Spectre
|
||||
'.asc', // LTspice
|
||||
'.tf', // Technology File
|
||||
'.il',
|
||||
'.ils' // SKILL
|
||||
]
|
||||
],
|
||||
[
|
||||
'game',
|
||||
[
|
||||
'.mtl', // Material Template Library
|
||||
'.x3d', // X3D文件
|
||||
'.gltf', // glTF JSON
|
||||
'.prefab', // Unity预制体 (YAML格式)
|
||||
'.meta' // Unity元数据文件 (YAML格式)
|
||||
]
|
||||
],
|
||||
[
|
||||
'other',
|
||||
[
|
||||
'.mcfunction', // Minecraft函数
|
||||
'.jsp', // JSP
|
||||
'.aspx', // ASP.NET
|
||||
'.ipynb', // Jupyter Notebook
|
||||
'.cake',
|
||||
'.ctp', // CakePHP
|
||||
'.cfm',
|
||||
'.cfc' // ColdFusion
|
||||
]
|
||||
]
|
||||
])
|
||||
|
||||
export const textExts = Array.from(textExtsByCategory.values()).flat()
|
||||
|
||||
export const ZOOM_LEVELS = [0.25, 0.33, 0.5, 0.67, 0.75, 0.8, 0.9, 1, 1.1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5]
|
||||
|
||||
@ -170,3 +403,19 @@ export const KB = 1024
|
||||
export const MB = 1024 * KB
|
||||
export const GB = 1024 * MB
|
||||
export const defaultLanguage = 'en-US'
|
||||
|
||||
export enum FeedUrl {
|
||||
PRODUCTION = 'https://releases.cherry-ai.com',
|
||||
GITHUB_LATEST = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download',
|
||||
PRERELEASE_LOWEST = 'https://github.com/CherryHQ/cherry-studio/releases/download/v1.4.0'
|
||||
}
|
||||
|
||||
export enum UpgradeChannel {
|
||||
LATEST = 'latest', // 最新稳定版本
|
||||
RC = 'rc', // 公测版本
|
||||
BETA = 'beta' // 预览版本
|
||||
}
|
||||
|
||||
export const defaultTimeout = 10 * 1000 * 60
|
||||
|
||||
export const occupiedDirs = ['logs', 'Network', 'Partitions/webview/Network']
|
||||
|
||||
2904
packages/shared/config/languages.ts
Normal file
2904
packages/shared/config/languages.ts
Normal file
File diff suppressed because it is too large
Load Diff
9098
resources/data/agents-en.json
Normal file
9098
resources/data/agents-en.json
Normal file
File diff suppressed because one or more lines are too long
9098
resources/data/agents-zh.json
Normal file
9098
resources/data/agents-zh.json
Normal file
File diff suppressed because one or more lines are too long
@ -2,12 +2,12 @@ const fs = require('fs')
|
||||
const path = require('path')
|
||||
const os = require('os')
|
||||
const { execSync } = require('child_process')
|
||||
const AdmZip = require('adm-zip')
|
||||
const StreamZip = require('node-stream-zip')
|
||||
const { downloadWithRedirects } = require('./download')
|
||||
|
||||
// Base URL for downloading bun binaries
|
||||
const BUN_RELEASE_BASE_URL = 'https://gitcode.com/CherryHQ/bun/releases/download'
|
||||
const DEFAULT_BUN_VERSION = '1.2.9' // Default fallback version
|
||||
const DEFAULT_BUN_VERSION = '1.2.17' // Default fallback version
|
||||
|
||||
// Mapping of platform+arch to binary package name
|
||||
const BUN_PACKAGES = {
|
||||
@ -66,35 +66,36 @@ async function downloadBunBinary(platform, arch, version = DEFAULT_BUN_VERSION,
|
||||
|
||||
// Extract the zip file using adm-zip
|
||||
console.log(`Extracting ${packageName} to ${binDir}...`)
|
||||
const zip = new AdmZip(tempFilename)
|
||||
zip.extractAllTo(tempdir, true)
|
||||
const zip = new StreamZip.async({ file: tempFilename })
|
||||
|
||||
// Move files using Node.js fs
|
||||
const sourceDir = path.join(tempdir, packageName.split('.')[0])
|
||||
const files = fs.readdirSync(sourceDir)
|
||||
// Get all entries in the zip file
|
||||
const entries = await zip.entries()
|
||||
|
||||
for (const file of files) {
|
||||
const sourcePath = path.join(sourceDir, file)
|
||||
const destPath = path.join(binDir, file)
|
||||
// Extract files directly to binDir, flattening the directory structure
|
||||
for (const entry of Object.values(entries)) {
|
||||
if (!entry.isDirectory) {
|
||||
// Get just the filename without path
|
||||
const filename = path.basename(entry.name)
|
||||
const outputPath = path.join(binDir, filename)
|
||||
|
||||
fs.copyFileSync(sourcePath, destPath)
|
||||
fs.unlinkSync(sourcePath)
|
||||
|
||||
// Set executable permissions for non-Windows platforms
|
||||
if (platform !== 'win32') {
|
||||
try {
|
||||
// 755 permission: rwxr-xr-x
|
||||
fs.chmodSync(destPath, '755')
|
||||
} catch (error) {
|
||||
console.warn(`Warning: Failed to set executable permissions: ${error.message}`)
|
||||
console.log(`Extracting ${entry.name} -> ${filename}`)
|
||||
await zip.extract(entry.name, outputPath)
|
||||
// Make executable files executable on Unix-like systems
|
||||
if (platform !== 'win32') {
|
||||
try {
|
||||
fs.chmodSync(outputPath, 0o755)
|
||||
} catch (chmodError) {
|
||||
console.error(`Warning: Failed to set executable permissions on ${filename}`)
|
||||
return false
|
||||
}
|
||||
}
|
||||
console.log(`Extracted ${entry.name} -> ${outputPath}`)
|
||||
}
|
||||
}
|
||||
await zip.close()
|
||||
|
||||
// Clean up
|
||||
fs.unlinkSync(tempFilename)
|
||||
fs.rmSync(sourceDir, { recursive: true })
|
||||
|
||||
console.log(`Successfully installed bun ${version} for ${platformKey}`)
|
||||
return true
|
||||
} catch (error) {
|
||||
|
||||
@ -2,34 +2,33 @@ const fs = require('fs')
|
||||
const path = require('path')
|
||||
const os = require('os')
|
||||
const { execSync } = require('child_process')
|
||||
const tar = require('tar')
|
||||
const AdmZip = require('adm-zip')
|
||||
const StreamZip = require('node-stream-zip')
|
||||
const { downloadWithRedirects } = require('./download')
|
||||
|
||||
// Base URL for downloading uv binaries
|
||||
const UV_RELEASE_BASE_URL = 'https://gitcode.com/CherryHQ/uv/releases/download'
|
||||
const DEFAULT_UV_VERSION = '0.6.14'
|
||||
const DEFAULT_UV_VERSION = '0.7.13'
|
||||
|
||||
// Mapping of platform+arch to binary package name
|
||||
const UV_PACKAGES = {
|
||||
'darwin-arm64': 'uv-aarch64-apple-darwin.tar.gz',
|
||||
'darwin-x64': 'uv-x86_64-apple-darwin.tar.gz',
|
||||
'darwin-arm64': 'uv-aarch64-apple-darwin.zip',
|
||||
'darwin-x64': 'uv-x86_64-apple-darwin.zip',
|
||||
'win32-arm64': 'uv-aarch64-pc-windows-msvc.zip',
|
||||
'win32-ia32': 'uv-i686-pc-windows-msvc.zip',
|
||||
'win32-x64': 'uv-x86_64-pc-windows-msvc.zip',
|
||||
'linux-arm64': 'uv-aarch64-unknown-linux-gnu.tar.gz',
|
||||
'linux-ia32': 'uv-i686-unknown-linux-gnu.tar.gz',
|
||||
'linux-ppc64': 'uv-powerpc64-unknown-linux-gnu.tar.gz',
|
||||
'linux-ppc64le': 'uv-powerpc64le-unknown-linux-gnu.tar.gz',
|
||||
'linux-s390x': 'uv-s390x-unknown-linux-gnu.tar.gz',
|
||||
'linux-x64': 'uv-x86_64-unknown-linux-gnu.tar.gz',
|
||||
'linux-armv7l': 'uv-armv7-unknown-linux-gnueabihf.tar.gz',
|
||||
'linux-arm64': 'uv-aarch64-unknown-linux-gnu.zip',
|
||||
'linux-ia32': 'uv-i686-unknown-linux-gnu.zip',
|
||||
'linux-ppc64': 'uv-powerpc64-unknown-linux-gnu.zip',
|
||||
'linux-ppc64le': 'uv-powerpc64le-unknown-linux-gnu.zip',
|
||||
'linux-s390x': 'uv-s390x-unknown-linux-gnu.zip',
|
||||
'linux-x64': 'uv-x86_64-unknown-linux-gnu.zip',
|
||||
'linux-armv7l': 'uv-armv7-unknown-linux-gnueabihf.zip',
|
||||
// MUSL variants
|
||||
'linux-musl-arm64': 'uv-aarch64-unknown-linux-musl.tar.gz',
|
||||
'linux-musl-ia32': 'uv-i686-unknown-linux-musl.tar.gz',
|
||||
'linux-musl-x64': 'uv-x86_64-unknown-linux-musl.tar.gz',
|
||||
'linux-musl-armv6l': 'uv-arm-unknown-linux-musleabihf.tar.gz',
|
||||
'linux-musl-armv7l': 'uv-armv7-unknown-linux-musleabihf.tar.gz'
|
||||
'linux-musl-arm64': 'uv-aarch64-unknown-linux-musl.zip',
|
||||
'linux-musl-ia32': 'uv-i686-unknown-linux-musl.zip',
|
||||
'linux-musl-x64': 'uv-x86_64-unknown-linux-musl.zip',
|
||||
'linux-musl-armv6l': 'uv-arm-unknown-linux-musleabihf.zip',
|
||||
'linux-musl-armv7l': 'uv-armv7-unknown-linux-musleabihf.zip'
|
||||
}
|
||||
|
||||
/**
|
||||
@ -66,46 +65,35 @@ async function downloadUvBinary(platform, arch, version = DEFAULT_UV_VERSION, is
|
||||
|
||||
console.log(`Extracting ${packageName} to ${binDir}...`)
|
||||
|
||||
// 根据文件扩展名选择解压方法
|
||||
if (packageName.endsWith('.zip')) {
|
||||
// 使用 adm-zip 处理 zip 文件
|
||||
const zip = new AdmZip(tempFilename)
|
||||
zip.extractAllTo(binDir, true)
|
||||
fs.unlinkSync(tempFilename)
|
||||
console.log(`Successfully installed uv ${version} for ${platform}-${arch}`)
|
||||
return true
|
||||
} else {
|
||||
// tar.gz 文件的处理保持不变
|
||||
await tar.x({
|
||||
file: tempFilename,
|
||||
cwd: tempdir,
|
||||
z: true
|
||||
})
|
||||
const zip = new StreamZip.async({ file: tempFilename })
|
||||
|
||||
// Move files using Node.js fs
|
||||
const sourceDir = path.join(tempdir, packageName.split('.')[0])
|
||||
const files = fs.readdirSync(sourceDir)
|
||||
for (const file of files) {
|
||||
const sourcePath = path.join(sourceDir, file)
|
||||
const destPath = path.join(binDir, file)
|
||||
fs.copyFileSync(sourcePath, destPath)
|
||||
fs.unlinkSync(sourcePath)
|
||||
// Get all entries in the zip file
|
||||
const entries = await zip.entries()
|
||||
|
||||
// Set executable permissions for non-Windows platforms
|
||||
// Extract files directly to binDir, flattening the directory structure
|
||||
for (const entry of Object.values(entries)) {
|
||||
if (!entry.isDirectory) {
|
||||
// Get just the filename without path
|
||||
const filename = path.basename(entry.name)
|
||||
const outputPath = path.join(binDir, filename)
|
||||
|
||||
console.log(`Extracting ${entry.name} -> ${filename}`)
|
||||
await zip.extract(entry.name, outputPath)
|
||||
// Make executable files executable on Unix-like systems
|
||||
if (platform !== 'win32') {
|
||||
try {
|
||||
fs.chmodSync(destPath, '755')
|
||||
} catch (error) {
|
||||
console.warn(`Warning: Failed to set executable permissions: ${error.message}`)
|
||||
fs.chmodSync(outputPath, 0o755)
|
||||
} catch (chmodError) {
|
||||
console.error(`Warning: Failed to set executable permissions on ${filename}`)
|
||||
return false
|
||||
}
|
||||
}
|
||||
console.log(`Extracted ${entry.name} -> ${outputPath}`)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
fs.unlinkSync(tempFilename)
|
||||
fs.rmSync(sourceDir, { recursive: true })
|
||||
}
|
||||
|
||||
await zip.close()
|
||||
fs.unlinkSync(tempFilename)
|
||||
console.log(`Successfully installed uv ${version} for ${platform}-${arch}`)
|
||||
return true
|
||||
} catch (error) {
|
||||
|
||||
@ -36,6 +36,11 @@ exports.default = async function (context) {
|
||||
keepPackageNodeFiles(node_modules_path, '@libsql', ['win32-x64-msvc'])
|
||||
}
|
||||
}
|
||||
|
||||
if (platform === 'windows') {
|
||||
fs.rmSync(path.join(context.appOutDir, 'LICENSE.electron.txt'), { force: true })
|
||||
fs.rmSync(path.join(context.appOutDir, 'LICENSES.chromium.html'), { force: true })
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,16 +1,19 @@
|
||||
/**
|
||||
* Paratera_API_KEY=sk-abcxxxxxxxxxxxxxxxxxxxxxxx123 ts-node scripts/update-i18n.ts
|
||||
* 使用 OpenAI 兼容的模型生成 i18n 文本,并更新到 translate 目录
|
||||
*
|
||||
* API_KEY=sk-xxxx BASE_URL=xxxx MODEL=xxxx ts-node scripts/update-i18n.ts
|
||||
*/
|
||||
|
||||
// OCOOL API KEY
|
||||
const Paratera_API_KEY = process.env.Paratera_API_KEY
|
||||
const API_KEY = process.env.API_KEY
|
||||
const BASE_URL = process.env.BASE_URL || 'https://llmapi.paratera.com/v1'
|
||||
const MODEL = process.env.MODEL || 'Qwen3-235B-A22B'
|
||||
|
||||
const INDEX = [
|
||||
// 语言的名称 代码 用来翻译的模型
|
||||
{ name: 'France', code: 'fr-fr', model: 'Qwen3-235B-A22B' },
|
||||
{ name: 'Spanish', code: 'es-es', model: 'Qwen3-235B-A22B' },
|
||||
{ name: 'Portuguese', code: 'pt-pt', model: 'Qwen3-235B-A22B' },
|
||||
{ name: 'Greek', code: 'el-gr', model: 'Qwen3-235B-A22B' }
|
||||
// 语言的名称代码用来翻译的模型
|
||||
{ name: 'France', code: 'fr-fr', model: MODEL },
|
||||
{ name: 'Spanish', code: 'es-es', model: MODEL },
|
||||
{ name: 'Portuguese', code: 'pt-pt', model: MODEL },
|
||||
{ name: 'Greek', code: 'el-gr', model: MODEL }
|
||||
]
|
||||
|
||||
const fs = require('fs')
|
||||
@ -19,8 +22,8 @@ import OpenAI from 'openai'
|
||||
const zh = JSON.parse(fs.readFileSync('src/renderer/src/i18n/locales/zh-cn.json', 'utf8')) as object
|
||||
|
||||
const openai = new OpenAI({
|
||||
apiKey: Paratera_API_KEY,
|
||||
baseURL: 'https://llmapi.paratera.com/v1'
|
||||
apiKey: API_KEY,
|
||||
baseURL: BASE_URL
|
||||
})
|
||||
|
||||
// 递归遍历翻译
|
||||
|
||||
33
src/main/bootstrap.ts
Normal file
33
src/main/bootstrap.ts
Normal file
@ -0,0 +1,33 @@
|
||||
import { occupiedDirs } from '@shared/config/constant'
|
||||
import { app } from 'electron'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
|
||||
import { initAppDataDir } from './utils/file'
|
||||
|
||||
app.isPackaged && initAppDataDir()
|
||||
|
||||
// 在主进程中复制 appData 中某些一直被占用的文件
|
||||
// 在renderer进程还没有启动时,主进程可以复制这些文件到新的appData中
|
||||
function copyOccupiedDirsInMainProcess() {
|
||||
const newAppDataPath = process.argv
|
||||
.slice(1)
|
||||
.find((arg) => arg.startsWith('--new-data-path='))
|
||||
?.split('--new-data-path=')[1]
|
||||
if (!newAppDataPath) {
|
||||
return
|
||||
}
|
||||
|
||||
if (process.platform === 'win32') {
|
||||
const appDataPath = app.getPath('userData')
|
||||
occupiedDirs.forEach((dir) => {
|
||||
const dirPath = path.join(appDataPath, dir)
|
||||
const newDirPath = path.join(newAppDataPath, dir)
|
||||
if (fs.existsSync(dirPath)) {
|
||||
fs.cpSync(dirPath, newDirPath, { recursive: true })
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
copyOccupiedDirsInMainProcess()
|
||||
@ -1,7 +1,6 @@
|
||||
import { app } from 'electron'
|
||||
|
||||
import { getDataPath } from './utils'
|
||||
|
||||
const isDev = process.env.NODE_ENV === 'development'
|
||||
|
||||
if (isDev) {
|
||||
|
||||
58
src/main/configs/SelectionConfig.ts
Normal file
58
src/main/configs/SelectionConfig.ts
Normal file
@ -0,0 +1,58 @@
|
||||
interface IFilterList {
|
||||
WINDOWS: string[]
|
||||
MAC?: string[]
|
||||
}
|
||||
|
||||
interface IFinetunedList {
|
||||
EXCLUDE_CLIPBOARD_CURSOR_DETECT: IFilterList
|
||||
INCLUDE_CLIPBOARD_DELAY_READ: IFilterList
|
||||
}
|
||||
|
||||
/*************************************************************************
|
||||
* 注意:请不要修改此配置,除非你非常清楚其含义、影响和行为的目的
|
||||
* Note: Do not modify this configuration unless you fully understand its meaning, implications, and intended behavior.
|
||||
* -----------------------------------------------------------------------
|
||||
* A predefined application filter list to include commonly used software
|
||||
* that does not require text selection but may conflict with it, and disable them in advance.
|
||||
* Only available in the selected mode.
|
||||
*
|
||||
* Specification: must be all lowercase, need to accurately find the actual running program name
|
||||
*************************************************************************/
|
||||
export const SELECTION_PREDEFINED_BLACKLIST: IFilterList = {
|
||||
WINDOWS: [
|
||||
'explorer.exe',
|
||||
// Screenshot
|
||||
'snipaste.exe',
|
||||
'pixpin.exe',
|
||||
'sharex.exe',
|
||||
// Office
|
||||
'excel.exe',
|
||||
'powerpnt.exe',
|
||||
// Image Editor
|
||||
'photoshop.exe',
|
||||
'illustrator.exe',
|
||||
// Video Editor
|
||||
'adobe premiere pro.exe',
|
||||
'afterfx.exe',
|
||||
// Audio Editor
|
||||
'adobe audition.exe',
|
||||
// 3D Editor
|
||||
'blender.exe',
|
||||
'3dsmax.exe',
|
||||
'maya.exe',
|
||||
// CAD
|
||||
'acad.exe',
|
||||
'sldworks.exe',
|
||||
// Remote Desktop
|
||||
'mstsc.exe'
|
||||
]
|
||||
}
|
||||
|
||||
export const SELECTION_FINETUNED_LIST: IFinetunedList = {
|
||||
EXCLUDE_CLIPBOARD_CURSOR_DETECT: {
|
||||
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe']
|
||||
},
|
||||
INCLUDE_CLIPBOARD_DELAY_READ: {
|
||||
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe', 'foxitphantom.exe']
|
||||
}
|
||||
}
|
||||
@ -1,38 +0,0 @@
|
||||
import type { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
|
||||
import { OpenAiEmbeddings } from '@cherrystudio/embedjs-openai'
|
||||
import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-openai-embeddings'
|
||||
import { getInstanceName } from '@main/utils'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import VoyageEmbeddings from './VoyageEmbeddings'
|
||||
|
||||
export default class EmbeddingsFactory {
|
||||
static create({ model, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings {
|
||||
const batchSize = 10
|
||||
if (model.includes('voyage')) {
|
||||
return new VoyageEmbeddings({
|
||||
modelName: model,
|
||||
apiKey,
|
||||
outputDimension: dimensions,
|
||||
batchSize: 8
|
||||
})
|
||||
}
|
||||
if (apiVersion !== undefined) {
|
||||
return new AzureOpenAiEmbeddings({
|
||||
azureOpenAIApiKey: apiKey,
|
||||
azureOpenAIApiVersion: apiVersion,
|
||||
azureOpenAIApiDeploymentName: model,
|
||||
azureOpenAIApiInstanceName: getInstanceName(baseURL),
|
||||
dimensions,
|
||||
batchSize
|
||||
})
|
||||
}
|
||||
return new OpenAiEmbeddings({
|
||||
model,
|
||||
apiKey,
|
||||
dimensions,
|
||||
batchSize,
|
||||
configuration: { baseURL }
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,3 +1,8 @@
|
||||
// don't reorder this file, it's used to initialize the app data dir and
|
||||
// other which should be run before the main process is ready
|
||||
// eslint-disable-next-line
|
||||
import './bootstrap'
|
||||
|
||||
import '@main/config'
|
||||
|
||||
import { electronApp, optimizer } from '@electron-toolkit/utils'
|
||||
@ -7,7 +12,7 @@ import { app } from 'electron'
|
||||
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
||||
import Logger from 'electron-log'
|
||||
|
||||
import { isDev } from './constant'
|
||||
import { isDev, isWin } from './constant'
|
||||
import { registerIpc } from './ipc'
|
||||
import { configManager } from './services/ConfigManager'
|
||||
import mcpService from './services/MCPService'
|
||||
@ -21,9 +26,39 @@ import selectionService, { initSelectionService } from './services/SelectionServ
|
||||
import { registerShortcuts } from './services/ShortcutService'
|
||||
import { TrayService } from './services/TrayService'
|
||||
import { windowService } from './services/WindowService'
|
||||
import { setUserDataDir } from './utils/file'
|
||||
|
||||
Logger.initialize()
|
||||
|
||||
/**
|
||||
* Disable chromium's window animations
|
||||
* main purpose for this is to avoid the transparent window flashing when it is shown
|
||||
* (especially on Windows for SelectionAssistant Toolbar)
|
||||
* Know Issue: https://github.com/electron/electron/issues/12130#issuecomment-627198990
|
||||
*/
|
||||
if (isWin) {
|
||||
app.commandLine.appendSwitch('wm-window-animations-disabled')
|
||||
}
|
||||
|
||||
// Enable features for unresponsive renderer js call stacks
|
||||
app.commandLine.appendSwitch('enable-features', 'DocumentPolicyIncludeJSCallStacksInCrashReports')
|
||||
app.on('web-contents-created', (_, webContents) => {
|
||||
webContents.session.webRequest.onHeadersReceived((details, callback) => {
|
||||
callback({
|
||||
responseHeaders: {
|
||||
...details.responseHeaders,
|
||||
'Document-Policy': ['include-js-call-stacks-in-crash-reports']
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
webContents.on('unresponsive', async () => {
|
||||
// Interrupt execution and collect call stack from unresponsive renderer
|
||||
Logger.error('Renderer unresponsive start')
|
||||
const callStack = await webContents.mainFrame.collectJavaScriptCallStack()
|
||||
Logger.error('Renderer unresponsive js call stack\n', callStack)
|
||||
})
|
||||
})
|
||||
|
||||
// in production mode, handle uncaught exception and unhandled rejection globally
|
||||
if (!isDev) {
|
||||
// handle uncaught exception
|
||||
@ -42,9 +77,6 @@ if (!app.requestSingleInstanceLock()) {
|
||||
app.quit()
|
||||
process.exit(0)
|
||||
} else {
|
||||
// Portable dir must be setup before app ready
|
||||
setUserDataDir()
|
||||
|
||||
dbService.migrateDb().then(async () => {
|
||||
await dbService.migrateSeed('preference')
|
||||
})
|
||||
@ -96,19 +128,27 @@ if (!app.requestSingleInstanceLock()) {
|
||||
registerProtocolClient(app)
|
||||
|
||||
// macOS specific: handle protocol when app is already running
|
||||
|
||||
app.on('open-url', (event, url) => {
|
||||
event.preventDefault()
|
||||
handleProtocolUrl(url)
|
||||
})
|
||||
|
||||
const handleOpenUrl = (args: string[]) => {
|
||||
const url = args.find((arg) => arg.startsWith(CHERRY_STUDIO_PROTOCOL + '://'))
|
||||
if (url) handleProtocolUrl(url)
|
||||
}
|
||||
|
||||
// for windows to start with url
|
||||
handleOpenUrl(process.argv)
|
||||
|
||||
// Listen for second instance
|
||||
app.on('second-instance', (_event, argv) => {
|
||||
windowService.showMainWindow()
|
||||
|
||||
// Protocol handler for Windows/Linux
|
||||
// The commandLine is an array of strings where the last item might be the URL
|
||||
const url = argv.find((arg) => arg.startsWith(CHERRY_STUDIO_PROTOCOL + '://'))
|
||||
if (url) handleProtocolUrl(url)
|
||||
handleOpenUrl(argv)
|
||||
})
|
||||
|
||||
app.on('browser-window-created', (_, window) => {
|
||||
|
||||
217
src/main/ipc.ts
217
src/main/ipc.ts
@ -1,16 +1,17 @@
|
||||
import fs from 'node:fs'
|
||||
import { arch } from 'node:os'
|
||||
import path from 'node:path'
|
||||
|
||||
import { isMac, isWin } from '@main/constant'
|
||||
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
|
||||
import { handleZoomFactor } from '@main/utils/zoom'
|
||||
import { UpgradeChannel } from '@shared/config/constant'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { Shortcut, ThemeMode } from '@types'
|
||||
import { BrowserWindow, ipcMain, nativeTheme, session, shell } from 'electron'
|
||||
import { BrowserWindow, dialog, ipcMain, session, shell, webContents } from 'electron'
|
||||
import log from 'electron-log'
|
||||
import { Notification } from 'src/renderer/src/types/notification'
|
||||
|
||||
import { titleBarOverlayDark, titleBarOverlayLight } from './config'
|
||||
import AppUpdater from './services/AppUpdater'
|
||||
import BackupManager from './services/BackupManager'
|
||||
import { configManager } from './services/ConfigManager'
|
||||
@ -18,34 +19,39 @@ import CopilotService from './services/CopilotService'
|
||||
import { ExportService } from './services/ExportService'
|
||||
import FileService from './services/FileService'
|
||||
import FileStorage from './services/FileStorage'
|
||||
import { GeminiService } from './services/GeminiService'
|
||||
import KnowledgeService from './services/KnowledgeService'
|
||||
import mcpService from './services/MCPService'
|
||||
import NotificationService from './services/NotificationService'
|
||||
import * as NutstoreService from './services/NutstoreService'
|
||||
import ObsidianVaultService from './services/ObsidianVaultService'
|
||||
import { ProxyConfig, proxyManager } from './services/ProxyManager'
|
||||
import { pythonService } from './services/PythonService'
|
||||
import { searchService } from './services/SearchService'
|
||||
import { SelectionService } from './services/SelectionService'
|
||||
import { registerShortcuts, unregisterAllShortcuts } from './services/ShortcutService'
|
||||
import storeSyncService from './services/StoreSyncService'
|
||||
import { TrayService } from './services/TrayService'
|
||||
import { themeService } from './services/ThemeService'
|
||||
import VertexAIService from './services/VertexAIService'
|
||||
import { setOpenLinkExternal } from './services/WebviewService'
|
||||
import { windowService } from './services/WindowService'
|
||||
import { calculateDirectorySize, getResourcePath } from './utils'
|
||||
import { decrypt, encrypt } from './utils/aes'
|
||||
import { getCacheDir, getConfigDir, getFilesDir } from './utils/file'
|
||||
import { getCacheDir, getConfigDir, getFilesDir, hasWritePermission, updateAppDataConfig } from './utils/file'
|
||||
import { compress, decompress } from './utils/zip'
|
||||
|
||||
const fileManager = new FileStorage()
|
||||
const backupManager = new BackupManager()
|
||||
const exportService = new ExportService(fileManager)
|
||||
const obsidianVaultService = new ObsidianVaultService()
|
||||
const vertexAIService = VertexAIService.getInstance()
|
||||
|
||||
export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
const appUpdater = new AppUpdater(mainWindow)
|
||||
const notificationService = new NotificationService(mainWindow)
|
||||
|
||||
// Initialize Python service with main window
|
||||
pythonService.setMainWindow(mainWindow)
|
||||
|
||||
ipcMain.handle(IpcChannel.App_Info, () => ({
|
||||
version: app.getVersion(),
|
||||
isPackaged: app.isPackaged,
|
||||
@ -56,7 +62,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
resourcesPath: getResourcePath(),
|
||||
logsPath: log.transports.file.getFile().path,
|
||||
arch: arch(),
|
||||
isPortable: isWin && 'PORTABLE_EXECUTABLE_DIR' in process.env
|
||||
isPortable: isWin && 'PORTABLE_EXECUTABLE_DIR' in process.env,
|
||||
installPath: path.dirname(app.getPath('exe'))
|
||||
}))
|
||||
|
||||
ipcMain.handle(IpcChannel.App_Proxy, async (_, proxy: string) => {
|
||||
@ -84,6 +91,27 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
configManager.setLanguage(language)
|
||||
})
|
||||
|
||||
// spell check
|
||||
ipcMain.handle(IpcChannel.App_SetEnableSpellCheck, (_, isEnable: boolean) => {
|
||||
// disable spell check for all webviews
|
||||
const webviews = webContents.getAllWebContents()
|
||||
webviews.forEach((webview) => {
|
||||
webview.session.setSpellCheckerEnabled(isEnable)
|
||||
})
|
||||
})
|
||||
|
||||
// spell check languages
|
||||
ipcMain.handle(IpcChannel.App_SetSpellCheckLanguages, (_, languages: string[]) => {
|
||||
if (languages.length === 0) {
|
||||
return
|
||||
}
|
||||
const windows = BrowserWindow.getAllWindows()
|
||||
windows.forEach((window) => {
|
||||
window.webContents.session.setSpellCheckerLanguages(languages)
|
||||
})
|
||||
configManager.set('spellCheckLanguages', languages)
|
||||
})
|
||||
|
||||
// launch on boot
|
||||
ipcMain.handle(IpcChannel.App_SetLaunchOnBoot, (_, openAtLogin: boolean) => {
|
||||
// Set login item settings for windows and mac
|
||||
@ -114,10 +142,24 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
configManager.setAutoUpdate(isActive)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.App_RestartTray, () => TrayService.getInstance().restartTray())
|
||||
ipcMain.handle(IpcChannel.App_SetTestPlan, async (_, isActive: boolean) => {
|
||||
log.info('set test plan', isActive)
|
||||
if (isActive !== configManager.getTestPlan()) {
|
||||
appUpdater.cancelDownload()
|
||||
configManager.setTestPlan(isActive)
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Config_Set, (_, key: string, value: any) => {
|
||||
configManager.set(key, value)
|
||||
ipcMain.handle(IpcChannel.App_SetTestChannel, async (_, channel: UpgradeChannel) => {
|
||||
log.info('set test channel', channel)
|
||||
if (channel !== configManager.getTestChannel()) {
|
||||
appUpdater.cancelDownload()
|
||||
configManager.setTestChannel(channel)
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Config_Set, (_, key: string, value: any, isNotify: boolean = false) => {
|
||||
configManager.set(key, value, isNotify)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Config_Get, (_, key: string) => {
|
||||
@ -126,34 +168,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
|
||||
// theme
|
||||
ipcMain.handle(IpcChannel.App_SetTheme, (_, theme: ThemeMode) => {
|
||||
const updateTitleBarOverlay = () => {
|
||||
if (!mainWindow?.setTitleBarOverlay) return
|
||||
const isDark = nativeTheme.shouldUseDarkColors
|
||||
mainWindow.setTitleBarOverlay(isDark ? titleBarOverlayDark : titleBarOverlayLight)
|
||||
}
|
||||
|
||||
const broadcastThemeChange = () => {
|
||||
const isDark = nativeTheme.shouldUseDarkColors
|
||||
const effectiveTheme = isDark ? ThemeMode.dark : ThemeMode.light
|
||||
BrowserWindow.getAllWindows().forEach((win) => win.webContents.send(IpcChannel.ThemeChange, effectiveTheme))
|
||||
}
|
||||
|
||||
const notifyThemeChange = () => {
|
||||
updateTitleBarOverlay()
|
||||
broadcastThemeChange()
|
||||
}
|
||||
|
||||
if (theme === ThemeMode.auto) {
|
||||
nativeTheme.themeSource = 'system'
|
||||
nativeTheme.on('updated', notifyThemeChange)
|
||||
} else {
|
||||
nativeTheme.themeSource = theme
|
||||
nativeTheme.off('updated', notifyThemeChange)
|
||||
}
|
||||
|
||||
updateTitleBarOverlay()
|
||||
configManager.setTheme(theme)
|
||||
notifyThemeChange()
|
||||
themeService.setTheme(theme)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.App_HandleZoomFactor, (_, delta: number, reset: boolean = false) => {
|
||||
@ -199,6 +214,102 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
}
|
||||
})
|
||||
|
||||
let preventQuitListener: ((event: Electron.Event) => void) | null = null
|
||||
ipcMain.handle(IpcChannel.App_SetStopQuitApp, (_, stop: boolean = false, reason: string = '') => {
|
||||
if (stop) {
|
||||
// Only add listener if not already added
|
||||
if (!preventQuitListener) {
|
||||
preventQuitListener = (event: Electron.Event) => {
|
||||
event.preventDefault()
|
||||
notificationService.sendNotification({
|
||||
title: reason,
|
||||
message: reason
|
||||
} as Notification)
|
||||
}
|
||||
app.on('before-quit', preventQuitListener)
|
||||
}
|
||||
} else {
|
||||
// Remove listener if it exists
|
||||
if (preventQuitListener) {
|
||||
app.removeListener('before-quit', preventQuitListener)
|
||||
preventQuitListener = null
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Select app data path
|
||||
ipcMain.handle(IpcChannel.App_Select, async (_, options: Electron.OpenDialogOptions) => {
|
||||
try {
|
||||
const { canceled, filePaths } = await dialog.showOpenDialog(options)
|
||||
if (canceled || filePaths.length === 0) {
|
||||
return null
|
||||
}
|
||||
return filePaths[0]
|
||||
} catch (error: any) {
|
||||
log.error('Failed to select app data path:', error)
|
||||
return null
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.App_HasWritePermission, async (_, filePath: string) => {
|
||||
return hasWritePermission(filePath)
|
||||
})
|
||||
|
||||
// Set app data path
|
||||
ipcMain.handle(IpcChannel.App_SetAppDataPath, async (_, filePath: string) => {
|
||||
updateAppDataConfig(filePath)
|
||||
app.setPath('userData', filePath)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.App_GetDataPathFromArgs, () => {
|
||||
return process.argv
|
||||
.slice(1)
|
||||
.find((arg) => arg.startsWith('--new-data-path='))
|
||||
?.split('--new-data-path=')[1]
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.App_FlushAppData, () => {
|
||||
BrowserWindow.getAllWindows().forEach((w) => {
|
||||
w.webContents.session.flushStorageData()
|
||||
w.webContents.session.cookies.flushStore()
|
||||
|
||||
w.webContents.session.closeAllConnections()
|
||||
})
|
||||
|
||||
session.defaultSession.flushStorageData()
|
||||
session.defaultSession.cookies.flushStore()
|
||||
session.defaultSession.closeAllConnections()
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.App_IsNotEmptyDir, async (_, path: string) => {
|
||||
return fs.readdirSync(path).length > 0
|
||||
})
|
||||
|
||||
// Copy user data to new location
|
||||
ipcMain.handle(IpcChannel.App_Copy, async (_, oldPath: string, newPath: string, occupiedDirs: string[] = []) => {
|
||||
try {
|
||||
await fs.promises.cp(oldPath, newPath, {
|
||||
recursive: true,
|
||||
filter: (src) => {
|
||||
if (occupiedDirs.some((dir) => src.startsWith(path.resolve(dir)))) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
})
|
||||
return { success: true }
|
||||
} catch (error: any) {
|
||||
log.error('Failed to copy user data:', error)
|
||||
return { success: false, error: error.message }
|
||||
}
|
||||
})
|
||||
|
||||
// Relaunch app
|
||||
ipcMain.handle(IpcChannel.App_RelaunchApp, (_, options?: Electron.RelaunchOptions) => {
|
||||
app.relaunch(options)
|
||||
app.exit(0)
|
||||
})
|
||||
|
||||
// check for update
|
||||
ipcMain.handle(IpcChannel.App_CheckForUpdate, async () => {
|
||||
return await appUpdater.checkForUpdates()
|
||||
@ -250,7 +361,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.File_WriteWithId, fileManager.writeFileWithId)
|
||||
ipcMain.handle(IpcChannel.File_SaveImage, fileManager.saveImage)
|
||||
ipcMain.handle(IpcChannel.File_Base64Image, fileManager.base64Image)
|
||||
ipcMain.handle(IpcChannel.File_SaveBase64Image, fileManager.saveBase64Image)
|
||||
ipcMain.handle(IpcChannel.File_Base64File, fileManager.base64File)
|
||||
ipcMain.handle(IpcChannel.File_GetPdfInfo, fileManager.pdfPageCount)
|
||||
ipcMain.handle(IpcChannel.File_Download, fileManager.downloadFile)
|
||||
ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile)
|
||||
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage)
|
||||
@ -298,12 +411,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
}
|
||||
})
|
||||
|
||||
// gemini
|
||||
ipcMain.handle(IpcChannel.Gemini_UploadFile, GeminiService.uploadFile)
|
||||
ipcMain.handle(IpcChannel.Gemini_Base64File, GeminiService.base64File)
|
||||
ipcMain.handle(IpcChannel.Gemini_RetrieveFile, GeminiService.retrieveFile)
|
||||
ipcMain.handle(IpcChannel.Gemini_ListFiles, GeminiService.listFiles)
|
||||
ipcMain.handle(IpcChannel.Gemini_DeleteFile, GeminiService.deleteFile)
|
||||
// VertexAI
|
||||
ipcMain.handle(IpcChannel.VertexAI_GetAuthHeaders, async (_, params) => {
|
||||
return vertexAIService.getAuthHeaders(params)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.VertexAI_ClearAuthCache, async (_, projectId: string, clientEmail?: string) => {
|
||||
vertexAIService.clearAuthCache(projectId, clientEmail)
|
||||
})
|
||||
|
||||
// mini window
|
||||
ipcMain.handle(IpcChannel.MiniWindow_Show, () => windowService.showMiniWindow())
|
||||
@ -333,6 +448,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
|
||||
ipcMain.handle(IpcChannel.Mcp_CheckConnectivity, mcpService.checkMcpConnectivity)
|
||||
|
||||
// Register Python execution handler
|
||||
ipcMain.handle(
|
||||
IpcChannel.Python_Execute,
|
||||
async (_, script: string, context?: Record<string, any>, timeout?: number) => {
|
||||
return await pythonService.executeScript(script, context, timeout)
|
||||
}
|
||||
)
|
||||
|
||||
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
|
||||
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))
|
||||
ipcMain.handle(IpcChannel.App_InstallUvBinary, () => runInstallScript('install-uv.js'))
|
||||
@ -378,9 +501,17 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
setOpenLinkExternal(webviewId, isExternal)
|
||||
)
|
||||
|
||||
ipcMain.handle(IpcChannel.Webview_SetSpellCheckEnabled, (_, webviewId: number, isEnable: boolean) => {
|
||||
const webview = webContents.fromId(webviewId)
|
||||
if (!webview) return
|
||||
webview.session.setSpellCheckerEnabled(isEnable)
|
||||
})
|
||||
|
||||
// store sync
|
||||
storeSyncService.registerIpcHandler()
|
||||
|
||||
// selection assistant
|
||||
SelectionService.registerIpcHandler()
|
||||
|
||||
ipcMain.handle(IpcChannel.App_QuoteToMain, (_, text: string) => windowService.quoteToMainWindow(text))
|
||||
}
|
||||
|
||||
@ -5,8 +5,15 @@ import EmbeddingsFactory from './EmbeddingsFactory'
|
||||
|
||||
export default class Embeddings {
|
||||
private sdk: BaseEmbeddings
|
||||
constructor({ model, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams) {
|
||||
this.sdk = EmbeddingsFactory.create({ model, apiKey, apiVersion, baseURL, dimensions } as KnowledgeBaseParams)
|
||||
constructor({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams) {
|
||||
this.sdk = EmbeddingsFactory.create({
|
||||
model,
|
||||
provider,
|
||||
apiKey,
|
||||
apiVersion,
|
||||
baseURL,
|
||||
dimensions
|
||||
} as KnowledgeBaseParams)
|
||||
}
|
||||
public async init(): Promise<void> {
|
||||
return this.sdk.init()
|
||||
67
src/main/knowledage/embeddings/EmbeddingsFactory.ts
Normal file
67
src/main/knowledage/embeddings/EmbeddingsFactory.ts
Normal file
@ -0,0 +1,67 @@
|
||||
import type { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
|
||||
import { OllamaEmbeddings } from '@cherrystudio/embedjs-ollama'
|
||||
import { OpenAiEmbeddings } from '@cherrystudio/embedjs-openai'
|
||||
import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-openai-embeddings'
|
||||
import { getInstanceName } from '@main/utils'
|
||||
import { KnowledgeBaseParams } from '@types'
|
||||
|
||||
import { SUPPORTED_DIM_MODELS as VOYAGE_SUPPORTED_DIM_MODELS, VoyageEmbeddings } from './VoyageEmbeddings'
|
||||
|
||||
export default class EmbeddingsFactory {
|
||||
static create({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings {
|
||||
const batchSize = 10
|
||||
if (provider === 'voyageai') {
|
||||
if (VOYAGE_SUPPORTED_DIM_MODELS.includes(model)) {
|
||||
return new VoyageEmbeddings({
|
||||
modelName: model,
|
||||
apiKey,
|
||||
outputDimension: dimensions,
|
||||
batchSize: 8
|
||||
})
|
||||
} else {
|
||||
return new VoyageEmbeddings({
|
||||
modelName: model,
|
||||
apiKey,
|
||||
batchSize: 8
|
||||
})
|
||||
}
|
||||
}
|
||||
if (provider === 'ollama') {
|
||||
if (baseURL.includes('v1/')) {
|
||||
return new OllamaEmbeddings({
|
||||
model: model,
|
||||
baseUrl: baseURL.replace('v1/', ''),
|
||||
requestOptions: {
|
||||
// @ts-ignore expected
|
||||
'encoding-format': 'float'
|
||||
}
|
||||
})
|
||||
}
|
||||
return new OllamaEmbeddings({
|
||||
model: model,
|
||||
baseUrl: baseURL,
|
||||
requestOptions: {
|
||||
// @ts-ignore expected
|
||||
'encoding-format': 'float'
|
||||
}
|
||||
})
|
||||
}
|
||||
if (apiVersion !== undefined) {
|
||||
return new AzureOpenAiEmbeddings({
|
||||
azureOpenAIApiKey: apiKey,
|
||||
azureOpenAIApiVersion: apiVersion,
|
||||
azureOpenAIApiDeploymentName: model,
|
||||
azureOpenAIApiInstanceName: getInstanceName(baseURL),
|
||||
dimensions,
|
||||
batchSize
|
||||
})
|
||||
}
|
||||
return new OpenAiEmbeddings({
|
||||
model,
|
||||
apiKey,
|
||||
dimensions,
|
||||
batchSize,
|
||||
configuration: { baseURL }
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -1,16 +1,20 @@
|
||||
import { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
|
||||
import { VoyageEmbeddings as _VoyageEmbeddings } from '@langchain/community/embeddings/voyage'
|
||||
|
||||
export default class VoyageEmbeddings extends BaseEmbeddings {
|
||||
/**
|
||||
* 支持设置嵌入维度的模型
|
||||
*/
|
||||
export const SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3']
|
||||
export class VoyageEmbeddings extends BaseEmbeddings {
|
||||
private model: _VoyageEmbeddings
|
||||
constructor(private readonly configuration?: ConstructorParameters<typeof _VoyageEmbeddings>[0]) {
|
||||
super()
|
||||
if (!this.configuration) this.configuration = {}
|
||||
if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3'
|
||||
|
||||
if (!this.configuration.outputDimension) {
|
||||
throw new Error('You need to pass in the optional dimensions parameter for this model')
|
||||
if (!SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) {
|
||||
throw new Error(`VoyageEmbeddings only supports ${SUPPORTED_DIM_MODELS.join(', ')}`)
|
||||
}
|
||||
|
||||
this.model = new _VoyageEmbeddings(this.configuration)
|
||||
}
|
||||
override async getDimensions(): Promise<number> {
|
||||
@ -16,6 +16,7 @@ const FILE_LOADER_MAP: Record<string, string> = {
|
||||
// 内置类型
|
||||
'.pdf': 'common',
|
||||
'.csv': 'common',
|
||||
'.doc': 'common',
|
||||
'.docx': 'common',
|
||||
'.pptx': 'common',
|
||||
'.xlsx': 'common',
|
||||
44
src/main/knowledage/loader/noteLoader.ts
Normal file
44
src/main/knowledage/loader/noteLoader.ts
Normal file
@ -0,0 +1,44 @@
|
||||
import { BaseLoader } from '@cherrystudio/embedjs-interfaces'
|
||||
import { cleanString } from '@cherrystudio/embedjs-utils'
|
||||
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'
|
||||
import md5 from 'md5'
|
||||
|
||||
export class NoteLoader extends BaseLoader<{ type: 'NoteLoader' }> {
|
||||
private readonly text: string
|
||||
private readonly sourceUrl?: string
|
||||
|
||||
constructor({
|
||||
text,
|
||||
sourceUrl,
|
||||
chunkSize,
|
||||
chunkOverlap
|
||||
}: {
|
||||
text: string
|
||||
sourceUrl?: string
|
||||
chunkSize?: number
|
||||
chunkOverlap?: number
|
||||
}) {
|
||||
super(`NoteLoader_${md5(text + (sourceUrl || ''))}`, { text, sourceUrl }, chunkSize ?? 2000, chunkOverlap ?? 0)
|
||||
this.text = text
|
||||
this.sourceUrl = sourceUrl
|
||||
}
|
||||
|
||||
override async *getUnfilteredChunks() {
|
||||
const chunker = new RecursiveCharacterTextSplitter({
|
||||
chunkSize: this.chunkSize,
|
||||
chunkOverlap: this.chunkOverlap
|
||||
})
|
||||
|
||||
const chunks = await chunker.splitText(cleanString(this.text))
|
||||
|
||||
for (const chunk of chunks) {
|
||||
yield {
|
||||
pageContent: chunk,
|
||||
metadata: {
|
||||
type: 'NoteLoader' as const,
|
||||
source: this.sourceUrl || 'note'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -17,14 +17,17 @@ export default abstract class BaseReranker {
|
||||
* Get Rerank Request Url
|
||||
*/
|
||||
protected getRerankUrl() {
|
||||
if (this.base.rerankModelProvider === 'dashscope') {
|
||||
if (this.base.rerankModelProvider === 'bailian') {
|
||||
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||
}
|
||||
|
||||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
||||
? this.base.rerankBaseURL.slice(0, -1)
|
||||
: this.base.rerankBaseURL
|
||||
// 必须携带/v1,否则会404
|
||||
let baseURL = this.base.rerankBaseURL
|
||||
|
||||
if (baseURL && baseURL.endsWith('/')) {
|
||||
// `/` 结尾强制使用rerankBaseURL
|
||||
return `${baseURL}rerank`
|
||||
}
|
||||
|
||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
||||
baseURL = `${baseURL}/v1`
|
||||
}
|
||||
@ -47,7 +50,7 @@ export default abstract class BaseReranker {
|
||||
documents,
|
||||
top_k: topN
|
||||
}
|
||||
} else if (provider === 'dashscope') {
|
||||
} else if (provider === 'bailian') {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
input: {
|
||||
@ -58,6 +61,12 @@ export default abstract class BaseReranker {
|
||||
top_n: topN
|
||||
}
|
||||
}
|
||||
} else if (provider?.includes('tei')) {
|
||||
return {
|
||||
query,
|
||||
texts: documents,
|
||||
return_text: true
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
model: this.base.rerankModel,
|
||||
@ -73,10 +82,17 @@ export default abstract class BaseReranker {
|
||||
*/
|
||||
protected extractRerankResult(data: any) {
|
||||
const provider = this.base.rerankModelProvider
|
||||
if (provider === 'dashscope') {
|
||||
if (provider === 'bailian') {
|
||||
return data.output.results
|
||||
} else if (provider === 'voyageai') {
|
||||
return data.data
|
||||
} else if (provider?.includes('tei')) {
|
||||
return data.map((item: any) => {
|
||||
return {
|
||||
index: item.index,
|
||||
relevance_score: item.score
|
||||
}
|
||||
})
|
||||
} else {
|
||||
return data.results
|
||||
}
|
||||
@ -6,6 +6,7 @@ import DifyKnowledgeServer from './dify-knowledge'
|
||||
import FetchServer from './fetch'
|
||||
import FileSystemServer from './filesystem'
|
||||
import MemoryServer from './memory'
|
||||
import PythonServer from './python'
|
||||
import ThinkingServer from './sequentialthinking'
|
||||
|
||||
export function createInMemoryMCPServer(name: string, args: string[] = [], envs: Record<string, string> = {}): Server {
|
||||
@ -31,6 +32,9 @@ export function createInMemoryMCPServer(name: string, args: string[] = [], envs:
|
||||
const difyKey = envs.DIFY_KEY
|
||||
return new DifyKnowledgeServer(difyKey, args).server
|
||||
}
|
||||
case '@cherry/python': {
|
||||
return new PythonServer().server
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
||||
}
|
||||
|
||||
113
src/main/mcpServers/python.ts
Normal file
113
src/main/mcpServers/python.ts
Normal file
@ -0,0 +1,113 @@
|
||||
import { pythonService } from '@main/services/PythonService'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js'
|
||||
import Logger from 'electron-log'
|
||||
|
||||
/**
|
||||
* Python MCP Server for executing Python code using Pyodide
|
||||
*/
|
||||
class PythonServer {
|
||||
public server: Server
|
||||
|
||||
constructor() {
|
||||
this.server = new Server(
|
||||
{
|
||||
name: 'python-server',
|
||||
version: '1.0.0'
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
tools: {}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
this.setupRequestHandlers()
|
||||
}
|
||||
|
||||
private setupRequestHandlers() {
|
||||
// List available tools
|
||||
this.server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||
return {
|
||||
tools: [
|
||||
{
|
||||
name: 'python_execute',
|
||||
description: `Execute Python code using Pyodide in a sandboxed environment. Supports most Python standard library and scientific packages.
|
||||
The code will be executed with Python 3.12.
|
||||
Dependencies may be defined via PEP 723 script metadata, e.g. to install "pydantic", the script should start
|
||||
with a comment of the form:
|
||||
# /// script
|
||||
# dependencies = ['pydantic']
|
||||
# ///
|
||||
print('python code here')`,
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
code: {
|
||||
type: 'string',
|
||||
description: 'The Python code to execute'
|
||||
},
|
||||
context: {
|
||||
type: 'object',
|
||||
description: 'Optional context variables to pass to the Python execution environment',
|
||||
additionalProperties: true
|
||||
},
|
||||
timeout: {
|
||||
type: 'number',
|
||||
description: 'Timeout in milliseconds (default: 60000)',
|
||||
default: 60000
|
||||
}
|
||||
},
|
||||
required: ['code']
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
// Handle tool calls
|
||||
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||
const { name, arguments: args } = request.params
|
||||
|
||||
if (name !== 'python_execute') {
|
||||
throw new McpError(ErrorCode.MethodNotFound, `Tool ${name} not found`)
|
||||
}
|
||||
|
||||
try {
|
||||
const {
|
||||
code,
|
||||
context = {},
|
||||
timeout = 60000
|
||||
} = args as {
|
||||
code: string
|
||||
context?: Record<string, any>
|
||||
timeout?: number
|
||||
}
|
||||
|
||||
if (!code || typeof code !== 'string') {
|
||||
throw new McpError(ErrorCode.InvalidParams, 'Code parameter is required and must be a string')
|
||||
}
|
||||
|
||||
Logger.info('Executing Python code via Pyodide')
|
||||
|
||||
const result = await pythonService.executeScript(code, context, timeout)
|
||||
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: result
|
||||
}
|
||||
]
|
||||
}
|
||||
} catch (error) {
|
||||
const errorMessage = error instanceof Error ? error.message : String(error)
|
||||
Logger.error('Python execution error:', errorMessage)
|
||||
|
||||
throw new McpError(ErrorCode.InternalError, `Python execution failed: ${errorMessage}`)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export default PythonServer
|
||||
@ -106,6 +106,7 @@ class SequentialThinkingServer {
|
||||
type: 'text',
|
||||
text: JSON.stringify(
|
||||
{
|
||||
thought: validatedInput.thought,
|
||||
thoughtNumber: validatedInput.thoughtNumber,
|
||||
totalThoughts: validatedInput.totalThoughts,
|
||||
nextThoughtNeeded: validatedInput.nextThoughtNeeded,
|
||||
|
||||
@ -1,9 +1,13 @@
|
||||
import { isWin } from '@main/constant'
|
||||
import { locales } from '@main/utils/locales'
|
||||
import { generateUserAgent } from '@main/utils/systemInfo'
|
||||
import { FeedUrl, UpgradeChannel } from '@shared/config/constant'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { UpdateInfo } from 'builder-util-runtime'
|
||||
import { CancellationToken, UpdateInfo } from 'builder-util-runtime'
|
||||
import { app, BrowserWindow, dialog } from 'electron'
|
||||
import logger from 'electron-log'
|
||||
import { AppUpdater as _AppUpdater, autoUpdater } from 'electron-updater'
|
||||
import { AppUpdater as _AppUpdater, autoUpdater, NsisUpdater, UpdateCheckResult } from 'electron-updater'
|
||||
import path from 'path'
|
||||
|
||||
import icon from '../../../build/icon.png?asset'
|
||||
import { configManager } from './ConfigManager'
|
||||
@ -11,6 +15,8 @@ import { configManager } from './ConfigManager'
|
||||
export default class AppUpdater {
|
||||
autoUpdater: _AppUpdater = autoUpdater
|
||||
private releaseInfo: UpdateInfo | undefined
|
||||
private cancellationToken: CancellationToken = new CancellationToken()
|
||||
private updateCheckResult: UpdateCheckResult | null = null
|
||||
|
||||
constructor(mainWindow: BrowserWindow) {
|
||||
logger.transports.file.level = 'info'
|
||||
@ -19,8 +25,11 @@ export default class AppUpdater {
|
||||
autoUpdater.forceDevUpdateConfig = !app.isPackaged
|
||||
autoUpdater.autoDownload = configManager.getAutoUpdate()
|
||||
autoUpdater.autoInstallOnAppQuit = configManager.getAutoUpdate()
|
||||
autoUpdater.requestHeaders = {
|
||||
...autoUpdater.requestHeaders,
|
||||
'User-Agent': generateUserAgent()
|
||||
}
|
||||
|
||||
// 检测下载错误
|
||||
autoUpdater.on('error', (error) => {
|
||||
// 简单记录错误信息和时间戳
|
||||
logger.error('更新异常', {
|
||||
@ -53,14 +62,139 @@ export default class AppUpdater {
|
||||
logger.info('下载完成', releaseInfo)
|
||||
})
|
||||
|
||||
if (isWin) {
|
||||
;(autoUpdater as NsisUpdater).installDirectory = path.dirname(app.getPath('exe'))
|
||||
}
|
||||
|
||||
this.autoUpdater = autoUpdater
|
||||
}
|
||||
|
||||
private async _getPreReleaseVersionFromGithub(channel: UpgradeChannel) {
|
||||
try {
|
||||
logger.info('get pre release version from github', channel)
|
||||
const responses = await fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', {
|
||||
headers: {
|
||||
Accept: 'application/vnd.github+json',
|
||||
'X-GitHub-Api-Version': '2022-11-28',
|
||||
'Accept-Language': 'en-US,en;q=0.9'
|
||||
}
|
||||
})
|
||||
const data = (await responses.json()) as GithubReleaseInfo[]
|
||||
const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => {
|
||||
return item.prerelease && item.tag_name.includes(`-${channel}.`)
|
||||
})
|
||||
|
||||
logger.info('release info', release)
|
||||
|
||||
if (!release) {
|
||||
return null
|
||||
}
|
||||
|
||||
logger.info('release info', release.tag_name)
|
||||
return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}`
|
||||
} catch (error) {
|
||||
logger.error('Failed to get latest not draft version from github:', error)
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
private async _getIpCountry() {
|
||||
try {
|
||||
// add timeout using AbortController
|
||||
const controller = new AbortController()
|
||||
const timeoutId = setTimeout(() => controller.abort(), 5000)
|
||||
|
||||
const ipinfo = await fetch('https://ipinfo.io/json', {
|
||||
signal: controller.signal,
|
||||
headers: {
|
||||
'User-Agent':
|
||||
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36',
|
||||
'Accept-Language': 'en-US,en;q=0.9'
|
||||
}
|
||||
})
|
||||
|
||||
clearTimeout(timeoutId)
|
||||
const data = await ipinfo.json()
|
||||
return data.country || 'CN'
|
||||
} catch (error) {
|
||||
logger.error('Failed to get ipinfo:', error)
|
||||
return 'CN'
|
||||
}
|
||||
}
|
||||
|
||||
public setAutoUpdate(isActive: boolean) {
|
||||
autoUpdater.autoDownload = isActive
|
||||
autoUpdater.autoInstallOnAppQuit = isActive
|
||||
}
|
||||
|
||||
private _getChannelByVersion(version: string) {
|
||||
if (version.includes(`-${UpgradeChannel.BETA}.`)) {
|
||||
return UpgradeChannel.BETA
|
||||
}
|
||||
if (version.includes(`-${UpgradeChannel.RC}.`)) {
|
||||
return UpgradeChannel.RC
|
||||
}
|
||||
return UpgradeChannel.LATEST
|
||||
}
|
||||
|
||||
private _getTestChannel() {
|
||||
const currentChannel = this._getChannelByVersion(app.getVersion())
|
||||
const savedChannel = configManager.getTestChannel()
|
||||
|
||||
if (currentChannel === UpgradeChannel.LATEST) {
|
||||
return savedChannel || UpgradeChannel.RC
|
||||
}
|
||||
|
||||
if (savedChannel === currentChannel) {
|
||||
return savedChannel
|
||||
}
|
||||
|
||||
// if the upgrade channel is not equal to the current channel, use the latest channel
|
||||
return UpgradeChannel.LATEST
|
||||
}
|
||||
|
||||
private async _setFeedUrl() {
|
||||
const testPlan = configManager.getTestPlan()
|
||||
if (testPlan) {
|
||||
const channel = this._getTestChannel()
|
||||
|
||||
if (channel === UpgradeChannel.LATEST) {
|
||||
this.autoUpdater.channel = UpgradeChannel.LATEST
|
||||
this.autoUpdater.setFeedURL(FeedUrl.GITHUB_LATEST)
|
||||
return
|
||||
}
|
||||
|
||||
const preReleaseUrl = await this._getPreReleaseVersionFromGithub(channel)
|
||||
if (preReleaseUrl) {
|
||||
this.autoUpdater.setFeedURL(preReleaseUrl)
|
||||
this.autoUpdater.channel = channel
|
||||
return
|
||||
}
|
||||
|
||||
// if no prerelease url, use lowest prerelease version to avoid error
|
||||
this.autoUpdater.setFeedURL(FeedUrl.PRERELEASE_LOWEST)
|
||||
this.autoUpdater.channel = UpgradeChannel.LATEST
|
||||
return
|
||||
}
|
||||
|
||||
this.autoUpdater.channel = UpgradeChannel.LATEST
|
||||
this.autoUpdater.setFeedURL(FeedUrl.PRODUCTION)
|
||||
|
||||
const ipCountry = await this._getIpCountry()
|
||||
logger.info('ipCountry', ipCountry)
|
||||
if (ipCountry.toLowerCase() !== 'cn') {
|
||||
this.autoUpdater.setFeedURL(FeedUrl.GITHUB_LATEST)
|
||||
}
|
||||
}
|
||||
|
||||
public cancelDownload() {
|
||||
this.cancellationToken.cancel()
|
||||
this.cancellationToken = new CancellationToken()
|
||||
if (this.autoUpdater.autoDownload) {
|
||||
this.updateCheckResult?.cancellationToken?.cancel()
|
||||
}
|
||||
}
|
||||
|
||||
public async checkForUpdates() {
|
||||
if (isWin && 'PORTABLE_EXECUTABLE_DIR' in process.env) {
|
||||
return {
|
||||
@ -69,17 +203,26 @@ export default class AppUpdater {
|
||||
}
|
||||
}
|
||||
|
||||
await this._setFeedUrl()
|
||||
|
||||
// disable downgrade after change the channel
|
||||
this.autoUpdater.allowDowngrade = false
|
||||
|
||||
// github and gitcode don't support multiple range download
|
||||
this.autoUpdater.disableDifferentialDownload = true
|
||||
|
||||
try {
|
||||
const update = await this.autoUpdater.checkForUpdates()
|
||||
if (update?.isUpdateAvailable && !this.autoUpdater.autoDownload) {
|
||||
this.updateCheckResult = await this.autoUpdater.checkForUpdates()
|
||||
if (this.updateCheckResult?.isUpdateAvailable && !this.autoUpdater.autoDownload) {
|
||||
// 如果 autoDownload 为 false,则需要再调用下面的函数触发下
|
||||
// do not use await, because it will block the return of this function
|
||||
this.autoUpdater.downloadUpdate()
|
||||
logger.info('downloadUpdate manual by check for updates', this.cancellationToken)
|
||||
this.autoUpdater.downloadUpdate(this.cancellationToken)
|
||||
}
|
||||
|
||||
return {
|
||||
currentVersion: this.autoUpdater.currentVersion,
|
||||
updateInfo: update?.updateInfo
|
||||
updateInfo: this.updateCheckResult?.updateInfo
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to check for update:', error)
|
||||
@ -94,15 +237,22 @@ export default class AppUpdater {
|
||||
if (!this.releaseInfo) {
|
||||
return
|
||||
}
|
||||
const locale = locales[configManager.getLanguage()]
|
||||
const { update: updateLocale } = locale.translation
|
||||
|
||||
let detail = this.formatReleaseNotes(this.releaseInfo.releaseNotes)
|
||||
if (detail === '') {
|
||||
detail = updateLocale.noReleaseNotes
|
||||
}
|
||||
|
||||
dialog
|
||||
.showMessageBox({
|
||||
type: 'info',
|
||||
title: '安装更新',
|
||||
title: updateLocale.title,
|
||||
icon,
|
||||
message: `新版本 ${this.releaseInfo.version} 已准备就绪`,
|
||||
detail: this.formatReleaseNotes(this.releaseInfo.releaseNotes),
|
||||
buttons: ['稍后安装', '立即安装'],
|
||||
message: updateLocale.message.replace('{{version}}', this.releaseInfo.version),
|
||||
detail,
|
||||
buttons: [updateLocale.later, updateLocale.install],
|
||||
defaultId: 1,
|
||||
cancelId: 0
|
||||
})
|
||||
@ -118,7 +268,7 @@ export default class AppUpdater {
|
||||
|
||||
private formatReleaseNotes(releaseNotes: string | ReleaseNoteInfo[] | null | undefined): string {
|
||||
if (!releaseNotes) {
|
||||
return '暂无更新说明'
|
||||
return ''
|
||||
}
|
||||
|
||||
if (typeof releaseNotes === 'string') {
|
||||
@ -128,7 +278,11 @@ export default class AppUpdater {
|
||||
return releaseNotes.map((note) => note.note).join('\n')
|
||||
}
|
||||
}
|
||||
|
||||
interface GithubReleaseInfo {
|
||||
draft: boolean
|
||||
prerelease: boolean
|
||||
tag_name: string
|
||||
}
|
||||
interface ReleaseNoteInfo {
|
||||
readonly version: string
|
||||
readonly note: string | null
|
||||
|
||||
@ -7,8 +7,9 @@ import Logger from 'electron-log'
|
||||
import * as fs from 'fs-extra'
|
||||
import StreamZip from 'node-stream-zip'
|
||||
import * as path from 'path'
|
||||
import { createClient, CreateDirectoryOptions, FileStat } from 'webdav'
|
||||
import { CreateDirectoryOptions, FileStat } from 'webdav'
|
||||
|
||||
import { getDataPath } from '../utils'
|
||||
import WebDav from './WebDav'
|
||||
import { windowService } from './WindowService'
|
||||
|
||||
@ -253,7 +254,7 @@ class BackupManager {
|
||||
Logger.log('[backup] step 3: restore Data directory')
|
||||
// 恢复 Data 目录
|
||||
const sourcePath = path.join(this.tempDir, 'Data')
|
||||
const destPath = path.join(app.getPath('userData'), 'Data')
|
||||
const destPath = getDataPath()
|
||||
|
||||
const dataExists = await fs.pathExists(sourcePath)
|
||||
const dataFiles = dataExists ? await fs.readdir(sourcePath) : []
|
||||
@ -295,10 +296,12 @@ class BackupManager {
|
||||
async backupToWebdav(_: Electron.IpcMainInvokeEvent, data: string, webdavConfig: WebDavConfig) {
|
||||
const filename = webdavConfig.fileName || 'cherry-studio.backup.zip'
|
||||
const backupedFilePath = await this.backup(_, filename, data, undefined, webdavConfig.skipBackupFile)
|
||||
const contentLength = (await fs.stat(backupedFilePath)).size
|
||||
const webdavClient = new WebDav(webdavConfig)
|
||||
try {
|
||||
const result = await webdavClient.putFileContents(filename, fs.createReadStream(backupedFilePath), {
|
||||
overwrite: true
|
||||
overwrite: true,
|
||||
contentLength
|
||||
})
|
||||
// 上传成功后删除本地备份文件
|
||||
await fs.remove(backupedFilePath)
|
||||
@ -340,12 +343,8 @@ class BackupManager {
|
||||
|
||||
listWebdavFiles = async (_: Electron.IpcMainInvokeEvent, config: WebDavConfig) => {
|
||||
try {
|
||||
const client = createClient(config.webdavHost, {
|
||||
username: config.webdavUser,
|
||||
password: config.webdavPass
|
||||
})
|
||||
|
||||
const response = await client.getDirectoryContents(config.webdavPath)
|
||||
const client = new WebDav(config)
|
||||
const response = await client.getDirectoryContents()
|
||||
const files = Array.isArray(response) ? response : response.data
|
||||
|
||||
return files
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { defaultLanguage, ZOOM_SHORTCUTS } from '@shared/config/constant'
|
||||
import { defaultLanguage, UpgradeChannel, ZOOM_SHORTCUTS } from '@shared/config/constant'
|
||||
import { LanguageVarious, Shortcut, ThemeMode } from '@types'
|
||||
import { app } from 'electron'
|
||||
import Store from 'electron-store'
|
||||
@ -16,10 +16,15 @@ export enum ConfigKeys {
|
||||
ClickTrayToShowQuickAssistant = 'clickTrayToShowQuickAssistant',
|
||||
EnableQuickAssistant = 'enableQuickAssistant',
|
||||
AutoUpdate = 'autoUpdate',
|
||||
TestPlan = 'testPlan',
|
||||
TestChannel = 'testChannel',
|
||||
EnableDataCollection = 'enableDataCollection',
|
||||
SelectionAssistantEnabled = 'selectionAssistantEnabled',
|
||||
SelectionAssistantTriggerMode = 'selectionAssistantTriggerMode',
|
||||
SelectionAssistantFollowToolbar = 'selectionAssistantFollowToolbar'
|
||||
SelectionAssistantFollowToolbar = 'selectionAssistantFollowToolbar',
|
||||
SelectionAssistantRemeberWinSize = 'selectionAssistantRemeberWinSize',
|
||||
SelectionAssistantFilterMode = 'selectionAssistantFilterMode',
|
||||
SelectionAssistantFilterList = 'selectionAssistantFilterList'
|
||||
}
|
||||
|
||||
export class ConfigManager {
|
||||
@ -35,12 +40,12 @@ export class ConfigManager {
|
||||
return this.get(ConfigKeys.Language, locale) as LanguageVarious
|
||||
}
|
||||
|
||||
setLanguage(theme: LanguageVarious) {
|
||||
this.set(ConfigKeys.Language, theme)
|
||||
setLanguage(lang: LanguageVarious) {
|
||||
this.setAndNotify(ConfigKeys.Language, lang)
|
||||
}
|
||||
|
||||
getTheme(): ThemeMode {
|
||||
return this.get(ConfigKeys.Theme, ThemeMode.auto)
|
||||
return this.get(ConfigKeys.Theme, ThemeMode.system)
|
||||
}
|
||||
|
||||
setTheme(theme: ThemeMode) {
|
||||
@ -60,8 +65,7 @@ export class ConfigManager {
|
||||
}
|
||||
|
||||
setTray(value: boolean) {
|
||||
this.set(ConfigKeys.Tray, value)
|
||||
this.notifySubscribers(ConfigKeys.Tray, value)
|
||||
this.setAndNotify(ConfigKeys.Tray, value)
|
||||
}
|
||||
|
||||
getTrayOnClose(): boolean {
|
||||
@ -77,8 +81,7 @@ export class ConfigManager {
|
||||
}
|
||||
|
||||
setZoomFactor(factor: number) {
|
||||
this.set(ConfigKeys.ZoomFactor, factor)
|
||||
this.notifySubscribers(ConfigKeys.ZoomFactor, factor)
|
||||
this.setAndNotify(ConfigKeys.ZoomFactor, factor)
|
||||
}
|
||||
|
||||
subscribe<T>(key: string, callback: (newValue: T) => void) {
|
||||
@ -110,11 +113,10 @@ export class ConfigManager {
|
||||
}
|
||||
|
||||
setShortcuts(shortcuts: Shortcut[]) {
|
||||
this.set(
|
||||
this.setAndNotify(
|
||||
ConfigKeys.Shortcuts,
|
||||
shortcuts.filter((shortcut) => shortcut.system)
|
||||
)
|
||||
this.notifySubscribers(ConfigKeys.Shortcuts, shortcuts)
|
||||
}
|
||||
|
||||
getClickTrayToShowQuickAssistant(): boolean {
|
||||
@ -130,7 +132,7 @@ export class ConfigManager {
|
||||
}
|
||||
|
||||
setEnableQuickAssistant(value: boolean) {
|
||||
this.set(ConfigKeys.EnableQuickAssistant, value)
|
||||
this.setAndNotify(ConfigKeys.EnableQuickAssistant, value)
|
||||
}
|
||||
|
||||
getAutoUpdate(): boolean {
|
||||
@ -141,6 +143,22 @@ export class ConfigManager {
|
||||
this.set(ConfigKeys.AutoUpdate, value)
|
||||
}
|
||||
|
||||
getTestPlan(): boolean {
|
||||
return this.get<boolean>(ConfigKeys.TestPlan, false)
|
||||
}
|
||||
|
||||
setTestPlan(value: boolean) {
|
||||
this.set(ConfigKeys.TestPlan, value)
|
||||
}
|
||||
|
||||
getTestChannel(): UpgradeChannel {
|
||||
return this.get<UpgradeChannel>(ConfigKeys.TestChannel)
|
||||
}
|
||||
|
||||
setTestChannel(value: UpgradeChannel) {
|
||||
this.set(ConfigKeys.TestChannel, value)
|
||||
}
|
||||
|
||||
getEnableDataCollection(): boolean {
|
||||
return this.get<boolean>(ConfigKeys.EnableDataCollection, true)
|
||||
}
|
||||
@ -151,12 +169,11 @@ export class ConfigManager {
|
||||
|
||||
// Selection Assistant: is enabled the selection assistant
|
||||
getSelectionAssistantEnabled(): boolean {
|
||||
return this.get<boolean>(ConfigKeys.SelectionAssistantEnabled, true)
|
||||
return this.get<boolean>(ConfigKeys.SelectionAssistantEnabled, false)
|
||||
}
|
||||
|
||||
setSelectionAssistantEnabled(value: boolean) {
|
||||
this.set(ConfigKeys.SelectionAssistantEnabled, value)
|
||||
this.notifySubscribers(ConfigKeys.SelectionAssistantEnabled, value)
|
||||
this.setAndNotify(ConfigKeys.SelectionAssistantEnabled, value)
|
||||
}
|
||||
|
||||
// Selection Assistant: trigger mode (selected, ctrlkey)
|
||||
@ -165,8 +182,7 @@ export class ConfigManager {
|
||||
}
|
||||
|
||||
setSelectionAssistantTriggerMode(value: string) {
|
||||
this.set(ConfigKeys.SelectionAssistantTriggerMode, value)
|
||||
this.notifySubscribers(ConfigKeys.SelectionAssistantTriggerMode, value)
|
||||
this.setAndNotify(ConfigKeys.SelectionAssistantTriggerMode, value)
|
||||
}
|
||||
|
||||
// Selection Assistant: if action window position follow toolbar
|
||||
@ -175,12 +191,40 @@ export class ConfigManager {
|
||||
}
|
||||
|
||||
setSelectionAssistantFollowToolbar(value: boolean) {
|
||||
this.set(ConfigKeys.SelectionAssistantFollowToolbar, value)
|
||||
this.notifySubscribers(ConfigKeys.SelectionAssistantFollowToolbar, value)
|
||||
this.setAndNotify(ConfigKeys.SelectionAssistantFollowToolbar, value)
|
||||
}
|
||||
|
||||
set(key: string, value: unknown) {
|
||||
getSelectionAssistantRemeberWinSize(): boolean {
|
||||
return this.get<boolean>(ConfigKeys.SelectionAssistantRemeberWinSize, false)
|
||||
}
|
||||
|
||||
setSelectionAssistantRemeberWinSize(value: boolean) {
|
||||
this.setAndNotify(ConfigKeys.SelectionAssistantRemeberWinSize, value)
|
||||
}
|
||||
|
||||
getSelectionAssistantFilterMode(): string {
|
||||
return this.get<string>(ConfigKeys.SelectionAssistantFilterMode, 'default')
|
||||
}
|
||||
|
||||
setSelectionAssistantFilterMode(value: string) {
|
||||
this.setAndNotify(ConfigKeys.SelectionAssistantFilterMode, value)
|
||||
}
|
||||
|
||||
getSelectionAssistantFilterList(): string[] {
|
||||
return this.get<string[]>(ConfigKeys.SelectionAssistantFilterList, [])
|
||||
}
|
||||
|
||||
setSelectionAssistantFilterList(value: string[]) {
|
||||
this.setAndNotify(ConfigKeys.SelectionAssistantFilterList, value)
|
||||
}
|
||||
|
||||
setAndNotify(key: string, value: unknown) {
|
||||
this.set(key, value, true)
|
||||
}
|
||||
|
||||
set(key: string, value: unknown, isNotify: boolean = false) {
|
||||
this.store.set(key, value)
|
||||
isNotify && this.notifySubscribers(key, value)
|
||||
}
|
||||
|
||||
get<T>(key: string, defaultValue?: T) {
|
||||
|
||||
@ -4,18 +4,29 @@ import { locales } from '../utils/locales'
|
||||
import { configManager } from './ConfigManager'
|
||||
|
||||
class ContextMenu {
|
||||
public contextMenu(w: Electron.BrowserWindow) {
|
||||
w.webContents.on('context-menu', (_event, properties) => {
|
||||
public contextMenu(w: Electron.WebContents) {
|
||||
w.on('context-menu', (_event, properties) => {
|
||||
const template: MenuItemConstructorOptions[] = this.createEditMenuItems(properties)
|
||||
const filtered = template.filter((item) => item.visible !== false)
|
||||
if (filtered.length > 0) {
|
||||
const menu = Menu.buildFromTemplate([...filtered, ...this.createInspectMenuItems(w)])
|
||||
let template = [...filtered, ...this.createInspectMenuItems(w)]
|
||||
const dictionarySuggestions = this.createDictionarySuggestions(properties, w)
|
||||
if (dictionarySuggestions.length > 0) {
|
||||
template = [
|
||||
...dictionarySuggestions,
|
||||
{ type: 'separator' },
|
||||
this.createSpellCheckMenuItem(properties, w),
|
||||
{ type: 'separator' },
|
||||
...template
|
||||
]
|
||||
}
|
||||
const menu = Menu.buildFromTemplate(template)
|
||||
menu.popup()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private createInspectMenuItems(w: Electron.BrowserWindow): MenuItemConstructorOptions[] {
|
||||
private createInspectMenuItems(w: Electron.WebContents): MenuItemConstructorOptions[] {
|
||||
const locale = locales[configManager.getLanguage()]
|
||||
const { common } = locale.translation
|
||||
const template: MenuItemConstructorOptions[] = [
|
||||
@ -23,7 +34,7 @@ class ContextMenu {
|
||||
id: 'inspect',
|
||||
label: common.inspect,
|
||||
click: () => {
|
||||
w.webContents.toggleDevTools()
|
||||
w.toggleDevTools()
|
||||
},
|
||||
enabled: true
|
||||
}
|
||||
@ -72,6 +83,53 @@ class ContextMenu {
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
private createSpellCheckMenuItem(
|
||||
properties: Electron.ContextMenuParams,
|
||||
w: Electron.WebContents
|
||||
): MenuItemConstructorOptions {
|
||||
const hasText = properties.selectionText.length > 0
|
||||
|
||||
return {
|
||||
id: 'learnSpelling',
|
||||
label: '&Learn Spelling',
|
||||
visible: Boolean(properties.isEditable && hasText && properties.misspelledWord),
|
||||
click: () => {
|
||||
w.session.addWordToSpellCheckerDictionary(properties.misspelledWord)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private createDictionarySuggestions(
|
||||
properties: Electron.ContextMenuParams,
|
||||
w: Electron.WebContents
|
||||
): MenuItemConstructorOptions[] {
|
||||
const hasText = properties.selectionText.length > 0
|
||||
|
||||
if (!hasText || !properties.misspelledWord) {
|
||||
return []
|
||||
}
|
||||
|
||||
if (properties.dictionarySuggestions.length === 0) {
|
||||
return [
|
||||
{
|
||||
id: 'dictionarySuggestions',
|
||||
label: 'No Guesses Found',
|
||||
visible: true,
|
||||
enabled: false
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return properties.dictionarySuggestions.map((suggestion) => ({
|
||||
id: 'dictionarySuggestions',
|
||||
label: suggestion,
|
||||
visible: Boolean(properties.isEditable && hasText && properties.misspelledWord),
|
||||
click: (menuItem: Electron.MenuItem) => {
|
||||
w.replaceMisspelling(menuItem.label)
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
||||
export const contextMenu = new ContextMenu()
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import fs from 'node:fs'
|
||||
import fs from 'fs/promises'
|
||||
|
||||
export default class FileService {
|
||||
public static async readFile(_: Electron.IpcMainInvokeEvent, path: string) {
|
||||
return fs.readFileSync(path, 'utf8')
|
||||
public static async readFile(_: Electron.IpcMainInvokeEvent, pathOrUrl: string, encoding?: BufferEncoding) {
|
||||
const path = pathOrUrl.startsWith('file://') ? new URL(pathOrUrl) : pathOrUrl
|
||||
if (encoding) return fs.readFile(path, { encoding })
|
||||
return fs.readFile(path)
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,9 +15,11 @@ import * as fs from 'fs'
|
||||
import { writeFileSync } from 'fs'
|
||||
import { readFile } from 'fs/promises'
|
||||
import officeParser from 'officeparser'
|
||||
import { getDocument } from 'officeparser/pdfjs-dist-build/pdf.js'
|
||||
import * as path from 'path'
|
||||
import { chdir } from 'process'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import WordExtractor from 'word-extractor'
|
||||
|
||||
class FileStorage {
|
||||
private storageDir = getFilesDir()
|
||||
@ -219,10 +221,20 @@ class FileStorage {
|
||||
public readFile = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<string> => {
|
||||
const filePath = path.join(this.storageDir, id)
|
||||
|
||||
if (documentExts.includes(path.extname(filePath))) {
|
||||
const fileExtension = path.extname(filePath)
|
||||
|
||||
if (documentExts.includes(fileExtension)) {
|
||||
const originalCwd = process.cwd()
|
||||
try {
|
||||
chdir(this.tempDir)
|
||||
|
||||
if (fileExtension === '.doc') {
|
||||
const extractor = new WordExtractor()
|
||||
const extracted = await extractor.extract(filePath)
|
||||
chdir(originalCwd)
|
||||
return extracted.getBody()
|
||||
}
|
||||
|
||||
const data = await officeParser.parseOfficeAsync(filePath)
|
||||
chdir(originalCwd)
|
||||
return data
|
||||
@ -268,6 +280,51 @@ class FileStorage {
|
||||
}
|
||||
}
|
||||
|
||||
public saveBase64Image = async (_: Electron.IpcMainInvokeEvent, base64Data: string): Promise<FileType> => {
|
||||
try {
|
||||
if (!base64Data) {
|
||||
throw new Error('Base64 data is required')
|
||||
}
|
||||
|
||||
// 移除 base64 头部信息(如果存在)
|
||||
const base64String = base64Data.replace(/^data:.*;base64,/, '')
|
||||
const buffer = Buffer.from(base64String, 'base64')
|
||||
const uuid = uuidv4()
|
||||
const ext = '.png'
|
||||
const destPath = path.join(this.storageDir, uuid + ext)
|
||||
|
||||
logger.info('[FileStorage] Saving base64 image:', {
|
||||
storageDir: this.storageDir,
|
||||
destPath,
|
||||
bufferSize: buffer.length
|
||||
})
|
||||
|
||||
// 确保目录存在
|
||||
if (!fs.existsSync(this.storageDir)) {
|
||||
fs.mkdirSync(this.storageDir, { recursive: true })
|
||||
}
|
||||
|
||||
await fs.promises.writeFile(destPath, buffer)
|
||||
|
||||
const fileMetadata: FileType = {
|
||||
id: uuid,
|
||||
origin_name: uuid + ext,
|
||||
name: uuid + ext,
|
||||
path: destPath,
|
||||
created_at: new Date().toISOString(),
|
||||
size: buffer.length,
|
||||
ext: ext.slice(1),
|
||||
type: getFileType(ext),
|
||||
count: 1
|
||||
}
|
||||
|
||||
return fileMetadata
|
||||
} catch (error) {
|
||||
logger.error('[FileStorage] Failed to save base64 image:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public base64File = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: string; mime: string }> => {
|
||||
const filePath = path.join(this.storageDir, id)
|
||||
const buffer = await fs.promises.readFile(filePath)
|
||||
@ -276,6 +333,16 @@ class FileStorage {
|
||||
return { data: base64, mime }
|
||||
}
|
||||
|
||||
public pdfPageCount = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<number> => {
|
||||
const filePath = path.join(this.storageDir, id)
|
||||
const buffer = await fs.promises.readFile(filePath)
|
||||
|
||||
const doc = await getDocument({ data: buffer }).promise
|
||||
const pages = doc.numPages
|
||||
await doc.destroy()
|
||||
return pages
|
||||
}
|
||||
|
||||
public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
|
||||
const filePath = path.join(this.storageDir, id)
|
||||
const data = await fs.promises.readFile(filePath)
|
||||
@ -296,7 +363,7 @@ class FileStorage {
|
||||
public open = async (
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
options: OpenDialogOptions
|
||||
): Promise<{ fileName: string; filePath: string; content: Buffer } | null> => {
|
||||
): Promise<{ fileName: string; filePath: string; content?: Buffer; size: number } | null> => {
|
||||
try {
|
||||
const result: OpenDialogReturnValue = await dialog.showOpenDialog({
|
||||
title: '打开文件',
|
||||
@ -308,8 +375,16 @@ class FileStorage {
|
||||
if (!result.canceled && result.filePaths.length > 0) {
|
||||
const filePath = result.filePaths[0]
|
||||
const fileName = filePath.split('/').pop() || ''
|
||||
const content = await readFile(filePath)
|
||||
return { fileName, filePath, content }
|
||||
const stats = await fs.promises.stat(filePath)
|
||||
|
||||
// If the file is less than 2GB, read the content
|
||||
if (stats.size < 2 * 1024 * 1024 * 1024) {
|
||||
const content = await readFile(filePath)
|
||||
return { fileName, filePath, content, size: stats.size }
|
||||
}
|
||||
|
||||
// For large files, only return file information, do not read content
|
||||
return { fileName, filePath, size: stats.size }
|
||||
}
|
||||
|
||||
return null
|
||||
|
||||
@ -1,79 +0,0 @@
|
||||
import { File, FileState, GoogleGenAI, Pager } from '@google/genai'
|
||||
import { FileType } from '@types'
|
||||
import fs from 'fs'
|
||||
|
||||
import { CacheService } from './CacheService'
|
||||
|
||||
export class GeminiService {
|
||||
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
|
||||
private static readonly CACHE_DURATION = 3000
|
||||
|
||||
static async uploadFile(
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
file: FileType,
|
||||
{ apiKey, baseURL }: { apiKey: string; baseURL: string }
|
||||
): Promise<File> {
|
||||
const sdk = new GoogleGenAI({
|
||||
vertexai: false,
|
||||
apiKey,
|
||||
httpOptions: {
|
||||
baseUrl: baseURL
|
||||
}
|
||||
})
|
||||
|
||||
return await sdk.files.upload({
|
||||
file: file.path,
|
||||
config: {
|
||||
mimeType: 'application/pdf',
|
||||
name: file.id,
|
||||
displayName: file.origin_name
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
static async base64File(_: Electron.IpcMainInvokeEvent, file: FileType) {
|
||||
return {
|
||||
data: Buffer.from(fs.readFileSync(file.path)).toString('base64'),
|
||||
mimeType: 'application/pdf'
|
||||
}
|
||||
}
|
||||
|
||||
static async retrieveFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File | undefined> {
|
||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||
const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY)
|
||||
if (cachedResponse) {
|
||||
return GeminiService.processResponse(cachedResponse, file)
|
||||
}
|
||||
|
||||
const response = await sdk.files.list()
|
||||
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION)
|
||||
|
||||
return GeminiService.processResponse(response, file)
|
||||
}
|
||||
|
||||
private static async processResponse(response: Pager<File>, file: FileType) {
|
||||
for await (const f of response) {
|
||||
if (f.state === FileState.ACTIVE) {
|
||||
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string): Promise<File[]> {
|
||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||
const files: File[] = []
|
||||
for await (const f of await sdk.files.list()) {
|
||||
files.push(f)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
static async deleteFile(_: Electron.IpcMainInvokeEvent, fileId: string, apiKey: string) {
|
||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
||||
await sdk.files.delete({ name: fileId })
|
||||
}
|
||||
}
|
||||
@ -16,21 +16,22 @@
|
||||
import * as fs from 'node:fs'
|
||||
import path from 'node:path'
|
||||
|
||||
import { RAGApplication, RAGApplicationBuilder, TextLoader } from '@cherrystudio/embedjs'
|
||||
import { RAGApplication, RAGApplicationBuilder } from '@cherrystudio/embedjs'
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { LibSqlDb } from '@cherrystudio/embedjs-libsql'
|
||||
import { SitemapLoader } from '@cherrystudio/embedjs-loader-sitemap'
|
||||
import { WebLoader } from '@cherrystudio/embedjs-loader-web'
|
||||
import Embeddings from '@main/embeddings/Embeddings'
|
||||
import { addFileLoader } from '@main/loader'
|
||||
import Reranker from '@main/reranker/Reranker'
|
||||
import Embeddings from '@main/knowledage/embeddings/Embeddings'
|
||||
import { addFileLoader } from '@main/knowledage/loader'
|
||||
import { NoteLoader } from '@main/knowledage/loader/noteLoader'
|
||||
import Reranker from '@main/knowledage/reranker/Reranker'
|
||||
import { windowService } from '@main/services/WindowService'
|
||||
import { getDataPath } from '@main/utils'
|
||||
import { getAllFiles } from '@main/utils/file'
|
||||
import { MB } from '@shared/config/constant'
|
||||
import type { LoaderReturn } from '@shared/config/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { FileType, KnowledgeBaseParams, KnowledgeItem } from '@types'
|
||||
import { app } from 'electron'
|
||||
import Logger from 'electron-log'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
@ -88,7 +89,7 @@ const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => {
|
||||
}
|
||||
|
||||
class KnowledgeService {
|
||||
private storageDir = path.join(app.getPath('userData'), 'Data', 'KnowledgeBase')
|
||||
private storageDir = path.join(getDataPath(), 'KnowledgeBase')
|
||||
// Byte based
|
||||
private workload = 0
|
||||
private processingItemCount = 0
|
||||
@ -110,13 +111,21 @@ class KnowledgeService {
|
||||
private getRagApplication = async ({
|
||||
id,
|
||||
model,
|
||||
provider,
|
||||
apiKey,
|
||||
apiVersion,
|
||||
baseURL,
|
||||
dimensions
|
||||
}: KnowledgeBaseParams): Promise<RAGApplication> => {
|
||||
let ragApplication: RAGApplication
|
||||
const embeddings = new Embeddings({ model, apiKey, apiVersion, baseURL, dimensions } as KnowledgeBaseParams)
|
||||
const embeddings = new Embeddings({
|
||||
model,
|
||||
provider,
|
||||
apiKey,
|
||||
apiVersion,
|
||||
baseURL,
|
||||
dimensions
|
||||
} as KnowledgeBaseParams)
|
||||
try {
|
||||
ragApplication = await new RAGApplicationBuilder()
|
||||
.setModel('NO_MODEL')
|
||||
@ -135,7 +144,7 @@ class KnowledgeService {
|
||||
this.getRagApplication(base)
|
||||
}
|
||||
|
||||
public reset = async (_: Electron.IpcMainInvokeEvent, { base }: { base: KnowledgeBaseParams }): Promise<void> => {
|
||||
public reset = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> => {
|
||||
const ragApplication = await this.getRagApplication(base)
|
||||
await ragApplication.reset()
|
||||
}
|
||||
@ -325,6 +334,7 @@ class KnowledgeService {
|
||||
): LoaderTask {
|
||||
const { base, item, forceReload } = options
|
||||
const content = item.content as string
|
||||
const sourceUrl = (item as any).sourceUrl
|
||||
|
||||
const encoder = new TextEncoder()
|
||||
const contentBytes = encoder.encode(content)
|
||||
@ -334,7 +344,12 @@ class KnowledgeService {
|
||||
state: LoaderTaskItemState.PENDING,
|
||||
task: () => {
|
||||
const loaderReturn = ragApplication.addLoader(
|
||||
new TextLoader({ text: content, chunkSize: base.chunkSize, chunkOverlap: base.chunkOverlap }),
|
||||
new NoteLoader({
|
||||
text: content,
|
||||
sourceUrl,
|
||||
chunkSize: base.chunkSize,
|
||||
chunkOverlap: base.chunkOverlap
|
||||
}),
|
||||
forceReload
|
||||
) as Promise<LoaderReturn>
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ export function registerProtocolClient(app: Electron.App) {
|
||||
}
|
||||
}
|
||||
|
||||
app.setAsDefaultProtocolClient('cherrystudio')
|
||||
app.setAsDefaultProtocolClient(CHERRY_STUDIO_PROTOCOL)
|
||||
}
|
||||
|
||||
export function handleProtocolUrl(url: string) {
|
||||
|
||||
102
src/main/services/PythonService.ts
Normal file
102
src/main/services/PythonService.ts
Normal file
@ -0,0 +1,102 @@
|
||||
import { randomUUID } from 'node:crypto'
|
||||
|
||||
import { BrowserWindow, ipcMain } from 'electron'
|
||||
|
||||
interface PythonExecutionRequest {
|
||||
id: string
|
||||
script: string
|
||||
context: Record<string, any>
|
||||
timeout: number
|
||||
}
|
||||
|
||||
interface PythonExecutionResponse {
|
||||
id: string
|
||||
result?: string
|
||||
error?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Service for executing Python code by communicating with the PyodideService in the renderer process
|
||||
*/
|
||||
export class PythonService {
|
||||
private static instance: PythonService | null = null
|
||||
private mainWindow: BrowserWindow | null = null
|
||||
private pendingRequests = new Map<string, { resolve: (value: string) => void; reject: (error: Error) => void }>()
|
||||
|
||||
private constructor() {
|
||||
// Private constructor for singleton pattern
|
||||
this.setupIpcHandlers()
|
||||
}
|
||||
|
||||
public static getInstance(): PythonService {
|
||||
if (!PythonService.instance) {
|
||||
PythonService.instance = new PythonService()
|
||||
}
|
||||
return PythonService.instance
|
||||
}
|
||||
|
||||
private setupIpcHandlers() {
|
||||
// Handle responses from renderer
|
||||
ipcMain.on('python-execution-response', (_, response: PythonExecutionResponse) => {
|
||||
const request = this.pendingRequests.get(response.id)
|
||||
if (request) {
|
||||
this.pendingRequests.delete(response.id)
|
||||
if (response.error) {
|
||||
request.reject(new Error(response.error))
|
||||
} else {
|
||||
request.resolve(response.result || '')
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public setMainWindow(mainWindow: BrowserWindow) {
|
||||
this.mainWindow = mainWindow
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute Python code by sending request to renderer PyodideService
|
||||
*/
|
||||
public async executeScript(
|
||||
script: string,
|
||||
context: Record<string, any> = {},
|
||||
timeout: number = 60000
|
||||
): Promise<string> {
|
||||
if (!this.mainWindow) {
|
||||
throw new Error('Main window not set in PythonService')
|
||||
}
|
||||
|
||||
return new Promise((resolve, reject) => {
|
||||
const requestId = randomUUID()
|
||||
|
||||
// Store the request
|
||||
this.pendingRequests.set(requestId, { resolve, reject })
|
||||
|
||||
// Set up timeout
|
||||
const timeoutId = setTimeout(() => {
|
||||
this.pendingRequests.delete(requestId)
|
||||
reject(new Error('Python execution timed out'))
|
||||
}, timeout + 5000) // Add 5s buffer for IPC communication
|
||||
|
||||
// Update resolve/reject to clear timeout
|
||||
const originalResolve = resolve
|
||||
const originalReject = reject
|
||||
this.pendingRequests.set(requestId, {
|
||||
resolve: (value: string) => {
|
||||
clearTimeout(timeoutId)
|
||||
originalResolve(value)
|
||||
},
|
||||
reject: (error: Error) => {
|
||||
clearTimeout(timeoutId)
|
||||
originalReject(error)
|
||||
}
|
||||
})
|
||||
|
||||
// Send request to renderer
|
||||
const request: PythonExecutionRequest = { id: requestId, script, context, timeout }
|
||||
this.mainWindow?.webContents.send('python-execution-request', request)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const pythonService = PythonService.getInstance()
|
||||
@ -1,3 +1,4 @@
|
||||
import { SELECTION_FINETUNED_LIST, SELECTION_PREDEFINED_BLACKLIST } from '@main/configs/SelectionConfig'
|
||||
import { isDev, isWin } from '@main/constant'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { BrowserWindow, ipcMain, screen } from 'electron'
|
||||
@ -13,6 +14,7 @@ import type {
|
||||
|
||||
import type { ActionItem } from '../../renderer/src/types/selectionTypes'
|
||||
import { ConfigKeys, configManager } from './ConfigManager'
|
||||
import storeSyncService from './StoreSyncService'
|
||||
|
||||
let SelectionHook: SelectionHookConstructor | null = null
|
||||
try {
|
||||
@ -36,6 +38,12 @@ type RelativeOrientation =
|
||||
| 'middleRight'
|
||||
| 'center'
|
||||
|
||||
enum TriggerMode {
|
||||
Selected = 'selected',
|
||||
Ctrlkey = 'ctrlkey',
|
||||
Shortcut = 'shortcut'
|
||||
}
|
||||
|
||||
/** SelectionService is a singleton class that manages the selection hook and the toolbar window
|
||||
*
|
||||
* Features:
|
||||
@ -58,8 +66,11 @@ export class SelectionService {
|
||||
private initStatus: boolean = false
|
||||
private started: boolean = false
|
||||
|
||||
private triggerMode = 'selected'
|
||||
private triggerMode = TriggerMode.Selected
|
||||
private isFollowToolbar = true
|
||||
private isRemeberWinSize = false
|
||||
private filterMode = 'default'
|
||||
private filterList: string[] = []
|
||||
|
||||
private toolbarWindow: BrowserWindow | null = null
|
||||
private actionWindows = new Set<BrowserWindow>()
|
||||
@ -84,6 +95,11 @@ export class SelectionService {
|
||||
private readonly ACTION_WINDOW_WIDTH = 500
|
||||
private readonly ACTION_WINDOW_HEIGHT = 400
|
||||
|
||||
private lastActionWindowSize: { width: number; height: number } = {
|
||||
width: this.ACTION_WINDOW_WIDTH,
|
||||
height: this.ACTION_WINDOW_HEIGHT
|
||||
}
|
||||
|
||||
private constructor() {
|
||||
try {
|
||||
if (!SelectionHook) {
|
||||
@ -136,17 +152,106 @@ export class SelectionService {
|
||||
}
|
||||
|
||||
private initConfig() {
|
||||
this.triggerMode = configManager.getSelectionAssistantTriggerMode()
|
||||
this.triggerMode = configManager.getSelectionAssistantTriggerMode() as TriggerMode
|
||||
this.isFollowToolbar = configManager.getSelectionAssistantFollowToolbar()
|
||||
this.isRemeberWinSize = configManager.getSelectionAssistantRemeberWinSize()
|
||||
this.filterMode = configManager.getSelectionAssistantFilterMode()
|
||||
this.filterList = configManager.getSelectionAssistantFilterList()
|
||||
|
||||
this.setHookGlobalFilterMode(this.filterMode, this.filterList)
|
||||
this.setHookFineTunedList()
|
||||
|
||||
configManager.subscribe(ConfigKeys.SelectionAssistantTriggerMode, (triggerMode: TriggerMode) => {
|
||||
const oldTriggerMode = this.triggerMode
|
||||
|
||||
configManager.subscribe(ConfigKeys.SelectionAssistantTriggerMode, (triggerMode: string) => {
|
||||
this.triggerMode = triggerMode
|
||||
this.processTriggerMode()
|
||||
|
||||
//trigger mode changed, need to update the filter list
|
||||
if (oldTriggerMode !== triggerMode) {
|
||||
this.setHookGlobalFilterMode(this.filterMode, this.filterList)
|
||||
}
|
||||
})
|
||||
|
||||
configManager.subscribe(ConfigKeys.SelectionAssistantFollowToolbar, (isFollowToolbar: boolean) => {
|
||||
this.isFollowToolbar = isFollowToolbar
|
||||
})
|
||||
|
||||
configManager.subscribe(ConfigKeys.SelectionAssistantRemeberWinSize, (isRemeberWinSize: boolean) => {
|
||||
this.isRemeberWinSize = isRemeberWinSize
|
||||
//when off, reset the last action window size to default
|
||||
if (!this.isRemeberWinSize) {
|
||||
this.lastActionWindowSize = {
|
||||
width: this.ACTION_WINDOW_WIDTH,
|
||||
height: this.ACTION_WINDOW_HEIGHT
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
configManager.subscribe(ConfigKeys.SelectionAssistantFilterMode, (filterMode: string) => {
|
||||
this.filterMode = filterMode
|
||||
this.setHookGlobalFilterMode(this.filterMode, this.filterList)
|
||||
})
|
||||
|
||||
configManager.subscribe(ConfigKeys.SelectionAssistantFilterList, (filterList: string[]) => {
|
||||
this.filterList = filterList
|
||||
this.setHookGlobalFilterMode(this.filterMode, this.filterList)
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Set the global filter mode for the selection-hook
|
||||
* @param mode - The mode to set, either 'default', 'whitelist', or 'blacklist'
|
||||
* @param list - An array of strings representing the list of items to include or exclude
|
||||
*/
|
||||
private setHookGlobalFilterMode(mode: string, list: string[]) {
|
||||
if (!this.selectionHook) return
|
||||
|
||||
const modeMap = {
|
||||
default: SelectionHook!.FilterMode.DEFAULT,
|
||||
whitelist: SelectionHook!.FilterMode.INCLUDE_LIST,
|
||||
blacklist: SelectionHook!.FilterMode.EXCLUDE_LIST
|
||||
}
|
||||
|
||||
let combinedList: string[] = list
|
||||
let combinedMode = mode
|
||||
|
||||
//only the selected mode need to combine the predefined blacklist with the user-defined blacklist
|
||||
if (this.triggerMode === TriggerMode.Selected) {
|
||||
switch (mode) {
|
||||
case 'blacklist':
|
||||
//combine the predefined blacklist with the user-defined blacklist
|
||||
combinedList = [...new Set([...list, ...SELECTION_PREDEFINED_BLACKLIST.WINDOWS])]
|
||||
break
|
||||
case 'whitelist':
|
||||
combinedList = [...list]
|
||||
break
|
||||
case 'default':
|
||||
default:
|
||||
//use the predefined blacklist as the default filter list
|
||||
combinedList = [...SELECTION_PREDEFINED_BLACKLIST.WINDOWS]
|
||||
combinedMode = 'blacklist'
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if (!this.selectionHook.setGlobalFilterMode(modeMap[combinedMode], combinedList)) {
|
||||
this.logError(new Error('Failed to set selection-hook global filter mode'))
|
||||
}
|
||||
}
|
||||
|
||||
private setHookFineTunedList() {
|
||||
if (!this.selectionHook) return
|
||||
|
||||
this.selectionHook.setFineTunedList(
|
||||
SelectionHook!.FineTunedListType.EXCLUDE_CLIPBOARD_CURSOR_DETECT,
|
||||
SELECTION_FINETUNED_LIST.EXCLUDE_CLIPBOARD_CURSOR_DETECT.WINDOWS
|
||||
)
|
||||
|
||||
this.selectionHook.setFineTunedList(
|
||||
SelectionHook!.FineTunedListType.INCLUDE_CLIPBOARD_DELAY_READ,
|
||||
SELECTION_FINETUNED_LIST.INCLUDE_CLIPBOARD_DELAY_READ.WINDOWS
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -160,8 +265,6 @@ export class SelectionService {
|
||||
}
|
||||
|
||||
try {
|
||||
//init basic configs
|
||||
this.initConfig()
|
||||
//make sure the toolbar window is ready
|
||||
this.createToolbarWindow()
|
||||
// Initialize preloaded windows
|
||||
@ -175,11 +278,14 @@ export class SelectionService {
|
||||
|
||||
// Start the hook
|
||||
if (this.selectionHook.start({ debug: isDev })) {
|
||||
//init basic configs
|
||||
this.initConfig()
|
||||
|
||||
//init trigger mode configs
|
||||
this.processTriggerMode()
|
||||
|
||||
this.started = true
|
||||
this.logInfo('SelectionService Started')
|
||||
this.logInfo('SelectionService Started', true)
|
||||
return true
|
||||
}
|
||||
|
||||
@ -200,13 +306,20 @@ export class SelectionService {
|
||||
if (!this.selectionHook) return false
|
||||
|
||||
this.selectionHook.stop()
|
||||
this.selectionHook.cleanup()
|
||||
this.selectionHook.cleanup() //already remove all listeners
|
||||
|
||||
//reset the listener states
|
||||
this.isCtrlkeyListenerActive = false
|
||||
this.isHideByMouseKeyListenerActive = false
|
||||
|
||||
if (this.toolbarWindow) {
|
||||
this.toolbarWindow.close()
|
||||
this.toolbarWindow = null
|
||||
}
|
||||
this.closePreloadedActionWindows()
|
||||
|
||||
this.started = false
|
||||
this.logInfo('SelectionService Stopped')
|
||||
this.logInfo('SelectionService Stopped', true)
|
||||
return true
|
||||
}
|
||||
|
||||
@ -222,7 +335,22 @@ export class SelectionService {
|
||||
this.selectionHook = null
|
||||
this.initStatus = false
|
||||
SelectionService.instance = null
|
||||
this.logInfo('SelectionService Quitted')
|
||||
this.logInfo('SelectionService Quitted', true)
|
||||
}
|
||||
|
||||
/**
|
||||
* Toggle the enabled state of the selection service
|
||||
* Will sync the new enabled store to all renderer windows
|
||||
*/
|
||||
public toggleEnabled(enabled: boolean | undefined = undefined) {
|
||||
if (!this.selectionHook) return
|
||||
|
||||
const newEnabled = enabled === undefined ? !configManager.getSelectionAssistantEnabled() : enabled
|
||||
|
||||
configManager.setSelectionAssistantEnabled(newEnabled)
|
||||
|
||||
//sync the new enabled state to all renderer windows
|
||||
storeSyncService.syncToRenderer('selectionStore/setSelectionEnabled', newEnabled)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -269,6 +397,9 @@ export class SelectionService {
|
||||
|
||||
// Clean up when closed
|
||||
this.toolbarWindow.on('closed', () => {
|
||||
if (!this.toolbarWindow?.isDestroyed()) {
|
||||
this.toolbarWindow?.destroy()
|
||||
}
|
||||
this.toolbarWindow = null
|
||||
})
|
||||
|
||||
@ -325,8 +456,18 @@ export class SelectionService {
|
||||
x: posX,
|
||||
y: posY
|
||||
})
|
||||
|
||||
//set the window to always on top (highest level)
|
||||
//should set every time the window is shown
|
||||
this.toolbarWindow!.setAlwaysOnTop(true, 'screen-saver')
|
||||
this.toolbarWindow!.show()
|
||||
this.toolbarWindow!.setOpacity(1)
|
||||
|
||||
/**
|
||||
* In Windows 10, setOpacity(1) will make the window completely transparent
|
||||
* It's a strange behavior, so we don't use it for compatibility
|
||||
*/
|
||||
// this.toolbarWindow!.setOpacity(1)
|
||||
|
||||
this.startHideByMouseKeyListener()
|
||||
}
|
||||
|
||||
@ -336,7 +477,7 @@ export class SelectionService {
|
||||
public hideToolbar(): void {
|
||||
if (!this.isToolbarAlive()) return
|
||||
|
||||
this.toolbarWindow!.setOpacity(0)
|
||||
// this.toolbarWindow!.setOpacity(0)
|
||||
this.toolbarWindow!.hide()
|
||||
|
||||
this.stopHideByMouseKeyListener()
|
||||
@ -454,6 +595,45 @@ export class SelectionService {
|
||||
return startTop.y === endTop.y && startBottom.y === endBottom.y
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the user selected text and process it (trigger by shortcut)
|
||||
*
|
||||
* it's a public method used by shortcut service
|
||||
*/
|
||||
public processSelectTextByShortcut(): void {
|
||||
if (!this.selectionHook || !this.started || this.triggerMode !== TriggerMode.Shortcut) return
|
||||
|
||||
const selectionData = this.selectionHook.getCurrentSelection()
|
||||
|
||||
if (selectionData) {
|
||||
this.processTextSelection(selectionData)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Determine if the text selection should be processed by filter mode&list
|
||||
* @param selectionData Text selection information and coordinates
|
||||
* @returns {boolean} True if the selection should be processed, false otherwise
|
||||
*/
|
||||
private shouldProcessTextSelection(selectionData: TextSelectionData): boolean {
|
||||
if (selectionData.programName === '' || this.filterMode === 'default') {
|
||||
return true
|
||||
}
|
||||
|
||||
const programName = selectionData.programName.toLowerCase()
|
||||
//items in filterList are already in lower case
|
||||
const isFound = this.filterList.some((item) => programName.includes(item))
|
||||
|
||||
switch (this.filterMode) {
|
||||
case 'whitelist':
|
||||
return isFound
|
||||
case 'blacklist':
|
||||
return !isFound
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
/**
|
||||
* Process text selection data and show toolbar
|
||||
* Handles different selection scenarios:
|
||||
@ -468,6 +648,10 @@ export class SelectionService {
|
||||
return
|
||||
}
|
||||
|
||||
if (!this.shouldProcessTextSelection(selectionData)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Determine reference point and position for toolbar
|
||||
let refPoint: { x: number; y: number } = { x: 0, y: 0 }
|
||||
let isLogical = false
|
||||
@ -551,12 +735,16 @@ export class SelectionService {
|
||||
selectionData.endBottom
|
||||
)
|
||||
|
||||
// Note: shift key + mouse click == DoubleClick
|
||||
|
||||
//double click to select a word
|
||||
if (isDoubleClick && isSameLine) {
|
||||
refOrientation = 'bottomMiddle'
|
||||
refPoint = { x: selectionData.mousePosEnd.x, y: selectionData.endBottom.y + 4 }
|
||||
break
|
||||
}
|
||||
|
||||
// below: isDoubleClick || isSameLine
|
||||
if (isSameLine) {
|
||||
const direction = selectionData.mousePosEnd.x - selectionData.mousePosStart.x
|
||||
|
||||
@ -570,6 +758,7 @@ export class SelectionService {
|
||||
break
|
||||
}
|
||||
|
||||
// below: !isDoubleClick && !isSameLine
|
||||
const direction = selectionData.mousePosEnd.y - selectionData.mousePosStart.y
|
||||
|
||||
if (direction > 0) {
|
||||
@ -667,7 +856,11 @@ export class SelectionService {
|
||||
*/
|
||||
private handleKeyDownHide = (data: KeyboardEventData) => {
|
||||
//dont hide toolbar when ctrlkey is pressed
|
||||
if (this.triggerMode === 'ctrlkey' && this.isCtrlkey(data.vkCode)) {
|
||||
if (this.triggerMode === TriggerMode.Ctrlkey && this.isCtrlkey(data.vkCode)) {
|
||||
return
|
||||
}
|
||||
//dont hide toolbar when shiftkey or altkey is pressed, because it's used for selection
|
||||
if (this.isShiftkey(data.vkCode) || this.isAltkey(data.vkCode)) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -695,6 +888,9 @@ export class SelectionService {
|
||||
//ctrlkey pressed
|
||||
if (this.lastCtrlkeyDownTime === 0) {
|
||||
this.lastCtrlkeyDownTime = Date.now()
|
||||
//add the mouse-wheel&mouse-down listener, detect if user is zooming in/out or multi-selecting
|
||||
this.selectionHook!.on('mouse-wheel', this.handleMouseWheelCtrlkeyMode)
|
||||
this.selectionHook!.on('mouse-down', this.handleMouseDownCtrlkeyMode)
|
||||
return
|
||||
}
|
||||
|
||||
@ -705,7 +901,6 @@ export class SelectionService {
|
||||
this.lastCtrlkeyDownTime = -1
|
||||
|
||||
const selectionData = this.selectionHook!.getCurrentSelection()
|
||||
|
||||
if (selectionData) {
|
||||
this.processTextSelection(selectionData)
|
||||
}
|
||||
@ -718,14 +913,45 @@ export class SelectionService {
|
||||
*/
|
||||
private handleKeyUpCtrlkeyMode = (data: KeyboardEventData) => {
|
||||
if (!this.isCtrlkey(data.vkCode)) return
|
||||
//remove the mouse-wheel&mouse-down listener
|
||||
this.selectionHook!.off('mouse-wheel', this.handleMouseWheelCtrlkeyMode)
|
||||
this.selectionHook!.off('mouse-down', this.handleMouseDownCtrlkeyMode)
|
||||
this.lastCtrlkeyDownTime = 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle mouse wheel events in ctrlkey trigger mode
|
||||
* ignore CtrlKey pressing when mouse wheel is used
|
||||
* because user is zooming in/out
|
||||
*/
|
||||
private handleMouseWheelCtrlkeyMode = () => {
|
||||
this.lastCtrlkeyDownTime = -1
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle mouse down events in ctrlkey trigger mode
|
||||
* ignore CtrlKey pressing when mouse down is used
|
||||
* because user is multi-selecting
|
||||
*/
|
||||
private handleMouseDownCtrlkeyMode = () => {
|
||||
this.lastCtrlkeyDownTime = -1
|
||||
}
|
||||
|
||||
//check if the key is ctrl key
|
||||
private isCtrlkey(vkCode: number) {
|
||||
return vkCode === 162 || vkCode === 163
|
||||
}
|
||||
|
||||
//check if the key is shift key
|
||||
private isShiftkey(vkCode: number) {
|
||||
return vkCode === 160 || vkCode === 161
|
||||
}
|
||||
|
||||
//check if the key is alt key
|
||||
private isAltkey(vkCode: number) {
|
||||
return vkCode === 164 || vkCode === 165
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a preloaded action window for quick response
|
||||
* Action windows handle specific operations on selected text
|
||||
@ -733,8 +959,8 @@ export class SelectionService {
|
||||
*/
|
||||
private createPreloadedActionWindow(): BrowserWindow {
|
||||
const preloadedActionWindow = new BrowserWindow({
|
||||
width: this.ACTION_WINDOW_WIDTH,
|
||||
height: this.ACTION_WINDOW_HEIGHT,
|
||||
width: this.isRemeberWinSize ? this.lastActionWindowSize.width : this.ACTION_WINDOW_WIDTH,
|
||||
height: this.isRemeberWinSize ? this.lastActionWindowSize.height : this.ACTION_WINDOW_HEIGHT,
|
||||
minWidth: 300,
|
||||
minHeight: 200,
|
||||
frame: false,
|
||||
@ -778,6 +1004,17 @@ export class SelectionService {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Close all preloaded action windows
|
||||
*/
|
||||
private closePreloadedActionWindows() {
|
||||
for (const actionWindow of this.preloadedActionWindows) {
|
||||
if (!actionWindow.isDestroyed()) {
|
||||
actionWindow.destroy()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Preload a new action window asynchronously
|
||||
* This method is called after popping a window to ensure we always have windows ready
|
||||
@ -808,6 +1045,16 @@ export class SelectionService {
|
||||
}
|
||||
})
|
||||
|
||||
//remember the action window size
|
||||
actionWindow.on('resized', () => {
|
||||
if (this.isRemeberWinSize) {
|
||||
this.lastActionWindowSize = {
|
||||
width: actionWindow.getBounds().width,
|
||||
height: actionWindow.getBounds().height
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
this.actionWindows.add(actionWindow)
|
||||
|
||||
// Asynchronously create a new preloaded window
|
||||
@ -830,30 +1077,58 @@ export class SelectionService {
|
||||
* @param actionWindow Window to position and show
|
||||
*/
|
||||
private showActionWindow(actionWindow: BrowserWindow) {
|
||||
let actionWindowWidth = this.ACTION_WINDOW_WIDTH
|
||||
let actionWindowHeight = this.ACTION_WINDOW_HEIGHT
|
||||
|
||||
//if remember win size is true, use the last remembered size
|
||||
if (this.isRemeberWinSize) {
|
||||
actionWindowWidth = this.lastActionWindowSize.width
|
||||
actionWindowHeight = this.lastActionWindowSize.height
|
||||
}
|
||||
|
||||
//center way
|
||||
if (!this.isFollowToolbar || !this.toolbarWindow) {
|
||||
if (this.isRemeberWinSize) {
|
||||
actionWindow.setBounds({
|
||||
width: actionWindowWidth,
|
||||
height: actionWindowHeight
|
||||
})
|
||||
}
|
||||
|
||||
actionWindow.show()
|
||||
this.hideToolbar()
|
||||
return
|
||||
}
|
||||
|
||||
//follow toolbar
|
||||
|
||||
const toolbarBounds = this.toolbarWindow!.getBounds()
|
||||
const display = screen.getDisplayNearestPoint({ x: toolbarBounds.x, y: toolbarBounds.y })
|
||||
const workArea = display.workArea
|
||||
const GAP = 6 // 6px gap from screen edges
|
||||
|
||||
//make sure action window is inside screen
|
||||
if (actionWindowWidth > workArea.width - 2 * GAP) {
|
||||
actionWindowWidth = workArea.width - 2 * GAP
|
||||
}
|
||||
|
||||
if (actionWindowHeight > workArea.height - 2 * GAP) {
|
||||
actionWindowHeight = workArea.height - 2 * GAP
|
||||
}
|
||||
|
||||
// Calculate initial position to center action window horizontally below toolbar
|
||||
let posX = Math.round(toolbarBounds.x + (toolbarBounds.width - this.ACTION_WINDOW_WIDTH) / 2)
|
||||
let posX = Math.round(toolbarBounds.x + (toolbarBounds.width - actionWindowWidth) / 2)
|
||||
let posY = Math.round(toolbarBounds.y)
|
||||
|
||||
// Ensure action window stays within screen boundaries with a small gap
|
||||
if (posX + this.ACTION_WINDOW_WIDTH > workArea.x + workArea.width) {
|
||||
posX = workArea.x + workArea.width - this.ACTION_WINDOW_WIDTH - GAP
|
||||
if (posX + actionWindowWidth > workArea.x + workArea.width) {
|
||||
posX = workArea.x + workArea.width - actionWindowWidth - GAP
|
||||
} else if (posX < workArea.x) {
|
||||
posX = workArea.x + GAP
|
||||
}
|
||||
if (posY + this.ACTION_WINDOW_HEIGHT > workArea.y + workArea.height) {
|
||||
if (posY + actionWindowHeight > workArea.y + workArea.height) {
|
||||
// If window would go below screen, try to position it above toolbar
|
||||
posY = workArea.y + workArea.height - this.ACTION_WINDOW_HEIGHT - GAP
|
||||
posY = workArea.y + workArea.height - actionWindowHeight - GAP
|
||||
} else if (posY < workArea.y) {
|
||||
posY = workArea.y + GAP
|
||||
}
|
||||
@ -861,8 +1136,8 @@ export class SelectionService {
|
||||
actionWindow.setPosition(posX, posY, false)
|
||||
//KEY to make window not resize
|
||||
actionWindow.setBounds({
|
||||
width: this.ACTION_WINDOW_WIDTH,
|
||||
height: this.ACTION_WINDOW_HEIGHT,
|
||||
width: actionWindowWidth,
|
||||
height: actionWindowHeight,
|
||||
x: posX,
|
||||
y: posY
|
||||
})
|
||||
@ -888,31 +1163,44 @@ export class SelectionService {
|
||||
* Manages appropriate event listeners for each mode
|
||||
*/
|
||||
private processTriggerMode() {
|
||||
if (this.triggerMode === 'selected') {
|
||||
if (this.isCtrlkeyListenerActive) {
|
||||
this.selectionHook!.off('key-down', this.handleKeyDownCtrlkeyMode)
|
||||
this.selectionHook!.off('key-up', this.handleKeyUpCtrlkeyMode)
|
||||
switch (this.triggerMode) {
|
||||
case TriggerMode.Selected:
|
||||
if (this.isCtrlkeyListenerActive) {
|
||||
this.selectionHook!.off('key-down', this.handleKeyDownCtrlkeyMode)
|
||||
this.selectionHook!.off('key-up', this.handleKeyUpCtrlkeyMode)
|
||||
|
||||
this.isCtrlkeyListenerActive = false
|
||||
}
|
||||
this.isCtrlkeyListenerActive = false
|
||||
}
|
||||
|
||||
this.selectionHook!.enableClipboard()
|
||||
this.selectionHook!.setSelectionPassiveMode(false)
|
||||
} else if (this.triggerMode === 'ctrlkey') {
|
||||
if (!this.isCtrlkeyListenerActive) {
|
||||
this.selectionHook!.on('key-down', this.handleKeyDownCtrlkeyMode)
|
||||
this.selectionHook!.on('key-up', this.handleKeyUpCtrlkeyMode)
|
||||
this.selectionHook!.setSelectionPassiveMode(false)
|
||||
break
|
||||
case TriggerMode.Ctrlkey:
|
||||
if (!this.isCtrlkeyListenerActive) {
|
||||
this.selectionHook!.on('key-down', this.handleKeyDownCtrlkeyMode)
|
||||
this.selectionHook!.on('key-up', this.handleKeyUpCtrlkeyMode)
|
||||
|
||||
this.isCtrlkeyListenerActive = true
|
||||
}
|
||||
this.isCtrlkeyListenerActive = true
|
||||
}
|
||||
|
||||
this.selectionHook!.disableClipboard()
|
||||
this.selectionHook!.setSelectionPassiveMode(true)
|
||||
this.selectionHook!.setSelectionPassiveMode(true)
|
||||
break
|
||||
case TriggerMode.Shortcut:
|
||||
//remove the ctrlkey listener, don't need any key listener for shortcut mode
|
||||
if (this.isCtrlkeyListenerActive) {
|
||||
this.selectionHook!.off('key-down', this.handleKeyDownCtrlkeyMode)
|
||||
this.selectionHook!.off('key-up', this.handleKeyUpCtrlkeyMode)
|
||||
|
||||
this.isCtrlkeyListenerActive = false
|
||||
}
|
||||
|
||||
this.selectionHook!.setSelectionPassiveMode(true)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
public writeToClipboard(text: string): boolean {
|
||||
return this.selectionHook?.writeToClipboard(text) ?? false
|
||||
if (!this.selectionHook || !this.started) return false
|
||||
return this.selectionHook.writeToClipboard(text)
|
||||
}
|
||||
|
||||
/**
|
||||
@ -946,6 +1234,18 @@ export class SelectionService {
|
||||
configManager.setSelectionAssistantFollowToolbar(isFollowToolbar)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Selection_SetRemeberWinSize, (_, isRemeberWinSize: boolean) => {
|
||||
configManager.setSelectionAssistantRemeberWinSize(isRemeberWinSize)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Selection_SetFilterMode, (_, filterMode: string) => {
|
||||
configManager.setSelectionAssistantFilterMode(filterMode)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Selection_SetFilterList, (_, filterList: string[]) => {
|
||||
configManager.setSelectionAssistantFilterList(filterList)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.Selection_ProcessAction, (_, actionItem: ActionItem) => {
|
||||
selectionService?.processAction(actionItem)
|
||||
})
|
||||
@ -974,8 +1274,10 @@ export class SelectionService {
|
||||
this.isIpcHandlerRegistered = true
|
||||
}
|
||||
|
||||
private logInfo(message: string) {
|
||||
isDev && Logger.info('[SelectionService] Info: ', message)
|
||||
private logInfo(message: string, forceShow: boolean = false) {
|
||||
if (isDev || forceShow) {
|
||||
Logger.info('[SelectionService] Info: ', message)
|
||||
}
|
||||
}
|
||||
|
||||
private logError(...args: [...string[], Error]) {
|
||||
|
||||
@ -4,10 +4,16 @@ import { BrowserWindow, globalShortcut } from 'electron'
|
||||
import Logger from 'electron-log'
|
||||
|
||||
import { configManager } from './ConfigManager'
|
||||
import selectionService from './SelectionService'
|
||||
import { windowService } from './WindowService'
|
||||
|
||||
let showAppAccelerator: string | null = null
|
||||
let showMiniWindowAccelerator: string | null = null
|
||||
let selectionAssistantToggleAccelerator: string | null = null
|
||||
let selectionAssistantSelectTextAccelerator: string | null = null
|
||||
|
||||
//indicate if the shortcuts are registered on app boot time
|
||||
let isRegisterOnBoot = true
|
||||
|
||||
// store the focus and blur handlers for each window to unregister them later
|
||||
const windowOnHandlers = new Map<BrowserWindow, { onFocusHandler: () => void; onBlurHandler: () => void }>()
|
||||
@ -28,6 +34,18 @@ function getShortcutHandler(shortcut: Shortcut) {
|
||||
return () => {
|
||||
windowService.toggleMiniWindow()
|
||||
}
|
||||
case 'selection_assistant_toggle':
|
||||
return () => {
|
||||
if (selectionService) {
|
||||
selectionService.toggleEnabled()
|
||||
}
|
||||
}
|
||||
case 'selection_assistant_select_text':
|
||||
return () => {
|
||||
if (selectionService) {
|
||||
selectionService.processSelectTextByShortcut()
|
||||
}
|
||||
}
|
||||
default:
|
||||
return null
|
||||
}
|
||||
@ -37,9 +55,8 @@ function formatShortcutKey(shortcut: string[]): string {
|
||||
return shortcut.join('+')
|
||||
}
|
||||
|
||||
const convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat = (
|
||||
shortcut: string | string[]
|
||||
): string => {
|
||||
// convert the shortcut recorded by keyboard event key value to electron global shortcut format
|
||||
const convertShortcutFormat = (shortcut: string | string[]): string => {
|
||||
const accelerator = (() => {
|
||||
if (Array.isArray(shortcut)) {
|
||||
return shortcut
|
||||
@ -93,11 +110,14 @@ const convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutForm
|
||||
}
|
||||
|
||||
export function registerShortcuts(window: BrowserWindow) {
|
||||
window.once('ready-to-show', () => {
|
||||
if (configManager.getLaunchToTray()) {
|
||||
registerOnlyUniversalShortcuts()
|
||||
}
|
||||
})
|
||||
if (isRegisterOnBoot) {
|
||||
window.once('ready-to-show', () => {
|
||||
if (configManager.getLaunchToTray()) {
|
||||
registerOnlyUniversalShortcuts()
|
||||
}
|
||||
})
|
||||
isRegisterOnBoot = false
|
||||
}
|
||||
|
||||
//only for clearer code
|
||||
const registerOnlyUniversalShortcuts = () => {
|
||||
@ -124,7 +144,12 @@ export function registerShortcuts(window: BrowserWindow) {
|
||||
}
|
||||
|
||||
// only register universal shortcuts when needed
|
||||
if (onlyUniversalShortcuts && !['show_app', 'mini_window'].includes(shortcut.key)) {
|
||||
if (
|
||||
onlyUniversalShortcuts &&
|
||||
!['show_app', 'mini_window', 'selection_assistant_toggle', 'selection_assistant_select_text'].includes(
|
||||
shortcut.key
|
||||
)
|
||||
) {
|
||||
return
|
||||
}
|
||||
|
||||
@ -146,6 +171,14 @@ export function registerShortcuts(window: BrowserWindow) {
|
||||
showMiniWindowAccelerator = formatShortcutKey(shortcut.shortcut)
|
||||
break
|
||||
|
||||
case 'selection_assistant_toggle':
|
||||
selectionAssistantToggleAccelerator = formatShortcutKey(shortcut.shortcut)
|
||||
break
|
||||
|
||||
case 'selection_assistant_select_text':
|
||||
selectionAssistantSelectTextAccelerator = formatShortcutKey(shortcut.shortcut)
|
||||
break
|
||||
|
||||
//the following ZOOMs will register shortcuts seperately, so will return
|
||||
case 'zoom_in':
|
||||
globalShortcut.register('CommandOrControl+=', () => handler(window))
|
||||
@ -162,9 +195,7 @@ export function registerShortcuts(window: BrowserWindow) {
|
||||
return
|
||||
}
|
||||
|
||||
const accelerator = convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat(
|
||||
shortcut.shortcut
|
||||
)
|
||||
const accelerator = convertShortcutFormat(shortcut.shortcut)
|
||||
|
||||
globalShortcut.register(accelerator, () => handler(window))
|
||||
} catch (error) {
|
||||
@ -181,15 +212,25 @@ export function registerShortcuts(window: BrowserWindow) {
|
||||
|
||||
if (showAppAccelerator) {
|
||||
const handler = getShortcutHandler({ key: 'show_app' } as Shortcut)
|
||||
const accelerator =
|
||||
convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat(showAppAccelerator)
|
||||
const accelerator = convertShortcutFormat(showAppAccelerator)
|
||||
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||
}
|
||||
|
||||
if (showMiniWindowAccelerator) {
|
||||
const handler = getShortcutHandler({ key: 'mini_window' } as Shortcut)
|
||||
const accelerator =
|
||||
convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat(showMiniWindowAccelerator)
|
||||
const accelerator = convertShortcutFormat(showMiniWindowAccelerator)
|
||||
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||
}
|
||||
|
||||
if (selectionAssistantToggleAccelerator) {
|
||||
const handler = getShortcutHandler({ key: 'selection_assistant_toggle' } as Shortcut)
|
||||
const accelerator = convertShortcutFormat(selectionAssistantToggleAccelerator)
|
||||
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||
}
|
||||
|
||||
if (selectionAssistantSelectTextAccelerator) {
|
||||
const handler = getShortcutHandler({ key: 'selection_assistant_select_text' } as Shortcut)
|
||||
const accelerator = convertShortcutFormat(selectionAssistantSelectTextAccelerator)
|
||||
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||
}
|
||||
} catch (error) {
|
||||
@ -217,6 +258,8 @@ export function unregisterAllShortcuts() {
|
||||
try {
|
||||
showAppAccelerator = null
|
||||
showMiniWindowAccelerator = null
|
||||
selectionAssistantToggleAccelerator = null
|
||||
selectionAssistantSelectTextAccelerator = null
|
||||
windowOnHandlers.forEach((handlers, window) => {
|
||||
window.off('focus', handlers.onFocusHandler)
|
||||
window.off('blur', handlers.onBlurHandler)
|
||||
|
||||
@ -49,6 +49,23 @@ export class StoreSyncService {
|
||||
this.windowIds = this.windowIds.filter((id) => id !== windowId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Sync an action to all renderer windows
|
||||
* @param type Action type, like 'settings/setTray'
|
||||
* @param payload Action payload
|
||||
*
|
||||
* NOTICE: DO NOT use directly in ConfigManager, may cause infinite sync loop
|
||||
*/
|
||||
public syncToRenderer(type: string, payload: any): void {
|
||||
const action: StoreSyncAction = {
|
||||
type,
|
||||
payload
|
||||
}
|
||||
|
||||
//-1 means the action is from the main process, will be broadcast to all windows
|
||||
this.broadcastToOtherWindows(-1, action)
|
||||
}
|
||||
|
||||
/**
|
||||
* Register IPC handlers for store sync communication
|
||||
* Handles window subscription, unsubscription and action broadcasting
|
||||
|
||||
48
src/main/services/ThemeService.ts
Normal file
48
src/main/services/ThemeService.ts
Normal file
@ -0,0 +1,48 @@
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { ThemeMode } from '@types'
|
||||
import { BrowserWindow, nativeTheme } from 'electron'
|
||||
|
||||
import { titleBarOverlayDark, titleBarOverlayLight } from '../config'
|
||||
import { configManager } from './ConfigManager'
|
||||
|
||||
class ThemeService {
|
||||
private theme: ThemeMode = ThemeMode.system
|
||||
constructor() {
|
||||
this.theme = configManager.getTheme()
|
||||
|
||||
if (this.theme === ThemeMode.dark || this.theme === ThemeMode.light || this.theme === ThemeMode.system) {
|
||||
nativeTheme.themeSource = this.theme
|
||||
} else {
|
||||
// 兼容旧版本
|
||||
configManager.setTheme(ThemeMode.system)
|
||||
nativeTheme.themeSource = ThemeMode.system
|
||||
}
|
||||
nativeTheme.on('updated', this.themeUpdatadHandler.bind(this))
|
||||
}
|
||||
|
||||
themeUpdatadHandler() {
|
||||
BrowserWindow.getAllWindows().forEach((win) => {
|
||||
if (win && !win.isDestroyed() && win.setTitleBarOverlay) {
|
||||
try {
|
||||
win.setTitleBarOverlay(nativeTheme.shouldUseDarkColors ? titleBarOverlayDark : titleBarOverlayLight)
|
||||
} catch (error) {
|
||||
// don't throw error if setTitleBarOverlay failed
|
||||
// Because it may be called with some windows have some title bar
|
||||
}
|
||||
}
|
||||
win.webContents.send(IpcChannel.ThemeUpdated, nativeTheme.shouldUseDarkColors ? ThemeMode.dark : ThemeMode.light)
|
||||
})
|
||||
}
|
||||
|
||||
setTheme(theme: ThemeMode) {
|
||||
if (theme === this.theme) {
|
||||
return
|
||||
}
|
||||
|
||||
this.theme = theme
|
||||
nativeTheme.themeSource = theme
|
||||
configManager.setTheme(theme)
|
||||
}
|
||||
}
|
||||
|
||||
export const themeService = new ThemeService()
|
||||
@ -1,20 +1,22 @@
|
||||
import { isMac } from '@main/constant'
|
||||
import { isLinux, isMac, isWin } from '@main/constant'
|
||||
import { locales } from '@main/utils/locales'
|
||||
import { app, Menu, MenuItemConstructorOptions, nativeImage, nativeTheme, Tray } from 'electron'
|
||||
|
||||
import icon from '../../../build/tray_icon.png?asset'
|
||||
import iconDark from '../../../build/tray_icon_dark.png?asset'
|
||||
import iconLight from '../../../build/tray_icon_light.png?asset'
|
||||
import { configManager } from './ConfigManager'
|
||||
import { ConfigKeys, configManager } from './ConfigManager'
|
||||
import selectionService from './SelectionService'
|
||||
import { windowService } from './WindowService'
|
||||
|
||||
export class TrayService {
|
||||
private static instance: TrayService
|
||||
private tray: Tray | null = null
|
||||
private contextMenu: Menu | null = null
|
||||
|
||||
constructor() {
|
||||
this.watchConfigChanges()
|
||||
this.updateTray()
|
||||
this.watchTrayChanges()
|
||||
TrayService.instance = this
|
||||
}
|
||||
|
||||
@ -28,14 +30,14 @@ export class TrayService {
|
||||
const iconPath = isMac ? (nativeTheme.shouldUseDarkColors ? iconLight : iconDark) : icon
|
||||
const tray = new Tray(iconPath)
|
||||
|
||||
if (process.platform === 'win32') {
|
||||
if (isWin) {
|
||||
tray.setImage(iconPath)
|
||||
} else if (process.platform === 'darwin') {
|
||||
} else if (isMac) {
|
||||
const image = nativeImage.createFromPath(iconPath)
|
||||
const resizedImage = image.resize({ width: 16, height: 16 })
|
||||
resizedImage.setTemplateImage(true)
|
||||
tray.setImage(resizedImage)
|
||||
} else if (process.platform === 'linux') {
|
||||
} else if (isLinux) {
|
||||
const image = nativeImage.createFromPath(iconPath)
|
||||
const resizedImage = image.resize({ width: 16, height: 16 })
|
||||
tray.setImage(resizedImage)
|
||||
@ -43,20 +45,56 @@ export class TrayService {
|
||||
|
||||
this.tray = tray
|
||||
|
||||
const locale = locales[configManager.getLanguage()]
|
||||
const { tray: trayLocale } = locale.translation
|
||||
this.updateContextMenu()
|
||||
|
||||
const enableQuickAssistant = configManager.getEnableQuickAssistant()
|
||||
if (isLinux) {
|
||||
this.tray.setContextMenu(this.contextMenu)
|
||||
}
|
||||
|
||||
this.tray.setToolTip('Cherry Studio')
|
||||
|
||||
this.tray.on('right-click', () => {
|
||||
if (this.contextMenu) {
|
||||
this.tray?.popUpContextMenu(this.contextMenu)
|
||||
}
|
||||
})
|
||||
|
||||
this.tray.on('click', () => {
|
||||
if (configManager.getEnableQuickAssistant() && configManager.getClickTrayToShowQuickAssistant()) {
|
||||
windowService.showMiniWindow()
|
||||
} else {
|
||||
windowService.showMainWindow()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private updateContextMenu() {
|
||||
const locale = locales[configManager.getLanguage()]
|
||||
const { tray: trayLocale, selection: selectionLocale } = locale.translation
|
||||
|
||||
const quickAssistantEnabled = configManager.getEnableQuickAssistant()
|
||||
const selectionAssistantEnabled = configManager.getSelectionAssistantEnabled()
|
||||
|
||||
const template = [
|
||||
{
|
||||
label: trayLocale.show_window,
|
||||
click: () => windowService.showMainWindow()
|
||||
},
|
||||
enableQuickAssistant && {
|
||||
quickAssistantEnabled && {
|
||||
label: trayLocale.show_mini_window,
|
||||
click: () => windowService.showMiniWindow()
|
||||
},
|
||||
isWin && {
|
||||
label: selectionLocale.name + (selectionAssistantEnabled ? ' - On' : ' - Off'),
|
||||
// type: 'checkbox',
|
||||
// checked: selectionAssistantEnabled,
|
||||
click: () => {
|
||||
if (selectionService) {
|
||||
selectionService.toggleEnabled()
|
||||
this.updateContextMenu()
|
||||
}
|
||||
}
|
||||
},
|
||||
{ type: 'separator' },
|
||||
{
|
||||
label: trayLocale.quit,
|
||||
@ -64,25 +102,7 @@ export class TrayService {
|
||||
}
|
||||
].filter(Boolean) as MenuItemConstructorOptions[]
|
||||
|
||||
const contextMenu = Menu.buildFromTemplate(template)
|
||||
|
||||
if (process.platform === 'linux') {
|
||||
this.tray.setContextMenu(contextMenu)
|
||||
}
|
||||
|
||||
this.tray.setToolTip('Cherry Studio')
|
||||
|
||||
this.tray.on('right-click', () => {
|
||||
this.tray?.popUpContextMenu(contextMenu)
|
||||
})
|
||||
|
||||
this.tray.on('click', () => {
|
||||
if (enableQuickAssistant && configManager.getClickTrayToShowQuickAssistant()) {
|
||||
windowService.showMiniWindow()
|
||||
} else {
|
||||
windowService.showMainWindow()
|
||||
}
|
||||
})
|
||||
this.contextMenu = Menu.buildFromTemplate(template)
|
||||
}
|
||||
|
||||
private updateTray() {
|
||||
@ -94,13 +114,6 @@ export class TrayService {
|
||||
}
|
||||
}
|
||||
|
||||
public restartTray() {
|
||||
if (configManager.getTray()) {
|
||||
this.destroyTray()
|
||||
this.createTray()
|
||||
}
|
||||
}
|
||||
|
||||
private destroyTray() {
|
||||
if (this.tray) {
|
||||
this.tray.destroy()
|
||||
@ -108,8 +121,20 @@ export class TrayService {
|
||||
}
|
||||
}
|
||||
|
||||
private watchTrayChanges() {
|
||||
configManager.subscribe<boolean>('tray', () => this.updateTray())
|
||||
private watchConfigChanges() {
|
||||
configManager.subscribe(ConfigKeys.Tray, () => this.updateTray())
|
||||
|
||||
configManager.subscribe(ConfigKeys.Language, () => {
|
||||
this.updateContextMenu()
|
||||
})
|
||||
|
||||
configManager.subscribe(ConfigKeys.EnableQuickAssistant, () => {
|
||||
this.updateContextMenu()
|
||||
})
|
||||
|
||||
configManager.subscribe(ConfigKeys.SelectionAssistantEnabled, () => {
|
||||
this.updateContextMenu()
|
||||
})
|
||||
}
|
||||
|
||||
private quit() {
|
||||
|
||||
142
src/main/services/VertexAIService.ts
Normal file
142
src/main/services/VertexAIService.ts
Normal file
@ -0,0 +1,142 @@
|
||||
import { GoogleAuth } from 'google-auth-library'
|
||||
|
||||
interface ServiceAccountCredentials {
|
||||
privateKey: string
|
||||
clientEmail: string
|
||||
}
|
||||
|
||||
interface VertexAIAuthParams {
|
||||
projectId: string
|
||||
serviceAccount?: ServiceAccountCredentials
|
||||
}
|
||||
|
||||
const REQUIRED_VERTEX_AI_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
|
||||
|
||||
class VertexAIService {
|
||||
private static instance: VertexAIService
|
||||
private authClients: Map<string, GoogleAuth> = new Map()
|
||||
|
||||
static getInstance(): VertexAIService {
|
||||
if (!VertexAIService.instance) {
|
||||
VertexAIService.instance = new VertexAIService()
|
||||
}
|
||||
return VertexAIService.instance
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化私钥,确保它包含正确的PEM头部和尾部
|
||||
*/
|
||||
private formatPrivateKey(privateKey: string): string {
|
||||
if (!privateKey || typeof privateKey !== 'string') {
|
||||
throw new Error('Private key must be a non-empty string')
|
||||
}
|
||||
|
||||
// 处理JSON字符串中的转义换行符
|
||||
let key = privateKey.replace(/\\n/g, '\n')
|
||||
|
||||
// 如果已经是正确格式的PEM,直接返回
|
||||
if (key.includes('-----BEGIN PRIVATE KEY-----') && key.includes('-----END PRIVATE KEY-----')) {
|
||||
return key
|
||||
}
|
||||
|
||||
// 移除所有换行符和空白字符(为了重新格式化)
|
||||
key = key.replace(/\s+/g, '')
|
||||
|
||||
// 移除可能存在的头部和尾部
|
||||
key = key.replace(/-----BEGIN[^-]*-----/g, '')
|
||||
key = key.replace(/-----END[^-]*-----/g, '')
|
||||
|
||||
// 确保私钥不为空
|
||||
if (!key) {
|
||||
throw new Error('Private key is empty after formatting')
|
||||
}
|
||||
|
||||
// 添加正确的PEM头部和尾部,并格式化为64字符一行
|
||||
const formattedKey = key.match(/.{1,64}/g)?.join('\n') || key
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取认证头用于 Vertex AI 请求
|
||||
*/
|
||||
async getAuthHeaders(params: VertexAIAuthParams): Promise<Record<string, string>> {
|
||||
const { projectId, serviceAccount } = params
|
||||
|
||||
if (!serviceAccount?.privateKey || !serviceAccount?.clientEmail) {
|
||||
throw new Error('Service account credentials are required')
|
||||
}
|
||||
|
||||
// 创建缓存键
|
||||
const cacheKey = `${projectId}-${serviceAccount.clientEmail}`
|
||||
|
||||
// 检查是否已有客户端实例
|
||||
let auth = this.authClients.get(cacheKey)
|
||||
|
||||
if (!auth) {
|
||||
try {
|
||||
// 格式化私钥
|
||||
const formattedPrivateKey = this.formatPrivateKey(serviceAccount.privateKey)
|
||||
|
||||
// 创建新的认证客户端
|
||||
auth = new GoogleAuth({
|
||||
credentials: {
|
||||
private_key: formattedPrivateKey,
|
||||
client_email: serviceAccount.clientEmail
|
||||
},
|
||||
projectId,
|
||||
scopes: [REQUIRED_VERTEX_AI_SCOPE]
|
||||
})
|
||||
|
||||
this.authClients.set(cacheKey, auth)
|
||||
} catch (formatError: any) {
|
||||
throw new Error(`Invalid private key format: ${formatError.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
// 获取认证头
|
||||
const authHeaders = await auth.getRequestHeaders()
|
||||
|
||||
// 转换为普通对象
|
||||
const headers: Record<string, string> = {}
|
||||
for (const [key, value] of Object.entries(authHeaders)) {
|
||||
if (typeof value === 'string') {
|
||||
headers[key] = value
|
||||
}
|
||||
}
|
||||
|
||||
return headers
|
||||
} catch (error: any) {
|
||||
// 如果认证失败,清除缓存的客户端
|
||||
this.authClients.delete(cacheKey)
|
||||
throw new Error(`Failed to authenticate with service account: ${error.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理指定项目的认证缓存
|
||||
*/
|
||||
clearAuthCache(projectId: string, clientEmail?: string): void {
|
||||
if (clientEmail) {
|
||||
const cacheKey = `${projectId}-${clientEmail}`
|
||||
this.authClients.delete(cacheKey)
|
||||
} else {
|
||||
// 清理该项目的所有缓存
|
||||
for (const [key] of this.authClients) {
|
||||
if (key.startsWith(`${projectId}-`)) {
|
||||
this.authClients.delete(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有认证缓存
|
||||
*/
|
||||
clearAllAuthCache(): void {
|
||||
this.authClients.clear()
|
||||
}
|
||||
}
|
||||
|
||||
export default VertexAIService
|
||||
@ -1,5 +1,7 @@
|
||||
import { WebDavConfig } from '@types'
|
||||
import Logger from 'electron-log'
|
||||
import https from 'https'
|
||||
import path from 'path'
|
||||
import Stream from 'stream'
|
||||
import {
|
||||
BufferLike,
|
||||
@ -14,13 +16,14 @@ export default class WebDav {
|
||||
private webdavPath: string
|
||||
|
||||
constructor(params: WebDavConfig) {
|
||||
this.webdavPath = params.webdavPath
|
||||
this.webdavPath = params.webdavPath || '/'
|
||||
|
||||
this.instance = createClient(params.webdavHost, {
|
||||
username: params.webdavUser,
|
||||
password: params.webdavPass,
|
||||
maxBodyLength: Infinity,
|
||||
maxContentLength: Infinity
|
||||
maxContentLength: Infinity,
|
||||
httpsAgent: new https.Agent({ rejectUnauthorized: false })
|
||||
})
|
||||
|
||||
this.putFileContents = this.putFileContents.bind(this)
|
||||
@ -49,7 +52,7 @@ export default class WebDav {
|
||||
throw error
|
||||
}
|
||||
|
||||
const remoteFilePath = `${this.webdavPath}/${filename}`
|
||||
const remoteFilePath = path.posix.join(this.webdavPath, filename)
|
||||
|
||||
try {
|
||||
return await this.instance.putFileContents(remoteFilePath, data, options)
|
||||
@ -64,7 +67,7 @@ export default class WebDav {
|
||||
throw new Error('WebDAV client not initialized')
|
||||
}
|
||||
|
||||
const remoteFilePath = `${this.webdavPath}/${filename}`
|
||||
const remoteFilePath = path.posix.join(this.webdavPath, filename)
|
||||
|
||||
try {
|
||||
return await this.instance.getFileContents(remoteFilePath, options)
|
||||
@ -74,6 +77,19 @@ export default class WebDav {
|
||||
}
|
||||
}
|
||||
|
||||
public getDirectoryContents = async () => {
|
||||
if (!this.instance) {
|
||||
throw new Error('WebDAV client not initialized')
|
||||
}
|
||||
|
||||
try {
|
||||
return await this.instance.getDirectoryContents(this.webdavPath)
|
||||
} catch (error) {
|
||||
Logger.error('[WebDAV] Error getting directory contents on WebDAV:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public checkConnection = async () => {
|
||||
if (!this.instance) {
|
||||
throw new Error('WebDAV client not initialized')
|
||||
@ -105,7 +121,7 @@ export default class WebDav {
|
||||
throw new Error('WebDAV client not initialized')
|
||||
}
|
||||
|
||||
const remoteFilePath = `${this.webdavPath}/${filename}`
|
||||
const remoteFilePath = path.posix.join(this.webdavPath, filename)
|
||||
|
||||
try {
|
||||
return await this.instance.deleteFile(remoteFilePath)
|
||||
|
||||
@ -1,8 +1,10 @@
|
||||
// just import the themeService to ensure the theme is initialized
|
||||
import './ThemeService'
|
||||
|
||||
import { is } from '@electron-toolkit/utils'
|
||||
import { isDev, isLinux, isMac, isWin } from '@main/constant'
|
||||
import { getFilesDir } from '@main/utils/file'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { ThemeMode } from '@types'
|
||||
import { app, BrowserWindow, nativeTheme, shell } from 'electron'
|
||||
import Logger from 'electron-log'
|
||||
import windowStateKeeper from 'electron-window-state'
|
||||
@ -45,13 +47,6 @@ export class WindowService {
|
||||
maximize: false
|
||||
})
|
||||
|
||||
const theme = configManager.getTheme()
|
||||
if (theme === ThemeMode.auto) {
|
||||
nativeTheme.themeSource = 'system'
|
||||
} else {
|
||||
nativeTheme.themeSource = theme
|
||||
}
|
||||
|
||||
this.mainWindow = new BrowserWindow({
|
||||
x: mainWindowState.x,
|
||||
y: mainWindowState.y,
|
||||
@ -61,7 +56,7 @@ export class WindowService {
|
||||
minHeight: 600,
|
||||
show: false,
|
||||
autoHideMenuBar: true,
|
||||
transparent: isMac,
|
||||
transparent: false,
|
||||
vibrancy: 'sidebar',
|
||||
visualEffectState: 'active',
|
||||
titleBarStyle: 'hidden',
|
||||
@ -100,6 +95,7 @@ export class WindowService {
|
||||
|
||||
this.setupMaximize(mainWindow, mainWindowState.isMaximized)
|
||||
this.setupContextMenu(mainWindow)
|
||||
this.setupSpellCheck(mainWindow)
|
||||
this.setupWindowEvents(mainWindow)
|
||||
this.setupWebContentsHandlers(mainWindow)
|
||||
this.setupWindowLifecycleEvents(mainWindow)
|
||||
@ -107,6 +103,18 @@ export class WindowService {
|
||||
this.loadMainWindowContent(mainWindow)
|
||||
}
|
||||
|
||||
private setupSpellCheck(mainWindow: BrowserWindow) {
|
||||
const enableSpellCheck = configManager.get('enableSpellCheck', false)
|
||||
if (enableSpellCheck) {
|
||||
try {
|
||||
const spellCheckLanguages = configManager.get('spellCheckLanguages', []) as string[]
|
||||
spellCheckLanguages.length > 0 && mainWindow.webContents.session.setSpellCheckerLanguages(spellCheckLanguages)
|
||||
} catch (error) {
|
||||
Logger.error('Failed to set spell check languages:', error as Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private setupMainWindowMonitor(mainWindow: BrowserWindow) {
|
||||
mainWindow.webContents.on('render-process-gone', (_, details) => {
|
||||
Logger.error(`Renderer process crashed with: ${JSON.stringify(details)}`)
|
||||
@ -121,12 +129,6 @@ export class WindowService {
|
||||
app.exit(1)
|
||||
}
|
||||
})
|
||||
|
||||
mainWindow.webContents.on('unresponsive', () => {
|
||||
// 在升级到electron 34后,可以获取具体js stack trace,目前只打个日志监控下
|
||||
// https://www.electronjs.org/blog/electron-34-0#unresponsive-renderer-javascript-call-stacks
|
||||
Logger.error('Renderer process unresponsive')
|
||||
})
|
||||
}
|
||||
|
||||
private setupMaximize(mainWindow: BrowserWindow, isMaximized: boolean) {
|
||||
@ -141,9 +143,10 @@ export class WindowService {
|
||||
}
|
||||
|
||||
private setupContextMenu(mainWindow: BrowserWindow) {
|
||||
contextMenu.contextMenu(mainWindow)
|
||||
app.on('browser-window-created', (_, win) => {
|
||||
contextMenu.contextMenu(win)
|
||||
contextMenu.contextMenu(mainWindow.webContents)
|
||||
// setup context menu for all webviews like miniapp
|
||||
app.on('web-contents-created', (_, webContents) => {
|
||||
contextMenu.contextMenu(webContents)
|
||||
})
|
||||
|
||||
// Dangerous API
|
||||
@ -448,8 +451,7 @@ export class WindowService {
|
||||
preload: join(__dirname, '../preload/index.js'),
|
||||
sandbox: false,
|
||||
webSecurity: false,
|
||||
webviewTag: true,
|
||||
backgroundThrottling: false
|
||||
webviewTag: true
|
||||
}
|
||||
})
|
||||
|
||||
@ -549,6 +551,25 @@ export class WindowService {
|
||||
public setPinMiniWindow(isPinned) {
|
||||
this.isPinnedMiniWindow = isPinned
|
||||
}
|
||||
|
||||
/**
|
||||
* 引用文本到主窗口
|
||||
* @param text 原始文本(未格式化)
|
||||
*/
|
||||
public quoteToMainWindow(text: string): void {
|
||||
try {
|
||||
this.showMainWindow()
|
||||
|
||||
const mainWindow = this.getMainWindow()
|
||||
if (mainWindow && !mainWindow.isDestroyed()) {
|
||||
setTimeout(() => {
|
||||
mainWindow.webContents.send(IpcChannel.App_QuoteToMain, text)
|
||||
}, 100)
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error('Failed to quote to main window:', error as Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const windowService = WindowService.getInstance()
|
||||
|
||||
@ -1,37 +1,47 @@
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import Logger from 'electron-log'
|
||||
|
||||
import { windowService } from '../WindowService'
|
||||
|
||||
export function handleProvidersProtocolUrl(url: URL) {
|
||||
const params = new URLSearchParams(url.search)
|
||||
export async function handleProvidersProtocolUrl(url: URL) {
|
||||
switch (url.pathname) {
|
||||
case '/api-keys': {
|
||||
// jsonConfig example:
|
||||
// {
|
||||
// "id": "tokenflux",
|
||||
// "baseUrl": "https://tokenflux.ai/v1",
|
||||
// "apiKey": "sk-xxxx"
|
||||
// "apiKey": "sk-xxxx",
|
||||
// "name": "TokenFlux", // optional
|
||||
// "type": "openai" // optional
|
||||
// }
|
||||
// cherrystudio://providers/api-keys?data={base64Encode(JSON.stringify(jsonConfig))}
|
||||
// cherrystudio://providers/api-keys?v=1&data={base64Encode(JSON.stringify(jsonConfig))}
|
||||
|
||||
// replace + and / to _ and - because + and / are processed by URLSearchParams
|
||||
const processedSearch = url.search.replaceAll('+', '_').replaceAll('/', '-')
|
||||
const params = new URLSearchParams(processedSearch)
|
||||
const data = params.get('data')
|
||||
if (data) {
|
||||
const stringify = Buffer.from(data, 'base64').toString('utf8')
|
||||
Logger.info('get api keys from urlschema: ', stringify)
|
||||
const jsonConfig = JSON.parse(stringify)
|
||||
Logger.info('get api keys from urlschema: ', jsonConfig)
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
if (mainWindow && !mainWindow.isDestroyed()) {
|
||||
mainWindow.webContents.send(IpcChannel.Provider_AddKey, jsonConfig)
|
||||
mainWindow.webContents.executeJavaScript(`window.navigate('/settings/provider?id=${jsonConfig.id}')`)
|
||||
}
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
const version = params.get('v')
|
||||
if (version == '1') {
|
||||
// TODO: handle different version
|
||||
Logger.info('handleProvidersProtocolUrl', { data, version })
|
||||
}
|
||||
|
||||
// add check there is window.navigate function in mainWindow
|
||||
if (
|
||||
mainWindow &&
|
||||
!mainWindow.isDestroyed() &&
|
||||
(await mainWindow.webContents.executeJavaScript(`typeof window.navigate === 'function'`))
|
||||
) {
|
||||
mainWindow.webContents.executeJavaScript(`window.navigate('/settings/provider?addProviderData=${data}')`)
|
||||
} else {
|
||||
Logger.error('No data found in URL')
|
||||
setTimeout(() => {
|
||||
handleProvidersProtocolUrl(url)
|
||||
}, 1000)
|
||||
}
|
||||
break
|
||||
}
|
||||
default:
|
||||
console.error(`Unknown MCP protocol URL: ${url}`)
|
||||
Logger.error(`Unknown MCP protocol URL: ${url}`)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
@ -92,6 +92,7 @@ describe('file', () => {
|
||||
it('should return DOCUMENT for document extensions', () => {
|
||||
expect(getFileType('.pdf')).toBe(FileTypes.DOCUMENT)
|
||||
expect(getFileType('.pptx')).toBe(FileTypes.DOCUMENT)
|
||||
expect(getFileType('.doc')).toBe(FileTypes.DOCUMENT)
|
||||
expect(getFileType('.docx')).toBe(FileTypes.DOCUMENT)
|
||||
expect(getFileType('.xlsx')).toBe(FileTypes.DOCUMENT)
|
||||
expect(getFileType('.odt')).toBe(FileTypes.DOCUMENT)
|
||||
|
||||
@ -2,12 +2,26 @@ import * as fs from 'node:fs'
|
||||
import os from 'node:os'
|
||||
import path from 'node:path'
|
||||
|
||||
import { isMac } from '@main/constant'
|
||||
import { isPortable } from '@main/constant'
|
||||
import { audioExts, documentExts, imageExts, textExts, videoExts } from '@shared/config/constant'
|
||||
import { FileType, FileTypes } from '@types'
|
||||
import { app } from 'electron'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
export function initAppDataDir() {
|
||||
const appDataPath = getAppDataPathFromConfig()
|
||||
if (appDataPath) {
|
||||
app.setPath('userData', appDataPath)
|
||||
return
|
||||
}
|
||||
|
||||
if (isPortable) {
|
||||
const portableDir = process.env.PORTABLE_EXECUTABLE_DIR
|
||||
app.setPath('userData', path.join(portableDir || app.getPath('exe'), 'data'))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 创建文件类型映射表,提高查找效率
|
||||
const fileTypeMap = new Map<string, FileTypes>()
|
||||
|
||||
@ -23,6 +37,85 @@ function initFileTypeMap() {
|
||||
// 初始化映射表
|
||||
initFileTypeMap()
|
||||
|
||||
export function hasWritePermission(path: string) {
|
||||
try {
|
||||
fs.accessSync(path, fs.constants.W_OK)
|
||||
return true
|
||||
} catch (error) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
function getAppDataPathFromConfig() {
|
||||
try {
|
||||
const configPath = path.join(getConfigDir(), 'config.json')
|
||||
if (!fs.existsSync(configPath)) {
|
||||
return null
|
||||
}
|
||||
|
||||
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
|
||||
|
||||
if (!config.appDataPath) {
|
||||
return null
|
||||
}
|
||||
|
||||
let appDataPath = null
|
||||
// 兼容旧版本
|
||||
if (config.appDataPath && typeof config.appDataPath === 'string') {
|
||||
appDataPath = config.appDataPath
|
||||
// 将旧版本数据迁移到新版本
|
||||
appDataPath && updateAppDataConfig(appDataPath)
|
||||
} else {
|
||||
appDataPath = config.appDataPath.find(
|
||||
(item: { executablePath: string }) => item.executablePath === app.getPath('exe')
|
||||
)?.dataPath
|
||||
}
|
||||
|
||||
if (appDataPath && fs.existsSync(appDataPath) && hasWritePermission(appDataPath)) {
|
||||
return appDataPath
|
||||
}
|
||||
|
||||
return null
|
||||
} catch (error) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export function updateAppDataConfig(appDataPath: string) {
|
||||
const configDir = getConfigDir()
|
||||
if (!fs.existsSync(configDir)) {
|
||||
fs.mkdirSync(configDir, { recursive: true })
|
||||
}
|
||||
|
||||
// config.json
|
||||
// appDataPath: [{ executablePath: string, dataPath: string }]
|
||||
const configPath = path.join(getConfigDir(), 'config.json')
|
||||
if (!fs.existsSync(configPath)) {
|
||||
fs.writeFileSync(
|
||||
configPath,
|
||||
JSON.stringify({ appDataPath: [{ executablePath: app.getPath('exe'), dataPath: appDataPath }] }, null, 2)
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
|
||||
if (!config.appDataPath || (config.appDataPath && typeof config.appDataPath !== 'object')) {
|
||||
config.appDataPath = []
|
||||
}
|
||||
|
||||
const existingPath = config.appDataPath.find(
|
||||
(item: { executablePath: string }) => item.executablePath === app.getPath('exe')
|
||||
)
|
||||
|
||||
if (existingPath) {
|
||||
existingPath.dataPath = appDataPath
|
||||
} else {
|
||||
config.appDataPath.push({ executablePath: app.getPath('exe'), dataPath: appDataPath })
|
||||
}
|
||||
|
||||
fs.writeFileSync(configPath, JSON.stringify(config, null, 2))
|
||||
}
|
||||
|
||||
export function getFileType(ext: string): FileTypes {
|
||||
ext = ext.toLowerCase()
|
||||
return fileTypeMap.get(ext) || FileTypes.OTHER
|
||||
@ -88,12 +181,3 @@ export function getCacheDir() {
|
||||
export function getAppConfigDir(name: string) {
|
||||
return path.join(getConfigDir(), name)
|
||||
}
|
||||
|
||||
export function setUserDataDir() {
|
||||
if (!isMac) {
|
||||
const dir = path.join(path.dirname(app.getPath('exe')), 'data')
|
||||
if (fs.existsSync(dir) && fs.statSync(dir).isDirectory()) {
|
||||
app.setPath('userData', dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
92
src/main/utils/systemInfo.ts
Normal file
92
src/main/utils/systemInfo.ts
Normal file
@ -0,0 +1,92 @@
|
||||
import { app } from 'electron'
|
||||
import macosRelease from 'macos-release'
|
||||
import os from 'os'
|
||||
|
||||
/**
|
||||
* System information interface
|
||||
*/
|
||||
export interface SystemInfo {
|
||||
platform: NodeJS.Platform
|
||||
arch: string
|
||||
osRelease: string
|
||||
appVersion: string
|
||||
osString: string
|
||||
archString: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Get basic system constants for quick access
|
||||
* @returns {Object} Basic system constants
|
||||
*/
|
||||
export function getSystemConstants() {
|
||||
return {
|
||||
platform: process.platform,
|
||||
arch: process.arch,
|
||||
osRelease: os.release(),
|
||||
appVersion: app.getVersion()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get system information
|
||||
* @returns {SystemInfo} Complete system information object
|
||||
*/
|
||||
export function getSystemInfo(): SystemInfo {
|
||||
const platform = process.platform
|
||||
const arch = process.arch
|
||||
const osRelease = os.release()
|
||||
const appVersion = app.getVersion()
|
||||
|
||||
let osString = ''
|
||||
|
||||
switch (platform) {
|
||||
case 'win32': {
|
||||
// Get Windows version
|
||||
const parts = osRelease.split('.')
|
||||
const buildNumber = parseInt(parts[2], 10)
|
||||
osString = buildNumber >= 22000 ? 'Windows 11' : 'Windows 10'
|
||||
break
|
||||
}
|
||||
case 'darwin': {
|
||||
// macOS version handling using macos-release for better accuracy
|
||||
try {
|
||||
const macVersionInfo = macosRelease()
|
||||
const versionString = macVersionInfo.version.replace(/\./g, '_') // 15.0.0 -> 15_0_0
|
||||
osString = arch === 'arm64' ? `Mac OS X ${versionString}` : `Intel Mac OS X ${versionString}` // Mac OS X 15_0_0
|
||||
} catch (error) {
|
||||
// Fallback to original logic if macos-release fails
|
||||
const macVersion = osRelease.split('.').slice(0, 2).join('_')
|
||||
osString = arch === 'arm64' ? `Mac OS X ${macVersion}` : `Intel Mac OS X ${macVersion}`
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'linux': {
|
||||
osString = `Linux ${arch}`
|
||||
break
|
||||
}
|
||||
default: {
|
||||
osString = `${platform} ${arch}`
|
||||
}
|
||||
}
|
||||
|
||||
const archString = arch === 'x64' ? 'x86_64' : arch === 'arm64' ? 'arm64' : arch
|
||||
|
||||
return {
|
||||
platform,
|
||||
arch,
|
||||
osRelease,
|
||||
appVersion,
|
||||
osString,
|
||||
archString
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate User-Agent string based on user system data
|
||||
* @returns {string} Dynamically generated User-Agent string
|
||||
*/
|
||||
export function generateUserAgent(): string {
|
||||
const systemInfo = getSystemInfo()
|
||||
|
||||
return `Mozilla/5.0 (${systemInfo.osString}; ${systemInfo.archString}) AppleWebKit/537.36 (KHTML, like Gecko) CherryStudio/${systemInfo.appVersion} Chrome/124.0.0.0 Safari/537.36`
|
||||
}
|
||||
@ -1,7 +1,8 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import { electronAPI } from '@electron-toolkit/preload'
|
||||
import { UpgradeChannel } from '@shared/config/constant'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types'
|
||||
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, ThemeMode, WebDavConfig } from '@types'
|
||||
import { contextBridge, ipcRenderer, OpenDialogOptions, shell, webUtils } from 'electron'
|
||||
import { Notification } from 'src/renderer/src/types/notification'
|
||||
import { CreateDirectoryOptions } from 'webdav'
|
||||
@ -16,15 +17,28 @@ const api = {
|
||||
checkForUpdate: () => ipcRenderer.invoke(IpcChannel.App_CheckForUpdate),
|
||||
showUpdateDialog: () => ipcRenderer.invoke(IpcChannel.App_ShowUpdateDialog),
|
||||
setLanguage: (lang: string) => ipcRenderer.invoke(IpcChannel.App_SetLanguage, lang),
|
||||
setEnableSpellCheck: (isEnable: boolean) => ipcRenderer.invoke(IpcChannel.App_SetEnableSpellCheck, isEnable),
|
||||
setSpellCheckLanguages: (languages: string[]) => ipcRenderer.invoke(IpcChannel.App_SetSpellCheckLanguages, languages),
|
||||
setLaunchOnBoot: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetLaunchOnBoot, isActive),
|
||||
setLaunchToTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetLaunchToTray, isActive),
|
||||
setTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTray, isActive),
|
||||
setTrayOnClose: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTrayOnClose, isActive),
|
||||
restartTray: () => ipcRenderer.invoke(IpcChannel.App_RestartTray),
|
||||
setTheme: (theme: 'light' | 'dark' | 'auto') => ipcRenderer.invoke(IpcChannel.App_SetTheme, theme),
|
||||
setTestPlan: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTestPlan, isActive),
|
||||
setTestChannel: (channel: UpgradeChannel) => ipcRenderer.invoke(IpcChannel.App_SetTestChannel, channel),
|
||||
setTheme: (theme: ThemeMode) => ipcRenderer.invoke(IpcChannel.App_SetTheme, theme),
|
||||
handleZoomFactor: (delta: number, reset: boolean = false) =>
|
||||
ipcRenderer.invoke(IpcChannel.App_HandleZoomFactor, delta, reset),
|
||||
setAutoUpdate: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetAutoUpdate, isActive),
|
||||
select: (options: Electron.OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.App_Select, options),
|
||||
hasWritePermission: (path: string) => ipcRenderer.invoke(IpcChannel.App_HasWritePermission, path),
|
||||
setAppDataPath: (path: string) => ipcRenderer.invoke(IpcChannel.App_SetAppDataPath, path),
|
||||
getDataPathFromArgs: () => ipcRenderer.invoke(IpcChannel.App_GetDataPathFromArgs),
|
||||
copy: (oldPath: string, newPath: string, occupiedDirs: string[] = []) =>
|
||||
ipcRenderer.invoke(IpcChannel.App_Copy, oldPath, newPath, occupiedDirs),
|
||||
setStopQuitApp: (stop: boolean, reason: string) => ipcRenderer.invoke(IpcChannel.App_SetStopQuitApp, stop, reason),
|
||||
flushAppData: () => ipcRenderer.invoke(IpcChannel.App_FlushAppData),
|
||||
isNotEmptyDir: (path: string) => ipcRenderer.invoke(IpcChannel.App_IsNotEmptyDir, path),
|
||||
relaunchApp: (options?: Electron.RelaunchOptions) => ipcRenderer.invoke(IpcChannel.App_RelaunchApp, options),
|
||||
openWebsite: (url: string) => ipcRenderer.invoke(IpcChannel.Open_Website, url),
|
||||
getCacheSize: () => ipcRenderer.invoke(IpcChannel.App_GetCacheSize),
|
||||
clearCache: () => ipcRenderer.invoke(IpcChannel.App_ClearCache),
|
||||
@ -76,14 +90,17 @@ const api = {
|
||||
selectFolder: () => ipcRenderer.invoke(IpcChannel.File_SelectFolder),
|
||||
saveImage: (name: string, data: string) => ipcRenderer.invoke(IpcChannel.File_SaveImage, name, data),
|
||||
base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId),
|
||||
download: (url: string, isUseContentType?: boolean) => ipcRenderer.invoke(IpcChannel.File_Download, url, isUseContentType),
|
||||
saveBase64Image: (data: string) => ipcRenderer.invoke(IpcChannel.File_SaveBase64Image, data),
|
||||
download: (url: string, isUseContentType?: boolean) =>
|
||||
ipcRenderer.invoke(IpcChannel.File_Download, url, isUseContentType),
|
||||
copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath),
|
||||
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId),
|
||||
base64File: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64File, fileId),
|
||||
pdfInfo: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_GetPdfInfo, fileId),
|
||||
getPathForFile: (file: File) => webUtils.getPathForFile(file)
|
||||
},
|
||||
fs: {
|
||||
read: (path: string) => ipcRenderer.invoke(IpcChannel.Fs_Read, path)
|
||||
read: (pathOrUrl: string, encoding?: BufferEncoding) => ipcRenderer.invoke(IpcChannel.Fs_Read, pathOrUrl, encoding)
|
||||
},
|
||||
export: {
|
||||
toWord: (markdown: string, fileName: string) => ipcRenderer.invoke(IpcChannel.Export_Word, markdown, fileName)
|
||||
@ -125,8 +142,16 @@ const api = {
|
||||
listFiles: (apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_ListFiles, apiKey),
|
||||
deleteFile: (fileId: string, apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_DeleteFile, fileId, apiKey)
|
||||
},
|
||||
|
||||
vertexAI: {
|
||||
getAuthHeaders: (params: { projectId: string; serviceAccount?: { privateKey: string; clientEmail: string } }) =>
|
||||
ipcRenderer.invoke(IpcChannel.VertexAI_GetAuthHeaders, params),
|
||||
clearAuthCache: (projectId: string, clientEmail?: string) =>
|
||||
ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail)
|
||||
},
|
||||
config: {
|
||||
set: (key: string, value: any) => ipcRenderer.invoke(IpcChannel.Config_Set, key, value),
|
||||
set: (key: string, value: any, isNotify: boolean = false) =>
|
||||
ipcRenderer.invoke(IpcChannel.Config_Set, key, value, isNotify),
|
||||
get: (key: string) => ipcRenderer.invoke(IpcChannel.Config_Get, key)
|
||||
},
|
||||
miniWindow: {
|
||||
@ -158,6 +183,10 @@ const api = {
|
||||
getInstallInfo: () => ipcRenderer.invoke(IpcChannel.Mcp_GetInstallInfo),
|
||||
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server)
|
||||
},
|
||||
python: {
|
||||
execute: (script: string, context?: Record<string, any>, timeout?: number) =>
|
||||
ipcRenderer.invoke(IpcChannel.Python_Execute, script, context, timeout)
|
||||
},
|
||||
shell: {
|
||||
openExternal: (url: string, options?: Electron.OpenExternalOptions) => shell.openExternal(url, options)
|
||||
},
|
||||
@ -200,7 +229,9 @@ const api = {
|
||||
},
|
||||
webview: {
|
||||
setOpenLinkExternal: (webviewId: number, isExternal: boolean) =>
|
||||
ipcRenderer.invoke(IpcChannel.Webview_SetOpenLinkExternal, webviewId, isExternal)
|
||||
ipcRenderer.invoke(IpcChannel.Webview_SetOpenLinkExternal, webviewId, isExternal),
|
||||
setSpellCheckEnabled: (webviewId: number, isEnable: boolean) =>
|
||||
ipcRenderer.invoke(IpcChannel.Webview_SetSpellCheckEnabled, webviewId, isEnable)
|
||||
},
|
||||
storeSync: {
|
||||
subscribe: () => ipcRenderer.invoke(IpcChannel.StoreSync_Subscribe),
|
||||
@ -216,11 +247,16 @@ const api = {
|
||||
setTriggerMode: (triggerMode: string) => ipcRenderer.invoke(IpcChannel.Selection_SetTriggerMode, triggerMode),
|
||||
setFollowToolbar: (isFollowToolbar: boolean) =>
|
||||
ipcRenderer.invoke(IpcChannel.Selection_SetFollowToolbar, isFollowToolbar),
|
||||
setRemeberWinSize: (isRemeberWinSize: boolean) =>
|
||||
ipcRenderer.invoke(IpcChannel.Selection_SetRemeberWinSize, isRemeberWinSize),
|
||||
setFilterMode: (filterMode: string) => ipcRenderer.invoke(IpcChannel.Selection_SetFilterMode, filterMode),
|
||||
setFilterList: (filterList: string[]) => ipcRenderer.invoke(IpcChannel.Selection_SetFilterList, filterList),
|
||||
processAction: (actionItem: ActionItem) => ipcRenderer.invoke(IpcChannel.Selection_ProcessAction, actionItem),
|
||||
closeActionWindow: () => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowClose),
|
||||
minimizeActionWindow: () => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowMinimize),
|
||||
pinActionWindow: (isPinned: boolean) => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowPin, isPinned)
|
||||
}
|
||||
},
|
||||
quoteToMainWindow: (text: string) => ipcRenderer.invoke(IpcChannel.App_QuoteToMain, text)
|
||||
}
|
||||
|
||||
// Use `contextBridge` APIs to expose Electron APIs to
|
||||
|
||||
@ -2,42 +2,45 @@
|
||||
<html lang="zh-CN">
|
||||
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="initial-scale=1, width=device-width" />
|
||||
<meta http-equiv="Content-Security-Policy"
|
||||
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
|
||||
<title>Cherry Studio Selection Toolbar</title>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="initial-scale=1, width=device-width" />
|
||||
<meta http-equiv="Content-Security-Policy"
|
||||
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
|
||||
<title>Cherry Studio Selection Toolbar</title>
|
||||
|
||||
</head>
|
||||
|
||||
<body>
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/windows/selection/toolbar/entryPoint.tsx"></script>
|
||||
<style>
|
||||
html {
|
||||
margin: 0;
|
||||
}
|
||||
<div id="root"></div>
|
||||
<script type="module" src="/src/windows/selection/toolbar/entryPoint.tsx"></script>
|
||||
<style>
|
||||
html {
|
||||
margin: 0 !important;
|
||||
background-color: transparent !important;
|
||||
background-image: none !important;
|
||||
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
overflow: hidden;
|
||||
width: 100vw;
|
||||
height: 100vh;
|
||||
}
|
||||
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
-ms-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
body {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
overflow: hidden !important;
|
||||
width: 100vw !important;
|
||||
height: 100vh !important;
|
||||
|
||||
#root {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
width: max-content !important;
|
||||
height: fit-content !important;
|
||||
}
|
||||
</style>
|
||||
-webkit-user-select: none;
|
||||
-moz-user-select: none;
|
||||
-ms-user-select: none;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
#root {
|
||||
margin: 0 !important;
|
||||
padding: 0 !important;
|
||||
width: max-content !important;
|
||||
height: fit-content !important;
|
||||
}
|
||||
</style>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
223
src/renderer/src/aiCore/AI_CORE_DESIGN.md
Normal file
223
src/renderer/src/aiCore/AI_CORE_DESIGN.md
Normal file
@ -0,0 +1,223 @@
|
||||
# Cherry Studio AI Provider 技术架构文档 (新方案)
|
||||
|
||||
## 1. 核心设计理念与目标
|
||||
|
||||
本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标:
|
||||
|
||||
- **职责清晰**:明确划分各组件的职责,降低耦合度。
|
||||
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
|
||||
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
|
||||
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
|
||||
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
|
||||
|
||||
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
|
||||
|
||||
## 2. 核心组件详解
|
||||
|
||||
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
|
||||
|
||||
这是整个 AI 功能的核心模块。
|
||||
|
||||
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
|
||||
|
||||
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
|
||||
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
|
||||
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
|
||||
- 例如:SDK 的 `delta.content` -> `TextDeltaChunk`;SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`;SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
|
||||
- **关键**:`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>` 或 `<tool_use>` 标签。
|
||||
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
|
||||
|
||||
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
|
||||
|
||||
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
|
||||
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
|
||||
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
|
||||
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
|
||||
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
|
||||
|
||||
#### 2.1.3. `ApiClientFactory.ts`
|
||||
|
||||
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
|
||||
|
||||
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
|
||||
|
||||
- **职责**:作为所有 AI 相关业务功能的统一入口。
|
||||
- 提供面向应用的高层接口,例如:
|
||||
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
|
||||
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
|
||||
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
|
||||
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
|
||||
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`。
|
||||
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。
|
||||
- **封装特定任务的提示工程 (Prompt Engineering)**:
|
||||
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`。
|
||||
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
|
||||
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
|
||||
- **将 `Promise` 的 `resolve` 和 `reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`。
|
||||
- **优势**:
|
||||
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
|
||||
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
|
||||
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
|
||||
|
||||
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
|
||||
|
||||
- 定义核心的、Provider 无关的内部请求结构,例如:
|
||||
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
|
||||
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
|
||||
|
||||
### 2.2. `middleware`
|
||||
|
||||
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
|
||||
|
||||
**核心组件包括:**
|
||||
|
||||
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。
|
||||
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
|
||||
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
|
||||
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
|
||||
|
||||
#### 2.2.1. `middlewareTypes.ts`
|
||||
|
||||
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance` 和 `_coreRequest`)、`MiddlewareAPI`、`CompletionsMiddleware` 等。
|
||||
|
||||
#### 2.2.2. 核心中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`。
|
||||
- **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`,Gemini 的 `GenerateContentResponse` 中的部分等)。
|
||||
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`。
|
||||
|
||||
#### 2.2.3. 特性中间件 (`middleware/feat/`)
|
||||
|
||||
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
|
||||
|
||||
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk` 和 `ThinkingCompleteChunk`。
|
||||
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
|
||||
|
||||
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
|
||||
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
|
||||
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
|
||||
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
|
||||
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
|
||||
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
|
||||
|
||||
#### 2.2.5. 通用中间件 (`middleware/common/`)
|
||||
|
||||
- **`LoggingMiddleware.ts`**: 请求和响应日志。
|
||||
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
|
||||
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
|
||||
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
|
||||
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
|
||||
- 在流结束时,发送包含最终累加信息的完成信号。
|
||||
|
||||
### 2.3. `types/chunk.ts`
|
||||
|
||||
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
|
||||
|
||||
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
|
||||
|
||||
```markdown
|
||||
**应用层 (例如 UI 组件)**
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` (`aiCore/index.ts`)**
|
||||
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
|
||||
||
|
||||
\\/
|
||||
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
|
||||
(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件)
|
||||
||
|
||||
\\/
|
||||
**[ 预处理阶段中间件 ]**
|
||||
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|
||||
|| (Context 中准备好 SDK 请求参数)
|
||||
\\/
|
||||
**[ 处理阶段中间件 ]**
|
||||
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|
||||
|| (处理各种特性和Chunk类型)
|
||||
\\/
|
||||
**[ SDK调用阶段中间件 ]**
|
||||
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|
||||
|| (输出: 标准化的应用层Chunk流)
|
||||
\\/
|
||||
**`FinalChunkConsumerMiddleware` (核心)**
|
||||
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
|
||||
```
|
||||
|
||||
## 4. 建议的文件/目录结构
|
||||
|
||||
```
|
||||
src/renderer/src/
|
||||
└── aiCore/
|
||||
├── clients/
|
||||
│ ├── openai/
|
||||
│ ├── gemini/
|
||||
│ ├── anthropic/
|
||||
│ ├── BaseApiClient.ts
|
||||
│ ├── ApiClientFactory.ts
|
||||
│ ├── AihubmixAPIClient.ts
|
||||
│ ├── index.ts
|
||||
│ └── types.ts
|
||||
├── middleware/
|
||||
│ ├── common/
|
||||
│ ├── core/
|
||||
│ ├── feat/
|
||||
│ ├── builder.ts
|
||||
│ ├── composer.ts
|
||||
│ ├── index.ts
|
||||
│ ├── register.ts
|
||||
│ ├── schemas.ts
|
||||
│ ├── types.ts
|
||||
│ └── utils.ts
|
||||
├── types/
|
||||
│ ├── chunk.ts
|
||||
│ └── ...
|
||||
└── index.ts
|
||||
```
|
||||
|
||||
## 5. 迁移和实施建议
|
||||
|
||||
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
|
||||
- **优先定义核心类型**:`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
|
||||
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
|
||||
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
|
||||
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
|
||||
|
||||
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
|
||||
|
||||
## 6. 迁移策略与实施建议
|
||||
|
||||
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
|
||||
|
||||
**目标架构核心组件回顾:**
|
||||
|
||||
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
|
||||
|
||||
**迁移步骤:**
|
||||
|
||||
**Phase 0: 准备工作和类型定义**
|
||||
|
||||
1. **定义核心数据结构 (TypeScript 类型):**
|
||||
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
|
||||
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。
|
||||
- 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。
|
||||
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
|
||||
3. **调整 `AiProviderMiddlewareContext`:**
|
||||
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`。
|
||||
- 确保包含 `_coreRequest: CoreRequestType`。
|
||||
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void` 和 `rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
|
||||
|
||||
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
|
||||
|
||||
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
|
||||
2. **迁移SDK实例和配置。**
|
||||
3. **实现 `getRequestTransformer()`:** 将 `CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
|
||||
4. **实现 `getResponseChunkTransformer()`:** 将 `OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `
|
||||
223
src/renderer/src/aiCore/clients/AihubmixAPIClient.ts
Normal file
223
src/renderer/src/aiCore/clients/AihubmixAPIClient.ts
Normal file
@ -0,0 +1,223 @@
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import {
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { CompletionsContext } from '../middleware/types'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
/**
|
||||
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
||||
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
||||
*/
|
||||
export class AihubmixAPIClient extends BaseApiClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||
new Map()
|
||||
private defaultClient: OpenAIAPIClient
|
||||
private currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
const providerExtraHeaders = {
|
||||
...provider,
|
||||
extra_headers: {
|
||||
...provider.extra_headers,
|
||||
'APP-Code': 'MLTG2087'
|
||||
}
|
||||
}
|
||||
|
||||
// 初始化各个client - 现在有类型安全
|
||||
const claudeClient = new AnthropicAPIClient(providerExtraHeaders)
|
||||
const geminiClient = new GeminiAPIClient({ ...providerExtraHeaders, apiHost: 'https://aihubmix.com/gemini' })
|
||||
const openaiClient = new OpenAIResponseAPIClient(providerExtraHeaders)
|
||||
const defaultClient = new OpenAIAPIClient(providerExtraHeaders)
|
||||
|
||||
this.clients.set('claude', claudeClient)
|
||||
this.clients.set('gemini', geminiClient)
|
||||
this.clients.set('openai', openaiClient)
|
||||
this.clients.set('default', defaultClient)
|
||||
|
||||
// 设置默认client
|
||||
this.defaultClient = defaultClient
|
||||
this.currentClient = this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
override getBaseURL(): string {
|
||||
if (!this.currentClient) {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
return this.currentClient.getBaseURL()
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:确保client是BaseApiClient的实例
|
||||
*/
|
||||
private isValidClient(client: unknown): client is BaseApiClient {
|
||||
return (
|
||||
client !== null &&
|
||||
client !== undefined &&
|
||||
typeof client === 'object' &&
|
||||
'createCompletions' in client &&
|
||||
'getRequestTransformer' in client &&
|
||||
'getResponseChunkTransformer' in client
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
private getClient(model: Model): BaseApiClient {
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
// claude开头
|
||||
if (id.startsWith('claude')) {
|
||||
const client = this.clients.get('claude')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Claude client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// gemini开头 且不以-nothink、-search结尾
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
const client = this.clients.get('gemini')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Gemini client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// OpenAI系列模型
|
||||
if (isOpenAILLMModel(model)) {
|
||||
const client = this.clients.get('openai')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('OpenAI client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
return this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型选择合适的client并委托调用
|
||||
*/
|
||||
public getClientForModel(model: Model): BaseApiClient {
|
||||
this.currentClient = this.getClient(model)
|
||||
return this.currentClient
|
||||
}
|
||||
|
||||
// ============ BaseApiClient 抽象方法实现 ============
|
||||
|
||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||
// 尝试从payload中提取模型信息来选择client
|
||||
const modelId = this.extractModelFromPayload(payload)
|
||||
if (modelId) {
|
||||
const modelObj = { id: modelId } as Model
|
||||
const targetClient = this.getClient(modelObj)
|
||||
return targetClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
// 如果无法从payload中提取模型,使用当前设置的client
|
||||
return this.currentClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从SDK payload中提取模型ID
|
||||
*/
|
||||
private extractModelFromPayload(payload: SdkParams): string | null {
|
||||
// 不同的SDK可能有不同的字段名
|
||||
if ('model' in payload && typeof payload.model === 'string') {
|
||||
return payload.model
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.currentClient.generateImage(params)
|
||||
}
|
||||
|
||||
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||
const client = model ? this.getClient(model) : this.currentClient
|
||||
return client.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
async listModels(): Promise<SdkModel[]> {
|
||||
// 可以聚合所有client的模型,或者使用默认client
|
||||
return this.defaultClient.listModels()
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<SdkInstance> {
|
||||
return this.currentClient.getSdkInstance()
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
||||
return this.currentClient.getRequestTransformer()
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
|
||||
return this.currentClient.getResponseChunkTransformer(ctx)
|
||||
}
|
||||
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
}
|
||||
|
||||
buildSdkMessages(
|
||||
currentReqMessages: SdkMessageParam[],
|
||||
output: SdkRawOutput | string,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls?: SdkToolCall[]
|
||||
): SdkMessageParam[] {
|
||||
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
}
|
||||
|
||||
convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): SdkMessageParam | undefined {
|
||||
const client = this.getClient(model)
|
||||
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
}
|
||||
|
||||
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
||||
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
|
||||
estimateMessageTokens(message: SdkMessageParam): number {
|
||||
return this.currentClient.estimateMessageTokens(message)
|
||||
}
|
||||
}
|
||||
72
src/renderer/src/aiCore/clients/ApiClientFactory.ts
Normal file
72
src/renderer/src/aiCore/clients/ApiClientFactory.ts
Normal file
@ -0,0 +1,72 @@
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { AihubmixAPIClient } from './AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from './gemini/VertexAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import { PPIOAPIClient } from './ppio/PPIOAPIClient'
|
||||
|
||||
/**
|
||||
* Factory for creating ApiClient instances based on provider configuration
|
||||
* 根据提供者配置创建ApiClient实例的工厂
|
||||
*/
|
||||
export class ApiClientFactory {
|
||||
/**
|
||||
* Create an ApiClient instance for the given provider
|
||||
* 为给定的提供者创建ApiClient实例
|
||||
*/
|
||||
static create(provider: Provider): BaseApiClient {
|
||||
console.log(`[ApiClientFactory] Creating ApiClient for provider:`, {
|
||||
id: provider.id,
|
||||
type: provider.type
|
||||
})
|
||||
|
||||
let instance: BaseApiClient
|
||||
|
||||
// 首先检查特殊的provider id
|
||||
if (provider.id === 'aihubmix') {
|
||||
console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`)
|
||||
instance = new AihubmixAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
if (provider.id === 'ppio') {
|
||||
console.log(`[ApiClientFactory] Creating PPIOAPIClient for provider: ${provider.id}`)
|
||||
instance = new PPIOAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
// 然后检查标准的provider type
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
case 'azure-openai':
|
||||
console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'openai-response':
|
||||
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'gemini':
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'vertexai':
|
||||
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
instance = new AnthropicAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
default:
|
||||
console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
}
|
||||
|
||||
return instance
|
||||
}
|
||||
}
|
||||
|
||||
export function isOpenAIProvider(provider: Provider) {
|
||||
return !['anthropic', 'gemini'].includes(provider.type)
|
||||
}
|
||||
@ -1,40 +1,70 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { isFunctionCallingModel, isNotSupportTemperatureAndTopP } from '@renderer/config/models'
|
||||
import {
|
||||
isFunctionCallingModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isOpenAIModel,
|
||||
isSupportedFlexServiceTier
|
||||
} from '@renderer/config/models'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||
import type {
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import {
|
||||
Assistant,
|
||||
FileTypes,
|
||||
GenerateImageParams,
|
||||
KnowledgeReference,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
OpenAIServiceTier,
|
||||
Provider,
|
||||
Suggestion,
|
||||
ToolCallResponse,
|
||||
WebSearchProviderResponse,
|
||||
WebSearchResponse
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { delay, isJSON, parseJSON } from '@renderer/utils'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findFileBlocks, getContentWithTools, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import Logger from 'electron-log/renderer'
|
||||
import { isEmpty } from 'lodash'
|
||||
import type OpenAI from 'openai'
|
||||
|
||||
import type { CompletionsParams } from '.'
|
||||
import { CompletionsContext } from '../middleware/types'
|
||||
import { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
export default abstract class BaseProvider {
|
||||
// Threshold for determining whether to use system prompt for tools
|
||||
/**
|
||||
* Abstract base class for API clients.
|
||||
* Provides common functionality and structure for specific client implementations.
|
||||
*/
|
||||
export abstract class BaseApiClient<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
|
||||
{
|
||||
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
|
||||
|
||||
protected provider: Provider
|
||||
public provider: Provider
|
||||
protected host: string
|
||||
protected apiKey: string
|
||||
|
||||
protected useSystemPromptForTools: boolean = true
|
||||
protected sdkInstance?: TSdkInstance
|
||||
public useSystemPromptForTools: boolean = true
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
@ -42,31 +72,70 @@ export default abstract class BaseProvider {
|
||||
this.apiKey = this.getApiKey()
|
||||
}
|
||||
|
||||
abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
abstract translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
): Promise<string>
|
||||
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
|
||||
abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null>
|
||||
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
|
||||
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
|
||||
abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }>
|
||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
||||
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
||||
public abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
|
||||
public abstract mcpToolCallResponseToMessage(
|
||||
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
||||
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
||||
|
||||
/**
|
||||
* 核心API Endpoint
|
||||
**/
|
||||
|
||||
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
|
||||
|
||||
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
|
||||
|
||||
abstract getEmbeddingDimensions(model?: Model): Promise<number>
|
||||
|
||||
abstract listModels(): Promise<SdkModel[]>
|
||||
|
||||
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
|
||||
/**
|
||||
* 中间件
|
||||
**/
|
||||
|
||||
// 在 CoreRequestToSdkParamsMiddleware中使用
|
||||
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
/**
|
||||
* 工具转换
|
||||
**/
|
||||
|
||||
// Optional tool conversion methods - implement if needed by the specific provider
|
||||
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
|
||||
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
|
||||
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
abstract buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string | undefined,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
abstract estimateMessageTokens(message: TMessageParam): number
|
||||
|
||||
abstract convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): any
|
||||
): TMessageParam | undefined
|
||||
|
||||
/**
|
||||
* 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
* 不同的提供商可能使用不同的字段名(如messages、history等)
|
||||
*/
|
||||
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
|
||||
/**
|
||||
* 通用函数
|
||||
**/
|
||||
|
||||
public getBaseURL(): string {
|
||||
const host = this.provider.apiHost
|
||||
return formatApiHost(host)
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
public getApiKey() {
|
||||
@ -111,18 +180,37 @@ export default abstract class BaseProvider {
|
||||
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
|
||||
}
|
||||
|
||||
public async fakeCompletions({ onChunk }: CompletionsParams) {
|
||||
for (let i = 0; i < 100; i++) {
|
||||
await delay(0.01)
|
||||
onChunk({
|
||||
response: { text: i + '\n', usage: { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 } },
|
||||
type: ChunkType.BLOCK_COMPLETE
|
||||
})
|
||||
protected getServiceTier(model: Model) {
|
||||
if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
let serviceTier = 'auto' as OpenAIServiceTier
|
||||
|
||||
if (openAI && openAI?.serviceTier === 'flex') {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
serviceTier = 'flex'
|
||||
} else {
|
||||
serviceTier = 'auto'
|
||||
}
|
||||
} else {
|
||||
serviceTier = openAI.serviceTier
|
||||
}
|
||||
|
||||
return serviceTier
|
||||
}
|
||||
|
||||
protected getTimeout(model: Model) {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
return 15 * 1000 * 60
|
||||
}
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
public async getMessageContent(message: Message): Promise<string> {
|
||||
const content = getMainTextContent(message)
|
||||
const content = getContentWithTools(message)
|
||||
|
||||
if (isEmpty(content)) {
|
||||
return ''
|
||||
}
|
||||
@ -148,6 +236,36 @@ export default abstract class BaseProvider {
|
||||
return content
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the file content from the message
|
||||
* @param message - The message
|
||||
* @returns The file content
|
||||
*/
|
||||
protected async extractFileContent(message: Message) {
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
if (fileBlocks.length > 0) {
|
||||
const textFileBlocks = fileBlocks.filter(
|
||||
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
|
||||
)
|
||||
|
||||
if (textFileBlocks.length > 0) {
|
||||
let text = ''
|
||||
const divider = '\n\n---\n\n'
|
||||
|
||||
for (const fileBlock of textFileBlocks) {
|
||||
const file = fileBlock.file
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||
text = text + fileNameRow + fileContent + divider
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
private async getWebSearchReferencesFromCache(message: Message) {
|
||||
const content = getMainTextContent(message)
|
||||
if (isEmpty(content)) {
|
||||
@ -156,6 +274,7 @@ export default abstract class BaseProvider {
|
||||
const webSearch: WebSearchResponse = window.keyv.get(`web-search-${message.id}`)
|
||||
|
||||
if (webSearch) {
|
||||
window.keyv.remove(`web-search-${message.id}`)
|
||||
return (webSearch.results as WebSearchProviderResponse).results.map(
|
||||
(result, index) =>
|
||||
({
|
||||
@ -181,6 +300,7 @@ export default abstract class BaseProvider {
|
||||
const knowledgeReferences: KnowledgeReference[] = window.keyv.get(`knowledge-search-${message.id}`)
|
||||
|
||||
if (!isEmpty(knowledgeReferences)) {
|
||||
window.keyv.remove(`knowledge-search-${message.id}`)
|
||||
// Logger.log(`Found ${knowledgeReferences.length} knowledge base references in cache for ID: ${message.id}`)
|
||||
return knowledgeReferences
|
||||
}
|
||||
@ -209,7 +329,7 @@ export default abstract class BaseProvider {
|
||||
)
|
||||
}
|
||||
|
||||
protected createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
const abortController = new AbortController()
|
||||
const abortFn = () => abortController.abort()
|
||||
|
||||
@ -255,11 +375,11 @@ export default abstract class BaseProvider {
|
||||
}
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
protected setupToolsConfig<T>(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: T[]
|
||||
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: TSdkSpecificTool[]
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
let tools: T[] = []
|
||||
let tools: TSdkSpecificTool[] = []
|
||||
|
||||
// If there are no tools, return an empty array
|
||||
if (!mcpTools?.length) {
|
||||
@ -267,14 +387,14 @@ export default abstract class BaseProvider {
|
||||
}
|
||||
|
||||
// If the number of tools exceeds the threshold, use the system prompt
|
||||
if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) {
|
||||
if (mcpTools.length > BaseApiClient.SYSTEM_PROMPT_THRESHOLD) {
|
||||
this.useSystemPromptForTools = true
|
||||
return { tools }
|
||||
}
|
||||
|
||||
// If the model supports function calling and tool usage is enabled
|
||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||
tools = this.convertMcpTools<T>(mcpTools)
|
||||
tools = this.convertMcpToolsToSdkTools(mcpTools)
|
||||
this.useSystemPromptForTools = false
|
||||
}
|
||||
|
||||
736
src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts
Normal file
736
src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts
Normal file
@ -0,0 +1,736 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
Base64ImageSource,
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
TextBlockParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUseBlock,
|
||||
WebSearchTool20250305
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import {
|
||||
ContentBlock,
|
||||
ContentBlockParam,
|
||||
MessageCreateParams,
|
||||
MessageCreateParamsBase,
|
||||
RedactedThinkingBlockParam,
|
||||
ServerToolUseBlockParam,
|
||||
ThinkingBlockParam,
|
||||
ThinkingConfigParam,
|
||||
ToolUnion,
|
||||
ToolUseBlockParam,
|
||||
WebSearchResultBlock,
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
ChunkType,
|
||||
ErrorChunk,
|
||||
LLMWebSearchCompleteChunk,
|
||||
LLMWebSearchInProgressChunk,
|
||||
MCPToolCreatedChunk,
|
||||
TextDeltaChunk,
|
||||
ThinkingDeltaChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import { type Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
AnthropicSdkMessageParam,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkRawOutput
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
export class AnthropicAPIClient extends BaseApiClient<
|
||||
Anthropic,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawOutput,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkMessageParam,
|
||||
ToolUseBlock,
|
||||
ToolUnion
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<Anthropic> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
this.sdkInstance = new Anthropic({
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'anthropic-beta': 'output-128k-2025-02-19',
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
})
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: AnthropicSdkParams,
|
||||
options?: Anthropic.RequestOptions
|
||||
): Promise<AnthropicSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
return await sdk.messages.create(payload, options)
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
return response.data
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
override async getEmbeddingDimensions(): Promise<number> {
|
||||
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @param model - The model
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
type: 'text',
|
||||
text: await this.getMessageContent(message)
|
||||
}
|
||||
]
|
||||
|
||||
// Get and process image blocks
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
// Handle uploaded file
|
||||
const file = imageBlock.file
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
parts.push({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: base64Data
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
|
||||
return mcpToolsToAnthropicTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): AnthropicSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: mcpToolResponse.toolCallId!,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text || ''
|
||||
} satisfies TextBlockParam
|
||||
}
|
||||
if (item.type === 'image') {
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: item.data || '',
|
||||
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||
type: 'base64'
|
||||
}
|
||||
} satisfies ImageBlockParam
|
||||
}
|
||||
return
|
||||
})
|
||||
.filter((n) => typeof n !== 'undefined'),
|
||||
is_error: resp.isError
|
||||
} satisfies ToolResultBlockParam
|
||||
]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Implementing abstract methods from BaseApiClient
|
||||
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
|
||||
// This might need adjustment based on how tool calls are specifically handled in the new structure
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
return mcpTool
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input as Record<string, unknown>,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
override buildSdkMessages(
|
||||
currentReqMessages: AnthropicSdkMessageParam[],
|
||||
output: Anthropic.Message,
|
||||
toolResults: AnthropicSdkMessageParam[]
|
||||
): AnthropicSdkMessageParam[] {
|
||||
const assistantMessage: AnthropicSdkMessageParam = {
|
||||
role: output.role,
|
||||
content: convertContentBlocksToParams(output.content)
|
||||
}
|
||||
|
||||
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
|
||||
if (toolResults && toolResults.length > 0) {
|
||||
newMessages.push(...toolResults)
|
||||
}
|
||||
return newMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
|
||||
if (typeof message.content === 'string') {
|
||||
return estimateTextTokens(message.content)
|
||||
}
|
||||
return message.content
|
||||
.map((content) => {
|
||||
switch (content.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(content.text)
|
||||
case 'image':
|
||||
if (content.source.type === 'base64') {
|
||||
return estimateTextTokens(content.source.data)
|
||||
} else {
|
||||
return estimateTextTokens(content.source.url)
|
||||
}
|
||||
case 'tool_use':
|
||||
return estimateTextTokens(JSON.stringify(content.input))
|
||||
case 'tool_result':
|
||||
return estimateTextTokens(JSON.stringify(content.content))
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
|
||||
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
|
||||
const messageParam: AnthropicSdkMessageParam = {
|
||||
role: message.role,
|
||||
content: convertContentBlocksToParams(message.content)
|
||||
}
|
||||
return messageParam
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic专用的原始流监听器
|
||||
* 处理MessageStream对象的特定事件
|
||||
*/
|
||||
attachRawStreamListener(
|
||||
rawOutput: AnthropicSdkRawOutput,
|
||||
listener: RawStreamListener<AnthropicSdkRawChunk>
|
||||
): AnthropicSdkRawOutput {
|
||||
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
|
||||
// 专用的Anthropic事件处理
|
||||
const anthropicListener = listener as AnthropicStreamListener
|
||||
// 检查是否为MessageStream
|
||||
if (rawOutput instanceof MessageStream) {
|
||||
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream,附加专用监听器`)
|
||||
|
||||
if (listener.onStart) {
|
||||
listener.onStart()
|
||||
}
|
||||
|
||||
if (listener.onChunk) {
|
||||
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
|
||||
listener.onChunk!(event)
|
||||
})
|
||||
}
|
||||
|
||||
if (anthropicListener.onContentBlock) {
|
||||
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
rawOutput.on('finalMessage', anthropicListener.onMessage)
|
||||
}
|
||||
|
||||
if (listener.onEnd) {
|
||||
rawOutput.on('end', () => {
|
||||
listener.onEnd!()
|
||||
})
|
||||
}
|
||||
|
||||
if (listener.onError) {
|
||||
rawOutput.on('error', (error: Error) => {
|
||||
listener.onError!(error)
|
||||
})
|
||||
}
|
||||
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
anthropicListener.onMessage(rawOutput)
|
||||
}
|
||||
|
||||
// 对于非MessageStream响应
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
} as WebSearchTool20250305
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: AnthropicSdkParams
|
||||
messages: AnthropicSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemPrompt = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools, assistant)
|
||||
}
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
? { type: 'text', text: systemPrompt }
|
||||
: undefined
|
||||
|
||||
// 3. 处理用户消息
|
||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
sdkMessages.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
const webSearchTool = await this.getWebSearchParams(model)
|
||||
if (webSearchTool) {
|
||||
tools.push(webSearchTool)
|
||||
}
|
||||
}
|
||||
|
||||
const commonParams: MessageCreateParamsBase = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: sdkMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: systemMessage ? [systemMessage] : undefined,
|
||||
thinking: this.getBudgetToken(assistant, model),
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
|
||||
const finalParams: MessageCreateParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: finalParams, messages: sdkMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
|
||||
return () => {
|
||||
let accumulatedJson = ''
|
||||
const toolCalls: Record<number, ToolUseBlock> = {}
|
||||
|
||||
return {
|
||||
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
switch (rawChunk.type) {
|
||||
case 'message': {
|
||||
let i = 0
|
||||
for (const content of rawChunk.content) {
|
||||
switch (content.type) {
|
||||
case 'text': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: content.text
|
||||
} as TextDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[i] = content
|
||||
i++
|
||||
break
|
||||
}
|
||||
case 'thinking': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: content.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: content.content,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if (i > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: Object.values(toolCalls)
|
||||
} as MCPToolCreatedChunk)
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'content_block_start': {
|
||||
const contentBlock = rawChunk.content_block
|
||||
switch (contentBlock.type) {
|
||||
case 'server_tool_use': {
|
||||
if (contentBlock.name === 'web_search') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
} as LLMWebSearchInProgressChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
if (
|
||||
contentBlock.content &&
|
||||
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||
) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
code: (contentBlock.content as WebSearchToolResultError).error_code,
|
||||
message: (contentBlock.content as WebSearchToolResultError).error_code
|
||||
}
|
||||
} as ErrorChunk)
|
||||
} else {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: contentBlock.content as Array<WebSearchResultBlock>,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[rawChunk.index] = contentBlock
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_delta': {
|
||||
const messageDelta = rawChunk.delta
|
||||
switch (messageDelta.type) {
|
||||
case 'text_delta': {
|
||||
if (messageDelta.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: messageDelta.text
|
||||
} as TextDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'thinking_delta': {
|
||||
if (messageDelta.thinking) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: messageDelta.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'input_json_delta': {
|
||||
if (messageDelta.partial_json) {
|
||||
accumulatedJson += messageDelta.partial_json
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_stop': {
|
||||
const toolCall = toolCalls[rawChunk.index]
|
||||
if (toolCall) {
|
||||
try {
|
||||
toolCall.input = JSON.parse(accumulatedJson)
|
||||
Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [toolCall]
|
||||
} as MCPToolCreatedChunk)
|
||||
} catch (error) {
|
||||
Logger.error(`Error parsing tool call input: ${error}`)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'message_delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
|
||||
* 去除服务器生成的额外字段,只保留发送给API所需的字段
|
||||
*/
|
||||
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
|
||||
return contentBlocks.map((block): ContentBlockParam => {
|
||||
switch (block.type) {
|
||||
case 'text':
|
||||
// TextBlock -> TextBlockParam,去除 citations 等服务器字段
|
||||
return {
|
||||
type: 'text',
|
||||
text: block.text
|
||||
} satisfies TextBlockParam
|
||||
case 'tool_use':
|
||||
// ToolUseBlock -> ToolUseBlockParam
|
||||
return {
|
||||
type: 'tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ToolUseBlockParam
|
||||
case 'thinking':
|
||||
// ThinkingBlock -> ThinkingBlockParam
|
||||
return {
|
||||
type: 'thinking',
|
||||
thinking: block.thinking,
|
||||
signature: block.signature
|
||||
} satisfies ThinkingBlockParam
|
||||
case 'redacted_thinking':
|
||||
// RedactedThinkingBlock -> RedactedThinkingBlockParam
|
||||
return {
|
||||
type: 'redacted_thinking',
|
||||
data: block.data
|
||||
} satisfies RedactedThinkingBlockParam
|
||||
case 'server_tool_use':
|
||||
// ServerToolUseBlock -> ServerToolUseBlockParam
|
||||
return {
|
||||
type: 'server_tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ServerToolUseBlockParam
|
||||
case 'web_search_tool_result':
|
||||
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
|
||||
return {
|
||||
type: 'web_search_tool_result',
|
||||
tool_use_id: block.tool_use_id,
|
||||
content: block.content
|
||||
} satisfies WebSearchToolResultBlockParam
|
||||
default:
|
||||
return block as ContentBlockParam
|
||||
}
|
||||
})
|
||||
}
|
||||
827
src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts
Normal file
827
src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts
Normal file
@ -0,0 +1,827 @@
|
||||
import {
|
||||
Content,
|
||||
File,
|
||||
FileState,
|
||||
FunctionCall,
|
||||
GenerateContentConfig,
|
||||
GenerateImagesConfig,
|
||||
GoogleGenAI,
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
Modality,
|
||||
Model as GeminiModel,
|
||||
Pager,
|
||||
Part,
|
||||
SafetySetting,
|
||||
SendMessageParameters,
|
||||
ThinkingConfig,
|
||||
Tool
|
||||
} from '@google/genai'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
isGemmaModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { CacheService } from '@renderer/services/CacheService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileType,
|
||||
FileTypes,
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
GeminiOptions,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToGeminiMessage,
|
||||
mcpToolsToGeminiTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout, MB } from '@shared/config/constant'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
export class GeminiAPIClient extends BaseApiClient<
|
||||
GoogleGenAI,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkToolCall,
|
||||
Tool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const { model, history, ...rest } = payload
|
||||
const realPayload: Omit<GeminiSdkParams, 'model'> = {
|
||||
...rest,
|
||||
config: {
|
||||
...rest.config,
|
||||
abortSignal: options?.signal,
|
||||
httpOptions: {
|
||||
...rest.config?.httpOptions,
|
||||
timeout: options?.timeout
|
||||
}
|
||||
}
|
||||
} satisfies SendMessageParameters
|
||||
|
||||
const streamOutput = options?.streamOutput
|
||||
|
||||
const chat = sdk.chats.create({
|
||||
model: model,
|
||||
history: history
|
||||
})
|
||||
|
||||
if (streamOutput) {
|
||||
const stream = chat.sendMessageStream(realPayload)
|
||||
return stream
|
||||
} else {
|
||||
const response = await chat.sendMessage(realPayload)
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
|
||||
const config: GenerateImagesConfig = {
|
||||
numberOfImages: batchSize,
|
||||
aspectRatio: imageSize,
|
||||
abortSignal: signal,
|
||||
httpOptions: {
|
||||
timeout: defaultTimeout
|
||||
}
|
||||
}
|
||||
const response = await sdk.models.generateImages({
|
||||
model: model,
|
||||
prompt,
|
||||
config
|
||||
})
|
||||
|
||||
if (!response.generatedImages || response.generatedImages.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
const images = response.generatedImages
|
||||
.filter((image) => image.image?.imageBytes)
|
||||
.map((image) => {
|
||||
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
|
||||
return dataPrefix + image.image?.imageBytes
|
||||
})
|
||||
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
|
||||
return images
|
||||
} catch (error) {
|
||||
console.error('[generateImage] error:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
const data = await sdk.models.embedContent({
|
||||
model: model.id,
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||
})
|
||||
return data.embeddings?.[0]?.values?.length || 0
|
||||
}
|
||||
|
||||
override async listModels(): Promise<GeminiModel[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
const models: GeminiModel[] = []
|
||||
for await (const model of response) {
|
||||
models.push(model)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: false,
|
||||
apiKey: this.apiKey,
|
||||
apiVersion: this.getApiVersion(),
|
||||
httpOptions: {
|
||||
baseUrl: this.getBaseURL(),
|
||||
apiVersion: this.getApiVersion(),
|
||||
headers: {
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
protected getApiVersion(): string {
|
||||
if (this.provider.isVertex) {
|
||||
return 'v1'
|
||||
}
|
||||
return 'v1beta'
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a PDF file
|
||||
* @param file - The file
|
||||
* @returns The part
|
||||
*/
|
||||
private async handlePdfFile(file: FileType): Promise<Part> {
|
||||
const smallFileSize = 20 * MB
|
||||
const isSmallFile = file.size < smallFileSize
|
||||
|
||||
if (isSmallFile) {
|
||||
const { data, mimeType } = await this.base64File(file)
|
||||
return {
|
||||
inlineData: {
|
||||
data,
|
||||
mimeType
|
||||
} as Part['inlineData']
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve file from Gemini uploaded files
|
||||
const fileMetadata: File | undefined = await this.retrieveFile(file)
|
||||
|
||||
if (fileMetadata) {
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: fileMetadata.uri,
|
||||
mimeType: fileMetadata.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
}
|
||||
|
||||
// If file is not found, upload it to Gemini
|
||||
const result = await this.uploadFile(file)
|
||||
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: result.uri,
|
||||
mimeType: result.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message contents
|
||||
* @param message - The message
|
||||
* @returns The message contents
|
||||
*/
|
||||
private async convertMessageToSdkParam(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
||||
|
||||
// Add any generated images from previous responses
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
|
||||
if (file.ext === '.pdf') {
|
||||
parts.push(await this.handlePdfFile(file))
|
||||
continue
|
||||
}
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async getImageFileContents(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const content = getMainTextContent(message)
|
||||
const parts: Part[] = [{ text: content }]
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the safety settings
|
||||
* @returns The safety settings
|
||||
*/
|
||||
private getSafetySettings(): SafetySetting[] {
|
||||
const safetyThreshold = 'OFF' as HarmBlockThreshold
|
||||
|
||||
return [
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
|
||||
threshold: HarmBlockThreshold.BLOCK_NONE
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model) {
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// 如果thinking_budget是undefined,不思考
|
||||
if (reasoningEffort === undefined) {
|
||||
return GEMINI_FLASH_MODEL_REGEX.test(model.id)
|
||||
? {
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 0
|
||||
}
|
||||
}
|
||||
: {}
|
||||
}
|
||||
|
||||
if (reasoningEffort === 'auto') {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
|
||||
// 计算 budgetTokens,确保不低于 min
|
||||
const budget = Math.floor((max - min) * effortRatio + min)
|
||||
|
||||
return {
|
||||
thinkingConfig: {
|
||||
...(budget > 0 ? { thinkingBudget: budget } : {}),
|
||||
includeThoughts: true
|
||||
} as ThinkingConfig
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
|
||||
return {
|
||||
systemInstruction: undefined,
|
||||
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||
responseMimeType: 'text/plain'
|
||||
}
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: GeminiSdkParams
|
||||
messages: GeminiSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemInstruction = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant)
|
||||
}
|
||||
|
||||
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
|
||||
const history: Content[] = []
|
||||
// 3. 处理用户消息
|
||||
if (typeof messages === 'string') {
|
||||
messageContents = {
|
||||
role: 'user',
|
||||
parts: [{ text: messages }]
|
||||
}
|
||||
} else {
|
||||
const userLastMessage = messages.pop()
|
||||
if (userLastMessage) {
|
||||
messageContents = await this.convertMessageToSdkParam(userLastMessage)
|
||||
for (const message of messages) {
|
||||
history.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
messages.push(userLastMessage)
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
googleSearch: {}
|
||||
})
|
||||
}
|
||||
|
||||
if (isGemmaModel(model) && assistant.prompt) {
|
||||
const isFirstMessage = history.length === 0
|
||||
if (isFirstMessage && messageContents) {
|
||||
const userMessageText =
|
||||
messageContents.parts && messageContents.parts.length > 0
|
||||
? (messageContents.parts[0] as Part).text || ''
|
||||
: ''
|
||||
const systemMessage = [
|
||||
{
|
||||
text:
|
||||
'<start_of_turn>user\n' +
|
||||
systemInstruction +
|
||||
'<end_of_turn>\n' +
|
||||
'<start_of_turn>user\n' +
|
||||
userMessageText +
|
||||
'<end_of_turn>'
|
||||
}
|
||||
] as Part[]
|
||||
if (messageContents && messageContents.parts) {
|
||||
messageContents.parts[0] = systemMessage[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const newHistory =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
|
||||
: history
|
||||
|
||||
const newMessageContents =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages[recursiveSdkMessages.length - 1]
|
||||
: messageContents
|
||||
|
||||
const generateContentConfig: GenerateContentConfig = {
|
||||
safetySettings: this.getSafetySettings(),
|
||||
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
topP: this.getTopP(assistant, model),
|
||||
maxOutputTokens: maxTokens,
|
||||
tools: tools,
|
||||
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
|
||||
...this.getBudgetToken(assistant, model),
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
|
||||
const param: GeminiSdkParams = {
|
||||
model: model.id,
|
||||
config: generateContentConfig,
|
||||
history: newHistory,
|
||||
message: newMessageContents.parts!
|
||||
}
|
||||
|
||||
return {
|
||||
payload: param,
|
||||
messages: [messageContents],
|
||||
metadata: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
|
||||
return () => ({
|
||||
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
const toolCalls: FunctionCall[] = []
|
||||
if (chunk.candidates && chunk.candidates.length > 0) {
|
||||
for (const candidate of chunk.candidates) {
|
||||
if (candidate.content) {
|
||||
candidate.content.parts?.forEach((part) => {
|
||||
const text = part.text || ''
|
||||
if (part.thought) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.inlineData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [
|
||||
part.inlineData?.data?.startsWith('data:')
|
||||
? part.inlineData?.data
|
||||
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
|
||||
]
|
||||
}
|
||||
})
|
||||
} else if (part.functionCall) {
|
||||
toolCalls.push(part.functionCall)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if (candidate.finishReason) {
|
||||
if (candidate.groundingMetadata) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: candidate.groundingMetadata,
|
||||
source: WebSearchSource.GEMINI
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens:
|
||||
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
|
||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
|
||||
return mcpToolsToGeminiTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
|
||||
} catch {
|
||||
return toolCall.args
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.id || nanoid(),
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): GeminiSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: mcpToolResponse.toolCallId,
|
||||
name: mcpToolResponse.tool.id,
|
||||
response: {
|
||||
output: !resp.isError ? resp.content : undefined,
|
||||
error: resp.isError ? resp.content : undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
} satisfies Content
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: Content[],
|
||||
output: string,
|
||||
toolResults: Content[],
|
||||
toolCalls: FunctionCall[]
|
||||
): Content[] {
|
||||
const parts: Part[] = []
|
||||
const modelParts: Part[] = []
|
||||
if (output) {
|
||||
modelParts.push({
|
||||
text: output
|
||||
})
|
||||
}
|
||||
|
||||
toolCalls.forEach((toolCall) => {
|
||||
modelParts.push({
|
||||
functionCall: toolCall
|
||||
})
|
||||
})
|
||||
|
||||
parts.push(
|
||||
...toolResults
|
||||
.map((ts) => ts.parts)
|
||||
.flat()
|
||||
.filter((p) => p !== undefined)
|
||||
)
|
||||
|
||||
const userMessage: Content = {
|
||||
role: 'user',
|
||||
parts: []
|
||||
}
|
||||
|
||||
if (modelParts.length > 0) {
|
||||
currentReqMessages.push({
|
||||
role: 'model',
|
||||
parts: modelParts
|
||||
})
|
||||
}
|
||||
if (parts.length > 0) {
|
||||
userMessage.parts?.push(...parts)
|
||||
currentReqMessages.push(userMessage)
|
||||
}
|
||||
|
||||
return currentReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
|
||||
return (
|
||||
message.parts?.reduce((acc, part) => {
|
||||
if (part.text) {
|
||||
return acc + estimateTextTokens(part.text)
|
||||
}
|
||||
if (part.functionCall) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
|
||||
}
|
||||
if (part.functionResponse) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
|
||||
}
|
||||
if (part.inlineData) {
|
||||
return acc + estimateTextTokens(part.inlineData.data || '')
|
||||
}
|
||||
if (part.fileData) {
|
||||
return acc + estimateTextTokens(part.fileData.fileUri || '')
|
||||
}
|
||||
return acc
|
||||
}, 0) || 0
|
||||
)
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
|
||||
const messageParam: GeminiSdkMessageParam = {
|
||||
role: 'user',
|
||||
parts: []
|
||||
}
|
||||
if (Array.isArray(sdkPayload.message)) {
|
||||
sdkPayload.message.forEach((part) => {
|
||||
if (typeof part === 'string') {
|
||||
messageParam.parts?.push({ text: part })
|
||||
} else if (typeof part === 'object') {
|
||||
messageParam.parts?.push(part)
|
||||
}
|
||||
})
|
||||
}
|
||||
return [...(sdkPayload.history || []), messageParam]
|
||||
}
|
||||
|
||||
private async uploadFile(file: FileType): Promise<File> {
|
||||
return await this.sdkInstance!.files.upload({
|
||||
file: file.path,
|
||||
config: {
|
||||
mimeType: 'application/pdf',
|
||||
name: file.id,
|
||||
displayName: file.origin_name
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private async base64File(file: FileType) {
|
||||
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||
return {
|
||||
data,
|
||||
mimeType: 'application/pdf'
|
||||
}
|
||||
}
|
||||
|
||||
private async retrieveFile(file: FileType): Promise<File | undefined> {
|
||||
const cachedResponse = CacheService.get<any>('gemini_file_list')
|
||||
|
||||
if (cachedResponse) {
|
||||
return this.processResponse(cachedResponse, file)
|
||||
}
|
||||
|
||||
const response = await this.sdkInstance!.files.list()
|
||||
CacheService.set('gemini_file_list', response, 3000)
|
||||
|
||||
return this.processResponse(response, file)
|
||||
}
|
||||
|
||||
private async processResponse(response: Pager<File>, file: FileType) {
|
||||
for await (const f of response) {
|
||||
if (f.state === FileState.ACTIVE) {
|
||||
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async listFiles(): Promise<File[]> {
|
||||
const files: File[] = []
|
||||
for await (const f of await this.sdkInstance!.files.list()) {
|
||||
files.push(f)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async deleteFile(fileId: string) {
|
||||
await this.sdkInstance!.files.delete({ name: fileId })
|
||||
}
|
||||
}
|
||||
95
src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts
Normal file
95
src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts
Normal file
@ -0,0 +1,95 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { GeminiAPIClient } from './GeminiAPIClient'
|
||||
|
||||
export class VertexAPIClient extends GeminiAPIClient {
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const location = getVertexAILocation()
|
||||
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
const authHeaders = await this.getServiceAccountAuthHeaders()
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: projectId,
|
||||
location: location,
|
||||
httpOptions: {
|
||||
apiVersion: this.getApiVersion(),
|
||||
headers: authHeaders
|
||||
}
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 检查是否已有有效的认证头(提前 5 分钟过期)
|
||||
const now = Date.now()
|
||||
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
|
||||
return this.authHeaders
|
||||
}
|
||||
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId,
|
||||
serviceAccount: {
|
||||
privateKey: serviceAccount.privateKey,
|
||||
clientEmail: serviceAccount.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
// 设置过期时间(通常认证头有效期为 1 小时)
|
||||
this.authHeadersExpiry = now + 60 * 60 * 1000
|
||||
|
||||
return this.authHeaders
|
||||
} catch (error: any) {
|
||||
console.error('Failed to get auth headers:', error)
|
||||
throw new Error(`Service Account authentication failed: ${error.message}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理认证缓存并重新初始化
|
||||
*/
|
||||
clearAuthCache(): void {
|
||||
this.authHeaders = undefined
|
||||
this.authHeadersExpiry = undefined
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
|
||||
if (projectId && serviceAccount.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
|
||||
}
|
||||
}
|
||||
}
|
||||
6
src/renderer/src/aiCore/clients/index.ts
Normal file
6
src/renderer/src/aiCore/clients/index.ts
Normal file
@ -0,0 +1,6 @@
|
||||
export * from './ApiClientFactory'
|
||||
export * from './BaseApiClient'
|
||||
export * from './types'
|
||||
|
||||
// Export specific clients from subdirectories
|
||||
export * from './openai/OpenAIApiClient'
|
||||
764
src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts
Normal file
764
src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts
Normal file
@ -0,0 +1,764 @@
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
getOpenAIWebSearchParams,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortGrokModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenDoubaoModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
// For Copilot token
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkRawContentSource,
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||
mcpToolsToOpenAIChatTools,
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
|
||||
export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
OpenAI | AzureOpenAI,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawOutput,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
ChatCompletionTool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAISdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAISdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
// @ts-ignore - SDK参数可能有额外的字段
|
||||
return await sdk.chat.completions.create(payload, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
// Method for reasoning effort, moved from OpenAIProvider
|
||||
override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
if (this.provider.id === 'groq') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
// Doubao 思考模式支持
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
// reasoningEffort 为空,默认开启 enabled
|
||||
if (!reasoningEffort) {
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
if (reasoningEffort === 'high') {
|
||||
return { thinking: { type: 'enabled' } }
|
||||
}
|
||||
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
|
||||
return { thinking: { type: 'auto' } }
|
||||
}
|
||||
// 其他情况不带 thinking 字段
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!reasoningEffort) {
|
||||
if (model.provider === 'openrouter') {
|
||||
if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||
return {}
|
||||
}
|
||||
return { reasoning: { enabled: false, exclude: true } }
|
||||
}
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return { enable_thinking: false }
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||
return {
|
||||
extra_body: {
|
||||
google: {
|
||||
thinking_config: {
|
||||
thinking_budget: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
return { thinking: { type: 'disabled' } }
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const budgetTokens = Math.floor(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
|
||||
)
|
||||
|
||||
// OpenRouter models
|
||||
if (model.provider === 'openrouter') {
|
||||
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen models
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return {
|
||||
enable_thinking: true,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Grok models
|
||||
if (isSupportedReasoningEffortGrokModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI models
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
if (reasoningEffort === 'auto') {
|
||||
return {
|
||||
extra_body: {
|
||||
google: {
|
||||
thinking_config: {
|
||||
thinking_budget: -1,
|
||||
include_thoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return {
|
||||
extra_body: {
|
||||
google: {
|
||||
thinking_config: {
|
||||
thinking_budget: budgetTokens,
|
||||
include_thoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude models
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const maxTokens = assistant.settings?.maxTokens
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budget_tokens: Math.floor(
|
||||
Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Doubao models
|
||||
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||
if (assistant.settings?.reasoning_effort === 'high') {
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default case: no special thinking settings
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the provider does not support files
|
||||
* @returns True if the provider does not support files, false otherwise
|
||||
*/
|
||||
private get isNotSupportFiles() {
|
||||
if (this.provider?.isNotSupportArrayContent) {
|
||||
return true
|
||||
}
|
||||
|
||||
const providers = ['deepseek', 'baichuan', 'minimax', 'xirang']
|
||||
|
||||
return providers.includes(this.provider.id)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @param model - The model
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAISdkMessageParam> {
|
||||
const isVision = isVisionModel(model)
|
||||
const content = await this.getMessageContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
// If the model does not support files, extract the file content
|
||||
if (this.isNotSupportFiles) {
|
||||
const fileContent = await this.extractFileContent(message)
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: content + '\n\n---\n\n' + fileContent
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
// If the model supports files, add the file content to the message
|
||||
const parts: ChatCompletionContentPart[] = []
|
||||
|
||||
if (content) {
|
||||
parts.push({ type: 'text', text: content })
|
||||
}
|
||||
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (isVision) {
|
||||
if (imageBlock.file) {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({ type: 'image_url', image_url: { url: image.data } })
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({ type: 'image_url', image_url: { url: imageBlock.url } })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] {
|
||||
return mcpToolsToOpenAIChatTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(
|
||||
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
mcpTools: MCPTool[]
|
||||
): MCPTool | undefined {
|
||||
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcpToolResponse(
|
||||
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
mcpTool: MCPTool
|
||||
): ToolCallResponse {
|
||||
let parsedArgs: any
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolCall.function.arguments)
|
||||
} catch {
|
||||
parsedArgs = toolCall.function.arguments
|
||||
}
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): OpenAISdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
// This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id
|
||||
// For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts.
|
||||
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
role: 'tool',
|
||||
tool_call_id: mcpToolResponse.toolCallId,
|
||||
content: JSON.stringify(resp.content)
|
||||
} as OpenAI.Chat.Completions.ChatCompletionToolMessageParam
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAISdkMessageParam[],
|
||||
output: string | undefined,
|
||||
toolResults: OpenAISdkMessageParam[],
|
||||
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
|
||||
): OpenAISdkMessageParam[] {
|
||||
if (!output && toolCalls.length === 0) {
|
||||
return [...currentReqMessages, ...toolResults]
|
||||
}
|
||||
|
||||
const assistantMessage: OpenAISdkMessageParam = {
|
||||
role: 'assistant',
|
||||
content: output,
|
||||
tool_calls: toolCalls.length > 0 ? toolCalls : undefined
|
||||
}
|
||||
const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults]
|
||||
return newReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: OpenAISdkMessageParam): number {
|
||||
let sum = 0
|
||||
if (typeof message.content === 'string') {
|
||||
sum += estimateTextTokens(message.content)
|
||||
} else if (Array.isArray(message.content)) {
|
||||
sum += (message.content || [])
|
||||
.map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(part.text)
|
||||
case 'image_url':
|
||||
return estimateTextTokens(part.image_url.url)
|
||||
case 'input_audio':
|
||||
return estimateTextTokens(part.input_audio.data)
|
||||
case 'file':
|
||||
return estimateTextTokens(part.file.file_data || '')
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
if ('tool_calls' in message && message.tool_calls) {
|
||||
sum += message.tool_calls.reduce((acc, toolCall) => {
|
||||
return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments))
|
||||
}, 0)
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<OpenAISdkParams, OpenAISdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: OpenAISdkParams
|
||||
messages: OpenAISdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemMessage = { role: 'system', content: assistant.prompt || '' }
|
||||
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
systemMessage = {
|
||||
role: 'developer',
|
||||
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
|
||||
}
|
||||
}
|
||||
|
||||
if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) {
|
||||
systemMessage.role = 'assistant'
|
||||
}
|
||||
|
||||
// 2. 设置工具(必须在this.usesystemPromptForTools前面)
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools, assistant)
|
||||
}
|
||||
|
||||
// 3. 处理用户消息
|
||||
const userMessages: OpenAISdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
userMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
userMessages.push(await this.convertMessageToSdkParam(message, model))
|
||||
}
|
||||
}
|
||||
|
||||
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
|
||||
if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
|
||||
const postsuffix = '/no_think'
|
||||
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
|
||||
const currentContent = lastUserMsg.content
|
||||
|
||||
lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any
|
||||
}
|
||||
|
||||
// 4. 最终请求消息
|
||||
let reqMessages: OpenAISdkMessageParam[]
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessages]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[]
|
||||
}
|
||||
|
||||
reqMessages = processReqMessages(model, reqMessages)
|
||||
|
||||
// 5. 创建通用参数
|
||||
const commonParams = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: reqMessages,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_tokens: maxTokens,
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
service_tier: this.getServiceTier(model),
|
||||
...this.getProviderSpecificParameters(assistant, model),
|
||||
...this.getReasoningEffort(assistant, model),
|
||||
...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
|
||||
// Create the appropriate parameters object based on whether streaming is enabled
|
||||
const sdkParams: OpenAISdkParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
|
||||
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAISdkRawChunk> {
|
||||
let hasBeenCollectedWebSearch = false
|
||||
const collectWebSearchData = (
|
||||
chunk: OpenAISdkRawChunk,
|
||||
contentSource: OpenAISdkRawContentSource,
|
||||
context: ResponseChunkTransformerContext
|
||||
) => {
|
||||
if (hasBeenCollectedWebSearch) {
|
||||
return
|
||||
}
|
||||
// OpenAI annotations
|
||||
// @ts-ignore - annotations may not be in standard type definitions
|
||||
const annotations = contentSource.annotations || chunk.annotations
|
||||
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
results: annotations,
|
||||
source: WebSearchSource.OPENAI
|
||||
}
|
||||
}
|
||||
|
||||
// Grok citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'grok' && chunk.citations) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.GROK
|
||||
}
|
||||
}
|
||||
|
||||
// Perplexity citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.PERPLEXITY
|
||||
}
|
||||
}
|
||||
|
||||
// OpenRouter citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.OPENROUTER
|
||||
}
|
||||
}
|
||||
|
||||
// Zhipu web search
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
if (context.provider?.id === 'zhipu' && chunk.web_search) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
results: chunk.web_search,
|
||||
source: WebSearchSource.ZHIPU
|
||||
}
|
||||
}
|
||||
|
||||
// Hunyuan web search
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
results: chunk.search_info.search_results,
|
||||
source: WebSearchSource.HUNYUAN
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 放到AnthropicApiClient中
|
||||
// // Other providers...
|
||||
// // @ts-ignore - web_search may not be in standard type definitions
|
||||
// if (chunk.web_search) {
|
||||
// const sourceMap: Record<string, string> = {
|
||||
// openai: 'openai',
|
||||
// anthropic: 'anthropic',
|
||||
// qwenlm: 'qwen'
|
||||
// }
|
||||
// const source = sourceMap[context.provider?.id] || 'openai_response'
|
||||
// return {
|
||||
// results: chunk.web_search,
|
||||
// source: source as const
|
||||
// }
|
||||
// }
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
|
||||
let isFinished = false
|
||||
let lastUsageInfo: any = null
|
||||
|
||||
/**
|
||||
* 统一的完成信号发送逻辑
|
||||
* - 有 finish_reason 时
|
||||
* - 无 finish_reason 但是流正常结束时
|
||||
*/
|
||||
const emitCompletionSignals = (controller: TransformStreamDefaultController<GenericChunk>) => {
|
||||
if (isFinished) return
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
|
||||
const usage = lastUsageInfo || {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: { usage }
|
||||
})
|
||||
|
||||
// 防止重复发送
|
||||
isFinished = true
|
||||
}
|
||||
|
||||
return (context: ResponseChunkTransformerContext) => ({
|
||||
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 持续更新usage信息
|
||||
if (chunk.usage) {
|
||||
lastUsageInfo = {
|
||||
prompt_tokens: chunk.usage.prompt_tokens || 0,
|
||||
completion_tokens: chunk.usage.completion_tokens || 0,
|
||||
total_tokens: (chunk.usage.prompt_tokens || 0) + (chunk.usage.completion_tokens || 0)
|
||||
}
|
||||
}
|
||||
|
||||
// 处理chunk
|
||||
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
|
||||
const choice = chunk.choices[0]
|
||||
|
||||
if (!choice) return
|
||||
|
||||
// 对于流式响应,使用 delta;对于非流式响应,使用 message。
|
||||
// 然而某些 OpenAI 兼容平台在非流式请求时会错误地返回一个空对象的 delta 字段。
|
||||
// 如果 delta 为空对象,应当忽略它并回退到 message,避免造成内容缺失。
|
||||
let contentSource: OpenAISdkRawContentSource | null = null
|
||||
if ('delta' in choice && choice.delta && Object.keys(choice.delta).length > 0) {
|
||||
contentSource = choice.delta
|
||||
} else if ('message' in choice) {
|
||||
contentSource = choice.message
|
||||
}
|
||||
|
||||
if (!contentSource) return
|
||||
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
|
||||
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
|
||||
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
|
||||
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
|
||||
if (reasoningText) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: reasoningText
|
||||
})
|
||||
}
|
||||
|
||||
// 处理文本内容
|
||||
if (contentSource.content) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: contentSource.content
|
||||
})
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if (contentSource.tool_calls) {
|
||||
for (const toolCall of contentSource.tool_calls) {
|
||||
if ('index' in toolCall) {
|
||||
const { id, index, function: fun } = toolCall
|
||||
if (fun?.name) {
|
||||
toolCalls[index] = {
|
||||
id: id || '',
|
||||
function: {
|
||||
name: fun.name,
|
||||
arguments: fun.arguments || ''
|
||||
},
|
||||
type: 'function'
|
||||
}
|
||||
} else if (fun?.arguments) {
|
||||
toolCalls[index].function.arguments += fun.arguments
|
||||
}
|
||||
} else {
|
||||
toolCalls.push(toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理finish_reason,发送流结束信号
|
||||
if ('finish_reason' in choice && choice.finish_reason) {
|
||||
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
emitCompletionSignals(controller)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
// 流正常结束时,检查是否需要发送完成信号
|
||||
flush(controller) {
|
||||
if (isFinished) return
|
||||
|
||||
Logger.debug('[OpenAIApiClient] Stream ended without finish_reason, emitting fallback completion signals')
|
||||
emitCompletionSignals(controller)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
256
src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts
Normal file
256
src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts
Normal file
@ -0,0 +1,256 @@
|
||||
import {
|
||||
isClaudeReasoningModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isOpenAIReasoningModel,
|
||||
isSupportedModel,
|
||||
isSupportedReasoningEffortOpenAIModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
|
||||
/**
|
||||
* 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能
|
||||
*/
|
||||
export abstract class OpenAIBaseClient<
|
||||
TSdkInstance extends OpenAI | AzureOpenAI,
|
||||
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
|
||||
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
|
||||
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
|
||||
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
|
||||
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
|
||||
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
|
||||
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
// 仅适用于openai
|
||||
override getBaseURL(): string {
|
||||
const host = this.provider.apiHost
|
||||
return formatApiHost(host)
|
||||
}
|
||||
|
||||
override async generateImage({
|
||||
model,
|
||||
prompt,
|
||||
negativePrompt,
|
||||
imageSize,
|
||||
batchSize,
|
||||
seed,
|
||||
numInferenceSteps,
|
||||
guidanceScale,
|
||||
signal,
|
||||
promptEnhancement
|
||||
}: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = (await sdk.request({
|
||||
method: 'post',
|
||||
path: '/images/generations',
|
||||
signal,
|
||||
body: {
|
||||
model,
|
||||
prompt,
|
||||
negative_prompt: negativePrompt,
|
||||
image_size: imageSize,
|
||||
batch_size: batchSize,
|
||||
seed: seed ? parseInt(seed) : undefined,
|
||||
num_inference_steps: numInferenceSteps,
|
||||
guidance_scale: guidanceScale,
|
||||
prompt_enhancement: promptEnhancement
|
||||
}
|
||||
})) as { data: Array<{ url: string }> }
|
||||
|
||||
return response.data.map((item) => item.url)
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
const data = await sdk.embeddings.create({
|
||||
model: model.id,
|
||||
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
|
||||
encoding_format: 'float'
|
||||
})
|
||||
return data.data[0].embedding.length
|
||||
}
|
||||
|
||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
if (this.provider.id === 'github') {
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body
|
||||
.map((model) => ({
|
||||
id: model.name,
|
||||
description: model.summary,
|
||||
object: 'model',
|
||||
owned_by: model.publisher
|
||||
}))
|
||||
.filter(isSupportedModel)
|
||||
}
|
||||
if (this.provider.id === 'together') {
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body.map((model) => ({
|
||||
id: model.id,
|
||||
description: model.display_name,
|
||||
object: 'model',
|
||||
owned_by: model.organization
|
||||
}))
|
||||
}
|
||||
const models = response.data || []
|
||||
models.forEach((model) => {
|
||||
model.id = model.id.trim()
|
||||
})
|
||||
|
||||
return models.filter(isSupportedModel)
|
||||
} catch (error) {
|
||||
console.error('Error listing models:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
let apiKeyForSdkInstance = this.apiKey
|
||||
|
||||
if (this.provider.id === 'copilot') {
|
||||
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||
// this.provider.apiKey不允许修改
|
||||
// this.provider.apiKey = token
|
||||
apiKeyForSdkInstance = token
|
||||
}
|
||||
|
||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||
this.sdkInstance = new AzureOpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
apiVersion: this.provider.apiVersion,
|
||||
endpoint: this.provider.apiHost
|
||||
}) as TSdkInstance
|
||||
} else {
|
||||
this.sdkInstance = new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
baseURL: this.getBaseURL(),
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders(),
|
||||
...this.provider.extra_headers,
|
||||
...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}),
|
||||
...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {})
|
||||
}
|
||||
}) as TSdkInstance
|
||||
}
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (
|
||||
isNotSupportTemperatureAndTopP(model) ||
|
||||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (
|
||||
isNotSupportTemperatureAndTopP(model) ||
|
||||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the provider specific parameters for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The provider specific parameters
|
||||
*/
|
||||
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
if (this.provider.id === 'openrouter') {
|
||||
if (model.id.includes('deepseek-r1')) {
|
||||
return {
|
||||
include_reasoning: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return {
|
||||
max_tokens: undefined,
|
||||
max_completion_tokens: maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
if (!isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
const summaryText = openAI?.summaryText || 'off'
|
||||
|
||||
let summary: string | undefined = undefined
|
||||
|
||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||
summary = undefined
|
||||
} else {
|
||||
summary = summaryText
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
if (!reasoningEffort) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort as OpenAI.ReasoningEffort,
|
||||
summary: summary
|
||||
} as OpenAI.Reasoning
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,613 @@
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
|
||||
import {
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
FileType,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToOpenAIMessage,
|
||||
mcpToolsToOpenAIResponseTools,
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { MB } from '@shared/config/constant'
|
||||
import { isEmpty } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
import { ResponseInput } from 'openai/resources/responses/responses'
|
||||
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
import { OpenAIAPIClient } from './OpenAIApiClient'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
|
||||
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
OpenAI,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAIResponseSdkTool
|
||||
> {
|
||||
private client: OpenAIAPIClient
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.client = new OpenAIAPIClient(provider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型特征选择合适的客户端
|
||||
*/
|
||||
public getClient(model: Model) {
|
||||
if (isOpenAIChatCompletionOnlyModel(model)) {
|
||||
return this.client
|
||||
} else {
|
||||
return this
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders(),
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAIResponseSdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAIResponseSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
return await sdk.responses.create(payload, options)
|
||||
}
|
||||
|
||||
private async handlePdfFile(file: FileType): Promise<OpenAI.Responses.ResponseInputFile | undefined> {
|
||||
if (file.size > 32 * MB) return undefined
|
||||
try {
|
||||
const pageCount = await window.api.file.pdfInfo(file.id + file.ext)
|
||||
if (pageCount > 100) return undefined
|
||||
} catch {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||
return {
|
||||
type: 'input_file',
|
||||
filename: file.origin_name,
|
||||
file_data: `data:application/pdf;base64,${data}`
|
||||
} as OpenAI.Responses.ResponseInputFile
|
||||
}
|
||||
|
||||
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
|
||||
const isVision = isVisionModel(model)
|
||||
const content = await this.getMessageContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||
if (message.role === 'assistant') {
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: content
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: content ? [{ type: 'input_text', text: content }] : []
|
||||
} as OpenAI.Responses.EasyInputMessage
|
||||
}
|
||||
}
|
||||
|
||||
const parts: OpenAI.Responses.ResponseInputContent[] = []
|
||||
if (content) {
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: content
|
||||
})
|
||||
}
|
||||
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (isVision) {
|
||||
if (imageBlock.file) {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: image.data as string
|
||||
})
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: imageBlock.url
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) continue
|
||||
|
||||
if (isVision && file.ext === '.pdf') {
|
||||
const pdfPart = await this.handlePdfFile(file)
|
||||
if (pdfPart) {
|
||||
parts.push(pdfPart)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
|
||||
return mcpToolsToOpenAIResponseTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return JSON.parse(toolCall.arguments)
|
||||
} catch {
|
||||
return toolCall.arguments
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.call_id,
|
||||
toolCallId: toolCall.call_id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): OpenAIResponseSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
type: 'function_call_output',
|
||||
call_id: mcpToolResponse.toolCallId,
|
||||
output: JSON.stringify(resp.content)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
|
||||
const content: OpenAI.Responses.ResponseInput = []
|
||||
content.push(...response.output)
|
||||
return content
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAIResponseSdkMessageParam[],
|
||||
output: OpenAI.Responses.Response | undefined,
|
||||
toolResults: OpenAIResponseSdkMessageParam[],
|
||||
toolCalls: OpenAIResponseSdkToolCall[]
|
||||
): OpenAIResponseSdkMessageParam[] {
|
||||
if (!output && toolCalls.length === 0) {
|
||||
return [...currentReqMessages, ...toolResults]
|
||||
}
|
||||
|
||||
if (!output) {
|
||||
return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])]
|
||||
}
|
||||
|
||||
const content = this.convertResponseToMessageContent(output)
|
||||
|
||||
const newReqMessages = [...currentReqMessages, ...content, ...(toolResults || [])]
|
||||
return newReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
|
||||
let sum = 0
|
||||
if ('content' in message) {
|
||||
if (typeof message.content === 'string') {
|
||||
sum += estimateTextTokens(message.content)
|
||||
} else if (Array.isArray(message.content)) {
|
||||
for (const part of message.content) {
|
||||
switch (part.type) {
|
||||
case 'input_text':
|
||||
sum += estimateTextTokens(part.text)
|
||||
break
|
||||
case 'input_image':
|
||||
sum += estimateTextTokens(part.image_url || '')
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
switch (message.type) {
|
||||
case 'function_call_output':
|
||||
sum += estimateTextTokens(message.output)
|
||||
break
|
||||
case 'function_call':
|
||||
sum += estimateTextTokens(message.arguments)
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
|
||||
if (typeof sdkPayload.input === 'string') {
|
||||
return [{ role: 'user', content: sdkPayload.input }]
|
||||
}
|
||||
return sdkPayload.input
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: OpenAIResponseSdkParams
|
||||
messages: OpenAIResponseSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemMessage: OpenAI.Responses.EasyInputMessage = {
|
||||
role: 'system',
|
||||
content: []
|
||||
}
|
||||
|
||||
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
|
||||
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
|
||||
text: assistant.prompt || '',
|
||||
type: 'input_text'
|
||||
}
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
systemMessage.role = 'developer'
|
||||
}
|
||||
|
||||
// 2. 设置工具
|
||||
let tools: OpenAI.Responses.Tool[] = []
|
||||
const { tools: extraTools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools, assistant)
|
||||
}
|
||||
systemMessageContent.push(systemMessageInput)
|
||||
systemMessage.content = systemMessageContent
|
||||
|
||||
// 3. 处理用户消息
|
||||
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
|
||||
if (typeof messages === 'string') {
|
||||
userMessage.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
userMessage.push(await this.convertMessageToSdkParam(message, model))
|
||||
}
|
||||
}
|
||||
// FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id)
|
||||
if (enableGenerateImage) {
|
||||
const finalAssistantMessage = userMessage.findLast(
|
||||
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
|
||||
) as OpenAI.Responses.EasyInputMessage
|
||||
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
|
||||
if (
|
||||
finalAssistantMessage &&
|
||||
Array.isArray(finalAssistantMessage.content) &&
|
||||
finalUserMessage &&
|
||||
Array.isArray(finalUserMessage.content)
|
||||
) {
|
||||
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
|
||||
}
|
||||
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
|
||||
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
|
||||
}
|
||||
|
||||
// 4. 最终请求消息
|
||||
let reqMessages: OpenAI.Responses.ResponseInput
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessage]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
type: 'web_search_preview'
|
||||
})
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
tools.push({
|
||||
type: 'image_generation',
|
||||
partial_images: streamOutput ? 2 : undefined
|
||||
})
|
||||
}
|
||||
|
||||
tools = tools.concat(extraTools)
|
||||
const commonParams = {
|
||||
model: model.id,
|
||||
input:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: reqMessages,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_output_tokens: maxTokens,
|
||||
stream: streamOutput,
|
||||
tools: !isEmpty(tools) ? tools : undefined,
|
||||
service_tier: this.getServiceTier(model),
|
||||
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
|
||||
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
|
||||
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
|
||||
}
|
||||
const sdkParams: OpenAIResponseSdkParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
|
||||
const toolCalls: OpenAIResponseSdkToolCall[] = []
|
||||
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
|
||||
let hasBeenCollectedToolCalls = false
|
||||
let hasReasoningSummary = false
|
||||
return () => ({
|
||||
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 处理chunk
|
||||
if ('output' in chunk) {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = chunk
|
||||
}
|
||||
for (const output of chunk.output) {
|
||||
switch (output.type) {
|
||||
case 'message':
|
||||
if (output.content[0].type === 'output_text') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: output.content[0].text
|
||||
})
|
||||
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: output.content[0].annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
case 'reasoning':
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: output.summary.map((s) => s.text).join('\n')
|
||||
})
|
||||
break
|
||||
case 'function_call':
|
||||
toolCalls.push(output)
|
||||
break
|
||||
case 'image_generation_call':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${output.result}`]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usage?.input_tokens || 0,
|
||||
completion_tokens: chunk.usage?.output_tokens || 0,
|
||||
total_tokens: chunk.usage?.total_tokens || 0
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
switch (chunk.type) {
|
||||
case 'response.output_item.added':
|
||||
if (chunk.item.type === 'function_call') {
|
||||
outputItems.push(chunk.item)
|
||||
}
|
||||
break
|
||||
case 'response.reasoning_summary_part.added':
|
||||
if (hasReasoningSummary) {
|
||||
const separator = '\n\n'
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: separator
|
||||
})
|
||||
}
|
||||
hasReasoningSummary = true
|
||||
break
|
||||
case 'response.reasoning_summary_text.delta':
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.generating':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.partial_image':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_DELTA,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.completed':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE
|
||||
})
|
||||
break
|
||||
case 'response.output_text.delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'response.function_call_arguments.done': {
|
||||
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
|
||||
(item) => item.id === chunk.item_id
|
||||
)
|
||||
if (outputItem) {
|
||||
if (outputItem.type === 'function_call') {
|
||||
toolCalls.push({
|
||||
...outputItem,
|
||||
arguments: chunk.arguments,
|
||||
status: 'completed'
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.content_part.done': {
|
||||
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: chunk.part.annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
hasBeenCollectedToolCalls = true
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.completed': {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = chunk.response
|
||||
}
|
||||
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
hasBeenCollectedToolCalls = true
|
||||
}
|
||||
const completion_tokens = chunk.response.usage?.output_tokens || 0
|
||||
const total_tokens = chunk.response.usage?.total_tokens || 0
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.response.usage?.input_tokens || 0,
|
||||
completion_tokens: completion_tokens,
|
||||
total_tokens: total_tokens
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'error': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
message: chunk.message,
|
||||
code: chunk.code
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
65
src/renderer/src/aiCore/clients/ppio/PPIOAPIClient.ts
Normal file
65
src/renderer/src/aiCore/clients/ppio/PPIOAPIClient.ts
Normal file
@ -0,0 +1,65 @@
|
||||
import { isSupportedModel } from '@renderer/config/models'
|
||||
import { Provider } from '@renderer/types'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
|
||||
|
||||
export class PPIOAPIClient extends OpenAIAPIClient {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
// PPIO requires three separate requests to get all model types
|
||||
const [chatModelsResponse, embeddingModelsResponse, rerankerModelsResponse] = await Promise.all([
|
||||
// Chat/completion models
|
||||
sdk.request({
|
||||
method: 'get',
|
||||
path: '/models'
|
||||
}),
|
||||
// Embedding models
|
||||
sdk.request({
|
||||
method: 'get',
|
||||
path: '/models?model_type=embedding'
|
||||
}),
|
||||
// Reranker models
|
||||
sdk.request({
|
||||
method: 'get',
|
||||
path: '/models?model_type=reranker'
|
||||
})
|
||||
])
|
||||
|
||||
// Extract models from all responses
|
||||
// @ts-ignore - PPIO response structure may not be typed
|
||||
const allModels = [
|
||||
...((chatModelsResponse as any)?.data || []),
|
||||
...((embeddingModelsResponse as any)?.data || []),
|
||||
...((rerankerModelsResponse as any)?.data || [])
|
||||
]
|
||||
|
||||
// Process and standardize model data
|
||||
const processedModels = allModels.map((model: any) => ({
|
||||
id: model.id || model.name,
|
||||
description: model.description || model.display_name || model.summary,
|
||||
object: 'model' as const,
|
||||
owned_by: model.owned_by || model.publisher || model.organization || 'ppio',
|
||||
created: model.created || Date.now()
|
||||
}))
|
||||
|
||||
// Clean up model IDs and filter supported models
|
||||
processedModels.forEach((model) => {
|
||||
if (model.id) {
|
||||
model.id = model.id.trim()
|
||||
}
|
||||
})
|
||||
|
||||
return processedModels.filter(isSupportedModel)
|
||||
} catch (error) {
|
||||
console.error('Error listing PPIO models:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
}
|
||||
140
src/renderer/src/aiCore/clients/types.ts
Normal file
140
src/renderer/src/aiCore/clients/types.ts
Normal file
@ -0,0 +1,140 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import { Provider } from '@renderer/types'
|
||||
import {
|
||||
AnthropicSdkRawChunk,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAISdkRawChunk,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
|
||||
import { CompletionsContext } from '../middleware/types'
|
||||
|
||||
/**
|
||||
* 原始流监听器接口
|
||||
*/
|
||||
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
|
||||
onChunk?: (chunk: TRawChunk) => void
|
||||
onStart?: () => void
|
||||
onEnd?: () => void
|
||||
onError?: (error: Error) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI 专用的流监听器
|
||||
*/
|
||||
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
|
||||
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
|
||||
onFinishReason?: (reason: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI Response 专用的流监听器
|
||||
*/
|
||||
export interface OpenAIResponseStreamListener<TChunk extends OpenAIResponseSdkRawChunk = OpenAIResponseSdkRawChunk>
|
||||
extends RawStreamListener<TChunk> {
|
||||
onMessage?: (response: OpenAIResponseSdkRawOutput) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic 专用的流监听器
|
||||
*/
|
||||
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
|
||||
extends RawStreamListener<TChunk> {
|
||||
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
|
||||
onMessage?: (message: Anthropic.Messages.Message) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求转换器接口
|
||||
*/
|
||||
export interface RequestTransformer<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam
|
||||
> {
|
||||
transform(
|
||||
completionsParams: CompletionsParams,
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
isRecursiveCall?: boolean,
|
||||
recursiveSdkMessages?: TMessageParam[]
|
||||
): Promise<{
|
||||
payload: TSdkParams
|
||||
messages: TMessageParam[]
|
||||
metadata?: Record<string, any>
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* 响应块转换器接口
|
||||
*/
|
||||
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
|
||||
context?: TContext
|
||||
) => Transformer<TRawChunk, GenericChunk>
|
||||
|
||||
export interface ResponseChunkTransformerContext {
|
||||
isStreaming: boolean
|
||||
isEnabledToolCalling: boolean
|
||||
isEnabledWebSearch: boolean
|
||||
isEnabledReasoning: boolean
|
||||
mcpTools: MCPTool[]
|
||||
provider: Provider
|
||||
}
|
||||
|
||||
/**
|
||||
* API客户端接口
|
||||
*/
|
||||
export interface ApiClient<
|
||||
TSdkInstance = any,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> {
|
||||
provider: Provider
|
||||
|
||||
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
|
||||
// 实际的SDK调用由SdkCallMiddleware处理
|
||||
// completions(params: CompletionsParams): Promise<CompletionsResult>
|
||||
|
||||
createCompletions(payload: TSdkParams): Promise<TRawOutput>
|
||||
|
||||
// SDK相关方法
|
||||
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
// 原始流监听方法
|
||||
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
|
||||
|
||||
// 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具)
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
convertMcpToolResponseToSdkMessageParam?(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: any,
|
||||
model: Model
|
||||
): TMessageParam | undefined
|
||||
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
// 构建SDK特定的消息列表,用于工具调用后的递归调用
|
||||
buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
// 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
}
|
||||
132
src/renderer/src/aiCore/index.ts
Normal file
132
src/renderer/src/aiCore/index.ts
Normal file
@ -0,0 +1,132 @@
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { OpenAIAPIClient } from './clients'
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
if (!params.enableReasoning) {
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
builder.remove(ThinkChunkMiddlewareName)
|
||||
}
|
||||
// 注意:用client判断会导致typescript类型收窄
|
||||
if (!(this.apiClient instanceof OpenAIAPIClient)) {
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
}
|
||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
}
|
||||
if (params.callType !== 'chat') {
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
return wrappedCompletionMethod(params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
console.error('Error getting embedding dimensions:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
182
src/renderer/src/aiCore/middleware/BUILDER_USAGE.md
Normal file
182
src/renderer/src/aiCore/middleware/BUILDER_USAGE.md
Normal file
@ -0,0 +1,182 @@
|
||||
# MiddlewareBuilder 使用指南
|
||||
|
||||
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
|
||||
|
||||
## 主要特性
|
||||
|
||||
### 1. 统一的中间件命名
|
||||
|
||||
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
|
||||
|
||||
```typescript
|
||||
// 中间件文件示例
|
||||
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
|
||||
export const SdkCallMiddleware: CompletionsMiddleware = ...
|
||||
```
|
||||
|
||||
### 2. NamedMiddleware 接口
|
||||
|
||||
中间件使用统一的 `NamedMiddleware` 接口格式:
|
||||
|
||||
```typescript
|
||||
interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 中间件注册表
|
||||
|
||||
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
|
||||
|
||||
```typescript
|
||||
import { MiddlewareRegistry } from './register'
|
||||
|
||||
// 通过名称获取中间件
|
||||
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
|
||||
```
|
||||
|
||||
## 基本用法
|
||||
|
||||
### 1. 使用默认中间件链
|
||||
|
||||
```typescript
|
||||
import { CompletionsMiddlewareBuilder } from './builder'
|
||||
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 2. 自定义中间件链
|
||||
|
||||
```typescript
|
||||
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
|
||||
|
||||
const builder = createCompletionsBuilder([
|
||||
MiddlewareRegistry['AbortHandlerMiddleware'],
|
||||
MiddlewareRegistry['TextChunkMiddleware']
|
||||
])
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 3. 动态调整中间件链
|
||||
|
||||
```typescript
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 根据条件添加、移除、替换中间件
|
||||
if (needsLogging) {
|
||||
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
|
||||
}
|
||||
|
||||
if (disableTools) {
|
||||
builder.remove('McpToolChunkMiddleware')
|
||||
}
|
||||
|
||||
if (customThinking) {
|
||||
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 4. 链式操作
|
||||
|
||||
```typescript
|
||||
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
|
||||
.add(MiddlewareRegistry['CustomMiddleware'])
|
||||
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
|
||||
.remove('WebSearchMiddleware')
|
||||
.build()
|
||||
```
|
||||
|
||||
## API 参考
|
||||
|
||||
### CompletionsMiddlewareBuilder
|
||||
|
||||
**静态方法:**
|
||||
|
||||
- `static withDefaults()`: 创建带有默认中间件链的构建器
|
||||
|
||||
**实例方法:**
|
||||
|
||||
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
|
||||
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
|
||||
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
|
||||
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
|
||||
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
|
||||
- `remove(targetName: string)`: 移除指定中间件
|
||||
- `has(name: string)`: 检查是否包含指定中间件
|
||||
- `build()`: 构建最终的中间件数组
|
||||
- `getChain()`: 获取当前链(包含名称信息)
|
||||
- `clear()`: 清空中间件链
|
||||
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
|
||||
|
||||
### 工厂函数
|
||||
|
||||
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
|
||||
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
|
||||
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
|
||||
|
||||
### 中间件注册表
|
||||
|
||||
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
|
||||
- `getMiddleware(name)`: 根据名称获取中间件
|
||||
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
|
||||
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式)
|
||||
|
||||
## 类型安全
|
||||
|
||||
构建器提供完整的 TypeScript 类型支持:
|
||||
|
||||
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
|
||||
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
|
||||
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
|
||||
|
||||
## 默认中间件链
|
||||
|
||||
默认的 Completions 中间件执行顺序:
|
||||
|
||||
1. `FinalChunkConsumerMiddleware` - 最终消费者
|
||||
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
|
||||
3. `AbortHandlerMiddleware` - 中止处理
|
||||
4. `McpToolChunkMiddleware` - 工具处理
|
||||
5. `WebSearchMiddleware` - Web搜索处理
|
||||
6. `TextChunkMiddleware` - 文本处理
|
||||
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
|
||||
8. `ThinkChunkMiddleware` - 思考处理
|
||||
9. `ResponseTransformMiddleware` - 响应转换
|
||||
10. `StreamAdapterMiddleware` - 流适配器
|
||||
11. `SdkCallMiddleware` - SDK调用
|
||||
|
||||
## 在 AiProvider 中的使用
|
||||
|
||||
```typescript
|
||||
export default class AiProvider {
|
||||
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
|
||||
// 1. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 2. 根据参数动态调整
|
||||
if (params.enableCustomFeature) {
|
||||
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
|
||||
}
|
||||
|
||||
// 3. 应用中间件
|
||||
const middlewares = builder.build()
|
||||
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
|
||||
|
||||
return wrappedMethod(params)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器
|
||||
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
|
||||
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
|
||||
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
|
||||
|
||||
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。
|
||||
175
src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md
Normal file
175
src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md
Normal file
@ -0,0 +1,175 @@
|
||||
# Cherry Studio 中间件规范
|
||||
|
||||
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
|
||||
|
||||
## 1. 核心概念
|
||||
|
||||
### 1.1. 中间件 (Middleware)
|
||||
|
||||
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
|
||||
|
||||
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
|
||||
|
||||
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
|
||||
|
||||
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
|
||||
|
||||
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
|
||||
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
|
||||
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
|
||||
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
|
||||
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
|
||||
- `abortController?: AbortController`: 用于中止请求的控制器。
|
||||
- 其他中间件可能读写的、与当前请求相关的动态数据。
|
||||
|
||||
### 1.3. `MiddlewareName` (中间件名称)
|
||||
|
||||
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
|
||||
|
||||
```typescript
|
||||
// example
|
||||
export enum MiddlewareName {
|
||||
LOGGING_START = 'LoggingStartMiddleware',
|
||||
LOGGING_END = 'LoggingEndMiddleware',
|
||||
ERROR_HANDLING = 'ErrorHandlingMiddleware',
|
||||
ABORT_HANDLER = 'AbortHandlerMiddleware',
|
||||
// Core Flow
|
||||
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
|
||||
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
|
||||
STREAM_ADAPTER = 'StreamAdapterMiddleware',
|
||||
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
|
||||
// Features
|
||||
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
|
||||
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
|
||||
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
|
||||
// Finalization
|
||||
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
// Add more as needed
|
||||
}
|
||||
```
|
||||
|
||||
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
|
||||
|
||||
### 1.4. 中间件执行结构
|
||||
|
||||
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
|
||||
|
||||
```typescript
|
||||
// 简化形式的中间件函数签名
|
||||
type MiddlewareFunction = (
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any, // e.g., CompletionsParams
|
||||
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
|
||||
) => Promise<void> // 中间件自身也可能返回 Promise
|
||||
|
||||
// 或者更经典的 Koa/Express 风格 (三段式)
|
||||
// type MiddlewareFactory = (api?: MiddlewareApi) =>
|
||||
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
|
||||
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
|
||||
// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。
|
||||
```
|
||||
|
||||
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
|
||||
|
||||
## 2. `MiddlewareBuilder` (通用中间件构建器)
|
||||
|
||||
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
|
||||
|
||||
### 2.1. 设计理念
|
||||
|
||||
`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。
|
||||
|
||||
### 2.2. API 概览
|
||||
|
||||
```typescript
|
||||
class MiddlewareBuilder {
|
||||
constructor(baseChain?: Middleware[])
|
||||
|
||||
add(middleware: Middleware): this
|
||||
prepend(middleware: Middleware): this
|
||||
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
|
||||
remove(targetName: MiddlewareName): this
|
||||
|
||||
build(): Middleware[] // 返回构建好的中间件数组
|
||||
|
||||
// 可选:直接执行链
|
||||
execute(
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any,
|
||||
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
|
||||
): void
|
||||
}
|
||||
```
|
||||
|
||||
### 2.3. 使用示例
|
||||
|
||||
```typescript
|
||||
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
|
||||
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
|
||||
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
|
||||
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
|
||||
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
|
||||
|
||||
// 2. 定义一个基础链 (可选)
|
||||
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
|
||||
|
||||
// 3. 使用 MiddlewareBuilder
|
||||
const builder = new MiddlewareBuilder(BASE_CHAIN)
|
||||
|
||||
if (params.needsCustomFeature) {
|
||||
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
|
||||
}
|
||||
|
||||
if (params.isHighSecurityContext) {
|
||||
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
|
||||
}
|
||||
|
||||
if (params.overrideLogging) {
|
||||
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
|
||||
}
|
||||
|
||||
// 4. 获取最终链
|
||||
const finalChain = builder.build()
|
||||
|
||||
// 5. 执行 (通过外部执行器)
|
||||
// middlewareExecutor(finalChain, context, params);
|
||||
// 或者 builder.execute(context, params, middlewareExecutor);
|
||||
```
|
||||
|
||||
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
|
||||
|
||||
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
|
||||
|
||||
### 3.1. 职责
|
||||
|
||||
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。
|
||||
- 按顺序迭代中间件。
|
||||
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
|
||||
- 处理中间件执行过程中的Promise(如果中间件是异步的)。
|
||||
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
|
||||
|
||||
## 4. 在 `AiCoreService` 中使用
|
||||
|
||||
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
|
||||
|
||||
1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。
|
||||
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
|
||||
3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
|
||||
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
|
||||
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
|
||||
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
|
||||
|
||||
## 5. 组合功能
|
||||
|
||||
对于组合功能(例如 "Completions then Translate"):
|
||||
|
||||
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
|
||||
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
|
||||
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
|
||||
- 这种方式最大化了复用,并保持了各部分职责的清晰。
|
||||
|
||||
## 6. 中间件命名和发现
|
||||
|
||||
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。
|
||||
241
src/renderer/src/aiCore/middleware/builder.ts
Normal file
241
src/renderer/src/aiCore/middleware/builder.ts
Normal file
@ -0,0 +1,241 @@
|
||||
import { DefaultCompletionsNamedMiddlewares } from './register'
|
||||
import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
|
||||
|
||||
/**
|
||||
* 带有名称标识的中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
|
||||
/**
|
||||
* 中间件执行器函数类型
|
||||
*/
|
||||
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
|
||||
chain: any[],
|
||||
context: TContext,
|
||||
params: any
|
||||
) => Promise<any>
|
||||
|
||||
/**
|
||||
* 通用中间件构建器类
|
||||
* 提供流式 API 用于动态构建和管理中间件链
|
||||
*
|
||||
* 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式
|
||||
*/
|
||||
export class MiddlewareBuilder<TMiddleware = any> {
|
||||
private middlewares: NamedMiddleware<TMiddleware>[]
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param baseChain - 可选的基础中间件链(NamedMiddleware 格式)
|
||||
*/
|
||||
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
|
||||
this.middlewares = baseChain ? [...baseChain] : []
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的末尾添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
add(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.push(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的开头添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
prepend(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.unshift(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之后插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index + 1, 0, middlewareToInsert)
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之前插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 0, middlewareToInsert)
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 替换指定的中间件
|
||||
* @param targetName - 要替换的中间件名称
|
||||
* @param newMiddleware - 新的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares[index] = newMiddleware
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除指定的中间件
|
||||
* @param targetName - 要移除的中间件名称
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
remove(targetName: string): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 1)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
* @returns 构建好的中间件数组
|
||||
*/
|
||||
build(): TMiddleware[] {
|
||||
return this.middlewares.map((item) => item.middleware)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前中间件链的副本(包含名称信息)
|
||||
* @returns 当前中间件链的副本
|
||||
*/
|
||||
getChain(): NamedMiddleware<TMiddleware>[] {
|
||||
return [...this.middlewares]
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否包含指定名称的中间件
|
||||
* @param name - 中间件名称
|
||||
* @returns 是否包含该中间件
|
||||
*/
|
||||
has(name: string): boolean {
|
||||
return this.findMiddlewareIndex(name) !== -1
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取中间件链的长度
|
||||
* @returns 中间件数量
|
||||
*/
|
||||
get length(): number {
|
||||
return this.middlewares.length
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空中间件链
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
clear(): this {
|
||||
this.middlewares = []
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接执行构建好的中间件链
|
||||
* @param context - 中间件上下文
|
||||
* @param params - 参数
|
||||
* @param middlewareExecutor - 中间件执行器
|
||||
* @returns 执行结果
|
||||
*/
|
||||
execute<TContext extends BaseContext>(
|
||||
context: TContext,
|
||||
params: any,
|
||||
middlewareExecutor: MiddlewareExecutor<TContext>
|
||||
): Promise<any> {
|
||||
const chain = this.build()
|
||||
return middlewareExecutor(chain, context, params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 查找中间件在链中的索引
|
||||
* @param name - 中间件名称
|
||||
* @returns 索引,如果未找到返回 -1
|
||||
*/
|
||||
private findMiddlewareIndex(name: string): number {
|
||||
return this.middlewares.findIndex((item) => item.name === name)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Completions 中间件构建器
|
||||
*/
|
||||
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用默认的 Completions 中间件链
|
||||
* @returns CompletionsMiddlewareBuilder 实例
|
||||
*/
|
||||
static withDefaults(): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用方法中间件构建器
|
||||
*/
|
||||
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
}
|
||||
|
||||
// 便捷的工厂函数
|
||||
|
||||
/**
|
||||
* 创建 Completions 中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns Completions 中间件构建器实例
|
||||
*/
|
||||
export function createCompletionsBuilder(
|
||||
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
|
||||
): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建通用方法中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns 通用方法中间件构建器实例
|
||||
*/
|
||||
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
|
||||
return new MethodMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 为中间件添加名称属性的辅助函数
|
||||
* 可以用于给现有的中间件添加名称属性
|
||||
*/
|
||||
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
|
||||
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
|
||||
}
|
||||
@ -0,0 +1,106 @@
|
||||
import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
|
||||
|
||||
export const AbortHandlerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
|
||||
|
||||
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
|
||||
if (isRecursiveCall) {
|
||||
const result = await next(ctx, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// 获取当前消息的ID用于abort管理
|
||||
// 优先使用处理过的消息,如果没有则使用原始消息
|
||||
let messageId: string | undefined
|
||||
|
||||
if (typeof params.messages === 'string') {
|
||||
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
|
||||
} else {
|
||||
const processedMessages = params.messages
|
||||
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
|
||||
messageId = lastUserMessage?.id
|
||||
}
|
||||
|
||||
if (!messageId) {
|
||||
console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`)
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
const abortController = new AbortController()
|
||||
const abortFn = (): void => abortController.abort()
|
||||
|
||||
addAbortController(messageId, abortFn)
|
||||
|
||||
let abortSignal: AbortSignal | null = abortController.signal
|
||||
|
||||
const cleanup = (): void => {
|
||||
removeAbortController(messageId as string, abortFn)
|
||||
if (ctx._internal?.flowControl) {
|
||||
ctx._internal.flowControl.abortController = undefined
|
||||
ctx._internal.flowControl.abortSignal = undefined
|
||||
ctx._internal.flowControl.cleanup = undefined
|
||||
}
|
||||
abortSignal = null
|
||||
}
|
||||
|
||||
// 将controller添加到_internal中的flowControl状态
|
||||
if (!ctx._internal.flowControl) {
|
||||
ctx._internal.flowControl = {}
|
||||
}
|
||||
ctx._internal.flowControl.abortController = abortController
|
||||
ctx._internal.flowControl.abortSignal = abortSignal
|
||||
ctx._internal.flowControl.cleanup = cleanup
|
||||
|
||||
const result = await next(ctx, params)
|
||||
|
||||
const error = new DOMException('Request was aborted', 'AbortError')
|
||||
|
||||
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
||||
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
||||
transform(chunk, controller) {
|
||||
// 检查 abort 状态
|
||||
if (abortSignal?.aborted) {
|
||||
// 转换为 ErrorChunk
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
|
||||
controller.enqueue(errorChunk)
|
||||
cleanup()
|
||||
return
|
||||
}
|
||||
|
||||
// 正常传递 chunk
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
|
||||
flush(controller) {
|
||||
// 在流结束时再次检查 abort 状态
|
||||
if (abortSignal?.aborted) {
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
controller.enqueue(errorChunk)
|
||||
}
|
||||
// 在流完全处理完成后清理 AbortController
|
||||
cleanup()
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: streamWithAbortHandler
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,56 @@
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext } from '../types'
|
||||
import { createErrorChunk } from '../utils'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
|
||||
|
||||
/**
|
||||
* 创建一个错误处理中间件。
|
||||
*
|
||||
* 这是一个高阶函数,它接收配置并返回一个标准的中间件。
|
||||
* 它的主要职责是捕获下游中间件或API调用中发生的任何错误。
|
||||
*
|
||||
* @param config - 中间件的配置。
|
||||
* @returns 一个配置好的CompletionsMiddleware。
|
||||
*/
|
||||
export const ErrorHandlerMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
|
||||
const { shouldThrow } = params
|
||||
|
||||
try {
|
||||
// 尝试执行下一个中间件
|
||||
return await next(ctx, params)
|
||||
} catch (error: any) {
|
||||
console.log('ErrorHandlerMiddleware_error', error)
|
||||
// 1. 使用通用的工具函数将错误解析为标准格式
|
||||
const errorChunk = createErrorChunk(error)
|
||||
// 2. 调用从外部传入的 onError 回调
|
||||
if (params.onError) {
|
||||
params.onError(error)
|
||||
}
|
||||
|
||||
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
|
||||
if (shouldThrow) {
|
||||
throw error
|
||||
}
|
||||
|
||||
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
|
||||
const errorStream = new ReadableStream<Chunk>({
|
||||
start(controller) {
|
||||
controller.enqueue(errorChunk)
|
||||
controller.close()
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
rawOutput: undefined,
|
||||
stream: errorStream, // 将包含错误的流传递下去
|
||||
controller: undefined,
|
||||
getText: () => '' // 错误情况下没有文本结果
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,183 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { Usage } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
|
||||
/**
|
||||
* 最终Chunk消费和通知中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调
|
||||
* 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取)
|
||||
* 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
|
||||
* 4. 处理MCP工具调用的多轮请求中的数据累加
|
||||
*/
|
||||
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall =
|
||||
params._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
ctx._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
false
|
||||
|
||||
// 初始化累计数据(只在顶层调用时初始化)
|
||||
if (!isRecursiveCall) {
|
||||
if (!ctx._internal.customState) {
|
||||
ctx._internal.customState = {}
|
||||
}
|
||||
ctx._internal.observer = {
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
},
|
||||
metrics: {
|
||||
completion_tokens: 0,
|
||||
time_completion_millsec: 0,
|
||||
time_first_token_millsec: 0,
|
||||
time_thinking_millsec: 0
|
||||
}
|
||||
}
|
||||
// 初始化文本累积器
|
||||
ctx._internal.customState.accumulatedText = ''
|
||||
ctx._internal.customState.startTimestamp = Date.now()
|
||||
}
|
||||
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理GenericChunk流式响应
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const reader = resultFromUpstream.getReader()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value: chunk } = await reader.read()
|
||||
if (done) {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`)
|
||||
break
|
||||
}
|
||||
|
||||
if (chunk) {
|
||||
const genericChunk = chunk as GenericChunk
|
||||
// 提取并累加usage/metrics数据
|
||||
extractAndAccumulateUsageMetrics(ctx, genericChunk)
|
||||
|
||||
const shouldSkipChunk =
|
||||
isRecursiveCall &&
|
||||
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
|
||||
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
|
||||
|
||||
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error)
|
||||
throw error
|
||||
} finally {
|
||||
if (params.onChunk && !isRecursiveCall) {
|
||||
params.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
|
||||
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
|
||||
}
|
||||
} as Chunk)
|
||||
if (ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 为流式输出添加getText方法
|
||||
const modifiedResult = {
|
||||
...result,
|
||||
stream: new ReadableStream<GenericChunk>({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
}
|
||||
}),
|
||||
getText: () => {
|
||||
return ctx._internal.customState?.accumulatedText || ''
|
||||
}
|
||||
}
|
||||
|
||||
return modifiedResult
|
||||
} else {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
|
||||
*/
|
||||
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
|
||||
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.customState.firstTokenTimestamp = Date.now()
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
||||
}
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
|
||||
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
||||
if (chunk.response?.usage) {
|
||||
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
||||
}
|
||||
|
||||
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.observer.metrics.time_first_token_millsec =
|
||||
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
|
||||
ctx._internal.observer.metrics.time_completion_millsec +=
|
||||
Date.now() - ctx._internal.customState.firstTokenTimestamp
|
||||
}
|
||||
}
|
||||
|
||||
// 也可以从其他chunk类型中提取metrics数据
|
||||
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
|
||||
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
|
||||
ctx._internal.observer.metrics.time_thinking_millsec || 0,
|
||||
chunk.thinking_millsec
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加usage数据
|
||||
*/
|
||||
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
|
||||
if (newUsage.prompt_tokens !== undefined) {
|
||||
accumulated.prompt_tokens += newUsage.prompt_tokens
|
||||
}
|
||||
if (newUsage.completion_tokens !== undefined) {
|
||||
accumulated.completion_tokens += newUsage.completion_tokens
|
||||
}
|
||||
if (newUsage.total_tokens !== undefined) {
|
||||
accumulated.total_tokens += newUsage.total_tokens
|
||||
}
|
||||
if (newUsage.thoughts_tokens !== undefined) {
|
||||
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
|
||||
}
|
||||
}
|
||||
|
||||
export default FinalChunkConsumerMiddleware
|
||||
@ -0,0 +1,64 @@
|
||||
import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
|
||||
|
||||
/**
|
||||
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
|
||||
* 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。
|
||||
* @param args - The arguments array to stringify. 要字符串化的参数数组。
|
||||
* @returns A string representation of the arguments. 参数的字符串表示形式。
|
||||
*/
|
||||
const stringifyArgsForLogging = (args: any[]): string => {
|
||||
try {
|
||||
return args
|
||||
.map((arg) => {
|
||||
if (typeof arg === 'function') return '[Function]'
|
||||
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
|
||||
return '[Object with >20 keys]'
|
||||
}
|
||||
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
|
||||
const stringifiedArg = JSON.stringify(arg, null, 2)
|
||||
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
|
||||
})
|
||||
.join(', ')
|
||||
} catch (e) {
|
||||
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* This middleware logs the initiation, success/failure, and duration of a method call.
|
||||
* 此中间件记录方法调用的启动、成功/失败以及持续时间。
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。
|
||||
*/
|
||||
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
|
||||
const middlewareName = 'GenericLoggingMiddleware'
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
|
||||
const methodName = ctx.methodName
|
||||
const logPrefix = `[${middlewareName} (${methodName})]`
|
||||
console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args))
|
||||
const startTime = Date.now()
|
||||
try {
|
||||
const result = await next(ctx, args)
|
||||
const duration = Date.now() - startTime
|
||||
// Log successful completion of the method call with duration. /
|
||||
// 记录方法调用成功完成及其持续时间。
|
||||
console.log(`${logPrefix} Successful. Duration: ${duration}ms`)
|
||||
return result
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
// Log failure of the method call with duration and error information. /
|
||||
// 记录方法调用失败及其持续时间和错误信息。
|
||||
console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error)
|
||||
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
|
||||
}
|
||||
}
|
||||
}
|
||||
285
src/renderer/src/aiCore/middleware/composer.ts
Normal file
285
src/renderer/src/aiCore/middleware/composer.ts
Normal file
@ -0,0 +1,285 @@
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { BaseApiClient } from '../clients'
|
||||
import { CompletionsParams, CompletionsResult } from './schemas'
|
||||
import {
|
||||
BaseContext,
|
||||
CompletionsContext,
|
||||
CompletionsMiddleware,
|
||||
MethodMiddleware,
|
||||
MIDDLEWARE_CONTEXT_SYMBOL,
|
||||
MiddlewareAPI
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
* Creates the initial context for a method call, populating method-specific fields. /
|
||||
* 为方法调用创建初始上下文,并填充特定于该方法的字段。
|
||||
* @param methodName - The name of the method being called. / 被调用的方法名。
|
||||
* @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。
|
||||
* @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。
|
||||
* @param providerInstance - The instance of the provider. / 提供者实例。
|
||||
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。
|
||||
* @returns The created context object. / 创建的上下文对象。
|
||||
*/
|
||||
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
|
||||
methodName: string,
|
||||
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
|
||||
// Factory to create specific context from base and the *original call arguments array*
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
|
||||
): TContext {
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs // Store the full original arguments array in the context
|
||||
}
|
||||
|
||||
if (specificContextFactory) {
|
||||
return specificContextFactory(baseContext, originalCallArgs)
|
||||
}
|
||||
return baseContext as TContext // Fallback to base context if no specific factory
|
||||
}
|
||||
|
||||
/**
|
||||
* Composes an array of functions from right to left. /
|
||||
* 从右到左组合一个函数数组。
|
||||
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
|
||||
* `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。
|
||||
* Each function in funcs is expected to take the result of the next function
|
||||
* (or the initial value for the rightmost function) as its argument. /
|
||||
* `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。
|
||||
* @param funcs - Array of functions to compose. / 要组合的函数数组。
|
||||
* @returns The composed function. / 组合后的函数。
|
||||
*/
|
||||
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
|
||||
if (funcs.length === 0) {
|
||||
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
|
||||
// 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。
|
||||
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
|
||||
}
|
||||
if (funcs.length === 1) {
|
||||
return funcs[0]
|
||||
}
|
||||
return funcs.reduce(
|
||||
(a, b) =>
|
||||
(...args: any[]) =>
|
||||
a(b(...args))
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of Redux-style middlewares to a generic provider method. /
|
||||
* 将一组Redux风格的中间件应用于一个通用的提供者方法。
|
||||
* This version keeps arguments as an array throughout the middleware chain. /
|
||||
* 此版本在整个中间件链中将参数保持为数组形式。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param methodName - The name of the method to be enhanced. / 需要增强的方法名。
|
||||
* @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。
|
||||
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。
|
||||
* @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。
|
||||
* @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。
|
||||
*/
|
||||
export function applyMethodMiddlewares<
|
||||
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
|
||||
TResult = unknown,
|
||||
TContext extends BaseContext = BaseContext
|
||||
>(
|
||||
methodName: string,
|
||||
originalMethod: (...args: TArgs) => Promise<TResult>,
|
||||
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
|
||||
): (...args: TArgs) => Promise<TResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
|
||||
const ctx = createInitialCallContext<TContext, TArgs>(
|
||||
methodName,
|
||||
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
|
||||
specificContextFactory
|
||||
)
|
||||
|
||||
const api: MiddlewareAPI<TContext, TArgs> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
|
||||
}
|
||||
|
||||
// `finalDispatch` is the function that will ultimately call the original provider method. /
|
||||
// `finalDispatch` 是最终将调用原始提供者方法的函数。
|
||||
// It receives the current context and arguments, which may have been transformed by middlewares. /
|
||||
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
|
||||
const finalDispatch = async (
|
||||
_: TContext,
|
||||
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
|
||||
): Promise<TResult> => {
|
||||
return originalMethod.apply(currentArgs)
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
|
||||
* 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。
|
||||
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
|
||||
* 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。
|
||||
* @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。
|
||||
* @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。
|
||||
*/
|
||||
export function applyCompletionsMiddlewares<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
>(
|
||||
originalApiClientInstance: BaseApiClient<
|
||||
TSdkInstance,
|
||||
TSdkParams,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
|
||||
middlewares: CompletionsMiddleware<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>[]
|
||||
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
|
||||
const methodName = 'completions'
|
||||
|
||||
// Factory to create AiProviderMiddlewareCompletionsContext. /
|
||||
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
|
||||
const completionsContextFactory = (
|
||||
base: BaseContext,
|
||||
callArgs: [CompletionsParams]
|
||||
): CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> => {
|
||||
return {
|
||||
...base,
|
||||
methodName,
|
||||
apiClientInstance: originalApiClientInstance,
|
||||
originalArgs: callArgs,
|
||||
_internal: {
|
||||
toolProcessingState: {
|
||||
recursionDepth: 0,
|
||||
isRecursiveCall: false
|
||||
},
|
||||
observer: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return async function enhancedCompletionsMethod(
|
||||
params: CompletionsParams,
|
||||
options?: RequestOptions
|
||||
): Promise<CompletionsResult> {
|
||||
// `originalCallArgs` for context creation is `[params]`. /
|
||||
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
|
||||
const originalCallArgs: [CompletionsParams] = [params]
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs
|
||||
}
|
||||
const ctx = completionsContextFactory(baseContext, originalCallArgs)
|
||||
|
||||
const api: MiddlewareAPI<
|
||||
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
|
||||
[CompletionsParams]
|
||||
> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
|
||||
}
|
||||
|
||||
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
|
||||
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
|
||||
const finalDispatch = async (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> // Context passed through / 上下文透传
|
||||
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
|
||||
): Promise<CompletionsResult> => {
|
||||
// At this point, middleware should have transformed CompletionsParams to SDK params
|
||||
// and stored them in context. If no transformation happened, we need to handle it.
|
||||
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
|
||||
// 如果没有进行转换,我们需要处理它。
|
||||
|
||||
const sdkPayload = context._internal?.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
|
||||
}
|
||||
|
||||
const abortSignal = context._internal.flowControl?.abortSignal
|
||||
const timeout = context._internal.customState?.sdkMetadata?.timeout
|
||||
|
||||
// Call the original SDK method with transformed parameters
|
||||
// 使用转换后的参数调用原始 SDK 方法
|
||||
const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, {
|
||||
...options,
|
||||
signal: abortSignal,
|
||||
timeout
|
||||
})
|
||||
|
||||
// Return result wrapped in CompletionsResult format
|
||||
// 以 CompletionsResult 格式返回包装的结果
|
||||
return {
|
||||
rawOutput
|
||||
} as CompletionsResult
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api))
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
|
||||
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
|
||||
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
|
||||
// 这样可以避免重复执行整个中间件链
|
||||
ctx._internal.enhancedDispatch = enhancedDispatch
|
||||
|
||||
// Execute with context and the single params object. /
|
||||
// 使用上下文和单个参数对象执行。
|
||||
return enhancedDispatch(ctx, params)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,310 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
|
||||
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
||||
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
|
||||
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
|
||||
|
||||
/**
|
||||
* MCP工具处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式)
|
||||
* 2. 执行工具调用
|
||||
* 3. 递归处理工具结果
|
||||
* 4. 管理工具调用状态和递归深度
|
||||
*/
|
||||
export const McpToolChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
// 如果没有工具,直接调用下一个中间件
|
||||
if (!mcpTools || mcpTools.length === 0) {
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
|
||||
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
}
|
||||
|
||||
let result: CompletionsResult
|
||||
|
||||
if (depth === 0) {
|
||||
result = await next(ctx, currentParams)
|
||||
} else {
|
||||
const enhancedCompletions = ctx._internal.enhancedDispatch
|
||||
if (!enhancedCompletions) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`)
|
||||
throw new Error('Enhanced completions method not found')
|
||||
}
|
||||
|
||||
ctx._internal.toolProcessingState!.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState!.recursionDepth = depth
|
||||
|
||||
result = await enhancedCompletions(ctx, currentParams)
|
||||
}
|
||||
|
||||
if (!result.stream) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`)
|
||||
throw new Error('No stream returned from enhanced completions')
|
||||
}
|
||||
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
const toolHandlingStream = resultFromUpstream.pipeThrough(
|
||||
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: toolHandlingStream
|
||||
}
|
||||
}
|
||||
|
||||
return executeWithToolHandling(params, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建工具处理的 TransformStream
|
||||
*/
|
||||
function createToolHandlingTransform(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
mcpTools: MCPTool[],
|
||||
depth: number,
|
||||
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const toolCalls: SdkToolCall[] = []
|
||||
const toolUseResponses: MCPToolResponse[] = []
|
||||
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
|
||||
let hasToolCalls = false
|
||||
let hasToolUseResponses = false
|
||||
let streamEnded = false
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
// 处理MCP工具进展chunk
|
||||
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
|
||||
const createdChunk = chunk as MCPToolCreatedChunk
|
||||
|
||||
// 1. 处理Function Call方式的工具调用
|
||||
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
||||
toolCalls.push(...createdChunk.tool_calls)
|
||||
hasToolCalls = true
|
||||
}
|
||||
|
||||
// 2. 处理Tool Use方式的工具调用
|
||||
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
||||
toolUseResponses.push(...createdChunk.tool_use_responses)
|
||||
hasToolUseResponses = true
|
||||
}
|
||||
|
||||
// 不转发MCP工具进展chunks,避免重复处理
|
||||
return
|
||||
}
|
||||
|
||||
// 转发其他所有chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
|
||||
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
|
||||
|
||||
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
|
||||
streamEnded = true
|
||||
|
||||
try {
|
||||
let toolResult: SdkMessageParam[] = []
|
||||
|
||||
if (shouldExecuteToolCalls) {
|
||||
toolResult = await executeToolCalls(
|
||||
ctx,
|
||||
toolCalls,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
} else if (shouldExecuteToolUseResponses) {
|
||||
toolResult = await executeToolUseResponses(
|
||||
ctx,
|
||||
toolUseResponses,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
}
|
||||
|
||||
if (toolResult.length > 0) {
|
||||
const output = ctx._internal.toolProcessingState?.output
|
||||
|
||||
const newParams = buildParamsWithToolResults(ctx, currentParams, output, toolResult, toolCalls)
|
||||
await executeWithToolHandling(newParams, depth + 1)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
||||
controller.error(error)
|
||||
} finally {
|
||||
hasToolCalls = false
|
||||
hasToolUseResponses = false
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具调用(Function Call 方式)
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
ctx: CompletionsContext,
|
||||
toolCalls: SdkToolCall[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
// 转换为MCPToolResponse格式
|
||||
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
if (!mcpTool) {
|
||||
return undefined
|
||||
}
|
||||
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
})
|
||||
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||
|
||||
if (mcpToolResponses.length === 0) {
|
||||
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
|
||||
return []
|
||||
}
|
||||
|
||||
// 使用现有的parseAndCallTools函数执行工具
|
||||
const toolResults = await parseAndCallTools(
|
||||
mcpToolResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
)
|
||||
|
||||
return toolResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具使用响应(Tool Use Response 方式)
|
||||
* 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串
|
||||
*/
|
||||
async function executeToolUseResponses(
|
||||
ctx: CompletionsContext,
|
||||
toolUseResponses: MCPToolResponse[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
||||
const toolResults = await parseAndCallTools(
|
||||
toolUseResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
)
|
||||
|
||||
return toolResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建包含工具结果的新参数
|
||||
*/
|
||||
function buildParamsWithToolResults(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
output: SdkRawOutput | string | undefined,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls: SdkToolCall[]
|
||||
): CompletionsParams {
|
||||
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
||||
const currentReqMessages = getCurrentReqMessages(ctx)
|
||||
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
// 从回复中构建助手消息
|
||||
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
|
||||
if (output && ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = undefined
|
||||
}
|
||||
|
||||
// 估算新增消息的 token 消耗并累加到 usage 中
|
||||
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
|
||||
try {
|
||||
const newMessages = newReqMessages.slice(currentReqMessages.length)
|
||||
const additionalTokens = newMessages.reduce((acc, message) => {
|
||||
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
|
||||
}, 0)
|
||||
|
||||
if (additionalTokens > 0) {
|
||||
ctx._internal.observer.usage.prompt_tokens += additionalTokens
|
||||
ctx._internal.observer.usage.total_tokens += additionalTokens
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新递归状态
|
||||
if (!ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
ctx._internal.toolProcessingState.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
|
||||
|
||||
return {
|
||||
...currentParams,
|
||||
_internal: {
|
||||
...ctx._internal,
|
||||
sdkPayload: ctx._internal.sdkPayload,
|
||||
newReqMessages: newReqMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型安全地获取当前请求消息
|
||||
* 使用API客户端提供的抽象方法,保持中间件的provider无关性
|
||||
*/
|
||||
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
|
||||
const sdkPayload = ctx._internal.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
return []
|
||||
}
|
||||
|
||||
// 使用API客户端的抽象方法来提取消息,保持provider无关性
|
||||
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
|
||||
export default McpToolChunkMiddleware
|
||||
@ -0,0 +1,46 @@
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { AnthropicStreamListener } from '../../clients/types'
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
|
||||
|
||||
export const RawStreamListenerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 在这里可以监听到从SDK返回的最原始流
|
||||
if (result.rawOutput) {
|
||||
const providerType = ctx.apiClientInstance.provider.type
|
||||
// TODO: 后面下放到AnthropicAPIClient
|
||||
if (providerType === 'anthropic') {
|
||||
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
|
||||
onMessage: (message) => {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = message
|
||||
}
|
||||
}
|
||||
// onContentBlock: (contentBlock) => {
|
||||
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
|
||||
// }
|
||||
}
|
||||
|
||||
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
|
||||
|
||||
const monitoredOutput = specificApiClient.attachRawStreamListener(
|
||||
result.rawOutput as AnthropicSdkRawOutput,
|
||||
anthropicListener
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
rawOutput: monitoredOutput
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,85 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||
|
||||
import { ResponseChunkTransformerContext } from '../../clients/types'
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
|
||||
|
||||
/**
|
||||
* 响应转换中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ReadableStream类型的响应流
|
||||
* 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式
|
||||
* 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用
|
||||
*
|
||||
* 注意:此中间件应该在StreamAdapterMiddleware之后执行
|
||||
*/
|
||||
export const ResponseTransformMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换原始SDK响应块
|
||||
if (result.stream) {
|
||||
const adaptedStream = result.stream
|
||||
|
||||
// 处理ReadableStream类型的流
|
||||
if (adaptedStream instanceof ReadableStream) {
|
||||
const apiClient = ctx.apiClientInstance
|
||||
if (!apiClient) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`)
|
||||
throw new Error('ApiClient instance not found in context')
|
||||
}
|
||||
|
||||
// 获取响应转换器
|
||||
const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx)
|
||||
if (!responseChunkTransformer) {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
|
||||
return result
|
||||
}
|
||||
|
||||
const assistant = params.assistant
|
||||
const model = assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`)
|
||||
throw new Error('Assistant or Model not found for transformation')
|
||||
}
|
||||
|
||||
const transformerContext: ResponseChunkTransformerContext = {
|
||||
isStreaming: params.streamOutput || false,
|
||||
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
|
||||
isEnabledWebSearch: params.enableWebSearch || false,
|
||||
isEnabledReasoning: params.enableReasoning || false,
|
||||
mcpTools: params.mcpTools || [],
|
||||
provider: ctx.apiClientInstance?.provider
|
||||
}
|
||||
|
||||
console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext)
|
||||
|
||||
try {
|
||||
// 创建转换后的流
|
||||
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
|
||||
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
|
||||
)
|
||||
|
||||
// 将转换后的ReadableStream保存到result,供下游中间件使用
|
||||
return {
|
||||
...result,
|
||||
stream: genericChunkTransformStream
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有流或不是ReadableStream,返回原始结果
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,56 @@
|
||||
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
import { isAsyncIterable } from '../utils'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
|
||||
|
||||
/**
|
||||
* 流适配器中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流
|
||||
* 2. 将AsyncIterable转换为WHATWG ReadableStream
|
||||
* 3. 更新响应结果中的stream
|
||||
*
|
||||
* 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream
|
||||
*/
|
||||
export const StreamAdapterMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了
|
||||
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
if (
|
||||
result.rawOutput &&
|
||||
!(result.rawOutput instanceof ReadableStream) &&
|
||||
isAsyncIterable<SdkRawChunk>(result.rawOutput)
|
||||
) {
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
|
||||
return {
|
||||
...result,
|
||||
stream: result.rawOutput
|
||||
}
|
||||
} else if (result.rawOutput) {
|
||||
// 非流式输出,强行变为可读流
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,99 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
|
||||
|
||||
/**
|
||||
* 文本块处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 累积文本内容(TEXT_DELTA)
|
||||
* 2. 对文本内容进行智能链接转换
|
||||
* 3. 生成TEXT_COMPLETE事件
|
||||
* 4. 暂存Web搜索结果,用于最终链接完善
|
||||
* 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本
|
||||
*/
|
||||
export const TextChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换流式响应中的文本内容
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const assistant = params.assistant
|
||||
const model = params.assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`)
|
||||
return result
|
||||
}
|
||||
|
||||
// 用于跨chunk的状态管理
|
||||
let accumulatedTextContent = ''
|
||||
let hasEnqueue = false
|
||||
const enhancedTextStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
accumulatedTextContent += textChunk.text
|
||||
|
||||
// 处理 onResponse 回调 - 发送增量文本更新
|
||||
if (params.onResponse) {
|
||||
params.onResponse(accumulatedTextContent, false)
|
||||
}
|
||||
|
||||
// 创建新的chunk,包含处理后的文本
|
||||
controller.enqueue(chunk)
|
||||
} else if (accumulatedTextContent) {
|
||||
if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
controller.enqueue(chunk)
|
||||
hasEnqueue = true
|
||||
}
|
||||
const finalText = accumulatedTextContent
|
||||
ctx._internal.customState!.accumulatedText = finalText
|
||||
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
|
||||
ctx._internal.toolProcessingState.output = finalText
|
||||
}
|
||||
|
||||
// 处理 onResponse 回调 - 发送最终完整文本
|
||||
if (params.onResponse) {
|
||||
params.onResponse(finalText, true)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: finalText
|
||||
})
|
||||
accumulatedTextContent = ''
|
||||
if (!hasEnqueue) {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
} else {
|
||||
// 其他类型的chunk直接传递
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: enhancedTextStream
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user