Skip to content

Commit 07980bc

Browse files
author
John Halloran
committed
feat: skip update_stretch() when rho = 0, reducing program to regular NMF
1 parent 026538c commit 07980bc

File tree

1 file changed

+33
-27
lines changed

1 file changed

+33
-27
lines changed

src/diffpy/snmf/snmf_class.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def __init__(
6969
init_weights=None,
7070
init_components=None,
7171
init_stretch=None,
72-
max_iter=500,
72+
max_iter=100,
7373
min_iter=20,
74-
tol=5e-7,
74+
tol=1e-6,
7575
n_components=None,
7676
random_state=None,
7777
show_plots=False,
@@ -125,8 +125,8 @@ def __init__(
125125
n_components is not None and init_weights is not None
126126
):
127127
raise ValueError(
128-
"Conflicting source for n_components. Must provide either init_weights or n_components "
129-
"directly, but not both."
128+
"Conflicting or missing source for n_components. Must provide either init_weights "
129+
"or n_components directly, but not both."
130130
)
131131

132132
# Initialize weights and determine number of components
@@ -181,6 +181,7 @@ def fit(self, rho=0, eta=0, reset=True):
181181
rho : float Optional Default = 0
182182
The stretching factor that influences the decomposition. Zero corresponds to no
183183
stretching present. Relatively insensitive and typically adjusted in powers of 10.
184+
If equal to zero, program acts like standard NMF.
184185
eta : int Optional Default = 0
185186
The sparsity factor that influences the decomposition. Should be set to zero for
186187
non-sparse data such as PDF. Can be used to improve results for sparse data such
@@ -200,6 +201,10 @@ def fit(self, rho=0, eta=0, reset=True):
200201
self.rho = rho
201202
self.eta = eta
202203

204+
# If rho = 0, set stretching matrix to identity
205+
if self.rho == 0:
206+
self.stretch_ = np.ones_like(self.stretch_)
207+
203208
# Set up residual matrix, objective function, and history
204209
self.residuals = self.get_residual_matrix()
205210
self.objective_function = self.get_objective_function()
@@ -386,30 +391,31 @@ def outer_loop(self):
386391
):
387392
break
388393

389-
self.update_stretch()
390-
self.residuals = self.get_residual_matrix()
391-
self.objective_function = self.get_objective_function()
392-
print(
393-
f"Objective function after update_stretch: {self.objective_function:.5e}"
394-
)
395-
self._objective_history.append(self.objective_function)
396-
self.objective_difference = (
397-
self._objective_history[-2] - self._objective_history[-1]
398-
)
399-
if self.objective_function < self.best_objective:
400-
self.best_objective = self.objective_function
401-
self.best_matrices = [
402-
self.components_.copy(),
403-
self.weights_.copy(),
404-
self.stretch_.copy(),
405-
]
406-
if self.plotter is not None:
407-
self.plotter.update(
408-
components=self.components_,
409-
weights=self.weights_,
410-
stretch=self.stretch_,
411-
update_tag="stretch",
394+
if not self.rho == 0: # Don't update stretch if rho = 0
395+
self.update_stretch()
396+
self.residuals = self.get_residual_matrix()
397+
self.objective_function = self.get_objective_function()
398+
print(
399+
f"Objective function after update_stretch: {self.objective_function:.5e}"
400+
)
401+
self._objective_history.append(self.objective_function)
402+
self.objective_difference = (
403+
self._objective_history[-2] - self._objective_history[-1]
412404
)
405+
if self.objective_function < self.best_objective:
406+
self.best_objective = self.objective_function
407+
self.best_matrices = [
408+
self.components_.copy(),
409+
self.weights_.copy(),
410+
self.stretch_.copy(),
411+
]
412+
if self.plotter is not None:
413+
self.plotter.update(
414+
components=self.components_,
415+
weights=self.weights_,
416+
stretch=self.stretch_,
417+
update_tag="stretch",
418+
)
413419

414420
def get_residual_matrix(self, components=None, weights=None, stretch=None):
415421
"""

0 commit comments

Comments
 (0)