From 05ada917414ea7c574be3974c7de4f09535961fd Mon Sep 17 00:00:00 2001 From: Artur Sawicki Date: Thu, 5 Dec 2024 11:13:45 +0100 Subject: [PATCH] feat: Rework data types (#3244) Rework data types handling inside our SDK: - Extract in a separate package with two main entry points: - `func ParseDataType(raw string) (DataType, error)` - responsible for proper datatype parsing and storing the attributes If any (falling to the hardcoded default for now - this is a continuation from #3020 where diff suppressions for number and text types were set exactly this way; for V1 we should stay with hardcoded defaults and we can think of improving them later on without the direct influence on our users) - `func AreTheSame(a DataType, b DataType) bool` - responsible for comparing any two data types; used mostly for diff suppressions and changes discovery - Replace all invocations of the old `ToDataType` in a compatible way (temporary function `LegacyDataTypeFrom` that underneath calls `ToLegacyDataTypeSql()` which returns the same output as old data types used two) - Test on the integration level the consistency of data types' behavior with Snowflake docs - Use new parser and comparator in validation and diff suppression functions (replacing the old ones in all resources) Next steps/future improvements: - Use new DataType in SDK SQL builder - Replace the rest of the old DataType usages (especially in the SDK Requests/Opts) - Tackle SNOW-1596962 TODOs regarding VECTOR type - Improve parsing functions for lists of arguments (and defaults) - will be done on a case-by-case basis with all the resources - Clean up the code in new data types (TODOs with SNOW-1843440) --- MIGRATION_GUIDE.md | 26 + pkg/helpers/helpers.go | 3 +- pkg/resources/diff_suppressions.go | 24 +- pkg/resources/external_function.go | 12 +- .../external_function_state_upgraders.go | 5 +- pkg/resources/external_table.go | 21 +- pkg/resources/function.go | 20 +- pkg/resources/function_state_upgraders.go | 5 +- pkg/resources/helpers.go | 66 - pkg/resources/helpers_test.go | 15 +- pkg/resources/masking_policy.go | 23 +- pkg/resources/procedure.go | 26 +- pkg/resources/procedure_state_upgraders.go | 5 +- pkg/resources/row_access_policy.go | 15 +- pkg/resources/table.go | 46 +- pkg/resources/table_acceptance_test.go | 5 +- pkg/resources/validators.go | 20 +- pkg/resources/validators_test.go | 42 - pkg/sdk/common_types.go | 6 +- pkg/sdk/data_types.go | 176 --- pkg/sdk/data_types_deprecated.go | 51 + pkg/sdk/data_types_deprecated_test.go | 56 + pkg/sdk/data_types_test.go | 218 ---- pkg/sdk/datatypes/array.go | 21 + pkg/sdk/datatypes/binary.go | 51 + pkg/sdk/datatypes/boolean.go | 21 + pkg/sdk/datatypes/data_types.go | 149 +++ pkg/sdk/datatypes/data_types_test.go | 1148 +++++++++++++++++ pkg/sdk/datatypes/date.go | 21 + pkg/sdk/datatypes/float.go | 21 + pkg/sdk/datatypes/geography.go | 21 + pkg/sdk/datatypes/geometry.go | 21 + pkg/sdk/datatypes/legacy.go | 19 + pkg/sdk/datatypes/number.go | 104 ++ pkg/sdk/datatypes/object.go | 21 + pkg/sdk/datatypes/text.go | 69 + pkg/sdk/datatypes/time.go | 51 + pkg/sdk/datatypes/timestamp.go | 3 + pkg/sdk/datatypes/timestamp_ltz.go | 49 + pkg/sdk/datatypes/timestamp_ntz.go | 49 + pkg/sdk/datatypes/timestamp_tz.go | 49 + pkg/sdk/datatypes/variant.go | 21 + pkg/sdk/datatypes/vector.go | 65 + pkg/sdk/dynamic_table.go | 5 +- pkg/sdk/identifier_helpers.go | 10 +- pkg/sdk/masking_policy.go | 5 +- pkg/sdk/tables_test.go | 8 +- .../testint/data_types_integration_test.go | 349 +++++ .../external_tables_integration_test.go | 2 +- pkg/sdk/testint/functions_integration_test.go | 7 +- ...ow_access_policies_gen_integration_test.go | 10 +- pkg/sdk/testint/tables_integration_test.go | 5 +- .../testint/warehouses_integration_test.go | 2 +- pkg/sdk/validations.go | 4 +- 54 files changed, 2627 insertions(+), 640 deletions(-) delete mode 100644 pkg/sdk/data_types.go create mode 100644 pkg/sdk/data_types_deprecated.go create mode 100644 pkg/sdk/data_types_deprecated_test.go delete mode 100644 pkg/sdk/data_types_test.go create mode 100644 pkg/sdk/datatypes/array.go create mode 100644 pkg/sdk/datatypes/binary.go create mode 100644 pkg/sdk/datatypes/boolean.go create mode 100644 pkg/sdk/datatypes/data_types.go create mode 100644 pkg/sdk/datatypes/data_types_test.go create mode 100644 pkg/sdk/datatypes/date.go create mode 100644 pkg/sdk/datatypes/float.go create mode 100644 pkg/sdk/datatypes/geography.go create mode 100644 pkg/sdk/datatypes/geometry.go create mode 100644 pkg/sdk/datatypes/legacy.go create mode 100644 pkg/sdk/datatypes/number.go create mode 100644 pkg/sdk/datatypes/object.go create mode 100644 pkg/sdk/datatypes/text.go create mode 100644 pkg/sdk/datatypes/time.go create mode 100644 pkg/sdk/datatypes/timestamp.go create mode 100644 pkg/sdk/datatypes/timestamp_ltz.go create mode 100644 pkg/sdk/datatypes/timestamp_ntz.go create mode 100644 pkg/sdk/datatypes/timestamp_tz.go create mode 100644 pkg/sdk/datatypes/variant.go create mode 100644 pkg/sdk/datatypes/vector.go create mode 100644 pkg/sdk/testint/data_types_integration_test.go diff --git a/MIGRATION_GUIDE.md b/MIGRATION_GUIDE.md index a73c7d1d0d..dbfef7d7f6 100644 --- a/MIGRATION_GUIDE.md +++ b/MIGRATION_GUIDE.md @@ -77,6 +77,32 @@ resource "snowflake_tag_association" "table_association" { The state is migrated automatically. Please adjust your configuration files. +### Data type changes + +As part of reworking functions, procedures, and any other resource utilizing Snowflake data types, we adjusted the parsing of data types to be more aligned with Snowflake (according to [docs](https://docs.snowflake.com/en/sql-reference/intro-summary-data-types)). + +Affected resources: +- `snowflake_function` +- `snowflake_procedure` +- `snowflake_table` +- `snowflake_external_function` +- `snowflake_masking_policy` +- `snowflake_row_access_policy` +- `snowflake_dynamic_table` +You may encounter non-empty plans in these resources after bumping. + +Changes to the previous implementation/limitations: +- `BOOL` is no longer supported; use `BOOLEAN` instead. +- Following the change described [here](#bugfix-handle-data-type-diff-suppression-better-for-text-and-number), comparing and suppressing changes of data types was extended for all other data types with the following rules: + - `CHARACTER`, `CHAR`, `NCHAR` now have the default size set to 1 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-text#char-character-nchar)) + - `BINARY` has default size set to 8388608 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-text#binary)) + - `TIME` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#time)) + - `TIMESTAMP_LTZ` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)); supported aliases: `TIMESTAMPLTZ`, `TIMESTAMP WITH LOCAL TIME ZONE`. + - `TIMESTAMP_NTZ` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)); supported aliases: `TIMESTAMPNTZ`, `TIMESTAMP WITHOUT TIME ZONE`, `DATETIME`. + - `TIMESTAMP_TZ` has default precision set to 9 if not provided (following the [docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)); supported aliases: `TIMESTAMPTZ`, `TIMESTAMP WITH TIME ZONE`. +- The session-settable `TIMESTAMP` is NOT supported ([docs](https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp)) +- `VECTOR` type still is limited and will be addressed soon (probably before the release so it will be edited) + ## v0.98.0 ➞ v0.99.0 ### snowflake_tasks data source changes diff --git a/pkg/helpers/helpers.go b/pkg/helpers/helpers.go index 6b3e39d4cf..b4dde1acd7 100644 --- a/pkg/helpers/helpers.go +++ b/pkg/helpers/helpers.go @@ -142,7 +142,8 @@ func DecodeSnowflakeAccountIdentifier(identifier string) (sdk.AccountIdentifier, } } -// TODO: use slices.Concat in Go 1.22 +// ConcatSlices is a temporary replacement for slices.Concat that will be available after we migrate to Go 1.22. +// TODO [SNOW-1844769]: use slices.Concat func ConcatSlices[T any](slices ...[]T) []T { var tmp []T for _, s := range slices { diff --git a/pkg/resources/diff_suppressions.go b/pkg/resources/diff_suppressions.go index 597529e4ec..14efa760b2 100644 --- a/pkg/resources/diff_suppressions.go +++ b/pkg/resources/diff_suppressions.go @@ -9,10 +9,15 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) func NormalizeAndCompare[T comparable](normalize func(string) (T, error)) schema.SchemaDiffSuppressFunc { + return NormalizeAndCompareUsingFunc(normalize, func(a, b T) bool { return a == b }) +} + +func NormalizeAndCompareUsingFunc[T any](normalize func(string) (T, error), compareFunc func(a, b T) bool) schema.SchemaDiffSuppressFunc { return func(_, oldValue, newValue string, _ *schema.ResourceData) bool { oldNormalized, err := normalize(oldValue) if err != nil { @@ -22,10 +27,15 @@ func NormalizeAndCompare[T comparable](normalize func(string) (T, error)) schema if err != nil { return false } - return oldNormalized == newNormalized + + return compareFunc(oldNormalized, newNormalized) } } +// DiffSuppressDataTypes handles data type suppression taking into account data type attributes for each type. +// It falls back to Snowflake defaults for arguments if no arguments were provided for the data type. +var DiffSuppressDataTypes = NormalizeAndCompareUsingFunc(datatypes.ParseDataType, datatypes.AreTheSame) + // NormalizeAndCompareIdentifiersInSet is a diff suppression function that should be used at top-level TypeSet fields that // hold identifiers to avoid diffs like: // - "DATABASE"."SCHEMA"."OBJECT" @@ -254,3 +264,15 @@ func IgnoreNewEmptyListOrSubfields(ignoredSubfields ...string) schema.SchemaDiff return len(parts) == 3 && slices.Contains(ignoredSubfields, parts[2]) && new == "" } } + +func ignoreTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + return strings.TrimSpace(old) == strings.TrimSpace(new) +} + +func ignoreCaseSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + return strings.EqualFold(old, new) +} + +func ignoreCaseAndTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { + return strings.EqualFold(strings.TrimSpace(old), strings.TrimSpace(new)) +} diff --git a/pkg/resources/external_function.go b/pkg/resources/external_function.go index 9459adab6a..2580fb6141 100644 --- a/pkg/resources/external_function.go +++ b/pkg/resources/external_function.go @@ -8,11 +8,11 @@ import ( "strconv" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" @@ -227,11 +227,11 @@ func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, for _, arg := range v.([]interface{}) { argName := arg.(map[string]interface{})["name"].(string) argType := arg.(map[string]interface{})["type"].(string) - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return diag.FromErr(err) } - args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + args = append(args, sdk.ExternalFunctionArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) } } argTypes := make([]sdk.DataType, 0, len(args)) @@ -241,13 +241,13 @@ func CreateContextExternalFunction(ctx context.Context, d *schema.ResourceData, id := sdk.NewSchemaObjectIdentifierWithArguments(database, schemaName, name, argTypes...) returnType := d.Get("return_type").(string) - resultDataType, err := sdk.ToDataType(returnType) + resultDataType, err := datatypes.ParseDataType(returnType) if err != nil { return diag.FromErr(err) } apiIntegration := sdk.NewAccountObjectIdentifier(d.Get("api_integration").(string)) urlOfProxyAndResource := d.Get("url_of_proxy_and_resource").(string) - req := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), resultDataType, &apiIntegration, urlOfProxyAndResource) + req := sdk.NewCreateExternalFunctionRequest(id.SchemaObjectId(), sdk.LegacyDataTypeFrom(resultDataType), &apiIntegration, urlOfProxyAndResource) // Set optionals if len(args) > 0 { diff --git a/pkg/resources/external_function_state_upgraders.go b/pkg/resources/external_function_state_upgraders.go index aba74585aa..315d7f6caa 100644 --- a/pkg/resources/external_function_state_upgraders.go +++ b/pkg/resources/external_function_state_upgraders.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type v085ExternalFunctionId struct { @@ -52,11 +53,11 @@ func v085ExternalFunctionStateUpgrader(ctx context.Context, rawState map[string] argDataTypes := make([]sdk.DataType, 0) if parsedV085ExternalFunctionId.ExternalFunctionArgTypes != "" { for _, argType := range strings.Split(parsedV085ExternalFunctionId.ExternalFunctionArgTypes, "-") { - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return nil, err } - argDataTypes = append(argDataTypes, argDataType) + argDataTypes = append(argDataTypes, sdk.LegacyDataTypeFrom(argDataType)) } } diff --git a/pkg/resources/external_table.go b/pkg/resources/external_table.go index 833b8c4dd8..56404cf703 100644 --- a/pkg/resources/external_table.go +++ b/pkg/resources/external_table.go @@ -5,17 +5,14 @@ import ( "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/hashicorp/terraform-plugin-sdk/v2/diag" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) var externalTableSchema = map[string]*schema.Schema{ @@ -59,11 +56,11 @@ var externalTableSchema = map[string]*schema.Schema{ ForceNew: true, }, "type": { - Type: schema.TypeString, - Required: true, - Description: "Column type, e.g. VARIANT", - ForceNew: true, - ValidateFunc: IsDataType(), + Type: schema.TypeString, + Required: true, + Description: "Column type, e.g. VARIANT", + ForceNew: true, + ValidateDiagFunc: IsDataTypeValid, }, "as": { Type: schema.TypeString, diff --git a/pkg/resources/function.go b/pkg/resources/function.go index 314439b96d..19415f91f1 100644 --- a/pkg/resources/function.go +++ b/pkg/resources/function.go @@ -7,11 +7,11 @@ import ( "regexp" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" @@ -311,7 +311,7 @@ func createScalaFunction(ctx context.Context, d *schema.ResourceData, meta inter functionDefinition := d.Get("statement").(string) handler := d.Get("handler").(string) // create request with required - request := sdk.NewCreateForScalaFunctionRequest(id, returnDataType, handler) + request := sdk.NewCreateForScalaFunctionRequest(id, sdk.LegacyDataTypeFrom(returnDataType), handler) request.WithFunctionDefinition(functionDefinition) // Set optionals @@ -739,16 +739,16 @@ func parseFunctionArguments(d *schema.ResourceData) ([]sdk.FunctionArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + args = append(args, sdk.FunctionArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil } -func convertFunctionDataType(s string) (sdk.DataType, diag.Diagnostics) { - dataType, err := sdk.ToDataType(s) +func convertFunctionDataType(s string) (datatypes.DataType, diag.Diagnostics) { + dataType, err := datatypes.ParseDataType(s) if err != nil { - return dataType, diag.FromErr(err) + return nil, diag.FromErr(err) } return dataType, nil } @@ -759,13 +759,13 @@ func convertFunctionColumns(s string) ([]sdk.FunctionColumn, diag.Diagnostics) { var columns []sdk.FunctionColumn for _, match := range matches { if len(match) == 3 { - dataType, err := sdk.ToDataType(match[2]) + dataType, err := datatypes.ParseDataType(match[2]) if err != nil { return nil, diag.FromErr(err) } columns = append(columns, sdk.FunctionColumn{ ColumnName: match[1], - ColumnDataType: dataType, + ColumnDataType: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -789,7 +789,7 @@ func parseFunctionReturnsRequest(s string) (*sdk.FunctionReturnsRequest, diag.Di if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(returnDataType)) + returns.WithResultDataType(*sdk.NewFunctionReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/resources/function_state_upgraders.go b/pkg/resources/function_state_upgraders.go index 501e44f1dc..7be3c5b9b8 100644 --- a/pkg/resources/function_state_upgraders.go +++ b/pkg/resources/function_state_upgraders.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type v085FunctionId struct { @@ -48,11 +49,11 @@ func v085FunctionIdStateUpgrader(ctx context.Context, rawState map[string]interf argDataTypes := make([]sdk.DataType, len(parsedV085FunctionId.ArgTypes)) for i, argType := range parsedV085FunctionId.ArgTypes { - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return nil, err } - argDataTypes[i] = argDataType + argDataTypes[i] = sdk.LegacyDataTypeFrom(argDataType) } schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArgumentsOld(parsedV085FunctionId.DatabaseName, parsedV085FunctionId.SchemaName, parsedV085FunctionId.FunctionName, argDataTypes) diff --git a/pkg/resources/helpers.go b/pkg/resources/helpers.go index c436f8d084..6752840e1e 100644 --- a/pkg/resources/helpers.go +++ b/pkg/resources/helpers.go @@ -3,79 +3,13 @@ package resources import ( "fmt" "slices" - "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) -func dataTypeValidateFunc(val interface{}, _ string) (warns []string, errs []error) { - if ok := sdk.IsValidDataType(val.(string)); !ok { - errs = append(errs, fmt.Errorf("%v is not a valid data type", val)) - } - return -} - -func dataTypeDiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - oldDT, err := sdk.ToDataType(old) - if err != nil { - return false - } - newDT, err := sdk.ToDataType(new) - if err != nil { - return false - } - return oldDT == newDT -} - -// DataTypeIssue3007DiffSuppressFunc is a temporary solution to handle data type suppression problems. -// Currently, it handles only number and text data types. -// It falls back to Snowflake defaults for arguments if no arguments were provided for the data type. -// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework -func DataTypeIssue3007DiffSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - oldDataType, err := sdk.ToDataType(old) - if err != nil { - return false - } - newDataType, err := sdk.ToDataType(new) - if err != nil { - return false - } - if oldDataType != newDataType { - return false - } - switch v := oldDataType; v { - case sdk.DataTypeNumber: - logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Handling number data type diff suppression") - oldPrecision, oldScale := sdk.ParseNumberDataTypeRaw(old) - newPrecision, newScale := sdk.ParseNumberDataTypeRaw(new) - return oldPrecision == newPrecision && oldScale == newScale - case sdk.DataTypeVARCHAR: - logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Handling text data type diff suppression") - oldLength := sdk.ParseVarcharDataTypeRaw(old) - newLength := sdk.ParseVarcharDataTypeRaw(new) - return oldLength == newLength - default: - logging.DebugLogger.Printf("[DEBUG] DataTypeIssue3007DiffSuppressFunc: Diff suppression for %s can't be currently handled", v) - } - return true -} - -func ignoreTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - return strings.TrimSpace(old) == strings.TrimSpace(new) -} - -func ignoreCaseSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - return strings.EqualFold(old, new) -} - -func ignoreCaseAndTrimSpaceSuppressFunc(_, old, new string, _ *schema.ResourceData) bool { - return strings.EqualFold(strings.TrimSpace(old), strings.TrimSpace(new)) -} - func getTagObjectIdentifier(obj map[string]any) sdk.ObjectIdentifier { database := obj["database"].(string) schema := obj["schema"].(string) diff --git a/pkg/resources/helpers_test.go b/pkg/resources/helpers_test.go index c60807c9ac..a29a8658a4 100644 --- a/pkg/resources/helpers_test.go +++ b/pkg/resources/helpers_test.go @@ -11,6 +11,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-testing/terraform" "github.com/stretchr/testify/assert" @@ -337,7 +338,7 @@ func TestListDiffWithCommonItems(t *testing.T) { } } -func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { +func Test_DataTypeDiffSuppressFunc(t *testing.T) { testCases := []struct { name string old string @@ -401,7 +402,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym number data type precision implicit and same", old: "NUMBER", - new: fmt.Sprintf("DECIMAL(%d)", sdk.DefaultNumberPrecision), + new: fmt.Sprintf("DECIMAL(%d)", datatypes.DefaultNumberPrecision), expected: true, }, { @@ -425,7 +426,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym number data type default scale implicit and explicit", old: "NUMBER(30)", - new: fmt.Sprintf("DECIMAL(30, %d)", sdk.DefaultNumberScale), + new: fmt.Sprintf("DECIMAL(30, %d)", datatypes.DefaultNumberScale), expected: true, }, { @@ -437,13 +438,13 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym number data type both precision and scale implicit and explicit", old: "NUMBER", - new: fmt.Sprintf("DECIMAL(%d, %d)", sdk.DefaultNumberPrecision, sdk.DefaultNumberScale), + new: fmt.Sprintf("DECIMAL(%d, %d)", datatypes.DefaultNumberPrecision, datatypes.DefaultNumberScale), expected: true, }, { name: "synonym number data type both precision and scale implicit and scale different", old: "NUMBER", - new: fmt.Sprintf("DECIMAL(%d, 2)", sdk.DefaultNumberPrecision), + new: fmt.Sprintf("DECIMAL(%d, 2)", datatypes.DefaultNumberPrecision), expected: false, }, { @@ -461,7 +462,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { { name: "synonym text data type length implicit and same", old: "VARCHAR", - new: fmt.Sprintf("TEXT(%d)", sdk.DefaultVarcharLength), + new: fmt.Sprintf("TEXT(%d)", datatypes.DefaultVarcharLength), expected: true, }, { @@ -475,7 +476,7 @@ func Test_DataTypeIssue3007DiffSuppressFunc(t *testing.T) { for _, tc := range testCases { tc := tc t.Run(tc.name, func(t *testing.T) { - result := resources.DataTypeIssue3007DiffSuppressFunc("", tc.old, tc.new, nil) + result := resources.DiffSuppressDataTypes("", tc.old, tc.new, nil) require.Equal(t, tc.expected, result) }) } diff --git a/pkg/resources/masking_policy.go b/pkg/resources/masking_policy.go index 4acd32bbb1..f35df310cd 100644 --- a/pkg/resources/masking_policy.go +++ b/pkg/resources/masking_policy.go @@ -6,13 +6,12 @@ import ( "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" @@ -55,8 +54,8 @@ var maskingPolicySchema = map[string]*schema.Schema{ "type": { Type: schema.TypeString, Required: true, - DiffSuppressFunc: NormalizeAndCompare(sdk.ToDataType), - ValidateDiagFunc: sdkValidation(sdk.ToDataType), + DiffSuppressFunc: DiffSuppressDataTypes, + ValidateDiagFunc: IsDataTypeValid, Description: dataTypeFieldDescription("The argument type. VECTOR data types are not yet supported."), ForceNew: true, }, @@ -77,8 +76,8 @@ var maskingPolicySchema = map[string]*schema.Schema{ Required: true, Description: dataTypeFieldDescription("The return data type must match the input data type of the first column that is specified as an input column."), ForceNew: true, - DiffSuppressFunc: NormalizeAndCompare(sdk.ToDataType), - ValidateDiagFunc: sdkValidation(sdk.ToDataType), + DiffSuppressFunc: DiffSuppressDataTypes, + ValidateDiagFunc: IsDataTypeValid, }, "exempt_other_policies": { Type: schema.TypeString, @@ -198,17 +197,17 @@ func CreateMaskingPolicy(ctx context.Context, d *schema.ResourceData, meta any) args := make([]sdk.TableColumnSignature, 0) for _, arg := range arguments { v := arg.(map[string]any) - dataType, err := sdk.ToDataType(v["type"].(string)) + dataType, err := datatypes.ParseDataType(v["type"].(string)) if err != nil { return diag.FromErr(err) } args = append(args, sdk.TableColumnSignature{ Name: v["name"].(string), - Type: dataType, + Type: sdk.LegacyDataTypeFrom(dataType), }) } - returns, err := sdk.ToDataType(returnDataType) + returns, err := datatypes.ParseDataType(returnDataType) if err != nil { return diag.FromErr(err) } @@ -226,7 +225,7 @@ func CreateMaskingPolicy(ctx context.Context, d *schema.ResourceData, meta any) opts.ExemptOtherPolicies = sdk.Pointer(parsed) } - err = client.MaskingPolicies.Create(ctx, id, args, returns, expression, opts) + err = client.MaskingPolicies.Create(ctx, id, args, sdk.LegacyDataTypeFrom(returns), expression, opts) if err != nil { return diag.FromErr(err) } diff --git a/pkg/resources/procedure.go b/pkg/resources/procedure.go index adb80061bb..aa8b557250 100644 --- a/pkg/resources/procedure.go +++ b/pkg/resources/procedure.go @@ -8,11 +8,11 @@ import ( "slices" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" @@ -60,9 +60,9 @@ var procedureSchema = map[string]*schema.Schema{ "type": { Type: schema.TypeString, Required: true, - ValidateFunc: dataTypeValidateFunc, - DiffSuppressFunc: dataTypeDiffSuppressFunc, Description: "The argument type", + ValidateDiagFunc: IsDataTypeValid, + DiffSuppressFunc: DiffSuppressDataTypes, }, }, }, @@ -322,7 +322,7 @@ func createJavaScriptProcedure(ctx context.Context, d *schema.ResourceData, meta return diags } procedureDefinition := d.Get("statement").(string) - req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), returnDataType, procedureDefinition) + req := sdk.NewCreateForJavaScriptProcedureRequest(id.SchemaObjectId(), sdk.LegacyDataTypeFrom(returnDataType), procedureDefinition) if len(args) > 0 { req.WithArguments(args) } @@ -735,16 +735,16 @@ func getProcedureArguments(d *schema.ResourceData) ([]sdk.ProcedureArgumentReque if diags != nil { return nil, diags } - args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: argDataType}) + args = append(args, sdk.ProcedureArgumentRequest{ArgName: argName, ArgDataType: sdk.LegacyDataTypeFrom(argDataType)}) } } return args, nil } -func convertProcedureDataType(s string) (sdk.DataType, diag.Diagnostics) { - dataType, err := sdk.ToDataType(s) +func convertProcedureDataType(s string) (datatypes.DataType, diag.Diagnostics) { + dataType, err := datatypes.ParseDataType(s) if err != nil { - return dataType, diag.FromErr(err) + return nil, diag.FromErr(err) } return dataType, nil } @@ -755,13 +755,13 @@ func convertProcedureColumns(s string) ([]sdk.ProcedureColumn, diag.Diagnostics) var columns []sdk.ProcedureColumn for _, match := range matches { if len(match) == 3 { - dataType, err := sdk.ToDataType(match[2]) + dataType, err := datatypes.ParseDataType(match[2]) if err != nil { return nil, diag.FromErr(err) } columns = append(columns, sdk.ProcedureColumn{ ColumnName: match[1], - ColumnDataType: dataType, + ColumnDataType: sdk.LegacyDataTypeFrom(dataType), }) } } @@ -785,7 +785,7 @@ func parseProcedureReturnsRequest(s string) (*sdk.ProcedureReturnsRequest, diag. if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } @@ -807,7 +807,7 @@ func parseProcedureSQLReturnsRequest(s string) (*sdk.ProcedureSQLReturnsRequest, if diags != nil { return nil, diags } - returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(returnDataType)) + returns.WithResultDataType(*sdk.NewProcedureReturnsResultDataTypeRequest(sdk.LegacyDataTypeFrom(returnDataType))) } return returns, nil } diff --git a/pkg/resources/procedure_state_upgraders.go b/pkg/resources/procedure_state_upgraders.go index 24e47d7d9f..610822401d 100644 --- a/pkg/resources/procedure_state_upgraders.go +++ b/pkg/resources/procedure_state_upgraders.go @@ -6,6 +6,7 @@ import ( "strings" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type v085ProcedureId struct { @@ -48,11 +49,11 @@ func v085ProcedureStateUpgrader(ctx context.Context, rawState map[string]interfa argDataTypes := make([]sdk.DataType, len(parsedV085ProcedureId.ArgTypes)) for i, argType := range parsedV085ProcedureId.ArgTypes { - argDataType, err := sdk.ToDataType(argType) + argDataType, err := datatypes.ParseDataType(argType) if err != nil { return nil, err } - argDataTypes[i] = argDataType + argDataTypes[i] = sdk.LegacyDataTypeFrom(argDataType) } schemaObjectIdentifierWithArguments := sdk.NewSchemaObjectIdentifierWithArgumentsOld(parsedV085ProcedureId.DatabaseName, parsedV085ProcedureId.SchemaName, parsedV085ProcedureId.ProcedureName, argDataTypes) diff --git a/pkg/resources/row_access_policy.go b/pkg/resources/row_access_policy.go index 12c3050cb3..5722b148ba 100644 --- a/pkg/resources/row_access_policy.go +++ b/pkg/resources/row_access_policy.go @@ -6,13 +6,12 @@ import ( "fmt" "log" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" @@ -54,8 +53,8 @@ var rowAccessPolicySchema = map[string]*schema.Schema{ "type": { Type: schema.TypeString, Required: true, - DiffSuppressFunc: NormalizeAndCompare(sdk.ToDataType), - ValidateDiagFunc: sdkValidation(sdk.ToDataType), + DiffSuppressFunc: DiffSuppressDataTypes, + ValidateDiagFunc: IsDataTypeValid, Description: dataTypeFieldDescription("The argument type. VECTOR data types are not yet supported."), ForceNew: true, }, @@ -179,11 +178,11 @@ func CreateRowAccessPolicy(ctx context.Context, d *schema.ResourceData, meta any args := make([]sdk.CreateRowAccessPolicyArgsRequest, 0) for _, arg := range arguments { v := arg.(map[string]any) - dataType, err := sdk.ToDataType(v["type"].(string)) + dataType, err := datatypes.ParseDataType(v["type"].(string)) if err != nil { return diag.FromErr(err) } - args = append(args, *sdk.NewCreateRowAccessPolicyArgsRequest(v["name"].(string), dataType)) + args = append(args, *sdk.NewCreateRowAccessPolicyArgsRequest(v["name"].(string), sdk.LegacyDataTypeFrom(dataType))) } createRequest := sdk.NewCreateRowAccessPolicyRequest(id, args, rowAccessExpression) diff --git a/pkg/resources/table.go b/pkg/resources/table.go index 017ad799d8..f0d4d77ea4 100644 --- a/pkg/resources/table.go +++ b/pkg/resources/table.go @@ -7,16 +7,15 @@ import ( "strconv" "strings" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" - "github.com/hashicorp/terraform-plugin-sdk/v2/diag" - "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" - + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/schemas" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/snowflake" + "github.com/hashicorp/terraform-plugin-sdk/v2/diag" + "github.com/hashicorp/terraform-plugin-sdk/v2/helper/customdiff" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation" ) @@ -64,8 +63,8 @@ var tableSchema = map[string]*schema.Schema{ Type: schema.TypeString, Required: true, Description: "Column type, e.g. VARIANT. For a full list of column types, see [Summary of Data Types](https://docs.snowflake.com/en/sql-reference/intro-summary-data-types).", - ValidateFunc: dataTypeValidateFunc, - DiffSuppressFunc: DataTypeIssue3007DiffSuppressFunc, + ValidateDiagFunc: IsDataTypeValid, + DiffSuppressFunc: DiffSuppressDataTypes, }, "nullable": { Type: schema.TypeBool, @@ -388,9 +387,13 @@ func getColumns(from interface{}) (to columns) { return to } -func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { +func getTableColumnRequest(from interface{}) (*sdk.TableColumnRequest, error) { c := from.(map[string]interface{}) _type := c["type"].(string) + dataType, err := datatypes.ParseDataType(_type) + if err != nil { + return nil, err + } nameInQuotes := fmt.Sprintf(`"%v"`, snowflake.EscapeString(c["name"].(string))) request := sdk.NewTableColumnRequest(nameInQuotes, sdk.DataType(_type)) @@ -400,7 +403,7 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { if len(_default) == 1 { if c, ok := _default[0].(map[string]interface{})["constant"]; ok { if constant, ok := c.(string); ok && len(constant) > 0 { - if sdk.IsStringType(_type) { + if datatypes.IsTextDataType(dataType) { expression = snowflake.EscapeSnowflakeString(constant) } else { expression = constant @@ -415,7 +418,7 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { } if s, ok := _default[0].(map[string]interface{})["sequence"]; ok { - if seq := s.(string); ok && len(seq) > 0 { + if seq, ok2 := s.(string); ok2 && len(seq) > 0 { expression = fmt.Sprintf(`%v.NEXTVAL`, seq) } } @@ -435,22 +438,26 @@ func getTableColumnRequest(from interface{}) *sdk.TableColumnRequest { request.WithMaskingPolicy(sdk.NewColumnMaskingPolicyRequest(sdk.NewSchemaObjectIdentifierFromFullyQualifiedName(maskingPolicy))) } - if sdk.IsStringType(_type) { + if datatypes.IsTextDataType(dataType) { request.WithCollate(sdk.String(c["collate"].(string))) } return request. WithNotNull(sdk.Bool(!c["nullable"].(bool))). - WithComment(sdk.String(c["comment"].(string))) + WithComment(sdk.String(c["comment"].(string))), nil } -func getTableColumnRequests(from interface{}) []sdk.TableColumnRequest { +func getTableColumnRequests(from interface{}) ([]sdk.TableColumnRequest, error) { cols := from.([]interface{}) to := make([]sdk.TableColumnRequest, len(cols)) for i, c := range cols { - to[i] = *getTableColumnRequest(c) + cReq, err := getTableColumnRequest(c) + if err != nil { + return nil, err + } + to[i] = *cReq } - return to + return to, nil } type primarykey struct { @@ -577,7 +584,10 @@ func CreateTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Dia name := d.Get("name").(string) id := sdk.NewSchemaObjectIdentifier(databaseName, schemaName, name) - tableColumnRequests := getTableColumnRequests(d.Get("column").([]interface{})) + tableColumnRequests, err := getTableColumnRequests(d.Get("column").([]interface{})) + if err != nil { + return diag.FromErr(err) + } createRequest := sdk.NewCreateTableRequest(id, tableColumnRequests) @@ -620,7 +630,7 @@ func CreateTable(ctx context.Context, d *schema.ResourceData, meta any) diag.Dia createRequest.WithTags(tagAssociationRequests) } - err := client.Tables.Create(ctx, createRequest) + err = client.Tables.Create(ctx, createRequest) if err != nil { return diag.FromErr(fmt.Errorf("error creating table %v err = %w", name, err)) } diff --git a/pkg/resources/table_acceptance_test.go b/pkg/resources/table_acceptance_test.go index caeeb1daaf..8ba4f24a8a 100644 --- a/pkg/resources/table_acceptance_test.go +++ b/pkg/resources/table_acceptance_test.go @@ -15,6 +15,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/provider/resources" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/terraform-plugin-testing/config" "github.com/hashicorp/terraform-plugin-testing/helper/resource" "github.com/hashicorp/terraform-plugin-testing/plancheck" @@ -2097,7 +2098,7 @@ func TestAcc_Table_issue3007_textColumn(t *testing.T) { tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() resourceName := "snowflake_table.test_table" - defaultVarchar := fmt.Sprintf("VARCHAR(%d)", sdk.DefaultVarcharLength) + defaultVarchar := fmt.Sprintf("VARCHAR(%d)", datatypes.DefaultVarcharLength) resource.Test(t, resource.TestCase{ PreCheck: func() { acc.TestAccPreCheck(t) }, @@ -2170,7 +2171,7 @@ func TestAcc_Table_issue3007_numberColumn(t *testing.T) { tableId := acc.TestClient().Ids.RandomSchemaObjectIdentifier() resourceName := "snowflake_table.test_table" - defaultNumber := fmt.Sprintf("NUMBER(%d,%d)", sdk.DefaultNumberPrecision, sdk.DefaultNumberScale) + defaultNumber := fmt.Sprintf("NUMBER(%d,%d)", datatypes.DefaultNumberPrecision, datatypes.DefaultNumberScale) resource.Test(t, resource.TestCase{ PreCheck: func() { acc.TestAccPreCheck(t) }, diff --git a/pkg/resources/validators.go b/pkg/resources/validators.go index 071a73a33c..d9fcb14edb 100644 --- a/pkg/resources/validators.go +++ b/pkg/resources/validators.go @@ -7,28 +7,12 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/provider/validators" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/hashicorp/go-cty/cty" "github.com/hashicorp/terraform-plugin-sdk/v2/diag" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema" ) -func IsDataType() schema.SchemaValidateFunc { //nolint:staticcheck - return func(value any, key string) (warnings []string, errors []error) { - stringValue, ok := value.(string) - if !ok { - errors = append(errors, fmt.Errorf("expected type of %s to be string, got %T", key, value)) - return warnings, errors - } - - _, err := sdk.ToDataType(stringValue) - if err != nil { - errors = append(errors, fmt.Errorf("expected %s to be one of %T values, got %s", key, sdk.DataTypeString, stringValue)) - } - - return warnings, errors - } -} - func IsValidIdentifier[T sdk.AccountObjectIdentifier | sdk.DatabaseObjectIdentifier | sdk.SchemaObjectIdentifier | sdk.TableColumnIdentifier]() schema.SchemaValidateDiagFunc { return validators.IsValidIdentifier[T]() } @@ -98,6 +82,8 @@ func sdkValidation[T any](normalize func(string) (T, error)) schema.SchemaValida return validators.NormalizeValidation(normalize) } +var IsDataTypeValid = sdkValidation(datatypes.ParseDataType) + func isNotEqualTo(notExpectedValue string, errorMessage string) schema.SchemaValidateDiagFunc { return func(value any, path cty.Path) diag.Diagnostics { if value != nil { diff --git a/pkg/resources/validators_test.go b/pkg/resources/validators_test.go index 9125bfd98c..af2408697e 100644 --- a/pkg/resources/validators_test.go +++ b/pkg/resources/validators_test.go @@ -9,48 +9,6 @@ import ( "github.com/stretchr/testify/assert" ) -func TestIsDataType(t *testing.T) { - isDataType := IsDataType() - key := "tag" - - testCases := []struct { - Name string - Value any - Error string - }{ - { - Name: "validation: correct DataType value", - Value: "NUMBER", - }, - { - Name: "validation: correct DataType value in lowercase", - Value: "number", - }, - { - Name: "validation: incorrect DataType value", - Value: "invalid data type", - Error: "expected tag to be one of", - }, - { - Name: "validation: incorrect value type", - Value: 123, - Error: "expected type of tag to be string", - }, - } - - for _, tt := range testCases { - t.Run(tt.Name, func(t *testing.T) { - _, errors := isDataType(tt.Value, key) - if tt.Error != "" { - assert.Len(t, errors, 1) - assert.ErrorContains(t, errors[0], tt.Error) - } else { - assert.Len(t, errors, 0) - } - }) - } -} - func Test_sdkValidation(t *testing.T) { genericNormalize := func(value string) (any, error) { if value == "ok" { diff --git a/pkg/sdk/common_types.go b/pkg/sdk/common_types.go index 1627ab9d2a..7a4975a78e 100644 --- a/pkg/sdk/common_types.go +++ b/pkg/sdk/common_types.go @@ -6,6 +6,8 @@ import ( "strconv" "strings" "time" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) var ( @@ -91,13 +93,13 @@ func ParseTableColumnSignature(signature string) ([]TableColumnSignature, error) if len(parts) < 2 { return []TableColumnSignature{}, fmt.Errorf("expected argument name and type, got %s", elem) } - dataType, err := ToDataType(parts[len(parts)-1]) + dataType, err := datatypes.ParseDataType(parts[len(parts)-1]) if err != nil { return []TableColumnSignature{}, err } arguments[i] = TableColumnSignature{ Name: strings.Join(parts[:len(parts)-1], " "), - Type: dataType, + Type: LegacyDataTypeFrom(dataType), } } return arguments, nil diff --git a/pkg/sdk/data_types.go b/pkg/sdk/data_types.go deleted file mode 100644 index 40ced4cb83..0000000000 --- a/pkg/sdk/data_types.go +++ /dev/null @@ -1,176 +0,0 @@ -package sdk - -import ( - "fmt" - "slices" - "strconv" - "strings" - - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" - "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/util" -) - -// DataType is based on https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. -type DataType string - -var allowedVectorInnerTypes = []DataType{ - DataTypeInt, - DataTypeFloat, -} - -const ( - DataTypeNumber DataType = "NUMBER" - DataTypeInt DataType = "INT" - DataTypeFloat DataType = "FLOAT" - DataTypeVARCHAR DataType = "VARCHAR" - DataTypeString DataType = "STRING" - DataTypeBinary DataType = "BINARY" - DataTypeBoolean DataType = "BOOLEAN" - DataTypeDate DataType = "DATE" - DataTypeTime DataType = "TIME" - DataTypeTimestamp DataType = "TIMESTAMP" - DataTypeTimestampLTZ DataType = "TIMESTAMP_LTZ" - DataTypeTimestampNTZ DataType = "TIMESTAMP_NTZ" - DataTypeTimestampTZ DataType = "TIMESTAMP_TZ" - DataTypeVariant DataType = "VARIANT" - DataTypeObject DataType = "OBJECT" - DataTypeArray DataType = "ARRAY" - DataTypeGeography DataType = "GEOGRAPHY" - DataTypeGeometry DataType = "GEOMETRY" -) - -var ( - DataTypeNumberSynonyms = []string{"NUMBER", "DECIMAL", "NUMERIC", "INT", "INTEGER", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} - DataTypeFloatSynonyms = []string{"FLOAT", "FLOAT4", "FLOAT8", "DOUBLE", "DOUBLE PRECISION", "REAL"} - DataTypeVarcharSynonyms = []string{"VARCHAR", "CHAR", "CHARACTER", "STRING", "TEXT"} - DataTypeBinarySynonyms = []string{"BINARY", "VARBINARY"} - DataTypeBooleanSynonyms = []string{"BOOLEAN", "BOOL"} - DataTypeTimestampLTZSynonyms = []string{"TIMESTAMP_LTZ"} - DataTypeTimestampTZSynonyms = []string{"TIMESTAMP_TZ"} - DataTypeTimestampNTZSynonyms = []string{"DATETIME", "TIMESTAMP", "TIMESTAMP_NTZ"} - DataTypeTimeSynonyms = []string{"TIME"} - DataTypeVectorSynonyms = []string{"VECTOR"} -) - -const ( - DefaultNumberPrecision = 38 - DefaultNumberScale = 0 - DefaultVarcharLength = 16777216 -) - -func ToDataType(s string) (DataType, error) { - dType := strings.ToUpper(s) - - switch dType { - case "DATE": - return DataTypeDate, nil - case "VARIANT": - return DataTypeVariant, nil - case "OBJECT": - return DataTypeObject, nil - case "ARRAY": - return DataTypeArray, nil - case "GEOGRAPHY": - return DataTypeGeography, nil - case "GEOMETRY": - return DataTypeGeometry, nil - } - - if slices.ContainsFunc(DataTypeNumberSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeNumber, nil - } - if slices.ContainsFunc(DataTypeFloatSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeFloat, nil - } - if slices.ContainsFunc(DataTypeVarcharSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeVARCHAR, nil - } - if slices.ContainsFunc(DataTypeBinarySynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeBinary, nil - } - if slices.Contains(DataTypeBooleanSynonyms, dType) { - return DataTypeBoolean, nil - } - if slices.ContainsFunc(DataTypeTimestampLTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTimestampLTZ, nil - } - if slices.ContainsFunc(DataTypeTimestampTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTimestampTZ, nil - } - if slices.ContainsFunc(DataTypeTimestampNTZSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTimestampNTZ, nil - } - if slices.ContainsFunc(DataTypeTimeSynonyms, func(s string) bool { return strings.HasPrefix(dType, s) }) { - return DataTypeTime, nil - } - if slices.ContainsFunc(DataTypeVectorSynonyms, func(e string) bool { return strings.HasPrefix(dType, e) }) { - return DataType(dType), nil - } - return "", fmt.Errorf("invalid data type: %s", s) -} - -func IsStringType(_type string) bool { - t := strings.ToUpper(_type) - return strings.HasPrefix(t, "STRING") || - strings.HasPrefix(t, "VARCHAR") || - strings.HasPrefix(t, "CHAR") || - strings.HasPrefix(t, "TEXT") || - strings.HasPrefix(t, "NVARCHAR") || - strings.HasPrefix(t, "NCHAR") -} - -// ParseNumberDataTypeRaw extracts precision and scale from the raw number data type input. -// It returns defaults if it can't parse arguments, data type is different, or no arguments were provided. -// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework -func ParseNumberDataTypeRaw(rawDataType string) (int, int) { - r := util.TrimAllPrefixes(strings.TrimSpace(strings.ToUpper(rawDataType)), DataTypeNumberSynonyms...) - r = strings.TrimSpace(r) - if strings.HasPrefix(r, "(") && strings.HasSuffix(r, ")") { - parts := strings.Split(r[1:len(r)-1], ",") - switch l := len(parts); l { - case 1: - precision, err := strconv.Atoi(strings.TrimSpace(parts[0])) - if err == nil { - return precision, DefaultNumberScale - } else { - logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s", err: %v`, parts[0], err) - } - case 2: - precision, err1 := strconv.Atoi(strings.TrimSpace(parts[0])) - scale, err2 := strconv.Atoi(strings.TrimSpace(parts[1])) - if err1 == nil && err2 == nil { - return precision, scale - } else { - logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s" or scale "%s", errs: %v, %v`, parts[0], parts[1], err1, err2) - } - default: - logging.DebugLogger.Printf("[DEBUG] Unexpected length of number arguments") - } - } - logging.DebugLogger.Printf("[DEBUG] Returning default number precision and scale") - return DefaultNumberPrecision, DefaultNumberScale -} - -// ParseVarcharDataTypeRaw extracts length from the raw text data type input. -// It returns default if it can't parse arguments, data type is different, or no length argument was provided. -// TODO [SNOW-1348103 or SNOW-1348106]: visit with functions and procedures rework -func ParseVarcharDataTypeRaw(rawDataType string) int { - r := util.TrimAllPrefixes(strings.TrimSpace(strings.ToUpper(rawDataType)), DataTypeVarcharSynonyms...) - r = strings.TrimSpace(r) - if strings.HasPrefix(r, "(") && strings.HasSuffix(r, ")") { - parts := strings.Split(r[1:len(r)-1], ",") - switch l := len(parts); l { - case 1: - length, err := strconv.Atoi(strings.TrimSpace(parts[0])) - if err == nil { - return length - } else { - logging.DebugLogger.Printf(`[DEBUG] Could not parse varchar length "%s", err: %v`, parts[0], err) - } - default: - logging.DebugLogger.Printf("[DEBUG] Unexpected length of varchar arguments") - } - } - logging.DebugLogger.Printf("[DEBUG] Returning default varchar length") - return DefaultVarcharLength -} diff --git a/pkg/sdk/data_types_deprecated.go b/pkg/sdk/data_types_deprecated.go new file mode 100644 index 0000000000..0d0315ad5e --- /dev/null +++ b/pkg/sdk/data_types_deprecated.go @@ -0,0 +1,51 @@ +package sdk + +import ( + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" +) + +// DataType is based on https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. +type DataType string + +var allowedVectorInnerTypes = []DataType{ + DataTypeInt, + DataTypeFloat, +} + +const ( + DataTypeNumber DataType = "NUMBER" + DataTypeInt DataType = "INT" + DataTypeFloat DataType = "FLOAT" + DataTypeVARCHAR DataType = "VARCHAR" + DataTypeString DataType = "STRING" + DataTypeBinary DataType = "BINARY" + DataTypeBoolean DataType = "BOOLEAN" + DataTypeDate DataType = "DATE" + DataTypeTime DataType = "TIME" + DataTypeTimestampLTZ DataType = "TIMESTAMP_LTZ" + DataTypeTimestampNTZ DataType = "TIMESTAMP_NTZ" + DataTypeTimestampTZ DataType = "TIMESTAMP_TZ" + DataTypeVariant DataType = "VARIANT" + DataTypeObject DataType = "OBJECT" + DataTypeArray DataType = "ARRAY" + DataTypeGeography DataType = "GEOGRAPHY" + DataTypeGeometry DataType = "GEOMETRY" +) + +// IsStringType is a legacy method. datatypes.IsTextDataType should be used instead. +// TODO [SNOW-1348114]: remove with tables rework +func IsStringType(_type string) bool { + t := strings.ToUpper(_type) + return strings.HasPrefix(t, "STRING") || + strings.HasPrefix(t, "VARCHAR") || + strings.HasPrefix(t, "CHAR") || + strings.HasPrefix(t, "TEXT") || + strings.HasPrefix(t, "NVARCHAR") || + strings.HasPrefix(t, "NCHAR") +} + +func LegacyDataTypeFrom(newDataType datatypes.DataType) DataType { + return DataType(newDataType.ToLegacyDataTypeSql()) +} diff --git a/pkg/sdk/data_types_deprecated_test.go b/pkg/sdk/data_types_deprecated_test.go new file mode 100644 index 0000000000..ae4445f746 --- /dev/null +++ b/pkg/sdk/data_types_deprecated_test.go @@ -0,0 +1,56 @@ +package sdk + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsStringType(t *testing.T) { + type test struct { + input string + want bool + } + + tests := []test{ + // case insensitive. + {input: "STRING", want: true}, + {input: "string", want: true}, + {input: "String", want: true}, + + // varchar types. + {input: "VARCHAR", want: true}, + {input: "NVARCHAR", want: true}, + {input: "NVARCHAR2", want: true}, + {input: "CHAR", want: true}, + {input: "NCHAR", want: true}, + {input: "CHAR VARYING", want: true}, + {input: "NCHAR VARYING", want: true}, + {input: "TEXT", want: true}, + + // with length + {input: "VARCHAR(100)", want: true}, + {input: "NVARCHAR(100)", want: true}, + {input: "NVARCHAR2(100)", want: true}, + {input: "CHAR(100)", want: true}, + {input: "NCHAR(100)", want: true}, + {input: "CHAR VARYING(100)", want: true}, + {input: "NCHAR VARYING(100)", want: true}, + {input: "TEXT(100)", want: true}, + + // binary is not string types. + {input: "binary", want: false}, + {input: "varbinary", want: false}, + + // other types + {input: "boolean", want: false}, + {input: "number", want: false}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + got := IsStringType(tc.input) + require.Equal(t, tc.want, got) + }) + } +} diff --git a/pkg/sdk/data_types_test.go b/pkg/sdk/data_types_test.go deleted file mode 100644 index 156e5dc8f2..0000000000 --- a/pkg/sdk/data_types_test.go +++ /dev/null @@ -1,218 +0,0 @@ -package sdk - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestToDataType(t *testing.T) { - type test struct { - input string - want DataType - } - - tests := []test{ - // case insensitive. - {input: "STRING", want: DataTypeVARCHAR}, - {input: "string", want: DataTypeVARCHAR}, - {input: "String", want: DataTypeVARCHAR}, - - // number types. - {input: "number", want: DataTypeNumber}, - {input: "decimal", want: DataTypeNumber}, - {input: "numeric", want: DataTypeNumber}, - {input: "int", want: DataTypeNumber}, - {input: "integer", want: DataTypeNumber}, - {input: "bigint", want: DataTypeNumber}, - {input: "smallint", want: DataTypeNumber}, - {input: "tinyint", want: DataTypeNumber}, - {input: "byteint", want: DataTypeNumber}, - - // float types. - {input: "float", want: DataTypeFloat}, - {input: "float4", want: DataTypeFloat}, - {input: "float8", want: DataTypeFloat}, - {input: "double", want: DataTypeFloat}, - {input: "double precision", want: DataTypeFloat}, - {input: "real", want: DataTypeFloat}, - - // varchar types. - {input: "varchar", want: DataTypeVARCHAR}, - {input: "char", want: DataTypeVARCHAR}, - {input: "character", want: DataTypeVARCHAR}, - {input: "string", want: DataTypeVARCHAR}, - {input: "text", want: DataTypeVARCHAR}, - - // binary types. - {input: "binary", want: DataTypeBinary}, - {input: "varbinary", want: DataTypeBinary}, - {input: "boolean", want: DataTypeBoolean}, - - // boolean types. - {input: "boolean", want: DataTypeBoolean}, - {input: "bool", want: DataTypeBoolean}, - - // timestamp ntz types. - {input: "datetime", want: DataTypeTimestampNTZ}, - {input: "timestamp", want: DataTypeTimestampNTZ}, - {input: "timestamp_ntz", want: DataTypeTimestampNTZ}, - - // timestamp tz types. - {input: "timestamp_tz", want: DataTypeTimestampTZ}, - {input: "timestamp_tz(9)", want: DataTypeTimestampTZ}, - - // timestamp ltz types. - {input: "timestamp_ltz", want: DataTypeTimestampLTZ}, - {input: "timestamp_ltz(9)", want: DataTypeTimestampLTZ}, - - // time types. - {input: "time", want: DataTypeTime}, - {input: "time(9)", want: DataTypeTime}, - - // all othertypes - {input: "date", want: DataTypeDate}, - {input: "variant", want: DataTypeVariant}, - {input: "object", want: DataTypeObject}, - {input: "array", want: DataTypeArray}, - {input: "geography", want: DataTypeGeography}, - {input: "geometry", want: DataTypeGeometry}, - {input: "VECTOR(INT, 10)", want: "VECTOR(INT, 10)"}, - {input: "VECTOR(INT, 20)", want: "VECTOR(INT, 20)"}, - {input: "VECTOR(FLOAT, 10)", want: "VECTOR(FLOAT, 10)"}, - {input: "VECTOR(FLOAT, 20)", want: "VECTOR(FLOAT, 20)"}, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got, err := ToDataType(tc.input) - require.NoError(t, err) - require.Equal(t, tc.want, got) - }) - } -} - -func TestIsStringType(t *testing.T) { - type test struct { - input string - want bool - } - - tests := []test{ - // case insensitive. - {input: "STRING", want: true}, - {input: "string", want: true}, - {input: "String", want: true}, - - // varchar types. - {input: "VARCHAR", want: true}, - {input: "NVARCHAR", want: true}, - {input: "NVARCHAR2", want: true}, - {input: "CHAR", want: true}, - {input: "NCHAR", want: true}, - {input: "CHAR VARYING", want: true}, - {input: "NCHAR VARYING", want: true}, - {input: "TEXT", want: true}, - - // with length - {input: "VARCHAR(100)", want: true}, - {input: "NVARCHAR(100)", want: true}, - {input: "NVARCHAR2(100)", want: true}, - {input: "CHAR(100)", want: true}, - {input: "NCHAR(100)", want: true}, - {input: "CHAR VARYING(100)", want: true}, - {input: "NCHAR VARYING(100)", want: true}, - {input: "TEXT(100)", want: true}, - - // binary is not string types. - {input: "binary", want: false}, - {input: "varbinary", want: false}, - - // other types - {input: "boolean", want: false}, - {input: "number", want: false}, - } - - for _, tc := range tests { - t.Run(tc.input, func(t *testing.T) { - got := IsStringType(tc.input) - require.Equal(t, tc.want, got) - }) - } -} - -func Test_ParseNumberDataTypeRaw(t *testing.T) { - type test struct { - input string - expectedPrecision int - expectedScale int - } - defaults := func(input string) test { - return test{input: input, expectedPrecision: DefaultNumberPrecision, expectedScale: DefaultNumberScale} - } - - tests := []test{ - {input: "NUMBER(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale}, - {input: "NUMBER(30, 2)", expectedPrecision: 30, expectedScale: 2}, - {input: "decimal(30, 2)", expectedPrecision: 30, expectedScale: 2}, - {input: "NUMBER( 30 , 2 )", expectedPrecision: 30, expectedScale: 2}, - {input: " NUMBER ( 30 , 2 ) ", expectedPrecision: 30, expectedScale: 2}, - - // returns defaults if it can't parse arguments, data type is different, or no arguments were provided - defaults("VARCHAR(1, 2)"), - defaults("VARCHAR(1)"), - defaults("VARCHAR"), - defaults("NUMBER"), - defaults("NUMBER()"), - defaults("NUMBER(x)"), - defaults(fmt.Sprintf("NUMBER(%d)", DefaultNumberPrecision)), - defaults(fmt.Sprintf("NUMBER(%d, x)", DefaultNumberPrecision)), - defaults(fmt.Sprintf("NUMBER(x, %d)", DefaultNumberScale)), - defaults("NUMBER(1, 2, 3)"), - } - - for _, tc := range tests { - tc := tc - t.Run(tc.input, func(t *testing.T) { - precision, scale := ParseNumberDataTypeRaw(tc.input) - assert.Equal(t, tc.expectedPrecision, precision) - assert.Equal(t, tc.expectedScale, scale) - }) - } -} - -func Test_ParseVarcharDataTypeRaw(t *testing.T) { - type test struct { - input string - expectedLength int - } - defaults := func(input string) test { - return test{input: input, expectedLength: DefaultVarcharLength} - } - - tests := []test{ - {input: "VARCHAR(30)", expectedLength: 30}, - {input: "text(30)", expectedLength: 30}, - {input: "VARCHAR( 30 )", expectedLength: 30}, - {input: " VARCHAR ( 30 ) ", expectedLength: 30}, - - // returns defaults if it can't parse arguments, data type is different, or no arguments were provided - defaults("VARCHAR(1, 2)"), - defaults("VARCHAR(x)"), - defaults("VARCHAR"), - defaults("NUMBER"), - defaults("NUMBER()"), - defaults("NUMBER(x)"), - defaults(fmt.Sprintf("VARCHAR(%d)", DefaultVarcharLength)), - } - - for _, tc := range tests { - tc := tc - t.Run(tc.input, func(t *testing.T) { - length := ParseVarcharDataTypeRaw(tc.input) - assert.Equal(t, tc.expectedLength, length) - }) - } -} diff --git a/pkg/sdk/datatypes/array.go b/pkg/sdk/datatypes/array.go new file mode 100644 index 0000000000..eb7247f6e6 --- /dev/null +++ b/pkg/sdk/datatypes/array.go @@ -0,0 +1,21 @@ +package datatypes + +// ArrayDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-semistructured#array +// It does not have synonyms. It does not have any attributes. +type ArrayDataType struct { + underlyingType string +} + +func (t *ArrayDataType) ToSql() string { + return t.underlyingType +} + +func (t *ArrayDataType) ToLegacyDataTypeSql() string { + return ArrayLegacyDataType +} + +var ArrayDataTypeSynonyms = []string{ArrayLegacyDataType} + +func parseArrayDataTypeRaw(raw sanitizedDataTypeRaw) (*ArrayDataType, error) { + return &ArrayDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/binary.go b/pkg/sdk/datatypes/binary.go new file mode 100644 index 0000000000..c50dba0570 --- /dev/null +++ b/pkg/sdk/datatypes/binary.go @@ -0,0 +1,51 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const DefaultBinarySize = 8388608 + +// BinaryDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-text#data-types-for-binary-strings +// It does have synonyms that allow specifying size. +type BinaryDataType struct { + size int + underlyingType string +} + +func (t *BinaryDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.size) +} + +func (t *BinaryDataType) ToLegacyDataTypeSql() string { + return BinaryLegacyDataType +} + +var BinaryDataTypeSynonyms = []string{BinaryLegacyDataType, "VARBINARY"} + +func parseBinaryDataTypeRaw(raw sanitizedDataTypeRaw) (*BinaryDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default size for binary") + return &BinaryDataType{DefaultBinarySize, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`binary %s could not be parsed, use "%s(size)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`binary %s could not be parsed, use "%s(size)" format`, raw.raw, raw.matchedByType) + } + sizeRaw := r[1 : len(r)-1] + size, err := strconv.Atoi(strings.TrimSpace(sizeRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse binary size "%s", err: %v`, sizeRaw, err) + return nil, fmt.Errorf(`could not parse the binary's size: "%s", err: %w`, sizeRaw, err) + } + return &BinaryDataType{size, raw.matchedByType}, nil +} + +func areBinaryDataTypesTheSame(a, b *BinaryDataType) bool { + return a.size == b.size +} diff --git a/pkg/sdk/datatypes/boolean.go b/pkg/sdk/datatypes/boolean.go new file mode 100644 index 0000000000..4e84979f40 --- /dev/null +++ b/pkg/sdk/datatypes/boolean.go @@ -0,0 +1,21 @@ +package datatypes + +// BooleanDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-logical +// It does not have synonyms. It does not have any attributes. +type BooleanDataType struct { + underlyingType string +} + +func (t *BooleanDataType) ToSql() string { + return t.underlyingType +} + +func (t *BooleanDataType) ToLegacyDataTypeSql() string { + return BooleanLegacyDataType +} + +var BooleanDataTypeSynonyms = []string{BooleanLegacyDataType} + +func parseBooleanDataTypeRaw(raw sanitizedDataTypeRaw) (*BooleanDataType, error) { + return &BooleanDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/data_types.go b/pkg/sdk/datatypes/data_types.go new file mode 100644 index 0000000000..e1c0065855 --- /dev/null +++ b/pkg/sdk/datatypes/data_types.go @@ -0,0 +1,149 @@ +package datatypes + +import ( + "fmt" + "reflect" + "slices" + "strings" +) + +// TODO [SNOW-1843440]: generalize definitions for different types; generalize the ParseDataType function +// TODO [SNOW-1843440]: generalize implementation in types (i.e. the internal struct implementing ToLegacyDataTypeSql and containing the underlyingType) +// TODO [SNOW-1843440]: consider known/unknown to use Snowflake defaults and allow better handling in terraform resources +// TODO [SNOW-1843440]: replace old DataTypes + +// DataType is the common interface that represents all Snowflake datatypes documented in https://docs.snowflake.com/en/sql-reference/intro-summary-data-types. +type DataType interface { + ToSql() string + ToLegacyDataTypeSql() string +} + +type sanitizedDataTypeRaw struct { + raw string + matchedByType string +} + +// ParseDataType is the entry point to get the implementation of the DataType from input raw string. +// TODO [SNOW-1843440]: order currently matters (e.g. HasPrefix(TIME) can match also TIMESTAMP*, make the checks more precise and order-independent) +func ParseDataType(raw string) (DataType, error) { + dataTypeRaw := strings.TrimSpace(strings.ToUpper(raw)) + + if idx := slices.IndexFunc(AllNumberDataTypes, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseNumberDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, AllNumberDataTypes[idx]}) + } + if slices.Contains(FloatDataTypeSynonyms, dataTypeRaw) { + return parseFloatDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if idx := slices.IndexFunc(AllTextDataTypes, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTextDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, AllTextDataTypes[idx]}) + } + if idx := slices.IndexFunc(BinaryDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseBinaryDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, BinaryDataTypeSynonyms[idx]}) + } + if slices.Contains(BooleanDataTypeSynonyms, dataTypeRaw) { + return parseBooleanDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(DateDataTypeSynonyms, dataTypeRaw) { + return parseDateDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if idx := slices.IndexFunc(TimestampLtzDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimestampLtzDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimestampLtzDataTypeSynonyms[idx]}) + } + if idx := slices.IndexFunc(TimestampNtzDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimestampNtzDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimestampNtzDataTypeSynonyms[idx]}) + } + if idx := slices.IndexFunc(TimestampTzDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimestampTzDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimestampTzDataTypeSynonyms[idx]}) + } + if idx := slices.IndexFunc(TimeDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseTimeDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, TimeDataTypeSynonyms[idx]}) + } + if slices.Contains(VariantDataTypeSynonyms, dataTypeRaw) { + return parseVariantDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(ObjectDataTypeSynonyms, dataTypeRaw) { + return parseObjectDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(ArrayDataTypeSynonyms, dataTypeRaw) { + return parseArrayDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(GeographyDataTypeSynonyms, dataTypeRaw) { + return parseGeographyDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if slices.Contains(GeometryDataTypeSynonyms, dataTypeRaw) { + return parseGeometryDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, dataTypeRaw}) + } + if idx := slices.IndexFunc(VectorDataTypeSynonyms, func(s string) bool { return strings.HasPrefix(dataTypeRaw, s) }); idx >= 0 { + return parseVectorDataTypeRaw(sanitizedDataTypeRaw{dataTypeRaw, VectorDataTypeSynonyms[idx]}) + } + + return nil, fmt.Errorf("invalid data type: %s", raw) +} + +// AreTheSame compares any two data types. +// If both data types are nil it returns true. +// If only one data type is nil it returns false. +// It returns false for different underlying types. +// For the same type it performs type-specific comparison. +func AreTheSame(a DataType, b DataType) bool { + if a == nil && b == nil { + return true + } + if a == nil && b != nil || a != nil && b == nil { + return false + } + if reflect.TypeOf(a) != reflect.TypeOf(b) { + return false + } + switch v := a.(type) { + case *ArrayDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *BinaryDataType: + return castSuccessfully(v, b, areBinaryDataTypesTheSame) + case *BooleanDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *DateDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *FloatDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *GeographyDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *GeometryDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *NumberDataType: + return castSuccessfully(v, b, areNumberDataTypesTheSame) + case *ObjectDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *TextDataType: + return castSuccessfully(v, b, areTextDataTypesTheSame) + case *TimeDataType: + return castSuccessfully(v, b, areTimeDataTypesTheSame) + case *TimestampLtzDataType: + return castSuccessfully(v, b, areTimestampLtzDataTypesTheSame) + case *TimestampNtzDataType: + return castSuccessfully(v, b, areTimestampNtzDataTypesTheSame) + case *TimestampTzDataType: + return castSuccessfully(v, b, areTimestampTzDataTypesTheSame) + case *VariantDataType: + return castSuccessfully(v, b, noArgsDataTypesAreTheSame) + case *VectorDataType: + return castSuccessfully(v, b, areVectorDataTypesTheSame) + } + return false +} + +func IsTextDataType(a DataType) bool { + _, ok := a.(*TextDataType) + return ok +} + +func castSuccessfully[T any](a T, b DataType, invoke func(a T, b T) bool) bool { + if dCasted, ok := b.(T); ok { + return invoke(a, dCasted) + } + return false +} + +func noArgsDataTypesAreTheSame[T DataType](_ T, _ T) bool { + return true +} diff --git a/pkg/sdk/datatypes/data_types_test.go b/pkg/sdk/datatypes/data_types_test.go new file mode 100644 index 0000000000..21525fded8 --- /dev/null +++ b/pkg/sdk/datatypes/data_types_test.go @@ -0,0 +1,1148 @@ +package datatypes + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_ParseDataType_Number(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedScale int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultNumberPrecision, + expectedScale: DefaultNumberScale, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "NUMBER(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMBER"}, + {input: "NUMBER(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMBER"}, + {input: "dec(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "DEC"}, + {input: "dec(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "DEC"}, + {input: "decimal(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "DECIMAL"}, + {input: "decimal(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "DECIMAL"}, + {input: "NuMeRiC(30)", expectedPrecision: 30, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMERIC"}, + {input: "NuMeRiC(30, 2)", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMERIC"}, + {input: "NUMBER( 30 , 2 )", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMBER"}, + {input: " NUMBER ( 30 , 2 ) ", expectedPrecision: 30, expectedScale: 2, expectedUnderlyingType: "NUMBER"}, + {input: fmt.Sprintf("NUMBER(%d)", DefaultNumberPrecision), expectedPrecision: DefaultNumberPrecision, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMBER"}, + {input: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedPrecision: DefaultNumberPrecision, expectedScale: DefaultNumberScale, expectedUnderlyingType: "NUMBER"}, + + defaults("NUMBER"), + defaults("DEC"), + defaults("DECIMAL"), + defaults("NUMERIC"), + defaults(" NUMBER "), + + defaults("INT"), + defaults("INTEGER"), + defaults("BIGINT"), + defaults("SMALLINT"), + defaults("TINYINT"), + defaults("BYTEINT"), + defaults("int"), + defaults("integer"), + defaults("bigint"), + defaults("smallint"), + defaults("tinyint"), + defaults("byteint"), + } + + negativeTestCases := []test{ + negative("other(1, 2)"), + negative("other(1)"), + negative("other"), + negative("NUMBER()"), + negative("NUMBER(x)"), + negative(fmt.Sprintf("NUMBER(%d, x)", DefaultNumberPrecision)), + negative(fmt.Sprintf("NUMBER(x, %d)", DefaultNumberScale)), + negative("NUMBER(1, 2, 3)"), + negative("NUMBER("), + negative("NUMBER)"), + negative("NUM BER"), + negative("INT(30)"), + negative("INT(30, 2)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &NumberDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*NumberDataType).precision) + assert.Equal(t, tc.expectedScale, parsed.(*NumberDataType).scale) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*NumberDataType).underlyingType) + + assert.Equal(t, NumberLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d, %d)", parsed.(*NumberDataType).underlyingType, parsed.(*NumberDataType).precision, parsed.(*NumberDataType).scale), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Float(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" FLOAT "), + defaults("FLOAT"), + defaults("FLOAT4"), + defaults("FLOAT8"), + defaults("DOUBLE PRECISION"), + defaults("DOUBLE"), + defaults("REAL"), + defaults("float"), + defaults("float4"), + defaults("float8"), + defaults("double precision"), + defaults("double"), + defaults("real"), + } + + negativeTestCases := []test{ + negative("FLOAT(38, 0)"), + negative("FLOAT(38, 2)"), + negative("FLOAT(38)"), + negative("FLOAT()"), + negative("F L O A T"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &FloatDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*FloatDataType).underlyingType) + + assert.Equal(t, FloatLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Text(t *testing.T) { + type test struct { + input string + expectedLength int + expectedUnderlyingType string + } + defaultsVarchar := func(input string) test { + return test{ + input: input, + expectedLength: DefaultVarcharLength, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + defaultsChar := func(input string) test { + return test{ + input: input, + expectedLength: DefaultCharLength, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "VARCHAR(30)", expectedLength: 30, expectedUnderlyingType: "VARCHAR"}, + {input: "string(30)", expectedLength: 30, expectedUnderlyingType: "STRING"}, + {input: "VARCHAR( 30 )", expectedLength: 30, expectedUnderlyingType: "VARCHAR"}, + {input: " VARCHAR ( 30 ) ", expectedLength: 30, expectedUnderlyingType: "VARCHAR"}, + {input: fmt.Sprintf("VARCHAR(%d)", DefaultVarcharLength), expectedLength: DefaultVarcharLength, expectedUnderlyingType: "VARCHAR"}, + + {input: "CHAR(30)", expectedLength: 30, expectedUnderlyingType: "CHAR"}, + {input: "character(30)", expectedLength: 30, expectedUnderlyingType: "CHARACTER"}, + {input: "CHAR( 30 )", expectedLength: 30, expectedUnderlyingType: "CHAR"}, + {input: " CHAR ( 30 ) ", expectedLength: 30, expectedUnderlyingType: "CHAR"}, + {input: fmt.Sprintf("CHAR(%d)", DefaultCharLength), expectedLength: DefaultCharLength, expectedUnderlyingType: "CHAR"}, + + defaultsVarchar(" VARCHAR "), + defaultsVarchar("VARCHAR"), + defaultsVarchar("STRING"), + defaultsVarchar("TEXT"), + defaultsVarchar("NVARCHAR"), + defaultsVarchar("NVARCHAR2"), + defaultsVarchar("CHAR VARYING"), + defaultsVarchar("NCHAR VARYING"), + defaultsVarchar("varchar"), + defaultsVarchar("string"), + defaultsVarchar("text"), + defaultsVarchar("nvarchar"), + defaultsVarchar("nvarchar2"), + defaultsVarchar("char varying"), + defaultsVarchar("nchar varying"), + + defaultsChar(" CHAR "), + defaultsChar("CHAR"), + defaultsChar("CHARACTER"), + defaultsChar("NCHAR"), + defaultsChar("char"), + defaultsChar("character"), + defaultsChar("nchar"), + } + + negativeTestCases := []test{ + negative("other(1, 2)"), + negative("other(1)"), + negative("other"), + negative("VARCHAR()"), + negative("VARCHAR(x)"), + negative("VARCHAR( )"), + negative("CHAR()"), + negative("CHAR(x)"), + negative("CHAR( )"), + negative("VARCHAR(1, 2)"), + negative("VARCHAR("), + negative("VARCHAR)"), + negative("VAR CHAR"), + negative("CHAR(1, 2)"), + negative("CHAR("), + negative("CHAR)"), + negative("CH AR"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TextDataType{}, parsed) + + assert.Equal(t, tc.expectedLength, parsed.(*TextDataType).length) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TextDataType).underlyingType) + + assert.Equal(t, VarcharLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TextDataType).underlyingType, parsed.(*TextDataType).length), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Binary(t *testing.T) { + type test struct { + input string + expectedSize int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedSize: DefaultBinarySize, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "BINARY(30)", expectedSize: 30, expectedUnderlyingType: "BINARY"}, + {input: "varbinary(30)", expectedSize: 30, expectedUnderlyingType: "VARBINARY"}, + {input: "BINARY( 30 )", expectedSize: 30, expectedUnderlyingType: "BINARY"}, + {input: " BINARY ( 30 ) ", expectedSize: 30, expectedUnderlyingType: "BINARY"}, + {input: fmt.Sprintf("BINARY(%d)", DefaultBinarySize), expectedSize: DefaultBinarySize, expectedUnderlyingType: "BINARY"}, + + defaults(" BINARY "), + defaults("BINARY"), + defaults("VARBINARY"), + defaults("binary"), + defaults("varbinary"), + } + + negativeTestCases := []test{ + negative("other(1, 2)"), + negative("other(1)"), + negative("other"), + negative("BINARY()"), + negative("BINARY(x)"), + negative("BINARY( )"), + negative("BINARY(1, 2)"), + negative("BINARY("), + negative("BINARY)"), + negative("BIN ARY"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &BinaryDataType{}, parsed) + + assert.Equal(t, tc.expectedSize, parsed.(*BinaryDataType).size) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*BinaryDataType).underlyingType) + + assert.Equal(t, BinaryLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*BinaryDataType).underlyingType, parsed.(*BinaryDataType).size), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Boolean(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" BOOLEAN "), + defaults("BOOLEAN"), + defaults("boolean"), + } + + negativeTestCases := []test{ + negative("BOOLEAN(38, 0)"), + negative("BOOLEAN(38, 2)"), + negative("BOOLEAN(38)"), + negative("BOOLEAN()"), + negative("BOOL"), + negative("bool"), + negative("B O O L E A N"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &BooleanDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*BooleanDataType).underlyingType) + + assert.Equal(t, BooleanLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Date(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" DATE "), + defaults("DATE"), + defaults("date"), + } + + negativeTestCases := []test{ + negative("DATE(38, 0)"), + negative("DATE(38, 2)"), + negative("DATE(38)"), + negative("DATE()"), + negative("D A T E"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &DateDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*DateDataType).underlyingType) + + assert.Equal(t, DateLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Time(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimePrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" TIME "), + defaults("TIME"), + defaults("time"), + {input: "TIME(5)", expectedPrecision: 5, expectedUnderlyingType: "TIME"}, + {input: "time(5)", expectedPrecision: 5, expectedUnderlyingType: "TIME"}, + } + + negativeTestCases := []test{ + negative("TIME(38, 0)"), + negative("TIME(38, 2)"), + negative("TIME()"), + negative("T I M E"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimeDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimeDataType).underlyingType) + assert.Equal(t, tc.expectedPrecision, parsed.(*TimeDataType).precision) + + assert.Equal(t, TimeLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", tc.expectedUnderlyingType, tc.expectedPrecision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_TimestampLtz(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimestampPrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "TIMESTAMP_LTZ(4)", expectedPrecision: 4, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + {input: "timestamp with local time zone(5)", expectedPrecision: 5, expectedUnderlyingType: "TIMESTAMP WITH LOCAL TIME ZONE"}, + {input: "TIMESTAMP_LTZ( 2 )", expectedPrecision: 2, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + {input: " TIMESTAMP_LTZ ( 7 ) ", expectedPrecision: 7, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + {input: fmt.Sprintf("TIMESTAMP_LTZ(%d)", DefaultTimestampPrecision), expectedPrecision: DefaultTimestampPrecision, expectedUnderlyingType: "TIMESTAMP_LTZ"}, + + defaults(" TIMESTAMP_LTZ "), + defaults("TIMESTAMP_LTZ"), + defaults("TIMESTAMPLTZ"), + defaults("TIMESTAMP WITH LOCAL TIME ZONE"), + defaults("timestamp_ltz"), + defaults("timestampltz"), + defaults("timestamp with local time zone"), + } + + negativeTestCases := []test{ + negative("TIMESTAMP_LTZ(38, 0)"), + negative("TIMESTAMP_LTZ(38, 2)"), + negative("TIMESTAMP_LTZ()"), + negative("T I M E S T A M P _ L T Z"), + negative("other"), + negative("other(3)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimestampLtzDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*TimestampLtzDataType).precision) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimestampLtzDataType).underlyingType) + + assert.Equal(t, TimestampLtzLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampLtzDataType).underlyingType, parsed.(*TimestampLtzDataType).precision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_TimestampNtz(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimestampPrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "TIMESTAMP_NTZ(4)", expectedPrecision: 4, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + {input: "timestamp without time zone(5)", expectedPrecision: 5, expectedUnderlyingType: "TIMESTAMP WITHOUT TIME ZONE"}, + {input: "TIMESTAMP_NTZ( 2 )", expectedPrecision: 2, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + {input: " TIMESTAMP_NTZ ( 7 ) ", expectedPrecision: 7, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + {input: fmt.Sprintf("TIMESTAMP_NTZ(%d)", DefaultTimestampPrecision), expectedPrecision: DefaultTimestampPrecision, expectedUnderlyingType: "TIMESTAMP_NTZ"}, + + defaults(" TIMESTAMP_NTZ "), + defaults("TIMESTAMP_NTZ"), + defaults("TIMESTAMPNTZ"), + defaults("TIMESTAMP WITHOUT TIME ZONE"), + defaults("DATETIME"), + defaults("timestamp_ntz"), + defaults("timestampntz"), + defaults("timestamp without time zone"), + defaults("datetime"), + } + + negativeTestCases := []test{ + negative("TIMESTAMP_NTZ(38, 0)"), + negative("TIMESTAMP_NTZ(38, 2)"), + negative("TIMESTAMP_NTZ()"), + negative("T I M E S T A M P _ N T Z"), + negative("other"), + negative("other(3)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimestampNtzDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*TimestampNtzDataType).precision) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimestampNtzDataType).underlyingType) + + assert.Equal(t, TimestampNtzLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampNtzDataType).underlyingType, parsed.(*TimestampNtzDataType).precision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_TimestampTz(t *testing.T) { + type test struct { + input string + expectedPrecision int + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedPrecision: DefaultTimestampPrecision, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "TIMESTAMP_TZ(4)", expectedPrecision: 4, expectedUnderlyingType: "TIMESTAMP_TZ"}, + {input: "timestamp with time zone(5)", expectedPrecision: 5, expectedUnderlyingType: "TIMESTAMP WITH TIME ZONE"}, + {input: "TIMESTAMP_TZ( 2 )", expectedPrecision: 2, expectedUnderlyingType: "TIMESTAMP_TZ"}, + {input: " TIMESTAMP_TZ ( 7 ) ", expectedPrecision: 7, expectedUnderlyingType: "TIMESTAMP_TZ"}, + {input: fmt.Sprintf("TIMESTAMP_TZ(%d)", DefaultTimestampPrecision), expectedPrecision: DefaultTimestampPrecision, expectedUnderlyingType: "TIMESTAMP_TZ"}, + + defaults(" TIMESTAMP_TZ "), + defaults("TIMESTAMP_TZ"), + defaults("TIMESTAMPTZ"), + defaults("TIMESTAMP WITH TIME ZONE"), + defaults("timestamp_tz"), + defaults("timestamptz"), + defaults("timestamp with time zone"), + } + + negativeTestCases := []test{ + negative("TIMESTAMP_TZ(38, 0)"), + negative("TIMESTAMP_TZ(38, 2)"), + negative("TIMESTAMP_TZ()"), + negative("T I M E S T A M P _ T Z"), + negative("other"), + negative("other(3)"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &TimestampTzDataType{}, parsed) + + assert.Equal(t, tc.expectedPrecision, parsed.(*TimestampTzDataType).precision) + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*TimestampTzDataType).underlyingType) + + assert.Equal(t, TimestampTzLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%d)", parsed.(*TimestampTzDataType).underlyingType, parsed.(*TimestampTzDataType).precision), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Variant(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" VARIANT "), + defaults("VARIANT"), + defaults("variant"), + } + + negativeTestCases := []test{ + negative("VARIANT(38, 0)"), + negative("VARIANT(38, 2)"), + negative("VARIANT(38)"), + negative("VARIANT()"), + negative("V A R I A N T"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &VariantDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*VariantDataType).underlyingType) + + assert.Equal(t, VariantLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Object(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" OBJECT "), + defaults("OBJECT"), + defaults("object"), + } + + negativeTestCases := []test{ + negative("OBJECT(38, 0)"), + negative("OBJECT(38, 2)"), + negative("OBJECT(38)"), + negative("OBJECT()"), + negative("O B J E C T"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &ObjectDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*ObjectDataType).underlyingType) + + assert.Equal(t, ObjectLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Array(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" ARRAY "), + defaults("ARRAY"), + defaults("array"), + } + + negativeTestCases := []test{ + negative("ARRAY(38, 0)"), + negative("ARRAY(38, 2)"), + negative("ARRAY(38)"), + negative("ARRAY()"), + negative("A R R A Y"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &ArrayDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*ArrayDataType).underlyingType) + + assert.Equal(t, ArrayLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Geography(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" GEOGRAPHY "), + defaults("GEOGRAPHY"), + defaults("geography"), + } + + negativeTestCases := []test{ + negative("GEOGRAPHY(38, 0)"), + negative("GEOGRAPHY(38, 2)"), + negative("GEOGRAPHY(38)"), + negative("GEOGRAPHY()"), + negative("G E O G R A P H Y"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &GeographyDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*GeographyDataType).underlyingType) + + assert.Equal(t, GeographyLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Geometry(t *testing.T) { + type test struct { + input string + expectedUnderlyingType string + } + defaults := func(input string) test { + return test{ + input: input, + expectedUnderlyingType: strings.TrimSpace(strings.ToUpper(input)), + } + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + defaults(" GEOMETRY "), + defaults("GEOMETRY"), + defaults("geometry"), + } + + negativeTestCases := []test{ + negative("GEOMETRY(38, 0)"), + negative("GEOMETRY(38, 2)"), + negative("GEOMETRY(38)"), + negative("GEOMETRY()"), + negative("G E O M E T R Y"), + negative("other"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &GeometryDataType{}, parsed) + + assert.Equal(t, tc.expectedUnderlyingType, parsed.(*GeometryDataType).underlyingType) + + assert.Equal(t, GeometryLegacyDataType, parsed.ToLegacyDataTypeSql()) + assert.Equal(t, tc.expectedUnderlyingType, parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_ParseDataType_Vector(t *testing.T) { + type test struct { + input string + expectedInnerType string + expectedDimension int + } + negative := func(input string) test { + return test{input: input} + } + + positiveTestCases := []test{ + {input: "VECTOR(INT, 2)", expectedInnerType: "INT", expectedDimension: 2}, + {input: "VECTOR(FLOAT, 2)", expectedInnerType: "FLOAT", expectedDimension: 2}, + {input: "VeCtOr ( InT , 40 )", expectedInnerType: "INT", expectedDimension: 40}, + {input: " VECTOR ( INT , 40 )", expectedInnerType: "INT", expectedDimension: 40}, + } + + negativeTestCases := []test{ + negative("VECTOR(1, 2)"), + negative("VECTOR(1)"), + negative("VECTOR(2, INT)"), + negative("VECTOR()"), + negative("VECTOR"), + negative("VECTOR(INT, 2, 3)"), + negative("VECTOR(INT)"), + negative("VECTOR(x, 2)"), + negative("VECTOR("), + negative("VECTOR)"), + negative("VEC TOR"), + } + + for _, tc := range positiveTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.NoError(t, err) + require.IsType(t, &VectorDataType{}, parsed) + + assert.Equal(t, tc.expectedInnerType, parsed.(*VectorDataType).innerType) + assert.Equal(t, tc.expectedDimension, parsed.(*VectorDataType).dimension) + assert.Equal(t, "VECTOR", parsed.(*VectorDataType).underlyingType) + + assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToLegacyDataTypeSql()) + assert.Equal(t, fmt.Sprintf("%s(%s, %d)", parsed.(*VectorDataType).underlyingType, parsed.(*VectorDataType).innerType, parsed.(*VectorDataType).dimension), parsed.ToSql()) + }) + } + + for _, tc := range negativeTestCases { + tc := tc + t.Run(tc.input, func(t *testing.T) { + parsed, err := ParseDataType(tc.input) + + require.Error(t, err) + require.Nil(t, parsed) + }) + } +} + +func Test_AreTheSame(t *testing.T) { + // empty d1/d2 means nil DataType input + type test struct { + d1 string + d2 string + expectedOutcome bool + } + + testCases := []test{ + {d1: "", d2: "", expectedOutcome: true}, + {d1: "", d2: "NUMBER", expectedOutcome: false}, + {d1: "NUMBER", d2: "", expectedOutcome: false}, + + {d1: "NUMBER(20)", d2: "NUMBER(20, 2)", expectedOutcome: false}, + {d1: "NUMBER(20, 1)", d2: "NUMBER(20, 2)", expectedOutcome: false}, + {d1: "NUMBER", d2: "NUMBER(20, 2)", expectedOutcome: false}, + {d1: "NUMBER", d2: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedOutcome: true}, + {d1: fmt.Sprintf("NUMBER(%d)", DefaultNumberPrecision), d2: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedOutcome: true}, + {d1: "NUMBER", d2: "NUMBER", expectedOutcome: true}, + {d1: "NUMBER(20)", d2: "NUMBER(20)", expectedOutcome: true}, + {d1: "NUMBER(20, 2)", d2: "NUMBER(20, 2)", expectedOutcome: true}, + {d1: "INT", d2: "NUMBER", expectedOutcome: true}, + {d1: "INT", d2: fmt.Sprintf("NUMBER(%d, %d)", DefaultNumberPrecision, DefaultNumberScale), expectedOutcome: true}, + {d1: "INT", d2: "NUMBER(20)", expectedOutcome: false}, + {d1: "NUMBER", d2: "VARCHAR", expectedOutcome: false}, + {d1: "NUMBER(20)", d2: "VARCHAR(20)", expectedOutcome: false}, + {d1: "CHAR", d2: "VARCHAR", expectedOutcome: false}, + {d1: "CHAR", d2: fmt.Sprintf("VARCHAR(%d)", DefaultCharLength), expectedOutcome: true}, + {d1: fmt.Sprintf("CHAR(%d)", DefaultVarcharLength), d2: "VARCHAR", expectedOutcome: true}, + {d1: "BINARY", d2: "BINARY", expectedOutcome: true}, + {d1: "BINARY", d2: "VARBINARY", expectedOutcome: true}, + {d1: "BINARY(20)", d2: "BINARY(20)", expectedOutcome: true}, + {d1: "BINARY(20)", d2: "BINARY(30)", expectedOutcome: false}, + {d1: "BINARY", d2: "BINARY(30)", expectedOutcome: false}, + {d1: fmt.Sprintf("BINARY(%d)", DefaultBinarySize), d2: "BINARY", expectedOutcome: true}, + {d1: "FLOAT", d2: "FLOAT4", expectedOutcome: true}, + {d1: "DOUBLE", d2: "FLOAT8", expectedOutcome: true}, + {d1: "DOUBLE PRECISION", d2: "REAL", expectedOutcome: true}, + {d1: "TIMESTAMPLTZ", d2: "TIMESTAMPNTZ", expectedOutcome: false}, + {d1: "TIMESTAMPLTZ", d2: "TIMESTAMPTZ", expectedOutcome: false}, + {d1: "TIMESTAMPLTZ", d2: fmt.Sprintf("TIMESTAMPLTZ(%d)", DefaultTimestampPrecision), expectedOutcome: true}, + {d1: "VECTOR(INT, 20)", d2: "VECTOR(INT, 20)", expectedOutcome: true}, + {d1: "VECTOR(INT, 20)", d2: "VECTOR(INT, 30)", expectedOutcome: false}, + {d1: "VECTOR(FLOAT, 20)", d2: "VECTOR(INT, 30)", expectedOutcome: false}, + {d1: "VECTOR(FLOAT, 20)", d2: "VECTOR(INT, 20)", expectedOutcome: false}, + {d1: "VECTOR(FLOAT, 20)", d2: "VECTOR(FLOAT, 20)", expectedOutcome: true}, + {d1: "VECTOR(FLOAT, 20)", d2: "FLOAT", expectedOutcome: false}, + {d1: "TIME", d2: "TIME", expectedOutcome: true}, + {d1: "TIME", d2: "TIME(5)", expectedOutcome: false}, + {d1: "TIME", d2: fmt.Sprintf("TIME(%d)", DefaultTimePrecision), expectedOutcome: true}, + } + + for _, tc := range testCases { + tc := tc + t.Run(fmt.Sprintf(`compare "%s" with "%s" expecting %t`, tc.d1, tc.d2, tc.expectedOutcome), func(t *testing.T) { + var p1, p2 DataType + var err error + + if tc.d1 != "" { + p1, err = ParseDataType(tc.d1) + require.NoError(t, err) + } + + if tc.d2 != "" { + p2, err = ParseDataType(tc.d2) + require.NoError(t, err) + } + + require.Equal(t, tc.expectedOutcome, AreTheSame(p1, p2)) + }) + } +} diff --git a/pkg/sdk/datatypes/date.go b/pkg/sdk/datatypes/date.go new file mode 100644 index 0000000000..92ee7c27bc --- /dev/null +++ b/pkg/sdk/datatypes/date.go @@ -0,0 +1,21 @@ +package datatypes + +// DateDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#date +// It does not have synonyms. It does not have any attributes. +type DateDataType struct { + underlyingType string +} + +func (t *DateDataType) ToSql() string { + return t.underlyingType +} + +func (t *DateDataType) ToLegacyDataTypeSql() string { + return DateLegacyDataType +} + +var DateDataTypeSynonyms = []string{DateLegacyDataType} + +func parseDateDataTypeRaw(raw sanitizedDataTypeRaw) (*DateDataType, error) { + return &DateDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/float.go b/pkg/sdk/datatypes/float.go new file mode 100644 index 0000000000..a0ca84863b --- /dev/null +++ b/pkg/sdk/datatypes/float.go @@ -0,0 +1,21 @@ +package datatypes + +// FloatDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-numeric#data-types-for-floating-point-numbers +// It does have synonyms. It does not have any attributes. +type FloatDataType struct { + underlyingType string +} + +func (t *FloatDataType) ToSql() string { + return t.underlyingType +} + +func (t *FloatDataType) ToLegacyDataTypeSql() string { + return FloatLegacyDataType +} + +var FloatDataTypeSynonyms = []string{"FLOAT8", "FLOAT4", FloatLegacyDataType, "DOUBLE PRECISION", "DOUBLE", "REAL"} + +func parseFloatDataTypeRaw(raw sanitizedDataTypeRaw) (*FloatDataType, error) { + return &FloatDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/geography.go b/pkg/sdk/datatypes/geography.go new file mode 100644 index 0000000000..4a024a20b0 --- /dev/null +++ b/pkg/sdk/datatypes/geography.go @@ -0,0 +1,21 @@ +package datatypes + +// GeographyDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-geospatial#geography-data-type +// It does not have synonyms. It does not have any attributes. +type GeographyDataType struct { + underlyingType string +} + +func (t *GeographyDataType) ToSql() string { + return t.underlyingType +} + +func (t *GeographyDataType) ToLegacyDataTypeSql() string { + return GeographyLegacyDataType +} + +var GeographyDataTypeSynonyms = []string{GeographyLegacyDataType} + +func parseGeographyDataTypeRaw(raw sanitizedDataTypeRaw) (*GeographyDataType, error) { + return &GeographyDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/geometry.go b/pkg/sdk/datatypes/geometry.go new file mode 100644 index 0000000000..d09ebd3eea --- /dev/null +++ b/pkg/sdk/datatypes/geometry.go @@ -0,0 +1,21 @@ +package datatypes + +// GeometryDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-geospatial#geometry-data-type +// It does not have synonyms. It does not have any attributes. +type GeometryDataType struct { + underlyingType string +} + +func (t *GeometryDataType) ToSql() string { + return t.underlyingType +} + +func (t *GeometryDataType) ToLegacyDataTypeSql() string { + return GeometryLegacyDataType +} + +var GeometryDataTypeSynonyms = []string{GeometryLegacyDataType} + +func parseGeometryDataTypeRaw(raw sanitizedDataTypeRaw) (*GeometryDataType, error) { + return &GeometryDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/legacy.go b/pkg/sdk/datatypes/legacy.go new file mode 100644 index 0000000000..5a0e249cd7 --- /dev/null +++ b/pkg/sdk/datatypes/legacy.go @@ -0,0 +1,19 @@ +package datatypes + +const ( + ArrayLegacyDataType = "ARRAY" + BinaryLegacyDataType = "BINARY" + BooleanLegacyDataType = "BOOLEAN" + DateLegacyDataType = "DATE" + FloatLegacyDataType = "FLOAT" + GeographyLegacyDataType = "GEOGRAPHY" + GeometryLegacyDataType = "GEOMETRY" + NumberLegacyDataType = "NUMBER" + ObjectLegacyDataType = "OBJECT" + VarcharLegacyDataType = "VARCHAR" + TimeLegacyDataType = "TIME" + TimestampLtzLegacyDataType = "TIMESTAMP_LTZ" + TimestampNtzLegacyDataType = "TIMESTAMP_NTZ" + TimestampTzLegacyDataType = "TIMESTAMP_TZ" + VariantLegacyDataType = "VARIANT" +) diff --git a/pkg/sdk/datatypes/number.go b/pkg/sdk/datatypes/number.go new file mode 100644 index 0000000000..14ac2696fc --- /dev/null +++ b/pkg/sdk/datatypes/number.go @@ -0,0 +1,104 @@ +package datatypes + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const ( + DefaultNumberPrecision = 38 + DefaultNumberScale = 0 +) + +// NumberDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-numeric#data-types-for-fixed-point-numbers +// It does have synonyms that allow specifying precision and scale; here called synonyms. +// It does have synonyms that does not allow specifying precision and scale; here called subtypes. +type NumberDataType struct { + precision int + scale int + underlyingType string +} + +func (t *NumberDataType) ToSql() string { + return fmt.Sprintf("%s(%d, %d)", t.underlyingType, t.precision, t.scale) +} + +func (t *NumberDataType) ToLegacyDataTypeSql() string { + return NumberLegacyDataType +} + +var ( + NumberDataTypeSynonyms = []string{NumberLegacyDataType, "DECIMAL", "DEC", "NUMERIC"} + NumberDataTypeSubTypes = []string{"INTEGER", "INT", "BIGINT", "SMALLINT", "TINYINT", "BYTEINT"} + AllNumberDataTypes = append(NumberDataTypeSynonyms, NumberDataTypeSubTypes...) +) + +func parseNumberDataTypeRaw(raw sanitizedDataTypeRaw) (*NumberDataType, error) { + switch { + case slices.Contains(NumberDataTypeSynonyms, raw.matchedByType): + return parseNumberDataTypeWithPrecisionAndScale(raw) + case slices.Contains(NumberDataTypeSubTypes, raw.matchedByType): + return parseNumberDataTypeWithoutPrecisionAndScale(raw) + default: + return nil, fmt.Errorf("unknown number data type: %s", raw.raw) + } +} + +// parseNumberDataTypeWithPrecisionAndScale extracts precision and scale from the raw number data type input. +// It returns defaults if no arguments were provided. It returns error if any part is not parseable. +func parseNumberDataTypeWithPrecisionAndScale(raw sanitizedDataTypeRaw) (*NumberDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default number precision and scale") + return &NumberDataType{DefaultNumberPrecision, DefaultNumberScale, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`number %s could not be parsed, use "%s(precision, scale)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`number %s could not be parsed, use "%s(precision, scale)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := r[1 : len(r)-1] + parts := strings.Split(onlyArgs, ",") + switch l := len(parts); l { + case 1: + precision, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s", err: %v`, parts[0], err) + return nil, fmt.Errorf(`could not parse the number's precision: "%s", err: %w`, parts[0], err) + } + return &NumberDataType{precision, DefaultNumberScale, raw.matchedByType}, nil + case 2: + precision, err := strconv.Atoi(strings.TrimSpace(parts[0])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number precision "%s", err: %v`, parts[0], err) + return nil, fmt.Errorf(`could not parse the number's precision: "%s", err: %w`, parts[0], err) + } + scale, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse number scale "%s", err: %v`, parts[1], err) + return nil, fmt.Errorf(`could not parse the number's scale: "%s", err: %w`, parts[1], err) + } + return &NumberDataType{precision, scale, raw.matchedByType}, nil + default: + logging.DebugLogger.Printf("[DEBUG] Unexpected length of number arguments") + return nil, fmt.Errorf(`number cannot have %d arguments: "%s"; only precision and scale are allowed`, l, onlyArgs) + } +} + +func parseNumberDataTypeWithoutPrecisionAndScale(raw sanitizedDataTypeRaw) (*NumberDataType, error) { + if raw.raw != raw.matchedByType { + args := strings.TrimPrefix(raw.raw, raw.matchedByType) + logging.DebugLogger.Printf("[DEBUG] Number type %s cannot have arguments: %s", raw.matchedByType, args) + return nil, fmt.Errorf("number type %s cannot have arguments: %s", raw.matchedByType, args) + } else { + logging.DebugLogger.Printf("[DEBUG] Returning default number precision and scale") + return &NumberDataType{DefaultNumberPrecision, DefaultNumberScale, raw.matchedByType}, nil + } +} + +func areNumberDataTypesTheSame(a, b *NumberDataType) bool { + return a.precision == b.precision && a.scale == b.scale +} diff --git a/pkg/sdk/datatypes/object.go b/pkg/sdk/datatypes/object.go new file mode 100644 index 0000000000..fe333aa7b0 --- /dev/null +++ b/pkg/sdk/datatypes/object.go @@ -0,0 +1,21 @@ +package datatypes + +// ObjectDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-semistructured#object +// It does not have synonyms. It does not have any attributes. +type ObjectDataType struct { + underlyingType string +} + +func (t *ObjectDataType) ToSql() string { + return t.underlyingType +} + +func (t *ObjectDataType) ToLegacyDataTypeSql() string { + return ObjectLegacyDataType +} + +var ObjectDataTypeSynonyms = []string{ObjectLegacyDataType} + +func parseObjectDataTypeRaw(raw sanitizedDataTypeRaw) (*ObjectDataType, error) { + return &ObjectDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/text.go b/pkg/sdk/datatypes/text.go new file mode 100644 index 0000000000..2598253101 --- /dev/null +++ b/pkg/sdk/datatypes/text.go @@ -0,0 +1,69 @@ +package datatypes + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const ( + DefaultVarcharLength = 16777216 + DefaultCharLength = 1 +) + +// TextDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-text#data-types-for-text-strings +// It does have synonyms that allow specifying length. +// It does have synonyms that allow specifying length but differ with the default length when length is omitted; here called subtypes. +type TextDataType struct { + length int + underlyingType string +} + +func (t *TextDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.length) +} + +func (t *TextDataType) ToLegacyDataTypeSql() string { + return VarcharLegacyDataType +} + +var ( + TextDataTypeSynonyms = []string{VarcharLegacyDataType, "STRING", "TEXT", "NVARCHAR2", "NVARCHAR", "CHAR VARYING", "NCHAR VARYING"} + TextDataTypeSubtypes = []string{"CHARACTER", "CHAR", "NCHAR"} + AllTextDataTypes = append(TextDataTypeSynonyms, TextDataTypeSubtypes...) +) + +// parseTextDataTypeRaw extracts length from the raw text data type input. +// It returns default if it can't parse arguments, data type is different, or no length argument was provided. +func parseTextDataTypeRaw(raw sanitizedDataTypeRaw) (*TextDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default length for text") + switch { + case slices.Contains(TextDataTypeSynonyms, raw.matchedByType): + return &TextDataType{DefaultVarcharLength, raw.matchedByType}, nil + case slices.Contains(TextDataTypeSubtypes, raw.matchedByType): + return &TextDataType{DefaultCharLength, raw.matchedByType}, nil + default: + return nil, fmt.Errorf("unknown text data type: %s", raw.raw) + } + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`text %s could not be parsed, use "%s(length)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`text %s could not be parsed, use "%s(length)" format`, raw.raw, raw.matchedByType) + } + lengthRaw := r[1 : len(r)-1] + length, err := strconv.Atoi(strings.TrimSpace(lengthRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse varchar length "%s", err: %v`, lengthRaw, err) + return nil, fmt.Errorf(`could not parse the varchar's length: "%s", err: %w`, lengthRaw, err) + } + return &TextDataType{length, raw.matchedByType}, nil +} + +func areTextDataTypesTheSame(a, b *TextDataType) bool { + return a.length == b.length +} diff --git a/pkg/sdk/datatypes/time.go b/pkg/sdk/datatypes/time.go new file mode 100644 index 0000000000..ee79421122 --- /dev/null +++ b/pkg/sdk/datatypes/time.go @@ -0,0 +1,51 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +const DefaultTimePrecision = 9 + +// TimeDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#time +// It does not have synonyms. It does have optional precision attribute. +type TimeDataType struct { + precision int + underlyingType string +} + +func (t *TimeDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimeDataType) ToLegacyDataTypeSql() string { + return TimeLegacyDataType +} + +var TimeDataTypeSynonyms = []string{TimeLegacyDataType} + +func parseTimeDataTypeRaw(raw sanitizedDataTypeRaw) (*TimeDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for time") + return &TimeDataType{DefaultTimePrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`time %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`time %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse time precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the time's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimeDataType{precision, raw.matchedByType}, nil +} + +func areTimeDataTypesTheSame(a, b *TimeDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/timestamp.go b/pkg/sdk/datatypes/timestamp.go new file mode 100644 index 0000000000..82b22b74d2 --- /dev/null +++ b/pkg/sdk/datatypes/timestamp.go @@ -0,0 +1,3 @@ +package datatypes + +const DefaultTimestampPrecision = 9 diff --git a/pkg/sdk/datatypes/timestamp_ltz.go b/pkg/sdk/datatypes/timestamp_ltz.go new file mode 100644 index 0000000000..f844ec537f --- /dev/null +++ b/pkg/sdk/datatypes/timestamp_ltz.go @@ -0,0 +1,49 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TimestampLtzDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp-ltz-timestamp-ntz-timestamp-tz +// It does have synonyms. It does have optional precision attribute. +type TimestampLtzDataType struct { + precision int + underlyingType string +} + +func (t *TimestampLtzDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimestampLtzDataType) ToLegacyDataTypeSql() string { + return TimestampLtzLegacyDataType +} + +var TimestampLtzDataTypeSynonyms = []string{TimestampLtzLegacyDataType, "TIMESTAMPLTZ", "TIMESTAMP WITH LOCAL TIME ZONE"} + +func parseTimestampLtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampLtzDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for timestamp ltz") + return &TimestampLtzDataType{DefaultTimestampPrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`timestamp ltz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`timestamp ltz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse timestamp ltz precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the timestamp's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimestampLtzDataType{precision, raw.matchedByType}, nil +} + +func areTimestampLtzDataTypesTheSame(a, b *TimestampLtzDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/timestamp_ntz.go b/pkg/sdk/datatypes/timestamp_ntz.go new file mode 100644 index 0000000000..86aa5f0a0c --- /dev/null +++ b/pkg/sdk/datatypes/timestamp_ntz.go @@ -0,0 +1,49 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TimestampNtzDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp-ltz-timestamp-ntz-timestamp-tz +// It does have synonyms. It does have optional precision attribute. +type TimestampNtzDataType struct { + precision int + underlyingType string +} + +func (t *TimestampNtzDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimestampNtzDataType) ToLegacyDataTypeSql() string { + return TimestampNtzLegacyDataType +} + +var TimestampNtzDataTypeSynonyms = []string{TimestampNtzLegacyDataType, "TIMESTAMPNTZ", "TIMESTAMP WITHOUT TIME ZONE", "DATETIME"} + +func parseTimestampNtzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampNtzDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for timestamp ntz") + return &TimestampNtzDataType{DefaultTimestampPrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`timestamp ntz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`timestamp ntz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse timestamp ntz precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the timestamp's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimestampNtzDataType{precision, raw.matchedByType}, nil +} + +func areTimestampNtzDataTypesTheSame(a, b *TimestampNtzDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/timestamp_tz.go b/pkg/sdk/datatypes/timestamp_tz.go new file mode 100644 index 0000000000..44e6cafeb6 --- /dev/null +++ b/pkg/sdk/datatypes/timestamp_tz.go @@ -0,0 +1,49 @@ +package datatypes + +import ( + "fmt" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// TimestampTzDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-datetime#timestamp-ltz-timestamp-ntz-timestamp-tz +// It does have synonyms. It does have optional precision attribute. +type TimestampTzDataType struct { + precision int + underlyingType string +} + +func (t *TimestampTzDataType) ToSql() string { + return fmt.Sprintf("%s(%d)", t.underlyingType, t.precision) +} + +func (t *TimestampTzDataType) ToLegacyDataTypeSql() string { + return TimestampTzLegacyDataType +} + +var TimestampTzDataTypeSynonyms = []string{TimestampTzLegacyDataType, "TIMESTAMPTZ", "TIMESTAMP WITH TIME ZONE"} + +func parseTimestampTzDataTypeRaw(raw sanitizedDataTypeRaw) (*TimestampTzDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" { + logging.DebugLogger.Printf("[DEBUG] Returning default precision for timestamp tz") + return &TimestampTzDataType{DefaultTimestampPrecision, raw.matchedByType}, nil + } + if !strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")") { + logging.DebugLogger.Printf(`timestamp tz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`timestamp tz %s could not be parsed, use "%s(precision)" format`, raw.raw, raw.matchedByType) + } + precisionRaw := r[1 : len(r)-1] + precision, err := strconv.Atoi(strings.TrimSpace(precisionRaw)) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse timestamp tz precision "%s", err: %v`, precisionRaw, err) + return nil, fmt.Errorf(`could not parse the timestamp's precision: "%s", err: %w`, precisionRaw, err) + } + return &TimestampTzDataType{precision, raw.matchedByType}, nil +} + +func areTimestampTzDataTypesTheSame(a, b *TimestampTzDataType) bool { + return a.precision == b.precision +} diff --git a/pkg/sdk/datatypes/variant.go b/pkg/sdk/datatypes/variant.go new file mode 100644 index 0000000000..b096084934 --- /dev/null +++ b/pkg/sdk/datatypes/variant.go @@ -0,0 +1,21 @@ +package datatypes + +// VariantDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-semistructured#variant +// It does not have synonyms. It does not have any attributes. +type VariantDataType struct { + underlyingType string +} + +func (t *VariantDataType) ToSql() string { + return t.underlyingType +} + +func (t *VariantDataType) ToLegacyDataTypeSql() string { + return VariantLegacyDataType +} + +var VariantDataTypeSynonyms = []string{VariantLegacyDataType} + +func parseVariantDataTypeRaw(raw sanitizedDataTypeRaw) (*VariantDataType, error) { + return &VariantDataType{raw.matchedByType}, nil +} diff --git a/pkg/sdk/datatypes/vector.go b/pkg/sdk/datatypes/vector.go new file mode 100644 index 0000000000..a535ca2b58 --- /dev/null +++ b/pkg/sdk/datatypes/vector.go @@ -0,0 +1,65 @@ +package datatypes + +import ( + "fmt" + "slices" + "strconv" + "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/logging" +) + +// VectorDataType is based on https://docs.snowflake.com/en/sql-reference/data-types-vector#vector +// It does not have synonyms. It does have type (int or float) and dimension required attributes. +type VectorDataType struct { + innerType string + dimension int + underlyingType string +} + +func (t *VectorDataType) ToSql() string { + return fmt.Sprintf("%s(%s, %d)", t.underlyingType, t.innerType, t.dimension) +} + +// ToLegacyDataTypeSql for vector is the only one correct because in the old implementation it was returned as DataType(dType), so a proper format. +func (t *VectorDataType) ToLegacyDataTypeSql() string { + return t.ToSql() +} + +var ( + VectorDataTypeSynonyms = []string{"VECTOR"} + VectorAllowedInnerTypes = []string{"INT", "FLOAT"} +) + +// parseVectorDataTypeRaw extracts type and dimension from the raw vector data type input. +// Both attributes are required so no defaults are returned in case any of them is missing. +func parseVectorDataTypeRaw(raw sanitizedDataTypeRaw) (*VectorDataType, error) { + r := strings.TrimSpace(strings.TrimPrefix(raw.raw, raw.matchedByType)) + if r == "" || (!strings.HasPrefix(r, "(") || !strings.HasSuffix(r, ")")) { + logging.DebugLogger.Printf(`vector %s could not be parsed, use "%s(type, dimension)" format`, raw.raw, raw.matchedByType) + return nil, fmt.Errorf(`vector %s could not be parsed, use "%s(type, dimension)" format`, raw.raw, raw.matchedByType) + } + onlyArgs := r[1 : len(r)-1] + parts := strings.Split(onlyArgs, ",") + switch l := len(parts); l { + case 2: + vectorType := strings.TrimSpace(parts[0]) + if !slices.Contains(VectorAllowedInnerTypes, vectorType) { + logging.DebugLogger.Printf(`[DEBUG] Inner type for vector could not be recognized: "%s"; use one of %s`, parts[0], strings.Join(VectorAllowedInnerTypes, ",")) + return nil, fmt.Errorf(`could not parse vector's inner type': "%s"; use one of %s`, parts[0], strings.Join(VectorAllowedInnerTypes, ",")) + } + dimension, err := strconv.Atoi(strings.TrimSpace(parts[1])) + if err != nil { + logging.DebugLogger.Printf(`[DEBUG] Could not parse vector's dimension "%s", err: %v`, parts[1], err) + return nil, fmt.Errorf(`could not parse the vector's dimension: "%s", err: %w`, parts[1], err) + } + return &VectorDataType{vectorType, dimension, raw.matchedByType}, nil + default: + logging.DebugLogger.Printf("[DEBUG] Unexpected length of vector arguments") + return nil, fmt.Errorf(`vector cannot have %d arguments: "%s"; use "%s(type, dimension)" format`, l, onlyArgs, raw.matchedByType) + } +} + +func areVectorDataTypesTheSame(a, b *VectorDataType) bool { + return a.innerType == b.innerType && a.dimension == b.dimension +} diff --git a/pkg/sdk/dynamic_table.go b/pkg/sdk/dynamic_table.go index dac13dc576..35457cff6a 100644 --- a/pkg/sdk/dynamic_table.go +++ b/pkg/sdk/dynamic_table.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/tracking" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type DynamicTables interface { @@ -232,10 +233,10 @@ type dynamicTableDetailsRow struct { } func (row dynamicTableDetailsRow) convert() *DynamicTableDetails { - typ, _ := ToDataType(row.Type) + typ, _ := datatypes.ParseDataType(row.Type) dtd := &DynamicTableDetails{ Name: row.Name, - Type: typ, + Type: LegacyDataTypeFrom(typ), Kind: row.Kind, IsNull: row.IsNull == "Y", PrimaryKey: row.PrimaryKey, diff --git a/pkg/sdk/identifier_helpers.go b/pkg/sdk/identifier_helpers.go index 95ea8e894f..90d1acdf44 100644 --- a/pkg/sdk/identifier_helpers.go +++ b/pkg/sdk/identifier_helpers.go @@ -4,6 +4,8 @@ import ( "fmt" "log" "strings" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) type Identifier interface { @@ -247,8 +249,8 @@ func NewSchemaObjectIdentifierFromFullyQualifiedName(fullyQualifiedName string) if trimmedArg == "" { continue } - dt, _ := ToDataType(trimmedArg) - id.arguments = append(id.arguments, dt) + dt, _ := datatypes.ParseDataType(trimmedArg) + id.arguments = append(id.arguments, LegacyDataTypeFrom(dt)) } } else { // this is every other kind of schema object id.name = strings.Trim(parts[2], `"`) @@ -318,11 +320,11 @@ func NewSchemaObjectIdentifierWithArguments(databaseName, schemaName, name strin // Arguments have to be "normalized" with ToDataType, so the signature would match with the one returned by Snowflake. normalizedArguments := make([]DataType, len(argumentDataTypes)) for i, argument := range argumentDataTypes { - normalizedArgument, err := ToDataType(string(argument)) + normalizedArgument, err := datatypes.ParseDataType(string(argument)) if err != nil { log.Printf("[DEBUG] failed to normalize argument %d: %v, err = %v", i, argument, err) } - normalizedArguments[i] = normalizedArgument + normalizedArguments[i] = LegacyDataTypeFrom(normalizedArgument) } return SchemaObjectIdentifierWithArguments{ databaseName: strings.Trim(databaseName, `"`), diff --git a/pkg/sdk/masking_policy.go b/pkg/sdk/masking_policy.go index b6c87a2d0a..f92b2e89b6 100644 --- a/pkg/sdk/masking_policy.go +++ b/pkg/sdk/masking_policy.go @@ -9,6 +9,7 @@ import ( "time" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) var _ MaskingPolicies = (*maskingPolicies)(nil) @@ -360,14 +361,14 @@ type maskingPolicyDetailsRow struct { } func (row maskingPolicyDetailsRow) toMaskingPolicyDetails() *MaskingPolicyDetails { - dataType, err := ToDataType(row.ReturnType) + dataType, err := datatypes.ParseDataType(row.ReturnType) if err != nil { return nil } v := &MaskingPolicyDetails{ Name: row.Name, Signature: []TableColumnSignature{}, - ReturnType: dataType, + ReturnType: LegacyDataTypeFrom(dataType), Body: row.Body, } diff --git a/pkg/sdk/tables_test.go b/pkg/sdk/tables_test.go index bcd807a75f..9c2c32ed7e 100644 --- a/pkg/sdk/tables_test.go +++ b/pkg/sdk/tables_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -379,7 +380,9 @@ func TestTableCreate(t *testing.T) { tableComment := random.Comment() collation := "de" columnName := "FIRST_COLUMN" - columnType, err := ToDataType("VARCHAR") + columnTypeRaw, err := datatypes.ParseDataType("VARCHAR") + require.NoError(t, err) + columnType := LegacyDataTypeFrom(columnTypeRaw) maskingPolicy := ColumnMaskingPolicy{ Name: randomSchemaObjectIdentifier(), Using: []string{"FOO", "BAR"}, @@ -551,8 +554,9 @@ func TestTableCreateAsSelect(t *testing.T) { t.Run("with complete options", func(t *testing.T) { id := randomSchemaObjectIdentifier() columnName := "FIRST_COLUMN" - columnType, err := ToDataType("VARCHAR") + columnTypeRaw, err := datatypes.ParseDataType("VARCHAR") require.NoError(t, err) + columnType := LegacyDataTypeFrom(columnTypeRaw) maskingPolicy := TableAsSelectColumnMaskingPolicy{ Name: randomSchemaObjectIdentifier(), } diff --git a/pkg/sdk/testint/data_types_integration_test.go b/pkg/sdk/testint/data_types_integration_test.go new file mode 100644 index 0000000000..e62a59ef5f --- /dev/null +++ b/pkg/sdk/testint/data_types_integration_test.go @@ -0,0 +1,349 @@ +package testint + +import ( + "fmt" + "slices" + "testing" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/helpers" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestInt_DataTypes(t *testing.T) { + client := testClient(t) + ctx := testContext(t) + + incorrectBooleanDatatypes := []string{ + "BOOLEAN()", + "BOOLEAN(1)", + "BOOL", + } + incorrectFloatDatatypes := []string{ + "DOUBLE()", + "DOUBLE(1)", + "DOUBLE PRECISION(1)", + } + incorrectlyCorrectFloatDatatypes := []string{ + "FLOAT()", + "FLOAT(20)", + "FLOAT4(20)", + "FLOAT8(20)", + "REAL(20)", + } + incorrectNumberDatatypes := []string{ + "NUMBER()", + "NUMBER(x)", + "INT()", + "NUMBER(36, 5, 7)", + } + incorrectTextDatatypes := []string{ + "VARCHAR()", + "VARCHAR(x)", + "VARCHAR(36, 5)", + } + vectorInnerTypesSynonyms := helpers.ConcatSlices(datatypes.AllNumberDataTypes, datatypes.FloatDataTypeSynonyms) + vectorInnerTypeSynonymsThatWork := []string{ + "INTEGER", + "INT", + "FLOAT8", + "FLOAT4", + "FLOAT", + } + + for _, c := range datatypes.ArrayDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of array datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT []::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT []::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + for _, c := range datatypes.BinaryDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of binary datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TO_BINARY('AB')::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT TO_BINARY('AB')::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT TO_BINARY('AB')::%s(36, 2)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "','") + assert.ErrorContains(t, err, "')'") + }) + } + + for _, c := range datatypes.BooleanDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of boolean datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TRUE::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range incorrectBooleanDatatypes { + t.Run(fmt.Sprintf("check behavior of boolean datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TRUE::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + for _, c := range datatypes.DateDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of date datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.FloatDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of float datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1.1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range incorrectFloatDatatypes { + t.Run(fmt.Sprintf("check behavior of float datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1.1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + // There is no attribute documented for float numbers: https://docs.snowflake.com/en/sql-reference/data-types-numeric#float-float4-float8. + // However, adding it succeeds for FLOAT, FLOAT4, FLOAT8, and REAL, but ift fails both for DOUBLE and DOUBLE PRECISION. + for _, c := range incorrectlyCorrectFloatDatatypes { + t.Run(fmt.Sprintf("document incorrect behavior of float datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1.1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.NoError(t, err) + }) + } + + // Testing on table creation here because casting (::GEOGRAPHY) was ending with errors (even for the "correct" cases). + for _, c := range datatypes.GeographyDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of geography datatype: %s", c), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i %s)", tableId.FullyQualifiedName(), c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + + tableId = testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql = fmt.Sprintf("CREATE TABLE %s (i %s())", tableId.FullyQualifiedName(), c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '('") + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } + + // Testing on table creation here because casting (::GEOMETRY) was ending with errors (even for the "correct" cases). + for _, c := range datatypes.GeometryDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of geometry datatype: %s", c), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i %s)", tableId.FullyQualifiedName(), c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + + tableId = testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql = fmt.Sprintf("CREATE TABLE %s (i %s())", tableId.FullyQualifiedName(), c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '('") + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } + + for _, c := range datatypes.NumberDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of number datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 1::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 1::%s(36, 5)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.NumberDataTypeSubTypes { + t.Run(fmt.Sprintf("check behavior of number data type subtype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 1::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + for _, c := range incorrectNumberDatatypes { + t.Run(fmt.Sprintf("check behavior of number datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 1::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + for _, c := range datatypes.ObjectDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of object data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT {}::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT {}::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + for _, c := range datatypes.AllTextDataTypes { + t.Run(fmt.Sprintf("check behavior of text data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT 'A'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT 'ABC'::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range incorrectTextDatatypes { + t.Run(fmt.Sprintf("check behavior of text datatype: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT ABC::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + require.Error(t, err) + }) + } + + for _, c := range datatypes.TimeDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of time data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '00:00:00'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '00:00:00'::%s(5)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.TimestampLtzDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of timestamp ltz data types: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s(3)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.TimestampNtzDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of timestamp ntz data types: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s(3)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.TimestampTzDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of timestamp tz data types: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT '2024-12-02 00:00:00 +0000'::%s(3)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + }) + } + + for _, c := range datatypes.VariantDataTypeSynonyms { + t.Run(fmt.Sprintf("check behavior of variant data type: %s", c), func(t *testing.T) { + sql := fmt.Sprintf("SELECT TO_VARIANT(1)::%s", c) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + + sql = fmt.Sprintf("SELECT TO_VARIANT(1)::%s(36)", c) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected '36'") + }) + } + + // Testing on table creation here because apparently VECTOR is not supported as query in the gosnowflake driver. + // It ends with "unsupported data type" from https://github.com/snowflakedb/gosnowflake/blob/171ddf2540f3a24f2a990e8453dc425ea864a4a0/converter.go#L1599. + for _, c := range datatypes.VectorDataTypeSynonyms { + for _, inner := range datatypes.VectorAllowedInnerTypes { + t.Run(fmt.Sprintf("check behavior of vector data type: %s, %s", c, inner), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i %s(%s, 2))", tableId.FullyQualifiedName(), c, inner) + _, err := client.QueryUnsafe(ctx, sql) + assert.NoError(t, err) + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + + tableId = testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql = fmt.Sprintf("CREATE TABLE %s (i %s(%s))", tableId.FullyQualifiedName(), c, inner) + _, err = client.QueryUnsafe(ctx, sql) + assert.ErrorContains(t, err, "SQL compilation error") + assert.ErrorContains(t, err, "unexpected ')'") + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } + } + + // Testing on table creation here because apparently VECTOR is not supported as query in the gosnowflake driver. + // It ends with "unsupported data type" from https://github.com/snowflakedb/gosnowflake/blob/171ddf2540f3a24f2a990e8453dc425ea864a4a0/converter.go#L1599. + for _, c := range vectorInnerTypesSynonyms { + t.Run(fmt.Sprintf("document behavior of vector data type synonyms: %s", c), func(t *testing.T) { + tableId := testClientHelper().Ids.RandomSchemaObjectIdentifier() + sql := fmt.Sprintf("CREATE TABLE %s (i VECTOR(%s, 3))", tableId.FullyQualifiedName(), c) + _, err := client.QueryUnsafe(ctx, sql) + if slices.Contains(vectorInnerTypeSynonymsThatWork, c) { + assert.NoError(t, err) + } else { + assert.ErrorContains(t, err, "SQL compilation error") + switch { + case slices.Contains(datatypes.NumberDataTypeSynonyms, c): + assert.ErrorContains(t, err, fmt.Sprintf("unexpected '%s'", c)) + case slices.Contains(datatypes.NumberDataTypeSubTypes, c): + assert.ErrorContains(t, err, "Unsupported vector element type 'NUMBER(38,0)'") + case slices.Contains(datatypes.FloatDataTypeSynonyms, c): + assert.ErrorContains(t, err, "Unsupported vector element type 'FLOAT'") + default: + t.Fail() + } + } + t.Cleanup(testClientHelper().Table.DropFunc(t, tableId)) + }) + } +} diff --git a/pkg/sdk/testint/external_tables_integration_test.go b/pkg/sdk/testint/external_tables_integration_test.go index 26f0e0a720..d614aa10d9 100644 --- a/pkg/sdk/testint/external_tables_integration_test.go +++ b/pkg/sdk/testint/external_tables_integration_test.go @@ -27,7 +27,7 @@ func TestInt_ExternalTables(t *testing.T) { return []*sdk.ExternalTableColumnRequest{ sdk.NewExternalTableColumnRequest("filename", sdk.DataTypeString, "metadata$filename::string"), sdk.NewExternalTableColumnRequest("city", sdk.DataTypeString, "value:city:findname::string"), - sdk.NewExternalTableColumnRequest("time", sdk.DataTypeTimestamp, "to_timestamp(value:time::int)"), + sdk.NewExternalTableColumnRequest("time", sdk.DataTypeTimestampLTZ, "to_timestamp_ltz(value:time::int)"), sdk.NewExternalTableColumnRequest("weather", sdk.DataTypeVariant, "value:weather::variant"), } } diff --git a/pkg/sdk/testint/functions_integration_test.go b/pkg/sdk/testint/functions_integration_test.go index 1fe0a04bb9..44bb8b898a 100644 --- a/pkg/sdk/testint/functions_integration_test.go +++ b/pkg/sdk/testint/functions_integration_test.go @@ -9,6 +9,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -518,7 +519,6 @@ func TestInt_FunctionsShowByID(t *testing.T) { *sdk.NewFunctionArgumentRequest("M", sdk.DataTypeDate), *sdk.NewFunctionArgumentRequest("N", "DATETIME"), *sdk.NewFunctionArgumentRequest("O", sdk.DataTypeTime), - *sdk.NewFunctionArgumentRequest("P", sdk.DataTypeTimestamp), *sdk.NewFunctionArgumentRequest("R", sdk.DataTypeTimestampLTZ), *sdk.NewFunctionArgumentRequest("S", sdk.DataTypeTimestampNTZ), *sdk.NewFunctionArgumentRequest("T", sdk.DataTypeTimestampTZ), @@ -536,14 +536,15 @@ func TestInt_FunctionsShowByID(t *testing.T) { "add", ). WithArguments(args). - WithFunctionDefinition("def add(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, R, S, T, U, V, W, X, Y, Z): A + A"), + WithFunctionDefinition("def add(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, R, S, T, U, V, W, X, Y, Z): A + A"), ) require.NoError(t, err) dataTypes := make([]sdk.DataType, len(args)) for i, arg := range args { - dataTypes[i], err = sdk.ToDataType(string(arg.ArgDataType)) + dataType, err := datatypes.ParseDataType(string(arg.ArgDataType)) require.NoError(t, err) + dataTypes[i] = sdk.LegacyDataTypeFrom(dataType) } idWithArguments := sdk.NewSchemaObjectIdentifierWithArguments(id.DatabaseName(), id.SchemaName(), id.Name(), dataTypes...) diff --git a/pkg/sdk/testint/row_access_policies_gen_integration_test.go b/pkg/sdk/testint/row_access_policies_gen_integration_test.go index 4241a7eab1..2833210fe8 100644 --- a/pkg/sdk/testint/row_access_policies_gen_integration_test.go +++ b/pkg/sdk/testint/row_access_policies_gen_integration_test.go @@ -6,6 +6,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/acceptance/helpers/random" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -221,7 +222,7 @@ func TestInt_RowAccessPolicies(t *testing.T) { t.Run("describe row access policy: with timestamp data type normalization", func(t *testing.T) { argName := random.AlphaN(5) - argType := sdk.DataTypeTimestamp + argType := sdk.DataTypeTimestampLTZ args := sdk.NewCreateRowAccessPolicyArgsRequest(argName, argType) body := "true" @@ -234,7 +235,7 @@ func TestInt_RowAccessPolicies(t *testing.T) { assertRowAccessPolicyDescription(t, returnedRowAccessPolicyDescription, rowAccessPolicy.ID(), []sdk.TableColumnSignature{{ Name: argName, - Type: sdk.DataTypeTimestampNTZ, + Type: sdk.DataTypeTimestampLTZ, }}, body) }) @@ -317,7 +318,6 @@ func TestInt_RowAccessPoliciesDescribe(t *testing.T) { *sdk.NewCreateRowAccessPolicyArgsRequest("M", sdk.DataTypeDate), *sdk.NewCreateRowAccessPolicyArgsRequest("N", "DATETIME"), *sdk.NewCreateRowAccessPolicyArgsRequest("O", sdk.DataTypeTime), - *sdk.NewCreateRowAccessPolicyArgsRequest("P", sdk.DataTypeTimestamp), *sdk.NewCreateRowAccessPolicyArgsRequest("R", sdk.DataTypeTimestampLTZ), *sdk.NewCreateRowAccessPolicyArgsRequest("S", sdk.DataTypeTimestampNTZ), *sdk.NewCreateRowAccessPolicyArgsRequest("T", sdk.DataTypeTimestampTZ), @@ -342,11 +342,11 @@ func TestInt_RowAccessPoliciesDescribe(t *testing.T) { require.NoError(t, err) wantArgs := make([]sdk.TableColumnSignature, len(args)) for i, arg := range args { - dataType, err := sdk.ToDataType(string(arg.Type)) + dataType, err := datatypes.ParseDataType(string(arg.Type)) require.NoError(t, err) wantArgs[i] = sdk.TableColumnSignature{ Name: arg.Name, - Type: dataType, + Type: sdk.LegacyDataTypeFrom(dataType), } } assert.Equal(t, wantArgs, policyDetails.Signature) diff --git a/pkg/sdk/testint/tables_integration_test.go b/pkg/sdk/testint/tables_integration_test.go index 19d7e8732f..7678ea5c36 100644 --- a/pkg/sdk/testint/tables_integration_test.go +++ b/pkg/sdk/testint/tables_integration_test.go @@ -13,6 +13,7 @@ import ( "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/collections" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/internal/snowflakeroles" "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk" + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -43,9 +44,9 @@ func TestInt_Table(t *testing.T) { require.Len(t, createdColumns, len(expectedColumns)) for i, expectedColumn := range expectedColumns { assert.Equal(t, strings.ToUpper(expectedColumn.Name), createdColumns[i].ColumnName) - createdColumnDataType, err := sdk.ToDataType(createdColumns[i].DataType) + createdColumnDataType, err := datatypes.ParseDataType(createdColumns[i].DataType) assert.NoError(t, err) - assert.Equal(t, expectedColumn.Type, createdColumnDataType) + assert.Equal(t, expectedColumn.Type, sdk.LegacyDataTypeFrom(createdColumnDataType)) } } diff --git a/pkg/sdk/testint/warehouses_integration_test.go b/pkg/sdk/testint/warehouses_integration_test.go index 659705f775..998bca7c09 100644 --- a/pkg/sdk/testint/warehouses_integration_test.go +++ b/pkg/sdk/testint/warehouses_integration_test.go @@ -616,7 +616,7 @@ func TestInt_Warehouses(t *testing.T) { require.NoError(t, err) assert.Equal(t, 1, result.Running) assert.Equal(t, 0, result.Queued) - assert.Equal(t, sdk.WarehouseStateSuspended, result.State) + assert.Eventually(t, func() bool { return sdk.WarehouseStateSuspended == result.State }, 5*time.Second, time.Second) }) t.Run("alter: resize with a long running-query", func(t *testing.T) { diff --git a/pkg/sdk/validations.go b/pkg/sdk/validations.go index ada355f2d2..d8199f2d24 100644 --- a/pkg/sdk/validations.go +++ b/pkg/sdk/validations.go @@ -2,10 +2,12 @@ package sdk import ( "reflect" + + "github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/datatypes" ) func IsValidDataType(v string) bool { - _, err := ToDataType(v) + _, err := datatypes.ParseDataType(v) return err == nil }