Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions .github/scripts/microbench_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import argparse
import bisect
from pathlib import Path
from typing import Dict, List

def main():
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -214,7 +213,7 @@ def get_op_pattern(base_op_name: str, get_backward: bool) -> tuple:
else:
return (base_op_name, f"{base_op_name} ")

def process_l1_loss(content: str, case_name: str, data: List, columns: List):
def process_l1_loss(content: str, case_name: str, data: list, columns: list):
shape_matches = list(re.finditer(r"(shape\s*[:=].*?)(?=\n\S|$)", content))
shape_lines = [match.group(0) for match in shape_matches]
shape_positions = [match.start() for match in shape_matches]
Expand Down Expand Up @@ -281,7 +280,7 @@ def process_l1_loss(content: str, case_name: str, data: List, columns: List):

data.append([record.get(col, "") for col in columns])

def extract_times(content: str, pattern: str, get_backward: bool) -> List:
def extract_times(content: str, pattern: str, get_backward: bool) -> list:
lines = content.split('\n')
results = []
for line in lines:
Expand All @@ -297,8 +296,8 @@ def extract_times(content: str, pattern: str, get_backward: bool) -> List:

return results

def create_record(params: Dict, case_name: str, op_name: str,
backward: str, time_us: float) -> Dict:
def create_record(params: dict, case_name: str, op_name: str,
backward: str, time_us: float) -> dict:
return {
"P": params.get("p", ""),
**params,
Expand All @@ -316,7 +315,7 @@ def convert_to_us(value: float, unit: str) -> float:
return value * 1_000_000
return value

def extract_params(text: str) -> Dict:
def extract_params(text: str) -> dict:
params = {}
pairs = re.split(r'[;]', text.strip())

Expand Down
2 changes: 1 addition & 1 deletion .github/scripts/op_perf_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def preprocess_row(row):
def display_row(record):
formatted = {}
for key, value in record.items():
if isinstance(value, (list, tuple, dict)):
if isinstance(value, list | tuple | dict):
formatted[key] = str(value)
elif value == "NULL":
formatted[key] = "NULL"
Expand Down
6 changes: 3 additions & 3 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ init_command = [
'mccabe==0.7.0',
'pycodestyle==2.11.1',
'pyflakes==3.1.0',
'torchfix==0.4.0 ; python_version >= "3.9" and python_version < "3.13"',
'torchfix==0.4.0 ; python_version < "3.13"',
]


Expand Down Expand Up @@ -83,11 +83,11 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'numpy==1.26.4 ; python_version >= "3.9" and python_version <= "3.11"',
'numpy==1.26.4 ; python_version <= "3.11"',
'numpy==2.1.0 ; python_version >= "3.12"',
'expecttest==0.3.0',
'mypy==1.13.0',
'sympy==1.13.0 ; python_version >= "3.9"',
'sympy==1.13.0',
'types-requests==2.27.25',
'types-PyYAML==6.0.7',
'types-tabulate==0.8.8',
Expand Down
2 changes: 1 addition & 1 deletion mypy-strict.ini
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# files.

[mypy]
python_version = 3.8
python_version = 3.10
plugins = mypy_plugins/check_mypy_version.py, numpy.typing.mypy_plugin

cache_dir = .mypy_cache/strict
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ ignore = [
"B023",
"B028", # No explicit `stacklevel` keyword argument found
"B904", # Migrate from TRY200
"B905", # Python>=3.10 specific suppressions
"E402",
"C408", # C408 ignored because we like the dict keyword argument syntax
"E501", # E501 is not flexible enough, we're using B950 instead
Expand Down
3 changes: 2 additions & 1 deletion test/xpu/distributed/test_c10d_ops_xccl.py
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,8 @@ def test_all_to_all_single_none(self):
out = torch.zeros(self.world_size, 2, dtype=send.dtype).to(device)
dist.all_to_all_single(out, send)
self.assertEqual(
out.tolist(), list(zip(range(self.world_size), range(self.world_size)))
out.tolist(),
list(zip(range(self.world_size), range(self.world_size))),
)


Expand Down
2 changes: 1 addition & 1 deletion test/xpu/test_modules_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _test_multiple_device_transfer(self, device, dtype, module_info, training):
if torch.cuda.device_count() >= 2:
# === test cross-GPU transfer works
def _to_device1(objs):
if isinstance(objs, (tuple, list)):
if isinstance(objs, tuple | list):
return type(objs)(_to_device1(item) for item in objs)
elif isinstance(objs, dict):
return {name: _to_device1(item) for name, item in objs.items()}
Expand Down
8 changes: 4 additions & 4 deletions test/xpu/test_tensor_creation_ops_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import unittest
import warnings
from itertools import combinations, combinations_with_replacement, permutations, product
from typing import Any, Dict, List, Tuple
from typing import Any

import numpy as np
import torch
Expand Down Expand Up @@ -4852,11 +4852,11 @@ def test_astensor_consistency(self, device):
True,
42,
1.0,
# Homogeneous Lists
# Homogeneous lists
[True, True, False],
[1, 2, 3, 42],
[0.0, 1.0, 2.0, 3.0],
# Mixed Lists
# Mixed lists
[True, False, 0],
[0.0, True, False],
[0, 1.0, 42],
Expand Down Expand Up @@ -4894,7 +4894,7 @@ def test_numpy_scalars(self, device):
def test_default_device(self, device):
original = torch.arange(5)

examples: List[Tuple[Any, Dict]] = [
examples: list[tuple[Any, dict]] = [
(3, {}),
(original, {}),
(to_numpy(original), {}),
Expand Down
9 changes: 3 additions & 6 deletions test/xpu/test_torch_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from functools import partial
from itertools import chain, combinations, permutations, product
from multiprocessing.reduction import ForkingPickler
from typing import Tuple

import numpy as np
import torch
Expand Down Expand Up @@ -120,10 +119,8 @@
torch._C._get_privateuse1_backend_name(),
)

from typing import List


def my_get_all_device_types_xpu() -> List[str]:
def my_get_all_device_types_xpu() -> list[str]:
devices = [
"cpu",
]
Expand Down Expand Up @@ -487,7 +484,7 @@ def test_set_storage(self, device, dtype):

def _check_storage_meta(self, s, s_check):
self.assertTrue(
isinstance(s, (torch.UntypedStorage, torch.TypedStorage))
isinstance(s, torch.UntypedStorage | torch.TypedStorage)
and isinstance(s_check, type(s)),
(
"s and s_check must both be one of UntypedStorage or "
Expand Down Expand Up @@ -4138,7 +4135,7 @@ def test_errors_index_copy(self, device):

def _prepare_data_for_index_copy_and_add_deterministic(
self, dim: int, device: torch.device
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert dim >= 0 and dim < 3
a = [5, 4, 3]
a[dim] = 2000
Expand Down
3 changes: 2 additions & 1 deletion tools/linter/adapters/_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from functools import cached_property
from pathlib import Path
from tokenize import generate_tokens, TokenInfo
from typing import Any, Iterator, Sequence
from typing import Any
from collections.abc import Iterator, Sequence
from typing_extensions import Never


Expand Down
3 changes: 2 additions & 1 deletion tools/linter/adapters/docstring_linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import token
from functools import cached_property
from pathlib import Path
from typing import Iterator, Sequence, TYPE_CHECKING
from typing import TYPE_CHECKING
from collections.abc import Iterator, Sequence


_PARENT = Path(__file__).parent.absolute()
Expand Down
Loading
Loading