Add XNN Pack toggle switch for ONNX inference acceleration (#155)

* Initial plan

* Add XNN Pack switch for ONNX inference acceleration

Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>

* Refactor Rust ONNX session creation to reduce code duplication

Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>
This commit is contained in:
Copilot
2025-11-28 21:23:31 +08:00
committed by GitHub
parent db024f19bd
commit db89100402
26 changed files with 197 additions and 74 deletions

View File

@@ -40,7 +40,7 @@ Future<List<String>> dnsLookupIps({required String host}) =>
/// # Arguments
/// * `urls` - List of base URLs to test
/// * `path_suffix` - Optional path suffix to append to each URL (e.g., "/api/version")
/// If None, tests the base URL directly
/// If None, tests the base URL directly
Future<String?> getFasterUrl({
required List<String> urls,
String? pathSuffix,

View File

@@ -12,15 +12,18 @@ import 'package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart';
/// * `model_path` - 模型文件夹路径
/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en"
/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型)
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
///
Future<void> loadTranslationModel({
required String modelPath,
required String modelKey,
required String quantizationSuffix,
required bool useXnnpack,
}) => RustLib.instance.api.crateApiOrtApiLoadTranslationModel(
modelPath: modelPath,
modelKey: modelKey,
quantizationSuffix: quantizationSuffix,
useXnnpack: useXnnpack,
);
/// 翻译文本

View File

@@ -116,6 +116,7 @@ abstract class RustLibApi extends BaseApi {
required String modelPath,
required String modelKey,
required String quantizationSuffix,
required bool useXnnpack,
});
Future<void> crateApiAsarApiRsiLauncherAsarDataWriteMainJs({
@@ -421,6 +422,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
required String modelPath,
required String modelKey,
required String quantizationSuffix,
required bool useXnnpack,
}) {
return handler.executeNormal(
NormalTask(
@@ -428,11 +430,13 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
var arg0 = cst_encode_String(modelPath);
var arg1 = cst_encode_String(modelKey);
var arg2 = cst_encode_String(quantizationSuffix);
var arg3 = cst_encode_bool(useXnnpack);
return wire.wire__crate__api__ort_api__load_translation_model(
port_,
arg0,
arg1,
arg2,
arg3,
);
},
codec: DcoCodec(
@@ -440,7 +444,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
decodeErrorData: dco_decode_AnyhowException,
),
constMeta: kCrateApiOrtApiLoadTranslationModelConstMeta,
argValues: [modelPath, modelKey, quantizationSuffix],
argValues: [modelPath, modelKey, quantizationSuffix, useXnnpack],
apiImpl: this,
),
);
@@ -449,7 +453,7 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi {
TaskConstMeta get kCrateApiOrtApiLoadTranslationModelConstMeta =>
const TaskConstMeta(
debugName: "load_translation_model",
argNames: ["modelPath", "modelKey", "quantizationSuffix"],
argNames: ["modelPath", "modelKey", "quantizationSuffix", "useXnnpack"],
);
@override

View File

@@ -880,12 +880,14 @@ class RustLibWire implements BaseWire {
ffi.Pointer<wire_cst_list_prim_u_8_strict> model_path,
ffi.Pointer<wire_cst_list_prim_u_8_strict> model_key,
ffi.Pointer<wire_cst_list_prim_u_8_strict> quantization_suffix,
bool use_xnnpack,
) {
return _wire__crate__api__ort_api__load_translation_model(
port_,
model_path,
model_key,
quantization_suffix,
use_xnnpack,
);
}
@@ -897,6 +899,7 @@ class RustLibWire implements BaseWire {
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
ffi.Bool,
)
>
>(
@@ -910,6 +913,7 @@ class RustLibWire implements BaseWire {
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
ffi.Pointer<wire_cst_list_prim_u_8_strict>,
bool,
)
>();