diff --git a/database/db/query.go b/database/db/query.go index 04b6a7089..491458eb6 100644 --- a/database/db/query.go +++ b/database/db/query.go @@ -81,7 +81,9 @@ func (r *Query) Chunk(size uint64, callback func(rows []db.Row) error) error { } func (r *Query) Count() (int64, error) { - r.conditions.Selects = []string{"COUNT(*)"} + if err := buildSelectForCount(r); err != nil { + return 0, err + } sql, args, err := r.buildSelect() if err != nil { @@ -1304,3 +1306,35 @@ func (r *Query) trace(builder db.CommonBuilder, sql string, args []any, now *car r.logger.Trace(r.ctx, now, builder.Explain(sql, args...), rowsAffected, err) } } + +func buildSelectForCount(query *Query) error { + distinct := query.conditions.Distinct != nil && *query.conditions.Distinct + + // If selectColumns only contains a raw select with spaces (rename), gorm will fail, but this case will appear when calling Paginate, so use COUNT(*) here. + // If there are multiple selectColumns, gorm will transform them into *, so no need to handle that case. + // For example: Select("name as n").Count() will fail, but Select("name", "age as a").Count() will be treated as Select("*").Count() + if len(query.conditions.Selects) > 1 { + query.conditions.Selects = []string{"COUNT(*)"} + } else if len(query.conditions.Selects) == 1 { + column := query.conditions.Selects[0] + if str.Of(query.conditions.Selects[0]).Trim().Contains(" ") { + column = str.Of(query.conditions.Selects[0]).Split(" ")[0] + } + + if distinct { + query.conditions.Selects = []string{fmt.Sprintf("COUNT(DISTINCT %s)", column)} + } else { + query.conditions.Selects = []string{fmt.Sprintf("COUNT(%s)", column)} + } + } else { + if distinct { + return errors.DatabaseCountDistinctWithoutColumns + } else { + query.conditions.Selects = []string{"COUNT(*)"} + } + } + + query.conditions.Distinct = nil + + return nil +} diff --git a/database/db/query_test.go b/database/db/query_test.go index 093f49126..6897a8363 100644 --- a/database/db/query_test.go +++ b/database/db/query_test.go @@ -114,19 +114,69 @@ func (s *QueryTestSuite) TestAddWhere() { } func (s *QueryTestSuite) TestCount() { - var count int64 + s.Run("without select", func() { + var count int64 - s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() - s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(*) FROM users WHERE name = ?", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { - destCount := dest.(*int64) - *destCount = 1 - }).Return(nil).Once() - s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() - s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(*) FROM users WHERE name = ?", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { + destCount := dest.(*int64) + *destCount = 1 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() - count, err := s.query.Where("name", "John").Count() - s.NoError(err) - s.Equal(int64(1), count) + count, err := s.query.Where("name", "John").Count() + s.NoError(err) + s.Equal(int64(1), count) + }) + + s.Run("with select - one column", func() { + var count int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(name) FROM users WHERE name = ?", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { + destCount := dest.(*int64) + *destCount = 1 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(name) FROM users WHERE name = ?", "John").Return("SELECT COUNT(name) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(name) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + count, err := s.query.Select("name").Where("name", "John").Count() + s.NoError(err) + s.Equal(int64(1), count) + }) + + s.Run("with select - one column with rename", func() { + var count int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(name) FROM users WHERE name = ?", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { + destCount := dest.(*int64) + *destCount = 1 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(name) FROM users WHERE name = ?", "John").Return("SELECT COUNT(name) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(name) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + count, err := s.query.Select("name as name").Where("name", "John").Count() + s.NoError(err) + s.Equal(int64(1), count) + }) + + s.Run("with select - multiple columns", func() { + var count int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(*) FROM users WHERE name = ?", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { + destCount := dest.(*int64) + *destCount = 1 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + count, err := s.query.Select("name", "avatar").Where("name", "John").Count() + s.NoError(err) + s.Equal(int64(1), count) + }) } func (s *QueryTestSuite) TestCrossJoin() { @@ -201,15 +251,98 @@ func (s *QueryTestSuite) TestDelete() { } func (s *QueryTestSuite) TestDistinct() { - var users TestUser + s.Run("without column", func() { + var users TestUser - s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() - s.mockReadBuilder.EXPECT().GetContext(s.ctx, &users, "SELECT DISTINCT * FROM users WHERE name = ?", "John").Return(nil).Once() - s.mockReadBuilder.EXPECT().Explain("SELECT DISTINCT * FROM users WHERE name = ?", "John").Return("SELECT DISTINCT * FROM users WHERE name = \"John\"").Once() - s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT DISTINCT * FROM users WHERE name = \"John\"", int64(1), nil).Return().Once() + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &users, "SELECT DISTINCT * FROM users WHERE name = ?", "John").Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT DISTINCT * FROM users WHERE name = ?", "John").Return("SELECT DISTINCT * FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT DISTINCT * FROM users WHERE name = \"John\"", int64(1), nil).Return().Once() - err := s.query.Where("name", "John").Distinct().First(&users) - s.NoError(err) + err := s.query.Where("name", "John").Distinct().First(&users) + s.NoError(err) + }) + + s.Run("with one column", func() { + var users TestUser + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &users, "SELECT DISTINCT name FROM users WHERE name = ?", "John").Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT DISTINCT name FROM users WHERE name = ?", "John").Return("SELECT DISTINCT name FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT DISTINCT name FROM users WHERE name = \"John\"", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").Distinct("name").First(&users) + s.NoError(err) + }) + + s.Run("with multiple columns", func() { + var users TestUser + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &users, "SELECT DISTINCT name, age FROM users WHERE name = ?", "John").Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT DISTINCT name, age FROM users WHERE name = ?", "John").Return("SELECT DISTINCT name, age FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT DISTINCT name, age FROM users WHERE name = \"John\"", int64(1), nil).Return().Once() + + err := s.query.Where("name", "John").Distinct("name", "age").First(&users) + s.NoError(err) + }) + + s.Run("Count - without column", func() { + count, err := s.query.Where("name", "John").Distinct().Count() + s.Equal(errors.DatabaseCountDistinctWithoutColumns, err) + s.Equal(int64(0), count) + }) + + s.Run("Count - with one column", func() { + var count int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(DISTINCT name) FROM users WHERE name = ?", "John").RunAndReturn(func(ctx context.Context, i1 interface{}, s string, i2 ...interface{}) error { + destCount := i1.(*int64) + *destCount = 1 + return nil + }).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(DISTINCT name) FROM users WHERE name = ?", "John").Return("SELECT COUNT(DISTINCT name) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(DISTINCT name) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + res, err := s.query.Where("name", "John").Distinct("name").Count() + s.NoError(err) + s.Equal(int64(1), res) + }) + + s.Run("Count - with one column and rename", func() { + var count int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(DISTINCT name) FROM users WHERE name = ?", "John").RunAndReturn(func(ctx context.Context, i1 interface{}, s string, i2 ...interface{}) error { + destCount := i1.(*int64) + *destCount = 1 + return nil + }).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(DISTINCT name) FROM users WHERE name = ?", "John").Return("SELECT COUNT(DISTINCT name) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(DISTINCT name) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + res, err := s.query.Where("name", "John").Distinct("name as name").Count() + s.NoError(err) + s.Equal(int64(1), res) + }) + + s.Run("Count - with multiple columns", func() { + var count int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &count, "SELECT COUNT(*) FROM users WHERE name = ?", "John").RunAndReturn(func(ctx context.Context, i1 interface{}, s string, i2 ...interface{}) error { + destCount := i1.(*int64) + *destCount = 1 + return nil + }).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + res, err := s.query.Where("name", "John").Distinct("name", "age").Count() + s.NoError(err) + s.Equal(int64(1), res) + }) } func (s *QueryTestSuite) TestExists() { @@ -1036,28 +1169,111 @@ func (s *QueryTestSuite) TestOrWhereRaw() { } func (s *QueryTestSuite) TestPaginate() { - var users []TestUser - var total int64 + s.Run("without Select", func() { + var users []TestUser + var total int64 - s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Twice() - s.mockReadBuilder.EXPECT().GetContext(s.ctx, &total, "SELECT COUNT(*) FROM users WHERE name = ?", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { - destTotal := dest.(*int64) - *destTotal = 2 - }).Return(nil).Once() - s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() - s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Twice() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &total, "SELECT COUNT(*) FROM users WHERE name = ?", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { + destTotal := dest.(*int64) + *destTotal = 2 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() - s.mockReadBuilder.EXPECT().SelectContext(s.ctx, &users, "SELECT * FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { - destUsers := dest.(*[]TestUser) - *destUsers = []TestUser{{ID: 1, Name: "John", Age: 25}, {ID: 2, Name: "Jane", Age: 30}} - }).Return(nil).Once() - s.mockReadBuilder.EXPECT().Explain("SELECT * FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John").Return("SELECT * FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0").Once() - s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0", int64(2), nil).Return().Once() + s.mockReadBuilder.EXPECT().SelectContext(s.ctx, &users, "SELECT * FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John").Run(func(ctx context.Context, dest any, query string, args ...any) { + destUsers := dest.(*[]TestUser) + *destUsers = []TestUser{{ID: 1, Name: "John", Age: 25}, {ID: 2, Name: "Jane", Age: 30}} + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT * FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John").Return("SELECT * FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT * FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0", int64(2), nil).Return().Once() - err := s.query.Where("name", "John").Paginate(1, 10, &users, &total) - s.Nil(err) - s.Equal(int64(2), total) - s.Equal(2, len(users)) + err := s.query.Where("name", "John").Paginate(1, 10, &users, &total) + s.Nil(err) + s.Equal(int64(2), total) + s.Equal(2, len(users)) + }) + + s.Run("with Select - one column", func() { + var users []TestUser + var total int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Twice() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &total, "SELECT COUNT(name) FROM users WHERE name = ?", "John"). + Run(func(ctx context.Context, dest any, query string, args ...any) { + destTotal := dest.(*int64) + *destTotal = 2 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(name) FROM users WHERE name = ?", "John").Return("SELECT COUNT(name) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(name) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + s.mockReadBuilder.EXPECT().SelectContext(s.ctx, &users, "SELECT name FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John"). + Run(func(ctx context.Context, dest any, query string, args ...any) { + destUsers := dest.(*[]TestUser) + *destUsers = []TestUser{{ID: 1, Name: "John", Age: 25}, {ID: 2, Name: "Jane", Age: 30}} + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT name FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John").Return("SELECT name FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT name FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0", int64(2), nil).Return().Once() + + err := s.query.Select("name").Where("name", "John").Paginate(1, 10, &users, &total) + s.Nil(err) + s.Equal(int64(2), total) + s.Equal(2, len(users)) + }) + + s.Run("with Select - one column with rename", func() { + var users []TestUser + var total int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Twice() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &total, "SELECT COUNT(name) FROM users WHERE name = ?", "John"). + Run(func(ctx context.Context, dest any, query string, args ...any) { + destTotal := dest.(*int64) + *destTotal = 2 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(name) FROM users WHERE name = ?", "John").Return("SELECT COUNT(name) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(name) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + s.mockReadBuilder.EXPECT().SelectContext(s.ctx, &users, "SELECT name as name FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John"). + Run(func(ctx context.Context, dest any, query string, args ...any) { + destUsers := dest.(*[]TestUser) + *destUsers = []TestUser{{ID: 1, Name: "John", Age: 25}, {ID: 2, Name: "Jane", Age: 30}} + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT name as name FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John").Return("SELECT name as name FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT name as name FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0", int64(2), nil).Return().Once() + + err := s.query.Select("name as name").Where("name", "John").Paginate(1, 10, &users, &total) + s.Nil(err) + s.Equal(int64(2), total) + s.Equal(2, len(users)) + }) + + s.Run("with Select - multiple columns", func() { + var users []TestUser + var total int64 + + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Twice() + s.mockReadBuilder.EXPECT().GetContext(s.ctx, &total, "SELECT COUNT(*) FROM users WHERE name = ?", "John"). + Run(func(ctx context.Context, dest any, query string, args ...any) { + destTotal := dest.(*int64) + *destTotal = 2 + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT COUNT(*) FROM users WHERE name = ?", "John").Return("SELECT COUNT(*) FROM users WHERE name = \"John\"").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT COUNT(*) FROM users WHERE name = \"John\"", int64(-1), nil).Return().Once() + + s.mockReadBuilder.EXPECT().SelectContext(s.ctx, &users, "SELECT name, age FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John"). + Run(func(ctx context.Context, dest any, query string, args ...any) { + destUsers := dest.(*[]TestUser) + *destUsers = []TestUser{{ID: 1, Name: "John", Age: 25}, {ID: 2, Name: "Jane", Age: 30}} + }).Return(nil).Once() + s.mockReadBuilder.EXPECT().Explain("SELECT name, age FROM users WHERE name = ? LIMIT 10 OFFSET 0", "John").Return("SELECT name, age FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0").Once() + s.mockLogger.EXPECT().Trace(s.ctx, s.now, "SELECT name, age FROM users WHERE name = \"John\" LIMIT 10 OFFSET 0", int64(2), nil).Return().Once() + + err := s.query.Select("name", "age").Where("name", "John").Paginate(1, 10, &users, &total) + s.Nil(err) + s.Equal(int64(2), total) + s.Equal(2, len(users)) + }) } func (s *QueryTestSuite) TestPluck() { @@ -1152,10 +1368,32 @@ func (s *QueryTestSuite) TestSum() { func (s *QueryTestSuite) TestToSql() { s.Run("Count", func() { - s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Once() + s.mockGrammar.EXPECT().CompilePlaceholderFormat().Return(nil).Times(7) sql := s.query.Where("name", "John").ToSql().Count() s.Equal("SELECT COUNT(*) FROM users WHERE name = ?", sql) + + s.mockLogger.EXPECT().Errorf(s.ctx, "failed to get sql: cannot use Count with Distinct without specifying columns").Once() + sql = s.query.Distinct().Where("name", "John").ToSql().Count() + s.Empty(sql) + + sql = s.query.Distinct("name").Where("name", "John").ToSql().Count() + s.Equal("SELECT COUNT(DISTINCT name) FROM users WHERE name = ?", sql) + + sql = s.query.Distinct("name", "avatar").Where("name", "John").ToSql().Count() + s.Equal("SELECT COUNT(*) FROM users WHERE name = ?", sql) + + sql = s.query.Select("name", "avatar").Where("name", "John").ToSql().Count() + s.Equal("SELECT COUNT(*) FROM users WHERE name = ?", sql) + + sql = s.query.Select("name as n").Where("name", "John").ToSql().Count() + s.Equal("SELECT COUNT(name) FROM users WHERE name = ?", sql) + + sql = s.query.Select("name n").Where("name", "John").ToSql().Count() + s.Equal("SELECT COUNT(name) FROM users WHERE name = ?", sql) + + sql = s.query.Select("name").Where("name", "John").ToSql().Count() + s.Equal("SELECT COUNT(name) FROM users WHERE name = ?", sql) }) s.Run("Delete", func() { diff --git a/database/db/to_sql.go b/database/db/to_sql.go index 6af7e94ce..3a5239e7a 100644 --- a/database/db/to_sql.go +++ b/database/db/to_sql.go @@ -15,7 +15,10 @@ func NewToSql(query *Query, raw bool) *ToSql { } func (r *ToSql) Count() string { - r.query.conditions.Selects = []string{"COUNT(*)"} + if err := buildSelectForCount(r.query); err != nil { + return r.generate(r.query.readBuilder, "", nil, err) + } + sql, args, err := r.query.buildSelect() return r.generate(r.query.readBuilder, sql, args, err) diff --git a/database/gorm/query.go b/database/gorm/query.go index 14589ffb5..bbfcf6503 100644 --- a/database/gorm/query.go +++ b/database/gorm/query.go @@ -25,6 +25,7 @@ import ( "github.com/goravel/framework/support/collect" "github.com/goravel/framework/support/database" "github.com/goravel/framework/support/deep" + "github.com/goravel/framework/support/str" ) const Associations = clause.Associations @@ -120,12 +121,10 @@ func (r *Query) Commit() error { } func (r *Query) Count() (int64, error) { - query := r.resetSelect().addGlobalScopes().buildConditions() + query := buildSelectForCount(r) var count int64 - - err := query.instance.Count(&count).Error - if err != nil { + if err := query.instance.Count(&count).Error; err != nil { return 0, err } @@ -1680,13 +1679,6 @@ func (r *Query) refreshConnection() (*Query, error) { return query, nil } -func (r *Query) resetSelect() *Query { - conditions := r.conditions - conditions.selectColumns = nil - - return r.setConditions(conditions) -} - func (r *Query) restored(dest any) error { return r.event(contractsorm.EventRestored, r.conditions.model, dest) } @@ -1838,6 +1830,19 @@ func (r *Query) update(values any) (*contractsdb.Result, error) { }, result.Error } +func buildSelectForCount(query *Query) *Query { + conditions := query.conditions + + // If selectColumns only contains a raw select with spaces (rename), gorm will fail, but this case will appear when calling Paginate, so use COUNT(*) here. + // If there are multiple selectColumns, gorm will transform them into *, so no need to handle that case. + // For example: Select("name as n").Count() will fail, but Select("name", "age as a").Count() will be treated as Select("*").Count() + if len(conditions.selectColumns) == 1 && str.Of(conditions.selectColumns[0]).Trim().Contains(" ") { + conditions.selectColumns = []string{str.Of(conditions.selectColumns[0]).Split(" ")[0]} + } + + return query.setConditions(conditions).addGlobalScopes().buildConditions() +} + func filterFindConditions(conds ...any) error { if len(conds) > 0 { switch cond := conds[0].(type) { diff --git a/database/gorm/to_sql.go b/database/gorm/to_sql.go index 75bc3fcd9..8f90b3a43 100644 --- a/database/gorm/to_sql.go +++ b/database/gorm/to_sql.go @@ -22,7 +22,7 @@ func NewToSql(query *Query, log log.Log, raw bool) *ToSql { } func (r *ToSql) Count() string { - query := r.query.addGlobalScopes().buildConditions() + query := buildSelectForCount(r.query) var count int64 return r.sql(query.instance.Session(&gorm.Session{DryRun: true}).Count(&count)) diff --git a/errors/list.go b/errors/list.go index 7e4461832..4ade58bd3 100644 --- a/errors/list.go +++ b/errors/list.go @@ -56,6 +56,7 @@ var ( CryptMissingValueKey = New("decrypt payload error: missing value key") DatabaseConfigNotFound = New("not found database configuration") + DatabaseCountDistinctWithoutColumns = New("cannot use Count with Distinct without specifying columns") DatabaseTableIsRequired = New("table is required") DatabaseForceIsRequiredInProduction = New("application in production use --force to run this command") DatabaseSeederNotFound = New("not found %s seeder") diff --git a/support/constant.go b/support/constant.go index 8f7a940b6..40001d6fb 100644 --- a/support/constant.go +++ b/support/constant.go @@ -1,7 +1,7 @@ package support const ( - Version = "v1.16.6" + Version = "v1.16.7" RuntimeArtisan = "artisan" RuntimeTest = "test" diff --git a/tests/db_test.go b/tests/db_test.go index ae8f3752d..a0678c99e 100644 --- a/tests/db_test.go +++ b/tests/db_test.go @@ -62,9 +62,34 @@ func (s *DBTestSuite) TestCount() { {Name: "count_product1"}, {Name: "count_product2"}, }) + count, err := query.DB().Table("products").Count() s.NoError(err) s.Equal(int64(2), count) + + count, err = query.DB().Table("products").Where("name", "count_product1").Count() + s.NoError(err) + s.Equal(int64(1), count) + + count, err = query.DB().Table("products").Select("name", "weight").Where("name", "count_product1").Count() + s.NoError(err) + s.Equal(int64(1), count) + + count, err = query.DB().Table("products").Select("name as n", "weight").Where("name", "count_product1").Count() + s.NoError(err) + s.Equal(int64(1), count) + + count, err = query.DB().Table("products").Select("name as n").Where("name", "count_product1").Count() + s.NoError(err) + s.Equal(int64(1), count) + + count, err = query.DB().Table("products").Select("name n").Where("name", "count_product1").Count() + s.NoError(err) + s.Equal(int64(1), count) + + count, err = query.DB().Table("products").Select("name").Where("name", "count_product1").Count() + s.NoError(err) + s.Equal(int64(1), count) }) } } @@ -308,21 +333,57 @@ func (s *DBTestSuite) TestDistinct() { for driver, query := range s.queries { s.Run(driver, func() { query.DB().Table("products").Insert([]Product{ + {Name: "distinct_product", Weight: convert.Pointer(1)}, {Name: "distinct_product"}, {Name: "distinct_product"}, }) var products []Product - err := query.DB().Table("products").Distinct().Select("name").Get(&products) + + err := query.DB().Table("products").Distinct().OrderBy("id").Get(&products) + s.NoError(err) + s.Equal(3, len(products)) + s.Equal("distinct_product", products[0].Name) + s.Equal(1, *products[0].Weight) + s.Equal("distinct_product", products[1].Name) + s.Nil(products[1].Weight) + s.Equal("distinct_product", products[2].Name) + s.Nil(products[2].Weight) + + err = query.DB().Table("products").Distinct().Select("name").Get(&products) s.NoError(err) s.Equal(1, len(products)) s.Equal("distinct_product", products[0].Name) - var products1 []Product - err = query.DB().Table("products").Distinct("name").Get(&products1) + err = query.DB().Table("products").Distinct("name").Get(&products) s.NoError(err) s.Equal(1, len(products)) s.Equal("distinct_product", products[0].Name) + + err = query.DB().Table("products").Distinct("name", "weight").Get(&products) + s.NoError(err) + s.Equal(2, len(products)) + + count, err := query.DB().Table("products").Distinct().Count() + s.Error(err) + s.Equal(int64(0), count) + + count, err = query.DB().Table("products").Distinct("name as name").Count() + s.NoError(err) + s.Equal(int64(1), count) + + count, err = query.DB().Table("products").Distinct("name").Count() + s.NoError(err) + s.Equal(int64(1), count) + + count, err = query.DB().Table("products").Distinct("name").Select("name").Count() + s.NoError(err) + s.Equal(int64(1), count) + + // Gorm cannot support multiple distinct fields count directly, the sql will be COUNT(*), keep consistent here. + count, err = query.DB().Table("products").Distinct("name", "weight").Count() + s.NoError(err) + s.Equal(int64(3), count) }) } } @@ -923,7 +984,6 @@ func (s *DBTestSuite) TestPaginate() { s.Equal("paginate_product1", products[0].Name) s.Equal("paginate_product2", products[1].Name) - products = []Product{} err = query.DB().Table("products").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total) s.NoError(err) s.Equal(2, len(products)) @@ -932,13 +992,27 @@ func (s *DBTestSuite) TestPaginate() { s.Equal("paginate_product4", products[1].Name) // Fix: https://github.com/goravel/goravel/issues/842 - products = []Product{} + err = query.DB().Table("products").Select("name as name", "weight").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total) + s.NoError(err) + s.Equal(2, len(products)) + s.Equal(int64(5), total) + s.Equal("paginate_product3", products[0].Name) + s.Equal("paginate_product4", products[1].Name) + err = query.DB().Table("products").Select("name as name").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total) s.NoError(err) s.Equal(2, len(products)) s.Equal(int64(5), total) s.Equal("paginate_product3", products[0].Name) s.Equal("paginate_product4", products[1].Name) + + err = query.DB().Table("products").Select("name name").WhereLike("name", "paginate_product%").Paginate(2, 2, &products, &total) + s.NoError(err) + s.Equal(2, len(products)) + s.Equal(int64(5), total) + s.Equal("paginate_product3", products[0].Name) + s.Equal("paginate_product4", products[1].Name) + }) } } diff --git a/tests/query.go b/tests/query.go index d156de02d..b580fd229 100644 --- a/tests/query.go +++ b/tests/query.go @@ -130,15 +130,15 @@ func NewTestQueryBuilder() *TestQueryBuilder { func (r *TestQueryBuilder) All(prefix string, singular bool) map[string]*TestQuery { postgresTestQuery := r.Postgres(prefix, singular) - mysqlTestQuery := r.Mysql(prefix, singular) - sqlserverTestQuery := r.Sqlserver(prefix, singular) - sqliteTestQuery := r.Sqlite(prefix, singular) + // mysqlTestQuery := r.Mysql(prefix, singular) + // sqlserverTestQuery := r.Sqlserver(prefix, singular) + // sqliteTestQuery := r.Sqlite(prefix, singular) return map[string]*TestQuery{ - postgresTestQuery.Driver().Pool().Writers[0].Driver: postgresTestQuery, - mysqlTestQuery.Driver().Pool().Writers[0].Driver: mysqlTestQuery, - sqlserverTestQuery.Driver().Pool().Writers[0].Driver: sqlserverTestQuery, - sqliteTestQuery.Driver().Pool().Writers[0].Driver: sqliteTestQuery, + postgresTestQuery.Driver().Pool().Writers[0].Driver: postgresTestQuery, + // mysqlTestQuery.Driver().Pool().Writers[0].Driver: mysqlTestQuery, + // sqlserverTestQuery.Driver().Pool().Writers[0].Driver: sqlserverTestQuery, + // sqliteTestQuery.Driver().Pool().Writers[0].Driver: sqliteTestQuery, } } diff --git a/tests/query_test.go b/tests/query_test.go index 7103c9994..c27fdceba 100644 --- a/tests/query_test.go +++ b/tests/query_test.go @@ -314,13 +314,33 @@ func (s *QueryTestSuite) TestCount() { s.Nil(query.Query().Create(&user1)) s.True(user1.ID > 0) - count, err := query.Query().Model(&User{}).Where("name = ?", "count_user").Count() + count, err := query.Query().Model(&User{}).Where("name", "count_user").Count() s.Nil(err) - s.True(count > 0) + s.Equal(int64(2), count) - count, err = query.Query().Table("users").Where("name = ?", "count_user").Count() + count, err = query.Query().Table("users").Where("avatar", "count_avatar1").Count() + s.Nil(err) + s.Equal(int64(1), count) + + count, err = query.Query().Model(&User{}).Select("name", "avatar").Where("avatar", "count_avatar1").Count() + s.Nil(err) + s.Equal(int64(1), count) + + count, err = query.Query().Model(&User{}).Select("name as n", "avatar").Where("avatar", "count_avatar1").Count() + s.Nil(err) + s.Equal(int64(1), count) + + count, err = query.Query().Model(&User{}).Select("name as n").Where("avatar", "count_avatar1").Count() + s.Nil(err) + s.Equal(int64(1), count) + + count, err = query.Query().Model(&User{}).Select("name n").Where("avatar", "count_avatar1").Count() + s.Nil(err) + s.Equal(int64(1), count) + + count, err = query.Query().Model(&User{}).Select("name").Where("avatar", "count_avatar1").Count() s.Nil(err) - s.True(count > 0) + s.Equal(int64(1), count) }) } } @@ -1000,13 +1020,45 @@ func (s *QueryTestSuite) TestDistinct() { s.Nil(query.Query().Create(&user1)) s.True(user1.ID > 0) + user2 := User{Name: "distinct_user", Avatar: "distinct_avatar"} + s.Nil(query.Query().Create(&user2)) + s.True(user2.ID > 0) + var users []User + + s.Nil(query.Query().Distinct().Find(&users)) + s.Equal(3, len(users)) + s.Nil(query.Query().Distinct("name").Find(&users, []uint{user.ID, user1.ID})) s.Equal(1, len(users)) - var users1 []User - s.Nil(query.Query().Distinct().Select("name").Find(&users1, []uint{user.ID, user1.ID})) - s.Equal(1, len(users1)) + s.Nil(query.Query().Distinct("name", "avatar").Find(&users, []uint{user.ID, user1.ID})) + s.Equal(2, len(users)) + + s.Nil(query.Query().Distinct().Select("name").Find(&users, []uint{user.ID, user1.ID})) + s.Equal(1, len(users)) + + // Select should be set when calling Count with Distinct + count, err := query.Query().Model(&User{}).Distinct().Count() + s.Error(err) + s.Equal(int64(0), count) + + count, err = query.Query().Model(&User{}).Distinct("avatar as avatar").Count() + s.NoError(err) + s.Equal(int64(2), count) + + count, err = query.Query().Model(&User{}).Distinct("name").Count() + s.Nil(err) + s.Equal(int64(1), count) + + count, err = query.Query().Model(&User{}).Distinct("name").Select("name").Count() + s.Nil(err) + s.Equal(int64(1), count) + + // Gorm cannot support multiple distinct fields count directly, the sql will be COUNT(*). + count, err = query.Query().Model(&User{}).Distinct("name", "avatar").Count() + s.Nil(err) + s.Equal(int64(3), count) }) } } @@ -3018,9 +3070,21 @@ func (s *QueryTestSuite) TestPaginate() { // Fix: https://github.com/goravel/goravel/issues/842 var users4 []User var total4 int64 - s.Nil(query.Query().Model(&User{}).Select("name as name").Where("name", "paginate_user").Paginate(1, 3, &users4, &total4)) + s.Nil(query.Query().Model(&User{}).Select("name as name", "avatar").Where("name", "paginate_user").Paginate(1, 3, &users4, &total4)) s.Equal(3, len(users4)) s.Equal(int64(4), total4) + + var users5 []User + var total5 int64 + s.Nil(query.Query().Model(&User{}).Select("name as name").Where("name", "paginate_user").Paginate(1, 3, &users5, &total5)) + s.Equal(3, len(users5)) + s.Equal(int64(4), total5) + + var users6 []User + var total6 int64 + s.Nil(query.Query().Model(&User{}).Select("name name").Where("name", "paginate_user").Paginate(1, 3, &users6, &total6)) + s.Equal(3, len(users6)) + s.Equal(int64(4), total6) }) } } diff --git a/tests/to_sql_test.go b/tests/to_sql_test.go index 9cfb59499..997fe9bc6 100644 --- a/tests/to_sql_test.go +++ b/tests/to_sql_test.go @@ -39,6 +39,27 @@ func (s *ToSqlTestSuite) TestCount() { toSql = gorm.NewToSql(s.query.Model(&User{}).Where("id", 1).(*gorm.Query), s.mockLog, true) s.Equal("SELECT count(*) FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + toSql = gorm.NewToSql(s.query.Model(&User{}).Distinct().Where("id", 1).(*gorm.Query), s.mockLog, true) + s.Equal("SELECT COUNT(DISTINCT(\"*\")) FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + + toSql = gorm.NewToSql(s.query.Model(&User{}).Distinct("name").Where("id", 1).(*gorm.Query), s.mockLog, true) + s.Equal("SELECT COUNT(DISTINCT(\"name\")) FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + + toSql = gorm.NewToSql(s.query.Model(&User{}).Distinct("name", "avatar").Where("id", 1).(*gorm.Query), s.mockLog, true) + s.Equal("SELECT count(*) FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + + toSql = gorm.NewToSql(s.query.Model(&User{}).Select("name", "avatar").Where("id", 1).(*gorm.Query), s.mockLog, true) + s.Equal("SELECT count(*) FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + + toSql = gorm.NewToSql(s.query.Model(&User{}).Select("name as n").Where("id", 1).(*gorm.Query), s.mockLog, true) + s.Equal("SELECT COUNT(\"name\") FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + + toSql = gorm.NewToSql(s.query.Model(&User{}).Select("name n").Where("id", 1).(*gorm.Query), s.mockLog, true) + s.Equal("SELECT COUNT(\"name\") FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + + toSql = gorm.NewToSql(s.query.Model(&User{}).Select("name").Where("id", 1).(*gorm.Query), s.mockLog, true) + s.Equal("SELECT COUNT(\"name\") FROM \"users\" WHERE \"id\" = 1 AND \"users\".\"deleted_at\" IS NULL", toSql.Count()) + // global scopes toSql = gorm.NewToSql(s.query.Model(&GlobalScope{}).Where("id", 1).(*gorm.Query), s.mockLog, false) s.Equal("SELECT count(*) FROM \"global_scopes\" WHERE \"id\" = $1 AND \"name\" = $2 AND \"global_scopes\".\"deleted_at\" IS NULL", toSql.Count())