Skip to content

Commit 779e85d

Browse files
authored
add config validation in predicted-latency-scorer plugin (#1904)
Signed-off-by: CYJiang <googs1025@gmail.com>
1 parent a88e4fb commit 779e85d

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package slo_aware_router
1919
import (
2020
"context"
2121
"encoding/json"
22+
"errors"
2223
"fmt"
2324
"math/rand"
2425
"sync"
@@ -93,6 +94,10 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl
9394
}
9495
}
9596

97+
if err := parameters.validate(); err != nil {
98+
return nil, fmt.Errorf("invalid SLOAwareRouter config: %w", err)
99+
}
100+
96101
predictor, err := startPredictor(handle)
97102
if err != nil {
98103
return nil, fmt.Errorf("failed to start latency predictor: %w", err)
@@ -101,6 +106,50 @@ func SLOAwareRouterFactory(name string, rawParameters json.RawMessage, handle pl
101106
return NewSLOAwareRouter(parameters, predictor).WithName(name), nil
102107
}
103108

109+
func (c *Config) validate() error {
110+
var errs []error
111+
112+
if c.SamplingMean <= 0 {
113+
errs = append(errs, fmt.Errorf("samplingMean must be > 0, got %f", c.SamplingMean))
114+
}
115+
116+
if c.MaxSampledTokens <= 0 {
117+
errs = append(errs, fmt.Errorf("maxSampledTokens must be > 0, got %d", c.MaxSampledTokens))
118+
}
119+
120+
if c.SLOBufferFactor <= 0 {
121+
errs = append(errs, fmt.Errorf("sloBufferFactor must be > 0, got %f", c.SLOBufferFactor))
122+
}
123+
124+
if c.NegHeadroomTTFTWeight < 0 || c.NegHeadroomTPOTWeight < 0 ||
125+
c.HeadroomTTFTWeight < 0 || c.HeadroomTPOTWeight < 0 {
126+
errs = append(errs, errors.New("all headroom weights must be >= 0"))
127+
}
128+
129+
if c.CompositeKVWeight < 0 || c.CompositeQueueWeight < 0 || c.CompositePrefixWeight < 0 {
130+
errs = append(errs, errors.New("composite weights must be >= 0"))
131+
}
132+
133+
if c.EpsilonExploreSticky < 0 || c.EpsilonExploreSticky > 1 {
134+
errs = append(errs, fmt.Errorf("epsilonExploreSticky must be in [0, 1], got %f", c.EpsilonExploreSticky))
135+
}
136+
if c.EpsilonExploreNeg < 0 || c.EpsilonExploreNeg > 1 {
137+
errs = append(errs, fmt.Errorf("epsilonExploreNeg must be in [0, 1], got %f", c.EpsilonExploreNeg))
138+
}
139+
140+
if c.AffinityGateTau < 0 || c.AffinityGateTau > 1 {
141+
errs = append(errs, fmt.Errorf("affinityGateTau must be in [0, 1], got %f", c.AffinityGateTau))
142+
}
143+
if c.AffinityGateTauGlobal <= 0 || c.AffinityGateTauGlobal > 1 {
144+
errs = append(errs, fmt.Errorf("affinityGateTauGlobal must be in (0, 1], got %f", c.AffinityGateTauGlobal))
145+
}
146+
147+
if len(errs) > 0 {
148+
return errors.Join(errs...)
149+
}
150+
return nil
151+
}
152+
104153
func NewSLOAwareRouter(config Config, predictor latencypredictor.PredictorInterface) *SLOAwareRouter {
105154
strategy := headroomStrategy(config.HeadroomSelectionStrategy)
106155
if strategy == "" {

pkg/epp/scheduling/framework/plugins/multi/slo_aware_router/scorer_test.go

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package slo_aware_router
1818

1919
import (
2020
"context"
21+
"encoding/json"
2122
"errors"
2223
"fmt"
2324
"testing"
@@ -30,6 +31,7 @@ import (
3031
schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3132
requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3233
latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
34+
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
3335
)
3436

3537
// mockPredictor implements PredictorInterface for testing
@@ -535,3 +537,161 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) {
535537
})
536538
}
537539
}
540+
541+
func TestSLOAwareRouterFactory(t *testing.T) {
542+
tests := []struct {
543+
name string
544+
pluginName string
545+
jsonParams string
546+
expectErr bool
547+
}{
548+
{
549+
name: "valid config with all fields",
550+
pluginName: "full-config",
551+
jsonParams: `{
552+
"samplingMean": 150.0,
553+
"maxSampledTokens": 30,
554+
"sloBufferFactor": 1.2,
555+
"negHeadroomTTFTWeight": 0.7,
556+
"negHeadroomTPOTWeight": 0.3,
557+
"headroomTTFTWeight": 0.9,
558+
"headroomTPOTWeight": 0.1,
559+
"headroomSelectionStrategy": "least",
560+
"compositeKVWeight": 1.0,
561+
"compositeQueueWeight": 0.8,
562+
"compositePrefixWeight": 0.5,
563+
"epsilonExploreSticky": 0.02,
564+
"epsilonExploreNeg": 0.03,
565+
"affinityGateTau": 0.85,
566+
"affinityGateTauGlobal": 0.95,
567+
"selectionMode": "linear"
568+
}`,
569+
expectErr: false,
570+
},
571+
{
572+
name: "valid config with minimal override (uses defaults)",
573+
pluginName: "minimal",
574+
jsonParams: `{}`,
575+
expectErr: false,
576+
},
577+
{
578+
name: "valid config with composite strategy",
579+
pluginName: "composite",
580+
jsonParams: `{
581+
"headroomSelectionStrategy": "composite-least",
582+
"selectionMode": "linear"
583+
}`,
584+
expectErr: false,
585+
},
586+
{
587+
name: "invalid samplingMean <= 0",
588+
pluginName: "bad-sampling-mean",
589+
jsonParams: `{"samplingMean": -1.0}`,
590+
expectErr: true,
591+
},
592+
{
593+
name: "invalid maxSampledTokens <= 0",
594+
pluginName: "bad-max-tokens",
595+
jsonParams: `{"maxSampledTokens": 0}`,
596+
expectErr: true,
597+
},
598+
{
599+
name: "invalid sloBufferFactor <= 0",
600+
pluginName: "bad-buffer",
601+
jsonParams: `{"sloBufferFactor": 0}`,
602+
expectErr: true,
603+
},
604+
{
605+
name: "negative headroom weight",
606+
pluginName: "neg-weight",
607+
jsonParams: `{"negHeadroomTTFTWeight": -0.1}`,
608+
expectErr: true,
609+
},
610+
{
611+
name: "epsilonExploreSticky > 1",
612+
pluginName: "epsilon-too-high",
613+
jsonParams: `{"epsilonExploreSticky": 1.1}`,
614+
expectErr: true,
615+
},
616+
{
617+
name: "epsilonExploreNeg < 0",
618+
pluginName: "epsilon-negative",
619+
jsonParams: `{"epsilonExploreNeg": -0.1}`,
620+
expectErr: true,
621+
},
622+
{
623+
name: "affinityGateTau out of (0,1]",
624+
pluginName: "tau-invalid",
625+
jsonParams: `{"affinityGateTau": 1.5}`,
626+
expectErr: true,
627+
},
628+
{
629+
name: "affinityGateTauGlobal <= 0",
630+
pluginName: "tau-global-zero",
631+
jsonParams: `{"affinityGateTauGlobal": 0}`,
632+
expectErr: true,
633+
},
634+
{
635+
name: "multiple validation errors",
636+
pluginName: "multi-error",
637+
jsonParams: `{
638+
"samplingMean": -1,
639+
"maxSampledTokens": 0,
640+
"epsilonExploreSticky": 2.0,
641+
"headroomSelectionStrategy": "unknown"
642+
}`,
643+
expectErr: true,
644+
},
645+
}
646+
647+
for _, tt := range tests {
648+
t.Run(tt.name, func(t *testing.T) {
649+
handle := utils.NewTestHandle(context.Background())
650+
rawParams := json.RawMessage(tt.jsonParams)
651+
plugin, err := SLOAwareRouterFactory(tt.pluginName, rawParams, handle)
652+
653+
if tt.expectErr {
654+
assert.Error(t, err)
655+
assert.Nil(t, plugin)
656+
} else {
657+
assert.NoError(t, err)
658+
assert.NotNil(t, plugin)
659+
}
660+
})
661+
}
662+
}
663+
664+
func TestSLOAwareRouterFactoryInvalidJSON(t *testing.T) {
665+
invalidTests := []struct {
666+
name string
667+
jsonParams string
668+
}{
669+
{
670+
name: "malformed JSON",
671+
jsonParams: `{"samplingMean": 100.0, "maxSampledTokens":`, // incomplete
672+
},
673+
{
674+
name: "samplingMean as string",
675+
jsonParams: `{"samplingMean": "100"}`,
676+
},
677+
{
678+
name: "maxSampledTokens as float",
679+
jsonParams: `{"maxSampledTokens": 20.5}`,
680+
},
681+
{
682+
name: "headroomSelectionStrategy as number",
683+
jsonParams: `{"headroomSelectionStrategy": 123}`,
684+
},
685+
}
686+
687+
for _, tt := range invalidTests {
688+
t.Run(tt.name, func(t *testing.T) {
689+
handle := utils.NewTestHandle(context.Background())
690+
rawParams := json.RawMessage(tt.jsonParams)
691+
plugin, err := SLOAwareRouterFactory("test", rawParams, handle)
692+
693+
assert.Error(t, err)
694+
assert.Nil(t, plugin)
695+
})
696+
}
697+
}

0 commit comments

Comments
 (0)