Skip to content

Vectorized/GPU Sim Evaluation / ManiSkill 3 Port #38

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,8 @@ dist/
.vscode

imgui.ini

octo/
checkpoints/
ManiSkill2_real2sim/
videos/
239 changes: 72 additions & 167 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion simpler_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import gymnasium as gym
import mani_skill2_real2sim.envs
import mani_skill.envs

ENVIRONMENTS = [
"google_robot_pick_coke_can",
Expand Down
85 changes: 52 additions & 33 deletions simpler_env/policies/octo/octo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,24 @@
import tensorflow as tf
from transformers import AutoTokenizer
from transforms3d.euler import euler2axangle

from functools import partial
from simpler_env.utils.action.action_ensemble import ActionEnsembler
from mani_skill.utils.geometry import rotation_conversions
from mani_skill.utils import common
import torch
from torch.utils import dlpack as torch_dlpack

from jax import dlpack as jax_dlpack
import jax.numpy as jnp

def torch2jax(x_torch):
x_torch = x_torch.contiguous() # https://github.com/google/jax/issues/8082
x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
return x_jax

def jax2torch(x_jax):
x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
return x_torch

class OctoInference:
def __init__(
Expand Down Expand Up @@ -55,6 +70,8 @@ def __init__(
self.model = OctoModel.load_pretrained(self.model_type)
self.action_mean = self.model.dataset_statistics[dataset_id]["action"]["mean"]
self.action_std = self.model.dataset_statistics[dataset_id]["action"]["std"]
self.action_mean = jnp.array(self.action_mean)
self.action_std = jnp.array(self.action_std)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -86,13 +103,9 @@ def __init__(
self.num_image_history = 0

def _resize_image(self, image: np.ndarray) -> np.ndarray:
image = tf.image.resize(
image,
size=(self.image_size, self.image_size),
method="lanczos3",
antialias=True,
)
image = tf.cast(tf.clip_by_value(tf.round(image), 0, 255), tf.uint8).numpy()
"""resize image to a square image of size self.image_size. image should be shape (B, H, W, 3)"""
image = jax.vmap(partial(jax.image.resize, shape=(self.image_size, self.image_size, 3), method="lanczos3", antialias=True))(image)
image = jnp.clip(jnp.round(image), 0, 255).astype(jnp.uint8)
return image

def _add_image_to_history(self, image: np.ndarray) -> None:
Expand All @@ -105,17 +118,20 @@ def _add_image_to_history(self, image: np.ndarray) -> None:
self.num_image_history = min(self.num_image_history + 1, self.horizon)

def _obtain_image_history_and_mask(self) -> tuple[np.ndarray, np.ndarray]:
images = np.stack(self.image_history, axis=0)
images = jnp.stack(self.image_history, axis=1)
batch_size = images.shape[0]
horizon = len(self.image_history)
pad_mask = np.ones(horizon, dtype=np.float64) # note: this should be of float type, not a bool type
pad_mask[: horizon - min(horizon, self.num_image_history)] = 0
pad_mask = jnp.ones((batch_size, horizon), dtype=jnp.float32) # note: this should be of float type, not a bool type
pad_mask = pad_mask.at[:, : horizon - min(horizon, self.num_image_history)].set(0)
# pad_mask = np.ones(self.horizon, dtype=np.float64) # note: this should be of float type, not a bool type
# pad_mask[:self.horizon - self.num_image_history] = 0
return images, pad_mask

def reset(self, task_description: str) -> None:
self.task = self.model.create_tasks(texts=[task_description])
self.task_description = task_description
def reset(self, task_descriptions: str) -> None:
if isinstance(task_descriptions, str):
task_descriptions = [task_descriptions]
self.task = self.model.create_tasks(texts=task_descriptions)
self.task_description = task_descriptions
self.image_history.clear()
if self.action_ensemble:
self.action_ensembler.reset()
Expand All @@ -130,7 +146,7 @@ def reset(self, task_description: str) -> None:
def step(self, image: np.ndarray, task_description: Optional[str] = None, *args, **kwargs) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]:
"""
Input:
image: np.ndarray of shape (H, W, 3), uint8
image: np.ndarray/torch tensor of shape (B, H, W, 3), uint8
task_description: Optional[str], task description; if different from previous task description, policy state is reset
Output:
raw_action: dict; raw policy action output
Expand All @@ -145,45 +161,48 @@ def step(self, image: np.ndarray, task_description: Optional[str] = None, *args,
# task description has changed; reset the policy state
self.reset(task_description)

assert image.dtype == np.uint8
# assert image.dtype == np.uint8
assert len(image.shape) == 4, "image shape should be (batch_size, height, width, 3)"
batch_size = image.shape[0]
image = torch2jax(image)
image = self._resize_image(image)
self._add_image_to_history(image)
images, pad_mask = self._obtain_image_history_and_mask()
images, pad_mask = images[None], pad_mask[None]

# we need use a different rng key for each model forward step; this has a large impact on model performance
self.rng, key = jax.random.split(self.rng) # each shape [2,]
# print("octo local rng", self.rng, key)

input_observation = {"image_primary": images, "pad_mask": pad_mask}
# images.shape (b, h, w, c, 3), pad_mask.shape (b, h)
norm_raw_actions = self.model.sample_actions(
input_observation,
self.task,
rng=key,
)
raw_actions = norm_raw_actions * self.action_std[None] + self.action_mean[None]
raw_actions = raw_actions[0] # remove batch, becoming (action_pred_horizon, action_dim)

assert raw_actions.shape == (self.pred_action_horizon, 7)
assert raw_actions.shape == (batch_size, self.pred_action_horizon, 7)
if self.action_ensemble:
raw_actions = self.action_ensembler.ensemble_action(raw_actions)
raw_actions = raw_actions[None] # [1, 7]

raw_actions = jax2torch(raw_actions)
raw_action = {
"world_vector": np.array(raw_actions[0, :3]),
"rotation_delta": np.array(raw_actions[0, 3:6]),
"open_gripper": np.array(raw_actions[0, 6:7]), # range [0, 1]; 1 = open; 0 = close
"world_vector": raw_actions[:, :3],
"rotation_delta": raw_actions[:, 3:6],
"open_gripper": raw_actions[:, 6:7], # range [0, 1]; 1 = open; 0 = close
}

# process raw_action to obtain the action to be sent to the maniskill2 environment
raw_action = common.to_tensor(raw_action)

# TODO (stao): check if we need torch float 64s.
# process raw_action to obtain the action to be sent to the maniskill environment
action = {}
action["world_vector"] = raw_action["world_vector"] * self.action_scale
action_rotation_delta = np.asarray(raw_action["rotation_delta"], dtype=np.float64)
roll, pitch, yaw = action_rotation_delta
action_rotation_ax, action_rotation_angle = euler2axangle(roll, pitch, yaw)
action_rotation_axangle = action_rotation_ax * action_rotation_angle
action["rot_axangle"] = action_rotation_axangle * self.action_scale

# action_rotation_delta = np.asarray(raw_action["rotation_delta"], dtype=np.float64)
# roll, pitch, yaw = action_rotation_delta
# action_rotation_ax, action_rotation_angle = euler2axangle(roll, pitch, yaw)
# action_rotation_axangle = action_rotation_ax * action_rotation_angle
# action["rot_axangle"] = action_rotation_axangle * self.action_scale
# TODO: is there a better conversion from euler angles to axis angle?
action["rot_axangle"] = rotation_conversions.matrix_to_axis_angle(rotation_conversions.euler_angles_to_matrix(raw_action["rotation_delta"], "XYZ"))
if self.policy_setup == "google_robot":
current_gripper_action = raw_action["open_gripper"]

Expand Down
192 changes: 192 additions & 0 deletions simpler_env/real2sim_eval_maniskill3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from collections import defaultdict
import json
import os
import signal
import time
import numpy as np
from typing import Annotated, Optional

import torch
import tree
from mani_skill.utils import common
from mani_skill.utils import visualization
from mani_skill.utils.visualization.misc import images_to_video
signal.signal(signal.SIGINT, signal.SIG_DFL) # allow ctrl+c
from simpler_env.utils.env.observation_utils import get_image_from_maniskill3_obs_dict

import gymnasium as gym
import numpy as np
from mani_skill.envs.tasks.digital_twins.bridge_dataset_eval import *
from mani_skill.envs.sapien_env import BaseEnv
import tyro
from dataclasses import dataclass
from pathlib import Path

@dataclass
class Args:
"""
This is a script to evaluate policies on real2sim environments. Example command to run:

XLA_PYTHON_CLIENT_PREALLOCATE=false python real2sim_eval_maniskill3.py \
--model="octo-small" -e "PutEggplantInBasketScene-v1" -s 0 --num-episodes 192 --num-envs 64
"""


env_id: Annotated[str, tyro.conf.arg(aliases=["-e"])] = "PutCarrotOnPlateInScene-v1"
"""The environment ID of the task you want to simulate. Can be one of
PutCarrotOnPlateInScene-v1, PutSpoonOnTableClothInScene-v1, StackGreenCubeOnYellowCubeBakedTexInScene-v1, PutEggplantInBasketScene-v1"""

shader: str = "default"

num_envs: int = 1
"""Number of environments to run. With more than 1 environment the environment will use the GPU backend
which runs faster enabling faster large-scale evaluations. Note that the overall behavior of the simulation
will be slightly different between CPU and GPU backends."""

num_episodes: int = 100
"""Number of episodes to run and record evaluation metrics over"""

record_dir: str = "videos"
"""The directory to save videos and results"""

model: Optional[str] = None
"""The model to evaluate on the given environment. Can be one of octo-base, octo-small, rt-1x. If not given, random actions are sampled."""

ckpt_path: str = ""
"""Checkpoint path for models. Only used for RT models"""

seed: Annotated[int, tyro.conf.arg(aliases=["-s"])] = 0
"""Seed the model and environment. Default seed is 0"""

reset_by_episode_id: bool = True
"""Whether to reset by fixed episode ids instead of random sampling initial states."""

info_on_video: bool = False
"""Whether to write info text onto the video"""

save_video: bool = True
"""Whether to save videos"""

debug: bool = False

def main():
args = tyro.cli(Args)
if args.seed is not None:
np.random.seed(args.seed)


sensor_configs = dict()
sensor_configs["shader_pack"] = args.shader
env: BaseEnv = gym.make(
args.env_id,
obs_mode="rgb+segmentation",
num_envs=args.num_envs,
sensor_configs=sensor_configs
)
sim_backend = 'gpu' if env.device.type == 'cuda' else 'cpu'

# Setup up the policy inference model
model = None
try:

policy_setup = "widowx_bridge"
if args.model is None:
pass
else:
from simpler_env.policies.rt1.rt1_model import RT1Inference
from simpler_env.policies.octo.octo_model import OctoInference
if args.model == "octo-base" or args.model == "octo-small":
model = OctoInference(model_type=args.model, policy_setup=policy_setup, init_rng=args.seed, action_scale=1)
elif args.model == "rt-1x":
ckpt_path=args.ckpt_path
model = RT1Inference(
saved_model_path=ckpt_path,
policy_setup=policy_setup,
action_scale=1,
)
elif args.model is not None:
raise ValueError(f"Model {args.model} does not exist / is not supported.")
except:
if args.model is not None:
raise Exception("SIMPLER Env Policy Inference is not installed")

model_name = args.model if args.model is not None else "random"
if model_name == "random":
print("Using random actions.")
exp_dir = os.path.join(args.record_dir, f"real2sim_eval/{model_name}_{args.env_id}")
Path(exp_dir).mkdir(parents=True, exist_ok=True)

eval_metrics = defaultdict(list)
eps_count = 0

print(f"Running Real2Sim Evaluation of model {args.model} on environment {args.env_id}")
print(f"Using {args.num_envs} environments on the {sim_backend} simulation backend")

timers = {"env.step+inference": 0, "env.step": 0, "inference": 0, "total": 0}
total_start_time = time.time()

while eps_count < args.num_episodes:
seed = args.seed + eps_count
obs, _ = env.reset(seed=seed, options={"episode_id": torch.tensor([seed + i for i in range(args.num_envs)])})
instruction = env.unwrapped.get_language_instruction()
print("instruction:", instruction[0])
if model is not None:
model.reset(instruction)
images = []
predicted_terminated, truncated = False, False
images.append(get_image_from_maniskill3_obs_dict(env, obs))
elapsed_steps = 0
while not (predicted_terminated or truncated):
if model is not None:
start_time = time.time()
raw_action, action = model.step(images[-1], instruction)
action = torch.cat([action["world_vector"], action["rot_axangle"], action["gripper"]], dim=1)
timers["inference"] += time.time() - start_time
else:
action = env.action_space.sample()

if elapsed_steps > 0:
if args.save_video and args.info_on_video:
for i in range(len(images[-1])):
images[-1][i] = visualization.put_info_on_image(images[-1][i], tree.map_structure(lambda x: x[i], info))

start_time = time.time()
obs, reward, terminated, truncated, info = env.step(action)
timers["env.step"] += time.time() - start_time
elapsed_steps += 1
info = common.to_numpy(info)

truncated = bool(truncated.any()) # note that all envs truncate and terminate at the same time.
images.append(get_image_from_maniskill3_obs_dict(env, obs))

for k, v in info.items():
eval_metrics[k].append(v.flatten())
if args.save_video:
for i in range(len(images[-1])):
images_to_video([img[i].cpu().numpy() for img in images], exp_dir, f"{sim_backend}_eval_{seed + i}_success={info['success'][i].item()}", fps=10, verbose=True)
eps_count += args.num_envs
if args.num_envs == 1:
print(f"Evaluated episode {eps_count}. Seed {seed}. Results after {eps_count} episodes:")
else:
print(f"Evaluated {args.num_envs} episodes, seeds {seed} to {eps_count}. Results after {eps_count} episodes:")
for k, v in eval_metrics.items():
print(f"{k}: {np.mean(v)}")
# Print timing information
timers["total"] = time.time() - total_start_time
timers["env.step+inference"] = timers["env.step"] + timers["inference"]
mean_metrics = {k: np.mean(v) for k, v in eval_metrics.items()}
mean_metrics["total_episodes"] = eps_count
mean_metrics["time/episodes_per_second"] = eps_count / timers["total"]
print("Timing Info:")
for key, value in timers.items():
mean_metrics[f"time/{key}"] = value
print(f"{key}: {value:.2f} seconds")
metrics_path = os.path.join(exp_dir, f"{sim_backend}_eval_metrics.json")
if sim_backend == "gpu":
metrics_path = metrics_path.replace("gpu", f"gpu_{args.num_envs}_envs")
with open(metrics_path, "w") as f:
json.dump(mean_metrics, f, indent=4)
print(f"Evaluation complete. Results saved to {exp_dir}. Metrics saved to {metrics_path}")

if __name__ == "__main__":
main()
17 changes: 10 additions & 7 deletions simpler_env/utils/action/action_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections import deque

import numpy as np

import jax.numpy as jnp

class ActionEnsembler:
def __init__(self, pred_action_horizon, action_ensemble_temp=0.0):
Expand All @@ -18,13 +18,16 @@ def ensemble_action(self, cur_action):
if cur_action.ndim == 1:
curr_act_preds = np.stack(self.action_history)
else:
curr_act_preds = np.stack(
[pred_actions[i] for (i, pred_actions) in zip(range(num_actions - 1, -1, -1), self.action_history)]
)
curr_act_preds = jnp.stack(
[pred_actions[:, i] for (i, pred_actions) in zip(range(num_actions - 1, -1, -1), self.action_history)]
) # shape (1 to self.pred_action_horizon, batch_size, action_dim)
# more recent predictions get exponentially *less* weight than older predictions
weights = np.exp(-self.action_ensemble_temp * np.arange(num_actions))
weights = jnp.exp(-self.action_ensemble_temp * jnp.arange(num_actions))
weights = weights / weights.sum()
# compute the weighted average across all predictions for this timestep
cur_action = np.sum(weights[:, None] * curr_act_preds, axis=0)

# Expand weights to match batch and action dimensions
weights_expanded = weights[:, None, None]

# Apply weights across all batches and sum
cur_action = jnp.sum(weights_expanded * curr_act_preds, axis=0)
return cur_action
Loading