11from functools import partial
2- import threading
3- import time
42
53import holoviews as hv
64from holoviews .streams import Pipe
75import numpy as np
8- import pandas as pd
96import panel as pn
10- import param
117import plangym
128from plangym .utils import process_frame
139
10+ from fragile .core import FaiRunner
1411from fragile .shaolin .stream_plots import Image , RGB
1512from fragile .videogames import aggregate_visits , MontezumaTree
1613
1916pn .extension ("tabulator" , theme = "dark" )
2017
2118
22- class FaiRunner (param .Parameterized ):
23- is_running = param .Boolean (default = False )
24-
25- def __init__ (self , fai , n_steps , plot = None , report_interval = 100 ):
26- super ().__init__ ()
27- self .reset_btn = pn .widgets .Button (icon = "restore" , button_type = "primary" )
28- self .play_btn = pn .widgets .Button (icon = "player-play" , button_type = "primary" )
29- self .pause_btn = pn .widgets .Button (icon = "player-pause" , button_type = "primary" )
30- self .step_btn = pn .widgets .Button (name = "Step" , button_type = "primary" )
31- self .progress = pn .indicators .Progress (
32- name = "Progress" , value = 0 , width = 600 , max = n_steps , bar_color = "primary"
33- )
34- self .sleep_val = pn .widgets .FloatInput (value = 0.0 , width = 60 )
35- self .report_interval = pn .widgets .IntInput (value = report_interval )
36- self .table = pn .widgets .Tabulator ()
37- self .fai = fai
38- self .n_steps = n_steps
39- self .curr_step = 0
40- self .plot = plot
41- self .thread = None
42- self .erase_coef_val = pn .widgets .FloatInput (value = 0.05 , width = 60 , name = "erase" )
43-
44- @param .depends ("erase_coef_val.value" )
45- def update_erase_coef (self ):
46- self .fai .erase_coef = self .erase_coef_val .value
47-
48- @param .depends ("reset_btn.value" )
49- def on_reset_click (self ):
50- self .fai .reset ()
51- self .curr_step = 0
52- self .progress .value = 1
53- self .curr_step = 0
54- self .play_btn .disabled = False
55- self .pause_btn .disabled = True
56- self .step_btn .disabled = False
57- self .is_running = False
58- self .progress .bar_color = "primary"
59- summary = pd .DataFrame (self .fai .summary (), index = [0 ])
60- self .table .value = summary
61- if self .plot is not None :
62- self .plot .reset (self .fai )
63- self .plot .send (self .fai )
64-
65- @param .depends ("play_btn.value" )
66- def on_play_click (self ):
67- self .play_btn .disabled = True
68- self .pause_btn .disabled = False
69- self .is_running = True
70- if self .thread is None or not self .thread .is_alive ():
71- self .thread = threading .Thread (target = self .run )
72- self .thread .start ()
73-
74- @param .depends ("pause_btn.clicks" )
75- def on_pause_click (self ):
76- self .is_running = False
77- self .play_btn .disabled = False
78- self .pause_btn .disabled = True
79- if self .thread is not None :
80- self .thread .join ()
81-
82- @param .depends ("step_btn.value" )
83- def on_step_click (self ):
84- self .take_single_step ()
85-
86- def take_single_step (self ):
87- self .fai .step_tree ()
88- self .curr_step += 1
89- self .progress .value = self .curr_step
90- if self .curr_step >= self .n_steps :
91- self .is_running = False
92- self .progress .bar_color = "success"
93- self .step_btn .disabled = True
94- self .play_btn .disabled = True
95- self .pause_btn .disabled = True
96-
97- if self .fai .oobs .sum ().cpu ().item () == self .fai .n_walkers - 1 :
98- self .is_running = False
99- self .progress .bar_color = "danger"
100-
101- if self .fai .iteration % self .report_interval .value == 0 :
102- summary = pd .DataFrame (self .fai .summary (), index = [0 ])
103- self .table .value = summary
104- if self .plot is not None :
105- self .plot .send (self .fai )
106-
107- def run (self ):
108- while self .is_running :
109- self .take_single_step ()
110- time .sleep (self .sleep_val .value )
111-
112- def __panel__ (self ):
113- # pn.state.add_periodic_callback(self.run, period=20)
114-
115- return pn .Column (
116- self .table ,
117- self .progress ,
118- pn .Row (
119- self .play_btn ,
120- self .pause_btn ,
121- self .reset_btn ,
122- self .step_btn ,
123- pn .pane .Markdown ("**Sleep**" ),
124- self .sleep_val ,
125- self .report_interval ,
126- self .erase_coef_val ,
127- ),
128- self .on_play_click ,
129- self .on_pause_click ,
130- self .on_reset_click ,
131- self .on_step_click ,
132- self .update_erase_coef ,
133- # self.run,
134- )
135-
136-
13719PYRAMID = [
13820 [- 1 , - 1 , - 1 , 0 , 1 , 2 , - 1 , - 1 , - 1 ],
13921 [- 1 , - 1 , 3 , 4 , 5 , 6 , 7 , - 1 , - 1 ],
@@ -324,7 +206,9 @@ def main():
324206 frameskip = 3 ,
325207 check_death = True ,
326208 episodic_life = False ,
327- ) # , n_workers=10, ray=True)
209+ n_workers = 10 ,
210+ ray = True ,
211+ )
328212
329213 n_walkers = 10000
330214 plot = MontezumaDisplay ()
@@ -333,6 +217,3 @@ def main():
333217 )
334218 runner = FaiRunner (fai , 1000000 , plot = plot )
335219 pn .panel (pn .Column (runner , plot )).servable ()
336-
337-
338- main ()
0 commit comments