diff --git a/core/server/gen/libcore.proto b/core/server/gen/libcore.proto index 4c07518..2cde5ad 100644 --- a/core/server/gen/libcore.proto +++ b/core/server/gen/libcore.proto @@ -19,6 +19,7 @@ service LibcoreService { // rpc SpeedTest(SpeedTestRequest) returns(SpeedTestResponse); rpc QuerySpeedTest(EmptyReq) returns(QuerySpeedTestResponse); + rpc QueryCountryTest(EmptyReq) returns(QueryCountryTestResponse); } message EmptyReq {} @@ -100,6 +101,8 @@ message SpeedTestRequest { optional bool simple_download = 7 [default = false]; optional string simple_download_addr = 8 [default = ""]; optional int32 timeout_ms = 9 [default = 0]; + optional bool only_country = 10 [default = false]; + optional int32 country_concurrency = 11 [default = 0]; } message SpeedTestResult { @@ -122,6 +125,10 @@ message QuerySpeedTestResponse { optional bool is_running = 2 [default = false]; } +message QueryCountryTestResponse { + repeated SpeedTestResult results = 1; +} + message QueryURLTestResponse { repeated URLTestResp results = 1; } diff --git a/core/server/server.go b/core/server/server.go index 9deed14..7c7e7ad 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -309,7 +309,7 @@ func (s *server) IsPrivileged(in *gen.EmptyReq, out *gen.IsPrivilegedResponse) e } func (s *server) SpeedTest(in *gen.SpeedTestRequest, out *gen.SpeedTestResponse) error { - if !*in.TestDownload && !*in.TestUpload && !*in.SimpleDownload { + if !*in.TestDownload && !*in.TestUpload && !*in.SimpleDownload && !*in.OnlyCountry { return errors.New("cannot run empty test") } var testInstance *boxbox.Box @@ -339,7 +339,7 @@ func (s *server) SpeedTest(in *gen.SpeedTestRequest, out *gen.SpeedTestResponse) outboundTags = []string{outbound.Tag()} } - results := BatchSpeedTest(testCtx, testInstance, outboundTags, *in.TestDownload, *in.TestUpload, *in.SimpleDownload, *in.SimpleDownloadAddr, time.Duration(*in.TimeoutMs)*time.Millisecond) + results := BatchSpeedTest(testCtx, testInstance, outboundTags, *in.TestDownload, *in.TestUpload, *in.SimpleDownload, *in.SimpleDownloadAddr, time.Duration(*in.TimeoutMs)*time.Millisecond, *in.OnlyCountry, *in.CountryConcurrency) res := make([]*gen.SpeedTestResult, 0) for _, data := range results { @@ -382,3 +382,24 @@ func (s *server) QuerySpeedTest(in *gen.EmptyReq, out *gen.QuerySpeedTestRespons out.IsRunning = To(isRunning) return nil } + +func (s *server) QueryCountryTest(in *gen.EmptyReq, out *gen.QueryCountryTestResponse) error { + results := CountryResults.Results() + for _, res := range results { + var errStr string + if res.Error != nil { + errStr = res.Error.Error() + } + out.Results = append(out.Results, &gen.SpeedTestResult{ + DlSpeed: To(res.DlSpeed), + UlSpeed: To(res.UlSpeed), + Latency: To(res.Latency), + OutboundTag: To(res.Tag), + Error: To(errStr), + ServerName: To(res.ServerName), + ServerCountry: To(res.ServerCountry), + Cancelled: To(res.Cancelled), + }) + } + return nil +} diff --git a/core/server/test_utils.go b/core/server/test_utils.go index 4da92d2..f7d7864 100644 --- a/core/server/test_utils.go +++ b/core/server/test_utils.go @@ -22,6 +22,7 @@ var testCtx context.Context var cancelTests context.CancelFunc var SpTQuerier SpeedTestResultQuerier var URLReporter URLTestReporter +var CountryResults CountryTestResults const URLTestTimeout = 3 * time.Second const FetchServersTimeout = 8 * time.Second @@ -85,6 +86,25 @@ func (s *SpeedTestResultQuerier) setIsRunning(isRunning bool) { s.isRunning = isRunning } +type CountryTestResults struct { + results []*SpeedTestResult + mu sync.Mutex +} + +func (c *CountryTestResults) AddResult(result *SpeedTestResult) { + c.mu.Lock() + defer c.mu.Unlock() + c.results = append(c.results, result) +} + +func (c *CountryTestResults) Results() []*SpeedTestResult { + c.mu.Lock() + defer c.mu.Unlock() + cp := c.results + c.results = nil + return cp +} + func BatchURLTest(ctx context.Context, i *boxbox.Box, outboundTags []string, url string, maxConcurrency int, twice bool, timeout time.Duration) []*URLTestResult { if timeout <= 0 { timeout = URLTestTimeout @@ -171,9 +191,17 @@ func getNetDialer(dialer func(ctx context.Context, network string, destination m } } -func BatchSpeedTest(ctx context.Context, i *boxbox.Box, outboundTags []string, testDl, testUl bool, simpleDL bool, simpleAddress string, timeout time.Duration) []*SpeedTestResult { +func BatchSpeedTest(ctx context.Context, i *boxbox.Box, outboundTags []string, testDl, testUl bool, simpleDL bool, simpleAddress string, timeout time.Duration, countryOnly bool, countryConcurrency int32) []*SpeedTestResult { outbounds := service.FromContext[adapter.OutboundManager](i.Context()) results := make([]*SpeedTestResult, 0) + var queuer chan struct{} + wg := &sync.WaitGroup{} + if countryOnly { + if countryConcurrency <= 0 { + countryConcurrency = 5 + } + queuer = make(chan struct{}, countryConcurrency) + } for _, tag := range outboundTags { select { @@ -190,6 +218,21 @@ func BatchSpeedTest(ctx context.Context, i *boxbox.Box, outboundTags []string, t results = append(results, res) var err error + if countryOnly { + queuer <- struct{}{} + wg.Add(1) + go func(res *SpeedTestResult, outbound adapter.Outbound) { + defer func() { <-queuer }() + defer wg.Done() + err := countryTest(ctx, getNetDialer(outbound.DialContext), res) + if err != nil && !errors.Is(err, context.Canceled) { + res.Error = err + fmt.Println("Failed to countryTest with err:", err) + } + CountryResults.AddResult(res) + }(res, outbound) + continue + } if simpleDL { err = simpleDownloadTest(ctx, getNetDialer(outbound.DialContext), res, simpleAddress, timeout) } else { @@ -206,6 +249,7 @@ func BatchSpeedTest(ctx context.Context, i *boxbox.Box, outboundTags []string, t res.UlSpeed = "" } } + wg.Wait() return results } @@ -273,7 +317,7 @@ func simpleDownloadTest(ctx context.Context, dialer func(ctx context.Context, ne } } -func speedTestWithDialer(ctx context.Context, dialer func(ctx context.Context, network string, address string) (net.Conn, error), res *SpeedTestResult, testDl, testUl bool, timeout time.Duration) error { +func getSpeedtestServer(ctx context.Context, dialer func(ctx context.Context, network string, address string) (net.Conn, error)) (*speedtest.Server, error) { clt := speedtest.New(speedtest.WithUserConfig(&speedtest.UserConfig{ DialContextFunc: dialer, PingMode: speedtest.HTTP, @@ -283,18 +327,38 @@ func speedTestWithDialer(ctx context.Context, dialer func(ctx context.Context, n defer cancel() srv, err := clt.FetchServerListContext(fetchCtx) if err != nil { - return err + return nil, err } srv, err = srv.FindServer(nil) if err != nil { - return err + return nil, err } if srv.Len() == 0 { - return errors.New("no server found for speedTest") + return nil, errors.New("no server found for speedTest") } - res.ServerName = srv[0].Name - res.ServerCountry = srv[0].Country + + return srv[0], nil +} + +func countryTest(ctx context.Context, dialer func(ctx context.Context, network string, address string) (net.Conn, error), res *SpeedTestResult) error { + srv, err := getSpeedtestServer(ctx, dialer) + if err != nil { + return err + } + res.ServerName = srv.Name + res.ServerCountry = srv.Country + res.Latency = int32(srv.Latency.Milliseconds()) + return nil +} + +func speedTestWithDialer(ctx context.Context, dialer func(ctx context.Context, network string, address string) (net.Conn, error), res *SpeedTestResult, testDl, testUl bool, timeout time.Duration) error { + srv, err := getSpeedtestServer(ctx, dialer) + if err != nil { + return err + } + res.ServerName = srv.Name + res.ServerCountry = srv.Country done := make(chan struct{}) @@ -306,7 +370,7 @@ func speedTestWithDialer(ctx context.Context, dialer func(ctx context.Context, n if testDl { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - err = srv[0].DownloadTestContext(timeoutCtx) + err = srv.DownloadTestContext(timeoutCtx) if err != nil && !errors.Is(err, context.Canceled) { res.Error = err return @@ -315,7 +379,7 @@ func speedTestWithDialer(ctx context.Context, dialer func(ctx context.Context, n if testUl { timeoutCtx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - err = srv[0].UploadTestContext(timeoutCtx) + err = srv.UploadTestContext(timeoutCtx) if err != nil && !errors.Is(err, context.Canceled) { res.Error = err return @@ -329,18 +393,18 @@ func speedTestWithDialer(ctx context.Context, dialer func(ctx context.Context, n for { select { case <-done: - res.DlSpeed = internal.BrateToStr(float64(srv[0].DLSpeed)) - res.UlSpeed = internal.BrateToStr(float64(srv[0].ULSpeed)) - res.Latency = int32(srv[0].Latency.Milliseconds()) + res.DlSpeed = internal.BrateToStr(float64(srv.DLSpeed)) + res.UlSpeed = internal.BrateToStr(float64(srv.ULSpeed)) + res.Latency = int32(srv.Latency.Milliseconds()) SpTQuerier.storeResult(res) return nil case <-ctx.Done(): res.Cancelled = true return ctx.Err() case <-ticker.C: - res.DlSpeed = internal.BrateToStr(srv[0].Context.GetEWMADownloadRate()) - res.UlSpeed = internal.BrateToStr(srv[0].Context.GetEWMAUploadRate()) - res.Latency = int32(srv[0].Latency.Milliseconds()) + res.DlSpeed = internal.BrateToStr(srv.Context.GetEWMADownloadRate()) + res.UlSpeed = internal.BrateToStr(srv.Context.GetEWMAUploadRate()) + res.Latency = int32(srv.Latency.Milliseconds()) SpTQuerier.storeResult(res) } } diff --git a/include/api/RPC.h b/include/api/RPC.h index 8e9297e..f588540 100644 --- a/include/api/RPC.h +++ b/include/api/RPC.h @@ -37,6 +37,8 @@ namespace API { libcore::QuerySpeedTestResponse QueryCurrentSpeedTests(bool *rpcOK); + libcore::QueryCountryTestResponse QueryCountryTestResults(bool *rpcOK); + private: std::function()> make_rpc_client; std::function onError; diff --git a/include/global/Const.hpp b/include/global/Const.hpp index 20d6861..3a03fc7 100644 --- a/include/global/Const.hpp +++ b/include/global/Const.hpp @@ -37,6 +37,7 @@ namespace Configs { DL, UL, SIMPLEDL, + COUNTRY, }; } diff --git a/include/ui/mainwindow.h b/include/ui/mainwindow.h index 9e6b281..56e1052 100644 --- a/include/ui/mainwindow.h +++ b/include/ui/mainwindow.h @@ -276,6 +276,10 @@ private: void setupConnectionList(); + void querySpeedtest(QDateTime lastProxyListUpdate, const QMap& tag2entID, bool testCurrent); + + void queryCountryTest(const QMap& tag2entID, bool testCurrent); + protected: bool eventFilter(QObject *obj, QEvent *event) override; diff --git a/include/ui/setting/dialog_basic_settings.ui b/include/ui/setting/dialog_basic_settings.ui index 5932ad2..621b2e7 100644 --- a/include/ui/setting/dialog_basic_settings.ui +++ b/include/ui/setting/dialog_basic_settings.ui @@ -217,6 +217,11 @@ Simple Download + + + Only Country + + diff --git a/src/api/RPC.cpp b/src/api/RPC.cpp index eeec570..00f5461 100644 --- a/src/api/RPC.cpp +++ b/src/api/RPC.cpp @@ -222,5 +222,22 @@ if (!Configs::dataStore->core_running) MW_show_log("Cannot invoke method " + QSt } } + libcore::QueryCountryTestResponse Client::QueryCountryTestResults(bool* rpcOK) + { + CHECK("QueryCountryTestResults") + const libcore::EmptyReq request; + libcore::QueryCountryTestResponse reply; + std::string resp, req = spb::pb::serialize(request); + auto err = make_rpc_client()->CallMethod("LibcoreService.QueryCountryTest", &req, &resp); + + if(err.IsNil()) { + reply = spb::pb::deserialize< libcore::QueryCountryTestResponse >( resp ); + *rpcOK = true; + return reply; + } else { + NOT_OK + return reply; + } + } } // namespace API diff --git a/src/ui/mainwindow_grpc.cpp b/src/ui/mainwindow_grpc.cpp index dfd752c..bb072de 100644 --- a/src/ui/mainwindow_grpc.cpp +++ b/src/ui/mainwindow_grpc.cpp @@ -267,6 +267,66 @@ void MainWindow::speedtest_current_group(const QList& tag2entID, bool testCurrent) +{ + bool ok; + auto res = defaultClient->QueryCurrentSpeedTests(&ok); + if (!ok || !res.is_running.value()) + { + return; + } + auto profile = testCurrent ? running : Configs::profileManager->GetProfile(tag2entID[QString::fromStdString(res.result.value().outbound_tag.value())]); + if (profile == nullptr) + { + return; + } + runOnUiThread([=, this, &lastProxyListUpdate] + { + showSpeedtestData = true; + currentSptProfileName = profile->bean->name; + currentTestResult = res.result.value(); + UpdateDataView(); + + if (res.result.value().error.value().empty() && !res.result.value().cancelled.value() && lastProxyListUpdate.msecsTo(QDateTime::currentDateTime()) >= 500) + { + if (!res.result.value().dl_speed.value().empty()) profile->dl_speed = QString::fromStdString(res.result.value().dl_speed.value()); + if (!res.result.value().ul_speed.value().empty()) profile->ul_speed = QString::fromStdString(res.result.value().ul_speed.value()); + if (profile->latency <= 0 && res.result.value().latency.value() > 0) profile->latency = res.result.value().latency.value(); + if (!res.result->server_country.value().empty()) profile->test_country = CountryNameToCode(QString::fromStdString(res.result.value().server_country.value())); + refresh_proxy_list(profile->id); + lastProxyListUpdate = QDateTime::currentDateTime(); + } + }); +} + +void MainWindow::queryCountryTest(const QMap& tag2entID, bool testCurrent) +{ + bool ok; + auto res = defaultClient->QueryCountryTestResults(&ok); + if (!ok || res.results.empty()) + { + return; + } + for (const auto& result : res.results) + { + auto profile = testCurrent ? running : Configs::profileManager->GetProfile(tag2entID[QString::fromStdString(result.outbound_tag.value())]); + if (profile == nullptr) + { + return; + } + runOnUiThread([=, this] + { + if (result.error.value().empty() && !result.cancelled.value()) + { + if (profile->latency <= 0 && result.latency.value() > 0) profile->latency = result.latency.value(); + if (!result.server_country.value().empty()) profile->test_country = CountryNameToCode(QString::fromStdString(result.server_country.value())); + refresh_proxy_list(profile->id); + } + }); + } +} + + void MainWindow::runSpeedTest(const QString& config, bool useDefault, bool testCurrent, const QStringList& outboundTags, const QMap& tag2entID, int entID) { if (stopSpeedtest.load()) { @@ -287,6 +347,8 @@ void MainWindow::runSpeedTest(const QString& config, bool useDefault, bool testC req.simple_download_addr = Configs::dataStore->simple_dl_url.toStdString(); req.test_current = testCurrent; req.timeout_ms = Configs::dataStore->speed_test_timeout_ms; + req.only_country = speedtestConf == Configs::TestConfig::COUNTRY; + req.country_concurrency = Configs::dataStore->test_concurrent; // loop query result auto doneMu = new QMutex; @@ -294,40 +356,19 @@ void MainWindow::runSpeedTest(const QString& config, bool useDefault, bool testC runOnNewThread([=,this] { QDateTime lastProxyListUpdate = QDateTime::currentDateTime(); - bool ok; while (true) { QThread::msleep(100); if (doneMu->tryLock()) { break; } - auto res = defaultClient->QueryCurrentSpeedTests(&ok); - if (!ok || !res.is_running.value()) + if (speedtestConf == Configs::TestConfig::COUNTRY) { - continue; + queryCountryTest(tag2entID, testCurrent); + } else + { + querySpeedtest(lastProxyListUpdate, tag2entID, testCurrent); } - auto profile = testCurrent ? running : Configs::profileManager->GetProfile(tag2entID[QString::fromStdString(res.result.value().outbound_tag.value())]); - if (profile == nullptr) - { - continue; - } - runOnUiThread([=, this, &lastProxyListUpdate] - { - showSpeedtestData = true; - currentSptProfileName = profile->bean->name; - currentTestResult = res.result.value(); - UpdateDataView(); - - if (res.result.value().error.value().empty() && !res.result.value().cancelled.value() && lastProxyListUpdate.msecsTo(QDateTime::currentDateTime()) >= 500) - { - if (!res.result.value().dl_speed.value().empty()) profile->dl_speed = QString::fromStdString(res.result.value().dl_speed.value()); - if (!res.result.value().ul_speed.value().empty()) profile->ul_speed = QString::fromStdString(res.result.value().ul_speed.value()); - if (profile->latency <= 0 && res.result.value().latency.value() > 0) profile->latency = res.result.value().latency.value(); - if (!res.result->server_country.value().empty()) profile->test_country = CountryNameToCode(QString::fromStdString(res.result.value().server_country.value())); - refresh_proxy_list(profile->id); - lastProxyListUpdate = QDateTime::currentDateTime(); - } - }); } runOnUiThread([=, this] {