Add federated learning with DRS from scratch implementation#1
Draft
TheVidz wants to merge 3 commits into
Draft
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds a scratch (non-Flower) baseline federated learning simulation for ancestry classification, including a simple FedAvg server, stateless client training with GA4GH DRS-based dataset resolution, a multi-round runner, and plotting utilities.
Changes:
- Introduces a FedAvg server that aggregates client checkpoints and logs weighted global metrics per round.
- Adds a stateless client that resolves datasets via DRS (with fallback), trains locally, emits metrics, and writes per-round checkpoints with embedded metadata.
- Adds a local simulation orchestrator and a plotting script for convergence visualization, plus updates
.gitignore.
Reviewed changes
Copilot reviewed 8 out of 10 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| scratch-fl/server.py | Implements server-side FedAvg aggregation and server metrics logging/checkpointing. |
| scratch-fl/client.py | Implements client-side DRS dataset resolution, local training/evaluation, and checkpoint/metrics output. |
| scratch-fl/model.py | Defines the shared PyTorch model architecture and parameter helpers. |
| scratch-fl/run_simulation.py | Orchestrates multi-round local simulation by launching server/client steps. |
| scratch-fl/plot_results.py | Generates convergence plots from server metrics CSV. |
| scratch-fl/data/site_2.tsv | Adds a local TSV dataset shard used as sample input data. |
| .gitignore | Updates ignored FL-related artifact directories. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+22
to
+35
| global_state = global_model.state_dict() | ||
| aggregated_state = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in global_state.items()} | ||
|
|
||
| logger.info(f"Aggregating {len(client_checkpoints)} client state dicts across {total_samples} cumulative samples.") | ||
|
|
||
| for path, sample_size in zip(client_checkpoints, client_sample_sizes): | ||
| client_state = torch.load(path, map_location="cpu") | ||
| weight = sample_size / total_samples | ||
| for k in aggregated_state.keys(): | ||
| # Avoid aggregating embedded metadata wrappers | ||
| if k == "metadata": | ||
| continue | ||
| aggregated_state[k] += client_state[k].float() * weight | ||
|
|
Comment on lines
+27
to
+29
| for path, sample_size in zip(client_checkpoints, client_sample_sizes): | ||
| client_state = torch.load(path, map_location="cpu") | ||
| weight = sample_size / total_samples |
Comment on lines
+79
to
+82
| # Load file safely to parse data total allocations | ||
| payload = torch.load(client_out_path, map_location="cpu") | ||
| meta = payload.get("metadata", {}) | ||
|
|
Comment on lines
+55
to
+66
| except Exception as e: | ||
| print(f"[DRS Error] Starter kit connection failed at {meta_url}. Falling back to clean mock local files.") | ||
| fallback_path = f"./data_site_{args.client_id}_unified.tsv" | ||
| if not os.path.exists(fallback_path): | ||
| # Synthesize mockup data automatically if both DRS server and data rows are absent | ||
| mock_data = [] | ||
| for i in range(250): | ||
| row = {f"PC{idx}_AVG": torch.randn(1).item() for idx in range(1, 11)} | ||
| row["super_pop"] = SUPERPOPS[i % 5] | ||
| mock_data.append(row) | ||
| pd.DataFrame(mock_data).to_csv(fallback_path, sep="\t", index=False) | ||
| return fallback_path |
Comment on lines
+44
to
+54
| meta_url = f"http://localhost:4500/ga4gh/drs/v1/objects/{object_id}" | ||
| try: | ||
| meta_resp = requests.get(meta_url, timeout=5).json() | ||
| access_id = meta_resp["access_methods"][0]["access_id"] | ||
|
|
||
| access_url = f"http://localhost:4500/ga4gh/drs/v1/objects/{object_id}/access/{access_id}" | ||
| stream_url = requests.get(access_url, timeout=5).json()["url"] | ||
|
|
||
| if stream_url.startswith("file://"): | ||
| stream_url = stream_url.replace("file://", "", 1) | ||
| return stream_url |
Comment on lines
+1
to
+18
| # run_simulation.py | ||
| import os | ||
| import subprocess | ||
| import sys | ||
|
|
||
| ROUNDS = 15 | ||
| NUM_CLIENTS = 4 | ||
| ARTIFACTS_DIR = "./checkpoints" | ||
|
|
||
| def run_cmd(cmd): | ||
| """Executes a standard batch container shell script command synchronously.""" | ||
| process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True) | ||
| for line in process.stdout: | ||
| print(line, end="") | ||
| process.wait() | ||
| if process.returncode != 0: | ||
| print(f"\n[Execution Failure] Command failed with exit code: {process.returncode}") | ||
| sys.exit(process.returncode) |
Comment on lines
+20
to
+25
| def main(): | ||
| print("=========================================================================") | ||
| print(" Starting Automated Trusted Federated AI Production-Grade Sandbox Flow") | ||
| print("=========================================================================") | ||
|
|
||
| os.makedirs(ARTIFACTS_DIR, exist_ok=True) |
Comment on lines
+1
to
+2
| IID ALLELE_CT NAMED_ALLELE_DOSAGE_SUM PC1_AVG PC2_AVG PC3_AVG PC4_AVG PC5_AVG PC6_AVG PC7_AVG PC8_AVG PC9_AVG PC10_AVG super_pop | ||
| HG00117 1734952 1734952 0.0944944 -0.13021 -0.0262812 -0.0377915 -0.000188891 0.00293258 0.00503762 0.00867724 0.00301176 0.003309 EUR |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces a complete baseline federated learning implementation built from scratch without relying on the Flower framework.
The implementation provides a simple stateless FedAvg pipeline for ancestry classification and serves as a reference implementation for future FL experimentation and integration.
What's Included
Components
client.pyserver.pymodel.pyrun_simulation.pyplot_results.pyValidation
Verified by running the complete federated simulation across multiple rounds, including:
Notes
This PR establishes the baseline scratch implementation. Future work will build additional features (specifically work towards TES utilisation for FL) on top of this foundation.