@@ -164,9 +164,13 @@ def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, sca
164164 self .act_fun = []
165165 self .depth = len (width ) - 1
166166
167+ #print('haha1', width)
167168 for i in range (len (width )):
168- if type (width [i ]) == int :
169+ #print(type(width[i]), type(width[i]) == int)
170+ if type (width [i ]) == int or type (width [i ]) == np .int64 :
169171 width [i ] = [width [i ],0 ]
172+
173+ #print('haha2', width)
170174
171175 self .width = width
172176
@@ -196,7 +200,18 @@ def __init__(self, width=None, grid=3, k=3, mult_arity = 2, noise_scale=0.3, sca
196200
197201 for l in range (self .depth ):
198202 # splines
199- sp_batch = KANLayer (in_dim = width_in [l ], out_dim = width_out [l + 1 ], num = grid , k = k , noise_scale = noise_scale , scale_base_mu = scale_base_mu , scale_base_sigma = scale_base_sigma , scale_sp = 1. , base_fun = base_fun , grid_eps = grid_eps , grid_range = grid_range , sp_trainable = sp_trainable , sb_trainable = sb_trainable , sparse_init = sparse_init )
203+ if isinstance (grid , list ):
204+ grid_l = grid [l ]
205+ else :
206+ grid_l = grid
207+
208+ if isinstance (k , list ):
209+ k_l = k [l ]
210+ else :
211+ k_l = k
212+
213+
214+ sp_batch = KANLayer (in_dim = width_in [l ], out_dim = width_out [l + 1 ], num = grid_l , k = k_l , noise_scale = noise_scale , scale_base_mu = scale_base_mu , scale_base_sigma = scale_base_sigma , scale_sp = 1. , base_fun = base_fun , grid_eps = grid_eps , grid_range = grid_range , sp_trainable = sp_trainable , sb_trainable = sb_trainable , sparse_init = sparse_init )
200215 self .act_fun .append (sp_batch )
201216
202217 self .node_bias = []
@@ -951,14 +966,14 @@ def unfix_symbolic(self, l, i, j, log_history=True):
951966 if log_history :
952967 self .log_history ('unfix_symbolic' )
953968
954- def unfix_symbolic_all (self ):
969+ def unfix_symbolic_all (self , log_history = True ):
955970 '''
956971 unfix all activation functions.
957972 '''
958973 for l in range (len (self .width ) - 1 ):
959- for i in range (self .width [l ]):
960- for j in range (self .width [l + 1 ]):
961- self .unfix_symbolic (l , i , j )
974+ for i in range (self .width_in [l ]):
975+ for j in range (self .width_out [l + 1 ]):
976+ self .unfix_symbolic (l , i , j , log_history )
962977
963978 def get_range (self , l , i , j , verbose = True ):
964979 '''
@@ -1522,6 +1537,10 @@ def closure():
15221537
15231538 if _ == steps - 1 and old_save_act :
15241539 self .save_act = True
1540+
1541+ if save_fig and _ % save_fig_freq == 0 :
1542+ save_act = self .save_act
1543+ self .save_act = True
15251544
15261545 train_id = np .random .choice (dataset ['train_input' ].shape [0 ], batch_size , replace = False )
15271546 test_id = np .random .choice (dataset ['test_input' ].shape [0 ], batch_size_test , replace = False )
@@ -1579,6 +1598,7 @@ def closure():
15791598 self .plot (folder = img_folder , in_vars = in_vars , out_vars = out_vars , title = "Step {}" .format (_ ), beta = beta )
15801599 plt .savefig (img_folder + '/' + str (_ ) + '.jpg' , bbox_inches = 'tight' , dpi = 200 )
15811600 plt .close ()
1601+ self .save_act = save_act
15821602
15831603 self .log_history ('fit' )
15841604 # revert back to original state
@@ -2160,7 +2180,7 @@ def suggest_symbolic(self, l, i, j, a_range=(-10, 10), b_range=(-10, 10), lib=No
21602180
21612181 return best_name , best_fun , best_r2 , best_c ;
21622182
2163- def auto_symbolic (self , a_range = (- 10 , 10 ), b_range = (- 10 , 10 ), lib = None , verbose = 1 ):
2183+ def auto_symbolic (self , a_range = (- 10 , 10 ), b_range = (- 10 , 10 ), lib = None , verbose = 1 , weight_simple = 0.8 , r2_threshold = 0.0 ):
21642184 '''
21652185 automatic symbolic regression for all edges
21662186
@@ -2174,7 +2194,10 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
21742194 library of candidate symbolic functions
21752195 verbose : int
21762196 larger verbosity => more verbosity
2177-
2197+ weight_simple : float
2198+ a weight that prioritizies simplicity (low complexity) over performance (high r2) - set to 0.0 to ignore complexity
2199+ r2_threshold : float
2200+ If r2 is below this threshold, the edge will not be fixed with any symbolic function - set to 0.0 to ignore this threshold
21782201 Returns:
21792202 --------
21802203 None
@@ -2191,17 +2214,19 @@ def auto_symbolic(self, a_range=(-10, 10), b_range=(-10, 10), lib=None, verbose=
21912214 for l in range (len (self .width_in ) - 1 ):
21922215 for i in range (self .width_in [l ]):
21932216 for j in range (self .width_out [l + 1 ]):
2194- #if self.symbolic_fun[l].mask[j, i] > 0. and self.act_fun[l].mask[i][j] == 0.:
21952217 if self .symbolic_fun [l ].mask [j , i ] > 0. and self .act_fun [l ].mask [i ][j ] == 0. :
21962218 print (f'skipping ({ l } ,{ i } ,{ j } ) since already symbolic' )
21972219 elif self .symbolic_fun [l ].mask [j , i ] == 0. and self .act_fun [l ].mask [i ][j ] == 0. :
21982220 self .fix_symbolic (l , i , j , '0' , verbose = verbose > 1 , log_history = False )
21992221 print (f'fixing ({ l } ,{ i } ,{ j } ) with 0' )
22002222 else :
2201- name , fun , r2 , c = self .suggest_symbolic (l , i , j , a_range = a_range , b_range = b_range , lib = lib , verbose = False )
2202- self .fix_symbolic (l , i , j , name , verbose = verbose > 1 , log_history = False )
2203- if verbose >= 1 :
2204- print (f'fixing ({ l } ,{ i } ,{ j } ) with { name } , r2={ r2 } , c={ c } ' )
2223+ name , fun , r2 , c = self .suggest_symbolic (l , i , j , a_range = a_range , b_range = b_range , lib = lib , verbose = False , weight_simple = weight_simple )
2224+ if r2 >= r2_threshold :
2225+ self .fix_symbolic (l , i , j , name , verbose = verbose > 1 , log_history = False )
2226+ if verbose >= 1 :
2227+ print (f'fixing ({ l } ,{ i } ,{ j } ) with { name } , r2={ r2 } , c={ c } ' )
2228+ else :
2229+ print (f'For ({ l } ,{ i } ,{ j } ) the best fit was { name } , but r^2 = { r2 } and this is lower than { r2_threshold } . This edge was omitted, keep training or try a different threshold.' )
22052230
22062231 self .log_history ('auto_symbolic' )
22072232
0 commit comments