update and fix local dns

This commit is contained in:
armv9 2025-09-06 21:02:29 +09:00
parent a75a15146e
commit bbf59e3b46
4 changed files with 296 additions and 169 deletions

View File

@ -6,16 +6,14 @@ import android.os.CancellationSignal
import android.system.ErrnoException import android.system.ErrnoException
import androidx.annotation.RequiresApi import androidx.annotation.RequiresApi
import io.nekohasekai.sagernet.SagerNet import io.nekohasekai.sagernet.SagerNet
import io.nekohasekai.sagernet.ktx.tryResumeWithException import io.nekohasekai.sagernet.ktx.Logs
import io.nekohasekai.sagernet.ktx.runOnIoDispatcher
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.asExecutor import kotlinx.coroutines.asExecutor
import kotlinx.coroutines.runBlocking
import libcore.ExchangeContext import libcore.ExchangeContext
import libcore.LocalDNSTransport import libcore.LocalDNSTransport
import java.net.InetAddress import java.net.InetAddress
import java.net.UnknownHostException import java.net.UnknownHostException
import kotlin.coroutines.resume
import kotlin.coroutines.suspendCoroutine
object LocalResolverImpl : LocalDNSTransport { object LocalResolverImpl : LocalDNSTransport {
@ -27,110 +25,126 @@ object LocalResolverImpl : LocalDNSTransport {
return Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q return Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q
} }
override fun networkHandle(): Long {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
return SagerNet.underlyingNetwork?.networkHandle ?: 0
}
return 0
}
@RequiresApi(Build.VERSION_CODES.Q) @RequiresApi(Build.VERSION_CODES.Q)
override fun exchange(ctx: ExchangeContext, message: ByteArray) { override fun exchange(ctx: ExchangeContext, message: ByteArray) {
return runBlocking { val signal = CancellationSignal()
suspendCoroutine { continuation -> ctx.onCancel(signal::cancel)
val signal = CancellationSignal()
ctx.onCancel(signal::cancel)
val callback = object : DnsResolver.Callback<ByteArray> {
override fun onAnswer(answer: ByteArray, rcode: Int) {
// exchange don't generate rcode error
ctx.rawSuccess(answer)
continuation.resume(Unit)
}
override fun onError(error: DnsResolver.DnsException) { val callback = object : DnsResolver.Callback<ByteArray> {
when (val cause = error.cause) { override fun onAnswer(answer: ByteArray, rcode: Int) {
is ErrnoException -> { ctx.rawSuccess(answer)
ctx.errnoCode(cause.errno) }
continuation.resume(Unit)
return override fun onError(error: DnsResolver.DnsException) {
} val cause = error.cause
if (cause is ErrnoException) {
ctx.errnoCode(cause.errno)
} else {
Logs.w(error)
ctx.errnoCode(114514)
}
}
}
DnsResolver.getInstance().rawQuery(
SagerNet.underlyingNetwork,
message,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
}
override fun lookup(ctx: ExchangeContext, network: String, domain: String) {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
val signal = CancellationSignal()
ctx.onCancel(signal::cancel)
val callback = object : DnsResolver.Callback<Collection<InetAddress>> {
override fun onAnswer(answer: Collection<InetAddress>, rcode: Int) {
try {
if (rcode == 0) {
ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n"))
} else {
ctx.errorCode(rcode)
} }
continuation.tryResumeWithException(error) } catch (e: Exception) {
Logs.w(e)
ctx.errnoCode(114514)
} }
} }
DnsResolver.getInstance().rawQuery(
override fun onError(error: DnsResolver.DnsException) {
try {
val cause = error.cause
if (cause is ErrnoException) {
ctx.errnoCode(cause.errno)
} else {
Logs.w(error)
ctx.errnoCode(114514)
}
} catch (e: Exception) {
Logs.w(e)
ctx.errnoCode(114514)
}
}
}
val type = when {
network.endsWith("4") -> DnsResolver.TYPE_A
network.endsWith("6") -> DnsResolver.TYPE_AAAA
else -> null
}
if (type != null) {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork, SagerNet.underlyingNetwork,
message, domain,
type,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
} else {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork,
domain,
DnsResolver.FLAG_NO_RETRY, DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(), Dispatchers.IO.asExecutor(),
signal, signal,
callback callback
) )
} }
}
}
override fun lookup(ctx: ExchangeContext, network: String, domain: String) {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
return runBlocking {
suspendCoroutine { continuation ->
val signal = CancellationSignal()
ctx.onCancel(signal::cancel)
val callback = object : DnsResolver.Callback<Collection<InetAddress>> {
override fun onAnswer(answer: Collection<InetAddress>, rcode: Int) {
if (rcode == 0) {
ctx.success((answer as Collection<InetAddress?>).mapNotNull { it?.hostAddress }
.joinToString("\n"))
} else {
ctx.errorCode(rcode)
}
continuation.resume(Unit)
}
override fun onError(error: DnsResolver.DnsException) {
when (val cause = error.cause) {
is ErrnoException -> {
ctx.errnoCode(cause.errno)
continuation.resume(Unit)
return
}
}
continuation.tryResumeWithException(error)
}
}
val type = when {
network.endsWith("4") -> DnsResolver.TYPE_A
network.endsWith("6") -> DnsResolver.TYPE_AAAA
else -> null
}
if (type != null) {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork,
domain,
type,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
} else {
DnsResolver.getInstance().query(
SagerNet.underlyingNetwork,
domain,
DnsResolver.FLAG_NO_RETRY,
Dispatchers.IO.asExecutor(),
signal,
callback
)
}
}
}
} else { } else {
val answer = try { runOnIoDispatcher {
val u = SagerNet.underlyingNetwork // 老版本系统,继续用阻塞的 InetAddress
if (u != null) { try {
u.getAllByName(domain) val u = SagerNet.underlyingNetwork
} else { val answer = if (u != null) {
InetAddress.getAllByName(domain) u.getAllByName(domain)
} else {
InetAddress.getAllByName(domain)
}
if (answer != null) {
ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n"))
} else {
ctx.errnoCode(114514)
}
} catch (e: UnknownHostException) {
ctx.errorCode(RCODE_NXDOMAIN)
} catch (e: Exception) {
Logs.w(e)
ctx.errnoCode(114514)
} }
} catch (e: UnknownHostException) {
ctx.errorCode(RCODE_NXDOMAIN)
return
} }
ctx.success(answer.mapNotNull { it.hostAddress }.joinToString("\n"))
} }
} }

View File

@ -114,10 +114,7 @@ func nekoboxAndroidDNSTransportRegistry(localTransport LocalDNSTransport) *dns.T
local.RegisterTransport(registry) local.RegisterTransport(registry)
} else { } else {
dns.RegisterTransport(registry, "local", func(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) { dns.RegisterTransport(registry, "local", func(ctx context.Context, logger log.ContextLogger, tag string, options option.LocalDNSServerOptions) (adapter.DNSTransport, error) {
return &platformLocalDNSTransport{ return newPlatformTransport(localTransport, tag, options), nil
iif: localTransport,
tag: tag,
}, nil
}) })
} }

114
libcore/dns_android.go Normal file
View File

@ -0,0 +1,114 @@
//go:build android && cgo
package libcore
/*
#include <stdint.h>
#include <stddef.h>
#include <stdlib.h>
#include <dlfcn.h>
typedef int (*android_res_nsend_t)(uint64_t network, const uint8_t* msg, size_t msglen, int flags);
typedef int (*android_res_nresult_t)(int fd, int* rcode, uint8_t* resp, size_t resp_len);
static int call_android_res_nsend(void* sym, uint64_t network, const uint8_t* msg, size_t msglen, int flags) {
android_res_nsend_t f = (android_res_nsend_t)sym;
if (!f) return -1;
return f(network, msg, msglen, flags);
}
static int call_android_res_nresult(void* sym, int fd, int* rcode, uint8_t* resp, size_t resp_len) {
android_res_nresult_t f = (android_res_nresult_t)sym;
if (!f) return -1;
return f(fd, rcode, resp, resp_len);
}
*/
import "C"
import (
"context"
"errors"
"os"
"unsafe"
"golang.org/x/sys/unix"
)
func init() {
libname := C.CString("libandroid.so")
defer C.free(unsafe.Pointer(libname))
libHandle := C.dlopen(libname, C.int(C.RTLD_NOW))
if libHandle == nil {
return
}
symNameSend := C.CString("android_res_nsend")
defer C.free(unsafe.Pointer(symNameSend))
androidResNSendSym := C.dlsym(libHandle, symNameSend)
if androidResNSendSym == nil {
return
}
symNameResult := C.CString("android_res_nresult")
defer C.free(unsafe.Pointer(symNameResult))
androidResNResultSym := C.dlsym(libHandle, symNameResult)
if androidResNResultSym == nil {
return
}
callAndroidResNSend := func(network uint64, msg []byte) (int, error) {
if len(msg) == 0 {
return 0, errors.New("empty payload")
}
msgPtr := (*C.uint8_t)(unsafe.Pointer(&msg[0]))
msgLen := C.size_t(len(msg))
ret := C.call_android_res_nsend(androidResNSendSym, C.uint64_t(network), msgPtr, msgLen, C.int(0))
return int(ret), nil
}
callAndroidResNResult := func(fd int, resp []byte) (int, int) {
if len(resp) == 0 {
return 0, 0
}
respPtr := (*C.uint8_t)(unsafe.Pointer(&resp[0]))
respLen := C.size_t(len(resp))
var rcode C.int
n := C.call_android_res_nresult(androidResNResultSym, C.int(fd), &rcode, respPtr, respLen)
return int(rcode), int(n)
}
// set rawQueryFunc
rawQueryFunc = func(networkHandle int64, request []byte) ([]byte, error) {
fd, err := callAndroidResNSend(uint64(networkHandle), request)
if err != nil {
return nil, err
}
if fd < 0 {
return nil, unix.Errno(-fd)
}
// wait for response (timeout 5000 ms)
pfds := []unix.PollFd{{Fd: int32(fd), Events: unix.POLLIN | unix.POLLERR}}
nReady, err := unix.Poll(pfds, 5000)
if err != nil {
unix.Close(fd)
return nil, err
}
if nReady == 0 {
unix.Close(fd)
return nil, context.DeadlineExceeded
}
// read response into buffer
response := make([]byte, 8192)
_, n := callAndroidResNResult(fd, response)
if n < 0 {
return nil, unix.Errno(-n)
}
if n == 0 {
return nil, os.ErrInvalid
}
return response[:n], nil
}
}

View File

@ -6,11 +6,13 @@ import (
"context" "context"
"net/netip" "net/netip"
"strings" "strings"
"sync"
"syscall" "syscall"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/dns" "github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
@ -19,8 +21,11 @@ import (
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
) )
var rawQueryFunc func(networkHandle int64, request []byte) ([]byte, error)
type LocalDNSTransport interface { type LocalDNSTransport interface {
Raw() bool Raw() bool
NetworkHandle() int64
Lookup(ctx *ExchangeContext, network string, domain string) error Lookup(ctx *ExchangeContext, network string, domain string) error
Exchange(ctx *ExchangeContext, message []byte) error Exchange(ctx *ExchangeContext, message []byte) error
} }
@ -28,98 +33,91 @@ type LocalDNSTransport interface {
var gLocalDNSTransport *platformLocalDNSTransport = nil var gLocalDNSTransport *platformLocalDNSTransport = nil
type platformLocalDNSTransport struct { type platformLocalDNSTransport struct {
dns.TransportAdapter
iif LocalDNSTransport iif LocalDNSTransport
tag string raw bool
} }
func (p *platformLocalDNSTransport) Start(adapter.StartStage) error { func newPlatformTransport(iif LocalDNSTransport, tag string, options option.LocalDNSServerOptions) *platformLocalDNSTransport {
return &platformLocalDNSTransport{
TransportAdapter: dns.NewTransportAdapterWithLocalOptions(constant.DNSTypeLocal, tag, options),
iif: iif,
raw: iif.Raw(),
}
}
func (p *platformLocalDNSTransport) Start(stage adapter.StartStage) error {
return nil return nil
} }
func (p *platformLocalDNSTransport) Reset() {
}
func (p *platformLocalDNSTransport) Close() error { func (p *platformLocalDNSTransport) Close() error {
return nil return nil
} }
func (p *platformLocalDNSTransport) Raw() bool {
return p.iif.Raw()
}
func (p *platformLocalDNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { func (p *platformLocalDNSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
messageBytes, err := message.Pack() if p.raw && rawQueryFunc != nil {
if err != nil { // Raw - Android 10 及以上才有
return nil, err
}
response := &ExchangeContext{
context: ctx,
}
var responseMessage *mDNS.Msg
return responseMessage, task.Run(ctx, func() error {
err = p.iif.Exchange(response, messageBytes)
if err != nil {
return err
}
if response.error != nil {
return response.error
}
responseMessage = &response.message
return nil
})
}
func (p *platformLocalDNSTransport) Lookup(ctx context.Context, domain string, strategy constant.DomainStrategy) ([]netip.Addr, error) { messageBytes, err := message.Pack()
var network string
switch strategy {
case constant.DomainStrategyIPv4Only:
network = "ip4"
case constant.DomainStrategyPreferIPv6:
network = "ip6"
default:
network = "ip"
}
response := &ExchangeContext{
context: ctx,
}
var responseAddr []netip.Addr
return responseAddr, task.Run(ctx, func() error {
err := p.iif.Lookup(response, network, domain)
if err != nil { if err != nil {
return err return nil, err
} }
if response.error != nil { msg, err := rawQueryFunc(p.iif.NetworkHandle(), messageBytes)
return response.error if err != nil {
return nil, err
} }
switch strategy { responseMessage := new(mDNS.Msg)
case constant.DomainStrategyIPv4Only: err = responseMessage.Unpack(msg)
responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool { if err != nil {
return it.Is4() return nil, err
}) }
case constant.DomainStrategyPreferIPv6: return responseMessage, nil
responseAddr = common.Filter(response.addresses, func(it netip.Addr) bool { } else {
return it.Is6() // Lookup - Android 10 以下
})
question := message.Question[0]
var network string
switch question.Qtype {
case mDNS.TypeA:
network = "ip4"
case mDNS.TypeAAAA:
network = "ip6"
default: default:
responseAddr = response.addresses return nil, E.New("only IP queries are supported by current version of Android")
} }
/*if len(responseAddr) == 0 {
response.error = dns.RCodeSuccess
}*/
return nil
})
}
func (p *platformLocalDNSTransport) Tag() string { done := make(chan struct{})
return p.tag response := &ExchangeContext{
} context: ctx,
done: sync.OnceFunc(func() {
close(done)
}),
}
func (p *platformLocalDNSTransport) Type() string { var responseAddrs []netip.Addr
return "local" var group task.Group
} group.Append0(func(ctx context.Context) error {
err := p.iif.Lookup(response, network, question.Name)
func (p *platformLocalDNSTransport) Dependencies() []string { if err != nil {
return nil return err
}
select {
case <-done:
case <-ctx.Done():
return context.Canceled
}
if response.error != nil {
return response.error
}
responseAddrs = response.addresses
return nil
})
err := group.Run(ctx)
if err != nil {
return nil, err
}
return dns.FixedResponse(message.Id, question, responseAddrs, constant.DefaultDNSTTL), nil
}
} }
type Func interface { type Func interface {
@ -131,6 +129,7 @@ type ExchangeContext struct {
message mDNS.Msg message mDNS.Msg
addresses []netip.Addr addresses []netip.Addr
error error error error
done func()
} }
func (c *ExchangeContext) OnCancel(callback Func) { func (c *ExchangeContext) OnCancel(callback Func) {
@ -153,12 +152,15 @@ func (c *ExchangeContext) RawSuccess(result []byte) {
if err != nil { if err != nil {
c.error = E.Cause(err, "parse response") c.error = E.Cause(err, "parse response")
} }
c.done()
} }
func (c *ExchangeContext) ErrorCode(code int32) { func (c *ExchangeContext) ErrorCode(code int32) {
c.error = dns.RcodeError(code) c.error = dns.RcodeError(code)
c.done()
} }
func (c *ExchangeContext) ErrnoCode(code int32) { func (c *ExchangeContext) ErrnoCode(code int32) {
c.error = syscall.Errno(code) c.error = syscall.Errno(code)
c.done()
} }