package trie

import (
	"errors"
	"strings"
)

const (
	wildcard        = "*"
	dotWildcard     = ""
	complexWildcard = "+"
	domainStep      = "."
)

// ErrInvalidDomain means insert domain is invalid
var ErrInvalidDomain = errors.New("invalid domain")

// DomainTrie contains the main logic for adding and searching nodes for domain segments.
// support wildcard domain (e.g *.google.com)
type DomainTrie[T any] struct {
	root *Node[T]
}

func ValidAndSplitDomain(domain string) ([]string, bool) {
	if domain != "" && domain[len(domain)-1] == '.' {
		return nil, false
	}
	domain = strings.ToLower(domain)
	parts := strings.Split(domain, domainStep)
	if len(parts) == 1 {
		if parts[0] == "" {
			return nil, false
		}

		return parts, true
	}

	for _, part := range parts[1:] {
		if part == "" {
			return nil, false
		}
	}

	return parts, true
}

// Insert adds a node to the trie.
// Support
// 1. www.example.com
// 2. *.example.com
// 3. subdomain.*.example.com
// 4. .example.com
// 5. +.example.com
func (t *DomainTrie[T]) Insert(domain string, data T) error {
	parts, valid := ValidAndSplitDomain(domain)
	if !valid {
		return ErrInvalidDomain
	}

	if parts[0] == complexWildcard {
		t.insert(parts[1:], data)
		parts[0] = dotWildcard
		t.insert(parts, data)
	} else {
		t.insert(parts, data)
	}

	return nil
}

func (t *DomainTrie[T]) insert(parts []string, data T) {
	node := t.root
	// reverse storage domain part to save space
	for i := len(parts) - 1; i >= 0; i-- {
		part := parts[i]
		node = node.getOrNewChild(part)
	}

	node.setData(data)
}

// Search is the most important part of the Trie.
// Priority as:
// 1. static part
// 2. wildcard domain
// 2. dot wildcard domain
func (t *DomainTrie[T]) Search(domain string) *Node[T] {
	parts, valid := ValidAndSplitDomain(domain)
	if !valid || parts[0] == "" {
		return nil
	}

	n := t.search(t.root, parts)

	if n.isEmpty() {
		return nil
	}

	return n
}

func (t *DomainTrie[T]) search(node *Node[T], parts []string) *Node[T] {
	if len(parts) == 0 {
		return node
	}

	if c := node.getChild(parts[len(parts)-1]); c != nil {
		if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
			return n
		}
	}

	if c := node.getChild(wildcard); c != nil {
		if n := t.search(c, parts[:len(parts)-1]); !n.isEmpty() {
			return n
		}
	}

	return node.getChild(dotWildcard)
}

func (t *DomainTrie[T]) Optimize() {
	t.root.optimize()
}

func (t *DomainTrie[T]) Foreach(fn func(domain string, data T) bool) {
	for key, data := range t.root.getChildren() {
		recursion([]string{key}, data, fn)
		if data != nil && data.inited {
			if !fn(joinDomain([]string{key}), data.data) {
				return
			}
		}
	}
}

func recursion[T any](items []string, node *Node[T], fn func(domain string, data T) bool) bool {
	for key, data := range node.getChildren() {
		newItems := append([]string{key}, items...)
		if data != nil && data.inited {
			domain := joinDomain(newItems)
			if domain[0] == domainStepByte {
				domain = complexWildcard + domain
			}
			if !fn(domain, data.Data()) {
				return false
			}
		}
		if !recursion(newItems, data, fn) {
			return false
		}
	}
	return true
}

func joinDomain(items []string) string {
	return strings.Join(items, domainStep)
}

// New returns a new, empty Trie.
func New[T any]() *DomainTrie[T] {
	return &DomainTrie[T]{root: newNode[T]()}
}