Skip to content

Commit

Permalink
Adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
traut committed Nov 14, 2024
1 parent 9566f16 commit c975997
Show file tree
Hide file tree
Showing 5 changed files with 400 additions and 57 deletions.
46 changes: 0 additions & 46 deletions internal/microsoft/client/microsoftonline_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,52 +14,6 @@ const (
authURL = "https://login.microsoftonline.com"
)

// type MicrosoftCredentials struct {
// TenantID string `url:"-"`
// ClientID string `url:"-"`
// ClientSecret string `url:"-"`
// }

// type GetClientCredentialsTokenRes struct {
// AccessToken string `json:"access_token"`
// }

// func AuthorizeClient(ctx context.Context, req *MicrosoftCredentials) (*GetClientCredentialsTokenRes, error) {
// format := "/%s/oauth2/token"
// u, err := url.Parse(authURL + fmt.Sprintf(format, req.TenantID))
// if err != nil {
// return nil, err
// }
// payload := url.Values{
// "grant_type": {"client_credentials"},
// "client_id": {req.ClientID},
// "client_secret": {req.ClientSecret},
// "resource": {baseURLAzure},
// }
// r, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(payload.Encode()))
// if err != nil {
// return nil, err
// }
// c.prepare(r)
// r.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// client := http.Client{
// Timeout: 15 * time.Second,
// }
// res, err := client.Do(r)
// if err != nil {
// return nil, err
// }
// if res.StatusCode != http.StatusOK {
// return nil, fmt.Errorf("Microsoft OAuth2 endpoint returned status code: %d", res.StatusCode)
// }
// defer res.Body.Close()
// var data GetClientCredentialsTokenRes
// if err := json.NewDecoder(res.Body).Decode(&data); err != nil {
// return nil, err
// }
// return &data, nil
// }

type AcquireTokenFn func(ctx context.Context, tenantId string, clientId string, cred confidential.Credential, scopes []string) (string, error)

var AcquireAzureToken AcquireTokenFn = func(ctx context.Context, tenantId string, clientId string, cred confidential.Credential, scopes []string) (accessToken string, err error) {
Expand Down
1 change: 0 additions & 1 deletion internal/microsoft/data_microsoft_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ func fetchMicrosoftGraph(loader MicrosoftGraphClientLoadFn) plugin.RetrieveDataF

var response plugindata.Data


queryParamsAttr := params.Args.GetAttrVal("query_params")
queryParams := url.Values{}

Expand Down
25 changes: 15 additions & 10 deletions internal/microsoft/data_microsoft_graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/zclconf/go-cty/cty"

"github.com/blackstork-io/fabric/internal/microsoft"
"github.com/blackstork-io/fabric/internal/microsoft/client"
client_mocks "github.com/blackstork-io/fabric/mocks/internalpkg/microsoft"
"github.com/blackstork-io/fabric/pkg/diagnostics/diagtest"
"github.com/blackstork-io/fabric/plugin"
Expand Down Expand Up @@ -161,14 +162,15 @@ func (s *MicrosoftGraphDataSourceTestSuite) TestMissingConfig() {
}

func (s *MicrosoftGraphDataSourceTestSuite) TestMissingCredentials() {
expectedData := plugindata.List{
plugindata.Map{
"severity": plugindata.String("High"),
"displayName": plugindata.String("Incident 1"),
},
}
s.cli.On("QueryObjects", mock.Anything, "/security/incidents", url.Values{"$top": []string{"10"}}, 1).
Return(expectedData, nil)
s.plugin = microsoft.Plugin(
"1.0.0",
nil,
nil,
microsoft.MakeDefaultMicrosoftGraphClientLoader(client.AcquireAzureToken),
nil,
)
s.schema = s.plugin.DataSources["microsoft_graph"]

ctx := context.Background()
result, diags := s.schema.DataFunc(ctx, &plugin.RetrieveDataParams{
Config: plugintest.NewTestDecoder(s.T(), s.schema.Config).
Expand All @@ -182,6 +184,9 @@ func (s *MicrosoftGraphDataSourceTestSuite) TestMissingCredentials() {
SetAttr("size", cty.NumberIntVal(1)).
Decode(),
})
s.Nil(diags)
s.Equal(expectedData, result.AsPluginData())
s.Nil(result)
diagtest.Asserts{{
diagtest.IsError,
diagtest.DetailContains("Either `client_secret` or `private_key` / `private_key_file` arguments must be provide"),
}}.AssertMatch(s.T(), diags, nil)
}
168 changes: 168 additions & 0 deletions internal/microsoft/data_microsoft_security_query_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
package microsoft_test

import (
"context"
"errors"
"testing"

"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/suite"
"github.com/zclconf/go-cty/cty"

"github.com/blackstork-io/fabric/internal/microsoft"
"github.com/blackstork-io/fabric/internal/microsoft/client"
client_mocks "github.com/blackstork-io/fabric/mocks/internalpkg/microsoft"
"github.com/blackstork-io/fabric/pkg/diagnostics/diagtest"
"github.com/blackstork-io/fabric/plugin"
"github.com/blackstork-io/fabric/plugin/dataspec"
"github.com/blackstork-io/fabric/plugin/plugindata"
"github.com/blackstork-io/fabric/plugin/plugintest"
)

type MicrosoftSecurityQueryDataSourceTestSuite struct {
suite.Suite
plugin *plugin.Schema
schema *plugin.DataSource
cli *client_mocks.MicrosoftSecurityClient
}

func TestMicrosoftSecurityQueryDataSourceTestSuite(t *testing.T) {
suite.Run(t, &MicrosoftSecurityQueryDataSourceTestSuite{})
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) SetupSuite() {
s.plugin = microsoft.Plugin(
"1.0.0",
nil,
nil,
nil,
(func(ctx context.Context, cfg *dataspec.Block) (client microsoft.MicrosoftSecurityClient, err error) {
return s.cli, nil
}),
)
s.schema = s.plugin.DataSources["microsoft_security_query"]
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) SetupTest() {
s.cli = &client_mocks.MicrosoftSecurityClient{}
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) TearDownTest() {
s.cli.AssertExpectations(s.T())
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) TestSchema() {
s.Require().NotNil(s.plugin)
s.Require().NotNil(s.schema)
s.NotNil(s.schema.Args)
s.NotNil(s.schema.DataFunc)
s.NotNil(s.schema.Config)
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) TestBasic() {
testQuery := "test-query"
expectedData := plugindata.List{
plugindata.Map{
"foo": plugindata.String("bar"),
},
}

rawResponse := plugindata.Map{
"Results": plugindata.List{
plugindata.Map{
"foo": plugindata.String("bar"),
},
},
}
s.cli.On("RunAdvancedQuery", mock.Anything, testQuery).Return(rawResponse, nil)
ctx := context.Background()
result, diags := s.schema.DataFunc(ctx, &plugin.RetrieveDataParams{
Config: plugintest.NewTestDecoder(s.T(), s.schema.Config).
SetAttr("client_id", cty.StringVal("cid")).
SetAttr("tenant_id", cty.StringVal("tid")).
SetAttr("client_secret", cty.StringVal("csecret")).
Decode(),
Args: plugintest.NewTestDecoder(s.T(), s.schema.Args).
SetAttr("query", cty.StringVal(testQuery)).
Decode(),
})
s.Nil(diags)
s.NotNil(result)
s.Equal(expectedData, result.AsPluginData())
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) TestClientError() {
testQuery := "test-query"
s.cli.On("RunAdvancedQuery", mock.Anything, testQuery).
Return(nil, errors.New("Microsoft Security query API returned status code: 400"))

ctx := context.Background()
result, diags := s.schema.DataFunc(ctx, &plugin.RetrieveDataParams{
Config: plugintest.NewTestDecoder(s.T(), s.schema.Config).
SetAttr("client_id", cty.StringVal("cid")).
SetAttr("tenant_id", cty.StringVal("tid")).
SetAttr("client_secret", cty.StringVal("csecret")).
Decode(),
Args: plugintest.NewTestDecoder(s.T(), s.schema.Args).
SetAttr("query", cty.StringVal(testQuery)).
Decode(),
})
s.Nil(result)
diagtest.Asserts{{
diagtest.IsError,
diagtest.DetailContains("Microsoft Security query API returned status code: 400"),
}}.AssertMatch(s.T(), diags, nil)
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) TestMissingArgs() {
plugintest.NewTestDecoder(
s.T(),
s.schema.Args,
).Decode([]diagtest.Assert{
diagtest.IsError,
diagtest.SummaryEquals("Missing required attribute"),
diagtest.DetailContains("query"),
})
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) TestMissingConfig() {
plugintest.NewTestDecoder(
s.T(),
s.schema.Config,
).Decode([]diagtest.Assert{
diagtest.IsError,
diagtest.SummaryEquals("Missing required attribute"),
diagtest.DetailContains("client_id"),
}, []diagtest.Assert{
diagtest.IsError,
diagtest.SummaryEquals("Missing required attribute"),
diagtest.DetailContains("tenant_id"),
})
}

func (s *MicrosoftSecurityQueryDataSourceTestSuite) TestMissingCredentials() {
s.plugin = microsoft.Plugin(
"1.0.0",
nil,
nil,
nil,
microsoft.MakeDefaultMicrosoftSecurityClientLoader(client.AcquireAzureToken),
)
s.schema = s.plugin.DataSources["microsoft_security_query"]

ctx := context.Background()
result, diags := s.schema.DataFunc(ctx, &plugin.RetrieveDataParams{
Config: plugintest.NewTestDecoder(s.T(), s.schema.Config).
SetAttr("client_id", cty.StringVal("cid")).
SetAttr("tenant_id", cty.StringVal("tid")).
Decode(),
Args: plugintest.NewTestDecoder(s.T(), s.schema.Args).
SetAttr("query", cty.StringVal("some")).
Decode(),
})
s.Nil(result)
diagtest.Asserts{{
diagtest.IsError,
diagtest.DetailContains("Either `client_secret` or `private_key` / `private_key_file` arguments must be provide"),
}}.AssertMatch(s.T(), diags, nil)
}
Loading

0 comments on commit c975997

Please sign in to comment.