Skip to content

Commit 45061f4

Browse files
committed
Fix lint errors in parts.py and processors.py.
PiperOrigin-RevId: 589085615
1 parent 8f293ae commit 45061f4

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

dqn_zoo/parts.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import csv
2323
import os
2424
import timeit
25-
from typing import Any, Iterable, Mapping, Optional, Tuple, Union
25+
from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple, Union
2626

2727
import distrax
2828
import dm_env
@@ -123,7 +123,7 @@ def run_loop(
123123

124124

125125
def generate_statistics(
126-
trackers: Iterable[Any],
126+
trackers: Sequence[Any],
127127
timestep_action_sequence: Iterable[
128128
Tuple[
129129
dm_env.Environment,
@@ -329,7 +329,7 @@ def get(self) -> Mapping[str, float]:
329329
return self._statistics
330330

331331

332-
def make_default_trackers(initial_agent: Agent):
332+
def make_default_trackers(initial_agent: Agent) -> Sequence[Any]:
333333
return [
334334
EpisodeTracker(),
335335
StepRateTracker(),
@@ -461,7 +461,7 @@ def __init__(self, fname: str):
461461
self._header_written = False
462462
self._fieldnames = None
463463

464-
def write(self, values: collections.OrderedDict) -> None:
464+
def write(self, values: collections.OrderedDict[str, float]) -> None:
465465
"""Appends given values as new row to CSV file."""
466466
if self._fieldnames is None:
467467
self._fieldnames = values.keys()
@@ -479,7 +479,6 @@ def write(self, values: collections.OrderedDict) -> None:
479479

480480
def close(self) -> None:
481481
"""Closes the `CsvWriter`."""
482-
pass
483482

484483
def get_state(self) -> Mapping[str, Any]:
485484
"""Retrieves `CsvWriter` state as a `dict` (e.g. for serialization)."""

dqn_zoo/processors.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,17 @@ def named_tuple_sequence_stack(values: Sequence[NamedTuple]) -> NamedTuple:
9494
class Deque:
9595
"""Double ended queue with a maximum length and initial values."""
9696

97-
def __init__(self, max_length: int, initial_values=None):
97+
def __init__(
98+
self, max_length: int, initial_values: Optional[Iterable[Any]] = None
99+
):
98100
self._deque = collections.deque(maxlen=max_length)
99101
self._initial_values = initial_values or []
100102

101103
def reset(self) -> None:
102104
self._deque.clear()
103105
self._deque.extend(self._initial_values)
104106

105-
def __call__(self, value: Any) -> collections.deque:
107+
def __call__(self, value: Any) -> collections.deque[Any]:
106108
self._deque.append(value)
107109
return self._deque
108110

@@ -297,19 +299,20 @@ def reduce_step_type(
297299
"""Outputs a representative step type from an array of step types."""
298300
# Zero padding will appear to be FIRST. Padding should only be seen before the
299301
# FIRST (e.g. 000F) or after LAST (e.g. ML00).
300-
if debug:
301-
np_step_types = np.array(step_types)
302+
302303
output_step_type = StepType.MID
303304
for i, step_type in enumerate(step_types):
304305
if step_type == 0: # step_type not actually FIRST, but we do expect 000F.
305-
if debug and not (np_step_types == 0).all():
306-
raise ValueError('Expected zero padding followed by FIRST.')
306+
if debug:
307+
if not (np.array(step_types) == 0).all():
308+
raise ValueError('Expected zero padding followed by FIRST.')
307309
output_step_type = StepType.FIRST
308310
break
309311
elif step_type == StepType.LAST:
310312
output_step_type = StepType.LAST
311-
if debug and not (np_step_types[i + 1 :] == 0).all():
312-
raise ValueError('Expected LAST to be followed by zero padding.')
313+
if debug:
314+
if not (np.array(step_types)[i + 1 :] == 0).all():
315+
raise ValueError('Expected LAST to be followed by zero padding.')
313316
break
314317
else:
315318
if step_type != StepType.MID:
@@ -343,8 +346,12 @@ def aggregate_discounts(
343346
raise ValueError(
344347
'All discounts should be 0 or 1, got: %s.' % np_discounts
345348
)
349+
else:
350+
np_discounts = None
351+
346352
if None in discounts:
347353
if debug:
354+
assert isinstance(np_discounts, np.ndarray)
348355
if not (np_discounts[-1] is None and (np_discounts[:-1] == 0).all()):
349356
# Should have [0, 0, 0, None] due to zero padding.
350357
raise ValueError('Should only have a None discount for FIRST.')

0 commit comments

Comments
 (0)