update and fix local dns

This commit is contained in:
armv9 2025-09-06 21:02:29 +09:00
parent a75a15146e
commit bbf59e3b46
4 changed files with 296 additions and 169 deletions

View File

@ -6,16 +6,14 @@ import android.os.CancellationSignal
import android.system.ErrnoException
import androidx.annotation.RequiresApi
import io.nekohasekai.sagernet.SagerNet
import io.nekohasekai.sagernet.ktx.tryResumeWithException
import io.nekohasekai.sagernet.ktx.Logs
import io.nekohasekai.sagernet.ktx.runOnIoDispatcher
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.asExecutor
import kotlinx.coroutines.runBlocking
import libcore.ExchangeContext
import libcore.LocalDNSTransport
import java.net.InetAddress
import java.net.UnknownHostException
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
object LocalResolverImpl : LocalDNSTransport {
@ -27,110 +25,126 @@ object LocalResolverImpl : LocalDNSTransport {
return Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q
}
override fun networkHandle(): Long {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
return SagerNet.underlyingNetwork?.networkHandle ?: 0
}
return 0
}
@RequiresApi(Build.VERSION_CODES.Q)
override fun exchange(ctx: ExchangeContext, message: ByteArray) {
return runBlocking {
suspendCoroutine { continuation ->
val signal = CancellationSignal()
ctx.onCancel(signal::cancel)
val callback = object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) {
// exchange don't generate rcode error
ctx.rawSuccess(answer)
continuation.resume(Unit)
}
val signal = CancellationSignal()
ctx.onCancel(signal::cancel)
override fun onError(error: DnsResolver.DnsException) {
when (val cause = error.cause) {
is ErrnoException -> {
ctx.errnoCode(cause.errno)
continuation.resume(Unit)
return
}
val callback = object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) {
ctx.rawSuccess(answer)
}
override fun onError(error: DnsResolver.DnsException) {
val cause = error.cause
if (cause is ErrnoException) {
ctx.errnoCode(cause.errno)
} else {
Logs.w(error)
ctx.errnoCode(114514)
}
}
}
DnsResolver.getInstance().rawQuery(
SagerNet.underlyingNetwork,
message,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
}
override fun lookup(ctx: ExchangeContext, network: String, domain: String) {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
val signal = CancellationSignal()
ctx.onCancel(signal::cancel)
val callback = object : DnsResolver.Callback<Collection<InetAddress>> {
override fun onAnswer(answer: Collection<InetAddress>, rcode: Int) {
try {
if (rcode == 0) {
ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n"))
} else {
ctx.errorCode(rcode)
}
continuation.tryResumeWithException(error)
} catch (e: Exception) {
Logs.w(e)
ctx.errnoCode(114514)
}
}
DnsResolver.getInstance().rawQuery(
override fun onError(error: DnsResolver.DnsException) {
try {
val cause = error.cause
if (cause is ErrnoException) {
ctx.errnoCode(cause.errno)
} else {
Logs.w(error)
ctx.errnoCode(114514)
}
} catch (e: Exception) {
Logs.w(e)
ctx.errnoCode(114514)
}
}
}
val type = when {
network.endsWith("4") -> DnsResolver.TYPE_A
network.endsWith("6") -> DnsResolver.TYPE_AAAA
else -> null
}
if (type != null) {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork,
message,
domain,
type,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
} else {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork,
domain,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
}
}
}
override fun lookup(ctx: ExchangeContext, network: String, domain: String) {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
return runBlocking {
suspendCoroutine { continuation ->
val signal = CancellationSignal()
ctx.onCancel(signal::cancel)
val callback = object : DnsResolver.Callback<Collection<InetAddress>> {
override fun onAnswer(answer: Collection<InetAddress>, rcode: Int) {
if (rcode == 0) {
ctx.success((answer as Collection<InetAddress?>).mapNotNull { it?.hostAddress }
.joinToString("\n"))
} else {
ctx.errorCode(rcode)
}
continuation.resume(Unit)
}
override fun onError(error: DnsResolver.DnsException) {
when (val cause = error.cause) {
is ErrnoException -> {
ctx.errnoCode(cause.errno)
continuation.resume(Unit)
return
}
}
continuation.tryResumeWithException(error)
}
}
val type = when {
network.endsWith("4") -> DnsResolver.TYPE_A
network.endsWith("6") -> DnsResolver.TYPE_AAAA
else -> null
}
if (type != null) {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork,
domain,
type,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
} else {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork,
domain,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
}
}
}
} else {
val answer = try {
val u = SagerNet.underlyingNetwork
if (u != null) {
u.getAllByName(domain)
} else {
InetAddress.getAllByName(domain)
runOnIoDispatcher {
// 老版本系统,继续用阻塞的 InetAddress
try {
val u = SagerNet.underlyingNetwork
val answer = if (u != null) {
u.getAllByName(domain)
} else {
InetAddress.getAllByName(domain)
}
if (answer != null) {
ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n"))
} else {
ctx.errnoCode(114514)
}
} catch (e: UnknownHostException) {
ctx.errorCode(RCODE_NXDOMAIN)
} catch (e: Exception) {
Logs.w(e)
ctx.errnoCode(114514)
}
} catch (e: UnknownHostException) {
ctx.errorCode(RCODE_NXDOMAIN)
return
}
ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n"))
}
}

View File

@ -114,10 +114,7 @@ func nekoboxAndroidDNSTransportRegistry(localTransport LocalDNSTransport) *dns.T
local.RegisterTransport(registry)
} else {
dns.RegisterTransport(registry, "local", func(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
return &platformLocalDNSTransport{
iif: localTransport,
tag: tag,
}, nil
return newPlatformTransport(localTransport, tag, options), nil
})
}

114
libcore/dns_android.go Normal file
View File

@ -0,0 +1,114 @@
//go:build android && cgo
package libcore
/*
#include <stdint.h>
#include <stddef.h>
#include <stdlib.h>
#include <dlfcn.h>
typedef int (*android_res_nsend_t)(uint64_t network, const uint8_t* msg, size_t msglen, int flags);
typedef int (*android_res_nresult_t)(int fd, int* rcode, uint8_t* resp, size_t resp_len);
static int call_android_res_nsend(void* sym, uint64_t network, const uint8_t* msg, size_t msglen, int flags) {
android_res_nsend_t f = (android_res_nsend_t)sym;
if (!f) return -1;
return f(network, msg, msglen, flags);
}
static int call_android_res_nresult(void* sym, int fd, int* rcode, uint8_t* resp, size_t resp_len) {
android_res_nresult_t f = (android_res_nresult_t)sym;
if (!f) return -1;
return f(fd, rcode, resp, resp_len);
}
*/
import "C"
import (
"context"
"errors"
"os"
"unsafe"
"golang.org/x/sys/unix"
)
func init() {
libname := C.CString("libandroid.so")
defer C.free(unsafe.Pointer(libname))
libHandle := C.dlopen(libname, C.int(C.RTLD_NOW))
if libHandle == nil {
return
}
symNameSend := C.CString("android_res_nsend")
defer C.free(unsafe.Pointer(symNameSend))
androidResNSendSym := C.dlsym(libHandle, symNameSend)
if androidResNSendSym == nil {
return
}
symNameResult := C.CString("android_res_nresult")
defer C.free(unsafe.Pointer(symNameResult))
androidResNResultSym := C.dlsym(libHandle, symNameResult)
if androidResNResultSym == nil {
return
}
callAndroidResNSend := func(network uint64, msg []byte) (int, error) {
if len(msg) == 0 {
return 0, errors.New("empty payload")
}
msgPtr := (*C.uint8_t)(unsafe.Pointer(&msg[0]))
msgLen := C.size_t(len(msg))
ret := C.call_android_res_nsend(androidResNSendSym, C.uint64_t(network), msgPtr, msgLen, C.int(0))
return int(ret), nil
}
callAndroidResNResult := func(fd int, resp []byte) (int, int) {
if len(resp) == 0 {
return 0, 0
}
respPtr := (*C.uint8_t)(unsafe.Pointer(&resp[0]))
respLen := C.size_t(len(resp))
var rcode C.int
n := C.call_android_res_nresult(androidResNResultSym, C.int(fd), &rcode, respPtr, respLen)
return int(rcode), int(n)
}
// set rawQueryFunc
rawQueryFunc = func(networkHandle int64, request []byte) ([]byte, error) {
fd, err := callAndroidResNSend(uint64(networkHandle), request)
if err != nil {
return nil, err
}
if fd < 0 {
return nil, unix.Errno(-fd)
}
// wait for response (timeout 5000 ms)
pfds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN | unix.POLLERR}}
nReady, err := unix.Poll(pfds, 5000)
if err != nil {
unix.Close(fd)
return nil, err
}
if nReady == 0 {
unix.Close(fd)
return nil, context.DeadlineExceeded
}
// read response into buffer
response := make([]byte, 8192)
_, n := callAndroidResNResult(fd, response)
if n < 0 {
return nil, unix.Errno(-n)
}
if n == 0 {
return nil, os.ErrInvalid
}
return response[:n], nil
}
}

View File

@ -6,11 +6,13 @@ import (
"context"
"net/netip"
"strings"
"sync"
"syscall"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
@ -19,8 +21,11 @@ import (
mDNS "github.com/miekg/dns"
)
var rawQueryFunc func(networkHandle int64, request []byte) ([]byte, error)
type LocalDNSTransport interface {
Raw() bool
NetworkHandle() int64
Lookup(ctx *ExchangeContext, network string, domain string) error
Exchange(ctx *ExchangeContext, message []byte) error
}
@ -28,98 +33,91 @@ type LocalDNSTransport interface {
var gLocalDNSTransport *platformLocalDNSTransport = nil
type platformLocalDNSTransport struct {
dns.TransportAdapter
iif LocalDNSTransport
tag string
raw bool
}
func (p *platformLocalDNSTransport) Start(adapter.StartStage) error {
func newPlatformTransport(iif LocalDNSTransport, tag string, options option.LocalDNSServerOptions) *platformLocalDNSTransport {
return &platformLocalDNSTransport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(constant.DNSTypeLocal, tag, options),
iif: iif,
raw: iif.Raw(),
}
}
func (p *platformLocalDNSTransport) Start(stage adapter.StartStage) error {
return nil
}
func (p *platformLocalDNSTransport) Reset() {
}
func (p *platformLocalDNSTransport) Close() error {
return nil
}
func (p *platformLocalDNSTransport) Raw() bool {
return p.iif.Raw()
}
func (p *platformLocalDNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
messageBytes, err := message.Pack()
if err != nil {
return nil, err
}
response := &ExchangeContext{
context: ctx,
}
var responseMessage *mDNS.Msg
return responseMessage, task.Run(ctx, func() error {
err = p.iif.Exchange(response, messageBytes)
if err != nil {
return err
}
if response.error != nil {
return response.error
}
responseMessage = &response.message
return nil
})
}
if p.raw && rawQueryFunc != nil {
// Raw - Android 10 及以上才有
func (p *platformLocalDNSTransport) Lookup(ctx context.Context, domain string, strategy constant.DomainStrategy) ([]netip.Addr, error) {
var network string
switch strategy {
case constant.DomainStrategyIPv4Only:
network = "ip4"
case constant.DomainStrategyPreferIPv6:
network = "ip6"
default:
network = "ip"
}
response := &ExchangeContext{
context: ctx,
}
var responseAddr []netip.Addr
return responseAddr, task.Run(ctx, func() error {
err := p.iif.Lookup(response, network, domain)
messageBytes, err := message.Pack()
if err != nil {
return err
return nil, err
}
if response.error != nil {
return response.error
msg, err := rawQueryFunc(p.iif.NetworkHandle(), messageBytes)
if err != nil {
return nil, err
}
switch strategy {
case constant.DomainStrategyIPv4Only:
responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool {
return it.Is4()
})
case constant.DomainStrategyPreferIPv6:
responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool {
return it.Is6()
})
responseMessage := new(mDNS.Msg)
err = responseMessage.Unpack(msg)
if err != nil {
return nil, err
}
return responseMessage, nil
} else {
// Lookup - Android 10 以下
question := message.Question[0]
var network string
switch question.Qtype {
case mDNS.TypeA:
network = "ip4"
case mDNS.TypeAAAA:
network = "ip6"
default:
responseAddr = response.addresses
return nil, E.New("only IP queries are supported by current version of Android")
}
/*if len(responseAddr) == 0 {
response.error = dns.RCodeSuccess
}*/
return nil
})
}
func (p *platformLocalDNSTransport) Tag() string {
return p.tag
}
done := make(chan struct{})
response := &ExchangeContext{
context: ctx,
done: sync.OnceFunc(func() {
close(done)
}),
}
func (p *platformLocalDNSTransport) Type() string {
return "local"
}
func (p *platformLocalDNSTransport) Dependencies() []string {
return nil
var responseAddrs []netip.Addr
var group task.Group
group.Append0(func(ctx context.Context) error {
err := p.iif.Lookup(response, network, question.Name)
if err != nil {
return err
}
select {
case <-done:
case <-ctx.Done():
return context.Canceled
}
if response.error != nil {
return response.error
}
responseAddrs = response.addresses
return nil
})
err := group.Run(ctx)
if err != nil {
return nil, err
}
return dns.FixedResponse(message.Id, question, responseAddrs, constant.DefaultDNSTTL), nil
}
}
type Func interface {
@ -131,6 +129,7 @@ type ExchangeContext struct {
message mDNS.Msg
addresses []netip.Addr
error error
done func()
}
func (c *ExchangeContext) OnCancel(callback Func) {
@ -153,12 +152,15 @@ func (c *ExchangeContext) RawSuccess(result []byte) {
if err != nil {
c.error = E.Cause(err, "parse response")
}
c.done()
}
func (c *ExchangeContext) ErrorCode(code int32) {
c.error = dns.RcodeError(code)
c.done()
}
func (c *ExchangeContext) ErrnoCode(code int32) {
c.error = syscall.Errno(code)
c.done()
}