Skip to content

Commit e48dba9

Browse files
authored
[RLlib] Add LSTM option to run_connector_policyexample (old API stack; w/ manual Connector.reset() call). (ray-project#45829)
1 parent 8b224dc commit e48dba9

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

rllib/BUILD

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2052,6 +2052,15 @@ py_test(
20522052
srcs = ["examples/_old_api_stack/connectors/run_connector_policy.py"],
20532053
)
20542054

2055+
py_test(
2056+
name = "examples/_old_api_stack/connectors/run_connector_policy_w_lstm",
2057+
main = "examples/_old_api_stack/connectors/run_connector_policy.py",
2058+
tags = ["team:rllib", "exclusive", "examples", "old_api_stack"],
2059+
size = "small",
2060+
srcs = ["examples/_old_api_stack/connectors/run_connector_policy.py"],
2061+
args = ["--use-lstm"],
2062+
)
2063+
20552064
py_test(
20562065
name = "examples/_old_api_stack/connectors/adapt_connector_policy",
20572066
main = "examples/_old_api_stack/connectors/adapt_connector_policy.py",

rllib/examples/_old_api_stack/connectors/prepare_checkpoint.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@
44
from ray.rllib.algorithms.sac import SACConfig
55

66

7-
def create_appo_cartpole_checkpoint(output_dir):
7+
def create_appo_cartpole_checkpoint(output_dir, use_lstm=False):
88
# enable_connectors defaults to True. Just trying to be explicit here.
9-
config = APPOConfig().environment("CartPole-v1").env_runners(enable_connectors=True)
9+
config = (
10+
APPOConfig()
11+
.environment("CartPole-v1")
12+
.env_runners(enable_connectors=True)
13+
.training(model={"use_lstm": use_lstm})
14+
)
1015
# Build algorithm object.
1116
algo = config.build()
1217
algo.save(checkpoint_dir=output_dir)

rllib/examples/_old_api_stack/connectors/run_connector_policy.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
and use it in a serving/inference setting.
33
"""
44

5+
import argparse
56
import gymnasium as gym
67
import os
78
import tempfile
@@ -13,6 +14,9 @@
1314
from ray.rllib.policy.policy import Policy
1415
from ray.rllib.utils.policy import local_policy_inference
1516

17+
parser = argparse.ArgumentParser()
18+
parser.add_argument("--use-lstm", action="store_true", help="Add LSTM to the setup.")
19+
1620

1721
def run(checkpoint_path, policy_id):
1822
# __sphinx_doc_begin__
@@ -24,34 +28,45 @@ def run(checkpoint_path, policy_id):
2428

2529
# Run CartPole.
2630
env = gym.make("CartPole-v1")
31+
env_id = "env_1"
2732
obs, info = env.reset()
28-
terminated = truncated = False
29-
step = 0
30-
while not terminated and not truncated:
31-
step += 1
32-
33+
# Run for 2 episodes.
34+
episodes = step = 0
35+
while episodes < 2:
3336
# Use local_policy_inference() to run inference, so we do not have to
3437
# provide policy states or extra fetch dictionaries.
3538
# "env_1" and "agent_1" are dummy env and agent IDs to run connectors with.
3639
policy_outputs = local_policy_inference(
37-
policy, "env_1", "agent_1", obs, explore=False
40+
policy, env_id, "agent_1", obs, explore=False
3841
)
3942
assert len(policy_outputs) == 1
4043
action, _, _ = policy_outputs[0]
41-
print(f"step {step}", obs, action)
44+
print(f"episode {episodes} step {step}", obs, action)
4245

4346
# Step environment forward one more step.
4447
obs, _, terminated, truncated, _ = env.step(action)
48+
step += 1
49+
50+
# If the episode is done, reset the env and our connectors and start a new
51+
# episode.
52+
if terminated or truncated:
53+
episodes += 1
54+
step = 0
55+
obs, info = env.reset()
56+
policy.agent_connectors.reset(env_id)
57+
4558
# __sphinx_doc_end__
4659

4760

4861
if __name__ == "__main__":
62+
args = parser.parse_args()
63+
4964
with tempfile.TemporaryDirectory() as tmpdir:
5065
policy_id = "default_policy"
5166

5267
# Note, this is just for demo purpose.
5368
# Normally, you would use a policy checkpoint from a real training run.
54-
create_appo_cartpole_checkpoint(tmpdir)
69+
create_appo_cartpole_checkpoint(tmpdir, args.use_lstm)
5570
policy_checkpoint_path = os.path.join(
5671
tmpdir,
5772
"policies",

0 commit comments

Comments
 (0)