Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions contracts/database/orm/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ type Event interface {
GetAttribute(key string) any
// GetOriginal returns the original attribute value for the given key.
GetOriginal(key string, def ...any) any
// IsDirty returns true if the given column is dirty.
IsDirty(columns ...string) bool
// IsClean returns true if the given column is clean.
IsClean(columns ...string) bool
// IsDirty returns true if the given column is dirty.
IsDirty(columns ...string) bool
// Query returns the query instance.
Query() Query
// SetAttribute sets the attribute value for the given key.
Expand Down
7 changes: 4 additions & 3 deletions contracts/database/orm/orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ type Query interface {
// Cursor returns a cursor, use scan to iterate over the returned rows.
Cursor() (chan Cursor, error)
// Delete deletes records matching given conditions, if the conditions are empty will delete all records.
Delete(value any, conds ...any) (*Result, error)
Delete(value ...any) (*Result, error)
// Distinct specifies distinct fields to query.
Distinct(args ...any) Query
// Driver gets the driver for the query.
Expand All @@ -63,7 +63,7 @@ type Query interface {
// return a new instance of the model initialized with those attributes.
FirstOrNew(dest any, attributes any, values ...any) error
// ForceDelete forces delete records matching given conditions.
ForceDelete(value any, conds ...any) (*Result, error)
ForceDelete(value ...any) (*Result, error)
// Get retrieves all rows from the database.
Get(dest any) error
// Group specifies the group method on the query.
Expand Down Expand Up @@ -193,9 +193,10 @@ type Result struct {
type ToSql interface {
Count() string
Create(value any) string
Delete(value any, conds ...any) string
Delete(value ...any) string
Find(dest any, conds ...any) string
First(dest any) string
ForceDelete(value ...any) string
Get(dest any) string
Pluck(column string, dest any) string
Save(value any) string
Expand Down
193 changes: 100 additions & 93 deletions database/gorm/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,50 +28,12 @@ func NewEvent(query *QueryImpl, model, dest any) *Event {
}
}

func (e *Event) ColumnNames() map[string]string {
if e.columnNames != nil {
return e.columnNames
}

if e.model != nil {
return fetchColumnNames(e.model)
} else {
return fetchColumnNames(e.dest)
}
}

func (e *Event) Context() context.Context {
return e.query.ctx
}

func (e *Event) DestOfMap() map[string]any {
if e.destOfMap != nil {
return e.destOfMap
}

destOfMap := make(map[string]any)
if destMap, ok := e.dest.(map[string]any); ok {
for key, value := range destMap {
destOfMap[key] = value
destOfMap[str.Of(key).Snake().String()] = value
}
} else {
destType := reflect.TypeOf(e.dest)
if destType.Kind() == reflect.Pointer {
destType = destType.Elem()
}
if destType.Kind() == reflect.Struct {
destOfMap = structToMap(e.dest)
}
}

e.destOfMap = destOfMap

return e.destOfMap
}

func (e *Event) GetAttribute(key string) any {
destOfMap := e.DestOfMap()
destOfMap := e.getDestOfMap()
value, exist := destOfMap[e.toDBColumnName(key)]
if exist && e.validColumn(key) && e.validValue(key, value) {
return value
Expand All @@ -81,7 +43,7 @@ func (e *Event) GetAttribute(key string) any {
}

func (e *Event) GetOriginal(key string, def ...any) any {
modelOfMap := e.ModelOfMap()
modelOfMap := e.getModelOfMap()
value, exist := modelOfMap[e.toDBColumnName(key)]
if exist {
return value
Expand All @@ -94,8 +56,12 @@ func (e *Event) GetOriginal(key string, def ...any) any {
return nil
}

func (e *Event) IsClean(fields ...string) bool {
return !e.IsDirty(fields...)
}

func (e *Event) IsDirty(columns ...string) bool {
destOfMap := e.DestOfMap()
destOfMap := e.getDestOfMap()

if len(columns) == 0 {
for destColumn, destValue := range destOfMap {
Expand Down Expand Up @@ -125,30 +91,16 @@ func (e *Event) IsDirty(columns ...string) bool {
return false
}

func (e *Event) IsClean(fields ...string) bool {
return !e.IsDirty(fields...)
}

func (e *Event) ModelOfMap() map[string]any {
if e.modelOfMap != nil {
return e.modelOfMap
}

if e.model == nil {
return map[string]any{}
}

e.modelOfMap = structToMap(e.model)

return e.modelOfMap
}

func (e *Event) Query() orm.Query {
return NewQueryImpl(e.query.ctx, e.query.config, e.query.connection, e.query.instance.Session(&gorm.Session{NewDB: true}), nil)
}

func (e *Event) SetAttribute(key string, value any) {
destOfMap := e.DestOfMap()
if e.dest == nil {
return
}

destOfMap := e.getDestOfMap()
destOfMap[e.toDBColumnName(key)] = value
e.destOfMap = destOfMap

Expand Down Expand Up @@ -187,7 +139,7 @@ func (e *Event) SetAttribute(key string, value any) {
}

func (e *Event) dirty(destColumn string, destValue any) bool {
modelOfMap := e.ModelOfMap()
modelOfMap := e.getModelOfMap()
dbDestColumn := e.toDBColumnName(destColumn)

if modelValue, exist := modelOfMap[dbDestColumn]; exist {
Expand All @@ -208,8 +160,63 @@ func (e *Event) equalColumnName(origin, source string) bool {
return originDbColumnName == sourceDbColumnName
}

func (e *Event) getColumnNames() map[string]string {
if e.columnNames == nil {
if e.model != nil {
e.columnNames = fetchColumnNames(e.model)
} else {
e.columnNames = fetchColumnNames(e.dest)
}
}

return e.columnNames
}

func (e *Event) getDestOfMap() map[string]any {
if e.dest == nil {
return nil
}
if e.destOfMap != nil {
return e.destOfMap
}

destOfMap := make(map[string]any)
if destMap, ok := e.dest.(map[string]any); ok {
for key, value := range destMap {
destOfMap[key] = value
destOfMap[str.Of(key).Snake().String()] = value
}
} else {
destType := reflect.TypeOf(e.dest)
if destType.Kind() == reflect.Pointer {
destType = destType.Elem()
}
if destType.Kind() == reflect.Struct {
destOfMap = structToMap(e.dest)
}
}

e.destOfMap = destOfMap

return e.destOfMap
}

func (e *Event) getModelOfMap() map[string]any {
if e.modelOfMap != nil {
return e.modelOfMap
}

if e.model == nil {
return map[string]any{}
}

e.modelOfMap = structToMap(e.model)

return e.modelOfMap
}

func (e *Event) toDBColumnName(name string) string {
dbColumnName, exist := e.ColumnNames()[name]
dbColumnName, exist := e.getColumnNames()[name]
if exist {
return dbColumnName
}
Expand Down Expand Up @@ -277,6 +284,37 @@ func (e *Event) validValue(name string, value any) bool {
return !valueValue.IsZero()
}

func fetchColumnNames(model any) map[string]string {
res := make(map[string]string)
modelType := reflect.TypeOf(model)
modelValue := reflect.ValueOf(model)
if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
modelValue = modelValue.Elem()
}

for i := 0; i < modelType.NumField(); i++ {
if !modelType.Field(i).IsExported() {
continue
}
fieldType := modelType.Field(i)
fieldValue := modelValue.Field(i)
if fieldValue.Kind() == reflect.Struct && fieldType.Anonymous {
subStructMap := fetchColumnNames(fieldValue.Interface())
for key, value := range subStructMap {
res[key] = value
}
continue
}

dbColumn := structNameToDbColumnName(modelType.Field(i).Name, modelType.Field(i).Tag.Get("gorm"))
res[modelType.Field(i).Name] = dbColumn
res[dbColumn] = dbColumn
}

return res
}

func structToMap(data any) map[string]any {
res := make(map[string]any)
modelType := reflect.TypeOf(data)
Expand Down Expand Up @@ -334,34 +372,3 @@ func structNameToDbColumnName(structName, tag string) string {

return str.Of(structName).Snake().String()
}

func fetchColumnNames(model any) map[string]string {
res := make(map[string]string)
modelType := reflect.TypeOf(model)
modelValue := reflect.ValueOf(model)
if modelType.Kind() == reflect.Pointer {
modelType = modelType.Elem()
modelValue = modelValue.Elem()
}

for i := 0; i < modelType.NumField(); i++ {
if !modelType.Field(i).IsExported() {
continue
}
fieldType := modelType.Field(i)
fieldValue := modelValue.Field(i)
if fieldValue.Kind() == reflect.Struct && fieldType.Anonymous {
subStructMap := fetchColumnNames(fieldValue.Interface())
for key, value := range subStructMap {
res[key] = value
}
continue
}

dbColumn := structNameToDbColumnName(modelType.Field(i).Name, modelType.Field(i).Tag.Get("gorm"))
res[modelType.Field(i).Name] = dbColumn
res[dbColumn] = dbColumn
}

return res
}
2 changes: 1 addition & 1 deletion database/gorm/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ func (s *EventTestSuite) TestColumnNames() {
"admin_at": "admin_at",
"ManageAt": "manage_at",
"manage_at": "manage_at",
}, event.ColumnNames())
}, event.getColumnNames())
}
}

Expand Down
Loading