From 45061f4bbbcfa87d11bbba3cfc2305a650a41c26 Mon Sep 17 00:00:00 2001 From: johnquan Date: Fri, 8 Dec 2023 12:16:59 +0000 Subject: [PATCH] Fix lint errors in `parts.py` and `processors.py`. PiperOrigin-RevId: 589085615 --- dqn_zoo/parts.py | 9 ++++----- dqn_zoo/processors.py | 23 +++++++++++++++-------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/dqn_zoo/parts.py b/dqn_zoo/parts.py index 14276bc..ea28898 100644 --- a/dqn_zoo/parts.py +++ b/dqn_zoo/parts.py @@ -22,7 +22,7 @@ import csv import os import timeit -from typing import Any, Iterable, Mapping, Optional, Tuple, Union +from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple, Union import distrax import dm_env @@ -123,7 +123,7 @@ def run_loop( def generate_statistics( - trackers: Iterable[Any], + trackers: Sequence[Any], timestep_action_sequence: Iterable[ Tuple[ dm_env.Environment, @@ -329,7 +329,7 @@ def get(self) -> Mapping[str, float]: return self._statistics -def make_default_trackers(initial_agent: Agent): +def make_default_trackers(initial_agent: Agent) -> Sequence[Any]: return [ EpisodeTracker(), StepRateTracker(), @@ -461,7 +461,7 @@ def __init__(self, fname: str): self._header_written = False self._fieldnames = None - def write(self, values: collections.OrderedDict) -> None: + def write(self, values: collections.OrderedDict[str, float]) -> None: """Appends given values as new row to CSV file.""" if self._fieldnames is None: self._fieldnames = values.keys() @@ -479,7 +479,6 @@ def write(self, values: collections.OrderedDict) -> None: def close(self) -> None: """Closes the `CsvWriter`.""" - pass def get_state(self) -> Mapping[str, Any]: """Retrieves `CsvWriter` state as a `dict` (e.g. for serialization).""" diff --git a/dqn_zoo/processors.py b/dqn_zoo/processors.py index 83a0f2d..7774de3 100644 --- a/dqn_zoo/processors.py +++ b/dqn_zoo/processors.py @@ -94,7 +94,9 @@ def named_tuple_sequence_stack(values: Sequence[NamedTuple]) -> NamedTuple: class Deque: """Double ended queue with a maximum length and initial values.""" - def __init__(self, max_length: int, initial_values=None): + def __init__( + self, max_length: int, initial_values: Optional[Iterable[Any]] = None + ): self._deque = collections.deque(maxlen=max_length) self._initial_values = initial_values or [] @@ -102,7 +104,7 @@ def reset(self) -> None: self._deque.clear() self._deque.extend(self._initial_values) - def __call__(self, value: Any) -> collections.deque: + def __call__(self, value: Any) -> collections.deque[Any]: self._deque.append(value) return self._deque @@ -297,19 +299,20 @@ def reduce_step_type( """Outputs a representative step type from an array of step types.""" # Zero padding will appear to be FIRST. Padding should only be seen before the # FIRST (e.g. 000F) or after LAST (e.g. ML00). - if debug: - np_step_types = np.array(step_types) + output_step_type = StepType.MID for i, step_type in enumerate(step_types): if step_type == 0: # step_type not actually FIRST, but we do expect 000F. - if debug and not (np_step_types == 0).all(): - raise ValueError('Expected zero padding followed by FIRST.') + if debug: + if not (np.array(step_types) == 0).all(): + raise ValueError('Expected zero padding followed by FIRST.') output_step_type = StepType.FIRST break elif step_type == StepType.LAST: output_step_type = StepType.LAST - if debug and not (np_step_types[i + 1 :] == 0).all(): - raise ValueError('Expected LAST to be followed by zero padding.') + if debug: + if not (np.array(step_types)[i + 1 :] == 0).all(): + raise ValueError('Expected LAST to be followed by zero padding.') break else: if step_type != StepType.MID: @@ -343,8 +346,12 @@ def aggregate_discounts( raise ValueError( 'All discounts should be 0 or 1, got: %s.' % np_discounts ) + else: + np_discounts = None + if None in discounts: if debug: + assert isinstance(np_discounts, np.ndarray) if not (np_discounts[-1] is None and (np_discounts[:-1] == 0).all()): # Should have [0, 0, 0, None] due to zero padding. raise ValueError('Should only have a None discount for FIRST.')