Skip to content

Commit e46b81e

Browse files
committed
add optimizationrecord best()
1 parent d3c9b04 commit e46b81e

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/torchlensmaker/optimize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ class OptimizationRecord:
5353
def plot(self) -> None:
5454
plot_optimization_record(self)
5555

56+
def best(self) -> None:
57+
best_loss, idx = self.loss.min(dim=0)
58+
print(f"Best loss {best_loss.item()} at iteration {idx + 1} / {self.num_iter}")
59+
for n, p in self.parameters.items():
60+
print(" ", n, p[idx])
61+
print()
62+
5663

5764
def optimize(
5865
optics: nn.Module,
@@ -126,7 +133,6 @@ def plot_optimization_record(record: OptimizationRecord) -> None:
126133
parameters = record.parameters
127134
loss = record.loss
128135

129-
130136
# Plot parameters and loss
131137
fig, (ax1, ax2) = plt.subplots(2, 1)
132138
epoch_range = torch.arange(0, record.num_iter)

0 commit comments

Comments
 (0)