记忆功能

This commit is contained in:
1600822305 2025-04-12 22:03:13 +08:00
parent 8b95a131ec
commit b8dffce149
7 changed files with 404 additions and 6 deletions

View File

@ -0,0 +1,303 @@
// src/main/mcpServers/simpleremember.ts
import { getConfigDir } from '@main/utils/file'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ListToolsRequestSchema, ListPromptsRequestSchema, McpError, ErrorCode } from '@modelcontextprotocol/sdk/types.js'
import { promises as fs } from 'fs'
import path from 'path'
import { Mutex } from 'async-mutex'
// 定义记忆文件路径
const defaultMemoryPath = path.join(getConfigDir(), 'simpleremember.json')
// 记忆项接口
interface Memory {
content: string;
createdAt: string;
}
// 记忆存储结构
interface MemoryStorage {
memories: Memory[];
}
class SimpleRememberManager {
private memoryPath: string;
private memories: Memory[] = [];
private fileMutex: Mutex = new Mutex();
constructor(memoryPath: string) {
this.memoryPath = memoryPath;
}
// 静态工厂方法用于初始化
public static async create(memoryPath: string): Promise<SimpleRememberManager> {
const manager = new SimpleRememberManager(memoryPath);
await manager._ensureMemoryPathExists();
await manager._loadMemoriesFromDisk();
return manager;
}
// 确保记忆文件存在
private async _ensureMemoryPathExists(): Promise<void> {
try {
const directory = path.dirname(this.memoryPath);
await fs.mkdir(directory, { recursive: true });
try {
await fs.access(this.memoryPath);
} catch (error) {
// 文件不存在,创建一个空文件
await fs.writeFile(this.memoryPath, JSON.stringify({ memories: [] }, null, 2));
}
} catch (error) {
console.error('Failed to ensure memory path exists:', error);
throw new McpError(ErrorCode.InternalError, `Failed to ensure memory path: ${error instanceof Error ? error.message : String(error)}`);
}
}
// 从磁盘加载记忆
private async _loadMemoriesFromDisk(): Promise<void> {
try {
const data = await fs.readFile(this.memoryPath, 'utf-8');
// 处理空文件情况
if (data.trim() === '') {
this.memories = [];
await this._persistMemories();
return;
}
const storage: MemoryStorage = JSON.parse(data);
this.memories = storage.memories || [];
} catch (error) {
if (error instanceof Error && 'code' in error && (error as any).code === 'ENOENT') {
this.memories = [];
await this._persistMemories();
} else if (error instanceof SyntaxError) {
console.error('Failed to parse simpleremember.json, initializing with empty memories:', error);
this.memories = [];
await this._persistMemories();
} else {
console.error('Unexpected error loading memories:', error);
throw new McpError(ErrorCode.InternalError, `Failed to load memories: ${error instanceof Error ? error.message : String(error)}`);
}
}
}
// 将记忆持久化到磁盘
private async _persistMemories(): Promise<void> {
const release = await this.fileMutex.acquire();
try {
const storage: MemoryStorage = {
memories: this.memories
};
await fs.writeFile(this.memoryPath, JSON.stringify(storage, null, 2));
} catch (error) {
console.error('Failed to save memories:', error);
throw new McpError(ErrorCode.InternalError, `Failed to save memories: ${error instanceof Error ? error.message : String(error)}`);
} finally {
release();
}
}
// 添加新记忆
async remember(memory: string): Promise<Memory> {
const newMemory: Memory = {
content: memory,
createdAt: new Date().toISOString()
};
this.memories.push(newMemory);
await this._persistMemories();
return newMemory;
}
// 获取所有记忆
async getAllMemories(): Promise<Memory[]> {
return [...this.memories];
}
// 获取记忆 - 这个方法会被get_memories工具调用
async get_memories(): Promise<Memory[]> {
return this.getAllMemories();
}
}
// 定义工具 - 按照MCP规范定义工具
const REMEMBER_TOOL = {
name: 'remember',
description: '用于记忆长期有用信息的工具。这个工具会自动应用记忆,无需显式调用。只用于存储长期有用的信息,不适合临时信息。',
inputSchema: {
type: 'object',
properties: {
memory: {
type: 'string',
description: '要记住的简洁(1句话)记忆内容'
}
},
required: ['memory']
}
};
const GET_MEMORIES_TOOL = {
name: 'get_memories',
description: '获取所有已存储的记忆',
inputSchema: {
type: 'object',
properties: {}
}
};
// 添加日志以便调试
console.log("[SimpleRemember] Defined tools:", { REMEMBER_TOOL, GET_MEMORIES_TOOL });
class SimpleRememberServer {
public server: Server;
private simpleRememberManager: SimpleRememberManager | null = null;
private initializationPromise: Promise<void>;
constructor(envPath: string = '') {
const memoryPath = envPath
? path.isAbsolute(envPath)
? envPath
: path.resolve(envPath)
: defaultMemoryPath;
console.log("[SimpleRemember] Creating server with memory path:", memoryPath);
// 初始化服务器
this.server = new Server(
{
name: 'simple-remember-server',
version: '1.0.0'
},
{
capabilities: {
tools: {
// 按照MCP规范声明工具能力
listChanged: true
},
// 添加空的prompts能力表示支持提示词功能但没有实际的提示词
prompts: {}
}
}
);
console.log("[SimpleRemember] Server initialized with tools capability");
// 手动添加工具到服务器的工具列表中
console.log("[SimpleRemember] Adding tools to server");
// 先设置请求处理程序,再初始化管理器
this.setupRequestHandlers();
this.initializationPromise = this._initializeManager(memoryPath);
console.log("[SimpleRemember] Server initialization complete");
// 打印工具信息以确认它们已注册
console.log("[SimpleRemember] Tools registered:", [REMEMBER_TOOL.name, GET_MEMORIES_TOOL.name]);
}
private async _initializeManager(memoryPath: string): Promise<void> {
try {
this.simpleRememberManager = await SimpleRememberManager.create(memoryPath);
console.log("SimpleRememberManager initialized successfully.");
} catch (error) {
console.error("Failed to initialize SimpleRememberManager:", error);
this.simpleRememberManager = null;
}
}
private async _getManager(): Promise<SimpleRememberManager> {
if (!this.simpleRememberManager) {
await this.initializationPromise;
if (!this.simpleRememberManager) {
throw new McpError(ErrorCode.InternalError, "SimpleRememberManager is not initialized");
}
}
return this.simpleRememberManager;
}
setupRequestHandlers() {
// 添加对prompts/list请求的处理
this.server.setRequestHandler(ListPromptsRequestSchema, async (request) => {
console.log("[SimpleRemember] Listing prompts request received", request);
// 返回空的提示词列表
return {
prompts: []
};
});
this.server.setRequestHandler(ListToolsRequestSchema, async (request) => {
// 直接返回工具列表,不需要等待管理器初始化
console.log("[SimpleRemember] Listing tools request received", request);
// 打印工具定义以确保它们存在
console.log("[SimpleRemember] REMEMBER_TOOL:", JSON.stringify(REMEMBER_TOOL));
console.log("[SimpleRemember] GET_MEMORIES_TOOL:", JSON.stringify(GET_MEMORIES_TOOL));
const toolsList = [REMEMBER_TOOL, GET_MEMORIES_TOOL];
console.log("[SimpleRemember] Returning tools:", JSON.stringify(toolsList));
// 按照MCP规范返回工具列表
return {
tools: toolsList,
// 如果有分页可以添加nextCursor
// nextCursor: "next-page-cursor"
};
});
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
const { name, arguments: args } = request.params;
console.log(`[SimpleRemember] Received tool call: ${name}`, args);
try {
const manager = await this._getManager();
if (name === 'remember') {
if (!args || typeof args.memory !== 'string') {
console.error(`[SimpleRemember] Invalid arguments for ${name}:`, args);
throw new McpError(ErrorCode.InvalidParams, `Invalid arguments for ${name}: 'memory' string is required.`);
}
console.log(`[SimpleRemember] Remembering: "${args.memory}"`);
const result = await manager.remember(args.memory);
console.log(`[SimpleRemember] Memory saved successfully:`, result);
// 按照MCP规范返回工具调用结果
return {
content: [{
type: 'text',
text: `记忆已保存: "${args.memory}"`
}],
isError: false
};
}
if (name === 'get_memories') {
console.log(`[SimpleRemember] Getting all memories`);
const memories = await manager.get_memories();
console.log(`[SimpleRemember] Retrieved ${memories.length} memories`);
// 按照MCP规范返回工具调用结果
return {
content: [{
type: 'text',
text: JSON.stringify(memories, null, 2)
}],
isError: false
};
}
console.error(`[SimpleRemember] Unknown tool: ${name}`);
throw new McpError(ErrorCode.MethodNotFound, `Unknown tool: ${name}`);
} catch (error) {
console.error(`[SimpleRemember] Error handling tool call ${name}:`, error);
// 按照MCP规范返回工具调用错误
return {
content: [{
type: 'text',
text: error instanceof Error ? error.message : String(error)
}],
isError: true
};
}
});
}
}
export default SimpleRememberServer;

View File

@ -14,6 +14,8 @@ import { Assistant, FileTypes, MCPToolResponse, Message, Model, Provider, Sugges
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import store from '@renderer/store'
import { getActiveServers } from '@renderer/store/mcp'
import { first, flatten, sum, takeRight } from 'lodash'
import OpenAI from 'openai'
@ -177,7 +179,7 @@ export default class AnthropicProvider extends BaseProvider {
let systemPrompt = assistant.prompt
if (mcpTools && mcpTools.length > 0) {
systemPrompt = buildSystemPrompt(systemPrompt, mcpTools)
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools, getActiveServers(store.getState()))
}
const body: MessageCreateParamsNonStreaming = {

View File

@ -24,6 +24,8 @@ import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import { EVENT_NAMES } from '@renderer/services/EventService'
import store from '@renderer/store'
import { getActiveServers } from '@renderer/store/mcp'
import {
filterContextMessages,
filterEmptyMessages,
@ -228,7 +230,7 @@ export default class GeminiProvider extends BaseProvider {
let systemInstruction = assistant.prompt
if (mcpTools && mcpTools.length > 0) {
systemInstruction = buildSystemPrompt(assistant.prompt || '', mcpTools)
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, getActiveServers(store.getState()))
}
// const tools = mcpToolsToGeminiTools(mcpTools)

View File

@ -20,6 +20,7 @@ import {
filterUserRoleStartMessages
} from '@renderer/services/MessagesService'
import store from '@renderer/store'
import { getActiveServers } from '@renderer/store/mcp'
import {
Assistant,
FileTypes,
@ -318,7 +319,7 @@ export default class OpenAIProvider extends BaseProvider {
}
}
if (mcpTools && mcpTools.length > 0) {
systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools)
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools, getActiveServers(store.getState()))
}
const userMessages: ChatCompletionMessageParam[] = []

View File

@ -97,6 +97,13 @@ export const builtinMCPServers: MCPServer[] = [
type: 'inMemory',
description: '实现文件系统操作的模型上下文协议MCP的 Node.js 服务器',
isActive: false
},
{
id: nanoid(),
name: '@cherry/simpleremember',
type: 'inMemory',
description: '自动记忆工具,功能跟上面的记忆工具差不多。这个记忆会自动应用到对话中,无需显式调用。适合记住用户偏好、项目背景等长期有用信息.可以跨对话。',
isActive: true
}
]

View File

@ -147,12 +147,28 @@ ${availableTools}
</tools>`
}
export const buildSystemPrompt = (userSystemPrompt: string, tools: MCPTool[]): string => {
import { MCPServer } from '@renderer/types'
import { getRememberedMemories } from './remember-utils'
export const buildSystemPrompt = async (userSystemPrompt: string, tools: MCPTool[], mcpServers: MCPServer[] = []): Promise<string> => {
// 获取记忆
let memoriesPrompt = '';
try {
memoriesPrompt = await getRememberedMemories(mcpServers);
} catch (error) {
console.error('Error getting memories:', error);
}
// 添加记忆工具的使用说明
const rememberInstructions = '\n\n您可以使用remember工具记住用户的长期偏好和重要信息。当用户说"请记住..."或"记住..."时使用remember工具存储这些信息。记忆会自动应用到所有对话中无需显式调用。';
const enhancedPrompt = userSystemPrompt + rememberInstructions + memoriesPrompt;
if (tools && tools.length > 0) {
return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt)
return SYSTEM_PROMPT.replace('{{ USER_SYSTEM_PROMPT }}', enhancedPrompt)
.replace('{{ TOOL_USE_EXAMPLES }}', ToolUseExamples)
.replace('{{ AVAILABLE_TOOLS }}', AvailableTools(tools))
}
return userSystemPrompt
return enhancedPrompt
}

View File

@ -0,0 +1,67 @@
// src/renderer/src/utils/remember-utils.ts
import { MCPServer } from '@renderer/types'
export async function getRememberedMemories(mcpServers: MCPServer[]): Promise<string> {
try {
// 查找simpleremember服务器
const rememberServer = mcpServers.find(server => server.name === '@cherry/simpleremember' && server.isActive);
if (!rememberServer) {
console.log('[SimpleRemember] Server not found or not active');
return '';
}
console.log('[SimpleRemember] Found server:', rememberServer.name, 'isActive:', rememberServer.isActive);
// 调用get_memories工具
try {
console.log('[SimpleRemember] Calling get_memories tool...');
const response = await window.api.mcp.callTool({
server: rememberServer,
name: 'get_memories',
args: {}
});
console.log('[SimpleRemember] get_memories response:', response);
if (response.isError) {
console.error('[SimpleRemember] Error getting memories:', response);
return '';
}
// 解析记忆
// 根据MCP规范工具返回的是content数组而不是data
let memories = [];
if (response.content && response.content.length > 0 && response.content[0].text) {
try {
memories = JSON.parse(response.content[0].text);
} catch (parseError) {
console.error('[SimpleRemember] Failed to parse memories JSON:', parseError);
return '';
}
} else if (response.data) {
// 兼容旧版本的返回格式
memories = response.data;
}
console.log('[SimpleRemember] Parsed memories:', memories);
if (!Array.isArray(memories) || memories.length === 0) {
console.log('[SimpleRemember] No memories found or invalid format');
return '';
}
// 构建记忆提示词
const memoryPrompt = memories.map(memory => `- ${memory.content}`).join('\n');
console.log('[SimpleRemember] Generated memory prompt:', memoryPrompt);
return `\n\n用户的记忆:\n${memoryPrompt}`;
} catch (toolError) {
console.error('[SimpleRemember] Error calling get_memories tool:', toolError);
return '';
}
} catch (error) {
console.error('[SimpleRemember] Error in getRememberedMemories:', error);
return '';
}
}