ComputeShare is a lightweight federated machine learning training system built with PyTorch. It enables users to distribute training tasks across multiple machines over the internet, utilizing native hardware acceleration — CUDA on NVIDIA GPUs (Windows/Linux), MPS (Metal Performance Shaders) on Apple Silicon Macs, or CPU as a universal fallback — to split the workload and drastically reduce training time.
Teaching Computers to Train Together: Building a Distributed Training Platform Across Multiple GPUs
The system consists of two primary components operating in a Bulk Synchronous Parallel formation:
- Parameter Server (
server.py): The central node that holds the global PyTorch model. It asynchronously waits for workers to submit their computed gradients, decompresses the payload, mathematically averages them via SGD, and updates the global weights. The server implements the Linear Scaling Rule, dynamically multiplying the learning rate by the total WorkerWORLD_SIZEto prevent mathematical decay in multi-node clusters. - Workers (
worker.py): Distributed clients that pull the latest model from the server, process a unique mathematical shard of the training dataset using their local GPU (CUDA / MPS / CPU), and submit the calculated vectors back to the server.
- Stale Gradients Rejection: To prevent a slow worker from polluting the global weights, workers attach an
X-Worker-VersionHTTP header with their gradients. If the Server has already advanced to a new global version, it immediately mathematically rejects the slow worker's payload (HTTP 409 Conflict) and forces the worker to re-pull the new weights and recompute. - Connection Continuity: The workers leverage extended
requeststimeouts (15s/30s) to survive aggressive LocalTunnel connection spikes without experiencing process-ending timeout drops. - Compression: To drastically decrease network overhead (e.g., preventing Ngrok blocks), gradients are not serialized into JSON. The workers serialize their PyTorch tensors using
io.BytesIO(), compress them locally viagzip, and transmit the binary payload (application/octet-stream). - Authentication: All external HTTP communication is secured with a mandatory 4-digit PIN header.
Workers automatically detect and use the best available hardware at runtime — no configuration needed.
| Platform | Hardware | Backend |
|---|---|---|
| Windows / Linux | NVIDIA GPU | CUDA |
| macOS (Apple Silicon) | M1 / M2 / M3 GPU | MPS |
| Any machine | CPU only | CPU (fallback) |
Priority order: CUDA → MPS → CPU
Before running the system, initialize the Python environment and install the required dependencies:
python3 -m venv .venv
source .venv/bin/activate
pip install -r requirements.txtDesignate one machine to act as the parameter server. The server can be initialized interactively or via inline arguments to bypass prompts.
Interactive Initialization:
python server.py --dataset MNIST --pin 1234Inline Initialization:
python server.py --dataset MNIST --pinSizEpo <PIN> <WORLD_SIZE> <TOTAL_GLOBAL_BATCHES>
# Example: python server.py --dataset FashionMNIST --pinSizEpo 1234 2 50You will be prompted to define:
WORLD_SIZE: Total number of distributed workers.TOTAL_GLOBAL_BATCHES: The number of gradient averaging rounds before the model is saved.--dataset: The dataset to use (e.g.,MNIST,FashionMNIST,CIFAR10). Default isMNIST.
Note: For external connections over the internet, we highly recommend exposing port 8000 using a Free Cloudflare Tunnel for maximum speed and stability:
npx cloudflared tunnel --url http://localhost:8000
(Alternatively, if all machines are on the exact same Wi-Fi/Local Network, simply use the server's Local IP Address http://192.168.x.x:8000 natively without any tunnels for zero-latency setups!)
On any machine participating in the training, execute the worker script.
Interactive Initialization:
python worker.py --dataset MNIST --pin 1234Inline Initialization:
python worker.py --dataset MNIST --pinSizRanBatEpo <PIN> <WORLD_SIZE> <RANK> <BATCH_SIZE> <TOTAL_GLOBAL_BATCHES>
# Example: python worker.py --dataset FashionMNIST --pinSizRanBatEpo 1234 2 0 32 50Regardless of initialization method, workers require:
WORLD_SIZE: Must match the server configuration.RANK: The worker's unique ID (0 toWORLD_SIZE - 1).BATCH_SIZE: Number of images to process per forward pass.TOTAL_GLOBAL_BATCHES: Must match the server configuration.
Note: Edit SERVER_URL on line 30 of worker.py if connecting via LocalTunnel.
Once the server reaches the target global batches, it will save the final weights to trained_model.pth. Execute the testing script to evaluate its accuracy on 10,000 novel images:
python test.py --dataset MNISTAny recognized torchvision dataset works dynamically without crashing, thanks to the Universal Dataset Factory in utils.py. The Factory acts as an intelligent API router that automatically resolves PyTorch's wildly inconsistent train=True vs split='train' kwargs, maps all topological class bounds so the CUDA/MPS/CPU device dynamically scales its final layers, and securely rejects multi-label datasets like CelebA before they initialize.
Furthermore, the Factory features a Dynamic Auto-Installer that actively intercepts missing underlying dataset dependencies (e.g., PyTorch crashing because it needs h5py or gdown for PCAM) and iteratively installs them on the fly in the background via pip! test.py also no longer hardcodes "10,000" and will natively calculate the precise population volume of any validation shape processed.
Ranked Performance Index: (Keep in mind, to keep bandwidth extremely low (< 50KB/sec), all input is mathematically shrunk to 28x28 grayscale tensors).
1. High Performance Tier (95%+ Accuracy natively - Simple contours natively scale):
MNIST(Handwritten digits)FashionMNIST(Articles of clothing)KMNIST(Kuzushiji characters)QMNIST(Extended Handwritten Digits)EMNIST(Extended digits/letters)USPS(Postal Digits)
2. Medium Performance Tier (60% - 80% Accuracy natively - Silhouettes survive thresholding):
CIFAR10(10 classes of objects)SVHN(Street View House Numbers)STL10(Higher Res Objects)
3. Low Performance Tier (Runs flawlessly over the network, but suffers deep predictive accuracy penalties when shrunk to 28x28 Grayscales due to heavy reliance on High-Resolution Color mapping):
CIFAR100(100 classes of objects)StanfordCars(Car Models - 196 Classes)PCAM(Medical Cancer scans)EuroSAT(Satellite Imagery - 10 Classes)Flowers102(Flower Species - 102 Classes)OxfordIIITPet(Pets - 37 Breeds)Places365(Scenes/Places - 365 Classes)Food101(Food Dishes - 101 Classes)GTSRB(Traffic Signs - 43 Classes)DTD(Textures - 47 Classes)FGVCAircraft(Aircraft Models - 100 Classes)Country211(Photos by Country - 211 Classes)Caltech101/Caltech256(General Objects)