mirror of
https://github.com/StarCitizenToolBox/app.git
synced 2026-02-11 17:50:23 +00:00
Add XNN Pack toggle switch for ONNX inference acceleration (#155)
* Initial plan * Add XNN Pack switch for ONNX inference acceleration Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com> * Refactor Rust ONNX session creation to reduce code duplication Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: xkeyC <39891083+xkeyC@users.noreply.github.com>
This commit is contained in:
@@ -15,13 +15,15 @@ static MODEL_CACHE: Lazy<Mutex<HashMap<String, OpusMtModel>>> =
|
||||
/// * `model_path` - 模型文件夹路径
|
||||
/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en")
|
||||
/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型)
|
||||
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
|
||||
///
|
||||
pub fn load_translation_model(
|
||||
model_path: String,
|
||||
model_key: String,
|
||||
quantization_suffix: String,
|
||||
use_xnnpack: bool,
|
||||
) -> Result<()> {
|
||||
let model = OpusMtModel::new(&model_path, &quantization_suffix)?;
|
||||
let model = OpusMtModel::new(&model_path, &quantization_suffix, use_xnnpack)?;
|
||||
|
||||
let mut cache = MODEL_CACHE
|
||||
.lock()
|
||||
|
||||
@@ -262,6 +262,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
|
||||
model_path: impl CstDecode<String>,
|
||||
model_key: impl CstDecode<String>,
|
||||
quantization_suffix: impl CstDecode<String>,
|
||||
use_xnnpack: impl CstDecode<bool>,
|
||||
) {
|
||||
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
|
||||
flutter_rust_bridge::for_generated::TaskInfo {
|
||||
@@ -273,6 +274,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
|
||||
let api_model_path = model_path.cst_decode();
|
||||
let api_model_key = model_key.cst_decode();
|
||||
let api_quantization_suffix = quantization_suffix.cst_decode();
|
||||
let api_use_xnnpack = use_xnnpack.cst_decode();
|
||||
move |context| {
|
||||
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
|
||||
(move || {
|
||||
@@ -280,6 +282,7 @@ fn wire__crate__api__ort_api__load_translation_model_impl(
|
||||
api_model_path,
|
||||
api_model_key,
|
||||
api_quantization_suffix,
|
||||
api_use_xnnpack,
|
||||
)?;
|
||||
Ok(output_ok)
|
||||
})(),
|
||||
@@ -1750,12 +1753,14 @@ mod io {
|
||||
model_path: *mut wire_cst_list_prim_u_8_strict,
|
||||
model_key: *mut wire_cst_list_prim_u_8_strict,
|
||||
quantization_suffix: *mut wire_cst_list_prim_u_8_strict,
|
||||
use_xnnpack: bool,
|
||||
) {
|
||||
wire__crate__api__ort_api__load_translation_model_impl(
|
||||
port_,
|
||||
model_path,
|
||||
model_key,
|
||||
quantization_suffix,
|
||||
use_xnnpack,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -38,16 +38,42 @@ impl Default for ModelConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// 创建 ONNX 会话的辅助函数
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_path` - 模型文件路径
|
||||
/// * `model_name` - 模型名称(用于错误消息)
|
||||
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
|
||||
fn create_session(model_path: &Path, model_name: &str, use_xnnpack: bool) -> Result<Session> {
|
||||
let mut builder = Session::builder()
|
||||
.context(format!("Failed to create {} session builder", model_name))?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.context("Failed to set optimization level")?
|
||||
.with_intra_threads(4)
|
||||
.context("Failed to set intra threads")?;
|
||||
|
||||
if use_xnnpack {
|
||||
builder = builder
|
||||
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
|
||||
.context("Failed to register XNNPACK execution provider")?;
|
||||
}
|
||||
|
||||
builder
|
||||
.commit_from_file(model_path)
|
||||
.context(format!("Failed to load {} model", model_name))
|
||||
}
|
||||
|
||||
impl OpusMtModel {
|
||||
/// 从模型路径创建新的 OpusMT 模型实例
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `model_path` - 模型文件夹路径(应包含 onnx 子文件夹)
|
||||
/// * `quantization_suffix` - 量化后缀,如 "_q4", "_q8",为空字符串则使用默认模型
|
||||
/// * `use_xnnpack` - 是否使用 XNNPACK 加速
|
||||
///
|
||||
/// # Returns
|
||||
/// * `Result<Self>` - 成功返回模型实例,失败返回错误
|
||||
pub fn new<P: AsRef<Path>>(model_path: P, quantization_suffix: &str) -> Result<Self> {
|
||||
pub fn new<P: AsRef<Path>>(model_path: P, quantization_suffix: &str, use_xnnpack: bool) -> Result<Self> {
|
||||
let model_path = model_path.as_ref();
|
||||
|
||||
// onnx-community 标准:模型在 onnx 子文件夹中
|
||||
@@ -82,19 +108,7 @@ impl OpusMtModel {
|
||||
));
|
||||
}
|
||||
|
||||
let encoder_session = Session::builder()
|
||||
.context("Failed to create encoder session builder")?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.context("Failed to set optimization level")?
|
||||
.with_intra_threads(4)
|
||||
.context("Failed to set intra threads")?
|
||||
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
|
||||
.context("Failed to register XNNPACK execution provider")?
|
||||
.commit_from_file(&encoder_path)
|
||||
.context(format!(
|
||||
"Failed to load encoder model: {}",
|
||||
encoder_filename
|
||||
))?;
|
||||
let encoder_session = create_session(&encoder_path, "encoder", use_xnnpack)?;
|
||||
|
||||
// 加载 decoder 模型(在 onnx 子目录)
|
||||
let decoder_path = onnx_dir.join(&decoder_filename);
|
||||
@@ -105,19 +119,7 @@ impl OpusMtModel {
|
||||
));
|
||||
}
|
||||
|
||||
let decoder_session = Session::builder()
|
||||
.context("Failed to create decoder session builder")?
|
||||
.with_optimization_level(GraphOptimizationLevel::Level3)
|
||||
.context("Failed to set optimization level")?
|
||||
.with_intra_threads(4)
|
||||
.context("Failed to set intra threads")?
|
||||
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
|
||||
.context("Failed to register XNNPACK execution provider")?
|
||||
.commit_from_file(&decoder_path)
|
||||
.context(format!(
|
||||
"Failed to load decoder model: {}",
|
||||
decoder_filename
|
||||
))?;
|
||||
let decoder_session = create_session(&decoder_path, "decoder", use_xnnpack)?;
|
||||
|
||||
// 加载配置(如果存在,在根目录)
|
||||
let config = Self::load_config(model_path)?;
|
||||
@@ -395,6 +397,7 @@ mod tests {
|
||||
let model = OpusMtModel::new(
|
||||
"E:\\Project\\StarCtizen\\Opus-MT-StarCitizen\\results\\final_model",
|
||||
"_q4f16",
|
||||
true,
|
||||
)
|
||||
.unwrap();
|
||||
let result = model.translate("北极星要炸了,快撤!").unwrap();
|
||||
|
||||
Reference in New Issue
Block a user