resolver.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584
  1. package dns
  2. import (
  3. "context"
  4. "errors"
  5. "net/netip"
  6. "strings"
  7. "time"
  8. "github.com/metacubex/mihomo/common/arc"
  9. "github.com/metacubex/mihomo/common/lru"
  10. "github.com/metacubex/mihomo/component/fakeip"
  11. "github.com/metacubex/mihomo/component/geodata/router"
  12. "github.com/metacubex/mihomo/component/resolver"
  13. "github.com/metacubex/mihomo/component/trie"
  14. C "github.com/metacubex/mihomo/constant"
  15. "github.com/metacubex/mihomo/constant/provider"
  16. "github.com/metacubex/mihomo/log"
  17. D "github.com/miekg/dns"
  18. "github.com/samber/lo"
  19. orderedmap "github.com/wk8/go-ordered-map/v2"
  20. "golang.org/x/exp/maps"
  21. "golang.org/x/sync/singleflight"
  22. )
  23. type dnsClient interface {
  24. ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error)
  25. Address() string
  26. }
  27. type dnsCache interface {
  28. GetWithExpire(key string) (*D.Msg, time.Time, bool)
  29. SetWithExpire(key string, value *D.Msg, expire time.Time)
  30. }
  31. type result struct {
  32. Msg *D.Msg
  33. Error error
  34. }
  35. type Resolver struct {
  36. ipv6 bool
  37. ipv6Timeout time.Duration
  38. hosts *trie.DomainTrie[resolver.HostValue]
  39. main []dnsClient
  40. fallback []dnsClient
  41. fallbackDomainFilters []fallbackDomainFilter
  42. fallbackIPFilters []fallbackIPFilter
  43. group singleflight.Group
  44. cache dnsCache
  45. policy []dnsPolicy
  46. proxyServer []dnsClient
  47. }
  48. func (r *Resolver) LookupIPPrimaryIPv4(ctx context.Context, host string) (ips []netip.Addr, err error) {
  49. ch := make(chan []netip.Addr, 1)
  50. go func() {
  51. defer close(ch)
  52. ip, err := r.lookupIP(ctx, host, D.TypeAAAA)
  53. if err != nil {
  54. return
  55. }
  56. ch <- ip
  57. }()
  58. ips, err = r.lookupIP(ctx, host, D.TypeA)
  59. if err == nil {
  60. return
  61. }
  62. ip, open := <-ch
  63. if !open {
  64. return nil, resolver.ErrIPNotFound
  65. }
  66. return ip, nil
  67. }
  68. func (r *Resolver) LookupIP(ctx context.Context, host string) (ips []netip.Addr, err error) {
  69. ch := make(chan []netip.Addr, 1)
  70. go func() {
  71. defer close(ch)
  72. ip, err := r.lookupIP(ctx, host, D.TypeAAAA)
  73. if err != nil {
  74. return
  75. }
  76. ch <- ip
  77. }()
  78. ips, err = r.lookupIP(ctx, host, D.TypeA)
  79. var waitIPv6 *time.Timer
  80. if r != nil && r.ipv6Timeout > 0 {
  81. waitIPv6 = time.NewTimer(r.ipv6Timeout)
  82. } else {
  83. waitIPv6 = time.NewTimer(100 * time.Millisecond)
  84. }
  85. defer waitIPv6.Stop()
  86. select {
  87. case ipv6s, open := <-ch:
  88. if !open && err != nil {
  89. return nil, resolver.ErrIPNotFound
  90. }
  91. ips = append(ips, ipv6s...)
  92. case <-waitIPv6.C:
  93. // wait ipv6 result
  94. }
  95. return ips, nil
  96. }
  97. // LookupIPv4 request with TypeA
  98. func (r *Resolver) LookupIPv4(ctx context.Context, host string) ([]netip.Addr, error) {
  99. return r.lookupIP(ctx, host, D.TypeA)
  100. }
  101. // LookupIPv6 request with TypeAAAA
  102. func (r *Resolver) LookupIPv6(ctx context.Context, host string) ([]netip.Addr, error) {
  103. return r.lookupIP(ctx, host, D.TypeAAAA)
  104. }
  105. func (r *Resolver) shouldIPFallback(ip netip.Addr) bool {
  106. for _, filter := range r.fallbackIPFilters {
  107. if filter.Match(ip) {
  108. return true
  109. }
  110. }
  111. return false
  112. }
  113. // ExchangeContext a batch of dns request with context.Context, and it use cache
  114. func (r *Resolver) ExchangeContext(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
  115. if len(m.Question) == 0 {
  116. return nil, errors.New("should have one question at least")
  117. }
  118. continueFetch := false
  119. defer func() {
  120. if continueFetch || errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled) {
  121. go func() {
  122. ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout)
  123. defer cancel()
  124. _, _ = r.exchangeWithoutCache(ctx, m) // ignore result, just for putMsgToCache
  125. }()
  126. }
  127. }()
  128. q := m.Question[0]
  129. domain := msgToDomain(m)
  130. _, qTypeStr := msgToQtype(m)
  131. cacheM, expireTime, hit := r.cache.GetWithExpire(q.String())
  132. if hit {
  133. ips := msgToIP(cacheM)
  134. log.Debugln("[DNS] cache hit %s --> %s %s, expire at %s", domain, ips, qTypeStr, expireTime.Format("2006-01-02 15:04:05"))
  135. now := time.Now()
  136. msg = cacheM.Copy()
  137. if expireTime.Before(now) {
  138. setMsgTTL(msg, uint32(1)) // Continue fetch
  139. continueFetch = true
  140. } else {
  141. // updating TTL by subtracting common delta time from each DNS record
  142. updateMsgTTL(msg, uint32(time.Until(expireTime).Seconds()))
  143. }
  144. return
  145. }
  146. return r.exchangeWithoutCache(ctx, m)
  147. }
  148. // ExchangeWithoutCache a batch of dns request, and it do NOT GET from cache
  149. func (r *Resolver) exchangeWithoutCache(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
  150. q := m.Question[0]
  151. retryNum := 0
  152. retryMax := 3
  153. fn := func() (result any, err error) {
  154. ctx, cancel := context.WithTimeout(context.Background(), resolver.DefaultDNSTimeout) // reset timeout in singleflight
  155. defer cancel()
  156. cache := false
  157. defer func() {
  158. if err != nil {
  159. result = retryNum
  160. retryNum++
  161. return
  162. }
  163. msg := result.(*D.Msg)
  164. if cache {
  165. // OPT RRs MUST NOT be cached, forwarded, or stored in or loaded from master files.
  166. msg.Extra = lo.Filter(msg.Extra, func(rr D.RR, index int) bool {
  167. return rr.Header().Rrtype != D.TypeOPT
  168. })
  169. putMsgToCache(r.cache, q.String(), q, msg)
  170. }
  171. }()
  172. isIPReq := isIPRequest(q)
  173. if isIPReq {
  174. cache = true
  175. return r.ipExchange(ctx, m)
  176. }
  177. if matched := r.matchPolicy(m); len(matched) != 0 {
  178. result, cache, err = batchExchange(ctx, matched, m)
  179. return
  180. }
  181. result, cache, err = batchExchange(ctx, r.main, m)
  182. return
  183. }
  184. ch := r.group.DoChan(q.String(), fn)
  185. var result singleflight.Result
  186. select {
  187. case result = <-ch:
  188. break
  189. case <-ctx.Done():
  190. select {
  191. case result = <-ch: // maybe ctxDone and chFinish in same time, get DoChan's result as much as possible
  192. break
  193. default:
  194. go func() { // start a retrying monitor in background
  195. result := <-ch
  196. ret, err, shared := result.Val, result.Err, result.Shared
  197. if err != nil && !shared && ret.(int) < retryMax { // retry
  198. r.group.DoChan(q.String(), fn)
  199. }
  200. }()
  201. return nil, ctx.Err()
  202. }
  203. }
  204. ret, err, shared := result.Val, result.Err, result.Shared
  205. if err != nil && !shared && ret.(int) < retryMax { // retry
  206. r.group.DoChan(q.String(), fn)
  207. }
  208. if err == nil {
  209. msg = ret.(*D.Msg)
  210. if shared {
  211. msg = msg.Copy()
  212. }
  213. }
  214. return
  215. }
  216. func (r *Resolver) matchPolicy(m *D.Msg) []dnsClient {
  217. if r.policy == nil {
  218. return nil
  219. }
  220. domain := msgToDomain(m)
  221. if domain == "" {
  222. return nil
  223. }
  224. for _, policy := range r.policy {
  225. if dnsClients := policy.Match(domain); len(dnsClients) > 0 {
  226. return dnsClients
  227. }
  228. }
  229. return nil
  230. }
  231. func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
  232. if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
  233. return false
  234. }
  235. domain := msgToDomain(m)
  236. if domain == "" {
  237. return false
  238. }
  239. for _, df := range r.fallbackDomainFilters {
  240. if df.Match(domain) {
  241. return true
  242. }
  243. }
  244. return false
  245. }
  246. func (r *Resolver) ipExchange(ctx context.Context, m *D.Msg) (msg *D.Msg, err error) {
  247. if matched := r.matchPolicy(m); len(matched) != 0 {
  248. res := <-r.asyncExchange(ctx, matched, m)
  249. return res.Msg, res.Error
  250. }
  251. onlyFallback := r.shouldOnlyQueryFallback(m)
  252. if onlyFallback {
  253. res := <-r.asyncExchange(ctx, r.fallback, m)
  254. return res.Msg, res.Error
  255. }
  256. msgCh := r.asyncExchange(ctx, r.main, m)
  257. if r.fallback == nil || len(r.fallback) == 0 { // directly return if no fallback servers are available
  258. res := <-msgCh
  259. msg, err = res.Msg, res.Error
  260. return
  261. }
  262. res := <-msgCh
  263. if res.Error == nil {
  264. if ips := msgToIP(res.Msg); len(ips) != 0 {
  265. shouldNotFallback := lo.EveryBy(ips, func(ip netip.Addr) bool {
  266. return !r.shouldIPFallback(ip)
  267. })
  268. if shouldNotFallback {
  269. msg, err = res.Msg, res.Error // no need to wait for fallback result
  270. return
  271. }
  272. }
  273. }
  274. res = <-r.asyncExchange(ctx, r.fallback, m)
  275. msg, err = res.Msg, res.Error
  276. return
  277. }
  278. func (r *Resolver) lookupIP(ctx context.Context, host string, dnsType uint16) (ips []netip.Addr, err error) {
  279. ip, err := netip.ParseAddr(host)
  280. if err == nil {
  281. isIPv4 := ip.Is4() || ip.Is4In6()
  282. if dnsType == D.TypeAAAA && !isIPv4 {
  283. return []netip.Addr{ip}, nil
  284. } else if dnsType == D.TypeA && isIPv4 {
  285. return []netip.Addr{ip}, nil
  286. } else {
  287. return []netip.Addr{}, resolver.ErrIPVersion
  288. }
  289. }
  290. query := &D.Msg{}
  291. query.SetQuestion(D.Fqdn(host), dnsType)
  292. msg, err := r.ExchangeContext(ctx, query)
  293. if err != nil {
  294. return []netip.Addr{}, err
  295. }
  296. ips = msgToIP(msg)
  297. ipLength := len(ips)
  298. if ipLength == 0 {
  299. return []netip.Addr{}, resolver.ErrIPNotFound
  300. }
  301. return
  302. }
  303. func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D.Msg) <-chan *result {
  304. ch := make(chan *result, 1)
  305. go func() {
  306. res, _, err := batchExchange(ctx, client, msg)
  307. ch <- &result{Msg: res, Error: err}
  308. }()
  309. return ch
  310. }
  311. // Invalid return this resolver can or can't be used
  312. func (r *Resolver) Invalid() bool {
  313. if r == nil {
  314. return false
  315. }
  316. return len(r.main) > 0
  317. }
  318. type NameServer struct {
  319. Net string
  320. Addr string
  321. Interface string
  322. ProxyAdapter C.ProxyAdapter
  323. ProxyName string
  324. Params map[string]string
  325. PreferH3 bool
  326. }
  327. func (ns NameServer) Equal(ns2 NameServer) bool {
  328. defer func() {
  329. // C.ProxyAdapter compare maybe panic, just ignore
  330. recover()
  331. }()
  332. if ns.Net == ns2.Net &&
  333. ns.Addr == ns2.Addr &&
  334. ns.Interface == ns2.Interface &&
  335. ns.ProxyAdapter == ns2.ProxyAdapter &&
  336. ns.ProxyName == ns2.ProxyName &&
  337. maps.Equal(ns.Params, ns2.Params) &&
  338. ns.PreferH3 == ns2.PreferH3 {
  339. return true
  340. }
  341. return false
  342. }
  343. type FallbackFilter struct {
  344. GeoIP bool
  345. GeoIPCode string
  346. IPCIDR []netip.Prefix
  347. Domain []string
  348. GeoSite []router.DomainMatcher
  349. }
  350. type Config struct {
  351. Main, Fallback []NameServer
  352. Default []NameServer
  353. ProxyServer []NameServer
  354. IPv6 bool
  355. IPv6Timeout uint
  356. EnhancedMode C.DNSMode
  357. FallbackFilter FallbackFilter
  358. Pool *fakeip.Pool
  359. Hosts *trie.DomainTrie[resolver.HostValue]
  360. Policy *orderedmap.OrderedMap[string, []NameServer]
  361. Tunnel provider.Tunnel
  362. CacheAlgorithm string
  363. }
  364. func NewResolver(config Config) *Resolver {
  365. var cache dnsCache
  366. if config.CacheAlgorithm == "lru" {
  367. cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true))
  368. } else {
  369. cache = arc.New(arc.WithSize[string, *D.Msg](4096))
  370. }
  371. defaultResolver := &Resolver{
  372. main: transform(config.Default, nil),
  373. cache: cache,
  374. ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond,
  375. }
  376. var nameServerCache []struct {
  377. NameServer
  378. dnsClient
  379. }
  380. cacheTransform := func(nameserver []NameServer) (result []dnsClient) {
  381. LOOP:
  382. for _, ns := range nameserver {
  383. for _, nsc := range nameServerCache {
  384. if nsc.NameServer.Equal(ns) {
  385. result = append(result, nsc.dnsClient)
  386. continue LOOP
  387. }
  388. }
  389. // not in cache
  390. dc := transform([]NameServer{ns}, defaultResolver)
  391. if len(dc) > 0 {
  392. dc := dc[0]
  393. nameServerCache = append(nameServerCache, struct {
  394. NameServer
  395. dnsClient
  396. }{NameServer: ns, dnsClient: dc})
  397. result = append(result, dc)
  398. }
  399. }
  400. return
  401. }
  402. if config.CacheAlgorithm == "" || config.CacheAlgorithm == "lru" {
  403. cache = lru.New(lru.WithSize[string, *D.Msg](4096), lru.WithStale[string, *D.Msg](true))
  404. } else {
  405. cache = arc.New(arc.WithSize[string, *D.Msg](4096))
  406. }
  407. r := &Resolver{
  408. ipv6: config.IPv6,
  409. main: cacheTransform(config.Main),
  410. cache: cache,
  411. hosts: config.Hosts,
  412. ipv6Timeout: time.Duration(config.IPv6Timeout) * time.Millisecond,
  413. }
  414. if len(config.Fallback) != 0 {
  415. r.fallback = cacheTransform(config.Fallback)
  416. }
  417. if len(config.ProxyServer) != 0 {
  418. r.proxyServer = cacheTransform(config.ProxyServer)
  419. }
  420. if config.Policy.Len() != 0 {
  421. r.policy = make([]dnsPolicy, 0)
  422. var triePolicy *trie.DomainTrie[[]dnsClient]
  423. insertPolicy := func(policy dnsPolicy) {
  424. if triePolicy != nil {
  425. triePolicy.Optimize()
  426. r.policy = append(r.policy, domainTriePolicy{triePolicy})
  427. triePolicy = nil
  428. }
  429. if policy != nil {
  430. r.policy = append(r.policy, policy)
  431. }
  432. }
  433. for pair := config.Policy.Oldest(); pair != nil; pair = pair.Next() {
  434. domain, nameserver := pair.Key, pair.Value
  435. if temp := strings.Split(domain, ":"); len(temp) == 2 {
  436. prefix := temp[0]
  437. key := temp[1]
  438. switch prefix {
  439. case "rule-set":
  440. if _, ok := config.Tunnel.RuleProviders()[key]; ok {
  441. log.Debugln("Adding rule-set policy: %s ", key)
  442. insertPolicy(domainSetPolicy{
  443. tunnel: config.Tunnel,
  444. name: key,
  445. dnsClients: cacheTransform(nameserver),
  446. })
  447. continue
  448. } else {
  449. log.Warnln("Can't found ruleset policy: %s", key)
  450. }
  451. case "geosite":
  452. inverse := false
  453. if strings.HasPrefix(key, "!") {
  454. inverse = true
  455. key = key[1:]
  456. }
  457. log.Debugln("Adding geosite policy: %s inversed %t", key, inverse)
  458. matcher, err := NewGeoSite(key)
  459. if err != nil {
  460. log.Warnln("adding geosite policy %s error: %s", key, err)
  461. continue
  462. }
  463. insertPolicy(geositePolicy{
  464. matcher: matcher,
  465. inverse: inverse,
  466. dnsClients: cacheTransform(nameserver),
  467. })
  468. continue // skip triePolicy new
  469. }
  470. }
  471. if triePolicy == nil {
  472. triePolicy = trie.New[[]dnsClient]()
  473. }
  474. _ = triePolicy.Insert(domain, cacheTransform(nameserver))
  475. }
  476. insertPolicy(nil)
  477. }
  478. fallbackIPFilters := []fallbackIPFilter{}
  479. if config.FallbackFilter.GeoIP {
  480. fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{
  481. code: config.FallbackFilter.GeoIPCode,
  482. })
  483. }
  484. for _, ipnet := range config.FallbackFilter.IPCIDR {
  485. fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
  486. }
  487. r.fallbackIPFilters = fallbackIPFilters
  488. fallbackDomainFilters := []fallbackDomainFilter{}
  489. if len(config.FallbackFilter.Domain) != 0 {
  490. fallbackDomainFilters = append(fallbackDomainFilters, NewDomainFilter(config.FallbackFilter.Domain))
  491. }
  492. if len(config.FallbackFilter.GeoSite) != 0 {
  493. fallbackDomainFilters = append(fallbackDomainFilters, &geoSiteFilter{
  494. matchers: config.FallbackFilter.GeoSite,
  495. })
  496. }
  497. r.fallbackDomainFilters = fallbackDomainFilters
  498. return r
  499. }
  500. func NewProxyServerHostResolver(old *Resolver) *Resolver {
  501. r := &Resolver{
  502. ipv6: old.ipv6,
  503. main: old.proxyServer,
  504. cache: old.cache,
  505. hosts: old.hosts,
  506. ipv6Timeout: old.ipv6Timeout,
  507. }
  508. return r
  509. }
  510. var ParseNameServer func(servers []string) ([]NameServer, error) // define in config/config.go