diff --git a/pkg/server/api/api.go b/pkg/server/api/api.go index 29d207e685..b6f424f0b5 100644 --- a/pkg/server/api/api.go +++ b/pkg/server/api/api.go @@ -13,6 +13,9 @@ import ( // AuthorizedEntryFetcher is the interface to fetch authorized entries type AuthorizedEntryFetcher interface { + // LookupAuthorizedEntries fetches the entries in entrIDs that the + // specified SPIFFE ID is authorized for + LookupAuthorizedEntries(ctx context.Context, id spiffeid.ID, entryIDs map[string]struct{}) (map[string]*types.Entry, error) // FetchAuthorizedEntries fetches the entries that the specified // SPIFFE ID is authorized for FetchAuthorizedEntries(ctx context.Context, id spiffeid.ID) ([]*types.Entry, error) diff --git a/pkg/server/api/entry/v1/service_test.go b/pkg/server/api/entry/v1/service_test.go index 79de20f35c..7936fd56cf 100644 --- a/pkg/server/api/entry/v1/service_test.go +++ b/pkg/server/api/entry/v1/service_test.go @@ -4840,6 +4840,20 @@ type entryFetcher struct { entries []*types.Entry } +func (f *entryFetcher) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, _ map[string]struct{}) (map[string]*types.Entry, error) { + entries, err := f.FetchAuthorizedEntries(ctx, agentID) + if err != nil { + return nil, err + } + + entriesMap := make(map[string]*types.Entry) + for _, entry := range entries { + entriesMap[entry.GetId()] = entry + } + + return entriesMap, nil +} + func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]*types.Entry, error) { if f.err != "" { return nil, status.Error(codes.Internal, f.err) diff --git a/pkg/server/api/svid/v1/service.go b/pkg/server/api/svid/v1/service.go index 0b04139178..0e8dc6c550 100644 --- a/pkg/server/api/svid/v1/service.go +++ b/pkg/server/api/svid/v1/service.go @@ -168,8 +168,13 @@ func (s *Service) BatchNewX509SVID(ctx context.Context, req *svidv1.BatchNewX509 return nil, api.MakeErr(log, status.Code(err), "rejecting request due to certificate signing rate limiting", err) } + requestedEntries := make(map[string]struct{}) + for _, svidParam := range req.Params { + requestedEntries[svidParam.GetEntryId()] = struct{}{} + } + // Fetch authorized entries - entriesMap, err := s.fetchEntries(ctx, log) + entriesMap, err := s.findEntries(ctx, log, requestedEntries) if err != nil { return nil, err } @@ -205,24 +210,17 @@ func (s *Service) BatchNewX509SVID(ctx context.Context, req *svidv1.BatchNewX509 return &svidv1.BatchNewX509SVIDResponse{Results: results}, nil } -// fetchEntries fetches authorized entries using caller ID from context -func (s *Service) fetchEntries(ctx context.Context, log logrus.FieldLogger) (map[string]*types.Entry, error) { +func (s *Service) findEntries(ctx context.Context, log logrus.FieldLogger, entries map[string]struct{}) (map[string]*types.Entry, error) { callerID, ok := rpccontext.CallerID(ctx) if !ok { return nil, api.MakeErr(log, codes.Internal, "caller ID missing from request context", nil) } - entries, err := s.ef.FetchAuthorizedEntries(ctx, callerID) + foundEntries, err := s.ef.LookupAuthorizedEntries(ctx, callerID, entries) if err != nil { return nil, api.MakeErr(log, codes.Internal, "failed to fetch registration entries", err) } - - entriesMap := make(map[string]*types.Entry, len(entries)) - for _, entry := range entries { - entriesMap[entry.Id] = entry - } - - return entriesMap, nil + return foundEntries, nil } // newX509SVID creates an X509-SVID using data from registration entry and key from CSR @@ -262,7 +260,7 @@ func (s *Service) newX509SVID(ctx context.Context, param *svidv1.NewX509SVIDPara } } - spiffeID, err := api.TrustDomainMemberIDFromProto(ctx, s.td, entry.SpiffeId) + spiffeID, err := api.TrustDomainMemberIDFromProto(ctx, s.td, entry.GetSpiffeId()) if err != nil { // This shouldn't be the case unless there is invalid data in the datastore return &svidv1.BatchNewX509SVIDResponse_Result{ @@ -274,8 +272,8 @@ func (s *Service) newX509SVID(ctx context.Context, param *svidv1.NewX509SVIDPara x509Svid, err := s.ca.SignWorkloadX509SVID(ctx, ca.WorkloadX509SVIDParams{ SPIFFEID: spiffeID, PublicKey: csr.PublicKey, - DNSNames: entry.DnsNames, - TTL: time.Duration(entry.X509SvidTtl) * time.Second, + DNSNames: entry.GetDnsNames(), + TTL: time.Duration(entry.GetX509SvidTtl()) * time.Second, }) if err != nil { return &svidv1.BatchNewX509SVIDResponse_Result{ @@ -285,12 +283,12 @@ func (s *Service) newX509SVID(ctx context.Context, param *svidv1.NewX509SVIDPara log.WithField(telemetry.Expiration, x509Svid[0].NotAfter.Format(time.RFC3339)). WithField(telemetry.SerialNumber, x509Svid[0].SerialNumber.String()). - WithField(telemetry.RevisionNumber, entry.RevisionNumber). + WithField(telemetry.RevisionNumber, entry.GetRevisionNumber()). Debug("Signed X509 SVID") return &svidv1.BatchNewX509SVIDResponse_Result{ Svid: &types.X509SVID{ - Id: entry.SpiffeId, + Id: entry.GetSpiffeId(), CertChain: x509util.RawCertsFromCertificates(x509Svid), ExpiresAt: x509Svid[0].NotAfter.Unix(), }, @@ -350,8 +348,12 @@ func (s *Service) NewJWTSVID(ctx context.Context, req *svidv1.NewJWTSVIDRequest) return nil, api.MakeErr(log, status.Code(err), "rejecting request due to JWT signing request rate limiting", err) } + entries := map[string]struct{}{ + req.EntryId: {}, + } + // Fetch authorized entries - entriesMap, err := s.fetchEntries(ctx, log) + entriesMap, err := s.findEntries(ctx, log, entries) if err != nil { return nil, err } @@ -361,12 +363,12 @@ func (s *Service) NewJWTSVID(ctx context.Context, req *svidv1.NewJWTSVIDRequest) return nil, api.MakeErr(log, codes.NotFound, "entry not found or not authorized", nil) } - jwtsvid, err := s.mintJWTSVID(ctx, entry.SpiffeId, req.Audience, entry.JwtSvidTtl) + jwtsvid, err := s.mintJWTSVID(ctx, entry.GetSpiffeId(), req.Audience, entry.GetJwtSvidTtl()) if err != nil { return nil, err } rpccontext.AuditRPCWithFields(ctx, logrus.Fields{ - telemetry.TTL: entry.JwtSvidTtl, + telemetry.TTL: entry.GetJwtSvidTtl(), }) return &svidv1.NewJWTSVIDResponse{ diff --git a/pkg/server/api/svid/v1/service_test.go b/pkg/server/api/svid/v1/service_test.go index 68951297a6..3759a56ad5 100644 --- a/pkg/server/api/svid/v1/service_test.go +++ b/pkg/server/api/svid/v1/service_test.go @@ -2173,6 +2173,20 @@ type entryFetcher struct { entries []*types.Entry } +func (f *entryFetcher) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, _ map[string]struct{}) (map[string]*types.Entry, error) { + entries, err := f.FetchAuthorizedEntries(ctx, agentID) + if err != nil { + return nil, err + } + + entriesMap := make(map[string]*types.Entry) + for _, entry := range entries { + entriesMap[entry.GetId()] = entry + } + + return entriesMap, nil +} + func (f *entryFetcher) FetchAuthorizedEntries(ctx context.Context, agentID spiffeid.ID) ([]*types.Entry, error) { if f.err != "" { return nil, status.Error(codes.Internal, f.err) diff --git a/pkg/server/authorizedentries/cache.go b/pkg/server/authorizedentries/cache.go index 7a3261782b..9295aec8d3 100644 --- a/pkg/server/authorizedentries/cache.go +++ b/pkg/server/authorizedentries/cache.go @@ -59,6 +59,37 @@ func NewCache(clk clock.Clock) *Cache { } } +func (c *Cache) LookupAuthorizedEntries(agentID spiffeid.ID, entryIDs map[string]struct{}) map[string]*types.Entry { + c.mu.RLock() + defer c.mu.RUnlock() + + // Load up the agent selectors. If the agent info does not exist, it is + // likely that the cache is still catching up to a recent attestation. + // Since the calling agent has already been authorized and authenticated, + // it is safe to continue with the authorized entry crawl to obtain entries + // that are directly parented against the agent. Any entries that would be + // obtained via node aliasing will not be returned until the cache is + // updated with the node selectors for the agent. + agent, _ := c.agentsByID.Get(agentRecord{ID: agentID.String()}) + + parentSeen := allocStringSet() + defer freeStringSet(parentSeen) + + records := allocRecordSlice() + defer freeRecordSlice(records) + + foundEntries := make(map[string]*types.Entry) + + c.findDescendents(records, foundEntries, agentID.String(), entryIDs, parentSeen) + + agentAliases := c.getAgentAliases(agent.Selectors) + for _, alias := range agentAliases { + c.findDescendents(records, foundEntries, alias.AliasID, entryIDs, parentSeen) + } + + return foundEntries +} + func (c *Cache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry { c.mu.RLock() defer c.mu.RUnlock() @@ -164,6 +195,24 @@ func (c *Cache) appendDescendents(records []entryRecord, parentID string, parent return records } +func (c *Cache) findDescendents(records []entryRecord, foundEntries map[string]*types.Entry, parentID string, entryIDs map[string]struct{}, parentSeen stringSet) []entryRecord { + if _, ok := parentSeen[parentID]; ok { + return records + } + parentSeen[parentID] = struct{}{} + + lenBefore := len(records) + records = c.appendEntryRecordsForParentID(records, parentID) + // Crawl the children that were appended to get their descendents + for _, entry := range records[lenBefore:] { + if _, ok := entryIDs[entry.EntryID]; ok { + foundEntries[entry.EntryID] = cloneEntry(entry.EntryCloneOnly) + } + records = c.findDescendents(records, foundEntries, entry.SPIFFEID, entryIDs, parentSeen) + } + return records +} + func (c *Cache) appendEntryRecordsForParentID(records []entryRecord, parentID string) []entryRecord { pivot := entryRecord{ParentID: parentID} c.entriesByParentID.AscendGreaterOrEqual(pivot, func(record entryRecord) bool { diff --git a/pkg/server/cache/entrycache/fullcache.go b/pkg/server/cache/entrycache/fullcache.go index 19fa461cfe..77f0a215e6 100644 --- a/pkg/server/cache/entrycache/fullcache.go +++ b/pkg/server/cache/entrycache/fullcache.go @@ -2,6 +2,7 @@ package entrycache import ( "context" + "maps" "sync" "github.com/spiffe/go-spiffe/v2/spiffeid" @@ -28,6 +29,7 @@ var _ Cache = (*FullEntryCache)(nil) // Cache contains a snapshot of all registration entries and Agent selectors from the data source // at a particular moment in time. type Cache interface { + LookupAuthorizedEntries(agentID spiffeid.ID, entries map[string]struct{}) map[string]*types.Entry GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry } @@ -173,6 +175,13 @@ func Build(ctx context.Context, entryIter EntryIterator, agentIter AgentIterator }, nil } +func (c *FullEntryCache) LookupAuthorizedEntries(agentID spiffeid.ID, entries map[string]struct{}) map[string]*types.Entry { + seen := allocSeenSet() + defer freeSeenSet(seen) + + return c.lookupAuthorizedEntries(spiffeIDFromID(agentID), entries, seen) +} + // GetAuthorizedEntries gets all authorized registration entries for a given Agent SPIFFE ID. func (c *FullEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry { seen := allocSeenSet() @@ -181,6 +190,22 @@ func (c *FullEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entr return cloneEntries(c.getAuthorizedEntries(spiffeIDFromID(agentID), seen)) } +func (c *FullEntryCache) lookupAuthorizedEntries(id spiffeID, entries map[string]struct{}, seen map[spiffeID]struct{}) map[string]*types.Entry { + foundEntries := make(map[string]*types.Entry) + + for _, descendant := range c.crawl(id, seen) { + if _, ok := entries[descendant.Id]; ok { + foundEntries[descendant.Id] = descendant + } + maps.Copy(foundEntries, c.lookupAuthorizedEntries(spiffeIDFromProto(descendant.SpiffeId), entries, seen)) + } + + for _, alias := range c.aliases[id] { + maps.Copy(foundEntries, c.lookupAuthorizedEntries(alias.id, entries, seen)) + } + return foundEntries +} + func (c *FullEntryCache) getAuthorizedEntries(id spiffeID, seen map[spiffeID]struct{}) []*types.Entry { entries := c.crawl(id, seen) for _, descendant := range entries { diff --git a/pkg/server/cache/entrycache/fullcache_test.go b/pkg/server/cache/entrycache/fullcache_test.go index 3b2c34f7c8..636b2f6adb 100644 --- a/pkg/server/cache/entrycache/fullcache_test.go +++ b/pkg/server/cache/entrycache/fullcache_test.go @@ -128,7 +128,7 @@ func TestCache(t *testing.T) { expected := entries[:3] expected = append(expected, entries[4]) - assertAuthorizedEntries(t, cache, rootID, expected...) + assertAuthorizedEntries(t, cache, rootID, entries, expected...) } func TestCacheReturnsClonedEntries(t *testing.T) { @@ -232,9 +232,9 @@ func TestFullCacheNodeAliasing(t *testing.T) { cache, err := BuildFromDataStore(context.Background(), ds) assert.NoError(t, err) - assertAuthorizedEntries(t, cache, agentIDs[0], workloadEntries[:2]...) - assertAuthorizedEntries(t, cache, agentIDs[1], workloadEntries[1]) - assertAuthorizedEntries(t, cache, agentIDs[2], workloadEntries[2]) + assertAuthorizedEntries(t, cache, agentIDs[0], workloadEntries, workloadEntries[:2]...) + assertAuthorizedEntries(t, cache, agentIDs[1], workloadEntries, workloadEntries[1]) + assertAuthorizedEntries(t, cache, agentIDs[2], workloadEntries, workloadEntries[2]) } func TestFullCacheExcludesNodeSelectorMappedEntriesForExpiredAgents(t *testing.T) { @@ -795,7 +795,7 @@ func newSQLPlugin(ctx context.Context, tb testing.TB) datastore.DataStore { return p } -func assertAuthorizedEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, entries ...*common.RegistrationEntry) { +func assertAuthorizedEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, allEntries []*common.RegistrationEntry, entries ...*common.RegistrationEntry) { tb.Helper() expected, err := api.RegistrationEntriesToProto(entries) require.NoError(tb, err) @@ -806,6 +806,22 @@ func assertAuthorizedEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, en sortEntries(authorizedEntries) spiretest.AssertProtoListEqual(tb, expected, authorizedEntries) + + assertLookupEntries(tb, cache, agentID, allEntries, entries...) +} + +func assertLookupEntries(tb testing.TB, cache Cache, agentID spiffeid.ID, lookup []*common.RegistrationEntry, entries ...*common.RegistrationEntry) { + tb.Helper() + expected, err := api.RegistrationEntriesToProto(entries) + require.NoError(tb, err) + sortEntries(expected) + + lookupEntries := make(map[string]struct{}) + for _, entry := range lookup { + lookupEntries[entry.EntryId] = struct{}{} + } + foundEntries := cache.LookupAuthorizedEntries(agentID, lookupEntries) + require.Len(tb, foundEntries, len(entries)) } func sortEntries(es []*types.Entry) { diff --git a/pkg/server/endpoints/authorized_entryfetcher.go b/pkg/server/endpoints/authorized_entryfetcher.go index 0d31853129..7099561d54 100644 --- a/pkg/server/endpoints/authorized_entryfetcher.go +++ b/pkg/server/endpoints/authorized_entryfetcher.go @@ -56,6 +56,10 @@ func NewAuthorizedEntryFetcherWithEventsBasedCache(ctx context.Context, log logr }, nil } +func (a *AuthorizedEntryFetcherWithEventsBasedCache) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, entryIDs map[string]struct{}) (map[string]*types.Entry, error) { + return a.cache.LookupAuthorizedEntries(agentID, entryIDs), nil +} + func (a *AuthorizedEntryFetcherWithEventsBasedCache) FetchAuthorizedEntries(_ context.Context, agentID spiffeid.ID) ([]*types.Entry, error) { return a.cache.GetAuthorizedEntries(agentID), nil } diff --git a/pkg/server/endpoints/entryfetcher.go b/pkg/server/endpoints/entryfetcher.go index 19e2e0b6b9..bd2f76e2c4 100644 --- a/pkg/server/endpoints/entryfetcher.go +++ b/pkg/server/endpoints/entryfetcher.go @@ -49,6 +49,12 @@ func NewAuthorizedEntryFetcherWithFullCache(ctx context.Context, buildCache entr }, nil } +func (a *AuthorizedEntryFetcherWithFullCache) LookupAuthorizedEntries(ctx context.Context, agentID spiffeid.ID, entryIDs map[string]struct{}) (map[string]*types.Entry, error) { + a.mu.RLock() + defer a.mu.RUnlock() + return a.cache.LookupAuthorizedEntries(agentID, entryIDs), nil +} + func (a *AuthorizedEntryFetcherWithFullCache) FetchAuthorizedEntries(_ context.Context, agentID spiffeid.ID) ([]*types.Entry, error) { a.mu.RLock() defer a.mu.RUnlock() diff --git a/pkg/server/endpoints/entryfetcher_test.go b/pkg/server/endpoints/entryfetcher_test.go index 5262f72676..6be4725ad4 100644 --- a/pkg/server/endpoints/entryfetcher_test.go +++ b/pkg/server/endpoints/entryfetcher_test.go @@ -28,6 +28,17 @@ type staticEntryCache struct { entries map[spiffeid.ID][]*types.Entry } +func (f *staticEntryCache) LookupAuthorizedEntries(agentID spiffeid.ID, _ map[string]struct{}) map[string]*types.Entry { + entries := f.entries[agentID] + + entriesMap := make(map[string]*types.Entry) + for _, entry := range entries { + entriesMap[entry.GetId()] = entry + } + + return entriesMap +} + func (sef *staticEntryCache) GetAuthorizedEntries(agentID spiffeid.ID) []*types.Entry { return sef.entries[agentID] }