Skip to content

Commit b8d205c

Browse files
3rdCorelebrice
andauthored
Fixed MeasureSamplesPerSecondCallback and changed configs (#138)
* Update wandb.yaml * Update wandb_cluster.yaml * Fix condition to check tensor dimensions * Set `entity` to `null` in config instead of "MyOrganisation" --------- Co-authored-by: Fabrice Normandin <[email protected]>
1 parent d37d4c3 commit b8d205c

File tree

3 files changed

+3
-1
lines changed

3 files changed

+3
-1
lines changed

project/algorithms/callbacks/samples_per_second.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_num_samples(self, batch: BatchType) -> int:
190190
return next(
191191
v.shape[0]
192192
for v in optree.tree_leaves(batch) # type: ignore
193-
if isinstance(v, torch.Tensor) and v.ndim > 1
193+
if isinstance(v, torch.Tensor) and v.ndim >= 1
194194
)
195195
raise NotImplementedError(
196196
f"Don't know how many 'samples' there are in batch of type {type(batch)}"

project/configs/trainer/logger/wandb.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
wandb:
44
_target_: lightning.pytorch.loggers.wandb.WandbLogger
5+
entity: null # Optional. It can be useful to set this explicitly if you use a different organization account for this project.
56
project: "ResearchTemplate"
67
name: ${name}
78
save_dir: "."

project/configs/trainer/logger/wandb_cluster.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
defaults:
22
- wandb
33
wandb:
4+
entity: null # Optional. It can be useful to set this explicitly if you use a different organization account for this project.
45
project: "ResearchTemplate"
56
# TODO: Use the Orion trial name?
67
name: ${oc.env:SLURM_JOB_ID}_${oc.env:SLURM_PROCID}

0 commit comments

Comments
 (0)