Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions pgstmt/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,31 @@ type SelectStatement interface {
From(table ...string)
FromSelect(f func(b SelectStatement), as string)
FromValues(f func(b Values), as string)

Join(table string) Join
InnerJoin(table string) Join
FullOuterJoin(table string) Join
LeftJoin(table string) Join
RightJoin(table string) Join

JoinSelect(f func(b SelectStatement), as string) Join
InnerJoinSelect(f func(b SelectStatement), as string) Join
FullOuterJoinSelect(f func(b SelectStatement), as string) Join
LeftJoinSelect(f func(b SelectStatement), as string) Join
RightJoinSelect(f func(b SelectStatement), as string) Join

JoinLateralSelect(f func(b SelectStatement), as string) Join
InnerJoinLateralSelect(f func(b SelectStatement), as string) Join
FullOuterJoinLateralSelect(f func(b SelectStatement), as string) Join
LeftJoinLateralSelect(f func(b SelectStatement), as string) Join
RightJoinLateralSelect(f func(b SelectStatement), as string) Join

JoinUnion(f func(b UnionStatement), as string) Join
InnerJoinUnion(f func(b UnionStatement), as string) Join
FullOuterJoinUnion(f func(b UnionStatement), as string) Join
LeftJoinUnion(f func(b UnionStatement), as string) Join
RightJoinUnion(f func(b UnionStatement), as string) Join

Where(f func(b Cond))
GroupBy(col ...string)
Having(f func(b Cond))
Expand Down Expand Up @@ -217,6 +227,44 @@ func (st *selectStmt) RightJoinLateralSelect(f func(b SelectStatement), as strin
return st.joinSelect("right join lateral", f, as)
}

func (st *selectStmt) joinUnion(typ string, f func(b UnionStatement), as string) Join {
var x unionStmt
f(&x)

var b buffer
b.push(paren(x.make()))
if as != "" {
b.push(as)
}

j := join{
typ: typ,
table: &b,
}
st.joins.push(&j)
return &j
}

func (st *selectStmt) JoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("join", f, as)
}

func (st *selectStmt) InnerJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("inner join", f, as)
}

func (st *selectStmt) FullOuterJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("full outer join", f, as)
}

func (st *selectStmt) LeftJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("left join", f, as)
}

func (st *selectStmt) RightJoinUnion(f func(b UnionStatement), as string) Join {
return st.joinUnion("right join", f, as)
}

func (st *selectStmt) Where(f func(b Cond)) {
f(&st.where)
}
Expand Down
31 changes: 31 additions & 0 deletions pgstmt/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,37 @@ func TestSelect(t *testing.T) {
`,
nil,
},
{
"inner join union",
pgstmt.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table1")
b.InnerJoinUnion(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table2")
})
b.AllSelect(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table3")
})
b.OrderBy("id").Desc()
b.Limit(100)
}, "t").Using("id")
}),
`
select id
from table1
inner join (
(select id from table2)
union all
(select id from table3)
order by id desc
limit 100
) t using (id)
`,
nil,
},
}

for _, tC := range cases {
Expand Down
99 changes: 99 additions & 0 deletions pgstmt/union.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package pgstmt

func Union(f func(b UnionStatement)) *Result {
var st unionStmt
f(&st)
return newResult(build(st.make()))
}

type UnionStatement interface {
Select(f func(b SelectStatement))
AllSelect(f func(b SelectStatement))
Union(f func(b UnionStatement))
AllUnion(f func(b UnionStatement))
OrderBy(col string) OrderBy
Limit(n int64)
Offset(n int64)
}

type unionStmt struct {
b buffer
orderBy group
limit *int64
offset *int64
}

func (st *unionStmt) Select(f func(b SelectStatement)) {
var x selectStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union", paren(x.make()))
}
}

func (st *unionStmt) AllSelect(f func(b SelectStatement)) {
var x selectStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union all", paren(x.make()))
}
}

func (st *unionStmt) Union(f func(b UnionStatement)) {
var x unionStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union", paren(x.make()))
}
}

func (st *unionStmt) AllUnion(f func(b UnionStatement)) {
var x unionStmt
f(&x)

if st.b.empty() {
st.b.push(paren(x.make()))
} else {
st.b.push("union all", paren(x.make()))
}
}

func (st *unionStmt) OrderBy(col string) OrderBy {
p := orderBy{
col: col,
}
st.orderBy.push(&p)
return &p
}

func (st *unionStmt) Limit(n int64) {
st.limit = &n
}

func (st *unionStmt) Offset(n int64) {
st.offset = &n
}

func (st *unionStmt) make() *buffer {
var b buffer
b.push(&st.b)
if !st.orderBy.empty() {
b.push("order by", &st.orderBy)
}
if st.limit != nil {
b.push("limit", *st.limit)
}
if st.offset != nil {
b.push("offset", *st.offset)
}
return &b
}
94 changes: 94 additions & 0 deletions pgstmt/union_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package pgstmt_test

import (
"testing"

"github.com/stretchr/testify/assert"

"github.com/acoshift/pgsql/pgstmt"
)

func TestUnion(t *testing.T) {
t.Parallel()

cases := []struct {
name string
result *pgstmt.Result
query string
args []interface{}
}{
{
"union select",
pgstmt.Union(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table1")
})
b.AllSelect(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table2")
})
b.OrderBy("id")
b.Limit(10)
b.Offset(2)
}),
`
(select id from table1)
union all (select id from table2)
order by id
limit 10 offset 2
`,
nil,
},
{
"union nested",
pgstmt.Union(func(b pgstmt.UnionStatement) {
b.Union(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table1")
})
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table2")
})
})
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table3")
})
b.AllUnion(func(b pgstmt.UnionStatement) {
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table4")
})
b.Select(func(b pgstmt.SelectStatement) {
b.Columns("id")
b.From("table5")
})
})
}),
`
(
(select id from table1)
union (select id from table2)
)
union (select id from table3)
union all (
(select id from table4)
union
(select id from table5)
)
`,
nil,
},
}

for _, tC := range cases {
t.Run(tC.name, func(t *testing.T) {
q, args := tC.result.SQL()
assert.Equal(t, stripSpace(tC.query), q)
assert.EqualValues(t, tC.args, args)
})
}
}
2 changes: 2 additions & 0 deletions pgstmt/util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,7 @@ func stripSpace(s string) string {
}
s = p
}
s = strings.ReplaceAll(s, "( ", "(")
s = strings.ReplaceAll(s, " )", ")")
return s
}