Search CTRL + K

源码阅读之 pond

首先从 API 文档 入手,可以看到核心数据结构和接口就几个:

全局变量定义

首先进入 pool.go 文件,可以看到定义了如下全局变量:

前面的是默认值;后面的是错误定义。这是个好习惯,将系统可能的错误枚举地、显示的定义、罗列,方便调用者判断错误,或包装转发错误。

const (
	DefaultQueueSize        = 0
	DefaultNonBlocking      = false
	LinkedBufferInitialSize = 1024
	LinkedBufferMaxCapacity = 100 * 1024
)

var (
	ErrQueueFull             = errors.New("queue is full")
	ErrQueueEmpty            = errors.New("queue is empty")
	ErrPoolStopped           = errors.New("pool stopped")
	ErrMaxConcurrencyReached = errors.New("max concurrency reached")

	poolStoppedFuture = func() Task {
		future, resolve := future.NewFuture(context.Background())
		resolve(ErrPoolStopped)
		return future
	}()
)

库的入口:Pool

库定义了两个接口类来拆解“协程池”这一概念:

// basePool is the base interface for all pool types.
type basePool interface {
	// Returns the number of worker goroutines that are currently active (executing a task) in the pool.
	RunningWorkers() int64

	// Returns the total number of tasks submitted to the pool since its creation.
	SubmittedTasks() uint64

	// Returns the number of tasks that are currently waiting in the pool's queue.
	WaitingTasks() uint64

	// Returns the number of tasks that have completed with an error.
	FailedTasks() uint64

	// Returns the number of tasks that have completed successfully.
	SuccessfulTasks() uint64

	// Returns the total number of tasks that have completed (either successfully or with an error).
	CompletedTasks() uint64

	// Returns the maximum concurrency of the pool.
	MaxConcurrency() int

	// Returns the size of the task queue.
	QueueSize() int

	// Returns true if the pool is non-blocking, meaning that it will not block when the task queue is full.
	// In a non-blocking pool, tasks that cannot be submitted to the queue will be dropped.
	// By default, pools are blocking, meaning that they will block when the task queue is full.
	NonBlocking() bool

	// Returns the context associated with this pool.
	Context() context.Context

	// Stops the pool and returns a future that can be used to wait for all tasks pending to complete.
	Stop() Task

	// Stops the pool and waits for all tasks to complete.
	StopAndWait()

	// Returns true if the pool has been stopped or its context has been cancelled.
	Stopped() bool

	// Resizes the pool by changing the maximum concurrency (number of workers) of the pool.
	// The new max concurrency must be greater than 0.
	// If the new max concurrency is less than the current number of running workers, the pool will continue to run with the new max concurrency.
	Resize(maxConcurrency int)
}

// Represents a pool of goroutines that can execute tasks concurrently.
type Pool interface {
	basePool

	// Submits a task to the pool without waiting for it to complete.
	Go(task func()) error

	// Submits a task to the pool and returns a future that can be used to wait for the task to complete.
	Submit(task func()) Task

	// Submits a task to the pool and returns a future that can be used to wait for the task to complete.
	SubmitErr(task func() error) Task

	// Creates a new subpool with the specified maximum concurrency and options.
	NewSubpool(maxConcurrency int, options ...Option) Pool

	// Creates a new task group.
	NewGroup() TaskGroup

	// Creates a new task group with the specified context.
	NewGroupContext(ctx context.Context) TaskGroup
}

basePool 定义了整个协程池的状态获取、参数获取和启停等全局的方法,而后者 Pool 定义了“动作”,比如提交任务、创建任务组等。

为什么要这样接口拆解?

因为有 PoolResultPool 这两种协程池,他们提交任务相关的方法参数不同,但其他接口都一致。

接着定义了内部结构体 pool 实现所有方法。

Pool 的初始化

NewPool 函数中可以看到这么一段逻辑:

func newPool(maxConcurrency int, parent *pool, options ...Option) *pool {

	// ...

	pool := &pool{
		ctx:            context.Background(),
		nonBlocking:    DefaultNonBlocking,
		maxConcurrency: maxConcurrency,
		queueSize:      DefaultQueueSize,
		submitWaiters:  make(chan struct{}),
	}

    // ...

	for _, option := range options {
		option(pool)
	}

	// ...

	return pool
}

func NewPool(maxConcurrency int, options ...Option) Pool {
	return newPool(maxConcurrency, nil, options...)
}

追寻到 pooloptions.go 文件,可以看到 Option 和相关参数定义如下:

type Option func(*pool)

// WithContext sets the context for the pool.
func WithContext(ctx context.Context) Option {
	return func(p *pool) {
		p.ctx = ctx
	}
}

// WithQueueSize sets the max number of elements that can be queued in the pool.
func WithQueueSize(size int) Option {
	return func(p *pool) {
		p.queueSize = size
	}
}

// WithNonBlocking sets the pool to be non-blocking when the queue is full.
// This option is only effective when the queue size is set.
func WithNonBlocking(nonBlocking bool) Option {
	return func(p *pool) {
		p.nonBlocking = nonBlocking
	}
}

眼尖的人可能一眼看出来这是 生成器模式(Builder Design Pattern),但是是 golang native 版。得益于 golang 中函数也是一等公民,可以将生成器模式优化成更加简洁的版本,以避免重 叠构造函数(telescoping constructor) 的问题。

pool 内部实现

type pool struct {
	mutex               sync.Mutex
	parent              *pool
	ctx                 context.Context
	cancel              context.CancelCauseFunc
	nonBlocking         bool
	maxConcurrency      int
	closed              atomic.Bool
	workerCount         atomic.Int64
	workerWaitGroup     sync.WaitGroup
	submitWaiters       chan struct{}
	queueSize           int
	tasks               *linkedbuffer.LinkedBuffer[any]
	submittedTaskCount  atomic.Uint64
	successfulTaskCount atomic.Uint64
	failedTaskCount     atomic.Uint64
}

这里只讲几个关键 fields。

// buffer implements a generic buffer that can store any type of data.
// It is not thread-safe and should be used with a mutex.
// It is used by LinkedBuffer to store data and is not intended to be used directly.
type buffer[T any] struct {
	data           []T
	nextWriteIndex int
	nextReadIndex  int
	next           *buffer[T]
}

// LinkedBuffer implements an unbounded generic buffer that can be written to and read from concurrently.
// It is implemented using a linked list of buffers.
type LinkedBuffer[T any] struct {
	// Reader points to the buffer that is currently being read
	readBuffer *buffer[T]

	// Writer points to the buffer that is currently being written
	writeBuffer *buffer[T]

	maxCapacity int
	writeCount  atomic.Uint64
	readCount   atomic.Uint64
}

func NewLinkedBuffer[T any](initialCapacity, maxCapacity int) *LinkedBuffer[T] {
	initialBuffer := newBuffer[T](initialCapacity)

	buffer := &LinkedBuffer[T]{
		readBuffer:  initialBuffer,
		writeBuffer: initialBuffer,
		maxCapacity: maxCapacity,
	}

	return buffer
}

初看定义,可能会非常奇怪为什么要在 buffer 中定义 read、write index;为什么要用 next 指向另一个 buffer。

接着看 LinkedBuffer 的操作方法:

// Write writes values to the buffer
func (b *LinkedBuffer[T]) Write(value T) {

	// Write elements
	err := b.writeBuffer.Write(value)

	if err == ErrEOF {
		// Increase next buffer capacity
		var newCapacity int
		capacity := b.writeBuffer.Cap()
		if capacity < 1024 {
			newCapacity = capacity * 2
		} else {
			newCapacity = capacity + capacity/2
		}
		if newCapacity > b.maxCapacity {
			newCapacity = b.maxCapacity
		}

		if b.writeBuffer.next == nil {
			b.writeBuffer.next = newBuffer[T](newCapacity)
			b.writeBuffer = b.writeBuffer.next
		}

		// Retry writing
		b.Write(value)
		return
	}

	// Increment written count
	b.writeCount.Add(1)
}

// Read reads values from the buffer and returns the number of elements read
func (b *LinkedBuffer[T]) Read() (value T, err error) {
	// Read element
	value, err = b.readBuffer.Read()

	if err == ErrEOF {
		if b.readBuffer.next == nil {
			// No more elements to read
			return
		}
		// Move to next read buffer
		if b.readBuffer != b.readBuffer.next {
			b.readBuffer = b.readBuffer.next
		}

		// Retry reading
		return b.Read()
	}

	// Increment read count
	b.readCount.Add(1)

	return
}

可以看到 Write 方法在当前 buffer“满了”后,会新建一段 buffer 并移动当前 writeBuffer 到下一段,完全不影响当前的 readBuffer;并且当当前 Read 也读完后,也会顺势切换到下一段 buffer。

v2 版本的最大变动:Task

用过 v1 版本 pond 的人应该知道,v1 版本最大的遗憾就是无法精细 Wait 某一个任务。于是 v2 版本引入了 Task/Result

// Task represents a task that can be waited on. If the task fails, the error can be retrieved.
type Task interface {

	// Done returns a channel that is closed when the task is complete or has failed.
	Done() <-chan struct{}

	// Wait waits for the task to complete and returns any error that occurred.
	Wait() error
}

// TaskGroup represents a task that yields a result. If the task fails, the error can be retrieved.
type Result[R any] interface {

	// Done returns a channel that is closed when the task is complete or has failed.
	Done() <-chan struct{}

	// Wait waits for the task to complete and returns the result and any error that occurred.
	Wait() (R, error)
}

现在可以主动地调用 Wait 以等待某个任务执行完成并获取其结果了。

其实内部就是利用 context 实现结果传递。

不合理的改动

v2 版本将 TrySubmit 方法删除,改为若为 non-blocking 模式时,Submit 后的任务会即时失败,可以通过 task.Wait() 获取错误,错误为 queue is full

个人认为这个改动不合理,queue is full 的错误应当是 Submit 这个动作产生的,因此应该由该动作直接返回。若将任务提交失败的信息放到任务结果中获取,这就造成了语义混乱。

虽然也可以实现“当任务队列满了返回错误”的要求,但会变得相当扭曲:

package main

import (
        "errors"
        "fmt"
        "time"

        "github.com/alitto/pond/v2"
)

func main() {
        pool := pond.NewPool(1, pond.WithQueueSize(5), pond.WithNonBlocking(true))
        for i := range 10 {
                i := i
                task := pool.Submit(func() {
                        fmt.Printf("Running task #%d\n", i)
                        time.Sleep(time.Second)
                })
                select {
                case <-task.Done():
                        if err := task.Wait(); errors.Is(err, pond.ErrQueueFull) {
                                fmt.Printf("Error: %v\n", err)
                        }
                default:
                }
        }
        pool.StopAndWait()
}

已提 issue #103 等待讨论。