statement.go 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "database/sql/driver"
  11. "encoding/json"
  12. "fmt"
  13. "io"
  14. "reflect"
  15. )
  16. type mysqlStmt struct {
  17. mc *mysqlConn
  18. id uint32
  19. paramCount int
  20. }
  21. func (stmt *mysqlStmt) Close() error {
  22. if stmt.mc == nil || stmt.mc.closed.Load() {
  23. // driver.Stmt.Close can be called more than once, thus this function
  24. // has to be idempotent.
  25. // See also Issue #450 and golang/go#16019.
  26. //errLog.Print(ErrInvalidConn)
  27. return driver.ErrBadConn
  28. }
  29. err := stmt.mc.writeCommandPacketUint32(comStmtClose, stmt.id)
  30. stmt.mc = nil
  31. return err
  32. }
  33. func (stmt *mysqlStmt) NumInput() int {
  34. return stmt.paramCount
  35. }
  36. func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
  37. return converter{}
  38. }
  39. func (stmt *mysqlStmt) CheckNamedValue(nv *driver.NamedValue) (err error) {
  40. nv.Value, err = converter{}.ConvertValue(nv.Value)
  41. return
  42. }
  43. func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
  44. if stmt.mc.closed.Load() {
  45. errLog.Print(ErrInvalidConn)
  46. return nil, driver.ErrBadConn
  47. }
  48. // Send command
  49. err := stmt.writeExecutePacket(args)
  50. if err != nil {
  51. return nil, stmt.mc.markBadConn(err)
  52. }
  53. mc := stmt.mc
  54. mc.affectedRows = 0
  55. mc.insertId = 0
  56. // Read Result
  57. resLen, err := mc.readResultSetHeaderPacket()
  58. if err != nil {
  59. return nil, err
  60. }
  61. if resLen > 0 {
  62. // Columns
  63. if err = mc.readUntilEOF(); err != nil {
  64. return nil, err
  65. }
  66. // Rows
  67. if err := mc.readUntilEOF(); err != nil {
  68. return nil, err
  69. }
  70. }
  71. if err := mc.discardResults(); err != nil {
  72. return nil, err
  73. }
  74. return &mysqlResult{
  75. affectedRows: int64(mc.affectedRows),
  76. insertId: int64(mc.insertId),
  77. }, nil
  78. }
  79. func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
  80. return stmt.query(args)
  81. }
  82. func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
  83. if stmt.mc.closed.Load() {
  84. errLog.Print(ErrInvalidConn)
  85. return nil, driver.ErrBadConn
  86. }
  87. // Send command
  88. err := stmt.writeExecutePacket(args)
  89. if err != nil {
  90. return nil, stmt.mc.markBadConn(err)
  91. }
  92. mc := stmt.mc
  93. // Read Result
  94. resLen, err := mc.readResultSetHeaderPacket()
  95. if err != nil {
  96. return nil, err
  97. }
  98. rows := new(binaryRows)
  99. if resLen > 0 {
  100. rows.mc = mc
  101. rows.rs.columns, err = mc.readColumns(resLen)
  102. } else {
  103. rows.rs.done = true
  104. switch err := rows.NextResultSet(); err {
  105. case nil, io.EOF:
  106. return rows, nil
  107. default:
  108. return nil, err
  109. }
  110. }
  111. return rows, err
  112. }
  113. var jsonType = reflect.TypeOf(json.RawMessage{})
  114. type converter struct{}
  115. // ConvertValue mirrors the reference/default converter in database/sql/driver
  116. // with _one_ exception. We support uint64 with their high bit and the default
  117. // implementation does not. This function should be kept in sync with
  118. // database/sql/driver defaultConverter.ConvertValue() except for that
  119. // deliberate difference.
  120. func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
  121. if driver.IsValue(v) {
  122. return v, nil
  123. }
  124. if vr, ok := v.(driver.Valuer); ok {
  125. sv, err := callValuerValue(vr)
  126. if err != nil {
  127. return nil, err
  128. }
  129. if driver.IsValue(sv) {
  130. return sv, nil
  131. }
  132. // A value returned from the Valuer interface can be "a type handled by
  133. // a database driver's NamedValueChecker interface" so we should accept
  134. // uint64 here as well.
  135. if u, ok := sv.(uint64); ok {
  136. return u, nil
  137. }
  138. return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
  139. }
  140. rv := reflect.ValueOf(v)
  141. switch rv.Kind() {
  142. case reflect.Ptr:
  143. // indirect pointers
  144. if rv.IsNil() {
  145. return nil, nil
  146. } else {
  147. return c.ConvertValue(rv.Elem().Interface())
  148. }
  149. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  150. return rv.Int(), nil
  151. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  152. return rv.Uint(), nil
  153. case reflect.Float32, reflect.Float64:
  154. return rv.Float(), nil
  155. case reflect.Bool:
  156. return rv.Bool(), nil
  157. case reflect.Slice:
  158. switch t := rv.Type(); {
  159. case t == jsonType:
  160. return v, nil
  161. case t.Elem().Kind() == reflect.Uint8:
  162. return rv.Bytes(), nil
  163. default:
  164. return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, t.Elem().Kind())
  165. }
  166. case reflect.String:
  167. return rv.String(), nil
  168. }
  169. return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
  170. }
  171. var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
  172. // callValuerValue returns vr.Value(), with one exception:
  173. // If vr.Value is an auto-generated method on a pointer type and the
  174. // pointer is nil, it would panic at runtime in the panicwrap
  175. // method. Treat it like nil instead.
  176. //
  177. // This is so people can implement driver.Value on value types and
  178. // still use nil pointers to those types to mean nil/NULL, just like
  179. // string/*string.
  180. //
  181. // This is an exact copy of the same-named unexported function from the
  182. // database/sql package.
  183. func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
  184. if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
  185. rv.IsNil() &&
  186. rv.Type().Elem().Implements(valuerReflectType) {
  187. return nil, nil
  188. }
  189. return vr.Value()
  190. }