@@ -69,9 +69,9 @@ def __init__(
6969 init_weights = None ,
7070 init_components = None ,
7171 init_stretch = None ,
72- max_iter = 500 ,
72+ max_iter = 100 ,
7373 min_iter = 20 ,
74- tol = 5e-7 ,
74+ tol = 1e-6 ,
7575 n_components = None ,
7676 random_state = None ,
7777 show_plots = False ,
@@ -125,8 +125,8 @@ def __init__(
125125 n_components is not None and init_weights is not None
126126 ):
127127 raise ValueError (
128- "Conflicting source for n_components. Must provide either init_weights or n_components "
129- "directly, but not both."
128+ "Conflicting or missing source for n_components. Must provide either init_weights "
129+ "or n_components directly, but not both."
130130 )
131131
132132 # Initialize weights and determine number of components
@@ -181,6 +181,7 @@ def fit(self, rho=0, eta=0, reset=True):
181181 rho : float Optional Default = 0
182182 The stretching factor that influences the decomposition. Zero corresponds to no
183183 stretching present. Relatively insensitive and typically adjusted in powers of 10.
184+ If equal to zero, program acts like standard NMF.
184185 eta : int Optional Default = 0
185186 The sparsity factor that influences the decomposition. Should be set to zero for
186187 non-sparse data such as PDF. Can be used to improve results for sparse data such
@@ -200,6 +201,10 @@ def fit(self, rho=0, eta=0, reset=True):
200201 self .rho = rho
201202 self .eta = eta
202203
204+ # If rho = 0, set stretching matrix to identity
205+ if self .rho == 0 :
206+ self .stretch_ = np .ones_like (self .stretch_ )
207+
203208 # Set up residual matrix, objective function, and history
204209 self .residuals = self .get_residual_matrix ()
205210 self .objective_function = self .get_objective_function ()
@@ -386,30 +391,31 @@ def outer_loop(self):
386391 ):
387392 break
388393
389- self .update_stretch ()
390- self .residuals = self .get_residual_matrix ()
391- self .objective_function = self .get_objective_function ()
392- print (
393- f"Objective function after update_stretch: { self .objective_function :.5e} "
394- )
395- self ._objective_history .append (self .objective_function )
396- self .objective_difference = (
397- self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
398- )
399- if self .objective_function < self .best_objective :
400- self .best_objective = self .objective_function
401- self .best_matrices = [
402- self .components_ .copy (),
403- self .weights_ .copy (),
404- self .stretch_ .copy (),
405- ]
406- if self .plotter is not None :
407- self .plotter .update (
408- components = self .components_ ,
409- weights = self .weights_ ,
410- stretch = self .stretch_ ,
411- update_tag = "stretch" ,
394+ if not self .rho == 0 : # Don't update stretch if rho = 0
395+ self .update_stretch ()
396+ self .residuals = self .get_residual_matrix ()
397+ self .objective_function = self .get_objective_function ()
398+ print (
399+ f"Objective function after update_stretch: { self .objective_function :.5e} "
400+ )
401+ self ._objective_history .append (self .objective_function )
402+ self .objective_difference = (
403+ self ._objective_history [- 2 ] - self ._objective_history [- 1 ]
412404 )
405+ if self .objective_function < self .best_objective :
406+ self .best_objective = self .objective_function
407+ self .best_matrices = [
408+ self .components_ .copy (),
409+ self .weights_ .copy (),
410+ self .stretch_ .copy (),
411+ ]
412+ if self .plotter is not None :
413+ self .plotter .update (
414+ components = self .components_ ,
415+ weights = self .weights_ ,
416+ stretch = self .stretch_ ,
417+ update_tag = "stretch" ,
418+ )
413419
414420 def get_residual_matrix (self , components = None , weights = None , stretch = None ):
415421 """
0 commit comments