dialer.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. package dialer
  2. import (
  3. "context"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "net/netip"
  8. "os"
  9. "strconv"
  10. "strings"
  11. "sync"
  12. "time"
  13. "github.com/metacubex/mihomo/component/resolver"
  14. "github.com/metacubex/mihomo/log"
  15. )
  16. const (
  17. DefaultTCPTimeout = 5 * time.Second
  18. DefaultUDPTimeout = DefaultTCPTimeout
  19. )
  20. type dialFunc func(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error)
  21. var (
  22. dialMux sync.Mutex
  23. IP4PEnable bool
  24. actualSingleStackDialContext = serialSingleStackDialContext
  25. actualDualStackDialContext = serialDualStackDialContext
  26. tcpConcurrent = false
  27. fallbackTimeout = 300 * time.Millisecond
  28. )
  29. func applyOptions(options ...Option) *option {
  30. opt := &option{
  31. interfaceName: DefaultInterface.Load(),
  32. routingMark: int(DefaultRoutingMark.Load()),
  33. }
  34. for _, o := range DefaultOptions {
  35. o(opt)
  36. }
  37. for _, o := range options {
  38. o(opt)
  39. }
  40. return opt
  41. }
  42. func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
  43. opt := applyOptions(options...)
  44. if opt.network == 4 || opt.network == 6 {
  45. if strings.Contains(network, "tcp") {
  46. network = "tcp"
  47. } else {
  48. network = "udp"
  49. }
  50. network = fmt.Sprintf("%s%d", network, opt.network)
  51. }
  52. ips, port, err := parseAddr(ctx, network, address, opt.resolver)
  53. if err != nil {
  54. return nil, err
  55. }
  56. switch network {
  57. case "tcp4", "tcp6", "udp4", "udp6":
  58. return actualSingleStackDialContext(ctx, network, ips, port, opt)
  59. case "tcp", "udp":
  60. return actualDualStackDialContext(ctx, network, ips, port, opt)
  61. default:
  62. return nil, ErrorInvalidedNetworkStack
  63. }
  64. }
  65. func ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort, options ...Option) (net.PacketConn, error) {
  66. if DefaultSocketHook != nil {
  67. return listenPacketHooked(ctx, network, address)
  68. }
  69. cfg := applyOptions(options...)
  70. lc := &net.ListenConfig{}
  71. if cfg.interfaceName != "" {
  72. bind := bindIfaceToListenConfig
  73. if cfg.fallbackBind {
  74. bind = fallbackBindIfaceToListenConfig
  75. }
  76. addr, err := bind(cfg.interfaceName, lc, network, address, rAddrPort)
  77. if err != nil {
  78. return nil, err
  79. }
  80. address = addr
  81. }
  82. if cfg.addrReuse {
  83. addrReuseToListenConfig(lc)
  84. }
  85. if cfg.routingMark != 0 {
  86. bindMarkToListenConfig(cfg.routingMark, lc, network, address)
  87. }
  88. return lc.ListenPacket(ctx, network, address)
  89. }
  90. func SetTcpConcurrent(concurrent bool) {
  91. dialMux.Lock()
  92. defer dialMux.Unlock()
  93. tcpConcurrent = concurrent
  94. if concurrent {
  95. actualSingleStackDialContext = concurrentSingleStackDialContext
  96. actualDualStackDialContext = concurrentDualStackDialContext
  97. } else {
  98. actualSingleStackDialContext = serialSingleStackDialContext
  99. actualDualStackDialContext = serialDualStackDialContext
  100. }
  101. }
  102. func GetTcpConcurrent() bool {
  103. dialMux.Lock()
  104. defer dialMux.Unlock()
  105. return tcpConcurrent
  106. }
  107. func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
  108. if DefaultSocketHook != nil {
  109. return dialContextHooked(ctx, network, destination, port)
  110. }
  111. var address string
  112. if IP4PEnable {
  113. destination, port = lookupIP4P(destination, port)
  114. }
  115. address = net.JoinHostPort(destination.String(), port)
  116. netDialer := opt.netDialer
  117. switch netDialer.(type) {
  118. case nil:
  119. netDialer = &net.Dialer{}
  120. case *net.Dialer:
  121. _netDialer := *netDialer.(*net.Dialer)
  122. netDialer = &_netDialer // make a copy
  123. default:
  124. return netDialer.DialContext(ctx, network, address)
  125. }
  126. dialer := netDialer.(*net.Dialer)
  127. if opt.interfaceName != "" {
  128. bind := bindIfaceToDialer
  129. if opt.fallbackBind {
  130. bind = fallbackBindIfaceToDialer
  131. }
  132. if err := bind(opt.interfaceName, dialer, network, destination); err != nil {
  133. return nil, err
  134. }
  135. }
  136. if opt.routingMark != 0 {
  137. bindMarkToDialer(opt.routingMark, dialer, network, destination)
  138. }
  139. if opt.mpTcp {
  140. setMultiPathTCP(dialer)
  141. }
  142. if opt.tfo && !DisableTFO {
  143. return dialTFO(ctx, *dialer, network, address)
  144. }
  145. return dialer.DialContext(ctx, network, address)
  146. }
  147. func serialSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
  148. return serialDialContext(ctx, network, ips, port, opt)
  149. }
  150. func serialDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
  151. return dualStackDialContext(ctx, serialDialContext, network, ips, port, opt)
  152. }
  153. func concurrentSingleStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
  154. return parallelDialContext(ctx, network, ips, port, opt)
  155. }
  156. func concurrentDualStackDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
  157. if opt.prefer != 4 && opt.prefer != 6 {
  158. return parallelDialContext(ctx, network, ips, port, opt)
  159. }
  160. return dualStackDialContext(ctx, parallelDialContext, network, ips, port, opt)
  161. }
  162. func dualStackDialContext(ctx context.Context, dialFn dialFunc, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
  163. ipv4s, ipv6s := resolver.SortationAddr(ips)
  164. if len(ipv4s) == 0 && len(ipv6s) == 0 {
  165. return nil, ErrorNoIpAddress
  166. }
  167. preferIPVersion := opt.prefer
  168. fallbackTicker := time.NewTicker(fallbackTimeout)
  169. defer fallbackTicker.Stop()
  170. results := make(chan dialResult)
  171. returned := make(chan struct{})
  172. defer close(returned)
  173. var wg sync.WaitGroup
  174. racer := func(ips []netip.Addr, isPrimary bool) {
  175. defer wg.Done()
  176. result := dialResult{isPrimary: isPrimary}
  177. defer func() {
  178. select {
  179. case results <- result:
  180. case <-returned:
  181. if result.Conn != nil && result.error == nil {
  182. _ = result.Conn.Close()
  183. }
  184. }
  185. }()
  186. result.Conn, result.error = dialFn(ctx, network, ips, port, opt)
  187. }
  188. if len(ipv4s) != 0 {
  189. wg.Add(1)
  190. go racer(ipv4s, preferIPVersion != 6)
  191. }
  192. if len(ipv6s) != 0 {
  193. wg.Add(1)
  194. go racer(ipv6s, preferIPVersion != 4)
  195. }
  196. go func() {
  197. wg.Wait()
  198. close(results)
  199. }()
  200. var fallback dialResult
  201. var errs []error
  202. loop:
  203. for {
  204. select {
  205. case <-fallbackTicker.C:
  206. if fallback.error == nil && fallback.Conn != nil {
  207. return fallback.Conn, nil
  208. }
  209. case res, ok := <-results:
  210. if !ok {
  211. break loop
  212. }
  213. if res.error == nil {
  214. if res.isPrimary {
  215. return res.Conn, nil
  216. }
  217. fallback = res
  218. } else {
  219. if res.isPrimary {
  220. errs = append([]error{fmt.Errorf("connect failed: %w", res.error)}, errs...)
  221. } else {
  222. errs = append(errs, fmt.Errorf("connect failed: %w", res.error))
  223. }
  224. }
  225. }
  226. }
  227. if fallback.error == nil && fallback.Conn != nil {
  228. return fallback.Conn, nil
  229. }
  230. return nil, errors.Join(errs...)
  231. }
  232. func parallelDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
  233. if len(ips) == 0 {
  234. return nil, ErrorNoIpAddress
  235. }
  236. results := make(chan dialResult)
  237. returned := make(chan struct{})
  238. defer close(returned)
  239. racer := func(ctx context.Context, ip netip.Addr) {
  240. result := dialResult{isPrimary: true, ip: ip}
  241. defer func() {
  242. select {
  243. case results <- result:
  244. case <-returned:
  245. if result.Conn != nil && result.error == nil {
  246. _ = result.Conn.Close()
  247. }
  248. }
  249. }()
  250. result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
  251. }
  252. for _, ip := range ips {
  253. go racer(ctx, ip)
  254. }
  255. var errs []error
  256. for i := 0; i < len(ips); i++ {
  257. res := <-results
  258. if res.error == nil {
  259. return res.Conn, nil
  260. }
  261. errs = append(errs, res.error)
  262. }
  263. if len(errs) > 0 {
  264. return nil, errors.Join(errs...)
  265. }
  266. return nil, os.ErrDeadlineExceeded
  267. }
  268. func serialDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
  269. if len(ips) == 0 {
  270. return nil, ErrorNoIpAddress
  271. }
  272. var errs []error
  273. for _, ip := range ips {
  274. if conn, err := dialContext(ctx, network, ip, port, opt); err == nil {
  275. return conn, nil
  276. } else {
  277. errs = append(errs, err)
  278. }
  279. }
  280. return nil, errors.Join(errs...)
  281. }
  282. type dialResult struct {
  283. ip netip.Addr
  284. net.Conn
  285. error
  286. isPrimary bool
  287. }
  288. func parseAddr(ctx context.Context, network, address string, preferResolver resolver.Resolver) ([]netip.Addr, string, error) {
  289. host, port, err := net.SplitHostPort(address)
  290. if err != nil {
  291. return nil, "-1", err
  292. }
  293. var ips []netip.Addr
  294. switch network {
  295. case "tcp4", "udp4":
  296. if preferResolver == nil {
  297. ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host)
  298. } else {
  299. ips, err = resolver.LookupIPv4WithResolver(ctx, host, preferResolver)
  300. }
  301. case "tcp6", "udp6":
  302. if preferResolver == nil {
  303. ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host)
  304. } else {
  305. ips, err = resolver.LookupIPv6WithResolver(ctx, host, preferResolver)
  306. }
  307. default:
  308. if preferResolver == nil {
  309. ips, err = resolver.LookupIPProxyServerHost(ctx, host)
  310. } else {
  311. ips, err = resolver.LookupIPWithResolver(ctx, host, preferResolver)
  312. }
  313. }
  314. if err != nil {
  315. return nil, "-1", fmt.Errorf("dns resolve failed: %w", err)
  316. }
  317. for i, ip := range ips {
  318. if ip.Is4In6() {
  319. ips[i] = ip.Unmap()
  320. }
  321. }
  322. return ips, port, nil
  323. }
  324. type Dialer struct {
  325. Opt option
  326. }
  327. func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  328. return DialContext(ctx, network, address, WithOption(d.Opt))
  329. }
  330. func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
  331. opt := d.Opt // make a copy
  332. if rAddrPort.Addr().Unmap().IsLoopback() {
  333. // avoid "The requested address is not valid in its context."
  334. WithInterface("")(&opt)
  335. }
  336. return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, rAddrPort, WithOption(opt))
  337. }
  338. func NewDialer(options ...Option) Dialer {
  339. opt := applyOptions(options...)
  340. return Dialer{Opt: *opt}
  341. }
  342. func GetIP4PEnable(enableIP4PConvert bool) {
  343. IP4PEnable = enableIP4PConvert
  344. }
  345. // kanged from https://github.com/heiher/frp/blob/ip4p/client/ip4p.go
  346. func lookupIP4P(addr netip.Addr, port string) (netip.Addr, string) {
  347. ip := addr.AsSlice()
  348. if ip[0] == 0x20 && ip[1] == 0x01 &&
  349. ip[2] == 0x00 && ip[3] == 0x00 {
  350. addr = netip.AddrFrom4([4]byte{ip[12], ip[13], ip[14], ip[15]})
  351. port = strconv.Itoa(int(ip[10])<<8 + int(ip[11]))
  352. log.Debugln("Convert IP4P address %s to %s", ip, net.JoinHostPort(addr.String(), port))
  353. return addr, port
  354. }
  355. return addr, port
  356. }