Skip to content

sgd working #551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions wayang-fl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Project Outline

This project aims to develop a user-friendly, platform-agnostic software framework for Federated Learning (FL) setups using Apache WAYANG. It abstracts away the underlying execution platforms, allowing seamless deployment across heterogeneous environments. The software has been tested with a basic Stochastic Gradient Descent (SGD) setup to validate its functionality.

# Class Overview

## Client Package

### `Client.java`
**Location:** `src/main/java/org/client/`
**Purpose:** Represents a federated learning client with identifying information.
Stores the client’s unique URL and name, which are used for communication and identification in FL workflows.

### `FLClient.java`
**Location:** `src/main/java/org/client/`
**Purpose:** Implements a federated learning client using the Pekko actor model.
Handles handshake, plan initialization, and computation requests from the server.
Uses Apache WAYANG to build and execute logical dataflow plans on a specified platform (Java or Spark), enabling flexible backend execution.


### `FLClientApp.java`
**Location:** `src/main/java/org/client/`
**Purpose:** Entry point for launching a federated learning client as a Pekko actor.
Initializes the actor system with configuration and starts the `FLClient` with its associated metadata (URL, ID, platform, and input data).

## Components package

### `FLJob.java`
**Location:** `src/main/java/org/components/`
**Purpose:** Orchestrates the federated learning job by initializing the server actor, managing client connections, and running iterative training.
Encapsulates job configuration including the aggregation logic, stopping criteria, plan function, hyperparameters, and update rules.
Handles key stages: handshake, distributing plans/hyperparameters, checking stopping criterion, running iterations, and updating global state.

### `FLSystem.java`
**Location:** `src/main/java/org/components/`
**Purpose:** Coordinates multiple federated learning jobs by managing clients, instantiating servers, and running registered jobs.
Provides interfaces for job registration, execution, and retrieval of final results. Acts as the main system-level controller for managing FL lifecycles.


### `Aggregator.java`
**Location:** `src/main/java/org/components/aggregator/`
**Purpose:** Defines an aggregator responsible for combining the responses from multiple clients using a specified aggregation function.
This class leverages an `AggregatorFunction` to apply the defined aggregation logic on client responses, which are passed along with server hyperparameters.

### `Criterion.java`
**Location:** `src/main/java/org/components/criterion/`
**Purpose:** Defines a criterion to determine whether the federated learning job should stop or continue.
This class uses a user-defined function `shouldStop` that checks whether the job should halt based on the current values of the system, such as the progress or condition of the job.


### `Hyperparameters.java`
**Location:** `src/main/java/org/components/hyperparameters/`
**Purpose:** Manages the hyperparameters for both the server and client in a federated learning system.
This class provides methods to update and retrieve hyperparameters for the server and client. It maintains separate maps for server-specific and client-specific hyperparameters.


### `AggregatorFunction.java`
**Location:** `src/main/java/org/functions/`
**Purpose:** Defines a functional interface for aggregating client responses in federated learning.
The interface contains a single method, `apply`, that takes a list of client responses and a map of server hyperparameters, returning an aggregated result.
This is used by the `Aggregator` class to apply custom aggregation logic.

### `PlanFunction.java`
**Location:** `src/main/java/org/functions/`
**Purpose:** Defines a functional interface for applying custom operations to a Wayang `JavaPlanBuilder`.
The interface contains a single method, `apply`, which takes three arguments:
- `Object a`: A custom object input.
- `JavaPlanBuilder b`: A `JavaPlanBuilder` instance for constructing the execution plan.
- `Map<String, Object> c`: A map containing additional parameters.

The method returns a `Operator`, which represents an operation in the Wayang execution plan.


## Messages Package
**Location:** `src/main/java/org/messages/`
**Purpose:** Contains the various kinds of messages that are being used in the FL (Federated Learning) setup by the actor model.
This package defines different message types exchanged between actors in the federated learning system, facilitating communication and synchronization between the server and clients during the training process.


## Server Package

### `Server.java`
**Location:** `src/main/java/org/server/`
**Purpose:** Represents a server in the Federated Learning (FL) system with a URL and name.

### `FLServer.java`

**Location:** `src/main/java/org/server/`
**Purpose:** Represents a Federated Learning (FL) server in the system. This class handles client-server communication, hyperparameter synchronization, model updates, and iterative learning processes using the actor model.




# SGD Testing
We tested the SGD algorithm using our FL setup, with 3 clients and a server. The relevant code can be found in `src/test`.

## Test Execution

To start the testing process, you need to run the client tests. These simulate the behavior of the clients in the Federated Learning setup:

- `FLClientTest1.java`
- `FLClientTest2.java`
- `FLClientTest3.java`

Once the clients start running, you can run the `FLIntegrationTest.java` file. This file starts the server and coordinates the SGD training process for 5 epochs. The server receives updates from the clients, aggregates them, and runs the SGD setup.

The current example using the `wayang-ml4all` package to run SGD. For a different algorithm, tha plan, hyperparameters, criterion and aggregator can be specified in the `FLIntegrationTest.java` class.
10 changes: 8 additions & 2 deletions wayang-fl/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,14 @@
</dependency>
<dependency>
<groupId>org.apache.wayang</groupId>
<artifactId>wayang-spark</artifactId>
<version>1.0.1-SNAPSHOT</version>

<artifactId>wayang-ml4all</artifactId>
<version>0.7.1</version>
</dependency>
<dependency>
<groupId>org.apache.wayang</groupId>
<artifactId>wayang-spark_2.12</artifactId>
<version>0.7.1</version>
</dependency>
<dependency>
<groupId>org.apache.wayang</groupId>
Expand Down
22 changes: 11 additions & 11 deletions wayang-fl/src/main/java/org/client/FLClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,26 @@ public class FLClient extends AbstractActor {
private JavaPlanBuilder planBuilder;
private PlanFunction planFunction;
private Map<String, Object> hyperparams;
private String[] inputFiles;

private Plugin getPlugin(String platformType){
if(platformType.equals("java")) return Java.basicPlugin();
else if(platformType.equals("spark")) return Spark.basicPlugin();
else return null;
}

public Props props(Client client, String platformType) {
return Props.create(FLClient.class, () -> new FLClient(client, platformType));
public Props props(Client client, String platformType, String[] inputFiles) {
return Props.create(FLClient.class, () -> new FLClient(client, platformType, inputFiles));
}

public FLClient(Client client, String platformType) {
public FLClient(Client client, String platformType, String[] inputFiles) {
this.client = client;
this.wayangContext = new WayangContext(new Configuration()).withPlugin(getPlugin(platformType));
this.planBuilder = new JavaPlanBuilder(wayangContext)
.withJobName(client.getName()+"-job")
.withUdfJarOf(FLClient.class);
this.collector = new LinkedList<>();
this.inputFiles = inputFiles;
}

@Override
Expand All @@ -88,33 +90,31 @@ private void handlePlanHyperparametersMessage(PlanHyperparametersMessage msg) {
System.out.println(client.getName() + " receiving plan");
planFunction = msg.getSerializedplan();
hyperparams = msg.getHyperparams();
hyperparams.put("inputFiles", inputFiles);
getSender().tell(new PlanHyperparametersAckMessage(), getSelf());
System.out.println(client.getName() + " initialised plan function");
}

private void buildPlan(Object operand){

Operator op = planFunction.apply(operand, planBuilder, hyperparams);
// System.out.println(op);
JavaPlanBuilder newPlanBuilder = new JavaPlanBuilder(wayangContext)
.withJobName(client.getName()+"-job")
.withUdfJarOf(FLClient.class);
Operator op = planFunction.apply(operand, newPlanBuilder, hyperparams);
Class classType = op.getOutput(0).getType().getDataUnitType().getTypeClass();
LocalCallbackSink<?> sink = LocalCallbackSink.createCollectingSink(collector, classType);
op.connectTo(0, sink, 0);
plan = new WayangPlan(sink);
// System.out.println(plan);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these systems.out messages and use logger

}



private void handleClientUpdateRequestMessage(ClientUpdateRequestMessage msg) {
System.out.println(client.getName() + " Received compute request");
// System.out.println(planFunction);
// System.out.println(client.getName() + " Printed planFunction");
Object operand = msg.getValue();
buildPlan(operand);
wayangContext.execute(client.getName() + "-job", plan);
System.out.println(client.getName() + " executed plan successfully");
getSender().tell(new ClientUpdateResponseMessage(new LinkedList<>(collector)), getSelf());
// System.out.println(client.getName());
// System.out.println(collector);
collector.clear();
}

Expand Down
6 changes: 4 additions & 2 deletions wayang-fl/src/main/java/org/client/FLClientApp.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,19 @@
public class FLClientApp {
private final Client client;
private final String platform_type;
private final String[] inputFiles;

public FLClientApp(String client_url, String client_id, String platform_type){
public FLClientApp(String client_url, String client_id, String platform_type, String[] inputFiles){
this.client = new Client(client_url, client_id);
this.platform_type = platform_type;
this.inputFiles = inputFiles;
}

public void startFLClient(Config config){
ActorSystem system = ActorSystem.create(client.getName() + "-system", config);
ActorRef FLClientActor = system.actorOf(
Props.create(FLClient.class, () -> new FLClient(
client, platform_type
client, platform_type, inputFiles
)),
client.getName()
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.functions.AggregatorFunction;

import java.util.Collection;
import java.util.List;
import java.util.Map;

Expand All @@ -31,6 +32,13 @@ public Aggregator(AggregatorFunction aggregator){
}

public Object aggregate(List<Object> ClientResponses, Map<String, Object> server_hyperparams){
// for(Object response : ClientResponses){
// System.out.println("printing client response");
// System.out.println(response);
// for (double o : (double[]) response) {
// System.out.println(" Element: " + o);
// }
// }
return aggregator.apply(ClientResponses, server_hyperparams);
}
}
2 changes: 2 additions & 0 deletions wayang-fl/src/main/java/org/server/FLServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ public void handleSendPlanHyperParametersMessage(SendPlanHyperparametersMessage
for(ActorRef client : active_clients.keySet()){
if(!active_clients.get(client)) continue;
active_client_count++;
// remove this line later
client.tell(new PlanHyperparametersMessage(plan, client_hyperparams), getSelf());
}
// while(client_acks < active_client_count){}
Expand All @@ -105,6 +106,7 @@ public void handleRunIterationMessage(RunIterationMessage msg){
}

public void handleAggregateResponsesMessage(AggregateResponsesMessage msg){
System.out.println("Iteration Over, Aggregating Responses");
Object aggregatedResult = aggregator.aggregate(client_responses, hyperparams);
getSender().tell(aggregatedResult, getSelf());
}
Expand Down
31 changes: 31 additions & 0 deletions wayang-fl/src/main/java/org/temp/LibSVMTransform.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package org.temp;

import org.apache.wayang.ml4all.abstraction.api.Transform;
import org.apache.wayang.ml4all.utils.StringUtil;

import java.util.List;

public class LibSVMTransform extends Transform<double[], String> {

int features;

public LibSVMTransform (int features) {
this.features = features;
}

@Override
public double[] transform(String line) {
List<String> pointStr = StringUtil.split(line, ' ');
double[] point = new double[features+1];
point[0] = Double.parseDouble(pointStr.get(0));
for (int i = 1; i < pointStr.size(); i++) {
if (pointStr.get(i).equals("")) {
continue;
}
// String kv[] = pointStr.get(i).split(":", 2);
// point[Integer.parseInt(kv[0])] = Double.parseDouble(kv[1]);
point[i] = Double.parseDouble(pointStr.get(i));
}
return point;
}
}
78 changes: 78 additions & 0 deletions wayang-fl/src/main/java/org/temp/test.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package org.temp;

import org.apache.wayang.api.DataQuantaBuilder;
import org.apache.wayang.api.JavaPlanBuilder;
import org.apache.wayang.core.api.Configuration;
import org.apache.wayang.core.api.WayangContext;
import org.apache.wayang.java.Java;
import org.apache.wayang.ml4all.abstraction.api.Compute;
import org.apache.wayang.ml4all.abstraction.api.Sample;
import org.apache.wayang.ml4all.abstraction.api.Transform;
import org.apache.wayang.ml4all.abstraction.plan.ML4allModel;
import org.apache.wayang.ml4all.abstraction.plan.wrappers.ComputeWrapper;
import org.apache.wayang.ml4all.abstraction.plan.wrappers.TransformPerPartitionWrapper;
import org.apache.wayang.ml4all.algorithms.sgd.ComputeLogisticGradient;
import org.apache.wayang.ml4all.algorithms.sgd.SGDSample;
import org.client.FLClient;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class test {
public static void main(String args[]){

WayangContext wayangContext = new WayangContext(new Configuration()).withPlugin(Java.basicPlugin());
JavaPlanBuilder pb = new JavaPlanBuilder(wayangContext)
.withJobName("test-client"+"-job")
.withUdfJarOf(FLClient.class);
// List<Double> weights = new ArrayList<>(Collections.nCopies(29, 0.0));;
double[] weights = new double[29];
String inputFileUrl = "file:/Users/vedantaneogi/Downloads/higgs_part1.txt";
int datasetSize = 29;
ML4allModel model = new ML4allModel();
model.put("weights", weights);
ArrayList<ML4allModel> broadcastModel = new ArrayList<>(1);
broadcastModel.add(model);
// Step 1: Define ML operators
Sample sampleOp = new SGDSample();
Transform transformOp = new LibSVMTransform(29);
Compute computeOp = new ComputeLogisticGradient();

// Step 2: Create weight DataQuanta
var weightsBuilder = pb
.loadCollection(broadcastModel)
.withName("model");

// Step 3: Load dataset and apply transform
DataQuantaBuilder transformBuilder = (DataQuantaBuilder) pb
.readTextFile(inputFileUrl)
.withName("source")
.mapPartitions(new TransformPerPartitionWrapper(transformOp))
.withName("transform");


// Collection<?> parsedData = transformBuilder.collect();
// for (Object row : parsedData) {
// System.out.println(row);
// }

// Step 4: Sample, compute gradient, and broadcast weights
DataQuantaBuilder result = (DataQuantaBuilder) transformBuilder
.sample(sampleOp.sampleSize())
.withSampleMethod(sampleOp.sampleMethod())
.withDatasetSize(datasetSize)
.map(new ComputeWrapper<>(computeOp))
.withBroadcast(weightsBuilder, "model");
//
// System.out.println(result.collect());
// Collection<?> output = result.collect();
// for (Object o : result.collect()) {
// System.out.println("Type: " + o.getClass().getName());
// System.out.println("Value: " + o);
// for (Object idk : (double[]) o){
// System.out.println(idk);
// }
// }
}
}
2 changes: 1 addition & 1 deletion wayang-fl/src/test/java/org/test/FLClientTest1.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public static void setup() {


public static void main(String args[]){
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2552/user/client1", "client1", "java");
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2552/user/client1", "client1", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part1.txt"});
Config config = ConfigFactory.load("client1-application");
client_app.startFLClient(config);
}
Expand Down
2 changes: 1 addition & 1 deletion wayang-fl/src/test/java/org/test/FLClientTest2.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static void setup() {
}

public static void main(String args[]) throws Exception {
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2553/user/client2", "client2", "java");
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2553/user/client2", "client2", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part2.txt"});
Config config = ConfigFactory.load("client2-application");
client_app.startFLClient(config);
}
Expand Down
2 changes: 1 addition & 1 deletion wayang-fl/src/test/java/org/test/FLClientTest3.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public static void setup() {
}

public static void main(String args[]) throws Exception {
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2554/user/client3", "client3", "java");
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2554/user/client3", "client3", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part3.txt"});
Config config = ConfigFactory.load("client3-application");
client_app.startFLClient(config);
}
Expand Down
Loading