Skip to content

Commit

Permalink
Adjust for review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
moroten committed Jan 24, 2025
1 parent 47d3807 commit de70210
Show file tree
Hide file tree
Showing 5 changed files with 281 additions and 76 deletions.
82 changes: 49 additions & 33 deletions pkg/grpc/remote_authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ type remoteAuthenticator struct {
type RemoteAuthenticatorCacheKey [sha256.Size]byte

type remoteAuthCacheEntry struct {
ready <-chan struct{}
response remoteAuthResponse
// ready is closed when the remote request has finished.
ready <-chan struct{}
// response is nil if the request is ongoing or has failed and should be
// retried.
response *remoteAuthResponse
}

type remoteAuthResponse struct {
Expand All @@ -45,16 +48,24 @@ type remoteAuthResponse struct {
err error
}

func (ce *remoteAuthCacheEntry) HasExpired(now time.Time) bool {
func (ce *remoteAuthCacheEntry) IsReady() bool {
select {
case <-ce.ready:
return ce.response.expirationTime.Before(now)
return true
default:
// Ongoing remote requests have not expired by definition.
return false
}
}

// IsValid returns false if a new remote request should be made.
func (ce *remoteAuthCacheEntry) IsValid(now time.Time) bool {
if ce.response == nil {
// Error response on the remote request, make a new request.
return false
}
return now.Before(ce.response.expirationTime)
}

// NewRemoteAuthenticator creates a new RemoteAuthenticator for incoming
// requests that forwards headers to a remote service for authentication. The
// result from the remote service is cached.
Expand Down Expand Up @@ -95,37 +106,46 @@ func (a *remoteAuthenticator) Authenticate(ctx context.Context, headers map[stri
// keeping credentials in the memory.
requestKey := sha256.Sum256(requestBytes)

a.lock.Lock()
now := a.clock.Now()
entry := a.getAndTouchCacheEntry(requestKey)
if entry != nil && entry.HasExpired(now) {
entry = nil
}
if entry == nil {
// No valid cache entry available. Deduplicate requests by creating a
// pending cached response.
responseReady := make(chan struct{})
entry = &remoteAuthCacheEntry{
ready: responseReady,
for {
a.lock.Lock()
entry := a.getAndTouchCacheEntry(requestKey)
if entry == nil || (entry.IsReady() && !entry.IsValid(now)) {
// No valid cache entry available. Deduplicate requests by creating a
// pending cached response.
responseReady := make(chan struct{})
entry = &remoteAuthCacheEntry{
ready: responseReady,
}
a.cachedResponses[requestKey] = entry
a.lock.Unlock()

// Perform the remote authentication request.
response, err := a.authenticateRemotely(ctx, request)
if err != nil {
close(responseReady)
return nil, err
}
entry.response = response
close(responseReady)
return response.authMetadata, response.err
}
a.cachedResponses[requestKey] = entry
a.lock.Unlock()

// Perform the remote authentication request.
entry.response = a.authenticateRemotely(ctx, request)
close(responseReady)
} else {
a.lock.Unlock()

// Wait for the remote request to finish.
select {
case <-ctx.Done():
return nil, util.StatusWrapWithCode(ctx.Err(), codes.Unauthenticated, "Context cancelled")
return nil, util.StatusFromContext(ctx)
case <-entry.ready:
// Noop
// Check whether the remote authentication call succeeded.
// Otherwise, retry with our own ctx.
if entry.response != nil {
// Note that the expiration time is not checked, as the response
// is as fresh as it can be.
return entry.response.authMetadata, entry.response.err
}
}
}
return entry.response.authMetadata, entry.response.err
}

func (a *remoteAuthenticator) getAndTouchCacheEntry(requestKey RemoteAuthenticatorCacheKey) *remoteAuthCacheEntry {
Expand All @@ -145,16 +165,15 @@ func (a *remoteAuthenticator) getAndTouchCacheEntry(requestKey RemoteAuthenticat
return nil
}

func (a *remoteAuthenticator) authenticateRemotely(ctx context.Context, request *auth_pb.AuthenticateRequest) remoteAuthResponse {
func (a *remoteAuthenticator) authenticateRemotely(ctx context.Context, request *auth_pb.AuthenticateRequest) (*remoteAuthResponse, error) {
ret := remoteAuthResponse{
// The default expirationTime has already passed.
expirationTime: time.Time{},
}

response, err := a.remoteAuthClient.Authenticate(ctx, request)
if err != nil {
ret.err = util.StatusWrapWithCode(err, codes.Unauthenticated, "Remote authentication failed")
return ret
return nil, util.StatusWrapWithCode(err, codes.Unauthenticated, "Remote authentication failed")
}

// An invalid expiration time indicates that the response should not be cached.
Expand All @@ -168,14 +187,11 @@ func (a *remoteAuthenticator) authenticateRemotely(ctx context.Context, request
ret.authMetadata, err = auth.NewAuthenticationMetadataFromProto(verdict.Allow)
if err != nil {
ret.err = util.StatusWrapWithCode(err, codes.Unauthenticated, "Bad authentication response")
return ret
}
case *auth_pb.AuthenticateResponse_Deny:
ret.err = status.Error(codes.Unauthenticated, verdict.Deny)
return ret
default:
ret.err = status.Error(codes.Unauthenticated, "Invalid authentication verdict")
return ret
}
return ret
return &ret, nil
}
88 changes: 85 additions & 3 deletions pkg/grpc/remote_authenticator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,14 @@ func TestRemoteAuthenticatorSuccess(t *testing.T) {
Public: structpb.NewStringValue("You're totally who you say you are: " + token),
},
},
CacheExpirationTime: timestamppb.New(time.Unix(1001, 0)),
CacheExpirationTime: timestamppb.New(time.Unix(1002, 0)),
})
} else if strings.HasPrefix(token, "deny") {
proto.Merge(reply.(proto.Message), &auth_pb.AuthenticateResponse{
Verdict: &auth_pb.AuthenticateResponse_Deny{
Deny: "You are an alien: " + token,
},
CacheExpirationTime: timestamppb.New(time.Unix(1001, 0)),
CacheExpirationTime: timestamppb.New(time.Unix(1002, 0)),
})
}
return nil
Expand Down Expand Up @@ -292,7 +292,7 @@ func TestRemoteAuthenticatorSuccess(t *testing.T) {
_, err := authenticator.Authenticate(ctx1c, map[string][]string{"Authorization": {"token1"}})
testutil.RequireEqualStatus(
t,
status.Error(codes.Unauthenticated, "Context cancelled: context canceled"),
status.Error(codes.Canceled, "context canceled"),
err)
close(done1c)
}()
Expand Down Expand Up @@ -327,4 +327,86 @@ func TestRemoteAuthenticatorSuccess(t *testing.T) {
<-done1a
<-done1b
})

t.Run("SkipDeduplicateErrors", func(t *testing.T) {
client := mock.NewMockClientConnInterface(ctrl)
clock := mock.NewMockClock(ctrl)

authCalled := make(chan struct{})
authRelease := make(map[string]chan struct{})

client.EXPECT().Invoke(
ctx, "/buildbarn.auth.Authentication/Authenticate", gomock.Any(), gomock.Any(),
).DoAndReturn(func(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
token := args.(*auth_pb.AuthenticateRequest).RequestMetadata["Authorization"].Value[0]
proto.Merge(reply.(proto.Message), &auth_pb.AuthenticateResponse{})
authCalled <- struct{}{}
<-authRelease[token]
return status.Error(codes.DataLoss, "Data loss")
})
client.EXPECT().Invoke(
ctx, "/buildbarn.auth.Authentication/Authenticate", gomock.Any(), gomock.Any(),
).DoAndReturn(func(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
token := args.(*auth_pb.AuthenticateRequest).RequestMetadata["Authorization"].Value[0]
proto.Merge(reply.(proto.Message), &auth_pb.AuthenticateResponse{})
authCalled <- struct{}{}
<-authRelease[token]
return nil
})

clock.EXPECT().Now().Return(time.Unix(1000, 0)).AnyTimes()

authenticator := bb_grpc.NewRemoteAuthenticator(
client,
structpb.NewStringValue("auth-scope"),
clock,
eviction.NewLRUSet[bb_grpc.RemoteAuthenticatorCacheKey](),
100,
)
doAuth := func(token string, done chan<- struct{}, verdict string) {
_, err := authenticator.Authenticate(ctx, map[string][]string{"Authorization": {token}})
defer close(done)
testutil.RequireEqualStatus(
t,
status.Error(codes.Unauthenticated, verdict),
err)
}

authRelease["token1"] = make(chan struct{})
done1a := make(chan struct{})
done1b := make(chan struct{})
done1c := make(chan struct{})
go doAuth("token1", done1a, "Remote authentication failed: Data loss")
<-authCalled // token1a
go doAuth("token1", done1b, "Invalid authentication verdict")
go doAuth("token1", done1c, "Invalid authentication verdict")
// Nothing done yet.
time.Sleep(100 * time.Millisecond)
select {
case <-done1a:
t.Error("done1a too early")
case <-done1b:
t.Error("done1b too early")
case <-done1c:
t.Error("done1c too early")
case <-authCalled:
t.Error("authCalled second time too early")
default:
// Noop.
}
close(authRelease["token1"])
// token1 still blocked.
time.Sleep(100 * time.Millisecond)
select {
case <-done1b:
t.Error("done1b too early")
case <-done1c:
t.Error("done1c too early")
case <-authCalled:
// token1b released.
// Noop.
}
<-done1b
<-done1c
})
}
82 changes: 52 additions & 30 deletions pkg/grpc/remote_authorizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,29 @@ type RemoteAuthorizerCacheKey [sha256.Size]byte

type remoteAuthorizerCacheEntry struct {
ready <-chan struct{}
valid bool
expirationTime time.Time
err error
}

func (ce *remoteAuthorizerCacheEntry) HasExpired(now time.Time) bool {
func (ce *remoteAuthorizerCacheEntry) IsReady() bool {
select {
case <-ce.ready:
return ce.expirationTime.Before(now)
return true
default:
// Ongoing remote requests have not expired by definition.
return false
}
}

// IsValid returns false if a new remote request should be made.
func (ce *remoteAuthorizerCacheEntry) IsValid(now time.Time) bool {
if !ce.valid {
// Error response on the remote request, make a new request.
return false
}
return now.Before(ce.expirationTime)
}

// NewRemoteAuthorizer creates a new Authorizer which asks a remote gRPC
// service for authorize response. The result from the remote service is
// cached.
Expand Down Expand Up @@ -95,37 +104,50 @@ func (a *remoteAuthorizer) authorizeSingle(ctx context.Context, instanceName dig
// Hash the request to use as a cache key to save memory.
requestKey := sha256.Sum256(requestBytes)

a.lock.Lock()
now := a.clock.Now()
entry := a.getAndTouchCacheEntry(requestKey)
if entry != nil && entry.HasExpired(now) {
entry = nil
}
if entry == nil {
// No valid cache entry available. Deduplicate requests by creating a
// pending cached response.
responseReady := make(chan struct{})
entry = &remoteAuthorizerCacheEntry{
ready: responseReady,
for {
a.lock.Lock()
now := a.clock.Now()
entry := a.getAndTouchCacheEntry(requestKey)
if entry == nil || (entry.IsReady() && !entry.IsValid(now)) {
// No valid cache entry available. Deduplicate requests by creating a
// pending cached response.
responseReady := make(chan struct{})
entry = &remoteAuthorizerCacheEntry{
ready: responseReady,
}
a.cachedResponses[requestKey] = entry
a.lock.Unlock()

// Perform the remote authentication request.
expirationTime, err := a.authorizeRemotely(ctx, request)
if expirationTime == nil {
// The response should not be cached.
entry.valid = false
close(responseReady)
return err
}
entry.valid = true
entry.expirationTime = *expirationTime
entry.err = err
close(responseReady)
return entry.err
}
a.cachedResponses[requestKey] = entry
a.lock.Unlock()

// Perform the remote authentication request.
entry.expirationTime, entry.err = a.authorizeRemotely(ctx, request)
close(responseReady)
} else {
a.lock.Unlock()

// Wait for the remote request to finish.
select {
case <-ctx.Done():
return util.StatusWrapWithCode(ctx.Err(), codes.PermissionDenied, "Context cancelled")
return util.StatusFromContext(ctx)
case <-entry.ready:
// Noop
// Check whether the remote authentication call succeeded, otherwise
// retry with our own ctx.
if entry.valid {
// Note that the expiration time is not checked, as the response
// is as fresh as it can be.
return entry.err
}
}
}
return entry.err
}

func (a *remoteAuthorizer) getAndTouchCacheEntry(requestKey RemoteAuthorizerCacheKey) *remoteAuthorizerCacheEntry {
Expand All @@ -145,13 +167,13 @@ func (a *remoteAuthorizer) getAndTouchCacheEntry(requestKey RemoteAuthorizerCach
return nil
}

func (a *remoteAuthorizer) authorizeRemotely(ctx context.Context, request *auth_pb.AuthorizeRequest) (time.Time, error) {
func (a *remoteAuthorizer) authorizeRemotely(ctx context.Context, request *auth_pb.AuthorizeRequest) (*time.Time, error) {
// The default expirationTime has already passed.
expirationTime := time.Time{}

response, err := a.remoteAuthClient.Authorize(ctx, request)
if err != nil {
return expirationTime, util.StatusWrapWithCode(err, codes.PermissionDenied, "Remote authorization failed")
return nil, util.StatusWrapWithCode(err, codes.PermissionDenied, "Remote authorization failed")
}

// An invalid expiration time indicates that the response should not be cached.
Expand All @@ -162,10 +184,10 @@ func (a *remoteAuthorizer) authorizeRemotely(ctx context.Context, request *auth_

switch verdict := response.GetVerdict().(type) {
case *auth_pb.AuthorizeResponse_Allow:
return expirationTime, nil
return &expirationTime, nil
case *auth_pb.AuthorizeResponse_Deny:
return expirationTime, status.Error(codes.PermissionDenied, verdict.Deny)
return &expirationTime, status.Error(codes.PermissionDenied, verdict.Deny)
default:
return expirationTime, status.Error(codes.PermissionDenied, "Invalid authorize verdict")
return &expirationTime, status.Error(codes.PermissionDenied, "Invalid authorize verdict")
}
}
Loading

0 comments on commit de70210

Please sign in to comment.