Skip to content

Commit f15f089

Browse files
committed
Reformatting due to switch to ruff
1 parent c33d19f commit f15f089

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+177
-189
lines changed

docs/nbstripout.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Implements a platform-independent way of calling nbstripout (used in pyproject.toml)."""
2+
23
import glob
34
import os
45
from pathlib import Path

examples/inverse/irl_gail.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def dist(loc_scale: tuple[torch.Tensor, torch.Tensor]) -> Distribution:
247247
test_collector = Collector[CollectStats](algorithm, test_envs)
248248
# log
249249
t0 = datetime.datetime.now().strftime("%m%d_%H%M%S")
250-
log_file = f'seed_{args.seed}_{t0}-{args.task.replace("-", "_")}_gail'
250+
log_file = f"seed_{args.seed}_{t0}-{args.task.replace('-', '_')}_gail"
251251
log_path = os.path.join(args.logdir, args.task, "gail", log_file)
252252
writer = SummaryWriter(log_path)
253253
writer.add_text("args", str(args))

examples/offline/convert_rl_unplugged_atari.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
* episode_return: Total episode return computed using per-step [-1, 1]
2828
clipping.
2929
"""
30+
3031
import os
3132
from argparse import ArgumentParser, Namespace
3233

@@ -179,13 +180,16 @@ def download(url: str, fname: str, chunk_size: int | None = 1024) -> None:
179180
if os.path.exists(fname):
180181
print(f"Found cached file at {fname}.")
181182
return
182-
with open(fname, "wb") as ofile, tqdm(
183-
desc=fname,
184-
total=total,
185-
unit="iB",
186-
unit_scale=True,
187-
unit_divisor=1024,
188-
) as bar:
183+
with (
184+
open(fname, "wb") as ofile,
185+
tqdm(
186+
desc=fname,
187+
total=total,
188+
unit="iB",
189+
unit_scale=True,
190+
unit_divisor=1024,
191+
) as bar,
192+
):
189193
for data in resp.iter_content(chunk_size=chunk_size):
190194
size = ofile.write(data)
191195
bar.update(size)

test/base/env.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def __init__(
3434
random_sleep: bool = False,
3535
array_state: bool = False,
3636
) -> None:
37-
assert (
38-
dict_state + recurse_state + array_state <= 1
39-
), "dict_state / recurse_state / array_state can be only one true"
37+
assert dict_state + recurse_state + array_state <= 1, (
38+
"dict_state / recurse_state / array_state can be only one true"
39+
)
4040
self.size = size
4141
self.sleep = sleep
4242
self.random_sleep = random_sleep
@@ -208,9 +208,9 @@ def step(
208208

209209
class MyGoalEnv(MoveToRightEnv):
210210
def __init__(self, *args: Any, **kwargs: Any) -> None:
211-
assert (
212-
kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0
213-
), "dict_state / recurse_state not supported"
211+
assert kwargs.get("dict_state", 0) + kwargs.get("recurse_state", 0) == 0, (
212+
"dict_state / recurse_state not supported"
213+
)
214214
super().__init__(*args, **kwargs)
215215
super().reset(options={"state": 0})
216216

test/base/test_buffer.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import os
22
import pickle
33
import tempfile
4-
from test.base.env import MoveToRightEnv, MyGoalEnv
54
from typing import cast
65

76
import h5py
@@ -10,6 +9,7 @@
109
import pytest
1110
import torch
1211

12+
from test.base.env import MoveToRightEnv, MyGoalEnv
1313
from tianshou.data import (
1414
Batch,
1515
CachedReplayBuffer,
@@ -1451,10 +1451,7 @@ def test_custom_key() -> None:
14511451
# Check if they have the same keys
14521452
assert set(batch.get_keys()) == set(
14531453
sampled_batch.get_keys(),
1454-
), "Batches have different keys: {} and {}".format(
1455-
set(batch.get_keys()),
1456-
set(sampled_batch.get_keys()),
1457-
)
1454+
), f"Batches have different keys: {set(batch.get_keys())} and {set(sampled_batch.get_keys())}"
14581455
# Compare the values for each key
14591456
for key in batch.get_keys():
14601457
if isinstance(batch.__dict__[key], np.ndarray) and isinstance(

test/base/test_collector.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from collections.abc import Callable, Sequence
2-
from test.base.env import MoveToRightEnv, NXEnv
32
from typing import Any
43

54
import gymnasium as gym
65
import numpy as np
76
import pytest
87
import tqdm
98

9+
from test.base.env import MoveToRightEnv, NXEnv
1010
from tianshou.algorithm.algorithm_base import Policy, episode_mc_return_to_go
1111
from tianshou.data import (
1212
AsyncCollector,
@@ -410,11 +410,8 @@ def test_collector_with_dict_state() -> None:
410410
result = c1.collect(n_episode=8)
411411
assert result.n_collected_episodes == 8
412412
lens = np.bincount(result.lens)
413-
assert (
414-
result.n_collected_steps == 21
415-
and np.all(lens == [0, 0, 2, 2, 2, 2])
416-
or result.n_collected_steps == 20
417-
and np.all(lens == [0, 0, 3, 1, 2, 2])
413+
assert (result.n_collected_steps == 21 and np.all(lens == [0, 0, 2, 2, 2, 2])) or (
414+
result.n_collected_steps == 20 and np.all(lens == [0, 0, 3, 1, 2, 2])
418415
)
419416
batch, _ = c1.buffer.sample(10)
420417
c0.buffer.update(c1.buffer)

test/base/test_env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import sys
22
import time
33
from collections.abc import Callable
4-
from test.base.env import MoveToRightEnv, NXEnv
54
from typing import Any, Literal
65

76
import gymnasium as gym
87
import numpy as np
98
import pytest
109
from gymnasium.spaces.discrete import Discrete
1110

11+
from test.base.env import MoveToRightEnv, NXEnv
1212
from tianshou.data import Batch
1313
from tianshou.env import (
1414
ContinuousToDiscrete,

test/continuous/test_ddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import argparse
22
import os
3-
from test.determinism_test import AlgorithmDeterminismTest
43

54
import gymnasium as gym
65
import numpy as np
76
import torch
87
from torch.utils.tensorboard import SummaryWriter
98

9+
from test.determinism_test import AlgorithmDeterminismTest
1010
from tianshou.algorithm import DDPG
1111
from tianshou.algorithm.algorithm_base import Algorithm
1212
from tianshou.algorithm.modelfree.ddpg import ContinuousDeterministicPolicy

test/continuous/test_npg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import argparse
22
import os
3-
from test.determinism_test import AlgorithmDeterminismTest
43

54
import gymnasium as gym
65
import numpy as np
@@ -9,6 +8,7 @@
98
from torch.distributions import Distribution, Independent, Normal
109
from torch.utils.tensorboard import SummaryWriter
1110

11+
from test.determinism_test import AlgorithmDeterminismTest
1212
from tianshou.algorithm import NPG
1313
from tianshou.algorithm.algorithm_base import Algorithm
1414
from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy

test/continuous/test_ppo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import argparse
22
import os
3-
from test.determinism_test import AlgorithmDeterminismTest
43

54
import gymnasium as gym
65
import numpy as np
76
import torch
87
from torch.distributions import Distribution, Independent, Normal
98
from torch.utils.tensorboard import SummaryWriter
109

10+
from test.determinism_test import AlgorithmDeterminismTest
1111
from tianshou.algorithm import PPO
1212
from tianshou.algorithm.algorithm_base import Algorithm
1313
from tianshou.algorithm.modelfree.reinforce import ProbabilisticActorPolicy

0 commit comments

Comments
 (0)