123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348 |
- package trojan
- import (
- "context"
- "crypto/sha256"
- "crypto/tls"
- "encoding/binary"
- "encoding/hex"
- "errors"
- "io"
- "net"
- "net/http"
- "sync"
- N "github.com/metacubex/mihomo/common/net"
- "github.com/metacubex/mihomo/common/pool"
- "github.com/metacubex/mihomo/component/ca"
- tlsC "github.com/metacubex/mihomo/component/tls"
- C "github.com/metacubex/mihomo/constant"
- "github.com/metacubex/mihomo/transport/socks5"
- "github.com/metacubex/mihomo/transport/vmess"
- )
- const (
-
- maxLength = 8192
- )
- var (
- defaultALPN = []string{"h2", "http/1.1"}
- defaultWebsocketALPN = []string{"http/1.1"}
- crlf = []byte{'\r', '\n'}
- )
- type Command = byte
- const (
- CommandTCP byte = 1
- CommandUDP byte = 3
-
- commandXRD byte = 0xf0
- commandXRO byte = 0xf1
- )
- type Option struct {
- Password string
- ALPN []string
- ServerName string
- SkipCertVerify bool
- Fingerprint string
- ClientFingerprint string
- Reality *tlsC.RealityConfig
- }
- type WebsocketOption struct {
- Host string
- Port string
- Path string
- Headers http.Header
- V2rayHttpUpgrade bool
- V2rayHttpUpgradeFastOpen bool
- }
- type Trojan struct {
- option *Option
- hexPassword []byte
- }
- func (t *Trojan) StreamConn(ctx context.Context, conn net.Conn) (net.Conn, error) {
- alpn := defaultALPN
- if len(t.option.ALPN) != 0 {
- alpn = t.option.ALPN
- }
- tlsConfig := &tls.Config{
- NextProtos: alpn,
- MinVersion: tls.VersionTLS12,
- InsecureSkipVerify: t.option.SkipCertVerify,
- ServerName: t.option.ServerName,
- }
- var err error
- tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint)
- if err != nil {
- return nil, err
- }
- if len(t.option.ClientFingerprint) != 0 {
- if t.option.Reality == nil {
- utlsConn, valid := vmess.GetUTLSConn(conn, t.option.ClientFingerprint, tlsConfig)
- if valid {
- ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
- defer cancel()
- err := utlsConn.(*tlsC.UConn).HandshakeContext(ctx)
- return utlsConn, err
- }
- } else {
- ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
- defer cancel()
- return tlsC.GetRealityConn(ctx, conn, t.option.ClientFingerprint, tlsConfig, t.option.Reality)
- }
- }
- if t.option.Reality != nil {
- return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint")
- }
- tlsConn := tls.Client(conn, tlsConfig)
-
- ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
- defer cancel()
- err = tlsConn.HandshakeContext(ctx)
- return tlsConn, err
- }
- func (t *Trojan) StreamWebsocketConn(ctx context.Context, conn net.Conn, wsOptions *WebsocketOption) (net.Conn, error) {
- alpn := defaultWebsocketALPN
- if len(t.option.ALPN) != 0 {
- alpn = t.option.ALPN
- }
- tlsConfig := &tls.Config{
- NextProtos: alpn,
- MinVersion: tls.VersionTLS12,
- InsecureSkipVerify: t.option.SkipCertVerify,
- ServerName: t.option.ServerName,
- }
- var err error
- tlsConfig, err = ca.GetSpecifiedFingerprintTLSConfig(tlsConfig, t.option.Fingerprint)
- if err != nil {
- return nil, err
- }
- return vmess.StreamWebsocketConn(ctx, conn, &vmess.WebsocketConfig{
- Host: wsOptions.Host,
- Port: wsOptions.Port,
- Path: wsOptions.Path,
- Headers: wsOptions.Headers,
- V2rayHttpUpgrade: wsOptions.V2rayHttpUpgrade,
- V2rayHttpUpgradeFastOpen: wsOptions.V2rayHttpUpgradeFastOpen,
- TLS: true,
- TLSConfig: tlsConfig,
- ClientFingerprint: t.option.ClientFingerprint,
- })
- }
- func (t *Trojan) WriteHeader(w io.Writer, command Command, socks5Addr []byte) error {
- buf := pool.GetBuffer()
- defer pool.PutBuffer(buf)
- buf.Write(t.hexPassword)
- buf.Write(crlf)
- buf.WriteByte(command)
- buf.Write(socks5Addr)
- buf.Write(crlf)
- _, err := w.Write(buf.Bytes())
- return err
- }
- func (t *Trojan) PacketConn(conn net.Conn) net.PacketConn {
- return &PacketConn{
- Conn: conn,
- }
- }
- func writePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
- buf := pool.GetBuffer()
- defer pool.PutBuffer(buf)
- buf.Write(socks5Addr)
- binary.Write(buf, binary.BigEndian, uint16(len(payload)))
- buf.Write(crlf)
- buf.Write(payload)
- return w.Write(buf.Bytes())
- }
- func WritePacket(w io.Writer, socks5Addr, payload []byte) (int, error) {
- if len(payload) <= maxLength {
- return writePacket(w, socks5Addr, payload)
- }
- offset := 0
- total := len(payload)
- for {
- cursor := offset + maxLength
- if cursor > total {
- cursor = total
- }
- n, err := writePacket(w, socks5Addr, payload[offset:cursor])
- if err != nil {
- return offset + n, err
- }
- offset = cursor
- if offset == total {
- break
- }
- }
- return total, nil
- }
- func ReadPacket(r io.Reader, payload []byte) (net.Addr, int, int, error) {
- addr, err := socks5.ReadAddr(r, payload)
- if err != nil {
- return nil, 0, 0, errors.New("read addr error")
- }
- uAddr := addr.UDPAddr()
- if uAddr == nil {
- return nil, 0, 0, errors.New("parse addr error")
- }
- if _, err = io.ReadFull(r, payload[:2]); err != nil {
- return nil, 0, 0, errors.New("read length error")
- }
- total := int(binary.BigEndian.Uint16(payload[:2]))
- if total > maxLength {
- return nil, 0, 0, errors.New("packet invalid")
- }
-
- if _, err = io.ReadFull(r, payload[:2]); err != nil {
- return nil, 0, 0, errors.New("read crlf error")
- }
- length := len(payload)
- if total < length {
- length = total
- }
- if _, err = io.ReadFull(r, payload[:length]); err != nil {
- return nil, 0, 0, errors.New("read packet error")
- }
- return uAddr, length, total - length, nil
- }
- func New(option *Option) *Trojan {
- return &Trojan{option, hexSha224([]byte(option.Password))}
- }
- var _ N.EnhancePacketConn = (*PacketConn)(nil)
- type PacketConn struct {
- net.Conn
- remain int
- rAddr net.Addr
- mux sync.Mutex
- }
- func (pc *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
- return WritePacket(pc, socks5.ParseAddr(addr.String()), b)
- }
- func (pc *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
- pc.mux.Lock()
- defer pc.mux.Unlock()
- if pc.remain != 0 {
- length := len(b)
- if pc.remain < length {
- length = pc.remain
- }
- n, err := pc.Conn.Read(b[:length])
- if err != nil {
- return 0, nil, err
- }
- pc.remain -= n
- addr := pc.rAddr
- if pc.remain == 0 {
- pc.rAddr = nil
- }
- return n, addr, nil
- }
- addr, n, remain, err := ReadPacket(pc.Conn, b)
- if err != nil {
- return 0, nil, err
- }
- if remain != 0 {
- pc.remain = remain
- pc.rAddr = addr
- }
- return n, addr, nil
- }
- func (pc *PacketConn) WaitReadFrom() (data []byte, put func(), addr net.Addr, err error) {
- pc.mux.Lock()
- defer pc.mux.Unlock()
- destination, err := socks5.ReadAddr0(pc.Conn)
- if err != nil {
- return nil, nil, nil, err
- }
- addr = destination.UDPAddr()
- data = pool.Get(pool.UDPBufferSize)
- put = func() {
- _ = pool.Put(data)
- }
- _, err = io.ReadFull(pc.Conn, data[:2+2])
- if err != nil {
- if put != nil {
- put()
- }
- return nil, nil, nil, err
- }
- length := binary.BigEndian.Uint16(data)
- if length > 0 {
- data = data[:length]
- _, err = io.ReadFull(pc.Conn, data)
- if err != nil {
- if put != nil {
- put()
- }
- return nil, nil, nil, err
- }
- } else {
- if put != nil {
- put()
- }
- return nil, nil, addr, nil
- }
- return
- }
- func hexSha224(data []byte) []byte {
- buf := make([]byte, 56)
- hash := sha256.Sum224(data)
- hex.Encode(buf, hash[:])
- return buf
- }
|