Skip to content

Commit 037fadd

Browse files
committed
Prefix aware scorer data structure fix
Signed-off-by: Ricardo Noriega <rnoriega@redhat.com>
1 parent 4966409 commit 037fadd

File tree

9 files changed

+100
-11
lines changed

9 files changed

+100
-11
lines changed

cmd/epp/main.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"net/http"
2424
"os"
2525
"strconv"
26+
"time"
2627

2728
"github.com/go-logr/logr"
2829
"github.com/prometheus/client_golang/prometheus/promhttp"
@@ -42,6 +43,7 @@ import (
4243
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
4344
runserver "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/server"
4445
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
46+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
4547
)
4648

4749
const (
@@ -162,7 +164,12 @@ func run() error {
162164

163165
pmf := backendmetrics.NewPodMetricsFactory(&backendmetrics.PodMetricsClientImpl{MetricMapping: mapping}, *refreshMetricsInterval)
164166
// Setup runner.
165-
datastore := datastore.NewDatastore(ctx, pmf)
167+
datastore := datastore.NewDatastore(ctx, pmf, scheduling.NewPrefixStore(scheduling.PrefixStoreConfig{
168+
MaxEntries: 1000,
169+
MinPrefixLen: 1,
170+
MaxPrefixLen: 100,
171+
EntryTTL: 5 * time.Minute,
172+
}))
166173

167174
serverRunner := &runserver.ExtProcServerRunner{
168175
GrpcPort: *grpcPort,

pkg/epp/datastore/datastore.go

+20-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ import (
3232
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
3333
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3434
podutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/pod"
35+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
36+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
3537
)
3638

3739
const (
@@ -75,14 +77,15 @@ type Datastore interface {
7577
Clear()
7678
}
7779

78-
func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFactory) Datastore {
80+
func NewDatastore(parentCtx context.Context, pmf *backendmetrics.PodMetricsFactory, prefixStore *scheduling.PrefixStore) Datastore {
7981
store := &datastore{
8082
parentCtx: parentCtx,
8183
poolAndModelsMu: sync.RWMutex{},
8284
models: make(map[string]*v1alpha2.InferenceModel),
8385
pods: &sync.Map{},
8486
sessions: &sync.Map{},
8587
pmf: pmf,
88+
prefixStore: prefixStore,
8689
}
8790

8891
go store.cleanupSessions(sessionKeepAliveCheckFrequency, sessionKeepAliveTime, parentCtx)
@@ -103,6 +106,8 @@ type datastore struct {
103106
// key: session id, value: *backendmetrics.Pod
104107
sessions *sync.Map
105108
pmf *backendmetrics.PodMetricsFactory
109+
// prefixStore is used to store and lookup model prefixes
110+
prefixStore *scheduling.PrefixStore
106111
}
107112

108113
func (ds *datastore) Clear() {
@@ -256,6 +261,19 @@ func (ds *datastore) PodUpdateOrAddIfNotExist(pod *corev1.Pod, pool *v1alpha2.In
256261
}
257262
// Update pod properties if anything changed.
258263
pm.UpdatePod(pod)
264+
265+
// Update prefix store with pod's active models
266+
if ds.prefixStore != nil {
267+
metrics := pm.GetMetrics()
268+
for model := range metrics.ActiveModels {
269+
// Add a prefix for each active model
270+
// The prefix is the model name itself
271+
ds.prefixStore.AddPrefix(ds.parentCtx, model, namespacedName, model)
272+
}
273+
} else {
274+
log.FromContext(ds.parentCtx).V(logging.DEBUG).Info("Prefix store is nil, skipping prefix updates", "pod", namespacedName)
275+
}
276+
259277
return ok
260278
}
261279

@@ -301,6 +319,7 @@ func (ds *datastore) PodDelete(namespacedName types.NamespacedName) {
301319
pmr := v.(backendmetrics.PodMetrics)
302320
pmr.StopRefreshLoop()
303321
}
322+
log.FromContext(ds.parentCtx).V(logging.DEBUG).Info("Pod removed or not added", "name", namespacedName)
304323
}
305324

306325
type sessionInfo struct {

pkg/epp/handlers/request.go

+5
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ func (s *StreamingServer) HandleRequestBody(
5151

5252
modelName := model
5353

54+
// Extract prompt from request body
55+
prompt, _ := requestBodyMap["prompt"].(string) // We don't require prompt to be present
56+
reqCtx.Prompt = prompt
57+
5458
// NOTE: The nil checking for the modelObject means that we DO allow passthrough currently.
5559
// This might be a security risk in the future where adapters not registered in the InferenceModel
5660
// are able to be requested by using their distinct name.
@@ -67,6 +71,7 @@ func (s *StreamingServer) HandleRequestBody(
6771

6872
llmReq := &schedulingtypes.LLMRequest{
6973
Model: model,
74+
Prompt: prompt,
7075
ResolvedTargetModel: modelName,
7176
Critical: modelObj.Spec.Criticality != nil && *modelObj.Spec.Criticality == v1alpha2.Critical,
7277
SessionID: reqCtx.SessionID,

pkg/epp/handlers/response.go

+27
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"sigs.k8s.io/controller-runtime/pkg/log"
2626
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/metrics"
2727
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
28+
"k8s.io/apimachinery/pkg/types"
2829
)
2930

3031
const (
@@ -62,6 +63,32 @@ func (s *StreamingServer) HandleResponseBody(
6263
// will add the processing for streaming case.
6364
reqCtx.ResponseComplete = true
6465

66+
// Add the prompt to the prefix store if we have a target pod and model
67+
if reqCtx.TargetPod != "" && reqCtx.ResolvedTargetModel != "" {
68+
// Get the prompt from the request context
69+
prompt := reqCtx.Prompt
70+
if prompt != "" {
71+
// Convert TargetPod string to NamespacedName
72+
parts := strings.Split(reqCtx.TargetPod, "/")
73+
if len(parts) != 2 {
74+
logger.Error(nil, "Invalid TargetPod format", "targetPod", reqCtx.TargetPod)
75+
return reqCtx, nil
76+
}
77+
podName := types.NamespacedName{
78+
Namespace: parts[0],
79+
Name: parts[1],
80+
}
81+
82+
// Add the prefix to the store
83+
err := s.scheduler.GetPrefixStore().AddPrefix(ctx, prompt, podName, reqCtx.ResolvedTargetModel)
84+
if err != nil {
85+
logger.Error(err, "Failed to add prefix to store", "prefix", prompt, "pod", reqCtx.TargetPod, "model", reqCtx.ResolvedTargetModel)
86+
} else {
87+
logger.Info("Added prefix to store", "prefix", prompt, "pod", reqCtx.TargetPod, "model", reqCtx.ResolvedTargetModel)
88+
}
89+
}
90+
}
91+
6592
reqCtx.respBodyResp = &extProcPb.ProcessingResponse{
6693
// The Endpoint Picker supports two approaches to communicating the target endpoint, as a request header
6794
// and as an unstructure ext-proc response metadata key/value pair. This enables different integration

pkg/epp/handlers/server.go

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ import (
4040
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
4141
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
4242
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
43+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
4344
)
4445

4546
func NewStreamingServer(scheduler Scheduler, destinationEndpointHintMetadataNamespace, destinationEndpointHintKey string, datastore datastore.Datastore) *StreamingServer {
@@ -66,6 +67,7 @@ type StreamingServer struct {
6667

6768
type Scheduler interface {
6869
Schedule(ctx context.Context, b *schedulingtypes.LLMRequest) (targetPod schedulingtypes.Pod, err error)
70+
GetPrefixStore() *scheduling.PrefixStore
6971
}
7072

7173
// RequestContext stores context information during the life time of an HTTP request.
@@ -75,6 +77,7 @@ type RequestContext struct {
7577
Model string
7678
SessionID string
7779
ResolvedTargetModel string
80+
Prompt string
7881
RequestReceivedTimestamp time.Time
7982
ResponseCompleteTimestamp time.Time
8083
RequestSize int

pkg/epp/scheduling/prefix_aware_scorer.go

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package scheduling
1919
import (
2020
"sigs.k8s.io/controller-runtime/pkg/log"
2121
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
22+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
2223
)
2324

2425
// PrefixAwareScorer is a routing scorer that scores pods based on the longest prefix match
@@ -44,7 +45,9 @@ func (s *PrefixAwareScorer) ScoreTargets(ctx *types.Context, pods []*types.PodMe
4445

4546
// Get the prompt from the request
4647
prompt := ctx.Req.Prompt
48+
4749
if prompt == "" {
50+
logger.V(logging.DEBUG).Info("Empty prompt, returning zero scores for all pods")
4851
// If no prompt, return zero scores for all pods
4952
for i, pod := range pods {
5053
scoredPods[i] = PodScore{
@@ -58,6 +61,7 @@ func (s *PrefixAwareScorer) ScoreTargets(ctx *types.Context, pods []*types.PodMe
5861
// Find the best matching pod for the prompt
5962
matchedPod, found := s.prefixStore.FindPodForPrefix(ctx, prompt, ctx.Req.ResolvedTargetModel)
6063
if !found {
64+
logger.V(logging.DEBUG).Info("No matching prefix found, returning zero scores for all pods")
6165
// If no matching prefix found, return zero scores for all pods
6266
for i, pod := range pods {
6367
scoredPods[i] = PodScore{
@@ -71,12 +75,18 @@ func (s *PrefixAwareScorer) ScoreTargets(ctx *types.Context, pods []*types.PodMe
7175
// Assign scores based on pod match
7276
for i, pod := range pods {
7377
if pod.NamespacedName == matchedPod {
74-
logger.Info("Pod found for prefix", "prompt", prompt, "pod", pod.NamespacedName.String())
78+
logger.V(logging.DEBUG).Info("Pod matched for prefix",
79+
"prompt", prompt,
80+
"pod", pod.NamespacedName.String(),
81+
"score", s.weight)
7582
scoredPods[i] = PodScore{
7683
Score: s.weight, // Use the configured weight for the matching pod
7784
Pod: pod,
7885
}
7986
} else {
87+
logger.V(logging.DEBUG).Info("Pod did not match",
88+
"pod", pod.NamespacedName.String(),
89+
"score", 0)
8090
scoredPods[i] = PodScore{
8191
Score: 0, // Zero score for non-matching pods
8292
Pod: pod,

pkg/epp/scheduling/prefix_store.go

+19-7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"k8s.io/apimachinery/pkg/types"
1010
"sigs.k8s.io/controller-runtime/pkg/log"
1111
errutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/error"
12+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
1213
)
1314

1415
// PrefixEntry represents a single entry in the prefix store
@@ -35,6 +36,8 @@ type PrefixStore struct {
3536

3637
// NewPrefixStore creates a new PrefixStore with the given configuration
3738
func NewPrefixStore(config PrefixStoreConfig) *PrefixStore {
39+
logger := log.FromContext(context.Background())
40+
logger.V(logging.DEBUG).Info("Creating new PrefixStore", "config", config)
3841
return &PrefixStore{
3942
tree: radix.New(),
4043
config: config,
@@ -56,13 +59,15 @@ func (ps *PrefixStore) AddPrefix(ctx context.Context, prefix string, pod types.N
5659
}
5760
}
5861
if len(prefix) > ps.config.MaxPrefixLen {
62+
logger.V(logging.DEBUG).Info("Truncating prefix", "originalLength", len(prefix), "maxLength", ps.config.MaxPrefixLen)
5963
prefix = prefix[:ps.config.MaxPrefixLen]
6064
}
6165

6266
// Check if we're updating an existing entry
6367
if val, exists := ps.tree.Get(prefix); exists {
6468
entry := val.(*PrefixEntry)
6569
if entry.PodRef == pod && entry.ModelName == modelName {
70+
logger.V(logging.DEBUG).Info("Updating existing entry", "prefix", prefix, "pod", pod.String())
6671
entry.LastUsed = time.Now()
6772
ps.tree.Insert(prefix, entry)
6873
return nil
@@ -71,6 +76,7 @@ func (ps *PrefixStore) AddPrefix(ctx context.Context, prefix string, pod types.N
7176

7277
// Check total entries limit
7378
if ps.tree.Len() >= ps.config.MaxEntries {
79+
logger.V(logging.DEBUG).Info("Store at capacity, evicting oldest entry", "currentSize", ps.tree.Len(), "maxSize", ps.config.MaxEntries)
7480
ps.evictOldest()
7581
}
7682

@@ -82,7 +88,7 @@ func (ps *PrefixStore) AddPrefix(ctx context.Context, prefix string, pod types.N
8288
}
8389
ps.tree.Insert(prefix, entry)
8490

85-
logger.Info("Added prefix entry", "prefix", prefix, "pod", pod.String(), "model", modelName)
91+
logger.V(logging.DEBUG).Info("Successfully added new prefix entry", "prefix", prefix, "pod", pod.String(), "model", modelName, "totalEntries", ps.tree.Len())
8692
return nil
8793
}
8894

@@ -94,32 +100,36 @@ func (ps *PrefixStore) FindPodForPrefix(ctx context.Context, prefix string, mode
94100
logger := log.FromContext(ctx)
95101

96102
if len(prefix) < ps.config.MinPrefixLen {
103+
logger.V(logging.DEBUG).Info("Prefix too short", "prefix", prefix, "minLength", ps.config.MinPrefixLen)
97104
return types.NamespacedName{}, false
98105
}
99106

100107
if len(prefix) > ps.config.MaxPrefixLen {
108+
logger.V(logging.DEBUG).Info("Truncating prefix", "originalLength", len(prefix), "maxLength", ps.config.MaxPrefixLen)
101109
prefix = prefix[:ps.config.MaxPrefixLen]
102110
}
103111

104112
// Use LongestPrefix to find the best match
105113
matchedPrefix, val, found := ps.tree.LongestPrefix(prefix)
106114
if !found {
115+
logger.V(logging.DEBUG).Info("No matching prefix found", "prefix", prefix)
107116
return types.NamespacedName{}, false
108117
}
109118

110119
entry := val.(*PrefixEntry)
111120

112121
// Check if entry has expired or model doesn't match
113-
if time.Since(entry.LastUsed) > ps.config.EntryTTL || entry.ModelName != modelName {
114-
// Don't remove here to avoid write lock
122+
if time.Since(entry.LastUsed) > ps.config.EntryTTL {
123+
return types.NamespacedName{}, false
124+
}
125+
if entry.ModelName != modelName {
115126
return types.NamespacedName{}, false
116127
}
117128

118129
// Update LastUsed time for the matched entry
119130
entry.LastUsed = time.Now()
120131
ps.tree.Insert(matchedPrefix, entry)
121132

122-
logger.Info("Found pod for prefix", "prefix", prefix, "matchedPrefix", matchedPrefix, "pod", entry.PodRef.String(), "model", modelName)
123133
return entry.PodRef, true
124134
}
125135

@@ -169,7 +179,9 @@ func (ps *PrefixStore) cleanupExpired(ctx context.Context) {
169179
}
170180

171181
if len(keysToDelete) > 0 {
172-
logger.Info("Cleaned up expired entries", "count", len(keysToDelete))
182+
logger.V(logging.DEBUG).Info("Cleaned up expired entries", "count", len(keysToDelete), "remainingEntries", ps.tree.Len())
183+
} else {
184+
logger.V(logging.DEBUG).Info("No expired entries found", "totalEntries", ps.tree.Len())
173185
}
174186
}
175187

@@ -179,14 +191,14 @@ func (ps *PrefixStore) RunMaintenance(ctx context.Context) {
179191
ticker := time.NewTicker(ps.config.EntryTTL / 2)
180192
defer ticker.Stop()
181193

194+
logger.V(logging.DEBUG).Info("Starting maintenance routine", "interval", ps.config.EntryTTL/2)
195+
182196
for {
183197
select {
184198
case <-ctx.Done():
185-
logger.Info("Maintenance routine stopping")
186199
return
187200
case <-ticker.C:
188201
ps.cleanupExpired(ctx)
189-
logger.Info("Completed maintenance cycle")
190202
}
191203
}
192204
}

pkg/epp/scheduling/scheduler.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ type Scheduler struct {
137137

138138
func NewScheduler(datastore Datastore) *Scheduler {
139139
sMng := NewScorerMng()
140-
sMng.addScorer(NewSessionAffinityScorer(1, datastore))
140+
sMng.addScorer(NewSessionAffinityScorer(0, datastore))
141141

142142
// Initialize prefix store with configuration from environment variables
143143
prefixStore := NewPrefixStore(PrefixStoreConfig{
@@ -193,3 +193,8 @@ func (s *Scheduler) Schedule(ctx context.Context, req *types.LLMRequest) (target
193193

194194
return selectedPod, nil
195195
}
196+
197+
// GetPrefixStore returns the prefix store for this scheduler
198+
func (s *Scheduler) GetPrefixStore() *PrefixStore {
199+
return s.prefixStore
200+
}

pkg/epp/scheduling/scorer.go

+1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ func (sm *ScorerMng) scoreTargets(ctx *types.Context, pods []*types.PodMetrics)
9292
if isFirst {
9393
maxScore = score
9494
highestScoreTargets = []*types.PodMetrics{pod}
95+
isFirst = false
9596
} else {
9697
if score > maxScore {
9798
maxScore = score

0 commit comments

Comments
 (0)