mirror of
https://github.com/StarCitizenToolBox/app.git
synced 2026-01-13 11:40:27 +00:00
Add XNN Pack toggle switch for ONNX inference acceleration (#155)
* Initial plan * Add XNN Pack switch for ONNX inference acceleration Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com> * Refactor Rust ONNX session creation to reduce code duplication Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>
This commit is contained in:
parent
db024f19bd
commit
db89100402
@ -82,7 +82,7 @@ final class AppGlobalModelProvider
|
||||
}
|
||||
}
|
||||
|
||||
String _$appGlobalModelHash() => r'51f72c5d8538e2a4f11d256802b1a1f2e04d03be';
|
||||
String _$appGlobalModelHash() => r'9729c3ffb891e5899abbb3dc7d2d25ef13a442e7';
|
||||
|
||||
abstract class _$AppGlobalModel extends $Notifier<AppGlobalState> {
|
||||
AppGlobalState build();
|
||||
|
||||
@ -40,7 +40,7 @@ Future<List<String>> dnsLookupIps({required String host}) =>
|
||||
/// # Arguments
|
||||
/// * `urls` - List of base URLs to test
|
||||
/// * `path_suffix` - Optional path suffix to append to each URL (e.g., "/api/version")
|
||||
/// If None, tests the base URL directly
|
||||
/// If None, tests the base URL directly
|
||||
Future<String?> getFasterUrl({
|
||||
required List<String> urls,
|
||||
String? pathSuffix,
|
||||
|
||||
@ -12,15 +12,18 @@ import 'package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart';
|
||||
/// * `model_path` - 模型文件夹路径
|
||||
/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en")
|
||||
/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型)
|
||||
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
|
||||
///
|
||||
Future<void> loadTranslationModel({
|
||||
required String modelPath,
|
||||
required String modelKey,
|
||||
required String quantizationSuffix,
|
||||
required bool useXnnpack,
|
||||
}) => RustLib.instance.api.crateApiOrtApiLoadTranslationModel(
|
||||
modelPath: modelPath,
|
||||
modelKey: modelKey,
|
||||
quantizationSuffix: quantizationSuffix,
|
||||
useXnnpack: useXnnpack,
|
||||
);
|
||||
|
||||
/// 翻译文本
|
||||
|
||||
@ -116,6 +116,7 @@ abstract class RustLibApi extends BaseApi {
|
||||
required String modelPath,
|
||||
required String modelKey,
|
||||
required String quantizationSuffix,
|
||||
required bool useXnnpack,
|
||||
});
|
||||
|
||||
Future<void> crateApiAsarApiRsiLauncherAsarDataWriteMainJs({
|
||||
@ -421,6 +422,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
required String modelPath,
|
||||
required String modelKey,
|
||||
required String quantizationSuffix,
|
||||
required bool useXnnpack,
|
||||
}) {
|
||||
return handler.executeNormal(
|
||||
NormalTask(
|
||||
@ -428,11 +430,13 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
var arg0 = cst_encode_String(modelPath);
|
||||
var arg1 = cst_encode_String(modelKey);
|
||||
var arg2 = cst_encode_String(quantizationSuffix);
|
||||
var arg3 = cst_encode_bool(useXnnpack);
|
||||
return wire.wire__crate__api__ort_api__load_translation_model(
|
||||
port_,
|
||||
arg0,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
);
|
||||
},
|
||||
codec: DcoCodec(
|
||||
@ -440,7 +444,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
decodeErrorData: dco_decode_AnyhowException,
|
||||
),
|
||||
constMeta: kCrateApiOrtApiLoadTranslationModelConstMeta,
|
||||
argValues: [modelPath, modelKey, quantizationSuffix],
|
||||
argValues: [modelPath, modelKey, quantizationSuffix, useXnnpack],
|
||||
apiImpl: this,
|
||||
),
|
||||
);
|
||||
@ -449,7 +453,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
|
||||
TaskConstMeta get kCrateApiOrtApiLoadTranslationModelConstMeta =>
|
||||
const TaskConstMeta(
|
||||
debugName: "load_translation_model",
|
||||
argNames: ["modelPath", "modelKey", "quantizationSuffix"],
|
||||
argNames: ["modelPath", "modelKey", "quantizationSuffix", "useXnnpack"],
|
||||
);
|
||||
|
||||
@override
|
||||
|
||||
@ -880,12 +880,14 @@ class RustLibWire implements BaseWire {
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict> model_path,
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict> model_key,
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict> quantization_suffix,
|
||||
bool use_xnnpack,
|
||||
) {
|
||||
return _wire__crate__api__ort_api__load_translation_model(
|
||||
port_,
|
||||
model_path,
|
||||
model_key,
|
||||
quantization_suffix,
|
||||
use_xnnpack,
|
||||
);
|
||||
}
|
||||
|
||||
@ -897,6 +899,7 @@ class RustLibWire implements BaseWire {
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
|
||||
ffi.Bool,
|
||||
)
|
||||
>
|
||||
>(
|
||||
@ -910,6 +913,7 @@ class RustLibWire implements BaseWire {
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
|
||||
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
|
||||
bool,
|
||||
)
|
||||
>();
|
||||
|
||||
|
||||
@ -1410,6 +1410,12 @@ class MessageLookup extends MessageLookupByLibrary {
|
||||
"settings_item_dns_info": MessageLookupByLibrary.simpleMessage(
|
||||
"When enabled, may solve DNS pollution issues in some regions",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack": MessageLookupByLibrary.simpleMessage(
|
||||
"Use XNN to accelerate ONNX inference",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack_info": MessageLookupByLibrary.simpleMessage(
|
||||
"Disabling this option may solve some compatibility issues",
|
||||
),
|
||||
"settings_title_game": MessageLookupByLibrary.simpleMessage("Game"),
|
||||
"settings_title_general": MessageLookupByLibrary.simpleMessage("General"),
|
||||
"support_dev_alipay": MessageLookupByLibrary.simpleMessage("Alipay"),
|
||||
|
||||
@ -1249,6 +1249,12 @@ class MessageLookup extends MessageLookupByLibrary {
|
||||
"settings_item_dns_info": MessageLookupByLibrary.simpleMessage(
|
||||
"有効にすると、一部の地域でのDNS汚染問題を解決する可能性があります",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack": MessageLookupByLibrary.simpleMessage(
|
||||
"XNNを使用してONNX推論を高速化",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack_info": MessageLookupByLibrary.simpleMessage(
|
||||
"このオプションをオフにすると、一部の互換性問題が解決する場合があります",
|
||||
),
|
||||
"settings_title_game": MessageLookupByLibrary.simpleMessage("ゲーム"),
|
||||
"settings_title_general": MessageLookupByLibrary.simpleMessage("一般"),
|
||||
"support_dev_alipay": MessageLookupByLibrary.simpleMessage("Alipay"),
|
||||
|
||||
@ -1414,6 +1414,12 @@ class MessageLookup extends MessageLookupByLibrary {
|
||||
"settings_item_dns_info": MessageLookupByLibrary.simpleMessage(
|
||||
"При включении может решить проблемы с DNS-загрязнением в некоторых регионах",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack": MessageLookupByLibrary.simpleMessage(
|
||||
"Использовать XNN для ускорения ONNX",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack_info": MessageLookupByLibrary.simpleMessage(
|
||||
"Отключение этой опции может решить некоторые проблемы совместимости",
|
||||
),
|
||||
"settings_title_game": MessageLookupByLibrary.simpleMessage("Игра"),
|
||||
"settings_title_general": MessageLookupByLibrary.simpleMessage("Общие"),
|
||||
"support_dev_alipay": MessageLookupByLibrary.simpleMessage("Alipay"),
|
||||
|
||||
@ -1206,6 +1206,12 @@ class MessageLookup extends MessageLookupByLibrary {
|
||||
"settings_item_dns_info": MessageLookupByLibrary.simpleMessage(
|
||||
"开启后可能解决部分地区 DNS 污染的问题",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack": MessageLookupByLibrary.simpleMessage(
|
||||
"使用 XNN 加速 ONNX 推理",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack_info": MessageLookupByLibrary.simpleMessage(
|
||||
"关闭此选项或许可以解决一些兼容问题",
|
||||
),
|
||||
"settings_title_game": MessageLookupByLibrary.simpleMessage("游戏"),
|
||||
"settings_title_general": MessageLookupByLibrary.simpleMessage("通用"),
|
||||
"support_dev_alipay": MessageLookupByLibrary.simpleMessage("支付宝"),
|
||||
|
||||
@ -1200,6 +1200,12 @@ class MessageLookup extends MessageLookupByLibrary {
|
||||
"settings_item_dns_info": MessageLookupByLibrary.simpleMessage(
|
||||
"開啟後可能解決部分地區 DNS 污染的問題",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack": MessageLookupByLibrary.simpleMessage(
|
||||
"使用 XNN 加速 ONNX 推理",
|
||||
),
|
||||
"settings_item_onnx_xnn_pack_info": MessageLookupByLibrary.simpleMessage(
|
||||
"關閉此選項或許可以解決一些相容性問題",
|
||||
),
|
||||
"settings_title_game": MessageLookupByLibrary.simpleMessage("遊戲"),
|
||||
"settings_title_general": MessageLookupByLibrary.simpleMessage("通用"),
|
||||
"support_dev_alipay": MessageLookupByLibrary.simpleMessage("支付寶"),
|
||||
|
||||
@ -28,10 +28,9 @@ class S {
|
||||
static const AppLocalizationDelegate delegate = AppLocalizationDelegate();
|
||||
|
||||
static Future<S> load(Locale locale) {
|
||||
final name =
|
||||
(locale.countryCode?.isEmpty ?? false)
|
||||
? locale.languageCode
|
||||
: locale.toString();
|
||||
final name = (locale.countryCode?.isEmpty ?? false)
|
||||
? locale.languageCode
|
||||
: locale.toString();
|
||||
final localeName = Intl.canonicalizedLocale(name);
|
||||
return initializeMessages(localeName).then((_) {
|
||||
Intl.defaultLocale = localeName;
|
||||
@ -6057,6 +6056,26 @@ class S {
|
||||
args: [],
|
||||
);
|
||||
}
|
||||
|
||||
/// `Use XNN to accelerate ONNX inference`
|
||||
String get settings_item_onnx_xnn_pack {
|
||||
return Intl.message(
|
||||
'Use XNN to accelerate ONNX inference',
|
||||
name: 'settings_item_onnx_xnn_pack',
|
||||
desc: '',
|
||||
args: [],
|
||||
);
|
||||
}
|
||||
|
||||
/// `Disabling this option may solve some compatibility issues`
|
||||
String get settings_item_onnx_xnn_pack_info {
|
||||
return Intl.message(
|
||||
'Disabling this option may solve some compatibility issues',
|
||||
name: 'settings_item_onnx_xnn_pack_info',
|
||||
desc: '',
|
||||
args: [],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
class AppLocalizationDelegate extends LocalizationsDelegate<S> {
|
||||
|
||||
@ -1200,5 +1200,8 @@
|
||||
"user_unregister_success": "Account unregistered successfully",
|
||||
"@user_unregister_success": {},
|
||||
"user_unregister_failed": "Account unregistration failed",
|
||||
"@user_unregister_failed": {}
|
||||
"@user_unregister_failed": {},
|
||||
"settings_item_onnx_xnn_pack": "Use XNN to accelerate ONNX inference",
|
||||
"@settings_item_onnx_xnn_pack": {},
|
||||
"settings_item_onnx_xnn_pack_info": "Disabling this option may solve some compatibility issues"
|
||||
}
|
||||
|
||||
@ -1158,5 +1158,8 @@
|
||||
"log_analyzer_description": "プレイ記録を分析(ログイン、死亡、キルなどの情報)",
|
||||
"@log_analyzer_description": {},
|
||||
"log_analyzer_window_title": "SCToolbox: logアナライザ",
|
||||
"@log_analyzer_window_title": {}
|
||||
"@log_analyzer_window_title": {},
|
||||
"settings_item_onnx_xnn_pack": "XNNを使用してONNX推論を高速化",
|
||||
"@settings_item_onnx_xnn_pack": {},
|
||||
"settings_item_onnx_xnn_pack_info": "このオプションをオフにすると、一部の互換性問題が解決する場合があります"
|
||||
}
|
||||
|
||||
@ -1158,5 +1158,8 @@
|
||||
"log_analyzer_description": "Анализ ваших игровых записей (логин, смерти, убийства и другая информация)",
|
||||
"@log_analyzer_description": {},
|
||||
"log_analyzer_window_title": "SCToolbox: Анализатор логов",
|
||||
"@log_analyzer_window_title": {}
|
||||
"@log_analyzer_window_title": {},
|
||||
"settings_item_onnx_xnn_pack": "Использовать XNN для ускорения ONNX",
|
||||
"@settings_item_onnx_xnn_pack": {},
|
||||
"settings_item_onnx_xnn_pack_info": "Отключение этой опции может решить некоторые проблемы совместимости"
|
||||
}
|
||||
|
||||
@ -915,5 +915,8 @@
|
||||
"user_confirm_unregister_title": "确认注销",
|
||||
"user_confirm_unregister_message": "您确定要注销账户吗?此操作不可撤销,如需再次登录,需重新验证 RSI 账号。",
|
||||
"user_unregister_success": "账户注销成功",
|
||||
"user_unregister_failed": "账户注销失败"
|
||||
"user_unregister_failed": "账户注销失败",
|
||||
"settings_item_onnx_xnn_pack": "使用 XNN 加速 ONNX 推理",
|
||||
"@settings_item_onnx_xnn_pack": {},
|
||||
"settings_item_onnx_xnn_pack_info": "关闭此选项或许可以解决一些兼容问题"
|
||||
}
|
||||
@ -1180,5 +1180,8 @@
|
||||
"tools_vehicle_sorting_search": "搜索載具",
|
||||
"@tools_vehicle_sorting_search": {},
|
||||
"tools_vehicle_sorting_sorted": "已排序載具",
|
||||
"@tools_vehicle_sorting_sorted": {}
|
||||
"@tools_vehicle_sorting_sorted": {},
|
||||
"settings_item_onnx_xnn_pack": "使用 XNN 加速 ONNX 推理",
|
||||
"@settings_item_onnx_xnn_pack": {},
|
||||
"settings_item_onnx_xnn_pack_info": "關閉此選項或許可以解決一些相容性問題"
|
||||
}
|
||||
|
||||
@ -44,7 +44,7 @@ final class PartyRoomProvider
|
||||
}
|
||||
}
|
||||
|
||||
String _$partyRoomHash() => r'5640c173d0820c681f3bc68872a2ab4f2fa29285';
|
||||
String _$partyRoomHash() => r'2ce3ac365bec3af8f7e1d350b53262c8e4e2872d';
|
||||
|
||||
/// PartyRoom Provider
|
||||
|
||||
|
||||
@ -199,8 +199,13 @@ class InputMethodDialogUIModel extends _$InputMethodDialogUIModel {
|
||||
|
||||
String get _localTranslateModelDir => "${appGlobalState.applicationSupportDir}/onnx_models";
|
||||
|
||||
bool get _isEnableOnnxXnnPack {
|
||||
final userBox = Hive.box("app_conf");
|
||||
return userBox.get("isEnableOnnxXnnPack", defaultValue: true);
|
||||
}
|
||||
|
||||
OnnxTranslationProvider get _localTranslateModelProvider =>
|
||||
onnxTranslationProvider(_localTranslateModelDir, _localTranslateModelName);
|
||||
onnxTranslationProvider(_localTranslateModelDir, _localTranslateModelName, _isEnableOnnxXnnPack);
|
||||
|
||||
void _checkAutoTranslateOnInit() {
|
||||
// 检查模型文件是否存在,不存在则关闭自动翻译
|
||||
@ -333,20 +338,21 @@ class InputMethodDialogUIModel extends _$InputMethodDialogUIModel {
|
||||
@riverpod
|
||||
class OnnxTranslation extends _$OnnxTranslation {
|
||||
@override
|
||||
bool build(String modelDir, String modelName) {
|
||||
dPrint("[OnnxTranslation] Build provider for model: $modelName");
|
||||
bool build(String modelDir, String modelName, bool useXnnPack) {
|
||||
dPrint("[OnnxTranslation] Build provider for model: $modelName, useXnnPack: $useXnnPack");
|
||||
ref.onDispose(disposeModel);
|
||||
return false;
|
||||
}
|
||||
|
||||
Future<String?> initModel() async {
|
||||
dPrint("[OnnxTranslation] Load model: $modelName from $modelDir");
|
||||
dPrint("[OnnxTranslation] Load model: $modelName from $modelDir, useXnnPack: $useXnnPack");
|
||||
String? errorMessage;
|
||||
try {
|
||||
await ort.loadTranslationModel(
|
||||
modelPath: "$modelDir/$modelName",
|
||||
modelKey: modelName,
|
||||
quantizationSuffix: "_q4f16",
|
||||
useXnnpack: useXnnPack,
|
||||
);
|
||||
state = true;
|
||||
} catch (e) {
|
||||
|
||||
@ -43,7 +43,7 @@ final class InputMethodDialogUIModelProvider
|
||||
}
|
||||
|
||||
String _$inputMethodDialogUIModelHash() =>
|
||||
r'51f1708f22a90f7c2f879ad3d2a87a8e2f81b9e9';
|
||||
r'bd96c85ef2073d80de6eba71748b41adb8861e1c';
|
||||
|
||||
abstract class _$InputMethodDialogUIModel
|
||||
extends $Notifier<InputMethodDialogUIState> {
|
||||
@ -73,7 +73,7 @@ final class OnnxTranslationProvider
|
||||
extends $NotifierProvider<OnnxTranslation, bool> {
|
||||
const OnnxTranslationProvider._({
|
||||
required OnnxTranslationFamily super.from,
|
||||
required (String, String) super.argument,
|
||||
required (String, String, bool) super.argument,
|
||||
}) : super(
|
||||
retry: null,
|
||||
name: r'onnxTranslationProvider',
|
||||
@ -115,7 +115,7 @@ final class OnnxTranslationProvider
|
||||
}
|
||||
}
|
||||
|
||||
String _$onnxTranslationHash() => r'4f3dc0e361dca2d6b00f557496bdf006cc6c235c';
|
||||
String _$onnxTranslationHash() => r'd4946a47240ab42dd65c35fa3dda365e4c491462';
|
||||
|
||||
final class OnnxTranslationFamily extends $Family
|
||||
with
|
||||
@ -124,7 +124,7 @@ final class OnnxTranslationFamily extends $Family
|
||||
bool,
|
||||
bool,
|
||||
bool,
|
||||
(String, String)
|
||||
(String, String, bool)
|
||||
> {
|
||||
const OnnxTranslationFamily._()
|
||||
: super(
|
||||
@ -135,23 +135,30 @@ final class OnnxTranslationFamily extends $Family
|
||||
isAutoDispose: true,
|
||||
);
|
||||
|
||||
OnnxTranslationProvider call(String modelDir, String modelName) =>
|
||||
OnnxTranslationProvider._(argument: (modelDir, modelName), from: this);
|
||||
OnnxTranslationProvider call(
|
||||
String modelDir,
|
||||
String modelName,
|
||||
bool useXnnPack,
|
||||
) => OnnxTranslationProvider._(
|
||||
argument: (modelDir, modelName, useXnnPack),
|
||||
from: this,
|
||||
);
|
||||
|
||||
@override
|
||||
String toString() => r'onnxTranslationProvider';
|
||||
}
|
||||
|
||||
abstract class _$OnnxTranslation extends $Notifier<bool> {
|
||||
late final _$args = ref.$arg as (String, String);
|
||||
late final _$args = ref.$arg as (String, String, bool);
|
||||
String get modelDir => _$args.$1;
|
||||
String get modelName => _$args.$2;
|
||||
bool get useXnnPack => _$args.$3;
|
||||
|
||||
bool build(String modelDir, String modelName);
|
||||
bool build(String modelDir, String modelName, bool useXnnPack);
|
||||
@$mustCallSuper
|
||||
@override
|
||||
void runBuild() {
|
||||
final created = build(_$args.$1, _$args.$2);
|
||||
final created = build(_$args.$1, _$args.$2, _$args.$3);
|
||||
final ref = this.ref as $Ref<bool, bool>;
|
||||
final element =
|
||||
ref.element
|
||||
|
||||
@ -39,6 +39,13 @@ class SettingsUI extends HookConsumerWidget {
|
||||
onSwitch: model.onChangeUseInternalDNS,
|
||||
onTap: () => model.onChangeUseInternalDNS(!sate.isUseInternalDNS)),
|
||||
const SizedBox(height: 12),
|
||||
makeSettingsItem(const Icon(FluentIcons.processing, size: 20),
|
||||
S.current.settings_item_onnx_xnn_pack,
|
||||
subTitle: S.current.settings_item_onnx_xnn_pack_info,
|
||||
switchStatus: sate.isEnableOnnxXnnPack,
|
||||
onSwitch: model.onChangeOnnxXnnPack,
|
||||
onTap: () => model.onChangeOnnxXnnPack(!sate.isEnableOnnxXnnPack)),
|
||||
const SizedBox(height: 12),
|
||||
makeSettingsItem(const Icon(FluentIcons.delete, size: 20),
|
||||
S.current.setting_action_clear_translation_file_cache,
|
||||
subTitle: S.current.setting_action_info_cache_clearing_info(
|
||||
|
||||
@ -26,6 +26,7 @@ abstract class SettingsUIState with _$SettingsUIState {
|
||||
String? customGamePath,
|
||||
@Default(0) int locationCacheSize,
|
||||
@Default(false) bool isUseInternalDNS,
|
||||
@Default(true) bool isEnableOnnxXnnPack,
|
||||
}) = _SettingsUIState;
|
||||
}
|
||||
|
||||
@ -44,6 +45,7 @@ class SettingsUIModel extends _$SettingsUIModel {
|
||||
await _loadLocationCacheSize();
|
||||
await _loadToolSiteMirrorState();
|
||||
await _loadUseInternalDNS();
|
||||
await _loadOnnxXnnPackState();
|
||||
}
|
||||
|
||||
Future<void> setGameLaunchECore(BuildContext context) async {
|
||||
@ -227,4 +229,17 @@ class SettingsUIModel extends _$SettingsUIModel {
|
||||
userBox.get("isUseInternalDNS", defaultValue: false);
|
||||
state = state.copyWith(isUseInternalDNS: isUseInternalDNS);
|
||||
}
|
||||
|
||||
void onChangeOnnxXnnPack(bool? b) {
|
||||
final userBox = Hive.box("app_conf");
|
||||
userBox.put("isEnableOnnxXnnPack", b ?? true);
|
||||
_initState();
|
||||
}
|
||||
|
||||
Future _loadOnnxXnnPackState() async {
|
||||
final userBox = await Hive.openBox("app_conf");
|
||||
final isEnableOnnxXnnPack =
|
||||
userBox.get("isEnableOnnxXnnPack", defaultValue: true);
|
||||
state = state.copyWith(isEnableOnnxXnnPack: isEnableOnnxXnnPack);
|
||||
}
|
||||
}
|
||||
|
||||
@ -14,7 +14,7 @@ T _$identity<T>(T value) => value;
|
||||
/// @nodoc
|
||||
mixin _$SettingsUIState {
|
||||
|
||||
bool get isEnableToolSiteMirrors; String get inputGameLaunchECore; String? get customLauncherPath; String? get customGamePath; int get locationCacheSize; bool get isUseInternalDNS;
|
||||
bool get isEnableToolSiteMirrors; String get inputGameLaunchECore; String? get customLauncherPath; String? get customGamePath; int get locationCacheSize; bool get isUseInternalDNS; bool get isEnableOnnxXnnPack;
|
||||
/// Create a copy of SettingsUIState
|
||||
/// with the given fields replaced by the non-null parameter values.
|
||||
@JsonKey(includeFromJson: false, includeToJson: false)
|
||||
@ -25,16 +25,16 @@ $SettingsUIStateCopyWith<SettingsUIState> get copyWith => _$SettingsUIStateCopyW
|
||||
|
||||
@override
|
||||
bool operator ==(Object other) {
|
||||
return identical(this, other) || (other.runtimeType == runtimeType&&other is SettingsUIState&&(identical(other.isEnableToolSiteMirrors, isEnableToolSiteMirrors) || other.isEnableToolSiteMirrors == isEnableToolSiteMirrors)&&(identical(other.inputGameLaunchECore, inputGameLaunchECore) || other.inputGameLaunchECore == inputGameLaunchECore)&&(identical(other.customLauncherPath, customLauncherPath) || other.customLauncherPath == customLauncherPath)&&(identical(other.customGamePath, customGamePath) || other.customGamePath == customGamePath)&&(identical(other.locationCacheSize, locationCacheSize) || other.locationCacheSize == locationCacheSize)&&(identical(other.isUseInternalDNS, isUseInternalDNS) || other.isUseInternalDNS == isUseInternalDNS));
|
||||
return identical(this, other) || (other.runtimeType == runtimeType&&other is SettingsUIState&&(identical(other.isEnableToolSiteMirrors, isEnableToolSiteMirrors) || other.isEnableToolSiteMirrors == isEnableToolSiteMirrors)&&(identical(other.inputGameLaunchECore, inputGameLaunchECore) || other.inputGameLaunchECore == inputGameLaunchECore)&&(identical(other.customLauncherPath, customLauncherPath) || other.customLauncherPath == customLauncherPath)&&(identical(other.customGamePath, customGamePath) || other.customGamePath == customGamePath)&&(identical(other.locationCacheSize, locationCacheSize) || other.locationCacheSize == locationCacheSize)&&(identical(other.isUseInternalDNS, isUseInternalDNS) || other.isUseInternalDNS == isUseInternalDNS)&&(identical(other.isEnableOnnxXnnPack, isEnableOnnxXnnPack) || other.isEnableOnnxXnnPack == isEnableOnnxXnnPack));
|
||||
}
|
||||
|
||||
|
||||
@override
|
||||
int get hashCode => Object.hash(runtimeType,isEnableToolSiteMirrors,inputGameLaunchECore,customLauncherPath,customGamePath,locationCacheSize,isUseInternalDNS);
|
||||
int get hashCode => Object.hash(runtimeType,isEnableToolSiteMirrors,inputGameLaunchECore,customLauncherPath,customGamePath,locationCacheSize,isUseInternalDNS,isEnableOnnxXnnPack);
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'SettingsUIState(isEnableToolSiteMirrors: $isEnableToolSiteMirrors, inputGameLaunchECore: $inputGameLaunchECore, customLauncherPath: $customLauncherPath, customGamePath: $customGamePath, locationCacheSize: $locationCacheSize, isUseInternalDNS: $isUseInternalDNS)';
|
||||
return 'SettingsUIState(isEnableToolSiteMirrors: $isEnableToolSiteMirrors, inputGameLaunchECore: $inputGameLaunchECore, customLauncherPath: $customLauncherPath, customGamePath: $customGamePath, locationCacheSize: $locationCacheSize, isUseInternalDNS: $isUseInternalDNS, isEnableOnnxXnnPack: $isEnableOnnxXnnPack)';
|
||||
}
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ abstract mixin class $SettingsUIStateCopyWith<$Res> {
|
||||
factory $SettingsUIStateCopyWith(SettingsUIState value, $Res Function(SettingsUIState) _then) = _$SettingsUIStateCopyWithImpl;
|
||||
@useResult
|
||||
$Res call({
|
||||
bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS
|
||||
bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS, bool isEnableOnnxXnnPack
|
||||
});
|
||||
|
||||
|
||||
@ -62,7 +62,7 @@ class _$SettingsUIStateCopyWithImpl<$Res>
|
||||
|
||||
/// Create a copy of SettingsUIState
|
||||
/// with the given fields replaced by the non-null parameter values.
|
||||
@pragma('vm:prefer-inline') @override $Res call({Object? isEnableToolSiteMirrors = null,Object? inputGameLaunchECore = null,Object? customLauncherPath = freezed,Object? customGamePath = freezed,Object? locationCacheSize = null,Object? isUseInternalDNS = null,}) {
|
||||
@pragma('vm:prefer-inline') @override $Res call({Object? isEnableToolSiteMirrors = null,Object? inputGameLaunchECore = null,Object? customLauncherPath = freezed,Object? customGamePath = freezed,Object? locationCacheSize = null,Object? isUseInternalDNS = null,Object? isEnableOnnxXnnPack = null,}) {
|
||||
return _then(_self.copyWith(
|
||||
isEnableToolSiteMirrors: null == isEnableToolSiteMirrors ? _self.isEnableToolSiteMirrors : isEnableToolSiteMirrors // ignore: cast_nullable_to_non_nullable
|
||||
as bool,inputGameLaunchECore: null == inputGameLaunchECore ? _self.inputGameLaunchECore : inputGameLaunchECore // ignore: cast_nullable_to_non_nullable
|
||||
@ -70,6 +70,7 @@ as String,customLauncherPath: freezed == customLauncherPath ? _self.customLaunch
|
||||
as String?,customGamePath: freezed == customGamePath ? _self.customGamePath : customGamePath // ignore: cast_nullable_to_non_nullable
|
||||
as String?,locationCacheSize: null == locationCacheSize ? _self.locationCacheSize : locationCacheSize // ignore: cast_nullable_to_non_nullable
|
||||
as int,isUseInternalDNS: null == isUseInternalDNS ? _self.isUseInternalDNS : isUseInternalDNS // ignore: cast_nullable_to_non_nullable
|
||||
as bool,isEnableOnnxXnnPack: null == isEnableOnnxXnnPack ? _self.isEnableOnnxXnnPack : isEnableOnnxXnnPack // ignore: cast_nullable_to_non_nullable
|
||||
as bool,
|
||||
));
|
||||
}
|
||||
@ -155,10 +156,10 @@ return $default(_that);case _:
|
||||
/// }
|
||||
/// ```
|
||||
|
||||
@optionalTypeArgs TResult maybeWhen<TResult extends Object?>(TResult Function( bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS)? $default,{required TResult orElse(),}) {final _that = this;
|
||||
@optionalTypeArgs TResult maybeWhen<TResult extends Object?>(TResult Function( bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS, bool isEnableOnnxXnnPack)? $default,{required TResult orElse(),}) {final _that = this;
|
||||
switch (_that) {
|
||||
case _SettingsUIState() when $default != null:
|
||||
return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.customLauncherPath,_that.customGamePath,_that.locationCacheSize,_that.isUseInternalDNS);case _:
|
||||
return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.customLauncherPath,_that.customGamePath,_that.locationCacheSize,_that.isUseInternalDNS,_that.isEnableOnnxXnnPack);case _:
|
||||
return orElse();
|
||||
|
||||
}
|
||||
@ -176,10 +177,10 @@ return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.c
|
||||
/// }
|
||||
/// ```
|
||||
|
||||
@optionalTypeArgs TResult when<TResult extends Object?>(TResult Function( bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS) $default,) {final _that = this;
|
||||
@optionalTypeArgs TResult when<TResult extends Object?>(TResult Function( bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS, bool isEnableOnnxXnnPack) $default,) {final _that = this;
|
||||
switch (_that) {
|
||||
case _SettingsUIState():
|
||||
return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.customLauncherPath,_that.customGamePath,_that.locationCacheSize,_that.isUseInternalDNS);case _:
|
||||
return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.customLauncherPath,_that.customGamePath,_that.locationCacheSize,_that.isUseInternalDNS,_that.isEnableOnnxXnnPack);case _:
|
||||
throw StateError('Unexpected subclass');
|
||||
|
||||
}
|
||||
@ -196,10 +197,10 @@ return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.c
|
||||
/// }
|
||||
/// ```
|
||||
|
||||
@optionalTypeArgs TResult? whenOrNull<TResult extends Object?>(TResult? Function( bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS)? $default,) {final _that = this;
|
||||
@optionalTypeArgs TResult? whenOrNull<TResult extends Object?>(TResult? Function( bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS, bool isEnableOnnxXnnPack)? $default,) {final _that = this;
|
||||
switch (_that) {
|
||||
case _SettingsUIState() when $default != null:
|
||||
return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.customLauncherPath,_that.customGamePath,_that.locationCacheSize,_that.isUseInternalDNS);case _:
|
||||
return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.customLauncherPath,_that.customGamePath,_that.locationCacheSize,_that.isUseInternalDNS,_that.isEnableOnnxXnnPack);case _:
|
||||
return null;
|
||||
|
||||
}
|
||||
@ -211,7 +212,7 @@ return $default(_that.isEnableToolSiteMirrors,_that.inputGameLaunchECore,_that.c
|
||||
|
||||
|
||||
class _SettingsUIState implements SettingsUIState {
|
||||
_SettingsUIState({this.isEnableToolSiteMirrors = false, this.inputGameLaunchECore = "0", this.customLauncherPath, this.customGamePath, this.locationCacheSize = 0, this.isUseInternalDNS = false});
|
||||
_SettingsUIState({this.isEnableToolSiteMirrors = false, this.inputGameLaunchECore = "0", this.customLauncherPath, this.customGamePath, this.locationCacheSize = 0, this.isUseInternalDNS = false, this.isEnableOnnxXnnPack = true});
|
||||
|
||||
|
||||
@override@JsonKey() final bool isEnableToolSiteMirrors;
|
||||
@ -220,6 +221,7 @@ class _SettingsUIState implements SettingsUIState {
|
||||
@override final String? customGamePath;
|
||||
@override@JsonKey() final int locationCacheSize;
|
||||
@override@JsonKey() final bool isUseInternalDNS;
|
||||
@override@JsonKey() final bool isEnableOnnxXnnPack;
|
||||
|
||||
/// Create a copy of SettingsUIState
|
||||
/// with the given fields replaced by the non-null parameter values.
|
||||
@ -231,16 +233,16 @@ _$SettingsUIStateCopyWith<_SettingsUIState> get copyWith => __$SettingsUIStateCo
|
||||
|
||||
@override
|
||||
bool operator ==(Object other) {
|
||||
return identical(this, other) || (other.runtimeType == runtimeType&&other is _SettingsUIState&&(identical(other.isEnableToolSiteMirrors, isEnableToolSiteMirrors) || other.isEnableToolSiteMirrors == isEnableToolSiteMirrors)&&(identical(other.inputGameLaunchECore, inputGameLaunchECore) || other.inputGameLaunchECore == inputGameLaunchECore)&&(identical(other.customLauncherPath, customLauncherPath) || other.customLauncherPath == customLauncherPath)&&(identical(other.customGamePath, customGamePath) || other.customGamePath == customGamePath)&&(identical(other.locationCacheSize, locationCacheSize) || other.locationCacheSize == locationCacheSize)&&(identical(other.isUseInternalDNS, isUseInternalDNS) || other.isUseInternalDNS == isUseInternalDNS));
|
||||
return identical(this, other) || (other.runtimeType == runtimeType&&other is _SettingsUIState&&(identical(other.isEnableToolSiteMirrors, isEnableToolSiteMirrors) || other.isEnableToolSiteMirrors == isEnableToolSiteMirrors)&&(identical(other.inputGameLaunchECore, inputGameLaunchECore) || other.inputGameLaunchECore == inputGameLaunchECore)&&(identical(other.customLauncherPath, customLauncherPath) || other.customLauncherPath == customLauncherPath)&&(identical(other.customGamePath, customGamePath) || other.customGamePath == customGamePath)&&(identical(other.locationCacheSize, locationCacheSize) || other.locationCacheSize == locationCacheSize)&&(identical(other.isUseInternalDNS, isUseInternalDNS) || other.isUseInternalDNS == isUseInternalDNS)&&(identical(other.isEnableOnnxXnnPack, isEnableOnnxXnnPack) || other.isEnableOnnxXnnPack == isEnableOnnxXnnPack));
|
||||
}
|
||||
|
||||
|
||||
@override
|
||||
int get hashCode => Object.hash(runtimeType,isEnableToolSiteMirrors,inputGameLaunchECore,customLauncherPath,customGamePath,locationCacheSize,isUseInternalDNS);
|
||||
int get hashCode => Object.hash(runtimeType,isEnableToolSiteMirrors,inputGameLaunchECore,customLauncherPath,customGamePath,locationCacheSize,isUseInternalDNS,isEnableOnnxXnnPack);
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'SettingsUIState(isEnableToolSiteMirrors: $isEnableToolSiteMirrors, inputGameLaunchECore: $inputGameLaunchECore, customLauncherPath: $customLauncherPath, customGamePath: $customGamePath, locationCacheSize: $locationCacheSize, isUseInternalDNS: $isUseInternalDNS)';
|
||||
return 'SettingsUIState(isEnableToolSiteMirrors: $isEnableToolSiteMirrors, inputGameLaunchECore: $inputGameLaunchECore, customLauncherPath: $customLauncherPath, customGamePath: $customGamePath, locationCacheSize: $locationCacheSize, isUseInternalDNS: $isUseInternalDNS, isEnableOnnxXnnPack: $isEnableOnnxXnnPack)';
|
||||
}
|
||||
|
||||
|
||||
@ -251,7 +253,7 @@ abstract mixin class _$SettingsUIStateCopyWith<$Res> implements $SettingsUIState
|
||||
factory _$SettingsUIStateCopyWith(_SettingsUIState value, $Res Function(_SettingsUIState) _then) = __$SettingsUIStateCopyWithImpl;
|
||||
@override @useResult
|
||||
$Res call({
|
||||
bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS
|
||||
bool isEnableToolSiteMirrors, String inputGameLaunchECore, String? customLauncherPath, String? customGamePath, int locationCacheSize, bool isUseInternalDNS, bool isEnableOnnxXnnPack
|
||||
});
|
||||
|
||||
|
||||
@ -268,7 +270,7 @@ class __$SettingsUIStateCopyWithImpl<$Res>
|
||||
|
||||
/// Create a copy of SettingsUIState
|
||||
/// with the given fields replaced by the non-null parameter values.
|
||||
@override @pragma('vm:prefer-inline') $Res call({Object? isEnableToolSiteMirrors = null,Object? inputGameLaunchECore = null,Object? customLauncherPath = freezed,Object? customGamePath = freezed,Object? locationCacheSize = null,Object? isUseInternalDNS = null,}) {
|
||||
@override @pragma('vm:prefer-inline') $Res call({Object? isEnableToolSiteMirrors = null,Object? inputGameLaunchECore = null,Object? customLauncherPath = freezed,Object? customGamePath = freezed,Object? locationCacheSize = null,Object? isUseInternalDNS = null,Object? isEnableOnnxXnnPack = null,}) {
|
||||
return _then(_SettingsUIState(
|
||||
isEnableToolSiteMirrors: null == isEnableToolSiteMirrors ? _self.isEnableToolSiteMirrors : isEnableToolSiteMirrors // ignore: cast_nullable_to_non_nullable
|
||||
as bool,inputGameLaunchECore: null == inputGameLaunchECore ? _self.inputGameLaunchECore : inputGameLaunchECore // ignore: cast_nullable_to_non_nullable
|
||||
@ -276,6 +278,7 @@ as String,customLauncherPath: freezed == customLauncherPath ? _self.customLaunch
|
||||
as String?,customGamePath: freezed == customGamePath ? _self.customGamePath : customGamePath // ignore: cast_nullable_to_non_nullable
|
||||
as String?,locationCacheSize: null == locationCacheSize ? _self.locationCacheSize : locationCacheSize // ignore: cast_nullable_to_non_nullable
|
||||
as int,isUseInternalDNS: null == isUseInternalDNS ? _self.isUseInternalDNS : isUseInternalDNS // ignore: cast_nullable_to_non_nullable
|
||||
as bool,isEnableOnnxXnnPack: null == isEnableOnnxXnnPack ? _self.isEnableOnnxXnnPack : isEnableOnnxXnnPack // ignore: cast_nullable_to_non_nullable
|
||||
as bool,
|
||||
));
|
||||
}
|
||||
|
||||
@ -41,7 +41,7 @@ final class SettingsUIModelProvider
|
||||
}
|
||||
}
|
||||
|
||||
String _$settingsUIModelHash() => r'd19104d924f018a9230548d0372692fc344adacd';
|
||||
String _$settingsUIModelHash() => r'72947d5ed36290df865cb010b056dc632f5dccec';
|
||||
|
||||
abstract class _$SettingsUIModel extends $Notifier<SettingsUIState> {
|
||||
SettingsUIState build();
|
||||
|
||||
@ -15,13 +15,15 @@ static MODEL_CACHE: Lazy<Mutex<HashMap<String, OpusMtModel>>> =
|
||||
/// * `model_path` - 模型文件夹路径
|
||||
/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en")
|
||||
/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型)
|
||||
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
|
||||
///
|
||||
pub fn load_translation_model(
|
||||
model_path: String,
|
||||
model_key: String,
|
||||
quantization_suffix: String,
|
||||
use_xnnpack: bool,
|
||||
) -> Result<()> {
|
||||
let model = OpusMtModel::new(&model_path, &quantization_suffix)?;
|
||||
let model = OpusMtModel::new(&model_path, &quantization_suffix, use_xnnpack)?;
|
||||
|
||||
let mut cache = MODEL_CACHE
|
||||
.lock()
|
||||
|
||||
@ -262,6 +262,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
|
||||
model_path: impl CstDecode<String>,
|
||||
model_key: impl CstDecode<String>,
|
||||
quantization_suffix: impl CstDecode<String>,
|
||||
use_xnnpack: impl CstDecode<bool>,
|
||||
) {
|
||||
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
|
||||
flutter_rust_bridge::for_generated::TaskInfo {
|
||||
@ -273,6 +274,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
|
||||
let api_model_path = model_path.cst_decode();
|
||||
let api_model_key = model_key.cst_decode();
|
||||
let api_quantization_suffix = quantization_suffix.cst_decode();
|
||||
let api_use_xnnpack = use_xnnpack.cst_decode();
|
||||
move |context| {
|
||||
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
|
||||
(move || {
|
||||
@ -280,6 +282,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
|
||||
api_model_path,
|
||||
api_model_key,
|
||||
api_quantization_suffix,
|
||||
api_use_xnnpack,
|
||||
)?;
|
||||
Ok(output_ok)
|
||||
})(),
|
||||
@ -1750,12 +1753,14 @@ mod io {
|
||||
model_path: *mut wire_cst_list_prim_u_8_strict,
|
||||
model_key: *mut wire_cst_list_prim_u_8_strict,
|
||||
quantization_suffix: *mut wire_cst_list_prim_u_8_strict,
|
||||
use_xnnpack: bool,
|
||||
) {
|
||||
wire__crate__api__ort_api__load_translation_model_impl(
|
||||
port_,
|
||||
model_path,
|
||||
model_key,
|
||||
quantization_suffix,
|
||||
use_xnnpack,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -38,16 +38,42 @@ impl Default for ModelConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建 ONNX 会话的辅助函数
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_path` - 模型文件路径
|
||||
/// * `model_name` - 模型名称(用于错误消息)
|
||||
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
|
||||
fn create_session(model_path: &Path, model_name: &str, use_xnnpack: bool) -> Result<Session> {
|
||||
let mut builder = Session::builder()
|
||||
.context(format!("Failed to create {} session builder", model_name))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.context("Failed to set optimization level")?
|
||||
.with_intra_threads(4)
|
||||
.context("Failed to set intra threads")?;
|
||||
|
||||
if use_xnnpack {
|
||||
builder = builder
|
||||
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
|
||||
.context("Failed to register XNNPACK execution provider")?;
|
||||
}
|
||||
|
||||
builder
|
||||
.commit_from_file(model_path)
|
||||
.context(format!("Failed to load {} model", model_name))
|
||||
}
|
||||
|
||||
impl OpusMtModel {
|
||||
/// 从模型路径创建新的 OpusMT 模型实例
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_path` - 模型文件夹路径(应包含 onnx 子文件夹)
|
||||
/// * `quantization_suffix` - 量化后缀,如 "_q4", "_q8",为空字符串则使用默认模型
|
||||
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Self>` - 成功返回模型实例,失败返回错误
|
||||
pub fn new<P: AsRef<Path>>(model_path: P, quantization_suffix: &str) -> Result<Self> {
|
||||
pub fn new<P: AsRef<Path>>(model_path: P, quantization_suffix: &str, use_xnnpack: bool) -> Result<Self> {
|
||||
let model_path = model_path.as_ref();
|
||||
|
||||
// onnx-community 标准:模型在 onnx 子文件夹中
|
||||
@ -82,19 +108,7 @@ impl OpusMtModel {
|
||||
));
|
||||
}
|
||||
|
||||
let encoder_session = Session::builder()
|
||||
.context("Failed to create encoder session builder")?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.context("Failed to set optimization level")?
|
||||
.with_intra_threads(4)
|
||||
.context("Failed to set intra threads")?
|
||||
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
|
||||
.context("Failed to register XNNPACK execution provider")?
|
||||
.commit_from_file(&encoder_path)
|
||||
.context(format!(
|
||||
"Failed to load encoder model: {}",
|
||||
encoder_filename
|
||||
))?;
|
||||
let encoder_session = create_session(&encoder_path, "encoder", use_xnnpack)?;
|
||||
|
||||
// 加载 decoder 模型(在 onnx 子目录)
|
||||
let decoder_path = onnx_dir.join(&decoder_filename);
|
||||
@ -105,19 +119,7 @@ impl OpusMtModel {
|
||||
));
|
||||
}
|
||||
|
||||
let decoder_session = Session::builder()
|
||||
.context("Failed to create decoder session builder")?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.context("Failed to set optimization level")?
|
||||
.with_intra_threads(4)
|
||||
.context("Failed to set intra threads")?
|
||||
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
|
||||
.context("Failed to register XNNPACK execution provider")?
|
||||
.commit_from_file(&decoder_path)
|
||||
.context(format!(
|
||||
"Failed to load decoder model: {}",
|
||||
decoder_filename
|
||||
))?;
|
||||
let decoder_session = create_session(&decoder_path, "decoder", use_xnnpack)?;
|
||||
|
||||
// 加载配置(如果存在,在根目录)
|
||||
let config = Self::load_config(model_path)?;
|
||||
@ -395,6 +397,7 @@ mod tests {
|
||||
let model = OpusMtModel::new(
|
||||
"E:\\Project\\StarCtizen\\Opus-MT-StarCitizen\\results\\final_model",
|
||||
"_q4f16",
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
let result = model.translate("北极星要炸了,快撤!").unwrap();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user