diff --git a/.gitignore b/.gitignore index f4cc92ce30..a8107fa93e 100644 --- a/.gitignore +++ b/.gitignore @@ -71,3 +71,5 @@ playwright-report test-results YOUR_MEMORY_FILE_PATH + +.sessions/ diff --git a/.oxlintrc.json b/.oxlintrc.json index 8c440af32c..20f3955d58 100644 --- a/.oxlintrc.json +++ b/.oxlintrc.json @@ -118,7 +118,7 @@ "no-unused-expressions": "off", // this rule disallow us to use expression to call function, like `condition && fn()` "no-unused-labels": "error", "no-unused-private-class-members": "error", - "no-unused-vars": ["error", { "caughtErrors": "none" }], + "no-unused-vars": ["warn", { "caughtErrors": "none" }], "no-useless-backreference": "error", "no-useless-catch": "error", "no-useless-escape": "error", diff --git a/.vscode/settings.json b/.vscode/settings.json index 4097b2bdac..9fe4ec4bb4 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -35,10 +35,10 @@ ".oxlintrc.json": "jsonc" }, "files.eol": "\n", - "i18n-ally.displayLanguage": "zh-cn", + // "i18n-ally.displayLanguage": "zh-cn", // 界面显示语言 "i18n-ally.enabledFrameworks": ["react-i18next", "i18next"], "i18n-ally.enabledParsers": ["ts", "js", "json"], // 解析语言 - "i18n-ally.fullReloadOnChanged": true, // 界面显示语言 + "i18n-ally.fullReloadOnChanged": true, "i18n-ally.keystyle": "nested", // 翻译路径格式 "i18n-ally.localesPaths": ["src/renderer/src/i18n/locales"], // "i18n-ally.namespace": true, // 开启命名空间 diff --git a/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.1-d937b73fed.patch b/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.1-d937b73fed.patch new file mode 100644 index 0000000000..5e37489f25 --- /dev/null +++ b/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.1-d937b73fed.patch @@ -0,0 +1,31 @@ +diff --git a/sdk.mjs b/sdk.mjs +index 461e9a2ba246778261108a682762ffcf26f7224e..44bd667d9f591969d36a105ba5eb8b478c738dd8 100644 +--- a/sdk.mjs ++++ b/sdk.mjs +@@ -6215,7 +6215,7 @@ function createAbortController(maxListeners = DEFAULT_MAX_LISTENERS) { + } + + // ../src/transport/ProcessTransport.ts +-import { spawn } from "child_process"; ++import { fork } from "child_process"; + import { createInterface } from "readline"; + + // ../src/utils/fsOperations.ts +@@ -6473,14 +6473,11 @@ class ProcessTransport { + const errorMessage = isNativeBinary(pathToClaudeCodeExecutable) ? `Claude Code native binary not found at ${pathToClaudeCodeExecutable}. Please ensure Claude Code is installed via native installer or specify a valid path with options.pathToClaudeCodeExecutable.` : `Claude Code executable not found at ${pathToClaudeCodeExecutable}. Is options.pathToClaudeCodeExecutable set?`; + throw new ReferenceError(errorMessage); + } +- const isNative = isNativeBinary(pathToClaudeCodeExecutable); +- const spawnCommand = isNative ? pathToClaudeCodeExecutable : executable; +- const spawnArgs = isNative ? args : [...executableArgs, pathToClaudeCodeExecutable, ...args]; +- this.logForDebugging(isNative ? `Spawning Claude Code native binary: ${pathToClaudeCodeExecutable} ${args.join(" ")}` : `Spawning Claude Code process: ${executable} ${[...executableArgs, pathToClaudeCodeExecutable, ...args].join(" ")}`); ++ this.logForDebugging(`Forking Claude Code Node.js process: ${pathToClaudeCodeExecutable} ${args.join(" ")}`); + const stderrMode = env.DEBUG || stderr ? "pipe" : "ignore"; +- this.child = spawn(spawnCommand, spawnArgs, { ++ this.child = fork(pathToClaudeCodeExecutable, args, { + cwd, +- stdio: ["pipe", "pipe", stderrMode], ++ stdio: stderrMode === "pipe" ? ["pipe", "pipe", "pipe", "ipc"] : ["pipe", "pipe", "ignore", "ipc"], + signal: this.abortController.signal, + env + }); diff --git a/AGENTS.md b/AGENTS.md new file mode 120000 index 0000000000..681311eb9c --- /dev/null +++ b/AGENTS.md @@ -0,0 +1 @@ +CLAUDE.md \ No newline at end of file diff --git a/CLAUDE.md b/CLAUDE.md index 03589aebc5..0eaf1a3aaf 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,44 +1,39 @@ -# CLAUDE.md +# AI Assistant Guide -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. +This file provides guidance to AI coding assistants when working with code in this repository. Adherence to these guidelines is crucial for maintaining code quality and consistency. + +## Guiding Principles (MUST FOLLOW) + +- **Keep it clear**: Write code that is easy to read, maintain, and explain. +- **Match the house style**: Reuse existing patterns, naming, and conventions. +- **Search smart**: Prefer `ast-grep` for semantic queries; fall back to `rg`/`grep` when needed. +- **Build with HeroUI**: Use HeroUI for every new UI component; never add `antd` or `styled-components`. +- **Log centrally**: Route all logging through `loggerService` with the right context—no `console.log`. +- **Research via subagent**: Lean on `subagent` for external docs, APIs, news, and references. +- **Seek review**: Ask a human developer to review substantial changes before merging. +- **Commit in rhythm**: Keep commits small, conventional, and emoji-tagged. ## Development Commands -### Environment Setup +- **Install**: `yarn install` - Install all project dependencies +- **Development**: `yarn dev` - Runs Electron app in development mode with hot reload +- **Debug**: `yarn debug` - Starts with debugging enabled, use `chrome://inspect` to attach debugger +- **Build Check**: `yarn build:check` - **REQUIRED** before commits (lint + test + typecheck) + - If having i18n sort issues, run `yarn sync:i18n` first to sync template + - If having formatting issues, run `yarn format` first +- **Test**: `yarn test` - Run all tests (Vitest) across main and renderer processes +- **Single Test**: + - `yarn test:main` - Run tests for main process only + - `yarn test:renderer` - Run tests for renderer process only +- **Lint**: `yarn lint` - Fix linting issues and run TypeScript type checking +- **Format**: `yarn format` - Auto-format code using Biome -- **Prerequisites**: Node.js v22.x.x or higher, Yarn 4.9.1 -- **Setup Yarn**: `corepack enable && corepack prepare yarn@4.9.1 --activate` -- **Install Dependencies**: `yarn install` -- **Add New Dependencies**: `yarn add -D` for renderer-specific dependencies, `yarn add` for others. +## Project Architecture -### Development - -- **Start Development**: `yarn dev` - Runs Electron app in development mode -- **Debug Mode**: `yarn debug` - Starts with debugging enabled, use chrome://inspect - -### Testing & Quality - -- **Run Tests**: `yarn test` - Runs all tests (Vitest) -- **Run E2E Tests**: `yarn test:e2e` - Playwright end-to-end tests -- **Type Check**: `yarn typecheck` - Checks TypeScript for both node and web -- **Lint**: `yarn lint` - ESLint with auto-fix -- **Format**: `yarn format` - Biome formatting - -### Build & Release - -- **Build**: `yarn build` - Builds for production (includes typecheck) -- **Platform-specific builds**: - - Windows: `yarn build:win` - - macOS: `yarn build:mac` - - Linux: `yarn build:linux` - -## Architecture Overview - -### Electron Multi-Process Architecture - -- **Main Process** (`src/main/`): Node.js backend handling system integration, file operations, and services -- **Renderer Process** (`src/renderer/`): React-based UI running in Chromium -- **Preload Scripts** (`src/preload/`): Secure bridge between main and renderer processes +### Electron Structure +- **Main Process** (`src/main/`): Node.js backend with services (MCP, Knowledge, Storage, etc.) +- **Renderer Process** (`src/renderer/`): React UI with Redux state management +- **Preload Scripts** (`src/preload/`): Secure IPC bridge ### Key Architectural Components @@ -148,24 +143,8 @@ The application uses three distinct data management systems. Choose the appropri ### Usage ```typescript -// Main process import { loggerService } from '@logger' const logger = loggerService.withContext('moduleName') - -// Renderer process (set window source first) -loggerService.initWindowSource('windowName') -const logger = loggerService.withContext('moduleName') - -// Logging +// Renderer: loggerService.initWindowSource('windowName') first logger.info('message', CONTEXT) -logger.error('message', new Error('error'), CONTEXT) ``` - -### Log Levels (highest to lowest) - -- `error` - Critical errors causing crash/unusable functionality -- `warn` - Potential issues that don't affect core functionality -- `info` - Application lifecycle and key user actions -- `verbose` - Detailed flow information for feature tracing -- `debug` - Development diagnostic info (not for production) -- `silly` - Extreme debugging, low-level information diff --git a/electron-builder.yml b/electron-builder.yml index 92dd036e31..76f48eeb24 100644 --- a/electron-builder.yml +++ b/electron-builder.yml @@ -128,7 +128,21 @@ afterSign: scripts/notarize.js artifactBuildCompleted: scripts/artifact-build-completed.js releaseInfo: releaseNotes: | - Optimized note-taking feature, now able to quickly rename by modifying the title - Fixed issue where CherryAI free model could not be used - Fixed issue where VertexAI proxy address could not be called normally - Fixed issue where built-in tools from service providers could not be called normally + What's New in v1.6.3 + + Features: + - Notes: Add spell-check control, automatic table line wrapping, export functionality, and LLM-based renaming + - UI: Expand topic rename clickable area, add middle-click tab closing, remove redundant scrollbars, fix message menubar overflow + - Editor: Add read-only extension support, make TextFilePreview read-only but copyable + - Models: Update support for DeepSeek v3.2, Claude 4.5, GLM 4.6, Gemini regex, and vision models + - Code Tools: Add GitHub Copilot CLI integration + + Bug Fixes: + - Fix migration for missing providers + - Fix forked topic retaining old name after rename + - Restore first token latency reporting in metrics + - Fix UI scrollbar and overflow issues + + Technical Updates: + - Upgrade to Electron 37.6.0 + - Update dependencies across packages diff --git a/electron.vite.config.ts b/electron.vite.config.ts index 7b8990b6e7..76c1b55401 100644 --- a/electron.vite.config.ts +++ b/electron.vite.config.ts @@ -102,6 +102,7 @@ export default defineConfig({ alias: { '@renderer': resolve('src/renderer/src'), '@shared': resolve('packages/shared'), + '@types': resolve('src/renderer/src/types'), '@logger': resolve('src/renderer/src/services/LoggerService'), '@data': resolve('src/renderer/src/data'), '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), diff --git a/package.json b/package.json index 0a2e7c5865..eaa7c25c1c 100644 --- a/package.json +++ b/package.json @@ -43,15 +43,18 @@ "release": "node scripts/version.js", "publish": "yarn build:check && yarn release patch push", "pulish:artifacts": "cd packages/artifacts && npm publish && cd -", - "generate:agents": "yarn workspace @cherry-studio/database agents", + "agents:generate": "NODE_ENV='development' drizzle-kit generate --config src/main/services/agents/drizzle.config.ts", + "agents:push": "NODE_ENV='development' drizzle-kit push --config src/main/services/agents/drizzle.config.ts", + "agents:studio": "NODE_ENV='development' drizzle-kit studio --config src/main/services/agents/drizzle.config.ts", + "agents:drop": "NODE_ENV='development' drizzle-kit drop --config src/main/services/agents/drizzle.config.ts", "generate:icons": "electron-icon-builder --input=./build/logo.png --output=build", "analyze:renderer": "VISUALIZER_RENDERER=true yarn build", "analyze:main": "VISUALIZER_MAIN=true yarn build", "typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"", "typecheck:node": "tsgo --noEmit -p tsconfig.node.json --composite false", "typecheck:web": "tsgo --noEmit -p tsconfig.web.json --composite false", - "check:i18n": "tsx scripts/check-i18n.ts", - "sync:i18n": "tsx scripts/sync-i18n.ts", + "check:i18n": "dotenv -e .env -- tsx scripts/check-i18n.ts", + "sync:i18n": "dotenv -e .env -- tsx scripts/sync-i18n.ts", "update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts", "auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts", "update:languages": "tsx scripts/update-languages.ts", @@ -65,7 +68,7 @@ "test:e2e": "yarn playwright test", "test:lint": "oxlint --deny-warnings && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --cache", "test:scripts": "vitest scripts", - "lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && biome lint --write && biome format --write && yarn typecheck && yarn check:i18n", + "lint": "oxlint --fix && eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --cache && biome lint --write && biome format --write && yarn typecheck && yarn check:i18n && yarn format:check", "lint:ox": "oxlint --fix && biome lint --write && biome format --write", "format": "biome format --write && biome lint --write", "format:check": "biome format && biome lint", @@ -77,11 +80,11 @@ "release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core npm publish --access public" }, "dependencies": { + "@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.1#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.1-d937b73fed.patch", "@libsql/client": "0.14.0", "@libsql/win32-x64-msvc": "^0.4.7", "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch", "@strongtz/win32-arm64-msvc": "^0.4.7", - "express": "^5.1.0", "font-list": "^2.0.0", "graceful-fs": "^4.2.11", "jsdom": "26.1.0", @@ -91,7 +94,6 @@ "selection-hook": "^1.0.12", "sharp": "^0.34.3", "swagger-jsdoc": "^6.2.8", - "swagger-ui-express": "^5.0.1", "tesseract.js": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch", "turndown": "7.2.0" }, @@ -153,6 +155,7 @@ "@opentelemetry/sdk-trace-node": "^2.0.0", "@opentelemetry/sdk-trace-web": "^2.0.0", "@playwright/test": "^1.52.0", + "@radix-ui/react-context-menu": "^2.2.16", "@reduxjs/toolkit": "^2.2.5", "@shikijs/markdown-it": "^3.12.0", "@swc/plugin-styled-components": "^8.0.4", @@ -204,6 +207,7 @@ "@types/swagger-ui-express": "^4.1.8", "@types/tinycolor2": "^1", "@types/turndown": "^5.0.5", + "@types/uuid": "^10.0.0", "@types/word-extractor": "^1", "@typescript/native-preview": "latest", "@uiw/codemirror-extensions-langs": "^4.25.1", @@ -241,10 +245,11 @@ "dompurify": "^3.2.6", "dotenv-cli": "^7.4.2", "drizzle-kit": "^0.31.4", - "drizzle-orm": "^0.44.2", + "drizzle-orm": "^0.44.5", "electron": "37.6.0", "electron-builder": "26.0.15", "electron-devtools-installer": "^3.2.0", + "electron-reload": "^2.0.0-alpha.1", "electron-store": "^8.2.0", "electron-updater": "6.6.4", "electron-vite": "4.0.0", @@ -257,6 +262,8 @@ "eslint-plugin-react-hooks": "^5.2.0", "eslint-plugin-simple-import-sort": "^12.1.1", "eslint-plugin-unused-imports": "^4.1.4", + "express": "^5.1.0", + "express-validator": "^7.2.1", "fast-diff": "^1.3.0", "fast-xml-parser": "^5.2.0", "fetch-socks": "1.3.2", @@ -329,6 +336,7 @@ "string-width": "^7.2.0", "striptags": "^3.2.0", "styled-components": "^6.1.11", + "swagger-ui-express": "^5.0.1", "swr": "^2.3.6", "tailwindcss": "^4.1.13", "tar": "^7.4.3", @@ -340,7 +348,7 @@ "typescript": "~5.8.2", "undici": "6.21.2", "unified": "^11.0.5", - "uuid": "^10.0.0", + "uuid": "^13.0.0", "vite": "npm:rolldown-vite@latest", "vitest": "^3.2.4", "webdav": "^5.8.0", diff --git a/packages/shared/IpcChannel.ts b/packages/shared/IpcChannel.ts index 9c0c88e9a4..355e340853 100644 --- a/packages/shared/IpcChannel.ts +++ b/packages/shared/IpcChannel.ts @@ -91,6 +91,10 @@ export enum IpcChannel { // Python Python_Execute = 'python:execute', + // agent messages + AgentMessage_PersistExchange = 'agent-message:persist-exchange', + AgentMessage_GetHistory = 'agent-message:get-history', + //copilot Copilot_GetAuthMessage = 'copilot:get-auth-message', Copilot_GetCopilotToken = 'copilot:get-copilot-token', @@ -184,6 +188,7 @@ export enum IpcChannel { File_ValidateNotesDirectory = 'file:validateNotesDirectory', File_StartWatcher = 'file:startWatcher', File_StopWatcher = 'file:stopWatcher', + File_ShowInFolder = 'file:showInFolder', // file service FileService_Upload = 'file-service:upload', diff --git a/packages/shared/agents/claudecode/types.ts b/packages/shared/agents/claudecode/types.ts new file mode 100644 index 0000000000..df1b5dfea4 --- /dev/null +++ b/packages/shared/agents/claudecode/types.ts @@ -0,0 +1,12 @@ +import type { SDKMessage } from '@anthropic-ai/claude-agent-sdk' +import type { ContentBlockParam } from '@anthropic-ai/sdk/resources/messages' + +export type ClaudeCodeRawValue = + | { + type: string + session_id: string + slash_commands: string[] + tools: string[] + raw: Extract + } + | ContentBlockParam diff --git a/packages/shared/anthropic/index.ts b/packages/shared/anthropic/index.ts new file mode 100644 index 0000000000..777cbd13e8 --- /dev/null +++ b/packages/shared/anthropic/index.ts @@ -0,0 +1,170 @@ +/** + * @fileoverview Shared Anthropic AI client utilities for Cherry Studio + * + * This module provides functions for creating Anthropic SDK clients with different + * authentication methods (OAuth, API key) and building Claude Code system messages. + * It supports both standard Anthropic API and Anthropic Vertex AI endpoints. + * + * This shared module can be used by both main and renderer processes. + */ + +import Anthropic from '@anthropic-ai/sdk' +import { TextBlockParam } from '@anthropic-ai/sdk/resources' +import { loggerService } from '@logger' +import { Provider } from '@types' +import type { ModelMessage } from 'ai' + +const logger = loggerService.withContext('anthropic-sdk') + +const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.` + +const defaultClaudeCodeSystem: Array = [ + { + type: 'text', + text: defaultClaudeCodeSystemPrompt + } +] + +/** + * Creates and configures an Anthropic SDK client based on the provider configuration. + * + * This function supports two authentication methods: + * 1. OAuth: Uses OAuth tokens passed as parameter + * 2. API Key: Uses traditional API key authentication + * + * For OAuth authentication, it includes Claude Code specific headers and beta features. + * For API key authentication, it uses the provider's configuration with custom headers. + * + * @param provider - The provider configuration containing authentication details + * @param oauthToken - Optional OAuth token for OAuth authentication + * @returns An initialized Anthropic or AnthropicVertex client + * @throws Error when OAuth token is not available for OAuth authentication + * + * @example + * ```typescript + * // OAuth authentication + * const oauthProvider = { authType: 'oauth' }; + * const oauthClient = getSdkClient(oauthProvider, 'oauth-token-here'); + * + * // API key authentication + * const apiKeyProvider = { + * authType: 'apikey', + * apiKey: 'your-api-key', + * apiHost: 'https://api.anthropic.com' + * }; + * const apiKeyClient = getSdkClient(apiKeyProvider); + * ``` + */ +export function getSdkClient( + provider: Provider, + oauthToken?: string | null, + extraHeaders?: Record +): Anthropic { + if (provider.authType === 'oauth') { + if (!oauthToken) { + throw new Error('OAuth token is not available') + } + return new Anthropic({ + authToken: oauthToken, + baseURL: 'https://api.anthropic.com', + dangerouslyAllowBrowser: true, + defaultHeaders: { + 'Content-Type': 'application/json', + 'anthropic-version': '2023-06-01', + 'anthropic-beta': + 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14', + 'anthropic-dangerous-direct-browser-access': 'true', + 'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)', + 'x-app': 'cli', + 'x-stainless-retry-count': '0', + 'x-stainless-timeout': '600', + 'x-stainless-lang': 'js', + 'x-stainless-package-version': '0.60.0', + 'x-stainless-os': 'MacOS', + 'x-stainless-arch': 'arm64', + 'x-stainless-runtime': 'node', + 'x-stainless-runtime-version': 'v22.18.0', + ...extraHeaders + } + }) + } + const baseURL = + provider.type === 'anthropic' + ? provider.apiHost + : (provider.anthropicApiHost && provider.anthropicApiHost.trim()) || provider.apiHost + + logger.debug('Anthropic API baseURL', { baseURL, providerId: provider.id }) + + if (provider.id === 'aihubmix') { + return new Anthropic({ + apiKey: provider.apiKey, + baseURL, + dangerouslyAllowBrowser: true, + defaultHeaders: { + 'anthropic-beta': 'output-128k-2025-02-19', + 'APP-Code': 'MLTG2087', + ...provider.extra_headers, + ...extraHeaders + } + }) + } + + return new Anthropic({ + apiKey: provider.apiKey, + authToken: provider.apiKey, + baseURL, + dangerouslyAllowBrowser: true, + defaultHeaders: { + 'anthropic-beta': 'output-128k-2025-02-19', + ...provider.extra_headers + } + }) +} + +/** + * Builds and prepends the Claude Code system message to user-provided system messages. + * + * This function ensures that all interactions with Claude include the official Claude Code + * system prompt, which identifies the assistant as "Claude Code, Anthropic's official CLI for Claude." + * + * The function handles three cases: + * 1. No system message provided: Returns only the default Claude Code system message + * 2. String system message: Converts to array format and prepends Claude Code message + * 3. Array system message: Checks if Claude Code message exists and prepends if missing + * + * @param system - Optional user-provided system message (string or TextBlockParam array) + * @returns Combined system message with Claude Code prompt prepended + * + * ``` + */ +export function buildClaudeCodeSystemMessage(system?: string | Array): Array { + if (!system) { + return defaultClaudeCodeSystem + } + + if (typeof system === 'string') { + if (system.trim() === defaultClaudeCodeSystemPrompt || system.trim() === '') { + return defaultClaudeCodeSystem + } else { + return [...defaultClaudeCodeSystem, { type: 'text', text: system }] + } + } + if (Array.isArray(system)) { + const firstSystem = system[0] + if (firstSystem.type === 'text' && firstSystem.text.trim() === defaultClaudeCodeSystemPrompt) { + return system + } else { + return [...defaultClaudeCodeSystem, ...system] + } + } + + return defaultClaudeCodeSystem +} + +export function buildClaudeCodeSystemModelMessage(system?: string | Array): Array { + const textBlocks = buildClaudeCodeSystemMessage(system) + return textBlocks.map((block) => ({ + role: 'system', + content: block.text + })) +} diff --git a/packages/shared/data/preference/preferenceSchemas.ts b/packages/shared/data/preference/preferenceSchemas.ts index 4a37cde378..6b758c0ff5 100644 --- a/packages/shared/data/preference/preferenceSchemas.ts +++ b/packages/shared/data/preference/preferenceSchemas.ts @@ -659,7 +659,7 @@ export const DefaultPreferences: PreferenceSchemas = { 'ui.sidebar.icons.invisible': [], 'ui.sidebar.icons.visible': [ 'assistants', - 'agents', + 'store', 'paintings', 'translate', 'minapp', diff --git a/packages/shared/data/preference/preferenceTypes.ts b/packages/shared/data/preference/preferenceTypes.ts index 7a9af886c6..a02e57cbcd 100644 --- a/packages/shared/data/preference/preferenceTypes.ts +++ b/packages/shared/data/preference/preferenceTypes.ts @@ -55,7 +55,7 @@ export type AssistantTabSortType = 'tags' | 'list' export type SidebarIcon = | 'assistants' - | 'agents' + | 'store' | 'paintings' | 'translate' | 'minapp' diff --git a/resources/database/drizzle/0000_confused_wendigo.sql b/resources/database/drizzle/0000_confused_wendigo.sql new file mode 100644 index 0000000000..b2328e39c2 --- /dev/null +++ b/resources/database/drizzle/0000_confused_wendigo.sql @@ -0,0 +1,53 @@ +--> statement-breakpoint +CREATE TABLE `migrations` ( + `version` integer PRIMARY KEY NOT NULL, + `tag` text NOT NULL, + `executed_at` integer NOT NULL +); + +CREATE TABLE `agents` ( + `id` text PRIMARY KEY NOT NULL, + `type` text NOT NULL, + `name` text NOT NULL, + `description` text, + `accessible_paths` text, + `instructions` text, + `model` text NOT NULL, + `plan_model` text, + `small_model` text, + `mcps` text, + `allowed_tools` text, + `configuration` text, + `created_at` text NOT NULL, + `updated_at` text NOT NULL +); + +--> statement-breakpoint +CREATE TABLE `sessions` ( + `id` text PRIMARY KEY NOT NULL, + `agent_type` text NOT NULL, + `agent_id` text NOT NULL, + `name` text NOT NULL, + `description` text, + `accessible_paths` text, + `instructions` text, + `model` text NOT NULL, + `plan_model` text, + `small_model` text, + `mcps` text, + `allowed_tools` text, + `configuration` text, + `created_at` text NOT NULL, + `updated_at` text NOT NULL +); + +--> statement-breakpoint +CREATE TABLE `session_messages` ( + `id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, + `session_id` text NOT NULL, + `role` text NOT NULL, + `content` text NOT NULL, + `metadata` text, + `created_at` text NOT NULL, + `updated_at` text NOT NULL +); diff --git a/resources/database/drizzle/0001_woozy_captain_flint.sql b/resources/database/drizzle/0001_woozy_captain_flint.sql new file mode 100644 index 0000000000..f80f483c72 --- /dev/null +++ b/resources/database/drizzle/0001_woozy_captain_flint.sql @@ -0,0 +1 @@ +ALTER TABLE `session_messages` ADD `agent_session_id` text DEFAULT ''; \ No newline at end of file diff --git a/resources/database/drizzle/meta/0000_snapshot.json b/resources/database/drizzle/meta/0000_snapshot.json new file mode 100644 index 0000000000..140460cf09 --- /dev/null +++ b/resources/database/drizzle/meta/0000_snapshot.json @@ -0,0 +1,331 @@ +{ + "version": "6", + "dialect": "sqlite", + "id": "35efb412-0230-4767-9c76-7b7c4d40369f", + "prevId": "00000000-0000-0000-0000-000000000000", + "tables": { + "agents": { + "name": "agents", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "type": { + "name": "type", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "description": { + "name": "description", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "accessible_paths": { + "name": "accessible_paths", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "instructions": { + "name": "instructions", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "model": { + "name": "model", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "plan_model": { + "name": "plan_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "small_model": { + "name": "small_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "mcps": { + "name": "mcps", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "allowed_tools": { + "name": "allowed_tools", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "configuration": { + "name": "configuration", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "session_messages": { + "name": "session_messages", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "session_id": { + "name": "session_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "role": { + "name": "role", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "metadata": { + "name": "metadata", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "migrations": { + "name": "migrations", + "columns": { + "version": { + "name": "version", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "tag": { + "name": "tag", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "executed_at": { + "name": "executed_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "sessions": { + "name": "sessions", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "agent_type": { + "name": "agent_type", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "agent_id": { + "name": "agent_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "description": { + "name": "description", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "accessible_paths": { + "name": "accessible_paths", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "instructions": { + "name": "instructions", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "model": { + "name": "model", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "plan_model": { + "name": "plan_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "small_model": { + "name": "small_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "mcps": { + "name": "mcps", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "allowed_tools": { + "name": "allowed_tools", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "configuration": { + "name": "configuration", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "views": {}, + "enums": {}, + "_meta": { + "schemas": {}, + "tables": {}, + "columns": {} + }, + "internal": { + "indexes": {} + } +} diff --git a/resources/database/drizzle/meta/0001_snapshot.json b/resources/database/drizzle/meta/0001_snapshot.json new file mode 100644 index 0000000000..3b78976dd0 --- /dev/null +++ b/resources/database/drizzle/meta/0001_snapshot.json @@ -0,0 +1,339 @@ +{ + "version": "6", + "dialect": "sqlite", + "id": "dabab6db-a2cd-4e96-b06e-6cb87d445a87", + "prevId": "35efb412-0230-4767-9c76-7b7c4d40369f", + "tables": { + "agents": { + "name": "agents", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "type": { + "name": "type", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "description": { + "name": "description", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "accessible_paths": { + "name": "accessible_paths", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "instructions": { + "name": "instructions", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "model": { + "name": "model", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "plan_model": { + "name": "plan_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "small_model": { + "name": "small_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "mcps": { + "name": "mcps", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "allowed_tools": { + "name": "allowed_tools", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "configuration": { + "name": "configuration", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "session_messages": { + "name": "session_messages", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "session_id": { + "name": "session_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "role": { + "name": "role", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "agent_session_id": { + "name": "agent_session_id", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false, + "default": "''" + }, + "metadata": { + "name": "metadata", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "migrations": { + "name": "migrations", + "columns": { + "version": { + "name": "version", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "tag": { + "name": "tag", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "executed_at": { + "name": "executed_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "sessions": { + "name": "sessions", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "agent_type": { + "name": "agent_type", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "agent_id": { + "name": "agent_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "description": { + "name": "description", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "accessible_paths": { + "name": "accessible_paths", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "instructions": { + "name": "instructions", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "model": { + "name": "model", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "plan_model": { + "name": "plan_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "small_model": { + "name": "small_model", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "mcps": { + "name": "mcps", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "allowed_tools": { + "name": "allowed_tools", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "configuration": { + "name": "configuration", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "updated_at": { + "name": "updated_at", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "views": {}, + "enums": {}, + "_meta": { + "schemas": {}, + "tables": {}, + "columns": {} + }, + "internal": { + "indexes": {} + } +} diff --git a/resources/database/drizzle/meta/_journal.json b/resources/database/drizzle/meta/_journal.json new file mode 100644 index 0000000000..8648e01703 --- /dev/null +++ b/resources/database/drizzle/meta/_journal.json @@ -0,0 +1,20 @@ +{ + "version": "7", + "dialect": "sqlite", + "entries": [ + { + "idx": 0, + "version": "6", + "when": 1758091173882, + "tag": "0000_confused_wendigo", + "breakpoints": true + }, + { + "idx": 1, + "version": "6", + "when": 1758187378775, + "tag": "0001_woozy_captain_flint", + "breakpoints": true + } + ] +} diff --git a/scripts/auto-translate-i18n.ts b/scripts/auto-translate-i18n.ts index ef42c8da41..6a90f5b23f 100644 --- a/scripts/auto-translate-i18n.ts +++ b/scripts/auto-translate-i18n.ts @@ -9,8 +9,9 @@ import * as path from 'path' const localesDir = path.join(__dirname, '../src/renderer/src/i18n/locales') const translateDir = path.join(__dirname, '../src/renderer/src/i18n/translate') -const baseLocale = 'zh-cn' +const baseLocale = process.env.BASE_LOCALE ?? 'zh-cn' const baseFileName = `${baseLocale}.json` +const baseLocalePath = path.join(__dirname, '../src/renderer/src/i18n/locales', baseFileName) type I18NValue = string | { [key: string]: I18NValue } type I18N = { [key: string]: I18NValue } @@ -105,6 +106,9 @@ const translateRecursively = async (originObj: I18N, systemPrompt: string): Prom } const main = async () => { + if (!fs.existsSync(baseLocalePath)) { + throw new Error(`${baseLocalePath} not found.`) + } const localeFiles = fs .readdirSync(localesDir) .filter((file) => file.endsWith('.json') && file !== baseFileName) diff --git a/scripts/before-pack.js b/scripts/before-pack.js index 59c0a39171..3cde049bd4 100644 --- a/scripts/before-pack.js +++ b/scripts/before-pack.js @@ -35,6 +35,9 @@ const allX64 = { '@napi-rs/system-ocr-win32-x64-msvc': '1.0.2' } +const claudeCodeVenderPath = '@anthropic-ai/claude-agent-sdk/vendor' +const claudeCodeVenders = ['arm64-darwin', 'arm64-linux', 'x64-darwin', 'x64-linux', 'x64-win32'] + const platformToArch = { mac: 'darwin', windows: 'win32', @@ -46,9 +49,6 @@ exports.default = async function (context) { const archType = arch === Arch.arm64 ? 'arm64' : 'x64' const platform = context.packager.platform.name - const arm64Filters = Object.keys(allArm64).map((f) => '!node_modules/' + f + '/**') - const x64Filters = Object.keys(allX64).map((f) => '!node_modules/' + f + '/*') - const downloadPackages = async (packages) => { console.log('downloading packages ......') const downloadPromises = [] @@ -67,25 +67,39 @@ exports.default = async function (context) { await Promise.all(downloadPromises) } - const changeFilters = async (packages, filtersToExclude, filtersToInclude) => { - await downloadPackages(packages) + const changeFilters = async (filtersToExclude, filtersToInclude) => { // remove filters for the target architecture (allow inclusion) - let filters = context.packager.config.files[0].filter filters = filters.filter((filter) => !filtersToInclude.includes(filter)) + // add filters for other architectures (exclude them) filters.push(...filtersToExclude) context.packager.config.files[0].filter = filters } - if (arch === Arch.arm64) { - await changeFilters(allArm64, x64Filters, arm64Filters) - return - } + await downloadPackages(arch === Arch.arm64 ? allArm64 : allX64) - if (arch === Arch.x64) { - await changeFilters(allX64, arm64Filters, x64Filters) - return + const arm64Filters = Object.keys(allArm64).map((f) => '!node_modules/' + f + '/**') + const x64Filters = Object.keys(allX64).map((f) => '!node_modules/' + f + '/*') + const excludeClaudeCodeRipgrepFilters = claudeCodeVenders + .filter((f) => f !== `${archType}-${platformToArch[platform]}`) + .map((f) => '!node_modules/' + claudeCodeVenderPath + '/ripgrep/' + f + '/**') + const excludeClaudeCodeJBPlutins = ['!node_modules/' + claudeCodeVenderPath + '/' + 'claude-code-jetbrains-plugin'] + + const includeClaudeCodeFilters = [ + '!node_modules/' + claudeCodeVenderPath + '/ripgrep/' + `${archType}-${platformToArch[platform]}/**` + ] + + if (arch === Arch.arm64) { + await changeFilters( + [...x64Filters, ...excludeClaudeCodeRipgrepFilters, ...excludeClaudeCodeJBPlutins], + [...arm64Filters, ...includeClaudeCodeFilters] + ) + } else { + await changeFilters( + [...arm64Filters, ...excludeClaudeCodeRipgrepFilters, ...excludeClaudeCodeJBPlutins], + [...x64Filters, ...includeClaudeCodeFilters] + ) } } diff --git a/scripts/check-i18n.ts b/scripts/check-i18n.ts index cb357aef09..5735474106 100644 --- a/scripts/check-i18n.ts +++ b/scripts/check-i18n.ts @@ -4,7 +4,7 @@ import * as path from 'path' import { sortedObjectByKeys } from './sort' const translationsDir = path.join(__dirname, '../src/renderer/src/i18n/locales') -const baseLocale = 'zh-cn' +const baseLocale = process.env.BASE_LOCALE ?? 'zh-cn' const baseFileName = `${baseLocale}.json` const baseFilePath = path.join(translationsDir, baseFileName) diff --git a/scripts/sync-i18n.ts b/scripts/sync-i18n.ts index aa13bddefd..6b58756a5d 100644 --- a/scripts/sync-i18n.ts +++ b/scripts/sync-i18n.ts @@ -5,7 +5,7 @@ import { sortedObjectByKeys } from './sort' const localesDir = path.join(__dirname, '../src/renderer/src/i18n/locales') const translateDir = path.join(__dirname, '../src/renderer/src/i18n/translate') -const baseLocale = 'zh-cn' +const baseLocale = process.env.BASE_LOCALE ?? 'zh-cn' const baseFileName = `${baseLocale}.json` const baseFilePath = path.join(localesDir, baseFileName) diff --git a/src/main/apiServer/app.ts b/src/main/apiServer/app.ts index 46da10f876..a645e96740 100644 --- a/src/main/apiServer/app.ts +++ b/src/main/apiServer/app.ts @@ -3,23 +3,42 @@ import cors from 'cors' import express from 'express' import { v4 as uuidv4 } from 'uuid' +import { LONG_POLL_TIMEOUT_MS } from './config/timeouts' import { authMiddleware } from './middleware/auth' import { errorHandler } from './middleware/error' import { setupOpenAPIDocumentation } from './middleware/openapi' +import { agentsRoutes } from './routes/agents' import { chatRoutes } from './routes/chat' import { mcpRoutes } from './routes/mcp' +import { messagesProviderRoutes, messagesRoutes } from './routes/messages' import { modelsRoutes } from './routes/models' const logger = loggerService.withContext('ApiServer') +const extendMessagesTimeout: express.RequestHandler = (req, res, next) => { + req.setTimeout(LONG_POLL_TIMEOUT_MS) + res.setTimeout(LONG_POLL_TIMEOUT_MS) + next() +} + const app = express() +app.use( + express.json({ + limit: '50mb' + }) +) // Global middleware app.use((req, res, next) => { const start = Date.now() res.on('finish', () => { const duration = Date.now() - start - logger.info(`${req.method} ${req.path} - ${res.statusCode} - ${duration}ms`) + logger.info('API request completed', { + method: req.method, + path: req.path, + statusCode: res.statusCode, + durationMs: duration + }) }) next() }) @@ -101,27 +120,28 @@ app.get('/', (_req, res) => { name: 'Cherry Studio API', version: '1.0.0', endpoints: { - health: 'GET /health', - models: 'GET /v1/models', - chat: 'POST /v1/chat/completions', - mcp: 'GET /v1/mcps' + health: 'GET /health' } }) }) +// Setup OpenAPI documentation before protected routes so docs remain public +setupOpenAPIDocumentation(app) + +// Provider-specific messages route requires authentication +app.use('/:provider/v1/messages', authMiddleware, extendMessagesTimeout, messagesProviderRoutes) + // API v1 routes with auth const apiRouter = express.Router() apiRouter.use(authMiddleware) -apiRouter.use(express.json()) // Mount routes apiRouter.use('/chat', chatRoutes) apiRouter.use('/mcps', mcpRoutes) +apiRouter.use('/messages', extendMessagesTimeout, messagesRoutes) apiRouter.use('/models', modelsRoutes) +apiRouter.use('/agents', agentsRoutes) app.use('/v1', apiRouter) -// Setup OpenAPI documentation -setupOpenAPIDocumentation(app) - // Error handling (must be last) app.use(errorHandler) diff --git a/src/main/apiServer/config.ts b/src/main/apiServer/config.ts index d1aac85dad..60b1986be9 100644 --- a/src/main/apiServer/config.ts +++ b/src/main/apiServer/config.ts @@ -36,7 +36,7 @@ class ConfigManager { } return this._config } catch (error: any) { - logger.warn('Failed to load config from Redux, using defaults:', error) + logger.warn('Failed to load config from Redux, using defaults', { error }) this._config = { enabled: false, port: defaultPort, diff --git a/src/main/apiServer/config/timeouts.ts b/src/main/apiServer/config/timeouts.ts new file mode 100644 index 0000000000..2c5e077430 --- /dev/null +++ b/src/main/apiServer/config/timeouts.ts @@ -0,0 +1,3 @@ +export const LONG_POLL_TIMEOUT_MS = 120 * 60_000 // 120 minutes + +export const MESSAGE_STREAM_TIMEOUT_MS = LONG_POLL_TIMEOUT_MS diff --git a/src/main/apiServer/middleware/__tests__/auth.test.ts b/src/main/apiServer/middleware/__tests__/auth.test.ts new file mode 100644 index 0000000000..859050fb2b --- /dev/null +++ b/src/main/apiServer/middleware/__tests__/auth.test.ts @@ -0,0 +1,368 @@ +import type { NextFunction, Request, Response } from 'express' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +import { config } from '../../config' +import { authMiddleware } from '../auth' + +// Mock the config module +vi.mock('../../config', () => ({ + config: { + get: vi.fn() + } +})) + +// Mock the logger +vi.mock('@logger', () => ({ + loggerService: { + withContext: vi.fn(() => ({ + debug: vi.fn() + })) + } +})) + +const mockConfig = config as any + +describe('authMiddleware', () => { + let req: Partial + let res: Partial + let next: NextFunction + let jsonMock: ReturnType + let statusMock: ReturnType + + beforeEach(() => { + jsonMock = vi.fn() + statusMock = vi.fn(() => ({ json: jsonMock })) + + req = { + header: vi.fn() + } + res = { + status: statusMock + } + next = vi.fn() + + vi.clearAllMocks() + }) + + describe('Missing credentials', () => { + it('should return 401 when both auth headers are missing', async () => { + ;(req.header as any).mockReturnValue('') + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(401) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should return 401 when both auth headers are empty strings', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return '' + if (header === 'x-api-key') return '' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(401) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' }) + expect(next).not.toHaveBeenCalled() + }) + }) + + describe('Server configuration', () => { + it('should return 403 when API key is not configured', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return 'some-key' + return '' + }) + + mockConfig.get.mockResolvedValue({ apiKey: '' }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should return 403 when API key is null', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return 'some-key' + return '' + }) + + mockConfig.get.mockResolvedValue({ apiKey: null }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + }) + + describe('API Key authentication (priority)', () => { + const validApiKey = 'valid-api-key-123' + + beforeEach(() => { + mockConfig.get.mockResolvedValue({ apiKey: validApiKey }) + }) + + it('should authenticate successfully with valid API key', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return validApiKey + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(next).toHaveBeenCalled() + expect(statusMock).not.toHaveBeenCalled() + }) + + it('should return 403 with invalid API key', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return 'invalid-key' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should return 401 with empty API key', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return ' ' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(401) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: empty x-api-key' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should handle API key with whitespace', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return ` ${validApiKey} ` + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(next).toHaveBeenCalled() + expect(statusMock).not.toHaveBeenCalled() + }) + + it('should prioritize API key over Bearer token when both are present', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return validApiKey + if (header === 'authorization') return 'Bearer invalid-token' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(next).toHaveBeenCalled() + expect(statusMock).not.toHaveBeenCalled() + }) + + it('should return 403 when API key is invalid even if Bearer token is valid', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return 'invalid-key' + if (header === 'authorization') return `Bearer ${validApiKey}` + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + }) + + describe('Bearer token authentication (fallback)', () => { + const validApiKey = 'valid-api-key-123' + + beforeEach(() => { + mockConfig.get.mockResolvedValue({ apiKey: validApiKey }) + }) + + it('should authenticate successfully with valid Bearer token when no API key', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return `Bearer ${validApiKey}` + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(next).toHaveBeenCalled() + expect(statusMock).not.toHaveBeenCalled() + }) + + it('should return 403 with invalid Bearer token', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return 'Bearer invalid-token' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should return 401 with malformed authorization header', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return 'Basic sometoken' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(401) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should return 401 with Bearer without space', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return 'Bearer' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(401) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should handle Bearer token with only trailing spaces (edge case)', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(401) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should handle Bearer token with case insensitive prefix', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return `bearer ${validApiKey}` + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(next).toHaveBeenCalled() + expect(statusMock).not.toHaveBeenCalled() + }) + + it('should handle Bearer token with whitespace', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return ` Bearer ${validApiKey} ` + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(next).toHaveBeenCalled() + expect(statusMock).not.toHaveBeenCalled() + }) + }) + + describe('Edge cases', () => { + const validApiKey = 'valid-api-key-123' + + beforeEach(() => { + mockConfig.get.mockResolvedValue({ apiKey: validApiKey }) + }) + + it('should handle config.get() rejection', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return validApiKey + return '' + }) + + mockConfig.get.mockRejectedValue(new Error('Config error')) + + await expect(authMiddleware(req as Request, res as Response, next)).rejects.toThrow('Config error') + }) + + it('should use timing-safe comparison for different length tokens', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return 'short' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should return 401 when neither credential format is valid', async () => { + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return 'Invalid format' + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(401) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' }) + expect(next).not.toHaveBeenCalled() + }) + }) + + describe('Timing attack protection', () => { + const validApiKey = 'valid-api-key-123' + + beforeEach(() => { + mockConfig.get.mockResolvedValue({ apiKey: validApiKey }) + }) + + it('should handle similar length but different API keys securely', async () => { + const similarKey = 'valid-api-key-124' // Same length, different last char + + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'x-api-key') return similarKey + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + + it('should handle similar length but different Bearer tokens securely', async () => { + const similarKey = 'valid-api-key-124' // Same length, different last char + + ;(req.header as any).mockImplementation((header: string) => { + if (header === 'authorization') return `Bearer ${similarKey}` + return '' + }) + + await authMiddleware(req as Request, res as Response, next) + + expect(statusMock).toHaveBeenCalledWith(403) + expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' }) + expect(next).not.toHaveBeenCalled() + }) + }) +}) diff --git a/src/main/apiServer/middleware/auth.ts b/src/main/apiServer/middleware/auth.ts index 2c2838756e..bf44e4eb37 100644 --- a/src/main/apiServer/middleware/auth.ts +++ b/src/main/apiServer/middleware/auth.ts @@ -3,8 +3,17 @@ import type { NextFunction, Request, Response } from 'express' import { config } from '../config' +const isValidToken = (token: string, apiKey: string): boolean => { + if (token.length !== apiKey.length) { + return false + } + const tokenBuf = Buffer.from(token) + const keyBuf = Buffer.from(apiKey) + return crypto.timingSafeEqual(tokenBuf, keyBuf) +} + export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => { - const auth = req.header('Authorization') || '' + const auth = req.header('authorization') || '' const xApiKey = req.header('x-api-key') || '' // Fast rejection if neither credential header provided @@ -12,51 +21,46 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc return res.status(401).json({ error: 'Unauthorized: missing credentials' }) } - let token: string | undefined - - // Prefer Bearer if well‑formed - if (auth) { - const trimmed = auth.trim() - const bearerPrefix = /^Bearer\s+/i - if (bearerPrefix.test(trimmed)) { - const candidate = trimmed.replace(bearerPrefix, '').trim() - if (!candidate) { - return res.status(401).json({ error: 'Unauthorized: empty bearer token' }) - } - token = candidate - } - } - - // Fallback to x-api-key if token still not resolved - if (!token && xApiKey) { - if (!xApiKey.trim()) { - return res.status(401).json({ error: 'Unauthorized: empty x-api-key' }) - } - token = xApiKey.trim() - } - - if (!token) { - // At this point we had at least one header, but none yielded a usable token - return res.status(401).json({ error: 'Unauthorized: invalid credentials format' }) - } - const { apiKey } = await config.get() if (!apiKey) { - // If server not configured, treat as forbidden (or could be 500). Choose 403 to avoid leaking config state. return res.status(403).json({ error: 'Forbidden' }) } - // Timing-safe compare when lengths match, else immediate forbidden - if (token.length !== apiKey.length) { - return res.status(403).json({ error: 'Forbidden' }) + // Check API key first (priority) + if (xApiKey) { + const trimmedApiKey = xApiKey.trim() + if (!trimmedApiKey) { + return res.status(401).json({ error: 'Unauthorized: empty x-api-key' }) + } + + if (isValidToken(trimmedApiKey, apiKey)) { + return next() + } else { + return res.status(403).json({ error: 'Forbidden' }) + } } - const tokenBuf = Buffer.from(token) - const keyBuf = Buffer.from(apiKey) - if (!crypto.timingSafeEqual(tokenBuf, keyBuf)) { - return res.status(403).json({ error: 'Forbidden' }) + // Fallback to Bearer token + if (auth) { + const trimmed = auth.trim() + const bearerPrefix = /^Bearer\s+/i + + if (!bearerPrefix.test(trimmed)) { + return res.status(401).json({ error: 'Unauthorized: invalid authorization format' }) + } + + const token = trimmed.replace(bearerPrefix, '').trim() + if (!token) { + return res.status(401).json({ error: 'Unauthorized: empty bearer token' }) + } + + if (isValidToken(token, apiKey)) { + return next() + } else { + return res.status(403).json({ error: 'Forbidden' }) + } } - return next() + return res.status(401).json({ error: 'Unauthorized: invalid credentials format' }) } diff --git a/src/main/apiServer/middleware/error.ts b/src/main/apiServer/middleware/error.ts index 56a2ec54f3..03c2d5617e 100644 --- a/src/main/apiServer/middleware/error.ts +++ b/src/main/apiServer/middleware/error.ts @@ -6,7 +6,7 @@ const logger = loggerService.withContext('ApiServerErrorHandler') // oxlint-disable-next-line @typescript-eslint/no-unused-vars export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => { - logger.error('API Server Error:', err) + logger.error('API server error', { error: err }) // Don't expose internal errors in production const isDev = process.env.NODE_ENV === 'development' diff --git a/src/main/apiServer/middleware/openapi.ts b/src/main/apiServer/middleware/openapi.ts index 596d3089d7..c136fecdde 100644 --- a/src/main/apiServer/middleware/openapi.ts +++ b/src/main/apiServer/middleware/openapi.ts @@ -197,10 +197,11 @@ export function setupOpenAPIDocumentation(app: Express) { }) ) - logger.info('OpenAPI documentation setup complete') - logger.info('Documentation available at /api-docs') - logger.info('OpenAPI spec available at /api-docs.json') + logger.info('OpenAPI documentation ready', { + docsPath: '/api-docs', + specPath: '/api-docs.json' + }) } catch (error) { - logger.error('Failed to setup OpenAPI documentation:', error as Error) + logger.error('Failed to setup OpenAPI documentation', { error }) } } diff --git a/src/main/apiServer/routes/agents/handlers/agents.ts b/src/main/apiServer/routes/agents/handlers/agents.ts new file mode 100644 index 0000000000..1e772932e2 --- /dev/null +++ b/src/main/apiServer/routes/agents/handlers/agents.ts @@ -0,0 +1,567 @@ +import { loggerService } from '@logger' +import { AgentModelValidationError, agentService, sessionService } from '@main/services/agents' +import { ListAgentsResponse, type ReplaceAgentRequest, type UpdateAgentRequest } from '@types' +import { Request, Response } from 'express' + +import type { ValidationRequest } from '../validators/zodValidator' + +const logger = loggerService.withContext('ApiServerAgentsHandlers') + +const modelValidationErrorBody = (error: AgentModelValidationError) => ({ + error: { + message: `Invalid ${error.context.field}: ${error.detail.message}`, + type: 'invalid_request_error', + code: error.detail.code + } +}) + +/** + * @swagger + * /v1/agents: + * post: + * summary: Create a new agent + * description: Creates a new autonomous agent with the specified configuration and automatically + * provisions an initial session that mirrors the agent's settings. + * tags: [Agents] + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/CreateAgentRequest' + * responses: + * 201: + * description: Agent created successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 400: + * description: Validation error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 500: + * description: Internal server error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + */ +export const createAgent = async (req: Request, res: Response): Promise => { + try { + logger.debug('Creating agent') + logger.debug('Agent payload', { body: req.body }) + + const agent = await agentService.createAgent(req.body) + + try { + logger.info('Agent created', { agentId: agent.id }) + logger.debug('Creating default session for agent', { agentId: agent.id }) + + await sessionService.createSession(agent.id, {}) + + logger.info('Default session created for agent', { agentId: agent.id }) + return res.status(201).json(agent) + } catch (sessionError: any) { + logger.error('Failed to create default session for new agent, rolling back agent creation', { + agentId: agent.id, + error: sessionError + }) + + try { + await agentService.deleteAgent(agent.id) + } catch (rollbackError: any) { + logger.error('Failed to roll back agent after session creation failure', { + agentId: agent.id, + error: rollbackError + }) + } + + return res.status(500).json({ + error: { + message: `Failed to create default session for agent: ${sessionError.message}`, + type: 'internal_error', + code: 'agent_session_creation_failed' + } + }) + } + } catch (error: any) { + if (error instanceof AgentModelValidationError) { + logger.warn('Agent model validation error during create', { + agentType: error.context.agentType, + field: error.context.field, + model: error.context.model, + detail: error.detail + }) + return res.status(400).json(modelValidationErrorBody(error)) + } + + logger.error('Error creating agent', { error }) + return res.status(500).json({ + error: { + message: `Failed to create agent: ${error.message}`, + type: 'internal_error', + code: 'agent_creation_failed' + } + }) + } +} + +/** + * @swagger + * /v1/agents: + * get: + * summary: List all agents + * description: Retrieves a paginated list of all agents + * tags: [Agents] + * parameters: + * - in: query + * name: limit + * schema: + * type: integer + * minimum: 1 + * maximum: 100 + * default: 20 + * description: Number of agents to return + * - in: query + * name: offset + * schema: + * type: integer + * minimum: 0 + * default: 0 + * description: Number of agents to skip + * responses: + * 200: + * description: List of agents + * content: + * application/json: + * schema: + * type: object + * properties: + * data: + * type: array + * items: + * $ref: '#/components/schemas/AgentEntity' + * total: + * type: integer + * description: Total number of agents + * limit: + * type: integer + * description: Number of agents returned + * offset: + * type: integer + * description: Number of agents skipped + * 400: + * description: Validation error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 500: + * description: Internal server error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + */ +export const listAgents = async (req: Request, res: Response): Promise => { + try { + const limit = req.query.limit ? parseInt(req.query.limit as string) : 20 + const offset = req.query.offset ? parseInt(req.query.offset as string) : 0 + + logger.debug('Listing agents', { limit, offset }) + + const result = await agentService.listAgents({ limit, offset }) + + logger.info('Agents listed', { + returned: result.agents.length, + total: result.total, + limit, + offset + }) + return res.json({ + data: result.agents, + total: result.total, + limit, + offset + } satisfies ListAgentsResponse) + } catch (error: any) { + logger.error('Error listing agents', { error }) + return res.status(500).json({ + error: { + message: 'Failed to list agents', + type: 'internal_error', + code: 'agent_list_failed' + } + }) + } +} + +/** + * @swagger + * /v1/agents/{agentId}: + * get: + * summary: Get agent by ID + * description: Retrieves a specific agent by its ID + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * responses: + * 200: + * description: Agent details + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 500: + * description: Internal server error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + */ +export const getAgent = async (req: Request, res: Response): Promise => { + try { + const { agentId } = req.params + logger.debug('Getting agent', { agentId }) + + const agent = await agentService.getAgent(agentId) + + if (!agent) { + logger.warn('Agent not found', { agentId }) + return res.status(404).json({ + error: { + message: 'Agent not found', + type: 'not_found', + code: 'agent_not_found' + } + }) + } + + logger.info('Agent retrieved', { agentId }) + return res.json(agent) + } catch (error: any) { + logger.error('Error getting agent', { error, agentId: req.params.agentId }) + return res.status(500).json({ + error: { + message: 'Failed to get agent', + type: 'internal_error', + code: 'agent_get_failed' + } + }) + } +} + +/** + * @swagger + * /v1/agents/{agentId}: + * put: + * summary: Update agent + * description: Updates an existing agent with the provided data + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/CreateAgentRequest' + * responses: + * 200: + * description: Agent updated successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 400: + * description: Validation error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 500: + * description: Internal server error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + */ +export const updateAgent = async (req: Request, res: Response): Promise => { + const { agentId } = req.params + try { + logger.debug('Updating agent', { agentId }) + logger.debug('Replace payload', { body: req.body }) + + const { validatedBody } = req as ValidationRequest + const replacePayload = (validatedBody ?? {}) as ReplaceAgentRequest + + const agent = await agentService.updateAgent(agentId, replacePayload, { replace: true }) + + if (!agent) { + logger.warn('Agent not found for update', { agentId }) + return res.status(404).json({ + error: { + message: 'Agent not found', + type: 'not_found', + code: 'agent_not_found' + } + }) + } + + logger.info('Agent updated', { agentId }) + return res.json(agent) + } catch (error: any) { + if (error instanceof AgentModelValidationError) { + logger.warn('Agent model validation error during update', { + agentId, + agentType: error.context.agentType, + field: error.context.field, + model: error.context.model, + detail: error.detail + }) + return res.status(400).json(modelValidationErrorBody(error)) + } + + logger.error('Error updating agent', { error, agentId }) + return res.status(500).json({ + error: { + message: 'Failed to update agent: ' + error.message, + type: 'internal_error', + code: 'agent_update_failed' + } + }) + } +} + +/** + * @swagger + * /v1/agents/{agentId}: + * patch: + * summary: Partially update agent + * description: Partially updates an existing agent with only the provided fields + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * type: object + * properties: + * name: + * type: string + * description: Agent name + * description: + * type: string + * description: Agent description + * avatar: + * type: string + * description: Agent avatar URL + * instructions: + * type: string + * description: System prompt/instructions + * model: + * type: string + * description: Main model ID + * plan_model: + * type: string + * description: Optional planning model ID + * small_model: + * type: string + * description: Optional small/fast model ID + * tools: + * type: array + * items: + * type: string + * description: Tools + * mcps: + * type: array + * items: + * type: string + * description: MCP tool IDs + * knowledges: + * type: array + * items: + * type: string + * description: Knowledge base IDs + * configuration: + * type: object + * description: Extensible settings + * accessible_paths: + * type: array + * items: + * type: string + * description: Accessible directory paths + * permission_mode: + * type: string + * enum: [readOnly, acceptEdits, bypassPermissions] + * description: Permission mode + * max_steps: + * type: integer + * description: Maximum steps the agent can take + * description: Only include the fields you want to update + * responses: + * 200: + * description: Agent updated successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 400: + * description: Validation error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 500: + * description: Internal server error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + */ +export const patchAgent = async (req: Request, res: Response): Promise => { + const { agentId } = req.params + try { + logger.debug('Partially updating agent', { agentId }) + logger.debug('Patch payload', { body: req.body }) + + const { validatedBody } = req as ValidationRequest + const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest + + const agent = await agentService.updateAgent(agentId, updatePayload) + + if (!agent) { + logger.warn('Agent not found for partial update', { agentId }) + return res.status(404).json({ + error: { + message: 'Agent not found', + type: 'not_found', + code: 'agent_not_found' + } + }) + } + + logger.info('Agent patched', { agentId }) + return res.json(agent) + } catch (error: any) { + if (error instanceof AgentModelValidationError) { + logger.warn('Agent model validation error during partial update', { + agentId, + agentType: error.context.agentType, + field: error.context.field, + model: error.context.model, + detail: error.detail + }) + return res.status(400).json(modelValidationErrorBody(error)) + } + + logger.error('Error partially updating agent', { error, agentId }) + return res.status(500).json({ + error: { + message: `Failed to partially update agent: ${error.message}`, + type: 'internal_error', + code: 'agent_patch_failed' + } + }) + } +} + +/** + * @swagger + * /v1/agents/{agentId}: + * delete: + * summary: Delete agent + * description: Deletes an agent and all associated sessions and logs + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * responses: + * 204: + * description: Agent deleted successfully + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 500: + * description: Internal server error + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + */ +export const deleteAgent = async (req: Request, res: Response): Promise => { + try { + const { agentId } = req.params + logger.debug('Deleting agent', { agentId }) + + const deleted = await agentService.deleteAgent(agentId) + + if (!deleted) { + logger.warn('Agent not found for deletion', { agentId }) + return res.status(404).json({ + error: { + message: 'Agent not found', + type: 'not_found', + code: 'agent_not_found' + } + }) + } + + logger.info('Agent deleted', { agentId }) + return res.status(204).send() + } catch (error: any) { + logger.error('Error deleting agent', { error, agentId: req.params.agentId }) + return res.status(500).json({ + error: { + message: 'Failed to delete agent', + type: 'internal_error', + code: 'agent_delete_failed' + } + }) + } +} diff --git a/src/main/apiServer/routes/agents/handlers/index.ts b/src/main/apiServer/routes/agents/handlers/index.ts new file mode 100644 index 0000000000..0bd3e1d73a --- /dev/null +++ b/src/main/apiServer/routes/agents/handlers/index.ts @@ -0,0 +1,3 @@ +export * as agentHandlers from './agents' +export * as messageHandlers from './messages' +export * as sessionHandlers from './sessions' diff --git a/src/main/apiServer/routes/agents/handlers/messages.ts b/src/main/apiServer/routes/agents/handlers/messages.ts new file mode 100644 index 0000000000..e18fadc0e0 --- /dev/null +++ b/src/main/apiServer/routes/agents/handlers/messages.ts @@ -0,0 +1,317 @@ +import { loggerService } from '@logger' +import { MESSAGE_STREAM_TIMEOUT_MS } from '@main/apiServer/config/timeouts' +import { createStreamAbortController, STREAM_TIMEOUT_REASON } from '@main/apiServer/utils/createStreamAbortController' +import { agentService, sessionMessageService, sessionService } from '@main/services/agents' +import { Request, Response } from 'express' + +const logger = loggerService.withContext('ApiServerMessagesHandlers') + +// Helper function to verify agent and session exist and belong together +const verifyAgentAndSession = async (agentId: string, sessionId: string) => { + const agentExists = await agentService.agentExists(agentId) + if (!agentExists) { + throw { status: 404, code: 'agent_not_found', message: 'Agent not found' } + } + + const session = await sessionService.getSession(agentId, sessionId) + if (!session) { + throw { status: 404, code: 'session_not_found', message: 'Session not found' } + } + + if (session.agent_id !== agentId) { + throw { status: 404, code: 'session_not_found', message: 'Session not found for this agent' } + } + + return session +} + +export const createMessage = async (req: Request, res: Response): Promise => { + let clearAbortTimeout: (() => void) | undefined + + try { + const { agentId, sessionId } = req.params + + const session = await verifyAgentAndSession(agentId, sessionId) + + const messageData = req.body + + logger.info('Creating streaming message', { agentId, sessionId }) + logger.debug('Streaming message payload', { messageData }) + + // Set SSE headers + res.setHeader('Content-Type', 'text/event-stream') + res.setHeader('Cache-Control', 'no-cache') + res.setHeader('Connection', 'keep-alive') + res.setHeader('Access-Control-Allow-Origin', '*') + res.setHeader('Access-Control-Allow-Headers', 'Cache-Control') + + const { + abortController, + registerAbortHandler, + clearAbortTimeout: helperClearAbortTimeout + } = createStreamAbortController({ + timeoutMs: MESSAGE_STREAM_TIMEOUT_MS + }) + clearAbortTimeout = helperClearAbortTimeout + const { stream, completion } = await sessionMessageService.createSessionMessage( + session, + messageData, + abortController + ) + const reader = stream.getReader() + + // Track stream lifecycle so we keep the SSE connection open until persistence finishes + let responseEnded = false + let streamFinished = false + + const cleanupAbortTimeout = () => { + clearAbortTimeout?.() + } + + const finalizeResponse = () => { + if (responseEnded) { + return + } + + if (!streamFinished) { + return + } + + responseEnded = true + cleanupAbortTimeout() + try { + // res.write('data: {"type":"finish"}\n\n') + res.write('data: [DONE]\n\n') + } catch (writeError) { + logger.error('Error writing final sentinel to SSE stream', { error: writeError as Error }) + } + res.end() + } + + /** + * Client Disconnect Detection for Server-Sent Events (SSE) + * + * We monitor multiple HTTP events to reliably detect when a client disconnects + * from the streaming response. This is crucial for: + * - Aborting long-running Claude Code processes + * - Cleaning up resources and preventing memory leaks + * - Avoiding orphaned processes + * + * Event Priority & Behavior: + * 1. res.on('close') - Most common for SSE client disconnects (browser tab close, curl Ctrl+C) + * 2. req.on('aborted') - Explicit request abortion + * 3. req.on('close') - Request object closure (less common with SSE) + * + * When any disconnect event fires, we: + * - Abort the Claude Code SDK process via abortController + * - Clean up event listeners to prevent memory leaks + * - Mark the response as ended to prevent further writes + */ + registerAbortHandler((abortReason) => { + cleanupAbortTimeout() + + if (responseEnded) return + + responseEnded = true + + if (abortReason === STREAM_TIMEOUT_REASON) { + logger.error('Streaming message timeout', { agentId, sessionId }) + try { + res.write( + `data: ${JSON.stringify({ + type: 'error', + error: { + message: 'Stream timeout', + type: 'timeout_error', + code: 'stream_timeout' + } + })}\n\n` + ) + } catch (writeError) { + logger.error('Error writing timeout to SSE stream', { error: writeError }) + } + } else if (abortReason === 'Client disconnected') { + logger.info('Streaming client disconnected', { agentId, sessionId }) + } else { + logger.warn('Streaming aborted', { agentId, sessionId, reason: abortReason }) + } + + reader.cancel(abortReason ?? 'stream aborted').catch(() => {}) + + if (!res.headersSent) { + res.setHeader('Content-Type', 'text/event-stream') + res.setHeader('Cache-Control', 'no-cache') + res.setHeader('Connection', 'keep-alive') + } + + if (!res.writableEnded) { + res.end() + } + }) + + const handleDisconnect = () => { + if (abortController.signal.aborted) return + abortController.abort('Client disconnected') + } + + req.on('close', handleDisconnect) + req.on('aborted', handleDisconnect) + res.on('close', handleDisconnect) + + const pumpStream = async () => { + try { + while (!responseEnded) { + const { done, value } = await reader.read() + if (done) { + break + } + + res.write(`data: ${JSON.stringify(value)}\n\n`) + } + + streamFinished = true + finalizeResponse() + } catch (error) { + if (responseEnded) return + logger.error('Error reading agent stream', { error }) + try { + res.write( + `data: ${JSON.stringify({ + type: 'error', + error: { + message: (error as Error).message || 'Stream processing error', + type: 'stream_error', + code: 'stream_processing_failed' + } + })}\n\n` + ) + } catch (writeError) { + logger.error('Error writing stream error to SSE', { error: writeError }) + } + responseEnded = true + cleanupAbortTimeout() + res.end() + } + } + + pumpStream().catch((error) => { + logger.error('Pump stream failure', { error }) + }) + + completion + .then(() => { + streamFinished = true + finalizeResponse() + }) + .catch((error) => { + if (responseEnded) return + logger.error('Streaming message error', { agentId, sessionId, error }) + try { + res.write( + `data: ${JSON.stringify({ + type: 'error', + error: { + message: (error as { message?: string })?.message || 'Stream processing error', + type: 'stream_error', + code: 'stream_processing_failed' + } + })}\n\n` + ) + } catch (writeError) { + logger.error('Error writing completion error to SSE stream', { error: writeError }) + } + responseEnded = true + cleanupAbortTimeout() + res.end() + }) + // Clear timeout when response ends + res.on('close', cleanupAbortTimeout) + res.on('finish', cleanupAbortTimeout) + } catch (error: any) { + clearAbortTimeout?.() + logger.error('Error in streaming message handler', { + error, + agentId: req.params.agentId, + sessionId: req.params.sessionId + }) + + // Send error as SSE if possible + if (!res.headersSent) { + res.setHeader('Content-Type', 'text/event-stream') + res.setHeader('Cache-Control', 'no-cache') + res.setHeader('Connection', 'keep-alive') + } + + try { + const errorResponse = { + type: 'error', + error: { + message: error.status ? error.message : 'Failed to create streaming message', + type: error.status ? 'not_found' : 'internal_error', + code: error.status ? error.code : 'stream_creation_failed' + } + } + + res.write(`data: ${JSON.stringify(errorResponse)}\n\n`) + } catch (writeError) { + logger.error('Error writing initial error to SSE stream', { error: writeError }) + } + + res.end() + } +} + +export const deleteMessage = async (req: Request, res: Response): Promise => { + try { + const { agentId, sessionId, messageId: messageIdParam } = req.params + const messageId = Number(messageIdParam) + + await verifyAgentAndSession(agentId, sessionId) + + const deleted = await sessionMessageService.deleteSessionMessage(sessionId, messageId) + + if (!deleted) { + logger.warn('Session message not found', { agentId, sessionId, messageId }) + return res.status(404).json({ + error: { + message: 'Message not found for this session', + type: 'not_found', + code: 'session_message_not_found' + } + }) + } + + logger.info('Session message deleted', { agentId, sessionId, messageId }) + return res.status(204).send() + } catch (error: any) { + if (error?.status === 404) { + logger.warn('Delete message failed - missing resource', { + agentId: req.params.agentId, + sessionId: req.params.sessionId, + messageId: req.params.messageId, + error + }) + return res.status(404).json({ + error: { + message: error.message, + type: 'not_found', + code: error.code ?? 'session_message_not_found' + } + }) + } + + logger.error('Error deleting session message', { + error, + agentId: req.params.agentId, + sessionId: req.params.sessionId, + messageId: Number(req.params.messageId) + }) + return res.status(500).json({ + error: { + message: 'Failed to delete session message', + type: 'internal_error', + code: 'session_message_delete_failed' + } + }) + } +} diff --git a/src/main/apiServer/routes/agents/handlers/sessions.ts b/src/main/apiServer/routes/agents/handlers/sessions.ts new file mode 100644 index 0000000000..72875dab8a --- /dev/null +++ b/src/main/apiServer/routes/agents/handlers/sessions.ts @@ -0,0 +1,366 @@ +import { loggerService } from '@logger' +import { AgentModelValidationError, sessionMessageService, sessionService } from '@main/services/agents' +import { ListAgentSessionsResponse, type ReplaceSessionRequest, UpdateSessionResponse } from '@types' +import { Request, Response } from 'express' + +import type { ValidationRequest } from '../validators/zodValidator' + +const logger = loggerService.withContext('ApiServerSessionsHandlers') + +const modelValidationErrorBody = (error: AgentModelValidationError) => ({ + error: { + message: `Invalid ${error.context.field}: ${error.detail.message}`, + type: 'invalid_request_error', + code: error.detail.code + } +}) + +export const createSession = async (req: Request, res: Response): Promise => { + const { agentId } = req.params + try { + const sessionData = req.body + + logger.debug('Creating new session', { agentId }) + logger.debug('Session payload', { sessionData }) + + const session = await sessionService.createSession(agentId, sessionData) + + logger.info('Session created', { agentId, sessionId: session?.id }) + return res.status(201).json(session) + } catch (error: any) { + if (error instanceof AgentModelValidationError) { + logger.warn('Session model validation error during create', { + agentId, + agentType: error.context.agentType, + field: error.context.field, + model: error.context.model, + detail: error.detail + }) + return res.status(400).json(modelValidationErrorBody(error)) + } + + logger.error('Error creating session', { error, agentId }) + return res.status(500).json({ + error: { + message: `Failed to create session: ${error.message}`, + type: 'internal_error', + code: 'session_creation_failed' + } + }) + } +} + +export const listSessions = async (req: Request, res: Response): Promise => { + const { agentId } = req.params + try { + const limit = req.query.limit ? parseInt(req.query.limit as string) : 20 + const offset = req.query.offset ? parseInt(req.query.offset as string) : 0 + const status = req.query.status as any + + logger.debug('Listing agent sessions', { agentId, limit, offset, status }) + + const result = await sessionService.listSessions(agentId, { limit, offset }) + + logger.info('Agent sessions listed', { + agentId, + returned: result.sessions.length, + total: result.total, + limit, + offset + }) + return res.json({ + data: result.sessions, + total: result.total, + limit, + offset + }) + } catch (error: any) { + logger.error('Error listing sessions', { error, agentId }) + return res.status(500).json({ + error: { + message: 'Failed to list sessions', + type: 'internal_error', + code: 'session_list_failed' + } + }) + } +} + +export const getSession = async (req: Request, res: Response): Promise => { + try { + const { agentId, sessionId } = req.params + logger.debug('Getting session', { agentId, sessionId }) + + const session = await sessionService.getSession(agentId, sessionId) + + if (!session) { + logger.warn('Session not found', { agentId, sessionId }) + return res.status(404).json({ + error: { + message: 'Session not found', + type: 'not_found', + code: 'session_not_found' + } + }) + } + + // // Verify session belongs to the agent + // logger.warn(`Session ${sessionId} does not belong to agent ${agentId}`) + // return res.status(404).json({ + // error: { + // message: 'Session not found for this agent', + // type: 'not_found', + // code: 'session_not_found' + // } + // }) + // } + + // Fetch session messages + logger.debug('Fetching session messages', { sessionId }) + const { messages } = await sessionMessageService.listSessionMessages(sessionId) + + // Add messages to session + const sessionWithMessages = { + ...session, + messages: messages + } + + logger.info('Session retrieved', { agentId, sessionId, messageCount: messages.length }) + return res.json(sessionWithMessages) + } catch (error: any) { + logger.error('Error getting session', { error, agentId: req.params.agentId, sessionId: req.params.sessionId }) + return res.status(500).json({ + error: { + message: 'Failed to get session', + type: 'internal_error', + code: 'session_get_failed' + } + }) + } +} + +export const updateSession = async (req: Request, res: Response): Promise => { + const { agentId, sessionId } = req.params + try { + logger.debug('Updating session', { agentId, sessionId }) + logger.debug('Replace payload', { body: req.body }) + + // First check if session exists and belongs to agent + const existingSession = await sessionService.getSession(agentId, sessionId) + if (!existingSession || existingSession.agent_id !== agentId) { + logger.warn('Session not found for update', { agentId, sessionId }) + return res.status(404).json({ + error: { + message: 'Session not found for this agent', + type: 'not_found', + code: 'session_not_found' + } + }) + } + + const { validatedBody } = req as ValidationRequest + const replacePayload = (validatedBody ?? {}) as ReplaceSessionRequest + + const session = await sessionService.updateSession(agentId, sessionId, replacePayload) + + if (!session) { + logger.warn('Session missing during update', { agentId, sessionId }) + return res.status(404).json({ + error: { + message: 'Session not found', + type: 'not_found', + code: 'session_not_found' + } + }) + } + + logger.info('Session updated', { agentId, sessionId }) + return res.json(session satisfies UpdateSessionResponse) + } catch (error: any) { + if (error instanceof AgentModelValidationError) { + logger.warn('Session model validation error during update', { + agentId, + sessionId, + agentType: error.context.agentType, + field: error.context.field, + model: error.context.model, + detail: error.detail + }) + return res.status(400).json(modelValidationErrorBody(error)) + } + + logger.error('Error updating session', { error, agentId, sessionId }) + return res.status(500).json({ + error: { + message: `Failed to update session: ${error.message}`, + type: 'internal_error', + code: 'session_update_failed' + } + }) + } +} + +export const patchSession = async (req: Request, res: Response): Promise => { + const { agentId, sessionId } = req.params + try { + logger.debug('Patching session', { agentId, sessionId }) + logger.debug('Patch payload', { body: req.body }) + + // First check if session exists and belongs to agent + const existingSession = await sessionService.getSession(agentId, sessionId) + if (!existingSession || existingSession.agent_id !== agentId) { + logger.warn('Session not found for patch', { agentId, sessionId }) + return res.status(404).json({ + error: { + message: 'Session not found for this agent', + type: 'not_found', + code: 'session_not_found' + } + }) + } + + const updateSession = { ...existingSession, ...req.body } + const session = await sessionService.updateSession(agentId, sessionId, updateSession) + + if (!session) { + logger.warn('Session missing while patching', { agentId, sessionId }) + return res.status(404).json({ + error: { + message: 'Session not found', + type: 'not_found', + code: 'session_not_found' + } + }) + } + + logger.info('Session patched', { agentId, sessionId }) + return res.json(session) + } catch (error: any) { + if (error instanceof AgentModelValidationError) { + logger.warn('Session model validation error during patch', { + agentId, + sessionId, + agentType: error.context.agentType, + field: error.context.field, + model: error.context.model, + detail: error.detail + }) + return res.status(400).json(modelValidationErrorBody(error)) + } + + logger.error('Error patching session', { error, agentId, sessionId }) + return res.status(500).json({ + error: { + message: `Failed to patch session, ${error.message}`, + type: 'internal_error', + code: 'session_patch_failed' + } + }) + } +} + +export const deleteSession = async (req: Request, res: Response): Promise => { + try { + const { agentId, sessionId } = req.params + logger.debug('Deleting session', { agentId, sessionId }) + + // First check if session exists and belongs to agent + const existingSession = await sessionService.getSession(agentId, sessionId) + if (!existingSession || existingSession.agent_id !== agentId) { + logger.warn('Session not found for deletion', { agentId, sessionId }) + return res.status(404).json({ + error: { + message: 'Session not found for this agent', + type: 'not_found', + code: 'session_not_found' + } + }) + } + + const deleted = await sessionService.deleteSession(agentId, sessionId) + + if (!deleted) { + logger.warn('Session missing during delete', { agentId, sessionId }) + return res.status(404).json({ + error: { + message: 'Session not found', + type: 'not_found', + code: 'session_not_found' + } + }) + } + + logger.info('Session deleted', { agentId, sessionId }) + + const { total } = await sessionService.listSessions(agentId, { limit: 1 }) + + if (total === 0) { + logger.info('No remaining sessions, creating default', { agentId }) + try { + const fallbackSession = await sessionService.createSession(agentId, {}) + logger.info('Default session created after delete', { + agentId, + sessionId: fallbackSession?.id + }) + } catch (recoveryError: any) { + logger.error('Failed to recreate session after deleting last session', { + agentId, + error: recoveryError + }) + return res.status(500).json({ + error: { + message: `Failed to recreate session after deletion: ${recoveryError.message}`, + type: 'internal_error', + code: 'session_recovery_failed' + } + }) + } + } + + return res.status(204).send() + } catch (error: any) { + logger.error('Error deleting session', { error, agentId: req.params.agentId, sessionId: req.params.sessionId }) + return res.status(500).json({ + error: { + message: 'Failed to delete session', + type: 'internal_error', + code: 'session_delete_failed' + } + }) + } +} + +// Convenience endpoints for sessions without agent context +export const listAllSessions = async (req: Request, res: Response): Promise => { + try { + const limit = req.query.limit ? parseInt(req.query.limit as string) : 20 + const offset = req.query.offset ? parseInt(req.query.offset as string) : 0 + const status = req.query.status as any + + logger.debug('Listing all sessions', { limit, offset, status }) + + const result = await sessionService.listSessions(undefined, { limit, offset }) + + logger.info('Sessions listed', { + returned: result.sessions.length, + total: result.total, + limit, + offset + }) + return res.json({ + data: result.sessions, + total: result.total, + limit, + offset + } satisfies ListAgentSessionsResponse) + } catch (error: any) { + logger.error('Error listing all sessions', { error }) + return res.status(500).json({ + error: { + message: 'Failed to list sessions', + type: 'internal_error', + code: 'session_list_failed' + } + }) + } +} diff --git a/src/main/apiServer/routes/agents/index.ts b/src/main/apiServer/routes/agents/index.ts new file mode 100644 index 0000000000..42843b7201 --- /dev/null +++ b/src/main/apiServer/routes/agents/index.ts @@ -0,0 +1,965 @@ +import express from 'express' + +import { agentHandlers, messageHandlers, sessionHandlers } from './handlers' +import { checkAgentExists, handleValidationErrors } from './middleware' +import { + validateAgent, + validateAgentId, + validateAgentReplace, + validateAgentUpdate, + validatePagination, + validateSession, + validateSessionId, + validateSessionMessage, + validateSessionMessageId, + validateSessionReplace, + validateSessionUpdate +} from './validators' + +// Create main agents router +const agentsRouter = express.Router() + +/** + * @swagger + * components: + * schemas: + * PermissionMode: + * type: string + * enum: [default, acceptEdits, bypassPermissions, plan] + * description: Permission mode for agent operations + * + * AgentType: + * type: string + * enum: [claude-code] + * description: Type of agent + * + * AgentConfiguration: + * type: object + * properties: + * permission_mode: + * $ref: '#/components/schemas/PermissionMode' + * default: default + * max_turns: + * type: integer + * default: 10 + * description: Maximum number of interaction turns + * additionalProperties: true + * + * AgentBase: + * type: object + * properties: + * name: + * type: string + * description: Agent name + * description: + * type: string + * description: Agent description + * accessible_paths: + * type: array + * items: + * type: string + * description: Array of directory paths the agent can access + * instructions: + * type: string + * description: System prompt/instructions + * model: + * type: string + * description: Main model ID + * plan_model: + * type: string + * description: Optional planning model ID + * small_model: + * type: string + * description: Optional small/fast model ID + * mcps: + * type: array + * items: + * type: string + * description: Array of MCP tool IDs + * allowed_tools: + * type: array + * items: + * type: string + * description: Array of allowed tool IDs (whitelist) + * configuration: + * $ref: '#/components/schemas/AgentConfiguration' + * required: + * - model + * - accessible_paths + * + * AgentEntity: + * allOf: + * - $ref: '#/components/schemas/AgentBase' + * - type: object + * properties: + * id: + * type: string + * description: Unique agent identifier + * type: + * $ref: '#/components/schemas/AgentType' + * created_at: + * type: string + * format: date-time + * description: ISO timestamp of creation + * updated_at: + * type: string + * format: date-time + * description: ISO timestamp of last update + * required: + * - id + * - type + * - created_at + * - updated_at + * CreateAgentRequest: + * allOf: + * - $ref: '#/components/schemas/AgentBase' + * - type: object + * properties: + * type: + * $ref: '#/components/schemas/AgentType' + * name: + * type: string + * minLength: 1 + * description: Agent name (required) + * model: + * type: string + * minLength: 1 + * description: Main model ID (required) + * required: + * - type + * - name + * - model + * + * UpdateAgentRequest: + * type: object + * properties: + * name: + * type: string + * description: Agent name + * description: + * type: string + * description: Agent description + * accessible_paths: + * type: array + * items: + * type: string + * description: Array of directory paths the agent can access + * instructions: + * type: string + * description: System prompt/instructions + * model: + * type: string + * description: Main model ID + * plan_model: + * type: string + * description: Optional planning model ID + * small_model: + * type: string + * description: Optional small/fast model ID + * mcps: + * type: array + * items: + * type: string + * description: Array of MCP tool IDs + * allowed_tools: + * type: array + * items: + * type: string + * description: Array of allowed tool IDs (whitelist) + * configuration: + * $ref: '#/components/schemas/AgentConfiguration' + * description: Partial update - all fields are optional + * + * ReplaceAgentRequest: + * $ref: '#/components/schemas/AgentBase' + * + * SessionEntity: + * allOf: + * - $ref: '#/components/schemas/AgentBase' + * - type: object + * properties: + * id: + * type: string + * description: Unique session identifier + * agent_id: + * type: string + * description: Primary agent ID for the session + * agent_type: + * $ref: '#/components/schemas/AgentType' + * created_at: + * type: string + * format: date-time + * description: ISO timestamp of creation + * updated_at: + * type: string + * format: date-time + * description: ISO timestamp of last update + * required: + * - id + * - agent_id + * - agent_type + * - created_at + * - updated_at + * + * CreateSessionRequest: + * allOf: + * - $ref: '#/components/schemas/AgentBase' + * - type: object + * properties: + * model: + * type: string + * minLength: 1 + * description: Main model ID (required) + * required: + * - model + * + * UpdateSessionRequest: + * type: object + * properties: + * name: + * type: string + * description: Session name + * description: + * type: string + * description: Session description + * accessible_paths: + * type: array + * items: + * type: string + * description: Array of directory paths the agent can access + * instructions: + * type: string + * description: System prompt/instructions + * model: + * type: string + * description: Main model ID + * plan_model: + * type: string + * description: Optional planning model ID + * small_model: + * type: string + * description: Optional small/fast model ID + * mcps: + * type: array + * items: + * type: string + * description: Array of MCP tool IDs + * allowed_tools: + * type: array + * items: + * type: string + * description: Array of allowed tool IDs (whitelist) + * configuration: + * $ref: '#/components/schemas/AgentConfiguration' + * description: Partial update - all fields are optional + * + * ReplaceSessionRequest: + * allOf: + * - $ref: '#/components/schemas/AgentBase' + * - type: object + * properties: + * model: + * type: string + * minLength: 1 + * description: Main model ID (required) + * required: + * - model + * + * CreateSessionMessageRequest: + * type: object + * properties: + * content: + * type: string + * minLength: 1 + * description: Message content + * required: + * - content + * + * PaginationQuery: + * type: object + * properties: + * limit: + * type: integer + * minimum: 1 + * maximum: 100 + * default: 20 + * description: Number of items to return + * offset: + * type: integer + * minimum: 0 + * default: 0 + * description: Number of items to skip + * status: + * type: string + * enum: [idle, running, completed, failed, stopped] + * description: Filter by session status + * + * ListAgentsResponse: + * type: object + * properties: + * agents: + * type: array + * items: + * $ref: '#/components/schemas/AgentEntity' + * total: + * type: integer + * description: Total number of agents + * limit: + * type: integer + * description: Number of items returned + * offset: + * type: integer + * description: Number of items skipped + * required: + * - agents + * - total + * - limit + * - offset + * + * ListSessionsResponse: + * type: object + * properties: + * sessions: + * type: array + * items: + * $ref: '#/components/schemas/SessionEntity' + * total: + * type: integer + * description: Total number of sessions + * limit: + * type: integer + * description: Number of items returned + * offset: + * type: integer + * description: Number of items skipped + * required: + * - sessions + * - total + * - limit + * - offset + * + * ErrorResponse: + * type: object + * properties: + * error: + * type: object + * properties: + * message: + * type: string + * description: Error message + * type: + * type: string + * description: Error type + * code: + * type: string + * description: Error code + * required: + * - message + * - type + * - code + * required: + * - error + */ + +/** + * @swagger + * /agents: + * post: + * summary: Create a new agent + * tags: [Agents] + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/CreateAgentRequest' + * responses: + * 201: + * description: Agent created successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 400: + * description: Invalid request body + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ +// Agent CRUD routes +agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.createAgent) + +/** + * @swagger + * /agents: + * get: + * summary: List all agents with pagination + * tags: [Agents] + * parameters: + * - in: query + * name: limit + * schema: + * type: integer + * minimum: 1 + * maximum: 100 + * default: 20 + * description: Number of agents to return + * - in: query + * name: offset + * schema: + * type: integer + * minimum: 0 + * default: 0 + * description: Number of agents to skip + * - in: query + * name: status + * schema: + * type: string + * enum: [idle, running, completed, failed, stopped] + * description: Filter by agent status + * responses: + * 200: + * description: List of agents + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ListAgentsResponse' + */ +agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.listAgents) + +/** + * @swagger + * /agents/{agentId}: + * get: + * summary: Get agent by ID + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * responses: + * 200: + * description: Agent details + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ +agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent) +/** + * @swagger + * /agents/{agentId}: + * put: + * summary: Replace agent (full update) + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ReplaceAgentRequest' + * responses: + * 200: + * description: Agent updated successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 400: + * description: Invalid request body + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ +agentsRouter.put('/:agentId', validateAgentId, validateAgentReplace, handleValidationErrors, agentHandlers.updateAgent) +/** + * @swagger + * /agents/{agentId}: + * patch: + * summary: Update agent (partial update) + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/UpdateAgentRequest' + * responses: + * 200: + * description: Agent updated successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/AgentEntity' + * 400: + * description: Invalid request body + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ +agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent) +/** + * @swagger + * /agents/{agentId}: + * delete: + * summary: Delete agent + * tags: [Agents] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * responses: + * 204: + * description: Agent deleted successfully + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ +agentsRouter.delete('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.deleteAgent) + +// Create sessions router with agent context +const createSessionsRouter = (): express.Router => { + const sessionsRouter = express.Router({ mergeParams: true }) + + // Session CRUD routes (nested under agent) + /** + * @swagger + * /agents/{agentId}/sessions: + * post: + * summary: Create a new session for an agent + * tags: [Sessions] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/CreateSessionRequest' + * responses: + * 201: + * description: Session created successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/SessionEntity' + * 400: + * description: Invalid request body + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + sessionsRouter.post('/', validateSession, handleValidationErrors, sessionHandlers.createSession) + + /** + * @swagger + * /agents/{agentId}/sessions: + * get: + * summary: List sessions for an agent + * tags: [Sessions] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * - in: query + * name: limit + * schema: + * type: integer + * minimum: 1 + * maximum: 100 + * default: 20 + * description: Number of sessions to return + * - in: query + * name: offset + * schema: + * type: integer + * minimum: 0 + * default: 0 + * description: Number of sessions to skip + * - in: query + * name: status + * schema: + * type: string + * enum: [idle, running, completed, failed, stopped] + * description: Filter by session status + * responses: + * 200: + * description: List of sessions + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ListSessionsResponse' + * 404: + * description: Agent not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + sessionsRouter.get('/', validatePagination, handleValidationErrors, sessionHandlers.listSessions) + /** + * @swagger + * /agents/{agentId}/sessions/{sessionId}: + * get: + * summary: Get session by ID + * tags: [Sessions] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * - in: path + * name: sessionId + * required: true + * schema: + * type: string + * description: Session ID + * responses: + * 200: + * description: Session details + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/SessionEntity' + * 404: + * description: Agent or session not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + sessionsRouter.get('/:sessionId', validateSessionId, handleValidationErrors, sessionHandlers.getSession) + /** + * @swagger + * /agents/{agentId}/sessions/{sessionId}: + * put: + * summary: Replace session (full update) + * tags: [Sessions] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * - in: path + * name: sessionId + * required: true + * schema: + * type: string + * description: Session ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ReplaceSessionRequest' + * responses: + * 200: + * description: Session updated successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/SessionEntity' + * 400: + * description: Invalid request body + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + * 404: + * description: Agent or session not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + sessionsRouter.put( + '/:sessionId', + validateSessionId, + validateSessionReplace, + handleValidationErrors, + sessionHandlers.updateSession + ) + /** + * @swagger + * /agents/{agentId}/sessions/{sessionId}: + * patch: + * summary: Update session (partial update) + * tags: [Sessions] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * - in: path + * name: sessionId + * required: true + * schema: + * type: string + * description: Session ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/UpdateSessionRequest' + * responses: + * 200: + * description: Session updated successfully + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/SessionEntity' + * 400: + * description: Invalid request body + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + * 404: + * description: Agent or session not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + sessionsRouter.patch( + '/:sessionId', + validateSessionId, + validateSessionUpdate, + handleValidationErrors, + sessionHandlers.patchSession + ) + /** + * @swagger + * /agents/{agentId}/sessions/{sessionId}: + * delete: + * summary: Delete session + * tags: [Sessions] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * - in: path + * name: sessionId + * required: true + * schema: + * type: string + * description: Session ID + * responses: + * 204: + * description: Session deleted successfully + * 404: + * description: Agent or session not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + sessionsRouter.delete('/:sessionId', validateSessionId, handleValidationErrors, sessionHandlers.deleteSession) + + return sessionsRouter +} + +// Create messages router with agent and session context +const createMessagesRouter = (): express.Router => { + const messagesRouter = express.Router({ mergeParams: true }) + + // Message CRUD routes (nested under agent/session) + /** + * @swagger + * /agents/{agentId}/sessions/{sessionId}/messages: + * post: + * summary: Create a new message in a session + * tags: [Messages] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * - in: path + * name: sessionId + * required: true + * schema: + * type: string + * description: Session ID + * requestBody: + * required: true + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/CreateSessionMessageRequest' + * responses: + * 201: + * description: Message created successfully + * content: + * application/json: + * schema: + * type: object + * properties: + * id: + * type: number + * description: Message ID + * session_id: + * type: string + * description: Session ID + * role: + * type: string + * enum: [assistant, user, system, tool] + * description: Message role + * content: + * type: object + * description: Message content (AI SDK format) + * agent_session_id: + * type: string + * description: Agent session ID for resuming + * metadata: + * type: object + * description: Additional metadata + * created_at: + * type: string + * format: date-time + * updated_at: + * type: string + * format: date-time + * 400: + * description: Invalid request body + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + * 404: + * description: Agent or session not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + messagesRouter.post('/', validateSessionMessage, handleValidationErrors, messageHandlers.createMessage) + + /** + * @swagger + * /agents/{agentId}/sessions/{sessionId}/messages/{messageId}: + * delete: + * summary: Delete a message from a session + * tags: [Messages] + * parameters: + * - in: path + * name: agentId + * required: true + * schema: + * type: string + * description: Agent ID + * - in: path + * name: sessionId + * required: true + * schema: + * type: string + * description: Session ID + * - in: path + * name: messageId + * required: true + * schema: + * type: integer + * description: Message ID + * responses: + * 204: + * description: Message deleted successfully + * 404: + * description: Agent, session, or message not found + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/ErrorResponse' + */ + messagesRouter.delete('/:messageId', validateSessionMessageId, handleValidationErrors, messageHandlers.deleteMessage) + return messagesRouter +} + +// Mount nested resources with clear hierarchy +const sessionsRouter = createSessionsRouter() +const messagesRouter = createMessagesRouter() + +// Mount sessions under specific agent +agentsRouter.use('/:agentId/sessions', validateAgentId, checkAgentExists, handleValidationErrors, sessionsRouter) + +// Mount messages under specific agent/session +agentsRouter.use( + '/:agentId/sessions/:sessionId/messages', + validateAgentId, + validateSessionId, + handleValidationErrors, + messagesRouter +) + +// Export main router and convenience router +export const agentsRoutes = agentsRouter diff --git a/src/main/apiServer/routes/agents/middleware/common.ts b/src/main/apiServer/routes/agents/middleware/common.ts new file mode 100644 index 0000000000..d45f197e4a --- /dev/null +++ b/src/main/apiServer/routes/agents/middleware/common.ts @@ -0,0 +1,44 @@ +import { Request, Response } from 'express' + +import { agentService } from '../../../../services/agents' +import { loggerService } from '../../../../services/LoggerService' + +const logger = loggerService.withContext('ApiServerMiddleware') + +// Since Zod validators handle their own errors, this is now a pass-through +export const handleValidationErrors = (_req: Request, _res: Response, next: any): void => { + next() +} + +// Middleware to check if agent exists +export const checkAgentExists = async (req: Request, res: Response, next: any): Promise => { + try { + const { agentId } = req.params + const exists = await agentService.agentExists(agentId) + + if (!exists) { + res.status(404).json({ + error: { + message: 'Agent not found', + type: 'not_found', + code: 'agent_not_found' + } + }) + return + } + + next() + } catch (error) { + logger.error('Error checking agent existence', { + error: error as Error, + agentId: req.params.agentId + }) + res.status(500).json({ + error: { + message: 'Failed to validate agent', + type: 'internal_error', + code: 'agent_validation_failed' + } + }) + } +} diff --git a/src/main/apiServer/routes/agents/middleware/index.ts b/src/main/apiServer/routes/agents/middleware/index.ts new file mode 100644 index 0000000000..89a3196b12 --- /dev/null +++ b/src/main/apiServer/routes/agents/middleware/index.ts @@ -0,0 +1 @@ +export * from './common' diff --git a/src/main/apiServer/routes/agents/validators/agents.ts b/src/main/apiServer/routes/agents/validators/agents.ts new file mode 100644 index 0000000000..4b29e66929 --- /dev/null +++ b/src/main/apiServer/routes/agents/validators/agents.ts @@ -0,0 +1,24 @@ +import { + AgentIdParamSchema, + CreateAgentRequestSchema, + ReplaceAgentRequestSchema, + UpdateAgentRequestSchema +} from '@types' + +import { createZodValidator } from './zodValidator' + +export const validateAgent = createZodValidator({ + body: CreateAgentRequestSchema +}) + +export const validateAgentReplace = createZodValidator({ + body: ReplaceAgentRequestSchema +}) + +export const validateAgentUpdate = createZodValidator({ + body: UpdateAgentRequestSchema +}) + +export const validateAgentId = createZodValidator({ + params: AgentIdParamSchema +}) diff --git a/src/main/apiServer/routes/agents/validators/common.ts b/src/main/apiServer/routes/agents/validators/common.ts new file mode 100644 index 0000000000..4e9a8ceaa2 --- /dev/null +++ b/src/main/apiServer/routes/agents/validators/common.ts @@ -0,0 +1,7 @@ +import { PaginationQuerySchema } from '@types' + +import { createZodValidator } from './zodValidator' + +export const validatePagination = createZodValidator({ + query: PaginationQuerySchema +}) diff --git a/src/main/apiServer/routes/agents/validators/index.ts b/src/main/apiServer/routes/agents/validators/index.ts new file mode 100644 index 0000000000..7bba43e3b7 --- /dev/null +++ b/src/main/apiServer/routes/agents/validators/index.ts @@ -0,0 +1,4 @@ +export * from './agents' +export * from './common' +export * from './messages' +export * from './sessions' diff --git a/src/main/apiServer/routes/agents/validators/messages.ts b/src/main/apiServer/routes/agents/validators/messages.ts new file mode 100644 index 0000000000..8d7cddfa7b --- /dev/null +++ b/src/main/apiServer/routes/agents/validators/messages.ts @@ -0,0 +1,11 @@ +import { CreateSessionMessageRequestSchema, SessionMessageIdParamSchema } from '@types' + +import { createZodValidator } from './zodValidator' + +export const validateSessionMessage = createZodValidator({ + body: CreateSessionMessageRequestSchema +}) + +export const validateSessionMessageId = createZodValidator({ + params: SessionMessageIdParamSchema +}) diff --git a/src/main/apiServer/routes/agents/validators/sessions.ts b/src/main/apiServer/routes/agents/validators/sessions.ts new file mode 100644 index 0000000000..5081849649 --- /dev/null +++ b/src/main/apiServer/routes/agents/validators/sessions.ts @@ -0,0 +1,24 @@ +import { + CreateSessionRequestSchema, + ReplaceSessionRequestSchema, + SessionIdParamSchema, + UpdateSessionRequestSchema +} from '@types' + +import { createZodValidator } from './zodValidator' + +export const validateSession = createZodValidator({ + body: CreateSessionRequestSchema +}) + +export const validateSessionReplace = createZodValidator({ + body: ReplaceSessionRequestSchema +}) + +export const validateSessionUpdate = createZodValidator({ + body: UpdateSessionRequestSchema +}) + +export const validateSessionId = createZodValidator({ + params: SessionIdParamSchema +}) diff --git a/src/main/apiServer/routes/agents/validators/zodValidator.ts b/src/main/apiServer/routes/agents/validators/zodValidator.ts new file mode 100644 index 0000000000..1a0e83786a --- /dev/null +++ b/src/main/apiServer/routes/agents/validators/zodValidator.ts @@ -0,0 +1,68 @@ +import { NextFunction, Request, Response } from 'express' +import { ZodError, ZodType } from 'zod' + +export interface ValidationRequest extends Request { + validatedBody?: any + validatedParams?: any + validatedQuery?: any +} + +export interface ZodValidationConfig { + body?: ZodType + params?: ZodType + query?: ZodType +} + +export const createZodValidator = (config: ZodValidationConfig) => { + return (req: ValidationRequest, res: Response, next: NextFunction): void => { + try { + if (config.body && req.body) { + req.validatedBody = config.body.parse(req.body) + } + + if (config.params && req.params) { + req.validatedParams = config.params.parse(req.params) + } + + if (config.query && req.query) { + req.validatedQuery = config.query.parse(req.query) + } + + next() + } catch (error) { + if (error instanceof ZodError) { + const validationErrors = error.issues.map((err) => ({ + type: 'field', + value: err.input, + msg: err.message, + path: err.path.map((p) => String(p)).join('.'), + location: getLocationFromPath(err.path, config) + })) + + res.status(400).json({ + error: { + message: 'Validation failed', + type: 'validation_error', + details: validationErrors + } + }) + return + } + + res.status(500).json({ + error: { + message: 'Internal validation error', + type: 'internal_error', + code: 'validation_processing_failed' + } + }) + } + } +} + +function getLocationFromPath(path: (string | number | symbol)[], config: ZodValidationConfig): string { + if (config.body && path.length > 0) return 'body' + if (config.params && path.length > 0) return 'params' + if (config.query && path.length > 0) return 'query' + return 'unknown' +} diff --git a/src/main/apiServer/routes/chat.ts b/src/main/apiServer/routes/chat.ts index f9c9c357e6..0338fd26b7 100644 --- a/src/main/apiServer/routes/chat.ts +++ b/src/main/apiServer/routes/chat.ts @@ -1,16 +1,106 @@ import type { Request, Response } from 'express' import express from 'express' -import OpenAI from 'openai' import type { ChatCompletionCreateParams } from 'openai/resources' import { loggerService } from '../../services/LoggerService' -import { chatCompletionService } from '../services/chat-completion' -import { validateModelId } from '../utils' +import { + ChatCompletionModelError, + chatCompletionService, + ChatCompletionValidationError +} from '../services/chat-completion' const logger = loggerService.withContext('ApiServerChatRoutes') const router = express.Router() +interface ErrorResponseBody { + error: { + message: string + type: string + code: string + } +} + +const mapChatCompletionError = (error: unknown): { status: number; body: ErrorResponseBody } => { + if (error instanceof ChatCompletionValidationError) { + logger.warn('Chat completion validation error', { + errors: error.errors + }) + + return { + status: 400, + body: { + error: { + message: error.errors.join('; '), + type: 'invalid_request_error', + code: 'validation_failed' + } + } + } + } + + if (error instanceof ChatCompletionModelError) { + logger.warn('Chat completion model error', error.error) + + return { + status: 400, + body: { + error: { + message: error.error.message, + type: 'invalid_request_error', + code: error.error.code + } + } + } + } + + if (error instanceof Error) { + let statusCode = 500 + let errorType = 'server_error' + let errorCode = 'internal_error' + + if (error.message.includes('API key') || error.message.includes('authentication')) { + statusCode = 401 + errorType = 'authentication_error' + errorCode = 'invalid_api_key' + } else if (error.message.includes('rate limit') || error.message.includes('quota')) { + statusCode = 429 + errorType = 'rate_limit_error' + errorCode = 'rate_limit_exceeded' + } else if (error.message.includes('timeout') || error.message.includes('connection')) { + statusCode = 502 + errorType = 'server_error' + errorCode = 'upstream_error' + } + + logger.error('Chat completion error', { error }) + + return { + status: statusCode, + body: { + error: { + message: error.message || 'Internal server error', + type: errorType, + code: errorCode + } + } + } + } + + logger.error('Chat completion unknown error', { error }) + + return { + status: 500, + body: { + error: { + message: 'Internal server error', + type: 'server_error', + code: 'internal_error' + } + } + } +} + /** * @swagger * /v1/chat/completions: @@ -61,7 +151,7 @@ const router = express.Router() * type: integer * total_tokens: * type: integer - * text/plain: + * text/event-stream: * schema: * type: string * description: Server-sent events stream (when stream=true) @@ -104,72 +194,31 @@ router.post('/completions', async (req: Request, res: Response) => { }) } - logger.info('Chat completion request:', { + logger.debug('Chat completion request', { model: request.model, messageCount: request.messages?.length || 0, stream: request.stream, temperature: request.temperature }) - // Validate request - const validation = chatCompletionService.validateRequest(request) - if (!validation.isValid) { - return res.status(400).json({ - error: { - message: validation.errors.join('; '), - type: 'invalid_request_error', - code: 'validation_failed' - } - }) - } + const isStreaming = !!request.stream - // Validate model ID and get provider - const modelValidation = await validateModelId(request.model) - if (!modelValidation.valid) { - const error = modelValidation.error! - logger.warn(`Model validation failed for '${request.model}':`, error) - return res.status(400).json({ - error: { - message: error.message, - type: 'invalid_request_error', - code: error.code - } - }) - } + if (isStreaming) { + const { stream } = await chatCompletionService.processStreamingCompletion(request) - const provider = modelValidation.provider! - const modelId = modelValidation.modelId! - - logger.info('Model validation successful:', { - provider: provider.id, - providerType: provider.type, - modelId: modelId, - fullModelId: request.model - }) - - // Create OpenAI client - const client = new OpenAI({ - baseURL: provider.apiHost, - apiKey: provider.apiKey - }) - request.model = modelId - - // Handle streaming - if (request.stream) { - const streamResponse = await client.chat.completions.create(request) - - res.setHeader('Content-Type', 'text/plain; charset=utf-8') - res.setHeader('Cache-Control', 'no-cache') + res.setHeader('Content-Type', 'text/event-stream; charset=utf-8') + res.setHeader('Cache-Control', 'no-cache, no-transform') res.setHeader('Connection', 'keep-alive') + res.setHeader('X-Accel-Buffering', 'no') + res.flushHeaders() try { - for await (const chunk of streamResponse as any) { + for await (const chunk of stream) { res.write(`data: ${JSON.stringify(chunk)}\n\n`) } res.write('data: [DONE]\n\n') - res.end() } catch (streamError: any) { - logger.error('Stream error:', streamError) + logger.error('Stream error', { error: streamError }) res.write( `data: ${JSON.stringify({ error: { @@ -179,47 +228,17 @@ router.post('/completions', async (req: Request, res: Response) => { } })}\n\n` ) + } finally { res.end() } return } - // Handle non-streaming - const response = await client.chat.completions.create(request) + const { response } = await chatCompletionService.processCompletion(request) return res.json(response) - } catch (error: any) { - logger.error('Chat completion error:', error) - - let statusCode = 500 - let errorType = 'server_error' - let errorCode = 'internal_error' - let errorMessage = 'Internal server error' - - if (error instanceof Error) { - errorMessage = error.message - - if (error.message.includes('API key') || error.message.includes('authentication')) { - statusCode = 401 - errorType = 'authentication_error' - errorCode = 'invalid_api_key' - } else if (error.message.includes('rate limit') || error.message.includes('quota')) { - statusCode = 429 - errorType = 'rate_limit_error' - errorCode = 'rate_limit_exceeded' - } else if (error.message.includes('timeout') || error.message.includes('connection')) { - statusCode = 502 - errorType = 'server_error' - errorCode = 'upstream_error' - } - } - - return res.status(statusCode).json({ - error: { - message: errorMessage, - type: errorType, - code: errorCode - } - }) + } catch (error: unknown) { + const { status, body } = mapChatCompletionError(error) + return res.status(status).json(body) } }) diff --git a/src/main/apiServer/routes/mcp.ts b/src/main/apiServer/routes/mcp.ts index e36c57ed36..90626af158 100644 --- a/src/main/apiServer/routes/mcp.ts +++ b/src/main/apiServer/routes/mcp.ts @@ -44,14 +44,14 @@ const router = express.Router() */ router.get('/', async (req: Request, res: Response) => { try { - logger.info('Get all MCP servers request received') + logger.debug('Listing MCP servers') const servers = await mcpApiService.getAllServers(req) return res.json({ success: true, data: servers }) } catch (error: any) { - logger.error('Error fetching MCP servers:', error) + logger.error('Error fetching MCP servers', { error }) return res.status(503).json({ success: false, error: { @@ -104,10 +104,12 @@ router.get('/', async (req: Request, res: Response) => { */ router.get('/:server_id', async (req: Request, res: Response) => { try { - logger.info('Get MCP server info request received') + logger.debug('Get MCP server info request received', { + serverId: req.params.server_id + }) const server = await mcpApiService.getServerInfo(req.params.server_id) if (!server) { - logger.warn('MCP server not found') + logger.warn('MCP server not found', { serverId: req.params.server_id }) return res.status(404).json({ success: false, error: { @@ -122,7 +124,7 @@ router.get('/:server_id', async (req: Request, res: Response) => { data: server }) } catch (error: any) { - logger.error('Error fetching MCP server info:', error) + logger.error('Error fetching MCP server info', { error, serverId: req.params.server_id }) return res.status(503).json({ success: false, error: { @@ -138,7 +140,7 @@ router.get('/:server_id', async (req: Request, res: Response) => { router.all('/:server_id/mcp', async (req: Request, res: Response) => { const server = await mcpApiService.getServerById(req.params.server_id) if (!server) { - logger.warn('MCP server not found') + logger.warn('MCP server not found', { serverId: req.params.server_id }) return res.status(404).json({ success: false, error: { diff --git a/src/main/apiServer/routes/messages.ts b/src/main/apiServer/routes/messages.ts new file mode 100644 index 0000000000..3b7c338199 --- /dev/null +++ b/src/main/apiServer/routes/messages.ts @@ -0,0 +1,403 @@ +import { MessageCreateParams } from '@anthropic-ai/sdk/resources' +import { loggerService } from '@logger' +import { Provider } from '@types' +import express, { Request, Response } from 'express' + +import { messagesService } from '../services/messages' +import { getProviderById, validateModelId } from '../utils' + +const logger = loggerService.withContext('ApiServerMessagesRoutes') + +const router = express.Router() +const providerRouter = express.Router({ mergeParams: true }) + +// Helper function for basic request validation +async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> { + const request: MessageCreateParams = req.body + + if (!request) { + return { + valid: false, + error: { + type: 'error', + error: { + type: 'invalid_request_error', + message: 'Request body is required' + } + } + } + } + + return { valid: true } +} + +interface HandleMessageProcessingOptions { + req: Request + res: Response + provider: Provider + request: MessageCreateParams + modelId?: string +} + +async function handleMessageProcessing({ + req, + res, + provider, + request, + modelId +}: HandleMessageProcessingOptions): Promise { + try { + const validation = messagesService.validateRequest(request) + if (!validation.isValid) { + res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: validation.errors.join('; ') + } + }) + return + } + + const extraHeaders = messagesService.prepareHeaders(req.headers) + const { client, anthropicRequest } = await messagesService.processMessage({ + provider, + request, + extraHeaders, + modelId + }) + + if (request.stream) { + await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider) + return + } + + const response = await client.messages.create(anthropicRequest) + res.json(response) + } catch (error: any) { + logger.error('Message processing error', { error }) + const { statusCode, errorResponse } = messagesService.transformError(error) + res.status(statusCode).json(errorResponse) + } +} + +/** + * @swagger + * /v1/messages: + * post: + * summary: Create message + * description: Create a message response using Anthropic's API format + * tags: [Messages] + * requestBody: + * required: true + * content: + * application/json: + * schema: + * type: object + * required: + * - model + * - max_tokens + * - messages + * properties: + * model: + * type: string + * description: Model ID in format "provider:model_id" + * example: "my-anthropic:claude-3-5-sonnet-20241022" + * max_tokens: + * type: integer + * minimum: 1 + * description: Maximum number of tokens to generate + * example: 1024 + * messages: + * type: array + * items: + * type: object + * properties: + * role: + * type: string + * enum: [user, assistant] + * content: + * oneOf: + * - type: string + * - type: array + * system: + * type: string + * description: System message + * temperature: + * type: number + * minimum: 0 + * maximum: 1 + * description: Sampling temperature + * top_p: + * type: number + * minimum: 0 + * maximum: 1 + * description: Nucleus sampling + * top_k: + * type: integer + * minimum: 0 + * description: Top-k sampling + * stream: + * type: boolean + * description: Whether to stream the response + * tools: + * type: array + * description: Available tools for the model + * responses: + * 200: + * description: Message response + * content: + * application/json: + * schema: + * type: object + * properties: + * id: + * type: string + * type: + * type: string + * example: message + * role: + * type: string + * example: assistant + * content: + * type: array + * items: + * type: object + * model: + * type: string + * stop_reason: + * type: string + * stop_sequence: + * type: string + * usage: + * type: object + * properties: + * input_tokens: + * type: integer + * output_tokens: + * type: integer + * text/event-stream: + * schema: + * type: string + * description: Server-sent events stream (when stream=true) + * 400: + * description: Bad request + * content: + * application/json: + * schema: + * type: object + * properties: + * type: + * type: string + * example: error + * error: + * type: object + * properties: + * type: + * type: string + * message: + * type: string + * 401: + * description: Unauthorized + * 429: + * description: Rate limit exceeded + * 500: + * description: Internal server error + */ +router.post('/', async (req: Request, res: Response) => { + // Validate request body + const bodyValidation = await validateRequestBody(req) + if (!bodyValidation.valid) { + return res.status(400).json(bodyValidation.error) + } + + try { + const request: MessageCreateParams = req.body + + // Validate model ID and get provider + const modelValidation = await validateModelId(request.model) + if (!modelValidation.valid) { + const error = modelValidation.error! + logger.warn('Model validation failed', { + model: request.model, + error + }) + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: error.message + } + }) + } + + const provider = modelValidation.provider! + const modelId = modelValidation.modelId! + + return handleMessageProcessing({ req, res, provider, request, modelId }) + } catch (error: any) { + logger.error('Message processing error', { error }) + const { statusCode, errorResponse } = messagesService.transformError(error) + return res.status(statusCode).json(errorResponse) + } +}) + +/** + * @swagger + * /{provider_id}/v1/messages: + * post: + * summary: Create message with provider in path + * description: Create a message response using provider ID from URL path + * tags: [Messages] + * parameters: + * - in: path + * name: provider_id + * required: true + * schema: + * type: string + * description: Provider ID (e.g., "my-anthropic") + * example: "my-anthropic" + * requestBody: + * required: true + * content: + * application/json: + * schema: + * type: object + * required: + * - model + * - max_tokens + * - messages + * properties: + * model: + * type: string + * description: Model ID without provider prefix + * example: "claude-3-5-sonnet-20241022" + * max_tokens: + * type: integer + * minimum: 1 + * description: Maximum number of tokens to generate + * example: 1024 + * messages: + * type: array + * items: + * type: object + * properties: + * role: + * type: string + * enum: [user, assistant] + * content: + * oneOf: + * - type: string + * - type: array + * system: + * type: string + * description: System message + * temperature: + * type: number + * minimum: 0 + * maximum: 1 + * description: Sampling temperature + * top_p: + * type: number + * minimum: 0 + * maximum: 1 + * description: Nucleus sampling + * top_k: + * type: integer + * minimum: 0 + * description: Top-k sampling + * stream: + * type: boolean + * description: Whether to stream the response + * tools: + * type: array + * description: Available tools for the model + * responses: + * 200: + * description: Message response + * content: + * application/json: + * schema: + * type: object + * properties: + * id: + * type: string + * type: + * type: string + * example: message + * role: + * type: string + * example: assistant + * content: + * type: array + * items: + * type: object + * model: + * type: string + * stop_reason: + * type: string + * stop_sequence: + * type: string + * usage: + * type: object + * properties: + * input_tokens: + * type: integer + * output_tokens: + * type: integer + * text/event-stream: + * schema: + * type: string + * description: Server-sent events stream (when stream=true) + * 400: + * description: Bad request + * 401: + * description: Unauthorized + * 429: + * description: Rate limit exceeded + * 500: + * description: Internal server error + */ +providerRouter.post('/', async (req: Request, res: Response) => { + // Validate request body + const bodyValidation = await validateRequestBody(req) + if (!bodyValidation.valid) { + return res.status(400).json(bodyValidation.error) + } + + try { + const providerId = req.params.provider + + if (!providerId) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: 'Provider ID is required in URL path' + } + }) + } + + // Get provider directly by ID from URL path + const provider = await getProviderById(providerId) + if (!provider) { + return res.status(400).json({ + type: 'error', + error: { + type: 'invalid_request_error', + message: `Provider '${providerId}' not found or not enabled` + } + }) + } + + const request: MessageCreateParams = req.body + + return handleMessageProcessing({ req, res, provider, request }) + } catch (error: any) { + logger.error('Message processing error', { error }) + const { statusCode, errorResponse } = messagesService.transformError(error) + return res.status(statusCode).json(errorResponse) + } +}) + +export { providerRouter as messagesProviderRoutes, router as messagesRoutes } diff --git a/src/main/apiServer/routes/models.ts b/src/main/apiServer/routes/models.ts index f4761d688c..4295391c07 100644 --- a/src/main/apiServer/routes/models.ts +++ b/src/main/apiServer/routes/models.ts @@ -1,74 +1,126 @@ +import type { ApiModelsResponse } from '@types' +import { ApiModelsFilterSchema } from '@types' import type { Request, Response } from 'express' import express from 'express' -import { loggerService } from '../../services/LoggerService' -import { chatCompletionService } from '../services/chat-completion' +import { loggerService } from '@logger' +import { modelsService } from '../services/models' const logger = loggerService.withContext('ApiServerModelsRoutes') -const router = express.Router() +const router = express + .Router() -/** - * @swagger - * /v1/models: - * get: - * summary: List available models - * description: Returns a list of available AI models from all configured providers - * tags: [Models] - * responses: - * 200: - * description: List of available models - * content: - * application/json: - * schema: - * type: object - * properties: - * object: - * type: string - * example: list - * data: - * type: array - * items: - * $ref: '#/components/schemas/Model' - * 503: - * description: Service unavailable - * content: - * application/json: - * schema: - * $ref: '#/components/schemas/Error' - */ -router.get('/', async (_req: Request, res: Response) => { - try { - logger.info('Models list request received') + /** + * @swagger + * /v1/models: + * get: + * summary: List available models + * description: Returns a list of available AI models from all configured providers with optional filtering + * tags: [Models] + * parameters: + * - in: query + * name: providerType + * schema: + * type: string + * enum: [openai, openai-response, anthropic, gemini] + * description: Filter models by provider type + * - in: query + * name: offset + * schema: + * type: integer + * minimum: 0 + * default: 0 + * description: Pagination offset + * - in: query + * name: limit + * schema: + * type: integer + * minimum: 1 + * description: Maximum number of models to return + * responses: + * 200: + * description: List of available models + * content: + * application/json: + * schema: + * type: object + * properties: + * object: + * type: string + * example: list + * data: + * type: array + * items: + * $ref: '#/components/schemas/Model' + * total: + * type: integer + * description: Total number of models (when using pagination) + * offset: + * type: integer + * description: Current offset (when using pagination) + * limit: + * type: integer + * description: Current limit (when using pagination) + * 400: + * description: Invalid query parameters + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + * 503: + * description: Service unavailable + * content: + * application/json: + * schema: + * $ref: '#/components/schemas/Error' + */ + .get('/', async (req: Request, res: Response) => { + try { + logger.debug('Models list request received', { query: req.query }) - const models = await chatCompletionService.getModels() + // Validate query parameters using Zod schema + const filterResult = ApiModelsFilterSchema.safeParse(req.query) - if (models.length === 0) { - logger.warn( - 'No models available from providers. This may be because no OpenAI providers are configured or enabled.' - ) - } - - logger.info(`Returning ${models.length} models (OpenAI providers only)`) - logger.debug( - 'Model IDs:', - models.map((m) => m.id) - ) - - return res.json({ - object: 'list', - data: models - }) - } catch (error: any) { - logger.error('Error fetching models:', error) - return res.status(503).json({ - error: { - message: 'Failed to retrieve models from available providers', - type: 'service_unavailable', - code: 'models_unavailable' + if (!filterResult.success) { + logger.warn('Invalid model query parameters', { issues: filterResult.error.issues }) + return res.status(400).json({ + error: { + message: 'Invalid query parameters', + type: 'invalid_request_error', + code: 'invalid_parameters', + details: filterResult.error.issues.map((issue) => ({ + field: issue.path.join('.'), + message: issue.message + })) + } + }) } - }) - } -}) + + const filter = filterResult.data + const response = await modelsService.getModels(filter) + + if (response.data.length === 0) { + logger.warn('No models available from providers', { filter }) + } + + logger.info('Models response ready', { + filter, + total: response.total, + modelIds: response.data.map((m) => m.id) + }) + + return res.json(response satisfies ApiModelsResponse) + } catch (error: any) { + logger.error('Error fetching models', { error }) + return res.status(503).json({ + error: { + message: 'Failed to retrieve models from available providers', + type: 'service_unavailable', + code: 'models_unavailable' + } + }) + } + }) export { router as modelsRoutes } diff --git a/src/main/apiServer/server.ts b/src/main/apiServer/server.ts index 2555fa8c2e..0cba77aaa3 100644 --- a/src/main/apiServer/server.ts +++ b/src/main/apiServer/server.ts @@ -1,11 +1,16 @@ import { createServer } from 'node:http' +import { agentService } from '../services/agents' import { loggerService } from '../services/LoggerService' import { app } from './app' import { config } from './config' const logger = loggerService.withContext('ApiServer') +const GLOBAL_REQUEST_TIMEOUT_MS = 5 * 60_000 +const GLOBAL_HEADERS_TIMEOUT_MS = GLOBAL_REQUEST_TIMEOUT_MS + 5_000 +const GLOBAL_KEEPALIVE_TIMEOUT_MS = 60_000 + export class ApiServer { private server: ReturnType | null = null @@ -16,16 +21,21 @@ export class ApiServer { } // Load config - const { port, host, apiKey } = await config.load() + const { port, host } = await config.load() + + // Initialize AgentService + logger.info('Initializing AgentService') + await agentService.initialize() + logger.info('AgentService initialized') // Create server with Express app this.server = createServer(app) + this.applyServerTimeouts(this.server) // Start server return new Promise((resolve, reject) => { this.server!.listen(port, host, () => { - logger.info(`API Server started at http://${host}:${port}`) - logger.info(`API Key: ${apiKey}`) + logger.info('API server started', { host, port }) resolve() }) @@ -33,12 +43,19 @@ export class ApiServer { }) } + private applyServerTimeouts(server: ReturnType): void { + server.requestTimeout = GLOBAL_REQUEST_TIMEOUT_MS + server.headersTimeout = Math.max(GLOBAL_HEADERS_TIMEOUT_MS, server.requestTimeout + 1_000) + server.keepAliveTimeout = GLOBAL_KEEPALIVE_TIMEOUT_MS + server.setTimeout(0) + } + async stop(): Promise { if (!this.server) return return new Promise((resolve) => { this.server!.close(() => { - logger.info('API Server stopped') + logger.info('API server stopped') this.server = null resolve() }) @@ -56,7 +73,7 @@ export class ApiServer { const isListening = this.server?.listening || false const result = hasServer && isListening - logger.debug('isRunning check:', { hasServer, isListening, result }) + logger.debug('isRunning check', { hasServer, isListening, result }) return result } diff --git a/src/main/apiServer/services/chat-completion.ts b/src/main/apiServer/services/chat-completion.ts index 5ea077eb59..9ccf363b43 100644 --- a/src/main/apiServer/services/chat-completion.ts +++ b/src/main/apiServer/services/chat-completion.ts @@ -1,83 +1,132 @@ +import type { Provider } from '@types' import OpenAI from 'openai' -import type { ChatCompletionCreateParams } from 'openai/resources' +import type { ChatCompletionCreateParams, ChatCompletionCreateParamsStreaming } from 'openai/resources' -import { loggerService } from '../../services/LoggerService' -import type { OpenAICompatibleModel } from '../utils' -import { - getProviderByModel, - getRealProviderModel, - listAllAvailableModels, - transformModelToOpenAI, - validateProvider -} from '../utils' +import { loggerService } from '@logger' +import { type ModelValidationError, validateModelId } from '../utils' const logger = loggerService.withContext('ChatCompletionService') -export interface ModelData extends OpenAICompatibleModel { - provider_id: string - model_id: string - name: string -} - export interface ValidationResult { isValid: boolean errors: string[] } +export class ChatCompletionValidationError extends Error { + constructor(public readonly errors: string[]) { + super(`Request validation failed: ${errors.join('; ')}`) + this.name = 'ChatCompletionValidationError' + } +} + +export class ChatCompletionModelError extends Error { + constructor(public readonly error: ModelValidationError) { + super(`Model validation failed: ${error.message}`) + this.name = 'ChatCompletionModelError' + } +} + +export type PrepareRequestResult = + | { status: 'validation_error'; errors: string[] } + | { status: 'model_error'; error: ModelValidationError } + | { + status: 'ok' + provider: Provider + modelId: string + client: OpenAI + providerRequest: ChatCompletionCreateParams + } + export class ChatCompletionService { - async getModels(): Promise { - try { - logger.info('Getting available models from providers') + async resolveProviderContext( + model: string + ): Promise< + { ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI } + > { + const modelValidation = await validateModelId(model) + if (!modelValidation.valid) { + return { + ok: false, + error: modelValidation.error! + } + } - const models = await listAllAvailableModels() + const provider = modelValidation.provider! - // Use Map to deduplicate models by their full ID (provider:model_id) - const uniqueModels = new Map() - - for (const model of models) { - const openAIModel = transformModelToOpenAI(model) - const fullModelId = openAIModel.id // This is already in format "provider:model_id" - - // Only add if not already present (first occurrence wins) - if (!uniqueModels.has(fullModelId)) { - uniqueModels.set(fullModelId, { - ...openAIModel, - provider_id: model.provider, - model_id: model.id, - name: model.name - }) - } else { - logger.debug(`Skipping duplicate model: ${fullModelId}`) + if (provider.type !== 'openai') { + return { + ok: false, + error: { + type: 'unsupported_provider_type', + message: `Provider '${provider.id}' of type '${provider.type}' is not supported for OpenAI chat completions`, + code: 'unsupported_provider_type' } } + } - const modelData = Array.from(uniqueModels.values()) + const modelId = modelValidation.modelId! - logger.info(`Successfully retrieved ${modelData.length} unique models from ${models.length} total models`) + const client = new OpenAI({ + baseURL: provider.apiHost, + apiKey: provider.apiKey + }) - if (models.length > modelData.length) { - logger.debug(`Filtered out ${models.length - modelData.length} duplicate models`) + return { + ok: true, + provider, + modelId, + client + } + } + + async prepareRequest(request: ChatCompletionCreateParams, stream: boolean): Promise { + const requestValidation = this.validateRequest(request) + if (!requestValidation.isValid) { + return { + status: 'validation_error', + errors: requestValidation.errors } + } - return modelData - } catch (error: any) { - logger.error('Error getting models:', error) - return [] + const providerContext = await this.resolveProviderContext(request.model!) + if (!providerContext.ok) { + return { + status: 'model_error', + error: providerContext.error + } + } + + const { provider, modelId, client } = providerContext + + logger.debug('Model validation successful', { + provider: provider.id, + providerType: provider.type, + modelId, + fullModelId: request.model + }) + + return { + status: 'ok', + provider, + modelId, + client, + providerRequest: stream + ? { + ...request, + model: modelId, + stream: true as const + } + : { + ...request, + model: modelId, + stream: false as const + } } } validateRequest(request: ChatCompletionCreateParams): ValidationResult { const errors: string[] = [] - // Validate model - if (!request.model) { - errors.push('Model is required') - } else if (typeof request.model !== 'string') { - errors.push('Model must be a string') - } else if (!request.model.includes(':')) { - errors.push('Model must be in format "provider:model_id"') - } - // Validate messages if (!request.messages) { errors.push('Messages array is required') @@ -98,17 +147,6 @@ export class ChatCompletionService { } // Validate optional parameters - if (request.temperature !== undefined) { - if (typeof request.temperature !== 'number' || request.temperature < 0 || request.temperature > 2) { - errors.push('Temperature must be a number between 0 and 2') - } - } - - if (request.max_tokens !== undefined) { - if (typeof request.max_tokens !== 'number' || request.max_tokens < 1) { - errors.push('max_tokens must be a positive number') - } - } return { isValid: errors.length === 0, @@ -116,48 +154,30 @@ export class ChatCompletionService { } } - async processCompletion(request: ChatCompletionCreateParams): Promise { + async processCompletion(request: ChatCompletionCreateParams): Promise<{ + provider: Provider + modelId: string + response: OpenAI.Chat.Completions.ChatCompletion + }> { try { - logger.info('Processing chat completion request:', { + logger.debug('Processing chat completion request', { model: request.model, messageCount: request.messages.length, stream: request.stream }) - // Validate request - const validation = this.validateRequest(request) - if (!validation.isValid) { - throw new Error(`Request validation failed: ${validation.errors.join(', ')}`) + const preparation = await this.prepareRequest(request, false) + if (preparation.status === 'validation_error') { + throw new ChatCompletionValidationError(preparation.errors) } - // Get provider for the model - const provider = await getProviderByModel(request.model!) - if (!provider) { - throw new Error(`Provider not found for model: ${request.model}`) + if (preparation.status === 'model_error') { + throw new ChatCompletionModelError(preparation.error) } - // Validate provider - if (!validateProvider(provider)) { - throw new Error(`Provider validation failed for: ${provider.id}`) - } + const { provider, modelId, client, providerRequest } = preparation - // Extract model ID from the full model string - const modelId = getRealProviderModel(request.model) - - // Create OpenAI client for the provider - const client = new OpenAI({ - baseURL: provider.apiHost, - apiKey: provider.apiKey - }) - - // Prepare request with the actual model ID - const providerRequest = { - ...request, - model: modelId, - stream: false - } - - logger.debug('Sending request to provider:', { + logger.debug('Sending request to provider', { provider: provider.id, model: modelId, apiHost: provider.apiHost @@ -165,71 +185,71 @@ export class ChatCompletionService { const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion - logger.info('Successfully processed chat completion') - return response + logger.info('Chat completion processed', { + modelId, + provider: provider.id + }) + return { + provider, + modelId, + response + } } catch (error: any) { - logger.error('Error processing chat completion:', error) + logger.error('Error processing chat completion', { + error, + model: request.model + }) throw error } } - async *processStreamingCompletion( - request: ChatCompletionCreateParams - ): AsyncIterable { + async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{ + provider: Provider + modelId: string + stream: AsyncIterable + }> { try { - logger.info('Processing streaming chat completion request:', { + logger.debug('Processing streaming chat completion request', { model: request.model, messageCount: request.messages.length }) - // Validate request - const validation = this.validateRequest(request) - if (!validation.isValid) { - throw new Error(`Request validation failed: ${validation.errors.join(', ')}`) + const preparation = await this.prepareRequest(request, true) + if (preparation.status === 'validation_error') { + throw new ChatCompletionValidationError(preparation.errors) } - // Get provider for the model - const provider = await getProviderByModel(request.model!) - if (!provider) { - throw new Error(`Provider not found for model: ${request.model}`) + if (preparation.status === 'model_error') { + throw new ChatCompletionModelError(preparation.error) } - // Validate provider - if (!validateProvider(provider)) { - throw new Error(`Provider validation failed for: ${provider.id}`) - } + const { provider, modelId, client, providerRequest } = preparation - // Extract model ID from the full model string - const modelId = getRealProviderModel(request.model) - - // Create OpenAI client for the provider - const client = new OpenAI({ - baseURL: provider.apiHost, - apiKey: provider.apiKey - }) - - // Prepare streaming request - const streamingRequest = { - ...request, - model: modelId, - stream: true as const - } - - logger.debug('Sending streaming request to provider:', { + logger.debug('Sending streaming request to provider', { provider: provider.id, model: modelId, apiHost: provider.apiHost }) - const stream = await client.chat.completions.create(streamingRequest) + const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming + const stream = (await client.chat.completions.create( + streamRequest + )) as AsyncIterable - for await (const chunk of stream) { - yield chunk + logger.info('Streaming chat completion started', { + modelId, + provider: provider.id + }) + return { + provider, + modelId, + stream } - - logger.info('Successfully completed streaming chat completion') } catch (error: any) { - logger.error('Error processing streaming chat completion:', error) + logger.error('Error processing streaming chat completion', { + error, + model: request.model + }) throw error } } diff --git a/src/main/apiServer/services/mcp.ts b/src/main/apiServer/services/mcp.ts index c743077f1f..d75fadee6c 100644 --- a/src/main/apiServer/services/mcp.ts +++ b/src/main/apiServer/services/mcp.ts @@ -9,8 +9,7 @@ import type { Request, Response } from 'express' import type { IncomingMessage, ServerResponse } from 'http' import { loggerService } from '../../services/LoggerService' -import { reduxService } from '../../services/ReduxService' -import { getMcpServerById } from '../utils/mcp' +import { getMcpServerById, getMCPServersFromRedux } from '../utils/mcp' const logger = loggerService.withContext('MCPApiService') const transports: Record = {} @@ -46,42 +45,18 @@ class MCPApiService extends EventEmitter { constructor() { super() this.initMcpServer() - logger.silly('MCPApiService initialized') + logger.debug('MCPApiService initialized') } private initMcpServer() { this.transport.onmessage = this.onMessage } - /** - * Get servers directly from Redux store - */ - private async getServersFromRedux(): Promise { - try { - logger.silly('Getting servers from Redux store') - - // Try to get from cache first (faster) - const cachedServers = reduxService.selectSync('state.mcp.servers') - if (cachedServers && Array.isArray(cachedServers)) { - logger.silly(`Found ${cachedServers.length} servers in Redux cache`) - return cachedServers - } - - // If cache is not available, get fresh data - const servers = await reduxService.select('state.mcp.servers') - logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`) - return servers || [] - } catch (error: any) { - logger.error('Failed to get servers from Redux:', error) - return [] - } - } - // get all activated servers async getAllServers(req: Request): Promise { try { - const servers = await this.getServersFromRedux() - logger.silly(`Returning ${servers.length} servers`) + const servers = await getMCPServersFromRedux() + logger.debug('Returning servers from Redux', { count: servers.length }) const resp: McpServersResp = { servers: {} } @@ -98,7 +73,7 @@ class MCPApiService extends EventEmitter { } return resp } catch (error: any) { - logger.error('Failed to get all servers:', error) + logger.error('Failed to get all servers', { error }) throw new Error('Failed to retrieve servers') } } @@ -106,87 +81,47 @@ class MCPApiService extends EventEmitter { // get server by id async getServerById(id: string): Promise { try { - logger.silly(`getServerById called with id: ${id}`) - const servers = await this.getServersFromRedux() + logger.debug('getServerById called', { id }) + const servers = await getMCPServersFromRedux() const server = servers.find((s) => s.id === id) if (!server) { - logger.warn(`Server with id ${id} not found`) + logger.warn('Server not found', { id }) return null } - logger.silly(`Returning server with id ${id}`) + logger.debug('Returning server', { id }) return server } catch (error: any) { - logger.error(`Failed to get server with id ${id}:`, error) + logger.error('Failed to get server', { id, error }) throw new Error('Failed to retrieve server') } } async getServerInfo(id: string): Promise { try { - logger.silly(`getServerInfo called with id: ${id}`) const server = await this.getServerById(id) if (!server) { - logger.warn(`Server with id ${id} not found`) + logger.warn('Server not found while fetching info', { id }) return null } - logger.silly(`Returning server info for id ${id}`) const client = await mcpService.initClient(server) const tools = await client.listTools() - - logger.info(`Server with id ${id} info:`, { tools: JSON.stringify(tools) }) - - // const [version, tools, prompts, resources] = await Promise.all([ - // () => { - // try { - // return client.getServerVersion() - // } catch (error) { - // logger.error(`Failed to get server version for id ${id}:`, { error: error }) - // return '1.0.0' - // } - // }, - // (() => { - // try { - // return client.listTools() - // } catch (error) { - // logger.error(`Failed to list tools for id ${id}:`, { error: error }) - // return [] - // } - // })(), - // (() => { - // try { - // return client.listPrompts() - // } catch (error) { - // logger.error(`Failed to list prompts for id ${id}:`, { error: error }) - // return [] - // } - // })(), - // (() => { - // try { - // return client.listResources() - // } catch (error) { - // logger.error(`Failed to list resources for id ${id}:`, { error: error }) - // return [] - // } - // })() - // ]) - return { id: server.id, name: server.name, type: server.type, description: server.description, - tools + tools: tools.tools } } catch (error: any) { - logger.error(`Failed to get server info with id ${id}:`, error) + logger.error('Failed to get server info', { id, error }) throw new Error('Failed to retrieve server info') } } async handleRequest(req: Request, res: Response, server: MCPServer) { const sessionId = req.headers['mcp-session-id'] as string | undefined - logger.silly(`Handling request for server with sessionId ${sessionId}`) + logger.debug('Handling MCP request', { sessionId, serverId: server.id }) let transport: StreamableHTTPServerTransport if (sessionId && transports[sessionId]) { transport = transports[sessionId] @@ -199,7 +134,7 @@ class MCPApiService extends EventEmitter { }) transport.onclose = () => { - logger.info(`Transport for sessionId ${sessionId} closed`) + logger.info('Transport closed', { sessionId }) if (transport.sessionId) { delete transports[transport.sessionId] } @@ -234,12 +169,15 @@ class MCPApiService extends EventEmitter { } } - logger.info(`Request body`, { rawBody: req.body, messages: JSON.stringify(messages) }) + logger.debug('Dispatching MCP request', { + sessionId: transport.sessionId ?? sessionId, + messageCount: messages.length + }) await transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages) } private onMessage(message: JSONRPCMessage, extra?: MessageExtraInfo) { - logger.info(`Received message: ${JSON.stringify(message)}`, extra) + logger.debug('Received MCP message', { message, extra }) // Handle message here } } diff --git a/src/main/apiServer/services/messages.ts b/src/main/apiServer/services/messages.ts new file mode 100644 index 0000000000..edce9a9528 --- /dev/null +++ b/src/main/apiServer/services/messages.ts @@ -0,0 +1,321 @@ +import Anthropic from '@anthropic-ai/sdk' +import { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources' +import { loggerService } from '@logger' +import anthropicService from '@main/services/AnthropicService' +import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' +import { Provider } from '@types' +import { Response } from 'express' + +const logger = loggerService.withContext('MessagesService') +const EXCLUDED_FORWARD_HEADERS: ReadonlySet = new Set([ + 'host', + 'x-api-key', + 'authorization', + 'sentry-trace', + 'baggage', + 'content-length', + 'connection' +]) + +export interface ValidationResult { + isValid: boolean + errors: string[] +} + +export interface ErrorResponse { + type: 'error' + error: { + type: string + message: string + requestId?: string + } +} + +export interface StreamConfig { + response: Response + onChunk?: (chunk: MessageStreamEvent) => void + onError?: (error: any) => void + onComplete?: () => void +} + +export interface ProcessMessageOptions { + provider: Provider + request: MessageCreateParams + extraHeaders?: Record + modelId?: string +} + +export interface ProcessMessageResult { + client: Anthropic + anthropicRequest: MessageCreateParams +} + +export class MessagesService { + validateRequest(request: MessageCreateParams): ValidationResult { + // TODO: Implement comprehensive request validation + const errors: string[] = [] + + if (!request.model || typeof request.model !== 'string') { + errors.push('Model is required') + } + + if (typeof request.max_tokens !== 'number' || !Number.isFinite(request.max_tokens) || request.max_tokens < 1) { + errors.push('max_tokens is required and must be a positive number') + } + + if (!request.messages || !Array.isArray(request.messages) || request.messages.length === 0) { + errors.push('messages is required and must be a non-empty array') + } else { + request.messages.forEach((message, index) => { + if (!message || typeof message !== 'object') { + errors.push(`messages[${index}] must be an object`) + return + } + + if (!('role' in message) || typeof message.role !== 'string' || message.role.trim().length === 0) { + errors.push(`messages[${index}].role is required`) + } + + const content: unknown = message.content + if (content === undefined || content === null) { + errors.push(`messages[${index}].content is required`) + return + } + + if (typeof content === 'string' && content.trim().length === 0) { + errors.push(`messages[${index}].content cannot be empty`) + } else if (Array.isArray(content) && content.length === 0) { + errors.push(`messages[${index}].content must include at least one item when using an array`) + } + }) + } + + return { + isValid: errors.length === 0, + errors + } + } + + async getClient(provider: Provider, extraHeaders?: Record): Promise { + // Create Anthropic client for the provider + if (provider.authType === 'oauth') { + const oauthToken = await anthropicService.getValidAccessToken() + return getSdkClient(provider, oauthToken, extraHeaders) + } + return getSdkClient(provider, null, extraHeaders) + } + + prepareHeaders(headers: Record): Record { + const extraHeaders: Record = {} + + for (const [key, value] of Object.entries(headers)) { + if (value === undefined) { + continue + } + + const normalizedKey = key.toLowerCase() + if (EXCLUDED_FORWARD_HEADERS.has(normalizedKey)) { + continue + } + + extraHeaders[normalizedKey] = value + } + + return extraHeaders + } + + createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams { + const anthropicRequest: MessageCreateParams = { + ...request, + stream: !!request.stream + } + + // Override model if provided + if (modelId) { + anthropicRequest.model = modelId + } + + // Add Claude Code system message for OAuth providers + if (provider.type === 'anthropic' && provider.authType === 'oauth') { + anthropicRequest.system = buildClaudeCodeSystemMessage(request.system) + } + + return anthropicRequest + } + + async handleStreaming( + client: Anthropic, + request: MessageCreateParams, + config: StreamConfig, + provider: Provider + ): Promise { + const { response, onChunk, onError, onComplete } = config + + // Set streaming headers + response.setHeader('Content-Type', 'text/event-stream; charset=utf-8') + response.setHeader('Cache-Control', 'no-cache, no-transform') + response.setHeader('Connection', 'keep-alive') + response.setHeader('X-Accel-Buffering', 'no') + response.flushHeaders() + + const flushableResponse = response as Response & { flush?: () => void } + const flushStream = () => { + if (typeof flushableResponse.flush !== 'function') { + return + } + try { + flushableResponse.flush() + } catch (flushError: unknown) { + logger.warn('Failed to flush streaming response', { error: flushError }) + } + } + + const writeSse = (eventType: string | undefined, payload: unknown) => { + if (response.writableEnded || response.destroyed) { + return + } + + if (eventType) { + response.write(`event: ${eventType}\n`) + } + + const data = typeof payload === 'string' ? payload : JSON.stringify(payload) + response.write(`data: ${data}\n\n`) + flushStream() + } + + try { + const stream = client.messages.stream(request) + for await (const chunk of stream) { + if (response.writableEnded || response.destroyed) { + logger.warn('Streaming response ended before stream completion', { + provider: provider.id, + model: request.model + }) + break + } + + writeSse(chunk.type, chunk) + + if (onChunk) { + onChunk(chunk) + } + } + writeSse(undefined, '[DONE]') + + if (onComplete) { + onComplete() + } + } catch (streamError: any) { + logger.error('Stream error', { + error: streamError, + provider: provider.id, + model: request.model, + apiHost: provider.apiHost, + anthropicApiHost: provider.anthropicApiHost + }) + writeSse(undefined, { + type: 'error', + error: { + type: 'api_error', + message: 'Stream processing error' + } + }) + + if (onError) { + onError(streamError) + } + } finally { + if (!response.writableEnded) { + response.end() + } + } + } + + transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } { + let statusCode = 500 + let errorType = 'api_error' + let errorMessage = 'Internal server error' + + const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined + const anthropicError = error?.error + + if (anthropicStatus) { + statusCode = anthropicStatus + } + + if (anthropicError?.type) { + errorType = anthropicError.type + } + + if (anthropicError?.message) { + errorMessage = anthropicError.message + } else if (error instanceof Error && error.message) { + errorMessage = error.message + } + + // Infer error type from message if not from Anthropic API + if (!anthropicStatus && error instanceof Error) { + const errorMessageText = error.message ?? '' + + if (errorMessageText.includes('API key') || errorMessageText.includes('authentication')) { + statusCode = 401 + errorType = 'authentication_error' + } else if (errorMessageText.includes('rate limit') || errorMessageText.includes('quota')) { + statusCode = 429 + errorType = 'rate_limit_error' + } else if (errorMessageText.includes('timeout') || errorMessageText.includes('connection')) { + statusCode = 502 + errorType = 'api_error' + } else if (errorMessageText.includes('validation') || errorMessageText.includes('invalid')) { + statusCode = 400 + errorType = 'invalid_request_error' + } + } + + const safeErrorMessage = + typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error' + + return { + statusCode, + errorResponse: { + type: 'error', + error: { + type: errorType, + message: safeErrorMessage, + requestId: error?.request_id + } + } + } + } + + async processMessage(options: ProcessMessageOptions): Promise { + const { provider, request, extraHeaders, modelId } = options + + const client = await this.getClient(provider, extraHeaders) + const anthropicRequest = this.createAnthropicRequest(request, provider, modelId) + + const messageCount = Array.isArray(request.messages) ? request.messages.length : 0 + + logger.info('Processing anthropic messages request', { + provider: provider.id, + apiHost: provider.apiHost, + anthropicApiHost: provider.anthropicApiHost, + model: anthropicRequest.model, + stream: !!anthropicRequest.stream, + // systemPrompt: JSON.stringify(!!request.system), + // messages: JSON.stringify(request.messages), + messageCount, + toolCount: Array.isArray(request.tools) ? request.tools.length : 0 + }) + + // Return client and request for route layer to handle streaming/non-streaming + return { + client, + anthropicRequest + } + } +} + +// Export singleton instance +export const messagesService = new MessagesService() diff --git a/src/main/apiServer/services/models.ts b/src/main/apiServer/services/models.ts new file mode 100644 index 0000000000..684d7f10a8 --- /dev/null +++ b/src/main/apiServer/services/models.ts @@ -0,0 +1,108 @@ +import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels' +import { loggerService } from '../../services/LoggerService' +import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils' + +const logger = loggerService.withContext('ModelsService') + +// Re-export for backward compatibility + +export type ModelsFilter = ApiModelsFilter + +export class ModelsService { + async getModels(filter: ModelsFilter): Promise { + try { + logger.debug('Getting available models from providers', { filter }) + + let providers = await getAvailableProviders() + + if (filter.providerType === 'anthropic') { + providers = providers.filter( + (p) => p.type === 'anthropic' || (p.anthropicApiHost !== undefined && p.anthropicApiHost.trim() !== '') + ) + } + + const models = await listAllAvailableModels(providers) + // Use Map to deduplicate models by their full ID (provider:model_id) + const uniqueModels = new Map() + + for (const model of models) { + const provider = providers.find((p) => p.id === model.provider) + logger.debug(`Processing model ${model.id} from provider ${model.provider}`, { + isAnthropicModel: provider?.isAnthropicModel + }) + if ( + !provider || + (filter.providerType === 'anthropic' && provider.isAnthropicModel && !provider.isAnthropicModel(model)) + ) { + continue + } + // Special case: For "aihubmix", it should be covered by above condition, but just in case + if (provider.id === 'aihubmix' && filter.providerType === 'anthropic' && !model.id.includes('claude')) { + continue + } + + const openAIModel = transformModelToOpenAI(model, provider) + const fullModelId = openAIModel.id // This is already in format "provider:model_id" + + // Only add if not already present (first occurrence wins) + if (!uniqueModels.has(fullModelId)) { + uniqueModels.set(fullModelId, openAIModel) + } else { + logger.debug(`Skipping duplicate model: ${fullModelId}`) + } + } + + let modelData = Array.from(uniqueModels.values()) + const total = modelData.length + + // Apply pagination + const offset = filter?.offset || 0 + const limit = filter?.limit + + if (limit !== undefined) { + modelData = modelData.slice(offset, offset + limit) + logger.debug( + `Applied pagination: offset=${offset}, limit=${limit}, showing ${modelData.length} of ${total} models` + ) + } else if (offset > 0) { + modelData = modelData.slice(offset) + logger.debug(`Applied offset: offset=${offset}, showing ${modelData.length} of ${total} models`) + } + + logger.info('Models retrieved', { + returned: modelData.length, + discovered: models.length, + filter + }) + + if (models.length > total) { + logger.debug(`Filtered out ${models.length - total} models after deduplication and filtering`) + } + + const response: ApiModelsResponse = { + object: 'list', + data: modelData + } + + // Add pagination metadata if applicable + if (filter?.limit !== undefined || filter?.offset !== undefined) { + response.total = total + response.offset = offset + if (filter?.limit !== undefined) { + response.limit = filter.limit + } + } + + return response + } catch (error: any) { + logger.error('Error getting models', { error, filter }) + return { + object: 'list', + data: [] + } + } + } +} + +// Export singleton instance +export const modelsService = new ModelsService() diff --git a/src/main/apiServer/utils/createStreamAbortController.ts b/src/main/apiServer/utils/createStreamAbortController.ts new file mode 100644 index 0000000000..243ad5b96e --- /dev/null +++ b/src/main/apiServer/utils/createStreamAbortController.ts @@ -0,0 +1,64 @@ +export type StreamAbortHandler = (reason: unknown) => void + +export interface StreamAbortController { + abortController: AbortController + registerAbortHandler: (handler: StreamAbortHandler) => void + clearAbortTimeout: () => void +} + +export const STREAM_TIMEOUT_REASON = 'stream timeout' + +interface CreateStreamAbortControllerOptions { + timeoutMs: number +} + +export const createStreamAbortController = (options: CreateStreamAbortControllerOptions): StreamAbortController => { + const { timeoutMs } = options + const abortController = new AbortController() + const signal = abortController.signal + + let timeoutId: NodeJS.Timeout | undefined + let abortHandler: StreamAbortHandler | undefined + + const clearAbortTimeout = () => { + if (!timeoutId) { + return + } + clearTimeout(timeoutId) + timeoutId = undefined + } + + const handleAbort = () => { + clearAbortTimeout() + + if (!abortHandler) { + return + } + + abortHandler(signal.reason) + } + + signal.addEventListener('abort', handleAbort, { once: true }) + + const registerAbortHandler = (handler: StreamAbortHandler) => { + abortHandler = handler + + if (signal.aborted) { + abortHandler(signal.reason) + } + } + + if (timeoutMs > 0) { + timeoutId = setTimeout(() => { + if (!signal.aborted) { + abortController.abort(STREAM_TIMEOUT_REASON) + } + }, timeoutMs) + } + + return { + abortController, + registerAbortHandler, + clearAbortTimeout + } +} diff --git a/src/main/apiServer/utils/index.ts b/src/main/apiServer/utils/index.ts index 9db91282c0..720d3ccf21 100644 --- a/src/main/apiServer/utils/index.ts +++ b/src/main/apiServer/utils/index.ts @@ -1,46 +1,60 @@ +import { CacheService } from '@main/services/CacheService' import { loggerService } from '@main/services/LoggerService' import { reduxService } from '@main/services/ReduxService' -import type { Model, Provider } from '@types' +import type { ApiModel, Model, Provider } from '@types' const logger = loggerService.withContext('ApiServerUtils') -// OpenAI compatible model format -export interface OpenAICompatibleModel { - id: string - object: 'model' - created: number - owned_by: string - provider?: string - provider_model_id?: string -} +// Cache configuration +const PROVIDERS_CACHE_KEY = 'api-server:providers' +const PROVIDERS_CACHE_TTL = 10 * 1000 // 10 seconds export async function getAvailableProviders(): Promise { try { - // Wait for store to be ready before accessing providers + // Try to get from cache first (faster) + const cachedSupportedProviders = CacheService.get(PROVIDERS_CACHE_KEY) + if (cachedSupportedProviders && cachedSupportedProviders.length > 0) { + logger.debug('Providers resolved from cache', { + count: cachedSupportedProviders.length + }) + return cachedSupportedProviders + } + + // If cache is not available, get fresh data from Redux const providers = await reduxService.select('state.llm.providers') if (!providers || !Array.isArray(providers)) { - logger.warn('No providers found in Redux store, returning empty array') + logger.warn('No providers found in Redux store') return [] } - // Only support OpenAI type providers for API server - const openAIProviders = providers.filter((p: Provider) => p.enabled && p.type === 'openai') + // Support OpenAI and Anthropic type providers for API server + const supportedProviders = providers.filter( + (p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic') + ) - logger.info(`Filtered to ${openAIProviders.length} OpenAI providers from ${providers.length} total providers`) + // Cache the filtered results + CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL) - return openAIProviders + logger.info('Providers filtered', { + supported: supportedProviders.length, + total: providers.length + }) + + return supportedProviders } catch (error: any) { - logger.error('Failed to get providers from Redux store:', error) + logger.error('Failed to get providers from Redux store', { error }) return [] } } -export async function listAllAvailableModels(): Promise { +export async function listAllAvailableModels(providers?: Provider[]): Promise { try { - const providers = await getAvailableProviders() + if (!providers) { + providers = await getAvailableProviders() + } return providers.map((p: Provider) => p.models || []).flat() } catch (error: any) { - logger.error('Failed to list available models:', error) + logger.error('Failed to list available models', { error }) return [] } } @@ -48,15 +62,13 @@ export async function listAllAvailableModels(): Promise { export async function getProviderByModel(model: string): Promise { try { if (!model || typeof model !== 'string') { - logger.warn(`Invalid model parameter: ${model}`) + logger.warn('Invalid model parameter', { model }) return undefined } // Validate model format first if (!model.includes(':')) { - logger.warn( - `Invalid model format, must contain ':' separator. Expected format "provider:model_id", got: ${model}` - ) + logger.warn('Invalid model format missing separator', { model }) return undefined } @@ -64,7 +76,7 @@ export async function getProviderByModel(model: string): Promise p.id === providerId) if (!provider) { - logger.warn( - `Provider '${providerId}' not found or not enabled. Available providers: ${providers.map((p) => p.id).join(', ')}` - ) + logger.warn('Provider not found for model', { + providerId, + available: providers.map((p) => p.id) + }) return undefined } - logger.debug(`Found provider '${providerId}' for model: ${model}`) + logger.debug('Provider resolved for model', { providerId, model }) return provider } catch (error: any) { - logger.error('Failed to get provider by model:', error) + logger.error('Failed to get provider by model', { error, model }) return undefined } } @@ -96,9 +109,12 @@ export interface ModelValidationError { code: string } -export async function validateModelId( - model: string -): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> { +export async function validateModelId(model: string): Promise<{ + valid: boolean + error?: ModelValidationError + provider?: Provider + modelId?: string +}> { try { if (!model || typeof model !== 'string') { return { @@ -169,7 +185,7 @@ export async function validateModelId( modelId } } catch (error: any) { - logger.error('Error validating model ID:', error) + logger.error('Error validating model ID', { error, model }) return { valid: false, error: { @@ -181,17 +197,47 @@ export async function validateModelId( } } -export function transformModelToOpenAI(model: Model): OpenAICompatibleModel { +export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel { + const providerDisplayName = provider?.name return { id: `${model.provider}:${model.id}`, object: 'model', + name: model.name, created: Math.floor(Date.now() / 1000), - owned_by: model.owned_by || model.provider, + owned_by: model.owned_by || providerDisplayName || model.provider, provider: model.provider, + provider_name: providerDisplayName, + provider_type: provider?.type, provider_model_id: model.id } } +export async function getProviderById(providerId: string): Promise { + try { + if (!providerId || typeof providerId !== 'string') { + logger.warn('Invalid provider ID parameter', { providerId }) + return undefined + } + + const providers = await getAvailableProviders() + const provider = providers.find((p: Provider) => p.id === providerId) + + if (!provider) { + logger.warn('Provider not found by ID', { + providerId, + available: providers.map((p) => p.id) + }) + return undefined + } + + logger.debug('Provider found by ID', { providerId }) + return provider + } catch (error: any) { + logger.error('Failed to get provider by ID', { error, providerId }) + return undefined + } +} + export function validateProvider(provider: Provider): boolean { try { if (!provider) { @@ -200,7 +246,7 @@ export function validateProvider(provider: Provider): boolean { // Check required fields if (!provider.id || !provider.type || !provider.apiKey || !provider.apiHost) { - logger.warn('Provider missing required fields:', { + logger.warn('Provider missing required fields', { id: !!provider.id, type: !!provider.type, apiKey: !!provider.apiKey, @@ -211,21 +257,25 @@ export function validateProvider(provider: Provider): boolean { // Check if provider is enabled if (!provider.enabled) { - logger.debug(`Provider is disabled: ${provider.id}`) + logger.debug('Provider is disabled', { providerId: provider.id }) return false } - // Only support OpenAI type providers - if (provider.type !== 'openai') { - logger.debug( - `Provider type '${provider.type}' not supported, only 'openai' type is currently supported: ${provider.id}` - ) + // Support OpenAI and Anthropic type providers + if (provider.type !== 'openai' && provider.type !== 'anthropic') { + logger.debug('Provider type not supported', { + providerId: provider.id, + providerType: provider.type + }) return false } return true } catch (error: any) { - logger.error('Error validating provider:', error) + logger.error('Error validating provider', { + error, + providerId: provider?.id + }) return false } } diff --git a/src/main/apiServer/utils/mcp.ts b/src/main/apiServer/utils/mcp.ts index 983d0ff706..f110df5847 100644 --- a/src/main/apiServer/utils/mcp.ts +++ b/src/main/apiServer/utils/mcp.ts @@ -1,3 +1,4 @@ +import { CacheService } from '@main/services/CacheService' import mcpService from '@main/services/MCPService' import { Server } from '@modelcontextprotocol/sdk/server/index.js' import type { ListToolsResult } from '@modelcontextprotocol/sdk/types.js' @@ -9,6 +10,10 @@ import { reduxService } from '../../services/ReduxService' const logger = loggerService.withContext('MCPApiService') +// Cache configuration +const MCP_SERVERS_CACHE_KEY = 'api-server:mcp-servers' +const MCP_SERVERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes + const cachedServers: Record = {} async function handleListToolsRequest(request: any, extra: any): Promise { @@ -34,20 +39,35 @@ async function handleCallToolRequest(request: any, extra: any): Promise { } async function getMcpServerConfigById(id: string): Promise { - const servers = await getServersFromRedux() + const servers = await getMCPServersFromRedux() return servers.find((s) => s.id === id || s.name === id) } /** * Get servers directly from Redux store */ -async function getServersFromRedux(): Promise { +export async function getMCPServersFromRedux(): Promise { try { + logger.debug('Getting servers from Redux store') + + // Try to get from cache first (faster) + const cachedServers = CacheService.get(MCP_SERVERS_CACHE_KEY) + if (cachedServers) { + logger.debug('MCP servers resolved from cache', { count: cachedServers.length }) + return cachedServers + } + + // If cache is not available, get fresh data from Redux const servers = await reduxService.select('state.mcp.servers') - logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`) - return servers || [] + const serverList = servers || [] + + // Cache the results + CacheService.set(MCP_SERVERS_CACHE_KEY, serverList, MCP_SERVERS_CACHE_TTL) + + logger.debug('Fetched servers from Redux store', { count: serverList.length }) + return serverList } catch (error: any) { - logger.error('Failed to get servers from Redux:', error) + logger.error('Failed to get servers from Redux', { error }) return [] } } @@ -55,7 +75,7 @@ async function getServersFromRedux(): Promise { export async function getMcpServerById(id: string): Promise { const server = cachedServers[id] if (!server) { - const servers = await getServersFromRedux() + const servers = await getMCPServersFromRedux() const mcpServer = servers.find((s) => s.id === id || s.name === id) if (!mcpServer) { throw new Error(`Server not found: ${id}`) @@ -72,6 +92,6 @@ export async function getMcpServerById(id: string): Promise { cachedServers[id] = newServer return newServer } - logger.silly('getMcpServer ', { server: server }) + logger.debug('Returning cached MCP server', { id, hasHandlers: Boolean(server) }) return server } diff --git a/src/main/index.ts b/src/main/index.ts index 86a67c0728..f0c1c0966e 100644 --- a/src/main/index.ts +++ b/src/main/index.ts @@ -12,9 +12,13 @@ import { preferenceService } from '@data/PreferenceService' import { replaceDevtoolsFont } from '@main/utils/windowUtil' import { app, dialog } from 'electron' import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer' - import { isDev, isLinux, isWin } from './constant' + +import process from 'node:process' + import { registerIpc } from './ipc' +import { agentService } from './services/agents' +import { apiServerService } from './services/ApiServerService' import { configManager } from './services/ConfigManager' import mcpService from './services/MCPService' import { nodeTraceService } from './services/NodeTraceService' @@ -29,8 +33,6 @@ import { registerShortcuts } from './services/ShortcutService' import { TrayService } from './services/TrayService' import { windowService } from './services/WindowService' import { dataRefactorMigrateService } from './data/migrate/dataRefactor/DataRefactorMigrateService' -import process from 'node:process' -import { apiServerService } from './services/ApiServerService' import { dataApiService } from '@data/DataApiService' import { cacheService } from '@data/CacheService' @@ -226,6 +228,14 @@ if (!app.requestSingleInstanceLock()) { //start selection assistant service initSelectionService() + // Initialize Agent Service + try { + await agentService.initialize() + logger.info('Agent service initialized successfully') + } catch (error: any) { + logger.error('Failed to initialize Agent service:', error) + } + // Start API server if enabled try { const config = await apiServerService.getCurrentConfig() diff --git a/src/main/ipc.ts b/src/main/ipc.ts index 899ffcf9ea..f439bd563e 100644 --- a/src/main/ipc.ts +++ b/src/main/ipc.ts @@ -14,12 +14,21 @@ import type { SpanEntity, TokenUsage } from '@mcp-trace/trace-core' import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH } from '@shared/config/constant' import type { UpgradeChannel } from '@shared/data/preference/preferenceTypes' import { IpcChannel } from '@shared/IpcChannel' -import type { FileMetadata, Notification, OcrProvider, Provider, Shortcut, SupportedOcrFile } from '@types' +import type { + FileMetadata, + Notification, + OcrProvider, + Provider, + Shortcut, + SupportedOcrFile, + AgentPersistedMessage +} from '@types' import checkDiskSpace from 'check-disk-space' import type { ProxyConfig } from 'electron' import { BrowserWindow, dialog, ipcMain, session, shell, systemPreferences, webContents } from 'electron' import fontList from 'font-list' +import { agentMessageRepository } from './services/agents/database' import { apiServerService } from './services/ApiServerService' import appService from './services/AppService' import AppUpdater from './services/AppUpdater' @@ -203,6 +212,27 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { } }) + ipcMain.handle(IpcChannel.AgentMessage_PersistExchange, async (_event, payload) => { + try { + return await agentMessageRepository.persistExchange(payload) + } catch (error) { + logger.error('Failed to persist agent session messages', error as Error) + throw error + } + }) + + ipcMain.handle( + IpcChannel.AgentMessage_GetHistory, + async (_event, { sessionId }: { sessionId: string }): Promise => { + try { + return await agentMessageRepository.getSessionHistory(sessionId) + } catch (error) { + logger.error('Failed to get agent session history', error as Error) + throw error + } + } + ) + //only for mac if (isMac) { ipcMain.handle(IpcChannel.App_MacIsProcessTrusted, (): boolean => { @@ -499,6 +529,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) { ipcMain.handle(IpcChannel.File_ValidateNotesDirectory, fileManager.validateNotesDirectory.bind(fileManager)) ipcMain.handle(IpcChannel.File_StartWatcher, fileManager.startFileWatcher.bind(fileManager)) ipcMain.handle(IpcChannel.File_StopWatcher, fileManager.stopFileWatcher.bind(fileManager)) + ipcMain.handle(IpcChannel.File_ShowInFolder, fileManager.showInFolder.bind(fileManager)) // file service ipcMain.handle(IpcChannel.FileService_Upload, async (_, provider: Provider, file: FileMetadata) => { diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts index 8f29f2d7b3..00dda778be 100644 --- a/src/main/services/FileStorage.ts +++ b/src/main/services/FileStorage.ts @@ -719,7 +719,10 @@ class FileStorage { } public openPath = async (_: Electron.IpcMainInvokeEvent, path: string): Promise => { - shell.openPath(path).catch((err) => logger.error('[IPC - Error] Failed to open file:', err)) + const resolved = await shell.openPath(path) + if (resolved !== '') { + throw new Error(resolved) + } } /** @@ -1223,6 +1226,19 @@ class FileStorage { return false } } + + public showInFolder = async (_: Electron.IpcMainInvokeEvent, path: string): Promise => { + if (!fs.existsSync(path)) { + const msg = `File or folder does not exist: ${path}` + logger.error(msg) + throw new Error(msg) + } + try { + shell.showItemInFolder(path) + } catch (error) { + logger.error('Failed to show item in folder:', error as Error) + } + } } export const fileStorage = new FileStorage() diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index 9da600b97b..ed19a3a88a 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -8,6 +8,7 @@ import { createInMemoryMCPServer } from '@main/mcpServers/factory' import { makeSureDirExists, removeEnvProxy } from '@main/utils' import { buildFunctionCallToolName } from '@main/utils/mcp' import { getBinaryName, getBinaryPath } from '@main/utils/process' +import getLoginShellEnvironment from '@main/utils/shell-env' import { TraceMethod, withSpanFunc } from '@mcp-trace/trace-core' import { Client } from '@modelcontextprotocol/sdk/client/index.js' import type { SSEClientTransportOptions } from '@modelcontextprotocol/sdk/client/sse.js' @@ -45,13 +46,11 @@ import { } from '@types' import { app, net } from 'electron' import { EventEmitter } from 'events' -import { memoize } from 'lodash' import { v4 as uuidv4 } from 'uuid' import DxtService from './DxtService' import { CallBackServer } from './mcp/oauth/callback' import { McpOAuthClientProvider } from './mcp/oauth/provider' -import getLoginShellEnvironment from './mcp/shell-env' import { windowService } from './WindowService' // Generic type for caching wrapped functions @@ -336,7 +335,7 @@ class McpService { getServerLogger(server).debug(`Starting server`, { command: cmd, args }) // Logger.info(`[MCP] Environment variables for server:`, server.env) - const loginShellEnv = await this.getLoginShellEnv() + const loginShellEnv = await getLoginShellEnvironment() // Bun not support proxy https://github.com/oven-sh/bun/issues/16812 if (cmd.includes('bun')) { @@ -879,20 +878,6 @@ class McpService { return await cachedGetResource(server, uri) } - private getLoginShellEnv = memoize(async (): Promise> => { - try { - const loginEnv = await getLoginShellEnvironment() - const pathSeparator = process.platform === 'win32' ? ';' : ':' - const cherryBinPath = path.join(os.homedir(), '.cherrystudio', 'bin') - loginEnv.PATH = `${loginEnv.PATH}${pathSeparator}${cherryBinPath}` - logger.debug('Successfully fetched login shell environment variables:') - return loginEnv - } catch (error) { - logger.error('Failed to fetch login shell environment variables:', error as Error) - return {} - } - }) - // 实现 abortTool 方法 public async abortTool(_: Electron.IpcMainInvokeEvent, callId: string) { const activeToolCall = this.activeToolCalls.get(callId) diff --git a/src/main/services/agents/BaseService.ts b/src/main/services/agents/BaseService.ts new file mode 100644 index 0000000000..86d4aef52c --- /dev/null +++ b/src/main/services/agents/BaseService.ts @@ -0,0 +1,336 @@ +import { type Client, createClient } from '@libsql/client' +import { loggerService } from '@logger' +import { mcpApiService } from '@main/apiServer/services/mcp' +import { ModelValidationError, validateModelId } from '@main/apiServer/utils' +import { AgentType, MCPTool, objectKeys, SlashCommand, Tool } from '@types' +import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql' +import fs from 'fs' +import path from 'path' + +import { MigrationService } from './database/MigrationService' +import * as schema from './database/schema' +import { dbPath } from './drizzle.config' +import { AgentModelField, AgentModelValidationError } from './errors' +import { builtinSlashCommands } from './services/claudecode/commands' +import { builtinTools } from './services/claudecode/tools' + +const logger = loggerService.withContext('BaseService') + +/** + * Base service class providing shared database connection and utilities + * for all agent-related services. + * + * Features: + * - Programmatic schema management (no CLI dependencies) + * - Automatic table creation and migration + * - Schema version tracking and compatibility checks + * - Transaction-based operations for safety + * - Development vs production mode handling + * - Connection retry logic with exponential backoff + */ +export abstract class BaseService { + protected static client: Client | null = null + protected static db: LibSQLDatabase | null = null + protected static isInitialized = false + protected static initializationPromise: Promise | null = null + protected jsonFields: string[] = ['tools', 'mcps', 'configuration', 'accessible_paths', 'allowed_tools'] + + /** + * Initialize database with retry logic and proper error handling + */ + protected static async initialize(): Promise { + // Return existing initialization if in progress + if (BaseService.initializationPromise) { + return BaseService.initializationPromise + } + + if (BaseService.isInitialized) { + return + } + + BaseService.initializationPromise = BaseService.performInitialization() + return BaseService.initializationPromise + } + + public async listMcpTools(agentType: AgentType, ids?: string[]): Promise { + const tools: Tool[] = [] + if (agentType === 'claude-code') { + tools.push(...builtinTools) + } + if (ids && ids.length > 0) { + for (const id of ids) { + try { + const server = await mcpApiService.getServerInfo(id) + if (server) { + server.tools.forEach((tool: MCPTool) => { + tools.push({ + id: `mcp_${id}_${tool.name}`, + name: tool.name, + type: 'mcp', + description: tool.description || '', + requirePermissions: true + }) + }) + } + } catch (error) { + logger.warn('Failed to list MCP tools', { + id, + error: error as Error + }) + } + } + } + + return tools + } + + public async listSlashCommands(agentType: AgentType): Promise { + if (agentType === 'claude-code') { + return builtinSlashCommands + } + return [] + } + + private static async performInitialization(): Promise { + const maxRetries = 3 + let lastError: Error + + for (let attempt = 1; attempt <= maxRetries; attempt++) { + try { + logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`) + + // Ensure the database directory exists + const dbDir = path.dirname(dbPath) + if (!fs.existsSync(dbDir)) { + logger.info(`Creating database directory: ${dbDir}`) + fs.mkdirSync(dbDir, { recursive: true }) + } + + BaseService.client = createClient({ + url: `file:${dbPath}` + }) + + BaseService.db = drizzle(BaseService.client, { schema }) + + // Run database migrations + const migrationService = new MigrationService(BaseService.db, BaseService.client) + await migrationService.runMigrations() + + BaseService.isInitialized = true + logger.info('Agent database initialized successfully') + return + } catch (error) { + lastError = error as Error + logger.warn(`Database initialization attempt ${attempt} failed:`, lastError) + + // Clean up on failure + if (BaseService.client) { + try { + BaseService.client.close() + } catch (closeError) { + logger.warn('Failed to close client during cleanup:', closeError as Error) + } + } + BaseService.client = null + BaseService.db = null + + // Wait before retrying (exponential backoff) + if (attempt < maxRetries) { + const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s + logger.info(`Retrying in ${delay}ms...`) + await new Promise((resolve) => setTimeout(resolve, delay)) + } + } + } + + // All retries failed + BaseService.initializationPromise = null + logger.error('Failed to initialize Agent database after all retries:', lastError!) + throw lastError! + } + + protected ensureInitialized(): void { + if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) { + throw new Error('Database not initialized. Call initialize() first.') + } + } + + protected get database(): LibSQLDatabase { + this.ensureInitialized() + return BaseService.db! + } + + protected get rawClient(): Client { + this.ensureInitialized() + return BaseService.client! + } + + protected serializeJsonFields(data: any): any { + const serialized = { ...data } + + for (const field of this.jsonFields) { + if (serialized[field] !== undefined) { + serialized[field] = + Array.isArray(serialized[field]) || typeof serialized[field] === 'object' + ? JSON.stringify(serialized[field]) + : serialized[field] + } + } + + return serialized + } + + protected deserializeJsonFields(data: any): any { + if (!data) return data + + const deserialized = { ...data } + + for (const field of this.jsonFields) { + if (deserialized[field] && typeof deserialized[field] === 'string') { + try { + deserialized[field] = JSON.parse(deserialized[field]) + } catch (error) { + logger.warn(`Failed to parse JSON field ${field}:`, error as Error) + } + } + } + + // convert null from db to undefined to satisfy type definition + for (const key of objectKeys(data)) { + if (deserialized[key] === null) { + deserialized[key] = undefined + } + } + + return deserialized + } + + /** + * Validate, normalize, and ensure filesystem access for a set of absolute paths. + * + * - Requires every entry to be an absolute path and throws if not. + * - Normalizes each path and deduplicates while preserving order. + * - Creates missing directories (or parent directories for file-like paths). + */ + protected ensurePathsExist(paths?: string[]): string[] { + if (!paths?.length) { + return [] + } + + const sanitizedPaths: string[] = [] + const seenPaths = new Set() + + for (const rawPath of paths) { + if (!rawPath) { + continue + } + + if (!path.isAbsolute(rawPath)) { + throw new Error(`Accessible path must be absolute: ${rawPath}`) + } + + // Normalize to provide consistent values to downstream consumers. + const resolvedPath = path.normalize(rawPath) + + let stats: fs.Stats | null = null + try { + // Attempt to stat the path to understand whether it already exists and if it is a file. + if (fs.existsSync(resolvedPath)) { + stats = fs.statSync(resolvedPath) + } + } catch (error) { + logger.warn('Failed to inspect accessible path', { + path: rawPath, + error: error instanceof Error ? error.message : String(error) + }) + } + + const looksLikeFile = + (stats && stats.isFile()) || (!stats && path.extname(resolvedPath) !== '' && !resolvedPath.endsWith(path.sep)) + + // For file-like targets create the parent directory; otherwise ensure the directory itself. + const directoryToEnsure = looksLikeFile ? path.dirname(resolvedPath) : resolvedPath + + if (!fs.existsSync(directoryToEnsure)) { + try { + fs.mkdirSync(directoryToEnsure, { recursive: true }) + } catch (error) { + logger.error('Failed to create accessible path directory', { + path: directoryToEnsure, + error: error instanceof Error ? error.message : String(error) + }) + throw error + } + } + + // Preserve the first occurrence only to avoid duplicates while keeping caller order stable. + if (!seenPaths.has(resolvedPath)) { + seenPaths.add(resolvedPath) + sanitizedPaths.push(resolvedPath) + } + } + + return sanitizedPaths + } + + /** + * Force re-initialization (for development/testing) + */ + protected async validateAgentModels( + agentType: AgentType, + models: Partial> + ): Promise { + const entries = Object.entries(models) as [AgentModelField, string | undefined][] + if (entries.length === 0) { + return + } + + for (const [field, rawValue] of entries) { + if (rawValue === undefined || rawValue === null) { + continue + } + + const modelValue = rawValue + const validation = await validateModelId(modelValue) + + if (!validation.valid || !validation.provider) { + const detail: ModelValidationError = validation.error ?? { + type: 'invalid_format', + message: 'Unknown model validation error', + code: 'validation_error' + } + + throw new AgentModelValidationError({ agentType, field, model: modelValue }, detail) + } + + if (!validation.provider.apiKey) { + throw new AgentModelValidationError( + { agentType, field, model: modelValue }, + { + type: 'invalid_format', + message: `Provider '${validation.provider.id}' is missing an API key`, + code: 'provider_api_key_missing' + } + ) + } + } + } + + static async reinitialize(): Promise { + BaseService.isInitialized = false + BaseService.initializationPromise = null + + if (BaseService.client) { + try { + BaseService.client.close() + } catch (error) { + logger.warn('Failed to close client during reinitialize:', error as Error) + } + } + + BaseService.client = null + BaseService.db = null + + await BaseService.initialize() + } +} diff --git a/src/main/services/agents/README.md b/src/main/services/agents/README.md new file mode 100644 index 0000000000..986ac8b8df --- /dev/null +++ b/src/main/services/agents/README.md @@ -0,0 +1,81 @@ +# Agents Service + +Simplified Drizzle ORM implementation for agent and session management in Cherry Studio. + +## Features + +- **Native Drizzle migrations** - Uses built-in migrate() function +- **Zero CLI dependencies** in production +- **Auto-initialization** with retry logic +- **Full TypeScript** type safety +- **Model validation** to ensure models exist and provider configuration matches the agent type + +## Schema + +- `agents.schema.ts` - Agent definitions +- `sessions.schema.ts` - Session and message tables +- `migrations.schema.ts` - Migration tracking + +## Usage + +```typescript +import { agentService } from './services' + +// Create agent - fully typed +const agent = await agentService.createAgent({ + type: 'custom', + name: 'My Agent', + model: 'anthropic:claude-3-5-sonnet-20241022' +}) +``` + +## Model Validation + +- Model identifiers must use the `provider:model_id` format (for example `anthropic:claude-3-5-sonnet-20241022`). +- `model`, `plan_model`, and `small_model` are validated against the configured providers before the database is touched. +- Invalid configurations return a `400 invalid_request_error` response and the create/update operation is aborted. + +## Development Commands + +```bash +# Apply schema changes +yarn agents:generate + +# Quick development sync +yarn agents:push + +# Database tools +yarn agents:studio # Open Drizzle Studio +yarn agents:health # Health check +yarn agents:drop # Reset database +``` + +## Workflow + +1. **Edit schema** in `/database/schema/` +2. **Generate migration** with `yarn agents:generate` +3. **Test changes** with `yarn agents:health` +4. **Deploy** - migrations apply automatically + +## Services + +- `AgentService` - Agent CRUD operations +- `SessionService` - Session management +- `SessionMessageService` - Message logging +- `BaseService` - Database utilities +- `schemaSyncer` - Migration handler + +## Troubleshooting + +```bash +# Check status +yarn agents:health + +# Apply migrations +yarn agents:migrate + +# Reset completely +yarn agents:reset --yes +``` + +The simplified migration system reduced complexity from 463 to ~30 lines while maintaining all functionality through Drizzle's native migration system. diff --git a/src/main/services/agents/database/MigrationService.ts b/src/main/services/agents/database/MigrationService.ts new file mode 100644 index 0000000000..fce09bc68b --- /dev/null +++ b/src/main/services/agents/database/MigrationService.ts @@ -0,0 +1,161 @@ +import { type Client } from '@libsql/client' +import { loggerService } from '@logger' +import { getResourcePath } from '@main/utils' +import { type LibSQLDatabase } from 'drizzle-orm/libsql' +import fs from 'fs' +import path from 'path' + +import * as schema from './schema' +import { migrations, type NewMigration } from './schema/migrations.schema' + +const logger = loggerService.withContext('MigrationService') + +interface MigrationJournal { + version: string + dialect: string + entries: Array<{ + idx: number + version: string + when: number + tag: string + breakpoints: boolean + }> +} + +export class MigrationService { + private db: LibSQLDatabase + private client: Client + private migrationDir: string + + constructor(db: LibSQLDatabase, client: Client) { + this.db = db + this.client = client + this.migrationDir = path.join(getResourcePath(), 'database', 'drizzle') + } + + async runMigrations(): Promise { + try { + logger.info('Starting migration check...') + + const hasMigrationsTable = await this.migrationsTableExists() + + if (!hasMigrationsTable) { + logger.info('Migrations table not found; assuming fresh database state') + } + + // Read migration journal + const journal = await this.readMigrationJournal() + if (!journal.entries.length) { + logger.info('No migrations found in journal') + return + } + + // Get applied migrations + const appliedMigrations = hasMigrationsTable ? await this.getAppliedMigrations() : [] + const appliedVersions = new Set(appliedMigrations.map((m) => Number(m.version))) + + const latestAppliedVersion = appliedMigrations.reduce( + (max, migration) => Math.max(max, Number(migration.version)), + 0 + ) + const latestJournalVersion = journal.entries.reduce((max, entry) => Math.max(max, entry.idx), 0) + + logger.info(`Latest applied migration: v${latestAppliedVersion}, latest available: v${latestJournalVersion}`) + + // Find pending migrations (compare journal idx with stored version, which is the same value) + const pendingMigrations = journal.entries + .filter((entry) => !appliedVersions.has(entry.idx)) + .sort((a, b) => a.idx - b.idx) + + if (pendingMigrations.length === 0) { + logger.info('Database is up to date') + return + } + + logger.info(`Found ${pendingMigrations.length} pending migrations`) + + // Execute pending migrations + for (const migration of pendingMigrations) { + await this.executeMigration(migration) + } + + logger.info('All migrations completed successfully') + } catch (error) { + logger.error('Migration failed:', { error }) + throw error + } + } + + private async migrationsTableExists(): Promise { + try { + const table = await this.client.execute(`SELECT name FROM sqlite_master WHERE type='table' AND name='migrations'`) + return table.rows.length > 0 + } catch (error) { + logger.error('Failed to check migrations table status:', { error }) + throw error + } + } + + private async readMigrationJournal(): Promise { + const journalPath = path.join(this.migrationDir, 'meta', '_journal.json') + + if (!fs.existsSync(journalPath)) { + logger.warn('Migration journal not found:', { journalPath }) + return { version: '7', dialect: 'sqlite', entries: [] } + } + + try { + const journalContent = fs.readFileSync(journalPath, 'utf-8') + return JSON.parse(journalContent) + } catch (error) { + logger.error('Failed to read migration journal:', { error }) + throw error + } + } + + private async getAppliedMigrations(): Promise { + try { + return await this.db.select().from(migrations) + } catch (error) { + // This should not happen since we ensure the table exists in runMigrations() + logger.error('Failed to query applied migrations:', { error }) + throw error + } + } + + private async executeMigration(migration: MigrationJournal['entries'][0]): Promise { + const sqlFilePath = path.join(this.migrationDir, `${migration.tag}.sql`) + + if (!fs.existsSync(sqlFilePath)) { + throw new Error(`Migration SQL file not found: ${sqlFilePath}`) + } + + try { + logger.info(`Executing migration ${migration.tag}...`) + const startTime = Date.now() + + // Read and execute SQL + const sqlContent = fs.readFileSync(sqlFilePath, 'utf-8') + await this.client.executeMultiple(sqlContent) + + // Record migration as applied (store journal idx as version for tracking) + const newMigration: NewMigration = { + version: migration.idx, + tag: migration.tag, + executedAt: Date.now() + } + + if (!(await this.migrationsTableExists())) { + throw new Error('Migrations table missing after executing migration; cannot record progress') + } + + await this.db.insert(migrations).values(newMigration) + + const executionTime = Date.now() - startTime + logger.info(`Migration ${migration.tag} completed in ${executionTime}ms`) + } catch (error) { + logger.error(`Migration ${migration.tag} failed:`, { error }) + throw error + } + } +} diff --git a/src/main/services/agents/database/index.ts b/src/main/services/agents/database/index.ts new file mode 100644 index 0000000000..61b3a9ffcc --- /dev/null +++ b/src/main/services/agents/database/index.ts @@ -0,0 +1,14 @@ +/** + * Database Module + * + * This module provides centralized access to Drizzle ORM schemas + * for type-safe database operations. + * + * Schema evolution is handled by Drizzle Kit migrations. + */ + +// Drizzle ORM schemas +export * from './schema' + +// Repository helpers +export * from './sessionMessageRepository' diff --git a/src/main/services/agents/database/schema/agents.schema.ts b/src/main/services/agents/database/schema/agents.schema.ts new file mode 100644 index 0000000000..be8983c1fc --- /dev/null +++ b/src/main/services/agents/database/schema/agents.schema.ts @@ -0,0 +1,35 @@ +/** + * Drizzle ORM schema for agents table + */ + +import { index, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +export const agentsTable = sqliteTable('agents', { + id: text('id').primaryKey(), + type: text('type').notNull(), + name: text('name').notNull(), + description: text('description'), + accessible_paths: text('accessible_paths'), // JSON array of directory paths the agent can access + + instructions: text('instructions'), + + model: text('model').notNull(), // Main model ID (required) + plan_model: text('plan_model'), // Optional plan/thinking model ID + small_model: text('small_model'), // Optional small/fast model ID + + mcps: text('mcps'), // JSON array of MCP tool IDs + allowed_tools: text('allowed_tools'), // JSON array of allowed tool IDs (whitelist) + + configuration: text('configuration'), // JSON, extensible settings + + created_at: text('created_at').notNull(), + updated_at: text('updated_at').notNull() +}) + +// Indexes for agents table +export const agentsNameIdx = index('idx_agents_name').on(agentsTable.name) +export const agentsTypeIdx = index('idx_agents_type').on(agentsTable.type) +export const agentsCreatedAtIdx = index('idx_agents_created_at').on(agentsTable.created_at) + +export type AgentRow = typeof agentsTable.$inferSelect +export type InsertAgentRow = typeof agentsTable.$inferInsert diff --git a/src/main/services/agents/database/schema/index.ts b/src/main/services/agents/database/schema/index.ts new file mode 100644 index 0000000000..553f94a038 --- /dev/null +++ b/src/main/services/agents/database/schema/index.ts @@ -0,0 +1,8 @@ +/** + * Drizzle ORM schema exports + */ + +export * from './agents.schema' +export * from './messages.schema' +export * from './migrations.schema' +export * from './sessions.schema' diff --git a/src/main/services/agents/database/schema/messages.schema.ts b/src/main/services/agents/database/schema/messages.schema.ts new file mode 100644 index 0000000000..d14c755014 --- /dev/null +++ b/src/main/services/agents/database/schema/messages.schema.ts @@ -0,0 +1,30 @@ +import { foreignKey, index, integer, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { sessionsTable } from './sessions.schema' + +// session_messages table to log all messages, thoughts, actions, observations in a session +export const sessionMessagesTable = sqliteTable('session_messages', { + id: integer('id').primaryKey({ autoIncrement: true }), + session_id: text('session_id').notNull(), + role: text('role').notNull(), // 'user', 'agent', 'system', 'tool' + content: text('content').notNull(), // JSON structured data + agent_session_id: text('agent_session_id').default(''), + metadata: text('metadata'), // JSON metadata (optional) + created_at: text('created_at').notNull(), + updated_at: text('updated_at').notNull() +}) + +// Indexes for session_messages table +export const sessionMessagesSessionIdIdx = index('idx_session_messages_session_id').on(sessionMessagesTable.session_id) +export const sessionMessagesCreatedAtIdx = index('idx_session_messages_created_at').on(sessionMessagesTable.created_at) +export const sessionMessagesUpdatedAtIdx = index('idx_session_messages_updated_at').on(sessionMessagesTable.updated_at) + +// Foreign keys for session_messages table +export const sessionMessagesFkSession = foreignKey({ + columns: [sessionMessagesTable.session_id], + foreignColumns: [sessionsTable.id], + name: 'fk_session_messages_session_id' +}).onDelete('cascade') + +export type SessionMessageRow = typeof sessionMessagesTable.$inferSelect +export type InsertSessionMessageRow = typeof sessionMessagesTable.$inferInsert diff --git a/src/main/services/agents/database/schema/migrations.schema.ts b/src/main/services/agents/database/schema/migrations.schema.ts new file mode 100644 index 0000000000..ab0ad17b90 --- /dev/null +++ b/src/main/services/agents/database/schema/migrations.schema.ts @@ -0,0 +1,14 @@ +/** + * Migration tracking schema + */ + +import { integer, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +export const migrations = sqliteTable('migrations', { + version: integer('version').primaryKey(), + tag: text('tag').notNull(), + executedAt: integer('executed_at').notNull() +}) + +export type Migration = typeof migrations.$inferSelect +export type NewMigration = typeof migrations.$inferInsert diff --git a/src/main/services/agents/database/schema/sessions.schema.ts b/src/main/services/agents/database/schema/sessions.schema.ts new file mode 100644 index 0000000000..21ac2fe2c6 --- /dev/null +++ b/src/main/services/agents/database/schema/sessions.schema.ts @@ -0,0 +1,45 @@ +/** + * Drizzle ORM schema for sessions and session_logs tables + */ + +import { foreignKey, index, sqliteTable, text } from 'drizzle-orm/sqlite-core' + +import { agentsTable } from './agents.schema' + +export const sessionsTable = sqliteTable('sessions', { + id: text('id').primaryKey(), + agent_type: text('agent_type').notNull(), + agent_id: text('agent_id').notNull(), // Primary agent ID for the session + name: text('name').notNull(), + description: text('description'), + accessible_paths: text('accessible_paths'), // JSON array of directory paths the agent can access + + instructions: text('instructions'), + + model: text('model').notNull(), // Main model ID (required) + plan_model: text('plan_model'), // Optional plan/thinking model ID + small_model: text('small_model'), // Optional small/fast model ID + + mcps: text('mcps'), // JSON array of MCP tool IDs + allowed_tools: text('allowed_tools'), // JSON array of allowed tool IDs (whitelist) + + configuration: text('configuration'), // JSON, extensible settings + + created_at: text('created_at').notNull(), + updated_at: text('updated_at').notNull() +}) + +// Foreign keys for sessions table +export const sessionsFkAgent = foreignKey({ + columns: [sessionsTable.agent_id], + foreignColumns: [agentsTable.id], + name: 'fk_session_agent_id' +}).onDelete('cascade') + +// Indexes for sessions table +export const sessionsCreatedAtIdx = index('idx_sessions_created_at').on(sessionsTable.created_at) +export const sessionsMainAgentIdIdx = index('idx_sessions_agent_id').on(sessionsTable.agent_id) +export const sessionsModelIdx = index('idx_sessions_model').on(sessionsTable.model) + +export type SessionRow = typeof sessionsTable.$inferSelect +export type InsertSessionRow = typeof sessionsTable.$inferInsert diff --git a/src/main/services/agents/database/sessionMessageRepository.ts b/src/main/services/agents/database/sessionMessageRepository.ts new file mode 100644 index 0000000000..4567c61ec0 --- /dev/null +++ b/src/main/services/agents/database/sessionMessageRepository.ts @@ -0,0 +1,257 @@ +import { loggerService } from '@logger' +import type { + AgentMessageAssistantPersistPayload, + AgentMessagePersistExchangePayload, + AgentMessagePersistExchangeResult, + AgentMessageUserPersistPayload, + AgentPersistedMessage, + AgentSessionMessageEntity +} from '@types' +import { and, asc, eq } from 'drizzle-orm' + +import { BaseService } from '../BaseService' +import type { InsertSessionMessageRow, SessionMessageRow } from './schema' +import { sessionMessagesTable } from './schema' + +const logger = loggerService.withContext('AgentMessageRepository') + +type TxClient = any + +export type PersistUserMessageParams = AgentMessageUserPersistPayload & { + sessionId: string + agentSessionId?: string + tx?: TxClient +} + +export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & { + sessionId: string + agentSessionId: string + tx?: TxClient +} + +type PersistExchangeParams = AgentMessagePersistExchangePayload & { + tx?: TxClient +} + +type PersistExchangeResult = AgentMessagePersistExchangeResult + +class AgentMessageRepository extends BaseService { + private static instance: AgentMessageRepository | null = null + + static getInstance(): AgentMessageRepository { + if (!AgentMessageRepository.instance) { + AgentMessageRepository.instance = new AgentMessageRepository() + } + + return AgentMessageRepository.instance + } + + private serializeMessage(payload: AgentPersistedMessage): string { + return JSON.stringify(payload) + } + + private serializeMetadata(metadata?: Record): string | undefined { + if (!metadata) { + return undefined + } + + try { + return JSON.stringify(metadata) + } catch (error) { + logger.warn('Failed to serialize session message metadata', error as Error) + return undefined + } + } + + private deserialize(row: any): AgentSessionMessageEntity { + if (!row) return row + + const deserialized = { ...row } + + if (typeof deserialized.content === 'string') { + try { + deserialized.content = JSON.parse(deserialized.content) + } catch (error) { + logger.warn('Failed to parse session message content JSON', error as Error) + } + } + + if (typeof deserialized.metadata === 'string') { + try { + deserialized.metadata = JSON.parse(deserialized.metadata) + } catch (error) { + logger.warn('Failed to parse session message metadata JSON', error as Error) + } + } + + return deserialized + } + + private getWriter(tx?: TxClient): TxClient { + return tx ?? this.database + } + + private async findExistingMessageRow( + writer: TxClient, + sessionId: string, + role: string, + messageId: string + ): Promise { + const candidateRows: SessionMessageRow[] = await writer + .select() + .from(sessionMessagesTable) + .where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role))) + .orderBy(asc(sessionMessagesTable.created_at)) + + for (const row of candidateRows) { + if (!row?.content) continue + + try { + const parsed = JSON.parse(row.content) as AgentPersistedMessage | undefined + if (parsed?.message?.id === messageId) { + return row + } + } catch (error) { + logger.warn('Failed to parse session message content JSON during lookup', error as Error) + } + } + + return null + } + + private async upsertMessage( + params: PersistUserMessageParams | PersistAssistantMessageParams + ): Promise { + await AgentMessageRepository.initialize() + this.ensureInitialized() + + const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params + + if (!payload?.message?.role) { + throw new Error('Message payload missing role') + } + + if (!payload.message.id) { + throw new Error('Message payload missing id') + } + + const writer = this.getWriter(tx) + const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString() + const serializedPayload = this.serializeMessage(payload) + const serializedMetadata = this.serializeMetadata(metadata) + + const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id) + + if (existingRow) { + const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined + const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || '' + + await writer + .update(sessionMessagesTable) + .set({ + content: serializedPayload, + metadata: metadataToPersist, + agent_session_id: agentSessionToPersist, + updated_at: now + }) + .where(eq(sessionMessagesTable.id, existingRow.id)) + + return this.deserialize({ + ...existingRow, + content: serializedPayload, + metadata: metadataToPersist, + agent_session_id: agentSessionToPersist, + updated_at: now + }) + } + + const insertData: InsertSessionMessageRow = { + session_id: sessionId, + role: payload.message.role, + content: serializedPayload, + agent_session_id: agentSessionId, + metadata: serializedMetadata, + created_at: now, + updated_at: now + } + + const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning() + + return this.deserialize(saved) + } + + async persistUserMessage(params: PersistUserMessageParams): Promise { + return this.upsertMessage({ ...params, agentSessionId: params.agentSessionId ?? '' }) + } + + async persistAssistantMessage(params: PersistAssistantMessageParams): Promise { + return this.upsertMessage(params) + } + + async persistExchange(params: PersistExchangeParams): Promise { + await AgentMessageRepository.initialize() + this.ensureInitialized() + + const { sessionId, agentSessionId, user, assistant } = params + + const result = await this.database.transaction(async (tx) => { + const exchangeResult: PersistExchangeResult = {} + + if (user?.payload) { + exchangeResult.userMessage = await this.persistUserMessage({ + sessionId, + agentSessionId, + payload: user.payload, + metadata: user.metadata, + createdAt: user.createdAt, + tx + }) + } + + if (assistant?.payload) { + exchangeResult.assistantMessage = await this.persistAssistantMessage({ + sessionId, + agentSessionId, + payload: assistant.payload, + metadata: assistant.metadata, + createdAt: assistant.createdAt, + tx + }) + } + + return exchangeResult + }) + + return result + } + + async getSessionHistory(sessionId: string): Promise { + await AgentMessageRepository.initialize() + this.ensureInitialized() + + try { + const rows = await this.database + .select() + .from(sessionMessagesTable) + .where(eq(sessionMessagesTable.session_id, sessionId)) + .orderBy(asc(sessionMessagesTable.created_at)) + + const messages: AgentPersistedMessage[] = [] + + for (const row of rows) { + const deserialized = this.deserialize(row) + if (deserialized?.content) { + messages.push(deserialized.content as AgentPersistedMessage) + } + } + + logger.info(`Loaded ${messages.length} messages for session ${sessionId}`) + return messages + } catch (error) { + logger.error('Failed to load session history', error as Error) + throw error + } + } +} + +export const agentMessageRepository = AgentMessageRepository.getInstance() diff --git a/src/main/services/agents/drizzle.config.ts b/src/main/services/agents/drizzle.config.ts new file mode 100644 index 0000000000..e12518c069 --- /dev/null +++ b/src/main/services/agents/drizzle.config.ts @@ -0,0 +1,31 @@ +/** + * Drizzle Kit configuration for agents database + */ + +import os from 'node:os' +import path from 'node:path' + +import { defineConfig } from 'drizzle-kit' +import { app } from 'electron' + +function getDbPath() { + if (process.env.NODE_ENV === 'development') { + return path.join(os.homedir(), '.cherrystudio', 'data', 'agents.db') + } + return path.join(app.getPath('userData'), 'agents.db') +} + +const resolvedDbPath = getDbPath() + +export const dbPath = resolvedDbPath + +export default defineConfig({ + dialect: 'sqlite', + schema: './src/main/services/agents/database/schema/index.ts', + out: './resources/database/drizzle', + dbCredentials: { + url: `file:${resolvedDbPath}` + }, + verbose: true, + strict: true +}) diff --git a/src/main/services/agents/errors.ts b/src/main/services/agents/errors.ts new file mode 100644 index 0000000000..b0df2341d7 --- /dev/null +++ b/src/main/services/agents/errors.ts @@ -0,0 +1,22 @@ +import { ModelValidationError } from '@main/apiServer/utils' +import { AgentType } from '@types' + +export type AgentModelField = 'model' | 'plan_model' | 'small_model' + +export interface AgentModelValidationContext { + agentType: AgentType + field: AgentModelField + model?: string +} + +export class AgentModelValidationError extends Error { + readonly context: AgentModelValidationContext + readonly detail: ModelValidationError + + constructor(context: AgentModelValidationContext, detail: ModelValidationError) { + super(`Validation failed for ${context.agentType}.${context.field}: ${detail.message}`) + this.name = 'AgentModelValidationError' + this.context = context + this.detail = detail + } +} diff --git a/src/main/services/agents/index.ts b/src/main/services/agents/index.ts new file mode 100644 index 0000000000..00409e7c64 --- /dev/null +++ b/src/main/services/agents/index.ts @@ -0,0 +1,25 @@ +/** + * Agents Service Module + * + * This module provides a complete autonomous agent management system with: + * - Agent lifecycle management (CRUD operations) + * - Session handling with conversation history + * - Comprehensive logging and audit trails + * - Database operations with Drizzle ORM and migration support + * - RESTful API endpoints for external integration + */ + +// === Core Services === +// Main service classes and singleton instances +export * from './services' + +// === Error Types === +export { type AgentModelField, AgentModelValidationError } from './errors' + +// === Base Infrastructure === +// Shared database utilities and base service class +export { BaseService } from './BaseService' + +// === Database Layer === +// Drizzle ORM schemas, migrations, and database utilities +export * as Database from './database' diff --git a/src/main/services/agents/interfaces/AgentStreamInterface.ts b/src/main/services/agents/interfaces/AgentStreamInterface.ts new file mode 100644 index 0000000000..1b9c6f136d --- /dev/null +++ b/src/main/services/agents/interfaces/AgentStreamInterface.ts @@ -0,0 +1,31 @@ +// Agent-agnostic streaming interface +// This interface should be implemented by all agent services + +import { EventEmitter } from 'node:events' + +import { GetAgentSessionResponse } from '@types' +import type { TextStreamPart } from 'ai' + +// Generic agent stream event that works with any agent type +export interface AgentStreamEvent { + type: 'chunk' | 'error' | 'complete' | 'cancelled' + chunk?: TextStreamPart // Standard AI SDK chunk for UI consumption + error?: Error +} + +// Agent stream interface that all agents should implement +export interface AgentStream extends EventEmitter { + emit(event: 'data', data: AgentStreamEvent): boolean + on(event: 'data', listener: (data: AgentStreamEvent) => void): this + once(event: 'data', listener: (data: AgentStreamEvent) => void): this +} + +// Base agent service interface +export interface AgentServiceInterface { + invoke( + prompt: string, + session: GetAgentSessionResponse, + abortController: AbortController, + lastAgentSessionId?: string + ): Promise +} diff --git a/src/main/services/agents/services/AgentService.ts b/src/main/services/agents/services/AgentService.ts new file mode 100644 index 0000000000..78e7acadb8 --- /dev/null +++ b/src/main/services/agents/services/AgentService.ts @@ -0,0 +1,195 @@ +import path from 'node:path' + +import { getDataPath } from '@main/utils' +import { + AgentBaseSchema, + AgentEntity, + CreateAgentRequest, + CreateAgentResponse, + GetAgentResponse, + ListOptions, + UpdateAgentRequest, + UpdateAgentResponse +} from '@types' +import { count, eq } from 'drizzle-orm' + +import { BaseService } from '../BaseService' +import { type AgentRow, agentsTable, type InsertAgentRow } from '../database/schema' +import { AgentModelField } from '../errors' + +export class AgentService extends BaseService { + private static instance: AgentService | null = null + private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model'] + + static getInstance(): AgentService { + if (!AgentService.instance) { + AgentService.instance = new AgentService() + } + return AgentService.instance + } + + async initialize(): Promise { + await BaseService.initialize() + } + + // Agent Methods + async createAgent(req: CreateAgentRequest): Promise { + this.ensureInitialized() + + const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}` + const now = new Date().toISOString() + + if (!req.accessible_paths || req.accessible_paths.length === 0) { + const defaultPath = path.join(getDataPath(), 'agents', id) + req.accessible_paths = [defaultPath] + } + + if (req.accessible_paths !== undefined) { + req.accessible_paths = this.ensurePathsExist(req.accessible_paths) + } + + await this.validateAgentModels(req.type, { + model: req.model, + plan_model: req.plan_model, + small_model: req.small_model + }) + + const serializedReq = this.serializeJsonFields(req) + + const insertData: InsertAgentRow = { + id, + type: req.type, + name: req.name || 'New Agent', + description: req.description, + instructions: req.instructions || 'You are a helpful assistant.', + model: req.model, + plan_model: req.plan_model, + small_model: req.small_model, + configuration: serializedReq.configuration, + accessible_paths: serializedReq.accessible_paths, + created_at: now, + updated_at: now + } + + await this.database.insert(agentsTable).values(insertData) + const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1) + if (!result[0]) { + throw new Error('Failed to create agent') + } + + const agent = this.deserializeJsonFields(result[0]) as AgentEntity + return agent + } + + async getAgent(id: string): Promise { + this.ensureInitialized() + + const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1) + + if (!result[0]) { + return null + } + + const agent = this.deserializeJsonFields(result[0]) as GetAgentResponse + agent.tools = await this.listMcpTools(agent.type, agent.mcps) + return agent + } + + async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> { + this.ensureInitialized() // Build query with pagination + + const totalResult = await this.database.select({ count: count() }).from(agentsTable) + + const baseQuery = this.database.select().from(agentsTable).orderBy(agentsTable.created_at) + + const result = + options.limit !== undefined + ? options.offset !== undefined + ? await baseQuery.limit(options.limit).offset(options.offset) + : await baseQuery.limit(options.limit) + : await baseQuery + + const agents = result.map((row) => this.deserializeJsonFields(row)) as GetAgentResponse[] + + for (const agent of agents) { + agent.tools = await this.listMcpTools(agent.type, agent.mcps) + } + + return { agents, total: totalResult[0].count } + } + + async updateAgent( + id: string, + updates: UpdateAgentRequest, + options: { replace?: boolean } = {} + ): Promise { + this.ensureInitialized() + + // Check if agent exists + const existing = await this.getAgent(id) + if (!existing) { + return null + } + + const now = new Date().toISOString() + + if (updates.accessible_paths !== undefined) { + updates.accessible_paths = this.ensurePathsExist(updates.accessible_paths) + } + + const modelUpdates: Partial> = {} + for (const field of this.modelFields) { + if (Object.prototype.hasOwnProperty.call(updates, field)) { + modelUpdates[field] = updates[field as keyof UpdateAgentRequest] as string | undefined + } + } + + if (Object.keys(modelUpdates).length > 0) { + await this.validateAgentModels(existing.type, modelUpdates) + } + + const serializedUpdates = this.serializeJsonFields(updates) + + const updateData: Partial = { + updated_at: now + } + const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof AgentRow)[] + const shouldReplace = options.replace ?? false + + for (const field of replaceableFields) { + if (shouldReplace || Object.prototype.hasOwnProperty.call(serializedUpdates, field)) { + if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) { + const value = serializedUpdates[field as keyof typeof serializedUpdates] + ;(updateData as Record)[field] = value ?? null + } else if (shouldReplace) { + ;(updateData as Record)[field] = null + } + } + } + + await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id)) + return await this.getAgent(id) + } + + async deleteAgent(id: string): Promise { + this.ensureInitialized() + + const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id)) + + return result.rowsAffected > 0 + } + + async agentExists(id: string): Promise { + this.ensureInitialized() + + const result = await this.database + .select({ id: agentsTable.id }) + .from(agentsTable) + .where(eq(agentsTable.id, id)) + .limit(1) + + return result.length > 0 + } +} + +export const agentService = AgentService.getInstance() diff --git a/src/main/services/agents/services/SessionMessageService.ts b/src/main/services/agents/services/SessionMessageService.ts new file mode 100644 index 0000000000..f7d44e1612 --- /dev/null +++ b/src/main/services/agents/services/SessionMessageService.ts @@ -0,0 +1,321 @@ +import { loggerService } from '@logger' +import type { + AgentSessionMessageEntity, + CreateSessionMessageRequest, + GetAgentSessionResponse, + ListOptions +} from '@types' +import { TextStreamPart } from 'ai' +import { and, desc, eq, not } from 'drizzle-orm' + +import { BaseService } from '../BaseService' +import { sessionMessagesTable } from '../database/schema' +import { AgentStreamEvent } from '../interfaces/AgentStreamInterface' +import ClaudeCodeService from './claudecode' + +const logger = loggerService.withContext('SessionMessageService') + +type SessionStreamResult = { + stream: ReadableStream>> + completion: Promise<{ + userMessage?: AgentSessionMessageEntity + assistantMessage?: AgentSessionMessageEntity + }> +} + +// Ensure errors emitted through SSE are serializable +function serializeError(error: unknown): { message: string; name?: string; stack?: string } { + if (error instanceof Error) { + return { + message: error.message, + name: error.name, + stack: error.stack + } + } + + if (typeof error === 'string') { + return { message: error } + } + + return { + message: 'Unknown error' + } +} + +class TextStreamAccumulator { + private textBuffer = '' + private totalText = '' + private readonly toolCalls = new Map() + private readonly toolResults = new Map() + + add(part: TextStreamPart>): void { + switch (part.type) { + case 'text-start': + this.textBuffer = '' + break + case 'text-delta': + if (part.text) { + this.textBuffer += part.text + } + break + case 'text-end': { + const blockText = (part.providerMetadata?.text?.value as string | undefined) ?? this.textBuffer + if (blockText) { + this.totalText += blockText + } + this.textBuffer = '' + break + } + case 'tool-call': + if (part.toolCallId) { + const legacyPart = part as typeof part & { + args?: unknown + providerMetadata?: { raw?: { input?: unknown } } + } + this.toolCalls.set(part.toolCallId, { + toolName: part.toolName, + input: part.input ?? legacyPart.args ?? legacyPart.providerMetadata?.raw?.input + }) + } + break + case 'tool-result': + if (part.toolCallId) { + const legacyPart = part as typeof part & { + result?: unknown + providerMetadata?: { raw?: unknown } + } + this.toolResults.set(part.toolCallId, part.output ?? legacyPart.result ?? legacyPart.providerMetadata?.raw) + } + break + default: + break + } + } +} + +export class SessionMessageService extends BaseService { + private static instance: SessionMessageService | null = null + private cc: ClaudeCodeService = new ClaudeCodeService() + + static getInstance(): SessionMessageService { + if (!SessionMessageService.instance) { + SessionMessageService.instance = new SessionMessageService() + } + return SessionMessageService.instance + } + + async initialize(): Promise { + await BaseService.initialize() + } + + async sessionMessageExists(id: number): Promise { + this.ensureInitialized() + + const result = await this.database + .select({ id: sessionMessagesTable.id }) + .from(sessionMessagesTable) + .where(eq(sessionMessagesTable.id, id)) + .limit(1) + + return result.length > 0 + } + + async listSessionMessages( + sessionId: string, + options: ListOptions = {} + ): Promise<{ messages: AgentSessionMessageEntity[] }> { + this.ensureInitialized() + + // Get messages with pagination + const baseQuery = this.database + .select() + .from(sessionMessagesTable) + .where(eq(sessionMessagesTable.session_id, sessionId)) + .orderBy(sessionMessagesTable.created_at) + + const result = + options.limit !== undefined + ? options.offset !== undefined + ? await baseQuery.limit(options.limit).offset(options.offset) + : await baseQuery.limit(options.limit) + : await baseQuery + + const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[] + + return { messages } + } + + async deleteSessionMessage(sessionId: string, messageId: number): Promise { + this.ensureInitialized() + + const result = await this.database + .delete(sessionMessagesTable) + .where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId))) + + return result.rowsAffected > 0 + } + + async createSessionMessage( + session: GetAgentSessionResponse, + messageData: CreateSessionMessageRequest, + abortController: AbortController + ): Promise { + this.ensureInitialized() + + return await this.startSessionMessageStream(session, messageData, abortController) + } + + private async startSessionMessageStream( + session: GetAgentSessionResponse, + req: CreateSessionMessageRequest, + abortController: AbortController + ): Promise { + const agentSessionId = await this.getLastAgentSessionId(session.id) + logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId }) + + if (session.agent_type !== 'claude-code') { + // TODO: Implement support for other agent types + logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type }) + throw new Error('Unsupported agent type for streaming') + } + + const claudeStream = await this.cc.invoke(req.content, session, abortController, agentSessionId) + const accumulator = new TextStreamAccumulator() + + let resolveCompletion!: (value: { + userMessage?: AgentSessionMessageEntity + assistantMessage?: AgentSessionMessageEntity + }) => void + let rejectCompletion!: (reason?: unknown) => void + + const completion = new Promise<{ + userMessage?: AgentSessionMessageEntity + assistantMessage?: AgentSessionMessageEntity + }>((resolve, reject) => { + resolveCompletion = resolve + rejectCompletion = reject + }) + + let finished = false + + const cleanup = () => { + if (finished) return + finished = true + claudeStream.removeAllListeners() + } + + const stream = new ReadableStream>>({ + start: (controller) => { + claudeStream.on('data', async (event: AgentStreamEvent) => { + if (finished) return + try { + switch (event.type) { + case 'chunk': { + const chunk = event.chunk as TextStreamPart> | undefined + if (!chunk) { + logger.warn('Received agent chunk event without chunk payload') + return + } + + accumulator.add(chunk) + controller.enqueue(chunk) + break + } + + case 'error': { + const stderrMessage = (event as any)?.data?.stderr as string | undefined + const underlyingError = event.error ?? (stderrMessage ? new Error(stderrMessage) : undefined) + cleanup() + const streamError = underlyingError ?? new Error('Stream error') + controller.error(streamError) + rejectCompletion(serializeError(streamError)) + break + } + + case 'complete': { + cleanup() + controller.close() + resolveCompletion({}) + break + } + + case 'cancelled': { + cleanup() + controller.close() + resolveCompletion({}) + break + } + + default: + logger.warn('Unknown event type from Claude Code service:', { + type: event.type + }) + break + } + } catch (error) { + cleanup() + controller.error(error) + rejectCompletion(serializeError(error)) + } + }) + }, + cancel: (reason) => { + cleanup() + abortController.abort(typeof reason === 'string' ? reason : 'stream cancelled') + resolveCompletion({}) + } + }) + + return { stream, completion } + } + + private async getLastAgentSessionId(sessionId: string): Promise { + this.ensureInitialized() + + try { + const result = await this.database + .select({ agent_session_id: sessionMessagesTable.agent_session_id }) + .from(sessionMessagesTable) + .where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, '')))) + .orderBy(desc(sessionMessagesTable.created_at)) + .limit(1) + + logger.silly('Last agent session ID result:', { agentSessionId: result[0]?.agent_session_id, sessionId }) + return result[0]?.agent_session_id || '' + } catch (error) { + logger.error('Failed to get last agent session ID', { + sessionId, + error + }) + return '' + } + } + + private deserializeSessionMessage(data: any): AgentSessionMessageEntity { + if (!data) return data + + const deserialized = { ...data } + + // Parse content JSON + if (deserialized.content && typeof deserialized.content === 'string') { + try { + deserialized.content = JSON.parse(deserialized.content) + } catch (error) { + logger.warn(`Failed to parse content JSON:`, error as Error) + } + } + + // Parse metadata JSON + if (deserialized.metadata && typeof deserialized.metadata === 'string') { + try { + deserialized.metadata = JSON.parse(deserialized.metadata) + } catch (error) { + logger.warn(`Failed to parse metadata JSON:`, error as Error) + } + } + + return deserialized + } +} + +export const sessionMessageService = SessionMessageService.getInstance() diff --git a/src/main/services/agents/services/SessionService.ts b/src/main/services/agents/services/SessionService.ts new file mode 100644 index 0000000000..5fcb60600d --- /dev/null +++ b/src/main/services/agents/services/SessionService.ts @@ -0,0 +1,235 @@ +import { + AgentBaseSchema, + type AgentEntity, + type AgentSessionEntity, + type CreateSessionRequest, + type GetAgentSessionResponse, + type ListOptions, + type UpdateSessionRequest, + UpdateSessionResponse +} from '@types' +import { and, count, desc, eq, type SQL } from 'drizzle-orm' + +import { BaseService } from '../BaseService' +import { agentsTable, type InsertSessionRow, type SessionRow, sessionsTable } from '../database/schema' +import { AgentModelField } from '../errors' + +export class SessionService extends BaseService { + private static instance: SessionService | null = null + private readonly modelFields: AgentModelField[] = ['model', 'plan_model', 'small_model'] + + static getInstance(): SessionService { + if (!SessionService.instance) { + SessionService.instance = new SessionService() + } + return SessionService.instance + } + + async initialize(): Promise { + await BaseService.initialize() + } + + async createSession( + agentId: string, + req: Partial = {} + ): Promise { + this.ensureInitialized() + + // Validate agent exists - we'll need to import AgentService for this check + // For now, we'll skip this validation to avoid circular dependencies + // The database foreign key constraint will handle this + + const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1) + if (!agents[0]) { + throw new Error('Agent not found') + } + const agent = this.deserializeJsonFields(agents[0]) as AgentEntity + + const id = `session_${Date.now()}_${Math.random().toString(36).substring(2, 11)}` + const now = new Date().toISOString() + + // inherit configuration from agent by default, can be overridden by sessionData + const sessionData: Partial = { + ...agent, + ...req + } + + await this.validateAgentModels(agent.type, { + model: sessionData.model, + plan_model: sessionData.plan_model, + small_model: sessionData.small_model + }) + + if (sessionData.accessible_paths !== undefined) { + sessionData.accessible_paths = this.ensurePathsExist(sessionData.accessible_paths) + } + + const serializedData = this.serializeJsonFields(sessionData) + + const insertData: InsertSessionRow = { + id, + agent_id: agentId, + agent_type: agent.type, + name: serializedData.name || null, + description: serializedData.description || null, + accessible_paths: serializedData.accessible_paths || null, + instructions: serializedData.instructions || null, + model: serializedData.model || null, + plan_model: serializedData.plan_model || null, + small_model: serializedData.small_model || null, + mcps: serializedData.mcps || null, + configuration: serializedData.configuration || null, + created_at: now, + updated_at: now + } + + await this.database.insert(sessionsTable).values(insertData) + + const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1) + + if (!result[0]) { + throw new Error('Failed to create session') + } + + const session = this.deserializeJsonFields(result[0]) + return await this.getSession(agentId, session.id) + } + + async getSession(agentId: string, id: string): Promise { + this.ensureInitialized() + + const result = await this.database + .select() + .from(sessionsTable) + .where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId))) + .limit(1) + + if (!result[0]) { + return null + } + + const session = this.deserializeJsonFields(result[0]) as GetAgentSessionResponse + session.tools = await this.listMcpTools(session.agent_type, session.mcps) + session.slash_commands = await this.listSlashCommands(session.agent_type) + return session + } + + async listSessions( + agentId?: string, + options: ListOptions = {} + ): Promise<{ sessions: AgentSessionEntity[]; total: number }> { + this.ensureInitialized() + + // Build where conditions + const whereConditions: SQL[] = [] + if (agentId) { + whereConditions.push(eq(sessionsTable.agent_id, agentId)) + } + + const whereClause = + whereConditions.length > 1 + ? and(...whereConditions) + : whereConditions.length === 1 + ? whereConditions[0] + : undefined + + // Get total count + const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause) + + const total = totalResult[0].count + + // Build list query with pagination - sort by updated_at descending (latest first) + const baseQuery = this.database + .select() + .from(sessionsTable) + .where(whereClause) + .orderBy(desc(sessionsTable.updated_at)) + + const result = + options.limit !== undefined + ? options.offset !== undefined + ? await baseQuery.limit(options.limit).offset(options.offset) + : await baseQuery.limit(options.limit) + : await baseQuery + + const sessions = result.map((row) => this.deserializeJsonFields(row)) as GetAgentSessionResponse[] + + return { sessions, total } + } + + async updateSession( + agentId: string, + id: string, + updates: UpdateSessionRequest + ): Promise { + this.ensureInitialized() + + // Check if session exists + const existing = await this.getSession(agentId, id) + if (!existing) { + return null + } + + // Validate agent exists if changing main_agent_id + // We'll skip this validation for now to avoid circular dependencies + + const now = new Date().toISOString() + + if (updates.accessible_paths !== undefined) { + updates.accessible_paths = this.ensurePathsExist(updates.accessible_paths) + } + + const modelUpdates: Partial> = {} + for (const field of this.modelFields) { + if (Object.prototype.hasOwnProperty.call(updates, field)) { + modelUpdates[field] = updates[field as keyof UpdateSessionRequest] as string | undefined + } + } + + if (Object.keys(modelUpdates).length > 0) { + await this.validateAgentModels(existing.agent_type, modelUpdates) + } + + const serializedUpdates = this.serializeJsonFields(updates) + + const updateData: Partial = { + updated_at: now + } + const replaceableFields = Object.keys(AgentBaseSchema.shape) as (keyof SessionRow)[] + + for (const field of replaceableFields) { + if (Object.prototype.hasOwnProperty.call(serializedUpdates, field)) { + const value = serializedUpdates[field as keyof typeof serializedUpdates] + ;(updateData as Record)[field] = value ?? null + } + } + + await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id)) + + return await this.getSession(agentId, id) + } + + async deleteSession(agentId: string, id: string): Promise { + this.ensureInitialized() + + const result = await this.database + .delete(sessionsTable) + .where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId))) + + return result.rowsAffected > 0 + } + + async sessionExists(agentId: string, id: string): Promise { + this.ensureInitialized() + + const result = await this.database + .select({ id: sessionsTable.id }) + .from(sessionsTable) + .where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId))) + .limit(1) + + return result.length > 0 + } +} + +export const sessionService = SessionService.getInstance() diff --git a/src/main/services/agents/services/claudecode/__tests__/transform.test.ts b/src/main/services/agents/services/claudecode/__tests__/transform.test.ts new file mode 100644 index 0000000000..1c5c2ade6b --- /dev/null +++ b/src/main/services/agents/services/claudecode/__tests__/transform.test.ts @@ -0,0 +1,290 @@ +import type { SDKMessage } from '@anthropic-ai/claude-agent-sdk' +import { describe, expect, it } from 'vitest' + +import { ClaudeStreamState, transformSDKMessageToStreamParts } from '../transform' + +const baseStreamMetadata = { + parent_tool_use_id: null, + session_id: 'session-123' +} + +const uuid = (n: number) => `00000000-0000-0000-0000-${n.toString().padStart(12, '0')}` + +describe('Claude → AiSDK transform', () => { + it('handles tool call streaming lifecycle', () => { + const state = new ClaudeStreamState() + const parts: ReturnType[number][] = [] + + const messages: SDKMessage[] = [ + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(1), + event: { + type: 'message_start', + message: { + id: 'msg-start', + type: 'message', + role: 'assistant', + model: 'claude-test', + content: [], + stop_reason: null, + stop_sequence: null, + usage: {} + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(2), + event: { + type: 'content_block_start', + index: 0, + content_block: { + type: 'tool_use', + id: 'tool-1', + name: 'Bash', + input: {} + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(3), + event: { + type: 'content_block_delta', + index: 0, + delta: { + type: 'input_json_delta', + partial_json: '{"command":"ls"}' + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'assistant', + uuid: uuid(4), + message: { + id: 'msg-tool', + type: 'message', + role: 'assistant', + model: 'claude-test', + content: [ + { + type: 'tool_use', + id: 'tool-1', + name: 'Bash', + input: { + command: 'ls' + } + } + ], + stop_reason: 'tool_use', + stop_sequence: null, + usage: { + input_tokens: 1, + output_tokens: 0 + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(5), + event: { + type: 'content_block_stop', + index: 0 + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(6), + event: { + type: 'message_delta', + delta: { + stop_reason: 'tool_use', + stop_sequence: null + }, + usage: { + input_tokens: 1, + output_tokens: 5 + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(7), + event: { + type: 'message_stop' + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'user', + uuid: uuid(8), + message: { + role: 'user', + content: [ + { + type: 'tool_result', + tool_use_id: 'tool-1', + content: 'ok', + is_error: false + } + ] + } + } as SDKMessage + ] + + for (const message of messages) { + const transformed = transformSDKMessageToStreamParts(message, state) + for (const part of transformed) { + parts.push(part) + } + } + + const types = parts.map((part) => part.type) + expect(types).toEqual([ + 'start-step', + 'tool-input-start', + 'tool-input-delta', + 'tool-call', + 'tool-input-end', + 'finish-step', + 'tool-result' + ]) + + const finishStep = parts.find((part) => part.type === 'finish-step') as Extract< + (typeof parts)[number], + { type: 'finish-step' } + > + expect(finishStep.finishReason).toBe('tool-calls') + expect(finishStep.usage).toEqual({ inputTokens: 1, outputTokens: 5, totalTokens: 6 }) + + const toolResult = parts.find((part) => part.type === 'tool-result') as Extract< + (typeof parts)[number], + { type: 'tool-result' } + > + expect(toolResult.toolCallId).toBe('tool-1') + expect(toolResult.toolName).toBe('Bash') + expect(toolResult.input).toEqual({ command: 'ls' }) + expect(toolResult.output).toBe('ok') + }) + + it('handles streaming text completion', () => { + const state = new ClaudeStreamState() + const parts: ReturnType[number][] = [] + + const messages: SDKMessage[] = [ + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(9), + event: { + type: 'message_start', + message: { + id: 'msg-text', + type: 'message', + role: 'assistant', + model: 'claude-text', + content: [], + stop_reason: null, + stop_sequence: null, + usage: {} + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(10), + event: { + type: 'content_block_start', + index: 0, + content_block: { + type: 'text', + text: '' + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(11), + event: { + type: 'content_block_delta', + index: 0, + delta: { + type: 'text_delta', + text: 'Hello' + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(12), + event: { + type: 'content_block_delta', + index: 0, + delta: { + type: 'text_delta', + text: ' world' + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(13), + event: { + type: 'content_block_stop', + index: 0 + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(14), + event: { + type: 'message_delta', + delta: { + stop_reason: 'end_turn', + stop_sequence: null + }, + usage: { + input_tokens: 2, + output_tokens: 4 + } + } + } as unknown as SDKMessage, + { + ...baseStreamMetadata, + type: 'stream_event', + uuid: uuid(15), + event: { + type: 'message_stop' + } + } as SDKMessage + ] + + for (const message of messages) { + const transformed = transformSDKMessageToStreamParts(message, state) + parts.push(...transformed) + } + + const types = parts.map((part) => part.type) + expect(types).toEqual(['start-step', 'text-start', 'text-delta', 'text-delta', 'text-end', 'finish-step']) + + const finishStep = parts.find((part) => part.type === 'finish-step') as Extract< + (typeof parts)[number], + { type: 'finish-step' } + > + expect(finishStep.finishReason).toBe('stop') + expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 }) + }) +}) diff --git a/src/main/services/agents/services/claudecode/claude-stream-state.ts b/src/main/services/agents/services/claudecode/claude-stream-state.ts new file mode 100644 index 0000000000..078f048ce8 --- /dev/null +++ b/src/main/services/agents/services/claudecode/claude-stream-state.ts @@ -0,0 +1,241 @@ +/** + * Lightweight state container shared by the Claude → AiSDK transformer. Anthropic does not send + * deterministic identifiers for intermediate content blocks, so we stitch one together by tracking + * block indices and associated AiSDK ids. This class also keeps: + * • incremental text / reasoning buffers so we can emit only deltas while retaining the full + * aggregate for later tool-call emission; + * • a reverse lookup for tool calls so `tool_result` snapshots can recover their metadata; + * • pending usage + finish reason from `message_delta` events until the corresponding + * `message_stop` arrives. + * Every Claude turn gets its own instance. `resetStep` should be invoked once the finish event has + * been emitted to avoid leaking state into the next turn. + */ +import type { FinishReason, LanguageModelUsage, ProviderMetadata } from 'ai' + +/** + * Shared fields for every block that Claude can stream (text, reasoning, tool). + */ +type BaseBlockState = { + id: string + index: number +} + +type TextBlockState = BaseBlockState & { + kind: 'text' + text: string +} + +type ReasoningBlockState = BaseBlockState & { + kind: 'reasoning' + text: string + redacted: boolean +} + +type ToolBlockState = BaseBlockState & { + kind: 'tool' + toolCallId: string + toolName: string + inputBuffer: string + providerMetadata?: ProviderMetadata + resolvedInput?: unknown +} + +export type BlockState = TextBlockState | ReasoningBlockState | ToolBlockState + +type PendingUsageState = { + usage?: LanguageModelUsage + finishReason?: FinishReason +} + +type PendingToolCall = { + toolCallId: string + toolName: string + input: unknown + providerMetadata?: ProviderMetadata +} + +/** + * Tracks the lifecycle of Claude streaming blocks (text, thinking, tool calls) + * across individual websocket events. The transformer relies on this class to + * stitch together deltas, manage pending tool inputs/results, and propagate + * usage/finish metadata once Anthropic closes a message. + */ +export class ClaudeStreamState { + private blocksByIndex = new Map() + private toolIndexById = new Map() + private pendingUsage: PendingUsageState = {} + private pendingToolCalls = new Map() + private stepActive = false + + /** Marks the beginning of a new AiSDK step. */ + beginStep(): void { + this.stepActive = true + } + + hasActiveStep(): boolean { + return this.stepActive + } + + /** Creates a text block placeholder so future deltas can accumulate into it. */ + openTextBlock(index: number, id: string): TextBlockState { + const block: TextBlockState = { + kind: 'text', + id, + index, + text: '' + } + this.blocksByIndex.set(index, block) + return block + } + + /** Starts tracking an Anthropic "thinking" block, optionally flagged as redacted. */ + openReasoningBlock(index: number, id: string, redacted: boolean): ReasoningBlockState { + const block: ReasoningBlockState = { + kind: 'reasoning', + id, + index, + redacted, + text: '' + } + this.blocksByIndex.set(index, block) + return block + } + + /** Caches tool metadata so subsequent input deltas and results can find it. */ + openToolBlock( + index: number, + params: { toolCallId: string; toolName: string; providerMetadata?: ProviderMetadata } + ): ToolBlockState { + const block: ToolBlockState = { + kind: 'tool', + id: params.toolCallId, + index, + toolCallId: params.toolCallId, + toolName: params.toolName, + inputBuffer: '', + providerMetadata: params.providerMetadata + } + this.blocksByIndex.set(index, block) + this.toolIndexById.set(params.toolCallId, index) + return block + } + + getBlock(index: number): BlockState | undefined { + return this.blocksByIndex.get(index) + } + + getToolBlockById(toolCallId: string): ToolBlockState | undefined { + const index = this.toolIndexById.get(toolCallId) + if (index === undefined) return undefined + const block = this.blocksByIndex.get(index) + if (!block || block.kind !== 'tool') return undefined + return block + } + + /** Appends streamed text to a text block, returning the updated state when present. */ + appendTextDelta(index: number, text: string): TextBlockState | undefined { + const block = this.blocksByIndex.get(index) + if (!block || block.kind !== 'text') return undefined + block.text += text + return block + } + + /** Appends streamed "thinking" content to the tracked reasoning block. */ + appendReasoningDelta(index: number, text: string): ReasoningBlockState | undefined { + const block = this.blocksByIndex.get(index) + if (!block || block.kind !== 'reasoning') return undefined + block.text += text + return block + } + + /** Concatenates incremental JSON payloads for tool input blocks. */ + appendToolInputDelta(index: number, jsonDelta: string): ToolBlockState | undefined { + const block = this.blocksByIndex.get(index) + if (!block || block.kind !== 'tool') return undefined + block.inputBuffer += jsonDelta + return block + } + + /** Records a tool call to be consumed once its result arrives from the user. */ + registerToolCall( + toolCallId: string, + payload: { toolName: string; input: unknown; providerMetadata?: ProviderMetadata } + ): void { + this.pendingToolCalls.set(toolCallId, { + toolCallId, + toolName: payload.toolName, + input: payload.input, + providerMetadata: payload.providerMetadata + }) + } + + /** Retrieves and clears the buffered tool call metadata for the given id. */ + consumePendingToolCall(toolCallId: string): PendingToolCall | undefined { + const entry = this.pendingToolCalls.get(toolCallId) + if (entry) { + this.pendingToolCalls.delete(toolCallId) + } + return entry + } + + /** + * Persists the final input payload for a tool block once the provider signals + * completion so that downstream tool results can reference the original call. + */ + completeToolBlock(toolCallId: string, input: unknown, providerMetadata?: ProviderMetadata): void { + this.registerToolCall(toolCallId, { + toolName: this.getToolBlockById(toolCallId)?.toolName ?? 'unknown', + input, + providerMetadata + }) + const block = this.getToolBlockById(toolCallId) + if (block) { + block.resolvedInput = input + } + } + + /** Removes a block from the active index map when Claude signals it is done. */ + closeBlock(index: number): BlockState | undefined { + const block = this.blocksByIndex.get(index) + if (!block) return undefined + this.blocksByIndex.delete(index) + if (block.kind === 'tool') { + this.toolIndexById.delete(block.toolCallId) + } + return block + } + + /** Stores interim usage metrics so they can be emitted with the `finish-step`. */ + setPendingUsage(usage?: LanguageModelUsage, finishReason?: FinishReason): void { + if (usage) { + this.pendingUsage.usage = usage + } + if (finishReason) { + this.pendingUsage.finishReason = finishReason + } + } + + getPendingUsage(): PendingUsageState { + return { ...this.pendingUsage } + } + + /** Clears any accumulated usage values for the next streamed message. */ + resetPendingUsage(): void { + this.pendingUsage = {} + } + + /** Drops cached block metadata for the currently active message. */ + resetBlocks(): void { + this.blocksByIndex.clear() + this.toolIndexById.clear() + } + + /** Resets the entire step lifecycle after emitting a terminal frame. */ + resetStep(): void { + this.resetBlocks() + this.resetPendingUsage() + this.stepActive = false + } +} + +export type { PendingToolCall } diff --git a/src/main/services/agents/services/claudecode/commands.ts b/src/main/services/agents/services/claudecode/commands.ts new file mode 100644 index 0000000000..ce90e0978a --- /dev/null +++ b/src/main/services/agents/services/claudecode/commands.ts @@ -0,0 +1,25 @@ +import { SlashCommand } from '@types' + +export const builtinSlashCommands: SlashCommand[] = [ + { command: '/add-dir', description: 'Add additional working directories' }, + { command: '/agents', description: 'Manage custom AI subagents for specialized tasks' }, + { command: '/bug', description: 'Report bugs (sends conversation to Anthropic)' }, + { command: '/clear', description: 'Clear conversation history' }, + { command: '/compact', description: 'Compact conversation with optional focus instructions' }, + { command: '/config', description: 'View/modify configuration' }, + { command: '/cost', description: 'Show token usage statistics' }, + { command: '/doctor', description: 'Checks the health of your Claude Code installation' }, + { command: '/help', description: 'Get usage help' }, + { command: '/init', description: 'Initialize project with CLAUDE.md guide' }, + { command: '/login', description: 'Switch Anthropic accounts' }, + { command: '/logout', description: 'Sign out from your Anthropic account' }, + { command: '/mcp', description: 'Manage MCP server connections and OAuth authentication' }, + { command: '/memory', description: 'Edit CLAUDE.md memory files' }, + { command: '/model', description: 'Select or change the AI model' }, + { command: '/permissions', description: 'View or update permissions' }, + { command: '/pr_comments', description: 'View pull request comments' }, + { command: '/review', description: 'Request code review' }, + { command: '/status', description: 'View account and system statuses' }, + { command: '/terminal-setup', description: 'Install Shift+Enter key binding for newlines (iTerm2 and VSCode only)' }, + { command: '/vim', description: 'Enter vim mode for alternating insert and command modes' } +] diff --git a/src/main/services/agents/services/claudecode/index.ts b/src/main/services/agents/services/claudecode/index.ts new file mode 100644 index 0000000000..7b2f119afb --- /dev/null +++ b/src/main/services/agents/services/claudecode/index.ts @@ -0,0 +1,296 @@ +// src/main/services/agents/services/claudecode/index.ts +import { EventEmitter } from 'node:events' +import { createRequire } from 'node:module' + +import { McpHttpServerConfig, Options, query, SDKMessage } from '@anthropic-ai/claude-agent-sdk' +import { loggerService } from '@logger' +import { config as apiConfigService } from '@main/apiServer/config' +import { validateModelId } from '@main/apiServer/utils' +import getLoginShellEnvironment from '@main/utils/shell-env' +import { app } from 'electron' + +import { GetAgentSessionResponse } from '../..' +import { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface' +import { ClaudeStreamState, transformSDKMessageToStreamParts } from './transform' + +const require_ = createRequire(import.meta.url) +const logger = loggerService.withContext('ClaudeCodeService') + +class ClaudeCodeStream extends EventEmitter implements AgentStream { + declare emit: (event: 'data', data: AgentStreamEvent) => boolean + declare on: (event: 'data', listener: (data: AgentStreamEvent) => void) => this + declare once: (event: 'data', listener: (data: AgentStreamEvent) => void) => this +} + +class ClaudeCodeService implements AgentServiceInterface { + private claudeExecutablePath: string + + constructor() { + // Resolve Claude Code CLI robustly (works in dev and in asar) + this.claudeExecutablePath = require_.resolve('@anthropic-ai/claude-agent-sdk/cli.js') + if (app.isPackaged) { + this.claudeExecutablePath = this.claudeExecutablePath.replace(/\.asar([\\/])/, '.asar.unpacked$1') + } + } + + async invoke( + prompt: string, + session: GetAgentSessionResponse, + abortController: AbortController, + lastAgentSessionId?: string + ): Promise { + const aiStream = new ClaudeCodeStream() + + // Validate session accessible paths and make sure it exists as a directory + const cwd = session.accessible_paths[0] + if (!cwd) { + aiStream.emit('data', { + type: 'error', + error: new Error('No accessible paths defined for the agent session') + }) + return aiStream + } + + // Validate model info + const modelInfo = await validateModelId(session.model) + if (!modelInfo.valid) { + aiStream.emit('data', { + type: 'error', + error: new Error(`Invalid model ID '${session.model}': ${JSON.stringify(modelInfo.error)}`) + }) + return aiStream + } + if ( + (modelInfo.provider?.type !== 'anthropic' && + (modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) || + modelInfo.provider.apiKey === '' + ) { + logger.error('Anthropic provider configuration is missing', { + modelInfo + }) + + aiStream.emit('data', { + type: 'error', + error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`) + }) + return aiStream + } + + const apiConfig = await apiConfigService.get() + const loginShellEnv = await getLoginShellEnvironment() + const loginShellEnvWithoutProxies = Object.fromEntries( + Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy')) + ) as Record + + const env = { + ...loginShellEnvWithoutProxies, + // TODO: fix the proxy api server + // ANTHROPIC_API_KEY: apiConfig.apiKey, + // ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey, + // ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`, + ANTHROPIC_API_KEY: modelInfo.provider.apiKey, + ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey, + ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost, + ANTHROPIC_MODEL: modelInfo.modelId, + ANTHROPIC_SMALL_FAST_MODEL: modelInfo.modelId, + ELECTRON_RUN_AS_NODE: '1', + ELECTRON_NO_ATTACH_CONSOLE: '1' + } + + const errorChunks: string[] = [] + + // Build SDK options from parameters + const options: Options = { + abortController, + cwd, + env, + // model: modelInfo.modelId, + pathToClaudeCodeExecutable: this.claudeExecutablePath, + stderr: (chunk: string) => { + logger.warn('claude stderr', { chunk }) + errorChunks.push(chunk) + }, + systemPrompt: session.instructions + ? { + type: 'preset', + preset: 'claude_code', + append: session.instructions + } + : { type: 'preset', preset: 'claude_code' }, + settingSources: ['project'], + includePartialMessages: true, + permissionMode: session.configuration?.permission_mode, + maxTurns: session.configuration?.max_turns, + allowedTools: session.allowed_tools + } + + if (session.accessible_paths.length > 1) { + options.additionalDirectories = session.accessible_paths.slice(1) + } + + if (session.mcps && session.mcps.length > 0) { + // mcp configs + const mcpList: Record = {} + for (const mcpId of session.mcps) { + mcpList[mcpId] = { + type: 'http', + url: `http://${apiConfig.host}:${apiConfig.port}/v1/mcps/${mcpId}/mcp`, + headers: { + Authorization: `Bearer ${apiConfig.apiKey}` + } + } + } + options.mcpServers = mcpList + options.strictMcpConfig = true + } + + if (lastAgentSessionId) { + options.resume = lastAgentSessionId + // TODO: use fork session when we support branching sessions + // options.forkSession = true + } + + logger.info('Starting Claude Code SDK query', { + prompt, + cwd: options.cwd, + model: options.model, + permissionMode: options.permissionMode, + maxTurns: options.maxTurns, + allowedTools: options.allowedTools, + resume: options.resume + }) + + // Start async processing on the next tick so listeners can subscribe first + setImmediate(() => { + this.processSDKQuery(prompt, options, aiStream, errorChunks).catch((error) => { + logger.error('Unhandled Claude Code stream error', { + error: error instanceof Error ? { name: error.name, message: error.message } : String(error) + }) + aiStream.emit('data', { + type: 'error', + error: error instanceof Error ? error : new Error(String(error)) + }) + }) + }) + + return aiStream + } + + private async *userMessages(prompt: string) { + { + yield { + type: 'user' as const, + parent_tool_use_id: null, + session_id: '', + message: { + role: 'user' as const, + content: prompt + } + } + } + } + + /** + * Process SDK query and emit stream events + */ + private async processSDKQuery( + prompt: string, + options: Options, + stream: ClaudeCodeStream, + errorChunks: string[] + ): Promise { + const jsonOutput: SDKMessage[] = [] + let hasCompleted = false + const startTime = Date.now() + + const streamState = new ClaudeStreamState() + try { + // Process streaming responses using SDK query + for await (const message of query({ + prompt: this.userMessages(prompt), + options + })) { + if (hasCompleted) break + + jsonOutput.push(message) + + if (message.type === 'assistant' || message.type === 'user') { + logger.silly('claude response', { + message, + content: JSON.stringify(message.message.content) + }) + } else if (message.type === 'stream_event') { + logger.silly('Claude stream event', { + message, + event: JSON.stringify(message.event) + }) + } else { + logger.silly('Claude response', { + message, + event: JSON.stringify(message) + }) + } + + // Transform SDKMessage to UIMessageChunks + const chunks = transformSDKMessageToStreamParts(message, streamState) + for (const chunk of chunks) { + stream.emit('data', { + type: 'chunk', + chunk + }) + } + } + + // Successfully completed + hasCompleted = true + const duration = Date.now() - startTime + + logger.debug('SDK query completed successfully', { + duration, + messageCount: jsonOutput.length + }) + + // Emit completion event + stream.emit('data', { + type: 'complete' + }) + } catch (error) { + if (hasCompleted) return + hasCompleted = true + + const duration = Date.now() - startTime + + // Check if this is an abort error + const errorObj = error as any + const isAborted = + errorObj?.name === 'AbortError' || + errorObj?.message?.includes('aborted') || + options.abortController?.signal.aborted + + if (isAborted) { + logger.info('SDK query aborted by client disconnect', { duration }) + // Simply cleanup and return - don't emit error events + stream.emit('data', { + type: 'cancelled', + error: new Error('Request aborted by client') + }) + return + } + + errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj)) + const errorMessage = errorChunks.join('\n\n') + logger.error('SDK query failed', { + duration, + error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj), + stderr: errorChunks + }) + // Emit error event + stream.emit('data', { + type: 'error', + error: new Error(errorMessage) + }) + } + } +} + +export default ClaudeCodeService diff --git a/src/main/services/agents/services/claudecode/map-claude-code-finish-reason.ts b/src/main/services/agents/services/claudecode/map-claude-code-finish-reason.ts new file mode 100644 index 0000000000..04748fbb55 --- /dev/null +++ b/src/main/services/agents/services/claudecode/map-claude-code-finish-reason.ts @@ -0,0 +1,34 @@ +// ported from https://github.com/ben-vargas/ai-sdk-provider-claude-code/blob/main/src/map-claude-code-finish-reason.ts#L22 +import type { LanguageModelV2FinishReason } from '@ai-sdk/provider' + +/** + * Maps Claude Code SDK result subtypes to AI SDK finish reasons. + * + * @param subtype - The result subtype from Claude Code SDK + * @returns The corresponding AI SDK finish reason + * + * @example + * ```typescript + * const finishReason = mapClaudeCodeFinishReason('error_max_turns'); + * // Returns: 'length' + * ``` + * + * @remarks + * Mappings: + * - 'success' -> 'stop' (normal completion) + * - 'error_max_turns' -> 'length' (hit turn limit) + * - 'error_during_execution' -> 'error' (execution error) + * - default -> 'stop' (unknown subtypes treated as normal completion) + */ +export function mapClaudeCodeFinishReason(subtype?: string): LanguageModelV2FinishReason { + switch (subtype) { + case 'success': + return 'stop' + case 'error_max_turns': + return 'length' + case 'error_during_execution': + return 'error' + default: + return 'stop' + } +} diff --git a/src/main/services/agents/services/claudecode/tools.ts b/src/main/services/agents/services/claudecode/tools.ts new file mode 100644 index 0000000000..0785827cd5 --- /dev/null +++ b/src/main/services/agents/services/claudecode/tools.ts @@ -0,0 +1,84 @@ +import { Tool } from '@types' + +// https://docs.anthropic.com/en/docs/claude-code/settings#tools-available-to-claude +export const builtinTools: Tool[] = [ + { + id: 'Bash', + name: 'Bash', + description: 'Executes shell commands in your environment', + requirePermissions: true, + type: 'builtin' + }, + { + id: 'Edit', + name: 'Edit', + description: 'Makes targeted edits to specific files', + requirePermissions: true, + type: 'builtin' + }, + { + id: 'Glob', + name: 'Glob', + description: 'Finds files based on pattern matching', + requirePermissions: false, + type: 'builtin' + }, + { + id: 'Grep', + name: 'Grep', + description: 'Searches for patterns in file contents', + requirePermissions: false, + type: 'builtin' + }, + { + id: 'MultiEdit', + name: 'MultiEdit', + description: 'Performs multiple edits on a single file atomically', + requirePermissions: true, + type: 'builtin' + }, + { + id: 'NotebookEdit', + name: 'NotebookEdit', + description: 'Modifies Jupyter notebook cells', + requirePermissions: true, + type: 'builtin' + }, + { + id: 'NotebookRead', + name: 'NotebookRead', + description: 'Reads and displays Jupyter notebook contents', + requirePermissions: false, + type: 'builtin' + }, + { id: 'Read', name: 'Read', description: 'Reads the contents of files', requirePermissions: false, type: 'builtin' }, + { + id: 'Task', + name: 'Task', + description: 'Runs a sub-agent to handle complex, multi-step tasks', + requirePermissions: false, + type: 'builtin' + }, + { + id: 'TodoWrite', + name: 'TodoWrite', + description: 'Creates and manages structured task lists', + requirePermissions: false, + type: 'builtin' + }, + { + id: 'WebFetch', + name: 'WebFetch', + description: 'Fetches content from a specified URL', + requirePermissions: true, + type: 'builtin' + }, + { + id: 'WebSearch', + name: 'WebSearch', + description: 'Performs web searches with domain filtering', + requirePermissions: true, + type: 'builtin' + }, + { id: 'Write', name: 'Write', description: 'Creates or overwrites files', requirePermissions: true, type: 'builtin' } +] diff --git a/src/main/services/agents/services/claudecode/transform.ts b/src/main/services/agents/services/claudecode/transform.ts new file mode 100644 index 0000000000..4af3716c1d --- /dev/null +++ b/src/main/services/agents/services/claudecode/transform.ts @@ -0,0 +1,703 @@ +/** + * Translates Anthropic Claude Code streaming messages into the generic AiSDK stream + * parts that the agent runtime understands. The transformer coordinates batched + * text/tool payloads, keeps per-message state using {@link ClaudeStreamState}, + * and normalises usage metadata and finish reasons so downstream consumers do + * not need to reason about Anthropic-specific payload shapes. + * + * Stream lifecycle cheatsheet (per Claude turn): + * 1. `stream_event.message_start` → emit `start-step` and mark the state as active. + * 2. `content_block_start` (by index) → open a stateful block; emits one of + * `text-start` | `reasoning-start` | `tool-input-start`. + * 3. `content_block_delta` → append incremental text / reasoning / tool JSON, + * emitting only the delta to minimise UI churn. + * 4. `content_block_stop` → emit the matching `*-end` event and release the block. + * 5. `message_delta` → capture usage + stop reason but defer emission. + * 6. `message_stop` → emit `finish-step` with cached usage & reason, then reset. + * 7. Assistant snapshots with `tool_use` finalise the tool block (`tool-call`). + * 8. User snapshots with `tool_result` emit `tool-result`/`tool-error` using the cached payload. + * 9. Assistant snapshots with plain text (when no stream events were provided) fall back to + * emitting `text-*` parts and a synthetic `finish-step`. + */ + +import { SDKMessage } from '@anthropic-ai/claude-agent-sdk' +import type { BetaStopReason } from '@anthropic-ai/sdk/resources/beta/messages/messages.mjs' +import { loggerService } from '@logger' +import type { FinishReason, LanguageModelUsage, ProviderMetadata, TextStreamPart } from 'ai' +import { v4 as uuidv4 } from 'uuid' + +import { ClaudeStreamState } from './claude-stream-state' +import { mapClaudeCodeFinishReason } from './map-claude-code-finish-reason' + +const logger = loggerService.withContext('ClaudeCodeTransform') + +type AgentStreamPart = TextStreamPart> + +type ToolUseContent = { + type: 'tool_use' + id: string + name: string + input: unknown +} + +type ToolResultContent = { + type: 'tool_result' + tool_use_id: string + content: unknown + is_error?: boolean +} + +/** + * Maps Anthropic stop reasons to the AiSDK equivalents so higher level + * consumers can treat completion states uniformly across providers. + */ +const finishReasonMapping: Record = { + end_turn: 'stop', + max_tokens: 'length', + stop_sequence: 'stop', + tool_use: 'tool-calls', + pause_turn: 'unknown', + refusal: 'content-filter' +} + +const emptyUsage: LanguageModelUsage = { + inputTokens: 0, + outputTokens: 0, + totalTokens: 0 +} + +/** + * Generates deterministic-ish message identifiers that are compatible with the + * AiSDK text stream contract. Anthropic deltas sometimes omit ids, so we create + * our own to ensure the downstream renderer can stitch chunks together. + */ +const generateMessageId = (): string => `msg_${uuidv4().replace(/-/g, '')}` + +/** + * Extracts provider metadata from the raw Claude message so we can surface it + * on every emitted stream part for observability and debugging purposes. + */ +const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata => { + return { + anthropic: { + uuid: message.uuid || generateMessageId(), + session_id: message.session_id + }, + raw: message as Record + } +} + +/** + * Central entrypoint that receives Claude Code websocket events and converts + * them into AiSDK `TextStreamPart`s. The state machine tracks outstanding + * blocks across calls so that incremental deltas can be correlated correctly. + */ +export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] { + switch (sdkMessage.type) { + case 'assistant': + return handleAssistantMessage(sdkMessage, state) + case 'user': + return handleUserMessage(sdkMessage, state) + case 'stream_event': + return handleStreamEvent(sdkMessage, state) + case 'system': + return handleSystemMessage(sdkMessage) + case 'result': + return handleResultMessage(sdkMessage) + default: + logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type }) + return [] + } +} + +/** + * Handles aggregated assistant messages that arrive outside of the streaming + * protocol (e.g. after a tool call finishes). We emit the appropriate + * text/tool events and close the active step once the payload is fully + * processed. + */ +function handleAssistantMessage( + message: Extract, + state: ClaudeStreamState +): AgentStreamPart[] { + const chunks: AgentStreamPart[] = [] + const providerMetadata = sdkMessageToProviderMetadata(message) + const content = message.message.content + const isStreamingActive = state.hasActiveStep() + + if (typeof content === 'string') { + if (!content) { + return chunks + } + + if (!isStreamingActive) { + state.beginStep() + chunks.push({ + type: 'start-step', + request: { body: '' }, + warnings: [] + }) + } + + const textId = message.uuid?.toString() || generateMessageId() + chunks.push({ + type: 'text-start', + id: textId, + providerMetadata + }) + chunks.push({ + type: 'text-delta', + id: textId, + text: content, + providerMetadata + }) + chunks.push({ + type: 'text-end', + id: textId, + providerMetadata + }) + return finalizeNonStreamingStep(message, state, chunks) + } + + if (!Array.isArray(content)) { + return chunks + } + + const textBlocks: string[] = [] + + for (const block of content) { + switch (block.type) { + case 'text': + if (!isStreamingActive) { + textBlocks.push(block.text) + } + break + case 'tool_use': + handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks) + break + default: + logger.warn('Unhandled assistant content block', { type: (block as any).type }) + break + } + } + + if (!isStreamingActive && textBlocks.length > 0) { + const id = message.uuid?.toString() || generateMessageId() + state.beginStep() + chunks.push({ + type: 'start-step', + request: { body: '' }, + warnings: [] + }) + chunks.push({ + type: 'text-start', + id, + providerMetadata + }) + chunks.push({ + type: 'text-delta', + id, + text: textBlocks.join(''), + providerMetadata + }) + chunks.push({ + type: 'text-end', + id, + providerMetadata + }) + return finalizeNonStreamingStep(message, state, chunks) + } + + return chunks +} + +/** + * Registers tool invocations with the stream state so that later tool results + * can be matched with the originating call. + */ +function handleAssistantToolUse( + block: ToolUseContent, + providerMetadata: ProviderMetadata, + state: ClaudeStreamState, + chunks: AgentStreamPart[] +): void { + chunks.push({ + type: 'tool-call', + toolCallId: block.id, + toolName: block.name, + input: block.input, + providerExecuted: true, + providerMetadata + }) + state.completeToolBlock(block.id, block.input, providerMetadata) +} + +/** + * Emits the terminating `finish-step` frame for non-streamed responses and + * clears the currently active step in the state tracker. + */ +function finalizeNonStreamingStep( + message: Extract, + state: ClaudeStreamState, + chunks: AgentStreamPart[] +): AgentStreamPart[] { + const usage = calculateUsageFromMessage(message) + const finishReason = inferFinishReason(message.message.stop_reason) + chunks.push({ + type: 'finish-step', + response: { + id: message.uuid, + timestamp: new Date(), + modelId: message.message.model ?? '' + }, + usage: usage ?? emptyUsage, + finishReason, + providerMetadata: sdkMessageToProviderMetadata(message) + }) + state.resetStep() + return chunks +} + +/** + * Converts user-originated websocket frames (text, tool results, etc.) into + * the AiSDK format. Tool results are matched back to pending tool calls via the + * shared `ClaudeStreamState` instance. + */ +function handleUserMessage( + message: Extract, + state: ClaudeStreamState +): AgentStreamPart[] { + const chunks: AgentStreamPart[] = [] + const providerMetadata = sdkMessageToProviderMetadata(message) + const content = message.message.content + + if (typeof content === 'string') { + if (!content) { + return chunks + } + + const id = message.uuid?.toString() || generateMessageId() + chunks.push({ + type: 'text-start', + id, + providerMetadata + }) + chunks.push({ + type: 'text-delta', + id, + text: content, + providerMetadata + }) + chunks.push({ + type: 'text-end', + id, + providerMetadata + }) + return chunks + } + + if (!Array.isArray(content)) { + return chunks + } + + for (const block of content) { + if (block.type === 'tool_result') { + const toolResult = block as ToolResultContent + const pendingCall = state.consumePendingToolCall(toolResult.tool_use_id) + if (toolResult.is_error) { + chunks.push({ + type: 'tool-error', + toolCallId: toolResult.tool_use_id, + toolName: pendingCall?.toolName ?? 'unknown', + input: pendingCall?.input, + error: toolResult.content, + providerExecuted: true + } as AgentStreamPart) + } else { + chunks.push({ + type: 'tool-result', + toolCallId: toolResult.tool_use_id, + toolName: pendingCall?.toolName ?? 'unknown', + input: pendingCall?.input, + output: toolResult.content, + providerExecuted: true + }) + } + } else if (block.type === 'text') { + const id = message.uuid?.toString() || generateMessageId() + chunks.push({ + type: 'text-start', + id, + providerMetadata + }) + chunks.push({ + type: 'text-delta', + id, + text: (block as { text: string }).text, + providerMetadata + }) + chunks.push({ + type: 'text-end', + id, + providerMetadata + }) + } else { + logger.warn('Unhandled user content block', { type: (block as any).type }) + } + } + + return chunks +} + +/** + * Handles the fine-grained real-time streaming protocol where Anthropic emits + * discrete events for message lifecycle, content blocks, and usage deltas. + */ +function handleStreamEvent( + message: Extract, + state: ClaudeStreamState +): AgentStreamPart[] { + const chunks: AgentStreamPart[] = [] + const providerMetadata = sdkMessageToProviderMetadata(message) + const { event } = message + + switch (event.type) { + case 'message_start': + state.beginStep() + chunks.push({ + type: 'start-step', + request: { body: '' }, + warnings: [] + }) + break + + case 'content_block_start': + handleContentBlockStart(event.index, event.content_block, providerMetadata, state, chunks) + break + + case 'content_block_delta': + handleContentBlockDelta(event.index, event.delta, providerMetadata, state, chunks) + break + + case 'content_block_stop': { + const block = state.closeBlock(event.index) + if (!block) { + logger.warn('Received content_block_stop for unknown index', { index: event.index }) + break + } + + switch (block.kind) { + case 'text': + chunks.push({ + type: 'text-end', + id: block.id, + providerMetadata + }) + break + case 'reasoning': + chunks.push({ + type: 'reasoning-end', + id: block.id, + providerMetadata + }) + break + case 'tool': + chunks.push({ + type: 'tool-input-end', + id: block.toolCallId, + providerMetadata + }) + break + default: + break + } + break + } + + case 'message_delta': { + const finishReason = event.delta.stop_reason + ? mapStopReason(event.delta.stop_reason as BetaStopReason) + : undefined + const usage = convertUsage(event.usage) + state.setPendingUsage(usage, finishReason) + break + } + + case 'message_stop': { + const pending = state.getPendingUsage() + chunks.push({ + type: 'finish-step', + response: { + id: message.uuid, + timestamp: new Date(), + modelId: '' + }, + usage: pending.usage ?? emptyUsage, + finishReason: pending.finishReason ?? 'stop', + providerMetadata + }) + state.resetStep() + break + } + + default: + logger.warn('Unknown stream event type', { type: (event as any).type }) + break + } + + return chunks +} + +/** + * Opens the appropriate block type when Claude starts streaming a new content + * section so later deltas know which logical entity to append to. + */ +function handleContentBlockStart( + index: number, + contentBlock: any, + providerMetadata: ProviderMetadata, + state: ClaudeStreamState, + chunks: AgentStreamPart[] +): void { + switch (contentBlock.type) { + case 'text': { + const block = state.openTextBlock(index, generateMessageId()) + chunks.push({ + type: 'text-start', + id: block.id, + providerMetadata + }) + break + } + case 'thinking': + case 'redacted_thinking': { + const block = state.openReasoningBlock(index, generateMessageId(), contentBlock.type === 'redacted_thinking') + chunks.push({ + type: 'reasoning-start', + id: block.id, + providerMetadata + }) + break + } + case 'tool_use': { + const block = state.openToolBlock(index, { + toolCallId: contentBlock.id, + toolName: contentBlock.name, + providerMetadata + }) + chunks.push({ + type: 'tool-input-start', + id: block.toolCallId, + toolName: block.toolName, + providerMetadata + }) + break + } + default: + logger.warn('Unhandled content_block_start type', { type: contentBlock.type }) + break + } +} + +/** + * Applies incremental deltas to the active block (text, thinking, tool input) + * and emits the translated AiSDK chunk immediately. + */ +function handleContentBlockDelta( + index: number, + delta: any, + providerMetadata: ProviderMetadata, + state: ClaudeStreamState, + chunks: AgentStreamPart[] +): void { + switch (delta.type) { + case 'text_delta': { + const block = state.appendTextDelta(index, delta.text) + if (!block) { + logger.warn('Received text_delta for unknown block', { index }) + return + } + chunks.push({ + type: 'text-delta', + id: block.id, + text: block.text, + providerMetadata + }) + break + } + case 'thinking_delta': { + const block = state.appendReasoningDelta(index, delta.thinking) + if (!block) { + logger.warn('Received thinking_delta for unknown block', { index }) + return + } + chunks.push({ + type: 'reasoning-delta', + id: block.id, + text: delta.thinking, + providerMetadata + }) + break + } + case 'signature_delta': { + const block = state.getBlock(index) + if (block && block.kind === 'reasoning') { + chunks.push({ + type: 'reasoning-delta', + id: block.id, + text: '', + providerMetadata + }) + } + break + } + case 'input_json_delta': { + const block = state.appendToolInputDelta(index, delta.partial_json) + if (!block) { + logger.warn('Received input_json_delta for unknown block', { index }) + return + } + chunks.push({ + type: 'tool-input-delta', + id: block.toolCallId, + delta: block.inputBuffer, + providerMetadata + }) + break + } + default: + logger.warn('Unhandled content_block_delta type', { type: delta.type }) + break + } +} + +/** + * System messages currently only deliver the session bootstrap payload. We + * forward it as both a `start` marker and a raw snapshot for diagnostics. + */ +function handleSystemMessage(message: Extract): AgentStreamPart[] { + const chunks: AgentStreamPart[] = [] + if (message.subtype === 'init') { + chunks.push({ + type: 'start' + }) + chunks.push({ + type: 'raw', + rawValue: { + type: 'init', + session_id: message.session_id, + slash_commands: message.slash_commands, + tools: message.tools, + raw: message + } + }) + } else if (message.subtype === 'compact_boundary') { + chunks.push({ + type: 'raw', + rawValue: { + type: 'compact', + session_id: message.session_id, + raw: message + } + }) + } + return chunks +} + +/** + * Terminal result messages arrive once the Claude Code session concludes. + * Successful runs yield a `finish` frame with aggregated usage metrics, while + * failures are surfaced as `error` frames. + */ +function handleResultMessage(message: Extract): AgentStreamPart[] { + const chunks: AgentStreamPart[] = [] + + let usage: LanguageModelUsage | undefined + if ('usage' in message) { + usage = { + inputTokens: message.usage.input_tokens ?? 0, + outputTokens: message.usage.output_tokens ?? 0, + totalTokens: (message.usage.input_tokens ?? 0) + (message.usage.output_tokens ?? 0) + } + } + + if (message.subtype === 'success') { + chunks.push({ + type: 'finish', + totalUsage: usage ?? emptyUsage, + finishReason: mapClaudeCodeFinishReason(message.subtype), + providerMetadata: { + ...sdkMessageToProviderMetadata(message), + usage: message.usage, + durationMs: message.duration_ms, + costUsd: message.total_cost_usd, + raw: message + } + } as AgentStreamPart) + } else { + chunks.push({ + type: 'error', + error: { + message: `${message.subtype}: Process failed after ${message.num_turns} turns` + } + } as AgentStreamPart) + } + return chunks +} + +/** + * Normalises usage payloads so the caller always receives numeric values even + * when the provider omits certain fields. + */ +function convertUsage( + usage?: { + input_tokens?: number | null + output_tokens?: number | null + } | null +): LanguageModelUsage | undefined { + if (!usage) { + return undefined + } + const inputTokens = usage.input_tokens ?? 0 + const outputTokens = usage.output_tokens ?? 0 + return { + inputTokens, + outputTokens, + totalTokens: inputTokens + outputTokens + } +} + +/** + * Anthropic-only wrapper around {@link finishReasonMapping} that defaults to + * `unknown` to avoid surprising downstream consumers when new stop reasons are + * introduced. + */ +function mapStopReason(reason: BetaStopReason): FinishReason { + return finishReasonMapping[reason] ?? 'unknown' +} + +/** + * Extracts token accounting details from an assistant message, if available. + */ +function calculateUsageFromMessage( + message: Extract +): LanguageModelUsage | undefined { + const usage = message.message.usage + if (!usage) return undefined + return { + inputTokens: usage.input_tokens ?? 0, + outputTokens: usage.output_tokens ?? 0, + totalTokens: (usage.input_tokens ?? 0) + (usage.output_tokens ?? 0) + } +} + +/** + * Converts Anthropic stop reasons into AiSDK finish reasons, falling back to a + * generic `stop` if the provider omits the detail entirely. + */ +function inferFinishReason(stopReason: BetaStopReason | null | undefined): FinishReason { + if (!stopReason) return 'stop' + return mapStopReason(stopReason) +} + +export { ClaudeStreamState } diff --git a/src/main/services/agents/services/index.ts b/src/main/services/agents/services/index.ts new file mode 100644 index 0000000000..e6e545a442 --- /dev/null +++ b/src/main/services/agents/services/index.ts @@ -0,0 +1,26 @@ +/** + * Agent Services Module + * + * This module provides service classes for managing agents, sessions, and session messages. + * All services extend BaseService and provide database operations with proper error handling. + */ + +// Service classes +export { AgentService } from './AgentService' +export { SessionMessageService } from './SessionMessageService' +export { SessionService } from './SessionService' + +// Service instances (singletons) +export { agentService } from './AgentService' +export { sessionMessageService } from './SessionMessageService' +export { sessionService } from './SessionService' + +// Type definitions for service requests and responses +export type { AgentEntity, AgentSessionEntity, CreateAgentRequest, UpdateAgentRequest } from '@types' +export type { + AgentSessionMessageEntity, + CreateSessionRequest, + GetAgentSessionResponse, + ListOptions as SessionListOptions, + UpdateSessionRequest +} from '@types' diff --git a/src/main/services/mcp/shell-env.ts b/src/main/services/mcp/shell-env.ts deleted file mode 100644 index 831cb76b61..0000000000 --- a/src/main/services/mcp/shell-env.ts +++ /dev/null @@ -1,122 +0,0 @@ -import { loggerService } from '@logger' -import { spawn } from 'child_process' -import os from 'os' - -const logger = loggerService.withContext('ShellEnv') - -/** - * Spawns a login shell in the user's home directory to capture its environment variables. - * @returns {Promise} A promise that resolves with an object containing - * the environment variables, or rejects with an error. - */ -function getLoginShellEnvironment(): Promise> { - return new Promise((resolve, reject) => { - const homeDirectory = os.homedir() - if (!homeDirectory) { - return reject(new Error("Could not determine user's home directory.")) - } - - let shellPath = process.env.SHELL - let commandArgs - let shellCommandToGetEnv - - const platform = os.platform() - - if (platform === 'win32') { - // On Windows, 'cmd.exe' is the common shell. - // The 'set' command lists environment variables. - // We don't typically talk about "login shells" in the same way, - // but cmd will load the user's environment. - shellPath = process.env.COMSPEC || 'cmd.exe' - shellCommandToGetEnv = 'set' - commandArgs = ['/c', shellCommandToGetEnv] // /c Carries out the command specified by string and then terminates - } else { - // For POSIX systems (Linux, macOS) - if (!shellPath) { - // Fallback if process.env.SHELL is not set (less common for interactive users) - // Defaulting to bash, but this might not be the user's actual login shell. - // A more robust solution might involve checking /etc/passwd or similar, - // but that's more complex and often requires higher privileges or native modules. - logger.warn("process.env.SHELL is not set. Defaulting to /bin/bash. This might not be the user's login shell.") - shellPath = '/bin/bash' // A common default - } - // -l: Make it a login shell. This sources profile files like .profile, .bash_profile, .zprofile etc. - // -i: Make it interactive. Some shells or profile scripts behave differently. - // 'env': The command to print environment variables. - // Using 'env -0' would be more robust for parsing if values contain newlines, - // but requires splitting by null character. For simplicity, we'll use 'env'. - shellCommandToGetEnv = 'env' - commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command - } - - logger.debug(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) - - const child = spawn(shellPath, commandArgs, { - cwd: homeDirectory, // Run the command in the user's home directory - detached: true, // Allows the parent to exit independently of the child - stdio: ['ignore', 'pipe', 'pipe'], // stdin, stdout, stderr - shell: false // We are specifying the shell command directly - }) - - let output = '' - let errorOutput = '' - - child.stdout.on('data', (data) => { - output += data.toString() - }) - - child.stderr.on('data', (data) => { - errorOutput += data.toString() - }) - - child.on('error', (error) => { - logger.error(`Failed to start shell process: ${shellPath}`, error) - reject(new Error(`Failed to start shell: ${error.message}`)) - }) - - child.on('close', (code) => { - if (code !== 0) { - const errorMessage = `Shell process exited with code ${code}. Shell: ${shellPath}. Args: ${commandArgs.join(' ')}. CWD: ${homeDirectory}. Stderr: ${errorOutput.trim()}` - logger.error(errorMessage) - return reject(new Error(errorMessage)) - } - - if (errorOutput.trim()) { - // Some shells might output warnings or non-fatal errors to stderr - // during profile loading. Log it, but proceed if exit code is 0. - logger.warn(`Shell process stderr output (even with exit code 0):\n${errorOutput.trim()}`) - } - - const env: Record = {} - const lines = output.split('\n') - - lines.forEach((line) => { - const trimmedLine = line.trim() - if (trimmedLine) { - const separatorIndex = trimmedLine.indexOf('=') - if (separatorIndex > 0) { - // Ensure '=' is present and it's not the first character - const key = trimmedLine.substring(0, separatorIndex) - const value = trimmedLine.substring(separatorIndex + 1) - env[key] = value - } - } - }) - - if (Object.keys(env).length === 0 && output.length < 100) { - // Arbitrary small length check - // This might indicate an issue if no env vars were parsed or output was minimal - logger.warn( - 'Parsed environment is empty or output was very short. This might indicate an issue with shell execution or environment variable retrieval.' - ) - logger.warn(`Raw output from shell:\n${output}`) - } - - env.PATH = env.Path || env.PATH || '' - - resolve(env) - }) - }) -} - -export default getLoginShellEnvironment diff --git a/src/main/utils/ocr.ts b/src/main/utils/ocr.ts index 16eeee7b60..fdb5ff1743 100644 --- a/src/main/utils/ocr.ts +++ b/src/main/utils/ocr.ts @@ -1,8 +1,9 @@ import type { ImageFileMetadata } from '@types' import { readFile } from 'fs/promises' -import sharp from 'sharp' const preprocessImage = async (buffer: Buffer): Promise => { + // Delayed loading: The Sharp module is only loaded when the OCR functionality is actually needed, not at app startup + const sharp = (await import('sharp')).default return sharp(buffer) .grayscale() // 转为灰度 .normalize() diff --git a/src/main/utils/shell-env.ts b/src/main/utils/shell-env.ts new file mode 100644 index 0000000000..8fa5a46f3e --- /dev/null +++ b/src/main/utils/shell-env.ts @@ -0,0 +1,260 @@ +import os from 'node:os' +import path from 'node:path' + +import { loggerService } from '@logger' +import { isMac, isWin } from '@main/constant' +import { spawn } from 'child_process' +import { memoize } from 'lodash' + +const logger = loggerService.withContext('ShellEnv') + +// Give shells enough time to source profile files, but fail fast when they hang. +const SHELL_ENV_TIMEOUT_MS = 15_000 + +/** + * Ensures the Cherry Studio bin directory is appended to the user's PATH while + * preserving the original key casing and avoiding duplicate segments. + */ +const appendCherryBinToPath = (env: Record) => { + const pathSeparator = isWin ? ';' : ':' + const homeDirFromEnv = env.HOME || env.Home || env.USERPROFILE || env.UserProfile || os.homedir() + const cherryBinPath = path.join(homeDirFromEnv, '.cherrystudio', 'bin') + const pathKeys = Object.keys(env).filter((key) => key.toLowerCase() === 'path') + const canonicalPathKey = pathKeys[0] || (isWin ? 'Path' : 'PATH') + const existingPathValue = env[canonicalPathKey] || env.PATH || '' + + const normaliseSegment = (segment: string) => { + const normalized = path.normalize(segment) + return isWin ? normalized.toLowerCase() : normalized + } + + const uniqueSegments: string[] = [] + const seenSegments = new Set() + const pushIfUnique = (segment: string) => { + if (!segment) { + return + } + const canonicalSegment = normaliseSegment(segment) + if (!seenSegments.has(canonicalSegment)) { + seenSegments.add(canonicalSegment) + uniqueSegments.push(segment) + } + } + + existingPathValue + .split(pathSeparator) + .map((segment) => segment.trim()) + .forEach(pushIfUnique) + + pushIfUnique(cherryBinPath) + + const updatedPath = uniqueSegments.join(pathSeparator) + + if (pathKeys.length > 0) { + pathKeys.forEach((key) => { + env[key] = updatedPath + }) + } else { + env[canonicalPathKey] = updatedPath + } + + if (!isWin) { + env.PATH = updatedPath + } +} + +/** + * Spawns a login shell in the user's home directory to capture its environment variables. + * + * We explicitly run a login + interactive shell so it sources the same init files that a user + * would typically rely on inside their terminal. Many CLIs export PATH or other variables from + * these scripts; capturing them keeps spawned processes aligned with the user’s expectations. + * + * Timeout handling is important because profile scripts might block forever (e.g. misconfigured + * `read` or prompts). We proactively kill the shell and surface an error in that case so that + * the app does not hang. + * @returns {Promise} A promise that resolves with an object containing + * the environment variables, or rejects with an error. + */ +function getLoginShellEnvironment(): Promise> { + return new Promise((resolve, reject) => { + const homeDirectory = + process.env.HOME || process.env.Home || process.env.USERPROFILE || process.env.UserProfile || os.homedir() + if (!homeDirectory) { + return reject(new Error("Could not determine user's home directory.")) + } + + let shellPath = process.env.SHELL + let commandArgs + let shellCommandToGetEnv + + if (isWin) { + // On Windows, 'cmd.exe' is the common shell. + // The 'set' command lists environment variables. + // We don't typically talk about "login shells" in the same way, + // but cmd will load the user's environment. + shellPath = process.env.COMSPEC || 'cmd.exe' + shellCommandToGetEnv = 'set' + commandArgs = ['/c', shellCommandToGetEnv] // /c Carries out the command specified by string and then terminates + } else { + // For POSIX systems (Linux, macOS) + if (!shellPath) { + // Fallback if process.env.SHELL is not set (less common for interactive users) + // A more robust solution might involve checking /etc/passwd or similar, + // but that's more complex and often requires higher privileges or native modules. + if (isMac) { + // macOS defaults to zsh since Catalina (10.15) + logger.warn( + "process.env.SHELL is not set. Defaulting to /bin/zsh for macOS. This might not be the user's login shell." + ) + shellPath = '/bin/zsh' + } else { + // Other POSIX systems (Linux) default to bash + logger.warn( + "process.env.SHELL is not set. Defaulting to /bin/bash. This might not be the user's login shell." + ) + shellPath = '/bin/bash' + } + } + // -l: Make it a login shell. This sources profile files like .profile, .bash_profile, .zprofile etc. + // -i: Make it interactive. Some shells or profile scripts behave differently. + // 'env': The command to print environment variables. + // Using 'env -0' would be more robust for parsing if values contain newlines, + // but requires splitting by null character. For simplicity, we'll use 'env'. + shellCommandToGetEnv = 'env' + commandArgs = ['-ilc', shellCommandToGetEnv] // -i for interactive, -l for login, -c to execute command + } + + logger.debug(`Spawning shell: ${shellPath} with args: ${commandArgs.join(' ')} in ${homeDirectory}`) + + let settled = false + let timeoutId: NodeJS.Timeout | undefined + + const cleanup = () => { + if (timeoutId) { + clearTimeout(timeoutId) + timeoutId = undefined + } + } + + const resolveOnce = (value: Record) => { + if (settled) { + return + } + settled = true + cleanup() + resolve(value) + } + + const rejectOnce = (error: Error) => { + if (settled) { + return + } + settled = true + cleanup() + reject(error) + } + + const child = spawn(shellPath, commandArgs, { + cwd: homeDirectory, // Run the command in the user's home directory + detached: false, // Stay attached so we can clean up reliably + stdio: ['ignore', 'pipe', 'pipe'], // stdin, stdout, stderr + shell: false // We are specifying the shell command directly + }) + + let output = '' + let errorOutput = '' + + // Protects against shells that wait for user input or hang during profile sourcing. + timeoutId = setTimeout(() => { + const errorMessage = `Timed out after ${SHELL_ENV_TIMEOUT_MS}ms while retrieving shell environment. Shell: ${shellPath}. Args: ${commandArgs.join( + ' ' + )}. CWD: ${homeDirectory}` + logger.error(errorMessage) + child.kill() + rejectOnce(new Error(errorMessage)) + }, SHELL_ENV_TIMEOUT_MS) + + child.stdout.on('data', (data) => { + output += data.toString() + }) + + child.stderr.on('data', (data) => { + errorOutput += data.toString() + }) + + child.on('error', (error) => { + logger.error(`Failed to start shell process: ${shellPath}`, error) + rejectOnce(new Error(`Failed to start shell: ${error.message}`)) + }) + + child.on('close', (code) => { + if (settled) { + return + } + + if (code !== 0) { + const errorMessage = `Shell process exited with code ${code}. Shell: ${shellPath}. Args: ${commandArgs.join(' ')}. CWD: ${homeDirectory}. Stderr: ${errorOutput.trim()}` + logger.error(errorMessage) + return rejectOnce(new Error(errorMessage)) + } + + if (errorOutput.trim()) { + // Some shells might output warnings or non-fatal errors to stderr + // during profile loading. Log it, but proceed if exit code is 0. + logger.warn(`Shell process stderr output (even with exit code 0):\n${errorOutput.trim()}`) + } + + // Convert each VAR=VALUE line into our env map. + const env: Record = {} + const lines = output.split(/\r?\n/) + + lines.forEach((line) => { + const trimmedLine = line.trim() + if (trimmedLine) { + const separatorIndex = trimmedLine.indexOf('=') + if (separatorIndex > 0) { + // Ensure '=' is present and it's not the first character + const key = trimmedLine.substring(0, separatorIndex) + const value = trimmedLine.substring(separatorIndex + 1) + env[key] = value + } + } + }) + + if (Object.keys(env).length === 0 && output.length < 100) { + // Arbitrary small length check + // This might indicate an issue if no env vars were parsed or output was minimal + logger.warn( + 'Parsed environment is empty or output was very short. This might indicate an issue with shell execution or environment variable retrieval.' + ) + logger.warn(`Raw output from shell:\n${output}`) + } + + appendCherryBinToPath(env) + + resolveOnce(env) + }) + }) +} + +const memoizedGetShellEnvs = memoize(async () => { + try { + return await getLoginShellEnvironment() + } catch (error) { + logger.error('Failed to get shell environment, falling back to process.env', { error }) + // Fallback to current process environment with cherry studio bin path + const fallbackEnv: Record = {} + for (const key in process.env) { + fallbackEnv[key] = process.env[key] || '' + } + appendCherryBinToPath(fallbackEnv) + return fallbackEnv + } +}) + +export default memoizedGetShellEnvs + +export const refreshShellEnvCache = () => { + memoizedGetShellEnvs.cache.clear?.() +} diff --git a/src/preload/index.ts b/src/preload/index.ts index 41cdfb124d..e10352619a 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -173,7 +173,8 @@ const api = { openPath: (path: string) => ipcRenderer.invoke(IpcChannel.File_OpenPath, path), save: (path: string, content: string | NodeJS.ArrayBufferView, options?: any) => ipcRenderer.invoke(IpcChannel.File_Save, path, content, options), - selectFolder: (options?: OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.File_SelectFolder, options), + selectFolder: (options?: OpenDialogOptions): Promise => + ipcRenderer.invoke(IpcChannel.File_SelectFolder, options), saveImage: (name: string, data: string) => ipcRenderer.invoke(IpcChannel.File_SaveImage, name, data), binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId), base64Image: (fileId: string): Promise<{ mime: string; base64: string; data: string }> => @@ -204,7 +205,8 @@ const api = { } ipcRenderer.on('file-change', listener) return () => ipcRenderer.off('file-change', listener) - } + }, + showInFolder: (path: string): Promise => ipcRenderer.invoke(IpcChannel.File_ShowInFolder, path) }, fs: { read: (pathOrUrl: string, encoding?: BufferEncoding) => ipcRenderer.invoke(IpcChannel.Fs_Read, pathOrUrl, encoding), diff --git a/src/renderer/src/Router.tsx b/src/renderer/src/Router.tsx index 41a94de7c3..fb555d8bc3 100644 --- a/src/renderer/src/Router.tsx +++ b/src/renderer/src/Router.tsx @@ -9,7 +9,6 @@ import { ErrorBoundary } from './components/ErrorBoundary' import TabsContainer from './components/Tab/TabContainer' import NavigationHandler from './handler/NavigationHandler' import { useNavbarPosition } from './hooks/useNavbar' -import AgentsPage from './pages/agents/AgentsPage' import CodeToolsPage from './pages/code/CodeToolsPage' import FilesPage from './pages/files/FilesPage' import HomePage from './pages/home/HomePage' @@ -20,6 +19,7 @@ import MinAppsPage from './pages/minapps/MinAppsPage' import NotesPage from './pages/notes/NotesPage' import PaintingsRoutePage from './pages/paintings/PaintingsRoutePage' import SettingsPage from './pages/settings/SettingsPage' +import AssistantPresetsPage from './pages/store/assistants/presets/AssistantPresetsPage' import TranslatePage from './pages/translate/TranslatePage' const Router: FC = () => { @@ -30,7 +30,7 @@ const Router: FC = () => { } /> - } /> + } /> } /> } /> } /> diff --git a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts index 1f01806402..4fac440b03 100644 --- a/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts +++ b/src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts @@ -9,6 +9,7 @@ import { WebSearchSource } from '@renderer/types' import type { Chunk } from '@renderer/types/chunk' import { ChunkType } from '@renderer/types/chunk' import { convertLinks, flushLinkConverterBuffer } from '@renderer/utils/linkConverter' +import type { ClaudeCodeRawValue } from '@shared/agents/claudecode/types' import type { TextStreamPart, ToolSet } from 'ai' import { ToolCallChunkHandler } from './handleToolCallChunk' @@ -24,6 +25,7 @@ export class AiSdkToChunkAdapter { private accumulate: boolean | undefined private isFirstChunk = true private enableWebSearch: boolean = false + private onSessionUpdate?: (sessionId: string) => void private responseStartTimestamp: number | null = null private firstTokenTimestamp: number | null = null @@ -31,11 +33,13 @@ export class AiSdkToChunkAdapter { private onChunk: (chunk: Chunk) => void, mcpTools: MCPTool[] = [], accumulate?: boolean, - enableWebSearch?: boolean + enableWebSearch?: boolean, + onSessionUpdate?: (sessionId: string) => void ) { this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools) this.accumulate = accumulate this.enableWebSearch = enableWebSearch || false + this.onSessionUpdate = onSessionUpdate } private markFirstTokenIfNeeded() { @@ -119,6 +123,17 @@ export class AiSdkToChunkAdapter { ) { logger.silly(`AI SDK chunk type: ${chunk.type}`, chunk) switch (chunk.type) { + case 'raw': { + const agentRawMessage = chunk.rawValue as ClaudeCodeRawValue + if (agentRawMessage.type === 'init' && agentRawMessage.session_id) { + this.onSessionUpdate?.(agentRawMessage.session_id) + } + this.onChunk({ + type: ChunkType.RAW, + content: agentRawMessage + }) + break + } // === 文本相关事件 === case 'text-start': this.onChunk({ diff --git a/src/renderer/src/aiCore/index_new.ts b/src/renderer/src/aiCore/index_new.ts index 813ac31b9d..a210b956d0 100644 --- a/src/renderer/src/aiCore/index_new.ts +++ b/src/renderer/src/aiCore/index_new.ts @@ -14,6 +14,7 @@ import { addSpan, endSpan } from '@renderer/services/SpanManagerService' import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity' import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types' import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes' +import { buildClaudeCodeSystemModelMessage } from '@shared/anthropic' import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai' import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter' @@ -22,7 +23,6 @@ import type { CompletionsParams, CompletionsResult } from './legacy/middleware/s import type { AiSdkMiddlewareConfig } from './middleware/AiSdkMiddlewareBuilder' import { buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder' import { buildPlugins } from './plugins/PluginBuilder' -import { buildClaudeCodeSystemMessage } from './provider/config/anthropic' import { createAiSdkProvider } from './provider/factory' import { getActualProvider, @@ -123,13 +123,9 @@ export default class ModernAiProvider { } if (this.actualProvider.id === 'anthropic' && this.actualProvider.authType === 'oauth') { - const claudeCodeSystemMessage = buildClaudeCodeSystemMessage(params.system) + const claudeCodeSystemMessage = buildClaudeCodeSystemModelMessage(params.system) params.system = undefined // 清除原有system,避免重复 - if (Array.isArray(params.messages)) { - params.messages = [...claudeCodeSystemMessage, ...params.messages] - } else { - params.messages = claudeCodeSystemMessage - } + params.messages = [...claudeCodeSystemMessage, ...(params.messages || [])] } if (config.topicId && (await preferenceService.get('app.developer_mode.enabled'))) { diff --git a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts index 246f4a0f8f..2d64b1ef7c 100644 --- a/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/anthropic/AnthropicAPIClient.ts @@ -66,6 +66,7 @@ import { mcpToolsToAnthropicTools } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find' +import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic' import { t } from 'i18next' import type { GenericChunk } from '../../middleware/schemas' @@ -84,8 +85,8 @@ export class AnthropicAPIClient extends BaseApiClient< ToolUnion > { oauthToken: string | undefined = undefined - isOAuthMode: boolean = false sdkInstance: Anthropic | AnthropicVertex | undefined = undefined + constructor(provider: Provider) { super(provider) } @@ -94,84 +95,25 @@ export class AnthropicAPIClient extends BaseApiClient< if (this.sdkInstance) { return this.sdkInstance } - if (this.provider.authType === 'oauth') { - if (!this.oauthToken) { - throw new Error('OAuth token is not available') - } - this.sdkInstance = new Anthropic({ - authToken: this.oauthToken, - baseURL: 'https://api.anthropic.com', - dangerouslyAllowBrowser: true, - defaultHeaders: { - 'Content-Type': 'application/json', - 'anthropic-version': '2023-06-01', - 'anthropic-beta': 'oauth-2025-04-20' - // ...this.provider.extra_headers - } - }) - } else { - this.sdkInstance = new Anthropic({ - apiKey: this.apiKey, - baseURL: this.getBaseURL(), - dangerouslyAllowBrowser: true, - defaultHeaders: { - 'anthropic-beta': 'output-128k-2025-02-19', - ...this.provider.extra_headers - } - }) + this.oauthToken = await window.api.anthropic_oauth.getAccessToken() } - + this.sdkInstance = getSdkClient(this.provider, this.oauthToken) return this.sdkInstance } - private buildClaudeCodeSystemMessage(system?: string | Array): string | Array { - const defaultClaudeCodeSystem = `You are Claude Code, Anthropic's official CLI for Claude.` - if (!system) { - return defaultClaudeCodeSystem - } - - if (typeof system === 'string') { - if (system.trim() === defaultClaudeCodeSystem) { - return system - } - return [ - { - type: 'text', - text: defaultClaudeCodeSystem - }, - { - type: 'text', - text: system - } - ] - } - - if (system[0].text.trim() != defaultClaudeCodeSystem) { - system.unshift({ - type: 'text', - text: defaultClaudeCodeSystem - }) - } - - return system - } - override async createCompletions( payload: AnthropicSdkParams, options?: Anthropic.RequestOptions ): Promise { if (this.provider.authType === 'oauth') { - this.oauthToken = await window.api.anthropic_oauth.getAccessToken() - this.isOAuthMode = true - logger.info('[Anthropic Provider] Using OAuth token for authentication') - payload.system = this.buildClaudeCodeSystemMessage(payload.system) + payload.system = buildClaudeCodeSystemMessage(payload.system) } const sdk = (await this.getSdkInstance()) as Anthropic if (payload.stream) { return sdk.messages.stream(payload, options) } - return await sdk.messages.create(payload, options) + return sdk.messages.create(payload, options) } // @ts-ignore sdk未提供 @@ -181,14 +123,8 @@ export class AnthropicAPIClient extends BaseApiClient< } override async listModels(): Promise { - if (this.provider.authType === 'oauth') { - this.oauthToken = await window.api.anthropic_oauth.getAccessToken() - this.isOAuthMode = true - logger.info('[Anthropic Provider] Using OAuth token for authentication') - } const sdk = (await this.getSdkInstance()) as Anthropic const response = await sdk.models.list() - return response.data } diff --git a/src/renderer/src/aiCore/provider/config/anthropic.ts b/src/renderer/src/aiCore/provider/config/anthropic.ts deleted file mode 100644 index 3a8927de7b..0000000000 --- a/src/renderer/src/aiCore/provider/config/anthropic.ts +++ /dev/null @@ -1,24 +0,0 @@ -import type { SystemModelMessage } from 'ai' - -export function buildClaudeCodeSystemMessage(system?: string): Array { - const defaultClaudeCodeSystem = `You are Claude Code, Anthropic's official CLI for Claude.` - if (!system || system.trim() === defaultClaudeCodeSystem) { - return [ - { - role: 'system', - content: defaultClaudeCodeSystem - } - ] - } - - return [ - { - role: 'system', - content: defaultClaudeCodeSystem - }, - { - role: 'system', - content: system - } - ] -} diff --git a/src/renderer/src/aiCore/provider/providerConfig.ts b/src/renderer/src/aiCore/provider/providerConfig.ts index ccbdc008a1..020cd2a65c 100644 --- a/src/renderer/src/aiCore/provider/providerConfig.ts +++ b/src/renderer/src/aiCore/provider/providerConfig.ts @@ -79,9 +79,37 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider { /** * 格式化provider的API Host */ +function formatAnthropicApiHost(host: string): string { + const trimmedHost = host?.trim() + + if (!trimmedHost) { + return '' + } + + if (trimmedHost.endsWith('/')) { + return trimmedHost + } + + if (trimmedHost.endsWith('/v1')) { + return `${trimmedHost}/` + } + + return formatApiHost(trimmedHost) +} + function formatProviderApiHost(provider: Provider): Provider { const formatted = { ...provider } - if (formatted.type === 'gemini') { + if (formatted.anthropicApiHost) { + formatted.anthropicApiHost = formatAnthropicApiHost(formatted.anthropicApiHost) + } + + if (formatted.type === 'anthropic') { + const baseHost = formatted.anthropicApiHost || formatted.apiHost + formatted.apiHost = formatAnthropicApiHost(baseHost) + if (!formatted.anthropicApiHost) { + formatted.anthropicApiHost = formatted.apiHost + } + } else if (formatted.type === 'gemini') { formatted.apiHost = formatApiHost(formatted.apiHost, 'v1beta') } else { formatted.apiHost = formatApiHost(formatted.apiHost) diff --git a/src/renderer/src/api/agent.ts b/src/renderer/src/api/agent.ts new file mode 100644 index 0000000000..2b31873ce1 --- /dev/null +++ b/src/renderer/src/api/agent.ts @@ -0,0 +1,245 @@ +import { loggerService } from '@logger' +import { formatAgentServerError } from '@renderer/utils/error' +import { + AddAgentForm, + AgentServerErrorSchema, + ApiModelsFilter, + ApiModelsResponse, + ApiModelsResponseSchema, + CreateAgentRequest, + CreateAgentResponse, + CreateAgentResponseSchema, + CreateSessionForm, + CreateSessionRequest, + GetAgentResponse, + GetAgentResponseSchema, + GetAgentSessionResponse, + GetAgentSessionResponseSchema, + ListAgentSessionsResponse, + ListAgentSessionsResponseSchema, + type ListAgentsResponse, + ListAgentsResponseSchema, + objectEntries, + objectKeys, + UpdateAgentForm, + UpdateAgentRequest, + UpdateAgentResponse, + UpdateAgentResponseSchema, + UpdateSessionForm, + UpdateSessionRequest +} from '@types' +import axios, { Axios, AxiosRequestConfig, isAxiosError } from 'axios' +import { ZodError } from 'zod' + +type ApiVersion = 'v1' + +const logger = loggerService.withContext('AgentApiClient') + +// const logger = loggerService.withContext('AgentClient') +const processError = (error: unknown, fallbackMessage: string) => { + logger.error(fallbackMessage, error as Error) + if (isAxiosError(error)) { + const result = AgentServerErrorSchema.safeParse(error.response?.data) + if (result.success) { + return new Error(formatAgentServerError(result.data)) + } + } else if (error instanceof ZodError) { + return error + } + return new Error(fallbackMessage, { cause: error }) +} + +export class AgentApiClient { + private axios: Axios + private apiVersion: ApiVersion = 'v1' + constructor(config: AxiosRequestConfig, apiVersion?: ApiVersion) { + if (!config.baseURL || !config.headers?.Authorization) { + throw new Error('Please pass in baseUrl and Authroization header.') + } + if (config.baseURL.endsWith('/')) { + throw new Error('baseURL should not end with /') + } + this.axios = axios.create(config) + if (apiVersion) { + this.apiVersion = apiVersion + } + } + + public agentPaths = { + base: `/${this.apiVersion}/agents`, + withId: (id: string) => `/${this.apiVersion}/agents/${id}` + } + + public getSessionPaths = (agentId: string) => ({ + base: `/${this.apiVersion}/agents/${agentId}/sessions`, + withId: (id: string) => `/${this.apiVersion}/agents/${agentId}/sessions/${id}` + }) + + public getSessionMessagesPaths = (agentId: string, sessionId: string) => ({ + base: `/${this.apiVersion}/agents/${agentId}/sessions/${sessionId}/messages`, + withId: (id: number) => `/${this.apiVersion}/agents/${agentId}/sessions/${sessionId}/messages/${id}` + }) + + public getModelsPath = (props?: ApiModelsFilter) => { + const base = `/${this.apiVersion}/models` + if (!props) return base + if (objectKeys(props).length > 0) { + const params = objectEntries(props) + .map(([key, value]) => `${key}=${value}`) + .join('&') + return `${base}?${params}` + } else { + return base + } + } + + public async listAgents(): Promise { + const url = this.agentPaths.base + try { + const response = await this.axios.get(url) + const result = ListAgentsResponseSchema.safeParse(response.data) + if (!result.success) { + throw new Error('Not a valid Agents array.') + } + return result.data + } catch (error) { + throw processError(error, 'Failed to list agents.') + } + } + + public async createAgent(form: AddAgentForm): Promise { + const url = this.agentPaths.base + try { + const payload = form satisfies CreateAgentRequest + const response = await this.axios.post(url, payload) + const data = CreateAgentResponseSchema.parse(response.data) + return data + } catch (error) { + throw processError(error, 'Failed to create agent.') + } + } + + public async getAgent(id: string): Promise { + const url = this.agentPaths.withId(id) + try { + const response = await this.axios.get(url) + const data = GetAgentResponseSchema.parse(response.data) + if (data.id !== id) { + throw new Error('Agent ID mismatch in response') + } + return data + } catch (error) { + throw processError(error, 'Failed to get agent.') + } + } + + public async deleteAgent(id: string): Promise { + const url = this.agentPaths.withId(id) + try { + await this.axios.delete(url) + } catch (error) { + throw processError(error, 'Failed to delete agent.') + } + } + + public async updateAgent(form: UpdateAgentForm): Promise { + const url = this.agentPaths.withId(form.id) + try { + const payload = form satisfies UpdateAgentRequest + const response = await this.axios.patch(url, payload) + const data = UpdateAgentResponseSchema.parse(response.data) + if (data.id !== form.id) { + throw new Error('Agent ID mismatch in response') + } + return data + } catch (error) { + throw processError(error, 'Failed to updateAgent.') + } + } + + public async listSessions(agentId: string): Promise { + const url = this.getSessionPaths(agentId).base + try { + const response = await this.axios.get(url) + const result = ListAgentSessionsResponseSchema.safeParse(response.data) + if (!result.success) { + throw new Error('Not a valid Sessions array.') + } + return result.data + } catch (error) { + throw processError(error, 'Failed to list sessions.') + } + } + + public async createSession(agentId: string, session: CreateSessionForm): Promise { + const url = this.getSessionPaths(agentId).base + try { + const payload = session satisfies CreateSessionRequest + const response = await this.axios.post(url, payload) + const data = GetAgentSessionResponseSchema.parse(response.data) + return data + } catch (error) { + throw processError(error, 'Failed to add session.') + } + } + + public async getSession(agentId: string, sessionId: string): Promise { + const url = this.getSessionPaths(agentId).withId(sessionId) + try { + const response = await this.axios.get(url) + // const data = GetAgentSessionResponseSchema.parse(response.data) + // TODO: enable validation + const data = response.data + if (sessionId !== data.id) { + throw new Error('Session ID mismatch in response') + } + return data + } catch (error) { + throw processError(error, 'Failed to get session.') + } + } + + public async deleteSession(agentId: string, sessionId: string): Promise { + const url = this.getSessionPaths(agentId).withId(sessionId) + try { + await this.axios.delete(url) + } catch (error) { + throw processError(error, 'Failed to delete session.') + } + } + + public async deleteSessionMessage(agentId: string, sessionId: string, messageId: number): Promise { + const url = this.getSessionMessagesPaths(agentId, sessionId).withId(messageId) + try { + await this.axios.delete(url) + } catch (error) { + throw processError(error, 'Failed to delete session message.') + } + } + + public async updateSession(agentId: string, session: UpdateSessionForm): Promise { + const url = this.getSessionPaths(agentId).withId(session.id) + try { + const payload = session satisfies UpdateSessionRequest + const response = await this.axios.patch(url, payload) + const data = GetAgentSessionResponseSchema.parse(response.data) + if (session.id !== data.id) { + throw new Error('Session ID mismatch in response') + } + return data + } catch (error) { + throw processError(error, 'Failed to update session.') + } + } + + public async getModels(props?: ApiModelsFilter): Promise { + const url = this.getModelsPath(props) + try { + const response = await this.axios.get(url) + const data = ApiModelsResponseSchema.parse(response.data) + return data + } catch (error) { + throw processError(error, 'Failed to get models.') + } + } +} diff --git a/src/renderer/src/assets/styles/index.css b/src/renderer/src/assets/styles/index.css index 828342a06a..1f03a6671d 100644 --- a/src/renderer/src/assets/styles/index.css +++ b/src/renderer/src/assets/styles/index.css @@ -12,17 +12,23 @@ @import '../fonts/country-flag-fonts/flag.css'; @layer base { - *, - *::before, - *::after { - box-sizing: border-box; - /* margin: 0; */ - font-weight: normal; + @layer base { + *, + *::before, + *::after { + box-sizing: border-box; + /* margin: 0; */ + font-weight: normal; + } + } + + .lucide:not(.lucide-custom) { + color: var(--color-icon); } } *:focus { - outline: none; + outline-style: none; } * { -webkit-tap-highlight-color: transparent; diff --git a/src/renderer/src/components/ApiModelLabel.tsx b/src/renderer/src/components/ApiModelLabel.tsx new file mode 100644 index 0000000000..68cc3dbbaf --- /dev/null +++ b/src/renderer/src/components/ApiModelLabel.tsx @@ -0,0 +1,28 @@ +import { Avatar, cn } from '@heroui/react' +import { getModelLogo } from '@renderer/config/models' +import { ApiModel } from '@renderer/types' +import React from 'react' + +import Ellipsis from './Ellipsis' + +export interface ModelLabelProps extends Omit, 'children'> { + model?: ApiModel + classNames?: { + container?: string + avatar?: string + modelName?: string + divider?: string + providerName?: string + } +} + +export const ApiModelLabel: React.FC = ({ model, className, classNames, ...props }) => { + return ( +
+ + {model?.name} + | + {model?.provider_name} +
+ ) +} diff --git a/src/renderer/src/components/Avatar/EmojiAvatarWithPicker.tsx b/src/renderer/src/components/Avatar/EmojiAvatarWithPicker.tsx new file mode 100644 index 0000000000..6735d86a4e --- /dev/null +++ b/src/renderer/src/components/Avatar/EmojiAvatarWithPicker.tsx @@ -0,0 +1,22 @@ +import { Button, Popover, PopoverContent, PopoverTrigger } from '@heroui/react' +import React from 'react' + +import EmojiPicker from '../EmojiPicker' + +type Props = { + emoji: string + onPick: (emoji: string) => void +} + +export const EmojiAvatarWithPicker: React.FC = ({ emoji, onPick }) => { + return ( + + + + + {form.accessible_paths.length > 0 ? ( +
+ {form.accessible_paths.map((path) => ( +
+ + {path} + + +
+ ))} +
+ ) : ( +

{t('agent.session.accessible_paths.empty')}

+ )} + +