Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 0 additions & 114 deletions produce_rf_plot.py

This file was deleted.

96 changes: 96 additions & 0 deletions retinal_rl/analysis/attribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import torch
from captum.attr import InputXGradient

from retinal_rl.models.brain import Brain
from retinal_rl.util import rescale_zero_one


def l1_attribution(
brain: Brain,
stimuli: dict[str, torch.Tensor],
target_circuit: torch.Tensor,
target_output_index: int = 0,
) -> dict[str, torch.Tensor]:
input_grads: dict[str, torch.Tensor] = {}
output = brain(stimuli)[target_circuit][target_output_index]
loss = torch.nn.L1Loss()(output, torch.zeros_like(output))
loss.backward()
for key, value in stimuli.items():
input_grads[key] = value.grad.detach().cpu()
return input_grads


def captum_attribution(
brain: Brain,
stimuli: dict[str, torch.Tensor],
target_circuit: torch.Tensor,
target_output_index: int = 0,
) -> dict[str, torch.Tensor]:
input_grads: dict[str, torch.Tensor] = {}

stimuli_keys = list(stimuli.keys()) # create list to preserve order

def _forward(*args: tuple[torch.Tensor]) -> torch.Tensor:
assert len(args) == len(stimuli_keys)
return brain({k: v for k, v in zip(stimuli_keys, args)})[target_circuit][
target_output_index
]

value_grad_calculator = InputXGradient(_forward)
value_grads = value_grad_calculator.attribute(
tuple(stimuli[k] for k in stimuli_keys)
)
for key, value_grad in zip(stimuli_keys, value_grads):
input_grads[key] = value_grad.detach().cpu()
return input_grads


ATTRIBUTION_METHODS = {"l1": l1_attribution, "attribution": captum_attribution}


def analyze(
brain: Brain,
stimuli: dict[str, torch.Tensor],
target_circuit: torch.Tensor,
target_output_index: int = 0,
method: str = "l1",
sum_channels: bool = True,
rescale_per_frame: bool = False,
) -> dict[str, torch.Tensor]:
assert method in ATTRIBUTION_METHODS, f"Unknown attribution method: {method}"

is_training = brain.training
required_grad = next(brain.parameters()).requires_grad
grad_enabled = torch.is_grad_enabled()

# this is required to compute gradients
torch.set_grad_enabled(True)
brain.train()
brain.requires_grad_(False)

for key, value in stimuli.items():
stimuli[key] = value.requires_grad_(True)

input_grads: dict[str, torch.Tensor] = {}
input_grads = ATTRIBUTION_METHODS[method](
brain, stimuli, target_circuit, target_output_index
)

if sum_channels:
for key, grad in input_grads.items():
input_grads[key] = grad.sum(dim=1, keepdim=True)
if rescale_per_frame:
for key, grad in input_grads.items():
for frame in range(grad.shape[0]):
input_grads[key][frame] = rescale_zero_one(input_grads[key][frame])

# restore original state of training / grad_enabled
brain.requires_grad_(required_grad)
brain.train(is_training)
torch.set_grad_enabled(grad_enabled)
return input_grads


def plot(): # -> Figure:
# TODO: Implement plotting logic
raise NotImplementedError
16 changes: 15 additions & 1 deletion retinal_rl/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import re
from enum import Enum
from math import ceil, floor
from typing import Any, List, Tuple, TypeVar, Union, cast
from typing import Any, List, Optional, Tuple, TypeVar, Union, cast

import numpy as np
import torch
from numpy.typing import NDArray
from torch import nn

Expand Down Expand Up @@ -183,3 +184,16 @@ def _double_up(x: Union[int, Tuple[int, ...]]):
if isinstance(x, int):
return (x, x)
return x


ArrayLike = TypeVar("ArrayLike", np.ndarray, torch.Tensor)


def rescale_zero_one(
x: ArrayLike, min: Optional[float] = None, max: Optional[float] = None
) -> ArrayLike:
if min is None:
min = np.min(x) if isinstance(x, np.ndarray) else torch.min(x).item()
if max is None:
max = np.max(x) if isinstance(x, np.ndarray) else torch.max(x).item()
return (x - min) / (max - min + 1e-8)
17 changes: 11 additions & 6 deletions runner/scripts/produce_rf_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ def reshape_images(
):
n, _, w, h = arr.shape
whitespace_pix = np.round(whitespace * max(w, h)).astype(int)
if n_rows is None and n_cols is None:
n_rows = 1
if n_rows is None:
n_rows = (n + n_cols - 1) // n_cols
n_rows = (n + n_cols - 1) // n_cols if n_cols is not None else 1
if n_cols is None:
n_cols = (n + n_rows - 1) // n_rows

Expand Down Expand Up @@ -71,8 +69,11 @@ def init_plot(
rf_dir: Path, cur_file: str, hyper_params: list[str], figwidth: float = 10
):
# Init figure
with open(rf_dir / cur_file) as f:
rf = json.load(f)
if cur_file.endswith(".json"):
with open(rf_dir / cur_file) as f:
rf = json.load(f)
else:
rf = np.load(rf_dir / cur_file, allow_pickle=True)

comp_layer_rfs = []
for i, (layer, layer_rfs) in enumerate(rf.items()):
Expand Down Expand Up @@ -182,7 +183,11 @@ def parse_args(argv: list[str]):


experiments_path, out_dir, anim, fast = parse_args(sys.argv)
for experiment_path in experiments_path.iterdir():
if (experiments_path / "data").exists():
_iter = [experiments_path]
else:
_iter = experiments_path.iterdir()
for experiment_path in _iter:
try:
print(experiment_path)
if anim:
Expand Down
Loading