Skip to content

Cherry-pick [2.7] Pass-Through: Zero Tensor Copy at CJ for Large-Model Federated Training (#4210)#4289

Open
YuanTingHsieh wants to merge 1 commit intoNVIDIA:mainfrom
YuanTingHsieh:cherry-pick-4210
Open

Cherry-pick [2.7] Pass-Through: Zero Tensor Copy at CJ for Large-Model Federated Training (#4210)#4289
YuanTingHsieh wants to merge 1 commit intoNVIDIA:mainfrom
YuanTingHsieh:cherry-pick-4210

Conversation

@YuanTingHsieh
Copy link
Collaborator

This PR introduces the pass-through architecture for ClientAPILauncherExecutor, eliminating tensor materialisation at the CJ (Client Job) process when large models are exchanged between the FL server and a subprocess agent.

In large-model federated learning (e.g., 7B–70B LLM fine-tuning), the CJ process today acts as a blind relay that fully deserializes and re-serializes every tensor it receives from the FL server before forwarding to the subprocess. For a 70B float16 model, this consumes ~140 GB of CJ memory and requires two complete network transfers. B1 pass-through removes both costs.


NVFlare's multi-hop execution path for launch_external_process=True looks like:

FL Server ──serialize──▶ CJ process ──re-serialize──▶ Subprocess agent

Each tensor in the global model is handled as follows at CJ:

  1. Server serializes the model and creates a download transaction (tensor data lives on the server).
  2. CJ fully downloads every tensor from the server into its own heap, materialising the complete model in CJ memory.
  3. CJ re-serializes the model for the subprocess, creating a new download transaction — the subprocess then downloads from CJ.

For large models, this means:

  • CJ peak memory = full model size (potentially 100s of GB).
  • Two full network transfers: server → CJ, then CJ → subprocess.
  • CJ becomes a throughput bottleneck and an OOM risk for any model that doesn't fit in the CJ process's memory.

This is the reason why workflows are infeasible for large models beyond what the CJ machine can hold.


With FOBSContextKey.PASS_THROUGH enabled on CJ's cell FOBS context, the data path becomes:

FL Server ──stream──▶ CJ (LazyDownloadRef only, no tensor data)
                          └──forward ref──▶ Subprocess
                                               └──download──▶ FL Server

CJ holds only lightweight placeholders (< 100 bytes per tensor). The subprocess downloads each tensor directly from the FL server — CJ is never involved in the tensor data path.

(nvflare/fuel/utils/fobs/__init__.py)
A new context key that signals ViaDownloaderDecomposer to skip the download step and create lazy placeholders instead.

A small sentinel object (four fields: fqcn, ref_id, item_id, dot) created by recompose() in PASS_THROUGH mode. It carries the original FL server's FQCN, batch ref_id, intra-batch item ID, and Datum Object Type — everything the subprocess needs to download the tensor directly.

A named sentinel stored in fobs_ctx[items_key] during PASS_THROUGH receive. Using a typed class (rather than a plain tuple) makes the PASS_THROUGH branch unambiguous and immune to accidental type collisions with real item dicts.

A new auto-registered FOBS decomposer for LazyDownloadRef. When CJ re-serializes a task containing LazyDownloadRef objects:

  • decompose() delegates to get_dot_handler(lazy.dot) — the original ViaDownloaderDecomposer subclass (e.g., TensorDecomposer, NumpyArrayDecomposer). That handler's _finalize_lazy_batch post-callback re-emits the original server datum (fqcn + ref_id + DOT) so the subprocess knows exactly where to download from. lazy_dot is appended to the encoding dict for routing on the receive side.
  • recompose() uses lazy_dot to look up the handler and delegates to handler.recompose(), which retrieves the real tensor from fobs_ctx[handler.items_key] (populated by process_datum() when the subprocess received the forwarded datum).

The dot (Datum Object Type) field on both LazyDownloadRef and _LazyBatchInfo ensures that numpy arrays stay with NumpyArrayDecomposer and PyTorch tensors stay with TensorDecomposer, preserving type safety through the full pass-through hop.

(client_api_launcher_executor.py)
On startup, the executor enables PASS_THROUGH on the engine cell's FOBS context:

cell.core_cell.update_fobs_context({FOBSContextKey.PASS_THROUGH: True})

This single line activates the full B1 architecture for every job that uses launch_external_process=True — including llm_hf and any recipe that calls ScriptRunner(launch_external_process=True).


The pipe (CellPipe) operates on already-serialized bytes. Intercepting at the pipe level would require parsing FOBS binary format, re-writing datum references, and re-assembling the byte stream — fragile and tightly coupled to the wire format.

Intercepting at the FOBS decomposer level is the natural extension point: decomposers already control exactly when and how data is materialised. PASS_THROUGH simply adds a "don't materialise" branch to that existing mechanism.

The subprocess must know which ViaDownloaderDecomposer subclass owns the downloaded data so it can store it in the correct fobs_ctx[items_key] and route recompose() correctly. The dot field, set when the server originally serialized the tensor, carries this type information through the pass-through hop without any type-switching logic.

| Stage | Before (tensor materialised) | After (B1 pass-through) | |-------|------------------------------|------------------------| | CJ receive | Full model size (e.g., 140 GB) | ~100 bytes per tensor | | CJ forward | Creates new download tx | Re-emits original server datum |
| Subprocess receive | Downloads from CJ | Downloads directly from FL server |


  1. Zero tensor copy at CJ — CJ memory footprint is independent of model size.
  2. One network transfer instead of two — tensors travel server → subprocess directly.
  3. No CJ OOM risk for large models regardless of CJ machine memory capacity.
  4. Transparent to job authors — no changes to job configs, training scripts, or recipe APIs; launch_external_process=True automatically activates B1.
  5. Type-safedot propagation preserves tensor type (numpy / pytorch) through the hop without any if/elif type switching.

  • All existing jobs using launch_external_process=True automatically benefit. No config or script changes required.
  • Jobs using launch_external_process=False (in-process executor) are completely unaffected — ClientAPILauncherExecutor.initialize() is not called.
  • For models smaller than the ViaDownloaderDecomposer streaming threshold (2 MB per array), FOBS uses native (inline) serialization regardless of PASS_THROUGH — behaviour is identical to before.
  • LazyDownloadRefDecomposer is auto-registered via the existing register_folder mechanism; no explicit registration call is needed by any caller.

File Change
nvflare/fuel/utils/fobs/__init__.py Add
FOBSContextKey.PASS_THROUGH
nvflare/fuel/utils/fobs/decomposers/via_downloader.py Add LazyDownloadRef, _LazyBatchInfo, PASS_THROUGH branches in process_datum() / recompose(), LazyDownloadRefDecomposer, _finalize_lazy_batch post-callback
nvflare/app_common/executors/client_api_launcher_executor.py initialize() enables PASS_THROUGH on engine cell

(22 tests)

Test class What is verified
TestLazyDownloadRef Construction, __slots__, per-item distinctness
TestLazyBatchInfo Construction, __slots__, isinstance reliability vs plain tuple
TestProcessDatumPassThrough PASS_THROUGH stores _LazyBatchInfo, never calls _download_from_remote_cell; normal mode calls download
DownloadService._tx_table unchanged; 50-cycle repeat produces no state bleed

tests/unit_test/fuel/f3/streaming/test_pass_through_e2e.py (5 tests, real TCP Cells)

Test What is verified
test_arrays_survive_pass_through_hop Full round-trip: server → CJ (PASS_THROUGH) → subprocess, arrays arrive bit-exact
DownloadService._tx_table is unchanged during PASS_THROUGH + re-serialization
test_forwarded_payload_carries_original_server_ref Forwarded datum contains original server fqcn and ref_id — subprocess downloads from server, not CJ
test_multiple_array_roundtrip 8-array batch all survive with bit-exact values

tests/integration_test/data/jobs/pt_large_model_pass_through/

Full-stack integration test using PTClientAPILauncherExecutor with launch_once=True (the pattern used by llm_hf):

  • Model: LargeNet — 3-layer MLP with ~8 MB of float32 parameters, well above the 2 MB ViaDownloaderDecomposer streaming threshold. This forces the real B1 code path (streaming + PASS_THROUGH) rather than the native inline path used by small models.
  • Client script: Mirrors llm_hf/client.py structure (while flare.is_running() loop, receive / train / send). Uses CPU-only synthetic data — no dataset download required in CI.
  • Added to client_api.yml as "run pt-large-model-pass-through".

🤖 Generated with Claude Code


Fixes # .

Description

A few sentences describing the changes proposed in this pull request.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Quick tests passed locally by running ./runtest.sh.
  • In-line docstrings updated.
  • Documentation updated.

Copilot AI review requested due to automatic review settings March 11, 2026 01:19
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 11, 2026

Greptile Summary

This PR introduces the B1 pass-through architecture for ClientAPILauncherExecutor, eliminating tensor materialisation at the CJ (Client Job) process when large models are exchanged between the FL server and a subprocess agent. With FOBSContextKey.PASS_THROUGH enabled, ViaDownloaderDecomposer skips the download step on the CJ hop and instead stores lightweight LazyDownloadRef placeholders; when CJ re-serialises the payload for the subprocess, LazyDownloadRefDecomposer re-emits the original server datum so the subprocess downloads tensors directly from the FL server — resulting in zero tensor copy at CJ and a single network transfer.

Key observations:

  • Correctness of the core path: The LazyDownloadRef / _LazyBatchInfo design is sound. dot propagation correctly routes numpy vs. PyTorch tensors through the full hop. Unit and E2E tests are comprehensive and well-structured.
  • initialize() partial-failure restore gap: PASS_THROUGH is set on the shared engine cell before super().initialize(), which is properly restored by the guarded except block. However, if super().initialize() succeeds and then ValueError is raised for an invalid EXTERNAL_PRE_INIT_TIMEOUT (lines 149-151), that exception is outside the try/except and _restore_pass_through is never called in that code path. Recovery depends on finalize() being invoked by the framework.
  • Cherry-pick artifact in __init__: Lines 111-112 reference memory_gc_rounds and cuda_empty_cache, which are not parameters of __init__ (see previous review thread), causing NameError on every instantiation.
  • lazy_batch_key silently uses first-item ref for all items: No assertion validates that all LazyDownloadRef objects in a single serialisation call share the same ref_id. This invariant holds today (one _DecomposeCtx per type per message), but adding an explicit assertion would protect against future regressions.
  • Type-routing via dot: The design choice of passing LazyDownloadRef into ViaDownloaderDecomposer.decompose() (a decomposer whose supported_type() is not LazyDownloadRef) is unconventional but intentional — it allows the original handler's _finalize_lazy_batch to emit a datum with the correct DOT.

Confidence Score: 3/5

  • Not safe to merge until the cherry-pick NameError on lines 111-112 of client_api_launcher_executor.py is resolved; the PASS_THROUGH partial-restore gap is an additional robustness concern.
  • The core B1 architecture (via_downloader.py, fobs/init.py) is well-designed with thorough test coverage. However, the cherry-picked init body references two undeclared variables (memory_gc_rounds, cuda_empty_cache) that cause NameError on every instantiation of ClientAPILauncherExecutor — a hard blocker. Additionally, the PASS_THROUGH restore is not guaranteed when ValueError is raised after super().initialize() succeeds, which is a secondary robustness issue.
  • nvflare/app_common/executors/client_api_launcher_executor.py requires attention for the undeclared variable NameError (lines 111-112) and the incomplete PASS_THROUGH restore on partial initialize() failure (lines 138-156).

Important Files Changed

Filename Overview
nvflare/fuel/utils/fobs/decomposers/via_downloader.py Core of the B1 pass-through architecture: adds LazyDownloadRef, _LazyBatchInfo, PASS_THROUGH branches in process_datum/recompose, LazyDownloadRefDecomposer, and _finalize_lazy_batch; design is sound but the lazy_batch_key check-then-set pattern silently uses the first batch's ref_id for all items if two batches share a serialisation call.
nvflare/app_common/executors/client_api_launcher_executor.py Enables PASS_THROUGH on the engine cell's FOBS context in initialize(); restore is correctly placed in finalize() via finally, but the post-super() ValueError path (invalid EXTERNAL_PRE_INIT_TIMEOUT) lacks a restore call, and two undeclared variables (memory_gc_rounds, cuda_empty_cache) on lines 111-112 cause NameError on every instantiation.
nvflare/fuel/utils/fobs/init.py Adds FOBSContextKey.PASS_THROUGH constant with clear docstring; change is minimal and non-breaking.
tests/unit_test/fuel/utils/fobs/test_pass_through.py Comprehensive unit tests covering construction, PASS_THROUGH process_datum/recompose paths, post-callback registration, datum correctness, and no-memory-accumulation invariants; well-structured and thorough.
tests/unit_test/fuel/f3/streaming/test_pass_through_e2e.py End-to-end tests using real TCP Cell objects; validates full server→CJ(PASS_THROUGH)→subprocess round-trip, lazy ref invariants, no download transactions at CJ, and correct server ref forwarding.
tests/unit_test/app_common/executors/client_api_launcher_executor_test.py Tests PASS_THROUGH restore on finalize and exception paths; does not catch the NameError on lines 111-112 because the fixture monkeypatches init dependencies away.
tests/integration_test/data/jobs/pt_large_model_pass_through/app/custom/large_model_train.py Integration test client script mirroring llm_hf/client.py; uses CPU-only synthetic data with an ~8 MB LargeNet that exceeds the 2 MB streaming threshold, correctly exercising the B1 code path.

Sequence Diagram

sequenceDiagram
    participant Server as FL Server
    participant CJ as CJ Process<br/>(PASS_THROUGH=True)
    participant Sub as Subprocess Agent

    Note over Server: dump_to_bytes(model)<br/>creates ObjectDownloader tx<br/>ref_id = UUID, fqcn = server

    Server->>CJ: serialized bytes<br/>(datum: {fqcn, ref_id})

    Note over CJ: process_datum() detects PASS_THROUGH<br/>→ stores _LazyBatchInfo(fqcn, ref_id)<br/>→ NO download call

    Note over CJ: recompose() returns<br/>LazyDownloadRef(fqcn, ref_id, item_id)<br/>for each tensor item

    Note over CJ: LazyDownloadRefDecomposer.decompose()<br/>→ re-emits original server datum<br/>via _finalize_lazy_batch post-CB<br/>→ NO new ObjectDownloader tx

    CJ->>Sub: forwarded bytes<br/>(datum: original server {fqcn, ref_id})

    Note over Sub: process_datum() NOT in PASS_THROUGH<br/>→ calls _download_from_remote_cell

    Sub->>Server: download tensors directly<br/>(ref_id from original datum)
    Server-->>Sub: tensor data

    Note over Sub: recompose() returns real tensors<br/>from downloaded items dict
Loading

Last reviewed commit: bbb9526

Comment on lines +111 to +112
self._memory_gc_rounds = memory_gc_rounds
self._cuda_empty_cache = cuda_empty_cache
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Undefined names cause NameError at instantiation

memory_gc_rounds and cuda_empty_cache are referenced in the __init__ body but are not declared as parameters in the method signature. This will raise a NameError every time ClientAPILauncherExecutor(...) is instantiated, completely breaking the class.

Looking at the constructor signature (lines 30-54), neither memory_gc_rounds nor cuda_empty_cache appear, and they are also not assigned anywhere earlier in __init__. The LauncherExecutor parent class does not define them either.

This looks like a cherry-pick artifact: the 2.7 PR (#4210) likely added these parameters to the constructor, but the cherry-pick to main only included the body assignments, not the signature additions. Since these instance variables (self._memory_gc_rounds, self._cuda_empty_cache) are also never referenced anywhere else in this class, the simplest fix is to remove both lines:

Suggested change
self._memory_gc_rounds = memory_gc_rounds
self._cuda_empty_cache = cuda_empty_cache
self._cell_with_pass_through = None
self._prev_pass_through = None

The unit test fixture in client_api_launcher_executor_test.py that calls ClientAPILauncherExecutor(pipe_id="test_pipe") will also fail with this NameError before any test logic even runs.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR backports a “pass-through” serialization mode for external-process execution so the Client Job (CJ) process forwards large model tensors without materializing them in CJ memory, enabling the subprocess agent to download tensor payloads directly from the originating cell (typically the FL server).

Changes:

  • Add FOBSContextKey.PASS_THROUGH and implement LazyDownloadRef / pass-through branches in ViaDownloaderDecomposer plus an auto-registered LazyDownloadRefDecomposer.
  • Enable/restore PASS_THROUGH on the engine cell in ClientAPILauncherExecutor lifecycle.
  • Add unit tests, TCP cell E2E tests, and a full integration job exercising the streaming threshold path; update memory-management docs and .gitignore.

Reviewed changes

Copilot reviewed 13 out of 14 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
nvflare/fuel/utils/fobs/__init__.py Adds FOBSContextKey.PASS_THROUGH flag to control pass-through behavior.
nvflare/fuel/utils/fobs/decomposers/via_downloader.py Implements LazyDownloadRef, pass-through processing, and a decomposer to forward original download refs.
nvflare/app_common/executors/client_api_launcher_executor.py Turns on PASS_THROUGH for the engine cell during executor init and restores it on finalize/error.
tests/unit_test/fuel/utils/fobs/test_pass_through.py Unit tests for pass-through logic, lazy refs, and “no download tx created” invariants.
tests/unit_test/fuel/f3/streaming/test_pass_through_e2e.py E2E tests using real TCP Cells to validate the server→CJ→subprocess pass-through hop.
tests/unit_test/app_common/executors/client_api_launcher_executor_test.py Tests that PASS_THROUGH is restored correctly on finalize and init failure.
tests/integration_test/data/test_configs/standalone_job/client_api.yml Adds a standalone integration test entry to run the new pass-through job.
tests/integration_test/data/jobs/pt_large_model_pass_through/meta.conf Defines the integration test job metadata.
tests/integration_test/data/jobs/pt_large_model_pass_through/app/custom/large_model_train.py Client-side training script that exercises external-process pass-through.
tests/integration_test/data/jobs/pt_large_model_pass_through/app/custom/large_model_net.py Defines an ~8MB model to force the streaming/download path.
tests/integration_test/data/jobs/pt_large_model_pass_through/app/config/config_fed_server.conf Server config for the integration job.
tests/integration_test/data/jobs/pt_large_model_pass_through/app/config/config_fed_client.conf Client config using PTClientAPILauncherExecutor with launch_once=True.
docs/programming_guide/memory_management.rst Updates recommended memory settings table and adds jemalloc preload guidance.
.gitignore Ignores memory profiler .dat outputs under tests/memory_profile/.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +111 to +112
self._memory_gc_rounds = memory_gc_rounds
self._cuda_empty_cache = cuda_empty_cache
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__ assigns self._memory_gc_rounds = memory_gc_rounds and self._cuda_empty_cache = cuda_empty_cache, but neither memory_gc_rounds nor cuda_empty_cache is defined in this scope (they are not parameters and not module globals). This will raise NameError when instantiating ClientAPILauncherExecutor. Add these as explicit __init__ parameters with defaults (and document/use them), or remove the assignments if they’re not intended for this branch.

Suggested change
self._memory_gc_rounds = memory_gc_rounds
self._cuda_empty_cache = cuda_empty_cache

Copilot uses AI. Check for mistakes.
}


def _simulate_cj_pass_through(server_bytes: bytes) -> bytes:
Copy link

Copilot AI Mar 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_simulate_cj_pass_through is annotated/documented as returning bytes, but it actually returns a 2-tuple (cj_result, forwarded_bytes) and all call sites unpack it as such. Update the return type annotation (and docstring) to match the actual return value.

Copilot uses AI. Check for mistakes.
…Training (NVIDIA#4210)

This PR introduces the **pass-through architecture** for
`ClientAPILauncherExecutor`, eliminating tensor materialisation at the
CJ (Client Job) process when large models are exchanged between the FL
server and a subprocess agent.

In large-model federated learning (e.g., 7B–70B LLM fine-tuning), the CJ
process today acts as a blind relay that fully deserializes and
re-serializes every tensor it receives from the FL server before
forwarding to the subprocess. For a 70B float16 model, this consumes
~140 GB of CJ memory and requires two complete network transfers. B1
pass-through removes both costs.

---

NVFlare's multi-hop execution path for `launch_external_process=True`
looks like:

```
FL Server ──serialize──▶ CJ process ──re-serialize──▶ Subprocess agent
```

Each tensor in the global model is handled as follows at CJ:

1. **Server** serializes the model and creates a download transaction
(tensor data lives on the server).
2. **CJ** fully *downloads* every tensor from the server into its own
heap, materialising the complete model in CJ memory.
3. **CJ** re-serializes the model for the subprocess, creating a *new*
download transaction — the subprocess then downloads from CJ.

For large models, this means:
- **CJ peak memory = full model size** (potentially 100s of GB).
- **Two full network transfers**: server → CJ, then CJ → subprocess.
- **CJ becomes a throughput bottleneck** and an OOM risk for any model
that doesn't fit in the CJ process's memory.

This is the reason why workflows are infeasible for large models beyond
what the CJ machine can hold.

---

With `FOBSContextKey.PASS_THROUGH` enabled on CJ's cell FOBS context,
the data path becomes:

```
FL Server ──stream──▶ CJ (LazyDownloadRef only, no tensor data)
                          └──forward ref──▶ Subprocess
                                               └──download──▶ FL Server
```

CJ holds **only lightweight placeholders** (< 100 bytes per tensor). The
subprocess downloads each tensor directly from the FL server — CJ is
never involved in the tensor data path.

(`nvflare/fuel/utils/fobs/__init__.py`)
A new context key that signals ViaDownloaderDecomposer to skip the
download step and create lazy placeholders instead.

A small sentinel object (four fields: `fqcn`, `ref_id`, `item_id`,
`dot`) created by `recompose()` in PASS_THROUGH mode. It carries the
original FL server's FQCN, batch ref_id, intra-batch item ID, and Datum
Object Type — everything the subprocess needs to download the tensor
directly.

A named sentinel stored in `fobs_ctx[items_key]` during PASS_THROUGH
receive. Using a typed class (rather than a plain tuple) makes the
PASS_THROUGH branch unambiguous and immune to accidental type collisions
with real item dicts.

A new auto-registered FOBS decomposer for `LazyDownloadRef`. When CJ
re-serializes a task containing `LazyDownloadRef` objects:

- **`decompose()`** delegates to `get_dot_handler(lazy.dot)` — the
original `ViaDownloaderDecomposer` subclass (e.g., `TensorDecomposer`,
`NumpyArrayDecomposer`). That handler's `_finalize_lazy_batch`
post-callback re-emits the *original* server datum (fqcn + ref_id + DOT)
so the subprocess knows exactly where to download from. `lazy_dot` is
appended to the encoding dict for routing on the receive side.
- **`recompose()`** uses `lazy_dot` to look up the handler and delegates
to `handler.recompose()`, which retrieves the real tensor from
`fobs_ctx[handler.items_key]` (populated by `process_datum()` when the
subprocess received the forwarded datum).

The `dot` (Datum Object Type) field on both `LazyDownloadRef` and
`_LazyBatchInfo` ensures that numpy arrays stay with
`NumpyArrayDecomposer` and PyTorch tensors stay with `TensorDecomposer`,
preserving type safety through the full pass-through hop.

(`client_api_launcher_executor.py`)
On startup, the executor enables PASS_THROUGH on the engine cell's FOBS
context:

```python
cell.core_cell.update_fobs_context({FOBSContextKey.PASS_THROUGH: True})
```

This single line activates the full B1 architecture for every job that
uses `launch_external_process=True` — including `llm_hf` and any recipe
that calls `ScriptRunner(launch_external_process=True)`.

---

The pipe (CellPipe) operates on already-serialized bytes. Intercepting
at the pipe level would require parsing FOBS binary format, re-writing
datum references, and re-assembling the byte stream — fragile and
tightly coupled to the wire format.

Intercepting at the FOBS decomposer level is the natural extension
point: decomposers already control exactly when and how data is
materialised. PASS_THROUGH simply adds a "don't materialise" branch to
that existing mechanism.

The subprocess must know *which* `ViaDownloaderDecomposer` subclass owns
the downloaded data so it can store it in the correct
`fobs_ctx[items_key]` and route `recompose()` correctly. The `dot`
field, set when the server originally serialized the tensor, carries
this type information through the pass-through hop without any
type-switching logic.

| Stage | Before (tensor materialised) | After (B1 pass-through) |
|-------|------------------------------|------------------------|
| CJ receive | Full model size (e.g., 140 GB) | ~100 bytes per tensor |
| CJ forward | Creates new download tx | Re-emits original server datum
|
| Subprocess receive | Downloads from CJ | Downloads directly from FL
server |

---

1. **Zero tensor copy at CJ** — CJ memory footprint is independent of
model size.
2. **One network transfer** instead of two — tensors travel server →
subprocess directly.
3. **No CJ OOM risk** for large models regardless of CJ machine memory
capacity.
4. **Transparent to job authors** — no changes to job configs, training
scripts, or recipe APIs; `launch_external_process=True` automatically
activates B1.
5. **Type-safe** — `dot` propagation preserves tensor type (numpy /
pytorch) through the hop without any if/elif type switching.

---

- All existing jobs using `launch_external_process=True` automatically
benefit. No config or script changes required.
- Jobs using `launch_external_process=False` (in-process executor) are
completely unaffected — `ClientAPILauncherExecutor.initialize()` is not
called.
- For models smaller than the ViaDownloaderDecomposer streaming
threshold (2 MB per array), FOBS uses native (inline) serialization
regardless of `PASS_THROUGH` — behaviour is identical to before.
- `LazyDownloadRefDecomposer` is auto-registered via the existing
`register_folder` mechanism; no explicit registration call is needed by
any caller.

---

| File | Change |
|------|--------|
| `nvflare/fuel/utils/fobs/__init__.py` | Add
`FOBSContextKey.PASS_THROUGH` |
| `nvflare/fuel/utils/fobs/decomposers/via_downloader.py` | Add
`LazyDownloadRef`, `_LazyBatchInfo`, PASS_THROUGH branches in
`process_datum()` / `recompose()`, `LazyDownloadRefDecomposer`,
`_finalize_lazy_batch` post-callback |
| `nvflare/app_common/executors/client_api_launcher_executor.py` |
`initialize()` enables PASS_THROUGH on engine cell |

---

(22 tests)

| Test class | What is verified |
|------------|-----------------|
| `TestLazyDownloadRef` | Construction, `__slots__`, per-item
distinctness |
| `TestLazyBatchInfo` | Construction, `__slots__`, `isinstance`
reliability vs plain tuple |
| `TestProcessDatumPassThrough` | PASS_THROUGH stores `_LazyBatchInfo`,
never calls `_download_from_remote_cell`; normal mode calls download |
| `TestRecomposePassThrough` | Returns `LazyDownloadRef` with correct
`fqcn`, `ref_id`, `item_id` from `_LazyBatchInfo` |
| `TestDecomposeWithLazyDownloadRef` | Returns REF encoding;
`_finalize_lazy_batch` post-CB registered once per batch regardless of
item count; emitted datum has correct fqcn/ref_id/DOT |
| `TestNoMemoryAccumulation` | `_CtxKey.OBJECTS` absent after
PASS_THROUGH (no download transaction opened);
`DownloadService._tx_table` unchanged; 50-cycle repeat produces no state
bleed |

`tests/unit_test/fuel/f3/streaming/test_pass_through_e2e.py` (5 tests,
real TCP Cells)

| Test | What is verified |
|------|-----------------|
| `test_arrays_survive_pass_through_hop` | Full round-trip: server → CJ
(PASS_THROUGH) → subprocess, arrays arrive bit-exact |
| `test_cj_holds_only_lazy_refs_not_tensor_data` | After PASS_THROUGH
deserialization, CJ holds only `LazyDownloadRef`, never `np.ndarray` |
| `test_cj_creates_no_download_transaction` |
`DownloadService._tx_table` is unchanged during PASS_THROUGH +
re-serialization |
| `test_forwarded_payload_carries_original_server_ref` | Forwarded datum
contains original server `fqcn` and `ref_id` — subprocess downloads from
server, not CJ |
| `test_multiple_array_roundtrip` | 8-array batch all survive with
bit-exact values |

`tests/integration_test/data/jobs/pt_large_model_pass_through/`

Full-stack integration test using `PTClientAPILauncherExecutor` with
`launch_once=True` (the pattern used by `llm_hf`):

- **Model**: `LargeNet` — 3-layer MLP with ~8 MB of float32 parameters,
well above the 2 MB ViaDownloaderDecomposer streaming threshold. This
forces the real B1 code path (streaming + PASS_THROUGH) rather than the
native inline path used by small models.
- **Client script**: Mirrors `llm_hf/client.py` structure (`while
flare.is_running()` loop, receive / train / send). Uses CPU-only
synthetic data — no dataset download required in CI.
- **Added to** `client_api.yml` as `"run pt-large-model-pass-through"`.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Comment on lines 138 to 156
@@ -126,6 +155,23 @@ def initialize(self, fl_ctx: FLContext) -> None:
)
self._external_pre_init_timeout = timeout_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PASS_THROUGH not restored on partial initialize() failure

If super().initialize(fl_ctx) succeeds but the subsequent ValueError is raised at line 151 (invalid EXTERNAL_PRE_INIT_TIMEOUT), _restore_pass_through is never called in that code path. The try/except on lines 138-142 only guards super().initialize(), leaving the exception raised at line 151 without a restore call.

At that point, self._cell_with_pass_through and self._prev_pass_through are already set, and PASS_THROUGH=True is live on the cell. Recovery depends entirely on the NVFlare framework guaranteeing a finalize() call even after a partial initialize() failure — which is not enforced here.

The minimal fix is to extend the guarded block to cover the full post-setup phase:

try:
    super().initialize(fl_ctx)
    # Check for top-level config override for external_pre_init_timeout
    config_timeout = get_client_config_value(fl_ctx, EXTERNAL_PRE_INIT_TIMEOUT)
    if config_timeout is not None:
        timeout_value = float(config_timeout)
        if timeout_value <= 0:
            self.log_error(fl_ctx, f"Invalid EXTERNAL_PRE_INIT_TIMEOUT: {timeout_value}s (must be positive)")
            raise ValueError(f"EXTERNAL_PRE_INIT_TIMEOUT must be positive, got {timeout_value}")
        self.log_info(
            fl_ctx,
            f"Overriding external_pre_init_timeout from config: {self._external_pre_init_timeout}s -> {timeout_value}s",
        )
        self._external_pre_init_timeout = timeout_value
except Exception:
    self._restore_pass_through(fl_ctx)
    raise

Comment on lines +247 to +260
if isinstance(target, LazyDownloadRef):
fobs_ctx = manager.fobs_ctx
lazy_batch_key = f"{self.prefix}{_LAZY_BATCH_CTX_SUFFIX}"
if lazy_batch_key not in fobs_ctx:
# First LazyDownloadRef of this batch: register a post-callback
# that will add the single shared datum (fqcn + ref_id) after all
# items have been serialised.
fobs_ctx[lazy_batch_key] = {"fqcn": target.fqcn, "ref_id": target.ref_id}
manager.register_post_cb(self._finalize_lazy_batch)

self.logger.debug(
f"ViaDownloader: re-emitting LazyDownloadRef {target.item_id=} " f"{target.fqcn=} {target.ref_id=}"
)
return {EncKey.TYPE: EncType.REF, EncKey.DATA: target.item_id}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Silent data corruption if two batches with different ref_ids are serialised in the same message

The check-then-set on lazy_batch_key (lines 250-255) records only the first LazyDownloadRef's fqcn and ref_id:

if lazy_batch_key not in fobs_ctx:
    fobs_ctx[lazy_batch_key] = {"fqcn": target.fqcn, "ref_id": target.ref_id}
    manager.register_post_cb(self._finalize_lazy_batch)

If a later LazyDownloadRef in the same serialisation call carries a different ref_id (e.g., two independent server batches merged into one forwarded payload), its item_id is emitted with a REF encoding (line 260) but the single datum added by _finalize_lazy_batch still points to the first batch's ref_id. The subprocess would then attempt to resolve every item against that one ref, silently returning wrong tensors or a download error for items from the second batch.

The current architecture guarantees a single batch per type per message (one _DecomposeCtx per decomposer type), so this is not a live bug today. However, adding an assertion protects against future regressions and documents the invariant explicitly:

if lazy_batch_key not in fobs_ctx:
    fobs_ctx[lazy_batch_key] = {"fqcn": target.fqcn, "ref_id": target.ref_id}
    manager.register_post_cb(self._finalize_lazy_batch)
else:
    # All LazyDownloadRefs in one message must belong to the same server batch.
    existing = fobs_ctx[lazy_batch_key]
    assert existing["ref_id"] == target.ref_id, (
        f"LazyDownloadRef ref_id mismatch: expected {existing['ref_id']!r}, "
        f"got {target.ref_id!r}. Multiple server batches in one message are not supported."
    )

@YuanTingHsieh
Copy link
Collaborator Author

This PR depends on cherry-pick of 4211

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants