-
Notifications
You must be signed in to change notification settings - Fork 97
sgd working #551
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+373
−54
Merged
sgd working #551
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
8199420
sgd working
duvylfyluksq d002a34
removed test code
duvylfyluksq 3ad9fe2
removed print statements
duvylfyluksq 3562813
removed hardcoded paths
duvylfyluksq f90ab3b
Merge branch 'WAYANG-FL' into WAYANG-FL
duvylfyluksq 6d91e66
README.md
duvylfyluksq File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# Project Outline | ||
|
||
This project aims to develop a user-friendly, platform-agnostic software framework for Federated Learning (FL) setups using Apache WAYANG. It abstracts away the underlying execution platforms, allowing seamless deployment across heterogeneous environments. The software has been tested with a basic Stochastic Gradient Descent (SGD) setup to validate its functionality. | ||
|
||
# Class Overview | ||
|
||
## Client Package | ||
|
||
### `Client.java` | ||
**Location:** `src/main/java/org/client/` | ||
**Purpose:** Represents a federated learning client with identifying information. | ||
Stores the client’s unique URL and name, which are used for communication and identification in FL workflows. | ||
|
||
### `FLClient.java` | ||
**Location:** `src/main/java/org/client/` | ||
**Purpose:** Implements a federated learning client using the Pekko actor model. | ||
Handles handshake, plan initialization, and computation requests from the server. | ||
Uses Apache WAYANG to build and execute logical dataflow plans on a specified platform (Java or Spark), enabling flexible backend execution. | ||
|
||
|
||
### `FLClientApp.java` | ||
**Location:** `src/main/java/org/client/` | ||
**Purpose:** Entry point for launching a federated learning client as a Pekko actor. | ||
Initializes the actor system with configuration and starts the `FLClient` with its associated metadata (URL, ID, platform, and input data). | ||
|
||
## Components package | ||
|
||
### `FLJob.java` | ||
**Location:** `src/main/java/org/components/` | ||
**Purpose:** Orchestrates the federated learning job by initializing the server actor, managing client connections, and running iterative training. | ||
Encapsulates job configuration including the aggregation logic, stopping criteria, plan function, hyperparameters, and update rules. | ||
Handles key stages: handshake, distributing plans/hyperparameters, checking stopping criterion, running iterations, and updating global state. | ||
|
||
### `FLSystem.java` | ||
**Location:** `src/main/java/org/components/` | ||
**Purpose:** Coordinates multiple federated learning jobs by managing clients, instantiating servers, and running registered jobs. | ||
Provides interfaces for job registration, execution, and retrieval of final results. Acts as the main system-level controller for managing FL lifecycles. | ||
|
||
|
||
### `Aggregator.java` | ||
**Location:** `src/main/java/org/components/aggregator/` | ||
**Purpose:** Defines an aggregator responsible for combining the responses from multiple clients using a specified aggregation function. | ||
This class leverages an `AggregatorFunction` to apply the defined aggregation logic on client responses, which are passed along with server hyperparameters. | ||
|
||
### `Criterion.java` | ||
**Location:** `src/main/java/org/components/criterion/` | ||
**Purpose:** Defines a criterion to determine whether the federated learning job should stop or continue. | ||
This class uses a user-defined function `shouldStop` that checks whether the job should halt based on the current values of the system, such as the progress or condition of the job. | ||
|
||
|
||
### `Hyperparameters.java` | ||
**Location:** `src/main/java/org/components/hyperparameters/` | ||
**Purpose:** Manages the hyperparameters for both the server and client in a federated learning system. | ||
This class provides methods to update and retrieve hyperparameters for the server and client. It maintains separate maps for server-specific and client-specific hyperparameters. | ||
|
||
|
||
### `AggregatorFunction.java` | ||
**Location:** `src/main/java/org/functions/` | ||
**Purpose:** Defines a functional interface for aggregating client responses in federated learning. | ||
The interface contains a single method, `apply`, that takes a list of client responses and a map of server hyperparameters, returning an aggregated result. | ||
This is used by the `Aggregator` class to apply custom aggregation logic. | ||
|
||
### `PlanFunction.java` | ||
**Location:** `src/main/java/org/functions/` | ||
**Purpose:** Defines a functional interface for applying custom operations to a Wayang `JavaPlanBuilder`. | ||
The interface contains a single method, `apply`, which takes three arguments: | ||
- `Object a`: A custom object input. | ||
- `JavaPlanBuilder b`: A `JavaPlanBuilder` instance for constructing the execution plan. | ||
- `Map<String, Object> c`: A map containing additional parameters. | ||
|
||
The method returns a `Operator`, which represents an operation in the Wayang execution plan. | ||
|
||
|
||
## Messages Package | ||
**Location:** `src/main/java/org/messages/` | ||
**Purpose:** Contains the various kinds of messages that are being used in the FL (Federated Learning) setup by the actor model. | ||
This package defines different message types exchanged between actors in the federated learning system, facilitating communication and synchronization between the server and clients during the training process. | ||
|
||
|
||
## Server Package | ||
|
||
### `Server.java` | ||
**Location:** `src/main/java/org/server/` | ||
**Purpose:** Represents a server in the Federated Learning (FL) system with a URL and name. | ||
|
||
### `FLServer.java` | ||
|
||
**Location:** `src/main/java/org/server/` | ||
**Purpose:** Represents a Federated Learning (FL) server in the system. This class handles client-server communication, hyperparameter synchronization, model updates, and iterative learning processes using the actor model. | ||
|
||
|
||
|
||
|
||
# SGD Testing | ||
We tested the SGD algorithm using our FL setup, with 3 clients and a server. The relevant code can be found in `src/test`. | ||
|
||
## Test Execution | ||
|
||
To start the testing process, you need to run the client tests. These simulate the behavior of the clients in the Federated Learning setup: | ||
|
||
- `FLClientTest1.java` | ||
- `FLClientTest2.java` | ||
- `FLClientTest3.java` | ||
|
||
Once the clients start running, you can run the `FLIntegrationTest.java` file. This file starts the server and coordinates the SGD training process for 5 epochs. The server receives updates from the clients, aggregates them, and runs the SGD setup. | ||
|
||
The current example using the `wayang-ml4all` package to run SGD. For a different algorithm, tha plan, hyperparameters, criterion and aggregator can be specified in the `FLIntegrationTest.java` class. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package org.temp; | ||
|
||
import org.apache.wayang.ml4all.abstraction.api.Transform; | ||
import org.apache.wayang.ml4all.utils.StringUtil; | ||
|
||
import java.util.List; | ||
|
||
public class LibSVMTransform extends Transform<double[], String> { | ||
|
||
int features; | ||
|
||
public LibSVMTransform (int features) { | ||
this.features = features; | ||
} | ||
|
||
@Override | ||
public double[] transform(String line) { | ||
List<String> pointStr = StringUtil.split(line, ' '); | ||
double[] point = new double[features+1]; | ||
point[0] = Double.parseDouble(pointStr.get(0)); | ||
for (int i = 1; i < pointStr.size(); i++) { | ||
if (pointStr.get(i).equals("")) { | ||
continue; | ||
} | ||
// String kv[] = pointStr.get(i).split(":", 2); | ||
// point[Integer.parseInt(kv[0])] = Double.parseDouble(kv[1]); | ||
point[i] = Double.parseDouble(pointStr.get(i)); | ||
} | ||
return point; | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package org.temp; | ||
|
||
import org.apache.wayang.api.DataQuantaBuilder; | ||
import org.apache.wayang.api.JavaPlanBuilder; | ||
import org.apache.wayang.core.api.Configuration; | ||
import org.apache.wayang.core.api.WayangContext; | ||
import org.apache.wayang.java.Java; | ||
import org.apache.wayang.ml4all.abstraction.api.Compute; | ||
import org.apache.wayang.ml4all.abstraction.api.Sample; | ||
import org.apache.wayang.ml4all.abstraction.api.Transform; | ||
import org.apache.wayang.ml4all.abstraction.plan.ML4allModel; | ||
import org.apache.wayang.ml4all.abstraction.plan.wrappers.ComputeWrapper; | ||
import org.apache.wayang.ml4all.abstraction.plan.wrappers.TransformPerPartitionWrapper; | ||
import org.apache.wayang.ml4all.algorithms.sgd.ComputeLogisticGradient; | ||
import org.apache.wayang.ml4all.algorithms.sgd.SGDSample; | ||
import org.client.FLClient; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.List; | ||
|
||
public class test { | ||
public static void main(String args[]){ | ||
|
||
WayangContext wayangContext = new WayangContext(new Configuration()).withPlugin(Java.basicPlugin()); | ||
JavaPlanBuilder pb = new JavaPlanBuilder(wayangContext) | ||
.withJobName("test-client"+"-job") | ||
.withUdfJarOf(FLClient.class); | ||
// List<Double> weights = new ArrayList<>(Collections.nCopies(29, 0.0));; | ||
double[] weights = new double[29]; | ||
String inputFileUrl = "file:/Users/vedantaneogi/Downloads/higgs_part1.txt"; | ||
int datasetSize = 29; | ||
ML4allModel model = new ML4allModel(); | ||
model.put("weights", weights); | ||
ArrayList<ML4allModel> broadcastModel = new ArrayList<>(1); | ||
broadcastModel.add(model); | ||
// Step 1: Define ML operators | ||
Sample sampleOp = new SGDSample(); | ||
Transform transformOp = new LibSVMTransform(29); | ||
Compute computeOp = new ComputeLogisticGradient(); | ||
|
||
// Step 2: Create weight DataQuanta | ||
var weightsBuilder = pb | ||
.loadCollection(broadcastModel) | ||
.withName("model"); | ||
|
||
// Step 3: Load dataset and apply transform | ||
DataQuantaBuilder transformBuilder = (DataQuantaBuilder) pb | ||
.readTextFile(inputFileUrl) | ||
.withName("source") | ||
.mapPartitions(new TransformPerPartitionWrapper(transformOp)) | ||
.withName("transform"); | ||
|
||
|
||
// Collection<?> parsedData = transformBuilder.collect(); | ||
// for (Object row : parsedData) { | ||
// System.out.println(row); | ||
// } | ||
|
||
// Step 4: Sample, compute gradient, and broadcast weights | ||
DataQuantaBuilder result = (DataQuantaBuilder) transformBuilder | ||
.sample(sampleOp.sampleSize()) | ||
.withSampleMethod(sampleOp.sampleMethod()) | ||
.withDatasetSize(datasetSize) | ||
.map(new ComputeWrapper<>(computeOp)) | ||
.withBroadcast(weightsBuilder, "model"); | ||
// | ||
// System.out.println(result.collect()); | ||
// Collection<?> output = result.collect(); | ||
// for (Object o : result.collect()) { | ||
// System.out.println("Type: " + o.getClass().getName()); | ||
// System.out.println("Value: " + o); | ||
// for (Object idk : (double[]) o){ | ||
// System.out.println(idk); | ||
// } | ||
// } | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,7 +34,7 @@ public static void setup() { | |
|
||
|
||
public static void main(String args[]){ | ||
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2552/user/client1", "client1", "java"); | ||
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2552/user/client1", "client1", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part1.txt"}); | ||
Config config = ConfigFactory.load("client1-application"); | ||
client_app.startFLClient(config); | ||
} | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ public static void setup() { | |
} | ||
|
||
public static void main(String args[]) throws Exception { | ||
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2553/user/client2", "client2", "java"); | ||
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2553/user/client2", "client2", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part2.txt"}); | ||
Config config = ConfigFactory.load("client2-application"); | ||
client_app.startFLClient(config); | ||
} | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,7 @@ public static void setup() { | |
} | ||
|
||
public static void main(String args[]) throws Exception { | ||
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2554/user/client3", "client3", "java"); | ||
FLClientApp client_app = new FLClientApp("pekko://[email protected]:2554/user/client3", "client3", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part3.txt"}); | ||
Config config = ConfigFactory.load("client3-application"); | ||
client_app.startFLClient(config); | ||
} | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these systems.out messages and use logger