socks5.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. package socks5
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "io"
  7. "net"
  8. "net/netip"
  9. "strconv"
  10. "github.com/metacubex/mihomo/component/auth"
  11. )
  12. // Error represents a SOCKS error
  13. type Error byte
  14. func (err Error) Error() string {
  15. return "SOCKS error: " + strconv.Itoa(int(err))
  16. }
  17. // Command is request commands as defined in RFC 1928 section 4.
  18. type Command = uint8
  19. const Version = 5
  20. // SOCKS request commands as defined in RFC 1928 section 4.
  21. const (
  22. CmdConnect Command = 1
  23. CmdBind Command = 2
  24. CmdUDPAssociate Command = 3
  25. )
  26. // SOCKS address types as defined in RFC 1928 section 5.
  27. const (
  28. AtypIPv4 = 1
  29. AtypDomainName = 3
  30. AtypIPv6 = 4
  31. )
  32. // MaxAddrLen is the maximum size of SOCKS address in bytes.
  33. const MaxAddrLen = 1 + 1 + 255 + 2
  34. // MaxAuthLen is the maximum size of user/password field in SOCKS5 Auth
  35. const MaxAuthLen = 255
  36. // Addr represents a SOCKS address as defined in RFC 1928 section 5.
  37. type Addr []byte
  38. func (a Addr) String() string {
  39. var host, port string
  40. switch a[0] {
  41. case AtypDomainName:
  42. hostLen := uint16(a[1])
  43. host = string(a[2 : 2+hostLen])
  44. port = strconv.Itoa((int(a[2+hostLen]) << 8) | int(a[2+hostLen+1]))
  45. case AtypIPv4:
  46. host = net.IP(a[1 : 1+net.IPv4len]).String()
  47. port = strconv.Itoa((int(a[1+net.IPv4len]) << 8) | int(a[1+net.IPv4len+1]))
  48. case AtypIPv6:
  49. host = net.IP(a[1 : 1+net.IPv6len]).String()
  50. port = strconv.Itoa((int(a[1+net.IPv6len]) << 8) | int(a[1+net.IPv6len+1]))
  51. }
  52. return net.JoinHostPort(host, port)
  53. }
  54. // UDPAddr converts a socks5.Addr to *net.UDPAddr
  55. func (a Addr) UDPAddr() *net.UDPAddr {
  56. if len(a) == 0 {
  57. return nil
  58. }
  59. switch a[0] {
  60. case AtypIPv4:
  61. var ip [net.IPv4len]byte
  62. copy(ip[0:], a[1:1+net.IPv4len])
  63. return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
  64. case AtypIPv6:
  65. var ip [net.IPv6len]byte
  66. copy(ip[0:], a[1:1+net.IPv6len])
  67. return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
  68. }
  69. // Other Atyp
  70. return nil
  71. }
  72. // SOCKS errors as defined in RFC 1928 section 6.
  73. const (
  74. ErrGeneralFailure = Error(1)
  75. ErrConnectionNotAllowed = Error(2)
  76. ErrNetworkUnreachable = Error(3)
  77. ErrHostUnreachable = Error(4)
  78. ErrConnectionRefused = Error(5)
  79. ErrTTLExpired = Error(6)
  80. ErrCommandNotSupported = Error(7)
  81. ErrAddressNotSupported = Error(8)
  82. )
  83. // Auth errors used to return a specific "Auth failed" error
  84. var ErrAuth = errors.New("auth failed")
  85. type User struct {
  86. Username string
  87. Password string
  88. }
  89. // ServerHandshake fast-tracks SOCKS initialization to get target address to connect on server side.
  90. func ServerHandshake(rw net.Conn, authenticator auth.Authenticator) (addr Addr, command Command, user string, err error) {
  91. // Read RFC 1928 for request and reply structure and sizes.
  92. buf := make([]byte, MaxAddrLen)
  93. // read VER, NMETHODS, METHODS
  94. if _, err = io.ReadFull(rw, buf[:2]); err != nil {
  95. return
  96. }
  97. nmethods := buf[1]
  98. if _, err = io.ReadFull(rw, buf[:nmethods]); err != nil {
  99. return
  100. }
  101. // write VER METHOD
  102. if authenticator != nil {
  103. if _, err = rw.Write([]byte{5, 2}); err != nil {
  104. return
  105. }
  106. // Get header
  107. header := make([]byte, 2)
  108. if _, err = io.ReadFull(rw, header); err != nil {
  109. return
  110. }
  111. authBuf := make([]byte, MaxAuthLen)
  112. // Get username
  113. userLen := int(header[1])
  114. if userLen <= 0 {
  115. rw.Write([]byte{1, 1})
  116. err = ErrAuth
  117. return
  118. }
  119. if _, err = io.ReadFull(rw, authBuf[:userLen]); err != nil {
  120. return
  121. }
  122. user = string(authBuf[:userLen])
  123. // Get password
  124. if _, err = rw.Read(header[:1]); err != nil {
  125. return
  126. }
  127. passLen := int(header[0])
  128. if passLen <= 0 {
  129. rw.Write([]byte{1, 1})
  130. err = ErrAuth
  131. return
  132. }
  133. if _, err = io.ReadFull(rw, authBuf[:passLen]); err != nil {
  134. return
  135. }
  136. pass := string(authBuf[:passLen])
  137. // Verify
  138. if ok := authenticator.Verify(string(user), string(pass)); !ok {
  139. rw.Write([]byte{1, 1})
  140. err = ErrAuth
  141. return
  142. }
  143. // Response auth state
  144. if _, err = rw.Write([]byte{1, 0}); err != nil {
  145. return
  146. }
  147. } else {
  148. if _, err = rw.Write([]byte{5, 0}); err != nil {
  149. return
  150. }
  151. }
  152. // read VER CMD RSV ATYP DST.ADDR DST.PORT
  153. if _, err = io.ReadFull(rw, buf[:3]); err != nil {
  154. return
  155. }
  156. command = buf[1]
  157. addr, err = ReadAddr(rw, buf)
  158. if err != nil {
  159. return
  160. }
  161. switch command {
  162. case CmdConnect, CmdUDPAssociate:
  163. // Acquire server listened address info
  164. localAddr := ParseAddr(rw.LocalAddr().String())
  165. if localAddr == nil {
  166. err = ErrAddressNotSupported
  167. } else {
  168. // write VER REP RSV ATYP BND.ADDR BND.PORT
  169. _, err = rw.Write(bytes.Join([][]byte{{5, 0, 0}, localAddr}, []byte{}))
  170. }
  171. case CmdBind:
  172. fallthrough
  173. default:
  174. err = ErrCommandNotSupported
  175. }
  176. return
  177. }
  178. // ClientHandshake fast-tracks SOCKS initialization to get target address to connect on client side.
  179. func ClientHandshake(rw io.ReadWriter, addr Addr, command Command, user *User) (Addr, error) {
  180. buf := make([]byte, MaxAddrLen)
  181. var err error
  182. // VER, NMETHODS, METHODS
  183. if user != nil {
  184. _, err = rw.Write([]byte{5, 1, 2})
  185. } else {
  186. _, err = rw.Write([]byte{5, 1, 0})
  187. }
  188. if err != nil {
  189. return nil, err
  190. }
  191. // VER, METHOD
  192. if _, err := io.ReadFull(rw, buf[:2]); err != nil {
  193. return nil, err
  194. }
  195. if buf[0] != 5 {
  196. return nil, errors.New("SOCKS version error")
  197. }
  198. if buf[1] == 2 {
  199. if user == nil {
  200. return nil, ErrAuth
  201. }
  202. // password protocol version
  203. authMsg := &bytes.Buffer{}
  204. authMsg.WriteByte(1)
  205. authMsg.WriteByte(uint8(len(user.Username)))
  206. authMsg.WriteString(user.Username)
  207. authMsg.WriteByte(uint8(len(user.Password)))
  208. authMsg.WriteString(user.Password)
  209. if _, err := rw.Write(authMsg.Bytes()); err != nil {
  210. return nil, err
  211. }
  212. if _, err := io.ReadFull(rw, buf[:2]); err != nil {
  213. return nil, err
  214. }
  215. if buf[1] != 0 {
  216. return nil, errors.New("rejected username/password")
  217. }
  218. } else if buf[1] != 0 {
  219. return nil, errors.New("SOCKS need auth")
  220. }
  221. // VER, CMD, RSV, ADDR
  222. if _, err := rw.Write(bytes.Join([][]byte{{5, command, 0}, addr}, []byte{})); err != nil {
  223. return nil, err
  224. }
  225. // VER, REP, RSV
  226. if _, err := io.ReadFull(rw, buf[:3]); err != nil {
  227. return nil, err
  228. }
  229. return ReadAddr(rw, buf)
  230. }
  231. func ReadAddr(r io.Reader, b []byte) (Addr, error) {
  232. if len(b) < MaxAddrLen {
  233. return nil, io.ErrShortBuffer
  234. }
  235. _, err := io.ReadFull(r, b[:1]) // read 1st byte for address type
  236. if err != nil {
  237. return nil, err
  238. }
  239. switch b[0] {
  240. case AtypDomainName:
  241. _, err = io.ReadFull(r, b[1:2]) // read 2nd byte for domain length
  242. if err != nil {
  243. return nil, err
  244. }
  245. domainLength := uint16(b[1])
  246. _, err = io.ReadFull(r, b[2:2+domainLength+2])
  247. return b[:1+1+domainLength+2], err
  248. case AtypIPv4:
  249. _, err = io.ReadFull(r, b[1:1+net.IPv4len+2])
  250. return b[:1+net.IPv4len+2], err
  251. case AtypIPv6:
  252. _, err = io.ReadFull(r, b[1:1+net.IPv6len+2])
  253. return b[:1+net.IPv6len+2], err
  254. }
  255. return nil, ErrAddressNotSupported
  256. }
  257. func ReadAddr0(r io.Reader) (Addr, error) {
  258. aType, err := ReadByte(r) // read 1st byte for address type
  259. if err != nil {
  260. return nil, err
  261. }
  262. switch aType {
  263. case AtypDomainName:
  264. var domainLength byte
  265. domainLength, err = ReadByte(r) // read 2nd byte for domain length
  266. if err != nil {
  267. return nil, err
  268. }
  269. b := make([]byte, 1+1+uint16(domainLength)+2)
  270. _, err = io.ReadFull(r, b[2:])
  271. b[0] = aType
  272. b[1] = domainLength
  273. return b, err
  274. case AtypIPv4:
  275. var b [1 + net.IPv4len + 2]byte
  276. _, err = io.ReadFull(r, b[1:])
  277. b[0] = aType
  278. return b[:], err
  279. case AtypIPv6:
  280. var b [1 + net.IPv6len + 2]byte
  281. _, err = io.ReadFull(r, b[1:])
  282. b[0] = aType
  283. return b[:], err
  284. }
  285. return nil, ErrAddressNotSupported
  286. }
  287. func ReadByte(reader io.Reader) (byte, error) {
  288. if br, isBr := reader.(io.ByteReader); isBr {
  289. return br.ReadByte()
  290. }
  291. var b [1]byte
  292. if _, err := io.ReadFull(reader, b[:]); err != nil {
  293. return 0, err
  294. }
  295. return b[0], nil
  296. }
  297. // SplitAddr slices a SOCKS address from beginning of b. Returns nil if failed.
  298. func SplitAddr(b []byte) Addr {
  299. addrLen := 1
  300. if len(b) < addrLen {
  301. return nil
  302. }
  303. switch b[0] {
  304. case AtypDomainName:
  305. if len(b) < 2 {
  306. return nil
  307. }
  308. addrLen = 1 + 1 + int(b[1]) + 2
  309. case AtypIPv4:
  310. addrLen = 1 + net.IPv4len + 2
  311. case AtypIPv6:
  312. addrLen = 1 + net.IPv6len + 2
  313. default:
  314. return nil
  315. }
  316. if len(b) < addrLen {
  317. return nil
  318. }
  319. return b[:addrLen]
  320. }
  321. // ParseAddr parses the address in string s. Returns nil if failed.
  322. func ParseAddr(s string) Addr {
  323. var addr Addr
  324. host, port, err := net.SplitHostPort(s)
  325. if err != nil {
  326. return nil
  327. }
  328. if ip := net.ParseIP(host); ip != nil {
  329. if ip4 := ip.To4(); ip4 != nil {
  330. addr = make([]byte, 1+net.IPv4len+2)
  331. addr[0] = AtypIPv4
  332. copy(addr[1:], ip4)
  333. } else {
  334. addr = make([]byte, 1+net.IPv6len+2)
  335. addr[0] = AtypIPv6
  336. copy(addr[1:], ip)
  337. }
  338. } else {
  339. if len(host) > 255 {
  340. return nil
  341. }
  342. addr = make([]byte, 1+1+len(host)+2)
  343. addr[0] = AtypDomainName
  344. addr[1] = byte(len(host))
  345. copy(addr[2:], host)
  346. }
  347. portnum, err := strconv.ParseUint(port, 10, 16)
  348. if err != nil {
  349. return nil
  350. }
  351. addr[len(addr)-2], addr[len(addr)-1] = byte(portnum>>8), byte(portnum)
  352. return addr
  353. }
  354. // ParseAddrToSocksAddr parse a socks addr from net.addr
  355. // This is a fast path of ParseAddr(addr.String())
  356. func ParseAddrToSocksAddr(addr net.Addr) Addr {
  357. var hostip net.IP
  358. var port int
  359. if udpaddr, ok := addr.(*net.UDPAddr); ok {
  360. hostip = udpaddr.IP
  361. port = udpaddr.Port
  362. } else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
  363. hostip = tcpaddr.IP
  364. port = tcpaddr.Port
  365. }
  366. // fallback parse
  367. if hostip == nil {
  368. return ParseAddr(addr.String())
  369. }
  370. var parsed Addr
  371. if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
  372. parsed = make([]byte, 1+net.IPv4len+2)
  373. parsed[0] = AtypIPv4
  374. copy(parsed[1:], ip4)
  375. binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
  376. } else {
  377. parsed = make([]byte, 1+net.IPv6len+2)
  378. parsed[0] = AtypIPv6
  379. copy(parsed[1:], hostip)
  380. binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
  381. }
  382. return parsed
  383. }
  384. func AddrFromStdAddrPort(addrPort netip.AddrPort) Addr {
  385. addr := addrPort.Addr()
  386. if addr.Is4() {
  387. ip4 := addr.As4()
  388. return []byte{AtypIPv4, ip4[0], ip4[1], ip4[2], ip4[3], byte(addrPort.Port() >> 8), byte(addrPort.Port())}
  389. }
  390. buf := make([]byte, 1+net.IPv6len+2)
  391. buf[0] = AtypIPv6
  392. copy(buf[1:], addr.AsSlice())
  393. buf[1+net.IPv6len] = byte(addrPort.Port() >> 8)
  394. buf[1+net.IPv6len+1] = byte(addrPort.Port())
  395. return buf
  396. }
  397. // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
  398. func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
  399. if len(packet) < 5 {
  400. err = errors.New("insufficient length of packet")
  401. return
  402. }
  403. // packet[0] and packet[1] are reserved
  404. if !bytes.Equal(packet[:2], []byte{0, 0}) {
  405. err = errors.New("reserved fields should be zero")
  406. return
  407. }
  408. if packet[2] != 0 /* fragments */ {
  409. err = errors.New("discarding fragmented payload")
  410. return
  411. }
  412. addr = SplitAddr(packet[3:])
  413. if addr == nil {
  414. err = errors.New("failed to read UDP header")
  415. }
  416. payload = packet[3+len(addr):]
  417. return
  418. }
  419. func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) {
  420. if addr == nil {
  421. err = errors.New("address is invalid")
  422. return
  423. }
  424. packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{})
  425. return
  426. }