package shadowtls

import (
	"context"
	"crypto/hmac"
	"crypto/sha1"
	"crypto/tls"
	"encoding/binary"
	"fmt"
	"hash"
	"io"
	"net"

	"github.com/metacubex/mihomo/common/pool"
	C "github.com/metacubex/mihomo/constant"
)

const (
	chunkSize           = 1 << 13
	Mode         string = "shadow-tls"
	hashLen      int    = 8
	tlsHeaderLen int    = 5
)

var (
	DefaultALPN = []string{"h2", "http/1.1"}
)

// ShadowTLS is shadow-tls implementation
type ShadowTLS struct {
	net.Conn
	password     []byte
	remain       int
	firstRequest bool
	tlsConfig    *tls.Config
}

type HashedConn struct {
	net.Conn
	hasher hash.Hash
}

func newHashedStream(conn net.Conn, password []byte) HashedConn {
	return HashedConn{
		Conn:   conn,
		hasher: hmac.New(sha1.New, password),
	}
}

func (h HashedConn) Read(b []byte) (n int, err error) {
	n, err = h.Conn.Read(b)
	h.hasher.Write(b[:n])
	return
}

func (s *ShadowTLS) read(b []byte) (int, error) {
	var buf [tlsHeaderLen]byte
	_, err := io.ReadFull(s.Conn, buf[:])
	if err != nil {
		return 0, fmt.Errorf("shadowtls read failed %w", err)
	}
	if buf[0] != 0x17 || buf[1] != 0x3 || buf[2] != 0x3 {
		return 0, fmt.Errorf("invalid shadowtls header %v", buf)
	}
	length := int(binary.BigEndian.Uint16(buf[3:]))

	if length > len(b) {
		n, err := s.Conn.Read(b)
		if err != nil {
			return n, err
		}
		s.remain = length - n
		return n, nil
	}

	return io.ReadFull(s.Conn, b[:length])
}

func (s *ShadowTLS) Read(b []byte) (int, error) {
	if s.remain > 0 {
		length := s.remain
		if length > len(b) {
			length = len(b)
		}

		n, err := io.ReadFull(s.Conn, b[:length])
		if err != nil {
			return n, fmt.Errorf("shadowtls Read failed with %w", err)
		}
		s.remain -= n
		return n, nil
	}

	return s.read(b)
}

func (s *ShadowTLS) Write(b []byte) (int, error) {
	length := len(b)
	for i := 0; i < length; i += chunkSize {
		end := i + chunkSize
		if end > length {
			end = length
		}

		n, err := s.write(b[i:end])
		if err != nil {
			return n, fmt.Errorf("shadowtls Write failed with %w, i=%d, end=%d, n=%d", err, i, end, n)
		}
	}
	return length, nil
}

func (s *ShadowTLS) write(b []byte) (int, error) {
	var hashVal []byte
	if s.firstRequest {
		hashedConn := newHashedStream(s.Conn, s.password)
		tlsConn := tls.Client(hashedConn, s.tlsConfig)
		// fix tls handshake not timeout
		ctx, cancel := context.WithTimeout(context.Background(), C.DefaultTLSTimeout)
		defer cancel()
		if err := tlsConn.HandshakeContext(ctx); err != nil {
			return 0, fmt.Errorf("tls connect failed with %w", err)
		}
		hashVal = hashedConn.hasher.Sum(nil)[:hashLen]
		s.firstRequest = false
	}

	buf := pool.GetBuffer()
	defer pool.PutBuffer(buf)
	buf.Write([]byte{0x17, 0x03, 0x03})
	binary.Write(buf, binary.BigEndian, uint16(len(b)+len(hashVal)))
	buf.Write(hashVal)
	buf.Write(b)
	_, err := s.Conn.Write(buf.Bytes())
	if err != nil {
		// return 0 because errors occur here make the
		// whole situation irrecoverable
		return 0, err
	}
	return len(b), nil
}

// NewShadowTLS return a ShadowTLS
func NewShadowTLS(conn net.Conn, password string, tlsConfig *tls.Config) net.Conn {
	return &ShadowTLS{
		Conn:         conn,
		password:     []byte(password),
		firstRequest: true,
		tlsConfig:    tlsConfig,
	}
}