diff --git a/pkg/cache/tas_flavor_snapshot.go b/pkg/cache/tas_flavor_snapshot.go index dc07eef0b1..20f29eddb8 100644 --- a/pkg/cache/tas_flavor_snapshot.go +++ b/pkg/cache/tas_flavor_snapshot.go @@ -333,7 +333,9 @@ func (s *TASFlavorSnapshot) findTopologyAssignment( return nil, fmt.Sprintf("no requested topology level: %s", *key) } // phase 1 - determine the number of pods which can fit in each topology domain - s.fillInCounts(requests, assumedUsage, simulateEmpty, append(podSetTolerations, s.tolerations...)) + if err := s.fillInCounts(requests, assumedUsage, simulateEmpty, append(podSetTolerations, s.tolerations...)); err != nil { + return nil, fmt.Sprintf("unexpected pods fitting calculation failure: %v", err) + } // phase 2a: determine the level at which the assignment is done along with // the domains which can accommodate all pods @@ -508,7 +510,7 @@ func (s *TASFlavorSnapshot) sortedDomains(domains []*domain) []*domain { func (s *TASFlavorSnapshot) fillInCounts(requests resources.Requests, assumedUsage map[utiltas.TopologyDomainID]resources.Requests, simulateEmpty bool, - tolerations []corev1.Toleration) { + tolerations []corev1.Toleration) error { for _, domain := range s.domains { // cleanup the state in case some remaining values are present from computing // assignments for previous PodSets. @@ -529,11 +531,15 @@ func (s *TASFlavorSnapshot) fillInCounts(requests resources.Requests, if leafAssumedUsage, found := assumedUsage[leaf.domain.id]; found { remainingCapacity.Sub(leafAssumedUsage) } - leaf.state = requests.CountIn(remainingCapacity) + var err error + if leaf.state, err = requests.CountIn(remainingCapacity); err != nil { + return err + } } for _, root := range s.roots { root.state = s.fillInCountsHelper(root) } + return nil } func (s *TASFlavorSnapshot) fillInCountsHelper(domain *domain) int32 { diff --git a/pkg/resources/requests.go b/pkg/resources/requests.go index 204705c0a3..5aef56c884 100644 --- a/pkg/resources/requests.go +++ b/pkg/resources/requests.go @@ -17,6 +17,7 @@ limitations under the License. package resources import ( + "errors" "maps" "strings" @@ -25,6 +26,10 @@ import ( "k8s.io/utils/ptr" ) +var ( + errorRequestsHasTwoOrMorePodsCount = errors.New("requests have 2 or more Pods count") +) + // The following resources calculations are inspired on // https://github.com/kubernetes/kubernetes/blob/master/pkg/scheduler/framework/types.go @@ -103,12 +108,15 @@ func ResourceQuantityString(name corev1.ResourceName, v int64) string { return rq.String() } -func (req Requests) CountIn(capacity Requests) int32 { +func (r Requests) CountIn(capacity Requests) (int32, error) { + if count, ok := r[corev1.ResourcePods]; ok && count > 1 { + return 0, errorRequestsHasTwoOrMorePodsCount + } var result *int32 - for rName, rValue := range req { + for rName, rValue := range r { capacity, found := capacity[rName] if !found { - return 0 + return 0, nil } // find the minimum count matching all the resource quota. count := int32(capacity / rValue) @@ -116,5 +124,5 @@ func (req Requests) CountIn(capacity Requests) int32 { result = ptr.To(count) } } - return ptr.Deref(result, 0) + return ptr.Deref(result, 0), nil } diff --git a/pkg/resources/requests_test.go b/pkg/resources/requests_test.go index ad0f9c046f..1418d356a4 100644 --- a/pkg/resources/requests_test.go +++ b/pkg/resources/requests_test.go @@ -19,6 +19,8 @@ package resources import ( "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" corev1 "k8s.io/api/core/v1" ) @@ -27,6 +29,7 @@ func TestCountIn(t *testing.T) { requests Requests capacity Requests wantResult int32 + wantError error }{ "requests equal capacity": { requests: Requests{ @@ -80,10 +83,25 @@ func TestCountIn(t *testing.T) { }, wantResult: 2, }, + "requests have 2 or more Pods count": { + requests: Requests{ + corev1.ResourceCPU: 2, + corev1.ResourcePods: 2, + }, + capacity: Requests{ + corev1.ResourceCPU: 5, + corev1.ResourcePods: 10, + }, + wantResult: 0, + wantError: errorRequestsHasTwoOrMorePodsCount, + }, } for name, tc := range cases { t.Run(name, func(t *testing.T) { - gotResult := tc.requests.CountIn(tc.capacity) + gotResult, err := tc.requests.CountIn(tc.capacity) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("unexpected error (-want,+got\n):%s", diff) + } if tc.wantResult != gotResult { t.Errorf("unexpected result, want=%d, got=%d", tc.wantResult, gotResult) }