Skip to content

Add federated learning with DRS from scratch implementation#1

Draft
TheVidz wants to merge 3 commits into
mainfrom
fl-from-scratch
Draft

Add federated learning with DRS from scratch implementation#1
TheVidz wants to merge 3 commits into
mainfrom
fl-from-scratch

Conversation

@TheVidz

@TheVidz TheVidz commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

Summary

This PR introduces a complete baseline federated learning implementation built from scratch without relying on the Flower framework.

The implementation provides a simple stateless FedAvg pipeline for ancestry classification and serves as a reference implementation for future FL experimentation and integration.

What's Included

  • Scratch implementation of the FedAvg aggregation algorithm
  • Stateless client training pipeline
  • Global server orchestration
  • Ancestry classification model
  • Automated multi-round simulation runner
  • Performance visualization utilities
  • Client/server metric logging
  • Checkpoint management across federated rounds
  • GA4GH DRS-based dataset resolution for clients

Components

  • client.py

    • Local training
    • Validation
    • Metrics generation
    • DRS dataset loading
    • Model checkpoint generation
  • server.py

    • FedAvg aggregation
    • Global checkpoint generation
    • Weighted metric aggregation
  • model.py

    • Shared neural network architecture
    • Parameter serialization utilities
  • run_simulation.py

    • End-to-end federated training orchestration
  • plot_results.py

    • Training convergence visualization

Validation

Verified by running the complete federated simulation across multiple rounds, including:

  • Client training
  • Global aggregation
  • Checkpoint generation
  • Metrics logging
  • Convergence plot generation

Notes

This PR establishes the baseline scratch implementation. Future work will build additional features (specifically work towards TES utilisation for FL) on top of this foundation.

@TheVidz TheVidz marked this pull request as draft June 30, 2026 14:14
@TheVidz TheVidz requested a review from Copilot June 30, 2026 14:14

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a scratch (non-Flower) baseline federated learning simulation for ancestry classification, including a simple FedAvg server, stateless client training with GA4GH DRS-based dataset resolution, a multi-round runner, and plotting utilities.

Changes:

  • Introduces a FedAvg server that aggregates client checkpoints and logs weighted global metrics per round.
  • Adds a stateless client that resolves datasets via DRS (with fallback), trains locally, emits metrics, and writes per-round checkpoints with embedded metadata.
  • Adds a local simulation orchestrator and a plotting script for convergence visualization, plus updates .gitignore.

Reviewed changes

Copilot reviewed 8 out of 10 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
scratch-fl/server.py Implements server-side FedAvg aggregation and server metrics logging/checkpointing.
scratch-fl/client.py Implements client-side DRS dataset resolution, local training/evaluation, and checkpoint/metrics output.
scratch-fl/model.py Defines the shared PyTorch model architecture and parameter helpers.
scratch-fl/run_simulation.py Orchestrates multi-round local simulation by launching server/client steps.
scratch-fl/plot_results.py Generates convergence plots from server metrics CSV.
scratch-fl/data/site_2.tsv Adds a local TSV dataset shard used as sample input data.
.gitignore Updates ignored FL-related artifact directories.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread scratch-fl/server.py
Comment on lines +22 to +35
global_state = global_model.state_dict()
aggregated_state = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in global_state.items()}

logger.info(f"Aggregating {len(client_checkpoints)} client state dicts across {total_samples} cumulative samples.")

for path, sample_size in zip(client_checkpoints, client_sample_sizes):
client_state = torch.load(path, map_location="cpu")
weight = sample_size / total_samples
for k in aggregated_state.keys():
# Avoid aggregating embedded metadata wrappers
if k == "metadata":
continue
aggregated_state[k] += client_state[k].float() * weight

Comment thread scratch-fl/server.py
Comment on lines +27 to +29
for path, sample_size in zip(client_checkpoints, client_sample_sizes):
client_state = torch.load(path, map_location="cpu")
weight = sample_size / total_samples
Comment thread scratch-fl/server.py
Comment on lines +79 to +82
# Load file safely to parse data total allocations
payload = torch.load(client_out_path, map_location="cpu")
meta = payload.get("metadata", {})

Comment thread scratch-fl/client.py
Comment on lines +55 to +66
except Exception as e:
print(f"[DRS Error] Starter kit connection failed at {meta_url}. Falling back to clean mock local files.")
fallback_path = f"./data_site_{args.client_id}_unified.tsv"
if not os.path.exists(fallback_path):
# Synthesize mockup data automatically if both DRS server and data rows are absent
mock_data = []
for i in range(250):
row = {f"PC{idx}_AVG": torch.randn(1).item() for idx in range(1, 11)}
row["super_pop"] = SUPERPOPS[i % 5]
mock_data.append(row)
pd.DataFrame(mock_data).to_csv(fallback_path, sep="\t", index=False)
return fallback_path
Comment thread scratch-fl/client.py
Comment on lines +44 to +54
meta_url = f"http://localhost:4500/ga4gh/drs/v1/objects/{object_id}"
try:
meta_resp = requests.get(meta_url, timeout=5).json()
access_id = meta_resp["access_methods"][0]["access_id"]

access_url = f"http://localhost:4500/ga4gh/drs/v1/objects/{object_id}/access/{access_id}"
stream_url = requests.get(access_url, timeout=5).json()["url"]

if stream_url.startswith("file://"):
stream_url = stream_url.replace("file://", "", 1)
return stream_url
Comment on lines +1 to +18
# run_simulation.py
import os
import subprocess
import sys

ROUNDS = 15
NUM_CLIENTS = 4
ARTIFACTS_DIR = "./checkpoints"

def run_cmd(cmd):
"""Executes a standard batch container shell script command synchronously."""
process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True)
for line in process.stdout:
print(line, end="")
process.wait()
if process.returncode != 0:
print(f"\n[Execution Failure] Command failed with exit code: {process.returncode}")
sys.exit(process.returncode)
Comment on lines +20 to +25
def main():
print("=========================================================================")
print(" Starting Automated Trusted Federated AI Production-Grade Sandbox Flow")
print("=========================================================================")

os.makedirs(ARTIFACTS_DIR, exist_ok=True)
Comment on lines +1 to +2
IID ALLELE_CT NAMED_ALLELE_DOSAGE_SUM PC1_AVG PC2_AVG PC3_AVG PC4_AVG PC5_AVG PC6_AVG PC7_AVG PC8_AVG PC9_AVG PC10_AVG super_pop
HG00117 1734952 1734952 0.0944944 -0.13021 -0.0262812 -0.0377915 -0.000188891 0.00293258 0.00503762 0.00867724 0.00301176 0.003309 EUR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants