Add XNN Pack toggle switch for ONNX inference acceleration (#155)

* Initial plan

* Add XNN Pack switch for ONNX inference acceleration

Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>

* Refactor Rust ONNX session creation to reduce code duplication

Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>
This commit is contained in:
Copilot
2025-11-28 21:23:31 +08:00
committed by GitHub
parent db024f19bd
commit db89100402
26 changed files with 197 additions and 74 deletions

View File

@@ -15,13 +15,15 @@ static MODEL_CACHE: Lazy<Mutex<HashMap<String, OpusMtModel>>> =
/// * `model_path` - 模型文件夹路径
/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en"
/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型)
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
///
pub fn load_translation_model(
model_path: String,
model_key: String,
quantization_suffix: String,
use_xnnpack: bool,
) -> Result<()> {
let model = OpusMtModel::new(&model_path, &quantization_suffix)?;
let model = OpusMtModel::new(&model_path, &quantization_suffix, use_xnnpack)?;
let mut cache = MODEL_CACHE
.lock()

View File

@@ -262,6 +262,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
model_path: impl CstDecode<String>,
model_key: impl CstDecode<String>,
quantization_suffix: impl CstDecode<String>,
use_xnnpack: impl CstDecode<bool>,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
@@ -273,6 +274,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
let api_model_path = model_path.cst_decode();
let api_model_key = model_key.cst_decode();
let api_quantization_suffix = quantization_suffix.cst_decode();
let api_use_xnnpack = use_xnnpack.cst_decode();
move |context| {
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
(move || {
@@ -280,6 +282,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
api_model_path,
api_model_key,
api_quantization_suffix,
api_use_xnnpack,
)?;
Ok(output_ok)
})(),
@@ -1750,12 +1753,14 @@ mod io {
model_path: *mut wire_cst_list_prim_u_8_strict,
model_key: *mut wire_cst_list_prim_u_8_strict,
quantization_suffix: *mut wire_cst_list_prim_u_8_strict,
use_xnnpack: bool,
) {
wire__crate__api__ort_api__load_translation_model_impl(
port_,
model_path,
model_key,
quantization_suffix,
use_xnnpack,
)
}

View File

@@ -38,16 +38,42 @@ impl Default for ModelConfig {
}
}
/// 创建 ONNX 会话的辅助函数
///
/// # Arguments
/// * `model_path` - 模型文件路径
/// * `model_name` - 模型名称(用于错误消息)
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
fn create_session(model_path: &Path, model_name: &str, use_xnnpack: bool) -> Result<Session> {
let mut builder = Session::builder()
.context(format!("Failed to create {} session builder", model_name))?
.with_optimization_level(GraphOptimizationLevel::Level3)
.context("Failed to set optimization level")?
.with_intra_threads(4)
.context("Failed to set intra threads")?;
if use_xnnpack {
builder = builder
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
.context("Failed to register XNNPACK execution provider")?;
}
builder
.commit_from_file(model_path)
.context(format!("Failed to load {} model", model_name))
}
impl OpusMtModel {
/// 从模型路径创建新的 OpusMT 模型实例
///
/// # Arguments
/// * `model_path` - 模型文件夹路径(应包含 onnx 子文件夹)
/// * `quantization_suffix` - 量化后缀,如 "_q4", "_q8",为空字符串则使用默认模型
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
///
/// # Returns
/// * `Result<Self>` - 成功返回模型实例,失败返回错误
pub fn new<P: AsRef<Path>>(model_path: P, quantization_suffix: &str) -> Result<Self> {
pub fn new<P: AsRef<Path>>(model_path: P, quantization_suffix: &str, use_xnnpack: bool) -> Result<Self> {
let model_path = model_path.as_ref();
// onnx-community 标准:模型在 onnx 子文件夹中
@@ -82,19 +108,7 @@ impl OpusMtModel {
));
}
let encoder_session = Session::builder()
.context("Failed to create encoder session builder")?
.with_optimization_level(GraphOptimizationLevel::Level3)
.context("Failed to set optimization level")?
.with_intra_threads(4)
.context("Failed to set intra threads")?
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
.context("Failed to register XNNPACK execution provider")?
.commit_from_file(&encoder_path)
.context(format!(
"Failed to load encoder model: {}",
encoder_filename
))?;
let encoder_session = create_session(&encoder_path, "encoder", use_xnnpack)?;
// 加载 decoder 模型(在 onnx 子目录)
let decoder_path = onnx_dir.join(&decoder_filename);
@@ -105,19 +119,7 @@ impl OpusMtModel {
));
}
let decoder_session = Session::builder()
.context("Failed to create decoder session builder")?
.with_optimization_level(GraphOptimizationLevel::Level3)
.context("Failed to set optimization level")?
.with_intra_threads(4)
.context("Failed to set intra threads")?
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
.context("Failed to register XNNPACK execution provider")?
.commit_from_file(&decoder_path)
.context(format!(
"Failed to load decoder model: {}",
decoder_filename
))?;
let decoder_session = create_session(&decoder_path, "decoder", use_xnnpack)?;
// 加载配置(如果存在,在根目录)
let config = Self::load_config(model_path)?;
@@ -395,6 +397,7 @@ mod tests {
let model = OpusMtModel::new(
"E:\\Project\\StarCtizen\\Opus-MT-StarCitizen\\results\\final_model",
"_q4f16",
true,
)
.unwrap();
let result = model.translate("北极星要炸了,快撤!").unwrap();