feat: ORT Local Translate

This commit is contained in:
xkeyC
2025-11-15 17:58:42 +08:00
parent 58da84c0a6
commit 3219129094
19 changed files with 619 additions and 232 deletions

View File

@@ -245,10 +245,16 @@ impl OpusMtModel {
.context("Failed to create attention_mask array")?;
// 3. 运行 encoder
let input_ids_value =
Value::from_array(input_ids_array).context("Failed to create input_ids value")?;
let attention_mask_value = Value::from_array(attention_mask_array.clone())
.context("Failed to create attention_mask value")?;
let input_ids_value = Value::from_array((
input_ids_array.shape().to_vec(),
input_ids_array.into_raw_vec_and_offset().0,
))
.context("Failed to create input_ids value")?;
let attention_mask_value = Value::from_array((
attention_mask_array.shape().to_vec(),
attention_mask_array.clone().into_raw_vec_and_offset().0,
))
.context("Failed to create attention_mask value")?;
let encoder_inputs = ort::inputs![
"input_ids" => input_ids_value,
@@ -303,12 +309,21 @@ impl OpusMtModel {
.context("Failed to create decoder input_ids")?;
// 创建 ORT Value
let decoder_input_value = Value::from_array(decoder_input_ids)
.context("Failed to create decoder input value")?;
let encoder_hidden_value = Value::from_array(encoder_hidden_states.clone())
.context("Failed to create encoder hidden value")?;
let encoder_mask_value = Value::from_array(encoder_attention_mask.clone())
.context("Failed to create encoder mask value")?;
let decoder_input_value = Value::from_array((
decoder_input_ids.shape().to_vec(),
decoder_input_ids.into_raw_vec_and_offset().0,
))
.context("Failed to create decoder input value")?;
let encoder_hidden_value = Value::from_array((
encoder_hidden_states.shape().to_vec(),
encoder_hidden_states.clone().into_raw_vec_and_offset().0,
))
.context("Failed to create encoder hidden value")?;
let encoder_mask_value = Value::from_array((
encoder_attention_mask.shape().to_vec(),
encoder_attention_mask.clone().into_raw_vec_and_offset().0,
))
.context("Failed to create encoder mask value")?;
// 运行 decoder
let decoder_inputs = ort::inputs![