Skip to content

Commit 4d8e577

Browse files
authored
RDS provider uses writer for dump output (#42)
* RDS provider uses writer for dump output * Linting
1 parent abf212f commit 4d8e577

9 files changed

Lines changed: 230 additions & 156 deletions

File tree

internal/mysql/provider/provider.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package provider
22

3-
import "context"
3+
import (
4+
"context"
5+
"io"
6+
)
47

58
// Interface implements the required functionality for a Provider.
69
type Interface interface {
7-
GetSelectQueryForTable(ctx context.Context, table string, params DumpParams) (string, error)
10+
WriteTableData(ctx context.Context, w io.Writer, table string, params DumpParams) error
811
}
912

1013
// DumpParams is used to pass parameters to the Dump function.

internal/mysql/provider/rds/provider.go

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"fmt"
7+
"io"
78
"log"
89
"strings"
910

@@ -31,11 +32,30 @@ func NewClient(conn *sql.Conn, logger *log.Logger, region, uri string) *Client {
3132
}
3233
}
3334

34-
// GetSelectQueryForTable will return a complete SELECT query to export data from a table.
35-
func (d *Client) GetSelectQueryForTable(ctx context.Context, table string, params provider.DumpParams) (string, error) {
35+
// WriteTableData will export the data from a table to S3 and write the LOAD DATA query to the provided writer.
36+
func (d *Client) WriteTableData(ctx context.Context, w io.Writer, table string, params provider.DumpParams) error {
37+
// Push table data to s3.
38+
err := d.exportTableData(ctx, table, params)
39+
if err != nil {
40+
return err
41+
}
42+
43+
// Write the import query to the writer.
44+
err = d.writeLoadQueryForTable(w, table)
45+
if err != nil {
46+
return err
47+
}
48+
49+
return nil
50+
}
51+
52+
// Export the data from a table to S3.
53+
func (d *Client) exportTableData(ctx context.Context, table string, params provider.DumpParams) error {
54+
d.Logger.Printf("Exporting data to S3 for table: %s", table)
55+
3656
cols, err := providerutils.QueryColumnsForTable(ctx, d.Conn, table, params)
3757
if err != nil {
38-
return "", err
58+
return err
3959
}
4060

4161
query := fmt.Sprintf("SELECT %s", strings.Join(cols, ", "))
@@ -50,26 +70,28 @@ func (d *Client) GetSelectQueryForTable(ctx context.Context, table string, param
5070
query = fmt.Sprintf("%s MANIFEST ON", query)
5171
query = fmt.Sprintf("%s OVERWRITE ON", query)
5272

53-
importQuery, err := d.GetLoadQueryForTable(table)
73+
_, err = d.Conn.QueryContext(ctx, query)
5474
if err != nil {
55-
return "", err
75+
return fmt.Errorf("error exporting data to S3 for table %s: %w", table, err)
5676
}
5777

58-
fmt.Println(importQuery)
59-
return query, nil
78+
return nil
6079
}
6180

62-
// GetLoadQueryForTable will return a complete SELECT query to fetch data from a table.
63-
func (d *Client) GetLoadQueryForTable(table string) (string, error) {
81+
// Write a LOAD DATA FROM S3 query to the provided writer.
82+
func (d *Client) writeLoadQueryForTable(w io.Writer, table string) error {
6483
if table == "" {
65-
return "", fmt.Errorf("error: no table specified")
84+
return fmt.Errorf("error: no table specified")
6685
}
86+
6787
if d.Region == "" || len(strings.Split(d.Region, "-")) != 3 {
68-
return "", fmt.Errorf("error: region is not configured correctly")
88+
return fmt.Errorf("error: region is not configured correctly")
6989
}
90+
7091
path := strings.TrimPrefix(d.URI, "s3://")
7192
query := fmt.Sprintf("LOAD DATA FROM S3 MANIFEST 'S3-%s://%s/%s.csv.manifest' INTO TABLE `%s` CHARACTER SET utf8mb4", d.Region, path, table, table)
7293
query = fmt.Sprintf("%s FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query)
7394

74-
return query, nil
95+
_, err := fmt.Fprintln(w, query)
96+
return err
7597
}
Lines changed: 56 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,69 @@
11
package rds
22

33
import (
4+
"bytes"
45
"context"
6+
"io"
57
"log"
6-
"os"
8+
"regexp"
79
"testing"
810

911
"github.com/DATA-DOG/go-sqlmock"
10-
"github.com/stretchr/testify/assert"
11-
12-
"github.com/skpr/mtk/internal/mysql/mock"
1312
"github.com/skpr/mtk/internal/mysql/provider"
1413
)
1514

16-
func TestMySQLGetExportSelectQueryFor(t *testing.T) {
17-
db, mock := mock.GetDB(t)
18-
dumper := NewClient(db, log.New(os.Stdout, "", 0), "ap-southheast-2", "s3://path/to/bucket")
19-
mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows(
20-
sqlmock.NewRows([]string{"c1", "c2"}).AddRow("a", "b"))
21-
query, err := dumper.GetSelectQueryForTable(context.TODO(), "table", provider.DumpParams{
22-
SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}},
23-
WhereMap: map[string]string{"table": "c1 > 0"},
24-
})
25-
assert.Nil(t, err)
26-
assert.Equal(t, "SELECT `c1`, NOW() AS `c2` FROM `table` WHERE c1 > 0 INTO OUTFILE S3 's3://path/to/bucket/table.csv' FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n' MANIFEST ON OVERWRITE ON", query)
27-
}
15+
func TestWriteTableData(t *testing.T) {
16+
db, mock, err := sqlmock.New()
17+
if err != nil {
18+
t.Fatalf("sqlmock.New error: %v", err)
19+
}
20+
defer db.Close()
21+
22+
conn, err := db.Conn(context.Background())
23+
if err != nil {
24+
t.Fatalf("db.Conn error: %v", err)
25+
}
26+
27+
// 1) QueryColumnsForTable runs: expect the probe query and return columns.
28+
mock.
29+
ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` LIMIT 1")).
30+
WillReturnRows(sqlmock.NewRows([]string{"id", "name"})) // Columns only
31+
32+
// 2) exportTableData runs: expect the exact OUTFILE S3 query with backticked cols.
33+
exportSQL := "SELECT `id`, `name` FROM `users` INTO OUTFILE S3 's3://my-bucket/prefix/users.csv' " +
34+
"FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n' MANIFEST ON OVERWRITE ON"
35+
mock.
36+
ExpectQuery(regexp.QuoteMeta(exportSQL)).
37+
WillReturnRows(sqlmock.NewRows([]string{}))
38+
39+
// Client + buffer for the LOAD DATA query
40+
buf := &bytes.Buffer{}
41+
c := &Client{
42+
Conn: conn,
43+
Logger: log.New(io.Discard, "", 0),
44+
Region: "ap-southeast-2",
45+
URI: "s3://my-bucket/prefix",
46+
}
47+
48+
// Run the full flow
49+
if err := c.WriteTableData(context.Background(), buf, "users", provider.DumpParams{
50+
WhereMap: map[string]string{},
51+
SelectMap: map[string]map[string]string{},
52+
}); err != nil {
53+
t.Fatalf("WriteTableData error: %v", err)
54+
}
55+
56+
// Assert the LOAD DATA query is exactly written (including newline)
57+
got := buf.String()
58+
want := "LOAD DATA FROM S3 MANIFEST 'S3-ap-southeast-2://my-bucket/prefix/users.csv.manifest' " +
59+
"INTO TABLE `users` CHARACTER SET utf8mb4 FIELDS TERMINATED BY ',' ENCLOSED BY '\"' " +
60+
"LINES TERMINATED BY '\\n'\n"
61+
if got != want {
62+
t.Fatalf("unexpected load query:\n--- got ---\n%s\n--- want ---\n%s", got, want)
63+
}
2864

29-
func TestMySQLGetLoadQueryFor(t *testing.T) {
30-
db, _ := mock.GetDB(t)
31-
dumper := NewClient(db, log.New(os.Stdout, "", 0), "ap-southeast-4", "s3://path/to/bucket")
32-
query, err := dumper.GetLoadQueryForTable("table_name")
33-
assert.Nil(t, err)
34-
assert.Equal(t, "LOAD DATA FROM S3 MANIFEST 'S3-ap-southeast-4://path/to/bucket/table_name.csv.manifest' INTO TABLE `table_name` CHARACTER SET utf8mb4 FIELDS TERMINATED BY ',' ENCLOSED BY '\"' LINES TERMINATED BY '\\n'", query)
65+
// Ensure all expectations met
66+
if err := mock.ExpectationsWereMet(); err != nil {
67+
t.Fatalf("unmet sqlmock expectations: %v", err)
68+
}
3569
}

internal/mysql/provider/stdout/provider.go

Lines changed: 72 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"database/sql"
66
"fmt"
7+
"io"
78
"log"
89
"strings"
910

@@ -26,11 +27,11 @@ func NewClient(conn *sql.Conn, logger *log.Logger) *Client {
2627
}
2728
}
2829

29-
// GetSelectQueryForTable will return a complete SELECT query to fetch data from a table.
30-
func (d *Client) GetSelectQueryForTable(ctx context.Context, table string, params provider.DumpParams) (string, error) {
30+
// WriteTableData will write the data from a table to the provided writer.
31+
func (d *Client) WriteTableData(ctx context.Context, w io.Writer, table string, params provider.DumpParams) error {
3132
cols, err := providerutils.QueryColumnsForTable(ctx, d.Conn, table, params)
3233
if err != nil {
33-
return "", err
34+
return err
3435
}
3536

3637
query := fmt.Sprintf("SELECT %s FROM `%s`", strings.Join(cols, ", "), table)
@@ -39,5 +40,72 @@ func (d *Client) GetSelectQueryForTable(ctx context.Context, table string, param
3940
query = fmt.Sprintf("%s WHERE %s", query, where)
4041
}
4142

42-
return query, nil
43+
rows, err := d.Conn.QueryContext(ctx, query)
44+
if err != nil {
45+
return err
46+
}
47+
48+
columns, err := rows.Columns()
49+
if err != nil {
50+
return err
51+
}
52+
53+
defer rows.Close()
54+
55+
values := make([]*sql.RawBytes, len(columns))
56+
scanArgs := make([]interface{}, len(values))
57+
58+
for i := range values {
59+
scanArgs[i] = &values[i]
60+
}
61+
62+
var (
63+
counter = 0
64+
firstRun = true
65+
)
66+
67+
for rows.Next() {
68+
// We have already done a loop and need to close the previous insert statement.
69+
if counter >= params.ExtendedInsertRows {
70+
fmt.Fprintln(w, ";")
71+
counter = 0
72+
} else {
73+
if !firstRun {
74+
fmt.Fprint(w, ",")
75+
}
76+
}
77+
78+
if counter == 0 {
79+
fmt.Fprintf(w, "INSERT INTO `%s` VALUES ", table)
80+
}
81+
82+
counter++
83+
84+
firstRun = false
85+
86+
if err = rows.Scan(scanArgs...); err != nil {
87+
return err
88+
}
89+
90+
var vals []string
91+
92+
for _, col := range values {
93+
val := "NULL"
94+
95+
if col != nil {
96+
val, err = getValue(string(*col))
97+
if err != nil {
98+
return err
99+
}
100+
}
101+
102+
vals = append(vals, val)
103+
}
104+
105+
fmt.Fprintf(w, "(%s)", strings.Join(vals, ","))
106+
}
107+
108+
fmt.Fprintln(w, ";")
109+
110+
return nil
43111
}
Lines changed: 59 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,72 @@
11
package stdout
22

33
import (
4+
"bytes"
45
"context"
5-
"errors"
6+
"io"
67
"log"
7-
"os"
8+
"regexp"
89
"testing"
910

1011
"github.com/DATA-DOG/go-sqlmock"
11-
"github.com/stretchr/testify/assert"
12-
13-
"github.com/skpr/mtk/internal/mysql/mock"
1412
"github.com/skpr/mtk/internal/mysql/provider"
1513
)
1614

17-
func TestMySQLGetSelectQueryFor(t *testing.T) {
18-
db, mock := mock.GetDB(t)
19-
dumper := NewClient(db, log.New(os.Stdout, "", 0))
20-
mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnRows(
21-
sqlmock.NewRows([]string{"c1", "c2"}).AddRow("a", "b"))
22-
query, err := dumper.GetSelectQueryForTable(context.TODO(), "table", provider.DumpParams{
23-
SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}},
24-
WhereMap: map[string]string{"table": "c1 > 0"},
25-
})
26-
assert.Nil(t, err)
27-
assert.Equal(t, "SELECT `c1`, NOW() AS `c2` FROM `table` WHERE c1 > 0", query)
28-
}
15+
func TestWriteTableData(t *testing.T) {
16+
db, mock, err := sqlmock.New()
17+
if err != nil {
18+
t.Fatalf("sqlmock.New error: %v", err)
19+
}
20+
defer db.Close()
21+
22+
conn, err := db.Conn(context.Background())
23+
if err != nil {
24+
t.Fatalf("db.Conn error: %v", err)
25+
}
26+
27+
// Probe from utils.QueryColumnsForTable
28+
mock.ExpectQuery(regexp.QuoteMeta("SELECT * FROM `users` LIMIT 1")).
29+
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}))
30+
31+
// Main SELECT that GetSelectQueryForTable issues
32+
mainSelect := "SELECT `id`, `name` FROM `users`"
33+
rows := sqlmock.NewRows([]string{"id", "name"}).
34+
AddRow(1, "Alice").
35+
AddRow(2, "Bob").
36+
AddRow(3, "Charlie").
37+
AddRow(4, "Dana").
38+
AddRow(5, "Eve")
39+
mock.ExpectQuery(regexp.QuoteMeta(mainSelect)).WillReturnRows(rows)
40+
41+
var out bytes.Buffer
42+
c := &Client{
43+
Conn: conn,
44+
Logger: log.New(io.Discard, "", 0),
45+
}
46+
47+
params := provider.DumpParams{
48+
ExtendedInsertRows: 2,
49+
WhereMap: map[string]string{},
50+
SelectMap: map[string]map[string]string{},
51+
}
52+
53+
if err := c.WriteTableData(context.Background(), &out, "users", params); err != nil {
54+
t.Fatalf("WriteTableData returned error: %v", err)
55+
}
56+
57+
got := out.String()
58+
59+
// Expect two batches of 2 and one batch of 1 (with final semicolon inline).
60+
want :=
61+
"INSERT INTO `users` VALUES (1,'Alice'),(2,'Bob');\n" +
62+
"INSERT INTO `users` VALUES (3,'Charlie'),(4,'Dana');\n" +
63+
"INSERT INTO `users` VALUES (5,'Eve');\n"
64+
65+
if got != want {
66+
t.Fatalf("unexpected writer output:\n--- got ---\n%s--- want ---\n%s", got, want)
67+
}
2968

30-
func TestMySQLGetSelectQueryForHandlingError(t *testing.T) {
31-
db, mock := mock.GetDB(t)
32-
dumper := NewClient(db, log.New(os.Stdout, "", 0))
33-
e := errors.New("broken")
34-
mock.ExpectQuery("SELECT \\* FROM `table` LIMIT 1").WillReturnError(e)
35-
query, err := dumper.GetSelectQueryForTable(context.TODO(), "table", provider.DumpParams{
36-
SelectMap: map[string]map[string]string{"table": {"c2": "NOW()"}},
37-
WhereMap: map[string]string{"table": "c1 > 0"},
38-
})
39-
assert.Equal(t, e, err)
40-
assert.Equal(t, "", query)
69+
if err := mock.ExpectationsWereMet(); err != nil {
70+
t.Fatalf("unmet sqlmock expectations: %v", err)
71+
}
4172
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package mysql
1+
package stdout
22

33
import (
44
"bytes"

internal/mysql/utils_test.go renamed to internal/mysql/provider/stdout/utils_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package mysql
1+
package stdout
22

33
import (
44
"testing"

0 commit comments

Comments
 (0)