pool.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. package pool
  2. import (
  3. "context"
  4. "errors"
  5. "net"
  6. "sync"
  7. "sync/atomic"
  8. "time"
  9. "github.com/go-redis/redis/v8/internal"
  10. )
  11. var (
  12. // ErrClosed performs any operation on the closed client will return this error.
  13. ErrClosed = errors.New("redis: client is closed")
  14. // ErrPoolTimeout timed out waiting to get a connection from the connection pool.
  15. ErrPoolTimeout = errors.New("redis: connection pool timeout")
  16. )
  17. var timers = sync.Pool{
  18. New: func() interface{} {
  19. t := time.NewTimer(time.Hour)
  20. t.Stop()
  21. return t
  22. },
  23. }
  24. // Stats contains pool state information and accumulated stats.
  25. type Stats struct {
  26. Hits uint32 // number of times free connection was found in the pool
  27. Misses uint32 // number of times free connection was NOT found in the pool
  28. Timeouts uint32 // number of times a wait timeout occurred
  29. TotalConns uint32 // number of total connections in the pool
  30. IdleConns uint32 // number of idle connections in the pool
  31. StaleConns uint32 // number of stale connections removed from the pool
  32. }
  33. type Pooler interface {
  34. NewConn(context.Context) (*Conn, error)
  35. CloseConn(*Conn) error
  36. Get(context.Context) (*Conn, error)
  37. Put(context.Context, *Conn)
  38. Remove(context.Context, *Conn, error)
  39. Len() int
  40. IdleLen() int
  41. Stats() *Stats
  42. Close() error
  43. }
  44. type Options struct {
  45. Dialer func(context.Context) (net.Conn, error)
  46. OnClose func(*Conn) error
  47. PoolFIFO bool
  48. PoolSize int
  49. MinIdleConns int
  50. MaxConnAge time.Duration
  51. PoolTimeout time.Duration
  52. IdleTimeout time.Duration
  53. IdleCheckFrequency time.Duration
  54. }
  55. type lastDialErrorWrap struct {
  56. err error
  57. }
  58. type ConnPool struct {
  59. opt *Options
  60. dialErrorsNum uint32 // atomic
  61. lastDialError atomic.Value
  62. queue chan struct{}
  63. connsMu sync.Mutex
  64. conns []*Conn
  65. idleConns []*Conn
  66. poolSize int
  67. idleConnsLen int
  68. stats Stats
  69. _closed uint32 // atomic
  70. closedCh chan struct{}
  71. }
  72. var _ Pooler = (*ConnPool)(nil)
  73. func NewConnPool(opt *Options) *ConnPool {
  74. p := &ConnPool{
  75. opt: opt,
  76. queue: make(chan struct{}, opt.PoolSize),
  77. conns: make([]*Conn, 0, opt.PoolSize),
  78. idleConns: make([]*Conn, 0, opt.PoolSize),
  79. closedCh: make(chan struct{}),
  80. }
  81. p.connsMu.Lock()
  82. p.checkMinIdleConns()
  83. p.connsMu.Unlock()
  84. if opt.IdleTimeout > 0 && opt.IdleCheckFrequency > 0 {
  85. go p.reaper(opt.IdleCheckFrequency)
  86. }
  87. return p
  88. }
  89. func (p *ConnPool) checkMinIdleConns() {
  90. if p.opt.MinIdleConns == 0 {
  91. return
  92. }
  93. for p.poolSize < p.opt.PoolSize && p.idleConnsLen < p.opt.MinIdleConns {
  94. p.poolSize++
  95. p.idleConnsLen++
  96. go func() {
  97. err := p.addIdleConn()
  98. if err != nil && err != ErrClosed {
  99. p.connsMu.Lock()
  100. p.poolSize--
  101. p.idleConnsLen--
  102. p.connsMu.Unlock()
  103. }
  104. }()
  105. }
  106. }
  107. func (p *ConnPool) addIdleConn() error {
  108. cn, err := p.dialConn(context.TODO(), true)
  109. if err != nil {
  110. return err
  111. }
  112. p.connsMu.Lock()
  113. defer p.connsMu.Unlock()
  114. // It is not allowed to add new connections to the closed connection pool.
  115. if p.closed() {
  116. _ = cn.Close()
  117. return ErrClosed
  118. }
  119. p.conns = append(p.conns, cn)
  120. p.idleConns = append(p.idleConns, cn)
  121. return nil
  122. }
  123. func (p *ConnPool) NewConn(ctx context.Context) (*Conn, error) {
  124. return p.newConn(ctx, false)
  125. }
  126. func (p *ConnPool) newConn(ctx context.Context, pooled bool) (*Conn, error) {
  127. cn, err := p.dialConn(ctx, pooled)
  128. if err != nil {
  129. return nil, err
  130. }
  131. p.connsMu.Lock()
  132. defer p.connsMu.Unlock()
  133. // It is not allowed to add new connections to the closed connection pool.
  134. if p.closed() {
  135. _ = cn.Close()
  136. return nil, ErrClosed
  137. }
  138. p.conns = append(p.conns, cn)
  139. if pooled {
  140. // If pool is full remove the cn on next Put.
  141. if p.poolSize >= p.opt.PoolSize {
  142. cn.pooled = false
  143. } else {
  144. p.poolSize++
  145. }
  146. }
  147. return cn, nil
  148. }
  149. func (p *ConnPool) dialConn(ctx context.Context, pooled bool) (*Conn, error) {
  150. if p.closed() {
  151. return nil, ErrClosed
  152. }
  153. if atomic.LoadUint32(&p.dialErrorsNum) >= uint32(p.opt.PoolSize) {
  154. return nil, p.getLastDialError()
  155. }
  156. netConn, err := p.opt.Dialer(ctx)
  157. if err != nil {
  158. p.setLastDialError(err)
  159. if atomic.AddUint32(&p.dialErrorsNum, 1) == uint32(p.opt.PoolSize) {
  160. go p.tryDial()
  161. }
  162. return nil, err
  163. }
  164. cn := NewConn(netConn)
  165. cn.pooled = pooled
  166. return cn, nil
  167. }
  168. func (p *ConnPool) tryDial() {
  169. for {
  170. if p.closed() {
  171. return
  172. }
  173. conn, err := p.opt.Dialer(context.Background())
  174. if err != nil {
  175. p.setLastDialError(err)
  176. time.Sleep(time.Second)
  177. continue
  178. }
  179. atomic.StoreUint32(&p.dialErrorsNum, 0)
  180. _ = conn.Close()
  181. return
  182. }
  183. }
  184. func (p *ConnPool) setLastDialError(err error) {
  185. p.lastDialError.Store(&lastDialErrorWrap{err: err})
  186. }
  187. func (p *ConnPool) getLastDialError() error {
  188. err, _ := p.lastDialError.Load().(*lastDialErrorWrap)
  189. if err != nil {
  190. return err.err
  191. }
  192. return nil
  193. }
  194. // Get returns existed connection from the pool or creates a new one.
  195. func (p *ConnPool) Get(ctx context.Context) (*Conn, error) {
  196. if p.closed() {
  197. return nil, ErrClosed
  198. }
  199. if err := p.waitTurn(ctx); err != nil {
  200. return nil, err
  201. }
  202. for {
  203. p.connsMu.Lock()
  204. cn, err := p.popIdle()
  205. p.connsMu.Unlock()
  206. if err != nil {
  207. return nil, err
  208. }
  209. if cn == nil {
  210. break
  211. }
  212. if p.isStaleConn(cn) {
  213. _ = p.CloseConn(cn)
  214. continue
  215. }
  216. atomic.AddUint32(&p.stats.Hits, 1)
  217. return cn, nil
  218. }
  219. atomic.AddUint32(&p.stats.Misses, 1)
  220. newcn, err := p.newConn(ctx, true)
  221. if err != nil {
  222. p.freeTurn()
  223. return nil, err
  224. }
  225. return newcn, nil
  226. }
  227. func (p *ConnPool) getTurn() {
  228. p.queue <- struct{}{}
  229. }
  230. func (p *ConnPool) waitTurn(ctx context.Context) error {
  231. select {
  232. case <-ctx.Done():
  233. return ctx.Err()
  234. default:
  235. }
  236. select {
  237. case p.queue <- struct{}{}:
  238. return nil
  239. default:
  240. }
  241. timer := timers.Get().(*time.Timer)
  242. timer.Reset(p.opt.PoolTimeout)
  243. select {
  244. case <-ctx.Done():
  245. if !timer.Stop() {
  246. <-timer.C
  247. }
  248. timers.Put(timer)
  249. return ctx.Err()
  250. case p.queue <- struct{}{}:
  251. if !timer.Stop() {
  252. <-timer.C
  253. }
  254. timers.Put(timer)
  255. return nil
  256. case <-timer.C:
  257. timers.Put(timer)
  258. atomic.AddUint32(&p.stats.Timeouts, 1)
  259. return ErrPoolTimeout
  260. }
  261. }
  262. func (p *ConnPool) freeTurn() {
  263. <-p.queue
  264. }
  265. func (p *ConnPool) popIdle() (*Conn, error) {
  266. if p.closed() {
  267. return nil, ErrClosed
  268. }
  269. n := len(p.idleConns)
  270. if n == 0 {
  271. return nil, nil
  272. }
  273. var cn *Conn
  274. if p.opt.PoolFIFO {
  275. cn = p.idleConns[0]
  276. copy(p.idleConns, p.idleConns[1:])
  277. p.idleConns = p.idleConns[:n-1]
  278. } else {
  279. idx := n - 1
  280. cn = p.idleConns[idx]
  281. p.idleConns = p.idleConns[:idx]
  282. }
  283. p.idleConnsLen--
  284. p.checkMinIdleConns()
  285. return cn, nil
  286. }
  287. func (p *ConnPool) Put(ctx context.Context, cn *Conn) {
  288. if cn.rd.Buffered() > 0 {
  289. internal.Logger.Printf(ctx, "Conn has unread data")
  290. p.Remove(ctx, cn, BadConnError{})
  291. return
  292. }
  293. if !cn.pooled {
  294. p.Remove(ctx, cn, nil)
  295. return
  296. }
  297. p.connsMu.Lock()
  298. p.idleConns = append(p.idleConns, cn)
  299. p.idleConnsLen++
  300. p.connsMu.Unlock()
  301. p.freeTurn()
  302. }
  303. func (p *ConnPool) Remove(ctx context.Context, cn *Conn, reason error) {
  304. p.removeConnWithLock(cn)
  305. p.freeTurn()
  306. _ = p.closeConn(cn)
  307. }
  308. func (p *ConnPool) CloseConn(cn *Conn) error {
  309. p.removeConnWithLock(cn)
  310. return p.closeConn(cn)
  311. }
  312. func (p *ConnPool) removeConnWithLock(cn *Conn) {
  313. p.connsMu.Lock()
  314. p.removeConn(cn)
  315. p.connsMu.Unlock()
  316. }
  317. func (p *ConnPool) removeConn(cn *Conn) {
  318. for i, c := range p.conns {
  319. if c == cn {
  320. p.conns = append(p.conns[:i], p.conns[i+1:]...)
  321. if cn.pooled {
  322. p.poolSize--
  323. p.checkMinIdleConns()
  324. }
  325. return
  326. }
  327. }
  328. }
  329. func (p *ConnPool) closeConn(cn *Conn) error {
  330. if p.opt.OnClose != nil {
  331. _ = p.opt.OnClose(cn)
  332. }
  333. return cn.Close()
  334. }
  335. // Len returns total number of connections.
  336. func (p *ConnPool) Len() int {
  337. p.connsMu.Lock()
  338. n := len(p.conns)
  339. p.connsMu.Unlock()
  340. return n
  341. }
  342. // IdleLen returns number of idle connections.
  343. func (p *ConnPool) IdleLen() int {
  344. p.connsMu.Lock()
  345. n := p.idleConnsLen
  346. p.connsMu.Unlock()
  347. return n
  348. }
  349. func (p *ConnPool) Stats() *Stats {
  350. idleLen := p.IdleLen()
  351. return &Stats{
  352. Hits: atomic.LoadUint32(&p.stats.Hits),
  353. Misses: atomic.LoadUint32(&p.stats.Misses),
  354. Timeouts: atomic.LoadUint32(&p.stats.Timeouts),
  355. TotalConns: uint32(p.Len()),
  356. IdleConns: uint32(idleLen),
  357. StaleConns: atomic.LoadUint32(&p.stats.StaleConns),
  358. }
  359. }
  360. func (p *ConnPool) closed() bool {
  361. return atomic.LoadUint32(&p._closed) == 1
  362. }
  363. func (p *ConnPool) Filter(fn func(*Conn) bool) error {
  364. p.connsMu.Lock()
  365. defer p.connsMu.Unlock()
  366. var firstErr error
  367. for _, cn := range p.conns {
  368. if fn(cn) {
  369. if err := p.closeConn(cn); err != nil && firstErr == nil {
  370. firstErr = err
  371. }
  372. }
  373. }
  374. return firstErr
  375. }
  376. func (p *ConnPool) Close() error {
  377. if !atomic.CompareAndSwapUint32(&p._closed, 0, 1) {
  378. return ErrClosed
  379. }
  380. close(p.closedCh)
  381. var firstErr error
  382. p.connsMu.Lock()
  383. for _, cn := range p.conns {
  384. if err := p.closeConn(cn); err != nil && firstErr == nil {
  385. firstErr = err
  386. }
  387. }
  388. p.conns = nil
  389. p.poolSize = 0
  390. p.idleConns = nil
  391. p.idleConnsLen = 0
  392. p.connsMu.Unlock()
  393. return firstErr
  394. }
  395. func (p *ConnPool) reaper(frequency time.Duration) {
  396. ticker := time.NewTicker(frequency)
  397. defer ticker.Stop()
  398. for {
  399. select {
  400. case <-ticker.C:
  401. // It is possible that ticker and closedCh arrive together,
  402. // and select pseudo-randomly pick ticker case, we double
  403. // check here to prevent being executed after closed.
  404. if p.closed() {
  405. return
  406. }
  407. _, err := p.ReapStaleConns()
  408. if err != nil {
  409. internal.Logger.Printf(context.Background(), "ReapStaleConns failed: %s", err)
  410. continue
  411. }
  412. case <-p.closedCh:
  413. return
  414. }
  415. }
  416. }
  417. func (p *ConnPool) ReapStaleConns() (int, error) {
  418. var n int
  419. for {
  420. p.getTurn()
  421. p.connsMu.Lock()
  422. cn := p.reapStaleConn()
  423. p.connsMu.Unlock()
  424. p.freeTurn()
  425. if cn != nil {
  426. _ = p.closeConn(cn)
  427. n++
  428. } else {
  429. break
  430. }
  431. }
  432. atomic.AddUint32(&p.stats.StaleConns, uint32(n))
  433. return n, nil
  434. }
  435. func (p *ConnPool) reapStaleConn() *Conn {
  436. if len(p.idleConns) == 0 {
  437. return nil
  438. }
  439. cn := p.idleConns[0]
  440. if !p.isStaleConn(cn) {
  441. return nil
  442. }
  443. p.idleConns = append(p.idleConns[:0], p.idleConns[1:]...)
  444. p.idleConnsLen--
  445. p.removeConn(cn)
  446. return cn
  447. }
  448. func (p *ConnPool) isStaleConn(cn *Conn) bool {
  449. if p.opt.IdleTimeout == 0 && p.opt.MaxConnAge == 0 {
  450. return false
  451. }
  452. now := time.Now()
  453. if p.opt.IdleTimeout > 0 && now.Sub(cn.UsedAt()) >= p.opt.IdleTimeout {
  454. return true
  455. }
  456. if p.opt.MaxConnAge > 0 && now.Sub(cn.createdAt) >= p.opt.MaxConnAge {
  457. return true
  458. }
  459. return false
  460. }