Skip to content

Commit

Permalink
Fix lint errors in parts.py and processors.py.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 589085615
  • Loading branch information
jqdm committed Dec 10, 2023
1 parent 8f293ae commit 45061f4
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
9 changes: 4 additions & 5 deletions dqn_zoo/parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import csv
import os
import timeit
from typing import Any, Iterable, Mapping, Optional, Tuple, Union
from typing import Any, Iterable, Mapping, Optional, Sequence, Tuple, Union

import distrax
import dm_env
Expand Down Expand Up @@ -123,7 +123,7 @@ def run_loop(


def generate_statistics(
trackers: Iterable[Any],
trackers: Sequence[Any],
timestep_action_sequence: Iterable[
Tuple[
dm_env.Environment,
Expand Down Expand Up @@ -329,7 +329,7 @@ def get(self) -> Mapping[str, float]:
return self._statistics


def make_default_trackers(initial_agent: Agent):
def make_default_trackers(initial_agent: Agent) -> Sequence[Any]:
return [
EpisodeTracker(),
StepRateTracker(),
Expand Down Expand Up @@ -461,7 +461,7 @@ def __init__(self, fname: str):
self._header_written = False
self._fieldnames = None

def write(self, values: collections.OrderedDict) -> None:
def write(self, values: collections.OrderedDict[str, float]) -> None:
"""Appends given values as new row to CSV file."""
if self._fieldnames is None:
self._fieldnames = values.keys()
Expand All @@ -479,7 +479,6 @@ def write(self, values: collections.OrderedDict) -> None:

def close(self) -> None:
"""Closes the `CsvWriter`."""
pass

def get_state(self) -> Mapping[str, Any]:
"""Retrieves `CsvWriter` state as a `dict` (e.g. for serialization)."""
Expand Down
23 changes: 15 additions & 8 deletions dqn_zoo/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,17 @@ def named_tuple_sequence_stack(values: Sequence[NamedTuple]) -> NamedTuple:
class Deque:
"""Double ended queue with a maximum length and initial values."""

def __init__(self, max_length: int, initial_values=None):
def __init__(
self, max_length: int, initial_values: Optional[Iterable[Any]] = None
):
self._deque = collections.deque(maxlen=max_length)
self._initial_values = initial_values or []

def reset(self) -> None:
self._deque.clear()
self._deque.extend(self._initial_values)

def __call__(self, value: Any) -> collections.deque:
def __call__(self, value: Any) -> collections.deque[Any]:
self._deque.append(value)
return self._deque

Expand Down Expand Up @@ -297,19 +299,20 @@ def reduce_step_type(
"""Outputs a representative step type from an array of step types."""
# Zero padding will appear to be FIRST. Padding should only be seen before the
# FIRST (e.g. 000F) or after LAST (e.g. ML00).
if debug:
np_step_types = np.array(step_types)

output_step_type = StepType.MID
for i, step_type in enumerate(step_types):
if step_type == 0: # step_type not actually FIRST, but we do expect 000F.
if debug and not (np_step_types == 0).all():
raise ValueError('Expected zero padding followed by FIRST.')
if debug:
if not (np.array(step_types) == 0).all():
raise ValueError('Expected zero padding followed by FIRST.')
output_step_type = StepType.FIRST
break
elif step_type == StepType.LAST:
output_step_type = StepType.LAST
if debug and not (np_step_types[i + 1 :] == 0).all():
raise ValueError('Expected LAST to be followed by zero padding.')
if debug:
if not (np.array(step_types)[i + 1 :] == 0).all():
raise ValueError('Expected LAST to be followed by zero padding.')
break
else:
if step_type != StepType.MID:
Expand Down Expand Up @@ -343,8 +346,12 @@ def aggregate_discounts(
raise ValueError(
'All discounts should be 0 or 1, got: %s.' % np_discounts
)
else:
np_discounts = None

if None in discounts:
if debug:
assert isinstance(np_discounts, np.ndarray)
if not (np_discounts[-1] is None and (np_discounts[:-1] == 0).all()):
# Should have [0, 0, 0, None] due to zero padding.
raise ValueError('Should only have a None discount for FIRST.')
Expand Down

0 comments on commit 45061f4

Please sign in to comment.