statement.go 34 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "database/sql/driver"
  7. "fmt"
  8. "reflect"
  9. "strings"
  10. "time"
  11. "xorm.io/builder"
  12. "xorm.io/core"
  13. )
  14. // Statement save all the sql info for executing SQL
  15. type Statement struct {
  16. RefTable *core.Table
  17. Engine *Engine
  18. Start int
  19. LimitN int
  20. idParam *core.PK
  21. OrderStr string
  22. JoinStr string
  23. joinArgs []interface{}
  24. GroupByStr string
  25. HavingStr string
  26. ColumnStr string
  27. selectStr string
  28. useAllCols bool
  29. OmitStr string
  30. AltTableName string
  31. tableName string
  32. RawSQL string
  33. RawParams []interface{}
  34. UseCascade bool
  35. UseAutoJoin bool
  36. StoreEngine string
  37. Charset string
  38. UseCache bool
  39. UseAutoTime bool
  40. noAutoCondition bool
  41. IsDistinct bool
  42. IsForUpdate bool
  43. TableAlias string
  44. allUseBool bool
  45. checkVersion bool
  46. unscoped bool
  47. columnMap columnMap
  48. omitColumnMap columnMap
  49. mustColumnMap map[string]bool
  50. nullableMap map[string]bool
  51. incrColumns exprParams
  52. decrColumns exprParams
  53. exprColumns exprParams
  54. cond builder.Cond
  55. bufferSize int
  56. context ContextCache
  57. lastError error
  58. }
  59. // Init reset all the statement's fields
  60. func (statement *Statement) Init() {
  61. statement.RefTable = nil
  62. statement.Start = 0
  63. statement.LimitN = 0
  64. statement.OrderStr = ""
  65. statement.UseCascade = true
  66. statement.JoinStr = ""
  67. statement.joinArgs = make([]interface{}, 0)
  68. statement.GroupByStr = ""
  69. statement.HavingStr = ""
  70. statement.ColumnStr = ""
  71. statement.OmitStr = ""
  72. statement.columnMap = columnMap{}
  73. statement.omitColumnMap = columnMap{}
  74. statement.AltTableName = ""
  75. statement.tableName = ""
  76. statement.idParam = nil
  77. statement.RawSQL = ""
  78. statement.RawParams = make([]interface{}, 0)
  79. statement.UseCache = true
  80. statement.UseAutoTime = true
  81. statement.noAutoCondition = false
  82. statement.IsDistinct = false
  83. statement.IsForUpdate = false
  84. statement.TableAlias = ""
  85. statement.selectStr = ""
  86. statement.allUseBool = false
  87. statement.useAllCols = false
  88. statement.mustColumnMap = make(map[string]bool)
  89. statement.nullableMap = make(map[string]bool)
  90. statement.checkVersion = true
  91. statement.unscoped = false
  92. statement.incrColumns = exprParams{}
  93. statement.decrColumns = exprParams{}
  94. statement.exprColumns = exprParams{}
  95. statement.cond = builder.NewCond()
  96. statement.bufferSize = 0
  97. statement.context = nil
  98. statement.lastError = nil
  99. }
  100. // NoAutoCondition if you do not want convert bean's field as query condition, then use this function
  101. func (statement *Statement) NoAutoCondition(no ...bool) *Statement {
  102. statement.noAutoCondition = true
  103. if len(no) > 0 {
  104. statement.noAutoCondition = no[0]
  105. }
  106. return statement
  107. }
  108. // Alias set the table alias
  109. func (statement *Statement) Alias(alias string) *Statement {
  110. statement.TableAlias = alias
  111. return statement
  112. }
  113. // SQL adds raw sql statement
  114. func (statement *Statement) SQL(query interface{}, args ...interface{}) *Statement {
  115. switch query.(type) {
  116. case (*builder.Builder):
  117. var err error
  118. statement.RawSQL, statement.RawParams, err = query.(*builder.Builder).ToSQL()
  119. if err != nil {
  120. statement.lastError = err
  121. }
  122. case string:
  123. statement.RawSQL = query.(string)
  124. statement.RawParams = args
  125. default:
  126. statement.lastError = ErrUnSupportedSQLType
  127. }
  128. return statement
  129. }
  130. // Where add Where statement
  131. func (statement *Statement) Where(query interface{}, args ...interface{}) *Statement {
  132. return statement.And(query, args...)
  133. }
  134. // And add Where & and statement
  135. func (statement *Statement) And(query interface{}, args ...interface{}) *Statement {
  136. switch query.(type) {
  137. case string:
  138. cond := builder.Expr(query.(string), args...)
  139. statement.cond = statement.cond.And(cond)
  140. case map[string]interface{}:
  141. cond := builder.Eq(query.(map[string]interface{}))
  142. statement.cond = statement.cond.And(cond)
  143. case builder.Cond:
  144. cond := query.(builder.Cond)
  145. statement.cond = statement.cond.And(cond)
  146. for _, v := range args {
  147. if vv, ok := v.(builder.Cond); ok {
  148. statement.cond = statement.cond.And(vv)
  149. }
  150. }
  151. default:
  152. statement.lastError = ErrConditionType
  153. }
  154. return statement
  155. }
  156. // Or add Where & Or statement
  157. func (statement *Statement) Or(query interface{}, args ...interface{}) *Statement {
  158. switch query.(type) {
  159. case string:
  160. cond := builder.Expr(query.(string), args...)
  161. statement.cond = statement.cond.Or(cond)
  162. case map[string]interface{}:
  163. cond := builder.Eq(query.(map[string]interface{}))
  164. statement.cond = statement.cond.Or(cond)
  165. case builder.Cond:
  166. cond := query.(builder.Cond)
  167. statement.cond = statement.cond.Or(cond)
  168. for _, v := range args {
  169. if vv, ok := v.(builder.Cond); ok {
  170. statement.cond = statement.cond.Or(vv)
  171. }
  172. }
  173. default:
  174. // TODO: not support condition type
  175. }
  176. return statement
  177. }
  178. // In generate "Where column IN (?) " statement
  179. func (statement *Statement) In(column string, args ...interface{}) *Statement {
  180. in := builder.In(statement.Engine.Quote(column), args...)
  181. statement.cond = statement.cond.And(in)
  182. return statement
  183. }
  184. // NotIn generate "Where column NOT IN (?) " statement
  185. func (statement *Statement) NotIn(column string, args ...interface{}) *Statement {
  186. notIn := builder.NotIn(statement.Engine.Quote(column), args...)
  187. statement.cond = statement.cond.And(notIn)
  188. return statement
  189. }
  190. func (statement *Statement) setRefValue(v reflect.Value) error {
  191. var err error
  192. statement.RefTable, err = statement.Engine.autoMapType(reflect.Indirect(v))
  193. if err != nil {
  194. return err
  195. }
  196. statement.tableName = statement.Engine.TableName(v, true)
  197. return nil
  198. }
  199. func (statement *Statement) setRefBean(bean interface{}) error {
  200. var err error
  201. statement.RefTable, err = statement.Engine.autoMapType(rValue(bean))
  202. if err != nil {
  203. return err
  204. }
  205. statement.tableName = statement.Engine.TableName(bean, true)
  206. return nil
  207. }
  208. // Auto generating update columnes and values according a struct
  209. func (statement *Statement) buildUpdates(bean interface{},
  210. includeVersion, includeUpdated, includeNil,
  211. includeAutoIncr, update bool) ([]string, []interface{}) {
  212. engine := statement.Engine
  213. table := statement.RefTable
  214. allUseBool := statement.allUseBool
  215. useAllCols := statement.useAllCols
  216. mustColumnMap := statement.mustColumnMap
  217. nullableMap := statement.nullableMap
  218. columnMap := statement.columnMap
  219. omitColumnMap := statement.omitColumnMap
  220. unscoped := statement.unscoped
  221. var colNames = make([]string, 0)
  222. var args = make([]interface{}, 0)
  223. for _, col := range table.Columns() {
  224. if !includeVersion && col.IsVersion {
  225. continue
  226. }
  227. if col.IsCreated {
  228. continue
  229. }
  230. if !includeUpdated && col.IsUpdated {
  231. continue
  232. }
  233. if !includeAutoIncr && col.IsAutoIncrement {
  234. continue
  235. }
  236. if col.IsDeleted && !unscoped {
  237. continue
  238. }
  239. if omitColumnMap.contain(col.Name) {
  240. continue
  241. }
  242. if len(columnMap) > 0 && !columnMap.contain(col.Name) {
  243. continue
  244. }
  245. if col.MapType == core.ONLYFROMDB {
  246. continue
  247. }
  248. if statement.incrColumns.isColExist(col.Name) {
  249. continue
  250. } else if statement.decrColumns.isColExist(col.Name) {
  251. continue
  252. } else if statement.exprColumns.isColExist(col.Name) {
  253. continue
  254. }
  255. fieldValuePtr, err := col.ValueOf(bean)
  256. if err != nil {
  257. engine.logger.Error(err)
  258. continue
  259. }
  260. fieldValue := *fieldValuePtr
  261. fieldType := reflect.TypeOf(fieldValue.Interface())
  262. if fieldType == nil {
  263. continue
  264. }
  265. requiredField := useAllCols
  266. includeNil := useAllCols
  267. if b, ok := getFlagForColumn(mustColumnMap, col); ok {
  268. if b {
  269. requiredField = true
  270. } else {
  271. continue
  272. }
  273. }
  274. // !evalphobia! set fieldValue as nil when column is nullable and zero-value
  275. if b, ok := getFlagForColumn(nullableMap, col); ok {
  276. if b && col.Nullable && isZero(fieldValue.Interface()) {
  277. var nilValue *int
  278. fieldValue = reflect.ValueOf(nilValue)
  279. fieldType = reflect.TypeOf(fieldValue.Interface())
  280. includeNil = true
  281. }
  282. }
  283. var val interface{}
  284. if fieldValue.CanAddr() {
  285. if structConvert, ok := fieldValue.Addr().Interface().(core.Conversion); ok {
  286. data, err := structConvert.ToDB()
  287. if err != nil {
  288. engine.logger.Error(err)
  289. } else {
  290. val = data
  291. }
  292. goto APPEND
  293. }
  294. }
  295. if structConvert, ok := fieldValue.Interface().(core.Conversion); ok {
  296. data, err := structConvert.ToDB()
  297. if err != nil {
  298. engine.logger.Error(err)
  299. } else {
  300. val = data
  301. }
  302. goto APPEND
  303. }
  304. if fieldType.Kind() == reflect.Ptr {
  305. if fieldValue.IsNil() {
  306. if includeNil {
  307. args = append(args, nil)
  308. colNames = append(colNames, fmt.Sprintf("%v=?", engine.Quote(col.Name)))
  309. }
  310. continue
  311. } else if !fieldValue.IsValid() {
  312. continue
  313. } else {
  314. // dereference ptr type to instance type
  315. fieldValue = fieldValue.Elem()
  316. fieldType = reflect.TypeOf(fieldValue.Interface())
  317. requiredField = true
  318. }
  319. }
  320. switch fieldType.Kind() {
  321. case reflect.Bool:
  322. if allUseBool || requiredField {
  323. val = fieldValue.Interface()
  324. } else {
  325. // if a bool in a struct, it will not be as a condition because it default is false,
  326. // please use Where() instead
  327. continue
  328. }
  329. case reflect.String:
  330. if !requiredField && fieldValue.String() == "" {
  331. continue
  332. }
  333. // for MyString, should convert to string or panic
  334. if fieldType.String() != reflect.String.String() {
  335. val = fieldValue.String()
  336. } else {
  337. val = fieldValue.Interface()
  338. }
  339. case reflect.Int8, reflect.Int16, reflect.Int, reflect.Int32, reflect.Int64:
  340. if !requiredField && fieldValue.Int() == 0 {
  341. continue
  342. }
  343. val = fieldValue.Interface()
  344. case reflect.Float32, reflect.Float64:
  345. if !requiredField && fieldValue.Float() == 0.0 {
  346. continue
  347. }
  348. val = fieldValue.Interface()
  349. case reflect.Uint8, reflect.Uint16, reflect.Uint, reflect.Uint32, reflect.Uint64:
  350. if !requiredField && fieldValue.Uint() == 0 {
  351. continue
  352. }
  353. t := int64(fieldValue.Uint())
  354. val = reflect.ValueOf(&t).Interface()
  355. case reflect.Struct:
  356. if fieldType.ConvertibleTo(core.TimeType) {
  357. t := fieldValue.Convert(core.TimeType).Interface().(time.Time)
  358. if !requiredField && (t.IsZero() || !fieldValue.IsValid()) {
  359. continue
  360. }
  361. val = engine.formatColTime(col, t)
  362. } else if nulType, ok := fieldValue.Interface().(driver.Valuer); ok {
  363. val, _ = nulType.Value()
  364. } else {
  365. if !col.SQLType.IsJson() {
  366. engine.autoMapType(fieldValue)
  367. if table, ok := engine.Tables[fieldValue.Type()]; ok {
  368. if len(table.PrimaryKeys) == 1 {
  369. pkField := reflect.Indirect(fieldValue).FieldByName(table.PKColumns()[0].FieldName)
  370. // fix non-int pk issues
  371. if pkField.IsValid() && (!requiredField && !isZero(pkField.Interface())) {
  372. val = pkField.Interface()
  373. } else {
  374. continue
  375. }
  376. } else {
  377. // TODO: how to handler?
  378. panic("not supported")
  379. }
  380. } else {
  381. val = fieldValue.Interface()
  382. }
  383. } else {
  384. // Blank struct could not be as update data
  385. if requiredField || !isStructZero(fieldValue) {
  386. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  387. if err != nil {
  388. panic(fmt.Sprintf("mashal %v failed", fieldValue.Interface()))
  389. }
  390. if col.SQLType.IsText() {
  391. val = string(bytes)
  392. } else if col.SQLType.IsBlob() {
  393. val = bytes
  394. }
  395. } else {
  396. continue
  397. }
  398. }
  399. }
  400. case reflect.Array, reflect.Slice, reflect.Map:
  401. if !requiredField {
  402. if fieldValue == reflect.Zero(fieldType) {
  403. continue
  404. }
  405. if fieldType.Kind() == reflect.Array {
  406. if isArrayValueZero(fieldValue) {
  407. continue
  408. }
  409. } else if fieldValue.IsNil() || !fieldValue.IsValid() || fieldValue.Len() == 0 {
  410. continue
  411. }
  412. }
  413. if col.SQLType.IsText() {
  414. bytes, err := DefaultJSONHandler.Marshal(fieldValue.Interface())
  415. if err != nil {
  416. engine.logger.Error(err)
  417. continue
  418. }
  419. val = string(bytes)
  420. } else if col.SQLType.IsBlob() {
  421. var bytes []byte
  422. var err error
  423. if fieldType.Kind() == reflect.Slice &&
  424. fieldType.Elem().Kind() == reflect.Uint8 {
  425. if fieldValue.Len() > 0 {
  426. val = fieldValue.Bytes()
  427. } else {
  428. continue
  429. }
  430. } else if fieldType.Kind() == reflect.Array &&
  431. fieldType.Elem().Kind() == reflect.Uint8 {
  432. val = fieldValue.Slice(0, 0).Interface()
  433. } else {
  434. bytes, err = DefaultJSONHandler.Marshal(fieldValue.Interface())
  435. if err != nil {
  436. engine.logger.Error(err)
  437. continue
  438. }
  439. val = bytes
  440. }
  441. } else {
  442. continue
  443. }
  444. default:
  445. val = fieldValue.Interface()
  446. }
  447. APPEND:
  448. args = append(args, val)
  449. if col.IsPrimaryKey && engine.dialect.DBType() == "ql" {
  450. continue
  451. }
  452. colNames = append(colNames, fmt.Sprintf("%v = ?", engine.Quote(col.Name)))
  453. }
  454. return colNames, args
  455. }
  456. func (statement *Statement) needTableName() bool {
  457. return len(statement.JoinStr) > 0
  458. }
  459. func (statement *Statement) colName(col *core.Column, tableName string) string {
  460. if statement.needTableName() {
  461. var nm = tableName
  462. if len(statement.TableAlias) > 0 {
  463. nm = statement.TableAlias
  464. }
  465. return statement.Engine.Quote(nm) + "." + statement.Engine.Quote(col.Name)
  466. }
  467. return statement.Engine.Quote(col.Name)
  468. }
  469. // TableName return current tableName
  470. func (statement *Statement) TableName() string {
  471. if statement.AltTableName != "" {
  472. return statement.AltTableName
  473. }
  474. return statement.tableName
  475. }
  476. // ID generate "where id = ? " statement or for composite key "where key1 = ? and key2 = ?"
  477. func (statement *Statement) ID(id interface{}) *Statement {
  478. idValue := reflect.ValueOf(id)
  479. idType := reflect.TypeOf(idValue.Interface())
  480. switch idType {
  481. case ptrPkType:
  482. if pkPtr, ok := (id).(*core.PK); ok {
  483. statement.idParam = pkPtr
  484. return statement
  485. }
  486. case pkType:
  487. if pk, ok := (id).(core.PK); ok {
  488. statement.idParam = &pk
  489. return statement
  490. }
  491. }
  492. switch idType.Kind() {
  493. case reflect.String:
  494. statement.idParam = &core.PK{idValue.Convert(reflect.TypeOf("")).Interface()}
  495. return statement
  496. }
  497. statement.idParam = &core.PK{id}
  498. return statement
  499. }
  500. // Incr Generate "Update ... Set column = column + arg" statement
  501. func (statement *Statement) Incr(column string, arg ...interface{}) *Statement {
  502. if len(arg) > 0 {
  503. statement.incrColumns.addParam(column, arg[0])
  504. } else {
  505. statement.incrColumns.addParam(column, 1)
  506. }
  507. return statement
  508. }
  509. // Decr Generate "Update ... Set column = column - arg" statement
  510. func (statement *Statement) Decr(column string, arg ...interface{}) *Statement {
  511. if len(arg) > 0 {
  512. statement.decrColumns.addParam(column, arg[0])
  513. } else {
  514. statement.decrColumns.addParam(column, 1)
  515. }
  516. return statement
  517. }
  518. // SetExpr Generate "Update ... Set column = {expression}" statement
  519. func (statement *Statement) SetExpr(column string, expression interface{}) *Statement {
  520. statement.exprColumns.addParam(column, expression)
  521. return statement
  522. }
  523. func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
  524. newColumns := make([]string, 0)
  525. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  526. for _, col := range columns {
  527. newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
  528. }
  529. return newColumns
  530. }
  531. func (statement *Statement) colmap2NewColsWithQuote() []string {
  532. newColumns := make([]string, len(statement.columnMap), len(statement.columnMap))
  533. copy(newColumns, statement.columnMap)
  534. for i := 0; i < len(statement.columnMap); i++ {
  535. newColumns[i] = statement.Engine.Quote(newColumns[i])
  536. }
  537. return newColumns
  538. }
  539. // Distinct generates "DISTINCT col1, col2 " statement
  540. func (statement *Statement) Distinct(columns ...string) *Statement {
  541. statement.IsDistinct = true
  542. statement.Cols(columns...)
  543. return statement
  544. }
  545. // ForUpdate generates "SELECT ... FOR UPDATE" statement
  546. func (statement *Statement) ForUpdate() *Statement {
  547. statement.IsForUpdate = true
  548. return statement
  549. }
  550. // Select replace select
  551. func (statement *Statement) Select(str string) *Statement {
  552. statement.selectStr = str
  553. return statement
  554. }
  555. // Cols generate "col1, col2" statement
  556. func (statement *Statement) Cols(columns ...string) *Statement {
  557. cols := col2NewCols(columns...)
  558. for _, nc := range cols {
  559. statement.columnMap.add(nc)
  560. }
  561. newColumns := statement.colmap2NewColsWithQuote()
  562. statement.ColumnStr = strings.Join(newColumns, ", ")
  563. statement.ColumnStr = strings.Replace(statement.ColumnStr, statement.Engine.quote("*"), "*", -1)
  564. return statement
  565. }
  566. // AllCols update use only: update all columns
  567. func (statement *Statement) AllCols() *Statement {
  568. statement.useAllCols = true
  569. return statement
  570. }
  571. // MustCols update use only: must update columns
  572. func (statement *Statement) MustCols(columns ...string) *Statement {
  573. newColumns := col2NewCols(columns...)
  574. for _, nc := range newColumns {
  575. statement.mustColumnMap[strings.ToLower(nc)] = true
  576. }
  577. return statement
  578. }
  579. // UseBool indicates that use bool fields as update contents and query contiditions
  580. func (statement *Statement) UseBool(columns ...string) *Statement {
  581. if len(columns) > 0 {
  582. statement.MustCols(columns...)
  583. } else {
  584. statement.allUseBool = true
  585. }
  586. return statement
  587. }
  588. // Omit do not use the columns
  589. func (statement *Statement) Omit(columns ...string) {
  590. newColumns := col2NewCols(columns...)
  591. for _, nc := range newColumns {
  592. statement.omitColumnMap = append(statement.omitColumnMap, nc)
  593. }
  594. statement.OmitStr = statement.Engine.Quote(strings.Join(newColumns, statement.Engine.Quote(", ")))
  595. }
  596. // Nullable Update use only: update columns to null when value is nullable and zero-value
  597. func (statement *Statement) Nullable(columns ...string) {
  598. newColumns := col2NewCols(columns...)
  599. for _, nc := range newColumns {
  600. statement.nullableMap[strings.ToLower(nc)] = true
  601. }
  602. }
  603. // Top generate LIMIT limit statement
  604. func (statement *Statement) Top(limit int) *Statement {
  605. statement.Limit(limit)
  606. return statement
  607. }
  608. // Limit generate LIMIT start, limit statement
  609. func (statement *Statement) Limit(limit int, start ...int) *Statement {
  610. statement.LimitN = limit
  611. if len(start) > 0 {
  612. statement.Start = start[0]
  613. }
  614. return statement
  615. }
  616. // OrderBy generate "Order By order" statement
  617. func (statement *Statement) OrderBy(order string) *Statement {
  618. if len(statement.OrderStr) > 0 {
  619. statement.OrderStr += ", "
  620. }
  621. statement.OrderStr += order
  622. return statement
  623. }
  624. // Desc generate `ORDER BY xx DESC`
  625. func (statement *Statement) Desc(colNames ...string) *Statement {
  626. var buf strings.Builder
  627. if len(statement.OrderStr) > 0 {
  628. fmt.Fprint(&buf, statement.OrderStr, ", ")
  629. }
  630. newColNames := statement.col2NewColsWithQuote(colNames...)
  631. fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, "))
  632. statement.OrderStr = buf.String()
  633. return statement
  634. }
  635. // Asc provide asc order by query condition, the input parameters are columns.
  636. func (statement *Statement) Asc(colNames ...string) *Statement {
  637. var buf strings.Builder
  638. if len(statement.OrderStr) > 0 {
  639. fmt.Fprint(&buf, statement.OrderStr, ", ")
  640. }
  641. newColNames := statement.col2NewColsWithQuote(colNames...)
  642. fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, "))
  643. statement.OrderStr = buf.String()
  644. return statement
  645. }
  646. // Table tempororily set table name, the parameter could be a string or a pointer of struct
  647. func (statement *Statement) Table(tableNameOrBean interface{}) *Statement {
  648. v := rValue(tableNameOrBean)
  649. t := v.Type()
  650. if t.Kind() == reflect.Struct {
  651. var err error
  652. statement.RefTable, err = statement.Engine.autoMapType(v)
  653. if err != nil {
  654. statement.Engine.logger.Error(err)
  655. return statement
  656. }
  657. }
  658. statement.AltTableName = statement.Engine.TableName(tableNameOrBean, true)
  659. return statement
  660. }
  661. // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN
  662. func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement {
  663. var buf strings.Builder
  664. if len(statement.JoinStr) > 0 {
  665. fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP)
  666. } else {
  667. fmt.Fprintf(&buf, "%v JOIN ", joinOP)
  668. }
  669. switch tp := tablename.(type) {
  670. case builder.Builder:
  671. subSQL, subQueryArgs, err := tp.ToSQL()
  672. if err != nil {
  673. statement.lastError = err
  674. return statement
  675. }
  676. tbs := strings.Split(tp.TableName(), ".")
  677. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  678. var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
  679. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  680. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  681. case *builder.Builder:
  682. subSQL, subQueryArgs, err := tp.ToSQL()
  683. if err != nil {
  684. statement.lastError = err
  685. return statement
  686. }
  687. tbs := strings.Split(tp.TableName(), ".")
  688. quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
  689. var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
  690. fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
  691. statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
  692. default:
  693. tbName := statement.Engine.TableName(tablename, true)
  694. fmt.Fprintf(&buf, "%s ON %v", tbName, condition)
  695. }
  696. statement.JoinStr = buf.String()
  697. statement.joinArgs = append(statement.joinArgs, args...)
  698. return statement
  699. }
  700. // GroupBy generate "Group By keys" statement
  701. func (statement *Statement) GroupBy(keys string) *Statement {
  702. statement.GroupByStr = keys
  703. return statement
  704. }
  705. // Having generate "Having conditions" statement
  706. func (statement *Statement) Having(conditions string) *Statement {
  707. statement.HavingStr = fmt.Sprintf("HAVING %v", conditions)
  708. return statement
  709. }
  710. // Unscoped always disable struct tag "deleted"
  711. func (statement *Statement) Unscoped() *Statement {
  712. statement.unscoped = true
  713. return statement
  714. }
  715. func (statement *Statement) genColumnStr() string {
  716. if statement.RefTable == nil {
  717. return ""
  718. }
  719. var buf strings.Builder
  720. columns := statement.RefTable.Columns()
  721. for _, col := range columns {
  722. if statement.omitColumnMap.contain(col.Name) {
  723. continue
  724. }
  725. if len(statement.columnMap) > 0 && !statement.columnMap.contain(col.Name) {
  726. continue
  727. }
  728. if col.MapType == core.ONLYTODB {
  729. continue
  730. }
  731. if buf.Len() != 0 {
  732. buf.WriteString(", ")
  733. }
  734. if statement.JoinStr != "" {
  735. if statement.TableAlias != "" {
  736. buf.WriteString(statement.TableAlias)
  737. } else {
  738. buf.WriteString(statement.TableName())
  739. }
  740. buf.WriteString(".")
  741. }
  742. statement.Engine.QuoteTo(&buf, col.Name)
  743. }
  744. return buf.String()
  745. }
  746. func (statement *Statement) genCreateTableSQL() string {
  747. return statement.Engine.dialect.CreateTableSql(statement.RefTable, statement.TableName(),
  748. statement.StoreEngine, statement.Charset)
  749. }
  750. func (statement *Statement) genIndexSQL() []string {
  751. var sqls []string
  752. tbName := statement.TableName()
  753. for _, index := range statement.RefTable.Indexes {
  754. if index.Type == core.IndexType {
  755. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  756. /*idxTBName := strings.Replace(tbName, ".", "_", -1)
  757. idxTBName = strings.Replace(idxTBName, `"`, "", -1)
  758. sql := fmt.Sprintf("CREATE INDEX %v ON %v (%v);", quote(indexName(idxTBName, idxName)),
  759. quote(tbName), quote(strings.Join(index.Cols, quote(","))))*/
  760. sqls = append(sqls, sql)
  761. }
  762. }
  763. return sqls
  764. }
  765. func uniqueName(tableName, uqeName string) string {
  766. return fmt.Sprintf("UQE_%v_%v", tableName, uqeName)
  767. }
  768. func (statement *Statement) genUniqueSQL() []string {
  769. var sqls []string
  770. tbName := statement.TableName()
  771. for _, index := range statement.RefTable.Indexes {
  772. if index.Type == core.UniqueType {
  773. sql := statement.Engine.dialect.CreateIndexSql(tbName, index)
  774. sqls = append(sqls, sql)
  775. }
  776. }
  777. return sqls
  778. }
  779. func (statement *Statement) genDelIndexSQL() []string {
  780. var sqls []string
  781. tbName := statement.TableName()
  782. idxPrefixName := strings.Replace(tbName, `"`, "", -1)
  783. idxPrefixName = strings.Replace(idxPrefixName, `.`, "_", -1)
  784. for idxName, index := range statement.RefTable.Indexes {
  785. var rIdxName string
  786. if index.Type == core.UniqueType {
  787. rIdxName = uniqueName(idxPrefixName, idxName)
  788. } else if index.Type == core.IndexType {
  789. rIdxName = indexName(idxPrefixName, idxName)
  790. }
  791. sql := fmt.Sprintf("DROP INDEX %v", statement.Engine.Quote(statement.Engine.TableName(rIdxName, true)))
  792. if statement.Engine.dialect.IndexOnTable() {
  793. sql += fmt.Sprintf(" ON %v", statement.Engine.Quote(tbName))
  794. }
  795. sqls = append(sqls, sql)
  796. }
  797. return sqls
  798. }
  799. func (statement *Statement) genAddColumnStr(col *core.Column) (string, []interface{}) {
  800. quote := statement.Engine.Quote
  801. sql := fmt.Sprintf("ALTER TABLE %v ADD %v", quote(statement.TableName()),
  802. col.String(statement.Engine.dialect))
  803. if statement.Engine.dialect.DBType() == core.MYSQL && len(col.Comment) > 0 {
  804. sql += " COMMENT '" + col.Comment + "'"
  805. }
  806. sql += ";"
  807. return sql, []interface{}{}
  808. }
  809. func (statement *Statement) buildConds(table *core.Table, bean interface{}, includeVersion bool, includeUpdated bool, includeNil bool, includeAutoIncr bool, addedTableName bool) (builder.Cond, error) {
  810. return statement.Engine.buildConds(table, bean, includeVersion, includeUpdated, includeNil, includeAutoIncr, statement.allUseBool, statement.useAllCols,
  811. statement.unscoped, statement.mustColumnMap, statement.TableName(), statement.TableAlias, addedTableName)
  812. }
  813. func (statement *Statement) mergeConds(bean interface{}) error {
  814. if !statement.noAutoCondition {
  815. var addedTableName = (len(statement.JoinStr) > 0)
  816. autoCond, err := statement.buildConds(statement.RefTable, bean, true, true, false, true, addedTableName)
  817. if err != nil {
  818. return err
  819. }
  820. statement.cond = statement.cond.And(autoCond)
  821. }
  822. if err := statement.processIDParam(); err != nil {
  823. return err
  824. }
  825. return nil
  826. }
  827. func (statement *Statement) genConds(bean interface{}) (string, []interface{}, error) {
  828. if err := statement.mergeConds(bean); err != nil {
  829. return "", nil, err
  830. }
  831. return builder.ToSQL(statement.cond)
  832. }
  833. func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, error) {
  834. v := rValue(bean)
  835. isStruct := v.Kind() == reflect.Struct
  836. if isStruct {
  837. statement.setRefBean(bean)
  838. }
  839. var columnStr = statement.ColumnStr
  840. if len(statement.selectStr) > 0 {
  841. columnStr = statement.selectStr
  842. } else {
  843. // TODO: always generate column names, not use * even if join
  844. if len(statement.JoinStr) == 0 {
  845. if len(columnStr) == 0 {
  846. if len(statement.GroupByStr) > 0 {
  847. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  848. } else {
  849. columnStr = statement.genColumnStr()
  850. }
  851. }
  852. } else {
  853. if len(columnStr) == 0 {
  854. if len(statement.GroupByStr) > 0 {
  855. columnStr = statement.Engine.quoteColumns(statement.GroupByStr)
  856. }
  857. }
  858. }
  859. }
  860. if len(columnStr) == 0 {
  861. columnStr = "*"
  862. }
  863. if isStruct {
  864. if err := statement.mergeConds(bean); err != nil {
  865. return "", nil, err
  866. }
  867. } else {
  868. if err := statement.processIDParam(); err != nil {
  869. return "", nil, err
  870. }
  871. }
  872. condSQL, condArgs, err := builder.ToSQL(statement.cond)
  873. if err != nil {
  874. return "", nil, err
  875. }
  876. sqlStr, err := statement.genSelectSQL(columnStr, condSQL, true, true)
  877. if err != nil {
  878. return "", nil, err
  879. }
  880. return sqlStr, append(statement.joinArgs, condArgs...), nil
  881. }
  882. func (statement *Statement) genCountSQL(beans ...interface{}) (string, []interface{}, error) {
  883. var condSQL string
  884. var condArgs []interface{}
  885. var err error
  886. if len(beans) > 0 {
  887. statement.setRefBean(beans[0])
  888. condSQL, condArgs, err = statement.genConds(beans[0])
  889. } else {
  890. condSQL, condArgs, err = builder.ToSQL(statement.cond)
  891. }
  892. if err != nil {
  893. return "", nil, err
  894. }
  895. var selectSQL = statement.selectStr
  896. if len(selectSQL) <= 0 {
  897. if statement.IsDistinct {
  898. selectSQL = fmt.Sprintf("count(DISTINCT %s)", statement.ColumnStr)
  899. } else {
  900. selectSQL = "count(*)"
  901. }
  902. }
  903. sqlStr, err := statement.genSelectSQL(selectSQL, condSQL, false, false)
  904. if err != nil {
  905. return "", nil, err
  906. }
  907. return sqlStr, append(statement.joinArgs, condArgs...), nil
  908. }
  909. func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (string, []interface{}, error) {
  910. statement.setRefBean(bean)
  911. var sumStrs = make([]string, 0, len(columns))
  912. for _, colName := range columns {
  913. if !strings.Contains(colName, " ") && !strings.Contains(colName, "(") {
  914. colName = statement.Engine.Quote(colName)
  915. }
  916. sumStrs = append(sumStrs, fmt.Sprintf("COALESCE(sum(%s),0)", colName))
  917. }
  918. sumSelect := strings.Join(sumStrs, ", ")
  919. condSQL, condArgs, err := statement.genConds(bean)
  920. if err != nil {
  921. return "", nil, err
  922. }
  923. sqlStr, err := statement.genSelectSQL(sumSelect, condSQL, true, true)
  924. if err != nil {
  925. return "", nil, err
  926. }
  927. return sqlStr, append(statement.joinArgs, condArgs...), nil
  928. }
  929. func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) {
  930. var (
  931. distinct string
  932. dialect = statement.Engine.Dialect()
  933. quote = statement.Engine.Quote
  934. fromStr = " FROM "
  935. top, mssqlCondi, whereStr string
  936. )
  937. if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") {
  938. distinct = "DISTINCT "
  939. }
  940. if len(condSQL) > 0 {
  941. whereStr = " WHERE " + condSQL
  942. }
  943. if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") {
  944. fromStr += statement.TableName()
  945. } else {
  946. fromStr += quote(statement.TableName())
  947. }
  948. if statement.TableAlias != "" {
  949. if dialect.DBType() == core.ORACLE {
  950. fromStr += " " + quote(statement.TableAlias)
  951. } else {
  952. fromStr += " AS " + quote(statement.TableAlias)
  953. }
  954. }
  955. if statement.JoinStr != "" {
  956. fromStr = fmt.Sprintf("%v %v", fromStr, statement.JoinStr)
  957. }
  958. if dialect.DBType() == core.MSSQL {
  959. if statement.LimitN > 0 {
  960. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  961. }
  962. if statement.Start > 0 {
  963. var column string
  964. if len(statement.RefTable.PKColumns()) == 0 {
  965. for _, index := range statement.RefTable.Indexes {
  966. if len(index.Cols) == 1 {
  967. column = index.Cols[0]
  968. break
  969. }
  970. }
  971. if len(column) == 0 {
  972. column = statement.RefTable.ColumnsSeq()[0]
  973. }
  974. } else {
  975. column = statement.RefTable.PKColumns()[0].Name
  976. }
  977. if statement.needTableName() {
  978. if len(statement.TableAlias) > 0 {
  979. column = statement.TableAlias + "." + column
  980. } else {
  981. column = statement.TableName() + "." + column
  982. }
  983. }
  984. var orderStr string
  985. if needOrderBy && len(statement.OrderStr) > 0 {
  986. orderStr = " ORDER BY " + statement.OrderStr
  987. }
  988. var groupStr string
  989. if len(statement.GroupByStr) > 0 {
  990. groupStr = " GROUP BY " + statement.GroupByStr
  991. }
  992. mssqlCondi = fmt.Sprintf("(%s NOT IN (SELECT TOP %d %s%s%s%s%s))",
  993. column, statement.Start, column, fromStr, whereStr, orderStr, groupStr)
  994. }
  995. }
  996. var buf strings.Builder
  997. fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr)
  998. if len(mssqlCondi) > 0 {
  999. if len(whereStr) > 0 {
  1000. fmt.Fprint(&buf, " AND ", mssqlCondi)
  1001. } else {
  1002. fmt.Fprint(&buf, " WHERE ", mssqlCondi)
  1003. }
  1004. }
  1005. if statement.GroupByStr != "" {
  1006. fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr)
  1007. }
  1008. if statement.HavingStr != "" {
  1009. fmt.Fprint(&buf, " ", statement.HavingStr)
  1010. }
  1011. if needOrderBy && statement.OrderStr != "" {
  1012. fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr)
  1013. }
  1014. if needLimit {
  1015. if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE {
  1016. if statement.Start > 0 {
  1017. fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start)
  1018. } else if statement.LimitN > 0 {
  1019. fmt.Fprint(&buf, " LIMIT ", statement.LimitN)
  1020. }
  1021. } else if dialect.DBType() == core.ORACLE {
  1022. if statement.Start != 0 || statement.LimitN != 0 {
  1023. oldString := buf.String()
  1024. buf.Reset()
  1025. rawColStr := columnStr
  1026. if rawColStr == "*" {
  1027. rawColStr = "at.*"
  1028. }
  1029. fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d",
  1030. columnStr, rawColStr, oldString, statement.Start+statement.LimitN, statement.Start)
  1031. }
  1032. }
  1033. }
  1034. if statement.IsForUpdate {
  1035. return dialect.ForUpdateSql(buf.String()), nil
  1036. }
  1037. return buf.String(), nil
  1038. }
  1039. func (statement *Statement) processIDParam() error {
  1040. if statement.idParam == nil || statement.RefTable == nil {
  1041. return nil
  1042. }
  1043. if len(statement.RefTable.PrimaryKeys) != len(*statement.idParam) {
  1044. return fmt.Errorf("ID condition is error, expect %d primarykeys, there are %d",
  1045. len(statement.RefTable.PrimaryKeys),
  1046. len(*statement.idParam),
  1047. )
  1048. }
  1049. for i, col := range statement.RefTable.PKColumns() {
  1050. var colName = statement.colName(col, statement.TableName())
  1051. statement.cond = statement.cond.And(builder.Eq{colName: (*(statement.idParam))[i]})
  1052. }
  1053. return nil
  1054. }
  1055. func (statement *Statement) joinColumns(cols []*core.Column, includeTableName bool) string {
  1056. var colnames = make([]string, len(cols))
  1057. for i, col := range cols {
  1058. if includeTableName {
  1059. colnames[i] = statement.Engine.Quote(statement.TableName()) +
  1060. "." + statement.Engine.Quote(col.Name)
  1061. } else {
  1062. colnames[i] = statement.Engine.Quote(col.Name)
  1063. }
  1064. }
  1065. return strings.Join(colnames, ", ")
  1066. }
  1067. func (statement *Statement) convertIDSQL(sqlStr string) string {
  1068. if statement.RefTable != nil {
  1069. cols := statement.RefTable.PKColumns()
  1070. if len(cols) == 0 {
  1071. return ""
  1072. }
  1073. colstrs := statement.joinColumns(cols, false)
  1074. sqls := splitNNoCase(sqlStr, " from ", 2)
  1075. if len(sqls) != 2 {
  1076. return ""
  1077. }
  1078. var top string
  1079. if statement.LimitN > 0 && statement.Engine.dialect.DBType() == core.MSSQL {
  1080. top = fmt.Sprintf("TOP %d ", statement.LimitN)
  1081. }
  1082. newsql := fmt.Sprintf("SELECT %s%s FROM %v", top, colstrs, sqls[1])
  1083. return newsql
  1084. }
  1085. return ""
  1086. }
  1087. func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {
  1088. if statement.RefTable == nil || len(statement.RefTable.PrimaryKeys) != 1 {
  1089. return "", ""
  1090. }
  1091. colstrs := statement.joinColumns(statement.RefTable.PKColumns(), true)
  1092. sqls := splitNNoCase(sqlStr, "where", 2)
  1093. if len(sqls) != 2 {
  1094. if len(sqls) == 1 {
  1095. return sqls[0], fmt.Sprintf("SELECT %v FROM %v",
  1096. colstrs, statement.Engine.Quote(statement.TableName()))
  1097. }
  1098. return "", ""
  1099. }
  1100. var whereStr = sqls[1]
  1101. // TODO: for postgres only, if any other database?
  1102. var paraStr string
  1103. if statement.Engine.dialect.DBType() == core.POSTGRES {
  1104. paraStr = "$"
  1105. } else if statement.Engine.dialect.DBType() == core.MSSQL {
  1106. paraStr = ":"
  1107. }
  1108. if paraStr != "" {
  1109. if strings.Contains(sqls[1], paraStr) {
  1110. dollers := strings.Split(sqls[1], paraStr)
  1111. whereStr = dollers[0]
  1112. for i, c := range dollers[1:] {
  1113. ccs := strings.SplitN(c, " ", 2)
  1114. whereStr += fmt.Sprintf(paraStr+"%v %v", i+1, ccs[1])
  1115. }
  1116. }
  1117. }
  1118. return sqls[0], fmt.Sprintf("SELECT %v FROM %v WHERE %v",
  1119. colstrs, statement.Engine.Quote(statement.TableName()),
  1120. whereStr)
  1121. }