diff --git a/component/resolver/local.go b/component/resolver/service.go similarity index 73% rename from component/resolver/local.go rename to component/resolver/service.go index e8505118..8b8f1158 100644 --- a/component/resolver/local.go +++ b/component/resolver/service.go @@ -6,15 +6,15 @@ import ( D "github.com/miekg/dns" ) -var DefaultLocalServer LocalServer +var DefaultService Service -type LocalServer interface { +type Service interface { ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) } // ServeMsg with a dns.Msg, return resolve dns.Msg func ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) { - if server := DefaultLocalServer; server != nil { + if server := DefaultService; server != nil { return server.ServeMsg(ctx, msg) } diff --git a/context/dns.go b/context/dns.go index 1cc2067d..15143102 100644 --- a/context/dns.go +++ b/context/dns.go @@ -2,10 +2,10 @@ package context import ( "context" + "github.com/metacubex/mihomo/common/utils" "github.com/gofrs/uuid/v5" - "github.com/miekg/dns" ) const ( @@ -17,17 +17,15 @@ const ( type DNSContext struct { context.Context - id uuid.UUID - msg *dns.Msg - tp string + id uuid.UUID + tp string } -func NewDNSContext(ctx context.Context, msg *dns.Msg) *DNSContext { +func NewDNSContext(ctx context.Context) *DNSContext { return &DNSContext{ Context: ctx, - id: utils.NewUUIDV4(), - msg: msg, + id: utils.NewUUIDV4(), } } diff --git a/dns/local.go b/dns/local.go deleted file mode 100644 index 37b5d41b..00000000 --- a/dns/local.go +++ /dev/null @@ -1,20 +0,0 @@ -package dns - -import ( - "context" - - D "github.com/miekg/dns" -) - -type LocalServer struct { - handler handler -} - -// ServeMsg implement resolver.LocalServer ResolveMsg -func (s *LocalServer) ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) { - return handlerWithContext(ctx, s.handler, msg) -} - -func NewLocalServer(resolver *Resolver, mapper *ResolverEnhancer) *LocalServer { - return &LocalServer{handler: NewHandler(resolver, mapper)} -} diff --git a/dns/middleware.go b/dns/middleware.go index 5b0c0d85..502a37e5 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -7,22 +7,22 @@ import ( "github.com/metacubex/mihomo/common/lru" "github.com/metacubex/mihomo/component/fakeip" - R "github.com/metacubex/mihomo/component/resolver" + "github.com/metacubex/mihomo/component/resolver" C "github.com/metacubex/mihomo/constant" - "github.com/metacubex/mihomo/context" + icontext "github.com/metacubex/mihomo/context" "github.com/metacubex/mihomo/log" D "github.com/miekg/dns" ) type ( - handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) + handler func(ctx *icontext.DNSContext, r *D.Msg) (*D.Msg, error) middleware func(next handler) handler ) func withHosts(mapping *lru.LruCache[netip.Addr, string]) middleware { return func(next handler) handler { - return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { + return func(ctx *icontext.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] if !isIPRequest(q) { @@ -36,7 +36,7 @@ func withHosts(mapping *lru.LruCache[netip.Addr, string]) middleware { rr.Target = domain + "." resp.Answer = append([]D.RR{rr}, resp.Answer...) } - record, ok := R.DefaultHosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA) + record, ok := resolver.DefaultHosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA) if !ok { if record != nil && record.IsDomain { // replace request domain @@ -88,7 +88,7 @@ func withHosts(mapping *lru.LruCache[netip.Addr, string]) middleware { return next(ctx, r) } - ctx.SetType(context.DNSTypeHost) + ctx.SetType(icontext.DNSTypeHost) msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true msg.RecursionAvailable = true @@ -99,7 +99,7 @@ func withHosts(mapping *lru.LruCache[netip.Addr, string]) middleware { func withMapping(mapping *lru.LruCache[netip.Addr, string]) middleware { return func(next handler) handler { - return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { + return func(ctx *icontext.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] if !isIPRequest(q) { @@ -149,7 +149,7 @@ func withMapping(mapping *lru.LruCache[netip.Addr, string]) middleware { func withFakeIP(fakePool *fakeip.Pool) middleware { return func(next handler) handler { - return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { + return func(ctx *icontext.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] host := strings.TrimRight(q.Name, ".") @@ -173,7 +173,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { msg := r.Copy() msg.Answer = []D.RR{rr} - ctx.SetType(context.DNSTypeFakeIP) + ctx.SetType(icontext.DNSTypeFakeIP) setMsgTTL(msg, 1) msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true @@ -185,8 +185,8 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { } func withResolver(resolver *Resolver) handler { - return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { - ctx.SetType(context.DNSTypeRaw) + return func(ctx *icontext.DNSContext, r *D.Msg) (*D.Msg, error) { + ctx.SetType(icontext.DNSTypeRaw) q := r.Question[0] @@ -218,8 +218,8 @@ func compose(middlewares []middleware, endpoint handler) handler { return h } -func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { - middlewares := []middleware{} +func newHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { + var middlewares []middleware if mapper.useHosts { middlewares = append(middlewares, withHosts(mapper.mapping)) diff --git a/dns/server.go b/dns/server.go index 541aeee4..b1224c62 100644 --- a/dns/server.go +++ b/dns/server.go @@ -1,13 +1,12 @@ package dns import ( - stdContext "context" - "errors" + "context" "net" "github.com/metacubex/mihomo/adapter/inbound" "github.com/metacubex/mihomo/common/sockopt" - "github.com/metacubex/mihomo/context" + "github.com/metacubex/mihomo/component/resolver" "github.com/metacubex/mihomo/log" D "github.com/miekg/dns" @@ -21,39 +20,32 @@ var ( ) type Server struct { - handler handler + service resolver.Service tcpServer *D.Server udpServer *D.Server } // ServeDNS implement D.Handler ServeDNS func (s *Server) ServeDNS(w D.ResponseWriter, r *D.Msg) { - msg, err := handlerWithContext(stdContext.Background(), s.handler, r) + msg, err := s.service.ServeMsg(context.Background(), r) if err != nil { - D.HandleFailed(w, r) + m := new(D.Msg) + m.SetRcode(r, D.RcodeServerFailure) + // does not matter if this write fails + w.WriteMsg(m) return } msg.Compress = true w.WriteMsg(msg) } -func handlerWithContext(stdCtx stdContext.Context, handler handler, msg *D.Msg) (*D.Msg, error) { - if len(msg.Question) == 0 { - return nil, errors.New("at least one question is required") - } - - ctx := context.NewDNSContext(stdCtx, msg) - return handler(ctx, msg) +func (s *Server) SetService(service resolver.Service) { + s.service = service } -func (s *Server) SetHandler(handler handler) { - s.handler = handler -} - -func ReCreateServer(addr string, resolver *Resolver, mapper *ResolverEnhancer) { - if addr == address && resolver != nil && mapper != nil { - handler := NewHandler(resolver, mapper) - server.SetHandler(handler) +func ReCreateServer(addr string, service resolver.Service) { + if addr == address && service != nil { + server.SetService(service) return } @@ -67,10 +59,10 @@ func ReCreateServer(addr string, resolver *Resolver, mapper *ResolverEnhancer) { server.udpServer = nil } - server.handler = nil + server.service = nil address = "" - if addr == "" || resolver == nil || mapper == nil { + if addr == "" || service == nil { return } @@ -87,8 +79,7 @@ func ReCreateServer(addr string, resolver *Resolver, mapper *ResolverEnhancer) { } address = addr - handler := NewHandler(resolver, mapper) - server = &Server{handler: handler} + server = &Server{service: service} go func() { p, err := inbound.ListenPacket("udp", addr) diff --git a/dns/service.go b/dns/service.go new file mode 100644 index 00000000..4a7c1bb2 --- /dev/null +++ b/dns/service.go @@ -0,0 +1,29 @@ +package dns + +import ( + "context" + "errors" + + "github.com/metacubex/mihomo/component/resolver" + icontext "github.com/metacubex/mihomo/context" + D "github.com/miekg/dns" +) + +type Service struct { + handler handler +} + +// ServeMsg implement [resolver.Service] ResolveMsg +func (s *Service) ServeMsg(ctx context.Context, msg *D.Msg) (*D.Msg, error) { + if len(msg.Question) == 0 { + return nil, errors.New("at least one question is required") + } + + return s.handler(icontext.NewDNSContext(ctx), msg) +} + +var _ resolver.Service = (*Service)(nil) + +func NewService(resolver *Resolver, mapper *ResolverEnhancer) *Service { + return &Service{handler: newHandler(resolver, mapper)} +} diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 9921973d..041e6fc3 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -240,10 +240,10 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { if !c.Enable { resolver.DefaultResolver = nil resolver.DefaultHostMapper = nil - resolver.DefaultLocalServer = nil + resolver.DefaultService = nil resolver.ProxyServerHostResolver = nil resolver.DirectHostResolver = nil - dns.ReCreateServer("", nil, nil) + dns.ReCreateServer("", nil) return } @@ -273,9 +273,11 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { m.PatchFrom(old.(*dns.ResolverEnhancer)) } + s := dns.NewService(r.Resolver, m) + resolver.DefaultResolver = r resolver.DefaultHostMapper = m - resolver.DefaultLocalServer = dns.NewLocalServer(r.Resolver, m) + resolver.DefaultService = s resolver.UseSystemHosts = c.UseSystemHosts if r.ProxyResolver.Invalid() { @@ -290,7 +292,7 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { resolver.DirectHostResolver = r.Resolver } - dns.ReCreateServer(c.Listen, r.Resolver, m) + dns.ReCreateServer(c.Listen, s) } func updateHosts(tree *trie.DomainTrie[resolver.HostValue]) {