Skip to content

Commit 5736efb

Browse files
generatedunixname89002005287564facebook-github-bot
authored andcommitted
torcheval
Reviewed By: jermenkoo Differential Revision: D68317779 fbshipit-source-id: 36e00f1766bbe0baf7138320f3092bfafe65ae2d
1 parent 2c7dfb3 commit 5736efb

File tree

128 files changed

+745
-799
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+745
-799
lines changed

docs/source/ext/fbcode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python3
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) Meta Platforms, Inc. and affiliates.
43
# All rights reserved.
54
#

docs/update_docs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import ast
99
import os
1010
from pathlib import Path
11-
from typing import List, Tuple
1211

1312

14-
def get_submodule_vars(filename) -> Tuple[str, List[str]]:
13+
def get_submodule_vars(filename) -> tuple[str, list[str]]:
1514
"""This function reads the init files which are inside of the metrics/ and functional/ subdirectories.
1615
Each subdirectory within these two folders are associated with a domain of metrics, e.g. classification or regression.
1716
This function assumes that each of the init files has two variables defined.
@@ -27,7 +26,7 @@ def get_submodule_vars(filename) -> Tuple[str, List[str]]:
2726
doc_name (str): defined above
2827
all_modules (List[str]): defined above
2928
"""
30-
with open(filename, "r") as file:
29+
with open(filename) as file:
3130
tree = ast.parse(file.read())
3231

3332
doc_name = None

setup.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
from datetime import date
1212

13-
from typing import List
14-
1513
from setuptools import find_packages, setup
1614
from torcheval import __version__
1715

@@ -20,7 +18,7 @@ def current_path(file_name: str) -> str:
2018
return os.path.abspath(os.path.join(__file__, os.path.pardir, file_name))
2119

2220

23-
def read_requirements(file_name: str) -> List[str]:
21+
def read_requirements(file_name: str) -> list[str]:
2422
with open(current_path(file_name), encoding="utf8") as f:
2523
return f.read().strip().split()
2624

tests/metrics/aggregation/test_cat.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
from typing import List
109

1110
import torch
1211
from torcheval.metrics import Cat
@@ -19,7 +18,7 @@
1918

2019
class TestCat(MetricClassTester):
2120
def _test_cat_class_with_input(
22-
self, input_val_tensors: List[torch.Tensor], dim: int = 0
21+
self, input_val_tensors: list[torch.Tensor], dim: int = 0
2322
) -> None:
2423
self.run_class_implementation_tests(
2524
metric=Cat(),

tests/metrics/aggregation/test_cov.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66

77
# pyre-strict
88

9-
from typing import List
109

1110
import torch
1211
from torcheval.metrics import Covariance
1312
from torcheval.utils.test_utils.metric_class_tester import MetricClassTester
1413

1514

1615
class TestCovariance(MetricClassTester):
17-
def _test_covariance_with_input(self, batching: List[int]) -> None:
16+
def _test_covariance_with_input(self, batching: list[int]) -> None:
1817
gen = torch.Generator()
1918
gen.manual_seed(3)
2019
X = torch.randn(sum(batching), 4, generator=gen)

tests/metrics/aggregation/test_mean.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
from typing import List, Union
109

1110
import numpy as np
1211
import torch
@@ -77,8 +76,8 @@ def test_mean_class_update_input_valid_weight(self) -> None:
7776
]
7877

7978
def _compute_result(
80-
update_value: List[torch.Tensor],
81-
update_weight: List[Union[float, int, torch.Tensor]],
79+
update_value: list[torch.Tensor],
80+
update_weight: list[float | int | torch.Tensor],
8281
) -> torch.Tensor:
8382
weighted_sum = 0.0
8483
weights = 0.0

tests/metrics/aggregation/test_sum.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
# pyre-strict
88

9-
from typing import List, Union
109

1110
import torch
1211
from torcheval.metrics import Sum
@@ -68,8 +67,8 @@ def test_sum_class_update_input_valid_weight(self) -> None:
6867
]
6968

7069
def _compute_result(
71-
update_inputs: List[torch.Tensor],
72-
update_weights: List[Union[float, torch.Tensor]],
70+
update_inputs: list[torch.Tensor],
71+
update_weights: list[float | torch.Tensor],
7372
) -> torch.Tensor:
7473
weighted_sum = torch.tensor(0.0, dtype=torch.float64)
7574
for v, w in zip(update_inputs, update_weights):

tests/metrics/aggregation/test_throughput.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# pyre-strict
88

99
import random
10-
from typing import List
1110

1211
from torcheval.metrics import Throughput
1312
from torcheval.utils.test_utils.metric_class_tester import (
@@ -20,8 +19,8 @@
2019
class TestThroughput(MetricClassTester):
2120
def _test_throughput_class_with_input(
2221
self,
23-
num_processed: List[int],
24-
elapsed_time_sec: List[float],
22+
num_processed: list[int],
23+
elapsed_time_sec: list[float],
2524
) -> None:
2625
num_individual_update = NUM_TOTAL_UPDATES // NUM_PROCESSES
2726
expected_num_total = sum(num_processed)

tests/metrics/audio/test_fad.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
# pyre-strict
88

99
import unittest
10-
from typing import Optional
1110

1211
import numpy as np
1312
import torch
@@ -39,7 +38,7 @@ def gen_sine_wave(
3938
freq: float = 600,
4039
length_seconds: float = 6,
4140
sample_rate: int = 16_000,
42-
std_dev: Optional[float] = None,
41+
std_dev: float | None = None,
4342
) -> torch.Tensor:
4443
"""Creates sine wave of the specified frequency, sample_rate and length."""
4544
t = np.linspace(0, length_seconds, int(length_seconds * sample_rate))
@@ -53,7 +52,7 @@ def gen_sine_wave(
5352
return torch.from_numpy(np.asarray(2**15 * samples, dtype=np.int16)).float()
5453

5554

56-
def gen_fad_test_batch(num_files: int, std_dev: Optional[float]) -> torch.Tensor:
55+
def gen_fad_test_batch(num_files: int, std_dev: float | None) -> torch.Tensor:
5756
"""Creates a tensor representing a batch of sine waves with optional
5857
gaussian noise added in."""
5958
frequencies = np.linspace(100, 1000, num_files).tolist()

tests/metrics/classification/test_auprc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
# pyre-ignore-all-errors[56]: Pyre was not able to infer the type of argument
1010

11-
from typing import Optional, Tuple
1211

1312
import numpy as np
1413

@@ -427,7 +426,7 @@ def _get_sklearn_equivalent(
427426
self,
428427
input: torch.Tensor,
429428
target: torch.Tensor,
430-
average: Optional[str] = "macro",
429+
average: str | None = "macro",
431430
device: str = "cpu",
432431
) -> torch.Tensor:
433432
# Convert input/target to sklearn style inputs
@@ -445,7 +444,7 @@ def _get_sklearn_equivalent(
445444

446445
def _get_rand_inputs_multilabel(
447446
self, num_updates: int, num_labels: int, batch_size: int
448-
) -> Tuple[torch.Tensor, torch.Tensor]:
447+
) -> tuple[torch.Tensor, torch.Tensor]:
449448
input = torch.rand(size=[num_updates, batch_size, num_labels])
450449
targets = torch.randint(
451450
low=0, high=2, size=[num_updates, batch_size, num_labels]
@@ -457,7 +456,7 @@ def _check_against_sklearn(
457456
input: torch.Tensor,
458457
target: torch.Tensor,
459458
num_labels: int,
460-
average: Optional[str] = "macro",
459+
average: str | None = "macro",
461460
num_updates: int = NUM_TOTAL_UPDATES,
462461
) -> None:
463462
device = "cuda" if torch.cuda.is_available() else "cpu"

0 commit comments

Comments
 (0)