Skip to content

Commit d2b7ecb

Browse files
committed
feat: [#708] facades.DB() Support JSON Where Clauses
1 parent d5bbc40 commit d2b7ecb

File tree

12 files changed

+993
-137
lines changed

12 files changed

+993
-137
lines changed

contracts/database/db/db.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,16 @@ type Query interface {
110110
OrWhereColumn(column1 string, column2 ...string) Query
111111
// OrWhereIn adds an "or where column in" clause to the query.
112112
OrWhereIn(column string, values []any) Query
113+
// OrWhereJsonContains adds an "or where JSON contains" clause to the query.
114+
OrWhereJsonContains(column string, value any) Query
115+
// OrWhereJsonContainsKey add a clause that determines if a JSON path exists to the query.
116+
OrWhereJsonContainsKey(column string) Query
117+
// OrWhereJsonDoesntContain add an "or where JSON not contains" clause to the query.
118+
OrWhereJsonDoesntContain(column string, value any) Query
119+
// OrWhereJsonDoesntContainKey add a clause that determines if a JSON path does not exist to the query.
120+
OrWhereJsonDoesntContainKey(column string) Query
121+
// OrWhereJsonLength add an "or where JSON length" clause to the query.
122+
OrWhereJsonLength(column string, length int) Query
113123
// OrWhereLike adds an "or where column like" clause to the query.
114124
OrWhereLike(column string, value string) Query
115125
// OrWhereNot adds an "or where not" clause to the query.
@@ -161,6 +171,16 @@ type Query interface {
161171
WhereExists(func() Query) Query
162172
// WhereIn adds a "where column in" clause to the query.
163173
WhereIn(column string, values []any) Query
174+
// WhereJsonContains add a "where JSON contains" clause to the query.
175+
WhereJsonContains(column string, value any) Query
176+
// WhereJsonContainsKey add a clause that determines if a JSON path exists to the query.
177+
WhereJsonContainsKey(column string) Query
178+
// WhereJsonDoesntContain add a "where JSON not contains" clause to the query.
179+
WhereJsonDoesntContain(column string, value any) Query
180+
// WhereJsonDoesntContainKey add a clause that determines if a JSON path does not exist to the query.
181+
WhereJsonDoesntContainKey(column string) Query
182+
// WhereJsonLength add a "where JSON length" clause to the query.
183+
WhereJsonLength(column string, length int) Query
164184
// WhereLike adds a "where like" clause to the query.
165185
WhereLike(column string, value string) Query
166186
// WhereNot adds a basic "where not" clause to the query.

contracts/database/driver/conditions.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
package driver
22

3+
// WhereType type of where condition
4+
type WhereType int
5+
6+
const (
7+
WhereTypeBase WhereType = iota
8+
WhereTypeJsonContains
9+
WhereTypeJsonContainsKey
10+
WhereTypeJsonLength
11+
)
12+
313
type Conditions struct {
414
CrossJoin []Join
515
Distinct *bool
@@ -30,7 +40,9 @@ type Join struct {
3040
}
3141

3242
type Where struct {
43+
Type WhereType
3344
Query any
3445
Args []any
3546
Or bool
47+
IsNot bool
3648
}

database/db/builder.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,18 @@ package db
33
import (
44
"github.com/jmoiron/sqlx"
55
"gorm.io/gorm"
6+
7+
"github.com/goravel/framework/support/str"
68
)
79

10+
var NameMapper = func(s string) string {
11+
if s == "ID" {
12+
return "id"
13+
}
14+
15+
return str.Of(s).Snake().String()
16+
}
17+
818
type Builder struct {
919
*sqlx.DB
1020
gormDB *gorm.DB
@@ -17,6 +27,7 @@ func NewBuilder(gormDB *gorm.DB, driver string) (*Builder, error) {
1727
}
1828

1929
dbx := sqlx.NewDb(db, driver)
30+
dbx.MapperFunc(NameMapper)
2031

2132
return &Builder{
2233
DB: dbx,
@@ -40,6 +51,7 @@ func NewTxBuilder(gormDB *gorm.DB, driver string) (*TxBuilder, error) {
4051
}
4152

4253
dbx := sqlx.NewDb(db, driver)
54+
dbx.MapperFunc(NameMapper)
4355
tx, err := dbx.Beginx()
4456
if err != nil {
4557
return nil, err

database/db/query.go

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -537,14 +537,11 @@ func (r *Query) OrderByRaw(raw string) db.Query {
537537
}
538538

539539
func (r *Query) OrWhere(query any, args ...any) db.Query {
540-
q := r.clone()
541-
q.conditions.Where = append(q.conditions.Where, contractsdriver.Where{
540+
return r.addWhere(contractsdriver.Where{
542541
Query: query,
543542
Args: args,
544543
Or: true,
545544
})
546-
547-
return q
548545
}
549546

550547
func (r *Query) OrWhereBetween(column string, x, y any) db.Query {
@@ -568,6 +565,51 @@ func (r *Query) OrWhereIn(column string, values []any) db.Query {
568565
return r.OrWhere(column, values)
569566
}
570567

568+
func (r *Query) OrWhereJsonContains(column string, value any) db.Query {
569+
return r.addWhere(contractsdriver.Where{
570+
Type: contractsdriver.WhereTypeJsonContains,
571+
Query: column,
572+
Args: []any{value},
573+
Or: true,
574+
})
575+
}
576+
577+
func (r *Query) OrWhereJsonContainsKey(column string) db.Query {
578+
return r.addWhere(contractsdriver.Where{
579+
Type: contractsdriver.WhereTypeJsonContainsKey,
580+
Query: column,
581+
Or: true,
582+
})
583+
}
584+
585+
func (r *Query) OrWhereJsonDoesntContain(column string, value any) db.Query {
586+
return r.addWhere(contractsdriver.Where{
587+
Type: contractsdriver.WhereTypeJsonContains,
588+
Query: column,
589+
Args: []any{value},
590+
IsNot: true,
591+
Or: true,
592+
})
593+
}
594+
595+
func (r *Query) OrWhereJsonDoesntContainKey(column string) db.Query {
596+
return r.addWhere(contractsdriver.Where{
597+
Type: contractsdriver.WhereTypeJsonContainsKey,
598+
Query: column,
599+
IsNot: true,
600+
Or: true,
601+
})
602+
}
603+
604+
func (r *Query) OrWhereJsonLength(column string, length int) db.Query {
605+
return r.addWhere(contractsdriver.Where{
606+
Type: contractsdriver.WhereTypeJsonLength,
607+
Query: column,
608+
Args: []any{length},
609+
Or: true,
610+
})
611+
}
612+
571613
func (r *Query) OrWhereLike(column string, value string) db.Query {
572614
return r.OrWhere(sq.Like{column: value})
573615
}
@@ -769,13 +811,10 @@ func (r *Query) When(condition bool, callback func(query db.Query) db.Query, fal
769811
}
770812

771813
func (r *Query) Where(query any, args ...any) db.Query {
772-
q := r.clone()
773-
q.conditions.Where = append(q.conditions.Where, contractsdriver.Where{
814+
return r.addWhere(contractsdriver.Where{
774815
Query: query,
775816
Args: args,
776817
})
777-
778-
return q
779818
}
780819

781820
func (r *Query) WhereBetween(column string, x, y any) db.Query {
@@ -812,6 +851,46 @@ func (r *Query) WhereIn(column string, values []any) db.Query {
812851
return r.Where(column, values)
813852
}
814853

854+
func (r *Query) WhereJsonContains(column string, value any) db.Query {
855+
return r.addWhere(contractsdriver.Where{
856+
Type: contractsdriver.WhereTypeJsonContains,
857+
Query: column,
858+
Args: []any{value},
859+
})
860+
}
861+
862+
func (r *Query) WhereJsonContainsKey(column string) db.Query {
863+
return r.addWhere(contractsdriver.Where{
864+
Type: contractsdriver.WhereTypeJsonContainsKey,
865+
Query: column,
866+
})
867+
}
868+
869+
func (r *Query) WhereJsonDoesntContain(column string, value any) db.Query {
870+
return r.addWhere(contractsdriver.Where{
871+
Type: contractsdriver.WhereTypeJsonContains,
872+
Query: column,
873+
Args: []any{value},
874+
IsNot: true,
875+
})
876+
}
877+
878+
func (r *Query) WhereJsonDoesntContainKey(column string) db.Query {
879+
return r.addWhere(contractsdriver.Where{
880+
Type: contractsdriver.WhereTypeJsonContainsKey,
881+
Query: column,
882+
IsNot: true,
883+
})
884+
}
885+
886+
func (r *Query) WhereJsonLength(column string, length int) db.Query {
887+
return r.addWhere(contractsdriver.Where{
888+
Type: contractsdriver.WhereTypeJsonLength,
889+
Query: column,
890+
Args: []any{length},
891+
})
892+
}
893+
815894
func (r *Query) WhereLike(column string, value string) db.Query {
816895
return r.Where(sq.Like{column: value})
817896
}
@@ -865,6 +944,13 @@ func (r *Query) WhereRaw(raw string, args []any) db.Query {
865944
return r.Where(sq.Expr(raw, args...))
866945
}
867946

947+
func (r *Query) addWhere(where contractsdriver.Where) db.Query {
948+
q := r.clone()
949+
q.conditions.Where = append(q.conditions.Where, where)
950+
951+
return q
952+
}
953+
868954
func (r *Query) buildDelete() (sql string, args []any, err error) {
869955
if r.err != nil {
870956
return "", nil, r.err
@@ -1043,13 +1129,39 @@ func (r *Query) buildUpdate(data map[string]any) (sql string, args []any, err er
10431129
func (r *Query) buildWhere(where contractsdriver.Where) (any, []any, error) {
10441130
switch query := where.Query.(type) {
10451131
case string:
1046-
if !str.Of(query).Trim().Contains(" ", "?") {
1132+
switch where.Type {
1133+
case contractsdriver.WhereTypeJsonContains:
1134+
var err error
1135+
query, where.Args, err = r.grammar.CompileJsonContains(query, where.Args[0], where.IsNot)
1136+
if err != nil {
1137+
return nil, nil, errors.OrmJsonContainsInvalidBinding.Args(err)
1138+
}
1139+
case contractsdriver.WhereTypeJsonContainsKey:
1140+
query = str.Of(r.grammar.CompileJsonContainsKey(query, where.IsNot)).Replace("?", "??").String()
1141+
case contractsdriver.WhereTypeJsonLength:
1142+
segments := strings.SplitN(query, " ", 2)
1143+
segments[0] = r.grammar.CompileJsonLength(segments[0])
1144+
query = strings.Join(segments, " ")
1145+
default:
1146+
if str.Of(query).Trim().Contains("->") {
1147+
segments := strings.Split(query, " ")
1148+
for i := range segments {
1149+
if strings.Contains(segments[i], "->") {
1150+
segments[i] = r.grammar.CompileJsonSelector(segments[i])
1151+
}
1152+
}
1153+
query = strings.Join(segments, " ")
1154+
where.Args = r.grammar.CompileJsonValues(where.Args...)
1155+
}
1156+
}
1157+
if !str.Of(query).Trim().Contains("?") {
10471158
if len(where.Args) > 1 {
10481159
return sq.Eq{query: where.Args}, nil, nil
10491160
} else if len(where.Args) == 1 {
10501161
return sq.Eq{query: where.Args[0]}, nil, nil
10511162
}
10521163
}
1164+
10531165
return query, where.Args, nil
10541166
case map[string]any:
10551167
return sq.Eq(query), nil, nil

database/db/utils.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,18 @@ func convertToMap(data any) (map[string]any, error) {
123123

124124
// Get field name from db tag or use field name
125125
tag := field.Tag.Get("db")
126-
if tag == "" || tag == "-" {
126+
if tag == "-" {
127127
continue
128128
}
129129
var fieldName string
130-
if comma := strings.Index(tag, ","); comma != -1 {
131-
fieldName = tag[:comma]
130+
if tag != "" {
131+
if comma := strings.Index(tag, ","); comma != -1 {
132+
fieldName = tag[:comma]
133+
} else {
134+
fieldName = tag
135+
}
132136
} else {
133-
fieldName = tag
137+
fieldName = NameMapper(field.Name)
134138
}
135139

136140
fieldValue := val.Field(i)

database/db/utils_test.go

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ type Body struct {
1515
Weight string `db:"weight"`
1616
Head *string `db:"head"`
1717
Height int `db:"-"`
18-
Age uint
19-
DateTime carbon.DateTime `db:"date_time"`
20-
leg int `db:"leg"`
18+
Age uint `db:"-"`
19+
DateTime carbon.DateTime
20+
leg int `db:"leg"`
2121
}
2222

2323
type User struct {
@@ -65,8 +65,8 @@ func TestConvertToSliceMap(t *testing.T) {
6565
TestTimestamps: TestTimestamps{CreatedAt: dateTime, UpdatedAt: dateTime}},
6666
},
6767
want: []map[string]any{
68-
{"id": 1, "weight": "100kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
69-
{"id": 2, "length": 1, "weight": "90kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
68+
{"id": 1, "email": "john@example.com", "weight": "100kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
69+
{"id": 2, "email": "jane@example.com", "length": 1, "weight": "90kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
7070
},
7171
},
7272
{
@@ -80,8 +80,8 @@ func TestConvertToSliceMap(t *testing.T) {
8080
TestTimestamps: TestTimestamps{CreatedAt: dateTime, UpdatedAt: dateTime}},
8181
},
8282
want: []map[string]any{
83-
{"id": 1, "weight": "100kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
84-
{"id": 2, "length": 1, "weight": "90kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
83+
{"id": 1, "email": "john@example.com", "weight": "100kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
84+
{"id": 2, "email": "jane@example.com", "length": 1, "weight": "90kg", "head": &head, "date_time": *dateTime, "created_at": dateTime, "updated_at": dateTime, "deleted_at": deletedAt},
8585
},
8686
},
8787
{
@@ -126,9 +126,11 @@ func TestConvertToSliceMap(t *testing.T) {
126126
},
127127
}
128128

129-
for _, test := range tests {
130-
sliceMap, err := convertToSliceMap(test.data)
131-
assert.NoError(t, err)
132-
assert.Equal(t, test.want, sliceMap)
129+
for _, tt := range tests {
130+
t.Run(tt.name, func(t *testing.T) {
131+
sliceMap, err := convertToSliceMap(tt.data)
132+
assert.NoError(t, err)
133+
assert.Equal(t, tt.want, sliceMap)
134+
})
133135
}
134136
}

0 commit comments

Comments
 (0)