diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 4c09d8619d9e..9ff754cf1e99 100755 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -68,7 +68,7 @@ std::optional from_torch(std::optiona class PyKvCacheManager : public tbk::BaseKVCacheManager { public: - NB_TRAMPOLINE(tbk::BaseKVCacheManager, 30); + NB_TRAMPOLINE(tbk::BaseKVCacheManager, 36); // using BaseKVCacheManager::BaseKVCacheManager; // Inherit constructors void allocatePools(bool useUvm = false) override @@ -255,6 +255,12 @@ class PyKvCacheManager : public tbk::BaseKVCacheManager { NB_OVERRIDE_PURE(flushIterationEvents); } + + SizeType32 countReusableBlocks(VecUniqueTokens const& uniqueTokens, tb::LlmRequest const& llmRequest, + bool onlyAllocated = false) const override + { + NB_OVERRIDE_PURE(countReusableBlocks, uniqueTokens, llmRequest, onlyAllocated); + } }; // TODO: Deduplicate executor bindings KvCacheStats diff --git a/security_scanning/docs/poetry.lock b/security_scanning/docs/poetry.lock index 0f110d524616..6846958a8fcb 100644 --- a/security_scanning/docs/poetry.lock +++ b/security_scanning/docs/poetry.lock @@ -976,14 +976,14 @@ rtd = ["ipython", "myst-nb", "sphinx", "sphinx-book-theme", "sphinx-examples"] [[package]] name = "sphinx-togglebutton" -version = "0.4.4" +version = "0.4.5" description = "Toggle page content and collapse admonitions in Sphinx." optional = false python-versions = "*" groups = ["main"] files = [ - {file = "sphinx_togglebutton-0.4.4-py3-none-any.whl", hash = "sha256:820658cd4c4c34c2ee7a21105e638b2f65a9e1d43ee991090715eb7fd9683cdf"}, - {file = "sphinx_togglebutton-0.4.4.tar.gz", hash = "sha256:04c332692fd5f5363ad02a001e693369767d6c1f0e58279770a2aeb571b472a1"}, + {file = "sphinx_togglebutton-0.4.5-py3-none-any.whl", hash = "sha256:74eac6d2426110c3e1e6f989a98e07d7823141a335df1ad8a9d637bdf6a7af62"}, + {file = "sphinx_togglebutton-0.4.5.tar.gz", hash = "sha256:c870dfbd3bc6e119b50ff9a37a64f8991902269e856728931c7d89877e8d4b3d"}, ] [package.dependencies] @@ -1239,4 +1239,4 @@ test = ["pytest (>=6.0.0)", "setuptools (>=77)"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.13" -content-hash = "25155b7ceb59522a3d568a3a7f15a11aca6e1b2e7f17bde117f1b1b33be32945" +content-hash = "0efe4e268bbba908b5b3650561fd66a46f89161835945adb6f58e34005c39019" diff --git a/security_scanning/docs/pyproject.toml b/security_scanning/docs/pyproject.toml index da05d31bade5..655c48ba3040 100644 --- a/security_scanning/docs/pyproject.toml +++ b/security_scanning/docs/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "breathe (>=4.36.0,<5.0.0)", "sphinx-copybutton (>=0.5.2,<0.6.0)", "autodoc-pydantic (>=2.2.0,<3.0.0)", - "sphinx-togglebutton (>=0.4.4,<0.5.0)", + "sphinx-togglebutton (>=0.4.5,<0.5.0)", "sphinxcontrib-mermaid (>=2.0.1,<3.0.0)" ] diff --git a/security_scanning/examples/auto_deploy/poetry.lock b/security_scanning/examples/auto_deploy/poetry.lock index a8b64a12b0ca..7cdbce6aac32 100644 --- a/security_scanning/examples/auto_deploy/poetry.lock +++ b/security_scanning/examples/auto_deploy/poetry.lock @@ -1188,14 +1188,14 @@ all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2 [[package]] name = "ipython" -version = "8.38.0" +version = "8.39.0" description = "IPython: Productive Interactive Computing" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "ipython-8.38.0-py3-none-any.whl", hash = "sha256:750162629d800ac65bb3b543a14e7a74b0e88063eac9b92124d4b2aa3f6d8e86"}, - {file = "ipython-8.38.0.tar.gz", hash = "sha256:9cfea8c903ce0867cc2f23199ed8545eb741f3a69420bfcf3743ad1cec856d39"}, + {file = "ipython-8.39.0-py3-none-any.whl", hash = "sha256:bb3c51c4fa8148ab1dea07a79584d1c854e234ea44aa1283bcb37bc75054651f"}, + {file = "ipython-8.39.0.tar.gz", hash = "sha256:4110ae96012c379b8b6db898a07e186c40a2a1ef5d57a7fa83166047d9da7624"}, ] [package.dependencies] diff --git a/security_scanning/examples/models/contrib/hyperclovax/poetry.lock b/security_scanning/examples/models/contrib/hyperclovax/poetry.lock index 3d88f4da4fb3..560ec6d5b0df 100644 --- a/security_scanning/examples/models/contrib/hyperclovax/poetry.lock +++ b/security_scanning/examples/models/contrib/hyperclovax/poetry.lock @@ -257,14 +257,14 @@ files = [ [[package]] name = "fsspec" -version = "2026.2.0" +version = "2026.3.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437"}, - {file = "fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff"}, + {file = "fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4"}, + {file = "fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41"}, ] [package.extras] diff --git a/security_scanning/examples/models/contrib/mmdit/poetry.lock b/security_scanning/examples/models/contrib/mmdit/poetry.lock index b52e1703493e..19d81ff964ae 100644 --- a/security_scanning/examples/models/contrib/mmdit/poetry.lock +++ b/security_scanning/examples/models/contrib/mmdit/poetry.lock @@ -281,14 +281,14 @@ files = [ [[package]] name = "fsspec" -version = "2026.2.0" +version = "2026.3.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437"}, - {file = "fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff"}, + {file = "fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4"}, + {file = "fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41"}, ] [package.extras] diff --git a/security_scanning/examples/models/contrib/stdit/poetry.lock b/security_scanning/examples/models/contrib/stdit/poetry.lock index f329861fc0e3..373565e58d1d 100644 --- a/security_scanning/examples/models/contrib/stdit/poetry.lock +++ b/security_scanning/examples/models/contrib/stdit/poetry.lock @@ -782,14 +782,14 @@ files = [ [[package]] name = "fsspec" -version = "2026.2.0" +version = "2026.3.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437"}, - {file = "fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff"}, + {file = "fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4"}, + {file = "fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41"}, ] [package.extras] diff --git a/security_scanning/examples/models/core/mixtral/poetry.lock b/security_scanning/examples/models/core/mixtral/poetry.lock index 4185fb2e7512..b29857bb92cb 100644 --- a/security_scanning/examples/models/core/mixtral/poetry.lock +++ b/security_scanning/examples/models/core/mixtral/poetry.lock @@ -309,14 +309,14 @@ files = [ [[package]] name = "fsspec" -version = "2026.2.0" +version = "2026.3.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437"}, - {file = "fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff"}, + {file = "fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4"}, + {file = "fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41"}, ] [package.extras] diff --git a/security_scanning/examples/models/core/mllama/poetry.lock b/security_scanning/examples/models/core/mllama/poetry.lock index e3c48ca44e15..d2c75dd6eded 100644 --- a/security_scanning/examples/models/core/mllama/poetry.lock +++ b/security_scanning/examples/models/core/mllama/poetry.lock @@ -302,14 +302,14 @@ files = [ [[package]] name = "fsspec" -version = "2026.2.0" +version = "2026.3.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437"}, - {file = "fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff"}, + {file = "fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4"}, + {file = "fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41"}, ] [package.extras] diff --git a/security_scanning/examples/serve/poetry.lock b/security_scanning/examples/serve/poetry.lock index d3452dc7d489..84ddd7df306e 100644 --- a/security_scanning/examples/serve/poetry.lock +++ b/security_scanning/examples/serve/poetry.lock @@ -1189,14 +1189,14 @@ files = [ [[package]] name = "fsspec" -version = "2026.2.0" +version = "2026.3.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437"}, - {file = "fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff"}, + {file = "fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4"}, + {file = "fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41"}, ] [package.extras] diff --git a/security_scanning/metadata.json b/security_scanning/metadata.json index 870ac677e5f1..65440783d506 100644 --- a/security_scanning/metadata.json +++ b/security_scanning/metadata.json @@ -1,4 +1,4 @@ { - "commit_hash": "d0d12138a352f40cf420795c192f18ecfaea6a81", - "timestamp": "2026-03-27T02:47:11Z" + "commit_hash": "789494fcfe75d130a9c79cc781d9628426b51835", + "timestamp": "2026-03-28T02:47:37Z" } diff --git a/security_scanning/triton_backend/poetry.lock b/security_scanning/triton_backend/poetry.lock index b3e3bce651b8..c27d6573a5d6 100644 --- a/security_scanning/triton_backend/poetry.lock +++ b/security_scanning/triton_backend/poetry.lock @@ -803,14 +803,14 @@ files = [ [[package]] name = "fsspec" -version = "2026.2.0" +version = "2026.3.0" description = "File-system specification" optional = false python-versions = ">=3.10" groups = ["main"] files = [ - {file = "fsspec-2026.2.0-py3-none-any.whl", hash = "sha256:98de475b5cb3bd66bedd5c4679e87b4fdfe1a3bf4d707b151b3c07e58c9a2437"}, - {file = "fsspec-2026.2.0.tar.gz", hash = "sha256:6544e34b16869f5aacd5b90bdf1a71acb37792ea3ddf6125ee69a22a53fb8bff"}, + {file = "fsspec-2026.3.0-py3-none-any.whl", hash = "sha256:d2ceafaad1b3457968ed14efa28798162f1638dbb5d2a6868a2db002a5ee39a4"}, + {file = "fsspec-2026.3.0.tar.gz", hash = "sha256:1ee6a0e28677557f8c2f994e3eea77db6392b4de9cd1f5d7a9e87a0ae9d01b41"}, ] [package.extras] @@ -1617,23 +1617,22 @@ files = [ [[package]] name = "protobuf" -version = "5.29.6" +version = "6.33.6" description = "" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" groups = ["main"] files = [ - {file = "protobuf-5.29.6-cp310-abi3-win32.whl", hash = "sha256:62e8a3114992c7c647bce37dcc93647575fc52d50e48de30c6fcb28a6a291eb1"}, - {file = "protobuf-5.29.6-cp310-abi3-win_amd64.whl", hash = "sha256:7e6ad413275be172f67fdee0f43484b6de5a904cc1c3ea9804cb6fe2ff366eda"}, - {file = "protobuf-5.29.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:b5a169e664b4057183a34bdc424540e86eea47560f3c123a0d64de4e137f9269"}, - {file = "protobuf-5.29.6-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:a8866b2cff111f0f863c1b3b9e7572dc7eaea23a7fae27f6fc613304046483e6"}, - {file = "protobuf-5.29.6-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:e3387f44798ac1106af0233c04fb8abf543772ff241169946f698b3a9a3d3ab9"}, - {file = "protobuf-5.29.6-cp38-cp38-win32.whl", hash = "sha256:36ade6ff88212e91aef4e687a971a11d7d24d6948a66751abc1b3238648f5d05"}, - {file = "protobuf-5.29.6-cp38-cp38-win_amd64.whl", hash = "sha256:831e2da16b6cc9d8f1654c041dd594eda43391affd3c03a91bea7f7f6da106d6"}, - {file = "protobuf-5.29.6-cp39-cp39-win32.whl", hash = "sha256:cb4c86de9cd8a7f3a256b9744220d87b847371c6b2f10bde87768918ef33ba49"}, - {file = "protobuf-5.29.6-cp39-cp39-win_amd64.whl", hash = "sha256:76e07e6567f8baf827137e8d5b8204b6c7b6488bbbff1bf0a72b383f77999c18"}, - {file = "protobuf-5.29.6-py3-none-any.whl", hash = "sha256:6b9edb641441b2da9fa8f428760fc136a49cf97a52076010cf22a2ff73438a86"}, - {file = "protobuf-5.29.6.tar.gz", hash = "sha256:da9ee6a5424b6b30fd5e45c5ea663aef540ca95f9ad99d1e887e819cdf9b8723"}, + {file = "protobuf-6.33.6-cp310-abi3-win32.whl", hash = "sha256:7d29d9b65f8afef196f8334e80d6bc1d5d4adedb449971fefd3723824e6e77d3"}, + {file = "protobuf-6.33.6-cp310-abi3-win_amd64.whl", hash = "sha256:0cd27b587afca21b7cfa59a74dcbd48a50f0a6400cfb59391340ad729d91d326"}, + {file = "protobuf-6.33.6-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:9720e6961b251bde64edfdab7d500725a2af5280f3f4c87e57c0208376aa8c3a"}, + {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_aarch64.whl", hash = "sha256:e2afbae9b8e1825e3529f88d514754e094278bb95eadc0e199751cdd9a2e82a2"}, + {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_s390x.whl", hash = "sha256:c96c37eec15086b79762ed265d59ab204dabc53056e3443e702d2681f4b39ce3"}, + {file = "protobuf-6.33.6-cp39-abi3-manylinux2014_x86_64.whl", hash = "sha256:e9db7e292e0ab79dd108d7f1a94fe31601ce1ee3f7b79e0692043423020b0593"}, + {file = "protobuf-6.33.6-cp39-cp39-win32.whl", hash = "sha256:bd56799fb262994b2c2faa1799693c95cc2e22c62f56fb43af311cae45d26f0e"}, + {file = "protobuf-6.33.6-cp39-cp39-win_amd64.whl", hash = "sha256:f443a394af5ed23672bc6c486be138628fbe5c651ccbc536873d7da23d1868cf"}, + {file = "protobuf-6.33.6-py3-none-any.whl", hash = "sha256:77179e006c476e69bf8e8ce866640091ec42e1beb80b213c3900006ecfba6901"}, + {file = "protobuf-6.33.6.tar.gz", hash = "sha256:a6768d25248312c297558af96a9f9c929e8c4cee0659cb07e780731095f38135"}, ] [[package]] @@ -2193,15 +2192,15 @@ vision = ["Pillow (>=10.0.1,<=15.0)"] [[package]] name = "tritonclient" -version = "2.66.0" +version = "2.67.0" description = "Python client library and utilities for communicating with Triton Inference Server" optional = false python-versions = "*" groups = ["main"] files = [ - {file = "tritonclient-2.66.0-py3-none-any.whl", hash = "sha256:7e3558a47542528a4edd2da607050e756022c6b14ea2b352a67c78f1484e698f"}, - {file = "tritonclient-2.66.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:d2c2865239017a0d15a5f34e23f093feac222e608d6182a1444354204a0833ad"}, - {file = "tritonclient-2.66.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:ec759cd8186130be9726be2fa31bc14b1dd46daa19add770610df889aa57d0a4"}, + {file = "tritonclient-2.67.0-py3-none-any.whl", hash = "sha256:5e2d4f2f14dd79faa9110dff9b89a869d52b3e15a146c645850ec276e2d04568"}, + {file = "tritonclient-2.67.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:e7dcba1810083f852f1cffef4954949896ab7fb0405cf913862b512ae3c487f1"}, + {file = "tritonclient-2.67.0-py3-none-manylinux2014_aarch64.whl", hash = "sha256:c23f057eb90472fe3051b26840a478595f9aea71ca8dda813bf4f4b208a69fed"}, ] [package.dependencies] @@ -2212,14 +2211,14 @@ grpcio = {version = ">=1.63.0,<1.68", optional = true, markers = "extra == \"all numpy = ">=1.19.1" packaging = {version = ">=14.1", optional = true, markers = "extra == \"all\""} perf-analyzer = {version = "*", optional = true, markers = "extra == \"all\""} -protobuf = {version = ">=5.26.1,<6.0.dev0", optional = true, markers = "extra == \"all\""} +protobuf = {version = ">=6.30.0,<7.0", optional = true, markers = "extra == \"all\""} python-rapidjson = ">=0.9.1" urllib3 = ">=2.0.7" [package.extras] -all = ["aiohttp (>=3.8.1,<4.0.0)", "cuda-python", "geventhttpclient (>=2.3.3)", "grpcio (>=1.63.0,<1.68)", "numpy (>=1.19.1)", "packaging (>=14.1)", "perf-analyzer", "protobuf (>=5.26.1,<6.0.dev0)", "python-rapidjson (>=0.9.1)"] +all = ["aiohttp (>=3.8.1,<4.0.0)", "cuda-python", "geventhttpclient (>=2.3.3)", "grpcio (>=1.63.0,<1.68)", "numpy (>=1.19.1)", "packaging (>=14.1)", "perf-analyzer", "protobuf (>=6.30.0,<7.0)", "python-rapidjson (>=0.9.1)"] cuda = ["cuda-python"] -grpc = ["grpcio (>=1.63.0,<1.68)", "numpy (>=1.19.1)", "packaging (>=14.1)", "protobuf (>=5.26.1,<6.0.dev0)", "python-rapidjson (>=0.9.1)"] +grpc = ["grpcio (>=1.63.0,<1.68)", "numpy (>=1.19.1)", "packaging (>=14.1)", "protobuf (>=6.30.0,<7.0)", "python-rapidjson (>=0.9.1)"] http = ["aiohttp (>=3.8.1,<4.0.0)", "geventhttpclient (>=2.3.3)", "numpy (>=1.19.1)", "python-rapidjson (>=0.9.1)"] perf-analyzer = ["perf-analyzer"] @@ -2461,4 +2460,4 @@ testing = ["coverage[toml]", "zope.event", "zope.testing"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.13" -content-hash = "cf300be248b685ba0d45a682bf08209d1793afd81ed57faac2357a8adcc4d304" +content-hash = "9fbbe64e936222cbe44fe0cdd487be9e9c009a8cd388911b59df10c801ca8378" diff --git a/security_scanning/triton_backend/pyproject.toml b/security_scanning/triton_backend/pyproject.toml index 924b626aacfa..8da1d1aedcdc 100644 --- a/security_scanning/triton_backend/pyproject.toml +++ b/security_scanning/triton_backend/pyproject.toml @@ -9,7 +9,7 @@ requires-python = ">=3.10,<3.13" dependencies = [ "regex (>=2026.2.28,<2027.0.0)", "fire (>=0.7.1,<0.8.0)", - "tritonclient[all] (>=2.66.0,<3.0.0)", + "tritonclient[all] (>=2.67.0,<3.0.0)", "transformers (==4.57.3)", "tabulate (>=0.10.0,<0.11.0)", "torchao (>=0.14.1)" diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 9f595a6fb95d..112b7437df0f 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -1335,6 +1335,7 @@ def create_py_executor_instance( waiting_queue_policy = (scheduler_config.waiting_queue_policy if scheduler_config is not None else WaitingQueuePolicy.FCFS) + return PyExecutor( resource_manager, scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index e5dd69c6ff84..130769730689 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -67,7 +67,7 @@ from .scheduler import (RequestScheduler, ScheduledRequests, SerializableSchedulerOutput, WaitingQueue, create_waiting_queue) -from .scheduler.adp_router import ADPRouter, DefaultADPRouter +from .scheduler.adp_router import ADPRouter # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." @@ -285,8 +285,7 @@ def __init__( virtual_memory_pools: Optional[dict] = None, hang_detection_timeout: Optional[int] = None, execution_stream: Optional[torch.cuda.Stream] = None, - waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS, - adp_router: Optional[ADPRouter] = None): + waiting_queue_policy: WaitingQueuePolicy = WaitingQueuePolicy.FCFS): super(PyExecutor, self).__init__() self.device_id = torch.cuda.current_device() self.global_rank = dist.rank @@ -313,7 +312,6 @@ def __init__( self.model_engine = model_engine self.enable_attention_dp = model_engine.enable_attention_dp self.dist = dist - self.adp_router: ADPRouter = (adp_router or DefaultADPRouter(dist=dist)) self.sampler = sampler self.drafter = drafter self.draft_model_engine = getattr(self.drafter, "draft_model_engine", @@ -387,6 +385,12 @@ def __init__( self.enable_kv_cache_reuse and self.kv_cache_manager.enable_partial_reuse) + self.adp_router: ADPRouter = ADPRouter.create( + dist=self.dist, + kv_cache_manager=self.kv_cache_manager, + attention_dp_config=self.llm_args.attention_dp_config, + ) + self.max_input_len = max_input_len # _executor_loop private data self.max_num_active_requests = model_engine.get_max_num_sequences() @@ -2583,6 +2587,9 @@ def _fetch_new_requests( # 6. Schedule requests across ranks (DP only) if self.enable_attention_dp: + if self.adp_router.needs_prefix_matches: + self.adp_router.gather_prefix_matches(new_requests) + all_ranks_new_requests, self.expected_num_active_requests = \ self.adp_router.route_requests( all_rank_states, new_requests, diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 0fefec42a565..43a7503655b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -559,6 +559,37 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], pin_memory=prefer_pinned(), device='cpu') + def probe_prefix_match_length(self, input_tokens, lora_task_id=None): + """Probe the KV cache radix tree for prefix match length. + + Returns the number of prefix tokens already cached on this rank. + Used by KVCacheAwareADPRouter for cache-aware routing. + """ + if not self.enable_block_reuse: + return 0 + # is_variable_window is only defined on the concrete KVCacheManager + # nanobind class, not on BaseKVCacheManager. Use getattr to avoid + # AttributeError on other subclasses or mocks. + if getattr(self.impl, 'is_variable_window', False): + return 0 + if not input_tokens: + return 0 + from tensorrt_llm.bindings import SamplingConfig + from tensorrt_llm.bindings.internal.batch_manager import BlockKey + from tensorrt_llm.bindings.internal.batch_manager import \ + LlmRequest as CppLlmRequest + block_key = BlockKey(tokens=input_tokens, lora_task_id=lora_task_id) + unique_tokens = block_key.unique_tokens + dummy_req = CppLlmRequest(request_id=0, + max_new_tokens=0, + input_tokens=input_tokens, + sampling_config=SamplingConfig(), + is_streaming=False, + lora_task_id=lora_task_id) + num_blocks = self.impl.count_reusable_blocks(unique_tokens, dummy_req, + False) + return num_blocks * self.tokens_per_block + def shutdown(self): self.impl.release_pools() diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py b/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py index 13638db19658..0c549f89f92b 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/__init__.py @@ -21,7 +21,7 @@ """ # Re-export from scheduler.py -from .adp_router import ADPRouter, DefaultADPRouter, RankState +from .adp_router import ADPRouter, DefaultADPRouter, KVCacheAwareADPRouter, RankState from .scheduler import ( BindCapacityScheduler, BindMicroBatchScheduler, @@ -66,6 +66,7 @@ # ADP "ADPRouter", "DefaultADPRouter", + "KVCacheAwareADPRouter", "RankState", # Waiting queues "FCFSWaitingQueue", diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py b/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py index 344247ee18ca..bec60692173f 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/adp_router.py @@ -8,6 +8,11 @@ 1. Each rank builds its local RankState 2. All ranks exchange RankState via allgather 3. ADPRouter.route_requests() distributes new requests + +Includes: + - DefaultADPRouter: load-balanced min-heap routing + - KVCacheAwareADPRouter: cache-aware routing that factors in + prefix match length from the KV cache radix tree """ from __future__ import annotations @@ -53,14 +58,55 @@ def deserialize(cls, data: list[int]) -> RankState: class ADPRouter(ABC): """Abstract interface for distributing new requests across ADP ranks. + This is an **instance-level** router: it distributes requests across the + DP ranks within a single instance (e.g., one mpirun job controlling 8 GPUs + that together form a single logical server). + + In disaggregated serving architectures, a separate higher-level router + orchestrates traffic across multiple instances (e.g., routing between + prefill and decode instances). That cross-instance routing is outside + the scope of this class. + Interface: Input: list[RankState], list[Request] Output: dict[rank, list[Request]] """ + needs_prefix_matches: bool = False + def __init__(self, dist: Distributed): self.dist = dist + @classmethod + def create( + cls, dist: "Distributed", kv_cache_manager=None, attention_dp_config=None + ) -> "ADPRouter": + """Factory method to create the appropriate ADP router. + + Args: + dist: Distributed communicator. + kv_cache_manager: KV cache manager instance (may be None). + attention_dp_config: AttentionDpConfig instance (may be None). + + Returns: + A KVCacheAwareADPRouter if config requests it and the + kv_cache_manager has block reuse enabled; DefaultADPRouter + otherwise. + """ + if ( + attention_dp_config is not None + and attention_dp_config.enable_kv_cache_aware_routing + and kv_cache_manager is not None + and kv_cache_manager.enable_block_reuse + ): + return KVCacheAwareADPRouter( + dist=dist, + kv_cache_manager=kv_cache_manager, + load_balance_weight=attention_dp_config.kv_cache_routing_load_balance_weight, + ) + + return DefaultADPRouter(dist=dist) + @abstractmethod def create_rank_state( self, @@ -272,3 +318,226 @@ def _balance_requests_across_ranks( all_ranks_new_requests[rank].extend(reqs) return all_ranks_new_requests + + +class KVCacheAwareADPRouter(ADPRouter): + """KV cache-aware request router for attention data parallelism. + + Routes requests considering both load balance and KV cache prefix match + length on each rank. When a request's prefix is already cached on a rank, + that rank is preferred to avoid redundant prefill computation. + + Scoring: score(rank, request) = effective_tokens + β * normalized_load + where: + effective_tokens = input_tokens - prefix_match_length + normalized_load = rank_active_tokens / max(total_active_tokens, req_tokens) * req_tokens + Lower score = better rank. + + The load term is normalized by the total active tokens across eligible + ranks (floored at req_tokens) so that both terms remain on the same + scale regardless of absolute load levels. + + Requires a KV cache manager with enable_block_reuse=True. + Falls back to load-based routing when no cache hits exist. + """ + + needs_prefix_matches: bool = True + + def __init__(self, dist: "Distributed", kv_cache_manager, load_balance_weight: float = 1.0): + super().__init__(dist) + self.kv_cache_manager = kv_cache_manager + self.load_balance_weight = load_balance_weight + self._all_ranks_prefix_matches: List[Dict[int, int]] = [] + + def create_rank_state( + self, + active_requests: list[LlmRequest], + new_requests: list[RequestQueueItem], + ) -> RankState: + if self.dist.has_cp_helix: + num_active_tokens = sum(req.total_input_len_cp for req in active_requests) + else: + num_active_tokens = sum(req.py_orig_prompt_len for req in active_requests) + return RankState( + rank=self.dist.tp_rank, + num_active_requests=len(active_requests), + num_active_tokens=num_active_tokens, + ) + + def gather_prefix_matches( + self, + new_requests: list[RequestQueueItem], + ) -> None: + """Probe local radix tree for each new request, allgather across ranks. + + Populates self._all_ranks_prefix_matches for use by route_requests. + Must be called after new_requests are available and before route_requests. + """ + local_matches: list[int] = [] + for req_item in new_requests: + req = req_item.request + if req is None: + local_matches.extend([req_item.id, 0]) + continue + input_tokens = getattr(req, "input_token_ids", None) or [] + probe_tokens = input_tokens[:-1] if len(input_tokens) > 1 else [] + lora_config = getattr(req, "lora_config", None) + lora_task_id = lora_config.task_id if lora_config is not None else None + match_len = self.kv_cache_manager.probe_prefix_match_length(probe_tokens, lora_task_id) + local_matches.extend([req_item.id, match_len]) + + all_data = self.dist.tp_allgather(local_matches) + + self._all_ranks_prefix_matches = [] + for rank_data in all_data: + matches: Dict[int, int] = {} + for i in range(0, len(rank_data), 2): + req_id = rank_data[i] + matches[req_id] = rank_data[i + 1] + self._all_ranks_prefix_matches.append(matches) + + def _score_rank( + self, + req_tokens: int, + match_len: int, + rank_active_tokens: float, + load_denom: float, + ) -> float: + """Score a candidate rank for a request (lower is better). + + Args: + req_tokens: Total input tokens of the request. + match_len: Prefix match length on this rank's radix tree. + rank_active_tokens: Active tokens currently on this rank. + load_denom: Normalization denominator for the load term. + + Returns: + Score combining cache miss cost and load penalty. + """ + effective = req_tokens - match_len + normalized_load = rank_active_tokens / load_denom * req_tokens + return effective + self.load_balance_weight * normalized_load + + @staticmethod + def _prefix_fingerprint(token_ids, num_tokens: int = 64) -> tuple: + """Return a hashable fingerprint from the first num_tokens tokens. + + Requests sharing the same fingerprint likely belong to the same + conversation / prefix group and benefit from being routed to the + same rank. + """ + if not token_ids: + return () + return tuple(token_ids[:num_tokens]) + + def route_requests( + self, + all_rank_states: list[RankState], + new_requests: list[RequestQueueItem], + max_num_active_requests: int, + ) -> Tuple[Dict[int, List[RequestQueueItem]], int]: + tp_size = len(all_rank_states) + all_ranks_new_requests: Dict[int, List[RequestQueueItem]] = { + s.rank: [] for s in all_rank_states + } + all_ranks_num_active_requests = [s.num_active_requests for s in all_rank_states] + all_ranks_num_active_tokens = [float(s.num_active_tokens) for s in all_rank_states] + + def get_relax_value(req_item): + scheduling_params = getattr(req_item.request, "py_scheduling_params", None) + if scheduling_params is None: + return True + return scheduling_params.attention_dp_relax + + sorted_requests = sorted(new_requests, key=get_relax_value) + + remaining_unscheduled = [] + for req_item in sorted_requests: + scheduled = False + scheduling_params = getattr(req_item.request, "py_scheduling_params", None) + if scheduling_params is not None: + target_dp_rank = scheduling_params.attention_dp_rank + if ( + target_dp_rank is not None + and all_ranks_num_active_requests[target_dp_rank] < max_num_active_requests + ): + all_ranks_num_active_requests[target_dp_rank] += 1 + scheduled = True + all_ranks_new_requests[target_dp_rank].append(req_item) + + if not scheduled: + remaining_unscheduled.append(req_item) + + num_new_requests_all_ranks = len(remaining_unscheduled) + total_num_active_requests = sum(all_ranks_num_active_requests) + expected_num_active_requests = max( + (total_num_active_requests + num_new_requests_all_ranks + tp_size - 1) // tp_size, + max(all_ranks_num_active_requests), + ) + + # --- Prefix-affinity sorting --- + # Sort by prefix fingerprint first (group related requests together), + # then by ISL descending within each group. This ensures that when + # request A (conv X, turn 5) is routed to rank R, request B (conv X, + # turn 3) is processed next and the load tracker still favours rank R. + def _sort_key(req_item): + tokens = getattr(req_item.request, "input_token_ids", []) if req_item.request else [] + fp = self._prefix_fingerprint(tokens) + # Negate length so longer requests come first within each group + return (fp, -len(tokens)) + + remaining_unscheduled = sorted(remaining_unscheduled, key=_sort_key) + + eligible_ranks = [ + rank + for rank in range(tp_size) + if all_ranks_num_active_requests[rank] < expected_num_active_requests + ] + + prefix_matches = self._all_ranks_prefix_matches + + for req_item in remaining_unscheduled: + if not eligible_ranks: + break + + req_tokens = ( + len(getattr(req_item.request, "input_token_ids", [])) if req_item.request else 0 + ) + req_id = req_item.id + + best_rank = eligible_ranks[0] + best_score = float("inf") + + # --- Normalize load term --- + # Normalize each rank's active_tokens by the total load across all + # eligible ranks (floored to req_tokens to avoid dividing by ~0). + # This makes the load term scale-invariant: a rank carrying 2x the + # average load gets a normalized penalty of ~req_tokens regardless + # of whether total load is 100 tokens or 100 000 tokens. + # Using load_range (max-min) instead caused the penalty to blow up + # when all ranks were near-idle (small range, large relative fraction). + total_load = sum(all_ranks_num_active_tokens[r] for r in eligible_ranks) + load_denom = max(total_load, float(req_tokens)) + + for rank in eligible_ranks: + match_len = prefix_matches[rank].get(req_id, 0) if rank < len(prefix_matches) else 0 + score = self._score_rank( + req_tokens, match_len, all_ranks_num_active_tokens[rank], load_denom + ) + if score < best_score: + best_score = score + best_rank = rank + + all_ranks_new_requests[best_rank].append(req_item) + all_ranks_num_active_requests[best_rank] += 1 + + match_len = ( + prefix_matches[best_rank].get(req_id, 0) if best_rank < len(prefix_matches) else 0 + ) + effective_added = req_tokens - match_len + all_ranks_num_active_tokens[best_rank] += effective_added + + if all_ranks_num_active_requests[best_rank] >= expected_num_active_requests: + eligible_ranks.remove(best_rank) + + return all_ranks_new_requests, expected_num_active_requests diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index db17c6d8e68e..9b6b525d00e1 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -540,6 +540,18 @@ class AttentionDpConfig(StrictBaseModel): batching_wait_iters: int = Field( default=10, description="The number of iterations to wait for batching.") + enable_kv_cache_aware_routing: bool = Field( + default=False, + description="Enable internal KV cache-aware routing for attention DP. " + "When enabled, distributes requests among ranks within a single " + "instance's attention DP group, routing them to the rank with the " + "matching prefix KV cache to reduce redundant prefill computation.") + kv_cache_routing_load_balance_weight: float = Field( + default=1.0, + description= + "Weight (beta) for the load-balance term in KV cache-aware routing. " + "Higher values prioritize load balance over cache affinity. " + "Only used when enable_kv_cache_aware_routing is True.") @model_validator(mode='after') def validate_attention_dp_config(self) -> 'AttentionDpConfig': diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index fc272ee5cfaf..2f18a4a4051b 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -54,11 +54,11 @@ def patched_start_mpi_pool(self): # isort: off from tensorrt_llm.llmapi import ( - AutoDecodingConfig, CudaGraphConfig, DeepSeekSparseAttentionConfig, - Eagle3DecodingConfig, KvCacheConfig, MoeConfig, MTPDecodingConfig, - NGramDecodingConfig, PARDDecodingConfig, RocketSparseAttentionConfig, - SADecodingConfig, SamplingParams, SchedulerConfig, - SkipSoftmaxAttentionConfig, TorchCompileConfig) + AttentionDpConfig, AutoDecodingConfig, CudaGraphConfig, + DeepSeekSparseAttentionConfig, Eagle3DecodingConfig, KvCacheConfig, + MoeConfig, MTPDecodingConfig, NGramDecodingConfig, PARDDecodingConfig, + RocketSparseAttentionConfig, SADecodingConfig, SamplingParams, + SchedulerConfig, SkipSoftmaxAttentionConfig, TorchCompileConfig) # isort: on from tensorrt_llm.quantization import QuantAlgo @@ -1693,6 +1693,32 @@ def test_bfloat16_mtp_sa(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm, extra_acc_spec="use_sa_spec") + @pytest.mark.skip_less_device(4) + @parametrize_with_ids("mtp_nextn", [0, 2]) + def test_bfloat16_4gpus_kv_cache_aware_routing(self, mtp_nextn): + """Accuracy test for attention DP with KV cache-aware routing.""" + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75, + enable_block_reuse=True) + pytorch_config = dict( + disable_overlap_scheduler=False, + cuda_graph_config=CudaGraphConfig(), + ) + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + attention_dp_config = AttentionDpConfig( + enable_kv_cache_aware_routing=True, ) + with LLM(self.MODEL_PATH, + tensor_parallel_size=4, + moe_expert_parallel_size=4, + kv_cache_config=kv_cache_config, + **pytorch_config, + enable_attention_dp=True, + attention_dp_config=attention_dp_config, + speculative_config=mtp_config) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @pytest.mark.skip_less_device(4) @parametrize_with_ids("torch_compile", [False, True]) @parametrize_with_ids("attention_dp,cuda_graph,overlap_scheduler", diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index ec19b2068832..370800c2a93f 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -110,6 +110,8 @@ accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[tp4-mtp_nextn=2] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[ep4-mtp_nextn=0] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[ep4-mtp_nextn=2] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=0] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=2] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus_static_eplb accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=0-moe_backend=WIDEEP] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2-moe_backend=WIDEEP] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index de8f4e963d1f..33e2b055655a 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -212,6 +212,8 @@ l0_dgx_b200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=0] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=2] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-low_precision_combine=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-tp2pp2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-low_precision_combine=False-torch_compile=True] diff --git a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml index 97da746c27fd..0cf908feb485 100644 --- a/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml +++ b/tests/integration/test_lists/test-db/l0_gb200_multi_gpus.yml @@ -25,6 +25,8 @@ l0_gb200_multi_gpus: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[ep4-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus[tp2pp2-mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=0] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_kv_cache_aware_routing[mtp_nextn=2] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[tp4-mtp_nextn=0] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[tp4-mtp_nextn=2] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_python_scheduler[ep4-mtp_nextn=0] diff --git a/tests/unittest/_torch/executor/test_kvcache_aware_router.py b/tests/unittest/_torch/executor/test_kvcache_aware_router.py new file mode 100644 index 000000000000..34518faa7549 --- /dev/null +++ b/tests/unittest/_torch/executor/test_kvcache_aware_router.py @@ -0,0 +1,439 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for KV cache-aware ADP router and probe_prefix_match_length. + +These tests use mock objects and do NOT require GPU. +""" + +from unittest.mock import MagicMock, Mock + +from tensorrt_llm._torch.pyexecutor.scheduler.adp_router import ( + ADPRouter, + KVCacheAwareADPRouter, + RankState, +) + +# ---- Helpers ---- + + +def _mock_dist(tp_rank=0, tp_size=1, has_cp_helix=False): + """Create a mock Distributed object for testing.""" + dist = MagicMock() + dist.tp_rank = tp_rank + dist.tp_size = tp_size + dist.has_cp_helix = has_cp_helix + return dist + + +def _make_request_item( + req_id, num_tokens=10, target_dp_rank=None, attention_dp_relax=True, lora_task_id=None +): + """Create a mock RequestQueueItem for testing.""" + item = MagicMock() + item.id = req_id + item.child_req_ids = None + scheduling_params = MagicMock() + scheduling_params.attention_dp_rank = target_dp_rank + scheduling_params.attention_dp_relax = attention_dp_relax + item.request = MagicMock() + item.request.py_scheduling_params = scheduling_params + item.request.input_token_ids = list(range(num_tokens)) + if lora_task_id is not None: + lora_config = MagicMock() + lora_config.task_id = lora_task_id + item.request.lora_config = lora_config + else: + item.request.lora_config = None + return item + + +def _mock_kv_cache_manager(probe_results=None): + """Create mock KV cache manager with configurable probe results. + + Args: + probe_results: dict mapping (tuple(input_tokens), lora_task_id) -> match_length. + If None, all probes return 0. + """ + mgr = MagicMock() + probe_results = probe_results or {} + + def mock_probe(input_tokens, lora_task_id=None): + key = (tuple(input_tokens), lora_task_id) + return probe_results.get(key, 0) + + mgr.probe_prefix_match_length = Mock(side_effect=mock_probe) + return mgr + + +# ---- Tests for KVCacheAwareADPRouter ---- + + +class TestKVCacheAwareADPRouter: + """Tests for the KV cache-aware ADP router.""" + + def test_is_adp_router(self): + dist = _mock_dist() + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + assert isinstance(router, ADPRouter) + + def test_create_rank_state(self): + dist = _mock_dist(tp_rank=0) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + req1 = Mock(py_orig_prompt_len=100) + req2 = Mock(py_orig_prompt_len=200) + state = router.create_rank_state([req1, req2], []) + assert state.rank == 0 + assert state.num_active_requests == 2 + assert state.num_active_tokens == 300 + + def test_create_rank_state_cp_helix(self): + dist = _mock_dist(tp_rank=1, has_cp_helix=True) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + req1 = Mock(total_input_len_cp=150) + state = router.create_rank_state([req1], []) + assert state.rank == 1 + assert state.num_active_tokens == 150 + + # -- gather_prefix_matches tests -- + + def test_gather_prefix_matches_single_rank(self): + tokens_a = list(range(100)) + tokens_b = list(range(50)) + probe_results = { + (tuple(tokens_a[:-1]), None): 64, + (tuple(tokens_b[:-1]), None): 0, + } + + dist = _mock_dist(tp_rank=0, tp_size=1) + dist.tp_allgather = Mock(side_effect=lambda x: [x]) + mgr = _mock_kv_cache_manager(probe_results) + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + req_a = _make_request_item(1, num_tokens=100) + req_b = _make_request_item(2, num_tokens=50) + + router.gather_prefix_matches([req_a, req_b]) + assert len(router._all_ranks_prefix_matches) == 1 + assert router._all_ranks_prefix_matches[0] == {1: 64, 2: 0} + + def test_gather_prefix_matches_two_ranks(self): + tokens_a = list(range(100)) + probe_results = { + (tuple(tokens_a[:-1]), None): 64, + } + + dist = _mock_dist(tp_rank=0, tp_size=2) + # Simulate allgather: rank 0 sends [1, 64, 2, 0], rank 1 sends [1, 32, 2, 0] + dist.tp_allgather = Mock(return_value=[[1, 64, 2, 0], [1, 32, 2, 0]]) + mgr = _mock_kv_cache_manager(probe_results) + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + req_a = _make_request_item(1, num_tokens=100) + req_b = _make_request_item(2, num_tokens=50) + + router.gather_prefix_matches([req_a, req_b]) + assert len(router._all_ranks_prefix_matches) == 2 + assert router._all_ranks_prefix_matches[0] == {1: 64, 2: 0} + assert router._all_ranks_prefix_matches[1] == {1: 32, 2: 0} + + def test_gather_prefix_matches_with_lora(self): + tokens_a = list(range(100)) + probe_results = { + (tuple(tokens_a[:-1]), 42): 64, + } + + dist = _mock_dist(tp_rank=0, tp_size=1) + dist.tp_allgather = Mock(side_effect=lambda x: [x]) + mgr = _mock_kv_cache_manager(probe_results) + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + req_a = _make_request_item(1, num_tokens=100, lora_task_id=42) + router.gather_prefix_matches([req_a]) + assert router._all_ranks_prefix_matches[0] == {1: 64} + + def test_gather_prefix_matches_empty(self): + dist = _mock_dist(tp_rank=0, tp_size=1) + dist.tp_allgather = Mock(side_effect=lambda x: [x]) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + router.gather_prefix_matches([]) + assert len(router._all_ranks_prefix_matches) == 1 + assert router._all_ranks_prefix_matches[0] == {} + + # -- route_requests tests -- + + def test_route_prefers_cached_rank(self): + """Request with cache hit on rank 0 → routed to rank 0.""" + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + router._all_ranks_prefix_matches = [ + {1: 80}, # rank 0: 80 tokens cached + {1: 0}, # rank 1: no cache + ] + + states = [ + RankState(rank=0, num_active_requests=0, num_active_tokens=0), + RankState(rank=1, num_active_requests=0, num_active_tokens=0), + ] + req = _make_request_item(1, num_tokens=100) + + result, _ = router.route_requests(states, [req], max_num_active_requests=10) + assert len(result[0]) == 1 # routed to rank 0 + assert len(result[1]) == 0 + + def test_route_degenerates_to_load_balance_no_cache(self): + """No cache hits → routes to least loaded rank.""" + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + router._all_ranks_prefix_matches = [{1: 0}, {1: 0}] + + states = [ + RankState(rank=0, num_active_requests=5, num_active_tokens=500), + RankState(rank=1, num_active_requests=1, num_active_tokens=100), + ] + req = _make_request_item(1, num_tokens=100) + + result, _ = router.route_requests(states, [req], max_num_active_requests=10) + assert len(result[1]) == 1 # rank 1 is less loaded + assert len(result[0]) == 0 + + def test_route_load_overcomes_cache(self): + """Heavy load on cached rank → routes to idle rank despite no cache.""" + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr, load_balance_weight=10.0) + + router._all_ranks_prefix_matches = [{1: 80}, {1: 0}] + + states = [ + RankState(rank=0, num_active_requests=5, num_active_tokens=5000), + RankState(rank=1, num_active_requests=0, num_active_tokens=0), + ] + req = _make_request_item(1, num_tokens=100) + + result, _ = router.route_requests(states, [req], max_num_active_requests=10) + # total_load=5000, load_denom=max(5000, 100)=5000 + # score(rank0) = (100-80) + 10 * (5000/5000 * 100) = 20 + 1000 = 1020 + # score(rank1) = (100-0) + 10 * (0/5000 * 100) = 100 + 0 = 100 + assert len(result[1]) == 1 + assert len(result[0]) == 0 + + def test_route_respects_explicit_dp_rank(self): + """Non-relaxed request with explicit dp_rank → honored.""" + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + router._all_ranks_prefix_matches = [{1: 80}, {1: 0}] + + states = [ + RankState(rank=0, num_active_requests=0, num_active_tokens=0), + RankState(rank=1, num_active_requests=0, num_active_tokens=0), + ] + req = _make_request_item(1, num_tokens=100, target_dp_rank=1, attention_dp_relax=False) + + result, _ = router.route_requests(states, [req], max_num_active_requests=10) + assert len(result[1]) == 1 # forced to rank 1 + assert len(result[0]) == 0 + + def test_route_multiple_requests_effective_token_update(self): + """After routing, rank load increases by effective (not full) tokens. + + Uses 4 requests so expected_num_active_requests=2, allowing 2 cached + requests to land on rank 0 before the capacity ceiling kicks in. + If load were updated with full tokens (100) instead of effective (20), + the 2nd cached request would score 120 on rank 0 vs 100 on rank 1 + and go to rank 1 instead. + """ + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + # Requests 1,2 have cache on rank 0; requests 3,4 have no cache + router._all_ranks_prefix_matches = [ + {1: 80, 2: 80, 3: 0, 4: 0}, # rank 0 + {1: 0, 2: 0, 3: 0, 4: 0}, # rank 1 + ] + + states = [ + RankState(rank=0, num_active_requests=0, num_active_tokens=0), + RankState(rank=1, num_active_requests=0, num_active_tokens=0), + ] + req_a = _make_request_item(1, num_tokens=100) + req_b = _make_request_item(2, num_tokens=100) + req_c = _make_request_item(3, num_tokens=100) + req_d = _make_request_item(4, num_tokens=100) + + result, _ = router.route_requests( + states, [req_a, req_b, req_c, req_d], max_num_active_requests=10 + ) + # expected_num_active_requests = max((0+4+1)//2, 0) = 2 + # req_a → rank 0: total_load=0, load_denom=max(0,100)=100 + # score(0)=(100-80)+1.0*(0/100*100)=20, score(1)=100+0=100 → rank0 + # active_tokens[0] += 20 → [20, 0] + # req_b → rank 0: total_load=20, load_denom=max(20,100)=100 + # score(0)=(100-80)+1.0*(20/100*100)=40, score(1)=100+0=100 → rank0 + # active_tokens[0] += 20 → [40, 0]; rank0 at capacity (2), removed + # req_c → rank 1 (only eligible): active_tokens[1] += 100 + # req_d → rank 1: active_tokens[1] += 100 + assert len(result[0]) == 2 # cached requests on rank 0 + assert result[0][0].id == 1 + assert result[0][1].id == 2 + assert len(result[1]) == 2 # uncached requests on rank 1 + assert result[1][0].id == 3 + assert result[1][1].id == 4 + + def test_route_empty_requests(self): + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + router._all_ranks_prefix_matches = [{}, {}] + + states = [ + RankState(rank=0, num_active_requests=0, num_active_tokens=0), + RankState(rank=1, num_active_requests=0, num_active_tokens=0), + ] + + result, expected = router.route_requests(states, [], max_num_active_requests=10) + assert result == {0: [], 1: []} + assert expected >= 0 + + def test_route_four_ranks_balanced(self): + """4 ranks, no cache hits → balanced distribution.""" + dist = _mock_dist(tp_rank=0, tp_size=4) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + router._all_ranks_prefix_matches = [{}, {}, {}, {}] + + states = [RankState(rank=i, num_active_requests=0, num_active_tokens=0) for i in range(4)] + reqs = [_make_request_item(i, num_tokens=10) for i in range(8)] + + result, _ = router.route_requests(states, reqs, max_num_active_requests=10) + total = sum(len(v) for v in result.values()) + assert total == 8 + for rank_reqs in result.values(): + assert len(rank_reqs) == 2 + + def test_route_capacity_limit_respected(self): + """Requests beyond max capacity are not assigned.""" + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + router._all_ranks_prefix_matches = [{1: 0, 2: 0, 3: 0}, {1: 0, 2: 0, 3: 0}] + + states = [ + RankState(rank=0, num_active_requests=0, num_active_tokens=0), + RankState(rank=1, num_active_requests=0, num_active_tokens=0), + ] + reqs = [_make_request_item(i, num_tokens=10) for i in range(1, 4)] + + result, _ = router.route_requests(states, reqs, max_num_active_requests=2) + total = sum(len(v) for v in result.values()) + assert total == 3 # all assigned (expected = max(ceil(3/2), 0) = 2 per rank) + + def test_route_mixed_cache_and_no_cache(self): + """Some requests have cache hits, some don't → smart routing.""" + dist = _mock_dist(tp_rank=0, tp_size=2) + mgr = _mock_kv_cache_manager() + router = KVCacheAwareADPRouter(dist=dist, kv_cache_manager=mgr) + + # req 1 has cache on rank 0, req 2 has no cache anywhere + router._all_ranks_prefix_matches = [ + {1: 80, 2: 0}, # rank 0 + {1: 0, 2: 0}, # rank 1 + ] + + states = [ + RankState(rank=0, num_active_requests=0, num_active_tokens=0), + RankState(rank=1, num_active_requests=0, num_active_tokens=0), + ] + req_a = _make_request_item(1, num_tokens=100) # cache hit on rank 0 + req_b = _make_request_item(2, num_tokens=100) # no cache + + result, _ = router.route_requests(states, [req_a, req_b], max_num_active_requests=10) + # req_a → rank 0 (cache hit: effective=20 vs 100) + # active_tokens → [20, 0] + # req_b → rank 1 (both ranks have 0 cache; rank 0 carries 20 tokens) + # total_load=20, load_denom=max(20,100)=100 + # score(rank0) = 100 + 1.0*(20/100*100) = 120 + # score(rank1) = 100 + 1.0*(0/100*100) = 100 + assert len(result[0]) == 1 + assert result[0][0].id == 1 + assert len(result[1]) == 1 + assert result[1][0].id == 2 + + +# ---- Tests for V1 KVCacheManager.probe_prefix_match_length ---- + + +class TestProbeOnV1KVCacheManager: + """Test probe_prefix_match_length on v1 (C++) KVCacheManager using mocks.""" + + def test_block_reuse_disabled_returns_zero(self): + from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager + + mgr = Mock(spec=KVCacheManager) + mgr.enable_block_reuse = False + + result = KVCacheManager.probe_prefix_match_length(mgr, input_tokens=[1, 2, 3]) + assert result == 0 + + def test_variable_window_returns_zero(self): + from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager + + mgr = Mock(spec=KVCacheManager) + mgr.enable_block_reuse = True + mgr.impl = Mock() + mgr.impl.is_variable_window = True + + result = KVCacheManager.probe_prefix_match_length(mgr, input_tokens=[1, 2, 3]) + assert result == 0 + # count_reusable_blocks should NOT be called (would crash) + mgr.impl.count_reusable_blocks.assert_not_called() + + def test_empty_tokens_returns_zero(self): + from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager + + mgr = Mock(spec=KVCacheManager) + mgr.enable_block_reuse = True + mgr.impl = Mock() + mgr.impl.is_variable_window = False + + result = KVCacheManager.probe_prefix_match_length(mgr, input_tokens=[]) + assert result == 0 + + def test_block_to_token_conversion(self): + """Verify num_blocks * tokens_per_block conversion.""" + from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager + + mgr = Mock(spec=KVCacheManager) + mgr.enable_block_reuse = True + mgr.impl = Mock() + mgr.impl.is_variable_window = False + mgr.impl.count_reusable_blocks = Mock(return_value=3) + mgr.tokens_per_block = 64 + + result = KVCacheManager.probe_prefix_match_length(mgr, input_tokens=list(range(200))) + assert result == 192 # 3 blocks * 64 tokens/block