Skip to content

Commit da64d7e

Browse files
committed
support ray job submit to run p2p agent
1 parent 16ecee4 commit da64d7e

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

matrix/agents/examples/ts_interaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def init(
5656
logger: logging.Logger,
5757
) -> None:
5858
task = metadata["task"]
59-
self._id = (task["question_id"],)
59+
self._id = task["question_id"]
6060
await super().init(
6161
simulation_id, first_agent, sink, metadata, resources, logger
6262
)

matrix/agents/p2p_agents.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ def done(self):
453453
return result
454454

455455

456-
@ray.remote
456+
# @ray.remote
457457
class Sink(AgentActor):
458458
def __init__(
459459
self,
@@ -782,10 +782,14 @@ async def run_simulation(self):
782782
logger.info(f"Configuration:\n{OmegaConf.to_yaml(self.cfg, resolve=True)}")
783783
cli = Cli(**self.cfg.matrix)
784784
if not ray.is_initialized():
785-
ray.init(
786-
address=get_ray_address(cli.cluster.cluster_info()), # type: ignore[arg-type]
787-
log_to_driver=True,
788-
)
785+
if os.environ.get("RAY_ADDRESS"):
786+
# already inside ray
787+
ray.init()
788+
else:
789+
ray.init(
790+
address=get_ray_address(cli.cluster.cluster_info()), # type: ignore[arg-type]
791+
log_to_driver=True,
792+
)
789793

790794
# Load tasks
791795
self.data_loader = instantiate(self.cfg.dataset)
@@ -862,10 +866,14 @@ def main(cfg: DictConfig):
862866
setup_logging(logger, cfg.get("debug", False))
863867
cli = Cli(**cfg.matrix)
864868
if not ray.is_initialized():
865-
ray.init(
866-
address=get_ray_address(cli.cluster.cluster_info()), # type: ignore[arg-type]
867-
log_to_driver=True,
868-
)
869+
if os.environ.get("RAY_ADDRESS"):
870+
# already inside ray
871+
ray.init()
872+
else:
873+
ray.init(
874+
address=get_ray_address(cli.cluster.cluster_info()), # type: ignore[arg-type]
875+
log_to_driver=True,
876+
)
869877

870878
logger.info(f"Launching {num_tasks} Ray actors for parallel processing")
871879

0 commit comments

Comments
 (0)