Skip to content

Commit

Permalink
Merge branch 'main' into chore/invalid_status
Browse files Browse the repository at this point in the history
  • Loading branch information
yanksyoon authored Jan 11, 2024
2 parents d0ce202 + 261c441 commit 5db0e8a
Show file tree
Hide file tree
Showing 12 changed files with 75 additions and 64 deletions.
8 changes: 4 additions & 4 deletions ops/charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ class RelationCreatedEvent(RelationEvent):
can occur before units for those applications have started. All existing
relations should be established before start.
"""
unit: None
unit: None # pyright: ignore[reportIncompatibleVariableOverride]
"""Always ``None``."""


Expand All @@ -481,7 +481,7 @@ class RelationJoinedEvent(RelationEvent):
remote ``private-address`` setting, which is always available when
the relation is created and is by convention not deleted.
"""
unit: model.Unit
unit: model.Unit # pyright: ignore[reportIncompatibleVariableOverride]
"""The remote unit that has triggered this event."""


Expand Down Expand Up @@ -523,7 +523,7 @@ class RelationDepartedEvent(RelationEvent):
Once all callback methods bound to this event have been run for such a
relation, the unit agent will fire the :class:`RelationBrokenEvent`.
"""
unit: model.Unit
unit: model.Unit # pyright: ignore[reportIncompatibleVariableOverride]
"""The remote unit that has triggered this event."""

def __init__(self, handle: 'Handle', relation: 'model.Relation',
Expand Down Expand Up @@ -580,7 +580,7 @@ class RelationBrokenEvent(RelationEvent):
bound to this event is being executed, it is guaranteed that no remote units
are currently known locally.
"""
unit: None
unit: None # pyright: ignore[reportIncompatibleVariableOverride]
"""Always ``None``."""


Expand Down
15 changes: 8 additions & 7 deletions ops/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,17 +277,18 @@ def get_secret(self, *, id: Optional[str] = None, label: Optional[str] = None) -
return Secret(self._backend, id=id, label=label, content=content)


if typing.TYPE_CHECKING:
# (entity type, name): instance.
_WeakCacheType = weakref.WeakValueDictionary[
Tuple['UnitOrApplicationType', str],
Optional[Union['Unit', 'Application']]]


class _ModelCache:
def __init__(self, meta: 'ops.charm.CharmMeta', backend: '_ModelBackend'):
if typing.TYPE_CHECKING:
# (entity type, name): instance.
_weakcachetype = weakref.WeakValueDictionary[
Tuple['UnitOrApplicationType', str],
Optional[Union['Unit', 'Application']]]

self._meta = meta
self._backend = backend
self._weakrefs: _weakcachetype = weakref.WeakValueDictionary()
self._weakrefs: _WeakCacheType = weakref.WeakValueDictionary()

@typing.overload
def get(self, entity_type: Type['Unit'], name: str) -> 'Unit': ... # noqa
Expand Down
8 changes: 4 additions & 4 deletions ops/pebble.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,7 +1614,7 @@ def _websocket_to_writer(ws: '_WebSocket', writer: '_WebsocketWriter',
break

if encoding is not None:
chunk = chunk.decode(encoding)
chunk = typing.cast(bytes, chunk).decode(encoding)
writer.write(chunk)


Expand Down Expand Up @@ -2019,7 +2019,7 @@ def _wait_change_using_wait(self, change_id: ChangeID, timeout: Optional[float])

def _wait_change(self, change_id: ChangeID, timeout: Optional[float] = None) -> Change:
"""Call the wait-change API endpoint directly."""
query = {}
query: Dict[str, Any] = {}
if timeout is not None:
query['timeout'] = _format_timeout(timeout)

Expand Down Expand Up @@ -2255,7 +2255,7 @@ def _encode_multipart(self, metadata: Dict[str, Any], path: str,
elif isinstance(source, bytes):
source_io: _AnyStrFileLikeIO = io.BytesIO(source)
else:
source_io: _AnyStrFileLikeIO = source
source_io: _AnyStrFileLikeIO = source # type: ignore
boundary = binascii.hexlify(os.urandom(16))
path_escaped = path.replace('"', '\\"').encode('utf-8') # NOQA: test_quote_backslashes
content_type = f"multipart/form-data; boundary=\"{boundary.decode('utf-8')}\"" # NOQA: test_quote_backslashes
Expand Down Expand Up @@ -2736,7 +2736,7 @@ def get_checks(
Returns:
List of :class:`CheckInfo` objects.
"""
query = {}
query: Dict[str, Any] = {}
if level is not None:
query['level'] = level.value
if names:
Expand Down
4 changes: 2 additions & 2 deletions ops/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import subprocess
from datetime import timedelta
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple, Union
from typing import Any, Callable, Generator, List, Optional, Tuple, Union, cast

import yaml # pyright: ignore[reportMissingModuleSource]

Expand Down Expand Up @@ -205,7 +205,7 @@ def notices(self, event_path: Optional[str] = None) -> '_NoticeGenerator':
if not rows:
break
for row in rows:
yield tuple(row)
yield cast(_Notice, tuple(row))


class JujuStorage:
Expand Down
5 changes: 3 additions & 2 deletions ops/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import shutil
import signal
import tempfile
import typing
import uuid
import warnings
from contextlib import contextmanager
Expand Down Expand Up @@ -3020,7 +3021,7 @@ def push(
file_path.write_bytes(source)
else:
# If source is binary, open file in binary mode and ignore encoding param
is_binary = isinstance(source.read(0), bytes)
is_binary = isinstance(source.read(0), bytes) # type: ignore
open_mode = 'wb' if is_binary else 'w'
open_encoding = None if is_binary else encoding
with file_path.open(open_mode, encoding=open_encoding) as f:
Expand Down Expand Up @@ -3144,7 +3145,7 @@ def _transform_exec_handler_output(self,
f"exec handler must return bytes if encoding is None,"
f"not {data.__class__.__name__}")
else:
return io.StringIO(data)
return io.StringIO(typing.cast(str, data))

def exec(
self,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ reportMissingModuleSource = false
reportPrivateUsage = false
reportUnnecessaryIsInstance = false
reportUnnecessaryComparison = false
disableBytesTypePromotions = false
stubPath = ""
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ flake8-builtins~=2.1
pyproject-flake8~=6.1
pep8-naming~=0.13
pytest~=7.2
pyright==1.1.317
pyright==1.1.345
pytest-operator~=0.23
coverage[toml]~=7.0
typing_extensions~=4.2
Expand Down
3 changes: 1 addition & 2 deletions test/test_charm.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def test_observe_decorated_method(self):
# way we know of to cleanly decorate charm event observers.
events: typing.List[ops.EventBase] = []

def dec(fn: typing.Callable[['MyCharm', ops.EventBase], None] # noqa: F821
) -> typing.Callable[..., None]:
def dec(fn: typing.Any) -> typing.Callable[..., None]:
# simple decorator that appends to the nonlocal
# `events` list all events it receives
@functools.wraps(fn)
Expand Down
52 changes: 28 additions & 24 deletions test/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,31 @@ def _on_event(self, event: ops.EventBase):
'ObjectWithStorage[obj]/on/event[1]']))


MutableTypesTestCase = typing.Tuple[
typing.Callable[[], typing.Any], # Called to get operand A.
typing.Any, # Operand B.
typing.Any, # Expected result.
typing.Callable[[typing.Any, typing.Any], None], # Operation to perform.
typing.Callable[[typing.Any, typing.Any], typing.Any], # Validation to perform.
]

ComparisonOperationsTestCase = typing.Tuple[
typing.Any, # Operand A.
typing.Any, # Operand B.
typing.Callable[[typing.Any, typing.Any], bool], # Operation to test.
bool, # Result of op(A, B).
bool, # Result of op(B, A).
]

SetOperationsTestCase = typing.Tuple[
typing.Set[str], # A set to test an operation against (other_set).
# An operation to test.
typing.Callable[[typing.Set[str], typing.Set[str]], typing.Set[str]],
typing.Set[str], # The expected result of operation(obj._stored.set, other_set).
typing.Set[str], # The expected result of operation(other_set, obj._stored.set).
]


class TestStoredState(BaseTestCase):

def setUp(self):
Expand Down Expand Up @@ -1116,14 +1141,7 @@ def test_mutable_types(self):
# Test and validation functions in a list of tuples.
# Assignment and keywords like del are not supported in lambdas
# so functions are used instead.
test_case = typing.Tuple[
typing.Callable[[], typing.Any], # Called to get operand A.
typing.Any, # Operand B.
typing.Any, # Expected result.
typing.Callable[[typing.Any, typing.Any], None], # Operation to perform.
typing.Callable[[typing.Any, typing.Any], typing.Any], # Validation to perform.
]
test_operations: typing.List[test_case] = [(
test_operations: typing.List[MutableTypesTestCase] = [(
lambda: {},
None,
{},
Expand Down Expand Up @@ -1336,14 +1354,7 @@ def save_snapshot(self, value: typing.Union[ops.StoredStateData, ops.EventBase])
framework_copy.close()

def test_comparison_operations(self):
test_case = typing.Tuple[
typing.Any, # Operand A.
typing.Any, # Operand B.
typing.Callable[[typing.Any, typing.Any], bool], # Operation to test.
bool, # Result of op(A, B).
bool, # Result of op(B, A).
]
test_operations: typing.List[test_case] = [(
test_operations: typing.List[ComparisonOperationsTestCase] = [(
{"1"},
{"1", "2"},
lambda a, b: a < b,
Expand Down Expand Up @@ -1436,14 +1447,7 @@ class SomeObject(ops.Object):
self.assertEqual(op(b, obj._stored.a), op_ba)

def test_set_operations(self):
test_case = typing.Tuple[
typing.Set[str], # A set to test an operation against (other_set).
# An operation to test.
typing.Callable[[typing.Set[str], typing.Set[str]], typing.Set[str]],
typing.Set[str], # The expected result of operation(obj._stored.set, other_set).
typing.Set[str], # The expected result of operation(other_set, obj._stored.set).
]
test_operations: typing.List[test_case] = [(
test_operations: typing.List[SetOperationsTestCase] = [(
{"1"},
lambda a, b: a | b,
{"1", "a", "b"},
Expand Down
23 changes: 13 additions & 10 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2378,7 +2378,14 @@ def test_unresolved_ingress_addresses(self):
self.assertEqual(binding.network.ingress_addresses, ['foo.bar.baz.com'])


_metric_and_label_pair = typing.Tuple[typing.Dict[str, float], typing.Dict[str, str]]
_MetricAndLabelPair = typing.Tuple[typing.Dict[str, float], typing.Dict[str, str]]


_ValidMetricsTestCase = typing.Tuple[
typing.Mapping[str, typing.Union[int, float]],
typing.Mapping[str, str],
typing.List[typing.List[str]],
]


class TestModelBackend(unittest.TestCase):
Expand Down Expand Up @@ -2851,12 +2858,8 @@ def test_juju_log(self):
[['juju-log', '--log-level', 'BAR', '--', 'foo']])

def test_valid_metrics(self):
_caselist = typing.List[typing.Tuple[
typing.Mapping[str, typing.Union[int, float]],
typing.Mapping[str, str],
typing.List[typing.List[str]]]]
fake_script(self, 'add-metric', 'exit 0')
test_cases: _caselist = [(
test_cases: typing.List[_ValidMetricsTestCase] = [(
OrderedDict([('foo', 42), ('b-ar', 4.5), ('ba_-z', 4.5), ('a', 1)]),
OrderedDict([('de', 'ad'), ('be', 'ef_ -')]),
[['add-metric', '--labels', 'de=ad,be=ef_ -',
Expand All @@ -2871,7 +2874,7 @@ def test_valid_metrics(self):
self.assertEqual(fake_script_calls(self, clear=True), expected_calls)

def test_invalid_metric_names(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'': 4.2}, {}),
({'1': 4.2}, {}),
({'1': -4.2}, {}),
Expand All @@ -2890,7 +2893,7 @@ def test_invalid_metric_names(self):
self.backend.add_metrics(metrics, labels)

def test_invalid_metric_values(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'a': float('+inf')}, {}),
({'a': float('-inf')}, {}),
({'a': float('nan')}, {}),
Expand All @@ -2902,7 +2905,7 @@ def test_invalid_metric_values(self):
self.backend.add_metrics(metrics, labels)

def test_invalid_metric_labels(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'foo': 4.2}, {'': 'baz'}),
({'foo': 4.2}, {',bar': 'baz'}),
({'foo': 4.2}, {'b=a=r': 'baz'}),
Expand All @@ -2913,7 +2916,7 @@ def test_invalid_metric_labels(self):
self.backend.add_metrics(metrics, labels)

def test_invalid_metric_label_values(self):
invalid_inputs: typing.List[_metric_and_label_pair] = [
invalid_inputs: typing.List[_MetricAndLabelPair] = [
({'foo': 4.2}, {'bar': ''}),
({'foo': 4.2}, {'bar': 'b,az'}),
({'foo': 4.2}, {'bar': 'b=az'}),
Expand Down
11 changes: 6 additions & 5 deletions test/test_pebble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2425,7 +2425,7 @@ def _parse_write_multipart(self,
for part in message.walk():
name = part.get_param('name', header='Content-Disposition')
if name == 'request':
req = json.loads(part.get_payload())
req = json.loads(typing.cast(str, part.get_payload()))
elif name == 'files':
# decode=True, ironically, avoids decoding bytes to str
content = part.get_payload(decode=True)
Expand Down Expand Up @@ -3092,10 +3092,11 @@ def test_wait_exit_nonzero(self):
process = self.client.exec(['false'])
with self.assertRaises(pebble.ExecError) as cm:
process.wait()
self.assertEqual(cm.exception.command, ['false'])
self.assertEqual(cm.exception.exit_code, 1)
self.assertEqual(cm.exception.stdout, None)
self.assertEqual(cm.exception.stderr, None)
exc = typing.cast(pebble.ExecError[str], cm.exception)
self.assertEqual(exc.command, ['false'])
self.assertEqual(exc.exit_code, 1)
self.assertIsNone(exc.stdout)
self.assertIsNone(exc.stderr)

self.assertEqual(self.client.requests, [
('POST', '/v1/exec', None, self.build_exec_data(['false'])),
Expand Down
7 changes: 4 additions & 3 deletions test/test_real_pebble.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@ def test_exec_wait_output(self):
with self.assertRaises(pebble.ExecError) as cm:
process = self.client.exec(['/bin/sh', '-c', 'echo OUT; echo ERR >&2; exit 42'])
process.wait_output()
self.assertEqual(cm.exception.exit_code, 42)
self.assertEqual(cm.exception.stdout, 'OUT\n')
self.assertEqual(cm.exception.stderr, 'ERR\n')
exc = typing.cast(pebble.ExecError[str], cm.exception)
self.assertEqual(exc.exit_code, 42)
self.assertEqual(exc.stdout, 'OUT\n')
self.assertEqual(exc.stderr, 'ERR\n')

def test_exec_send_stdin(self):
process = self.client.exec(['awk', '{ print toupper($0) }'], stdin='foo\nBar\n')
Expand Down

0 comments on commit 5db0e8a

Please sign in to comment.