mirror of
https://github.com/CherryHQ/cherry-studio.git
synced 2026-01-06 13:19:33 +08:00
refactor: 重构大文件上传相关逻辑,适配OpenAI标准文件服务,增加qwen-long和qwen-doc系列模型原生上传支持 (#9997)
* refactor: 重构大文件上传相关逻辑,适配OpenAI标准文件服务,增加qwen-long和qwen-doc的原生上传支持 * chore: 优化大文件上传相关逻辑与类型,优化上传文件资源释放 * fix: 修复原生上传时,用户没有输入内容会导致错误的问题 * chore: 优化函数名称
This commit is contained in:
parent
4a72d40394
commit
79592e2c27
@ -3,6 +3,7 @@ import { Provider } from '@types'
|
||||
import { BaseFileService } from './BaseFileService'
|
||||
import { GeminiService } from './GeminiService'
|
||||
import { MistralService } from './MistralService'
|
||||
import { OpenaiService } from './OpenAIService'
|
||||
|
||||
export class FileServiceManager {
|
||||
private static instance: FileServiceManager
|
||||
@ -30,6 +31,9 @@ export class FileServiceManager {
|
||||
case 'mistral':
|
||||
service = new MistralService(provider)
|
||||
break
|
||||
case 'openai':
|
||||
service = new OpenaiService(provider)
|
||||
break
|
||||
default:
|
||||
throw new Error(`Unsupported service type: ${type}`)
|
||||
}
|
||||
|
||||
125
src/main/services/remotefile/OpenAIService.ts
Normal file
125
src/main/services/remotefile/OpenAIService.ts
Normal file
@ -0,0 +1,125 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { fileStorage } from '@main/services/FileStorage'
|
||||
import { FileListResponse, FileMetadata, FileUploadResponse, Provider } from '@types'
|
||||
import * as fs from 'fs'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import { CacheService } from '../CacheService'
|
||||
import { BaseFileService } from './BaseFileService'
|
||||
|
||||
const logger = loggerService.withContext('OpenAIService')
|
||||
|
||||
export class OpenaiService extends BaseFileService {
|
||||
private static readonly FILE_CACHE_DURATION = 7 * 24 * 60 * 60 * 1000
|
||||
private static readonly generateUIFileIdCacheKey = (fileId: string) => `ui_file_id_${fileId}`
|
||||
private readonly client: OpenAI
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.client = new OpenAI({
|
||||
apiKey: provider.apiKey,
|
||||
baseURL: provider.apiHost
|
||||
})
|
||||
}
|
||||
|
||||
async uploadFile(file: FileMetadata): Promise<FileUploadResponse> {
|
||||
let fileReadStream: fs.ReadStream | undefined
|
||||
try {
|
||||
fileReadStream = fs.createReadStream(fileStorage.getFilePathById(file))
|
||||
// 还原文件原始名,以提高模型对文件的理解
|
||||
const fileStreamWithMeta = Object.assign(fileReadStream, {
|
||||
name: file.origin_name
|
||||
})
|
||||
const response = await this.client.files.create({
|
||||
file: fileStreamWithMeta,
|
||||
purpose: file.purpose || 'assistants'
|
||||
})
|
||||
if (!response.id) {
|
||||
throw new Error('File id not found in response')
|
||||
}
|
||||
// 映射RemoteFileId到UIFileId上
|
||||
CacheService.set<string>(
|
||||
OpenaiService.generateUIFileIdCacheKey(file.id),
|
||||
response.id,
|
||||
OpenaiService.FILE_CACHE_DURATION
|
||||
)
|
||||
return {
|
||||
fileId: response.id,
|
||||
displayName: file.origin_name,
|
||||
status: 'success',
|
||||
originalFile: {
|
||||
type: 'openai',
|
||||
file: response
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error uploading file:', error as Error)
|
||||
return {
|
||||
fileId: '',
|
||||
displayName: file.origin_name,
|
||||
status: 'failed'
|
||||
}
|
||||
} finally {
|
||||
// 销毁文件流
|
||||
if (fileReadStream) fileReadStream.destroy()
|
||||
}
|
||||
}
|
||||
|
||||
async listFiles(): Promise<FileListResponse> {
|
||||
try {
|
||||
const response = await this.client.files.list()
|
||||
return {
|
||||
files: response.data.map((file) => ({
|
||||
id: file.id,
|
||||
displayName: file.filename || '',
|
||||
size: file.bytes,
|
||||
status: 'success', // All listed files are processed,
|
||||
originalFile: {
|
||||
type: 'openai',
|
||||
file
|
||||
}
|
||||
}))
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error listing files:', error as Error)
|
||||
return { files: [] }
|
||||
}
|
||||
}
|
||||
|
||||
async deleteFile(fileId: string): Promise<void> {
|
||||
try {
|
||||
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
|
||||
await this.client.files.delete(cachedRemoteFileId || fileId)
|
||||
logger.debug(`File ${fileId} deleted`)
|
||||
} catch (error) {
|
||||
logger.error('Error deleting file:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async retrieveFile(fileId: string): Promise<FileUploadResponse> {
|
||||
try {
|
||||
// 尝试反映射RemoteFileId
|
||||
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
|
||||
const response = await this.client.files.retrieve(cachedRemoteFileId || fileId)
|
||||
|
||||
return {
|
||||
fileId: response.id,
|
||||
displayName: response.filename,
|
||||
status: 'success',
|
||||
originalFile: {
|
||||
type: 'openai',
|
||||
file: response
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error retrieving file:', error as Error)
|
||||
return {
|
||||
fileId: fileId,
|
||||
displayName: '',
|
||||
status: 'failed',
|
||||
originalFile: undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -10,6 +10,7 @@ import { FileTypes } from '@renderer/types'
|
||||
import { FileMessageBlock } from '@renderer/types/newMessage'
|
||||
import { findFileBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import type { FilePart, TextPart } from 'ai'
|
||||
import type OpenAI from 'openai'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload, supportsPdfInput } from './modelCapabilities'
|
||||
@ -112,6 +113,86 @@ export async function handleGeminiFileUpload(file: FileMetadata, model: Model):
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理OpenAI大文件上传
|
||||
*/
|
||||
export async function handleOpenAILargeFileUpload(
|
||||
file: FileMetadata,
|
||||
model: Model
|
||||
): Promise<(FilePart & { id?: string }) | null> {
|
||||
const provider = getProviderByModel(model)
|
||||
// 如果模型为qwen-long系列,文档中要求purpose需要为'file-extract'
|
||||
if (['qwen-long', 'qwen-doc'].some((modelName) => model.name.includes(modelName))) {
|
||||
file = {
|
||||
...file,
|
||||
// 该类型并不在OpenAI定义中,但符合sdk规范,强制断言
|
||||
purpose: 'file-extract' as OpenAI.FilePurpose
|
||||
}
|
||||
}
|
||||
try {
|
||||
// 检查文件是否已经上传过
|
||||
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
|
||||
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
|
||||
// 断言OpenAIFile对象
|
||||
const remoteFile = fileMetadata.originalFile.file as OpenAI.Files.FileObject
|
||||
// 判断用途是否一致
|
||||
if (remoteFile.purpose !== file.purpose) {
|
||||
logger.warn(`File ${file.origin_name} purpose mismatch: ${remoteFile.purpose} vs ${file.purpose}`)
|
||||
throw new Error('File purpose mismatch')
|
||||
}
|
||||
return {
|
||||
type: 'file',
|
||||
filename: file.origin_name,
|
||||
mediaType: '',
|
||||
data: `fileid://${remoteFile.id}`
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to retrieve file ${file.origin_name}:`, error as Error)
|
||||
return null
|
||||
}
|
||||
try {
|
||||
// 如果文件未上传,执行上传
|
||||
const uploadResult = await window.api.fileService.upload(provider, file)
|
||||
if (uploadResult.originalFile?.file) {
|
||||
// 断言OpenAIFile对象
|
||||
const remoteFile = uploadResult.originalFile.file as OpenAI.Files.FileObject
|
||||
logger.info(`File ${file.origin_name} uploaded.`)
|
||||
return {
|
||||
type: 'file',
|
||||
filename: remoteFile.filename,
|
||||
mediaType: '',
|
||||
data: `fileid://${remoteFile.id}`
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error(`Failed to upload file ${file.origin_name}:`, error as Error)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 大文件上传路由函数
|
||||
*/
|
||||
export async function handleLargeFileUpload(
|
||||
file: FileMetadata,
|
||||
model: Model
|
||||
): Promise<(FilePart & { id?: string }) | null> {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
if (['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)) {
|
||||
return await handleGeminiFileUpload(file, model)
|
||||
}
|
||||
|
||||
if (provider.type === 'openai') {
|
||||
return await handleOpenAILargeFileUpload(file, model)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 将文件块转换为FilePart(用于原生文件支持)
|
||||
*/
|
||||
@ -127,7 +208,7 @@ export async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, mo
|
||||
// 如果支持大文件上传(如Gemini File API),尝试上传
|
||||
if (supportsLargeFileUpload(model)) {
|
||||
logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`)
|
||||
const uploadResult = await handleGeminiFileUpload(file, model)
|
||||
const uploadResult = await handleLargeFileUpload(file, model)
|
||||
if (uploadResult) {
|
||||
return uploadResult
|
||||
}
|
||||
|
||||
@ -13,7 +13,15 @@ import {
|
||||
findThinkingBlocks,
|
||||
getMainTextContent
|
||||
} from '@renderer/utils/messageUtils/find'
|
||||
import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai'
|
||||
import type {
|
||||
AssistantModelMessage,
|
||||
FilePart,
|
||||
ImagePart,
|
||||
ModelMessage,
|
||||
SystemModelMessage,
|
||||
TextPart,
|
||||
UserModelMessage
|
||||
} from 'ai'
|
||||
|
||||
import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor'
|
||||
|
||||
@ -27,7 +35,7 @@ export async function convertMessageToSdkParam(
|
||||
message: Message,
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<ModelMessage> {
|
||||
): Promise<ModelMessage | ModelMessage[]> {
|
||||
const content = getMainTextContent(message)
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
@ -48,7 +56,7 @@ async function convertMessageToUserModelMessage(
|
||||
imageBlocks: ImageMessageBlock[],
|
||||
isVisionModel = false,
|
||||
model?: Model
|
||||
): Promise<UserModelMessage> {
|
||||
): Promise<UserModelMessage | (UserModelMessage | SystemModelMessage)[]> {
|
||||
const parts: Array<TextPart | FilePart | ImagePart> = []
|
||||
if (content) {
|
||||
parts.push({ type: 'text', text: content })
|
||||
@ -85,6 +93,19 @@ async function convertMessageToUserModelMessage(
|
||||
if (model) {
|
||||
const filePart = await convertFileBlockToFilePart(fileBlock, model)
|
||||
if (filePart) {
|
||||
// 判断filePart是否为string
|
||||
if (typeof filePart.data === 'string' && filePart.data.startsWith('fileid://')) {
|
||||
return [
|
||||
{
|
||||
role: 'system',
|
||||
content: filePart.data
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: parts.length > 0 ? parts : ''
|
||||
}
|
||||
]
|
||||
}
|
||||
parts.push(filePart)
|
||||
logger.debug(`File ${file.origin_name} processed as native file format`)
|
||||
processed = true
|
||||
@ -159,7 +180,7 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M
|
||||
|
||||
for (const message of messages) {
|
||||
const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
|
||||
sdkMessages.push(sdkMessage)
|
||||
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
|
||||
}
|
||||
|
||||
return sdkMessages
|
||||
|
||||
@ -10,26 +10,61 @@ import { FileTypes } from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
|
||||
// 工具函数:基于模型名和提供商判断是否支持某特性
|
||||
function modelSupportValidator(
|
||||
model: Model,
|
||||
{
|
||||
supportedModels = [],
|
||||
unsupportedModels = [],
|
||||
supportedProviders = [],
|
||||
unsupportedProviders = []
|
||||
}: {
|
||||
supportedModels?: string[]
|
||||
unsupportedModels?: string[]
|
||||
supportedProviders?: string[]
|
||||
unsupportedProviders?: string[]
|
||||
}
|
||||
): boolean {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
// 黑名单:命中不支持的模型直接拒绝
|
||||
if (unsupportedModels.some((name) => model.name.includes(name))) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 黑名单:命中不支持的提供商直接拒绝,常用于某些提供商的同名模型并不具备原模型的某些特性
|
||||
if (unsupportedProviders.includes(aiSdkId)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 白名单:命中支持的模型名
|
||||
if (supportedModels.some((name) => model.name.includes(name))) {
|
||||
return true
|
||||
}
|
||||
|
||||
// 回退到提供商判断
|
||||
return supportedProviders.includes(aiSdkId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否支持原生PDF输入
|
||||
*/
|
||||
export function supportsPdfInput(model: Model): boolean {
|
||||
// 基于AI SDK文档,这些提供商支持PDF输入
|
||||
const supportedProviders = [
|
||||
'openai',
|
||||
'azure-openai',
|
||||
'anthropic',
|
||||
'google',
|
||||
'google-generative-ai',
|
||||
'google-vertex',
|
||||
'bedrock',
|
||||
'amazon-bedrock'
|
||||
]
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
return supportedProviders.some((provider) => aiSdkId === provider)
|
||||
// 基于AI SDK文档,以下模型或提供商支持PDF输入
|
||||
return modelSupportValidator(model, {
|
||||
supportedModels: ['qwen-long', 'qwen-doc'],
|
||||
supportedProviders: [
|
||||
'openai',
|
||||
'azure-openai',
|
||||
'anthropic',
|
||||
'google',
|
||||
'google-generative-ai',
|
||||
'google-vertex',
|
||||
'bedrock',
|
||||
'amazon-bedrock'
|
||||
]
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@ -43,11 +78,11 @@ export function supportsImageInput(model: Model): boolean {
|
||||
* 检查提供商是否支持大文件上传(如Gemini File API)
|
||||
*/
|
||||
export function supportsLargeFileUpload(model: Model): boolean {
|
||||
const provider = getProviderByModel(model)
|
||||
const aiSdkId = getAiSdkProviderId(provider)
|
||||
|
||||
// 目前主要是Gemini系列支持大文件上传
|
||||
return ['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)
|
||||
// 基于AI SDK文档,以下模型或提供商支持大文件上传
|
||||
return modelSupportValidator(model, {
|
||||
supportedModels: ['qwen-long', 'qwen-doc'],
|
||||
supportedProviders: ['google', 'google-generative-ai', 'google-vertex']
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
@ -67,6 +102,11 @@ export function getFileSizeLimit(model: Model, fileType: FileTypes): number {
|
||||
return 20 * 1024 * 1024 // 20MB
|
||||
}
|
||||
|
||||
// Dashscope如果模型支持大文件上传优先使用File API上传
|
||||
if (aiSdkId === 'dashscope' && supportsLargeFileUpload(model)) {
|
||||
return 0 // 使用较小的默认值
|
||||
}
|
||||
|
||||
// 其他提供商没有明确限制,使用较大的默认值
|
||||
// 这与Legacy架构中的实现一致,让提供商自行处理文件大小
|
||||
return Infinity
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import type { File } from '@google/genai'
|
||||
import type { FileSchema } from '@mistralai/mistralai/models/components'
|
||||
import type OpenAI from 'openai'
|
||||
|
||||
export type RemoteFile =
|
||||
| {
|
||||
@ -10,13 +11,17 @@ export type RemoteFile =
|
||||
type: 'mistral'
|
||||
file: FileSchema
|
||||
}
|
||||
| {
|
||||
type: 'openai'
|
||||
file: OpenAI.Files.FileObject
|
||||
}
|
||||
|
||||
/**
|
||||
* Type guard to check if a RemoteFile is a Gemini file
|
||||
* @param file - The RemoteFile to check
|
||||
* @returns True if the file is a Gemini file (file property is of type File)
|
||||
*/
|
||||
export const isGeminiFile = (file: RemoteFile): file is RemoteFile & { type: 'gemini'; file: File } => {
|
||||
export const isGeminiFile = (file: RemoteFile): file is { type: 'gemini'; file: File } => {
|
||||
return file.type === 'gemini'
|
||||
}
|
||||
|
||||
@ -25,10 +30,18 @@ export const isGeminiFile = (file: RemoteFile): file is RemoteFile & { type: 'ge
|
||||
* @param file - The RemoteFile to check
|
||||
* @returns True if the file is a Mistral file (file property is of type FileSchema)
|
||||
*/
|
||||
export const isMistralFile = (file: RemoteFile): file is RemoteFile & { type: 'mistral'; file: FileSchema } => {
|
||||
export const isMistralFile = (file: RemoteFile): file is { type: 'mistral'; file: FileSchema } => {
|
||||
return file.type === 'mistral'
|
||||
}
|
||||
|
||||
/** Type guard to check if a RemoteFile is an OpenAI file
|
||||
* @param file - The RemoteFile to check
|
||||
* @returns True if the file is an OpenAI file (file property is of type OpenAI.Files.FileObject)
|
||||
*/
|
||||
export const isOpenAIFile = (file: RemoteFile): file is { type: 'openai'; file: OpenAI.Files.FileObject } => {
|
||||
return file.type === 'openai'
|
||||
}
|
||||
|
||||
export type FileStatus = 'success' | 'processing' | 'failed' | 'unknown'
|
||||
|
||||
export interface FileUploadResponse {
|
||||
@ -93,6 +106,10 @@ export interface FileMetadata {
|
||||
* 该文件预计的token大小 (可选)
|
||||
*/
|
||||
tokens?: number
|
||||
/**
|
||||
* 该文件的用途
|
||||
*/
|
||||
purpose?: OpenAI.FilePurpose
|
||||
}
|
||||
|
||||
export interface FileType extends FileMetadata {}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user