websocket.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. package vmess
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "crypto/rand"
  7. "crypto/sha1"
  8. "crypto/tls"
  9. "encoding/base64"
  10. "encoding/binary"
  11. "errors"
  12. "fmt"
  13. "io"
  14. "net"
  15. "net/http"
  16. "net/url"
  17. "strconv"
  18. "strings"
  19. "time"
  20. "github.com/metacubex/mihomo/common/buf"
  21. N "github.com/metacubex/mihomo/common/net"
  22. tlsC "github.com/metacubex/mihomo/component/tls"
  23. "github.com/metacubex/mihomo/log"
  24. "github.com/gobwas/ws"
  25. "github.com/gobwas/ws/wsutil"
  26. "github.com/metacubex/randv2"
  27. )
  28. type websocketConn struct {
  29. net.Conn
  30. state ws.State
  31. reader *wsutil.Reader
  32. controlHandler wsutil.FrameHandlerFunc
  33. rawWriter N.ExtendedWriter
  34. }
  35. type websocketWithEarlyDataConn struct {
  36. net.Conn
  37. wsWriter N.ExtendedWriter
  38. underlay net.Conn
  39. closed bool
  40. dialed chan bool
  41. cancel context.CancelFunc
  42. ctx context.Context
  43. config *WebsocketConfig
  44. }
  45. type WebsocketConfig struct {
  46. Host string
  47. Port string
  48. Path string
  49. Headers http.Header
  50. TLS bool
  51. TLSConfig *tls.Config
  52. MaxEarlyData int
  53. EarlyDataHeaderName string
  54. ClientFingerprint string
  55. V2rayHttpUpgrade bool
  56. V2rayHttpUpgradeFastOpen bool
  57. }
  58. // Read implements net.Conn.Read()
  59. // modify from gobwas/ws/wsutil.readData
  60. func (wsc *websocketConn) Read(b []byte) (n int, err error) {
  61. defer func() { // avoid gobwas/ws pbytes.GetLen panic
  62. if value := recover(); value != nil {
  63. err = fmt.Errorf("websocket error: %s", value)
  64. }
  65. }()
  66. var header ws.Header
  67. for {
  68. n, err = wsc.reader.Read(b)
  69. // in gobwas/ws: "The error is io.EOF only if all of message bytes were read."
  70. // but maybe next frame still have data, so drop it
  71. if errors.Is(err, io.EOF) {
  72. err = nil
  73. }
  74. if !errors.Is(err, wsutil.ErrNoFrameAdvance) {
  75. return
  76. }
  77. header, err = wsc.reader.NextFrame()
  78. if err != nil {
  79. return
  80. }
  81. if header.OpCode.IsControl() {
  82. err = wsc.controlHandler(header, wsc.reader)
  83. if err != nil {
  84. return
  85. }
  86. continue
  87. }
  88. if header.OpCode&(ws.OpBinary|ws.OpText) == 0 {
  89. err = wsc.reader.Discard()
  90. if err != nil {
  91. return
  92. }
  93. continue
  94. }
  95. }
  96. }
  97. // Write implements io.Writer.
  98. func (wsc *websocketConn) Write(b []byte) (n int, err error) {
  99. err = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpBinary, b)
  100. if err != nil {
  101. return
  102. }
  103. n = len(b)
  104. return
  105. }
  106. func (wsc *websocketConn) WriteBuffer(buffer *buf.Buffer) error {
  107. var payloadBitLength int
  108. dataLen := buffer.Len()
  109. data := buffer.Bytes()
  110. if dataLen < 126 {
  111. payloadBitLength = 1
  112. } else if dataLen < 65536 {
  113. payloadBitLength = 3
  114. } else {
  115. payloadBitLength = 9
  116. }
  117. var headerLen int
  118. headerLen += 1 // FIN / RSV / OPCODE
  119. headerLen += payloadBitLength
  120. if wsc.state.ClientSide() {
  121. headerLen += 4 // MASK KEY
  122. }
  123. header := buffer.ExtendHeader(headerLen)
  124. header[0] = byte(ws.OpBinary) | 0x80
  125. if wsc.state.ClientSide() {
  126. header[1] = 1 << 7
  127. } else {
  128. header[1] = 0
  129. }
  130. if dataLen < 126 {
  131. header[1] |= byte(dataLen)
  132. } else if dataLen < 65536 {
  133. header[1] |= 126
  134. binary.BigEndian.PutUint16(header[2:], uint16(dataLen))
  135. } else {
  136. header[1] |= 127
  137. binary.BigEndian.PutUint64(header[2:], uint64(dataLen))
  138. }
  139. if wsc.state.ClientSide() {
  140. maskKey := randv2.Uint32()
  141. binary.LittleEndian.PutUint32(header[1+payloadBitLength:], maskKey)
  142. N.MaskWebSocket(maskKey, data)
  143. }
  144. return wsc.rawWriter.WriteBuffer(buffer)
  145. }
  146. func (wsc *websocketConn) FrontHeadroom() int {
  147. return 14
  148. }
  149. func (wsc *websocketConn) Upstream() any {
  150. return wsc.Conn
  151. }
  152. func (wsc *websocketConn) Close() error {
  153. _ = wsc.Conn.SetWriteDeadline(time.Now().Add(time.Second * 5))
  154. _ = wsutil.WriteMessage(wsc.Conn, wsc.state, ws.OpClose, ws.NewCloseFrameBody(ws.StatusNormalClosure, ""))
  155. _ = wsc.Conn.Close()
  156. return nil
  157. }
  158. func (wsedc *websocketWithEarlyDataConn) Dial(earlyData []byte) error {
  159. base64DataBuf := &bytes.Buffer{}
  160. base64EarlyDataEncoder := base64.NewEncoder(base64.RawURLEncoding, base64DataBuf)
  161. earlyDataBuf := bytes.NewBuffer(earlyData)
  162. if _, err := base64EarlyDataEncoder.Write(earlyDataBuf.Next(wsedc.config.MaxEarlyData)); err != nil {
  163. return fmt.Errorf("failed to encode early data: %w", err)
  164. }
  165. if errc := base64EarlyDataEncoder.Close(); errc != nil {
  166. return fmt.Errorf("failed to encode early data tail: %w", errc)
  167. }
  168. var err error
  169. if wsedc.Conn, err = streamWebsocketConn(wsedc.ctx, wsedc.underlay, wsedc.config, base64DataBuf); err != nil {
  170. wsedc.Close()
  171. return fmt.Errorf("failed to dial WebSocket: %w", err)
  172. }
  173. wsedc.dialed <- true
  174. wsedc.wsWriter = N.NewExtendedWriter(wsedc.Conn)
  175. if earlyDataBuf.Len() != 0 {
  176. _, err = wsedc.Conn.Write(earlyDataBuf.Bytes())
  177. }
  178. return err
  179. }
  180. func (wsedc *websocketWithEarlyDataConn) Write(b []byte) (int, error) {
  181. if wsedc.closed {
  182. return 0, io.ErrClosedPipe
  183. }
  184. if wsedc.Conn == nil {
  185. if err := wsedc.Dial(b); err != nil {
  186. return 0, err
  187. }
  188. return len(b), nil
  189. }
  190. return wsedc.Conn.Write(b)
  191. }
  192. func (wsedc *websocketWithEarlyDataConn) WriteBuffer(buffer *buf.Buffer) error {
  193. if wsedc.closed {
  194. return io.ErrClosedPipe
  195. }
  196. if wsedc.Conn == nil {
  197. if err := wsedc.Dial(buffer.Bytes()); err != nil {
  198. return err
  199. }
  200. return nil
  201. }
  202. return wsedc.wsWriter.WriteBuffer(buffer)
  203. }
  204. func (wsedc *websocketWithEarlyDataConn) Read(b []byte) (int, error) {
  205. if wsedc.closed {
  206. return 0, io.ErrClosedPipe
  207. }
  208. if wsedc.Conn == nil {
  209. select {
  210. case <-wsedc.ctx.Done():
  211. return 0, io.ErrUnexpectedEOF
  212. case <-wsedc.dialed:
  213. }
  214. }
  215. return wsedc.Conn.Read(b)
  216. }
  217. func (wsedc *websocketWithEarlyDataConn) Close() error {
  218. wsedc.closed = true
  219. wsedc.cancel()
  220. if wsedc.Conn == nil {
  221. return nil
  222. }
  223. return wsedc.Conn.Close()
  224. }
  225. func (wsedc *websocketWithEarlyDataConn) LocalAddr() net.Addr {
  226. if wsedc.Conn == nil {
  227. return wsedc.underlay.LocalAddr()
  228. }
  229. return wsedc.Conn.LocalAddr()
  230. }
  231. func (wsedc *websocketWithEarlyDataConn) RemoteAddr() net.Addr {
  232. if wsedc.Conn == nil {
  233. return wsedc.underlay.RemoteAddr()
  234. }
  235. return wsedc.Conn.RemoteAddr()
  236. }
  237. func (wsedc *websocketWithEarlyDataConn) SetDeadline(t time.Time) error {
  238. if err := wsedc.SetReadDeadline(t); err != nil {
  239. return err
  240. }
  241. return wsedc.SetWriteDeadline(t)
  242. }
  243. func (wsedc *websocketWithEarlyDataConn) SetReadDeadline(t time.Time) error {
  244. if wsedc.Conn == nil {
  245. return nil
  246. }
  247. return wsedc.Conn.SetReadDeadline(t)
  248. }
  249. func (wsedc *websocketWithEarlyDataConn) SetWriteDeadline(t time.Time) error {
  250. if wsedc.Conn == nil {
  251. return nil
  252. }
  253. return wsedc.Conn.SetWriteDeadline(t)
  254. }
  255. func (wsedc *websocketWithEarlyDataConn) FrontHeadroom() int {
  256. return 14
  257. }
  258. func (wsedc *websocketWithEarlyDataConn) Upstream() any {
  259. return wsedc.underlay
  260. }
  261. //func (wsedc *websocketWithEarlyDataConn) LazyHeadroom() bool {
  262. // return wsedc.Conn == nil
  263. //}
  264. //
  265. //func (wsedc *websocketWithEarlyDataConn) Upstream() any {
  266. // if wsedc.Conn == nil { // ensure return a nil interface not an interface with nil value
  267. // return nil
  268. // }
  269. // return wsedc.Conn
  270. //}
  271. func (wsedc *websocketWithEarlyDataConn) NeedHandshake() bool {
  272. return wsedc.Conn == nil
  273. }
  274. func streamWebsocketWithEarlyDataConn(conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
  275. ctx, cancel := context.WithCancel(context.Background())
  276. conn = &websocketWithEarlyDataConn{
  277. dialed: make(chan bool, 1),
  278. cancel: cancel,
  279. ctx: ctx,
  280. underlay: conn,
  281. config: c,
  282. }
  283. // websocketWithEarlyDataConn can't correct handle Deadline
  284. // it will not apply the already set Deadline after Dial()
  285. // so call N.NewDeadlineConn to add a safe wrapper
  286. return N.NewDeadlineConn(conn), nil
  287. }
  288. func streamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig, earlyData *bytes.Buffer) (net.Conn, error) {
  289. u, err := url.Parse(c.Path)
  290. if err != nil {
  291. return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
  292. }
  293. uri := url.URL{
  294. Scheme: "ws",
  295. Host: net.JoinHostPort(c.Host, c.Port),
  296. Path: u.Path,
  297. RawQuery: u.RawQuery,
  298. }
  299. if !strings.HasPrefix(uri.Path, "/") {
  300. uri.Path = "/" + uri.Path
  301. }
  302. if c.TLS {
  303. uri.Scheme = "wss"
  304. config := c.TLSConfig
  305. if config == nil { // The config cannot be nil
  306. config = &tls.Config{NextProtos: []string{"http/1.1"}}
  307. }
  308. if config.ServerName == "" && !config.InsecureSkipVerify { // users must set either ServerName or InsecureSkipVerify in the config.
  309. config = config.Clone()
  310. config.ServerName = uri.Host
  311. }
  312. if len(c.ClientFingerprint) != 0 {
  313. if fingerprint, exists := tlsC.GetFingerprint(c.ClientFingerprint); exists {
  314. utlsConn := tlsC.UClient(conn, config, fingerprint)
  315. if err = utlsConn.BuildWebsocketHandshakeState(); err != nil {
  316. return nil, fmt.Errorf("parse url %s error: %w", c.Path, err)
  317. }
  318. conn = utlsConn
  319. }
  320. } else {
  321. conn = tls.Client(conn, config)
  322. }
  323. if tlsConn, ok := conn.(interface {
  324. HandshakeContext(ctx context.Context) error
  325. }); ok {
  326. if err = tlsConn.HandshakeContext(ctx); err != nil {
  327. return nil, err
  328. }
  329. }
  330. }
  331. request := &http.Request{
  332. Method: http.MethodGet,
  333. URL: &uri,
  334. Header: c.Headers.Clone(),
  335. Host: c.Host,
  336. }
  337. request.Header.Set("Connection", "Upgrade")
  338. request.Header.Set("Upgrade", "websocket")
  339. if host := request.Header.Get("Host"); host != "" {
  340. // For client requests, Host optionally overrides the Host
  341. // header to send. If empty, the Request.Write method uses
  342. // the value of URL.Host. Host may contain an international
  343. // domain name.
  344. request.Host = host
  345. }
  346. request.Header.Del("Host")
  347. var secKey string
  348. if !c.V2rayHttpUpgrade {
  349. const nonceKeySize = 16
  350. // NOTE: bts does not escape.
  351. bts := make([]byte, nonceKeySize)
  352. if _, err = rand.Read(bts); err != nil {
  353. return nil, fmt.Errorf("rand read error: %w", err)
  354. }
  355. secKey = base64.StdEncoding.EncodeToString(bts)
  356. request.Header.Set("Sec-WebSocket-Version", "13")
  357. request.Header.Set("Sec-WebSocket-Key", secKey)
  358. }
  359. if earlyData != nil {
  360. earlyDataString := earlyData.String()
  361. if c.EarlyDataHeaderName == "" {
  362. uri.Path += earlyDataString
  363. } else {
  364. request.Header.Set(c.EarlyDataHeaderName, earlyDataString)
  365. }
  366. }
  367. if ctx.Done() != nil {
  368. done := N.SetupContextForConn(ctx, conn)
  369. defer done(&err)
  370. }
  371. err = request.Write(conn)
  372. if err != nil {
  373. return nil, err
  374. }
  375. bufferedConn := N.NewBufferedConn(conn)
  376. if c.V2rayHttpUpgrade && c.V2rayHttpUpgradeFastOpen {
  377. return N.NewEarlyConn(bufferedConn, func() error {
  378. response, err := http.ReadResponse(bufferedConn.Reader(), request)
  379. if err != nil {
  380. return err
  381. }
  382. if response.StatusCode != http.StatusSwitchingProtocols ||
  383. !strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
  384. !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
  385. return fmt.Errorf("unexpected status: %s", response.Status)
  386. }
  387. return nil
  388. }), nil
  389. }
  390. response, err := http.ReadResponse(bufferedConn.Reader(), request)
  391. if err != nil {
  392. return nil, err
  393. }
  394. if response.StatusCode != http.StatusSwitchingProtocols ||
  395. !strings.EqualFold(response.Header.Get("Connection"), "upgrade") ||
  396. !strings.EqualFold(response.Header.Get("Upgrade"), "websocket") {
  397. return nil, fmt.Errorf("unexpected status: %s", response.Status)
  398. }
  399. if c.V2rayHttpUpgrade {
  400. return bufferedConn, nil
  401. }
  402. if log.Level() == log.DEBUG { // we might not check this for performance
  403. secAccept := response.Header.Get("Sec-Websocket-Accept")
  404. const acceptSize = 28 // base64.StdEncoding.EncodedLen(sha1.Size)
  405. if lenSecAccept := len(secAccept); lenSecAccept != acceptSize {
  406. return nil, fmt.Errorf("unexpected Sec-Websocket-Accept length: %d", lenSecAccept)
  407. }
  408. if getSecAccept(secKey) != secAccept {
  409. return nil, errors.New("unexpected Sec-Websocket-Accept")
  410. }
  411. }
  412. conn = newWebsocketConn(conn, ws.StateClientSide)
  413. // websocketConn can't correct handle ReadDeadline
  414. // so call N.NewDeadlineConn to add a safe wrapper
  415. return N.NewDeadlineConn(conn), nil
  416. }
  417. func getSecAccept(secKey string) string {
  418. const magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  419. const nonceSize = 24 // base64.StdEncoding.EncodedLen(nonceKeySize)
  420. p := make([]byte, nonceSize+len(magic))
  421. copy(p[:nonceSize], secKey)
  422. copy(p[nonceSize:], magic)
  423. sum := sha1.Sum(p)
  424. return base64.StdEncoding.EncodeToString(sum[:])
  425. }
  426. func StreamWebsocketConn(ctx context.Context, conn net.Conn, c *WebsocketConfig) (net.Conn, error) {
  427. if u, err := url.Parse(c.Path); err == nil {
  428. if q := u.Query(); q.Get("ed") != "" {
  429. if ed, err := strconv.Atoi(q.Get("ed")); err == nil {
  430. c.MaxEarlyData = ed
  431. c.EarlyDataHeaderName = "Sec-WebSocket-Protocol"
  432. q.Del("ed")
  433. u.RawQuery = q.Encode()
  434. c.Path = u.String()
  435. }
  436. }
  437. }
  438. if c.MaxEarlyData > 0 {
  439. return streamWebsocketWithEarlyDataConn(conn, c)
  440. }
  441. return streamWebsocketConn(ctx, conn, c, nil)
  442. }
  443. func newWebsocketConn(conn net.Conn, state ws.State) *websocketConn {
  444. controlHandler := wsutil.ControlFrameHandler(conn, state)
  445. return &websocketConn{
  446. Conn: conn,
  447. state: state,
  448. reader: &wsutil.Reader{
  449. Source: conn,
  450. State: state,
  451. SkipHeaderCheck: true,
  452. CheckUTF8: false,
  453. OnIntermediate: controlHandler,
  454. },
  455. controlHandler: controlHandler,
  456. rawWriter: N.NewExtendedWriter(conn),
  457. }
  458. }
  459. var replacer = strings.NewReplacer("+", "-", "/", "_", "=", "")
  460. func decodeEd(s string) ([]byte, error) {
  461. return base64.RawURLEncoding.DecodeString(replacer.Replace(s))
  462. }
  463. func decodeXray0rtt(requestHeader http.Header) []byte {
  464. // read inHeader's `Sec-WebSocket-Protocol` for Xray's 0rtt ws
  465. if secProtocol := requestHeader.Get("Sec-WebSocket-Protocol"); len(secProtocol) > 0 {
  466. if edBuf, err := decodeEd(secProtocol); err == nil { // sure could base64 decode
  467. return edBuf
  468. }
  469. }
  470. return nil
  471. }
  472. func IsWebSocketUpgrade(r *http.Request) bool {
  473. return r.Header.Get("Upgrade") == "websocket"
  474. }
  475. func IsV2rayHttpUpdate(r *http.Request) bool {
  476. return IsWebSocketUpgrade(r) && r.Header.Get("Sec-WebSocket-Key") == ""
  477. }
  478. func StreamUpgradedWebsocketConn(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
  479. var conn net.Conn
  480. var rw *bufio.ReadWriter
  481. var err error
  482. isRaw := IsV2rayHttpUpdate(r)
  483. w.Header().Set("Connection", "upgrade")
  484. w.Header().Set("Upgrade", "websocket")
  485. if !isRaw {
  486. w.Header().Set("Sec-Websocket-Accept", getSecAccept(r.Header.Get("Sec-WebSocket-Key")))
  487. }
  488. w.WriteHeader(http.StatusSwitchingProtocols)
  489. if flusher, isFlusher := w.(interface{ FlushError() error }); isFlusher {
  490. err = flusher.FlushError()
  491. if err != nil {
  492. return nil, fmt.Errorf("flush response: %w", err)
  493. }
  494. }
  495. hijacker, canHijack := w.(http.Hijacker)
  496. if !canHijack {
  497. return nil, errors.New("invalid connection, maybe HTTP/2")
  498. }
  499. conn, rw, err = hijacker.Hijack()
  500. if err != nil {
  501. return nil, fmt.Errorf("hijack failed: %w", err)
  502. }
  503. // rw.Writer was flushed, so we only need warp rw.Reader
  504. conn = N.WarpConnWithBioReader(conn, rw.Reader)
  505. if !isRaw {
  506. conn = newWebsocketConn(conn, ws.StateServerSide)
  507. // websocketConn can't correct handle ReadDeadline
  508. // so call N.NewDeadlineConn to add a safe wrapper
  509. conn = N.NewDeadlineConn(conn)
  510. }
  511. if edBuf := decodeXray0rtt(r.Header); len(edBuf) > 0 {
  512. appendOk := false
  513. if bufConn, ok := conn.(*N.BufferedConn); ok {
  514. appendOk = bufConn.AppendData(edBuf)
  515. }
  516. if !appendOk {
  517. conn = N.NewCachedConn(conn, edBuf)
  518. }
  519. }
  520. return conn, nil
  521. }