123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197 |
- // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
- //
- // Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
- //
- // This Source Code Form is subject to the terms of the Mozilla Public
- // License, v. 2.0. If a copy of the MPL was not distributed with this file,
- // You can obtain one at http://mozilla.org/MPL/2.0/.
- package mysql
- import (
- "context"
- "database/sql/driver"
- "net"
- "os"
- "strconv"
- "strings"
- )
- type connector struct {
- cfg *Config // immutable private copy.
- encodedAttributes string // Encoded connection attributes.
- }
- func encodeConnectionAttributes(cfg *Config) string {
- connAttrsBuf := make([]byte, 0)
- // default connection attributes
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientName)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrClientNameValue)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOS)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrOSValue)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatform)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPlatformValue)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrPid)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, strconv.Itoa(os.Getpid()))
- serverHost, _, _ := net.SplitHostPort(cfg.Addr)
- if serverHost != "" {
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, connAttrServerHost)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, serverHost)
- }
- // user-defined connection attributes
- for _, connAttr := range strings.Split(cfg.ConnectionAttributes, ",") {
- k, v, found := strings.Cut(connAttr, ":")
- if !found {
- continue
- }
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, k)
- connAttrsBuf = appendLengthEncodedString(connAttrsBuf, v)
- }
- return string(connAttrsBuf)
- }
- func newConnector(cfg *Config) *connector {
- encodedAttributes := encodeConnectionAttributes(cfg)
- return &connector{
- cfg: cfg,
- encodedAttributes: encodedAttributes,
- }
- }
- // Connect implements driver.Connector interface.
- // Connect returns a connection to the database.
- func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
- var err error
- // Invoke beforeConnect if present, with a copy of the configuration
- cfg := c.cfg
- if c.cfg.beforeConnect != nil {
- cfg = c.cfg.Clone()
- err = c.cfg.beforeConnect(ctx, cfg)
- if err != nil {
- return nil, err
- }
- }
- // New mysqlConn
- mc := &mysqlConn{
- maxAllowedPacket: maxPacketSize,
- maxWriteSize: maxPacketSize - 1,
- closech: make(chan struct{}),
- cfg: cfg,
- connector: c,
- }
- mc.parseTime = mc.cfg.ParseTime
- // Connect to Server
- dialsLock.RLock()
- dial, ok := dials[mc.cfg.Net]
- dialsLock.RUnlock()
- if ok {
- dctx := ctx
- if mc.cfg.Timeout > 0 {
- var cancel context.CancelFunc
- dctx, cancel = context.WithTimeout(ctx, c.cfg.Timeout)
- defer cancel()
- }
- mc.netConn, err = dial(dctx, mc.cfg.Addr)
- } else {
- nd := net.Dialer{Timeout: mc.cfg.Timeout}
- mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
- }
- if err != nil {
- return nil, err
- }
- mc.rawConn = mc.netConn
- // Enable TCP Keepalives on TCP connections
- if tc, ok := mc.netConn.(*net.TCPConn); ok {
- if err := tc.SetKeepAlive(true); err != nil {
- c.cfg.Logger.Print(err)
- }
- }
- // Call startWatcher for context support (From Go 1.8)
- mc.startWatcher()
- if err := mc.watchCancel(ctx); err != nil {
- mc.cleanup()
- return nil, err
- }
- defer mc.finish()
- mc.buf = newBuffer(mc.netConn)
- // Set I/O timeouts
- mc.buf.timeout = mc.cfg.ReadTimeout
- mc.writeTimeout = mc.cfg.WriteTimeout
- // Reading Handshake Initialization Packet
- authData, plugin, err := mc.readHandshakePacket()
- if err != nil {
- mc.cleanup()
- return nil, err
- }
- if plugin == "" {
- plugin = defaultAuthPlugin
- }
- // Send Client Authentication Packet
- authResp, err := mc.auth(authData, plugin)
- if err != nil {
- // try the default auth plugin, if using the requested plugin failed
- c.cfg.Logger.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
- plugin = defaultAuthPlugin
- authResp, err = mc.auth(authData, plugin)
- if err != nil {
- mc.cleanup()
- return nil, err
- }
- }
- if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
- mc.cleanup()
- return nil, err
- }
- // Handle response to auth packet, switch methods if possible
- if err = mc.handleAuthResult(authData, plugin); err != nil {
- // Authentication failed and MySQL has already closed the connection
- // (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
- // Do not send COM_QUIT, just cleanup and return the error.
- mc.cleanup()
- return nil, err
- }
- if mc.cfg.MaxAllowedPacket > 0 {
- mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
- } else {
- // Get max allowed packet size
- maxap, err := mc.getSystemVar("max_allowed_packet")
- if err != nil {
- mc.Close()
- return nil, err
- }
- mc.maxAllowedPacket = stringToInt(maxap) - 1
- }
- if mc.maxAllowedPacket < maxPacketSize {
- mc.maxWriteSize = mc.maxAllowedPacket
- }
- // Handle DSN Params
- err = mc.handleParams()
- if err != nil {
- mc.Close()
- return nil, err
- }
- return mc, nil
- }
- // Driver implements driver.Connector interface.
- // Driver returns &MySQLDriver{}.
- func (c *connector) Driver() driver.Driver {
- return &MySQLDriver{}
- }
|