A set of utilities that I frequently reuse across different projects, which would be more efficiently managed in a dedicated repository.
uv add "xtils[jitpp] @ git+https://github.com/jessefarebro/xtils"
A wrapper around jax.jit providing additional functionality.
This wrapper provides three additional features over regular jit.
- Proper type hints for jitted functions. This makes it so tools like pyright can show autocompletions for your jitted functions.
- You can use type annotations to specify static and donated arguments
with the
Static[]andDonate[]annotation types. - A somewhat opinionated way to bind attributes of a class so that you can jit static class methods more easily while still retaining the modularity of classes.
Functional example of type annotations:
from xtils import jitpp
from xtils.jitpp import Static, Donate
@jitpp.jit
def f(x: Donate[int], sign: Static[int]) -> int:
return x * sign
f(1, -1)
f(1, 1) # re-traced as `sign` is annotated static.Caution
If you use other decorators between @jitpp.jit and your function this could
potentially cause problems if the type annotations are stripped or if arguments
are permuted.
Class-based staticmethod example:
from xtils import jitpp
from xtils.jitpp import Static, Donate, Bind
@dataclasses.dataclass
class MyClass:
sign: float
@jitpp.jit
@staticmethod
def f(x: Donate[int], *, sign: Bind[Static[int]]) -> int:
return x * sign
obj = MyClass(sign=-1)
obj.f(1) # NOTE: sign doesn't need to be provided as its bound to `obj.sign`
obj.sign = 1
obj.f(1) # re-traced as `sign` is annotated static.uv add "xtils[jdbpp] @ git+https://github.com/jessefarebro/xtils"
An improved version of the builtin Jax debugger. It supports features like:
- An improved UI (e.g., code highlighting, pretty backtrace, prettyprint)
- Ability to run arbitrary commands like pdb, e.g., importing libraries
- IPython shell
- Command history logging so e.g., you can press up to get previous commands from prior sessions
Additional commands include:
interactorito drop into an IPython shellllfor long-list to list the entire fileshapeorsto get the shapes of all the variables in the current contextpdef,pdoc,pfile,pinfo,psource, and magic commands from IPythonvto get a table of the variables in scope
All you need to do in order to use jdbpp is to import xtils.jdbpp and it will register itself with Jax.
uv add "xtils[fiddle] @ git+https://github.com/jessefarebro/xtils"
from typing import Iterator
import fiddle as fdl
from fiddle import selectors
from xtils.fiddle import auto_sweep
def sweep_lr(base: fdl.Buildable[T]) -> Iterator[fdl.Buildable[T]]:
for lr in (1e-4, 1e-5, 5e-5):
yield fdl.deepcopy_with(base, lr=lr)
def sweep_seed(base: fdl.Buildable[T]) -> Iterator[fdl.Buildable[T]]:
for seed in range(10):
yield fdl.deepcopy_with(base, seed=seed)
# Trials will contain fdl.Buildables that have (lr, seed)
# pairs mutated from the above configs.
trials = auto_sweep.make_trials_from_sweeps(
my_base_config(),
[sweep_lr, sweep_seed]
)from xtils.fiddle import printing
my_config_params = printing.as_dict(
cfg,
include_buildable_fn_or_cls=True,
include_defaults=False,
buildable_fn_or_cls_key="__fn_or_cls__",
flatten_tree=False,
)uv add "xtils[clu] @ git+https://github.com/jessefarebro/xtils"
from xtils.clu import metric_writers
writer = metric_writers.create_default_writer(
just_logging=False,
asynchronous=False,
)You can customize the metric writer with:
--metric_writer=aim|wandb|tensorboard
# Aim
--aim.repo # Repository directory
--aim.experiment # Experiment name.
--aim.run_hash # Run hash
--aim.log_system_params # Log system parameters.
# Wandb
--wandb.save_code # Save code.
--wandb.id # Run ID
--wandb.tags # Tags.
--wandb.name # Name
--wandb.group # Group.
--wandb.mode # Mode: online|offline|disabled
# Tensorboard
--tensorboard.logdir # Log directory.uv add "xtils[domains] @ git+https://github.com/jessefarebro/xtils"
from xtils.domains import atari as dm_ale
env = dm_ale.AtariEnvironment(
game,
mode=None,
difficulty=None,
seed=None,
repeat_action_probability=0.25,
frameskip=1,
max_episode_frames=108_000,
render_mode=None,
frame_processing=None,
action_set=ActionSet.Minimal,
observation_type=(
dm_ale.ObservationType.ImageRGB,
dm_ale.ObservationType.Lives,
)
)uv add "xtils[plotting] @ git+https://github.com/jessefarebro/xtils"
Get the theme that can be used with Seaborn object .theme(...):
from xtils.plotting import THEME
so.Plot(...)
.theme(THEME)Fetch baseline dataframes for Dopamine and DQN Zoo:
from xtils.plotting import baselines
dopamine = baselines.dopamine()
zoo = baselines.zoo()Rollingmove transform.LineLabelmark.
Mirrors pd.DataFrame.rolling.
from xtils.plotting import objects as xso
so.Plot(...)
.add(so.Line(), so.Agg(), xso.Rolling())from xtils.plotting import objects as xso
so.Plot(...)
.add(so.Line() + xso.LineLabel(), so.Agg(), text="Agent")uv add "xtils[rl] @ git+https://github.com/jessefarebro/xtils"
-
Dynamic programming utilities
transition_matrixreward_vectorsuccessor_representationuniform_random_policyproto_value_functionsbellman_optimality_oppolicy_improvement_oppolicy_evaluation_opvalue_iterationpolicy_iteration
-
MDP
GridWorldFourRoomsGridMiddleWallGridWindingGridDayanGridTMazeSimpleT
-
Spectral
policy_mixing_distributionscompute_entropy