1 Star 0 Fork 0

小毛驴 / gorm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
statement.go 19.79 KB
一键复制 编辑 原始数据 按行查看 历史
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742
package gorm
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"sync"
"gorm.io/gorm/clause"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils"
)
// Statement statement
type Statement struct {
*DB
TableExpr *clause.Expr
Table string
Model interface{}
Unscoped bool
Dest interface{}
ReflectValue reflect.Value
Clauses map[string]clause.Clause
BuildClauses []string
Distinct bool
Selects []string // selected columns
Omits []string // omit columns
Joins []join
Preloads map[string][]interface{}
Settings sync.Map
ConnPool ConnPool
Schema *schema.Schema
Context context.Context
RaiseErrorOnNotFound bool
SkipHooks bool
SQL strings.Builder
Vars []interface{}
CurDestIndex int
attrs []interface{}
assigns []interface{}
scopes []func(*DB) *DB
}
type join struct {
Name string
Conds []interface{}
On *clause.Where
Selects []string
Omits []string
JoinType clause.JoinType
}
// StatementModifier statement modifier interface
type StatementModifier interface {
ModifyStatement(*Statement)
}
// WriteString write string
func (stmt *Statement) WriteString(str string) (int, error) {
return stmt.SQL.WriteString(str)
}
// WriteByte write byte
func (stmt *Statement) WriteByte(c byte) error {
return stmt.SQL.WriteByte(c)
}
// WriteQuoted write quoted value
func (stmt *Statement) WriteQuoted(value interface{}) {
stmt.QuoteTo(&stmt.SQL, value)
}
// QuoteTo write quoted value to writer
func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
write := func(raw bool, str string) {
if raw {
writer.WriteString(str)
} else {
stmt.DB.Dialector.QuoteTo(writer, str)
}
}
switch v := field.(type) {
case clause.Table:
if v.Name == clause.CurrentTable {
if stmt.TableExpr != nil {
stmt.TableExpr.Build(stmt)
} else {
write(v.Raw, stmt.Table)
}
} else {
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteByte(' ')
write(v.Raw, v.Alias)
}
case clause.Column:
if v.Table != "" {
if v.Table == clause.CurrentTable {
write(v.Raw, stmt.Table)
} else {
write(v.Raw, v.Table)
}
writer.WriteByte('.')
}
if v.Name == clause.PrimaryKey {
if stmt.Schema == nil {
stmt.DB.AddError(ErrModelValueRequired)
} else if stmt.Schema.PrioritizedPrimaryField != nil {
write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
} else if len(stmt.Schema.DBNames) > 0 {
write(v.Raw, stmt.Schema.DBNames[0])
} else {
stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck
}
} else {
write(v.Raw, v.Name)
}
if v.Alias != "" {
writer.WriteString(" AS ")
write(v.Raw, v.Alias)
}
case []clause.Column:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteByte(',')
}
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case clause.Expr:
v.Build(stmt)
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
case []string:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteByte(',')
}
stmt.DB.Dialector.QuoteTo(writer, d)
}
writer.WriteByte(')')
default:
stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
}
}
// Quote returns quoted value
func (stmt *Statement) Quote(field interface{}) string {
var builder strings.Builder
stmt.QuoteTo(&builder, field)
return builder.String()
}
// AddVar add var
func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
for idx, v := range vars {
if idx > 0 {
writer.WriteByte(',')
}
switch v := v.(type) {
case sql.NamedArg:
stmt.Vars = append(stmt.Vars, v.Value)
case clause.Column, clause.Table:
stmt.QuoteTo(writer, v)
case Valuer:
reflectValue := reflect.ValueOf(v)
if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
stmt.AddVar(writer, nil)
} else {
stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
}
case clause.Interface:
c := clause.Clause{Name: v.Name()}
v.MergeClause(&c)
c.Build(stmt)
case clause.Expression:
v.Build(stmt)
case driver.Valuer:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []byte:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
case []interface{}:
if len(v) > 0 {
writer.WriteByte('(')
stmt.AddVar(writer, v...)
writer.WriteByte(')')
} else {
writer.WriteString("(NULL)")
}
case *DB:
subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
if v.Statement.SQL.Len() > 0 {
var (
vars = subdb.Statement.Vars
sql = v.Statement.SQL.String()
)
subdb.Statement.Vars = make([]interface{}, 0, len(vars))
for _, vv := range vars {
subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
bindvar := strings.Builder{}
v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
sql = strings.Replace(sql, bindvar.String(), "?", 1)
}
subdb.Statement.SQL.Reset()
subdb.Statement.Vars = stmt.Vars
if strings.Contains(sql, "@") {
clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
} else {
clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
}
} else {
subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
subdb.callbacks.Query().Execute(subdb)
}
writer.WriteString(subdb.Statement.SQL.String())
stmt.Vars = subdb.Statement.Vars
default:
switch rv := reflect.ValueOf(v); rv.Kind() {
case reflect.Slice, reflect.Array:
if rv.Len() == 0 {
writer.WriteString("(NULL)")
} else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) {
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
} else {
writer.WriteByte('(')
for i := 0; i < rv.Len(); i++ {
if i > 0 {
writer.WriteByte(',')
}
stmt.AddVar(writer, rv.Index(i).Interface())
}
writer.WriteByte(')')
}
default:
stmt.Vars = append(stmt.Vars, v)
stmt.DB.Dialector.BindVarTo(writer, stmt, v)
}
}
}
}
// AddClause add clause
func (stmt *Statement) AddClause(v clause.Interface) {
if optimizer, ok := v.(StatementModifier); ok {
optimizer.ModifyStatement(stmt)
} else {
name := v.Name()
c := stmt.Clauses[name]
c.Name = name
v.MergeClause(&c)
stmt.Clauses[name] = c
}
}
// AddClauseIfNotExists add clause if not exists
func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
stmt.AddClause(v)
}
}
// BuildCondition build condition
func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
if s, ok := query.(string); ok {
// if it is a number, then treats it as primary key
if _, err := strconv.Atoi(s); err != nil {
if s == "" && len(args) == 0 {
return nil
}
if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) > 0 && strings.Contains(s, "@") {
// looks like a named query
return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
}
if strings.Contains(strings.TrimSpace(s), " ") {
// looks like a where condition
return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
}
if len(args) == 1 {
return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
}
}
}
conds := make([]clause.Expression, 0, 4)
args = append([]interface{}{query}, args...)
for idx, arg := range args {
if arg == nil {
continue
}
if valuer, ok := arg.(driver.Valuer); ok {
arg, _ = valuer.Value()
}
switch v := arg.(type) {
case clause.Expression:
conds = append(conds, v)
case *DB:
v.executeScopes()
if cs, ok := v.Statement.Clauses["WHERE"]; ok {
if where, ok := cs.Expression.(clause.Where); ok {
if len(where.Exprs) == 1 {
if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
where.Exprs[0] = clause.AndConditions(orConds)
}
}
conds = append(conds, clause.And(where.Exprs...))
} else if cs.Expression != nil {
conds = append(conds, cs.Expression)
}
}
case map[interface{}]interface{}:
for i, j := range v {
conds = append(conds, clause.Eq{Column: i, Value: j})
}
case map[string]string:
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
case map[string]interface{}:
keys := make([]string, 0, len(v))
for i := range v {
keys = append(keys, i)
}
sort.Strings(keys)
for _, key := range keys {
reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
if _, ok := v[key].(driver.Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else if _, ok := v[key].(Valuer); ok {
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
} else {
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]interface{}, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
conds = append(conds, clause.IN{Column: key, Values: values})
}
default:
conds = append(conds, clause.Eq{Column: key, Value: v[key]})
}
}
default:
reflectValue := reflect.Indirect(reflect.ValueOf(arg))
for reflectValue.Kind() == reflect.Ptr {
reflectValue = reflectValue.Elem()
}
if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
selectedColumns := map[string]bool{}
if idx == 0 {
for _, v := range args[1:] {
if vs, ok := v.(string); ok {
selectedColumns[vs] = true
}
}
}
restricted := len(selectedColumns) != 0
switch reflectValue.Kind() {
case reflect.Struct:
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
case reflect.Slice, reflect.Array:
for i := 0; i < reflectValue.Len(); i++ {
for _, field := range s.Fields {
selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
if selected || (!restricted && field.Readable) {
if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected {
if field.DBName != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
} else if field.DataType != "" {
conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
}
}
}
}
}
}
if restricted {
break
}
} else if !reflectValue.IsValid() {
stmt.AddError(ErrInvalidData)
} else if len(conds) == 0 {
if len(args) == 1 {
switch reflectValue.Kind() {
case reflect.Slice, reflect.Array:
// optimize reflect value length
valueLen := reflectValue.Len()
values := make([]interface{}, valueLen)
for i := 0; i < valueLen; i++ {
values[i] = reflectValue.Index(i).Interface()
}
if len(values) > 0 {
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
return []clause.Expression{clause.And(conds...)}
}
return nil
}
}
conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
}
}
}
if len(conds) > 0 {
return []clause.Expression{clause.And(conds...)}
}
return nil
}
// Build build sql with clauses names
func (stmt *Statement) Build(clauses ...string) {
var firstClauseWritten bool
for _, name := range clauses {
if c, ok := stmt.Clauses[name]; ok {
if firstClauseWritten {
stmt.WriteByte(' ')
}
firstClauseWritten = true
if b, ok := stmt.DB.ClauseBuilders[name]; ok {
b(c, stmt)
} else {
c.Build(stmt)
}
}
}
}
func (stmt *Statement) Parse(value interface{}) (err error) {
return stmt.ParseWithSpecialTableName(value, "")
}
func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
stmt.Table = tables[1]
return
}
stmt.Table = stmt.Schema.Table
}
return err
}
func (stmt *Statement) clone() *Statement {
newStmt := &Statement{
TableExpr: stmt.TableExpr,
Table: stmt.Table,
Model: stmt.Model,
Unscoped: stmt.Unscoped,
Dest: stmt.Dest,
ReflectValue: stmt.ReflectValue,
Clauses: map[string]clause.Clause{},
Distinct: stmt.Distinct,
Selects: stmt.Selects,
Omits: stmt.Omits,
Preloads: map[string][]interface{}{},
ConnPool: stmt.ConnPool,
Schema: stmt.Schema,
Context: stmt.Context,
RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
SkipHooks: stmt.SkipHooks,
}
if stmt.SQL.Len() > 0 {
newStmt.SQL.WriteString(stmt.SQL.String())
newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
}
for k, c := range stmt.Clauses {
newStmt.Clauses[k] = c
}
for k, p := range stmt.Preloads {
newStmt.Preloads[k] = p
}
if len(stmt.Joins) > 0 {
newStmt.Joins = make([]join, len(stmt.Joins))
copy(newStmt.Joins, stmt.Joins)
}
if len(stmt.scopes) > 0 {
newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
copy(newStmt.scopes, stmt.scopes)
}
stmt.Settings.Range(func(k, v interface{}) bool {
newStmt.Settings.Store(k, v)
return true
})
return newStmt
}
// SetColumn set column's value
//
// stmt.SetColumn("Name", "jinzhu") // Hooks Method
// stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
if v, ok := stmt.Dest.(map[string]interface{}); ok {
v[name] = value
} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
for _, m := range v {
m[name] = value
}
} else if stmt.Schema != nil {
if field := stmt.Schema.LookUpField(name); field != nil {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
if stmt.ReflectValue != destValue {
if !destValue.CanAddr() {
destValueCanAddr := reflect.New(destValue.Type())
destValueCanAddr.Elem().Set(destValue)
stmt.Dest = destValueCanAddr.Interface()
destValue = destValueCanAddr.Elem()
}
switch destValue.Kind() {
case reflect.Struct:
stmt.AddError(field.Set(stmt.Context, destValue, value))
default:
stmt.AddError(ErrInvalidData)
}
}
switch stmt.ReflectValue.Kind() {
case reflect.Slice, reflect.Array:
if len(fromCallbacks) > 0 {
for i := 0; i < stmt.ReflectValue.Len(); i++ {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value))
}
} else {
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value))
}
case reflect.Struct:
if !stmt.ReflectValue.CanAddr() {
stmt.AddError(ErrInvalidValue)
return
}
stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value))
}
} else {
stmt.AddError(ErrInvalidField)
}
} else {
stmt.AddError(ErrInvalidField)
}
}
// Changed check model changed or not when updating
func (stmt *Statement) Changed(fields ...string) bool {
modelValue := stmt.ReflectValue
switch modelValue.Kind() {
case reflect.Slice, reflect.Array:
modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
}
selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
changed := func(field *schema.Field) bool {
fieldValue, _ := field.ValueOf(stmt.Context, modelValue)
if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
if mv, mok := stmt.Dest.(map[string]interface{}); mok {
if fv, ok := mv[field.Name]; ok {
return !utils.AssertEqual(fv, fieldValue)
} else if fv, ok := mv[field.DBName]; ok {
return !utils.AssertEqual(fv, fieldValue)
}
} else {
destValue := reflect.ValueOf(stmt.Dest)
for destValue.Kind() == reflect.Ptr {
destValue = destValue.Elem()
}
changedValue, zero := field.ValueOf(stmt.Context, destValue)
if v {
return !utils.AssertEqual(changedValue, fieldValue)
}
return !zero && !utils.AssertEqual(changedValue, fieldValue)
}
}
return false
}
if len(fields) == 0 {
for _, field := range stmt.Schema.FieldsByDBName {
if changed(field) {
return true
}
}
} else {
for _, name := range fields {
if field := stmt.Schema.LookUpField(name); field != nil {
if changed(field) {
return true
}
}
}
}
return false
}
var matchName = func() func(tableColumn string) (table, column string) {
nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`)
return func(tableColumn string) (table, column string) {
if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 {
table = matches[1]
star := matches[2]
columnName := matches[3]
if star != "" {
return table, star
}
return table, columnName
}
return "", ""
}
}()
// SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
results := map[string]bool{}
notRestricted := false
processColumn := func(column string, result bool) {
if stmt.Schema == nil {
results[column] = result
} else if column == "*" {
notRestricted = result
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else if column == clause.Associations {
for _, rel := range stmt.Schema.Relationships.Relations {
results[rel.Name] = result
}
} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
results[field.DBName] = result
} else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") {
if col == "*" {
for _, dbName := range stmt.Schema.DBNames {
results[dbName] = result
}
} else {
results[col] = result
}
} else {
results[column] = result
}
}
// select columns
for _, column := range stmt.Selects {
processColumn(column, true)
}
// omit columns
for _, column := range stmt.Omits {
processColumn(column, false)
}
if stmt.Schema != nil {
for _, field := range stmt.Schema.FieldsByName {
name := field.DBName
if name == "" {
name = field.Name
}
if requireCreate && !field.Creatable {
results[name] = false
} else if requireUpdate && !field.Updatable {
results[name] = false
}
}
}
return results, !notRestricted && len(stmt.Selects) > 0
}
1
https://gitee.com/dream_xml/gorm.git
git@gitee.com:dream_xml/gorm.git
dream_xml
gorm
gorm
master

搜索帮助