Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the WithSQLTransformer option #696

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 0 additions & 38 deletions internal/testutils/sql_transformer.go

This file was deleted.

4 changes: 0 additions & 4 deletions internal/testutils/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,6 @@ func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll,
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn)
}

func WithMigratorAndConnectionToContainerWithOptions(t *testing.T, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) {
WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", opts, fn)
}

// setupTestDatabase creates a new database in the test container and returns:
// - a connection to the new database
// - the connection string to the new database
Expand Down
21 changes: 4 additions & 17 deletions pkg/migrations/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ type Operation interface {
// version in the database (through a view)
// update the given views to expose the new schema version
// Returns the table that requires backfilling, if any.
Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error)
Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error)

// Complete will update the database schema to match the current version
// after calling Start.
// This method should be called once the previous version is no longer used.
Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error
Complete(ctx context.Context, conn db.DB, s *schema.Schema) error

// Rollback will revert the changes made by Start. It is not possible to
// rollback a completed migration.
Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error
Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error

// Validate returns a descriptive error if the operation cannot be applied to the given schema.
Validate(ctx context.Context, s *schema.Schema) error
Expand All @@ -48,18 +48,6 @@ type RequiresSchemaRefreshOperation interface {
RequiresSchemaRefresh()
}

// SQLTransformer is an interface that can be used to transform SQL statements.
type SQLTransformer interface {
// TransformSQL will transform the given SQL statement.
TransformSQL(sql string) (string, error)
}

type SQLTransformerFunc func(string) (string, error)

func (fn SQLTransformerFunc) TransformSQL(sql string) (string, error) {
return fn(sql)
}

type (
Operations []Operation
Migration struct {
Expand Down Expand Up @@ -94,12 +82,11 @@ func (m *Migration) Validate(ctx context.Context, s *schema.Schema) error {
// made by the migration. No changes are made to the physical database.
func (m *Migration) UpdateVirtualSchema(ctx context.Context, s *schema.Schema) error {
db := &db.FakeDB{}
tr := SQLTransformerFunc(func(sql string) (string, error) { return sql, nil })

// Run `Start` on each operation using the fake DB. Updates will be made to
// the in-memory schema `s` without touching the physical database.
for _, op := range m.Operations {
if _, err := op.Start(ctx, db, "", tr, s); err != nil {
if _, err := op.Start(ctx, db, "", s); err != nil {
return err
}
}
Expand Down
23 changes: 9 additions & 14 deletions pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ import (

var _ Operation = (*OpAddColumn)(nil)

func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) {
func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) {
table := s.GetTable(o.Table)
if table == nil {
return nil, TableDoesNotExistError{Name: o.Table}
}

if err := addColumn(ctx, conn, *o, table, tr); err != nil {
if err := addColumn(ctx, conn, *o, table); err != nil {
return nil, fmt.Errorf("failed to start add column operation: %w", err)
}

Expand Down Expand Up @@ -53,7 +53,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string

var tableToBackfill *schema.Table
if o.Up != "" {
err := createTrigger(ctx, conn, tr, triggerConfig{
err := createTrigger(ctx, conn, triggerConfig{
Name: TriggerName(o.Table, o.Column.Name),
Direction: TriggerDirectionUp,
Columns: table.Columns,
Expand All @@ -76,7 +76,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string
return tableToBackfill, nil
}

func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error {
tempName := TemporaryName(o.Column.Name)

_, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s",
Expand Down Expand Up @@ -148,7 +148,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransforme
return nil
}

func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error {
table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
Expand Down Expand Up @@ -247,7 +247,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error {
return nil
}

func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error {
func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table) error {
// don't add non-nullable columns with no default directly
// they are handled by:
// - adding the column as nullable
Expand Down Expand Up @@ -280,7 +280,7 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table,
o.Column.Unique = false

o.Column.Name = TemporaryName(o.Column.Name)
columnWriter := ColumnSQLWriter{WithPK: true, Transformer: tr}
columnWriter := ColumnSQLWriter{WithPK: true}
colSQL, err := columnWriter.Write(o.Column)
if err != nil {
return err
Expand Down Expand Up @@ -331,8 +331,7 @@ func IsNotNullConstraintName(name string) bool {
// It can optionally include the primary key constraint
// When creating a table, the primary key constraint is not added to the column definition
type ColumnSQLWriter struct {
WithPK bool
Transformer SQLTransformer
WithPK bool
}

func (w ColumnSQLWriter) Write(col Column) (string, error) {
Expand All @@ -349,11 +348,7 @@ func (w ColumnSQLWriter) Write(col Column) (string, error) {
sql += " NOT NULL"
}
if col.Default != nil {
d, err := w.Transformer.TransformSQL(*col.Default)
if err != nil {
return "", err
}
sql += fmt.Sprintf(" DEFAULT %s", d)
sql += fmt.Sprintf(" DEFAULT %s", *col.Default)
}

if col.Generated != nil {
Expand Down
101 changes: 0 additions & 101 deletions pkg/migrations/op_add_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"github.com/xataio/pgroll/internal/testutils"
"github.com/xataio/pgroll/pkg/backfill"
"github.com/xataio/pgroll/pkg/migrations"
"github.com/xataio/pgroll/pkg/roll"
)

func TestAddColumn(t *testing.T) {
Expand Down Expand Up @@ -1611,106 +1610,6 @@ func TestAddColumnWithComment(t *testing.T) {
}})
}

func TestAddColumnDefaultTransformation(t *testing.T) {
t.Parallel()

sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{
"'default value 1'": "'rewritten'",
"'default value 2'": testutils.MockSQLTransformerError,
})

ExecuteTests(t, TestCases{
{
name: "column default is rewritten by the SQL transformer",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
},
},
},
},
{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "users",
Column: migrations.Column{
Name: "name",
Type: "text",
Default: ptr("'default value 1'"),
},
},
},
},
},
afterStart: func(t *testing.T, db *sql.DB, schema string) {
// Insert some data into the table
MustInsert(t, db, schema, "02_add_column", "users", map[string]string{
"id": "1",
})

// Ensure the row has the rewritten default value.
rows := MustSelect(t, db, schema, "02_add_column", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "rewritten"},
}, rows)
},
afterRollback: func(t *testing.T, db *sql.DB, schema string) {
},
afterComplete: func(t *testing.T, db *sql.DB, schema string) {
// Ensure the row has the rewritten default value.
rows := MustSelect(t, db, schema, "02_add_column", "users")
assert.Equal(t, []map[string]any{
{"id": 1, "name": "rewritten"},
}, rows)
},
},
{
name: "operation fails when the SQL transformer returns an error",
migrations: []migrations.Migration{
{
Name: "01_create_table",
Operations: migrations.Operations{
&migrations.OpCreateTable{
Name: "users",
Columns: []migrations.Column{
{
Name: "id",
Type: "serial",
Pk: true,
},
},
},
},
},
{
Name: "02_add_column",
Operations: migrations.Operations{
&migrations.OpAddColumn{
Table: "users",
Column: migrations.Column{
Name: "name",
Type: "text",
Default: ptr("'default value 2'"),
},
},
},
},
},
wantStartErr: testutils.ErrMockSQLTransformer,
},
}, roll.WithSQLTransformer(sqlTransformer))
}

func TestAddColumnInMultiOperationMigrations(t *testing.T) {
t.Parallel()

Expand Down
16 changes: 8 additions & 8 deletions pkg/migrations/op_alter_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

var _ Operation = (*OpAlterColumn)(nil)

func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) {
func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) {
table := s.GetTable(o.Table)
if table == nil {
return nil, TableDoesNotExistError{Name: o.Table}
Expand All @@ -34,7 +34,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri
}

// Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL.
err := createTrigger(ctx, conn, tr, triggerConfig{
err := createTrigger(ctx, conn, triggerConfig{
Name: TriggerName(o.Table, o.Column),
Direction: TriggerDirectionUp,
Columns: table.Columns,
Expand All @@ -58,7 +58,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri
})

// Add a trigger to copy values from the new column to the old.
err = createTrigger(ctx, conn, tr, triggerConfig{
err = createTrigger(ctx, conn, triggerConfig{
Name: TriggerName(o.Table, TemporaryName(o.Column)),
Direction: TriggerDirectionDown,
Columns: table.Columns,
Expand All @@ -74,20 +74,20 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri

// perform any operation specific start steps
for _, op := range ops {
if _, err := op.Start(ctx, conn, latestSchema, tr, s); err != nil {
if _, err := op.Start(ctx, conn, latestSchema, s); err != nil {
return nil, err
}
}

return table, nil
}

func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error {
ops := o.subOperations()

// Perform any operation specific completion steps
for _, op := range ops {
if err := op.Complete(ctx, conn, tr, s); err != nil {
if err := op.Complete(ctx, conn, s); err != nil {
return err
}
}
Expand Down Expand Up @@ -142,7 +142,7 @@ func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransfor
return nil
}

func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error {
table := s.GetTable(o.Table)
if table == nil {
return TableDoesNotExistError{Name: o.Table}
Expand All @@ -155,7 +155,7 @@ func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransfor
// Perform any operation specific rollback steps
ops := o.subOperations()
for _, ops := range ops {
if err := ops.Rollback(ctx, conn, tr, nil); err != nil {
if err := ops.Rollback(ctx, conn, nil); err != nil {
return err
}
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/migrations/op_change_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type OpChangeType struct {

var _ Operation = (*OpChangeType)(nil)

func (o *OpChangeType) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) {
func (o *OpChangeType) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) {
table := s.GetTable(o.Table)
if table == nil {
return nil, TableDoesNotExistError{Name: o.Table}
Expand All @@ -28,11 +28,11 @@ func (o *OpChangeType) Start(ctx context.Context, conn db.DB, latestSchema strin
return table, nil
}

func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error {
return nil
}

func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error {
func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error {
return nil
}

Expand Down
Loading