feat: [rust] add Opus-MT

This commit is contained in:
xkeyC
2025-11-15 17:56:58 +08:00
parent 028660cbba
commit 58da84c0a6
14 changed files with 1676 additions and 128 deletions

View File

@@ -5,3 +5,4 @@ pub mod http_api;
pub mod rs_process;
pub mod win32_api;
pub mod asar_api;
pub mod ort_api;

107
rust/src/api/ort_api.rs Normal file
View File

@@ -0,0 +1,107 @@
use anyhow::Result;
use once_cell::sync::Lazy;
use std::collections::HashMap;
use std::sync::Mutex;
use crate::ort_models::opus_mt::OpusMtModel;
/// 全局模型缓存
static MODEL_CACHE: Lazy<Mutex<HashMap<String, OpusMtModel>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
/// 加载 ONNX 翻译模型
///
/// # Arguments
/// * `model_path` - 模型文件夹路径
/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en"
/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型)
///
pub fn load_translation_model(
model_path: String,
model_key: String,
quantization_suffix: String,
) -> Result<()> {
let model = OpusMtModel::new(&model_path, &quantization_suffix)?;
let mut cache = MODEL_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?;
cache.insert(model_key, model);
Ok(())
}
/// 翻译文本
///
/// # Arguments
/// * `model_key` - 模型缓存键(如 "zh-en"
/// * `text` - 要翻译的文本
///
/// # Returns
/// * `Result<String>` - 翻译后的文本
pub fn translate_text(model_key: String, text: String) -> Result<String> {
let cache = MODEL_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?;
let model = cache.get(&model_key).ok_or_else(|| {
anyhow::anyhow!(
"Model not found: {}. Please load the model first.",
model_key
)
})?;
model.translate(&text)
}
/// 批量翻译文本
///
/// # Arguments
/// * `model_key` - 模型缓存键(如 "zh-en"
/// * `texts` - 要翻译的文本列表
///
/// # Returns
/// * `Result<Vec<String>>` - 翻译后的文本列表
pub fn translate_text_batch(model_key: String, texts: Vec<String>) -> Result<Vec<String>> {
let cache = MODEL_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?;
let model = cache.get(&model_key).ok_or_else(|| {
anyhow::anyhow!(
"Model not found: {}. Please load the model first.",
model_key
)
})?;
model.translate_batch(&texts)
}
/// 卸载模型
///
/// # Arguments
/// * `model_key` - 模型缓存键(如 "zh-en"
///
pub fn unload_translation_model(model_key: String) -> Result<()> {
let mut cache = MODEL_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?;
cache.remove(&model_key);
Ok(())
}
/// 清空所有已加载的模型
///
/// # Returns
pub fn clear_all_models() -> Result<()> {
let mut cache = MODEL_CACHE
.lock()
.map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?;
cache.clear();
Ok(())
}

View File

@@ -37,7 +37,7 @@ flutter_rust_bridge::frb_generated_boilerplate!(
default_rust_auto_opaque = RustAutoOpaqueNom,
);
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.11.1";
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 1832496273;
pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -706588047;
// Section: executor
@@ -45,6 +45,27 @@ flutter_rust_bridge::frb_generated_default_handler!();
// Section: wire_funcs
fn wire__crate__api__ort_api__clear_all_models_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
debug_name: "clear_all_models",
port: Some(port_),
mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal,
},
move || {
move |context| {
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
(move || {
let output_ok = crate::api::ort_api::clear_all_models()?;
Ok(output_ok)
})(),
)
}
},
)
}
fn wire__crate__api__http_api__dns_lookup_ips_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
host: impl CstDecode<String>,
@@ -161,6 +182,37 @@ fn wire__crate__api__asar_api__get_rsi_launcher_asar_data_impl(
},
)
}
fn wire__crate__api__ort_api__load_translation_model_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
model_path: impl CstDecode<String>,
model_key: impl CstDecode<String>,
quantization_suffix: impl CstDecode<String>,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
debug_name: "load_translation_model",
port: Some(port_),
mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal,
},
move || {
let api_model_path = model_path.cst_decode();
let api_model_key = model_key.cst_decode();
let api_quantization_suffix = quantization_suffix.cst_decode();
move |context| {
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
(move || {
let output_ok = crate::api::ort_api::load_translation_model(
api_model_path,
api_model_key,
api_quantization_suffix,
)?;
Ok(output_ok)
})(),
)
}
},
)
}
fn wire__crate__api__asar_api__rsi_launcher_asar_data_write_main_js_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
that: impl CstDecode<crate::api::asar_api::RsiLauncherAsarData>,
@@ -315,6 +367,82 @@ fn wire__crate__api__rs_process__start_impl(
},
)
}
fn wire__crate__api__ort_api__translate_text_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
model_key: impl CstDecode<String>,
text: impl CstDecode<String>,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
debug_name: "translate_text",
port: Some(port_),
mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal,
},
move || {
let api_model_key = model_key.cst_decode();
let api_text = text.cst_decode();
move |context| {
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
(move || {
let output_ok =
crate::api::ort_api::translate_text(api_model_key, api_text)?;
Ok(output_ok)
})(),
)
}
},
)
}
fn wire__crate__api__ort_api__translate_text_batch_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
model_key: impl CstDecode<String>,
texts: impl CstDecode<Vec<String>>,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
debug_name: "translate_text_batch",
port: Some(port_),
mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal,
},
move || {
let api_model_key = model_key.cst_decode();
let api_texts = texts.cst_decode();
move |context| {
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
(move || {
let output_ok =
crate::api::ort_api::translate_text_batch(api_model_key, api_texts)?;
Ok(output_ok)
})(),
)
}
},
)
}
fn wire__crate__api__ort_api__unload_translation_model_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
model_key: impl CstDecode<String>,
) {
FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::<flutter_rust_bridge::for_generated::DcoCodec, _, _>(
flutter_rust_bridge::for_generated::TaskInfo {
debug_name: "unload_translation_model",
port: Some(port_),
mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal,
},
move || {
let api_model_key = model_key.cst_decode();
move |context| {
transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>(
(move || {
let output_ok =
crate::api::ort_api::unload_translation_model(api_model_key)?;
Ok(output_ok)
})(),
)
}
},
)
}
fn wire__crate__api__rs_process__write_impl(
port_: flutter_rust_bridge::for_generated::MessagePort,
rs_pid: impl CstDecode<u32>,
@@ -1361,6 +1489,13 @@ mod io {
}
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__clear_all_models(
port_: i64,
) {
wire__crate__api__ort_api__clear_all_models_impl(port_)
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__http_api__dns_lookup_ips(
port_: i64,
@@ -1406,6 +1541,21 @@ mod io {
wire__crate__api__asar_api__get_rsi_launcher_asar_data_impl(port_, asar_path)
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__load_translation_model(
port_: i64,
model_path: *mut wire_cst_list_prim_u_8_strict,
model_key: *mut wire_cst_list_prim_u_8_strict,
quantization_suffix: *mut wire_cst_list_prim_u_8_strict,
) {
wire__crate__api__ort_api__load_translation_model_impl(
port_,
model_path,
model_key,
quantization_suffix,
)
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__asar_api__rsi_launcher_asar_data_write_main_js(
port_: i64,
@@ -1459,6 +1609,32 @@ mod io {
)
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__translate_text(
port_: i64,
model_key: *mut wire_cst_list_prim_u_8_strict,
text: *mut wire_cst_list_prim_u_8_strict,
) {
wire__crate__api__ort_api__translate_text_impl(port_, model_key, text)
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__translate_text_batch(
port_: i64,
model_key: *mut wire_cst_list_prim_u_8_strict,
texts: *mut wire_cst_list_String,
) {
wire__crate__api__ort_api__translate_text_batch_impl(port_, model_key, texts)
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__unload_translation_model(
port_: i64,
model_key: *mut wire_cst_list_prim_u_8_strict,
) {
wire__crate__api__ort_api__unload_translation_model_impl(port_, model_key)
}
#[unsafe(no_mangle)]
pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__rs_process__write(
port_: i64,

View File

@@ -1,3 +1,4 @@
pub mod api;
mod frb_generated;
pub mod http_package;
pub mod ort_models;

View File

@@ -0,0 +1 @@
pub mod opus_mt;

View File

@@ -0,0 +1,388 @@
use anyhow::{anyhow, Context, Result};
use ndarray::{Array2, ArrayD};
use ort::{
execution_providers::XNNPACKExecutionProvider, session::builder::GraphOptimizationLevel,
session::Session, value::Value,
};
use std::path::Path;
use std::sync::Mutex;
use tokenizers::Tokenizer;
/// Opus-MT 翻译模型的推理结构
pub struct OpusMtModel {
encoder_session: Mutex<Session>,
decoder_session: Mutex<Session>,
tokenizer: Tokenizer,
config: ModelConfig,
}
/// 模型配置
#[derive(Debug, Clone)]
pub struct ModelConfig {
pub max_length: usize,
pub num_beams: usize,
pub decoder_start_token_id: i64,
pub eos_token_id: i64,
pub pad_token_id: i64,
}
impl Default for ModelConfig {
fn default() -> Self {
Self {
max_length: 512,
num_beams: 1,
decoder_start_token_id: 0,
eos_token_id: 0,
pad_token_id: 0,
}
}
}
impl OpusMtModel {
/// 从模型路径创建新的 OpusMT 模型实例
///
/// # Arguments
/// * `model_path` - 模型文件夹路径(应包含 onnx 子文件夹)
/// * `quantization_suffix` - 量化后缀,如 "_q4", "_q8",为空字符串则使用默认模型
///
/// # Returns
/// * `Result<Self>` - 成功返回模型实例,失败返回错误
pub fn new<P: AsRef<Path>>(model_path: P, quantization_suffix: &str) -> Result<Self> {
let model_path = model_path.as_ref();
// onnx-community 标准:模型在 onnx 子文件夹中
let onnx_dir = model_path.join("onnx");
// 加载 tokenizer在根目录
let tokenizer_path = model_path.join("tokenizer.json");
// 动态加载并修复 tokenizer
let tokenizer =
Self::load_tokenizer(&tokenizer_path).context("Failed to load tokenizer")?;
// 构建模型文件名
let encoder_filename = if quantization_suffix.is_empty() {
"encoder_model.onnx".to_string()
} else {
format!("encoder_model{}.onnx", quantization_suffix)
};
let decoder_filename = if quantization_suffix.is_empty() {
"decoder_model.onnx".to_string()
} else {
format!("decoder_model{}.onnx", quantization_suffix)
};
// 加载 encoder 模型(在 onnx 子目录)
let encoder_path = onnx_dir.join(&encoder_filename);
if !encoder_path.exists() {
return Err(anyhow!(
"Encoder model not found: {}",
encoder_path.display()
));
}
let encoder_session = Session::builder()
.context("Failed to create encoder session builder")?
.with_optimization_level(GraphOptimizationLevel::Level3)
.context("Failed to set optimization level")?
.with_intra_threads(4)
.context("Failed to set intra threads")?
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
.context("Failed to register XNNPACK execution provider")?
.commit_from_file(&encoder_path)
.context(format!(
"Failed to load encoder model: {}",
encoder_filename
))?;
// 加载 decoder 模型(在 onnx 子目录)
let decoder_path = onnx_dir.join(&decoder_filename);
if !decoder_path.exists() {
return Err(anyhow!(
"Decoder model not found: {}",
decoder_path.display()
));
}
let decoder_session = Session::builder()
.context("Failed to create decoder session builder")?
.with_optimization_level(GraphOptimizationLevel::Level3)
.context("Failed to set optimization level")?
.with_intra_threads(4)
.context("Failed to set intra threads")?
.with_execution_providers([XNNPACKExecutionProvider::default().build()])
.context("Failed to register XNNPACK execution provider")?
.commit_from_file(&decoder_path)
.context(format!(
"Failed to load decoder model: {}",
decoder_filename
))?;
// 加载配置(如果存在,在根目录)
let config = Self::load_config(model_path)?;
Ok(Self {
encoder_session: Mutex::new(encoder_session),
decoder_session: Mutex::new(decoder_session),
tokenizer,
config,
})
}
/// 动态加载 tokenizer自动修复常见问题
fn load_tokenizer(tokenizer_path: &Path) -> Result<Tokenizer> {
use std::fs;
// 读取原始文件
let content =
fs::read_to_string(tokenizer_path).context("Failed to read tokenizer.json")?;
// 解析为 JSON
let mut json: serde_json::Value =
serde_json::from_str(&content).context("Failed to parse tokenizer.json")?;
let mut needs_fix = false;
// 修复 normalizer 中的问题
if let Some(obj) = json.as_object_mut() {
if let Some(normalizer) = obj.get("normalizer") {
let mut should_remove_normalizer = false;
if normalizer.is_null() {
// normalizer 是 null需要移除
should_remove_normalizer = true;
} else if let Some(norm_obj) = normalizer.as_object() {
// 检查是否是有问题的 Precompiled 类型
if let Some(type_val) = norm_obj.get("type") {
if type_val.as_str() == Some("Precompiled") {
// 检查 precompiled_charsmap 字段
if let Some(precompiled) = norm_obj.get("precompiled_charsmap") {
if precompiled.is_null() {
// precompiled_charsmap 是 null移除整个 normalizer
should_remove_normalizer = true;
}
} else {
// 缺少 precompiled_charsmap 字段,移除整个 normalizer
should_remove_normalizer = true;
}
}
}
}
if should_remove_normalizer {
obj.remove("normalizer");
needs_fix = true;
}
}
}
// 从修复后的 JSON 字符串加载 tokenizer
let json_str = if needs_fix {
serde_json::to_string(&json).context("Failed to serialize fixed tokenizer")?
} else {
content
};
// 从字节数组加载 tokenizer
Tokenizer::from_bytes(json_str.as_bytes())
.map_err(|e| anyhow!("Failed to load tokenizer: {}", e))
}
/// 从配置文件加载模型配置
fn load_config(model_path: &Path) -> Result<ModelConfig> {
let config_path = model_path.join("config.json");
if config_path.exists() {
let config_str =
std::fs::read_to_string(config_path).context("Failed to read config.json")?;
let config_json: serde_json::Value =
serde_json::from_str(&config_str).context("Failed to parse config.json")?;
Ok(ModelConfig {
max_length: config_json["max_length"].as_u64().unwrap_or(512) as usize,
num_beams: config_json["num_beams"].as_u64().unwrap_or(1) as usize,
decoder_start_token_id: config_json["decoder_start_token_id"].as_i64().unwrap_or(0),
eos_token_id: config_json["eos_token_id"].as_i64().unwrap_or(0),
pad_token_id: config_json["pad_token_id"].as_i64().unwrap_or(0),
})
} else {
Ok(ModelConfig::default())
}
}
/// 翻译文本
///
/// # Arguments
/// * `text` - 要翻译的文本
///
/// # Returns
/// * `Result<String>` - 翻译后的文本
pub fn translate(&self, text: &str) -> Result<String> {
// 1. Tokenize 输入文本
let encoding = self
.tokenizer
.encode(text, true)
.map_err(|e| anyhow!("Failed to encode text: {}", e))?;
let input_ids = encoding.get_ids();
let attention_mask = encoding.get_attention_mask();
// 2. 准备 encoder 输入
let batch_size = 1;
let seq_len = input_ids.len();
let input_ids_array: Array2<i64> = Array2::from_shape_vec(
(batch_size, seq_len),
input_ids.iter().map(|&id| id as i64).collect(),
)
.context("Failed to create input_ids array")?;
let attention_mask_array: Array2<i64> = Array2::from_shape_vec(
(batch_size, seq_len),
attention_mask.iter().map(|&mask| mask as i64).collect(),
)
.context("Failed to create attention_mask array")?;
// 3. 运行 encoder
let input_ids_value =
Value::from_array(input_ids_array).context("Failed to create input_ids value")?;
let attention_mask_value = Value::from_array(attention_mask_array.clone())
.context("Failed to create attention_mask value")?;
let encoder_inputs = ort::inputs![
"input_ids" => input_ids_value,
"attention_mask" => attention_mask_value,
];
let mut encoder_session = self
.encoder_session
.lock()
.map_err(|e| anyhow!("Failed to lock encoder session: {}", e))?;
let encoder_outputs = encoder_session
.run(encoder_inputs)
.context("Failed to run encoder")?;
let encoder_hidden_states = encoder_outputs["last_hidden_state"]
.try_extract_tensor::<f32>()
.context("Failed to extract encoder hidden states")?;
// 将 tensor 转换为 ArrayD
let (shape, data) = encoder_hidden_states;
let shape_vec: Vec<usize> = shape.iter().map(|&x| x as usize).collect();
let encoder_array = ArrayD::from_shape_vec(shape_vec, data.to_vec())
.context("Failed to create encoder array")?;
// 4. 贪婪解码生成输出
let output_ids = self.greedy_decode(encoder_array, &attention_mask_array)?;
// 5. Decode 输出 token IDs
let output_tokens: Vec<u32> = output_ids.iter().map(|&id| id as u32).collect();
let decoded = self
.tokenizer
.decode(&output_tokens, true)
.map_err(|e| anyhow!("Failed to decode output: {}", e))?;
Ok(decoded)
}
/// 贪婪解码
fn greedy_decode(
&self,
encoder_hidden_states: ArrayD<f32>,
encoder_attention_mask: &Array2<i64>,
) -> Result<Vec<i64>> {
let batch_size = 1;
let mut generated_ids = vec![self.config.decoder_start_token_id];
for _ in 0..self.config.max_length {
// 准备 decoder 输入
let decoder_input_len = generated_ids.len();
let decoder_input_ids: Array2<i64> =
Array2::from_shape_vec((batch_size, decoder_input_len), generated_ids.clone())
.context("Failed to create decoder input_ids")?;
// 创建 ORT Value
let decoder_input_value = Value::from_array(decoder_input_ids)
.context("Failed to create decoder input value")?;
let encoder_hidden_value = Value::from_array(encoder_hidden_states.clone())
.context("Failed to create encoder hidden value")?;
let encoder_mask_value = Value::from_array(encoder_attention_mask.clone())
.context("Failed to create encoder mask value")?;
// 运行 decoder
let decoder_inputs = ort::inputs![
"input_ids" => decoder_input_value,
"encoder_hidden_states" => encoder_hidden_value,
"encoder_attention_mask" => encoder_mask_value,
];
let mut decoder_session = self
.decoder_session
.lock()
.map_err(|e| anyhow!("Failed to lock decoder session: {}", e))?;
let decoder_outputs = decoder_session
.run(decoder_inputs)
.context("Failed to run decoder")?;
// 获取 logits
let logits_tensor = decoder_outputs["logits"]
.try_extract_tensor::<f32>()
.context("Failed to extract logits")?;
let (logits_shape, logits_data) = logits_tensor;
let vocab_size = logits_shape[2] as usize;
// 获取最后一个 token 的 logits
let last_token_idx = decoder_input_len - 1;
let last_logits_start = last_token_idx * vocab_size;
let last_logits_end = last_logits_start + vocab_size;
let last_logits_slice = &logits_data[last_logits_start..last_logits_end];
// 找到最大概率的 token
let next_token_id = last_logits_slice
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.map(|(idx, _)| idx as i64)
.context("Failed to find max token")?;
// 检查是否到达结束 token
if next_token_id == self.config.eos_token_id {
break;
}
generated_ids.push(next_token_id);
}
Ok(generated_ids)
}
/// 批量翻译文本
///
/// # Arguments
/// * `texts` - 要翻译的文本列表
///
/// # Returns
/// * `Result<Vec<String>>` - 翻译后的文本列表
pub fn translate_batch(&self, texts: &[String]) -> Result<Vec<String>> {
texts.iter().map(|text| self.translate(text)).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_translation() {
let model = OpusMtModel::new(
"C:\\Users\\xkeyc\\Downloads\\onnx_models\\opus-mt-zh-en",
"_q4f16",
)
.unwrap();
let result = model.translate("你好世界").unwrap();
println!("Translation: {}", result);
}
}