diff --git a/lib/common/rust/api/ort_api.dart b/lib/common/rust/api/ort_api.dart new file mode 100644 index 0000000..49973bf --- /dev/null +++ b/lib/common/rust/api/ort_api.dart @@ -0,0 +1,72 @@ +// This file is automatically generated, so please do not edit it. +// @generated by `flutter_rust_bridge`@ 2.11.1. + +// ignore_for_file: invalid_use_of_internal_member, unused_import, unnecessary_import + +import '../frb_generated.dart'; +import 'package:flutter_rust_bridge/flutter_rust_bridge_for_generated.dart'; + +/// 加载 ONNX 翻译模型 +/// +/// # Arguments +/// * `model_path` - 模型文件夹路径 +/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en") +/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型) +/// +Future loadTranslationModel({ + required String modelPath, + required String modelKey, + required String quantizationSuffix, +}) => RustLib.instance.api.crateApiOrtApiLoadTranslationModel( + modelPath: modelPath, + modelKey: modelKey, + quantizationSuffix: quantizationSuffix, +); + +/// 翻译文本 +/// +/// # Arguments +/// * `model_key` - 模型缓存键(如 "zh-en") +/// * `text` - 要翻译的文本 +/// +/// # Returns +/// * `Result` - 翻译后的文本 +Future translateText({ + required String modelKey, + required String text, +}) => RustLib.instance.api.crateApiOrtApiTranslateText( + modelKey: modelKey, + text: text, +); + +/// 批量翻译文本 +/// +/// # Arguments +/// * `model_key` - 模型缓存键(如 "zh-en") +/// * `texts` - 要翻译的文本列表 +/// +/// # Returns +/// * `Result>` - 翻译后的文本列表 +Future> translateTextBatch({ + required String modelKey, + required List texts, +}) => RustLib.instance.api.crateApiOrtApiTranslateTextBatch( + modelKey: modelKey, + texts: texts, +); + +/// 卸载模型 +/// +/// # Arguments +/// * `model_key` - 模型缓存键(如 "zh-en") +/// +Future unloadTranslationModel({required String modelKey}) => RustLib + .instance + .api + .crateApiOrtApiUnloadTranslationModel(modelKey: modelKey); + +/// 清空所有已加载的模型 +/// +/// # Returns +Future clearAllModels() => + RustLib.instance.api.crateApiOrtApiClearAllModels(); diff --git a/lib/common/rust/frb_generated.dart b/lib/common/rust/frb_generated.dart index f7d0ba7..5cbea7e 100644 --- a/lib/common/rust/frb_generated.dart +++ b/lib/common/rust/frb_generated.dart @@ -5,6 +5,7 @@ import 'api/asar_api.dart'; import 'api/http_api.dart'; +import 'api/ort_api.dart'; import 'api/rs_process.dart'; import 'api/win32_api.dart'; import 'dart:async'; @@ -68,7 +69,7 @@ class RustLib extends BaseEntrypoint { String get codegenVersion => '2.11.1'; @override - int get rustContentHash => 1832496273; + int get rustContentHash => -706588047; static const kDefaultExternalLibraryLoaderConfig = ExternalLibraryLoaderConfig( @@ -79,6 +80,8 @@ class RustLib extends BaseEntrypoint { } abstract class RustLibApi extends BaseApi { + Future crateApiOrtApiClearAllModels(); + Future> crateApiHttpApiDnsLookupIps({required String host}); Future> crateApiHttpApiDnsLookupTxt({required String host}); @@ -96,6 +99,12 @@ abstract class RustLibApi extends BaseApi { required String asarPath, }); + Future crateApiOrtApiLoadTranslationModel({ + required String modelPath, + required String modelKey, + required String quantizationSuffix, + }); + Future crateApiAsarApiRsiLauncherAsarDataWriteMainJs({ required RsiLauncherAsarData that, required List content, @@ -122,6 +131,18 @@ abstract class RustLibApi extends BaseApi { required String workingDirectory, }); + Future crateApiOrtApiTranslateText({ + required String modelKey, + required String text, + }); + + Future> crateApiOrtApiTranslateTextBatch({ + required String modelKey, + required List texts, + }); + + Future crateApiOrtApiUnloadTranslationModel({required String modelKey}); + Future crateApiRsProcessWrite({ required int rsPid, required String data, @@ -136,6 +157,27 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { required super.portManager, }); + @override + Future crateApiOrtApiClearAllModels() { + return handler.executeNormal( + NormalTask( + callFfi: (port_) { + return wire.wire__crate__api__ort_api__clear_all_models(port_); + }, + codec: DcoCodec( + decodeSuccessData: dco_decode_unit, + decodeErrorData: dco_decode_AnyhowException, + ), + constMeta: kCrateApiOrtApiClearAllModelsConstMeta, + argValues: [], + apiImpl: this, + ), + ); + } + + TaskConstMeta get kCrateApiOrtApiClearAllModelsConstMeta => + const TaskConstMeta(debugName: "clear_all_models", argNames: []); + @override Future> crateApiHttpApiDnsLookupIps({required String host}) { return handler.executeNormal( @@ -268,6 +310,42 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { argNames: ["asarPath"], ); + @override + Future crateApiOrtApiLoadTranslationModel({ + required String modelPath, + required String modelKey, + required String quantizationSuffix, + }) { + return handler.executeNormal( + NormalTask( + callFfi: (port_) { + var arg0 = cst_encode_String(modelPath); + var arg1 = cst_encode_String(modelKey); + var arg2 = cst_encode_String(quantizationSuffix); + return wire.wire__crate__api__ort_api__load_translation_model( + port_, + arg0, + arg1, + arg2, + ); + }, + codec: DcoCodec( + decodeSuccessData: dco_decode_unit, + decodeErrorData: dco_decode_AnyhowException, + ), + constMeta: kCrateApiOrtApiLoadTranslationModelConstMeta, + argValues: [modelPath, modelKey, quantizationSuffix], + apiImpl: this, + ), + ); + } + + TaskConstMeta get kCrateApiOrtApiLoadTranslationModelConstMeta => + const TaskConstMeta( + debugName: "load_translation_model", + argNames: ["modelPath", "modelKey", "quantizationSuffix"], + ); + @override Future crateApiAsarApiRsiLauncherAsarDataWriteMainJs({ required RsiLauncherAsarData that, @@ -443,6 +521,102 @@ class RustLibApiImpl extends RustLibApiImplPlatform implements RustLibApi { argNames: ["executable", "arguments", "workingDirectory", "streamSink"], ); + @override + Future crateApiOrtApiTranslateText({ + required String modelKey, + required String text, + }) { + return handler.executeNormal( + NormalTask( + callFfi: (port_) { + var arg0 = cst_encode_String(modelKey); + var arg1 = cst_encode_String(text); + return wire.wire__crate__api__ort_api__translate_text( + port_, + arg0, + arg1, + ); + }, + codec: DcoCodec( + decodeSuccessData: dco_decode_String, + decodeErrorData: dco_decode_AnyhowException, + ), + constMeta: kCrateApiOrtApiTranslateTextConstMeta, + argValues: [modelKey, text], + apiImpl: this, + ), + ); + } + + TaskConstMeta get kCrateApiOrtApiTranslateTextConstMeta => + const TaskConstMeta( + debugName: "translate_text", + argNames: ["modelKey", "text"], + ); + + @override + Future> crateApiOrtApiTranslateTextBatch({ + required String modelKey, + required List texts, + }) { + return handler.executeNormal( + NormalTask( + callFfi: (port_) { + var arg0 = cst_encode_String(modelKey); + var arg1 = cst_encode_list_String(texts); + return wire.wire__crate__api__ort_api__translate_text_batch( + port_, + arg0, + arg1, + ); + }, + codec: DcoCodec( + decodeSuccessData: dco_decode_list_String, + decodeErrorData: dco_decode_AnyhowException, + ), + constMeta: kCrateApiOrtApiTranslateTextBatchConstMeta, + argValues: [modelKey, texts], + apiImpl: this, + ), + ); + } + + TaskConstMeta get kCrateApiOrtApiTranslateTextBatchConstMeta => + const TaskConstMeta( + debugName: "translate_text_batch", + argNames: ["modelKey", "texts"], + ); + + @override + Future crateApiOrtApiUnloadTranslationModel({ + required String modelKey, + }) { + return handler.executeNormal( + NormalTask( + callFfi: (port_) { + var arg0 = cst_encode_String(modelKey); + return wire.wire__crate__api__ort_api__unload_translation_model( + port_, + arg0, + ); + }, + codec: DcoCodec( + decodeSuccessData: dco_decode_unit, + decodeErrorData: dco_decode_AnyhowException, + ), + constMeta: kCrateApiOrtApiUnloadTranslationModelConstMeta, + argValues: [modelKey], + apiImpl: this, + ), + ); + } + + TaskConstMeta get kCrateApiOrtApiUnloadTranslationModelConstMeta => + const TaskConstMeta( + debugName: "unload_translation_model", + argNames: ["modelKey"], + ); + @override Future crateApiRsProcessWrite({ required int rsPid, diff --git a/lib/common/rust/frb_generated.io.dart b/lib/common/rust/frb_generated.io.dart index e9a0a38..190df88 100644 --- a/lib/common/rust/frb_generated.io.dart +++ b/lib/common/rust/frb_generated.io.dart @@ -5,6 +5,7 @@ import 'api/asar_api.dart'; import 'api/http_api.dart'; +import 'api/ort_api.dart'; import 'api/rs_process.dart'; import 'api/win32_api.dart'; import 'dart:async'; @@ -614,6 +615,18 @@ class RustLibWire implements BaseWire { late final _store_dart_post_cobject = _store_dart_post_cobjectPtr .asFunction(); + void wire__crate__api__ort_api__clear_all_models(int port_) { + return _wire__crate__api__ort_api__clear_all_models(port_); + } + + late final _wire__crate__api__ort_api__clear_all_modelsPtr = + _lookup>( + 'frbgen_starcitizen_doctor_wire__crate__api__ort_api__clear_all_models', + ); + late final _wire__crate__api__ort_api__clear_all_models = + _wire__crate__api__ort_api__clear_all_modelsPtr + .asFunction(); + void wire__crate__api__http_api__dns_lookup_ips( int port_, ffi.Pointer host, @@ -733,6 +746,44 @@ class RustLibWire implements BaseWire { void Function(int, ffi.Pointer) >(); + void wire__crate__api__ort_api__load_translation_model( + int port_, + ffi.Pointer model_path, + ffi.Pointer model_key, + ffi.Pointer quantization_suffix, + ) { + return _wire__crate__api__ort_api__load_translation_model( + port_, + model_path, + model_key, + quantization_suffix, + ); + } + + late final _wire__crate__api__ort_api__load_translation_modelPtr = + _lookup< + ffi.NativeFunction< + ffi.Void Function( + ffi.Int64, + ffi.Pointer, + ffi.Pointer, + ffi.Pointer, + ) + > + >( + 'frbgen_starcitizen_doctor_wire__crate__api__ort_api__load_translation_model', + ); + late final _wire__crate__api__ort_api__load_translation_model = + _wire__crate__api__ort_api__load_translation_modelPtr + .asFunction< + void Function( + int, + ffi.Pointer, + ffi.Pointer, + ffi.Pointer, + ) + >(); + void wire__crate__api__asar_api__rsi_launcher_asar_data_write_main_js( int port_, ffi.Pointer that, @@ -898,6 +949,95 @@ class RustLibWire implements BaseWire { ) >(); + void wire__crate__api__ort_api__translate_text( + int port_, + ffi.Pointer model_key, + ffi.Pointer text, + ) { + return _wire__crate__api__ort_api__translate_text(port_, model_key, text); + } + + late final _wire__crate__api__ort_api__translate_textPtr = + _lookup< + ffi.NativeFunction< + ffi.Void Function( + ffi.Int64, + ffi.Pointer, + ffi.Pointer, + ) + > + >('frbgen_starcitizen_doctor_wire__crate__api__ort_api__translate_text'); + late final _wire__crate__api__ort_api__translate_text = + _wire__crate__api__ort_api__translate_textPtr + .asFunction< + void Function( + int, + ffi.Pointer, + ffi.Pointer, + ) + >(); + + void wire__crate__api__ort_api__translate_text_batch( + int port_, + ffi.Pointer model_key, + ffi.Pointer texts, + ) { + return _wire__crate__api__ort_api__translate_text_batch( + port_, + model_key, + texts, + ); + } + + late final _wire__crate__api__ort_api__translate_text_batchPtr = + _lookup< + ffi.NativeFunction< + ffi.Void Function( + ffi.Int64, + ffi.Pointer, + ffi.Pointer, + ) + > + >( + 'frbgen_starcitizen_doctor_wire__crate__api__ort_api__translate_text_batch', + ); + late final _wire__crate__api__ort_api__translate_text_batch = + _wire__crate__api__ort_api__translate_text_batchPtr + .asFunction< + void Function( + int, + ffi.Pointer, + ffi.Pointer, + ) + >(); + + void wire__crate__api__ort_api__unload_translation_model( + int port_, + ffi.Pointer model_key, + ) { + return _wire__crate__api__ort_api__unload_translation_model( + port_, + model_key, + ); + } + + late final _wire__crate__api__ort_api__unload_translation_modelPtr = + _lookup< + ffi.NativeFunction< + ffi.Void Function( + ffi.Int64, + ffi.Pointer, + ) + > + >( + 'frbgen_starcitizen_doctor_wire__crate__api__ort_api__unload_translation_model', + ); + late final _wire__crate__api__ort_api__unload_translation_model = + _wire__crate__api__ort_api__unload_translation_modelPtr + .asFunction< + void Function(int, ffi.Pointer) + >(); + void wire__crate__api__rs_process__write( int port_, int rs_pid, diff --git a/lib/main.dart b/lib/main.dart index 5299e5f..9cecfd4 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -29,10 +29,7 @@ void main(List args) async { Future _initWindow() async { await windowManager.ensureInitialized(); - await windowManager.setTitleBarStyle( - TitleBarStyle.hidden, - windowButtonVisibility: false, - ); + await windowManager.setTitleBarStyle(TitleBarStyle.hidden, windowButtonVisibility: false); await windowManager.setSize(const Size(1280, 810)); await windowManager.setMinimumSize(const Size(1280, 810)); await windowManager.center(animate: true); @@ -73,18 +70,22 @@ class App extends HookConsumerWidget with WindowListener { ); }, theme: FluentThemeData( - brightness: Brightness.dark, - fontFamily: "SourceHanSansCN-Regular", - navigationPaneTheme: NavigationPaneThemeData( - backgroundColor: appState.themeConf.backgroundColor, + brightness: Brightness.dark, + fontFamily: "SourceHanSansCN-Regular", + navigationPaneTheme: NavigationPaneThemeData(backgroundColor: appState.themeConf.backgroundColor), + menuColor: appState.themeConf.menuColor, + micaBackgroundColor: appState.themeConf.micaColor, + buttonTheme: ButtonThemeData( + defaultButtonStyle: ButtonStyle( + shape: WidgetStateProperty.all( + RoundedRectangleBorder( + borderRadius: BorderRadius.circular(4), + side: BorderSide(color: Colors.white.withValues(alpha: .01)), + ), + ), ), - menuColor: appState.themeConf.menuColor, - micaBackgroundColor: appState.themeConf.micaColor, - buttonTheme: ButtonThemeData( - defaultButtonStyle: ButtonStyle( - shape: WidgetStateProperty.all(RoundedRectangleBorder( - borderRadius: BorderRadius.circular(4), side: BorderSide(color: Colors.white.withValues(alpha: .01)))), - ))), + ), + ), locale: appState.appLocale, debugShowCheckedModeBanner: false, routeInformationParser: router.routeInformationParser, @@ -112,42 +113,28 @@ Widget _defaultWebviewTitleBar(BuildContext context) { final state = TitleBarWebViewState.of(context); final controller = TitleBarWebViewController.of(context); return FluentTheme( - data: FluentThemeData.dark(), - child: Row( - crossAxisAlignment: CrossAxisAlignment.center, - children: [ - if (Platform.isMacOS) const SizedBox(width: 96), - IconButton( - onPressed: !state.canGoBack ? null : controller.back, - icon: const Icon(FluentIcons.chevron_left), - ), - const SizedBox(width: 12), - IconButton( - onPressed: !state.canGoForward ? null : controller.forward, - icon: const Icon(FluentIcons.chevron_right), - ), - const SizedBox(width: 12), - if (state.isLoading) - IconButton( - onPressed: controller.stop, - icon: const Icon(FluentIcons.chrome_close), - ) - else - IconButton( - onPressed: controller.reload, - icon: const Icon(FluentIcons.refresh), - ), - const SizedBox(width: 12), - (state.isLoading) - ? const SizedBox( - width: 24, - height: 24, - child: ProgressRing(), - ) - : const SizedBox(width: 24), - const SizedBox(width: 12), - SelectableText(state.url ?? ""), - const Spacer() - ], - )); + data: FluentThemeData.dark(), + child: Row( + crossAxisAlignment: CrossAxisAlignment.center, + children: [ + if (Platform.isMacOS) const SizedBox(width: 96), + IconButton(onPressed: !state.canGoBack ? null : controller.back, icon: const Icon(FluentIcons.chevron_left)), + const SizedBox(width: 12), + IconButton( + onPressed: !state.canGoForward ? null : controller.forward, + icon: const Icon(FluentIcons.chevron_right), + ), + const SizedBox(width: 12), + if (state.isLoading) + IconButton(onPressed: controller.stop, icon: const Icon(FluentIcons.chrome_close)) + else + IconButton(onPressed: controller.reload, icon: const Icon(FluentIcons.refresh)), + const SizedBox(width: 12), + (state.isLoading) ? const SizedBox(width: 24, height: 24, child: ProgressRing()) : const SizedBox(width: 24), + const SizedBox(width: 12), + SelectableText(state.url ?? ""), + const Spacer(), + ], + ), + ); } diff --git a/pubspec.lock b/pubspec.lock index 9d7b740..d8e8ad2 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -5,34 +5,34 @@ packages: dependency: transitive description: name: _fe_analyzer_shared - sha256: da0d9209ca76bde579f2da330aeb9df62b6319c834fa7baae052021b0462401f + sha256: f0bb5d1648339c8308cc0b9838d8456b3cfe5c91f9dc1a735b4d003269e5da9a url: "https://pub.dev" source: hosted - version: "85.0.0" + version: "88.0.0" analyzer: dependency: transitive description: name: analyzer - sha256: f4ad0fea5f102201015c9aae9d93bc02f75dd9491529a8c21f88d17a8523d44c + sha256: "0b7b9c329d2879f8f05d6c05b32ee9ec025f39b077864bdb5ac9a7b63418a98f" url: "https://pub.dev" source: hosted - version: "7.6.0" + version: "8.1.1" analyzer_buffer: dependency: transitive description: name: analyzer_buffer - sha256: f7833bee67c03c37241c67f8741b17cc501b69d9758df7a5a4a13ed6c947be43 + sha256: aba2f75e63b3135fd1efaa8b6abefe1aa6e41b6bd9806221620fa48f98156033 url: "https://pub.dev" source: hosted - version: "0.1.10" + version: "0.1.11" analyzer_plugin: dependency: transitive description: name: analyzer_plugin - sha256: a5ab7590c27b779f3d4de67f31c4109dbe13dd7339f86461a6f2a8ab2594d8ce + sha256: dd574a0ab77de88b7d9c12bc4b626109a5ca9078216a79041a5c24c3a1bd103c url: "https://pub.dev" source: hosted - version: "0.13.4" + version: "0.13.7" archive: dependency: "direct main" description: @@ -278,42 +278,42 @@ packages: dependency: "direct dev" description: name: custom_lint - sha256: "78085fbe842de7c5bef92de811ca81536968dbcbbcdac5c316711add2d15e796" + sha256: "751ee9440920f808266c3ec2553420dea56d3c7837dd2d62af76b11be3fcece5" url: "https://pub.dev" source: hosted - version: "0.8.0" + version: "0.8.1" custom_lint_builder: dependency: transitive description: name: custom_lint_builder - sha256: cc5532d5733d4eccfccaaec6070a1926e9f21e613d93ad0927fad020b95c9e52 + sha256: "1128db6f58e71d43842f3b9be7465c83f0c47f4dd8918f878dd6ad3b72a32072" url: "https://pub.dev" source: hosted - version: "0.8.0" + version: "0.8.1" custom_lint_core: dependency: transitive description: name: custom_lint_core - sha256: cc4684d22ca05bf0a4a51127e19a8aea576b42079ed2bc9e956f11aaebe35dd1 + sha256: "85b339346154d5646952d44d682965dfe9e12cae5febd706f0db3aa5010d6423" url: "https://pub.dev" source: hosted - version: "0.8.0" + version: "0.8.1" custom_lint_visitor: dependency: transitive description: name: custom_lint_visitor - sha256: "4a86a0d8415a91fbb8298d6ef03e9034dc8e323a599ddc4120a0e36c433983a2" + sha256: "446d68322747ec1c36797090de776aa72228818d3d80685a91ff524d163fee6d" url: "https://pub.dev" source: hosted - version: "1.0.0+7.7.0" + version: "1.0.0+8.1.1" dart_style: dependency: transitive description: name: dart_style - sha256: "8a0e5fba27e8ee025d2ffb4ee820b4e6e2cf5e4246a6b1a477eb66866947e0bb" + sha256: c87dfe3d56f183ffe9106a18aebc6db431fc7c98c31a54b952a77f3d54a85697 url: "https://pub.dev" source: hosted - version: "3.1.1" + version: "3.1.2" dbus: dependency: transitive description: @@ -422,10 +422,10 @@ packages: dependency: "direct main" description: name: file_picker - sha256: f2d9f173c2c14635cc0e9b14c143c49ef30b4934e8d1d274d6206fcb0086a06f + sha256: f8f4ea435f791ab1f817b4e338ed958cb3d04ba43d6736ffc39958d950754967 url: "https://pub.dev" source: hosted - version: "10.3.3" + version: "10.3.6" file_sizes: dependency: "direct main" description: @@ -618,10 +618,10 @@ packages: dependency: "direct main" description: name: hexcolor - sha256: c07f4bbb9095df87eeca87e7c69e8c3d60f70c66102d7b8d61c4af0453add3f6 + sha256: "0f237eed7db96ebacd8fda00d17f5ae262aaa84c213d53457c06b1dcbdfa81f2" url: "https://pub.dev" source: hosted - version: "3.0.1" + version: "3.0.2" highlight: dependency: transitive description: @@ -666,10 +666,10 @@ packages: dependency: "direct overridden" description: name: http - sha256: bb2ce4590bc2667c96f318d68cac1b5a7987ec819351d32b1c987239a815e007 + sha256: "87721a4a50b19c7f1d49001e51409bddc46303966ce89a65af4f4e6004896412" url: "https://pub.dev" source: hosted - version: "1.5.0" + version: "1.6.0" http_client_helper: dependency: transitive description: @@ -874,10 +874,10 @@ packages: dependency: "direct main" description: name: meta - sha256: e3641ec5d63ebf0d9b41bd43201a66e3fc79a65db5f61fc181f04cd27aab950c + sha256: "23f08335362185a5ea2ad3a4e597f1375e78bce8a040df5c600c8d3552ef2394" url: "https://pub.dev" source: hosted - version: "1.16.0" + version: "1.17.0" mime: dependency: transitive description: @@ -890,10 +890,10 @@ packages: dependency: transitive description: name: mockito - sha256: "2314cbe9165bcd16106513df9cf3c3224713087f09723b128928dc11a4379f99" + sha256: "4feb43bc4eb6c03e832f5fcd637d1abb44b98f9cfa245c58e27382f58859f8f6" url: "https://pub.dev" source: hosted - version: "5.5.0" + version: "5.5.1" msix: dependency: "direct dev" description: @@ -1261,10 +1261,10 @@ packages: dependency: transitive description: name: source_gen - sha256: "7b19d6ba131c6eb98bfcbf8d56c1a7002eba438af2e7ae6f8398b2b0f4f381e3" + sha256: "9098ab86015c4f1d8af6486b547b11100e73b193e1899015033cb3e14ad20243" url: "https://pub.dev" source: hosted - version: "3.1.0" + version: "4.0.2" source_helper: dependency: transitive description: @@ -1297,14 +1297,6 @@ packages: url: "https://pub.dev" source: hosted version: "1.10.1" - sprintf: - dependency: transitive - description: - name: sprintf - sha256: "1fc9ffe69d4df602376b52949af107d8f5703b77cda567c4d7d86a0693120f23" - url: "https://pub.dev" - source: hosted - version: "7.0.0" stack_trace: dependency: transitive description: @@ -1373,26 +1365,26 @@ packages: dependency: transitive description: name: test - sha256: "65e29d831719be0591f7b3b1a32a3cda258ec98c58c7b25f7b84241bc31215bb" + sha256: "75906bf273541b676716d1ca7627a17e4c4070a3a16272b7a3dc7da3b9f3f6b7" url: "https://pub.dev" source: hosted - version: "1.26.2" + version: "1.26.3" test_api: dependency: transitive description: name: test_api - sha256: "522f00f556e73044315fa4585ec3270f1808a4b186c936e612cab0b565ff1e00" + sha256: ab2726c1a94d3176a45960b6234466ec367179b87dd74f1611adb1f3b5fb9d55 url: "https://pub.dev" source: hosted - version: "0.7.6" + version: "0.7.7" test_core: dependency: transitive description: name: test_core - sha256: "80bf5a02b60af04b09e14f6fe68b921aad119493e26e490deaca5993fef1b05a" + sha256: "0cc24b5ff94b38d2ae73e1eb43cc302b77964fbf67abad1e296025b78deb53d0" url: "https://pub.dev" source: hosted - version: "0.6.11" + version: "0.6.12" timing: dependency: transitive description: @@ -1485,10 +1477,10 @@ packages: dependency: "direct main" description: name: uuid - sha256: a5be9ef6618a7ac1e964353ef476418026db906c4facdedaa299b7a2e71690ff + sha256: a11b666489b1954e01d992f3d601b1804a33937b5a8fe677bd26b8a9f96f96e8 url: "https://pub.dev" source: hosted - version: "4.5.1" + version: "4.5.2" vector_graphics: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 5df54c1..fdc03f6 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -33,20 +33,20 @@ dependencies: markdown_widget: ^2.3.2+8 extended_image: ^10.0.1 device_info_plus: ^12.2.0 - file_picker: ^10.3.3 + file_picker: ^10.3.6 file_sizes: ^1.0.6 desktop_webview_window: ^0.2.3 flutter_svg: ^2.2.2 archive: ^4.0.7 jwt_decode: ^0.3.1 - uuid: ^4.5.1 + uuid: ^4.5.2 flutter_tilt: ^3.3.2 card_swiper: ^3.0.1 ffi: ^2.1.4 flutter_rust_bridge: ^2.11.1 freezed_annotation: ^3.1.0 - meta: ^1.16.0 - hexcolor: ^3.0.1 + meta: ^1.17.0 + hexcolor: ^3.0.2 html: ^0.15.6 fixnum: ^1.1.1 rust_builder: @@ -54,7 +54,6 @@ dependencies: aria2: git: https://github.com/xkeyC/dart_aria2_rpc.git # path: ../../xkeyC/dart_aria2_rpc - # path: ../../xkeyC/dart_aria2_rpc intl: any synchronized: ^3.4.0 super_sliver_list: ^0.4.1 @@ -69,7 +68,7 @@ dependencies: crypto: ^3.0.7 xml: ^6.6.1 dependency_overrides: - http: ^1.5.0 + http: ^1.6.0 intl: ^0.20.2 dev_dependencies: @@ -81,7 +80,7 @@ dev_dependencies: freezed: ^3.2.3 json_serializable: ^6.11.1 riverpod_generator: ^3.0.3 - custom_lint: ^0.8.0 + custom_lint: ^0.8.1 riverpod_lint: ^3.0.3 ffigen: ^20.0.0 sct_dev_tools: diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 0da8580..e005048 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -316,12 +316,24 @@ dependencies = [ "windows-link 0.2.1", ] +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55248b47b0caf0546f7988906588779981c43bb1bc9d0c44087278f80cdb44ba" + [[package]] name = "bitflags" version = "2.10.0" @@ -630,6 +642,16 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "crossbeam-deque" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9dd111b7b7f7d55b72c0a6ae361660ee5853c9af73f70c3c2ef6858b950e2e51" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + [[package]] name = "crossbeam-epoch" version = "0.9.18" @@ -655,14 +677,38 @@ dependencies = [ "typenum", ] +[[package]] +name = "darling" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" +dependencies = [ + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + [[package]] name = "darling" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.21.3", + "darling_macro 0.21.3", +] + +[[package]] +name = "darling_core" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d00b9596d185e565c2207a0b01f8bd1a135483d02d9b7b0a54b11da8d53412e" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", ] [[package]] @@ -679,13 +725,24 @@ dependencies = [ "syn", ] +[[package]] +name = "darling_macro" +version = "0.20.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" +dependencies = [ + "darling_core 0.20.11", + "quote", + "syn", +] + [[package]] name = "darling_macro" version = "0.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" dependencies = [ - "darling_core", + "darling_core 0.21.3", "quote", "syn", ] @@ -729,6 +786,16 @@ dependencies = [ "syn", ] +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "pem-rfc7468", + "zeroize", +] + [[package]] name = "deranged" version = "0.5.4" @@ -739,6 +806,37 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive_builder" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "507dfb09ea8b7fa618fcf76e953f4f5e192547945816d5358edffe39f6f94947" +dependencies = [ + "derive_builder_macro", +] + +[[package]] +name = "derive_builder_core" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" +dependencies = [ + "darling 0.20.11", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "derive_builder_macro" +version = "0.20.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab63b0e2bf4d5928aff72e83a7dace85d7bba5fe12dcc3c5a572d78caffd3f3c" +dependencies = [ + "derive_builder_core", + "syn", +] + [[package]] name = "digest" version = "0.10.7" @@ -865,6 +963,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" + [[package]] name = "event-listener" version = "5.4.1" @@ -902,6 +1006,18 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "filetime" +version = "0.2.26" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc0505cd1b6fa6580283f6bdf70a73fcf4aba1184038c90902b92b3dd0df63ed" +dependencies = [ + "cfg-if", + "libc", + "libredox", + "windows-sys 0.60.2", +] + [[package]] name = "find-msvc-tools" version = "0.1.4" @@ -1206,7 +1322,7 @@ dependencies = [ "idna", "ipnet", "once_cell", - "rand", + "rand 0.9.2", "ring", "thiserror 2.0.17", "tinyvec", @@ -1228,9 +1344,9 @@ dependencies = [ "moka", "once_cell", "parking_lot", - "rand", + "rand 0.9.2", "resolv-conf", - "smallvec", + "smallvec 1.15.1", "thiserror 2.0.17", "tokio", "tracing", @@ -1293,7 +1409,7 @@ dependencies = [ "itoa", "pin-project-lite", "pin-utils", - "smallvec", + "smallvec 1.15.1", "tokio", "want", ] @@ -1337,7 +1453,7 @@ version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c6995591a8f1380fcb4ba966a252a4b29188d51d2b89e3a252f5305be65aea8" dependencies = [ - "base64", + "base64 0.22.1", "bytes", "futures-channel", "futures-core", @@ -1418,7 +1534,7 @@ dependencies = [ "icu_normalizer_data", "icu_properties", "icu_provider", - "smallvec", + "smallvec 1.15.1", "zerovec", ] @@ -1480,7 +1596,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de" dependencies = [ "idna_adapter", - "smallvec", + "smallvec 1.15.1", "utf8_iter", ] @@ -1575,6 +1691,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -1603,6 +1728,17 @@ version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +[[package]] +name = "libredox" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb" +dependencies = [ + "bitflags", + "libc", + "redox_syscall", +] + [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -1654,6 +1790,32 @@ dependencies = [ "time", ] +[[package]] +name = "macro_rules_attribute" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65049d7923698040cd0b1ddcced9b0eb14dd22c5f86ae59c3740eab64a676520" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "670fdfda89751bc4a84ac13eaa63e205cf0fd22b4c9a5fbfa085b63c1f1d3a30" + +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "md-5" version = "0.10.6" @@ -1725,11 +1887,33 @@ dependencies = [ "parking_lot", "portable-atomic", "rustc_version", - "smallvec", + "smallvec 1.15.1", "tagptr", "uuid", ] +[[package]] +name = "monostate" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3341a273f6c9d5bef1908f17b7267bbab0e95c9bf69a0d4dcf8e9e1b2c76ef67" +dependencies = [ + "monostate-impl", + "serde", + "serde_core", +] + +[[package]] +name = "monostate-impl" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e4db6d5580af57bf992f59068d4ea26fd518574ff48d7639b255a36f9de6e7e9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "native-tls" version = "0.2.14" @@ -1747,6 +1931,21 @@ dependencies = [ "tempfile", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "nix" version = "0.30.1" @@ -1784,12 +1983,30 @@ dependencies = [ "zbus", ] +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1873,6 +2090,28 @@ version = "1.70.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" +[[package]] +name = "onig" +version = "6.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "336b9c63443aceef14bea841b899035ae3abe89b7c486aaf4c5bd8aafedac3f0" +dependencies = [ + "bitflags", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7f86c6eef3d6df15f23bcfb6af487cbd2fed4e5581d58d5bf1f5f8b7f6727dc" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "openssl" version = "0.10.74" @@ -1927,6 +2166,31 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "ort" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fa7e49bd669d32d7bc2a15ec540a527e7764aec722a45467814005725bcd721" +dependencies = [ + "ndarray", + "ort-sys", + "smallvec 2.0.0-alpha.10", + "tracing", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2aba9f5c7c479925205799216e7e5d07cc1d4fa76ea8058c60a9a30f6a4e890" +dependencies = [ + "flate2", + "pkg-config", + "sha2", + "tar", + "ureq", +] + [[package]] name = "oslog" version = "0.2.0" @@ -1969,10 +2233,25 @@ dependencies = [ "cfg-if", "libc", "redox_syscall", - "smallvec", + "smallvec 1.15.1", "windows-link 0.2.1", ] +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "pem-rfc7468" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88b39c9bfcfc231068454382784bb460aae594343fb030d46e9f50a645418412" +dependencies = [ + "base64ct", +] + [[package]] name = "percent-encoding" version = "2.3.2" @@ -2037,6 +2316,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "potential_utf" version = "0.1.3" @@ -2133,7 +2421,7 @@ dependencies = [ "bytes", "getrandom 0.3.4", "lru-slab", - "rand", + "rand 0.9.2", "ring", "rustc-hash", "rustls", @@ -2174,14 +2462,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha 0.3.1", + "rand_core 0.6.4", +] + [[package]] name = "rand" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ - "rand_chacha", - "rand_core", + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core 0.6.4", ] [[package]] @@ -2191,7 +2500,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" dependencies = [ "ppv-lite86", - "rand_core", + "rand_core 0.9.3", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom 0.2.16", ] [[package]] @@ -2203,6 +2521,43 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-cond" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22e18b0f0062d30d4230b2e85ff77fdfe4326feb054b9783a3460d8435c8ab91" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + [[package]] name = "redox_syscall" version = "0.5.18" @@ -2268,7 +2623,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" dependencies = [ "async-compression", - "base64", + "base64 0.22.1", "bytes", "cookie", "cookie_store", @@ -2340,10 +2695,14 @@ dependencies = [ "flutter_rust_bridge", "futures", "hickory-resolver", + "ndarray", "notify-rust", "once_cell", + "ort", "reqwest", "scopeguard", + "serde_json", + "tokenizers", "tokio", "url", "walkdir", @@ -2581,7 +2940,7 @@ version = "3.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa66c845eee442168b2c8134fec70ac50dc20e760769c8ba0ad1319ca1959b04" dependencies = [ - "base64", + "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", @@ -2600,7 +2959,7 @@ version = "3.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b91a903660542fced4e99881aa481bdbaec1634568ee02e0b8bd57c64cb38955" dependencies = [ - "darling", + "darling 0.21.3", "proc-macro2", "quote", "syn", @@ -2659,6 +3018,12 @@ version = "1.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" +[[package]] +name = "smallvec" +version = "2.0.0-alpha.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d44cfb396c3caf6fbfd0ab422af02631b69ddd96d2eff0b0f0724f9024051b" + [[package]] name = "socket2" version = "0.5.10" @@ -2679,6 +3044,29 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -2761,6 +3149,17 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" +[[package]] +name = "tar" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d863878d212c87a19c1a610eb53bb01fe12951c0501cf5a0d65f724914a667a" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "tauri-winrt-notification" version = "0.7.2" @@ -2900,6 +3299,37 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.20.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b08cc37428a476fc9e20ac850132a513a2e1ce32b6a31addf2b74fa7033b905" +dependencies = [ + "aho-corasick", + "derive_builder", + "esaxx-rs", + "getrandom 0.2.16", + "itertools 0.12.1", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand 0.8.5", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror 1.0.69", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.48.0" @@ -3117,18 +3547,69 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "462eeb75aeb73aea900253ce739c8e18a67423fadf006037cd3ff27e82748a06" +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec 1.15.1", +] + +[[package]] +name = "unicode-segmentation" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" + [[package]] name = "unicode-xid" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" +[[package]] +name = "ureq" +version = "3.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d39cb1dbab692d82a977c0392ffac19e188bd9186a9f32806f0aaa859d75585a" +dependencies = [ + "base64 0.22.1", + "der", + "log", + "native-tls", + "percent-encoding", + "rustls-pki-types", + "socks", + "ureq-proto", + "utf-8", + "webpki-root-certs", +] + +[[package]] +name = "ureq-proto" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b4531c118335662134346048ddb0e54cc86bd7e81866757873055f0e38f5d2" +dependencies = [ + "base64 0.22.1", + "http", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.7" @@ -3141,6 +3622,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" @@ -3309,7 +3796,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d12a78aa0bab22d2f26ed1a96df7ab58e8a93506a3e20adb47c51a93b4e1357" dependencies = [ "const_format", - "itertools", + "itertools 0.11.0", "nom", "pori", "regex", @@ -3337,6 +3824,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-root-certs" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee3e3b5f5e80bc89f30ce8d0343bf4e5f12341c51f3e26cbeecbc7c85443e85b" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "webpki-roots" version = "1.0.3" @@ -3865,6 +4361,16 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ea2f10b9bb0928dfb1b42b65e1f9e36f7f54dbdf08457afefb38afcdec4fa2bb" +[[package]] +name = "xattr" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32e45ad4206f6d2479085147f02bc2ef834ac85886624a23575ae137c8aa8156" +dependencies = [ + "libc", + "rustix", +] + [[package]] name = "yoke" version = "0.8.0" diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 5b2cf8d..3bee994 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -23,6 +23,10 @@ scopeguard = "1.2" notify-rust = "4" asar = "0.3.0" walkdir = "2.5.0" +ort = { version = "2.0.0-rc.10", features = ["xnnpack", "download-binaries", "ndarray"] } +tokenizers = { version = "0.20", default-features = false, features = ["onig"] } +ndarray = "0.16" +serde_json = "1.0" [target.'cfg(windows)'.dependencies] windows = { version = "0.62.2", features = ["Win32_UI_WindowsAndMessaging"] } diff --git a/rust/src/api/mod.rs b/rust/src/api/mod.rs index cb0e8f9..1c73f59 100644 --- a/rust/src/api/mod.rs +++ b/rust/src/api/mod.rs @@ -5,3 +5,4 @@ pub mod http_api; pub mod rs_process; pub mod win32_api; pub mod asar_api; +pub mod ort_api; diff --git a/rust/src/api/ort_api.rs b/rust/src/api/ort_api.rs new file mode 100644 index 0000000..8a661e5 --- /dev/null +++ b/rust/src/api/ort_api.rs @@ -0,0 +1,107 @@ +use anyhow::Result; +use once_cell::sync::Lazy; +use std::collections::HashMap; +use std::sync::Mutex; + +use crate::ort_models::opus_mt::OpusMtModel; + +/// 全局模型缓存 +static MODEL_CACHE: Lazy>> = + Lazy::new(|| Mutex::new(HashMap::new())); + +/// 加载 ONNX 翻译模型 +/// +/// # Arguments +/// * `model_path` - 模型文件夹路径 +/// * `model_key` - 模型缓存键(用于标识模型,如 "zh-en") +/// * `quantization_suffix` - 量化后缀(如 "_q4", "_q8",空字符串表示使用默认模型) +/// +pub fn load_translation_model( + model_path: String, + model_key: String, + quantization_suffix: String, +) -> Result<()> { + let model = OpusMtModel::new(&model_path, &quantization_suffix)?; + + let mut cache = MODEL_CACHE + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?; + + cache.insert(model_key, model); + + Ok(()) +} + +/// 翻译文本 +/// +/// # Arguments +/// * `model_key` - 模型缓存键(如 "zh-en") +/// * `text` - 要翻译的文本 +/// +/// # Returns +/// * `Result` - 翻译后的文本 +pub fn translate_text(model_key: String, text: String) -> Result { + let cache = MODEL_CACHE + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?; + + let model = cache.get(&model_key).ok_or_else(|| { + anyhow::anyhow!( + "Model not found: {}. Please load the model first.", + model_key + ) + })?; + + model.translate(&text) +} + +/// 批量翻译文本 +/// +/// # Arguments +/// * `model_key` - 模型缓存键(如 "zh-en") +/// * `texts` - 要翻译的文本列表 +/// +/// # Returns +/// * `Result>` - 翻译后的文本列表 +pub fn translate_text_batch(model_key: String, texts: Vec) -> Result> { + let cache = MODEL_CACHE + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?; + + let model = cache.get(&model_key).ok_or_else(|| { + anyhow::anyhow!( + "Model not found: {}. Please load the model first.", + model_key + ) + })?; + + model.translate_batch(&texts) +} + +/// 卸载模型 +/// +/// # Arguments +/// * `model_key` - 模型缓存键(如 "zh-en") +/// +pub fn unload_translation_model(model_key: String) -> Result<()> { + let mut cache = MODEL_CACHE + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?; + + cache.remove(&model_key); + + Ok(()) +} + +/// 清空所有已加载的模型 +/// +/// # Returns +pub fn clear_all_models() -> Result<()> { + let mut cache = MODEL_CACHE + .lock() + .map_err(|e| anyhow::anyhow!("Failed to lock model cache: {}", e))?; + + cache.clear(); + + Ok(()) +} diff --git a/rust/src/frb_generated.rs b/rust/src/frb_generated.rs index 69a3d28..6e2f10b 100644 --- a/rust/src/frb_generated.rs +++ b/rust/src/frb_generated.rs @@ -37,7 +37,7 @@ flutter_rust_bridge::frb_generated_boilerplate!( default_rust_auto_opaque = RustAutoOpaqueNom, ); pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_VERSION: &str = "2.11.1"; -pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = 1832496273; +pub(crate) const FLUTTER_RUST_BRIDGE_CODEGEN_CONTENT_HASH: i32 = -706588047; // Section: executor @@ -45,6 +45,27 @@ flutter_rust_bridge::frb_generated_default_handler!(); // Section: wire_funcs +fn wire__crate__api__ort_api__clear_all_models_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "clear_all_models", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + move |context| { + transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>( + (move || { + let output_ok = crate::api::ort_api::clear_all_models()?; + Ok(output_ok) + })(), + ) + } + }, + ) +} fn wire__crate__api__http_api__dns_lookup_ips_impl( port_: flutter_rust_bridge::for_generated::MessagePort, host: impl CstDecode, @@ -161,6 +182,37 @@ fn wire__crate__api__asar_api__get_rsi_launcher_asar_data_impl( }, ) } +fn wire__crate__api__ort_api__load_translation_model_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + model_path: impl CstDecode, + model_key: impl CstDecode, + quantization_suffix: impl CstDecode, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "load_translation_model", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let api_model_path = model_path.cst_decode(); + let api_model_key = model_key.cst_decode(); + let api_quantization_suffix = quantization_suffix.cst_decode(); + move |context| { + transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>( + (move || { + let output_ok = crate::api::ort_api::load_translation_model( + api_model_path, + api_model_key, + api_quantization_suffix, + )?; + Ok(output_ok) + })(), + ) + } + }, + ) +} fn wire__crate__api__asar_api__rsi_launcher_asar_data_write_main_js_impl( port_: flutter_rust_bridge::for_generated::MessagePort, that: impl CstDecode, @@ -315,6 +367,82 @@ fn wire__crate__api__rs_process__start_impl( }, ) } +fn wire__crate__api__ort_api__translate_text_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + model_key: impl CstDecode, + text: impl CstDecode, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "translate_text", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let api_model_key = model_key.cst_decode(); + let api_text = text.cst_decode(); + move |context| { + transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>( + (move || { + let output_ok = + crate::api::ort_api::translate_text(api_model_key, api_text)?; + Ok(output_ok) + })(), + ) + } + }, + ) +} +fn wire__crate__api__ort_api__translate_text_batch_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + model_key: impl CstDecode, + texts: impl CstDecode>, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "translate_text_batch", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let api_model_key = model_key.cst_decode(); + let api_texts = texts.cst_decode(); + move |context| { + transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>( + (move || { + let output_ok = + crate::api::ort_api::translate_text_batch(api_model_key, api_texts)?; + Ok(output_ok) + })(), + ) + } + }, + ) +} +fn wire__crate__api__ort_api__unload_translation_model_impl( + port_: flutter_rust_bridge::for_generated::MessagePort, + model_key: impl CstDecode, +) { + FLUTTER_RUST_BRIDGE_HANDLER.wrap_normal::( + flutter_rust_bridge::for_generated::TaskInfo { + debug_name: "unload_translation_model", + port: Some(port_), + mode: flutter_rust_bridge::for_generated::FfiCallMode::Normal, + }, + move || { + let api_model_key = model_key.cst_decode(); + move |context| { + transform_result_dco::<_, _, flutter_rust_bridge::for_generated::anyhow::Error>( + (move || { + let output_ok = + crate::api::ort_api::unload_translation_model(api_model_key)?; + Ok(output_ok) + })(), + ) + } + }, + ) +} fn wire__crate__api__rs_process__write_impl( port_: flutter_rust_bridge::for_generated::MessagePort, rs_pid: impl CstDecode, @@ -1361,6 +1489,13 @@ mod io { } } + #[unsafe(no_mangle)] + pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__clear_all_models( + port_: i64, + ) { + wire__crate__api__ort_api__clear_all_models_impl(port_) + } + #[unsafe(no_mangle)] pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__http_api__dns_lookup_ips( port_: i64, @@ -1406,6 +1541,21 @@ mod io { wire__crate__api__asar_api__get_rsi_launcher_asar_data_impl(port_, asar_path) } + #[unsafe(no_mangle)] + pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__load_translation_model( + port_: i64, + model_path: *mut wire_cst_list_prim_u_8_strict, + model_key: *mut wire_cst_list_prim_u_8_strict, + quantization_suffix: *mut wire_cst_list_prim_u_8_strict, + ) { + wire__crate__api__ort_api__load_translation_model_impl( + port_, + model_path, + model_key, + quantization_suffix, + ) + } + #[unsafe(no_mangle)] pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__asar_api__rsi_launcher_asar_data_write_main_js( port_: i64, @@ -1459,6 +1609,32 @@ mod io { ) } + #[unsafe(no_mangle)] + pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__translate_text( + port_: i64, + model_key: *mut wire_cst_list_prim_u_8_strict, + text: *mut wire_cst_list_prim_u_8_strict, + ) { + wire__crate__api__ort_api__translate_text_impl(port_, model_key, text) + } + + #[unsafe(no_mangle)] + pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__translate_text_batch( + port_: i64, + model_key: *mut wire_cst_list_prim_u_8_strict, + texts: *mut wire_cst_list_String, + ) { + wire__crate__api__ort_api__translate_text_batch_impl(port_, model_key, texts) + } + + #[unsafe(no_mangle)] + pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__ort_api__unload_translation_model( + port_: i64, + model_key: *mut wire_cst_list_prim_u_8_strict, + ) { + wire__crate__api__ort_api__unload_translation_model_impl(port_, model_key) + } + #[unsafe(no_mangle)] pub extern "C" fn frbgen_starcitizen_doctor_wire__crate__api__rs_process__write( port_: i64, diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 2c46138..118a7c7 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -1,3 +1,4 @@ pub mod api; mod frb_generated; pub mod http_package; +pub mod ort_models; diff --git a/rust/src/ort_models/mod.rs b/rust/src/ort_models/mod.rs new file mode 100644 index 0000000..5fe1506 --- /dev/null +++ b/rust/src/ort_models/mod.rs @@ -0,0 +1 @@ +pub mod opus_mt; \ No newline at end of file diff --git a/rust/src/ort_models/opus_mt.rs b/rust/src/ort_models/opus_mt.rs new file mode 100644 index 0000000..63ae56e --- /dev/null +++ b/rust/src/ort_models/opus_mt.rs @@ -0,0 +1,388 @@ +use anyhow::{anyhow, Context, Result}; +use ndarray::{Array2, ArrayD}; +use ort::{ + execution_providers::XNNPACKExecutionProvider, session::builder::GraphOptimizationLevel, + session::Session, value::Value, +}; +use std::path::Path; +use std::sync::Mutex; +use tokenizers::Tokenizer; + +/// Opus-MT 翻译模型的推理结构 +pub struct OpusMtModel { + encoder_session: Mutex, + decoder_session: Mutex, + tokenizer: Tokenizer, + config: ModelConfig, +} + +/// 模型配置 +#[derive(Debug, Clone)] +pub struct ModelConfig { + pub max_length: usize, + pub num_beams: usize, + pub decoder_start_token_id: i64, + pub eos_token_id: i64, + pub pad_token_id: i64, +} + +impl Default for ModelConfig { + fn default() -> Self { + Self { + max_length: 512, + num_beams: 1, + decoder_start_token_id: 0, + eos_token_id: 0, + pad_token_id: 0, + } + } +} + +impl OpusMtModel { + /// 从模型路径创建新的 OpusMT 模型实例 + /// + /// # Arguments + /// * `model_path` - 模型文件夹路径(应包含 onnx 子文件夹) + /// * `quantization_suffix` - 量化后缀,如 "_q4", "_q8",为空字符串则使用默认模型 + /// + /// # Returns + /// * `Result` - 成功返回模型实例,失败返回错误 + pub fn new>(model_path: P, quantization_suffix: &str) -> Result { + let model_path = model_path.as_ref(); + + // onnx-community 标准:模型在 onnx 子文件夹中 + let onnx_dir = model_path.join("onnx"); + + // 加载 tokenizer(在根目录) + let tokenizer_path = model_path.join("tokenizer.json"); + + // 动态加载并修复 tokenizer + let tokenizer = + Self::load_tokenizer(&tokenizer_path).context("Failed to load tokenizer")?; + + // 构建模型文件名 + let encoder_filename = if quantization_suffix.is_empty() { + "encoder_model.onnx".to_string() + } else { + format!("encoder_model{}.onnx", quantization_suffix) + }; + + let decoder_filename = if quantization_suffix.is_empty() { + "decoder_model.onnx".to_string() + } else { + format!("decoder_model{}.onnx", quantization_suffix) + }; + + // 加载 encoder 模型(在 onnx 子目录) + let encoder_path = onnx_dir.join(&encoder_filename); + if !encoder_path.exists() { + return Err(anyhow!( + "Encoder model not found: {}", + encoder_path.display() + )); + } + + let encoder_session = Session::builder() + .context("Failed to create encoder session builder")? + .with_optimization_level(GraphOptimizationLevel::Level3) + .context("Failed to set optimization level")? + .with_intra_threads(4) + .context("Failed to set intra threads")? + .with_execution_providers([XNNPACKExecutionProvider::default().build()]) + .context("Failed to register XNNPACK execution provider")? + .commit_from_file(&encoder_path) + .context(format!( + "Failed to load encoder model: {}", + encoder_filename + ))?; + + // 加载 decoder 模型(在 onnx 子目录) + let decoder_path = onnx_dir.join(&decoder_filename); + if !decoder_path.exists() { + return Err(anyhow!( + "Decoder model not found: {}", + decoder_path.display() + )); + } + + let decoder_session = Session::builder() + .context("Failed to create decoder session builder")? + .with_optimization_level(GraphOptimizationLevel::Level3) + .context("Failed to set optimization level")? + .with_intra_threads(4) + .context("Failed to set intra threads")? + .with_execution_providers([XNNPACKExecutionProvider::default().build()]) + .context("Failed to register XNNPACK execution provider")? + .commit_from_file(&decoder_path) + .context(format!( + "Failed to load decoder model: {}", + decoder_filename + ))?; + + // 加载配置(如果存在,在根目录) + let config = Self::load_config(model_path)?; + + Ok(Self { + encoder_session: Mutex::new(encoder_session), + decoder_session: Mutex::new(decoder_session), + tokenizer, + config, + }) + } + + /// 动态加载 tokenizer,自动修复常见问题 + fn load_tokenizer(tokenizer_path: &Path) -> Result { + use std::fs; + + // 读取原始文件 + let content = + fs::read_to_string(tokenizer_path).context("Failed to read tokenizer.json")?; + + // 解析为 JSON + let mut json: serde_json::Value = + serde_json::from_str(&content).context("Failed to parse tokenizer.json")?; + + let mut needs_fix = false; + + // 修复 normalizer 中的问题 + if let Some(obj) = json.as_object_mut() { + if let Some(normalizer) = obj.get("normalizer") { + let mut should_remove_normalizer = false; + + if normalizer.is_null() { + // normalizer 是 null,需要移除 + should_remove_normalizer = true; + } else if let Some(norm_obj) = normalizer.as_object() { + // 检查是否是有问题的 Precompiled 类型 + if let Some(type_val) = norm_obj.get("type") { + if type_val.as_str() == Some("Precompiled") { + // 检查 precompiled_charsmap 字段 + if let Some(precompiled) = norm_obj.get("precompiled_charsmap") { + if precompiled.is_null() { + // precompiled_charsmap 是 null,移除整个 normalizer + should_remove_normalizer = true; + } + } else { + // 缺少 precompiled_charsmap 字段,移除整个 normalizer + should_remove_normalizer = true; + } + } + } + } + + if should_remove_normalizer { + obj.remove("normalizer"); + needs_fix = true; + } + } + } + + // 从修复后的 JSON 字符串加载 tokenizer + let json_str = if needs_fix { + serde_json::to_string(&json).context("Failed to serialize fixed tokenizer")? + } else { + content + }; + + // 从字节数组加载 tokenizer + Tokenizer::from_bytes(json_str.as_bytes()) + .map_err(|e| anyhow!("Failed to load tokenizer: {}", e)) + } + + /// 从配置文件加载模型配置 + fn load_config(model_path: &Path) -> Result { + let config_path = model_path.join("config.json"); + + if config_path.exists() { + let config_str = + std::fs::read_to_string(config_path).context("Failed to read config.json")?; + let config_json: serde_json::Value = + serde_json::from_str(&config_str).context("Failed to parse config.json")?; + + Ok(ModelConfig { + max_length: config_json["max_length"].as_u64().unwrap_or(512) as usize, + num_beams: config_json["num_beams"].as_u64().unwrap_or(1) as usize, + decoder_start_token_id: config_json["decoder_start_token_id"].as_i64().unwrap_or(0), + eos_token_id: config_json["eos_token_id"].as_i64().unwrap_or(0), + pad_token_id: config_json["pad_token_id"].as_i64().unwrap_or(0), + }) + } else { + Ok(ModelConfig::default()) + } + } + + /// 翻译文本 + /// + /// # Arguments + /// * `text` - 要翻译的文本 + /// + /// # Returns + /// * `Result` - 翻译后的文本 + pub fn translate(&self, text: &str) -> Result { + // 1. Tokenize 输入文本 + let encoding = self + .tokenizer + .encode(text, true) + .map_err(|e| anyhow!("Failed to encode text: {}", e))?; + + let input_ids = encoding.get_ids(); + let attention_mask = encoding.get_attention_mask(); + + // 2. 准备 encoder 输入 + let batch_size = 1; + let seq_len = input_ids.len(); + + let input_ids_array: Array2 = Array2::from_shape_vec( + (batch_size, seq_len), + input_ids.iter().map(|&id| id as i64).collect(), + ) + .context("Failed to create input_ids array")?; + + let attention_mask_array: Array2 = Array2::from_shape_vec( + (batch_size, seq_len), + attention_mask.iter().map(|&mask| mask as i64).collect(), + ) + .context("Failed to create attention_mask array")?; + + // 3. 运行 encoder + let input_ids_value = + Value::from_array(input_ids_array).context("Failed to create input_ids value")?; + let attention_mask_value = Value::from_array(attention_mask_array.clone()) + .context("Failed to create attention_mask value")?; + + let encoder_inputs = ort::inputs![ + "input_ids" => input_ids_value, + "attention_mask" => attention_mask_value, + ]; + + let mut encoder_session = self + .encoder_session + .lock() + .map_err(|e| anyhow!("Failed to lock encoder session: {}", e))?; + let encoder_outputs = encoder_session + .run(encoder_inputs) + .context("Failed to run encoder")?; + + let encoder_hidden_states = encoder_outputs["last_hidden_state"] + .try_extract_tensor::() + .context("Failed to extract encoder hidden states")?; + + // 将 tensor 转换为 ArrayD + let (shape, data) = encoder_hidden_states; + let shape_vec: Vec = shape.iter().map(|&x| x as usize).collect(); + let encoder_array = ArrayD::from_shape_vec(shape_vec, data.to_vec()) + .context("Failed to create encoder array")?; + + // 4. 贪婪解码生成输出 + let output_ids = self.greedy_decode(encoder_array, &attention_mask_array)?; + + // 5. Decode 输出 token IDs + let output_tokens: Vec = output_ids.iter().map(|&id| id as u32).collect(); + let decoded = self + .tokenizer + .decode(&output_tokens, true) + .map_err(|e| anyhow!("Failed to decode output: {}", e))?; + + Ok(decoded) + } + + /// 贪婪解码 + fn greedy_decode( + &self, + encoder_hidden_states: ArrayD, + encoder_attention_mask: &Array2, + ) -> Result> { + let batch_size = 1; + let mut generated_ids = vec![self.config.decoder_start_token_id]; + + for _ in 0..self.config.max_length { + // 准备 decoder 输入 + let decoder_input_len = generated_ids.len(); + let decoder_input_ids: Array2 = + Array2::from_shape_vec((batch_size, decoder_input_len), generated_ids.clone()) + .context("Failed to create decoder input_ids")?; + + // 创建 ORT Value + let decoder_input_value = Value::from_array(decoder_input_ids) + .context("Failed to create decoder input value")?; + let encoder_hidden_value = Value::from_array(encoder_hidden_states.clone()) + .context("Failed to create encoder hidden value")?; + let encoder_mask_value = Value::from_array(encoder_attention_mask.clone()) + .context("Failed to create encoder mask value")?; + + // 运行 decoder + let decoder_inputs = ort::inputs![ + "input_ids" => decoder_input_value, + "encoder_hidden_states" => encoder_hidden_value, + "encoder_attention_mask" => encoder_mask_value, + ]; + + let mut decoder_session = self + .decoder_session + .lock() + .map_err(|e| anyhow!("Failed to lock decoder session: {}", e))?; + let decoder_outputs = decoder_session + .run(decoder_inputs) + .context("Failed to run decoder")?; + + // 获取 logits + let logits_tensor = decoder_outputs["logits"] + .try_extract_tensor::() + .context("Failed to extract logits")?; + + let (logits_shape, logits_data) = logits_tensor; + let vocab_size = logits_shape[2] as usize; + + // 获取最后一个 token 的 logits + let last_token_idx = decoder_input_len - 1; + let last_logits_start = last_token_idx * vocab_size; + let last_logits_end = last_logits_start + vocab_size; + + let last_logits_slice = &logits_data[last_logits_start..last_logits_end]; + + // 找到最大概率的 token + let next_token_id = last_logits_slice + .iter() + .enumerate() + .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap()) + .map(|(idx, _)| idx as i64) + .context("Failed to find max token")?; + + // 检查是否到达结束 token + if next_token_id == self.config.eos_token_id { + break; + } + + generated_ids.push(next_token_id); + } + + Ok(generated_ids) + } + + /// 批量翻译文本 + /// + /// # Arguments + /// * `texts` - 要翻译的文本列表 + /// + /// # Returns + /// * `Result>` - 翻译后的文本列表 + pub fn translate_batch(&self, texts: &[String]) -> Result> { + texts.iter().map(|text| self.translate(text)).collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_translation() { + let model = OpusMtModel::new( + "C:\\Users\\xkeyc\\Downloads\\onnx_models\\opus-mt-zh-en", + "_q4f16", + ) + .unwrap(); + let result = model.translate("你好世界").unwrap(); + println!("Translation: {}", result); + } +}