From 40b2cde2b21ee5d47f1ebd11b29d2fe724aa7545 Mon Sep 17 00:00:00 2001 From: wwqgtxx Date: Fri, 19 Sep 2025 21:07:43 +0800 Subject: [PATCH] chore: cleanup dns client code --- dns/client.go | 97 +++++++++++++++++++++++---------------------------- 1 file changed, 44 insertions(+), 53 deletions(-) diff --git a/dns/client.go b/dns/client.go index cc07f7ef..9cc7df2d 100644 --- a/dns/client.go +++ b/dns/client.go @@ -16,53 +16,49 @@ import ( ) type client struct { - *D.Client - port string - host string - dialer *dnsDialer - addr string + port string + host string + dialer *dnsDialer + schema string + skipCertVerify bool } var _ dnsClient = (*client)(nil) // Address implements dnsClient func (c *client) Address() string { - if len(c.addr) != 0 { - return c.addr - } - schema := "udp" - if strings.HasPrefix(c.Client.Net, "tcp") { - schema = "tcp" - if strings.HasSuffix(c.Client.Net, "tls") { - schema = "tls" - } - } - - c.addr = fmt.Sprintf("%s://%s", schema, net.JoinHostPort(c.host, c.port)) - return c.addr + return fmt.Sprintf("%s://%s", c.schema, net.JoinHostPort(c.host, c.port)) } func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) { network := "udp" - if strings.HasPrefix(c.Client.Net, "tcp") { + if c.schema != "udp" { network = "tcp" } - tlsConfig, err := ca.GetTLSConfig(ca.Option{ - TLSConfig: c.Client.TLSConfig, - }) - if err != nil { - return nil, err - } - addr := net.JoinHostPort(c.host, c.port) conn, err := c.dialer.DialContext(ctx, network, addr) if err != nil { return nil, err } - defer func() { - _ = conn.Close() - }() + defer conn.Close() + + if c.schema == "tls" { + tlsConfig, err := ca.GetTLSConfig(ca.Option{ + TLSConfig: &tls.Config{ + ServerName: c.host, + InsecureSkipVerify: c.skipCertVerify, + }, + }) + if err != nil { + return nil, err + } + tlsConn := tls.Client(conn, tlsConfig) + if err := tlsConn.HandshakeContext(ctx); err != nil { + return nil, err + } + conn = tlsConn + } // miekg/dns ExchangeContext doesn't respond to context cancel. // this is a workaround @@ -72,34 +68,30 @@ func (c *client) ExchangeContext(ctx context.Context, m *D.Msg) (*D.Msg, error) } ch := make(chan result, 1) go func() { - if strings.HasSuffix(c.Client.Net, "tls") { - conn = tls.Client(conn, tlsConfig) + dClient := &D.Client{ + UDPSize: 4096, + Timeout: 5 * time.Second, } - dConn := &D.Conn{ - Conn: conn, - UDPSize: c.Client.UDPSize, - TsigSecret: c.Client.TsigSecret, - TsigProvider: c.Client.TsigProvider, + Conn: conn, + UDPSize: dClient.UDPSize, } - msg, _, err := c.Client.ExchangeWithConn(m, dConn) + msg, _, err := dClient.ExchangeWithConn(m, dConn) // Resolvers MUST resend queries over TCP if they receive a truncated UDP response (with TC=1 set)! if msg != nil && msg.Truncated && network == "udp" { - tcpClient := *c.Client // copy a client - tcpClient.Net = "tcp" network = "tcp" log.Debugln("[DNS] Truncated reply from %s:%s for %s over UDP, retrying over TCP", c.host, c.port, m.Question[0].String()) - dConn.Conn, err = c.dialer.DialContext(ctx, network, addr) + var tcpConn net.Conn + tcpConn, err = c.dialer.DialContext(ctx, network, addr) if err != nil { ch <- result{msg, err} return } - defer func() { - _ = conn.Close() - }() - msg, _, err = tcpClient.ExchangeWithConn(m, dConn) + defer tcpConn.Close() + dConn.Conn = tcpConn + msg, _, err = dClient.ExchangeWithConn(m, dConn) } ch <- result{msg, err} @@ -118,20 +110,19 @@ func (c *client) ResetConnection() {} func newClient(addr string, resolver *Resolver, netType string, params map[string]string, proxyAdapter C.ProxyAdapter, proxyName string) *client { host, port, _ := net.SplitHostPort(addr) c := &client{ - Client: &D.Client{ - Net: netType, - TLSConfig: &tls.Config{ - ServerName: host, - }, - UDPSize: 4096, - Timeout: 5 * time.Second, - }, port: port, host: host, dialer: newDNSDialer(resolver, proxyAdapter, proxyName), + schema: "udp", + } + if strings.HasPrefix(netType, "tcp") { + c.schema = "tcp" + if strings.HasSuffix(netType, "tls") { + c.schema = "tls" + } } if params["skip-cert-verify"] == "true" { - c.TLSConfig.InsecureSkipVerify = true + c.skipCertVerify = true } return c }