@@ -120,7 +120,7 @@ def reset_state(self, V, C, E, batch_size=None):
120120 alpha = self .f_p_alpha (V )
121121 beta = self .f_p_beta (V )
122122 self .p .value = alpha / (alpha + beta )
123- if batch_size is not None :
123+ if isinstance ( batch_size , int ) :
124124 assert self .p .shape [0 ] == batch_size
125125
126126 def f_p_alpha (self , V ):
@@ -434,7 +434,7 @@ def current(self, V, C, E):
434434 def reset_state (self , V , C , E , batch_size = None ):
435435 self .p .value = self .f_p_inf (V )
436436 self .q .value = self .f_q_inf (V )
437- if batch_size is not None :
437+ if isinstance ( batch_size , int ) :
438438 assert self .p .shape [0 ] == batch_size
439439 assert self .q .shape [0 ] == batch_size
440440
@@ -722,7 +722,7 @@ def current(self, V, C, E):
722722 def reset_state (self , V , C , E , batch_size = None ):
723723 self .p .value = self .f_p_inf (V )
724724 self .q .value = self .f_q_inf (V )
725- if batch_size is not None :
725+ if isinstance ( batch_size , int ) :
726726 assert self .p .shape [0 ] == batch_size
727727 assert self .q .shape [0 ] == batch_size
728728
@@ -1001,7 +1001,7 @@ def current(self, V, C, E):
10011001
10021002 def reset_state (self , V , C , E , batch_size = None ):
10031003 self .p .value = self .f_p_inf (V )
1004- if batch_size is not None :
1004+ if isinstance ( batch_size , int ) :
10051005 assert self .p .shape [0 ] == batch_size
10061006
10071007 def f_p_inf (self , V ):
@@ -1087,7 +1087,7 @@ def reset_state(self, V, batch_size=None):
10871087 alpha = self .f_p_alpha (V )
10881088 beta = self .f_p_beta (V )
10891089 self .p .value = alpha / (alpha + beta )
1090- if batch_size is not None :
1090+ if isinstance ( batch_size , int ) :
10911091 assert self .p .shape [0 ] == batch_size
10921092
10931093 def f_p_alpha (self , V ):
@@ -1410,7 +1410,7 @@ def current(self, V):
14101410 def reset_state (self , V , batch_size = None ):
14111411 self .p .value = self .f_p_inf (V )
14121412 self .q .value = self .f_q_inf (V )
1413- if batch_size is not None :
1413+ if isinstance ( batch_size , int ) :
14141414 assert self .p .shape [0 ] == batch_size
14151415 assert self .q .shape [0 ] == batch_size
14161416
@@ -1705,7 +1705,7 @@ def current(self, V):
17051705 def reset_state (self , V , batch_size = None ):
17061706 self .p .value = self .f_p_inf (V )
17071707 self .q .value = self .f_q_inf (V )
1708- if batch_size is not None :
1708+ if isinstance ( batch_size , int ) :
17091709 assert self .p .shape [0 ] == batch_size
17101710 assert self .q .shape [0 ] == batch_size
17111711
@@ -1991,7 +1991,7 @@ def current(self, V):
19911991
19921992 def reset_state (self , V , batch_size = None ):
19931993 self .p .value = self .f_p_inf (V )
1994- if batch_size is not None :
1994+ if isinstance ( batch_size , int ) :
19951995 assert self .p .shape [0 ] == batch_size
19961996
19971997 def f_p_inf (self , V ):
0 commit comments