connector.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "context"
  11. "database/sql/driver"
  12. "net"
  13. "os"
  14. "strconv"
  15. "strings"
  16. )
  17. type connector struct {
  18. cfg *Config // immutable private copy.
  19. encodedAttributes string // Encoded connection attributes.
  20. }
  21. func encodeConnectionAttributes(cfg *Config) string {
  22. connAttrsBuf := make([]byte, 0)
  23. // default connection attributes
  24. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
  25. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
  26. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
  27. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
  28. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
  29. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
  30. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
  31. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
  32. serverHost, _, _ := net.SplitHostPort(cfg.Addr)
  33. if serverHost != "" {
  34. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost)
  35. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost)
  36. }
  37. // user-defined connection attributes
  38. for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
  39. k, v, found := strings.Cut(connAttr, ":")
  40. if !found {
  41. continue
  42. }
  43. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, k)
  44. connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
  45. }
  46. return string(connAttrsBuf)
  47. }
  48. func newConnector(cfg *Config) *connector {
  49. encodedAttributes := encodeConnectionAttributes(cfg)
  50. return &connector{
  51. cfg: cfg,
  52. encodedAttributes: encodedAttributes,
  53. }
  54. }
  55. // Connect implements driver.Connector interface.
  56. // Connect returns a connection to the database.
  57. func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
  58. var err error
  59. // Invoke beforeConnect if present, with a copy of the configuration
  60. cfg := c.cfg
  61. if c.cfg.beforeConnect != nil {
  62. cfg = c.cfg.Clone()
  63. err = c.cfg.beforeConnect(ctx, cfg)
  64. if err != nil {
  65. return nil, err
  66. }
  67. }
  68. // New mysqlConn
  69. mc := &mysqlConn{
  70. maxAllowedPacket: maxPacketSize,
  71. maxWriteSize: maxPacketSize - 1,
  72. closech: make(chan struct{}),
  73. cfg: cfg,
  74. connector: c,
  75. }
  76. mc.parseTime = mc.cfg.ParseTime
  77. // Connect to Server
  78. dialsLock.RLock()
  79. dial, ok := dials[mc.cfg.Net]
  80. dialsLock.RUnlock()
  81. if ok {
  82. dctx := ctx
  83. if mc.cfg.Timeout > 0 {
  84. var cancel context.CancelFunc
  85. dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
  86. defer cancel()
  87. }
  88. mc.netConn, err = dial(dctx, mc.cfg.Addr)
  89. } else {
  90. nd := net.Dialer{Timeout: mc.cfg.Timeout}
  91. mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
  92. }
  93. if err != nil {
  94. return nil, err
  95. }
  96. mc.rawConn = mc.netConn
  97. // Enable TCP Keepalives on TCP connections
  98. if tc, ok := mc.netConn.(*net.TCPConn); ok {
  99. if err := tc.SetKeepAlive(true); err != nil {
  100. c.cfg.Logger.Print(err)
  101. }
  102. }
  103. // Call startWatcher for context support (From Go 1.8)
  104. mc.startWatcher()
  105. if err := mc.watchCancel(ctx); err != nil {
  106. mc.cleanup()
  107. return nil, err
  108. }
  109. defer mc.finish()
  110. mc.buf = newBuffer(mc.netConn)
  111. // Set I/O timeouts
  112. mc.buf.timeout = mc.cfg.ReadTimeout
  113. mc.writeTimeout = mc.cfg.WriteTimeout
  114. // Reading Handshake Initialization Packet
  115. authData, plugin, err := mc.readHandshakePacket()
  116. if err != nil {
  117. mc.cleanup()
  118. return nil, err
  119. }
  120. if plugin == "" {
  121. plugin = defaultAuthPlugin
  122. }
  123. // Send Client Authentication Packet
  124. authResp, err := mc.auth(authData, plugin)
  125. if err != nil {
  126. // try the default auth plugin, if using the requested plugin failed
  127. c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
  128. plugin = defaultAuthPlugin
  129. authResp, err = mc.auth(authData, plugin)
  130. if err != nil {
  131. mc.cleanup()
  132. return nil, err
  133. }
  134. }
  135. if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
  136. mc.cleanup()
  137. return nil, err
  138. }
  139. // Handle response to auth packet, switch methods if possible
  140. if err = mc.handleAuthResult(authData, plugin); err != nil {
  141. // Authentication failed and MySQL has already closed the connection
  142. // (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
  143. // Do not send COM_QUIT, just cleanup and return the error.
  144. mc.cleanup()
  145. return nil, err
  146. }
  147. if mc.cfg.MaxAllowedPacket > 0 {
  148. mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
  149. } else {
  150. // Get max allowed packet size
  151. maxap, err := mc.getSystemVar("max_allowed_packet")
  152. if err != nil {
  153. mc.Close()
  154. return nil, err
  155. }
  156. mc.maxAllowedPacket = stringToInt(maxap) - 1
  157. }
  158. if mc.maxAllowedPacket < maxPacketSize {
  159. mc.maxWriteSize = mc.maxAllowedPacket
  160. }
  161. // Handle DSN Params
  162. err = mc.handleParams()
  163. if err != nil {
  164. mc.Close()
  165. return nil, err
  166. }
  167. return mc, nil
  168. }
  169. // Driver implements driver.Connector interface.
  170. // Driver returns &MySQLDriver{}.
  171. func (c *connector) Driver() driver.Driver {
  172. return &MySQLDriver{}
  173. }