diff --git a/produce_rf_plot.py b/produce_rf_plot.py deleted file mode 100644 index 065ae11..0000000 --- a/produce_rf_plot.py +++ /dev/null @@ -1,114 +0,0 @@ -import json -import os -import sys -from pathlib import Path -from typing import Any, Optional - -import numpy as np -import numpy.typing as npt -import yaml -from matplotlib import pyplot as plt - - -def rescale(x: npt.NDArray[Any], min: float = 0, max: float = 1) -> npt.NDArray[Any]: - return ((x - x.min()) / (x.max() - x.min())) * (max - min) + min - - -def reshape_images( - arr: npt.NDArray[Any], - n_rows: Optional[int] = None, - n_cols: Optional[int] = None, - whitespace: float = 0.1, - rescale_individ: bool = False, -): - n, _, w, h = arr.shape - whitespace_pix = np.round(whitespace * max(w, h)).astype(int) - if n_rows is None and n_cols is None: - n_rows = 1 - if n_rows is None: - n_rows = (n + n_cols - 1) // n_cols - if n_cols is None: - n_cols = (n + n_rows - 1) // n_rows - - # Calculate the total width and height of the final image - total_width = n_cols * w + (n_cols - 1) * whitespace_pix - total_height = n_rows * h + (n_rows - 1) * whitespace_pix - - # Create a new image with the calculated dimensions - final_image = np.full( - (3, total_height, total_width), - 1 if rescale_individ else arr.max(), - dtype=np.float32, - ) - - # Populate the final image with the individual images - for i in range(n): - row = i // n_cols - col = i % n_cols - x1 = col * (w + whitespace_pix) - y1 = row * (h + whitespace_pix) - x2 = x1 + w - y2 = y1 + h - final_image[:, y1:y2, x1:x2] = rescale(arr[i]) if rescale_individ else arr[i] - - if not rescale_individ: - final_image = rescale(final_image) - - return np.moveaxis(final_image, 0, -1) - - -def produce_image(experiment_path: Path, out_dir: Path, last=True): - rf_dir = experiment_path / "data/analyses" - config_path = experiment_path / "config/config.yaml" - - index = -1 if last else 3 - - with open(config_path) as file: - config = yaml.safe_load(file) - - rf_files = os.listdir(rf_dir) - rf_files.sort(key=lambda f: os.path.getctime(rf_dir / f)) - - cur_file = rf_files[index] - with open(rf_dir / cur_file) as f: - rf = json.load(f) - - hyper_params = [ - config["env_name"][10:], - "rnn" if "rnn" in config["brain"]["circuits"] else "feedforward", - "weight=" + str(config["recon_weight"]), - "step=" + cur_file.split("_")[-1][:-5], - ] - comp_layer_rfs = [] - for i, (layer, layer_rfs) in enumerate(rf.items()): - comp_layer_rfs.append( - reshape_images(np.array(layer_rfs), n_cols=8, whitespace=0.1) - ) - - height_ratios = [x.shape[0] / x.shape[1] for x in comp_layer_rfs] - - plt.subplots( - nrows=len(rf.keys()), - ncols=1, - height_ratios=height_ratios, - figsize=(10, 10 * sum(height_ratios) + 1), - ) - for i, (layer, layer_rfs) in enumerate(zip(rf.keys(), comp_layer_rfs)): - plt.subplot(len(rf.keys()), 1, i + 1) - plt.imshow(layer_rfs) - plt.axis("off") - plt.title(layer, loc="left") - plt.suptitle(str.join(", ", hyper_params)) - plt.tight_layout() - filename = ( - str.join("_", hyper_params[:-1]) + ("_last" if last else "_early") + ".png" - ) - plt.savefig(out_dir / filename) - plt.close() - - -experiment_path = Path(sys.argv[1]) -out_dir = Path(sys.argv[2]) - -produce_image(experiment_path, out_dir, last=True) -produce_image(experiment_path, out_dir, last=False) diff --git a/retinal_rl/analysis/attribution.py b/retinal_rl/analysis/attribution.py new file mode 100644 index 0000000..7f70289 --- /dev/null +++ b/retinal_rl/analysis/attribution.py @@ -0,0 +1,96 @@ +import torch +from captum.attr import InputXGradient + +from retinal_rl.models.brain import Brain +from retinal_rl.util import rescale_zero_one + + +def l1_attribution( + brain: Brain, + stimuli: dict[str, torch.Tensor], + target_circuit: torch.Tensor, + target_output_index: int = 0, +) -> dict[str, torch.Tensor]: + input_grads: dict[str, torch.Tensor] = {} + output = brain(stimuli)[target_circuit][target_output_index] + loss = torch.nn.L1Loss()(output, torch.zeros_like(output)) + loss.backward() + for key, value in stimuli.items(): + input_grads[key] = value.grad.detach().cpu() + return input_grads + + +def captum_attribution( + brain: Brain, + stimuli: dict[str, torch.Tensor], + target_circuit: torch.Tensor, + target_output_index: int = 0, +) -> dict[str, torch.Tensor]: + input_grads: dict[str, torch.Tensor] = {} + + stimuli_keys = list(stimuli.keys()) # create list to preserve order + + def _forward(*args: tuple[torch.Tensor]) -> torch.Tensor: + assert len(args) == len(stimuli_keys) + return brain({k: v for k, v in zip(stimuli_keys, args)})[target_circuit][ + target_output_index + ] + + value_grad_calculator = InputXGradient(_forward) + value_grads = value_grad_calculator.attribute( + tuple(stimuli[k] for k in stimuli_keys) + ) + for key, value_grad in zip(stimuli_keys, value_grads): + input_grads[key] = value_grad.detach().cpu() + return input_grads + + +ATTRIBUTION_METHODS = {"l1": l1_attribution, "attribution": captum_attribution} + + +def analyze( + brain: Brain, + stimuli: dict[str, torch.Tensor], + target_circuit: torch.Tensor, + target_output_index: int = 0, + method: str = "l1", + sum_channels: bool = True, + rescale_per_frame: bool = False, +) -> dict[str, torch.Tensor]: + assert method in ATTRIBUTION_METHODS, f"Unknown attribution method: {method}" + + is_training = brain.training + required_grad = next(brain.parameters()).requires_grad + grad_enabled = torch.is_grad_enabled() + + # this is required to compute gradients + torch.set_grad_enabled(True) + brain.train() + brain.requires_grad_(False) + + for key, value in stimuli.items(): + stimuli[key] = value.requires_grad_(True) + + input_grads: dict[str, torch.Tensor] = {} + input_grads = ATTRIBUTION_METHODS[method]( + brain, stimuli, target_circuit, target_output_index + ) + + if sum_channels: + for key, grad in input_grads.items(): + input_grads[key] = grad.sum(dim=1, keepdim=True) + if rescale_per_frame: + for key, grad in input_grads.items(): + for frame in range(grad.shape[0]): + input_grads[key][frame] = rescale_zero_one(input_grads[key][frame]) + + # restore original state of training / grad_enabled + brain.requires_grad_(required_grad) + brain.train(is_training) + torch.set_grad_enabled(grad_enabled) + return input_grads + + +def plot(): # -> Figure: + # TODO: Implement plotting logic + raise NotImplementedError diff --git a/retinal_rl/util.py b/retinal_rl/util.py index 487626b..3192f02 100644 --- a/retinal_rl/util.py +++ b/retinal_rl/util.py @@ -3,9 +3,10 @@ import re from enum import Enum from math import ceil, floor -from typing import Any, List, Tuple, TypeVar, Union, cast +from typing import Any, List, Optional, Tuple, TypeVar, Union, cast import numpy as np +import torch from numpy.typing import NDArray from torch import nn @@ -183,3 +184,16 @@ def _double_up(x: Union[int, Tuple[int, ...]]): if isinstance(x, int): return (x, x) return x + + +ArrayLike = TypeVar("ArrayLike", np.ndarray, torch.Tensor) + + +def rescale_zero_one( + x: ArrayLike, min: Optional[float] = None, max: Optional[float] = None +) -> ArrayLike: + if min is None: + min = np.min(x) if isinstance(x, np.ndarray) else torch.min(x).item() + if max is None: + max = np.max(x) if isinstance(x, np.ndarray) else torch.max(x).item() + return (x - min) / (max - min + 1e-8) diff --git a/runner/scripts/produce_rf_plot.py b/runner/scripts/produce_rf_plot.py index ea7caa0..d587035 100644 --- a/runner/scripts/produce_rf_plot.py +++ b/runner/scripts/produce_rf_plot.py @@ -27,10 +27,8 @@ def reshape_images( ): n, _, w, h = arr.shape whitespace_pix = np.round(whitespace * max(w, h)).astype(int) - if n_rows is None and n_cols is None: - n_rows = 1 if n_rows is None: - n_rows = (n + n_cols - 1) // n_cols + n_rows = (n + n_cols - 1) // n_cols if n_cols is not None else 1 if n_cols is None: n_cols = (n + n_rows - 1) // n_rows @@ -71,8 +69,11 @@ def init_plot( rf_dir: Path, cur_file: str, hyper_params: list[str], figwidth: float = 10 ): # Init figure - with open(rf_dir / cur_file) as f: - rf = json.load(f) + if cur_file.endswith(".json"): + with open(rf_dir / cur_file) as f: + rf = json.load(f) + else: + rf = np.load(rf_dir / cur_file, allow_pickle=True) comp_layer_rfs = [] for i, (layer, layer_rfs) in enumerate(rf.items()): @@ -182,7 +183,11 @@ def parse_args(argv: list[str]): experiments_path, out_dir, anim, fast = parse_args(sys.argv) -for experiment_path in experiments_path.iterdir(): +if (experiments_path / "data").exists(): + _iter = [experiments_path] +else: + _iter = experiments_path.iterdir() +for experiment_path in _iter: try: print(experiment_path) if anim: diff --git a/runner/scripts/produce_video.py b/runner/scripts/produce_video.py index bac66ef..2d6d96f 100644 --- a/runner/scripts/produce_video.py +++ b/runner/scripts/produce_video.py @@ -1,12 +1,13 @@ -import sys +import argparse import time from collections import deque +from enum import Enum from pathlib import Path -from typing import Optional, Tuple import numpy as np import torch from omegaconf import OmegaConf +from sample_factory.algo.learning.learner import Learner from sample_factory.algo.sampling.batched_sampling import preprocess_actions from sample_factory.algo.utils.action_distributions import argmax_actions from sample_factory.algo.utils.env_info import extract_env_info @@ -17,115 +18,208 @@ from sample_factory.enjoy import ( load_state_dict, make_env, - render_frame, - visualize_policy_inputs, ) from sample_factory.huggingface.huggingface_utils import ( - generate_model_card, generate_replay_video, - push_to_hf, ) from sample_factory.model.actor_critic import create_actor_critic from sample_factory.model.model_utils import get_rnn_size from sample_factory.utils.typing import Config, StatusCode -from sample_factory.utils.utils import experiment_dir, log +from sample_factory.utils.utils import log +from retinal_rl.analysis.attribution import analyze as attribution_analyze +from retinal_rl.rl.sample_factory.models import SampleFactoryBrain +from retinal_rl.util import rescale_zero_one from runner.frameworks.rl.sf_framework import SFFramework OmegaConf.register_new_resolver("eval", eval) -def create_video(experiment_path: Path): +class VideoType(str, Enum): + RAW = "RAW" + AUGMENTED = "AUGMENTED" + DECODED = "DECODED" + VALUE_MASK = "VALUE_MASK" + + +def video_type(value: str) -> VideoType: + try: + return VideoType(value) + except ValueError: + allowed = ", ".join([e.value for e in VideoType]) + raise argparse.ArgumentTypeError( + f"invalid choice: {value!r} (choose from: {allowed})" + ) + + +def parse_args(argv: list[str] | None = None) -> tuple[Path, list[VideoType], bool]: + parser = argparse.ArgumentParser( + description="Select zero or more VideoTypes and a boolean flag." + ) + + parser.add_argument( + "-e", "--experiment_path", type=Path, help="Path to the experiment directory." + ) + + parser.add_argument( + "-t", + "--type", + metavar="VIDTYPE", + type=video_type, + nargs="+", + default=["RAW"], + help="Zero or more video types. Allowed: " + + ", ".join([e.value for e in VideoType]), + ) + + # Single boolean flag: present -> True, absent -> False + parser.add_argument( + "--actor_frame_rate", + action="store_true", + help="Produce videos at the frame rate the actor operates at (will display only the frames the actor actually sees, typically 1/4 of the original frame rate).", + ) + + parser_args = parser.parse_args(argv) + return parser_args.experiment_path, parser_args.type, parser_args.actor_frame_rate + + +def get_checkpoint_name(experiment_cfg) -> str: + policy_id = experiment_cfg.policy_index + name_prefix = dict(latest="checkpoint", best="best")[ + experiment_cfg.load_checkpoint_kind + ] + checkpoints = Learner.get_checkpoints( + Learner.checkpoint_dir(experiment_cfg, policy_id), f"{name_prefix}_*" + ) + return checkpoints[-1] + + +def create_video( + experiment_path: Path, video_types: list[VideoType], actor_frame_rate: bool +): # Load the config file - cfg = OmegaConf.load(experiment_path / "config" / "config.yaml") - cfg.path.run_dir = experiment_path + experiment_cfg = OmegaConf.load(experiment_path / "config" / "config.yaml") + experiment_cfg.path.run_dir = experiment_path + + experiment_cfg.logging.use_wandb = False + + framework = SFFramework(experiment_cfg, "cache") + custom_enjoy(framework.sf_cfg, video_types, actor_frame_rate) - cfg.logging.use_wandb = False - cfg.samplefactory.save_video = True - cfg.samplefactory.no_render = True - framework = SFFramework(cfg, "cache") - custom_enjoy(framework.sf_cfg) +def get_frames( + actor_critic: SampleFactoryBrain, obs, rnn_states +) -> dict[VideoType, torch.Tensor]: + normalized_obs = prepare_and_normalize_obs(actor_critic, obs) + responses = actor_critic.brain( + {"vision": normalized_obs["obs"], "rnn_state": rnn_states} + ) -def _rescale_zero_one(x, min: Optional[float] = None, max: Optional[float] = None): - if min is None: - min = np.min(x) - if max is None: - max = np.max(x) - return (x - min) / (max - min) + # find if decoder exists by matching output shape to input shape + # TODO: Use loss definition instead and pass the key to the function + decoder_key = None + for response_key, response in responses.items(): + if response[0].shape == obs["obs"].shape and response_key != "vision": + decoder_key = response_key + break + + cur_frames: dict[VideoType, torch.Tensor] = { + VideoType.RAW: obs["obs"].detach(), + VideoType.AUGMENTED: normalized_obs["obs"].detach(), + VideoType.DECODED: responses[decoder_key][0].detach() + if decoder_key is not None + else None, + VideoType.VALUE_MASK: attribution_analyze( + actor_critic.brain, + {"vision": normalized_obs["obs"], "rnn_state": rnn_states}, + target_circuit="critic", + method="l1", + sum_channels=True, + rescale_per_frame=False, + )["vision"], + } + return cur_frames def custom_enjoy( # noqa: C901 # TODO: Properly implement this anyway - cfg: Config, -) -> Tuple[StatusCode, float]: + experiment_cfg: Config, + video_types: list[VideoType], + actor_frame_rate: bool, +) -> tuple[StatusCode, float]: verbose = False - cfg = load_from_checkpoint(cfg) + experiment_cfg = load_from_checkpoint(experiment_cfg) eval_env_frameskip: int = ( - cfg.env_frameskip if cfg.eval_env_frameskip is None else cfg.eval_env_frameskip + experiment_cfg.env_frameskip + if experiment_cfg.eval_env_frameskip is None + else experiment_cfg.eval_env_frameskip ) assert ( - cfg.env_frameskip % eval_env_frameskip == 0 - ), f"{cfg.env_frameskip=} must be divisible by {eval_env_frameskip=}" - render_action_repeat: int = cfg.env_frameskip // eval_env_frameskip - cfg.env_frameskip = cfg.eval_env_frameskip = eval_env_frameskip + experiment_cfg.env_frameskip % eval_env_frameskip == 0 + ), f"{experiment_cfg.env_frameskip=} must be divisible by {eval_env_frameskip=}" + render_action_repeat: int = experiment_cfg.env_frameskip // eval_env_frameskip + experiment_cfg.env_frameskip = experiment_cfg.eval_env_frameskip = ( + eval_env_frameskip + ) log.debug( - f"Using frameskip {cfg.env_frameskip} and {render_action_repeat=} for evaluation" + f"Using frameskip {experiment_cfg.env_frameskip} and {render_action_repeat=} for evaluation" ) - cfg.num_envs = 1 + experiment_cfg.num_envs = 1 - render_mode = "human" - if cfg.save_video: - render_mode = "rgb_array" - elif cfg.no_render: - render_mode = None + render_mode = "rgb_array" - env = make_env(cfg, render_mode=render_mode) - env_info = extract_env_info(env, cfg) + env = make_env(experiment_cfg, render_mode=render_mode) + env_info = extract_env_info(env, experiment_cfg) if hasattr(env.unwrapped, "reset_on_init"): # reset call ruins the demo recording for VizDoom env.unwrapped.reset_on_init = False - actor_critic = create_actor_critic(cfg, env.observation_space, env.action_space) + actor_critic = create_actor_critic( + experiment_cfg, env.observation_space, env.action_space + ) actor_critic.eval() - device = torch.device("cpu" if cfg.device == "cpu" else "cuda") + device = torch.device("cpu" if experiment_cfg.device == "cpu" else "cuda") actor_critic.model_to_device(device) - load_state_dict(cfg, actor_critic, device) + load_state_dict(experiment_cfg, actor_critic, device) episode_rewards = [deque([], maxlen=100) for _ in range(env.num_agents)] true_objectives = [deque([], maxlen=100) for _ in range(env.num_agents)] num_frames = 0 - last_render_start = time.time() - def max_frames_reached(frames: int) -> bool: - return cfg.max_num_frames is not None and frames > cfg.max_num_frames + return ( + experiment_cfg.max_num_frames is not None + and frames > experiment_cfg.max_num_frames + ) reward_list = [] obs, infos = env.reset() action_mask = obs.pop("action_mask").to(device) if "action_mask" in obs else None rnn_states = torch.zeros( - [env.num_agents, get_rnn_size(cfg)], dtype=torch.float32, device=device + [env.num_agents, get_rnn_size(experiment_cfg)], + dtype=torch.float32, + device=device, ) episode_reward = None finished_episode = [False for _ in range(env.num_agents)] - video_frames = [] + video_frames = {vid_type: [] for vid_type in video_types} num_episodes = 0 with torch.no_grad(): while not max_frames_reached(num_frames): normalized_obs = prepare_and_normalize_obs(actor_critic, obs) - if not cfg.no_render: - visualize_policy_inputs(normalized_obs) + cur_frames: dict[VideoType, torch.Tensor] = get_frames( + actor_critic, obs, rnn_states + ) policy_outputs = actor_critic( normalized_obs, rnn_states, action_mask=action_mask ) @@ -133,7 +227,7 @@ def max_frames_reached(frames: int) -> bool: # sample actions from the distribution by default actions = policy_outputs["actions"] - if cfg.eval_deterministic: + if experiment_cfg.eval_deterministic: action_distribution = actor_critic.action_distribution() actions = argmax_actions(action_distribution) @@ -144,19 +238,23 @@ def max_frames_reached(frames: int) -> bool: rnn_states = policy_outputs["new_rnn_states"] - for _ in range(render_action_repeat): + for _i_repeat in range(render_action_repeat): obs, rew, terminated, truncated, infos = env.step(actions) need_video_frame = ( - len(video_frames) < cfg.video_frames - or cfg.video_frames < 0 + len(next(iter(video_frames.items()))) < experiment_cfg.video_frames + or experiment_cfg.video_frames < 0 and num_episodes == 0 ) if need_video_frame: # frame = env.render() - normalized_obs = prepare_and_normalize_obs(actor_critic, obs) - frame = normalized_obs["obs"] - video_frames.append(frame[0].movedim(0, -1).cpu().numpy()) + if not actor_frame_rate and _i_repeat > 0: + cur_frames = get_frames(actor_critic, obs, rnn_states) + + for _vid_type in video_types: + video_frames[_vid_type].append( + cur_frames[_vid_type][0].movedim(0, -1).cpu().numpy() + ) action_mask = ( obs.pop("action_mask").to(device) if "action_mask" in obs else None @@ -196,11 +294,13 @@ def max_frames_reached(frames: int) -> bool: true_objectives[agent_i][-1], ) rnn_states[agent_i] = torch.zeros( - [get_rnn_size(cfg)], dtype=torch.float32, device=device + [get_rnn_size(experiment_cfg)], + dtype=torch.float32, + device=device, ) episode_reward[agent_i] = 0 - if cfg.use_record_episode_statistics: + if experiment_cfg.use_record_episode_statistics: # we want the scores from the full episode not a single agent death (due to EpisodicLifeEnv wrapper) if "episode" in infos[agent_i]: num_episodes += 1 @@ -211,9 +311,9 @@ def max_frames_reached(frames: int) -> bool: # if episode terminated synchronously for all agents, pause a bit before starting a new one if all(dones): - render_frame( - cfg, env, video_frames, num_episodes, last_render_start - ) + # render_frame( + # experiment_cfg, env, video_frames[VideoType.RAW], num_episodes, last_render_start + # ) # I don't think this is too important - if, find a solution with VideoType.RAW time.sleep(0.05) if all(finished_episode): @@ -247,38 +347,38 @@ def max_frames_reached(frames: int) -> bool: ), ) - if num_episodes >= cfg.max_num_episodes: + if num_episodes >= experiment_cfg.max_num_episodes: break env.close() - if cfg.save_video: - fps = cfg.fps if cfg.fps > 0 else 30 + fps = experiment_cfg.fps if experiment_cfg.fps > 0 else 30 + for _vid_type in video_types: # assert frames are in the right range (0-255) to produce the video - video_frames = (_rescale_zero_one(np.array(video_frames)) * 255).astype( - np.uint8 - ) - generate_replay_video(experiment_dir(cfg=cfg), video_frames, fps, cfg) - - if cfg.push_to_hub: - generate_model_card( - experiment_dir(cfg=cfg), - cfg.algo, - cfg.env, - cfg.hf_repository, - reward_list, - cfg.enjoy_script, - cfg.train_script, + shape = video_frames[_vid_type][0].shape + for i, frame in enumerate(video_frames[_vid_type]): + if frame.shape != shape: + video_frames[_vid_type][i] = np.zeros(shape, dtype=np.uint8) + video_frames[_vid_type] = ( + rescale_zero_one(np.stack(video_frames[_vid_type])) * 255 + ).astype(np.uint8) + vid_path = experiment_path / "data" / "video" + vid_path.mkdir(parents=True, exist_ok=True) + + ckpt_str = Path(get_checkpoint_name(experiment_cfg)).name[:-4] + vid_type_str = f"_{_vid_type.value}" + frame_rate_str = "_actor_frame_rate" if actor_frame_rate else "" + experiment_cfg.video_name = ckpt_str + vid_type_str + frame_rate_str + ".mp4" + generate_replay_video( + str(vid_path), video_frames[_vid_type], fps, experiment_cfg ) - push_to_hf(experiment_dir(cfg=cfg), cfg.hf_repository) return ExperimentStatus.SUCCESS, sum( [sum(episode_rewards[i]) for i in range(env.num_agents)] ) / max(1, sum([len(episode_rewards[i]) for i in range(env.num_agents)])) -experiment_path = Path(sys.argv[1]) - if __name__ == "__main__": - create_video(experiment_path) + experiment_path, video_types, actor_frame_rate = parse_args() + create_video(experiment_path, video_types, actor_frame_rate)