middleware.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package dns
  2. import (
  3. "net/netip"
  4. "strings"
  5. "time"
  6. "github.com/metacubex/mihomo/common/lru"
  7. "github.com/metacubex/mihomo/common/nnip"
  8. "github.com/metacubex/mihomo/component/fakeip"
  9. R "github.com/metacubex/mihomo/component/resolver"
  10. C "github.com/metacubex/mihomo/constant"
  11. "github.com/metacubex/mihomo/context"
  12. "github.com/metacubex/mihomo/log"
  13. D "github.com/miekg/dns"
  14. )
  15. type (
  16. handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
  17. middleware func(next handler) handler
  18. )
  19. func withHosts(hosts R.Hosts, mapping *lru.LruCache[netip.Addr, string]) middleware {
  20. return func(next handler) handler {
  21. return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
  22. q := r.Question[0]
  23. if !isIPRequest(q) {
  24. return next(ctx, r)
  25. }
  26. host := strings.TrimRight(q.Name, ".")
  27. handleCName := func(resp *D.Msg, domain string) {
  28. rr := &D.CNAME{}
  29. rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10}
  30. rr.Target = domain + "."
  31. resp.Answer = append([]D.RR{rr}, resp.Answer...)
  32. }
  33. record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA)
  34. if !ok {
  35. if record != nil && record.IsDomain {
  36. // replace request domain
  37. newR := r.Copy()
  38. newR.Question[0].Name = record.Domain + "."
  39. resp, err := next(ctx, newR)
  40. if err == nil {
  41. resp.Id = r.Id
  42. resp.Question = r.Question
  43. handleCName(resp, record.Domain)
  44. }
  45. return resp, err
  46. }
  47. return next(ctx, r)
  48. }
  49. msg := r.Copy()
  50. handleIPs := func() {
  51. for _, ipAddr := range record.IPs {
  52. if ipAddr.Is4() && q.Qtype == D.TypeA {
  53. rr := &D.A{}
  54. rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
  55. rr.A = ipAddr.AsSlice()
  56. msg.Answer = append(msg.Answer, rr)
  57. if mapping != nil {
  58. mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
  59. }
  60. } else if q.Qtype == D.TypeAAAA {
  61. rr := &D.AAAA{}
  62. rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
  63. ip := ipAddr.As16()
  64. rr.AAAA = ip[:]
  65. msg.Answer = append(msg.Answer, rr)
  66. if mapping != nil {
  67. mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
  68. }
  69. }
  70. }
  71. }
  72. switch q.Qtype {
  73. case D.TypeA:
  74. handleIPs()
  75. case D.TypeAAAA:
  76. handleIPs()
  77. case D.TypeCNAME:
  78. handleCName(r, record.Domain)
  79. default:
  80. return next(ctx, r)
  81. }
  82. ctx.SetType(context.DNSTypeHost)
  83. msg.SetRcode(r, D.RcodeSuccess)
  84. msg.Authoritative = true
  85. msg.RecursionAvailable = true
  86. return msg, nil
  87. }
  88. }
  89. }
  90. func withMapping(mapping *lru.LruCache[netip.Addr, string]) middleware {
  91. return func(next handler) handler {
  92. return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
  93. q := r.Question[0]
  94. if !isIPRequest(q) {
  95. return next(ctx, r)
  96. }
  97. msg, err := next(ctx, r)
  98. if err != nil {
  99. return nil, err
  100. }
  101. host := strings.TrimRight(q.Name, ".")
  102. for _, ans := range msg.Answer {
  103. var ip netip.Addr
  104. var ttl uint32
  105. switch a := ans.(type) {
  106. case *D.A:
  107. ip = nnip.IpToAddr(a.A)
  108. ttl = a.Hdr.Ttl
  109. case *D.AAAA:
  110. ip = nnip.IpToAddr(a.AAAA)
  111. ttl = a.Hdr.Ttl
  112. default:
  113. continue
  114. }
  115. if ttl < 1 {
  116. ttl = 1
  117. }
  118. mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*time.Duration(ttl)))
  119. }
  120. return msg, nil
  121. }
  122. }
  123. }
  124. func withFakeIP(fakePool *fakeip.Pool) middleware {
  125. return func(next handler) handler {
  126. return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
  127. q := r.Question[0]
  128. host := strings.TrimRight(q.Name, ".")
  129. if fakePool.ShouldSkipped(host) {
  130. return next(ctx, r)
  131. }
  132. switch q.Qtype {
  133. case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS:
  134. return handleMsgWithEmptyAnswer(r), nil
  135. }
  136. if q.Qtype != D.TypeA {
  137. return next(ctx, r)
  138. }
  139. rr := &D.A{}
  140. rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
  141. ip := fakePool.Lookup(host)
  142. rr.A = ip.AsSlice()
  143. msg := r.Copy()
  144. msg.Answer = []D.RR{rr}
  145. ctx.SetType(context.DNSTypeFakeIP)
  146. setMsgTTL(msg, 1)
  147. msg.SetRcode(r, D.RcodeSuccess)
  148. msg.Authoritative = true
  149. msg.RecursionAvailable = true
  150. return msg, nil
  151. }
  152. }
  153. }
  154. func withResolver(resolver *Resolver) handler {
  155. return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
  156. ctx.SetType(context.DNSTypeRaw)
  157. q := r.Question[0]
  158. // return a empty AAAA msg when ipv6 disabled
  159. if !resolver.ipv6 && q.Qtype == D.TypeAAAA {
  160. return handleMsgWithEmptyAnswer(r), nil
  161. }
  162. msg, err := resolver.ExchangeContext(ctx, r)
  163. if err != nil {
  164. log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
  165. return msg, err
  166. }
  167. msg.SetRcode(r, msg.Rcode)
  168. msg.Authoritative = true
  169. return msg, nil
  170. }
  171. }
  172. func compose(middlewares []middleware, endpoint handler) handler {
  173. length := len(middlewares)
  174. h := endpoint
  175. for i := length - 1; i >= 0; i-- {
  176. middleware := middlewares[i]
  177. h = middleware(h)
  178. }
  179. return h
  180. }
  181. func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
  182. middlewares := []middleware{}
  183. if resolver.hosts != nil {
  184. middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping))
  185. }
  186. if mapper.mode == C.DNSFakeIP {
  187. middlewares = append(middlewares, withFakeIP(mapper.fakePool))
  188. }
  189. if mapper.mode != C.DNSNormal {
  190. middlewares = append(middlewares, withMapping(mapper.mapping))
  191. }
  192. return compose(middlewares, withResolver(resolver))
  193. }