123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557 |
- package pool
- import (
- "context"
- "errors"
- "net"
- "sync"
- "sync/atomic"
- "time"
- "github.com/go-redis/redis/v8/internal"
- )
- var (
- // ErrClosed performs any operation on the closed client will return this error.
- ErrClosed = errors.New("redis: client is closed")
- // ErrPoolTimeout timed out waiting to get a connection from the connection pool.
- ErrPoolTimeout = errors.New("redis: connection pool timeout")
- )
- var timers = sync.Pool{
- New: func() interface{} {
- t := time.NewTimer(time.Hour)
- t.Stop()
- return t
- },
- }
- // Stats contains pool state information and accumulated stats.
- type Stats struct {
- Hits uint32 // number of times free connection was found in the pool
- Misses uint32 // number of times free connection was NOT found in the pool
- Timeouts uint32 // number of times a wait timeout occurred
- TotalConns uint32 // number of total connections in the pool
- IdleConns uint32 // number of idle connections in the pool
- StaleConns uint32 // number of stale connections removed from the pool
- }
- type Pooler interface {
- NewConn(context.Context) (*Conn, error)
- CloseConn(*Conn) error
- Get(context.Context) (*Conn, error)
- Put(context.Context, *Conn)
- Remove(context.Context, *Conn, error)
- Len() int
- IdleLen() int
- Stats() *Stats
- Close() error
- }
- type Options struct {
- Dialer func(context.Context) (net.Conn, error)
- OnClose func(*Conn) error
- PoolFIFO bool
- PoolSize int
- MinIdleConns int
- MaxConnAge time.Duration
- PoolTimeout time.Duration
- IdleTimeout time.Duration
- IdleCheckFrequency time.Duration
- }
- type lastDialErrorWrap struct {
- err error
- }
- type ConnPool struct {
- opt *Options
- dialErrorsNum uint32 // atomic
- lastDialError atomic.Value
- queue chan struct{}
- connsMu sync.Mutex
- conns []*Conn
- idleConns []*Conn
- poolSize int
- idleConnsLen int
- stats Stats
- _closed uint32 // atomic
- closedCh chan struct{}
- }
- var _ Pooler = (*ConnPool)(nil)
- func NewConnPool(opt *Options) *ConnPool {
- p := &ConnPool{
- opt: opt,
- queue: make(chan struct{}, opt.PoolSize),
- conns: make([]*Conn, 0, opt.PoolSize),
- idleConns: make([]*Conn, 0, opt.PoolSize),
- closedCh: make(chan struct{}),
- }
- p.connsMu.Lock()
- p.checkMinIdleConns()
- p.connsMu.Unlock()
- if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
- go p.reaper(opt.IdleCheckFrequency)
- }
- return p
- }
- func (p *ConnPool) checkMinIdleConns() {
- if p.opt.MinIdleConns == 0 {
- return
- }
- for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
- p.poolSize++
- p.idleConnsLen++
- go func() {
- err := p.addIdleConn()
- if err != nil && err != ErrClosed {
- p.connsMu.Lock()
- p.poolSize--
- p.idleConnsLen--
- p.connsMu.Unlock()
- }
- }()
- }
- }
- func (p *ConnPool) addIdleConn() error {
- cn, err := p.dialConn(context.TODO(), true)
- if err != nil {
- return err
- }
- p.connsMu.Lock()
- defer p.connsMu.Unlock()
- // It is not allowed to add new connections to the closed connection pool.
- if p.closed() {
- _ = cn.Close()
- return ErrClosed
- }
- p.conns = append(p.conns, cn)
- p.idleConns = append(p.idleConns, cn)
- return nil
- }
- func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
- return p.newConn(ctx, false)
- }
- func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
- cn, err := p.dialConn(ctx, pooled)
- if err != nil {
- return nil, err
- }
- p.connsMu.Lock()
- defer p.connsMu.Unlock()
- // It is not allowed to add new connections to the closed connection pool.
- if p.closed() {
- _ = cn.Close()
- return nil, ErrClosed
- }
- p.conns = append(p.conns, cn)
- if pooled {
- // If pool is full remove the cn on next Put.
- if p.poolSize >= p.opt.PoolSize {
- cn.pooled = false
- } else {
- p.poolSize++
- }
- }
- return cn, nil
- }
- func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
- if p.closed() {
- return nil, ErrClosed
- }
- if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
- return nil, p.getLastDialError()
- }
- netConn, err := p.opt.Dialer(ctx)
- if err != nil {
- p.setLastDialError(err)
- if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
- go p.tryDial()
- }
- return nil, err
- }
- cn := NewConn(netConn)
- cn.pooled = pooled
- return cn, nil
- }
- func (p *ConnPool) tryDial() {
- for {
- if p.closed() {
- return
- }
- conn, err := p.opt.Dialer(context.Background())
- if err != nil {
- p.setLastDialError(err)
- time.Sleep(time.Second)
- continue
- }
- atomic.StoreUint32(&p.dialErrorsNum, 0)
- _ = conn.Close()
- return
- }
- }
- func (p *ConnPool) setLastDialError(err error) {
- p.lastDialError.Store(&lastDialErrorWrap{err: err})
- }
- func (p *ConnPool) getLastDialError() error {
- err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
- if err != nil {
- return err.err
- }
- return nil
- }
- // Get returns existed connection from the pool or creates a new one.
- func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
- if p.closed() {
- return nil, ErrClosed
- }
- if err := p.waitTurn(ctx); err != nil {
- return nil, err
- }
- for {
- p.connsMu.Lock()
- cn, err := p.popIdle()
- p.connsMu.Unlock()
- if err != nil {
- return nil, err
- }
- if cn == nil {
- break
- }
- if p.isStaleConn(cn) {
- _ = p.CloseConn(cn)
- continue
- }
- atomic.AddUint32(&p.stats.Hits, 1)
- return cn, nil
- }
- atomic.AddUint32(&p.stats.Misses, 1)
- newcn, err := p.newConn(ctx, true)
- if err != nil {
- p.freeTurn()
- return nil, err
- }
- return newcn, nil
- }
- func (p *ConnPool) getTurn() {
- p.queue <- struct{}{}
- }
- func (p *ConnPool) waitTurn(ctx context.Context) error {
- select {
- case <-ctx.Done():
- return ctx.Err()
- default:
- }
- select {
- case p.queue <- struct{}{}:
- return nil
- default:
- }
- timer := timers.Get().(*time.Timer)
- timer.Reset(p.opt.PoolTimeout)
- select {
- case <-ctx.Done():
- if !timer.Stop() {
- <-timer.C
- }
- timers.Put(timer)
- return ctx.Err()
- case p.queue <- struct{}{}:
- if !timer.Stop() {
- <-timer.C
- }
- timers.Put(timer)
- return nil
- case <-timer.C:
- timers.Put(timer)
- atomic.AddUint32(&p.stats.Timeouts, 1)
- return ErrPoolTimeout
- }
- }
- func (p *ConnPool) freeTurn() {
- <-p.queue
- }
- func (p *ConnPool) popIdle() (*Conn, error) {
- if p.closed() {
- return nil, ErrClosed
- }
- n := len(p.idleConns)
- if n == 0 {
- return nil, nil
- }
- var cn *Conn
- if p.opt.PoolFIFO {
- cn = p.idleConns[0]
- copy(p.idleConns, p.idleConns[1:])
- p.idleConns = p.idleConns[:n-1]
- } else {
- idx := n - 1
- cn = p.idleConns[idx]
- p.idleConns = p.idleConns[:idx]
- }
- p.idleConnsLen--
- p.checkMinIdleConns()
- return cn, nil
- }
- func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
- if cn.rd.Buffered() > 0 {
- internal.Logger.Printf(ctx, "Conn has unread data")
- p.Remove(ctx, cn, BadConnError{})
- return
- }
- if !cn.pooled {
- p.Remove(ctx, cn, nil)
- return
- }
- p.connsMu.Lock()
- p.idleConns = append(p.idleConns, cn)
- p.idleConnsLen++
- p.connsMu.Unlock()
- p.freeTurn()
- }
- func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
- p.removeConnWithLock(cn)
- p.freeTurn()
- _ = p.closeConn(cn)
- }
- func (p *ConnPool) CloseConn(cn *Conn) error {
- p.removeConnWithLock(cn)
- return p.closeConn(cn)
- }
- func (p *ConnPool) removeConnWithLock(cn *Conn) {
- p.connsMu.Lock()
- p.removeConn(cn)
- p.connsMu.Unlock()
- }
- func (p *ConnPool) removeConn(cn *Conn) {
- for i, c := range p.conns {
- if c == cn {
- p.conns = append(p.conns[:i], p.conns[i+1:]...)
- if cn.pooled {
- p.poolSize--
- p.checkMinIdleConns()
- }
- return
- }
- }
- }
- func (p *ConnPool) closeConn(cn *Conn) error {
- if p.opt.OnClose != nil {
- _ = p.opt.OnClose(cn)
- }
- return cn.Close()
- }
- // Len returns total number of connections.
- func (p *ConnPool) Len() int {
- p.connsMu.Lock()
- n := len(p.conns)
- p.connsMu.Unlock()
- return n
- }
- // IdleLen returns number of idle connections.
- func (p *ConnPool) IdleLen() int {
- p.connsMu.Lock()
- n := p.idleConnsLen
- p.connsMu.Unlock()
- return n
- }
- func (p *ConnPool) Stats() *Stats {
- idleLen := p.IdleLen()
- return &Stats{
- Hits: atomic.LoadUint32(&p.stats.Hits),
- Misses: atomic.LoadUint32(&p.stats.Misses),
- Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
- TotalConns: uint32(p.Len()),
- IdleConns: uint32(idleLen),
- StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
- }
- }
- func (p *ConnPool) closed() bool {
- return atomic.LoadUint32(&p._closed) == 1
- }
- func (p *ConnPool) Filter(fn func(*Conn) bool) error {
- p.connsMu.Lock()
- defer p.connsMu.Unlock()
- var firstErr error
- for _, cn := range p.conns {
- if fn(cn) {
- if err := p.closeConn(cn); err != nil && firstErr == nil {
- firstErr = err
- }
- }
- }
- return firstErr
- }
- func (p *ConnPool) Close() error {
- if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
- return ErrClosed
- }
- close(p.closedCh)
- var firstErr error
- p.connsMu.Lock()
- for _, cn := range p.conns {
- if err := p.closeConn(cn); err != nil && firstErr == nil {
- firstErr = err
- }
- }
- p.conns = nil
- p.poolSize = 0
- p.idleConns = nil
- p.idleConnsLen = 0
- p.connsMu.Unlock()
- return firstErr
- }
- func (p *ConnPool) reaper(frequency time.Duration) {
- ticker := time.NewTicker(frequency)
- defer ticker.Stop()
- for {
- select {
- case <-ticker.C:
- // It is possible that ticker and closedCh arrive together,
- // and select pseudo-randomly pick ticker case, we double
- // check here to prevent being executed after closed.
- if p.closed() {
- return
- }
- _, err := p.ReapStaleConns()
- if err != nil {
- internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err)
- continue
- }
- case <-p.closedCh:
- return
- }
- }
- }
- func (p *ConnPool) ReapStaleConns() (int, error) {
- var n int
- for {
- p.getTurn()
- p.connsMu.Lock()
- cn := p.reapStaleConn()
- p.connsMu.Unlock()
- p.freeTurn()
- if cn != nil {
- _ = p.closeConn(cn)
- n++
- } else {
- break
- }
- }
- atomic.AddUint32(&p.stats.StaleConns, uint32(n))
- return n, nil
- }
- func (p *ConnPool) reapStaleConn() *Conn {
- if len(p.idleConns) == 0 {
- return nil
- }
- cn := p.idleConns[0]
- if !p.isStaleConn(cn) {
- return nil
- }
- p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...)
- p.idleConnsLen--
- p.removeConn(cn)
- return cn
- }
- func (p *ConnPool) isStaleConn(cn *Conn) bool {
- if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
- return false
- }
- now := time.Now()
- if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
- return true
- }
- if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
- return true
- }
- return false
- }
|