@@ -312,17 +312,29 @@ def pin_layer(
312312 self ,
313313 layer_idx : int ,
314314 ) -> dict [str , torch .Tensor ]:
315+ """Return a layer's streamed tensors as pinned-host tensors.
316+
317+ Cached: on the second call for the same layer we hand back the
318+ already-pinned tensors and skip the safetensors read entirely. The
319+ streaming scheduler calls this once per layer per token, so missing
320+ the cache pays full disk I/O on every decode step — that's a 3+
321+ second per-token regression on a 32B model.
322+ """
315323
316324 assert self .index is not None
317325
318326 layer = self .index .layers [layer_idx ]
327+ needed = [s for s in layer .tensors .values () if not s .is_resident ]
319328
320- grouped : dict [Path , list [TensorSpec ]] = {}
321-
322- for spec in layer .tensors .values ():
323- if spec .is_resident :
324- continue
329+ # Cache hit: every needed tensor is already in _pinned_tensors.
330+ # Strict all-or-nothing — partial hits force a re-read so we don't
331+ # mix tensors from different load passes (defensive; in practice
332+ # we either pinned the whole layer or none of it).
333+ if needed and all (s .name in self ._pinned_tensors for s in needed ):
334+ return {s .name : self ._pinned_tensors [s .name ] for s in needed }
325335
336+ grouped : dict [Path , list [TensorSpec ]] = {}
337+ for spec in needed :
326338 grouped .setdefault (spec .path , []).append (spec )
327339
328340 out : dict [str , torch .Tensor ] = {}
0 commit comments