Skip to content

Commit e95842c

Browse files
committed
openvla policy intergration pull request
1 parent 8c0a23d commit e95842c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

simpler_env/main_inference.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from simpler_env.evaluation.maniskill2_evaluator import maniskill2_evaluator
88
from simpler_env.policies.octo.octo_server_model import OctoServerInference
99
from simpler_env.policies.rt1.rt1_model import RT1Inference
10-
from simpler_env.policies.openvla.openvla_model import OpenVALInference
10+
from simpler_env.policies.openvla.openvla_model import OpenVLAInference
1111

1212
try:
1313
from simpler_env.policies.octo.octo_model import OctoInference
@@ -56,7 +56,7 @@
5656
)
5757
elif args.policy_model == "openvla":
5858
assert args.ckpt_path is not None
59-
model = OpenVALInference(
59+
model = OpenVLAInference(
6060
saved_model_path=args.ckpt_path,
6161
policy_setup=args.policy_setup,
6262
action_scale=args.action_scale,

simpler_env/policies/openvla/openvla_model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import cv2 as cv
1212

1313

14-
class OpenVALInference:
14+
class OpenVLAInference:
1515
def __init__(
1616
self,
1717
saved_model_path: str = "openvla/openvla-7b",
@@ -144,7 +144,7 @@ def step(
144144
relative_gripper_action = self.previous_gripper_action - current_gripper_action
145145
self.previous_gripper_action = current_gripper_action
146146

147-
if np.abs(relative_gripper_action) > 0.5 and self.sticky_action_is_on is False:
147+
if np.abs(relative_gripper_action) > 0.5 and (not self.sticky_action_is_on):
148148
self.sticky_action_is_on = True
149149
self.sticky_gripper_action = relative_gripper_action
150150

0 commit comments

Comments
 (0)