dispatcher.go 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package sniffer
  2. import (
  3. "errors"
  4. "fmt"
  5. "net"
  6. "net/netip"
  7. "time"
  8. "github.com/metacubex/mihomo/common/lru"
  9. N "github.com/metacubex/mihomo/common/net"
  10. "github.com/metacubex/mihomo/component/trie"
  11. C "github.com/metacubex/mihomo/constant"
  12. "github.com/metacubex/mihomo/constant/sniffer"
  13. "github.com/metacubex/mihomo/log"
  14. )
  15. var (
  16. ErrorUnsupportedSniffer = errors.New("unsupported sniffer")
  17. ErrorSniffFailed = errors.New("all sniffer failed")
  18. ErrNoClue = errors.New("not enough information for making a decision")
  19. )
  20. var Dispatcher *SnifferDispatcher
  21. type SnifferDispatcher struct {
  22. enable bool
  23. sniffers map[sniffer.Sniffer]SnifferConfig
  24. forceDomain *trie.DomainSet
  25. skipSNI *trie.DomainSet
  26. skipList *lru.LruCache[string, uint8]
  27. forceDnsMapping bool
  28. parsePureIp bool
  29. }
  30. func (sd *SnifferDispatcher) shouldOverride(metadata *C.Metadata) bool {
  31. return (metadata.Host == "" && sd.parsePureIp) ||
  32. sd.forceDomain.Has(metadata.Host) ||
  33. (metadata.DNSMode == C.DNSMapping && sd.forceDnsMapping)
  34. }
  35. func (sd *SnifferDispatcher) UDPSniff(packet C.PacketAdapter) bool {
  36. metadata := packet.Metadata()
  37. if sd.shouldOverride(packet.Metadata()) {
  38. for sniffer, config := range sd.sniffers {
  39. if sniffer.SupportNetwork() == C.UDP || sniffer.SupportNetwork() == C.ALLNet {
  40. inWhitelist := sniffer.SupportPort(metadata.DstPort)
  41. overrideDest := config.OverrideDest
  42. if inWhitelist {
  43. host, err := sniffer.SniffData(packet.Data())
  44. if err != nil {
  45. continue
  46. }
  47. sd.replaceDomain(metadata, host, overrideDest)
  48. return true
  49. }
  50. }
  51. }
  52. }
  53. return false
  54. }
  55. // TCPSniff returns true if the connection is sniffed to have a domain
  56. func (sd *SnifferDispatcher) TCPSniff(conn *N.BufferedConn, metadata *C.Metadata) bool {
  57. if sd.shouldOverride(metadata) {
  58. inWhitelist := false
  59. overrideDest := false
  60. for sniffer, config := range sd.sniffers {
  61. if sniffer.SupportNetwork() == C.TCP || sniffer.SupportNetwork() == C.ALLNet {
  62. inWhitelist = sniffer.SupportPort(metadata.DstPort)
  63. if inWhitelist {
  64. overrideDest = config.OverrideDest
  65. break
  66. }
  67. }
  68. }
  69. if !inWhitelist {
  70. return false
  71. }
  72. dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
  73. if count, ok := sd.skipList.Get(dst); ok && count > 5 {
  74. log.Debugln("[Sniffer] Skip sniffing[%s] due to multiple failures", dst)
  75. return false
  76. }
  77. if host, err := sd.sniffDomain(conn, metadata); err != nil {
  78. sd.cacheSniffFailed(metadata)
  79. log.Debugln("[Sniffer] All sniffing sniff failed with from [%s:%d] to [%s:%d]", metadata.SrcIP, metadata.SrcPort, metadata.String(), metadata.DstPort)
  80. return false
  81. } else {
  82. if sd.skipSNI.Has(host) {
  83. log.Debugln("[Sniffer] Skip sni[%s]", host)
  84. return false
  85. }
  86. sd.skipList.Delete(dst)
  87. sd.replaceDomain(metadata, host, overrideDest)
  88. return true
  89. }
  90. }
  91. return false
  92. }
  93. func (sd *SnifferDispatcher) replaceDomain(metadata *C.Metadata, host string, overrideDest bool) {
  94. metadata.SniffHost = host
  95. if overrideDest {
  96. log.Debugln("[Sniffer] Sniff %s [%s]-->[%s] success, replace domain [%s]-->[%s]",
  97. metadata.NetWork,
  98. metadata.SourceDetail(),
  99. metadata.RemoteAddress(),
  100. metadata.Host, host)
  101. metadata.Host = host
  102. }
  103. metadata.DNSMode = C.DNSNormal
  104. }
  105. func (sd *SnifferDispatcher) Enable() bool {
  106. return sd.enable
  107. }
  108. func (sd *SnifferDispatcher) sniffDomain(conn *N.BufferedConn, metadata *C.Metadata) (string, error) {
  109. for s := range sd.sniffers {
  110. if s.SupportNetwork() == C.TCP {
  111. _ = conn.SetReadDeadline(time.Now().Add(1 * time.Second))
  112. _, err := conn.Peek(1)
  113. _ = conn.SetReadDeadline(time.Time{})
  114. if err != nil {
  115. _, ok := err.(*net.OpError)
  116. if ok {
  117. sd.cacheSniffFailed(metadata)
  118. log.Errorln("[Sniffer] [%s] may not have any sent data, Consider adding skip", metadata.DstIP.String())
  119. _ = conn.Close()
  120. }
  121. return "", err
  122. }
  123. bufferedLen := conn.Buffered()
  124. bytes, err := conn.Peek(bufferedLen)
  125. if err != nil {
  126. log.Debugln("[Sniffer] the data length not enough")
  127. continue
  128. }
  129. host, err := s.SniffData(bytes)
  130. if err != nil {
  131. //log.Debugln("[Sniffer] [%s] Sniff data failed %s", s.Protocol(), metadata.DstIP)
  132. continue
  133. }
  134. _, err = netip.ParseAddr(host)
  135. if err == nil {
  136. //log.Debugln("[Sniffer] [%s] Sniff data failed %s", s.Protocol(), metadata.DstIP)
  137. continue
  138. }
  139. return host, nil
  140. }
  141. }
  142. return "", ErrorSniffFailed
  143. }
  144. func (sd *SnifferDispatcher) cacheSniffFailed(metadata *C.Metadata) {
  145. dst := fmt.Sprintf("%s:%d", metadata.DstIP, metadata.DstPort)
  146. sd.skipList.Compute(dst, func(oldValue uint8, loaded bool) (newValue uint8, delete bool) {
  147. if oldValue <= 5 {
  148. oldValue++
  149. }
  150. return oldValue, false
  151. })
  152. }
  153. func NewCloseSnifferDispatcher() (*SnifferDispatcher, error) {
  154. dispatcher := SnifferDispatcher{
  155. enable: false,
  156. }
  157. return &dispatcher, nil
  158. }
  159. func NewSnifferDispatcher(snifferConfig map[sniffer.Type]SnifferConfig,
  160. forceDomain *trie.DomainSet, skipSNI *trie.DomainSet,
  161. forceDnsMapping bool, parsePureIp bool) (*SnifferDispatcher, error) {
  162. dispatcher := SnifferDispatcher{
  163. enable: true,
  164. forceDomain: forceDomain,
  165. skipSNI: skipSNI,
  166. skipList: lru.New(lru.WithSize[string, uint8](128), lru.WithAge[string, uint8](600)),
  167. forceDnsMapping: forceDnsMapping,
  168. parsePureIp: parsePureIp,
  169. sniffers: make(map[sniffer.Sniffer]SnifferConfig, 0),
  170. }
  171. for snifferName, config := range snifferConfig {
  172. s, err := NewSniffer(snifferName, config)
  173. if err != nil {
  174. log.Errorln("Sniffer name[%s] is error", snifferName)
  175. return &SnifferDispatcher{enable: false}, err
  176. }
  177. dispatcher.sniffers[s] = config
  178. }
  179. return &dispatcher, nil
  180. }
  181. func NewSniffer(name sniffer.Type, snifferConfig SnifferConfig) (sniffer.Sniffer, error) {
  182. switch name {
  183. case sniffer.TLS:
  184. return NewTLSSniffer(snifferConfig)
  185. case sniffer.HTTP:
  186. return NewHTTPSniffer(snifferConfig)
  187. case sniffer.QUIC:
  188. return NewQuicSniffer(snifferConfig)
  189. default:
  190. return nil, ErrorUnsupportedSniffer
  191. }
  192. }