diff --git a/dns/util.go b/dns/util.go index e960a8e1..d46cd65c 100644 --- a/dns/util.go +++ b/dns/util.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/netip" + "strconv" "strings" "time" @@ -15,6 +16,7 @@ import ( D "github.com/miekg/dns" "github.com/samber/lo" + "golang.org/x/exp/slices" ) const ( @@ -108,37 +110,75 @@ func transform(servers []NameServer, resolver *Resolver) []dnsClient { } c = warpClientWithEdns0Subnet(c, s.Params) - - if s.Params["disable-ipv4"] == "true" { - c = warpClientWithDisableType(c, D.TypeA) - } - - if s.Params["disable-ipv6"] == "true" { - c = warpClientWithDisableType(c, D.TypeAAAA) - } + c = warpClientWithDisableTypes(c, s.Params) ret = append(ret, c) } return ret } -type clientWithDisableType struct { +type clientWithDisableTypes struct { dnsClient - qType uint16 + disableTypes map[uint16]struct{} } -func (c clientWithDisableType) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { - if len(m.Question) > 0 { - q := m.Question[0] - if q.Qtype == c.qType { - return handleMsgWithEmptyAnswer(m), nil +func (c clientWithDisableTypes) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) { + // filter dns request + if slices.ContainsFunc(m.Question, c.inQuestion) { + // In fact, DNS requests are not allowed to contain multiple questions: + // https://stackoverflow.com/questions/4082081/requesting-a-and-aaaa-records-in-single-dns-query/4083071 + // so, when we find a question containing the type, we can simply discard the entire dns request. + return handleMsgWithEmptyAnswer(m), nil + } + + // do real exchange + msg, err = c.dnsClient.ExchangeContext(ctx, m) + if err != nil { + return + } + + // filter dns response + msg.Answer = slices.DeleteFunc(msg.Answer, c.inRR) + msg.Ns = slices.DeleteFunc(msg.Ns, c.inRR) + msg.Extra = slices.DeleteFunc(msg.Extra, c.inRR) + return +} + +func (c clientWithDisableTypes) inQuestion(q D.Question) bool { + _, ok := c.disableTypes[q.Qtype] + return ok +} + +func (c clientWithDisableTypes) inRR(rr D.RR) bool { + _, ok := c.disableTypes[rr.Header().Rrtype] + return ok +} + +func warpClientWithDisableTypes(c dnsClient, params map[string]string) dnsClient { + disableTypes := make(map[uint16]struct{}) + if params["disable-ipv4"] == "true" { + disableTypes[D.TypeA] = struct{}{} + } + if params["disable-ipv6"] == "true" { + disableTypes[D.TypeAAAA] = struct{}{} + } + for key, value := range params { + const prefix = "disable-qtype-" + if strings.HasPrefix(key, prefix) && value == "true" { // eg: disable-qtype-65=true + qType, err := strconv.ParseUint(key[len(prefix):], 10, 16) + if err != nil { + continue + } + if _, ok := D.TypeToRR[uint16(qType)]; !ok { // check valid RR_Header.Rrtype and Question.qtype + continue + } + disableTypes[uint16(qType)] = struct{}{} } } - return c.dnsClient.ExchangeContext(ctx, m) -} - -func warpClientWithDisableType(c dnsClient, qType uint16) dnsClient { - return clientWithDisableType{c, qType} + if len(disableTypes) > 0 { + return clientWithDisableTypes{c, disableTypes} + } + return c } type clientWithEdns0Subnet struct {