From 85e6d25de5d801fe02551deedc741139df3d1b68 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 6 Jun 2025 00:45:58 +0800 Subject: [PATCH] feat: all dns client support `ecs` and `ecs-override` params --- dns/doh.go | 33 ++------------------------------- dns/util.go | 53 +++++++++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 49 insertions(+), 37 deletions(-) diff --git a/dns/doh.go b/dns/doh.go index c6f7c67b..7c804f26 100644 --- a/dns/doh.go +++ b/dns/doh.go @@ -9,7 +9,6 @@ import ( "io" "net" "net/http" - "net/netip" "net/url" "runtime" "strconv" @@ -71,8 +70,6 @@ type dnsOverHTTPS struct { dialer *dnsDialer addr string skipCertVerify bool - ecsPrefix netip.Prefix - ecsOverride bool } // type check @@ -105,28 +102,6 @@ func newDoHClient(urlString string, r *Resolver, preferH3 bool, params map[strin doh.skipCertVerify = true } - if ecs := params["ecs"]; ecs != "" { - prefix, err := netip.ParsePrefix(ecs) - if err != nil { - addr, err := netip.ParseAddr(ecs) - if err != nil { - log.Warnln("DOH [%s] config with invalid ecs: %s", doh.addr, ecs) - } else { - doh.ecsPrefix = netip.PrefixFrom(addr, addr.BitLen()) - } - } else { - doh.ecsPrefix = prefix - } - } - - if doh.ecsPrefix.IsValid() { - log.Debugln("DOH [%s] config with ecs: %s", doh.addr, doh.ecsPrefix) - } - - if params["ecs-override"] == "true" { - doh.ecsOverride = true - } - runtime.SetFinalizer(doh, (*dnsOverHTTPS).Close) return doh @@ -154,10 +129,6 @@ func (doh *dnsOverHTTPS) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D. } }() - if doh.ecsPrefix.IsValid() { - setEdns0Subnet(m, doh.ecsPrefix, doh.ecsOverride) - } - // Check if there was already an active client before sending the request. // We'll only attempt to re-connect if there was one. client, isCached, err := doh.getClient(ctx) @@ -552,8 +523,8 @@ func (doh *dnsOverHTTPS) createTransportH3( Dial: func( ctx context.Context, - // Ignore the address and always connect to the one that we got - // from the bootstrapper. + // Ignore the address and always connect to the one that we got + // from the bootstrapper. _ string, tlsCfg *tlsC.Config, cfg *quic.Config, diff --git a/dns/util.go b/dns/util.go index 7a5519be..97496965 100644 --- a/dns/util.go +++ b/dns/util.go @@ -106,12 +106,14 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { c = newClient(s.Addr, resolver, s.Net, s.ProxyAdapter, s.ProxyName) } + c = warpClientWithEdns0Subnet(c, s.Params) + if s.Params["disable-ipv4"] == "true" { - c = newDisableTypeClient(c, D.TypeA) + c = warpClientWithDisableType(c, D.TypeA) } if s.Params["disable-ipv6"] == "true" { - c = newDisableTypeClient(c, D.TypeAAAA) + c = warpClientWithDisableType(c, D.TypeAAAA) } ret = append(ret, c) @@ -119,12 +121,12 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { return ret } -type disableTypeClient struct { +type clientWithDisableType struct { dnsClient qType uint16 } -func (c disableTypeClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { +func (c clientWithDisableType) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { if len(m.Question) > 0 { q := m.Question[0] if q.Qtype == c.qType { @@ -134,8 +136,47 @@ func (c disableTypeClient) ExchangeContext(ctx context.Context, m *D.Msg) (msg * return c.dnsClient.ExchangeContext(ctx, m) } -func newDisableTypeClient(c dnsClient, qType uint16) dnsClient { - return disableTypeClient{c, qType} +func warpClientWithDisableType(c dnsClient, qType uint16) dnsClient { + return clientWithDisableType{c, qType} +} + +type clientWithEdns0Subnet struct { + dnsClient + ecsPrefix netip.Prefix + ecsOverride bool +} + +func (c clientWithEdns0Subnet) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) { + m = m.Copy() + setEdns0Subnet(m, c.ecsPrefix, c.ecsOverride) + return c.dnsClient.ExchangeContext(ctx, m) +} + +func warpClientWithEdns0Subnet(c dnsClient, params map[string]string) dnsClient { + var ecsPrefix netip.Prefix + var ecsOverride bool + if ecs := params["ecs"]; ecs != "" { + prefix, err := netip.ParsePrefix(ecs) + if err != nil { + addr, err := netip.ParseAddr(ecs) + if err != nil { + log.Warnln("DNS [%s] config with invalid ecs: %s", c.Address(), ecs) + } else { + ecsPrefix = netip.PrefixFrom(addr, addr.BitLen()) + } + } else { + ecsPrefix = prefix + } + } + + if ecsPrefix.IsValid() { + log.Debugln("DNS [%s] config with ecs: %s", c.Address(), ecsPrefix) + if params["ecs-override"] == "true" { + ecsOverride = true + } + return clientWithEdns0Subnet{c, ecsPrefix, ecsOverride} + } + return c } func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {