Skip to content

Commit

Permalink
feat: added integration/snapshot tests (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
Will-Mann-16 authored Dec 29, 2023
1 parent 666db2e commit 49a07b0
Show file tree
Hide file tree
Showing 103 changed files with 2,396 additions and 114 deletions.
49 changes: 49 additions & 0 deletions astTraversal/packageNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ type PackageNode struct {
Package *packages.Package
Edges []*PackageNode
Files []*FileNode

// TypeDocMap is a map of type names to their documentation
// We cache this to save iterating over types every time we need to find the documentation
TypeDocMap map[string]string
}

func (p *PackageNode) Path() string {
Expand Down Expand Up @@ -192,3 +196,48 @@ func (p *PackageNode) ASTAtPos(pos token.Pos) (ast.Node, error) {

return nil, fmt.Errorf("node at %s not found in package %s", node, p.Path())
}

// FindDocForType finds the documentation for a type in the package
func (p *PackageNode) FindDocForType(typeName string) (string, bool) {
p.populateTypeDocMap()

doc, ok := p.TypeDocMap[typeName]
return doc, ok
}

// populateTypeDocMap populates the TypeDocMap for the package
func (p *PackageNode) populateTypeDocMap() {
// If the map is already populated, we don't need to do anything
if p.TypeDocMap != nil {
return
}

// Otherwise, we need to populate it
p.TypeDocMap = make(map[string]string)

// Loop over every file
for _, f := range p.Files {
// Loop over every declaration
for _, decl := range f.AST.Decls {
// If the declaration is a GenDecl, it's a const/var/type declaration (top level)
if genDecl, ok := decl.(*ast.GenDecl); ok {
// Loop over every spec
for _, spec := range genDecl.Specs {
// If the spec is a type spec, it's a type declaration
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
// If the type spec has no documentation, and the overarching declaration has no documentation, we skip it
// TypeSpecs can have documentation, but it's not common
// It's more common for the GenDecl to have the documentation
if typeSpec.Doc == nil && genDecl.Doc == nil {
continue
} else if typeSpec.Doc != nil { // The TypeSpec has priority over the GenDecl
p.TypeDocMap[typeSpec.Name.Name] = typeSpec.Doc.Text()
} else {
p.TypeDocMap[typeSpec.Name.Name] = genDecl.Doc.Text()
}
}
}
}
}
}
}
2 changes: 1 addition & 1 deletion astTraversal/packageNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,6 @@ func TestPackageNode_FindASTForType(t *testing.T) {
// Find the AST for the type "otherpkg1.Foo"
_, _, err := traverser.ActiveFile().Package.FindASTForType("otherpkg1.Foo")
assert.Error(t, err)
assert.ErrorContains(t, err, "type otherpkg1.Foo not found in package github.com/ls6-events/astra/testfiles")
assert.ErrorContains(t, err, "type otherpkg1.Foo not found in package github.com/ls6-events/astra/astTraversal/testfiles")
})
}
2 changes: 1 addition & 1 deletion astTraversal/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import (
"path"
)

const relativeTestFilePath = "testfiles"
const relativeTestFilePath = "astTraversal/testfiles"

func createTraverserFromTestFile(testFilePath string) (*BaseTraverser, error) {
wd, err := os.Getwd() // Will work as tests are only run from the root of the project
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ package testfiles

import (
"fmt"
"github.com/ls6-events/astra/testfiles/otherpkg1"
"github.com/ls6-events/astra/astTraversal/testfiles/otherpkg1"
"strings"
)

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package testfiles

import (
"fmt"
"github.com/ls6-events/astra/testfiles/otherpkg1"
"github.com/ls6-events/astra/astTraversal/testfiles/otherpkg1"
"strings"
)

Expand Down
File renamed without changes.
63 changes: 8 additions & 55 deletions astTraversal/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,11 @@ func (t *TypeTraverser) Result() (Result, error) {
// TODO - this isn't working entirely well for external packages
// Needs investigation
node, err := structFieldResult.Package.ASTAtPos(pos)
if err != nil || node == nil {
t.Traverser.Log.Warn().Err(err).Msgf("failed to get AST at position %d for field %s", pos, f.Id())
} else if field, ok := node.(*ast.Field); ok {
structFieldResult.Doc = FormatDoc(field.Doc.Text())
if err == nil && node != nil {
if field, ok := node.(*ast.Field); ok {
t.Traverser.Log.Debug().Str("field", field.Names[0].Name).Msg("Found doc for field")
structFieldResult.Doc = FormatDoc(field.Doc.Text())
}
}
}

Expand Down Expand Up @@ -280,58 +281,10 @@ func (t *TypeTraverser) Doc() (string, error) {
return "", err
}

node, err := pkg.ASTAtPos(named.Obj().Pos())
if err != nil || node == nil {
t.Traverser.Log.Warn().Err(err).Msgf("failed to get AST for type %s", t.Node.String())
for expr, info := range pkg.Package.TypesInfo.Types {
if info.Type.String() == named.String() {
node = expr
break
}
}
}

for node != nil {
switch n := node.(type) {
case *ast.TypeSpec:
doc := n.Doc.Text()
if doc != "" {
return FormatDoc(doc), nil
}

// If the doc doesn't exist, we need to find the declaration of the type
// and get the doc from there
// This is because the doc is attached to the declaration, not the AST type
for _, file := range pkg.Package.Syntax {
if pkg.Package.Fset.Position(file.Pos()).Filename == pkg.Package.Fset.Position(n.Pos()).Filename {
for _, decl := range file.Decls {
if genDecl, ok := decl.(*ast.GenDecl); ok {
for _, spec := range genDecl.Specs {
if typeSpec, ok := spec.(*ast.TypeSpec); ok {
if typeSpec.Name.Name == n.Name.Name {
return FormatDoc(genDecl.Doc.Text()), nil
}
}
}
}
}
}
}

node = nil
case *ast.CompositeLit:
node = n.Type
case *ast.Ident:
if n.Obj != nil {
node = n.Obj.Decl.(ast.Node)
} else {
node = nil
}
default:
node = nil
}
doc, ok := pkg.FindDocForType(named.Obj().Name())
if ok {
return FormatDoc(doc), nil
}

}

return "", nil
Expand Down
6 changes: 1 addition & 5 deletions formatTypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ type TypeFormat struct {
Format string
}

// predefinedTypeMap is the map of the standard go types that are accepted by OpenAPI
// PredefinedTypeMap is the map of the standard go types that are accepted by OpenAPI
// It contains the go type as a string and the corresponding OpenAPI type as the value - also including the format
var PredefinedTypeMap = map[string]TypeFormat{
"string": {
Expand Down Expand Up @@ -54,10 +54,6 @@ var PredefinedTypeMap = map[string]TypeFormat{
Type: "integer",
Format: "uint64",
},
"float": {
Type: "number",
Format: "float",
},
"float32": {
Type: "number",
Format: "float32",
Expand Down
3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ module github.com/ls6-events/astra
go 1.21

require (
github.com/Jeffail/gabs/v2 v2.7.0
github.com/gin-gonic/gin v1.9.1
github.com/google/go-cmp v0.5.5
github.com/google/uuid v1.5.0
github.com/iancoleman/strcase v0.3.0
github.com/rs/zerolog v1.31.0
github.com/stretchr/testify v1.8.4
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
github.com/Jeffail/gabs/v2 v2.7.0 h1:Y2edYaTcE8ZpRsR2AtmPu5xQdFDIthFG0jYhu5PY8kg=
github.com/Jeffail/gabs/v2 v2.7.0/go.mod h1:dp5ocw1FvBBQYssgHsG7I1WYsiLRtkUaB1FEtSwvNUw=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM=
github.com/bytedance/sonic v1.10.1 h1:7a1wuFXL1cMy7a3f7/VFcEtriuXQnUBhtoVfOZiaysc=
Expand Down Expand Up @@ -33,6 +35,8 @@ github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaS
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.5.0 h1:1p67kYwdtXjb0gL0BPiP1Av9wiZPo5A8z2cWkTZ+eyU=
github.com/google/uuid v1.5.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/iancoleman/strcase v0.3.0 h1:nTXanmYxhfFAMjZL34Ov6gkzEsSJZ5DbhxWjvSASxEI=
github.com/iancoleman/strcase v0.3.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
Expand Down
56 changes: 15 additions & 41 deletions inputs/gin/parseFunction.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
return route, nil
})
// Query Param methods
case "GetQuery":
fallthrough
case "Query":
case "GetQuery", "Query":
currRoute, err = funcBuilder.Value().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
name := params[0].(string)

Expand All @@ -261,10 +259,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}

case "GetQueryArray":
fallthrough
case "QueryArray":
case "GetQueryArray", "QueryArray":
currRoute, err = funcBuilder.Value().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
name := params[0].(string)

Expand All @@ -283,10 +278,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}

case "GetQueryMap":
fallthrough
case "QueryMap":
case "GetQueryMap", "QueryMap":
currRoute, err = funcBuilder.Value().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
name := params[0].(string)

Expand All @@ -305,10 +297,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}

case "ShouldBindQuery":
fallthrough
case "BindQuery":
case "ShouldBindQuery", "BindQuery":
currRoute, err = funcBuilder.ExpressionResult().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
field := astra.ParseResultToField(params[0].(astTraversal.Result))

Expand All @@ -324,9 +313,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
}

// Body Param methods
case "ShouldBind":
fallthrough
case "Bind":
case "ShouldBind", "Bind":
currRoute, err = funcBuilder.ExpressionResult().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
field := astra.ParseResultToField(params[0].(astTraversal.Result))

Expand Down Expand Up @@ -362,9 +349,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}
case "ShouldBindJSON":
fallthrough
case "BindJSON":
case "ShouldBindJSON", "BindJSON":
currRoute, err = funcBuilder.ExpressionResult().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
field := astra.ParseResultToField(params[0].(astTraversal.Result))

Expand All @@ -379,9 +364,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}
case "ShouldBindXML":
fallthrough
case "BindXML":
case "ShouldBindXML", "BindXML":
currRoute, err = funcBuilder.ExpressionResult().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
field := astra.ParseResultToField(params[0].(astTraversal.Result))

Expand All @@ -396,9 +379,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}
case "ShouldBindYAML":
fallthrough
case "BindYAML":
case "ShouldBindYAML", "BindYAML":
currRoute, err = funcBuilder.ExpressionResult().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
field := astra.ParseResultToField(params[0].(astTraversal.Result))

Expand All @@ -413,9 +394,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}
case "GetPostForm":
fallthrough
case "PostForm":
case "GetPostForm", "PostForm":
currRoute, err = funcBuilder.Value().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
name := params[0].(string)

Expand All @@ -434,9 +413,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}
case "GetPostFormArray":
fallthrough
case "PostFormArray":
case "GetPostFormArray", "PostFormArray":
currRoute, err = funcBuilder.Value().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
name := params[0].(string)

Expand All @@ -456,9 +433,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}
case "GetPostFormMap":
fallthrough
case "PostFormMap":
case "GetPostFormMap", "PostFormMap":
currRoute, err = funcBuilder.Value().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
name := params[0].(string)

Expand Down Expand Up @@ -496,9 +471,7 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
if err != nil {
return false
}
case "ShouldBindHeader":
fallthrough
case "BindHeader":
case "ShouldBindHeader", "BindHeader":
currRoute, err = funcBuilder.ExpressionResult().Build(func(route *astra.Route, params []any) (*astra.Route, error) {
field := astra.ParseResultToField(params[0].(astTraversal.Result))

Expand Down Expand Up @@ -569,8 +542,9 @@ func parseFunction(s *astra.Service, funcTraverser *astTraversal.FunctionTravers
result := params[1].(astTraversal.Result)

returnType := astra.ReturnType{
StatusCode: statusCode,
Field: astra.ParseResultToField(result),
ContentType: "application/json",
StatusCode: statusCode,
Field: astra.ParseResultToField(result),
}

route.ReturnTypes = astra.AddReturnType(route.ReturnTypes, returnType)
Expand Down
Loading

0 comments on commit 49a07b0

Please sign in to comment.