123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231 |
- package dns
- import (
- "net/netip"
- "strings"
- "time"
- "github.com/metacubex/mihomo/common/lru"
- "github.com/metacubex/mihomo/common/nnip"
- "github.com/metacubex/mihomo/component/fakeip"
- R "github.com/metacubex/mihomo/component/resolver"
- C "github.com/metacubex/mihomo/constant"
- "github.com/metacubex/mihomo/context"
- "github.com/metacubex/mihomo/log"
- D "github.com/miekg/dns"
- )
- type (
- handler func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error)
- middleware func(next handler) handler
- )
- func withHosts(hosts R.Hosts, mapping *lru.LruCache[netip.Addr, string]) middleware {
- return func(next handler) handler {
- return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
- q := r.Question[0]
- if !isIPRequest(q) {
- return next(ctx, r)
- }
- host := strings.TrimRight(q.Name, ".")
- handleCName := func(resp *D.Msg, domain string) {
- rr := &D.CNAME{}
- rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10}
- rr.Target = domain + "."
- resp.Answer = append([]D.RR{rr}, resp.Answer...)
- }
- record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA)
- if !ok {
- if record != nil && record.IsDomain {
- // replace request domain
- newR := r.Copy()
- newR.Question[0].Name = record.Domain + "."
- resp, err := next(ctx, newR)
- if err == nil {
- resp.Id = r.Id
- resp.Question = r.Question
- handleCName(resp, record.Domain)
- }
- return resp, err
- }
- return next(ctx, r)
- }
- msg := r.Copy()
- handleIPs := func() {
- for _, ipAddr := range record.IPs {
- if ipAddr.Is4() && q.Qtype == D.TypeA {
- rr := &D.A{}
- rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10}
- rr.A = ipAddr.AsSlice()
- msg.Answer = append(msg.Answer, rr)
- if mapping != nil {
- mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
- }
- } else if q.Qtype == D.TypeAAAA {
- rr := &D.AAAA{}
- rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeAAAA, Class: D.ClassINET, Ttl: 10}
- ip := ipAddr.As16()
- rr.AAAA = ip[:]
- msg.Answer = append(msg.Answer, rr)
- if mapping != nil {
- mapping.SetWithExpire(ipAddr, host, time.Now().Add(time.Second*10))
- }
- }
- }
- }
- switch q.Qtype {
- case D.TypeA:
- handleIPs()
- case D.TypeAAAA:
- handleIPs()
- case D.TypeCNAME:
- handleCName(r, record.Domain)
- default:
- return next(ctx, r)
- }
- ctx.SetType(context.DNSTypeHost)
- msg.SetRcode(r, D.RcodeSuccess)
- msg.Authoritative = true
- msg.RecursionAvailable = true
- return msg, nil
- }
- }
- }
- func withMapping(mapping *lru.LruCache[netip.Addr, string]) middleware {
- return func(next handler) handler {
- return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
- q := r.Question[0]
- if !isIPRequest(q) {
- return next(ctx, r)
- }
- msg, err := next(ctx, r)
- if err != nil {
- return nil, err
- }
- host := strings.TrimRight(q.Name, ".")
- for _, ans := range msg.Answer {
- var ip netip.Addr
- var ttl uint32
- switch a := ans.(type) {
- case *D.A:
- ip = nnip.IpToAddr(a.A)
- ttl = a.Hdr.Ttl
- case *D.AAAA:
- ip = nnip.IpToAddr(a.AAAA)
- ttl = a.Hdr.Ttl
- default:
- continue
- }
- if ttl < 1 {
- ttl = 1
- }
- mapping.SetWithExpire(ip, host, time.Now().Add(time.Second*time.Duration(ttl)))
- }
- return msg, nil
- }
- }
- }
- func withFakeIP(fakePool *fakeip.Pool) middleware {
- return func(next handler) handler {
- return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
- q := r.Question[0]
- host := strings.TrimRight(q.Name, ".")
- if fakePool.ShouldSkipped(host) {
- return next(ctx, r)
- }
- switch q.Qtype {
- case D.TypeAAAA, D.TypeSVCB, D.TypeHTTPS:
- return handleMsgWithEmptyAnswer(r), nil
- }
- if q.Qtype != D.TypeA {
- return next(ctx, r)
- }
- rr := &D.A{}
- rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: dnsDefaultTTL}
- ip := fakePool.Lookup(host)
- rr.A = ip.AsSlice()
- msg := r.Copy()
- msg.Answer = []D.RR{rr}
- ctx.SetType(context.DNSTypeFakeIP)
- setMsgTTL(msg, 1)
- msg.SetRcode(r, D.RcodeSuccess)
- msg.Authoritative = true
- msg.RecursionAvailable = true
- return msg, nil
- }
- }
- }
- func withResolver(resolver *Resolver) handler {
- return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) {
- ctx.SetType(context.DNSTypeRaw)
- q := r.Question[0]
- // return a empty AAAA msg when ipv6 disabled
- if !resolver.ipv6 && q.Qtype == D.TypeAAAA {
- return handleMsgWithEmptyAnswer(r), nil
- }
- msg, err := resolver.ExchangeContext(ctx, r)
- if err != nil {
- log.Debugln("[DNS Server] Exchange %s failed: %v", q.String(), err)
- return msg, err
- }
- msg.SetRcode(r, msg.Rcode)
- msg.Authoritative = true
- return msg, nil
- }
- }
- func compose(middlewares []middleware, endpoint handler) handler {
- length := len(middlewares)
- h := endpoint
- for i := length - 1; i >= 0; i-- {
- middleware := middlewares[i]
- h = middleware(h)
- }
- return h
- }
- func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler {
- middlewares := []middleware{}
- if resolver.hosts != nil {
- middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping))
- }
- if mapper.mode == C.DNSFakeIP {
- middlewares = append(middlewares, withFakeIP(mapper.fakePool))
- }
- if mapper.mode != C.DNSNormal {
- middlewares = append(middlewares, withMapping(mapper.mapping))
- }
- return compose(middlewares, withResolver(resolver))
- }
|