Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,25 @@ test-unit-%: download-tokenizer install-python-deps check-dependencies ## Run un
PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \
CGO_CFLAGS=${$*_CGO_CFLAGS} CGO_LDFLAGS=${$*_CGO_LDFLAGS} go test $($*_LDFLAGS) -v $$($($*_TEST_FILES) | tr '\n' ' ')

.PHONY: test-filter
test-filter: download-tokenizer install-python-deps check-dependencies ## Run filtered unit tests (usage: make test-filter PATTERN=TestName TYPE=epp)
@if [ -z "$(PATTERN)" ]; then \
echo "ERROR: PATTERN is required. Usage: make test-filter PATTERN=TestName [TYPE=epp|sidecar]"; \
exit 1; \
fi
@TEST_TYPE="$(if $(TYPE),$(TYPE),epp)"; \
printf "\033[33;1m==== Running Filtered Tests (pattern: $(PATTERN), type: $$TEST_TYPE) ====\033[0m\n"; \
KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache-manager 2>/dev/null || echo ""); \
if [ "$$TEST_TYPE" = "epp" ]; then \
PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \
CGO_CFLAGS=$(epp_CGO_CFLAGS) CGO_LDFLAGS=$(epp_CGO_LDFLAGS) \
go test $(epp_LDFLAGS) -v -run "$(PATTERN)" $$($(epp_TEST_FILES) | tr '\n' ' '); \
else \
PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \
CGO_CFLAGS=$(sidecar_CGO_CFLAGS) CGO_LDFLAGS=$(sidecar_CGO_LDFLAGS) \
go test $(sidecar_LDFLAGS) -v -run "$(PATTERN)" $$($(sidecar_TEST_FILES) | tr '\n' ' '); \
fi

.PHONY: test-integration
test-integration: download-tokenizer check-dependencies ## Run integration tests
@printf "\033[33;1m==== Running Integration Tests ====\033[0m\n"
Expand Down
71 changes: 44 additions & 27 deletions pkg/plugins/scorer/active_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"sync"
"time"

Expand Down Expand Up @@ -38,13 +39,13 @@ type ActiveRequestParameters struct {

// requestEntry represents a single request in the cache
type requestEntry struct {
PodName string
PodNames []string
RequestID string
}

// String returns a string representation of the request entry.
func (r *requestEntry) String() string {
return fmt.Sprintf("%s.%s", r.PodName, r.RequestID)
func (r requestEntry) String() string {
return fmt.Sprintf("%s:%s", r.RequestID, strings.Join(r.PodNames, "."))
}

// compile-time type assertion
Expand Down Expand Up @@ -97,7 +98,9 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act
requestCache.OnEviction(func(_ context.Context, reason ttlcache.EvictionReason,
item *ttlcache.Item[string, *requestEntry]) {
if reason == ttlcache.EvictionReasonExpired {
scorer.decrementPodCount(item.Value().PodName)
for _, podName := range item.Value().PodNames {
scorer.decrementPodCount(podName)
}
}
})

Expand Down Expand Up @@ -166,47 +169,61 @@ func (s *ActiveRequest) Score(ctx context.Context, _ *types.CycleState, _ *types
// PreRequest is called before a request is sent to the target pod.
// It creates a new request entry in the cache with its own TTL and
// increments the pod count for fast lookup.
func (s *ActiveRequest) PreRequest(ctx context.Context, request *types.LLMRequest,
schedulingResult *types.SchedulingResult) {
func (s *ActiveRequest) PreRequest(
ctx context.Context,
request *types.LLMRequest,
schedulingResult *types.SchedulingResult,
) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG)

for _, profileResult := range schedulingResult.ProfileResults { // schedulingResult guaranteed not to be nil
if profileResult == nil || profileResult.TargetPods == nil || len(profileResult.TargetPods) == 0 {
podNames := make([]string, 0, len(schedulingResult.ProfileResults))
for profileName, profileResult := range schedulingResult.ProfileResults {
if profileResult == nil || len(profileResult.TargetPods) == 0 {
continue
}

// create request entry for first pod only. TODO: support fallback pods
entry := &requestEntry{
PodName: profileResult.TargetPods[0].GetPod().NamespacedName.String(),
RequestID: request.RequestId,
}

// add to request cache with TTL
s.requestCache.Set(entry.String(), entry, 0) // Use default TTL
s.incrementPodCount(entry.PodName)

debugLogger.Info("Added request to cache", "requestEntry", entry.String())
podName := profileResult.TargetPods[0].GetPod().NamespacedName.String()
podNames = append(podNames, podName)
s.incrementPodCount(podName)
debugLogger.Info(
"Added request to cache",
"requestId", request.RequestId,
"podName", podName,
"profileName", profileName,
)
}

// add to request cache
s.requestCache.Set(request.RequestId, &requestEntry{PodNames: podNames, RequestID: request.RequestId}, 0) // Use default TTL
}

// ResponseComplete is called after a response is sent to the client.
// It removes the specific request entry from the cache and decrements
// the pod count.
func (s *ActiveRequest) ResponseComplete(ctx context.Context, request *types.LLMRequest,
_ *requestcontrol.Response, targetPod *backend.Pod) {
func (s *ActiveRequest) ResponseComplete(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vMaroon just to make sure, this is called once per request correct? This PR makes this assumption

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes 👍🏻

ctx context.Context,
request *types.LLMRequest,
_ *requestcontrol.Response,
targetPod *backend.Pod,
) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequest.ResponseComplete")
if targetPod == nil {
debugLogger.Info("Skipping ResponseComplete because targetPod is nil")
return
}

entry := requestEntry{targetPod.NamespacedName.String(), request.RequestId}

if _, found := s.requestCache.GetAndDelete(entry.String()); found {
s.decrementPodCount(entry.PodName)
debugLogger.Info("Removed request from cache", "requestEntry", entry.String())
if item, found := s.requestCache.GetAndDelete(request.RequestId); found {
entry := item.Value()
if entry != nil {
for _, podName := range entry.PodNames {
s.decrementPodCount(podName)
}
debugLogger.Info("Removed request from cache", "requestEntry", entry.String())
} else {
debugLogger.Info("Request entry value is nil", "requestId", request.RequestId)
}
} else {
debugLogger.Info("Request not found in cache", "requestEntry", entry.String())
debugLogger.Info("Request not found in cache", "requestId", request.RequestId)
}
}

Expand Down
Loading