Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Sorin Dumitru <sdumitru@bloomberg.net>
  • Loading branch information
sorindumitru committed Feb 3, 2025
1 parent 925bc9c commit d8cc2ac
Show file tree
Hide file tree
Showing 10 changed files with 168 additions and 24 deletions.
3 changes: 3 additions & 0 deletions pkg/server/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import (

// AuthorizedEntryFetcher is the interface to fetch authorized entries
type AuthorizedEntryFetcher interface {
// LookupAuthorizedEntries fetches the entries in entrIDs that the
// specified SPIFFE ID is authorized for
LookupAuthorizedEntries(ctx context.Context, id spiffeid.ID, entryIDs map[string]struct{}) (map[string]*types.Entry, error)
// FetchAuthorizedEntries fetches the entries that the specified
// SPIFFE ID is authorized for
FetchAuthorizedEntries(ctx context.Context, id spiffeid.ID) ([]*types.Entry, error)
Expand Down
14 changes: 14 additions & 0 deletions pkg/server/api/entry/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4840,6 +4840,20 @@ type entryFetcher struct {
entries []*types.Entry
}

func (f *entryFetcher) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, _ map[string]struct{}) (map[string]*types.Entry, error) {
entries, err := f.FetchAuthorizedEntries(ctx, agentID)
if err != nil {
return nil, err
}

entriesMap := make(map[string]*types.Entry)
for _, entry := range entries {
entriesMap[entry.GetId()] = entry
}

return entriesMap, nil
}

func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]*types.Entry, error) {
if f.err != "" {
return nil, status.Error(codes.Internal, f.err)
Expand Down
40 changes: 21 additions & 19 deletions pkg/server/api/svid/v1/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,13 @@ func (s *Service) BatchNewX509SVID(ctx context.Context, req *svidv1.BatchNewX509
return nil, api.MakeErr(log, status.Code(err), "rejecting request due to certificate signing rate limiting", err)
}

requestedEntries := make(map[string]struct{})
for _, svidParam := range req.Params {
requestedEntries[svidParam.GetEntryId()] = struct{}{}
}

// Fetch authorized entries
entriesMap, err := s.fetchEntries(ctx, log)
entriesMap, err := s.findEntries(ctx, log, requestedEntries)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -205,24 +210,17 @@ func (s *Service) BatchNewX509SVID(ctx context.Context, req *svidv1.BatchNewX509
return &svidv1.BatchNewX509SVIDResponse{Results: results}, nil
}

// fetchEntries fetches authorized entries using caller ID from context
func (s *Service) fetchEntries(ctx context.Context, log logrus.FieldLogger) (map[string]*types.Entry, error) {
func (s *Service) findEntries(ctx context.Context, log logrus.FieldLogger, entries map[string]struct{}) (map[string]*types.Entry, error) {
callerID, ok := rpccontext.CallerID(ctx)
if !ok {
return nil, api.MakeErr(log, codes.Internal, "caller ID missing from request context", nil)
}

entries, err := s.ef.FetchAuthorizedEntries(ctx, callerID)
foundEntries, err := s.ef.LookupAuthorizedEntries(ctx, callerID, entries)
if err != nil {
return nil, api.MakeErr(log, codes.Internal, "failed to fetch registration entries", err)
}

entriesMap := make(map[string]*types.Entry, len(entries))
for _, entry := range entries {
entriesMap[entry.Id] = entry
}

return entriesMap, nil
return foundEntries, nil
}

// newX509SVID creates an X509-SVID using data from registration entry and key from CSR
Expand Down Expand Up @@ -262,7 +260,7 @@ func (s *Service) newX509SVID(ctx context.Context, param *svidv1.NewX509SVIDPara
}
}

spiffeID, err := api.TrustDomainMemberIDFromProto(ctx, s.td, entry.SpiffeId)
spiffeID, err := api.TrustDomainMemberIDFromProto(ctx, s.td, entry.GetSpiffeId())
if err != nil {
// This shouldn't be the case unless there is invalid data in the datastore
return &svidv1.BatchNewX509SVIDResponse_Result{
Expand All @@ -274,8 +272,8 @@ func (s *Service) newX509SVID(ctx context.Context, param *svidv1.NewX509SVIDPara
x509Svid, err := s.ca.SignWorkloadX509SVID(ctx, ca.WorkloadX509SVIDParams{
SPIFFEID: spiffeID,
PublicKey: csr.PublicKey,
DNSNames: entry.DnsNames,
TTL: time.Duration(entry.X509SvidTtl) * time.Second,
DNSNames: entry.GetDnsNames(),
TTL: time.Duration(entry.GetX509SvidTtl()) * time.Second,
})
if err != nil {
return &svidv1.BatchNewX509SVIDResponse_Result{
Expand All @@ -285,12 +283,12 @@ func (s *Service) newX509SVID(ctx context.Context, param *svidv1.NewX509SVIDPara

log.WithField(telemetry.Expiration, x509Svid[0].NotAfter.Format(time.RFC3339)).
WithField(telemetry.SerialNumber, x509Svid[0].SerialNumber.String()).
WithField(telemetry.RevisionNumber, entry.RevisionNumber).
WithField(telemetry.RevisionNumber, entry.GetRevisionNumber()).
Debug("Signed X509 SVID")

return &svidv1.BatchNewX509SVIDResponse_Result{
Svid: &types.X509SVID{
Id: entry.SpiffeId,
Id: entry.GetSpiffeId(),
CertChain: x509util.RawCertsFromCertificates(x509Svid),
ExpiresAt: x509Svid[0].NotAfter.Unix(),
},
Expand Down Expand Up @@ -350,8 +348,12 @@ func (s *Service) NewJWTSVID(ctx context.Context, req *svidv1.NewJWTSVIDRequest)
return nil, api.MakeErr(log, status.Code(err), "rejecting request due to JWT signing request rate limiting", err)
}

entries := map[string]struct{}{
req.EntryId: {},
}

// Fetch authorized entries
entriesMap, err := s.fetchEntries(ctx, log)
entriesMap, err := s.findEntries(ctx, log, entries)
if err != nil {
return nil, err
}
Expand All @@ -361,12 +363,12 @@ func (s *Service) NewJWTSVID(ctx context.Context, req *svidv1.NewJWTSVIDRequest)
return nil, api.MakeErr(log, codes.NotFound, "entry not found or not authorized", nil)
}

jwtsvid, err := s.mintJWTSVID(ctx, entry.SpiffeId, req.Audience, entry.JwtSvidTtl)
jwtsvid, err := s.mintJWTSVID(ctx, entry.GetSpiffeId(), req.Audience, entry.GetJwtSvidTtl())
if err != nil {
return nil, err
}
rpccontext.AuditRPCWithFields(ctx, logrus.Fields{
telemetry.TTL: entry.JwtSvidTtl,
telemetry.TTL: entry.GetJwtSvidTtl(),
})

return &svidv1.NewJWTSVIDResponse{
Expand Down
14 changes: 14 additions & 0 deletions pkg/server/api/svid/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2173,6 +2173,20 @@ type entryFetcher struct {
entries []*types.Entry
}

func (f *entryFetcher) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, _ map[string]struct{}) (map[string]*types.Entry, error) {
entries, err := f.FetchAuthorizedEntries(ctx, agentID)
if err != nil {
return nil, err
}

entriesMap := make(map[string]*types.Entry)
for _, entry := range entries {
entriesMap[entry.GetId()] = entry
}

return entriesMap, nil
}

func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]*types.Entry, error) {
if f.err != "" {
return nil, status.Error(codes.Internal, f.err)
Expand Down
49 changes: 49 additions & 0 deletions pkg/server/authorizedentries/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,37 @@ func NewCache(clk clock.Clock) *Cache {
}
}

func (c *Cache) LookupAuthorizedEntries(agentID spiffeid.ID, entryIDs map[string]struct{}) map[string]*types.Entry {
c.mu.RLock()
defer c.mu.RUnlock()

// Load up the agent selectors. If the agent info does not exist, it is
// likely that the cache is still catching up to a recent attestation.
// Since the calling agent has already been authorized and authenticated,
// it is safe to continue with the authorized entry crawl to obtain entries
// that are directly parented against the agent. Any entries that would be
// obtained via node aliasing will not be returned until the cache is
// updated with the node selectors for the agent.
agent, _ := c.agentsByID.Get(agentRecord{ID: agentID.String()})

parentSeen := allocStringSet()
defer freeStringSet(parentSeen)

records := allocRecordSlice()
defer freeRecordSlice(records)

foundEntries := make(map[string]*types.Entry)

c.findDescendents(records, foundEntries, agentID.String(), entryIDs, parentSeen)

agentAliases := c.getAgentAliases(agent.Selectors)
for _, alias := range agentAliases {
c.findDescendents(records, foundEntries, alias.AliasID, entryIDs, parentSeen)
}

return foundEntries
}

func (c *Cache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry {
c.mu.RLock()
defer c.mu.RUnlock()
Expand Down Expand Up @@ -164,6 +195,24 @@ func (c *Cache) appendDescendents(records []entryRecord, parentID string, parent
return records
}

func (c *Cache) findDescendents(records []entryRecord, foundEntries map[string]*types.Entry, parentID string, entryIDs map[string]struct{}, parentSeen stringSet) []entryRecord {
if _, ok := parentSeen[parentID]; ok {
return records
}
parentSeen[parentID] = struct{}{}

lenBefore := len(records)
records = c.appendEntryRecordsForParentID(records, parentID)
// Crawl the children that were appended to get their descendents
for _, entry := range records[lenBefore:] {
if _, ok := entryIDs[entry.EntryID]; ok {
foundEntries[entry.EntryID] = cloneEntry(entry.EntryCloneOnly)
}
records = c.findDescendents(records, foundEntries, entry.SPIFFEID, entryIDs, parentSeen)
}
return records
}

func (c *Cache) appendEntryRecordsForParentID(records []entryRecord, parentID string) []entryRecord {
pivot := entryRecord{ParentID: parentID}
c.entriesByParentID.AscendGreaterOrEqual(pivot, func(record entryRecord) bool {
Expand Down
25 changes: 25 additions & 0 deletions pkg/server/cache/entrycache/fullcache.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package entrycache

import (
"context"
"maps"
"sync"

"github.com/spiffe/go-spiffe/v2/spiffeid"
Expand All @@ -28,6 +29,7 @@ var _ Cache = (*FullEntryCache)(nil)
// Cache contains a snapshot of all registration entries and Agent selectors from the data source
// at a particular moment in time.
type Cache interface {
LookupAuthorizedEntries(agentID spiffeid.ID, entries map[string]struct{}) map[string]*types.Entry
GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry
}

Expand Down Expand Up @@ -173,6 +175,13 @@ func Build(ctx context.Context, entryIter EntryIterator, agentIter AgentIterator
}, nil
}

func (c *FullEntryCache) LookupAuthorizedEntries(agentID spiffeid.ID, entries map[string]struct{}) map[string]*types.Entry {
seen := allocSeenSet()
defer freeSeenSet(seen)

return c.lookupAuthorizedEntries(spiffeIDFromID(agentID), entries, seen)
}

// GetAuthorizedEntries gets all authorized registration entries for a given Agent SPIFFE ID.
func (c *FullEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry {
seen := allocSeenSet()
Expand All @@ -181,6 +190,22 @@ func (c *FullEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entr
return cloneEntries(c.getAuthorizedEntries(spiffeIDFromID(agentID), seen))
}

func (c *FullEntryCache) lookupAuthorizedEntries(id spiffeID, entries map[string]struct{}, seen map[spiffeID]struct{}) map[string]*types.Entry {
foundEntries := make(map[string]*types.Entry)

for _, descendant := range c.crawl(id, seen) {
if _, ok := entries[descendant.Id]; ok {
foundEntries[descendant.Id] = descendant
}
maps.Copy(foundEntries, c.lookupAuthorizedEntries(spiffeIDFromProto(descendant.SpiffeId), entries, seen))
}

for _, alias := range c.aliases[id] {
maps.Copy(foundEntries, c.lookupAuthorizedEntries(alias.id, entries, seen))
}
return foundEntries
}

func (c *FullEntryCache) getAuthorizedEntries(id spiffeID, seen map[spiffeID]struct{}) []*types.Entry {
entries := c.crawl(id, seen)
for _, descendant := range entries {
Expand Down
26 changes: 21 additions & 5 deletions pkg/server/cache/entrycache/fullcache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestCache(t *testing.T) {

expected := entries[:3]
expected = append(expected, entries[4])
assertAuthorizedEntries(t, cache, rootID, expected...)
assertAuthorizedEntries(t, cache, rootID, entries, expected...)
}

func TestCacheReturnsClonedEntries(t *testing.T) {
Expand Down Expand Up @@ -232,9 +232,9 @@ func TestFullCacheNodeAliasing(t *testing.T) {
cache, err := BuildFromDataStore(context.Background(), ds)
assert.NoError(t, err)

assertAuthorizedEntries(t, cache, agentIDs[0], workloadEntries[:2]...)
assertAuthorizedEntries(t, cache, agentIDs[1], workloadEntries[1])
assertAuthorizedEntries(t, cache, agentIDs[2], workloadEntries[2])
assertAuthorizedEntries(t, cache, agentIDs[0], workloadEntries, workloadEntries[:2]...)
assertAuthorizedEntries(t, cache, agentIDs[1], workloadEntries, workloadEntries[1])
assertAuthorizedEntries(t, cache, agentIDs[2], workloadEntries, workloadEntries[2])
}

func TestFullCacheExcludesNodeSelectorMappedEntriesForExpiredAgents(t *testing.T) {
Expand Down Expand Up @@ -795,7 +795,7 @@ func newSQLPlugin(ctx context.Context, tb testing.TB) datastore.DataStore {
return p
}

func assertAuthorizedEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, entries ...*common.RegistrationEntry) {
func assertAuthorizedEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, allEntries []*common.RegistrationEntry, entries ...*common.RegistrationEntry) {
tb.Helper()
expected, err := api.RegistrationEntriesToProto(entries)
require.NoError(tb, err)
Expand All @@ -806,6 +806,22 @@ func assertAuthorizedEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, en
sortEntries(authorizedEntries)

spiretest.AssertProtoListEqual(tb, expected, authorizedEntries)

assertLookupEntries(tb, cache, agentID, allEntries, entries...)
}

func assertLookupEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, lookup []*common.RegistrationEntry, entries ...*common.RegistrationEntry) {
tb.Helper()
expected, err := api.RegistrationEntriesToProto(entries)
require.NoError(tb, err)
sortEntries(expected)

lookupEntries := make(map[string]struct{})
for _, entry := range lookup {
lookupEntries[entry.EntryId] = struct{}{}
}
foundEntries := cache.LookupAuthorizedEntries(agentID, lookupEntries)
require.Len(tb, foundEntries, len(entries))
}

func sortEntries(es []*types.Entry) {
Expand Down
4 changes: 4 additions & 0 deletions pkg/server/endpoints/authorized_entryfetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ func NewAuthorizedEntryFetcherWithEventsBasedCache(ctx context.Context, log logr
}, nil
}

func (a *AuthorizedEntryFetcherWithEventsBasedCache) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, entryIDs map[string]struct{}) (map[string]*types.Entry, error) {
return a.cache.LookupAuthorizedEntries(agentID, entryIDs), nil
}

func (a *AuthorizedEntryFetcherWithEventsBasedCache) FetchAuthorizedEntries(_ context.Context, agentID spiffeid.ID) ([]*types.Entry, error) {
return a.cache.GetAuthorizedEntries(agentID), nil
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/server/endpoints/entryfetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ func NewAuthorizedEntryFetcherWithFullCache(ctx context.Context, buildCache entr
}, nil
}

func (a *AuthorizedEntryFetcherWithFullCache) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, entryIDs map[string]struct{}) (map[string]*types.Entry, error) {
a.mu.RLock()
defer a.mu.RUnlock()
return a.cache.LookupAuthorizedEntries(agentID, entryIDs), nil
}

func (a *AuthorizedEntryFetcherWithFullCache) FetchAuthorizedEntries(_ context.Context, agentID spiffeid.ID) ([]*types.Entry, error) {
a.mu.RLock()
defer a.mu.RUnlock()
Expand Down
11 changes: 11 additions & 0 deletions pkg/server/endpoints/entryfetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@ type staticEntryCache struct {
entries map[spiffeid.ID][]*types.Entry
}

func (f *staticEntryCache) LookupAuthorizedEntries(agentID spiffeid.ID, _ map[string]struct{}) map[string]*types.Entry {
entries := f.entries[agentID]

entriesMap := make(map[string]*types.Entry)
for _, entry := range entries {
entriesMap[entry.GetId()] = entry
}

return entriesMap
}

func (sef *staticEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry {
return sef.entries[agentID]
}
Expand Down

0 comments on commit d8cc2ac

Please sign in to comment.