@@ -18,6 +18,7 @@ package slo_aware_router
1818
1919import (
2020 "context"
21+ "encoding/json"
2122 "errors"
2223 "fmt"
2324 "testing"
@@ -30,6 +31,7 @@ import (
3031 schedulingtypes "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
3132 requtil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/request"
3233 latencypredictor "sigs.k8s.io/gateway-api-inference-extension/sidecars/latencypredictorasync"
34+ "sigs.k8s.io/gateway-api-inference-extension/test/utils"
3335)
3436
3537// mockPredictor implements PredictorInterface for testing
@@ -535,3 +537,161 @@ func TestSLOAwareRouter_GetPrefixCacheScoreForPod(t *testing.T) {
535537 })
536538 }
537539}
540+
541+ func TestSLOAwareRouterFactory (t * testing.T ) {
542+ tests := []struct {
543+ name string
544+ pluginName string
545+ jsonParams string
546+ expectErr bool
547+ }{
548+ {
549+ name : "valid config with all fields" ,
550+ pluginName : "full-config" ,
551+ jsonParams : `{
552+ "samplingMean": 150.0,
553+ "maxSampledTokens": 30,
554+ "sloBufferFactor": 1.2,
555+ "negHeadroomTTFTWeight": 0.7,
556+ "negHeadroomTPOTWeight": 0.3,
557+ "headroomTTFTWeight": 0.9,
558+ "headroomTPOTWeight": 0.1,
559+ "headroomSelectionStrategy": "least",
560+ "compositeKVWeight": 1.0,
561+ "compositeQueueWeight": 0.8,
562+ "compositePrefixWeight": 0.5,
563+ "epsilonExploreSticky": 0.02,
564+ "epsilonExploreNeg": 0.03,
565+ "affinityGateTau": 0.85,
566+ "affinityGateTauGlobal": 0.95,
567+ "selectionMode": "linear"
568+ }` ,
569+ expectErr : false ,
570+ },
571+ {
572+ name : "valid config with minimal override (uses defaults)" ,
573+ pluginName : "minimal" ,
574+ jsonParams : `{}` ,
575+ expectErr : false ,
576+ },
577+ {
578+ name : "valid config with composite strategy" ,
579+ pluginName : "composite" ,
580+ jsonParams : `{
581+ "headroomSelectionStrategy": "composite-least",
582+ "selectionMode": "linear"
583+ }` ,
584+ expectErr : false ,
585+ },
586+ {
587+ name : "invalid samplingMean <= 0" ,
588+ pluginName : "bad-sampling-mean" ,
589+ jsonParams : `{"samplingMean": -1.0}` ,
590+ expectErr : true ,
591+ },
592+ {
593+ name : "invalid maxSampledTokens <= 0" ,
594+ pluginName : "bad-max-tokens" ,
595+ jsonParams : `{"maxSampledTokens": 0}` ,
596+ expectErr : true ,
597+ },
598+ {
599+ name : "invalid sloBufferFactor <= 0" ,
600+ pluginName : "bad-buffer" ,
601+ jsonParams : `{"sloBufferFactor": 0}` ,
602+ expectErr : true ,
603+ },
604+ {
605+ name : "negative headroom weight" ,
606+ pluginName : "neg-weight" ,
607+ jsonParams : `{"negHeadroomTTFTWeight": -0.1}` ,
608+ expectErr : true ,
609+ },
610+ {
611+ name : "epsilonExploreSticky > 1" ,
612+ pluginName : "epsilon-too-high" ,
613+ jsonParams : `{"epsilonExploreSticky": 1.1}` ,
614+ expectErr : true ,
615+ },
616+ {
617+ name : "epsilonExploreNeg < 0" ,
618+ pluginName : "epsilon-negative" ,
619+ jsonParams : `{"epsilonExploreNeg": -0.1}` ,
620+ expectErr : true ,
621+ },
622+ {
623+ name : "affinityGateTau out of (0,1]" ,
624+ pluginName : "tau-invalid" ,
625+ jsonParams : `{"affinityGateTau": 1.5}` ,
626+ expectErr : true ,
627+ },
628+ {
629+ name : "affinityGateTauGlobal <= 0" ,
630+ pluginName : "tau-global-zero" ,
631+ jsonParams : `{"affinityGateTauGlobal": 0}` ,
632+ expectErr : true ,
633+ },
634+ {
635+ name : "multiple validation errors" ,
636+ pluginName : "multi-error" ,
637+ jsonParams : `{
638+ "samplingMean": -1,
639+ "maxSampledTokens": 0,
640+ "epsilonExploreSticky": 2.0,
641+ "headroomSelectionStrategy": "unknown"
642+ }` ,
643+ expectErr : true ,
644+ },
645+ }
646+
647+ for _ , tt := range tests {
648+ t .Run (tt .name , func (t * testing.T ) {
649+ handle := utils .NewTestHandle (context .Background ())
650+ rawParams := json .RawMessage (tt .jsonParams )
651+ plugin , err := SLOAwareRouterFactory (tt .pluginName , rawParams , handle )
652+
653+ if tt .expectErr {
654+ assert .Error (t , err )
655+ assert .Nil (t , plugin )
656+ } else {
657+ assert .NoError (t , err )
658+ assert .NotNil (t , plugin )
659+ }
660+ })
661+ }
662+ }
663+
664+ func TestSLOAwareRouterFactoryInvalidJSON (t * testing.T ) {
665+ invalidTests := []struct {
666+ name string
667+ jsonParams string
668+ }{
669+ {
670+ name : "malformed JSON" ,
671+ jsonParams : `{"samplingMean": 100.0, "maxSampledTokens":` , // incomplete
672+ },
673+ {
674+ name : "samplingMean as string" ,
675+ jsonParams : `{"samplingMean": "100"}` ,
676+ },
677+ {
678+ name : "maxSampledTokens as float" ,
679+ jsonParams : `{"maxSampledTokens": 20.5}` ,
680+ },
681+ {
682+ name : "headroomSelectionStrategy as number" ,
683+ jsonParams : `{"headroomSelectionStrategy": 123}` ,
684+ },
685+ }
686+
687+ for _ , tt := range invalidTests {
688+ t .Run (tt .name , func (t * testing.T ) {
689+ handle := utils .NewTestHandle (context .Background ())
690+ rawParams := json .RawMessage (tt .jsonParams )
691+ plugin , err := SLOAwareRouterFactory ("test" , rawParams , handle )
692+
693+ assert .Error (t , err )
694+ assert .Nil (t , plugin )
695+ })
696+ }
697+ }
0 commit comments