Skip to content

Commit e31cdb4

Browse files
authored
Merge pull request #34 from Vance-Raiti/main
Remove depreciated torchmetrics __init__ argument and fix str_to_one_hot for CUDA devices
2 parents 243151b + a46cf5a commit e31cdb4

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

enformer_pytorch/metrics.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66
class MeanPearsonCorrCoefPerChannel(Metric):
77
is_differentiable: Optional[bool] = False
8-
full_state_update:bool = False
98
higher_is_better: Optional[bool] = True
109
def __init__(self, n_channels:int, dist_sync_on_step=False):
1110
"""Calculates the mean pearson correlation across channels aggregated over regions"""
12-
super().__init__(dist_sync_on_step=dist_sync_on_step, full_state_update=False)
11+
super().__init__(dist_sync_on_step=dist_sync_on_step)
1312
self.reduce_dims=(0, 1)
1413
self.add_state("product", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
1514
self.add_state("true", default=torch.zeros(n_channels, dtype=torch.float32), dist_reduce_fx="sum", )
@@ -41,4 +40,4 @@ def compute(self):
4140
pred_var = self.pred_squared - self.count * torch.square(pred_mean)
4241
tp_var = torch.sqrt(true_var) * torch.sqrt(pred_var)
4342
correlation = covariance / tp_var
44-
return correlation
43+
return correlation

enformer_pytorch/modeling_enformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,8 +446,9 @@ def forward(
446446
if isinstance(x, list):
447447
x = str_to_one_hot(x)
448448

449-
elif x.dtype == torch.long:
449+
elif type(x) == torch.Tensor and x.dtype == torch.long:
450450
x = seq_indices_to_one_hot(x)
451+
x.to(self.device)
451452

452453
no_batch = x.ndim == 2
453454

0 commit comments

Comments
 (0)