import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant' import { getOpenAIWebSearchParams, isGrokReasoningModel, isHunyuanSearchModel, isOpenAIoSeries, isOpenAIWebSearch, isReasoningModel, isSupportedModel, isVisionModel, isZhipuModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' import { EVENT_NAMES } from '@renderer/services/EventService' import { filterContextMessages, filterEmptyMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService' import store from '@renderer/store' import { Assistant, FileTypes, GenerateImageParams, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types' import { Message } from '@renderer/types/newMessageTypes' import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { addImageFileToContents } from '@renderer/utils/formats' import { mcpToolCallResponseToOpenAIMessage, parseAndCallTools } from '@renderer/utils/mcp-tools' import { findFileBlocks, findImageBlocks, getMessageContent } from '@renderer/utils/messageUtils/find' import { buildSystemPrompt } from '@renderer/utils/prompt' import { takeRight } from 'lodash' import OpenAI, { AzureOpenAI } from 'openai' import { ChatCompletionContentPart, ChatCompletionCreateParamsNonStreaming, ChatCompletionMessageParam } from 'openai/resources' import { CompletionsParams } from '.' import BaseProvider from './BaseProvider' type ReasoningEffort = 'high' | 'medium' | 'low' export default class OpenAIProvider extends BaseProvider { private sdk: OpenAI constructor(provider: Provider) { super(provider) if (provider.id === 'azure-openai' || provider.type === 'azure-openai') { this.sdk = new AzureOpenAI({ dangerouslyAllowBrowser: true, apiKey: this.apiKey, apiVersion: provider.apiVersion, endpoint: provider.apiHost }) return } this.sdk = new OpenAI({ dangerouslyAllowBrowser: true, apiKey: this.apiKey, baseURL: this.getBaseURL(), defaultHeaders: { ...this.defaultHeaders(), ...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}) } }) } /** * Check if the provider does not support files * @returns True if the provider does not support files, false otherwise */ private get isNotSupportFiles() { if (this.provider?.isNotSupportArrayContent) { return true } const providers = ['deepseek', 'baichuan', 'minimax', 'xirang'] return providers.includes(this.provider.id) } /** * Extract the file content from the message * @param message - The message * @returns The file content */ private async extractFileContent(message: Message) { const fileBlocks = findFileBlocks(message) if (fileBlocks.length > 0) { const textFileBlocks = fileBlocks.filter( (fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type) ) if (textFileBlocks.length > 0) { let text = '' const divider = '\n\n---\n\n' for (const fileBlock of textFileBlocks) { const file = fileBlock.file const fileContent = (await window.api.file.read(file.id + file.ext)).trim() const fileNameRow = 'file: ' + file.origin_name + '\n\n' text = text + fileNameRow + fileContent + divider } return text } } return '' } /** * Get the message parameter * @param message - The message * @param model - The model * @returns The message parameter */ private async getMessageParam( message: Message, model: Model ): Promise { const isVision = isVisionModel(model) const content = await this.getMessageContent(message) const fileBlocks = findFileBlocks(message) const imageBlocks = findImageBlocks(message) if (fileBlocks.length === 0 && imageBlocks.length === 0) { return { role: message.role === 'system' ? 'user' : message.role, content } } // If the model does not support files, extract the file content if (this.isNotSupportFiles) { const fileContent = await this.extractFileContent(message) return { role: message.role === 'system' ? 'user' : message.role, content: content + '\n\n---\n\n' + fileContent } } // If the model supports files, add the file content to the message const parts: ChatCompletionContentPart[] = [] if (content) { parts.push({ type: 'text', text: content }) } for (const imageBlock of imageBlocks) { if (isVision) { if (imageBlock.file) { const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext) parts.push({ type: 'image_url', image_url: { url: image.data } }) } else if (imageBlock.url && imageBlock.url.startsWith('data:')) { parts.push({ type: 'image_url', image_url: { url: imageBlock.url } }) } } } for (const fileBlock of fileBlocks) { const file = fileBlock.file if (!file) continue if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) { const fileContent = await (await window.api.file.read(file.id + file.ext)).trim() parts.push({ type: 'text', text: file.origin_name + '\n' + fileContent }) } } return { role: message.role === 'system' ? 'user' : message.role, content: parts } as ChatCompletionMessageParam } /** * Get the temperature for the assistant * @param assistant - The assistant * @param model - The model * @returns The temperature */ private getTemperature(assistant: Assistant, model: Model) { return isReasoningModel(model) || isOpenAIWebSearch(model) ? undefined : assistant?.settings?.temperature } /** * Get the provider specific parameters for the assistant * @param assistant - The assistant * @param model - The model * @returns The provider specific parameters */ private getProviderSpecificParameters(assistant: Assistant, model: Model) { const { maxTokens } = getAssistantSettings(assistant) if (this.provider.id === 'openrouter') { if (model.id.includes('deepseek-r1')) { return { include_reasoning: true } } } if (this.isOpenAIReasoning(model)) { return { max_tokens: undefined, max_completion_tokens: maxTokens } } return {} } /** * Get the top P for the assistant * @param assistant - The assistant * @param model - The model * @returns The top P */ private getTopP(assistant: Assistant, model: Model) { if (isReasoningModel(model) || isOpenAIWebSearch(model)) return undefined return assistant?.settings?.topP } /** * Get the reasoning effort for the assistant * @param assistant - The assistant * @param model - The model * @returns The reasoning effort */ private getReasoningEffort(assistant: Assistant, model: Model) { if (this.provider.id === 'groq') { return {} } if (isReasoningModel(model)) { if (model.provider === 'openrouter') { return { reasoning: { effort: assistant?.settings?.reasoning_effort } } } if (isGrokReasoningModel(model)) { return { reasoning_effort: assistant?.settings?.reasoning_effort } } if (isOpenAIoSeries(model)) { return { reasoning_effort: assistant?.settings?.reasoning_effort } } if (model.id.includes('claude-3.7-sonnet') || model.id.includes('claude-3-7-sonnet')) { const effortRatios: Record = { high: 0.8, medium: 0.5, low: 0.2 } const effort = assistant?.settings?.reasoning_effort as ReasoningEffort const effortRatio = effortRatios[effort] if (!effortRatio) { return {} } const maxTokens = assistant?.settings?.maxTokens || DEFAULT_MAX_TOKENS const budgetTokens = Math.trunc(Math.max(Math.min(maxTokens * effortRatio, 32000), 1024)) return { thinking: { type: 'enabled', budget_tokens: budgetTokens } } } return {} } return {} } /** * Check if the model is an OpenAI reasoning model * @param model - The model * @returns True if the model is an OpenAI reasoning model, false otherwise */ private isOpenAIReasoning(model: Model) { return model.id.startsWith('o1') || model.id.startsWith('o3') } /** * Generate completions for the assistant * @param messages - The messages * @param assistant - The assistant * @param mcpTools - The MCP tools * @param onChunk - The onChunk callback * @param onFilterMessages - The onFilterMessages callback * @returns The completions */ async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams): Promise { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant) messages = addImageFileToContents(messages) let systemMessage = { role: 'system', content: assistant.prompt || '' } if (isOpenAIoSeries(model)) { systemMessage = { role: 'developer', content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}` } } if (mcpTools && mcpTools.length > 0) { systemMessage.content = buildSystemPrompt(systemMessage.content || '', mcpTools) } const userMessages: ChatCompletionMessageParam[] = [] const _messages = filterUserRoleStartMessages( filterEmptyMessages(filterContextMessages(takeRight(messages, contextCount + 1))) ) onFilterMessages(_messages) for (const message of _messages) { userMessages.push(await this.getMessageParam(message, model)) } const isOpenAIReasoning = this.isOpenAIReasoning(model) const isSupportStreamOutput = () => { if (isOpenAIReasoning) { return false } return streamOutput } let hasReasoningContent = false let lastChunk = '' const isReasoningJustDone = ( delta: OpenAI.Chat.Completions.ChatCompletionChunk.Choice.Delta & { reasoning_content?: string reasoning?: string thinking?: string } ) => { if (!delta?.content) return false // 检查当前chunk和上一个chunk的组合是否形成###Response标记 const combinedChunks = lastChunk + delta.content lastChunk = delta.content // 检测思考结束 if (combinedChunks.includes('###Response') || delta.content === '') { return true } // 如果有reasoning_content或reasoning,说明是在思考中 if (delta?.reasoning_content || delta?.reasoning || delta?.thinking) { hasReasoningContent = true } // 如果之前有reasoning_content或reasoning,现在有普通content,说明思考结束 if (hasReasoningContent && delta.content) { return true } return false } let time_first_token_millsec = 0 let time_first_content_millsec = 0 const start_time_millsec = new Date().getTime() const lastUserMessage = _messages.findLast((m) => m.role === 'user') const { abortController, cleanup, signalPromise } = this.createAbortController(lastUserMessage?.id, true) const { signal } = abortController await this.checkIsCopilot() const reqMessages: ChatCompletionMessageParam[] = [systemMessage, ...userMessages].filter( Boolean ) as ChatCompletionMessageParam[] const toolResponses: MCPToolResponse[] = [] let firstChunk = true const processToolUses = async (content: string, idx: number) => { const toolResults = await parseAndCallTools( content, toolResponses, onChunk, idx, mcpToolCallResponseToOpenAIMessage, mcpTools, isVisionModel(model) ) if (toolResults.length > 0) { reqMessages.push({ role: 'assistant', content: content } as ChatCompletionMessageParam) toolResults.forEach((ts) => reqMessages.push(ts as ChatCompletionMessageParam)) const newStream = await this.sdk.chat.completions // @ts-ignore key is not typed .create( { model: model.id, messages: reqMessages, temperature: this.getTemperature(assistant, model), top_p: this.getTopP(assistant, model), max_tokens: maxTokens, keep_alive: this.keepAliveTime, stream: isSupportStreamOutput(), // tools: tools, ...getOpenAIWebSearchParams(assistant, model), ...this.getReasoningEffort(assistant, model), ...this.getProviderSpecificParameters(assistant, model), ...this.getCustomParameters(assistant) }, { signal } ) await processStream(newStream, idx + 1) } } const processStream = async (stream: any, idx: number) => { if (!isSupportStreamOutput()) { const time_completion_millsec = new Date().getTime() - start_time_millsec return onChunk({ text: stream.choices[0].message?.content || '', usage: stream.usage, metrics: { completion_tokens: stream.usage?.completion_tokens, time_completion_millsec, time_first_token_millsec: 0 } }) } let content = '' for await (const chunk of stream) { if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) { break } const delta = chunk.choices[0]?.delta if (delta?.content) { content += delta.content } if (delta?.reasoning_content || delta?.reasoning) { hasReasoningContent = true } if (time_first_token_millsec == 0) { time_first_token_millsec = new Date().getTime() - start_time_millsec } if (time_first_content_millsec == 0 && isReasoningJustDone(delta)) { time_first_content_millsec = new Date().getTime() } const time_completion_millsec = new Date().getTime() - start_time_millsec const time_thinking_millsec = time_first_content_millsec ? time_first_content_millsec - start_time_millsec : 0 // Extract citations from the raw response if available const citations = (chunk as OpenAI.Chat.Completions.ChatCompletionChunk & { citations?: string[] })?.citations const finishReason = chunk.choices[0]?.finish_reason let webSearch: any[] | undefined = undefined if (assistant.enableWebSearch && isZhipuModel(model) && finishReason === 'stop') { webSearch = chunk?.web_search } if (firstChunk && assistant.enableWebSearch && isHunyuanSearchModel(model)) { webSearch = chunk?.search_info?.search_results firstChunk = true } onChunk({ text: delta?.content || '', reasoning_content: delta?.reasoning_content || delta?.reasoning || '', usage: chunk.usage, metrics: { completion_tokens: chunk.usage?.completion_tokens, time_completion_millsec, time_first_token_millsec, time_thinking_millsec }, webSearch, annotations: delta?.annotations, citations, mcpToolResponse: toolResponses }) } await processToolUses(content, idx) } const stream = await this.sdk.chat.completions // @ts-ignore key is not typed .create( { model: model.id, messages: reqMessages, temperature: this.getTemperature(assistant, model), top_p: this.getTopP(assistant, model), max_tokens: maxTokens, keep_alive: this.keepAliveTime, stream: isSupportStreamOutput(), // tools: tools, ...getOpenAIWebSearchParams(assistant, model), ...this.getReasoningEffort(assistant, model), ...this.getProviderSpecificParameters(assistant, model), ...this.getCustomParameters(assistant) }, { signal } ) await processStream(stream, 0).finally(cleanup) // 捕获signal的错误 await signalPromise?.promise?.catch((error) => { throw error }) } /** * Translate a message * @param message - The message * @param assistant - The assistant * @param onResponse - The onResponse callback * @returns The translated message */ async translate(message: Message, assistant: Assistant, onResponse?: (text: string) => void) { const defaultModel = getDefaultModel() const model = assistant.model || defaultModel const content = await this.getMessageContent(message) const messagesForApi = content ? [ { role: 'system', content: assistant.prompt }, { role: 'user', content } ] : [{ role: 'user', content: assistant.prompt }] const isOpenAIReasoning = this.isOpenAIReasoning(model) const isSupportedStreamOutput = () => { if (!onResponse) { return false } if (isOpenAIReasoning) { return false } return true } const stream = isSupportedStreamOutput() await this.checkIsCopilot() // @ts-ignore key is not typed const response = await this.sdk.chat.completions.create({ model: model.id, messages: messagesForApi as ChatCompletionMessageParam[], stream, keep_alive: this.keepAliveTime, temperature: assistant?.settings?.temperature }) if (!stream) { return response.choices[0].message?.content || '' } let text = '' let isThinking = false const isReasoning = isReasoningModel(model) for await (const chunk of response) { const deltaContent = chunk.choices[0]?.delta?.content || '' if (isReasoning) { if (deltaContent.includes('')) { isThinking = true } if (!isThinking) { text += deltaContent onResponse?.(text) } if (deltaContent.includes('')) { isThinking = false } } else { text += deltaContent onResponse?.(text) } } return text } /** * Summarize a message * @param messages - The messages * @param assistant - The assistant * @returns The summary */ public async summaries(messages: Message[], assistant: Assistant): Promise { const model = getTopNamingModel() || assistant.model || getDefaultModel() const userMessages = takeRight(messages, 5) .filter((message) => !message.isPreset) .map((message) => ({ role: message.role, content: getMessageContent(message) })) const userMessageContent = userMessages.reduce((prev, curr) => { const content = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}` return prev + (prev ? '\n' : '') + content }, '') const systemMessage = { role: 'system', content: getStoreSetting('topicNamingPrompt') || i18n.t('prompts.title') } const userMessage = { role: 'user', content: userMessageContent } await this.checkIsCopilot() // @ts-ignore key is not typed const response = await this.sdk.chat.completions.create({ model: model.id, messages: [systemMessage, userMessage] as ChatCompletionMessageParam[], stream: false, keep_alive: this.keepAliveTime, max_tokens: 1000 }) // 针对思考类模型的返回,总结仅截取之后的内容 let content = response.choices[0].message?.content || '' content = content.replace(/^(.*?)<\/think>/s, '') return removeSpecialCharactersForTopicName(content.substring(0, 50)) } /** * Summarize a message for search * @param messages - The messages * @param assistant - The assistant * @returns The summary */ public async summaryForSearch(messages: Message[], assistant: Assistant): Promise { const model = assistant.model || getDefaultModel() const systemMessage = { role: 'system', content: assistant.prompt } const messageContents = messages.map((m) => getMessageContent(m)) const userMessageContent = messageContents.join('\n') const userMessage = { role: 'user', content: userMessageContent } // @ts-ignore key is not typed const response = await this.sdk.chat.completions.create( { model: model.id, messages: [systemMessage, userMessage] as ChatCompletionMessageParam[], stream: false, keep_alive: this.keepAliveTime, max_tokens: 1000 }, { timeout: 20 * 1000 } ) // 针对思考类模型的返回,总结仅截取之后的内容 let content = response.choices[0].message?.content || '' content = content.replace(/^(.*?)<\/think>/s, '') return content } /** * Generate text * @param prompt - The prompt * @param content - The content * @returns The generated text */ public async generateText({ prompt, content }: { prompt: string; content: string }): Promise { const model = getDefaultModel() await this.checkIsCopilot() const response = await this.sdk.chat.completions.create({ model: model.id, stream: false, messages: [ { role: 'system', content: prompt }, { role: 'user', content } ] }) return response.choices[0].message?.content || '' } /** * Generate suggestions * @param messages - The messages * @param assistant - The assistant * @returns The suggestions */ async suggestions(messages: Message[], assistant: Assistant): Promise { const model = assistant.model if (!model) { return [] } await this.checkIsCopilot() const userMessagesForApi = messages .filter((m) => m.role === 'user') .map((m) => ({ role: m.role, content: getMessageContent(m) })) const response: any = await this.sdk.request({ method: 'post', path: '/advice_questions', body: { messages: userMessagesForApi, model: model.id, max_tokens: 0, temperature: 0, n: 0 } }) return response?.questions?.filter(Boolean)?.map((q: any) => ({ content: q })) || [] } /** * Check if the model is valid * @param model - The model * @returns The validity of the model */ public async check(model: Model): Promise<{ valid: boolean; error: Error | null }> { if (!model) { return { valid: false, error: new Error('No model found') } } const body = { model: model.id, messages: [{ role: 'user', content: 'hi' }], stream: false } try { await this.checkIsCopilot() const response = await this.sdk.chat.completions.create(body as ChatCompletionCreateParamsNonStreaming) return { valid: Boolean(response?.choices[0].message), error: null } } catch (error: any) { return { valid: false, error } } } /** * Get the models * @returns The models */ public async models(): Promise { try { await this.checkIsCopilot() const response = await this.sdk.models.list() if (this.provider.id === 'github') { // @ts-ignore key is not typed return response.body .map((model) => ({ id: model.name, description: model.summary, object: 'model', owned_by: model.publisher })) .filter(isSupportedModel) } if (this.provider.id === 'together') { // @ts-ignore key is not typed return response?.body .map((model: any) => ({ id: model.id, description: model.display_name, object: 'model', owned_by: model.organization })) .filter(isSupportedModel) } const models = response?.data || [] return models.filter(isSupportedModel) } catch (error) { return [] } } /** * Generate an image * @param params - The parameters * @returns The generated image */ public async generateImage({ model, prompt, negativePrompt, imageSize, batchSize, seed, numInferenceSteps, guidanceScale, signal, promptEnhancement }: GenerateImageParams): Promise { const response = (await this.sdk.request({ method: 'post', path: '/images/generations', signal, body: { model, prompt, negative_prompt: negativePrompt, image_size: imageSize, batch_size: batchSize, seed: seed ? parseInt(seed) : undefined, num_inference_steps: numInferenceSteps, guidance_scale: guidanceScale, prompt_enhancement: promptEnhancement } })) as { data: Array<{ url: string }> } return response.data.map((item) => item.url) } /** * Get the embedding dimensions * @param model - The model * @returns The embedding dimensions */ public async getEmbeddingDimensions(model: Model): Promise { await this.checkIsCopilot() const data = await this.sdk.embeddings.create({ model: model.id, input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi' }) return data.data[0].embedding.length } public async checkIsCopilot() { if (this.provider.id !== 'copilot') return const defaultHeaders = store.getState().copilot.defaultHeaders // copilot每次请求前需要重新获取token,因为token中附带时间戳 const { token } = await window.api.copilot.getToken(defaultHeaders) this.sdk.apiKey = token } }