From bbf59e3b4625c04e86f3d78d2944f5839bc06f74 Mon Sep 17 00:00:00 2001 From: armv9 <48624112+arm64v8a@users.noreply.github.com> Date: Sat, 6 Sep 2025 21:02:29 +0900 Subject: [PATCH] update and fix local dns --- .../moe/matsuri/nb4a/net/LocalResolverImpl.kt | 198 ++++++++++-------- libcore/box_include.go | 5 +- libcore/dns_android.go | 114 ++++++++++ libcore/dns_box.go | 148 ++++++------- 4 files changed, 296 insertions(+), 169 deletions(-) create mode 100644 libcore/dns_android.go diff --git a/app/src/main/java/moe/matsuri/nb4a/net/LocalResolverImpl.kt b/app/src/main/java/moe/matsuri/nb4a/net/LocalResolverImpl.kt index 27c08e3..7de35b4 100644 --- a/app/src/main/java/moe/matsuri/nb4a/net/LocalResolverImpl.kt +++ b/app/src/main/java/moe/matsuri/nb4a/net/LocalResolverImpl.kt @@ -6,16 +6,14 @@ import android.os.CancellationSignal import android.system.ErrnoException import androidx.annotation.RequiresApi import io.nekohasekai.sagernet.SagerNet -import io.nekohasekai.sagernet.ktx.tryResumeWithException +import io.nekohasekai.sagernet.ktx.Logs +import io.nekohasekai.sagernet.ktx.runOnIoDispatcher import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.asExecutor -import kotlinx.coroutines.runBlocking import libcore.ExchangeContext import libcore.LocalDNSTransport import java.net.InetAddress import java.net.UnknownHostException -import kotlin.coroutines.resume -import kotlin.coroutines.suspendCoroutine object LocalResolverImpl : LocalDNSTransport { @@ -27,110 +25,126 @@ object LocalResolverImpl : LocalDNSTransport { return Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q } + override fun networkHandle(): Long { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + return SagerNet.underlyingNetwork?.networkHandle ?: 0 + } + return 0 + } + @RequiresApi(Build.VERSION_CODES.Q) override fun exchange(ctx: ExchangeContext, message: ByteArray) { - return runBlocking { - suspendCoroutine { continuation -> - val signal = CancellationSignal() - ctx.onCancel(signal::cancel) - val callback = object : DnsResolver.Callback { - override fun onAnswer(answer: ByteArray, rcode: Int) { - // exchange don't generate rcode error - ctx.rawSuccess(answer) - continuation.resume(Unit) - } + val signal = CancellationSignal() + ctx.onCancel(signal::cancel) - override fun onError(error: DnsResolver.DnsException) { - when (val cause = error.cause) { - is ErrnoException -> { - ctx.errnoCode(cause.errno) - continuation.resume(Unit) - return - } + val callback = object : DnsResolver.Callback { + override fun onAnswer(answer: ByteArray, rcode: Int) { + ctx.rawSuccess(answer) + } + + override fun onError(error: DnsResolver.DnsException) { + val cause = error.cause + if (cause is ErrnoException) { + ctx.errnoCode(cause.errno) + } else { + Logs.w(error) + ctx.errnoCode(114514) + } + } + } + + DnsResolver.getInstance().rawQuery( + SagerNet.underlyingNetwork, + message, + DnsResolver.FLAG_NO_RETRY, + Dispatchers.IO.asExecutor(), + signal, + callback + ) + } + + override fun lookup(ctx: ExchangeContext, network: String, domain: String) { + if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { + val signal = CancellationSignal() + ctx.onCancel(signal::cancel) + + val callback = object : DnsResolver.Callback> { + override fun onAnswer(answer: Collection, rcode: Int) { + try { + if (rcode == 0) { + ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n")) + } else { + ctx.errorCode(rcode) } - continuation.tryResumeWithException(error) + } catch (e: Exception) { + Logs.w(e) + ctx.errnoCode(114514) } } - DnsResolver.getInstance().rawQuery( + + override fun onError(error: DnsResolver.DnsException) { + try { + val cause = error.cause + if (cause is ErrnoException) { + ctx.errnoCode(cause.errno) + } else { + Logs.w(error) + ctx.errnoCode(114514) + } + } catch (e: Exception) { + Logs.w(e) + ctx.errnoCode(114514) + } + } + } + + val type = when { + network.endsWith("4") -> DnsResolver.TYPE_A + network.endsWith("6") -> DnsResolver.TYPE_AAAA + else -> null + } + if (type != null) { + DnsResolver.getInstance().query( SagerNet.underlyingNetwork, - message, + domain, + type, + DnsResolver.FLAG_NO_RETRY, + Dispatchers.IO.asExecutor(), + signal, + callback + ) + } else { + DnsResolver.getInstance().query( + SagerNet.underlyingNetwork, + domain, DnsResolver.FLAG_NO_RETRY, Dispatchers.IO.asExecutor(), signal, callback ) } - } - } - - override fun lookup(ctx: ExchangeContext, network: String, domain: String) { - if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) { - return runBlocking { - suspendCoroutine { continuation -> - val signal = CancellationSignal() - ctx.onCancel(signal::cancel) - val callback = object : DnsResolver.Callback> { - override fun onAnswer(answer: Collection, rcode: Int) { - if (rcode == 0) { - ctx.success((answer as Collection).mapNotNull { it?.hostAddress } - .joinToString("\n")) - } else { - ctx.errorCode(rcode) - } - continuation.resume(Unit) - } - - override fun onError(error: DnsResolver.DnsException) { - when (val cause = error.cause) { - is ErrnoException -> { - ctx.errnoCode(cause.errno) - continuation.resume(Unit) - return - } - } - continuation.tryResumeWithException(error) - } - } - val type = when { - network.endsWith("4") -> DnsResolver.TYPE_A - network.endsWith("6") -> DnsResolver.TYPE_AAAA - else -> null - } - if (type != null) { - DnsResolver.getInstance().query( - SagerNet.underlyingNetwork, - domain, - type, - DnsResolver.FLAG_NO_RETRY, - Dispatchers.IO.asExecutor(), - signal, - callback - ) - } else { - DnsResolver.getInstance().query( - SagerNet.underlyingNetwork, - domain, - DnsResolver.FLAG_NO_RETRY, - Dispatchers.IO.asExecutor(), - signal, - callback - ) - } - } - } } else { - val answer = try { - val u = SagerNet.underlyingNetwork - if (u != null) { - u.getAllByName(domain) - } else { - InetAddress.getAllByName(domain) + runOnIoDispatcher { + // 老版本系统,继续用阻塞的 InetAddress + try { + val u = SagerNet.underlyingNetwork + val answer = if (u != null) { + u.getAllByName(domain) + } else { + InetAddress.getAllByName(domain) + } + if (answer != null) { + ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n")) + } else { + ctx.errnoCode(114514) + } + } catch (e: UnknownHostException) { + ctx.errorCode(RCODE_NXDOMAIN) + } catch (e: Exception) { + Logs.w(e) + ctx.errnoCode(114514) } - } catch (e: UnknownHostException) { - ctx.errorCode(RCODE_NXDOMAIN) - return } - ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n")) } } diff --git a/libcore/box_include.go b/libcore/box_include.go index 082705d..5a95cd8 100644 --- a/libcore/box_include.go +++ b/libcore/box_include.go @@ -114,10 +114,7 @@ func nekoboxAndroidDNSTransportRegistry(localTransport LocalDNSTransport) *dns.T local.RegisterTransport(registry) } else { dns.RegisterTransport(registry, "local", func(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) { - return &platformLocalDNSTransport{ - iif: localTransport, - tag: tag, - }, nil + return newPlatformTransport(localTransport, tag, options), nil }) } diff --git a/libcore/dns_android.go b/libcore/dns_android.go new file mode 100644 index 0000000..184c4ab --- /dev/null +++ b/libcore/dns_android.go @@ -0,0 +1,114 @@ +//go:build android && cgo + +package libcore + +/* +#include +#include +#include +#include + +typedef int (*android_res_nsend_t)(uint64_t network, const uint8_t* msg, size_t msglen, int flags); +typedef int (*android_res_nresult_t)(int fd, int* rcode, uint8_t* resp, size_t resp_len); + +static int call_android_res_nsend(void* sym, uint64_t network, const uint8_t* msg, size_t msglen, int flags) { + android_res_nsend_t f = (android_res_nsend_t)sym; + if (!f) return -1; + return f(network, msg, msglen, flags); +} + +static int call_android_res_nresult(void* sym, int fd, int* rcode, uint8_t* resp, size_t resp_len) { + android_res_nresult_t f = (android_res_nresult_t)sym; + if (!f) return -1; + return f(fd, rcode, resp, resp_len); +} +*/ +import "C" + +import ( + "context" + "errors" + "os" + "unsafe" + + "golang.org/x/sys/unix" +) + +func init() { + libname := C.CString("libandroid.so") + defer C.free(unsafe.Pointer(libname)) + + libHandle := C.dlopen(libname, C.int(C.RTLD_NOW)) + if libHandle == nil { + return + } + + symNameSend := C.CString("android_res_nsend") + defer C.free(unsafe.Pointer(symNameSend)) + androidResNSendSym := C.dlsym(libHandle, symNameSend) + if androidResNSendSym == nil { + return + } + + symNameResult := C.CString("android_res_nresult") + defer C.free(unsafe.Pointer(symNameResult)) + androidResNResultSym := C.dlsym(libHandle, symNameResult) + if androidResNResultSym == nil { + return + } + + callAndroidResNSend := func(network uint64, msg []byte) (int, error) { + if len(msg) == 0 { + return 0, errors.New("empty payload") + } + msgPtr := (*C.uint8_t)(unsafe.Pointer(&msg[0])) + msgLen := C.size_t(len(msg)) + ret := C.call_android_res_nsend(androidResNSendSym, C.uint64_t(network), msgPtr, msgLen, C.int(0)) + return int(ret), nil + } + + callAndroidResNResult := func(fd int, resp []byte) (int, int) { + if len(resp) == 0 { + return 0, 0 + } + respPtr := (*C.uint8_t)(unsafe.Pointer(&resp[0])) + respLen := C.size_t(len(resp)) + var rcode C.int + n := C.call_android_res_nresult(androidResNResultSym, C.int(fd), &rcode, respPtr, respLen) + return int(rcode), int(n) + } + + // set rawQueryFunc + rawQueryFunc = func(networkHandle int64, request []byte) ([]byte, error) { + fd, err := callAndroidResNSend(uint64(networkHandle), request) + if err != nil { + return nil, err + } + if fd < 0 { + return nil, unix.Errno(-fd) + } + + // wait for response (timeout 5000 ms) + pfds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN | unix.POLLERR}} + nReady, err := unix.Poll(pfds, 5000) + if err != nil { + unix.Close(fd) + return nil, err + } + if nReady == 0 { + unix.Close(fd) + return nil, context.DeadlineExceeded + } + + // read response into buffer + response := make([]byte, 8192) + _, n := callAndroidResNResult(fd, response) + if n < 0 { + return nil, unix.Errno(-n) + } + if n == 0 { + return nil, os.ErrInvalid + } + return response[:n], nil + } +} diff --git a/libcore/dns_box.go b/libcore/dns_box.go index 8d2ff48..3f1cd5f 100644 --- a/libcore/dns_box.go +++ b/libcore/dns_box.go @@ -6,11 +6,13 @@ import ( "context" "net/netip" "strings" + "sync" "syscall" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" + "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -19,8 +21,11 @@ import ( mDNS "github.com/miekg/dns" ) +var rawQueryFunc func(networkHandle int64, request []byte) ([]byte, error) + type LocalDNSTransport interface { Raw() bool + NetworkHandle() int64 Lookup(ctx *ExchangeContext, network string, domain string) error Exchange(ctx *ExchangeContext, message []byte) error } @@ -28,98 +33,91 @@ type LocalDNSTransport interface { var gLocalDNSTransport *platformLocalDNSTransport = nil type platformLocalDNSTransport struct { + dns.TransportAdapter iif LocalDNSTransport - tag string + raw bool } -func (p *platformLocalDNSTransport) Start(adapter.StartStage) error { +func newPlatformTransport(iif LocalDNSTransport, tag string, options option.LocalDNSServerOptions) *platformLocalDNSTransport { + return &platformLocalDNSTransport{ + TransportAdapter: dns.NewTransportAdapterWithLocalOptions(constant.DNSTypeLocal, tag, options), + iif: iif, + raw: iif.Raw(), + } +} + +func (p *platformLocalDNSTransport) Start(stage adapter.StartStage) error { return nil } -func (p *platformLocalDNSTransport) Reset() { -} - func (p *platformLocalDNSTransport) Close() error { return nil } -func (p *platformLocalDNSTransport) Raw() bool { - return p.iif.Raw() -} - func (p *platformLocalDNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - messageBytes, err := message.Pack() - if err != nil { - return nil, err - } - response := &ExchangeContext{ - context: ctx, - } - var responseMessage *mDNS.Msg - return responseMessage, task.Run(ctx, func() error { - err = p.iif.Exchange(response, messageBytes) - if err != nil { - return err - } - if response.error != nil { - return response.error - } - responseMessage = &response.message - return nil - }) -} + if p.raw && rawQueryFunc != nil { + // Raw - Android 10 及以上才有 -func (p *platformLocalDNSTransport) Lookup(ctx context.Context, domain string, strategy constant.DomainStrategy) ([]netip.Addr, error) { - var network string - switch strategy { - case constant.DomainStrategyIPv4Only: - network = "ip4" - case constant.DomainStrategyPreferIPv6: - network = "ip6" - default: - network = "ip" - } - response := &ExchangeContext{ - context: ctx, - } - var responseAddr []netip.Addr - return responseAddr, task.Run(ctx, func() error { - err := p.iif.Lookup(response, network, domain) + messageBytes, err := message.Pack() if err != nil { - return err + return nil, err } - if response.error != nil { - return response.error + msg, err := rawQueryFunc(p.iif.NetworkHandle(), messageBytes) + if err != nil { + return nil, err } - switch strategy { - case constant.DomainStrategyIPv4Only: - responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool { - return it.Is4() - }) - case constant.DomainStrategyPreferIPv6: - responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool { - return it.Is6() - }) + responseMessage := new(mDNS.Msg) + err = responseMessage.Unpack(msg) + if err != nil { + return nil, err + } + return responseMessage, nil + } else { + // Lookup - Android 10 以下 + + question := message.Question[0] + var network string + switch question.Qtype { + case mDNS.TypeA: + network = "ip4" + case mDNS.TypeAAAA: + network = "ip6" default: - responseAddr = response.addresses + return nil, E.New("only IP queries are supported by current version of Android") } - /*if len(responseAddr) == 0 { - response.error = dns.RCodeSuccess - }*/ - return nil - }) -} -func (p *platformLocalDNSTransport) Tag() string { - return p.tag -} + done := make(chan struct{}) + response := &ExchangeContext{ + context: ctx, + done: sync.OnceFunc(func() { + close(done) + }), + } -func (p *platformLocalDNSTransport) Type() string { - return "local" -} - -func (p *platformLocalDNSTransport) Dependencies() []string { - return nil + var responseAddrs []netip.Addr + var group task.Group + group.Append0(func(ctx context.Context) error { + err := p.iif.Lookup(response, network, question.Name) + if err != nil { + return err + } + select { + case <-done: + case <-ctx.Done(): + return context.Canceled + } + if response.error != nil { + return response.error + } + responseAddrs = response.addresses + return nil + }) + err := group.Run(ctx) + if err != nil { + return nil, err + } + return dns.FixedResponse(message.Id, question, responseAddrs, constant.DefaultDNSTTL), nil + } } type Func interface { @@ -131,6 +129,7 @@ type ExchangeContext struct { message mDNS.Msg addresses []netip.Addr error error + done func() } func (c *ExchangeContext) OnCancel(callback Func) { @@ -153,12 +152,15 @@ func (c *ExchangeContext) RawSuccess(result []byte) { if err != nil { c.error = E.Cause(err, "parse response") } + c.done() } func (c *ExchangeContext) ErrorCode(code int32) { c.error = dns.RcodeError(code) + c.done() } func (c *ExchangeContext) ErrnoCode(code int32) { c.error = syscall.Errno(code) + c.done() }