mockbroker.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. package sarama
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "reflect"
  10. "strconv"
  11. "sync"
  12. "syscall"
  13. "time"
  14. "github.com/davecgh/go-spew/spew"
  15. )
  16. const (
  17. expectationTimeout = 500 * time.Millisecond
  18. )
  19. type GSSApiHandlerFunc func([]byte) []byte
  20. type requestHandlerFunc func(req *request) (res encoderWithHeader)
  21. // RequestNotifierFunc is invoked when a mock broker processes a request successfully
  22. // and will provides the number of bytes read and written.
  23. type RequestNotifierFunc func(bytesRead, bytesWritten int)
  24. // MockBroker is a mock Kafka broker that is used in unit tests. It is exposed
  25. // to facilitate testing of higher level or specialized consumers and producers
  26. // built on top of Sarama. Note that it does not 'mimic' the Kafka API protocol,
  27. // but rather provides a facility to do that. It takes care of the TCP
  28. // transport, request unmarshalling, response marshaling, and makes it the test
  29. // writer responsibility to program correct according to the Kafka API protocol
  30. // MockBroker behavior.
  31. //
  32. // MockBroker is implemented as a TCP server listening on a kernel-selected
  33. // localhost port that can accept many connections. It reads Kafka requests
  34. // from that connection and returns responses programmed by the SetHandlerByMap
  35. // function. If a MockBroker receives a request that it has no programmed
  36. // response for, then it returns nothing and the request times out.
  37. //
  38. // A set of MockRequest builders to define mappings used by MockBroker is
  39. // provided by Sarama. But users can develop MockRequests of their own and use
  40. // them along with or instead of the standard ones.
  41. //
  42. // When running tests with MockBroker it is strongly recommended to specify
  43. // a timeout to `go test` so that if the broker hangs waiting for a response,
  44. // the test panics.
  45. //
  46. // It is not necessary to prefix message length or correlation ID to your
  47. // response bytes, the server does that automatically as a convenience.
  48. type MockBroker struct {
  49. brokerID int32
  50. port int32
  51. closing chan none
  52. stopper chan none
  53. expectations chan encoderWithHeader
  54. listener net.Listener
  55. t TestReporter
  56. latency time.Duration
  57. handler requestHandlerFunc
  58. notifier RequestNotifierFunc
  59. history []RequestResponse
  60. lock sync.Mutex
  61. gssApiHandler GSSApiHandlerFunc
  62. }
  63. // RequestResponse represents a Request/Response pair processed by MockBroker.
  64. type RequestResponse struct {
  65. Request protocolBody
  66. Response encoder
  67. }
  68. // SetLatency makes broker pause for the specified period every time before
  69. // replying.
  70. func (b *MockBroker) SetLatency(latency time.Duration) {
  71. b.latency = latency
  72. }
  73. // SetHandlerByMap defines mapping of Request types to MockResponses. When a
  74. // request is received by the broker, it looks up the request type in the map
  75. // and uses the found MockResponse instance to generate an appropriate reply.
  76. // If the request type is not found in the map then nothing is sent.
  77. func (b *MockBroker) SetHandlerByMap(handlerMap map[string]MockResponse) {
  78. fnMap := make(map[string]MockResponse)
  79. for k, v := range handlerMap {
  80. fnMap[k] = v
  81. }
  82. b.setHandler(func(req *request) (res encoderWithHeader) {
  83. reqTypeName := reflect.TypeOf(req.body).Elem().Name()
  84. mockResponse := fnMap[reqTypeName]
  85. if mockResponse == nil {
  86. return nil
  87. }
  88. return mockResponse.For(req.body)
  89. })
  90. }
  91. // SetHandlerFuncByMap defines mapping of Request types to RequestHandlerFunc. When a
  92. // request is received by the broker, it looks up the request type in the map
  93. // and invoke the found RequestHandlerFunc instance to generate an appropriate reply.
  94. func (b *MockBroker) SetHandlerFuncByMap(handlerMap map[string]requestHandlerFunc) {
  95. fnMap := make(map[string]requestHandlerFunc)
  96. for k, v := range handlerMap {
  97. fnMap[k] = v
  98. }
  99. b.setHandler(func(req *request) (res encoderWithHeader) {
  100. reqTypeName := reflect.TypeOf(req.body).Elem().Name()
  101. return fnMap[reqTypeName](req)
  102. })
  103. }
  104. // SetNotifier set a function that will get invoked whenever a request has been
  105. // processed successfully and will provide the number of bytes read and written
  106. func (b *MockBroker) SetNotifier(notifier RequestNotifierFunc) {
  107. b.lock.Lock()
  108. b.notifier = notifier
  109. b.lock.Unlock()
  110. }
  111. // BrokerID returns broker ID assigned to the broker.
  112. func (b *MockBroker) BrokerID() int32 {
  113. return b.brokerID
  114. }
  115. // History returns a slice of RequestResponse pairs in the order they were
  116. // processed by the broker. Note that in case of multiple connections to the
  117. // broker the order expected by a test can be different from the order recorded
  118. // in the history, unless some synchronization is implemented in the test.
  119. func (b *MockBroker) History() []RequestResponse {
  120. b.lock.Lock()
  121. history := make([]RequestResponse, len(b.history))
  122. copy(history, b.history)
  123. b.lock.Unlock()
  124. return history
  125. }
  126. // Port returns the TCP port number the broker is listening for requests on.
  127. func (b *MockBroker) Port() int32 {
  128. return b.port
  129. }
  130. // Addr returns the broker connection string in the form "<address>:<port>".
  131. func (b *MockBroker) Addr() string {
  132. return b.listener.Addr().String()
  133. }
  134. // Close terminates the broker blocking until it stops internal goroutines and
  135. // releases all resources.
  136. func (b *MockBroker) Close() {
  137. close(b.expectations)
  138. if len(b.expectations) > 0 {
  139. buf := bytes.NewBufferString(fmt.Sprintf("mockbroker/%d: not all expectations were satisfied! Still waiting on:\n", b.BrokerID()))
  140. for e := range b.expectations {
  141. _, _ = buf.WriteString(spew.Sdump(e))
  142. }
  143. b.t.Error(buf.String())
  144. }
  145. close(b.closing)
  146. <-b.stopper
  147. }
  148. // setHandler sets the specified function as the request handler. Whenever
  149. // a mock broker reads a request from the wire it passes the request to the
  150. // function and sends back whatever the handler function returns.
  151. func (b *MockBroker) setHandler(handler requestHandlerFunc) {
  152. b.lock.Lock()
  153. b.handler = handler
  154. b.lock.Unlock()
  155. }
  156. func (b *MockBroker) serverLoop() {
  157. defer close(b.stopper)
  158. var err error
  159. var conn net.Conn
  160. go func() {
  161. <-b.closing
  162. err := b.listener.Close()
  163. if err != nil {
  164. b.t.Error(err)
  165. }
  166. }()
  167. wg := &sync.WaitGroup{}
  168. i := 0
  169. for conn, err = b.listener.Accept(); err == nil; conn, err = b.listener.Accept() {
  170. wg.Add(1)
  171. go b.handleRequests(conn, i, wg)
  172. i++
  173. }
  174. wg.Wait()
  175. if !isConnectionClosedError(err) {
  176. Logger.Printf("*** mockbroker/%d: listener closed, err=%v", b.BrokerID(), err)
  177. }
  178. }
  179. func (b *MockBroker) SetGSSAPIHandler(handler GSSApiHandlerFunc) {
  180. b.gssApiHandler = handler
  181. }
  182. func (b *MockBroker) readToBytes(r io.Reader) ([]byte, error) {
  183. var (
  184. bytesRead int
  185. lengthBytes = make([]byte, 4)
  186. )
  187. if _, err := io.ReadFull(r, lengthBytes); err != nil {
  188. return nil, err
  189. }
  190. bytesRead += len(lengthBytes)
  191. length := int32(binary.BigEndian.Uint32(lengthBytes))
  192. if length <= 4 || length > MaxRequestSize {
  193. return nil, PacketDecodingError{fmt.Sprintf("message of length %d too large or too small", length)}
  194. }
  195. encodedReq := make([]byte, length)
  196. if _, err := io.ReadFull(r, encodedReq); err != nil {
  197. return nil, err
  198. }
  199. bytesRead += len(encodedReq)
  200. fullBytes := append(lengthBytes, encodedReq...)
  201. return fullBytes, nil
  202. }
  203. func (b *MockBroker) isGSSAPI(buffer []byte) bool {
  204. return buffer[4] == 0x60 || bytes.Equal(buffer[4:6], []byte{0x05, 0x04})
  205. }
  206. func (b *MockBroker) handleRequests(conn io.ReadWriteCloser, idx int, wg *sync.WaitGroup) {
  207. defer wg.Done()
  208. defer func() {
  209. _ = conn.Close()
  210. }()
  211. s := spew.NewDefaultConfig()
  212. s.MaxDepth = 1
  213. Logger.Printf("*** mockbroker/%d/%d: connection opened", b.BrokerID(), idx)
  214. var err error
  215. abort := make(chan none)
  216. defer close(abort)
  217. go func() {
  218. select {
  219. case <-b.closing:
  220. _ = conn.Close()
  221. case <-abort:
  222. }
  223. }()
  224. var bytesWritten int
  225. var bytesRead int
  226. for {
  227. buffer, err := b.readToBytes(conn)
  228. if err != nil {
  229. if !isConnectionClosedError(err) {
  230. Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(buffer))
  231. b.serverError(err)
  232. }
  233. break
  234. }
  235. bytesWritten = 0
  236. if !b.isGSSAPI(buffer) {
  237. req, br, err := decodeRequest(bytes.NewReader(buffer))
  238. bytesRead = br
  239. if err != nil {
  240. if !isConnectionClosedError(err) {
  241. Logger.Printf("*** mockbroker/%d/%d: invalid request: err=%+v, %+v", b.brokerID, idx, err, spew.Sdump(req))
  242. b.serverError(err)
  243. }
  244. break
  245. }
  246. if b.latency > 0 {
  247. time.Sleep(b.latency)
  248. }
  249. b.lock.Lock()
  250. res := b.handler(req)
  251. b.history = append(b.history, RequestResponse{req.body, res})
  252. b.lock.Unlock()
  253. if res == nil {
  254. Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(req))
  255. continue
  256. }
  257. Logger.Printf(
  258. "*** mockbroker/%d/%d: replied to %T with %T\n-> %s\n-> %s",
  259. b.brokerID, idx, req.body, res,
  260. s.Sprintf("%#v", req.body),
  261. s.Sprintf("%#v", res),
  262. )
  263. encodedRes, err := encode(res, nil)
  264. if err != nil {
  265. b.serverError(fmt.Errorf("failed to encode %T - %w", res, err))
  266. break
  267. }
  268. if len(encodedRes) == 0 {
  269. b.lock.Lock()
  270. if b.notifier != nil {
  271. b.notifier(bytesRead, 0)
  272. }
  273. b.lock.Unlock()
  274. continue
  275. }
  276. resHeader := b.encodeHeader(res.headerVersion(), req.correlationID, uint32(len(encodedRes)))
  277. if _, err = conn.Write(resHeader); err != nil {
  278. b.serverError(err)
  279. break
  280. }
  281. if _, err = conn.Write(encodedRes); err != nil {
  282. b.serverError(err)
  283. break
  284. }
  285. bytesWritten = len(resHeader) + len(encodedRes)
  286. } else {
  287. // GSSAPI is not part of kafka protocol, but is supported for authentication proposes.
  288. // Don't support history for this kind of request as is only used for test GSSAPI authentication mechanism
  289. b.lock.Lock()
  290. res := b.gssApiHandler(buffer)
  291. b.lock.Unlock()
  292. if res == nil {
  293. Logger.Printf("*** mockbroker/%d/%d: ignored %v", b.brokerID, idx, spew.Sdump(buffer))
  294. continue
  295. }
  296. if _, err = conn.Write(res); err != nil {
  297. b.serverError(err)
  298. break
  299. }
  300. bytesWritten = len(res)
  301. }
  302. b.lock.Lock()
  303. if b.notifier != nil {
  304. b.notifier(bytesRead, bytesWritten)
  305. }
  306. b.lock.Unlock()
  307. }
  308. Logger.Printf("*** mockbroker/%d/%d: connection closed, err=%v", b.BrokerID(), idx, err)
  309. }
  310. func (b *MockBroker) encodeHeader(headerVersion int16, correlationId int32, payloadLength uint32) []byte {
  311. headerLength := uint32(8)
  312. if headerVersion >= 1 {
  313. headerLength = 9
  314. }
  315. resHeader := make([]byte, headerLength)
  316. binary.BigEndian.PutUint32(resHeader, payloadLength+headerLength-4)
  317. binary.BigEndian.PutUint32(resHeader[4:], uint32(correlationId))
  318. if headerVersion >= 1 {
  319. binary.PutUvarint(resHeader[8:], 0)
  320. }
  321. return resHeader
  322. }
  323. func (b *MockBroker) defaultRequestHandler(req *request) (res encoderWithHeader) {
  324. select {
  325. case res, ok := <-b.expectations:
  326. if !ok {
  327. return nil
  328. }
  329. return res
  330. case <-time.After(expectationTimeout):
  331. return nil
  332. }
  333. }
  334. func isConnectionClosedError(err error) bool {
  335. var result bool
  336. opError := &net.OpError{}
  337. if errors.As(err, &opError) {
  338. result = true
  339. } else if errors.Is(err, io.EOF) {
  340. result = true
  341. } else if err.Error() == "use of closed network connection" {
  342. result = true
  343. }
  344. return result
  345. }
  346. func (b *MockBroker) serverError(err error) {
  347. b.t.Helper()
  348. if isConnectionClosedError(err) {
  349. return
  350. }
  351. b.t.Errorf(err.Error())
  352. }
  353. // NewMockBroker launches a fake Kafka broker. It takes a TestReporter as provided by the
  354. // test framework and a channel of responses to use. If an error occurs it is
  355. // simply logged to the TestReporter and the broker exits.
  356. func NewMockBroker(t TestReporter, brokerID int32) *MockBroker {
  357. return NewMockBrokerAddr(t, brokerID, "localhost:0")
  358. }
  359. // NewMockBrokerAddr behaves like newMockBroker but listens on the address you give
  360. // it rather than just some ephemeral port.
  361. func NewMockBrokerAddr(t TestReporter, brokerID int32, addr string) *MockBroker {
  362. var (
  363. listener net.Listener
  364. err error
  365. )
  366. // retry up to 20 times if address already in use (e.g., if replacing broker which hasn't cleanly shutdown)
  367. for i := 0; i < 20; i++ {
  368. listener, err = net.Listen("tcp", addr)
  369. if err != nil {
  370. if errors.Is(err, syscall.EADDRINUSE) {
  371. Logger.Printf("*** mockbroker/%d waiting for %s (address already in use)", brokerID, addr)
  372. time.Sleep(time.Millisecond * 100)
  373. continue
  374. }
  375. t.Fatal(err)
  376. }
  377. break
  378. }
  379. if err != nil {
  380. t.Fatal(err)
  381. }
  382. return NewMockBrokerListener(t, brokerID, listener)
  383. }
  384. // NewMockBrokerListener behaves like newMockBrokerAddr but accepts connections on the listener specified.
  385. func NewMockBrokerListener(t TestReporter, brokerID int32, listener net.Listener) *MockBroker {
  386. var err error
  387. broker := &MockBroker{
  388. closing: make(chan none),
  389. stopper: make(chan none),
  390. t: t,
  391. brokerID: brokerID,
  392. expectations: make(chan encoderWithHeader, 512),
  393. listener: listener,
  394. }
  395. broker.handler = broker.defaultRequestHandler
  396. Logger.Printf("*** mockbroker/%d listening on %s\n", brokerID, broker.listener.Addr().String())
  397. _, portStr, err := net.SplitHostPort(broker.listener.Addr().String())
  398. if err != nil {
  399. t.Fatal(err)
  400. }
  401. tmp, err := strconv.ParseInt(portStr, 10, 32)
  402. if err != nil {
  403. t.Fatal(err)
  404. }
  405. broker.port = int32(tmp)
  406. go broker.serverLoop()
  407. return broker
  408. }
  409. func (b *MockBroker) Returns(e encoderWithHeader) {
  410. b.expectations <- e
  411. }