mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2025-12-29 14:31:35 +08:00
记忆功能
This commit is contained in:
parent
8b95a131ec
commit
b8dffce149
303
src/main/mcpServers/simpleremember.ts
Normal file
303
src/main/mcpServers/simpleremember.ts
Normal file
@ -0,0 +1,303 @@
|
||||
// src/main/mcpServers/simpleremember.ts
|
||||
import { getConfigDir } from '@main/utils/file'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema, ListPromptsRequestSchema, McpError, ErrorCode } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { promises as fs } from 'fs'
|
||||
import path from 'path'
|
||||
import { Mutex } from 'async-mutex'
|
||||
|
||||
// 定义记忆文件路径
|
||||
const defaultMemoryPath = path.join(getConfigDir(), 'simpleremember.json')
|
||||
|
||||
// 记忆项接口
|
||||
interface Memory {
|
||||
content: string;
|
||||
createdAt: string;
|
||||
}
|
||||
|
||||
// 记忆存储结构
|
||||
interface MemoryStorage {
|
||||
memories: Memory[];
|
||||
}
|
||||
|
||||
class SimpleRememberManager {
|
||||
private memoryPath: string;
|
||||
private memories: Memory[] = [];
|
||||
private fileMutex: Mutex = new Mutex();
|
||||
|
||||
constructor(memoryPath: string) {
|
||||
this.memoryPath = memoryPath;
|
||||
}
|
||||
|
||||
// 静态工厂方法用于初始化
|
||||
public static async create(memoryPath: string): Promise<SimpleRememberManager> {
|
||||
const manager = new SimpleRememberManager(memoryPath);
|
||||
await manager._ensureMemoryPathExists();
|
||||
await manager._loadMemoriesFromDisk();
|
||||
return manager;
|
||||
}
|
||||
|
||||
// 确保记忆文件存在
|
||||
private async _ensureMemoryPathExists(): Promise<void> {
|
||||
try {
|
||||
const directory = path.dirname(this.memoryPath);
|
||||
await fs.mkdir(directory, { recursive: true });
|
||||
try {
|
||||
await fs.access(this.memoryPath);
|
||||
} catch (error) {
|
||||
// 文件不存在,创建一个空文件
|
||||
await fs.writeFile(this.memoryPath, JSON.stringify({ memories: [] }, null, 2));
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Failed to ensure memory path exists:', error);
|
||||
throw new McpError(ErrorCode.InternalError, `Failed to ensure memory path: ${error instanceof Error ? error.message : String(error)}`);
|
||||
}
|
||||
}
|
||||
|
||||
// 从磁盘加载记忆
|
||||
private async _loadMemoriesFromDisk(): Promise<void> {
|
||||
try {
|
||||
const data = await fs.readFile(this.memoryPath, 'utf-8');
|
||||
// 处理空文件情况
|
||||
if (data.trim() === '') {
|
||||
this.memories = [];
|
||||
await this._persistMemories();
|
||||
return;
|
||||
}
|
||||
const storage: MemoryStorage = JSON.parse(data);
|
||||
this.memories = storage.memories || [];
|
||||
} catch (error) {
|
||||
if (error instanceof Error && 'code' in error && (error as any).code === 'ENOENT') {
|
||||
this.memories = [];
|
||||
await this._persistMemories();
|
||||
} else if (error instanceof SyntaxError) {
|
||||
console.error('Failed to parse simpleremember.json, initializing with empty memories:', error);
|
||||
this.memories = [];
|
||||
await this._persistMemories();
|
||||
} else {
|
||||
console.error('Unexpected error loading memories:', error);
|
||||
throw new McpError(ErrorCode.InternalError, `Failed to load memories: ${error instanceof Error ? error.message : String(error)}`);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 将记忆持久化到磁盘
|
||||
private async _persistMemories(): Promise<void> {
|
||||
const release = await this.fileMutex.acquire();
|
||||
try {
|
||||
const storage: MemoryStorage = {
|
||||
memories: this.memories
|
||||
};
|
||||
await fs.writeFile(this.memoryPath, JSON.stringify(storage, null, 2));
|
||||
} catch (error) {
|
||||
console.error('Failed to save memories:', error);
|
||||
throw new McpError(ErrorCode.InternalError, `Failed to save memories: ${error instanceof Error ? error.message : String(error)}`);
|
||||
} finally {
|
||||
release();
|
||||
}
|
||||
}
|
||||
|
||||
// 添加新记忆
|
||||
async remember(memory: string): Promise<Memory> {
|
||||
const newMemory: Memory = {
|
||||
content: memory,
|
||||
createdAt: new Date().toISOString()
|
||||
};
|
||||
this.memories.push(newMemory);
|
||||
await this._persistMemories();
|
||||
return newMemory;
|
||||
}
|
||||
|
||||
// 获取所有记忆
|
||||
async getAllMemories(): Promise<Memory[]> {
|
||||
return [...this.memories];
|
||||
}
|
||||
|
||||
// 获取记忆 - 这个方法会被get_memories工具调用
|
||||
async get_memories(): Promise<Memory[]> {
|
||||
return this.getAllMemories();
|
||||
}
|
||||
}
|
||||
|
||||
// 定义工具 - 按照MCP规范定义工具
|
||||
const REMEMBER_TOOL = {
|
||||
name: 'remember',
|
||||
description: '用于记忆长期有用信息的工具。这个工具会自动应用记忆,无需显式调用。只用于存储长期有用的信息,不适合临时信息。',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
memory: {
|
||||
type: 'string',
|
||||
description: '要记住的简洁(1句话)记忆内容'
|
||||
}
|
||||
},
|
||||
required: ['memory']
|
||||
}
|
||||
};
|
||||
|
||||
const GET_MEMORIES_TOOL = {
|
||||
name: 'get_memories',
|
||||
description: '获取所有已存储的记忆',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
};
|
||||
|
||||
// 添加日志以便调试
|
||||
console.log("[SimpleRemember] Defined tools:", { REMEMBER_TOOL, GET_MEMORIES_TOOL });
|
||||
|
||||
class SimpleRememberServer {
|
||||
public server: Server;
|
||||
private simpleRememberManager: SimpleRememberManager | null = null;
|
||||
private initializationPromise: Promise<void>;
|
||||
|
||||
constructor(envPath: string = '') {
|
||||
const memoryPath = envPath
|
||||
? path.isAbsolute(envPath)
|
||||
? envPath
|
||||
: path.resolve(envPath)
|
||||
: defaultMemoryPath;
|
||||
|
||||
console.log("[SimpleRemember] Creating server with memory path:", memoryPath);
|
||||
|
||||
// 初始化服务器
|
||||
this.server = new Server(
|
||||
{
|
||||
name: 'simple-remember-server',
|
||||
version: '1.0.0'
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
tools: {
|
||||
// 按照MCP规范声明工具能力
|
||||
listChanged: true
|
||||
},
|
||||
// 添加空的prompts能力,表示支持提示词功能但没有实际的提示词
|
||||
prompts: {}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
console.log("[SimpleRemember] Server initialized with tools capability");
|
||||
|
||||
// 手动添加工具到服务器的工具列表中
|
||||
console.log("[SimpleRemember] Adding tools to server");
|
||||
|
||||
// 先设置请求处理程序,再初始化管理器
|
||||
this.setupRequestHandlers();
|
||||
this.initializationPromise = this._initializeManager(memoryPath);
|
||||
|
||||
console.log("[SimpleRemember] Server initialization complete");
|
||||
// 打印工具信息以确认它们已注册
|
||||
console.log("[SimpleRemember] Tools registered:", [REMEMBER_TOOL.name, GET_MEMORIES_TOOL.name]);
|
||||
}
|
||||
|
||||
private async _initializeManager(memoryPath: string): Promise<void> {
|
||||
try {
|
||||
this.simpleRememberManager = await SimpleRememberManager.create(memoryPath);
|
||||
console.log("SimpleRememberManager initialized successfully.");
|
||||
} catch (error) {
|
||||
console.error("Failed to initialize SimpleRememberManager:", error);
|
||||
this.simpleRememberManager = null;
|
||||
}
|
||||
}
|
||||
|
||||
private async _getManager(): Promise<SimpleRememberManager> {
|
||||
if (!this.simpleRememberManager) {
|
||||
await this.initializationPromise;
|
||||
if (!this.simpleRememberManager) {
|
||||
throw new McpError(ErrorCode.InternalError, "SimpleRememberManager is not initialized");
|
||||
}
|
||||
}
|
||||
return this.simpleRememberManager;
|
||||
}
|
||||
|
||||
setupRequestHandlers() {
|
||||
// 添加对prompts/list请求的处理
|
||||
this.server.setRequestHandler(ListPromptsRequestSchema, async (request) => {
|
||||
console.log("[SimpleRemember] Listing prompts request received", request);
|
||||
|
||||
// 返回空的提示词列表
|
||||
return {
|
||||
prompts: []
|
||||
};
|
||||
});
|
||||
|
||||
this.server.setRequestHandler(ListToolsRequestSchema, async (request) => {
|
||||
// 直接返回工具列表,不需要等待管理器初始化
|
||||
console.log("[SimpleRemember] Listing tools request received", request);
|
||||
|
||||
// 打印工具定义以确保它们存在
|
||||
console.log("[SimpleRemember] REMEMBER_TOOL:", JSON.stringify(REMEMBER_TOOL));
|
||||
console.log("[SimpleRemember] GET_MEMORIES_TOOL:", JSON.stringify(GET_MEMORIES_TOOL));
|
||||
|
||||
const toolsList = [REMEMBER_TOOL, GET_MEMORIES_TOOL];
|
||||
console.log("[SimpleRemember] Returning tools:", JSON.stringify(toolsList));
|
||||
|
||||
// 按照MCP规范返回工具列表
|
||||
return {
|
||||
tools: toolsList,
|
||||
// 如果有分页,可以添加nextCursor
|
||||
// nextCursor: "next-page-cursor"
|
||||
};
|
||||
});
|
||||
|
||||
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||
const { name, arguments: args } = request.params;
|
||||
|
||||
console.log(`[SimpleRemember] Received tool call: ${name}`, args);
|
||||
|
||||
try {
|
||||
const manager = await this._getManager();
|
||||
|
||||
if (name === 'remember') {
|
||||
if (!args || typeof args.memory !== 'string') {
|
||||
console.error(`[SimpleRemember] Invalid arguments for ${name}:`, args);
|
||||
throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for ${name}: 'memory' string is required.`);
|
||||
}
|
||||
console.log(`[SimpleRemember] Remembering: "${args.memory}"`);
|
||||
const result = await manager.remember(args.memory);
|
||||
console.log(`[SimpleRemember] Memory saved successfully:`, result);
|
||||
// 按照MCP规范返回工具调用结果
|
||||
return {
|
||||
content: [{
|
||||
type: 'text',
|
||||
text: `记忆已保存: "${args.memory}"`
|
||||
}],
|
||||
isError: false
|
||||
};
|
||||
}
|
||||
|
||||
if (name === 'get_memories') {
|
||||
console.log(`[SimpleRemember] Getting all memories`);
|
||||
const memories = await manager.get_memories();
|
||||
console.log(`[SimpleRemember] Retrieved ${memories.length} memories`);
|
||||
// 按照MCP规范返回工具调用结果
|
||||
return {
|
||||
content: [{
|
||||
type: 'text',
|
||||
text: JSON.stringify(memories, null, 2)
|
||||
}],
|
||||
isError: false
|
||||
};
|
||||
}
|
||||
|
||||
console.error(`[SimpleRemember] Unknown tool: ${name}`);
|
||||
throw new McpError(ErrorCode.MethodNotFound, `Unknown tool: ${name}`);
|
||||
} catch (error) {
|
||||
console.error(`[SimpleRemember] Error handling tool call ${name}:`, error);
|
||||
// 按照MCP规范返回工具调用错误
|
||||
return {
|
||||
content: [{
|
||||
type: 'text',
|
||||
text: error instanceof Error ? error.message : String(error)
|
||||
}],
|
||||
isError: true
|
||||
};
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
export default SimpleRememberServer;
|
||||
@ -14,6 +14,8 @@ import { Assistant, FileTypes, MCPToolResponse, Message, Model, Provider, Sugges
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import store from '@renderer/store'
|
||||
import { getActiveServers } from '@renderer/store/mcp'
|
||||
import { first, flatten, sum, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
@ -177,7 +179,7 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
|
||||
let systemPrompt = assistant.prompt
|
||||
if (mcpTools && mcpTools.length > 0) {
|
||||
systemPrompt = buildSystemPrompt(systemPrompt, mcpTools)
|
||||
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools, getActiveServers(store.getState()))
|
||||
}
|
||||
|
||||
const body: MessageCreateParamsNonStreaming = {
|
||||
|
||||
@ -24,6 +24,8 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
|
||||
import { EVENT_NAMES } from '@renderer/services/EventService'
|
||||
import store from '@renderer/store'
|
||||
import { getActiveServers } from '@renderer/store/mcp'
|
||||
import {
|
||||
filterContextMessages,
|
||||
filterEmptyMessages,
|
||||
@ -228,7 +230,7 @@ export default class GeminiProvider extends BaseProvider {
|
||||
let systemInstruction = assistant.prompt
|
||||
|
||||
if (mcpTools && mcpTools.length > 0) {
|
||||
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
|
||||
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, getActiveServers(store.getState()))
|
||||
}
|
||||
|
||||
// const tools = mcpToolsToGeminiTools(mcpTools)
|
||||
|
||||
@ -20,6 +20,7 @@ import {
|
||||
filterUserRoleStartMessages
|
||||
} from '@renderer/services/MessagesService'
|
||||
import store from '@renderer/store'
|
||||
import { getActiveServers } from '@renderer/store/mcp'
|
||||
import {
|
||||
Assistant,
|
||||
FileTypes,
|
||||
@ -318,7 +319,7 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
}
|
||||
}
|
||||
if (mcpTools && mcpTools.length > 0) {
|
||||
systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools)
|
||||
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools, getActiveServers(store.getState()))
|
||||
}
|
||||
|
||||
const userMessages: ChatCompletionMessageParam[] = []
|
||||
|
||||
@ -97,6 +97,13 @@ export const builtinMCPServers: MCPServer[] = [
|
||||
type: 'inMemory',
|
||||
description: '实现文件系统操作的模型上下文协议(MCP)的 Node.js 服务器',
|
||||
isActive: false
|
||||
},
|
||||
{
|
||||
id: nanoid(),
|
||||
name: '@cherry/simpleremember',
|
||||
type: 'inMemory',
|
||||
description: '自动记忆工具,功能跟上面的记忆工具差不多。这个记忆会自动应用到对话中,无需显式调用。适合记住用户偏好、项目背景等长期有用信息.可以跨对话。',
|
||||
isActive: true
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@ -147,12 +147,28 @@ ${availableTools}
|
||||
</tools>`
|
||||
}
|
||||
|
||||
export const buildSystemPrompt = (userSystemPrompt: string, tools: MCPTool[]): string => {
|
||||
import { MCPServer } from '@renderer/types'
|
||||
import { getRememberedMemories } from './remember-utils'
|
||||
|
||||
export const buildSystemPrompt = async (userSystemPrompt: string, tools: MCPTool[], mcpServers: MCPServer[] = []): Promise<string> => {
|
||||
// 获取记忆
|
||||
let memoriesPrompt = '';
|
||||
try {
|
||||
memoriesPrompt = await getRememberedMemories(mcpServers);
|
||||
} catch (error) {
|
||||
console.error('Error getting memories:', error);
|
||||
}
|
||||
|
||||
// 添加记忆工具的使用说明
|
||||
const rememberInstructions = '\n\n您可以使用remember工具记住用户的长期偏好和重要信息。当用户说"请记住..."或"记住..."时,使用remember工具存储这些信息。记忆会自动应用到所有对话中,无需显式调用。';
|
||||
|
||||
const enhancedPrompt = userSystemPrompt + rememberInstructions + memoriesPrompt;
|
||||
|
||||
if (tools && tools.length > 0) {
|
||||
return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt)
|
||||
return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', enhancedPrompt)
|
||||
.replace('{{ TOOL_USE_EXAMPLES }}', ToolUseExamples)
|
||||
.replace('{{ AVAILABLE_TOOLS }}', AvailableTools(tools))
|
||||
}
|
||||
|
||||
return userSystemPrompt
|
||||
return enhancedPrompt
|
||||
}
|
||||
|
||||
67
src/renderer/src/utils/remember-utils.ts
Normal file
67
src/renderer/src/utils/remember-utils.ts
Normal file
@ -0,0 +1,67 @@
|
||||
// src/renderer/src/utils/remember-utils.ts
|
||||
import { MCPServer } from '@renderer/types'
|
||||
|
||||
export async function getRememberedMemories(mcpServers: MCPServer[]): Promise<string> {
|
||||
try {
|
||||
// 查找simpleremember服务器
|
||||
const rememberServer = mcpServers.find(server => server.name === '@cherry/simpleremember' && server.isActive);
|
||||
|
||||
if (!rememberServer) {
|
||||
console.log('[SimpleRemember] Server not found or not active');
|
||||
return '';
|
||||
}
|
||||
|
||||
console.log('[SimpleRemember] Found server:', rememberServer.name, 'isActive:', rememberServer.isActive);
|
||||
|
||||
// 调用get_memories工具
|
||||
try {
|
||||
console.log('[SimpleRemember] Calling get_memories tool...');
|
||||
const response = await window.api.mcp.callTool({
|
||||
server: rememberServer,
|
||||
name: 'get_memories',
|
||||
args: {}
|
||||
});
|
||||
|
||||
console.log('[SimpleRemember] get_memories response:', response);
|
||||
|
||||
if (response.isError) {
|
||||
console.error('[SimpleRemember] Error getting memories:', response);
|
||||
return '';
|
||||
}
|
||||
|
||||
// 解析记忆
|
||||
// 根据MCP规范,工具返回的是content数组,而不是data
|
||||
let memories = [];
|
||||
if (response.content && response.content.length > 0 && response.content[0].text) {
|
||||
try {
|
||||
memories = JSON.parse(response.content[0].text);
|
||||
} catch (parseError) {
|
||||
console.error('[SimpleRemember] Failed to parse memories JSON:', parseError);
|
||||
return '';
|
||||
}
|
||||
} else if (response.data) {
|
||||
// 兼容旧版本的返回格式
|
||||
memories = response.data;
|
||||
}
|
||||
|
||||
console.log('[SimpleRemember] Parsed memories:', memories);
|
||||
|
||||
if (!Array.isArray(memories) || memories.length === 0) {
|
||||
console.log('[SimpleRemember] No memories found or invalid format');
|
||||
return '';
|
||||
}
|
||||
|
||||
// 构建记忆提示词
|
||||
const memoryPrompt = memories.map(memory => `- ${memory.content}`).join('\n');
|
||||
console.log('[SimpleRemember] Generated memory prompt:', memoryPrompt);
|
||||
|
||||
return `\n\n用户的记忆:\n${memoryPrompt}`;
|
||||
} catch (toolError) {
|
||||
console.error('[SimpleRemember] Error calling get_memories tool:', toolError);
|
||||
return '';
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[SimpleRemember] Error in getRememberedMemories:', error);
|
||||
return '';
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user