Skip to content

Commit

Permalink
[fix] a more robust parser for views (Snowflake-Labs#185)
Browse files Browse the repository at this point in the history
[fix] a more robust parser for viewsIn order to extract the select statement from views we need a more robust parser.

So far it can parse all query shapes generated by this project and many others, which are relevant since they can come in via imports.

## Test Plan
* [x] acceptance tests
* [x] tested on our internal snowflake infra repo

## References
*  https://docs.snowflake.com/en/sql-reference/sql/create-view.html
  • Loading branch information
ryanking authored May 1, 2020
1 parent 5b558c8 commit 58f5f3d
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 4 deletions.
6 changes: 5 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ coverage: ## run the go coverage tool, reading file coverage.out
.PHONY: coverage

test: fmt deps ## run the tests
go test -race -coverprofile=coverage.txt -covermode=atomic ./...
go test -race -coverprofile=coverage.txt -covermode=atomic $(TESTARGS) ./...
.PHONY: test

test-acceptance: fmt deps ## runs all tests, including the acceptance tests which create and destroys real resources
Expand All @@ -79,6 +79,10 @@ install-tf: build ## installs plugin where terraform can find it
cp ./$(BASE_BINARY_NAME) $(HOME)/.terraform.d/plugins/$(BASE_BINARY_NAME)
.PHONY: install-tf

uninstall-tf: build ## uninstalls plugin from where terraform can find it
rm $(HOME)/.terraform.d/plugins/$(BASE_BINARY_NAME) 2>/dev/null
.PHONY: install-tf

help: ## display help for this makefile
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
.PHONY: help
Expand Down
10 changes: 7 additions & 3 deletions pkg/resources/view.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,13 @@ func ReadView(data *schema.ResourceData, meta interface{}) error {
}

// Want to only capture the Select part of the query because before that is the Create part of the view which we no longer care about
cleanString := space.ReplaceAllString(text.String, " ")
indexOfSelect := strings.Index(strings.ToUpper(cleanString), " AS SELECT")
substringOfQuery := cleanString[indexOfSelect+4:]

extractor := snowflake.NewViewSelectStatementExtractor(text.String)
substringOfQuery, err := extractor.Extract()
if err != nil {
return err
}

err = data.Set("statement", substringOfQuery)
if err != nil {
return err
Expand Down
139 changes: 139 additions & 0 deletions pkg/snowflake/parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package snowflake

import (
"fmt"
"strings"
"unicode"
)

// ViewSelectStatementExtractor is a simplistic parser that only exists to extract the select
// statement from a create view statement
//
// The implementation is optimized for undertandable and predictable behavior. So far we only seek
// to support queries of the sort that are generated by this project.
//
// Also there is little error handling and we assume queries are well-formed.
type ViewSelectStatementExtractor struct {
input []rune
pos int
}

func NewViewSelectStatementExtractor(input string) *ViewSelectStatementExtractor {
return &ViewSelectStatementExtractor{
input: []rune(input),
}
}

func (e *ViewSelectStatementExtractor) Extract() (string, error) {
fmt.Printf("[DEBUG] extracting view query %s\n", string(e.input))
e.consumeSpace()
e.consumeToken("create")
e.consumeSpace()
e.consumeToken("or replace")
e.consumeSpace()
e.consumeToken("secure")
e.consumeSpace()
e.consumeToken("recursive")
e.consumeSpace()
e.consumeToken("view")
e.consumeSpace()
e.consumeToken("if not exists")
e.consumeSpace()
e.consumeIdentifier()
// TODO column list
// TODO copy grants
e.consumeComment()
e.consumeSpace()
e.consumeComment()
e.consumeSpace()
e.consumeToken("as")
e.consumeSpace()

return string(e.input[e.pos:]), nil
}

// consumeToken will move e.pos forward iff the token is the next part of the input. Comparison is
// case-insensitive. Will return true if consumed.
func (e *ViewSelectStatementExtractor) consumeToken(t string) bool {
found := 0
for i, r := range t {
// it is annoying that we have to convert the runes back to strings to do a case-insensitive
// comparison. Hopefully I am just missing something in the docs.
if e.pos+i > len(e.input) || !strings.EqualFold(string(r), string(e.input[e.pos+i])) {
break
}
found += 1
}

if found == len(t) {
e.pos += len(t)
return true
}
return false
}

func (e *ViewSelectStatementExtractor) consumeSpace() {
found := 0
for {
if e.pos+found > len(e.input)-1 || !unicode.IsSpace(e.input[e.pos+found]) {
break
}
found += 1
}
e.pos += found
}

func (e *ViewSelectStatementExtractor) consumeIdentifier() {
e.consumeNonSpace()
}

func (e *ViewSelectStatementExtractor) consumeNonSpace() {
found := 0
for {
if e.pos+found > len(e.input)-1 || unicode.IsSpace(e.input[e.pos+found]) {
break
}
found += 1
}
e.pos += found
}

func (e *ViewSelectStatementExtractor) consumeComment() {
if c := e.consumeToken("comment"); !c {
return
}

e.consumeSpace()

if c := e.consumeToken("="); !c {
return
}

e.consumeSpace()

if c := e.consumeToken("'"); !c {
return
}

found := 0
escaped := false
for {
if e.pos+found > len(e.input)-1 {
break
}

if escaped {
escaped = false
} else if e.input[e.pos+found] == '\\' {
escaped = true
} else if e.input[e.pos+found] == '\'' {
break
}
found += 1
}
e.pos += found

if c := e.consumeToken("'"); !c {
return
}
}
3 changes: 3 additions & 0 deletions pkg/snowflake/parser_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package snowflake

// Internal tests for ViewSelectStatementExtractor
161 changes: 161 additions & 0 deletions pkg/snowflake/parser_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package snowflake

import (
"fmt"
"testing"
)

func TestViewSelectStatementExtractor_Extract(t *testing.T) {
basic := "create view foo as select * from bar;"
caps := "CREATE VIEW FOO AS SELECT * FROM BAR;"
parens := "create view foo as (select * from bar);"
multiline := `
create view foo as
select *
from bar;`

multilineComment := `
create view foo as
-- comment
select *
from bar;`

secure := "create secure view foo as select * from bar;"
replace := "create or replace view foo as select * from bar;"
recursive := "create recursive view foo as select * from bar;"
ine := "create view if not exists foo as select * from bar;"

comment := `create view foo comment='asdf' as select * from bar;`
commentEscape := `create view foo comment='asdf\'s are fun' as select * from bar;`
identifier := `create view "foo"."bar"."bam" comment='asdf\'s are fun' as select * from bar;`

full := `CREATE SECURE VIEW "rgdxfmnfhh"."PUBLIC"."rgdxfmnfhh" COMMENT = 'Terraform test resource' AS SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES`

type args struct {
input string
}
tests := []struct {
name string
args args
want string
wantErr bool
}{
{"basic", args{basic}, "select * from bar;", false},
{"caps", args{caps}, "SELECT * FROM BAR;", false},
{"parens", args{parens}, "(select * from bar);", false},
{"multiline", args{multiline}, "select *\nfrom bar;", false},
{"multilineComment", args{multilineComment}, "-- comment\nselect *\nfrom bar;", false},
{"secure", args{secure}, "select * from bar;", false},
{"replace", args{replace}, "select * from bar;", false},
{"recursive", args{recursive}, "select * from bar;", false},
{"ine", args{ine}, "select * from bar;", false},
{"comment", args{comment}, "select * from bar;", false},
{"commentEscape", args{commentEscape}, "select * from bar;", false},
{"identifier", args{identifier}, "select * from bar;", false},
{"full", args{full}, "SELECT ROLE_NAME, ROLE_OWNER FROM INFORMATION_SCHEMA.APPLICABLE_ROLES", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := NewViewSelectStatementExtractor(tt.args.input)
got, err := e.Extract()
if (err != nil) != tt.wantErr {
t.Errorf("ViewSelectStatementExtractor.Extract() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("ViewSelectStatementExtractor.Extract() = '%v', want '%v'", got, tt.want)
}
})
}
}

func TestViewSelectStatementExtractor_consumeToken(t *testing.T) {
type fields struct {
input []rune
pos int
}
type args struct {
t string
}
tests := []struct {
name string
fields fields
args args
posAfter int
}{
{"basic - found", fields{[]rune("foo"), 0}, args{"foo"}, 3},
{"basic - not found", fields{[]rune("foo"), 0}, args{"bar"}, 0},
{"basic - not found", fields{[]rune("fob"), 0}, args{"foo"}, 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &ViewSelectStatementExtractor{
input: tt.fields.input,
pos: tt.fields.pos,
}
e.consumeToken(tt.args.t)

if e.pos != tt.posAfter {
t.Errorf("pos after = %v, want %v", e.pos, tt.posAfter)
}
})
}
}

func TestViewSelectStatementExtractor_consumeSpace(t *testing.T) {
type fields struct {
input []rune
pos int
}
tests := []struct {
name string
fields fields
posAfter int
}{
{"simple", fields{[]rune(" foo"), 0}, 3},
{"empty", fields{[]rune(""), 0}, 0},
{"middle", fields{[]rune("foo \t\n bar"), 3}, 7},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
fmt.Println(tt.name)
e := &ViewSelectStatementExtractor{
input: tt.fields.input,
pos: tt.fields.pos,
}
e.consumeSpace()

if e.pos != tt.posAfter {
t.Errorf("pos after = %v, want %v", e.pos, tt.posAfter)
}
})
}
}

func TestViewSelectStatementExtractor_consumeComment(t *testing.T) {
type fields struct {
input []rune
pos int
}
tests := []struct {
name string
fields fields
posAfter int
}{
{"basic", fields{[]rune("comment='foo'"), 0}, 13},
{"escaped", fields{[]rune(`comment='fo\'o'`), 0}, 15},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &ViewSelectStatementExtractor{
input: tt.fields.input,
pos: tt.fields.pos,
}
e.consumeComment()

if e.pos != tt.posAfter {
t.Errorf("pos after = %v, want %v", e.pos, tt.posAfter)
}
})
}
}

0 comments on commit 58f5f3d

Please sign in to comment.