123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668 |
- package redis
- import (
- "context"
- "fmt"
- "strings"
- "sync"
- "time"
- "github.com/go-redis/redis/v8/internal"
- "github.com/go-redis/redis/v8/internal/pool"
- "github.com/go-redis/redis/v8/internal/proto"
- )
- // PubSub implements Pub/Sub commands as described in
- // http://redis.io/topics/pubsub. Message receiving is NOT safe
- // for concurrent use by multiple goroutines.
- //
- // PubSub automatically reconnects to Redis Server and resubscribes
- // to the channels in case of network errors.
- type PubSub struct {
- opt *Options
- newConn func(ctx context.Context, channels []string) (*pool.Conn, error)
- closeConn func(*pool.Conn) error
- mu sync.Mutex
- cn *pool.Conn
- channels map[string]struct{}
- patterns map[string]struct{}
- closed bool
- exit chan struct{}
- cmd *Cmd
- chOnce sync.Once
- msgCh *channel
- allCh *channel
- }
- func (c *PubSub) init() {
- c.exit = make(chan struct{})
- }
- func (c *PubSub) String() string {
- channels := mapKeys(c.channels)
- channels = append(channels, mapKeys(c.patterns)...)
- return fmt.Sprintf("PubSub(%s)", strings.Join(channels, ", "))
- }
- func (c *PubSub) connWithLock(ctx context.Context) (*pool.Conn, error) {
- c.mu.Lock()
- cn, err := c.conn(ctx, nil)
- c.mu.Unlock()
- return cn, err
- }
- func (c *PubSub) conn(ctx context.Context, newChannels []string) (*pool.Conn, error) {
- if c.closed {
- return nil, pool.ErrClosed
- }
- if c.cn != nil {
- return c.cn, nil
- }
- channels := mapKeys(c.channels)
- channels = append(channels, newChannels...)
- cn, err := c.newConn(ctx, channels)
- if err != nil {
- return nil, err
- }
- if err := c.resubscribe(ctx, cn); err != nil {
- _ = c.closeConn(cn)
- return nil, err
- }
- c.cn = cn
- return cn, nil
- }
- func (c *PubSub) writeCmd(ctx context.Context, cn *pool.Conn, cmd Cmder) error {
- return cn.WithWriter(ctx, c.opt.WriteTimeout, func(wr *proto.Writer) error {
- return writeCmd(wr, cmd)
- })
- }
- func (c *PubSub) resubscribe(ctx context.Context, cn *pool.Conn) error {
- var firstErr error
- if len(c.channels) > 0 {
- firstErr = c._subscribe(ctx, cn, "subscribe", mapKeys(c.channels))
- }
- if len(c.patterns) > 0 {
- err := c._subscribe(ctx, cn, "psubscribe", mapKeys(c.patterns))
- if err != nil && firstErr == nil {
- firstErr = err
- }
- }
- return firstErr
- }
- func mapKeys(m map[string]struct{}) []string {
- s := make([]string, len(m))
- i := 0
- for k := range m {
- s[i] = k
- i++
- }
- return s
- }
- func (c *PubSub) _subscribe(
- ctx context.Context, cn *pool.Conn, redisCmd string, channels []string,
- ) error {
- args := make([]interface{}, 0, 1+len(channels))
- args = append(args, redisCmd)
- for _, channel := range channels {
- args = append(args, channel)
- }
- cmd := NewSliceCmd(ctx, args...)
- return c.writeCmd(ctx, cn, cmd)
- }
- func (c *PubSub) releaseConnWithLock(
- ctx context.Context,
- cn *pool.Conn,
- err error,
- allowTimeout bool,
- ) {
- c.mu.Lock()
- c.releaseConn(ctx, cn, err, allowTimeout)
- c.mu.Unlock()
- }
- func (c *PubSub) releaseConn(ctx context.Context, cn *pool.Conn, err error, allowTimeout bool) {
- if c.cn != cn {
- return
- }
- if isBadConn(err, allowTimeout, c.opt.Addr) {
- c.reconnect(ctx, err)
- }
- }
- func (c *PubSub) reconnect(ctx context.Context, reason error) {
- _ = c.closeTheCn(reason)
- _, _ = c.conn(ctx, nil)
- }
- func (c *PubSub) closeTheCn(reason error) error {
- if c.cn == nil {
- return nil
- }
- if !c.closed {
- internal.Logger.Printf(c.getContext(), "redis: discarding bad PubSub connection: %s", reason)
- }
- err := c.closeConn(c.cn)
- c.cn = nil
- return err
- }
- func (c *PubSub) Close() error {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closed {
- return pool.ErrClosed
- }
- c.closed = true
- close(c.exit)
- return c.closeTheCn(pool.ErrClosed)
- }
- // Subscribe the client to the specified channels. It returns
- // empty subscription if there are no channels.
- func (c *PubSub) Subscribe(ctx context.Context, channels ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- err := c.subscribe(ctx, "subscribe", channels...)
- if c.channels == nil {
- c.channels = make(map[string]struct{})
- }
- for _, s := range channels {
- c.channels[s] = struct{}{}
- }
- return err
- }
- // PSubscribe the client to the given patterns. It returns
- // empty subscription if there are no patterns.
- func (c *PubSub) PSubscribe(ctx context.Context, patterns ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- err := c.subscribe(ctx, "psubscribe", patterns...)
- if c.patterns == nil {
- c.patterns = make(map[string]struct{})
- }
- for _, s := range patterns {
- c.patterns[s] = struct{}{}
- }
- return err
- }
- // Unsubscribe the client from the given channels, or from all of
- // them if none is given.
- func (c *PubSub) Unsubscribe(ctx context.Context, channels ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- for _, channel := range channels {
- delete(c.channels, channel)
- }
- err := c.subscribe(ctx, "unsubscribe", channels...)
- return err
- }
- // PUnsubscribe the client from the given patterns, or from all of
- // them if none is given.
- func (c *PubSub) PUnsubscribe(ctx context.Context, patterns ...string) error {
- c.mu.Lock()
- defer c.mu.Unlock()
- for _, pattern := range patterns {
- delete(c.patterns, pattern)
- }
- err := c.subscribe(ctx, "punsubscribe", patterns...)
- return err
- }
- func (c *PubSub) subscribe(ctx context.Context, redisCmd string, channels ...string) error {
- cn, err := c.conn(ctx, channels)
- if err != nil {
- return err
- }
- err = c._subscribe(ctx, cn, redisCmd, channels)
- c.releaseConn(ctx, cn, err, false)
- return err
- }
- func (c *PubSub) Ping(ctx context.Context, payload ...string) error {
- args := []interface{}{"ping"}
- if len(payload) == 1 {
- args = append(args, payload[0])
- }
- cmd := NewCmd(ctx, args...)
- c.mu.Lock()
- defer c.mu.Unlock()
- cn, err := c.conn(ctx, nil)
- if err != nil {
- return err
- }
- err = c.writeCmd(ctx, cn, cmd)
- c.releaseConn(ctx, cn, err, false)
- return err
- }
- // Subscription received after a successful subscription to channel.
- type Subscription struct {
- // Can be "subscribe", "unsubscribe", "psubscribe" or "punsubscribe".
- Kind string
- // Channel name we have subscribed to.
- Channel string
- // Number of channels we are currently subscribed to.
- Count int
- }
- func (m *Subscription) String() string {
- return fmt.Sprintf("%s: %s", m.Kind, m.Channel)
- }
- // Message received as result of a PUBLISH command issued by another client.
- type Message struct {
- Channel string
- Pattern string
- Payload string
- PayloadSlice []string
- }
- func (m *Message) String() string {
- return fmt.Sprintf("Message<%s: %s>", m.Channel, m.Payload)
- }
- // Pong received as result of a PING command issued by another client.
- type Pong struct {
- Payload string
- }
- func (p *Pong) String() string {
- if p.Payload != "" {
- return fmt.Sprintf("Pong<%s>", p.Payload)
- }
- return "Pong"
- }
- func (c *PubSub) newMessage(reply interface{}) (interface{}, error) {
- switch reply := reply.(type) {
- case string:
- return &Pong{
- Payload: reply,
- }, nil
- case []interface{}:
- switch kind := reply[0].(string); kind {
- case "subscribe", "unsubscribe", "psubscribe", "punsubscribe":
- // Can be nil in case of "unsubscribe".
- channel, _ := reply[1].(string)
- return &Subscription{
- Kind: kind,
- Channel: channel,
- Count: int(reply[2].(int64)),
- }, nil
- case "message":
- switch payload := reply[2].(type) {
- case string:
- return &Message{
- Channel: reply[1].(string),
- Payload: payload,
- }, nil
- case []interface{}:
- ss := make([]string, len(payload))
- for i, s := range payload {
- ss[i] = s.(string)
- }
- return &Message{
- Channel: reply[1].(string),
- PayloadSlice: ss,
- }, nil
- default:
- return nil, fmt.Errorf("redis: unsupported pubsub message payload: %T", payload)
- }
- case "pmessage":
- return &Message{
- Pattern: reply[1].(string),
- Channel: reply[2].(string),
- Payload: reply[3].(string),
- }, nil
- case "pong":
- return &Pong{
- Payload: reply[1].(string),
- }, nil
- default:
- return nil, fmt.Errorf("redis: unsupported pubsub message: %q", kind)
- }
- default:
- return nil, fmt.Errorf("redis: unsupported pubsub message: %#v", reply)
- }
- }
- // ReceiveTimeout acts like Receive but returns an error if message
- // is not received in time. This is low-level API and in most cases
- // Channel should be used instead.
- func (c *PubSub) ReceiveTimeout(ctx context.Context, timeout time.Duration) (interface{}, error) {
- if c.cmd == nil {
- c.cmd = NewCmd(ctx)
- }
- // Don't hold the lock to allow subscriptions and pings.
- cn, err := c.connWithLock(ctx)
- if err != nil {
- return nil, err
- }
- err = cn.WithReader(ctx, timeout, func(rd *proto.Reader) error {
- return c.cmd.readReply(rd)
- })
- c.releaseConnWithLock(ctx, cn, err, timeout > 0)
- if err != nil {
- return nil, err
- }
- return c.newMessage(c.cmd.Val())
- }
- // Receive returns a message as a Subscription, Message, Pong or error.
- // See PubSub example for details. This is low-level API and in most cases
- // Channel should be used instead.
- func (c *PubSub) Receive(ctx context.Context) (interface{}, error) {
- return c.ReceiveTimeout(ctx, 0)
- }
- // ReceiveMessage returns a Message or error ignoring Subscription and Pong
- // messages. This is low-level API and in most cases Channel should be used
- // instead.
- func (c *PubSub) ReceiveMessage(ctx context.Context) (*Message, error) {
- for {
- msg, err := c.Receive(ctx)
- if err != nil {
- return nil, err
- }
- switch msg := msg.(type) {
- case *Subscription:
- // Ignore.
- case *Pong:
- // Ignore.
- case *Message:
- return msg, nil
- default:
- err := fmt.Errorf("redis: unknown message: %T", msg)
- return nil, err
- }
- }
- }
- func (c *PubSub) getContext() context.Context {
- if c.cmd != nil {
- return c.cmd.ctx
- }
- return context.Background()
- }
- //------------------------------------------------------------------------------
- // Channel returns a Go channel for concurrently receiving messages.
- // The channel is closed together with the PubSub. If the Go channel
- // is blocked full for 30 seconds the message is dropped.
- // Receive* APIs can not be used after channel is created.
- //
- // go-redis periodically sends ping messages to test connection health
- // and re-subscribes if ping can not not received for 30 seconds.
- func (c *PubSub) Channel(opts ...ChannelOption) <-chan *Message {
- c.chOnce.Do(func() {
- c.msgCh = newChannel(c, opts...)
- c.msgCh.initMsgChan()
- })
- if c.msgCh == nil {
- err := fmt.Errorf("redis: Channel can't be called after ChannelWithSubscriptions")
- panic(err)
- }
- return c.msgCh.msgCh
- }
- // ChannelSize is like Channel, but creates a Go channel
- // with specified buffer size.
- //
- // Deprecated: use Channel(WithChannelSize(size)), remove in v9.
- func (c *PubSub) ChannelSize(size int) <-chan *Message {
- return c.Channel(WithChannelSize(size))
- }
- // ChannelWithSubscriptions is like Channel, but message type can be either
- // *Subscription or *Message. Subscription messages can be used to detect
- // reconnections.
- //
- // ChannelWithSubscriptions can not be used together with Channel or ChannelSize.
- func (c *PubSub) ChannelWithSubscriptions(_ context.Context, size int) <-chan interface{} {
- c.chOnce.Do(func() {
- c.allCh = newChannel(c, WithChannelSize(size))
- c.allCh.initAllChan()
- })
- if c.allCh == nil {
- err := fmt.Errorf("redis: ChannelWithSubscriptions can't be called after Channel")
- panic(err)
- }
- return c.allCh.allCh
- }
- type ChannelOption func(c *channel)
- // WithChannelSize specifies the Go chan size that is used to buffer incoming messages.
- //
- // The default is 100 messages.
- func WithChannelSize(size int) ChannelOption {
- return func(c *channel) {
- c.chanSize = size
- }
- }
- // WithChannelHealthCheckInterval specifies the health check interval.
- // PubSub will ping Redis Server if it does not receive any messages within the interval.
- // To disable health check, use zero interval.
- //
- // The default is 3 seconds.
- func WithChannelHealthCheckInterval(d time.Duration) ChannelOption {
- return func(c *channel) {
- c.checkInterval = d
- }
- }
- // WithChannelSendTimeout specifies the channel send timeout after which
- // the message is dropped.
- //
- // The default is 60 seconds.
- func WithChannelSendTimeout(d time.Duration) ChannelOption {
- return func(c *channel) {
- c.chanSendTimeout = d
- }
- }
- type channel struct {
- pubSub *PubSub
- msgCh chan *Message
- allCh chan interface{}
- ping chan struct{}
- chanSize int
- chanSendTimeout time.Duration
- checkInterval time.Duration
- }
- func newChannel(pubSub *PubSub, opts ...ChannelOption) *channel {
- c := &channel{
- pubSub: pubSub,
- chanSize: 100,
- chanSendTimeout: time.Minute,
- checkInterval: 3 * time.Second,
- }
- for _, opt := range opts {
- opt(c)
- }
- if c.checkInterval > 0 {
- c.initHealthCheck()
- }
- return c
- }
- func (c *channel) initHealthCheck() {
- ctx := context.TODO()
- c.ping = make(chan struct{}, 1)
- go func() {
- timer := time.NewTimer(time.Minute)
- timer.Stop()
- for {
- timer.Reset(c.checkInterval)
- select {
- case <-c.ping:
- if !timer.Stop() {
- <-timer.C
- }
- case <-timer.C:
- if pingErr := c.pubSub.Ping(ctx); pingErr != nil {
- c.pubSub.mu.Lock()
- c.pubSub.reconnect(ctx, pingErr)
- c.pubSub.mu.Unlock()
- }
- case <-c.pubSub.exit:
- return
- }
- }
- }()
- }
- // initMsgChan must be in sync with initAllChan.
- func (c *channel) initMsgChan() {
- ctx := context.TODO()
- c.msgCh = make(chan *Message, c.chanSize)
- go func() {
- timer := time.NewTimer(time.Minute)
- timer.Stop()
- var errCount int
- for {
- msg, err := c.pubSub.Receive(ctx)
- if err != nil {
- if err == pool.ErrClosed {
- close(c.msgCh)
- return
- }
- if errCount > 0 {
- time.Sleep(100 * time.Millisecond)
- }
- errCount++
- continue
- }
- errCount = 0
- // Any message is as good as a ping.
- select {
- case c.ping <- struct{}{}:
- default:
- }
- switch msg := msg.(type) {
- case *Subscription:
- // Ignore.
- case *Pong:
- // Ignore.
- case *Message:
- timer.Reset(c.chanSendTimeout)
- select {
- case c.msgCh <- msg:
- if !timer.Stop() {
- <-timer.C
- }
- case <-timer.C:
- internal.Logger.Printf(
- ctx, "redis: %s channel is full for %s (message is dropped)",
- c, c.chanSendTimeout)
- }
- default:
- internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
- }
- }
- }()
- }
- // initAllChan must be in sync with initMsgChan.
- func (c *channel) initAllChan() {
- ctx := context.TODO()
- c.allCh = make(chan interface{}, c.chanSize)
- go func() {
- timer := time.NewTimer(time.Minute)
- timer.Stop()
- var errCount int
- for {
- msg, err := c.pubSub.Receive(ctx)
- if err != nil {
- if err == pool.ErrClosed {
- close(c.allCh)
- return
- }
- if errCount > 0 {
- time.Sleep(100 * time.Millisecond)
- }
- errCount++
- continue
- }
- errCount = 0
- // Any message is as good as a ping.
- select {
- case c.ping <- struct{}{}:
- default:
- }
- switch msg := msg.(type) {
- case *Pong:
- // Ignore.
- case *Subscription, *Message:
- timer.Reset(c.chanSendTimeout)
- select {
- case c.allCh <- msg:
- if !timer.Stop() {
- <-timer.C
- }
- case <-timer.C:
- internal.Logger.Printf(
- ctx, "redis: %s channel is full for %s (message is dropped)",
- c, c.chanSendTimeout)
- }
- default:
- internal.Logger.Printf(ctx, "redis: unknown message type: %T", msg)
- }
- }
- }()
- }
|