mirror of
https://github.com/StarCitizenToolBox/app.git
synced 2026-02-12 18:20:24 +00:00
feat: ORT Local Translate
This commit is contained in:
@@ -245,10 +245,16 @@ impl OpusMtModel {
|
||||
.context("Failed to create attention_mask array")?;
|
||||
|
||||
// 3. 运行 encoder
|
||||
let input_ids_value =
|
||||
Value::from_array(input_ids_array).context("Failed to create input_ids value")?;
|
||||
let attention_mask_value = Value::from_array(attention_mask_array.clone())
|
||||
.context("Failed to create attention_mask value")?;
|
||||
let input_ids_value = Value::from_array((
|
||||
input_ids_array.shape().to_vec(),
|
||||
input_ids_array.into_raw_vec_and_offset().0,
|
||||
))
|
||||
.context("Failed to create input_ids value")?;
|
||||
let attention_mask_value = Value::from_array((
|
||||
attention_mask_array.shape().to_vec(),
|
||||
attention_mask_array.clone().into_raw_vec_and_offset().0,
|
||||
))
|
||||
.context("Failed to create attention_mask value")?;
|
||||
|
||||
let encoder_inputs = ort::inputs![
|
||||
"input_ids" => input_ids_value,
|
||||
@@ -303,12 +309,21 @@ impl OpusMtModel {
|
||||
.context("Failed to create decoder input_ids")?;
|
||||
|
||||
// 创建 ORT Value
|
||||
let decoder_input_value = Value::from_array(decoder_input_ids)
|
||||
.context("Failed to create decoder input value")?;
|
||||
let encoder_hidden_value = Value::from_array(encoder_hidden_states.clone())
|
||||
.context("Failed to create encoder hidden value")?;
|
||||
let encoder_mask_value = Value::from_array(encoder_attention_mask.clone())
|
||||
.context("Failed to create encoder mask value")?;
|
||||
let decoder_input_value = Value::from_array((
|
||||
decoder_input_ids.shape().to_vec(),
|
||||
decoder_input_ids.into_raw_vec_and_offset().0,
|
||||
))
|
||||
.context("Failed to create decoder input value")?;
|
||||
let encoder_hidden_value = Value::from_array((
|
||||
encoder_hidden_states.shape().to_vec(),
|
||||
encoder_hidden_states.clone().into_raw_vec_and_offset().0,
|
||||
))
|
||||
.context("Failed to create encoder hidden value")?;
|
||||
let encoder_mask_value = Value::from_array((
|
||||
encoder_attention_mask.shape().to_vec(),
|
||||
encoder_attention_mask.clone().into_raw_vec_and_offset().0,
|
||||
))
|
||||
.context("Failed to create encoder mask value")?;
|
||||
|
||||
// 运行 decoder
|
||||
let decoder_inputs = ort::inputs![
|
||||
|
||||
Reference in New Issue
Block a user