diff --git a/pkg/sdk/poc/generator/interface.go b/pkg/sdk/poc/generator/interface.go index 98c213367e..faffb3dabf 100644 --- a/pkg/sdk/poc/generator/interface.go +++ b/pkg/sdk/poc/generator/interface.go @@ -1,5 +1,31 @@ package generator +import "fmt" + +type objectIdentifierKind string + +const ( + AccountObjectIdentifier objectIdentifierKind = "AccountObjectIdentifier" + DatabaseObjectIdentifier objectIdentifierKind = "DatabaseObjectIdentifier" + SchemaObjectIdentifier objectIdentifierKind = "SchemaObjectIdentifier" + SchemaObjectIdentifierWithArguments objectIdentifierKind = "SchemaObjectIdentifierWithArguments" +) + +func toObjectIdentifierKind(s string) (objectIdentifierKind, error) { + switch s { + case "AccountObjectIdentifier": + return AccountObjectIdentifier, nil + case "DatabaseObjectIdentifier": + return DatabaseObjectIdentifier, nil + case "SchemaObjectIdentifier": + return SchemaObjectIdentifier, nil + case "SchemaObjectIdentifierWithArguments": + return SchemaObjectIdentifierWithArguments, nil + default: + return "", fmt.Errorf("invalid string identifier type: %s", s) + } +} + // Interface groups operations for particular object or objects family (e.g. DATABASE ROLE) type Interface struct { // Name is the interface's name, e.g. "DatabaseRoles" @@ -28,6 +54,5 @@ func (i *Interface) NameLowerCased() string { // ObjectIdentifierKind returns the level of the object identifier (e.g. for DatabaseObjectIdentifier, it returns the prefix "Database") func (i *Interface) ObjectIdentifierPrefix() idPrefix { - // return strings.Replace(i.IdentifierKind, "ObjectIdentifier", "", 1) return identifierStringToPrefix(i.IdentifierKind) } diff --git a/pkg/sdk/poc/generator/operation.go b/pkg/sdk/poc/generator/operation.go index dfc20ca034..07a7b25641 100644 --- a/pkg/sdk/poc/generator/operation.go +++ b/pkg/sdk/poc/generator/operation.go @@ -79,6 +79,11 @@ func (s *Operation) withHelperStructs(helperStructs ...*Field) *Operation { return s } +func (s *Operation) withObjectInterface(objectInterface *Interface) *Operation { + s.ObjectInterface = objectInterface + return s +} + func addShowMapping(op *Operation, from, to *Field) { op.ShowMapping = newMapping("convert", from, to) } @@ -170,9 +175,9 @@ func (i *Interface) ShowByIdOperationWithNoFiltering() *Interface { // ShowByIdOperationWithFiltering adds a ShowByID operation to the interface with filtering. Should be used for objects that implement filtering options e.g. Like or In. func (i *Interface) ShowByIdOperationWithFiltering(filter ShowByIDFilteringKind, filtering ...ShowByIDFilteringKind) *Interface { - op := newNoSqlOperation(string(OperationKindShowByID)) - op.ObjectInterface = i - op.withFiltering(append([]ShowByIDFilteringKind{filter}, filtering...)...) + op := newNoSqlOperation(string(OperationKindShowByID)). + withObjectInterface(i). + withFiltering(append(filtering, filter)...) i.Operations = append(i.Operations, op) return i } diff --git a/pkg/sdk/poc/generator/show_object_methods.go b/pkg/sdk/poc/generator/show_object_methods.go new file mode 100644 index 0000000000..0b263de958 --- /dev/null +++ b/pkg/sdk/poc/generator/show_object_methods.go @@ -0,0 +1,60 @@ +package generator + +import ( + "log" + "slices" +) + +type ShowObjectIdMethod struct { + StructName string + IdentifierKind objectIdentifierKind + Args []string +} + +func newShowObjectIDMethod(structName string, idType objectIdentifierKind) *ShowObjectIdMethod { + return &ShowObjectIdMethod{ + StructName: structName, + IdentifierKind: idType, + Args: idTypeParts[idType], + } +} + +var idTypeParts map[objectIdentifierKind][]string = map[objectIdentifierKind][]string{ + AccountObjectIdentifier: {"Name"}, + DatabaseObjectIdentifier: {"DatabaseName", "Name"}, + SchemaObjectIdentifier: {"DatabaseName", "SchemaName", "Name"}, +} + +func checkRequiredFieldsForIdMethod(structName string, helperStructs []*Field, idKind objectIdentifierKind) bool { + if requiredFields, ok := idTypeParts[idKind]; ok { + for _, field := range helperStructs { + if field.Name == structName { + return containsFieldNames(field.Fields, requiredFields...) + } + } + } + log.Printf("[WARN] no required fields mapping defined for identifier %s", idKind) + return false +} + +func containsFieldNames(fields []*Field, names ...string) bool { + fieldNames := []string{} + for _, field := range fields { + fieldNames = append(fieldNames, field.Name) + } + + for _, name := range names { + if !slices.Contains(fieldNames, name) { + return false + } + } + return true +} + +type ShowObjectTypeMethod struct { + StructName string +} + +func newShowObjectTypeMethod(structName string) *ShowObjectTypeMethod { + return &ShowObjectTypeMethod{StructName: structName} +} diff --git a/pkg/sdk/poc/generator/show_object_methods_test.go b/pkg/sdk/poc/generator/show_object_methods_test.go new file mode 100644 index 0000000000..806834bdd0 --- /dev/null +++ b/pkg/sdk/poc/generator/show_object_methods_test.go @@ -0,0 +1,50 @@ +package generator + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIdentifierStringToObjectIdentifier(t *testing.T) { + tests := []struct { + input string + expected objectIdentifierKind + }{ + {"AccountObjectIdentifier", AccountObjectIdentifier}, + {"DatabaseObjectIdentifier", DatabaseObjectIdentifier}, + {"SchemaObjectIdentifier", SchemaObjectIdentifier}, + {"SchemaObjectIdentifierWithArguments", SchemaObjectIdentifierWithArguments}, + } + + for _, test := range tests { + t.Run(test.input, func(t *testing.T) { + result, err := toObjectIdentifierKind(test.input) + require.NoError(t, err) + require.Equal(t, test.expected, result) + }) + } +} + +func TestIdentifierStringToObjectIdentifier_Invalid(t *testing.T) { + tests := []struct { + input string + err string + }{ + {"accountobjectidentifier", "invalid string identifier type: accountobjectidentifier"}, + {"Account", "invalid string identifier type: Account"}, + {"databaseobjectidentifier", "invalid string identifier type: databaseobjectidentifier"}, + {"Database", "invalid string identifier type: Database"}, + {"schemaobjectidentifier", "invalid string identifier type: schemaobjectidentifier"}, + {"Schema", "invalid string identifier type: Schema"}, + {"schemaobjectidentifierwitharguments", "invalid string identifier type: schemaobjectidentifierwitharguments"}, + {"schemawitharguemnts", "invalid string identifier type: schemawitharguemnts"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + _, err := toObjectIdentifierKind(tc.input) + require.ErrorContains(t, err, tc.err) + }) + } +} diff --git a/pkg/sdk/poc/generator/template_executors.go b/pkg/sdk/poc/generator/template_executors.go index a2fba04cc7..d2d5455055 100644 --- a/pkg/sdk/poc/generator/template_executors.go +++ b/pkg/sdk/poc/generator/template_executors.go @@ -21,9 +21,29 @@ func GenerateInterface(writer io.Writer, def *Interface) { if o.OptsField != nil { generateOptionsStruct(writer, o) } + + if o.Name == string(OperationKindShow) { + idKind, err := toObjectIdentifierKind(def.IdentifierKind) + if err != nil { + log.Printf("[WARN] for showObjectIdMethod: %v", err) + } + if checkRequiredFieldsForIdMethod(def.NameSingular, o.HelperStructs, idKind) { + generateShowObjectIdMethod(writer, newShowObjectIDMethod(def.NameSingular, idKind)) + } + + generateShowObjectTypeMethod(writer, newShowObjectTypeMethod(def.NameSingular)) + } } } +func generateShowObjectIdMethod(writer io.Writer, m *ShowObjectIdMethod) { + printTo(writer, ShowObjectIdMethodTemplate, m) +} + +func generateShowObjectTypeMethod(writer io.Writer, m *ShowObjectTypeMethod) { + printTo(writer, ShowObjectTypeMethodTemplate, m) +} + func generateOptionsStruct(writer io.Writer, operation *Operation) { printTo(writer, OperationStructTemplate, operation) diff --git a/pkg/sdk/poc/generator/templates.go b/pkg/sdk/poc/generator/templates.go index 8e0188bcb2..7e50da1777 100644 --- a/pkg/sdk/poc/generator/templates.go +++ b/pkg/sdk/poc/generator/templates.go @@ -24,6 +24,18 @@ var ( structTemplateContent string StructTemplate, _ = template.New("structTemplate").Parse(structTemplateContent) + //go:embed templates/show_object_method.tmpl + showObjectMethodTemplateContent string + ShowObjectMethodTemplate, _ = template.New("helperMethodTemplate").Parse(showObjectMethodTemplateContent) + + //go:embed templates/show_object_id_method.tmpl + showObjectIdMethodTemplateContent string + ShowObjectIdMethodTemplate, _ = template.New("showObjectIdMethodTemplate").Parse(showObjectIdMethodTemplateContent) + + //go:embed templates/show_object_type_method.tmpl + showObjectTypeMethodTemplateContent string + ShowObjectTypeMethodTemplate, _ = template.New("showObjectTypeMethodTemplate").Parse(showObjectTypeMethodTemplateContent) + //go:embed templates/dto_declarations.tmpl dtoDeclarationsTemplateContent string DtoTemplate, _ = template.New("dtoTemplate").Parse(dtoDeclarationsTemplateContent) diff --git a/pkg/sdk/poc/generator/templates/show_object_id_method.tmpl b/pkg/sdk/poc/generator/templates/show_object_id_method.tmpl new file mode 100644 index 0000000000..5c8aedf177 --- /dev/null +++ b/pkg/sdk/poc/generator/templates/show_object_id_method.tmpl @@ -0,0 +1,5 @@ +{{- /*gotype: github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator.ShowObjectIdMethod*/ -}} + +func (v *{{ .StructName }}) ID() {{ .IdentifierKind }} { + return New{{ .IdentifierKind }}({{ range .Args }}v.{{ . }}, {{ end }}) +} diff --git a/pkg/sdk/poc/generator/templates/show_object_method.tmpl b/pkg/sdk/poc/generator/templates/show_object_method.tmpl new file mode 100644 index 0000000000..1447d2e2a2 --- /dev/null +++ b/pkg/sdk/poc/generator/templates/show_object_method.tmpl @@ -0,0 +1,6 @@ +{{- /*gotype: github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator.ShowObjectMethod*/ -}} + +func (v *{{ .StructName }}) ObjectType() ObjectType { + return ObjectType{{ .StructName }} +} + diff --git a/pkg/sdk/poc/generator/templates/show_object_type_method.tmpl b/pkg/sdk/poc/generator/templates/show_object_type_method.tmpl new file mode 100644 index 0000000000..3b1f15d353 --- /dev/null +++ b/pkg/sdk/poc/generator/templates/show_object_type_method.tmpl @@ -0,0 +1,5 @@ +{{- /*gotype: github.com/Snowflake-Labs/terraform-provider-snowflake/pkg/sdk/poc/generator.ShowObjectTypeMethod*/ -}} + +func (v *{{ .StructName }}) ObjectType() ObjectType { + return ObjectType{{ .StructName }} +} diff --git a/pkg/sdk/secrets_gen.go b/pkg/sdk/secrets_gen.go index 30d467f25e..17e2f0a110 100644 --- a/pkg/sdk/secrets_gen.go +++ b/pkg/sdk/secrets_gen.go @@ -30,9 +30,11 @@ type CreateWithOAuthClientCredentialsFlowSecretOptions struct { OauthScopes *OauthScopesList `ddl:"parameter,parentheses" sql:"OAUTH_SCOPES"` Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` } + type ApiIntegrationScope struct { Scope string `ddl:"keyword,single_quotes"` } + type OauthScopesList struct { OauthScopesList []ApiIntegrationScope `ddl:"list,must_parentheses"` } @@ -85,30 +87,37 @@ type AlterSecretOptions struct { Set *SecretSet `ddl:"keyword" sql:"SET"` Unset *SecretUnset `ddl:"keyword"` } + type SecretSet struct { Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` SetForFlow *SetForFlow `ddl:"keyword"` } + type SetForFlow struct { SetForOAuthClientCredentials *SetForOAuthClientCredentials `ddl:"keyword"` SetForOAuthAuthorization *SetForOAuthAuthorization `ddl:"keyword"` SetForBasicAuthentication *SetForBasicAuthentication `ddl:"keyword"` SetForGenericString *SetForGenericString `ddl:"keyword"` } + type SetForOAuthClientCredentials struct { OauthScopes *OauthScopesList `ddl:"parameter,parentheses" sql:"OAUTH_SCOPES"` } + type SetForOAuthAuthorization struct { OauthRefreshToken *string `ddl:"parameter,single_quotes" sql:"OAUTH_REFRESH_TOKEN"` OauthRefreshTokenExpiryTime *string `ddl:"parameter,single_quotes" sql:"OAUTH_REFRESH_TOKEN_EXPIRY_TIME"` } + type SetForBasicAuthentication struct { Username *string `ddl:"parameter,single_quotes" sql:"USERNAME"` Password *string `ddl:"parameter,single_quotes" sql:"PASSWORD"` } + type SetForGenericString struct { SecretString *string `ddl:"parameter,single_quotes" sql:"SECRET_STRING"` } + type SecretUnset struct { Comment *bool `ddl:"keyword" sql:"SET COMMENT = NULL"` } @@ -128,6 +137,7 @@ type ShowSecretOptions struct { Like *Like `ddl:"keyword" sql:"LIKE"` In *ExtendedIn `ddl:"keyword" sql:"IN"` } + type secretDBRow struct { CreatedOn time.Time `db:"created_on"` Name string `db:"name"` @@ -139,6 +149,7 @@ type secretDBRow struct { OauthScopes sql.NullString `db:"oauth_scopes"` OwnerRoleType string `db:"owner_role_type"` } + type Secret struct { CreatedOn time.Time Name string @@ -151,11 +162,11 @@ type Secret struct { OwnerRoleType string } -func (s *Secret) ID() SchemaObjectIdentifier { - return NewSchemaObjectIdentifier(s.DatabaseName, s.SchemaName, s.Name) +func (v *Secret) ID() SchemaObjectIdentifier { + return NewSchemaObjectIdentifier(v.DatabaseName, v.SchemaName, v.Name) } -func (s *Secret) ObjectType() ObjectType { +func (v *Secret) ObjectType() ObjectType { return ObjectTypeSecret } diff --git a/pkg/sdk/streamlits_gen.go b/pkg/sdk/streamlits_gen.go index ce08a1c4c2..8f90b6a6a4 100644 --- a/pkg/sdk/streamlits_gen.go +++ b/pkg/sdk/streamlits_gen.go @@ -28,6 +28,7 @@ type CreateStreamlitOptions struct { Title *string `ddl:"parameter,single_quotes" sql:"TITLE"` Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` } + type ExternalAccessIntegrations struct { ExternalAccessIntegrations []AccountObjectIdentifier `ddl:"list,must_parentheses"` } @@ -42,6 +43,7 @@ type AlterStreamlitOptions struct { Unset *StreamlitUnset `ddl:"list,no_parentheses" sql:"UNSET"` RenameTo *SchemaObjectIdentifier `ddl:"identifier" sql:"RENAME TO"` } + type StreamlitSet struct { RootLocation *string `ddl:"parameter,single_quotes" sql:"ROOT_LOCATION"` MainFile *string `ddl:"parameter,single_quotes" sql:"MAIN_FILE"` @@ -50,6 +52,7 @@ type StreamlitSet struct { Comment *string `ddl:"parameter,single_quotes" sql:"COMMENT"` Title *string `ddl:"parameter,single_quotes" sql:"TITLE"` } + type StreamlitUnset struct { QueryWarehouse *bool `ddl:"keyword" sql:"QUERY_WAREHOUSE"` Comment *bool `ddl:"keyword" sql:"COMMENT"` @@ -73,6 +76,7 @@ type ShowStreamlitOptions struct { In *In `ddl:"keyword" sql:"IN"` Limit *LimitFrom `ddl:"keyword" sql:"LIMIT"` } + type streamlitsRow struct { CreatedOn string `db:"created_on"` Name string `db:"name"` @@ -85,6 +89,7 @@ type streamlitsRow struct { UrlId string `db:"url_id"` OwnerRoleType string `db:"owner_role_type"` } + type Streamlit struct { CreatedOn string Name string @@ -98,12 +103,17 @@ type Streamlit struct { OwnerRoleType string } +func (v *Streamlit) ID() SchemaObjectIdentifier { + return NewSchemaObjectIdentifier(v.DatabaseName, v.SchemaName, v.Name) +} + // DescribeStreamlitOptions is based on https://docs.snowflake.com/en/sql-reference/sql/desc-streamlit. type DescribeStreamlitOptions struct { describe bool `ddl:"static" sql:"DESCRIBE"` streamlit bool `ddl:"static" sql:"STREAMLIT"` name SchemaObjectIdentifier `ddl:"identifier"` } + type streamlitsDetailRow struct { Name string `db:"name"` Title sql.NullString `db:"title"` @@ -117,6 +127,7 @@ type streamlitsDetailRow struct { ExternalAccessIntegrations string `db:"external_access_integrations"` ExternalAccessSecrets string `db:"external_access_secrets"` } + type StreamlitDetail struct { Name string Title string @@ -130,7 +141,3 @@ type StreamlitDetail struct { ExternalAccessIntegrations []string ExternalAccessSecrets string } - -func (s *Streamlit) ID() SchemaObjectIdentifier { - return NewSchemaObjectIdentifier(s.DatabaseName, s.SchemaName, s.Name) -}