config.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. package ca
  2. import (
  3. "bytes"
  4. "crypto/sha256"
  5. "crypto/tls"
  6. "crypto/x509"
  7. _ "embed"
  8. "encoding/hex"
  9. "errors"
  10. "fmt"
  11. "os"
  12. "strconv"
  13. "strings"
  14. "sync"
  15. C "github.com/metacubex/mihomo/constant"
  16. )
  17. var trustCerts []*x509.Certificate
  18. var globalCertPool *x509.CertPool
  19. var mutex sync.RWMutex
  20. var errNotMatch = errors.New("certificate fingerprints do not match")
  21. //go:embed ca-certificates.crt
  22. var _CaCertificates []byte
  23. var DisableEmbedCa, _ = strconv.ParseBool(os.Getenv("DISABLE_EMBED_CA"))
  24. var DisableSystemCa, _ = strconv.ParseBool(os.Getenv("DISABLE_SYSTEM_CA"))
  25. func AddCertificate(certificate string) error {
  26. mutex.Lock()
  27. defer mutex.Unlock()
  28. if certificate == "" {
  29. return fmt.Errorf("certificate is empty")
  30. }
  31. if cert, err := x509.ParseCertificate([]byte(certificate)); err == nil {
  32. trustCerts = append(trustCerts, cert)
  33. return nil
  34. } else {
  35. return fmt.Errorf("add certificate failed")
  36. }
  37. }
  38. func initializeCertPool() {
  39. var err error
  40. if DisableSystemCa {
  41. globalCertPool = x509.NewCertPool()
  42. } else {
  43. globalCertPool, err = x509.SystemCertPool()
  44. if err != nil {
  45. globalCertPool = x509.NewCertPool()
  46. }
  47. }
  48. for _, cert := range trustCerts {
  49. globalCertPool.AddCert(cert)
  50. }
  51. if !DisableEmbedCa {
  52. globalCertPool.AppendCertsFromPEM(_CaCertificates)
  53. }
  54. }
  55. func ResetCertificate() {
  56. mutex.Lock()
  57. defer mutex.Unlock()
  58. trustCerts = nil
  59. initializeCertPool()
  60. }
  61. func getCertPool() *x509.CertPool {
  62. if globalCertPool == nil {
  63. mutex.Lock()
  64. defer mutex.Unlock()
  65. if globalCertPool != nil {
  66. return globalCertPool
  67. }
  68. initializeCertPool()
  69. }
  70. return globalCertPool
  71. }
  72. func verifyFingerprint(fingerprint *[32]byte) func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
  73. return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
  74. // ssl pining
  75. for i := range rawCerts {
  76. rawCert := rawCerts[i]
  77. cert, err := x509.ParseCertificate(rawCert)
  78. if err == nil {
  79. hash := sha256.Sum256(cert.Raw)
  80. if bytes.Equal(fingerprint[:], hash[:]) {
  81. return nil
  82. }
  83. }
  84. }
  85. return errNotMatch
  86. }
  87. }
  88. func convertFingerprint(fingerprint string) (*[32]byte, error) {
  89. fingerprint = strings.TrimSpace(strings.Replace(fingerprint, ":", "", -1))
  90. fpByte, err := hex.DecodeString(fingerprint)
  91. if err != nil {
  92. return nil, err
  93. }
  94. if len(fpByte) != 32 {
  95. return nil, fmt.Errorf("fingerprint string length error,need sha256 fingerprint")
  96. }
  97. return (*[32]byte)(fpByte), nil
  98. }
  99. // GetTLSConfig specified fingerprint, customCA and customCAString
  100. func GetTLSConfig(tlsConfig *tls.Config, fingerprint string, customCA string, customCAString string) (*tls.Config, error) {
  101. if tlsConfig == nil {
  102. tlsConfig = &tls.Config{}
  103. }
  104. var certificate []byte
  105. var err error
  106. if len(customCA) > 0 {
  107. certificate, err = os.ReadFile(C.Path.Resolve(customCA))
  108. if err != nil {
  109. return nil, fmt.Errorf("load ca error: %w", err)
  110. }
  111. } else if customCAString != "" {
  112. certificate = []byte(customCAString)
  113. }
  114. if len(certificate) > 0 {
  115. certPool := x509.NewCertPool()
  116. if !certPool.AppendCertsFromPEM(certificate) {
  117. return nil, fmt.Errorf("failed to parse certificate:\n\n %s", certificate)
  118. }
  119. tlsConfig.RootCAs = certPool
  120. } else {
  121. tlsConfig.RootCAs = getCertPool()
  122. }
  123. if len(fingerprint) > 0 {
  124. var fingerprintBytes *[32]byte
  125. fingerprintBytes, err = convertFingerprint(fingerprint)
  126. if err != nil {
  127. return nil, err
  128. }
  129. tlsConfig = GetGlobalTLSConfig(tlsConfig)
  130. tlsConfig.VerifyPeerCertificate = verifyFingerprint(fingerprintBytes)
  131. tlsConfig.InsecureSkipVerify = true
  132. }
  133. return tlsConfig, nil
  134. }
  135. // GetSpecifiedFingerprintTLSConfig specified fingerprint
  136. func GetSpecifiedFingerprintTLSConfig(tlsConfig *tls.Config, fingerprint string) (*tls.Config, error) {
  137. return GetTLSConfig(tlsConfig, fingerprint, "", "")
  138. }
  139. func GetGlobalTLSConfig(tlsConfig *tls.Config) *tls.Config {
  140. tlsConfig, _ = GetTLSConfig(tlsConfig, "", "", "")
  141. return tlsConfig
  142. }