diff --git a/wayang-fl/README.md b/wayang-fl/README.md new file mode 100644 index 000000000..52d188bd5 --- /dev/null +++ b/wayang-fl/README.md @@ -0,0 +1,107 @@ +# Project Outline + +This project aims to develop a user-friendly, platform-agnostic software framework for Federated Learning (FL) setups using Apache WAYANG. It abstracts away the underlying execution platforms, allowing seamless deployment across heterogeneous environments. The software has been tested with a basic Stochastic Gradient Descent (SGD) setup to validate its functionality. + +# Class Overview + +## Client Package + +### `Client.java` +**Location:** `src/main/java/org/client/` +**Purpose:** Represents a federated learning client with identifying information. +Stores the client’s unique URL and name, which are used for communication and identification in FL workflows. + +### `FLClient.java` +**Location:** `src/main/java/org/client/` +**Purpose:** Implements a federated learning client using the Pekko actor model. +Handles handshake, plan initialization, and computation requests from the server. +Uses Apache WAYANG to build and execute logical dataflow plans on a specified platform (Java or Spark), enabling flexible backend execution. + + +### `FLClientApp.java` +**Location:** `src/main/java/org/client/` +**Purpose:** Entry point for launching a federated learning client as a Pekko actor. +Initializes the actor system with configuration and starts the `FLClient` with its associated metadata (URL, ID, platform, and input data). + +## Components package + +### `FLJob.java` +**Location:** `src/main/java/org/components/` +**Purpose:** Orchestrates the federated learning job by initializing the server actor, managing client connections, and running iterative training. +Encapsulates job configuration including the aggregation logic, stopping criteria, plan function, hyperparameters, and update rules. +Handles key stages: handshake, distributing plans/hyperparameters, checking stopping criterion, running iterations, and updating global state. + +### `FLSystem.java` +**Location:** `src/main/java/org/components/` +**Purpose:** Coordinates multiple federated learning jobs by managing clients, instantiating servers, and running registered jobs. +Provides interfaces for job registration, execution, and retrieval of final results. Acts as the main system-level controller for managing FL lifecycles. + + +### `Aggregator.java` +**Location:** `src/main/java/org/components/aggregator/` +**Purpose:** Defines an aggregator responsible for combining the responses from multiple clients using a specified aggregation function. +This class leverages an `AggregatorFunction` to apply the defined aggregation logic on client responses, which are passed along with server hyperparameters. + +### `Criterion.java` +**Location:** `src/main/java/org/components/criterion/` +**Purpose:** Defines a criterion to determine whether the federated learning job should stop or continue. +This class uses a user-defined function `shouldStop` that checks whether the job should halt based on the current values of the system, such as the progress or condition of the job. + + +### `Hyperparameters.java` +**Location:** `src/main/java/org/components/hyperparameters/` +**Purpose:** Manages the hyperparameters for both the server and client in a federated learning system. +This class provides methods to update and retrieve hyperparameters for the server and client. It maintains separate maps for server-specific and client-specific hyperparameters. + + +### `AggregatorFunction.java` +**Location:** `src/main/java/org/functions/` +**Purpose:** Defines a functional interface for aggregating client responses in federated learning. +The interface contains a single method, `apply`, that takes a list of client responses and a map of server hyperparameters, returning an aggregated result. +This is used by the `Aggregator` class to apply custom aggregation logic. + +### `PlanFunction.java` +**Location:** `src/main/java/org/functions/` +**Purpose:** Defines a functional interface for applying custom operations to a Wayang `JavaPlanBuilder`. +The interface contains a single method, `apply`, which takes three arguments: +- `Object a`: A custom object input. +- `JavaPlanBuilder b`: A `JavaPlanBuilder` instance for constructing the execution plan. +- `Map c`: A map containing additional parameters. + +The method returns a `Operator`, which represents an operation in the Wayang execution plan. + + +## Messages Package +**Location:** `src/main/java/org/messages/` +**Purpose:** Contains the various kinds of messages that are being used in the FL (Federated Learning) setup by the actor model. +This package defines different message types exchanged between actors in the federated learning system, facilitating communication and synchronization between the server and clients during the training process. + + +## Server Package + +### `Server.java` +**Location:** `src/main/java/org/server/` +**Purpose:** Represents a server in the Federated Learning (FL) system with a URL and name. + +### `FLServer.java` + +**Location:** `src/main/java/org/server/` +**Purpose:** Represents a Federated Learning (FL) server in the system. This class handles client-server communication, hyperparameter synchronization, model updates, and iterative learning processes using the actor model. + + + + +# SGD Testing +We tested the SGD algorithm using our FL setup, with 3 clients and a server. The relevant code can be found in `src/test`. + +## Test Execution + +To start the testing process, you need to run the client tests. These simulate the behavior of the clients in the Federated Learning setup: + +- `FLClientTest1.java` +- `FLClientTest2.java` +- `FLClientTest3.java` + +Once the clients start running, you can run the `FLIntegrationTest.java` file. This file starts the server and coordinates the SGD training process for 5 epochs. The server receives updates from the clients, aggregates them, and runs the SGD setup. + +The current example using the `wayang-ml4all` package to run SGD. For a different algorithm, tha plan, hyperparameters, criterion and aggregator can be specified in the `FLIntegrationTest.java` class. \ No newline at end of file diff --git a/wayang-fl/pom.xml b/wayang-fl/pom.xml index efcc363d4..8d46a1ea6 100644 --- a/wayang-fl/pom.xml +++ b/wayang-fl/pom.xml @@ -116,8 +116,14 @@ org.apache.wayang - wayang-spark - 1.0.1-SNAPSHOT + + wayang-ml4all + 0.7.1 + + + org.apache.wayang + wayang-spark_2.12 + 0.7.1 org.apache.wayang diff --git a/wayang-fl/src/main/java/org/client/FLClient.java b/wayang-fl/src/main/java/org/client/FLClient.java index 6cea8d503..6b5136ec8 100644 --- a/wayang-fl/src/main/java/org/client/FLClient.java +++ b/wayang-fl/src/main/java/org/client/FLClient.java @@ -50,6 +50,7 @@ public class FLClient extends AbstractActor { private JavaPlanBuilder planBuilder; private PlanFunction planFunction; private Map hyperparams; + private String[] inputFiles; private Plugin getPlugin(String platformType){ if(platformType.equals("java")) return Java.basicPlugin(); @@ -57,17 +58,18 @@ private Plugin getPlugin(String platformType){ else return null; } - public Props props(Client client, String platformType) { - return Props.create(FLClient.class, () -> new FLClient(client, platformType)); + public Props props(Client client, String platformType, String[] inputFiles) { + return Props.create(FLClient.class, () -> new FLClient(client, platformType, inputFiles)); } - public FLClient(Client client, String platformType) { + public FLClient(Client client, String platformType, String[] inputFiles) { this.client = client; this.wayangContext = new WayangContext(new Configuration()).withPlugin(getPlugin(platformType)); this.planBuilder = new JavaPlanBuilder(wayangContext) .withJobName(client.getName()+"-job") .withUdfJarOf(FLClient.class); this.collector = new LinkedList<>(); + this.inputFiles = inputFiles; } @Override @@ -88,33 +90,31 @@ private void handlePlanHyperparametersMessage(PlanHyperparametersMessage msg) { System.out.println(client.getName() + " receiving plan"); planFunction = msg.getSerializedplan(); hyperparams = msg.getHyperparams(); + hyperparams.put("inputFiles", inputFiles); getSender().tell(new PlanHyperparametersAckMessage(), getSelf()); System.out.println(client.getName() + " initialised plan function"); } private void buildPlan(Object operand){ - - Operator op = planFunction.apply(operand, planBuilder, hyperparams); -// System.out.println(op); + JavaPlanBuilder newPlanBuilder = new JavaPlanBuilder(wayangContext) + .withJobName(client.getName()+"-job") + .withUdfJarOf(FLClient.class); + Operator op = planFunction.apply(operand, newPlanBuilder, hyperparams); Class classType = op.getOutput(0).getType().getDataUnitType().getTypeClass(); LocalCallbackSink sink = LocalCallbackSink.createCollectingSink(collector, classType); op.connectTo(0, sink, 0); plan = new WayangPlan(sink); -// System.out.println(plan); } private void handleClientUpdateRequestMessage(ClientUpdateRequestMessage msg) { System.out.println(client.getName() + " Received compute request"); -// System.out.println(planFunction); -// System.out.println(client.getName() + " Printed planFunction"); Object operand = msg.getValue(); buildPlan(operand); wayangContext.execute(client.getName() + "-job", plan); + System.out.println(client.getName() + " executed plan successfully"); getSender().tell(new ClientUpdateResponseMessage(new LinkedList<>(collector)), getSelf()); -// System.out.println(client.getName()); -// System.out.println(collector); collector.clear(); } diff --git a/wayang-fl/src/main/java/org/client/FLClientApp.java b/wayang-fl/src/main/java/org/client/FLClientApp.java index 5718fe783..cbd61c2b0 100644 --- a/wayang-fl/src/main/java/org/client/FLClientApp.java +++ b/wayang-fl/src/main/java/org/client/FLClientApp.java @@ -26,17 +26,19 @@ public class FLClientApp { private final Client client; private final String platform_type; + private final String[] inputFiles; - public FLClientApp(String client_url, String client_id, String platform_type){ + public FLClientApp(String client_url, String client_id, String platform_type, String[] inputFiles){ this.client = new Client(client_url, client_id); this.platform_type = platform_type; + this.inputFiles = inputFiles; } public void startFLClient(Config config){ ActorSystem system = ActorSystem.create(client.getName() + "-system", config); ActorRef FLClientActor = system.actorOf( Props.create(FLClient.class, () -> new FLClient( - client, platform_type + client, platform_type, inputFiles )), client.getName() ); diff --git a/wayang-fl/src/main/java/org/components/aggregator/Aggregator.java b/wayang-fl/src/main/java/org/components/aggregator/Aggregator.java index 5fb2aa2cb..c1758d592 100644 --- a/wayang-fl/src/main/java/org/components/aggregator/Aggregator.java +++ b/wayang-fl/src/main/java/org/components/aggregator/Aggregator.java @@ -20,6 +20,7 @@ import org.functions.AggregatorFunction; +import java.util.Collection; import java.util.List; import java.util.Map; @@ -31,6 +32,13 @@ public Aggregator(AggregatorFunction aggregator){ } public Object aggregate(List ClientResponses, Map server_hyperparams){ +// for(Object response : ClientResponses){ +// System.out.println("printing client response"); +// System.out.println(response); +// for (double o : (double[]) response) { +// System.out.println(" Element: " + o); +// } +// } return aggregator.apply(ClientResponses, server_hyperparams); } } diff --git a/wayang-fl/src/main/java/org/server/FLServer.java b/wayang-fl/src/main/java/org/server/FLServer.java index bbc4904f1..8ba4320b1 100644 --- a/wayang-fl/src/main/java/org/server/FLServer.java +++ b/wayang-fl/src/main/java/org/server/FLServer.java @@ -84,6 +84,7 @@ public void handleSendPlanHyperParametersMessage(SendPlanHyperparametersMessage for(ActorRef client : active_clients.keySet()){ if(!active_clients.get(client)) continue; active_client_count++; + // remove this line later client.tell(new PlanHyperparametersMessage(plan, client_hyperparams), getSelf()); } // while(client_acks < active_client_count){} @@ -105,6 +106,7 @@ public void handleRunIterationMessage(RunIterationMessage msg){ } public void handleAggregateResponsesMessage(AggregateResponsesMessage msg){ + System.out.println("Iteration Over, Aggregating Responses"); Object aggregatedResult = aggregator.aggregate(client_responses, hyperparams); getSender().tell(aggregatedResult, getSelf()); } diff --git a/wayang-fl/src/main/java/org/temp/LibSVMTransform.java b/wayang-fl/src/main/java/org/temp/LibSVMTransform.java new file mode 100644 index 000000000..a8c500279 --- /dev/null +++ b/wayang-fl/src/main/java/org/temp/LibSVMTransform.java @@ -0,0 +1,31 @@ +package org.temp; + +import org.apache.wayang.ml4all.abstraction.api.Transform; +import org.apache.wayang.ml4all.utils.StringUtil; + +import java.util.List; + +public class LibSVMTransform extends Transform { + + int features; + + public LibSVMTransform (int features) { + this.features = features; + } + + @Override + public double[] transform(String line) { + List pointStr = StringUtil.split(line, ' '); + double[] point = new double[features+1]; + point[0] = Double.parseDouble(pointStr.get(0)); + for (int i = 1; i < pointStr.size(); i++) { + if (pointStr.get(i).equals("")) { + continue; + } +// String kv[] = pointStr.get(i).split(":", 2); +// point[Integer.parseInt(kv[0])] = Double.parseDouble(kv[1]); + point[i] = Double.parseDouble(pointStr.get(i)); + } + return point; + } +} \ No newline at end of file diff --git a/wayang-fl/src/main/java/org/temp/test.java b/wayang-fl/src/main/java/org/temp/test.java new file mode 100644 index 000000000..a3a2c0d74 --- /dev/null +++ b/wayang-fl/src/main/java/org/temp/test.java @@ -0,0 +1,78 @@ +package org.temp; + +import org.apache.wayang.api.DataQuantaBuilder; +import org.apache.wayang.api.JavaPlanBuilder; +import org.apache.wayang.core.api.Configuration; +import org.apache.wayang.core.api.WayangContext; +import org.apache.wayang.java.Java; +import org.apache.wayang.ml4all.abstraction.api.Compute; +import org.apache.wayang.ml4all.abstraction.api.Sample; +import org.apache.wayang.ml4all.abstraction.api.Transform; +import org.apache.wayang.ml4all.abstraction.plan.ML4allModel; +import org.apache.wayang.ml4all.abstraction.plan.wrappers.ComputeWrapper; +import org.apache.wayang.ml4all.abstraction.plan.wrappers.TransformPerPartitionWrapper; +import org.apache.wayang.ml4all.algorithms.sgd.ComputeLogisticGradient; +import org.apache.wayang.ml4all.algorithms.sgd.SGDSample; +import org.client.FLClient; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +public class test { + public static void main(String args[]){ + + WayangContext wayangContext = new WayangContext(new Configuration()).withPlugin(Java.basicPlugin()); + JavaPlanBuilder pb = new JavaPlanBuilder(wayangContext) + .withJobName("test-client"+"-job") + .withUdfJarOf(FLClient.class); +// List weights = new ArrayList<>(Collections.nCopies(29, 0.0));; + double[] weights = new double[29]; + String inputFileUrl = "file:/Users/vedantaneogi/Downloads/higgs_part1.txt"; + int datasetSize = 29; + ML4allModel model = new ML4allModel(); + model.put("weights", weights); + ArrayList broadcastModel = new ArrayList<>(1); + broadcastModel.add(model); + // Step 1: Define ML operators + Sample sampleOp = new SGDSample(); + Transform transformOp = new LibSVMTransform(29); + Compute computeOp = new ComputeLogisticGradient(); + + // Step 2: Create weight DataQuanta + var weightsBuilder = pb + .loadCollection(broadcastModel) + .withName("model"); + + // Step 3: Load dataset and apply transform + DataQuantaBuilder transformBuilder = (DataQuantaBuilder) pb + .readTextFile(inputFileUrl) + .withName("source") + .mapPartitions(new TransformPerPartitionWrapper(transformOp)) + .withName("transform"); + + +// Collection parsedData = transformBuilder.collect(); +// for (Object row : parsedData) { +// System.out.println(row); +// } + + // Step 4: Sample, compute gradient, and broadcast weights + DataQuantaBuilder result = (DataQuantaBuilder) transformBuilder + .sample(sampleOp.sampleSize()) + .withSampleMethod(sampleOp.sampleMethod()) + .withDatasetSize(datasetSize) + .map(new ComputeWrapper<>(computeOp)) + .withBroadcast(weightsBuilder, "model"); +// +// System.out.println(result.collect()); +// Collection output = result.collect(); +// for (Object o : result.collect()) { +// System.out.println("Type: " + o.getClass().getName()); +// System.out.println("Value: " + o); +// for (Object idk : (double[]) o){ +// System.out.println(idk); +// } +// } + } +} diff --git a/wayang-fl/src/test/java/org/test/FLClientTest1.java b/wayang-fl/src/test/java/org/test/FLClientTest1.java index aedae082a..d5d22b400 100644 --- a/wayang-fl/src/test/java/org/test/FLClientTest1.java +++ b/wayang-fl/src/test/java/org/test/FLClientTest1.java @@ -34,7 +34,7 @@ public static void setup() { public static void main(String args[]){ - FLClientApp client_app = new FLClientApp("pekko://client1-system@127.0.0.1:2552/user/client1", "client1", "java"); + FLClientApp client_app = new FLClientApp("pekko://client1-system@127.0.0.1:2552/user/client1", "client1", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part1.txt"}); Config config = ConfigFactory.load("client1-application"); client_app.startFLClient(config); } diff --git a/wayang-fl/src/test/java/org/test/FLClientTest2.java b/wayang-fl/src/test/java/org/test/FLClientTest2.java index 7c25e4fdb..d5f011da5 100644 --- a/wayang-fl/src/test/java/org/test/FLClientTest2.java +++ b/wayang-fl/src/test/java/org/test/FLClientTest2.java @@ -33,7 +33,7 @@ public static void setup() { } public static void main(String args[]) throws Exception { - FLClientApp client_app = new FLClientApp("pekko://client2-system@127.0.0.1:2553/user/client2", "client2", "java"); + FLClientApp client_app = new FLClientApp("pekko://client2-system@127.0.0.1:2553/user/client2", "client2", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part2.txt"}); Config config = ConfigFactory.load("client2-application"); client_app.startFLClient(config); } diff --git a/wayang-fl/src/test/java/org/test/FLClientTest3.java b/wayang-fl/src/test/java/org/test/FLClientTest3.java index b523739fb..d10d10f50 100644 --- a/wayang-fl/src/test/java/org/test/FLClientTest3.java +++ b/wayang-fl/src/test/java/org/test/FLClientTest3.java @@ -33,7 +33,7 @@ public static void setup() { } public static void main(String args[]) throws Exception { - FLClientApp client_app = new FLClientApp("pekko://client3-system@127.0.0.1:2554/user/client3", "client3", "java"); + FLClientApp client_app = new FLClientApp("pekko://client3-system@127.0.0.1:2554/user/client3", "client3", "java", new String[] {"file:/Users/vedantaneogi/Downloads/higgs_part3.txt"}); Config config = ConfigFactory.load("client3-application"); client_app.startFLClient(config); } diff --git a/wayang-fl/src/test/java/org/test/FLIntegrationTest.java b/wayang-fl/src/test/java/org/test/FLIntegrationTest.java index ba8ae3b39..9f9e209c6 100644 --- a/wayang-fl/src/test/java/org/test/FLIntegrationTest.java +++ b/wayang-fl/src/test/java/org/test/FLIntegrationTest.java @@ -22,6 +22,13 @@ import com.typesafe.config.ConfigFactory; import org.apache.commons.lang3.tuple.Pair; import org.apache.pekko.actor.ActorSystem; +import org.apache.wayang.api.DataQuantaBuilder; +import org.apache.wayang.basic.operators.SampleOperator; +import org.apache.wayang.ml4all.abstraction.plan.ML4allModel; +import org.apache.wayang.ml4all.algorithms.sgd.ComputeLogisticGradient; +import org.temp.LibSVMTransform; +import org.apache.wayang.ml4all.algorithms.sgd.SGDSample; +import org.apache.wayang.ml4all.utils.StringUtil; import org.client.FLClientApp; import org.components.FLJob; import org.components.FLSystem; @@ -31,10 +38,15 @@ import org.functions.AggregatorFunction; import org.functions.PlanFunction; import org.server.Server; + +import org.apache.wayang.ml4all.abstraction.api.*; +import org.apache.wayang.ml4all.abstraction.plan.wrappers.*; + import org.junit.jupiter.api.AfterAll; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; +import javax.xml.crypto.Data; import java.nio.file.Files; import java.nio.file.Path; import java.util.*; @@ -42,7 +54,13 @@ import static org.junit.jupiter.api.Assertions.assertNotNull; + + + public class FLIntegrationTest { + + + private static Config server_config; private Config client_config; @BeforeAll @@ -75,28 +93,49 @@ public void testFLWorkflow() throws Exception { AggregatorFunction aggregatorFunction = (clientResponses, serverHyperparams) -> { - // Check if there are no responses if (clientResponses == null || clientResponses.isEmpty()) { - return List.of(); + return new double[0]; // empty array } - // Assume each response is a List - // Initialize an aggregated list with zeros using the size of the first response - List firstList = (List) clientResponses.get(0); - int size = firstList.size(); - List aggregated = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - aggregated.add(0.0); - } - // Iterate through each response and add the values element-wise + + double[] aggregated = null; + int clientCount = 0; + for (Object response : clientResponses) { - List values = (List) response; - for (int i = 0; i < values.size(); i++) { - aggregated.set(i, aggregated.get(i) + values.get(i)); + if (!(response instanceof Collection)) continue; + + Collection responseList = (Collection) response; + + for (Object o : responseList) { + if (!(o instanceof double[])) continue; + + double[] values = (double[]) o; + + if (aggregated == null) { + aggregated = new double[values.length]; + } + + for (int i = 0; i < values.length; i++) { + aggregated[i] += values[i]; + } + + clientCount++; } } - return aggregated; + + if (aggregated == null || clientCount == 0) return new double[0]; + + // Take average + for (int i = 0; i < aggregated.length; i++) { + aggregated[i] /= clientCount; + } + + // Discard index 0 + return Arrays.copyOfRange(aggregated, 1, aggregated.length); }; + + + // AggregatorFunction aggregatorFunction = (clientResponses, serverHyperparams) -> { // // Check if there are no responses // if (clientResponses == null || clientResponses.isEmpty()) { @@ -110,13 +149,6 @@ public void testFLWorkflow() throws Exception { // return aggregated; // }; -// for(int i = 0; i < clientNames.size(); i++){ -// FLClientApp client_app = new FLClientApp(clientUrls.get(i), clientNames.get(i), "java"); -// client_config = ConfigFactory.load(clientConfigs.get(i)); -// client_app.startFLClient(client_config); -// } - - // A simple aggregator that just returns the responses as-is. Aggregator aggregator = new Aggregator(aggregatorFunction); @@ -125,29 +157,81 @@ public void testFLWorkflow() throws Exception { // Dummy hyperparameters (assumes a default constructor). Hyperparameters hyperparameters = new Hyperparameters(); + hyperparameters.update_server_hyperparams("eta", 0.01); + hyperparameters.update_client_hyperparams("datasetSize", 333); +// PlanFunction planFunction = (w, pb, m) -> pb +// .loadCollection((List)w).withName("init weights") +// .map(value -> value + 2.0) +// .withName("Square elements") +// .dataQuanta() +// .operator(); + + PlanFunction planFunction = (operand, pb, hyperparams) -> { + // Step 0: Cast operand and extract hyperparams + System.out.println(Arrays.toString((double[])operand)); + double[] weights = (double[]) operand; + String inputFileUrl = ((String[]) hyperparams.get("inputFiles"))[0]; + int datasetSize = (int) hyperparams.get("datasetSize"); + + ML4allModel model = new ML4allModel(); + model.put("weights", weights); + ArrayList broadcastModel = new ArrayList<>(1); + broadcastModel.add(model); + + // Step 1: Define ML operators + Sample sampleOp = new SGDSample(); + Transform transformOp = new LibSVMTransform(29); + Compute computeOp = new ComputeLogisticGradient(); + + // Step 2: Create weight DataQuanta + var weightsBuilder = pb + .loadCollection(broadcastModel) + .withName("model"); - // A dummy plan function that does nothing. - PlanFunction planFunction = (w, pb, m) -> pb - .loadCollection((List)w).withName("init weights") - .map(value -> value + 2.0) - .withName("Square elements") - .dataQuanta() - .operator(); + // Step 3: Load dataset and apply transform + DataQuantaBuilder transformBuilder = (DataQuantaBuilder) pb + .readTextFile(inputFileUrl) + .withName("source") + .mapPartitions(new TransformPerPartitionWrapper(transformOp)) + .withName("transform"); + +// Collection parsedData = transformBuilder.collect(); +// for (Object row : parsedData) { +// System.out.println(row); +// } + + // Step 4: Sample, compute gradient, and broadcast weights + DataQuantaBuilder result = (DataQuantaBuilder) transformBuilder + .sample(sampleOp.sampleSize()) + .withSampleMethod(sampleOp.sampleMethod()) + .withDatasetSize(datasetSize) + .map(new ComputeWrapper<>(computeOp)) + .withBroadcast(weightsBuilder, "model"); + + return result.dataQuanta().operator(); + }; -// PlanFunction planFunction = (w, pb, m) -> (Integer)w + 1; - // Dummy initial values and operand. Map initialValues = new HashMap<>(); initialValues.put("epoch", 0); - Object initialOperand = (Object) new ArrayList<>(Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0)); -// Object initialOperand = (Object) 1; + Object initialOperand = new double[29]; // Dummy update rules (empty in this case). Map> updateRules = new HashMap<>(); updateRules.put("epoch", epoch -> (int)epoch + 1); // Dummy update operand function that simply returns the first part of the pair. - Function, Object> updateOperand = pair -> pair.getRight(); + Function, Object> updateOperand = pair -> { + double[] left = (double[]) pair.getLeft(); + double[] right = (double[]) pair.getRight(); + double[] updated = new double[left.length]; + + for (int i = 0; i < left.length; i++) { + updated[i] = left[i] - 0.01 * right[i]; + } + + return updated; + }; FLSystem flSystem = new FLSystem(server.getName(), server.getUrl(), clientNames, clientUrls); @@ -168,8 +252,9 @@ public void testFLWorkflow() throws Exception { flSystem.startFLJob(jobId); // Obtain the final result from the job. - Object finalResult = flSystem.getFLJobResult(jobId); - System.out.println("Final result: " + finalResult); + double[] finalResult = (double[]) flSystem.getFLJobResult(jobId); + + System.out.println("Final result: " + Arrays.toString(finalResult)); // Write the final result to a file in the target/test-output folder. Path outputPath = Path.of("target", "test-output.txt");