diff --git a/internal/testutils/sql_transformer.go b/internal/testutils/sql_transformer.go deleted file mode 100644 index bdaf6151a..000000000 --- a/internal/testutils/sql_transformer.go +++ /dev/null @@ -1,38 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package testutils - -import "errors" - -type MockSQLTransformer struct { - transformations map[string]string -} - -const MockSQLTransformerError = "ERROR" - -var ErrMockSQLTransformer = errors.New("SQL transformer error") - -// NewMockSQLTransformer creates a MockSQLTransformer with the given transformations. -// The transformations map is a map of input SQL to output SQL. If the output -// SQL is "ERROR", the transformer will return an error on that input. -func NewMockSQLTransformer(ts map[string]string) *MockSQLTransformer { - return &MockSQLTransformer{ - transformations: ts, - } -} - -// TransformSQL transforms the given SQL string according to the transformations -// provided to the MockSQLTransformer. If the input SQL is not in the transformations -// map, the input SQL is returned unchanged. -func (s *MockSQLTransformer) TransformSQL(sql string) (string, error) { - out, found := s.transformations[sql] - if !found { - return sql, nil - } - - if out == MockSQLTransformerError { - return "", ErrMockSQLTransformer - } - - return out, nil -} diff --git a/internal/testutils/util.go b/internal/testutils/util.go index 4893878a4..12d81796e 100644 --- a/internal/testutils/util.go +++ b/internal/testutils/util.go @@ -243,10 +243,6 @@ func WithMigratorAndConnectionToContainer(t *testing.T, fn func(mig *roll.Roll, WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", []roll.Option{roll.WithLockTimeoutMs(500)}, fn) } -func WithMigratorAndConnectionToContainerWithOptions(t *testing.T, opts []roll.Option, fn func(mig *roll.Roll, db *sql.DB)) { - WithMigratorInSchemaAndConnectionToContainerWithOptions(t, "public", opts, fn) -} - // setupTestDatabase creates a new database in the test container and returns: // - a connection to the new database // - the connection string to the new database diff --git a/pkg/migrations/migrations.go b/pkg/migrations/migrations.go index d62f38d87..12c1efd82 100644 --- a/pkg/migrations/migrations.go +++ b/pkg/migrations/migrations.go @@ -18,16 +18,16 @@ type Operation interface { // version in the database (through a view) // update the given views to expose the new schema version // Returns the table that requires backfilling, if any. - Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) + Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) // Complete will update the database schema to match the current version // after calling Start. // This method should be called once the previous version is no longer used. - Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error + Complete(ctx context.Context, conn db.DB, s *schema.Schema) error // Rollback will revert the changes made by Start. It is not possible to // rollback a completed migration. - Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error + Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error // Validate returns a descriptive error if the operation cannot be applied to the given schema. Validate(ctx context.Context, s *schema.Schema) error @@ -48,18 +48,6 @@ type RequiresSchemaRefreshOperation interface { RequiresSchemaRefresh() } -// SQLTransformer is an interface that can be used to transform SQL statements. -type SQLTransformer interface { - // TransformSQL will transform the given SQL statement. - TransformSQL(sql string) (string, error) -} - -type SQLTransformerFunc func(string) (string, error) - -func (fn SQLTransformerFunc) TransformSQL(sql string) (string, error) { - return fn(sql) -} - type ( Operations []Operation Migration struct { @@ -94,12 +82,11 @@ func (m *Migration) Validate(ctx context.Context, s *schema.Schema) error { // made by the migration. No changes are made to the physical database. func (m *Migration) UpdateVirtualSchema(ctx context.Context, s *schema.Schema) error { db := &db.FakeDB{} - tr := SQLTransformerFunc(func(sql string) (string, error) { return sql, nil }) // Run `Start` on each operation using the fake DB. Updates will be made to // the in-memory schema `s` without touching the physical database. for _, op := range m.Operations { - if _, err := op.Start(ctx, db, "", tr, s); err != nil { + if _, err := op.Start(ctx, db, "", s); err != nil { return err } } diff --git a/pkg/migrations/op_add_column.go b/pkg/migrations/op_add_column.go index 2a47ffa53..dd8aa14db 100644 --- a/pkg/migrations/op_add_column.go +++ b/pkg/migrations/op_add_column.go @@ -17,13 +17,13 @@ import ( var _ Operation = (*OpAddColumn)(nil) -func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} } - if err := addColumn(ctx, conn, *o, table, tr); err != nil { + if err := addColumn(ctx, conn, *o, table); err != nil { return nil, fmt.Errorf("failed to start add column operation: %w", err) } @@ -53,7 +53,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string var tableToBackfill *schema.Table if o.Up != "" { - err := createTrigger(ctx, conn, tr, triggerConfig{ + err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, o.Column.Name), Direction: TriggerDirectionUp, Columns: table.Columns, @@ -76,7 +76,7 @@ func (o *OpAddColumn) Start(ctx context.Context, conn db.DB, latestSchema string return tableToBackfill, nil } -func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { tempName := TemporaryName(o.Column.Name) _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", @@ -148,7 +148,7 @@ func (o *OpAddColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransforme return nil } -func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAddColumn) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { table := s.GetTable(o.Table) if table == nil { return TableDoesNotExistError{Name: o.Table} @@ -247,7 +247,7 @@ func (o *OpAddColumn) Validate(ctx context.Context, s *schema.Schema) error { return nil } -func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, tr SQLTransformer) error { +func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table) error { // don't add non-nullable columns with no default directly // they are handled by: // - adding the column as nullable @@ -280,7 +280,7 @@ func addColumn(ctx context.Context, conn db.DB, o OpAddColumn, t *schema.Table, o.Column.Unique = false o.Column.Name = TemporaryName(o.Column.Name) - columnWriter := ColumnSQLWriter{WithPK: true, Transformer: tr} + columnWriter := ColumnSQLWriter{WithPK: true} colSQL, err := columnWriter.Write(o.Column) if err != nil { return err @@ -331,8 +331,7 @@ func IsNotNullConstraintName(name string) bool { // It can optionally include the primary key constraint // When creating a table, the primary key constraint is not added to the column definition type ColumnSQLWriter struct { - WithPK bool - Transformer SQLTransformer + WithPK bool } func (w ColumnSQLWriter) Write(col Column) (string, error) { @@ -349,11 +348,7 @@ func (w ColumnSQLWriter) Write(col Column) (string, error) { sql += " NOT NULL" } if col.Default != nil { - d, err := w.Transformer.TransformSQL(*col.Default) - if err != nil { - return "", err - } - sql += fmt.Sprintf(" DEFAULT %s", d) + sql += fmt.Sprintf(" DEFAULT %s", *col.Default) } if col.Generated != nil { diff --git a/pkg/migrations/op_add_column_test.go b/pkg/migrations/op_add_column_test.go index 8313c1299..24b161e4c 100644 --- a/pkg/migrations/op_add_column_test.go +++ b/pkg/migrations/op_add_column_test.go @@ -12,7 +12,6 @@ import ( "github.com/xataio/pgroll/internal/testutils" "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" - "github.com/xataio/pgroll/pkg/roll" ) func TestAddColumn(t *testing.T) { @@ -1611,106 +1610,6 @@ func TestAddColumnWithComment(t *testing.T) { }}) } -func TestAddColumnDefaultTransformation(t *testing.T) { - t.Parallel() - - sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{ - "'default value 1'": "'rewritten'", - "'default value 2'": testutils.MockSQLTransformerError, - }) - - ExecuteTests(t, TestCases{ - { - name: "column default is rewritten by the SQL transformer", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "users", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - Pk: true, - }, - }, - }, - }, - }, - { - Name: "02_add_column", - Operations: migrations.Operations{ - &migrations.OpAddColumn{ - Table: "users", - Column: migrations.Column{ - Name: "name", - Type: "text", - Default: ptr("'default value 1'"), - }, - }, - }, - }, - }, - afterStart: func(t *testing.T, db *sql.DB, schema string) { - // Insert some data into the table - MustInsert(t, db, schema, "02_add_column", "users", map[string]string{ - "id": "1", - }) - - // Ensure the row has the rewritten default value. - rows := MustSelect(t, db, schema, "02_add_column", "users") - assert.Equal(t, []map[string]any{ - {"id": 1, "name": "rewritten"}, - }, rows) - }, - afterRollback: func(t *testing.T, db *sql.DB, schema string) { - }, - afterComplete: func(t *testing.T, db *sql.DB, schema string) { - // Ensure the row has the rewritten default value. - rows := MustSelect(t, db, schema, "02_add_column", "users") - assert.Equal(t, []map[string]any{ - {"id": 1, "name": "rewritten"}, - }, rows) - }, - }, - { - name: "operation fails when the SQL transformer returns an error", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "users", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - Pk: true, - }, - }, - }, - }, - }, - { - Name: "02_add_column", - Operations: migrations.Operations{ - &migrations.OpAddColumn{ - Table: "users", - Column: migrations.Column{ - Name: "name", - Type: "text", - Default: ptr("'default value 2'"), - }, - }, - }, - }, - }, - wantStartErr: testutils.ErrMockSQLTransformer, - }, - }, roll.WithSQLTransformer(sqlTransformer)) -} - func TestAddColumnInMultiOperationMigrations(t *testing.T) { t.Parallel() diff --git a/pkg/migrations/op_alter_column.go b/pkg/migrations/op_alter_column.go index 10e768f5a..a2fd14c41 100644 --- a/pkg/migrations/op_alter_column.go +++ b/pkg/migrations/op_alter_column.go @@ -15,7 +15,7 @@ import ( var _ Operation = (*OpAlterColumn)(nil) -func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -34,7 +34,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. - err := createTrigger(ctx, conn, tr, triggerConfig{ + err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, o.Column), Direction: TriggerDirectionUp, Columns: table.Columns, @@ -58,7 +58,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri }) // Add a trigger to copy values from the new column to the old. - err = createTrigger(ctx, conn, tr, triggerConfig{ + err = createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, TemporaryName(o.Column)), Direction: TriggerDirectionDown, Columns: table.Columns, @@ -74,7 +74,7 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri // perform any operation specific start steps for _, op := range ops { - if _, err := op.Start(ctx, conn, latestSchema, tr, s); err != nil { + if _, err := op.Start(ctx, conn, latestSchema, s); err != nil { return nil, err } } @@ -82,12 +82,12 @@ func (o *OpAlterColumn) Start(ctx context.Context, conn db.DB, latestSchema stri return table, nil } -func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { ops := o.subOperations() // Perform any operation specific completion steps for _, op := range ops { - if err := op.Complete(ctx, conn, tr, s); err != nil { + if err := op.Complete(ctx, conn, s); err != nil { return err } } @@ -142,7 +142,7 @@ func (o *OpAlterColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransfor return nil } -func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { table := s.GetTable(o.Table) if table == nil { return TableDoesNotExistError{Name: o.Table} @@ -155,7 +155,7 @@ func (o *OpAlterColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransfor // Perform any operation specific rollback steps ops := o.subOperations() for _, ops := range ops { - if err := ops.Rollback(ctx, conn, tr, nil); err != nil { + if err := ops.Rollback(ctx, conn, nil); err != nil { return err } } diff --git a/pkg/migrations/op_change_type.go b/pkg/migrations/op_change_type.go index 741ecbba3..423533c59 100644 --- a/pkg/migrations/op_change_type.go +++ b/pkg/migrations/op_change_type.go @@ -19,7 +19,7 @@ type OpChangeType struct { var _ Operation = (*OpChangeType)(nil) -func (o *OpChangeType) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpChangeType) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -28,11 +28,11 @@ func (o *OpChangeType) Start(ctx context.Context, conn db.DB, latestSchema strin return table, nil } -func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpChangeType) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } -func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpChangeType) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/op_create_constraint.go b/pkg/migrations/op_create_constraint.go index d7095a772..77d0b41db 100644 --- a/pkg/migrations/op_create_constraint.go +++ b/pkg/migrations/op_create_constraint.go @@ -15,7 +15,7 @@ import ( var _ Operation = (*OpCreateConstraint)(nil) -func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -41,7 +41,7 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema // Setup triggers for _, colName := range o.Columns { upSQL := o.Up[colName] - err := createTrigger(ctx, conn, tr, triggerConfig{ + err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, colName), Direction: TriggerDirectionUp, Columns: table.Columns, @@ -65,7 +65,7 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema }) downSQL := o.Down[colName] - err = createTrigger(ctx, conn, tr, triggerConfig{ + err = createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, TemporaryName(colName)), Direction: TriggerDirectionDown, Columns: table.Columns, @@ -92,14 +92,14 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema return table, nil } -func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { switch o.Type { case OpCreateConstraintTypeUnique: uniqueOp := &OpSetUnique{ Table: o.Table, Name: o.Name, } - err := uniqueOp.Complete(ctx, conn, tr, s) + err := uniqueOp.Complete(ctx, conn, s) if err != nil { return err } @@ -110,7 +110,7 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra Name: o.Name, }, } - err := checkOp.Complete(ctx, conn, tr, s) + err := checkOp.Complete(ctx, conn, s) if err != nil { return err } @@ -121,7 +121,7 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra Name: o.Name, }, } - err := fkOp.Complete(ctx, conn, tr, s) + err := fkOp.Complete(ctx, conn, s) if err != nil { return err } @@ -169,7 +169,7 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra return err } -func (o *OpCreateConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateConstraint) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { table := s.GetTable(o.Table) if table == nil { return TableDoesNotExistError{Name: o.Table} diff --git a/pkg/migrations/op_create_index.go b/pkg/migrations/op_create_index.go index 91412b4cd..56f6fe28a 100644 --- a/pkg/migrations/op_create_index.go +++ b/pkg/migrations/op_create_index.go @@ -15,7 +15,7 @@ import ( var _ Operation = (*OpCreateIndex)(nil) -func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -50,12 +50,12 @@ func (o *OpCreateIndex) Start(ctx context.Context, conn db.DB, latestSchema stri return nil, err } -func (o *OpCreateIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateIndex) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // No-op return nil } -func (o *OpCreateIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateIndex) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { // drop the index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", pq.QuoteIdentifier(o.Name))) diff --git a/pkg/migrations/op_create_table.go b/pkg/migrations/op_create_table.go index 4669b8faa..124b5e550 100644 --- a/pkg/migrations/op_create_table.go +++ b/pkg/migrations/op_create_table.go @@ -16,9 +16,9 @@ import ( var _ Operation = (*OpCreateTable)(nil) -func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { // Generate SQL for the columns in the table - columnsSQL, err := columnsToSQL(o.Columns, tr) + columnsSQL, err := columnsToSQL(o.Columns) if err != nil { return nil, fmt.Errorf("failed to create columns SQL: %w", err) } @@ -59,12 +59,12 @@ func (o *OpCreateTable) Start(ctx context.Context, conn db.DB, latestSchema stri return nil, nil } -func (o *OpCreateTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateTable) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // No-op return nil } -func (o *OpCreateTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpCreateTable) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP TABLE IF EXISTS %s", pq.QuoteIdentifier(o.Name))) return err @@ -267,10 +267,10 @@ func (o *OpCreateTable) updateSchema(s *schema.Schema) *schema.Schema { return s } -func columnsToSQL(cols []Column, tr SQLTransformer) (string, error) { +func columnsToSQL(cols []Column) (string, error) { var sql string var primaryKeys []string - columnWriter := ColumnSQLWriter{WithPK: false, Transformer: tr} + columnWriter := ColumnSQLWriter{WithPK: false} for i, col := range cols { if i > 0 { sql += ", " diff --git a/pkg/migrations/op_create_table_test.go b/pkg/migrations/op_create_table_test.go index acbb91722..5119786c4 100644 --- a/pkg/migrations/op_create_table_test.go +++ b/pkg/migrations/op_create_table_test.go @@ -12,7 +12,6 @@ import ( "github.com/xataio/pgroll/internal/testutils" "github.com/xataio/pgroll/pkg/migrations" - "github.com/xataio/pgroll/pkg/roll" ) func TestCreateTable(t *testing.T) { @@ -1500,95 +1499,6 @@ func TestCreateTableValidation(t *testing.T) { }) } -func TestCreateTableColumnDefaultTransformation(t *testing.T) { - t.Parallel() - - sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{ - "'default value 1'": "'rewritten'", - "'default value 2'": testutils.MockSQLTransformerError, - }) - - ExecuteTests(t, TestCases{ - { - name: "column default is rewritten by the SQL transformer", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "users", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - Pk: true, - }, - { - Name: "name", - Type: "text", - Default: ptr("'default value 1'"), - }, - }, - }, - }, - }, - }, - afterStart: func(t *testing.T, db *sql.DB, schema string) { - // Insert some data into the table - MustInsert(t, db, schema, "01_create_table", "users", map[string]string{ - "id": "1", - }) - - // Ensure the row has the rewritten default value. - rows := MustSelect(t, db, schema, "01_create_table", "users") - assert.Equal(t, []map[string]any{ - {"id": 1, "name": "rewritten"}, - }, rows) - }, - afterRollback: func(t *testing.T, db *sql.DB, schema string) { - }, - afterComplete: func(t *testing.T, db *sql.DB, schema string) { - // Insert some data into the table - MustInsert(t, db, schema, "01_create_table", "users", map[string]string{ - "id": "1", - }) - - // Ensure the row has the rewritten default value. - rows := MustSelect(t, db, schema, "01_create_table", "users") - assert.Equal(t, []map[string]any{ - {"id": 1, "name": "rewritten"}, - }, rows) - }, - }, - { - name: "create table fails when the SQL transformer returns an error", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpCreateTable{ - Name: "users", - Columns: []migrations.Column{ - { - Name: "id", - Type: "serial", - Pk: true, - }, - { - Name: "name", - Type: "text", - Default: ptr("'default value 2'"), - }, - }, - }, - }, - }, - }, - wantStartErr: testutils.ErrMockSQLTransformer, - }, - }, roll.WithSQLTransformer(sqlTransformer)) -} - func TestCreateTableValidationInMultiOperationMigrations(t *testing.T) { t.Parallel() diff --git a/pkg/migrations/op_drop_column.go b/pkg/migrations/op_drop_column.go index 0d56a04ed..d5a2c8695 100644 --- a/pkg/migrations/op_drop_column.go +++ b/pkg/migrations/op_drop_column.go @@ -13,9 +13,9 @@ import ( var _ Operation = (*OpDropColumn)(nil) -func (o *OpDropColumn) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpDropColumn) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { if o.Down != "" { - err := createTrigger(ctx, conn, tr, triggerConfig{ + err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, o.Column), Direction: TriggerDirectionDown, Columns: s.GetTable(o.Table).Columns, @@ -43,7 +43,7 @@ func (o *OpDropColumn) Start(ctx context.Context, conn db.DB, latestSchema strin return nil, nil } -func (o *OpDropColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropColumn) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s", pq.QuoteIdentifier(o.Table), pq.QuoteIdentifier(o.Column))) @@ -68,7 +68,7 @@ func (o *OpDropColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransform return nil } -func (o *OpDropColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropColumn) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { table := s.GetTable(o.Table) _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP FUNCTION IF EXISTS %s CASCADE", diff --git a/pkg/migrations/op_drop_constraint.go b/pkg/migrations/op_drop_constraint.go index a99bbe77c..f79325265 100644 --- a/pkg/migrations/op_drop_constraint.go +++ b/pkg/migrations/op_drop_constraint.go @@ -14,7 +14,7 @@ import ( var _ Operation = (*OpDropConstraint)(nil) -func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -34,7 +34,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, latestSchema s } // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. - err := createTrigger(ctx, conn, tr, triggerConfig{ + err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, column.Name), Direction: TriggerDirectionUp, Columns: table.Columns, @@ -56,7 +56,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, latestSchema s }) // Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL. - err = createTrigger(ctx, conn, tr, triggerConfig{ + err = createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, TemporaryName(column.Name)), Direction: TriggerDirectionDown, Columns: table.Columns, @@ -72,7 +72,7 @@ func (o *OpDropConstraint) Start(ctx context.Context, conn db.DB, latestSchema s return table, nil } -func (o *OpDropConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropConstraint) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // We have already validated that there is single column related to this constraint. table := s.GetTable(o.Table) column := table.GetColumn(table.GetConstraintColumns(o.Name)[0]) @@ -119,7 +119,7 @@ func (o *OpDropConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTrans return err } -func (o *OpDropConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropConstraint) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { // We have already validated that there is single column related to this constraint. table := s.GetTable(o.Table) columnName := table.GetConstraintColumns(o.Name)[0] diff --git a/pkg/migrations/op_drop_index.go b/pkg/migrations/op_drop_index.go index ea27a9413..1456e4ac4 100644 --- a/pkg/migrations/op_drop_index.go +++ b/pkg/migrations/op_drop_index.go @@ -13,12 +13,12 @@ import ( var _ Operation = (*OpDropIndex)(nil) -func (o *OpDropIndex) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpDropIndex) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { // no-op return nil, nil } -func (o *OpDropIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropIndex) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // drop the index concurrently _, err := conn.ExecContext(ctx, fmt.Sprintf("DROP INDEX CONCURRENTLY IF EXISTS %s", pq.QuoteIdentifier(o.Name))) @@ -26,7 +26,7 @@ func (o *OpDropIndex) Complete(ctx context.Context, conn db.DB, tr SQLTransforme return err } -func (o *OpDropIndex) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropIndex) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { // no-op return nil } diff --git a/pkg/migrations/op_drop_multicolumn_constraint.go b/pkg/migrations/op_drop_multicolumn_constraint.go index dc263fa60..bc1328be6 100644 --- a/pkg/migrations/op_drop_multicolumn_constraint.go +++ b/pkg/migrations/op_drop_multicolumn_constraint.go @@ -14,7 +14,7 @@ import ( var _ Operation = (*OpDropMultiColumnConstraint)(nil) -func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -44,7 +44,7 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, lat // Create triggers for each column covered by the constraint to be dropped for _, columnName := range table.GetConstraintColumns(o.Name) { // Add a trigger to copy values from the old column to the new, rewriting values using the `up` SQL. - err := createTrigger(ctx, conn, tr, triggerConfig{ + err := createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, columnName), Direction: TriggerDirectionUp, Columns: table.Columns, @@ -68,7 +68,7 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, lat }) // Add a trigger to copy values from the new column to the old, rewriting values using the `down` SQL. - err = createTrigger(ctx, conn, tr, triggerConfig{ + err = createTrigger(ctx, conn, triggerConfig{ Name: TriggerName(o.Table, TemporaryName(columnName)), Direction: TriggerDirectionDown, Columns: table.Columns, @@ -86,7 +86,7 @@ func (o *OpDropMultiColumnConstraint) Start(ctx context.Context, conn db.DB, lat return table, nil } -func (o *OpDropMultiColumnConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropMultiColumnConstraint) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { table := s.GetTable(o.Table) for _, columnName := range table.GetConstraintColumns(o.Name) { @@ -134,7 +134,7 @@ func (o *OpDropMultiColumnConstraint) Complete(ctx context.Context, conn db.DB, return nil } -func (o *OpDropMultiColumnConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropMultiColumnConstraint) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { table := s.GetTable(o.Table) for _, columnName := range table.GetConstraintColumns(o.Name) { diff --git a/pkg/migrations/op_drop_not_null.go b/pkg/migrations/op_drop_not_null.go index e5487570d..0ff0749d5 100644 --- a/pkg/migrations/op_drop_not_null.go +++ b/pkg/migrations/op_drop_not_null.go @@ -19,7 +19,7 @@ type OpDropNotNull struct { var _ Operation = (*OpDropNotNull)(nil) -func (o *OpDropNotNull) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpDropNotNull) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -28,11 +28,11 @@ func (o *OpDropNotNull) Start(ctx context.Context, conn db.DB, latestSchema stri return table, nil } -func (o *OpDropNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropNotNull) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } -func (o *OpDropNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropNotNull) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/op_drop_table.go b/pkg/migrations/op_drop_table.go index be7ed969a..d41b06371 100644 --- a/pkg/migrations/op_drop_table.go +++ b/pkg/migrations/op_drop_table.go @@ -13,7 +13,7 @@ import ( var _ Operation = (*OpDropTable)(nil) -func (o *OpDropTable) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpDropTable) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Name) if table == nil { return nil, TableDoesNotExistError{Name: o.Name} @@ -32,7 +32,7 @@ func (o *OpDropTable) Start(ctx context.Context, conn db.DB, latestSchema string return nil, nil } -func (o *OpDropTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropTable) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { deletionName := DeletionName(o.Name) // Perform the actual deletion of the soft-deleted table @@ -41,7 +41,7 @@ func (o *OpDropTable) Complete(ctx context.Context, conn db.DB, tr SQLTransforme return err } -func (o *OpDropTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpDropTable) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { // Mark the table as no longer deleted so that it is visible to preceding // Rollbacks in the same migration s.UnRemoveTable(o.Name) diff --git a/pkg/migrations/op_raw_sql.go b/pkg/migrations/op_raw_sql.go index 52a372743..f2437c5ef 100644 --- a/pkg/migrations/op_raw_sql.go +++ b/pkg/migrations/op_raw_sql.go @@ -11,45 +11,30 @@ import ( var _ Operation = (*OpRawSQL)(nil) -func (o *OpRawSQL) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpRawSQL) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { if o.OnComplete { return nil, nil } - up, err := tr.TransformSQL(o.Up) - if err != nil { - return nil, err - } - - _, err = conn.ExecContext(ctx, up) + _, err := conn.ExecContext(ctx, o.Up) return nil, err } -func (o *OpRawSQL) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRawSQL) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { if !o.OnComplete { return nil } - up, err := tr.TransformSQL(o.Up) - if err != nil { - return err - } - - _, err = conn.ExecContext(ctx, up) + _, err := conn.ExecContext(ctx, o.Up) return err } -func (o *OpRawSQL) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRawSQL) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { if o.Down == "" { return nil } - down, err := tr.TransformSQL(o.Down) - if err != nil { - return err - } - - _, err = conn.ExecContext(ctx, down) + _, err := conn.ExecContext(ctx, o.Down) return err } diff --git a/pkg/migrations/op_raw_sql_test.go b/pkg/migrations/op_raw_sql_test.go index b134dc8d8..53fc8951a 100644 --- a/pkg/migrations/op_raw_sql_test.go +++ b/pkg/migrations/op_raw_sql_test.go @@ -6,10 +6,7 @@ import ( "database/sql" "testing" - "github.com/xataio/pgroll/internal/testutils" - "github.com/xataio/pgroll/pkg/migrations" - "github.com/xataio/pgroll/pkg/roll" ) func TestRawSQL(t *testing.T) { @@ -187,106 +184,3 @@ func TestRawSQL(t *testing.T) { }, }) } - -func TestRawSQLTransformation(t *testing.T) { - t.Parallel() - - sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{ - "CREATE TABLE people(id int)": "CREATE TABLE users(id int)", - "DROP TABLE people": "DROP TABLE users", - "CREATE TABLE restricted(id int)": testutils.MockSQLTransformerError, - }) - - ExecuteTests(t, TestCases{ - { - name: "SQL transformer rewrites up and down SQL", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpRawSQL{ - Up: "CREATE TABLE people(id int)", - Down: "DROP TABLE people", - }, - }, - }, - }, - afterStart: func(t *testing.T, db *sql.DB, schema string) { - // The transformed `up` SQL was used in place of the original SQL - TableMustExist(t, db, schema, "users") - }, - afterRollback: func(t *testing.T, db *sql.DB, schema string) { - // The transformed `down` SQL was used in place of the original SQL - TableMustNotExist(t, db, schema, "users") - }, - afterComplete: func(t *testing.T, db *sql.DB, schema string) { - }, - }, - { - name: "SQL transformer rewrites up SQL when up is run on completion", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpRawSQL{ - Up: "CREATE TABLE people(id int)", - OnComplete: true, - }, - }, - }, - }, - afterStart: func(t *testing.T, db *sql.DB, schema string) { - }, - afterRollback: func(t *testing.T, db *sql.DB, schema string) { - }, - afterComplete: func(t *testing.T, db *sql.DB, schema string) { - // The transformed `up` SQL was used in place of the original SQL - TableMustExist(t, db, schema, "users") - }, - }, - { - name: "raw SQL operation fails when SQL transformer returns an error on up SQL", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpRawSQL{ - Up: "CREATE TABLE restricted(id int)", - }, - }, - }, - }, - wantStartErr: testutils.ErrMockSQLTransformer, - }, - { - name: "raw SQL operation fails when SQL transformer returns an error on down SQL", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpRawSQL{ - Up: "CREATE TABLE products(id int)", - Down: "CREATE TABLE restricted(id int)", - }, - }, - }, - }, - wantRollbackErr: testutils.ErrMockSQLTransformer, - }, - { - name: "raw SQL onComplete operation fails when SQL transformer returns an error on up SQL", - migrations: []migrations.Migration{ - { - Name: "01_create_table", - Operations: migrations.Operations{ - &migrations.OpRawSQL{ - Up: "CREATE TABLE restricted(id int)", - OnComplete: true, - }, - }, - }, - }, - wantCompleteErr: testutils.ErrMockSQLTransformer, - }, - }, roll.WithSQLTransformer(sqlTransformer)) -} diff --git a/pkg/migrations/op_rename_column.go b/pkg/migrations/op_rename_column.go index fc2f42987..e21f11f11 100644 --- a/pkg/migrations/op_rename_column.go +++ b/pkg/migrations/op_rename_column.go @@ -12,7 +12,7 @@ import ( var _ Operation = (*OpRenameColumn)(nil) -func (o *OpRenameColumn) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpRenameColumn) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { // Rename the table in the in-memory schema. table := s.GetTable(o.Table) if table == nil { @@ -31,7 +31,7 @@ func (o *OpRenameColumn) Start(ctx context.Context, conn db.DB, latestSchema str return nil, nil } -func (o *OpRenameColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameColumn) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // Rename the column _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME COLUMN %s TO %s", pq.QuoteIdentifier(o.Table), @@ -41,7 +41,7 @@ func (o *OpRenameColumn) Complete(ctx context.Context, conn db.DB, tr SQLTransfo return err } -func (o *OpRenameColumn) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameColumn) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { // Rename the column back to the original name in the in-memory schema. table := s.GetTable(o.Table) table.RenameColumn(o.To, o.From) diff --git a/pkg/migrations/op_rename_constraint.go b/pkg/migrations/op_rename_constraint.go index 2e87b755c..3bdf2cd7f 100644 --- a/pkg/migrations/op_rename_constraint.go +++ b/pkg/migrations/op_rename_constraint.go @@ -14,12 +14,12 @@ import ( var _ Operation = (*OpRenameConstraint)(nil) -func (o *OpRenameConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpRenameConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { // no-op return nil, nil } -func (o *OpRenameConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameConstraint) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // rename the constraint in the underlying table _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE %s RENAME CONSTRAINT %s TO %s", pq.QuoteIdentifier(o.Table), @@ -28,7 +28,7 @@ func (o *OpRenameConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra return err } -func (o *OpRenameConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameConstraint) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { // no-op return nil } diff --git a/pkg/migrations/op_rename_table.go b/pkg/migrations/op_rename_table.go index bfb565f4d..36956998b 100644 --- a/pkg/migrations/op_rename_table.go +++ b/pkg/migrations/op_rename_table.go @@ -14,11 +14,11 @@ import ( var _ Operation = (*OpRenameTable)(nil) -func (o *OpRenameTable) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpRenameTable) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { return nil, s.RenameTable(o.From, o.To) } -func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s RENAME TO %s", pq.QuoteIdentifier(o.From), pq.QuoteIdentifier(o.To))) @@ -26,7 +26,7 @@ func (o *OpRenameTable) Complete(ctx context.Context, conn db.DB, tr SQLTransfor return err } -func (o *OpRenameTable) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpRenameTable) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { s.RenameTable(o.To, o.From) return nil } diff --git a/pkg/migrations/op_set_check.go b/pkg/migrations/op_set_check.go index 56f6ea423..48d2274a5 100644 --- a/pkg/migrations/op_set_check.go +++ b/pkg/migrations/op_set_check.go @@ -23,7 +23,7 @@ type OpSetCheckConstraint struct { var _ Operation = (*OpSetCheckConstraint)(nil) -func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -37,7 +37,7 @@ func (o *OpSetCheckConstraint) Start(ctx context.Context, conn db.DB, latestSche return table, nil } -func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // Validate the check constraint _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", pq.QuoteIdentifier(o.Table), @@ -49,7 +49,7 @@ func (o *OpSetCheckConstraint) Complete(ctx context.Context, conn db.DB, tr SQLT return nil } -func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetCheckConstraint) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/op_set_comment.go b/pkg/migrations/op_set_comment.go index 84b00b792..07e9e52f7 100644 --- a/pkg/migrations/op_set_comment.go +++ b/pkg/migrations/op_set_comment.go @@ -20,7 +20,7 @@ type OpSetComment struct { var _ Operation = (*OpSetComment)(nil) -func (o *OpSetComment) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpSetComment) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { tbl := s.GetTable(o.Table) if tbl == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -29,11 +29,11 @@ func (o *OpSetComment) Start(ctx context.Context, conn db.DB, latestSchema strin return tbl, addCommentToColumn(ctx, conn, o.Table, TemporaryName(o.Column), o.Comment) } -func (o *OpSetComment) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetComment) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } -func (o *OpSetComment) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetComment) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/op_set_default.go b/pkg/migrations/op_set_default.go index 4adb7fa06..ce7918c75 100644 --- a/pkg/migrations/op_set_default.go +++ b/pkg/migrations/op_set_default.go @@ -22,7 +22,7 @@ type OpSetDefault struct { var _ Operation = (*OpSetDefault)(nil) -func (o *OpSetDefault) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpSetDefault) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -50,11 +50,11 @@ func (o *OpSetDefault) Start(ctx context.Context, conn db.DB, latestSchema strin return table, nil } -func (o *OpSetDefault) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetDefault) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } -func (o *OpSetDefault) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetDefault) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/op_set_fk.go b/pkg/migrations/op_set_fk.go index 5d450bf34..3df9fa551 100644 --- a/pkg/migrations/op_set_fk.go +++ b/pkg/migrations/op_set_fk.go @@ -22,7 +22,7 @@ type OpSetForeignKey struct { var _ Operation = (*OpSetForeignKey)(nil) -func (o *OpSetForeignKey) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpSetForeignKey) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) // Create a NOT VALID foreign key constraint on the new column. @@ -33,7 +33,7 @@ func (o *OpSetForeignKey) Start(ctx context.Context, conn db.DB, latestSchema st return table, nil } -func (o *OpSetForeignKey) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetForeignKey) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // Validate the foreign key constraint _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s VALIDATE CONSTRAINT %s", pq.QuoteIdentifier(o.Table), @@ -45,7 +45,7 @@ func (o *OpSetForeignKey) Complete(ctx context.Context, conn db.DB, tr SQLTransf return nil } -func (o *OpSetForeignKey) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetForeignKey) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/op_set_notnull.go b/pkg/migrations/op_set_notnull.go index 6f87ea537..e54c85c33 100644 --- a/pkg/migrations/op_set_notnull.go +++ b/pkg/migrations/op_set_notnull.go @@ -20,7 +20,7 @@ type OpSetNotNull struct { var _ Operation = (*OpSetNotNull)(nil) -func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -38,7 +38,7 @@ func (o *OpSetNotNull) Start(ctx context.Context, conn db.DB, latestSchema strin return table, nil } -func (o *OpSetNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetNotNull) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // Validate the NOT NULL constraint on the old column. // The constraint must be valid because: // * Existing NULL values in the old column were rewritten using the `up` SQL during backfill. @@ -69,7 +69,7 @@ func (o *OpSetNotNull) Complete(ctx context.Context, conn db.DB, tr SQLTransform return nil } -func (o *OpSetNotNull) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetNotNull) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/op_set_replica_identity.go b/pkg/migrations/op_set_replica_identity.go index 3aaa4bd53..9c995229b 100644 --- a/pkg/migrations/op_set_replica_identity.go +++ b/pkg/migrations/op_set_replica_identity.go @@ -15,7 +15,7 @@ import ( var _ Operation = (*OpSetReplicaIdentity)(nil) -func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { // build the correct form of the `SET REPLICA IDENTITY` statement based on the`identity type identitySQL := strings.ToUpper(o.Identity.Type) if identitySQL == "INDEX" { @@ -29,12 +29,12 @@ func (o *OpSetReplicaIdentity) Start(ctx context.Context, conn db.DB, latestSche return nil, err } -func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetReplicaIdentity) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // No-op return nil } -func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetReplicaIdentity) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { // No-op return nil } diff --git a/pkg/migrations/op_set_unique.go b/pkg/migrations/op_set_unique.go index c06f9dfe0..7c3519534 100644 --- a/pkg/migrations/op_set_unique.go +++ b/pkg/migrations/op_set_unique.go @@ -21,7 +21,7 @@ type OpSetUnique struct { var _ Operation = (*OpSetUnique)(nil) -func (o *OpSetUnique) Start(ctx context.Context, conn db.DB, latestSchema string, tr SQLTransformer, s *schema.Schema) (*schema.Table, error) { +func (o *OpSetUnique) Start(ctx context.Context, conn db.DB, latestSchema string, s *schema.Schema) (*schema.Table, error) { table := s.GetTable(o.Table) if table == nil { return nil, TableDoesNotExistError{Name: o.Table} @@ -34,7 +34,7 @@ func (o *OpSetUnique) Start(ctx context.Context, conn db.DB, latestSchema string return table, createUniqueIndexConcurrently(ctx, conn, s.Name, o.Name, table.Name, []string{column.Name}) } -func (o *OpSetUnique) Complete(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetUnique) Complete(ctx context.Context, conn db.DB, s *schema.Schema) error { // Create a unique constraint using the unique index _, err := conn.ExecContext(ctx, fmt.Sprintf("ALTER TABLE IF EXISTS %s ADD CONSTRAINT %s UNIQUE USING INDEX %s", pq.QuoteIdentifier(o.Table), @@ -47,7 +47,7 @@ func (o *OpSetUnique) Complete(ctx context.Context, conn db.DB, tr SQLTransforme return err } -func (o *OpSetUnique) Rollback(ctx context.Context, conn db.DB, tr SQLTransformer, s *schema.Schema) error { +func (o *OpSetUnique) Rollback(ctx context.Context, conn db.DB, s *schema.Schema) error { return nil } diff --git a/pkg/migrations/trigger.go b/pkg/migrations/trigger.go index 710e32161..1ca705a72 100644 --- a/pkg/migrations/trigger.go +++ b/pkg/migrations/trigger.go @@ -37,17 +37,12 @@ type triggerConfig struct { NeedsBackfillColumn string } -func createTrigger(ctx context.Context, conn db.DB, tr SQLTransformer, cfg triggerConfig) error { - expr, err := tr.TransformSQL(cfg.SQL) - if err != nil { - return err - } - - if len(expr) > 0 && expr[0] != '(' { - expr = "(" + expr + ")" +func createTrigger(ctx context.Context, conn db.DB, cfg triggerConfig) error { + // Parenthesize the up/down SQL if it's not parenthesized already + if len(cfg.SQL) > 0 && cfg.SQL[0] != '(' { + cfg.SQL = "(" + cfg.SQL + ")" } - cfg.SQL = expr cfg.NeedsBackfillColumn = CNeedsBackfillColumn funcSQL, err := buildFunction(cfg) diff --git a/pkg/roll/execute.go b/pkg/roll/execute.go index 386087506..19654ab57 100644 --- a/pkg/roll/execute.go +++ b/pkg/roll/execute.go @@ -86,7 +86,7 @@ func (m *Roll) StartDDLOperations(ctx context.Context, migration *migrations.Mig // execute operations var tablesToBackfill []*schema.Table for _, op := range migration.Operations { - table, err := op.Start(ctx, m.pgConn, latestSchema, m.sqlTransformer, newSchema) + table, err := op.Start(ctx, m.pgConn, latestSchema, newSchema) if err != nil { errRollback := m.Rollback(ctx) @@ -189,7 +189,7 @@ func (m *Roll) Complete(ctx context.Context) error { // execute operations refreshViews := false for _, op := range migration.Operations { - err := op.Complete(ctx, m.pgConn, m.sqlTransformer, currentSchema) + err := op.Complete(ctx, m.pgConn, currentSchema) if err != nil { return fmt.Errorf("unable to execute complete operation: %w", err) } @@ -267,7 +267,7 @@ func (m *Roll) Rollback(ctx context.Context) error { // roll back operations in reverse order for i := len(migration.Operations) - 1; i >= 0; i-- { - err := migration.Operations[i].Rollback(ctx, m.pgConn, m.sqlTransformer, schema) + err := migration.Operations[i].Rollback(ctx, m.pgConn, schema) if err != nil { return fmt.Errorf("unable to execute rollback operation: %w", err) } diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 4775f0d89..3d78d010f 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -710,112 +710,6 @@ func TestRollSchemaMethodReturnsCorrectSchema(t *testing.T) { }) } -func TestSQLTransformerOptionIsUsedWhenCreatingTriggers(t *testing.T) { - t.Parallel() - - t.Run("when the SQL transformer is used to rewrite SQL", func(t *testing.T) { - t.Parallel() - - sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{ - "'apples'": "'rewritten'", - }) - opts := []roll.Option{roll.WithSQLTransformer(sqlTransformer)} - - testutils.WithMigratorAndConnectionToContainerWithOptions(t, opts, func(mig *roll.Roll, db *sql.DB) { - ctx := context.Background() - - // Start a create table migration - err := mig.Start(ctx, &migrations.Migration{ - Name: "01_create_table", - Operations: migrations.Operations{createTableOp("table1")}, - }, backfill.NewConfig()) - require.NoError(t, err) - - // Complete the migration - err = mig.Complete(ctx) - require.NoError(t, err) - - // Insert some data - _, err = db.ExecContext(ctx, "INSERT INTO table1 (id, name) VALUES (1, 'alice'), (2, 'bob')") - require.NoError(t, err) - - // Start an add column migration that requires a backfill - err = mig.Start(ctx, &migrations.Migration{ - Name: "02_add_column", - Operations: migrations.Operations{ - &migrations.OpAddColumn{ - Table: "table1", - Up: "'apples'", - Column: migrations.Column{ - Name: "description", - Type: "text", - Nullable: false, - }, - }, - }, - }, backfill.NewConfig()) - require.NoError(t, err) - - // Complete the migration - err = mig.Complete(ctx) - require.NoError(t, err) - - // Ensure that the backfill used the SQL rewritten by the transformer - rows := MustSelect(t, db, "public", "02_add_column", "table1") - assert.Equal(t, []map[string]any{ - {"id": 1, "name": "alice", "description": "rewritten"}, - {"id": 2, "name": "bob", "description": "rewritten"}, - }, rows) - }) - }) - - t.Run("when the SQL transformer returns an error", func(t *testing.T) { - t.Parallel() - - sqlTransformer := testutils.NewMockSQLTransformer(map[string]string{ - "'apples'": testutils.MockSQLTransformerError, - }) - opts := []roll.Option{roll.WithSQLTransformer(sqlTransformer)} - - testutils.WithMigratorAndConnectionToContainerWithOptions(t, opts, func(mig *roll.Roll, db *sql.DB) { - ctx := context.Background() - - // Start a create table migration - err := mig.Start(ctx, &migrations.Migration{ - Name: "01_create_table", - Operations: migrations.Operations{createTableOp("table1")}, - }, backfill.NewConfig()) - require.NoError(t, err) - - // Complete the migration - err = mig.Complete(ctx) - require.NoError(t, err) - - // Insert some data - _, err = db.ExecContext(ctx, "INSERT INTO table1 (id, name) VALUES (1, 'alice'), (2, 'bob')") - require.NoError(t, err) - - // Start an add column migration that requires a backfill - err = mig.Start(ctx, &migrations.Migration{ - Name: "02_add_column", - Operations: migrations.Operations{ - &migrations.OpAddColumn{ - Table: "table1", - Up: "'apples'", - Column: migrations.Column{ - Name: "description", - Type: "text", - Nullable: false, - }, - }, - }, - }, backfill.NewConfig()) - // Ensure that the start phase has failed with a SQL transformer error - require.ErrorIs(t, err, testutils.ErrMockSQLTransformer) - }) - }) -} - func TestWithSearchPathOptionIsRespected(t *testing.T) { t.Parallel() diff --git a/pkg/roll/options.go b/pkg/roll/options.go index bc5a0fedc..e8b42acac 100644 --- a/pkg/roll/options.go +++ b/pkg/roll/options.go @@ -2,10 +2,6 @@ package roll -import ( - "github.com/xataio/pgroll/pkg/migrations" -) - type options struct { // lock timeout in milliseconds for pgroll DDL operations lockTimeoutMs int @@ -13,9 +9,6 @@ type options struct { // optional role to set before executing migrations role string - // optional SQL transformer to apply to all user-defined SQL statements - sqlTransformer migrations.SQLTransformer - // disable pgroll version schemas creation and deletion disableVersionSchemas bool @@ -84,16 +77,6 @@ func WithMigrationHooks(hooks MigrationHooks) Option { } } -// WithSQLTransformer sets the SQL transformer to apply to all user-defined SQL -// statements before they are executed. -// This is useful to sanitize or modify user defined SQL statements before they -// are executed. -func WithSQLTransformer(transformer migrations.SQLTransformer) Option { - return func(o *options) { - o.sqlTransformer = transformer - } -} - // WithSearchPath sets the search_path to use during migration execution. The // schema in which the migration is run is always included in the search path, // regardless of this setting. diff --git a/pkg/roll/roll.go b/pkg/roll/roll.go index 06897d293..5aade2a28 100644 --- a/pkg/roll/roll.go +++ b/pkg/roll/roll.go @@ -11,7 +11,6 @@ import ( "github.com/lib/pq" "github.com/xataio/pgroll/pkg/db" - "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/state" ) @@ -38,7 +37,6 @@ type Roll struct { migrationHooks MigrationHooks state *state.State pgVersion PGVersion - sqlTransformer migrations.SQLTransformer skipValidation bool } @@ -60,19 +58,11 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ... return nil, fmt.Errorf("unable to retrieve postgres version: %w", err) } - var sqlTransformer migrations.SQLTransformer = migrations.SQLTransformerFunc( - func(sql string) (string, error) { return sql, nil }, - ) - if rollOpts.sqlTransformer != nil { - sqlTransformer = rollOpts.sqlTransformer - } - return &Roll{ pgConn: &db.RDB{DB: conn}, schema: schema, state: state, pgVersion: pgMajorVersion, - sqlTransformer: sqlTransformer, disableVersionSchemas: rollOpts.disableVersionSchemas, noVersionSchemaForRawSQL: rollOpts.noVersionSchemaForRawSQL, migrationHooks: rollOpts.migrationHooks,