sql.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright 2018 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package builder
  5. import (
  6. sql2 "database/sql"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "time"
  11. )
  12. func condToSQL(cond Cond) (string, []interface{}, error) {
  13. if cond == nil || !cond.IsValid() {
  14. return "", nil, nil
  15. }
  16. w := NewWriter()
  17. if err := cond.WriteTo(w); err != nil {
  18. return "", nil, err
  19. }
  20. return w.String(), w.args, nil
  21. }
  22. func condToBoundSQL(cond Cond) (string, error) {
  23. if cond == nil || !cond.IsValid() {
  24. return "", nil
  25. }
  26. w := NewWriter()
  27. if err := cond.WriteTo(w); err != nil {
  28. return "", err
  29. }
  30. return ConvertToBoundSQL(w.String(), w.args)
  31. }
  32. // ToSQL convert a builder or conditions to SQL and args
  33. func ToSQL(cond interface{}) (string, []interface{}, error) {
  34. switch cond.(type) {
  35. case Cond:
  36. return condToSQL(cond.(Cond))
  37. case *Builder:
  38. return cond.(*Builder).ToSQL()
  39. }
  40. return "", nil, ErrNotSupportType
  41. }
  42. // ToBoundSQL convert a builder or conditions to parameters bound SQL
  43. func ToBoundSQL(cond interface{}) (string, error) {
  44. switch cond.(type) {
  45. case Cond:
  46. return condToBoundSQL(cond.(Cond))
  47. case *Builder:
  48. return cond.(*Builder).ToBoundSQL()
  49. }
  50. return "", ErrNotSupportType
  51. }
  52. func noSQLQuoteNeeded(a interface{}) bool {
  53. switch a.(type) {
  54. case int, int8, int16, int32, int64:
  55. return true
  56. case uint, uint8, uint16, uint32, uint64:
  57. return true
  58. case float32, float64:
  59. return true
  60. case bool:
  61. return true
  62. case string:
  63. return false
  64. case time.Time, *time.Time:
  65. return false
  66. }
  67. t := reflect.TypeOf(a)
  68. switch t.Kind() {
  69. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  70. return true
  71. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  72. return true
  73. case reflect.Float32, reflect.Float64:
  74. return true
  75. case reflect.Bool:
  76. return true
  77. case reflect.String:
  78. return false
  79. }
  80. return false
  81. }
  82. // ConvertToBoundSQL will convert SQL and args to a bound SQL
  83. func ConvertToBoundSQL(sql string, args []interface{}) (string, error) {
  84. buf := strings.Builder{}
  85. var i, j, start int
  86. for ; i < len(sql); i++ {
  87. if sql[i] == '?' {
  88. _, err := buf.WriteString(sql[start:i])
  89. if err != nil {
  90. return "", err
  91. }
  92. start = i + 1
  93. if len(args) == j {
  94. return "", ErrNeedMoreArguments
  95. }
  96. arg := args[j]
  97. if namedArg, ok := arg.(sql2.NamedArg); ok {
  98. arg = namedArg.Value
  99. }
  100. if noSQLQuoteNeeded(arg) {
  101. _, err = fmt.Fprint(&buf, arg)
  102. } else {
  103. // replace ' -> '' (standard replacement) to avoid critical SQL injection,
  104. // NOTICE: may allow some injection like % (or _) in LIKE query
  105. _, err = fmt.Fprintf(&buf, "'%v'", strings.Replace(fmt.Sprintf("%v", arg), "'",
  106. "''", -1))
  107. }
  108. if err != nil {
  109. return "", err
  110. }
  111. j = j + 1
  112. }
  113. }
  114. _, err := buf.WriteString(sql[start:])
  115. if err != nil {
  116. return "", err
  117. }
  118. return buf.String(), nil
  119. }
  120. // ConvertPlaceholder replaces ? to $1, $2 ... or :1, :2 ... according prefix
  121. func ConvertPlaceholder(sql, prefix string) (string, error) {
  122. buf := strings.Builder{}
  123. var i, j, start int
  124. for ; i < len(sql); i++ {
  125. if sql[i] == '?' {
  126. if _, err := buf.WriteString(sql[start:i]); err != nil {
  127. return "", err
  128. }
  129. start = i + 1
  130. j = j + 1
  131. if _, err := buf.WriteString(fmt.Sprintf("%v%d", prefix, j)); err != nil {
  132. return "", err
  133. }
  134. }
  135. }
  136. if _, err := buf.WriteString(sql[start:]); err != nil {
  137. return "", err
  138. }
  139. return buf.String(), nil
  140. }