diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 000000000..73ae7ffbe --- /dev/null +++ b/.dockerignore @@ -0,0 +1,21 @@ +# Git +.git +.gitignore + +# Build artifacts +bin +build + +# IDE and OS files +.idea +.vscode +*.DS_Store + +# Local virtual environments +venv + +# Python cache files +__pycache__ + +# Docker files +Dockerfile diff --git a/.github/workflows/auto-assign.yaml b/.github/workflows/auto-assign.yaml index 9fa5e0447..4c9eca29f 100644 --- a/.github/workflows/auto-assign.yaml +++ b/.github/workflows/auto-assign.yaml @@ -2,9 +2,9 @@ name: Auto Assign Reviewers on: pull_request: - types: [opened, reopened, synchronize] + types: [opened] pull_request_target: - types: [opened, reopened, synchronize] + types: [opened] permissions: contents: read diff --git a/.github/workflows/check-typos.yaml b/.github/workflows/check-typos.yaml index ea4fb6859..6bb82afe5 100644 --- a/.github/workflows/check-typos.yaml +++ b/.github/workflows/check-typos.yaml @@ -13,5 +13,5 @@ jobs: uses: actions/checkout@v6 - name: Check typos - uses: crate-ci/typos@v1.42.0 + uses: crate-ci/typos@v1.43.0 diff --git a/.github/workflows/ci-pr-checks.yaml b/.github/workflows/ci-pr-checks.yaml index dbf6e3dd9..ccbdf1e05 100644 --- a/.github/workflows/ci-pr-checks.yaml +++ b/.github/workflows/ci-pr-checks.yaml @@ -12,7 +12,7 @@ jobs: check-changes: runs-on: ubuntu-latest outputs: - docs: ${{ steps.filter.outputs.docs }} + src: ${{ steps.filter.outputs.src }} steps: - name: Checkout source uses: actions/checkout@v6 @@ -20,14 +20,23 @@ jobs: id: filter with: filters: | - docs: - - 'README.md' - - 'docs/**' + src: + - '**/*.go' + - '**/*.py' + - Dockerfile.epp + - Dockerfile.sidecar + - Makefile* + - go.mod lint-and-test: needs: check-changes - if: ${{ needs.check-changes.outputs.docs == 'false' }} + if: ${{ needs.check-changes.outputs.src == 'true' }} runs-on: ubuntu-latest steps: + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + - name: Checkout source uses: actions/checkout@v6 @@ -43,9 +52,6 @@ jobs: go-version: "${{ env.GO_VERSION }}" cache-dependency-path: ./go.sum - - name: Install dependencies - run: sudo make install-dependencies - - name: Configure CGO for Python run: | PYTHON_INCLUDE=$(python3 -c "import sysconfig; print(sysconfig.get_path('include'))") @@ -57,14 +63,17 @@ jobs: - name: Set PKG_CONFIG_PATH run: echo "PKG_CONFIG_PATH=/usr/lib/pkgconfig" >> $GITHUB_ENV - - name: go mod tidy - run: go mod tidy + - name: Install dependencies + run: | + go mod tidy + sudo -E env "PATH=$PATH" make install-dependencies install-python-deps - name: Run lint checks uses: golangci/golangci-lint-action@v9 with: - version: 'v2.1.6' + version: "v2.1.6" args: "--config=./.golangci.yml" + skip-cache: true env: CGO_ENABLED: ${{ env.CGO_ENABLED }} CGO_CFLAGS: ${{ env.CGO_CFLAGS }} @@ -74,10 +83,8 @@ jobs: - name: Run make build shell: bash - run: | - make build + run: make build - name: Run make test shell: bash - run: | - make test + run: make test diff --git a/.github/workflows/ci-release.yaml b/.github/workflows/ci-release.yaml index 233debc30..6734738e6 100644 --- a/.github/workflows/ci-release.yaml +++ b/.github/workflows/ci-release.yaml @@ -11,6 +11,11 @@ jobs: docker-build-and-push: runs-on: ubuntu-latest steps: + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: false + - name: Checkout source uses: actions/checkout@v6 diff --git a/.github/workflows/dispatch-on-lgtm.yml b/.github/workflows/dispatch-on-lgtm.yml new file mode 100644 index 000000000..61e66a751 --- /dev/null +++ b/.github/workflows/dispatch-on-lgtm.yml @@ -0,0 +1,18 @@ +name: ChatOps Dispatcher +on: + issue_comment: + types: [created] + +jobs: + dispatch: + runs-on: ubuntu-latest + steps: + - name: Slash Command Dispatch + uses: peter-evans/slash-command-dispatch@v3 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commands: lgtm + # 1. Basic Security: Only allow users with 'write' access to even trigger this + permission: write + issue-type: pull-request + reactions: false diff --git a/.github/workflows/lgtm-command.yml b/.github/workflows/lgtm-command.yml new file mode 100644 index 000000000..5f7d93e2c --- /dev/null +++ b/.github/workflows/lgtm-command.yml @@ -0,0 +1,134 @@ +# ============================================================================ +# LGTM Command Worker +# ============================================================================ +# Handles /lgtm commands from chatops-dispatcher +# +# Flow: +# 1. User comments /lgtm on PR +# 2. chatops-dispatcher catches it and dispatches here +# 3. This workflow: +# - Verifies user is in OWNERS file +# - Checks if PR is draft +# - Checks for blocking labels (hold) +# - Adds lgtm label +# - Enables auto-merge +# ============================================================================ + +name: LGTM Command Worker +on: + repository_dispatch: + types: [lgtm-command] + +env: + BLOCKING_LABELS: "hold" + +jobs: + apply-lgtm: + runs-on: ubuntu-latest + permissions: + contents: write + pull-requests: write + issues: write + steps: + - uses: actions/checkout@v4 + - uses: tibdex/github-app-token@v1 + id: generate-token + with: + app_id: ${{ secrets.VLLMD_BOT_APP_ID }} + private_key: ${{ secrets.VLLMD_BOT_APP_PRIVATE_KEY }} + repository: ${{ github.repository }} + + # ----------------------------------------------------------------------- + # STEP 1: AUTHORIZATION - Verify user is in OWNERS file + # ----------------------------------------------------------------------- + # Only users listed as approvers in OWNERS can use /lgtm + # This prevents unauthorized users from approving PRs + - name: Check Permissions + env: + ACTOR: ${{ github.event.client_payload.github.actor }} + run: | + # Extract only the approvers section and check if ACTOR is listed + APPROVERS=$(sed -n '/^approvers:/,/^[^ -]/p' OWNERS | grep -v '^approvers:') + if echo "$APPROVERS" | grep -q "^\s*-\s*$ACTOR\s*$"; then + echo "User $ACTOR is authorized" + else + echo "::error:: User $ACTOR is not an approver." + exit 1 + fi + + # ----------------------------------------------------------------------- + # STEP 2: VALIDATION - Check if PR is in draft mode + # ----------------------------------------------------------------------- + # Draft PRs cannot be approved - they must be marked ready for review + # This prevents accidental approval of incomplete work + - name: Check Draft Status + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.client_payload.github.payload.issue.number }} + REPO: ${{ github.repository }} + run: | + IS_DRAFT=$(gh pr view $PR_NUMBER --repo "$REPO" --json isDraft --jq '.isDraft') + if [ "$IS_DRAFT" = "true" ]; then + echo "::error:: Cannot LGTM a Draft PR." + gh issue comment $PR_NUMBER --repo "$REPO" --body "⚠️ **LGTM Failed**: PR is a Draft." + exit 1 + fi + + # ----------------------------------------------------------------------- + # STEP 3: BLOCKING LABELS - Check for hold + # ----------------------------------------------------------------------- + # If any blocking label exists, fail immediately + # This prevents approving PRs that are explicitly marked as not ready + - name: Check for Blocking Labels + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_URL: ${{ github.event.client_payload.github.payload.issue.html_url }} + run: | + LABELS=$(gh pr view "$PR_URL" --json labels --jq '.labels[].name') + if echo "$LABELS" | grep -Eiq "^($BLOCKING_LABELS)$"; then + echo "::error:: PR is blocked by label." + gh issue comment "$PR_URL" --body " **Merge Blocked**: Please remove the \`hold\` label before merging." + exit 1 + fi + + # ----------------------------------------------------------------------- + # STEP 4: APPLY LGTM - Add label, wait, then enable auto-merge + # ----------------------------------------------------------------------- + # 1. Add lgtm label (triggers gatekeeper to validate) + # 2. Enable auto-merge (PR will merge when all checks pass) + - name: Apply or Cancel LGTM + env: + GH_TOKEN: ${{ steps.generate-token.outputs.token }} + PR_NUMBER: ${{ github.event.client_payload.github.payload.issue.number }} + PR_URL: ${{ github.event.client_payload.github.payload.issue.html_url }} + # Extract the full comment body from the dispatcher payload + COMMENT_BODY: ${{ github.event.client_payload.github.payload.comment.body }} + run: | + # Check if the command is a cancellation + if echo "$COMMENT_BODY" | grep -q "/lgtm cancel"; then + echo "🚨 Retracting LGTM status..." + + # 1. Remove lgtm label + gh issue edit "$PR_NUMBER" --remove-label "lgtm" || echo "Label already gone" + + # 2. Disable Auto-Merge + gh pr merge --disable-auto "$PR_URL" || echo "Auto-merge was not enabled" + + # 3. Notify user + gh issue comment "$PR_URL" --body "Retracted: **LGTM** label removed and auto-merge disabled by @$ACTOR." + + else + echo "βœ… Applying LGTM status..." + + # 1. Add lgtm label + gh issue edit "$PR_NUMBER" --add-label "lgtm" + + # 2. Enable auto-merge (Squash) + if ! gh pr merge --auto --squash "$PR_URL" 2>&1 | tee merge_output.txt; then + ERROR_MSG=$(cat merge_output.txt) + gh issue comment "$PR_URL" --body "⚠️ **Auto-merge failed**: $ERROR_MSG" + exit 1 + fi + + gh issue comment "$PR_URL" --body "βœ… **LGTM**: Approval recorded and auto-merge enabled." + fi \ No newline at end of file diff --git a/.github/workflows/lgtm-gatekeeper.yml b/.github/workflows/lgtm-gatekeeper.yml new file mode 100644 index 000000000..ed687c627 --- /dev/null +++ b/.github/workflows/lgtm-gatekeeper.yml @@ -0,0 +1,45 @@ +# ============================================================================ +# LGTM Gatekeeper - Required Status Check +# ============================================================================ +# Rules Enforced: +# 1. PR MUST have "lgtm" label +# 2. PR MUST NOT have blocking labels (hold) +# ============================================================================ + +name: LGTM Gatekeeper +on: + pull_request: + # Run on PR open/reopen and label changes + # NOT on synchronize (handled by lgtm-reset.yml) + types: [opened, labeled, unlabeled, reopened] + +env: + BLOCKING_LABELS: "hold" + +jobs: + validate-pr: + runs-on: ubuntu-latest + steps: + - name: Enforce LGTM & Blockers + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PR_NUMBER: ${{ github.event.pull_request.number }} + REPO: ${{ github.repository }} + run: | + # Fetch current labels + LABELS=$(gh pr view $PR_NUMBER --repo "$REPO" --json labels --jq '.labels[].name') + + # Check 1: IS IT BLOCKED? + if echo "$LABELS" | grep -Eiq "^($BLOCKING_LABELS)$"; then + echo "::error:: β›” FAILED: PR is blocked by a 'hold' label." + exit 1 + fi + + # Check 2: IS IT APPROVED? + # If Reset workflow removed the label, this check fails immediately. + if ! echo "$LABELS" | grep -Fqx "lgtm"; then + echo "::error:: β›” FAILED: PR is missing the 'lgtm' label." + exit 1 + fi + + echo "βœ… PASSED: LGTM present and no blockers." diff --git a/.github/workflows/lgtm-reset.yml b/.github/workflows/lgtm-reset.yml new file mode 100644 index 000000000..6fc69dbbc --- /dev/null +++ b/.github/workflows/lgtm-reset.yml @@ -0,0 +1,46 @@ +# ============================================================================ +# LGTM Reset - Auto-Remove LGTM on New Commits +# ============================================================================ +# Kubernetes Prow behavior: When new commits are pushed, approval is invalidated +# +# What It Does: +# 1. Detects when new commits are pushed to a PR +# 2. Removes the "lgtm" label (if present) +# 3. Disables auto-merge (safety net) +# 4. Posts a comment explaining why +# ============================================================================ + +name: LGTM Reset +on: + pull_request: + types: [synchronize] # Triggers instantly on new commits + +jobs: + reset-lgtm: + runs-on: ubuntu-latest + permissions: + pull-requests: write + steps: + - uses: tibdex/github-app-token@v1 + id: generate-token + with: + app_id: ${{ secrets.VLLMD_BOT_APP_ID }} + private_key: ${{ secrets.VLLMD_BOT_APP_PRIVATE_KEY }} + repository: ${{ github.repository }} + - name: Invalidate LGTM + env: + GH_TOKEN: ${{ steps.generate-token.outputs.token }} + PR_NUMBER: ${{ github.event.pull_request.number }} + REPO: ${{ github.repository }} + run: | + echo "🚨 New code pushed. Resetting LGTM status..." + + # 1. Remove the label (This triggers the Gatekeeper to run again) + gh pr edit $PR_NUMBER --repo "$REPO" --remove-label "lgtm" || true + + # 2. Disable Auto-Merge (Safety net) + gh pr merge --disable-auto $PR_NUMBER --repo "$REPO" || true + + # 3. Notify user + gh issue comment $PR_NUMBER --repo "$REPO" --body "πŸ”„ **Reset**: New commits pushed. LGTM removed." + \ No newline at end of file diff --git a/.github/workflows/prow-github.yml b/.github/workflows/prow-github.yml index 0c5f11fd8..42a37cc70 100644 --- a/.github/workflows/prow-github.yml +++ b/.github/workflows/prow-github.yml @@ -18,7 +18,7 @@ jobs: steps: - uses: jpmcb/prow-github-actions@v2.0.0 with: - github-token: "${{ secrets.GITHUB_TOKEN }}" + github-token: "${{ secrets.BOT_TOKEN }}" prow-commands: "/assign /unassign /approve @@ -27,7 +27,6 @@ jobs: /kind /priority /remove - /lgtm /close /reopen /lock diff --git a/.github/workflows/prow-pr-automerge.yml b/.github/workflows/prow-pr-automerge.yml deleted file mode 100644 index c9bb0972b..000000000 --- a/.github/workflows/prow-pr-automerge.yml +++ /dev/null @@ -1,18 +0,0 @@ -# This Github workflow will check every 5m for PRs with the lgtm label and will attempt to automatically merge them. -# If the hold label is present, it will block automatic merging. - -name: "Prow merge on lgtm label" -on: - schedule: - - cron: "*/5 * * * *" # every 5 minutes - -jobs: - auto-merge: - runs-on: ubuntu-latest - steps: - - uses: jpmcb/prow-github-actions@v2.0.0 - with: - jobs: 'lgtm' - github-token: "${{ secrets.GITHUB_TOKEN }}" - merge-method: 'squash' - diff --git a/.github/workflows/prow-pr-remove-lgtm.yml b/.github/workflows/prow-pr-remove-lgtm.yml deleted file mode 100644 index caf208f36..000000000 --- a/.github/workflows/prow-pr-remove-lgtm.yml +++ /dev/null @@ -1,11 +0,0 @@ -name: Run Jobs on PR -on: pull_request - -jobs: - execute: - runs-on: ubuntu-latest - steps: - - uses: jpmcb/prow-github-actions@v2.0.0 - with: - jobs: lgtm - github-token: '${{ secrets.GITHUB_TOKEN }}' diff --git a/.gitignore b/.gitignore index f94c6c6ce..7f5a0e944 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,8 @@ main bin/ +*debug_bin* + # Test binary, built with `go test -c` *.test @@ -27,6 +29,7 @@ go.work.sum # Environment Files .DS_Store .env +CLAUDE.md # IDE files .idea diff --git a/.golangci.yml b/.golangci.yml index bc928c25b..47e51d394 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -6,28 +6,60 @@ run: formatters: enable: - - goimports - - gofmt + - goimports + - gofmt linters: enable: - - copyloopvar - - dupword - - durationcheck - - fatcontext - - ginkgolinter - - gocritic - - govet - - loggercheck - - misspell - - perfsprint - - revive - - unconvert - - makezero - - errcheck - - goconst - - ineffassign - - nakedret - - prealloc - - unparam - - unused + - bodyclose + - copyloopvar + - dupword + - durationcheck + - errcheck + - fatcontext + - ginkgolinter + - goconst + - gocritic + - govet + - ineffassign + - loggercheck + - makezero + - misspell + - nakedret + - nilnil + - perfsprint + - prealloc + - revive + - staticcheck + - unparam + - unused + - unconvert + settings: + revive: # see https://github.com/mgechev/revive#available-rules for all options + rules: + - name: blank-imports + - name: context-as-argument + - name: context-keys-type + - name: dot-imports + - name: error-return + - name: error-strings + - name: error-naming + - name: exported + - name: if-return + - name: increment-decrement + - name: var-naming + - name: var-declaration + - name: package-comments + - name: range + - name: receiver-naming + - name: time-naming + - name: unexported-return + - name: indent-error-flow + - name: errorf + +issues: + # do not limit the number of issues, ensure even identical are reported + max-issues-per-linter: 0 + max-same-issues: 0 + # Note: 'new' setting is controlled via Makefile LINT_NEW_ONLY variable + # Set to true to only check new code, false (default) to check all code diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index caef5f9d7..bf0921699 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -356,3 +356,20 @@ helm uninstall kgateway-crds -n kgateway-system ``` For more details, see the Gateway API inference Extension [getting started guide](https://gateway-api-inference-extension.sigs.k8s.io/guides/) + +## PR Approval Process + +The project uses a Prow-inspired ChatOps system to manage PR approvals via comment commands. + +### Available Commands + +| Command | Policy | Description | +|---------|--------|-------------| +| `/lgtm` | OWNERS approvers | Adds the `lgtm` label and enables auto-merge (squash). The PR merges automatically once all checks pass. | +| `/lgtm cancel` | OWNERS approvers | Removes the `lgtm` label and disables auto-merge. | +| `/hold` | Anyone with write access | Adds the `hold` label to prevent the PR from merging. | +| `/hold cancel` | Anyone with write access | Removes the `hold` label. | + +### Approval Reset on New Commits + +When new commits are pushed to an approved PR, the `lgtm` label is automatically removed and auto-merge is disabled. This ensures approvals always reflect the latest code. The author must request a new `/lgtm` after pushing changes. diff --git a/Dockerfile.epp b/Dockerfile.epp index 915a34a0a..4996c6fbe 100644 --- a/Dockerfile.epp +++ b/Dockerfile.epp @@ -1,13 +1,52 @@ ## Minimal runtime Dockerfile (microdnf-only, no torch, wrapper in site-packages) -# Build Stage: using Go 1.24 image -FROM quay.io/projectquay/golang:1.24 AS builder +# Go dependencies stage: download go modules and extract kv-cache +FROM quay.io/projectquay/golang:1.24 AS go-deps + +WORKDIR /workspace + +# Copy the Go Modules manifests +COPY go.mod go.mod +COPY go.sum go.sum + +# Copy the go source +COPY cmd/ cmd/ +COPY pkg/ pkg/ + +RUN go mod download + +# Copy Python wrapper and requirements from llm-d-kv-cache dependency +# Extract version dynamically and copy to a known location +RUN KV_CACHE_PKG=$(go list -m -f '{{.Dir}}' github.com/llm-d/llm-d-kv-cache) && \ + mkdir -p /workspace/kv-cache && \ + cp -r $KV_CACHE_PKG/* /workspace/kv-cache + +FROM python:3.12-slim AS python-builder + +ARG TARGETARCH + +COPY --from=go-deps /workspace/kv-cache /workspace/kv-cache +WORKDIR /workspace/kv-cache + +# Create venv and install vLLM based on architecture using pre-built wheels +RUN python3.12 -m venv /workspace/kv-cache/build/venv && \ + . /workspace/kv-cache/build/venv/bin/activate && \ + pip install --upgrade pip && \ + VLLM_VERSION="0.14.0" && \ + if [ "$TARGETARCH" = "arm64" ]; then \ + pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_aarch64.whl; \ + elif [ "$TARGETARCH" = "amd64" ]; then \ + pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cpu; \ + else \ + echo "ERROR: Unsupported architecture: $TARGETARCH. Only arm64 and amd64 are supported." && exit 1; \ + fi + +# Go build stage +FROM quay.io/projectquay/golang:1.24 AS go-builder ARG TARGETOS ARG TARGETARCH ARG PYTHON_VERSION=3.12 - ENV PYTHON=python${PYTHON_VERSION} -ENV PYTHONPATH=/usr/lib64/${PYTHON}/site-packages:/usr/lib/${PYTHON}/site-packages # Install build tools # The builder is based on UBI8, so we need epel-release-8. @@ -16,52 +55,22 @@ RUN dnf install -y 'https://dl.fedoraproject.org/pub/epel/epel-release-latest-8. dnf install -y gcc-c++ libstdc++ libstdc++-devel clang zeromq-devel pkgconfig ${PYTHON}-devel ${PYTHON}-pip git && \ dnf clean all +COPY --from=go-deps /workspace /workspace +COPY --from=go-deps /go/pkg/mod /go/pkg/mod WORKDIR /workspace -# Copy the Go Modules manifests -COPY go.mod go.mod -COPY go.sum go.sum +COPY Makefile* ./ -# Copy the go source -COPY cmd/ cmd/ -COPY pkg/ pkg/ +COPY --from=python-builder /workspace/kv-cache/pkg/preprocessing/chat_completions /workspace/kv-cache/pkg/preprocessing/chat_completions +RUN make setup-venv +COPY --from=python-builder /workspace/kv-cache/build/venv/lib/python3.12/site-packages /workspace/build/venv/lib/python3.12/site-packages -RUN go mod download +ENV PYTHONPATH=/workspace/kv-cache/pkg/preprocessing/chat_completions:/workspace/build/venv/lib/python3.12/site-packages +RUN python3.12 -c "import tokenizer_wrapper" # verify tokenizer_wrapper is correctly installed -# Copy Python wrapper and requirements from llm-d-kv-cache-manager dependency -# Extract version dynamically and copy to a known location -# We need to keep llm-d-kv-cache-manager as go module path is kept the old name -RUN KVCACHE_MANAGER_VERSION=$(go list -m -f '{{.Version}}' github.com/llm-d/llm-d-kv-cache-manager) && \ - mkdir -p /workspace/kv-cache && \ - cp /go/pkg/mod/github.com/llm-d/llm-d-kv-cache-manager@${KVCACHE_MANAGER_VERSION}/pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py \ - /workspace/kv-cache/render_jinja_template_wrapper.py && \ - cp /go/pkg/mod/github.com/llm-d/llm-d-kv-cache-manager@${KVCACHE_MANAGER_VERSION}/pkg/preprocessing/chat_completions/requirements.txt \ - /workspace/kv-cache/requirements.txt - -# HuggingFace tokenizer bindings (static lib) -RUN mkdir -p lib -# Ensure that the RELEASE_VERSION matches the one used in the imported llm-d-kv-cache-manager version ARG RELEASE_VERSION=v1.22.1 -RUN curl -L https://github.com/daulet/tokenizers/releases/download/${RELEASE_VERSION}/libtokenizers.${TARGETOS}-${TARGETARCH}.tar.gz | tar -xz -C lib -RUN ranlib lib/*.a - -# Build -# the GOARCH has not a default value to allow the binary be built according to the host where the command -# was called. For example, if we call make image-build in a local env which has the Apple Silicon M1 SO -# the docker BUILDPLATFORM arg will be linux/arm64 when for Apple x86 it will be linux/amd64. Therefore, -# by leaving it empty we can ensure that the container and binary shipped on it will have the same platform. -ENV CGO_ENABLED=1 -ENV GOOS=${TARGETOS:-linux} -ENV GOARCH=${TARGETARCH} - - -ARG COMMIT_SHA=unknown -ARG BUILD_REF -RUN CGO_CFLAGS="$(${PYTHON}-config --cflags) -I/workspace/lib" && \ - CGO_LDFLAGS="$(${PYTHON}-config --ldflags --embed) -L/workspace/lib -ltokenizers -ldl -lm" && \ - export CGO_CFLAGS CGO_LDFLAGS && \ - go build -a -o bin/epp -ldflags="-extldflags '-L$(pwd)/lib' -X sigs.k8s.io/gateway-api-inference-extension/version.CommitSHA=${COMMIT_SHA} -X sigs.k8s.io/gateway-api-inference-extension/version.BuildRef=${BUILD_REF}" cmd/epp/main.go +RUN TOKENIZER_VERSION=${RELEASE_VERSION} make build-epp # Runtime stage # Use ubi9 as a minimal base image to package the manager binary @@ -69,7 +78,7 @@ RUN CGO_CFLAGS="$(${PYTHON}-config --cflags) -I/workspace/lib" && \ FROM registry.access.redhat.com/ubi9/ubi-minimal:9.7 ARG PYTHON_VERSION=3.12 WORKDIR / -COPY --from=builder /workspace/bin/epp /app/epp +COPY --from=go-builder /workspace/bin/epp /app/epp USER root @@ -87,25 +96,16 @@ RUN curl -L -o /tmp/epel-release.rpm https://dl.fedoraproject.org/pub/epel/epel- ln -sf /usr/bin/${PYTHON} /usr/bin/python3 && \ ln -sf /usr/bin/${PYTHON} /usr/bin/python +# Copy Python kv-cache package and site-packages from the python-builder stage +COPY --from=python-builder /workspace/kv-cache /workspace/kv-cache +ENV PYTHONPATH=/workspace/kv-cache/pkg/preprocessing/chat_completions:/workspace/kv-cache/build/venv/lib/python3.12/site-packages +RUN ${PYTHON} -c "import tokenizer_wrapper" # verify tokenizer_wrapper is correctly installed -# Install wrapper as a module in site-packages -RUN mkdir -p /usr/local/lib/${PYTHON}/site-packages/ -COPY --from=builder /workspace/kv-cache/render_jinja_template_wrapper.py /usr/local/lib/${PYTHON}/site-packages/ - -# Python deps (no cache, single target) – filter out torch -ENV PIP_NO_CACHE_DIR=1 PIP_DISABLE_PIP_VERSION_CHECK=1 -COPY --from=builder /workspace/kv-cache/requirements.txt /tmp/requirements.txt -RUN sed '/^torch\b/d' /tmp/requirements.txt > /tmp/requirements.notorch.txt && \ - ${PYTHON} -m pip install --no-cache-dir --upgrade pip setuptools wheel && \ - ${PYTHON} -m pip install --no-cache-dir --target /usr/local/lib/${PYTHON}/site-packages -r /tmp/requirements.notorch.txt && \ - ${PYTHON} -m pip install --no-cache-dir --target /usr/local/lib/${PYTHON}/site-packages PyYAML && \ - rm /tmp/requirements.txt /tmp/requirements.notorch.txt && \ - rm -rf /root/.cache/pip - -# Python env -ENV PYTHONPATH="/usr/local/lib/${PYTHON}/site-packages:/usr/lib/${PYTHON}/site-packages" -ENV PATH=/usr/bin:/usr/local/bin:$PATH ENV HF_HOME="/tmp/.cache" +# used by kv-cache-manager +ENV LOCAL_TOKENIZER_DIR="/tmp/.cache" +# Create cache directory and set permissions for non-root user +RUN mkdir -p /tmp/.cache && chown -R 65532:65532 ${HF_HOME} USER 65532:65532 @@ -117,4 +117,3 @@ EXPOSE 9090 EXPOSE 5557 ENTRYPOINT ["/app/epp"] - diff --git a/Makefile b/Makefile index edbd8ec56..89c60c651 100644 --- a/Makefile +++ b/Makefile @@ -25,10 +25,11 @@ export EPP_IMAGE ?= $(IMAGE_TAG_BASE):$(EPP_TAG) SIDECAR_TAG ?= dev SIDECAR_IMAGE_TAG_BASE ?= $(IMAGE_REGISTRY)/$(SIDECAR_IMAGE_NAME) export SIDECAR_IMAGE ?= $(SIDECAR_IMAGE_TAG_BASE):$(SIDECAR_TAG) -VLLM_SIMULATOR_TAG ?= v0.6.1 +VLLM_SIMULATOR_TAG ?= latest VLLM_SIMULATOR_TAG_BASE ?= $(IMAGE_REGISTRY)/$(VLLM_SIMULATOR_IMAGE_NAME) export VLLM_SIMULATOR_IMAGE ?= $(VLLM_SIMULATOR_TAG_BASE):$(VLLM_SIMULATOR_TAG) NAMESPACE ?= hc4ai-operator +LINT_NEW_ONLY ?= false # Set to true to only lint new code, false to lint all code (default matches CI behavior) # Map go arch to platform-specific arch ifeq ($(TARGETOS),darwin) @@ -139,9 +140,15 @@ format: check-golangci-lint ## Format Go source files $(GOLANGCI_LINT) fmt .PHONY: lint -lint: check-golangci-lint check-typos ## Run lint +lint: check-golangci-lint check-typos ## Run lint (use LINT_NEW_ONLY=true to only check new code) @printf "\033[33;1m==== Running linting ====\033[0m\n" - CGO_CFLAGS="${CGO_CFLAGS}" $(GOLANGCI_LINT) run + @if [ "$(LINT_NEW_ONLY)" = "true" ]; then \ + printf "\033[33mChecking new code only (LINT_NEW_ONLY=true)\033[0m\n"; \ + CGO_CFLAGS="${CGO_CFLAGS}" $(GOLANGCI_LINT) run --new; \ + else \ + printf "\033[33mChecking all code (LINT_NEW_ONLY=false, default)\033[0m\n"; \ + CGO_CFLAGS="${CGO_CFLAGS}" $(GOLANGCI_LINT) run; \ + fi $(TYPOS) .PHONY: install-hooks @@ -157,10 +164,29 @@ test-unit: test-unit-epp test-unit-sidecar ## Run unit tests .PHONY: test-unit-% test-unit-%: download-tokenizer install-python-deps check-dependencies ## Run unit tests @printf "\033[33;1m==== Running Unit Tests ====\033[0m\n" - @KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache-manager 2>/dev/null || echo ""); \ + @KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache 2>/dev/null || echo ""); \ PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \ CGO_CFLAGS=${$*_CGO_CFLAGS} CGO_LDFLAGS=${$*_CGO_LDFLAGS} go test $($*_LDFLAGS) -v $$($($*_TEST_FILES) | tr '\n' ' ') +.PHONY: test-filter +test-filter: download-tokenizer install-python-deps check-dependencies ## Run filtered unit tests (usage: make test-filter PATTERN=TestName TYPE=epp) + @if [ -z "$(PATTERN)" ]; then \ + echo "ERROR: PATTERN is required. Usage: make test-filter PATTERN=TestName [TYPE=epp|sidecar]"; \ + exit 1; \ + fi + @TEST_TYPE="$(if $(TYPE),$(TYPE),epp)"; \ + printf "\033[33;1m==== Running Filtered Tests (pattern: $(PATTERN), type: $$TEST_TYPE) ====\033[0m\n"; \ + KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}/pkg/preprocessing/chat_completions' github.com/llm-d/llm-d-kv-cache 2>/dev/null || echo ""); \ + if [ "$$TEST_TYPE" = "epp" ]; then \ + PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \ + CGO_CFLAGS=$(epp_CGO_CFLAGS) CGO_LDFLAGS=$(epp_CGO_LDFLAGS) \ + go test $(epp_LDFLAGS) -v -run "$(PATTERN)" $$($(epp_TEST_FILES) | tr '\n' ' '); \ + else \ + PYTHONPATH="$$KV_CACHE_PKG:$(VENV_DIR)/lib/python$(PYTHON_VERSION)/site-packages" \ + CGO_CFLAGS=$(sidecar_CGO_CFLAGS) CGO_LDFLAGS=$(sidecar_CGO_LDFLAGS) \ + go test $(sidecar_LDFLAGS) -v -run "$(PATTERN)" $$($(sidecar_TEST_FILES) | tr '\n' ' '); \ + fi + .PHONY: test-integration test-integration: download-tokenizer check-dependencies ## Run integration tests @printf "\033[33;1m==== Running Integration Tests ====\033[0m\n" diff --git a/Makefile.tools.mk b/Makefile.tools.mk index e750cd41b..7bf41b369 100644 --- a/Makefile.tools.mk +++ b/Makefile.tools.mk @@ -18,11 +18,12 @@ GINKGO_VERSION ?= v2.27.2 GOLANGCI_LINT_VERSION ?= v2.1.6 KUSTOMIZE_VERSION ?= v5.5.0 TYPOS_VERSION ?= v1.34.0 +VLLM_VERSION ?= 0.14.0 ## Python Configuration PYTHON_VERSION ?= 3.12 # Extract RELEASE_VERSION from Dockerfile -TOKENIZER_VERSION := $(shell grep '^ARG RELEASE_VERSION=' Dockerfile.epp | cut -d'=' -f2) +TOKENIZER_VERSION ?= $(shell grep '^ARG RELEASE_VERSION=' Dockerfile.epp | cut -d'=' -f2) # Python executable for creating venv PYTHON_EXE := $(shell command -v python$(PYTHON_VERSION) || command -v python3) @@ -151,33 +152,79 @@ $(TOKENIZER_LIB): | $(LOCALLIB) @ranlib $(LOCALLIB)/*.a @echo "Tokenizer bindings downloaded successfully." - -.PHONY: install-python-deps -install-python-deps: ## Sets up Python virtual environment and installs dependencies - @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n" +.PHONY: detect-python +detect-python: ## Detects Python and prints the configuration. + @printf "\033[33;1m==== Python Configuration ====\033[0m\n" @if [ -z "$(PYTHON_EXE)" ]; then \ echo "ERROR: Python 3 not found in PATH."; \ exit 1; \ fi + @# Verify the version of the found python executable using its exit code + @if ! $(PYTHON_EXE) -c "import sys; sys.exit(0 if sys.version_info[:2] == ($(shell echo $(PYTHON_VERSION) | cut -d. -f1), $(shell echo $(PYTHON_VERSION) | cut -d. -f2)) else 1)"; then \ + echo "ERROR: Found Python at '$(PYTHON_EXE)' but it is not version $(PYTHON_VERSION)."; \ + echo "Please ensure 'python$(PYTHON_VERSION)' or a compatible 'python3' is in your PATH."; \ + exit 1; \ + fi + @echo "Python executable: $(PYTHON_EXE) ($$($(PYTHON_EXE) --version))" + @echo "Python CFLAGS: $(PYTHON_CFLAGS)" + @echo "Python LDFLAGS: $(PYTHON_LDFLAGS)" + @if [ -z "$(PYTHON_CFLAGS)" ]; then \ + echo "ERROR: Python development headers not found. See installation instructions above."; \ + exit 1; \ + fi + @printf "\033[33;1m==============================\033[0m\n" + +.PHONY: setup-venv +setup-venv: detect-python ## Sets up the Python virtual environment. + @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n" @if [ ! -f "$(VENV_BIN)/pip" ]; then \ echo "Creating virtual environment..."; \ $(PYTHON_EXE) -m venv $(VENV_DIR) || { \ echo "ERROR: Failed to create virtual environment."; \ echo "Your Python installation may be missing the 'venv' module."; \ + echo "Try: 'sudo apt install python$(PYTHON_VERSION)-venv' or 'sudo dnf install python$(PYTHON_VERSION)-devel'"; \ exit 1; \ }; \ fi - @echo "Upgrading pip and installing dependencies..." - @$(VENV_BIN)/pip install --upgrade pip --quiet - @KV_CACHE_PKG=$$(go list -m -f '{{.Dir}}' github.com/llm-d/llm-d-kv-cache-manager 2>/dev/null); \ - if [ -n "$$KV_CACHE_PKG" ] && [ -f "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/requirements.txt" ]; then \ - echo "Installing Python dependencies from kv-cache-manager..."; \ - $(VENV_BIN)/pip install --quiet -r "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/requirements.txt"; \ + @echo "Upgrading pip..." + @$(VENV_BIN)/pip install --upgrade pip + @echo "Python virtual environment setup complete." + +.PHONY: install-python-deps +install-python-deps: setup-venv ## installs dependencies. + @printf "\033[33;1m==== Setting up Python virtual environment in $(VENV_DIR) ====\033[0m\n" + @echo "install vllm..." + @KV_CACHE_PKG=$${KV_CACHE_PKG:-$$(go list -m -f '{{.Dir}}' github.com/llm-d/llm-d-kv-cache 2>/dev/null)}; \ + if [ -z "$$KV_CACHE_PKG" ]; then \ + echo "ERROR: kv-cache package not found."; \ + exit 1; \ + fi; \ + if [ "$(TARGETOS)" = "darwin" ]; then \ + if [ -f "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/setup.sh" ]; then \ + echo "Running kv-cache setup script for macOS..."; \ + cp "$$KV_CACHE_PKG/pkg/preprocessing/chat_completions/setup.sh" build/kv-cache-setup.sh; \ + chmod +wx build/kv-cache-setup.sh; \ + cd build && PATH=$(VENV_BIN):$$PATH ./kv-cache-setup.sh && cd ..; \ + else \ + echo "ERROR: setup script not found at $$KV_CACHE_PKG/pkg/preprocessing/chat_completions/setup.sh"; \ + exit 1; \ + fi; \ else \ - echo "WARNING: Could not find kv-cache-manager requirements.txt, installing minimal deps..."; \ - $(VENV_BIN)/pip install --quiet 'transformers>=4.53.0' 'jinja2>=2.11'; \ + echo "Installing vLLM for Linux $(TARGETARCH)..."; \ + if [ "$(TARGETARCH)" = "arm64" ]; then \ + $(VENV_BIN)/pip install https://github.com/vllm-project/vllm/releases/download/v$(VLLM_VERSION)/vllm-$(VLLM_VERSION)+cpu-cp38-abi3-manylinux_2_35_aarch64.whl; \ + elif [ "$(TARGETARCH)" = "amd64" ]; then \ + $(VENV_BIN)/pip install https://github.com/vllm-project/vllm/releases/download/v$(VLLM_VERSION)/vllm-$(VLLM_VERSION)+cpu-cp38-abi3-manylinux_2_35_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cpu; \ + else \ + echo "ERROR: Unsupported architecture: $(TARGETARCH). Only arm64 and amd64 are supported."; \ + exit 1; \ + fi; \ fi - @echo "βœ… Python dependencies installed in venv" + @echo "Verifying vllm installation..." + @$(VENV_BIN)/python -c "import vllm; print('βœ… vllm version ' + vllm.__version__ + ' installed.')" || { \ + echo "ERROR: vllm library not properly installed in venv."; \ + exit 1; \ + } .PHONY: check-tools check-tools: check-go check-ginkgo check-golangci-lint check-kustomize check-envsubst check-container-tool check-kubectl check-buildah check-typos ## Check that all required tools are installed diff --git a/OWNERS b/OWNERS index 84ec0bd47..4262fe677 100644 --- a/OWNERS +++ b/OWNERS @@ -1,4 +1,5 @@ approvers: +- revit13 - elevran - kfswain - nilig @@ -23,4 +24,4 @@ auto-assign: - nilig - nirrozenbaum - shmuelk - - vMaroon \ No newline at end of file + - vMaroon diff --git a/README.md b/README.md index 3be2337aa..c1e291a38 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ # Inference Scheduler -This scheduler makes optimized routing decisions for inference requests to +This schedulejjr makes optimized routing decisions for inference requests to the llm-d inference framework. ## About diff --git a/cmd/pd-sidecar/main.go b/cmd/pd-sidecar/main.go index cd086eb7d..8e9e6533c 100644 --- a/cmd/pd-sidecar/main.go +++ b/cmd/pd-sidecar/main.go @@ -35,7 +35,7 @@ var ( // supportedConnectors defines all valid P/D connector types supportedConnectors = []string{ proxy.ConnectorNIXLV2, - proxy.ConnectorLMCache, + proxy.ConnectorSharedStorage, proxy.ConnectorSGLang, } ) diff --git a/deploy/components/crds-gie/kustomization.yaml b/deploy/components/crds-gie/kustomization.yaml index f2e0a4629..e9897ce37 100644 --- a/deploy/components/crds-gie/kustomization.yaml +++ b/deploy/components/crds-gie/kustomization.yaml @@ -10,4 +10,4 @@ apiVersion: kustomize.config.k8s.io/v1beta1 kind: Kustomization resources: -- https://github.com/kubernetes-sigs/gateway-api-inference-extension/config/crd?ref=v1.2.1 +- https://github.com/kubernetes-sigs/gateway-api-inference-extension/config/crd?ref=v1.3.0 diff --git a/deploy/config/dp-epp-config.yaml b/deploy/config/dp-epp-config.yaml index 703a44f67..6e8418866 100644 --- a/deploy/config/dp-epp-config.yaml +++ b/deploy/config/dp-epp-config.yaml @@ -3,7 +3,13 @@ apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig plugins: -- type: prefix-cache-scorer +- type: precise-prefix-cache-scorer + parameters: + indexerConfig: + tokenProcessorConfig: + blockSize: 5 + kvBlockIndexConfig: + maxPrefixBlocksToMatch: 256 - type: decode-filter - type: max-score-picker - type: data-parallel-profile-handler @@ -14,5 +20,5 @@ schedulingProfiles: plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - - pluginRef: prefix-cache-scorer + - pluginRef: precise-prefix-cache-scorer weight: 2 diff --git a/deploy/config/epp-precise-prefix-cache-config.yaml b/deploy/config/epp-precise-prefix-cache-config.yaml index 157505278..39b0bb285 100644 --- a/deploy/config/epp-precise-prefix-cache-config.yaml +++ b/deploy/config/epp-precise-prefix-cache-config.yaml @@ -7,10 +7,10 @@ plugins: - type: decode-filter - type: precise-prefix-cache-scorer parameters: + tokenProcessorConfig: + blockSize: 64 # must match vLLM block size + hashSeed: "42" # must match vLLM PYTHONHASHSEED env var indexerConfig: - tokenProcessorConfig: - blockSize: 64 # must match vLLM block size - hashSeed: "42" # must match vLLM PYTHONHASHSEED env var kvBlockIndexConfig: enableMetrics: true # enable kv-block index metrics (prometheus) - type: kv-cache-utilization-scorer diff --git a/deploy/config/pd-epp-config.yaml b/deploy/config/pd-epp-config.yaml index 9732be2f3..35a94a3be 100644 --- a/deploy/config/pd-epp-config.yaml +++ b/deploy/config/pd-epp-config.yaml @@ -1,13 +1,25 @@ # Sample EPP configuration for tunning with P/D apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: - type: prefill-header-handler - type: prefix-cache-scorer + parameters: + maxPrefixBlocksToMatch: 256 + lruCapacityPerServer: 31250 +- type: queue-scorer - type: prefill-filter - type: decode-filter - type: max-score-picker +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 16 - type: pd-profile-handler + parameters: + primaryPort: ${PRIMARY_PORT} + deciderPluginName: prefix-based-pd-decider schedulingProfiles: - name: prefill plugins: @@ -15,9 +27,13 @@ schedulingProfiles: - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 - name: decode plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 diff --git a/deploy/config/sim-epp-kvcache-config.yaml b/deploy/config/sim-epp-kvcache-config.yaml index 566e92437..f582c543f 100644 --- a/deploy/config/sim-epp-kvcache-config.yaml +++ b/deploy/config/sim-epp-kvcache-config.yaml @@ -3,18 +3,16 @@ apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig plugins: -- type: prefix-cache-scorer +- type: precise-prefix-cache-scorer parameters: - mode: cache_tracking kvEventsConfig: zmqEndpoint: tcp://0.0.0.0:5557 indexerConfig: - prefixStoreConfig: - blockSize: 16 tokenProcessorConfig: blockSize: 16 # must match vLLM block size if not default (16) hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods tokenizersPoolConfig: + modelName: TinyLlama/TinyLlama-1.1B-Chat-v1.0 # replace value to use different model for tokenizer loading hf: tokenizersCacheDir: "/cache/tokenizers" kvBlockIndexConfig: @@ -28,5 +26,5 @@ schedulingProfiles: plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - - pluginRef: prefix-cache-scorer + - pluginRef: precise-prefix-cache-scorer weight: 10 diff --git a/deploy/config/sim-epp-no-hit-lru.yaml b/deploy/config/sim-epp-no-hit-lru.yaml index 8d0224411..e10ec5062 100644 --- a/deploy/config/sim-epp-no-hit-lru.yaml +++ b/deploy/config/sim-epp-no-hit-lru.yaml @@ -3,11 +3,13 @@ apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig plugins: -- type: prefix-cache-scorer +- type: precise-prefix-cache-scorer parameters: - hashBlockSize: 5 - maxPrefixBlocksToMatch: 256 - lruCapacityPerServer: 31250 + indexerConfig: + tokenProcessorConfig: + blockSize: 5 + kvBlockIndexConfig: + maxPrefixBlocksToMatch: 256 - type: no-hit-lru-scorer parameters: lruSize: 2048 @@ -19,7 +21,7 @@ schedulingProfiles: plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - - pluginRef: prefix-cache-scorer + - pluginRef: precise-prefix-cache-scorer weight: 2 - pluginRef: no-hit-lru-scorer weight: 1 diff --git a/deploy/config/sim-pd-epp-config.yaml b/deploy/config/sim-pd-epp-config.yaml index 2d6a85dd9..2f93504a1 100644 --- a/deploy/config/sim-pd-epp-config.yaml +++ b/deploy/config/sim-pd-epp-config.yaml @@ -2,21 +2,27 @@ # Use with small hash block size for simulation purposes apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: - type: prefill-header-handler - type: prefix-cache-scorer parameters: - hashBlockSize: 5 + blockSizeTokens: 16 + autoTune: false maxPrefixBlocksToMatch: 256 lruCapacityPerServer: 31250 +- type: queue-scorer - type: prefill-filter - type: decode-filter - type: max-score-picker +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 16 - type: pd-profile-handler parameters: - threshold: 10 - hashBlockSize: 5 primaryPort: ${PRIMARY_PORT} + deciderPluginName: prefix-based-pd-decider schedulingProfiles: - name: prefill plugins: @@ -24,9 +30,13 @@ schedulingProfiles: - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 - name: decode plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - pluginRef: prefix-cache-scorer weight: 2 + - pluginRef: queue-scorer + weight: 1 diff --git a/docs/architecture.md b/docs/architecture.md index b3215815c..c6ab17198 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -161,11 +161,13 @@ A complete configuration might look like this: apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig plugins: -- type: prefix-cache-scorer +- type: precise-prefix-cache-scorer parameters: - hashBlockSize: 5 - maxPrefixBlocksToMatch: 256 - lruCapacityPerServer: 31250 + indexerConfig: + tokenProcessorConfig: + blockSize: 5 + kvBlockIndexConfig: + maxPrefixBlocksToMatch: 256 - type: decode-filter - type: max-score-picker - type: single-profile-handler @@ -174,7 +176,7 @@ schedulingProfiles: plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - - pluginRef: prefix-cache-scorer + - pluginRef: precise-prefix-cache-scorer weight: 50 ``` @@ -204,15 +206,41 @@ Selects the profiles to use when running with disaggregated prefill/decode - **Type**: `pd-profile-handler` - **Parameters**: - - `threshold`: specifies the threshold at which there are enough new input tokens to send the request to prefill and then decode, vs just to decode. - - `hashBlockSize`: specifies the length of the prompt chunk that a block is keyed by. This must the same value used for the PrefixCachePlugin. - `decodeProfile`: specifies the name of the profile used for the decode scheduling. Only needed if the decode profile is not named `decode`. - `prefillProfile`: specifies the name of the profile used for the prefill scheduling. Only needed if the prefill profile is not named `prefill`. + - `deciderPluginName`: specifies the name of the decider plugin. Decider determines whether disaggregated PD should be executed + - `primaryPort`: the base port number used for data parallel communication. **Note:** When using this plugin you must also have a PrefixCachePlugin configured in the prefill and decode scheduling profiles. --- +#### Prefix Based Decider Plugin + +Type: `prefix-based-pd-decider` + +**Parameters** +- `nonCachedTokens`: length, in token, of the uncached part of the user input above which disaggregated PD is triggered. + +Note: `prepareDataPlugins` feature gate should be enabled + +**Example** +```yaml +kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins +plugins: +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 4 +- type: pd-profile-handler + parameters: + primaryPort: 8000 + deciderPluginName: prefix-based-pd-decider +``` + +--- + #### ByLabelSelector Filters out pods using a standard Kubernetes label selector. @@ -308,12 +336,14 @@ Configuration: - **Type**: `precise-prefix-cache-scorer` - **Parameters**: + - `tokenProcessorConfig`: Configuration for the `kvblock.TokenProcessor`. - `indexerConfig`: Configuration for the `kvcache.Indexer`. - `kvEventsConfig`: Configuration for the `kvevents.Pool`. See list of parameters at [llm-d-kv-cache/docs/configuration.md](https://github.com/llm-d/llm-d-kv-cache/blob/fa85b60207ba0a09daf23071e10ccb62d7977b40/docs/configuration.md). Note that in most cases you will only need to set: +- Model name in the `tokenizersPoolConfig` to match the model used in the vLLM deployment. - HuggingFace token for the `tokenizersPoolConfig` or the `tokenizersCacheDir` to a mounted directory containing the tokenizers. - For the HuggingFace token, the inference-scheduler also accepts the environment variable `HF_TOKEN` - this is the practical option for security. - **IMPORTANT**: Token processor's block-size and hash-seed to match those used in the vLLM deployment. @@ -325,15 +355,41 @@ Example configuration with the above parameters set: plugins: - type: precise-prefix-cache-scorer parameters: + tokenProcessorConfig: + blockSize: 64 # must match vLLM block size + hashSeed: "12345" # must match vLLM PYTHONHASHSEED env var indexerConfig: - tokenProcessorConfig: - blockSize: 64 - hashSeed: "12345" - tokenizersPoolConfig: - hf: - huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable - kvBlockIndexConfig: - enableMetrics: true + kvBlockIndexConfig: + enableMetrics: true + tokenizersPoolConfig: + modelName: hf-repo/model-name + hf: + huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable +``` + +Example configuration for automatic pod discovery in active-active multi-replica scheduler deployments: +```yaml + - type: precise-prefix-cache-scorer + parameters: + tokenProcessorConfig: + blockSize: 64 + hashSeed: "42" + indexerConfig: + tokenizersPoolConfig: + modelName: "Qwen/Qwen3-32B" + hf: + tokenizersCacheDir: "/tmp/tokenizers" + kvEventsConfig: + topicFilter: "kv@" + concurrency: 4 + discoverPods: true # enables automatic pod discovery for active-active HA + podDiscoveryConfig: + socketPort: 5556 +``` + +Where the vLLM engines are configured to emit KV-Events on port `5556` as follows: +```yaml + --kv-events-config "{\"enable_kv_cache_events\":true,\"publisher\":\"zmq\",\"endpoint\":\"tcp://*:5556\",\"topic\":\"kv@${POD_IP}@Qwen/Qwen3-32B\"}" ``` Example configuration with all parameters set: @@ -342,23 +398,26 @@ Example configuration with all parameters set: plugins: - type: precise-prefix-cache-scorer parameters: + tokenProcessorConfig: + blockSize: 16 + hashSeed: "12345" kvEventsConfig: - zmqEndpoint: tcp://*:5557 - topicFilter: kv@ - concurrency: 8 - kvCacheIndexerConfig: + topicFilter: "kv@" + concurrency: 4 + discoverPods: true # enables automatic pod discovery for active-active HA + podDiscoveryConfig: + socketPort: 5556 + indexerConfig: prefixStoreConfig: cacheSize: 500000 blockSize: 256 - tokenProcessorConfig: - blockSize: 16 - hashSeed: "12345" kvBlockIndexConfig: inMemoryConfig: size: 100000000 podCacheSize: 10 enableMetrics: true tokenizersPoolConfig: + modelName: hf-repo/model-name workersCount: 8 hf: huggingFaceToken: your_hf_token_here # automatically set by `HF_TOKEN` environment variable @@ -434,11 +493,13 @@ Example configuration: ```yaml plugins: - - type: prefix-cache-scorer + - type: precise-prefix-cache-scorer parameters: - hashBlockSize: 5 - maxPrefixBlocksToMatch: 256 - lruCapacityPerServer: 31250 + indexerConfig: + tokenProcessorConfig: + blockSize: 5 + kvBlockIndexConfig: + maxPrefixBlocksToMatch: 256 - type: no-hit-lru-scorer parameters: lruSize: 2048 @@ -450,7 +511,7 @@ schedulingProfiles: plugins: - pluginRef: decode-filter - pluginRef: max-score-picker - - pluginRef: prefix-cache-scorer + - pluginRef: precise-prefix-cache-scorer weight: 2 - pluginRef: no-hit-lru-scorer weight: 1 diff --git a/docs/disagg_pd.md b/docs/disagg_pd.md index 383b9c193..4c7cabd23 100644 --- a/docs/disagg_pd.md +++ b/docs/disagg_pd.md @@ -1,8 +1,8 @@ -# Disaggregated Prefill/Decode Inference Serving in llm-d +# Disaggregated Prefill/Decode Inference Serving in LLM-D ## Overview -This document describes the architecture and request lifecycle for enabling **disaggregated prefill and decode (P/D)** inference execution in the llm-d router. The architecture aims to improve flexibility, scalability, and performance by enabling separation of prefill and decode stages onto different workers. +This document describes the architecture and request lifecycle for enabling **disaggregated prefill and decode (P/D)** inference execution in the LLM-D router. The architecture aims to improve flexibility, scalability, and performance by enabling separation of prefill and decode stages onto different workers. This evolved version removes the requirement for sidecars on the **prefill node**, simplifying deployment while maintaining orchestration from the **decode node**. @@ -25,7 +25,7 @@ This evolved version removes the requirement for sidecars on the **prefill node* | **Decode Worker** | Handles decode stage and contains the sidecar for coordination | | **Sidecar (Decode)** | Orchestrates communication with prefill worker and manages lifecycle | | **Envoy Proxy** | Accepts OpenAI-style requests and forwards them to EPP | -| **EPP** | End Point Picker, makes scheduling decisions | +| **EPP** | Endpoint Picker, makes scheduling decisions | --- @@ -37,7 +37,7 @@ This evolved version removes the requirement for sidecars on the **prefill node* 2. **EPP Scheduling Decision** - EPP evaluates: - Prompt length - - KV cache hit probability + - KV-cache hit probability - System and pod load - Selects either: - **Single node** path (decode handles all) @@ -47,10 +47,10 @@ This evolved version removes the requirement for sidecars on the **prefill node* 3. **Execution** - Request lands on Decode Worker (as selected by EPP) - Decode sidecar coordinates: - - If `prefill_worker_id == nil`, runs both stages locally by passing request to local vllm - - If split: - - Sends prefill job to Prefill Worker with a special header `do_remote_decode=true` - - Upon receiving response from Prefill Worker runs decode stage + - If `x-prefiller-host-port` header doesn't exist, runs both stages locally by passing request to local vLLM + - If `x-prefiller-host-port` header exists: + - Sends the prefill job to the selected Prefill Worker with a special request field `do_remote_decode=true` + - Upon receiving the response from the Prefill Worker runs the decode stage 4. **Response Flow** - Response flows from decode sidecar β†’ Envoy β†’ EPP β†’ User @@ -59,11 +59,34 @@ This evolved version removes the requirement for sidecars on the **prefill node* ## Architectural Details + +```mermaid +sequenceDiagram + participant C as Client + participant I as Inference Gateway + participant DS as Decode Worker Sidecar + participant D as Decode Worker(vLLM) + participant P as Prefill Worker(vLLM) + + + C->>I: Inference Request + I->>DS: Request is sent to the Decode Worker Sidecar
with the selected Prefill worker set in a header. + DS->>P: Remote Prefill with prompt(max_tokens=1) + P-->>P: Run prefill + P->>DS: Remote kv parameters + DS->> D: Request is sent to the Decode Worker (vLLM) with remote_prefill true,
prefill ID and memory block IDs + D-->>P: Read kv-cache + D-->>D: Schedule decode into queue & run decode + D->>DS: Inference Response + DS->>I: Inference Response + I->>C: Inference Response +``` + ### Sidecar Responsibilities (Decode Only) - Receives EPP metadata (decode pod, optional prefill pod) - Sends request to prefill -- Waits and validates result +- Waits for the result and validates it - Launches local decode job - Sends final response @@ -73,33 +96,21 @@ This evolved version removes the requirement for sidecars on the **prefill node* ## Worker Selection Logic -- **Decode Worker**: - - Prefer longest prefix match / KV cache utilization (depends on available scorers) - -- **Prefill Worker**: - - High prefix-cache hit rate - - Low load +- **Decode/Prefill Worker**: + - Prefer longest prefix match/kv-cache utilization (depends on available scorers) and low load > **Skip prefill worker** when: -> - Prefix match/kv cache hit is high +> - Prefix match/kv-cache hit is high > - Prompt is very short --- -## vLLM and LMCache Integration - -- **vLLM changes** (or wrapper APIs): - - `save()`, `load()` APIs - - `done_sending`, `done_receiving` - - Connector API supporting async transfer - ---- ## Drawbacks & Limitations -- Slight increase in TTFT for split P/D +- Slight increase in TTFT for disaggregated P/D - Possibility of stranded memory on prefill crash -- Need for timeout and retry logic +- The need for timeout and retry logic --- @@ -115,17 +126,17 @@ This evolved version removes the requirement for sidecars on the **prefill node* ## Future Considerations - Cache coordinate -- Pre allocation of kv blocks in decode node , push cache from prefill to decode worker during calculation +- Pre-allocation of kv blocks in the decode node, push cache from the prefill to the decode worker during calculation --- ## Integrating External Prefill/Decode Workloads -The llm-d inference scheduler supports integration with external disaggregated prefill/decode (P/D) workloads other inference frameworks that follow the same P/D separation pattern but use **different Kubernetes Pod labeling conventions**. +The LLM-D inference scheduler supports integration with external disaggregated prefill/decode (P/D) workloads other inference frameworks that follow the same P/D separation pattern but use **different Kubernetes Pod labeling conventions**. ### Labeling Convention Flexibility -By default, llm-d uses the label key `llm-d.ai/role` with values: +By default, LLM-D uses the label key `llm-d.ai/role` with values: - `"prefill"` β†’ prefill-only pods - `"decode"` or `"both"` β†’ decode-capable pods @@ -144,6 +155,8 @@ Below is a minimal `EndpointPickerConfig` that enables integration with workload ```yaml apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: # Prefill selection: match Pods with label role=prefill - type: by-label @@ -159,15 +172,18 @@ plugins: validValues: ["decode"] - type: prefix-cache-scorer parameters: - hashBlockSize: 5 + autoTune: false + blockSize: 5 maxPrefixBlocksToMatch: 256 lruCapacityPerServer: 31250 - type: max-score-picker - type: prefill-header-handler - - type: pd-profile-handler + - type: prefix-based-pd-decider parameters: - threshold: 0 - hashBlockSize: 5 + nonCachedTokens: 8 + - type: pd-profile-handler + parameters: + deciderPluginName: prefix-based-pd-decider primaryPort: 8000 schedulingProfiles: - name: prefill @@ -175,13 +191,11 @@ schedulingProfiles: - pluginRef: "prefill-pods" - pluginRef: "max-score-picker" - pluginRef: "prefix-cache-scorer" - weight: 2 - name: decode plugins: - pluginRef: "decode-pods" - pluginRef: "max-score-picker" - pluginRef: "prefix-cache-scorer" - weight: 2 ``` --- @@ -190,6 +204,59 @@ schedulingProfiles: ![Disaggregated Prefill/Decode Architecture](./images/dp_architecture.png) +--- +## PD Deciders + +PD deciders are pd handler plugins responsible for determining whether disaggregated P/D should be executed for a given request, based on the properties of the request prompt. + + +### Prefix-Based PD Decider + +The `prefix-based-pd-decider` plugin makes the disaggregation decision according to the length of the non-cached suffix of the prompt relative to tokens already cached on the selected decode pod. + +**How It Works** +- Once a decode pod is selected, the decider checks how many tokens from the incoming prompt have already been sent to this pod + +- If the remaining non-cached suffix length is longer than the configured threshold (nonCachedTokens), disaggregation is triggered β€” the prefill will run remotely on a prefill pod, and decode locally on the decode pod + +- If the non-cached suffix is shorter or equal to the threshold, the full request runs locally on the decode worker without remote prefill + +**Configuration** +```yaml +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 8 +``` + +**Parameter:** + +- `nonCachedTokens`: Number of non-cached tokens that trigger disaggregation + - If set to 0, disaggregation always occurs for all requests + +**Feature Gate Requirement** +To activate this decider, ensure the following feature gate is enabled in your EndpointPickerConfig + +```yaml +featureGates: +- prepareDataPlugins +``` + + +### Always-Disagg PD Decider +The `always-disagg-pd-decider` is a simpler alternative used mainly for testing or benchmarking. +It always triggers disaggregation, regardless of prefix cache state or prompt characteristics. + +**Configuration example:** + +```yaml +- type: always-disagg-pd-decider +``` + +**Notes:** +This plugin accepts no parameters. + +It’s useful for validating end-to-end prefill/decode splitting and comparing system performance under forced disaggregation. + --- ## References diff --git a/go.mod b/go.mod index ec6a25010..57c502280 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module github.com/llm-d/llm-d-inference-scheduler -go 1.24.1 +go 1.24.9 -toolchain go1.24.2 +toolchain go1.24.12 require ( github.com/go-logr/logr v1.4.3 @@ -10,9 +10,9 @@ require ( github.com/google/uuid v1.6.0 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jellydator/ttlcache/v3 v3.4.0 - github.com/llm-d/llm-d-kv-cache-manager v0.4.0 - github.com/onsi/ginkgo/v2 v2.27.4 - github.com/onsi/gomega v1.39.0 + github.com/llm-d/llm-d-kv-cache v0.5.0 + github.com/onsi/ginkgo/v2 v2.28.1 + github.com/onsi/gomega v1.39.1 github.com/openai/openai-go v1.12.0 github.com/prometheus/client_golang v1.23.2 github.com/stretchr/testify v1.11.1 @@ -23,10 +23,10 @@ require ( k8s.io/apimachinery v0.34.3 k8s.io/client-go v0.34.3 k8s.io/component-base v0.34.3 - k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d - sigs.k8s.io/controller-runtime v0.22.4 + k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 + sigs.k8s.io/controller-runtime v0.22.5 sigs.k8s.io/gateway-api v1.4.1 - sigs.k8s.io/gateway-api-inference-extension v1.3.0-rc.2 + sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a ) require ( @@ -46,7 +46,7 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/emicklei/go-restful/v3 v3.13.0 // indirect github.com/envoyproxy/go-control-plane/envoy v1.36.0 // indirect - github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect + github.com/envoyproxy/protoc-gen-validate v1.3.0 // indirect github.com/evanphx/json-patch/v5 v5.9.11 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect @@ -54,23 +54,32 @@ require ( github.com/go-errors/errors v1.4.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/zapr v1.3.0 // indirect - github.com/go-openapi/jsonpointer v0.21.2 // indirect - github.com/go-openapi/jsonreference v0.21.0 // indirect - github.com/go-openapi/swag v0.23.1 // indirect + github.com/go-openapi/jsonpointer v0.22.1 // indirect + github.com/go-openapi/jsonreference v0.21.3 // indirect + github.com/go-openapi/swag v0.25.4 // indirect + github.com/go-openapi/swag/cmdutils v0.25.4 // indirect + github.com/go-openapi/swag/conv v0.25.4 // indirect + github.com/go-openapi/swag/fileutils v0.25.4 // indirect + github.com/go-openapi/swag/jsonname v0.25.4 // indirect + github.com/go-openapi/swag/jsonutils v0.25.4 // indirect + github.com/go-openapi/swag/loading v0.25.4 // indirect + github.com/go-openapi/swag/mangling v0.25.4 // indirect + github.com/go-openapi/swag/netutils v0.25.4 // indirect + github.com/go-openapi/swag/stringutils v0.25.4 // indirect + github.com/go-openapi/swag/typeutils v0.25.4 // indirect + github.com/go-openapi/swag/yamlutils v0.25.4 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/cel-go v0.26.0 // indirect github.com/google/gnostic-models v0.7.0 // indirect - github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8 // indirect + github.com/google/pprof v0.0.0-20260115054156-294ebfa9ad83 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect github.com/grafana/regexp v0.0.0-20250905093917-f7b3be9d1853 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect - github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/kylelemons/godebug v1.1.0 // indirect - github.com/mailru/easyjson v0.9.0 // indirect github.com/moby/spdystream v0.5.0 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.3-0.20250322232337-35a7c28c31ee // indirect @@ -83,7 +92,7 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/prometheus/common v0.67.5 // indirect github.com/prometheus/procfs v0.17.0 // indirect - github.com/prometheus/prometheus v0.308.1 // indirect + github.com/prometheus/prometheus v0.309.1 // indirect github.com/redis/go-redis/v9 v9.11.0 // indirect github.com/spf13/cobra v1.9.1 // indirect github.com/spf13/pflag v1.0.10 // indirect @@ -97,7 +106,7 @@ require ( github.com/x448/float16 v0.8.4 // indirect github.com/xlab/treeprint v1.2.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect go.opentelemetry.io/otel v1.39.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 // indirect go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.39.0 // indirect @@ -112,16 +121,16 @@ require ( go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/exp v0.0.0-20250808145144-a408d31f581a // indirect - golang.org/x/mod v0.30.0 // indirect - golang.org/x/net v0.48.0 // indirect + golang.org/x/mod v0.32.0 // indirect + golang.org/x/net v0.49.0 // indirect golang.org/x/oauth2 v0.34.0 // indirect - golang.org/x/sys v0.39.0 // indirect - golang.org/x/term v0.38.0 // indirect - golang.org/x/text v0.32.0 // indirect - golang.org/x/time v0.13.0 // indirect - golang.org/x/tools v0.39.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/term v0.39.0 // indirect + golang.org/x/text v0.33.0 // indirect + golang.org/x/time v0.14.0 // indirect + golang.org/x/tools v0.41.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect diff --git a/go.sum b/go.sum index dd2983e89..17c6e64b2 100644 --- a/go.sum +++ b/go.sum @@ -6,14 +6,14 @@ cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIi cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.9.0 h1:pDUj4QMoPejqq20dK0Pg2N4yG9zIkYGdBtwLoEkH9Zs= cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.19.1 h1:5YTBM8QDVIBN3sxBil89WfdAAqDZbyJTgh688DSxX5w= -github.com/Azure/azure-sdk-for-go/sdk/azcore v1.19.1/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.12.0 h1:wL5IEG5zb7BVv1Kv0Xm92orq+5hB5Nipn3B5tn4Rqfk= -github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.12.0/go.mod h1:J7MUC/wtRpfGVbQ5sIItY5/FuVWmvzlY21WAOfQnq/I= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0 h1:JXg2dwJUmPB9JmtVmdEB16APJ7jurfbY5jnfXpJoRMc= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.20.0/go.mod h1:YD5h/ldMsG0XiIw7PdyNhLxaM317eFh5yNLccNfGdyw= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1 h1:Hk5QBxZQC1jb2Fwj6mpzme37xbCDdNTxU7O9eb5+LB4= +github.com/Azure/azure-sdk-for-go/sdk/azidentity v1.13.1/go.mod h1:IYus9qsFobWIc2YVwe/WPjcnyCkPKtnHAqUYeebc8z0= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2 h1:9iefClla7iYpfYWdzPCRDozdmndjTm8DXdpCzPajMgA= github.com/Azure/azure-sdk-for-go/sdk/internal v1.11.2/go.mod h1:XtLgD3ZD34DAaVIIAyG3objl5DynM3CQ/vMcbBNJZGI= -github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0 h1:XkkQbfMyuH2jTSjQjSoihryI8GINRcs4xp8lNawg0FI= -github.com/AzureAD/microsoft-authentication-library-for-go v1.5.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0 h1:XRzhVemXdgvJqCH0sFfrBUTnUJSBrBf7++ypk+twtRs= +github.com/AzureAD/microsoft-authentication-library-for-go v1.6.0/go.mod h1:HKpQxkWaGLJ+D/5H8QRpyQXA1eKjxkFlOMwck5+33Jk= github.com/Masterminds/semver/v3 v3.4.0 h1:Zog+i5UMtVoCU8oKka5P7i9q9HgrJeGzI9SA1Xbatp0= github.com/Masterminds/semver/v3 v3.4.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= github.com/alecthomas/units v0.0.0-20240927000941-0f3dac36c52b h1:mimo19zliBX/vSQ6PWWSL9lK8qwHozUj03+zLoEB8O0= @@ -24,32 +24,34 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8 github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio= github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs= -github.com/aws/aws-sdk-go-v2 v1.39.6 h1:2JrPCVgWJm7bm83BDwY5z8ietmeJUbh3O2ACnn+Xsqk= -github.com/aws/aws-sdk-go-v2 v1.39.6/go.mod h1:c9pm7VwuW0UPxAEYGyTmyurVcNrbF6Rt/wixFqDhcjE= -github.com/aws/aws-sdk-go-v2/config v1.31.17 h1:QFl8lL6RgakNK86vusim14P2k8BFSxjvUkcWLDjgz9Y= -github.com/aws/aws-sdk-go-v2/config v1.31.17/go.mod h1:V8P7ILjp/Uef/aX8TjGk6OHZN6IKPM5YW6S78QnRD5c= -github.com/aws/aws-sdk-go-v2/credentials v1.18.21 h1:56HGpsgnmD+2/KpG0ikvvR8+3v3COCwaF4r+oWwOeNA= -github.com/aws/aws-sdk-go-v2/credentials v1.18.21/go.mod h1:3YELwedmQbw7cXNaII2Wywd+YY58AmLPwX4LzARgmmA= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13 h1:T1brd5dR3/fzNFAQch/iBKeX07/ffu/cLu+q+RuzEWk= -github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.13/go.mod h1:Peg/GBAQ6JDt+RoBf4meB1wylmAipb7Kg2ZFakZTlwk= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13 h1:a+8/MLcWlIxo1lF9xaGt3J/u3yOZx+CdSveSNwjhD40= -github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.13/go.mod h1:oGnKwIYZ4XttyU2JWxFrwvhF6YKiK/9/wmE3v3Iu9K8= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13 h1:HBSI2kDkMdWz4ZM7FjwE7e/pWDEZ+nR95x8Ztet1ooY= -github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.13/go.mod h1:YE94ZoDArI7awZqJzBAZ3PDD2zSfuP7w6P2knOzIn8M= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/config v1.32.6 h1:hFLBGUKjmLAekvi1evLi5hVvFQtSo3GYwi+Bx4lpJf8= +github.com/aws/aws-sdk-go-v2/config v1.32.6/go.mod h1:lcUL/gcd8WyjCrMnxez5OXkO3/rwcNmvfno62tnXNcI= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6 h1:F9vWao2TwjV2MyiyVS+duza0NIRtAslgLUM0vTA1ZaE= +github.com/aws/aws-sdk-go-v2/credentials v1.19.6/go.mod h1:SgHzKjEVsdQr6Opor0ihgWtkWdfRAIwxYzSJ8O85VHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16 h1:80+uETIWS1BqjnN9uJ0dBUaETh+P1XwFy5vwHwK5r9k= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.16/go.mod h1:wOOsYuxYuB/7FlnVtzeBYRcjSRtQpAW0hCP7tIULMwo= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3 h1:x2Ibm/Af8Fi+BH+Hsn9TXGdT+hKbDd5XOTZxTMxDk7o= -github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.3/go.mod h1:IW1jwyrQgMdhisceG8fQLmQIydcT/jWY21rFhzgaKwo= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13 h1:kDqdFvMY4AtKoACfzIGD8A0+hbT41KTKF//gq7jITfM= -github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.13/go.mod h1:lmKuogqSU3HzQCwZ9ZtcqOc5XGMqtDK7OIc2+DxiUEg= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.1 h1:0JPwLz1J+5lEOfy/g0SURC9cxhbQ1lIMHMa+AHZSzz0= -github.com/aws/aws-sdk-go-v2/service/sso v1.30.1/go.mod h1:fKvyjJcz63iL/ftA6RaM8sRCtN4r4zl4tjL3qw5ec7k= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5 h1:OWs0/j2UYR5LOGi88sD5/lhN6TDLG6SfA7CqsQO9zF0= -github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.5/go.mod h1:klO+ejMvYsB4QATfEOIXk8WAEwN4N0aBfJpvC+5SZBo= -github.com/aws/aws-sdk-go-v2/service/sts v1.39.1 h1:mLlUgHn02ue8whiR4BmxxGJLR2gwU6s6ZzJ5wDamBUs= -github.com/aws/aws-sdk-go-v2/service/sts v1.39.1/go.mod h1:E19xDjpzPZC7LS2knI9E6BaRFDK43Eul7vd6rSq2HWk= -github.com/aws/smithy-go v1.23.2 h1:Crv0eatJUQhaManss33hS5r40CG3ZFH+21XSkqMrIUM= -github.com/aws/smithy-go v1.23.2/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4 h1:HpI7aMmJ+mm1wkSHIA2t5EaFFv5EFYXePW30p1EIrbQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.4/go.mod h1:C5RdGMYGlfM0gYq/tifqgn4EbyX99V15P2V3R+VHbQU= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8 h1:aM/Q24rIlS3bRAhTyFurowU8A0SMyGDtEOY/l/s/1Uw= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.8/go.mod h1:+fWt2UHSb4kS7Pu8y+BMBvJF0EWx+4H0hzNwtDNRTrg= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12 h1:AHDr0DaHIAo8c9t1emrzAlVDFp+iMMKnPdYy6XO4MCE= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.12/go.mod h1:GQ73XawFFiWxyWXMHWfhiomvP3tXtdNar/fi8z18sx0= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3 h1:6df1vn4bBlDDo4tARvBm7l6KA9iVMnE3NWizDeWSrps= github.com/bboreham/go-loser v0.0.0-20230920113527-fcc2c21820a3/go.mod h1:CIWtjkly68+yqLPbvwwR/fjNJA/idrtULjZWh2v1ys0= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -87,8 +89,8 @@ github.com/emicklei/go-restful/v3 v3.13.0 h1:C4Bl2xDndpU6nJ4bc1jXd+uTmYPVUwkD6bF github.com/emicklei/go-restful/v3 v3.13.0/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/envoyproxy/go-control-plane/envoy v1.36.0 h1:yg/JjO5E7ubRyKX3m07GF3reDNEnfOboJ0QySbH736g= github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= -github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= -github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= +github.com/envoyproxy/protoc-gen-validate v1.3.0 h1:TvGH1wof4H33rezVKWSpqKz5NXWg5VPuZ0uONDT6eb4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= github.com/evanphx/json-patch v0.5.2 h1:xVCHIVMUu1wtM/VkR9jVZ45N3FhZfYMMYGorLCR8P3k= github.com/evanphx/json-patch v0.5.2/go.mod h1:ZWS5hhDbVDyob71nXKNL0+PWn6ToqBHMikGIFbs31qQ= github.com/evanphx/json-patch/v5 v5.9.11 h1:/8HVnzMq13/3x9TPvjG08wUGqBTmZBsCWzjTM0wiaDU= @@ -114,12 +116,40 @@ github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-logr/zapr v1.3.0 h1:XGdV8XW8zdwFiwOA2Dryh1gj2KRQyOOoNmBy4EplIcQ= github.com/go-logr/zapr v1.3.0/go.mod h1:YKepepNBd1u/oyhd/yQmtjVXmm9uML4IXUgMOwR8/Gg= -github.com/go-openapi/jsonpointer v0.21.2 h1:AqQaNADVwq/VnkCmQg6ogE+M3FOsKTytwges0JdwVuA= -github.com/go-openapi/jsonpointer v0.21.2/go.mod h1:50I1STOfbY1ycR8jGz8DaMeLCdXiI6aDteEdRNNzpdk= -github.com/go-openapi/jsonreference v0.21.0 h1:Rs+Y7hSXT83Jacb7kFyjn4ijOuVGSvOdF2+tg1TRrwQ= -github.com/go-openapi/jsonreference v0.21.0/go.mod h1:LmZmgsrTkVg9LG4EaHeY8cBDslNPMo06cago5JNLkm4= -github.com/go-openapi/swag v0.23.1 h1:lpsStH0n2ittzTnbaSloVZLuB5+fvSY/+hnagBjSNZU= -github.com/go-openapi/swag v0.23.1/go.mod h1:STZs8TbRvEQQKUA+JZNAm3EWlgaOBGpyFDqQnDHMef0= +github.com/go-openapi/jsonpointer v0.22.1 h1:sHYI1He3b9NqJ4wXLoJDKmUmHkWy/L7rtEo92JUxBNk= +github.com/go-openapi/jsonpointer v0.22.1/go.mod h1:pQT9OsLkfz1yWoMgYFy4x3U5GY5nUlsOn1qSBH5MkCM= +github.com/go-openapi/jsonreference v0.21.3 h1:96Dn+MRPa0nYAR8DR1E03SblB5FJvh7W6krPI0Z7qMc= +github.com/go-openapi/jsonreference v0.21.3/go.mod h1:RqkUP0MrLf37HqxZxrIAtTWW4ZJIK1VzduhXYBEeGc4= +github.com/go-openapi/swag v0.25.4 h1:OyUPUFYDPDBMkqyxOTkqDYFnrhuhi9NR6QVUvIochMU= +github.com/go-openapi/swag v0.25.4/go.mod h1:zNfJ9WZABGHCFg2RnY0S4IOkAcVTzJ6z2Bi+Q4i6qFQ= +github.com/go-openapi/swag/cmdutils v0.25.4 h1:8rYhB5n6WawR192/BfUu2iVlxqVR9aRgGJP6WaBoW+4= +github.com/go-openapi/swag/cmdutils v0.25.4/go.mod h1:pdae/AFo6WxLl5L0rq87eRzVPm/XRHM3MoYgRMvG4A0= +github.com/go-openapi/swag/conv v0.25.4 h1:/Dd7p0LZXczgUcC/Ikm1+YqVzkEeCc9LnOWjfkpkfe4= +github.com/go-openapi/swag/conv v0.25.4/go.mod h1:3LXfie/lwoAv0NHoEuY1hjoFAYkvlqI/Bn5EQDD3PPU= +github.com/go-openapi/swag/fileutils v0.25.4 h1:2oI0XNW5y6UWZTC7vAxC8hmsK/tOkWXHJQH4lKjqw+Y= +github.com/go-openapi/swag/fileutils v0.25.4/go.mod h1:cdOT/PKbwcysVQ9Tpr0q20lQKH7MGhOEb6EwmHOirUk= +github.com/go-openapi/swag/jsonname v0.25.4 h1:bZH0+MsS03MbnwBXYhuTttMOqk+5KcQ9869Vye1bNHI= +github.com/go-openapi/swag/jsonname v0.25.4/go.mod h1:GPVEk9CWVhNvWhZgrnvRA6utbAltopbKwDu8mXNUMag= +github.com/go-openapi/swag/jsonutils v0.25.4 h1:VSchfbGhD4UTf4vCdR2F4TLBdLwHyUDTd1/q4i+jGZA= +github.com/go-openapi/swag/jsonutils v0.25.4/go.mod h1:7OYGXpvVFPn4PpaSdPHJBtF0iGnbEaTk8AvBkoWnaAY= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4 h1:IACsSvBhiNJwlDix7wq39SS2Fh7lUOCJRmx/4SN4sVo= +github.com/go-openapi/swag/jsonutils/fixtures_test v0.25.4/go.mod h1:Mt0Ost9l3cUzVv4OEZG+WSeoHwjWLnarzMePNDAOBiM= +github.com/go-openapi/swag/loading v0.25.4 h1:jN4MvLj0X6yhCDduRsxDDw1aHe+ZWoLjW+9ZQWIKn2s= +github.com/go-openapi/swag/loading v0.25.4/go.mod h1:rpUM1ZiyEP9+mNLIQUdMiD7dCETXvkkC30z53i+ftTE= +github.com/go-openapi/swag/mangling v0.25.4 h1:2b9kBJk9JvPgxr36V23FxJLdwBrpijI26Bx5JH4Hp48= +github.com/go-openapi/swag/mangling v0.25.4/go.mod h1:6dxwu6QyORHpIIApsdZgb6wBk/DPU15MdyYj/ikn0Hg= +github.com/go-openapi/swag/netutils v0.25.4 h1:Gqe6K71bGRb3ZQLusdI8p/y1KLgV4M/k+/HzVSqT8H0= +github.com/go-openapi/swag/netutils v0.25.4/go.mod h1:m2W8dtdaoX7oj9rEttLyTeEFFEBvnAx9qHd5nJEBzYg= +github.com/go-openapi/swag/stringutils v0.25.4 h1:O6dU1Rd8bej4HPA3/CLPciNBBDwZj9HiEpdVsb8B5A8= +github.com/go-openapi/swag/stringutils v0.25.4/go.mod h1:GTsRvhJW5xM5gkgiFe0fV3PUlFm0dr8vki6/VSRaZK0= +github.com/go-openapi/swag/typeutils v0.25.4 h1:1/fbZOUN472NTc39zpa+YGHn3jzHWhv42wAJSN91wRw= +github.com/go-openapi/swag/typeutils v0.25.4/go.mod h1:Ou7g//Wx8tTLS9vG0UmzfCsjZjKhpjxayRKTHXf2pTE= +github.com/go-openapi/swag/yamlutils v0.25.4 h1:6jdaeSItEUb7ioS9lFoCZ65Cne1/RZtPBZ9A56h92Sw= +github.com/go-openapi/swag/yamlutils v0.25.4/go.mod h1:MNzq1ulQu+yd8Kl7wPOut/YHAAU/H6hL91fF+E2RFwc= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2 h1:0+Y41Pz1NkbTHz8NngxTuAXxEodtNSI1WG1c/m5Akw4= +github.com/go-openapi/testify/enable/yaml/v2 v2.0.2/go.mod h1:kme83333GCtJQHXQ8UKX3IBZu6z8T5Dvy5+CW3NLUUg= +github.com/go-openapi/testify/v2 v2.0.2 h1:X999g3jeLcoY8qctY/c/Z8iBHTbwLz7R2WXd6Ub6wls= +github.com/go-openapi/testify/v2 v2.0.2/go.mod h1:HCPmvFFnheKK2BuwSA0TbbdxJ3I16pjwMkYkP4Ywn54= github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI= github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8= github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= @@ -143,14 +173,14 @@ github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= -github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8 h1:ZI8gCoCjGzPsum4L21jHdQs8shFBIQih1TM9Rd/c+EQ= -github.com/google/pprof v0.0.0-20250923004556-9e5a51aed1e8/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= +github.com/google/pprof v0.0.0-20260115054156-294ebfa9ad83 h1:z2ogiKUYzX5Is6zr/vP9vJGqPwcdqsWjOt+V8J7+bTc= +github.com/google/pprof v0.0.0-20260115054156-294ebfa9ad83/go.mod h1:MxpfABSjhmINe3F1It9d+8exIHFvUqtLIRCdOGNXqiI= github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.6 h1:GW/XbdyBFQ8Qe+YAmFU9uHLo7OnF5tL52HFAgMmyrf4= -github.com/googleapis/enterprise-certificate-proxy v0.3.6/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= +github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= +github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo= @@ -165,8 +195,6 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jellydator/ttlcache/v3 v3.4.0 h1:YS4P125qQS0tNhtL6aeYkheEaB/m8HCqdMMP4mnWdTY= github.com/jellydator/ttlcache/v3 v3.4.0/go.mod h1:Hw9EgjymziQD3yGsQdf1FqFdpp7YjFMd4Srg5EJlgD4= -github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= -github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/joshdk/go-junit v1.0.0 h1:S86cUKIdwBHWwA6xCmFlf3RTLfVXYQfvanM5Uh+K6GE= github.com/joshdk/go-junit v1.0.0/go.mod h1:TiiV0PqkaNfFXjEiyjWM3XXrhVyCa1K4Zfga6W52ung= github.com/jpillora/backoff v1.0.0 h1:uvFg412JmmHBHw7iwprIxkPMI+sGQ4kzOWsMeHnm2EA= @@ -183,10 +211,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= -github.com/llm-d/llm-d-kv-cache-manager v0.4.0 h1:MBWVpDW0PWsqNJEEAW1esrJW+Xavb0a7w14tCJWWyRY= -github.com/llm-d/llm-d-kv-cache-manager v0.4.0/go.mod h1:ZlK7MCuz5D/weLeHyNKEmVF/eJZDyYn3XyRowTihq9o= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= +github.com/llm-d/llm-d-kv-cache v0.5.0 h1:XQpkbg1yedGxn2w7QS/v/2YtrOZGp16Sw49KvMlQ1s0= +github.com/llm-d/llm-d-kv-cache v0.5.0/go.mod h1:XyhzHBYeOWamBMPkuRySB5nJ0zzQpK/mbuXKqJRFT6A= github.com/maruel/natural v1.1.1 h1:Hja7XhhmvEFhcByqDoHz9QZbkWey+COd9xWfCfn1ioo= github.com/maruel/natural v1.1.1/go.mod h1:v+Rfd79xlw1AgVBjbO0BEQmptqb5HvL/k9GRHB7ZKEg= github.com/mfridman/tparse v0.18.0 h1:wh6dzOKaIwkUGyKgOntDW4liXSo37qg5AXbIhkMV3vE= @@ -210,10 +236,10 @@ github.com/mxk/go-flowrate v0.0.0-20140419014527-cca7078d478f/go.mod h1:ZdcZmHo+ github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid/v2 v2.1.1 h1:suPZ4ARWLOJLegGFiZZ1dFAkqzhMjL3J1TzI+5wHz8s= github.com/oklog/ulid/v2 v2.1.1/go.mod h1:rcEKHmBBKfef9DhnvX7y1HZBYxjXb0cP5ExxNsTT1QQ= -github.com/onsi/ginkgo/v2 v2.27.4 h1:fcEcQW/A++6aZAZQNUmNjvA9PSOzefMJBerHJ4t8v8Y= -github.com/onsi/ginkgo/v2 v2.27.4/go.mod h1:ArE1D/XhNXBXCBkKOLkbsb2c81dQHCRcF5zwn/ykDRo= -github.com/onsi/gomega v1.39.0 h1:y2ROC3hKFmQZJNFeGAMeHZKkjBL65mIZcvrLQBF9k6Q= -github.com/onsi/gomega v1.39.0/go.mod h1:ZCU1pkQcXDO5Sl9/VVEGlDyp+zm0m1cmeG5TOzLgdh4= +github.com/onsi/ginkgo/v2 v2.28.1 h1:S4hj+HbZp40fNKuLUQOYLDgZLwNUVn19N3Atb98NCyI= +github.com/onsi/ginkgo/v2 v2.28.1/go.mod h1:CLtbVInNckU3/+gC8LzkGUb9oF+e8W8TdUsxPwvdOgE= +github.com/onsi/gomega v1.39.1 h1:1IJLAad4zjPn2PsnhH70V4DKRFlrCzGBNrNaru+Vf28= +github.com/onsi/gomega v1.39.1/go.mod h1:hL6yVALoTOxeWudERyfppUcZXjMwIMLnuSfruD2lcfg= github.com/openai/openai-go v1.12.0 h1:NBQCnXzqOTv5wsgNC36PrFEiskGfO5wccfCWDo9S1U0= github.com/openai/openai-go v1.12.0/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/pebbe/zmq4 v1.4.0 h1:gO5P92Ayl8GXpPZdYcD62Cwbq0slSBVVQRIXwGSJ6eQ= @@ -239,8 +265,8 @@ github.com/prometheus/otlptranslator v1.0.0 h1:s0LJW/iN9dkIH+EnhiD3BlkkP5QVIUVEo github.com/prometheus/otlptranslator v1.0.0/go.mod h1:vRYWnXvI6aWGpsdY/mOT/cbeVRBlPWtBNDb7kGR3uKM= github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7DuK0= github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= -github.com/prometheus/prometheus v0.308.1 h1:ApMNI/3/es3Ze90Z7CMb+wwU2BsSYur0m5VKeqHj7h4= -github.com/prometheus/prometheus v0.308.1/go.mod h1:aHjYCDz9zKRyoUXvMWvu13K9XHOkBB12XrEqibs3e0A= +github.com/prometheus/prometheus v0.309.1 h1:jutK6eCYDpWdPTUbVbkcQsNCMO9CCkSwjQRMLds4jSo= +github.com/prometheus/prometheus v0.309.1/go.mod h1:d+dOGiVhuNDa4MaFXHVdnUBy/CzqlcNTooR8oM1wdTU= github.com/prometheus/sigv4 v0.3.0 h1:QIG7nTbu0JTnNidGI1Uwl5AGVIChWUACxn2B/BQ1kms= github.com/prometheus/sigv4 v0.3.0/go.mod h1:fKtFYDus2M43CWKMNtGvFNHGXnAJJEGZbiYCmVp/F8I= github.com/redis/go-redis/v9 v9.11.0 h1:E3S08Gl/nJNn5vkxd2i78wZxWAPNZgUNTp8WIJUAiIs= @@ -293,8 +319,8 @@ github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0 h1:RbKq8BG0FI8OiXhBfcRtqqHcZcka+gU3cskNuf05R18= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.63.0/go.mod h1:h06DGIukJOevXaj/xrNjhi/2098RZzcLTbc0jDAUbsg= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ= go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.39.0 h1:f0cb2XPmrqn4XMy9PNliTgRKJgS5WcL/u0/WRYGz4t0= @@ -328,20 +354,20 @@ go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20250808145144-a408d31f581a h1:Y+7uR/b1Mw2iSXZ3G//1haIiSElDQZ8KWh0h+sZPG90= golang.org/x/exp v0.0.0-20250808145144-a408d31f581a/go.mod h1:rT6SFzZ7oxADUDx58pcaKFTcZ+inxAa9fTrYx/uVYwg= golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= +golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -352,22 +378,22 @@ golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= -golang.org/x/time v0.13.0 h1:eUlYslOIt32DgYD6utsuUeHs4d7AsEYLuIAdg7FlYgI= -golang.org/x/time v0.13.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= +golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= +golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= +golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -376,10 +402,10 @@ gomodules.xyz/jsonpatch/v2 v2.4.0 h1:Ci3iUJyx9UeRx7CeFN8ARgGbkESwJK+KB9lLcWxY/Zw gomodules.xyz/jsonpatch/v2 v2.4.0/go.mod h1:AH3dM2RI6uoBZxn3LVrfvJ3E0/9dG4cSrbuBJT4moAY= gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/api v0.252.0 h1:xfKJeAJaMwb8OC9fesr369rjciQ704AjU/psjkKURSI= -google.golang.org/api v0.252.0/go.mod h1:dnHOv81x5RAmumZ7BWLShB/u7JZNeyalImxHmtTHxqw= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA= +google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4= +google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 h1:7LRqPCEdE4TP4/9psdaB7F2nhZFfBiGJomA5sojLWdU= +google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= @@ -412,16 +438,16 @@ k8s.io/klog/v2 v2.130.1 h1:n9Xl7H1Xvksem4KFG4PYbdQCQxqc/tTUyrgXaOhHSzk= k8s.io/klog/v2 v2.130.1/go.mod h1:3Jpz1GvMt720eyJH1ckRHK1EDfpxISzJ7I9OYgaDtPE= k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3 h1:liMHz39T5dJO1aOKHLvwaCjDbf07wVh6yaUlTpunnkE= k8s.io/kube-openapi v0.0.0-20250814151709-d7b6acb124c3/go.mod h1:UZ2yyWbFTpuhSbFhv24aGNOdoRdJZgsIObGBUaYVsts= -k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d h1:wAhiDyZ4Tdtt7e46e9M5ZSAJ/MnPGPs+Ki1gHw4w1R0= -k8s.io/utils v0.0.0-20250820121507-0af2bda4dd1d/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= +k8s.io/utils v0.0.0-20251002143259-bc988d571ff4 h1:SjGebBtkBqHFOli+05xYbK8YF1Dzkbzn+gDM4X9T4Ck= +k8s.io/utils v0.0.0-20251002143259-bc988d571ff4/go.mod h1:OLgZIPagt7ERELqWJFomSt595RzquPNLL48iOWgYOg0= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2 h1:jpcvIRr3GLoUoEKRkHKSmGjxb6lWwrBlJsXc+eUYQHM= sigs.k8s.io/apiserver-network-proxy/konnectivity-client v0.31.2/go.mod h1:Ve9uj1L+deCXFrPOk1LpFXqTg7LCFzFso6PA48q/XZw= -sigs.k8s.io/controller-runtime v0.22.4 h1:GEjV7KV3TY8e+tJ2LCTxUTanW4z/FmNB7l327UfMq9A= -sigs.k8s.io/controller-runtime v0.22.4/go.mod h1:+QX1XUpTXN4mLoblf4tqr5CQcyHPAki2HLXqQMY6vh8= +sigs.k8s.io/controller-runtime v0.22.5 h1:v3nfSUMowX/2WMp27J9slwGFyAt7IV0YwBxAkrUr0GE= +sigs.k8s.io/controller-runtime v0.22.5/go.mod h1:pc5SoYWnWI6I+cBHYYdZ7B6YHZVY5xNfll88JB+vniI= sigs.k8s.io/gateway-api v1.4.1 h1:NPxFutNkKNa8UfLd2CMlEuhIPMQgDQ6DXNKG9sHbJU8= sigs.k8s.io/gateway-api v1.4.1/go.mod h1:AR5RSqciWP98OPckEjOjh2XJhAe2Na4LHyXD2FUY7Qk= -sigs.k8s.io/gateway-api-inference-extension v1.3.0-rc.2 h1:/UBpLm3Z1HqCTfawkLKwY+PFFGawU55gvjJNJpf6LyM= -sigs.k8s.io/gateway-api-inference-extension v1.3.0-rc.2/go.mod h1:Cyex0AlEzhuXFklzl0y5Hdf5zVY8PUtSKhzMvHh5D9M= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a h1:Ce5CZ0R3c5H475uEuJ92FMgux3j99wDrSsI4ivTBEXQ= +sigs.k8s.io/gateway-api-inference-extension v0.0.0-20260128235548-fd30cb97714a/go.mod h1:lvMpB9a+Lk+xBi5Pk6teUG+NqA16WR8nRpmBNFJbflU= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730 h1:IpInykpT6ceI+QxKBbEflcR5EXP7sU1kvOlxwZh5txg= sigs.k8s.io/json v0.0.0-20250730193827-2d320260d730/go.mod h1:mdzfpAEoE6DHQEN0uh9ZbOCuHbLK5wOm7dK4ctXE9Tg= sigs.k8s.io/kustomize/api v0.21.0 h1:I7nry5p8iDJbuRdYS7ez8MUvw7XVNPcIP5GkzzuXIIQ= diff --git a/pkg/plugins/datalayer/models/datasource_test.go b/pkg/plugins/datalayer/models/datasource_test.go new file mode 100644 index 000000000..1c398ad02 --- /dev/null +++ b/pkg/plugins/datalayer/models/datasource_test.go @@ -0,0 +1,49 @@ +// Package models +package models + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "k8s.io/apimachinery/pkg/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" +) + +func TestDatasource(t *testing.T) { + source := http.NewHTTPDataSource("https", "/models", true, ModelsDataSourceType, + "models-data-source", parseModels, ModelsResponseType) + extractor, err := NewModelExtractor() + assert.Nil(t, err, "failed to create extractor") + + err = source.AddExtractor(extractor) + assert.Nil(t, err, "failed to add extractor") + + err = source.AddExtractor(extractor) + assert.NotNil(t, err, "expected to fail to add the same extractor twice") + + extractors := source.Extractors() + assert.Len(t, extractors, 1) + assert.Equal(t, extractor.TypedName().String(), extractors[0]) + + err = datalayer.RegisterSource(source) + assert.Nil(t, err, "failed to register") + + ctx := context.Background() + factory := datalayer.NewEndpointFactory([]fwkdl.DataSource{source}, 100*time.Hour) + pod := &fwkdl.EndpointMetadata{ + NamespacedName: types.NamespacedName{ + Name: "pod1", + Namespace: "default", + }, + Address: "1.2.3.4:5678", + } + endpoint := factory.NewEndpoint(ctx, pod, nil) + assert.NotNil(t, endpoint, "failed to create endpoint") + + err = source.Collect(ctx, endpoint) + assert.NotNil(t, err, "expected to fail to collect metrics") +} diff --git a/pkg/plugins/datalayer/models/extractor.go b/pkg/plugins/datalayer/models/extractor.go new file mode 100644 index 000000000..317240ab2 --- /dev/null +++ b/pkg/plugins/datalayer/models/extractor.go @@ -0,0 +1,102 @@ +package models + +import ( + "context" + "fmt" + "reflect" + "strings" + + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + fwkplugin "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" +) + +const modelsAttributeKey = "/v1/models" + +// ModelInfoCollection defines models' data returned from /v1/models API +type ModelInfoCollection []ModelInfo + +// ModelInfo defines model's data returned from /v1/models API +type ModelInfo struct { + ID string `json:"id"` + Parent string `json:"parent,omitempty"` +} + +// String returns a string representation of the model info +func (m *ModelInfo) String() string { + return fmt.Sprintf("%+v", *m) +} + +// Clone returns a full copy of the object +func (m ModelInfoCollection) Clone() fwkdl.Cloneable { + if m == nil { + return nil + } + clone := make([]ModelInfo, len(m)) + copy(clone, m) + return (*ModelInfoCollection)(&clone) +} + +func (m ModelInfoCollection) String() string { + if m == nil { + return "[]" + } + parts := make([]string, len(m)) + for i, p := range m { + parts[i] = p.String() + } + return "[" + strings.Join(parts, ", ") + "]" +} + +// ModelResponse is the response from /v1/models API +type ModelResponse struct { + Object string `json:"object"` + Data []ModelInfo `json:"data"` +} + +// ModelsResponseType is the type of models response +var ( + ModelsResponseType = reflect.TypeOf(ModelResponse{}) +) + +// ModelExtractor implements the models extraction. +type ModelExtractor struct { + typedName fwkplugin.TypedName +} + +// NewModelExtractor returns a new model extractor. +func NewModelExtractor() (*ModelExtractor, error) { + return &ModelExtractor{ + typedName: fwkplugin.TypedName{ + Type: ModelsExtractorType, + Name: ModelsExtractorType, + }, + }, nil +} + +// TypedName returns the type and name of the ModelExtractor. +func (me *ModelExtractor) TypedName() fwkplugin.TypedName { + return me.typedName +} + +// WithName sets the name of the extractor. +func (me *ModelExtractor) WithName(name string) *ModelExtractor { + me.typedName.Name = name + return me +} + +// ExpectedInputType defines the type expected by ModelExtractor. +func (me *ModelExtractor) ExpectedInputType() reflect.Type { + return ModelsResponseType +} + +// Extract transforms the data source output into a concrete attribute that +// is stored on the given endpoint. +func (me *ModelExtractor) Extract(_ context.Context, data any, ep fwkdl.Endpoint) error { + models, ok := data.(*ModelResponse) + if !ok { + return fmt.Errorf("unexpected input in Extract: %T", data) + } + + ep.GetAttributes().Put(modelsAttributeKey, ModelInfoCollection(models.Data)) + return nil +} diff --git a/pkg/plugins/datalayer/models/extractor_test.go b/pkg/plugins/datalayer/models/extractor_test.go new file mode 100644 index 000000000..4f075f3d0 --- /dev/null +++ b/pkg/plugins/datalayer/models/extractor_test.go @@ -0,0 +1,113 @@ +package models + +import ( + "context" + "testing" + + "github.com/google/go-cmp/cmp" + + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" +) + +func TestExtractorExtract(t *testing.T) { + ctx := context.Background() + + extractor, err := NewModelExtractor() + if err != nil { + t.Fatalf("failed to create extractor: %v", err) + } + + if exType := extractor.TypedName().Type; exType == "" { + t.Error("empty extractor type") + } + + if exName := extractor.TypedName().Name; exName == "" { + t.Error("empty extractor name") + } + + if inputType := extractor.ExpectedInputType(); inputType != ModelsResponseType { + t.Errorf("incorrect expected input type: %v", inputType) + } + + ep := fwkdl.NewEndpoint(nil, nil) + if ep == nil { + t.Fatal("expected non-nil endpoint") + } + + model := "food-review" + + tests := []struct { + name string + data any + wantErr bool + updated bool // whether metrics are expected to change + }{ + { + name: "nil data", + data: nil, + wantErr: true, + updated: false, + }, + { + name: "empty ModelsResponse", + data: &ModelResponse{}, + wantErr: false, + updated: false, + }, + { + name: "valid models response", + data: &ModelResponse{ + Object: "list", + Data: []ModelInfo{ + { + ID: model, + }, + { + ID: "lora1", + Parent: model, + }, + }, + }, + wantErr: false, + updated: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("Extract panicked: %v", r) + } + }() + + attr := ep.GetAttributes() + before, ok := attr.Get(modelsAttributeKey) + if ok && before != nil { + t.Error("expected empty attributes") + } + err := extractor.Extract(ctx, tt.data, ep) + after, ok := attr.Get(modelsAttributeKey) + if !ok && tt.updated { + t.Error("expected updated attributes") + } + + if tt.wantErr && err == nil { + t.Errorf("expected error but got nil") + } + if !tt.wantErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + + if tt.updated { + if diff := cmp.Diff(before, after); diff == "" { + t.Errorf("expected models to be updated, but no change detected") + } + } else { + if diff := cmp.Diff(before, after); diff != "" { + t.Errorf("expected no models update, but got changes:\n%s", diff) + } + } + }) + } +} diff --git a/pkg/plugins/datalayer/models/factories.go b/pkg/plugins/datalayer/models/factories.go new file mode 100644 index 000000000..49cd4edad --- /dev/null +++ b/pkg/plugins/datalayer/models/factories.go @@ -0,0 +1,69 @@ +package models + +import ( + "encoding/json" + "fmt" + "io" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/http" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" +) + +const ( + // ModelsDataSourceType is models data source type + ModelsDataSourceType = "models-data-source" + // ModelsExtractorType is models extractor type + ModelsExtractorType = "model-server-protocol-models" +) + +// Configuration parameters for models data source. +type modelsDatasourceParams struct { + // Scheme defines the protocol scheme used in models retrieval (e.g., "http"). + Scheme string `json:"scheme"` + // Path defines the URL path used in models retrieval (e.g., "/v1/models"). + Path string `json:"path"` + // InsecureSkipVerify defines whether model server certificate should be verified or not. + InsecureSkipVerify bool `json:"insecureSkipVerify"` +} + +// ModelDataSourceFactory is a factory function used to instantiate data layer's +// models data source plugins specified in a configuration. +func ModelDataSourceFactory(name string, parameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { + cfg := defaultDataSourceConfigParams() + if parameters != nil { // overlay the defaults with configured values + if err := json.Unmarshal(parameters, cfg); err != nil { + return nil, err + } + } + if cfg.Scheme != "http" && cfg.Scheme != "https" { + return nil, fmt.Errorf("unsupported scheme: %s", cfg.Scheme) + } + + ds := http.NewHTTPDataSource(cfg.Scheme, cfg.Path, cfg.InsecureSkipVerify, ModelsDataSourceType, + name, parseModels, ModelsResponseType) + return ds, nil +} + +// ModelServerExtractorFactory is a factory function used to instantiate data layer's models +// Extractor plugins specified in a configuration. +func ModelServerExtractorFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { + extractor, err := NewModelExtractor() + if err != nil { + return nil, err + } + return extractor.WithName(name), nil +} + +func defaultDataSourceConfigParams() *modelsDatasourceParams { + return &modelsDatasourceParams{Scheme: "http", Path: "/v1/models", InsecureSkipVerify: true} +} + +func parseModels(data io.Reader) (any, error) { + body, err := io.ReadAll(data) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %v", err) + } + var modelsResponse ModelResponse + err = json.Unmarshal(body, &modelsResponse) + return &modelsResponse, err +} diff --git a/pkg/plugins/filter/by_label.go b/pkg/plugins/filter/by_label.go index 070b6a627..464bc81a7 100644 --- a/pkg/plugins/filter/by_label.go +++ b/pkg/plugins/filter/by_label.go @@ -5,9 +5,8 @@ import ( "encoding/json" "fmt" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -21,10 +20,10 @@ type byLabelParameters struct { AllowsNoLabel bool `json:"allowsNoLabel"` } -var _ framework.Filter = &ByLabel{} // validate interface conformance +var _ scheduling.Filter = &ByLabel{} // validate interface conformance // ByLabelFactory defines the factory function for the ByLabel filter. -func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := byLabelParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -47,7 +46,7 @@ func ByLabelFactory(name string, rawParameters json.RawMessage, _ plugins.Handle // NewByLabel creates and returns an instance of the RoleBasedFilter based on the input parameters // name - the filter name // labelName - the name of the label to use -// allowsNoLabel - if true pods without given label will be considered as valid (not filtered out) +// allowsNoLabel - if true endpoints without given label will be considered as valid (not filtered out) // validValuesApp - list of valid values func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues ...string) *ByLabel { validValuesMap := map[string]struct{}{} @@ -57,27 +56,27 @@ func NewByLabel(name string, labelName string, allowsNoLabel bool, validValues . } return &ByLabel{ - typedName: plugins.TypedName{Type: ByLabelType, Name: name}, + typedName: plugin.TypedName{Type: ByLabelType, Name: name}, labelName: labelName, allowsNoLabel: allowsNoLabel, validValues: validValuesMap, } } -// ByLabel - filters out pods based on the values defined by the given label +// ByLabel - filters out endpoints based on the values defined by the given label type ByLabel struct { // name defines the filter typed name - typedName plugins.TypedName + typedName plugin.TypedName // labelName defines the name of the label to be checked labelName string // validValues defines list of valid label values validValues map[string]struct{} - // allowsNoLabel - if true pods without given label will be considered as valid (not filtered out) + // allowsNoLabel - if true endpoints without given label will be considered as valid (not filtered out) allowsNoLabel bool } // TypedName returns the typed name of the plugin -func (f *ByLabel) TypedName() plugins.TypedName { +func (f *ByLabel) TypedName() plugin.TypedName { return f.typedName } @@ -87,19 +86,19 @@ func (f *ByLabel) WithName(name string) *ByLabel { return f } -// Filter filters out all pods that are not marked with one of roles from the validRoles collection +// Filter filters out all endpoints that are not marked with one of roles from the validRoles collection // or has no role label in case allowsNoRolesLabel is true -func (f *ByLabel) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { - filteredPods := []types.Pod{} +func (f *ByLabel) Filter(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) []scheduling.Endpoint { + filteredEndpoints := []scheduling.Endpoint{} - for _, pod := range pods { - val, labelDefined := pod.GetPod().Labels[f.labelName] + for _, endpoint := range endpoints { + val, labelDefined := endpoint.GetMetadata().Labels[f.labelName] _, valueExists := f.validValues[val] if (!labelDefined && f.allowsNoLabel) || valueExists { - filteredPods = append(filteredPods, pod) + filteredEndpoints = append(filteredEndpoints, endpoint) } } - return filteredPods + return filteredEndpoints } diff --git a/pkg/plugins/filter/by_label_selector.go b/pkg/plugins/filter/by_label_selector.go index 98b95d418..ceb53a0e3 100644 --- a/pkg/plugins/filter/by_label_selector.go +++ b/pkg/plugins/filter/by_label_selector.go @@ -8,9 +8,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -19,10 +18,10 @@ const ( ) // compile-time type assertion -var _ framework.Filter = &ByLabelSelector{} +var _ scheduling.Filter = &ByLabelSelector{} // ByLabelSelectorFactory defines the factory function for the ByLabelSelector filter -func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func ByLabelSelectorFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := metav1.LabelSelector{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -44,30 +43,30 @@ func NewByLabelSelector(name string, selector *metav1.LabelSelector) (*ByLabelSe } return &ByLabelSelector{ - typedName: plugins.TypedName{Type: ByLabelSelectorType, Name: name}, + typedName: plugin.TypedName{Type: ByLabelSelectorType, Name: name}, selector: labelSelector, }, nil } -// ByLabelSelector filters out pods that do not match its label selector criteria +// ByLabelSelector filters out endpoints that do not match its label selector criteria type ByLabelSelector struct { - typedName plugins.TypedName + typedName plugin.TypedName selector labels.Selector } // TypedName returns the typed name of the plugin -func (blf *ByLabelSelector) TypedName() plugins.TypedName { +func (blf *ByLabelSelector) TypedName() plugin.TypedName { return blf.typedName } -// Filter filters out all pods that do not satisfy the label selector -func (blf *ByLabelSelector) Filter(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) []types.Pod { - filtered := []types.Pod{} +// Filter filters out all endpoints that do not satisfy the label selector +func (blf *ByLabelSelector) Filter(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) []scheduling.Endpoint { + filtered := []scheduling.Endpoint{} - for _, pod := range pods { - labels := labels.Set(pod.GetPod().Labels) + for _, endpoint := range endpoints { + labels := labels.Set(endpoint.GetMetadata().Labels) if blf.selector.Matches(labels) { - filtered = append(filtered, pod) + filtered = append(filtered, endpoint) } } return filtered diff --git a/pkg/plugins/filter/by_label_selector_test.go b/pkg/plugins/filter/by_label_selector_test.go index 3dc16e9d0..0d7947f09 100644 --- a/pkg/plugins/filter/by_label_selector_test.go +++ b/pkg/plugins/filter/by_label_selector_test.go @@ -9,9 +9,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" "github.com/llm-d/llm-d-inference-scheduler/test/utils" @@ -147,35 +146,35 @@ func TestByLabelSelectorFactoryWithInvalidJSON(t *testing.T) { } func TestByLabelSelectorFiltering(t *testing.T) { - pods := []types.Pod{ - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, + endpoints := []scheduling.Endpoint{ + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, "10.0.0.1", map[string]string{ "app": "nginx", "version": "v1.0", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, "10.0.0.2", map[string]string{ "app": "nginx", "version": "v1.1", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, "10.0.0.3", map[string]string{ "app": "coredns", "tier": "system", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, "10.0.0.4", map[string]string{ "app": "redis", "tier": "backend", "deprecated": "true", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, "10.0.0.5", map[string]string{ "app": "web", @@ -301,17 +300,17 @@ func TestByLabelSelectorFiltering(t *testing.T) { ctx := utils.NewTestContext(t) - filteredPods := blf.Filter(ctx, nil, nil, pods) + filteredEndpoints := blf.Filter(ctx, nil, nil, endpoints) - var actualPodNames []string - for _, pod := range filteredPods { - actualPodNames = append(actualPodNames, pod.GetPod().NamespacedName.Name) + var actualEndpointNames []string + for _, endpoint := range filteredEndpoints { + actualEndpointNames = append(actualEndpointNames, endpoint.GetMetadata().NamespacedName.Name) } - assert.ElementsMatch(t, tt.expectedPods, actualPodNames, - "filtered pods should match expected pods") - assert.Len(t, filteredPods, len(tt.expectedPods), - "filtered pods count should match expected count") + assert.ElementsMatch(t, tt.expectedPods, actualEndpointNames, + "filtered endpoints should match expected endpoints") + assert.Len(t, filteredEndpoints, len(tt.expectedPods), + "filtered endpoints count should match expected count") }) } } @@ -326,26 +325,26 @@ func TestByLabelSelectorFilterEdgeCases(t *testing.T) { ctx := utils.NewTestContext(t) - t.Run("empty pods slice", func(t *testing.T) { - result := blf.Filter(ctx, nil, nil, []types.Pod{}) + t.Run("empty endpoints slice", func(t *testing.T) { + result := blf.Filter(ctx, nil, nil, []scheduling.Endpoint{}) assert.Empty(t, result) }) - t.Run("nil pods slice", func(t *testing.T) { + t.Run("nil endpoints slice", func(t *testing.T) { result := blf.Filter(ctx, nil, nil, nil) assert.Empty(t, result) }) - t.Run("pods with nil labels", func(t *testing.T) { - pods := []types.Pod{createPod(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", nil)} - result := blf.Filter(ctx, nil, nil, pods) - assert.Empty(t, result, "pod with nil labels should not match") + t.Run("endpoints with nil labels", func(t *testing.T) { + endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", nil)} + result := blf.Filter(ctx, nil, nil, endpoints) + assert.Empty(t, result, "endpoint with nil labels should not match") }) - t.Run("pods with empty labels", func(t *testing.T) { - pods := []types.Pod{createPod(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", map[string]string{})} - result := blf.Filter(ctx, nil, nil, pods) - assert.Empty(t, result, "pod with empty labels should not match") + t.Run("endpoints with empty labels", func(t *testing.T) { + endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Name: "pod-1"}, "10.0.0.1", map[string]string{})} + result := blf.Filter(ctx, nil, nil, endpoints) + assert.Empty(t, result, "endpoint with empty labels should not match") }) } @@ -371,7 +370,7 @@ func ExamplePrefillDecodeRolesInLWS() { plugin, _ = filter.ByLabelSelectorFactory("prefill-role", prefillWorkerJSON, nil) prefillworker, _ := plugin.(*filter.ByLabelSelector) - pods := []types.Pod{createPod(k8stypes.NamespacedName{Namespace: "default", Name: "vllm"}, + endpoints := []scheduling.Endpoint{createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "vllm"}, "10.0.0.1", map[string]string{ "app.kubernetes.io/component": "vllm-worker", @@ -383,7 +382,7 @@ func ExamplePrefillDecodeRolesInLWS() { name := "" for _, blf := range []*filter.ByLabelSelector{decodeLeader, decodeFollower, prefillworker} { - filtered := PrefillDecodeRolesInLWS(blf, pods) + filtered := PrefillDecodeRolesInLWS(blf, endpoints) if len(filtered) > 0 { name = blf.TypedName().Name } @@ -395,17 +394,18 @@ func ExamplePrefillDecodeRolesInLWS() { } // Helper functions -func createPod(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) types.Pod { - return &types.PodMetrics{ - Pod: &backend.Pod{ +func createEndpoint(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: nsn, Address: ipaddr, Labels: labels, }, - MetricsState: &backendmetrics.MetricsState{}, - } + &fwkdl.Metrics{}, + nil, + ) } -func PrefillDecodeRolesInLWS(blf *filter.ByLabelSelector, pods []types.Pod) []types.Pod { - return blf.Filter(context.Background(), nil, nil, pods) +func PrefillDecodeRolesInLWS(blf *filter.ByLabelSelector, endpoints []scheduling.Endpoint) []scheduling.Endpoint { + return blf.Filter(context.Background(), nil, nil, endpoints) } diff --git a/pkg/plugins/filter/by_label_test.go b/pkg/plugins/filter/by_label_test.go index a3af75a4b..933871b61 100644 --- a/pkg/plugins/filter/by_label_test.go +++ b/pkg/plugins/filter/by_label_test.go @@ -8,9 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) @@ -139,54 +138,55 @@ func TestByLabelFactoryInvalidJSON(t *testing.T) { } // Helper functions -func createPod(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) types.Pod { - return &types.PodMetrics{ - Pod: &backend.Pod{ +func createEndpoint(nsn k8stypes.NamespacedName, ipaddr string, labels map[string]string) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: nsn, Address: ipaddr, Labels: labels, }, - MetricsState: &backendmetrics.MetricsState{}, - } + &fwkdl.Metrics{}, + nil, + ) } func TestByLabelFiltering(t *testing.T) { - pods := []types.Pod{ - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, + endpoints := []scheduling.Endpoint{ + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-1"}, "10.0.0.1", map[string]string{ "app": "nginx", "version": "v1.0", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "nginx-2"}, "10.0.0.2", map[string]string{ "app": "nginx", "version": "v1.1", "tier": "frontend", }), - createPod(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "kube-system", Name: "coredns-1"}, "10.0.0.3", map[string]string{ "app": "coredns", "tier": "system", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "redis-1"}, "10.0.0.4", map[string]string{ "app": "redis", "tier": "backend", "deprecated": "true", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "web-1"}, "10.0.0.5", map[string]string{ "app": "web", "tier": "frontend", "environment": "production", }), - createPod(k8stypes.NamespacedName{Namespace: "default", Name: "no-tier-pod"}, + createEndpoint(k8stypes.NamespacedName{Namespace: "default", Name: "no-tier-pod"}, "10.0.0.6", map[string]string{ "app": "unknown", @@ -247,17 +247,17 @@ func TestByLabelFiltering(t *testing.T) { ctx := utils.NewTestContext(t) - filteredPods := blf.Filter(ctx, nil, nil, pods) + filteredEndpoints := blf.Filter(ctx, nil, nil, endpoints) - var actualPodNames []string - for _, pod := range filteredPods { - actualPodNames = append(actualPodNames, pod.GetPod().NamespacedName.Name) + var actualEndpointNames []string + for _, endpoint := range filteredEndpoints { + actualEndpointNames = append(actualEndpointNames, endpoint.GetMetadata().NamespacedName.Name) } - assert.ElementsMatch(t, tt.expectedPods, actualPodNames, - "filtered pods should match expected pods") - assert.Len(t, filteredPods, len(tt.expectedPods), - "filtered pods count should match expected count") + assert.ElementsMatch(t, tt.expectedPods, actualEndpointNames, + "filtered endpoints should match expected endpoints") + assert.Len(t, filteredEndpoints, len(tt.expectedPods), + "filtered endpoints count should match expected count") }) } } diff --git a/pkg/plugins/filter/pd_role.go b/pkg/plugins/filter/pd_role.go index cc4cf7448..da96e7893 100644 --- a/pkg/plugins/filter/pd_role.go +++ b/pkg/plugins/filter/pd_role.go @@ -3,7 +3,7 @@ package filter import ( "encoding/json" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" ) const ( @@ -23,7 +23,7 @@ const ( ) // PrefillRoleFactory defines the factory function for the Prefill filter. -func PrefillRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func PrefillRoleFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { return NewPrefillRole().WithName(name), nil } @@ -33,7 +33,7 @@ func NewPrefillRole() *ByLabel { } // DecodeRoleFactory defines the factory function for the Decode filter. -func DecodeRoleFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func DecodeRoleFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { return NewDecodeRole().WithName(name), nil } diff --git a/pkg/plugins/pre-request/pd_prerequest.go b/pkg/plugins/pre-request/pd_prerequest.go index beebbe46c..c77fc700f 100644 --- a/pkg/plugins/pre-request/pd_prerequest.go +++ b/pkg/plugins/pre-request/pd_prerequest.go @@ -7,9 +7,9 @@ import ( "fmt" "net" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" ) @@ -29,7 +29,7 @@ type prefillHeaderHandlerParameters struct { var _ requestcontrol.PreRequest = &PrefillHeaderHandler{} // PrefillHeaderHandlerFactory defines the factory function for the PrefillHeaderHandler -func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := prefillHeaderHandlerParameters{ PrefillProfile: defaultPrefillProfile, } @@ -44,19 +44,19 @@ func PrefillHeaderHandlerFactory(name string, rawParameters json.RawMessage, _ p // NewPrefillHeaderHandler initializes a new PrefillHeaderHandler and returns its pointer. func NewPrefillHeaderHandler(prefillProfile string) *PrefillHeaderHandler { return &PrefillHeaderHandler{ - typedName: plugins.TypedName{Type: PrefillHeaderHandlerType}, + typedName: plugin.TypedName{Type: PrefillHeaderHandlerType}, prefillProfile: prefillProfile, } } // PrefillHeaderHandler PreRequest plugin type PrefillHeaderHandler struct { - typedName plugins.TypedName + typedName plugin.TypedName prefillProfile string } // TypedName returns the typed name of the plugin. -func (p *PrefillHeaderHandler) TypedName() plugins.TypedName { +func (p *PrefillHeaderHandler) TypedName() plugin.TypedName { return p.typedName } @@ -67,7 +67,7 @@ func (p *PrefillHeaderHandler) WithName(name string) *PrefillHeaderHandler { } // PreRequest wires prefill SchedulerProfile result into a header to indicate prefill worker -func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { +func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) { if _, found := request.Headers[common.PrefillPodHeader]; found { request.Headers[common.PrefillPodHeader] = "" // clear header, if already set } @@ -77,7 +77,7 @@ func (p *PrefillHeaderHandler) PreRequest(_ context.Context, request *types.LLMR return // prefill profile failed to run or we chose not to run it, no-op in this case } - targetPod := prefillProfileRunResult.TargetPods[0].GetPod() + targetPod := prefillProfileRunResult.TargetEndpoints[0].GetMetadata() prefillHostPort := net.JoinHostPort(targetPod.Address, targetPod.Port) request.Headers[common.PrefillPodHeader] = prefillHostPort // in the form of } diff --git a/pkg/plugins/profile/always_disagg_decider.go b/pkg/plugins/profile/always_disagg_decider.go new file mode 100644 index 000000000..1fefd47fe --- /dev/null +++ b/pkg/plugins/profile/always_disagg_decider.go @@ -0,0 +1,48 @@ +package profile + +import ( + "context" + "encoding/json" + + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" +) + +const ( + // AlwaysDisaggDeciderPluginType is the type-name of the alwaysDisaggPDDecider plugin. + AlwaysDisaggDeciderPluginType = "always-disagg-pd-decider" +) + +// compile-time type assertion +var _ pdDeciderPlugin = &AlwaysDisaggPDDecider{} + +// AlwaysDisaggPDDecider is a PD decider plugin which always decide to disaggregate PD +type AlwaysDisaggPDDecider struct { + typedName plugin.TypedName +} + +// AlwaysDisaggPDDeciderPluginFactory defines the factory function for creating +// a new instance of the AlwaysDisaggPDDecider. +func AlwaysDisaggPDDeciderPluginFactory(name string, _ json.RawMessage, + _ plugin.Handle) (plugin.Plugin, error) { + return newAlwaysDisaggPDDecider().WithName(name), nil +} + +func newAlwaysDisaggPDDecider() *AlwaysDisaggPDDecider { + return &AlwaysDisaggPDDecider{} +} + +// TypedName returns the typed name of the plugin. +func (d *AlwaysDisaggPDDecider) TypedName() plugin.TypedName { + return d.typedName +} + +// WithName sets the name of the plugin. +func (d *AlwaysDisaggPDDecider) WithName(name string) *AlwaysDisaggPDDecider { + d.typedName.Name = name + return d +} + +func (d *AlwaysDisaggPDDecider) disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool { + return true +} diff --git a/pkg/plugins/profile/dp_profile_handler.go b/pkg/plugins/profile/dp_profile_handler.go index 5d71384ca..836872584 100644 --- a/pkg/plugins/profile/dp_profile_handler.go +++ b/pkg/plugins/profile/dp_profile_handler.go @@ -8,9 +8,9 @@ import ( "net" "strconv" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/controller-runtime/pkg/log" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" ) @@ -25,10 +25,10 @@ type dataParallelProfileHandlerParameters struct { } // compile-time type assertion -var _ framework.ProfileHandler = &DataParallelProfileHandler{} +var _ scheduling.ProfileHandler = &DataParallelProfileHandler{} // DataParallelProfileHandlerFactory defines the factory function for the DataParallelProfileHandler -func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { parameters := dataParallelProfileHandlerParameters{ PrimaryPort: 8000, } @@ -50,19 +50,19 @@ func DataParallelProfileHandlerFactory(name string, rawParameters json.RawMessag // NewDataParallelProfileHandler initializes a new PdProfileHandler and returns its pointer. func NewDataParallelProfileHandler(primaryPort int) *DataParallelProfileHandler { return &DataParallelProfileHandler{ - typedName: plugins.TypedName{Type: DataParallelProfileHandlerType}, + typedName: plugin.TypedName{Type: DataParallelProfileHandlerType}, primaryPort: strconv.Itoa(primaryPort), } } // DataParallelProfileHandler handles scheduler profiles for Data Parallel. type DataParallelProfileHandler struct { - typedName plugins.TypedName + typedName plugin.TypedName primaryPort string } // TypedName returns the typed name of the plugin. -func (h *DataParallelProfileHandler) TypedName() plugins.TypedName { +func (h *DataParallelProfileHandler) TypedName() plugin.TypedName { return h.typedName } @@ -74,12 +74,19 @@ func (h *DataParallelProfileHandler) WithName(name string) *DataParallelProfileH // Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the // previously executed cycles along with their results. -func (h *DataParallelProfileHandler) Pick(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, - profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { +func (h *DataParallelProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile, + profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile { if len(profiles) == len(profileResults) { // all profiles have been executed already in previous call - return map[string]*framework.SchedulerProfile{} + return map[string]scheduling.SchedulerProfile{} } - // return all profiles + // Validate that only one profile is configured for Data Parallel mode + if len(profiles) != 1 { + log.FromContext(ctx).Error(nil, "Data Parallel profile handler requires exactly one scheduling profile", + "profileCount", len(profiles), + ) + return map[string]scheduling.SchedulerProfile{} // return empty map for fast exit in later steps + } + // return only one profile return profiles } @@ -87,8 +94,8 @@ func (h *DataParallelProfileHandler) Pick(_ context.Context, _ *types.CycleState // It may aggregate results, log test profile outputs, or apply custom logic. It specifies in the SchedulingResult the // key of the primary profile that should be used to get the request selected destination. // When a profile run fails, its result in the profileResults map is nil. -func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, request *types.LLMRequest, - profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { +func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, + profileResults map[string]*scheduling.ProfileRunResult) (*scheduling.SchedulingResult, error) { if len(profileResults) != 1 { return nil, errors.New("data parallel profile handler is intended to be used with a single profile, failed to process multiple profiles") } @@ -104,23 +111,23 @@ func (h *DataParallelProfileHandler) ProcessResults(_ context.Context, _ *types. return nil, fmt.Errorf("failed to run scheduler profile '%s'", singleProfileName) } - newResult := types.ProfileRunResult{ - TargetPods: []types.Pod{}, + newResult := scheduling.ProfileRunResult{ + TargetEndpoints: []scheduling.Endpoint{}, } - targetPod := profileResult.TargetPods[0].GetPod() + targetPod := profileResult.TargetEndpoints[0].GetMetadata() request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetPod.Address, targetPod.Port) - for _, target := range profileResult.TargetPods { - newPodInfo := target.GetPod().Clone() - newPodInfo.Port = h.primaryPort - targetPod := &types.PodMetrics{Pod: newPodInfo, MetricsState: target.GetMetrics().Clone()} - newResult.TargetPods = append(newResult.TargetPods, targetPod) + for _, target := range profileResult.TargetEndpoints { + newMetadata := target.GetMetadata().Clone() + newMetadata.Port = h.primaryPort + targetEndpoint := scheduling.NewEndpoint(newMetadata, target.GetMetrics().Clone(), nil) + newResult.TargetEndpoints = append(newResult.TargetEndpoints, targetEndpoint) } - modifiedResults := map[string]*types.ProfileRunResult{singleProfileName: &newResult} + modifiedResults := map[string]*scheduling.ProfileRunResult{singleProfileName: &newResult} - return &types.SchedulingResult{ + return &scheduling.SchedulingResult{ ProfileResults: modifiedResults, PrimaryProfileName: singleProfileName, }, nil diff --git a/pkg/plugins/profile/dp_profile_handler_test.go b/pkg/plugins/profile/dp_profile_handler_test.go index 30a7a6da4..9bb942327 100644 --- a/pkg/plugins/profile/dp_profile_handler_test.go +++ b/pkg/plugins/profile/dp_profile_handler_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" ) @@ -120,17 +120,86 @@ func TestDataParallelProfileHandlerFactoryInvalidJSON(t *testing.T) { } } +func Test_DataParallelProfileHandler_Pick(t *testing.T) { + tests := []struct { + name string + profiles map[string]scheduling.SchedulerProfile + profileResults map[string]*scheduling.ProfileRunResult + expectEmptyResult bool + expectLogError bool + description string + }{ + { + name: "success: single profile, first call", + profiles: map[string]scheduling.SchedulerProfile{ + "default": newMockSchedulerProfile(), + }, + profileResults: map[string]*scheduling.ProfileRunResult{}, + expectEmptyResult: false, + expectLogError: false, + description: "Should return the single profile to run", + }, + { + name: "success: single profile, second call (all already executed)", + profiles: map[string]scheduling.SchedulerProfile{ + "default": newMockSchedulerProfile(), + }, + profileResults: map[string]*scheduling.ProfileRunResult{ + "default": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + }, + expectEmptyResult: true, + expectLogError: false, + description: "Should return empty map since all profiles have been executed already in previous call", + }, + { + name: "error: multiple profiles configured in EPP", + profiles: map[string]scheduling.SchedulerProfile{ + "profile1": newMockSchedulerProfile(), + "profile2": newMockSchedulerProfile(), + }, + profileResults: map[string]*scheduling.ProfileRunResult{}, + expectEmptyResult: true, + expectLogError: true, + description: "Should return empty map and log error for multiple profiles", + }, + { + name: "error: zero profiles configured in EPP", + profiles: map[string]scheduling.SchedulerProfile{}, + profileResults: map[string]*scheduling.ProfileRunResult{}, + expectEmptyResult: true, + expectLogError: true, + description: "Should return empty map and log error for zero profiles", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + handler := NewDataParallelProfileHandler(8000).WithName("test-handler") + ctx := context.Background() + + result := handler.Pick(ctx, &scheduling.CycleState{}, &scheduling.LLMRequest{}, tt.profiles, tt.profileResults) + + if tt.expectEmptyResult { + assert.Empty(t, result, tt.description) + } else { + assert.NotEmpty(t, result, tt.description) + assert.Equal(t, len(tt.profiles), len(result), "Should return all profiles when valid") + } + }) + } +} + func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { tests := []struct { name string primaryPort int - profileResults map[string]*types.ProfileRunResult + profileResults map[string]*scheduling.ProfileRunResult expectError bool - checkResult func(*testing.T, *types.SchedulingResult, map[string]string) + checkResult func(*testing.T, *scheduling.SchedulingResult, map[string]string) }{ { name: "error: multiple profiles not supported", - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "profile1": newMockProfileRunResult(DefaultTestPodPort, "pod1"), "profile2": newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, @@ -138,7 +207,7 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { }, { name: "error: single profile but result is nil", - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "nil-profile": nil, }, expectError: true, @@ -146,16 +215,16 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { { name: "success: single profile with primaryPort β†’ port overridden, header set", primaryPort: 9000, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { assert.Equal(t, "dp-profile", res.PrimaryProfileName) - pods := res.ProfileResults["dp-profile"].TargetPods + pods := res.ProfileResults["dp-profile"].TargetEndpoints require.Len(t, pods, 1) - assert.Equal(t, "9000", pods[0].GetPod().Port) // overridden + assert.Equal(t, "9000", pods[0].GetMetadata().Port) // overridden expectedHeader := net.JoinHostPort("10.0.0.1", DefaultTestPodPort) // original assert.Equal(t, expectedHeader, headers[common.DataParallelPodHeader]) }, @@ -163,28 +232,28 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { { name: "success: primaryPort=0 β†’ port becomes '0'", primaryPort: 0, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "dp": newMockProfileRunResult("8080", "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { - pod := res.ProfileResults["dp"].TargetPods[0] - assert.Equal(t, "0", pod.GetPod().Port) + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { + pod := res.ProfileResults["dp"].TargetEndpoints[0] + assert.Equal(t, "0", pod.GetMetadata().Port) assert.Equal(t, "10.0.0.1:8080", headers[common.DataParallelPodHeader]) }, }, { name: "success: multiple target pods β†’ all ports overridden", primaryPort: 8080, - profileResults: map[string]*types.ProfileRunResult{ + profileResults: map[string]*scheduling.ProfileRunResult{ "dp-profile": newMockProfileRunResult(DefaultTestPodPort, "pod1", "pod2"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { - pods := res.ProfileResults["dp-profile"].TargetPods + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { + pods := res.ProfileResults["dp-profile"].TargetEndpoints assert.Len(t, pods, 2) for _, p := range pods { - assert.Equal(t, "8080", p.GetPod().Port) + assert.Equal(t, "8080", p.GetMetadata().Port) } assert.Equal(t, net.JoinHostPort("10.0.0.1", DefaultTestPodPort), headers[common.DataParallelPodHeader]) }, @@ -195,8 +264,8 @@ func Test_DataParallelProfileHandler_ProcessResults(t *testing.T) { t.Run(tt.name, func(t *testing.T) { handler := NewDataParallelProfileHandler(tt.primaryPort).WithName("test-handler") headers := make(map[string]string) - req := &types.LLMRequest{Headers: headers} - result, err := handler.ProcessResults(context.Background(), &types.CycleState{}, req, tt.profileResults) + req := &scheduling.LLMRequest{Headers: headers} + result, err := handler.ProcessResults(context.Background(), &scheduling.CycleState{}, req, tt.profileResults) if tt.expectError { assert.Error(t, err) diff --git a/pkg/plugins/profile/pd_profile_handler.go b/pkg/plugins/profile/pd_profile_handler.go index 8dff33e43..3f5a4b932 100644 --- a/pkg/plugins/profile/pd_profile_handler.go +++ b/pkg/plugins/profile/pd_profile_handler.go @@ -9,49 +9,56 @@ import ( "net" "strconv" - "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" + "github.com/llm-d/llm-d-inference-scheduler/pkg/metrics" ) const ( // PdProfileHandlerType is the type of the PdProfileHandler PdProfileHandlerType = "pd-profile-handler" - defaultDecodeProfile = "decode" - defaultPrefillProfile = "prefill" - defaultPrefixPluginType = prefix.PrefixCachePluginType + defaultDecodeProfile = "decode" + defaultPrefillProfile = "prefill" + defaultPrefixPluginType = prefix.PrefixCachePluginType + defaultDeciderPluginName = AlwaysDisaggDeciderPluginType + + // AverageCharactersPerToken is an estimated average characters per token, used since the request we cached is not tokenized. + AverageCharactersPerToken = 4 ) +// pdDeciderPlugin interface for pd decider plugins +type pdDeciderPlugin interface { + plugin.Plugin + // disaggregate checks if disaggregated PD is required for the given request and endpoint. + disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool +} + type pdProfileHandlerParameters struct { - Threshold int `json:"threshold"` - DecodeProfile string `json:"decodeProfile"` - PrefillProfile string `json:"prefillProfile"` - PrefixPluginType string `json:"prefixPluginType"` - PrefixPluginName string `json:"prefixPluginName"` - HashBlockSize int `json:"hashBlockSize"` - PrimaryPort int `json:"primaryPort"` + DecodeProfile string `json:"decodeProfile"` + PrefillProfile string `json:"prefillProfile"` + PrefixPluginType string `json:"prefixPluginType"` + PrefixPluginName string `json:"prefixPluginName"` + PrimaryPort int `json:"primaryPort"` + DeciderPluginName string `json:"deciderPluginName"` } // compile-time type assertion -var _ framework.ProfileHandler = &PdProfileHandler{} +var _ scheduling.ProfileHandler = &PdProfileHandler{} // PdProfileHandlerFactory defines the factory function for the PdProfileHandler -func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := pdProfileHandlerParameters{ - Threshold: 0, - DecodeProfile: defaultDecodeProfile, - PrefillProfile: defaultPrefillProfile, - PrefixPluginType: defaultPrefixPluginType, - HashBlockSize: prefix.DefaultBlockSize, - PrimaryPort: 0, + DecodeProfile: defaultDecodeProfile, + PrefillProfile: defaultPrefillProfile, + PrefixPluginType: defaultPrefixPluginType, + PrimaryPort: 0, + DeciderPluginName: defaultDeciderPluginName, } if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -63,54 +70,66 @@ func PdProfileHandlerFactory(name string, rawParameters json.RawMessage, _ plugi parameters.PrefixPluginName = parameters.PrefixPluginType } - if parameters.Threshold < 0 { - return nil, fmt.Errorf("invalid threshold: must be >= 0, got %d", parameters.Threshold) - } - - if parameters.HashBlockSize <= 0 { - return nil, fmt.Errorf("invalid hashBlockSize: must be > 0, got %d", parameters.HashBlockSize) - } - if parameters.PrimaryPort != 0 { if parameters.PrimaryPort < 1 || parameters.PrimaryPort > 65535 { return nil, fmt.Errorf("invalid primaryPort: must be between 1 and 65535, got %d", parameters.PrimaryPort) } } - return NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName, - parameters.Threshold, parameters.HashBlockSize, parameters.PrimaryPort).WithName(name), nil + if parameters.DeciderPluginName == "" { + return nil, errors.New("decider plugin name is not defined") + } + + plugin := handle.Plugin(parameters.DeciderPluginName) + if plugin == nil { + return nil, fmt.Errorf("invalid decider plugin type: %s", parameters.DeciderPluginName) + } + + deciderPlugin, ok := plugin.(pdDeciderPlugin) + if !ok { + return nil, fmt.Errorf("decider plugin of type: %s does not implement pdDeciderPlugin", parameters.DeciderPluginName) + } + + handler, err := NewPdProfileHandler(parameters.PrefillProfile, parameters.DecodeProfile, parameters.PrefixPluginType, parameters.PrefixPluginName, + parameters.PrimaryPort, deciderPlugin) + + if err != nil { + return nil, err + } + + return handler.WithName(name), nil + } // NewPdProfileHandler initializes a new PdProfileHandler and returns its pointer. -func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, pdThreshold, hashBlockSize, primaryPort int) *PdProfileHandler { +func NewPdProfileHandler(prefillProfile, decodeProfile, prefixPluginType, prefixPluginName string, + primaryPort int, deciderPlugin pdDeciderPlugin) (*PdProfileHandler, error) { result := &PdProfileHandler{ - typedName: plugins.TypedName{Type: PdProfileHandlerType}, - prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName}, + typedName: plugin.TypedName{Type: PdProfileHandlerType}, + prefixPluginTypedName: plugin.TypedName{Type: prefixPluginType, Name: prefixPluginName}, decodeProfile: decodeProfile, prefillProfile: prefillProfile, - pdThreshold: pdThreshold, - hashBlockSize: hashBlockSize, + decider: deciderPlugin, } if primaryPort != 0 { result.primaryPort = strconv.Itoa(primaryPort) } - return result + return result, nil } // PdProfileHandler handles scheduler profiles for PD. type PdProfileHandler struct { - typedName plugins.TypedName - prefixPluginTypedName plugins.TypedName + typedName plugin.TypedName + prefixPluginTypedName plugin.TypedName decodeProfile string prefillProfile string - pdThreshold int - hashBlockSize int primaryPort string + decider pdDeciderPlugin } // TypedName returns the typed name of the plugin. -func (h *PdProfileHandler) TypedName() plugins.TypedName { +func (h *PdProfileHandler) TypedName() plugin.TypedName { return h.typedName } @@ -122,11 +141,11 @@ func (h *PdProfileHandler) WithName(name string) *PdProfileHandler { // Pick selects the SchedulingProfiles to run from the list of candidate profiles, while taking into consideration the request properties and the // previously executed cycles along with their results. -func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, profiles map[string]*framework.SchedulerProfile, - profileResults map[string]*types.ProfileRunResult) map[string]*framework.SchedulerProfile { +func (h *PdProfileHandler) Pick(ctx context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, profiles map[string]scheduling.SchedulerProfile, + profileResults map[string]*scheduling.ProfileRunResult) map[string]scheduling.SchedulerProfile { if _, executed := profileResults[h.decodeProfile]; !executed { // if decode profile was not executed yet, first let the scheduler run the decode profile - return map[string]*framework.SchedulerProfile{ + return map[string]scheduling.SchedulerProfile{ h.decodeProfile: profiles[h.decodeProfile], } } @@ -135,75 +154,56 @@ func (h *PdProfileHandler) Pick(ctx context.Context, cycleState *types.CycleStat // when a profile run fails its result value is nil. we need to check decode result before continuing to prefill // check if all configured profiles have been executed, or if decode failed, no need to run more profiles. if len(profiles) == len(profileResults) || profileResults[h.decodeProfile] == nil { - return map[string]*framework.SchedulerProfile{} + return map[string]scheduling.SchedulerProfile{} } - if h.pdThreshold > 0 { - userInput, err := getUserInputBytes(request) - if err != nil { - log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input bytes") - return nil - } - - // if we're here that means decode profile ran successfully, and we have additional profile configured that didn't run yet, - // which means PD is enabled (otherwise, prefill profile is not configured at all and this profile handler is not used). - // inspect decode execution result to decide if prefill should run or not. - // if the request is short enough, use decode results only and don't run the prefill profile. - hitPercentagePrefix := 0.0 // default to 0, meaning no prefix cache hit - prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(h.prefixPluginTypedName.String())) - if err != nil { - log.FromContext(ctx).Error(err, "unable to read prefix state") - } else { - decodePod := profileResults[h.decodeProfile].TargetPods[0].GetPod().NamespacedName - hitPrefix := max(prefixState.PrefixCacheServers[prefix.ServerID(decodePod)]-1, 0) // The first hit is always the model name - hitPercentagePrefix = float64(hitPrefix*h.hashBlockSize) / float64(len(userInput)) - log.FromContext(ctx).V(logutil.DEBUG).Info("Computed hit percentage for prefix cache", "hitPercentage", hitPercentagePrefix, - "promptLength", len(userInput)) - } + inputTokens, err := getUserInputLenInTokens(request) + if err != nil { + log.FromContext(ctx).V(logutil.DEBUG).Error(err, "Failed to get user input") + return nil + } - if (1.0-hitPercentagePrefix)*float64(len(userInput)) < float64(h.pdThreshold) { - log.FromContext(ctx).Info("Non-cached suffix is smaller than threshold, using decode profile only", "hitPercentage", hitPercentagePrefix) - metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly) - return map[string]*framework.SchedulerProfile{} // do not run prefill + if h.decider != nil && h.decider.disaggregate(ctx, inputTokens, profileResults[h.decodeProfile].TargetEndpoints[0]) { + metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode) + // run the prefill profile + return map[string]scheduling.SchedulerProfile{ + h.prefillProfile: profiles[h.prefillProfile], } } - metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypePrefillDecode) - // run the prefill profile - return map[string]*framework.SchedulerProfile{ - h.prefillProfile: profiles[h.prefillProfile], - } + metrics.RecordPDDecision(request.TargetModel, metrics.DecisionTypeDecodeOnly) + return map[string]scheduling.SchedulerProfile{} // do not run prefill } // ProcessResults handles the outcome of the profile runs after the selected profiles ran. // In case of an error in any of the profiles, the matching entry in the profileResults will contain nil, to indicate there was // an error while running the profile. -func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState, request *types.LLMRequest, - profileResults map[string]*types.ProfileRunResult) (*types.SchedulingResult, error) { +func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, + profileResults map[string]*scheduling.ProfileRunResult) (*scheduling.SchedulingResult, error) { decodeRunResults := profileResults[h.decodeProfile] if decodeRunResults == nil { // if decode profile failed to run, we should fail return nil, errors.New("failed to find available decode workers") } // otherwise, decode ran successfully - updatedResults := map[string]*types.ProfileRunResult{} + updatedResults := map[string]*scheduling.ProfileRunResult{} // Add decode profile to result if h.primaryPort != "" { // Data Parallel is active - targetPod := decodeRunResults.TargetPods[0].GetPod() - request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetPod.Address, targetPod.Port) + targetEndpoint := decodeRunResults.TargetEndpoints[0].GetMetadata() + request.Headers[common.DataParallelPodHeader] = net.JoinHostPort(targetEndpoint.Address, targetEndpoint.Port) - updatedResult := types.ProfileRunResult{ - TargetPods: []types.Pod{}, + updatedResult := scheduling.ProfileRunResult{ + TargetEndpoints: []scheduling.Endpoint{}, } - for _, target := range decodeRunResults.TargetPods { - updatedPodInfo := target.GetPod().Clone() - updatedPodInfo.Port = h.primaryPort - targetPod := &types.PodMetrics{Pod: updatedPodInfo, MetricsState: target.GetMetrics().Clone()} - updatedResult.TargetPods = append(updatedResult.TargetPods, targetPod) + for _, target := range decodeRunResults.TargetEndpoints { + updatedEndpointInfo := target.GetMetadata().Clone() + updatedEndpointInfo.Port = h.primaryPort + targetEndpoint := scheduling.NewEndpoint(updatedEndpointInfo, target.GetMetrics().Clone(), nil) + updatedResult.TargetEndpoints = append(updatedResult.TargetEndpoints, targetEndpoint) } updatedResults[h.decodeProfile] = &updatedResult } else { @@ -216,17 +216,24 @@ func (h *PdProfileHandler) ProcessResults(_ context.Context, _ *types.CycleState updatedResults[h.prefillProfile] = prefillRunResult } - return &types.SchedulingResult{ + return &scheduling.SchedulingResult{ PrimaryProfileName: h.decodeProfile, ProfileResults: updatedResults, }, nil } -func getUserInputBytes(request *types.LLMRequest) ([]byte, error) { +// returns length of user input in tokens +func getUserInputLenInTokens(request *scheduling.LLMRequest) (int, error) { if request.Body.Completions != nil { // assumed to be valid if not nil - return []byte(request.Body.Completions.Prompt), nil + return len([]byte(request.Body.Completions.Prompt)) / AverageCharactersPerToken, nil } // must be chat-completions request at this point, return bytes of entire messages - return json.Marshal(request.Body.ChatCompletions.Messages) + prompt, err := json.Marshal(request.Body.ChatCompletions.Messages) + + if err != nil { + return 0, err + } + + return len(prompt) / AverageCharactersPerToken, nil } diff --git a/pkg/plugins/profile/pd_profile_handler_test.go b/pkg/plugins/profile/pd_profile_handler_test.go index 09068104b..c932c8064 100644 --- a/pkg/plugins/profile/pd_profile_handler_test.go +++ b/pkg/plugins/profile/pd_profile_handler_test.go @@ -8,130 +8,120 @@ import ( "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "github.com/llm-d/llm-d-inference-scheduler/pkg/common" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestPdProfileHandlerFactory(t *testing.T) { + ctx := utils.NewTestContext(t) tests := []struct { name string pluginName string - jsonParams string + params map[string]any expectErr bool }{ { name: "valid configuration with all defaults", pluginName: "default-handler", - jsonParams: "{}", + params: map[string]any{}, expectErr: false, }, { name: "valid configuration with custom values", pluginName: "custom-handler", - jsonParams: `{ - "threshold": 100, - "decodeProfile": "my-decode", - "prefillProfile": "my-prefill", - "prefixPluginName": "my-prefix-cache", - "hashBlockSize": 32, - "primaryPort": 8080 - }`, + params: map[string]any{ + "decodeProfile": "my-decode", + "prefillProfile": "my-prefill", + "prefixPluginName": "my-prefix-cache", + "primaryPort": 8080, + "deciderPluginName": PrefixBasedPDDeciderPluginType, + }, expectErr: false, }, { name: "zero primaryPort is allowed", pluginName: "zero-port", - jsonParams: `{"primaryPort": 0}`, - expectErr: false, - }, - { - name: "threshold = 0 is allowed", - pluginName: "zero-threshold", - jsonParams: `{"threshold": 0}`, - expectErr: false, - }, - { - name: "negative threshold should error", - pluginName: "neg-threshold", - jsonParams: `{"threshold": -1}`, - expectErr: true, - }, - { - name: "hashBlockSize = 0 should error", - pluginName: "zero-block-size", - jsonParams: `{"hashBlockSize": 0}`, - expectErr: true, + params: map[string]any{ + "primaryPort": 0, + }, + expectErr: false, }, { - name: "negative hashBlockSize should error", - pluginName: "neg-block-size", - jsonParams: `{"hashBlockSize": -5}`, - expectErr: true, + name: "nonCachedTokens = 0 is allowed", + pluginName: "zero-non-cached-tokens", + params: map[string]any{ + "deciderPluginName": PrefixBasedPDDeciderPluginType, + }, + expectErr: false, }, { name: "primaryPort below range should error", pluginName: "port-too-low", - jsonParams: `{"primaryPort": 0}`, // OK + params: map[string]any{"primaryPort": 0}, // OK expectErr: false, }, { name: "primaryPort = 1 is valid", pluginName: "port-min", - jsonParams: `{"primaryPort": 1}`, + params: map[string]any{"primaryPort": 1}, expectErr: false, }, { name: "primaryPort = 65535 is valid", pluginName: "port-max", - jsonParams: `{"primaryPort": 65535}`, + params: map[string]any{"primaryPort": 65535}, expectErr: false, }, { name: "empty decodeProfile is valid", pluginName: "empty-decode", - jsonParams: `{"decodeProfile": ""}`, + params: map[string]any{"decodeProfile": ""}, expectErr: false, }, { name: "empty prefillProfile is valid", pluginName: "empty-prefill", - jsonParams: `{"prefillProfile": ""}`, + params: map[string]any{"prefillProfile": ""}, expectErr: false, }, { name: "empty prefixPluginName is valid", pluginName: "empty-prefix-plugin", - jsonParams: `{"prefixPluginName": ""}`, + params: map[string]any{"prefixPluginName": ""}, expectErr: false, }, { name: "primaryPort = 65536 should error", pluginName: "port-too-high", - jsonParams: `{"primaryPort": 65536}`, + params: map[string]any{"primaryPort": 65536}, expectErr: true, }, { name: "primaryPort = -10 should error", pluginName: "port-negative", - jsonParams: `{"primaryPort": -10}`, + params: map[string]any{"primaryPort": -10}, expectErr: true, }, } + handle, err := createHandleWithDeciderPlugins(ctx) + assert.NoError(t, err) + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var rawParams json.RawMessage - if tt.jsonParams != "" { - rawParams = json.RawMessage(tt.jsonParams) + if tt.params != nil { + bytes, err := json.Marshal(tt.params) + assert.NoError(t, err) + rawParams = json.RawMessage(bytes) } - plugin, err := PdProfileHandlerFactory(tt.pluginName, rawParams, nil) + plugin, err := PdProfileHandlerFactory(tt.pluginName, rawParams, handle) if tt.expectErr { assert.Error(t, err) @@ -145,21 +135,19 @@ func TestPdProfileHandlerFactory(t *testing.T) { } func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) { + ctx := utils.NewTestContext(t) + invalidTests := []struct { name string jsonParams string }{ { name: "malformed JSON", - jsonParams: `{"threshold": 100, "hashBlockSize":`, // incomplete - }, - { - name: "threshold as string instead of int", - jsonParams: `{"threshold": "100"}`, + jsonParams: `{"deciderPluginName": `, // incomplete }, { - name: "hashBlockSize as boolean", - jsonParams: `{"hashBlockSize": true}`, + name: "invalid decider plugin type", + jsonParams: `{"deciderPluginName": "INVALID"}`, }, { name: "primaryPort as float", @@ -167,10 +155,13 @@ func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) { }, } + handle, err := createHandleWithDeciderPlugins(ctx) + assert.NoError(t, err) + for _, tt := range invalidTests { t.Run(tt.name, func(t *testing.T) { rawParams := json.RawMessage(tt.jsonParams) - plugin, err := PdProfileHandlerFactory("test", rawParams, nil) + plugin, err := PdProfileHandlerFactory("test", rawParams, handle) assert.Error(t, err) assert.Nil(t, plugin) @@ -180,164 +171,280 @@ func TestPdProfileHandlerFactoryInvalidJSON(t *testing.T) { const DefaultTestPodPort = "8000" -// createPod creates a mock Pod with customizable IP and port. -func createPod(nsn k8stypes.NamespacedName, ipaddr, port string, labels map[string]string) types.Pod { - return &types.PodMetrics{ - Pod: &backend.Pod{ +// createEndpoint creates a mock Endpoint with customizable IP and port. +func createEndpoint(nsn k8stypes.NamespacedName, ipaddr, port string, labels map[string]string) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: nsn, Address: ipaddr, Port: port, Labels: labels, }, - MetricsState: &backendmetrics.MetricsState{}, - } + nil, + fwkdl.NewAttributes(), + ) } // newMockProfileRunResult creates a ProfileRunResult with Pods using the given port. -func newMockProfileRunResult(port string, podNames ...string) *types.ProfileRunResult { - pods := make([]types.Pod, 0, len(podNames)) - for i, name := range podNames { +func newMockProfileRunResult(port string, endpointNames ...string) *scheduling.ProfileRunResult { + endpoints := make([]scheduling.Endpoint, 0, len(endpointNames)) + for i, name := range endpointNames { ip := fmt.Sprintf("10.0.0.%d", i+1) - pods = append(pods, createPod( + endpoints = append(endpoints, createEndpoint( k8stypes.NamespacedName{Namespace: "default", Name: name}, ip, port, map[string]string{}, )) } - return &types.ProfileRunResult{ - TargetPods: pods, + return &scheduling.ProfileRunResult{ + TargetEndpoints: endpoints, } } -func newMockSchedulerProfile() *framework.SchedulerProfile { - return &framework.SchedulerProfile{} +func newMockSchedulerProfile() scheduling.SchedulerProfile { + return &mockSchedulerProfile{} } -func TestPdProfileHandler_Pick(t *testing.T) { - ctx := utils.NewTestContext(t) - request := &types.LLMRequest{ - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ - Prompt: "hello world", +type mockSchedulerProfile struct{} + +func (p *mockSchedulerProfile) Run(_ context.Context, _ *scheduling.LLMRequest, _ *scheduling.CycleState, _ []scheduling.Endpoint) (*scheduling.ProfileRunResult, error) { + return &scheduling.ProfileRunResult{}, nil +} + +// creates and returns llm completion request forthe given prompt +func createRequest(prompt string) *scheduling.LLMRequest { + return &scheduling.LLMRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ + Prompt: prompt, }, }, } +} - profiles := map[string]*framework.SchedulerProfile{ +// returns array of profile names in the given profile pick result +func getProfilesFromResult(result map[string]scheduling.SchedulerProfile) []string { + profiles := make([]string, len(result)) + index := 0 + + for name := range result { + profiles[index] = name + index++ + } + + return profiles +} + +func TestPdProfileHandler_Pick(t *testing.T) { + ctx := utils.NewTestContext(t) + request := createRequest("hello world hello world hello world") + + profiles := map[string]scheduling.SchedulerProfile{ "decode": newMockSchedulerProfile(), "prefill": newMockSchedulerProfile(), } tests := []struct { - name string - pdThreshold int - hashBlockSize int - prefixPluginType string - prefixPluginName string - setupPrefixState func(*types.CycleState) - profileResults map[string]*types.ProfileRunResult - expectedProfiles []string + name string + nonCachedTokensLimit int + prefixPluginType string + prefixPluginName string + cachedTokens int + profileResults map[string]*scheduling.ProfileRunResult + expectedProfiles []string }{ { - name: "decode not executed yet β†’ run decode", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - profileResults: map[string]*types.ProfileRunResult{}, - expectedProfiles: []string{"decode"}, + name: "decode not executed yet β†’ run decode", + nonCachedTokensLimit: 10, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, + profileResults: map[string]*scheduling.ProfileRunResult{}, + expectedProfiles: []string{defaultDecodeProfile}, }, { - name: "decode failed (nil result) β†’ run nothing", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - profileResults: map[string]*types.ProfileRunResult{ - "decode": nil, + name: "decode failed (nil result) β†’ run nothing", + nonCachedTokensLimit: 10, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: nil, }, expectedProfiles: []string{}, }, { - name: "all profiles already executed β†’ run nothing", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - profileResults: map[string]*types.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), - "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"), + name: "all profiles already executed β†’ run nothing", + nonCachedTokensLimit: 10, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultPrefillProfile: newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, expectedProfiles: []string{}, }, { - name: "pd threshold NOT triggered β†’ run prefill", - pdThreshold: 5, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - setupPrefixState: func(cs *types.CycleState) { - state := &prefix.SchedulingContextState{ - PrefixCacheServers: map[prefix.ServerID]int{ - prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 1, - }, - } - key := plugins.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) - cs.Write(key, state) + name: "has enough not-cached tokens β†’ run prefill", + // Need at least 4 non-cached tokens (16+ chars) to trigger disaggregated prefill + // In this case: prompt length is 35 chars (8 tokens), cached length is 2 tokens -> disaggregated prefill should trigger + nonCachedTokensLimit: 4, + cachedTokens: 2, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, - profileResults: map[string]*types.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), - }, - expectedProfiles: []string{"prefill"}, + expectedProfiles: []string{defaultPrefillProfile}, }, { - name: "pd threshold triggered (short non-cached suffix) β†’ skip prefill", - pdThreshold: 100, - hashBlockSize: 16, - prefixPluginType: prefix.PrefixCachePluginType, - prefixPluginName: prefix.PrefixCachePluginType, - setupPrefixState: func(cs *types.CycleState) { - state := &prefix.SchedulingContextState{ - PrefixCacheServers: map[prefix.ServerID]int{ - prefix.ServerID(k8stypes.NamespacedName{Name: "pod1", Namespace: "default"}): 5, - }, - } - key := plugins.StateKey(fmt.Sprintf("%s/%s", prefix.PrefixCachePluginType, prefix.PrefixCachePluginType)) - cs.Write(key, state) - }, - profileResults: map[string]*types.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + name: "short non-cached suffix β†’ skip prefill", + // Need at least 4 non-cached tokens (16+ chars) to trigger disaggregated prefill + // In this case: prompt length is 35 chars (8 tokens), cached length is 5 tokens -> skip prefill + nonCachedTokensLimit: 4, + cachedTokens: 5, + prefixPluginType: prefix.PrefixCachePluginType, + prefixPluginName: prefix.PrefixCachePluginType, + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectedProfiles: []string{}, }, } for _, tt := range tests { + deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: tt.nonCachedTokensLimit}) + assert.NoError(t, err) + t.Run(tt.name, func(t *testing.T) { - handler := NewPdProfileHandler( - "prefill", - "decode", + handler, err := NewPdProfileHandler( + defaultPrefillProfile, + defaultDecodeProfile, tt.prefixPluginType, tt.prefixPluginName, - tt.pdThreshold, - tt.hashBlockSize, 0, - ).WithName("test-handler") + deciderPlugin, + ) + assert.NoError(t, err) + + // set prefix to the given cached tokens number for pod "pod1" in decode profile results + inputTokens := len(request.Body.Completions.Prompt) / AverageCharactersPerToken - cs := &types.CycleState{} - if tt.setupPrefixState != nil { - tt.setupPrefixState(cs) + for profileName, profileRes := range tt.profileResults { + if profileName == defaultDecodeProfile && profileRes != nil { + for _, pod := range profileRes.TargetEndpoints { + pod.Put(approximateprefix.PrefixCacheMatchInfoKey, + approximateprefix.NewPrefixCacheMatchInfo(tt.cachedTokens, inputTokens, 1)) + } + } } + result := handler.Pick(ctx, nil, request, profiles, tt.profileResults) + assert.ElementsMatch(t, tt.expectedProfiles, getProfilesFromResult(result)) + }) + } +} - result := handler.Pick(ctx, cs, request, profiles, tt.profileResults) +func TestPdProfileHandler_PickSeries(t *testing.T) { + ctx := context.Background() + prompt := "hello world, hello world, hello world, hello world, hello world, hello world, hello world!" + request := createRequest(prompt) + longerRequest := createRequest(prompt + "123") + longRequest := createRequest(prompt + prompt) - var actual []string - for name := range result { - actual = append(actual, name) - } + profiles := map[string]scheduling.SchedulerProfile{ + defaultDecodeProfile: newMockSchedulerProfile(), + defaultPrefillProfile: newMockSchedulerProfile(), + } + profileResults := map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), + } + + type testData struct { + request *scheduling.LLMRequest + cachedTokens int + expectedProfiles []string + } + tests := []struct { + name string + nonCachedTokensLimit int + tests []testData + }{ + { + name: "same request twice", + nonCachedTokensLimit: 2, + tests: []testData{{ + request: request, + cachedTokens: 0, + expectedProfiles: []string{defaultPrefillProfile}, + }, { + request: request, + cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken, + expectedProfiles: []string{}, + }}, + }, { + name: "short request and a little bit longer after it", + // Need at least 2 non-cached tokens (8+ chars) to trigger disaggregated prefill + // In this case: longer request is longer in 4 chars than the request -> no disaggregated prefill + nonCachedTokensLimit: 2, + tests: []testData{{ + request: request, + cachedTokens: 0, + expectedProfiles: []string{defaultPrefillProfile}, + }, { + request: longerRequest, + cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken, + expectedProfiles: []string{}, + }}, + }, { + name: "short request and a long one after it", + // Need at least 2 non-cached tokens (8+ chars) to trigger disaggregated prefill + // In this case: long request is longer enough than the request -> should have disaggregated prefill + nonCachedTokensLimit: 2, + tests: []testData{{ + request: request, + cachedTokens: 0, + expectedProfiles: []string{defaultPrefillProfile}, + }, { + request: longRequest, + cachedTokens: len(request.Body.Completions.Prompt) / AverageCharactersPerToken, + expectedProfiles: []string{defaultPrefillProfile}, + }}, + }, + } - assert.ElementsMatch(t, tt.expectedProfiles, actual) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: tt.nonCachedTokensLimit}) + assert.NoError(t, err) + + handler, err := NewPdProfileHandler( + defaultPrefillProfile, + defaultDecodeProfile, + prefix.PrefixCachePluginType, + prefix.PrefixCachePluginType, + 0, + deciderPlugin, + ) + assert.NoError(t, err) + + // run sequences of request + for _, innerTest := range tt.tests { + cs := &scheduling.CycleState{} + + // set prefix to the given cached tokens number for pod "pod1" in decode profile results + inputTokens := len(innerTest.request.Body.Completions.Prompt) / AverageCharactersPerToken + + for profileName, profileRes := range profileResults { + if profileName == defaultDecodeProfile && profileRes != nil { + for _, endpoint := range profileRes.TargetEndpoints { + endpoint.Put(approximateprefix.PrefixCacheMatchInfoKey, + approximateprefix.NewPrefixCacheMatchInfo(innerTest.cachedTokens, inputTokens, 1)) + } + } + } + + result := handler.Pick(ctx, cs, innerTest.request, profiles, profileResults) + assert.ElementsMatch(t, innerTest.expectedProfiles, getProfilesFromResult(result)) + } }) } } @@ -346,57 +453,57 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { tests := []struct { name string primaryPort int - profileResults map[string]*types.ProfileRunResult + profileResults map[string]*scheduling.ProfileRunResult expectError bool - checkResult func(*testing.T, *types.SchedulingResult, map[string]string) + checkResult func(*testing.T, *scheduling.SchedulingResult, map[string]string) }{ { name: "decode failed β†’ error", - profileResults: map[string]*types.ProfileRunResult{ - "decode": nil, + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: nil, }, expectError: true, }, { name: "decode success, no prefill, no primaryPort", primaryPort: 0, - profileResults: map[string]*types.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { - assert.Equal(t, "decode", res.PrimaryProfileName) - assert.Contains(t, res.ProfileResults, "decode") - assert.NotContains(t, res.ProfileResults, "prefill") - pod := res.ProfileResults["decode"].TargetPods[0].GetPod() - assert.Equal(t, DefaultTestPodPort, pod.Port) + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { + assert.Equal(t, defaultDecodeProfile, res.PrimaryProfileName) + assert.Contains(t, res.ProfileResults, defaultDecodeProfile) + assert.NotContains(t, res.ProfileResults, defaultPrefillProfile) + metadata := res.ProfileResults[defaultDecodeProfile].TargetEndpoints[0].GetMetadata() + assert.Equal(t, DefaultTestPodPort, metadata.Port) assert.Empty(t, headers[common.DataParallelPodHeader]) }, }, { name: "decode success, with prefill", primaryPort: 0, - profileResults: map[string]*types.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), - "prefill": newMockProfileRunResult(DefaultTestPodPort, "pod2"), + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), + defaultPrefillProfile: newMockProfileRunResult(DefaultTestPodPort, "pod2"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, _ map[string]string) { - assert.Equal(t, "decode", res.PrimaryProfileName) - assert.Contains(t, res.ProfileResults, "decode") - assert.Contains(t, res.ProfileResults, "prefill") + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, _ map[string]string) { + assert.Equal(t, defaultDecodeProfile, res.PrimaryProfileName) + assert.Contains(t, res.ProfileResults, defaultDecodeProfile) + assert.Contains(t, res.ProfileResults, defaultPrefillProfile) }, }, { name: "with primaryPort β†’ port updated and header set", primaryPort: 9000, - profileResults: map[string]*types.ProfileRunResult{ - "decode": newMockProfileRunResult(DefaultTestPodPort, "pod1"), + profileResults: map[string]*scheduling.ProfileRunResult{ + defaultDecodeProfile: newMockProfileRunResult(DefaultTestPodPort, "pod1"), }, expectError: false, - checkResult: func(t *testing.T, res *types.SchedulingResult, headers map[string]string) { - pod := res.ProfileResults["decode"].TargetPods[0].GetPod() - assert.Equal(t, "9000", pod.Port) + checkResult: func(t *testing.T, res *scheduling.SchedulingResult, headers map[string]string) { + metadata := res.ProfileResults[defaultDecodeProfile].TargetEndpoints[0].GetMetadata() + assert.Equal(t, "9000", metadata.Port) hostPort := headers[common.DataParallelPodHeader] assert.Equal(t, "10.0.0.1:8000", hostPort) @@ -405,22 +512,25 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { } for _, tt := range tests { + deciderPlugin, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: 0}) + assert.NoError(t, err) + t.Run(tt.name, func(t *testing.T) { - handler := NewPdProfileHandler( - "prefill", - "decode", + handler, err := NewPdProfileHandler( + defaultPrefillProfile, + defaultDecodeProfile, prefix.PrefixCachePluginType, prefix.PrefixCachePluginType, - 0, - prefix.DefaultBlockSize, tt.primaryPort, - ).WithName("test-handler") + deciderPlugin, + ) + assert.NoError(t, err) headers := make(map[string]string) - req := &types.LLMRequest{ + req := &scheduling.LLMRequest{ Headers: headers, } - result, err := handler.ProcessResults(context.Background(), &types.CycleState{}, req, tt.profileResults) + result, err := handler.ProcessResults(context.Background(), &scheduling.CycleState{}, req, tt.profileResults) if tt.expectError { assert.Error(t, err) @@ -433,3 +543,16 @@ func TestPdProfileHandler_ProcessResults(t *testing.T) { }) } } + +func createHandleWithDeciderPlugins(ctx context.Context) (plugin.Handle, error) { + handle := plugin.NewEppHandle(ctx, nil) + plugin1, err := NewPrefixBasedPDDecider(PrefixBasedPDDeciderConfig{NonCachedTokens: 4}) + if err != nil { + return nil, err + } + handle.AddPlugin(PrefixBasedPDDeciderPluginType, plugin1) + plugin2 := newAlwaysDisaggPDDecider() + handle.AddPlugin(AlwaysDisaggDeciderPluginType, plugin2) + + return handle, nil +} diff --git a/pkg/plugins/profile/prefix_based_pd_decider.go b/pkg/plugins/profile/prefix_based_pd_decider.go new file mode 100644 index 000000000..8948b4d37 --- /dev/null +++ b/pkg/plugins/profile/prefix_based_pd_decider.go @@ -0,0 +1,137 @@ +package profile + +import ( + "context" + "encoding/json" + "errors" + "fmt" + + "sigs.k8s.io/controller-runtime/pkg/log" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" +) + +const ( + // PrefixBasedPDDeciderPluginType is the type-name of the prefixBasedPDDecider plugin. + PrefixBasedPDDeciderPluginType = "prefix-based-pd-decider" +) + +// PrefixBasedPDDeciderConfig holds the configuration for the prefixBasedPDDecider plugin. +type PrefixBasedPDDeciderConfig struct { + // NonCachedTokens non cached minimum tokens that triggers disaggregated PD + NonCachedTokens int `json:"nonCachedTokens"` +} + +func (p PrefixBasedPDDeciderConfig) validate() error { + if p.NonCachedTokens < 0 { + return errors.New("nonCachedTokens parameter of prefix disaggregation decider cannot be negative") + } + + return nil +} + +// compile-time type assertion +var _ pdDeciderPlugin = &PrefixBasedPDDecider{} + +// PrefixBasedPDDecider is a PD decider plugin which decision is based prefix aware +type PrefixBasedPDDecider struct { + typedName plugin.TypedName + config PrefixBasedPDDeciderConfig +} + +// PrefixBasedPDDeciderPluginFactory defines the factory function for creating +// a new instance of the prefixBasedPDDecider. +func PrefixBasedPDDeciderPluginFactory(name string, rawParameters json.RawMessage, + handle plugin.Handle) (plugin.Plugin, error) { + config := PrefixBasedPDDeciderConfig{ + NonCachedTokens: 0, + } + + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, &config); err != nil { + return nil, fmt.Errorf("failed to parse %s plugin config: %w", PrefixBasedPDDeciderPluginType, err) + } + } + + decider, err := NewPrefixBasedPDDecider(config) + if err != nil { + return nil, fmt.Errorf("failed to create %s plugin: %w", PrefixBasedPDDeciderPluginType, err) + } + + return decider.WithName(name), nil +} + +// NewPrefixBasedPDDecider initializes a NewPrefixBasedPDDecider prefix based PD decider Plugin and returns its pointer. +// If the configuration is invalid an error is returned. +func NewPrefixBasedPDDecider(config PrefixBasedPDDeciderConfig) (*PrefixBasedPDDecider, error) { + if err := config.validate(); err != nil { + return nil, err + } + + return &PrefixBasedPDDecider{ + config: config, + }, nil +} + +// TypedName returns the typed name of the plugin. +func (d *PrefixBasedPDDecider) TypedName() plugin.TypedName { + return d.typedName +} + +// WithName sets the name of the plugin. +func (d *PrefixBasedPDDecider) WithName(name string) *PrefixBasedPDDecider { + d.typedName.Name = name + return d +} + +func (d *PrefixBasedPDDecider) disaggregate(ctx context.Context, inputTokens int, endpoint scheduling.Endpoint) bool { + logger := log.FromContext(ctx) + debugLogger := log.FromContext(ctx).V(logutil.DEBUG) + + if d.config.NonCachedTokens <= 0 { // always use disaggregation in case of non cached tokens number is 0 + return true + } + if endpoint == nil { + logger.Error(nil, "prefix decider: endpoint is nil") + return false + } + if inputTokens < d.config.NonCachedTokens { + debugLogger.Info("Input is shorter than the nonCachedToken, no disaggregated PD") + return false + } + // inspect the decode endpoint to decide if prefill should run or not. + // if the non-cached part is short enough - no disaggregation. + prefixInfoRaw, ok := endpoint.Get(approximateprefix.PrefixCacheMatchInfoKey) + if !ok || prefixInfoRaw == nil { + logger.Error(nil, "unable to read prefix cache state") + return false + } + prefixCacheMatchInfo, ok := prefixInfoRaw.(*approximateprefix.PrefixCacheMatchInfo) + if !ok { + logger.Error(nil, "wrong type of prefix cache match info") + return false + } + + // number of cached tokens + hitPrefixTokens := prefixCacheMatchInfo.MatchBlocks() * prefixCacheMatchInfo.BlockSizeTokens() + // length of non-cached suffix in tokens + nonCachedTokens := inputTokens - hitPrefixTokens + + debugLogger.Info("Computed hit percentage for prefix cache", + "absolute hit prefix len (tokens)", hitPrefixTokens, + "prompt length (token)", inputTokens) + + if nonCachedTokens < d.config.NonCachedTokens { + debugLogger.Info("Non-cached suffix is smaller than threshold, using decode profile only") + return false // do not run prefill + } + + return true +} + +// Consumes defines data types consumed by this plugin +func (*PrefixBasedPDDecider) Consumes() map[string]any { + return map[string]any{approximateprefix.PrefixCacheMatchInfoKey: approximateprefix.PrefixCacheMatchInfo{}} +} diff --git a/pkg/plugins/register.go b/pkg/plugins/register.go index f78c4afc7..d3c50fd31 100644 --- a/pkg/plugins/register.go +++ b/pkg/plugins/register.go @@ -1,25 +1,31 @@ package plugins import ( + "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/datalayer/models" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" prerequest "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/pre-request" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" ) // RegisterAllPlugins registers the factory functions of all plugins in this repository. func RegisterAllPlugins() { - plugins.Register(filter.ByLabelType, filter.ByLabelFactory) - plugins.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory) - plugins.Register(filter.DecodeRoleType, filter.DecodeRoleFactory) - plugins.Register(filter.PrefillRoleType, filter.PrefillRoleFactory) - plugins.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory) - plugins.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory) - plugins.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory) - plugins.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory) - plugins.Register(scorer.LoadAwareType, scorer.LoadAwareFactory) - plugins.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory) - plugins.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory) - plugins.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory) + plugin.Register(filter.ByLabelType, filter.ByLabelFactory) + plugin.Register(filter.ByLabelSelectorType, filter.ByLabelSelectorFactory) + plugin.Register(filter.DecodeRoleType, filter.DecodeRoleFactory) + plugin.Register(filter.PrefillRoleType, filter.PrefillRoleFactory) + plugin.Register(prerequest.PrefillHeaderHandlerType, prerequest.PrefillHeaderHandlerFactory) + plugin.Register(profile.DataParallelProfileHandlerType, profile.DataParallelProfileHandlerFactory) + plugin.Register(profile.PdProfileHandlerType, profile.PdProfileHandlerFactory) + plugin.Register(scorer.PrecisePrefixCachePluginType, scorer.PrecisePrefixCachePluginFactory) + plugin.Register(scorer.LoadAwareType, scorer.LoadAwareFactory) + plugin.Register(scorer.SessionAffinityType, scorer.SessionAffinityFactory) + plugin.Register(scorer.ActiveRequestType, scorer.ActiveRequestFactory) + plugin.Register(scorer.NoHitLRUType, scorer.NoHitLRUFactory) + plugin.Register(models.ModelsDataSourceType, models.ModelDataSourceFactory) + plugin.Register(models.ModelsExtractorType, models.ModelServerExtractorFactory) + // pd decider plugins + plugin.Register(profile.PrefixBasedPDDeciderPluginType, profile.PrefixBasedPDDeciderPluginFactory) + plugin.Register(profile.AlwaysDisaggDeciderPluginType, profile.AlwaysDisaggPDDeciderPluginFactory) } diff --git a/pkg/plugins/scorer/active_request.go b/pkg/plugins/scorer/active_request.go index 14e2e4169..a35791b56 100644 --- a/pkg/plugins/scorer/active_request.go +++ b/pkg/plugins/scorer/active_request.go @@ -4,17 +4,17 @@ import ( "context" "encoding/json" "fmt" + "strings" "sync" "time" "github.com/jellydator/ttlcache/v3" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -38,22 +38,34 @@ type ActiveRequestParameters struct { // requestEntry represents a single request in the cache type requestEntry struct { - PodName string + PodNames []string RequestID string } // String returns a string representation of the request entry. -func (r *requestEntry) String() string { - return fmt.Sprintf("%s.%s", r.PodName, r.RequestID) +func (r requestEntry) String() string { + return fmt.Sprintf("%s:%s", r.RequestID, strings.Join(r.PodNames, ".")) +} + +// endpointScores implements logr.Marshaler to lazily convert endpoint keys +// to strings only when the log line is actually written. +type endpointScores map[scheduling.Endpoint]float64 + +func (s endpointScores) MarshalLog() interface{} { + result := make(map[string]float64, len(s)) + for ep, score := range s { + result[ep.GetMetadata().NamespacedName.String()] = score + } + return result } // compile-time type assertion -var _ framework.Scorer = &ActiveRequest{} +var _ scheduling.Scorer = &ActiveRequest{} var _ requestcontrol.PreRequest = &ActiveRequest{} var _ requestcontrol.ResponseComplete = &ActiveRequest{} // ActiveRequestFactory defines the factory function for the ActiveRequest scorer. -func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { +func ActiveRequestFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := ActiveRequestParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -86,18 +98,20 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act ) scorer := &ActiveRequest{ - typedName: plugins.TypedName{Type: ActiveRequestType}, - requestCache: requestCache, - podCounts: make(map[string]int), - mutex: &sync.RWMutex{}, + typedName: plugin.TypedName{Type: ActiveRequestType}, + requestCache: requestCache, + endpointCounts: make(map[string]int), + mutex: &sync.RWMutex{}, } // callback to decrement count when requests expire // most requests will be removed in ResponseComplete, but this ensures - // that we don't leak pod counts if ResponseComplete is not called + // that we don't leak endpoint counts if ResponseComplete is not called requestCache.OnEviction(func(_ context.Context, reason ttlcache.EvictionReason, item *ttlcache.Item[string, *requestEntry]) { if reason == ttlcache.EvictionReasonExpired { - scorer.decrementPodCount(item.Value().PodName) + for _, endpointName := range item.Value().PodNames { + scorer.decrementPodCount(endpointName) + } } }) @@ -107,20 +121,20 @@ func NewActiveRequest(ctx context.Context, params *ActiveRequestParameters) *Act } // ActiveRequest keeps track of individual requests being served -// per pod. +// per endpoint. type ActiveRequest struct { - typedName plugins.TypedName + typedName plugin.TypedName - // requestCache stores individual request entries with unique composite keys (podName.requestID) + // requestCache stores individual request entries with unique composite keys (endpointName.requestID) requestCache *ttlcache.Cache[string, *requestEntry] - // podCounts maintains fast lookup for request counts per pod - podCounts map[string]int - mutex *sync.RWMutex + // endpointCounts maintains fast lookup for request counts per endpoint + endpointCounts map[string]int + mutex *sync.RWMutex } // TypedName returns the typed name of the plugin. -func (s *ActiveRequest) TypedName() plugins.TypedName { +func (s *ActiveRequest) TypedName() plugin.TypedName { return s.typedName } @@ -130,110 +144,131 @@ func (s *ActiveRequest) WithName(name string) *ActiveRequest { return s } -// Score scores the given pods based on the number of active requests -// being served by each pod. The score is normalized to a range of 0-1. -func (s *ActiveRequest) Score(ctx context.Context, _ *types.CycleState, _ *types.LLMRequest, - pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[string]int) +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *ActiveRequest) Category() scheduling.ScorerCategory { + return scheduling.Distribution +} + +// Score scores the given endpoints based on the number of active requests +// being served by each endpoint. The score is normalized to a range of 0-1. +func (s *ActiveRequest) Score(ctx context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, + endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[string]int) maxCount := 0 s.mutex.RLock() - for podName, count := range s.podCounts { - scoredPods[podName] = count + for endpointName, count := range s.endpointCounts { + scoredEndpoints[endpointName] = count if count >= maxCount { maxCount = count } } s.mutex.RUnlock() - scoredPodsMap := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - podName := pod.GetPod().NamespacedName.String() - if count, exists := scoredPods[podName]; exists { + log.FromContext(ctx).V(logutil.DEBUG).Info("Active request counts", "endpointCounts", scoredEndpoints, "maxCount", maxCount) + + scoredEndpointsMap := make(map[scheduling.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + endpointName := endpoint.GetMetadata().NamespacedName.String() + if count, exists := scoredEndpoints[endpointName]; exists { if count == 0 || maxCount == 0 { - scoredPodsMap[pod] = 1.0 // no requests means highest score + scoredEndpointsMap[endpoint] = 1.0 // no requests means highest score } else { - scoredPodsMap[pod] = float64(maxCount-count) / float64(maxCount) + scoredEndpointsMap[endpoint] = float64(maxCount-count) / float64(maxCount) } } else { - scoredPodsMap[pod] = 1.0 + scoredEndpointsMap[endpoint] = 1.0 } } - log.FromContext(ctx).V(logutil.DEBUG).Info("Scored pods", "scores", scoredPodsMap) - return scoredPodsMap + log.FromContext(ctx).V(logutil.DEBUG).Info("Scored endpoints", "scores", endpointScores(scoredEndpointsMap)) + return scoredEndpointsMap } -// PreRequest is called before a request is sent to the target pod. +// PreRequest is called before a request is sent to the target endpoint. // It creates a new request entry in the cache with its own TTL and -// increments the pod count for fast lookup. -func (s *ActiveRequest) PreRequest(ctx context.Context, request *types.LLMRequest, - schedulingResult *types.SchedulingResult) { +// increments the endpoint count for fast lookup. +func (s *ActiveRequest) PreRequest( + ctx context.Context, + request *scheduling.LLMRequest, + schedulingResult *scheduling.SchedulingResult, +) { debugLogger := log.FromContext(ctx).V(logutil.DEBUG) - for _, profileResult := range schedulingResult.ProfileResults { // schedulingResult guaranteed not to be nil - if profileResult == nil || profileResult.TargetPods == nil || len(profileResult.TargetPods) == 0 { + endpointNames := make([]string, 0, len(schedulingResult.ProfileResults)) + for profileName, profileResult := range schedulingResult.ProfileResults { + if profileResult == nil || len(profileResult.TargetEndpoints) == 0 { continue } - // create request entry for first pod only. TODO: support fallback pods - entry := &requestEntry{ - PodName: profileResult.TargetPods[0].GetPod().NamespacedName.String(), - RequestID: request.RequestId, - } - - // add to request cache with TTL - s.requestCache.Set(entry.String(), entry, 0) // Use default TTL - s.incrementPodCount(entry.PodName) - - debugLogger.Info("Added request to cache", "requestEntry", entry.String()) + endpointName := profileResult.TargetEndpoints[0].GetMetadata().NamespacedName.String() + endpointNames = append(endpointNames, endpointName) + s.incrementPodCount(endpointName) + debugLogger.Info( + "Added request to cache", + "requestId", request.RequestId, + "endpointName", endpointName, + "profileName", profileName, + ) } + + // add to request cache + s.requestCache.Set(request.RequestId, &requestEntry{PodNames: endpointNames, RequestID: request.RequestId}, 0) // Use default TTL } // ResponseComplete is called after a response is sent to the client. // It removes the specific request entry from the cache and decrements -// the pod count. -func (s *ActiveRequest) ResponseComplete(ctx context.Context, request *types.LLMRequest, - _ *requestcontrol.Response, targetPod *backend.Pod) { +// the endpoint count. +func (s *ActiveRequest) ResponseComplete( + ctx context.Context, + request *scheduling.LLMRequest, + _ *requestcontrol.Response, + targetPod *datalayer.EndpointMetadata, +) { debugLogger := log.FromContext(ctx).V(logutil.DEBUG).WithName("ActiveRequest.ResponseComplete") if targetPod == nil { debugLogger.Info("Skipping ResponseComplete because targetPod is nil") return } - entry := requestEntry{targetPod.NamespacedName.String(), request.RequestId} - - if _, found := s.requestCache.GetAndDelete(entry.String()); found { - s.decrementPodCount(entry.PodName) - debugLogger.Info("Removed request from cache", "requestEntry", entry.String()) + if item, found := s.requestCache.GetAndDelete(request.RequestId); found { + entry := item.Value() + if entry != nil { + for _, endpointName := range entry.PodNames { + s.decrementPodCount(endpointName) + } + debugLogger.Info("Removed request from cache", "requestEntry", entry.String()) + } else { + debugLogger.Info("Request entry value is nil", "requestId", request.RequestId) + } } else { - debugLogger.Info("Request not found in cache", "requestEntry", entry.String()) + debugLogger.Info("Request not found in cache", "requestId", request.RequestId) } } -// incrementPodCount increments the request count for a pod. -func (s *ActiveRequest) incrementPodCount(podName string) { +// incrementPodCount increments the request count for a endpoint. +func (s *ActiveRequest) incrementPodCount(endpointName string) { s.mutex.Lock() defer s.mutex.Unlock() - s.podCounts[podName]++ + s.endpointCounts[endpointName]++ } -// decrementPodCount decrements the request count for a pod and removes +// decrementPodCount decrements the request count for a endpoint and removes // the entry if count reaches zero. -func (s *ActiveRequest) decrementPodCount(podName string) { +func (s *ActiveRequest) decrementPodCount(endpointName string) { s.mutex.Lock() defer s.mutex.Unlock() - if count, exists := s.podCounts[podName]; exists { + if count, exists := s.endpointCounts[endpointName]; exists { if count <= 1 { - delete(s.podCounts, podName) + delete(s.endpointCounts, endpointName) } else { - s.podCounts[podName] = count - 1 + s.endpointCounts[endpointName] = count - 1 } } } -func cleanCachePeriodically(ctx context.Context, cache *ttlcache.Cache[string, *requestEntry], requestTimeout time.Duration) { +func cleanCachePeriodically[K comparable, V any](ctx context.Context, cache *ttlcache.Cache[K, V], requestTimeout time.Duration) { ticker := time.NewTicker(requestTimeout) defer ticker.Stop() diff --git a/pkg/plugins/scorer/active_request_test.go b/pkg/plugins/scorer/active_request_test.go index e7215ce1a..12ab609aa 100644 --- a/pkg/plugins/scorer/active_request_test.go +++ b/pkg/plugins/scorer/active_request_test.go @@ -4,84 +4,112 @@ import ( "testing" "time" - "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) -func TestActiveRequestScorer_Score(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 2, +// Test helper functions + +func newTestEndpoint(name string, queueSize int) scheduling.Endpoint { + return scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: name, Namespace: "default"}}, + &fwkdl.Metrics{ + WaitingQueueSize: queueSize, }, + nil, + ) +} + +func newTestRequest(id string) *scheduling.LLMRequest { + return &scheduling.LLMRequest{ + RequestId: id, } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 0, - }, +} + +func newTestSchedulingResult(profileEndpoints map[string]scheduling.Endpoint) *scheduling.SchedulingResult { + profileResults := make(map[string]*scheduling.ProfileRunResult) + for profile, endpoint := range profileEndpoints { + profileResults[profile] = &scheduling.ProfileRunResult{ + TargetEndpoints: []scheduling.Endpoint{endpoint}, + } } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 15, - }, + return &scheduling.SchedulingResult{ + ProfileResults: profileResults, } +} + +func (s *ActiveRequest) getPodCount(endpointName string) int { + s.mutex.RLock() + defer s.mutex.RUnlock() + return s.endpointCounts[endpointName] +} + +func (s *ActiveRequest) hasPodCount(endpointName string) bool { + s.mutex.RLock() + defer s.mutex.RUnlock() + _, exists := s.endpointCounts[endpointName] + return exists +} + +func TestActiveRequestScorer_Score(t *testing.T) { + endpointA := newTestEndpoint("pod-a", 2) + endpointB := newTestEndpoint("pod-b", 0) + endpointC := newTestEndpoint("pod-c", 15) tests := []struct { name string setupCache func(*ActiveRequest) - input []types.Pod - wantScores map[types.Pod]float64 + input []scheduling.Endpoint + wantScores map[scheduling.Endpoint]float64 }{ { - name: "no pods in cache", + name: "no endpoints in cache", setupCache: func(_ *ActiveRequest) { // Cache is empty }, - input: []types.Pod{podA, podB, podC}, - wantScores: map[types.Pod]float64{ - podA: 1, - podB: 1, - podC: 1, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 1, + endpointB: 1, + endpointC: 1, }, }, { - name: "all pods in cache with different request counts", + name: "all endpoints in cache with different request counts", setupCache: func(s *ActiveRequest) { s.mutex.Lock() - s.podCounts["default/pod-a"] = 3 - s.podCounts["default/pod-b"] = 0 - s.podCounts["default/pod-c"] = 6 + s.endpointCounts["default/pod-a"] = 3 + s.endpointCounts["default/pod-b"] = 0 + s.endpointCounts["default/pod-c"] = 6 s.mutex.Unlock() }, - input: []types.Pod{podA, podB, podC}, - wantScores: map[types.Pod]float64{ - podA: 0.5, - podB: 1.0, - podC: 0.0, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.5, + endpointB: 1.0, + endpointC: 0.0, }, }, { - name: "some pods in cache", + name: "some endpoints in cache", setupCache: func(s *ActiveRequest) { s.mutex.Lock() - s.podCounts["default/pod-a"] = 4 - s.podCounts["default/pod-c"] = 1 + s.endpointCounts["default/pod-a"] = 4 + s.endpointCounts["default/pod-c"] = 1 // pod-b not in cache s.mutex.Unlock() }, - input: []types.Pod{podA, podB, podC}, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 1.0, - podC: 0.75, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 1.0, + endpointC: 0.75, }, }, } @@ -95,9 +123,7 @@ func TestActiveRequestScorer_Score(t *testing.T) { got := scorer.Score(ctx, nil, nil, test.input) - if diff := cmp.Diff(test.wantScores, got); diff != "" { - t.Errorf("Unexpected output (-want +got): %v", diff) - } + assert.Equal(t, test.wantScores, got) }) } } @@ -106,124 +132,57 @@ func TestActiveRequestScorer_PreRequest(t *testing.T) { ctx := utils.NewTestContext(t) scorer := NewActiveRequest(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 2, - }, - } - - request := &types.LLMRequest{ - RequestId: "test-request-1", - } - - schedulingResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } + endpointA := newTestEndpoint("pod-a", 2) + endpointB := newTestEndpoint("pod-b", 0) - // First request - scorer.PreRequest(ctx, request, schedulingResult) + testProfile := "test-profile" - // Check cache and pod counts - compositeKey := "default/pod-a.test-request-1" - if !scorer.requestCache.Has(compositeKey) { - t.Errorf("Expected request to be in cache with key %s", compositeKey) - } + t.Run("First request", func(t *testing.T) { + request := newTestRequest("test-request-1") + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + testProfile: endpointA, + }) - scorer.mutex.RLock() - count := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if count != 1 { - t.Errorf("Expected pod-a count to be 1, got %d", count) - } + scorer.PreRequest(ctx, request, schedulingResult) - // Second request with different ID to same pod - request2 := &types.LLMRequest{ - RequestId: "test-request-2", - } - schedulingResult2 := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } + assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache") + assert.Equal(t, 1, scorer.getPodCount(endpointA.GetMetadata().NamespacedName.String())) + }) - scorer.PreRequest(ctx, request2, schedulingResult2) + t.Run("Second request to multiple endpoints", func(t *testing.T) { + request := newTestRequest("test-request-2") + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + testProfile: endpointA, + "prefill": endpointB, + }) - // Check incremented count - scorer.mutex.RLock() - count = scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if count != 2 { - t.Errorf("Expected pod-a count to be 2, got %d", count) - } + scorer.PreRequest(ctx, request, schedulingResult) - // Check both requests are in cache - compositeKey2 := "default/pod-a.test-request-2" - if !scorer.requestCache.Has(compositeKey2) { - t.Errorf("Expected second request to be in cache with key %s", compositeKey2) - } + assert.True(t, scorer.requestCache.Has(request.RequestId), "Expected request to be in cache") + assert.Equal(t, 2, scorer.getPodCount(endpointA.GetMetadata().NamespacedName.String())) + assert.Equal(t, 1, scorer.getPodCount(endpointB.GetMetadata().NamespacedName.String())) + }) } func TestActiveRequestScorer_ResponseComplete(t *testing.T) { ctx := utils.NewTestContext(t) - scorer := NewActiveRequest(ctx, nil) - request := &types.LLMRequest{ - RequestId: "test-request-1", - } + endpointA := newTestEndpoint("pod-a", 2) + request := newTestRequest("test-request-1") - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{ - WaitingQueueSize: 2, - }, - } // Setup initial state: add request through PreRequest - schedulingResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } - + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + "test-profile": endpointA, + }) scorer.PreRequest(ctx, request, schedulingResult) - // Verify initial state - compositeKey := "default/pod-a.test-request-1" - if !scorer.requestCache.Has(compositeKey) { - t.Fatal("Request should be in cache before ResponseComplete") - } - - scorer.mutex.RLock() - initialCount := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if initialCount != 1 { - t.Fatalf("Expected initial count to be 1, got %d", initialCount) - } - - // Call PostResponse - scorer.ResponseComplete(ctx, request, &requestcontrol.Response{}, podA.GetPod()) + // Call ResponseComplete + scorer.ResponseComplete(ctx, request, &requestcontrol.Response{}, endpointA.GetMetadata()) - // Check request is removed from cache - if scorer.requestCache.Has(compositeKey) { - t.Errorf("Request should be removed from cache after ResponseComplete") - } - - // Check pod count is decremented and removed (since it was 1) - scorer.mutex.RLock() - _, exists := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if exists { - t.Errorf("Pod should be removed from podCounts when count reaches 0") - } + assert.False(t, scorer.requestCache.Has(request.RequestId)) + assert.False(t, scorer.hasPodCount(endpointA.GetMetadata().NamespacedName.String()), + "Pod count should be removed after decrement to zero") } func TestActiveRequestScorer_TTLExpiration(t *testing.T) { @@ -231,34 +190,19 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { // Use very short timeout for test params := &ActiveRequestParameters{RequestTimeout: "1s"} - scorer := NewActiveRequest(ctx, params) // 1 second timeout - - request := &types.LLMRequest{ - RequestId: "test-request-ttl", - } + scorer := NewActiveRequest(ctx, params) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - } - - schedulingResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - "test-profile": { - TargetPods: []types.Pod{podA}, - }, - }, - } + endpointA := newTestEndpoint("pod-a", 0) + request := newTestRequest("test-request-ttl") + schedulingResult := newTestSchedulingResult(map[string]scheduling.Endpoint{ + "test-profile": endpointA, + }) // Add request scorer.PreRequest(ctx, request, schedulingResult) // Verify request is added - scorer.mutex.RLock() - initialCount := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if initialCount != 1 { - t.Fatalf("Expected initial count to be 1, got %d", initialCount) - } + require.Equal(t, 1, scorer.getPodCount("default/pod-a"), "Expected initial count to be 1") // Wait for TTL expiration time.Sleep(2 * time.Second) @@ -266,13 +210,9 @@ func TestActiveRequestScorer_TTLExpiration(t *testing.T) { // Trigger cleanup scorer.requestCache.DeleteExpired() - // Check that pod count is decremented due to TTL expiration - scorer.mutex.RLock() - _, exists := scorer.podCounts["default/pod-a"] - scorer.mutex.RUnlock() - if exists { - t.Errorf("Pod should be removed from podCounts after TTL expiration") - } + // Check that endpoint count is decremented due to TTL expiration + assert.False(t, scorer.hasPodCount("default/pod-a"), + "Pod should be removed from endpointCounts after TTL expiration") } func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { @@ -282,9 +222,7 @@ func TestNewActiveRequestScorer_InvalidTimeout(t *testing.T) { scorer := NewActiveRequest(ctx, params) // Should use default timeout when invalid value is provided - if scorer == nil { - t.Error("Expected scorer to be created even with invalid timeout") - } + assert.NotNil(t, scorer, "Expected scorer to be created even with invalid timeout") } func TestActiveRequestScorer_TypedName(t *testing.T) { @@ -292,10 +230,7 @@ func TestActiveRequestScorer_TypedName(t *testing.T) { scorer := NewActiveRequest(ctx, nil) - typedName := scorer.TypedName() - if typedName.Type != ActiveRequestType { - t.Errorf("Expected type %s, got %s", ActiveRequestType, typedName.Type) - } + assert.Equal(t, ActiveRequestType, scorer.TypedName().Type) } func TestActiveRequestScorer_WithName(t *testing.T) { @@ -306,7 +241,5 @@ func TestActiveRequestScorer_WithName(t *testing.T) { scorer = scorer.WithName(testName) - if scorer.TypedName().Name != testName { - t.Errorf("Expected name %s, got %s", testName, scorer.TypedName().Name) - } + assert.Equal(t, testName, scorer.TypedName().Name) } diff --git a/pkg/plugins/scorer/load_aware.go b/pkg/plugins/scorer/load_aware.go index c4f86d0bb..4f3ef918b 100644 --- a/pkg/plugins/scorer/load_aware.go +++ b/pkg/plugins/scorer/load_aware.go @@ -6,10 +6,9 @@ import ( "fmt" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -25,10 +24,10 @@ type loadAwareParameters struct { } // compile-time type assertion -var _ framework.Scorer = &LoadAware{} +var _ scheduling.Scorer = &LoadAware{} // LoadAwareFactory defines the factory function for the LoadAware -func LoadAwareFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { +func LoadAwareFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := loadAwareParameters{Threshold: QueueThresholdDefault} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -47,19 +46,19 @@ func NewLoadAware(ctx context.Context, queueThreshold int) *LoadAware { } return &LoadAware{ - typedName: plugins.TypedName{Type: LoadAwareType}, + typedName: plugin.TypedName{Type: LoadAwareType}, queueThreshold: float64(queueThreshold), } } // LoadAware scorer that is based on load type LoadAware struct { - typedName plugins.TypedName + typedName plugin.TypedName queueThreshold float64 } // TypedName returns the typed name of the plugin. -func (s *LoadAware) TypedName() plugins.TypedName { +func (s *LoadAware) TypedName() plugin.TypedName { return s.typedName } @@ -69,6 +68,11 @@ func (s *LoadAware) WithName(name string) *LoadAware { return s } +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *LoadAware) Category() scheduling.ScorerCategory { + return scheduling.Distribution +} + // Score scores the given pod in range of 0-1 // Currently metrics contains number of requests waiting in the queue, there is no information about number of requests // that can be processed in the given pod immediately. @@ -76,20 +80,20 @@ func (s *LoadAware) WithName(name string) *LoadAware { // Pod with requests in the queue will get score between 0.5 and 0. // Score 0 will get pod with number of requests in the queue equal to the threshold used in load-based filter // In the future, pods with additional capacity will get score higher than 0.5 -func (s *LoadAware) Score(_ context.Context, _ *types.CycleState, _ *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64) +func (s *LoadAware) Score(_ context.Context, _ *scheduling.CycleState, _ *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64) - for _, pod := range pods { - waitingRequests := float64(pod.GetMetrics().WaitingQueueSize) + for _, endpoint := range endpoints { + waitingRequests := float64(endpoint.GetMetrics().WaitingQueueSize) if waitingRequests == 0 { - scoredPods[pod] = 0.5 + scoredEndpoints[endpoint] = 0.5 } else { if waitingRequests > s.queueThreshold { waitingRequests = s.queueThreshold } - scoredPods[pod] = 0.5 * (1.0 - (waitingRequests / s.queueThreshold)) + scoredEndpoints[endpoint] = 0.5 * (1.0 - (waitingRequests / s.queueThreshold)) } } - return scoredPods + return scoredEndpoints } diff --git a/pkg/plugins/scorer/load_aware_test.go b/pkg/plugins/scorer/load_aware_test.go index e693e99b5..c3e43a94b 100644 --- a/pkg/plugins/scorer/load_aware_test.go +++ b/pkg/plugins/scorer/load_aware_test.go @@ -6,56 +6,57 @@ import ( "github.com/google/go-cmp/cmp" - k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + k8stypes "k8s.io/apimachinery/pkg/types" // Import config for thresholds + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestLoadBasedScorer(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{ + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{ WaitingQueueSize: 2, }, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{ + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, - MetricsState: &backendmetrics.MetricsState{ + nil, + ) + endpointC := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, + &fwkdl.Metrics{ WaitingQueueSize: 15, }, - } + nil, + ) tests := []struct { name string - scorer framework.Scorer - req *types.LLMRequest - input []types.Pod - wantScores map[types.Pod]float64 + scorer scheduling.Scorer + req *scheduling.LLMRequest + input []scheduling.Endpoint + wantScores map[scheduling.Endpoint]float64 }{ { name: "load based scorer", scorer: scorer.NewLoadAware(utils.NewTestContext(t), 10), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "critical", }, - input: []types.Pod{ - podA, podB, podC, + input: []scheduling.Endpoint{ + endpointA, endpointB, endpointC, }, - wantScores: map[types.Pod]float64{ - podA: 0.4, - podB: 0.5, - podC: 0, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.4, + endpointB: 0.5, + endpointC: 0, }, }, } diff --git a/pkg/plugins/scorer/no_hit_lru.go b/pkg/plugins/scorer/no_hit_lru.go index 417cf05a5..65182199a 100644 --- a/pkg/plugins/scorer/no_hit_lru.go +++ b/pkg/plugins/scorer/no_hit_lru.go @@ -7,24 +7,29 @@ import ( lru "github.com/hashicorp/golang-lru/v2" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" ) const ( // NoHitLRUType is the type of the NoHitLRU scorer NoHitLRUType = "no-hit-lru-scorer" - // defaultLRUSize is the maximum number of pods we'll consider in the cache + // defaultLRUSize is the maximum number of endpoints we'll consider in the cache defaultLRUSize = 1024 + + // defaultPrefillProfile is the name of the prefill profile + // + // This is currently hardcoded until we have a defined proper config interface. + // (See also https://github.com/kubernetes-sigs/gateway-api-inference-extension/pull/2104/ ) + defaultPrefillProfile = "prefill" ) // compile-time type assertions -var _ framework.Scorer = &NoHitLRU{} +var _ scheduling.Scorer = &NoHitLRU{} var _ requestcontrol.PreRequest = &NoHitLRU{} // NoHitLRUParameters defines the parameters for the NoHitLRU scorer. @@ -36,7 +41,7 @@ type NoHitLRUParameters struct { // Defaults to "prefix-cache-scorer". PrefixPluginName string `json:"prefixPluginName"` - // LRUSize defines the maximum number of pods to track in the LRU cache. + // LRUSize defines the maximum number of endpoints to track in the LRU cache. LRUSize int `json:"lruSize"` } @@ -46,13 +51,13 @@ type coldRequestState struct { isCold bool } -// Clone implements the plugins.StateData interface -func (c *coldRequestState) Clone() plugins.StateData { +// Clone implements the plugin.StateData interface +func (c *coldRequestState) Clone() plugin.StateData { return &coldRequestState{isCold: c.isCold} } // NoHitLRUFactory defines the factory function for the NoHitLRU -func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugins.Handle) (plugins.Plugin, error) { +func NoHitLRUFactory(name string, rawParameters json.RawMessage, handle plugin.Handle) (plugin.Plugin, error) { parameters := NoHitLRUParameters{} if rawParameters != nil { if err := json.Unmarshal(rawParameters, ¶meters); err != nil { @@ -95,25 +100,25 @@ func NewNoHitLRU(ctx context.Context, params *NoHitLRUParameters) *NoHitLRU { } return &NoHitLRU{ - typedName: plugins.TypedName{Type: NoHitLRUType}, + typedName: plugin.TypedName{Type: NoHitLRUType}, lruCache: lruCache, - prefixPluginTypedName: plugins.TypedName{Type: prefixPluginType, Name: prefixPluginName}, - pluginState: plugins.NewPluginState(ctx), + prefixPluginTypedName: plugin.TypedName{Type: prefixPluginType, Name: prefixPluginName}, + pluginState: plugin.NewPluginState(ctx), } } -// NoHitLRU scorer that favors pods that were least recently used for cold requests. +// NoHitLRU scorer that favors endpoints that were least recently used for cold requests. // This can help evenly distribute cache growth, since cold requests result in more // new KV blocks. type NoHitLRU struct { - typedName plugins.TypedName - lruCache *lru.Cache[string, struct{}] // pod name -> dummy value (we only care about order) - prefixPluginTypedName plugins.TypedName - pluginState *plugins.PluginState + typedName plugin.TypedName + lruCache *lru.Cache[string, struct{}] // endpoint name -> dummy value (we only care about order) + prefixPluginTypedName plugin.TypedName + pluginState *plugin.PluginState } // TypedName returns the typed name of the plugin. -func (s *NoHitLRU) TypedName() plugins.TypedName { +func (s *NoHitLRU) TypedName() plugin.TypedName { return s.typedName } @@ -123,14 +128,19 @@ func (s *NoHitLRU) WithName(name string) *NoHitLRU { return s } +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *NoHitLRU) Category() scheduling.ScorerCategory { + return scheduling.Distribution +} + // isColdRequest determines if a request is cold by reading the prefix cache state. // Returns true if no prefix cache hits were found, or if prefix cache state is unavailable. -func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleState) bool { +func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *scheduling.CycleState) bool { logger := log.FromContext(ctx).V(logutil.DEBUG) // Read prefix cache state to determine if this is a cold request // This is treated as an optimization - if the state isn't available, we assume cold request - prefixState, err := types.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugins.StateKey(s.prefixPluginTypedName.String())) + prefixState, err := scheduling.ReadCycleStateKey[*prefix.SchedulingContextState](cycleState, plugin.StateKey(s.prefixPluginTypedName.String())) if err != nil { logger.Info("No prefix cache state found, treating as cold request for LRU optimization", "error", err) @@ -141,17 +151,17 @@ func (s *NoHitLRU) isColdRequest(ctx context.Context, cycleState *types.CycleSta return len(prefixState.PrefixCacheServers) == 0 } -// scoreNeutral returns neutral scores (0.5) for all pods. +// scoreNeutral returns neutral scores (0.5) for all endpoints. // Used when a request has cache hits and LRU optimization should not apply. -func (s *NoHitLRU) scoreNeutral(pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64, len(pods)) - for _, pod := range pods { - scoredPods[pod] = 0.5 +func (s *NoHitLRU) scoreNeutral(endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64, len(endpoints)) + for _, endpoint := range endpoints { + scoredEndpoints[endpoint] = 0.5 } - return scoredPods + return scoredEndpoints } -// getLRUPositions returns a map of pod names to their LRU position. +// getLRUPositions returns a map of endpoint names to their LRU position. // Position 0 represents the oldest (least recently used) entry. func (s *NoHitLRU) getLRUPositions() map[string]int { // Get all keys from LRU cache in order (oldest first) @@ -165,105 +175,105 @@ func (s *NoHitLRU) getLRUPositions() map[string]int { return lruPosition } -// partitionPodsByUsage separates pods into those that have received cold requests +// partitionPodsByUsage separates endpoints into those that have received cold requests // (usedPods) and those that have never received cold requests (neverUsedPods). -func (s *NoHitLRU) partitionPodsByUsage(pods []types.Pod, lruPosition map[string]int) (usedPods, neverUsedPods []types.Pod) { - for _, pod := range pods { - podName := pod.GetPod().NamespacedName.String() - if _, exists := lruPosition[podName]; exists { - usedPods = append(usedPods, pod) +func (s *NoHitLRU) partitionPodsByUsage(endpoints []scheduling.Endpoint, lruPosition map[string]int) (usedEndpoints, neverUsedEndpoints []scheduling.Endpoint) { + for _, endpoint := range endpoints { + endpointName := endpoint.GetMetadata().NamespacedName.String() + if _, exists := lruPosition[endpointName]; exists { + usedEndpoints = append(usedEndpoints, endpoint) } else { - neverUsedPods = append(neverUsedPods, pod) + neverUsedEndpoints = append(neverUsedEndpoints, endpoint) } } - return usedPods, neverUsedPods + return usedEndpoints, neverUsedEndpoints } -// scoreNeverUsedPods assigns scores to pods that have never received a cold request. -// The first never-used pod gets the highest score (1.0), with subsequent pods +// scoreNeverUsedEndpoints assigns scores to endpoints that have never received a cold request. +// The first never-used endpoint gets the highest score (1.0), with subsequent endpoints // receiving progressively lower scores. -func (s *NoHitLRU) scoreNeverUsedPods(scoredPods map[types.Pod]float64, neverUsedPods []types.Pod, totalPods int) { +func (s *NoHitLRU) scoreNeverUsedPods(scoredPods map[scheduling.Endpoint]float64, neverUsedPods []scheduling.Endpoint, totalEndpoints int) { // Avoid possibility of dividing by zero. - if totalPods <= 1 { + if totalEndpoints <= 1 { return } - for i, pod := range neverUsedPods { - score := 1.0 - float64(i)/float64(totalPods-1) - scoredPods[pod] = score + for i, endpoint := range neverUsedPods { + score := 1.0 - float64(i)/float64(totalEndpoints-1) + scoredPods[endpoint] = score } } -// scoreUsedPods assigns scores to pods based on their LRU position. +// scoreUsedPods assigns scores to endpoints based on their LRU position. // Pods that were least recently used for cold requests receive higher scores. -func (s *NoHitLRU) scoreUsedPods(scoredPods map[types.Pod]float64, usedPods []types.Pod, lruPosition map[string]int, neverUsedCount, totalPods int) { +func (s *NoHitLRU) scoreUsedPods(scoredEndpoints map[scheduling.Endpoint]float64, usedPods []scheduling.Endpoint, lruPosition map[string]int, neverUsedCount, totalEndpoints int) { // Avoid possibility of dividing by zero. - if totalPods <= 1 { + if totalEndpoints <= 1 { return } - for _, pod := range usedPods { - podName := pod.GetPod().NamespacedName.String() - lruPos := lruPosition[podName] + for _, endpoint := range usedPods { + endpointName := endpoint.GetMetadata().NamespacedName.String() + lruPos := lruPosition[endpointName] // LRU keys are oldest to newest so rank 0 = oldest - // The never used pod count is added to the rank so that - // a never-used pod will always have the highest score. + // The never used endpoint count is added to the rank so that + // a never-used endpoint will always have the highest score. rank := neverUsedCount + lruPos - score := 1.0 - float64(rank)/float64(totalPods-1) + score := 1.0 - float64(rank)/float64(totalEndpoints-1) if score < 0 { score = 0 } - scoredPods[pod] = score + scoredEndpoints[endpoint] = score } } -// scoreColdRequestByLRU scores pods based on their LRU position for cold requests. +// scoreColdRequestByLRU scores endpoints based on their LRU position for cold requests. // Pods that have never received a cold request get the highest scores. -// Among previously used pods, least recently used ones get higher scores. -func (s *NoHitLRU) scoreColdRequestByLRU(pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64, len(pods)) - totalPods := len(pods) +// Among previously used endpoints, least recently used ones get higher scores. +func (s *NoHitLRU) scoreColdRequestByLRU(endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64, len(endpoints)) + totalEndpoints := len(endpoints) // Avoid possibility of dividing by zero. - if totalPods == 1 { - scoredPods[pods[0]] = 1.0 - return scoredPods + if totalEndpoints == 1 { + scoredEndpoints[endpoints[0]] = 1.0 + return scoredEndpoints } lruPosition := s.getLRUPositions() - usedPods, neverUsedPods := s.partitionPodsByUsage(pods, lruPosition) + usedEndpoints, neverUsedEndpoints := s.partitionPodsByUsage(endpoints, lruPosition) - s.scoreNeverUsedPods(scoredPods, neverUsedPods, totalPods) - s.scoreUsedPods(scoredPods, usedPods, lruPosition, len(neverUsedPods), totalPods) + s.scoreNeverUsedPods(scoredEndpoints, neverUsedEndpoints, totalEndpoints) + s.scoreUsedPods(scoredEndpoints, usedEndpoints, lruPosition, len(neverUsedEndpoints), totalEndpoints) - return scoredPods + return scoredEndpoints } -// Score scores the given pods based on LRU for cold requests. -// For cache hits, returns neutral scores (0.5) for all pods. -// For cache misses, ranks pods by their LRU order. -// - LRU ordering is with respect to when a pod last received a cold request. -// - Least recently used (or never used) pods get highest score (1.0) -// - Most recently used pods get lowest score (approaching 0.0) -func (s *NoHitLRU) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +// Score scores the given endpoints based on LRU for cold requests. +// For cache hits, returns neutral scores (0.5) for all endpoints. +// For cache misses, ranks endpoints by their LRU order. +// - LRU ordering is with respect to when a endpoint last received a cold request. +// - Least recently used (or never used) endpoints get highest score (1.0) +// - Most recently used endpoints get lowest score (approaching 0.0) +func (s *NoHitLRU) Score(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { logger := log.FromContext(ctx).V(logutil.DEBUG) isCold := s.isColdRequest(ctx, cycleState) // Store the cold request state in plugin state for PreRequest to use coldState := &coldRequestState{isCold: isCold} - s.pluginState.Write(request.RequestId, plugins.StateKey(s.typedName.String()), coldState) + s.pluginState.Write(request.RequestId, plugin.StateKey(s.typedName.String()), coldState) if !isCold { logger.Info("Cache hit detected, returning neutral scores") - return s.scoreNeutral(pods) + return s.scoreNeutral(endpoints) } - logger.Info("Cold request detected, scoring pods by LRU") - return s.scoreColdRequestByLRU(pods) + logger.Info("Cold request detected, scoring endpoints by LRU") + return s.scoreColdRequestByLRU(endpoints) } -// PreRequest is called before a request is sent to the target pod. -// For cold requests, it updates the LRU cache to track which pods have been used recently. -func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, schedulingResult *types.SchedulingResult) { +// PreRequest is called before a request is sent to the target endpoint. +// For cold requests, it updates the LRU cache to track which endpoints have been used recently. +func (s *NoHitLRU) PreRequest(ctx context.Context, request *scheduling.LLMRequest, schedulingResult *scheduling.SchedulingResult) { logger := log.FromContext(ctx).V(logutil.DEBUG) if schedulingResult == nil || len(schedulingResult.ProfileResults) == 0 { @@ -272,7 +282,7 @@ func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, sc } // Read the cold request state we stored in Score - coldState, err := plugins.ReadPluginStateKey[*coldRequestState](s.pluginState, request.RequestId, plugins.StateKey(s.typedName.String())) + coldState, err := plugin.ReadPluginStateKey[*coldRequestState](s.pluginState, request.RequestId, plugin.StateKey(s.typedName.String())) // After fetching the cold state, drop it from the plugin state immediately (otherwise it will hang around until it becomes stale). s.pluginState.Delete(request.RequestId) @@ -286,19 +296,23 @@ func (s *NoHitLRU) PreRequest(ctx context.Context, request *types.LLMRequest, sc return } - // Get the primary profile's target pod - primaryProfile := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName] - if primaryProfile == nil || len(primaryProfile.TargetPods) == 0 { - logger.Info("No target pod in primary profile") - return + if targetProfile, ok := schedulingResult.ProfileResults[schedulingResult.PrimaryProfileName]; ok && targetProfile != nil && len(targetProfile.TargetEndpoints) != 0 { + s.moveTargetPodToFront(ctx, request, targetProfile, schedulingResult.PrimaryProfileName) } + if targetProfile, ok := schedulingResult.ProfileResults[defaultPrefillProfile]; ok && targetProfile != nil && len(targetProfile.TargetEndpoints) != 0 { + s.moveTargetPodToFront(ctx, request, targetProfile, defaultPrefillProfile) + } +} + +func (s *NoHitLRU) moveTargetPodToFront(ctx context.Context, request *scheduling.LLMRequest, targetProfile *scheduling.ProfileRunResult, profileName string) { + logger := log.FromContext(ctx).V(logutil.DEBUG) - targetPod := primaryProfile.TargetPods[0] - podName := targetPod.GetPod().NamespacedName.String() + targetPod := targetProfile.TargetEndpoints[0] + endpointName := targetPod.GetMetadata().NamespacedName.String() - // Move the pod to the front of the LRU. + // Move the endpoint to the front of the LRU. var present struct{} // dummy value - s.lruCache.Add(podName, present) + s.lruCache.Add(endpointName, present) - logger.Info("Updated LRU cache for cold request", "pod", podName, "requestId", request.RequestId) + logger.Info("Updated LRU cache for cold request", "profile", profileName, "endpoint", endpointName, "requestId", request.RequestId) } diff --git a/pkg/plugins/scorer/no_hit_lru_test.go b/pkg/plugins/scorer/no_hit_lru_test.go index 6890c998c..2af62b021 100644 --- a/pkg/plugins/scorer/no_hit_lru_test.go +++ b/pkg/plugins/scorer/no_hit_lru_test.go @@ -9,49 +9,47 @@ import ( "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) -var _ plugins.Handle = &fakeHandle{} +var _ plugin.Handle = &fakeHandle{} type fakeHandle struct { ctx context.Context - plugins map[string]plugins.Plugin + plugins map[string]plugin.Plugin } func newFakeHandle(ctx context.Context) *fakeHandle { - return &fakeHandle{ctx: ctx, plugins: map[string]plugins.Plugin{}} + return &fakeHandle{ctx: ctx, plugins: map[string]plugin.Plugin{}} } func (h *fakeHandle) Context() context.Context { return h.ctx } -func (h *fakeHandle) Plugin(name string) plugins.Plugin { +func (h *fakeHandle) Plugin(name string) plugin.Plugin { return h.plugins[name] } -func (h *fakeHandle) AddPlugin(name string, plugin plugins.Plugin) { +func (h *fakeHandle) AddPlugin(name string, plugin plugin.Plugin) { h.plugins[name] = plugin } -func (h *fakeHandle) GetAllPlugins() []plugins.Plugin { - result := make([]plugins.Plugin, 0, len(h.plugins)) +func (h *fakeHandle) GetAllPlugins() []plugin.Plugin { + result := make([]plugin.Plugin, 0, len(h.plugins)) for _, plugin := range h.plugins { result = append(result, plugin) } return result } -func (h *fakeHandle) GetAllPluginsWithNames() map[string]plugins.Plugin { +func (h *fakeHandle) GetAllPluginsWithNames() map[string]plugin.Plugin { return h.plugins } @@ -60,10 +58,10 @@ func (h *fakeHandle) PodList() []k8stypes.NamespacedName { } type stubPlugin struct { - name plugins.TypedName + name plugin.TypedName } -func (p *stubPlugin) TypedName() plugins.TypedName { +func (p *stubPlugin) TypedName() plugin.TypedName { return p.name } @@ -84,7 +82,7 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) { name: "prefix plugin present - should work", handle: func() *fakeHandle { h := newFakeHandle(utils.NewTestContext(t)) - h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugins.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}}) + h.AddPlugin(prefix.PrefixCachePluginType, &stubPlugin{name: plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}}) return h }(), expectError: false, @@ -123,87 +121,90 @@ func TestNoHitLRUFactoryDependencyValidation(t *testing.T) { } func TestNoHitLRUScorer(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, - MetricsState: &backendmetrics.MetricsState{}, - } + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointC := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}}, + &fwkdl.Metrics{}, + nil, + ) tests := []struct { name string - scorer framework.Scorer - req *types.LLMRequest - input []types.Pod + scorer scheduling.Scorer + req *scheduling.LLMRequest + input []scheduling.Endpoint prefixState *prefix.SchedulingContextState - wantScores map[types.Pod]float64 + wantScores map[scheduling.Endpoint]float64 description string }{ { - name: "cold request - all pods never used", + name: "cold request - all endpoints never used", scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "test-model", }, - input: []types.Pod{podA, podB, podC}, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, prefixState: &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request }, - wantScores: map[types.Pod]float64{ - podA: 1.0, // All never-used pods get high scores - podB: 0.5, - podC: 0.0, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 1.0, // All never-used endpoints get high scores + endpointB: 0.5, + endpointC: 0.0, }, - description: "Never-used pods should get high scores for cold requests", + description: "Never-used endpoints should get high scores for cold requests", }, { name: "cache hit - neutral scores", scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "test-model", }, - input: []types.Pod{podA, podB, podC}, + input: []scheduling.Endpoint{endpointA, endpointB, endpointC}, prefixState: &prefix.SchedulingContextState{ PrefixCacheServers: map[prefix.ServerID]int{ {Name: "server1", Namespace: "default"}: 5, // non-empty = cache hit }, }, - wantScores: map[types.Pod]float64{ - podA: 0.5, // All pods get neutral scores for cache hits - podB: 0.5, - podC: 0.5, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.5, // All endpoints get neutral scores for cache hits + endpointB: 0.5, + endpointC: 0.5, }, description: "Cache hits should return neutral scores", }, { - name: "single pod - max score", + name: "single endpoint - max score", scorer: scorer.NewNoHitLRU(utils.NewTestContext(t), nil), - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ TargetModel: "test-model", }, - input: []types.Pod{podA}, + input: []scheduling.Endpoint{endpointA}, prefixState: &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request }, - wantScores: map[types.Pod]float64{ - podA: 1.0, // Single pod gets max score + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 1.0, // Single endpoint gets max score }, - description: "Single pod should get maximum score", + description: "Single endpoint should get maximum score", }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { // Create cycle state and set prefix state - cycleState := &types.CycleState{} + cycleState := &scheduling.CycleState{} if test.prefixState != nil { - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), test.prefixState) } @@ -221,42 +222,44 @@ func TestNoHitLRUBasicFunctionality(t *testing.T) { scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{}, - } + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{}, + nil, + ) - pods := []types.Pod{podA, podB} + endpoints := []scheduling.Endpoint{endpointA, endpointB} // Test basic scoring for cold request (no crashes, returns valid scores) coldPrefixState := &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request } - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), coldPrefixState) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints) - // Should return scores for all pods + // Should return scores for all endpoints if len(scores) != 2 { t.Errorf("Expected 2 scores, got %d", len(scores)) } // All scores should be valid (between 0 and 1) - for pod, score := range scores { + for endpoint, score := range scores { if score < 0 || score > 1 { - t.Errorf("Invalid score %f for pod %s", score, pod.GetPod().NamespacedName.String()) + t.Errorf("Invalid score %f for endpoint %s", score, endpoint.GetMetadata().NamespacedName.String()) } } - // For never-used pods, should have different scores (to provide ordering) - if scores[podA] == scores[podB] { - t.Errorf("Expected different scores for different pods, both got %f", scores[podA]) + // For never-used endpoints, should have different scores (to provide ordering) + if scores[endpointA] == scores[endpointB] { + t.Errorf("Expected different scores for different endpoints, both got %f", scores[endpointA]) } } @@ -264,16 +267,17 @@ func TestNoPrefixCacheStateFound(t *testing.T) { ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - pods := []types.Pod{podA} - cycleState := &types.CycleState{} + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpoints := []scheduling.Endpoint{endpointA} + cycleState := &scheduling.CycleState{} - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints) - if scores[podA] != 1.0 { + if scores[endpointA] != 1.0 { t.Errorf("Failure to find a prefix cache should result in scoring as a cold request.") } } @@ -282,120 +286,123 @@ func TestNoHitLRUPreferLeastRecentlyUsedAfterColdRequests(t *testing.T) { ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podC := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - pods := []types.Pod{podA, podB, podC} + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointC := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-c", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + endpoints := []scheduling.Endpoint{endpointA, endpointB, endpointC} primaryProfile := "primary-profile" - toPrefixState := func(entries map[prefix.ServerID]int) *types.CycleState { - cycle := &types.CycleState{} - cycle.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + toPrefixState := func(entries map[prefix.ServerID]int) *scheduling.CycleState { + cycle := &scheduling.CycleState{} + cycle.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{PrefixCacheServers: entries}) return cycle } - requestToPod := func(target types.Pod) *types.SchedulingResult { - return &types.SchedulingResult{ + requestToEndpoint := func(target scheduling.Endpoint) *scheduling.SchedulingResult { + return &scheduling.SchedulingResult{ PrimaryProfileName: primaryProfile, - ProfileResults: map[string]*types.ProfileRunResult{ + ProfileResults: map[string]*scheduling.ProfileRunResult{ primaryProfile: { - TargetPods: []types.Pod{target}, + TargetEndpoints: []scheduling.Endpoint{target}, }, }, } } // Test LRU behavior indirectly through scoring rather than internal state - assertHighestScoredPod := func(expectedPod types.Pod, testName string) { + assertHighestScoredPod := func(expectedEndpoint scheduling.Endpoint, testName string) { t.Helper() - coldReq := &types.LLMRequest{RequestId: testName + "-scoring-check"} - scores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReq, pods) + coldReq := &scheduling.LLMRequest{RequestId: testName + "-scoring-check"} + scores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReq, endpoints) highestScore := -1.0 - var highestPod types.Pod - for pod, score := range scores { + var highestEndpoint scheduling.Endpoint + for endpoint, score := range scores { if score > highestScore { highestScore = score - highestPod = pod + highestEndpoint = endpoint } } - if highestPod.GetPod().NamespacedName.String() != expectedPod.GetPod().NamespacedName.String() { + if highestEndpoint.GetMetadata().NamespacedName.String() != expectedEndpoint.GetMetadata().NamespacedName.String() { t.Fatalf("expected %s to have highest score for LRU behavior, but %s had highest score (%f). All scores: %+v", - expectedPod.GetPod().NamespacedName.String(), - highestPod.GetPod().NamespacedName.String(), + expectedEndpoint.GetMetadata().NamespacedName.String(), + highestEndpoint.GetMetadata().NamespacedName.String(), highestScore, scores) } } t.Run("initial cold request seeds cache", func(_ *testing.T) { - coldReqA := &types.LLMRequest{RequestId: "cold-1"} - scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqA, pods) - scorer.PreRequest(ctx, coldReqA, requestToPod(podA)) - // After podA handles a cold request, other pods should score higher for new cold requests - assertHighestScoredPod(podB, "after-podA-used") + coldReqA := &scheduling.LLMRequest{RequestId: "cold-1"} + scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqA, endpoints) + scorer.PreRequest(ctx, coldReqA, requestToEndpoint(endpointA)) + // After endpointA handles a cold request, other endpoints should score higher for new cold requests + assertHighestScoredPod(endpointB, "after-endpointA-used") }) - t.Run("unused pods rank above existing ones", func(t *testing.T) { - coldReqCheck := &types.LLMRequest{RequestId: "cold-check"} - coldScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqCheck, pods) - if coldScores[podB] <= coldScores[podA] { - t.Fatalf("expected pod-b to outrank pod-a after pod-a handled previous cold request, scores=%+v", coldScores) + t.Run("unused endpoints rank above existing ones", func(t *testing.T) { + coldReqCheck := &scheduling.LLMRequest{RequestId: "cold-check"} + coldScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqCheck, endpoints) + if coldScores[endpointB] <= coldScores[endpointA] { + t.Fatalf("expected endpoint-b to outrank endpoint-a after endpoint-a handled previous cold request, scores=%+v", coldScores) } - if coldScores[podB] != 1.0 { - t.Fatalf("expected pod-b to score 1.0, scores=%+v", coldScores) + if coldScores[endpointB] != 1.0 { + t.Fatalf("expected endpoint-b to score 1.0, scores=%+v", coldScores) } - if coldScores[podC] != 0.5 { - t.Fatalf("expected pod-c to score 0.5, scores=%+v", coldScores) + if coldScores[endpointC] != 0.5 { + t.Fatalf("expected endpoint-c to score 0.5, scores=%+v", coldScores) } }) t.Run("warm request leaves LRU untouched", func(t *testing.T) { - warmReq := &types.LLMRequest{RequestId: "warm-1"} + warmReq := &scheduling.LLMRequest{RequestId: "warm-1"} warmState := map[prefix.ServerID]int{ {Name: "server1", Namespace: "default"}: 1, } - warmScores := scorer.Score(ctx, toPrefixState(warmState), warmReq, pods) + warmScores := scorer.Score(ctx, toPrefixState(warmState), warmReq, endpoints) for _, score := range warmScores { if score != 0.5 { t.Fatalf("expected neutral score for warm request, got %f", score) } } - scorer.PreRequest(ctx, warmReq, requestToPod(podB)) - postWarmReq := &types.LLMRequest{RequestId: "cold-after-warm"} - postWarmScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), postWarmReq, pods) - if postWarmScores[podB] <= postWarmScores[podA] { + scorer.PreRequest(ctx, warmReq, requestToEndpoint(endpointB)) + postWarmReq := &scheduling.LLMRequest{RequestId: "cold-after-warm"} + postWarmScores := scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), postWarmReq, endpoints) + if postWarmScores[endpointB] <= postWarmScores[endpointA] { t.Fatalf("expected warm request to leave ordering unchanged, scores=%+v", postWarmScores) } }) - t.Run("second cold request rotates to podB", func(_ *testing.T) { - // Simulate podB handling a cold request - coldReqB := &types.LLMRequest{RequestId: "cold-2"} - scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqB, pods) - scorer.PreRequest(ctx, coldReqB, requestToPod(podB)) - // Now podC should score highest since both podA and podB have been used - assertHighestScoredPod(podC, "after-podB-used") + t.Run("second cold request rotates to endpointB", func(_ *testing.T) { + // Simulate endpointB handling a cold request + coldReqB := &scheduling.LLMRequest{RequestId: "cold-2"} + scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqB, endpoints) + scorer.PreRequest(ctx, coldReqB, requestToEndpoint(endpointB)) + // Now endpointC should score highest since both endpointA and endpointB have been used + assertHighestScoredPod(endpointC, "after-endpointB-used") }) - t.Run("third cold request rotates back to podA", func(_ *testing.T) { - // Simulate podC handling a cold request - coldReqC := &types.LLMRequest{RequestId: "cold-3"} - scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqC, pods) - scorer.PreRequest(ctx, coldReqC, requestToPod(podC)) - // Now podA should score highest again (LRU rotation) - assertHighestScoredPod(podA, "after-podC-used") + t.Run("third cold request rotates back to endpointA", func(_ *testing.T) { + // Simulate endpointC handling a cold request + coldReqC := &scheduling.LLMRequest{RequestId: "cold-3"} + scorer.Score(ctx, toPrefixState(make(map[prefix.ServerID]int)), coldReqC, endpoints) + scorer.PreRequest(ctx, coldReqC, requestToEndpoint(endpointC)) + // Now endpointA should score highest again (LRU rotation) + assertHighestScoredPod(endpointA, "after-endpointC-used") }) } @@ -403,55 +410,177 @@ func TestNoHitLRUEdgeCases(t *testing.T) { ctx := utils.NewTestContext(t) scorer := scorer.NewNoHitLRU(ctx, nil) - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) - t.Run("empty pods list", func(t *testing.T) { - emptyPods := []types.Pod{} - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + t.Run("empty endpoints list", func(t *testing.T) { + emptyEndpoints := []scheduling.Endpoint{} + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, emptyPods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, emptyEndpoints) if len(scores) != 0 { - t.Errorf("Expected empty scores for empty pods list, got %d scores", len(scores)) + t.Errorf("Expected empty scores for empty endpoints list, got %d scores", len(scores)) } }) - t.Run("nil pods list", func(t *testing.T) { - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + t.Run("nil endpoints list", func(t *testing.T) { + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, nil) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, nil) if scores == nil { - t.Errorf("Expected non-nil scores map for nil pods list") + t.Errorf("Expected non-nil scores map for nil endpoints list") } if len(scores) != 0 { - t.Errorf("Expected empty scores for nil pods list, got %d scores", len(scores)) + t.Errorf("Expected empty scores for nil endpoints list, got %d scores", len(scores)) } }) - t.Run("single pod returns 1.0", func(t *testing.T) { - pods := []types.Pod{podA} - cycleState := &types.CycleState{} - cycleState.Write(plugins.StateKey(plugins.TypedName{Type: prefix.PrefixCachePluginType, + t.Run("single endpoint returns 1.0", func(t *testing.T) { + endpoints := []scheduling.Endpoint{endpointA} + cycleState := &scheduling.CycleState{} + cycleState.Write(plugin.StateKey(plugin.TypedName{Type: prefix.PrefixCachePluginType, Name: prefix.PrefixCachePluginType}.String()), &prefix.SchedulingContextState{ PrefixCacheServers: make(map[prefix.ServerID]int), // cold request }) - scores := scorer.Score(ctx, cycleState, &types.LLMRequest{}, pods) + scores := scorer.Score(ctx, cycleState, &scheduling.LLMRequest{}, endpoints) + + if scores[endpointA] != 1.0 { + t.Errorf("Expected single endpoint to get score 1.0, got %f", scores[endpointA]) + } + }) +} + +func TestNoHitLRUPrefillDecodeTracking(t *testing.T) { + // Prefill worker endpoints + prefillEndpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "prefill-a", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + prefillEndpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "prefill-b", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + + // Decode worker endpoints + decodeEndpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "decode-a", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + decodeEndpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "decode-b", Namespace: "default"}}, + &fwkdl.Metrics{}, + nil, + ) + + prefillEndpoints := []scheduling.Endpoint{prefillEndpointA, prefillEndpointB} + decodeEndpoints := []scheduling.Endpoint{decodeEndpointA, decodeEndpointB} + + coldPrefixState := &scheduling.CycleState{} + coldPrefixState.Write(plugin.StateKey(prefix.PrefixCachePluginType), &prefix.SchedulingContextState{ + PrefixCacheServers: make(map[prefix.ServerID]int), // empty = cold request + }) + + ctx := context.Background() + + t.Run("P/D scenario - both profiles tracked separately", func(t *testing.T) { + scorer := scorer.NewNoHitLRU(ctx, nil) + + // First cold request with P/D + req1 := &scheduling.LLMRequest{RequestId: "pd-request-1"} + scorer.Score(ctx, coldPrefixState, req1, append(prefillEndpoints, decodeEndpoints...)) + + // Simulate scheduling result with both prefill and decode profiles + pdResult := &scheduling.SchedulingResult{ + PrimaryProfileName: "decode", + ProfileResults: map[string]*scheduling.ProfileRunResult{ + "prefill": { + TargetEndpoints: []scheduling.Endpoint{prefillEndpointA}, + }, + "decode": { + TargetEndpoints: []scheduling.Endpoint{decodeEndpointA}, + }, + }, + } + scorer.PreRequest(ctx, req1, pdResult) + + // Second cold request - both prefillPodB and decodePodB should score higher + // since prefillPodA and decodePodA were just used + req2 := &scheduling.LLMRequest{RequestId: "pd-request-2"} + prefillScores := scorer.Score(ctx, coldPrefixState, req2, prefillEndpoints) + decodeScores := scorer.Score(ctx, coldPrefixState, req2, decodeEndpoints) + + if prefillScores[prefillEndpointB] <= prefillScores[prefillEndpointA] { + t.Errorf("Expected prefill-b to score higher than prefill-a after prefill-a was used: %+v", prefillScores) + } + + if decodeScores[decodeEndpointB] <= decodeScores[decodeEndpointA] { + t.Errorf("Expected decode-b to score higher than decode-a after decode-a was used: %+v", decodeScores) + } + }) + + t.Run("non-P/D scenario - only primary profile exists", func(t *testing.T) { + req := &scheduling.LLMRequest{RequestId: "non-pd-request"} + scorer := scorer.NewNoHitLRU(ctx, nil) + scorer.Score(ctx, coldPrefixState, req, decodeEndpoints) + + // Scheduling result with only decode profile (no prefill) + result := &scheduling.SchedulingResult{ + PrimaryProfileName: "decode", + ProfileResults: map[string]*scheduling.ProfileRunResult{ + "decode": { + TargetEndpoints: []scheduling.Endpoint{decodeEndpointA}, + }, + // No "prefill" profile in results + }, + } + // Should not panic when prefill profile doesn't exist + scorer.PreRequest(ctx, req, result) + + // Verify decodePodA was tracked + req2 := &scheduling.LLMRequest{RequestId: "non-pd-request-2"} + scores := scorer.Score(ctx, coldPrefixState, req2, decodeEndpoints) + + if scores[decodeEndpointB] <= scores[decodeEndpointA] { + t.Errorf("Expected decode-b to score higher than decode-a: %+v", scores) + } + }) + + t.Run("nil scheduling result - graceful handling", func(_ *testing.T) { + req := &scheduling.LLMRequest{RequestId: "nil-result"} + scorer := scorer.NewNoHitLRU(ctx, nil) + scorer.Score(ctx, coldPrefixState, req, decodeEndpoints) + + // Should not panic with nil result + scorer.PreRequest(ctx, req, nil) + }) + + t.Run("empty profile results - graceful handling", func(_ *testing.T) { + req := &scheduling.LLMRequest{RequestId: "empty-results"} + scorer := scorer.NewNoHitLRU(ctx, nil) + scorer.Score(ctx, coldPrefixState, req, decodeEndpoints) - if scores[podA] != 1.0 { - t.Errorf("Expected single pod to get score 1.0, got %f", scores[podA]) + result := &scheduling.SchedulingResult{ + PrimaryProfileName: "decode", + ProfileResults: map[string]*scheduling.ProfileRunResult{}, } + // Should not panic with empty profile results + scorer.PreRequest(ctx, req, result) }) } diff --git a/pkg/plugins/scorer/precise_prefix_cache.go b/pkg/plugins/scorer/precise_prefix_cache.go index 2ce6551c0..f839625d8 100644 --- a/pkg/plugins/scorer/precise_prefix_cache.go +++ b/pkg/plugins/scorer/precise_prefix_cache.go @@ -6,16 +6,18 @@ import ( "errors" "fmt" "os" + "time" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" - preprocessing "github.com/llm-d/llm-d-kv-cache-manager/pkg/preprocessing/chat_completions" + "github.com/jellydator/ttlcache/v3" + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache" + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" + preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" ) const ( @@ -26,8 +28,11 @@ const ( // PrecisePrefixCachePluginConfig holds the configuration for the // PrecisePrefixCacheScorer plugin. type PrecisePrefixCachePluginConfig struct { + // TokenProcessorConfig holds the configuration for the `kvblock.TokenProcessor` which is + // used to process tokens into KV-block keys. + TokenProcessorConfig *kvblock.TokenProcessorConfig `json:"tokenProcessorConfig"` // IndexerConfig holds the configuration for the `kvcache.Indexer` which is - // used to score pods based on the KV-cache index state. + // used to score endpoints based on the KV-cache index state. IndexerConfig *kvcache.Config `json:"indexerConfig"` // KVEventsConfig holds the configuration for the `kvevents.Pool` which is // used to subscribe to KV-cache events and update the internal KV-cache @@ -36,13 +41,12 @@ type PrecisePrefixCachePluginConfig struct { } // compile-time type assertion -var _ framework.Scorer = &PrecisePrefixCacheScorer{} +var _ scheduling.Scorer = &PrecisePrefixCacheScorer{} // PrecisePrefixCachePluginFactory defines the factory function for creating // a new instance of the PrefixCacheTrackingPlugin. func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, - handle plugins.Handle) (plugins.Plugin, error) { - + handle plugin.Handle) (plugin.Plugin, error) { indexerConfig, err := kvcache.NewDefaultConfig() if err != nil { return nil, fmt.Errorf("failed to initialize indexer config: %w", err) @@ -53,18 +57,24 @@ func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, KVEventsConfig: kvevents.DefaultConfig(), } - // read hugging face token from environment variable if set + if rawParameters != nil { + if err := json.Unmarshal(rawParameters, ¶meters); err != nil { + return nil, fmt.Errorf("failed to parse %s plugin config: %w", PrecisePrefixCachePluginType, err) + } + } + + // Apply HF token from environment if not already set if token := os.Getenv("HF_TOKEN"); token != "" && parameters.IndexerConfig != nil && parameters.IndexerConfig.TokenizersPoolConfig != nil && - parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig != nil { + parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig != nil && + parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken == "" { parameters.IndexerConfig.TokenizersPoolConfig.HFTokenizerConfig.HuggingFaceToken = token } - if rawParameters != nil { - if err := json.Unmarshal(rawParameters, ¶meters); err != nil { - return nil, fmt.Errorf("failed to parse %s plugin config: %w", PrecisePrefixCachePluginType, err) - } + // Validate model name is set + if parameters.IndexerConfig == nil || parameters.IndexerConfig.TokenizersPoolConfig == nil || parameters.IndexerConfig.TokenizersPoolConfig.ModelName == "" { + return nil, errors.New("modelName is required in indexerConfig.tokenizersPoolConfig") } scorer, err := New(handle.Context(), parameters) @@ -80,13 +90,19 @@ func PrecisePrefixCachePluginFactory(name string, rawParameters json.RawMessage, // based on the provided configuration. The `kvevents.Pool` is started // in a goroutine to listen for KV-cache events and update the internal // KV-cache index state. The `kvcache.Indexer` is also started in a goroutine -// to score pods based on the KV-cache index state. +// to score endpoints based on the KV-cache index state. // // If the configuration is invalid or if the indexer fails to initialize, // an error is returned. func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePrefixCacheScorer, error) { + if config.TokenProcessorConfig == nil { + config.TokenProcessorConfig = kvblock.DefaultTokenProcessorConfig() + } + + tokenProcessor := kvblock.NewChunkedTokenDatabase(config.TokenProcessorConfig) + // initialize the indexer - kvCacheIndexer, err := kvcache.NewKVCacheIndexer(ctx, config.IndexerConfig) + kvCacheIndexer, err := kvcache.NewKVCacheIndexer(ctx, config.IndexerConfig, tokenProcessor) if err != nil { return nil, fmt.Errorf("failed to create `kvcache.Indexer`: %w", err) } @@ -94,27 +110,66 @@ func New(ctx context.Context, config PrecisePrefixCachePluginConfig) (*PrecisePr go kvCacheIndexer.Run(ctx) // initialize the KV-events pool - pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex()) + pool := kvevents.NewPool(config.KVEventsConfig, kvCacheIndexer.KVBlockIndex(), tokenProcessor) pool.Start(ctx) + subscribersManager := kvevents.NewSubscriberManager(pool) + var subscribersCache *ttlcache.Cache[string, struct{}] + + // initialize the subscribers cache only if endpoint discovery is enabled + if config.KVEventsConfig.DiscoverPods { + // initialize the subscribers TTL cache + subscriptionTimeout := 10 * time.Minute + subscribersCache = ttlcache.New[string, struct{}]( + ttlcache.WithTTL[string, struct{}](subscriptionTimeout), + ) + subscribersCache.OnEviction(func(ctx context.Context, reason ttlcache.EvictionReason, + item *ttlcache.Item[string, struct{}], + ) { + if reason == ttlcache.EvictionReasonExpired { + subscribersManager.RemoveSubscriber(ctx, item.Key()) + } + }) + go cleanCachePeriodically(ctx, subscribersCache, subscriptionTimeout) + } + if config.KVEventsConfig.ZMQEndpoint != "" { + // setup local subscriber to support global socket mode + if err := subscribersManager.EnsureSubscriber(ctx, "local-subscriber", + config.KVEventsConfig.ZMQEndpoint, config.KVEventsConfig.TopicFilter, false); err != nil { + return nil, fmt.Errorf("failed to create local subscriber for global socket mode: %w", err) + } + } + return &PrecisePrefixCacheScorer{ - typedName: plugins.TypedName{Type: PrecisePrefixCachePluginType}, - kvCacheIndexer: kvCacheIndexer, + typedName: plugin.TypedName{Type: PrecisePrefixCachePluginType}, + kvCacheIndexer: kvCacheIndexer, + subscribersCache: subscribersCache, + subscribersManager: subscribersManager, + kvEventsConfig: config.KVEventsConfig, }, nil } // PrecisePrefixCacheScorer implements the framework.Scorer interface. // The scorer implements precise prefix-cache KV-block locality scoring. -// It uses the `kvcache.Indexer` to score pods based on the KV-cache index +// It uses the `kvcache.Indexer` to score endpoints based on the KV-cache index // state, and the `kvevents.Pool` to subscribe to KV-cache events // to keep the internal KV-cache index state up-to-date. type PrecisePrefixCacheScorer struct { - typedName plugins.TypedName + typedName plugin.TypedName kvCacheIndexer *kvcache.Indexer + + // until the IGW data-layer is ready to provide endpoint events, + // we maintain a TTL cache of known endpoints that are discovered through + // the scoring process. If a endpoint is not in the received endpoints list + // during scoring for a certain period, we consider it gone and + // stop its KV events subscription. + subscribersCache *ttlcache.Cache[string, struct{}] + subscribersManager *kvevents.SubscriberManager + kvEventsConfig *kvevents.Config } // TypedName returns the typed name of the plugin. -func (s *PrecisePrefixCacheScorer) TypedName() plugins.TypedName { +func (s *PrecisePrefixCacheScorer) TypedName() plugin.TypedName { return s.typedName } @@ -124,12 +179,37 @@ func (s *PrecisePrefixCacheScorer) WithName(name string) *PrecisePrefixCacheScor return s } -// Score scores the provided pod based on the KVCache index state. +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *PrecisePrefixCacheScorer) Category() scheduling.ScorerCategory { + return scheduling.Affinity +} + +// Score scores the provided endpoint based on the KVCache index state. // The returned scores are normalized to a range of 0-1. -func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { +func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { logger := log.FromContext(ctx).WithName(s.typedName.String()) debugLogger := logger.V(logutil.DEBUG) + if s.kvEventsConfig.DiscoverPods { + // update subscribers here temporarily + for _, endpoint := range endpoints { + endpointObj := endpoint.GetMetadata() + if endpointObj == nil { + continue + } + endpointKey := endpointObj.NamespacedName.String() + s.subscribersCache.Set(endpointKey, struct{}{}, 0) // use default TTL + + if err := s.subscribersManager.EnsureSubscriber(context.Background(), endpointKey, // dont use request ctx + fmt.Sprintf("tcp://%s:%d", endpointObj.Address, s.kvEventsConfig.PodDiscoveryConfig.SocketPort), + s.kvEventsConfig.TopicFilter, true); err != nil { + logger.Error(err, "Failed to ensure KV-events subscriber for endpoint", "endpoint", endpointKey, + "endpoint", endpointObj.Address) + continue + } + } + } + if request == nil { debugLogger.Info("Request is nil, skipping scoring") return nil @@ -137,41 +217,41 @@ func (s *PrecisePrefixCacheScorer) Score(ctx context.Context, cycleState *types. scores, err := s.getScores(ctx, request) if err != nil { - logger.Error(err, "Failed to get pod scores") + logger.Error(err, "Failed to get endpoint scores") return nil } - debugLogger.Info("Got pod scores", "scores", scores) + debugLogger.Info("Got endpoint scores", "scores", scores) - podToKey := func(pod types.Pod) (string, bool) { - metricsPod := pod.GetPod() - if metricsPod == nil { + endpointToKey := func(endpoint scheduling.Endpoint) (string, bool) { + metadata := endpoint.GetMetadata() + if metadata == nil { return "", false } - return metricsPod.Address, true + return metadata.Address, true } state := &prefix.SchedulingContextState{ PrefixHashes: []prefix.BlockHash{}, PrefixCacheServers: map[prefix.ServerID]int{}, } - for _, pod := range pods { - key, ok := podToKey(pod) + for _, endpoint := range endpoints { + key, ok := endpointToKey(endpoint) if !ok { continue } - state.PrefixCacheServers[prefix.ServerID(pod.GetPod().NamespacedName)] = int(scores[key]) + state.PrefixCacheServers[prefix.ServerID(endpoint.GetMetadata().NamespacedName)] = int(scores[key]) } - cycleState.Write(plugins.StateKey(s.typedName.String()), state) + cycleState.Write(plugin.StateKey(s.typedName.String()), state) - return indexedScoresToNormalizedScoredPods(pods, podToKey, scores) + return indexedScoresToNormalizedScoredPods(endpoints, endpointToKey, scores) } -// getScores retrieves the pod scores from the KV-cache indexer +// getScores retrieves the endpoint scores from the KV-cache indexer // based on the provided LLM request. // If the request contains chat completions, it processes them accordingly. // If the request contains regular completions, it uses the prompt directly. -func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types.LLMRequest) (map[string]float64, error) { +func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *scheduling.LLMRequest) (map[string]float64, error) { logger := log.FromContext(ctx).WithName(s.typedName.String()) traceLogger := logger.V(logutil.TRACE) @@ -186,8 +266,17 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types traceLogger.Info("Both chat/completions and completions present; defaulting to chat/completions") } - renderReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: make([]preprocessing.ChatMessage, 0), + // Convert messages to conversation format + conversations := make([]preprocessing.Conversation, len(request.Body.ChatCompletions.Messages)) + for i, msg := range request.Body.ChatCompletions.Messages { + conversations[i] = preprocessing.Conversation{ + Role: msg.Role, + Content: msg.Content.Raw, + } + } + + renderReq := &preprocessing.ApplyChatTemplateRequest{ + Conversation: [][]preprocessing.Conversation{conversations}, Tools: request.Body.ChatCompletions.Tools, Documents: request.Body.ChatCompletions.Documents, ChatTemplate: request.Body.ChatCompletions.ChatTemplate, @@ -197,22 +286,14 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types ChatTemplateKWArgs: request.Body.ChatCompletions.ChatTemplateKWArgs, } - // Convert messages to the format expected by the renderer - for _, msg := range request.Body.ChatCompletions.Messages { - renderReq.Conversations = append(renderReq.Conversations, preprocessing.ChatMessage{ - Role: msg.Role, - Content: msg.Content.Raw, - }) - } - traceLogger.Info("Processing chat completion request", - "messagesCount", len(renderReq.Conversations), + "messagesCount", len(conversations), "toolsCount", len(renderReq.Tools), "documentsCount", len(renderReq.Documents)) scores, err := s.kvCacheIndexer.GetPodScores(ctx, renderReq, "", request.TargetModel, nil) if err != nil { - return nil, fmt.Errorf("failed to get pod scores for chat/completions: %w", err) + return nil, fmt.Errorf("failed to get endpoint scores for chat/completions: %w", err) } return scores, nil } @@ -224,7 +305,7 @@ func (s *PrecisePrefixCacheScorer) getScores(ctx context.Context, request *types scores, err := s.kvCacheIndexer.GetPodScores(ctx, nil, prompt, request.TargetModel, nil) if err != nil { - return nil, fmt.Errorf("failed to get pod scores for completions: %w", err) + return nil, fmt.Errorf("failed to get endpoint scores for completions: %w", err) } return scores, nil } diff --git a/pkg/plugins/scorer/precise_prefix_cache_test.go b/pkg/plugins/scorer/precise_prefix_cache_test.go index eb7284b95..968497798 100644 --- a/pkg/plugins/scorer/precise_prefix_cache_test.go +++ b/pkg/plugins/scorer/precise_prefix_cache_test.go @@ -6,16 +6,15 @@ import ( "testing" "github.com/google/go-cmp/cmp" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvblock" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/kvcache/kvevents" - preprocessing "github.com/llm-d/llm-d-kv-cache-manager/pkg/preprocessing/chat_completions" - "github.com/llm-d/llm-d-kv-cache-manager/pkg/tokenization" + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache" + "github.com/llm-d/llm-d-kv-cache/pkg/kvcache/kvblock" + "github.com/llm-d/llm-d-kv-cache/pkg/kvevents" + preprocessing "github.com/llm-d/llm-d-kv-cache/pkg/preprocessing/chat_completions" + "github.com/llm-d/llm-d-kv-cache/pkg/tokenization" "github.com/stretchr/testify/require" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) @@ -37,34 +36,38 @@ func TestPrefixCacheTracking_Score(t *testing.T) { testcases := []struct { name string - pods []types.Pod - request *types.LLMRequest - kvBlockData func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry + endpoints []scheduling.Endpoint + request *scheduling.LLMRequest + kvBlockData func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry wantScoresByAddress map[string]float64 }{ { name: "nil request", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, + nil, + nil, + ), }, wantScoresByAddress: map[string]float64{}, // empty map }, { name: "empty request body", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, + nil, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", Body: nil, @@ -73,63 +76,66 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, { name: "longest prefix scorer (default scorer)", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 1, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 2, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") prompt := req.Completions.Prompt - testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig) + testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) require.NoError(t, err) // use the actual tokenizer on the test prompt - tokens, _, err := testTokenizer.Encode(prompt, model) + tokens, _, err := testTokenizer.Encode(prompt, model, true) require.NoError(t, err) // compute chunk hashes using the default block size tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) - chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model) + chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model) require.GreaterOrEqual(t, len(chunkKeys), 3, "Need at least 3 chunks for test") // populate kvblock.Index to test longest prefix matching: - // - chunk0 (first chunk): all pods have it (common prefix start) + // - chunk0 (first chunk): all endpoints have it (common prefix start) // - chunk1: pod-a and pod-b have it (pod-c drops off after chunk0) // - chunk2: only pod-a has it (pod-b drops off after chunk1) // LongestPrefixScorer uses intersection, so: @@ -138,17 +144,17 @@ func TestPrefixCacheTracking_Score(t *testing.T) { // pod-c: 1 chunk (0) -> score 1 // Normalized: (3-1)/(3-1) = 1.0, (2-1)/(3-1) = 0.5, (1-1)/(3-1) = 0.0 - return map[kvblock.Key][]kvblock.PodEntry{ - {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: { + return map[kvblock.BlockHash][]kvblock.PodEntry{ + chunkKeys[0]: { {PodIdentifier: "10.0.0.1:8080"}, {PodIdentifier: "10.0.0.2:8080"}, {PodIdentifier: "10.0.0.3:8080"}, }, - {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: { + chunkKeys[1]: { {PodIdentifier: "10.0.0.1:8080"}, {PodIdentifier: "10.0.0.2:8080"}, }, - {ModelName: model, ChunkHash: chunkKeys[2].ChunkHash}: { + chunkKeys[2]: { {PodIdentifier: "10.0.0.1:8080"}, }, } @@ -161,90 +167,99 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, { name: "chat completions request", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 1, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - ChatCompletions: &types.ChatCompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + ChatCompletions: &scheduling.ChatCompletionsRequest{ ChatTemplate: `{% for message in messages %}{{ message.role }}: {{ message.content }} -{% endfor %}`, - Messages: []types.Message{ + {% endfor %}`, + Messages: []scheduling.Message{ { Role: "user", - Content: types.Content{Raw: "Hello, how are you?"}, + Content: scheduling.Content{Raw: "Hello, how are you?"}, }, { Role: "assistant", - Content: types.Content{Raw: "I'm doing well, thank you for asking!"}, + Content: scheduling.Content{Raw: "I'm doing well, thank you for asking!"}, }, { Role: "user", - Content: types.Content{Raw: "Can you help me with a question about prefix caching in LLM inference?"}, + Content: scheduling.Content{Raw: "Can you help me with a question about prefix caching in LLM inference?"}, }, }, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.ChatCompletions, "req expected to use ChatCompletions API") // convert to preprocessing format - var chatMessages []preprocessing.ChatMessage + var conversations []preprocessing.Conversation for _, msg := range req.ChatCompletions.Messages { - chatMessages = append(chatMessages, preprocessing.ChatMessage{ + conversations = append(conversations, preprocessing.Conversation{ Role: msg.Role, Content: msg.Content.Raw, }) } + processor := preprocessing.NewChatTemplatingProcessor() + tokenizerCacheKey, err := processor.GetOrCreateTokenizerKey(t.Context(), &preprocessing.GetOrCreateTokenizerKeyRequest{ + IsLocal: true, + Model: "testdata/" + model, + }) + require.NoError(t, err) + // render the chat template - renderReq := &preprocessing.RenderJinjaTemplateRequest{ - Conversations: chatMessages, - ChatTemplate: req.ChatCompletions.ChatTemplate, + renderReq := &preprocessing.ApplyChatTemplateRequest{ + Key: tokenizerCacheKey, + Conversation: [][]preprocessing.Conversation{conversations}, + ChatTemplate: req.ChatCompletions.ChatTemplate, } - processor := preprocessing.NewChatTemplatingProcessor() - rendered, err := processor.RenderChatTemplate(t.Context(), renderReq) + rendered, err := processor.ApplyChatTemplate(t.Context(), renderReq) require.NoError(t, err) // tokenize rendered prompt - testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig) + testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) require.NoError(t, err) - tokens, _, err := testTokenizer.Encode(rendered.RenderedChats[0], model) + tokens, _, err := testTokenizer.Encode(rendered, model, false) require.NoError(t, err) tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) - chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model) + chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model) require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test") // pod-a has both chunks, pod-b has only the first - return map[kvblock.Key][]kvblock.PodEntry{ - {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: { + return map[kvblock.BlockHash][]kvblock.PodEntry{ + chunkKeys[0]: { {PodIdentifier: "10.0.0.1:8080"}, {PodIdentifier: "10.0.0.2:8080"}, }, - {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: { + chunkKeys[1]: { {PodIdentifier: "10.0.0.1:8080"}, }, } @@ -256,60 +271,63 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, { name: "partial prefix", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 1, }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 2, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") - testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig) + testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) require.NoError(t, err) - tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model) + tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model, true) require.NoError(t, err) tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) - chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model) + chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model) require.GreaterOrEqual(t, len(chunkKeys), 3, "Need at least 3 chunks for test") // Test partial prefix cache scenario: - // - chunk0: all pods (common prefix start) + // - chunk0: all endpoints (common prefix start) // - chunk1: only pod-a (creates a gap for pod-b and pod-c) // - chunk2: pod-a and pod-b (pod-b has this but missing chunk1) // @@ -317,16 +335,16 @@ func TestPrefixCacheTracking_Score(t *testing.T) { // pod-a: has chunks 0,1,2 contiguously -> score 3 // pod-b: has chunks 0,2 (missing 1) -> prefix stops at chunk0 -> score 1 // pod-c: has only chunk 0 -> score 1 - return map[kvblock.Key][]kvblock.PodEntry{ - {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: { + return map[kvblock.BlockHash][]kvblock.PodEntry{ + chunkKeys[0]: { {PodIdentifier: "10.0.0.1:8080"}, {PodIdentifier: "10.0.0.2:8080"}, {PodIdentifier: "10.0.0.3:8080"}, }, - {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: { + chunkKeys[1]: { {PodIdentifier: "10.0.0.1:8080"}, // only pod-a has chunk1 }, - {ModelName: model, ChunkHash: chunkKeys[2].ChunkHash}: { + chunkKeys[2]: { {PodIdentifier: "10.0.0.1:8080"}, {PodIdentifier: "10.0.0.2:8080"}, // pod-b has chunk2 but missing chunk1 }, @@ -342,204 +360,161 @@ func TestPrefixCacheTracking_Score(t *testing.T) { }, }, { - name: "different model names", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, - Address: "10.0.0.1:8080", - }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, - Address: "10.0.0.2:8080", - }, - }, - }, - request: &types.LLMRequest{ - RequestId: "test-request", - TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ - Prompt: prompt, - }, - }, - }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry { - require.NotNil(t, req.Completions, "req expected to use Completions API") - - testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig) - require.NoError(t, err) - - tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model) - require.NoError(t, err) - - tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) - chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model) - - require.GreaterOrEqual(t, len(chunkKeys), 1, "Need at least 1 chunk for test") - - // Populate the index with blocks for model `different-model` - // The request will ask for "test-model" but the cache only has "different-model" - // This should result in no cache hits since models don't share cache - return map[kvblock.Key][]kvblock.PodEntry{ - {ModelName: "different-model", ChunkHash: chunkKeys[0].ChunkHash}: { - {PodIdentifier: "10.0.0.1:8080"}, - {PodIdentifier: "10.0.0.2:8080"}, - }, - } - }, - wantScoresByAddress: map[string]float64{ - // Even though both pods have the chunk cached, it's for a different model - // so there should be no cache hits for the requested model - "10.0.0.1:8080": 0.0, - "10.0.0.2:8080": 0.0, - }, - }, - { - name: "single pod", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + name: "single endpoint", + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - MetricsState: &backendmetrics.MetricsState{ + &fwkdl.Metrics{ WaitingQueueSize: 0, }, - }, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") - testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig) + testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) require.NoError(t, err) - tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model) + tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model, true) require.NoError(t, err) tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) - chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model) + chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model) require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test") - // Single pod has 2 chunks cached - return map[kvblock.Key][]kvblock.PodEntry{ - {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: { + // Single endpoint has 2 chunks cached + return map[kvblock.BlockHash][]kvblock.PodEntry{ + chunkKeys[0]: { {PodIdentifier: "10.0.0.1:8080"}, }, - {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: { + chunkKeys[1]: { {PodIdentifier: "10.0.0.1:8080"}, }, } }, wantScoresByAddress: map[string]float64{ - // with only one pod, minScore == maxScore, so normalization returns 1.0 + // with only one endpoint, minScore == maxScore, so normalization returns 1.0 "10.0.0.1:8080": 1.0, }, }, { name: "no cache hits (empty index)", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - }, + nil, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ - Prompt: "This prompt has never been cached before on any pod.", + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ + Prompt: "This prompt has never been cached before on any endpoint.", }, }, }, kvBlockData: nil, // no cached data wantScoresByAddress: map[string]float64{ - // when no pods have any cache hits, all should get equal scores (0.0) + // when no endpoints have any cache hits, all should get equal scores (0.0) "10.0.0.1:8080": 0.0, "10.0.0.2:8080": 0.0, "10.0.0.3:8080": 0.0, }, }, { - name: "all pods have equal prefix length", - pods: []types.Pod{ - &types.PodMetrics{ - Pod: &backend.Pod{ + name: "all endpoints have equal prefix length", + endpoints: []scheduling.Endpoint{ + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}, Address: "10.0.0.1:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}, Address: "10.0.0.2:8080", }, - }, - &types.PodMetrics{ - Pod: &backend.Pod{ + nil, + nil, + ), + scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod-c"}, Address: "10.0.0.3:8080", }, - }, + nil, + nil, + ), }, - request: &types.LLMRequest{ + request: &scheduling.LLMRequest{ RequestId: "test-request", TargetModel: "test-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &scheduling.LLMRequestBody{ + Completions: &scheduling.CompletionsRequest{ Prompt: prompt, }, }, }, - kvBlockData: func(req *types.LLMRequestBody, model string) map[kvblock.Key][]kvblock.PodEntry { + kvBlockData: func(req *scheduling.LLMRequestBody, model string) map[kvblock.BlockHash][]kvblock.PodEntry { require.NotNil(t, req.Completions, "req expected to use Completions API") - testTokenizer, err := tokenization.NewCachedLocalTokenizer(localTokenizerConfig) + testTokenizer, err := tokenization.NewCachedLocalTokenizer(t.Context(), model, localTokenizerConfig) require.NoError(t, err) - tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model) + tokens, _, err := testTokenizer.Encode(req.Completions.Prompt, model, true) require.NoError(t, err) tokenProcessor := kvblock.NewChunkedTokenDatabase(kvblock.DefaultTokenProcessorConfig()) - chunkKeys := tokenProcessor.TokensToKVBlockKeys(tokens, model) + chunkKeys := tokenProcessor.TokensToKVBlockKeys(kvblock.EmptyBlockHash, tokens, model) require.GreaterOrEqual(t, len(chunkKeys), 2, "Need at least 2 chunks for test") - // all pods have the same 2 chunks cached - return map[kvblock.Key][]kvblock.PodEntry{ - {ModelName: model, ChunkHash: chunkKeys[0].ChunkHash}: { + // all endpoints have the same 2 chunks cached + return map[kvblock.BlockHash][]kvblock.PodEntry{ + chunkKeys[0]: { {PodIdentifier: "10.0.0.1:8080"}, {PodIdentifier: "10.0.0.2:8080"}, {PodIdentifier: "10.0.0.3:8080"}, }, - {ModelName: model, ChunkHash: chunkKeys[1].ChunkHash}: { + chunkKeys[1]: { {PodIdentifier: "10.0.0.1:8080"}, {PodIdentifier: "10.0.0.2:8080"}, {PodIdentifier: "10.0.0.3:8080"}, @@ -547,8 +522,8 @@ func TestPrefixCacheTracking_Score(t *testing.T) { } }, wantScoresByAddress: map[string]float64{ - // when all pods have equal cache (minScore == maxScore), the implementation - // returns 1.0 for all pods to avoid division by zero + // when all endpoints have equal cache (minScore == maxScore), the implementation + // returns 1.0 for all endpoints to avoid division by zero "10.0.0.1:8080": 1.0, "10.0.0.2:8080": 1.0, "10.0.0.3:8080": 1.0, @@ -562,6 +537,7 @@ func TestPrefixCacheTracking_Score(t *testing.T) { kvcacheConfig, err := kvcache.NewDefaultConfig() kvcacheConfig.TokenizersPoolConfig = &tokenization.Config{ + ModelName: "test-model", WorkersCount: 1, MinPrefixOverlapRatio: 0.8, LocalTokenizerConfig: &localTokenizerConfig, @@ -580,17 +556,17 @@ func TestPrefixCacheTracking_Score(t *testing.T) { kvBlockIndex := prefixCacheScorer.kvCacheIndexer.KVBlockIndex() blockData := tt.kvBlockData(tt.request.Body, tt.request.TargetModel) for key, entries := range blockData { - err := kvBlockIndex.Add(ctx, []kvblock.Key{key}, entries) + err := kvBlockIndex.Add(ctx, []kvblock.BlockHash{kvblock.EmptyBlockHash}, []kvblock.BlockHash{key}, entries) require.NoError(t, err) } } - got := prefixCacheScorer.Score(ctx, types.NewCycleState(), tt.request, tt.pods) + got := prefixCacheScorer.Score(ctx, scheduling.NewCycleState(), tt.request, tt.endpoints) gotByAddress := make(map[string]float64) - for pod, score := range got { - if podMetrics, ok := pod.(*types.PodMetrics); ok && podMetrics.GetPod() != nil { - gotByAddress[podMetrics.GetPod().Address] = score + for endpoint, score := range got { + if endpoint.GetMetadata() != nil { + gotByAddress[endpoint.GetMetadata().Address] = score } } diff --git a/pkg/plugins/scorer/session_affinity.go b/pkg/plugins/scorer/session_affinity.go index 3ac9230c6..87e9d2be9 100644 --- a/pkg/plugins/scorer/session_affinity.go +++ b/pkg/plugins/scorer/session_affinity.go @@ -6,12 +6,11 @@ import ( "encoding/json" "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" - logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" + logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/common/util/logging" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) const ( @@ -22,18 +21,18 @@ const ( ) // compile-time type assertion -var _ framework.Scorer = &SessionAffinity{} +var _ scheduling.Scorer = &SessionAffinity{} var _ requestcontrol.ResponseComplete = &SessionAffinity{} // SessionAffinityFactory defines the factory function for SessionAffinity scorer. -func SessionAffinityFactory(name string, _ json.RawMessage, _ plugins.Handle) (plugins.Plugin, error) { +func SessionAffinityFactory(name string, _ json.RawMessage, _ plugin.Handle) (plugin.Plugin, error) { return NewSessionAffinity().WithName(name), nil } // NewSessionAffinity returns a scorer func NewSessionAffinity() *SessionAffinity { return &SessionAffinity{ - typedName: plugins.TypedName{Type: SessionAffinityType}, + typedName: plugin.TypedName{Type: SessionAffinityType}, } } @@ -42,11 +41,11 @@ func NewSessionAffinity() *SessionAffinity { // session was sent to, by giving that pod the specified weight and assigning // zero score to the rest of the targets type SessionAffinity struct { - typedName plugins.TypedName + typedName plugin.TypedName } // TypedName returns the typed name of the plugin. -func (s *SessionAffinity) TypedName() plugins.TypedName { +func (s *SessionAffinity) TypedName() plugin.TypedName { return s.typedName } @@ -56,9 +55,14 @@ func (s *SessionAffinity) WithName(name string) *SessionAffinity { return s } +// Category returns the preference the scorer applies when scoring candidate endpoints. +func (s *SessionAffinity) Category() scheduling.ScorerCategory { + return scheduling.Affinity +} + // Score assign a high score to the pod used in previous requests and zero to others -func (s *SessionAffinity) Score(ctx context.Context, _ *types.CycleState, request *types.LLMRequest, pods []types.Pod) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64) +func (s *SessionAffinity) Score(ctx context.Context, _ *scheduling.CycleState, request *scheduling.LLMRequest, endpoints []scheduling.Endpoint) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64) sessionToken := request.Headers[sessionTokenHeader] podName := "" @@ -70,21 +74,21 @@ func (s *SessionAffinity) Score(ctx context.Context, _ *types.CycleState, reques podName = string(decodedBytes) } } - for _, pod := range pods { - scoredPods[pod] = 0.0 // initial value - if pod.GetPod().NamespacedName.String() == podName { - scoredPods[pod] = 1.0 + for _, endpoint := range endpoints { + scoredEndpoints[endpoint] = 0.0 // initial value + if endpoint.GetMetadata().NamespacedName.String() == podName { + scoredEndpoints[endpoint] = 1.0 } } - return scoredPods + return scoredEndpoints } // ResponseComplete sets the session header on the response sent to the client // TODO: this should be using a cookie and ensure not overriding any other // cookie values if present. // Tracked in https://github.com/llm-d/llm-d-inference-scheduler/issues/28 -func (s *SessionAffinity) ResponseComplete(ctx context.Context, _ *types.LLMRequest, response *requestcontrol.Response, targetPod *backend.Pod) { +func (s *SessionAffinity) ResponseComplete(ctx context.Context, _ *scheduling.LLMRequest, response *requestcontrol.Response, targetPod *datalayer.EndpointMetadata) { if response == nil || targetPod == nil { reqID := "undefined" if response != nil { diff --git a/pkg/plugins/scorer/session_affinity_test.go b/pkg/plugins/scorer/session_affinity_test.go index 943b06eb4..d7acf3468 100644 --- a/pkg/plugins/scorer/session_affinity_test.go +++ b/pkg/plugins/scorer/session_affinity_test.go @@ -8,79 +8,80 @@ import ( "github.com/google/go-cmp/cmp" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/requestcontrol" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/requestcontrol" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer" "github.com/llm-d/llm-d-inference-scheduler/test/utils" ) func TestSessionAffinity_Score(t *testing.T) { - podA := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - podB := &types.PodMetrics{ - Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, - MetricsState: &backendmetrics.MetricsState{}, - } - - inputPods := []types.Pod{podA, podB} - - // valid session token for podB - validSessionTokenForPodB := base64.StdEncoding.EncodeToString([]byte(podB.GetPod().NamespacedName.String())) + endpointA := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}}, + &fwkdl.Metrics{}, + nil, + ) + endpointB := scheduling.NewEndpoint( + &fwkdl.EndpointMetadata{NamespacedName: k8stypes.NamespacedName{Name: "pod-b"}}, + &fwkdl.Metrics{}, + nil, + ) + + inputEndpoints := []scheduling.Endpoint{endpointA, endpointB} + + // valid session token for endpointB + validSessionTokenForEndpointB := base64.StdEncoding.EncodeToString([]byte(endpointB.GetMetadata().NamespacedName.String())) sessionAffinityScorer := scorer.NewSessionAffinity() tests := []struct { name string - req *types.LLMRequest - input []types.Pod - wantScores map[types.Pod]float64 + req *scheduling.LLMRequest + input []scheduling.Endpoint + wantScores map[scheduling.Endpoint]float64 }{ { - name: "selects correct pod : podB", - req: &types.LLMRequest{ - Headers: map[string]string{"x-session-token": validSessionTokenForPodB}, + name: "selects correct endpoint : endpointB", + req: &scheduling.LLMRequest{ + Headers: map[string]string{"x-session-token": validSessionTokenForEndpointB}, }, - input: inputPods, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 1.0, + input: inputEndpoints, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 1.0, }, }, { name: "no session token", - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ Headers: map[string]string{}, }, - // both pods get score 0.0 - input: inputPods, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 0.0, + // both endpoints get score 0.0 + input: inputEndpoints, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 0.0, }, }, { name: "invalid session token", - req: &types.LLMRequest{ + req: &scheduling.LLMRequest{ Headers: map[string]string{"x-session-token": "garbage-token"}, }, // expect same behavior as no session token - input: inputPods, - wantScores: map[types.Pod]float64{ - podA: 0.0, - podB: 0.0, + input: inputEndpoints, + wantScores: map[scheduling.Endpoint]float64{ + endpointA: 0.0, + endpointB: 0.0, }, }, { - name: "no pods available", - req: &types.LLMRequest{}, - input: []types.Pod{}, + name: "no endpoints available", + req: &scheduling.LLMRequest{}, + input: []scheduling.Endpoint{}, // returns empty score map - wantScores: map[types.Pod]float64{}, + wantScores: map[scheduling.Endpoint]float64{}, }, } @@ -97,30 +98,30 @@ func TestSessionAffinity_Score(t *testing.T) { func TestSessionAffinity_ResponseComplete(t *testing.T) { - targetPod := &backend.Pod{ + targetEndpoint := &fwkdl.EndpointMetadata{ NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, Address: "1.2.3.4", } // expected token to be set in response header - wantToken := base64.StdEncoding.EncodeToString([]byte(targetPod.NamespacedName.String())) + wantToken := base64.StdEncoding.EncodeToString([]byte(targetEndpoint.NamespacedName.String())) tests := []struct { name string initialResponse *requestcontrol.Response - targetPod *backend.Pod + targetPod *fwkdl.EndpointMetadata wantHeaders map[string]string }{ { name: "standard case with existing headers map", initialResponse: &requestcontrol.Response{RequestId: "req-1", Headers: make(map[string]string)}, - targetPod: targetPod, + targetPod: targetEndpoint, wantHeaders: map[string]string{"x-session-token": wantToken}, }, { name: "response with nil headers map", initialResponse: &requestcontrol.Response{RequestId: "req-2", Headers: nil}, - targetPod: targetPod, + targetPod: targetEndpoint, wantHeaders: map[string]string{"x-session-token": wantToken}, }, { diff --git a/pkg/plugins/scorer/utils.go b/pkg/plugins/scorer/utils.go index 31a721b71..4d4b3c741 100644 --- a/pkg/plugins/scorer/utils.go +++ b/pkg/plugins/scorer/utils.go @@ -3,41 +3,41 @@ package scorer import ( "math" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" ) -// podToKey is a function type that converts a Pod to a string key. +// endpointToKey is a function type that converts a Pod to a string key. // It returns the key and a boolean indicating success. -type podToKeyFunc func(pod types.Pod) (string, bool) +type endpointToKeyFunc func(endpoint scheduling.Endpoint) (string, bool) // indexedScoresToNormalizedScoredPods converts a map of pod scores to a map of // normalized scores. The function takes a list of pods, a function to convert // a pod to a key, and a map of scores indexed by those keys. It returns a map // of pods to their normalized scores. -func indexedScoresToNormalizedScoredPods(pods []types.Pod, podToKey podToKeyFunc, - scores map[string]float64) map[types.Pod]float64 { - scoredPods := make(map[types.Pod]float64) +func indexedScoresToNormalizedScoredPods(endpoints []scheduling.Endpoint, endpointToKey endpointToKeyFunc, + scores map[string]float64) map[scheduling.Endpoint]float64 { + scoredEndpoints := make(map[scheduling.Endpoint]float64) minScore, maxScore := getMinMax(scores) - for _, pod := range pods { - key, ok := podToKey(pod) + for _, endpoint := range endpoints { + key, ok := endpointToKey(endpoint) if !ok { continue } if score, ok := scores[key]; ok { if minScore == maxScore { - scoredPods[pod] = 1.0 + scoredEndpoints[endpoint] = 1.0 continue } - scoredPods[pod] = (score - minScore) / (maxScore - minScore) + scoredEndpoints[endpoint] = (score - minScore) / (maxScore - minScore) } else { - scoredPods[pod] = 0.0 + scoredEndpoints[endpoint] = 0.0 } } - return scoredPods + return scoredEndpoints } func getMinMax(scores map[string]float64) (float64, float64) { diff --git a/pkg/scheduling/pd/scheduler_test.go b/pkg/scheduling/pd/scheduler_test.go index bd1f6b1c1..e068ade12 100644 --- a/pkg/scheduling/pd/scheduler_test.go +++ b/pkg/scheduling/pd/scheduler_test.go @@ -11,14 +11,13 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" k8stypes "k8s.io/apimachinery/pkg/types" - "sigs.k8s.io/controller-runtime/pkg/log" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend" - backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" // Import config for thresholds + "sigs.k8s.io/controller-runtime/pkg/log" // Import config for thresholds + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/datalayer/plugins/approximateprefix" + fwkdl "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/datalayer" + fwkschd "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/scheduling" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/picker" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/plugins/scheduling/scorer/prefix" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/multi/prefix" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/framework/plugins/picker" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/filter" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/profile" @@ -32,42 +31,45 @@ const ( // Tests the scheduler expected behavior. func TestPDSchedule(t *testing.T) { - pod1 := &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "pod1"}, + endpoint1 := fwkschd.NewEndpoint( + &fwkdl.EndpointMetadata{ + NamespacedName: k8stypes.NamespacedName{Name: "endpoint1"}, Address: "1.2.3.4", Labels: map[string]string{filter.RoleLabel: filter.RolePrefill}, }, - MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}, - } - pod2 := &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "pod2"}, + &fwkdl.Metrics{WaitingQueueSize: 0}, + fwkdl.NewAttributes(), + ) + endpoint2 := fwkschd.NewEndpoint( + &fwkdl.EndpointMetadata{ + NamespacedName: k8stypes.NamespacedName{Name: "endpoint2"}, Address: "5.6.7.8", Labels: map[string]string{filter.RoleLabel: filter.RoleDecode}, }, - MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 0}, - } - noRolePod1 := &types.PodMetrics{ - Pod: &backend.Pod{ - NamespacedName: k8stypes.NamespacedName{Name: "noRolePod1"}, + &fwkdl.Metrics{WaitingQueueSize: 0}, + fwkdl.NewAttributes(), + ) + noRoleEndpoint1 := fwkschd.NewEndpoint( + &fwkdl.EndpointMetadata{ + NamespacedName: k8stypes.NamespacedName{Name: "noRoleEndpoint1"}, Address: "1.1.1.1", }, - MetricsState: &backendmetrics.MetricsState{WaitingQueueSize: 2}, - } + &fwkdl.Metrics{WaitingQueueSize: 2}, + fwkdl.NewAttributes(), + ) - prefillDecodeResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ - decode: {TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod2, + prefillDecodeResult := &fwkschd.SchedulingResult{ + ProfileResults: map[string]*fwkschd.ProfileRunResult{ + decode: {TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint2, }, }, }, prefill: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod1, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint1, }, }, }, @@ -76,12 +78,12 @@ func TestPDSchedule(t *testing.T) { PrimaryProfileName: decode, } - decodeResult := &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ + decodeResult := &fwkschd.SchedulingResult{ + ProfileResults: map[string]*fwkschd.ProfileRunResult{ decode: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod2, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint2, }, }, }, @@ -91,114 +93,114 @@ func TestPDSchedule(t *testing.T) { tests := []struct { name string - req *types.LLMRequest - input []types.Pod - wantRes *types.SchedulingResult - wantRes2 *types.SchedulingResult // a subsequent call to check prefix cache and how it affects PD + req *fwkschd.LLMRequest + input []fwkschd.Endpoint + wantRes *fwkschd.SchedulingResult + wantRes2 *fwkschd.SchedulingResult // a subsequent call to check prefix cache and how it affects PD err bool }{ { - name: "no candidate pods", - req: &types.LLMRequest{ + name: "no candidate endpoints", + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "any-model", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - input: []types.Pod{}, + input: []fwkschd.Endpoint{}, err: true, }, { - name: "one decode pod, long prompt", - req: &types.LLMRequest{ + name: "one decode endpoint, long prompt", + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - // pod2 will be picked because it is the only pod with Decode role - input: []types.Pod{pod2}, + // endpoint2 will be picked because it is the only endpoint with Decode role + input: []fwkschd.Endpoint{endpoint2}, wantRes: decodeResult, }, { - name: "one prefill pod, long prompt", - req: &types.LLMRequest{ + name: "one prefill endpoint, long prompt", + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - // no Decode pod - input: []types.Pod{pod1}, + // no Decode endpoint + input: []fwkschd.Endpoint{endpoint1}, err: true, }, { name: "1P1D - long prompt", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678906", }, }, }, - // pod2 will be picked in the decode profile result, pod1 will be in the prefill profile result - input: []types.Pod{pod1, pod2}, + // endpoint2 will be picked in the decode profile result, endpoint1 will be in the prefill profile result + input: []fwkschd.Endpoint{endpoint1, endpoint2}, wantRes: prefillDecodeResult, wantRes2: decodeResult, }, { name: "1P1Dshort", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345", }, }, }, - // pod2 will be picked because it is the decode pod, pod1 shouldn't be picked, + // endpoint2 will be picked because it is the decode endpoint, endpoint1 shouldn't be picked, // because the prompt is too short - input: []types.Pod{pod1, pod2}, + input: []fwkschd.Endpoint{endpoint1, endpoint2}, wantRes: decodeResult, wantRes2: decodeResult, }, { name: "TestRolesWithNoDecode", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ Prompt: "12345678901", }, }, }, - input: []types.Pod{pod1, noRolePod1}, - wantRes: &types.SchedulingResult{ - ProfileResults: map[string]*types.ProfileRunResult{ + input: []fwkschd.Endpoint{endpoint1, noRoleEndpoint1}, + wantRes: &fwkschd.SchedulingResult{ + ProfileResults: map[string]*fwkschd.ProfileRunResult{ decode: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: noRolePod1, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: noRoleEndpoint1, }, }, }, prefill: { - TargetPods: []types.Pod{ - &types.ScoredPod{ - Pod: pod1, + TargetEndpoints: []fwkschd.Endpoint{ + &fwkschd.ScoredEndpoint{ + Endpoint: endpoint1, }, }, }, @@ -208,18 +210,18 @@ func TestPDSchedule(t *testing.T) { }, { name: "1P2D - long prompt", - req: &types.LLMRequest{ + req: &fwkschd.LLMRequest{ RequestId: uuid.NewString(), TargetModel: "critical", - Body: &types.LLMRequestBody{ - Completions: &types.CompletionsRequest{ - Prompt: "12345678906", + Body: &fwkschd.LLMRequestBody{ + Completions: &fwkschd.CompletionsRequest{ + Prompt: "1234567890123456789012345678901234567890", }, }, }, - // pod2 will be picked in the decode profile result cause it has higher score than noRolePod1 - // pod1 will be in the prefill profile result - input: []types.Pod{pod1, pod2, noRolePod1}, + // endpoint2 will be picked in the decode profile result cause it has higher score than noRoleEndpoint1 + // endpoint1 will be in the prefill profile result + input: []fwkschd.Endpoint{endpoint1, endpoint2, noRoleEndpoint1}, wantRes: prefillDecodeResult, wantRes2: decodeResult, }, @@ -232,49 +234,64 @@ func TestPDSchedule(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { // initialize scheduler with config - prefixScorer := prefix.New(ctx, prefix.Config{BlockSize: 5, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250}) + prefixScorer, err := prefix.New(ctx, prefix.Config{AutoTune: false, BlockSizeTokens: 2, MaxPrefixBlocksToMatch: 256, LRUCapacityPerServer: 31250}) + assert.NoError(t, err, "Prefix plugin creation returned unexpected error") - prefillSchedulerProfile := framework.NewSchedulerProfile(). + prefillSchedulerProfile := scheduling.NewSchedulerProfile(). WithFilters(filter.NewPrefillRole()). WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) - err := prefillSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 50)) + err = prefillSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 50)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") - decodeSchedulerProfile := framework.NewSchedulerProfile(). + decodeSchedulerProfile := scheduling.NewSchedulerProfile(). WithFilters(filter.NewDecodeRole()). - WithScorers(framework.NewWeightedScorer(scorer.NewLoadAware(ctx, scorer.QueueThresholdDefault), 1)). + WithScorers(scheduling.NewWeightedScorer(scorer.NewLoadAware(ctx, scorer.QueueThresholdDefault), 1)). WithPicker(picker.NewMaxScorePicker(picker.DefaultMaxNumOfEndpoints)) - err = decodeSchedulerProfile.AddPlugins(framework.NewWeightedScorer(prefixScorer, 0)) + err = decodeSchedulerProfile.AddPlugins(scheduling.NewWeightedScorer(prefixScorer, 0)) assert.NoError(t, err, "SchedulerProfile AddPlugins returned unexpected error") - profileHandle := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, 10, 5, 0) + deciderPlugin, err := profile.NewPrefixBasedPDDecider(profile.PrefixBasedPDDeciderConfig{NonCachedTokens: 2}) + assert.NoError(t, err) + + profileHandle, err := profile.NewPdProfileHandler(prefill, decode, prefixScorer.TypedName().Type, prefixScorer.TypedName().Name, + 0, deciderPlugin) + assert.NoError(t, err) - schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]*framework.SchedulerProfile{ + schedulerConfig := scheduling.NewSchedulerConfig(profileHandle, map[string]fwkschd.SchedulerProfile{ prefill: prefillSchedulerProfile, decode: decodeSchedulerProfile, }) scheduler := scheduling.NewSchedulerWithConfig(schedulerConfig) + + inputTokens := len(test.req.Body.Completions.Prompt) / profile.AverageCharactersPerToken + for _, pod := range test.input { + pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(0, inputTokens, 1)) + } got, err := scheduler.Schedule(ctx, test.req, test.input) if test.err != (err != nil) { t.Errorf("Unexpected error, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.wantRes, got, cmpopts.IgnoreFields(types.ScoredPod{}, "Score")); diff != "" { + if diff := cmp.Diff(test.wantRes, got, cmpopts.IgnoreUnexported(fwkdl.Attributes{}), cmpopts.IgnoreFields(fwkschd.ScoredEndpoint{}, "Score")); diff != "" { t.Errorf("Unexpected output (-want +got): %v", diff) } - if test.wantRes2 != nil { // Checking the prefix match in the decode pod. // make sure prefix plugin stores the prefix hit in cache, so we can test it in the following schedule call prefixScorer.PreRequest(ctx, test.req, got) time.Sleep(time.Second) + // update number of cached tokens "stored" in the first schedule execution + for _, pod := range test.input { + pod.Put(approximateprefix.PrefixCacheMatchInfoKey, approximateprefix.NewPrefixCacheMatchInfo(inputTokens, inputTokens, 1)) + } + got, err = scheduler.Schedule(ctx, test.req, test.input) if test.err != (err != nil) { t.Errorf("Unexpected error in schedule call, got %v, want %v", err, test.err) } - if diff := cmp.Diff(test.wantRes2, got, cmpopts.IgnoreFields(types.ScoredPod{}, "Score")); diff != "" { + if diff := cmp.Diff(test.wantRes2, got, cmpopts.IgnoreUnexported(fwkdl.Attributes{}), cmpopts.IgnoreFields(fwkschd.ScoredEndpoint{}, "Score")); diff != "" { t.Errorf("Unexpected output in subsequent schedule call (-want +got): %v", diff) } } diff --git a/pkg/sidecar/proxy/connector_lmcache.go b/pkg/sidecar/proxy/connector_lmcache.go deleted file mode 100644 index f19412e83..000000000 --- a/pkg/sidecar/proxy/connector_lmcache.go +++ /dev/null @@ -1,91 +0,0 @@ -/* -Copyright 2025 The llm-d Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package proxy - -import ( - "encoding/json" - "io" - "net/http" - "strings" -) - -func (s *Server) runLMCacheProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) { - s.logger.Info("running LMCache protocol") - - // Read and parse request body - defer r.Body.Close() //nolint:all - original, err := io.ReadAll(r.Body) - if err != nil { - w.WriteHeader(http.StatusBadRequest) // TODO: check FastAPI error code when failing to read body - w.Write([]byte(err.Error())) //nolint:all - return - } - - // Parse completion request - var completionRequest map[string]any - if err := json.Unmarshal(original, &completionRequest); err != nil { - if err := errorJSONInvalid(err, w); err != nil { - s.logger.Error(err, "failed to send error response to client") - } - return - } - - // Create prefiller request. Set max_tokens to 1. - - ctx := r.Context() - preq := r.Clone(ctx) - - completionRequest[requestFieldMaxTokens] = 1 - completionRequest[requestFieldMaxCompletionTokens] = 1 - - pbody, err := json.Marshal(completionRequest) - if err != nil { - if err := errorJSONInvalid(err, w); err != nil { - s.logger.Error(err, "failed to send error response to client") - } - return - } - preq.Body = io.NopCloser(strings.NewReader(string(pbody))) - preq.ContentLength = int64(len(pbody)) - - // Forward request to prefiller - - prefillHandler, err := s.prefillerProxyHandler(prefillPodHostPort) - if err != nil { - if err := errorBadGateway(err, w); err != nil { - s.logger.Error(err, "failed to send error response to client") - } - return - } - s.logger.V(4).Info("sending prefill request", "to", prefillPodHostPort) - pw := &bufferedResponseWriter{} - prefillHandler.ServeHTTP(pw, preq) - - if pw.statusCode < 200 || pw.statusCode >= 300 { - s.logger.Error(err, "request failed", "code", pw.statusCode) - w.WriteHeader(pw.statusCode) - return - } - - // Forward original request to local decoder - - r.Body = io.NopCloser(strings.NewReader(string(original))) - if !s.forwardDataParallel || !s.dataParallelHandler(w, r) { - s.logger.V(4).Info("sending request to decoder", "to", s.decoderURL.Host) - s.decoderProxy.ServeHTTP(w, r) - } -} diff --git a/pkg/sidecar/proxy/connector_nixlv2.go b/pkg/sidecar/proxy/connector_nixlv2.go index 265072bbf..f8543a711 100644 --- a/pkg/sidecar/proxy/connector_nixlv2.go +++ b/pkg/sidecar/proxy/connector_nixlv2.go @@ -107,7 +107,7 @@ func (s *Server) runNIXLProtocolV2(w http.ResponseWriter, r *http.Request, prefi pw := &bufferedResponseWriter{} prefillHandler.ServeHTTP(pw, preq) - if pw.statusCode < 200 || pw.statusCode >= 300 { + if isHTTPError(pw.statusCode) { s.logger.Error(err, "request failed", "code", pw.statusCode) w.WriteHeader(pw.statusCode) return diff --git a/pkg/sidecar/proxy/connector_shared_storage.go b/pkg/sidecar/proxy/connector_shared_storage.go new file mode 100644 index 000000000..5245f2993 --- /dev/null +++ b/pkg/sidecar/proxy/connector_shared_storage.go @@ -0,0 +1,276 @@ +/* +Copyright 2025 The llm-d Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package proxy + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "maps" + "net/http" + "strings" +) + +func (s *Server) runSharedStorageProtocol(w http.ResponseWriter, r *http.Request, prefillPodHostPort string) { + s.logger.V(4).Info("running Shared Storage protocol", "url", prefillPodHostPort) + + // Read and parse request body + defer r.Body.Close() //nolint:all + original, err := io.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) // TODO: check FastAPI error code when failing to read body + w.Write([]byte(err.Error())) //nolint:all + return + } + + // Parse completion request + var completionRequest map[string]any + if err := json.Unmarshal(original, &completionRequest); err != nil { + if err := errorJSONInvalid(err, w); err != nil { + s.logger.Error(err, "failed to send Invalid JSON error response to client") + } + return + } + + // If "cache_hit_threshold" is present in the request, we try to decode first. The decode node must meet the cache hit threshold in order to execute. + // If the decode node is below the threshold, it won't process the request and return a "cache_threshold" finish reason. In that case, + // we fall back to P/D disaggregation: perform prefill and then decode. + // For more information refer to the RFC https://github.com/vllm-project/vllm/issues/24256 + if cacheHitThreshold, hasCacheHitThreshold := completionRequest[requestFieldCacheHitThreshold]; hasCacheHitThreshold { + s.logger.V(4).Info("cache_hit_threshold field found in the request, trying to decode first", requestFieldCacheHitThreshold, cacheHitThreshold) + decodeReq := cloneRequestWithBody(r, original) + needsPrefill, err := s.tryDecode(w, decodeReq, completionRequest) + if err != nil { + return + } + if !needsPrefill { + s.logger.V(4).Info("decode succeeded without prefill") + return + } + s.logger.V(4).Info("decode failed due to failing to meet the cache hit threshold", requestFieldCacheHitThreshold, cacheHitThreshold) + } + + // we clone the completion request to avoid modifying the original request + prefillRequest := maps.Clone(completionRequest) + if err := s.prefill(w, r, prefillPodHostPort, prefillRequest); err != nil { + s.logger.Error(err, "prefill failed") + return + } + + s.logger.V(4).Info("forwarding to decoder after prefill") + completionRequest[requestFieldCacheHitThreshold] = 0 + decodeRequestBody, err := json.Marshal(completionRequest) + if err != nil { + if err := errorJSONInvalid(err, w); err != nil { + s.logger.Error(err, "failed to send Invalid JSON error response to client") + } + return + } + + decodeReq := cloneRequestWithBody(r, decodeRequestBody) + s.decoderProxy.ServeHTTP(w, decodeReq) +} + +// tryDecode attempts to decode and returns whether prefill is needed. +func (s *Server) tryDecode(w http.ResponseWriter, r *http.Request, completionRequest map[string]any) (bool, error) { + if isStreaming, _ := completionRequest[requestFieldStream].(bool); isStreaming { + if flusher, ok := w.(flushableResponseWriter); ok { + bw := newResponseWriterWithBuffer(flusher) + return s.tryDecodeStreaming(bw, r) + } + } + return s.tryDecodeBuffered(w, r) +} + +// tryDecodeBuffered handles non-streaming decode attempts. +// It buffers the entire response before inspecting it. +func (s *Server) tryDecodeBuffered(w http.ResponseWriter, r *http.Request) (bool, error) { + dw := &bufferedResponseWriter{} + s.decoderProxy.ServeHTTP(dw, r) + + if isHTTPError(dw.statusCode) { + + w.WriteHeader(dw.statusCode) + if dw.buffer.Len() > 0 { + w.Write([]byte(dw.buffer.String())) //nolint:all + } + + err := errors.New("decode request failed") + s.logger.Error(err, "unexpected status code", "code", dw.statusCode) + + return false, err + } + + // Parse response to check finish_reason + var response map[string]any + if err := json.Unmarshal([]byte(dw.buffer.String()), &response); err != nil { + s.logger.Error(err, "failed to unmarshal decode response", "response", dw.buffer.String()) + + if err := errorInternalServerError(err, w); err != nil { + s.logger.Error(err, "failed to send error response to client") + } + return false, err + } + + // Check for cache_threshold finish reason + if s.hasCacheThresholdFinishReason(response) { + return true, nil + } + + // Decode succeeded, write response to client + maps.Copy(w.Header(), dw.headers) + w.Write([]byte(dw.buffer.String())) //nolint:all + + return false, nil +} + +// tryDecodeStreaming handles streaming decode attempts. +// It buffers the initial response to check for cache_threshold, then switches +// to direct streaming mode if decode succeeds. +func (s *Server) tryDecodeStreaming(w *responseWriterWithBuffer, r *http.Request) (bool, error) { + // Run ServeHTTP in a goroutine so we can inspect the initial choice to determine if we need to prefill. + done := make(chan struct{}) + go func() { + defer close(done) + s.decoderProxy.ServeHTTP(w, r) + }() + + // Wait for either: + // - firstChunkReady(): first body data is available in buffer + // - done: request completed (possibly with no body, e.g., error response) + select { + case <-w.firstChunkReady(): + case <-done: + s.logger.V(4).Info("request completed without body data") + } + + statusCode := w.getStatusCode() + if isHTTPError(statusCode) { + if err := w.flushBufferAndGoDirect(); err != nil { + s.logger.Error(err, "failed to flush buffer to client") + return false, err + } + return false, fmt.Errorf("decode request failed with status code: %d", statusCode) + } + + // Check buffered SSE content for cache_threshold finish reason. + if s.checkBufferedResponseForCacheThreshold(w.buffered()) { + s.logger.V(4).Info("finish reason cache_threshold detected, needs prefill") + return true, nil + } + + // No cache_threshold finish reason found, flush buffer and switch to direct mode + // to let the rest of the response stream through. + s.logger.V(4).Info("first response for request shows success without cache_threshold finish reason") + if err := w.flushBufferAndGoDirect(); err != nil { + s.logger.Error(err, "failed to flush buffer to client and switch to direct mode") + return false, err + } + <-done + return false, nil +} + +// hasCacheThresholdFinishReason checks if a parsed response contains cache_threshold finish reason. +func (s *Server) hasCacheThresholdFinishReason(response map[string]any) bool { + choices, ok := response[responseFieldChoices].([]any) + if !ok || len(choices) == 0 { + return false + } + + choice, ok := choices[0].(map[string]any) + if !ok { + return false + } + + finishReason, ok := choice[responseFieldFinishReason].(string) + return ok && finishReason == finishReasonCacheThreshold +} + +// checkBufferedResponseForCacheThreshold checks the buffered SSE response for cache_threshold finish reason. +// This is only called for streaming responses, so data is always in SSE format. +func (s *Server) checkBufferedResponseForCacheThreshold(data string) bool { + // Parse SSE format: "data: {...json...}\n\ndata: {...json...}\n\n" + for _, line := range strings.Split(data, "\n") { + line = strings.TrimSpace(line) + if line == "" || line == "data: [DONE]" || !strings.HasPrefix(line, "data: ") { + continue + } + + jsonData := strings.TrimPrefix(line, "data: ") + var response map[string]any + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + s.logger.V(4).Info("skipping malformed SSE chunk", "chunk", jsonData) + continue + } + + if s.hasCacheThresholdFinishReason(response) { + return true + } + } + return false +} + +// prefill routes a request to a prefill node +func (s *Server) prefill(w http.ResponseWriter, r *http.Request, prefillPodHostPort string, completionRequest map[string]any) error { + // Prepare prefill request + completionRequest[requestFieldMaxTokens] = 1 + completionRequest[requestFieldMaxCompletionTokens] = 1 + completionRequest[requestFieldCacheHitThreshold] = 0 + + pbody, err := json.Marshal(completionRequest) + if err != nil { + if err := errorJSONInvalid(err, w); err != nil { + s.logger.Error(err, "failed to send Invalid JSON error response to client") + } + return err + } + preq := cloneRequestWithBody(r, pbody) + + prefillHandler, err := s.prefillerProxyHandler(prefillPodHostPort) + if err != nil { + if err := errorBadGateway(err, w); err != nil { + s.logger.Error(err, "failed to send Bad Gateway error response to client") + } + return err + } + + // send prefill request + s.logger.V(4).Info("sending prefill request", "to", prefillPodHostPort) + pw := &bufferedResponseWriter{} + prefillHandler.ServeHTTP(pw, preq) + + if isHTTPError(pw.statusCode) { + s.logger.Error(nil, "prefill request failed", "code", pw.statusCode) + w.WriteHeader(pw.statusCode) + if pw.buffer.Len() > 0 { + w.Write([]byte(pw.buffer.String())) //nolint:all + } + return fmt.Errorf("prefill request failed with status code: %d", pw.statusCode) + } + + s.logger.V(4).Info("prefill completed successfully") + return nil +} + +func cloneRequestWithBody(r *http.Request, body []byte) *http.Request { + cloned := r.Clone(r.Context()) + cloned.Body = io.NopCloser(bytes.NewReader(body)) + cloned.ContentLength = int64(len(body)) + return cloned +} diff --git a/pkg/sidecar/proxy/connector_test.go b/pkg/sidecar/proxy/connector_test.go index 64ba26879..b2fa407f9 100644 --- a/pkg/sidecar/proxy/connector_test.go +++ b/pkg/sidecar/proxy/connector_test.go @@ -44,7 +44,7 @@ type sidecarTestInfo struct { proxy *Server } -var connectors = []string{ConnectorLMCache, ConnectorNIXLV2} +var connectors = []string{ConnectorSharedStorage, ConnectorNIXLV2} var _ = Describe("Common Connector tests", func() { diff --git a/pkg/sidecar/proxy/errors.go b/pkg/sidecar/proxy/errors.go index 0ff35ae75..394e30072 100644 --- a/pkg/sidecar/proxy/errors.go +++ b/pkg/sidecar/proxy/errors.go @@ -50,6 +50,10 @@ func errorBadGateway(err error, w http.ResponseWriter) error { return sendError(err, "BadGateway", http.StatusBadGateway, w) } +func errorInternalServerError(err error, w http.ResponseWriter) error { + return sendError(err, "InternalServerError", http.StatusInternalServerError, w) +} + // sendError simulates vLLM errors // // Example: diff --git a/pkg/sidecar/proxy/proxy.go b/pkg/sidecar/proxy/proxy.go index bd90dd3a9..8e3b8e733 100644 --- a/pkg/sidecar/proxy/proxy.go +++ b/pkg/sidecar/proxy/proxy.go @@ -46,6 +46,12 @@ const ( requestFieldRemotePort = "remote_port" requestFieldStream = "stream" requestFieldStreamOptions = "stream_options" + requestFieldCacheHitThreshold = "cache_hit_threshold" + + responseFieldChoices = "choices" + responseFieldFinishReason = "finish_reason" + + finishReasonCacheThreshold = "cache_threshold" // SGLang bootstrap fields requestFieldBootstrapHost = "bootstrap_host" @@ -55,8 +61,8 @@ const ( // ConnectorNIXLV2 enables the P/D NIXL v2 protocol ConnectorNIXLV2 = "nixlv2" - // ConnectorLMCache enables (now deprecated) P/D LMCache protocol - ConnectorLMCache = "lmcache" + // ConnectorSharedStorage enables (now deprecated) P/D Shared Storage protocol + ConnectorSharedStorage = "shared-storage" // ConnectorSGLang enables SGLang P/D disaggregation protocol ConnectorSGLang = "sglang" @@ -177,8 +183,8 @@ func (s *Server) Clone() *Server { func (s *Server) setConnector() { switch s.config.Connector { - case ConnectorLMCache: - s.runConnectorProtocol = s.runLMCacheProtocol + case ConnectorSharedStorage: + s.runConnectorProtocol = s.runSharedStorageProtocol case ConnectorSGLang: s.runConnectorProtocol = s.runSGLangProtocol case ConnectorNIXLV2: diff --git a/pkg/sidecar/proxy/proxy_helpers.go b/pkg/sidecar/proxy/proxy_helpers.go index 5f68bb4ea..30ba76494 100644 --- a/pkg/sidecar/proxy/proxy_helpers.go +++ b/pkg/sidecar/proxy/proxy_helpers.go @@ -124,3 +124,8 @@ func (s *Server) createDecoderProxyHandler(decoderURL *url.URL, decoderInsecureS } return decoderProxy } + +// isHTTPError returns true if the status code indicates an error (not in the 2xx range). +func isHTTPError(statusCode int) bool { + return statusCode < http.StatusOK || statusCode >= http.StatusMultipleChoices +} diff --git a/pkg/sidecar/proxy/status_response_writer.go b/pkg/sidecar/proxy/status_response_writer.go index 721e898e6..24d96d263 100644 --- a/pkg/sidecar/proxy/status_response_writer.go +++ b/pkg/sidecar/proxy/status_response_writer.go @@ -19,8 +19,12 @@ package proxy import ( "net/http" "strings" + "sync" + "sync/atomic" ) +const sseEventDelimiter = "\n\n" + // bufferedResponseWriter receives responses from prefillers type bufferedResponseWriter struct { headers http.Header @@ -45,3 +49,185 @@ func (w *bufferedResponseWriter) Write(b []byte) (int, error) { func (w *bufferedResponseWriter) WriteHeader(statusCode int) { w.statusCode = statusCode } + +type flushableResponseWriter interface { + http.ResponseWriter + http.Flusher +} + +// responseWriterWithBuffer wraps an http.ResponseWriter to buffer initial writes. +// Start in buffer mode to inspect the first chunk, then call flushBufferAndGoDirect() +// to write buffered content and switch to direct pass-through mode. +type responseWriterWithBuffer struct { + writerFlusher flushableResponseWriter + + // buffering is checked atomically to allow lock-free fast paths + // in direct mode (Write and Flush). + buffering atomic.Bool + + // mu protects buffer, statusCode, and wroteHeader during buffering mode + // and during the transition to direct mode. + mu sync.Mutex + buffer strings.Builder + statusCode int + wroteHeader bool + + // ready receives an error (or nil) when the first Write happens, + // signaling that there's data available for inspection or an error occurred. + ready chan struct{} + readyOnce sync.Once +} + +// newResponseWriterWithBuffer creates a new writer starting in buffer mode. +func newResponseWriterWithBuffer(w flushableResponseWriter) *responseWriterWithBuffer { + rw := &responseWriterWithBuffer{ + writerFlusher: w, + ready: make(chan struct{}, 1), // buffered to avoid blocking sender + } + rw.buffering.Store(true) + return rw +} + +func (w *responseWriterWithBuffer) Header() http.Header { + return w.writerFlusher.Header() +} + +func (w *responseWriterWithBuffer) Write(b []byte) (int, error) { + if !w.buffering.Load() { + return w.writerFlusher.Write(b) + } + + // Buffering mode, need lock to protect buffer + w.mu.Lock() + defer w.mu.Unlock() + + // Double-check after acquiring lock (may have transitioned to direct mode) + if !w.buffering.Load() { + return w.writerFlusher.Write(b) + } + + if w.statusCode == 0 { + w.statusCode = http.StatusOK + } + + // Write() always returns a nil error + n, _ := w.buffer.Write(b) + + // Signal ready when buffer contains at least 2 complete SSE events. + // For SSE streaming, the first chunk is just the role announcement with + // finish_reason:null. We need the second chunk to see if cache_threshold + // was returned (early abort) or if decode is proceeding normally. + if shouldSignal(w.buffer.String()) { + w.signalReady() + } + + return n, nil +} + +func (w *responseWriterWithBuffer) WriteHeader(statusCode int) { + if !w.buffering.Load() { + w.writerFlusher.WriteHeader(statusCode) + return + } + + w.mu.Lock() + defer w.mu.Unlock() + + if !w.buffering.Load() { + w.writerFlusher.WriteHeader(statusCode) + return + } + + if w.statusCode == 0 { + w.statusCode = statusCode + } +} + +func (w *responseWriterWithBuffer) Flush() { + if w.buffering.Load() { + // Apply same logic as Write(): only signal when we have at least 2 SSE events. + w.mu.Lock() + shouldSignal := shouldSignal(w.buffer.String()) + w.mu.Unlock() + if shouldSignal { + w.signalReady() + } + return + } + w.writerFlusher.Flush() +} + +// firstChunkReady returns a channel that receives nil when the first complete +// chunk of body data is available in the buffer. For SSE streaming responses, +// this is signaled when the buffer contains "\n\n" (complete SSE event). +// As a fallback, Flush() also signals readiness. +func (w *responseWriterWithBuffer) firstChunkReady() <-chan struct{} { + return w.ready +} + +func (w *responseWriterWithBuffer) signalReady() { + w.readyOnce.Do(func() { + w.ready <- struct{}{} + close(w.ready) + }) +} + +func (w *responseWriterWithBuffer) writeHeaderOnce() { + if w.wroteHeader { + return + } + w.wroteHeader = true + if w.statusCode != 0 { + w.writerFlusher.WriteHeader(w.statusCode) + } +} + +// buffered returns the currently buffered content for inspection. +func (w *responseWriterWithBuffer) buffered() string { + w.mu.Lock() + defer w.mu.Unlock() + return w.buffer.String() +} + +// getStatusCode returns the status code that was set (0 if not set). +func (w *responseWriterWithBuffer) getStatusCode() int { + w.mu.Lock() + defer w.mu.Unlock() + return w.statusCode +} + +// flushBufferAndGoDirect writes any buffered content to the underlying writer +// and switches to direct mode for all subsequent writes. +func (w *responseWriterWithBuffer) flushBufferAndGoDirect() error { + w.mu.Lock() + defer w.mu.Unlock() + + if !w.buffering.Load() { + return nil // already in direct mode + } + + w.writeHeaderOnce() + + // Write buffered content to underlying writer + if w.buffer.Len() > 0 { + _, err := w.writerFlusher.Write([]byte(w.buffer.String())) + if err != nil { + return err + } + } + + // Flush BEFORE switching to direct mode. + // This prevents concurrent Flush() calls on the underlying writer, + // since the proxy goroutine's Flush() will no-op while buffering=true. + w.writerFlusher.Flush() + + // Switch to direct mode. After this, the proxy goroutine handles + // all writes and flushes directly (no concurrency with us). + w.buffering.Store(false) + + return nil +} + +func shouldSignal(data string) bool { + return strings.Count(data, sseEventDelimiter) >= 2 +} diff --git a/scripts/fetch-python-wrapper.sh b/scripts/fetch-python-wrapper.sh deleted file mode 100755 index 0c8bc1d68..000000000 --- a/scripts/fetch-python-wrapper.sh +++ /dev/null @@ -1,47 +0,0 @@ -#!/usr/bin/env bash -# fetch-python-wrapper.sh -# Fetches the Python wrapper file (render_jinja_template_wrapper.py) from llm-d-kv-cache-manager -# for use in Docker builds and local development. -# Version can be provided as CLI arg or via KVCACHE_MANAGER_VERSION env var (default v0.3.2). -# -# This script replicates the original Dockerfile logic: -# 1. Creates a temporary directory -# 2. Clones the repo into that directory -# 3. Creates the output directory structure -# 4. Copies the wrapper file to the output location -# 5. Cleans up the temporary directory - -set -euo pipefail - -VERSION="${1:-${KVCACHE_MANAGER_VERSION:-v0.3.2}}" -OUTPUT_DIR="${2:-llm-d-kv-cache-manager/pkg/preprocessing/chat_completions}" - -REPO_URL="https://github.com/llm-d/llm-d-kv-cache-manager.git" -WRAPPER_FILE="pkg/preprocessing/chat_completions/render_jinja_template_wrapper.py" - -# Create temporary directory (equivalent to: mkdir -p /tmp/kv-cache-manager) -# TEMP_DIR will be an absolute path like /tmp/tmp.XXXXXX -TEMP_DIR=$(mktemp -d) -trap "rm -rf ${TEMP_DIR}" EXIT - -echo "Fetching Python wrapper from llm-d-kv-cache-manager@${VERSION}..." - -# Equivalent to: cd /tmp/kv-cache-manager && git clone ... . -# (clones repo contents directly into TEMP_DIR - using absolute path, no need to cd) -git clone --depth 1 --branch "${VERSION}" "${REPO_URL}" "${TEMP_DIR}" - -# Create output directory if it doesn't exist -# (equivalent to: mkdir -p /workspace/llm-d-kv-cache-manager/pkg/preprocessing/chat_completions) -# OUTPUT_DIR is relative to current working directory (relative path, same as original) -mkdir -p "${OUTPUT_DIR}" - -# Copy wrapper file -# Source: absolute path ${TEMP_DIR}/${WRAPPER_FILE} (e.g., /tmp/tmp.XXXXXX/pkg/.../wrapper.py) -# Destination: relative path ${OUTPUT_DIR}/ (e.g., llm-d-kv-cache-manager/pkg/.../) -# (equivalent to original: cp pkg/.../wrapper.py /workspace/... from within temp dir) -cp "${TEMP_DIR}/${WRAPPER_FILE}" "${OUTPUT_DIR}/" - -# Cleanup happens automatically via trap (equivalent to: rm -rf /tmp/kv-cache-manager) - -echo "Successfully fetched render_jinja_template_wrapper.py to ${OUTPUT_DIR}/" - diff --git a/scripts/kind-dev-env.sh b/scripts/kind-dev-env.sh index 5ef91726d..6b0d22a04 100755 --- a/scripts/kind-dev-env.sh +++ b/scripts/kind-dev-env.sh @@ -37,7 +37,7 @@ EPP_IMAGE="${EPP_IMAGE:-${IMAGE_REGISTRY}/llm-d-inference-scheduler:${EPP_TAG}}" export EPP_IMAGE # Set the model name to deploy -export MODEL_NAME="${MODEL_NAME:-food-review}" +export MODEL_NAME="${MODEL_NAME:-TinyLlama/TinyLlama-1.1B-Chat-v1.0}" # Extract model family (e.g., "meta-llama" from "meta-llama/Llama-3.1-8B-Instruct") export MODEL_FAMILY="${MODEL_NAME%%/*}" # Extract model ID (e.g., "Llama-3.1-8B-Instruct") @@ -74,32 +74,41 @@ export VLLM_REPLICA_COUNT_D="${VLLM_REPLICA_COUNT_D:-2}" # Data Parallel size export VLLM_DATA_PARALLEL_SIZE="${VLLM_DATA_PARALLEL_SIZE:-1}" -PRIMARY_PORT="0" -if [ "${PD_ENABLED}" != "\"true\"" ] && [ ${VLLM_DATA_PARALLEL_SIZE} -eq 1 ]; then - if [ "${KV_CACHE_ENABLED}" != "true" ]; then - DEFAULT_EPP_CONFIG="deploy/config/sim-epp-config.yaml" - else - DEFAULT_EPP_CONFIG="deploy/config/sim-epp-kvcache-config.yaml" - fi -else - if [ "${KV_CACHE_ENABLED}" != "true" ]; then - if [ "${PD_ENABLED}" == "\"true\"" ]; then - DEFAULT_EPP_CONFIG="deploy/config/sim-pd-epp-config.yaml" - if [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then - PRIMARY_PORT="8000" - fi - else - DEFAULT_EPP_CONFIG="deploy/config/dp-epp-config.yaml" - fi - else +# Validate configuration constraints +if [ "${KV_CACHE_ENABLED}" == "true" ]; then + # KV cache requires simple mode: no PD and DP size must be 1 + if [ "${PD_ENABLED}" == "\"true\"" ] || [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then echo "Invalid configuration: PD_ENABLED=true and KV_CACHE_ENABLED=true is not supported" exit 1 fi fi -export EPP_CONFIG="${EPP_CONFIG:-${DEFAULT_EPP_CONFIG}}" +# Set PRIMARY_PORT based on PD mode with data parallelism +if [ "${PD_ENABLED}" == "\"true\"" ] && [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then + PRIMARY_PORT="8000" +else + PRIMARY_PORT="0" +fi export PRIMARY_PORT +# Determine EPP config file based on feature flags +if [ "${KV_CACHE_ENABLED}" == "true" ]; then + # KV cache mode (simple mode only) + DEFAULT_EPP_CONFIG="deploy/config/sim-epp-kvcache-config.yaml" +elif [ "${PD_ENABLED}" == "\"true\"" ]; then + # Prefill-Decode mode + DEFAULT_EPP_CONFIG="deploy/config/sim-pd-epp-config.yaml" +elif [ ${VLLM_DATA_PARALLEL_SIZE} -ne 1 ]; then + # Data Parallel mode (only needed for Istio pre-1.28.1) + # Not really called in kind(docker.io/istio/pilot:1.28.1) by "make env-dev-kind" + DEFAULT_EPP_CONFIG="deploy/config/dp-epp-config.yaml" +else + # Simple mode + DEFAULT_EPP_CONFIG="deploy/config/sim-epp-config.yaml" +fi + +export EPP_CONFIG="${EPP_CONFIG:-${DEFAULT_EPP_CONFIG}}" + # ------------------------------------------------------------------------------ # Setup & Requirement Checks # ------------------------------------------------------------------------------ diff --git a/scripts/kubernetes-dev-env.sh b/scripts/kubernetes-dev-env.sh index 6cd8c4456..215fe86a4 100755 --- a/scripts/kubernetes-dev-env.sh +++ b/scripts/kubernetes-dev-env.sh @@ -24,7 +24,7 @@ if [[ -z "${HF_TOKEN:-}" ]]; then exit 1 fi -export VLLM_CHART_DIR="${VLLM_CHART_DIR:-../llm-d-kv-cache-manager/vllm-setup-helm}" +export VLLM_CHART_DIR="${VLLM_CHART_DIR:-../llm-d-kv-cache/vllm-setup-helm}" # Check that Chart.yaml exists if [[ ! -f "$VLLM_CHART_DIR/Chart.yaml" ]]; then echo "Chart.yaml not found in $VLLM_CHART_DIR" diff --git a/scripts/pull_images.sh b/scripts/pull_images.sh index 3acf439c1..e82337f32 100755 --- a/scripts/pull_images.sh +++ b/scripts/pull_images.sh @@ -7,7 +7,7 @@ echo "Using container tool: ${CONTAINER_RUNTIME}" # Set a default EPP_TAG if not provided EPP_TAG="${EPP_TAG:-dev}" # Set a default VLLM_SIMULATOR_TAG if not provided -VLLM_SIMULATOR_TAG="${VLLM_SIMULATOR_TAG:-v0.6.1}" +VLLM_SIMULATOR_TAG="${VLLM_SIMULATOR_TAG:-latest}" # Set the default routing side car image tag SIDECAR_TAG="${SIDECAR_TAG:-dev}" diff --git a/test-blocking-labels-1771128456.txt b/test-blocking-labels-1771128456.txt new file mode 100644 index 000000000..e5dba9286 --- /dev/null +++ b/test-blocking-labels-1771128456.txt @@ -0,0 +1 @@ +Test file for: blocking-labels - Sun Feb 15 06:07:39 IST 2026 diff --git a/test-blocking-labels-1771128602.txt b/test-blocking-labels-1771128602.txt new file mode 100644 index 000000000..b546b6479 --- /dev/null +++ b/test-blocking-labels-1771128602.txt @@ -0,0 +1 @@ +Test file for: blocking-labels - Sun Feb 15 06:10:06 IST 2026 diff --git a/test-lgtm-1770898248.txt b/test-lgtm-1770898248.txt new file mode 100644 index 000000000..b0ce6a34d --- /dev/null +++ b/test-lgtm-1770898248.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Thu Feb 12 14:10:50 IST 2026 diff --git a/test-lgtm-1770899120.txt b/test-lgtm-1770899120.txt new file mode 100644 index 000000000..086324f21 --- /dev/null +++ b/test-lgtm-1770899120.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Thu Feb 12 14:25:22 IST 2026 diff --git a/test-lgtm-1770899284.txt b/test-lgtm-1770899284.txt new file mode 100644 index 000000000..034733efc --- /dev/null +++ b/test-lgtm-1770899284.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Thu Feb 12 14:28:06 IST 2026 diff --git a/test-lgtm-1770900845.txt b/test-lgtm-1770900845.txt new file mode 100644 index 000000000..d6ca35750 --- /dev/null +++ b/test-lgtm-1770900845.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Thu Feb 12 14:54:07 IST 2026 diff --git a/test-lgtm-1770901687.txt b/test-lgtm-1770901687.txt new file mode 100644 index 000000000..1b0ff35e7 --- /dev/null +++ b/test-lgtm-1770901687.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Thu Feb 12 15:08:08 IST 2026 diff --git a/test-lgtm-1770902441.txt b/test-lgtm-1770902441.txt new file mode 100644 index 000000000..50c087c16 --- /dev/null +++ b/test-lgtm-1770902441.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Thu Feb 12 15:20:43 IST 2026 diff --git a/test-lgtm-1771123253.txt b/test-lgtm-1771123253.txt new file mode 100644 index 000000000..5891c2bf6 --- /dev/null +++ b/test-lgtm-1771123253.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 04:40:55 IST 2026 diff --git a/test-lgtm-1771123487.txt b/test-lgtm-1771123487.txt new file mode 100644 index 000000000..b52683765 --- /dev/null +++ b/test-lgtm-1771123487.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 04:44:49 IST 2026 diff --git a/test-lgtm-1771123805.txt b/test-lgtm-1771123805.txt new file mode 100644 index 000000000..d6e5201db --- /dev/null +++ b/test-lgtm-1771123805.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 04:50:07 IST 2026 diff --git a/test-lgtm-1771123942.txt b/test-lgtm-1771123942.txt new file mode 100644 index 000000000..a880ada61 --- /dev/null +++ b/test-lgtm-1771123942.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 04:52:23 IST 2026 diff --git a/test-lgtm-1771124326.txt b/test-lgtm-1771124326.txt new file mode 100644 index 000000000..b84903b8b --- /dev/null +++ b/test-lgtm-1771124326.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 04:58:48 IST 2026 diff --git a/test-lgtm-1771125160.txt b/test-lgtm-1771125160.txt new file mode 100644 index 000000000..ac1a8cf75 --- /dev/null +++ b/test-lgtm-1771125160.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 05:12:42 IST 2026 diff --git a/test-lgtm-1771125246.txt b/test-lgtm-1771125246.txt new file mode 100644 index 000000000..15e6d5010 --- /dev/null +++ b/test-lgtm-1771125246.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 05:14:08 IST 2026 diff --git a/test-lgtm-1771125307.txt b/test-lgtm-1771125307.txt new file mode 100644 index 000000000..7676ff8dd --- /dev/null +++ b/test-lgtm-1771125307.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 05:15:08 IST 2026 diff --git a/test-lgtm-1771126707.txt b/test-lgtm-1771126707.txt new file mode 100644 index 000000000..262706925 --- /dev/null +++ b/test-lgtm-1771126707.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 05:38:29 IST 2026 diff --git a/test-lgtm-1771127615.txt b/test-lgtm-1771127615.txt new file mode 100644 index 000000000..fe1386765 --- /dev/null +++ b/test-lgtm-1771127615.txt @@ -0,0 +1 @@ +This is a test file for LGTM workflow automation - Sun Feb 15 05:53:36 IST 2026 diff --git a/test-open-pr-1771133038.txt b/test-open-pr-1771133038.txt new file mode 100644 index 000000000..58d095262 --- /dev/null +++ b/test-open-pr-1771133038.txt @@ -0,0 +1 @@ +Test file for: open-pr - Sun Feb 15 07:24:04 IST 2026 diff --git a/test-open-pr-1771133183.txt b/test-open-pr-1771133183.txt new file mode 100644 index 000000000..8c635fe9c --- /dev/null +++ b/test-open-pr-1771133183.txt @@ -0,0 +1 @@ +Test file for: open-pr - Sun Feb 15 07:26:29 IST 2026 diff --git a/test-open-pr-1771135205.txt b/test-open-pr-1771135205.txt new file mode 100644 index 000000000..153db7286 --- /dev/null +++ b/test-open-pr-1771135205.txt @@ -0,0 +1 @@ +Test file for: open-pr - Sun Feb 15 08:00:10 IST 2026 diff --git a/test-open-pr-1771137224.txt b/test-open-pr-1771137224.txt new file mode 100644 index 000000000..40e1c3735 --- /dev/null +++ b/test-open-pr-1771137224.txt @@ -0,0 +1 @@ +Test file for: open-pr - Sun Feb 15 08:33:50 IST 2026 diff --git a/test-success-path-1771126029.txt b/test-success-path-1771126029.txt new file mode 100644 index 000000000..bd613da8c --- /dev/null +++ b/test-success-path-1771126029.txt @@ -0,0 +1 @@ +Test file for: success-path - Sun Feb 15 05:27:12 IST 2026 diff --git a/test-success-path-1771126298.txt b/test-success-path-1771126298.txt new file mode 100644 index 000000000..59aa2a331 --- /dev/null +++ b/test-success-path-1771126298.txt @@ -0,0 +1 @@ +Test file for: success-path - Sun Feb 15 05:31:42 IST 2026 diff --git a/test/config/prefix_cache_mode_test.go b/test/config/prefix_cache_mode_test.go index 5a13d0087..3965f9246 100644 --- a/test/config/prefix_cache_mode_test.go +++ b/test/config/prefix_cache_mode_test.go @@ -7,7 +7,7 @@ import ( "github.com/go-logr/logr" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/config/loader" - giePlugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/plugins" + giePlugins "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/framework/interface/plugin" "sigs.k8s.io/gateway-api-inference-extension/test/utils" "github.com/llm-d/llm-d-inference-scheduler/pkg/plugins" diff --git a/test/e2e/e2e_suite_test.go b/test/e2e/e2e_suite_test.go index f4f5c9247..ff400872e 100644 --- a/test/e2e/e2e_suite_test.go +++ b/test/e2e/e2e_suite_test.go @@ -27,6 +27,8 @@ import ( ) const ( + // kindClusterName is the name of the Kind cluster created for e2e tests. + kindClusterName = "e2e-tests" // defaultReadyTimeout is the default timeout for a resource to report a ready state. defaultReadyTimeout = 3 * time.Minute // defaultInterval is the default interval to check if a resource exists or ready conditions. @@ -61,9 +63,8 @@ var ( containerRuntime = env.GetEnvString("CONTAINER_RUNTIME", "docker", ginkgo.GinkgoLogr) eppImage = env.GetEnvString("EPP_IMAGE", "ghcr.io/llm-d/llm-d-inference-scheduler:dev", ginkgo.GinkgoLogr) - vllmSimImage = env.GetEnvString("VLLM_SIMULATOR_IMAGE", "ghcr.io/llm-d/llm-d-inference-sim:dev", ginkgo.GinkgoLogr) + vllmSimImage = env.GetEnvString("VLLM_SIMULATOR_IMAGE", "ghcr.io/llm-d/llm-d-inference-sim:latest", ginkgo.GinkgoLogr) sideCarImage = env.GetEnvString("SIDECAR_IMAGE", "ghcr.io/llm-d/llm-d-routing-sidecar:dev", ginkgo.GinkgoLogr) - // nsName is the namespace in which the K8S objects will be created nsName = env.GetEnvString("NAMESPACE", "default", ginkgo.GinkgoLogr) @@ -110,8 +111,18 @@ var _ = ginkgo.BeforeSuite(func() { }) var _ = ginkgo.AfterSuite(func() { - if k8sContext != "" { - // Used an existing Kubernetes context + if k8sContext == "" { + // delete kind cluster we created + ginkgo.By("Deleting kind cluster " + kindClusterName) + command := exec.Command("kind", "delete", "cluster", "--name", kindClusterName) + session, err := gexec.Start(command, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter) + if err != nil { + ginkgo.GinkgoLogr.Error(err, "Failed to delete kind cluster") + } else { + gomega.Eventually(session).WithTimeout(60 * time.Second).Should(gexec.Exit()) + } + } else { + // Used an existing Kubernetes context, clean up created resources // Stop port-forward if portForwardSession != nil { portForwardSession.Terminate() @@ -135,18 +146,12 @@ var _ = ginkgo.AfterSuite(func() { err := testConfig.KubeCli.CoreV1().Namespaces().Delete(testConfig.Context, nsName, metav1.DeleteOptions{}) gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) } - return } - - command := exec.Command("kind", "delete", "cluster", "--name", "e2e-tests") - session, err := gexec.Start(command, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter) - gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) - gomega.Eventually(session).WithTimeout(600 * time.Second).Should(gexec.Exit(0)) }) // Create the Kubernetes cluster for the E2E tests and load the local images func setupK8sCluster() { - command := exec.Command("kind", "create", "cluster", "--name", "e2e-tests", "--config", "-") + command := exec.Command("kind", "create", "cluster", "--name", kindClusterName, "--config", "-") stdin, err := command.StdinPipe() gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) go func() { @@ -166,17 +171,27 @@ func setupK8sCluster() { kindLoadImage(vllmSimImage) kindLoadImage(eppImage) kindLoadImage(sideCarImage) + kindLoadImage(vllmSimImage) } func kindLoadImage(image string) { tempDir := ginkgo.GinkgoT().TempDir() target := tempDir + "/container.tar" - ginkgo.By(fmt.Sprintf("Loading %s into the cluster e2e-tests using %s", image, containerRuntime)) + ginkgo.By(fmt.Sprintf("Loading %s into the cluster %s using %s", image, kindClusterName, containerRuntime)) _, err := exec.LookPath(containerRuntime) gomega.Expect(err).ShouldNot(gomega.HaveOccurred(), "Could not find %s in PATH", containerRuntime) + // Pull the image first to ensure it's available locally + ginkgo.By(fmt.Sprintf("Pulling image %s if not available locally", image)) + pullCommand := exec.Command(containerRuntime, "pull", image) + pullSession, pullErr := gexec.Start(pullCommand, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter) + if pullErr == nil { + // Wait for pull to complete, but don't fail if image already exists or can't be pulled + gomega.Eventually(pullSession).WithTimeout(600 * time.Second).Should(gexec.Exit()) + } + saveArgs := []string{"save", "--output", target} if containerRuntime == "docker" { // The platform flag is required for docker save to work but it is an unsupported flag for podman @@ -189,7 +204,7 @@ func kindLoadImage(image string) { gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) gomega.Eventually(session).WithTimeout(600 * time.Second).Should(gexec.Exit(0)) - command = exec.Command("kind", "--name", "e2e-tests", "load", "image-archive", target) + command = exec.Command("kind", "--name", kindClusterName, "load", "image-archive", target) session, err = gexec.Start(command, ginkgo.GinkgoWriter, ginkgo.GinkgoWriter) gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) gomega.Eventually(session).WithTimeout(600 * time.Second).Should(gexec.Exit(0)) @@ -278,10 +293,11 @@ func createInferencePool(numTargetPorts int, toDelete bool) []string { } infPoolYaml := testutils.ReadYaml(inferExtManifest) - targetPorts := "" + var b strings.Builder for idx := range numTargetPorts { - targetPorts += fmt.Sprintf("\n - number: %d", 8000+idx) + fmt.Fprintf(&b, "\n - number: %d", 8000+idx) } + targetPorts := b.String() infPoolYaml = substituteMany(infPoolYaml, map[string]string{ "${POOL_NAME}": poolName, diff --git a/test/e2e/e2e_test.go b/test/e2e/e2e_test.go index d23b20a58..44671ed71 100644 --- a/test/e2e/e2e_test.go +++ b/test/e2e/e2e_test.go @@ -25,7 +25,7 @@ const ( // running the vLLM simulator without PD simDeployment = "./yaml/vllm-sim.yaml" // simPDDeployment references the YAML file for the deployment - // running the vLLM simulator with PD + // running the vLLM simulator with PD (connector type is configurable via ${CONNECTOR_TYPE}) simPDDeployment = "./yaml/vllm-sim-pd.yaml" // simDPDeployment references the YAML file for the deployment // running the vLLM simulator with Data Parallel @@ -126,8 +126,178 @@ var _ = ginkgo.Describe("Run end to end tests", ginkgo.Ordered, func() { labelFilter2 := fmt.Sprintf(`decision_type="decode-only",model_name="%s"`, modelName) decodeOnlyCount := getCounterMetric(metricsURL, "llm_d_inference_scheduler_pd_decision_total", labelFilter2) - gomega.Expect(prefillDecodeCount).Should(gomega.Equal(6)) - gomega.Expect(decodeOnlyCount).Should(gomega.Equal(0)) + gomega.Expect(prefillDecodeCount).Should(gomega.Equal(4)) + gomega.Expect(decodeOnlyCount).Should(gomega.Equal(2)) + + testutils.DeleteObjects(testConfig, epp) + testutils.DeleteObjects(testConfig, modelServers) + }) + }) + + ginkgo.When("Running a PD configuration with shared-storage connector", func() { + ginkgo.It("should run regular (non-streaming) requests successfully", func() { + infPoolObjects = createInferencePool(1, true) + + prefillReplicas := 1 + decodeReplicas := 2 + modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage") + + epp := createEndPointPicker(pdConfig) + + prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector) + gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas)) + gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas)) + + // Test regular completion request + nsHdr, podHdrCompletion, _ := runCompletion(simplePrompt, modelName) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdrCompletion).Should(gomega.BeElementOf(decodePods)) + + // Test regular chat completion request + nsHdr, podHdrChat, _ := runChatCompletion(simplePrompt) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdrChat).Should(gomega.BeElementOf(decodePods)) + + // Run completion with a different prompt + nsHdr, podHdr, _ := runCompletion(extraPrompt, modelName) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + + // Run completion with original prompt (should go to same pod due to prefix cache) + nsHdr, podHdr, _ = runCompletion(simplePrompt, modelName) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + gomega.Expect(podHdr).Should(gomega.Equal(podHdrCompletion)) + + testutils.DeleteObjects(testConfig, epp) + testutils.DeleteObjects(testConfig, modelServers) + }) + + ginkgo.It("should run streaming requests successfully", func() { + infPoolObjects = createInferencePool(1, true) + + prefillReplicas := 1 + decodeReplicas := 2 + modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage") + + epp := createEndPointPicker(pdConfig) + + prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector) + gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas)) + gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas)) + + // Test streaming completion request + nsHdr, podHdr := runStreamingCompletion(simplePrompt, modelName) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + + // Test streaming chat completion request + nsHdr, podHdr = runStreamingChatCompletion(simplePrompt) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + + // Run streaming completion with a different prompt + nsHdr, podHdr = runStreamingCompletion(extraPrompt, modelName) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + + testutils.DeleteObjects(testConfig, epp) + testutils.DeleteObjects(testConfig, modelServers) + }) + + ginkgo.It("should handle decode-first success scenario with cache_hit_threshold", func() { + // This test verifies the decode-first optimization: + // When cache_hit_threshold is set and the decode succeeds (cache hit), + // the request should complete without falling back to P/D. + // IMPORTANT: The prefill pod should NOT process any requests in this scenario. + infPoolObjects = createInferencePool(1, true) + + prefillReplicas := 1 + decodeReplicas := 2 + modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage") + + epp := createEndPointPicker(pdConfig) + + prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector) + gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas)) + gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas)) + + // Get prefill request count BEFORE the test + prefillCountBefore := getPrefillRequestCount(prefillPods[0]) + ginkgo.By(fmt.Sprintf("Prefill request count before decode-first test: %d", prefillCountBefore)) + + // Test decode-first success: cache_hit_threshold is set, but simulator returns "stop" + // (without X-Cache-Threshold header), meaning decode succeeded without prefill + nsHdr, podHdr, finishReason := runCompletionWithCacheThreshold(simplePrompt, 0.5, false) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + gomega.Expect(finishReason).ShouldNot(gomega.Equal("cache_threshold")) + + // Test streaming decode-first success + nsHdr, podHdr, finishReason = runStreamingCompletionWithCacheThreshold(simplePrompt, 0.5, false) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + gomega.Expect(finishReason).ShouldNot(gomega.Equal("cache_threshold")) + + // Get prefill request count AFTER the test + prefillCountAfter := getPrefillRequestCount(prefillPods[0]) + ginkgo.By(fmt.Sprintf("Prefill request count after decode-first test: %d", prefillCountAfter)) + + // VERIFY: Prefill pod should NOT have processed any new requests + // (decode-first succeeded, so no P/D fallback occurred) + gomega.Expect(prefillCountAfter).Should(gomega.Equal(prefillCountBefore), + "Prefill pod should NOT process requests when cache threshold is met (decode-first success)") + + testutils.DeleteObjects(testConfig, epp) + testutils.DeleteObjects(testConfig, modelServers) + }) + + ginkgo.It("should handle decode-first fallback to P/D when cache threshold not met", func() { + // This test verifies the decode-first fallback scenario: + // When cache_hit_threshold is set and the decode returns cache_threshold finish_reason, + // the sidecar should fall back to P/D disaggregation. + // IMPORTANT: The prefill pod SHOULD process requests in this scenario. + infPoolObjects = createInferencePool(1, true) + + prefillReplicas := 1 + decodeReplicas := 2 + modelServers := createModelServersWithConnector(true, false, false, 0, prefillReplicas, decodeReplicas, "shared-storage") + + epp := createEndPointPicker(pdConfig) + + prefillPods, decodePods := getModelServerPods(podSelector, prefillSelector, decodeSelector) + gomega.Expect(prefillPods).Should(gomega.HaveLen(prefillReplicas)) + gomega.Expect(decodePods).Should(gomega.HaveLen(decodeReplicas)) + + // Get prefill request count BEFORE the test + prefillCountBefore := getPrefillRequestCount(prefillPods[0]) + ginkgo.By(fmt.Sprintf("Prefill request count before P/D fallback test: %d", prefillCountBefore)) + + // Test decode-first fallback: cache_hit_threshold is set AND X-Cache-Threshold header + // forces simulator to return "cache_threshold" finish_reason, triggering P/D fallback + nsHdr, podHdr, finishReason := runCompletionWithCacheThreshold(simplePrompt, 0.5, true) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + // The sidecar completes the P/D flow but returns cache_threshold as the finish_reason + // from the initial decode attempt (which triggered the fallback) + gomega.Expect(finishReason).Should(gomega.Equal("cache_threshold")) + + // Test streaming decode-first fallback + nsHdr, podHdr, finishReason = runStreamingCompletionWithCacheThreshold(extraPrompt, 0.5, true) + gomega.Expect(nsHdr).Should(gomega.Equal(nsName)) + gomega.Expect(podHdr).Should(gomega.BeElementOf(decodePods)) + gomega.Expect(finishReason).Should(gomega.Equal("cache_threshold")) + + // Get prefill request count AFTER the test + prefillCountAfter := getPrefillRequestCount(prefillPods[0]) + ginkgo.By(fmt.Sprintf("Prefill request count after P/D fallback test: %d", prefillCountAfter)) + + // VERIFY: Prefill pod SHOULD have processed 2 new requests (1 regular + 1 streaming) + // (decode-first failed, so P/D fallback occurred and prefill was invoked) + gomega.Expect(prefillCountAfter).Should(gomega.BeNumerically(">", prefillCountBefore), + "Prefill pod SHOULD process requests when cache threshold is NOT met (P/D fallback)") + gomega.Expect(prefillCountAfter-prefillCountBefore).Should(gomega.Equal(2), + "Prefill pod should have processed exactly 2 requests (1 regular + 1 streaming)") testutils.DeleteObjects(testConfig, epp) testutils.DeleteObjects(testConfig, modelServers) @@ -265,7 +435,13 @@ var _ = ginkgo.Describe("Run end to end tests", ginkgo.Ordered, func() { }) // createModelServers creates the model server resources used for testing from the given filePaths. +// Uses the default connector (nixlv2) for P/D deployments. func createModelServers(withPD, withKV, withDP bool, vllmReplicas, prefillReplicas, decodeReplicas int) []string { + return createModelServersWithConnector(withPD, withKV, withDP, vllmReplicas, prefillReplicas, decodeReplicas, "nixlv2") +} + +// createModelServersWithConnector creates model server resources with a specific connector type. +func createModelServersWithConnector(withPD, withKV, withDP bool, vllmReplicas, prefillReplicas, decodeReplicas int, connector string) []string { theModelName := modelName theSafeModelName := modelName if withKV { @@ -286,6 +462,7 @@ func createModelServers(withPD, withKV, withDP bool, vllmReplicas, prefillReplic "${MODEL_NAME_SAFE}": theSafeModelName, "${POOL_NAME}": poolName, "${KV_CACHE_ENABLED}": strconv.FormatBool(withKV), + "${CONNECTOR_TYPE}": connector, "${SIDECAR_IMAGE}": sideCarImage, "${VLLM_REPLICA_COUNT}": strconv.Itoa(vllmReplicas), "${VLLM_REPLICA_COUNT_D}": strconv.Itoa(decodeReplicas), @@ -429,6 +606,219 @@ func getCounterMetric(metricsURL, metricName, labelMatch string) int { return 0 } +func runStreamingCompletion(prompt string, theModel openai.CompletionNewParamsModel) (string, string) { + ginkgo.By(fmt.Sprintf("Sending Streaming Completion Request: (port %s) model=%s", port, theModel)) + + // Use raw HTTP for streaming to capture headers + body := fmt.Sprintf(`{"model":"%s","prompt":"%s","max_tokens":50,"stream":true}`, theModel, prompt) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/completions", port), strings.NewReader(body)) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer func() { + err := resp.Body.Close() + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }() + + gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK)) + + namespaceHeader := resp.Header.Get("x-inference-namespace") + podHeader := resp.Header.Get("x-inference-pod") + + // Read and verify the streaming response + respBody, err := io.ReadAll(resp.Body) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + + ginkgo.By(fmt.Sprintf("Streaming Completion received response length: %d bytes", len(respBody))) + + return namespaceHeader, podHeader +} + +func runStreamingChatCompletion(prompt string) (string, string) { + ginkgo.By(fmt.Sprintf("Sending Streaming Chat Completion Request: (port %s)", port)) + + // Use raw HTTP for streaming to capture headers + body := fmt.Sprintf(`{"model":"%s","messages":[{"role":"user","content":"%s"}],"stream":true}`, modelName, prompt) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/chat/completions", port), strings.NewReader(body)) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(req) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer func() { + err := resp.Body.Close() + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }() + + gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK)) + + namespaceHeader := resp.Header.Get("x-inference-namespace") + podHeader := resp.Header.Get("x-inference-pod") + + // Read and verify the streaming response + respBody, err := io.ReadAll(resp.Body) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + + ginkgo.By(fmt.Sprintf("Streaming Chat Completion received response length: %d bytes", len(respBody))) + + return namespaceHeader, podHeader +} + +// runCompletionWithCacheThreshold sends a completion request with cache_hit_threshold parameter. +// This triggers the decode-first optimization in the shared-storage connector. +// Returns namespace header, pod header, and the finish reason from the response. +func runCompletionWithCacheThreshold(prompt string, cacheHitThreshold float64, forceCacheThresholdFinishReason bool) (string, string, string) { + ginkgo.By(fmt.Sprintf("Sending Completion Request with cache_hit_threshold=%v, forceCacheThreshold=%v", cacheHitThreshold, forceCacheThresholdFinishReason)) + + body := fmt.Sprintf(`{"model":"%s","prompt":"%s","max_tokens":10,"cache_hit_threshold":%v}`, modelName, prompt, cacheHitThreshold) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/completions", port), strings.NewReader(body)) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + // Add X-Cache-Threshold header to force the simulator to return cache_threshold finish_reason + if forceCacheThresholdFinishReason { + req.Header.Set("X-Cache-Threshold-Finish-Reason", "true") + } + + resp, err := http.DefaultClient.Do(req) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer func() { + err := resp.Body.Close() + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }() + + gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK)) + + namespaceHeader := resp.Header.Get("x-inference-namespace") + podHeader := resp.Header.Get("x-inference-pod") + + // Parse response to get finish_reason + respBody, err := io.ReadAll(resp.Body) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + + // Extract finish_reason from JSON response + finishReason := extractFinishReason(string(respBody)) + + ginkgo.By(fmt.Sprintf("Completion Response: ns=%s, pod=%s, finish_reason=%s", namespaceHeader, podHeader, finishReason)) + + return namespaceHeader, podHeader, finishReason +} + +// runStreamingCompletionWithCacheThreshold sends a streaming completion request with cache_hit_threshold. +func runStreamingCompletionWithCacheThreshold(prompt string, cacheHitThreshold float64, forceCacheThresholdFinishReason bool) (string, string, string) { + ginkgo.By(fmt.Sprintf("Sending Streaming Completion Request with cache_hit_threshold=%v, forceCacheThreshold=%v", cacheHitThreshold, forceCacheThresholdFinishReason)) + + body := fmt.Sprintf(`{"model":"%s","prompt":"%s","max_tokens":10,"stream":true,"cache_hit_threshold":%v}`, modelName, prompt, cacheHitThreshold) + req, err := http.NewRequest(http.MethodPost, fmt.Sprintf("http://localhost:%s/v1/completions", port), strings.NewReader(body)) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + req.Header.Set("Content-Type", "application/json") + + if forceCacheThresholdFinishReason { + req.Header.Set("X-Cache-Threshold-Finish-Reason", "true") + } + + resp, err := http.DefaultClient.Do(req) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer func() { + err := resp.Body.Close() + gomega.Expect(err).ToNot(gomega.HaveOccurred()) + }() + + gomega.Expect(resp.StatusCode).Should(gomega.Equal(http.StatusOK)) + + namespaceHeader := resp.Header.Get("x-inference-namespace") + podHeader := resp.Header.Get("x-inference-pod") + + // Read streaming response and extract finish_reason from the last data chunk + respBody, err := io.ReadAll(resp.Body) + gomega.Expect(err).ShouldNot(gomega.HaveOccurred()) + + finishReason := extractFinishReasonFromStreaming(string(respBody)) + + ginkgo.By(fmt.Sprintf("Streaming Completion Response: ns=%s, pod=%s, finish_reason=%s", namespaceHeader, podHeader, finishReason)) + + return namespaceHeader, podHeader, finishReason +} + +// extractFinishReason extracts the finish_reason field from a JSON response string. +func extractFinishReason(jsonStr string) string { + // Simple extraction - look for "finish_reason":"value" pattern + idx := strings.Index(jsonStr, `"finish_reason":"`) + if idx == -1 { + // Try with null value + if strings.Contains(jsonStr, `"finish_reason":null`) { + return "null" + } + return "" + } + start := idx + len(`"finish_reason":"`) + end := strings.Index(jsonStr[start:], `"`) + if end == -1 { + return "" + } + return jsonStr[start : start+end] +} + +// extractFinishReasonFromStreaming extracts the finish_reason from the last SSE data chunk. +func extractFinishReasonFromStreaming(sseData string) string { + // Find the last "finish_reason" that is not null + lines := strings.Split(sseData, "\n") + lastFinishReason := "" + for _, line := range lines { + if strings.HasPrefix(line, "data: ") && !strings.Contains(line, "[DONE]") { + fr := extractFinishReason(line) + if fr != "" && fr != "null" { + lastFinishReason = fr + } + } + } + return lastFinishReason +} + +// getPrefillRequestCount gets the total request count from a prefill pod's metrics endpoint. +// This is used to verify whether a request was processed by the prefill pod. +func getPrefillRequestCount(prefillPodName string) int { + ginkgo.By("Getting request count from prefill pod: " + prefillPodName) + + // Use Kubernetes API proxy to access the metrics endpoint + output, err := testConfig.KubeCli.CoreV1().RESTClient(). + Get(). + Namespace(nsName). + Resource("pods"). + Name(prefillPodName + ":8000"). + SubResource("proxy"). + Suffix("metrics"). + DoRaw(testConfig.Context) + if err != nil { + ginkgo.By(fmt.Sprintf("Warning: Could not get metrics from prefill pod %s: %v", prefillPodName, err)) + return -1 + } + + return parseRequestCountFromMetrics(string(output)) +} + +// parseRequestCountFromMetrics extracts the request count from Prometheus metrics output. +func parseRequestCountFromMetrics(metricsOutput string) int { + // Look for vllm:e2e_request_latency_seconds_count{model_name="food-review"} + lines := strings.Split(metricsOutput, "\n") + for _, line := range lines { + if strings.Contains(line, "vllm:e2e_request_latency_seconds_count") && + strings.Contains(line, "food-review") { + // Extract the count value after the last space + parts := strings.Fields(line) + if len(parts) >= 2 { + count, err := strconv.Atoi(parts[len(parts)-1]) + if err == nil { + return count + } + } + } + } + return 0 +} + // Simple EPP configuration for running without P/D const simpleConfig = `apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig @@ -453,19 +843,24 @@ schedulingProfiles: // EPP configuration for running with P/D const pdConfig = `apiVersion: inference.networking.x-k8s.io/v1alpha1 kind: EndpointPickerConfig +featureGates: +- prepareDataPlugins plugins: - type: prefill-header-handler - type: prefix-cache-scorer parameters: - hashBlockSize: 10 + blockSizeTokens: 16 maxPrefixBlocksToMatch: 256 lruCapacityPerServer: 256 - type: prefill-filter - type: decode-filter - type: max-score-picker +- type: prefix-based-pd-decider + parameters: + nonCachedTokens: 16 - type: pd-profile-handler parameters: - threshold: 10 + deciderPluginName: prefix-based-pd-decider schedulingProfiles: - name: prefill plugins: @@ -487,15 +882,16 @@ kind: EndpointPickerConfig plugins: - type: precise-prefix-cache-scorer parameters: + tokenProcessorConfig: + blockSize: 16 + hashSeed: "42" kvEventsConfig: zmqEndpoint: tcp://0.0.0.0:5557 indexerConfig: prefixStoreConfig: blockSize: 16 - tokenProcessorConfig: - blockSize: 16 # must match vLLM block size if not default (16) - hashSeed: "42" # must match PYTHONHASHSEED in vLLM pods tokenizersPoolConfig: + modelName: Qwen/Qwen2.5-1.5B-Instruct hf: tokenizersCacheDir: "/cache/tokenizers" kvBlockIndexConfig: diff --git a/test/e2e/yaml/vllm-sim-pd.yaml b/test/e2e/yaml/vllm-sim-pd.yaml index 0f757c2f2..7a8abf14c 100644 --- a/test/e2e/yaml/vllm-sim-pd.yaml +++ b/test/e2e/yaml/vllm-sim-pd.yaml @@ -65,7 +65,7 @@ spec: args: - "--port=8000" - "--vllm-port=8200" - - "--connector=nixlv2" + - "--connector=${CONNECTOR_TYPE}" - "--secure-proxy=false" - "--decoder-use-tls=false" ports: diff --git a/test/e2e/yaml/vllm-sim.yaml b/test/e2e/yaml/vllm-sim.yaml index 36036996c..af6ad5d89 100644 --- a/test/e2e/yaml/vllm-sim.yaml +++ b/test/e2e/yaml/vllm-sim.yaml @@ -15,41 +15,45 @@ spec: app: ${POOL_NAME} spec: containers: - - name: vllm - image: ${VLLM_SIMULATOR_IMAGE} - imagePullPolicy: IfNotPresent - args: - - "--mode=echo" - - "--enable-kvcache=${KV_CACHE_ENABLED}" - - "--port=8000" - - "--model=${MODEL_NAME}" - - "--kv-cache-size=1024" - - "--block-size=16" - - "--zmq-endpoint=tcp://e2e-epp.default.svc.cluster.local:5557" - - "--event-batch-size=16" - - "--tokenizers-cache-dir=/tokenizer-cache" - ports: - - name: http - containerPort: 8000 - protocol: TCP - env: - - name: PORT - value: "8000" - - name: PYTHONHASHSEED - value: "42" - - name: POD_NAME - valueFrom: - fieldRef: - apiVersion: v1 - fieldPath: metadata.name - - name: POD_NAMESPACE - valueFrom: - fieldRef: - apiVersion: v1 - fieldPath: metadata.namespace - volumeMounts: - - name: tokenizer-cache - mountPath: /tokenizer-cache + - name: vllm + image: ${VLLM_SIMULATOR_IMAGE} + imagePullPolicy: IfNotPresent + args: + - "--mode=echo" + - "--enable-kvcache=${KV_CACHE_ENABLED}" + - "--port=8000" + - "--model=${MODEL_NAME}" + - "--kv-cache-size=1024" + - "--block-size=16" + - "--zmq-endpoint=tcp://e2e-epp.default.svc.cluster.local:5557" + - "--event-batch-size=16" + - "--tokenizers-cache-dir=/tokenizer-cache" + ports: + - name: http + containerPort: 8000 + protocol: TCP + env: + - name: POD_IP + valueFrom: + fieldRef: + fieldPath: status.podIP + - name: PORT + value: "8000" + - name: PYTHONHASHSEED + value: "42" + - name: POD_NAME + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.name + - name: POD_NAMESPACE + valueFrom: + fieldRef: + apiVersion: v1 + fieldPath: metadata.namespace + volumeMounts: + - name: tokenizer-cache + mountPath: /tokenizer-cache volumes: - - name: tokenizer-cache - emptyDir: {} + - name: tokenizer-cache + emptyDir: {} diff --git a/test/scripts/run_e2e.sh b/test/scripts/run_e2e.sh index 5d81e4ce6..a51fe0f4b 100755 --- a/test/scripts/run_e2e.sh +++ b/test/scripts/run_e2e.sh @@ -1,5 +1,17 @@ #!/bin/bash +set -euo pipefail + +cleanup() { + echo "Interrupted! Cleaning up kind cluster..." + kind delete cluster --name e2e-tests 2>/dev/null || true + exit 130 # SIGINT (Ctrl+C) +} + +# Set trap only for interruption signals +# Normally kind cluster cleanup is done by AfterSuite +trap cleanup INT TERM + echo "Running end to end tests" DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" diff --git a/test/sidecar/config/nixl/qwen-decoder-pod.yaml b/test/sidecar/config/nixl/qwen-decoder-pod.yaml index aaf7e1a80..6e24e5a85 100644 --- a/test/sidecar/config/nixl/qwen-decoder-pod.yaml +++ b/test/sidecar/config/nixl/qwen-decoder-pod.yaml @@ -58,13 +58,13 @@ spec: fieldRef: fieldPath: status.podIP - name: VLLM_NIXL_SIDE_CHANNEL_PORT - value: "5557" + value: "5600" - name: HF_HUB_CACHE value: /vllm-workspace/models - name: VLLM_LOGGING_LEVEL value: DEBUG ports: - - containerPort: 5557 + - containerPort: 5600 protocol: TCP volumeMounts: - name: model-cache diff --git a/test/sidecar/config/nixl/qwen-prefiller-pod.yaml b/test/sidecar/config/nixl/qwen-prefiller-pod.yaml index 1792a1c85..d5bb900e5 100644 --- a/test/sidecar/config/nixl/qwen-prefiller-pod.yaml +++ b/test/sidecar/config/nixl/qwen-prefiller-pod.yaml @@ -38,7 +38,7 @@ spec: - name: UCX_TLS value: "cuda_ipc,cuda_copy,tcp" - name: VLLM_NIXL_SIDE_CHANNEL_PORT - value: "5557" + value: "5600" - name: VLLM_NIXL_SIDE_CHANNEL_HOST valueFrom: fieldRef: @@ -53,7 +53,7 @@ spec: ports: - containerPort: 8000 protocol: TCP - - containerPort: 5557 + - containerPort: 5600 protocol: TCP resources: limits: diff --git a/test/sidecar/mock/chat_completions_handler.go b/test/sidecar/mock/chat_completions_handler.go index ddc94f917..6d5b8cd35 100644 --- a/test/sidecar/mock/chat_completions_handler.go +++ b/test/sidecar/mock/chat_completions_handler.go @@ -132,8 +132,8 @@ func (cc *ChatCompletionHandler) ServeHTTP(w http.ResponseWriter, r *http.Reques } - case "lmcache": - // LMCache protocol just returns empty response + case "shared-storage": + // Shared Storage protocol just returns empty response rawResponse = `{}` default: