refactor: 重构大文件上传相关逻辑,适配OpenAI标准文件服务,增加qwen-long和qwen-doc系列模型原生上传支持 (#9997)

* refactor: 重构大文件上传相关逻辑,适配OpenAI标准文件服务,增加qwen-long和qwen-doc的原生上传支持

* chore: 优化大文件上传相关逻辑与类型,优化上传文件资源释放

* fix: 修复原生上传时,用户没有输入内容会导致错误的问题

* chore: 优化函数名称
This commit is contained in:
Carlton 2025-09-08 00:16:48 +08:00 committed by GitHub
parent 4a72d40394
commit 79592e2c27
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 316 additions and 28 deletions

View File

@ -3,6 +3,7 @@ import { Provider } from '@types'
import { BaseFileService } from './BaseFileService' import { BaseFileService } from './BaseFileService'
import { GeminiService } from './GeminiService' import { GeminiService } from './GeminiService'
import { MistralService } from './MistralService' import { MistralService } from './MistralService'
import { OpenaiService } from './OpenAIService'
export class FileServiceManager { export class FileServiceManager {
private static instance: FileServiceManager private static instance: FileServiceManager
@ -30,6 +31,9 @@ export class FileServiceManager {
case 'mistral': case 'mistral':
service = new MistralService(provider) service = new MistralService(provider)
break break
case 'openai':
service = new OpenaiService(provider)
break
default: default:
throw new Error(`Unsupported service type: ${type}`) throw new Error(`Unsupported service type: ${type}`)
} }

View File

@ -0,0 +1,125 @@
import { loggerService } from '@logger'
import { fileStorage } from '@main/services/FileStorage'
import { FileListResponse, FileMetadata, FileUploadResponse, Provider } from '@types'
import * as fs from 'fs'
import OpenAI from 'openai'
import { CacheService } from '../CacheService'
import { BaseFileService } from './BaseFileService'
const logger = loggerService.withContext('OpenAIService')
export class OpenaiService extends BaseFileService {
private static readonly FILE_CACHE_DURATION = 7 * 24 * 60 * 60 * 1000
private static readonly generateUIFileIdCacheKey = (fileId: string) => `ui_file_id_${fileId}`
private readonly client: OpenAI
constructor(provider: Provider) {
super(provider)
this.client = new OpenAI({
apiKey: provider.apiKey,
baseURL: provider.apiHost
})
}
async uploadFile(file: FileMetadata): Promise<FileUploadResponse> {
let fileReadStream: fs.ReadStream | undefined
try {
fileReadStream = fs.createReadStream(fileStorage.getFilePathById(file))
// 还原文件原始名,以提高模型对文件的理解
const fileStreamWithMeta = Object.assign(fileReadStream, {
name: file.origin_name
})
const response = await this.client.files.create({
file: fileStreamWithMeta,
purpose: file.purpose || 'assistants'
})
if (!response.id) {
throw new Error('File id not found in response')
}
// 映射RemoteFileId到UIFileId上
CacheService.set<string>(
OpenaiService.generateUIFileIdCacheKey(file.id),
response.id,
OpenaiService.FILE_CACHE_DURATION
)
return {
fileId: response.id,
displayName: file.origin_name,
status: 'success',
originalFile: {
type: 'openai',
file: response
}
}
} catch (error) {
logger.error('Error uploading file:', error as Error)
return {
fileId: '',
displayName: file.origin_name,
status: 'failed'
}
} finally {
// 销毁文件流
if (fileReadStream) fileReadStream.destroy()
}
}
async listFiles(): Promise<FileListResponse> {
try {
const response = await this.client.files.list()
return {
files: response.data.map((file) => ({
id: file.id,
displayName: file.filename || '',
size: file.bytes,
status: 'success', // All listed files are processed,
originalFile: {
type: 'openai',
file
}
}))
}
} catch (error) {
logger.error('Error listing files:', error as Error)
return { files: [] }
}
}
async deleteFile(fileId: string): Promise<void> {
try {
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
await this.client.files.delete(cachedRemoteFileId || fileId)
logger.debug(`File ${fileId} deleted`)
} catch (error) {
logger.error('Error deleting file:', error as Error)
throw error
}
}
async retrieveFile(fileId: string): Promise<FileUploadResponse> {
try {
// 尝试反映射RemoteFileId
const cachedRemoteFileId = CacheService.get<string>(OpenaiService.generateUIFileIdCacheKey(fileId))
const response = await this.client.files.retrieve(cachedRemoteFileId || fileId)
return {
fileId: response.id,
displayName: response.filename,
status: 'success',
originalFile: {
type: 'openai',
file: response
}
}
} catch (error) {
logger.error('Error retrieving file:', error as Error)
return {
fileId: fileId,
displayName: '',
status: 'failed',
originalFile: undefined
}
}
}
}

View File

@ -10,6 +10,7 @@ import { FileTypes } from '@renderer/types'
import { FileMessageBlock } from '@renderer/types/newMessage' import { FileMessageBlock } from '@renderer/types/newMessage'
import { findFileBlocks } from '@renderer/utils/messageUtils/find' import { findFileBlocks } from '@renderer/utils/messageUtils/find'
import type { FilePart, TextPart } from 'ai' import type { FilePart, TextPart } from 'ai'
import type OpenAI from 'openai'
import { getAiSdkProviderId } from '../provider/factory' import { getAiSdkProviderId } from '../provider/factory'
import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload, supportsPdfInput } from './modelCapabilities' import { getFileSizeLimit, supportsImageInput, supportsLargeFileUpload, supportsPdfInput } from './modelCapabilities'
@ -112,6 +113,86 @@ export async function handleGeminiFileUpload(file: FileMetadata, model: Model):
return null return null
} }
/**
* OpenAI大文件上传
*/
export async function handleOpenAILargeFileUpload(
file: FileMetadata,
model: Model
): Promise<(FilePart & { id?: string }) | null> {
const provider = getProviderByModel(model)
// 如果模型为qwen-long系列文档中要求purpose需要为'file-extract'
if (['qwen-long', 'qwen-doc'].some((modelName) => model.name.includes(modelName))) {
file = {
...file,
// 该类型并不在OpenAI定义中但符合sdk规范强制断言
purpose: 'file-extract' as OpenAI.FilePurpose
}
}
try {
// 检查文件是否已经上传过
const fileMetadata = await window.api.fileService.retrieve(provider, file.id)
if (fileMetadata.status === 'success' && fileMetadata.originalFile?.file) {
// 断言OpenAIFile对象
const remoteFile = fileMetadata.originalFile.file as OpenAI.Files.FileObject
// 判断用途是否一致
if (remoteFile.purpose !== file.purpose) {
logger.warn(`File ${file.origin_name} purpose mismatch: ${remoteFile.purpose} vs ${file.purpose}`)
throw new Error('File purpose mismatch')
}
return {
type: 'file',
filename: file.origin_name,
mediaType: '',
data: `fileid://${remoteFile.id}`
}
}
} catch (error) {
logger.error(`Failed to retrieve file ${file.origin_name}:`, error as Error)
return null
}
try {
// 如果文件未上传,执行上传
const uploadResult = await window.api.fileService.upload(provider, file)
if (uploadResult.originalFile?.file) {
// 断言OpenAIFile对象
const remoteFile = uploadResult.originalFile.file as OpenAI.Files.FileObject
logger.info(`File ${file.origin_name} uploaded.`)
return {
type: 'file',
filename: remoteFile.filename,
mediaType: '',
data: `fileid://${remoteFile.id}`
}
}
} catch (error) {
logger.error(`Failed to upload file ${file.origin_name}:`, error as Error)
}
return null
}
/**
*
*/
export async function handleLargeFileUpload(
file: FileMetadata,
model: Model
): Promise<(FilePart & { id?: string }) | null> {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
if (['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId)) {
return await handleGeminiFileUpload(file, model)
}
if (provider.type === 'openai') {
return await handleOpenAILargeFileUpload(file, model)
}
return null
}
/** /**
* FilePart * FilePart
*/ */
@ -127,7 +208,7 @@ export async function convertFileBlockToFilePart(fileBlock: FileMessageBlock, mo
// 如果支持大文件上传如Gemini File API尝试上传 // 如果支持大文件上传如Gemini File API尝试上传
if (supportsLargeFileUpload(model)) { if (supportsLargeFileUpload(model)) {
logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`) logger.info(`Large PDF file ${file.origin_name} (${file.size} bytes) attempting File API upload`)
const uploadResult = await handleGeminiFileUpload(file, model) const uploadResult = await handleLargeFileUpload(file, model)
if (uploadResult) { if (uploadResult) {
return uploadResult return uploadResult
} }

View File

@ -13,7 +13,15 @@ import {
findThinkingBlocks, findThinkingBlocks,
getMainTextContent getMainTextContent
} from '@renderer/utils/messageUtils/find' } from '@renderer/utils/messageUtils/find'
import type { AssistantModelMessage, FilePart, ImagePart, ModelMessage, TextPart, UserModelMessage } from 'ai' import type {
AssistantModelMessage,
FilePart,
ImagePart,
ModelMessage,
SystemModelMessage,
TextPart,
UserModelMessage
} from 'ai'
import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor' import { convertFileBlockToFilePart, convertFileBlockToTextPart } from './fileProcessor'
@ -27,7 +35,7 @@ export async function convertMessageToSdkParam(
message: Message, message: Message,
isVisionModel = false, isVisionModel = false,
model?: Model model?: Model
): Promise<ModelMessage> { ): Promise<ModelMessage | ModelMessage[]> {
const content = getMainTextContent(message) const content = getMainTextContent(message)
const fileBlocks = findFileBlocks(message) const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message) const imageBlocks = findImageBlocks(message)
@ -48,7 +56,7 @@ async function convertMessageToUserModelMessage(
imageBlocks: ImageMessageBlock[], imageBlocks: ImageMessageBlock[],
isVisionModel = false, isVisionModel = false,
model?: Model model?: Model
): Promise<UserModelMessage> { ): Promise<UserModelMessage | (UserModelMessage | SystemModelMessage)[]> {
const parts: Array<TextPart | FilePart | ImagePart> = [] const parts: Array<TextPart | FilePart | ImagePart> = []
if (content) { if (content) {
parts.push({ type: 'text', text: content }) parts.push({ type: 'text', text: content })
@ -85,6 +93,19 @@ async function convertMessageToUserModelMessage(
if (model) { if (model) {
const filePart = await convertFileBlockToFilePart(fileBlock, model) const filePart = await convertFileBlockToFilePart(fileBlock, model)
if (filePart) { if (filePart) {
// 判断filePart是否为string
if (typeof filePart.data === 'string' && filePart.data.startsWith('fileid://')) {
return [
{
role: 'system',
content: filePart.data
},
{
role: 'user',
content: parts.length > 0 ? parts : ''
}
]
}
parts.push(filePart) parts.push(filePart)
logger.debug(`File ${file.origin_name} processed as native file format`) logger.debug(`File ${file.origin_name} processed as native file format`)
processed = true processed = true
@ -159,7 +180,7 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M
for (const message of messages) { for (const message of messages) {
const sdkMessage = await convertMessageToSdkParam(message, isVision, model) const sdkMessage = await convertMessageToSdkParam(message, isVision, model)
sdkMessages.push(sdkMessage) sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
} }
return sdkMessages return sdkMessages

View File

@ -10,26 +10,61 @@ import { FileTypes } from '@renderer/types'
import { getAiSdkProviderId } from '../provider/factory' import { getAiSdkProviderId } from '../provider/factory'
// 工具函数:基于模型名和提供商判断是否支持某特性
function modelSupportValidator(
model: Model,
{
supportedModels = [],
unsupportedModels = [],
supportedProviders = [],
unsupportedProviders = []
}: {
supportedModels?: string[]
unsupportedModels?: string[]
supportedProviders?: string[]
unsupportedProviders?: string[]
}
): boolean {
const provider = getProviderByModel(model)
const aiSdkId = getAiSdkProviderId(provider)
// 黑名单:命中不支持的模型直接拒绝
if (unsupportedModels.some((name) => model.name.includes(name))) {
return false
}
// 黑名单:命中不支持的提供商直接拒绝,常用于某些提供商的同名模型并不具备原模型的某些特性
if (unsupportedProviders.includes(aiSdkId)) {
return false
}
// 白名单:命中支持的模型名
if (supportedModels.some((name) => model.name.includes(name))) {
return true
}
// 回退到提供商判断
return supportedProviders.includes(aiSdkId)
}
/** /**
* PDF输入 * PDF输入
*/ */
export function supportsPdfInput(model: Model): boolean { export function supportsPdfInput(model: Model): boolean {
// 基于AI SDK文档这些提供商支持PDF输入 // 基于AI SDK文档以下模型或提供商支持PDF输入
const supportedProviders = [ return modelSupportValidator(model, {
'openai', supportedModels: ['qwen-long', 'qwen-doc'],
'azure-openai', supportedProviders: [
'anthropic', 'openai',
'google', 'azure-openai',
'google-generative-ai', 'anthropic',
'google-vertex', 'google',
'bedrock', 'google-generative-ai',
'amazon-bedrock' 'google-vertex',
] 'bedrock',
'amazon-bedrock'
const provider = getProviderByModel(model) ]
const aiSdkId = getAiSdkProviderId(provider) })
return supportedProviders.some((provider) => aiSdkId === provider)
} }
/** /**
@ -43,11 +78,11 @@ export function supportsImageInput(model: Model): boolean {
* Gemini File API * Gemini File API
*/ */
export function supportsLargeFileUpload(model: Model): boolean { export function supportsLargeFileUpload(model: Model): boolean {
const provider = getProviderByModel(model) // 基于AI SDK文档以下模型或提供商支持大文件上传
const aiSdkId = getAiSdkProviderId(provider) return modelSupportValidator(model, {
supportedModels: ['qwen-long', 'qwen-doc'],
// 目前主要是Gemini系列支持大文件上传 supportedProviders: ['google', 'google-generative-ai', 'google-vertex']
return ['google', 'google-generative-ai', 'google-vertex'].includes(aiSdkId) })
} }
/** /**
@ -67,6 +102,11 @@ export function getFileSizeLimit(model: Model, fileType: FileTypes): number {
return 20 * 1024 * 1024 // 20MB return 20 * 1024 * 1024 // 20MB
} }
// Dashscope如果模型支持大文件上传优先使用File API上传
if (aiSdkId === 'dashscope' && supportsLargeFileUpload(model)) {
return 0 // 使用较小的默认值
}
// 其他提供商没有明确限制,使用较大的默认值 // 其他提供商没有明确限制,使用较大的默认值
// 这与Legacy架构中的实现一致让提供商自行处理文件大小 // 这与Legacy架构中的实现一致让提供商自行处理文件大小
return Infinity return Infinity

View File

@ -1,5 +1,6 @@
import type { File } from '@google/genai' import type { File } from '@google/genai'
import type { FileSchema } from '@mistralai/mistralai/models/components' import type { FileSchema } from '@mistralai/mistralai/models/components'
import type OpenAI from 'openai'
export type RemoteFile = export type RemoteFile =
| { | {
@ -10,13 +11,17 @@ export type RemoteFile =
type: 'mistral' type: 'mistral'
file: FileSchema file: FileSchema
} }
| {
type: 'openai'
file: OpenAI.Files.FileObject
}
/** /**
* Type guard to check if a RemoteFile is a Gemini file * Type guard to check if a RemoteFile is a Gemini file
* @param file - The RemoteFile to check * @param file - The RemoteFile to check
* @returns True if the file is a Gemini file (file property is of type File) * @returns True if the file is a Gemini file (file property is of type File)
*/ */
export const isGeminiFile = (file: RemoteFile): file is RemoteFile & { type: 'gemini'; file: File } => { export const isGeminiFile = (file: RemoteFile): file is { type: 'gemini'; file: File } => {
return file.type === 'gemini' return file.type === 'gemini'
} }
@ -25,10 +30,18 @@ export const isGeminiFile = (file: RemoteFile): file is RemoteFile & { type: 'ge
* @param file - The RemoteFile to check * @param file - The RemoteFile to check
* @returns True if the file is a Mistral file (file property is of type FileSchema) * @returns True if the file is a Mistral file (file property is of type FileSchema)
*/ */
export const isMistralFile = (file: RemoteFile): file is RemoteFile & { type: 'mistral'; file: FileSchema } => { export const isMistralFile = (file: RemoteFile): file is { type: 'mistral'; file: FileSchema } => {
return file.type === 'mistral' return file.type === 'mistral'
} }
/** Type guard to check if a RemoteFile is an OpenAI file
* @param file - The RemoteFile to check
* @returns True if the file is an OpenAI file (file property is of type OpenAI.Files.FileObject)
*/
export const isOpenAIFile = (file: RemoteFile): file is { type: 'openai'; file: OpenAI.Files.FileObject } => {
return file.type === 'openai'
}
export type FileStatus = 'success' | 'processing' | 'failed' | 'unknown' export type FileStatus = 'success' | 'processing' | 'failed' | 'unknown'
export interface FileUploadResponse { export interface FileUploadResponse {
@ -93,6 +106,10 @@ export interface FileMetadata {
* token大小 () * token大小 ()
*/ */
tokens?: number tokens?: number
/**
*
*/
purpose?: OpenAI.FilePurpose
} }
export interface FileType extends FileMetadata {} export interface FileType extends FileMetadata {}