Add XNN Pack toggle switch for ONNX inference acceleration (#155)

* Initial plan

* Add XNN Pack switch for ONNX inference acceleration

Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>

* Refactor Rust ONNX session creation to reduce code duplication

Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>
This commit is contained in:
Copilot
2025-11-28 21:23:31 +08:00
committed by GitHub
parent db024f19bd
commit db89100402
26 changed files with 197 additions and 74 deletions

View File

@@ -199,8 +199,13 @@ class InputMethodDialogUIModel extends _$InputMethodDialogUIModel {
String get _localTranslateModelDir => "${appGlobalState.applicationSupportDir}/onnx_models";
bool get _isEnableOnnxXnnPack {
final userBox = Hive.box("app_conf");
return userBox.get("isEnableOnnxXnnPack", defaultValue: true);
}
OnnxTranslationProvider get _localTranslateModelProvider =>
onnxTranslationProvider(_localTranslateModelDir, _localTranslateModelName);
onnxTranslationProvider(_localTranslateModelDir, _localTranslateModelName, _isEnableOnnxXnnPack);
void _checkAutoTranslateOnInit() {
// 检查模型文件是否存在,不存在则关闭自动翻译
@@ -333,20 +338,21 @@ class InputMethodDialogUIModel extends _$InputMethodDialogUIModel {
@riverpod
class OnnxTranslation extends _$OnnxTranslation {
@override
bool build(String modelDir, String modelName) {
dPrint("[OnnxTranslation] Build provider for model: $modelName");
bool build(String modelDir, String modelName, bool useXnnPack) {
dPrint("[OnnxTranslation] Build provider for model: $modelName, useXnnPack: $useXnnPack");
ref.onDispose(disposeModel);
return false;
}
Future<String?> initModel() async {
dPrint("[OnnxTranslation] Load model: $modelName from $modelDir");
dPrint("[OnnxTranslation] Load model: $modelName from $modelDir, useXnnPack: $useXnnPack");
String? errorMessage;
try {
await ort.loadTranslationModel(
modelPath: "$modelDir/$modelName",
modelKey: modelName,
quantizationSuffix: "_q4f16",
useXnnpack: useXnnPack,
);
state = true;
} catch (e) {

View File

@@ -43,7 +43,7 @@ final class InputMethodDialogUIModelProvider
}
String _$inputMethodDialogUIModelHash() =>
r'51f1708f22a90f7c2f879ad3d2a87a8e2f81b9e9';
r'bd96c85ef2073d80de6eba71748b41adb8861e1c';
abstract class _$InputMethodDialogUIModel
extends $Notifier<InputMethodDialogUIState> {
@@ -73,7 +73,7 @@ final class OnnxTranslationProvider
extends $NotifierProvider<OnnxTranslation, bool> {
const OnnxTranslationProvider._({
required OnnxTranslationFamily super.from,
required (String, String) super.argument,
required (String, String, bool) super.argument,
}) : super(
retry: null,
name: r'onnxTranslationProvider',
@@ -115,7 +115,7 @@ final class OnnxTranslationProvider
}
}
String _$onnxTranslationHash() => r'4f3dc0e361dca2d6b00f557496bdf006cc6c235c';
String _$onnxTranslationHash() => r'd4946a47240ab42dd65c35fa3dda365e4c491462';
final class OnnxTranslationFamily extends $Family
with
@@ -124,7 +124,7 @@ final class OnnxTranslationFamily extends $Family
bool,
bool,
bool,
(String, String)
(String, String, bool)
> {
const OnnxTranslationFamily._()
: super(
@@ -135,23 +135,30 @@ final class OnnxTranslationFamily extends $Family
isAutoDispose: true,
);
OnnxTranslationProvider call(String modelDir, String modelName) =>
OnnxTranslationProvider._(argument: (modelDir, modelName), from: this);
OnnxTranslationProvider call(
String modelDir,
String modelName,
bool useXnnPack,
) => OnnxTranslationProvider._(
argument: (modelDir, modelName, useXnnPack),
from: this,
);
@override
String toString() => r'onnxTranslationProvider';
}
abstract class _$OnnxTranslation extends $Notifier<bool> {
late final _$args = ref.$arg as (String, String);
late final _$args = ref.$arg as (String, String, bool);
String get modelDir => _$args.$1;
String get modelName => _$args.$2;
bool get useXnnPack => _$args.$3;
bool build(String modelDir, String modelName);
bool build(String modelDir, String modelName, bool useXnnPack);
@$mustCallSuper
@override
void runBuild() {
final created = build(_$args.$1, _$args.$2);
final created = build(_$args.$1, _$args.$2, _$args.$3);
final ref = this.ref as $Ref<bool, bool>;
final element =
ref.element