Skip to content

Commit 9f15e77

Browse files
chesterxgchenclaude
authored andcommitted
[2.7] Pass-Through: Zero Tensor Copy at CJ for Large-Model Federated Training (#4210)
This PR introduces the **pass-through architecture** for `ClientAPILauncherExecutor`, eliminating tensor materialisation at the CJ (Client Job) process when large models are exchanged between the FL server and a subprocess agent. In large-model federated learning (e.g., 7B–70B LLM fine-tuning), the CJ process today acts as a blind relay that fully deserializes and re-serializes every tensor it receives from the FL server before forwarding to the subprocess. For a 70B float16 model, this consumes ~140 GB of CJ memory and requires two complete network transfers. B1 pass-through removes both costs. --- NVFlare's multi-hop execution path for `launch_external_process=True` looks like: ``` FL Server ──serialize──▶ CJ process ──re-serialize──▶ Subprocess agent ``` Each tensor in the global model is handled as follows at CJ: 1. **Server** serializes the model and creates a download transaction (tensor data lives on the server). 2. **CJ** fully *downloads* every tensor from the server into its own heap, materialising the complete model in CJ memory. 3. **CJ** re-serializes the model for the subprocess, creating a *new* download transaction — the subprocess then downloads from CJ. For large models, this means: - **CJ peak memory = full model size** (potentially 100s of GB). - **Two full network transfers**: server → CJ, then CJ → subprocess. - **CJ becomes a throughput bottleneck** and an OOM risk for any model that doesn't fit in the CJ process's memory. This is the reason why workflows are infeasible for large models beyond what the CJ machine can hold. --- With `FOBSContextKey.PASS_THROUGH` enabled on CJ's cell FOBS context, the data path becomes: ``` FL Server ──stream──▶ CJ (LazyDownloadRef only, no tensor data) └──forward ref──▶ Subprocess └──download──▶ FL Server ``` CJ holds **only lightweight placeholders** (< 100 bytes per tensor). The subprocess downloads each tensor directly from the FL server — CJ is never involved in the tensor data path. (`nvflare/fuel/utils/fobs/__init__.py`) A new context key that signals ViaDownloaderDecomposer to skip the download step and create lazy placeholders instead. A small sentinel object (four fields: `fqcn`, `ref_id`, `item_id`, `dot`) created by `recompose()` in PASS_THROUGH mode. It carries the original FL server's FQCN, batch ref_id, intra-batch item ID, and Datum Object Type — everything the subprocess needs to download the tensor directly. A named sentinel stored in `fobs_ctx[items_key]` during PASS_THROUGH receive. Using a typed class (rather than a plain tuple) makes the PASS_THROUGH branch unambiguous and immune to accidental type collisions with real item dicts. A new auto-registered FOBS decomposer for `LazyDownloadRef`. When CJ re-serializes a task containing `LazyDownloadRef` objects: - **`decompose()`** delegates to `get_dot_handler(lazy.dot)` — the original `ViaDownloaderDecomposer` subclass (e.g., `TensorDecomposer`, `NumpyArrayDecomposer`). That handler's `_finalize_lazy_batch` post-callback re-emits the *original* server datum (fqcn + ref_id + DOT) so the subprocess knows exactly where to download from. `lazy_dot` is appended to the encoding dict for routing on the receive side. - **`recompose()`** uses `lazy_dot` to look up the handler and delegates to `handler.recompose()`, which retrieves the real tensor from `fobs_ctx[handler.items_key]` (populated by `process_datum()` when the subprocess received the forwarded datum). The `dot` (Datum Object Type) field on both `LazyDownloadRef` and `_LazyBatchInfo` ensures that numpy arrays stay with `NumpyArrayDecomposer` and PyTorch tensors stay with `TensorDecomposer`, preserving type safety through the full pass-through hop. (`client_api_launcher_executor.py`) On startup, the executor enables PASS_THROUGH on the engine cell's FOBS context: ```python cell.core_cell.update_fobs_context({FOBSContextKey.PASS_THROUGH: True}) ``` This single line activates the full B1 architecture for every job that uses `launch_external_process=True` — including `llm_hf` and any recipe that calls `ScriptRunner(launch_external_process=True)`. --- The pipe (CellPipe) operates on already-serialized bytes. Intercepting at the pipe level would require parsing FOBS binary format, re-writing datum references, and re-assembling the byte stream — fragile and tightly coupled to the wire format. Intercepting at the FOBS decomposer level is the natural extension point: decomposers already control exactly when and how data is materialised. PASS_THROUGH simply adds a "don't materialise" branch to that existing mechanism. The subprocess must know *which* `ViaDownloaderDecomposer` subclass owns the downloaded data so it can store it in the correct `fobs_ctx[items_key]` and route `recompose()` correctly. The `dot` field, set when the server originally serialized the tensor, carries this type information through the pass-through hop without any type-switching logic. | Stage | Before (tensor materialised) | After (B1 pass-through) | |-------|------------------------------|------------------------| | CJ receive | Full model size (e.g., 140 GB) | ~100 bytes per tensor | | CJ forward | Creates new download tx | Re-emits original server datum | | Subprocess receive | Downloads from CJ | Downloads directly from FL server | --- 1. **Zero tensor copy at CJ** — CJ memory footprint is independent of model size. 2. **One network transfer** instead of two — tensors travel server → subprocess directly. 3. **No CJ OOM risk** for large models regardless of CJ machine memory capacity. 4. **Transparent to job authors** — no changes to job configs, training scripts, or recipe APIs; `launch_external_process=True` automatically activates B1. 5. **Type-safe** — `dot` propagation preserves tensor type (numpy / pytorch) through the hop without any if/elif type switching. --- - All existing jobs using `launch_external_process=True` automatically benefit. No config or script changes required. - Jobs using `launch_external_process=False` (in-process executor) are completely unaffected — `ClientAPILauncherExecutor.initialize()` is not called. - For models smaller than the ViaDownloaderDecomposer streaming threshold (2 MB per array), FOBS uses native (inline) serialization regardless of `PASS_THROUGH` — behaviour is identical to before. - `LazyDownloadRefDecomposer` is auto-registered via the existing `register_folder` mechanism; no explicit registration call is needed by any caller. --- | File | Change | |------|--------| | `nvflare/fuel/utils/fobs/__init__.py` | Add `FOBSContextKey.PASS_THROUGH` | | `nvflare/fuel/utils/fobs/decomposers/via_downloader.py` | Add `LazyDownloadRef`, `_LazyBatchInfo`, PASS_THROUGH branches in `process_datum()` / `recompose()`, `LazyDownloadRefDecomposer`, `_finalize_lazy_batch` post-callback | | `nvflare/app_common/executors/client_api_launcher_executor.py` | `initialize()` enables PASS_THROUGH on engine cell | --- (22 tests) | Test class | What is verified | |------------|-----------------| | `TestLazyDownloadRef` | Construction, `__slots__`, per-item distinctness | | `TestLazyBatchInfo` | Construction, `__slots__`, `isinstance` reliability vs plain tuple | | `TestProcessDatumPassThrough` | PASS_THROUGH stores `_LazyBatchInfo`, never calls `_download_from_remote_cell`; normal mode calls download | | `TestRecomposePassThrough` | Returns `LazyDownloadRef` with correct `fqcn`, `ref_id`, `item_id` from `_LazyBatchInfo` | | `TestDecomposeWithLazyDownloadRef` | Returns REF encoding; `_finalize_lazy_batch` post-CB registered once per batch regardless of item count; emitted datum has correct fqcn/ref_id/DOT | | `TestNoMemoryAccumulation` | `_CtxKey.OBJECTS` absent after PASS_THROUGH (no download transaction opened); `DownloadService._tx_table` unchanged; 50-cycle repeat produces no state bleed | `tests/unit_test/fuel/f3/streaming/test_pass_through_e2e.py` (5 tests, real TCP Cells) | Test | What is verified | |------|-----------------| | `test_arrays_survive_pass_through_hop` | Full round-trip: server → CJ (PASS_THROUGH) → subprocess, arrays arrive bit-exact | | `test_cj_holds_only_lazy_refs_not_tensor_data` | After PASS_THROUGH deserialization, CJ holds only `LazyDownloadRef`, never `np.ndarray` | | `test_cj_creates_no_download_transaction` | `DownloadService._tx_table` is unchanged during PASS_THROUGH + re-serialization | | `test_forwarded_payload_carries_original_server_ref` | Forwarded datum contains original server `fqcn` and `ref_id` — subprocess downloads from server, not CJ | | `test_multiple_array_roundtrip` | 8-array batch all survive with bit-exact values | `tests/integration_test/data/jobs/pt_large_model_pass_through/` Full-stack integration test using `PTClientAPILauncherExecutor` with `launch_once=True` (the pattern used by `llm_hf`): - **Model**: `LargeNet` — 3-layer MLP with ~8 MB of float32 parameters, well above the 2 MB ViaDownloaderDecomposer streaming threshold. This forces the real B1 code path (streaming + PASS_THROUGH) rather than the native inline path used by small models. - **Client script**: Mirrors `llm_hf/client.py` structure (`while flare.is_running()` loop, receive / train / send). Uses CPU-only synthetic data — no dataset download required in CI. - **Added to** `client_api.yml` as `"run pt-large-model-pass-through"`. 🤖 Generated with [Claude Code](https://claude.com/claude-code) --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent aaf5b13 commit 9f15e77

File tree

14 files changed

+1387
-6
lines changed

14 files changed

+1387
-6
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,3 +183,6 @@ CLAUDE.local.md
183183

184184
# local Codex artifacts
185185
.codex/
186+
187+
# memory profiler output
188+
tests/memory_profile/**/*.dat

docs/programming_guide/memory_management.rst

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,51 @@ measured and confirmed RSS is stable without cleanup.
142142
Recommended Settings
143143
====================
144144

145-
+--------+-------------------------------+----------------------+
146-
| Role | ``server_memory_gc_rounds`` | ``MALLOC_ARENA_MAX`` |
147-
+========+===============================+======================+
148-
| Server | 5 | 4 |
149-
+--------+-------------------------------+----------------------+
145+
+--------+-----------------------------+-----------------------------+----------------------+----------------------+
146+
| Role | ``server_memory_gc_rounds`` | ``client_memory_gc_rounds`` | ``MALLOC_ARENA_MAX`` | ``cuda_empty_cache`` |
147+
+========+=============================+=============================+======================+======================+
148+
| Server | 5 | N/A | 4 | N/A |
149+
+--------+-----------------------------+-----------------------------+----------------------+----------------------+
150+
| Client | N/A | 1 | 2 | True (for GPU) |
151+
+--------+-----------------------------+-----------------------------+----------------------+----------------------+
152+
153+
Using jemalloc
154+
==============
155+
156+
For PyTorch workloads, jemalloc is recommended over glibc malloc. NVFlare startup
157+
scripts preload jemalloc only when explicitly enabled via
158+
``NVFLARE_ENABLE_JEMALLOC_PRELOAD=true`` and jemalloc is available.
159+
160+
Startup Script
161+
--------------
162+
163+
The generated ``sub_start.sh`` script includes opt-in jemalloc preload:
164+
165+
.. code-block:: bash
166+
167+
# Enable jemalloc preload only when opted in
168+
if [ "${NVFLARE_ENABLE_JEMALLOC_PRELOAD:-false}" = "true" ]; then
169+
for JEMALLOC in /usr/lib/x86_64-linux-gnu/libjemalloc.so.2 \
170+
/usr/lib64/libjemalloc.so.2 \
171+
/usr/local/lib/libjemalloc.so; do
172+
if [ -f "$JEMALLOC" ]; then
173+
export LD_PRELOAD="${LD_PRELOAD:+$LD_PRELOAD:}$JEMALLOC"
174+
export MALLOC_CONF="${MALLOC_CONF:-dirty_decay_ms:5000,muzzy_decay_ms:5000}"
175+
break
176+
fi
177+
done
178+
fi
179+
180+
Installing jemalloc
181+
-------------------
182+
183+
.. code-block:: bash
184+
185+
# Ubuntu/Debian
186+
apt-get install libjemalloc2
187+
188+
# RHEL/CentOS
189+
yum install jemalloc
150190
151191
API Reference
152192
=============

nvflare/app_common/executors/client_api_launcher_executor.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file
2323
from nvflare.client.constants import CLIENT_API_CONFIG, EXTERNAL_PRE_INIT_TIMEOUT
2424
from nvflare.fuel.utils.attributes_exportable import ExportMode
25+
from nvflare.fuel.utils.fobs import FOBSContextKey
2526
from nvflare.utils.configs import get_client_config_value
2627

2728

@@ -107,10 +108,38 @@ def __init__(
107108
self._params_exchange_format = params_exchange_format
108109
self._params_transfer_type = params_transfer_type
109110
self._config_file_name = config_file_name
111+
self._memory_gc_rounds = memory_gc_rounds
112+
self._cuda_empty_cache = cuda_empty_cache
113+
self._cell_with_pass_through = None
114+
self._prev_pass_through = None
110115

111116
def initialize(self, fl_ctx: FLContext) -> None:
112117
self.prepare_config_for_launch(fl_ctx)
113-
super().initialize(fl_ctx)
118+
# Enable PASS_THROUGH mode on the engine's communication cell so that
119+
# large tensors arriving from the FL server are NOT downloaded here at
120+
# the CJ. ViaDownloaderDecomposer will instead create LazyDownloadRef
121+
# placeholders that carry the original server FQCN and ref_id. When CJ
122+
# forwards the task to the subprocess agent via the task pipe, those
123+
# placeholders are re-emitted as-is, causing the subprocess to download
124+
# each tensor directly from the server — one tensor at a time, with no
125+
# size limit and no tensor copy at CJ.
126+
engine = fl_ctx.get_engine()
127+
cell = engine.get_cell()
128+
if cell is not None:
129+
self._cell_with_pass_through = cell
130+
prev_ctx = cell.core_cell.get_fobs_context()
131+
self._prev_pass_through = prev_ctx.get(FOBSContextKey.PASS_THROUGH, None)
132+
cell.core_cell.update_fobs_context({FOBSContextKey.PASS_THROUGH: True})
133+
self.log_info(
134+
fl_ctx,
135+
"PASS_THROUGH enabled: task tensors will be downloaded by the subprocess "
136+
"agent directly from the source, bypassing CJ memory.",
137+
)
138+
try:
139+
super().initialize(fl_ctx)
140+
except Exception:
141+
self._restore_pass_through(fl_ctx)
142+
raise
114143

115144
# Check for top-level config override for external_pre_init_timeout
116145
# This allows jobs to configure timeout via add_client_config()
@@ -126,6 +155,23 @@ def initialize(self, fl_ctx: FLContext) -> None:
126155
)
127156
self._external_pre_init_timeout = timeout_value
128157

158+
def finalize(self, fl_ctx: FLContext) -> None:
159+
try:
160+
super().finalize(fl_ctx)
161+
finally:
162+
self._restore_pass_through(fl_ctx)
163+
164+
def _restore_pass_through(self, fl_ctx: FLContext):
165+
if self._cell_with_pass_through is None:
166+
return
167+
168+
self._cell_with_pass_through.core_cell.update_fobs_context(
169+
{FOBSContextKey.PASS_THROUGH: self._prev_pass_through}
170+
)
171+
self.log_info(fl_ctx, f"PASS_THROUGH restored to {self._prev_pass_through}.")
172+
self._cell_with_pass_through = None
173+
self._prev_pass_through = None
174+
129175
def prepare_config_for_launch(self, fl_ctx: FLContext):
130176
pipe_export_class, pipe_export_args = self.pipe.export(ExportMode.PEER)
131177
task_exchange_attributes = {

nvflare/fuel/utils/fobs/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,11 @@ class FOBSContextKey:
5252
DOWNLOAD_REQ_TIMEOUT = "download_req_timeout"
5353
SEC_CREDS = "sec_creds"
5454
NUM_RECEIVERS = "num_receivers"
55+
# When True, ViaDownloaderDecomposer will NOT download tensors at this hop.
56+
# Instead it creates LazyDownloadRef placeholders that preserve the original
57+
# source FQCN/ref_id so the reference can be forwarded verbatim to the next
58+
# hop (e.g. a subprocess agent), which then downloads directly from the
59+
# originating source. This eliminates intermediate tensor copies at the
60+
# forwarding node (the CJ) and is the foundation of the B1 pass-through
61+
# architecture.
62+
PASS_THROUGH = "pass_through"

nvflare/fuel/utils/fobs/decomposers/via_downloader.py

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,65 @@
3131
_MIN_DOWNLOAD_TIMEOUT = 60 # allow at least 1 minute gap between download activities
3232

3333

34+
class LazyDownloadRef:
35+
"""Placeholder created in PASS_THROUGH mode instead of downloading a tensor.
36+
37+
When a cell is configured as a pure forwarder (``FOBSContextKey.PASS_THROUGH``
38+
is set in its FOBS context), incoming download references from the source are
39+
not resolved. Instead a ``LazyDownloadRef`` is created for each tensor item
40+
in the received batch so that the original source FQCN and batch ref_id are
41+
preserved.
42+
43+
When the forwarding node (CJ) later serialises the task for its subprocess,
44+
``LazyDownloadRefDecomposer.decompose()`` detects ``LazyDownloadRef`` targets
45+
and re-emits the *original* download datum (pointing back to the server)
46+
instead of creating a new datum that would point to the CJ. The subprocess
47+
agent then resolves the references directly from the originating source,
48+
downloading each tensor individually without any copy passing through the CJ.
49+
50+
Attributes:
51+
fqcn: FQCN of the originating cell that owns the download transaction.
52+
ref_id: UUID of the batch download transaction on that cell.
53+
item_id: Intra-batch item placeholder (e.g. ``"T0"``, ``"T1"``).
54+
dot: Datum Object Type of the original download datum. Identifies
55+
which ``ViaDownloaderDecomposer`` subclass owns this ref (e.g.
56+
``NUMPY_DOWNLOAD`` or ``TENSOR_DOWNLOAD``). Required by
57+
``LazyDownloadRefDecomposer`` to route serialisation and
58+
deserialisation to the correct handler.
59+
"""
60+
61+
__slots__ = ("fqcn", "ref_id", "item_id", "dot")
62+
63+
def __init__(self, fqcn: str, ref_id: str, item_id: str, dot: int = 0):
64+
self.fqcn = fqcn
65+
self.ref_id = ref_id
66+
self.item_id = item_id
67+
self.dot = dot
68+
69+
70+
class _LazyBatchInfo:
71+
"""Sentinel stored in fobs_ctx[items_key] during PASS_THROUGH mode.
72+
73+
Carries the (fqcn, ref_id, dot) of the *original* download batch so that
74+
``recompose()`` can build a ``LazyDownloadRef`` for each item_id it
75+
encounters. Using a named sentinel class (rather than a plain tuple)
76+
makes the PASS_THROUGH path unambiguous and robust against accidental
77+
type collisions.
78+
"""
79+
80+
__slots__ = ("fqcn", "ref_id", "dot")
81+
82+
def __init__(self, fqcn: str, ref_id: str, dot: int = 0):
83+
self.fqcn = fqcn
84+
self.ref_id = ref_id
85+
self.dot = dot
86+
87+
88+
# fobs_ctx key used to carry the fqcn/ref_id batch info in PASS_THROUGH mode
89+
# so that recompose() can build per-item LazyDownloadRefs from a single datum.
90+
_LAZY_BATCH_CTX_SUFFIX = "_lazy_batch"
91+
92+
3493
class EncKey:
3594
TYPE = "type"
3695
DATA = "data"
@@ -178,6 +237,28 @@ def decompose(self, target: Any, manager: DatumManager = None) -> Any:
178237
# this should never happen
179238
raise RuntimeError("FOBS System Error: missing DatumManager")
180239

240+
# ── LazyDownloadRef: re-emit the original server datum verbatim ────────
241+
# A LazyDownloadRef was created in PASS_THROUGH mode when CJ received the
242+
# task from the server. Instead of creating a *new* download transaction
243+
# on *this* cell (which would make the subprocess download from CJ), we
244+
# re-emit the exact datum that the server originally sent. The subprocess
245+
# agent therefore downloads each tensor directly from the server, with no
246+
# tensor data ever materialised on CJ.
247+
if isinstance(target, LazyDownloadRef):
248+
fobs_ctx = manager.fobs_ctx
249+
lazy_batch_key = f"{self.prefix}{_LAZY_BATCH_CTX_SUFFIX}"
250+
if lazy_batch_key not in fobs_ctx:
251+
# First LazyDownloadRef of this batch: register a post-callback
252+
# that will add the single shared datum (fqcn + ref_id) after all
253+
# items have been serialised.
254+
fobs_ctx[lazy_batch_key] = {"fqcn": target.fqcn, "ref_id": target.ref_id}
255+
manager.register_post_cb(self._finalize_lazy_batch)
256+
257+
self.logger.debug(
258+
f"ViaDownloader: re-emitting LazyDownloadRef {target.item_id=} " f"{target.fqcn=} {target.ref_id=}"
259+
)
260+
return {EncKey.TYPE: EncType.REF, EncKey.DATA: target.item_id}
261+
181262
max_chunk_size = acu.get_int_var(
182263
self._config_var_name(ConfigVarName.DOWNLOAD_CHUNK_SIZE),
183264
self.max_chunk_size,
@@ -320,6 +401,26 @@ def _delete_download_tx_on_msg_root(self, msg_root_id: str, downloader: ObjectDo
320401
self.logger.debug(f"ViaDownloader: deleting download transaction associated with {msg_root_id=}")
321402
downloader.delete_transaction()
322403

404+
def _finalize_lazy_batch(self, mgr: DatumManager):
405+
"""Post-callback used when re-emitting a LazyDownloadRef batch.
406+
407+
Adds a single datum containing the *original* source FQCN and ref_id so
408+
that the downstream consumer (subprocess agent) can download the tensors
409+
directly from the originating cell (typically the FL server) without
410+
involving the CJ at all.
411+
"""
412+
fobs_ctx = mgr.fobs_ctx
413+
lazy_batch_key = f"{self.prefix}{_LAZY_BATCH_CTX_SUFFIX}"
414+
lazy_batch = fobs_ctx.get(lazy_batch_key)
415+
if not lazy_batch:
416+
return
417+
ref = {_RefKey.FQCN: lazy_batch["fqcn"], _RefKey.REF_ID: lazy_batch["ref_id"]}
418+
datum = Datum(datum_type=DatumType.TEXT, value=json.dumps(ref), dot=self.get_download_dot())
419+
self.logger.debug(
420+
f"ViaDownloader: finalized lazy batch datum for {lazy_batch['fqcn']=} {lazy_batch['ref_id']=}"
421+
)
422+
mgr.add_datum(datum)
423+
323424
def process_datum(self, datum: Datum, manager: DatumManager):
324425
"""This is called by the manager to process a datum that has a DOT.
325426
This happens before the recompose processing.
@@ -340,6 +441,17 @@ def process_datum(self, datum: Datum, manager: DatumManager):
340441
self.logger.debug(f"ViaDownloader: pre-processing datum {datum.dot=} before recompose")
341442
fobs_ctx = manager.fobs_ctx
342443

444+
if fobs_ctx.get(fobs.FOBSContextKey.PASS_THROUGH):
445+
# PASS_THROUGH mode: do NOT download tensors at this intermediate hop.
446+
# Store the batch (fqcn, ref_id) so that recompose() can build a
447+
# LazyDownloadRef for each item_id it encounters. The downstream
448+
# consumer (subprocess agent) will resolve the references directly
449+
# from the originating source cell.
450+
ref = json.loads(datum.value)
451+
self.logger.debug(f"ViaDownloader PASS_THROUGH: preserving lazy ref {ref} instead of downloading")
452+
fobs_ctx[self.items_key] = _LazyBatchInfo(ref[_RefKey.FQCN], ref[_RefKey.REF_ID], datum.dot)
453+
return
454+
343455
# data is to be downloaded
344456
ref = json.loads(datum.value)
345457
items = self._download_from_remote_cell(manager.fobs_ctx, ref)
@@ -377,6 +489,19 @@ def recompose(self, data: Any, manager: DatumManager = None) -> Any:
377489
item_id = data
378490
fobs_ctx = manager.fobs_ctx
379491
items = fobs_ctx.get(self.items_key)
492+
493+
# PASS_THROUGH mode: items_key holds a _LazyBatchInfo sentinel, not a dict.
494+
# Build a LazyDownloadRef so the reference can be forwarded verbatim.
495+
# Carry items.dot so that LazyDownloadRefDecomposer can route back to the
496+
# correct ViaDownloaderDecomposer subclass during subprocess recompose().
497+
if isinstance(items, _LazyBatchInfo):
498+
lazy = LazyDownloadRef(fqcn=items.fqcn, ref_id=items.ref_id, item_id=item_id, dot=items.dot)
499+
self.logger.debug(
500+
f"ViaDownloader PASS_THROUGH: created LazyDownloadRef {item_id=} "
501+
f"{items.fqcn=} {items.ref_id=} {items.dot=}"
502+
)
503+
return lazy
504+
380505
self.logger.debug(f"trying to get item for {item_id=} from {type(items)=}")
381506

382507
make_lazy_ref_fn = getattr(items, "make_lazy_ref", None)
@@ -431,3 +556,56 @@ def _download_from_remote_cell(self, fobs_ctx: dict, ref: dict):
431556
else:
432557
self.logger.debug(f"downloaded {len(items)} items successfully")
433558
return items
559+
560+
561+
class LazyDownloadRefDecomposer(fobs.Decomposer):
562+
"""Decomposer that serialises and deserialises :class:`LazyDownloadRef` objects.
563+
564+
``LazyDownloadRef`` objects are created at a forwarding hop (e.g. the CJ
565+
process) when ``FOBSContextKey.PASS_THROUGH`` is set. Instead of
566+
downloading tensors from the FL server, each tensor is represented as a
567+
lightweight placeholder that carries the original server FQCN, batch
568+
ref_id, item_id, and the Datum Object Type (``dot``) of the originating
569+
``ViaDownloaderDecomposer`` subclass.
570+
571+
When the forwarding node re-serialises the task for the subprocess agent,
572+
FOBS routes each ``LazyDownloadRef`` to this decomposer.
573+
574+
**decompose()**
575+
Delegates to the ``ViaDownloaderDecomposer`` subclass identified by
576+
``lazy.dot``. That handler's ``decompose()`` re-emits the original
577+
server batch datum (fqcn / ref_id) via a post-callback so the
578+
subprocess knows exactly where to download from. ``lazy_dot`` is
579+
appended to the returned encoding dict so ``recompose()`` can route
580+
back to the same handler.
581+
582+
**recompose()**
583+
Uses ``lazy_dot`` to look up the original handler and delegates to
584+
``handler.recompose()``. At the subprocess, ``process_datum()`` has
585+
already populated ``fobs_ctx[handler.items_key]`` with the downloaded
586+
tensors, so the call returns the real tensor value directly.
587+
"""
588+
589+
def supported_type(self):
590+
return LazyDownloadRef
591+
592+
def decompose(self, lazy: LazyDownloadRef, manager: DatumManager = None) -> dict:
593+
handler = fobs.get_dot_handler(lazy.dot)
594+
if not handler:
595+
raise RuntimeError(
596+
f"LazyDownloadRefDecomposer: no DOT handler registered for dot={lazy.dot!r}. "
597+
"Ensure the original ViaDownloaderDecomposer subclass (e.g. NumpyArrayDecomposer) "
598+
"is registered before serialising LazyDownloadRef objects."
599+
)
600+
result = handler.decompose(lazy, manager)
601+
result["lazy_dot"] = lazy.dot
602+
return result
603+
604+
def recompose(self, data: dict, manager: DatumManager = None) -> Any:
605+
lazy_dot = data.get("lazy_dot")
606+
if lazy_dot is None:
607+
raise RuntimeError("LazyDownloadRefDecomposer: missing 'lazy_dot' in encoded data")
608+
handler = fobs.get_dot_handler(lazy_dot)
609+
if not handler:
610+
raise RuntimeError(f"LazyDownloadRefDecomposer: no DOT handler registered for lazy_dot={lazy_dot!r}")
611+
return handler.recompose(data, manager)

0 commit comments

Comments
 (0)