Skip to content

Commit 0edeac0

Browse files
authored
Merge pull request #12 from FragileTech/alpha
Add skeleton for first release
2 parents 9c89d69 + 13e1a57 commit 0edeac0

13 files changed

Lines changed: 752 additions & 311 deletions

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ dependencies = [
5959
"quantlib>=1.35",
6060
"hvplot>=0.11.0",
6161
"streamz>=0.6.4",
62-
"plangym[atari]>=0.1.29",
62+
"plangym[dm_control,atari]>=0.1.29",
6363
"ray>=2.37.0",
6464
"jupyter",
6565
"notebook",

src/fragile/actions.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import numpy as np
2+
import torch
3+
4+
from fragile.core import BaseDtSampler, BaseFractalTree, BasePolicy, FractalTree
5+
from fragile.fractalai import clone_tensor
6+
7+
8+
class UniformDtSampler(BaseDtSampler):
9+
def __init__(self, min_dt: int = 1, max_dt: int = 5, fractal: BaseFractalTree | None = None):
10+
super().__init__(fractal=fractal)
11+
self.max_dt = max_dt
12+
self.min_dt = min_dt
13+
14+
def get_dt(self, n_walkers: int | None = None, fractal: BaseFractalTree | None = None):
15+
if n_walkers is None:
16+
n_walkers = fractal.n_walkers
17+
return np.random.randint(self.min_dt, self.max_dt, size=n_walkers) # noqa: NPY002
18+
19+
20+
class RandomGaussianPolicy(BasePolicy):
21+
def __init__(
22+
self,
23+
std: float = 1.0,
24+
min: float | None = None,
25+
max: float | None = None,
26+
fractal: BaseFractalTree | None = None,
27+
):
28+
super().__init__(fractal=fractal)
29+
self.std = std
30+
self.min_ = min
31+
self.max_ = max
32+
33+
def act(self, n_walkers: int | None = None, fractal: FractalTree | None = None):
34+
fractal = fractal if fractal is not None else self.fractal
35+
if n_walkers is None:
36+
n_walkers = fractal.n_walkers
37+
return (torch.randn((n_walkers, *fractal.action_shape)) * self.std).clamp(
38+
self.min_, self.max_
39+
)
40+
41+
42+
class GaussianForce(RandomGaussianPolicy):
43+
def __init__(
44+
self,
45+
std: float = 1.0,
46+
min: float | None = None,
47+
max: float | None = None,
48+
fractal: FractalTree | None = None,
49+
):
50+
super().__init__(std=std, fractal=fractal, min=min, max=max)
51+
action_shape = fractal.action_shape if fractal is not None else (1,)
52+
device = fractal.device if fractal is not None else "cpu"
53+
n_walkers = fractal.n_walkers if fractal is not None else 1
54+
self.velocity = torch.zeros((n_walkers, *action_shape), device=device)
55+
56+
def set_fractal(self, fractal: "FractalTree"):
57+
super().set_fractal(fractal)
58+
self.velocity = torch.zeros(
59+
(fractal.n_walkers, *fractal.action_shape), device=fractal.device
60+
)
61+
62+
def act(self, n_walkers: int | None = None, fractal: FractalTree | None = None):
63+
action = super().act(n_walkers=n_walkers, fractal=fractal)
64+
wc = (
65+
self.fractal.will_clone
66+
if self.fractal.will_clone.sum() > 0
67+
else torch.ones_like(self.fractal.will_clone)
68+
)
69+
self.velocity[wc] += action
70+
return self.velocity[wc].clamp(self.min_, self.max_)
71+
72+
def clone(self, will_clone: torch.Tensor, clone_ix: torch.Tensor):
73+
self.velocity = clone_tensor(self.velocity, clone_ix, will_clone)
74+
75+
def add_walkers(self, new_walkers):
76+
new_vel = torch.zeros((new_walkers, *self.velocity.shape[1:]), device=self.velocity.device)
77+
self.velocity = torch.cat((self.velocity, new_vel), dim=0).contiguous()

src/fragile/app/__init__.py

Whitespace-only changes.

src/fragile/app/debug.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import flogging
2+
import holoviews as hv
3+
import panel as pn
4+
5+
from fragile.actions import RandomGaussianPolicy, UniformDtSampler
6+
from fragile.benchmarks import Rastrigin
7+
from fragile.core import FaiRunner
8+
from fragile.functions import FunctionTree
9+
from fragile.shaolin.stream_plots import RGB
10+
from fragile.shaolin.streaming_fai import InteractiveFai
11+
12+
13+
hv.extension("bokeh")
14+
pn.extension("tabulator", theme="dark")
15+
16+
17+
class PlanGymDisplay:
18+
def __init__(
19+
self,
20+
):
21+
self.best_img = RGB()
22+
self._curr_best = -1
23+
24+
def reset(self, fai): # noqa: ARG002
25+
return
26+
27+
def send(self, fai):
28+
best_ix = fai.cum_reward.argmax().cpu().item()
29+
best_img = fai.img[best_ix]
30+
if best_ix != self._curr_best:
31+
self.best_img.send(best_img)
32+
self._curr_best = best_ix
33+
34+
def __panel__(self):
35+
return pn.Column(
36+
pn.Row(
37+
self.best_img.plot,
38+
# self.room_grey.plot * self.tree_best_room,
39+
),
40+
)
41+
42+
43+
def main():
44+
flogging.setup()
45+
env = Rastrigin(2)
46+
47+
n_walkers = 10000
48+
fai = FunctionTree(
49+
max_walkers=n_walkers,
50+
env=env,
51+
dt_sampler=UniformDtSampler(min_dt=1, max_dt=3),
52+
policy=RandomGaussianPolicy(std=0.05, min=-1.0, max=1.0),
53+
device="cpu",
54+
min_leafs=50,
55+
start_walkers=50,
56+
minimize=True,
57+
)
58+
plot = InteractiveFai(fai)
59+
runner = FaiRunner(fai, 1000000, plot=plot)
60+
return pn.panel(pn.Column(runner, plot)).servable()
61+
62+
63+
# if __name__ == "__main__":
64+
main()

src/fragile/app/functions.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import holoviews as hv
2+
import panel as pn
3+
import plangym
4+
5+
from fragile.core import FaiRunner, FractalTree
6+
from fragile.shaolin.stream_plots import RGB
7+
8+
9+
hv.extension("bokeh")
10+
pn.extension("tabulator", theme="dark")
11+
12+
13+
class PlanGymDisplay:
14+
def __init__(
15+
self,
16+
):
17+
self.best_img = RGB()
18+
self._curr_best = -1
19+
20+
def reset(self, fai): # noqa: ARG002
21+
return
22+
23+
def send(self, fai):
24+
best_ix = fai.cum_reward.argmax().cpu().item()
25+
best_img = fai.img[best_ix]
26+
if best_ix != self._curr_best:
27+
self.best_img.send(best_img)
28+
self._curr_best = best_ix
29+
30+
def __panel__(self):
31+
return pn.Column(
32+
pn.Row(
33+
self.best_img.plot,
34+
# self.room_grey.plot * self.tree_best_room,
35+
),
36+
)
37+
38+
39+
def main():
40+
env = plangym.make(
41+
domain_name="walker",
42+
task_name="stand",
43+
obs_type="coords",
44+
return_image=True,
45+
frameskip=1,
46+
# n_workers=10,
47+
# ray=True,
48+
)
49+
50+
n_walkers = 10000
51+
plot = PlanGymDisplay()
52+
fai = FractalTree(
53+
max_walkers=n_walkers, env=env, device="cpu", min_leafs=250, start_walkers=250
54+
)
55+
runner = FaiRunner(fai, 1000000, plot=plot)
56+
pn.panel(pn.Column(runner, plot)).servable()
Lines changed: 4 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from functools import partial
2-
import threading
3-
import time
42

53
import holoviews as hv
64
from holoviews.streams import Pipe
75
import numpy as np
8-
import pandas as pd
96
import panel as pn
10-
import param
117
import plangym
128
from plangym.utils import process_frame
139

10+
from fragile.core import FaiRunner
1411
from fragile.shaolin.stream_plots import Image, RGB
1512
from fragile.videogames import aggregate_visits, MontezumaTree
1613

@@ -19,121 +16,6 @@
1916
pn.extension("tabulator", theme="dark")
2017

2118

22-
class FaiRunner(param.Parameterized):
23-
is_running = param.Boolean(default=False)
24-
25-
def __init__(self, fai, n_steps, plot=None, report_interval=100):
26-
super().__init__()
27-
self.reset_btn = pn.widgets.Button(icon="restore", button_type="primary")
28-
self.play_btn = pn.widgets.Button(icon="player-play", button_type="primary")
29-
self.pause_btn = pn.widgets.Button(icon="player-pause", button_type="primary")
30-
self.step_btn = pn.widgets.Button(name="Step", button_type="primary")
31-
self.progress = pn.indicators.Progress(
32-
name="Progress", value=0, width=600, max=n_steps, bar_color="primary"
33-
)
34-
self.sleep_val = pn.widgets.FloatInput(value=0.0, width=60)
35-
self.report_interval = pn.widgets.IntInput(value=report_interval)
36-
self.table = pn.widgets.Tabulator()
37-
self.fai = fai
38-
self.n_steps = n_steps
39-
self.curr_step = 0
40-
self.plot = plot
41-
self.thread = None
42-
self.erase_coef_val = pn.widgets.FloatInput(value=0.05, width=60, name="erase")
43-
44-
@param.depends("erase_coef_val.value")
45-
def update_erase_coef(self):
46-
self.fai.erase_coef = self.erase_coef_val.value
47-
48-
@param.depends("reset_btn.value")
49-
def on_reset_click(self):
50-
self.fai.reset()
51-
self.curr_step = 0
52-
self.progress.value = 1
53-
self.curr_step = 0
54-
self.play_btn.disabled = False
55-
self.pause_btn.disabled = True
56-
self.step_btn.disabled = False
57-
self.is_running = False
58-
self.progress.bar_color = "primary"
59-
summary = pd.DataFrame(self.fai.summary(), index=[0])
60-
self.table.value = summary
61-
if self.plot is not None:
62-
self.plot.reset(self.fai)
63-
self.plot.send(self.fai)
64-
65-
@param.depends("play_btn.value")
66-
def on_play_click(self):
67-
self.play_btn.disabled = True
68-
self.pause_btn.disabled = False
69-
self.is_running = True
70-
if self.thread is None or not self.thread.is_alive():
71-
self.thread = threading.Thread(target=self.run)
72-
self.thread.start()
73-
74-
@param.depends("pause_btn.clicks")
75-
def on_pause_click(self):
76-
self.is_running = False
77-
self.play_btn.disabled = False
78-
self.pause_btn.disabled = True
79-
if self.thread is not None:
80-
self.thread.join()
81-
82-
@param.depends("step_btn.value")
83-
def on_step_click(self):
84-
self.take_single_step()
85-
86-
def take_single_step(self):
87-
self.fai.step_tree()
88-
self.curr_step += 1
89-
self.progress.value = self.curr_step
90-
if self.curr_step >= self.n_steps:
91-
self.is_running = False
92-
self.progress.bar_color = "success"
93-
self.step_btn.disabled = True
94-
self.play_btn.disabled = True
95-
self.pause_btn.disabled = True
96-
97-
if self.fai.oobs.sum().cpu().item() == self.fai.n_walkers - 1:
98-
self.is_running = False
99-
self.progress.bar_color = "danger"
100-
101-
if self.fai.iteration % self.report_interval.value == 0:
102-
summary = pd.DataFrame(self.fai.summary(), index=[0])
103-
self.table.value = summary
104-
if self.plot is not None:
105-
self.plot.send(self.fai)
106-
107-
def run(self):
108-
while self.is_running:
109-
self.take_single_step()
110-
time.sleep(self.sleep_val.value)
111-
112-
def __panel__(self):
113-
# pn.state.add_periodic_callback(self.run, period=20)
114-
115-
return pn.Column(
116-
self.table,
117-
self.progress,
118-
pn.Row(
119-
self.play_btn,
120-
self.pause_btn,
121-
self.reset_btn,
122-
self.step_btn,
123-
pn.pane.Markdown("**Sleep**"),
124-
self.sleep_val,
125-
self.report_interval,
126-
self.erase_coef_val,
127-
),
128-
self.on_play_click,
129-
self.on_pause_click,
130-
self.on_reset_click,
131-
self.on_step_click,
132-
self.update_erase_coef,
133-
# self.run,
134-
)
135-
136-
13719
PYRAMID = [
13820
[-1, -1, -1, 0, 1, 2, -1, -1, -1],
13921
[-1, -1, 3, 4, 5, 6, 7, -1, -1],
@@ -324,7 +206,9 @@ def main():
324206
frameskip=3,
325207
check_death=True,
326208
episodic_life=False,
327-
) # , n_workers=10, ray=True)
209+
n_workers=10,
210+
ray=True,
211+
)
328212

329213
n_walkers = 10000
330214
plot = MontezumaDisplay()
@@ -333,6 +217,3 @@ def main():
333217
)
334218
runner = FaiRunner(fai, 1000000, plot=plot)
335219
pn.panel(pn.Column(runner, plot)).servable()
336-
337-
338-
main()

0 commit comments

Comments
 (0)