123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163 |
- package trie
- import (
- "errors"
- "strings"
- )
- const (
- wildcard = "*"
- dotWildcard = ""
- complexWildcard = "+"
- domainStep = "."
- )
- // ErrInvalidDomain means insert domain is invalid
- var ErrInvalidDomain = errors.New("invalid domain")
- // DomainTrie contains the main logic for adding and searching nodes for domain segments.
- // support wildcard domain (e.g *.google.com)
- type DomainTrie[T any] struct {
- root *Node[T]
- }
- func ValidAndSplitDomain(domain string) ([]string, bool) {
- if domain != "" && domain[len(domain)-1] == '.' {
- return nil, false
- }
- domain = strings.ToLower(domain)
- parts := strings.Split(domain, domainStep)
- if len(parts) == 1 {
- if parts[0] == "" {
- return nil, false
- }
- return parts, true
- }
- for _, part := range parts[1:] {
- if part == "" {
- return nil, false
- }
- }
- return parts, true
- }
- // Insert adds a node to the trie.
- // Support
- // 1. www.example.com
- // 2. *.example.com
- // 3. subdomain.*.example.com
- // 4. .example.com
- // 5. +.example.com
- func (t *DomainTrie[T]) Insert(domain string, data T) error {
- parts, valid := ValidAndSplitDomain(domain)
- if !valid {
- return ErrInvalidDomain
- }
- if parts[0] == complexWildcard {
- t.insert(parts[1:], data)
- parts[0] = dotWildcard
- t.insert(parts, data)
- } else {
- t.insert(parts, data)
- }
- return nil
- }
- func (t *DomainTrie[T]) insert(parts []string, data T) {
- node := t.root
- // reverse storage domain part to save space
- for i := len(parts) - 1; i >= 0; i-- {
- part := parts[i]
- node = node.getOrNewChild(part)
- }
- node.setData(data)
- }
- // Search is the most important part of the Trie.
- // Priority as:
- // 1. static part
- // 2. wildcard domain
- // 2. dot wildcard domain
- func (t *DomainTrie[T]) Search(domain string) *Node[T] {
- parts, valid := ValidAndSplitDomain(domain)
- if !valid || parts[0] == "" {
- return nil
- }
- n := t.search(t.root, parts)
- if n.isEmpty() {
- return nil
- }
- return n
- }
- func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
- if len(parts) == 0 {
- return node
- }
- if c := node.getChild(parts[len(parts)-1]); c != nil {
- if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
- return n
- }
- }
- if c := node.getChild(wildcard); c != nil {
- if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
- return n
- }
- }
- return node.getChild(dotWildcard)
- }
- func (t *DomainTrie[T]) Optimize() {
- t.root.optimize()
- }
- func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) {
- for key, data := range t.root.getChildren() {
- recursion([]string{key}, data, fn)
- if data != nil && data.inited {
- if !fn(joinDomain([]string{key}), data.data) {
- return
- }
- }
- }
- }
- func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool {
- for key, data := range node.getChildren() {
- newItems := append([]string{key}, items...)
- if data != nil && data.inited {
- domain := joinDomain(newItems)
- if domain[0] == domainStepByte {
- domain = complexWildcard + domain
- }
- if !fn(domain, data.Data()) {
- return false
- }
- }
- if !recursion(newItems, data, fn) {
- return false
- }
- }
- return true
- }
- func joinDomain(items []string) string {
- return strings.Join(items, domainStep)
- }
- // New returns a new, empty Trie.
- func New[T any]() *DomainTrie[T] {
- return &DomainTrie[T]{root: newNode[T]()}
- }
|