batch.go 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. package batch
  2. import (
  3. "context"
  4. "sync"
  5. )
  6. type Option[T any] func(b *Batch[T])
  7. type Result[T any] struct {
  8. Value T
  9. Err error
  10. }
  11. type Error struct {
  12. Key string
  13. Err error
  14. }
  15. func WithConcurrencyNum[T any](n int) Option[T] {
  16. return func(b *Batch[T]) {
  17. q := make(chan struct{}, n)
  18. for i := 0; i < n; i++ {
  19. q <- struct{}{}
  20. }
  21. b.queue = q
  22. }
  23. }
  24. // Batch similar to errgroup, but can control the maximum number of concurrent
  25. type Batch[T any] struct {
  26. result map[string]Result[T]
  27. queue chan struct{}
  28. wg sync.WaitGroup
  29. mux sync.Mutex
  30. err *Error
  31. once sync.Once
  32. cancel func()
  33. }
  34. func (b *Batch[T]) Go(key string, fn func() (T, error)) {
  35. b.wg.Add(1)
  36. go func() {
  37. defer b.wg.Done()
  38. if b.queue != nil {
  39. <-b.queue
  40. defer func() {
  41. b.queue <- struct{}{}
  42. }()
  43. }
  44. value, err := fn()
  45. if err != nil {
  46. b.once.Do(func() {
  47. b.err = &Error{key, err}
  48. if b.cancel != nil {
  49. b.cancel()
  50. }
  51. })
  52. }
  53. ret := Result[T]{value, err}
  54. b.mux.Lock()
  55. defer b.mux.Unlock()
  56. b.result[key] = ret
  57. }()
  58. }
  59. func (b *Batch[T]) Wait() *Error {
  60. b.wg.Wait()
  61. if b.cancel != nil {
  62. b.cancel()
  63. }
  64. return b.err
  65. }
  66. func (b *Batch[T]) WaitAndGetResult() (map[string]Result[T], *Error) {
  67. err := b.Wait()
  68. return b.Result(), err
  69. }
  70. func (b *Batch[T]) Result() map[string]Result[T] {
  71. b.mux.Lock()
  72. defer b.mux.Unlock()
  73. copyM := map[string]Result[T]{}
  74. for k, v := range b.result {
  75. copyM[k] = v
  76. }
  77. return copyM
  78. }
  79. func New[T any](ctx context.Context, opts ...Option[T]) (*Batch[T], context.Context) {
  80. ctx, cancel := context.WithCancel(ctx)
  81. b := &Batch[T]{
  82. result: map[string]Result[T]{},
  83. }
  84. for _, o := range opts {
  85. o(b)
  86. }
  87. b.cancel = cancel
  88. return b, ctx
  89. }