icmpping.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. package nettools
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "golang.org/x/net/context"
  7. "golang.org/x/net/icmp"
  8. "golang.org/x/net/ipv4"
  9. "golang.org/x/net/ipv6"
  10. "math/rand"
  11. "net"
  12. "os"
  13. "syscall"
  14. "time"
  15. )
  16. type IcmpPingResult struct {
  17. Time int
  18. Err error
  19. IP net.IP
  20. TTL int
  21. }
  22. func (icmpR *IcmpPingResult) Result() int {
  23. return icmpR.Time
  24. }
  25. func (icmpR *IcmpPingResult) Error() error {
  26. return icmpR.Err
  27. }
  28. func (icmpR *IcmpPingResult) String() string {
  29. if icmpR.Err != nil {
  30. return fmt.Sprintf("%s", icmpR.Err)
  31. } else {
  32. return fmt.Sprintf("%s: time=%d ms, TTL=%d", icmpR.IP.String(), icmpR.Time, icmpR.TTL)
  33. }
  34. }
  35. type IcmpPing struct {
  36. host string
  37. Timeout time.Duration
  38. ip net.IP
  39. Privileged bool
  40. }
  41. func (icmpC *IcmpPing) SetHost(host string) {
  42. icmpC.host = host
  43. icmpC.ip = net.ParseIP(host)
  44. }
  45. func (icmpC *IcmpPing) Host() string {
  46. return icmpC.host
  47. }
  48. func NewIcmpPing(host string, timeout time.Duration) *IcmpPing {
  49. p := &IcmpPing{
  50. Timeout: timeout,
  51. }
  52. p.SetHost(host)
  53. return p
  54. }
  55. func (icmpC *IcmpPing) Ping() IPingResult {
  56. return icmpC.PingContext(context.Background())
  57. }
  58. func (icmpC *IcmpPing) PingContext(ctx context.Context) IPingResult {
  59. pingfunc := icmpC.ping_rootless
  60. if icmpC.Privileged {
  61. pingfunc = icmpC.ping_root
  62. }
  63. return pingfunc(ctx)
  64. }
  65. func (icmpC *IcmpPing) ping_root(ctx context.Context) IPingResult {
  66. return icmpC.rawping("ip")
  67. }
  68. // https://github.com/sparrc/go-ping/blob/master/ping.go
  69. func (icmpC *IcmpPing) rawping(network string) IPingResult {
  70. // 解析IP
  71. ip, isipv6, err := icmpC.parseip()
  72. if err != nil {
  73. return icmpC.errorResult(err)
  74. }
  75. // 创建连接
  76. conn, err := icmpC.getconn(network, ip, isipv6)
  77. if err != nil {
  78. return icmpC.errorResult(err)
  79. }
  80. defer conn.Close()
  81. conn.SetDeadline(time.Now().Add(icmpC.Timeout))
  82. // 发送
  83. r := rand.New(rand.NewSource(time.Now().UnixNano()))
  84. sendData := make([]byte, 32)
  85. r.Read(sendData)
  86. id := os.Getpid() & 0xffff
  87. sendMsg := icmpC.getmsg(isipv6, id, 0, sendData)
  88. sendMsgBytes, err := sendMsg.Marshal(nil)
  89. if err != nil {
  90. return icmpC.errorResult(err)
  91. }
  92. var dst net.Addr = &net.IPAddr{IP: ip}
  93. if network == "udp" {
  94. dst = &net.UDPAddr{IP: ip}
  95. }
  96. sendAt := time.Now()
  97. for {
  98. if _, err := conn.WriteTo(sendMsgBytes, dst); err != nil {
  99. if neterr, ok := err.(*net.OpError); ok {
  100. if neterr.Err == syscall.ENOBUFS {
  101. continue
  102. }
  103. }
  104. }
  105. break
  106. }
  107. recvBytes := make([]byte, 1500)
  108. recvSize := 0
  109. for {
  110. ttl := -1
  111. var peer net.Addr
  112. if isipv6 {
  113. var cm *ipv6.ControlMessage
  114. recvSize, cm, peer, err = conn.IPv6PacketConn().ReadFrom(recvBytes)
  115. if cm != nil {
  116. ttl = cm.HopLimit
  117. }
  118. } else {
  119. var cm *ipv4.ControlMessage
  120. recvSize, cm, peer, err = conn.IPv4PacketConn().ReadFrom(recvBytes)
  121. if cm != nil {
  122. ttl = cm.TTL
  123. }
  124. }
  125. if err != nil {
  126. return icmpC.errorResult(err)
  127. }
  128. recvAt := time.Now()
  129. recvProto := 1
  130. if isipv6 {
  131. recvProto = 58
  132. }
  133. recvMsg, err := icmp.ParseMessage(recvProto, recvBytes[:recvSize])
  134. if err != nil {
  135. return icmpC.errorResult(err)
  136. }
  137. recvData, recvID, recvType := icmpC.parserecvmsg(isipv6, recvMsg)
  138. // 修正数据长度
  139. if len(recvData) > len(sendData) {
  140. recvData = recvData[len(recvData)-len(sendData):]
  141. }
  142. // 收到的数据和发送的数据不一致,继续接收
  143. if !bytes.Equal(recvData, sendData) {
  144. continue
  145. }
  146. // 是 echo 回复,但 ID 不一致,继续接收
  147. if recvType == 1 && network == "ip" && recvID != id {
  148. continue
  149. }
  150. if peer != nil {
  151. if _ip := net.ParseIP(peer.String()); _ip != nil {
  152. ip = _ip
  153. }
  154. }
  155. switch recvType {
  156. case 1:
  157. // echo
  158. return &IcmpPingResult{
  159. TTL: ttl,
  160. Time: int(recvAt.Sub(sendAt).Milliseconds()),
  161. IP: ip,
  162. }
  163. case 2:
  164. // destination unreachable
  165. return icmpC.errorResult(errors.New(fmt.Sprintf("%s: destination unreachable", ip.String())))
  166. case 3:
  167. // time exceeded
  168. return icmpC.errorResult(errors.New(fmt.Sprintf("%s: time exceeded", ip.String())))
  169. }
  170. }
  171. }
  172. func (icmpC *IcmpPing) parseip() (ip net.IP, ipv6 bool, err error) {
  173. err = nil
  174. ip = cloneIP(icmpC.ip)
  175. if ip == nil {
  176. ip, err = LookupFunc(icmpC.host)
  177. if err != nil {
  178. return
  179. }
  180. }
  181. if isIPv4(ip) {
  182. ipv6 = false
  183. } else if isIPv6(ip) {
  184. ipv6 = true
  185. } else {
  186. err = errors.New("lookup ip failed")
  187. }
  188. return
  189. }
  190. func (icmpC *IcmpPing) getconn(network string, ip net.IP, isipv6 bool) (*icmp.PacketConn, error) {
  191. ipv4Proto := map[string]string{"ip": "ip4:icmp", "udp": "udp4"}
  192. ipv6Proto := map[string]string{"ip": "ip6:ipv6-icmp", "udp": "udp6"}
  193. icmpnetwork := ""
  194. if isipv6 {
  195. icmpnetwork = ipv6Proto[network]
  196. } else {
  197. icmpnetwork = ipv4Proto[network]
  198. }
  199. conn, err := icmp.ListenPacket(icmpnetwork, "")
  200. if err != nil {
  201. return nil, err
  202. }
  203. if isipv6 {
  204. conn.IPv6PacketConn().SetControlMessage(ipv6.FlagHopLimit, true)
  205. } else {
  206. conn.IPv4PacketConn().SetControlMessage(ipv4.FlagTTL, true)
  207. }
  208. return conn, nil
  209. }
  210. func (icmpC *IcmpPing) getmsg(isipv6 bool, id, seq int, data []byte) *icmp.Message {
  211. var msgtype icmp.Type = ipv4.ICMPTypeEcho
  212. if isipv6 {
  213. msgtype = ipv6.ICMPTypeEchoRequest
  214. }
  215. body := &icmp.Echo{
  216. ID: id,
  217. Seq: seq,
  218. Data: data,
  219. }
  220. msg := &icmp.Message{
  221. Type: msgtype,
  222. Code: 0,
  223. Body: body,
  224. }
  225. return msg
  226. }
  227. func (icmpC *IcmpPing) parserecvmsg(isipv6 bool, msg *icmp.Message) (data []byte, id, msgtype int) {
  228. id = 0
  229. data = nil
  230. msgtype = 0
  231. if isipv6 {
  232. switch msg.Type {
  233. case ipv6.ICMPTypeEchoReply:
  234. msgtype = 1
  235. case ipv6.ICMPTypeDestinationUnreachable:
  236. msgtype = 2
  237. case ipv6.ICMPTypeTimeExceeded:
  238. msgtype = 3
  239. }
  240. } else {
  241. switch msg.Type {
  242. case ipv4.ICMPTypeEchoReply:
  243. msgtype = 1
  244. case ipv4.ICMPTypeDestinationUnreachable:
  245. msgtype = 2
  246. case ipv4.ICMPTypeTimeExceeded:
  247. msgtype = 3
  248. }
  249. }
  250. switch msgtype {
  251. case 1:
  252. if tempmsg, ok := msg.Body.(*icmp.Echo); ok {
  253. data = tempmsg.Data
  254. id = tempmsg.ID
  255. }
  256. case 2:
  257. if tempmsg, ok := msg.Body.(*icmp.DstUnreach); ok {
  258. data = tempmsg.Data
  259. }
  260. case 3:
  261. if tempmsg, ok := msg.Body.(*icmp.TimeExceeded); ok {
  262. data = tempmsg.Data
  263. }
  264. }
  265. return
  266. }
  267. func (icmpC *IcmpPing) errorResult(err error) IPingResult {
  268. r := &IcmpPingResult{}
  269. r.Err = err
  270. return r
  271. }
  272. var (
  273. _ IPing = (*IcmpPing)(nil)
  274. _ IPingResult = (*IcmpPingResult)(nil)
  275. )