Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions retinal_rl/rl/sample_factory/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

"""

from typing import Optional

from sample_factory.utils.utils import str2bool
from sf_examples.vizdoom.doom.doom_params import (
add_doom_env_args,
Expand Down Expand Up @@ -79,6 +81,25 @@ def add_retinal_env_args(parser):
help="Only perform a dry run of the config and network analysis, without training or evaluation",
)

parser.add_argument(
"--warp_exp",
type=Optional[float],
default=None,
help="If and how much to warp the input image. If None, no warping is applied. The higher the closer to a center crop the warping is.",
)
parser.add_argument(
"--warp_h",
type=int,
default=60,
help="height after warping",
)
parser.add_argument(
"--warp_w",
type=int,
default=80,
help="width after warping",
)


def add_retinal_env_eval_args(parser):
"""
Expand Down
133 changes: 130 additions & 3 deletions retinal_rl/rl/sample_factory/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import gymnasium as gym
import numpy as np
from gymnasium import spaces

# import gym
from gymnasium.spaces import Discrete
Expand Down Expand Up @@ -40,6 +41,108 @@ def doom_action_space_no_backwards():
### Wrappers ###


class WarpResizeWrapper(gym.core.Wrapper):
"""Resize observation frames to specified (w,h) and convert to grayscale."""

def __init__(self, env, h: int, w: int, warp_exp: float = 2.0):
super().__init__(env)

self.w = w
self.h = h
self.warp_exp = warp_exp

if isinstance(env.observation_space, spaces.Dict):
new_spaces = {}
for key, space in env.observation_space.spaces.items():
new_spaces[key] = self._calc_new_obs_space(space)
self.observation_space = spaces.Dict(new_spaces)
else:
self.observation_space = self._calc_new_obs_space(env.observation_space)

def _calc_new_obs_space(self, old_space):
low, high = old_space.low.flat[0], old_space.high.flat[0]

assert (
len(old_space.shape) == 3
), "Expected observation space to have shape (h, w, channels)"

channel_last = len(old_space.shape) < 3 or np.argmin(old_space.shape) == 2
channels = old_space.shape[-1 if channel_last else 0]
new_shape = (
[self.h, self.w, channels] if channel_last else [channels, self.h, self.w]
)

return spaces.Box(low, high, shape=new_shape, dtype=old_space.dtype)

@staticmethod
def center_warp_image(image, out_shape: tuple[int, int] = (60, 80), exp: float = 2):
"""
Center-warp the image to a specified output size and scale.
"""
channel_last = len(image.shape) < 3 or np.argmin(image.shape) == 2
if channel_last:
h, w = image.shape[0], image.shape[1]
else:
h, w = image.shape[1], image.shape[2]
center = (h // 2, w // 2) # (height, width)

out_shape_half = (out_shape[0] // 2, out_shape[1] // 2)
row_even = out_shape[0] % 2
col_even = out_shape[1] % 2
row_idx = np.round(
(np.arange(0, out_shape_half[0]) / (out_shape_half[0] - 1)) ** exp
* (center[0] - (1 + row_even))
).astype(int) # Generate indices for rows and columns
col_idx = np.round(
(np.arange(0, out_shape_half[1]) / (out_shape_half[1] - 1)) ** exp
* (center[1] - (1 + col_even))
).astype(int)

# ensure difference is at least 1 pixel
row_inc = np.arange(len(row_idx))
col_inc = np.arange(len(col_idx))
row_idx[row_idx < row_inc] = row_inc[row_idx < row_inc]
col_idx[col_idx < col_inc] = col_inc[col_idx < col_inc]

h = center[0] - row_idx[::-1] - 1 - row_even
w = center[1] - col_idx[::-1] - 1 - col_even
if row_even:
h = np.hstack([h, center[0] - 1])
if col_even:
w = np.hstack([w, center[1] - 1])
h = np.hstack([h, row_idx + center[0]])
w = np.hstack([w, col_idx + center[1]])
if channel_last:
out = image[h[:, np.newaxis], w[np.newaxis, :]]
else:
out = image[:, h[:, np.newaxis], w[np.newaxis, :]]
return out

def _convert_obs(self, obs):
if obs is None:
return obs

return self.center_warp_image(
obs, out_shape=(self.h, self.w), exp=self.warp_exp
)

def _observation(self, obs):
if isinstance(obs, dict):
new_obs = {}
for key, value in obs.items():
new_obs[key] = self._convert_obs(value)
return new_obs
return self._convert_obs(obs)

def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
return self._observation(obs), info

def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
return self._observation(obs), reward, terminated, truncated, info


class SatietyInput(gym.Wrapper):
"""Add game variables to the observation space + reward shaping."""

Expand Down Expand Up @@ -96,12 +199,22 @@ def step(self, action):


def retinal_doomspec(
scene_name: str, cfg_path: str, sat_in: bool, allow_backwards: bool
scene_name: str,
cfg_path: str,
sat_in: bool,
allow_backwards: bool,
warp_exp: Optional[float] = None,
warp_w: int = 80,
warp_h: int = 60,
):
ewraps = []

if sat_in:
ewraps = [(SatietyInput, {})]
if warp_exp is not None:
ewraps.append(
(WarpResizeWrapper, {"h": warp_h, "w": warp_w, "warp_exp": warp_exp})
)

action_space = (
doom_action_space_basic()
Expand Down Expand Up @@ -134,14 +247,28 @@ def make_retinal_env_from_spec(


def register_retinal_env(
scene_name: str, cache_dir: str, input_satiety: bool, allow_backwards: bool = True
scene_name: str,
cache_dir: str,
input_satiety: bool,
allow_backwards: bool = True,
warp_exp: Optional[float] = None,
warp_h: int = 60,
warp_w: int = 80,
):
if not os.path.isabs(cache_dir):
# make path absolute by making it relative to the path of this file
# TODO: Discuss whether this is desired behaviour...
cache_dir = os.path.join(os.path.dirname(__file__), "..", "..", "..", cache_dir)
cfg_path = os.path.join(cache_dir, "scenarios", scene_name + ".cfg")

env_spec = retinal_doomspec(scene_name, cfg_path, input_satiety, allow_backwards)
env_spec = retinal_doomspec(
scene_name,
cfg_path,
input_satiety,
allow_backwards,
warp_exp=warp_exp,
warp_h=warp_h,
warp_w=warp_w,
)
make_env_func = functools.partial(make_retinal_env_from_spec, env_spec)
register_env(env_spec.name, make_env_func)
18 changes: 16 additions & 2 deletions runner/frameworks/rl/sf_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,9 @@ def __init__(self, cfg: DictConfig, data_root: str):
self.data_root,
self.sf_cfg.input_satiety,
self.sf_cfg.allow_backwards,
warp_exp=self.sf_cfg.warp_exp,
warp_h=self.sf_cfg.warp_h,
warp_w=self.sf_cfg.warp_w,
)

if hasattr(cfg.brain.circuits, "actor"):
Expand Down Expand Up @@ -162,8 +165,19 @@ def to_sf_cfg(cfg: DictConfig) -> Config:
)
# Using this function is necessary to make sure that the parameters are not overwritten when sample_factory loads a checkpoint

SFFramework._set_cfg_cli_argument(sf_cfg, "res_h", cfg.dataset.vision_height)
SFFramework._set_cfg_cli_argument(sf_cfg, "res_w", cfg.dataset.vision_width)
if hasattr(cfg.dataset, "warp_exp") and cfg.dataset.warp_exp is not None:
SFFramework._set_cfg_cli_argument(sf_cfg, "warp_exp", cfg.dataset.warp_exp)
SFFramework._set_cfg_cli_argument(
sf_cfg, "warp_h", cfg.dataset.vision_height
)
SFFramework._set_cfg_cli_argument(
sf_cfg, "warp_w", cfg.dataset.vision_width
)
else:
SFFramework._set_cfg_cli_argument(
sf_cfg, "res_h", cfg.dataset.vision_height
)
SFFramework._set_cfg_cli_argument(sf_cfg, "res_w", cfg.dataset.vision_width)
SFFramework._set_cfg_cli_argument(sf_cfg, "env", cfg.dataset.env_name)
SFFramework._set_cfg_cli_argument(
sf_cfg, "input_satiety", cfg.dataset.input_satiety
Expand Down