123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- package dns
- import (
- "context"
- "crypto/tls"
- "errors"
- "fmt"
- "net"
- "net/netip"
- "strings"
- "time"
- "github.com/metacubex/mihomo/common/nnip"
- "github.com/metacubex/mihomo/common/picker"
- "github.com/metacubex/mihomo/component/dialer"
- "github.com/metacubex/mihomo/component/resolver"
- "github.com/metacubex/mihomo/log"
- D "github.com/miekg/dns"
- "github.com/samber/lo"
- )
- const (
- MaxMsgSize = 65535
- )
- const serverFailureCacheTTL uint32 = 5
- func minimalTTL(records []D.RR) uint32 {
- rr := lo.MinBy(records, func(r1 D.RR, r2 D.RR) bool {
- return r1.Header().Ttl < r2.Header().Ttl
- })
- if rr == nil {
- return 0
- }
- return rr.Header().Ttl
- }
- func updateTTL(records []D.RR, ttl uint32) {
- if len(records) == 0 {
- return
- }
- delta := minimalTTL(records) - ttl
- for i := range records {
- records[i].Header().Ttl = lo.Clamp(records[i].Header().Ttl-delta, 1, records[i].Header().Ttl)
- }
- }
- func putMsgToCache(c dnsCache, key string, q D.Question, msg *D.Msg) {
- // skip dns cache for acme challenge
- if q.Qtype == D.TypeTXT && strings.HasPrefix(q.Name, "_acme-challenge.") {
- log.Debugln("[DNS] dns cache ignored because of acme challenge for: %s", q.Name)
- return
- }
- var ttl uint32
- if msg.Rcode == D.RcodeServerFailure {
- // [...] a resolver MAY cache a server failure response.
- // If it does so it MUST NOT cache it for longer than five (5) minutes [...]
- ttl = serverFailureCacheTTL
- } else {
- ttl = minimalTTL(append(append(msg.Answer, msg.Ns...), msg.Extra...))
- }
- if ttl == 0 {
- return
- }
- c.SetWithExpire(key, msg.Copy(), time.Now().Add(time.Duration(ttl)*time.Second))
- }
- func setMsgTTL(msg *D.Msg, ttl uint32) {
- for _, answer := range msg.Answer {
- answer.Header().Ttl = ttl
- }
- for _, ns := range msg.Ns {
- ns.Header().Ttl = ttl
- }
- for _, extra := range msg.Extra {
- extra.Header().Ttl = ttl
- }
- }
- func updateMsgTTL(msg *D.Msg, ttl uint32) {
- updateTTL(msg.Answer, ttl)
- updateTTL(msg.Ns, ttl)
- updateTTL(msg.Extra, ttl)
- }
- func isIPRequest(q D.Question) bool {
- return q.Qclass == D.ClassINET && (q.Qtype == D.TypeA || q.Qtype == D.TypeAAAA || q.Qtype == D.TypeCNAME)
- }
- func transform(servers []NameServer, resolver *Resolver) []dnsClient {
- ret := make([]dnsClient, 0, len(servers))
- for _, s := range servers {
- switch s.Net {
- case "https":
- ret = append(ret, newDoHClient(s.Addr, resolver, s.PreferH3, s.Params, s.ProxyAdapter, s.ProxyName))
- continue
- case "dhcp":
- ret = append(ret, newDHCPClient(s.Addr))
- continue
- case "system":
- ret = append(ret, newSystemClient())
- continue
- case "rcode":
- ret = append(ret, newRCodeClient(s.Addr))
- continue
- case "quic":
- if doq, err := newDoQ(resolver, s.Addr, s.ProxyAdapter, s.ProxyName); err == nil {
- ret = append(ret, doq)
- } else {
- log.Fatalln("DoQ format error: %v", err)
- }
- continue
- }
- var options []dialer.Option
- if s.Interface != "" {
- options = append(options, dialer.WithInterface(s.Interface))
- }
- host, port, _ := net.SplitHostPort(s.Addr)
- ret = append(ret, &client{
- Client: &D.Client{
- Net: s.Net,
- TLSConfig: &tls.Config{
- ServerName: host,
- },
- UDPSize: 4096,
- Timeout: 5 * time.Second,
- },
- port: port,
- host: host,
- dialer: newDNSDialer(resolver, s.ProxyAdapter, s.ProxyName, options...),
- })
- }
- return ret
- }
- func handleMsgWithEmptyAnswer(r *D.Msg) *D.Msg {
- msg := &D.Msg{}
- msg.Answer = []D.RR{}
- msg.SetRcode(r, D.RcodeSuccess)
- msg.Authoritative = true
- msg.RecursionAvailable = true
- return msg
- }
- func msgToIP(msg *D.Msg) []netip.Addr {
- ips := []netip.Addr{}
- for _, answer := range msg.Answer {
- switch ans := answer.(type) {
- case *D.AAAA:
- ips = append(ips, nnip.IpToAddr(ans.AAAA))
- case *D.A:
- ips = append(ips, nnip.IpToAddr(ans.A))
- }
- }
- return ips
- }
- func msgToDomain(msg *D.Msg) string {
- if len(msg.Question) > 0 {
- return strings.TrimRight(msg.Question[0].Name, ".")
- }
- return ""
- }
- func msgToQtype(msg *D.Msg) (uint16, string) {
- if len(msg.Question) > 0 {
- qType := msg.Question[0].Qtype
- return qType, D.Type(qType).String()
- }
- return 0, ""
- }
- func batchExchange(ctx context.Context, clients []dnsClient, m *D.Msg) (msg *D.Msg, cache bool, err error) {
- cache = true
- fast, ctx := picker.WithTimeout[*D.Msg](ctx, resolver.DefaultDNSTimeout)
- defer fast.Close()
- domain := msgToDomain(m)
- qType, qTypeStr := msgToQtype(m)
- var noIpMsg *D.Msg
- for _, client := range clients {
- if _, isRCodeClient := client.(rcodeClient); isRCodeClient {
- msg, err = client.ExchangeContext(ctx, m)
- return msg, false, err
- }
- client := client // shadow define client to ensure the value captured by the closure will not be changed in the next loop
- fast.Go(func() (*D.Msg, error) {
- log.Debugln("[DNS] resolve %s %s from %s", domain, qTypeStr, client.Address())
- m, err := client.ExchangeContext(ctx, m)
- if err != nil {
- return nil, err
- } else if cache && (m.Rcode == D.RcodeServerFailure || m.Rcode == D.RcodeRefused) {
- // currently, cache indicates whether this msg was from a RCode client,
- // so we would ignore RCode errors from RCode clients.
- return nil, errors.New("server failure: " + D.RcodeToString[m.Rcode])
- }
- ips := msgToIP(m)
- log.Debugln("[DNS] %s --> %s %s from %s", domain, ips, qTypeStr, client.Address())
- switch qType {
- case D.TypeAAAA:
- if len(ips) == 0 {
- noIpMsg = m
- return nil, resolver.ErrIPNotFound
- }
- case D.TypeA:
- if len(ips) == 0 {
- noIpMsg = m
- return nil, resolver.ErrIPNotFound
- }
- }
- return m, nil
- })
- }
- msg = fast.Wait()
- if msg == nil {
- if noIpMsg != nil {
- return noIpMsg, false, nil
- }
- err = errors.New("all DNS requests failed")
- if fErr := fast.Error(); fErr != nil {
- err = fmt.Errorf("%w, first error: %w", err, fErr)
- }
- }
- return
- }
|