aead.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. package vmess
  2. import (
  3. "crypto/cipher"
  4. "encoding/binary"
  5. "errors"
  6. "io"
  7. "sync"
  8. "github.com/metacubex/mihomo/common/pool"
  9. )
  10. type aeadWriter struct {
  11. io.Writer
  12. cipher.AEAD
  13. nonce [32]byte
  14. count uint16
  15. iv []byte
  16. writeLock sync.Mutex
  17. }
  18. func newAEADWriter(w io.Writer, aead cipher.AEAD, iv []byte) *aeadWriter {
  19. return &aeadWriter{Writer: w, AEAD: aead, iv: iv}
  20. }
  21. func (w *aeadWriter) Write(b []byte) (n int, err error) {
  22. w.writeLock.Lock()
  23. buf := pool.Get(pool.RelayBufferSize)
  24. defer func() {
  25. w.writeLock.Unlock()
  26. pool.Put(buf)
  27. }()
  28. length := len(b)
  29. for {
  30. if length == 0 {
  31. break
  32. }
  33. readLen := chunkSize - w.Overhead()
  34. if length < readLen {
  35. readLen = length
  36. }
  37. payloadBuf := buf[lenSize : lenSize+chunkSize-w.Overhead()]
  38. copy(payloadBuf, b[n:n+readLen])
  39. binary.BigEndian.PutUint16(buf[:lenSize], uint16(readLen+w.Overhead()))
  40. binary.BigEndian.PutUint16(w.nonce[:2], w.count)
  41. copy(w.nonce[2:], w.iv[2:12])
  42. w.Seal(payloadBuf[:0], w.nonce[:w.NonceSize()], payloadBuf[:readLen], nil)
  43. w.count++
  44. _, err = w.Writer.Write(buf[:lenSize+readLen+w.Overhead()])
  45. if err != nil {
  46. break
  47. }
  48. n += readLen
  49. length -= readLen
  50. }
  51. return
  52. }
  53. type aeadReader struct {
  54. io.Reader
  55. cipher.AEAD
  56. nonce [32]byte
  57. buf []byte
  58. offset int
  59. iv []byte
  60. sizeBuf []byte
  61. count uint16
  62. }
  63. func newAEADReader(r io.Reader, aead cipher.AEAD, iv []byte) *aeadReader {
  64. return &aeadReader{Reader: r, AEAD: aead, iv: iv, sizeBuf: make([]byte, lenSize)}
  65. }
  66. func (r *aeadReader) Read(b []byte) (int, error) {
  67. if r.buf != nil {
  68. n := copy(b, r.buf[r.offset:])
  69. r.offset += n
  70. if r.offset == len(r.buf) {
  71. pool.Put(r.buf)
  72. r.buf = nil
  73. }
  74. return n, nil
  75. }
  76. _, err := io.ReadFull(r.Reader, r.sizeBuf)
  77. if err != nil {
  78. return 0, err
  79. }
  80. size := int(binary.BigEndian.Uint16(r.sizeBuf))
  81. if size > maxSize {
  82. return 0, errors.New("buffer is larger than standard")
  83. }
  84. buf := pool.Get(size)
  85. _, err = io.ReadFull(r.Reader, buf[:size])
  86. if err != nil {
  87. pool.Put(buf)
  88. return 0, err
  89. }
  90. binary.BigEndian.PutUint16(r.nonce[:2], r.count)
  91. copy(r.nonce[2:], r.iv[2:12])
  92. _, err = r.Open(buf[:0], r.nonce[:r.NonceSize()], buf[:size], nil)
  93. r.count++
  94. if err != nil {
  95. return 0, err
  96. }
  97. realLen := size - r.Overhead()
  98. n := copy(b, buf[:realLen])
  99. if len(b) >= realLen {
  100. pool.Put(buf)
  101. return n, nil
  102. }
  103. r.offset = n
  104. r.buf = buf[:realLen]
  105. return n, nil
  106. }