feat: enhance OpenAI provider handling and add providerParams utility module

- Updated the `createBaseModel` function to handle OpenAI provider responses in strict mode.
- Modified `providerToAiSdkConfig` to include specific options for OpenAI when in strict mode.
- Introduced a new utility module `providerParams.ts` for managing provider-specific parameters, including OpenAI, Anthropic, and Gemini configurations.
- Added functions to retrieve service tiers, specific parameters, and reasoning efforts for various providers, improving overall provider management.
This commit is contained in:
MyPrototypeWhat 2025-07-02 16:43:06 +08:00
parent b660e9d524
commit ff3b1fc38f
3 changed files with 333 additions and 6 deletions

View File

@ -65,10 +65,13 @@ export async function createBaseModel(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${effectiveProviderId}"`
)
}
// 创建provider实例
const provider = creatorFunction(options)
let provider = creatorFunction(options)
// 加一个特判
if (providerConfig.id === 'openai' && options.compatibility === 'strict') {
provider = provider.responses
}
// 返回模型实例
if (typeof provider === 'function') {
let model: LanguageModelV1 = provider(modelId)

View File

@ -53,9 +53,16 @@ function providerToAiSdkConfig(provider: Provider): {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, actualProvider)
// 如果provider是openai则使用strict模式并且默认responses api
const openaiResponseOptions =
aiSdkProviderId === 'openai'
? {
compatibility: 'strict'
}
: undefined
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, actualProvider, openaiResponseOptions)
return {
providerId: aiSdkProviderId as ProviderId,
options
@ -183,13 +190,13 @@ export default class ModernAiProvider {
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(middlewareConfig)
console.log('构建的中间件:', middlewares)
// console.log('构建的中间件:', middlewares)
// 创建带有中间件的执行器
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
console.log('最终params', params)
const streamResult = await executor.streamText(
modelId,
params,

View File

@ -0,0 +1,317 @@
/**
*
* API AI SDK providerOptions
*/
import { Modality } from '@google/genai'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
isGeminiReasoningModel,
isOpenAIModel,
isOpenAIReasoningModel,
isReasoningModel,
isSupportedFlexServiceTier,
isSupportedReasoningEffortOpenAIModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import type { SettingsState } from '@renderer/store/settings'
import { Assistant, EFFORT_RATIO, Model, OpenAIServiceTier } from '@renderer/types'
// ===== OpenAI 相关参数 =====
/**
* OpenAI
* BaseApiClient.getServiceTier
*/
export function getServiceTier(model: Model): OpenAIServiceTier | undefined {
if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') {
return undefined
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
let serviceTier = 'auto' as OpenAIServiceTier
if (openAI && openAI?.serviceTier === 'flex') {
if (isSupportedFlexServiceTier(model)) {
serviceTier = 'flex'
} else {
serviceTier = 'auto'
}
} else {
serviceTier = openAI.serviceTier
}
return serviceTier
}
/**
* OpenAI
* OpenAIBaseClient.getProviderSpecificParameters
*/
export function getOpenAIProviderSpecificParameters(assistant: Assistant, model: Model) {
const { maxTokens } = getAssistantSettings(assistant)
if (model.provider === 'openrouter') {
if (model.id.includes('deepseek-r1')) {
return {
include_reasoning: true
}
}
}
if (isOpenAIReasoningModel(model)) {
return {
max_tokens: undefined,
max_completion_tokens: maxTokens
}
}
return {}
}
/**
* OpenAI
* OpenAIBaseClient.getReasoningEffort
*/
export function getOpenAIReasoningEffort(assistant: Assistant, model: Model) {
if (!isSupportedReasoningEffortOpenAIModel(model)) {
return {}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
let summary: string | undefined = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
summary = undefined
} else {
summary = summaryText
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort) {
return {}
}
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
reasoning: {
effort: reasoningEffort,
summary: summary
}
}
}
return {}
}
// ===== Anthropic 相关参数 =====
/**
* Anthropic
* AnthropicAPIClient.getBudgetToken
*/
export function getAnthropicBudgetToken(assistant: Assistant, model: Model) {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return {
type: 'disabled'
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(model.id)
if (!tokenLimit) {
return {
type: 'enabled',
budget_tokens: Math.max(1024, Math.floor((maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
}
}
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return {
type: 'enabled',
budget_tokens: budgetTokens
}
}
// ===== Gemini 相关参数 =====
/**
* Gemini
* GeminiAPIClient.getBudgetToken
*/
export function getGeminiBudgetToken(assistant: Assistant, model: Model) {
if (isGeminiReasoningModel(model)) {
const reasoningEffort = assistant?.settings?.reasoning_effort
// 如果 thinking_budget 是 undefined不思考
if (reasoningEffort === undefined) {
return {
thinkingConfig: {
includeThoughts: false,
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
return {
thinkingConfig: {
includeThoughts: true
}
}
}
const tokenLimit = findTokenLimit(model.id)
const { min = 0, max = 0 } = tokenLimit || {}
// 计算 budgetTokens确保不低于 min
const budget = Math.floor((max - min) * effortRatio + min)
return {
thinkingConfig: {
...(budget > 0 ? { thinkingBudget: budget } : {}),
includeThoughts: true
}
}
}
return {}
}
/**
* Gemini
* GeminiAPIClient.getSafetySettings
*/
export function getGeminiSafetySettings() {
return [
{
category: 'HARM_CATEGORY_HARASSMENT',
threshold: 'BLOCK_NONE'
},
{
category: 'HARM_CATEGORY_HATE_SPEECH',
threshold: 'BLOCK_NONE'
},
{
category: 'HARM_CATEGORY_SEXUALLY_EXPLICIT',
threshold: 'BLOCK_NONE'
},
{
category: 'HARM_CATEGORY_DANGEROUS_CONTENT',
threshold: 'BLOCK_NONE'
}
]
}
/**
* Gemini
* GeminiAPIClient.getGenerateImageParameter
*/
export function getGeminiGenerateImageParameter() {
return {
systemInstruction: undefined,
responseModalities: [Modality.TEXT, Modality.IMAGE],
responseMimeType: 'text/plain'
}
}
// ===== Web 搜索相关参数 =====
/**
* OpenAI Web providerOptions
* getOpenAIWebSearchParams
*/
export function getOpenAIWebSearchProviderOptions(model: Model, enableWebSearch: boolean): Record<string, any> {
return getOpenAIWebSearchParams(model, enableWebSearch)
}
/**
* Anthropic Web providerOptions
*/
export function getAnthropicWebSearchProviderOptions(model: Model, enableWebSearch: boolean): Record<string, any> {
if (!enableWebSearch) {
return {}
}
// Anthropic 通过 tools 实现,但在 AI SDK 中应该通过 providerOptions 统一
return {
webSearch: {
enabled: true,
toolType: 'web_search_20250305',
maxUses: 5
}
}
}
/**
* Gemini Web providerOptions
*/
export function getGeminiWebSearchProviderOptions(model: Model, enableWebSearch: boolean): Record<string, any> {
if (!enableWebSearch) {
return {}
}
// Gemini 通过 googleSearch 工具实现,但在 AI SDK 中应该通过 providerOptions 统一
return {
webSearch: {
enabled: true,
toolType: 'googleSearch'
}
}
}
// ===== 通用辅助函数 =====
/**
*
* BaseApiClient.getCustomParameters
*/
export function getCustomParameters(assistant: Assistant) {
return (
assistant?.settings?.customParameters?.reduce((acc, param) => {
if (!param.name?.trim()) {
return acc
}
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
return { ...acc, [param.name]: undefined }
}
try {
return { ...acc, [param.name]: JSON.parse(value) }
} catch {
return { ...acc, [param.name]: value }
}
}
return {
...acc,
[param.name]: param.value
}
}, {}) || {}
)
}