diff --git a/Makefile b/Makefile index 1778634ff7..a53aa5918a 100644 --- a/Makefile +++ b/Makefile @@ -55,7 +55,7 @@ coverage: ## run the go coverage tool, reading file coverage.out .PHONY: coverage test: fmt deps ## run the tests - go test -race -coverprofile=coverage.txt -covermode=atomic ./... + go test -race -coverprofile=coverage.txt -covermode=atomic $(TESTARGS) ./... .PHONY: test test-acceptance: fmt deps ## runs all tests, including the acceptance tests which create and destroys real resources @@ -79,6 +79,10 @@ install-tf: build ## installs plugin where terraform can find it cp ./$(BASE_BINARY_NAME) $(HOME)/.terraform.d/plugins/$(BASE_BINARY_NAME) .PHONY: install-tf +uninstall-tf: build ## uninstalls plugin from where terraform can find it + rm $(HOME)/.terraform.d/plugins/$(BASE_BINARY_NAME) 2>/dev/null +.PHONY: install-tf + help: ## display help for this makefile @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' .PHONY: help diff --git a/pkg/resources/view.go b/pkg/resources/view.go index 4e709e242b..5733c16f30 100644 --- a/pkg/resources/view.go +++ b/pkg/resources/view.go @@ -160,9 +160,13 @@ func ReadView(data *schema.ResourceData, meta interface{}) error { } // Want to only capture the Select part of the query because before that is the Create part of the view which we no longer care about - cleanString := space.ReplaceAllString(text.String, " ") - indexOfSelect := strings.Index(strings.ToUpper(cleanString), " AS SELECT") - substringOfQuery := cleanString[indexOfSelect+4:] + + extractor := snowflake.NewViewSelectStatementExtractor(text.String) + substringOfQuery, err := extractor.Extract() + if err != nil { + return err + } + err = data.Set("statement", substringOfQuery) if err != nil { return err diff --git a/pkg/snowflake/parser.go b/pkg/snowflake/parser.go new file mode 100644 index 0000000000..af97bf3ded --- /dev/null +++ b/pkg/snowflake/parser.go @@ -0,0 +1,139 @@ +package snowflake + +import ( + "fmt" + "strings" + "unicode" +) + +// ViewSelectStatementExtractor is a simplistic parser that only exists to extract the select +// statement from a create view statement +// +// The implementation is optimized for undertandable and predictable behavior. So far we only seek +// to support queries of the sort that are generated by this project. +// +// Also there is little error handling and we assume queries are well-formed. +type ViewSelectStatementExtractor struct { + input []rune + pos int +} + +func NewViewSelectStatementExtractor(input string) *ViewSelectStatementExtractor { + return &ViewSelectStatementExtractor{ + input: []rune(input), + } +} + +func (e *ViewSelectStatementExtractor) Extract() (string, error) { + fmt.Printf("[DEBUG] extracting view query %s\n", string(e.input)) + e.consumeSpace() + e.consumeToken("create") + e.consumeSpace() + e.consumeToken("or replace") + e.consumeSpace() + e.consumeToken("secure") + e.consumeSpace() + e.consumeToken("recursive") + e.consumeSpace() + e.consumeToken("view") + e.consumeSpace() + e.consumeToken("if not exists") + e.consumeSpace() + e.consumeIdentifier() + // TODO column list + // TODO copy grants + e.consumeComment() + e.consumeSpace() + e.consumeComment() + e.consumeSpace() + e.consumeToken("as") + e.consumeSpace() + + return string(e.input[e.pos:]), nil +} + +// consumeToken will move e.pos forward iff the token is the next part of the input. Comparison is +// case-insensitive. Will return true if consumed. +func (e *ViewSelectStatementExtractor) consumeToken(t string) bool { + found := 0 + for i, r := range t { + // it is annoying that we have to convert the runes back to strings to do a case-insensitive + // comparison. Hopefully I am just missing something in the docs. + if e.pos+i > len(e.input) || !strings.EqualFold(string(r), string(e.input[e.pos+i])) { + break + } + found += 1 + } + + if found == len(t) { + e.pos += len(t) + return true + } + return false +} + +func (e *ViewSelectStatementExtractor) consumeSpace() { + found := 0 + for { + if e.pos+found > len(e.input)-1 || !unicode.IsSpace(e.input[e.pos+found]) { + break + } + found += 1 + } + e.pos += found +} + +func (e *ViewSelectStatementExtractor) consumeIdentifier() { + e.consumeNonSpace() +} + +func (e *ViewSelectStatementExtractor) consumeNonSpace() { + found := 0 + for { + if e.pos+found > len(e.input)-1 || unicode.IsSpace(e.input[e.pos+found]) { + break + } + found += 1 + } + e.pos += found +} + +func (e *ViewSelectStatementExtractor) consumeComment() { + if c := e.consumeToken("comment"); !c { + return + } + + e.consumeSpace() + + if c := e.consumeToken("="); !c { + return + } + + e.consumeSpace() + + if c := e.consumeToken("'"); !c { + return + } + + found := 0 + escaped := false + for { + if e.pos+found > len(e.input)-1 { + break + } + + if escaped { + escaped = false + } else if e.input[e.pos+found] == '\\' { + escaped = true + } else if e.input[e.pos+found] == '\'' { + break + } + found += 1 + } + e.pos += found + + if c := e.consumeToken("'"); !c { + return + } +} diff --git a/pkg/snowflake/parser_internal_test.go b/pkg/snowflake/parser_internal_test.go new file mode 100644 index 0000000000..c2a7a09d9b --- /dev/null +++ b/pkg/snowflake/parser_internal_test.go @@ -0,0 +1,3 @@ +package snowflake + +// Internal tests for ViewSelectStatementExtractor diff --git a/pkg/snowflake/parser_test.go b/pkg/snowflake/parser_test.go new file mode 100644 index 0000000000..9164a40210 --- /dev/null +++ b/pkg/snowflake/parser_test.go @@ -0,0 +1,161 @@ +package snowflake + +import ( + "fmt" + "testing" +) + +func TestViewSelectStatementExtractor_Extract(t *testing.T) { + basic := "create view foo as select * from bar;" + caps := "CREATE VIEW FOO AS SELECT * FROM BAR;" + parens := "create view foo as (select * from bar);" + multiline := ` +create view foo as +select * +from bar;` + + multilineComment := ` +create view foo as +-- comment +select * +from bar;` + + secure := "create secure view foo as select * from bar;" + replace := "create or replace view foo as select * from bar;" + recursive := "create recursive view foo as select * from bar;" + ine := "create view if not exists foo as select * from bar;" + + comment := `create view foo comment='asdf' as select * from bar;` + commentEscape := `create view foo comment='asdf\'s are fun' as select * from bar;` + identifier := `create view "foo"."bar"."bam" comment='asdf\'s are fun' as select * from bar;` + + full := `CREATE SECURE VIEW "rgdxfmnfhh"."PUBLIC"."rgdxfmnfhh" COMMENT = 'Terraform test resource' AS SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES` + + type args struct { + input string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"basic", args{basic}, "select * from bar;", false}, + {"caps", args{caps}, "SELECT * FROM BAR;", false}, + {"parens", args{parens}, "(select * from bar);", false}, + {"multiline", args{multiline}, "select *\nfrom bar;", false}, + {"multilineComment", args{multilineComment}, "-- comment\nselect *\nfrom bar;", false}, + {"secure", args{secure}, "select * from bar;", false}, + {"replace", args{replace}, "select * from bar;", false}, + {"recursive", args{recursive}, "select * from bar;", false}, + {"ine", args{ine}, "select * from bar;", false}, + {"comment", args{comment}, "select * from bar;", false}, + {"commentEscape", args{commentEscape}, "select * from bar;", false}, + {"identifier", args{identifier}, "select * from bar;", false}, + {"full", args{full}, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := NewViewSelectStatementExtractor(tt.args.input) + got, err := e.Extract() + if (err != nil) != tt.wantErr { + t.Errorf("ViewSelectStatementExtractor.Extract() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("ViewSelectStatementExtractor.Extract() = '%v', want '%v'", got, tt.want) + } + }) + } +} + +func TestViewSelectStatementExtractor_consumeToken(t *testing.T) { + type fields struct { + input []rune + pos int + } + type args struct { + t string + } + tests := []struct { + name string + fields fields + args args + posAfter int + }{ + {"basic - found", fields{[]rune("foo"), 0}, args{"foo"}, 3}, + {"basic - not found", fields{[]rune("foo"), 0}, args{"bar"}, 0}, + {"basic - not found", fields{[]rune("fob"), 0}, args{"foo"}, 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &ViewSelectStatementExtractor{ + input: tt.fields.input, + pos: tt.fields.pos, + } + e.consumeToken(tt.args.t) + + if e.pos != tt.posAfter { + t.Errorf("pos after = %v, want %v", e.pos, tt.posAfter) + } + }) + } +} + +func TestViewSelectStatementExtractor_consumeSpace(t *testing.T) { + type fields struct { + input []rune + pos int + } + tests := []struct { + name string + fields fields + posAfter int + }{ + {"simple", fields{[]rune(" foo"), 0}, 3}, + {"empty", fields{[]rune(""), 0}, 0}, + {"middle", fields{[]rune("foo \t\n bar"), 3}, 7}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + fmt.Println(tt.name) + e := &ViewSelectStatementExtractor{ + input: tt.fields.input, + pos: tt.fields.pos, + } + e.consumeSpace() + + if e.pos != tt.posAfter { + t.Errorf("pos after = %v, want %v", e.pos, tt.posAfter) + } + }) + } +} + +func TestViewSelectStatementExtractor_consumeComment(t *testing.T) { + type fields struct { + input []rune + pos int + } + tests := []struct { + name string + fields fields + posAfter int + }{ + {"basic", fields{[]rune("comment='foo'"), 0}, 13}, + {"escaped", fields{[]rune(`comment='fo\'o'`), 0}, 15}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &ViewSelectStatementExtractor{ + input: tt.fields.input, + pos: tt.fields.pos, + } + e.consumeComment() + + if e.pos != tt.posAfter { + t.Errorf("pos after = %v, want %v", e.pos, tt.posAfter) + } + }) + } +}