Skip to content

Commit

Permalink
Start working on persister tests
Browse files Browse the repository at this point in the history
  • Loading branch information
supercairos committed Jan 11, 2024
1 parent 5191714 commit edf0738
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 17 deletions.
33 changes: 16 additions & 17 deletions persistence/sql/persister_consent.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,12 @@ func (p *Persister) GetDeviceGrantRequestByVerifier(ctx context.Context, verifie
defer span.End()

var dgr flow.DeviceGrantRequest
return &dgr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := c.Where("verifier = ?", verifier).First(&dgr); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return errorsx.WithStack(x.ErrNotFound)
}
return sqlcon.HandleError(err)
}

return nil
})
return &dgr, sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier = ?", verifier).First(&dgr))
}

func (p *Persister) AcceptDeviceGrantRequest(ctx context.Context, challenge string, device_code_signature string, client_id string, requested_scopes fosite.Arguments, requested_aud fosite.Arguments) (*flow.DeviceGrantRequest, error) {
var dgr flow.DeviceGrantRequest
return &dgr, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.QueryWithNetwork(ctx).Where("challenge = ?", challenge).First(&dgr); err != nil {
return sqlcon.HandleError(err)
}
Expand All @@ -257,18 +248,26 @@ func (p *Persister) AcceptDeviceGrantRequest(ctx context.Context, challenge stri
return errorsx.WithStack(x.ErrNotFound)
}
return err
})
}); err != nil {
return nil, sqlcon.HandleError(err)
}

return p.GetDeviceGrantRequestByVerifier(ctx, dgr.Verifier)
}

func (p *Persister) VerifyAndInvalidateDeviceGrantRequest(ctx context.Context, verifier string) (*flow.DeviceGrantRequest, error) {
var d flow.DeviceGrantRequest
return &d, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.QueryWithNetwork(ctx).Where("verifier = ?", verifier).First(&d); err != nil {
var dgr flow.DeviceGrantRequest
if err := p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
if err := p.QueryWithNetwork(ctx).Where("verifier = ?", verifier).First(&dgr); err != nil {
return sqlcon.HandleError(err)
}

return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier = ?", verifier).Delete(&d))
})
return sqlcon.HandleError(p.QueryWithNetwork(ctx).Where("verifier = ?", verifier).Delete(&dgr))
}); err != nil {
return nil, err
}

return &dgr, nil
}

func (p *Persister) CreateLoginRequest(ctx context.Context, req *flow.LoginRequest) (*flow.Flow, error) {
Expand Down
64 changes: 64 additions & 0 deletions persistence/sql/persister_nid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,37 @@ func (s *PersisterTestSuite) TestAcceptLogoutRequest() {
}
}

func (s *PersisterTestSuite) TestAcceptDeviceGrantRequest() {
t := s.T()
dgr := newDeviceGrantRequest()

for k, r := range s.registries {
t.Run("dialect="+k, func(*testing.T) {
require.NoError(t, r.ConsentManager().CreateDeviceGrantRequest(s.t1, dgr))

expected, err := r.ConsentManager().GetDeviceGrantRequestByVerifier(s.t1, dgr.Verifier)
require.NoError(t, err)
require.Equal(t, false, expected.Accepted)

client := &client.Client{ID: "client-id", Secret: "secret"}
require.NoError(t, r.Persister().CreateClient(s.t1, client))

dgrAccepted, err := r.ConsentManager().AcceptDeviceGrantRequest(s.t2, dgr.ID, "aaaa", "client-id", fosite.Arguments{"openid"}, fosite.Arguments{""})
require.Error(t, err)
require.Equal(t, &flow.DeviceGrantRequest{}, dgrAccepted)

actual, err := r.ConsentManager().GetDeviceGrantRequestByVerifier(s.t1, dgr.Verifier)
require.NoError(t, err)
require.Equal(t, expected, actual)

dgrAccepted, err = r.ConsentManager().AcceptDeviceGrantRequest(s.t1, dgr.ID, "aaaa", "client-id", fosite.Arguments{"openid"}, fosite.Arguments{""})
require.NoError(t, err)
require.Equal(t, dgrAccepted, actual)
require.Equal(t, true, actual.Accepted)
})
}
}

func (s *PersisterTestSuite) TestAddKeyGetKeyDeleteKey() {
t := s.T()
key := newKey("test-ks", "test")
Expand Down Expand Up @@ -2071,6 +2102,32 @@ func (s *PersisterTestSuite) TestVerifyAndInvalidateLogoutRequest() {
}
}

func (s *PersisterTestSuite) TestVerifyAndInvalidateDeviceGrantRequest() {
t := s.T()
for k, r := range s.registries {
t.Run(k, func(t *testing.T) {
dgr := newDeviceGrantRequest()

actual := &flow.DeviceGrantRequest{}

require.NoError(t, r.ConsentManager().CreateDeviceGrantRequest(s.t1, dgr))

expected, err := r.ConsentManager().GetDeviceGrantRequestByVerifier(s.t1, dgr.Verifier)
require.NoError(t, err)

dgrInvalidated, err := r.ConsentManager().VerifyAndInvalidateDeviceGrantRequest(s.t2, dgr.Verifier)
require.Error(t, err)
require.Nil(t, dgrInvalidated)
require.NoError(t, r.Persister().Connection(context.Background()).Find(actual, dgr.ID))
require.Equal(t, expected, actual)

_, err = r.ConsentManager().VerifyAndInvalidateDeviceGrantRequest(s.t1, dgr.Verifier)
require.NoError(t, err)
require.Error(t, x.ErrNotFound, r.Persister().Connection(context.Background()).Find(actual, dgr.ID))
})
}
}

func (s *PersisterTestSuite) TestWithFallbackNetworkID() {
t := s.T()
for k, r := range s.registries {
Expand Down Expand Up @@ -2133,6 +2190,13 @@ func newGrant(keySet string, keyID string) trust.Grant {
}
}

func newDeviceGrantRequest() *flow.DeviceGrantRequest {
return &flow.DeviceGrantRequest{
ID: uuid.Must(uuid.NewV4()).String(),
Verifier: uuid.Must(uuid.NewV4()).String(),
}
}

func newLogoutRequest() *flow.LogoutRequest {
return &flow.LogoutRequest{
ID: uuid.Must(uuid.NewV4()).String(),
Expand Down

0 comments on commit edf0738

Please sign in to comment.