mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-19 14:41:24 +08:00
feat: add middleware support for provider (#6176)
* feat: add middleware support for OpenAIProvider with logging capabilities - Introduced middleware functionality in OpenAIProvider to enhance completions processing. - Created AiProviderMiddlewareTypes for defining middleware interfaces and contexts. - Implemented sampleLoggingMiddleware for logging message content and processing times. - Updated OpenAIProvider constructor to accept middleware as an optional parameter. - Refactored completions method to utilize middleware for improved extensibility and logging. * refactor: streamline OpenAIProvider initialization and middleware application - Removed optional middleware parameter from OpenAIProvider constructor for simplicity. - Refactored ProviderFactory to create instances of providers and apply logging middleware consistently. - Enhanced completions method visibility by changing it from private to public. - Cleaned up unused code related to middleware handling in OpenAIProvider. * feat: enhance AiProvider with new middleware capabilities and completion context - Added public getter for provider info in BaseProvider. - Introduced finalizeSdkRequestParams hook for middleware to modify SDK-specific request parameters. - Refactored completions method in OpenAIProvider to accept a context object, improving middleware integration. - Updated middleware types to include new context structure and callback functions for better extensibility. - Enhanced logging middleware to utilize new context structure for improved logging capabilities. * refactor: enhance middleware structure and context handling in AiProvider - Updated BaseProvider and AiProvider to utilize AiProviderMiddlewareCompletionsContext for completions method. - Introduced new utility functions for middleware context creation and execution. - Refactored middleware application logic to improve extensibility and maintainability. - Replaced sampleLoggingMiddleware with a more robust LoggingMiddleware implementation. - Added new context management features for better middleware integration. * refactor: update AiProvider and middleware structure for improved completions handling - Refactored BaseProvider and AiProvider to change completions method signature from context to params. - Removed unused AiProviderMiddlewareCompletionsContext and related code for cleaner implementation. - Enhanced middleware configuration by introducing a dedicated middleware registration file. - Implemented logging middleware for completions to improve observability during processing. - Streamlined middleware application logic in ProviderFactory for better maintainability. * docs: 添加中间件编写指南文档 - 新增《如何为 AI Provider 编写中间件》文档,详细介绍中间件架构、类型及编写示例。 - 说明了中间件的执行顺序、注册方法及最佳实践,旨在帮助开发者有效创建和维护中间件。 * refactor: update completions method signatures and introduce CompletionsResult type - Changed the completions method signature in BaseProvider and AiProvider to return CompletionsResult instead of void. - Added CompletionsResult type definition to encapsulate streaming and usage metrics. - Updated middleware and related components to handle the new CompletionsResult structure, ensuring compatibility with existing functionality. - Introduced new middleware for stream adaptation to enhance chunk processing during completions. * refactor: enhance AiProvider middleware and streaming handling - Updated CompletionsResult type to support both OpenAI SDK stream and ReadableStream. - Modified CompletionsMiddleware to return CompletionsResult, improving type safety. - Introduced StreamAdapterMiddleware to adapt OpenAI SDK streams to application-specific chunk streams. - Enhanced logging in CompletionsLoggingMiddleware to capture and return results from next middleware calls. * refactor: update AiProvider and middleware for OpenAI completions handling - Renamed CompletionsResult to CompletionsOpenAIResult for clarity and updated its structure to support both OpenAI SDK and application-specific streams. - Modified completions method signatures in AiProvider and OpenAIProvider to return CompletionsOpenAIResult. - Enhanced middleware to process and adapt OpenAI SDK streams into standard chunk formats, improving overall streaming handling. - Introduced new middleware components: FinalChunkConsumerAndNotifierMiddleware and OpenAISDKChunkToStandardChunkMiddleware for better chunk processing and logging. * 删除 ExtractReasoningCompletionsMiddleware.ts 文件,清理未使用的中间件代码以提高代码整洁性和可维护性。 * refactor: consolidate middleware types and improve imports - Replaced references to AiProviderMiddlewareTypes with the new middlewareTypes file across various middleware components for better organization. - Introduced TextChunkMiddleware to enhance chunk processing from OpenAI SDK streams. - Cleaned up imports in multiple files to reflect the new structure, improving code clarity and maintainability. * feat: enhance abort handling with AbortController in middleware chain - Update CompletionsOpenAIResult interface to use AbortController instead of AbortSignal - Modify OpenAIProvider to pass abortController in completions method return - Update AbortHandlerMiddleware to use controller from upstream result - Improve abort handling flexibility by exposing full controller capabilities - Enable middleware to actively control abort operations beyond passive monitoring This change provides better control over request cancellation and enables more sophisticated abort handling patterns in the middleware pipeline. * refactor: enhance AiProvider and middleware for improved completions handling - Updated BaseProvider to expose additional methods and properties, including getMessageParam and createAbortController. - Modified OpenAIProvider to streamline completions processing and integrate new middleware for tool handling. - Introduced TransformParamsBeforeCompletions middleware to standardize parameter transformation before completions. - Added McpToolChunkMiddleware for managing tool calls within the completions stream. - Enhanced middleware types to support new functionalities and improve overall structure. These changes improve the flexibility and maintainability of the AiProvider and its middleware, facilitating better handling of OpenAI completions and tool interactions. * refactor: enhance middleware for recursive handling and internal state management - Introduced internal state management in middleware to support recursive calls, including enhanced dispatch functionality. - Updated middleware types to include new internal fields for managing recursion depth and call status. - Improved logging for better traceability of recursive calls and state transitions. - Adjusted various middleware components to utilize the new internal state, ensuring consistent behavior during recursive processing. These changes enhance the middleware's ability to handle complex scenarios involving recursive calls, improving overall robustness and maintainability. * fix(OpenAIProvider): return empty object for missing sdkParams in completions handling - Updated OpenAIProvider to return an empty object instead of undefined when sdkParams are not found, ensuring consistent return types. - Enhanced TransformParamsBeforeCompletions middleware to include a flag for built-in web search functionality based on assistant settings. * refactor(OpenAIProvider): enhance completions handling and middleware integration - Updated the completions method in OpenAIProvider to include an onChunk callback for improved streaming support. - Enabled the ThinkChunkMiddleware in the middleware registration for better handling of reasoning content. - Increased the maximum recursion depth in McpToolChunkMiddleware to prevent infinite loops. - Refined TextChunkMiddleware to directly enqueue chunks without unnecessary type checks. - Improved the ThinkChunkMiddleware to better manage reasoning tags and streamline chunk processing. These changes enhance the overall functionality and robustness of the AI provider and middleware components. * feat(WebSearchMiddleware): add web search handling and integration - Introduced WebSearchMiddleware to process various web search results, including annotations and citations, and generate LLM_WEB_SEARCH_COMPLETE chunks. - Enhanced TextChunkMiddleware to support link conversion based on the model and assistant settings, improving the handling of TEXT_DELTA chunks. - Updated middleware registration to include WebSearchMiddleware for comprehensive search result processing. These changes enhance the AI provider's capabilities in handling web search functionalities and improve the overall middleware architecture. * fix(middleware): improve optional chaining for chunk processing - Updated McpToolChunkMiddleware and ThinkChunkMiddleware to use optional chaining for accessing choices, enhancing robustness against undefined values. - Removed commented-out code in ThinkChunkMiddleware to streamline the chunk handling process. These changes improve the reliability of middleware when processing OpenAI API responses. * feat(middleware): enhance AbortHandlerMiddleware with recursion handling - Added logic to detect and handle recursive calls, preventing unnecessary creation of AbortControllers. - Improved logging for better visibility into middleware operations, including recursion depth and cleanup processes. - Streamlined cleanup process for non-stream responses to ensure resources are released promptly. These changes enhance the robustness and efficiency of the AbortHandlerMiddleware in managing API requests. * docs(middleware): 迁移步骤 * feat(middleware): implement FinalChunkConsumerMiddleware for usage and metrics accumulation - Introduced FinalChunkConsumerMiddleware to replace the deprecated FinalChunkConsumerAndNotifierMiddleware. - This new middleware accumulates usage and metrics data from OpenAI API responses, enhancing tracking capabilities. - Updated middleware registration to utilize the new FinalChunkConsumerMiddleware, ensuring proper integration. - Added support for handling recursive calls and improved logging for better debugging and monitoring. These changes enhance the middleware's ability to manage and report usage metrics effectively during API interactions. * refactor(migrate): update API request and response structures to TypeScript types - Changed the definitions of `CoreCompletionsRequest` and `Chunk` to use TypeScript types instead of Zod Schemas for better type safety and clarity. - Updated middleware and service classes to handle the new `Chunk` type, ensuring compatibility with the revised API client structure. - Enhanced the response processing logic to standardize the handling of raw SDK chunks into application-level `Chunk` objects. - Adjusted middleware to consume the new `Chunk` type, streamlining the overall architecture and improving maintainability. These changes facilitate a more robust and type-safe integration with AI provider APIs. * feat(AiProvider): implement API client architecture - Introduced ApiClientFactory for creating instances of API clients based on provider configuration. - Added BaseApiClient as an abstract class to provide common functionality for specific client implementations. - Implemented OpenAIApiClient for OpenAI and Azure OpenAI, including request and response handling. - Defined types and interfaces for API client operations, enhancing type safety and clarity. - Established middleware schemas for standardized request processing across AI providers. These changes lay the groundwork for a modular and extensible API client architecture, improving the integration of various AI providers. * refactor(StreamAdapterMiddleware): simplify stream adaptation logic - Updated StreamAdapterMiddleware to directly use AsyncIterable instead of wrapping it with rawSdkChunkAdapter, streamlining the adaptation process. - Modified asyncGeneratorToReadableStream to accept AsyncIterable, enhancing its flexibility and usability. These changes improve the efficiency of stream handling in the middleware. * refactor(AiProvider): simplify ResponseChunkTransformer interface and streamline OpenAIApiClient response handling - Changed ResponseChunkTransformer from an interface to a type for improved clarity and simplicity. - Refactored OpenAIApiClient to streamline the response transformation logic, reducing unnecessary complexity in handling tool calls and reasoning content. - Enhanced type safety by ensuring consistent handling of optional properties in response processing. These changes improve the maintainability and readability of the codebase while ensuring robust response handling in the API client. * doc(technicalArchitecture): add comprehensive documentation for AI Provider architecture * feat(architecture): introduce AI Core Design documentation and middleware specification - Added a comprehensive technical architecture document for the new AI Provider (`aiCore`), outlining core design principles, component details, and execution flow. - Established a middleware specification document to define the design, implementation, and usage of middleware within the `aiCore` module, promoting a flexible and maintainable system. - These additions provide clarity and guidance for future development and integration of AI functionalities within Cherry Studio. * refactor(middleware): consolidate and enhance middleware architecture - Removed deprecated extractReasoningMiddleware and integrated its functionality into existing middleware. - Streamlined middleware registration and improved type definitions for better clarity and maintainability. - Introduced new middleware components for handling chunk processing, web search, and reasoning tags, enhancing overall functionality. - Updated various middleware to utilize the new structures and improve logging for better debugging. These changes enhance the middleware's efficiency and maintainability, providing a more robust framework for API interactions. * refactor(AiProvider): enhance API client and middleware integration - Updated ApiClientFactory to include new SDK types for improved type safety and clarity. - Refactored BaseApiClient to support additional parameters in the completions method, enhancing flexibility for processing states. - Streamlined OpenAIApiClient to better handle tool calls and responses, including the introduction of new chunk types for tool management. - Improved middleware architecture by integrating processing states and refining message handling, ensuring a more robust interaction with the API. These changes enhance the overall maintainability and functionality of the API client and middleware, providing a more efficient framework for AI interactions. * fix(McpToolChunkMiddleware): remove redundant logging in recursion state update * refactor(McpToolChunkMiddleware): update tool call handling and type definitions - Replaced ChatCompletionMessageToolCall with SdkToolCall for improved type consistency. - Updated return types of executeToolCalls and executeToolUses functions to SdkMessage[], enhancing clarity in message handling. - Removed unused import to streamline the code. These changes enhance the maintainability and type safety of the middleware, ensuring better integration with the SDK. * refactor(middleware): enhance middleware structure and type handling - Updated middleware components to utilize new SDK types, improving type safety and clarity across the board. - Refactored various middleware to streamline processing logic, including enhanced handling of SDK messages and tool calls. - Improved logging and error handling for better debugging and maintainability. - Consolidated middleware functions to reduce redundancy and improve overall architecture. These changes enhance the robustness and maintainability of the middleware framework, ensuring a more efficient interaction with the API. * refactor(middleware): unify type imports and enhance middleware structure - Updated middleware components to import types from a unified 'types' file, improving consistency and clarity across the codebase. - Removed the deprecated 'type.ts' file to streamline the middleware structure. - Enhanced middleware registration and export mechanisms for better accessibility and maintainability. These changes contribute to a more organized and efficient middleware framework, facilitating easier future development and integration. * refactor(AiProvider): enhance API client and middleware integration - Updated AiProvider components to support new SDK types, improving type safety and clarity. - Refactored middleware to streamline processing logic, including enhanced handling of tool calls and responses. - Introduced new middleware for tool use extraction and raw stream listening, improving overall functionality. - Improved logging and error handling for better debugging and maintainability. These changes enhance the robustness and maintainability of the API client and middleware, ensuring a more efficient interaction with the API. * feat(middleware): add new middleware components for raw stream listening and tool use extraction - Introduced RawStreamListenerMiddleware and ToolUseExtractionMiddleware to enhance middleware capabilities. - Updated MiddlewareRegistry to include new middleware entries, improving overall functionality and extensibility. These changes expand the middleware framework, facilitating better handling of streaming and tool usage scenarios. * refactor(AiProvider): integrate new API client and middleware architecture - Replaced BaseProvider with ApiClientFactory to enhance API client instantiation. - Updated completions method to utilize new middleware architecture for improved processing. - Added TODOs for refactoring remaining methods to align with the new API client structure. - Removed deprecated middleware wrapping logic from ApiClientFactory for cleaner implementation. These changes improve the overall structure and maintainability of the AiProvider, facilitating better integration with the new middleware system. * refactor(middleware): update middleware architecture and documentation - Revised middleware naming conventions and introduced a centralized MiddlewareRegistry for better management and accessibility. - Enhanced MiddlewareBuilder to support named middleware and streamline the construction of middleware chains. - Updated documentation to reflect changes in middleware usage and structure, improving clarity for future development. These changes improve the organization and usability of the middleware framework, facilitating easier integration and maintenance. * refactor(AiProvider): enhance completions middleware logic and API client handling - Updated the completions method to conditionally remove middleware based on parameters, improving flexibility in processing. - Refactored the response chunk transformer in OpenAIApiClient and AnthropicAPIClient to utilize a more streamlined approach with TransformStream. - Simplified middleware context handling by removing unnecessary custom state management. - Improved logging and error handling across middleware components for better debugging and maintainability. These changes enhance the efficiency and clarity of the AiProvider's middleware integration, ensuring a more adaptable and robust processing framework. * refactor(AiProvider, middleware): clean up logging and improve method naming - Removed unnecessary logging of parameters in AiProvider to streamline the code. - Updated method name assignment in middleware to enhance clarity and consistency. These changes contribute to a cleaner codebase and improve the readability of the middleware and provider components. * feat(middleware): enhance middleware types and add RawStreamListenerMiddleware - Introduced RawStreamListenerMiddleware to the MiddlewareName enum for improved middleware capabilities. - Updated type definitions across middleware components to enhance type safety and clarity, including the addition of new SDK types. - Refactored context and middleware API interfaces to support more specific type parameters, improving overall maintainability. These changes expand the middleware framework, facilitating better handling of streaming scenarios and enhancing type safety across the codebase. * refactor(messageThunk): convert callback functions to async and handle errors during database updates This commit updates several callback functions in the messageThunk to be asynchronous, ensuring that block transitions are awaited properly. Additionally, error handling is added for the database update function to log any failures when saving blocks. This improves the reliability and responsiveness of the message processing flow. * refactor: enhance message block handling in messageThunk This commit refactors the message processing logic in messageThunk to improve the management of message blocks. Key changes include the introduction of dedicated IDs for different block types (main text, thinking, tool, and image) to streamline updates and transitions. The handling of placeholder blocks has been improved, ensuring that they are correctly converted to their respective types during processing. Additionally, error handling has been enhanced for better reliability in database updates. * feat(AiProvider): add default timeout configuration and enhance API client aborthandler - Introduced a default timeout constant to the configuration for improved API client timeout management. - Updated BaseApiClient and its derived classes to utilize the new timeout setting, ensuring consistent timeout behavior across different API clients. - Enhanced middleware to pass the timeout value during API calls, improving error handling and responsiveness. These changes improve the overall robustness and configurability of the API client interactions, facilitating better control over request timeouts. * feat(GeminiProvider): implement Gemini API client and enhance file handling - Introduced GeminiAPIClient to facilitate interactions with the Gemini API, replacing the previous GoogleGenAI integration. - Refactored GeminiProvider to utilize the new API client, improving code organization and maintainability. - Enhanced file handling capabilities, including support for PDF uploads and retrieval of file metadata. - Updated message processing to accommodate new SDK types and improve content generation logic. These changes significantly enhance the functionality and robustness of the GeminiProvider, enabling better integration with the Gemini API and improving overall user experience. * refactor(AiProvider, middleware): streamline API client and middleware integration - Removed deprecated methods and types from various API clients, enhancing code clarity and maintainability. - Updated the CompletionsParams interface to support messages as a string or array, improving flexibility in message handling. - Refactored middleware components to eliminate unnecessary state management and improve type safety. - Enhanced the handling of streaming responses and added utility functions for better stream management. These changes contribute to a more robust and efficient architecture for the AiProvider and its associated middleware, facilitating improved API interactions and user experience. * refactor(middleware): translation 适配 - Deleted SdkCallMiddleware to streamline middleware architecture and improve maintainability. - Commented out references to SdkCallModule in examples and registration files to prevent usage. - Enhanced logging in AbortHandlerMiddleware for better debugging and tracking of middleware execution. - Updated parameters in ResponseTransformMiddleware to improve flexibility in handling response settings. These changes contribute to a cleaner and more efficient middleware framework, facilitating better integration and performance. * refactor(ApiCheck): streamline API validation and error handling - Updated the API check logic to simplify validation processes and improve error handling across various components. - Refactored the `checkApi` function to throw errors directly instead of returning validation objects, enhancing clarity in error management. - Improved the handling of API key checks in `checkModelWithMultipleKeys` to provide more informative error messages. - Added a new method `getEmbeddingDimensions` in the `AiProvider` class to facilitate embedding dimension retrieval, enhancing model compatibility checks. These changes contribute to a more robust and maintainable API validation framework, improving overall user experience and error reporting. * refactor(HealthCheckService, ModelService): improve error handling and performance metrics - Updated error handling in `checkModelWithMultipleKeys` to truncate error messages for better readability. - Refactored `performModelCheck` to remove unnecessary error handling, focusing on performance metrics by returning only latency. - Enhanced the `checkModel` function to ensure consistent return types, improving clarity in API interactions. These changes contribute to a more efficient and user-friendly error reporting and performance tracking system. * refactor(AiProvider, models): enhance model handling and API client integration - Updated the `listModels` method in various API clients to improve model retrieval and ensure consistent return types. - Refactored the `EditModelsPopup` component to handle model properties more robustly, including fallback options for `id`, `name`, and other attributes. - Enhanced type definitions for models in the SDK to support new integrations and improve type safety. These changes contribute to a more reliable and maintainable model management system within the AiProvider, enhancing overall user experience and API interactions. * refactor(AiProvider, clients): implement image generation functionality - Refactored the `generateImage` method in the `AiProvider` class to utilize the `apiClient` for image generation, replacing the previous placeholder implementation. - Updated the `BaseApiClient` to include an abstract `generateImage` method, ensuring all derived clients implement this functionality. - Implemented the `generateImage` method in `GeminiAPIClient` and `OpenAIAPIClient`, providing specific logic for image generation based on the respective SDKs. - Added type definitions for `GenerateImageParams` across relevant files to enhance type safety and clarity in image generation parameters. These changes enhance the image generation capabilities of the AiProvider, improving integration with various API clients and overall user experience. * refactor(AiProvider, clients): restructure API client architecture and remove deprecated components - Refactored the `ProviderFactory` and removed the `AihubmixProvider` to streamline the API client architecture. - Updated the import paths for `isOpenAIProvider` to reflect the new structure. - Introduced `AihubmixAPIClient` and `OpenAIResponseAPIClient` to enhance client handling based on model types. - Improved the `AiProvider` class to utilize the new clients for better model-specific API interactions. - Enhanced type definitions and error handling across various components to improve maintainability and clarity. These changes contribute to a more efficient and organized API client structure, enhancing overall integration and user experience. * fix: update system prompt handling in API clients to use await for asynchronous operations - Modified the `AnthropicAPIClient`, `GeminiAPIClient`, `OpenAIAPIClient`, and `OpenAIResponseAPIClient` to ensure `buildSystemPrompt` is awaited, improving the handling of system prompts. - Adjusted the `fetchMessagesSummary` function to utilize the last five user messages for better context in API calls and added a utility function to clean up topic names. These changes enhance the reliability of prompt generation and improve the overall API interaction experience. * refactor(middleware): remove examples.ts to streamline middleware documentation - Deleted the `examples.ts` file containing various middleware usage examples to simplify the middleware structure and documentation. - This change contributes to a cleaner codebase and focuses on essential middleware components, enhancing maintainability. * refactor(AiProvider, middleware): enhance middleware handling and error management - Updated the `CompletionsParams` interface to include a new `callType` property for better middleware decision-making based on the context of the API call. - Introduced `ErrorHandlerMiddleware` to standardize error handling across middleware, allowing errors to be captured and processed as `ErrorChunk` objects. - Modified the `AbortHandlerMiddleware` to conditionally remove itself based on the `callType`, improving middleware efficiency. - Cleaned up logging in `AbortHandlerMiddleware` to reduce console output and enhance performance. - Updated middleware registration to include the new `ErrorHandlerMiddleware`, ensuring comprehensive error management in the middleware pipeline. These changes contribute to a more robust and maintainable middleware architecture, improving error handling and overall API interaction efficiency. * feat: implement token estimation for message handling - Added an abstract method `estimateMessageTokens` to the `BaseApiClient` class for estimating token usage based on message content. - Implemented the `estimateMessageTokens` method in `AnthropicAPIClient`, `GeminiAPIClient`, `OpenAIAPIClient`, and `OpenAIResponseAPIClient` to calculate token consumption for various message types. - Enhanced middleware to accumulate token usage for new messages, improving tracking of API call costs. These changes improve the efficiency of message processing and provide better insights into token usage across different API clients. * feat: add support for image generation and model handling - Introduced `SUPPORTED_DISABLE_GENERATION_MODELS` to manage models that disable image generation. - Updated `isSupportedDisableGenerationModel` function to check model compatibility. - Enhanced `Inputbar` logic to conditionally enable image generation based on model support. - Modified API clients to handle image generation calls and responses, including new chunk types for image data. - Updated middleware and service layers to incorporate image generation parameters and improve overall processing. These changes enhance the application's capabilities for image generation and improve the handling of various model types. * feat: enhance GeminiAPIClient for image generation support - Added `getGenerateImageParameter` method to configure image generation parameters. - Updated request handling in `GeminiAPIClient` to include image generation options. - Enhanced response processing to handle image data and enqueue it correctly. These changes improve the GeminiAPIClient's capabilities for generating and processing images, aligning with recent enhancements in image generation support. * feat: enhance image generation handling in OpenAIResponseAPIClient and middleware - Updated OpenAIResponseAPIClient to improve user message processing for image generation. - Added handling for image creation events in TransformCoreToSdkParamsMiddleware. - Adjusted ApiService to streamline image generation event handling. - Modified messageThunk to reflect changes in image block status during processing. These enhancements improve the integration and responsiveness of image generation features across the application. * refactor: remove unused AI provider classes - Deleted `AihubmixProvider`, `AnthropicProvider`, `BaseProvider`, `GeminiProvider`, and `OpenAIProvider` as they are no longer utilized in the codebase. - This cleanup reduces code complexity and improves maintainability by removing obsolete components related to AI provider functionality. * chore: remove obsolete test files for middleware - Deleted test files for `AbortHandlerMiddleware`, `LoggingMiddleware`, `TextChunkMiddleware`, `ThinkChunkMiddleware`, and `WebSearchMiddleware` as they are no longer needed. - This cleanup helps streamline the codebase and reduces maintenance overhead by removing outdated tests. * chore: remove Suggestions component and related functionality - Deleted the `Suggestions` component from the home page as it is no longer needed. - Removed associated imports and functions related to suggestion fetching, streamlining the codebase. - This cleanup helps improve maintainability by eliminating unused components. * feat: enhance OpenAIAPIClient and StreamProcessingService for tool call handling - Updated OpenAIAPIClient to conditionally include tool calls in the assistant message, improving message processing logic. - Enhanced tool call handling in the response transformer to correctly manage and enqueue tool call data. - Added a new callback for LLM response completion in StreamProcessingService, allowing better integration of response handling. These changes improve the functionality and responsiveness of the OpenAI API client and stream processing capabilities. * fix: copilot error * fix: improve chunk handling in TextChunkMiddleware and ThinkChunkMiddleware - Updated TextChunkMiddleware to enqueue LLM_RESPONSE_COMPLETE chunks based on accumulated text content. - Refactored ThinkChunkMiddleware to generate THINKING_COMPLETE chunks when receiving non-THINKING_DELTA chunks, ensuring proper handling of accumulated thinking content. - These changes enhance the middleware's responsiveness and accuracy in processing text and thinking chunks. * chore: update dependencies and improve styling - Updated `selection-hook` dependency to version 0.9.23 in `package.json` and `yarn.lock`. - Removed unused styles from `container.scss` and adjusted padding in `index.scss`. - Enhanced message rendering and layout in various components, including `Message`, `MessageHeader`, and `MessageMenubar`. - Added tooltip support for message divider settings in `SettingsTab`. - Improved handling of citation display in `CitationsList` and `CitationBlock`. These changes streamline the codebase and enhance the user interface for better usability. * feat: implement image generation middleware and enhance model handling - Added `ImageGenerationMiddleware` to handle dedicated image generation models, integrating image processing and OpenAI's image generation API. - Updated `AiProvider` to utilize the new middleware for dedicated image models, ensuring proper middleware chaining. - Introduced constants for dedicated image models in `models.ts` to streamline model identification. - Refactored error handling in `ErrorHandlerMiddleware` to use a utility function for better error management. - Cleaned up imports and removed unused code in various files for improved maintainability. * fix: update dedicated image models identification logic - Modified the `DEDICATED_IMAGE_MODELS` array to include 'grok-2-image' for improved model handling. - Enhanced the `isDedicatedImageGenerationModel` function to use a more robust check for model identification, ensuring better accuracy in middleware processing. * refactor: remove OpenAIResponseProvider class - Deleted the `OpenAIResponseProvider` class from the `AiProvider` module, streamlining the codebase by eliminating unused code. - This change enhances maintainability and reduces complexity in the provider architecture. * fix: usermessage * refactor: simplify AbortHandlerMiddleware for improved abort handling - Removed direct dependency on ApiClient for creating AbortController, enhancing modularity. - Introduced utility functions to manage abort controllers, streamlining the middleware's responsibilities. - Delegated abort signal handling to downstream middlewares, allowing for cleaner separation of concerns. * refactor(aiCore): Consolidate AI provider and middleware architecture This commit refactors the AI-related modules by unifying the `clients` and `middleware` directories under a single `aiCore` directory. This change simplifies the project structure, improves modularity, and makes the architecture more cohesive. Key changes: - Relocated provider-specific clients and middleware into the `aiCore` directory, removing the previous `providers/AiProvider` structure. - Updated the architectural documentation (`AI_CORE_DESIGN.md`) to accurately reflect the new, streamlined directory layout and execution flow. - The main `AiProvider` class is now the primary export of `aiCore/index.ts`, serving as the central access point for AI functionalities. * refactor: update imports and enhance middleware functionality - Adjusted import statements in `AnthropicAPIClient` and `GeminiAPIClient` for better organization. - Improved `AbortHandlerMiddleware` to handle abort signals more effectively, including the conversion of streams to handle abort scenarios. - Enhanced `ErrorHandlerMiddleware` to differentiate between abort errors and other types, ensuring proper error handling. - Cleaned up commented-out code in `FinalChunkConsumerMiddleware` for better readability and maintainability. * refactor: streamline middleware logging and improve error handling - Removed excessive debug logging from various middleware components, including `AbortHandlerMiddleware`, `FinalChunkConsumerMiddleware`, and `McpToolChunkMiddleware`, to enhance readability and performance. - Updated logging levels to use warnings for potential issues in `ResponseTransformMiddleware`, `TextChunkMiddleware`, and `ThinkChunkMiddleware`, ensuring better visibility of important messages. - Cleaned up commented-out code and unnecessary debug statements across multiple middleware files for improved maintainability. --------- Co-authored-by: suyao <sy20010504@gmail.com> Co-authored-by: eeee0717 <chentao020717Work@outlook.com> Co-authored-by: lizhixuan <zhixuan.li@banosuperapp.com>
This commit is contained in:
parent
6ad9044cd1
commit
5f4d73b00d
1
.vscode/launch.json
vendored
1
.vscode/launch.json
vendored
@ -7,7 +7,6 @@
|
||||
"request": "launch",
|
||||
"cwd": "${workspaceRoot}",
|
||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite",
|
||||
"runtimeVersion": "20",
|
||||
"windows": {
|
||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd"
|
||||
},
|
||||
|
||||
214
docs/technical/how-to-write-middlewares.md
Normal file
214
docs/technical/how-to-write-middlewares.md
Normal file
@ -0,0 +1,214 @@
|
||||
# 如何为 AI Provider 编写中间件
|
||||
|
||||
本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。
|
||||
|
||||
## 架构概览
|
||||
|
||||
我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。
|
||||
|
||||
- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。
|
||||
- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。
|
||||
- **上下文 (Context)**: 一个在中间件之间传递的对象,携带了关于当前调用的信息(如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。
|
||||
|
||||
## 中间件的类型
|
||||
|
||||
目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景:
|
||||
|
||||
1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。
|
||||
2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。
|
||||
|
||||
## 编写一个 `CompletionsMiddleware`
|
||||
|
||||
`CompletionsMiddleware` 的基本签名(TypeScript 类型)如下:
|
||||
|
||||
```typescript
|
||||
import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径
|
||||
|
||||
export type CompletionsMiddleware = (
|
||||
api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>
|
||||
) => (
|
||||
next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any> // next 返回 Promise<any> 代表原始SDK响应或下游中间件的结果
|
||||
) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<void> // 最内层函数通常返回 Promise<void>,因为结果通过 onChunk 或 context 副作用传递
|
||||
```
|
||||
|
||||
让我们分解这个三段式结构:
|
||||
|
||||
1. **第一层函数 `(api) => { ... }`**:
|
||||
|
||||
- 接收一个 `api` 对象。
|
||||
- `api` 对象提供了以下方法:
|
||||
- `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。
|
||||
- `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。
|
||||
- `api.getProviderId()`: 获取当前 Provider 的 ID。
|
||||
- `api.getProviderInstance()`: 获取原始的 Provider 实例。
|
||||
- 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。
|
||||
|
||||
2. **第二层函数 `(next) => { ... }`**:
|
||||
|
||||
- 接收一个 `next` 函数。
|
||||
- `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。
|
||||
- `next` 函数接收当前的 `context` 和 `params` (这些可能已被上游中间件修改)。
|
||||
- **重要的是**:`next` 的返回类型通常是 `Promise<any>`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK,它将返回原始的 SDK 响应(例如,OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。
|
||||
- 此函数返回第三层(也是最核心的)函数。
|
||||
|
||||
3. **第三层函数 `(context, params) => { ... }`**:
|
||||
- 这是执行中间件主要逻辑的地方。
|
||||
- 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。
|
||||
- 在此函数中,你可以:
|
||||
- **在调用 `next` 之前**:
|
||||
- 读取或修改 `params`。例如,添加默认参数、转换消息格式。
|
||||
- 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。
|
||||
- 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。
|
||||
- **调用 `await next(context, params)`**:
|
||||
- 这是将控制权传递给下游的关键步骤。
|
||||
- `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。
|
||||
- **在调用 `next` 之后**:
|
||||
- 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。
|
||||
- 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。
|
||||
- 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。
|
||||
|
||||
### 示例:一个简单的日志中间件
|
||||
|
||||
```typescript
|
||||
import {
|
||||
AiProviderMiddlewareCompletionsContext,
|
||||
CompletionsParams,
|
||||
MiddlewareAPI,
|
||||
OnChunkFunction // 假设 OnChunkFunction 类型被导出
|
||||
} from './AiProviderMiddlewareTypes' // 调整路径
|
||||
import { ChunkType } from '@renderer/types' // 调整路径
|
||||
|
||||
export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => {
|
||||
return (api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>) => {
|
||||
// console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`);
|
||||
|
||||
return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any>) => {
|
||||
return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise<void> => {
|
||||
const startTime = Date.now()
|
||||
// 从 context 中获取 onChunk (它最初来自 params.onChunk)
|
||||
const onChunk = context.onChunk
|
||||
|
||||
console.log(
|
||||
`[LoggingMiddleware] Request for ${context.methodName} with params:`,
|
||||
params.messages?.[params.messages.length - 1]?.content
|
||||
)
|
||||
|
||||
try {
|
||||
// 调用下一个中间件或核心逻辑
|
||||
// `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象)
|
||||
const rawSdkResponse = await next(context, params)
|
||||
|
||||
// 此处简单示例不处理 rawSdkResponse,假设下游中间件 (如 StreamingResponseHandler)
|
||||
// 会处理它并通过 onChunk 发送数据。
|
||||
// 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。
|
||||
// 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。
|
||||
|
||||
const duration = Date.now() - startTime
|
||||
console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`)
|
||||
|
||||
// 假设下游已经通过 onChunk 发送了所有数据。
|
||||
// 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送,
|
||||
// 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error)
|
||||
|
||||
// 如果 onChunk 可用,可以尝试发送一个错误块
|
||||
if (onChunk) {
|
||||
onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack }
|
||||
})
|
||||
// 考虑是否还需要发送 BLOCK_COMPLETE 来结束流
|
||||
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} })
|
||||
}
|
||||
throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### `AiProviderMiddlewareCompletionsContext` 的重要性
|
||||
|
||||
`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含:
|
||||
|
||||
- `methodName`: 当前调用的方法名 (总是 `'completions'`)。
|
||||
- `originalArgs`: 传递给 `completions` 的原始参数数组。
|
||||
- `providerId`: Provider 的 ID。
|
||||
- `_providerInstance`: Provider 实例。
|
||||
- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。**
|
||||
- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。
|
||||
- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`。
|
||||
|
||||
**关键**: 当你在中间件中修改 `params` 或 `context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。
|
||||
|
||||
### 中间件的顺序
|
||||
|
||||
中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。
|
||||
|
||||
- 请求首先通过第一个中间件,然后是第二个,依此类推。
|
||||
- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。
|
||||
|
||||
例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`:
|
||||
|
||||
1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。
|
||||
2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||
3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||
4. 核心SDK调用(或链的末端)。
|
||||
5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||
6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。
|
||||
7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||
|
||||
### 注册中间件
|
||||
|
||||
中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。
|
||||
|
||||
```typescript
|
||||
// register.ts
|
||||
import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes'
|
||||
import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件
|
||||
import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的
|
||||
|
||||
const middlewareConfig: AiProviderMiddlewareConfig = {
|
||||
completions: [
|
||||
createSimpleLoggingMiddleware(), // 你新加的中间件
|
||||
createCompletionsLoggingMiddleware() // 已有的日志中间件
|
||||
// ... 其他 completions 中间件
|
||||
],
|
||||
methods: {
|
||||
// translate: [createGenericLoggingMiddleware()],
|
||||
// ... 其他方法的中间件
|
||||
}
|
||||
}
|
||||
|
||||
export default middlewareConfig
|
||||
```
|
||||
|
||||
### 最佳实践
|
||||
|
||||
1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。
|
||||
2. **无副作用 (尽可能)**: 除了通过 `context` 或 `onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。
|
||||
3. **错误处理**:
|
||||
- 在中间件内部使用 `try...catch` 来处理可能发生的错误。
|
||||
- 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。
|
||||
- 如果重新抛出,确保错误对象包含足够的信息。
|
||||
4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作,确保它们是异步的。
|
||||
5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。
|
||||
6. **上下文管理**:
|
||||
- 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。
|
||||
- 明确你添加到 `context` 的字段的用途和生命周期。
|
||||
7. **`next` 的调用**:
|
||||
- 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。
|
||||
- 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。
|
||||
8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。
|
||||
9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。
|
||||
|
||||
### 调试技巧
|
||||
|
||||
- 在中间件的关键点使用 `console.log` 或调试器来检查 `params`、`context` 的状态以及 `next` 的返回值。
|
||||
- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。
|
||||
- 编写单元测试来独立验证每个中间件的行为。
|
||||
|
||||
通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。
|
||||
@ -408,3 +408,4 @@ export enum FeedUrl {
|
||||
PRODUCTION = 'https://releases.cherry-ai.com',
|
||||
EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
|
||||
}
|
||||
export const defaultTimeout = 5 * 1000 * 60
|
||||
|
||||
@ -4,6 +4,7 @@ import { arch } from 'node:os'
|
||||
import { isMac, isWin } from '@main/constant'
|
||||
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
|
||||
import { handleZoomFactor } from '@main/utils/zoom'
|
||||
import { FeedUrl } from '@shared/config/constant'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { Shortcut, ThemeMode } from '@types'
|
||||
import { BrowserWindow, ipcMain, session, shell } from 'electron'
|
||||
@ -34,7 +35,6 @@ import { calculateDirectorySize, getResourcePath } from './utils'
|
||||
import { decrypt, encrypt } from './utils/aes'
|
||||
import { getCacheDir, getConfigDir, getFilesDir } from './utils/file'
|
||||
import { compress, decompress } from './utils/zip'
|
||||
import { FeedUrl } from '@shared/config/constant'
|
||||
|
||||
const fileManager = new FileStorage()
|
||||
const backupManager = new BackupManager()
|
||||
|
||||
223
src/renderer/src/aiCore/AI_CORE_DESIGN.md
Normal file
223
src/renderer/src/aiCore/AI_CORE_DESIGN.md
Normal file
@ -0,0 +1,223 @@
|
||||
# Cherry Studio AI Provider 技术架构文档 (新方案)
|
||||
|
||||
## 1. 核心设计理念与目标
|
||||
|
||||
本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标:
|
||||
|
||||
- **职责清晰**:明确划分各组件的职责,降低耦合度。
|
||||
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
|
||||
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
|
||||
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
|
||||
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
|
||||
|
||||
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
|
||||
|
||||
## 2. 核心组件详解
|
||||
|
||||
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
|
||||
|
||||
这是整个 AI 功能的核心模块。
|
||||
|
||||
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
|
||||
|
||||
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
|
||||
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
|
||||
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
|
||||
- 例如:SDK 的 `delta.content` -> `TextDeltaChunk`;SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`;SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
|
||||
- **关键**:`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>` 或 `<tool_use>` 标签。
|
||||
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
|
||||
|
||||
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
|
||||
|
||||
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
|
||||
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
|
||||
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
|
||||
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
|
||||
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
|
||||
|
||||
#### 2.1.3. `ApiClientFactory.ts`
|
||||
|
||||
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
|
||||
|
||||
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
|
||||
|
||||
- **职责**:作为所有 AI 相关业务功能的统一入口。
|
||||
- 提供面向应用的高层接口,例如:
|
||||
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
|
||||
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
|
||||
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
|
||||
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
|
||||
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`。
|
||||
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。
|
||||
- **封装特定任务的提示工程 (Prompt Engineering)**:
|
||||
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`。
|
||||
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
|
||||
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
|
||||
- **将 `Promise` 的 `resolve` 和 `reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`。
|
||||
- **优势**:
|
||||
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
|
||||
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
|
||||
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
|
||||
|
||||
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
|
||||
|
||||
- 定义核心的、Provider 无关的内部请求结构,例如:
|
||||
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
|
||||
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
|
||||
|
||||
### 2.2. `middleware`
|
||||
|
||||
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
|
||||
|
||||
**核心组件包括:**
|
||||
|
||||
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。
|
||||
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
|
||||
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
|
||||
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
|
||||
|
||||
#### 2.2.1. `middlewareTypes.ts`
|
||||
|
||||
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance` 和 `_coreRequest`)、`MiddlewareAPI`、`CompletionsMiddleware` 等。
|
||||
|
||||
#### 2.2.2. 核心中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`。
|
||||
- **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`,Gemini 的 `GenerateContentResponse` 中的部分等)。
|
||||
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`。
|
||||
|
||||
#### 2.2.3. 特性中间件 (`middleware/feat/`)
|
||||
|
||||
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
|
||||
|
||||
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk` 和 `ThinkingCompleteChunk`。
|
||||
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
|
||||
|
||||
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
|
||||
|
||||
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
|
||||
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
|
||||
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
|
||||
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
|
||||
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
|
||||
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
|
||||
|
||||
#### 2.2.5. 通用中间件 (`middleware/common/`)
|
||||
|
||||
- **`LoggingMiddleware.ts`**: 请求和响应日志。
|
||||
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
|
||||
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
|
||||
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
|
||||
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
|
||||
- 在流结束时,发送包含最终累加信息的完成信号。
|
||||
|
||||
### 2.3. `types/chunk.ts`
|
||||
|
||||
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
|
||||
|
||||
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
|
||||
|
||||
```markdown
|
||||
**应用层 (例如 UI 组件)**
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` (`aiCore/index.ts`)**
|
||||
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
|
||||
||
|
||||
\\/
|
||||
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
|
||||
(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件)
|
||||
||
|
||||
\\/
|
||||
**[ 预处理阶段中间件 ]**
|
||||
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|
||||
|| (Context 中准备好 SDK 请求参数)
|
||||
\\/
|
||||
**[ 处理阶段中间件 ]**
|
||||
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|
||||
|| (处理各种特性和Chunk类型)
|
||||
\\/
|
||||
**[ SDK调用阶段中间件 ]**
|
||||
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|
||||
|| (输出: 标准化的应用层Chunk流)
|
||||
\\/
|
||||
**`FinalChunkConsumerMiddleware` (核心)**
|
||||
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
|
||||
||
|
||||
\\/
|
||||
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
|
||||
```
|
||||
|
||||
## 4. 建议的文件/目录结构
|
||||
|
||||
```
|
||||
src/renderer/src/
|
||||
└── aiCore/
|
||||
├── clients/
|
||||
│ ├── openai/
|
||||
│ ├── gemini/
|
||||
│ ├── anthropic/
|
||||
│ ├── BaseApiClient.ts
|
||||
│ ├── ApiClientFactory.ts
|
||||
│ ├── AihubmixAPIClient.ts
|
||||
│ ├── index.ts
|
||||
│ └── types.ts
|
||||
├── middleware/
|
||||
│ ├── common/
|
||||
│ ├── core/
|
||||
│ ├── feat/
|
||||
│ ├── builder.ts
|
||||
│ ├── composer.ts
|
||||
│ ├── index.ts
|
||||
│ ├── register.ts
|
||||
│ ├── schemas.ts
|
||||
│ ├── types.ts
|
||||
│ └── utils.ts
|
||||
├── types/
|
||||
│ ├── chunk.ts
|
||||
│ └── ...
|
||||
└── index.ts
|
||||
```
|
||||
|
||||
## 5. 迁移和实施建议
|
||||
|
||||
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
|
||||
- **优先定义核心类型**:`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
|
||||
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
|
||||
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
|
||||
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
|
||||
|
||||
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
|
||||
|
||||
## 6. 迁移策略与实施建议
|
||||
|
||||
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
|
||||
|
||||
**目标架构核心组件回顾:**
|
||||
|
||||
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
|
||||
|
||||
**迁移步骤:**
|
||||
|
||||
**Phase 0: 准备工作和类型定义**
|
||||
|
||||
1. **定义核心数据结构 (TypeScript 类型):**
|
||||
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
|
||||
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。
|
||||
- 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。
|
||||
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
|
||||
3. **调整 `AiProviderMiddlewareContext`:**
|
||||
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`。
|
||||
- 确保包含 `_coreRequest: CoreRequestType`。
|
||||
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void` 和 `rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
|
||||
|
||||
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
|
||||
|
||||
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
|
||||
2. **迁移SDK实例和配置。**
|
||||
3. **实现 `getRequestTransformer()`:** 将 `CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
|
||||
4. **实现 `getResponseChunkTransformer()`:** 将 `OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `
|
||||
207
src/renderer/src/aiCore/clients/AihubmixAPIClient.ts
Normal file
207
src/renderer/src/aiCore/clients/AihubmixAPIClient.ts
Normal file
@ -0,0 +1,207 @@
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import {
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
/**
|
||||
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
||||
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
||||
*/
|
||||
export class AihubmixAPIClient extends BaseApiClient {
|
||||
// 使用联合类型而不是any,保持类型安全
|
||||
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||
new Map()
|
||||
private defaultClient: OpenAIAPIClient
|
||||
private currentClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
// 初始化各个client - 现在有类型安全
|
||||
const claudeClient = new AnthropicAPIClient(provider)
|
||||
const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' })
|
||||
const openaiClient = new OpenAIResponseAPIClient(provider)
|
||||
const defaultClient = new OpenAIAPIClient(provider)
|
||||
|
||||
this.clients.set('claude', claudeClient)
|
||||
this.clients.set('gemini', geminiClient)
|
||||
this.clients.set('openai', openaiClient)
|
||||
this.clients.set('default', defaultClient)
|
||||
|
||||
// 设置默认client
|
||||
this.defaultClient = defaultClient
|
||||
this.currentClient = this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫:确保client是BaseApiClient的实例
|
||||
*/
|
||||
private isValidClient(client: unknown): client is BaseApiClient {
|
||||
return (
|
||||
client !== null &&
|
||||
client !== undefined &&
|
||||
typeof client === 'object' &&
|
||||
'createCompletions' in client &&
|
||||
'getRequestTransformer' in client &&
|
||||
'getResponseChunkTransformer' in client
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的client
|
||||
*/
|
||||
private getClient(model: Model): BaseApiClient {
|
||||
const id = model.id.toLowerCase()
|
||||
|
||||
// claude开头
|
||||
if (id.startsWith('claude')) {
|
||||
const client = this.clients.get('claude')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Claude client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// gemini开头 且不以-nothink、-search结尾
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
const client = this.clients.get('gemini')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('Gemini client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
// OpenAI系列模型
|
||||
if (isOpenAILLMModel(model)) {
|
||||
const client = this.clients.get('openai')
|
||||
if (!client || !this.isValidClient(client)) {
|
||||
throw new Error('OpenAI client not properly initialized')
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
return this.defaultClient as BaseApiClient
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型选择合适的client并委托调用
|
||||
*/
|
||||
public getClientForModel(model: Model): BaseApiClient {
|
||||
this.currentClient = this.getClient(model)
|
||||
return this.currentClient
|
||||
}
|
||||
|
||||
// ============ BaseApiClient 抽象方法实现 ============
|
||||
|
||||
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||
// 尝试从payload中提取模型信息来选择client
|
||||
const modelId = this.extractModelFromPayload(payload)
|
||||
if (modelId) {
|
||||
const modelObj = { id: modelId } as Model
|
||||
const targetClient = this.getClient(modelObj)
|
||||
return targetClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
// 如果无法从payload中提取模型,使用当前设置的client
|
||||
return this.currentClient.createCompletions(payload, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从SDK payload中提取模型ID
|
||||
*/
|
||||
private extractModelFromPayload(payload: SdkParams): string | null {
|
||||
// 不同的SDK可能有不同的字段名
|
||||
if ('model' in payload && typeof payload.model === 'string') {
|
||||
return payload.model
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.currentClient.generateImage(params)
|
||||
}
|
||||
|
||||
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||
const client = model ? this.getClient(model) : this.currentClient
|
||||
return client.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
async listModels(): Promise<SdkModel[]> {
|
||||
// 可以聚合所有client的模型,或者使用默认client
|
||||
return this.defaultClient.listModels()
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<SdkInstance> {
|
||||
return this.currentClient.getSdkInstance()
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
||||
return this.currentClient.getRequestTransformer()
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<SdkRawChunk> {
|
||||
return this.currentClient.getResponseChunkTransformer()
|
||||
}
|
||||
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
}
|
||||
|
||||
buildSdkMessages(
|
||||
currentReqMessages: SdkMessageParam[],
|
||||
output: SdkRawOutput | string,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls?: SdkToolCall[]
|
||||
): SdkMessageParam[] {
|
||||
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
}
|
||||
|
||||
convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): SdkMessageParam | undefined {
|
||||
const client = this.getClient(model)
|
||||
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
}
|
||||
|
||||
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
||||
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
|
||||
estimateMessageTokens(message: SdkMessageParam): number {
|
||||
return this.currentClient.estimateMessageTokens(message)
|
||||
}
|
||||
}
|
||||
62
src/renderer/src/aiCore/clients/ApiClientFactory.ts
Normal file
62
src/renderer/src/aiCore/clients/ApiClientFactory.ts
Normal file
@ -0,0 +1,62 @@
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { AihubmixAPIClient } from './AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||
|
||||
/**
|
||||
* Factory for creating ApiClient instances based on provider configuration
|
||||
* 根据提供者配置创建ApiClient实例的工厂
|
||||
*/
|
||||
export class ApiClientFactory {
|
||||
/**
|
||||
* Create an ApiClient instance for the given provider
|
||||
* 为给定的提供者创建ApiClient实例
|
||||
*/
|
||||
static create(provider: Provider): BaseApiClient {
|
||||
console.log(`[ApiClientFactory] Creating ApiClient for provider:`, {
|
||||
id: provider.id,
|
||||
type: provider.type
|
||||
})
|
||||
|
||||
let instance: BaseApiClient
|
||||
|
||||
// 首先检查特殊的provider id
|
||||
if (provider.id === 'aihubmix') {
|
||||
console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`)
|
||||
instance = new AihubmixAPIClient(provider) as BaseApiClient
|
||||
return instance
|
||||
}
|
||||
|
||||
// 然后检查标准的provider type
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
case 'azure-openai':
|
||||
console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'openai-response':
|
||||
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'gemini':
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
instance = new AnthropicAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
default:
|
||||
console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
}
|
||||
|
||||
return instance
|
||||
}
|
||||
}
|
||||
|
||||
export function isOpenAIProvider(provider: Provider) {
|
||||
return !['anthropic', 'gemini'].includes(provider.type)
|
||||
}
|
||||
@ -1,40 +1,69 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { isFunctionCallingModel, isNotSupportTemperatureAndTopP } from '@renderer/config/models'
|
||||
import {
|
||||
isFunctionCallingModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isOpenAIModel,
|
||||
isSupportedFlexServiceTier
|
||||
} from '@renderer/config/models'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||
import type {
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import {
|
||||
Assistant,
|
||||
FileTypes,
|
||||
GenerateImageParams,
|
||||
KnowledgeReference,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
OpenAIServiceTier,
|
||||
Provider,
|
||||
Suggestion,
|
||||
ToolCallResponse,
|
||||
WebSearchProviderResponse,
|
||||
WebSearchResponse
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { delay, isJSON, parseJSON } from '@renderer/utils'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkModel,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import Logger from 'electron-log/renderer'
|
||||
import { isEmpty } from 'lodash'
|
||||
import type OpenAI from 'openai'
|
||||
|
||||
import type { CompletionsParams } from '.'
|
||||
import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||
|
||||
export default abstract class BaseProvider {
|
||||
// Threshold for determining whether to use system prompt for tools
|
||||
/**
|
||||
* Abstract base class for API clients.
|
||||
* Provides common functionality and structure for specific client implementations.
|
||||
*/
|
||||
export abstract class BaseApiClient<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
|
||||
{
|
||||
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
|
||||
|
||||
protected provider: Provider
|
||||
public provider: Provider
|
||||
protected host: string
|
||||
protected apiKey: string
|
||||
|
||||
protected useSystemPromptForTools: boolean = true
|
||||
protected sdkInstance?: TSdkInstance
|
||||
public useSystemPromptForTools: boolean = true
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.provider = provider
|
||||
@ -42,32 +71,81 @@ export default abstract class BaseProvider {
|
||||
this.apiKey = this.getApiKey()
|
||||
}
|
||||
|
||||
abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
abstract translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
): Promise<string>
|
||||
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
|
||||
abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null>
|
||||
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
|
||||
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
|
||||
abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }>
|
||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
||||
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
// 由于现在出现了一些能够选择嵌入维度的嵌入模型,这个不考虑dimensions参数的方法将只能应用于那些不支持dimensions的模型
|
||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
||||
public abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
|
||||
public abstract mcpToolCallResponseToMessage(
|
||||
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
||||
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
||||
|
||||
/**
|
||||
* 核心API Endpoint
|
||||
**/
|
||||
|
||||
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
|
||||
|
||||
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
|
||||
|
||||
abstract getEmbeddingDimensions(model?: Model): Promise<number>
|
||||
|
||||
abstract listModels(): Promise<SdkModel[]>
|
||||
|
||||
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
|
||||
/**
|
||||
* 中间件
|
||||
**/
|
||||
|
||||
// 在 CoreRequestToSdkParamsMiddleware中使用
|
||||
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
abstract getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
/**
|
||||
* 工具转换
|
||||
**/
|
||||
|
||||
// Optional tool conversion methods - implement if needed by the specific provider
|
||||
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
|
||||
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
|
||||
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
abstract buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
abstract estimateMessageTokens(message: TMessageParam): number
|
||||
|
||||
abstract convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): any
|
||||
): TMessageParam | undefined
|
||||
|
||||
/**
|
||||
* 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
* 不同的提供商可能使用不同的字段名(如messages、history等)
|
||||
*/
|
||||
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
|
||||
/**
|
||||
* 附加原始流监听器
|
||||
*/
|
||||
public attachRawStreamListener<TListener extends RawStreamListener<TRawChunk>>(
|
||||
rawOutput: TRawOutput,
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
_listener: TListener
|
||||
): TRawOutput {
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用函数
|
||||
**/
|
||||
|
||||
public getBaseURL(): string {
|
||||
const host = this.provider.apiHost
|
||||
return formatApiHost(host)
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
public getApiKey() {
|
||||
@ -112,14 +190,32 @@ export default abstract class BaseProvider {
|
||||
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
|
||||
}
|
||||
|
||||
public async fakeCompletions({ onChunk }: CompletionsParams) {
|
||||
for (let i = 0; i < 100; i++) {
|
||||
await delay(0.01)
|
||||
onChunk({
|
||||
response: { text: i + '\n', usage: { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 } },
|
||||
type: ChunkType.BLOCK_COMPLETE
|
||||
})
|
||||
protected getServiceTier(model: Model) {
|
||||
if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
let serviceTier = 'auto' as OpenAIServiceTier
|
||||
|
||||
if (openAI && openAI?.serviceTier === 'flex') {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
serviceTier = 'flex'
|
||||
} else {
|
||||
serviceTier = 'auto'
|
||||
}
|
||||
} else {
|
||||
serviceTier = openAI.serviceTier
|
||||
}
|
||||
|
||||
return serviceTier
|
||||
}
|
||||
|
||||
protected getTimeout(model: Model) {
|
||||
if (isSupportedFlexServiceTier(model)) {
|
||||
return 15 * 1000 * 60
|
||||
}
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
public async getMessageContent(message: Message): Promise<string> {
|
||||
@ -149,6 +245,36 @@ export default abstract class BaseProvider {
|
||||
return content
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the file content from the message
|
||||
* @param message - The message
|
||||
* @returns The file content
|
||||
*/
|
||||
protected async extractFileContent(message: Message) {
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
if (fileBlocks.length > 0) {
|
||||
const textFileBlocks = fileBlocks.filter(
|
||||
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
|
||||
)
|
||||
|
||||
if (textFileBlocks.length > 0) {
|
||||
let text = ''
|
||||
const divider = '\n\n---\n\n'
|
||||
|
||||
for (const fileBlock of textFileBlocks) {
|
||||
const file = fileBlock.file
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||
text = text + fileNameRow + fileContent + divider
|
||||
}
|
||||
|
||||
return text
|
||||
}
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
private async getWebSearchReferencesFromCache(message: Message) {
|
||||
const content = getMainTextContent(message)
|
||||
if (isEmpty(content)) {
|
||||
@ -210,7 +336,7 @@ export default abstract class BaseProvider {
|
||||
)
|
||||
}
|
||||
|
||||
protected createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||
const abortController = new AbortController()
|
||||
const abortFn = () => abortController.abort()
|
||||
|
||||
@ -256,11 +382,11 @@ export default abstract class BaseProvider {
|
||||
}
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
protected setupToolsConfig<T>(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: T[]
|
||||
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: TSdkSpecificTool[]
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
let tools: T[] = []
|
||||
let tools: TSdkSpecificTool[] = []
|
||||
|
||||
// If there are no tools, return an empty array
|
||||
if (!mcpTools?.length) {
|
||||
@ -268,14 +394,14 @@ export default abstract class BaseProvider {
|
||||
}
|
||||
|
||||
// If the number of tools exceeds the threshold, use the system prompt
|
||||
if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) {
|
||||
if (mcpTools.length > BaseApiClient.SYSTEM_PROMPT_THRESHOLD) {
|
||||
this.useSystemPromptForTools = true
|
||||
return { tools }
|
||||
}
|
||||
|
||||
// If the model supports function calling and tool usage is enabled
|
||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||
tools = this.convertMcpTools<T>(mcpTools)
|
||||
tools = this.convertMcpToolsToSdkTools(mcpTools)
|
||||
this.useSystemPromptForTools = false
|
||||
}
|
||||
|
||||
714
src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts
Normal file
714
src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts
Normal file
@ -0,0 +1,714 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
Base64ImageSource,
|
||||
ImageBlockParam,
|
||||
MessageParam,
|
||||
TextBlockParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUseBlock,
|
||||
WebSearchTool20250305
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import {
|
||||
ContentBlock,
|
||||
ContentBlockParam,
|
||||
MessageCreateParams,
|
||||
MessageCreateParamsBase,
|
||||
RedactedThinkingBlockParam,
|
||||
ServerToolUseBlockParam,
|
||||
ThinkingBlockParam,
|
||||
ThinkingConfigParam,
|
||||
ToolUnion,
|
||||
ToolUseBlockParam,
|
||||
WebSearchResultBlock,
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import {
|
||||
ChunkType,
|
||||
ErrorChunk,
|
||||
LLMWebSearchCompleteChunk,
|
||||
LLMWebSearchInProgressChunk,
|
||||
MCPToolCreatedChunk,
|
||||
TextDeltaChunk,
|
||||
ThinkingDeltaChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
AnthropicSdkMessageParam,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkRawOutput
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
export class AnthropicAPIClient extends BaseApiClient<
|
||||
Anthropic,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawOutput,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkMessageParam,
|
||||
ToolUseBlock,
|
||||
ToolUnion
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<Anthropic> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
this.sdkInstance = new Anthropic({
|
||||
apiKey: this.getApiKey(),
|
||||
baseURL: this.getBaseURL(),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'anthropic-beta': 'output-128k-2025-02-19'
|
||||
}
|
||||
})
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: AnthropicSdkParams,
|
||||
options?: Anthropic.RequestOptions
|
||||
): Promise<AnthropicSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
return await sdk.messages.create(payload, options)
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
return response.data
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
override async getEmbeddingDimensions(): Promise<number> {
|
||||
return 0
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @param model - The model
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
type: 'text',
|
||||
text: getMainTextContent(message)
|
||||
}
|
||||
]
|
||||
|
||||
// Get and process image blocks
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
// Handle uploaded file
|
||||
const file = imageBlock.file
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
parts.push({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: base64Data
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
|
||||
return mcpToolsToAnthropicTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): AnthropicSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: mcpToolResponse.toolCallId!,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text || ''
|
||||
} satisfies TextBlockParam
|
||||
}
|
||||
if (item.type === 'image') {
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: item.data || '',
|
||||
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||
type: 'base64'
|
||||
}
|
||||
} satisfies ImageBlockParam
|
||||
}
|
||||
return
|
||||
})
|
||||
.filter((n) => typeof n !== 'undefined'),
|
||||
is_error: resp.isError
|
||||
} satisfies ToolResultBlockParam
|
||||
]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Implementing abstract methods from BaseApiClient
|
||||
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
|
||||
// This might need adjustment based on how tool calls are specifically handled in the new structure
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
return mcpTool
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input as Record<string, unknown>,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
override buildSdkMessages(
|
||||
currentReqMessages: AnthropicSdkMessageParam[],
|
||||
output: Anthropic.Message,
|
||||
toolResults: AnthropicSdkMessageParam[]
|
||||
): AnthropicSdkMessageParam[] {
|
||||
const assistantMessage: AnthropicSdkMessageParam = {
|
||||
role: output.role,
|
||||
content: convertContentBlocksToParams(output.content)
|
||||
}
|
||||
|
||||
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
|
||||
if (toolResults && toolResults.length > 0) {
|
||||
newMessages.push(...toolResults)
|
||||
}
|
||||
return newMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
|
||||
if (typeof message.content === 'string') {
|
||||
return estimateTextTokens(message.content)
|
||||
}
|
||||
return message.content
|
||||
.map((content) => {
|
||||
switch (content.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(content.text)
|
||||
case 'image':
|
||||
if (content.source.type === 'base64') {
|
||||
return estimateTextTokens(content.source.data)
|
||||
} else {
|
||||
return estimateTextTokens(content.source.url)
|
||||
}
|
||||
case 'tool_use':
|
||||
return estimateTextTokens(JSON.stringify(content.input))
|
||||
case 'tool_result':
|
||||
return estimateTextTokens(JSON.stringify(content.content))
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
|
||||
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
|
||||
const messageParam: AnthropicSdkMessageParam = {
|
||||
role: message.role,
|
||||
content: convertContentBlocksToParams(message.content)
|
||||
}
|
||||
return messageParam
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic专用的原始流监听器
|
||||
* 处理MessageStream对象的特定事件
|
||||
*/
|
||||
override attachRawStreamListener(
|
||||
rawOutput: AnthropicSdkRawOutput,
|
||||
listener: RawStreamListener<AnthropicSdkRawChunk>
|
||||
): AnthropicSdkRawOutput {
|
||||
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
|
||||
|
||||
// 检查是否为MessageStream
|
||||
if (rawOutput instanceof MessageStream) {
|
||||
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream,附加专用监听器`)
|
||||
|
||||
if (listener.onStart) {
|
||||
listener.onStart()
|
||||
}
|
||||
|
||||
if (listener.onChunk) {
|
||||
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
|
||||
listener.onChunk!(event)
|
||||
})
|
||||
}
|
||||
|
||||
// 专用的Anthropic事件处理
|
||||
const anthropicListener = listener as AnthropicStreamListener
|
||||
|
||||
if (anthropicListener.onContentBlock) {
|
||||
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
|
||||
}
|
||||
|
||||
if (anthropicListener.onMessage) {
|
||||
rawOutput.on('finalMessage', anthropicListener.onMessage)
|
||||
}
|
||||
|
||||
if (listener.onEnd) {
|
||||
rawOutput.on('end', () => {
|
||||
listener.onEnd!()
|
||||
})
|
||||
}
|
||||
|
||||
if (listener.onError) {
|
||||
rawOutput.on('error', (error: Error) => {
|
||||
listener.onError!(error)
|
||||
})
|
||||
}
|
||||
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
// 对于非MessageStream响应
|
||||
return rawOutput
|
||||
}
|
||||
|
||||
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
} as WebSearchTool20250305
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: AnthropicSdkParams
|
||||
messages: AnthropicSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemPrompt = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools)
|
||||
}
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
? { type: 'text', text: systemPrompt }
|
||||
: undefined
|
||||
|
||||
// 3. 处理用户消息
|
||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
sdkMessages.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
const webSearchTool = await this.getWebSearchParams(model)
|
||||
if (webSearchTool) {
|
||||
tools.push(webSearchTool)
|
||||
}
|
||||
}
|
||||
|
||||
const commonParams: MessageCreateParamsBase = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: sdkMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: systemMessage ? [systemMessage] : undefined,
|
||||
thinking: this.getBudgetToken(assistant, model),
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
const finalParams: MessageCreateParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: finalParams, messages: sdkMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
|
||||
return () => {
|
||||
let accumulatedJson = ''
|
||||
const toolCalls: Record<number, ToolUseBlock> = {}
|
||||
|
||||
return {
|
||||
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
switch (rawChunk.type) {
|
||||
case 'message': {
|
||||
for (const content of rawChunk.content) {
|
||||
switch (content.type) {
|
||||
case 'text': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: content.text
|
||||
} as TextDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[0] = content
|
||||
break
|
||||
}
|
||||
case 'thinking': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: content.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: content.content,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_start': {
|
||||
const contentBlock = rawChunk.content_block
|
||||
switch (contentBlock.type) {
|
||||
case 'server_tool_use': {
|
||||
if (contentBlock.name === 'web_search') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
} as LLMWebSearchInProgressChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'web_search_tool_result': {
|
||||
if (
|
||||
contentBlock.content &&
|
||||
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||
) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
code: (contentBlock.content as WebSearchToolResultError).error_code,
|
||||
message: (contentBlock.content as WebSearchToolResultError).error_code
|
||||
}
|
||||
} as ErrorChunk)
|
||||
} else {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: contentBlock.content as Array<WebSearchResultBlock>,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'tool_use': {
|
||||
toolCalls[rawChunk.index] = contentBlock
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_delta': {
|
||||
const messageDelta = rawChunk.delta
|
||||
switch (messageDelta.type) {
|
||||
case 'text_delta': {
|
||||
if (messageDelta.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: messageDelta.text
|
||||
} as TextDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'thinking_delta': {
|
||||
if (messageDelta.thinking) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: messageDelta.thinking
|
||||
} as ThinkingDeltaChunk)
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'input_json_delta': {
|
||||
if (messageDelta.partial_json) {
|
||||
accumulatedJson += messageDelta.partial_json
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'content_block_stop': {
|
||||
const toolCall = toolCalls[rawChunk.index]
|
||||
if (toolCall) {
|
||||
try {
|
||||
toolCall.input = JSON.parse(accumulatedJson)
|
||||
Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [toolCall]
|
||||
} as MCPToolCreatedChunk)
|
||||
} catch (error) {
|
||||
Logger.error(`Error parsing tool call input: ${error}`)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'message_delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
|
||||
* 去除服务器生成的额外字段,只保留发送给API所需的字段
|
||||
*/
|
||||
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
|
||||
return contentBlocks.map((block): ContentBlockParam => {
|
||||
switch (block.type) {
|
||||
case 'text':
|
||||
// TextBlock -> TextBlockParam,去除 citations 等服务器字段
|
||||
return {
|
||||
type: 'text',
|
||||
text: block.text
|
||||
} satisfies TextBlockParam
|
||||
case 'tool_use':
|
||||
// ToolUseBlock -> ToolUseBlockParam
|
||||
return {
|
||||
type: 'tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ToolUseBlockParam
|
||||
case 'thinking':
|
||||
// ThinkingBlock -> ThinkingBlockParam
|
||||
return {
|
||||
type: 'thinking',
|
||||
thinking: block.thinking,
|
||||
signature: block.signature
|
||||
} satisfies ThinkingBlockParam
|
||||
case 'redacted_thinking':
|
||||
// RedactedThinkingBlock -> RedactedThinkingBlockParam
|
||||
return {
|
||||
type: 'redacted_thinking',
|
||||
data: block.data
|
||||
} satisfies RedactedThinkingBlockParam
|
||||
case 'server_tool_use':
|
||||
// ServerToolUseBlock -> ServerToolUseBlockParam
|
||||
return {
|
||||
type: 'server_tool_use',
|
||||
id: block.id,
|
||||
name: block.name,
|
||||
input: block.input
|
||||
} satisfies ServerToolUseBlockParam
|
||||
case 'web_search_tool_result':
|
||||
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
|
||||
return {
|
||||
type: 'web_search_tool_result',
|
||||
tool_use_id: block.tool_use_id,
|
||||
content: block.content
|
||||
} satisfies WebSearchToolResultBlockParam
|
||||
default:
|
||||
return block as ContentBlockParam
|
||||
}
|
||||
})
|
||||
}
|
||||
781
src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts
Normal file
781
src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts
Normal file
@ -0,0 +1,781 @@
|
||||
import {
|
||||
Content,
|
||||
File,
|
||||
FileState,
|
||||
FunctionCall,
|
||||
GenerateContentConfig,
|
||||
GenerateImagesConfig,
|
||||
GoogleGenAI,
|
||||
HarmBlockThreshold,
|
||||
HarmCategory,
|
||||
Modality,
|
||||
Model as GeminiModel,
|
||||
Pager,
|
||||
Part,
|
||||
SafetySetting,
|
||||
SendMessageParameters,
|
||||
ThinkingConfig,
|
||||
Tool
|
||||
} from '@google/genai'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { findTokenLimit, isGeminiReasoningModel, isGemmaModel, isVisionModel } from '@renderer/config/models'
|
||||
import { CacheService } from '@renderer/services/CacheService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileType,
|
||||
FileTypes,
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
GeminiOptions,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToGeminiMessage,
|
||||
mcpToolsToGeminiTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { MB } from '@shared/config/constant'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
export class GeminiAPIClient extends BaseApiClient<
|
||||
GoogleGenAI,
|
||||
GeminiSdkParams,
|
||||
GeminiSdkRawOutput,
|
||||
GeminiSdkRawChunk,
|
||||
GeminiSdkMessageParam,
|
||||
GeminiSdkToolCall,
|
||||
Tool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const { model, history, ...rest } = payload
|
||||
const realPayload: Omit<GeminiSdkParams, 'model'> = {
|
||||
...rest,
|
||||
config: {
|
||||
...rest.config,
|
||||
abortSignal: options?.abortSignal,
|
||||
httpOptions: {
|
||||
...rest.config?.httpOptions,
|
||||
timeout: options?.timeout
|
||||
}
|
||||
}
|
||||
} satisfies SendMessageParameters
|
||||
|
||||
const streamOutput = options?.streamOutput
|
||||
|
||||
const chat = sdk.chats.create({
|
||||
model: model,
|
||||
history: history
|
||||
})
|
||||
|
||||
if (streamOutput) {
|
||||
const stream = chat.sendMessageStream(realPayload)
|
||||
return stream
|
||||
} else {
|
||||
const response = await chat.sendMessage(realPayload)
|
||||
return response
|
||||
}
|
||||
}
|
||||
|
||||
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
|
||||
const config: GenerateImagesConfig = {
|
||||
numberOfImages: batchSize,
|
||||
aspectRatio: imageSize,
|
||||
abortSignal: signal,
|
||||
httpOptions: {
|
||||
timeout: 5 * 60 * 1000
|
||||
}
|
||||
}
|
||||
const response = await sdk.models.generateImages({
|
||||
model: model,
|
||||
prompt,
|
||||
config
|
||||
})
|
||||
|
||||
if (!response.generatedImages || response.generatedImages.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
const images = response.generatedImages
|
||||
.filter((image) => image.image?.imageBytes)
|
||||
.map((image) => {
|
||||
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
|
||||
return dataPrefix + image.image?.imageBytes
|
||||
})
|
||||
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
|
||||
return images
|
||||
} catch (error) {
|
||||
console.error('[generateImage] error:', error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const data = await sdk.models.embedContent({
|
||||
model: model.id,
|
||||
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||
})
|
||||
return data.embeddings?.[0]?.values?.length || 0
|
||||
} catch (e) {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
override async listModels(): Promise<GeminiModel[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
const models: GeminiModel[] = []
|
||||
for await (const model of response) {
|
||||
models.push(model)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: false,
|
||||
apiKey: this.apiKey,
|
||||
httpOptions: { baseUrl: this.getBaseURL() }
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle a PDF file
|
||||
* @param file - The file
|
||||
* @returns The part
|
||||
*/
|
||||
private async handlePdfFile(file: FileType): Promise<Part> {
|
||||
const smallFileSize = 20 * MB
|
||||
const isSmallFile = file.size < smallFileSize
|
||||
|
||||
if (isSmallFile) {
|
||||
const { data, mimeType } = await this.base64File(file)
|
||||
return {
|
||||
inlineData: {
|
||||
data,
|
||||
mimeType
|
||||
} as Part['inlineData']
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve file from Gemini uploaded files
|
||||
const fileMetadata: File | undefined = await this.retrieveFile(file)
|
||||
|
||||
if (fileMetadata) {
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: fileMetadata.uri,
|
||||
mimeType: fileMetadata.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
}
|
||||
|
||||
// If file is not found, upload it to Gemini
|
||||
const result = await this.uploadFile(file)
|
||||
|
||||
return {
|
||||
fileData: {
|
||||
fileUri: result.uri,
|
||||
mimeType: result.mimeType
|
||||
} as Part['fileData']
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message contents
|
||||
* @param message - The message
|
||||
* @returns The message contents
|
||||
*/
|
||||
private async convertMessageToSdkParam(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
||||
// Add any generated images from previous responses
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
|
||||
if (file.ext === '.pdf') {
|
||||
parts.push(await this.handlePdfFile(file))
|
||||
continue
|
||||
}
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async getImageFileContents(message: Message): Promise<Content> {
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
const content = getMainTextContent(message)
|
||||
const parts: Part[] = [{ text: content }]
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (
|
||||
imageBlock.metadata?.generateImageResponse?.images &&
|
||||
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||
) {
|
||||
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||
// Extract base64 data and mime type from the data URL
|
||||
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data,
|
||||
mimeType: mimeType
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const file = imageBlock.file
|
||||
if (file) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
} as Part['inlineData']
|
||||
})
|
||||
}
|
||||
}
|
||||
return {
|
||||
role,
|
||||
parts: parts
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the safety settings
|
||||
* @returns The safety settings
|
||||
*/
|
||||
private getSafetySettings(): SafetySetting[] {
|
||||
const safetyThreshold = 'OFF' as HarmBlockThreshold
|
||||
|
||||
return [
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold: safetyThreshold
|
||||
},
|
||||
{
|
||||
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
|
||||
threshold: HarmBlockThreshold.BLOCK_NONE
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model) {
|
||||
if (isGeminiReasoningModel(model)) {
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini-.*-flash.*$')
|
||||
|
||||
// 如果thinking_budget是undefined,不思考
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: false,
|
||||
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
|
||||
} as ThinkingConfig
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
if (effortRatio > 1) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
includeThoughts: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const { max } = findTokenLimit(model.id) || { max: 0 }
|
||||
const budget = Math.floor(max * effortRatio)
|
||||
|
||||
return {
|
||||
thinkingConfig: {
|
||||
...(budget > 0 ? { thinkingBudget: budget } : {}),
|
||||
includeThoughts: true
|
||||
} as ThinkingConfig
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
|
||||
return {
|
||||
systemInstruction: undefined,
|
||||
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||
responseMimeType: 'text/plain'
|
||||
}
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: GeminiSdkParams
|
||||
messages: GeminiSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemInstruction = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools)
|
||||
}
|
||||
|
||||
let messageContents: Content
|
||||
const history: Content[] = []
|
||||
// 3. 处理用户消息
|
||||
if (typeof messages === 'string') {
|
||||
messageContents = {
|
||||
role: 'user',
|
||||
parts: [{ text: messages }]
|
||||
}
|
||||
} else {
|
||||
const userLastMessage = messages.pop()!
|
||||
messageContents = await this.convertMessageToSdkParam(userLastMessage)
|
||||
for (const message of messages) {
|
||||
history.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
googleSearch: {}
|
||||
})
|
||||
}
|
||||
|
||||
if (isGemmaModel(model) && assistant.prompt) {
|
||||
const isFirstMessage = history.length === 0
|
||||
if (isFirstMessage && messageContents) {
|
||||
const systemMessage = [
|
||||
{
|
||||
text:
|
||||
'<start_of_turn>user\n' +
|
||||
systemInstruction +
|
||||
'<end_of_turn>\n' +
|
||||
'<start_of_turn>user\n' +
|
||||
(messageContents?.parts?.[0] as Part).text +
|
||||
'<end_of_turn>'
|
||||
}
|
||||
] as Part[]
|
||||
if (messageContents && messageContents.parts) {
|
||||
messageContents.parts[0] = systemMessage[0]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const newHistory =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
|
||||
: history
|
||||
|
||||
const newMessageContents =
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? {
|
||||
...messageContents,
|
||||
parts: [
|
||||
...(messageContents.parts || []),
|
||||
...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || [])
|
||||
]
|
||||
}
|
||||
: messageContents
|
||||
|
||||
const generateContentConfig: GenerateContentConfig = {
|
||||
safetySettings: this.getSafetySettings(),
|
||||
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
topP: this.getTopP(assistant, model),
|
||||
maxOutputTokens: maxTokens,
|
||||
tools: tools,
|
||||
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
|
||||
...this.getBudgetToken(assistant, model),
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
const param: GeminiSdkParams = {
|
||||
model: model.id,
|
||||
config: generateContentConfig,
|
||||
history: newHistory,
|
||||
message: newMessageContents.parts!
|
||||
}
|
||||
|
||||
return {
|
||||
payload: param,
|
||||
messages: [messageContents],
|
||||
metadata: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
|
||||
return () => ({
|
||||
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
let toolCalls: FunctionCall[] = []
|
||||
if (chunk.candidates && chunk.candidates.length > 0) {
|
||||
for (const candidate of chunk.candidates) {
|
||||
if (candidate.content) {
|
||||
candidate.content.parts?.forEach((part) => {
|
||||
const text = part.text || ''
|
||||
if (part.thought) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.text) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: text
|
||||
})
|
||||
} else if (part.inlineData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [
|
||||
part.inlineData?.data?.startsWith('data:')
|
||||
? part.inlineData?.data
|
||||
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
if (candidate.finishReason) {
|
||||
if (candidate.groundingMetadata) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: candidate.groundingMetadata,
|
||||
source: WebSearchSource.GEMINI
|
||||
}
|
||||
} as LLMWebSearchCompleteChunk)
|
||||
}
|
||||
if (chunk.functionCalls) {
|
||||
toolCalls = toolCalls.concat(chunk.functionCalls)
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||
completion_tokens:
|
||||
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
|
||||
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
|
||||
return mcpToolsToGeminiTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
|
||||
} catch {
|
||||
return toolCall.args
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.id || nanoid(),
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): GeminiSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
parts: [
|
||||
{
|
||||
functionResponse: {
|
||||
id: mcpToolResponse.toolCallId,
|
||||
name: mcpToolResponse.tool.id,
|
||||
response: {
|
||||
output: !resp.isError ? resp.content : undefined,
|
||||
error: resp.isError ? resp.content : undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
} satisfies Content
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: Content[],
|
||||
output: string,
|
||||
toolResults: Content[],
|
||||
toolCalls: FunctionCall[]
|
||||
): Content[] {
|
||||
const parts: Part[] = []
|
||||
if (output) {
|
||||
parts.push({
|
||||
text: output
|
||||
})
|
||||
}
|
||||
toolCalls.forEach((toolCall) => {
|
||||
parts.push({
|
||||
functionCall: toolCall
|
||||
})
|
||||
})
|
||||
parts.push(
|
||||
...toolResults
|
||||
.map((ts) => ts.parts)
|
||||
.flat()
|
||||
.filter((p) => p !== undefined)
|
||||
)
|
||||
|
||||
const userMessage: Content = {
|
||||
role: 'user',
|
||||
parts: parts
|
||||
}
|
||||
|
||||
return [...currentReqMessages, userMessage]
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
|
||||
return (
|
||||
message.parts?.reduce((acc, part) => {
|
||||
if (part.text) {
|
||||
return acc + estimateTextTokens(part.text)
|
||||
}
|
||||
if (part.functionCall) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
|
||||
}
|
||||
if (part.functionResponse) {
|
||||
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
|
||||
}
|
||||
if (part.inlineData) {
|
||||
return acc + estimateTextTokens(part.inlineData.data || '')
|
||||
}
|
||||
if (part.fileData) {
|
||||
return acc + estimateTextTokens(part.fileData.fileUri || '')
|
||||
}
|
||||
return acc
|
||||
}, 0) || 0
|
||||
)
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
|
||||
return sdkPayload.history || []
|
||||
}
|
||||
|
||||
private async uploadFile(file: FileType): Promise<File> {
|
||||
return await this.sdkInstance!.files.upload({
|
||||
file: file.path,
|
||||
config: {
|
||||
mimeType: 'application/pdf',
|
||||
name: file.id,
|
||||
displayName: file.origin_name
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private async base64File(file: FileType) {
|
||||
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||
return {
|
||||
data,
|
||||
mimeType: 'application/pdf'
|
||||
}
|
||||
}
|
||||
|
||||
private async retrieveFile(file: FileType): Promise<File | undefined> {
|
||||
const cachedResponse = CacheService.get<any>('gemini_file_list')
|
||||
|
||||
if (cachedResponse) {
|
||||
return this.processResponse(cachedResponse, file)
|
||||
}
|
||||
|
||||
const response = await this.sdkInstance!.files.list()
|
||||
CacheService.set('gemini_file_list', response, 3000)
|
||||
|
||||
return this.processResponse(response, file)
|
||||
}
|
||||
|
||||
private async processResponse(response: Pager<File>, file: FileType) {
|
||||
for await (const f of response) {
|
||||
if (f.state === FileState.ACTIVE) {
|
||||
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
||||
return f
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async listFiles(): Promise<File[]> {
|
||||
const files: File[] = []
|
||||
for await (const f of await this.sdkInstance!.files.list()) {
|
||||
files.push(f)
|
||||
}
|
||||
return files
|
||||
}
|
||||
|
||||
// @ts-ignore unused
|
||||
private async deleteFile(fileId: string) {
|
||||
await this.sdkInstance!.files.delete({ name: fileId })
|
||||
}
|
||||
}
|
||||
6
src/renderer/src/aiCore/clients/index.ts
Normal file
6
src/renderer/src/aiCore/clients/index.ts
Normal file
@ -0,0 +1,6 @@
|
||||
export * from './ApiClientFactory'
|
||||
export * from './BaseApiClient'
|
||||
export * from './types'
|
||||
|
||||
// Export specific clients from subdirectories
|
||||
export * from './openai/OpenAIApiClient'
|
||||
646
src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts
Normal file
646
src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts
Normal file
@ -0,0 +1,646 @@
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import {
|
||||
findTokenLimit,
|
||||
getOpenAIWebSearchParams,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortGrokModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isSupportedThinkingTokenQwenModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
// For Copilot token
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkRawContentSource,
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||
mcpToolsToOpenAIChatTools,
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
|
||||
export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
OpenAI | AzureOpenAI,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawOutput,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
ChatCompletionTool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAISdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAISdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
// @ts-ignore - SDK参数可能有额外的字段
|
||||
return await sdk.chat.completions.create(payload, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
// Method for reasoning effort, moved from OpenAIProvider
|
||||
override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
if (this.provider.id === 'groq') {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
if (!reasoningEffort) {
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return { enable_thinking: false }
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
// openrouter没有提供一个不推理的选项,先隐藏
|
||||
if (this.provider.id === 'openrouter') {
|
||||
return { reasoning: { max_tokens: 0, exclude: true } }
|
||||
}
|
||||
return {
|
||||
reasoning_effort: 'none'
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
const budgetTokens = Math.floor(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
|
||||
)
|
||||
|
||||
// OpenRouter models
|
||||
if (model.provider === 'openrouter') {
|
||||
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Qwen models
|
||||
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
return {
|
||||
enable_thinking: true,
|
||||
thinking_budget: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
// Grok models
|
||||
if (isSupportedReasoningEffortGrokModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI models
|
||||
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
|
||||
// Claude models
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const maxTokens = assistant.settings?.maxTokens
|
||||
return {
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budget_tokens: Math.floor(
|
||||
Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default case: no special thinking settings
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the provider does not support files
|
||||
* @returns True if the provider does not support files, false otherwise
|
||||
*/
|
||||
private get isNotSupportFiles() {
|
||||
if (this.provider?.isNotSupportArrayContent) {
|
||||
return true
|
||||
}
|
||||
|
||||
const providers = ['deepseek', 'baichuan', 'minimax', 'xirang']
|
||||
|
||||
return providers.includes(this.provider.id)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @param model - The model
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAISdkMessageParam> {
|
||||
const isVision = isVisionModel(model)
|
||||
const content = await this.getMessageContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
// If the model does not support files, extract the file content
|
||||
if (this.isNotSupportFiles) {
|
||||
const fileContent = await this.extractFileContent(message)
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: content + '\n\n---\n\n' + fileContent
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
// If the model supports files, add the file content to the message
|
||||
const parts: ChatCompletionContentPart[] = []
|
||||
|
||||
if (content) {
|
||||
parts.push({ type: 'text', text: content })
|
||||
}
|
||||
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (isVision) {
|
||||
if (imageBlock.file) {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({ type: 'image_url', image_url: { url: image.data } })
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({ type: 'image_url', image_url: { url: imageBlock.url } })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) {
|
||||
continue
|
||||
}
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
} as OpenAISdkMessageParam
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] {
|
||||
return mcpToolsToOpenAIChatTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(
|
||||
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
mcpTools: MCPTool[]
|
||||
): MCPTool | undefined {
|
||||
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcpToolResponse(
|
||||
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||
mcpTool: MCPTool
|
||||
): ToolCallResponse {
|
||||
let parsedArgs: any
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolCall.function.arguments)
|
||||
} catch {
|
||||
parsedArgs = toolCall.function.arguments
|
||||
}
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): OpenAISdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
// This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id
|
||||
// For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts.
|
||||
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
role: 'tool',
|
||||
tool_call_id: mcpToolResponse.toolCallId,
|
||||
content: JSON.stringify(resp.content)
|
||||
} as OpenAI.Chat.Completions.ChatCompletionToolMessageParam
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAISdkMessageParam[],
|
||||
output: string,
|
||||
toolResults: OpenAISdkMessageParam[],
|
||||
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
|
||||
): OpenAISdkMessageParam[] {
|
||||
const assistantMessage: OpenAISdkMessageParam = {
|
||||
role: 'assistant',
|
||||
content: output,
|
||||
tool_calls: toolCalls.length > 0 ? toolCalls : undefined
|
||||
}
|
||||
const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults]
|
||||
return newReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: OpenAISdkMessageParam): number {
|
||||
let sum = 0
|
||||
if (typeof message.content === 'string') {
|
||||
sum += estimateTextTokens(message.content)
|
||||
} else if (Array.isArray(message.content)) {
|
||||
sum += (message.content || [])
|
||||
.map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => {
|
||||
switch (part.type) {
|
||||
case 'text':
|
||||
return estimateTextTokens(part.text)
|
||||
case 'image_url':
|
||||
return estimateTextTokens(part.image_url.url)
|
||||
case 'input_audio':
|
||||
return estimateTextTokens(part.input_audio.data)
|
||||
case 'file':
|
||||
return estimateTextTokens(part.file.file_data || '')
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
})
|
||||
.reduce((acc, curr) => acc + curr, 0)
|
||||
}
|
||||
if ('tool_calls' in message && message.tool_calls) {
|
||||
sum += message.tool_calls.reduce((acc, toolCall) => {
|
||||
return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments))
|
||||
}, 0)
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<OpenAISdkParams, OpenAISdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: OpenAISdkParams
|
||||
messages: OpenAISdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
let systemMessage = { role: 'system', content: assistant.prompt || '' }
|
||||
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
systemMessage = {
|
||||
role: 'developer',
|
||||
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
|
||||
}
|
||||
}
|
||||
|
||||
if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) {
|
||||
systemMessage.role = 'assistant'
|
||||
}
|
||||
|
||||
// 2. 设置工具(必须在this.usesystemPromptForTools前面)
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools)
|
||||
}
|
||||
|
||||
// 3. 处理用户消息
|
||||
const userMessages: OpenAISdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
userMessages.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
userMessages.push(await this.convertMessageToSdkParam(message, model))
|
||||
}
|
||||
}
|
||||
|
||||
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
|
||||
if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
|
||||
const postsuffix = '/no_think'
|
||||
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
|
||||
const currentContent = lastUserMsg.content
|
||||
|
||||
lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any
|
||||
}
|
||||
|
||||
// 4. 最终请求消息
|
||||
let reqMessages: OpenAISdkMessageParam[]
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessages]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[]
|
||||
}
|
||||
|
||||
reqMessages = processReqMessages(model, reqMessages)
|
||||
|
||||
// 5. 创建通用参数
|
||||
const commonParams = {
|
||||
model: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: reqMessages,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_tokens: maxTokens,
|
||||
tools: tools.length > 0 ? tools : undefined,
|
||||
service_tier: this.getServiceTier(model),
|
||||
...this.getProviderSpecificParameters(assistant, model),
|
||||
...this.getReasoningEffort(assistant, model),
|
||||
...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
// Create the appropriate parameters object based on whether streaming is enabled
|
||||
const sdkParams: OpenAISdkParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
|
||||
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||
getResponseChunkTransformer = (): ResponseChunkTransformer<OpenAISdkRawChunk> => {
|
||||
let hasBeenCollectedWebSearch = false
|
||||
const collectWebSearchData = (
|
||||
chunk: OpenAISdkRawChunk,
|
||||
contentSource: OpenAISdkRawContentSource,
|
||||
context: ResponseChunkTransformerContext
|
||||
) => {
|
||||
if (hasBeenCollectedWebSearch) {
|
||||
return
|
||||
}
|
||||
// OpenAI annotations
|
||||
// @ts-ignore - annotations may not be in standard type definitions
|
||||
const annotations = contentSource.annotations || chunk.annotations
|
||||
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
results: annotations,
|
||||
source: WebSearchSource.OPENAI
|
||||
}
|
||||
}
|
||||
|
||||
// Grok citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'grok' && chunk.citations) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.GROK
|
||||
}
|
||||
}
|
||||
|
||||
// Perplexity citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.PERPLEXITY
|
||||
}
|
||||
}
|
||||
|
||||
// OpenRouter citations
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - citations may not be in standard type definitions
|
||||
results: chunk.citations,
|
||||
source: WebSearchSource.OPENROUTER
|
||||
}
|
||||
}
|
||||
|
||||
// Zhipu web search
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
if (context.provider?.id === 'zhipu' && chunk.web_search) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - web_search may not be in standard type definitions
|
||||
results: chunk.web_search,
|
||||
source: WebSearchSource.ZHIPU
|
||||
}
|
||||
}
|
||||
|
||||
// Hunyuan web search
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
|
||||
hasBeenCollectedWebSearch = true
|
||||
return {
|
||||
// @ts-ignore - search_info may not be in standard type definitions
|
||||
results: chunk.search_info.search_results,
|
||||
source: WebSearchSource.HUNYUAN
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 放到AnthropicApiClient中
|
||||
// // Other providers...
|
||||
// // @ts-ignore - web_search may not be in standard type definitions
|
||||
// if (chunk.web_search) {
|
||||
// const sourceMap: Record<string, string> = {
|
||||
// openai: 'openai',
|
||||
// anthropic: 'anthropic',
|
||||
// qwenlm: 'qwen'
|
||||
// }
|
||||
// const source = sourceMap[context.provider?.id] || 'openai_response'
|
||||
// return {
|
||||
// results: chunk.web_search,
|
||||
// source: source as const
|
||||
// }
|
||||
// }
|
||||
|
||||
return null
|
||||
}
|
||||
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
|
||||
return (context: ResponseChunkTransformerContext) => ({
|
||||
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 处理chunk
|
||||
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
|
||||
const choice = chunk.choices[0]
|
||||
|
||||
if (!choice) return
|
||||
|
||||
// 对于流式响应,使用delta;对于非流式响应,使用message
|
||||
const contentSource: OpenAISdkRawContentSource | null =
|
||||
'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null
|
||||
|
||||
if (!contentSource) return
|
||||
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
|
||||
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
|
||||
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
|
||||
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
|
||||
if (reasoningText) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: reasoningText
|
||||
})
|
||||
}
|
||||
|
||||
// 处理文本内容
|
||||
if (contentSource.content) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: contentSource.content
|
||||
})
|
||||
}
|
||||
|
||||
// 处理工具调用
|
||||
if (contentSource.tool_calls) {
|
||||
for (const toolCall of contentSource.tool_calls) {
|
||||
if ('index' in toolCall) {
|
||||
const { id, index, function: fun } = toolCall
|
||||
if (fun?.name) {
|
||||
toolCalls[index] = {
|
||||
id: id || '',
|
||||
function: {
|
||||
name: fun.name,
|
||||
arguments: fun.arguments || ''
|
||||
},
|
||||
type: 'function'
|
||||
}
|
||||
} else if (fun?.arguments) {
|
||||
toolCalls[index].function.arguments += fun.arguments
|
||||
}
|
||||
} else {
|
||||
toolCalls.push(toolCall)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理finish_reason,发送流结束信号
|
||||
if ('finish_reason' in choice && choice.finish_reason) {
|
||||
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||
if (webSearchData) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: webSearchData
|
||||
})
|
||||
}
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.usage?.prompt_tokens || 0,
|
||||
completion_tokens: chunk.usage?.completion_tokens || 0,
|
||||
total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
258
src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts
Normal file
258
src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts
Normal file
@ -0,0 +1,258 @@
|
||||
import {
|
||||
isClaudeReasoningModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isOpenAIReasoningModel,
|
||||
isSupportedModel,
|
||||
isSupportedReasoningEffortOpenAIModel
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import { SettingsState } from '@renderer/store/settings'
|
||||
import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAISdkMessageParam,
|
||||
OpenAISdkParams,
|
||||
OpenAISdkRawChunk,
|
||||
OpenAISdkRawOutput,
|
||||
ReasoningEffortOptionalParams
|
||||
} from '@renderer/types/sdk'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
|
||||
/**
|
||||
* 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能
|
||||
*/
|
||||
export abstract class OpenAIBaseClient<
|
||||
TSdkInstance extends OpenAI | AzureOpenAI,
|
||||
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
|
||||
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
|
||||
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
|
||||
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
|
||||
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
|
||||
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
|
||||
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
// 仅适用于openai
|
||||
override getBaseURL(): string {
|
||||
const host = this.provider.apiHost
|
||||
return formatApiHost(host)
|
||||
}
|
||||
|
||||
override async generateImage({
|
||||
model,
|
||||
prompt,
|
||||
negativePrompt,
|
||||
imageSize,
|
||||
batchSize,
|
||||
seed,
|
||||
numInferenceSteps,
|
||||
guidanceScale,
|
||||
signal,
|
||||
promptEnhancement
|
||||
}: GenerateImageParams): Promise<string[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = (await sdk.request({
|
||||
method: 'post',
|
||||
path: '/images/generations',
|
||||
signal,
|
||||
body: {
|
||||
model,
|
||||
prompt,
|
||||
negative_prompt: negativePrompt,
|
||||
image_size: imageSize,
|
||||
batch_size: batchSize,
|
||||
seed: seed ? parseInt(seed) : undefined,
|
||||
num_inference_steps: numInferenceSteps,
|
||||
guidance_scale: guidanceScale,
|
||||
prompt_enhancement: promptEnhancement
|
||||
}
|
||||
})) as { data: Array<{ url: string }> }
|
||||
|
||||
return response.data.map((item) => item.url)
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
try {
|
||||
const data = await sdk.embeddings.create({
|
||||
model: model.id,
|
||||
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
|
||||
encoding_format: 'float'
|
||||
})
|
||||
return data.data[0].embedding.length
|
||||
} catch (e) {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||
try {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const response = await sdk.models.list()
|
||||
if (this.provider.id === 'github') {
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body
|
||||
.map((model) => ({
|
||||
id: model.name,
|
||||
description: model.summary,
|
||||
object: 'model',
|
||||
owned_by: model.publisher
|
||||
}))
|
||||
.filter(isSupportedModel)
|
||||
}
|
||||
if (this.provider.id === 'together') {
|
||||
// @ts-ignore key is not typed
|
||||
return response?.body.map((model) => ({
|
||||
id: model.id,
|
||||
description: model.display_name,
|
||||
object: 'model',
|
||||
owned_by: model.organization
|
||||
}))
|
||||
}
|
||||
const models = response.data || []
|
||||
models.forEach((model) => {
|
||||
model.id = model.id.trim()
|
||||
})
|
||||
|
||||
return models.filter(isSupportedModel)
|
||||
} catch (error) {
|
||||
console.error('Error listing models:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
let apiKeyForSdkInstance = this.provider.apiKey
|
||||
|
||||
if (this.provider.id === 'copilot') {
|
||||
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||
// this.provider.apiKey不允许修改
|
||||
// this.provider.apiKey = token
|
||||
apiKeyForSdkInstance = token
|
||||
}
|
||||
|
||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||
this.sdkInstance = new AzureOpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
apiVersion: this.provider.apiVersion,
|
||||
endpoint: this.provider.apiHost
|
||||
}) as TSdkInstance
|
||||
} else {
|
||||
this.sdkInstance = new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: apiKeyForSdkInstance,
|
||||
baseURL: this.getBaseURL(),
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders(),
|
||||
...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}),
|
||||
...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {})
|
||||
}
|
||||
}) as TSdkInstance
|
||||
}
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (
|
||||
isNotSupportTemperatureAndTopP(model) ||
|
||||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (
|
||||
isNotSupportTemperatureAndTopP(model) ||
|
||||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the provider specific parameters for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The provider specific parameters
|
||||
*/
|
||||
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
if (this.provider.id === 'openrouter') {
|
||||
if (model.id.includes('deepseek-r1')) {
|
||||
return {
|
||||
include_reasoning: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return {
|
||||
max_tokens: undefined,
|
||||
max_completion_tokens: maxTokens
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort for the assistant
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
if (!isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
const summaryText = openAI?.summaryText || 'off'
|
||||
|
||||
let summary: string | undefined = undefined
|
||||
|
||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||
summary = undefined
|
||||
} else {
|
||||
summary = summaryText
|
||||
}
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
if (!reasoningEffort) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
reasoning: {
|
||||
effort: reasoningEffort as OpenAI.ReasoningEffort,
|
||||
summary: summary
|
||||
} as OpenAI.Reasoning
|
||||
}
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,532 @@
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import {
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
isSupportedReasoningEffortOpenAIModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkTool,
|
||||
OpenAIResponseSdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToOpenAIMessage,
|
||||
mcpToolsToOpenAIResponseTools,
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { isEmpty } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
import { OpenAIAPIClient } from './OpenAIApiClient'
|
||||
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||
|
||||
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
OpenAI,
|
||||
OpenAIResponseSdkParams,
|
||||
OpenAIResponseSdkRawOutput,
|
||||
OpenAIResponseSdkRawChunk,
|
||||
OpenAIResponseSdkMessageParam,
|
||||
OpenAIResponseSdkToolCall,
|
||||
OpenAIResponseSdkTool
|
||||
> {
|
||||
private client: OpenAIAPIClient
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.client = new OpenAIAPIClient(provider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型特征选择合适的客户端
|
||||
*/
|
||||
public getClient(model: Model) {
|
||||
if (isOpenAIChatCompletionOnlyModel(model)) {
|
||||
return this.client
|
||||
} else {
|
||||
return this
|
||||
}
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
return new OpenAI({
|
||||
dangerouslyAllowBrowser: true,
|
||||
apiKey: this.provider.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
defaultHeaders: {
|
||||
...this.defaultHeaders()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: OpenAIResponseSdkParams,
|
||||
options?: OpenAI.RequestOptions
|
||||
): Promise<OpenAIResponseSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
return await sdk.responses.create(payload, options)
|
||||
}
|
||||
|
||||
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
|
||||
const isVision = isVisionModel(model)
|
||||
const content = await this.getMessageContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
|
||||
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||
if (message.role === 'assistant') {
|
||||
return {
|
||||
role: 'assistant',
|
||||
content: content
|
||||
}
|
||||
} else {
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: content ? [{ type: 'input_text', text: content }] : []
|
||||
} as OpenAI.Responses.EasyInputMessage
|
||||
}
|
||||
}
|
||||
|
||||
const parts: OpenAI.Responses.ResponseInputContent[] = []
|
||||
if (content) {
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: content
|
||||
})
|
||||
}
|
||||
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (isVision) {
|
||||
if (imageBlock.file) {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: image.data as string
|
||||
})
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
parts.push({
|
||||
detail: 'auto',
|
||||
type: 'input_image',
|
||||
image_url: imageBlock.url
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const file = fileBlock.file
|
||||
if (!file) continue
|
||||
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'input_text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
|
||||
return mcpToolsToOpenAIResponseTools(mcpTools)
|
||||
}
|
||||
|
||||
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
const parsedArgs = (() => {
|
||||
try {
|
||||
return JSON.parse(toolCall.arguments)
|
||||
} catch {
|
||||
return toolCall.arguments
|
||||
}
|
||||
})()
|
||||
|
||||
return {
|
||||
id: toolCall.call_id,
|
||||
toolCallId: toolCall.call_id,
|
||||
tool: mcpTool,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): OpenAIResponseSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
type: 'function_call_output',
|
||||
call_id: mcpToolResponse.toolCallId,
|
||||
output: JSON.stringify(resp.content)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
public buildSdkMessages(
|
||||
currentReqMessages: OpenAIResponseSdkMessageParam[],
|
||||
output: string,
|
||||
toolResults: OpenAIResponseSdkMessageParam[],
|
||||
toolCalls: OpenAIResponseSdkToolCall[]
|
||||
): OpenAIResponseSdkMessageParam[] {
|
||||
const assistantMessage: OpenAIResponseSdkMessageParam = {
|
||||
role: 'assistant',
|
||||
content: [{ type: 'input_text', text: output }]
|
||||
}
|
||||
const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])]
|
||||
return newReqMessages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
|
||||
let sum = 0
|
||||
if ('content' in message) {
|
||||
if (typeof message.content === 'string') {
|
||||
sum += estimateTextTokens(message.content)
|
||||
} else if (Array.isArray(message.content)) {
|
||||
for (const part of message.content) {
|
||||
switch (part.type) {
|
||||
case 'input_text':
|
||||
sum += estimateTextTokens(part.text)
|
||||
break
|
||||
case 'input_image':
|
||||
sum += estimateTextTokens(part.image_url || '')
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
switch (message.type) {
|
||||
case 'function_call_output':
|
||||
sum += estimateTextTokens(message.output)
|
||||
break
|
||||
case 'function_call':
|
||||
sum += estimateTextTokens(message.arguments)
|
||||
break
|
||||
default:
|
||||
break
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
|
||||
if (typeof sdkPayload.input === 'string') {
|
||||
return [{ role: 'user', content: sdkPayload.input }]
|
||||
}
|
||||
return sdkPayload.input
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: OpenAIResponseSdkParams
|
||||
messages: OpenAIResponseSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemMessage: OpenAI.Responses.EasyInputMessage = {
|
||||
role: 'system',
|
||||
content: []
|
||||
}
|
||||
|
||||
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
|
||||
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
|
||||
text: assistant.prompt || '',
|
||||
type: 'input_text'
|
||||
}
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
systemMessage.role = 'developer'
|
||||
}
|
||||
|
||||
// 2. 设置工具
|
||||
let tools: OpenAI.Responses.Tool[] = []
|
||||
const { tools: extraTools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools) {
|
||||
systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools)
|
||||
}
|
||||
systemMessageContent.push(systemMessageInput)
|
||||
systemMessage.content = systemMessageContent
|
||||
|
||||
// 3. 处理用户消息
|
||||
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
|
||||
if (typeof messages === 'string') {
|
||||
userMessage.push({ role: 'user', content: messages })
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
userMessage.push(await this.convertMessageToSdkParam(message, model))
|
||||
}
|
||||
}
|
||||
// FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id)
|
||||
if (enableGenerateImage) {
|
||||
const finalAssistantMessage = userMessage.findLast(
|
||||
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
|
||||
) as OpenAI.Responses.EasyInputMessage
|
||||
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
|
||||
if (
|
||||
finalAssistantMessage &&
|
||||
Array.isArray(finalAssistantMessage.content) &&
|
||||
finalUserMessage &&
|
||||
Array.isArray(finalUserMessage.content)
|
||||
) {
|
||||
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
|
||||
}
|
||||
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
|
||||
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
|
||||
}
|
||||
|
||||
// 4. 最终请求消息
|
||||
let reqMessages: OpenAI.Responses.ResponseInput
|
||||
if (!systemMessage.content) {
|
||||
reqMessages = [...userMessage]
|
||||
} else {
|
||||
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
tools.push({
|
||||
type: 'web_search_preview'
|
||||
})
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
tools.push({
|
||||
type: 'image_generation',
|
||||
partial_images: streamOutput ? 2 : undefined
|
||||
})
|
||||
}
|
||||
|
||||
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
||||
type: 'web_search_preview'
|
||||
}
|
||||
|
||||
tools = tools.concat(extraTools)
|
||||
const commonParams = {
|
||||
model: model.id,
|
||||
input:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: reqMessages,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
max_output_tokens: maxTokens,
|
||||
stream: streamOutput,
|
||||
tools: !isEmpty(tools) ? tools : undefined,
|
||||
tool_choice: enableWebSearch ? toolChoices : undefined,
|
||||
service_tier: this.getServiceTier(model),
|
||||
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
const sdkParams: OpenAIResponseSdkParams = streamOutput
|
||||
? {
|
||||
...commonParams,
|
||||
stream: true
|
||||
}
|
||||
: {
|
||||
...commonParams,
|
||||
stream: false
|
||||
}
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
|
||||
const toolCalls: OpenAIResponseSdkToolCall[] = []
|
||||
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
|
||||
return () => ({
|
||||
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
// 处理chunk
|
||||
if ('output' in chunk) {
|
||||
for (const output of chunk.output) {
|
||||
switch (output.type) {
|
||||
case 'message':
|
||||
if (output.content[0].type === 'output_text') {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: output.content[0].text
|
||||
})
|
||||
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: output.content[0].annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
case 'reasoning':
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: output.summary.map((s) => s.text).join('\n')
|
||||
})
|
||||
break
|
||||
case 'function_call':
|
||||
toolCalls.push(output)
|
||||
break
|
||||
case 'image_generation_call':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${output.result}`]
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
switch (chunk.type) {
|
||||
case 'response.output_item.added':
|
||||
if (chunk.item.type === 'function_call') {
|
||||
outputItems.push(chunk.item)
|
||||
}
|
||||
break
|
||||
case 'response.reasoning_summary_text.delta':
|
||||
controller.enqueue({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.generating':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.partial_image':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_DELTA,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'response.image_generation_call.completed':
|
||||
controller.enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE
|
||||
})
|
||||
break
|
||||
case 'response.output_text.delta': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: chunk.delta
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'response.function_call_arguments.done': {
|
||||
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
|
||||
(item) => item.id === chunk.item_id
|
||||
)
|
||||
if (outputItem) {
|
||||
if (outputItem.type === 'function_call') {
|
||||
toolCalls.push({
|
||||
...outputItem,
|
||||
arguments: chunk.arguments
|
||||
})
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.content_part.done': {
|
||||
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
source: WebSearchSource.OPENAI_RESPONSE,
|
||||
results: chunk.part.annotations
|
||||
}
|
||||
})
|
||||
}
|
||||
if (toolCalls.length > 0) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: toolCalls
|
||||
})
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'response.completed': {
|
||||
const completion_tokens = chunk.response.usage?.output_tokens || 0
|
||||
const total_tokens = chunk.response.usage?.total_tokens || 0
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: chunk.response.usage?.input_tokens || 0,
|
||||
completion_tokens: completion_tokens,
|
||||
total_tokens: total_tokens
|
||||
}
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'error': {
|
||||
controller.enqueue({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
message: chunk.message,
|
||||
code: chunk.code
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
129
src/renderer/src/aiCore/clients/types.ts
Normal file
129
src/renderer/src/aiCore/clients/types.ts
Normal file
@ -0,0 +1,129 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import { Provider } from '@renderer/types'
|
||||
import {
|
||||
AnthropicSdkRawChunk,
|
||||
OpenAISdkRawChunk,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
|
||||
|
||||
/**
|
||||
* 原始流监听器接口
|
||||
*/
|
||||
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
|
||||
onChunk?: (chunk: TRawChunk) => void
|
||||
onStart?: () => void
|
||||
onEnd?: () => void
|
||||
onError?: (error: Error) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* OpenAI 专用的流监听器
|
||||
*/
|
||||
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
|
||||
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
|
||||
onFinishReason?: (reason: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* Anthropic 专用的流监听器
|
||||
*/
|
||||
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
|
||||
extends RawStreamListener<TChunk> {
|
||||
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
|
||||
onMessage?: (message: Anthropic.Messages.Message) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 请求转换器接口
|
||||
*/
|
||||
export interface RequestTransformer<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam
|
||||
> {
|
||||
transform(
|
||||
completionsParams: CompletionsParams,
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
isRecursiveCall?: boolean,
|
||||
recursiveSdkMessages?: TMessageParam[]
|
||||
): Promise<{
|
||||
payload: TSdkParams
|
||||
messages: TMessageParam[]
|
||||
metadata?: Record<string, any>
|
||||
}>
|
||||
}
|
||||
|
||||
/**
|
||||
* 响应块转换器接口
|
||||
*/
|
||||
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
|
||||
context?: TContext
|
||||
) => Transformer<TRawChunk, GenericChunk>
|
||||
|
||||
export interface ResponseChunkTransformerContext {
|
||||
isStreaming: boolean
|
||||
isEnabledToolCalling: boolean
|
||||
isEnabledWebSearch: boolean
|
||||
isEnabledReasoning: boolean
|
||||
mcpTools: MCPTool[]
|
||||
provider: Provider
|
||||
}
|
||||
|
||||
/**
|
||||
* API客户端接口
|
||||
*/
|
||||
export interface ApiClient<
|
||||
TSdkInstance = any,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> {
|
||||
provider: Provider
|
||||
|
||||
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
|
||||
// 实际的SDK调用由SdkCallMiddleware处理
|
||||
// completions(params: CompletionsParams): Promise<CompletionsResult>
|
||||
|
||||
createCompletions(payload: TSdkParams): Promise<TRawOutput>
|
||||
|
||||
// SDK相关方法
|
||||
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||
|
||||
// 原始流监听方法
|
||||
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
|
||||
|
||||
// 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具)
|
||||
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||
convertMcpToolResponseToSdkMessageParam?(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: any,
|
||||
model: Model
|
||||
): TMessageParam | undefined
|
||||
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||
|
||||
// 构建SDK特定的消息列表,用于工具调用后的递归调用
|
||||
buildSdkMessages(
|
||||
currentReqMessages: TMessageParam[],
|
||||
output: TRawOutput | string,
|
||||
toolResults: TMessageParam[],
|
||||
toolCalls?: TToolCall[]
|
||||
): TMessageParam[]
|
||||
|
||||
// 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||
}
|
||||
130
src/renderer/src/aiCore/index.ts
Normal file
130
src/renderer/src/aiCore/index.ts
Normal file
@ -0,0 +1,130 @@
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { OpenAIAPIClient } from './clients'
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
if (!params.enableReasoning) {
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
builder.remove(ThinkChunkMiddlewareName)
|
||||
}
|
||||
// 注意:用client判断会导致typescript类型收窄
|
||||
if (!(this.apiClient instanceof OpenAIAPIClient)) {
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
if (!(this.apiClient instanceof AnthropicAPIClient)) {
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
}
|
||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
}
|
||||
if (params.callType !== 'chat') {
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
return wrappedCompletionMethod(params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
console.error('Error getting embedding dimensions:', error)
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
182
src/renderer/src/aiCore/middleware/BUILDER_USAGE.md
Normal file
182
src/renderer/src/aiCore/middleware/BUILDER_USAGE.md
Normal file
@ -0,0 +1,182 @@
|
||||
# MiddlewareBuilder 使用指南
|
||||
|
||||
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
|
||||
|
||||
## 主要特性
|
||||
|
||||
### 1. 统一的中间件命名
|
||||
|
||||
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
|
||||
|
||||
```typescript
|
||||
// 中间件文件示例
|
||||
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
|
||||
export const SdkCallMiddleware: CompletionsMiddleware = ...
|
||||
```
|
||||
|
||||
### 2. NamedMiddleware 接口
|
||||
|
||||
中间件使用统一的 `NamedMiddleware` 接口格式:
|
||||
|
||||
```typescript
|
||||
interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 中间件注册表
|
||||
|
||||
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
|
||||
|
||||
```typescript
|
||||
import { MiddlewareRegistry } from './register'
|
||||
|
||||
// 通过名称获取中间件
|
||||
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
|
||||
```
|
||||
|
||||
## 基本用法
|
||||
|
||||
### 1. 使用默认中间件链
|
||||
|
||||
```typescript
|
||||
import { CompletionsMiddlewareBuilder } from './builder'
|
||||
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 2. 自定义中间件链
|
||||
|
||||
```typescript
|
||||
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
|
||||
|
||||
const builder = createCompletionsBuilder([
|
||||
MiddlewareRegistry['AbortHandlerMiddleware'],
|
||||
MiddlewareRegistry['TextChunkMiddleware']
|
||||
])
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 3. 动态调整中间件链
|
||||
|
||||
```typescript
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 根据条件添加、移除、替换中间件
|
||||
if (needsLogging) {
|
||||
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
|
||||
}
|
||||
|
||||
if (disableTools) {
|
||||
builder.remove('McpToolChunkMiddleware')
|
||||
}
|
||||
|
||||
if (customThinking) {
|
||||
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
```
|
||||
|
||||
### 4. 链式操作
|
||||
|
||||
```typescript
|
||||
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
|
||||
.add(MiddlewareRegistry['CustomMiddleware'])
|
||||
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
|
||||
.remove('WebSearchMiddleware')
|
||||
.build()
|
||||
```
|
||||
|
||||
## API 参考
|
||||
|
||||
### CompletionsMiddlewareBuilder
|
||||
|
||||
**静态方法:**
|
||||
|
||||
- `static withDefaults()`: 创建带有默认中间件链的构建器
|
||||
|
||||
**实例方法:**
|
||||
|
||||
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
|
||||
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
|
||||
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
|
||||
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
|
||||
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
|
||||
- `remove(targetName: string)`: 移除指定中间件
|
||||
- `has(name: string)`: 检查是否包含指定中间件
|
||||
- `build()`: 构建最终的中间件数组
|
||||
- `getChain()`: 获取当前链(包含名称信息)
|
||||
- `clear()`: 清空中间件链
|
||||
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
|
||||
|
||||
### 工厂函数
|
||||
|
||||
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
|
||||
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
|
||||
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
|
||||
|
||||
### 中间件注册表
|
||||
|
||||
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
|
||||
- `getMiddleware(name)`: 根据名称获取中间件
|
||||
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
|
||||
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式)
|
||||
|
||||
## 类型安全
|
||||
|
||||
构建器提供完整的 TypeScript 类型支持:
|
||||
|
||||
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
|
||||
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
|
||||
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
|
||||
|
||||
## 默认中间件链
|
||||
|
||||
默认的 Completions 中间件执行顺序:
|
||||
|
||||
1. `FinalChunkConsumerMiddleware` - 最终消费者
|
||||
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
|
||||
3. `AbortHandlerMiddleware` - 中止处理
|
||||
4. `McpToolChunkMiddleware` - 工具处理
|
||||
5. `WebSearchMiddleware` - Web搜索处理
|
||||
6. `TextChunkMiddleware` - 文本处理
|
||||
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
|
||||
8. `ThinkChunkMiddleware` - 思考处理
|
||||
9. `ResponseTransformMiddleware` - 响应转换
|
||||
10. `StreamAdapterMiddleware` - 流适配器
|
||||
11. `SdkCallMiddleware` - SDK调用
|
||||
|
||||
## 在 AiProvider 中的使用
|
||||
|
||||
```typescript
|
||||
export default class AiProvider {
|
||||
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
|
||||
// 1. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
|
||||
// 2. 根据参数动态调整
|
||||
if (params.enableCustomFeature) {
|
||||
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
|
||||
}
|
||||
|
||||
// 3. 应用中间件
|
||||
const middlewares = builder.build()
|
||||
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
|
||||
|
||||
return wrappedMethod(params)
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 注意事项
|
||||
|
||||
1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器
|
||||
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
|
||||
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
|
||||
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
|
||||
|
||||
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。
|
||||
175
src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md
Normal file
175
src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md
Normal file
@ -0,0 +1,175 @@
|
||||
# Cherry Studio 中间件规范
|
||||
|
||||
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
|
||||
|
||||
## 1. 核心概念
|
||||
|
||||
### 1.1. 中间件 (Middleware)
|
||||
|
||||
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
|
||||
|
||||
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
|
||||
|
||||
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
|
||||
|
||||
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
|
||||
|
||||
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
|
||||
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
|
||||
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
|
||||
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
|
||||
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
|
||||
- `abortController?: AbortController`: 用于中止请求的控制器。
|
||||
- 其他中间件可能读写的、与当前请求相关的动态数据。
|
||||
|
||||
### 1.3. `MiddlewareName` (中间件名称)
|
||||
|
||||
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
|
||||
|
||||
```typescript
|
||||
// example
|
||||
export enum MiddlewareName {
|
||||
LOGGING_START = 'LoggingStartMiddleware',
|
||||
LOGGING_END = 'LoggingEndMiddleware',
|
||||
ERROR_HANDLING = 'ErrorHandlingMiddleware',
|
||||
ABORT_HANDLER = 'AbortHandlerMiddleware',
|
||||
// Core Flow
|
||||
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
|
||||
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
|
||||
STREAM_ADAPTER = 'StreamAdapterMiddleware',
|
||||
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
|
||||
// Features
|
||||
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
|
||||
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
|
||||
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
|
||||
// Finalization
|
||||
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
// Add more as needed
|
||||
}
|
||||
```
|
||||
|
||||
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
|
||||
|
||||
### 1.4. 中间件执行结构
|
||||
|
||||
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
|
||||
|
||||
```typescript
|
||||
// 简化形式的中间件函数签名
|
||||
type MiddlewareFunction = (
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any, // e.g., CompletionsParams
|
||||
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
|
||||
) => Promise<void> // 中间件自身也可能返回 Promise
|
||||
|
||||
// 或者更经典的 Koa/Express 风格 (三段式)
|
||||
// type MiddlewareFactory = (api?: MiddlewareApi) =>
|
||||
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
|
||||
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
|
||||
// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。
|
||||
```
|
||||
|
||||
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
|
||||
|
||||
## 2. `MiddlewareBuilder` (通用中间件构建器)
|
||||
|
||||
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
|
||||
|
||||
### 2.1. 设计理念
|
||||
|
||||
`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。
|
||||
|
||||
### 2.2. API 概览
|
||||
|
||||
```typescript
|
||||
class MiddlewareBuilder {
|
||||
constructor(baseChain?: Middleware[])
|
||||
|
||||
add(middleware: Middleware): this
|
||||
prepend(middleware: Middleware): this
|
||||
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
|
||||
remove(targetName: MiddlewareName): this
|
||||
|
||||
build(): Middleware[] // 返回构建好的中间件数组
|
||||
|
||||
// 可选:直接执行链
|
||||
execute(
|
||||
context: AiProviderMiddlewareContext,
|
||||
params: any,
|
||||
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
|
||||
): void
|
||||
}
|
||||
```
|
||||
|
||||
### 2.3. 使用示例
|
||||
|
||||
```typescript
|
||||
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
|
||||
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
|
||||
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
|
||||
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
|
||||
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
|
||||
|
||||
// 2. 定义一个基础链 (可选)
|
||||
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
|
||||
|
||||
// 3. 使用 MiddlewareBuilder
|
||||
const builder = new MiddlewareBuilder(BASE_CHAIN)
|
||||
|
||||
if (params.needsCustomFeature) {
|
||||
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
|
||||
}
|
||||
|
||||
if (params.isHighSecurityContext) {
|
||||
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
|
||||
}
|
||||
|
||||
if (params.overrideLogging) {
|
||||
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
|
||||
}
|
||||
|
||||
// 4. 获取最终链
|
||||
const finalChain = builder.build()
|
||||
|
||||
// 5. 执行 (通过外部执行器)
|
||||
// middlewareExecutor(finalChain, context, params);
|
||||
// 或者 builder.execute(context, params, middlewareExecutor);
|
||||
```
|
||||
|
||||
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
|
||||
|
||||
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
|
||||
|
||||
### 3.1. 职责
|
||||
|
||||
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。
|
||||
- 按顺序迭代中间件。
|
||||
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
|
||||
- 处理中间件执行过程中的Promise(如果中间件是异步的)。
|
||||
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
|
||||
|
||||
## 4. 在 `AiCoreService` 中使用
|
||||
|
||||
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
|
||||
|
||||
1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。
|
||||
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
|
||||
3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
|
||||
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
|
||||
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
|
||||
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
|
||||
|
||||
## 5. 组合功能
|
||||
|
||||
对于组合功能(例如 "Completions then Translate"):
|
||||
|
||||
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
|
||||
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
|
||||
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
|
||||
- 这种方式最大化了复用,并保持了各部分职责的清晰。
|
||||
|
||||
## 6. 中间件命名和发现
|
||||
|
||||
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。
|
||||
241
src/renderer/src/aiCore/middleware/builder.ts
Normal file
241
src/renderer/src/aiCore/middleware/builder.ts
Normal file
@ -0,0 +1,241 @@
|
||||
import { DefaultCompletionsNamedMiddlewares } from './register'
|
||||
import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
|
||||
|
||||
/**
|
||||
* 带有名称标识的中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware<TMiddleware = any> {
|
||||
name: string
|
||||
middleware: TMiddleware
|
||||
}
|
||||
|
||||
/**
|
||||
* 中间件执行器函数类型
|
||||
*/
|
||||
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
|
||||
chain: any[],
|
||||
context: TContext,
|
||||
params: any
|
||||
) => Promise<any>
|
||||
|
||||
/**
|
||||
* 通用中间件构建器类
|
||||
* 提供流式 API 用于动态构建和管理中间件链
|
||||
*
|
||||
* 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式
|
||||
*/
|
||||
export class MiddlewareBuilder<TMiddleware = any> {
|
||||
private middlewares: NamedMiddleware<TMiddleware>[]
|
||||
|
||||
/**
|
||||
* 构造函数
|
||||
* @param baseChain - 可选的基础中间件链(NamedMiddleware 格式)
|
||||
*/
|
||||
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
|
||||
this.middlewares = baseChain ? [...baseChain] : []
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的末尾添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
add(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.push(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在链的开头添加中间件
|
||||
* @param middleware - 要添加的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
prepend(middleware: NamedMiddleware<TMiddleware>): this {
|
||||
this.middlewares.unshift(middleware)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之后插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index + 1, 0, middlewareToInsert)
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 在指定中间件之前插入新中间件
|
||||
* @param targetName - 目标中间件名称
|
||||
* @param middlewareToInsert - 要插入的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 0, middlewareToInsert)
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 替换指定的中间件
|
||||
* @param targetName - 要替换的中间件名称
|
||||
* @param newMiddleware - 新的具名中间件
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares[index] = newMiddleware
|
||||
} else {
|
||||
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除指定的中间件
|
||||
* @param targetName - 要移除的中间件名称
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
remove(targetName: string): this {
|
||||
const index = this.findMiddlewareIndex(targetName)
|
||||
if (index !== -1) {
|
||||
this.middlewares.splice(index, 1)
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终的中间件数组
|
||||
* @returns 构建好的中间件数组
|
||||
*/
|
||||
build(): TMiddleware[] {
|
||||
return this.middlewares.map((item) => item.middleware)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前中间件链的副本(包含名称信息)
|
||||
* @returns 当前中间件链的副本
|
||||
*/
|
||||
getChain(): NamedMiddleware<TMiddleware>[] {
|
||||
return [...this.middlewares]
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否包含指定名称的中间件
|
||||
* @param name - 中间件名称
|
||||
* @returns 是否包含该中间件
|
||||
*/
|
||||
has(name: string): boolean {
|
||||
return this.findMiddlewareIndex(name) !== -1
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取中间件链的长度
|
||||
* @returns 中间件数量
|
||||
*/
|
||||
get length(): number {
|
||||
return this.middlewares.length
|
||||
}
|
||||
|
||||
/**
|
||||
* 清空中间件链
|
||||
* @returns this,支持链式调用
|
||||
*/
|
||||
clear(): this {
|
||||
this.middlewares = []
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接执行构建好的中间件链
|
||||
* @param context - 中间件上下文
|
||||
* @param params - 参数
|
||||
* @param middlewareExecutor - 中间件执行器
|
||||
* @returns 执行结果
|
||||
*/
|
||||
execute<TContext extends BaseContext>(
|
||||
context: TContext,
|
||||
params: any,
|
||||
middlewareExecutor: MiddlewareExecutor<TContext>
|
||||
): Promise<any> {
|
||||
const chain = this.build()
|
||||
return middlewareExecutor(chain, context, params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 查找中间件在链中的索引
|
||||
* @param name - 中间件名称
|
||||
* @returns 索引,如果未找到返回 -1
|
||||
*/
|
||||
private findMiddlewareIndex(name: string): number {
|
||||
return this.middlewares.findIndex((item) => item.name === name)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Completions 中间件构建器
|
||||
*/
|
||||
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用默认的 Completions 中间件链
|
||||
* @returns CompletionsMiddlewareBuilder 实例
|
||||
*/
|
||||
static withDefaults(): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用方法中间件构建器
|
||||
*/
|
||||
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
|
||||
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
|
||||
super(baseChain)
|
||||
}
|
||||
}
|
||||
|
||||
// 便捷的工厂函数
|
||||
|
||||
/**
|
||||
* 创建 Completions 中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns Completions 中间件构建器实例
|
||||
*/
|
||||
export function createCompletionsBuilder(
|
||||
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
|
||||
): CompletionsMiddlewareBuilder {
|
||||
return new CompletionsMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建通用方法中间件构建器
|
||||
* @param baseChain - 可选的基础链
|
||||
* @returns 通用方法中间件构建器实例
|
||||
*/
|
||||
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
|
||||
return new MethodMiddlewareBuilder(baseChain)
|
||||
}
|
||||
|
||||
/**
|
||||
* 为中间件添加名称属性的辅助函数
|
||||
* 可以用于给现有的中间件添加名称属性
|
||||
*/
|
||||
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
|
||||
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
|
||||
}
|
||||
@ -0,0 +1,106 @@
|
||||
import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
|
||||
|
||||
export const AbortHandlerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
|
||||
|
||||
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
|
||||
if (isRecursiveCall) {
|
||||
const result = await next(ctx, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// 获取当前消息的ID用于abort管理
|
||||
// 优先使用处理过的消息,如果没有则使用原始消息
|
||||
let messageId: string | undefined
|
||||
|
||||
if (typeof params.messages === 'string') {
|
||||
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
|
||||
} else {
|
||||
const processedMessages = params.messages
|
||||
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
|
||||
messageId = lastUserMessage?.id
|
||||
}
|
||||
|
||||
if (!messageId) {
|
||||
console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`)
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
const abortController = new AbortController()
|
||||
const abortFn = (): void => abortController.abort()
|
||||
|
||||
addAbortController(messageId, abortFn)
|
||||
|
||||
let abortSignal: AbortSignal | null = abortController.signal
|
||||
|
||||
const cleanup = (): void => {
|
||||
removeAbortController(messageId as string, abortFn)
|
||||
if (ctx._internal?.flowControl) {
|
||||
ctx._internal.flowControl.abortController = undefined
|
||||
ctx._internal.flowControl.abortSignal = undefined
|
||||
ctx._internal.flowControl.cleanup = undefined
|
||||
}
|
||||
abortSignal = null
|
||||
}
|
||||
|
||||
// 将controller添加到_internal中的flowControl状态
|
||||
if (!ctx._internal.flowControl) {
|
||||
ctx._internal.flowControl = {}
|
||||
}
|
||||
ctx._internal.flowControl.abortController = abortController
|
||||
ctx._internal.flowControl.abortSignal = abortSignal
|
||||
ctx._internal.flowControl.cleanup = cleanup
|
||||
|
||||
const result = await next(ctx, params)
|
||||
|
||||
const error = new DOMException('Request was aborted', 'AbortError')
|
||||
|
||||
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
||||
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
||||
transform(chunk, controller) {
|
||||
// 检查 abort 状态
|
||||
if (abortSignal?.aborted) {
|
||||
// 转换为 ErrorChunk
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
|
||||
controller.enqueue(errorChunk)
|
||||
cleanup()
|
||||
return
|
||||
}
|
||||
|
||||
// 正常传递 chunk
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
|
||||
flush(controller) {
|
||||
// 在流结束时再次检查 abort 状态
|
||||
if (abortSignal?.aborted) {
|
||||
const errorChunk: ErrorChunk = {
|
||||
type: ChunkType.ERROR,
|
||||
error
|
||||
}
|
||||
controller.enqueue(errorChunk)
|
||||
}
|
||||
// 在流完全处理完成后清理 AbortController
|
||||
cleanup()
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: streamWithAbortHandler
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,60 @@
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
|
||||
import { CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext } from '../types'
|
||||
import { createErrorChunk } from '../utils'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
|
||||
|
||||
/**
|
||||
* 创建一个错误处理中间件。
|
||||
*
|
||||
* 这是一个高阶函数,它接收配置并返回一个标准的中间件。
|
||||
* 它的主要职责是捕获下游中间件或API调用中发生的任何错误。
|
||||
*
|
||||
* @param config - 中间件的配置。
|
||||
* @returns 一个配置好的CompletionsMiddleware。
|
||||
*/
|
||||
export const ErrorHandlerMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
|
||||
const { shouldThrow } = params
|
||||
|
||||
try {
|
||||
// 尝试执行下一个中间件
|
||||
return await next(ctx, params)
|
||||
} catch (error: any) {
|
||||
let errorStream: ReadableStream<Chunk> | undefined
|
||||
// 有些sdk的abort error 是直接抛出的
|
||||
if (!isAbortError(error)) {
|
||||
// 1. 使用通用的工具函数将错误解析为标准格式
|
||||
const errorChunk = createErrorChunk(error)
|
||||
// 2. 调用从外部传入的 onError 回调
|
||||
if (params.onError) {
|
||||
params.onError(error)
|
||||
}
|
||||
|
||||
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
|
||||
if (shouldThrow) {
|
||||
throw error
|
||||
}
|
||||
|
||||
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
|
||||
errorStream = new ReadableStream<Chunk>({
|
||||
start(controller) {
|
||||
controller.enqueue(errorChunk)
|
||||
controller.close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
rawOutput: undefined,
|
||||
stream: errorStream, // 将包含错误的流传递下去
|
||||
controller: undefined,
|
||||
getText: () => '' // 错误情况下没有文本结果
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,183 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { Usage } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||
|
||||
/**
|
||||
* 最终Chunk消费和通知中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调
|
||||
* 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取)
|
||||
* 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
|
||||
* 4. 处理MCP工具调用的多轮请求中的数据累加
|
||||
*/
|
||||
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const isRecursiveCall =
|
||||
params._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
ctx._internal?.toolProcessingState?.isRecursiveCall ||
|
||||
false
|
||||
|
||||
// 初始化累计数据(只在顶层调用时初始化)
|
||||
if (!isRecursiveCall) {
|
||||
if (!ctx._internal.customState) {
|
||||
ctx._internal.customState = {}
|
||||
}
|
||||
ctx._internal.observer = {
|
||||
usage: {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0
|
||||
},
|
||||
metrics: {
|
||||
completion_tokens: 0,
|
||||
time_completion_millsec: 0,
|
||||
time_first_token_millsec: 0,
|
||||
time_thinking_millsec: 0
|
||||
}
|
||||
}
|
||||
// 初始化文本累积器
|
||||
ctx._internal.customState.accumulatedText = ''
|
||||
ctx._internal.customState.startTimestamp = Date.now()
|
||||
}
|
||||
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理GenericChunk流式响应
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const reader = resultFromUpstream.getReader()
|
||||
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value: chunk } = await reader.read()
|
||||
if (done) {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`)
|
||||
break
|
||||
}
|
||||
|
||||
if (chunk) {
|
||||
const genericChunk = chunk as GenericChunk
|
||||
// 提取并累加usage/metrics数据
|
||||
extractAndAccumulateUsageMetrics(ctx, genericChunk)
|
||||
|
||||
const shouldSkipChunk =
|
||||
isRecursiveCall &&
|
||||
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
|
||||
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
|
||||
|
||||
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`)
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error)
|
||||
throw error
|
||||
} finally {
|
||||
if (params.onChunk && !isRecursiveCall) {
|
||||
params.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
|
||||
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
|
||||
}
|
||||
} as Chunk)
|
||||
if (ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 为流式输出添加getText方法
|
||||
const modifiedResult = {
|
||||
...result,
|
||||
stream: new ReadableStream<GenericChunk>({
|
||||
start(controller) {
|
||||
controller.close()
|
||||
}
|
||||
}),
|
||||
getText: () => {
|
||||
return ctx._internal.customState?.accumulatedText || ''
|
||||
}
|
||||
}
|
||||
|
||||
return modifiedResult
|
||||
} else {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
|
||||
*/
|
||||
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
|
||||
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.customState.firstTokenTimestamp = Date.now()
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
||||
}
|
||||
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
|
||||
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
||||
if (chunk.response?.usage) {
|
||||
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
||||
}
|
||||
|
||||
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
|
||||
ctx._internal.observer.metrics.time_first_token_millsec =
|
||||
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
|
||||
ctx._internal.observer.metrics.time_completion_millsec +=
|
||||
Date.now() - ctx._internal.customState.firstTokenTimestamp
|
||||
}
|
||||
}
|
||||
|
||||
// 也可以从其他chunk类型中提取metrics数据
|
||||
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
|
||||
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
|
||||
ctx._internal.observer.metrics.time_thinking_millsec || 0,
|
||||
chunk.thinking_millsec
|
||||
)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 累加usage数据
|
||||
*/
|
||||
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
|
||||
if (newUsage.prompt_tokens !== undefined) {
|
||||
accumulated.prompt_tokens += newUsage.prompt_tokens
|
||||
}
|
||||
if (newUsage.completion_tokens !== undefined) {
|
||||
accumulated.completion_tokens += newUsage.completion_tokens
|
||||
}
|
||||
if (newUsage.total_tokens !== undefined) {
|
||||
accumulated.total_tokens += newUsage.total_tokens
|
||||
}
|
||||
if (newUsage.thoughts_tokens !== undefined) {
|
||||
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
|
||||
}
|
||||
}
|
||||
|
||||
export default FinalChunkConsumerMiddleware
|
||||
@ -0,0 +1,64 @@
|
||||
import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
|
||||
|
||||
/**
|
||||
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
|
||||
* 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。
|
||||
* @param args - The arguments array to stringify. 要字符串化的参数数组。
|
||||
* @returns A string representation of the arguments. 参数的字符串表示形式。
|
||||
*/
|
||||
const stringifyArgsForLogging = (args: any[]): string => {
|
||||
try {
|
||||
return args
|
||||
.map((arg) => {
|
||||
if (typeof arg === 'function') return '[Function]'
|
||||
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
|
||||
return '[Object with >20 keys]'
|
||||
}
|
||||
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
|
||||
const stringifiedArg = JSON.stringify(arg, null, 2)
|
||||
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
|
||||
})
|
||||
.join(', ')
|
||||
} catch (e) {
|
||||
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* This middleware logs the initiation, success/failure, and duration of a method call.
|
||||
* 此中间件记录方法调用的启动、成功/失败以及持续时间。
|
||||
*/
|
||||
|
||||
/**
|
||||
* Creates a generic logging middleware for provider methods.
|
||||
* 为提供者方法创建一个通用的日志中间件。
|
||||
* @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。
|
||||
*/
|
||||
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
|
||||
const middlewareName = 'GenericLoggingMiddleware'
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
|
||||
const methodName = ctx.methodName
|
||||
const logPrefix = `[${middlewareName} (${methodName})]`
|
||||
console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args))
|
||||
const startTime = Date.now()
|
||||
try {
|
||||
const result = await next(ctx, args)
|
||||
const duration = Date.now() - startTime
|
||||
// Log successful completion of the method call with duration. /
|
||||
// 记录方法调用成功完成及其持续时间。
|
||||
console.log(`${logPrefix} Successful. Duration: ${duration}ms`)
|
||||
return result
|
||||
} catch (error) {
|
||||
const duration = Date.now() - startTime
|
||||
// Log failure of the method call with duration and error information. /
|
||||
// 记录方法调用失败及其持续时间和错误信息。
|
||||
console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error)
|
||||
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
|
||||
}
|
||||
}
|
||||
}
|
||||
285
src/renderer/src/aiCore/middleware/composer.ts
Normal file
285
src/renderer/src/aiCore/middleware/composer.ts
Normal file
@ -0,0 +1,285 @@
|
||||
import {
|
||||
RequestOptions,
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { BaseApiClient } from '../clients'
|
||||
import { CompletionsParams, CompletionsResult } from './schemas'
|
||||
import {
|
||||
BaseContext,
|
||||
CompletionsContext,
|
||||
CompletionsMiddleware,
|
||||
MethodMiddleware,
|
||||
MIDDLEWARE_CONTEXT_SYMBOL,
|
||||
MiddlewareAPI
|
||||
} from './types'
|
||||
|
||||
/**
|
||||
* Creates the initial context for a method call, populating method-specific fields. /
|
||||
* 为方法调用创建初始上下文,并填充特定于该方法的字段。
|
||||
* @param methodName - The name of the method being called. / 被调用的方法名。
|
||||
* @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。
|
||||
* @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。
|
||||
* @param providerInstance - The instance of the provider. / 提供者实例。
|
||||
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。
|
||||
* @returns The created context object. / 创建的上下文对象。
|
||||
*/
|
||||
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
|
||||
methodName: string,
|
||||
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
|
||||
// Factory to create specific context from base and the *original call arguments array*
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
|
||||
): TContext {
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs // Store the full original arguments array in the context
|
||||
}
|
||||
|
||||
if (specificContextFactory) {
|
||||
return specificContextFactory(baseContext, originalCallArgs)
|
||||
}
|
||||
return baseContext as TContext // Fallback to base context if no specific factory
|
||||
}
|
||||
|
||||
/**
|
||||
* Composes an array of functions from right to left. /
|
||||
* 从右到左组合一个函数数组。
|
||||
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
|
||||
* `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。
|
||||
* Each function in funcs is expected to take the result of the next function
|
||||
* (or the initial value for the rightmost function) as its argument. /
|
||||
* `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。
|
||||
* @param funcs - Array of functions to compose. / 要组合的函数数组。
|
||||
* @returns The composed function. / 组合后的函数。
|
||||
*/
|
||||
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
|
||||
if (funcs.length === 0) {
|
||||
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
|
||||
// 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。
|
||||
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
|
||||
}
|
||||
if (funcs.length === 1) {
|
||||
return funcs[0]
|
||||
}
|
||||
return funcs.reduce(
|
||||
(a, b) =>
|
||||
(...args: any[]) =>
|
||||
a(b(...args))
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of Redux-style middlewares to a generic provider method. /
|
||||
* 将一组Redux风格的中间件应用于一个通用的提供者方法。
|
||||
* This version keeps arguments as an array throughout the middleware chain. /
|
||||
* 此版本在整个中间件链中将参数保持为数组形式。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param methodName - The name of the method to be enhanced. / 需要增强的方法名。
|
||||
* @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。
|
||||
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。
|
||||
* @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。
|
||||
* @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。
|
||||
*/
|
||||
export function applyMethodMiddlewares<
|
||||
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
|
||||
TResult = unknown,
|
||||
TContext extends BaseContext = BaseContext
|
||||
>(
|
||||
methodName: string,
|
||||
originalMethod: (...args: TArgs) => Promise<TResult>,
|
||||
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
|
||||
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
|
||||
): (...args: TArgs) => Promise<TResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
|
||||
const ctx = createInitialCallContext<TContext, TArgs>(
|
||||
methodName,
|
||||
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
|
||||
specificContextFactory
|
||||
)
|
||||
|
||||
const api: MiddlewareAPI<TContext, TArgs> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
|
||||
}
|
||||
|
||||
// `finalDispatch` is the function that will ultimately call the original provider method. /
|
||||
// `finalDispatch` 是最终将调用原始提供者方法的函数。
|
||||
// It receives the current context and arguments, which may have been transformed by middlewares. /
|
||||
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
|
||||
const finalDispatch = async (
|
||||
_: TContext,
|
||||
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
|
||||
): Promise<TResult> => {
|
||||
return originalMethod.apply(currentArgs)
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
|
||||
* 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。
|
||||
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
|
||||
* 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。
|
||||
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。
|
||||
* @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。
|
||||
* @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。
|
||||
*/
|
||||
export function applyCompletionsMiddlewares<
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
>(
|
||||
originalApiClientInstance: BaseApiClient<
|
||||
TSdkInstance,
|
||||
TSdkParams,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
|
||||
middlewares: CompletionsMiddleware<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>[]
|
||||
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
|
||||
// Returns a function matching the original method signature. /
|
||||
// 返回一个与原始方法签名匹配的函数。
|
||||
|
||||
const methodName = 'completions'
|
||||
|
||||
// Factory to create AiProviderMiddlewareCompletionsContext. /
|
||||
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
|
||||
const completionsContextFactory = (
|
||||
base: BaseContext,
|
||||
callArgs: [CompletionsParams]
|
||||
): CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> => {
|
||||
return {
|
||||
...base,
|
||||
methodName,
|
||||
apiClientInstance: originalApiClientInstance,
|
||||
originalArgs: callArgs,
|
||||
_internal: {
|
||||
toolProcessingState: {
|
||||
recursionDepth: 0,
|
||||
isRecursiveCall: false
|
||||
},
|
||||
observer: {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return async function enhancedCompletionsMethod(
|
||||
params: CompletionsParams,
|
||||
options?: RequestOptions
|
||||
): Promise<CompletionsResult> {
|
||||
// `originalCallArgs` for context creation is `[params]`. /
|
||||
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
|
||||
const originalCallArgs: [CompletionsParams] = [params]
|
||||
const baseContext: BaseContext = {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||
methodName,
|
||||
originalArgs: originalCallArgs
|
||||
}
|
||||
const ctx = completionsContextFactory(baseContext, originalCallArgs)
|
||||
|
||||
const api: MiddlewareAPI<
|
||||
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
|
||||
[CompletionsParams]
|
||||
> = {
|
||||
getContext: () => ctx,
|
||||
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
|
||||
}
|
||||
|
||||
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
|
||||
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
|
||||
const finalDispatch = async (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TMessageParam,
|
||||
TToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
> // Context passed through / 上下文透传
|
||||
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
|
||||
): Promise<CompletionsResult> => {
|
||||
// At this point, middleware should have transformed CompletionsParams to SDK params
|
||||
// and stored them in context. If no transformation happened, we need to handle it.
|
||||
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
|
||||
// 如果没有进行转换,我们需要处理它。
|
||||
|
||||
const sdkPayload = context._internal?.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
|
||||
}
|
||||
|
||||
const abortSignal = context._internal.flowControl?.abortSignal
|
||||
const timeout = context._internal.customState?.sdkMetadata?.timeout
|
||||
|
||||
// Call the original SDK method with transformed parameters
|
||||
// 使用转换后的参数调用原始 SDK 方法
|
||||
const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, {
|
||||
...options,
|
||||
signal: abortSignal,
|
||||
timeout
|
||||
})
|
||||
|
||||
// Return result wrapped in CompletionsResult format
|
||||
// 以 CompletionsResult 格式返回包装的结果
|
||||
return {
|
||||
rawOutput
|
||||
} as CompletionsResult
|
||||
}
|
||||
|
||||
const chain = middlewares.map((middleware) => middleware(api))
|
||||
const composedMiddlewareLogic = compose(...chain)
|
||||
|
||||
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
|
||||
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
|
||||
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||
|
||||
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
|
||||
// 这样可以避免重复执行整个中间件链
|
||||
ctx._internal.enhancedDispatch = enhancedDispatch
|
||||
|
||||
// Execute with context and the single params object. /
|
||||
// 使用上下文和单个参数对象执行。
|
||||
return enhancedDispatch(ctx, params)
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,306 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
|
||||
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
||||
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
|
||||
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
|
||||
|
||||
/**
|
||||
* MCP工具处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式)
|
||||
* 2. 执行工具调用
|
||||
* 3. 递归处理工具结果
|
||||
* 4. 管理工具调用状态和递归深度
|
||||
*/
|
||||
export const McpToolChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
// 如果没有工具,直接调用下一个中间件
|
||||
if (!mcpTools || mcpTools.length === 0) {
|
||||
return next(ctx, params)
|
||||
}
|
||||
|
||||
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
|
||||
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||
}
|
||||
|
||||
let result: CompletionsResult
|
||||
|
||||
if (depth === 0) {
|
||||
result = await next(ctx, currentParams)
|
||||
} else {
|
||||
const enhancedCompletions = ctx._internal.enhancedDispatch
|
||||
if (!enhancedCompletions) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`)
|
||||
throw new Error('Enhanced completions method not found')
|
||||
}
|
||||
|
||||
ctx._internal.toolProcessingState!.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState!.recursionDepth = depth
|
||||
|
||||
result = await enhancedCompletions(ctx, currentParams)
|
||||
}
|
||||
|
||||
if (!result.stream) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`)
|
||||
throw new Error('No stream returned from enhanced completions')
|
||||
}
|
||||
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
const toolHandlingStream = resultFromUpstream.pipeThrough(
|
||||
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: toolHandlingStream
|
||||
}
|
||||
}
|
||||
|
||||
return executeWithToolHandling(params, 0)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建工具处理的 TransformStream
|
||||
*/
|
||||
function createToolHandlingTransform(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
mcpTools: MCPTool[],
|
||||
depth: number,
|
||||
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const toolCalls: SdkToolCall[] = []
|
||||
const toolUseResponses: MCPToolResponse[] = []
|
||||
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
|
||||
let hasToolCalls = false
|
||||
let hasToolUseResponses = false
|
||||
let streamEnded = false
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
// 处理MCP工具进展chunk
|
||||
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
|
||||
const createdChunk = chunk as MCPToolCreatedChunk
|
||||
|
||||
// 1. 处理Function Call方式的工具调用
|
||||
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
||||
toolCalls.push(...createdChunk.tool_calls)
|
||||
hasToolCalls = true
|
||||
}
|
||||
|
||||
// 2. 处理Tool Use方式的工具调用
|
||||
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
||||
toolUseResponses.push(...createdChunk.tool_use_responses)
|
||||
hasToolUseResponses = true
|
||||
}
|
||||
|
||||
// 不转发MCP工具进展chunks,避免重复处理
|
||||
return
|
||||
}
|
||||
|
||||
// 转发其他所有chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
|
||||
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
|
||||
|
||||
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
|
||||
streamEnded = true
|
||||
|
||||
try {
|
||||
let toolResult: SdkMessageParam[] = []
|
||||
|
||||
if (shouldExecuteToolCalls) {
|
||||
toolResult = await executeToolCalls(
|
||||
ctx,
|
||||
toolCalls,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
} else if (shouldExecuteToolUseResponses) {
|
||||
toolResult = await executeToolUseResponses(
|
||||
ctx,
|
||||
toolUseResponses,
|
||||
mcpTools,
|
||||
allToolResponses,
|
||||
currentParams.onChunk,
|
||||
currentParams.assistant.model!
|
||||
)
|
||||
}
|
||||
|
||||
if (toolResult.length > 0) {
|
||||
const output = ctx._internal.toolProcessingState?.output
|
||||
|
||||
const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls)
|
||||
await executeWithToolHandling(newParams, depth + 1)
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
||||
controller.error(error)
|
||||
} finally {
|
||||
hasToolCalls = false
|
||||
hasToolUseResponses = false
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具调用(Function Call 方式)
|
||||
*/
|
||||
async function executeToolCalls(
|
||||
ctx: CompletionsContext,
|
||||
toolCalls: SdkToolCall[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
// 转换为MCPToolResponse格式
|
||||
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
if (!mcpTool) {
|
||||
return undefined
|
||||
}
|
||||
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
})
|
||||
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||
|
||||
if (mcpToolResponses.length === 0) {
|
||||
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
|
||||
return []
|
||||
}
|
||||
|
||||
// 使用现有的parseAndCallTools函数执行工具
|
||||
const toolResults = await parseAndCallTools(
|
||||
mcpToolResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
)
|
||||
|
||||
return toolResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行工具使用响应(Tool Use Response 方式)
|
||||
* 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串
|
||||
*/
|
||||
async function executeToolUseResponses(
|
||||
ctx: CompletionsContext,
|
||||
toolUseResponses: MCPToolResponse[],
|
||||
mcpTools: MCPTool[],
|
||||
allToolResponses: MCPToolResponse[],
|
||||
onChunk: CompletionsParams['onChunk'],
|
||||
model: Model
|
||||
): Promise<SdkMessageParam[]> {
|
||||
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
||||
const toolResults = await parseAndCallTools(
|
||||
toolUseResponses,
|
||||
allToolResponses,
|
||||
onChunk,
|
||||
(mcpToolResponse, resp, model) => {
|
||||
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||
},
|
||||
model,
|
||||
mcpTools
|
||||
)
|
||||
|
||||
return toolResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建包含工具结果的新参数
|
||||
*/
|
||||
function buildParamsWithToolResults(
|
||||
ctx: CompletionsContext,
|
||||
currentParams: CompletionsParams,
|
||||
output: SdkRawOutput | string,
|
||||
toolResults: SdkMessageParam[],
|
||||
toolCalls: SdkToolCall[]
|
||||
): CompletionsParams {
|
||||
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
||||
const currentReqMessages = getCurrentReqMessages(ctx)
|
||||
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
// 从回复中构建助手消息
|
||||
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||
|
||||
// 估算新增消息的 token 消耗并累加到 usage 中
|
||||
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
|
||||
try {
|
||||
const newMessages = newReqMessages.slice(currentReqMessages.length)
|
||||
const additionalTokens = newMessages.reduce((acc, message) => {
|
||||
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
|
||||
}, 0)
|
||||
|
||||
if (additionalTokens > 0) {
|
||||
ctx._internal.observer.usage.prompt_tokens += additionalTokens
|
||||
ctx._internal.observer.usage.total_tokens += additionalTokens
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新递归状态
|
||||
if (!ctx._internal.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState = {}
|
||||
}
|
||||
ctx._internal.toolProcessingState.isRecursiveCall = true
|
||||
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
|
||||
|
||||
return {
|
||||
...currentParams,
|
||||
_internal: {
|
||||
...ctx._internal,
|
||||
sdkPayload: ctx._internal.sdkPayload,
|
||||
newReqMessages: newReqMessages
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型安全地获取当前请求消息
|
||||
* 使用API客户端提供的抽象方法,保持中间件的provider无关性
|
||||
*/
|
||||
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
|
||||
const sdkPayload = ctx._internal.sdkPayload
|
||||
if (!sdkPayload) {
|
||||
return []
|
||||
}
|
||||
|
||||
// 使用API客户端的抽象方法来提取消息,保持provider无关性
|
||||
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
|
||||
}
|
||||
|
||||
export default McpToolChunkMiddleware
|
||||
@ -0,0 +1,48 @@
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { AnthropicStreamListener } from '../../clients/types'
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
|
||||
|
||||
export const RawStreamListenerMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 在这里可以监听到从SDK返回的最原始流
|
||||
if (result.rawOutput) {
|
||||
console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出,准备附加监听器`)
|
||||
|
||||
const providerType = ctx.apiClientInstance.provider.type
|
||||
// TODO: 后面下放到AnthropicAPIClient
|
||||
if (providerType === 'anthropic') {
|
||||
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
|
||||
onMessage: (message) => {
|
||||
if (ctx._internal?.toolProcessingState) {
|
||||
ctx._internal.toolProcessingState.output = message
|
||||
}
|
||||
}
|
||||
// onContentBlock: (contentBlock) => {
|
||||
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
|
||||
// }
|
||||
}
|
||||
|
||||
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
|
||||
|
||||
const monitoredOutput = specificApiClient.attachRawStreamListener(
|
||||
result.rawOutput as AnthropicSdkRawOutput,
|
||||
anthropicListener
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
rawOutput: monitoredOutput
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,85 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||
|
||||
import { ResponseChunkTransformerContext } from '../../clients/types'
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
|
||||
|
||||
/**
|
||||
* 响应转换中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ReadableStream类型的响应流
|
||||
* 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式
|
||||
* 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用
|
||||
*
|
||||
* 注意:此中间件应该在StreamAdapterMiddleware之后执行
|
||||
*/
|
||||
export const ResponseTransformMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换原始SDK响应块
|
||||
if (result.stream) {
|
||||
const adaptedStream = result.stream
|
||||
|
||||
// 处理ReadableStream类型的流
|
||||
if (adaptedStream instanceof ReadableStream) {
|
||||
const apiClient = ctx.apiClientInstance
|
||||
if (!apiClient) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`)
|
||||
throw new Error('ApiClient instance not found in context')
|
||||
}
|
||||
|
||||
// 获取响应转换器
|
||||
const responseChunkTransformer = apiClient.getResponseChunkTransformer?.()
|
||||
if (!responseChunkTransformer) {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
|
||||
return result
|
||||
}
|
||||
|
||||
const assistant = params.assistant
|
||||
const model = assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`)
|
||||
throw new Error('Assistant or Model not found for transformation')
|
||||
}
|
||||
|
||||
const transformerContext: ResponseChunkTransformerContext = {
|
||||
isStreaming: params.streamOutput || false,
|
||||
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
|
||||
isEnabledWebSearch: params.enableWebSearch || false,
|
||||
isEnabledReasoning: params.enableReasoning || false,
|
||||
mcpTools: params.mcpTools || [],
|
||||
provider: ctx.apiClientInstance?.provider
|
||||
}
|
||||
|
||||
console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext)
|
||||
|
||||
try {
|
||||
// 创建转换后的流
|
||||
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
|
||||
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
|
||||
)
|
||||
|
||||
// 将转换后的ReadableStream保存到result,供下游中间件使用
|
||||
return {
|
||||
...result,
|
||||
stream: genericChunkTransformStream
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有流或不是ReadableStream,返回原始结果
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,57 @@
|
||||
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
import { isAsyncIterable } from '../utils'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
|
||||
|
||||
/**
|
||||
* 流适配器中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流
|
||||
* 2. 将AsyncIterable转换为WHATWG ReadableStream
|
||||
* 3. 更新响应结果中的stream
|
||||
*
|
||||
* 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream
|
||||
*/
|
||||
export const StreamAdapterMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了
|
||||
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
if (
|
||||
result.rawOutput &&
|
||||
!(result.rawOutput instanceof ReadableStream) &&
|
||||
isAsyncIterable<SdkRawChunk>(result.rawOutput)
|
||||
) {
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
|
||||
return {
|
||||
...result,
|
||||
stream: result.rawOutput
|
||||
}
|
||||
} else if (result.rawOutput) {
|
||||
// 非流式输出,强行变为可读流
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
stream: whatwgReadableStream
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,99 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
|
||||
|
||||
/**
|
||||
* 文本块处理中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 累积文本内容(TEXT_DELTA)
|
||||
* 2. 对文本内容进行智能链接转换
|
||||
* 3. 生成TEXT_COMPLETE事件
|
||||
* 4. 暂存Web搜索结果,用于最终链接完善
|
||||
* 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本
|
||||
*/
|
||||
export const TextChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:转换流式响应中的文本内容
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
const assistant = params.assistant
|
||||
const model = params.assistant?.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`)
|
||||
return result
|
||||
}
|
||||
|
||||
// 用于跨chunk的状态管理
|
||||
let accumulatedTextContent = ''
|
||||
let hasEnqueue = false
|
||||
const enhancedTextStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
accumulatedTextContent += textChunk.text
|
||||
|
||||
// 处理 onResponse 回调 - 发送增量文本更新
|
||||
if (params.onResponse) {
|
||||
params.onResponse(accumulatedTextContent, false)
|
||||
}
|
||||
|
||||
// 创建新的chunk,包含处理后的文本
|
||||
controller.enqueue(chunk)
|
||||
} else if (accumulatedTextContent) {
|
||||
if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||
controller.enqueue(chunk)
|
||||
hasEnqueue = true
|
||||
}
|
||||
const finalText = accumulatedTextContent
|
||||
ctx._internal.customState!.accumulatedText = finalText
|
||||
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
|
||||
ctx._internal.toolProcessingState.output = finalText
|
||||
}
|
||||
|
||||
// 处理 onResponse 回调 - 发送最终完整文本
|
||||
if (params.onResponse) {
|
||||
params.onResponse(finalText, true)
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: finalText
|
||||
})
|
||||
accumulatedTextContent = ''
|
||||
if (!hasEnqueue) {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
} else {
|
||||
// 其他类型的chunk直接传递
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: enhancedTextStream
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
101
src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts
Normal file
101
src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts
Normal file
@ -0,0 +1,101 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ChunkType, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware'
|
||||
|
||||
/**
|
||||
* 处理思考内容的中间件
|
||||
*
|
||||
* 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理
|
||||
* 此中间件现在主要负责:
|
||||
* 1. 处理原始SDK chunk中的reasoning字段
|
||||
* 2. 计算准确的思考时间
|
||||
* 3. 在思考内容结束时生成THINKING_COMPLETE事件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 累积思考内容(THINKING_DELTA)
|
||||
* 2. 监听流结束信号,生成THINKING_COMPLETE事件
|
||||
* 3. 计算准确的思考时间
|
||||
*
|
||||
*/
|
||||
export const ThinkChunkMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理思考内容
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
// 检查是否启用reasoning
|
||||
const enableReasoning = params.enableReasoning || false
|
||||
if (!enableReasoning) {
|
||||
return result
|
||||
}
|
||||
|
||||
// 检查是否有流需要处理
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// thinking 处理状态
|
||||
let accumulatedThinkingContent = ''
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.THINKING_DELTA) {
|
||||
const thinkingChunk = chunk as ThinkingDeltaChunk
|
||||
|
||||
// 第一次接收到思考内容时记录开始时间
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = Date.now()
|
||||
}
|
||||
|
||||
accumulatedThinkingContent += thinkingChunk.text
|
||||
|
||||
// 更新思考时间并传递
|
||||
const enhancedChunk: ThinkingDeltaChunk = {
|
||||
...thinkingChunk,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(enhancedChunk)
|
||||
} else if (hasThinkingContent && thinkingStartTime > 0) {
|
||||
// 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: accumulatedThinkingContent,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
hasThinkingContent = false
|
||||
accumulatedThinkingContent = ''
|
||||
thinkingStartTime = 0
|
||||
|
||||
// 继续传递当前chunk
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
// 其他情况直接传递
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,83 @@
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware'
|
||||
|
||||
/**
|
||||
* 中间件:将CoreCompletionsRequest转换为SDK特定的参数
|
||||
* 使用上下文中ApiClient实例的requestTransformer进行转换
|
||||
*/
|
||||
export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx)
|
||||
|
||||
const internal = ctx._internal
|
||||
|
||||
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
|
||||
const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false
|
||||
const newSdkMessages = params._internal?.newReqMessages
|
||||
|
||||
const apiClient = ctx.apiClientInstance
|
||||
|
||||
if (!apiClient) {
|
||||
Logger.error(`🔄 [${MIDDLEWARE_NAME}] ApiClient instance not found in context.`)
|
||||
throw new Error('ApiClient instance not found in context')
|
||||
}
|
||||
|
||||
// 检查是否有requestTransformer方法
|
||||
const requestTransformer = apiClient.getRequestTransformer()
|
||||
if (!requestTransformer) {
|
||||
Logger.warn(
|
||||
`🔄 [${MIDDLEWARE_NAME}] ApiClient does not have getRequestTransformer method, skipping transformation`
|
||||
)
|
||||
const result = await next(ctx, params)
|
||||
return result
|
||||
}
|
||||
|
||||
// 确保assistant和model可用,它们是transformer所需的
|
||||
const assistant = params.assistant
|
||||
const model = params.assistant.model
|
||||
|
||||
if (!assistant || !model) {
|
||||
console.error(`🔄 [${MIDDLEWARE_NAME}] Assistant or Model not found for transformation.`)
|
||||
throw new Error('Assistant or Model not found for transformation')
|
||||
}
|
||||
|
||||
try {
|
||||
const transformResult = await requestTransformer.transform(
|
||||
params,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
newSdkMessages
|
||||
)
|
||||
|
||||
const { payload: sdkPayload, metadata } = transformResult
|
||||
|
||||
// 将SDK特定的payload和metadata存储在状态中,供下游中间件使用
|
||||
ctx._internal.sdkPayload = sdkPayload
|
||||
|
||||
if (metadata) {
|
||||
ctx._internal.customState = {
|
||||
...ctx._internal.customState,
|
||||
sdkMetadata: metadata
|
||||
}
|
||||
}
|
||||
|
||||
if (params.enableGenerateImage) {
|
||||
params.onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
}
|
||||
return next(ctx, params)
|
||||
} catch (error) {
|
||||
Logger.error(`🔄 [${MIDDLEWARE_NAME}] Error during request transformation:`, error)
|
||||
// 让错误向上传播,或者可以在这里进行特定的错误处理
|
||||
throw error
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,76 @@
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { smartLinkConverter } from '@renderer/utils/linkConverter'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'WebSearchMiddleware'
|
||||
|
||||
/**
|
||||
* Web搜索处理中间件 - 基于GenericChunk流处理
|
||||
*
|
||||
* 职责:
|
||||
* 1. 监听和记录Web搜索事件
|
||||
* 2. 可以在此处添加Web搜索结果的后处理逻辑
|
||||
* 3. 维护Web搜索相关的状态
|
||||
*
|
||||
* 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理
|
||||
*/
|
||||
export const WebSearchMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
ctx._internal.webSearchState = {
|
||||
results: undefined
|
||||
}
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
const model = params.assistant?.model!
|
||||
let isFirstChunk = true
|
||||
|
||||
// 响应后处理:记录Web搜索事件
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream
|
||||
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// Web搜索状态跟踪
|
||||
const enhancedStream = (resultFromUpstream as ReadableStream<GenericChunk>).pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const providerType = model.provider || 'openai'
|
||||
// 使用当前可用的Web搜索结果进行链接转换
|
||||
const text = chunk.text
|
||||
const processedText = smartLinkConverter(text, providerType, isFirstChunk)
|
||||
if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
}
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
text: processedText
|
||||
})
|
||||
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
|
||||
// 暂存Web搜索结果用于链接完善
|
||||
ctx._internal.webSearchState!.results = chunk.llm_web_search
|
||||
|
||||
// 将Web搜索完成事件继续传递下去
|
||||
controller.enqueue(chunk)
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: enhancedStream
|
||||
}
|
||||
} else {
|
||||
console.log(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,132 @@
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import OpenAI from 'openai'
|
||||
import { toFile } from 'openai/uploads'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware'
|
||||
|
||||
export const ImageGenerationMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const { assistant, messages } = params
|
||||
const client = context.apiClientInstance as BaseApiClient<OpenAI>
|
||||
const signal = context._internal?.flowControl?.abortSignal
|
||||
|
||||
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
|
||||
return next(context, params)
|
||||
}
|
||||
|
||||
const stream = new ReadableStream<GenericChunk>({
|
||||
async start(controller) {
|
||||
const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk)
|
||||
|
||||
try {
|
||||
if (!assistant.model) {
|
||||
throw new Error('Assistant model is not defined.')
|
||||
}
|
||||
|
||||
const sdk = await client.getSdkInstance()
|
||||
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
|
||||
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
const prompt = getMainTextContent(lastUserMessage)
|
||||
let imageFiles: Blob[] = []
|
||||
|
||||
// Collect images from user message
|
||||
const userImageBlocks = findImageBlocks(lastUserMessage)
|
||||
const userImages = await Promise.all(
|
||||
userImageBlocks.map(async (block) => {
|
||||
if (!block.file) return null
|
||||
const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id)
|
||||
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
|
||||
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
|
||||
})
|
||||
)
|
||||
imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[])
|
||||
|
||||
// Collect images from last assistant message
|
||||
if (lastAssistantMessage) {
|
||||
const assistantImageBlocks = findImageBlocks(lastAssistantMessage)
|
||||
const assistantImages = await Promise.all(
|
||||
assistantImageBlocks.map(async (block) => {
|
||||
const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '')
|
||||
if (!b64) return null
|
||||
const binary = atob(b64)
|
||||
const bytes = new Uint8Array(binary.length)
|
||||
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
|
||||
return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' })
|
||||
})
|
||||
)
|
||||
imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[])
|
||||
}
|
||||
|
||||
enqueue({ type: ChunkType.IMAGE_CREATED })
|
||||
|
||||
const startTime = Date.now()
|
||||
let response: OpenAI.Images.ImagesResponse
|
||||
|
||||
const options = { signal, timeout: 300_000 }
|
||||
|
||||
if (imageFiles.length > 0) {
|
||||
response = await sdk.images.edit(
|
||||
{
|
||||
model: assistant.model.id,
|
||||
image: imageFiles,
|
||||
prompt: prompt || ''
|
||||
},
|
||||
options
|
||||
)
|
||||
} else {
|
||||
response = await sdk.images.generate(
|
||||
{
|
||||
model: assistant.model.id,
|
||||
prompt: prompt || '',
|
||||
response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json'
|
||||
},
|
||||
options
|
||||
)
|
||||
}
|
||||
|
||||
const b64_json_array = response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || []
|
||||
|
||||
enqueue({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: 'base64', images: b64_json_array }
|
||||
})
|
||||
|
||||
const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
|
||||
|
||||
enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch (error: any) {
|
||||
enqueue({ type: ChunkType.ERROR, error })
|
||||
} finally {
|
||||
controller.close()
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
stream,
|
||||
getText: () => ''
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,136 @@
|
||||
import { Model } from '@renderer/types'
|
||||
import { ChunkType, TextDeltaChunk, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
|
||||
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
|
||||
import Logger from 'electron-log/renderer'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
|
||||
|
||||
// 不同模型的思考标签配置
|
||||
const reasoningTags: TagConfig[] = [
|
||||
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
|
||||
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' }
|
||||
]
|
||||
|
||||
const getAppropriateTag = (model?: Model): TagConfig => {
|
||||
if (model?.id?.includes('qwen3')) return reasoningTags[0]
|
||||
// 可以在这里添加更多模型特定的标签配置
|
||||
return reasoningTags[0] // 默认使用 <think> 标签
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本流中思考标签提取的中间件
|
||||
*
|
||||
* 该中间件专门处理文本流中的思考标签内容(如 <think>...</think>)
|
||||
* 主要用于 OpenAI 等支持思考标签的 provider
|
||||
*
|
||||
* 职责:
|
||||
* 1. 从文本流中提取思考标签内容
|
||||
* 2. 将标签内的内容转换为 THINKING_DELTA chunk
|
||||
* 3. 将标签外的内容作为正常文本输出
|
||||
* 4. 处理不同模型的思考标签格式
|
||||
* 5. 在思考内容结束时生成 THINKING_COMPLETE 事件
|
||||
*/
|
||||
export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
// 调用下游中间件
|
||||
const result = await next(context, params)
|
||||
|
||||
// 响应后处理:处理思考标签提取
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
// 检查是否有流需要处理
|
||||
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||
// 获取当前模型的思考标签配置
|
||||
const model = params.assistant?.model
|
||||
const reasoningTag = getAppropriateTag(model)
|
||||
|
||||
// 创建标签提取器
|
||||
const tagExtractor = new TagExtractor(reasoningTag)
|
||||
|
||||
// thinking 处理状态
|
||||
let hasThinkingContent = false
|
||||
let thinkingStartTime = 0
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(
|
||||
new TransformStream<GenericChunk, GenericChunk>({
|
||||
transform(chunk: GenericChunk, controller) {
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
|
||||
// 使用 TagExtractor 处理文本
|
||||
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||
|
||||
for (const extractionResult of extractionResults) {
|
||||
if (extractionResult.complete && extractionResult.tagContentExtracted) {
|
||||
// 生成 THINKING_COMPLETE 事件
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: extractionResult.tagContentExtracted,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
|
||||
// 重置思考状态
|
||||
hasThinkingContent = false
|
||||
thinkingStartTime = 0
|
||||
} else if (extractionResult.content.length > 0) {
|
||||
if (extractionResult.isTagContent) {
|
||||
// 第一次接收到思考内容时记录开始时间
|
||||
if (!hasThinkingContent) {
|
||||
hasThinkingContent = true
|
||||
thinkingStartTime = Date.now()
|
||||
}
|
||||
|
||||
const thinkingDeltaChunk: ThinkingDeltaChunk = {
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: extractionResult.content,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingDeltaChunk)
|
||||
} else {
|
||||
// 发送清理后的文本内容
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: extractionResult.content
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等)
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
// 处理可能剩余的思考内容
|
||||
const finalResult = tagExtractor.finalize()
|
||||
if (finalResult?.tagContentExtracted) {
|
||||
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: finalResult.tagContentExtracted,
|
||||
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||
}
|
||||
controller.enqueue(thinkingCompleteChunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
// 更新响应结果
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
} else {
|
||||
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@ -0,0 +1,124 @@
|
||||
import { MCPTool } from '@renderer/types'
|
||||
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
|
||||
import { parseToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
|
||||
|
||||
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||
|
||||
export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware'
|
||||
|
||||
// 工具使用标签配置
|
||||
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||
openingTag: '<tool_use>',
|
||||
closingTag: '</tool_use>',
|
||||
separator: '\n'
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具使用提取中间件
|
||||
*
|
||||
* 职责:
|
||||
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
|
||||
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
|
||||
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
|
||||
* 4. 清理文本流,移除工具使用标签但保留正常文本
|
||||
*
|
||||
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
|
||||
*/
|
||||
export const ToolUseExtractionMiddleware: CompletionsMiddleware =
|
||||
() =>
|
||||
(next) =>
|
||||
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||
const mcpTools = params.mcpTools || []
|
||||
|
||||
// 如果没有工具,直接调用下一个中间件
|
||||
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
|
||||
|
||||
// 调用下游中间件
|
||||
const result = await next(ctx, params)
|
||||
|
||||
// 响应后处理:处理工具使用标签提取
|
||||
if (result.stream) {
|
||||
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||
|
||||
const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools))
|
||||
|
||||
return {
|
||||
...result,
|
||||
stream: processedStream
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建工具使用提取的 TransformStream
|
||||
*/
|
||||
function createToolUseExtractionTransform(
|
||||
_ctx: CompletionsContext,
|
||||
mcpTools: MCPTool[]
|
||||
): TransformStream<GenericChunk, GenericChunk> {
|
||||
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||
|
||||
return new TransformStream({
|
||||
async transform(chunk: GenericChunk, controller) {
|
||||
try {
|
||||
// 处理文本内容,检测工具使用标签
|
||||
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||
const textChunk = chunk as TextDeltaChunk
|
||||
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||
|
||||
for (const result of extractionResults) {
|
||||
if (result.complete && result.tagContentExtracted) {
|
||||
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
|
||||
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
|
||||
|
||||
if (toolUseResponses.length > 0) {
|
||||
// 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
}
|
||||
} else if (!result.isTagContent && result.content) {
|
||||
// 发送标签外的正常文本内容
|
||||
const cleanTextChunk: TextDeltaChunk = {
|
||||
...textChunk,
|
||||
text: result.content
|
||||
}
|
||||
controller.enqueue(cleanTextChunk)
|
||||
}
|
||||
// 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// 转发其他所有chunk
|
||||
controller.enqueue(chunk)
|
||||
} catch (error) {
|
||||
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||
controller.error(error)
|
||||
}
|
||||
},
|
||||
|
||||
async flush(controller) {
|
||||
// 检查是否有未完成的标签内容
|
||||
const finalResult = tagExtractor.finalize()
|
||||
if (finalResult && finalResult.tagContentExtracted) {
|
||||
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
|
||||
if (toolUseResponses.length > 0) {
|
||||
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_use_responses: toolUseResponses
|
||||
}
|
||||
controller.enqueue(mcpToolCreatedChunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
export default ToolUseExtractionMiddleware
|
||||
88
src/renderer/src/aiCore/middleware/index.ts
Normal file
88
src/renderer/src/aiCore/middleware/index.ts
Normal file
@ -0,0 +1,88 @@
|
||||
import { CompletionsMiddleware, MethodMiddleware } from './types'
|
||||
|
||||
// /**
|
||||
// * Wraps a provider instance with middlewares.
|
||||
// */
|
||||
// export function wrapProviderWithMiddleware(
|
||||
// apiClientInstance: BaseApiClient,
|
||||
// middlewareConfig: MiddlewareConfig
|
||||
// ): BaseApiClient {
|
||||
// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`)
|
||||
// console.log(`[wrapProviderWithMiddleware] Middleware config:`, {
|
||||
// completions: middlewareConfig.completions?.length || 0,
|
||||
// methods: Object.keys(middlewareConfig.methods || {}).length
|
||||
// })
|
||||
|
||||
// // Cache for already wrapped methods to avoid re-wrapping on every access.
|
||||
// const wrappedMethodsCache = new Map<string, (...args: any[]) => Promise<any>>()
|
||||
|
||||
// const proxy = new Proxy(apiClientInstance, {
|
||||
// get(target, propKey, receiver) {
|
||||
// const methodName = typeof propKey === 'string' ? propKey : undefined
|
||||
|
||||
// if (!methodName) {
|
||||
// return Reflect.get(target, propKey, receiver)
|
||||
// }
|
||||
|
||||
// if (wrappedMethodsCache.has(methodName)) {
|
||||
// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`)
|
||||
// return wrappedMethodsCache.get(methodName)
|
||||
// }
|
||||
|
||||
// const originalMethod = Reflect.get(target, propKey, receiver)
|
||||
|
||||
// // If the property is not a function, return it directly.
|
||||
// if (typeof originalMethod !== 'function') {
|
||||
// return originalMethod
|
||||
// }
|
||||
|
||||
// let wrappedMethod: ((...args: any[]) => Promise<any>) | undefined
|
||||
|
||||
// // Handle completions method
|
||||
// if (methodName === 'completions' && middlewareConfig.completions?.length) {
|
||||
// console.log(
|
||||
// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares`
|
||||
// )
|
||||
// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise<any>
|
||||
// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions)
|
||||
// }
|
||||
// // Handle other methods
|
||||
// else {
|
||||
// const methodMiddlewares = middlewareConfig.methods?.[methodName]
|
||||
// if (methodMiddlewares?.length) {
|
||||
// console.log(
|
||||
// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares`
|
||||
// )
|
||||
// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise<any>
|
||||
// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares)
|
||||
// }
|
||||
// }
|
||||
|
||||
// if (wrappedMethod) {
|
||||
// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`)
|
||||
// wrappedMethodsCache.set(methodName, wrappedMethod)
|
||||
// return wrappedMethod
|
||||
// }
|
||||
|
||||
// // If no middlewares are configured for this method, return the original method bound to the target. /
|
||||
// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。
|
||||
// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`)
|
||||
// return originalMethod.bind(target)
|
||||
// }
|
||||
// })
|
||||
// return proxy as BaseApiClient
|
||||
// }
|
||||
|
||||
// Export types for external use
|
||||
export type { CompletionsMiddleware, MethodMiddleware }
|
||||
|
||||
// Export MiddlewareBuilder related types and classes
|
||||
export {
|
||||
CompletionsMiddlewareBuilder,
|
||||
createCompletionsBuilder,
|
||||
createMethodBuilder,
|
||||
MethodMiddlewareBuilder,
|
||||
MiddlewareBuilder,
|
||||
type MiddlewareExecutor,
|
||||
type NamedMiddleware
|
||||
} from './builder'
|
||||
149
src/renderer/src/aiCore/middleware/register.ts
Normal file
149
src/renderer/src/aiCore/middleware/register.ts
Normal file
@ -0,0 +1,149 @@
|
||||
import * as AbortHandlerModule from './common/AbortHandlerMiddleware'
|
||||
import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware'
|
||||
import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware'
|
||||
import * as LoggingModule from './common/LoggingMiddleware'
|
||||
import * as McpToolChunkModule from './core/McpToolChunkMiddleware'
|
||||
import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware'
|
||||
import * as ResponseTransformModule from './core/ResponseTransformMiddleware'
|
||||
// import * as SdkCallModule from './core/SdkCallMiddleware'
|
||||
import * as StreamAdapterModule from './core/StreamAdapterMiddleware'
|
||||
import * as TextChunkModule from './core/TextChunkMiddleware'
|
||||
import * as ThinkChunkModule from './core/ThinkChunkMiddleware'
|
||||
import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware'
|
||||
import * as WebSearchModule from './core/WebSearchMiddleware'
|
||||
import * as ImageGenerationModule from './feat/ImageGenerationMiddleware'
|
||||
import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware'
|
||||
import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware'
|
||||
|
||||
/**
|
||||
* 中间件注册表 - 提供所有可用中间件的集中访问
|
||||
* 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的
|
||||
*/
|
||||
export const MiddlewareRegistry = {
|
||||
[ErrorHandlerModule.MIDDLEWARE_NAME]: {
|
||||
name: ErrorHandlerModule.MIDDLEWARE_NAME,
|
||||
middleware: ErrorHandlerModule.ErrorHandlerMiddleware
|
||||
},
|
||||
// 通用中间件
|
||||
[AbortHandlerModule.MIDDLEWARE_NAME]: {
|
||||
name: AbortHandlerModule.MIDDLEWARE_NAME,
|
||||
middleware: AbortHandlerModule.AbortHandlerMiddleware
|
||||
},
|
||||
[FinalChunkConsumerModule.MIDDLEWARE_NAME]: {
|
||||
name: FinalChunkConsumerModule.MIDDLEWARE_NAME,
|
||||
middleware: FinalChunkConsumerModule.default
|
||||
},
|
||||
|
||||
// 核心流程中间件
|
||||
[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: {
|
||||
name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME,
|
||||
middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware
|
||||
},
|
||||
// [SdkCallModule.MIDDLEWARE_NAME]: {
|
||||
// name: SdkCallModule.MIDDLEWARE_NAME,
|
||||
// middleware: SdkCallModule.SdkCallMiddleware
|
||||
// },
|
||||
[StreamAdapterModule.MIDDLEWARE_NAME]: {
|
||||
name: StreamAdapterModule.MIDDLEWARE_NAME,
|
||||
middleware: StreamAdapterModule.StreamAdapterMiddleware
|
||||
},
|
||||
[RawStreamListenerModule.MIDDLEWARE_NAME]: {
|
||||
name: RawStreamListenerModule.MIDDLEWARE_NAME,
|
||||
middleware: RawStreamListenerModule.RawStreamListenerMiddleware
|
||||
},
|
||||
[ResponseTransformModule.MIDDLEWARE_NAME]: {
|
||||
name: ResponseTransformModule.MIDDLEWARE_NAME,
|
||||
middleware: ResponseTransformModule.ResponseTransformMiddleware
|
||||
},
|
||||
|
||||
// 特性处理中间件
|
||||
[ThinkingTagExtractionModule.MIDDLEWARE_NAME]: {
|
||||
name: ThinkingTagExtractionModule.MIDDLEWARE_NAME,
|
||||
middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware
|
||||
},
|
||||
[ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: {
|
||||
name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME,
|
||||
middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware
|
||||
},
|
||||
[ThinkChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: ThinkChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: ThinkChunkModule.ThinkChunkMiddleware
|
||||
},
|
||||
[McpToolChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: McpToolChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: McpToolChunkModule.McpToolChunkMiddleware
|
||||
},
|
||||
[WebSearchModule.MIDDLEWARE_NAME]: {
|
||||
name: WebSearchModule.MIDDLEWARE_NAME,
|
||||
middleware: WebSearchModule.WebSearchMiddleware
|
||||
},
|
||||
[TextChunkModule.MIDDLEWARE_NAME]: {
|
||||
name: TextChunkModule.MIDDLEWARE_NAME,
|
||||
middleware: TextChunkModule.TextChunkMiddleware
|
||||
},
|
||||
[ImageGenerationModule.MIDDLEWARE_NAME]: {
|
||||
name: ImageGenerationModule.MIDDLEWARE_NAME,
|
||||
middleware: ImageGenerationModule.ImageGenerationMiddleware
|
||||
}
|
||||
} as const
|
||||
|
||||
/**
|
||||
* 根据名称获取中间件
|
||||
* @param name - 中间件名称
|
||||
* @returns 对应的中间件信息
|
||||
*/
|
||||
export function getMiddleware(name: string) {
|
||||
return MiddlewareRegistry[name]
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有注册的中间件名称
|
||||
* @returns 中间件名称列表
|
||||
*/
|
||||
export function getRegisteredMiddlewareNames(): string[] {
|
||||
return Object.keys(MiddlewareRegistry)
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder
|
||||
*/
|
||||
export const DefaultCompletionsNamedMiddlewares = [
|
||||
MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者
|
||||
MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理
|
||||
MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换
|
||||
MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理
|
||||
MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理
|
||||
MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理
|
||||
MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理
|
||||
MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理
|
||||
MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider)
|
||||
MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK)
|
||||
MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换
|
||||
MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器
|
||||
MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器
|
||||
]
|
||||
|
||||
/**
|
||||
* 默认的通用方法中间件 - 例如翻译、摘要等
|
||||
*/
|
||||
export const DefaultMethodMiddlewares = {
|
||||
translate: [LoggingModule.createGenericLoggingMiddleware()],
|
||||
summaries: [LoggingModule.createGenericLoggingMiddleware()]
|
||||
}
|
||||
|
||||
/**
|
||||
* 导出所有中间件模块,方便外部使用
|
||||
*/
|
||||
export {
|
||||
AbortHandlerModule,
|
||||
FinalChunkConsumerModule,
|
||||
LoggingModule,
|
||||
McpToolChunkModule,
|
||||
ResponseTransformModule,
|
||||
StreamAdapterModule,
|
||||
TextChunkModule,
|
||||
ThinkChunkModule,
|
||||
ThinkingTagExtractionModule,
|
||||
TransformCoreToSdkParamsModule,
|
||||
WebSearchModule
|
||||
}
|
||||
77
src/renderer/src/aiCore/middleware/schemas.ts
Normal file
77
src/renderer/src/aiCore/middleware/schemas.ts
Normal file
@ -0,0 +1,77 @@
|
||||
import { Assistant, MCPTool } from '@renderer/types'
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { ProcessingState } from './types'
|
||||
|
||||
// ============================================================================
|
||||
// Core Request Types - 核心请求结构
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* 标准化的内部核心请求结构,用于所有AI Provider的统一处理
|
||||
* 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑
|
||||
*/
|
||||
export interface CompletionsParams {
|
||||
/**
|
||||
* 调用的业务场景类型,用于中间件判断是否执行
|
||||
* 'chat': 主要对话流程
|
||||
* 'translate': 翻译
|
||||
* 'summary': 摘要
|
||||
* 'search': 搜索摘要
|
||||
* 'generate': 生成
|
||||
* 'check': API检查
|
||||
*/
|
||||
callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check'
|
||||
|
||||
// 基础对话数据
|
||||
messages: Message[] | string // 联合类型方便判断是否为空
|
||||
|
||||
assistant: Assistant // 助手为基本单位
|
||||
// model: Model
|
||||
|
||||
onChunk?: (chunk: Chunk) => void
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
|
||||
// 错误相关
|
||||
onError?: (error: Error) => void
|
||||
shouldThrow?: boolean
|
||||
|
||||
// 工具相关
|
||||
mcpTools?: MCPTool[]
|
||||
|
||||
// 生成参数
|
||||
temperature?: number
|
||||
topP?: number
|
||||
maxTokens?: number
|
||||
|
||||
// 功能开关
|
||||
streamOutput: boolean
|
||||
enableWebSearch?: boolean
|
||||
enableReasoning?: boolean
|
||||
enableGenerateImage?: boolean
|
||||
|
||||
// 上下文控制
|
||||
contextCount?: number
|
||||
|
||||
_internal?: ProcessingState
|
||||
}
|
||||
|
||||
export interface CompletionsResult {
|
||||
rawOutput?: SdkRawOutput
|
||||
stream?: ReadableStream<SdkRawChunk> | ReadableStream<Chunk> | AsyncIterable<Chunk>
|
||||
controller?: AbortController
|
||||
|
||||
getText: () => string
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Generic Chunk Types - 通用数据块结构
|
||||
// ============================================================================
|
||||
|
||||
/**
|
||||
* 通用数据块类型
|
||||
* 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式
|
||||
*/
|
||||
export type GenericChunk = Chunk
|
||||
166
src/renderer/src/aiCore/middleware/types.ts
Normal file
166
src/renderer/src/aiCore/middleware/types.ts
Normal file
@ -0,0 +1,166 @@
|
||||
import { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types'
|
||||
import { Chunk, ErrorChunk } from '@renderer/types/chunk'
|
||||
import {
|
||||
SdkInstance,
|
||||
SdkMessageParam,
|
||||
SdkParams,
|
||||
SdkRawChunk,
|
||||
SdkRawOutput,
|
||||
SdkTool,
|
||||
SdkToolCall
|
||||
} from '@renderer/types/sdk'
|
||||
|
||||
import { BaseApiClient } from '../clients'
|
||||
import { CompletionsParams, CompletionsResult } from './schemas'
|
||||
|
||||
/**
|
||||
* Symbol to uniquely identify middleware context objects.
|
||||
*/
|
||||
export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext')
|
||||
|
||||
/**
|
||||
* Defines the structure for the onChunk callback function.
|
||||
*/
|
||||
export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void
|
||||
|
||||
/**
|
||||
* Base context that carries information about the current method call.
|
||||
*/
|
||||
export interface BaseContext {
|
||||
[MIDDLEWARE_CONTEXT_SYMBOL]: true
|
||||
methodName: string
|
||||
originalArgs: Readonly<any[]>
|
||||
}
|
||||
|
||||
/**
|
||||
* Processing state shared between middlewares.
|
||||
*/
|
||||
export interface ProcessingState<
|
||||
TParams extends SdkParams = SdkParams,
|
||||
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TToolCall extends SdkToolCall = SdkToolCall
|
||||
> {
|
||||
sdkPayload?: TParams
|
||||
newReqMessages?: TMessageParam[]
|
||||
observer?: {
|
||||
usage?: Usage
|
||||
metrics?: Metrics
|
||||
}
|
||||
toolProcessingState?: {
|
||||
pendingToolCalls?: Array<TToolCall>
|
||||
executingToolCalls?: Array<{
|
||||
sdkToolCall: TToolCall
|
||||
mcpToolResponse: MCPToolResponse
|
||||
}>
|
||||
output?: SdkRawOutput | string
|
||||
isRecursiveCall?: boolean
|
||||
recursionDepth?: number
|
||||
}
|
||||
webSearchState?: {
|
||||
results?: WebSearchResponse
|
||||
}
|
||||
flowControl?: {
|
||||
abortController?: AbortController
|
||||
abortSignal?: AbortSignal
|
||||
cleanup?: () => void
|
||||
}
|
||||
enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise<CompletionsResult>
|
||||
customState?: Record<string, any>
|
||||
}
|
||||
|
||||
/**
|
||||
* Extended context for completions method.
|
||||
*/
|
||||
export interface CompletionsContext<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> extends BaseContext {
|
||||
readonly methodName: 'completions' // 强制方法名为 'completions'
|
||||
|
||||
apiClientInstance: BaseApiClient<
|
||||
TSdkInstance,
|
||||
TSdkParams,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkSpecificTool
|
||||
>
|
||||
|
||||
// --- Mutable internal state for the duration of the middleware chain ---
|
||||
_internal: ProcessingState<TSdkParams, TSdkMessageParam, TSdkToolCall> // 包含所有可变的处理状态
|
||||
}
|
||||
|
||||
export interface MiddlewareAPI<Ctx extends BaseContext = BaseContext, Args extends any[] = any[]> {
|
||||
getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数
|
||||
getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数
|
||||
}
|
||||
|
||||
/**
|
||||
* Base middleware type.
|
||||
*/
|
||||
export type Middleware<TContext extends BaseContext> = (
|
||||
api: MiddlewareAPI<TContext>
|
||||
) => (
|
||||
next: (context: TContext, args: any[]) => Promise<unknown>
|
||||
) => (context: TContext, args: any[]) => Promise<unknown>
|
||||
|
||||
export type MethodMiddleware = Middleware<BaseContext>
|
||||
|
||||
/**
|
||||
* Completions middleware type.
|
||||
*/
|
||||
export type CompletionsMiddleware<
|
||||
TSdkParams extends SdkParams = SdkParams,
|
||||
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||
TSdkInstance extends SdkInstance = SdkInstance,
|
||||
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||
TSdkSpecificTool extends SdkTool = SdkTool
|
||||
> = (
|
||||
api: MiddlewareAPI<
|
||||
CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
[CompletionsParams]
|
||||
>
|
||||
) => (
|
||||
next: (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
params: CompletionsParams
|
||||
) => Promise<CompletionsResult>
|
||||
) => (
|
||||
context: CompletionsContext<
|
||||
TSdkParams,
|
||||
TSdkMessageParam,
|
||||
TSdkToolCall,
|
||||
TSdkInstance,
|
||||
TRawOutput,
|
||||
TRawChunk,
|
||||
TSdkSpecificTool
|
||||
>,
|
||||
params: CompletionsParams
|
||||
) => Promise<CompletionsResult>
|
||||
|
||||
// Re-export for convenience
|
||||
export type { Chunk as OnChunkArg } from '@renderer/types/chunk'
|
||||
57
src/renderer/src/aiCore/middleware/utils.ts
Normal file
57
src/renderer/src/aiCore/middleware/utils.ts
Normal file
@ -0,0 +1,57 @@
|
||||
import { ChunkType, ErrorChunk } from '@renderer/types/chunk'
|
||||
|
||||
/**
|
||||
* Creates an ErrorChunk object with a standardized structure.
|
||||
* @param error The error object or message.
|
||||
* @param chunkType The type of chunk, defaults to ChunkType.ERROR.
|
||||
* @returns An ErrorChunk object.
|
||||
*/
|
||||
export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk {
|
||||
let errorDetails: Record<string, any> = {}
|
||||
|
||||
if (error instanceof Error) {
|
||||
errorDetails = {
|
||||
message: error.message,
|
||||
name: error.name,
|
||||
stack: error.stack
|
||||
}
|
||||
} else if (typeof error === 'string') {
|
||||
errorDetails = { message: error }
|
||||
} else if (typeof error === 'object' && error !== null) {
|
||||
errorDetails = Object.getOwnPropertyNames(error).reduce(
|
||||
(acc, key) => {
|
||||
acc[key] = error[key]
|
||||
return acc
|
||||
},
|
||||
{} as Record<string, any>
|
||||
)
|
||||
if (!errorDetails.message && error.toString && typeof error.toString === 'function') {
|
||||
const errMsg = error.toString()
|
||||
if (errMsg !== '[object Object]') {
|
||||
errorDetails.message = errMsg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
type: chunkType,
|
||||
error: errorDetails
|
||||
} as ErrorChunk
|
||||
}
|
||||
|
||||
// Helper to capitalize method names for hook construction
|
||||
export function capitalize(str: string): string {
|
||||
if (!str) return ''
|
||||
return str.charAt(0).toUpperCase() + str.slice(1)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查对象是否实现了AsyncIterable接口
|
||||
*/
|
||||
export function isAsyncIterable<T = unknown>(obj: unknown): obj is AsyncIterable<T> {
|
||||
return (
|
||||
obj !== null &&
|
||||
typeof obj === 'object' &&
|
||||
typeof (obj as Record<symbol, unknown>)[Symbol.asyncIterator] === 'function'
|
||||
)
|
||||
}
|
||||
@ -143,7 +143,7 @@ import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png'
|
||||
import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg'
|
||||
import NomicLogo from '@renderer/assets/images/providers/nomic.png'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, Model } from '@renderer/types'
|
||||
import { Model } from '@renderer/types'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts'
|
||||
@ -199,6 +199,11 @@ export const VISION_REGEX = new RegExp(
|
||||
'i'
|
||||
)
|
||||
|
||||
// For middleware to identify models that must use the dedicated Image API
|
||||
export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1']
|
||||
export const isDedicatedImageGenerationModel = (model: Model): boolean =>
|
||||
DEDICATED_IMAGE_MODELS.filter((m) => model.id.includes(m)).length > 0
|
||||
|
||||
// Text to image models
|
||||
export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus/i
|
||||
|
||||
@ -2246,14 +2251,24 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
|
||||
'stabilityai/stable-diffusion-xl-base-1.0'
|
||||
]
|
||||
|
||||
export const SUPPORTED_DISABLE_GENERATION_MODELS = [
|
||||
'gemini-2.0-flash-exp',
|
||||
'gpt-4o',
|
||||
'gpt-4o-mini',
|
||||
'gpt-4.1',
|
||||
'gpt-4.1-mini',
|
||||
'gpt-4.1-nano',
|
||||
'o3'
|
||||
]
|
||||
|
||||
export const GENERATE_IMAGE_MODELS = [
|
||||
'gemini-2.0-flash-exp-image-generation',
|
||||
'gemini-2.0-flash-preview-image-generation',
|
||||
'gemini-2.0-flash-exp',
|
||||
'grok-2-image-1212',
|
||||
'grok-2-image',
|
||||
'grok-2-image-latest',
|
||||
'gpt-image-1'
|
||||
'gpt-image-1',
|
||||
...SUPPORTED_DISABLE_GENERATION_MODELS
|
||||
]
|
||||
|
||||
export const GEMINI_SEARCH_MODELS = [
|
||||
@ -2362,10 +2377,32 @@ export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIWebSearch(model: Model): boolean {
|
||||
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
return (
|
||||
model.id.includes('gpt-4o-search-preview') ||
|
||||
model.id.includes('gpt-4o-mini-search-preview') ||
|
||||
model.id.includes('o1-mini') ||
|
||||
model.id.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIWebSearchChatCompletionOnlyModel(model: Model): boolean {
|
||||
return model.id.includes('gpt-4o-search-preview') || model.id.includes('gpt-4o-mini-search-preview')
|
||||
}
|
||||
|
||||
export function isOpenAIWebSearchModel(model: Model): boolean {
|
||||
return (
|
||||
model.id.includes('gpt-4o-search-preview') ||
|
||||
model.id.includes('gpt-4o-mini-search-preview') ||
|
||||
(model.id.includes('gpt-4.1') && !model.id.includes('gpt-4.1-nano')) ||
|
||||
(model.id.includes('gpt-4o') && !model.id.includes('gpt-4o-image'))
|
||||
)
|
||||
}
|
||||
|
||||
export function isSupportedThinkingTokenModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@ -2506,7 +2543,7 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
if (isOpenAIReasoningModel(model) || isOpenAIWebSearch(model)) {
|
||||
if (isOpenAIReasoningModel(model) || isOpenAIChatCompletionOnlyModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -2536,17 +2573,13 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
// 不管哪个供应商都判断了
|
||||
if (model.id.includes('claude')) {
|
||||
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(model.id)
|
||||
}
|
||||
|
||||
if (provider.type === 'openai-response') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
if (isOpenAIWebSearchModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -2558,12 +2591,7 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
}
|
||||
|
||||
if (provider.id === 'aihubmix') {
|
||||
if (
|
||||
isOpenAILLMModel(model) &&
|
||||
!isTextToImageModel(model) &&
|
||||
!isOpenAIReasoningModel(model) &&
|
||||
!GENERATE_IMAGE_MODELS.includes(model.id)
|
||||
) {
|
||||
if (isOpenAIWebSearchModel(model)) {
|
||||
return true
|
||||
}
|
||||
|
||||
@ -2572,7 +2600,7 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
}
|
||||
|
||||
if (provider?.type === 'openai') {
|
||||
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearch(model)) {
|
||||
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearchModel(model)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
@ -2606,6 +2634,20 @@ export function isWebSearchModel(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
export function isOpenRouterBuiltInWebSearchModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (provider.id !== 'openrouter') {
|
||||
return false
|
||||
}
|
||||
|
||||
return isOpenAIWebSearchModel(model) || model.id.includes('sonar')
|
||||
}
|
||||
|
||||
export function isGenerateImageModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@ -2628,56 +2670,60 @@ export function isGenerateImageModel(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (isWebSearchModel(model)) {
|
||||
if (assistant.enableWebSearch) {
|
||||
const webSearchTools = getWebSearchTools(model)
|
||||
export function isSupportedDisableGenerationModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (model.provider === 'grok') {
|
||||
return {
|
||||
search_parameters: {
|
||||
mode: 'auto',
|
||||
return_citations: true,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
}
|
||||
}
|
||||
return SUPPORTED_DISABLE_GENERATION_MODELS.includes(model.id)
|
||||
}
|
||||
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: true, citation: true, search_info: true }
|
||||
}
|
||||
export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record<string, any> {
|
||||
if (!isEnableWebSearch) {
|
||||
return {}
|
||||
}
|
||||
|
||||
if (model.provider === 'dashscope') {
|
||||
return {
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
const webSearchTools = getWebSearchTools(model)
|
||||
|
||||
if (model.provider === 'openrouter') {
|
||||
return {
|
||||
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIWebSearch(model)) {
|
||||
return {
|
||||
web_search_options: {}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
tools: webSearchTools
|
||||
}
|
||||
} else {
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: false }
|
||||
if (model.provider === 'grok') {
|
||||
return {
|
||||
search_parameters: {
|
||||
mode: 'auto',
|
||||
return_citations: true,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: true, citation: true, search_info: true }
|
||||
}
|
||||
|
||||
if (model.provider === 'dashscope') {
|
||||
return {
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
|
||||
return {
|
||||
web_search_options: {}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.provider === 'openrouter') {
|
||||
return {
|
||||
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
tools: webSearchTools
|
||||
}
|
||||
|
||||
return {}
|
||||
}
|
||||
|
||||
|
||||
@ -1,118 +0,0 @@
|
||||
// Modified from https://github.com/vercel/ai/blob/845080d80b8538bb9c7e527d2213acb5f33ac9c2/packages/ai/core/middleware/extract-reasoning-middleware.ts
|
||||
|
||||
import { getPotentialStartIndex } from '../utils/getPotentialIndex'
|
||||
|
||||
export interface ExtractReasoningMiddlewareOptions {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
enableReasoning?: boolean
|
||||
}
|
||||
|
||||
function escapeRegExp(str: string) {
|
||||
return str.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&')
|
||||
}
|
||||
|
||||
// 支持泛型 T,默认 T = { type: string; textDelta: string }
|
||||
export function extractReasoningMiddleware<
|
||||
T extends { type: string } & (
|
||||
| { type: 'text-delta' | 'reasoning'; textDelta: string }
|
||||
| { type: string } // 其他类型
|
||||
) = { type: string; textDelta: string }
|
||||
>({ openingTag, closingTag, separator = '\n', enableReasoning }: ExtractReasoningMiddlewareOptions) {
|
||||
const openingTagEscaped = escapeRegExp(openingTag)
|
||||
const closingTagEscaped = escapeRegExp(closingTag)
|
||||
|
||||
return {
|
||||
wrapGenerate: async ({ doGenerate }: { doGenerate: () => Promise<{ text: string } & Record<string, any>> }) => {
|
||||
const { text: rawText, ...rest } = await doGenerate()
|
||||
if (rawText == null) {
|
||||
return { text: rawText, ...rest }
|
||||
}
|
||||
const text = rawText
|
||||
const regexp = new RegExp(`${openingTagEscaped}(.*?)${closingTagEscaped}`, 'gs')
|
||||
const matches = Array.from(text.matchAll(regexp))
|
||||
if (!matches.length) {
|
||||
return { text, ...rest }
|
||||
}
|
||||
const reasoning = matches.map((match: RegExpMatchArray) => match[1]).join(separator)
|
||||
let textWithoutReasoning = text
|
||||
for (let i = matches.length - 1; i >= 0; i--) {
|
||||
const match = matches[i] as RegExpMatchArray
|
||||
const beforeMatch = textWithoutReasoning.slice(0, match.index as number)
|
||||
const afterMatch = textWithoutReasoning.slice((match.index as number) + match[0].length)
|
||||
textWithoutReasoning =
|
||||
beforeMatch + (beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') + afterMatch
|
||||
}
|
||||
return { ...rest, text: textWithoutReasoning, reasoning }
|
||||
},
|
||||
wrapStream: async ({
|
||||
doStream
|
||||
}: {
|
||||
doStream: () => Promise<{ stream: ReadableStream<T> } & Record<string, any>>
|
||||
}) => {
|
||||
const { stream, ...rest } = await doStream()
|
||||
if (!enableReasoning) {
|
||||
return {
|
||||
stream,
|
||||
...rest
|
||||
}
|
||||
}
|
||||
let isFirstReasoning = true
|
||||
let isFirstText = true
|
||||
let afterSwitch = false
|
||||
let isReasoning = false
|
||||
let buffer = ''
|
||||
return {
|
||||
stream: stream.pipeThrough(
|
||||
new TransformStream<T, T>({
|
||||
transform: (chunk, controller) => {
|
||||
if (chunk.type !== 'text-delta') {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
// textDelta 只在 text-delta/reasoning chunk 上
|
||||
buffer += (chunk as { textDelta: string }).textDelta
|
||||
function publish(text: string) {
|
||||
if (text.length > 0) {
|
||||
const prefix = afterSwitch && (isReasoning ? !isFirstReasoning : !isFirstText) ? separator : ''
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
type: isReasoning ? 'reasoning' : 'text-delta',
|
||||
textDelta: prefix + text
|
||||
} as T)
|
||||
afterSwitch = false
|
||||
if (isReasoning) {
|
||||
isFirstReasoning = false
|
||||
} else {
|
||||
isFirstText = false
|
||||
}
|
||||
}
|
||||
}
|
||||
while (true) {
|
||||
const nextTag = isReasoning ? closingTag : openingTag
|
||||
const startIndex = getPotentialStartIndex(buffer, nextTag)
|
||||
if (startIndex == null) {
|
||||
publish(buffer)
|
||||
buffer = ''
|
||||
break
|
||||
}
|
||||
publish(buffer.slice(0, startIndex))
|
||||
const foundFullMatch = startIndex + nextTag.length <= buffer.length
|
||||
if (foundFullMatch) {
|
||||
buffer = buffer.slice(startIndex + nextTag.length)
|
||||
isReasoning = !isReasoning
|
||||
afterSwitch = true
|
||||
} else {
|
||||
buffer = buffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
),
|
||||
...rest
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -4,6 +4,7 @@ import TranslateButton from '@renderer/components/TranslateButton'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
isSupportedDisableGenerationModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isVisionModel,
|
||||
@ -727,7 +728,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
|
||||
if (!isGenerateImageModel(model) && assistant.enableGenerateImage) {
|
||||
updateAssistant({ ...assistant, enableGenerateImage: false })
|
||||
}
|
||||
if (isGenerateImageModel(model) && !assistant.enableGenerateImage && model.id !== 'gemini-2.0-flash-exp') {
|
||||
if (isGenerateImageModel(model) && !assistant.enableGenerateImage && !isSupportedDisableGenerationModel(model)) {
|
||||
updateAssistant({ ...assistant, enableGenerateImage: true })
|
||||
}
|
||||
}, [assistant, model, updateAssistant])
|
||||
|
||||
@ -40,7 +40,18 @@ function CitationBlock({ block }: { block: CitationMessageBlock }) {
|
||||
__html:
|
||||
(block.response?.results as GroundingMetadata)?.searchEntryPoint?.renderedContent
|
||||
?.replace(/@media \(prefers-color-scheme: light\)/g, 'body[theme-mode="light"]')
|
||||
.replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]') || ''
|
||||
.replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]')
|
||||
.replace(
|
||||
/background-color\s*:\s*#[0-9a-fA-F]{3,6}\b|\bbackground-color\s*:\s*[a-zA-Z-]+\b/g,
|
||||
'background-color: var(--color-background-soft)'
|
||||
)
|
||||
.replace(/\.gradient\s*{[^}]*background\s*:\s*[^};]+[;}]/g, (match) => {
|
||||
// Remove the background property while preserving the rest
|
||||
return match.replace(/background\s*:\s*[^};]+;?\s*/g, '')
|
||||
})
|
||||
.replace(/\.chip {\n/g, '.chip {\n background-color: var(--color-background)!important;\n')
|
||||
.replace(/border-color\s*:\s*[^};]+;?\s*/g, '')
|
||||
.replace(/border\s*:\s*[^};]+;?\s*/g, '') || ''
|
||||
}}
|
||||
/>
|
||||
</>
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import ImageViewer from '@renderer/components/ImageViewer'
|
||||
import type { ImageMessageBlock } from '@renderer/types/newMessage'
|
||||
import { type ImageMessageBlock, MessageBlockStatus } from '@renderer/types/newMessage'
|
||||
import React from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@ -9,23 +9,28 @@ interface Props {
|
||||
}
|
||||
|
||||
const ImageBlock: React.FC<Props> = ({ block }) => {
|
||||
if (block.status !== 'success') return <SvgSpinners180Ring />
|
||||
const images = block.metadata?.generateImageResponse?.images?.length
|
||||
? block.metadata?.generateImageResponse?.images
|
||||
: block?.file?.path
|
||||
? [`file://${block?.file?.path}`]
|
||||
: []
|
||||
return (
|
||||
<Container style={{ marginBottom: 8 }}>
|
||||
{images.map((src, index) => (
|
||||
<ImageViewer
|
||||
src={src}
|
||||
key={`image-${index}`}
|
||||
style={{ maxWidth: 500, maxHeight: 500, padding: 5, borderRadius: 8 }}
|
||||
/>
|
||||
))}
|
||||
</Container>
|
||||
)
|
||||
if (block.status === MessageBlockStatus.STREAMING || block.status === MessageBlockStatus.PROCESSING)
|
||||
return <SvgSpinners180Ring />
|
||||
if (block.status === MessageBlockStatus.SUCCESS) {
|
||||
const images = block.metadata?.generateImageResponse?.images?.length
|
||||
? block.metadata?.generateImageResponse?.images
|
||||
: block?.file?.path
|
||||
? [`file://${block?.file?.path}`]
|
||||
: []
|
||||
return (
|
||||
<Container style={{ marginBottom: 8 }}>
|
||||
{images.map((src, index) => (
|
||||
<ImageViewer
|
||||
src={src}
|
||||
key={`image-${index}`}
|
||||
style={{ maxWidth: 500, maxHeight: 500, padding: 5, borderRadius: 8 }}
|
||||
/>
|
||||
))}
|
||||
</Container>
|
||||
)
|
||||
} else {
|
||||
return <></>
|
||||
}
|
||||
}
|
||||
const Container = styled.div`
|
||||
display: flex;
|
||||
|
||||
@ -1,124 +0,0 @@
|
||||
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
|
||||
import { fetchSuggestions } from '@renderer/services/ApiService'
|
||||
import { getUserMessage } from '@renderer/services/MessagesService'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
import { sendMessage } from '@renderer/store/thunk/messageThunk'
|
||||
import { Assistant, Suggestion } from '@renderer/types'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { last } from 'lodash'
|
||||
import { FC, memo, useEffect, useState } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
interface Props {
|
||||
assistant: Assistant
|
||||
messages: Message[]
|
||||
}
|
||||
|
||||
const suggestionsMap = new Map<string, Suggestion[]>()
|
||||
|
||||
const Suggestions: FC<Props> = ({ assistant, messages }) => {
|
||||
const dispatch = useAppDispatch()
|
||||
|
||||
const [suggestions, setSuggestions] = useState<Suggestion[]>(
|
||||
suggestionsMap.get(messages[messages.length - 1]?.id) || []
|
||||
)
|
||||
const [loadingSuggestions, setLoadingSuggestions] = useState(false)
|
||||
|
||||
const handleSuggestionClick = async (content: string) => {
|
||||
const { message: userMessage, blocks } = getUserMessage({
|
||||
assistant,
|
||||
topic: assistant.topics[0],
|
||||
content
|
||||
})
|
||||
|
||||
await dispatch(sendMessage(userMessage, blocks, assistant, assistant.topics[0].id))
|
||||
}
|
||||
|
||||
const suggestionsHandle = async () => {
|
||||
if (loadingSuggestions) return
|
||||
try {
|
||||
setLoadingSuggestions(true)
|
||||
const _suggestions = await fetchSuggestions({
|
||||
assistant,
|
||||
messages
|
||||
})
|
||||
if (_suggestions.length) {
|
||||
setSuggestions(_suggestions)
|
||||
suggestionsMap.set(messages[messages.length - 1].id, _suggestions)
|
||||
}
|
||||
} finally {
|
||||
setLoadingSuggestions(false)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
suggestionsHandle()
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
setSuggestions(suggestionsMap.get(messages[messages.length - 1]?.id) || [])
|
||||
}, [messages])
|
||||
|
||||
if (last(messages)?.status !== 'success') {
|
||||
return null
|
||||
}
|
||||
if (loadingSuggestions) {
|
||||
return (
|
||||
<Container>
|
||||
<SvgSpinners180Ring color="var(--color-text-2)" />
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
if (suggestions.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<Container>
|
||||
<SuggestionsContainer>
|
||||
{suggestions.map((s, i) => (
|
||||
<SuggestionItem key={i} onClick={() => handleSuggestionClick(s.content)}>
|
||||
{s.content} →
|
||||
</SuggestionItem>
|
||||
))}
|
||||
</SuggestionsContainer>
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
const Container = styled.div`
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
padding: 10px 10px 20px 65px;
|
||||
display: flex;
|
||||
width: 100%;
|
||||
flex-direction: row;
|
||||
flex-wrap: wrap;
|
||||
gap: 15px;
|
||||
`
|
||||
|
||||
const SuggestionsContainer = styled.div`
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
flex-wrap: wrap;
|
||||
gap: 10px;
|
||||
`
|
||||
|
||||
const SuggestionItem = styled.div`
|
||||
display: flex;
|
||||
align-items: center;
|
||||
width: fit-content;
|
||||
padding: 5px 10px;
|
||||
border-radius: 12px;
|
||||
font-size: 12px;
|
||||
color: var(--color-text);
|
||||
background: var(--color-background-mute);
|
||||
cursor: pointer;
|
||||
&:hover {
|
||||
opacity: 0.9;
|
||||
}
|
||||
`
|
||||
|
||||
export default memo(Suggestions)
|
||||
@ -1,3 +1,4 @@
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { TopView } from '@renderer/components/TopView'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
@ -6,7 +7,6 @@ import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
|
||||
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
|
||||
import { useProviders } from '@renderer/hooks/useProvider'
|
||||
import { SettingHelpText } from '@renderer/pages/settings'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
|
||||
import { getModelUniqId } from '@renderer/services/ModelService'
|
||||
import { KnowledgeBase, Model } from '@renderer/types'
|
||||
|
||||
@ -11,7 +11,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings'
|
||||
import { useAllProviders } from '@renderer/hooks/useProvider'
|
||||
import { useRuntime } from '@renderer/hooks/useRuntime'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { translateText } from '@renderer/services/TranslateService'
|
||||
import { useAppDispatch } from '@renderer/store'
|
||||
@ -182,11 +182,9 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
|
||||
const base64s = await AI.generateImage({
|
||||
prompt,
|
||||
model: painting.model,
|
||||
config: {
|
||||
aspectRatio: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':'),
|
||||
numberOfImages: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages,
|
||||
personGeneration: painting.personGeneration
|
||||
}
|
||||
imageSize: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':') || '1:1',
|
||||
batchSize: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages || 1,
|
||||
personGeneration: painting.personGeneration
|
||||
})
|
||||
if (base64s?.length > 0) {
|
||||
const validFiles = await Promise.all(
|
||||
|
||||
@ -16,7 +16,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings'
|
||||
import { useAllProviders } from '@renderer/hooks/useProvider'
|
||||
import { useRuntime } from '@renderer/hooks/useRuntime'
|
||||
import { useSettings } from '@renderer/hooks/useSettings'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { translateText } from '@renderer/services/TranslateService'
|
||||
|
||||
@ -51,8 +51,8 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
try {
|
||||
let valid = false
|
||||
if (type === 'provider' && model) {
|
||||
const result = await checkApi({ ...(provider as Provider), apiKey: status.key }, model)
|
||||
valid = result.valid
|
||||
await checkApi({ ...(provider as Provider), apiKey: status.key }, model)
|
||||
valid = true
|
||||
} else {
|
||||
const result = await WebSearchService.checkSearch({
|
||||
...(provider as WebSearchProvider),
|
||||
@ -65,7 +65,7 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: valid } : s)))
|
||||
|
||||
return { index: i, valid }
|
||||
} catch (error) {
|
||||
} catch (error: unknown) {
|
||||
// 处理错误情况
|
||||
setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: false } : s)))
|
||||
return { index: i, valid: false }
|
||||
@ -90,8 +90,8 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
try {
|
||||
let valid = false
|
||||
if (type === 'provider' && model) {
|
||||
const result = await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model)
|
||||
valid = result.valid
|
||||
await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model)
|
||||
valid = true
|
||||
} else {
|
||||
const result = await WebSearchService.checkSearch({
|
||||
...(provider as WebSearchProvider),
|
||||
@ -103,7 +103,7 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
|
||||
setKeyStatuses((prev) =>
|
||||
prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: valid } : status))
|
||||
)
|
||||
} catch (error) {
|
||||
} catch (error: unknown) {
|
||||
setKeyStatuses((prev) =>
|
||||
prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: false } : status))
|
||||
)
|
||||
|
||||
@ -145,14 +145,17 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
|
||||
setListModels(
|
||||
models
|
||||
.map((model) => ({
|
||||
id: model.id,
|
||||
// @ts-ignore modelId
|
||||
id: model?.id || model?.name,
|
||||
// @ts-ignore name
|
||||
name: model.name || model.id,
|
||||
name: model?.display_name || model?.displayName || model?.name || model?.id,
|
||||
provider: _provider.id,
|
||||
group: getDefaultGroupName(model.id, _provider.id),
|
||||
// @ts-ignore name
|
||||
description: model?.description,
|
||||
owned_by: model?.owned_by
|
||||
// @ts-ignore group
|
||||
group: getDefaultGroupName(model?.id || model?.name, _provider.id),
|
||||
// @ts-ignore description
|
||||
description: model?.description || '',
|
||||
// @ts-ignore owned_by
|
||||
owned_by: model?.owned_by || ''
|
||||
}))
|
||||
.filter((model) => !isEmpty(model.name))
|
||||
)
|
||||
|
||||
@ -7,7 +7,7 @@ import { PROVIDER_CONFIG } from '@renderer/config/providers'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { useAllProviders, useProvider, useProviders } from '@renderer/hooks/useProvider'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { isOpenAIProvider } from '@renderer/providers/AiProvider/ProviderFactory'
|
||||
import { isOpenAIProvider } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { checkApi, formatApiKeys } from '@renderer/services/ApiService'
|
||||
import { checkModelsHealth, getModelCheckSummary } from '@renderer/services/HealthCheckService'
|
||||
import { isProviderSupportAuth } from '@renderer/services/ProviderService'
|
||||
@ -231,22 +231,32 @@ const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
|
||||
} else {
|
||||
setApiChecking(true)
|
||||
|
||||
const { valid, error } = await checkApi({ ...provider, apiKey, apiHost }, model)
|
||||
try {
|
||||
await checkApi({ ...provider, apiKey, apiHost }, model)
|
||||
|
||||
const errorMessage = error && error?.message ? ' ' + error?.message : ''
|
||||
window.message.success({
|
||||
key: 'api-check',
|
||||
style: { marginTop: '3vh' },
|
||||
duration: 2,
|
||||
content: i18n.t('message.api.connection.success')
|
||||
})
|
||||
|
||||
window.message[valid ? 'success' : 'error']({
|
||||
key: 'api-check',
|
||||
style: { marginTop: '3vh' },
|
||||
duration: valid ? 2 : 8,
|
||||
content: valid
|
||||
? i18n.t('message.api.connection.success')
|
||||
: i18n.t('message.api.connection.failed') + errorMessage
|
||||
})
|
||||
setApiValid(true)
|
||||
setTimeout(() => setApiValid(false), 3000)
|
||||
} catch (error: any) {
|
||||
const errorMessage = error?.message ? ' ' + error.message : ''
|
||||
|
||||
setApiValid(valid)
|
||||
setApiChecking(false)
|
||||
setTimeout(() => setApiValid(false), 3000)
|
||||
window.message.error({
|
||||
key: 'api-check',
|
||||
style: { marginTop: '3vh' },
|
||||
duration: 8,
|
||||
content: i18n.t('message.api.connection.failed') + errorMessage
|
||||
})
|
||||
|
||||
setApiValid(false)
|
||||
} finally {
|
||||
setApiChecking(false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,117 +0,0 @@
|
||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
import AnthropicProvider from './AnthropicProvider'
|
||||
import BaseProvider from './BaseProvider'
|
||||
import GeminiProvider from './GeminiProvider'
|
||||
import OpenAIProvider from './OpenAIProvider'
|
||||
import OpenAIResponseProvider from './OpenAIResponseProvider'
|
||||
|
||||
/**
|
||||
* AihubmixProvider - 根据模型类型自动选择合适的提供商
|
||||
* 使用装饰器模式实现
|
||||
*/
|
||||
export default class AihubmixProvider extends BaseProvider {
|
||||
private providers: Map<string, BaseProvider> = new Map()
|
||||
private defaultProvider: BaseProvider
|
||||
private currentProvider: BaseProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
|
||||
// 初始化各个提供商
|
||||
this.providers.set('claude', new AnthropicProvider(provider))
|
||||
this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' }))
|
||||
this.providers.set('openai', new OpenAIResponseProvider(provider))
|
||||
this.providers.set('default', new OpenAIProvider(provider))
|
||||
|
||||
// 设置默认提供商
|
||||
this.defaultProvider = this.providers.get('default')!
|
||||
this.currentProvider = this.defaultProvider
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据模型获取合适的提供商
|
||||
*/
|
||||
private getProvider(model: Model): BaseProvider {
|
||||
const id = model.id.toLowerCase()
|
||||
// claude开头
|
||||
if (id.startsWith('claude')) {
|
||||
return this.providers.get('claude')!
|
||||
}
|
||||
// gemini开头 或 imagen开头 且不以-nothink、-search结尾
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return this.providers.get('gemini')!
|
||||
}
|
||||
if (isOpenAILLMModel(model)) {
|
||||
return this.providers.get('openai')!
|
||||
}
|
||||
|
||||
return this.defaultProvider
|
||||
}
|
||||
|
||||
// 直接使用默认提供商的方法
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return this.defaultProvider.models()
|
||||
}
|
||||
|
||||
public async generateText(params: { prompt: string; content: string }): Promise<string> {
|
||||
return this.defaultProvider.generateText(params)
|
||||
}
|
||||
|
||||
public async generateImage(params: any): Promise<string[]> {
|
||||
return this.getProvider({
|
||||
id: params.model
|
||||
} as unknown as Model).generateImage(params)
|
||||
}
|
||||
|
||||
public async generateImageByChat(params: any): Promise<void> {
|
||||
return this.defaultProvider.generateImageByChat(params)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams): Promise<void> {
|
||||
const model = params.assistant.model
|
||||
this.currentProvider = this.getProvider(model!)
|
||||
return this.currentProvider.completions(params)
|
||||
}
|
||||
|
||||
public async translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
): Promise<string> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).translate(content, assistant, onResponse)
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).summaries(messages, assistant)
|
||||
}
|
||||
|
||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).summaryForSearch(messages, assistant)
|
||||
}
|
||||
|
||||
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||
return this.getProvider(assistant.model || getDefaultModel()).suggestions(messages, assistant)
|
||||
}
|
||||
|
||||
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
||||
return this.getProvider(model).check(model, stream)
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.getProvider(model).getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public convertMcpTools<T>(mcpTools: MCPTool[]) {
|
||||
return this.currentProvider.convertMcpTools(mcpTools) as T[]
|
||||
}
|
||||
|
||||
public mcpToolCallResponseToMessage(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) {
|
||||
return this.currentProvider.mcpToolCallResponseToMessage(mcpToolResponse, resp, model)
|
||||
}
|
||||
}
|
||||
@ -1,802 +0,0 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
Base64ImageSource,
|
||||
ImageBlockParam,
|
||||
MessageCreateParamsNonStreaming,
|
||||
MessageParam,
|
||||
TextBlockParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUnion,
|
||||
ToolUseBlock,
|
||||
WebSearchResultBlock,
|
||||
WebSearchTool20250305,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
filterUserRoleStartMessages
|
||||
} from '@renderer/services/MessagesService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
FileTypes,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Metrics,
|
||||
Model,
|
||||
Provider,
|
||||
Suggestion,
|
||||
ToolCallResponse,
|
||||
Usage,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools,
|
||||
parseAndCallTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { first, flatten, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CompletionsParams } from '.'
|
||||
import BaseProvider from './BaseProvider'
|
||||
|
||||
interface ReasoningConfig {
|
||||
type: 'enabled' | 'disabled'
|
||||
budget_tokens?: number
|
||||
}
|
||||
|
||||
export default class AnthropicProvider extends BaseProvider {
|
||||
private sdk: Anthropic
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.sdk = new Anthropic({
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'anthropic-beta': 'output-128k-2025-02-19'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.provider.apiHost
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the message parameter
|
||||
* @param message - The message
|
||||
* @returns The message parameter
|
||||
*/
|
||||
private async getMessageParam(message: Message): Promise<MessageParam> {
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
type: 'text',
|
||||
text: getMainTextContent(message)
|
||||
}
|
||||
]
|
||||
|
||||
// Get and process image blocks
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
// Handle uploaded file
|
||||
const file = imageBlock.file
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
parts.push({
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
parts.push({
|
||||
type: 'document',
|
||||
source: {
|
||||
type: 'base64',
|
||||
media_type: 'application/pdf',
|
||||
data: base64Data
|
||||
}
|
||||
})
|
||||
} else {
|
||||
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||
parts.push({
|
||||
type: 'text',
|
||||
text: file.origin_name + '\n' + fileContent
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||
if (!isWebSearchModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
return {
|
||||
type: 'web_search_20250305',
|
||||
name: 'web_search',
|
||||
max_uses: 5
|
||||
} as WebSearchTool20250305
|
||||
}
|
||||
|
||||
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.temperature
|
||||
}
|
||||
|
||||
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
return assistant.settings?.topP
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the reasoning effort
|
||||
* @param assistant - The assistant
|
||||
* @param model - The model
|
||||
* @returns The reasoning effort
|
||||
*/
|
||||
private getBudgetToken(assistant: Assistant, model: Model): ReasoningConfig | undefined {
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
if (reasoningEffort === undefined) {
|
||||
return {
|
||||
type: 'disabled'
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
type: 'enabled',
|
||||
budget_tokens: budgetTokens
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate completions
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @param mcpTools - The MCP tools
|
||||
* @param onChunk - The onChunk callback
|
||||
* @param onFilterMessages - The onFilterMessages callback
|
||||
*/
|
||||
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessagesParams: MessageParam[] = []
|
||||
|
||||
const _messages = filterUserRoleStartMessages(
|
||||
filterContextMessages(filterEmptyMessages(takeRight(messages, contextCount + 2)))
|
||||
)
|
||||
|
||||
onFilterMessages(_messages)
|
||||
|
||||
for (const message of _messages) {
|
||||
userMessagesParams.push(await this.getMessageParam(message))
|
||||
}
|
||||
|
||||
const userMessages = flatten(userMessagesParams)
|
||||
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
|
||||
|
||||
let systemPrompt = assistant.prompt
|
||||
|
||||
const { tools } = this.setupToolsConfig<ToolUnion>({
|
||||
model,
|
||||
mcpTools,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
if (this.useSystemPromptForTools && mcpTools && mcpTools.length) {
|
||||
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools)
|
||||
}
|
||||
|
||||
let systemMessage: TextBlockParam | undefined = undefined
|
||||
if (systemPrompt) {
|
||||
systemMessage = {
|
||||
type: 'text',
|
||||
text: systemPrompt
|
||||
}
|
||||
}
|
||||
|
||||
const isEnabledBuiltinWebSearch = assistant.enableWebSearch && isWebSearchModel(model)
|
||||
|
||||
if (isEnabledBuiltinWebSearch) {
|
||||
const webSearchTool = await this.getWebSearchParams(model)
|
||||
if (webSearchTool) {
|
||||
tools.push(webSearchTool)
|
||||
}
|
||||
}
|
||||
|
||||
const body: MessageCreateParamsNonStreaming = {
|
||||
model: model.id,
|
||||
messages: userMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
top_p: this.getTopP(assistant, model),
|
||||
system: systemMessage ? [systemMessage] : undefined,
|
||||
// @ts-ignore thinking
|
||||
thinking: this.getBudgetToken(assistant, model),
|
||||
tools: tools,
|
||||
...this.getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
const finalUsage: Usage = {
|
||||
completion_tokens: 0,
|
||||
prompt_tokens: 0,
|
||||
total_tokens: 0
|
||||
}
|
||||
|
||||
const finalMetrics: Metrics = {
|
||||
completion_tokens: 0,
|
||||
time_completion_millsec: 0,
|
||||
time_first_token_millsec: 0
|
||||
}
|
||||
const toolResponses: MCPToolResponse[] = []
|
||||
|
||||
const processStream = async (body: MessageCreateParamsNonStreaming, idx: number) => {
|
||||
let time_first_token_millsec = 0
|
||||
|
||||
if (!streamOutput) {
|
||||
const message = await this.sdk.messages.create({ ...body, stream: false })
|
||||
const time_completion_millsec = new Date().getTime() - start_time_millsec
|
||||
|
||||
let text = ''
|
||||
let reasoning_content = ''
|
||||
|
||||
if (message.content && message.content.length > 0) {
|
||||
const thinkingBlock = message.content.find((block) => block.type === 'thinking')
|
||||
const textBlock = message.content.find((block) => block.type === 'text')
|
||||
|
||||
if (thinkingBlock && 'thinking' in thinkingBlock) {
|
||||
reasoning_content = thinkingBlock.thinking
|
||||
}
|
||||
|
||||
if (textBlock && 'text' in textBlock) {
|
||||
text = textBlock.text
|
||||
}
|
||||
}
|
||||
|
||||
return onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
text,
|
||||
reasoning_content,
|
||||
usage: message.usage as any,
|
||||
metrics: {
|
||||
completion_tokens: message.usage?.output_tokens || 0,
|
||||
time_completion_millsec,
|
||||
time_first_token_millsec: 0
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
let thinking_content = ''
|
||||
let isFirstChunk = true
|
||||
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
// 等待接口返回流
|
||||
const toolCalls: ToolUseBlock[] = []
|
||||
|
||||
this.sdk.messages
|
||||
.stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 })
|
||||
.on('text', (text) => {
|
||||
if (isFirstChunk) {
|
||||
isFirstChunk = false
|
||||
if (time_first_token_millsec == 0) {
|
||||
time_first_token_millsec = new Date().getTime()
|
||||
} else {
|
||||
onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: thinking_content,
|
||||
thinking_millsec: new Date().getTime() - time_first_token_millsec
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
onChunk({ type: ChunkType.TEXT_DELTA, text })
|
||||
})
|
||||
.on('contentBlock', (block) => {
|
||||
if (block.type === 'server_tool_use' && block.name === 'web_search') {
|
||||
onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||
})
|
||||
} else if (block.type === 'web_search_tool_result') {
|
||||
if (
|
||||
block.content &&
|
||||
(block.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||
) {
|
||||
onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: {
|
||||
code: (block.content as WebSearchToolResultError).error_code,
|
||||
message: (block.content as WebSearchToolResultError).error_code
|
||||
}
|
||||
})
|
||||
} else {
|
||||
onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: block.content as Array<WebSearchResultBlock>,
|
||||
source: WebSearchSource.ANTHROPIC
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
if (block.type === 'tool_use') {
|
||||
toolCalls.push(block)
|
||||
}
|
||||
})
|
||||
.on('thinking', (thinking) => {
|
||||
if (time_first_token_millsec == 0) {
|
||||
time_first_token_millsec = new Date().getTime()
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: thinking,
|
||||
thinking_millsec: new Date().getTime() - time_first_token_millsec
|
||||
})
|
||||
thinking_content += thinking
|
||||
})
|
||||
.on('finalMessage', async (message) => {
|
||||
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
|
||||
// tool call
|
||||
if (toolCalls.length > 0) {
|
||||
const mcpToolResponses = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||
if (!mcpTool) {
|
||||
return undefined
|
||||
}
|
||||
return {
|
||||
id: toolCall.id,
|
||||
toolCallId: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input as Record<string, unknown>,
|
||||
status: 'pending'
|
||||
} as ToolCallResponse
|
||||
})
|
||||
.filter((t) => typeof t !== 'undefined')
|
||||
toolResults.push(
|
||||
...(await parseAndCallTools(
|
||||
mcpToolResponses,
|
||||
toolResponses,
|
||||
onChunk,
|
||||
this.mcpToolCallResponseToMessage,
|
||||
model,
|
||||
mcpTools
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
// tool use
|
||||
const content = message.content[0]
|
||||
if (content && content.type === 'text') {
|
||||
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text })
|
||||
toolResults.push(
|
||||
...(await parseAndCallTools(
|
||||
content.text,
|
||||
toolResponses,
|
||||
onChunk,
|
||||
this.mcpToolCallResponseToMessage,
|
||||
model,
|
||||
mcpTools
|
||||
))
|
||||
)
|
||||
}
|
||||
|
||||
if (thinking_content) {
|
||||
onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: thinking_content,
|
||||
thinking_millsec: new Date().getTime() - time_first_token_millsec
|
||||
})
|
||||
}
|
||||
|
||||
userMessages.push({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
})
|
||||
|
||||
if (toolResults.length > 0) {
|
||||
toolResults.forEach((ts) => userMessages.push(ts as MessageParam))
|
||||
const newBody = body
|
||||
newBody.messages = userMessages
|
||||
|
||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
try {
|
||||
await processStream(newBody, idx + 1)
|
||||
} catch (error) {
|
||||
console.error('Error processing stream:', error)
|
||||
reject(error)
|
||||
}
|
||||
}
|
||||
|
||||
// 直接修改finalUsage对象会报错,TypeError: Cannot assign to read only property 'prompt_tokens' of object '#<Object>'
|
||||
// 暂未找到原因
|
||||
const updatedUsage: Usage = {
|
||||
...finalUsage,
|
||||
prompt_tokens: finalUsage.prompt_tokens + (message.usage?.input_tokens || 0),
|
||||
completion_tokens: finalUsage.completion_tokens + (message.usage?.output_tokens || 0)
|
||||
}
|
||||
updatedUsage.total_tokens = updatedUsage.prompt_tokens + updatedUsage.completion_tokens
|
||||
|
||||
const updatedMetrics: Metrics = {
|
||||
...finalMetrics,
|
||||
completion_tokens: updatedUsage.completion_tokens,
|
||||
time_completion_millsec:
|
||||
finalMetrics.time_completion_millsec + (new Date().getTime() - start_time_millsec),
|
||||
time_first_token_millsec: time_first_token_millsec - start_time_millsec
|
||||
}
|
||||
|
||||
Object.assign(finalUsage, updatedUsage)
|
||||
Object.assign(finalMetrics, updatedMetrics)
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage: updatedUsage,
|
||||
metrics: updatedMetrics
|
||||
}
|
||||
})
|
||||
resolve()
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
.on('abort', () => {
|
||||
reject(new Error('Request was aborted.'))
|
||||
})
|
||||
})
|
||||
}
|
||||
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
const start_time_millsec = new Date().getTime()
|
||||
await processStream(body, 0).finally(() => {
|
||||
cleanup()
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Translate a message
|
||||
* @param content
|
||||
* @param assistant - The assistant
|
||||
* @param onResponse - The onResponse callback
|
||||
* @returns The translated message
|
||||
*/
|
||||
public async translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
|
||||
const messagesForApi = [{ role: 'user' as const, content: content }]
|
||||
|
||||
const stream = !!onResponse
|
||||
|
||||
const body: MessageCreateParamsNonStreaming = {
|
||||
model: model.id,
|
||||
messages: messagesForApi,
|
||||
max_tokens: 4096,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
system: assistant.prompt
|
||||
}
|
||||
|
||||
if (!stream) {
|
||||
const response = await this.sdk.messages.create({ ...body, stream: false })
|
||||
return response.content[0].type === 'text' ? response.content[0].text : ''
|
||||
}
|
||||
|
||||
let text = ''
|
||||
|
||||
return new Promise<string>((resolve, reject) => {
|
||||
this.sdk.messages
|
||||
.stream({ ...body, stream: true })
|
||||
.on('text', (_text) => {
|
||||
text += _text
|
||||
onResponse?.(text, false)
|
||||
})
|
||||
.on('finalMessage', () => {
|
||||
onResponse?.(text, true)
|
||||
resolve(text)
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize a message
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @returns The summary
|
||||
*/
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
|
||||
const userMessages = takeRight(messages, 5).map((message) => ({
|
||||
role: message.role,
|
||||
content: getMainTextContent(message)
|
||||
}))
|
||||
|
||||
if (first(userMessages)?.role === 'assistant') {
|
||||
userMessages.shift()
|
||||
}
|
||||
|
||||
const userMessageContent = userMessages.reduce((prev, curr) => {
|
||||
const currentContent = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}`
|
||||
return prev + (prev ? '\n' : '') + currentContent
|
||||
}, '')
|
||||
|
||||
const systemMessage = {
|
||||
role: 'system',
|
||||
content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
|
||||
}
|
||||
|
||||
const userMessage = {
|
||||
role: 'user',
|
||||
content: userMessageContent
|
||||
}
|
||||
|
||||
const message = await this.sdk.messages.create({
|
||||
messages: [userMessage] as Anthropic.Messages.MessageParam[],
|
||||
model: model.id,
|
||||
system: systemMessage.content,
|
||||
stream: false,
|
||||
max_tokens: 4096
|
||||
})
|
||||
|
||||
const responseContent = message.content[0].type === 'text' ? message.content[0].text : ''
|
||||
return removeSpecialCharactersForTopicName(responseContent)
|
||||
}
|
||||
|
||||
/**
|
||||
* Summarize a message for search
|
||||
* @param messages - The messages
|
||||
* @param assistant - The assistant
|
||||
* @returns The summary
|
||||
*/
|
||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const systemMessage = { content: assistant.prompt }
|
||||
|
||||
const userMessageContent = messages.map((m) => getMainTextContent(m)).join('\n')
|
||||
|
||||
const userMessage = {
|
||||
role: 'user' as const,
|
||||
content: userMessageContent
|
||||
}
|
||||
const lastUserMessage = messages[messages.length - 1]
|
||||
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
|
||||
const { signal } = abortController
|
||||
|
||||
const response = await this.sdk.messages
|
||||
.create(
|
||||
{
|
||||
messages: [userMessage],
|
||||
model: model.id,
|
||||
system: systemMessage.content,
|
||||
stream: false,
|
||||
max_tokens: 4096
|
||||
},
|
||||
{ timeout: 20 * 1000, signal }
|
||||
)
|
||||
.finally(cleanup)
|
||||
|
||||
return response.content[0].type === 'text' ? response.content[0].text : ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate text
|
||||
* @param prompt - The prompt
|
||||
* @param content - The content
|
||||
* @returns The generated text
|
||||
*/
|
||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
const model = getDefaultModel()
|
||||
|
||||
const message = await this.sdk.messages.create({
|
||||
model: model.id,
|
||||
system: prompt,
|
||||
stream: false,
|
||||
max_tokens: 4096,
|
||||
messages: [
|
||||
{
|
||||
role: 'user',
|
||||
content
|
||||
}
|
||||
]
|
||||
})
|
||||
|
||||
return message.content[0].type === 'text' ? message.content[0].text : ''
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate an image
|
||||
* @returns The generated image
|
||||
*/
|
||||
public async generateImage(): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
public async generateImageByChat(): Promise<void> {
|
||||
throw new Error('Method not implemented.')
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate suggestions
|
||||
* @returns The suggestions
|
||||
*/
|
||||
public async suggestions(): Promise<Suggestion[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if the model is valid
|
||||
* @param model - The model
|
||||
* @param stream - Whether to use streaming interface
|
||||
* @returns The validity of the model
|
||||
*/
|
||||
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
||||
if (!model) {
|
||||
return { valid: false, error: new Error('No model found') }
|
||||
}
|
||||
|
||||
const body = {
|
||||
model: model.id,
|
||||
messages: [{ role: 'user' as const, content: 'hi' }],
|
||||
max_tokens: 2, // api文档写的 x>1
|
||||
stream
|
||||
}
|
||||
|
||||
try {
|
||||
if (!stream) {
|
||||
const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming)
|
||||
return {
|
||||
valid: message.content.length > 0,
|
||||
error: null
|
||||
}
|
||||
} else {
|
||||
return await new Promise((resolve, reject) => {
|
||||
let hasContent = false
|
||||
this.sdk.messages
|
||||
.stream(body)
|
||||
.on('text', (text) => {
|
||||
if (!hasContent && text) {
|
||||
hasContent = true
|
||||
resolve({ valid: true, error: null })
|
||||
}
|
||||
})
|
||||
.on('finalMessage', (message) => {
|
||||
if (!hasContent && message.content && message.content.length > 0) {
|
||||
hasContent = true
|
||||
resolve({ valid: true, error: null })
|
||||
}
|
||||
if (!hasContent) {
|
||||
reject(new Error('Empty streaming response'))
|
||||
}
|
||||
})
|
||||
.on('error', (error) => reject(error))
|
||||
})
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the models
|
||||
* @returns The models
|
||||
*/
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(): Promise<number> {
|
||||
return 0
|
||||
}
|
||||
|
||||
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
|
||||
return mcpToolsToAnthropicTools(mcpTools) as T[]
|
||||
}
|
||||
|
||||
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: mcpToolResponse.toolCallId!,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return {
|
||||
type: 'text',
|
||||
text: item.text || ''
|
||||
} satisfies TextBlockParam
|
||||
}
|
||||
if (item.type === 'image') {
|
||||
return {
|
||||
type: 'image',
|
||||
source: {
|
||||
data: item.data || '',
|
||||
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||
type: 'base64'
|
||||
}
|
||||
} satisfies ImageBlockParam
|
||||
}
|
||||
return
|
||||
})
|
||||
.filter((n) => typeof n !== 'undefined'),
|
||||
is_error: resp.isError
|
||||
} satisfies ToolResultBlockParam
|
||||
]
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -1,33 +0,0 @@
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import AihubmixProvider from './AihubmixProvider'
|
||||
import AnthropicProvider from './AnthropicProvider'
|
||||
import BaseProvider from './BaseProvider'
|
||||
import GeminiProvider from './GeminiProvider'
|
||||
import OpenAIProvider from './OpenAIProvider'
|
||||
import OpenAIResponseProvider from './OpenAIResponseProvider'
|
||||
|
||||
export default class ProviderFactory {
|
||||
static create(provider: Provider): BaseProvider {
|
||||
if (provider.id === 'aihubmix') {
|
||||
return new AihubmixProvider(provider)
|
||||
}
|
||||
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
return new OpenAIProvider(provider)
|
||||
case 'openai-response':
|
||||
return new OpenAIResponseProvider(provider)
|
||||
case 'anthropic':
|
||||
return new AnthropicProvider(provider)
|
||||
case 'gemini':
|
||||
return new GeminiProvider(provider)
|
||||
default:
|
||||
return new OpenAIProvider(provider)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export function isOpenAIProvider(provider: Provider) {
|
||||
return !['anthropic', 'gemini'].includes(provider.type)
|
||||
}
|
||||
@ -1,94 +0,0 @@
|
||||
import { GenerateImagesParameters } from '@google/genai'
|
||||
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
|
||||
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
|
||||
import type { Assistant, GenerateImageParams, MCPTool, Model, Provider, Suggestion } from '@renderer/types'
|
||||
import { Chunk } from '@renderer/types/chunk'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
export interface CompletionsParams {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
onChunk: (chunk: Chunk) => void
|
||||
onFilterMessages: (messages: Message[]) => void
|
||||
mcpTools?: MCPTool[]
|
||||
}
|
||||
|
||||
export default class AiProvider {
|
||||
private sdk: BaseProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
this.sdk = ProviderFactory.create(provider)
|
||||
}
|
||||
|
||||
public async fakeCompletions(params: CompletionsParams): Promise<void> {
|
||||
return this.sdk.fakeCompletions(params)
|
||||
}
|
||||
|
||||
public async completions({
|
||||
messages,
|
||||
assistant,
|
||||
mcpTools,
|
||||
onChunk,
|
||||
onFilterMessages
|
||||
}: CompletionsParams): Promise<void> {
|
||||
return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
|
||||
}
|
||||
|
||||
public async translate(
|
||||
content: string,
|
||||
assistant: Assistant,
|
||||
onResponse?: (text: string, isComplete: boolean) => void
|
||||
): Promise<string> {
|
||||
return this.sdk.translate(content, assistant, onResponse)
|
||||
}
|
||||
|
||||
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
|
||||
return this.sdk.summaries(messages, assistant)
|
||||
}
|
||||
|
||||
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
|
||||
return this.sdk.summaryForSearch(messages, assistant)
|
||||
}
|
||||
|
||||
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
|
||||
return this.sdk.suggestions(messages, assistant)
|
||||
}
|
||||
|
||||
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
return this.sdk.generateText({ prompt, content })
|
||||
}
|
||||
|
||||
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
|
||||
return this.sdk.check(model, stream)
|
||||
}
|
||||
|
||||
public async models(): Promise<OpenAI.Models.Model[]> {
|
||||
return this.sdk.models()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.sdk.getApiKey()
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams | GenerateImagesParameters): Promise<string[]> {
|
||||
return this.sdk.generateImage(params as GenerateImageParams)
|
||||
}
|
||||
|
||||
public async generateImageByChat({
|
||||
messages,
|
||||
assistant,
|
||||
onChunk,
|
||||
onFilterMessages
|
||||
}: CompletionsParams): Promise<void> {
|
||||
return this.sdk.generateImageByChat({ messages, assistant, onChunk, onFilterMessages })
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.sdk.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.sdk.getBaseURL()
|
||||
}
|
||||
}
|
||||
@ -1,10 +1,21 @@
|
||||
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { getOpenAIWebSearchParams, isOpenAIWebSearch } from '@renderer/config/models'
|
||||
import {
|
||||
isEmbeddingModel,
|
||||
isGenerateImageModel,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isReasoningModel,
|
||||
isSupportedDisableGenerationModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import {
|
||||
SEARCH_SUMMARY_PROMPT,
|
||||
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
|
||||
SEARCH_SUMMARY_PROMPT_WEB_ONLY
|
||||
} from '@renderer/config/prompts'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import {
|
||||
Assistant,
|
||||
@ -13,20 +24,22 @@ import {
|
||||
MCPTool,
|
||||
Model,
|
||||
Provider,
|
||||
Suggestion,
|
||||
WebSearchResponse,
|
||||
WebSearchSource
|
||||
} from '@renderer/types'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import { isAbortError } from '@renderer/utils/error'
|
||||
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
|
||||
import { getKnowledgeBaseIds, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { findLast, isEmpty } from 'lodash'
|
||||
import { findLast, isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import AiProvider from '../providers/AiProvider'
|
||||
import AiProvider from '../aiCore'
|
||||
import {
|
||||
getAssistantProvider,
|
||||
getAssistantSettings,
|
||||
getDefaultModel,
|
||||
getProviderByModel,
|
||||
getTopNamingModel,
|
||||
@ -34,7 +47,13 @@ import {
|
||||
} from './AssistantService'
|
||||
import { getDefaultAssistant } from './AssistantService'
|
||||
import { processKnowledgeSearch } from './KnowledgeService'
|
||||
import { filterContextMessages, filterMessages, filterUsefulMessages } from './MessagesService'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
filterMessages,
|
||||
filterUsefulMessages,
|
||||
filterUserRoleStartMessages
|
||||
} from './MessagesService'
|
||||
import WebSearchService from './WebSearchService'
|
||||
|
||||
// TODO:考虑拆开
|
||||
@ -50,6 +69,7 @@ async function fetchExternalTool(
|
||||
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
|
||||
const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
|
||||
|
||||
// 使用外部搜索工具
|
||||
const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
|
||||
const shouldKnowledgeSearch = hasKnowledgeBase
|
||||
|
||||
@ -83,14 +103,14 @@ async function fetchExternalTool(
|
||||
summaryAssistant.prompt = prompt
|
||||
|
||||
try {
|
||||
const keywords = await fetchSearchSummary({
|
||||
const result = await fetchSearchSummary({
|
||||
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
|
||||
assistant: summaryAssistant
|
||||
})
|
||||
|
||||
if (!keywords) return getFallbackResult()
|
||||
if (!result) return getFallbackResult()
|
||||
|
||||
const extracted = extractInfoFromXML(keywords)
|
||||
const extracted = extractInfoFromXML(result.getText())
|
||||
// 根据需求过滤结果
|
||||
return {
|
||||
websearch: needWebExtract ? extracted?.websearch : undefined,
|
||||
@ -134,12 +154,6 @@ async function fetchExternalTool(
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Pass the guaranteed model to the check function
|
||||
const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
|
||||
if (!isEmpty(webSearchParams) || isOpenAIWebSearch(assistant.model)) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Use the consolidated processWebsearch function
|
||||
WebSearchService.createAbortSignal(lastUserMessage.id)
|
||||
@ -238,7 +252,7 @@ async function fetchExternalTool(
|
||||
|
||||
// Get MCP tools (Fix duplicate declaration)
|
||||
let mcpTools: MCPTool[] = [] // Initialize as empty array
|
||||
const enabledMCPs = lastUserMessage?.enabledMCPs
|
||||
const enabledMCPs = assistant.mcpServers
|
||||
if (enabledMCPs && enabledMCPs.length > 0) {
|
||||
try {
|
||||
const toolPromises = enabledMCPs.map(async (mcpServer) => {
|
||||
@ -301,17 +315,52 @@ export async function fetchChatCompletion({
|
||||
// NOTE: The search results are NOT added to the messages sent to the AI here.
|
||||
// They will be retrieved and used by the messageThunk later to create CitationBlocks.
|
||||
const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer)
|
||||
const model = assistant.model || getDefaultModel()
|
||||
|
||||
const { maxTokens, contextCount } = getAssistantSettings(assistant)
|
||||
|
||||
const filteredMessages = filterUsefulMessages(messages)
|
||||
|
||||
const _messages = filterUserRoleStartMessages(
|
||||
filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值
|
||||
)
|
||||
|
||||
const enableReasoning =
|
||||
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
||||
assistant.settings?.reasoning_effort !== undefined) ||
|
||||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
|
||||
|
||||
const enableWebSearch =
|
||||
(assistant.enableWebSearch && isWebSearchModel(model)) ||
|
||||
isOpenRouterBuiltInWebSearchModel(model) ||
|
||||
model.id.includes('sonar') ||
|
||||
false
|
||||
|
||||
const enableGenerateImage =
|
||||
isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true)
|
||||
|
||||
// --- Call AI Completions ---
|
||||
await AI.completions({
|
||||
messages: filteredMessages,
|
||||
assistant,
|
||||
onFilterMessages: () => {},
|
||||
onChunk: onChunkReceived,
|
||||
mcpTools: mcpTools
|
||||
})
|
||||
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
if (enableWebSearch) {
|
||||
onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
|
||||
}
|
||||
await AI.completions(
|
||||
{
|
||||
callType: 'chat',
|
||||
messages: _messages,
|
||||
assistant,
|
||||
onChunk: onChunkReceived,
|
||||
mcpTools: mcpTools,
|
||||
maxTokens,
|
||||
streamOutput: assistant.settings?.streamOutput || false,
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
},
|
||||
{
|
||||
streamOutput: assistant.settings?.streamOutput || false
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
interface FetchTranslateProps {
|
||||
@ -321,7 +370,7 @@ interface FetchTranslateProps {
|
||||
}
|
||||
|
||||
export async function fetchTranslate({ content, assistant, onResponse }: FetchTranslateProps) {
|
||||
const model = getTranslateModel()
|
||||
const model = getTranslateModel() || assistant.model || getDefaultModel()
|
||||
|
||||
if (!model) {
|
||||
throw new Error(i18n.t('error.provider_disabled'))
|
||||
@ -333,17 +382,42 @@ export async function fetchTranslate({ content, assistant, onResponse }: FetchTr
|
||||
throw new Error(i18n.t('error.no_api_key'))
|
||||
}
|
||||
|
||||
const isSupportedStreamOutput = () => {
|
||||
if (!onResponse) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
const stream = isSupportedStreamOutput()
|
||||
const enableReasoning =
|
||||
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
|
||||
assistant.settings?.reasoning_effort !== undefined) ||
|
||||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
|
||||
|
||||
const params: CompletionsParams = {
|
||||
callType: 'translate',
|
||||
messages: content,
|
||||
assistant: { ...assistant, model },
|
||||
streamOutput: stream,
|
||||
enableReasoning,
|
||||
onResponse
|
||||
}
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
return await AI.translate(content, assistant, onResponse)
|
||||
return (await AI.completions(params)).getText() || ''
|
||||
} catch (error: any) {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||||
const prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
|
||||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||||
const userMessages = takeRight(messages, 5)
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (!hasApiKey(provider)) {
|
||||
@ -352,9 +426,18 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
const params: CompletionsParams = {
|
||||
callType: 'summary',
|
||||
messages: filterMessages(userMessages),
|
||||
assistant: { ...assistant, prompt, model },
|
||||
maxTokens: 1000,
|
||||
streamOutput: false
|
||||
}
|
||||
|
||||
try {
|
||||
const text = await AI.summaries(filterMessages(messages), assistant)
|
||||
return text?.replace(/["']/g, '') || null
|
||||
const { getText } = await AI.completions(params)
|
||||
const text = getText()
|
||||
return removeSpecialCharactersForTopicName(text) || null
|
||||
} catch (error: any) {
|
||||
return null
|
||||
}
|
||||
@ -370,7 +453,14 @@ export async function fetchSearchSummary({ messages, assistant }: { messages: Me
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
return await AI.summaryForSearch(messages, assistant)
|
||||
const params: CompletionsParams = {
|
||||
callType: 'search',
|
||||
messages: messages,
|
||||
assistant,
|
||||
streamOutput: false
|
||||
}
|
||||
|
||||
return await AI.completions(params)
|
||||
}
|
||||
|
||||
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||||
@ -383,42 +473,32 @@ export async function fetchGenerate({ prompt, content }: { prompt: string; conte
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
assistant.prompt = prompt
|
||||
|
||||
const params: CompletionsParams = {
|
||||
callType: 'generate',
|
||||
messages: content,
|
||||
assistant,
|
||||
streamOutput: false
|
||||
}
|
||||
|
||||
try {
|
||||
return await AI.generateText({ prompt, content })
|
||||
const result = await AI.completions(params)
|
||||
return result.getText() || ''
|
||||
} catch (error: any) {
|
||||
return ''
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchSuggestions({
|
||||
messages,
|
||||
assistant
|
||||
}: {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
}): Promise<Suggestion[]> {
|
||||
const model = assistant.model
|
||||
if (!model || model.id.endsWith('global')) {
|
||||
return []
|
||||
}
|
||||
|
||||
const provider = getAssistantProvider(assistant)
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
return await AI.suggestions(filterMessages(messages), assistant)
|
||||
} catch (error: any) {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
function hasApiKey(provider: Provider) {
|
||||
if (!provider) return false
|
||||
if (provider.id === 'ollama' || provider.id === 'lmstudio') return true
|
||||
return !isEmpty(provider.apiKey)
|
||||
}
|
||||
|
||||
export async function fetchModels(provider: Provider) {
|
||||
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
@ -432,68 +512,69 @@ export const formatApiKeys = (value: string) => {
|
||||
return value.replaceAll(',', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',')
|
||||
}
|
||||
|
||||
export function checkApiProvider(provider: Provider): {
|
||||
valid: boolean
|
||||
error: Error | null
|
||||
} {
|
||||
export function checkApiProvider(provider: Provider): void {
|
||||
const key = 'api-check'
|
||||
const style = { marginTop: '3vh' }
|
||||
|
||||
if (provider.id !== 'ollama' && provider.id !== 'lmstudio') {
|
||||
if (!provider.apiKey) {
|
||||
window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style })
|
||||
return {
|
||||
valid: false,
|
||||
error: new Error(i18n.t('message.error.enter.api.key'))
|
||||
}
|
||||
throw new Error(i18n.t('message.error.enter.api.key'))
|
||||
}
|
||||
}
|
||||
|
||||
if (!provider.apiHost) {
|
||||
window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style })
|
||||
return {
|
||||
valid: false,
|
||||
error: new Error(i18n.t('message.error.enter.api.host'))
|
||||
}
|
||||
throw new Error(i18n.t('message.error.enter.api.host'))
|
||||
}
|
||||
|
||||
if (isEmpty(provider.models)) {
|
||||
window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
|
||||
return {
|
||||
valid: false,
|
||||
error: new Error(i18n.t('message.error.enter.model'))
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
valid: true,
|
||||
error: null
|
||||
throw new Error(i18n.t('message.error.enter.model'))
|
||||
}
|
||||
}
|
||||
|
||||
export async function checkApi(provider: Provider, model: Model): Promise<{ valid: boolean; error: Error | null }> {
|
||||
const validation = checkApiProvider(provider)
|
||||
if (!validation.valid) {
|
||||
return {
|
||||
valid: validation.valid,
|
||||
error: validation.error
|
||||
}
|
||||
}
|
||||
export async function checkApi(provider: Provider, model: Model): Promise<void> {
|
||||
checkApiProvider(provider)
|
||||
|
||||
const ai = new AiProvider(provider)
|
||||
|
||||
// Try streaming check first
|
||||
const result = await ai.check(model, true)
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
try {
|
||||
if (isEmbeddingModel(model)) {
|
||||
const result = await ai.getEmbeddingDimensions(model)
|
||||
if (result === 0) {
|
||||
throw new Error(i18n.t('message.error.enter.model'))
|
||||
}
|
||||
} else {
|
||||
const params: CompletionsParams = {
|
||||
callType: 'check',
|
||||
messages: 'hi',
|
||||
assistant,
|
||||
streamOutput: true
|
||||
}
|
||||
|
||||
if (result.valid && !result.error) {
|
||||
return result
|
||||
}
|
||||
|
||||
// 不应该假设错误由流式引发。多次发起检测请求可能触发429,掩盖了真正的问题。
|
||||
// 但这里错误类型做的很粗糙,暂时先这样
|
||||
if (result.error && result.error.message.includes('stream')) {
|
||||
return ai.check(model, false)
|
||||
} else {
|
||||
return result
|
||||
// Try streaming check first
|
||||
const result = await ai.completions(params)
|
||||
if (!result.getText()) {
|
||||
throw new Error('No response received')
|
||||
}
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error.message.includes('stream')) {
|
||||
const params: CompletionsParams = {
|
||||
callType: 'check',
|
||||
messages: 'hi',
|
||||
assistant,
|
||||
streamOutput: false
|
||||
}
|
||||
const result = await ai.completions(params)
|
||||
if (!result.getText()) {
|
||||
throw new Error('No response received')
|
||||
}
|
||||
} else {
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -98,14 +98,20 @@ export async function checkModelWithMultipleKeys(
|
||||
if (isParallel) {
|
||||
// Check all API keys in parallel
|
||||
const keyPromises = apiKeys.map(async (key) => {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
|
||||
return {
|
||||
key,
|
||||
isValid: result.valid,
|
||||
error: result.error?.message,
|
||||
latency: result.latency
|
||||
} as ApiKeyCheckStatus
|
||||
try {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
return {
|
||||
key,
|
||||
isValid: true,
|
||||
latency: result.latency
|
||||
} as ApiKeyCheckStatus
|
||||
} catch (error: unknown) {
|
||||
return {
|
||||
key,
|
||||
isValid: false,
|
||||
error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...'
|
||||
} as ApiKeyCheckStatus
|
||||
}
|
||||
})
|
||||
|
||||
const results = await Promise.allSettled(keyPromises)
|
||||
@ -125,14 +131,20 @@ export async function checkModelWithMultipleKeys(
|
||||
} else {
|
||||
// Check all API keys serially
|
||||
for (const key of apiKeys) {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
|
||||
keyResults.push({
|
||||
key,
|
||||
isValid: result.valid,
|
||||
error: result.error?.message,
|
||||
latency: result.latency
|
||||
})
|
||||
try {
|
||||
const result = await checkModel({ ...provider, apiKey: key }, model)
|
||||
keyResults.push({
|
||||
key,
|
||||
isValid: true,
|
||||
latency: result.latency
|
||||
})
|
||||
} catch (error: unknown) {
|
||||
keyResults.push({
|
||||
key,
|
||||
isValid: false,
|
||||
error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...'
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1,8 +1,8 @@
|
||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||
import AiProvider from '@renderer/aiCore'
|
||||
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
|
||||
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import store from '@renderer/store'
|
||||
import { FileType, KnowledgeBase, KnowledgeBaseParams, KnowledgeReference } from '@renderer/types'
|
||||
import { ExtractResults } from '@renderer/utils/extract'
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import { isEmbeddingModel } from '@renderer/config/models'
|
||||
import AiProvider from '@renderer/providers/AiProvider'
|
||||
import store from '@renderer/store'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { t } from 'i18next'
|
||||
import { pick } from 'lodash'
|
||||
|
||||
import { checkApiProvider } from './ApiService'
|
||||
import { checkApi } from './ApiService'
|
||||
|
||||
export const getModelUniqId = (m?: Model) => {
|
||||
return m?.id ? JSON.stringify(pick(m, ['id', 'provider'])) : ''
|
||||
@ -33,64 +31,23 @@ export function getModelName(model?: Model) {
|
||||
return modelName
|
||||
}
|
||||
|
||||
// Generic function to perform model checks
|
||||
// Abstracts provider validation and error handling, allowing different types of check logic
|
||||
// Generic function to perform model checks with exception handling
|
||||
async function performModelCheck<T>(
|
||||
provider: Provider,
|
||||
model: Model,
|
||||
checkFn: (ai: AiProvider, model: Model) => Promise<T>,
|
||||
processResult: (result: T) => { valid: boolean; error: Error | null }
|
||||
): Promise<{ valid: boolean; error: Error | null; latency?: number }> {
|
||||
const validation = checkApiProvider(provider)
|
||||
if (!validation.valid) {
|
||||
return {
|
||||
valid: validation.valid,
|
||||
error: validation.error
|
||||
}
|
||||
}
|
||||
checkFn: (provider: Provider, model: Model) => Promise<T>
|
||||
): Promise<{ latency: number }> {
|
||||
const startTime = performance.now()
|
||||
await checkFn(provider, model)
|
||||
const latency = performance.now() - startTime
|
||||
|
||||
const AI = new AiProvider(provider)
|
||||
|
||||
try {
|
||||
const startTime = performance.now()
|
||||
const result = await checkFn(AI, model)
|
||||
const latency = performance.now() - startTime
|
||||
|
||||
return {
|
||||
...processResult(result),
|
||||
latency
|
||||
}
|
||||
} catch (error: any) {
|
||||
return {
|
||||
valid: false,
|
||||
error
|
||||
}
|
||||
}
|
||||
return { latency }
|
||||
}
|
||||
|
||||
// Unified model check function
|
||||
// Automatically selects appropriate check method based on model type
|
||||
export async function checkModel(provider: Provider, model: Model) {
|
||||
if (isEmbeddingModel(model)) {
|
||||
return performModelCheck(
|
||||
provider,
|
||||
model,
|
||||
(ai, model) => ai.getEmbeddingDimensions(model),
|
||||
(dimensions) => ({ valid: dimensions > 0, error: null })
|
||||
)
|
||||
} else {
|
||||
return performModelCheck(
|
||||
provider,
|
||||
model,
|
||||
async (ai, model) => {
|
||||
// Try streaming check first
|
||||
const result = await ai.check(model, true)
|
||||
if (result.valid && !result.error) {
|
||||
return result
|
||||
}
|
||||
return ai.check(model, false)
|
||||
},
|
||||
({ valid, error }) => ({ valid, error: error || null })
|
||||
)
|
||||
}
|
||||
export async function checkModel(provider: Provider, model: Model): Promise<{ latency: number }> {
|
||||
return performModelCheck(provider, model, async (provider, model) => {
|
||||
await checkApi(provider, model)
|
||||
})
|
||||
}
|
||||
|
||||
@ -28,7 +28,9 @@ export interface StreamProcessorCallbacks {
|
||||
onLLMWebSearchComplete?: (llmWebSearchResult: WebSearchResponse) => void
|
||||
// Image generation chunk received
|
||||
onImageCreated?: () => void
|
||||
onImageGenerated?: (imageData: GenerateImageResponse) => void
|
||||
onImageDelta?: (imageData: GenerateImageResponse) => void
|
||||
onImageGenerated?: (imageData?: GenerateImageResponse) => void
|
||||
onLLMResponseComplete?: (response?: Response) => void
|
||||
// Called when an error occurs during chunk processing
|
||||
onError?: (error: any) => void
|
||||
// Called when the entire stream processing is signaled as complete (success or failure)
|
||||
@ -40,59 +42,84 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
|
||||
// The returned function processes a single chunk or a final signal
|
||||
return (chunk: Chunk) => {
|
||||
try {
|
||||
// Logger.log(`[${new Date().toLocaleString()}] createStreamProcessor ${chunk.type}`, chunk)
|
||||
// 1. Handle the manual final signal first
|
||||
if (chunk?.type === ChunkType.BLOCK_COMPLETE) {
|
||||
callbacks.onComplete?.(AssistantMessageStatus.SUCCESS, chunk?.response)
|
||||
return
|
||||
const data = chunk
|
||||
switch (data.type) {
|
||||
case ChunkType.BLOCK_COMPLETE: {
|
||||
if (callbacks.onComplete) callbacks.onComplete(AssistantMessageStatus.SUCCESS, data?.response)
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_RESPONSE_CREATED: {
|
||||
if (callbacks.onLLMResponseCreated) callbacks.onLLMResponseCreated()
|
||||
break
|
||||
}
|
||||
case ChunkType.TEXT_DELTA: {
|
||||
if (callbacks.onTextChunk) callbacks.onTextChunk(data.text)
|
||||
break
|
||||
}
|
||||
case ChunkType.TEXT_COMPLETE: {
|
||||
if (callbacks.onTextComplete) callbacks.onTextComplete(data.text)
|
||||
break
|
||||
}
|
||||
case ChunkType.THINKING_DELTA: {
|
||||
if (callbacks.onThinkingChunk) callbacks.onThinkingChunk(data.text, data.thinking_millsec)
|
||||
break
|
||||
}
|
||||
case ChunkType.THINKING_COMPLETE: {
|
||||
if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec)
|
||||
break
|
||||
}
|
||||
case ChunkType.MCP_TOOL_IN_PROGRESS: {
|
||||
if (callbacks.onToolCallInProgress)
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
|
||||
break
|
||||
}
|
||||
case ChunkType.MCP_TOOL_COMPLETE: {
|
||||
if (callbacks.onToolCallComplete && data.responses.length > 0) {
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp))
|
||||
}
|
||||
break
|
||||
}
|
||||
case ChunkType.EXTERNEL_TOOL_IN_PROGRESS: {
|
||||
if (callbacks.onExternalToolInProgress) callbacks.onExternalToolInProgress()
|
||||
break
|
||||
}
|
||||
case ChunkType.EXTERNEL_TOOL_COMPLETE: {
|
||||
if (callbacks.onExternalToolComplete) callbacks.onExternalToolComplete(data.external_tool)
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_WEB_SEARCH_IN_PROGRESS: {
|
||||
if (callbacks.onLLMWebSearchInProgress) callbacks.onLLMWebSearchInProgress()
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_WEB_SEARCH_COMPLETE: {
|
||||
if (callbacks.onLLMWebSearchComplete) callbacks.onLLMWebSearchComplete(data.llm_web_search)
|
||||
break
|
||||
}
|
||||
case ChunkType.IMAGE_CREATED: {
|
||||
if (callbacks.onImageCreated) callbacks.onImageCreated()
|
||||
break
|
||||
}
|
||||
case ChunkType.IMAGE_DELTA: {
|
||||
if (callbacks.onImageDelta) callbacks.onImageDelta(data.image)
|
||||
break
|
||||
}
|
||||
case ChunkType.IMAGE_COMPLETE: {
|
||||
if (callbacks.onImageGenerated) callbacks.onImageGenerated(data.image)
|
||||
break
|
||||
}
|
||||
case ChunkType.LLM_RESPONSE_COMPLETE: {
|
||||
if (callbacks.onLLMResponseComplete) callbacks.onLLMResponseComplete(data.response)
|
||||
break
|
||||
}
|
||||
case ChunkType.ERROR: {
|
||||
if (callbacks.onError) callbacks.onError(data.error)
|
||||
break
|
||||
}
|
||||
default: {
|
||||
// Handle unknown chunk types or log an error
|
||||
console.warn(`Unknown chunk type: ${data.type}`)
|
||||
}
|
||||
}
|
||||
// 2. Process the actual ChunkCallbackData
|
||||
const data = chunk // Cast after checking for 'final'
|
||||
// Invoke callbacks based on the fields present in the chunk data
|
||||
if (data.type === ChunkType.LLM_RESPONSE_CREATED && callbacks.onLLMResponseCreated) {
|
||||
callbacks.onLLMResponseCreated()
|
||||
}
|
||||
if (data.type === ChunkType.TEXT_DELTA && callbacks.onTextChunk) {
|
||||
callbacks.onTextChunk(data.text)
|
||||
}
|
||||
if (data.type === ChunkType.TEXT_COMPLETE && callbacks.onTextComplete) {
|
||||
callbacks.onTextComplete(data.text)
|
||||
}
|
||||
if (data.type === ChunkType.THINKING_DELTA && callbacks.onThinkingChunk) {
|
||||
callbacks.onThinkingChunk(data.text, data.thinking_millsec)
|
||||
}
|
||||
if (data.type === ChunkType.THINKING_COMPLETE && callbacks.onThinkingComplete) {
|
||||
callbacks.onThinkingComplete(data.text, data.thinking_millsec)
|
||||
}
|
||||
if (data.type === ChunkType.MCP_TOOL_IN_PROGRESS && callbacks.onToolCallInProgress) {
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
|
||||
}
|
||||
if (data.type === ChunkType.MCP_TOOL_COMPLETE && data.responses.length > 0 && callbacks.onToolCallComplete) {
|
||||
data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp))
|
||||
}
|
||||
if (data.type === ChunkType.EXTERNEL_TOOL_IN_PROGRESS && callbacks.onExternalToolInProgress) {
|
||||
callbacks.onExternalToolInProgress()
|
||||
}
|
||||
if (data.type === ChunkType.EXTERNEL_TOOL_COMPLETE && callbacks.onExternalToolComplete) {
|
||||
callbacks.onExternalToolComplete(data.external_tool)
|
||||
}
|
||||
if (data.type === ChunkType.LLM_WEB_SEARCH_IN_PROGRESS && callbacks.onLLMWebSearchInProgress) {
|
||||
callbacks.onLLMWebSearchInProgress()
|
||||
}
|
||||
if (data.type === ChunkType.LLM_WEB_SEARCH_COMPLETE && callbacks.onLLMWebSearchComplete) {
|
||||
callbacks.onLLMWebSearchComplete(data.llm_web_search)
|
||||
}
|
||||
if (data.type === ChunkType.IMAGE_CREATED && callbacks.onImageCreated) {
|
||||
callbacks.onImageCreated()
|
||||
}
|
||||
if (data.type === ChunkType.IMAGE_COMPLETE && callbacks.onImageGenerated) {
|
||||
callbacks.onImageGenerated(data.image)
|
||||
}
|
||||
if (data.type === ChunkType.ERROR && callbacks.onError) {
|
||||
callbacks.onError(data.error)
|
||||
}
|
||||
// Note: Usage and Metrics are usually handled at the end or accumulated differently,
|
||||
// so direct callbacks might not be the best fit here. They are often part of the final message state.
|
||||
} catch (error) {
|
||||
console.error('Error processing stream chunk:', error)
|
||||
callbacks.onError?.(error)
|
||||
|
||||
@ -8,7 +8,6 @@ import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/
|
||||
import { estimateMessagesUsage } from '@renderer/services/TokenService'
|
||||
import store from '@renderer/store'
|
||||
import type { Assistant, ExternalToolResult, FileType, MCPToolResponse, Model, Topic } from '@renderer/types'
|
||||
import { WebSearchSource } from '@renderer/types'
|
||||
import type {
|
||||
CitationMessageBlock,
|
||||
FileMessageBlock,
|
||||
@ -22,7 +21,6 @@ import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@r
|
||||
import { Response } from '@renderer/types/newMessage'
|
||||
import { uuid } from '@renderer/utils'
|
||||
import { formatErrorMessage, isAbortError } from '@renderer/utils/error'
|
||||
import { extractUrlsFromMarkdown } from '@renderer/utils/linkConverter'
|
||||
import {
|
||||
createAssistantMessage,
|
||||
createBaseMessageBlock,
|
||||
@ -35,7 +33,8 @@ import {
|
||||
createTranslationBlock,
|
||||
resetAssistantMessage
|
||||
} from '@renderer/utils/messageUtils/create'
|
||||
import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
|
||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { getTopicQueue } from '@renderer/utils/queue'
|
||||
import { isOnHomePage } from '@renderer/utils/window'
|
||||
import { t } from 'i18next'
|
||||
import { isEmpty, throttle } from 'lodash'
|
||||
@ -45,10 +44,10 @@ import type { AppDispatch, RootState } from '../index'
|
||||
import { removeManyBlocks, updateOneBlock, upsertManyBlocks, upsertOneBlock } from '../messageBlock'
|
||||
import { newMessagesActions, selectMessagesForTopic } from '../newMessage'
|
||||
|
||||
const handleChangeLoadingOfTopic = async (topicId: string) => {
|
||||
await waitForTopicQueue(topicId)
|
||||
store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
// const handleChangeLoadingOfTopic = async (topicId: string) => {
|
||||
// await waitForTopicQueue(topicId)
|
||||
// store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
// }
|
||||
// TODO: 后续可以将db操作移到Listener Middleware中
|
||||
export const saveMessageAndBlocksToDB = async (message: Message, blocks: MessageBlock[], messageIndex: number = -1) => {
|
||||
try {
|
||||
@ -337,10 +336,17 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
|
||||
let accumulatedContent = ''
|
||||
let accumulatedThinking = ''
|
||||
// 专注于管理UI焦点和块切换
|
||||
let lastBlockId: string | null = null
|
||||
let lastBlockType: MessageBlockType | null = null
|
||||
// 专注于块内部的生命周期处理
|
||||
let initialPlaceholderBlockId: string | null = null
|
||||
let citationBlockId: string | null = null
|
||||
let mainTextBlockId: string | null = null
|
||||
let thinkingBlockId: string | null = null
|
||||
let imageBlockId: string | null = null
|
||||
let toolBlockId: string | null = null
|
||||
let hasWebSearch = false
|
||||
const toolCallIdToBlockIdMap = new Map<string, string>()
|
||||
const notificationService = NotificationService.getInstance()
|
||||
|
||||
@ -400,129 +406,129 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
}
|
||||
|
||||
callbacks = {
|
||||
onLLMResponseCreated: () => {
|
||||
onLLMResponseCreated: async () => {
|
||||
const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
initialPlaceholderBlockId = baseBlock.id
|
||||
await handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
|
||||
},
|
||||
onTextChunk: (text) => {
|
||||
onTextChunk: async (text) => {
|
||||
accumulatedContent += text
|
||||
if (lastBlockId) {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.MAIN_TEXT,
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
}
|
||||
mainTextBlockId = lastBlockId
|
||||
lastBlockType = MessageBlockType.MAIN_TEXT
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else if (lastBlockType === MessageBlockType.MAIN_TEXT) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
throttledBlockUpdate(lastBlockId, blockChanges)
|
||||
// throttledBlockDbUpdate(lastBlockId, blockChanges)
|
||||
} else {
|
||||
const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
})
|
||||
handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
|
||||
mainTextBlockId = newBlock.id
|
||||
if (mainTextBlockId) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
throttledBlockUpdate(mainTextBlockId, blockChanges)
|
||||
} else if (initialPlaceholderBlockId) {
|
||||
// 将占位块转换为主文本块
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.MAIN_TEXT,
|
||||
content: accumulatedContent,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
}
|
||||
mainTextBlockId = initialPlaceholderBlockId
|
||||
// 清理占位块
|
||||
initialPlaceholderBlockId = null
|
||||
lastBlockType = MessageBlockType.MAIN_TEXT
|
||||
dispatch(updateOneBlock({ id: mainTextBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
citationReferences: citationBlockId ? [{ citationBlockId }] : []
|
||||
})
|
||||
mainTextBlockId = newBlock.id // 立即设置ID,防止竞态条件
|
||||
await handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
|
||||
}
|
||||
},
|
||||
onTextComplete: async (finalText) => {
|
||||
if (lastBlockType === MessageBlockType.MAIN_TEXT && lastBlockId) {
|
||||
if (mainTextBlockId) {
|
||||
const changes = {
|
||||
content: finalText,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
|
||||
if (assistant.enableWebSearch && assistant.model?.provider === 'openrouter') {
|
||||
const extractedUrls = extractUrlsFromMarkdown(finalText)
|
||||
if (extractedUrls.length > 0) {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{ response: { source: WebSearchSource.OPENROUTER, results: extractedUrls } },
|
||||
{ status: MessageBlockStatus.SUCCESS }
|
||||
)
|
||||
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
|
||||
}
|
||||
}
|
||||
cancelThrottledBlockUpdate(mainTextBlockId)
|
||||
dispatch(updateOneBlock({ id: mainTextBlockId, changes }))
|
||||
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
|
||||
mainTextBlockId = null
|
||||
} else {
|
||||
console.warn(
|
||||
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.`
|
||||
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.`
|
||||
)
|
||||
}
|
||||
},
|
||||
onThinkingChunk: (text, thinking_millsec) => {
|
||||
accumulatedThinking += text
|
||||
if (lastBlockId) {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||
// First chunk for this block: Update type and status immediately
|
||||
lastBlockType = MessageBlockType.THINKING
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else if (lastBlockType === MessageBlockType.THINKING) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: thinking_millsec
|
||||
}
|
||||
throttledBlockUpdate(lastBlockId, blockChanges)
|
||||
// throttledBlockDbUpdate(lastBlockId, blockChanges)
|
||||
} else {
|
||||
const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: 0
|
||||
})
|
||||
handleBlockTransition(newBlock, MessageBlockType.THINKING)
|
||||
if (citationBlockId && !hasWebSearch) {
|
||||
const changes: Partial<CitationMessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: citationBlockId, changes }))
|
||||
saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState)
|
||||
citationBlockId = null
|
||||
}
|
||||
},
|
||||
onThinkingChunk: async (text, thinking_millsec) => {
|
||||
accumulatedThinking += text
|
||||
if (thinkingBlockId) {
|
||||
const blockChanges: Partial<MessageBlock> = {
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: thinking_millsec
|
||||
}
|
||||
throttledBlockUpdate(thinkingBlockId, blockChanges)
|
||||
} else if (initialPlaceholderBlockId) {
|
||||
// First chunk for this block: Update type and status immediately
|
||||
lastBlockType = MessageBlockType.THINKING
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: accumulatedThinking,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
thinkingBlockId = initialPlaceholderBlockId
|
||||
initialPlaceholderBlockId = null
|
||||
dispatch(updateOneBlock({ id: thinkingBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, {
|
||||
status: MessageBlockStatus.STREAMING,
|
||||
thinking_millsec: 0
|
||||
})
|
||||
thinkingBlockId = newBlock.id // 立即设置ID,防止竞态条件
|
||||
await handleBlockTransition(newBlock, MessageBlockType.THINKING)
|
||||
}
|
||||
},
|
||||
onThinkingComplete: (finalText, final_thinking_millsec) => {
|
||||
if (lastBlockType === MessageBlockType.THINKING && lastBlockId) {
|
||||
if (thinkingBlockId) {
|
||||
const changes = {
|
||||
type: MessageBlockType.THINKING,
|
||||
content: finalText,
|
||||
status: MessageBlockStatus.SUCCESS,
|
||||
thinking_millsec: final_thinking_millsec
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
cancelThrottledBlockUpdate(thinkingBlockId)
|
||||
dispatch(updateOneBlock({ id: thinkingBlockId, changes }))
|
||||
saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
console.warn(
|
||||
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.`
|
||||
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.`
|
||||
)
|
||||
}
|
||||
thinkingBlockId = null
|
||||
},
|
||||
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN && lastBlockId) {
|
||||
if (initialPlaceholderBlockId) {
|
||||
lastBlockType = MessageBlockType.TOOL
|
||||
const changes = {
|
||||
type: MessageBlockType.TOOL,
|
||||
status: MessageBlockStatus.PROCESSING,
|
||||
metadata: { rawMcpToolResponse: toolResponse }
|
||||
}
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, lastBlockId)
|
||||
toolBlockId = initialPlaceholderBlockId
|
||||
initialPlaceholderBlockId = null
|
||||
dispatch(updateOneBlock({ id: toolBlockId, changes }))
|
||||
saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState)
|
||||
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
|
||||
} else if (toolResponse.status === 'invoking') {
|
||||
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
|
||||
toolName: toolResponse.tool.name,
|
||||
@ -539,6 +545,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
},
|
||||
onToolCallComplete: (toolResponse: MCPToolResponse) => {
|
||||
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
|
||||
toolCallIdToBlockIdMap.delete(toolResponse.id)
|
||||
if (toolResponse.status === 'done' || toolResponse.status === 'error') {
|
||||
if (!existingBlockId) {
|
||||
console.error(
|
||||
@ -564,10 +571,10 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
)
|
||||
}
|
||||
},
|
||||
onExternalToolInProgress: () => {
|
||||
onExternalToolInProgress: async () => {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
|
||||
},
|
||||
onExternalToolComplete: (externalToolResult: ExternalToolResult) => {
|
||||
@ -583,35 +590,39 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.')
|
||||
}
|
||||
},
|
||||
onLLMWebSearchInProgress: () => {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
|
||||
onLLMWebSearchInProgress: async () => {
|
||||
if (initialPlaceholderBlockId) {
|
||||
lastBlockType = MessageBlockType.CITATION
|
||||
citationBlockId = initialPlaceholderBlockId
|
||||
const changes = {
|
||||
type: MessageBlockType.CITATION,
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
}
|
||||
lastBlockType = MessageBlockType.CITATION
|
||||
dispatch(updateOneBlock({ id: initialPlaceholderBlockId, changes }))
|
||||
saveUpdatedBlockToDB(initialPlaceholderBlockId, assistantMsgId, topicId, getState)
|
||||
initialPlaceholderBlockId = null
|
||||
} else {
|
||||
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
|
||||
citationBlockId = citationBlock.id
|
||||
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
},
|
||||
onLLMWebSearchComplete: async (llmWebSearchResult) => {
|
||||
if (citationBlockId) {
|
||||
hasWebSearch = true
|
||||
const changes: Partial<CitationMessageBlock> = {
|
||||
response: llmWebSearchResult,
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: citationBlockId, changes }))
|
||||
saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const citationBlock = createCitationBlock(
|
||||
assistantMsgId,
|
||||
{ response: llmWebSearchResult },
|
||||
{ status: MessageBlockStatus.SUCCESS }
|
||||
)
|
||||
citationBlockId = citationBlock.id
|
||||
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
|
||||
}
|
||||
if (mainTextBlockId) {
|
||||
const state = getState()
|
||||
const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId]
|
||||
if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) {
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
if (!currentRefs.some((ref) => ref.citationBlockId === citationBlockId)) {
|
||||
|
||||
if (mainTextBlockId) {
|
||||
const state = getState()
|
||||
const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId]
|
||||
if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) {
|
||||
const currentRefs = existingMainTextBlock.citationReferences || []
|
||||
const mainTextChanges = {
|
||||
citationReferences: [
|
||||
...currentRefs,
|
||||
@ -621,40 +632,64 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
dispatch(updateOneBlock({ id: mainTextBlockId, changes: mainTextChanges }))
|
||||
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
mainTextBlockId = null
|
||||
}
|
||||
}
|
||||
},
|
||||
onImageCreated: () => {
|
||||
if (lastBlockId) {
|
||||
if (lastBlockType === MessageBlockType.UNKNOWN) {
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.IMAGE,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
lastBlockType = MessageBlockType.IMAGE
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.PROCESSING
|
||||
})
|
||||
handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
onImageCreated: async () => {
|
||||
if (initialPlaceholderBlockId) {
|
||||
lastBlockType = MessageBlockType.IMAGE
|
||||
const initialChanges: Partial<MessageBlock> = {
|
||||
type: MessageBlockType.IMAGE,
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
lastBlockType = MessageBlockType.IMAGE
|
||||
imageBlockId = initialPlaceholderBlockId
|
||||
initialPlaceholderBlockId = null
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes: initialChanges }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
} else if (!imageBlockId) {
|
||||
const imageBlock = createImageBlock(assistantMsgId, {
|
||||
status: MessageBlockStatus.STREAMING
|
||||
})
|
||||
imageBlockId = imageBlock.id
|
||||
await handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
|
||||
}
|
||||
},
|
||||
onImageGenerated: (imageData) => {
|
||||
onImageDelta: (imageData) => {
|
||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||
if (lastBlockId && lastBlockType === MessageBlockType.IMAGE) {
|
||||
if (imageBlockId) {
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
url: imageUrl,
|
||||
metadata: { generateImageResponse: imageData },
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
status: MessageBlockStatus.STREAMING
|
||||
}
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
},
|
||||
onImageGenerated: (imageData) => {
|
||||
if (imageBlockId) {
|
||||
if (!imageData) {
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
|
||||
const changes: Partial<ImageMessageBlock> = {
|
||||
url: imageUrl,
|
||||
metadata: { generateImageResponse: imageData },
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
dispatch(updateOneBlock({ id: imageBlockId, changes }))
|
||||
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
} else {
|
||||
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
|
||||
}
|
||||
imageBlockId = null
|
||||
},
|
||||
onError: async (error) => {
|
||||
console.dir(error, { depth: null })
|
||||
@ -683,15 +718,16 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
source: 'assistant'
|
||||
})
|
||||
}
|
||||
|
||||
if (lastBlockId) {
|
||||
const possibleBlockId =
|
||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
||||
if (possibleBlockId) {
|
||||
// 更改上一个block的状态为ERROR
|
||||
const changes: Partial<MessageBlock> = {
|
||||
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
cancelThrottledBlockUpdate(possibleBlockId)
|
||||
dispatch(updateOneBlock({ id: possibleBlockId, changes }))
|
||||
saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
|
||||
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
|
||||
@ -721,35 +757,45 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : []
|
||||
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
|
||||
|
||||
if (lastBlockId) {
|
||||
const possibleBlockId =
|
||||
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
|
||||
if (possibleBlockId) {
|
||||
const changes: Partial<MessageBlock> = {
|
||||
status: MessageBlockStatus.SUCCESS
|
||||
}
|
||||
cancelThrottledBlockUpdate(lastBlockId)
|
||||
dispatch(updateOneBlock({ id: lastBlockId, changes }))
|
||||
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
|
||||
cancelThrottledBlockUpdate(possibleBlockId)
|
||||
dispatch(updateOneBlock({ id: possibleBlockId, changes }))
|
||||
saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState)
|
||||
}
|
||||
|
||||
// const content = getMainTextContent(finalAssistantMsg)
|
||||
// if (!isOnHomePage()) {
|
||||
// await notificationService.send({
|
||||
// id: uuid(),
|
||||
// type: 'success',
|
||||
// title: t('notification.assistant'),
|
||||
// message: content.length > 50 ? content.slice(0, 47) + '...' : content,
|
||||
// silent: false,
|
||||
// timestamp: Date.now(),
|
||||
// source: 'assistant'
|
||||
// })
|
||||
// }
|
||||
const endTime = Date.now()
|
||||
const duration = endTime - startTime
|
||||
const content = getMainTextContent(finalAssistantMsg)
|
||||
if (!isOnHomePage() && duration > 60 * 1000) {
|
||||
await notificationService.send({
|
||||
id: uuid(),
|
||||
type: 'success',
|
||||
title: t('notification.assistant'),
|
||||
message: content.length > 50 ? content.slice(0, 47) + '...' : content,
|
||||
silent: false,
|
||||
timestamp: Date.now(),
|
||||
source: 'assistant'
|
||||
})
|
||||
}
|
||||
|
||||
// 更新topic的name
|
||||
autoRenameTopic(assistant, topicId)
|
||||
|
||||
if (response && response.usage?.total_tokens === 0) {
|
||||
if (
|
||||
response &&
|
||||
(response.usage?.total_tokens === 0 ||
|
||||
response?.usage?.prompt_tokens === 0 ||
|
||||
response?.usage?.completion_tokens === 0)
|
||||
) {
|
||||
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
|
||||
response.usage = usage
|
||||
}
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
if (response && response.metrics) {
|
||||
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
|
||||
@ -779,6 +825,7 @@ const fetchAndProcessAssistantResponseImpl = async (
|
||||
|
||||
const streamProcessorCallbacks = createStreamProcessor(callbacks)
|
||||
|
||||
const startTime = Date.now()
|
||||
await fetchChatCompletion({
|
||||
messages: messagesForContext,
|
||||
assistant: assistant,
|
||||
@ -833,9 +880,10 @@ export const sendMessage =
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error in sendMessage thunk:', error)
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1069,9 +1117,10 @@ export const resendMessageThunk =
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`[resendMessageThunk] Error resending user message ${userMessageToResend.id}:`, error)
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
@ -1179,10 +1228,11 @@ export const regenerateAssistantResponseThunk =
|
||||
`[regenerateAssistantResponseThunk] Error regenerating response for assistant message ${assistantMessageToRegenerate.id}:`,
|
||||
error
|
||||
)
|
||||
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
// dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
// --- Thunk to initiate translation and create the initial block ---
|
||||
@ -1348,9 +1398,10 @@ export const appendAssistantResponseThunk =
|
||||
console.error(`[appendAssistantResponseThunk] Error appending assistant response:`, error)
|
||||
// Optionally dispatch an error action or notification
|
||||
// Resetting loading state should be handled by the underlying fetchAndProcessAssistantResponseImpl
|
||||
} finally {
|
||||
handleChangeLoadingOfTopic(topicId)
|
||||
}
|
||||
// finally {
|
||||
// handleChangeLoadingOfTopic(topicId)
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import { ExternalToolResult, KnowledgeReference, MCPToolResponse, WebSearchResponse } from '.'
|
||||
import { ExternalToolResult, KnowledgeReference, MCPToolResponse, ToolUseResponse, WebSearchResponse } from '.'
|
||||
import { Response, ResponseError } from './newMessage'
|
||||
import { SdkToolCall } from './sdk'
|
||||
|
||||
// Define Enum for Chunk Types
|
||||
// 目前用到的,并没有列出完整的生命周期
|
||||
@ -11,6 +12,7 @@ export enum ChunkType {
|
||||
WEB_SEARCH_COMPLETE = 'web_search_complete',
|
||||
KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress',
|
||||
KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete',
|
||||
MCP_TOOL_CREATED = 'mcp_tool_created',
|
||||
MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress',
|
||||
MCP_TOOL_COMPLETE = 'mcp_tool_complete',
|
||||
EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete',
|
||||
@ -118,7 +120,7 @@ export interface ImageDeltaChunk {
|
||||
/**
|
||||
* A chunk of Base64 encoded image data
|
||||
*/
|
||||
image: string
|
||||
image: { type: 'base64'; images: string[] }
|
||||
|
||||
/**
|
||||
* The type of the chunk
|
||||
@ -135,7 +137,7 @@ export interface ImageCompleteChunk {
|
||||
/**
|
||||
* The image content of the chunk
|
||||
*/
|
||||
image: { type: 'base64'; images: string[] }
|
||||
image?: { type: 'base64'; images: string[] }
|
||||
}
|
||||
|
||||
export interface ThinkingDeltaChunk {
|
||||
@ -253,6 +255,12 @@ export interface ExternalToolCompleteChunk {
|
||||
type: ChunkType.EXTERNEL_TOOL_COMPLETE
|
||||
}
|
||||
|
||||
export interface MCPToolCreatedChunk {
|
||||
type: ChunkType.MCP_TOOL_CREATED
|
||||
tool_calls?: SdkToolCall[] // 工具调用
|
||||
tool_use_responses?: ToolUseResponse[] // 工具使用响应
|
||||
}
|
||||
|
||||
export interface MCPToolInProgressChunk {
|
||||
/**
|
||||
* The type of the chunk
|
||||
@ -345,6 +353,7 @@ export type Chunk =
|
||||
| WebSearchCompleteChunk // 互联网搜索完成
|
||||
| KnowledgeSearchInProgressChunk // 知识库搜索进行中
|
||||
| KnowledgeSearchCompleteChunk // 知识库搜索完成
|
||||
| MCPToolCreatedChunk // MCP工具被大模型创建
|
||||
| MCPToolInProgressChunk // MCP工具调用中
|
||||
| MCPToolCompleteChunk // MCP工具调用完成
|
||||
| ExternalToolCompleteChunk // 外部工具调用完成,外部工具包含搜索互联网,知识库,MCP服务器
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import type { WebSearchResultBlock } from '@anthropic-ai/sdk/resources'
|
||||
import type { GenerateImagesConfig, GroundingMetadata } from '@google/genai'
|
||||
import type { GenerateImagesConfig, GroundingMetadata, PersonGeneration } from '@google/genai'
|
||||
import type OpenAI from 'openai'
|
||||
import type { CSSProperties } from 'react'
|
||||
|
||||
@ -444,10 +444,11 @@ export type GenerateImageParams = {
|
||||
imageSize: string
|
||||
batchSize: number
|
||||
seed?: string
|
||||
numInferenceSteps: number
|
||||
guidanceScale: number
|
||||
numInferenceSteps?: number
|
||||
guidanceScale?: number
|
||||
signal?: AbortSignal
|
||||
promptEnhancement?: boolean
|
||||
personGeneration?: PersonGeneration
|
||||
}
|
||||
|
||||
export type GenerateImageResponse = {
|
||||
@ -520,7 +521,7 @@ export enum WebSearchSource {
|
||||
}
|
||||
|
||||
export type WebSearchResponse = {
|
||||
results: WebSearchResults
|
||||
results?: WebSearchResults
|
||||
source: WebSearchSource
|
||||
}
|
||||
|
||||
|
||||
107
src/renderer/src/types/sdk.ts
Normal file
107
src/renderer/src/types/sdk.ts
Normal file
@ -0,0 +1,107 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import {
|
||||
Message,
|
||||
MessageCreateParams,
|
||||
MessageParam,
|
||||
RawMessageStreamEvent,
|
||||
ToolUnion,
|
||||
ToolUseBlock
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import {
|
||||
Content,
|
||||
CreateChatParameters,
|
||||
FunctionCall,
|
||||
GenerateContentResponse,
|
||||
GoogleGenAI,
|
||||
Model as GeminiModel,
|
||||
SendMessageParameters,
|
||||
Tool
|
||||
} from '@google/genai'
|
||||
import OpenAI, { AzureOpenAI } from 'openai'
|
||||
import { Stream } from 'openai/streaming'
|
||||
|
||||
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | GoogleGenAI
|
||||
export type SdkParams = OpenAISdkParams | OpenAIResponseSdkParams | AnthropicSdkParams | GeminiSdkParams
|
||||
export type SdkRawChunk = OpenAISdkRawChunk | OpenAIResponseSdkRawChunk | AnthropicSdkRawChunk | GeminiSdkRawChunk
|
||||
export type SdkRawOutput = OpenAISdkRawOutput | OpenAIResponseSdkRawOutput | AnthropicSdkRawOutput | GeminiSdkRawOutput
|
||||
export type SdkMessageParam =
|
||||
| OpenAISdkMessageParam
|
||||
| OpenAIResponseSdkMessageParam
|
||||
| AnthropicSdkMessageParam
|
||||
| GeminiSdkMessageParam
|
||||
export type SdkToolCall =
|
||||
| OpenAI.Chat.Completions.ChatCompletionMessageToolCall
|
||||
| ToolUseBlock
|
||||
| FunctionCall
|
||||
| OpenAIResponseSdkToolCall
|
||||
export type SdkTool = OpenAI.Chat.Completions.ChatCompletionTool | ToolUnion | Tool | OpenAIResponseSdkTool
|
||||
export type SdkModel = OpenAI.Models.Model | Anthropic.ModelInfo | GeminiModel
|
||||
|
||||
export type RequestOptions = Anthropic.RequestOptions | OpenAI.RequestOptions | GeminiOptions
|
||||
|
||||
/**
|
||||
* OpenAI
|
||||
*/
|
||||
|
||||
type OpenAIParamsWithoutReasoningEffort = Omit<OpenAI.Chat.Completions.ChatCompletionCreateParams, 'reasoning_effort'>
|
||||
|
||||
export type ReasoningEffortOptionalParams = {
|
||||
thinking?: { type: 'disabled' | 'enabled'; budget_tokens?: number }
|
||||
reasoning?: { max_tokens?: number; exclude?: boolean; effort?: string } | OpenAI.Reasoning
|
||||
reasoning_effort?: OpenAI.Chat.Completions.ChatCompletionCreateParams['reasoning_effort'] | 'none' | 'auto'
|
||||
enable_thinking?: boolean
|
||||
thinking_budget?: number
|
||||
enable_reasoning?: boolean
|
||||
// Add any other potential reasoning-related keys here if they exist
|
||||
}
|
||||
|
||||
export type OpenAISdkParams = OpenAIParamsWithoutReasoningEffort & ReasoningEffortOptionalParams
|
||||
export type OpenAISdkRawChunk =
|
||||
| OpenAI.Chat.Completions.ChatCompletionChunk
|
||||
| ({
|
||||
_request_id?: string | null | undefined
|
||||
} & OpenAI.ChatCompletion)
|
||||
|
||||
export type OpenAISdkRawOutput = Stream<OpenAI.Chat.Completions.ChatCompletionChunk> | OpenAI.ChatCompletion
|
||||
export type OpenAISdkRawContentSource =
|
||||
| OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta
|
||||
| OpenAI.Chat.Completions.ChatCompletionMessage
|
||||
|
||||
export type OpenAISdkMessageParam = OpenAI.Chat.Completions.ChatCompletionMessageParam
|
||||
|
||||
/**
|
||||
* OpenAI Response
|
||||
*/
|
||||
|
||||
export type OpenAIResponseSdkParams = OpenAI.Responses.ResponseCreateParams
|
||||
export type OpenAIResponseSdkRawOutput = Stream<OpenAI.Responses.ResponseStreamEvent> | OpenAI.Responses.Response
|
||||
export type OpenAIResponseSdkRawChunk = OpenAI.Responses.ResponseStreamEvent | OpenAI.Responses.Response
|
||||
export type OpenAIResponseSdkMessageParam = OpenAI.Responses.ResponseInputItem
|
||||
export type OpenAIResponseSdkToolCall = OpenAI.Responses.ResponseFunctionToolCall
|
||||
export type OpenAIResponseSdkTool = OpenAI.Responses.Tool
|
||||
|
||||
/**
|
||||
* Anthropic
|
||||
*/
|
||||
|
||||
export type AnthropicSdkParams = MessageCreateParams
|
||||
export type AnthropicSdkRawOutput = MessageStream | Message
|
||||
export type AnthropicSdkRawChunk = RawMessageStreamEvent | Message
|
||||
export type AnthropicSdkMessageParam = MessageParam
|
||||
|
||||
/**
|
||||
* Gemini
|
||||
*/
|
||||
|
||||
export type GeminiSdkParams = SendMessageParameters & CreateChatParameters
|
||||
export type GeminiSdkRawOutput = AsyncGenerator<GenerateContentResponse> | GenerateContentResponse
|
||||
export type GeminiSdkRawChunk = GenerateContentResponse
|
||||
export type GeminiSdkMessageParam = Content
|
||||
export type GeminiSdkToolCall = FunctionCall
|
||||
|
||||
export type GeminiOptions = {
|
||||
streamOutput: boolean
|
||||
abortSignal?: AbortSignal
|
||||
timeout?: number
|
||||
}
|
||||
@ -369,3 +369,99 @@ export function cleanLinkCommas(text: string): string {
|
||||
// 匹配两个 Markdown 链接之间的英文逗号(可能包含空格)
|
||||
return text.replace(/\]\(([^)]+)\)\s*,\s*\[/g, ']($1)[')
|
||||
}
|
||||
|
||||
/**
|
||||
* 从文本中识别各种格式的Web搜索引用占位符
|
||||
* 支持的格式包括:[1], [ref_1], [1](@ref), [1,2,3](@ref) 等
|
||||
* @param {string} text 要分析的文本
|
||||
* @returns {Array} 识别到的引用信息数组
|
||||
*/
|
||||
export function extractWebSearchReferences(text: string): Array<{
|
||||
match: string
|
||||
placeholder: string
|
||||
numbers: number[]
|
||||
startIndex: number
|
||||
endIndex: number
|
||||
}> {
|
||||
const references: Array<{
|
||||
match: string
|
||||
placeholder: string
|
||||
numbers: number[]
|
||||
startIndex: number
|
||||
endIndex: number
|
||||
}> = []
|
||||
|
||||
// 匹配各种引用格式的正则表达式
|
||||
const patterns = [
|
||||
// [1], [2], [3] - 简单数字引用
|
||||
{ regex: /\[(\d+)\]/g, type: 'simple' },
|
||||
// [ref_1], [ref_2] - Zhipu格式
|
||||
{ regex: /\[ref_(\d+)\]/g, type: 'zhipu' },
|
||||
// [1](@ref), [2](@ref) - Hunyuan单个引用格式
|
||||
{ regex: /\[(\d+)\]\(@ref\)/g, type: 'hunyuan_single' },
|
||||
// [1,2,3](@ref) - Hunyuan多个引用格式
|
||||
{ regex: /\[([\d,\s]+)\]\(@ref\)/g, type: 'hunyuan_multiple' }
|
||||
]
|
||||
|
||||
patterns.forEach(({ regex, type }) => {
|
||||
let match
|
||||
while ((match = regex.exec(text)) !== null) {
|
||||
let numbers: number[] = []
|
||||
|
||||
if (type === 'hunyuan_multiple') {
|
||||
// 解析逗号分隔的数字
|
||||
numbers = match[1]
|
||||
.split(',')
|
||||
.map((num) => parseInt(num.trim()))
|
||||
.filter((num) => !isNaN(num))
|
||||
} else {
|
||||
// 单个数字
|
||||
numbers = [parseInt(match[1])]
|
||||
}
|
||||
|
||||
references.push({
|
||||
match: match[0],
|
||||
placeholder: match[0],
|
||||
numbers: numbers,
|
||||
startIndex: match.index!,
|
||||
endIndex: match.index! + match[0].length
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// 按位置排序
|
||||
return references.sort((a, b) => a.startIndex - b.startIndex)
|
||||
}
|
||||
|
||||
/**
|
||||
* 智能链接转换器 - 根据文本中的引用模式和Web搜索结果自动选择合适的转换策略
|
||||
* @param {string} text 当前文本块
|
||||
* @param {any[]} webSearchResults Web搜索结果数组
|
||||
* @param {string} providerType Provider类型 ('openai', 'zhipu', 'hunyuan', 'openrouter', etc.)
|
||||
* @param {boolean} resetCounter 是否重置计数器
|
||||
* @returns {string} 转换后的文本
|
||||
*/
|
||||
export function smartLinkConverter(
|
||||
text: string,
|
||||
providerType: string = 'openai',
|
||||
resetCounter: boolean = false
|
||||
): string {
|
||||
// 检测文本中的引用模式
|
||||
const references = extractWebSearchReferences(text)
|
||||
|
||||
if (references.length === 0) {
|
||||
// 如果没有特定的引用模式,使用通用转换
|
||||
return convertLinks(text, resetCounter)
|
||||
}
|
||||
|
||||
// 根据检测到的引用模式选择合适的转换器
|
||||
const hasZhipuPattern = references.some((ref) => ref.placeholder.includes('ref_'))
|
||||
|
||||
if (hasZhipuPattern) {
|
||||
return convertLinksToZhipu(text, resetCounter)
|
||||
} else if (providerType === 'openrouter') {
|
||||
return convertLinksToOpenRouter(text, resetCounter)
|
||||
} else {
|
||||
return convertLinks(text, resetCounter)
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,10 +1,4 @@
|
||||
import {
|
||||
ContentBlockParam,
|
||||
MessageParam,
|
||||
ToolResultBlockParam,
|
||||
ToolUnion,
|
||||
ToolUseBlock
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { ContentBlockParam, MessageParam, ToolUnion, ToolUseBlock } from '@anthropic-ai/sdk/resources'
|
||||
import { Content, FunctionCall, Part, Tool, Type as GeminiSchemaType } from '@google/genai'
|
||||
import Logger from '@renderer/config/logger'
|
||||
import { isFunctionCallingModel, isVisionModel } from '@renderer/config/models'
|
||||
@ -21,6 +15,7 @@ import {
|
||||
} from '@renderer/types'
|
||||
import type { MCPToolCompleteChunk, MCPToolInProgressChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { SdkMessageParam } from '@renderer/types/sdk'
|
||||
import { isArray, isObject, pull, transform } from 'lodash'
|
||||
import { nanoid } from 'nanoid'
|
||||
import OpenAI from 'openai'
|
||||
@ -31,7 +26,7 @@ import {
|
||||
ChatCompletionTool
|
||||
} from 'openai/resources'
|
||||
|
||||
import { CompletionsParams } from '../providers/AiProvider'
|
||||
import { CompletionsParams } from '../aiCore/middleware/schemas'
|
||||
|
||||
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
||||
const EXTRA_SCHEMA_KEYS = ['schema', 'headers']
|
||||
@ -449,13 +444,25 @@ export function parseToolUse(content: string, mcpTools: MCPTool[]): ToolUseRespo
|
||||
if (!content || !mcpTools || mcpTools.length === 0) {
|
||||
return []
|
||||
}
|
||||
|
||||
// 支持两种格式:
|
||||
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
|
||||
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||
|
||||
let contentToProcess = content
|
||||
|
||||
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||
if (!content.includes('<tool_use>')) {
|
||||
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||
}
|
||||
|
||||
const toolUsePattern =
|
||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||
const tools: ToolUseResponse[] = []
|
||||
let match
|
||||
let idx = 0
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(content)) !== null) {
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
// const fullMatch = match[0]
|
||||
const toolName = match[2].trim()
|
||||
const toolArgs = match[4].trim()
|
||||
@ -497,9 +504,7 @@ export async function parseAndCallTools<R>(
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[]
|
||||
): Promise<
|
||||
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
|
||||
>
|
||||
): Promise<SdkMessageParam[]>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string,
|
||||
@ -508,9 +513,7 @@ export async function parseAndCallTools<R>(
|
||||
convertToMessage: (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => R | undefined,
|
||||
model: Model,
|
||||
mcpTools?: MCPTool[]
|
||||
): Promise<
|
||||
(ChatCompletionMessageParam | MessageParam | Content | OpenAI.Responses.ResponseInputItem | ToolResultBlockParam)[]
|
||||
>
|
||||
): Promise<SdkMessageParam[]>
|
||||
|
||||
export async function parseAndCallTools<R>(
|
||||
content: string | MCPToolResponse[],
|
||||
@ -539,7 +542,7 @@ export async function parseAndCallTools<R>(
|
||||
...toolResponse,
|
||||
status: 'invoking'
|
||||
},
|
||||
onChunk
|
||||
onChunk!
|
||||
)
|
||||
}
|
||||
|
||||
@ -553,7 +556,7 @@ export async function parseAndCallTools<R>(
|
||||
status: 'done',
|
||||
response: toolCallResponse
|
||||
},
|
||||
onChunk
|
||||
onChunk!
|
||||
)
|
||||
|
||||
for (const content of toolCallResponse.content) {
|
||||
@ -563,10 +566,10 @@ export async function parseAndCallTools<R>(
|
||||
}
|
||||
|
||||
if (images.length) {
|
||||
onChunk({
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_CREATED
|
||||
})
|
||||
onChunk({
|
||||
onChunk?.({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
|
||||
@ -101,7 +101,7 @@ export function isEmoji(str: string): boolean {
|
||||
* @returns {string} 处理后的字符串
|
||||
*/
|
||||
export function removeSpecialCharactersForTopicName(str: string): string {
|
||||
return str.replace(/[\r\n]+/g, ' ').trim()
|
||||
return str.replace(/["'\r\n]+/g, ' ').trim()
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -31,10 +31,12 @@ export function readableStreamAsyncIterable<T>(stream: any): AsyncIterableIterat
|
||||
}
|
||||
}
|
||||
|
||||
export function asyncGeneratorToReadableStream<T>(gen: AsyncGenerator<T>): ReadableStream<T> {
|
||||
export function asyncGeneratorToReadableStream<T>(gen: AsyncIterable<T>): ReadableStream<T> {
|
||||
const iterator = gen[Symbol.asyncIterator]()
|
||||
|
||||
return new ReadableStream<T>({
|
||||
async pull(controller) {
|
||||
const { value, done } = await gen.next()
|
||||
const { value, done } = await iterator.next()
|
||||
if (done) {
|
||||
controller.close()
|
||||
} else {
|
||||
@ -43,3 +45,17 @@ export function asyncGeneratorToReadableStream<T>(gen: AsyncGenerator<T>): Reada
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 将单个数据项转换为可读流
|
||||
* @param data 要转换为流的单个数据项
|
||||
* @returns 包含单个数据项的ReadableStream
|
||||
*/
|
||||
export function createSingleChunkReadableStream<T>(data: T): ReadableStream<T> {
|
||||
return new ReadableStream<T>({
|
||||
start(controller) {
|
||||
controller.enqueue(data)
|
||||
controller.close()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
168
src/renderer/src/utils/tagExtraction.ts
Normal file
168
src/renderer/src/utils/tagExtraction.ts
Normal file
@ -0,0 +1,168 @@
|
||||
import { getPotentialStartIndex } from './getPotentialIndex'
|
||||
|
||||
export interface TagConfig {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
}
|
||||
|
||||
export interface TagExtractionState {
|
||||
textBuffer: string
|
||||
isInsideTag: boolean
|
||||
isFirstTag: boolean
|
||||
isFirstText: boolean
|
||||
afterSwitch: boolean
|
||||
accumulatedTagContent: string
|
||||
hasTagContent: boolean
|
||||
}
|
||||
|
||||
export interface TagExtractionResult {
|
||||
content: string
|
||||
isTagContent: boolean
|
||||
complete: boolean
|
||||
tagContentExtracted?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用标签提取处理器
|
||||
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||
*/
|
||||
export class TagExtractor {
|
||||
private config: TagConfig
|
||||
private state: TagExtractionState
|
||||
|
||||
constructor(config: TagConfig) {
|
||||
this.config = config
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本块,返回处理结果
|
||||
*/
|
||||
processText(newText: string): TagExtractionResult[] {
|
||||
this.state.textBuffer += newText
|
||||
const results: TagExtractionResult[] = []
|
||||
|
||||
// 处理标签提取逻辑
|
||||
while (true) {
|
||||
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||
|
||||
if (startIndex == null) {
|
||||
const content = this.state.textBuffer
|
||||
if (content.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(content),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
this.state.textBuffer = ''
|
||||
break
|
||||
}
|
||||
|
||||
// 处理标签前的内容
|
||||
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||
if (contentBeforeTag.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(contentBeforeTag),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
|
||||
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||
|
||||
if (foundFullMatch) {
|
||||
// 如果找到完整的标签
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||
|
||||
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||
results.push({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
})
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
}
|
||||
|
||||
this.state.isInsideTag = !this.state.isInsideTag
|
||||
this.state.afterSwitch = true
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.isFirstTag = false
|
||||
} else {
|
||||
this.state.isFirstText = false
|
||||
}
|
||||
} else {
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成处理,返回任何剩余的标签内容
|
||||
*/
|
||||
finalize(): TagExtractionResult | null {
|
||||
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||
const result = {
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
}
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
return result
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private addPrefix(text: string): string {
|
||||
const needsPrefix =
|
||||
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||
|
||||
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||
this.state.afterSwitch = false
|
||||
return prefix + text
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
reset(): void {
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user