diff --git a/lib/common/rust/frb_generated.io.dart b/lib/common/rust/frb_generated.io.dart index 190df88..c03e956 100644 --- a/lib/common/rust/frb_generated.io.dart +++ b/lib/common/rust/frb_generated.io.dart @@ -584,7 +584,7 @@ abstract class RustLibApiImplPlatform extends BaseApiImpl { // AUTO GENERATED FILE, DO NOT EDIT. // // Generated by `package:ffigen`. -// ignore_for_file: type=lint +// ignore_for_file: type=lint, unused_import /// generated by flutter_rust_bridge class RustLibWire implements BaseWire { diff --git a/lib/main.dart b/lib/main.dart index 9cecfd4..956782c 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -97,14 +97,7 @@ class App extends HookConsumerWidget with WindowListener { @override Future onWindowClose() async { debugPrint("onWindowClose"); - if (await windowManager.isPreventClose()) { - final windows = await DesktopMultiWindow.getAllSubWindowIds(); - for (final id in windows) { - await WindowController.fromWindowId(id).close(); - } - await windowManager.destroy(); - exit(0); - } + exit(0); super.onWindowClose(); } } diff --git a/lib/provider/aria2c.dart b/lib/provider/aria2c.dart index 3e4f75a..1d859d9 100644 --- a/lib/provider/aria2c.dart +++ b/lib/provider/aria2c.dart @@ -8,11 +8,11 @@ import 'package:flutter/foundation.dart'; import 'package:hive_ce/hive.dart'; import 'package:starcitizen_doctor/api/api.dart'; import 'package:starcitizen_doctor/common/helper/system_helper.dart'; -import 'package:starcitizen_doctor/common/rust/api/rs_process.dart' - as rs_process; +import 'package:starcitizen_doctor/common/rust/api/rs_process.dart' as rs_process; import 'package:starcitizen_doctor/common/utils/log.dart'; import 'package:starcitizen_doctor/common/utils/provider.dart'; +import 'package:starcitizen_doctor/ui/home/downloader/home_downloader_ui_model.dart'; part 'aria2c.g.dart'; @@ -20,11 +20,8 @@ part 'aria2c.freezed.dart'; @freezed abstract class Aria2cModelState with _$Aria2cModelState { - const factory Aria2cModelState({ - required String aria2cDir, - Aria2c? aria2c, - Aria2GlobalStat? aria2globalStat, - }) = _Aria2cModelState; + const factory Aria2cModelState({required String aria2cDir, Aria2c? aria2c, Aria2GlobalStat? aria2globalStat}) = + _Aria2cModelState; } extension Aria2cModelExt on Aria2cModelState { @@ -32,10 +29,8 @@ extension Aria2cModelExt on Aria2cModelState { bool get hasDownloadTask => aria2globalStat != null && aria2TotalTaskNum > 0; - int get aria2TotalTaskNum => aria2globalStat == null - ? 0 - : ((aria2globalStat!.numActive ?? 0) + - (aria2globalStat!.numWaiting ?? 0)); + int get aria2TotalTaskNum => + aria2globalStat == null ? 0 : ((aria2globalStat!.numActive ?? 0) + (aria2globalStat!.numWaiting ?? 0)); } @riverpod @@ -57,8 +52,7 @@ class Aria2cModel extends _$Aria2cModel { try { final sessionFile = File("$aria2cDir\\aria2.session"); // 有下载任务则第一时间初始化 - if (await sessionFile.exists() && - (await sessionFile.readAsString()).trim().isNotEmpty) { + if (await sessionFile.exists() && (await sessionFile.readAsString()).trim().isNotEmpty) { dPrint("launch Aria2c daemon"); await launchDaemon(appGlobalState.applicationBinaryModuleDir!); } else { @@ -74,8 +68,7 @@ class Aria2cModel extends _$Aria2cModel { Future launchDaemon(String applicationBinaryModuleDir) async { if (state.aria2c != null) return; - await BinaryModuleConf.extractModule( - ["aria2c"], applicationBinaryModuleDir); + await BinaryModuleConf.extractModule(["aria2c"], applicationBinaryModuleDir); /// skip for debug hot reload if (kDebugMode) { @@ -99,30 +92,30 @@ class Aria2cModel extends _$Aria2cModel { dPrint("Aria2cManager .----- aria2c start $port------"); final stream = rs_process.start( - executable: exePath, - arguments: [ - "-V", - "-c", - "-x 16", - "--dir=${state.aria2cDir}\\downloads", - "--disable-ipv6", - "--enable-rpc", - "--pause", - "--rpc-listen-port=$port", - "--rpc-secret=$pwd", - "--input-file=${sessionFile.absolute.path.trim()}", - "--save-session=${sessionFile.absolute.path.trim()}", - "--save-session-interval=60", - "--file-allocation=trunc", - "--seed-time=0", - ], - workingDirectory: state.aria2cDir); + executable: exePath, + arguments: [ + "-V", + "-c", + "-x 16", + "--dir=${state.aria2cDir}\\downloads", + "--disable-ipv6", + "--enable-rpc", + "--pause", + "--rpc-listen-port=$port", + "--rpc-secret=$pwd", + "--input-file=${sessionFile.absolute.path.trim()}", + "--save-session=${sessionFile.absolute.path.trim()}", + "--save-session-interval=60", + "--file-allocation=trunc", + "--seed-time=0", + ], + workingDirectory: state.aria2cDir, + ); String launchError = ""; stream.listen((event) { - dPrint( - "Aria2cManager.rs_process event === [${event.rsPid}] ${event.dataType} >> ${event.data}"); + dPrint("Aria2cManager.rs_process event === [${event.rsPid}] ${event.dataType} >> ${event.data}"); switch (event.dataType) { case rs_process.RsProcessStreamDataType.output: if (event.data.contains("IPv4 RPC: listening on TCP port")) { @@ -155,8 +148,7 @@ class Aria2cModel extends _$Aria2cModel { } String generateRandomPassword(int length) { - const String charset = - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; + const String charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"; Random random = Random(); StringBuffer buffer = StringBuffer(); for (int i = 0; i < length; i++) { @@ -190,12 +182,12 @@ class Aria2cModel extends _$Aria2cModel { _listenState(aria2c); }); final box = await Hive.openBox("app_conf"); - aria2c.changeGlobalOption(Aria2Option() - ..maxOverallUploadLimit = - textToByte(box.get("downloader_up_limit", defaultValue: "0")) - ..maxOverallDownloadLimit = - textToByte(box.get("downloader_down_limit", defaultValue: "0")) - ..btTracker = trackerList); + aria2c.changeGlobalOption( + Aria2Option() + ..maxOverallUploadLimit = textToByte(box.get("downloader_up_limit", defaultValue: "0")) + ..maxOverallDownloadLimit = textToByte(box.get("downloader_down_limit", defaultValue: "0")) + ..btTracker = trackerList, + ); } Future _listenState(Aria2c aria2c) async { @@ -214,4 +206,14 @@ class Aria2cModel extends _$Aria2cModel { await Future.delayed(const Duration(seconds: 1)); } } + + Future isNameInTask(String name) async { + final aria2c = state.aria2c; + if (aria2c == null) return false; + for (var value in [...await aria2c.tellActive(), ...await aria2c.tellWaiting(0, 100000)]) { + final t = HomeDownloaderUIModel.getTaskTypeAndName(value); + return (t.key == "torrent" && t.value.contains(name)); + } + return false; + } } diff --git a/lib/provider/aria2c.g.dart b/lib/provider/aria2c.g.dart index c3877b2..ae3f316 100644 --- a/lib/provider/aria2c.g.dart +++ b/lib/provider/aria2c.g.dart @@ -41,7 +41,7 @@ final class Aria2cModelProvider } } -String _$aria2cModelHash() => r'3d51aeefd92e5291dca1f01db961f9c5496ec24f'; +String _$aria2cModelHash() => r'eb45d6aa9fc641abceb34ad5685aab57aa7a870b'; abstract class _$Aria2cModel extends $Notifier { Aria2cModelState build(); diff --git a/lib/ui/home/downloader/home_downloader_ui_model.dart b/lib/ui/home/downloader/home_downloader_ui_model.dart index d98f5a8..0d3dbcc 100644 --- a/lib/ui/home/downloader/home_downloader_ui_model.dart +++ b/lib/ui/home/downloader/home_downloader_ui_model.dart @@ -79,9 +79,10 @@ class HomeDownloaderUIModel extends _$HomeDownloaderUIModel { return; case "cancel_all": final userOK = await showConfirmDialogs( - context, - S.current.downloader_action_confirm_cancel_all_tasks, - Text(S.current.downloader_info_manual_file_deletion_note)); + context, + S.current.downloader_action_confirm_cancel_all_tasks, + Text(S.current.downloader_info_manual_file_deletion_note), + ); if (userOK == true) { if (!aria2cState.isRunning) return; try { @@ -101,31 +102,19 @@ class HomeDownloaderUIModel extends _$HomeDownloaderUIModel { } int getTasksLen() { - return state.tasks.length + - state.waitingTasks.length + - state.stoppedTasks.length; + return state.tasks.length + state.waitingTasks.length + state.stoppedTasks.length; } (Aria2Task, String, bool) getTaskAndType(int index) { - final tempList = [ - ...state.tasks, - ...state.waitingTasks, - ...state.stoppedTasks - ]; + final tempList = [...state.tasks, ...state.waitingTasks, ...state.stoppedTasks]; if (index >= 0 && index < state.tasks.length) { return (tempList[index], "active", index == 0); } - if (index >= state.tasks.length && - index < state.tasks.length + state.waitingTasks.length) { + if (index >= state.tasks.length && index < state.tasks.length + state.waitingTasks.length) { return (tempList[index], "waiting", index == state.tasks.length); } - if (index >= state.tasks.length + state.waitingTasks.length && - index < tempList.length) { - return ( - tempList[index], - "stopped", - index == state.tasks.length + state.waitingTasks.length - ); + if (index >= state.tasks.length + state.waitingTasks.length && index < tempList.length) { + return (tempList[index], "stopped", index == state.tasks.length + state.waitingTasks.length); } throw Exception("Index out of range or element is null"); } @@ -148,8 +137,7 @@ class HomeDownloaderUIModel extends _$HomeDownloaderUIModel { int getETA(Aria2Task task) { if (task.downloadSpeed == null || task.downloadSpeed == 0) return 0; - final remainingBytes = - (task.totalLength ?? 0) - (task.completedLength ?? 0); + final remainingBytes = (task.totalLength ?? 0) - (task.completedLength ?? 0); return remainingBytes ~/ (task.downloadSpeed!); } @@ -172,9 +160,10 @@ class HomeDownloaderUIModel extends _$HomeDownloaderUIModel { if (gid != null) { if (!context.mounted) return; final ok = await showConfirmDialogs( - context, - S.current.downloader_action_confirm_cancel_download, - Text(S.current.downloader_info_manual_file_deletion_note)); + context, + S.current.downloader_action_confirm_cancel_download, + Text(S.current.downloader_info_manual_file_deletion_note), + ); if (ok == true) { final aria2c = ref.read(aria2cModelProvider).aria2c; await aria2c?.remove(gid); @@ -204,8 +193,8 @@ class HomeDownloaderUIModel extends _$HomeDownloaderUIModel { Future _listenDownloader() async { try { while (true) { - final aria2cState = ref.read(aria2cModelProvider); if (_disposed) return; + final aria2cState = ref.read(aria2cModelProvider); if (aria2cState.isRunning) { final aria2c = aria2cState.aria2c!; final tasks = await aria2c.tellActive(); @@ -219,12 +208,7 @@ class HomeDownloaderUIModel extends _$HomeDownloaderUIModel { globalStat: globalStat, ); } else { - state = state.copyWith( - tasks: [], - waitingTasks: [], - stoppedTasks: [], - globalStat: null, - ); + state = state.copyWith(tasks: [], waitingTasks: [], stoppedTasks: [], globalStat: null); } await Future.delayed(const Duration(seconds: 1)); } @@ -236,72 +220,64 @@ class HomeDownloaderUIModel extends _$HomeDownloaderUIModel { Future _showDownloadSpeedSettings(BuildContext context) async { final box = await Hive.openBox("app_conf"); - final upCtrl = TextEditingController( - text: box.get("downloader_up_limit", defaultValue: "")); - final downCtrl = TextEditingController( - text: box.get("downloader_down_limit", defaultValue: "")); + final upCtrl = TextEditingController(text: box.get("downloader_up_limit", defaultValue: "")); + final downCtrl = TextEditingController(text: box.get("downloader_down_limit", defaultValue: "")); final ifr = FilteringTextInputFormatter.allow(RegExp(r'^\d*[km]?$')); if (!context.mounted) return; final ok = await showConfirmDialogs( - context, - S.current.downloader_speed_limit_settings, - Column( - mainAxisSize: MainAxisSize.min, - crossAxisAlignment: CrossAxisAlignment.start, - children: [ - Text( - S.current.downloader_info_p2p_network_note, - style: TextStyle( - fontSize: 14, - color: Colors.white.withValues(alpha: .6), - ), - ), - const SizedBox(height: 24), - Text(S.current.downloader_info_download_unit_input_prompt), - const SizedBox(height: 12), - Text(S.current.downloader_input_upload_speed_limit), - const SizedBox(height: 6), - TextFormBox( - placeholder: "1、100k、10m、0", - controller: upCtrl, - placeholderStyle: - TextStyle(color: Colors.white.withValues(alpha: .6)), - inputFormatters: [ifr], - ), - const SizedBox(height: 12), - Text(S.current.downloader_input_download_speed_limit), - const SizedBox(height: 6), - TextFormBox( - placeholder: "1、100k、10m、0", - controller: downCtrl, - placeholderStyle: - TextStyle(color: Colors.white.withValues(alpha: .6)), - inputFormatters: [ifr], - ), - const SizedBox(height: 24), - Text( - S.current.downloader_input_info_p2p_upload_note, - style: TextStyle( - fontSize: 13, - color: Colors.white.withValues(alpha: .6), - ), - ) - ], - )); + context, + S.current.downloader_speed_limit_settings, + Column( + mainAxisSize: MainAxisSize.min, + crossAxisAlignment: CrossAxisAlignment.start, + children: [ + Text( + S.current.downloader_info_p2p_network_note, + style: TextStyle(fontSize: 14, color: Colors.white.withValues(alpha: .6)), + ), + const SizedBox(height: 24), + Text(S.current.downloader_info_download_unit_input_prompt), + const SizedBox(height: 12), + Text(S.current.downloader_input_upload_speed_limit), + const SizedBox(height: 6), + TextFormBox( + placeholder: "1、100k、10m、0", + controller: upCtrl, + placeholderStyle: TextStyle(color: Colors.white.withValues(alpha: .6)), + inputFormatters: [ifr], + ), + const SizedBox(height: 12), + Text(S.current.downloader_input_download_speed_limit), + const SizedBox(height: 6), + TextFormBox( + placeholder: "1、100k、10m、0", + controller: downCtrl, + placeholderStyle: TextStyle(color: Colors.white.withValues(alpha: .6)), + inputFormatters: [ifr], + ), + const SizedBox(height: 24), + Text( + S.current.downloader_input_info_p2p_upload_note, + style: TextStyle(fontSize: 13, color: Colors.white.withValues(alpha: .6)), + ), + ], + ), + ); if (ok == true) { final aria2cState = ref.read(aria2cModelProvider); final aria2cModel = ref.read(aria2cModelProvider.notifier); - await aria2cModel - .launchDaemon(appGlobalState.applicationBinaryModuleDir!); + await aria2cModel.launchDaemon(appGlobalState.applicationBinaryModuleDir!); final aria2c = aria2cState.aria2c!; final upByte = aria2cModel.textToByte(upCtrl.text.trim()); final downByte = aria2cModel.textToByte(downCtrl.text.trim()); final r = await aria2c - .changeGlobalOption(Aria2Option() - ..maxOverallUploadLimit = upByte - ..maxOverallDownloadLimit = downByte) + .changeGlobalOption( + Aria2Option() + ..maxOverallUploadLimit = upByte + ..maxOverallDownloadLimit = downByte, + ) .unwrap(); if (r != null) { await box.put('downloader_up_limit', upCtrl.text.trim()); diff --git a/lib/ui/home/downloader/home_downloader_ui_model.g.dart b/lib/ui/home/downloader/home_downloader_ui_model.g.dart index 47f226a..67e4a2a 100644 --- a/lib/ui/home/downloader/home_downloader_ui_model.g.dart +++ b/lib/ui/home/downloader/home_downloader_ui_model.g.dart @@ -42,7 +42,7 @@ final class HomeDownloaderUIModelProvider } String _$homeDownloaderUIModelHash() => - r'5b410cd38315d94279b18f147903eca4b09bd445'; + r'cb5d0973d56bbf40673afc2a734b49f5d034ab98'; abstract class _$HomeDownloaderUIModel extends $Notifier { diff --git a/lib/ui/home/home_ui_model.g.dart b/lib/ui/home/home_ui_model.g.dart index 00590f0..41f41b5 100644 --- a/lib/ui/home/home_ui_model.g.dart +++ b/lib/ui/home/home_ui_model.g.dart @@ -41,7 +41,7 @@ final class HomeUIModelProvider } } -String _$homeUIModelHash() => r'9dc8191f358c2d8e21ed931b3755e08ce394558e'; +String _$homeUIModelHash() => r'7dfe73383f7be2e520a42d176e199a8db208f008'; abstract class _$HomeUIModel extends $Notifier { HomeUIModelState build(); diff --git a/lib/ui/home/input_method/input_method_dialog_ui.dart b/lib/ui/home/input_method/input_method_dialog_ui.dart index 7a46900..630b47e 100644 --- a/lib/ui/home/input_method/input_method_dialog_ui.dart +++ b/lib/ui/home/input_method/input_method_dialog_ui.dart @@ -3,6 +3,7 @@ import 'package:flutter/services.dart'; import 'package:flutter_hooks/flutter_hooks.dart'; import 'package:go_router/go_router.dart'; import 'package:hooks_riverpod/hooks_riverpod.dart'; +import 'package:starcitizen_doctor/common/utils/log.dart'; import 'package:starcitizen_doctor/ui/home/input_method/input_method_dialog_ui_model.dart'; import 'package:starcitizen_doctor/ui/home/input_method/server.dart'; import 'package:starcitizen_doctor/widgets/widgets.dart'; @@ -60,7 +61,9 @@ class InputMethodDialogUI extends HookConsumerWidget { ), SizedBox(height: 12), TextFormBox( - placeholder: S.current.input_method_input_placeholder, + placeholder: state.isEnableAutoTranslate + ? "${S.current.input_method_input_placeholder}\n\n本地翻译模型对中英混合处理能力较差,如有需要,建议分开发送。" + : S.current.input_method_input_placeholder, controller: srcTextCtrl, maxLines: 5, placeholderStyle: TextStyle(color: Colors.white.withValues(alpha: .6)), @@ -68,7 +71,9 @@ class InputMethodDialogUI extends HookConsumerWidget { onChanged: (str) async { final text = model.onTextChange("src", str); destTextCtrl.text = text ?? ""; - if (text != null) {} + if (text != null) { + model.checkAutoTranslate(); + } }, ), SizedBox(height: 16), @@ -91,17 +96,23 @@ class InputMethodDialogUI extends HookConsumerWidget { placeholderStyle: TextStyle(color: Colors.white.withValues(alpha: .6)), style: TextStyle(fontSize: 16, color: Colors.white), enabled: true, - onChanged: (str) { - // final text = model.onTextChange("dest", str); - // if (text != null) { - // srcTextCtrl.text = text; - // } - }, + onChanged: (str) {}, ), SizedBox(height: 24), Row( mainAxisAlignment: MainAxisAlignment.end, children: [ + Row( + children: [ + Text(S.current.input_method_auto_translate), + SizedBox(width: 6), + ToggleSwitch( + checked: state.isEnableAutoTranslate, + onChanged: (b) => _onSwitchAutoTranslate(context, model, b), + ), + ], + ), + SizedBox(width: 24), Row( children: [ Text(S.current.input_method_remote_input_service), @@ -194,4 +205,52 @@ class InputMethodDialogUI extends HookConsumerWidget { await serverModel.stopServer().unwrap(context: context); } } + + void _onSwitchAutoTranslate(BuildContext context, InputMethodDialogUIModel model, bool b) async { + if (b) { + // 检查下载任务 + if (await model.isTranslateModelDownloading()) { + if (!context.mounted) return; + showToast(context, "模型正在下载中,请稍后..."); + return; + } + // 打开,检查本地模型 + if (!await model.checkLocalTranslateModelAvailable()) { + if (!context.mounted) return; + // 询问用户是否下载模型 + final userOK = await showConfirmDialogs( + context, + "是否下载 AI 模型以使用翻译功能?", + Text( + "大约需要 200MB 的本地空间。" + "\n\n我们使用本地模型进行翻译,您的翻译数据不会发送给任何第三方。" + "\n\n模型未对游戏术语优化,请自行判断使用。", + ), + ); + if (userOK) { + try { + final guid = await model.doDownloadTranslateModel(); + if (guid.isNotEmpty) { + if (!context.mounted) return; + context.go("/index/downloader"); + await Future.delayed(Duration(seconds: 1)).then((_) { + if (!context.mounted) return; + showToast(context, "下载已开始,请在模型下载完成后重新启用翻译功能。"); + }); + return; + } + } catch (e) { + dPrint("下载模型失败:$e"); + if (context.mounted) { + showToast(context, "下载模型失败:$e"); + } + return; + } + } + return; + } + } + if (!context.mounted) return; + model.toggleAutoTranslate(b, context: context).unwrap(context: context); + } } diff --git a/lib/ui/home/input_method/input_method_dialog_ui_model.dart b/lib/ui/home/input_method/input_method_dialog_ui_model.dart index b9865f0..eab612a 100644 --- a/lib/ui/home/input_method/input_method_dialog_ui_model.dart +++ b/lib/ui/home/input_method/input_method_dialog_ui_model.dart @@ -1,12 +1,22 @@ +// ignore_for_file: avoid_build_context_in_providers, use_build_context_synchronously import 'dart:async'; +import 'dart:convert'; +import 'dart:io'; import 'package:flutter/services.dart'; import 'package:flutter/widgets.dart'; import 'package:freezed_annotation/freezed_annotation.dart'; import 'package:hive_ce/hive.dart'; import 'package:riverpod_annotation/riverpod_annotation.dart'; +import 'package:starcitizen_doctor/api/api.dart'; +import 'package:starcitizen_doctor/common/io/rs_http.dart'; +import 'package:starcitizen_doctor/common/utils/async.dart'; +import 'package:starcitizen_doctor/common/utils/base_utils.dart'; import 'package:starcitizen_doctor/common/utils/log.dart'; +import 'package:starcitizen_doctor/common/utils/provider.dart'; +import 'package:starcitizen_doctor/provider/aria2c.dart'; import 'package:starcitizen_doctor/ui/home/localization/localization_ui_model.dart'; +import 'package:starcitizen_doctor/common/rust/api/ort_api.dart' as ort; part 'input_method_dialog_ui_model.g.dart'; @@ -44,7 +54,8 @@ class InputMethodDialogUIModel extends _$InputMethodDialogUIModel { final worldMaps = keyMaps?.map((key, value) => MapEntry(value.trim(), key)); final appBox = await Hive.openBox("app_conf"); final enableAutoCopy = appBox.get("enableAutoCopy", defaultValue: false); - final isEnableAutoTranslate = appBox.get("isEnableAutoTranslate", defaultValue: false); + final isEnableAutoTranslate = appBox.get("isEnableAutoTranslate_v2", defaultValue: false); + _checkAutoTranslateOnInit(); state = state.copyWith( keyMaps: keyMaps, worldMaps: worldMaps, @@ -134,14 +145,216 @@ class InputMethodDialogUIModel extends _$InputMethodDialogUIModel { _srcTextCtrl?.text = text; _destTextCtrl?.text = onTextChange("src", text) ?? ""; if (_destTextCtrl?.text.isEmpty ?? true) return; + checkAutoTranslate(webMessage: true); if (autoCopy && !state.isAutoTranslateWorking) { Clipboard.setData(ClipboardData(text: _destTextCtrl?.text ?? "")); } } - Future toggleAutoTranslate(bool b) async { + // ignore: duplicate_ignore + // ignore: avoid_build_context_in_providers + Future toggleAutoTranslate(bool b, {BuildContext? context}) async { state = state.copyWith(isEnableAutoTranslate: b); final appConf = await Hive.openBox("app_conf"); - await appConf.put("isEnableAutoTranslate", b); + await appConf.put("isEnableAutoTranslate_v2", b); + if (b) { + mountOnnxTranslationProvider(_localTranslateModelDir, _localTranslateModelName, context: context); + } + } + + Timer? _translateTimer; + + Future checkAutoTranslate({bool webMessage = false}) async { + final sourceText = _srcTextCtrl?.text ?? ""; + final content = _destTextCtrl?.text ?? ""; + if (sourceText.trim().isEmpty) return; + if (state.isEnableAutoTranslate) { + if (_translateTimer != null) _translateTimer?.cancel(); + state = state.copyWith(isAutoTranslateWorking: true); + _translateTimer = Timer(Duration(milliseconds: webMessage ? 1 : 400), () async { + try { + final inputText = sourceText.replaceAll("\n", " "); + final r = await doTranslateText(inputText); + if (r != null) { + String resultText = r; + // resultText 首字母大写 + if (content.isNotEmpty) { + final firstChar = resultText.characters.first; + resultText = resultText.replaceFirst(firstChar, firstChar.toUpperCase()); + } + _destTextCtrl?.text = "$content \n[en] $resultText"; + if (state.enableAutoCopy || webMessage) { + Clipboard.setData(ClipboardData(text: _destTextCtrl?.text ?? "")); + } + } + } catch (e) { + dPrint("[InputMethodDialogUIModel] AutoTranslate error: $e"); + } + state = state.copyWith(isAutoTranslateWorking: false); + }); + } + } + + String get _localTranslateModelName => "opus-mt-zh-en_onnx"; + + String get _localTranslateModelDir => "${appGlobalState.applicationSupportDir}/onnx_models"; + + OnnxTranslationProvider get _localTranslateModelProvider => + onnxTranslationProvider(_localTranslateModelDir, _localTranslateModelName); + + void _checkAutoTranslateOnInit() { + // 检查模型文件是否存在,不存在则关闭自动翻译 + if (state.isEnableAutoTranslate) { + checkLocalTranslateModelAvailable().then((available) { + if (!available) { + toggleAutoTranslate(false); + } + }); + } + } + + Future checkLocalTranslateModelAvailable() async { + final fileCheckList = const [ + "config.json", + "tokenizer.json", + "vocab.json", + "onnx/decoder_model_q4f16.onnx", + "onnx/encoder_model_q4f16.onnx", + ]; + var allExist = true; + for (var fileName in fileCheckList) { + final filePath = "$_localTranslateModelDir/$_localTranslateModelName/$fileName"; + if (!await File(filePath).exists()) { + allExist = false; + break; + } + } + return allExist; + } + + Future doDownloadTranslateModel() async { + state = state.copyWith(isAutoTranslateWorking: true); + try { + final aria2cManager = ref.read(aria2cModelProvider.notifier); + await aria2cManager.launchDaemon(appGlobalState.applicationBinaryModuleDir!); + final aria2c = ref.read(aria2cModelProvider).aria2c!; + + if (await aria2cManager.isNameInTask(_localTranslateModelName)) { + throw Exception("Model is already downloading"); + } + + final l = await Api.getAppTorrentDataList(); + final modelTorrent = l.firstWhere( + (element) => element.name == _localTranslateModelName, + orElse: () => throw Exception("Model torrent not found"), + ); + final torrentUrl = modelTorrent.url; + if (torrentUrl?.isEmpty ?? true) { + throw Exception("Get model torrent url failed"); + } + // get torrent Data + final data = await RSHttp.get(torrentUrl!); + final b64Str = base64Encode(data.data!); + final gid = await aria2c.addTorrent(b64Str, extraParams: {"dir": _localTranslateModelDir}); + return gid; + } catch (e) { + dPrint("[InputMethodDialogUIModel] doDownloadTranslateModel error: $e"); + rethrow; + } finally { + state = state.copyWith(isAutoTranslateWorking: false); + } + } + + Future mountOnnxTranslationProvider( + String localTranslateModelDir, + String localTranslateModelName, { + BuildContext? context, + }) async { + if (!ref.exists(_localTranslateModelProvider)) { + ref.listen(_localTranslateModelProvider, ((_, _) {})); + final err = await ref.read(_localTranslateModelProvider.notifier).initModel(); + _handleTranslateModel(context, err); + } else { + // 重新加载 + final err = await ref.read(_localTranslateModelProvider.notifier).initModel(); + _handleTranslateModel(context, err); + } + } + + Future _handleTranslateModel(BuildContext? context, String? err) async { + if (err != null) { + dPrint("[InputMethodDialogUIModel] mountOnnxTranslationProvider failed to init model"); + if (context != null) { + if (!context.mounted) return; + final userOK = await showConfirmDialogs(context, "翻译模型加载失败", Text("是否删除本地文件,稍后您可以尝试重新下载。错误信息:\n$err")); + if (userOK) { + // 删除文件,并禁用开关 + final dir = Directory("$_localTranslateModelDir/$_localTranslateModelName"); + if (await dir.exists()) { + await dir.delete(recursive: true); + dPrint("[InputMethodDialogUIModel] Deleted local translate model files."); + toggleAutoTranslate(false); + } + } + } else { + // 禁用开关 + toggleAutoTranslate(false); + } + } + } + + Future doTranslateText(String text) async { + if (!ref.exists(_localTranslateModelProvider)) { + await mountOnnxTranslationProvider(_localTranslateModelDir, _localTranslateModelName); + } + final onnxTranslationState = ref.read(_localTranslateModelProvider); + if (!onnxTranslationState) { + return null; + } + try { + final result = await ort.translateText(modelKey: _localTranslateModelName, text: text); + return result; + } catch (e) { + dPrint("[InputMethodDialogUIModel] doTranslateText error: $e"); + return null; + } + } + + Future isTranslateModelDownloading() async { + final aria2cManager = ref.read(aria2cModelProvider.notifier); + return await aria2cManager.isNameInTask(_localTranslateModelName); + } +} + +@riverpod +class OnnxTranslation extends _$OnnxTranslation { + @override + bool build(String modelDir, String modelName) { + dPrint("[OnnxTranslation] Build provider for model: $modelName"); + ref.onDispose(disposeModel); + return false; + } + + Future initModel() async { + dPrint("[OnnxTranslation] Load model: $modelName from $modelDir"); + String? errorMessage; + try { + await ort.loadTranslationModel( + modelPath: "$modelDir/$modelName", + modelKey: modelName, + quantizationSuffix: "_q4f16", + ); + state = true; + } catch (e) { + dPrint("[OnnxTranslation] Load model error: $e"); + errorMessage = e.toString(); + state = false; + } + return errorMessage; + } + + Future disposeModel() async { + await ort.unloadTranslationModel(modelKey: modelName).unwrap(); + dPrint("[OnnxTranslation] Unload model: $modelName"); } } diff --git a/lib/ui/home/input_method/input_method_dialog_ui_model.g.dart b/lib/ui/home/input_method/input_method_dialog_ui_model.g.dart index 9c335da..5e49b7e 100644 --- a/lib/ui/home/input_method/input_method_dialog_ui_model.g.dart +++ b/lib/ui/home/input_method/input_method_dialog_ui_model.g.dart @@ -43,7 +43,7 @@ final class InputMethodDialogUIModelProvider } String _$inputMethodDialogUIModelHash() => - r'c07ef2474866bdb3944892460879121e0f90591f'; + r'39b7fc1446c09514b837c0f181488d34a4391751'; abstract class _$InputMethodDialogUIModel extends $Notifier { @@ -65,3 +65,102 @@ abstract class _$InputMethodDialogUIModel element.handleValue(ref, created); } } + +@ProviderFor(OnnxTranslation) +const onnxTranslationProvider = OnnxTranslationFamily._(); + +final class OnnxTranslationProvider + extends $NotifierProvider { + const OnnxTranslationProvider._({ + required OnnxTranslationFamily super.from, + required (String, String) super.argument, + }) : super( + retry: null, + name: r'onnxTranslationProvider', + isAutoDispose: true, + dependencies: null, + $allTransitiveDependencies: null, + ); + + @override + String debugGetCreateSourceHash() => _$onnxTranslationHash(); + + @override + String toString() { + return r'onnxTranslationProvider' + '' + '$argument'; + } + + @$internal + @override + OnnxTranslation create() => OnnxTranslation(); + + /// {@macro riverpod.override_with_value} + Override overrideWithValue(bool value) { + return $ProviderOverride( + origin: this, + providerOverride: $SyncValueProvider(value), + ); + } + + @override + bool operator ==(Object other) { + return other is OnnxTranslationProvider && other.argument == argument; + } + + @override + int get hashCode { + return argument.hashCode; + } +} + +String _$onnxTranslationHash() => r'05b7b063a1013eed1ee4daae5212b3b6c555cd82'; + +final class OnnxTranslationFamily extends $Family + with + $ClassFamilyOverride< + OnnxTranslation, + bool, + bool, + bool, + (String, String) + > { + const OnnxTranslationFamily._() + : super( + retry: null, + name: r'onnxTranslationProvider', + dependencies: null, + $allTransitiveDependencies: null, + isAutoDispose: true, + ); + + OnnxTranslationProvider call(String modelDir, String modelName) => + OnnxTranslationProvider._(argument: (modelDir, modelName), from: this); + + @override + String toString() => r'onnxTranslationProvider'; +} + +abstract class _$OnnxTranslation extends $Notifier { + late final _$args = ref.$arg as (String, String); + String get modelDir => _$args.$1; + String get modelName => _$args.$2; + + bool build(String modelDir, String modelName); + @$mustCallSuper + @override + void runBuild() { + final created = build(_$args.$1, _$args.$2); + final ref = this.ref as $Ref; + final element = + ref.element + as $ClassProviderElement< + AnyNotifier, + bool, + Object?, + Object? + >; + element.handleValue(ref, created); + } +} diff --git a/lib/ui/home/localization/advanced_localization_ui_model.g.dart b/lib/ui/home/localization/advanced_localization_ui_model.g.dart index 38fb741..4b5d99d 100644 --- a/lib/ui/home/localization/advanced_localization_ui_model.g.dart +++ b/lib/ui/home/localization/advanced_localization_ui_model.g.dart @@ -47,7 +47,7 @@ final class AdvancedLocalizationUIModelProvider } String _$advancedLocalizationUIModelHash() => - r'2f890c854bc56e506c441acabc2014438a163617'; + r'c7cca8935ac7df2281e83297b11b6b82d94f7a59'; abstract class _$AdvancedLocalizationUIModel extends $Notifier { diff --git a/lib/ui/home/localization/localization_ui_model.dart b/lib/ui/home/localization/localization_ui_model.dart index 680dd18..2954421 100644 --- a/lib/ui/home/localization/localization_ui_model.dart +++ b/lib/ui/home/localization/localization_ui_model.dart @@ -66,6 +66,10 @@ class LocalizationUIModel extends _$LocalizationUIModel { @override LocalizationUIState build() { state = LocalizationUIState(selectedLanguage: languageSupport.keys.first); + ref.onDispose(() { + _customizeDirListenSub?.cancel(); + _customizeDirListenSub = null; + }); _init(); return state; } @@ -74,10 +78,6 @@ class LocalizationUIModel extends _$LocalizationUIModel { if (_scInstallPath == "not_install") { return; } - ref.onDispose(() { - _customizeDirListenSub?.cancel(); - _customizeDirListenSub = null; - }); final appConfBox = await Hive.openBox("app_conf"); final lang = await appConfBox.get("localization_selectedLanguage", defaultValue: languageSupport.keys.first); state = state.copyWith(selectedLanguage: lang); diff --git a/lib/ui/home/localization/localization_ui_model.g.dart b/lib/ui/home/localization/localization_ui_model.g.dart index 88a4385..eee7afd 100644 --- a/lib/ui/home/localization/localization_ui_model.g.dart +++ b/lib/ui/home/localization/localization_ui_model.g.dart @@ -42,7 +42,7 @@ final class LocalizationUIModelProvider } String _$localizationUIModelHash() => - r'd3797a7ff3d31dd1d4b05aed4a9969f4be6853c5'; + r'3d3f0ed7fa3631eca4e10d456c437f6fca8eedff'; abstract class _$LocalizationUIModel extends $Notifier { LocalizationUIState build(); diff --git a/lib/ui/home/performance/performance_ui_model.g.dart b/lib/ui/home/performance/performance_ui_model.g.dart index 2733524..70dc3d2 100644 --- a/lib/ui/home/performance/performance_ui_model.g.dart +++ b/lib/ui/home/performance/performance_ui_model.g.dart @@ -42,7 +42,7 @@ final class HomePerformanceUIModelProvider } String _$homePerformanceUIModelHash() => - r'c3c55c0470ef8c8be4915a1878deba332653ecde'; + r'4c5c33fe7d85dc8f6bf0d019c1b870d285d594ff'; abstract class _$HomePerformanceUIModel extends $Notifier { diff --git a/lib/ui/tools/tools_ui_model.dart b/lib/ui/tools/tools_ui_model.dart index e9eb890..bbe837e 100644 --- a/lib/ui/tools/tools_ui_model.dart +++ b/lib/ui/tools/tools_ui_model.dart @@ -19,7 +19,6 @@ import 'package:starcitizen_doctor/common/utils/log.dart'; import 'package:starcitizen_doctor/common/utils/multi_window_manager.dart'; import 'package:starcitizen_doctor/common/utils/provider.dart'; import 'package:starcitizen_doctor/provider/aria2c.dart'; -import 'package:starcitizen_doctor/ui/home/downloader/home_downloader_ui_model.dart'; import 'package:starcitizen_doctor/widgets/widgets.dart'; import 'package:url_launcher/url_launcher_string.dart'; import 'package:xml/xml.dart'; @@ -423,14 +422,11 @@ class ToolsUIModel extends _$ToolsUIModel { final aria2c = ref.read(aria2cModelProvider).aria2c!; // check download task list - for (var value in [...await aria2c.tellActive(), ...await aria2c.tellWaiting(0, 100000)]) { - final t = HomeDownloaderUIModel.getTaskTypeAndName(value); - if (t.key == "torrent" && t.value.contains("Data.p4k")) { - if (!context.mounted) return; - showToast(context, S.current.tools_action_info_p4k_download_in_progress); - state = state.copyWith(working: false); - return; - } + if (await aria2cManager.isNameInTask("Data.p4k")) { + if (!context.mounted) return; + showToast(context, S.current.tools_action_info_p4k_download_in_progress); + state = state.copyWith(working: false); + return; } if (torrentUrl == "") { diff --git a/lib/ui/tools/tools_ui_model.g.dart b/lib/ui/tools/tools_ui_model.g.dart index 8ae0dff..367e182 100644 --- a/lib/ui/tools/tools_ui_model.g.dart +++ b/lib/ui/tools/tools_ui_model.g.dart @@ -41,7 +41,7 @@ final class ToolsUIModelProvider } } -String _$toolsUIModelHash() => r'81a73aeccf978f7e620681eaf1a3d4182ff48f9e'; +String _$toolsUIModelHash() => r'885596b0df27191f2c69c571b0a1f60d9c6e31de'; abstract class _$ToolsUIModel extends $Notifier { ToolsUIState build(); diff --git a/rust/Cargo.lock b/rust/Cargo.lock index e005048..5bc9c55 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -17,6 +17,20 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "ahash" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a15f179cd60c4584b8a8c596927aadc462e27f2ca70c04e0071964a73ba7a75" +dependencies = [ + "cfg-if", + "getrandom 0.3.4", + "once_cell", + "serde", + "version_check", + "zerocopy", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -401,6 +415,15 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" +[[package]] +name = "castaway" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" +dependencies = [ + "rustversion", +] + [[package]] name = "cc" version = "1.2.43" @@ -508,6 +531,21 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "compact_str" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +dependencies = [ + "castaway", + "cfg-if", + "itoa", + "rustversion", + "ryu", + "serde", + "static_assertions", +] + [[package]] name = "compression-codecs" version = "0.4.31" @@ -756,6 +794,15 @@ dependencies = [ "cc", ] +[[package]] +name = "dary_heap" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06d2e3287df1c007e74221c49ca10a95d557349e54b3a75dc2fb14712c751f04" +dependencies = [ + "serde", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -1322,7 +1369,7 @@ dependencies = [ "idna", "ipnet", "once_cell", - "rand 0.9.2", + "rand", "ring", "thiserror 2.0.17", "tinyvec", @@ -1344,7 +1391,7 @@ dependencies = [ "moka", "once_cell", "parking_lot", - "rand 0.9.2", + "rand", "resolv-conf", "smallvec 1.15.1", "thiserror 2.0.17", @@ -1693,9 +1740,9 @@ dependencies = [ [[package]] name = "itertools" -version = "0.12.1" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" dependencies = [ "either", ] @@ -1946,6 +1993,21 @@ dependencies = [ "rawpointer", ] +[[package]] +name = "ndarray" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7c9125e8f6f10c9da3aad044cc918cf8784fa34de857b1aa68038eb05a50a9" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nix" version = "0.30.1" @@ -2172,7 +2234,7 @@ version = "2.0.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" dependencies = [ - "ndarray", + "ndarray 0.16.1", "ort-sys", "smallvec 2.0.0-alpha.10", "tracing", @@ -2421,7 +2483,7 @@ dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand 0.9.2", + "rand", "ring", "rustc-hash", "rustls", @@ -2462,35 +2524,14 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.3", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core 0.6.4", + "rand_chacha", + "rand_core", ] [[package]] @@ -2500,16 +2541,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core 0.9.3", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom 0.2.16", + "rand_core", ] [[package]] @@ -2539,12 +2571,12 @@ dependencies = [ [[package]] name = "rayon-cond" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +checksum = "2964d0cf57a3e7a06e8183d14a8b527195c706b7983549cd5462d5aa3747438f" dependencies = [ "either", - "itertools 0.11.0", + "itertools 0.14.0", "rayon", ] @@ -2695,7 +2727,7 @@ dependencies = [ "flutter_rust_bridge", "futures", "hickory-resolver", - "ndarray", + "ndarray 0.17.1", "notify-rust", "once_cell", "ort", @@ -3301,22 +3333,24 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokenizers" -version = "0.20.4" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" +checksum = "6475a27088c98ea96d00b39a9ddfb63780d1ad4cceb6f48374349a96ab2b7842" dependencies = [ + "ahash", "aho-corasick", + "compact_str", + "dary_heap", "derive_builder", "esaxx-rs", - "getrandom 0.2.16", - "itertools 0.12.1", - "lazy_static", + "getrandom 0.3.4", + "itertools 0.14.0", "log", "macro_rules_attribute", "monostate", "onig", "paste", - "rand 0.8.5", + "rand", "rayon", "rayon-cond", "regex", @@ -3324,7 +3358,7 @@ dependencies = [ "serde", "serde_json", "spm_precompiled", - "thiserror 1.0.69", + "thiserror 2.0.17", "unicode-normalization-alignments", "unicode-segmentation", "unicode_categories", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 3bee994..204095f 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -24,8 +24,8 @@ notify-rust = "4" asar = "0.3.0" walkdir = "2.5.0" ort = { version = "2.0.0-rc.10", features = ["xnnpack", "download-binaries", "ndarray"] } -tokenizers = { version = "0.20", default-features = false, features = ["onig"] } -ndarray = "0.16" +tokenizers = { version = "0.22", default-features = false, features = ["onig"] } +ndarray = "0.17" serde_json = "1.0" [target.'cfg(windows)'.dependencies] diff --git a/rust/src/ort_models/opus_mt.rs b/rust/src/ort_models/opus_mt.rs index 63ae56e..ef274d9 100644 --- a/rust/src/ort_models/opus_mt.rs +++ b/rust/src/ort_models/opus_mt.rs @@ -245,10 +245,16 @@ impl OpusMtModel { .context("Failed to create attention_mask array")?; // 3. 运行 encoder - let input_ids_value = - Value::from_array(input_ids_array).context("Failed to create input_ids value")?; - let attention_mask_value = Value::from_array(attention_mask_array.clone()) - .context("Failed to create attention_mask value")?; + let input_ids_value = Value::from_array(( + input_ids_array.shape().to_vec(), + input_ids_array.into_raw_vec_and_offset().0, + )) + .context("Failed to create input_ids value")?; + let attention_mask_value = Value::from_array(( + attention_mask_array.shape().to_vec(), + attention_mask_array.clone().into_raw_vec_and_offset().0, + )) + .context("Failed to create attention_mask value")?; let encoder_inputs = ort::inputs![ "input_ids" => input_ids_value, @@ -303,12 +309,21 @@ impl OpusMtModel { .context("Failed to create decoder input_ids")?; // 创建 ORT Value - let decoder_input_value = Value::from_array(decoder_input_ids) - .context("Failed to create decoder input value")?; - let encoder_hidden_value = Value::from_array(encoder_hidden_states.clone()) - .context("Failed to create encoder hidden value")?; - let encoder_mask_value = Value::from_array(encoder_attention_mask.clone()) - .context("Failed to create encoder mask value")?; + let decoder_input_value = Value::from_array(( + decoder_input_ids.shape().to_vec(), + decoder_input_ids.into_raw_vec_and_offset().0, + )) + .context("Failed to create decoder input value")?; + let encoder_hidden_value = Value::from_array(( + encoder_hidden_states.shape().to_vec(), + encoder_hidden_states.clone().into_raw_vec_and_offset().0, + )) + .context("Failed to create encoder hidden value")?; + let encoder_mask_value = Value::from_array(( + encoder_attention_mask.shape().to_vec(), + encoder_attention_mask.clone().into_raw_vec_and_offset().0, + )) + .context("Failed to create encoder mask value")?; // 运行 decoder let decoder_inputs = ort::inputs![