diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 000000000..73ae7ffbe
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,21 @@
+# Git
+.git
+.gitignore
+
+# Build artifacts
+bin
+build
+
+# IDE and OS files
+.idea
+.vscode
+*.DS_Store
+
+# Local virtual environments
+venv
+
+# Python cache files
+__pycache__
+
+# Docker files
+Dockerfile
diff --git a/.github/workflows/auto-assign.yaml b/.github/workflows/auto-assign.yaml
index 9fa5e0447..4c9eca29f 100644
--- a/.github/workflows/auto-assign.yaml
+++ b/.github/workflows/auto-assign.yaml
@@ -2,9 +2,9 @@ name: Auto Assign Reviewers
on:
pull_request:
- types: [opened, reopened, synchronize]
+ types: [opened]
pull_request_target:
- types: [opened, reopened, synchronize]
+ types: [opened]
permissions:
contents: read
diff --git a/.github/workflows/check-typos.yaml b/.github/workflows/check-typos.yaml
index ea4fb6859..6bb82afe5 100644
--- a/.github/workflows/check-typos.yaml
+++ b/.github/workflows/check-typos.yaml
@@ -13,5 +13,5 @@ jobs:
uses: actions/checkout@v6
- name: Check typos
- uses: crate-ci/typos@v1.42.0
+ uses: crate-ci/typos@v1.43.0
diff --git a/.github/workflows/ci-pr-checks.yaml b/.github/workflows/ci-pr-checks.yaml
index dbf6e3dd9..ccbdf1e05 100644
--- a/.github/workflows/ci-pr-checks.yaml
+++ b/.github/workflows/ci-pr-checks.yaml
@@ -12,7 +12,7 @@ jobs:
check-changes:
runs-on: ubuntu-latest
outputs:
- docs: ${{ steps.filter.outputs.docs }}
+ src: ${{ steps.filter.outputs.src }}
steps:
- name: Checkout source
uses: actions/checkout@v6
@@ -20,14 +20,23 @@ jobs:
id: filter
with:
filters: |
- docs:
- - 'README.md'
- - 'docs/**'
+ src:
+ - '**/*.go'
+ - '**/*.py'
+ - Dockerfile.epp
+ - Dockerfile.sidecar
+ - Makefile*
+ - go.mod
lint-and-test:
needs: check-changes
- if: ${{ needs.check-changes.outputs.docs == 'false' }}
+ if: ${{ needs.check-changes.outputs.src == 'true' }}
runs-on: ubuntu-latest
steps:
+ - name: Free Disk Space (Ubuntu)
+ uses: jlumbroso/free-disk-space@main
+ with:
+ tool-cache: false
+
- name: Checkout source
uses: actions/checkout@v6
@@ -43,9 +52,6 @@ jobs:
go-version: "${{ env.GO_VERSION }}"
cache-dependency-path: ./go.sum
- - name: Install dependencies
- run: sudo make install-dependencies
-
- name: Configure CGO for Python
run: |
PYTHON_INCLUDE=$(python3 -c "import sysconfig; print(sysconfig.get_path('include'))")
@@ -57,14 +63,17 @@ jobs:
- name: Set PKG_CONFIG_PATH
run: echo "PKG_CONFIG_PATH=/usr/lib/pkgconfig" >> $GITHUB_ENV
- - name: go mod tidy
- run: go mod tidy
+ - name: Install dependencies
+ run: |
+ go mod tidy
+ sudo -E env "PATH=$PATH" make install-dependencies install-python-deps
- name: Run lint checks
uses: golangci/golangci-lint-action@v9
with:
- version: 'v2.1.6'
+ version: "v2.1.6"
args: "--config=./.golangci.yml"
+ skip-cache: true
env:
CGO_ENABLED: ${{ env.CGO_ENABLED }}
CGO_CFLAGS: ${{ env.CGO_CFLAGS }}
@@ -74,10 +83,8 @@ jobs:
- name: Run make build
shell: bash
- run: |
- make build
+ run: make build
- name: Run make test
shell: bash
- run: |
- make test
+ run: make test
diff --git a/.github/workflows/ci-release.yaml b/.github/workflows/ci-release.yaml
index 233debc30..6734738e6 100644
--- a/.github/workflows/ci-release.yaml
+++ b/.github/workflows/ci-release.yaml
@@ -11,6 +11,11 @@ jobs:
docker-build-and-push:
runs-on: ubuntu-latest
steps:
+ - name: Free Disk Space (Ubuntu)
+ uses: jlumbroso/free-disk-space@main
+ with:
+ tool-cache: false
+
- name: Checkout source
uses: actions/checkout@v6
diff --git a/.github/workflows/dispatch-on-lgtm.yml b/.github/workflows/dispatch-on-lgtm.yml
new file mode 100644
index 000000000..61e66a751
--- /dev/null
+++ b/.github/workflows/dispatch-on-lgtm.yml
@@ -0,0 +1,18 @@
+name: ChatOps Dispatcher
+on:
+ issue_comment:
+ types: [created]
+
+jobs:
+ dispatch:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Slash Command Dispatch
+ uses: peter-evans/slash-command-dispatch@v3
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ commands: lgtm
+ # 1. Basic Security: Only allow users with 'write' access to even trigger this
+ permission: write
+ issue-type: pull-request
+ reactions: false
diff --git a/.github/workflows/lgtm-command.yml b/.github/workflows/lgtm-command.yml
new file mode 100644
index 000000000..5f7d93e2c
--- /dev/null
+++ b/.github/workflows/lgtm-command.yml
@@ -0,0 +1,134 @@
+# ============================================================================
+# LGTM Command Worker
+# ============================================================================
+# Handles /lgtm commands from chatops-dispatcher
+#
+# Flow:
+# 1. User comments /lgtm on PR
+# 2. chatops-dispatcher catches it and dispatches here
+# 3. This workflow:
+# - Verifies user is in OWNERS file
+# - Checks if PR is draft
+# - Checks for blocking labels (hold)
+# - Adds lgtm label
+# - Enables auto-merge
+# ============================================================================
+
+name: LGTM Command Worker
+on:
+ repository_dispatch:
+ types: [lgtm-command]
+
+env:
+ BLOCKING_LABELS: "hold"
+
+jobs:
+ apply-lgtm:
+ runs-on: ubuntu-latest
+ permissions:
+ contents: write
+ pull-requests: write
+ issues: write
+ steps:
+ - uses: actions/checkout@v4
+ - uses: tibdex/github-app-token@v1
+ id: generate-token
+ with:
+ app_id: ${{ secrets.VLLMD_BOT_APP_ID }}
+ private_key: ${{ secrets.VLLMD_BOT_APP_PRIVATE_KEY }}
+ repository: ${{ github.repository }}
+
+ # -----------------------------------------------------------------------
+ # STEP 1: AUTHORIZATION - Verify user is in OWNERS file
+ # -----------------------------------------------------------------------
+ # Only users listed as approvers in OWNERS can use /lgtm
+ # This prevents unauthorized users from approving PRs
+ - name: Check Permissions
+ env:
+ ACTOR: ${{ github.event.client_payload.github.actor }}
+ run: |
+ # Extract only the approvers section and check if ACTOR is listed
+ APPROVERS=$(sed -n '/^approvers:/,/^[^ -]/p' OWNERS | grep -v '^approvers:')
+ if echo "$APPROVERS" | grep -q "^\s*-\s*$ACTOR\s*$"; then
+ echo "User $ACTOR is authorized"
+ else
+ echo "::error:: User $ACTOR is not an approver."
+ exit 1
+ fi
+
+ # -----------------------------------------------------------------------
+ # STEP 2: VALIDATION - Check if PR is in draft mode
+ # -----------------------------------------------------------------------
+ # Draft PRs cannot be approved - they must be marked ready for review
+ # This prevents accidental approval of incomplete work
+ - name: Check Draft Status
+ env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ PR_NUMBER: ${{ github.event.client_payload.github.payload.issue.number }}
+ REPO: ${{ github.repository }}
+ run: |
+ IS_DRAFT=$(gh pr view $PR_NUMBER --repo "$REPO" --json isDraft --jq '.isDraft')
+ if [ "$IS_DRAFT" = "true" ]; then
+ echo "::error:: Cannot LGTM a Draft PR."
+ gh issue comment $PR_NUMBER --repo "$REPO" --body "β οΈ **LGTM Failed**: PR is a Draft."
+ exit 1
+ fi
+
+ # -----------------------------------------------------------------------
+ # STEP 3: BLOCKING LABELS - Check for hold
+ # -----------------------------------------------------------------------
+ # If any blocking label exists, fail immediately
+ # This prevents approving PRs that are explicitly marked as not ready
+ - name: Check for Blocking Labels
+ env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ PR_URL: ${{ github.event.client_payload.github.payload.issue.html_url }}
+ run: |
+ LABELS=$(gh pr view "$PR_URL" --json labels --jq '.labels[].name')
+ if echo "$LABELS" | grep -Eiq "^($BLOCKING_LABELS)$"; then
+ echo "::error:: PR is blocked by label."
+ gh issue comment "$PR_URL" --body " **Merge Blocked**: Please remove the \`hold\` label before merging."
+ exit 1
+ fi
+
+ # -----------------------------------------------------------------------
+ # STEP 4: APPLY LGTM - Add label, wait, then enable auto-merge
+ # -----------------------------------------------------------------------
+ # 1. Add lgtm label (triggers gatekeeper to validate)
+ # 2. Enable auto-merge (PR will merge when all checks pass)
+ - name: Apply or Cancel LGTM
+ env:
+ GH_TOKEN: ${{ steps.generate-token.outputs.token }}
+ PR_NUMBER: ${{ github.event.client_payload.github.payload.issue.number }}
+ PR_URL: ${{ github.event.client_payload.github.payload.issue.html_url }}
+ # Extract the full comment body from the dispatcher payload
+ COMMENT_BODY: ${{ github.event.client_payload.github.payload.comment.body }}
+ run: |
+ # Check if the command is a cancellation
+ if echo "$COMMENT_BODY" | grep -q "/lgtm cancel"; then
+ echo "π¨ Retracting LGTM status..."
+
+ # 1. Remove lgtm label
+ gh issue edit "$PR_NUMBER" --remove-label "lgtm" || echo "Label already gone"
+
+ # 2. Disable Auto-Merge
+ gh pr merge --disable-auto "$PR_URL" || echo "Auto-merge was not enabled"
+
+ # 3. Notify user
+ gh issue comment "$PR_URL" --body "Retracted: **LGTM** label removed and auto-merge disabled by @$ACTOR."
+
+ else
+ echo "β
Applying LGTM status..."
+
+ # 1. Add lgtm label
+ gh issue edit "$PR_NUMBER" --add-label "lgtm"
+
+ # 2. Enable auto-merge (Squash)
+ if ! gh pr merge --auto --squash "$PR_URL" 2>&1 | tee merge_output.txt; then
+ ERROR_MSG=$(cat merge_output.txt)
+ gh issue comment "$PR_URL" --body "β οΈ **Auto-merge failed**: $ERROR_MSG"
+ exit 1
+ fi
+
+ gh issue comment "$PR_URL" --body "β
**LGTM**: Approval recorded and auto-merge enabled."
+ fi
\ No newline at end of file
diff --git a/.github/workflows/lgtm-gatekeeper.yml b/.github/workflows/lgtm-gatekeeper.yml
new file mode 100644
index 000000000..ed687c627
--- /dev/null
+++ b/.github/workflows/lgtm-gatekeeper.yml
@@ -0,0 +1,45 @@
+# ============================================================================
+# LGTM Gatekeeper - Required Status Check
+# ============================================================================
+# Rules Enforced:
+# 1. PR MUST have "lgtm" label
+# 2. PR MUST NOT have blocking labels (hold)
+# ============================================================================
+
+name: LGTM Gatekeeper
+on:
+ pull_request:
+ # Run on PR open/reopen and label changes
+ # NOT on synchronize (handled by lgtm-reset.yml)
+ types: [opened, labeled, unlabeled, reopened]
+
+env:
+ BLOCKING_LABELS: "hold"
+
+jobs:
+ validate-pr:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Enforce LGTM & Blockers
+ env:
+ GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ PR_NUMBER: ${{ github.event.pull_request.number }}
+ REPO: ${{ github.repository }}
+ run: |
+ # Fetch current labels
+ LABELS=$(gh pr view $PR_NUMBER --repo "$REPO" --json labels --jq '.labels[].name')
+
+ # Check 1: IS IT BLOCKED?
+ if echo "$LABELS" | grep -Eiq "^($BLOCKING_LABELS)$"; then
+ echo "::error:: β FAILED: PR is blocked by a 'hold' label."
+ exit 1
+ fi
+
+ # Check 2: IS IT APPROVED?
+ # If Reset workflow removed the label, this check fails immediately.
+ if ! echo "$LABELS" | grep -Fqx "lgtm"; then
+ echo "::error:: β FAILED: PR is missing the 'lgtm' label."
+ exit 1
+ fi
+
+ echo "β
PASSED: LGTM present and no blockers."
diff --git a/.github/workflows/lgtm-reset.yml b/.github/workflows/lgtm-reset.yml
new file mode 100644
index 000000000..6fc69dbbc
--- /dev/null
+++ b/.github/workflows/lgtm-reset.yml
@@ -0,0 +1,46 @@
+# ============================================================================
+# LGTM Reset - Auto-Remove LGTM on New Commits
+# ============================================================================
+# Kubernetes Prow behavior: When new commits are pushed, approval is invalidated
+#
+# What It Does:
+# 1. Detects when new commits are pushed to a PR
+# 2. Removes the "lgtm" label (if present)
+# 3. Disables auto-merge (safety net)
+# 4. Posts a comment explaining why
+# ============================================================================
+
+name: LGTM Reset
+on:
+ pull_request:
+ types: [synchronize] # Triggers instantly on new commits
+
+jobs:
+ reset-lgtm:
+ runs-on: ubuntu-latest
+ permissions:
+ pull-requests: write
+ steps:
+ - uses: tibdex/github-app-token@v1
+ id: generate-token
+ with:
+ app_id: ${{ secrets.VLLMD_BOT_APP_ID }}
+ private_key: ${{ secrets.VLLMD_BOT_APP_PRIVATE_KEY }}
+ repository: ${{ github.repository }}
+ - name: Invalidate LGTM
+ env:
+ GH_TOKEN: ${{ steps.generate-token.outputs.token }}
+ PR_NUMBER: ${{ github.event.pull_request.number }}
+ REPO: ${{ github.repository }}
+ run: |
+ echo "π¨ New code pushed. Resetting LGTM status..."
+
+ # 1. Remove the label (This triggers the Gatekeeper to run again)
+ gh pr edit $PR_NUMBER --repo "$REPO" --remove-label "lgtm" || true
+
+ # 2. Disable Auto-Merge (Safety net)
+ gh pr merge --disable-auto $PR_NUMBER --repo "$REPO" || true
+
+ # 3. Notify user
+ gh issue comment $PR_NUMBER --repo "$REPO" --body "π **Reset**: New commits pushed. LGTM removed."
+
\ No newline at end of file
diff --git a/.github/workflows/prow-github.yml b/.github/workflows/prow-github.yml
index 0c5f11fd8..42a37cc70 100644
--- a/.github/workflows/prow-github.yml
+++ b/.github/workflows/prow-github.yml
@@ -18,7 +18,7 @@ jobs:
steps:
- uses: jpmcb/prow-github-actions@v2.0.0
with:
- github-token: "${{ secrets.GITHUB_TOKEN }}"
+ github-token: "${{ secrets.BOT_TOKEN }}"
prow-commands: "/assign
/unassign
/approve
@@ -27,7 +27,6 @@ jobs:
/kind
/priority
/remove
- /lgtm
/close
/reopen
/lock
diff --git a/.github/workflows/prow-pr-automerge.yml b/.github/workflows/prow-pr-automerge.yml
deleted file mode 100644
index c9bb0972b..000000000
--- a/.github/workflows/prow-pr-automerge.yml
+++ /dev/null
@@ -1,18 +0,0 @@
-# This Github workflow will check every 5m for PRs with the lgtm label and will attempt to automatically merge them.
-# If the hold label is present, it will block automatic merging.
-
-name: "Prow merge on lgtm label"
-on:
- schedule:
- - cron: "*/5 * * * *" # every 5 minutes
-
-jobs:
- auto-merge:
- runs-on: ubuntu-latest
- steps:
- - uses: jpmcb/prow-github-actions@v2.0.0
- with:
- jobs: 'lgtm'
- github-token: "${{ secrets.GITHUB_TOKEN }}"
- merge-method: 'squash'
-
diff --git a/.github/workflows/prow-pr-remove-lgtm.yml b/.github/workflows/prow-pr-remove-lgtm.yml
deleted file mode 100644
index caf208f36..000000000
--- a/.github/workflows/prow-pr-remove-lgtm.yml
+++ /dev/null
@@ -1,11 +0,0 @@
-name: Run Jobs on PR
-on: pull_request
-
-jobs:
- execute:
- runs-on: ubuntu-latest
- steps:
- - uses: jpmcb/prow-github-actions@v2.0.0
- with:
- jobs: lgtm
- github-token: '${{ secrets.GITHUB_TOKEN }}'
diff --git a/.gitignore b/.gitignore
index f94c6c6ce..7f5a0e944 100644
--- a/.gitignore
+++ b/.gitignore
@@ -11,6 +11,8 @@
main
bin/
+*debug_bin*
+
# Test binary, built with `go test -c`
*.test
@@ -27,6 +29,7 @@ go.work.sum
# Environment Files
.DS_Store
.env
+CLAUDE.md
# IDE files
.idea
diff --git a/.golangci.yml b/.golangci.yml
index bc928c25b..47e51d394 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -6,28 +6,60 @@ run:
formatters:
enable:
- - goimports
- - gofmt
+ - goimports
+ - gofmt
linters:
enable:
- - copyloopvar
- - dupword
- - durationcheck
- - fatcontext
- - ginkgolinter
- - gocritic
- - govet
- - loggercheck
- - misspell
- - perfsprint
- - revive
- - unconvert
- - makezero
- - errcheck
- - goconst
- - ineffassign
- - nakedret
- - prealloc
- - unparam
- - unused
+ - bodyclose
+ - copyloopvar
+ - dupword
+ - durationcheck
+ - errcheck
+ - fatcontext
+ - ginkgolinter
+ - goconst
+ - gocritic
+ - govet
+ - ineffassign
+ - loggercheck
+ - makezero
+ - misspell
+ - nakedret
+ - nilnil
+ - perfsprint
+ - prealloc
+ - revive
+ - staticcheck
+ - unparam
+ - unused
+ - unconvert
+ settings:
+ revive: # see https://github.com/mgechev/revive#available-rules for all options
+ rules:
+ - name: blank-imports
+ - name: context-as-argument
+ - name: context-keys-type
+ - name: dot-imports
+ - name: error-return
+ - name: error-strings
+ - name: error-naming
+ - name: exported
+ - name: if-return
+ - name: increment-decrement
+ - name: var-naming
+ - name: var-declaration
+ - name: package-comments
+ - name: range
+ - name: receiver-naming
+ - name: time-naming
+ - name: unexported-return
+ - name: indent-error-flow
+ - name: errorf
+
+issues:
+ # do not limit the number of issues, ensure even identical are reported
+ max-issues-per-linter: 0
+ max-same-issues: 0
+ # Note: 'new' setting is controlled via Makefile LINT_NEW_ONLY variable
+ # Set to true to only check new code, false (default) to check all code
diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md
index caef5f9d7..bf0921699 100644
--- a/DEVELOPMENT.md
+++ b/DEVELOPMENT.md
@@ -356,3 +356,20 @@ helm uninstall kgateway-crds -n kgateway-system
```
For more details, see the Gateway API inference Extension [getting started guide](https://gateway-api-inference-extension.sigs.k8s.io/guides/)
+
+## PR Approval Process
+
+The project uses a Prow-inspired ChatOps system to manage PR approvals via comment commands.
+
+### Available Commands
+
+| Command | Policy | Description |
+|---------|--------|-------------|
+| `/lgtm` | OWNERS approvers | Adds the `lgtm` label and enables auto-merge (squash). The PR merges automatically once all checks pass. |
+| `/lgtm cancel` | OWNERS approvers | Removes the `lgtm` label and disables auto-merge. |
+| `/hold` | Anyone with write access | Adds the `hold` label to prevent the PR from merging. |
+| `/hold cancel` | Anyone with write access | Removes the `hold` label. |
+
+### Approval Reset on New Commits
+
+When new commits are pushed to an approved PR, the `lgtm` label is automatically removed and auto-merge is disabled. This ensures approvals always reflect the latest code. The author must request a new `/lgtm` after pushing changes.
diff --git a/Dockerfile.epp b/Dockerfile.epp
index 915a34a0a..4996c6fbe 100644
--- a/Dockerfile.epp
+++ b/Dockerfile.epp
@@ -1,13 +1,52 @@
## Minimal runtime Dockerfile (microdnf-only, no torch, wrapper in site-packages)
-# Build Stage: using Go 1.24 image
-FROM quay.io/projectquay/golang:1.24 AS builder
+# Go dependencies stage: download go modules and extract kv-cache
+FROM quay.io/projectquay/golang:1.24 AS go-deps
+
+WORKDIR /workspace
+
+# Copy the Go Modules manifests
+COPY go.mod go.mod
+COPY go.sum go.sum
+
+# Copy the go source
+COPY cmd/ cmd/
+COPY pkg/ pkg/
+
+RUN go mod download
+
+# Copy Python wrapper and requirements from llm-d-kv-cache dependency
+# Extract version dynamically and copy to a known location
+RUN KV_CACHE_PKG=$(go list -m -f '{{.Dir}}' github.com/llm-d/llm-d-kv-cache) && \
+ mkdir -p /workspace/kv-cache && \
+ cp -r $KV_CACHE_PKG/* /workspace/kv-cache
+
+FROM python:3.12-slim AS python-builder
+
+ARG TARGETARCH
+
+COPY --from=go-deps /workspace/kv-cache /workspace/kv-cache
+WORKDIR /workspace/kv-cache
+
+# Create venv and install vLLM based on architecture using pre-built wheels
+RUN python3.12 -m venv /workspace/kv-cache/build/venv && \
+ . /workspace/kv-cache/build/venv/bin/activate && \
+ pip install --upgrade pip && \
+ VLLM_VERSION="0.14.0" && \
+ if [ "$TARGETARCH" = "arm64" ]; then \
+ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_aarch64.whl; \
+ elif [ "$TARGETARCH" = "amd64" ]; then \
+ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cpu; \
+ else \
+ echo "ERROR: Unsupported architecture: $TARGETARCH. Only arm64 and amd64 are supported." && exit 1; \
+ fi
+
+# Go build stage
+FROM quay.io/projectquay/golang:1.24 AS go-builder
ARG TARGETOS
ARG TARGETARCH
ARG PYTHON_VERSION=3.12
-
ENV PYTHON=python${PYTHON_VERSION}
-ENV PYTHONPATH=/usr/lib64/${PYTHON}/site-packages:/usr/lib/${PYTHON}/site-packages
# Install build tools
# The builder is based on UBI8, so we need epel-release-8.
@@ -16,52 +55,22 @@ RUN dnf install -y 'https://dl.fedoraproject.org/pub/epel/epel-release-latest-8.
dnf install -y gcc-c++ libstdc++ libstdc++-devel clang zeromq-devel pkgconfig ${PYTHON}-devel ${PYTHON}-pip git && \
dnf clean all
+COPY --from=go-deps /workspace /workspace
+COPY --from=go-deps /go/pkg/mod /go/pkg/mod
WORKDIR /workspace
-# Copy the Go Modules manifests
-COPY go.mod go.mod
-COPY go.sum go.sum
+COPY Makefile* ./
-# Copy the go source
-COPY cmd/ cmd/
-COPY pkg/ pkg/
+COPY --from=python-builder /workspace/kv-cache/pkg/preprocessing/chat_completions /workspace/kv-cache/pkg/preprocessing/chat_completions
+RUN make setup-venv
+COPY --from=python-builder /workspace/kv-cache/build/venv/lib/python3.12/site-packages /workspace/build/venv/lib/python3.12/site-packages
-RUN go mod download
+ENV PYTHONPATH=/workspace/kv-cache/pkg/preprocessing/chat_completions:/workspace/build/venv/lib/python3.12/site-packages
+RUN python3.12 -c "import tokenizer_wrapper" # verify tokenizer_wrapper is correctly installed
-# Copy Python wrapper and requirements from llm-d-kv-cache-manager dependency
-# Extract version dynamically and copy to a known location
-# We need to keep llm-d-kv-cache-manager as go module path is kept the old name
-RUN KVCACHE_MANAGER_VERSION=$(go list -m -f '{{.Version}}' github.com/llm-d/llm-d-kv-cache-manager) && \
- mkdir -p /workspace/kv-cache && \
- cp /go/pkg/mod/github.com/llm-d/llm-d-kv-cache-manager@${KVCACHE_MANAGER_VERSION}/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py \
- /workspace/kv-cache/render_jinja_template_wrapper.py && \
- cp /go/pkg/mod/github.com/llm-d/llm-d-kv-cache-manager@${KVCACHE_MANAGER_VERSION}/pkg/preprocessing/chat_completions/requirements.txt \
- /workspace/kv-cache/requirements.txt
-
-# HuggingFace tokenizer bindings (static lib)
-RUN mkdir -p lib
-# Ensure that the RELEASE_VERSION matches the one used in the imported llm-d-kv-cache-manager version
ARG RELEASE_VERSION=v1.22.1
-RUN curl -L https://github.com/daulet/tokenizers/releases/download/${RELEASE_VERSION}/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib
-RUN ranlib lib/*.a
-
-# Build
-# the GOARCH has not a default value to allow the binary be built according to the host where the command
-# was called. For example, if we call make image-build in a local env which has the Apple Silicon M1 SO
-# the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore,
-# by leaving it empty we can ensure that the container and binary shipped on it will have the same platform.
-ENV CGO_ENABLED=1
-ENV GOOS=${TARGETOS:-linux}
-ENV GOARCH=${TARGETARCH}
-
-
-ARG COMMIT_SHA=unknown
-ARG BUILD_REF
-RUN CGO_CFLAGS="$(${PYTHON}-config --cflags) -I/workspace/lib" && \
- CGO_LDFLAGS="$(${PYTHON}-config --ldflags --embed) -L/workspace/lib -ltokenizers -ldl -lm" && \
- export CGO_CFLAGS CGO_LDFLAGS && \
- go build -a -o bin/epp -ldflags="-extldflags '-L$(pwd)/lib' -X sigs.k8s.io/gateway-api-inference-extension/version.CommitSHA=${COMMIT_SHA} -X sigs.k8s.io/gateway-api-inference-extension/version.BuildRef=${BUILD_REF}" cmd/epp/main.go
+RUN TOKENIZER_VERSION=${RELEASE_VERSION} make build-epp
# Runtime stage
# Use ubi9 as a minimal base image to package the manager binary
@@ -69,7 +78,7 @@ RUN CGO_CFLAGS="$(${PYTHON}-config --cflags) -I/workspace/lib" && \
FROM registry.access.redhat.com/ubi9/ubi-minimal:9.7
ARG PYTHON_VERSION=3.12
WORKDIR /
-COPY --from=builder /workspace/bin/epp /app/epp
+COPY --from=go-builder /workspace/bin/epp /app/epp
USER root
@@ -87,25 +96,16 @@ RUN curl -L -o /tmp/epel-release.rpm https://dl.fedoraproject.org/pub/epel/epel-
ln -sf /usr/bin/${PYTHON} /usr/bin/python3 && \
ln -sf /usr/bin/${PYTHON} /usr/bin/python
+# Copy Python kv-cache package and site-packages from the python-builder stage
+COPY --from=python-builder /workspace/kv-cache /workspace/kv-cache
+ENV PYTHONPATH=/workspace/kv-cache/pkg/preprocessing/chat_completions:/workspace/kv-cache/build/venv/lib/python3.12/site-packages
+RUN ${PYTHON} -c "import tokenizer_wrapper" # verify tokenizer_wrapper is correctly installed
-# Install wrapper as a module in site-packages
-RUN mkdir -p /usr/local/lib/${PYTHON}/site-packages/
-COPY --from=builder /workspace/kv-cache/render_jinja_template_wrapper.py /usr/local/lib/${PYTHON}/site-packages/
-
-# Python deps (no cache, single target) β filter out torch
-ENV PIP_NO_CACHE_DIR=1 PIP_DISABLE_PIP_VERSION_CHECK=1
-COPY --from=builder /workspace/kv-cache/requirements.txt /tmp/requirements.txt
-RUN sed '/^torch\b/d' /tmp/requirements.txt > /tmp/requirements.notorch.txt && \
- ${PYTHON} -m pip install --no-cache-dir --upgrade pip setuptools wheel && \
- ${PYTHON} -m pip install --no-cache-dir --target /usr/local/lib/${PYTHON}/site-packages -r /tmp/requirements.notorch.txt && \
- ${PYTHON} -m pip install --no-cache-dir --target /usr/local/lib/${PYTHON}/site-packages PyYAML && \
- rm /tmp/requirements.txt /tmp/requirements.notorch.txt && \
- rm -rf /root/.cache/pip
-
-# Python env
-ENV PYTHONPATH="/usr/local/lib/${PYTHON}/site-packages:/usr/lib/${PYTHON}/site-packages"
-ENV PATH=/usr/bin:/usr/local/bin:$PATH
ENV HF_HOME="/tmp/.cache"
+# used by kv-cache-manager
+ENV LOCAL_TOKENIZER_DIR="/tmp/.cache"
+# Create cache directory and set permissions for non-root user
+RUN mkdir -p /tmp/.cache && chown -R 65532:65532 ${HF_HOME}
USER 65532:65532
@@ -117,4 +117,3 @@ EXPOSE 9090
EXPOSE 5557
ENTRYPOINT ["/app/epp"]
-
diff --git a/Makefile b/Makefile
index edbd8ec56..89c60c651 100644
--- a/Makefile
+++ b/Makefile
@@ -25,10 +25,11 @@ export EPP_IMAGE ?= $(IMAGE_TAG_BASE):$(EPP_TAG)
SIDECAR_TAG ?= dev
SIDECAR_IMAGE_TAG_BASE ?= $(IMAGE_REGISTRY)/$(SIDECAR_IMAGE_NAME)
export SIDECAR_IMAGE ?= $(SIDECAR_IMAGE_TAG_BASE):$(SIDECAR_TAG)
-VLLM_SIMULATOR_TAG ?= v0.6.1
+VLLM_SIMULATOR_TAG ?= latest
VLLM_SIMULATOR_TAG_BASE ?= $(IMAGE_REGISTRY)/$(VLLM_SIMULATOR_IMAGE_NAME)
export VLLM_SIMULATOR_IMAGE ?= $(VLLM_SIMULATOR_TAG_BASE):$(VLLM_SIMULATOR_TAG)
NAMESPACE ?= hc4ai-operator
+LINT_NEW_ONLY ?= false # Set to true to only lint new code, false to lint all code (default matches CI behavior)
# Map go arch to platform-specific arch
ifeq ($(TARGETOS),darwin)
@@ -139,9 +140,15 @@ format: check-golangci-lint ## Format Go source files
$(GOLANGCI_LINT) fmt
.PHONY: lint
-lint: check-golangci-lint check-typos ## Run lint
+lint: check-golangci-lint check-typos ## Run lint (use LINT_NEW_ONLY=true to only check new code)
@printf "\033[33;1m==== Running linting ====\033[0m\n"
- CGO_CFLAGS="${CGO_CFLAGS}" $(GOLANGCI_LINT) run
+ @if [ "$(LINT_NEW_ONLY)" = "true" ]; then \
+ printf "\033[33mChecking new code only (LINT_NEW_ONLY=true)\033[0m\n"; \
+ CGO_CFLAGS="${CGO_CFLAGS}" $(GOLANGCI_LINT) run --new; \
+ else \
+ printf "\033[33mChecking all code (LINT_NEW_ONLY=false, default)\033[0m\n"; \
+ CGO_CFLAGS="${CGO_CFLAGS}" $(GOLANGCI_LINT) run; \
+ fi
$(TYPOS)
.PHONY: install-hooks
@@ -157,10 +164,29 @@ test-unit: test-unit-epp test-unit-sidecar ## Run unit tests
.PHONY: test-unit-%
test-unit-%: download-tokenizer install-python-deps check-dependencies ## Run unit tests
@printf "\033[33;1m==== Running Unit Tests ====\033[0m\n"
- @KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache-manager 2>/dev/null || echo ""); \
+ @KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache 2>/dev/null || echo ""); \
PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \
CGO_CFLAGS=${$*_CGO_CFLAGS} CGO_LDFLAGS=${$*_CGO_LDFLAGS} go test $($*_LDFLAGS) -v $$($($*_TEST_FILES) | tr '\n' ' ')
+.PHONY: test-filter
+test-filter: download-tokenizer install-python-deps check-dependencies ## Run filtered unit tests (usage: make test-filter PATTERN=TestName TYPE=epp)
+ @if [ -z "$(PATTERN)" ]; then \
+ echo "ERROR: PATTERN is required. Usage: make test-filter PATTERN=TestName [TYPE=epp|sidecar]"; \
+ exit 1; \
+ fi
+ @TEST_TYPE="$(if $(TYPE),$(TYPE),epp)"; \
+ printf "\033[33;1m==== Running Filtered Tests (pattern: $(PATTERN), type: $$TEST_TYPE) ====\033[0m\n"; \
+ KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache 2>/dev/null || echo ""); \
+ if [ "$$TEST_TYPE" = "epp" ]; then \
+ PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \
+ CGO_CFLAGS=$(epp_CGO_CFLAGS) CGO_LDFLAGS=$(epp_CGO_LDFLAGS) \
+ go test $(epp_LDFLAGS) -v -run "$(PATTERN)" $$($(epp_TEST_FILES) | tr '\n' ' '); \
+ else \
+ PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \
+ CGO_CFLAGS=$(sidecar_CGO_CFLAGS) CGO_LDFLAGS=$(sidecar_CGO_LDFLAGS) \
+ go test $(sidecar_LDFLAGS) -v -run "$(PATTERN)" $$($(sidecar_TEST_FILES) | tr '\n' ' '); \
+ fi
+
.PHONY: test-integration
test-integration: download-tokenizer check-dependencies ## Run integration tests
@printf "\033[33;1m==== Running Integration Tests ====\033[0m\n"
diff --git a/Makefile.tools.mk b/Makefile.tools.mk
index e750cd41b..7bf41b369 100644
--- a/Makefile.tools.mk
+++ b/Makefile.tools.mk
@@ -18,11 +18,12 @@ GINKGO_VERSION ?= v2.27.2
GOLANGCI_LINT_VERSION ?= v2.1.6
KUSTOMIZE_VERSION ?= v5.5.0
TYPOS_VERSION ?= v1.34.0
+VLLM_VERSION ?= 0.14.0
## Python Configuration
PYTHON_VERSION ?= 3.12
# Extract RELEASE_VERSION from Dockerfile
-TOKENIZER_VERSION := $(shell grep '^ARG RELEASE_VERSION=' Dockerfile.epp | cut -d'=' -f2)
+TOKENIZER_VERSION ?= $(shell grep '^ARG RELEASE_VERSION=' Dockerfile.epp | cut -d'=' -f2)
# Python executable for creating venv
PYTHON_EXE := $(shell command -v python$(PYTHON_VERSION) || command -v python3)
@@ -151,33 +152,79 @@ $(TOKENIZER_LIB): | $(LOCALLIB)
@ranlib $(LOCALLIB)/*.a
@echo "Tokenizer bindings downloaded successfully."
-
-.PHONY: install-python-deps
-install-python-deps: ## Sets up Python virtual environment and installs dependencies
- @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n"
+.PHONY: detect-python
+detect-python: ## Detects Python and prints the configuration.
+ @printf "\033[33;1m==== Python Configuration ====\033[0m\n"
@if [ -z "$(PYTHON_EXE)" ]; then \
echo "ERROR: Python 3 not found in PATH."; \
exit 1; \
fi
+ @# Verify the version of the found python executable using its exit code
+ @if ! $(PYTHON_EXE) -c "import sys; sys.exit(0 if sys.version_info[:2] == ($(shell echo $(PYTHON_VERSION) | cut -d. -f1), $(shell echo $(PYTHON_VERSION) | cut -d. -f2)) else 1)"; then \
+ echo "ERROR: Found Python at '$(PYTHON_EXE)' but it is not version $(PYTHON_VERSION)."; \
+ echo "Please ensure 'python$(PYTHON_VERSION)' or a compatible 'python3' is in your PATH."; \
+ exit 1; \
+ fi
+ @echo "Python executable: $(PYTHON_EXE) ($$($(PYTHON_EXE) --version))"
+ @echo "Python CFLAGS: $(PYTHON_CFLAGS)"
+ @echo "Python LDFLAGS: $(PYTHON_LDFLAGS)"
+ @if [ -z "$(PYTHON_CFLAGS)" ]; then \
+ echo "ERROR: Python development headers not found. See installation instructions above."; \
+ exit 1; \
+ fi
+ @printf "\033[33;1m==============================\033[0m\n"
+
+.PHONY: setup-venv
+setup-venv: detect-python ## Sets up the Python virtual environment.
+ @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n"
@if [ ! -f "$(VENV_BIN)/pip" ]; then \
echo "Creating virtual environment..."; \
$(PYTHON_EXE) -m venv $(VENV_DIR) || { \
echo "ERROR: Failed to create virtual environment."; \
echo "Your Python installation may be missing the 'venv' module."; \
+ echo "Try: 'sudo apt install python$(PYTHON_VERSION)-venv' or 'sudo dnf install python$(PYTHON_VERSION)-devel'"; \
exit 1; \
}; \
fi
- @echo "Upgrading pip and installing dependencies..."
- @$(VENV_BIN)/pip install --upgrade pip --quiet
- @KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}' github.com/llm-d/llm-d-kv-cache-manager 2>/dev/null); \
- if [ -n "$$KV_CACHE_PKG" ] && [ -f "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/requirements.txt" ]; then \
- echo "Installing Python dependencies from kv-cache-manager..."; \
- $(VENV_BIN)/pip install --quiet -r "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/requirements.txt"; \
+ @echo "Upgrading pip..."
+ @$(VENV_BIN)/pip install --upgrade pip
+ @echo "Python virtual environment setup complete."
+
+.PHONY: install-python-deps
+install-python-deps: setup-venv ## installs dependencies.
+ @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n"
+ @echo "install vllm..."
+ @KV_CACHE_PKG=$${KV_CACHE_PKG:-$$(go list -m -f '{{.Dir}}' github.com/llm-d/llm-d-kv-cache 2>/dev/null)}; \
+ if [ -z "$$KV_CACHE_PKG" ]; then \
+ echo "ERROR: kv-cache package not found."; \
+ exit 1; \
+ fi; \
+ if [ "$(TARGETOS)" = "darwin" ]; then \
+ if [ -f "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/setup.sh" ]; then \
+ echo "Running kv-cache setup script for macOS..."; \
+ cp "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/setup.sh" build/kv-cache-setup.sh; \
+ chmod +wx build/kv-cache-setup.sh; \
+ cd build && PATH=$(VENV_BIN):$$PATH ./kv-cache-setup.sh && cd ..; \
+ else \
+ echo "ERROR: setup script not found at $$KV_CACHE_PKG/pkg/preprocessing/chat_completions/setup.sh"; \
+ exit 1; \
+ fi; \
else \
- echo "WARNING: Could not find kv-cache-manager requirements.txt, installing minimal deps..."; \
- $(VENV_BIN)/pip install --quiet 'transformers>=4.53.0' 'jinja2>=2.11'; \
+ echo "Installing vLLM for Linux $(TARGETARCH)..."; \
+ if [ "$(TARGETARCH)" = "arm64" ]; then \
+ $(VENV_BIN)/pip install https://github.com/vllm-project/vllm/releases/download/v$(VLLM_VERSION)/vllm-$(VLLM_VERSION)+cpu-cp38-abi3-manylinux_2_35_aarch64.whl; \
+ elif [ "$(TARGETARCH)" = "amd64" ]; then \
+ $(VENV_BIN)/pip install https://github.com/vllm-project/vllm/releases/download/v$(VLLM_VERSION)/vllm-$(VLLM_VERSION)+cpu-cp38-abi3-manylinux_2_35_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cpu; \
+ else \
+ echo "ERROR: Unsupported architecture: $(TARGETARCH). Only arm64 and amd64 are supported."; \
+ exit 1; \
+ fi; \
fi
- @echo "β
Python dependencies installed in venv"
+ @echo "Verifying vllm installation..."
+ @$(VENV_BIN)/python -c "import vllm; print('β
vllm version ' + vllm.__version__ + ' installed.')" || { \
+ echo "ERROR: vllm library not properly installed in venv."; \
+ exit 1; \
+ }
.PHONY: check-tools
check-tools: check-go check-ginkgo check-golangci-lint check-kustomize check-envsubst check-container-tool check-kubectl check-buildah check-typos ## Check that all required tools are installed
diff --git a/OWNERS b/OWNERS
index 84ec0bd47..4262fe677 100644
--- a/OWNERS
+++ b/OWNERS
@@ -1,4 +1,5 @@
approvers:
+- revit13
- elevran
- kfswain
- nilig
@@ -23,4 +24,4 @@ auto-assign:
- nilig
- nirrozenbaum
- shmuelk
- - vMaroon
\ No newline at end of file
+ - vMaroon
diff --git a/README.md b/README.md
index 3be2337aa..c1e291a38 100644
--- a/README.md
+++ b/README.md
@@ -5,7 +5,7 @@
# Inference Scheduler
-This scheduler makes optimized routing decisions for inference requests to
+This schedulejjr makes optimized routing decisions for inference requests to
the llm-d inference framework.
## About
diff --git a/cmd/pd-sidecar/main.go b/cmd/pd-sidecar/main.go
index cd086eb7d..8e9e6533c 100644
--- a/cmd/pd-sidecar/main.go
+++ b/cmd/pd-sidecar/main.go
@@ -35,7 +35,7 @@ var (
// supportedConnectors defines all valid P/D connector types
supportedConnectors = []string{
proxy.ConnectorNIXLV2,
- proxy.ConnectorLMCache,
+ proxy.ConnectorSharedStorage,
proxy.ConnectorSGLang,
}
)
diff --git a/deploy/components/crds-gie/kustomization.yaml b/deploy/components/crds-gie/kustomization.yaml
index f2e0a4629..e9897ce37 100644
--- a/deploy/components/crds-gie/kustomization.yaml
+++ b/deploy/components/crds-gie/kustomization.yaml
@@ -10,4 +10,4 @@ apiVersion: kustomize.config.k8s.io/v1beta1
kind: Kustomization
resources:
-- https://github.com/kubernetes-sigs/gateway-api-inference-extension/config/crd?ref=v1.2.1
+- https://github.com/kubernetes-sigs/gateway-api-inference-extension/config/crd?ref=v1.3.0
diff --git a/deploy/config/dp-epp-config.yaml b/deploy/config/dp-epp-config.yaml
index 703a44f67..6e8418866 100644
--- a/deploy/config/dp-epp-config.yaml
+++ b/deploy/config/dp-epp-config.yaml
@@ -3,7 +3,13 @@
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
plugins:
-- type: prefix-cache-scorer
+- type: precise-prefix-cache-scorer
+ parameters:
+ indexerConfig:
+ tokenProcessorConfig:
+ blockSize: 5
+ kvBlockIndexConfig:
+ maxPrefixBlocksToMatch: 256
- type: decode-filter
- type: max-score-picker
- type: data-parallel-profile-handler
@@ -14,5 +20,5 @@ schedulingProfiles:
plugins:
- pluginRef: decode-filter
- pluginRef: max-score-picker
- - pluginRef: prefix-cache-scorer
+ - pluginRef: precise-prefix-cache-scorer
weight: 2
diff --git a/deploy/config/epp-precise-prefix-cache-config.yaml b/deploy/config/epp-precise-prefix-cache-config.yaml
index 157505278..39b0bb285 100644
--- a/deploy/config/epp-precise-prefix-cache-config.yaml
+++ b/deploy/config/epp-precise-prefix-cache-config.yaml
@@ -7,10 +7,10 @@ plugins:
- type: decode-filter
- type: precise-prefix-cache-scorer
parameters:
+ tokenProcessorConfig:
+ blockSize: 64 # must match vLLM block size
+ hashSeed: "42" # must match vLLM PYTHONHASHSEED env var
indexerConfig:
- tokenProcessorConfig:
- blockSize: 64 # must match vLLM block size
- hashSeed: "42" # must match vLLM PYTHONHASHSEED env var
kvBlockIndexConfig:
enableMetrics: true # enable kv-block index metrics (prometheus)
- type: kv-cache-utilization-scorer
diff --git a/deploy/config/pd-epp-config.yaml b/deploy/config/pd-epp-config.yaml
index 9732be2f3..35a94a3be 100644
--- a/deploy/config/pd-epp-config.yaml
+++ b/deploy/config/pd-epp-config.yaml
@@ -1,13 +1,25 @@
# Sample EPP configuration for tunning with P/D
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
+featureGates:
+- prepareDataPlugins
plugins:
- type: prefill-header-handler
- type: prefix-cache-scorer
+ parameters:
+ maxPrefixBlocksToMatch: 256
+ lruCapacityPerServer: 31250
+- type: queue-scorer
- type: prefill-filter
- type: decode-filter
- type: max-score-picker
+- type: prefix-based-pd-decider
+ parameters:
+ nonCachedTokens: 16
- type: pd-profile-handler
+ parameters:
+ primaryPort: ${PRIMARY_PORT}
+ deciderPluginName: prefix-based-pd-decider
schedulingProfiles:
- name: prefill
plugins:
@@ -15,9 +27,13 @@ schedulingProfiles:
- pluginRef: max-score-picker
- pluginRef: prefix-cache-scorer
weight: 2
+ - pluginRef: queue-scorer
+ weight: 1
- name: decode
plugins:
- pluginRef: decode-filter
- pluginRef: max-score-picker
- pluginRef: prefix-cache-scorer
weight: 2
+ - pluginRef: queue-scorer
+ weight: 1
diff --git a/deploy/config/sim-epp-kvcache-config.yaml b/deploy/config/sim-epp-kvcache-config.yaml
index 566e92437..f582c543f 100644
--- a/deploy/config/sim-epp-kvcache-config.yaml
+++ b/deploy/config/sim-epp-kvcache-config.yaml
@@ -3,18 +3,16 @@
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
plugins:
-- type: prefix-cache-scorer
+- type: precise-prefix-cache-scorer
parameters:
- mode: cache_tracking
kvEventsConfig:
zmqEndpoint: tcp://0.0.0.0:5557
indexerConfig:
- prefixStoreConfig:
- blockSize: 16
tokenProcessorConfig:
blockSize: 16 # must match vLLM block size if not default (16)
hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods
tokenizersPoolConfig:
+ modelName: TinyLlama/TinyLlama-1.1B-Chat-v1.0 # replace value to use different model for tokenizer loading
hf:
tokenizersCacheDir: "/cache/tokenizers"
kvBlockIndexConfig:
@@ -28,5 +26,5 @@ schedulingProfiles:
plugins:
- pluginRef: decode-filter
- pluginRef: max-score-picker
- - pluginRef: prefix-cache-scorer
+ - pluginRef: precise-prefix-cache-scorer
weight: 10
diff --git a/deploy/config/sim-epp-no-hit-lru.yaml b/deploy/config/sim-epp-no-hit-lru.yaml
index 8d0224411..e10ec5062 100644
--- a/deploy/config/sim-epp-no-hit-lru.yaml
+++ b/deploy/config/sim-epp-no-hit-lru.yaml
@@ -3,11 +3,13 @@
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
plugins:
-- type: prefix-cache-scorer
+- type: precise-prefix-cache-scorer
parameters:
- hashBlockSize: 5
- maxPrefixBlocksToMatch: 256
- lruCapacityPerServer: 31250
+ indexerConfig:
+ tokenProcessorConfig:
+ blockSize: 5
+ kvBlockIndexConfig:
+ maxPrefixBlocksToMatch: 256
- type: no-hit-lru-scorer
parameters:
lruSize: 2048
@@ -19,7 +21,7 @@ schedulingProfiles:
plugins:
- pluginRef: decode-filter
- pluginRef: max-score-picker
- - pluginRef: prefix-cache-scorer
+ - pluginRef: precise-prefix-cache-scorer
weight: 2
- pluginRef: no-hit-lru-scorer
weight: 1
diff --git a/deploy/config/sim-pd-epp-config.yaml b/deploy/config/sim-pd-epp-config.yaml
index 2d6a85dd9..2f93504a1 100644
--- a/deploy/config/sim-pd-epp-config.yaml
+++ b/deploy/config/sim-pd-epp-config.yaml
@@ -2,21 +2,27 @@
# Use with small hash block size for simulation purposes
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
+featureGates:
+- prepareDataPlugins
plugins:
- type: prefill-header-handler
- type: prefix-cache-scorer
parameters:
- hashBlockSize: 5
+ blockSizeTokens: 16
+ autoTune: false
maxPrefixBlocksToMatch: 256
lruCapacityPerServer: 31250
+- type: queue-scorer
- type: prefill-filter
- type: decode-filter
- type: max-score-picker
+- type: prefix-based-pd-decider
+ parameters:
+ nonCachedTokens: 16
- type: pd-profile-handler
parameters:
- threshold: 10
- hashBlockSize: 5
primaryPort: ${PRIMARY_PORT}
+ deciderPluginName: prefix-based-pd-decider
schedulingProfiles:
- name: prefill
plugins:
@@ -24,9 +30,13 @@ schedulingProfiles:
- pluginRef: max-score-picker
- pluginRef: prefix-cache-scorer
weight: 2
+ - pluginRef: queue-scorer
+ weight: 1
- name: decode
plugins:
- pluginRef: decode-filter
- pluginRef: max-score-picker
- pluginRef: prefix-cache-scorer
weight: 2
+ - pluginRef: queue-scorer
+ weight: 1
diff --git a/docs/architecture.md b/docs/architecture.md
index b3215815c..c6ab17198 100644
--- a/docs/architecture.md
+++ b/docs/architecture.md
@@ -161,11 +161,13 @@ A complete configuration might look like this:
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
plugins:
-- type: prefix-cache-scorer
+- type: precise-prefix-cache-scorer
parameters:
- hashBlockSize: 5
- maxPrefixBlocksToMatch: 256
- lruCapacityPerServer: 31250
+ indexerConfig:
+ tokenProcessorConfig:
+ blockSize: 5
+ kvBlockIndexConfig:
+ maxPrefixBlocksToMatch: 256
- type: decode-filter
- type: max-score-picker
- type: single-profile-handler
@@ -174,7 +176,7 @@ schedulingProfiles:
plugins:
- pluginRef: decode-filter
- pluginRef: max-score-picker
- - pluginRef: prefix-cache-scorer
+ - pluginRef: precise-prefix-cache-scorer
weight: 50
```
@@ -204,15 +206,41 @@ Selects the profiles to use when running with disaggregated prefill/decode
- **Type**: `pd-profile-handler`
- **Parameters**:
- - `threshold`: specifies the threshold at which there are enough new input tokens to send the request to prefill and then decode, vs just to decode.
- - `hashBlockSize`: specifies the length of the prompt chunk that a block is keyed by. This must the same value used for the PrefixCachePlugin.
- `decodeProfile`: specifies the name of the profile used for the decode scheduling. Only needed if the decode profile is not named `decode`.
- `prefillProfile`: specifies the name of the profile used for the prefill scheduling. Only needed if the prefill profile is not named `prefill`.
+ - `deciderPluginName`: specifies the name of the decider plugin. Decider determines whether disaggregated PD should be executed
+ - `primaryPort`: the base port number used for data parallel communication.
**Note:** When using this plugin you must also have a PrefixCachePlugin configured in the prefill and decode scheduling profiles.
---
+#### Prefix Based Decider Plugin
+
+Type: `prefix-based-pd-decider`
+
+**Parameters**
+- `nonCachedTokens`: length, in token, of the uncached part of the user input above which disaggregated PD is triggered.
+
+Note: `prepareDataPlugins` feature gate should be enabled
+
+**Example**
+```yaml
+kind: EndpointPickerConfig
+featureGates:
+- prepareDataPlugins
+plugins:
+- type: prefix-based-pd-decider
+ parameters:
+ nonCachedTokens: 4
+- type: pd-profile-handler
+ parameters:
+ primaryPort: 8000
+ deciderPluginName: prefix-based-pd-decider
+```
+
+---
+
#### ByLabelSelector
Filters out pods using a standard Kubernetes label selector.
@@ -308,12 +336,14 @@ Configuration:
- **Type**: `precise-prefix-cache-scorer`
- **Parameters**:
+ - `tokenProcessorConfig`: Configuration for the `kvblock.TokenProcessor`.
- `indexerConfig`: Configuration for the `kvcache.Indexer`.
- `kvEventsConfig`: Configuration for the `kvevents.Pool`.
See list of parameters at [llm-d-kv-cache/docs/configuration.md](https://github.com/llm-d/llm-d-kv-cache/blob/fa85b60207ba0a09daf23071e10ccb62d7977b40/docs/configuration.md).
Note that in most cases you will only need to set:
+- Model name in the `tokenizersPoolConfig` to match the model used in the vLLM deployment.
- HuggingFace token for the `tokenizersPoolConfig` or the `tokenizersCacheDir` to a mounted directory containing the tokenizers.
- For the HuggingFace token, the inference-scheduler also accepts the environment variable `HF_TOKEN` - this is the practical option for security.
- **IMPORTANT**: Token processor's block-size and hash-seed to match those used in the vLLM deployment.
@@ -325,15 +355,41 @@ Example configuration with the above parameters set:
plugins:
- type: precise-prefix-cache-scorer
parameters:
+ tokenProcessorConfig:
+ blockSize: 64 # must match vLLM block size
+ hashSeed: "12345" # must match vLLM PYTHONHASHSEED env var
indexerConfig:
- tokenProcessorConfig:
- blockSize: 64
- hashSeed: "12345"
- tokenizersPoolConfig:
- hf:
- huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable
- kvBlockIndexConfig:
- enableMetrics: true
+ kvBlockIndexConfig:
+ enableMetrics: true
+ tokenizersPoolConfig:
+ modelName: hf-repo/model-name
+ hf:
+ huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable
+```
+
+Example configuration for automatic pod discovery in active-active multi-replica scheduler deployments:
+```yaml
+ - type: precise-prefix-cache-scorer
+ parameters:
+ tokenProcessorConfig:
+ blockSize: 64
+ hashSeed: "42"
+ indexerConfig:
+ tokenizersPoolConfig:
+ modelName: "Qwen/Qwen3-32B"
+ hf:
+ tokenizersCacheDir: "/tmp/tokenizers"
+ kvEventsConfig:
+ topicFilter: "kv@"
+ concurrency: 4
+ discoverPods: true # enables automatic pod discovery for active-active HA
+ podDiscoveryConfig:
+ socketPort: 5556
+```
+
+Where the vLLM engines are configured to emit KV-Events on port `5556` as follows:
+```yaml
+ --kv-events-config "{\"enable_kv_cache_events\":true,\"publisher\":\"zmq\",\"endpoint\":\"tcp://*:5556\",\"topic\":\"kv@${POD_IP}@Qwen/Qwen3-32B\"}"
```
Example configuration with all parameters set:
@@ -342,23 +398,26 @@ Example configuration with all parameters set:
plugins:
- type: precise-prefix-cache-scorer
parameters:
+ tokenProcessorConfig:
+ blockSize: 16
+ hashSeed: "12345"
kvEventsConfig:
- zmqEndpoint: tcp://*:5557
- topicFilter: kv@
- concurrency: 8
- kvCacheIndexerConfig:
+ topicFilter: "kv@"
+ concurrency: 4
+ discoverPods: true # enables automatic pod discovery for active-active HA
+ podDiscoveryConfig:
+ socketPort: 5556
+ indexerConfig:
prefixStoreConfig:
cacheSize: 500000
blockSize: 256
- tokenProcessorConfig:
- blockSize: 16
- hashSeed: "12345"
kvBlockIndexConfig:
inMemoryConfig:
size: 100000000
podCacheSize: 10
enableMetrics: true
tokenizersPoolConfig:
+ modelName: hf-repo/model-name
workersCount: 8
hf:
huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable
@@ -434,11 +493,13 @@ Example configuration:
```yaml
plugins:
- - type: prefix-cache-scorer
+ - type: precise-prefix-cache-scorer
parameters:
- hashBlockSize: 5
- maxPrefixBlocksToMatch: 256
- lruCapacityPerServer: 31250
+ indexerConfig:
+ tokenProcessorConfig:
+ blockSize: 5
+ kvBlockIndexConfig:
+ maxPrefixBlocksToMatch: 256
- type: no-hit-lru-scorer
parameters:
lruSize: 2048
@@ -450,7 +511,7 @@ schedulingProfiles:
plugins:
- pluginRef: decode-filter
- pluginRef: max-score-picker
- - pluginRef: prefix-cache-scorer
+ - pluginRef: precise-prefix-cache-scorer
weight: 2
- pluginRef: no-hit-lru-scorer
weight: 1
diff --git a/docs/disagg_pd.md b/docs/disagg_pd.md
index 383b9c193..4c7cabd23 100644
--- a/docs/disagg_pd.md
+++ b/docs/disagg_pd.md
@@ -1,8 +1,8 @@
-# Disaggregated Prefill/Decode Inference Serving in llm-d
+# Disaggregated Prefill/Decode Inference Serving in LLM-D
## Overview
-This document describes the architecture and request lifecycle for enabling **disaggregated prefill and decode (P/D)** inference execution in the llm-d router. The architecture aims to improve flexibility, scalability, and performance by enabling separation of prefill and decode stages onto different workers.
+This document describes the architecture and request lifecycle for enabling **disaggregated prefill and decode (P/D)** inference execution in the LLM-D router. The architecture aims to improve flexibility, scalability, and performance by enabling separation of prefill and decode stages onto different workers.
This evolved version removes the requirement for sidecars on the **prefill node**, simplifying deployment while maintaining orchestration from the **decode node**.
@@ -25,7 +25,7 @@ This evolved version removes the requirement for sidecars on the **prefill node*
| **Decode Worker** | Handles decode stage and contains the sidecar for coordination |
| **Sidecar (Decode)** | Orchestrates communication with prefill worker and manages lifecycle |
| **Envoy Proxy** | Accepts OpenAI-style requests and forwards them to EPP |
-| **EPP** | End Point Picker, makes scheduling decisions |
+| **EPP** | Endpoint Picker, makes scheduling decisions |
---
@@ -37,7 +37,7 @@ This evolved version removes the requirement for sidecars on the **prefill node*
2. **EPP Scheduling Decision**
- EPP evaluates:
- Prompt length
- - KV cache hit probability
+ - KV-cache hit probability
- System and pod load
- Selects either:
- **Single node** path (decode handles all)
@@ -47,10 +47,10 @@ This evolved version removes the requirement for sidecars on the **prefill node*
3. **Execution**
- Request lands on Decode Worker (as selected by EPP)
- Decode sidecar coordinates:
- - If `prefill_worker_id == nil`, runs both stages locally by passing request to local vllm
- - If split:
- - Sends prefill job to Prefill Worker with a special header `do_remote_decode=true`
- - Upon receiving response from Prefill Worker runs decode stage
+ - If `x-prefiller-host-port` header doesn't exist, runs both stages locally by passing request to local vLLM
+ - If `x-prefiller-host-port` header exists:
+ - Sends the prefill job to the selected Prefill Worker with a special request field `do_remote_decode=true`
+ - Upon receiving the response from the Prefill Worker runs the decode stage
4. **Response Flow**
- Response flows from decode sidecar β Envoy β EPP β User
@@ -59,11 +59,34 @@ This evolved version removes the requirement for sidecars on the **prefill node*
## Architectural Details
+
+```mermaid
+sequenceDiagram
+ participant C as Client
+ participant I as Inference Gateway
+ participant DS as Decode Worker Sidecar
+ participant D as Decode Worker(vLLM)
+ participant P as Prefill Worker(vLLM)
+
+
+ C->>I: Inference Request
+ I->>DS: Request is sent to the Decode Worker Sidecar
with the selected Prefill worker set in a header.
+ DS->>P: Remote Prefill with prompt(max_tokens=1)
+ P-->>P: Run prefill
+ P->>DS: Remote kv parameters
+ DS->> D: Request is sent to the Decode Worker (vLLM) with remote_prefill true,
prefill ID and memory block IDs
+ D-->>P: Read kv-cache
+ D-->>D: Schedule decode into queue & run decode
+ D->>DS: Inference Response
+ DS->>I: Inference Response
+ I->>C: Inference Response
+```
+
### Sidecar Responsibilities (Decode Only)
- Receives EPP metadata (decode pod, optional prefill pod)
- Sends request to prefill
-- Waits and validates result
+- Waits for the result and validates it
- Launches local decode job
- Sends final response
@@ -73,33 +96,21 @@ This evolved version removes the requirement for sidecars on the **prefill node*
## Worker Selection Logic
-- **Decode Worker**:
- - Prefer longest prefix match / KV cache utilization (depends on available scorers)
-
-- **Prefill Worker**:
- - High prefix-cache hit rate
- - Low load
+- **Decode/Prefill Worker**:
+ - Prefer longest prefix match/kv-cache utilization (depends on available scorers) and low load
> **Skip prefill worker** when:
-> - Prefix match/kv cache hit is high
+> - Prefix match/kv-cache hit is high
> - Prompt is very short
---
-## vLLM and LMCache Integration
-
-- **vLLM changes** (or wrapper APIs):
- - `save()`, `load()` APIs
- - `done_sending`, `done_receiving`
- - Connector API supporting async transfer
-
----
## Drawbacks & Limitations
-- Slight increase in TTFT for split P/D
+- Slight increase in TTFT for disaggregated P/D
- Possibility of stranded memory on prefill crash
-- Need for timeout and retry logic
+- The need for timeout and retry logic
---
@@ -115,17 +126,17 @@ This evolved version removes the requirement for sidecars on the **prefill node*
## Future Considerations
- Cache coordinate
-- Pre allocation of kv blocks in decode node , push cache from prefill to decode worker during calculation
+- Pre-allocation of kv blocks in the decode node, push cache from the prefill to the decode worker during calculation
---
## Integrating External Prefill/Decode Workloads
-The llm-d inference scheduler supports integration with external disaggregated prefill/decode (P/D) workloads other inference frameworks that follow the same P/D separation pattern but use **different Kubernetes Pod labeling conventions**.
+The LLM-D inference scheduler supports integration with external disaggregated prefill/decode (P/D) workloads other inference frameworks that follow the same P/D separation pattern but use **different Kubernetes Pod labeling conventions**.
### Labeling Convention Flexibility
-By default, llm-d uses the label key `llm-d.ai/role` with values:
+By default, LLM-D uses the label key `llm-d.ai/role` with values:
- `"prefill"` β prefill-only pods
- `"decode"` or `"both"` β decode-capable pods
@@ -144,6 +155,8 @@ Below is a minimal `EndpointPickerConfig` that enables integration with workload
```yaml
apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
+featureGates:
+- prepareDataPlugins
plugins:
# Prefill selection: match Pods with label role=prefill
- type: by-label
@@ -159,15 +172,18 @@ plugins:
validValues: ["decode"]
- type: prefix-cache-scorer
parameters:
- hashBlockSize: 5
+ autoTune: false
+ blockSize: 5
maxPrefixBlocksToMatch: 256
lruCapacityPerServer: 31250
- type: max-score-picker
- type: prefill-header-handler
- - type: pd-profile-handler
+ - type: prefix-based-pd-decider
parameters:
- threshold: 0
- hashBlockSize: 5
+ nonCachedTokens: 8
+ - type: pd-profile-handler
+ parameters:
+ deciderPluginName: prefix-based-pd-decider
primaryPort: 8000
schedulingProfiles:
- name: prefill
@@ -175,13 +191,11 @@ schedulingProfiles:
- pluginRef: "prefill-pods"
- pluginRef: "max-score-picker"
- pluginRef: "prefix-cache-scorer"
- weight: 2
- name: decode
plugins:
- pluginRef: "decode-pods"
- pluginRef: "max-score-picker"
- pluginRef: "prefix-cache-scorer"
- weight: 2
```
---
@@ -190,6 +204,59 @@ schedulingProfiles:

+---
+## PD Deciders
+
+PD deciders are pd handler plugins responsible for determining whether disaggregated P/D should be executed for a given request, based on the properties of the request prompt.
+
+
+### Prefix-Based PD Decider
+
+The `prefix-based-pd-decider` plugin makes the disaggregation decision according to the length of the non-cached suffix of the prompt relative to tokens already cached on the selected decode pod.
+
+**How It Works**
+- Once a decode pod is selected, the decider checks how many tokens from the incoming prompt have already been sent to this pod
+
+- If the remaining non-cached suffix length is longer than the configured threshold (nonCachedTokens), disaggregation is triggered β the prefill will run remotely on a prefill pod, and decode locally on the decode pod
+
+- If the non-cached suffix is shorter or equal to the threshold, the full request runs locally on the decode worker without remote prefill
+
+**Configuration**
+```yaml
+- type: prefix-based-pd-decider
+ parameters:
+ nonCachedTokens: 8
+```
+
+**Parameter:**
+
+- `nonCachedTokens`: Number of non-cached tokens that trigger disaggregation
+ - If set to 0, disaggregation always occurs for all requests
+
+**Feature Gate Requirement**
+To activate this decider, ensure the following feature gate is enabled in your EndpointPickerConfig
+
+```yaml
+featureGates:
+- prepareDataPlugins
+```
+
+
+### Always-Disagg PD Decider
+The `always-disagg-pd-decider` is a simpler alternative used mainly for testing or benchmarking.
+It always triggers disaggregation, regardless of prefix cache state or prompt characteristics.
+
+**Configuration example:**
+
+```yaml
+- type: always-disagg-pd-decider
+```
+
+**Notes:**
+This plugin accepts no parameters.
+
+Itβs useful for validating end-to-end prefill/decode splitting and comparing system performance under forced disaggregation.
+
---
## References
diff --git a/go.mod b/go.mod
index ec6a25010..57c502280 100644
--- a/go.mod
+++ b/go.mod
@@ -1,8 +1,8 @@
module github.com/llm-d/llm-d-inference-scheduler
-go 1.24.1
+go 1.24.9
-toolchain go1.24.2
+toolchain go1.24.12
require (
github.com/go-logr/logr v1.4.3
@@ -10,9 +10,9 @@ require (
github.com/google/uuid v1.6.0
github.com/hashicorp/golang-lru/v2 v2.0.7
github.com/jellydator/ttlcache/v3 v3.4.0
- github.com/llm-d/llm-d-kv-cache-manager v0.4.0
- github.com/onsi/ginkgo/v2 v2.27.4
- github.com/onsi/gomega v1.39.0
+ github.com/llm-d/llm-d-kv-cache v0.5.0
+ github.com/onsi/ginkgo/v2 v2.28.1
+ github.com/onsi/gomega v1.39.1
github.com/openai/openai-go v1.12.0
github.com/prometheus/client_golang v1.23.2
github.com/stretchr/testify v1.11.1
@@ -23,10 +23,10 @@ require (
k8s.io/apimachinery v0.34.3
k8s.io/client-go v0.34.3
k8s.io/component-base v0.34.3
- k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d
- sigs.k8s.io/controller-runtime v0.22.4
+ k8s.io/utils v0.0.0-20251002143259-bc988d571ff4
+ sigs.k8s.io/controller-runtime v0.22.5
sigs.k8s.io/gateway-api v1.4.1
- sigs.k8s.io/gateway-api-inference-extension v1.3.0-rc.2
+ sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a
)
require (
@@ -46,7 +46,7 @@ require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emicklei/go-restful/v3 v3.13.0 // indirect
github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect
- github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect
+ github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect
github.com/evanphx/json-patch/v5 v5.9.11 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
@@ -54,23 +54,32 @@ require (
github.com/go-errors/errors v1.4.2 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/go-logr/zapr v1.3.0 // indirect
- github.com/go-openapi/jsonpointer v0.21.2 // indirect
- github.com/go-openapi/jsonreference v0.21.0 // indirect
- github.com/go-openapi/swag v0.23.1 // indirect
+ github.com/go-openapi/jsonpointer v0.22.1 // indirect
+ github.com/go-openapi/jsonreference v0.21.3 // indirect
+ github.com/go-openapi/swag v0.25.4 // indirect
+ github.com/go-openapi/swag/cmdutils v0.25.4 // indirect
+ github.com/go-openapi/swag/conv v0.25.4 // indirect
+ github.com/go-openapi/swag/fileutils v0.25.4 // indirect
+ github.com/go-openapi/swag/jsonname v0.25.4 // indirect
+ github.com/go-openapi/swag/jsonutils v0.25.4 // indirect
+ github.com/go-openapi/swag/loading v0.25.4 // indirect
+ github.com/go-openapi/swag/mangling v0.25.4 // indirect
+ github.com/go-openapi/swag/netutils v0.25.4 // indirect
+ github.com/go-openapi/swag/stringutils v0.25.4 // indirect
+ github.com/go-openapi/swag/typeutils v0.25.4 // indirect
+ github.com/go-openapi/swag/yamlutils v0.25.4 // indirect
github.com/go-task/slim-sprig/v3 v3.0.0 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/btree v1.1.3 // indirect
github.com/google/cel-go v0.26.0 // indirect
github.com/google/gnostic-models v0.7.0 // indirect
- github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8 // indirect
+ github.com/google/pprof v0.0.0-20260115054156-294ebfa9ad83 // indirect
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect
github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853 // indirect
github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
- github.com/josharian/intern v1.0.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
- github.com/mailru/easyjson v0.9.0 // indirect
github.com/moby/spdystream v0.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect
@@ -83,7 +92,7 @@ require (
github.com/prometheus/client_model v0.6.2 // indirect
github.com/prometheus/common v0.67.5 // indirect
github.com/prometheus/procfs v0.17.0 // indirect
- github.com/prometheus/prometheus v0.308.1 // indirect
+ github.com/prometheus/prometheus v0.309.1 // indirect
github.com/redis/go-redis/v9 v9.11.0 // indirect
github.com/spf13/cobra v1.9.1 // indirect
github.com/spf13/pflag v1.0.10 // indirect
@@ -97,7 +106,7 @@ require (
github.com/x448/float16 v0.8.4 // indirect
github.com/xlab/treeprint v1.2.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
- go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect
+ go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect
go.opentelemetry.io/otel v1.39.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 // indirect
@@ -112,16 +121,16 @@ require (
go.yaml.in/yaml/v2 v2.4.3 // indirect
go.yaml.in/yaml/v3 v3.0.4 // indirect
golang.org/x/exp v0.0.0-20250808145144-a408d31f581a // indirect
- golang.org/x/mod v0.30.0 // indirect
- golang.org/x/net v0.48.0 // indirect
+ golang.org/x/mod v0.32.0 // indirect
+ golang.org/x/net v0.49.0 // indirect
golang.org/x/oauth2 v0.34.0 // indirect
- golang.org/x/sys v0.39.0 // indirect
- golang.org/x/term v0.38.0 // indirect
- golang.org/x/text v0.32.0 // indirect
- golang.org/x/time v0.13.0 // indirect
- golang.org/x/tools v0.39.0 // indirect
+ golang.org/x/sys v0.40.0 // indirect
+ golang.org/x/term v0.39.0 // indirect
+ golang.org/x/text v0.33.0 // indirect
+ golang.org/x/time v0.14.0 // indirect
+ golang.org/x/tools v0.41.0 // indirect
gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect
- google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect
+ google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
google.golang.org/protobuf v1.36.11 // indirect
gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect
diff --git a/go.sum b/go.sum
index dd2983e89..17c6e64b2 100644
--- a/go.sum
+++ b/go.sum
@@ -6,14 +6,14 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi
cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c=
cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs=
cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.19.1 h1:5YTBM8QDVIBN3sxBil89WfdAAqDZbyJTgh688DSxX5w=
-github.com/Azure/azure-sdk-for-go/sdk/azcore v1.19.1/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
-github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.12.0 h1:wL5IEG5zb7BVv1Kv0Xm92orq+5hB5Nipn3B5tn4Rqfk=
-github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.12.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc=
+github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw=
+github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4=
+github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI=
-github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI=
-github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
+github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs=
+github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk=
github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0=
github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM=
github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b h1:mimo19zliBX/vSQ6PWWSL9lK8qwHozUj03+zLoEB8O0=
@@ -24,32 +24,34 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8
github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
-github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk=
-github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE=
-github.com/aws/aws-sdk-go-v2/config v1.31.17 h1:QFl8lL6RgakNK86vusim14P2k8BFSxjvUkcWLDjgz9Y=
-github.com/aws/aws-sdk-go-v2/config v1.31.17/go.mod h1:V8P7ILjp/Uef/aX8TjGk6OHZN6IKPM5YW6S78QnRD5c=
-github.com/aws/aws-sdk-go-v2/credentials v1.18.21 h1:56HGpsgnmD+2/KpG0ikvvR8+3v3COCwaF4r+oWwOeNA=
-github.com/aws/aws-sdk-go-v2/credentials v1.18.21/go.mod h1:3YELwedmQbw7cXNaII2Wywd+YY58AmLPwX4LzARgmmA=
-github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 h1:T1brd5dR3/fzNFAQch/iBKeX07/ffu/cLu+q+RuzEWk=
-github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13/go.mod h1:Peg/GBAQ6JDt+RoBf4meB1wylmAipb7Kg2ZFakZTlwk=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 h1:a+8/MLcWlIxo1lF9xaGt3J/u3yOZx+CdSveSNwjhD40=
-github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13/go.mod h1:oGnKwIYZ4XttyU2JWxFrwvhF6YKiK/9/wmE3v3Iu9K8=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 h1:HBSI2kDkMdWz4ZM7FjwE7e/pWDEZ+nR95x8Ztet1ooY=
-github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13/go.mod h1:YE94ZoDArI7awZqJzBAZ3PDD2zSfuP7w6P2knOzIn8M=
+github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4=
+github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0=
+github.com/aws/aws-sdk-go-v2/config v1.32.6 h1:hFLBGUKjmLAekvi1evLi5hVvFQtSo3GYwi+Bx4lpJf8=
+github.com/aws/aws-sdk-go-v2/config v1.32.6/go.mod h1:lcUL/gcd8WyjCrMnxez5OXkO3/rwcNmvfno62tnXNcI=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.6 h1:F9vWao2TwjV2MyiyVS+duza0NIRtAslgLUM0vTA1ZaE=
+github.com/aws/aws-sdk-go-v2/credentials v1.19.6/go.mod h1:SgHzKjEVsdQr6Opor0ihgWtkWdfRAIwxYzSJ8O85VHY=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k=
+github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc=
+github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U=
+github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
-github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o=
-github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo=
-github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM=
-github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg=
-github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 h1:0JPwLz1J+5lEOfy/g0SURC9cxhbQ1lIMHMa+AHZSzz0=
-github.com/aws/aws-sdk-go-v2/service/sso v1.30.1/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k=
-github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 h1:OWs0/j2UYR5LOGi88sD5/lhN6TDLG6SfA7CqsQO9zF0=
-github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo=
-github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 h1:mLlUgHn02ue8whiR4BmxxGJLR2gwU6s6ZzJ5wDamBUs=
-github.com/aws/aws-sdk-go-v2/service/sts v1.39.1/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk=
-github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM=
-github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E=
+github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI=
+github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM=
+github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ=
+github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU=
+github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 h1:aM/Q24rIlS3bRAhTyFurowU8A0SMyGDtEOY/l/s/1Uw=
+github.com/aws/aws-sdk-go-v2/service/sso v1.30.8/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE=
+github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0=
+github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70=
+github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk=
+github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk=
+github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3 h1:6df1vn4bBlDDo4tARvBm7l6KA9iVMnE3NWizDeWSrps=
github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3/go.mod h1:CIWtjkly68+yqLPbvwwR/fjNJA/idrtULjZWh2v1ys0=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -87,8 +89,8 @@ github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bF
github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc=
github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g=
github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98=
-github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8=
-github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU=
+github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4=
+github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA=
github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k=
github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ=
github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU=
@@ -114,12 +116,40 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ=
github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg=
-github.com/go-openapi/jsonpointer v0.21.2 h1:AqQaNADVwq/VnkCmQg6ogE+M3FOsKTytwges0JdwVuA=
-github.com/go-openapi/jsonpointer v0.21.2/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk=
-github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ=
-github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4=
-github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU=
-github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0=
+github.com/go-openapi/jsonpointer v0.22.1 h1:sHYI1He3b9NqJ4wXLoJDKmUmHkWy/L7rtEo92JUxBNk=
+github.com/go-openapi/jsonpointer v0.22.1/go.mod h1:pQT9OsLkfz1yWoMgYFy4x3U5GY5nUlsOn1qSBH5MkCM=
+github.com/go-openapi/jsonreference v0.21.3 h1:96Dn+MRPa0nYAR8DR1E03SblB5FJvh7W6krPI0Z7qMc=
+github.com/go-openapi/jsonreference v0.21.3/go.mod h1:RqkUP0MrLf37HqxZxrIAtTWW4ZJIK1VzduhXYBEeGc4=
+github.com/go-openapi/swag v0.25.4 h1:OyUPUFYDPDBMkqyxOTkqDYFnrhuhi9NR6QVUvIochMU=
+github.com/go-openapi/swag v0.25.4/go.mod h1:zNfJ9WZABGHCFg2RnY0S4IOkAcVTzJ6z2Bi+Q4i6qFQ=
+github.com/go-openapi/swag/cmdutils v0.25.4 h1:8rYhB5n6WawR192/BfUu2iVlxqVR9aRgGJP6WaBoW+4=
+github.com/go-openapi/swag/cmdutils v0.25.4/go.mod h1:pdae/AFo6WxLl5L0rq87eRzVPm/XRHM3MoYgRMvG4A0=
+github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4=
+github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU=
+github.com/go-openapi/swag/fileutils v0.25.4 h1:2oI0XNW5y6UWZTC7vAxC8hmsK/tOkWXHJQH4lKjqw+Y=
+github.com/go-openapi/swag/fileutils v0.25.4/go.mod h1:cdOT/PKbwcysVQ9Tpr0q20lQKH7MGhOEb6EwmHOirUk=
+github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI=
+github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag=
+github.com/go-openapi/swag/jsonutils v0.25.4 h1:VSchfbGhD4UTf4vCdR2F4TLBdLwHyUDTd1/q4i+jGZA=
+github.com/go-openapi/swag/jsonutils v0.25.4/go.mod h1:7OYGXpvVFPn4PpaSdPHJBtF0iGnbEaTk8AvBkoWnaAY=
+github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4 h1:IACsSvBhiNJwlDix7wq39SS2Fh7lUOCJRmx/4SN4sVo=
+github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4/go.mod h1:Mt0Ost9l3cUzVv4OEZG+WSeoHwjWLnarzMePNDAOBiM=
+github.com/go-openapi/swag/loading v0.25.4 h1:jN4MvLj0X6yhCDduRsxDDw1aHe+ZWoLjW+9ZQWIKn2s=
+github.com/go-openapi/swag/loading v0.25.4/go.mod h1:rpUM1ZiyEP9+mNLIQUdMiD7dCETXvkkC30z53i+ftTE=
+github.com/go-openapi/swag/mangling v0.25.4 h1:2b9kBJk9JvPgxr36V23FxJLdwBrpijI26Bx5JH4Hp48=
+github.com/go-openapi/swag/mangling v0.25.4/go.mod h1:6dxwu6QyORHpIIApsdZgb6wBk/DPU15MdyYj/ikn0Hg=
+github.com/go-openapi/swag/netutils v0.25.4 h1:Gqe6K71bGRb3ZQLusdI8p/y1KLgV4M/k+/HzVSqT8H0=
+github.com/go-openapi/swag/netutils v0.25.4/go.mod h1:m2W8dtdaoX7oj9rEttLyTeEFFEBvnAx9qHd5nJEBzYg=
+github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8=
+github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0=
+github.com/go-openapi/swag/typeutils v0.25.4 h1:1/fbZOUN472NTc39zpa+YGHn3jzHWhv42wAJSN91wRw=
+github.com/go-openapi/swag/typeutils v0.25.4/go.mod h1:Ou7g//Wx8tTLS9vG0UmzfCsjZjKhpjxayRKTHXf2pTE=
+github.com/go-openapi/swag/yamlutils v0.25.4 h1:6jdaeSItEUb7ioS9lFoCZ65Cne1/RZtPBZ9A56h92Sw=
+github.com/go-openapi/swag/yamlutils v0.25.4/go.mod h1:MNzq1ulQu+yd8Kl7wPOut/YHAAU/H6hL91fF+E2RFwc=
+github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxEodtNSI1WG1c/m5Akw4=
+github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg=
+github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls=
+github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
@@ -143,14 +173,14 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0=
github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
-github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8 h1:ZI8gCoCjGzPsum4L21jHdQs8shFBIQih1TM9Rd/c+EQ=
-github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U=
+github.com/google/pprof v0.0.0-20260115054156-294ebfa9ad83 h1:z2ogiKUYzX5Is6zr/vP9vJGqPwcdqsWjOt+V8J7+bTc=
+github.com/google/pprof v0.0.0-20260115054156-294ebfa9ad83/go.mod h1:MxpfABSjhmINe3F1It9d+8exIHFvUqtLIRCdOGNXqiI=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
-github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4=
-github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
+github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ=
+github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA=
github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo=
github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc=
github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo=
@@ -165,8 +195,6 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY=
github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4=
-github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY=
-github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/joshdk/go-junit v1.0.0 h1:S86cUKIdwBHWwA6xCmFlf3RTLfVXYQfvanM5Uh+K6GE=
github.com/joshdk/go-junit v1.0.0/go.mod h1:TiiV0PqkaNfFXjEiyjWM3XXrhVyCa1K4Zfga6W52ung=
github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA=
@@ -183,10 +211,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
-github.com/llm-d/llm-d-kv-cache-manager v0.4.0 h1:MBWVpDW0PWsqNJEEAW1esrJW+Xavb0a7w14tCJWWyRY=
-github.com/llm-d/llm-d-kv-cache-manager v0.4.0/go.mod h1:ZlK7MCuz5D/weLeHyNKEmVF/eJZDyYn3XyRowTihq9o=
-github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
-github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
+github.com/llm-d/llm-d-kv-cache v0.5.0 h1:XQpkbg1yedGxn2w7QS/v/2YtrOZGp16Sw49KvMlQ1s0=
+github.com/llm-d/llm-d-kv-cache v0.5.0/go.mod h1:XyhzHBYeOWamBMPkuRySB5nJ0zzQpK/mbuXKqJRFT6A=
github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo=
github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg=
github.com/mfridman/tparse v0.18.0 h1:wh6dzOKaIwkUGyKgOntDW4liXSo37qg5AXbIhkMV3vE=
@@ -210,10 +236,10 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+
github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4=
github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s=
github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ=
-github.com/onsi/ginkgo/v2 v2.27.4 h1:fcEcQW/A++6aZAZQNUmNjvA9PSOzefMJBerHJ4t8v8Y=
-github.com/onsi/ginkgo/v2 v2.27.4/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo=
-github.com/onsi/gomega v1.39.0 h1:y2ROC3hKFmQZJNFeGAMeHZKkjBL65mIZcvrLQBF9k6Q=
-github.com/onsi/gomega v1.39.0/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4=
+github.com/onsi/ginkgo/v2 v2.28.1 h1:S4hj+HbZp40fNKuLUQOYLDgZLwNUVn19N3Atb98NCyI=
+github.com/onsi/ginkgo/v2 v2.28.1/go.mod h1:CLtbVInNckU3/+gC8LzkGUb9oF+e8W8TdUsxPwvdOgE=
+github.com/onsi/gomega v1.39.1 h1:1IJLAad4zjPn2PsnhH70V4DKRFlrCzGBNrNaru+Vf28=
+github.com/onsi/gomega v1.39.1/go.mod h1:hL6yVALoTOxeWudERyfppUcZXjMwIMLnuSfruD2lcfg=
github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0=
github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y=
github.com/pebbe/zmq4 v1.4.0 h1:gO5P92Ayl8GXpPZdYcD62Cwbq0slSBVVQRIXwGSJ6eQ=
@@ -239,8 +265,8 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo
github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM=
github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0=
github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw=
-github.com/prometheus/prometheus v0.308.1 h1:ApMNI/3/es3Ze90Z7CMb+wwU2BsSYur0m5VKeqHj7h4=
-github.com/prometheus/prometheus v0.308.1/go.mod h1:aHjYCDz9zKRyoUXvMWvu13K9XHOkBB12XrEqibs3e0A=
+github.com/prometheus/prometheus v0.309.1 h1:jutK6eCYDpWdPTUbVbkcQsNCMO9CCkSwjQRMLds4jSo=
+github.com/prometheus/prometheus v0.309.1/go.mod h1:d+dOGiVhuNDa4MaFXHVdnUBy/CzqlcNTooR8oM1wdTU=
github.com/prometheus/sigv4 v0.3.0 h1:QIG7nTbu0JTnNidGI1Uwl5AGVIChWUACxn2B/BQ1kms=
github.com/prometheus/sigv4 v0.3.0/go.mod h1:fKtFYDus2M43CWKMNtGvFNHGXnAJJEGZbiYCmVp/F8I=
github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs=
@@ -293,8 +319,8 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M
github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18=
-go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y=
+go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ=
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0=
@@ -328,20 +354,20 @@ go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
-golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU=
-golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0=
+golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8=
+golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A=
golang.org/x/exp v0.0.0-20250808145144-a408d31f581a h1:Y+7uR/b1Mw2iSXZ3G//1haIiSElDQZ8KWh0h+sZPG90=
golang.org/x/exp v0.0.0-20250808145144-a408d31f581a/go.mod h1:rT6SFzZ7oxADUDx58pcaKFTcZ+inxAa9fTrYx/uVYwg=
golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
-golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk=
-golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc=
+golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c=
+golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
-golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
-golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
+golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o=
+golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8=
golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw=
golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -352,22 +378,22 @@ golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
-golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
-golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
-golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q=
-golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg=
+golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ=
+golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
+golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY=
+golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
-golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
-golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
-golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI=
-golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
+golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE=
+golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8=
+golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI=
+golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE=
golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA=
-golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ=
-golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ=
+golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc=
+golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
@@ -376,10 +402,10 @@ gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw
gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
-google.golang.org/api v0.252.0 h1:xfKJeAJaMwb8OC9fesr369rjciQ704AjU/psjkKURSI=
-google.golang.org/api v0.252.0/go.mod h1:dnHOv81x5RAmumZ7BWLShB/u7JZNeyalImxHmtTHxqw=
-google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls=
-google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
+google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA=
+google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4=
+google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 h1:7LRqPCEdE4TP4/9psdaB7F2nhZFfBiGJomA5sojLWdU=
+google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc=
@@ -412,16 +438,16 @@ k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk=
k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE=
k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3 h1:liMHz39T5dJO1aOKHLvwaCjDbf07wVh6yaUlTpunnkE=
k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3/go.mod h1:UZ2yyWbFTpuhSbFhv24aGNOdoRdJZgsIObGBUaYVsts=
-k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d h1:wAhiDyZ4Tdtt7e46e9M5ZSAJ/MnPGPs+Ki1gHw4w1R0=
-k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
+k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck=
+k8s.io/utils v0.0.0-20251002143259-bc988d571ff4/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0=
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2 h1:jpcvIRr3GLoUoEKRkHKSmGjxb6lWwrBlJsXc+eUYQHM=
sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw=
-sigs.k8s.io/controller-runtime v0.22.4 h1:GEjV7KV3TY8e+tJ2LCTxUTanW4z/FmNB7l327UfMq9A=
-sigs.k8s.io/controller-runtime v0.22.4/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8=
+sigs.k8s.io/controller-runtime v0.22.5 h1:v3nfSUMowX/2WMp27J9slwGFyAt7IV0YwBxAkrUr0GE=
+sigs.k8s.io/controller-runtime v0.22.5/go.mod h1:pc5SoYWnWI6I+cBHYYdZ7B6YHZVY5xNfll88JB+vniI=
sigs.k8s.io/gateway-api v1.4.1 h1:NPxFutNkKNa8UfLd2CMlEuhIPMQgDQ6DXNKG9sHbJU8=
sigs.k8s.io/gateway-api v1.4.1/go.mod h1:AR5RSqciWP98OPckEjOjh2XJhAe2Na4LHyXD2FUY7Qk=
-sigs.k8s.io/gateway-api-inference-extension v1.3.0-rc.2 h1:/UBpLm3Z1HqCTfawkLKwY+PFFGawU55gvjJNJpf6LyM=
-sigs.k8s.io/gateway-api-inference-extension v1.3.0-rc.2/go.mod h1:Cyex0AlEzhuXFklzl0y5Hdf5zVY8PUtSKhzMvHh5D9M=
+sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a h1:Ce5CZ0R3c5H475uEuJ92FMgux3j99wDrSsI4ivTBEXQ=
+sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a/go.mod h1:lvMpB9a+Lk+xBi5Pk6teUG+NqA16WR8nRpmBNFJbflU=
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg=
sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg=
sigs.k8s.io/kustomize/api v0.21.0 h1:I7nry5p8iDJbuRdYS7ez8MUvw7XVNPcIP5GkzzuXIIQ=
diff --git a/pkg/plugins/datalayer/models/datasource_test.go b/pkg/plugins/datalayer/models/datasource_test.go
new file mode 100644
index 000000000..1c398ad02
--- /dev/null
+++ b/pkg/plugins/datalayer/models/datasource_test.go
@@ -0,0 +1,49 @@
+// Package models
+package models
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "github.com/stretchr/testify/assert"
+ "k8s.io/apimachinery/pkg/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+)
+
+func TestDatasource(t *testing.T) {
+ source := http.NewHTTPDataSource("https", "/models", true, ModelsDataSourceType,
+ "models-data-source", parseModels, ModelsResponseType)
+ extractor, err := NewModelExtractor()
+ assert.Nil(t, err, "failed to create extractor")
+
+ err = source.AddExtractor(extractor)
+ assert.Nil(t, err, "failed to add extractor")
+
+ err = source.AddExtractor(extractor)
+ assert.NotNil(t, err, "expected to fail to add the same extractor twice")
+
+ extractors := source.Extractors()
+ assert.Len(t, extractors, 1)
+ assert.Equal(t, extractor.TypedName().String(), extractors[0])
+
+ err = datalayer.RegisterSource(source)
+ assert.Nil(t, err, "failed to register")
+
+ ctx := context.Background()
+ factory := datalayer.NewEndpointFactory([]fwkdl.DataSource{source}, 100*time.Hour)
+ pod := &fwkdl.EndpointMetadata{
+ NamespacedName: types.NamespacedName{
+ Name: "pod1",
+ Namespace: "default",
+ },
+ Address: "1.2.3.4:5678",
+ }
+ endpoint := factory.NewEndpoint(ctx, pod, nil)
+ assert.NotNil(t, endpoint, "failed to create endpoint")
+
+ err = source.Collect(ctx, endpoint)
+ assert.NotNil(t, err, "expected to fail to collect metrics")
+}
diff --git a/pkg/plugins/datalayer/models/extractor.go b/pkg/plugins/datalayer/models/extractor.go
new file mode 100644
index 000000000..317240ab2
--- /dev/null
+++ b/pkg/plugins/datalayer/models/extractor.go
@@ -0,0 +1,102 @@
+package models
+
+import (
+ "context"
+ "fmt"
+ "reflect"
+ "strings"
+
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ fwkplugin "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+)
+
+const modelsAttributeKey = "/v1/models"
+
+// ModelInfoCollection defines models' data returned from /v1/models API
+type ModelInfoCollection []ModelInfo
+
+// ModelInfo defines model's data returned from /v1/models API
+type ModelInfo struct {
+ ID string `json:"id"`
+ Parent string `json:"parent,omitempty"`
+}
+
+// String returns a string representation of the model info
+func (m *ModelInfo) String() string {
+ return fmt.Sprintf("%+v", *m)
+}
+
+// Clone returns a full copy of the object
+func (m ModelInfoCollection) Clone() fwkdl.Cloneable {
+ if m == nil {
+ return nil
+ }
+ clone := make([]ModelInfo, len(m))
+ copy(clone, m)
+ return (*ModelInfoCollection)(&clone)
+}
+
+func (m ModelInfoCollection) String() string {
+ if m == nil {
+ return "[]"
+ }
+ parts := make([]string, len(m))
+ for i, p := range m {
+ parts[i] = p.String()
+ }
+ return "[" + strings.Join(parts, ", ") + "]"
+}
+
+// ModelResponse is the response from /v1/models API
+type ModelResponse struct {
+ Object string `json:"object"`
+ Data []ModelInfo `json:"data"`
+}
+
+// ModelsResponseType is the type of models response
+var (
+ ModelsResponseType = reflect.TypeOf(ModelResponse{})
+)
+
+// ModelExtractor implements the models extraction.
+type ModelExtractor struct {
+ typedName fwkplugin.TypedName
+}
+
+// NewModelExtractor returns a new model extractor.
+func NewModelExtractor() (*ModelExtractor, error) {
+ return &ModelExtractor{
+ typedName: fwkplugin.TypedName{
+ Type: ModelsExtractorType,
+ Name: ModelsExtractorType,
+ },
+ }, nil
+}
+
+// TypedName returns the type and name of the ModelExtractor.
+func (me *ModelExtractor) TypedName() fwkplugin.TypedName {
+ return me.typedName
+}
+
+// WithName sets the name of the extractor.
+func (me *ModelExtractor) WithName(name string) *ModelExtractor {
+ me.typedName.Name = name
+ return me
+}
+
+// ExpectedInputType defines the type expected by ModelExtractor.
+func (me *ModelExtractor) ExpectedInputType() reflect.Type {
+ return ModelsResponseType
+}
+
+// Extract transforms the data source output into a concrete attribute that
+// is stored on the given endpoint.
+func (me *ModelExtractor) Extract(_ context.Context, data any, ep fwkdl.Endpoint) error {
+ models, ok := data.(*ModelResponse)
+ if !ok {
+ return fmt.Errorf("unexpected input in Extract: %T", data)
+ }
+
+ ep.GetAttributes().Put(modelsAttributeKey, ModelInfoCollection(models.Data))
+ return nil
+}
diff --git a/pkg/plugins/datalayer/models/extractor_test.go b/pkg/plugins/datalayer/models/extractor_test.go
new file mode 100644
index 000000000..4f075f3d0
--- /dev/null
+++ b/pkg/plugins/datalayer/models/extractor_test.go
@@ -0,0 +1,113 @@
+package models
+
+import (
+ "context"
+ "testing"
+
+ "github.com/google/go-cmp/cmp"
+
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+)
+
+func TestExtractorExtract(t *testing.T) {
+ ctx := context.Background()
+
+ extractor, err := NewModelExtractor()
+ if err != nil {
+ t.Fatalf("failed to create extractor: %v", err)
+ }
+
+ if exType := extractor.TypedName().Type; exType == "" {
+ t.Error("empty extractor type")
+ }
+
+ if exName := extractor.TypedName().Name; exName == "" {
+ t.Error("empty extractor name")
+ }
+
+ if inputType := extractor.ExpectedInputType(); inputType != ModelsResponseType {
+ t.Errorf("incorrect expected input type: %v", inputType)
+ }
+
+ ep := fwkdl.NewEndpoint(nil, nil)
+ if ep == nil {
+ t.Fatal("expected non-nil endpoint")
+ }
+
+ model := "food-review"
+
+ tests := []struct {
+ name string
+ data any
+ wantErr bool
+ updated bool // whether metrics are expected to change
+ }{
+ {
+ name: "nil data",
+ data: nil,
+ wantErr: true,
+ updated: false,
+ },
+ {
+ name: "empty ModelsResponse",
+ data: &ModelResponse{},
+ wantErr: false,
+ updated: false,
+ },
+ {
+ name: "valid models response",
+ data: &ModelResponse{
+ Object: "list",
+ Data: []ModelInfo{
+ {
+ ID: model,
+ },
+ {
+ ID: "lora1",
+ Parent: model,
+ },
+ },
+ },
+ wantErr: false,
+ updated: true,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ defer func() {
+ if r := recover(); r != nil {
+ t.Errorf("Extract panicked: %v", r)
+ }
+ }()
+
+ attr := ep.GetAttributes()
+ before, ok := attr.Get(modelsAttributeKey)
+ if ok && before != nil {
+ t.Error("expected empty attributes")
+ }
+ err := extractor.Extract(ctx, tt.data, ep)
+ after, ok := attr.Get(modelsAttributeKey)
+ if !ok && tt.updated {
+ t.Error("expected updated attributes")
+ }
+
+ if tt.wantErr && err == nil {
+ t.Errorf("expected error but got nil")
+ }
+ if !tt.wantErr && err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+
+ if tt.updated {
+ if diff := cmp.Diff(before, after); diff == "" {
+ t.Errorf("expected models to be updated, but no change detected")
+ }
+ } else {
+ if diff := cmp.Diff(before, after); diff != "" {
+ t.Errorf("expected no models update, but got changes:\n%s", diff)
+ }
+ }
+ })
+ }
+}
diff --git a/pkg/plugins/datalayer/models/factories.go b/pkg/plugins/datalayer/models/factories.go
new file mode 100644
index 000000000..49cd4edad
--- /dev/null
+++ b/pkg/plugins/datalayer/models/factories.go
@@ -0,0 +1,69 @@
+package models
+
+import (
+ "encoding/json"
+ "fmt"
+ "io"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+)
+
+const (
+ // ModelsDataSourceType is models data source type
+ ModelsDataSourceType = "models-data-source"
+ // ModelsExtractorType is models extractor type
+ ModelsExtractorType = "model-server-protocol-models"
+)
+
+// Configuration parameters for models data source.
+type modelsDatasourceParams struct {
+ // Scheme defines the protocol scheme used in models retrieval (e.g., "http").
+ Scheme string `json:"scheme"`
+ // Path defines the URL path used in models retrieval (e.g., "/v1/models").
+ Path string `json:"path"`
+ // InsecureSkipVerify defines whether model server certificate should be verified or not.
+ InsecureSkipVerify bool `json:"insecureSkipVerify"`
+}
+
+// ModelDataSourceFactory is a factory function used to instantiate data layer's
+// models data source plugins specified in a configuration.
+func ModelDataSourceFactory(name string, parameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
+ cfg := defaultDataSourceConfigParams()
+ if parameters != nil { // overlay the defaults with configured values
+ if err := json.Unmarshal(parameters, cfg); err != nil {
+ return nil, err
+ }
+ }
+ if cfg.Scheme != "http" && cfg.Scheme != "https" {
+ return nil, fmt.Errorf("unsupported scheme: %s", cfg.Scheme)
+ }
+
+ ds := http.NewHTTPDataSource(cfg.Scheme, cfg.Path, cfg.InsecureSkipVerify, ModelsDataSourceType,
+ name, parseModels, ModelsResponseType)
+ return ds, nil
+}
+
+// ModelServerExtractorFactory is a factory function used to instantiate data layer's models
+// Extractor plugins specified in a configuration.
+func ModelServerExtractorFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
+ extractor, err := NewModelExtractor()
+ if err != nil {
+ return nil, err
+ }
+ return extractor.WithName(name), nil
+}
+
+func defaultDataSourceConfigParams() *modelsDatasourceParams {
+ return &modelsDatasourceParams{Scheme: "http", Path: "/v1/models", InsecureSkipVerify: true}
+}
+
+func parseModels(data io.Reader) (any, error) {
+ body, err := io.ReadAll(data)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read response body: %v", err)
+ }
+ var modelsResponse ModelResponse
+ err = json.Unmarshal(body, &modelsResponse)
+ return &modelsResponse, err
+}
diff --git a/pkg/plugins/filter/by_label.go b/pkg/plugins/filter/by_label.go
index 070b6a627..464bc81a7 100644
--- a/pkg/plugins/filter/by_label.go
+++ b/pkg/plugins/filter/by_label.go
@@ -5,9 +5,8 @@ import (
"encoding/json"
"fmt"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
const (
@@ -21,10 +20,10 @@ type byLabelParameters struct {
AllowsNoLabel bool `json:"allowsNoLabel"`
}
-var _ framework.Filter = &ByLabel{} // validate interface conformance
+var _ scheduling.Filter = &ByLabel{} // validate interface conformance
// ByLabelFactory defines the factory function for the ByLabel filter.
-func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
parameters := byLabelParameters{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
@@ -47,7 +46,7 @@ func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle
// NewByLabel creates and returns an instance of the RoleBasedFilter based on the input parameters
// name - the filter name
// labelName - the name of the label to use
-// allowsNoLabel - if true pods without given label will be considered as valid (not filtered out)
+// allowsNoLabel - if true endpoints without given label will be considered as valid (not filtered out)
// validValuesApp - list of valid values
func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues ...string) *ByLabel {
validValuesMap := map[string]struct{}{}
@@ -57,27 +56,27 @@ func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues .
}
return &ByLabel{
- typedName: plugins.TypedName{Type: ByLabelType, Name: name},
+ typedName: plugin.TypedName{Type: ByLabelType, Name: name},
labelName: labelName,
allowsNoLabel: allowsNoLabel,
validValues: validValuesMap,
}
}
-// ByLabel - filters out pods based on the values defined by the given label
+// ByLabel - filters out endpoints based on the values defined by the given label
type ByLabel struct {
// name defines the filter typed name
- typedName plugins.TypedName
+ typedName plugin.TypedName
// labelName defines the name of the label to be checked
labelName string
// validValues defines list of valid label values
validValues map[string]struct{}
- // allowsNoLabel - if true pods without given label will be considered as valid (not filtered out)
+ // allowsNoLabel - if true endpoints without given label will be considered as valid (not filtered out)
allowsNoLabel bool
}
// TypedName returns the typed name of the plugin
-func (f *ByLabel) TypedName() plugins.TypedName {
+func (f *ByLabel) TypedName() plugin.TypedName {
return f.typedName
}
@@ -87,19 +86,19 @@ func (f *ByLabel) WithName(name string) *ByLabel {
return f
}
-// Filter filters out all pods that are not marked with one of roles from the validRoles collection
+// Filter filters out all endpoints that are not marked with one of roles from the validRoles collection
// or has no role label in case allowsNoRolesLabel is true
-func (f *ByLabel) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
- filteredPods := []types.Pod{}
+func (f *ByLabel) Filter(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) []scheduling.Endpoint {
+ filteredEndpoints := []scheduling.Endpoint{}
- for _, pod := range pods {
- val, labelDefined := pod.GetPod().Labels[f.labelName]
+ for _, endpoint := range endpoints {
+ val, labelDefined := endpoint.GetMetadata().Labels[f.labelName]
_, valueExists := f.validValues[val]
if (!labelDefined && f.allowsNoLabel) || valueExists {
- filteredPods = append(filteredPods, pod)
+ filteredEndpoints = append(filteredEndpoints, endpoint)
}
}
- return filteredPods
+ return filteredEndpoints
}
diff --git a/pkg/plugins/filter/by_label_selector.go b/pkg/plugins/filter/by_label_selector.go
index 98b95d418..ceb53a0e3 100644
--- a/pkg/plugins/filter/by_label_selector.go
+++ b/pkg/plugins/filter/by_label_selector.go
@@ -8,9 +8,8 @@ import (
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
const (
@@ -19,10 +18,10 @@ const (
)
// compile-time type assertion
-var _ framework.Filter = &ByLabelSelector{}
+var _ scheduling.Filter = &ByLabelSelector{}
// ByLabelSelectorFactory defines the factory function for the ByLabelSelector filter
-func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
parameters := metav1.LabelSelector{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
@@ -44,30 +43,30 @@ func NewByLabelSelector(name string, selector *metav1.LabelSelector) (*ByLabelSe
}
return &ByLabelSelector{
- typedName: plugins.TypedName{Type: ByLabelSelectorType, Name: name},
+ typedName: plugin.TypedName{Type: ByLabelSelectorType, Name: name},
selector: labelSelector,
}, nil
}
-// ByLabelSelector filters out pods that do not match its label selector criteria
+// ByLabelSelector filters out endpoints that do not match its label selector criteria
type ByLabelSelector struct {
- typedName plugins.TypedName
+ typedName plugin.TypedName
selector labels.Selector
}
// TypedName returns the typed name of the plugin
-func (blf *ByLabelSelector) TypedName() plugins.TypedName {
+func (blf *ByLabelSelector) TypedName() plugin.TypedName {
return blf.typedName
}
-// Filter filters out all pods that do not satisfy the label selector
-func (blf *ByLabelSelector) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod {
- filtered := []types.Pod{}
+// Filter filters out all endpoints that do not satisfy the label selector
+func (blf *ByLabelSelector) Filter(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) []scheduling.Endpoint {
+ filtered := []scheduling.Endpoint{}
- for _, pod := range pods {
- labels := labels.Set(pod.GetPod().Labels)
+ for _, endpoint := range endpoints {
+ labels := labels.Set(endpoint.GetMetadata().Labels)
if blf.selector.Matches(labels) {
- filtered = append(filtered, pod)
+ filtered = append(filtered, endpoint)
}
}
return filtered
diff --git a/pkg/plugins/filter/by_label_selector_test.go b/pkg/plugins/filter/by_label_selector_test.go
index 3dc16e9d0..0d7947f09 100644
--- a/pkg/plugins/filter/by_label_selector_test.go
+++ b/pkg/plugins/filter/by_label_selector_test.go
@@ -9,9 +9,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
@@ -147,35 +146,35 @@ func TestByLabelSelectorFactoryWithInvalidJSON(t *testing.T) {
}
func TestByLabelSelectorFiltering(t *testing.T) {
- pods := []types.Pod{
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"},
+ endpoints := []scheduling.Endpoint{
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"},
"10.0.0.1",
map[string]string{
"app": "nginx",
"version": "v1.0",
"tier": "frontend",
}),
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"},
"10.0.0.2",
map[string]string{
"app": "nginx",
"version": "v1.1",
"tier": "frontend",
}),
- createPod(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"},
"10.0.0.3",
map[string]string{
"app": "coredns",
"tier": "system",
}),
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"},
"10.0.0.4",
map[string]string{
"app": "redis",
"tier": "backend",
"deprecated": "true",
}),
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"},
"10.0.0.5",
map[string]string{
"app": "web",
@@ -301,17 +300,17 @@ func TestByLabelSelectorFiltering(t *testing.T) {
ctx := utils.NewTestContext(t)
- filteredPods := blf.Filter(ctx, nil, nil, pods)
+ filteredEndpoints := blf.Filter(ctx, nil, nil, endpoints)
- var actualPodNames []string
- for _, pod := range filteredPods {
- actualPodNames = append(actualPodNames, pod.GetPod().NamespacedName.Name)
+ var actualEndpointNames []string
+ for _, endpoint := range filteredEndpoints {
+ actualEndpointNames = append(actualEndpointNames, endpoint.GetMetadata().NamespacedName.Name)
}
- assert.ElementsMatch(t, tt.expectedPods, actualPodNames,
- "filtered pods should match expected pods")
- assert.Len(t, filteredPods, len(tt.expectedPods),
- "filtered pods count should match expected count")
+ assert.ElementsMatch(t, tt.expectedPods, actualEndpointNames,
+ "filtered endpoints should match expected endpoints")
+ assert.Len(t, filteredEndpoints, len(tt.expectedPods),
+ "filtered endpoints count should match expected count")
})
}
}
@@ -326,26 +325,26 @@ func TestByLabelSelectorFilterEdgeCases(t *testing.T) {
ctx := utils.NewTestContext(t)
- t.Run("empty pods slice", func(t *testing.T) {
- result := blf.Filter(ctx, nil, nil, []types.Pod{})
+ t.Run("empty endpoints slice", func(t *testing.T) {
+ result := blf.Filter(ctx, nil, nil, []scheduling.Endpoint{})
assert.Empty(t, result)
})
- t.Run("nil pods slice", func(t *testing.T) {
+ t.Run("nil endpoints slice", func(t *testing.T) {
result := blf.Filter(ctx, nil, nil, nil)
assert.Empty(t, result)
})
- t.Run("pods with nil labels", func(t *testing.T) {
- pods := []types.Pod{createPod(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", nil)}
- result := blf.Filter(ctx, nil, nil, pods)
- assert.Empty(t, result, "pod with nil labels should not match")
+ t.Run("endpoints with nil labels", func(t *testing.T) {
+ endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", nil)}
+ result := blf.Filter(ctx, nil, nil, endpoints)
+ assert.Empty(t, result, "endpoint with nil labels should not match")
})
- t.Run("pods with empty labels", func(t *testing.T) {
- pods := []types.Pod{createPod(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", map[string]string{})}
- result := blf.Filter(ctx, nil, nil, pods)
- assert.Empty(t, result, "pod with empty labels should not match")
+ t.Run("endpoints with empty labels", func(t *testing.T) {
+ endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", map[string]string{})}
+ result := blf.Filter(ctx, nil, nil, endpoints)
+ assert.Empty(t, result, "endpoint with empty labels should not match")
})
}
@@ -371,7 +370,7 @@ func ExamplePrefillDecodeRolesInLWS() {
plugin, _ = filter.ByLabelSelectorFactory("prefill-role", prefillWorkerJSON, nil)
prefillworker, _ := plugin.(*filter.ByLabelSelector)
- pods := []types.Pod{createPod(k8stypes.NamespacedName{Namespace: "default", Name: "vllm"},
+ endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "vllm"},
"10.0.0.1",
map[string]string{
"app.kubernetes.io/component": "vllm-worker",
@@ -383,7 +382,7 @@ func ExamplePrefillDecodeRolesInLWS() {
name := ""
for _, blf := range []*filter.ByLabelSelector{decodeLeader, decodeFollower, prefillworker} {
- filtered := PrefillDecodeRolesInLWS(blf, pods)
+ filtered := PrefillDecodeRolesInLWS(blf, endpoints)
if len(filtered) > 0 {
name = blf.TypedName().Name
}
@@ -395,17 +394,18 @@ func ExamplePrefillDecodeRolesInLWS() {
}
// Helper functions
-func createPod(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) types.Pod {
- return &types.PodMetrics{
- Pod: &backend.Pod{
+func createEndpoint(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) scheduling.Endpoint {
+ return scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: nsn,
Address: ipaddr,
Labels: labels,
},
- MetricsState: &backendmetrics.MetricsState{},
- }
+ &fwkdl.Metrics{},
+ nil,
+ )
}
-func PrefillDecodeRolesInLWS(blf *filter.ByLabelSelector, pods []types.Pod) []types.Pod {
- return blf.Filter(context.Background(), nil, nil, pods)
+func PrefillDecodeRolesInLWS(blf *filter.ByLabelSelector, endpoints []scheduling.Endpoint) []scheduling.Endpoint {
+ return blf.Filter(context.Background(), nil, nil, endpoints)
}
diff --git a/pkg/plugins/filter/by_label_test.go b/pkg/plugins/filter/by_label_test.go
index a3af75a4b..933871b61 100644
--- a/pkg/plugins/filter/by_label_test.go
+++ b/pkg/plugins/filter/by_label_test.go
@@ -8,9 +8,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)
@@ -139,54 +138,55 @@ func TestByLabelFactoryInvalidJSON(t *testing.T) {
}
// Helper functions
-func createPod(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) types.Pod {
- return &types.PodMetrics{
- Pod: &backend.Pod{
+func createEndpoint(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) scheduling.Endpoint {
+ return scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: nsn,
Address: ipaddr,
Labels: labels,
},
- MetricsState: &backendmetrics.MetricsState{},
- }
+ &fwkdl.Metrics{},
+ nil,
+ )
}
func TestByLabelFiltering(t *testing.T) {
- pods := []types.Pod{
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"},
+ endpoints := []scheduling.Endpoint{
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"},
"10.0.0.1",
map[string]string{
"app": "nginx",
"version": "v1.0",
"tier": "frontend",
}),
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"},
"10.0.0.2",
map[string]string{
"app": "nginx",
"version": "v1.1",
"tier": "frontend",
}),
- createPod(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"},
"10.0.0.3",
map[string]string{
"app": "coredns",
"tier": "system",
}),
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"},
"10.0.0.4",
map[string]string{
"app": "redis",
"tier": "backend",
"deprecated": "true",
}),
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"},
"10.0.0.5",
map[string]string{
"app": "web",
"tier": "frontend",
"environment": "production",
}),
- createPod(k8stypes.NamespacedName{Namespace: "default", Name: "no-tier-pod"},
+ createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "no-tier-pod"},
"10.0.0.6",
map[string]string{
"app": "unknown",
@@ -247,17 +247,17 @@ func TestByLabelFiltering(t *testing.T) {
ctx := utils.NewTestContext(t)
- filteredPods := blf.Filter(ctx, nil, nil, pods)
+ filteredEndpoints := blf.Filter(ctx, nil, nil, endpoints)
- var actualPodNames []string
- for _, pod := range filteredPods {
- actualPodNames = append(actualPodNames, pod.GetPod().NamespacedName.Name)
+ var actualEndpointNames []string
+ for _, endpoint := range filteredEndpoints {
+ actualEndpointNames = append(actualEndpointNames, endpoint.GetMetadata().NamespacedName.Name)
}
- assert.ElementsMatch(t, tt.expectedPods, actualPodNames,
- "filtered pods should match expected pods")
- assert.Len(t, filteredPods, len(tt.expectedPods),
- "filtered pods count should match expected count")
+ assert.ElementsMatch(t, tt.expectedPods, actualEndpointNames,
+ "filtered endpoints should match expected endpoints")
+ assert.Len(t, filteredEndpoints, len(tt.expectedPods),
+ "filtered endpoints count should match expected count")
})
}
}
diff --git a/pkg/plugins/filter/pd_role.go b/pkg/plugins/filter/pd_role.go
index cc4cf7448..da96e7893 100644
--- a/pkg/plugins/filter/pd_role.go
+++ b/pkg/plugins/filter/pd_role.go
@@ -3,7 +3,7 @@ package filter
import (
"encoding/json"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
)
const (
@@ -23,7 +23,7 @@ const (
)
// PrefillRoleFactory defines the factory function for the Prefill filter.
-func PrefillRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func PrefillRoleFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
return NewPrefillRole().WithName(name), nil
}
@@ -33,7 +33,7 @@ func NewPrefillRole() *ByLabel {
}
// DecodeRoleFactory defines the factory function for the Decode filter.
-func DecodeRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func DecodeRoleFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
return NewDecodeRole().WithName(name), nil
}
diff --git a/pkg/plugins/pre-request/pd_prerequest.go b/pkg/plugins/pre-request/pd_prerequest.go
index beebbe46c..c77fc700f 100644
--- a/pkg/plugins/pre-request/pd_prerequest.go
+++ b/pkg/plugins/pre-request/pd_prerequest.go
@@ -7,9 +7,9 @@ import (
"fmt"
"net"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
)
@@ -29,7 +29,7 @@ type prefillHeaderHandlerParameters struct {
var _ requestcontrol.PreRequest = &PrefillHeaderHandler{}
// PrefillHeaderHandlerFactory defines the factory function for the PrefillHeaderHandler
-func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
parameters := prefillHeaderHandlerParameters{
PrefillProfile: defaultPrefillProfile,
}
@@ -44,19 +44,19 @@ func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ p
// NewPrefillHeaderHandler initializes a new PrefillHeaderHandler and returns its pointer.
func NewPrefillHeaderHandler(prefillProfile string) *PrefillHeaderHandler {
return &PrefillHeaderHandler{
- typedName: plugins.TypedName{Type: PrefillHeaderHandlerType},
+ typedName: plugin.TypedName{Type: PrefillHeaderHandlerType},
prefillProfile: prefillProfile,
}
}
// PrefillHeaderHandler PreRequest plugin
type PrefillHeaderHandler struct {
- typedName plugins.TypedName
+ typedName plugin.TypedName
prefillProfile string
}
// TypedName returns the typed name of the plugin.
-func (p *PrefillHeaderHandler) TypedName() plugins.TypedName {
+func (p *PrefillHeaderHandler) TypedName() plugin.TypedName {
return p.typedName
}
@@ -67,7 +67,7 @@ func (p *PrefillHeaderHandler) WithName(name string) *PrefillHeaderHandler {
}
// PreRequest wires prefill SchedulerProfile result into a header to indicate prefill worker
-func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
+func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) {
if _, found := request.Headers[common.PrefillPodHeader]; found {
request.Headers[common.PrefillPodHeader] = "" // clear header, if already set
}
@@ -77,7 +77,7 @@ func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *types.LLMR
return // prefill profile failed to run or we chose not to run it, no-op in this case
}
- targetPod := prefillProfileRunResult.TargetPods[0].GetPod()
+ targetPod := prefillProfileRunResult.TargetEndpoints[0].GetMetadata()
prefillHostPort := net.JoinHostPort(targetPod.Address, targetPod.Port)
request.Headers[common.PrefillPodHeader] = prefillHostPort // in the form of
}
diff --git a/pkg/plugins/profile/always_disagg_decider.go b/pkg/plugins/profile/always_disagg_decider.go
new file mode 100644
index 000000000..1fefd47fe
--- /dev/null
+++ b/pkg/plugins/profile/always_disagg_decider.go
@@ -0,0 +1,48 @@
+package profile
+
+import (
+ "context"
+ "encoding/json"
+
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+)
+
+const (
+ // AlwaysDisaggDeciderPluginType is the type-name of the alwaysDisaggPDDecider plugin.
+ AlwaysDisaggDeciderPluginType = "always-disagg-pd-decider"
+)
+
+// compile-time type assertion
+var _ pdDeciderPlugin = &AlwaysDisaggPDDecider{}
+
+// AlwaysDisaggPDDecider is a PD decider plugin which always decide to disaggregate PD
+type AlwaysDisaggPDDecider struct {
+ typedName plugin.TypedName
+}
+
+// AlwaysDisaggPDDeciderPluginFactory defines the factory function for creating
+// a new instance of the AlwaysDisaggPDDecider.
+func AlwaysDisaggPDDeciderPluginFactory(name string, _ json.RawMessage,
+ _ plugin.Handle) (plugin.Plugin, error) {
+ return newAlwaysDisaggPDDecider().WithName(name), nil
+}
+
+func newAlwaysDisaggPDDecider() *AlwaysDisaggPDDecider {
+ return &AlwaysDisaggPDDecider{}
+}
+
+// TypedName returns the typed name of the plugin.
+func (d *AlwaysDisaggPDDecider) TypedName() plugin.TypedName {
+ return d.typedName
+}
+
+// WithName sets the name of the plugin.
+func (d *AlwaysDisaggPDDecider) WithName(name string) *AlwaysDisaggPDDecider {
+ d.typedName.Name = name
+ return d
+}
+
+func (d *AlwaysDisaggPDDecider) disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool {
+ return true
+}
diff --git a/pkg/plugins/profile/dp_profile_handler.go b/pkg/plugins/profile/dp_profile_handler.go
index 5d71384ca..836872584 100644
--- a/pkg/plugins/profile/dp_profile_handler.go
+++ b/pkg/plugins/profile/dp_profile_handler.go
@@ -8,9 +8,9 @@ import (
"net"
"strconv"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ "sigs.k8s.io/controller-runtime/pkg/log"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
)
@@ -25,10 +25,10 @@ type dataParallelProfileHandlerParameters struct {
}
// compile-time type assertion
-var _ framework.ProfileHandler = &DataParallelProfileHandler{}
+var _ scheduling.ProfileHandler = &DataParallelProfileHandler{}
// DataParallelProfileHandlerFactory defines the factory function for the DataParallelProfileHandler
-func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
parameters := dataParallelProfileHandlerParameters{
PrimaryPort: 8000,
}
@@ -50,19 +50,19 @@ func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessag
// NewDataParallelProfileHandler initializes a new PdProfileHandler and returns its pointer.
func NewDataParallelProfileHandler(primaryPort int) *DataParallelProfileHandler {
return &DataParallelProfileHandler{
- typedName: plugins.TypedName{Type: DataParallelProfileHandlerType},
+ typedName: plugin.TypedName{Type: DataParallelProfileHandlerType},
primaryPort: strconv.Itoa(primaryPort),
}
}
// DataParallelProfileHandler handles scheduler profiles for Data Parallel.
type DataParallelProfileHandler struct {
- typedName plugins.TypedName
+ typedName plugin.TypedName
primaryPort string
}
// TypedName returns the typed name of the plugin.
-func (h *DataParallelProfileHandler) TypedName() plugins.TypedName {
+func (h *DataParallelProfileHandler) TypedName() plugin.TypedName {
return h.typedName
}
@@ -74,12 +74,19 @@ func (h *DataParallelProfileHandler) WithName(name string) *DataParallelProfileH
// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the
// previously executed cycles along with their results.
-func (h *DataParallelProfileHandler) Pick(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, profiles map[string]*framework.SchedulerProfile,
- profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
+func (h *DataParallelProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile,
+ profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile {
if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call
- return map[string]*framework.SchedulerProfile{}
+ return map[string]scheduling.SchedulerProfile{}
}
- // return all profiles
+ // Validate that only one profile is configured for Data Parallel mode
+ if len(profiles) != 1 {
+ log.FromContext(ctx).Error(nil, "Data Parallel profile handler requires exactly one scheduling profile",
+ "profileCount", len(profiles),
+ )
+ return map[string]scheduling.SchedulerProfile{} // return empty map for fast exit in later steps
+ }
+ // return only one profile
return profiles
}
@@ -87,8 +94,8 @@ func (h *DataParallelProfileHandler) Pick(_ context.Context, _ *types.CycleState
// It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the
// key of the primary profile that should be used to get the request selected destination.
// When a profile run fails, its result in the profileResults map is nil.
-func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, request *types.LLMRequest,
- profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
+func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest,
+ profileResults map[string]*scheduling.ProfileRunResult) (*scheduling.SchedulingResult, error) {
if len(profileResults) != 1 {
return nil, errors.New("data parallel profile handler is intended to be used with a single profile, failed to process multiple profiles")
}
@@ -104,23 +111,23 @@ func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *types.
return nil, fmt.Errorf("failed to run scheduler profile '%s'", singleProfileName)
}
- newResult := types.ProfileRunResult{
- TargetPods: []types.Pod{},
+ newResult := scheduling.ProfileRunResult{
+ TargetEndpoints: []scheduling.Endpoint{},
}
- targetPod := profileResult.TargetPods[0].GetPod()
+ targetPod := profileResult.TargetEndpoints[0].GetMetadata()
request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetPod.Address, targetPod.Port)
- for _, target := range profileResult.TargetPods {
- newPodInfo := target.GetPod().Clone()
- newPodInfo.Port = h.primaryPort
- targetPod := &types.PodMetrics{Pod: newPodInfo, MetricsState: target.GetMetrics().Clone()}
- newResult.TargetPods = append(newResult.TargetPods, targetPod)
+ for _, target := range profileResult.TargetEndpoints {
+ newMetadata := target.GetMetadata().Clone()
+ newMetadata.Port = h.primaryPort
+ targetEndpoint := scheduling.NewEndpoint(newMetadata, target.GetMetrics().Clone(), nil)
+ newResult.TargetEndpoints = append(newResult.TargetEndpoints, targetEndpoint)
}
- modifiedResults := map[string]*types.ProfileRunResult{singleProfileName: &newResult}
+ modifiedResults := map[string]*scheduling.ProfileRunResult{singleProfileName: &newResult}
- return &types.SchedulingResult{
+ return &scheduling.SchedulingResult{
ProfileResults: modifiedResults,
PrimaryProfileName: singleProfileName,
}, nil
diff --git a/pkg/plugins/profile/dp_profile_handler_test.go b/pkg/plugins/profile/dp_profile_handler_test.go
index 30a7a6da4..9bb942327 100644
--- a/pkg/plugins/profile/dp_profile_handler_test.go
+++ b/pkg/plugins/profile/dp_profile_handler_test.go
@@ -8,7 +8,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
)
@@ -120,17 +120,86 @@ func TestDataParallelProfileHandlerFactoryInvalidJSON(t *testing.T) {
}
}
+func Test_DataParallelProfileHandler_Pick(t *testing.T) {
+ tests := []struct {
+ name string
+ profiles map[string]scheduling.SchedulerProfile
+ profileResults map[string]*scheduling.ProfileRunResult
+ expectEmptyResult bool
+ expectLogError bool
+ description string
+ }{
+ {
+ name: "success: single profile, first call",
+ profiles: map[string]scheduling.SchedulerProfile{
+ "default": newMockSchedulerProfile(),
+ },
+ profileResults: map[string]*scheduling.ProfileRunResult{},
+ expectEmptyResult: false,
+ expectLogError: false,
+ description: "Should return the single profile to run",
+ },
+ {
+ name: "success: single profile, second call (all already executed)",
+ profiles: map[string]scheduling.SchedulerProfile{
+ "default": newMockSchedulerProfile(),
+ },
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ "default": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
+ },
+ expectEmptyResult: true,
+ expectLogError: false,
+ description: "Should return empty map since all profiles have been executed already in previous call",
+ },
+ {
+ name: "error: multiple profiles configured in EPP",
+ profiles: map[string]scheduling.SchedulerProfile{
+ "profile1": newMockSchedulerProfile(),
+ "profile2": newMockSchedulerProfile(),
+ },
+ profileResults: map[string]*scheduling.ProfileRunResult{},
+ expectEmptyResult: true,
+ expectLogError: true,
+ description: "Should return empty map and log error for multiple profiles",
+ },
+ {
+ name: "error: zero profiles configured in EPP",
+ profiles: map[string]scheduling.SchedulerProfile{},
+ profileResults: map[string]*scheduling.ProfileRunResult{},
+ expectEmptyResult: true,
+ expectLogError: true,
+ description: "Should return empty map and log error for zero profiles",
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ handler := NewDataParallelProfileHandler(8000).WithName("test-handler")
+ ctx := context.Background()
+
+ result := handler.Pick(ctx, &scheduling.CycleState{}, &scheduling.LLMRequest{}, tt.profiles, tt.profileResults)
+
+ if tt.expectEmptyResult {
+ assert.Empty(t, result, tt.description)
+ } else {
+ assert.NotEmpty(t, result, tt.description)
+ assert.Equal(t, len(tt.profiles), len(result), "Should return all profiles when valid")
+ }
+ })
+ }
+}
+
func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) {
tests := []struct {
name string
primaryPort int
- profileResults map[string]*types.ProfileRunResult
+ profileResults map[string]*scheduling.ProfileRunResult
expectError bool
- checkResult func(*testing.T, *types.SchedulingResult, map[string]string)
+ checkResult func(*testing.T, *scheduling.SchedulingResult, map[string]string)
}{
{
name: "error: multiple profiles not supported",
- profileResults: map[string]*types.ProfileRunResult{
+ profileResults: map[string]*scheduling.ProfileRunResult{
"profile1": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
"profile2": newMockProfileRunResult(DefaultTestPodPort, "pod2"),
},
@@ -138,7 +207,7 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) {
},
{
name: "error: single profile but result is nil",
- profileResults: map[string]*types.ProfileRunResult{
+ profileResults: map[string]*scheduling.ProfileRunResult{
"nil-profile": nil,
},
expectError: true,
@@ -146,16 +215,16 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) {
{
name: "success: single profile with primaryPort β port overridden, header set",
primaryPort: 9000,
- profileResults: map[string]*types.ProfileRunResult{
+ profileResults: map[string]*scheduling.ProfileRunResult{
"dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
},
expectError: false,
- checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
+ checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) {
assert.Equal(t, "dp-profile", res.PrimaryProfileName)
- pods := res.ProfileResults["dp-profile"].TargetPods
+ pods := res.ProfileResults["dp-profile"].TargetEndpoints
require.Len(t, pods, 1)
- assert.Equal(t, "9000", pods[0].GetPod().Port) // overridden
+ assert.Equal(t, "9000", pods[0].GetMetadata().Port) // overridden
expectedHeader := net.JoinHostPort("10.0.0.1", DefaultTestPodPort) // original
assert.Equal(t, expectedHeader, headers[common.DataParallelPodHeader])
},
@@ -163,28 +232,28 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) {
{
name: "success: primaryPort=0 β port becomes '0'",
primaryPort: 0,
- profileResults: map[string]*types.ProfileRunResult{
+ profileResults: map[string]*scheduling.ProfileRunResult{
"dp": newMockProfileRunResult("8080", "pod1"),
},
expectError: false,
- checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
- pod := res.ProfileResults["dp"].TargetPods[0]
- assert.Equal(t, "0", pod.GetPod().Port)
+ checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) {
+ pod := res.ProfileResults["dp"].TargetEndpoints[0]
+ assert.Equal(t, "0", pod.GetMetadata().Port)
assert.Equal(t, "10.0.0.1:8080", headers[common.DataParallelPodHeader])
},
},
{
name: "success: multiple target pods β all ports overridden",
primaryPort: 8080,
- profileResults: map[string]*types.ProfileRunResult{
+ profileResults: map[string]*scheduling.ProfileRunResult{
"dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1", "pod2"),
},
expectError: false,
- checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
- pods := res.ProfileResults["dp-profile"].TargetPods
+ checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) {
+ pods := res.ProfileResults["dp-profile"].TargetEndpoints
assert.Len(t, pods, 2)
for _, p := range pods {
- assert.Equal(t, "8080", p.GetPod().Port)
+ assert.Equal(t, "8080", p.GetMetadata().Port)
}
assert.Equal(t, net.JoinHostPort("10.0.0.1", DefaultTestPodPort), headers[common.DataParallelPodHeader])
},
@@ -195,8 +264,8 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
handler := NewDataParallelProfileHandler(tt.primaryPort).WithName("test-handler")
headers := make(map[string]string)
- req := &types.LLMRequest{Headers: headers}
- result, err := handler.ProcessResults(context.Background(), &types.CycleState{}, req, tt.profileResults)
+ req := &scheduling.LLMRequest{Headers: headers}
+ result, err := handler.ProcessResults(context.Background(), &scheduling.CycleState{}, req, tt.profileResults)
if tt.expectError {
assert.Error(t, err)
diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go
index 8dff33e43..3f5a4b932 100644
--- a/pkg/plugins/profile/pd_profile_handler.go
+++ b/pkg/plugins/profile/pd_profile_handler.go
@@ -9,49 +9,56 @@ import (
"net"
"strconv"
- "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics"
-
"sigs.k8s.io/controller-runtime/pkg/log"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix"
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
+ "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics"
)
const (
// PdProfileHandlerType is the type of the PdProfileHandler
PdProfileHandlerType = "pd-profile-handler"
- defaultDecodeProfile = "decode"
- defaultPrefillProfile = "prefill"
- defaultPrefixPluginType = prefix.PrefixCachePluginType
+ defaultDecodeProfile = "decode"
+ defaultPrefillProfile = "prefill"
+ defaultPrefixPluginType = prefix.PrefixCachePluginType
+ defaultDeciderPluginName = AlwaysDisaggDeciderPluginType
+
+ // AverageCharactersPerToken is an estimated average characters per token, used since the request we cached is not tokenized.
+ AverageCharactersPerToken = 4
)
+// pdDeciderPlugin interface for pd decider plugins
+type pdDeciderPlugin interface {
+ plugin.Plugin
+ // disaggregate checks if disaggregated PD is required for the given request and endpoint.
+ disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool
+}
+
type pdProfileHandlerParameters struct {
- Threshold int `json:"threshold"`
- DecodeProfile string `json:"decodeProfile"`
- PrefillProfile string `json:"prefillProfile"`
- PrefixPluginType string `json:"prefixPluginType"`
- PrefixPluginName string `json:"prefixPluginName"`
- HashBlockSize int `json:"hashBlockSize"`
- PrimaryPort int `json:"primaryPort"`
+ DecodeProfile string `json:"decodeProfile"`
+ PrefillProfile string `json:"prefillProfile"`
+ PrefixPluginType string `json:"prefixPluginType"`
+ PrefixPluginName string `json:"prefixPluginName"`
+ PrimaryPort int `json:"primaryPort"`
+ DeciderPluginName string `json:"deciderPluginName"`
}
// compile-time type assertion
-var _ framework.ProfileHandler = &PdProfileHandler{}
+var _ scheduling.ProfileHandler = &PdProfileHandler{}
// PdProfileHandlerFactory defines the factory function for the PdProfileHandler
-func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) {
parameters := pdProfileHandlerParameters{
- Threshold: 0,
- DecodeProfile: defaultDecodeProfile,
- PrefillProfile: defaultPrefillProfile,
- PrefixPluginType: defaultPrefixPluginType,
- HashBlockSize: prefix.DefaultBlockSize,
- PrimaryPort: 0,
+ DecodeProfile: defaultDecodeProfile,
+ PrefillProfile: defaultPrefillProfile,
+ PrefixPluginType: defaultPrefixPluginType,
+ PrimaryPort: 0,
+ DeciderPluginName: defaultDeciderPluginName,
}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
@@ -63,54 +70,66 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi
parameters.PrefixPluginName = parameters.PrefixPluginType
}
- if parameters.Threshold < 0 {
- return nil, fmt.Errorf("invalid threshold: must be >= 0, got %d", parameters.Threshold)
- }
-
- if parameters.HashBlockSize <= 0 {
- return nil, fmt.Errorf("invalid hashBlockSize: must be > 0, got %d", parameters.HashBlockSize)
- }
-
if parameters.PrimaryPort != 0 {
if parameters.PrimaryPort < 1 || parameters.PrimaryPort > 65535 {
return nil, fmt.Errorf("invalid primaryPort: must be between 1 and 65535, got %d", parameters.PrimaryPort)
}
}
- return NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName,
- parameters.Threshold, parameters.HashBlockSize, parameters.PrimaryPort).WithName(name), nil
+ if parameters.DeciderPluginName == "" {
+ return nil, errors.New("decider plugin name is not defined")
+ }
+
+ plugin := handle.Plugin(parameters.DeciderPluginName)
+ if plugin == nil {
+ return nil, fmt.Errorf("invalid decider plugin type: %s", parameters.DeciderPluginName)
+ }
+
+ deciderPlugin, ok := plugin.(pdDeciderPlugin)
+ if !ok {
+ return nil, fmt.Errorf("decider plugin of type: %s does not implement pdDeciderPlugin", parameters.DeciderPluginName)
+ }
+
+ handler, err := NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName,
+ parameters.PrimaryPort, deciderPlugin)
+
+ if err != nil {
+ return nil, err
+ }
+
+ return handler.WithName(name), nil
+
}
// NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer.
-func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, pdThreshold, hashBlockSize, primaryPort int) *PdProfileHandler {
+func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string,
+ primaryPort int, deciderPlugin pdDeciderPlugin) (*PdProfileHandler, error) {
result := &PdProfileHandler{
- typedName: plugins.TypedName{Type: PdProfileHandlerType},
- prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName},
+ typedName: plugin.TypedName{Type: PdProfileHandlerType},
+ prefixPluginTypedName: plugin.TypedName{Type: prefixPluginType, Name: prefixPluginName},
decodeProfile: decodeProfile,
prefillProfile: prefillProfile,
- pdThreshold: pdThreshold,
- hashBlockSize: hashBlockSize,
+ decider: deciderPlugin,
}
if primaryPort != 0 {
result.primaryPort = strconv.Itoa(primaryPort)
}
- return result
+ return result, nil
}
// PdProfileHandler handles scheduler profiles for PD.
type PdProfileHandler struct {
- typedName plugins.TypedName
- prefixPluginTypedName plugins.TypedName
+ typedName plugin.TypedName
+ prefixPluginTypedName plugin.TypedName
decodeProfile string
prefillProfile string
- pdThreshold int
- hashBlockSize int
primaryPort string
+ decider pdDeciderPlugin
}
// TypedName returns the typed name of the plugin.
-func (h *PdProfileHandler) TypedName() plugins.TypedName {
+func (h *PdProfileHandler) TypedName() plugin.TypedName {
return h.typedName
}
@@ -122,11 +141,11 @@ func (h *PdProfileHandler) WithName(name string) *PdProfileHandler {
// Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the
// previously executed cycles along with their results.
-func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile,
- profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile {
+func (h *PdProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile,
+ profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile {
if _, executed := profileResults[h.decodeProfile]; !executed {
// if decode profile was not executed yet, first let the scheduler run the decode profile
- return map[string]*framework.SchedulerProfile{
+ return map[string]scheduling.SchedulerProfile{
h.decodeProfile: profiles[h.decodeProfile],
}
}
@@ -135,75 +154,56 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat
// when a profile run fails its result value is nil. we need to check decode result before continuing to prefill
// check if all configured profiles have been executed, or if decode failed, no need to run more profiles.
if len(profiles) == len(profileResults) || profileResults[h.decodeProfile] == nil {
- return map[string]*framework.SchedulerProfile{}
+ return map[string]scheduling.SchedulerProfile{}
}
- if h.pdThreshold > 0 {
- userInput, err := getUserInputBytes(request)
- if err != nil {
- log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input bytes")
- return nil
- }
-
- // if we're here that means decode profile ran successfully, and we have additional profile configured that didn't run yet,
- // which means PD is enabled (otherwise, prefill profile is not configured at all and this profile handler is not used).
- // inspect decode execution result to decide if prefill should run or not.
- // if the request is short enough, use decode results only and don't run the prefill profile.
- hitPercentagePrefix := 0.0 // default to 0, meaning no prefix cache hit
- prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(h.prefixPluginTypedName.String()))
- if err != nil {
- log.FromContext(ctx).Error(err, "unable to read prefix state")
- } else {
- decodePod := profileResults[h.decodeProfile].TargetPods[0].GetPod().NamespacedName
- hitPrefix := max(prefixState.PrefixCacheServers[prefix.ServerID(decodePod)]-1, 0) // The first hit is always the model name
- hitPercentagePrefix = float64(hitPrefix*h.hashBlockSize) / float64(len(userInput))
- log.FromContext(ctx).V(logutil.DEBUG).Info("Computed hit percentage for prefix cache", "hitPercentage", hitPercentagePrefix,
- "promptLength", len(userInput))
- }
+ inputTokens, err := getUserInputLenInTokens(request)
+ if err != nil {
+ log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input")
+ return nil
+ }
- if (1.0-hitPercentagePrefix)*float64(len(userInput)) < float64(h.pdThreshold) {
- log.FromContext(ctx).Info("Non-cached suffix is smaller than threshold, using decode profile only", "hitPercentage", hitPercentagePrefix)
- metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly)
- return map[string]*framework.SchedulerProfile{} // do not run prefill
+ if h.decider != nil && h.decider.disaggregate(ctx, inputTokens, profileResults[h.decodeProfile].TargetEndpoints[0]) {
+ metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode)
+ // run the prefill profile
+ return map[string]scheduling.SchedulerProfile{
+ h.prefillProfile: profiles[h.prefillProfile],
}
}
- metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode)
- // run the prefill profile
- return map[string]*framework.SchedulerProfile{
- h.prefillProfile: profiles[h.prefillProfile],
- }
+ metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly)
+ return map[string]scheduling.SchedulerProfile{} // do not run prefill
}
// ProcessResults handles the outcome of the profile runs after the selected profiles ran.
// In case of an error in any of the profiles, the matching entry in the profileResults will contain nil, to indicate there was
// an error while running the profile.
-func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, request *types.LLMRequest,
- profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) {
+func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest,
+ profileResults map[string]*scheduling.ProfileRunResult) (*scheduling.SchedulingResult, error) {
decodeRunResults := profileResults[h.decodeProfile]
if decodeRunResults == nil { // if decode profile failed to run, we should fail
return nil, errors.New("failed to find available decode workers")
}
// otherwise, decode ran successfully
- updatedResults := map[string]*types.ProfileRunResult{}
+ updatedResults := map[string]*scheduling.ProfileRunResult{}
// Add decode profile to result
if h.primaryPort != "" {
// Data Parallel is active
- targetPod := decodeRunResults.TargetPods[0].GetPod()
- request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetPod.Address, targetPod.Port)
+ targetEndpoint := decodeRunResults.TargetEndpoints[0].GetMetadata()
+ request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetEndpoint.Address, targetEndpoint.Port)
- updatedResult := types.ProfileRunResult{
- TargetPods: []types.Pod{},
+ updatedResult := scheduling.ProfileRunResult{
+ TargetEndpoints: []scheduling.Endpoint{},
}
- for _, target := range decodeRunResults.TargetPods {
- updatedPodInfo := target.GetPod().Clone()
- updatedPodInfo.Port = h.primaryPort
- targetPod := &types.PodMetrics{Pod: updatedPodInfo, MetricsState: target.GetMetrics().Clone()}
- updatedResult.TargetPods = append(updatedResult.TargetPods, targetPod)
+ for _, target := range decodeRunResults.TargetEndpoints {
+ updatedEndpointInfo := target.GetMetadata().Clone()
+ updatedEndpointInfo.Port = h.primaryPort
+ targetEndpoint := scheduling.NewEndpoint(updatedEndpointInfo, target.GetMetrics().Clone(), nil)
+ updatedResult.TargetEndpoints = append(updatedResult.TargetEndpoints, targetEndpoint)
}
updatedResults[h.decodeProfile] = &updatedResult
} else {
@@ -216,17 +216,24 @@ func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState
updatedResults[h.prefillProfile] = prefillRunResult
}
- return &types.SchedulingResult{
+ return &scheduling.SchedulingResult{
PrimaryProfileName: h.decodeProfile,
ProfileResults: updatedResults,
}, nil
}
-func getUserInputBytes(request *types.LLMRequest) ([]byte, error) {
+// returns length of user input in tokens
+func getUserInputLenInTokens(request *scheduling.LLMRequest) (int, error) {
if request.Body.Completions != nil { // assumed to be valid if not nil
- return []byte(request.Body.Completions.Prompt), nil
+ return len([]byte(request.Body.Completions.Prompt)) / AverageCharactersPerToken, nil
}
// must be chat-completions request at this point, return bytes of entire messages
- return json.Marshal(request.Body.ChatCompletions.Messages)
+ prompt, err := json.Marshal(request.Body.ChatCompletions.Messages)
+
+ if err != nil {
+ return 0, err
+ }
+
+ return len(prompt) / AverageCharactersPerToken, nil
}
diff --git a/pkg/plugins/profile/pd_profile_handler_test.go b/pkg/plugins/profile/pd_profile_handler_test.go
index 09068104b..c932c8064 100644
--- a/pkg/plugins/profile/pd_profile_handler_test.go
+++ b/pkg/plugins/profile/pd_profile_handler_test.go
@@ -8,130 +8,120 @@ import (
"github.com/stretchr/testify/assert"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix"
"github.com/llm-d/llm-d-inference-scheduler/pkg/common"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)
func TestPdProfileHandlerFactory(t *testing.T) {
+ ctx := utils.NewTestContext(t)
tests := []struct {
name string
pluginName string
- jsonParams string
+ params map[string]any
expectErr bool
}{
{
name: "valid configuration with all defaults",
pluginName: "default-handler",
- jsonParams: "{}",
+ params: map[string]any{},
expectErr: false,
},
{
name: "valid configuration with custom values",
pluginName: "custom-handler",
- jsonParams: `{
- "threshold": 100,
- "decodeProfile": "my-decode",
- "prefillProfile": "my-prefill",
- "prefixPluginName": "my-prefix-cache",
- "hashBlockSize": 32,
- "primaryPort": 8080
- }`,
+ params: map[string]any{
+ "decodeProfile": "my-decode",
+ "prefillProfile": "my-prefill",
+ "prefixPluginName": "my-prefix-cache",
+ "primaryPort": 8080,
+ "deciderPluginName": PrefixBasedPDDeciderPluginType,
+ },
expectErr: false,
},
{
name: "zero primaryPort is allowed",
pluginName: "zero-port",
- jsonParams: `{"primaryPort": 0}`,
- expectErr: false,
- },
- {
- name: "threshold = 0 is allowed",
- pluginName: "zero-threshold",
- jsonParams: `{"threshold": 0}`,
- expectErr: false,
- },
- {
- name: "negative threshold should error",
- pluginName: "neg-threshold",
- jsonParams: `{"threshold": -1}`,
- expectErr: true,
- },
- {
- name: "hashBlockSize = 0 should error",
- pluginName: "zero-block-size",
- jsonParams: `{"hashBlockSize": 0}`,
- expectErr: true,
+ params: map[string]any{
+ "primaryPort": 0,
+ },
+ expectErr: false,
},
{
- name: "negative hashBlockSize should error",
- pluginName: "neg-block-size",
- jsonParams: `{"hashBlockSize": -5}`,
- expectErr: true,
+ name: "nonCachedTokens = 0 is allowed",
+ pluginName: "zero-non-cached-tokens",
+ params: map[string]any{
+ "deciderPluginName": PrefixBasedPDDeciderPluginType,
+ },
+ expectErr: false,
},
{
name: "primaryPort below range should error",
pluginName: "port-too-low",
- jsonParams: `{"primaryPort": 0}`, // OK
+ params: map[string]any{"primaryPort": 0}, // OK
expectErr: false,
},
{
name: "primaryPort = 1 is valid",
pluginName: "port-min",
- jsonParams: `{"primaryPort": 1}`,
+ params: map[string]any{"primaryPort": 1},
expectErr: false,
},
{
name: "primaryPort = 65535 is valid",
pluginName: "port-max",
- jsonParams: `{"primaryPort": 65535}`,
+ params: map[string]any{"primaryPort": 65535},
expectErr: false,
},
{
name: "empty decodeProfile is valid",
pluginName: "empty-decode",
- jsonParams: `{"decodeProfile": ""}`,
+ params: map[string]any{"decodeProfile": ""},
expectErr: false,
},
{
name: "empty prefillProfile is valid",
pluginName: "empty-prefill",
- jsonParams: `{"prefillProfile": ""}`,
+ params: map[string]any{"prefillProfile": ""},
expectErr: false,
},
{
name: "empty prefixPluginName is valid",
pluginName: "empty-prefix-plugin",
- jsonParams: `{"prefixPluginName": ""}`,
+ params: map[string]any{"prefixPluginName": ""},
expectErr: false,
},
{
name: "primaryPort = 65536 should error",
pluginName: "port-too-high",
- jsonParams: `{"primaryPort": 65536}`,
+ params: map[string]any{"primaryPort": 65536},
expectErr: true,
},
{
name: "primaryPort = -10 should error",
pluginName: "port-negative",
- jsonParams: `{"primaryPort": -10}`,
+ params: map[string]any{"primaryPort": -10},
expectErr: true,
},
}
+ handle, err := createHandleWithDeciderPlugins(ctx)
+ assert.NoError(t, err)
+
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var rawParams json.RawMessage
- if tt.jsonParams != "" {
- rawParams = json.RawMessage(tt.jsonParams)
+ if tt.params != nil {
+ bytes, err := json.Marshal(tt.params)
+ assert.NoError(t, err)
+ rawParams = json.RawMessage(bytes)
}
- plugin, err := PdProfileHandlerFactory(tt.pluginName, rawParams, nil)
+ plugin, err := PdProfileHandlerFactory(tt.pluginName, rawParams, handle)
if tt.expectErr {
assert.Error(t, err)
@@ -145,21 +135,19 @@ func TestPdProfileHandlerFactory(t *testing.T) {
}
func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) {
+ ctx := utils.NewTestContext(t)
+
invalidTests := []struct {
name string
jsonParams string
}{
{
name: "malformed JSON",
- jsonParams: `{"threshold": 100, "hashBlockSize":`, // incomplete
- },
- {
- name: "threshold as string instead of int",
- jsonParams: `{"threshold": "100"}`,
+ jsonParams: `{"deciderPluginName": `, // incomplete
},
{
- name: "hashBlockSize as boolean",
- jsonParams: `{"hashBlockSize": true}`,
+ name: "invalid decider plugin type",
+ jsonParams: `{"deciderPluginName": "INVALID"}`,
},
{
name: "primaryPort as float",
@@ -167,10 +155,13 @@ func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) {
},
}
+ handle, err := createHandleWithDeciderPlugins(ctx)
+ assert.NoError(t, err)
+
for _, tt := range invalidTests {
t.Run(tt.name, func(t *testing.T) {
rawParams := json.RawMessage(tt.jsonParams)
- plugin, err := PdProfileHandlerFactory("test", rawParams, nil)
+ plugin, err := PdProfileHandlerFactory("test", rawParams, handle)
assert.Error(t, err)
assert.Nil(t, plugin)
@@ -180,164 +171,280 @@ func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) {
const DefaultTestPodPort = "8000"
-// createPod creates a mock Pod with customizable IP and port.
-func createPod(nsn k8stypes.NamespacedName, ipaddr, port string, labels map[string]string) types.Pod {
- return &types.PodMetrics{
- Pod: &backend.Pod{
+// createEndpoint creates a mock Endpoint with customizable IP and port.
+func createEndpoint(nsn k8stypes.NamespacedName, ipaddr, port string, labels map[string]string) scheduling.Endpoint {
+ return scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: nsn,
Address: ipaddr,
Port: port,
Labels: labels,
},
- MetricsState: &backendmetrics.MetricsState{},
- }
+ nil,
+ fwkdl.NewAttributes(),
+ )
}
// newMockProfileRunResult creates a ProfileRunResult with Pods using the given port.
-func newMockProfileRunResult(port string, podNames ...string) *types.ProfileRunResult {
- pods := make([]types.Pod, 0, len(podNames))
- for i, name := range podNames {
+func newMockProfileRunResult(port string, endpointNames ...string) *scheduling.ProfileRunResult {
+ endpoints := make([]scheduling.Endpoint, 0, len(endpointNames))
+ for i, name := range endpointNames {
ip := fmt.Sprintf("10.0.0.%d", i+1)
- pods = append(pods, createPod(
+ endpoints = append(endpoints, createEndpoint(
k8stypes.NamespacedName{Namespace: "default", Name: name},
ip,
port,
map[string]string{},
))
}
- return &types.ProfileRunResult{
- TargetPods: pods,
+ return &scheduling.ProfileRunResult{
+ TargetEndpoints: endpoints,
}
}
-func newMockSchedulerProfile() *framework.SchedulerProfile {
- return &framework.SchedulerProfile{}
+func newMockSchedulerProfile() scheduling.SchedulerProfile {
+ return &mockSchedulerProfile{}
}
-func TestPdProfileHandler_Pick(t *testing.T) {
- ctx := utils.NewTestContext(t)
- request := &types.LLMRequest{
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
- Prompt: "hello world",
+type mockSchedulerProfile struct{}
+
+func (p *mockSchedulerProfile) Run(_ context.Context, _ *scheduling.LLMRequest, _ *scheduling.CycleState, _ []scheduling.Endpoint) (*scheduling.ProfileRunResult, error) {
+ return &scheduling.ProfileRunResult{}, nil
+}
+
+// creates and returns llm completion request forthe given prompt
+func createRequest(prompt string) *scheduling.LLMRequest {
+ return &scheduling.LLMRequest{
+ Body: &scheduling.LLMRequestBody{
+ Completions: &scheduling.CompletionsRequest{
+ Prompt: prompt,
},
},
}
+}
- profiles := map[string]*framework.SchedulerProfile{
+// returns array of profile names in the given profile pick result
+func getProfilesFromResult(result map[string]scheduling.SchedulerProfile) []string {
+ profiles := make([]string, len(result))
+ index := 0
+
+ for name := range result {
+ profiles[index] = name
+ index++
+ }
+
+ return profiles
+}
+
+func TestPdProfileHandler_Pick(t *testing.T) {
+ ctx := utils.NewTestContext(t)
+ request := createRequest("hello world hello world hello world")
+
+ profiles := map[string]scheduling.SchedulerProfile{
"decode": newMockSchedulerProfile(),
"prefill": newMockSchedulerProfile(),
}
tests := []struct {
- name string
- pdThreshold int
- hashBlockSize int
- prefixPluginType string
- prefixPluginName string
- setupPrefixState func(*types.CycleState)
- profileResults map[string]*types.ProfileRunResult
- expectedProfiles []string
+ name string
+ nonCachedTokensLimit int
+ prefixPluginType string
+ prefixPluginName string
+ cachedTokens int
+ profileResults map[string]*scheduling.ProfileRunResult
+ expectedProfiles []string
}{
{
- name: "decode not executed yet β run decode",
- pdThreshold: 100,
- hashBlockSize: 16,
- prefixPluginType: prefix.PrefixCachePluginType,
- prefixPluginName: prefix.PrefixCachePluginType,
- profileResults: map[string]*types.ProfileRunResult{},
- expectedProfiles: []string{"decode"},
+ name: "decode not executed yet β run decode",
+ nonCachedTokensLimit: 10,
+ prefixPluginType: prefix.PrefixCachePluginType,
+ prefixPluginName: prefix.PrefixCachePluginType,
+ profileResults: map[string]*scheduling.ProfileRunResult{},
+ expectedProfiles: []string{defaultDecodeProfile},
},
{
- name: "decode failed (nil result) β run nothing",
- pdThreshold: 100,
- hashBlockSize: 16,
- prefixPluginType: prefix.PrefixCachePluginType,
- prefixPluginName: prefix.PrefixCachePluginType,
- profileResults: map[string]*types.ProfileRunResult{
- "decode": nil,
+ name: "decode failed (nil result) β run nothing",
+ nonCachedTokensLimit: 10,
+ prefixPluginType: prefix.PrefixCachePluginType,
+ prefixPluginName: prefix.PrefixCachePluginType,
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: nil,
},
expectedProfiles: []string{},
},
{
- name: "all profiles already executed β run nothing",
- pdThreshold: 100,
- hashBlockSize: 16,
- prefixPluginType: prefix.PrefixCachePluginType,
- prefixPluginName: prefix.PrefixCachePluginType,
- profileResults: map[string]*types.ProfileRunResult{
- "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
- "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"),
+ name: "all profiles already executed β run nothing",
+ nonCachedTokensLimit: 10,
+ prefixPluginType: prefix.PrefixCachePluginType,
+ prefixPluginName: prefix.PrefixCachePluginType,
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"),
+ defaultPrefillProfile: newMockProfileRunResult(DefaultTestPodPort, "pod2"),
},
expectedProfiles: []string{},
},
{
- name: "pd threshold NOT triggered β run prefill",
- pdThreshold: 5,
- hashBlockSize: 16,
- prefixPluginType: prefix.PrefixCachePluginType,
- prefixPluginName: prefix.PrefixCachePluginType,
- setupPrefixState: func(cs *types.CycleState) {
- state := &prefix.SchedulingContextState{
- PrefixCacheServers: map[prefix.ServerID]int{
- prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 1,
- },
- }
- key := plugins.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType))
- cs.Write(key, state)
+ name: "has enough not-cached tokens β run prefill",
+ // Need at least 4 non-cached tokens (16+ chars) to trigger disaggregated prefill
+ // In this case: prompt length is 35 chars (8 tokens), cached length is 2 tokens -> disaggregated prefill should trigger
+ nonCachedTokensLimit: 4,
+ cachedTokens: 2,
+ prefixPluginType: prefix.PrefixCachePluginType,
+ prefixPluginName: prefix.PrefixCachePluginType,
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"),
},
- profileResults: map[string]*types.ProfileRunResult{
- "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
- },
- expectedProfiles: []string{"prefill"},
+ expectedProfiles: []string{defaultPrefillProfile},
},
{
- name: "pd threshold triggered (short non-cached suffix) β skip prefill",
- pdThreshold: 100,
- hashBlockSize: 16,
- prefixPluginType: prefix.PrefixCachePluginType,
- prefixPluginName: prefix.PrefixCachePluginType,
- setupPrefixState: func(cs *types.CycleState) {
- state := &prefix.SchedulingContextState{
- PrefixCacheServers: map[prefix.ServerID]int{
- prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 5,
- },
- }
- key := plugins.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType))
- cs.Write(key, state)
- },
- profileResults: map[string]*types.ProfileRunResult{
- "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
+ name: "short non-cached suffix β skip prefill",
+ // Need at least 4 non-cached tokens (16+ chars) to trigger disaggregated prefill
+ // In this case: prompt length is 35 chars (8 tokens), cached length is 5 tokens -> skip prefill
+ nonCachedTokensLimit: 4,
+ cachedTokens: 5,
+ prefixPluginType: prefix.PrefixCachePluginType,
+ prefixPluginName: prefix.PrefixCachePluginType,
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"),
},
expectedProfiles: []string{},
},
}
for _, tt := range tests {
+ deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: tt.nonCachedTokensLimit})
+ assert.NoError(t, err)
+
t.Run(tt.name, func(t *testing.T) {
- handler := NewPdProfileHandler(
- "prefill",
- "decode",
+ handler, err := NewPdProfileHandler(
+ defaultPrefillProfile,
+ defaultDecodeProfile,
tt.prefixPluginType,
tt.prefixPluginName,
- tt.pdThreshold,
- tt.hashBlockSize,
0,
- ).WithName("test-handler")
+ deciderPlugin,
+ )
+ assert.NoError(t, err)
+
+ // set prefix to the given cached tokens number for pod "pod1" in decode profile results
+ inputTokens := len(request.Body.Completions.Prompt) / AverageCharactersPerToken
- cs := &types.CycleState{}
- if tt.setupPrefixState != nil {
- tt.setupPrefixState(cs)
+ for profileName, profileRes := range tt.profileResults {
+ if profileName == defaultDecodeProfile && profileRes != nil {
+ for _, pod := range profileRes.TargetEndpoints {
+ pod.Put(approximateprefix.PrefixCacheMatchInfoKey,
+ approximateprefix.NewPrefixCacheMatchInfo(tt.cachedTokens, inputTokens, 1))
+ }
+ }
}
+ result := handler.Pick(ctx, nil, request, profiles, tt.profileResults)
+ assert.ElementsMatch(t, tt.expectedProfiles, getProfilesFromResult(result))
+ })
+ }
+}
- result := handler.Pick(ctx, cs, request, profiles, tt.profileResults)
+func TestPdProfileHandler_PickSeries(t *testing.T) {
+ ctx := context.Background()
+ prompt := "hello world, hello world, hello world, hello world, hello world, hello world, hello world!"
+ request := createRequest(prompt)
+ longerRequest := createRequest(prompt + "123")
+ longRequest := createRequest(prompt + prompt)
- var actual []string
- for name := range result {
- actual = append(actual, name)
- }
+ profiles := map[string]scheduling.SchedulerProfile{
+ defaultDecodeProfile: newMockSchedulerProfile(),
+ defaultPrefillProfile: newMockSchedulerProfile(),
+ }
+ profileResults := map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"),
+ }
+
+ type testData struct {
+ request *scheduling.LLMRequest
+ cachedTokens int
+ expectedProfiles []string
+ }
+ tests := []struct {
+ name string
+ nonCachedTokensLimit int
+ tests []testData
+ }{
+ {
+ name: "same request twice",
+ nonCachedTokensLimit: 2,
+ tests: []testData{{
+ request: request,
+ cachedTokens: 0,
+ expectedProfiles: []string{defaultPrefillProfile},
+ }, {
+ request: request,
+ cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken,
+ expectedProfiles: []string{},
+ }},
+ }, {
+ name: "short request and a little bit longer after it",
+ // Need at least 2 non-cached tokens (8+ chars) to trigger disaggregated prefill
+ // In this case: longer request is longer in 4 chars than the request -> no disaggregated prefill
+ nonCachedTokensLimit: 2,
+ tests: []testData{{
+ request: request,
+ cachedTokens: 0,
+ expectedProfiles: []string{defaultPrefillProfile},
+ }, {
+ request: longerRequest,
+ cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken,
+ expectedProfiles: []string{},
+ }},
+ }, {
+ name: "short request and a long one after it",
+ // Need at least 2 non-cached tokens (8+ chars) to trigger disaggregated prefill
+ // In this case: long request is longer enough than the request -> should have disaggregated prefill
+ nonCachedTokensLimit: 2,
+ tests: []testData{{
+ request: request,
+ cachedTokens: 0,
+ expectedProfiles: []string{defaultPrefillProfile},
+ }, {
+ request: longRequest,
+ cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken,
+ expectedProfiles: []string{defaultPrefillProfile},
+ }},
+ },
+ }
- assert.ElementsMatch(t, tt.expectedProfiles, actual)
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: tt.nonCachedTokensLimit})
+ assert.NoError(t, err)
+
+ handler, err := NewPdProfileHandler(
+ defaultPrefillProfile,
+ defaultDecodeProfile,
+ prefix.PrefixCachePluginType,
+ prefix.PrefixCachePluginType,
+ 0,
+ deciderPlugin,
+ )
+ assert.NoError(t, err)
+
+ // run sequences of request
+ for _, innerTest := range tt.tests {
+ cs := &scheduling.CycleState{}
+
+ // set prefix to the given cached tokens number for pod "pod1" in decode profile results
+ inputTokens := len(innerTest.request.Body.Completions.Prompt) / AverageCharactersPerToken
+
+ for profileName, profileRes := range profileResults {
+ if profileName == defaultDecodeProfile && profileRes != nil {
+ for _, endpoint := range profileRes.TargetEndpoints {
+ endpoint.Put(approximateprefix.PrefixCacheMatchInfoKey,
+ approximateprefix.NewPrefixCacheMatchInfo(innerTest.cachedTokens, inputTokens, 1))
+ }
+ }
+ }
+
+ result := handler.Pick(ctx, cs, innerTest.request, profiles, profileResults)
+ assert.ElementsMatch(t, innerTest.expectedProfiles, getProfilesFromResult(result))
+ }
})
}
}
@@ -346,57 +453,57 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) {
tests := []struct {
name string
primaryPort int
- profileResults map[string]*types.ProfileRunResult
+ profileResults map[string]*scheduling.ProfileRunResult
expectError bool
- checkResult func(*testing.T, *types.SchedulingResult, map[string]string)
+ checkResult func(*testing.T, *scheduling.SchedulingResult, map[string]string)
}{
{
name: "decode failed β error",
- profileResults: map[string]*types.ProfileRunResult{
- "decode": nil,
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: nil,
},
expectError: true,
},
{
name: "decode success, no prefill, no primaryPort",
primaryPort: 0,
- profileResults: map[string]*types.ProfileRunResult{
- "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"),
},
expectError: false,
- checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
- assert.Equal(t, "decode", res.PrimaryProfileName)
- assert.Contains(t, res.ProfileResults, "decode")
- assert.NotContains(t, res.ProfileResults, "prefill")
- pod := res.ProfileResults["decode"].TargetPods[0].GetPod()
- assert.Equal(t, DefaultTestPodPort, pod.Port)
+ checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) {
+ assert.Equal(t, defaultDecodeProfile, res.PrimaryProfileName)
+ assert.Contains(t, res.ProfileResults, defaultDecodeProfile)
+ assert.NotContains(t, res.ProfileResults, defaultPrefillProfile)
+ metadata := res.ProfileResults[defaultDecodeProfile].TargetEndpoints[0].GetMetadata()
+ assert.Equal(t, DefaultTestPodPort, metadata.Port)
assert.Empty(t, headers[common.DataParallelPodHeader])
},
},
{
name: "decode success, with prefill",
primaryPort: 0,
- profileResults: map[string]*types.ProfileRunResult{
- "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
- "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"),
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"),
+ defaultPrefillProfile: newMockProfileRunResult(DefaultTestPodPort, "pod2"),
},
expectError: false,
- checkResult: func(t *testing.T, res *types.SchedulingResult, _ map[string]string) {
- assert.Equal(t, "decode", res.PrimaryProfileName)
- assert.Contains(t, res.ProfileResults, "decode")
- assert.Contains(t, res.ProfileResults, "prefill")
+ checkResult: func(t *testing.T, res *scheduling.SchedulingResult, _ map[string]string) {
+ assert.Equal(t, defaultDecodeProfile, res.PrimaryProfileName)
+ assert.Contains(t, res.ProfileResults, defaultDecodeProfile)
+ assert.Contains(t, res.ProfileResults, defaultPrefillProfile)
},
},
{
name: "with primaryPort β port updated and header set",
primaryPort: 9000,
- profileResults: map[string]*types.ProfileRunResult{
- "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"),
+ profileResults: map[string]*scheduling.ProfileRunResult{
+ defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"),
},
expectError: false,
- checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) {
- pod := res.ProfileResults["decode"].TargetPods[0].GetPod()
- assert.Equal(t, "9000", pod.Port)
+ checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) {
+ metadata := res.ProfileResults[defaultDecodeProfile].TargetEndpoints[0].GetMetadata()
+ assert.Equal(t, "9000", metadata.Port)
hostPort := headers[common.DataParallelPodHeader]
assert.Equal(t, "10.0.0.1:8000", hostPort)
@@ -405,22 +512,25 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) {
}
for _, tt := range tests {
+ deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: 0})
+ assert.NoError(t, err)
+
t.Run(tt.name, func(t *testing.T) {
- handler := NewPdProfileHandler(
- "prefill",
- "decode",
+ handler, err := NewPdProfileHandler(
+ defaultPrefillProfile,
+ defaultDecodeProfile,
prefix.PrefixCachePluginType,
prefix.PrefixCachePluginType,
- 0,
- prefix.DefaultBlockSize,
tt.primaryPort,
- ).WithName("test-handler")
+ deciderPlugin,
+ )
+ assert.NoError(t, err)
headers := make(map[string]string)
- req := &types.LLMRequest{
+ req := &scheduling.LLMRequest{
Headers: headers,
}
- result, err := handler.ProcessResults(context.Background(), &types.CycleState{}, req, tt.profileResults)
+ result, err := handler.ProcessResults(context.Background(), &scheduling.CycleState{}, req, tt.profileResults)
if tt.expectError {
assert.Error(t, err)
@@ -433,3 +543,16 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) {
})
}
}
+
+func createHandleWithDeciderPlugins(ctx context.Context) (plugin.Handle, error) {
+ handle := plugin.NewEppHandle(ctx, nil)
+ plugin1, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: 4})
+ if err != nil {
+ return nil, err
+ }
+ handle.AddPlugin(PrefixBasedPDDeciderPluginType, plugin1)
+ plugin2 := newAlwaysDisaggPDDecider()
+ handle.AddPlugin(AlwaysDisaggDeciderPluginType, plugin2)
+
+ return handle, nil
+}
diff --git a/pkg/plugins/profile/prefix_based_pd_decider.go b/pkg/plugins/profile/prefix_based_pd_decider.go
new file mode 100644
index 000000000..8948b4d37
--- /dev/null
+++ b/pkg/plugins/profile/prefix_based_pd_decider.go
@@ -0,0 +1,137 @@
+package profile
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+
+ "sigs.k8s.io/controller-runtime/pkg/log"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+)
+
+const (
+ // PrefixBasedPDDeciderPluginType is the type-name of the prefixBasedPDDecider plugin.
+ PrefixBasedPDDeciderPluginType = "prefix-based-pd-decider"
+)
+
+// PrefixBasedPDDeciderConfig holds the configuration for the prefixBasedPDDecider plugin.
+type PrefixBasedPDDeciderConfig struct {
+ // NonCachedTokens non cached minimum tokens that triggers disaggregated PD
+ NonCachedTokens int `json:"nonCachedTokens"`
+}
+
+func (p PrefixBasedPDDeciderConfig) validate() error {
+ if p.NonCachedTokens < 0 {
+ return errors.New("nonCachedTokens parameter of prefix disaggregation decider cannot be negative")
+ }
+
+ return nil
+}
+
+// compile-time type assertion
+var _ pdDeciderPlugin = &PrefixBasedPDDecider{}
+
+// PrefixBasedPDDecider is a PD decider plugin which decision is based prefix aware
+type PrefixBasedPDDecider struct {
+ typedName plugin.TypedName
+ config PrefixBasedPDDeciderConfig
+}
+
+// PrefixBasedPDDeciderPluginFactory defines the factory function for creating
+// a new instance of the prefixBasedPDDecider.
+func PrefixBasedPDDeciderPluginFactory(name string, rawParameters json.RawMessage,
+ handle plugin.Handle) (plugin.Plugin, error) {
+ config := PrefixBasedPDDeciderConfig{
+ NonCachedTokens: 0,
+ }
+
+ if rawParameters != nil {
+ if err := json.Unmarshal(rawParameters, &config); err != nil {
+ return nil, fmt.Errorf("failed to parse %s plugin config: %w", PrefixBasedPDDeciderPluginType, err)
+ }
+ }
+
+ decider, err := NewPrefixBasedPDDecider(config)
+ if err != nil {
+ return nil, fmt.Errorf("failed to create %s plugin: %w", PrefixBasedPDDeciderPluginType, err)
+ }
+
+ return decider.WithName(name), nil
+}
+
+// NewPrefixBasedPDDecider initializes a NewPrefixBasedPDDecider prefix based PD decider Plugin and returns its pointer.
+// If the configuration is invalid an error is returned.
+func NewPrefixBasedPDDecider(config PrefixBasedPDDeciderConfig) (*PrefixBasedPDDecider, error) {
+ if err := config.validate(); err != nil {
+ return nil, err
+ }
+
+ return &PrefixBasedPDDecider{
+ config: config,
+ }, nil
+}
+
+// TypedName returns the typed name of the plugin.
+func (d *PrefixBasedPDDecider) TypedName() plugin.TypedName {
+ return d.typedName
+}
+
+// WithName sets the name of the plugin.
+func (d *PrefixBasedPDDecider) WithName(name string) *PrefixBasedPDDecider {
+ d.typedName.Name = name
+ return d
+}
+
+func (d *PrefixBasedPDDecider) disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool {
+ logger := log.FromContext(ctx)
+ debugLogger := log.FromContext(ctx).V(logutil.DEBUG)
+
+ if d.config.NonCachedTokens <= 0 { // always use disaggregation in case of non cached tokens number is 0
+ return true
+ }
+ if endpoint == nil {
+ logger.Error(nil, "prefix decider: endpoint is nil")
+ return false
+ }
+ if inputTokens < d.config.NonCachedTokens {
+ debugLogger.Info("Input is shorter than the nonCachedToken, no disaggregated PD")
+ return false
+ }
+ // inspect the decode endpoint to decide if prefill should run or not.
+ // if the non-cached part is short enough - no disaggregation.
+ prefixInfoRaw, ok := endpoint.Get(approximateprefix.PrefixCacheMatchInfoKey)
+ if !ok || prefixInfoRaw == nil {
+ logger.Error(nil, "unable to read prefix cache state")
+ return false
+ }
+ prefixCacheMatchInfo, ok := prefixInfoRaw.(*approximateprefix.PrefixCacheMatchInfo)
+ if !ok {
+ logger.Error(nil, "wrong type of prefix cache match info")
+ return false
+ }
+
+ // number of cached tokens
+ hitPrefixTokens := prefixCacheMatchInfo.MatchBlocks() * prefixCacheMatchInfo.BlockSizeTokens()
+ // length of non-cached suffix in tokens
+ nonCachedTokens := inputTokens - hitPrefixTokens
+
+ debugLogger.Info("Computed hit percentage for prefix cache",
+ "absolute hit prefix len (tokens)", hitPrefixTokens,
+ "prompt length (token)", inputTokens)
+
+ if nonCachedTokens < d.config.NonCachedTokens {
+ debugLogger.Info("Non-cached suffix is smaller than threshold, using decode profile only")
+ return false // do not run prefill
+ }
+
+ return true
+}
+
+// Consumes defines data types consumed by this plugin
+func (*PrefixBasedPDDecider) Consumes() map[string]any {
+ return map[string]any{approximateprefix.PrefixCacheMatchInfoKey: approximateprefix.PrefixCacheMatchInfo{}}
+}
diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go
index f78c4afc7..d3c50fd31 100644
--- a/pkg/plugins/register.go
+++ b/pkg/plugins/register.go
@@ -1,25 +1,31 @@
package plugins
import (
+ "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/datalayer/models"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
)
// RegisterAllPlugins registers the factory functions of all plugins in this repository.
func RegisterAllPlugins() {
- plugins.Register(filter.ByLabelType, filter.ByLabelFactory)
- plugins.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory)
- plugins.Register(filter.DecodeRoleType, filter.DecodeRoleFactory)
- plugins.Register(filter.PrefillRoleType, filter.PrefillRoleFactory)
- plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory)
- plugins.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory)
- plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory)
- plugins.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory)
- plugins.Register(scorer.LoadAwareType, scorer.LoadAwareFactory)
- plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
- plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
- plugins.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
+ plugin.Register(filter.ByLabelType, filter.ByLabelFactory)
+ plugin.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory)
+ plugin.Register(filter.DecodeRoleType, filter.DecodeRoleFactory)
+ plugin.Register(filter.PrefillRoleType, filter.PrefillRoleFactory)
+ plugin.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory)
+ plugin.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory)
+ plugin.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory)
+ plugin.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory)
+ plugin.Register(scorer.LoadAwareType, scorer.LoadAwareFactory)
+ plugin.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory)
+ plugin.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory)
+ plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory)
+ plugin.Register(models.ModelsDataSourceType, models.ModelDataSourceFactory)
+ plugin.Register(models.ModelsExtractorType, models.ModelServerExtractorFactory)
+ // pd decider plugins
+ plugin.Register(profile.PrefixBasedPDDeciderPluginType, profile.PrefixBasedPDDeciderPluginFactory)
+ plugin.Register(profile.AlwaysDisaggDeciderPluginType, profile.AlwaysDisaggPDDeciderPluginFactory)
}
diff --git a/pkg/plugins/scorer/active_request.go b/pkg/plugins/scorer/active_request.go
index 14e2e4169..a35791b56 100644
--- a/pkg/plugins/scorer/active_request.go
+++ b/pkg/plugins/scorer/active_request.go
@@ -4,17 +4,17 @@ import (
"context"
"encoding/json"
"fmt"
+ "strings"
"sync"
"time"
"github.com/jellydator/ttlcache/v3"
"sigs.k8s.io/controller-runtime/pkg/log"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
const (
@@ -38,22 +38,34 @@ type ActiveRequestParameters struct {
// requestEntry represents a single request in the cache
type requestEntry struct {
- PodName string
+ PodNames []string
RequestID string
}
// String returns a string representation of the request entry.
-func (r *requestEntry) String() string {
- return fmt.Sprintf("%s.%s", r.PodName, r.RequestID)
+func (r requestEntry) String() string {
+ return fmt.Sprintf("%s:%s", r.RequestID, strings.Join(r.PodNames, "."))
+}
+
+// endpointScores implements logr.Marshaler to lazily convert endpoint keys
+// to strings only when the log line is actually written.
+type endpointScores map[scheduling.Endpoint]float64
+
+func (s endpointScores) MarshalLog() interface{} {
+ result := make(map[string]float64, len(s))
+ for ep, score := range s {
+ result[ep.GetMetadata().NamespacedName.String()] = score
+ }
+ return result
}
// compile-time type assertion
-var _ framework.Scorer = &ActiveRequest{}
+var _ scheduling.Scorer = &ActiveRequest{}
var _ requestcontrol.PreRequest = &ActiveRequest{}
var _ requestcontrol.ResponseComplete = &ActiveRequest{}
// ActiveRequestFactory defines the factory function for the ActiveRequest scorer.
-func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
+func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) {
parameters := ActiveRequestParameters{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
@@ -86,18 +98,20 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act
)
scorer := &ActiveRequest{
- typedName: plugins.TypedName{Type: ActiveRequestType},
- requestCache: requestCache,
- podCounts: make(map[string]int),
- mutex: &sync.RWMutex{},
+ typedName: plugin.TypedName{Type: ActiveRequestType},
+ requestCache: requestCache,
+ endpointCounts: make(map[string]int),
+ mutex: &sync.RWMutex{},
}
// callback to decrement count when requests expire
// most requests will be removed in ResponseComplete, but this ensures
- // that we don't leak pod counts if ResponseComplete is not called
+ // that we don't leak endpoint counts if ResponseComplete is not called
requestCache.OnEviction(func(_ context.Context, reason ttlcache.EvictionReason,
item *ttlcache.Item[string, *requestEntry]) {
if reason == ttlcache.EvictionReasonExpired {
- scorer.decrementPodCount(item.Value().PodName)
+ for _, endpointName := range item.Value().PodNames {
+ scorer.decrementPodCount(endpointName)
+ }
}
})
@@ -107,20 +121,20 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act
}
// ActiveRequest keeps track of individual requests being served
-// per pod.
+// per endpoint.
type ActiveRequest struct {
- typedName plugins.TypedName
+ typedName plugin.TypedName
- // requestCache stores individual request entries with unique composite keys (podName.requestID)
+ // requestCache stores individual request entries with unique composite keys (endpointName.requestID)
requestCache *ttlcache.Cache[string, *requestEntry]
- // podCounts maintains fast lookup for request counts per pod
- podCounts map[string]int
- mutex *sync.RWMutex
+ // endpointCounts maintains fast lookup for request counts per endpoint
+ endpointCounts map[string]int
+ mutex *sync.RWMutex
}
// TypedName returns the typed name of the plugin.
-func (s *ActiveRequest) TypedName() plugins.TypedName {
+func (s *ActiveRequest) TypedName() plugin.TypedName {
return s.typedName
}
@@ -130,110 +144,131 @@ func (s *ActiveRequest) WithName(name string) *ActiveRequest {
return s
}
-// Score scores the given pods based on the number of active requests
-// being served by each pod. The score is normalized to a range of 0-1.
-func (s *ActiveRequest) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest,
- pods []types.Pod) map[types.Pod]float64 {
- scoredPods := make(map[string]int)
+// Category returns the preference the scorer applies when scoring candidate endpoints.
+func (s *ActiveRequest) Category() scheduling.ScorerCategory {
+ return scheduling.Distribution
+}
+
+// Score scores the given endpoints based on the number of active requests
+// being served by each endpoint. The score is normalized to a range of 0-1.
+func (s *ActiveRequest) Score(ctx context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest,
+ endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 {
+ scoredEndpoints := make(map[string]int)
maxCount := 0
s.mutex.RLock()
- for podName, count := range s.podCounts {
- scoredPods[podName] = count
+ for endpointName, count := range s.endpointCounts {
+ scoredEndpoints[endpointName] = count
if count >= maxCount {
maxCount = count
}
}
s.mutex.RUnlock()
- scoredPodsMap := make(map[types.Pod]float64, len(pods))
- for _, pod := range pods {
- podName := pod.GetPod().NamespacedName.String()
- if count, exists := scoredPods[podName]; exists {
+ log.FromContext(ctx).V(logutil.DEBUG).Info("Active request counts", "endpointCounts", scoredEndpoints, "maxCount", maxCount)
+
+ scoredEndpointsMap := make(map[scheduling.Endpoint]float64, len(endpoints))
+ for _, endpoint := range endpoints {
+ endpointName := endpoint.GetMetadata().NamespacedName.String()
+ if count, exists := scoredEndpoints[endpointName]; exists {
if count == 0 || maxCount == 0 {
- scoredPodsMap[pod] = 1.0 // no requests means highest score
+ scoredEndpointsMap[endpoint] = 1.0 // no requests means highest score
} else {
- scoredPodsMap[pod] = float64(maxCount-count) / float64(maxCount)
+ scoredEndpointsMap[endpoint] = float64(maxCount-count) / float64(maxCount)
}
} else {
- scoredPodsMap[pod] = 1.0
+ scoredEndpointsMap[endpoint] = 1.0
}
}
- log.FromContext(ctx).V(logutil.DEBUG).Info("Scored pods", "scores", scoredPodsMap)
- return scoredPodsMap
+ log.FromContext(ctx).V(logutil.DEBUG).Info("Scored endpoints", "scores", endpointScores(scoredEndpointsMap))
+ return scoredEndpointsMap
}
-// PreRequest is called before a request is sent to the target pod.
+// PreRequest is called before a request is sent to the target endpoint.
// It creates a new request entry in the cache with its own TTL and
-// increments the pod count for fast lookup.
-func (s *ActiveRequest) PreRequest(ctx context.Context, request *types.LLMRequest,
- schedulingResult *types.SchedulingResult) {
+// increments the endpoint count for fast lookup.
+func (s *ActiveRequest) PreRequest(
+ ctx context.Context,
+ request *scheduling.LLMRequest,
+ schedulingResult *scheduling.SchedulingResult,
+) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG)
- for _, profileResult := range schedulingResult.ProfileResults { // schedulingResult guaranteed not to be nil
- if profileResult == nil || profileResult.TargetPods == nil || len(profileResult.TargetPods) == 0 {
+ endpointNames := make([]string, 0, len(schedulingResult.ProfileResults))
+ for profileName, profileResult := range schedulingResult.ProfileResults {
+ if profileResult == nil || len(profileResult.TargetEndpoints) == 0 {
continue
}
- // create request entry for first pod only. TODO: support fallback pods
- entry := &requestEntry{
- PodName: profileResult.TargetPods[0].GetPod().NamespacedName.String(),
- RequestID: request.RequestId,
- }
-
- // add to request cache with TTL
- s.requestCache.Set(entry.String(), entry, 0) // Use default TTL
- s.incrementPodCount(entry.PodName)
-
- debugLogger.Info("Added request to cache", "requestEntry", entry.String())
+ endpointName := profileResult.TargetEndpoints[0].GetMetadata().NamespacedName.String()
+ endpointNames = append(endpointNames, endpointName)
+ s.incrementPodCount(endpointName)
+ debugLogger.Info(
+ "Added request to cache",
+ "requestId", request.RequestId,
+ "endpointName", endpointName,
+ "profileName", profileName,
+ )
}
+
+ // add to request cache
+ s.requestCache.Set(request.RequestId, &requestEntry{PodNames: endpointNames, RequestID: request.RequestId}, 0) // Use default TTL
}
// ResponseComplete is called after a response is sent to the client.
// It removes the specific request entry from the cache and decrements
-// the pod count.
-func (s *ActiveRequest) ResponseComplete(ctx context.Context, request *types.LLMRequest,
- _ *requestcontrol.Response, targetPod *backend.Pod) {
+// the endpoint count.
+func (s *ActiveRequest) ResponseComplete(
+ ctx context.Context,
+ request *scheduling.LLMRequest,
+ _ *requestcontrol.Response,
+ targetPod *datalayer.EndpointMetadata,
+) {
debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequest.ResponseComplete")
if targetPod == nil {
debugLogger.Info("Skipping ResponseComplete because targetPod is nil")
return
}
- entry := requestEntry{targetPod.NamespacedName.String(), request.RequestId}
-
- if _, found := s.requestCache.GetAndDelete(entry.String()); found {
- s.decrementPodCount(entry.PodName)
- debugLogger.Info("Removed request from cache", "requestEntry", entry.String())
+ if item, found := s.requestCache.GetAndDelete(request.RequestId); found {
+ entry := item.Value()
+ if entry != nil {
+ for _, endpointName := range entry.PodNames {
+ s.decrementPodCount(endpointName)
+ }
+ debugLogger.Info("Removed request from cache", "requestEntry", entry.String())
+ } else {
+ debugLogger.Info("Request entry value is nil", "requestId", request.RequestId)
+ }
} else {
- debugLogger.Info("Request not found in cache", "requestEntry", entry.String())
+ debugLogger.Info("Request not found in cache", "requestId", request.RequestId)
}
}
-// incrementPodCount increments the request count for a pod.
-func (s *ActiveRequest) incrementPodCount(podName string) {
+// incrementPodCount increments the request count for a endpoint.
+func (s *ActiveRequest) incrementPodCount(endpointName string) {
s.mutex.Lock()
defer s.mutex.Unlock()
- s.podCounts[podName]++
+ s.endpointCounts[endpointName]++
}
-// decrementPodCount decrements the request count for a pod and removes
+// decrementPodCount decrements the request count for a endpoint and removes
// the entry if count reaches zero.
-func (s *ActiveRequest) decrementPodCount(podName string) {
+func (s *ActiveRequest) decrementPodCount(endpointName string) {
s.mutex.Lock()
defer s.mutex.Unlock()
- if count, exists := s.podCounts[podName]; exists {
+ if count, exists := s.endpointCounts[endpointName]; exists {
if count <= 1 {
- delete(s.podCounts, podName)
+ delete(s.endpointCounts, endpointName)
} else {
- s.podCounts[podName] = count - 1
+ s.endpointCounts[endpointName] = count - 1
}
}
}
-func cleanCachePeriodically(ctx context.Context, cache *ttlcache.Cache[string, *requestEntry], requestTimeout time.Duration) {
+func cleanCachePeriodically[K comparable, V any](ctx context.Context, cache *ttlcache.Cache[K, V], requestTimeout time.Duration) {
ticker := time.NewTicker(requestTimeout)
defer ticker.Stop()
diff --git a/pkg/plugins/scorer/active_request_test.go b/pkg/plugins/scorer/active_request_test.go
index e7215ce1a..12ab609aa 100644
--- a/pkg/plugins/scorer/active_request_test.go
+++ b/pkg/plugins/scorer/active_request_test.go
@@ -4,84 +4,112 @@ import (
"testing"
"time"
- "github.com/google/go-cmp/cmp"
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/require"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)
-func TestActiveRequestScorer_Score(t *testing.T) {
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{
- WaitingQueueSize: 2,
+// Test helper functions
+
+func newTestEndpoint(name string, queueSize int) scheduling.Endpoint {
+ return scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: name, Namespace: "default"}},
+ &fwkdl.Metrics{
+ WaitingQueueSize: queueSize,
},
+ nil,
+ )
+}
+
+func newTestRequest(id string) *scheduling.LLMRequest {
+ return &scheduling.LLMRequest{
+ RequestId: id,
}
- podB := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{
- WaitingQueueSize: 0,
- },
+}
+
+func newTestSchedulingResult(profileEndpoints map[string]scheduling.Endpoint) *scheduling.SchedulingResult {
+ profileResults := make(map[string]*scheduling.ProfileRunResult)
+ for profile, endpoint := range profileEndpoints {
+ profileResults[profile] = &scheduling.ProfileRunResult{
+ TargetEndpoints: []scheduling.Endpoint{endpoint},
+ }
}
- podC := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{
- WaitingQueueSize: 15,
- },
+ return &scheduling.SchedulingResult{
+ ProfileResults: profileResults,
}
+}
+
+func (s *ActiveRequest) getPodCount(endpointName string) int {
+ s.mutex.RLock()
+ defer s.mutex.RUnlock()
+ return s.endpointCounts[endpointName]
+}
+
+func (s *ActiveRequest) hasPodCount(endpointName string) bool {
+ s.mutex.RLock()
+ defer s.mutex.RUnlock()
+ _, exists := s.endpointCounts[endpointName]
+ return exists
+}
+
+func TestActiveRequestScorer_Score(t *testing.T) {
+ endpointA := newTestEndpoint("pod-a", 2)
+ endpointB := newTestEndpoint("pod-b", 0)
+ endpointC := newTestEndpoint("pod-c", 15)
tests := []struct {
name string
setupCache func(*ActiveRequest)
- input []types.Pod
- wantScores map[types.Pod]float64
+ input []scheduling.Endpoint
+ wantScores map[scheduling.Endpoint]float64
}{
{
- name: "no pods in cache",
+ name: "no endpoints in cache",
setupCache: func(_ *ActiveRequest) {
// Cache is empty
},
- input: []types.Pod{podA, podB, podC},
- wantScores: map[types.Pod]float64{
- podA: 1,
- podB: 1,
- podC: 1,
+ input: []scheduling.Endpoint{endpointA, endpointB, endpointC},
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 1,
+ endpointB: 1,
+ endpointC: 1,
},
},
{
- name: "all pods in cache with different request counts",
+ name: "all endpoints in cache with different request counts",
setupCache: func(s *ActiveRequest) {
s.mutex.Lock()
- s.podCounts["default/pod-a"] = 3
- s.podCounts["default/pod-b"] = 0
- s.podCounts["default/pod-c"] = 6
+ s.endpointCounts["default/pod-a"] = 3
+ s.endpointCounts["default/pod-b"] = 0
+ s.endpointCounts["default/pod-c"] = 6
s.mutex.Unlock()
},
- input: []types.Pod{podA, podB, podC},
- wantScores: map[types.Pod]float64{
- podA: 0.5,
- podB: 1.0,
- podC: 0.0,
+ input: []scheduling.Endpoint{endpointA, endpointB, endpointC},
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 0.5,
+ endpointB: 1.0,
+ endpointC: 0.0,
},
},
{
- name: "some pods in cache",
+ name: "some endpoints in cache",
setupCache: func(s *ActiveRequest) {
s.mutex.Lock()
- s.podCounts["default/pod-a"] = 4
- s.podCounts["default/pod-c"] = 1
+ s.endpointCounts["default/pod-a"] = 4
+ s.endpointCounts["default/pod-c"] = 1
// pod-b not in cache
s.mutex.Unlock()
},
- input: []types.Pod{podA, podB, podC},
- wantScores: map[types.Pod]float64{
- podA: 0.0,
- podB: 1.0,
- podC: 0.75,
+ input: []scheduling.Endpoint{endpointA, endpointB, endpointC},
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 0.0,
+ endpointB: 1.0,
+ endpointC: 0.75,
},
},
}
@@ -95,9 +123,7 @@ func TestActiveRequestScorer_Score(t *testing.T) {
got := scorer.Score(ctx, nil, nil, test.input)
- if diff := cmp.Diff(test.wantScores, got); diff != "" {
- t.Errorf("Unexpected output (-want +got): %v", diff)
- }
+ assert.Equal(t, test.wantScores, got)
})
}
}
@@ -106,124 +132,57 @@ func TestActiveRequestScorer_PreRequest(t *testing.T) {
ctx := utils.NewTestContext(t)
scorer := NewActiveRequest(ctx, nil)
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{
- WaitingQueueSize: 2,
- },
- }
-
- request := &types.LLMRequest{
- RequestId: "test-request-1",
- }
-
- schedulingResult := &types.SchedulingResult{
- ProfileResults: map[string]*types.ProfileRunResult{
- "test-profile": {
- TargetPods: []types.Pod{podA},
- },
- },
- }
+ endpointA := newTestEndpoint("pod-a", 2)
+ endpointB := newTestEndpoint("pod-b", 0)
- // First request
- scorer.PreRequest(ctx, request, schedulingResult)
+ testProfile := "test-profile"
- // Check cache and pod counts
- compositeKey := "default/pod-a.test-request-1"
- if !scorer.requestCache.Has(compositeKey) {
- t.Errorf("Expected request to be in cache with key %s", compositeKey)
- }
+ t.Run("First request", func(t *testing.T) {
+ request := newTestRequest("test-request-1")
+ schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{
+ testProfile: endpointA,
+ })
- scorer.mutex.RLock()
- count := scorer.podCounts["default/pod-a"]
- scorer.mutex.RUnlock()
- if count != 1 {
- t.Errorf("Expected pod-a count to be 1, got %d", count)
- }
+ scorer.PreRequest(ctx, request, schedulingResult)
- // Second request with different ID to same pod
- request2 := &types.LLMRequest{
- RequestId: "test-request-2",
- }
- schedulingResult2 := &types.SchedulingResult{
- ProfileResults: map[string]*types.ProfileRunResult{
- "test-profile": {
- TargetPods: []types.Pod{podA},
- },
- },
- }
+ assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache")
+ assert.Equal(t, 1, scorer.getPodCount(endpointA.GetMetadata().NamespacedName.String()))
+ })
- scorer.PreRequest(ctx, request2, schedulingResult2)
+ t.Run("Second request to multiple endpoints", func(t *testing.T) {
+ request := newTestRequest("test-request-2")
+ schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{
+ testProfile: endpointA,
+ "prefill": endpointB,
+ })
- // Check incremented count
- scorer.mutex.RLock()
- count = scorer.podCounts["default/pod-a"]
- scorer.mutex.RUnlock()
- if count != 2 {
- t.Errorf("Expected pod-a count to be 2, got %d", count)
- }
+ scorer.PreRequest(ctx, request, schedulingResult)
- // Check both requests are in cache
- compositeKey2 := "default/pod-a.test-request-2"
- if !scorer.requestCache.Has(compositeKey2) {
- t.Errorf("Expected second request to be in cache with key %s", compositeKey2)
- }
+ assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache")
+ assert.Equal(t, 2, scorer.getPodCount(endpointA.GetMetadata().NamespacedName.String()))
+ assert.Equal(t, 1, scorer.getPodCount(endpointB.GetMetadata().NamespacedName.String()))
+ })
}
func TestActiveRequestScorer_ResponseComplete(t *testing.T) {
ctx := utils.NewTestContext(t)
-
scorer := NewActiveRequest(ctx, nil)
- request := &types.LLMRequest{
- RequestId: "test-request-1",
- }
+ endpointA := newTestEndpoint("pod-a", 2)
+ request := newTestRequest("test-request-1")
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{
- WaitingQueueSize: 2,
- },
- }
// Setup initial state: add request through PreRequest
- schedulingResult := &types.SchedulingResult{
- ProfileResults: map[string]*types.ProfileRunResult{
- "test-profile": {
- TargetPods: []types.Pod{podA},
- },
- },
- }
-
+ schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{
+ "test-profile": endpointA,
+ })
scorer.PreRequest(ctx, request, schedulingResult)
- // Verify initial state
- compositeKey := "default/pod-a.test-request-1"
- if !scorer.requestCache.Has(compositeKey) {
- t.Fatal("Request should be in cache before ResponseComplete")
- }
-
- scorer.mutex.RLock()
- initialCount := scorer.podCounts["default/pod-a"]
- scorer.mutex.RUnlock()
- if initialCount != 1 {
- t.Fatalf("Expected initial count to be 1, got %d", initialCount)
- }
-
- // Call PostResponse
- scorer.ResponseComplete(ctx, request, &requestcontrol.Response{}, podA.GetPod())
+ // Call ResponseComplete
+ scorer.ResponseComplete(ctx, request, &requestcontrol.Response{}, endpointA.GetMetadata())
- // Check request is removed from cache
- if scorer.requestCache.Has(compositeKey) {
- t.Errorf("Request should be removed from cache after ResponseComplete")
- }
-
- // Check pod count is decremented and removed (since it was 1)
- scorer.mutex.RLock()
- _, exists := scorer.podCounts["default/pod-a"]
- scorer.mutex.RUnlock()
- if exists {
- t.Errorf("Pod should be removed from podCounts when count reaches 0")
- }
+ assert.False(t, scorer.requestCache.Has(request.RequestId))
+ assert.False(t, scorer.hasPodCount(endpointA.GetMetadata().NamespacedName.String()),
+ "Pod count should be removed after decrement to zero")
}
func TestActiveRequestScorer_TTLExpiration(t *testing.T) {
@@ -231,34 +190,19 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) {
// Use very short timeout for test
params := &ActiveRequestParameters{RequestTimeout: "1s"}
- scorer := NewActiveRequest(ctx, params) // 1 second timeout
-
- request := &types.LLMRequest{
- RequestId: "test-request-ttl",
- }
+ scorer := NewActiveRequest(ctx, params)
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}},
- }
-
- schedulingResult := &types.SchedulingResult{
- ProfileResults: map[string]*types.ProfileRunResult{
- "test-profile": {
- TargetPods: []types.Pod{podA},
- },
- },
- }
+ endpointA := newTestEndpoint("pod-a", 0)
+ request := newTestRequest("test-request-ttl")
+ schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{
+ "test-profile": endpointA,
+ })
// Add request
scorer.PreRequest(ctx, request, schedulingResult)
// Verify request is added
- scorer.mutex.RLock()
- initialCount := scorer.podCounts["default/pod-a"]
- scorer.mutex.RUnlock()
- if initialCount != 1 {
- t.Fatalf("Expected initial count to be 1, got %d", initialCount)
- }
+ require.Equal(t, 1, scorer.getPodCount("default/pod-a"), "Expected initial count to be 1")
// Wait for TTL expiration
time.Sleep(2 * time.Second)
@@ -266,13 +210,9 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) {
// Trigger cleanup
scorer.requestCache.DeleteExpired()
- // Check that pod count is decremented due to TTL expiration
- scorer.mutex.RLock()
- _, exists := scorer.podCounts["default/pod-a"]
- scorer.mutex.RUnlock()
- if exists {
- t.Errorf("Pod should be removed from podCounts after TTL expiration")
- }
+ // Check that endpoint count is decremented due to TTL expiration
+ assert.False(t, scorer.hasPodCount("default/pod-a"),
+ "Pod should be removed from endpointCounts after TTL expiration")
}
func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) {
@@ -282,9 +222,7 @@ func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) {
scorer := NewActiveRequest(ctx, params)
// Should use default timeout when invalid value is provided
- if scorer == nil {
- t.Error("Expected scorer to be created even with invalid timeout")
- }
+ assert.NotNil(t, scorer, "Expected scorer to be created even with invalid timeout")
}
func TestActiveRequestScorer_TypedName(t *testing.T) {
@@ -292,10 +230,7 @@ func TestActiveRequestScorer_TypedName(t *testing.T) {
scorer := NewActiveRequest(ctx, nil)
- typedName := scorer.TypedName()
- if typedName.Type != ActiveRequestType {
- t.Errorf("Expected type %s, got %s", ActiveRequestType, typedName.Type)
- }
+ assert.Equal(t, ActiveRequestType, scorer.TypedName().Type)
}
func TestActiveRequestScorer_WithName(t *testing.T) {
@@ -306,7 +241,5 @@ func TestActiveRequestScorer_WithName(t *testing.T) {
scorer = scorer.WithName(testName)
- if scorer.TypedName().Name != testName {
- t.Errorf("Expected name %s, got %s", testName, scorer.TypedName().Name)
- }
+ assert.Equal(t, testName, scorer.TypedName().Name)
}
diff --git a/pkg/plugins/scorer/load_aware.go b/pkg/plugins/scorer/load_aware.go
index c4f86d0bb..4f3ef918b 100644
--- a/pkg/plugins/scorer/load_aware.go
+++ b/pkg/plugins/scorer/load_aware.go
@@ -6,10 +6,9 @@ import (
"fmt"
"sigs.k8s.io/controller-runtime/pkg/log"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
const (
@@ -25,10 +24,10 @@ type loadAwareParameters struct {
}
// compile-time type assertion
-var _ framework.Scorer = &LoadAware{}
+var _ scheduling.Scorer = &LoadAware{}
// LoadAwareFactory defines the factory function for the LoadAware
-func LoadAwareFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
+func LoadAwareFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) {
parameters := loadAwareParameters{Threshold: QueueThresholdDefault}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
@@ -47,19 +46,19 @@ func NewLoadAware(ctx context.Context, queueThreshold int) *LoadAware {
}
return &LoadAware{
- typedName: plugins.TypedName{Type: LoadAwareType},
+ typedName: plugin.TypedName{Type: LoadAwareType},
queueThreshold: float64(queueThreshold),
}
}
// LoadAware scorer that is based on load
type LoadAware struct {
- typedName plugins.TypedName
+ typedName plugin.TypedName
queueThreshold float64
}
// TypedName returns the typed name of the plugin.
-func (s *LoadAware) TypedName() plugins.TypedName {
+func (s *LoadAware) TypedName() plugin.TypedName {
return s.typedName
}
@@ -69,6 +68,11 @@ func (s *LoadAware) WithName(name string) *LoadAware {
return s
}
+// Category returns the preference the scorer applies when scoring candidate endpoints.
+func (s *LoadAware) Category() scheduling.ScorerCategory {
+ return scheduling.Distribution
+}
+
// Score scores the given pod in range of 0-1
// Currently metrics contains number of requests waiting in the queue, there is no information about number of requests
// that can be processed in the given pod immediately.
@@ -76,20 +80,20 @@ func (s *LoadAware) WithName(name string) *LoadAware {
// Pod with requests in the queue will get score between 0.5 and 0.
// Score 0 will get pod with number of requests in the queue equal to the threshold used in load-based filter
// In the future, pods with additional capacity will get score higher than 0.5
-func (s *LoadAware) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
- scoredPods := make(map[types.Pod]float64)
+func (s *LoadAware) Score(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 {
+ scoredEndpoints := make(map[scheduling.Endpoint]float64)
- for _, pod := range pods {
- waitingRequests := float64(pod.GetMetrics().WaitingQueueSize)
+ for _, endpoint := range endpoints {
+ waitingRequests := float64(endpoint.GetMetrics().WaitingQueueSize)
if waitingRequests == 0 {
- scoredPods[pod] = 0.5
+ scoredEndpoints[endpoint] = 0.5
} else {
if waitingRequests > s.queueThreshold {
waitingRequests = s.queueThreshold
}
- scoredPods[pod] = 0.5 * (1.0 - (waitingRequests / s.queueThreshold))
+ scoredEndpoints[endpoint] = 0.5 * (1.0 - (waitingRequests / s.queueThreshold))
}
}
- return scoredPods
+ return scoredEndpoints
}
diff --git a/pkg/plugins/scorer/load_aware_test.go b/pkg/plugins/scorer/load_aware_test.go
index e693e99b5..c3e43a94b 100644
--- a/pkg/plugins/scorer/load_aware_test.go
+++ b/pkg/plugins/scorer/load_aware_test.go
@@ -6,56 +6,57 @@ import (
"github.com/google/go-cmp/cmp"
- k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ k8stypes "k8s.io/apimachinery/pkg/types" // Import config for thresholds
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)
func TestLoadBasedScorer(t *testing.T) {
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
- MetricsState: &backendmetrics.MetricsState{
+ endpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
+ &fwkdl.Metrics{
WaitingQueueSize: 2,
},
- }
- podB := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
- MetricsState: &backendmetrics.MetricsState{
+ nil,
+ )
+ endpointB := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
+ &fwkdl.Metrics{
WaitingQueueSize: 0,
},
- }
- podC := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}},
- MetricsState: &backendmetrics.MetricsState{
+ nil,
+ )
+ endpointC := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}},
+ &fwkdl.Metrics{
WaitingQueueSize: 15,
},
- }
+ nil,
+ )
tests := []struct {
name string
- scorer framework.Scorer
- req *types.LLMRequest
- input []types.Pod
- wantScores map[types.Pod]float64
+ scorer scheduling.Scorer
+ req *scheduling.LLMRequest
+ input []scheduling.Endpoint
+ wantScores map[scheduling.Endpoint]float64
}{
{
name: "load based scorer",
scorer: scorer.NewLoadAware(utils.NewTestContext(t), 10),
- req: &types.LLMRequest{
+ req: &scheduling.LLMRequest{
TargetModel: "critical",
},
- input: []types.Pod{
- podA, podB, podC,
+ input: []scheduling.Endpoint{
+ endpointA, endpointB, endpointC,
},
- wantScores: map[types.Pod]float64{
- podA: 0.4,
- podB: 0.5,
- podC: 0,
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 0.4,
+ endpointB: 0.5,
+ endpointC: 0,
},
},
}
diff --git a/pkg/plugins/scorer/no_hit_lru.go b/pkg/plugins/scorer/no_hit_lru.go
index 417cf05a5..65182199a 100644
--- a/pkg/plugins/scorer/no_hit_lru.go
+++ b/pkg/plugins/scorer/no_hit_lru.go
@@ -7,24 +7,29 @@ import (
lru "github.com/hashicorp/golang-lru/v2"
"sigs.k8s.io/controller-runtime/pkg/log"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix"
)
const (
// NoHitLRUType is the type of the NoHitLRU scorer
NoHitLRUType = "no-hit-lru-scorer"
- // defaultLRUSize is the maximum number of pods we'll consider in the cache
+ // defaultLRUSize is the maximum number of endpoints we'll consider in the cache
defaultLRUSize = 1024
+
+ // defaultPrefillProfile is the name of the prefill profile
+ //
+ // This is currently hardcoded until we have a defined proper config interface.
+ // (See also https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/2104/ )
+ defaultPrefillProfile = "prefill"
)
// compile-time type assertions
-var _ framework.Scorer = &NoHitLRU{}
+var _ scheduling.Scorer = &NoHitLRU{}
var _ requestcontrol.PreRequest = &NoHitLRU{}
// NoHitLRUParameters defines the parameters for the NoHitLRU scorer.
@@ -36,7 +41,7 @@ type NoHitLRUParameters struct {
// Defaults to "prefix-cache-scorer".
PrefixPluginName string `json:"prefixPluginName"`
- // LRUSize defines the maximum number of pods to track in the LRU cache.
+ // LRUSize defines the maximum number of endpoints to track in the LRU cache.
LRUSize int `json:"lruSize"`
}
@@ -46,13 +51,13 @@ type coldRequestState struct {
isCold bool
}
-// Clone implements the plugins.StateData interface
-func (c *coldRequestState) Clone() plugins.StateData {
+// Clone implements the plugin.StateData interface
+func (c *coldRequestState) Clone() plugin.StateData {
return &coldRequestState{isCold: c.isCold}
}
// NoHitLRUFactory defines the factory function for the NoHitLRU
-func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) {
+func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) {
parameters := NoHitLRUParameters{}
if rawParameters != nil {
if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
@@ -95,25 +100,25 @@ func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU {
}
return &NoHitLRU{
- typedName: plugins.TypedName{Type: NoHitLRUType},
+ typedName: plugin.TypedName{Type: NoHitLRUType},
lruCache: lruCache,
- prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName},
- pluginState: plugins.NewPluginState(ctx),
+ prefixPluginTypedName: plugin.TypedName{Type: prefixPluginType, Name: prefixPluginName},
+ pluginState: plugin.NewPluginState(ctx),
}
}
-// NoHitLRU scorer that favors pods that were least recently used for cold requests.
+// NoHitLRU scorer that favors endpoints that were least recently used for cold requests.
// This can help evenly distribute cache growth, since cold requests result in more
// new KV blocks.
type NoHitLRU struct {
- typedName plugins.TypedName
- lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order)
- prefixPluginTypedName plugins.TypedName
- pluginState *plugins.PluginState
+ typedName plugin.TypedName
+ lruCache *lru.Cache[string, struct{}] // endpoint name -> dummy value (we only care about order)
+ prefixPluginTypedName plugin.TypedName
+ pluginState *plugin.PluginState
}
// TypedName returns the typed name of the plugin.
-func (s *NoHitLRU) TypedName() plugins.TypedName {
+func (s *NoHitLRU) TypedName() plugin.TypedName {
return s.typedName
}
@@ -123,14 +128,19 @@ func (s *NoHitLRU) WithName(name string) *NoHitLRU {
return s
}
+// Category returns the preference the scorer applies when scoring candidate endpoints.
+func (s *NoHitLRU) Category() scheduling.ScorerCategory {
+ return scheduling.Distribution
+}
+
// isColdRequest determines if a request is cold by reading the prefix cache state.
// Returns true if no prefix cache hits were found, or if prefix cache state is unavailable.
-func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleState) bool {
+func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *scheduling.CycleState) bool {
logger := log.FromContext(ctx).V(logutil.DEBUG)
// Read prefix cache state to determine if this is a cold request
// This is treated as an optimization - if the state isn't available, we assume cold request
- prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginTypedName.String()))
+ prefixState, err := scheduling.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugin.StateKey(s.prefixPluginTypedName.String()))
if err != nil {
logger.Info("No prefix cache state found, treating as cold request for LRU optimization", "error", err)
@@ -141,17 +151,17 @@ func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleSta
return len(prefixState.PrefixCacheServers) == 0
}
-// scoreNeutral returns neutral scores (0.5) for all pods.
+// scoreNeutral returns neutral scores (0.5) for all endpoints.
// Used when a request has cache hits and LRU optimization should not apply.
-func (s *NoHitLRU) scoreNeutral(pods []types.Pod) map[types.Pod]float64 {
- scoredPods := make(map[types.Pod]float64, len(pods))
- for _, pod := range pods {
- scoredPods[pod] = 0.5
+func (s *NoHitLRU) scoreNeutral(endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 {
+ scoredEndpoints := make(map[scheduling.Endpoint]float64, len(endpoints))
+ for _, endpoint := range endpoints {
+ scoredEndpoints[endpoint] = 0.5
}
- return scoredPods
+ return scoredEndpoints
}
-// getLRUPositions returns a map of pod names to their LRU position.
+// getLRUPositions returns a map of endpoint names to their LRU position.
// Position 0 represents the oldest (least recently used) entry.
func (s *NoHitLRU) getLRUPositions() map[string]int {
// Get all keys from LRU cache in order (oldest first)
@@ -165,105 +175,105 @@ func (s *NoHitLRU) getLRUPositions() map[string]int {
return lruPosition
}
-// partitionPodsByUsage separates pods into those that have received cold requests
+// partitionPodsByUsage separates endpoints into those that have received cold requests
// (usedPods) and those that have never received cold requests (neverUsedPods).
-func (s *NoHitLRU) partitionPodsByUsage(pods []types.Pod, lruPosition map[string]int) (usedPods, neverUsedPods []types.Pod) {
- for _, pod := range pods {
- podName := pod.GetPod().NamespacedName.String()
- if _, exists := lruPosition[podName]; exists {
- usedPods = append(usedPods, pod)
+func (s *NoHitLRU) partitionPodsByUsage(endpoints []scheduling.Endpoint, lruPosition map[string]int) (usedEndpoints, neverUsedEndpoints []scheduling.Endpoint) {
+ for _, endpoint := range endpoints {
+ endpointName := endpoint.GetMetadata().NamespacedName.String()
+ if _, exists := lruPosition[endpointName]; exists {
+ usedEndpoints = append(usedEndpoints, endpoint)
} else {
- neverUsedPods = append(neverUsedPods, pod)
+ neverUsedEndpoints = append(neverUsedEndpoints, endpoint)
}
}
- return usedPods, neverUsedPods
+ return usedEndpoints, neverUsedEndpoints
}
-// scoreNeverUsedPods assigns scores to pods that have never received a cold request.
-// The first never-used pod gets the highest score (1.0), with subsequent pods
+// scoreNeverUsedEndpoints assigns scores to endpoints that have never received a cold request.
+// The first never-used endpoint gets the highest score (1.0), with subsequent endpoints
// receiving progressively lower scores.
-func (s *NoHitLRU) scoreNeverUsedPods(scoredPods map[types.Pod]float64, neverUsedPods []types.Pod, totalPods int) {
+func (s *NoHitLRU) scoreNeverUsedPods(scoredPods map[scheduling.Endpoint]float64, neverUsedPods []scheduling.Endpoint, totalEndpoints int) {
// Avoid possibility of dividing by zero.
- if totalPods <= 1 {
+ if totalEndpoints <= 1 {
return
}
- for i, pod := range neverUsedPods {
- score := 1.0 - float64(i)/float64(totalPods-1)
- scoredPods[pod] = score
+ for i, endpoint := range neverUsedPods {
+ score := 1.0 - float64(i)/float64(totalEndpoints-1)
+ scoredPods[endpoint] = score
}
}
-// scoreUsedPods assigns scores to pods based on their LRU position.
+// scoreUsedPods assigns scores to endpoints based on their LRU position.
// Pods that were least recently used for cold requests receive higher scores.
-func (s *NoHitLRU) scoreUsedPods(scoredPods map[types.Pod]float64, usedPods []types.Pod, lruPosition map[string]int, neverUsedCount, totalPods int) {
+func (s *NoHitLRU) scoreUsedPods(scoredEndpoints map[scheduling.Endpoint]float64, usedPods []scheduling.Endpoint, lruPosition map[string]int, neverUsedCount, totalEndpoints int) {
// Avoid possibility of dividing by zero.
- if totalPods <= 1 {
+ if totalEndpoints <= 1 {
return
}
- for _, pod := range usedPods {
- podName := pod.GetPod().NamespacedName.String()
- lruPos := lruPosition[podName]
+ for _, endpoint := range usedPods {
+ endpointName := endpoint.GetMetadata().NamespacedName.String()
+ lruPos := lruPosition[endpointName]
// LRU keys are oldest to newest so rank 0 = oldest
- // The never used pod count is added to the rank so that
- // a never-used pod will always have the highest score.
+ // The never used endpoint count is added to the rank so that
+ // a never-used endpoint will always have the highest score.
rank := neverUsedCount + lruPos
- score := 1.0 - float64(rank)/float64(totalPods-1)
+ score := 1.0 - float64(rank)/float64(totalEndpoints-1)
if score < 0 {
score = 0
}
- scoredPods[pod] = score
+ scoredEndpoints[endpoint] = score
}
}
-// scoreColdRequestByLRU scores pods based on their LRU position for cold requests.
+// scoreColdRequestByLRU scores endpoints based on their LRU position for cold requests.
// Pods that have never received a cold request get the highest scores.
-// Among previously used pods, least recently used ones get higher scores.
-func (s *NoHitLRU) scoreColdRequestByLRU(pods []types.Pod) map[types.Pod]float64 {
- scoredPods := make(map[types.Pod]float64, len(pods))
- totalPods := len(pods)
+// Among previously used endpoints, least recently used ones get higher scores.
+func (s *NoHitLRU) scoreColdRequestByLRU(endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 {
+ scoredEndpoints := make(map[scheduling.Endpoint]float64, len(endpoints))
+ totalEndpoints := len(endpoints)
// Avoid possibility of dividing by zero.
- if totalPods == 1 {
- scoredPods[pods[0]] = 1.0
- return scoredPods
+ if totalEndpoints == 1 {
+ scoredEndpoints[endpoints[0]] = 1.0
+ return scoredEndpoints
}
lruPosition := s.getLRUPositions()
- usedPods, neverUsedPods := s.partitionPodsByUsage(pods, lruPosition)
+ usedEndpoints, neverUsedEndpoints := s.partitionPodsByUsage(endpoints, lruPosition)
- s.scoreNeverUsedPods(scoredPods, neverUsedPods, totalPods)
- s.scoreUsedPods(scoredPods, usedPods, lruPosition, len(neverUsedPods), totalPods)
+ s.scoreNeverUsedPods(scoredEndpoints, neverUsedEndpoints, totalEndpoints)
+ s.scoreUsedPods(scoredEndpoints, usedEndpoints, lruPosition, len(neverUsedEndpoints), totalEndpoints)
- return scoredPods
+ return scoredEndpoints
}
-// Score scores the given pods based on LRU for cold requests.
-// For cache hits, returns neutral scores (0.5) for all pods.
-// For cache misses, ranks pods by their LRU order.
-// - LRU ordering is with respect to when a pod last received a cold request.
-// - Least recently used (or never used) pods get highest score (1.0)
-// - Most recently used pods get lowest score (approaching 0.0)
-func (s *NoHitLRU) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
+// Score scores the given endpoints based on LRU for cold requests.
+// For cache hits, returns neutral scores (0.5) for all endpoints.
+// For cache misses, ranks endpoints by their LRU order.
+// - LRU ordering is with respect to when a endpoint last received a cold request.
+// - Least recently used (or never used) endpoints get highest score (1.0)
+// - Most recently used endpoints get lowest score (approaching 0.0)
+func (s *NoHitLRU) Score(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 {
logger := log.FromContext(ctx).V(logutil.DEBUG)
isCold := s.isColdRequest(ctx, cycleState)
// Store the cold request state in plugin state for PreRequest to use
coldState := &coldRequestState{isCold: isCold}
- s.pluginState.Write(request.RequestId, plugins.StateKey(s.typedName.String()), coldState)
+ s.pluginState.Write(request.RequestId, plugin.StateKey(s.typedName.String()), coldState)
if !isCold {
logger.Info("Cache hit detected, returning neutral scores")
- return s.scoreNeutral(pods)
+ return s.scoreNeutral(endpoints)
}
- logger.Info("Cold request detected, scoring pods by LRU")
- return s.scoreColdRequestByLRU(pods)
+ logger.Info("Cold request detected, scoring endpoints by LRU")
+ return s.scoreColdRequestByLRU(endpoints)
}
-// PreRequest is called before a request is sent to the target pod.
-// For cold requests, it updates the LRU cache to track which pods have been used recently.
-func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) {
+// PreRequest is called before a request is sent to the target endpoint.
+// For cold requests, it updates the LRU cache to track which endpoints have been used recently.
+func (s *NoHitLRU) PreRequest(ctx context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) {
logger := log.FromContext(ctx).V(logutil.DEBUG)
if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 {
@@ -272,7 +282,7 @@ func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, sc
}
// Read the cold request state we stored in Score
- coldState, err := plugins.ReadPluginStateKey[*coldRequestState](s.pluginState, request.RequestId, plugins.StateKey(s.typedName.String()))
+ coldState, err := plugin.ReadPluginStateKey[*coldRequestState](s.pluginState, request.RequestId, plugin.StateKey(s.typedName.String()))
// After fetching the cold state, drop it from the plugin state immediately (otherwise it will hang around until it becomes stale).
s.pluginState.Delete(request.RequestId)
@@ -286,19 +296,23 @@ func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, sc
return
}
- // Get the primary profile's target pod
- primaryProfile := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]
- if primaryProfile == nil || len(primaryProfile.TargetPods) == 0 {
- logger.Info("No target pod in primary profile")
- return
+ if targetProfile, ok := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]; ok && targetProfile != nil && len(targetProfile.TargetEndpoints) != 0 {
+ s.moveTargetPodToFront(ctx, request, targetProfile, schedulingResult.PrimaryProfileName)
}
+ if targetProfile, ok := schedulingResult.ProfileResults[defaultPrefillProfile]; ok && targetProfile != nil && len(targetProfile.TargetEndpoints) != 0 {
+ s.moveTargetPodToFront(ctx, request, targetProfile, defaultPrefillProfile)
+ }
+}
+
+func (s *NoHitLRU) moveTargetPodToFront(ctx context.Context, request *scheduling.LLMRequest, targetProfile *scheduling.ProfileRunResult, profileName string) {
+ logger := log.FromContext(ctx).V(logutil.DEBUG)
- targetPod := primaryProfile.TargetPods[0]
- podName := targetPod.GetPod().NamespacedName.String()
+ targetPod := targetProfile.TargetEndpoints[0]
+ endpointName := targetPod.GetMetadata().NamespacedName.String()
- // Move the pod to the front of the LRU.
+ // Move the endpoint to the front of the LRU.
var present struct{} // dummy value
- s.lruCache.Add(podName, present)
+ s.lruCache.Add(endpointName, present)
- logger.Info("Updated LRU cache for cold request", "pod", podName, "requestId", request.RequestId)
+ logger.Info("Updated LRU cache for cold request", "profile", profileName, "endpoint", endpointName, "requestId", request.RequestId)
}
diff --git a/pkg/plugins/scorer/no_hit_lru_test.go b/pkg/plugins/scorer/no_hit_lru_test.go
index 6890c998c..2af62b021 100644
--- a/pkg/plugins/scorer/no_hit_lru_test.go
+++ b/pkg/plugins/scorer/no_hit_lru_test.go
@@ -9,49 +9,47 @@ import (
"github.com/google/go-cmp/cmp"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)
-var _ plugins.Handle = &fakeHandle{}
+var _ plugin.Handle = &fakeHandle{}
type fakeHandle struct {
ctx context.Context
- plugins map[string]plugins.Plugin
+ plugins map[string]plugin.Plugin
}
func newFakeHandle(ctx context.Context) *fakeHandle {
- return &fakeHandle{ctx: ctx, plugins: map[string]plugins.Plugin{}}
+ return &fakeHandle{ctx: ctx, plugins: map[string]plugin.Plugin{}}
}
func (h *fakeHandle) Context() context.Context {
return h.ctx
}
-func (h *fakeHandle) Plugin(name string) plugins.Plugin {
+func (h *fakeHandle) Plugin(name string) plugin.Plugin {
return h.plugins[name]
}
-func (h *fakeHandle) AddPlugin(name string, plugin plugins.Plugin) {
+func (h *fakeHandle) AddPlugin(name string, plugin plugin.Plugin) {
h.plugins[name] = plugin
}
-func (h *fakeHandle) GetAllPlugins() []plugins.Plugin {
- result := make([]plugins.Plugin, 0, len(h.plugins))
+func (h *fakeHandle) GetAllPlugins() []plugin.Plugin {
+ result := make([]plugin.Plugin, 0, len(h.plugins))
for _, plugin := range h.plugins {
result = append(result, plugin)
}
return result
}
-func (h *fakeHandle) GetAllPluginsWithNames() map[string]plugins.Plugin {
+func (h *fakeHandle) GetAllPluginsWithNames() map[string]plugin.Plugin {
return h.plugins
}
@@ -60,10 +58,10 @@ func (h *fakeHandle) PodList() []k8stypes.NamespacedName {
}
type stubPlugin struct {
- name plugins.TypedName
+ name plugin.TypedName
}
-func (p *stubPlugin) TypedName() plugins.TypedName {
+func (p *stubPlugin) TypedName() plugin.TypedName {
return p.name
}
@@ -84,7 +82,7 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) {
name: "prefix plugin present - should work",
handle: func() *fakeHandle {
h := newFakeHandle(utils.NewTestContext(t))
- h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugins.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}})
+ h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}})
return h
}(),
expectError: false,
@@ -123,87 +121,90 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) {
}
func TestNoHitLRUScorer(t *testing.T) {
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- podB := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- podC := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
+ endpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpointB := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpointC := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
tests := []struct {
name string
- scorer framework.Scorer
- req *types.LLMRequest
- input []types.Pod
+ scorer scheduling.Scorer
+ req *scheduling.LLMRequest
+ input []scheduling.Endpoint
prefixState *prefix.SchedulingContextState
- wantScores map[types.Pod]float64
+ wantScores map[scheduling.Endpoint]float64
description string
}{
{
- name: "cold request - all pods never used",
+ name: "cold request - all endpoints never used",
scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil),
- req: &types.LLMRequest{
+ req: &scheduling.LLMRequest{
TargetModel: "test-model",
},
- input: []types.Pod{podA, podB, podC},
+ input: []scheduling.Endpoint{endpointA, endpointB, endpointC},
prefixState: &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request
},
- wantScores: map[types.Pod]float64{
- podA: 1.0, // All never-used pods get high scores
- podB: 0.5,
- podC: 0.0,
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 1.0, // All never-used endpoints get high scores
+ endpointB: 0.5,
+ endpointC: 0.0,
},
- description: "Never-used pods should get high scores for cold requests",
+ description: "Never-used endpoints should get high scores for cold requests",
},
{
name: "cache hit - neutral scores",
scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil),
- req: &types.LLMRequest{
+ req: &scheduling.LLMRequest{
TargetModel: "test-model",
},
- input: []types.Pod{podA, podB, podC},
+ input: []scheduling.Endpoint{endpointA, endpointB, endpointC},
prefixState: &prefix.SchedulingContextState{
PrefixCacheServers: map[prefix.ServerID]int{
{Name: "server1", Namespace: "default"}: 5, // non-empty = cache hit
},
},
- wantScores: map[types.Pod]float64{
- podA: 0.5, // All pods get neutral scores for cache hits
- podB: 0.5,
- podC: 0.5,
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 0.5, // All endpoints get neutral scores for cache hits
+ endpointB: 0.5,
+ endpointC: 0.5,
},
description: "Cache hits should return neutral scores",
},
{
- name: "single pod - max score",
+ name: "single endpoint - max score",
scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil),
- req: &types.LLMRequest{
+ req: &scheduling.LLMRequest{
TargetModel: "test-model",
},
- input: []types.Pod{podA},
+ input: []scheduling.Endpoint{endpointA},
prefixState: &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request
},
- wantScores: map[types.Pod]float64{
- podA: 1.0, // Single pod gets max score
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 1.0, // Single endpoint gets max score
},
- description: "Single pod should get maximum score",
+ description: "Single endpoint should get maximum score",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// Create cycle state and set prefix state
- cycleState := &types.CycleState{}
+ cycleState := &scheduling.CycleState{}
if test.prefixState != nil {
- cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
+ cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), test.prefixState)
}
@@ -221,42 +222,44 @@ func TestNoHitLRUBasicFunctionality(t *testing.T) {
scorer := scorer.NewNoHitLRU(ctx, nil)
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- podB := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
+ endpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpointB := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
- pods := []types.Pod{podA, podB}
+ endpoints := []scheduling.Endpoint{endpointA, endpointB}
// Test basic scoring for cold request (no crashes, returns valid scores)
coldPrefixState := &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request
}
- cycleState := &types.CycleState{}
- cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
+ cycleState := &scheduling.CycleState{}
+ cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), coldPrefixState)
- scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods)
+ scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints)
- // Should return scores for all pods
+ // Should return scores for all endpoints
if len(scores) != 2 {
t.Errorf("Expected 2 scores, got %d", len(scores))
}
// All scores should be valid (between 0 and 1)
- for pod, score := range scores {
+ for endpoint, score := range scores {
if score < 0 || score > 1 {
- t.Errorf("Invalid score %f for pod %s", score, pod.GetPod().NamespacedName.String())
+ t.Errorf("Invalid score %f for endpoint %s", score, endpoint.GetMetadata().NamespacedName.String())
}
}
- // For never-used pods, should have different scores (to provide ordering)
- if scores[podA] == scores[podB] {
- t.Errorf("Expected different scores for different pods, both got %f", scores[podA])
+ // For never-used endpoints, should have different scores (to provide ordering)
+ if scores[endpointA] == scores[endpointB] {
+ t.Errorf("Expected different scores for different endpoints, both got %f", scores[endpointA])
}
}
@@ -264,16 +267,17 @@ func TestNoPrefixCacheStateFound(t *testing.T) {
ctx := utils.NewTestContext(t)
scorer := scorer.NewNoHitLRU(ctx, nil)
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- pods := []types.Pod{podA}
- cycleState := &types.CycleState{}
+ endpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpoints := []scheduling.Endpoint{endpointA}
+ cycleState := &scheduling.CycleState{}
- scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods)
+ scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints)
- if scores[podA] != 1.0 {
+ if scores[endpointA] != 1.0 {
t.Errorf("Failure to find a prefix cache should result in scoring as a cold request.")
}
}
@@ -282,120 +286,123 @@ func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) {
ctx := utils.NewTestContext(t)
scorer := scorer.NewNoHitLRU(ctx, nil)
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- podB := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- podC := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- pods := []types.Pod{podA, podB, podC}
+ endpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpointB := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpointC := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpoints := []scheduling.Endpoint{endpointA, endpointB, endpointC}
primaryProfile := "primary-profile"
- toPrefixState := func(entries map[prefix.ServerID]int) *types.CycleState {
- cycle := &types.CycleState{}
- cycle.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
+ toPrefixState := func(entries map[prefix.ServerID]int) *scheduling.CycleState {
+ cycle := &scheduling.CycleState{}
+ cycle.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{PrefixCacheServers: entries})
return cycle
}
- requestToPod := func(target types.Pod) *types.SchedulingResult {
- return &types.SchedulingResult{
+ requestToEndpoint := func(target scheduling.Endpoint) *scheduling.SchedulingResult {
+ return &scheduling.SchedulingResult{
PrimaryProfileName: primaryProfile,
- ProfileResults: map[string]*types.ProfileRunResult{
+ ProfileResults: map[string]*scheduling.ProfileRunResult{
primaryProfile: {
- TargetPods: []types.Pod{target},
+ TargetEndpoints: []scheduling.Endpoint{target},
},
},
}
}
// Test LRU behavior indirectly through scoring rather than internal state
- assertHighestScoredPod := func(expectedPod types.Pod, testName string) {
+ assertHighestScoredPod := func(expectedEndpoint scheduling.Endpoint, testName string) {
t.Helper()
- coldReq := &types.LLMRequest{RequestId: testName + "-scoring-check"}
- scores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReq, pods)
+ coldReq := &scheduling.LLMRequest{RequestId: testName + "-scoring-check"}
+ scores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReq, endpoints)
highestScore := -1.0
- var highestPod types.Pod
- for pod, score := range scores {
+ var highestEndpoint scheduling.Endpoint
+ for endpoint, score := range scores {
if score > highestScore {
highestScore = score
- highestPod = pod
+ highestEndpoint = endpoint
}
}
- if highestPod.GetPod().NamespacedName.String() != expectedPod.GetPod().NamespacedName.String() {
+ if highestEndpoint.GetMetadata().NamespacedName.String() != expectedEndpoint.GetMetadata().NamespacedName.String() {
t.Fatalf("expected %s to have highest score for LRU behavior, but %s had highest score (%f). All scores: %+v",
- expectedPod.GetPod().NamespacedName.String(),
- highestPod.GetPod().NamespacedName.String(),
+ expectedEndpoint.GetMetadata().NamespacedName.String(),
+ highestEndpoint.GetMetadata().NamespacedName.String(),
highestScore,
scores)
}
}
t.Run("initial cold request seeds cache", func(_ *testing.T) {
- coldReqA := &types.LLMRequest{RequestId: "cold-1"}
- scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqA, pods)
- scorer.PreRequest(ctx, coldReqA, requestToPod(podA))
- // After podA handles a cold request, other pods should score higher for new cold requests
- assertHighestScoredPod(podB, "after-podA-used")
+ coldReqA := &scheduling.LLMRequest{RequestId: "cold-1"}
+ scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqA, endpoints)
+ scorer.PreRequest(ctx, coldReqA, requestToEndpoint(endpointA))
+ // After endpointA handles a cold request, other endpoints should score higher for new cold requests
+ assertHighestScoredPod(endpointB, "after-endpointA-used")
})
- t.Run("unused pods rank above existing ones", func(t *testing.T) {
- coldReqCheck := &types.LLMRequest{RequestId: "cold-check"}
- coldScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqCheck, pods)
- if coldScores[podB] <= coldScores[podA] {
- t.Fatalf("expected pod-b to outrank pod-a after pod-a handled previous cold request, scores=%+v", coldScores)
+ t.Run("unused endpoints rank above existing ones", func(t *testing.T) {
+ coldReqCheck := &scheduling.LLMRequest{RequestId: "cold-check"}
+ coldScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqCheck, endpoints)
+ if coldScores[endpointB] <= coldScores[endpointA] {
+ t.Fatalf("expected endpoint-b to outrank endpoint-a after endpoint-a handled previous cold request, scores=%+v", coldScores)
}
- if coldScores[podB] != 1.0 {
- t.Fatalf("expected pod-b to score 1.0, scores=%+v", coldScores)
+ if coldScores[endpointB] != 1.0 {
+ t.Fatalf("expected endpoint-b to score 1.0, scores=%+v", coldScores)
}
- if coldScores[podC] != 0.5 {
- t.Fatalf("expected pod-c to score 0.5, scores=%+v", coldScores)
+ if coldScores[endpointC] != 0.5 {
+ t.Fatalf("expected endpoint-c to score 0.5, scores=%+v", coldScores)
}
})
t.Run("warm request leaves LRU untouched", func(t *testing.T) {
- warmReq := &types.LLMRequest{RequestId: "warm-1"}
+ warmReq := &scheduling.LLMRequest{RequestId: "warm-1"}
warmState := map[prefix.ServerID]int{
{Name: "server1", Namespace: "default"}: 1,
}
- warmScores := scorer.Score(ctx, toPrefixState(warmState), warmReq, pods)
+ warmScores := scorer.Score(ctx, toPrefixState(warmState), warmReq, endpoints)
for _, score := range warmScores {
if score != 0.5 {
t.Fatalf("expected neutral score for warm request, got %f", score)
}
}
- scorer.PreRequest(ctx, warmReq, requestToPod(podB))
- postWarmReq := &types.LLMRequest{RequestId: "cold-after-warm"}
- postWarmScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), postWarmReq, pods)
- if postWarmScores[podB] <= postWarmScores[podA] {
+ scorer.PreRequest(ctx, warmReq, requestToEndpoint(endpointB))
+ postWarmReq := &scheduling.LLMRequest{RequestId: "cold-after-warm"}
+ postWarmScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), postWarmReq, endpoints)
+ if postWarmScores[endpointB] <= postWarmScores[endpointA] {
t.Fatalf("expected warm request to leave ordering unchanged, scores=%+v", postWarmScores)
}
})
- t.Run("second cold request rotates to podB", func(_ *testing.T) {
- // Simulate podB handling a cold request
- coldReqB := &types.LLMRequest{RequestId: "cold-2"}
- scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqB, pods)
- scorer.PreRequest(ctx, coldReqB, requestToPod(podB))
- // Now podC should score highest since both podA and podB have been used
- assertHighestScoredPod(podC, "after-podB-used")
+ t.Run("second cold request rotates to endpointB", func(_ *testing.T) {
+ // Simulate endpointB handling a cold request
+ coldReqB := &scheduling.LLMRequest{RequestId: "cold-2"}
+ scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqB, endpoints)
+ scorer.PreRequest(ctx, coldReqB, requestToEndpoint(endpointB))
+ // Now endpointC should score highest since both endpointA and endpointB have been used
+ assertHighestScoredPod(endpointC, "after-endpointB-used")
})
- t.Run("third cold request rotates back to podA", func(_ *testing.T) {
- // Simulate podC handling a cold request
- coldReqC := &types.LLMRequest{RequestId: "cold-3"}
- scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqC, pods)
- scorer.PreRequest(ctx, coldReqC, requestToPod(podC))
- // Now podA should score highest again (LRU rotation)
- assertHighestScoredPod(podA, "after-podC-used")
+ t.Run("third cold request rotates back to endpointA", func(_ *testing.T) {
+ // Simulate endpointC handling a cold request
+ coldReqC := &scheduling.LLMRequest{RequestId: "cold-3"}
+ scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqC, endpoints)
+ scorer.PreRequest(ctx, coldReqC, requestToEndpoint(endpointC))
+ // Now endpointA should score highest again (LRU rotation)
+ assertHighestScoredPod(endpointA, "after-endpointC-used")
})
}
@@ -403,55 +410,177 @@ func TestNoHitLRUEdgeCases(t *testing.T) {
ctx := utils.NewTestContext(t)
scorer := scorer.NewNoHitLRU(ctx, nil)
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
+ endpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
- t.Run("empty pods list", func(t *testing.T) {
- emptyPods := []types.Pod{}
- cycleState := &types.CycleState{}
- cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
+ t.Run("empty endpoints list", func(t *testing.T) {
+ emptyEndpoints := []scheduling.Endpoint{}
+ cycleState := &scheduling.CycleState{}
+ cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // cold request
})
- scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, emptyPods)
+ scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, emptyEndpoints)
if len(scores) != 0 {
- t.Errorf("Expected empty scores for empty pods list, got %d scores", len(scores))
+ t.Errorf("Expected empty scores for empty endpoints list, got %d scores", len(scores))
}
})
- t.Run("nil pods list", func(t *testing.T) {
- cycleState := &types.CycleState{}
- cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
+ t.Run("nil endpoints list", func(t *testing.T) {
+ cycleState := &scheduling.CycleState{}
+ cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // cold request
})
- scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, nil)
+ scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, nil)
if scores == nil {
- t.Errorf("Expected non-nil scores map for nil pods list")
+ t.Errorf("Expected non-nil scores map for nil endpoints list")
}
if len(scores) != 0 {
- t.Errorf("Expected empty scores for nil pods list, got %d scores", len(scores))
+ t.Errorf("Expected empty scores for nil endpoints list, got %d scores", len(scores))
}
})
- t.Run("single pod returns 1.0", func(t *testing.T) {
- pods := []types.Pod{podA}
- cycleState := &types.CycleState{}
- cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType,
+ t.Run("single endpoint returns 1.0", func(t *testing.T) {
+ endpoints := []scheduling.Endpoint{endpointA}
+ cycleState := &scheduling.CycleState{}
+ cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType,
Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{
PrefixCacheServers: make(map[prefix.ServerID]int), // cold request
})
- scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods)
+ scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints)
+
+ if scores[endpointA] != 1.0 {
+ t.Errorf("Expected single endpoint to get score 1.0, got %f", scores[endpointA])
+ }
+ })
+}
+
+func TestNoHitLRUPrefillDecodeTracking(t *testing.T) {
+ // Prefill worker endpoints
+ prefillEndpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "prefill-a", Namespace: "default"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ prefillEndpointB := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "prefill-b", Namespace: "default"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+
+ // Decode worker endpoints
+ decodeEndpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "decode-a", Namespace: "default"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ decodeEndpointB := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "decode-b", Namespace: "default"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+
+ prefillEndpoints := []scheduling.Endpoint{prefillEndpointA, prefillEndpointB}
+ decodeEndpoints := []scheduling.Endpoint{decodeEndpointA, decodeEndpointB}
+
+ coldPrefixState := &scheduling.CycleState{}
+ coldPrefixState.Write(plugin.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{
+ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request
+ })
+
+ ctx := context.Background()
+
+ t.Run("P/D scenario - both profiles tracked separately", func(t *testing.T) {
+ scorer := scorer.NewNoHitLRU(ctx, nil)
+
+ // First cold request with P/D
+ req1 := &scheduling.LLMRequest{RequestId: "pd-request-1"}
+ scorer.Score(ctx, coldPrefixState, req1, append(prefillEndpoints, decodeEndpoints...))
+
+ // Simulate scheduling result with both prefill and decode profiles
+ pdResult := &scheduling.SchedulingResult{
+ PrimaryProfileName: "decode",
+ ProfileResults: map[string]*scheduling.ProfileRunResult{
+ "prefill": {
+ TargetEndpoints: []scheduling.Endpoint{prefillEndpointA},
+ },
+ "decode": {
+ TargetEndpoints: []scheduling.Endpoint{decodeEndpointA},
+ },
+ },
+ }
+ scorer.PreRequest(ctx, req1, pdResult)
+
+ // Second cold request - both prefillPodB and decodePodB should score higher
+ // since prefillPodA and decodePodA were just used
+ req2 := &scheduling.LLMRequest{RequestId: "pd-request-2"}
+ prefillScores := scorer.Score(ctx, coldPrefixState, req2, prefillEndpoints)
+ decodeScores := scorer.Score(ctx, coldPrefixState, req2, decodeEndpoints)
+
+ if prefillScores[prefillEndpointB] <= prefillScores[prefillEndpointA] {
+ t.Errorf("Expected prefill-b to score higher than prefill-a after prefill-a was used: %+v", prefillScores)
+ }
+
+ if decodeScores[decodeEndpointB] <= decodeScores[decodeEndpointA] {
+ t.Errorf("Expected decode-b to score higher than decode-a after decode-a was used: %+v", decodeScores)
+ }
+ })
+
+ t.Run("non-P/D scenario - only primary profile exists", func(t *testing.T) {
+ req := &scheduling.LLMRequest{RequestId: "non-pd-request"}
+ scorer := scorer.NewNoHitLRU(ctx, nil)
+ scorer.Score(ctx, coldPrefixState, req, decodeEndpoints)
+
+ // Scheduling result with only decode profile (no prefill)
+ result := &scheduling.SchedulingResult{
+ PrimaryProfileName: "decode",
+ ProfileResults: map[string]*scheduling.ProfileRunResult{
+ "decode": {
+ TargetEndpoints: []scheduling.Endpoint{decodeEndpointA},
+ },
+ // No "prefill" profile in results
+ },
+ }
+ // Should not panic when prefill profile doesn't exist
+ scorer.PreRequest(ctx, req, result)
+
+ // Verify decodePodA was tracked
+ req2 := &scheduling.LLMRequest{RequestId: "non-pd-request-2"}
+ scores := scorer.Score(ctx, coldPrefixState, req2, decodeEndpoints)
+
+ if scores[decodeEndpointB] <= scores[decodeEndpointA] {
+ t.Errorf("Expected decode-b to score higher than decode-a: %+v", scores)
+ }
+ })
+
+ t.Run("nil scheduling result - graceful handling", func(_ *testing.T) {
+ req := &scheduling.LLMRequest{RequestId: "nil-result"}
+ scorer := scorer.NewNoHitLRU(ctx, nil)
+ scorer.Score(ctx, coldPrefixState, req, decodeEndpoints)
+
+ // Should not panic with nil result
+ scorer.PreRequest(ctx, req, nil)
+ })
+
+ t.Run("empty profile results - graceful handling", func(_ *testing.T) {
+ req := &scheduling.LLMRequest{RequestId: "empty-results"}
+ scorer := scorer.NewNoHitLRU(ctx, nil)
+ scorer.Score(ctx, coldPrefixState, req, decodeEndpoints)
- if scores[podA] != 1.0 {
- t.Errorf("Expected single pod to get score 1.0, got %f", scores[podA])
+ result := &scheduling.SchedulingResult{
+ PrimaryProfileName: "decode",
+ ProfileResults: map[string]*scheduling.ProfileRunResult{},
}
+ // Should not panic with empty profile results
+ scorer.PreRequest(ctx, req, result)
})
}
diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go
index 2ce6551c0..f839625d8 100644
--- a/pkg/plugins/scorer/precise_prefix_cache.go
+++ b/pkg/plugins/scorer/precise_prefix_cache.go
@@ -6,16 +6,18 @@ import (
"errors"
"fmt"
"os"
+ "time"
- "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache"
- "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents"
- preprocessing "github.com/llm-d/llm-d-kv-cache-manager/pkg/preprocessing/chat_completions"
+ "github.com/jellydator/ttlcache/v3"
+ "github.com/llm-d/llm-d-kv-cache/pkg/kvcache"
+ "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock"
+ "github.com/llm-d/llm-d-kv-cache/pkg/kvevents"
+ preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions"
"sigs.k8s.io/controller-runtime/pkg/log"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix"
)
const (
@@ -26,8 +28,11 @@ const (
// PrecisePrefixCachePluginConfig holds the configuration for the
// PrecisePrefixCacheScorer plugin.
type PrecisePrefixCachePluginConfig struct {
+ // TokenProcessorConfig holds the configuration for the `kvblock.TokenProcessor` which is
+ // used to process tokens into KV-block keys.
+ TokenProcessorConfig *kvblock.TokenProcessorConfig `json:"tokenProcessorConfig"`
// IndexerConfig holds the configuration for the `kvcache.Indexer` which is
- // used to score pods based on the KV-cache index state.
+ // used to score endpoints based on the KV-cache index state.
IndexerConfig *kvcache.Config `json:"indexerConfig"`
// KVEventsConfig holds the configuration for the `kvevents.Pool` which is
// used to subscribe to KV-cache events and update the internal KV-cache
@@ -36,13 +41,12 @@ type PrecisePrefixCachePluginConfig struct {
}
// compile-time type assertion
-var _ framework.Scorer = &PrecisePrefixCacheScorer{}
+var _ scheduling.Scorer = &PrecisePrefixCacheScorer{}
// PrecisePrefixCachePluginFactory defines the factory function for creating
// a new instance of the PrefixCacheTrackingPlugin.
func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage,
- handle plugins.Handle) (plugins.Plugin, error) {
-
+ handle plugin.Handle) (plugin.Plugin, error) {
indexerConfig, err := kvcache.NewDefaultConfig()
if err != nil {
return nil, fmt.Errorf("failed to initialize indexer config: %w", err)
@@ -53,18 +57,24 @@ func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage,
KVEventsConfig: kvevents.DefaultConfig(),
}
- // read hugging face token from environment variable if set
+ if rawParameters != nil {
+ if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
+ return nil, fmt.Errorf("failed to parse %s plugin config: %w", PrecisePrefixCachePluginType, err)
+ }
+ }
+
+ // Apply HF token from environment if not already set
if token := os.Getenv("HF_TOKEN"); token != "" &&
parameters.IndexerConfig != nil &&
parameters.IndexerConfig.TokenizersPoolConfig != nil &&
- parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig != nil {
+ parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig != nil &&
+ parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken == "" {
parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken = token
}
- if rawParameters != nil {
- if err := json.Unmarshal(rawParameters, ¶meters); err != nil {
- return nil, fmt.Errorf("failed to parse %s plugin config: %w", PrecisePrefixCachePluginType, err)
- }
+ // Validate model name is set
+ if parameters.IndexerConfig == nil || parameters.IndexerConfig.TokenizersPoolConfig == nil || parameters.IndexerConfig.TokenizersPoolConfig.ModelName == "" {
+ return nil, errors.New("modelName is required in indexerConfig.tokenizersPoolConfig")
}
scorer, err := New(handle.Context(), parameters)
@@ -80,13 +90,19 @@ func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage,
// based on the provided configuration. The `kvevents.Pool` is started
// in a goroutine to listen for KV-cache events and update the internal
// KV-cache index state. The `kvcache.Indexer` is also started in a goroutine
-// to score pods based on the KV-cache index state.
+// to score endpoints based on the KV-cache index state.
//
// If the configuration is invalid or if the indexer fails to initialize,
// an error is returned.
func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePrefixCacheScorer, error) {
+ if config.TokenProcessorConfig == nil {
+ config.TokenProcessorConfig = kvblock.DefaultTokenProcessorConfig()
+ }
+
+ tokenProcessor := kvblock.NewChunkedTokenDatabase(config.TokenProcessorConfig)
+
// initialize the indexer
- kvCacheIndexer, err := kvcache.NewKVCacheIndexer(ctx, config.IndexerConfig)
+ kvCacheIndexer, err := kvcache.NewKVCacheIndexer(ctx, config.IndexerConfig, tokenProcessor)
if err != nil {
return nil, fmt.Errorf("failed to create `kvcache.Indexer`: %w", err)
}
@@ -94,27 +110,66 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr
go kvCacheIndexer.Run(ctx)
// initialize the KV-events pool
- pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex())
+ pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex(), tokenProcessor)
pool.Start(ctx)
+ subscribersManager := kvevents.NewSubscriberManager(pool)
+ var subscribersCache *ttlcache.Cache[string, struct{}]
+
+ // initialize the subscribers cache only if endpoint discovery is enabled
+ if config.KVEventsConfig.DiscoverPods {
+ // initialize the subscribers TTL cache
+ subscriptionTimeout := 10 * time.Minute
+ subscribersCache = ttlcache.New[string, struct{}](
+ ttlcache.WithTTL[string, struct{}](subscriptionTimeout),
+ )
+ subscribersCache.OnEviction(func(ctx context.Context, reason ttlcache.EvictionReason,
+ item *ttlcache.Item[string, struct{}],
+ ) {
+ if reason == ttlcache.EvictionReasonExpired {
+ subscribersManager.RemoveSubscriber(ctx, item.Key())
+ }
+ })
+ go cleanCachePeriodically(ctx, subscribersCache, subscriptionTimeout)
+ }
+ if config.KVEventsConfig.ZMQEndpoint != "" {
+ // setup local subscriber to support global socket mode
+ if err := subscribersManager.EnsureSubscriber(ctx, "local-subscriber",
+ config.KVEventsConfig.ZMQEndpoint, config.KVEventsConfig.TopicFilter, false); err != nil {
+ return nil, fmt.Errorf("failed to create local subscriber for global socket mode: %w", err)
+ }
+ }
+
return &PrecisePrefixCacheScorer{
- typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType},
- kvCacheIndexer: kvCacheIndexer,
+ typedName: plugin.TypedName{Type: PrecisePrefixCachePluginType},
+ kvCacheIndexer: kvCacheIndexer,
+ subscribersCache: subscribersCache,
+ subscribersManager: subscribersManager,
+ kvEventsConfig: config.KVEventsConfig,
}, nil
}
// PrecisePrefixCacheScorer implements the framework.Scorer interface.
// The scorer implements precise prefix-cache KV-block locality scoring.
-// It uses the `kvcache.Indexer` to score pods based on the KV-cache index
+// It uses the `kvcache.Indexer` to score endpoints based on the KV-cache index
// state, and the `kvevents.Pool` to subscribe to KV-cache events
// to keep the internal KV-cache index state up-to-date.
type PrecisePrefixCacheScorer struct {
- typedName plugins.TypedName
+ typedName plugin.TypedName
kvCacheIndexer *kvcache.Indexer
+
+ // until the IGW data-layer is ready to provide endpoint events,
+ // we maintain a TTL cache of known endpoints that are discovered through
+ // the scoring process. If a endpoint is not in the received endpoints list
+ // during scoring for a certain period, we consider it gone and
+ // stop its KV events subscription.
+ subscribersCache *ttlcache.Cache[string, struct{}]
+ subscribersManager *kvevents.SubscriberManager
+ kvEventsConfig *kvevents.Config
}
// TypedName returns the typed name of the plugin.
-func (s *PrecisePrefixCacheScorer) TypedName() plugins.TypedName {
+func (s *PrecisePrefixCacheScorer) TypedName() plugin.TypedName {
return s.typedName
}
@@ -124,12 +179,37 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor
return s
}
-// Score scores the provided pod based on the KVCache index state.
+// Category returns the preference the scorer applies when scoring candidate endpoints.
+func (s *PrecisePrefixCacheScorer) Category() scheduling.ScorerCategory {
+ return scheduling.Affinity
+}
+
+// Score scores the provided endpoint based on the KVCache index state.
// The returned scores are normalized to a range of 0-1.
-func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
+func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 {
logger := log.FromContext(ctx).WithName(s.typedName.String())
debugLogger := logger.V(logutil.DEBUG)
+ if s.kvEventsConfig.DiscoverPods {
+ // update subscribers here temporarily
+ for _, endpoint := range endpoints {
+ endpointObj := endpoint.GetMetadata()
+ if endpointObj == nil {
+ continue
+ }
+ endpointKey := endpointObj.NamespacedName.String()
+ s.subscribersCache.Set(endpointKey, struct{}{}, 0) // use default TTL
+
+ if err := s.subscribersManager.EnsureSubscriber(context.Background(), endpointKey, // dont use request ctx
+ fmt.Sprintf("tcp://%s:%d", endpointObj.Address, s.kvEventsConfig.PodDiscoveryConfig.SocketPort),
+ s.kvEventsConfig.TopicFilter, true); err != nil {
+ logger.Error(err, "Failed to ensure KV-events subscriber for endpoint", "endpoint", endpointKey,
+ "endpoint", endpointObj.Address)
+ continue
+ }
+ }
+ }
+
if request == nil {
debugLogger.Info("Request is nil, skipping scoring")
return nil
@@ -137,41 +217,41 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.
scores, err := s.getScores(ctx, request)
if err != nil {
- logger.Error(err, "Failed to get pod scores")
+ logger.Error(err, "Failed to get endpoint scores")
return nil
}
- debugLogger.Info("Got pod scores", "scores", scores)
+ debugLogger.Info("Got endpoint scores", "scores", scores)
- podToKey := func(pod types.Pod) (string, bool) {
- metricsPod := pod.GetPod()
- if metricsPod == nil {
+ endpointToKey := func(endpoint scheduling.Endpoint) (string, bool) {
+ metadata := endpoint.GetMetadata()
+ if metadata == nil {
return "", false
}
- return metricsPod.Address, true
+ return metadata.Address, true
}
state := &prefix.SchedulingContextState{
PrefixHashes: []prefix.BlockHash{},
PrefixCacheServers: map[prefix.ServerID]int{},
}
- for _, pod := range pods {
- key, ok := podToKey(pod)
+ for _, endpoint := range endpoints {
+ key, ok := endpointToKey(endpoint)
if !ok {
continue
}
- state.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] = int(scores[key])
+ state.PrefixCacheServers[prefix.ServerID(endpoint.GetMetadata().NamespacedName)] = int(scores[key])
}
- cycleState.Write(plugins.StateKey(s.typedName.String()), state)
+ cycleState.Write(plugin.StateKey(s.typedName.String()), state)
- return indexedScoresToNormalizedScoredPods(pods, podToKey, scores)
+ return indexedScoresToNormalizedScoredPods(endpoints, endpointToKey, scores)
}
-// getScores retrieves the pod scores from the KV-cache indexer
+// getScores retrieves the endpoint scores from the KV-cache indexer
// based on the provided LLM request.
// If the request contains chat completions, it processes them accordingly.
// If the request contains regular completions, it uses the prompt directly.
-func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types.LLMRequest) (map[string]float64, error) {
+func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *scheduling.LLMRequest) (map[string]float64, error) {
logger := log.FromContext(ctx).WithName(s.typedName.String())
traceLogger := logger.V(logutil.TRACE)
@@ -186,8 +266,17 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types
traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions")
}
- renderReq := &preprocessing.RenderJinjaTemplateRequest{
- Conversations: make([]preprocessing.ChatMessage, 0),
+ // Convert messages to conversation format
+ conversations := make([]preprocessing.Conversation, len(request.Body.ChatCompletions.Messages))
+ for i, msg := range request.Body.ChatCompletions.Messages {
+ conversations[i] = preprocessing.Conversation{
+ Role: msg.Role,
+ Content: msg.Content.Raw,
+ }
+ }
+
+ renderReq := &preprocessing.ApplyChatTemplateRequest{
+ Conversation: [][]preprocessing.Conversation{conversations},
Tools: request.Body.ChatCompletions.Tools,
Documents: request.Body.ChatCompletions.Documents,
ChatTemplate: request.Body.ChatCompletions.ChatTemplate,
@@ -197,22 +286,14 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types
ChatTemplateKWArgs: request.Body.ChatCompletions.ChatTemplateKWArgs,
}
- // Convert messages to the format expected by the renderer
- for _, msg := range request.Body.ChatCompletions.Messages {
- renderReq.Conversations = append(renderReq.Conversations, preprocessing.ChatMessage{
- Role: msg.Role,
- Content: msg.Content.Raw,
- })
- }
-
traceLogger.Info("Processing chat completion request",
- "messagesCount", len(renderReq.Conversations),
+ "messagesCount", len(conversations),
"toolsCount", len(renderReq.Tools),
"documentsCount", len(renderReq.Documents))
scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, "", request.TargetModel, nil)
if err != nil {
- return nil, fmt.Errorf("failed to get pod scores for chat/completions: %w", err)
+ return nil, fmt.Errorf("failed to get endpoint scores for chat/completions: %w", err)
}
return scores, nil
}
@@ -224,7 +305,7 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types
scores, err := s.kvCacheIndexer.GetPodScores(ctx, nil, prompt, request.TargetModel, nil)
if err != nil {
- return nil, fmt.Errorf("failed to get pod scores for completions: %w", err)
+ return nil, fmt.Errorf("failed to get endpoint scores for completions: %w", err)
}
return scores, nil
}
diff --git a/pkg/plugins/scorer/precise_prefix_cache_test.go b/pkg/plugins/scorer/precise_prefix_cache_test.go
index eb7284b95..968497798 100644
--- a/pkg/plugins/scorer/precise_prefix_cache_test.go
+++ b/pkg/plugins/scorer/precise_prefix_cache_test.go
@@ -6,16 +6,15 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
- "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache"
- "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvblock"
- "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents"
- preprocessing "github.com/llm-d/llm-d-kv-cache-manager/pkg/preprocessing/chat_completions"
- "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization"
+ "github.com/llm-d/llm-d-kv-cache/pkg/kvcache"
+ "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock"
+ "github.com/llm-d/llm-d-kv-cache/pkg/kvevents"
+ preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions"
+ "github.com/llm-d/llm-d-kv-cache/pkg/tokenization"
"github.com/stretchr/testify/require"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)
@@ -37,34 +36,38 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
testcases := []struct {
name string
- pods []types.Pod
- request *types.LLMRequest
- kvBlockData func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry
+ endpoints []scheduling.Endpoint
+ request *scheduling.LLMRequest
+ kvBlockData func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry
wantScoresByAddress map[string]float64
}{
{
name: "nil request",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- },
+ nil,
+ nil,
+ ),
},
wantScoresByAddress: map[string]float64{}, // empty map
},
{
name: "empty request body",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- },
+ nil,
+ nil,
+ ),
},
- request: &types.LLMRequest{
+ request: &scheduling.LLMRequest{
RequestId: "test-request",
TargetModel: "test-model",
Body: nil,
@@ -73,63 +76,66 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
},
{
name: "longest prefix scorer (default scorer)",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 0,
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
Address: "10.0.0.2:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 1,
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-c"},
Address: "10.0.0.3:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 2,
},
- },
+ nil,
+ ),
},
- request: &types.LLMRequest{
+ request: &scheduling.LLMRequest{
RequestId: "test-request",
TargetModel: "test-model",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &scheduling.LLMRequestBody{
+ Completions: &scheduling.CompletionsRequest{
Prompt: prompt,
},
},
},
- kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry {
+ kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry {
require.NotNil(t, req.Completions, "req expected to use Completions API")
prompt := req.Completions.Prompt
- testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig)
+ testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig)
require.NoError(t, err)
// use the actual tokenizer on the test prompt
- tokens, _, err := testTokenizer.Encode(prompt, model)
+ tokens, _, err := testTokenizer.Encode(prompt, model, true)
require.NoError(t, err)
// compute chunk hashes using the default block size
tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig())
- chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model)
+ chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model)
require.GreaterOrEqual(t, len(chunkKeys), 3, "Need at least 3 chunks for test")
// populate kvblock.Index to test longest prefix matching:
- // - chunk0 (first chunk): all pods have it (common prefix start)
+ // - chunk0 (first chunk): all endpoints have it (common prefix start)
// - chunk1: pod-a and pod-b have it (pod-c drops off after chunk0)
// - chunk2: only pod-a has it (pod-b drops off after chunk1)
// LongestPrefixScorer uses intersection, so:
@@ -138,17 +144,17 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
// pod-c: 1 chunk (0) -> score 1
// Normalized: (3-1)/(3-1) = 1.0, (2-1)/(3-1) = 0.5, (1-1)/(3-1) = 0.0
- return map[kvblock.Key][]kvblock.PodEntry{
- {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: {
+ return map[kvblock.BlockHash][]kvblock.PodEntry{
+ chunkKeys[0]: {
{PodIdentifier: "10.0.0.1:8080"},
{PodIdentifier: "10.0.0.2:8080"},
{PodIdentifier: "10.0.0.3:8080"},
},
- {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: {
+ chunkKeys[1]: {
{PodIdentifier: "10.0.0.1:8080"},
{PodIdentifier: "10.0.0.2:8080"},
},
- {ModelName: model, ChunkHash: chunkKeys[2].ChunkHash}: {
+ chunkKeys[2]: {
{PodIdentifier: "10.0.0.1:8080"},
},
}
@@ -161,90 +167,99 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
},
{
name: "chat completions request",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 0,
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
Address: "10.0.0.2:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 1,
},
- },
+ nil,
+ ),
},
- request: &types.LLMRequest{
+ request: &scheduling.LLMRequest{
RequestId: "test-request",
TargetModel: "test-model",
- Body: &types.LLMRequestBody{
- ChatCompletions: &types.ChatCompletionsRequest{
+ Body: &scheduling.LLMRequestBody{
+ ChatCompletions: &scheduling.ChatCompletionsRequest{
ChatTemplate: `{% for message in messages %}{{ message.role }}: {{ message.content }}
-{% endfor %}`,
- Messages: []types.Message{
+ {% endfor %}`,
+ Messages: []scheduling.Message{
{
Role: "user",
- Content: types.Content{Raw: "Hello, how are you?"},
+ Content: scheduling.Content{Raw: "Hello, how are you?"},
},
{
Role: "assistant",
- Content: types.Content{Raw: "I'm doing well, thank you for asking!"},
+ Content: scheduling.Content{Raw: "I'm doing well, thank you for asking!"},
},
{
Role: "user",
- Content: types.Content{Raw: "Can you help me with a question about prefix caching in LLM inference?"},
+ Content: scheduling.Content{Raw: "Can you help me with a question about prefix caching in LLM inference?"},
},
},
},
},
},
- kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry {
+ kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry {
require.NotNil(t, req.ChatCompletions, "req expected to use ChatCompletions API")
// convert to preprocessing format
- var chatMessages []preprocessing.ChatMessage
+ var conversations []preprocessing.Conversation
for _, msg := range req.ChatCompletions.Messages {
- chatMessages = append(chatMessages, preprocessing.ChatMessage{
+ conversations = append(conversations, preprocessing.Conversation{
Role: msg.Role,
Content: msg.Content.Raw,
})
}
+ processor := preprocessing.NewChatTemplatingProcessor()
+ tokenizerCacheKey, err := processor.GetOrCreateTokenizerKey(t.Context(), &preprocessing.GetOrCreateTokenizerKeyRequest{
+ IsLocal: true,
+ Model: "testdata/" + model,
+ })
+ require.NoError(t, err)
+
// render the chat template
- renderReq := &preprocessing.RenderJinjaTemplateRequest{
- Conversations: chatMessages,
- ChatTemplate: req.ChatCompletions.ChatTemplate,
+ renderReq := &preprocessing.ApplyChatTemplateRequest{
+ Key: tokenizerCacheKey,
+ Conversation: [][]preprocessing.Conversation{conversations},
+ ChatTemplate: req.ChatCompletions.ChatTemplate,
}
- processor := preprocessing.NewChatTemplatingProcessor()
- rendered, err := processor.RenderChatTemplate(t.Context(), renderReq)
+ rendered, err := processor.ApplyChatTemplate(t.Context(), renderReq)
require.NoError(t, err)
// tokenize rendered prompt
- testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig)
+ testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig)
require.NoError(t, err)
- tokens, _, err := testTokenizer.Encode(rendered.RenderedChats[0], model)
+ tokens, _, err := testTokenizer.Encode(rendered, model, false)
require.NoError(t, err)
tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig())
- chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model)
+ chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model)
require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test")
// pod-a has both chunks, pod-b has only the first
- return map[kvblock.Key][]kvblock.PodEntry{
- {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: {
+ return map[kvblock.BlockHash][]kvblock.PodEntry{
+ chunkKeys[0]: {
{PodIdentifier: "10.0.0.1:8080"},
{PodIdentifier: "10.0.0.2:8080"},
},
- {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: {
+ chunkKeys[1]: {
{PodIdentifier: "10.0.0.1:8080"},
},
}
@@ -256,60 +271,63 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
},
{
name: "partial prefix",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 0,
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
Address: "10.0.0.2:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 1,
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-c"},
Address: "10.0.0.3:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 2,
},
- },
+ nil,
+ ),
},
- request: &types.LLMRequest{
+ request: &scheduling.LLMRequest{
RequestId: "test-request",
TargetModel: "test-model",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &scheduling.LLMRequestBody{
+ Completions: &scheduling.CompletionsRequest{
Prompt: prompt,
},
},
},
- kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry {
+ kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry {
require.NotNil(t, req.Completions, "req expected to use Completions API")
- testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig)
+ testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig)
require.NoError(t, err)
- tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model)
+ tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model, true)
require.NoError(t, err)
tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig())
- chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model)
+ chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model)
require.GreaterOrEqual(t, len(chunkKeys), 3, "Need at least 3 chunks for test")
// Test partial prefix cache scenario:
- // - chunk0: all pods (common prefix start)
+ // - chunk0: all endpoints (common prefix start)
// - chunk1: only pod-a (creates a gap for pod-b and pod-c)
// - chunk2: pod-a and pod-b (pod-b has this but missing chunk1)
//
@@ -317,16 +335,16 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
// pod-a: has chunks 0,1,2 contiguously -> score 3
// pod-b: has chunks 0,2 (missing 1) -> prefix stops at chunk0 -> score 1
// pod-c: has only chunk 0 -> score 1
- return map[kvblock.Key][]kvblock.PodEntry{
- {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: {
+ return map[kvblock.BlockHash][]kvblock.PodEntry{
+ chunkKeys[0]: {
{PodIdentifier: "10.0.0.1:8080"},
{PodIdentifier: "10.0.0.2:8080"},
{PodIdentifier: "10.0.0.3:8080"},
},
- {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: {
+ chunkKeys[1]: {
{PodIdentifier: "10.0.0.1:8080"}, // only pod-a has chunk1
},
- {ModelName: model, ChunkHash: chunkKeys[2].ChunkHash}: {
+ chunkKeys[2]: {
{PodIdentifier: "10.0.0.1:8080"},
{PodIdentifier: "10.0.0.2:8080"}, // pod-b has chunk2 but missing chunk1
},
@@ -342,204 +360,161 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
},
},
{
- name: "different model names",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
- NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
- Address: "10.0.0.1:8080",
- },
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
- NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
- Address: "10.0.0.2:8080",
- },
- },
- },
- request: &types.LLMRequest{
- RequestId: "test-request",
- TargetModel: "test-model",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
- Prompt: prompt,
- },
- },
- },
- kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry {
- require.NotNil(t, req.Completions, "req expected to use Completions API")
-
- testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig)
- require.NoError(t, err)
-
- tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model)
- require.NoError(t, err)
-
- tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig())
- chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model)
-
- require.GreaterOrEqual(t, len(chunkKeys), 1, "Need at least 1 chunk for test")
-
- // Populate the index with blocks for model `different-model`
- // The request will ask for "test-model" but the cache only has "different-model"
- // This should result in no cache hits since models don't share cache
- return map[kvblock.Key][]kvblock.PodEntry{
- {ModelName: "different-model", ChunkHash: chunkKeys[0].ChunkHash}: {
- {PodIdentifier: "10.0.0.1:8080"},
- {PodIdentifier: "10.0.0.2:8080"},
- },
- }
- },
- wantScoresByAddress: map[string]float64{
- // Even though both pods have the chunk cached, it's for a different model
- // so there should be no cache hits for the requested model
- "10.0.0.1:8080": 0.0,
- "10.0.0.2:8080": 0.0,
- },
- },
- {
- name: "single pod",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ name: "single endpoint",
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- MetricsState: &backendmetrics.MetricsState{
+ &fwkdl.Metrics{
WaitingQueueSize: 0,
},
- },
+ nil,
+ ),
},
- request: &types.LLMRequest{
+ request: &scheduling.LLMRequest{
RequestId: "test-request",
TargetModel: "test-model",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &scheduling.LLMRequestBody{
+ Completions: &scheduling.CompletionsRequest{
Prompt: prompt,
},
},
},
- kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry {
+ kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry {
require.NotNil(t, req.Completions, "req expected to use Completions API")
- testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig)
+ testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig)
require.NoError(t, err)
- tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model)
+ tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model, true)
require.NoError(t, err)
tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig())
- chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model)
+ chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model)
require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test")
- // Single pod has 2 chunks cached
- return map[kvblock.Key][]kvblock.PodEntry{
- {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: {
+ // Single endpoint has 2 chunks cached
+ return map[kvblock.BlockHash][]kvblock.PodEntry{
+ chunkKeys[0]: {
{PodIdentifier: "10.0.0.1:8080"},
},
- {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: {
+ chunkKeys[1]: {
{PodIdentifier: "10.0.0.1:8080"},
},
}
},
wantScoresByAddress: map[string]float64{
- // with only one pod, minScore == maxScore, so normalization returns 1.0
+ // with only one endpoint, minScore == maxScore, so normalization returns 1.0
"10.0.0.1:8080": 1.0,
},
},
{
name: "no cache hits (empty index)",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
Address: "10.0.0.2:8080",
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-c"},
Address: "10.0.0.3:8080",
},
- },
+ nil,
+ nil,
+ ),
},
- request: &types.LLMRequest{
+ request: &scheduling.LLMRequest{
RequestId: "test-request",
TargetModel: "test-model",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
- Prompt: "This prompt has never been cached before on any pod.",
+ Body: &scheduling.LLMRequestBody{
+ Completions: &scheduling.CompletionsRequest{
+ Prompt: "This prompt has never been cached before on any endpoint.",
},
},
},
kvBlockData: nil, // no cached data
wantScoresByAddress: map[string]float64{
- // when no pods have any cache hits, all should get equal scores (0.0)
+ // when no endpoints have any cache hits, all should get equal scores (0.0)
"10.0.0.1:8080": 0.0,
"10.0.0.2:8080": 0.0,
"10.0.0.3:8080": 0.0,
},
},
{
- name: "all pods have equal prefix length",
- pods: []types.Pod{
- &types.PodMetrics{
- Pod: &backend.Pod{
+ name: "all endpoints have equal prefix length",
+ endpoints: []scheduling.Endpoint{
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
Address: "10.0.0.2:8080",
},
- },
- &types.PodMetrics{
- Pod: &backend.Pod{
+ nil,
+ nil,
+ ),
+ scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod-c"},
Address: "10.0.0.3:8080",
},
- },
+ nil,
+ nil,
+ ),
},
- request: &types.LLMRequest{
+ request: &scheduling.LLMRequest{
RequestId: "test-request",
TargetModel: "test-model",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &scheduling.LLMRequestBody{
+ Completions: &scheduling.CompletionsRequest{
Prompt: prompt,
},
},
},
- kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry {
+ kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry {
require.NotNil(t, req.Completions, "req expected to use Completions API")
- testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig)
+ testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig)
require.NoError(t, err)
- tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model)
+ tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model, true)
require.NoError(t, err)
tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig())
- chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model)
+ chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model)
require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test")
- // all pods have the same 2 chunks cached
- return map[kvblock.Key][]kvblock.PodEntry{
- {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: {
+ // all endpoints have the same 2 chunks cached
+ return map[kvblock.BlockHash][]kvblock.PodEntry{
+ chunkKeys[0]: {
{PodIdentifier: "10.0.0.1:8080"},
{PodIdentifier: "10.0.0.2:8080"},
{PodIdentifier: "10.0.0.3:8080"},
},
- {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: {
+ chunkKeys[1]: {
{PodIdentifier: "10.0.0.1:8080"},
{PodIdentifier: "10.0.0.2:8080"},
{PodIdentifier: "10.0.0.3:8080"},
@@ -547,8 +522,8 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
}
},
wantScoresByAddress: map[string]float64{
- // when all pods have equal cache (minScore == maxScore), the implementation
- // returns 1.0 for all pods to avoid division by zero
+ // when all endpoints have equal cache (minScore == maxScore), the implementation
+ // returns 1.0 for all endpoints to avoid division by zero
"10.0.0.1:8080": 1.0,
"10.0.0.2:8080": 1.0,
"10.0.0.3:8080": 1.0,
@@ -562,6 +537,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
kvcacheConfig, err := kvcache.NewDefaultConfig()
kvcacheConfig.TokenizersPoolConfig = &tokenization.Config{
+ ModelName: "test-model",
WorkersCount: 1,
MinPrefixOverlapRatio: 0.8,
LocalTokenizerConfig: &localTokenizerConfig,
@@ -580,17 +556,17 @@ func TestPrefixCacheTracking_Score(t *testing.T) {
kvBlockIndex := prefixCacheScorer.kvCacheIndexer.KVBlockIndex()
blockData := tt.kvBlockData(tt.request.Body, tt.request.TargetModel)
for key, entries := range blockData {
- err := kvBlockIndex.Add(ctx, []kvblock.Key{key}, entries)
+ err := kvBlockIndex.Add(ctx, []kvblock.BlockHash{kvblock.EmptyBlockHash}, []kvblock.BlockHash{key}, entries)
require.NoError(t, err)
}
}
- got := prefixCacheScorer.Score(ctx, types.NewCycleState(), tt.request, tt.pods)
+ got := prefixCacheScorer.Score(ctx, scheduling.NewCycleState(), tt.request, tt.endpoints)
gotByAddress := make(map[string]float64)
- for pod, score := range got {
- if podMetrics, ok := pod.(*types.PodMetrics); ok && podMetrics.GetPod() != nil {
- gotByAddress[podMetrics.GetPod().Address] = score
+ for endpoint, score := range got {
+ if endpoint.GetMetadata() != nil {
+ gotByAddress[endpoint.GetMetadata().Address] = score
}
}
diff --git a/pkg/plugins/scorer/session_affinity.go b/pkg/plugins/scorer/session_affinity.go
index 3ac9230c6..87e9d2be9 100644
--- a/pkg/plugins/scorer/session_affinity.go
+++ b/pkg/plugins/scorer/session_affinity.go
@@ -6,12 +6,11 @@ import (
"encoding/json"
"sigs.k8s.io/controller-runtime/pkg/log"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
- logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
+ logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
const (
@@ -22,18 +21,18 @@ const (
)
// compile-time type assertion
-var _ framework.Scorer = &SessionAffinity{}
+var _ scheduling.Scorer = &SessionAffinity{}
var _ requestcontrol.ResponseComplete = &SessionAffinity{}
// SessionAffinityFactory defines the factory function for SessionAffinity scorer.
-func SessionAffinityFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) {
+func SessionAffinityFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) {
return NewSessionAffinity().WithName(name), nil
}
// NewSessionAffinity returns a scorer
func NewSessionAffinity() *SessionAffinity {
return &SessionAffinity{
- typedName: plugins.TypedName{Type: SessionAffinityType},
+ typedName: plugin.TypedName{Type: SessionAffinityType},
}
}
@@ -42,11 +41,11 @@ func NewSessionAffinity() *SessionAffinity {
// session was sent to, by giving that pod the specified weight and assigning
// zero score to the rest of the targets
type SessionAffinity struct {
- typedName plugins.TypedName
+ typedName plugin.TypedName
}
// TypedName returns the typed name of the plugin.
-func (s *SessionAffinity) TypedName() plugins.TypedName {
+func (s *SessionAffinity) TypedName() plugin.TypedName {
return s.typedName
}
@@ -56,9 +55,14 @@ func (s *SessionAffinity) WithName(name string) *SessionAffinity {
return s
}
+// Category returns the preference the scorer applies when scoring candidate endpoints.
+func (s *SessionAffinity) Category() scheduling.ScorerCategory {
+ return scheduling.Affinity
+}
+
// Score assign a high score to the pod used in previous requests and zero to others
-func (s *SessionAffinity) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 {
- scoredPods := make(map[types.Pod]float64)
+func (s *SessionAffinity) Score(ctx context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 {
+ scoredEndpoints := make(map[scheduling.Endpoint]float64)
sessionToken := request.Headers[sessionTokenHeader]
podName := ""
@@ -70,21 +74,21 @@ func (s *SessionAffinity) Score(ctx context.Context, _ *types.CycleState, reques
podName = string(decodedBytes)
}
}
- for _, pod := range pods {
- scoredPods[pod] = 0.0 // initial value
- if pod.GetPod().NamespacedName.String() == podName {
- scoredPods[pod] = 1.0
+ for _, endpoint := range endpoints {
+ scoredEndpoints[endpoint] = 0.0 // initial value
+ if endpoint.GetMetadata().NamespacedName.String() == podName {
+ scoredEndpoints[endpoint] = 1.0
}
}
- return scoredPods
+ return scoredEndpoints
}
// ResponseComplete sets the session header on the response sent to the client
// TODO: this should be using a cookie and ensure not overriding any other
// cookie values if present.
// Tracked in https://github.com/llm-d/llm-d-inference-scheduler/issues/28
-func (s *SessionAffinity) ResponseComplete(ctx context.Context, _ *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) {
+func (s *SessionAffinity) ResponseComplete(ctx context.Context, _ *scheduling.LLMRequest, response *requestcontrol.Response, targetPod *datalayer.EndpointMetadata) {
if response == nil || targetPod == nil {
reqID := "undefined"
if response != nil {
diff --git a/pkg/plugins/scorer/session_affinity_test.go b/pkg/plugins/scorer/session_affinity_test.go
index 943b06eb4..d7acf3468 100644
--- a/pkg/plugins/scorer/session_affinity_test.go
+++ b/pkg/plugins/scorer/session_affinity_test.go
@@ -8,79 +8,80 @@ import (
"github.com/google/go-cmp/cmp"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"github.com/llm-d/llm-d-inference-scheduler/test/utils"
)
func TestSessionAffinity_Score(t *testing.T) {
- podA := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
- podB := &types.PodMetrics{
- Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
- MetricsState: &backendmetrics.MetricsState{},
- }
-
- inputPods := []types.Pod{podA, podB}
-
- // valid session token for podB
- validSessionTokenForPodB := base64.StdEncoding.EncodeToString([]byte(podB.GetPod().NamespacedName.String()))
+ endpointA := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+ endpointB := scheduling.NewEndpoint(
+ &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}},
+ &fwkdl.Metrics{},
+ nil,
+ )
+
+ inputEndpoints := []scheduling.Endpoint{endpointA, endpointB}
+
+ // valid session token for endpointB
+ validSessionTokenForEndpointB := base64.StdEncoding.EncodeToString([]byte(endpointB.GetMetadata().NamespacedName.String()))
sessionAffinityScorer := scorer.NewSessionAffinity()
tests := []struct {
name string
- req *types.LLMRequest
- input []types.Pod
- wantScores map[types.Pod]float64
+ req *scheduling.LLMRequest
+ input []scheduling.Endpoint
+ wantScores map[scheduling.Endpoint]float64
}{
{
- name: "selects correct pod : podB",
- req: &types.LLMRequest{
- Headers: map[string]string{"x-session-token": validSessionTokenForPodB},
+ name: "selects correct endpoint : endpointB",
+ req: &scheduling.LLMRequest{
+ Headers: map[string]string{"x-session-token": validSessionTokenForEndpointB},
},
- input: inputPods,
- wantScores: map[types.Pod]float64{
- podA: 0.0,
- podB: 1.0,
+ input: inputEndpoints,
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 0.0,
+ endpointB: 1.0,
},
},
{
name: "no session token",
- req: &types.LLMRequest{
+ req: &scheduling.LLMRequest{
Headers: map[string]string{},
},
- // both pods get score 0.0
- input: inputPods,
- wantScores: map[types.Pod]float64{
- podA: 0.0,
- podB: 0.0,
+ // both endpoints get score 0.0
+ input: inputEndpoints,
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 0.0,
+ endpointB: 0.0,
},
},
{
name: "invalid session token",
- req: &types.LLMRequest{
+ req: &scheduling.LLMRequest{
Headers: map[string]string{"x-session-token": "garbage-token"},
},
// expect same behavior as no session token
- input: inputPods,
- wantScores: map[types.Pod]float64{
- podA: 0.0,
- podB: 0.0,
+ input: inputEndpoints,
+ wantScores: map[scheduling.Endpoint]float64{
+ endpointA: 0.0,
+ endpointB: 0.0,
},
},
{
- name: "no pods available",
- req: &types.LLMRequest{},
- input: []types.Pod{},
+ name: "no endpoints available",
+ req: &scheduling.LLMRequest{},
+ input: []scheduling.Endpoint{},
// returns empty score map
- wantScores: map[types.Pod]float64{},
+ wantScores: map[scheduling.Endpoint]float64{},
},
}
@@ -97,30 +98,30 @@ func TestSessionAffinity_Score(t *testing.T) {
func TestSessionAffinity_ResponseComplete(t *testing.T) {
- targetPod := &backend.Pod{
+ targetEndpoint := &fwkdl.EndpointMetadata{
NamespacedName: k8stypes.NamespacedName{Name: "pod1"},
Address: "1.2.3.4",
}
// expected token to be set in response header
- wantToken := base64.StdEncoding.EncodeToString([]byte(targetPod.NamespacedName.String()))
+ wantToken := base64.StdEncoding.EncodeToString([]byte(targetEndpoint.NamespacedName.String()))
tests := []struct {
name string
initialResponse *requestcontrol.Response
- targetPod *backend.Pod
+ targetPod *fwkdl.EndpointMetadata
wantHeaders map[string]string
}{
{
name: "standard case with existing headers map",
initialResponse: &requestcontrol.Response{RequestId: "req-1", Headers: make(map[string]string)},
- targetPod: targetPod,
+ targetPod: targetEndpoint,
wantHeaders: map[string]string{"x-session-token": wantToken},
},
{
name: "response with nil headers map",
initialResponse: &requestcontrol.Response{RequestId: "req-2", Headers: nil},
- targetPod: targetPod,
+ targetPod: targetEndpoint,
wantHeaders: map[string]string{"x-session-token": wantToken},
},
{
diff --git a/pkg/plugins/scorer/utils.go b/pkg/plugins/scorer/utils.go
index 31a721b71..4d4b3c741 100644
--- a/pkg/plugins/scorer/utils.go
+++ b/pkg/plugins/scorer/utils.go
@@ -3,41 +3,41 @@ package scorer
import (
"math"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
)
-// podToKey is a function type that converts a Pod to a string key.
+// endpointToKey is a function type that converts a Pod to a string key.
// It returns the key and a boolean indicating success.
-type podToKeyFunc func(pod types.Pod) (string, bool)
+type endpointToKeyFunc func(endpoint scheduling.Endpoint) (string, bool)
// indexedScoresToNormalizedScoredPods converts a map of pod scores to a map of
// normalized scores. The function takes a list of pods, a function to convert
// a pod to a key, and a map of scores indexed by those keys. It returns a map
// of pods to their normalized scores.
-func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc,
- scores map[string]float64) map[types.Pod]float64 {
- scoredPods := make(map[types.Pod]float64)
+func indexedScoresToNormalizedScoredPods(endpoints []scheduling.Endpoint, endpointToKey endpointToKeyFunc,
+ scores map[string]float64) map[scheduling.Endpoint]float64 {
+ scoredEndpoints := make(map[scheduling.Endpoint]float64)
minScore, maxScore := getMinMax(scores)
- for _, pod := range pods {
- key, ok := podToKey(pod)
+ for _, endpoint := range endpoints {
+ key, ok := endpointToKey(endpoint)
if !ok {
continue
}
if score, ok := scores[key]; ok {
if minScore == maxScore {
- scoredPods[pod] = 1.0
+ scoredEndpoints[endpoint] = 1.0
continue
}
- scoredPods[pod] = (score - minScore) / (maxScore - minScore)
+ scoredEndpoints[endpoint] = (score - minScore) / (maxScore - minScore)
} else {
- scoredPods[pod] = 0.0
+ scoredEndpoints[endpoint] = 0.0
}
}
- return scoredPods
+ return scoredEndpoints
}
func getMinMax(scores map[string]float64) (float64, float64) {
diff --git a/pkg/scheduling/pd/scheduler_test.go b/pkg/scheduling/pd/scheduler_test.go
index bd1f6b1c1..e068ade12 100644
--- a/pkg/scheduling/pd/scheduler_test.go
+++ b/pkg/scheduling/pd/scheduler_test.go
@@ -11,14 +11,13 @@ import (
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
k8stypes "k8s.io/apimachinery/pkg/types"
- "sigs.k8s.io/controller-runtime/pkg/log"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
- backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds
+ "sigs.k8s.io/controller-runtime/pkg/log" // Import config for thresholds
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix"
+ fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer"
+ fwkschd "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/picker"
+ "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker"
- "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile"
@@ -32,42 +31,45 @@ const (
// Tests the scheduler expected behavior.
func TestPDSchedule(t *testing.T) {
- pod1 := &types.PodMetrics{
- Pod: &backend.Pod{
- NamespacedName: k8stypes.NamespacedName{Name: "pod1"},
+ endpoint1 := fwkschd.NewEndpoint(
+ &fwkdl.EndpointMetadata{
+ NamespacedName: k8stypes.NamespacedName{Name: "endpoint1"},
Address: "1.2.3.4",
Labels: map[string]string{filter.RoleLabel: filter.RolePrefill},
},
- MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0},
- }
- pod2 := &types.PodMetrics{
- Pod: &backend.Pod{
- NamespacedName: k8stypes.NamespacedName{Name: "pod2"},
+ &fwkdl.Metrics{WaitingQueueSize: 0},
+ fwkdl.NewAttributes(),
+ )
+ endpoint2 := fwkschd.NewEndpoint(
+ &fwkdl.EndpointMetadata{
+ NamespacedName: k8stypes.NamespacedName{Name: "endpoint2"},
Address: "5.6.7.8",
Labels: map[string]string{filter.RoleLabel: filter.RoleDecode},
},
- MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0},
- }
- noRolePod1 := &types.PodMetrics{
- Pod: &backend.Pod{
- NamespacedName: k8stypes.NamespacedName{Name: "noRolePod1"},
+ &fwkdl.Metrics{WaitingQueueSize: 0},
+ fwkdl.NewAttributes(),
+ )
+ noRoleEndpoint1 := fwkschd.NewEndpoint(
+ &fwkdl.EndpointMetadata{
+ NamespacedName: k8stypes.NamespacedName{Name: "noRoleEndpoint1"},
Address: "1.1.1.1",
},
- MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 2},
- }
+ &fwkdl.Metrics{WaitingQueueSize: 2},
+ fwkdl.NewAttributes(),
+ )
- prefillDecodeResult := &types.SchedulingResult{
- ProfileResults: map[string]*types.ProfileRunResult{
- decode: {TargetPods: []types.Pod{
- &types.ScoredPod{
- Pod: pod2,
+ prefillDecodeResult := &fwkschd.SchedulingResult{
+ ProfileResults: map[string]*fwkschd.ProfileRunResult{
+ decode: {TargetEndpoints: []fwkschd.Endpoint{
+ &fwkschd.ScoredEndpoint{
+ Endpoint: endpoint2,
},
},
},
prefill: {
- TargetPods: []types.Pod{
- &types.ScoredPod{
- Pod: pod1,
+ TargetEndpoints: []fwkschd.Endpoint{
+ &fwkschd.ScoredEndpoint{
+ Endpoint: endpoint1,
},
},
},
@@ -76,12 +78,12 @@ func TestPDSchedule(t *testing.T) {
PrimaryProfileName: decode,
}
- decodeResult := &types.SchedulingResult{
- ProfileResults: map[string]*types.ProfileRunResult{
+ decodeResult := &fwkschd.SchedulingResult{
+ ProfileResults: map[string]*fwkschd.ProfileRunResult{
decode: {
- TargetPods: []types.Pod{
- &types.ScoredPod{
- Pod: pod2,
+ TargetEndpoints: []fwkschd.Endpoint{
+ &fwkschd.ScoredEndpoint{
+ Endpoint: endpoint2,
},
},
},
@@ -91,114 +93,114 @@ func TestPDSchedule(t *testing.T) {
tests := []struct {
name string
- req *types.LLMRequest
- input []types.Pod
- wantRes *types.SchedulingResult
- wantRes2 *types.SchedulingResult // a subsequent call to check prefix cache and how it affects PD
+ req *fwkschd.LLMRequest
+ input []fwkschd.Endpoint
+ wantRes *fwkschd.SchedulingResult
+ wantRes2 *fwkschd.SchedulingResult // a subsequent call to check prefix cache and how it affects PD
err bool
}{
{
- name: "no candidate pods",
- req: &types.LLMRequest{
+ name: "no candidate endpoints",
+ req: &fwkschd.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "any-model",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &fwkschd.LLMRequestBody{
+ Completions: &fwkschd.CompletionsRequest{
Prompt: "12345678901",
},
},
},
- input: []types.Pod{},
+ input: []fwkschd.Endpoint{},
err: true,
},
{
- name: "one decode pod, long prompt",
- req: &types.LLMRequest{
+ name: "one decode endpoint, long prompt",
+ req: &fwkschd.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "critical",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &fwkschd.LLMRequestBody{
+ Completions: &fwkschd.CompletionsRequest{
Prompt: "12345678901",
},
},
},
- // pod2 will be picked because it is the only pod with Decode role
- input: []types.Pod{pod2},
+ // endpoint2 will be picked because it is the only endpoint with Decode role
+ input: []fwkschd.Endpoint{endpoint2},
wantRes: decodeResult,
},
{
- name: "one prefill pod, long prompt",
- req: &types.LLMRequest{
+ name: "one prefill endpoint, long prompt",
+ req: &fwkschd.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "critical",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &fwkschd.LLMRequestBody{
+ Completions: &fwkschd.CompletionsRequest{
Prompt: "12345678901",
},
},
},
- // no Decode pod
- input: []types.Pod{pod1},
+ // no Decode endpoint
+ input: []fwkschd.Endpoint{endpoint1},
err: true,
},
{
name: "1P1D - long prompt",
- req: &types.LLMRequest{
+ req: &fwkschd.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "critical",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &fwkschd.LLMRequestBody{
+ Completions: &fwkschd.CompletionsRequest{
Prompt: "12345678906",
},
},
},
- // pod2 will be picked in the decode profile result, pod1 will be in the prefill profile result
- input: []types.Pod{pod1, pod2},
+ // endpoint2 will be picked in the decode profile result, endpoint1 will be in the prefill profile result
+ input: []fwkschd.Endpoint{endpoint1, endpoint2},
wantRes: prefillDecodeResult,
wantRes2: decodeResult,
},
{
name: "1P1Dshort",
- req: &types.LLMRequest{
+ req: &fwkschd.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "critical",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &fwkschd.LLMRequestBody{
+ Completions: &fwkschd.CompletionsRequest{
Prompt: "12345",
},
},
},
- // pod2 will be picked because it is the decode pod, pod1 shouldn't be picked,
+ // endpoint2 will be picked because it is the decode endpoint, endpoint1 shouldn't be picked,
// because the prompt is too short
- input: []types.Pod{pod1, pod2},
+ input: []fwkschd.Endpoint{endpoint1, endpoint2},
wantRes: decodeResult,
wantRes2: decodeResult,
},
{
name: "TestRolesWithNoDecode",
- req: &types.LLMRequest{
+ req: &fwkschd.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "critical",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
+ Body: &fwkschd.LLMRequestBody{
+ Completions: &fwkschd.CompletionsRequest{
Prompt: "12345678901",
},
},
},
- input: []types.Pod{pod1, noRolePod1},
- wantRes: &types.SchedulingResult{
- ProfileResults: map[string]*types.ProfileRunResult{
+ input: []fwkschd.Endpoint{endpoint1, noRoleEndpoint1},
+ wantRes: &fwkschd.SchedulingResult{
+ ProfileResults: map[string]*fwkschd.ProfileRunResult{
decode: {
- TargetPods: []types.Pod{
- &types.ScoredPod{
- Pod: noRolePod1,
+ TargetEndpoints: []fwkschd.Endpoint{
+ &fwkschd.ScoredEndpoint{
+ Endpoint: noRoleEndpoint1,
},
},
},
prefill: {
- TargetPods: []types.Pod{
- &types.ScoredPod{
- Pod: pod1,
+ TargetEndpoints: []fwkschd.Endpoint{
+ &fwkschd.ScoredEndpoint{
+ Endpoint: endpoint1,
},
},
},
@@ -208,18 +210,18 @@ func TestPDSchedule(t *testing.T) {
},
{
name: "1P2D - long prompt",
- req: &types.LLMRequest{
+ req: &fwkschd.LLMRequest{
RequestId: uuid.NewString(),
TargetModel: "critical",
- Body: &types.LLMRequestBody{
- Completions: &types.CompletionsRequest{
- Prompt: "12345678906",
+ Body: &fwkschd.LLMRequestBody{
+ Completions: &fwkschd.CompletionsRequest{
+ Prompt: "1234567890123456789012345678901234567890",
},
},
},
- // pod2 will be picked in the decode profile result cause it has higher score than noRolePod1
- // pod1 will be in the prefill profile result
- input: []types.Pod{pod1, pod2, noRolePod1},
+ // endpoint2 will be picked in the decode profile result cause it has higher score than noRoleEndpoint1
+ // endpoint1 will be in the prefill profile result
+ input: []fwkschd.Endpoint{endpoint1, endpoint2, noRoleEndpoint1},
wantRes: prefillDecodeResult,
wantRes2: decodeResult,
},
@@ -232,49 +234,64 @@ func TestPDSchedule(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
// initialize scheduler with config
- prefixScorer := prefix.New(ctx, prefix.Config{BlockSize: 5, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250})
+ prefixScorer, err := prefix.New(ctx, prefix.Config{AutoTune: false, BlockSizeTokens: 2, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250})
+ assert.NoError(t, err, "Prefix plugin creation returned unexpected error")
- prefillSchedulerProfile := framework.NewSchedulerProfile().
+ prefillSchedulerProfile := scheduling.NewSchedulerProfile().
WithFilters(filter.NewPrefillRole()).
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))
- err := prefillSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 50))
+ err = prefillSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 50))
assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error")
- decodeSchedulerProfile := framework.NewSchedulerProfile().
+ decodeSchedulerProfile := scheduling.NewSchedulerProfile().
WithFilters(filter.NewDecodeRole()).
- WithScorers(framework.NewWeightedScorer(scorer.NewLoadAware(ctx, scorer.QueueThresholdDefault), 1)).
+ WithScorers(scheduling.NewWeightedScorer(scorer.NewLoadAware(ctx, scorer.QueueThresholdDefault), 1)).
WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints))
- err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0))
+ err = decodeSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 0))
assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error")
- profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, 10, 5, 0)
+ deciderPlugin, err := profile.NewPrefixBasedPDDecider(profile.PrefixBasedPDDeciderConfig{NonCachedTokens: 2})
+ assert.NoError(t, err)
+
+ profileHandle, err := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name,
+ 0, deciderPlugin)
+ assert.NoError(t, err)
- schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]*framework.SchedulerProfile{
+ schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]fwkschd.SchedulerProfile{
prefill: prefillSchedulerProfile,
decode: decodeSchedulerProfile,
})
scheduler := scheduling.NewSchedulerWithConfig(schedulerConfig)
+
+ inputTokens := len(test.req.Body.Completions.Prompt) / profile.AverageCharactersPerToken
+ for _, pod := range test.input {
+ pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(0, inputTokens, 1))
+ }
got, err := scheduler.Schedule(ctx, test.req, test.input)
if test.err != (err != nil) {
t.Errorf("Unexpected error, got %v, want %v", err, test.err)
}
- if diff := cmp.Diff(test.wantRes, got, cmpopts.IgnoreFields(types.ScoredPod{}, "Score")); diff != "" {
+ if diff := cmp.Diff(test.wantRes, got, cmpopts.IgnoreUnexported(fwkdl.Attributes{}), cmpopts.IgnoreFields(fwkschd.ScoredEndpoint{}, "Score")); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
-
if test.wantRes2 != nil { // Checking the prefix match in the decode pod.
// make sure prefix plugin stores the prefix hit in cache, so we can test it in the following schedule call
prefixScorer.PreRequest(ctx, test.req, got)
time.Sleep(time.Second)
+ // update number of cached tokens "stored" in the first schedule execution
+ for _, pod := range test.input {
+ pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(inputTokens, inputTokens, 1))
+ }
+
got, err = scheduler.Schedule(ctx, test.req, test.input)
if test.err != (err != nil) {
t.Errorf("Unexpected error in schedule call, got %v, want %v", err, test.err)
}
- if diff := cmp.Diff(test.wantRes2, got, cmpopts.IgnoreFields(types.ScoredPod{}, "Score")); diff != "" {
+ if diff := cmp.Diff(test.wantRes2, got, cmpopts.IgnoreUnexported(fwkdl.Attributes{}), cmpopts.IgnoreFields(fwkschd.ScoredEndpoint{}, "Score")); diff != "" {
t.Errorf("Unexpected output in subsequent schedule call (-want +got): %v", diff)
}
}
diff --git a/pkg/sidecar/proxy/connector_lmcache.go b/pkg/sidecar/proxy/connector_lmcache.go
deleted file mode 100644
index f19412e83..000000000
--- a/pkg/sidecar/proxy/connector_lmcache.go
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
-Copyright 2025 The llm-d Authors.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-*/
-
-package proxy
-
-import (
- "encoding/json"
- "io"
- "net/http"
- "strings"
-)
-
-func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) {
- s.logger.Info("running LMCache protocol")
-
- // Read and parse request body
- defer r.Body.Close() //nolint:all
- original, err := io.ReadAll(r.Body)
- if err != nil {
- w.WriteHeader(http.StatusBadRequest) // TODO: check FastAPI error code when failing to read body
- w.Write([]byte(err.Error())) //nolint:all
- return
- }
-
- // Parse completion request
- var completionRequest map[string]any
- if err := json.Unmarshal(original, &completionRequest); err != nil {
- if err := errorJSONInvalid(err, w); err != nil {
- s.logger.Error(err, "failed to send error response to client")
- }
- return
- }
-
- // Create prefiller request. Set max_tokens to 1.
-
- ctx := r.Context()
- preq := r.Clone(ctx)
-
- completionRequest[requestFieldMaxTokens] = 1
- completionRequest[requestFieldMaxCompletionTokens] = 1
-
- pbody, err := json.Marshal(completionRequest)
- if err != nil {
- if err := errorJSONInvalid(err, w); err != nil {
- s.logger.Error(err, "failed to send error response to client")
- }
- return
- }
- preq.Body = io.NopCloser(strings.NewReader(string(pbody)))
- preq.ContentLength = int64(len(pbody))
-
- // Forward request to prefiller
-
- prefillHandler, err := s.prefillerProxyHandler(prefillPodHostPort)
- if err != nil {
- if err := errorBadGateway(err, w); err != nil {
- s.logger.Error(err, "failed to send error response to client")
- }
- return
- }
- s.logger.V(4).Info("sending prefill request", "to", prefillPodHostPort)
- pw := &bufferedResponseWriter{}
- prefillHandler.ServeHTTP(pw, preq)
-
- if pw.statusCode < 200 || pw.statusCode >= 300 {
- s.logger.Error(err, "request failed", "code", pw.statusCode)
- w.WriteHeader(pw.statusCode)
- return
- }
-
- // Forward original request to local decoder
-
- r.Body = io.NopCloser(strings.NewReader(string(original)))
- if !s.forwardDataParallel || !s.dataParallelHandler(w, r) {
- s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host)
- s.decoderProxy.ServeHTTP(w, r)
- }
-}
diff --git a/pkg/sidecar/proxy/connector_nixlv2.go b/pkg/sidecar/proxy/connector_nixlv2.go
index 265072bbf..f8543a711 100644
--- a/pkg/sidecar/proxy/connector_nixlv2.go
+++ b/pkg/sidecar/proxy/connector_nixlv2.go
@@ -107,7 +107,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi
pw := &bufferedResponseWriter{}
prefillHandler.ServeHTTP(pw, preq)
- if pw.statusCode < 200 || pw.statusCode >= 300 {
+ if isHTTPError(pw.statusCode) {
s.logger.Error(err, "request failed", "code", pw.statusCode)
w.WriteHeader(pw.statusCode)
return
diff --git a/pkg/sidecar/proxy/connector_shared_storage.go b/pkg/sidecar/proxy/connector_shared_storage.go
new file mode 100644
index 000000000..5245f2993
--- /dev/null
+++ b/pkg/sidecar/proxy/connector_shared_storage.go
@@ -0,0 +1,276 @@
+/*
+Copyright 2025 The llm-d Authors.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+*/
+
+package proxy
+
+import (
+ "bytes"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "io"
+ "maps"
+ "net/http"
+ "strings"
+)
+
+func (s *Server) runSharedStorageProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) {
+ s.logger.V(4).Info("running Shared Storage protocol", "url", prefillPodHostPort)
+
+ // Read and parse request body
+ defer r.Body.Close() //nolint:all
+ original, err := io.ReadAll(r.Body)
+ if err != nil {
+ w.WriteHeader(http.StatusBadRequest) // TODO: check FastAPI error code when failing to read body
+ w.Write([]byte(err.Error())) //nolint:all
+ return
+ }
+
+ // Parse completion request
+ var completionRequest map[string]any
+ if err := json.Unmarshal(original, &completionRequest); err != nil {
+ if err := errorJSONInvalid(err, w); err != nil {
+ s.logger.Error(err, "failed to send Invalid JSON error response to client")
+ }
+ return
+ }
+
+ // If "cache_hit_threshold" is present in the request, we try to decode first. The decode node must meet the cache hit threshold in order to execute.
+ // If the decode node is below the threshold, it won't process the request and return a "cache_threshold" finish reason. In that case,
+ // we fall back to P/D disaggregation: perform prefill and then decode.
+ // For more information refer to the RFC https://github.com/vllm-project/vllm/issues/24256
+ if cacheHitThreshold, hasCacheHitThreshold := completionRequest[requestFieldCacheHitThreshold]; hasCacheHitThreshold {
+ s.logger.V(4).Info("cache_hit_threshold field found in the request, trying to decode first", requestFieldCacheHitThreshold, cacheHitThreshold)
+ decodeReq := cloneRequestWithBody(r, original)
+ needsPrefill, err := s.tryDecode(w, decodeReq, completionRequest)
+ if err != nil {
+ return
+ }
+ if !needsPrefill {
+ s.logger.V(4).Info("decode succeeded without prefill")
+ return
+ }
+ s.logger.V(4).Info("decode failed due to failing to meet the cache hit threshold", requestFieldCacheHitThreshold, cacheHitThreshold)
+ }
+
+ // we clone the completion request to avoid modifying the original request
+ prefillRequest := maps.Clone(completionRequest)
+ if err := s.prefill(w, r, prefillPodHostPort, prefillRequest); err != nil {
+ s.logger.Error(err, "prefill failed")
+ return
+ }
+
+ s.logger.V(4).Info("forwarding to decoder after prefill")
+ completionRequest[requestFieldCacheHitThreshold] = 0
+ decodeRequestBody, err := json.Marshal(completionRequest)
+ if err != nil {
+ if err := errorJSONInvalid(err, w); err != nil {
+ s.logger.Error(err, "failed to send Invalid JSON error response to client")
+ }
+ return
+ }
+
+ decodeReq := cloneRequestWithBody(r, decodeRequestBody)
+ s.decoderProxy.ServeHTTP(w, decodeReq)
+}
+
+// tryDecode attempts to decode and returns whether prefill is needed.
+func (s *Server) tryDecode(w http.ResponseWriter, r *http.Request, completionRequest map[string]any) (bool, error) {
+ if isStreaming, _ := completionRequest[requestFieldStream].(bool); isStreaming {
+ if flusher, ok := w.(flushableResponseWriter); ok {
+ bw := newResponseWriterWithBuffer(flusher)
+ return s.tryDecodeStreaming(bw, r)
+ }
+ }
+ return s.tryDecodeBuffered(w, r)
+}
+
+// tryDecodeBuffered handles non-streaming decode attempts.
+// It buffers the entire response before inspecting it.
+func (s *Server) tryDecodeBuffered(w http.ResponseWriter, r *http.Request) (bool, error) {
+ dw := &bufferedResponseWriter{}
+ s.decoderProxy.ServeHTTP(dw, r)
+
+ if isHTTPError(dw.statusCode) {
+
+ w.WriteHeader(dw.statusCode)
+ if dw.buffer.Len() > 0 {
+ w.Write([]byte(dw.buffer.String())) //nolint:all
+ }
+
+ err := errors.New("decode request failed")
+ s.logger.Error(err, "unexpected status code", "code", dw.statusCode)
+
+ return false, err
+ }
+
+ // Parse response to check finish_reason
+ var response map[string]any
+ if err := json.Unmarshal([]byte(dw.buffer.String()), &response); err != nil {
+ s.logger.Error(err, "failed to unmarshal decode response", "response", dw.buffer.String())
+
+ if err := errorInternalServerError(err, w); err != nil {
+ s.logger.Error(err, "failed to send error response to client")
+ }
+ return false, err
+ }
+
+ // Check for cache_threshold finish reason
+ if s.hasCacheThresholdFinishReason(response) {
+ return true, nil
+ }
+
+ // Decode succeeded, write response to client
+ maps.Copy(w.Header(), dw.headers)
+ w.Write([]byte(dw.buffer.String())) //nolint:all
+
+ return false, nil
+}
+
+// tryDecodeStreaming handles streaming decode attempts.
+// It buffers the initial response to check for cache_threshold, then switches
+// to direct streaming mode if decode succeeds.
+func (s *Server) tryDecodeStreaming(w *responseWriterWithBuffer, r *http.Request) (bool, error) {
+ // Run ServeHTTP in a goroutine so we can inspect the initial choice to determine if we need to prefill.
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ s.decoderProxy.ServeHTTP(w, r)
+ }()
+
+ // Wait for either:
+ // - firstChunkReady(): first body data is available in buffer
+ // - done: request completed (possibly with no body, e.g., error response)
+ select {
+ case <-w.firstChunkReady():
+ case <-done:
+ s.logger.V(4).Info("request completed without body data")
+ }
+
+ statusCode := w.getStatusCode()
+ if isHTTPError(statusCode) {
+ if err := w.flushBufferAndGoDirect(); err != nil {
+ s.logger.Error(err, "failed to flush buffer to client")
+ return false, err
+ }
+ return false, fmt.Errorf("decode request failed with status code: %d", statusCode)
+ }
+
+ // Check buffered SSE content for cache_threshold finish reason.
+ if s.checkBufferedResponseForCacheThreshold(w.buffered()) {
+ s.logger.V(4).Info("finish reason cache_threshold detected, needs prefill")
+ return true, nil
+ }
+
+ // No cache_threshold finish reason found, flush buffer and switch to direct mode
+ // to let the rest of the response stream through.
+ s.logger.V(4).Info("first response for request shows success without cache_threshold finish reason")
+ if err := w.flushBufferAndGoDirect(); err != nil {
+ s.logger.Error(err, "failed to flush buffer to client and switch to direct mode")
+ return false, err
+ }
+ <-done
+ return false, nil
+}
+
+// hasCacheThresholdFinishReason checks if a parsed response contains cache_threshold finish reason.
+func (s *Server) hasCacheThresholdFinishReason(response map[string]any) bool {
+ choices, ok := response[responseFieldChoices].([]any)
+ if !ok || len(choices) == 0 {
+ return false
+ }
+
+ choice, ok := choices[0].(map[string]any)
+ if !ok {
+ return false
+ }
+
+ finishReason, ok := choice[responseFieldFinishReason].(string)
+ return ok && finishReason == finishReasonCacheThreshold
+}
+
+// checkBufferedResponseForCacheThreshold checks the buffered SSE response for cache_threshold finish reason.
+// This is only called for streaming responses, so data is always in SSE format.
+func (s *Server) checkBufferedResponseForCacheThreshold(data string) bool {
+ // Parse SSE format: "data: {...json...}\n\ndata: {...json...}\n\n"
+ for _, line := range strings.Split(data, "\n") {
+ line = strings.TrimSpace(line)
+ if line == "" || line == "data: [DONE]" || !strings.HasPrefix(line, "data: ") {
+ continue
+ }
+
+ jsonData := strings.TrimPrefix(line, "data: ")
+ var response map[string]any
+ if err := json.Unmarshal([]byte(jsonData), &response); err != nil {
+ s.logger.V(4).Info("skipping malformed SSE chunk", "chunk", jsonData)
+ continue
+ }
+
+ if s.hasCacheThresholdFinishReason(response) {
+ return true
+ }
+ }
+ return false
+}
+
+// prefill routes a request to a prefill node
+func (s *Server) prefill(w http.ResponseWriter, r *http.Request, prefillPodHostPort string, completionRequest map[string]any) error {
+ // Prepare prefill request
+ completionRequest[requestFieldMaxTokens] = 1
+ completionRequest[requestFieldMaxCompletionTokens] = 1
+ completionRequest[requestFieldCacheHitThreshold] = 0
+
+ pbody, err := json.Marshal(completionRequest)
+ if err != nil {
+ if err := errorJSONInvalid(err, w); err != nil {
+ s.logger.Error(err, "failed to send Invalid JSON error response to client")
+ }
+ return err
+ }
+ preq := cloneRequestWithBody(r, pbody)
+
+ prefillHandler, err := s.prefillerProxyHandler(prefillPodHostPort)
+ if err != nil {
+ if err := errorBadGateway(err, w); err != nil {
+ s.logger.Error(err, "failed to send Bad Gateway error response to client")
+ }
+ return err
+ }
+
+ // send prefill request
+ s.logger.V(4).Info("sending prefill request", "to", prefillPodHostPort)
+ pw := &bufferedResponseWriter{}
+ prefillHandler.ServeHTTP(pw, preq)
+
+ if isHTTPError(pw.statusCode) {
+ s.logger.Error(nil, "prefill request failed", "code", pw.statusCode)
+ w.WriteHeader(pw.statusCode)
+ if pw.buffer.Len() > 0 {
+ w.Write([]byte(pw.buffer.String())) //nolint:all
+ }
+ return fmt.Errorf("prefill request failed with status code: %d", pw.statusCode)
+ }
+
+ s.logger.V(4).Info("prefill completed successfully")
+ return nil
+}
+
+func cloneRequestWithBody(r *http.Request, body []byte) *http.Request {
+ cloned := r.Clone(r.Context())
+ cloned.Body = io.NopCloser(bytes.NewReader(body))
+ cloned.ContentLength = int64(len(body))
+ return cloned
+}
diff --git a/pkg/sidecar/proxy/connector_test.go b/pkg/sidecar/proxy/connector_test.go
index 64ba26879..b2fa407f9 100644
--- a/pkg/sidecar/proxy/connector_test.go
+++ b/pkg/sidecar/proxy/connector_test.go
@@ -44,7 +44,7 @@ type sidecarTestInfo struct {
proxy *Server
}
-var connectors = []string{ConnectorLMCache, ConnectorNIXLV2}
+var connectors = []string{ConnectorSharedStorage, ConnectorNIXLV2}
var _ = Describe("Common Connector tests", func() {
diff --git a/pkg/sidecar/proxy/errors.go b/pkg/sidecar/proxy/errors.go
index 0ff35ae75..394e30072 100644
--- a/pkg/sidecar/proxy/errors.go
+++ b/pkg/sidecar/proxy/errors.go
@@ -50,6 +50,10 @@ func errorBadGateway(err error, w http.ResponseWriter) error {
return sendError(err, "BadGateway", http.StatusBadGateway, w)
}
+func errorInternalServerError(err error, w http.ResponseWriter) error {
+ return sendError(err, "InternalServerError", http.StatusInternalServerError, w)
+}
+
// sendError simulates vLLM errors
//
// Example:
diff --git a/pkg/sidecar/proxy/proxy.go b/pkg/sidecar/proxy/proxy.go
index bd90dd3a9..8e3b8e733 100644
--- a/pkg/sidecar/proxy/proxy.go
+++ b/pkg/sidecar/proxy/proxy.go
@@ -46,6 +46,12 @@ const (
requestFieldRemotePort = "remote_port"
requestFieldStream = "stream"
requestFieldStreamOptions = "stream_options"
+ requestFieldCacheHitThreshold = "cache_hit_threshold"
+
+ responseFieldChoices = "choices"
+ responseFieldFinishReason = "finish_reason"
+
+ finishReasonCacheThreshold = "cache_threshold"
// SGLang bootstrap fields
requestFieldBootstrapHost = "bootstrap_host"
@@ -55,8 +61,8 @@ const (
// ConnectorNIXLV2 enables the P/D NIXL v2 protocol
ConnectorNIXLV2 = "nixlv2"
- // ConnectorLMCache enables (now deprecated) P/D LMCache protocol
- ConnectorLMCache = "lmcache"
+ // ConnectorSharedStorage enables (now deprecated) P/D Shared Storage protocol
+ ConnectorSharedStorage = "shared-storage"
// ConnectorSGLang enables SGLang P/D disaggregation protocol
ConnectorSGLang = "sglang"
@@ -177,8 +183,8 @@ func (s *Server) Clone() *Server {
func (s *Server) setConnector() {
switch s.config.Connector {
- case ConnectorLMCache:
- s.runConnectorProtocol = s.runLMCacheProtocol
+ case ConnectorSharedStorage:
+ s.runConnectorProtocol = s.runSharedStorageProtocol
case ConnectorSGLang:
s.runConnectorProtocol = s.runSGLangProtocol
case ConnectorNIXLV2:
diff --git a/pkg/sidecar/proxy/proxy_helpers.go b/pkg/sidecar/proxy/proxy_helpers.go
index 5f68bb4ea..30ba76494 100644
--- a/pkg/sidecar/proxy/proxy_helpers.go
+++ b/pkg/sidecar/proxy/proxy_helpers.go
@@ -124,3 +124,8 @@ func (s *Server) createDecoderProxyHandler(decoderURL *url.URL, decoderInsecureS
}
return decoderProxy
}
+
+// isHTTPError returns true if the status code indicates an error (not in the 2xx range).
+func isHTTPError(statusCode int) bool {
+ return statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices
+}
diff --git a/pkg/sidecar/proxy/status_response_writer.go b/pkg/sidecar/proxy/status_response_writer.go
index 721e898e6..24d96d263 100644
--- a/pkg/sidecar/proxy/status_response_writer.go
+++ b/pkg/sidecar/proxy/status_response_writer.go
@@ -19,8 +19,12 @@ package proxy
import (
"net/http"
"strings"
+ "sync"
+ "sync/atomic"
)
+const sseEventDelimiter = "\n\n"
+
// bufferedResponseWriter receives responses from prefillers
type bufferedResponseWriter struct {
headers http.Header
@@ -45,3 +49,185 @@ func (w *bufferedResponseWriter) Write(b []byte) (int, error) {
func (w *bufferedResponseWriter) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
+
+type flushableResponseWriter interface {
+ http.ResponseWriter
+ http.Flusher
+}
+
+// responseWriterWithBuffer wraps an http.ResponseWriter to buffer initial writes.
+// Start in buffer mode to inspect the first chunk, then call flushBufferAndGoDirect()
+// to write buffered content and switch to direct pass-through mode.
+type responseWriterWithBuffer struct {
+ writerFlusher flushableResponseWriter
+
+ // buffering is checked atomically to allow lock-free fast paths
+ // in direct mode (Write and Flush).
+ buffering atomic.Bool
+
+ // mu protects buffer, statusCode, and wroteHeader during buffering mode
+ // and during the transition to direct mode.
+ mu sync.Mutex
+ buffer strings.Builder
+ statusCode int
+ wroteHeader bool
+
+ // ready receives an error (or nil) when the first Write happens,
+ // signaling that there's data available for inspection or an error occurred.
+ ready chan struct{}
+ readyOnce sync.Once
+}
+
+// newResponseWriterWithBuffer creates a new writer starting in buffer mode.
+func newResponseWriterWithBuffer(w flushableResponseWriter) *responseWriterWithBuffer {
+ rw := &responseWriterWithBuffer{
+ writerFlusher: w,
+ ready: make(chan struct{}, 1), // buffered to avoid blocking sender
+ }
+ rw.buffering.Store(true)
+ return rw
+}
+
+func (w *responseWriterWithBuffer) Header() http.Header {
+ return w.writerFlusher.Header()
+}
+
+func (w *responseWriterWithBuffer) Write(b []byte) (int, error) {
+ if !w.buffering.Load() {
+ return w.writerFlusher.Write(b)
+ }
+
+ // Buffering mode, need lock to protect buffer
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ // Double-check after acquiring lock (may have transitioned to direct mode)
+ if !w.buffering.Load() {
+ return w.writerFlusher.Write(b)
+ }
+
+ if w.statusCode == 0 {
+ w.statusCode = http.StatusOK
+ }
+
+ // Write() always returns a nil error
+ n, _ := w.buffer.Write(b)
+
+ // Signal ready when buffer contains at least 2 complete SSE events.
+ // For SSE streaming, the first chunk is just the role announcement with
+ // finish_reason:null. We need the second chunk to see if cache_threshold
+ // was returned (early abort) or if decode is proceeding normally.
+ if shouldSignal(w.buffer.String()) {
+ w.signalReady()
+ }
+
+ return n, nil
+}
+
+func (w *responseWriterWithBuffer) WriteHeader(statusCode int) {
+ if !w.buffering.Load() {
+ w.writerFlusher.WriteHeader(statusCode)
+ return
+ }
+
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if !w.buffering.Load() {
+ w.writerFlusher.WriteHeader(statusCode)
+ return
+ }
+
+ if w.statusCode == 0 {
+ w.statusCode = statusCode
+ }
+}
+
+func (w *responseWriterWithBuffer) Flush() {
+ if w.buffering.Load() {
+ // Apply same logic as Write(): only signal when we have at least 2 SSE events.
+ w.mu.Lock()
+ shouldSignal := shouldSignal(w.buffer.String())
+ w.mu.Unlock()
+ if shouldSignal {
+ w.signalReady()
+ }
+ return
+ }
+ w.writerFlusher.Flush()
+}
+
+// firstChunkReady returns a channel that receives nil when the first complete
+// chunk of body data is available in the buffer. For SSE streaming responses,
+// this is signaled when the buffer contains "\n\n" (complete SSE event).
+// As a fallback, Flush() also signals readiness.
+func (w *responseWriterWithBuffer) firstChunkReady() <-chan struct{} {
+ return w.ready
+}
+
+func (w *responseWriterWithBuffer) signalReady() {
+ w.readyOnce.Do(func() {
+ w.ready <- struct{}{}
+ close(w.ready)
+ })
+}
+
+func (w *responseWriterWithBuffer) writeHeaderOnce() {
+ if w.wroteHeader {
+ return
+ }
+ w.wroteHeader = true
+ if w.statusCode != 0 {
+ w.writerFlusher.WriteHeader(w.statusCode)
+ }
+}
+
+// buffered returns the currently buffered content for inspection.
+func (w *responseWriterWithBuffer) buffered() string {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.buffer.String()
+}
+
+// getStatusCode returns the status code that was set (0 if not set).
+func (w *responseWriterWithBuffer) getStatusCode() int {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+ return w.statusCode
+}
+
+// flushBufferAndGoDirect writes any buffered content to the underlying writer
+// and switches to direct mode for all subsequent writes.
+func (w *responseWriterWithBuffer) flushBufferAndGoDirect() error {
+ w.mu.Lock()
+ defer w.mu.Unlock()
+
+ if !w.buffering.Load() {
+ return nil // already in direct mode
+ }
+
+ w.writeHeaderOnce()
+
+ // Write buffered content to underlying writer
+ if w.buffer.Len() > 0 {
+ _, err := w.writerFlusher.Write([]byte(w.buffer.String()))
+ if err != nil {
+ return err
+ }
+ }
+
+ // Flush BEFORE switching to direct mode.
+ // This prevents concurrent Flush() calls on the underlying writer,
+ // since the proxy goroutine's Flush() will no-op while buffering=true.
+ w.writerFlusher.Flush()
+
+ // Switch to direct mode. After this, the proxy goroutine handles
+ // all writes and flushes directly (no concurrency with us).
+ w.buffering.Store(false)
+
+ return nil
+}
+
+func shouldSignal(data string) bool {
+ return strings.Count(data, sseEventDelimiter) >= 2
+}
diff --git a/scripts/fetch-python-wrapper.sh b/scripts/fetch-python-wrapper.sh
deleted file mode 100755
index 0c8bc1d68..000000000
--- a/scripts/fetch-python-wrapper.sh
+++ /dev/null
@@ -1,47 +0,0 @@
-#!/usr/bin/env bash
-# fetch-python-wrapper.sh
-# Fetches the Python wrapper file (render_jinja_template_wrapper.py) from llm-d-kv-cache-manager
-# for use in Docker builds and local development.
-# Version can be provided as CLI arg or via KVCACHE_MANAGER_VERSION env var (default v0.3.2).
-#
-# This script replicates the original Dockerfile logic:
-# 1. Creates a temporary directory
-# 2. Clones the repo into that directory
-# 3. Creates the output directory structure
-# 4. Copies the wrapper file to the output location
-# 5. Cleans up the temporary directory
-
-set -euo pipefail
-
-VERSION="${1:-${KVCACHE_MANAGER_VERSION:-v0.3.2}}"
-OUTPUT_DIR="${2:-llm-d-kv-cache-manager/pkg/preprocessing/chat_completions}"
-
-REPO_URL="https://github.com/llm-d/llm-d-kv-cache-manager.git"
-WRAPPER_FILE="pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py"
-
-# Create temporary directory (equivalent to: mkdir -p /tmp/kv-cache-manager)
-# TEMP_DIR will be an absolute path like /tmp/tmp.XXXXXX
-TEMP_DIR=$(mktemp -d)
-trap "rm -rf ${TEMP_DIR}" EXIT
-
-echo "Fetching Python wrapper from llm-d-kv-cache-manager@${VERSION}..."
-
-# Equivalent to: cd /tmp/kv-cache-manager && git clone ... .
-# (clones repo contents directly into TEMP_DIR - using absolute path, no need to cd)
-git clone --depth 1 --branch "${VERSION}" "${REPO_URL}" "${TEMP_DIR}"
-
-# Create output directory if it doesn't exist
-# (equivalent to: mkdir -p /workspace/llm-d-kv-cache-manager/pkg/preprocessing/chat_completions)
-# OUTPUT_DIR is relative to current working directory (relative path, same as original)
-mkdir -p "${OUTPUT_DIR}"
-
-# Copy wrapper file
-# Source: absolute path ${TEMP_DIR}/${WRAPPER_FILE} (e.g., /tmp/tmp.XXXXXX/pkg/.../wrapper.py)
-# Destination: relative path ${OUTPUT_DIR}/ (e.g., llm-d-kv-cache-manager/pkg/.../)
-# (equivalent to original: cp pkg/.../wrapper.py /workspace/... from within temp dir)
-cp "${TEMP_DIR}/${WRAPPER_FILE}" "${OUTPUT_DIR}/"
-
-# Cleanup happens automatically via trap (equivalent to: rm -rf /tmp/kv-cache-manager)
-
-echo "Successfully fetched render_jinja_template_wrapper.py to ${OUTPUT_DIR}/"
-
diff --git a/scripts/kind-dev-env.sh b/scripts/kind-dev-env.sh
index 5ef91726d..6b0d22a04 100755
--- a/scripts/kind-dev-env.sh
+++ b/scripts/kind-dev-env.sh
@@ -37,7 +37,7 @@ EPP_IMAGE="${EPP_IMAGE:-${IMAGE_REGISTRY}/llm-d-inference-scheduler:${EPP_TAG}}"
export EPP_IMAGE
# Set the model name to deploy
-export MODEL_NAME="${MODEL_NAME:-food-review}"
+export MODEL_NAME="${MODEL_NAME:-TinyLlama/TinyLlama-1.1B-Chat-v1.0}"
# Extract model family (e.g., "meta-llama" from "meta-llama/Llama-3.1-8B-Instruct")
export MODEL_FAMILY="${MODEL_NAME%%/*}"
# Extract model ID (e.g., "Llama-3.1-8B-Instruct")
@@ -74,32 +74,41 @@ export VLLM_REPLICA_COUNT_D="${VLLM_REPLICA_COUNT_D:-2}"
# Data Parallel size
export VLLM_DATA_PARALLEL_SIZE="${VLLM_DATA_PARALLEL_SIZE:-1}"
-PRIMARY_PORT="0"
-if [ "${PD_ENABLED}" != "\"true\"" ] && [ ${VLLM_DATA_PARALLEL_SIZE} -eq 1 ]; then
- if [ "${KV_CACHE_ENABLED}" != "true" ]; then
- DEFAULT_EPP_CONFIG="deploy/config/sim-epp-config.yaml"
- else
- DEFAULT_EPP_CONFIG="deploy/config/sim-epp-kvcache-config.yaml"
- fi
-else
- if [ "${KV_CACHE_ENABLED}" != "true" ]; then
- if [ "${PD_ENABLED}" == "\"true\"" ]; then
- DEFAULT_EPP_CONFIG="deploy/config/sim-pd-epp-config.yaml"
- if [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then
- PRIMARY_PORT="8000"
- fi
- else
- DEFAULT_EPP_CONFIG="deploy/config/dp-epp-config.yaml"
- fi
- else
+# Validate configuration constraints
+if [ "${KV_CACHE_ENABLED}" == "true" ]; then
+ # KV cache requires simple mode: no PD and DP size must be 1
+ if [ "${PD_ENABLED}" == "\"true\"" ] || [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then
echo "Invalid configuration: PD_ENABLED=true and KV_CACHE_ENABLED=true is not supported"
exit 1
fi
fi
-export EPP_CONFIG="${EPP_CONFIG:-${DEFAULT_EPP_CONFIG}}"
+# Set PRIMARY_PORT based on PD mode with data parallelism
+if [ "${PD_ENABLED}" == "\"true\"" ] && [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then
+ PRIMARY_PORT="8000"
+else
+ PRIMARY_PORT="0"
+fi
export PRIMARY_PORT
+# Determine EPP config file based on feature flags
+if [ "${KV_CACHE_ENABLED}" == "true" ]; then
+ # KV cache mode (simple mode only)
+ DEFAULT_EPP_CONFIG="deploy/config/sim-epp-kvcache-config.yaml"
+elif [ "${PD_ENABLED}" == "\"true\"" ]; then
+ # Prefill-Decode mode
+ DEFAULT_EPP_CONFIG="deploy/config/sim-pd-epp-config.yaml"
+elif [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then
+ # Data Parallel mode (only needed for Istio pre-1.28.1)
+ # Not really called in kind(docker.io/istio/pilot:1.28.1) by "make env-dev-kind"
+ DEFAULT_EPP_CONFIG="deploy/config/dp-epp-config.yaml"
+else
+ # Simple mode
+ DEFAULT_EPP_CONFIG="deploy/config/sim-epp-config.yaml"
+fi
+
+export EPP_CONFIG="${EPP_CONFIG:-${DEFAULT_EPP_CONFIG}}"
+
# ------------------------------------------------------------------------------
# Setup & Requirement Checks
# ------------------------------------------------------------------------------
diff --git a/scripts/kubernetes-dev-env.sh b/scripts/kubernetes-dev-env.sh
index 6cd8c4456..215fe86a4 100755
--- a/scripts/kubernetes-dev-env.sh
+++ b/scripts/kubernetes-dev-env.sh
@@ -24,7 +24,7 @@ if [[ -z "${HF_TOKEN:-}" ]]; then
exit 1
fi
-export VLLM_CHART_DIR="${VLLM_CHART_DIR:-../llm-d-kv-cache-manager/vllm-setup-helm}"
+export VLLM_CHART_DIR="${VLLM_CHART_DIR:-../llm-d-kv-cache/vllm-setup-helm}"
# Check that Chart.yaml exists
if [[ ! -f "$VLLM_CHART_DIR/Chart.yaml" ]]; then
echo "Chart.yaml not found in $VLLM_CHART_DIR"
diff --git a/scripts/pull_images.sh b/scripts/pull_images.sh
index 3acf439c1..e82337f32 100755
--- a/scripts/pull_images.sh
+++ b/scripts/pull_images.sh
@@ -7,7 +7,7 @@ echo "Using container tool: ${CONTAINER_RUNTIME}"
# Set a default EPP_TAG if not provided
EPP_TAG="${EPP_TAG:-dev}"
# Set a default VLLM_SIMULATOR_TAG if not provided
-VLLM_SIMULATOR_TAG="${VLLM_SIMULATOR_TAG:-v0.6.1}"
+VLLM_SIMULATOR_TAG="${VLLM_SIMULATOR_TAG:-latest}"
# Set the default routing side car image tag
SIDECAR_TAG="${SIDECAR_TAG:-dev}"
diff --git a/test-blocking-labels-1771128456.txt b/test-blocking-labels-1771128456.txt
new file mode 100644
index 000000000..e5dba9286
--- /dev/null
+++ b/test-blocking-labels-1771128456.txt
@@ -0,0 +1 @@
+Test file for: blocking-labels - Sun Feb 15 06:07:39 IST 2026
diff --git a/test-blocking-labels-1771128602.txt b/test-blocking-labels-1771128602.txt
new file mode 100644
index 000000000..b546b6479
--- /dev/null
+++ b/test-blocking-labels-1771128602.txt
@@ -0,0 +1 @@
+Test file for: blocking-labels - Sun Feb 15 06:10:06 IST 2026
diff --git a/test-lgtm-1770898248.txt b/test-lgtm-1770898248.txt
new file mode 100644
index 000000000..b0ce6a34d
--- /dev/null
+++ b/test-lgtm-1770898248.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Thu Feb 12 14:10:50 IST 2026
diff --git a/test-lgtm-1770899120.txt b/test-lgtm-1770899120.txt
new file mode 100644
index 000000000..086324f21
--- /dev/null
+++ b/test-lgtm-1770899120.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Thu Feb 12 14:25:22 IST 2026
diff --git a/test-lgtm-1770899284.txt b/test-lgtm-1770899284.txt
new file mode 100644
index 000000000..034733efc
--- /dev/null
+++ b/test-lgtm-1770899284.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Thu Feb 12 14:28:06 IST 2026
diff --git a/test-lgtm-1770900845.txt b/test-lgtm-1770900845.txt
new file mode 100644
index 000000000..d6ca35750
--- /dev/null
+++ b/test-lgtm-1770900845.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Thu Feb 12 14:54:07 IST 2026
diff --git a/test-lgtm-1770901687.txt b/test-lgtm-1770901687.txt
new file mode 100644
index 000000000..1b0ff35e7
--- /dev/null
+++ b/test-lgtm-1770901687.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Thu Feb 12 15:08:08 IST 2026
diff --git a/test-lgtm-1770902441.txt b/test-lgtm-1770902441.txt
new file mode 100644
index 000000000..50c087c16
--- /dev/null
+++ b/test-lgtm-1770902441.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Thu Feb 12 15:20:43 IST 2026
diff --git a/test-lgtm-1771123253.txt b/test-lgtm-1771123253.txt
new file mode 100644
index 000000000..5891c2bf6
--- /dev/null
+++ b/test-lgtm-1771123253.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 04:40:55 IST 2026
diff --git a/test-lgtm-1771123487.txt b/test-lgtm-1771123487.txt
new file mode 100644
index 000000000..b52683765
--- /dev/null
+++ b/test-lgtm-1771123487.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 04:44:49 IST 2026
diff --git a/test-lgtm-1771123805.txt b/test-lgtm-1771123805.txt
new file mode 100644
index 000000000..d6e5201db
--- /dev/null
+++ b/test-lgtm-1771123805.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 04:50:07 IST 2026
diff --git a/test-lgtm-1771123942.txt b/test-lgtm-1771123942.txt
new file mode 100644
index 000000000..a880ada61
--- /dev/null
+++ b/test-lgtm-1771123942.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 04:52:23 IST 2026
diff --git a/test-lgtm-1771124326.txt b/test-lgtm-1771124326.txt
new file mode 100644
index 000000000..b84903b8b
--- /dev/null
+++ b/test-lgtm-1771124326.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 04:58:48 IST 2026
diff --git a/test-lgtm-1771125160.txt b/test-lgtm-1771125160.txt
new file mode 100644
index 000000000..ac1a8cf75
--- /dev/null
+++ b/test-lgtm-1771125160.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 05:12:42 IST 2026
diff --git a/test-lgtm-1771125246.txt b/test-lgtm-1771125246.txt
new file mode 100644
index 000000000..15e6d5010
--- /dev/null
+++ b/test-lgtm-1771125246.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 05:14:08 IST 2026
diff --git a/test-lgtm-1771125307.txt b/test-lgtm-1771125307.txt
new file mode 100644
index 000000000..7676ff8dd
--- /dev/null
+++ b/test-lgtm-1771125307.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 05:15:08 IST 2026
diff --git a/test-lgtm-1771126707.txt b/test-lgtm-1771126707.txt
new file mode 100644
index 000000000..262706925
--- /dev/null
+++ b/test-lgtm-1771126707.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 05:38:29 IST 2026
diff --git a/test-lgtm-1771127615.txt b/test-lgtm-1771127615.txt
new file mode 100644
index 000000000..fe1386765
--- /dev/null
+++ b/test-lgtm-1771127615.txt
@@ -0,0 +1 @@
+This is a test file for LGTM workflow automation - Sun Feb 15 05:53:36 IST 2026
diff --git a/test-open-pr-1771133038.txt b/test-open-pr-1771133038.txt
new file mode 100644
index 000000000..58d095262
--- /dev/null
+++ b/test-open-pr-1771133038.txt
@@ -0,0 +1 @@
+Test file for: open-pr - Sun Feb 15 07:24:04 IST 2026
diff --git a/test-open-pr-1771133183.txt b/test-open-pr-1771133183.txt
new file mode 100644
index 000000000..8c635fe9c
--- /dev/null
+++ b/test-open-pr-1771133183.txt
@@ -0,0 +1 @@
+Test file for: open-pr - Sun Feb 15 07:26:29 IST 2026
diff --git a/test-open-pr-1771135205.txt b/test-open-pr-1771135205.txt
new file mode 100644
index 000000000..153db7286
--- /dev/null
+++ b/test-open-pr-1771135205.txt
@@ -0,0 +1 @@
+Test file for: open-pr - Sun Feb 15 08:00:10 IST 2026
diff --git a/test-open-pr-1771137224.txt b/test-open-pr-1771137224.txt
new file mode 100644
index 000000000..40e1c3735
--- /dev/null
+++ b/test-open-pr-1771137224.txt
@@ -0,0 +1 @@
+Test file for: open-pr - Sun Feb 15 08:33:50 IST 2026
diff --git a/test-success-path-1771126029.txt b/test-success-path-1771126029.txt
new file mode 100644
index 000000000..bd613da8c
--- /dev/null
+++ b/test-success-path-1771126029.txt
@@ -0,0 +1 @@
+Test file for: success-path - Sun Feb 15 05:27:12 IST 2026
diff --git a/test-success-path-1771126298.txt b/test-success-path-1771126298.txt
new file mode 100644
index 000000000..59aa2a331
--- /dev/null
+++ b/test-success-path-1771126298.txt
@@ -0,0 +1 @@
+Test file for: success-path - Sun Feb 15 05:31:42 IST 2026
diff --git a/test/config/prefix_cache_mode_test.go b/test/config/prefix_cache_mode_test.go
index 5a13d0087..3965f9246 100644
--- a/test/config/prefix_cache_mode_test.go
+++ b/test/config/prefix_cache_mode_test.go
@@ -7,7 +7,7 @@ import (
"github.com/go-logr/logr"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/config/loader"
- giePlugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins"
+ giePlugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin"
"sigs.k8s.io/gateway-api-inference-extension/test/utils"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins"
diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go
index f4f5c9247..ff400872e 100644
--- a/test/e2e/e2e_suite_test.go
+++ b/test/e2e/e2e_suite_test.go
@@ -27,6 +27,8 @@ import (
)
const (
+ // kindClusterName is the name of the Kind cluster created for e2e tests.
+ kindClusterName = "e2e-tests"
// defaultReadyTimeout is the default timeout for a resource to report a ready state.
defaultReadyTimeout = 3 * time.Minute
// defaultInterval is the default interval to check if a resource exists or ready conditions.
@@ -61,9 +63,8 @@ var (
containerRuntime = env.GetEnvString("CONTAINER_RUNTIME", "docker", ginkgo.GinkgoLogr)
eppImage = env.GetEnvString("EPP_IMAGE", "ghcr.io/llm-d/llm-d-inference-scheduler:dev", ginkgo.GinkgoLogr)
- vllmSimImage = env.GetEnvString("VLLM_SIMULATOR_IMAGE", "ghcr.io/llm-d/llm-d-inference-sim:dev", ginkgo.GinkgoLogr)
+ vllmSimImage = env.GetEnvString("VLLM_SIMULATOR_IMAGE", "ghcr.io/llm-d/llm-d-inference-sim:latest", ginkgo.GinkgoLogr)
sideCarImage = env.GetEnvString("SIDECAR_IMAGE", "ghcr.io/llm-d/llm-d-routing-sidecar:dev", ginkgo.GinkgoLogr)
-
// nsName is the namespace in which the K8S objects will be created
nsName = env.GetEnvString("NAMESPACE", "default", ginkgo.GinkgoLogr)
@@ -110,8 +111,18 @@ var _ = ginkgo.BeforeSuite(func() {
})
var _ = ginkgo.AfterSuite(func() {
- if k8sContext != "" {
- // Used an existing Kubernetes context
+ if k8sContext == "" {
+ // delete kind cluster we created
+ ginkgo.By("Deleting kind cluster " + kindClusterName)
+ command := exec.Command("kind", "delete", "cluster", "--name", kindClusterName)
+ session, err := gexec.Start(command, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter)
+ if err != nil {
+ ginkgo.GinkgoLogr.Error(err, "Failed to delete kind cluster")
+ } else {
+ gomega.Eventually(session).WithTimeout(60 * time.Second).Should(gexec.Exit())
+ }
+ } else {
+ // Used an existing Kubernetes context, clean up created resources
// Stop port-forward
if portForwardSession != nil {
portForwardSession.Terminate()
@@ -135,18 +146,12 @@ var _ = ginkgo.AfterSuite(func() {
err := testConfig.KubeCli.CoreV1().Namespaces().Delete(testConfig.Context, nsName, metav1.DeleteOptions{})
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
}
- return
}
-
- command := exec.Command("kind", "delete", "cluster", "--name", "e2e-tests")
- session, err := gexec.Start(command, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter)
- gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
- gomega.Eventually(session).WithTimeout(600 * time.Second).Should(gexec.Exit(0))
})
// Create the Kubernetes cluster for the E2E tests and load the local images
func setupK8sCluster() {
- command := exec.Command("kind", "create", "cluster", "--name", "e2e-tests", "--config", "-")
+ command := exec.Command("kind", "create", "cluster", "--name", kindClusterName, "--config", "-")
stdin, err := command.StdinPipe()
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
go func() {
@@ -166,17 +171,27 @@ func setupK8sCluster() {
kindLoadImage(vllmSimImage)
kindLoadImage(eppImage)
kindLoadImage(sideCarImage)
+ kindLoadImage(vllmSimImage)
}
func kindLoadImage(image string) {
tempDir := ginkgo.GinkgoT().TempDir()
target := tempDir + "/container.tar"
- ginkgo.By(fmt.Sprintf("Loading %s into the cluster e2e-tests using %s", image, containerRuntime))
+ ginkgo.By(fmt.Sprintf("Loading %s into the cluster %s using %s", image, kindClusterName, containerRuntime))
_, err := exec.LookPath(containerRuntime)
gomega.Expect(err).ShouldNot(gomega.HaveOccurred(), "Could not find %s in PATH", containerRuntime)
+ // Pull the image first to ensure it's available locally
+ ginkgo.By(fmt.Sprintf("Pulling image %s if not available locally", image))
+ pullCommand := exec.Command(containerRuntime, "pull", image)
+ pullSession, pullErr := gexec.Start(pullCommand, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter)
+ if pullErr == nil {
+ // Wait for pull to complete, but don't fail if image already exists or can't be pulled
+ gomega.Eventually(pullSession).WithTimeout(600 * time.Second).Should(gexec.Exit())
+ }
+
saveArgs := []string{"save", "--output", target}
if containerRuntime == "docker" {
// The platform flag is required for docker save to work but it is an unsupported flag for podman
@@ -189,7 +204,7 @@ func kindLoadImage(image string) {
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
gomega.Eventually(session).WithTimeout(600 * time.Second).Should(gexec.Exit(0))
- command = exec.Command("kind", "--name", "e2e-tests", "load", "image-archive", target)
+ command = exec.Command("kind", "--name", kindClusterName, "load", "image-archive", target)
session, err = gexec.Start(command, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter)
gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
gomega.Eventually(session).WithTimeout(600 * time.Second).Should(gexec.Exit(0))
@@ -278,10 +293,11 @@ func createInferencePool(numTargetPorts int, toDelete bool) []string {
}
infPoolYaml := testutils.ReadYaml(inferExtManifest)
- targetPorts := ""
+ var b strings.Builder
for idx := range numTargetPorts {
- targetPorts += fmt.Sprintf("\n - number: %d", 8000+idx)
+ fmt.Fprintf(&b, "\n - number: %d", 8000+idx)
}
+ targetPorts := b.String()
infPoolYaml = substituteMany(infPoolYaml,
map[string]string{
"${POOL_NAME}": poolName,
diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go
index d23b20a58..44671ed71 100644
--- a/test/e2e/e2e_test.go
+++ b/test/e2e/e2e_test.go
@@ -25,7 +25,7 @@ const (
// running the vLLM simulator without PD
simDeployment = "./yaml/vllm-sim.yaml"
// simPDDeployment references the YAML file for the deployment
- // running the vLLM simulator with PD
+ // running the vLLM simulator with PD (connector type is configurable via ${CONNECTOR_TYPE})
simPDDeployment = "./yaml/vllm-sim-pd.yaml"
// simDPDeployment references the YAML file for the deployment
// running the vLLM simulator with Data Parallel
@@ -126,8 +126,178 @@ var _ = ginkgo.Describe("Run end to end tests", ginkgo.Ordered, func() {
labelFilter2 := fmt.Sprintf(`decision_type="decode-only",model_name="%s"`, modelName)
decodeOnlyCount := getCounterMetric(metricsURL, "llm_d_inference_scheduler_pd_decision_total", labelFilter2)
- gomega.Expect(prefillDecodeCount).Should(gomega.Equal(6))
- gomega.Expect(decodeOnlyCount).Should(gomega.Equal(0))
+ gomega.Expect(prefillDecodeCount).Should(gomega.Equal(4))
+ gomega.Expect(decodeOnlyCount).Should(gomega.Equal(2))
+
+ testutils.DeleteObjects(testConfig, epp)
+ testutils.DeleteObjects(testConfig, modelServers)
+ })
+ })
+
+ ginkgo.When("Running a PD configuration with shared-storage connector", func() {
+ ginkgo.It("should run regular (non-streaming) requests successfully", func() {
+ infPoolObjects = createInferencePool(1, true)
+
+ prefillReplicas := 1
+ decodeReplicas := 2
+ modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage")
+
+ epp := createEndPointPicker(pdConfig)
+
+ prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector)
+ gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas))
+ gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas))
+
+ // Test regular completion request
+ nsHdr, podHdrCompletion, _ := runCompletion(simplePrompt, modelName)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdrCompletion).Should(gomega.BeElementOf(decodePods))
+
+ // Test regular chat completion request
+ nsHdr, podHdrChat, _ := runChatCompletion(simplePrompt)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdrChat).Should(gomega.BeElementOf(decodePods))
+
+ // Run completion with a different prompt
+ nsHdr, podHdr, _ := runCompletion(extraPrompt, modelName)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+
+ // Run completion with original prompt (should go to same pod due to prefix cache)
+ nsHdr, podHdr, _ = runCompletion(simplePrompt, modelName)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+ gomega.Expect(podHdr).Should(gomega.Equal(podHdrCompletion))
+
+ testutils.DeleteObjects(testConfig, epp)
+ testutils.DeleteObjects(testConfig, modelServers)
+ })
+
+ ginkgo.It("should run streaming requests successfully", func() {
+ infPoolObjects = createInferencePool(1, true)
+
+ prefillReplicas := 1
+ decodeReplicas := 2
+ modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage")
+
+ epp := createEndPointPicker(pdConfig)
+
+ prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector)
+ gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas))
+ gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas))
+
+ // Test streaming completion request
+ nsHdr, podHdr := runStreamingCompletion(simplePrompt, modelName)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+
+ // Test streaming chat completion request
+ nsHdr, podHdr = runStreamingChatCompletion(simplePrompt)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+
+ // Run streaming completion with a different prompt
+ nsHdr, podHdr = runStreamingCompletion(extraPrompt, modelName)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+
+ testutils.DeleteObjects(testConfig, epp)
+ testutils.DeleteObjects(testConfig, modelServers)
+ })
+
+ ginkgo.It("should handle decode-first success scenario with cache_hit_threshold", func() {
+ // This test verifies the decode-first optimization:
+ // When cache_hit_threshold is set and the decode succeeds (cache hit),
+ // the request should complete without falling back to P/D.
+ // IMPORTANT: The prefill pod should NOT process any requests in this scenario.
+ infPoolObjects = createInferencePool(1, true)
+
+ prefillReplicas := 1
+ decodeReplicas := 2
+ modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage")
+
+ epp := createEndPointPicker(pdConfig)
+
+ prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector)
+ gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas))
+ gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas))
+
+ // Get prefill request count BEFORE the test
+ prefillCountBefore := getPrefillRequestCount(prefillPods[0])
+ ginkgo.By(fmt.Sprintf("Prefill request count before decode-first test: %d", prefillCountBefore))
+
+ // Test decode-first success: cache_hit_threshold is set, but simulator returns "stop"
+ // (without X-Cache-Threshold header), meaning decode succeeded without prefill
+ nsHdr, podHdr, finishReason := runCompletionWithCacheThreshold(simplePrompt, 0.5, false)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+ gomega.Expect(finishReason).ShouldNot(gomega.Equal("cache_threshold"))
+
+ // Test streaming decode-first success
+ nsHdr, podHdr, finishReason = runStreamingCompletionWithCacheThreshold(simplePrompt, 0.5, false)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+ gomega.Expect(finishReason).ShouldNot(gomega.Equal("cache_threshold"))
+
+ // Get prefill request count AFTER the test
+ prefillCountAfter := getPrefillRequestCount(prefillPods[0])
+ ginkgo.By(fmt.Sprintf("Prefill request count after decode-first test: %d", prefillCountAfter))
+
+ // VERIFY: Prefill pod should NOT have processed any new requests
+ // (decode-first succeeded, so no P/D fallback occurred)
+ gomega.Expect(prefillCountAfter).Should(gomega.Equal(prefillCountBefore),
+ "Prefill pod should NOT process requests when cache threshold is met (decode-first success)")
+
+ testutils.DeleteObjects(testConfig, epp)
+ testutils.DeleteObjects(testConfig, modelServers)
+ })
+
+ ginkgo.It("should handle decode-first fallback to P/D when cache threshold not met", func() {
+ // This test verifies the decode-first fallback scenario:
+ // When cache_hit_threshold is set and the decode returns cache_threshold finish_reason,
+ // the sidecar should fall back to P/D disaggregation.
+ // IMPORTANT: The prefill pod SHOULD process requests in this scenario.
+ infPoolObjects = createInferencePool(1, true)
+
+ prefillReplicas := 1
+ decodeReplicas := 2
+ modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage")
+
+ epp := createEndPointPicker(pdConfig)
+
+ prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector)
+ gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas))
+ gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas))
+
+ // Get prefill request count BEFORE the test
+ prefillCountBefore := getPrefillRequestCount(prefillPods[0])
+ ginkgo.By(fmt.Sprintf("Prefill request count before P/D fallback test: %d", prefillCountBefore))
+
+ // Test decode-first fallback: cache_hit_threshold is set AND X-Cache-Threshold header
+ // forces simulator to return "cache_threshold" finish_reason, triggering P/D fallback
+ nsHdr, podHdr, finishReason := runCompletionWithCacheThreshold(simplePrompt, 0.5, true)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+ // The sidecar completes the P/D flow but returns cache_threshold as the finish_reason
+ // from the initial decode attempt (which triggered the fallback)
+ gomega.Expect(finishReason).Should(gomega.Equal("cache_threshold"))
+
+ // Test streaming decode-first fallback
+ nsHdr, podHdr, finishReason = runStreamingCompletionWithCacheThreshold(extraPrompt, 0.5, true)
+ gomega.Expect(nsHdr).Should(gomega.Equal(nsName))
+ gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods))
+ gomega.Expect(finishReason).Should(gomega.Equal("cache_threshold"))
+
+ // Get prefill request count AFTER the test
+ prefillCountAfter := getPrefillRequestCount(prefillPods[0])
+ ginkgo.By(fmt.Sprintf("Prefill request count after P/D fallback test: %d", prefillCountAfter))
+
+ // VERIFY: Prefill pod SHOULD have processed 2 new requests (1 regular + 1 streaming)
+ // (decode-first failed, so P/D fallback occurred and prefill was invoked)
+ gomega.Expect(prefillCountAfter).Should(gomega.BeNumerically(">", prefillCountBefore),
+ "Prefill pod SHOULD process requests when cache threshold is NOT met (P/D fallback)")
+ gomega.Expect(prefillCountAfter-prefillCountBefore).Should(gomega.Equal(2),
+ "Prefill pod should have processed exactly 2 requests (1 regular + 1 streaming)")
testutils.DeleteObjects(testConfig, epp)
testutils.DeleteObjects(testConfig, modelServers)
@@ -265,7 +435,13 @@ var _ = ginkgo.Describe("Run end to end tests", ginkgo.Ordered, func() {
})
// createModelServers creates the model server resources used for testing from the given filePaths.
+// Uses the default connector (nixlv2) for P/D deployments.
func createModelServers(withPD, withKV, withDP bool, vllmReplicas, prefillReplicas, decodeReplicas int) []string {
+ return createModelServersWithConnector(withPD, withKV, withDP, vllmReplicas, prefillReplicas, decodeReplicas, "nixlv2")
+}
+
+// createModelServersWithConnector creates model server resources with a specific connector type.
+func createModelServersWithConnector(withPD, withKV, withDP bool, vllmReplicas, prefillReplicas, decodeReplicas int, connector string) []string {
theModelName := modelName
theSafeModelName := modelName
if withKV {
@@ -286,6 +462,7 @@ func createModelServers(withPD, withKV, withDP bool, vllmReplicas, prefillReplic
"${MODEL_NAME_SAFE}": theSafeModelName,
"${POOL_NAME}": poolName,
"${KV_CACHE_ENABLED}": strconv.FormatBool(withKV),
+ "${CONNECTOR_TYPE}": connector,
"${SIDECAR_IMAGE}": sideCarImage,
"${VLLM_REPLICA_COUNT}": strconv.Itoa(vllmReplicas),
"${VLLM_REPLICA_COUNT_D}": strconv.Itoa(decodeReplicas),
@@ -429,6 +606,219 @@ func getCounterMetric(metricsURL, metricName, labelMatch string) int {
return 0
}
+func runStreamingCompletion(prompt string, theModel openai.CompletionNewParamsModel) (string, string) {
+ ginkgo.By(fmt.Sprintf("Sending Streaming Completion Request: (port %s) model=%s", port, theModel))
+
+ // Use raw HTTP for streaming to capture headers
+ body := fmt.Sprintf(`{"model":"%s","prompt":"%s","max_tokens":50,"stream":true}`, theModel, prompt)
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/completions", port), strings.NewReader(body))
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ defer func() {
+ err := resp.Body.Close()
+ gomega.Expect(err).ToNot(gomega.HaveOccurred())
+ }()
+
+ gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK))
+
+ namespaceHeader := resp.Header.Get("x-inference-namespace")
+ podHeader := resp.Header.Get("x-inference-pod")
+
+ // Read and verify the streaming response
+ respBody, err := io.ReadAll(resp.Body)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+
+ ginkgo.By(fmt.Sprintf("Streaming Completion received response length: %d bytes", len(respBody)))
+
+ return namespaceHeader, podHeader
+}
+
+func runStreamingChatCompletion(prompt string) (string, string) {
+ ginkgo.By(fmt.Sprintf("Sending Streaming Chat Completion Request: (port %s)", port))
+
+ // Use raw HTTP for streaming to capture headers
+ body := fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"%s"}],"stream":true}`, modelName, prompt)
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/chat/completions", port), strings.NewReader(body))
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ req.Header.Set("Content-Type", "application/json")
+
+ resp, err := http.DefaultClient.Do(req)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ defer func() {
+ err := resp.Body.Close()
+ gomega.Expect(err).ToNot(gomega.HaveOccurred())
+ }()
+
+ gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK))
+
+ namespaceHeader := resp.Header.Get("x-inference-namespace")
+ podHeader := resp.Header.Get("x-inference-pod")
+
+ // Read and verify the streaming response
+ respBody, err := io.ReadAll(resp.Body)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+
+ ginkgo.By(fmt.Sprintf("Streaming Chat Completion received response length: %d bytes", len(respBody)))
+
+ return namespaceHeader, podHeader
+}
+
+// runCompletionWithCacheThreshold sends a completion request with cache_hit_threshold parameter.
+// This triggers the decode-first optimization in the shared-storage connector.
+// Returns namespace header, pod header, and the finish reason from the response.
+func runCompletionWithCacheThreshold(prompt string, cacheHitThreshold float64, forceCacheThresholdFinishReason bool) (string, string, string) {
+ ginkgo.By(fmt.Sprintf("Sending Completion Request with cache_hit_threshold=%v, forceCacheThreshold=%v", cacheHitThreshold, forceCacheThresholdFinishReason))
+
+ body := fmt.Sprintf(`{"model":"%s","prompt":"%s","max_tokens":10,"cache_hit_threshold":%v}`, modelName, prompt, cacheHitThreshold)
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/completions", port), strings.NewReader(body))
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ req.Header.Set("Content-Type", "application/json")
+
+ // Add X-Cache-Threshold header to force the simulator to return cache_threshold finish_reason
+ if forceCacheThresholdFinishReason {
+ req.Header.Set("X-Cache-Threshold-Finish-Reason", "true")
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ defer func() {
+ err := resp.Body.Close()
+ gomega.Expect(err).ToNot(gomega.HaveOccurred())
+ }()
+
+ gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK))
+
+ namespaceHeader := resp.Header.Get("x-inference-namespace")
+ podHeader := resp.Header.Get("x-inference-pod")
+
+ // Parse response to get finish_reason
+ respBody, err := io.ReadAll(resp.Body)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+
+ // Extract finish_reason from JSON response
+ finishReason := extractFinishReason(string(respBody))
+
+ ginkgo.By(fmt.Sprintf("Completion Response: ns=%s, pod=%s, finish_reason=%s", namespaceHeader, podHeader, finishReason))
+
+ return namespaceHeader, podHeader, finishReason
+}
+
+// runStreamingCompletionWithCacheThreshold sends a streaming completion request with cache_hit_threshold.
+func runStreamingCompletionWithCacheThreshold(prompt string, cacheHitThreshold float64, forceCacheThresholdFinishReason bool) (string, string, string) {
+ ginkgo.By(fmt.Sprintf("Sending Streaming Completion Request with cache_hit_threshold=%v, forceCacheThreshold=%v", cacheHitThreshold, forceCacheThresholdFinishReason))
+
+ body := fmt.Sprintf(`{"model":"%s","prompt":"%s","max_tokens":10,"stream":true,"cache_hit_threshold":%v}`, modelName, prompt, cacheHitThreshold)
+ req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/completions", port), strings.NewReader(body))
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ req.Header.Set("Content-Type", "application/json")
+
+ if forceCacheThresholdFinishReason {
+ req.Header.Set("X-Cache-Threshold-Finish-Reason", "true")
+ }
+
+ resp, err := http.DefaultClient.Do(req)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+ defer func() {
+ err := resp.Body.Close()
+ gomega.Expect(err).ToNot(gomega.HaveOccurred())
+ }()
+
+ gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK))
+
+ namespaceHeader := resp.Header.Get("x-inference-namespace")
+ podHeader := resp.Header.Get("x-inference-pod")
+
+ // Read streaming response and extract finish_reason from the last data chunk
+ respBody, err := io.ReadAll(resp.Body)
+ gomega.Expect(err).ShouldNot(gomega.HaveOccurred())
+
+ finishReason := extractFinishReasonFromStreaming(string(respBody))
+
+ ginkgo.By(fmt.Sprintf("Streaming Completion Response: ns=%s, pod=%s, finish_reason=%s", namespaceHeader, podHeader, finishReason))
+
+ return namespaceHeader, podHeader, finishReason
+}
+
+// extractFinishReason extracts the finish_reason field from a JSON response string.
+func extractFinishReason(jsonStr string) string {
+ // Simple extraction - look for "finish_reason":"value" pattern
+ idx := strings.Index(jsonStr, `"finish_reason":"`)
+ if idx == -1 {
+ // Try with null value
+ if strings.Contains(jsonStr, `"finish_reason":null`) {
+ return "null"
+ }
+ return ""
+ }
+ start := idx + len(`"finish_reason":"`)
+ end := strings.Index(jsonStr[start:], `"`)
+ if end == -1 {
+ return ""
+ }
+ return jsonStr[start : start+end]
+}
+
+// extractFinishReasonFromStreaming extracts the finish_reason from the last SSE data chunk.
+func extractFinishReasonFromStreaming(sseData string) string {
+ // Find the last "finish_reason" that is not null
+ lines := strings.Split(sseData, "\n")
+ lastFinishReason := ""
+ for _, line := range lines {
+ if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") {
+ fr := extractFinishReason(line)
+ if fr != "" && fr != "null" {
+ lastFinishReason = fr
+ }
+ }
+ }
+ return lastFinishReason
+}
+
+// getPrefillRequestCount gets the total request count from a prefill pod's metrics endpoint.
+// This is used to verify whether a request was processed by the prefill pod.
+func getPrefillRequestCount(prefillPodName string) int {
+ ginkgo.By("Getting request count from prefill pod: " + prefillPodName)
+
+ // Use Kubernetes API proxy to access the metrics endpoint
+ output, err := testConfig.KubeCli.CoreV1().RESTClient().
+ Get().
+ Namespace(nsName).
+ Resource("pods").
+ Name(prefillPodName + ":8000").
+ SubResource("proxy").
+ Suffix("metrics").
+ DoRaw(testConfig.Context)
+ if err != nil {
+ ginkgo.By(fmt.Sprintf("Warning: Could not get metrics from prefill pod %s: %v", prefillPodName, err))
+ return -1
+ }
+
+ return parseRequestCountFromMetrics(string(output))
+}
+
+// parseRequestCountFromMetrics extracts the request count from Prometheus metrics output.
+func parseRequestCountFromMetrics(metricsOutput string) int {
+ // Look for vllm:e2e_request_latency_seconds_count{model_name="food-review"}
+ lines := strings.Split(metricsOutput, "\n")
+ for _, line := range lines {
+ if strings.Contains(line, "vllm:e2e_request_latency_seconds_count") &&
+ strings.Contains(line, "food-review") {
+ // Extract the count value after the last space
+ parts := strings.Fields(line)
+ if len(parts) >= 2 {
+ count, err := strconv.Atoi(parts[len(parts)-1])
+ if err == nil {
+ return count
+ }
+ }
+ }
+ }
+ return 0
+}
+
// Simple EPP configuration for running without P/D
const simpleConfig = `apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
@@ -453,19 +843,24 @@ schedulingProfiles:
// EPP configuration for running with P/D
const pdConfig = `apiVersion: inference.networking.x-k8s.io/v1alpha1
kind: EndpointPickerConfig
+featureGates:
+- prepareDataPlugins
plugins:
- type: prefill-header-handler
- type: prefix-cache-scorer
parameters:
- hashBlockSize: 10
+ blockSizeTokens: 16
maxPrefixBlocksToMatch: 256
lruCapacityPerServer: 256
- type: prefill-filter
- type: decode-filter
- type: max-score-picker
+- type: prefix-based-pd-decider
+ parameters:
+ nonCachedTokens: 16
- type: pd-profile-handler
parameters:
- threshold: 10
+ deciderPluginName: prefix-based-pd-decider
schedulingProfiles:
- name: prefill
plugins:
@@ -487,15 +882,16 @@ kind: EndpointPickerConfig
plugins:
- type: precise-prefix-cache-scorer
parameters:
+ tokenProcessorConfig:
+ blockSize: 16
+ hashSeed: "42"
kvEventsConfig:
zmqEndpoint: tcp://0.0.0.0:5557
indexerConfig:
prefixStoreConfig:
blockSize: 16
- tokenProcessorConfig:
- blockSize: 16 # must match vLLM block size if not default (16)
- hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods
tokenizersPoolConfig:
+ modelName: Qwen/Qwen2.5-1.5B-Instruct
hf:
tokenizersCacheDir: "/cache/tokenizers"
kvBlockIndexConfig:
diff --git a/test/e2e/yaml/vllm-sim-pd.yaml b/test/e2e/yaml/vllm-sim-pd.yaml
index 0f757c2f2..7a8abf14c 100644
--- a/test/e2e/yaml/vllm-sim-pd.yaml
+++ b/test/e2e/yaml/vllm-sim-pd.yaml
@@ -65,7 +65,7 @@ spec:
args:
- "--port=8000"
- "--vllm-port=8200"
- - "--connector=nixlv2"
+ - "--connector=${CONNECTOR_TYPE}"
- "--secure-proxy=false"
- "--decoder-use-tls=false"
ports:
diff --git a/test/e2e/yaml/vllm-sim.yaml b/test/e2e/yaml/vllm-sim.yaml
index 36036996c..af6ad5d89 100644
--- a/test/e2e/yaml/vllm-sim.yaml
+++ b/test/e2e/yaml/vllm-sim.yaml
@@ -15,41 +15,45 @@ spec:
app: ${POOL_NAME}
spec:
containers:
- - name: vllm
- image: ${VLLM_SIMULATOR_IMAGE}
- imagePullPolicy: IfNotPresent
- args:
- - "--mode=echo"
- - "--enable-kvcache=${KV_CACHE_ENABLED}"
- - "--port=8000"
- - "--model=${MODEL_NAME}"
- - "--kv-cache-size=1024"
- - "--block-size=16"
- - "--zmq-endpoint=tcp://e2e-epp.default.svc.cluster.local:5557"
- - "--event-batch-size=16"
- - "--tokenizers-cache-dir=/tokenizer-cache"
- ports:
- - name: http
- containerPort: 8000
- protocol: TCP
- env:
- - name: PORT
- value: "8000"
- - name: PYTHONHASHSEED
- value: "42"
- - name: POD_NAME
- valueFrom:
- fieldRef:
- apiVersion: v1
- fieldPath: metadata.name
- - name: POD_NAMESPACE
- valueFrom:
- fieldRef:
- apiVersion: v1
- fieldPath: metadata.namespace
- volumeMounts:
- - name: tokenizer-cache
- mountPath: /tokenizer-cache
+ - name: vllm
+ image: ${VLLM_SIMULATOR_IMAGE}
+ imagePullPolicy: IfNotPresent
+ args:
+ - "--mode=echo"
+ - "--enable-kvcache=${KV_CACHE_ENABLED}"
+ - "--port=8000"
+ - "--model=${MODEL_NAME}"
+ - "--kv-cache-size=1024"
+ - "--block-size=16"
+ - "--zmq-endpoint=tcp://e2e-epp.default.svc.cluster.local:5557"
+ - "--event-batch-size=16"
+ - "--tokenizers-cache-dir=/tokenizer-cache"
+ ports:
+ - name: http
+ containerPort: 8000
+ protocol: TCP
+ env:
+ - name: POD_IP
+ valueFrom:
+ fieldRef:
+ fieldPath: status.podIP
+ - name: PORT
+ value: "8000"
+ - name: PYTHONHASHSEED
+ value: "42"
+ - name: POD_NAME
+ valueFrom:
+ fieldRef:
+ apiVersion: v1
+ fieldPath: metadata.name
+ - name: POD_NAMESPACE
+ valueFrom:
+ fieldRef:
+ apiVersion: v1
+ fieldPath: metadata.namespace
+ volumeMounts:
+ - name: tokenizer-cache
+ mountPath: /tokenizer-cache
volumes:
- - name: tokenizer-cache
- emptyDir: {}
+ - name: tokenizer-cache
+ emptyDir: {}
diff --git a/test/scripts/run_e2e.sh b/test/scripts/run_e2e.sh
index 5d81e4ce6..a51fe0f4b 100755
--- a/test/scripts/run_e2e.sh
+++ b/test/scripts/run_e2e.sh
@@ -1,5 +1,17 @@
#!/bin/bash
+set -euo pipefail
+
+cleanup() {
+ echo "Interrupted! Cleaning up kind cluster..."
+ kind delete cluster --name e2e-tests 2>/dev/null || true
+ exit 130 # SIGINT (Ctrl+C)
+}
+
+# Set trap only for interruption signals
+# Normally kind cluster cleanup is done by AfterSuite
+trap cleanup INT TERM
+
echo "Running end to end tests"
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
diff --git a/test/sidecar/config/nixl/qwen-decoder-pod.yaml b/test/sidecar/config/nixl/qwen-decoder-pod.yaml
index aaf7e1a80..6e24e5a85 100644
--- a/test/sidecar/config/nixl/qwen-decoder-pod.yaml
+++ b/test/sidecar/config/nixl/qwen-decoder-pod.yaml
@@ -58,13 +58,13 @@ spec:
fieldRef:
fieldPath: status.podIP
- name: VLLM_NIXL_SIDE_CHANNEL_PORT
- value: "5557"
+ value: "5600"
- name: HF_HUB_CACHE
value: /vllm-workspace/models
- name: VLLM_LOGGING_LEVEL
value: DEBUG
ports:
- - containerPort: 5557
+ - containerPort: 5600
protocol: TCP
volumeMounts:
- name: model-cache
diff --git a/test/sidecar/config/nixl/qwen-prefiller-pod.yaml b/test/sidecar/config/nixl/qwen-prefiller-pod.yaml
index 1792a1c85..d5bb900e5 100644
--- a/test/sidecar/config/nixl/qwen-prefiller-pod.yaml
+++ b/test/sidecar/config/nixl/qwen-prefiller-pod.yaml
@@ -38,7 +38,7 @@ spec:
- name: UCX_TLS
value: "cuda_ipc,cuda_copy,tcp"
- name: VLLM_NIXL_SIDE_CHANNEL_PORT
- value: "5557"
+ value: "5600"
- name: VLLM_NIXL_SIDE_CHANNEL_HOST
valueFrom:
fieldRef:
@@ -53,7 +53,7 @@ spec:
ports:
- containerPort: 8000
protocol: TCP
- - containerPort: 5557
+ - containerPort: 5600
protocol: TCP
resources:
limits:
diff --git a/test/sidecar/mock/chat_completions_handler.go b/test/sidecar/mock/chat_completions_handler.go
index ddc94f917..6d5b8cd35 100644
--- a/test/sidecar/mock/chat_completions_handler.go
+++ b/test/sidecar/mock/chat_completions_handler.go
@@ -132,8 +132,8 @@ func (cc *ChatCompletionHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques
}
- case "lmcache":
- // LMCache protocol just returns empty response
+ case "shared-storage":
+ // Shared Storage protocol just returns empty response
rawResponse = `{}`
default: