Skip to content

Commit f2b5fcb

Browse files
committed
Refs #22737: Add logic for reiterate tasks
Signed-off-by: Javier Gil Aviles <[email protected]>
1 parent 47d98f5 commit f2b5fcb

File tree

3 files changed

+82
-1
lines changed

3 files changed

+82
-1
lines changed

sustainml_cpp/src/cpp/orchestrator/OrchestratorNode.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,7 @@ RetCode_t OrchestratorNode::get_task_data(
461461
ret = RetCode_t::RETCODE_OK;
462462
}
463463
}
464+
break;
464465
}
465466
default:
466467
{

sustainml_modules/sustainml_modules/sustainml-wp5/orchestrator_node/orchestrator_node.py

+67-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sustainml_swig import NodeStatus
2222
import sustainml_swig
2323
import threading
24+
import json
2425

2526
class OrchestratorNodeHandle(cpp_OrchestratorNodeHandle):
2627

@@ -243,6 +244,50 @@ def get_carbontracker(self, task_id):
243244
'carbon_intensity': carbon_intensity}
244245
return json_output
245246

247+
def get_orchestrator(self, task_id):
248+
249+
# retrieve node data
250+
node_data = sustainml_swig.get_orchestrator(self.node_, task_id)
251+
if node_data is None:
252+
return {'Error': f"Failed to get {utils.string_node(utils.node_id.ORCHESTRATOR.value)} data for task {utils.string_task(task_id)}"}
253+
254+
# Parse data into json
255+
task_json = {'problem_id': task_id.problem_id(), 'iteration_id': task_id.iteration_id()}
256+
modality = node_data.modality()
257+
problem_short_description = node_data.problem_short_description()
258+
problem_definition = node_data.problem_definition()
259+
inputs = node_data.inputs()
260+
outputs = node_data.outputs()
261+
minimum_samples = node_data.minimum_samples()
262+
maximum_samples = node_data.maximum_samples()
263+
optimize_carbon_footprint_manual = node_data.optimize_carbon_footprint_manual()
264+
previous_iteration = node_data.previous_iteration()
265+
optimize_carbon_footprint_auto = node_data.optimize_carbon_footprint_auto()
266+
desired_carbon_footprint = node_data.desired_carbon_footprint()
267+
geo_location_continent = node_data.geo_location_continent()
268+
geo_location_region = node_data.geo_location_region()
269+
extra_data_vector = node_data.extra_data()
270+
extra_data_list = [s for s in extra_data_vector]
271+
extra_data_bytes = bytes(extra_data_list)
272+
extra_data_str = extra_data_bytes.decode('utf-8')
273+
extra_data = json.loads(extra_data_str)
274+
json_output = {'task_id': task_json,
275+
'modality': modality,
276+
'problem_short_description': problem_short_description,
277+
'problem_definition': problem_definition,
278+
'inputs': " ".join(inputs),
279+
'outputs': " ".join(outputs),
280+
'minimum_samples': minimum_samples,
281+
'maximum_samples': maximum_samples,
282+
'optimize_carbon_footprint_manual': optimize_carbon_footprint_manual,
283+
'previous_iteration': previous_iteration,
284+
'optimize_carbon_footprint_auto': optimize_carbon_footprint_auto,
285+
'desired_carbon_footprint': desired_carbon_footprint,
286+
'geo_location_continent': geo_location_continent,
287+
'geo_location_region': geo_location_region,
288+
'extra_data': extra_data}
289+
return json_output
290+
246291
def get_results(self, node_id, task_id):
247292
if task_id is None:
248293
task_id = self.get_last_task_id()
@@ -259,13 +304,34 @@ def get_results(self, node_id, task_id):
259304
return self.get_hw_provider(task_id)
260305
elif node_id == utils.node_id.CARBONTRACKER.value:
261306
return self.get_carbontracker(task_id)
307+
elif node_id == utils.node_id.ORCHESTRATOR.value:
308+
return self.get_orchestrator(task_id)
262309
else:
263310
message = utils.string_node(node_id) + " node does not have any results to show."
264311
return {'message': message, 'task_id': utils.task_json(task_id)}
265312

266313
def send_user_input(self, json_data):
267-
pair = self.node_.prepare_new_task()
314+
if json_data.get('previous_iteration') == 0:
315+
pair = self.node_.prepare_new_task()
316+
else:
317+
previous_task = sustainml_swig.TaskId()
318+
previous_task.iteration_id(json_data.get('previous_iteration'))
319+
extra = json_data.get('extra_data', {})
320+
previous_task.problem_id(extra.get('previous_problem_id'))
321+
# Verify the last doesn't exist yet
322+
existing_task = self.get_last_task_id()
323+
if existing_task is not None:
324+
if (previous_task.problem_id() == existing_task.problem_id() and
325+
previous_task.iteration_id() + 1 == existing_task.iteration_id()):
326+
print("Task already taken. Using :", utils.string_task(existing_task))
327+
return None
328+
pair = self.node_.prepare_new_iteration(previous_task)
268329
task_id = pair[0]
330+
331+
print("Task:", utils.string_task(task_id))
332+
print("Problem ID:", task_id.problem_id()) # Debugging
333+
print("Iteration ID:", task_id.iteration_id())
334+
269335
user_input = pair[1]
270336
self.handler_.register_task(task_id)
271337

sustainml_swig/src/swig/sustainml_swig/nodes/OrchestratorNode.i

+14
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,20 @@
126126
return node;
127127
}
128128

129+
types::UserInput* get_orchestrator(
130+
sustainml::orchestrator::OrchestratorNode* orchestrator,
131+
const types::TaskId& task_id)
132+
{
133+
void* data = nullptr;
134+
types::UserInput* node = nullptr;
135+
if (sustainml::RetCode_t::RETCODE_OK == orchestrator->get_task_data(
136+
task_id, sustainml::NodeID::ID_ORCHESTRATOR, data))
137+
{
138+
node = static_cast<types::UserInput*>(data);
139+
}
140+
return node;
141+
}
142+
129143
types::TaskId* get_task_id(
130144
const sustainml::NodeID node_id,
131145
void* data)

0 commit comments

Comments
 (0)