From 91ab8cf1d41359ef892d0cbde54c7b9198204243 Mon Sep 17 00:00:00 2001 From: Enzo Bonnal Date: Sun, 11 May 2025 17:46:02 +0100 Subject: [PATCH 1/7] 1.6.0a1: Make `Stream` an `AsyncIterable` decorator and an `Awaitable[Stream]`; add all async twin operations (closes #8 #100 #101) (#102) --- .github/workflows/pypi.yml | 3 + CONTRIBUTING.md | 2 +- README.md | 416 ++++-- setup.py | 3 +- streamable/__init__.py | 2 + streamable/afunctions.py | 353 +++++ streamable/aiterators.py | 924 +++++++++++++ streamable/functions.py | 118 +- streamable/iterators.py | 143 +- streamable/stream.py | 403 +++++- streamable/util/asynctools.py | 26 + streamable/util/errors.py | 4 + streamable/util/functiontools.py | 73 +- streamable/util/futuretools.py | 16 +- streamable/util/iterabletools.py | 62 + streamable/util/validationtools.py | 19 +- streamable/visitors/__init__.py | 2 + streamable/visitors/aiterator.py | 227 ++++ streamable/visitors/base.py | 24 + streamable/visitors/equality.py | 144 ++- streamable/visitors/iterator.py | 98 +- streamable/visitors/representation.py | 96 +- tests/test_iterators.py | 40 +- tests/test_readme.py | 98 +- tests/test_stream.py | 1728 ++++++++++++++++++------- tests/test_visitor.py | 16 + tests/utils.py | 215 +++ version.py | 2 +- 28 files changed, 4578 insertions(+), 679 deletions(-) create mode 100644 streamable/afunctions.py create mode 100644 streamable/aiterators.py create mode 100644 streamable/util/asynctools.py create mode 100644 streamable/util/errors.py create mode 100644 streamable/util/iterabletools.py create mode 100644 streamable/visitors/aiterator.py create mode 100644 tests/utils.py diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 02bed3e0..09157371 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -4,6 +4,9 @@ on: push: branches: - main + - alpha + - beta + - rc paths: - 'version.py' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0624e37c..06a6cebc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,7 +27,7 @@ is relatively close to Applying an operation simply performs argument validation and returns a new instance of a `Stream` sub-class corresponding to the operation. Defining a stream is basically constructing a composite structure where each node is a `Stream` instance: The above `events` variable holds a `CatchStream[int]` instance, whose `.upstream` attribute points to a `TruncateStream[int]` instance, whose `.upstream` attribute points to a `ForeachStream[int]` instance, whose `.upstream` points to a `Stream[int]` instance. Each node's `.source` attribute points to the same `range(10)`. ## Visitor Pattern -Each node in this composite structure exposes an `.accept` method enabling traversal by a visitor. Both `Stream.__iter__` and `Stream.__repr__` rely on visitor classes defined in the `streamable.visitors` package. +Each node in this composite structure exposes an `.accept` method enabling traversal by a visitor. `Stream.__iter__`/`Stream.__aiter__`/`Stream.__repr__` rely on visitor classes defined in the `streamable.visitors` package. ## Decorator Pattern A `Stream[T]` both inherits from `Iterable[T]` and holds an `Iterable[T]` as its `.source`: when you instantiate a stream from an iterable you decorate it with a fluent interface. diff --git a/README.md b/README.md index 4bf64c0b..1a799397 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,11 @@ # ΰΌ„ `streamable` -### *Pythonic Stream-like manipulation of iterables* +### *Pythonic Stream-like manipulation of (async) iterables* - πŸ”— ***Fluent*** chainable lazy operations -- πŸ”€ ***Concurrent*** via *threads*/*processes*/`asyncio` -- πŸ‡Ή ***Typed***, fully annotated, `Stream[T]` is an `Iterable[T]` +- πŸ”€ ***Concurrent*** via *threads*/*processes*/`async` +- πŸ‡Ή ***Typed***, fully annotated, `Stream[T]` is both an `Iterable[T]` and an `AsyncIterable[T]` - πŸ›‘οΈ ***Tested*** extensively with **Python 3.7 to 3.14** - πŸͺΆ ***Light***, no dependencies @@ -35,7 +35,7 @@ from streamable import Stream ## 3. init -Create a `Stream[T]` *decorating* an `Iterable[T]`: +Create a `Stream[T]` *decorating* an `Iterable[T]` (or an `AsyncIterable[T]`): ```python integers: Stream[int] = Stream(range(10)) @@ -55,9 +55,11 @@ inverses: Stream[float] = ( ## 5. iterate -Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, elements are processed *on-the-fly*: +Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`/`AsyncIterable`, elements are processed *on-the-fly*: -- **collect** + +### as `Iterable[T]` +- **into data structure** ```python >>> list(inverses) [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] @@ -65,7 +67,13 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, eleme {0.5, 1.0, 0.2, 0.33, 0.25, 0.17, 0.14, 0.12, 0.11} ``` -- **reduce** +- **`for`** +```python +>>> [inverse for inverse in inverses]: +[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] +``` + +- **`reduce`** ```python >>> sum(inverses) 2.82 @@ -73,23 +81,139 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, eleme >>> reduce(..., inverses) ``` -- **loop** +- **`iter`/`next`** +```python +>>> next(iter(inverses)) +1.0 +``` + +### as `AsyncIterable[T]` + +- **`async for`** ```python ->>> for inverse in inverses: ->>> ... +>>> async def main() -> List[float]: +>>> return [inverse async for inverse in inverses] + +>>> asyncio.run(main()) +[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] ``` -- **next** +- **`aiter`/`anext`** ```python ->>> next(iter(inverses)) +>>> asyncio.run(anext(aiter(inverses))) # before 3.10: inverses.__aiter__().__anext__() 1.0 ``` + + +# ↔️ **Extract-Transform-Load** + +**ETL scripts** can benefit from the expressiveness of this library. Below is a pipeline that extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: + +```python +import csv +from datetime import timedelta +import itertools +import requests +from streamable import Stream + +with open("./quadruped_pokemons.csv", mode="w") as file: + fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] + writer = csv.DictWriter(file, fields, extrasaction='ignore') + writer.writeheader() + + pipeline: Stream = ( + # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur + Stream(itertools.count(1)) + # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs + .throttle(16, per=timedelta(seconds=1)) + # GETs pokemons concurrently using a pool of 8 threads + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .map(requests.get, concurrency=8) + .foreach(requests.Response.raise_for_status) + .map(requests.Response.json) + # Stops the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keeps only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + .observe("quadruped pokemons") + # Catches errors due to None "generation" or "shape" + .catch( + TypeError, + when=lambda error: str(error) == "'NoneType' object is not subscriptable" + ) + # Writes a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catches exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) + + pipeline() +``` + +## or the `async` way + +Using `.amap` (the `.map`'s `async` counterpart), and `await`ing the stream to exhaust it as an `AsyncIterable[T]`: + +```python +import asyncio +import csv +from datetime import timedelta +import itertools +import httpx +from streamable import Stream + +async def main() -> None: + with open("./quadruped_pokemons.csv", mode="w") as file: + fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] + writer = csv.DictWriter(file, fields, extrasaction='ignore') + writer.writeheader() + + async with httpx.AsyncClient() as http_async_client: + pipeline: Stream = ( + # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur + Stream(itertools.count(1)) + # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs + .throttle(16, per=timedelta(seconds=1)) + # GETs pokemons via 8 concurrent asyncio coroutines + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .amap(http_async_client.get, concurrency=8) + .foreach(httpx.Response.raise_for_status) + .map(httpx.Response.json) + # Stops the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keeps only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + .observe("quadruped pokemons") + # Catches errors due to None "generation" or "shape" + .catch( + TypeError, + when=lambda error: str(error) == "'NoneType' object is not subscriptable" + ) + # Writes a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catches exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) + + await pipeline + +asyncio.run(main()) +``` + # πŸ“’ ***Operations*** *A dozen expressive lazy operations and that’s it!* -# `.map` +## `.map` > Applies a transformation on elements: @@ -102,7 +226,7 @@ assert list(integer_strings) == ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9 ``` -## concurrency +### concurrency > [!NOTE] > By default, all the concurrency modes presented below yield results in the upstream order (FIFO). Set the parameter `ordered=False` to yield results as they become available (***First Done, First Out***). @@ -149,32 +273,11 @@ if __name__ == "__main__": ``` -### `asyncio`-based concurrency - -> The sibling operation `.amap` applies an async function: - -
πŸ‘€ show example
- -```python -import httpx -import asyncio - -http_async_client = httpx.AsyncClient() - -pokemon_names: Stream[str] = ( - Stream(range(1, 4)) - .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") - .amap(http_async_client.get, concurrency=3) - .map(httpx.Response.json) - .map(lambda poke: poke["name"]) -) +### `async`-based concurrency: [see `.amap`](#amap) -assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] -asyncio.get_event_loop().run_until_complete(http_async_client.aclose()) -``` -
+> [The `.amap` operation can apply an `async` function concurrently.](#amap) -## "starmap" +### "starmap" > The `star` function decorator transforms a function that takes several positional arguments into a function that takes a tuple: @@ -194,9 +297,7 @@ assert list(zeros) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0] -# `.foreach` - - +## `.foreach` > Applies a side effect on elements: @@ -211,18 +312,19 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ``` -## concurrency +### concurrency > Similar to `.map`: > - set the `concurrency` parameter for **thread-based concurrency** > - set `via="process"` for **process-based concurrency** -> - use the sibling `.aforeach` operation for **`asyncio`-based concurrency** > - set `ordered=False` for ***First Done First Out*** +> - [The `.aforeach` operation can apply an `async` effect concurrently.](#aforeach) -# `.group` +## `.group` -> Groups elements into `List`s: +> Groups into `List`s +> ... up to a given group `size`:
πŸ‘€ show example
```python @@ -231,6 +333,9 @@ integers_by_5: Stream[List[int]] = integers.group(size=5) assert list(integers_by_5) == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] ```
+ +> ... and/or co-groups `by` a given key: +
πŸ‘€ show example
```python @@ -239,6 +344,9 @@ integers_by_parity: Stream[List[int]] = integers.group(by=lambda n: n % 2) assert list(integers_by_parity) == [[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]] ```
+ +> ... and/or co-groups the elements yielded by the upstream within a given time `interval`: +
πŸ‘€ show example
```python @@ -254,7 +362,9 @@ assert list(integers_within_1_sec) == [[0, 1, 2], [3, 4], [5, 6], [7, 8], [9]] ```
-> Mix the `size`/`by`/`interval` parameters: +> [!TIP] +> Combine the `size`/`by`/`interval` parameters: +
πŸ‘€ show example
```python @@ -267,7 +377,6 @@ assert list(integers_by_parity_by_2) == [[0, 2], [1, 3], [4, 6], [5, 7], [8], [9 ```
- ## `.groupby` > Like `.group`, but groups into `(key, elements)` tuples: @@ -300,7 +409,7 @@ assert list(counts_by_parity) == [("even", 5), ("odd", 5)] ``` -# `.flatten` +## `.flatten` > Ungroups elements assuming that they are `Iterable`s: @@ -328,7 +437,7 @@ assert list(mixed_ones_and_zeros) == [0, 1, 0, 1, 0, 1, 0, 1] ``` -# `.filter` +## `.filter` > Keeps only the elements that satisfy a condition: @@ -341,7 +450,7 @@ assert list(even_integers) == [0, 2, 4, 6, 8] ``` -# `.distinct` +## `.distinct` > Removes duplicates: @@ -383,7 +492,7 @@ assert list(consecutively_distinct_chars) == ["f", "o", "b", "a", "r", "f", "o"] ``` -# `.truncate` +## `.truncate` > Ends iteration once a given number of elements have been yielded: @@ -409,7 +518,7 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4] > If both `count` and `when` are set, truncation occurs as soon as either condition is met. -# `.skip` +## `.skip` > Skips the first specified number of elements: @@ -435,7 +544,7 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9] > If both `count` and `until` are set, skipping stops as soon as either condition is met. -# `.catch` +## `.catch` > Catches a given type of exception, and optionally yields a `replacement` value: @@ -495,8 +604,7 @@ assert len(errors) == len("foo") ``` - -# `.throttle` +## `.throttle` > Limits the number of yields `per` time interval: @@ -513,7 +621,7 @@ assert list(three_integers_per_second) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -# `.observe` +## `.observe` > Logs the progress of iterations:
πŸ‘€ show example
@@ -545,7 +653,7 @@ logging.getLogger("streamable").setLevel(logging.WARNING) ```
-# `+` +## `+` > Concatenates streams: @@ -557,7 +665,7 @@ assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4 -# `zip` +## `zip` > [!TIP] > Use the standard `zip` function: @@ -583,8 +691,6 @@ assert list(cubes) == [0, 1, 8, 27, 64, 125, 216, 343, 512, 729] ## `.count` - - > Iterates over the stream until exhaustion and returns the number of elements yielded:
πŸ‘€ show example
@@ -594,11 +700,8 @@ assert integers.count() == 10 ```
- ## `()` - - > *Calling* the stream iterates over it until exhaustion and returns it:
πŸ‘€ show example
@@ -610,8 +713,7 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ```
- -# `.pipe` +## `.pipe` > Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any: @@ -631,6 +733,138 @@ import pandas as pd > Inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html). +--- +--- + +# πŸ“’ ***`async` Operations*** + +Operations that accept a function as an argument have an `async` counterpart, which has the same signature but accepts `async` functions instead. These `async` operations are named the same as the original ones but with an `a` prefix. + +> [!TIP] +> One can mix regular and `async` operations on the same `Stream`, and then consume it as a regular `Iterable` or as an `AsyncIterable`. + +## `.amap` + +> Applies an `async` transformation on elements: + + +### Consume as `Iterable[T]` + +
πŸ‘€ show example
+ +```python +import asyncio +import httpx + +http_async_client = httpx.AsyncClient() + +pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http_async_client.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) +) + +assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] +asyncio.run(http_async_client.aclose()) +``` +
+ +### Consume as `AsyncIterable[T]` + +
πŸ‘€ show example
+ +```python +import asyncio +import httpx + +async def main() -> None: + async with httpx.AsyncClient() as http_async_client: + pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http_async_client.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) + ) + assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] + +asyncio.run(main()) +``` +
+ + +## `.aforeach` + +> Applies an `async` side effect on elements. Supports `concurrency` like `.amap`. + +## `.agroup` + +> Groups into `List`s according to an `async` grouping function. + +## `.agroupby` + +> Groups into `(key, elements)` tuples, according to an `async` grouping function. + +## `.aflatten` + +> Ungroups elements assuming that they are `AsyncIterable`s. + +> Like for `.flatten` you can set the `concurrency` parameter. + +## `.afilter` + +> Keeps only the elements that satisfy an `async` condition. + +## `.adistinct` + +> Removes duplicates according to an `async` deduplication `key`. + +## `.atruncate` + +> Ends iteration once a given number of elements have been yielded or `when` an `async` condition is satisfied. + +## `.askip` + +> Skips the specified number of elements or `until` an `async` predicate is satisfied. + +## `.acatch` + +> Catches a given type of exception `when` an `async` condition is satisfied. + +## Shorthands for consuming the stream as an `AsyncIterable[T]` + +## `.acount` + +> Iterates over the stream until exhaustion and returns the number of elements yielded: + +
πŸ‘€ show example
+ +```python +assert asyncio.run(integers.acount()) == 10 +``` +
+ + +## `await` + +> *Awaiting* the stream iterates over it until exhaustion and returns it: + +
πŸ‘€ show example
+ +```python +async def test_await() -> None: + state: List[int] = [] + appending_integers: Stream[int] = integers.foreach(state.append) + appending_integers is await appending_integers + assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +asyncio.run(test_await()) +``` +
+ +--- +--- # πŸ’‘ Notes @@ -662,58 +896,6 @@ assert collected == [0, 1, 2, 3, 5, 6, 7, 8, 9] -## Extract-Transform-Load -> [!TIP] -> **Custom ETL scripts** can benefit from the expressiveness of this library. Below is a pipeline that extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: - -
πŸ‘€ show example
- -```python -import csv -from datetime import timedelta -import itertools -import requests -from streamable import Stream - -with open("./quadruped_pokemons.csv", mode="w") as file: - fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] - writer = csv.DictWriter(file, fields, extrasaction='ignore') - writer.writeheader() - - pipeline: Stream = ( - # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur - Stream(itertools.count(1)) - # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs - .throttle(16, per=timedelta(seconds=1)) - # GETs pokemons concurrently using a pool of 8 threads - .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") - .map(requests.get, concurrency=8) - .foreach(requests.Response.raise_for_status) - .map(requests.Response.json) - # Stops the iteration when reaching the 1st pokemon of the 4th generation - .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") - .observe("pokemons") - # Keeps only quadruped Pokemons - .filter(lambda poke: poke["shape"]["name"] == "quadruped") - .observe("quadruped pokemons") - # Catches errors due to None "generation" or "shape" - .catch( - TypeError, - when=lambda error: str(error) == "'NoneType' object is not subscriptable" - ) - # Writes a batch of pokemons every 5 seconds to the CSV file - .group(interval=timedelta(seconds=5)) - .foreach(writer.writerows) - .flatten() - .observe("written pokemons") - # Catches exceptions and raises the 1st one at the end of the iteration - .catch(Exception, finally_raise=True) - ) - - pipeline() -``` -
- ## Visitor Pattern > [!TIP] > A `Stream` can be visited via its `.accept` method: implement a custom [***visitor***](https://en.wikipedia.org/wiki/Visitor_pattern) by extending the abstract class `streamable.visitors.Visitor`: diff --git a/setup.py b/setup.py index 4e6561a5..1ae77f09 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ from setuptools import find_packages, setup # type: ignore + from version import __version__ setup( @@ -10,7 +11,7 @@ license="Apache 2.", author="ebonnal", author_email="bonnal.enzo.dev@gmail.com", - description="Pythonic Stream-like manipulation of iterables", + description="Pythonic Stream-like manipulation of (async) iterables", long_description=open("README.md").read(), long_description_content_type="text/markdown", ) diff --git a/streamable/__init__.py b/streamable/__init__.py index a51ad250..17cf16ae 100644 --- a/streamable/__init__.py +++ b/streamable/__init__.py @@ -1,2 +1,4 @@ from streamable.stream import Stream from streamable.util.functiontools import star + +__all__ = ["Stream", "star"] diff --git a/streamable/afunctions.py b/streamable/afunctions.py new file mode 100644 index 00000000..5532de3c --- /dev/null +++ b/streamable/afunctions.py @@ -0,0 +1,353 @@ +import builtins +import datetime +from contextlib import suppress +from operator import itemgetter +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Callable, + Coroutine, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from streamable.aiterators import ( + ACatchAsyncIterator, + ADistinctAsyncIterator, + AFilterAsyncIterator, + AFlattenAsyncIterator, + AGroupbyAsyncIterator, + AMapAsyncIterator, + ConcurrentAFlattenAsyncIterator, + ConcurrentAMapAsyncIterator, + ConcurrentFlattenAsyncIterator, + ConcurrentMapAsyncIterator, + ConsecutiveADistinctAsyncIterator, + CountAndPredicateASkipAsyncIterator, + CountSkipAsyncIterator, + CountTruncateAsyncIterator, + FlattenAsyncIterator, + GroupAsyncIterator, + ObserveAsyncIterator, + PredicateASkipAsyncIterator, + PredicateATruncateAsyncIterator, + YieldsPerPeriodThrottleAsyncIterator, +) +from streamable.util.constants import NO_REPLACEMENT +from streamable.util.functiontools import asyncify +from streamable.util.validationtools import ( + validate_aiterator, + validate_concurrency, + validate_errors, + validate_group_size, + # validate_not_none, + validate_optional_count, + validate_optional_positive_count, + validate_optional_positive_interval, + validate_via, +) + +with suppress(ImportError): + from typing import Literal + +T = TypeVar("T") +U = TypeVar("U") + + +def catch( + aiterator: AsyncIterator[T], + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Any]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, +) -> AsyncIterator[T]: + return acatch( + aiterator, + errors, + when=asyncify(when) if when else None, + replacement=replacement, + finally_raise=finally_raise, + ) + + +def acatch( + aiterator: AsyncIterator[T], + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_errors(errors) + # validate_not_none(finally_raise, "finally_raise") + if errors is None: + return aiterator + return ACatchAsyncIterator( + aiterator, + tuple(builtins.filter(None, errors)) + if isinstance(errors, Iterable) + else errors, + when=when, + replacement=replacement, + finally_raise=finally_raise, + ) + + +def distinct( + aiterator: AsyncIterator[T], + key: Optional[Callable[[T], Any]] = None, + *, + consecutive_only: bool = False, +) -> AsyncIterator[T]: + return adistinct( + aiterator, + asyncify(key) if key else None, + consecutive_only=consecutive_only, + ) + + +def adistinct( + aiterator: AsyncIterator[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + *, + consecutive_only: bool = False, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + # validate_not_none(consecutive_only, "consecutive_only") + if consecutive_only: + return ConsecutiveADistinctAsyncIterator(aiterator, key) + return ADistinctAsyncIterator(aiterator, key) + + +def filter( + aiterator: AsyncIterator[T], + when: Callable[[T], Any], +) -> AsyncIterator[T]: + return afilter(aiterator, asyncify(when)) + + +def afilter( + aiterator: AsyncIterator[T], + when: Callable[[T], Any], +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + return AFilterAsyncIterator(aiterator, when) + + +def flatten( + aiterator: AsyncIterator[Iterable[T]], *, concurrency: int = 1 +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_concurrency(concurrency) + if concurrency == 1: + return FlattenAsyncIterator(aiterator) + else: + return ConcurrentFlattenAsyncIterator( + aiterator, + concurrency=concurrency, + buffersize=concurrency, + ) + + +def aflatten( + aiterator: AsyncIterator[AsyncIterable[T]], *, concurrency: int = 1 +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_concurrency(concurrency) + if concurrency == 1: + return AFlattenAsyncIterator(aiterator) + else: + return ConcurrentAFlattenAsyncIterator( + aiterator, + concurrency=concurrency, + buffersize=concurrency, + ) + + +def group( + aiterator: AsyncIterator[T], + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Any]] = None, +) -> AsyncIterator[List[T]]: + return agroup( + aiterator, + size, + interval=interval, + by=asyncify(by) if by else None, + ) + + +def agroup( + aiterator: AsyncIterator[T], + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> AsyncIterator[List[T]]: + validate_aiterator(aiterator) + validate_group_size(size) + validate_optional_positive_interval(interval) + if by is None: + return GroupAsyncIterator(aiterator, size, interval) + return map(itemgetter(1), AGroupbyAsyncIterator(aiterator, by, size, interval)) + + +def groupby( + aiterator: AsyncIterator[T], + key: Callable[[T], U], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, +) -> AsyncIterator[Tuple[U, List[T]]]: + return agroupby(aiterator, asyncify(key), size=size, interval=interval) + + +def agroupby( + aiterator: AsyncIterator[T], + key: Callable[[T], Coroutine[Any, Any, U]], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, +) -> AsyncIterator[Tuple[U, List[T]]]: + validate_aiterator(aiterator) + validate_group_size(size) + validate_optional_positive_interval(interval) + return AGroupbyAsyncIterator(aiterator, key, size, interval) + + +def map( + transformation: Callable[[T], U], + aiterator: AsyncIterator[T], + *, + concurrency: int = 1, + ordered: bool = True, + via: "Literal['thread', 'process']" = "thread", +) -> AsyncIterator[U]: + validate_aiterator(aiterator) + # validate_not_none(transformation, "transformation") + # validate_not_none(ordered, "ordered") + validate_concurrency(concurrency) + validate_via(via) + if concurrency == 1: + return amap(asyncify(transformation), aiterator) + else: + return ConcurrentMapAsyncIterator( + aiterator, + transformation, + concurrency=concurrency, + buffersize=concurrency, + ordered=ordered, + via=via, + ) + + +def amap( + transformation: Callable[[T], Coroutine[Any, Any, U]], + aiterator: AsyncIterator[T], + *, + concurrency: int = 1, + ordered: bool = True, +) -> AsyncIterator[U]: + validate_aiterator(aiterator) + # validate_not_none(transformation, "transformation") + # validate_not_none(ordered, "ordered") + validate_concurrency(concurrency) + if concurrency == 1: + return AMapAsyncIterator(aiterator, transformation) + return ConcurrentAMapAsyncIterator( + aiterator, + transformation, + buffersize=concurrency, + ordered=ordered, + ) + + +def observe(aiterator: AsyncIterator[T], what: str) -> AsyncIterator[T]: + validate_aiterator(aiterator) + # validate_not_none(what, "what") + return ObserveAsyncIterator(aiterator, what) + + +def skip( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Any]] = None, +) -> AsyncIterator[T]: + return askip( + aiterator, + count, + until=asyncify(until) if until else None, + ) + + +def askip( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_optional_count(count) + if until is not None: + if count is not None: + return CountAndPredicateASkipAsyncIterator(aiterator, count, until) + return PredicateASkipAsyncIterator(aiterator, until) + if count is not None: + return CountSkipAsyncIterator(aiterator, count) + return aiterator + + +def throttle( + aiterator: AsyncIterator[T], + count: Optional[int], + *, + per: Optional[datetime.timedelta] = None, +) -> AsyncIterator[T]: + validate_optional_positive_count(count) + validate_optional_positive_interval(per, name="per") + if count and per: + aiterator = YieldsPerPeriodThrottleAsyncIterator(aiterator, count, per) + return aiterator + + +def truncate( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Any]] = None, +) -> AsyncIterator[T]: + return atruncate( + aiterator, + count, + when=asyncify(when) if when else None, + ) + + +def atruncate( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_optional_count(count) + if count is not None: + aiterator = CountTruncateAsyncIterator(aiterator, count) + if when is not None: + aiterator = PredicateATruncateAsyncIterator(aiterator, when) + return aiterator diff --git a/streamable/aiterators.py b/streamable/aiterators.py new file mode 100644 index 00000000..493c7a8f --- /dev/null +++ b/streamable/aiterators.py @@ -0,0 +1,924 @@ +import asyncio +import datetime +import multiprocessing +import queue +import time +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from concurrent.futures import Executor, Future, ProcessPoolExecutor, ThreadPoolExecutor +from contextlib import contextmanager, suppress +from math import ceil +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + ContextManager, + Coroutine, + DefaultDict, + Deque, + Generic, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from streamable.util.asynctools import ( + GetEventLoopMixin, + awaitable_to_coroutine, + empty_aiter, +) +from streamable.util.functiontools import ( + aiter_wo_stopasynciteration, + awrap_error, + iter_wo_stopasynciteration, + wrap_error, +) +from streamable.util.loggertools import get_logger +from streamable.util.validationtools import ( + validate_aiterator, + validate_base, + validate_buffersize, + validate_concurrency, + validate_count, + validate_errors, + validate_group_size, + validate_optional_positive_interval, +) + +T = TypeVar("T") +U = TypeVar("U") + +from streamable.util.constants import NO_REPLACEMENT +from streamable.util.futuretools import ( + FDFOAsyncFutureResultCollection, + FDFOOSFutureResultCollection, + FIFOAsyncFutureResultCollection, + FIFOOSFutureResultCollection, + FutureResultCollection, +) + +with suppress(ImportError): + from typing import Literal + + +class ACatchAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + errors: Union[Type[Exception], Tuple[Type[Exception], ...]], + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]], + replacement: T, + finally_raise: bool, + ) -> None: + validate_aiterator(iterator) + validate_errors(errors) + self.iterator = iterator + self.errors = errors + self.when = awrap_error(when, StopAsyncIteration) if when else None + self.replacement = replacement + self.finally_raise = finally_raise + self._to_be_finally_raised: Optional[Exception] = None + + async def __anext__(self) -> T: + while True: + try: + return await self.iterator.__anext__() + except StopAsyncIteration: + if self._to_be_finally_raised: + try: + raise self._to_be_finally_raised + finally: + self._to_be_finally_raised = None + raise + except self.errors as e: + if not self.when or await self.when(e): + if self.finally_raise and not self._to_be_finally_raised: + self._to_be_finally_raised = e + if self.replacement is not NO_REPLACEMENT: + return self.replacement + continue + raise + + +class ADistinctAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.key = awrap_error(key, StopAsyncIteration) if key else None + self._already_seen: Set[Any] = set() + + async def __anext__(self) -> T: + while True: + elem = await self.iterator.__anext__() + key = await self.key(elem) if self.key else elem + if key not in self._already_seen: + break + self._already_seen.add(key) + return elem + + +class ConsecutiveADistinctAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.key = awrap_error(key, StopAsyncIteration) if key else None + self._last_key: Any = object() + + async def __anext__(self) -> T: + while True: + elem = await self.iterator.__anext__() + key = await self.key(elem) if self.key else elem + if key != self._last_key: + break + self._last_key = key + return elem + + +class FlattenAsyncIterator(AsyncIterator[U]): + def __init__(self, iterator: AsyncIterator[Iterable[U]]) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self._current_iterator_elem: Iterator[U] = iter(tuple()) + + async def __anext__(self) -> U: + while True: + try: + return next(self._current_iterator_elem) + except StopIteration: + self._current_iterator_elem = iter_wo_stopasynciteration( + await self.iterator.__anext__() + ) + + +class AFlattenAsyncIterator(AsyncIterator[U]): + def __init__(self, iterator: AsyncIterator[AsyncIterable[U]]) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self._current_iterator_elem: AsyncIterator[U] = empty_aiter() + + async def __anext__(self) -> U: + while True: + try: + return await self._current_iterator_elem.__anext__() + except StopAsyncIteration: + self._current_iterator_elem = aiter_wo_stopasynciteration( + await self.iterator.__anext__() + ) + + +class _GroupAsyncIteratorMixin(Generic[T]): + def __init__( + self, + iterator: AsyncIterator[T], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + validate_aiterator(iterator) + validate_group_size(size) + validate_optional_positive_interval(interval) + self.iterator = iterator + self.size = size or cast(int, float("inf")) + self.interval = interval + self._interval_seconds = interval.total_seconds() if interval else float("inf") + self._to_be_raised: Optional[Exception] = None + self._last_group_yielded_at: float = 0 + + def _interval_seconds_have_elapsed(self) -> bool: + if not self.interval: + return False + return ( + time.perf_counter() - self._last_group_yielded_at + ) >= self._interval_seconds + + def _remember_group_time(self) -> None: + if self.interval: + self._last_group_yielded_at = time.perf_counter() + + def _init_last_group_time(self) -> None: + if self.interval and not self._last_group_yielded_at: + self._last_group_yielded_at = time.perf_counter() + + +class GroupAsyncIterator(_GroupAsyncIteratorMixin[T], AsyncIterator[List[T]]): + def __init__( + self, + iterator: AsyncIterator[T], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + super().__init__(iterator, size, interval) + self._current_group: List[T] = [] + + async def __anext__(self) -> List[T]: + self._init_last_group_time() + if self._to_be_raised: + try: + raise self._to_be_raised + finally: + self._to_be_raised = None + try: + while len(self._current_group) < self.size and ( + not self._interval_seconds_have_elapsed() or not self._current_group + ): + self._current_group.append(await self.iterator.__anext__()) + except Exception as e: + if not self._current_group: + raise + self._to_be_raised = e + + group, self._current_group = self._current_group, [] + self._remember_group_time() + return group + + +class AGroupbyAsyncIterator( + _GroupAsyncIteratorMixin[T], AsyncIterator[Tuple[U, List[T]]] +): + def __init__( + self, + iterator: AsyncIterator[T], + key: Callable[[T], Coroutine[Any, Any, U]], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + super().__init__(iterator, size, interval) + self.key = awrap_error(key, StopAsyncIteration) + self._is_exhausted = False + self._groups_by: DefaultDict[U, List[T]] = defaultdict(list) + + async def _group_next_elem(self) -> None: + elem = await self.iterator.__anext__() + self._groups_by[await self.key(elem)].append(elem) + + def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]: + for key, group in self._groups_by.items(): + if len(group) >= self.size: + return key, self._groups_by.pop(key) + return None + + def _pop_first_group(self) -> Tuple[U, List[T]]: + first_key: U = next(iter(self._groups_by), cast(U, ...)) + return first_key, self._groups_by.pop(first_key) + + def _pop_largest_group(self) -> Tuple[U, List[T]]: + largest_group_key: Any = next(iter(self._groups_by), ...) + + for key, group in self._groups_by.items(): + if len(group) > len(self._groups_by[largest_group_key]): + largest_group_key = key + + return largest_group_key, self._groups_by.pop(largest_group_key) + + async def __anext__(self) -> Tuple[U, List[T]]: + self._init_last_group_time() + if self._is_exhausted: + if self._groups_by: + return self._pop_first_group() + raise StopAsyncIteration + + if self._to_be_raised: + if self._groups_by: + self._remember_group_time() + return self._pop_first_group() + try: + raise self._to_be_raised + finally: + self._to_be_raised = None + + try: + await self._group_next_elem() + + full_group: Optional[Tuple[U, List[T]]] = self._pop_full_group() + while not full_group and not self._interval_seconds_have_elapsed(): + await self._group_next_elem() + full_group = self._pop_full_group() + + self._remember_group_time() + return full_group or self._pop_largest_group() + + except StopAsyncIteration: + self._is_exhausted = True + return await self.__anext__() + + except Exception as e: + self._to_be_raised = e + return await self.__anext__() + + +class CountSkipAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T], count: int) -> None: + validate_aiterator(iterator) + validate_count(count) + self.iterator = iterator + self.count = count + self._n_skipped = 0 + self._done_skipping = False + + async def __anext__(self) -> T: + if not self._done_skipping: + while self._n_skipped < self.count: + await self.iterator.__anext__() + # do not count exceptions as skipped elements + self._n_skipped += 1 + self._done_skipping = True + return await self.iterator.__anext__() + + +class PredicateASkipAsyncIterator(AsyncIterator[T]): + def __init__( + self, iterator: AsyncIterator[T], until: Callable[[T], Coroutine[Any, Any, Any]] + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.until = awrap_error(until, StopAsyncIteration) + self._done_skipping = False + + async def __anext__(self) -> T: + elem = await self.iterator.__anext__() + if not self._done_skipping: + while not await self.until(elem): + elem = await self.iterator.__anext__() + self._done_skipping = True + return elem + + +class CountAndPredicateASkipAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + count: int, + until: Callable[[T], Coroutine[Any, Any, Any]], + ) -> None: + validate_aiterator(iterator) + validate_count(count) + self.iterator = iterator + self.count = count + self.until = awrap_error(until, StopAsyncIteration) + self._n_skipped = 0 + self._done_skipping = False + + async def __anext__(self) -> T: + elem = await self.iterator.__anext__() + if not self._done_skipping: + while self._n_skipped < self.count and not await self.until(elem): + elem = await self.iterator.__anext__() + # do not count exceptions as skipped elements + self._n_skipped += 1 + self._done_skipping = True + return elem + + +class CountTruncateAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T], count: int) -> None: + validate_aiterator(iterator) + validate_count(count) + self.iterator = iterator + self.count = count + self._current_count = 0 + + async def __anext__(self) -> T: + if self._current_count == self.count: + raise StopAsyncIteration() + elem = await self.iterator.__anext__() + self._current_count += 1 + return elem + + +class PredicateATruncateAsyncIterator(AsyncIterator[T]): + def __init__( + self, iterator: AsyncIterator[T], when: Callable[[T], Coroutine[Any, Any, Any]] + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.when = awrap_error(when, StopAsyncIteration) + self._satisfied = False + + async def __anext__(self) -> T: + if self._satisfied: + raise StopAsyncIteration() + elem = await self.iterator.__anext__() + if await self.when(elem): + self._satisfied = True + raise StopAsyncIteration() + return elem + + +class AMapAsyncIterator(AsyncIterator[U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], Coroutine[Any, Any, U]], + ) -> None: + validate_aiterator(iterator) + + self.iterator = iterator + self.transformation = awrap_error(transformation, StopAsyncIteration) + + async def __anext__(self) -> U: + return await self.transformation(await self.iterator.__anext__()) + + +class AFilterAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + when: Callable[[T], Coroutine[Any, Any, Any]], + ) -> None: + validate_aiterator(iterator) + + self.iterator = iterator + self.when = awrap_error(when, StopAsyncIteration) + + async def __anext__(self) -> T: + while True: + elem = await self.iterator.__anext__() + if await self.when(elem): + return elem + + +class ObserveAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T], what: str, base: int = 2) -> None: + validate_aiterator(iterator) + validate_base(base) + + self.iterator = iterator + self.what = what + self.base = base + + self._n_yields = 0 + self._n_errors = 0 + self._n_nexts = 0 + self._logged_n_nexts = 0 + self._next_threshold = 0 + + self._start_time = time.perf_counter() + + def _log(self) -> None: + get_logger().info( + "[%s %s] %s", + f"duration={datetime.datetime.fromtimestamp(time.perf_counter()) - datetime.datetime.fromtimestamp(self._start_time)}", + f"errors={self._n_errors}", + f"{self._n_yields} {self.what} yielded", + ) + self._logged_n_nexts = self._n_nexts + self._next_threshold = self.base * self._logged_n_nexts + + async def __anext__(self) -> T: + try: + elem = await self.iterator.__anext__() + self._n_nexts += 1 + self._n_yields += 1 + return elem + except StopAsyncIteration: + if self._n_nexts != self._logged_n_nexts: + self._log() + raise + except Exception: + self._n_nexts += 1 + self._n_errors += 1 + raise + finally: + if self._n_nexts >= self._next_threshold: + self._log() + + +class YieldsPerPeriodThrottleAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + max_yields: int, + period: datetime.timedelta, + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.max_yields = max_yields + self._period_seconds = period.total_seconds() + + self._period_index: int = -1 + self._yields_in_period = 0 + + self._offset: Optional[float] = None + + async def safe_next(self) -> Tuple[Optional[T], Optional[Exception]]: + try: + return await self.iterator.__anext__(), None + except StopAsyncIteration: + raise + except Exception as e: + return None, e + + async def __anext__(self) -> T: + elem, caught_error = await self.safe_next() + + now = time.perf_counter() + if not self._offset: + self._offset = now + now -= self._offset + + num_periods = now / self._period_seconds + period_index = int(num_periods) + + if self._period_index != period_index: + self._period_index = period_index + self._yields_in_period = max(0, self._yields_in_period - self.max_yields) + + if self._yields_in_period >= self.max_yields: + await asyncio.sleep( + (ceil(num_periods) - num_periods) * self._period_seconds + ) + self._yields_in_period += 1 + + if caught_error: + raise caught_error + return cast(T, elem) + + +class _RaisingAsyncIterator(AsyncIterator[T]): + class ExceptionContainer(NamedTuple): + exception: Exception + + def __init__( + self, + iterator: AsyncIterator[Union[T, ExceptionContainer]], + ) -> None: + self.iterator = iterator + + async def __anext__(self) -> T: + elem = await self.iterator.__anext__() + if isinstance(elem, self.ExceptionContainer): + raise elem.exception + return elem + + +class _ConcurrentMapAsyncIterableMixin( + Generic[T, U], + ABC, + AsyncIterable[Union[U, _RaisingAsyncIterator.ExceptionContainer]], +): + """ + Template Method Pattern: + This abstract class's `__iter__` is a skeleton for a queue-based concurrent mapping algorithm + that relies on abstract helper methods (`_context_manager`, `_create_future`, `_future_result_collection`) + that must be implemented by concrete subclasses. + """ + + def __init__( + self, + iterator: AsyncIterator[T], + buffersize: int, + ordered: bool, + ) -> None: + validate_aiterator(iterator) + validate_buffersize(buffersize) + self.iterator = iterator + self.buffersize = buffersize + self.ordered = ordered + + def _context_manager(self) -> ContextManager: + @contextmanager + def dummy_context_manager_generator(): + yield + + return dummy_context_manager_generator() + + @abstractmethod + def _launch_task( + self, elem: T + ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": ... + + # factory method + @abstractmethod + def _future_result_collection( + self, + ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: ... + + async def __aiter__( + self, + ) -> AsyncIterator[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: + with self._context_manager(): + future_results = self._future_result_collection() + + # queue tasks up to buffersize + with suppress(StopAsyncIteration): + while len(future_results) < self.buffersize: + future_results.add_future( + self._launch_task(await self.iterator.__anext__()) + ) + + # wait, queue, yield + while future_results: + result = await future_results.__anext__() + with suppress(StopAsyncIteration): + future_results.add_future( + self._launch_task(await self.iterator.__anext__()) + ) + yield result + + +class _ConcurrentMapAsyncIterable(_ConcurrentMapAsyncIterableMixin[T, U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], U], + concurrency: int, + buffersize: int, + ordered: bool, + via: "Literal['thread', 'process']", + ) -> None: + super().__init__(iterator, buffersize, ordered) + validate_concurrency(concurrency) + self.transformation = wrap_error(transformation, StopAsyncIteration) + self.concurrency = concurrency + self.executor: Executor + self.via = via + + def _context_manager(self) -> ContextManager: + if self.via == "thread": + self.executor = ThreadPoolExecutor(max_workers=self.concurrency) + if self.via == "process": + self.executor = ProcessPoolExecutor(max_workers=self.concurrency) + return self.executor + + # picklable + @staticmethod + def _safe_transformation( + transformation: Callable[[T], U], elem: T + ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]: + try: + return transformation(elem) + except Exception as e: + return _RaisingAsyncIterator.ExceptionContainer(e) + + def _launch_task( + self, elem: T + ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": + return self.executor.submit( + self._safe_transformation, self.transformation, elem + ) + + def _future_result_collection( + self, + ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: + if self.ordered: + return FIFOOSFutureResultCollection() + return FDFOOSFutureResultCollection( + multiprocessing.Queue if self.via == "process" else queue.Queue + ) + + +class ConcurrentMapAsyncIterator(_RaisingAsyncIterator[U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], U], + concurrency: int, + buffersize: int, + ordered: bool, + via: "Literal['thread', 'process']", + ) -> None: + super().__init__( + _ConcurrentMapAsyncIterable( + iterator, + transformation, + concurrency, + buffersize, + ordered, + via, + ).__aiter__() + ) + + +class _ConcurrentAMapAsyncIterable( + _ConcurrentMapAsyncIterableMixin[T, U], GetEventLoopMixin +): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], Coroutine[Any, Any, U]], + buffersize: int, + ordered: bool, + ) -> None: + super().__init__(iterator, buffersize, ordered) + self.transformation = awrap_error(transformation, StopAsyncIteration) + + async def _safe_transformation( + self, elem: T + ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]: + try: + coroutine = self.transformation(elem) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"`transformation` must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}", + ) + return await coroutine + except Exception as e: + return _RaisingAsyncIterator.ExceptionContainer(e) + + def _launch_task( + self, elem: T + ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": + return cast( + "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]", + self.get_event_loop().create_task(self._safe_transformation(elem)), + ) + + def _future_result_collection( + self, + ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: + if self.ordered: + return FIFOAsyncFutureResultCollection(self.get_event_loop()) + else: + return FDFOAsyncFutureResultCollection(self.get_event_loop()) + + +class ConcurrentAMapAsyncIterator(_RaisingAsyncIterator[U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], Coroutine[Any, Any, U]], + buffersize: int, + ordered: bool, + ) -> None: + super().__init__( + _ConcurrentAMapAsyncIterable( + iterator, + transformation, + buffersize, + ordered, + ).__aiter__() + ) + + +class _ConcurrentFlattenAsyncIterable( + AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]] +): + def __init__( + self, + iterables_iterator: AsyncIterator[Iterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + validate_aiterator(iterables_iterator) + validate_concurrency(concurrency) + self.iterables_iterator = iterables_iterator + self.concurrency = concurrency + self.buffersize = buffersize + + async def __aiter__( + self, + ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]: + with ThreadPoolExecutor(max_workers=self.concurrency) as executor: + iterator_and_future_pairs: Deque[Tuple[Iterator[T], Future]] = deque() + element_to_yield: Deque[ + Union[T, _RaisingAsyncIterator.ExceptionContainer] + ] = deque(maxlen=1) + iterator_to_queue: Optional[Iterator[T]] = None + # wait, queue, yield (FIFO) + while True: + if iterator_and_future_pairs: + iterator, future = iterator_and_future_pairs.popleft() + try: + element_to_yield.append(future.result()) + iterator_to_queue = iterator + except StopIteration: + pass + except Exception as e: + element_to_yield.append( + _RaisingAsyncIterator.ExceptionContainer(e) + ) + iterator_to_queue = iterator + + # queue tasks up to buffersize + while len(iterator_and_future_pairs) < self.buffersize: + if not iterator_to_queue: + try: + iterable = await self.iterables_iterator.__anext__() + except StopAsyncIteration: + break + try: + iterator_to_queue = iter_wo_stopasynciteration(iterable) + except Exception as e: + yield _RaisingAsyncIterator.ExceptionContainer(e) + continue + future = executor.submit(next, iterator_to_queue) + iterator_and_future_pairs.append((iterator_to_queue, future)) + iterator_to_queue = None + if element_to_yield: + yield element_to_yield.pop() + if not iterator_and_future_pairs: + break + + +class ConcurrentFlattenAsyncIterator(_RaisingAsyncIterator[T]): + def __init__( + self, + iterables_iterator: AsyncIterator[Iterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + super().__init__( + _ConcurrentFlattenAsyncIterable( + iterables_iterator, + concurrency, + buffersize, + ).__aiter__() + ) + + +class _ConcurrentAFlattenAsyncIterable( + AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]], GetEventLoopMixin +): + def __init__( + self, + iterables_iterator: AsyncIterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + validate_aiterator(iterables_iterator) + validate_concurrency(concurrency) + self.iterables_iterator = iterables_iterator + self.concurrency = concurrency + self.buffersize = buffersize + + async def __aiter__( + self, + ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]: + iterator_and_future_pairs: Deque[Tuple[AsyncIterator[T], Awaitable[T]]] = ( + deque() + ) + element_to_yield: Deque[Union[T, _RaisingAsyncIterator.ExceptionContainer]] = ( + deque(maxlen=1) + ) + iterator_to_queue: Optional[AsyncIterator[T]] = None + # wait, queue, yield (FIFO) + while True: + if iterator_and_future_pairs: + iterator, future = iterator_and_future_pairs.popleft() + try: + element_to_yield.append(await future) + iterator_to_queue = iterator + except StopAsyncIteration: + pass + except Exception as e: + element_to_yield.append(_RaisingAsyncIterator.ExceptionContainer(e)) + iterator_to_queue = iterator + + # queue tasks up to buffersize + while len(iterator_and_future_pairs) < self.buffersize: + if not iterator_to_queue: + try: + iterable = await self.iterables_iterator.__anext__() + except StopAsyncIteration: + break + try: + iterator_to_queue = aiter_wo_stopasynciteration(iterable) + except Exception as e: + yield _RaisingAsyncIterator.ExceptionContainer(e) + continue + future = self.get_event_loop().create_task( + awaitable_to_coroutine( + cast(AsyncIterator, iterator_to_queue).__anext__() + ) + ) + iterator_and_future_pairs.append( + (cast(AsyncIterator, iterator_to_queue), future) + ) + iterator_to_queue = None + if element_to_yield: + yield element_to_yield.pop() + if not iterator_and_future_pairs: + break + + +class ConcurrentAFlattenAsyncIterator(_RaisingAsyncIterator[T]): + def __init__( + self, + iterables_iterator: AsyncIterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + super().__init__( + _ConcurrentAFlattenAsyncIterable( + iterables_iterator, + concurrency, + buffersize, + ).__aiter__() + ) diff --git a/streamable/functions.py b/streamable/functions.py index aa369fd0..601831f2 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -4,6 +4,7 @@ from operator import itemgetter from typing import ( Any, + AsyncIterable, Callable, Coroutine, Iterable, @@ -17,9 +18,12 @@ ) from streamable.iterators import ( - AsyncConcurrentMapIterator, + AFlattenIterator, CatchIterator, + ConcurrentAFlattenIterator, + ConcurrentAMapIterator, ConcurrentFlattenIterator, + ConcurrentMapIterator, ConsecutiveDistinctIterator, CountAndPredicateSkipIterator, CountSkipIterator, @@ -29,13 +33,12 @@ GroupbyIterator, GroupIterator, ObserveIterator, - OSConcurrentMapIterator, PredicateSkipIterator, PredicateTruncateIterator, YieldsPerPeriodThrottleIterator, ) from streamable.util.constants import NO_REPLACEMENT -from streamable.util.functiontools import wrap_error +from streamable.util.functiontools import syncify, wrap_error from streamable.util.validationtools import ( validate_concurrency, validate_errors, @@ -79,6 +82,25 @@ def catch( ) +def acatch( + iterator: Iterator[T], + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, +) -> Iterator[T]: + return catch( + iterator, + errors, + when=syncify(when) if when else None, + replacement=replacement, + finally_raise=finally_raise, + ) + + def distinct( iterator: Iterator[T], key: Optional[Callable[[T], Any]] = None, @@ -92,6 +114,19 @@ def distinct( return DistinctIterator(iterator, key) +def adistinct( + iterator: Iterator[T], + key: Optional[Callable[[T], Any]] = None, + *, + consecutive_only: bool = False, +) -> Iterator[T]: + return distinct( + iterator, + syncify(key) if key else None, + consecutive_only=consecutive_only, + ) + + def flatten(iterator: Iterator[Iterable[T]], *, concurrency: int = 1) -> Iterator[T]: validate_iterator(iterator) validate_concurrency(concurrency) @@ -105,6 +140,21 @@ def flatten(iterator: Iterator[Iterable[T]], *, concurrency: int = 1) -> Iterato ) +def aflatten( + iterator: Iterator[AsyncIterable[T]], *, concurrency: int = 1 +) -> Iterator[T]: + validate_iterator(iterator) + validate_concurrency(concurrency) + if concurrency == 1: + return AFlattenIterator(iterator) + else: + return ConcurrentAFlattenIterator( + iterator, + concurrency=concurrency, + buffersize=concurrency, + ) + + def group( iterator: Iterator[T], size: Optional[int] = None, @@ -120,6 +170,21 @@ def group( return map(itemgetter(1), GroupbyIterator(iterator, by, size, interval)) +def agroup( + iterator: Iterator[T], + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> Iterator[List[T]]: + return group( + iterator, + size, + interval=interval, + by=syncify(by) if by else None, + ) + + def groupby( iterator: Iterator[T], key: Callable[[T], U], @@ -133,6 +198,21 @@ def groupby( return GroupbyIterator(iterator, key, size, interval) +def agroupby( + iterator: Iterator[T], + key: Callable[[T], Coroutine[Any, Any, U]], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, +) -> Iterator[Tuple[U, List[T]]]: + return groupby( + iterator, + syncify(key), + size=size, + interval=interval, + ) + + def map( transformation: Callable[[T], U], iterator: Iterator[T], @@ -149,7 +229,7 @@ def map( if concurrency == 1: return builtins.map(wrap_error(transformation, StopIteration), iterator) else: - return OSConcurrentMapIterator( + return ConcurrentMapIterator( iterator, transformation, concurrency=concurrency, @@ -170,7 +250,9 @@ def amap( # validate_not_none(transformation, "transformation") # validate_not_none(ordered, "ordered") validate_concurrency(concurrency) - return AsyncConcurrentMapIterator( + if concurrency == 1: + return map(syncify(transformation), iterator) + return ConcurrentAMapIterator( iterator, transformation, buffersize=concurrency, @@ -201,6 +283,19 @@ def skip( return iterator +def askip( + iterator: Iterator[T], + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> Iterator[T]: + return skip( + iterator, + count, + until=syncify(until) if until else None, + ) + + def throttle( iterator: Iterator[T], count: Optional[int], @@ -227,3 +322,16 @@ def truncate( if when is not None: iterator = PredicateTruncateIterator(iterator, when) return iterator + + +def atruncate( + iterator: Iterator[T], + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> Iterator[T]: + return truncate( + iterator, + count, + when=syncify(when) if when else None, + ) diff --git a/streamable/iterators.py b/streamable/iterators.py index d1713a52..7b684937 100644 --- a/streamable/iterators.py +++ b/streamable/iterators.py @@ -1,4 +1,3 @@ -import asyncio import datetime import multiprocessing import queue @@ -10,6 +9,9 @@ from math import ceil from typing import ( Any, + AsyncIterable, + AsyncIterator, + Awaitable, Callable, ContextManager, Coroutine, @@ -29,7 +31,16 @@ cast, ) -from streamable.util.functiontools import iter_wo_stopiteration, wrap_error +from streamable.util.asynctools import ( + GetEventLoopMixin, + awaitable_to_coroutine, + empty_aiter, +) +from streamable.util.functiontools import ( + aiter_wo_stopiteration, + iter_wo_stopiteration, + wrap_error, +) from streamable.util.loggertools import get_logger from streamable.util.validationtools import ( validate_base, @@ -149,6 +160,25 @@ def __next__(self) -> U: self._current_iterator_elem = iter_wo_stopiteration(next(self.iterator)) +class AFlattenIterator(Iterator[U], GetEventLoopMixin): + def __init__(self, iterator: Iterator[AsyncIterable[U]]) -> None: + validate_iterator(iterator) + self.iterator = iterator + + self._current_iterator_elem: AsyncIterator[U] = empty_aiter() + + def __next__(self) -> U: + while True: + try: + return self.get_event_loop().run_until_complete( + self._current_iterator_elem.__anext__() + ) + except StopAsyncIteration: + self._current_iterator_elem = aiter_wo_stopiteration( + next(self.iterator) + ) + + class _GroupIteratorMixin(Generic[T]): def __init__( self, @@ -489,7 +519,7 @@ def __next__(self) -> T: return elem -class _ConcurrentMapIterable( +class _ConcurrentMapIterableMixin( Generic[T, U], ABC, Iterable[Union[U, _RaisingIterator.ExceptionContainer]] ): """ @@ -546,7 +576,7 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]: yield result -class _OSConcurrentMapIterable(_ConcurrentMapIterable[T, U]): +class _ConcurrentMapIterable(_ConcurrentMapIterableMixin[T, U]): def __init__( self, iterator: Iterator[T], @@ -597,7 +627,7 @@ def _future_result_collection( ) -class OSConcurrentMapIterator(_RaisingIterator[U]): +class ConcurrentMapIterator(_RaisingIterator[U]): def __init__( self, iterator: Iterator[T], @@ -609,7 +639,7 @@ def __init__( ) -> None: super().__init__( iter( - _OSConcurrentMapIterable( + _ConcurrentMapIterable( iterator, transformation, concurrency, @@ -621,7 +651,7 @@ def __init__( ) -class _AsyncConcurrentMapIterable(_ConcurrentMapIterable[T, U]): +class _ConcurrentAMapIterable(_ConcurrentMapIterableMixin[T, U], GetEventLoopMixin): def __init__( self, iterator: Iterator[T], @@ -631,12 +661,6 @@ def __init__( ) -> None: super().__init__(iterator, buffersize, ordered) self.transformation = wrap_error(transformation, StopIteration) - self.event_loop: asyncio.AbstractEventLoop - try: - self.event_loop = asyncio.get_event_loop() - except RuntimeError: - self.event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.event_loop) async def _safe_transformation( self, elem: T @@ -656,19 +680,19 @@ def _launch_task( ) -> "Future[Union[U, _RaisingIterator.ExceptionContainer]]": return cast( "Future[Union[U, _RaisingIterator.ExceptionContainer]]", - self.event_loop.create_task(self._safe_transformation(elem)), + self.get_event_loop().create_task(self._safe_transformation(elem)), ) def _future_result_collection( self, ) -> FutureResultCollection[Union[U, _RaisingIterator.ExceptionContainer]]: if self.ordered: - return FIFOAsyncFutureResultCollection(self.event_loop) + return FIFOAsyncFutureResultCollection(self.get_event_loop()) else: - return FDFOAsyncFutureResultCollection(self.event_loop) + return FDFOAsyncFutureResultCollection(self.get_event_loop()) -class AsyncConcurrentMapIterator(_RaisingIterator[U]): +class ConcurrentAMapIterator(_RaisingIterator[U]): def __init__( self, iterator: Iterator[T], @@ -678,7 +702,7 @@ def __init__( ) -> None: super().__init__( iter( - _AsyncConcurrentMapIterable( + _ConcurrentAMapIterable( iterator, transformation, buffersize, @@ -760,3 +784,86 @@ def __init__( ) ) ) + + +class _ConcurrentAFlattenIterable( + Iterable[Union[T, _RaisingIterator.ExceptionContainer]], GetEventLoopMixin +): + def __init__( + self, + iterables_iterator: Iterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + validate_iterator(iterables_iterator) + validate_concurrency(concurrency) + self.iterables_iterator = iterables_iterator + self.concurrency = concurrency + self.buffersize = buffersize + + def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: + iterator_and_future_pairs: Deque[Tuple[AsyncIterator[T], Awaitable[T]]] = ( + deque() + ) + element_to_yield: Deque[Union[T, _RaisingIterator.ExceptionContainer]] = deque( + maxlen=1 + ) + iterator_to_queue: Optional[AsyncIterator[T]] = None + # wait, queue, yield (FIFO) + while True: + if iterator_and_future_pairs: + iterator, future = iterator_and_future_pairs.popleft() + try: + element_to_yield.append( + self.get_event_loop().run_until_complete(future) + ) + iterator_to_queue = iterator + except StopAsyncIteration: + pass + except Exception as e: + element_to_yield.append(_RaisingIterator.ExceptionContainer(e)) + iterator_to_queue = iterator + + # queue tasks up to buffersize + while len(iterator_and_future_pairs) < self.buffersize: + if not iterator_to_queue: + try: + iterable = next(self.iterables_iterator) + except StopIteration: + break + try: + iterator_to_queue = aiter_wo_stopiteration(iterable) + except Exception as e: + yield _RaisingIterator.ExceptionContainer(e) + continue + future = self.get_event_loop().create_task( + awaitable_to_coroutine( + cast(AsyncIterator, iterator_to_queue).__anext__() + ) + ) + iterator_and_future_pairs.append( + (cast(AsyncIterator, iterator_to_queue), future) + ) + iterator_to_queue = None + if element_to_yield: + yield element_to_yield.pop() + if not iterator_and_future_pairs: + break + + +class ConcurrentAFlattenIterator(_RaisingIterator[T]): + def __init__( + self, + iterables_iterator: Iterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + super().__init__( + iter( + _ConcurrentAFlattenIterable( + iterables_iterator, + concurrency, + buffersize, + ) + ) + ) diff --git a/streamable/stream.py b/streamable/stream.py index d25693c2..de8f2444 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -5,10 +5,14 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterable, + AsyncIterator, + Awaitable, Callable, Collection, Coroutine, Dict, + Generator, Generic, Iterable, Iterator, @@ -25,6 +29,7 @@ ) from streamable.util.constants import NO_REPLACEMENT +from streamable.util.functiontools import asyncify from streamable.util.loggertools import get_logger from streamable.util.validationtools import ( validate_concurrency, @@ -54,17 +59,29 @@ V = TypeVar("V") -class Stream(Iterable[T]): +class Stream(Iterable[T], AsyncIterable[T], Awaitable["Stream[T]"]): __slots__ = ("_source", "_upstream") # fmt: off @overload def __init__(self, source: Iterable[T]) -> None: ... @overload + def __init__(self, source: AsyncIterable[T]) -> None: ... + @overload def __init__(self, source: Callable[[], Iterable[T]]) -> None: ... + @overload + def __init__(self, source: Callable[[], AsyncIterable[T]]) -> None: ... # fmt: on - def __init__(self, source: Union[Iterable[T], Callable[[], Iterable[T]]]) -> None: + def __init__( + self, + source: Union[ + Iterable[T], + Callable[[], Iterable[T]], + AsyncIterable[T], + Callable[[], AsyncIterable[T]], + ], + ) -> None: """ A `Stream[T]` decorates an `Iterable[T]` with a **fluent interface** enabling the chaining of lazy operations. @@ -84,7 +101,11 @@ def upstream(self) -> "Optional[Stream]": return self._upstream @property - def source(self) -> Union[Iterable, Callable[[], Iterable]]: + def source( + self, + ) -> Union[ + Iterable, Callable[[], Iterable], AsyncIterable, Callable[[], AsyncIterable] + ]: """ Returns: Callable[[], Iterable]: Function called at iteration time (i.e. by `__iter__`) to get a fresh source iterable. @@ -103,6 +124,11 @@ def __iter__(self) -> Iterator[T]: return self.accept(IteratorVisitor[T]()) + def __aiter__(self) -> AsyncIterator[T]: + from streamable.visitors.aiterator import AsyncIteratorVisitor + + return self.accept(AsyncIteratorVisitor[T]()) + def __repr__(self) -> str: from streamable.visitors.representation import ReprVisitor @@ -134,6 +160,16 @@ def __call__(self) -> "Stream[T]": self.count() return self + def __await__(self) -> Generator[int, None, "Stream[T]"]: + """ + Iterates over this stream until exhaustion. + + Returns: + Stream[T]: self. + """ + yield from (self.acount().__await__()) + return self + def accept(self, visitor: "Visitor[V]") -> V: """ Entry point to visit this stream (en.wikipedia.org/wiki/Visitor_pattern). @@ -175,6 +211,40 @@ def catch( finally_raise=finally_raise, ) + def acatch( + self, + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, + ) -> "Stream[T]": + """ + Catches the upstream exceptions if they are instances of `errors` type and they satisfy the `when` predicate. + Optionally yields a `replacement` value. + If any exception was caught during the iteration and `finally_raise=True`, the first caught exception will be raised when the iteration finishes. + + Args: + errors (Optional[Type[Exception]], Iterable[Optional[Type[Exception]]], optional): The exception type to catch, or an iterable of exception types to catch (default: catches all `Exception`s) + when (Optional[Callable[[Exception], Coroutine[Any, Any, Any]]], optional): An additional condition that must be satisfied to catch the exception, i.e. `when(exception)` must be truthy. (default: no additional condition) + replacement (T, optional): The value to yield when an exception is caught. (default: do not yield any replacement value) + finally_raise (bool, optional): If True the first exception caught is raised when upstream's iteration ends. (default: iteration ends without raising) + + Returns: + Stream[T]: A stream of upstream elements catching the eligible exceptions. + """ + validate_errors(errors) + # validate_not_none(finally_raise, "finally_raise") + return ACatchStream( + self, + errors, + when=when, + replacement=replacement, + finally_raise=finally_raise, + ) + def count(self) -> int: """ Iterates over this stream until exhaustion and returns the count of elements. @@ -185,6 +255,18 @@ def count(self) -> int: return sum(1 for _ in self) + async def acount(self) -> int: + """ + Iterates over this stream until exhaustion and returns the count of elements. + + Returns: + int: Number of elements yielded during an entire iteration over this stream. + """ + count = 0 + async for _ in self: + count += 1 + return count + def display(self, level: int = logging.INFO) -> "Stream[T]": """ Logs (INFO level) a representation of the stream. @@ -226,6 +308,33 @@ def distinct( # validate_not_none(consecutive_only, "consecutive_only") return DistinctStream(self, key, consecutive_only) + def adistinct( + self, + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + *, + consecutive_only: bool = False, + ) -> "Stream[T]": + """ + Filters the stream to yield only distinct elements. + If a deduplication `key` is specified, `foo` and `bar` are treated as duplicates when `key(foo) == key(bar)`. + + + Among duplicates, the first encountered occurence in upstream order is yielded. + + Warning: + During iteration, the distinct elements yielded are retained in memory to perform deduplication. + Alternatively, remove only consecutive duplicates without memory footprint by setting `consecutive_only=True`. + + Args: + key (Callable[[T], Coroutine[Any, Any, Any]], optional): Elements are deduplicated based on `key(elem)`. (default: the deduplication is performed on the elements themselves) + consecutive_only (bool, optional): Whether to deduplicate only consecutive duplicates, or globally. (default: the deduplication is global) + + Returns: + Stream: A stream containing only unique upstream elements. + """ + # validate_not_none(consecutive_only, "consecutive_only") + return ADistinctStream(self, key, consecutive_only) + def filter(self, when: Callable[[T], Any] = bool) -> "Stream[T]": """ Filters the stream to yield only elements satisfying the `when` predicate. @@ -240,6 +349,24 @@ def filter(self, when: Callable[[T], Any] = bool) -> "Stream[T]": # validate_not_none(when, "when") return FilterStream(self, cast(Optional[Callable[[T], Any]], when) or bool) + def afilter(self, when: Callable[[T], Coroutine[Any, Any, Any]]) -> "Stream[T]": + """ + Filters the stream to yield only elements satisfying the `when` predicate. + + Args: + when (Callable[[T], Coroutine[Any, Any, Any]], optional): An element is kept if `when(elem)` is truthy. (default: keeps truthy elements) + + Returns: + Stream[T]: A stream of upstream elements satisfying the `when` predicate. + """ + # Unofficially accept `stream.afilter(None)`, behaving as builtin `filter(None, iter)` + # validate_not_none(when, "when") + return AFilterStream( + self, + cast(Optional[Callable[[T], Coroutine[Any, Any, Any]]], when) + or asyncify(bool), + ) + # fmt: off @overload def flatten( @@ -317,13 +444,43 @@ def flatten(self: "Stream[Iterable[U]]", *, concurrency: int = 1) -> "Stream[U]" Iterates over upstream elements assumed to be iterables, and individually yields their items. Args: - concurrency (int, optional): Represents both the number of threads used to concurrently flatten the upstream iterables and the number of iterables buffered. (default: no concurrency) + concurrency (int, optional): Number of upstream iterables concurrently flattened via threads. (default: no concurrency) Returns: Stream[R]: A stream of flattened elements from upstream iterables. """ validate_concurrency(concurrency) return FlattenStream(self, concurrency) + # fmt: off + @overload + def aflatten( + self: "Stream[AsyncIterator[U]]", + *, + concurrency: int = 1, + ) -> "Stream[U]": ... + + @overload + def aflatten( + self: "Stream[AsyncIterable[U]]", + *, + concurrency: int = 1, + ) -> "Stream[U]": ... + # fmt: on + + def aflatten( + self: "Stream[AsyncIterable[U]]", *, concurrency: int = 1 + ) -> "Stream[U]": + """ + Iterates over upstream elements assumed to be async iterables, and individually yields their items. + + Args: + concurrency (int, optional): Number of upstream async iterables concurrently flattened. (default: no concurrency) + Returns: + Stream[R]: A stream of flattened elements from upstream async iterables. + """ + validate_concurrency(concurrency) + return AFlattenStream(self, concurrency) + def foreach( self, effect: Callable[[T], Any], @@ -352,7 +509,7 @@ def foreach( def aforeach( self, - effect: Callable[[T], Coroutine], + effect: Callable[[T], Coroutine[Any, Any, Any]], *, concurrency: int = 1, ordered: bool = True, @@ -362,7 +519,7 @@ def aforeach( If the `effect(elem)` coroutine throws an exception then it will be thrown and `elem` will not be yielded. Args: - effect (Callable[[T], Any]): The asynchronous function to be applied to each element as a side effect. + effect (Callable[[T], Coroutine[Any, Any, Any]]): The asynchronous function to be applied to each element as a side effect. concurrency (int, optional): Represents both the number of async tasks concurrently applying the `effect` and the size of the buffer containing not-yet-yielded elements. If the buffer is full, the iteration over the upstream is paused until an element is yielded from the buffer. (default: no concurrency) ordered (bool, optional): If `concurrency` > 1, whether to preserve the order of upstream elements or to yield them as soon as they are processed. (default: preserves upstream order) Returns: @@ -401,6 +558,34 @@ def group( validate_optional_positive_interval(interval) return GroupStream(self, size, interval, by) + def agroup( + self, + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> "Stream[List[T]]": + """ + Groups upstream elements into lists. + + A group is yielded when any of the following conditions is met: + - The group reaches `size` elements. + - `interval` seconds have passed since the last group was yielded. + - The upstream source is exhausted. + + If `by` is specified, groups will only contain elements sharing the same `by(elem)` value (see `.agroupby` for `(key, elements)` pairs). + + Args: + size (Optional[int], optional): The maximum number of elements per group. (default: no size limit) + interval (float, optional): Yields a group if `interval` seconds have passed since the last group was yielded. (default: no interval limit) + by (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): If specified, groups will only contain elements sharing the same `by(elem)` value. (default: does not co-group elements) + Returns: + Stream[List[T]]: A stream of upstream elements grouped into lists. + """ + validate_group_size(size) + validate_optional_positive_interval(interval) + return AGroupStream(self, size, interval, by) + def groupby( self, key: Callable[[T], U], @@ -427,6 +612,32 @@ def groupby( # validate_not_none(key, "key") return GroupbyStream(self, key, size, interval) + def agroupby( + self, + key: Callable[[T], Coroutine[Any, Any, U]], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, + ) -> "Stream[Tuple[U, List[T]]]": + """ + Groups upstream elements into `(key, elements)` tuples. + + A group is yielded when any of the following conditions is met: + - A group reaches `size` elements. + - `interval` seconds have passed since the last group was yielded. + - The upstream source is exhausted. + + Args: + key (Callable[[T], Coroutine[Any, Any, U]]): An async function that returns the group key for an element. + size (Optional[int], optional): The maximum number of elements per group. (default: no size limit) + interval (Optional[datetime.timedelta], optional): If specified, yields a group if `interval` seconds have passed since the last group was yielded. (default: no interval limit) + + Returns: + Stream[Tuple[U, List[T]]]: A stream of upstream elements grouped by key, as `(key, elements)` tuples. + """ + # validate_not_none(key, "key") + return AGroupbyStream(self, key, size, interval) + def map( self, transformation: Callable[[T], U], @@ -531,6 +742,26 @@ def skip( validate_optional_count(count) return SkipStream(self, count, until) + def askip( + self, + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> "Stream[T]": + """ + Skips elements until `until(elem)` is truthy, or `count` elements have been skipped. + If both `count` and `until` are set, skipping stops as soon as either condition is met. + + Args: + count (Optional[int], optional): The maximum number of elements to skip. (default: no count-based skipping) + until (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): Elements are skipped until the first one for which `until(elem)` is truthy. This element and all the subsequent ones will be yielded. (default: no predicate-based skipping) + + Returns: + Stream: A stream of the upstream elements remaining after skipping. + """ + validate_optional_count(count) + return ASkipStream(self, count, until) + def throttle( self, count: Optional[int] = None, @@ -604,6 +835,26 @@ def truncate( validate_optional_count(count) return TruncateStream(self, count, when) + def atruncate( + self, + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> "Stream[T]": + """ + Stops an iteration as soon as `when(elem)` is truthy, or `count` elements have been yielded. + If both `count` and `when` are set, truncation occurs as soon as either condition is met. + + Args: + count (int, optional): The maximum number of elements to yield. (default: no count-based truncation) + when (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): An async predicate function that determines when to stop the iteration. Iteration stops immediately after encountering the first element for which `when(elem)` is truthy, and that element will not be yielded. (default: no predicate-based truncation) + + Returns: + Stream[T]: A stream of at most `count` upstream elements not satisfying the `when` predicate. + """ + validate_optional_count(count) + return ATruncateStream(self, count, when) + class DownStream(Stream[U], Generic[T, U]): """ @@ -621,7 +872,11 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "DownStream[T, U]": return new @property - def source(self) -> Union[Iterable, Callable[[], Iterable]]: + def source( + self, + ) -> Union[ + Iterable, Callable[[], Iterable], AsyncIterable, Callable[[], AsyncIterable] + ]: return self._upstream.source @property @@ -654,6 +909,27 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_catch_stream(self) +class ACatchStream(DownStream[T, T]): + __slots__ = ("_upstream", "_errors", "_when", "_replacement", "_finally_raise") + + def __init__( + self, + upstream: Stream[T], + errors: Union[Optional[Type[Exception]], Iterable[Optional[Type[Exception]]]], + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]], + replacement: T, + finally_raise: bool, + ) -> None: + super().__init__(upstream) + self._errors = errors + self._when = when + self._replacement = replacement + self._finally_raise = finally_raise + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_acatch_stream(self) + + class DistinctStream(DownStream[T, T]): __slots__ = ("_upstream", "_key", "_consecutive_only") @@ -671,6 +947,23 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_distinct_stream(self) +class ADistinctStream(DownStream[T, T]): + __slots__ = ("_upstream", "_key", "_consecutive_only") + + def __init__( + self, + upstream: Stream[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + consecutive_only: bool, + ) -> None: + super().__init__(upstream) + self._key = key + self._consecutive_only = consecutive_only + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_adistinct_stream(self) + + class FilterStream(DownStream[T, T]): __slots__ = ("_upstream", "_when") @@ -682,6 +975,19 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_filter_stream(self) +class AFilterStream(DownStream[T, T]): + __slots__ = ("_upstream", "_when") + + def __init__( + self, upstream: Stream[T], when: Callable[[T], Coroutine[Any, Any, Any]] + ) -> None: + super().__init__(upstream) + self._when = when + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_afilter_stream(self) + + class FlattenStream(DownStream[Iterable[T], T]): __slots__ = ("_upstream", "_concurrency") @@ -693,6 +999,17 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_flatten_stream(self) +class AFlattenStream(DownStream[AsyncIterable[T], T]): + __slots__ = ("_upstream", "_concurrency") + + def __init__(self, upstream: Stream[AsyncIterable[T]], concurrency: int) -> None: + super().__init__(upstream) + self._concurrency = concurrency + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_aflatten_stream(self) + + class ForeachStream(DownStream[T, T]): __slots__ = ("_upstream", "_effect", "_concurrency", "_ordered", "_via") @@ -752,6 +1069,25 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_group_stream(self) +class AGroupStream(DownStream[T, List[T]]): + __slots__ = ("_upstream", "_size", "_interval", "_by") + + def __init__( + self, + upstream: Stream[T], + size: Optional[int], + interval: Optional[datetime.timedelta], + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + super().__init__(upstream) + self._size = size + self._interval = interval + self._by = by + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_agroup_stream(self) + + class GroupbyStream(DownStream[T, Tuple[U, List[T]]]): __slots__ = ("_upstream", "_key", "_size", "_interval") @@ -771,6 +1107,25 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_groupby_stream(self) +class AGroupbyStream(DownStream[T, Tuple[U, List[T]]]): + __slots__ = ("_upstream", "_key", "_size", "_interval") + + def __init__( + self, + upstream: Stream[T], + key: Callable[[T], Coroutine[Any, Any, U]], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + super().__init__(upstream) + self._key = key + self._size = size + self._interval = interval + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_agroupby_stream(self) + + class MapStream(DownStream[T, U]): __slots__ = ("_upstream", "_transformation", "_concurrency", "_ordered", "_via") @@ -839,6 +1194,23 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_skip_stream(self) +class ASkipStream(DownStream[T, T]): + __slots__ = ("_upstream", "_count", "_until") + + def __init__( + self, + upstream: Stream[T], + count: Optional[int], + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + super().__init__(upstream) + self._count = count + self._until = until + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_askip_stream(self) + + class ThrottleStream(DownStream[T, T]): __slots__ = ("_upstream", "_count", "_per") @@ -871,3 +1243,20 @@ def __init__( def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_truncate_stream(self) + + +class ATruncateStream(DownStream[T, T]): + __slots__ = ("_upstream", "_count", "_when") + + def __init__( + self, + upstream: Stream[T], + count: Optional[int] = None, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> None: + super().__init__(upstream) + self._count = count + self._when = when + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_atruncate_stream(self) diff --git a/streamable/util/asynctools.py b/streamable/util/asynctools.py new file mode 100644 index 00000000..439445e5 --- /dev/null +++ b/streamable/util/asynctools.py @@ -0,0 +1,26 @@ +import asyncio +from typing import AsyncIterator, Awaitable, Optional, TypeVar + +T = TypeVar("T") + + +async def awaitable_to_coroutine(aw: Awaitable[T]) -> T: + return await aw + + +async def empty_aiter() -> AsyncIterator: + return + yield + + +class GetEventLoopMixin: + _EVENT_LOOP_SINGLETON: Optional[asyncio.AbstractEventLoop] = None + + @classmethod + def get_event_loop(cls) -> asyncio.AbstractEventLoop: + try: + return asyncio.get_running_loop() + except RuntimeError: + if not cls._EVENT_LOOP_SINGLETON: + cls._EVENT_LOOP_SINGLETON = asyncio.new_event_loop() + return cls._EVENT_LOOP_SINGLETON diff --git a/streamable/util/errors.py b/streamable/util/errors.py new file mode 100644 index 00000000..3f5bcf54 --- /dev/null +++ b/streamable/util/errors.py @@ -0,0 +1,4 @@ +class WrappedError(Exception): + def __init__(self, error: Exception): + super().__init__(repr(error)) + self.error = error diff --git a/streamable/util/functiontools.py b/streamable/util/functiontools.py index 3925d84e..ef8ca70e 100644 --- a/streamable/util/functiontools.py +++ b/streamable/util/functiontools.py @@ -1,15 +1,23 @@ -from typing import Any, Callable, Coroutine, Generic, Tuple, Type, TypeVar, overload +from functools import partial +from typing import ( + Any, + AsyncIterable, + Callable, + Coroutine, + Generic, + Tuple, + Type, + TypeVar, + overload, +) + +from streamable.util.asynctools import GetEventLoopMixin +from streamable.util.errors import WrappedError T = TypeVar("T") R = TypeVar("R") -class WrappedError(Exception): - def __init__(self, error: Exception): - super().__init__(repr(error)) - self.error = error - - class _ErrorWrappingDecorator(Generic[T, R]): def __init__(self, func: Callable[[T], R], error_type: Type[Exception]) -> None: self.func = func @@ -26,7 +34,33 @@ def wrap_error(func: Callable[[T], R], error_type: Type[Exception]) -> Callable[ return _ErrorWrappingDecorator(func, error_type) +def awrap_error( + async_func: Callable[[T], Coroutine[Any, Any, R]], error_type: Type[Exception] +) -> Callable[[T], Coroutine[Any, Any, R]]: + async def wrap(elem: T) -> R: + try: + coroutine = async_func(elem) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}" + ) + return await coroutine + except error_type as e: + raise WrappedError(e) from e + + return wrap + + iter_wo_stopiteration = wrap_error(iter, StopIteration) +iter_wo_stopasynciteration = wrap_error(iter, StopAsyncIteration) + + +def _aiter(aiterable: AsyncIterable): + return aiterable.__aiter__() + + +aiter_wo_stopasynciteration = wrap_error(_aiter, StopAsyncIteration) +aiter_wo_stopiteration = wrap_error(_aiter, StopIteration) class _Sidify(Generic[T]): @@ -115,3 +149,28 @@ def add(a: int, b: int) -> int: ``` """ return _Star(func) + + +class _Syncify(Generic[R], GetEventLoopMixin): + def __init__(self, async_func: Callable[[T], Coroutine[Any, Any, R]]) -> None: + self.async_func = async_func + + def __call__(self, T) -> R: + coroutine = self.async_func(T) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}" + ) + return self.get_event_loop().run_until_complete(coroutine) + + +def syncify(async_func: Callable[[T], Coroutine[Any, Any, R]]) -> Callable[[T], R]: + return _Syncify(async_func) + + +async def _async_call(func: Callable[[T], R], o: T) -> R: + return func(o) + + +def asyncify(func: Callable[[T], R]) -> Callable[[T], Coroutine[Any, Any, R]]: + return partial(_async_call, func) diff --git a/streamable/util/futuretools.py b/streamable/util/futuretools.py index ec551973..278ce5fb 100644 --- a/streamable/util/futuretools.py +++ b/streamable/util/futuretools.py @@ -3,7 +3,7 @@ from collections import deque from concurrent.futures import Future from contextlib import suppress -from typing import Awaitable, Deque, Iterator, Sized, Type, TypeVar, cast +from typing import AsyncIterator, Awaitable, Deque, Iterator, Sized, Type, TypeVar, cast with suppress(ImportError): from streamable.util.protocols import Queue @@ -11,7 +11,7 @@ T = TypeVar("T") -class FutureResultCollection(Iterator[T], Sized, ABC): +class FutureResultCollection(Iterator[T], AsyncIterator[T], Sized, ABC): """ Iterator over added futures' results. Supports adding new futures after iteration started. """ @@ -19,6 +19,9 @@ class FutureResultCollection(Iterator[T], Sized, ABC): @abstractmethod def add_future(self, future: "Future[T]") -> None: ... + async def __anext__(self) -> T: + return next(self) + class DequeFutureResultCollection(FutureResultCollection[T]): def __init__(self) -> None: @@ -87,6 +90,9 @@ def __next__(self) -> T: cast(Awaitable[T], self._futures.popleft()) ) + async def __anext__(self) -> T: + return await cast(Awaitable[T], self._futures.popleft()) + class FDFOAsyncFutureResultCollection(CallbackFutureResultCollection[T]): """ @@ -106,3 +112,9 @@ def __next__(self) -> T: self._n_futures -= 1 self._waiter = self.event_loop.create_future() return result + + async def __anext__(self) -> T: + result = await self._waiter + self._n_futures -= 1 + self._waiter = self.event_loop.create_future() + return result diff --git a/streamable/util/iterabletools.py b/streamable/util/iterabletools.py new file mode 100644 index 00000000..cbb1228c --- /dev/null +++ b/streamable/util/iterabletools.py @@ -0,0 +1,62 @@ +from typing import ( + AsyncIterable, + AsyncIterator, + Callable, + Iterable, + Iterator, + TypeVar, +) + +from streamable.util.asynctools import GetEventLoopMixin + +T = TypeVar("T") + + +class BiIterable(Iterable[T], AsyncIterable[T]): + pass + + +class BiIterator(Iterator[T], AsyncIterator[T]): + pass + + +class SyncToBiIterable(BiIterable[T]): + def __init__(self, iterable: Iterable[T]): + self.iterable = iterable + + def __iter__(self) -> Iterator[T]: + return iter(self.iterable) + + def __aiter__(self) -> AsyncIterator[T]: + return SyncToAsyncIterator(self.iterable) + + +sync_to_bi_iterable: Callable[[Iterable[T]], BiIterable[T]] = SyncToBiIterable + + +class SyncToAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: Iterable[T]): + self.iterator: Iterator[T] = iter(iterator) + + async def __anext__(self) -> T: + try: + return next(self.iterator) + except StopIteration as e: + raise StopAsyncIteration() from e + + +sync_to_async_iter: Callable[[Iterable[T]], AsyncIterator[T]] = SyncToAsyncIterator + + +class AsyncToSyncIterator(Iterator[T], GetEventLoopMixin): + def __init__(self, iterator: AsyncIterable[T]): + self.iterator: AsyncIterator[T] = iterator.__aiter__() + + def __next__(self) -> T: + try: + return self.get_event_loop().run_until_complete(self.iterator.__anext__()) + except StopAsyncIteration as e: + raise StopIteration() from e + + +async_to_sync_iter: Callable[[AsyncIterable[T]], Iterator[T]] = AsyncToSyncIterator diff --git a/streamable/util/validationtools.py b/streamable/util/validationtools.py index 15fc399b..375f0e12 100644 --- a/streamable/util/validationtools.py +++ b/streamable/util/validationtools.py @@ -1,6 +1,14 @@ import datetime from contextlib import suppress -from typing import Any, Iterable, Iterator, Optional, Type, TypeVar, Union +from typing import ( + AsyncIterator, + Iterable, + Iterator, + Optional, + Type, + TypeVar, + Union, +) with suppress(ImportError): from typing import Literal @@ -13,6 +21,13 @@ def validate_iterator(iterator: Iterator): raise TypeError(f"`iterator` must be an Iterator but got a {type(iterator)}") +def validate_aiterator(iterator: AsyncIterator): + if not isinstance(iterator, AsyncIterator): + raise TypeError( + f"`iterator` must be an AsyncIterator but got a {type(iterator)}" + ) + + def validate_base(base: int): if base <= 0: raise ValueError(f"`base` must be > 0 but got {base}") @@ -69,7 +84,7 @@ def validate_optional_positive_count(count: Optional[int]): # def validate_not_none(value: Any, name: str) -> None: # if value is None: -# raise TypeError(f"`{name}` cannot be None") +# raise TypeError(f"`{name}` must not be None") def validate_errors( diff --git a/streamable/visitors/__init__.py b/streamable/visitors/__init__.py index 1234b6b4..ba67ca01 100644 --- a/streamable/visitors/__init__.py +++ b/streamable/visitors/__init__.py @@ -1 +1,3 @@ from streamable.visitors.base import Visitor + +__all__ = ["Visitor"] diff --git a/streamable/visitors/aiterator.py b/streamable/visitors/aiterator.py new file mode 100644 index 00000000..4cfe959e --- /dev/null +++ b/streamable/visitors/aiterator.py @@ -0,0 +1,227 @@ +from typing import AsyncIterable, AsyncIterator, Iterable, TypeVar, cast + +from streamable import afunctions +from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, + AForeachStream, + AGroupbyStream, + AGroupStream, + AMapStream, + ASkipStream, + ATruncateStream, + CatchStream, + DistinctStream, + FilterStream, + FlattenStream, + ForeachStream, + GroupbyStream, + GroupStream, + MapStream, + ObserveStream, + SkipStream, + Stream, + ThrottleStream, + TruncateStream, +) +from streamable.util.functiontools import async_sidify, sidify +from streamable.util.iterabletools import sync_to_async_iter +from streamable.visitors import Visitor + +T = TypeVar("T") +U = TypeVar("U") + + +class AsyncIteratorVisitor(Visitor[AsyncIterator[T]]): + def visit_catch_stream(self, stream: CatchStream[T]) -> AsyncIterator[T]: + return afunctions.catch( + stream.upstream.accept(self), + stream._errors, + when=stream._when, + replacement=stream._replacement, + finally_raise=stream._finally_raise, + ) + + def visit_acatch_stream(self, stream: ACatchStream[T]) -> AsyncIterator[T]: + return afunctions.acatch( + stream.upstream.accept(self), + stream._errors, + when=stream._when, + replacement=stream._replacement, + finally_raise=stream._finally_raise, + ) + + def visit_distinct_stream(self, stream: DistinctStream[T]) -> AsyncIterator[T]: + return afunctions.distinct( + stream.upstream.accept(self), + stream._key, + consecutive_only=stream._consecutive_only, + ) + + def visit_adistinct_stream(self, stream: ADistinctStream[T]) -> AsyncIterator[T]: + return afunctions.adistinct( + stream.upstream.accept(self), + stream._key, + consecutive_only=stream._consecutive_only, + ) + + def visit_filter_stream(self, stream: FilterStream[T]) -> AsyncIterator[T]: + return afunctions.filter(stream.upstream.accept(self), stream._when) + + def visit_afilter_stream(self, stream: AFilterStream[T]) -> AsyncIterator[T]: + return afunctions.afilter(stream.upstream.accept(self), stream._when) + + def visit_flatten_stream(self, stream: FlattenStream[T]) -> AsyncIterator[T]: + return afunctions.flatten( + stream.upstream.accept(AsyncIteratorVisitor[Iterable]()), + concurrency=stream._concurrency, + ) + + def visit_aflatten_stream(self, stream: AFlattenStream[T]) -> AsyncIterator[T]: + return afunctions.aflatten( + stream.upstream.accept(AsyncIteratorVisitor[AsyncIterable]()), + concurrency=stream._concurrency, + ) + + def visit_foreach_stream(self, stream: ForeachStream[T]) -> AsyncIterator[T]: + return self.visit_map_stream( + MapStream( + stream.upstream, + sidify(stream._effect), + stream._concurrency, + stream._ordered, + stream._via, + ) + ) + + def visit_aforeach_stream(self, stream: AForeachStream[T]) -> AsyncIterator[T]: + return self.visit_amap_stream( + AMapStream( + stream.upstream, + async_sidify(stream._effect), + stream._concurrency, + stream._ordered, + ) + ) + + def visit_group_stream(self, stream: GroupStream[U]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.group( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._size, + interval=stream._interval, + by=stream._by, + ), + ) + + def visit_agroup_stream(self, stream: AGroupStream[U]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.agroup( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._size, + interval=stream._interval, + by=stream._by, + ), + ) + + def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.groupby( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._key, + size=stream._size, + interval=stream._interval, + ), + ) + + def visit_agroupby_stream(self, stream: AGroupbyStream[U, T]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.agroupby( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._key, + size=stream._size, + interval=stream._interval, + ), + ) + + def visit_map_stream(self, stream: MapStream[U, T]) -> AsyncIterator[T]: + return afunctions.map( + stream._transformation, + stream.upstream.accept(AsyncIteratorVisitor[U]()), + concurrency=stream._concurrency, + ordered=stream._ordered, + via=stream._via, + ) + + def visit_amap_stream(self, stream: AMapStream[U, T]) -> AsyncIterator[T]: + return afunctions.amap( + stream._transformation, + stream.upstream.accept(AsyncIteratorVisitor[U]()), + concurrency=stream._concurrency, + ordered=stream._ordered, + ) + + def visit_observe_stream(self, stream: ObserveStream[T]) -> AsyncIterator[T]: + return afunctions.observe( + stream.upstream.accept(self), + stream._what, + ) + + def visit_skip_stream(self, stream: SkipStream[T]) -> AsyncIterator[T]: + return afunctions.skip( + stream.upstream.accept(self), + stream._count, + until=stream._until, + ) + + def visit_askip_stream(self, stream: ASkipStream[T]) -> AsyncIterator[T]: + return afunctions.askip( + stream.upstream.accept(self), + stream._count, + until=stream._until, + ) + + def visit_throttle_stream(self, stream: ThrottleStream[T]) -> AsyncIterator[T]: + return afunctions.throttle( + stream.upstream.accept(self), + stream._count, + per=stream._per, + ) + + def visit_truncate_stream(self, stream: TruncateStream[T]) -> AsyncIterator[T]: + return afunctions.truncate( + stream.upstream.accept(self), + stream._count, + when=stream._when, + ) + + def visit_atruncate_stream(self, stream: ATruncateStream[T]) -> AsyncIterator[T]: + return afunctions.atruncate( + stream.upstream.accept(self), + stream._count, + when=stream._when, + ) + + def visit_stream(self, stream: Stream[T]) -> AsyncIterator[T]: + if isinstance(stream.source, Iterable): + return sync_to_async_iter(stream.source) + if isinstance(stream.source, AsyncIterable): + return stream.source.__aiter__() + if callable(stream.source): + iterable = stream.source() + if isinstance(iterable, Iterable): + return sync_to_async_iter(iterable) + if isinstance(iterable, AsyncIterable): + return iterable.__aiter__() + raise TypeError( + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a Callable[[], {type(iterable)}]" + ) + raise TypeError( + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a {type(stream.source)}" + ) diff --git a/streamable/visitors/base.py b/streamable/visitors/base.py index d8f5d409..3cf28bfd 100644 --- a/streamable/visitors/base.py +++ b/streamable/visitors/base.py @@ -15,15 +15,27 @@ def visit_stream(self, stream: stream.Stream) -> V: ... def visit_catch_stream(self, stream: stream.CatchStream) -> V: return self.visit_stream(stream) + def visit_acatch_stream(self, stream: stream.ACatchStream) -> V: + return self.visit_stream(stream) + def visit_distinct_stream(self, stream: stream.DistinctStream) -> V: return self.visit_stream(stream) + def visit_adistinct_stream(self, stream: stream.ADistinctStream) -> V: + return self.visit_stream(stream) + def visit_filter_stream(self, stream: stream.FilterStream) -> V: return self.visit_stream(stream) + def visit_afilter_stream(self, stream: stream.AFilterStream) -> V: + return self.visit_stream(stream) + def visit_flatten_stream(self, stream: stream.FlattenStream) -> V: return self.visit_stream(stream) + def visit_aflatten_stream(self, stream: stream.AFlattenStream) -> V: + return self.visit_stream(stream) + def visit_foreach_stream(self, stream: stream.ForeachStream) -> V: return self.visit_stream(stream) @@ -33,9 +45,15 @@ def visit_aforeach_stream(self, stream: stream.AForeachStream) -> V: def visit_group_stream(self, stream: stream.GroupStream) -> V: return self.visit_stream(stream) + def visit_agroup_stream(self, stream: stream.AGroupStream) -> V: + return self.visit_stream(stream) + def visit_groupby_stream(self, stream: stream.GroupbyStream) -> V: return self.visit_stream(stream) + def visit_agroupby_stream(self, stream: stream.AGroupbyStream) -> V: + return self.visit_stream(stream) + def visit_observe_stream(self, stream: stream.ObserveStream) -> V: return self.visit_stream(stream) @@ -48,8 +66,14 @@ def visit_amap_stream(self, stream: stream.AMapStream) -> V: def visit_skip_stream(self, stream: stream.SkipStream) -> V: return self.visit_stream(stream) + def visit_askip_stream(self, stream: stream.ASkipStream) -> V: + return self.visit_stream(stream) + def visit_throttle_stream(self, stream: stream.ThrottleStream) -> V: return self.visit_stream(stream) def visit_truncate_stream(self, stream: stream.TruncateStream) -> V: return self.visit_stream(stream) + + def visit_atruncate_stream(self, stream: stream.ATruncateStream) -> V: + return self.visit_stream(stream) diff --git a/streamable/visitors/equality.py b/streamable/visitors/equality.py index 3af9f2d0..d03bc09e 100644 --- a/streamable/visitors/equality.py +++ b/streamable/visitors/equality.py @@ -1,8 +1,16 @@ -from typing import Any, TypeVar +from typing import Any, Union from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -19,17 +27,17 @@ ) from streamable.visitors import Visitor -T = TypeVar("T") -U = TypeVar("U") - class EqualityVisitor(Visitor[bool]): def __init__(self, other: Any): self.other: Any = other - def visit_catch_stream(self, stream: CatchStream[T]) -> bool: + def type_eq(self, stream: Stream) -> bool: + return type(stream) == type(self.other) + + def catch_eq(self, stream: Union[CatchStream, ACatchStream]) -> bool: return ( - isinstance(self.other, CatchStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._errors == self.other._errors and stream._when == self.other._when @@ -37,114 +45,154 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> bool: and stream._finally_raise == self.other._finally_raise ) - def visit_distinct_stream(self, stream: DistinctStream[T]) -> bool: + def visit_catch_stream(self, stream: CatchStream) -> bool: + return self.catch_eq(stream) + + def visit_acatch_stream(self, stream: ACatchStream) -> bool: + return self.catch_eq(stream) + + def distinct_eq(self, stream: Union[DistinctStream, ADistinctStream]) -> bool: return ( - isinstance(self.other, DistinctStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._key == self.other._key and stream._consecutive_only == self.other._consecutive_only ) - def visit_filter_stream(self, stream: FilterStream[T]) -> bool: + def visit_distinct_stream(self, stream: DistinctStream) -> bool: + return self.distinct_eq(stream) + + def visit_adistinct_stream(self, stream: ADistinctStream) -> bool: + return self.distinct_eq(stream) + + def filter_eq(self, stream: Union[FilterStream, AFilterStream]) -> bool: return ( - isinstance(self.other, FilterStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._when == self.other._when ) - def visit_flatten_stream(self, stream: FlattenStream[T]) -> bool: - return ( - isinstance(self.other, FlattenStream) - and stream.upstream.accept(EqualityVisitor(self.other.upstream)) - and stream._concurrency == self.other._concurrency - ) + def visit_filter_stream(self, stream: FilterStream) -> bool: + return self.filter_eq(stream) - def visit_foreach_stream(self, stream: ForeachStream[T]) -> bool: + def visit_afilter_stream(self, stream: AFilterStream) -> bool: + return self.filter_eq(stream) + + def flatten_eq(self, stream: Union[FlattenStream, AFlattenStream]) -> bool: return ( - isinstance(self.other, ForeachStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._concurrency == self.other._concurrency - and stream._effect == self.other._effect - and stream._ordered == self.other._ordered - and stream._via == self.other._via ) - def visit_aforeach_stream(self, stream: AForeachStream[T]) -> bool: + def visit_flatten_stream(self, stream: FlattenStream) -> bool: + return self.flatten_eq(stream) + + def visit_aflatten_stream(self, stream: AFlattenStream) -> bool: + return self.flatten_eq(stream) + + def foreach_eq(self, stream: Union[ForeachStream, AForeachStream]) -> bool: return ( - isinstance(self.other, AForeachStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._concurrency == self.other._concurrency and stream._effect == self.other._effect and stream._ordered == self.other._ordered ) - def visit_group_stream(self, stream: GroupStream[U]) -> bool: + def visit_foreach_stream(self, stream: ForeachStream) -> bool: + return self.foreach_eq(stream) and stream._via == self.other._via + + def visit_aforeach_stream(self, stream: AForeachStream) -> bool: + return self.foreach_eq(stream) + + def group_eq(self, stream: Union[GroupStream, AGroupStream]) -> bool: return ( - isinstance(self.other, GroupStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._by == self.other._by and stream._size == self.other._size and stream._interval == self.other._interval ) - def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> bool: + def visit_group_stream(self, stream: GroupStream) -> bool: + return self.group_eq(stream) + + def visit_agroup_stream(self, stream: AGroupStream) -> bool: + return self.group_eq(stream) + + def groupby_eq(self, stream: Union[GroupbyStream, AGroupbyStream]) -> bool: return ( - isinstance(self.other, GroupbyStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._key == self.other._key and stream._size == self.other._size and stream._interval == self.other._interval ) - def visit_map_stream(self, stream: MapStream[U, T]) -> bool: - return ( - isinstance(self.other, MapStream) - and stream.upstream.accept(EqualityVisitor(self.other.upstream)) - and stream._concurrency == self.other._concurrency - and stream._transformation == self.other._transformation - and stream._ordered == self.other._ordered - and stream._via == self.other._via - ) + def visit_groupby_stream(self, stream: GroupbyStream) -> bool: + return self.groupby_eq(stream) + + def visit_agroupby_stream(self, stream: AGroupbyStream) -> bool: + return self.groupby_eq(stream) - def visit_amap_stream(self, stream: AMapStream[U, T]) -> bool: + def map_eq(self, stream: Union[MapStream, AMapStream]) -> bool: return ( - isinstance(self.other, AMapStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._concurrency == self.other._concurrency and stream._transformation == self.other._transformation and stream._ordered == self.other._ordered ) - def visit_observe_stream(self, stream: ObserveStream[T]) -> bool: + def visit_map_stream(self, stream: MapStream) -> bool: + return self.map_eq(stream) and stream._via == self.other._via + + def visit_amap_stream(self, stream: AMapStream) -> bool: + return self.map_eq(stream) + + def visit_observe_stream(self, stream: ObserveStream) -> bool: return ( - isinstance(self.other, ObserveStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._what == self.other._what ) - def visit_skip_stream(self, stream: SkipStream[T]) -> bool: + def skip_eq(self, stream: Union[SkipStream, ASkipStream]) -> bool: return ( - isinstance(self.other, SkipStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._count == self.other._count and stream._until == self.other._until ) - def visit_throttle_stream(self, stream: ThrottleStream[T]) -> bool: + def visit_skip_stream(self, stream: SkipStream) -> bool: + return self.skip_eq(stream) + + def visit_askip_stream(self, stream: ASkipStream) -> bool: + return self.skip_eq(stream) + + def visit_throttle_stream(self, stream: ThrottleStream) -> bool: return ( - isinstance(self.other, ThrottleStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._count == self.other._count and stream._per == self.other._per ) - def visit_truncate_stream(self, stream: TruncateStream[T]) -> bool: + def truncate_eq(self, stream: Union[TruncateStream, ATruncateStream]) -> bool: return ( - isinstance(self.other, TruncateStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._count == self.other._count and stream._when == self.other._when ) - def visit_stream(self, stream: Stream[T]) -> bool: - return isinstance(self.other, Stream) and stream.source == self.other.source + def visit_truncate_stream(self, stream: TruncateStream) -> bool: + return self.truncate_eq(stream) + + def visit_atruncate_stream(self, stream: ATruncateStream) -> bool: + return self.truncate_eq(stream) + + def visit_stream(self, stream: Stream) -> bool: + return self.type_eq(stream) and stream.source == self.other.source diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py index 417aa27b..34167252 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/iterator.py @@ -1,9 +1,17 @@ -from typing import Iterable, Iterator, TypeVar, cast +from typing import AsyncIterable, Iterable, Iterator, TypeVar, cast from streamable import functions from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -18,7 +26,8 @@ ThrottleStream, TruncateStream, ) -from streamable.util.functiontools import async_sidify, sidify, wrap_error +from streamable.util.functiontools import async_sidify, sidify, syncify, wrap_error +from streamable.util.iterabletools import async_to_sync_iter from streamable.visitors import Visitor T = TypeVar("T") @@ -35,6 +44,15 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> Iterator[T]: finally_raise=stream._finally_raise, ) + def visit_acatch_stream(self, stream: ACatchStream[T]) -> Iterator[T]: + return functions.acatch( + stream.upstream.accept(self), + stream._errors, + when=stream._when, + replacement=stream._replacement, + finally_raise=stream._finally_raise, + ) + def visit_distinct_stream(self, stream: DistinctStream[T]) -> Iterator[T]: return functions.distinct( stream.upstream.accept(self), @@ -42,18 +60,37 @@ def visit_distinct_stream(self, stream: DistinctStream[T]) -> Iterator[T]: consecutive_only=stream._consecutive_only, ) + def visit_adistinct_stream(self, stream: ADistinctStream[T]) -> Iterator[T]: + return functions.adistinct( + stream.upstream.accept(self), + stream._key, + consecutive_only=stream._consecutive_only, + ) + def visit_filter_stream(self, stream: FilterStream[T]) -> Iterator[T]: return filter( wrap_error(stream._when, StopIteration), cast(Iterable[T], stream.upstream.accept(self)), ) + def visit_afilter_stream(self, stream: AFilterStream[T]) -> Iterator[T]: + return filter( + wrap_error(syncify(stream._when), StopIteration), + cast(Iterable[T], stream.upstream.accept(self)), + ) + def visit_flatten_stream(self, stream: FlattenStream[T]) -> Iterator[T]: return functions.flatten( stream.upstream.accept(IteratorVisitor[Iterable]()), concurrency=stream._concurrency, ) + def visit_aflatten_stream(self, stream: AFlattenStream[T]) -> Iterator[T]: + return functions.aflatten( + stream.upstream.accept(IteratorVisitor[AsyncIterable]()), + concurrency=stream._concurrency, + ) + def visit_foreach_stream(self, stream: ForeachStream[T]) -> Iterator[T]: return self.visit_map_stream( MapStream( @@ -86,6 +123,17 @@ def visit_group_stream(self, stream: GroupStream[U]) -> Iterator[T]: ), ) + def visit_agroup_stream(self, stream: AGroupStream[U]) -> Iterator[T]: + return cast( + Iterator[T], + functions.agroup( + stream.upstream.accept(IteratorVisitor[U]()), + stream._size, + interval=stream._interval, + by=stream._by, + ), + ) + def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> Iterator[T]: return cast( Iterator[T], @@ -97,6 +145,17 @@ def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> Iterator[T]: ), ) + def visit_agroupby_stream(self, stream: AGroupbyStream[U, T]) -> Iterator[T]: + return cast( + Iterator[T], + functions.agroupby( + stream.upstream.accept(IteratorVisitor[U]()), + stream._key, + size=stream._size, + interval=stream._interval, + ), + ) + def visit_map_stream(self, stream: MapStream[U, T]) -> Iterator[T]: return functions.map( stream._transformation, @@ -127,6 +186,13 @@ def visit_skip_stream(self, stream: SkipStream[T]) -> Iterator[T]: until=stream._until, ) + def visit_askip_stream(self, stream: ASkipStream[T]) -> Iterator[T]: + return functions.askip( + stream.upstream.accept(self), + stream._count, + until=stream._until, + ) + def visit_throttle_stream(self, stream: ThrottleStream[T]) -> Iterator[T]: return functions.throttle( stream.upstream.accept(self), @@ -141,17 +207,27 @@ def visit_truncate_stream(self, stream: TruncateStream[T]) -> Iterator[T]: when=stream._when, ) + def visit_atruncate_stream(self, stream: ATruncateStream[T]) -> Iterator[T]: + return functions.atruncate( + stream.upstream.accept(self), + stream._count, + when=stream._when, + ) + def visit_stream(self, stream: Stream[T]) -> Iterator[T]: if isinstance(stream.source, Iterable): - iterable = stream.source - elif callable(stream.source): + return iter(stream.source) + if isinstance(stream.source, AsyncIterable): + return async_to_sync_iter(stream.source) + if callable(stream.source): iterable = stream.source() - if not isinstance(iterable, Iterable): - raise TypeError( - f"`source` must be an Iterable or a Callable[[], Iterable] but got a Callable[[], {type(iterable)}]" - ) - else: + if isinstance(iterable, Iterable): + return iter(iterable) + if isinstance(iterable, AsyncIterable): + return async_to_sync_iter(iterable) raise TypeError( - f"`source` must be an Iterable or a Callable[[], Iterable] but got a {type(stream.source)}" + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a Callable[[], {type(iterable)}]" ) - return iter(iterable) + raise TypeError( + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a {type(stream.source)}" + ) diff --git a/streamable/visitors/representation.py b/streamable/visitors/representation.py index 434b3ffc..926876ce 100644 --- a/streamable/visitors/representation.py +++ b/streamable/visitors/representation.py @@ -1,9 +1,17 @@ from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Type, TypeVar +from typing import Any, Iterable, List from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -22,9 +30,6 @@ from streamable.util.functiontools import _Star from streamable.visitors import Visitor -T = TypeVar("T") -U = TypeVar("U") - class ToStringVisitor(Visitor[str], ABC): def __init__(self) -> None: @@ -34,7 +39,7 @@ def __init__(self) -> None: @abstractmethod def to_string(o: object) -> str: ... - def visit_catch_stream(self, stream: CatchStream[T]) -> str: + def visit_catch_stream(self, stream: CatchStream) -> str: replacement = "" if stream._replacement is not NO_REPLACEMENT: replacement = f", replacement={self.to_string(stream._replacement)}" @@ -47,83 +52,136 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> str: ) return stream.upstream.accept(self) - def visit_distinct_stream(self, stream: DistinctStream[T]) -> str: + def visit_acatch_stream(self, stream: ACatchStream) -> str: + replacement = "" + if stream._replacement is not NO_REPLACEMENT: + replacement = f", replacement={self.to_string(stream._replacement)}" + if isinstance(stream._errors, Iterable): + errors = f"({', '.join(map(self.to_string, stream._errors))})" + else: + errors = self.to_string(stream._errors) + self.methods_reprs.append( + f"acatch({errors}, when={self.to_string(stream._when)}{replacement}, finally_raise={self.to_string(stream._finally_raise)})" + ) + return stream.upstream.accept(self) + + def visit_distinct_stream(self, stream: DistinctStream) -> str: self.methods_reprs.append( f"distinct({self.to_string(stream._key)}, consecutive_only={self.to_string(stream._consecutive_only)})" ) return stream.upstream.accept(self) - def visit_filter_stream(self, stream: FilterStream[T]) -> str: + def visit_adistinct_stream(self, stream: ADistinctStream) -> str: + self.methods_reprs.append( + f"adistinct({self.to_string(stream._key)}, consecutive_only={self.to_string(stream._consecutive_only)})" + ) + return stream.upstream.accept(self) + + def visit_filter_stream(self, stream: FilterStream) -> str: self.methods_reprs.append(f"filter({self.to_string(stream._when)})") return stream.upstream.accept(self) - def visit_flatten_stream(self, stream: FlattenStream[T]) -> str: + def visit_afilter_stream(self, stream: AFilterStream) -> str: + self.methods_reprs.append(f"afilter({self.to_string(stream._when)})") + return stream.upstream.accept(self) + + def visit_flatten_stream(self, stream: FlattenStream) -> str: self.methods_reprs.append( f"flatten(concurrency={self.to_string(stream._concurrency)})" ) return stream.upstream.accept(self) - def visit_foreach_stream(self, stream: ForeachStream[T]) -> str: + def visit_aflatten_stream(self, stream: AFlattenStream) -> str: + self.methods_reprs.append( + f"aflatten(concurrency={self.to_string(stream._concurrency)})" + ) + return stream.upstream.accept(self) + + def visit_foreach_stream(self, stream: ForeachStream) -> str: via = f", via={self.to_string(stream._via)}" if stream._concurrency > 1 else "" self.methods_reprs.append( f"foreach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}{via})" ) return stream.upstream.accept(self) - def visit_aforeach_stream(self, stream: AForeachStream[T]) -> str: + def visit_aforeach_stream(self, stream: AForeachStream) -> str: self.methods_reprs.append( f"aforeach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)})" ) return stream.upstream.accept(self) - def visit_group_stream(self, stream: GroupStream[U]) -> str: + def visit_group_stream(self, stream: GroupStream) -> str: self.methods_reprs.append( f"group(size={self.to_string(stream._size)}, by={self.to_string(stream._by)}, interval={self.to_string(stream._interval)})" ) return stream.upstream.accept(self) - def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> str: + def visit_agroup_stream(self, stream: AGroupStream) -> str: + self.methods_reprs.append( + f"agroup(size={self.to_string(stream._size)}, by={self.to_string(stream._by)}, interval={self.to_string(stream._interval)})" + ) + return stream.upstream.accept(self) + + def visit_groupby_stream(self, stream: GroupbyStream) -> str: self.methods_reprs.append( f"groupby({self.to_string(stream._key)}, size={self.to_string(stream._size)}, interval={self.to_string(stream._interval)})" ) return stream.upstream.accept(self) - def visit_map_stream(self, stream: MapStream[U, T]) -> str: + def visit_agroupby_stream(self, stream: AGroupbyStream) -> str: + self.methods_reprs.append( + f"agroupby({self.to_string(stream._key)}, size={self.to_string(stream._size)}, interval={self.to_string(stream._interval)})" + ) + return stream.upstream.accept(self) + + def visit_map_stream(self, stream: MapStream) -> str: via = f", via={self.to_string(stream._via)}" if stream._concurrency > 1 else "" self.methods_reprs.append( f"map({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}{via})" ) return stream.upstream.accept(self) - def visit_amap_stream(self, stream: AMapStream[U, T]) -> str: + def visit_amap_stream(self, stream: AMapStream) -> str: self.methods_reprs.append( f"amap({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)})" ) return stream.upstream.accept(self) - def visit_observe_stream(self, stream: ObserveStream[T]) -> str: + def visit_observe_stream(self, stream: ObserveStream) -> str: self.methods_reprs.append(f"""observe({self.to_string(stream._what)})""") return stream.upstream.accept(self) - def visit_skip_stream(self, stream: SkipStream[T]) -> str: + def visit_skip_stream(self, stream: SkipStream) -> str: self.methods_reprs.append( f"skip({self.to_string(stream._count)}, until={self.to_string(stream._until)})" ) return stream.upstream.accept(self) - def visit_throttle_stream(self, stream: ThrottleStream[T]) -> str: + def visit_askip_stream(self, stream: ASkipStream) -> str: + self.methods_reprs.append( + f"askip({self.to_string(stream._count)}, until={self.to_string(stream._until)})" + ) + return stream.upstream.accept(self) + + def visit_throttle_stream(self, stream: ThrottleStream) -> str: self.methods_reprs.append( f"throttle({self.to_string(stream._count)}, per={self.to_string(stream._per)})" ) return stream.upstream.accept(self) - def visit_truncate_stream(self, stream: TruncateStream[T]) -> str: + def visit_truncate_stream(self, stream: TruncateStream) -> str: self.methods_reprs.append( f"truncate(count={self.to_string(stream._count)}, when={self.to_string(stream._when)})" ) return stream.upstream.accept(self) - def visit_stream(self, stream: Stream[T]) -> str: + def visit_atruncate_stream(self, stream: ATruncateStream) -> str: + self.methods_reprs.append( + f"atruncate(count={self.to_string(stream._count)}, when={self.to_string(stream._when)})" + ) + return stream.upstream.accept(self) + + def visit_stream(self, stream: Stream) -> str: methods_block = "".join( map(lambda r: f" .{r}\n", reversed(self.methods_reprs)) ) diff --git a/tests/test_iterators.py b/tests/test_iterators.py index b9a74de5..2bb30d37 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -1,6 +1,15 @@ +import asyncio import unittest +from typing import AsyncIterator -from streamable.iterators import ObserveIterator, _OSConcurrentMapIterable +from streamable.aiterators import ( + _ConcurrentAMapAsyncIterable, + _RaisingAsyncIterator, +) +from streamable.iterators import ObserveIterator, _ConcurrentMapIterable +from streamable.util.asynctools import awaitable_to_coroutine +from streamable.util.iterabletools import sync_to_async_iter +from tests.utils import async_identity, identity, src class TestIterators(unittest.TestCase): @@ -8,9 +17,9 @@ def test_validation(self): with self.assertRaisesRegex( ValueError, "`buffersize` must be >= 1 but got 0", - msg="`_OSConcurrentMapIterable` constructor should raise for non-positive buffersize", + msg="`_ConcurrentMapIterable` constructor should raise for non-positive buffersize", ): - _OSConcurrentMapIterable( + _ConcurrentMapIterable( iterator=iter([]), transformation=str, concurrency=1, @@ -29,3 +38,28 @@ def test_validation(self): what="", base=0, ) + + def test_ConcurrentAMapAsyncIterable(self) -> None: + with self.assertRaisesRegex( + TypeError, + r"must be an async function i\.e\. a function returning a Coroutine but it returned a ", + msg="`amap` should raise a TypeError if a non async function is passed to it.", + ): + concurrent_amap_async_iterable: _ConcurrentAMapAsyncIterable[int, int] = ( + _ConcurrentAMapAsyncIterable( + sync_to_async_iter(src), + async_identity, + buffersize=2, + ordered=True, + ) + ) + + # remove error wrapping + concurrent_amap_async_iterable.transformation = identity # type: ignore + + aiterator: AsyncIterator[int] = _RaisingAsyncIterator( + concurrent_amap_async_iterable.__aiter__() + ) + print( + asyncio.run(awaitable_to_coroutine(aiterator.__aiter__().__anext__())) + ) diff --git a/tests/test_readme.py b/tests/test_readme.py index a0f1a89e..b5627d41 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -1,7 +1,8 @@ +import asyncio import time import unittest from datetime import timedelta -from typing import Iterator, List, Tuple +from typing import AsyncIterable, Iterator, List, Tuple, TypeVar from streamable.stream import Stream @@ -15,10 +16,12 @@ three_integers_per_second: Stream[int] = integers.throttle(5, per=timedelta(seconds=1)) +T = TypeVar("T") + # fmt: off class TestReadme(unittest.TestCase): - def test_collect_it(self) -> None: + def test_iterate(self) -> None: self.assertListEqual( list(inverses), [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11], @@ -34,6 +37,11 @@ def test_collect_it(self) -> None: self.assertEqual(next(inverses_iter), 1.0) self.assertEqual(next(inverses_iter), 0.5) + async def main() -> List[float]: + return [inverse async for inverse in inverses] + + assert asyncio.run(main()) == [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] + def test_map_example(self) -> None: integer_strings: Stream[str] = integers.map(str) @@ -59,7 +67,7 @@ def test_process_concurrent_map_example(self) -> None: # but the `state` of the main process is not mutated assert state == [] - def test_async_concurrent_map_example(self) -> None: + def test_amap_example(self) -> None: import asyncio import httpx @@ -75,7 +83,25 @@ def test_async_concurrent_map_example(self) -> None: ) assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] - asyncio.get_event_loop().run_until_complete(http_async_client.aclose()) + asyncio.run(http_async_client.aclose()) + + def test_async_amap_example(self) -> None: + import asyncio + + import httpx + + async def main() -> None: + async with httpx.AsyncClient() as http_async_client: + pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http_async_client.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) + ) + assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] + + asyncio.run(main()) def test_starmap_example(self) -> None: from streamable import star @@ -267,12 +293,23 @@ def test_zip_example(self) -> None: def test_count_example(self) -> None: assert integers.count() == 10 + def test_acount_example(self) -> None: + assert asyncio.run(integers.acount()) == 10 + def test_call_example(self) -> None: state: List[int] = [] appending_integers: Stream[int] = integers.foreach(state.append) assert appending_integers() is appending_integers assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + def test_await_example(self) -> None: + async def test() -> None: + state: List[int] = [] + appending_integers: Stream[int] = integers.foreach(state.append) + appending_integers is await appending_integers + assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + asyncio.run(test()) + def test_non_stopping_exceptions_example(self) -> None: from contextlib import suppress @@ -286,6 +323,59 @@ def test_non_stopping_exceptions_example(self) -> None: collected_casted_ints.extend(casted_ints) assert collected_casted_ints == [0, 1, 2, 3, 5, 6, 7, 8, 9] + def test_async_etl_example(self) -> None: # pragma: no cover + # for mypy typing check only + if not self: + import asyncio + import csv + import itertools + from datetime import timedelta + + import httpx + + from streamable import Stream + + async def main() -> None: + with open("./quadruped_pokemons.csv", mode="w") as file: + fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] + writer = csv.DictWriter(file, fields, extrasaction='ignore') + writer.writeheader() + + async with httpx.AsyncClient() as http_async_client: + pipeline: Stream = ( + # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur + Stream(itertools.count(1)) + # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs + .throttle(16, per=timedelta(seconds=1)) + # GETs pokemons via 8 concurrent asyncio coroutines + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .amap(http_async_client.get, concurrency=8) + .foreach(httpx.Response.raise_for_status) + .map(httpx.Response.json) + # Stops the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keeps only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + .observe("quadruped pokemons") + # Catches errors due to None "generation" or "shape" + .catch( + TypeError, + when=lambda error: str(error) == "'NoneType' object is not subscriptable" + ) + # Writes a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catches exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) + + await pipeline + + asyncio.run(main()) + def test_etl_example(self) -> None: # pragma: no cover # for mypy typing check only if not self: diff --git a/tests/test_stream.py b/tests/test_stream.py index 8bda6a9e..8dfa02d3 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -3,22 +3,21 @@ import datetime import logging import math -from operator import itemgetter -from pickle import PickleError import queue -import random import sys import threading import time -import timeit import traceback -from types import FrameType import unittest from collections import Counter +from functools import partial +from pickle import PickleError +from types import TracebackType from typing import ( Any, + AsyncIterable, + AsyncIterator, Callable, - Coroutine, Dict, Iterable, Iterator, @@ -27,139 +26,52 @@ Set, Tuple, Type, - TypeVar, + Union, cast, ) from parameterized import parameterized # type: ignore from streamable import Stream -from streamable.util.functiontools import WrappedError, star - -T = TypeVar("T") -R = TypeVar("R") - - -def timestream(stream: Stream[T], times: int = 1) -> Tuple[float, List[T]]: - res: List[T] = [] - - def iterate(): - nonlocal res - res = list(stream) - - return timeit.timeit(iterate, number=times) / times, res - - -def identity_sleep(seconds: float) -> float: - time.sleep(seconds) - return seconds - - -async def async_identity_sleep(seconds: float) -> float: - await asyncio.sleep(seconds) - return seconds - - -# simulates an I/0 bound function -slow_identity_duration = 0.05 - - -def slow_identity(x: T) -> T: - time.sleep(slow_identity_duration) - return x - - -async def async_slow_identity(x: T) -> T: - await asyncio.sleep(slow_identity_duration) - return x - - -def identity(x: T) -> T: - return x - - -# fmt: off -async def async_identity(x: T) -> T: return x -# fmt: on - - -def square(x): - return x**2 - - -async def async_square(x): - return x**2 - - -def throw(exc: Type[Exception]): - raise exc() - - -def throw_func(exc: Type[Exception]) -> Callable[[T], T]: - return lambda _: throw(exc) - - -def async_throw_func(exc: Type[Exception]) -> Callable[[T], Coroutine[Any, Any, T]]: - async def f(_: T) -> T: - raise exc - - return f - - -def throw_for_odd_func(exc): - return lambda i: throw(exc) if i % 2 == 1 else i - - -def async_throw_for_odd_func(exc): - async def f(i): - return throw(exc) if i % 2 == 1 else i - - return f - - -class TestError(Exception): - pass - - -DELTA_RATE = 0.4 -# size of the test collections -N = 256 - -src = range(N) - -even_src = range(0, N, 2) - - -def randomly_slowed( - func: Callable[[T], R], min_sleep: float = 0.001, max_sleep: float = 0.05 -) -> Callable[[T], R]: - def wrap(x: T) -> R: - time.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) - return func(x) - - return wrap - - -def async_randomly_slowed( - async_func: Callable[[T], Coroutine[Any, Any, R]], - min_sleep: float = 0.001, - max_sleep: float = 0.05, -) -> Callable[[T], Coroutine[Any, Any, R]]: - async def wrap(x: T) -> R: - await asyncio.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) - return await async_func(x) - - return wrap - - -def range_raising_at_exhaustion( - start: int, end: int, step: int, exception: Exception -) -> Iterator[int]: - yield from range(start, end, step) - raise exception - - -src_raising_at_exhaustion = lambda: range_raising_at_exhaustion(0, N, 1, TestError()) +from streamable.util.asynctools import awaitable_to_coroutine +from streamable.util.functiontools import WrappedError, asyncify, star +from streamable.util.iterabletools import ( + sync_to_async_iter, + sync_to_bi_iterable, +) +from tests.utils import ( + DELTA_RATE, + ITERABLE_TYPES, + IterableType, + N, + TestError, + alist_or_list, + anext_or_next, + async_identity, + async_identity_sleep, + async_randomly_slowed, + async_slow_identity, + async_square, + async_throw_for_odd_func, + async_throw_func, + bi_iterable_to_iter, + even_src, + identity, + identity_sleep, + stopiteration_for_iter_type, + randomly_slowed, + slow_identity, + slow_identity_duration, + square, + src, + src_raising_at_exhaustion, + throw, + throw_for_odd_func, + throw_func, + timestream, + to_list, + to_set, +) class TestStream(unittest.TestCase): @@ -204,6 +116,19 @@ def test_init(self) -> None: ): Stream(src).upstream = Stream(src) # type: ignore + @parameterized.expand(ITERABLE_TYPES) + def test_async_src(self, itype) -> None: + self.assertEqual( + to_list(Stream(sync_to_async_iter(src)), itype), + list(src), + msg="a stream with an async source must be collectable as an Iterable or as AsyncIterable", + ) + self.assertEqual( + to_list(Stream(sync_to_async_iter(src).__aiter__), itype), + list(src), + msg="a stream with an async source must be collectable as an Iterable or as AsyncIterable", + ) + def test_repr_and_display(self) -> None: class CustomCallable: pass @@ -211,31 +136,44 @@ class CustomCallable: complex_stream: Stream[int] = ( Stream(src) .truncate(1024, when=lambda _: False) + .atruncate(1024, when=async_identity) .skip(10) + .askip(10) .skip(until=lambda _: True) + .askip(until=async_identity) .distinct(lambda _: _) + .adistinct(async_identity) .filter() .map(lambda i: (i,)) .map(lambda i: (i,), concurrency=2) .filter(star(bool)) + .afilter(star(async_identity)) .foreach(lambda _: _) .foreach(lambda _: _, concurrency=2) .aforeach(async_identity) .map(cast(Callable[[Any], Any], CustomCallable())) .amap(async_identity) .group(100) + .agroup(100) .groupby(len) + .agroupby(async_identity) .map(star(lambda key, group: group)) .observe("groups") .flatten(concurrency=4) + .map(sync_to_async_iter) + .aflatten(concurrency=4) + .map(lambda _: 0) .throttle( 64, per=datetime.timedelta(seconds=1), ) .observe("foos") - .catch(finally_raise=True) + .catch(finally_raise=True, when=identity) + .acatch(finally_raise=True, when=async_identity) .catch((TypeError, ValueError, None, ZeroDivisionError)) - .catch(TypeError, replacement=None, finally_raise=True) + .acatch((TypeError, ValueError, None, ZeroDivisionError)) + .catch(TypeError, replacement=1, finally_raise=True) + .acatch(TypeError, replacement=1, finally_raise=True) ) print(repr(complex_stream)) @@ -274,54 +212,69 @@ class CustomCallable: """( Stream(range(0, 256)) .truncate(count=1024, when=) + .atruncate(count=1024, when=async_identity) .skip(10, until=None) + .askip(10, until=None) .skip(None, until=) + .askip(None, until=async_identity) .distinct(, consecutive_only=False) + .adistinct(async_identity, consecutive_only=False) .filter(bool) .map(, concurrency=1, ordered=True) .map(, concurrency=2, ordered=True, via='thread') .filter(star(bool)) + .afilter(star(async_identity)) .foreach(, concurrency=1, ordered=True) .foreach(, concurrency=2, ordered=True, via='thread') .aforeach(async_identity, concurrency=1, ordered=True) .map(CustomCallable(...), concurrency=1, ordered=True) .amap(async_identity, concurrency=1, ordered=True) .group(size=100, by=None, interval=None) + .agroup(size=100, by=None, interval=None) .groupby(len, size=None, interval=None) + .agroupby(async_identity, size=None, interval=None) .map(star(), concurrency=1, ordered=True) .observe('groups') .flatten(concurrency=4) + .map(SyncToAsyncIterator, concurrency=1, ordered=True) + .aflatten(concurrency=4) + .map(, concurrency=1, ordered=True) .throttle(64, per=datetime.timedelta(seconds=1)) .observe('foos') - .catch(Exception, when=None, finally_raise=True) + .catch(Exception, when=identity, finally_raise=True) + .acatch(Exception, when=async_identity, finally_raise=True) .catch((TypeError, ValueError, None, ZeroDivisionError), when=None, finally_raise=False) - .catch(TypeError, when=None, replacement=None, finally_raise=True) + .acatch((TypeError, ValueError, None, ZeroDivisionError), when=None, finally_raise=False) + .catch(TypeError, when=None, replacement=1, finally_raise=True) + .acatch(TypeError, when=None, replacement=1, finally_raise=True) )""", msg="`repr` should work as expected on a stream with many operation", ) - def test_iter(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_iter(self, itype: IterableType) -> None: self.assertIsInstance( - iter(Stream(src)), - Iterator, + bi_iterable_to_iter(Stream(src), itype=itype), + itype, msg="iter(stream) must return an Iterator.", ) with self.assertRaisesRegex( TypeError, - r"`source` must be an Iterable or a Callable\[\[\], Iterable\] but got a ", + r"`source` must be an Iterable/AsyncIterable or a Callable\[\[\], Iterable/AsyncIterable\] but got a ", msg="Getting an Iterator from a Stream with a source not being a Union[Callable[[], Iterator], ITerable] must raise TypeError.", ): - iter(Stream(1)) # type: ignore + bi_iterable_to_iter(Stream(1), itype=itype) # type: ignore with self.assertRaisesRegex( TypeError, - r"`source` must be an Iterable or a Callable\[\[\], Iterable\] but got a Callable\[\[\], \]", + r"`source` must be an Iterable/AsyncIterable or a Callable\[\[\], Iterable/AsyncIterable\] but got a Callable\[\[\], \]", msg="Getting an Iterator from a Stream with a source not being a Union[Callable[[], Iterator], ITerable] must raise TypeError.", ): - iter(Stream(lambda: 1)) # type: ignore + bi_iterable_to_iter(Stream(lambda: 1), itype=itype) # type: ignore - def test_add(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_add(self, itype: IterableType) -> None: from streamable.stream import FlattenStream stream = Stream(src) @@ -335,7 +288,7 @@ def test_add(self) -> None: stream_b = Stream(range(10, 20)) stream_c = Stream(range(20, 30)) self.assertListEqual( - list(stream_a + stream_b + stream_c), + to_list(stream_a + stream_b + stream_c, itype=itype), list(range(30)), msg="`chain` must yield the elements of the first stream the move on with the elements of the next ones and so on.", ) @@ -345,7 +298,9 @@ def test_add(self) -> None: [Stream.map, [identity]], [Stream.amap, [async_identity]], [Stream.foreach, [identity]], + [Stream.aforeach, [identity]], [Stream.flatten, []], + [Stream.aflatten, []], ] ) def test_sanitize_concurrency(self, method, args) -> None: @@ -384,26 +339,30 @@ def test_sanitize_via(self, method) -> None: method(Stream(src), identity, via="foo") @parameterized.expand( - [ - [1], - [2], - ] + [(concurrency, itype) for concurrency in (1, 2) for itype in ITERABLE_TYPES] ) - def test_map(self, concurrency) -> None: + def test_map(self, concurrency, itype) -> None: self.assertListEqual( - list(Stream(src).map(randomly_slowed(square), concurrency=concurrency)), + to_list( + Stream(src).map(randomly_slowed(square), concurrency=concurrency), + itype=itype, + ), list(map(square, src)), msg="At any concurrency the `map` method should act as the builtin map function, transforming elements while preserving input elements order.", ) @parameterized.expand( [ - [True, identity], - [False, sorted], + (ordered, order_mutation, itype) + for itype in ITERABLE_TYPES + for ordered, order_mutation in [ + (True, identity), + (False, sorted), + ] ] ) def test_process_concurrency( - self, ordered, order_mutation + self, ordered, order_mutation, itype ) -> None: # pragma: no cover # 3.7 and 3.8 are passing the test but hang forever after if sys.version_info.minor < 9: @@ -420,7 +379,7 @@ def local_identity(x): "", msg="process-based concurrency should not be able to serialize a lambda or a local func", ): - list(Stream(src).map(f, concurrency=2, via="process")) + to_list(Stream(src).map(f, concurrency=2, via="process"), itype=itype) sleeps = [0.01, 1, 0.01] state: List[str] = [] @@ -433,7 +392,7 @@ def local_identity(x): .foreach(lambda _: state.append(""), concurrency=1, ordered=True) ) self.assertListEqual( - list(stream), + to_list(stream, itype=itype), expected_result_list, msg="process-based concurrency must correctly transform elements, respecting `ordered`...", ) @@ -444,38 +403,38 @@ def local_identity(x): ) # test partial iteration: self.assertEqual( - next(iter(stream)), + anext_or_next(bi_iterable_to_iter(stream, itype=itype)), expected_result_list[0], msg="process-based concurrency must behave ok with partial iteration", ) @parameterized.expand( [ - [16, 0], - [1, 0], - [16, 1], - [16, 15], - [16, 16], + (concurrency, n_elems, itype) + for concurrency, n_elems in [ + [16, 0], + [1, 0], + [16, 1], + [16, 15], + [16, 16], + ] + for itype in ITERABLE_TYPES ] ) def test_map_with_more_concurrency_than_elements( - self, concurrency, n_elems + self, concurrency, n_elems, itype ) -> None: self.assertListEqual( - list(Stream(range(n_elems)).map(str, concurrency=concurrency)), + to_list( + Stream(range(n_elems)).map(str, concurrency=concurrency), itype=itype + ), list(map(str, range(n_elems))), msg="`map` method should act correctly when concurrency > number of elements.", ) @parameterized.expand( [ - [ - ordered, - order_mutation, - expected_duration, - operation, - func, - ] + [ordered, order_mutation, expected_duration, operation, func, itype] for ordered, order_mutation, expected_duration in [ (True, identity, 0.3), (False, sorted, 0.21), @@ -486,6 +445,7 @@ def test_map_with_more_concurrency_than_elements( (Stream.aforeach, asyncio.sleep), (Stream.amap, async_identity_sleep), ] + for itype in ITERABLE_TYPES ] ) def test_mapping_ordering( @@ -495,11 +455,13 @@ def test_mapping_ordering( expected_duration: float, operation, func, + itype, ) -> None: seconds = [0.1, 0.01, 0.2] duration, res = timestream( operation(Stream(seconds), func, ordered=ordered, concurrency=2), 5, + itype=itype, ) self.assertListEqual( res, @@ -515,23 +477,21 @@ def test_mapping_ordering( ) @parameterized.expand( - [ - [1], - [2], - ] + [(concurrency, itype) for concurrency in (1, 2) for itype in ITERABLE_TYPES] ) - def test_foreach(self, concurrency) -> None: + def test_foreach(self, concurrency, itype) -> None: side_collection: Set[int] = set() def side_effect(x: int, func: Callable[[int], int]): nonlocal side_collection side_collection.add(func(x)) - res = list( + res = to_list( Stream(src).foreach( lambda i: randomly_slowed(side_effect(i, square)), concurrency=concurrency, - ) + ), + itype=itype, ) self.assertListEqual( @@ -554,6 +514,7 @@ def side_effect(x: int, func: Callable[[int], int]): method, throw_func_, throw_for_odd_func_, + itype, ] for raised_exc, caught_exc in [ (TestError, (TestError,)), @@ -562,9 +523,11 @@ def side_effect(x: int, func: Callable[[int], int]): for concurrency in [1, 2] for method, throw_func_, throw_for_odd_func_ in [ (Stream.foreach, throw_func, throw_for_odd_func), + (Stream.aforeach, async_throw_func, async_throw_for_odd_func), (Stream.map, throw_func, throw_for_odd_func), (Stream.amap, async_throw_func, async_throw_for_odd_func), ] + for itype in ITERABLE_TYPES ] ) def test_map_or_foreach_with_exception( @@ -575,20 +538,25 @@ def test_map_or_foreach_with_exception( method: Callable[[Stream, Callable[[Any], int], int], Stream], throw_func: Callable[[Exception], Callable[[Any], int]], throw_for_odd_func: Callable[[Type[Exception]], Callable[[Any], int]], + itype: IterableType, ) -> None: with self.assertRaises( caught_exc, msg="At any concurrency, `map` and `foreach` and `amap` must raise", ): - list(method(Stream(src), throw_func(raised_exc), concurrency=concurrency)) # type: ignore + to_list( + method(Stream(src), throw_func(raised_exc), concurrency=concurrency), # type: ignore + itype=itype, + ) self.assertListEqual( - list( + to_list( method( Stream(src), throw_for_odd_func(raised_exc), concurrency=concurrency, # type: ignore - ).catch(caught_exc) + ).catch(caught_exc), + itype=itype, ), list(even_src), msg="At any concurrency, `map` and `foreach` and `amap` must not stop after one exception occured.", @@ -596,18 +564,24 @@ def test_map_or_foreach_with_exception( @parameterized.expand( [ - [method, func, concurrency] + [method, func, concurrency, itype] for method, func in [ (Stream.foreach, slow_identity), + (Stream.aforeach, async_slow_identity), (Stream.map, slow_identity), (Stream.amap, async_slow_identity), ] for concurrency in [1, 2, 4] + for itype in ITERABLE_TYPES ] ) - def test_map_and_foreach_concurrency(self, method, func, concurrency) -> None: + def test_map_and_foreach_concurrency( + self, method, func, concurrency, itype + ) -> None: expected_iteration_duration = N * slow_identity_duration / concurrency - duration, res = timestream(method(Stream(src), func, concurrency=concurrency)) + duration, res = timestream( + method(Stream(src), func, concurrency=concurrency), itype=itype + ) self.assertListEqual(res, list(src)) self.assertAlmostEqual( duration, @@ -618,50 +592,80 @@ def test_map_and_foreach_concurrency(self, method, func, concurrency) -> None: @parameterized.expand( [ - [1], - [2], + (concurrency, itype, flatten) + for concurrency in (1, 2) + for itype in ITERABLE_TYPES + for flatten in (Stream.flatten, Stream.aflatten) ] ) - def test_flatten(self, concurrency) -> None: + def test_flatten(self, concurrency, itype, flatten) -> None: n_iterables = 32 it = list(range(N // n_iterables)) double_it = it + it iterables_stream = Stream( - lambda: map(slow_identity, [double_it] + [it for _ in range(n_iterables)]) - ) + [sync_to_bi_iterable(double_it)] + + [sync_to_bi_iterable(it) for _ in range(n_iterables)] + ).map(slow_identity) self.assertCountEqual( - list(iterables_stream.flatten(concurrency=concurrency)), + to_list(flatten(iterables_stream, concurrency=concurrency), itype=itype), list(it) * n_iterables + double_it, msg="At any concurrency the `flatten` method should yield all the upstream iterables' elements.", ) self.assertListEqual( - list( - Stream([iter([]) for _ in range(2000)]).flatten(concurrency=concurrency) + to_list( + flatten( + Stream([sync_to_bi_iterable(iter([])) for _ in range(2000)]), + concurrency=concurrency, + ), + itype=itype, ), [], msg="`flatten` should not yield any element if upstream elements are empty iterables, and be resilient to recursion issue in case of successive empty upstream iterables.", ) with self.assertRaises( - TypeError, + (TypeError, AttributeError), msg="`flatten` should raise if an upstream element is not iterable.", ): - next(iter(Stream(cast(Iterable, src)).flatten())) + anext_or_next( + bi_iterable_to_iter( + flatten(Stream(cast(Union[Iterable, AsyncIterable], src))), + itype=itype, + ) + ) # test typing with ranges _: Stream[int] = Stream((src, src)).flatten() - def test_flatten_concurrency(self) -> None: + @parameterized.expand( + [ + (flatten, itype, slow) + for flatten, slow in ( + (Stream.flatten, partial(Stream.map, transformation=slow_identity)), + ( + Stream.aflatten, + partial(Stream.amap, transformation=async_slow_identity), + ), + ) + for itype in ITERABLE_TYPES + ] + ) + def test_flatten_concurrency(self, flatten, itype, slow) -> None: + concurrency = 2 iterable_size = 5 runtime, res = timestream( - Stream( - lambda: [ - Stream(map(slow_identity, ["a"] * iterable_size)), - Stream(map(slow_identity, ["b"] * iterable_size)), - Stream(map(slow_identity, ["c"] * iterable_size)), - ] - ).flatten(concurrency=2), + flatten( + Stream( + lambda: [ + slow(Stream(["a"] * iterable_size)), + slow(Stream(["b"] * iterable_size)), + slow(Stream(["c"] * iterable_size)), + ] + ), + concurrency=concurrency, + ), times=3, + itype=itype, ) self.assertListEqual( res, @@ -669,7 +673,7 @@ def test_flatten_concurrency(self) -> None: msg="`flatten` should process 'a's and 'b's concurrently and then 'c's", ) a_runtime = b_runtime = c_runtime = iterable_size * slow_identity_duration - expected_runtime = (a_runtime + b_runtime) / 2 + c_runtime + expected_runtime = (a_runtime + b_runtime) / concurrency + c_runtime self.assertAlmostEqual( runtime, expected_runtime, @@ -688,21 +692,29 @@ def test_flatten_typing(self) -> None: Stream("abc").map(lambda char: filter(lambda _: True, char)).flatten() ) + flattened_asynciter_stream: Stream[str] = ( + Stream("abc").map(sync_to_async_iter).aflatten() + ) + @parameterized.expand( [ - [exception_type, mapped_exception_type, concurrency] + [exception_type, mapped_exception_type, concurrency, itype, flatten] + for concurrency in [1, 2] + for itype in ITERABLE_TYPES for exception_type, mapped_exception_type in [ (TestError, TestError), - (StopIteration, WrappedError), + (stopiteration_for_iter_type(itype), (WrappedError, RuntimeError)), ] - for concurrency in [1, 2] + for flatten in (Stream.flatten, Stream.aflatten) ] ) - def test_flatten_with_exception( + def test_flatten_with_exception_in_iter( self, exception_type: Type[Exception], mapped_exception_type: Type[Exception], concurrency: int, + itype: IterableType, + flatten: Callable, ) -> None: n_iterables = 5 @@ -710,60 +722,79 @@ class IterableRaisingInIter(Iterable[int]): def __iter__(self) -> Iterator[int]: raise exception_type - self.assertSetEqual( - set( + res: Set[int] = to_set( + flatten( Stream( map( - lambda i: ( - IterableRaisingInIter() - if i % 2 - else cast(Iterable[int], range(i, i + 1)) + lambda i: sync_to_bi_iterable( + IterableRaisingInIter() if i % 2 else range(i, i + 1) ), range(n_iterables), ) - ) - .flatten(concurrency=concurrency) - .catch(mapped_exception_type) - ), + ), + concurrency=concurrency, + ).catch(mapped_exception_type), + itype=itype, + ) + self.assertSetEqual( + res, set(range(0, n_iterables, 2)), - msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should remap StopIteration one to PacifiedStopIteration.", + msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should wrap Stop(Async)Iteration.", ) + @parameterized.expand( + [ + [concurrency, itype, flatten] + for concurrency in [1, 2] + for itype in ITERABLE_TYPES + for flatten in (Stream.flatten, Stream.aflatten) + ] + ) + def test_flatten_with_exception_in_next( + self, + concurrency: int, + itype: IterableType, + flatten: Callable, + ) -> None: + n_iterables = 5 + class IteratorRaisingInNext(Iterator[int]): def __init__(self) -> None: self.first_next = True - def __iter__(self) -> Iterator[int]: - return self - def __next__(self) -> int: if not self.first_next: - raise StopIteration + raise StopIteration() self.first_next = False - raise exception_type + raise TestError - self.assertSetEqual( - set( + res = to_set( + flatten( Stream( map( lambda i: ( - IteratorRaisingInNext() - if i % 2 - else cast(Iterable[int], range(i, i + 1)) + sync_to_bi_iterable( + IteratorRaisingInNext() if i % 2 else range(i, i + 1) + ) ), range(n_iterables), ) - ) - .flatten(concurrency=concurrency) - .catch(mapped_exception_type) - ), + ), + concurrency=concurrency, + ).catch(TestError), + itype=itype, + ) + self.assertSetEqual( + res, set(range(0, n_iterables, 2)), - msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should remap StopIteration one to PacifiedStopIteration.", + msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should wrap Stop(Async)Iteration.", ) - @parameterized.expand([[concurrency] for concurrency in [2, 4]]) + @parameterized.expand( + [(concurrency, itype) for concurrency in [2, 4] for itype in ITERABLE_TYPES] + ) def test_partial_iteration_on_streams_using_concurrency( - self, concurrency: int + self, concurrency: int, itype: IterableType ) -> None: yielded_elems = [] @@ -798,14 +829,14 @@ def remembering_src() -> Iterator[int]: ), ]: yielded_elems = [] - iterator = iter(stream) + iterator = bi_iterable_to_iter(stream, itype=itype) time.sleep(0.5) self.assertEqual( len(yielded_elems), 0, msg=f"before the first call to `next` a concurrent {type(stream)} should have pulled 0 upstream elements.", ) - next(iterator) + anext_or_next(iterator) time.sleep(0.5) self.assertEqual( len(yielded_elems), @@ -813,40 +844,82 @@ def remembering_src() -> Iterator[int]: msg=f"`after the first call to `next` a concurrent {type(stream)} with concurrency={concurrency} should have pulled only {n_pulls_after_first_next} upstream elements.", ) - def test_filter(self) -> None: - def keep(x) -> Any: + @parameterized.expand(ITERABLE_TYPES) + def test_filter(self, itype: IterableType) -> None: + def keep(x) -> int: return x % 2 self.assertListEqual( - list(Stream(src).filter(keep)), + to_list(Stream(src).filter(keep), itype=itype), list(filter(keep, src)), msg="`filter` must act like builtin filter", ) self.assertListEqual( - list(Stream(src).filter()), + to_list(Stream(src).filter(bool), itype=itype), list(filter(None, src)), msg="`filter` with `bool` as predicate must act like builtin filter with None predicate.", ) self.assertListEqual( - list(Stream(src).filter()), + to_list(Stream(src).filter(), itype=itype), list(filter(None, src)), msg="`filter` without predicate must act like builtin filter with None predicate.", ) self.assertListEqual( - list(Stream(src).filter(None)), # type: ignore + to_list(Stream(src).filter(None), itype=itype), # type: ignore list(filter(None, src)), msg="`filter` with None predicate must act unofficially like builtin filter with None predicate.", ) - # Unofficially accept `stream.filter(None)`, behaving as builtin `filter(None, iter)` + self.assertEqual( + to_list(Stream(src).filter(None), itype=itype), # type: ignore + list(filter(None, src)), + msg="Unofficially accept `stream.filter(None)`, behaving as builtin `filter(None, iter)`", + ) # with self.assertRaisesRegex( # TypeError, # "`when` cannot be None", # msg="`filter` does not accept a None predicate", # ): - # list(Stream(src).filter(None)) # type: ignore + # to_list(Stream(src).filter(None), itype=itype) # type: ignore + + @parameterized.expand(ITERABLE_TYPES) + def test_afilter(self, itype: IterableType) -> None: + def keep(x) -> int: + return x % 2 + + async def async_keep(x) -> int: + return keep(x) + + self.assertListEqual( + to_list(Stream(src).afilter(async_keep), itype=itype), + list(filter(keep, src)), + msg="`afilter` must act like builtin filter", + ) + self.assertListEqual( + to_list(Stream(src).afilter(asyncify(bool)), itype=itype), + list(filter(None, src)), + msg="`afilter` with `bool` as predicate must act like builtin filter with None predicate.", + ) + self.assertListEqual( + to_list(Stream(src).afilter(None), itype=itype), # type: ignore + list(filter(None, src)), + msg="`afilter` with None predicate must act unofficially like builtin filter with None predicate.", + ) + + self.assertEqual( + to_list(Stream(src).afilter(None), itype=itype), # type: ignore + list(filter(None, src)), + msg="Unofficially accept `stream.afilter(None)`, behaving as builtin `filter(None, iter)`", + ) + # with self.assertRaisesRegex( + # TypeError, + # "`when` cannot be None", + # msg="`afilter` does not accept a None predicate", + # ): + # to_list(Stream(src).afilter(None), itype=itype) # type: ignore - def test_skip(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_skip(self, itype: IterableType) -> None: with self.assertRaisesRegex( ValueError, "`count` must be >= 0 but got -1", @@ -855,89 +928,164 @@ def test_skip(self) -> None: Stream(src).skip(-1) self.assertListEqual( - list(Stream(src).skip()), + to_list(Stream(src).skip(), itype=itype), list(src), msg="`skip` must be no-op if both `count` and `until` are None", ) self.assertListEqual( - list(Stream(src).skip(None)), + to_list(Stream(src).skip(None), itype=itype), list(src), msg="`skip` must be no-op if both `count` and `until` are None", ) for count in [0, 1, 3]: self.assertListEqual( - list(Stream(src).skip(count)), + to_list(Stream(src).skip(count), itype=itype), list(src)[count:], msg="`skip` must skip `count` elements", ) self.assertListEqual( - list( + to_list( Stream(map(throw_for_odd_func(TestError), src)) .skip(count) - .catch(TestError) + .catch(TestError), + itype=itype, ), list(filter(lambda i: i % 2 == 0, src))[count:], msg="`skip` must not count exceptions as skipped elements", ) self.assertListEqual( - list(Stream(src).skip(until=lambda n: n >= count)), + to_list(Stream(src).skip(until=lambda n: n >= count), itype=itype), list(src)[count:], msg="`skip` must yield starting from the first element satisfying `until`", ) self.assertListEqual( - list(Stream(src).skip(count, until=lambda n: False)), + to_list(Stream(src).skip(count, until=lambda n: False), itype=itype), list(src)[count:], msg="`skip` must ignore `count` elements if `until` is never satisfied", ) self.assertListEqual( - list(Stream(src).skip(count * 2, until=lambda n: n >= count)), + to_list( + Stream(src).skip(count * 2, until=lambda n: n >= count), itype=itype + ), list(src)[count:], msg="`skip` must ignore less than `count` elements if `until` is satisfied first", ) self.assertListEqual( - list(Stream(src).skip(until=lambda n: False)), + to_list(Stream(src).skip(until=lambda n: False), itype=itype), [], msg="`skip` must not yield any element if `until` is never satisfied", ) - def test_truncate(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_askip(self, itype: IterableType) -> None: + with self.assertRaisesRegex( + ValueError, + "`count` must be >= 0 but got -1", + msg="`askip` must raise ValueError if `count` is negative", + ): + Stream(src).askip(-1) + + self.assertListEqual( + to_list(Stream(src).askip(), itype=itype), + list(src), + msg="`askip` must be no-op if both `count` and `until` are None", + ) + + self.assertListEqual( + to_list(Stream(src).askip(None), itype=itype), + list(src), + msg="`askip` must be no-op if both `count` and `until` are None", + ) + + for count in [0, 1, 3]: + self.assertListEqual( + to_list(Stream(src).askip(count), itype=itype), + list(src)[count:], + msg="`askip` must skip `count` elements", + ) + + self.assertListEqual( + to_list( + Stream(map(throw_for_odd_func(TestError), src)) + .askip(count) + .catch(TestError), + itype=itype, + ), + list(filter(lambda i: i % 2 == 0, src))[count:], + msg="`askip` must not count exceptions as skipped elements", + ) + + self.assertListEqual( + to_list( + Stream(src).askip(until=asyncify(lambda n: n >= count)), itype=itype + ), + list(src)[count:], + msg="`askip` must yield starting from the first element satisfying `until`", + ) + + self.assertListEqual( + to_list( + Stream(src).askip(count, until=asyncify(lambda n: False)), + itype=itype, + ), + list(src)[count:], + msg="`askip` must ignore `count` elements if `until` is never satisfied", + ) + + self.assertListEqual( + to_list( + Stream(src).askip(count * 2, until=asyncify(lambda n: n >= count)), + itype=itype, + ), + list(src)[count:], + msg="`askip` must ignore less than `count` elements if `until` is satisfied first", + ) + + self.assertListEqual( + to_list(Stream(src).askip(until=asyncify(lambda n: False)), itype=itype), + [], + msg="`askip` must not yield any element if `until` is never satisfied", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_truncate(self, itype: IterableType) -> None: self.assertListEqual( - list(Stream(src).truncate(N * 2)), + to_list(Stream(src).truncate(N * 2), itype=itype), list(src), msg="`truncate` must be ok with count >= stream length", ) self.assertListEqual( - list(Stream(src).truncate()), + to_list(Stream(src).truncate(), itype=itype), list(src), msg="`truncate must be no-op if both `count` and `when` are None", ) self.assertListEqual( - list(Stream(src).truncate(None)), + to_list(Stream(src).truncate(None), itype=itype), list(src), msg="`truncate must be no-op if both `count` and `when` are None", ) self.assertListEqual( - list(Stream(src).truncate(2)), + to_list(Stream(src).truncate(2), itype=itype), [0, 1], msg="`truncate` must be ok with count >= 1", ) self.assertListEqual( - list(Stream(src).truncate(1)), + to_list(Stream(src).truncate(1), itype=itype), [0], msg="`truncate` must be ok with count == 1", ) self.assertListEqual( - list(Stream(src).truncate(0)), + to_list(Stream(src).truncate(0), itype=itype), [], msg="`truncate` must be ok with count == 0", ) @@ -950,197 +1098,333 @@ def test_truncate(self) -> None: Stream(src).truncate(-1) self.assertListEqual( - list(Stream(src).truncate(cast(int, float("inf")))), + to_list(Stream(src).truncate(cast(int, float("inf"))), itype=itype), list(src), msg="`truncate` must be no-op if `count` is inf", ) count = N // 2 - raising_stream_iterator = iter( - Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).truncate(count) + raising_stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).truncate(count), + itype=itype, ) with self.assertRaises( ZeroDivisionError, msg="`truncate` must not stop iteration when encountering exceptions and raise them without counting them...", ): - next(raising_stream_iterator) + anext_or_next(raising_stream_iterator) - self.assertListEqual(list(raising_stream_iterator), list(range(1, count + 1))) + self.assertListEqual( + alist_or_list(raising_stream_iterator), list(range(1, count + 1)) + ) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(raising_stream_iterator)), msg="... and after reaching the limit it still continues to raise StopIteration on calls to next", ): - next(raising_stream_iterator) + anext_or_next(raising_stream_iterator) - iter_truncated_on_predicate = iter(Stream(src).truncate(when=lambda n: n == 5)) + iter_truncated_on_predicate = bi_iterable_to_iter( + Stream(src).truncate(when=lambda n: n == 5), itype=itype + ) self.assertListEqual( - list(iter_truncated_on_predicate), - list(Stream(src).truncate(5)), + alist_or_list(iter_truncated_on_predicate), + to_list(Stream(src).truncate(5), itype=itype), msg="`when` n == 5 must be equivalent to `count` = 5", ) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(iter_truncated_on_predicate)), msg="After exhaustion a call to __next__ on a truncated iterator must raise StopIteration", ): - next(iter_truncated_on_predicate) + anext_or_next(iter_truncated_on_predicate) with self.assertRaises( ZeroDivisionError, msg="an exception raised by `when` must be raised", ): - list(Stream(src).truncate(when=lambda _: 1 / 0)) + to_list(Stream(src).truncate(when=lambda _: 1 / 0), itype=itype) self.assertListEqual( - list(Stream(src).truncate(6, when=lambda n: n == 5)), + to_list(Stream(src).truncate(6, when=lambda n: n == 5), itype=itype), list(range(5)), msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", ) self.assertListEqual( - list(Stream(src).truncate(5, when=lambda n: n == 6)), + to_list(Stream(src).truncate(5, when=lambda n: n == 6), itype=itype), list(range(5)), msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", ) - def test_group(self) -> None: - # behavior with invalid arguments - for seconds in [-1, 0]: - with self.assertRaises( - ValueError, - msg="`group` should raise error when called with `seconds` <= 0.", - ): - list( - Stream([1]).group( - size=100, interval=datetime.timedelta(seconds=seconds) - ) - ) + @parameterized.expand(ITERABLE_TYPES) + def test_atruncate(self, itype: IterableType) -> None: + self.assertListEqual( + to_list(Stream(src).atruncate(N * 2), itype=itype), + list(src), + msg="`atruncate` must be ok with count >= stream length", + ) - for size in [-1, 0]: - with self.assertRaises( - ValueError, - msg="`group` should raise error when called with `size` < 1.", - ): - list(Stream([1]).group(size=size)) + self.assertListEqual( + to_list(Stream(src).atruncate(), itype=itype), + list(src), + msg="`atruncate` must be no-op if both `count` and `when` are None", + ) - # group size self.assertListEqual( - list(Stream(range(6)).group(size=4)), - [[0, 1, 2, 3], [4, 5]], - msg="", + to_list(Stream(src).atruncate(None), itype=itype), + list(src), + msg="`atruncate` must be no-op if both `count` and `when` are None", ) + self.assertListEqual( - list(Stream(range(6)).group(size=2)), - [[0, 1], [2, 3], [4, 5]], - msg="", + to_list(Stream(src).atruncate(2), itype=itype), + [0, 1], + msg="`atruncate` must be ok with count >= 1", + ) + self.assertListEqual( + to_list(Stream(src).atruncate(1), itype=itype), + [0], + msg="`atruncate` must be ok with count == 1", ) self.assertListEqual( - list(Stream([]).group(size=2)), + to_list(Stream(src).atruncate(0), itype=itype), [], - msg="", + msg="`atruncate` must be ok with count == 0", ) - # behavior with exceptions - def f(i): - return i / (110 - i) + with self.assertRaisesRegex( + ValueError, + "`count` must be >= 0 but got -1", + msg="`atruncate` must raise ValueError if `count` is negative", + ): + Stream(src).atruncate(-1) - stream_iterator = iter(Stream(lambda: map(f, src)).group(100)) - next(stream_iterator) self.assertListEqual( - next(stream_iterator), - list(map(f, range(100, 110))), - msg="when encountering upstream exception, `group` should yield the current accumulated group...", + to_list(Stream(src).atruncate(cast(int, float("inf"))), itype=itype), + list(src), + msg="`atruncate` must be no-op if `count` is inf", + ) + + count = N // 2 + raising_stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).atruncate(count), + itype=itype, ) with self.assertRaises( ZeroDivisionError, - msg="... and raise the upstream exception during the next call to `next`...", + msg="`atruncate` must not stop iteration when encountering exceptions and raise them without counting them...", ): - next(stream_iterator) + anext_or_next(raising_stream_iterator) self.assertListEqual( - next(stream_iterator), - list(map(f, range(111, 211))), - msg="... and restarting a fresh group to yield after that.", + alist_or_list(raising_stream_iterator), list(range(1, count + 1)) ) - # behavior of the `seconds` parameter - self.assertListEqual( - list( + with self.assertRaises( + stopiteration_for_iter_type(type(raising_stream_iterator)), + msg="... and after reaching the limit it still continues to raise StopIteration on calls to next", + ): + anext_or_next(raising_stream_iterator) + + iter_truncated_on_predicate = bi_iterable_to_iter( + Stream(src).atruncate(when=asyncify(lambda n: n == 5)), itype=itype + ) + self.assertListEqual( + alist_or_list(iter_truncated_on_predicate), + to_list(Stream(src).atruncate(5), itype=itype), + msg="`when` n == 5 must be equivalent to `count` = 5", + ) + with self.assertRaises( + stopiteration_for_iter_type(type(iter_truncated_on_predicate)), + msg="After exhaustion a call to __next__ on a truncated iterator must raise StopIteration", + ): + anext_or_next(iter_truncated_on_predicate) + + with self.assertRaises( + ZeroDivisionError, + msg="an exception raised by `when` must be raised", + ): + to_list(Stream(src).atruncate(when=asyncify(lambda _: 1 / 0)), itype=itype) + + self.assertListEqual( + to_list( + Stream(src).atruncate(6, when=asyncify(lambda n: n == 5)), itype=itype + ), + list(range(5)), + msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", + ) + + self.assertListEqual( + to_list( + Stream(src).atruncate(5, when=asyncify(lambda n: n == 6)), itype=itype + ), + list(range(5)), + msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_group(self, itype: IterableType) -> None: + # behavior with invalid arguments + for seconds in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`group` should raise error when called with `seconds` <= 0.", + ): + to_list( + Stream([1]).group( + size=100, interval=datetime.timedelta(seconds=seconds) + ), + itype=itype, + ) + + for size in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`group` should raise error when called with `size` < 1.", + ): + to_list(Stream([1]).group(size=size), itype=itype) + + # group size + self.assertListEqual( + to_list(Stream(range(6)).group(size=4), itype=itype), + [[0, 1, 2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream(range(6)).group(size=2), itype=itype), + [[0, 1], [2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream([]).group(size=2), itype=itype), + [], + msg="", + ) + + # behavior with exceptions + def f(i): + return i / (110 - i) + + stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(f, src)).group(100), itype=itype + ) + anext_or_next(stream_iterator) + self.assertListEqual( + anext_or_next(stream_iterator), + list(map(f, range(100, 110))), + msg="when encountering upstream exception, `group` should yield the current accumulated group...", + ) + + with self.assertRaises( + ZeroDivisionError, + msg="... and raise the upstream exception during the next call to `next`...", + ): + anext_or_next(stream_iterator) + + self.assertListEqual( + anext_or_next(stream_iterator), + list(map(f, range(111, 211))), + msg="... and restarting a fresh group to yield after that.", + ) + + # behavior of the `seconds` parameter + self.assertListEqual( + to_list( Stream(lambda: map(slow_identity, src)).group( size=100, interval=datetime.timedelta(seconds=slow_identity_duration / 1000), - ) + ), + itype=itype, ), list(map(lambda e: [e], src)), msg="`group` should not yield empty groups even though `interval` if smaller than upstream's frequency", ) self.assertListEqual( - list( + to_list( Stream(lambda: map(slow_identity, src)).group( size=100, interval=datetime.timedelta(seconds=slow_identity_duration / 1000), by=lambda _: None, - ) + ), + itype=itype, ), list(map(lambda e: [e], src)), msg="`group` with `by` argument should not yield empty groups even though `interval` if smaller than upstream's frequency", ) self.assertListEqual( - list( + to_list( Stream(lambda: map(slow_identity, src)).group( size=100, interval=datetime.timedelta( seconds=2 * slow_identity_duration * 0.99 ), - ) + ), + itype=itype, ), list(map(lambda e: [e, e + 1], even_src)), msg="`group` should yield upstream elements in a two-element group if `interval` inferior to twice the upstream yield period", ) self.assertListEqual( - next(iter(Stream(src).group())), + anext_or_next(bi_iterable_to_iter(Stream(src).group(), itype=itype)), list(src), msg="`group` without arguments should group the elements all together", ) + # test agroupby + groupby_stream_iter: Union[ + Iterator[Tuple[int, List[int]]], AsyncIterator[Tuple[int, List[int]]] + ] = bi_iterable_to_iter( + Stream(src).groupby(lambda n: n % 2, size=2), itype=itype + ) + self.assertListEqual( + [anext_or_next(groupby_stream_iter), anext_or_next(groupby_stream_iter)], + [(0, [0, 2]), (1, [1, 3])], + msg="`groupby` must cogroup elements.", + ) + # test by - stream_iter = iter(Stream(src).group(size=2, by=lambda n: n % 2)) + stream_iter = bi_iterable_to_iter( + Stream(src).group(size=2, by=lambda n: n % 2), itype=itype + ) self.assertListEqual( - [next(stream_iter), next(stream_iter)], + [anext_or_next(stream_iter), anext_or_next(stream_iter)], [[0, 2], [1, 3]], msg="`group` called with a `by` function must cogroup elements.", ) self.assertListEqual( - next( - iter( + anext_or_next( + bi_iterable_to_iter( Stream(src_raising_at_exhaustion).group( size=10, by=lambda n: n % 4 != 0 - ) - ) + ), + itype=itype, + ), ), [1, 2, 3, 5, 6, 7, 9, 10, 11, 13], msg="`group` called with a `by` function and a `size` should yield the first batch becoming full.", ) self.assertListEqual( - list(Stream(src).group(by=lambda n: n % 2)), + to_list(Stream(src).group(by=lambda n: n % 2), itype=itype), [list(range(0, N, 2)), list(range(1, N, 2))], msg="`group` called with a `by` function and an infinite size must cogroup elements and yield groups starting with the group containing the oldest element.", ) self.assertListEqual( - list(Stream(range(10)).group(by=lambda n: n % 4 == 0)), + to_list(Stream(range(10)).group(by=lambda n: n % 4 == 0), itype=itype), [[0, 4, 8], [1, 2, 3, 5, 6, 7, 9]], msg="`group` called with a `by` function and reaching exhaustion must cogroup elements and yield uncomplete groups starting with the group containing the oldest element, even though it's not the largest.", ) - stream_iter = iter(Stream(src_raising_at_exhaustion).group(by=lambda n: n % 2)) + stream_iter = bi_iterable_to_iter( + Stream(src_raising_at_exhaustion).group(by=lambda n: n % 2), itype=itype + ) self.assertListEqual( - [next(stream_iter), next(stream_iter)], + [anext_or_next(stream_iter), anext_or_next(stream_iter)], [list(range(0, N, 2)), list(range(1, N, 2))], msg="`group` called with a `by` function and encountering an exception must cogroup elements and yield uncomplete groups starting with the group containing the oldest element.", ) @@ -1148,56 +1432,277 @@ def f(i): TestError, msg="`group` called with a `by` function and encountering an exception must raise it after all groups have been yielded", ): - next(stream_iter) + anext_or_next(stream_iter) # test seconds + by self.assertListEqual( - list( + to_list( Stream(lambda: map(slow_identity, range(10))).group( interval=datetime.timedelta(seconds=slow_identity_duration * 2.9), by=lambda n: n % 4 == 0, - ) + ), + itype=itype, ), [[1, 2], [0, 4], [3, 5, 6, 7], [8], [9]], msg="`group` called with a `by` function must cogroup elements and yield the largest groups when `seconds` is reached event though it's not the oldest.", ) - stream_iter = iter( + stream_iter = bi_iterable_to_iter( Stream(src).group( - size=3, by=lambda n: throw(StopIteration) if n == 2 else n - ) + size=3, + by=lambda n: throw(stopiteration_for_iter_type(itype)) if n == 2 else n, + ), + itype=itype, ) self.assertListEqual( - [next(stream_iter), next(stream_iter)], + [anext_or_next(stream_iter), anext_or_next(stream_iter)], [[0], [1]], msg="`group` should yield incomplete groups when `by` raises", ) with self.assertRaisesRegex( - WrappedError, - "StopIteration()", + (WrappedError, RuntimeError), + stopiteration_for_iter_type(itype).__name__, msg="`group` should raise and skip `elem` if `by(elem)` raises", ): - next(stream_iter) + anext_or_next(stream_iter) self.assertListEqual( - next(stream_iter), + anext_or_next(stream_iter), [3], msg="`group` should continue yielding after `by`'s exception has been raised.", ) - def test_throttle(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_agroup(self, itype: IterableType) -> None: + # behavior with invalid arguments + for seconds in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`agroup` should raise error when called with `seconds` <= 0.", + ): + to_list( + Stream([1]).agroup( + size=100, interval=datetime.timedelta(seconds=seconds) + ), + itype=itype, + ) + + for size in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`agroup` should raise error when called with `size` < 1.", + ): + to_list(Stream([1]).agroup(size=size), itype=itype) + + # group size + self.assertListEqual( + to_list(Stream(range(6)).agroup(size=4), itype=itype), + [[0, 1, 2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream(range(6)).agroup(size=2), itype=itype), + [[0, 1], [2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream([]).agroup(size=2), itype=itype), + [], + msg="", + ) + + # behavior with exceptions + def f(i): + return i / (110 - i) + + stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(f, src)).agroup(100), itype=itype + ) + anext_or_next(stream_iterator) + self.assertListEqual( + anext_or_next(stream_iterator), + list(map(f, range(100, 110))), + msg="when encountering upstream exception, `agroup` should yield the current accumulated group...", + ) + + with self.assertRaises( + ZeroDivisionError, + msg="... and raise the upstream exception during the next call to `next`...", + ): + anext_or_next(stream_iterator) + + self.assertListEqual( + anext_or_next(stream_iterator), + list(map(f, range(111, 211))), + msg="... and restarting a fresh group to yield after that.", + ) + + # behavior of the `seconds` parameter + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, src)).agroup( + size=100, + interval=datetime.timedelta(seconds=slow_identity_duration / 1000), + ), + itype=itype, + ), + list(map(lambda e: [e], src)), + msg="`agroup` should not yield empty groups even though `interval` if smaller than upstream's frequency", + ) + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, src)).agroup( + size=100, + interval=datetime.timedelta(seconds=slow_identity_duration / 1000), + by=asyncify(lambda _: None), + ), + itype=itype, + ), + list(map(lambda e: [e], src)), + msg="`agroup` with `by` argument should not yield empty groups even though `interval` if smaller than upstream's frequency", + ) + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, src)).agroup( + size=100, + interval=datetime.timedelta( + seconds=2 * slow_identity_duration * 0.99 + ), + ), + itype=itype, + ), + list(map(lambda e: [e, e + 1], even_src)), + msg="`agroup` should yield upstream elements in a two-element group if `interval` inferior to twice the upstream yield period", + ) + + self.assertListEqual( + anext_or_next(bi_iterable_to_iter(Stream(src).agroup(), itype=itype)), + list(src), + msg="`agroup` without arguments should group the elements all together", + ) + + # test agroupby + groupby_stream_iter: Union[ + Iterator[Tuple[int, List[int]]], AsyncIterator[Tuple[int, List[int]]] + ] = bi_iterable_to_iter( + Stream(src).agroupby(asyncify(lambda n: n % 2), size=2), itype=itype + ) + self.assertListEqual( + [anext_or_next(groupby_stream_iter), anext_or_next(groupby_stream_iter)], + [(0, [0, 2]), (1, [1, 3])], + msg="`agroupby` must cogroup elements.", + ) + + # test by + stream_iter = bi_iterable_to_iter( + Stream(src).agroup(size=2, by=asyncify(lambda n: n % 2)), itype=itype + ) + self.assertListEqual( + [anext_or_next(stream_iter), anext_or_next(stream_iter)], + [[0, 2], [1, 3]], + msg="`agroup` called with a `by` function must cogroup elements.", + ) + + self.assertListEqual( + anext_or_next( + bi_iterable_to_iter( + Stream(src_raising_at_exhaustion).agroup( + size=10, by=asyncify(lambda n: n % 4 != 0) + ), + itype=itype, + ), + ), + [1, 2, 3, 5, 6, 7, 9, 10, 11, 13], + msg="`agroup` called with a `by` function and a `size` should yield the first batch becoming full.", + ) + + self.assertListEqual( + to_list(Stream(src).agroup(by=asyncify(lambda n: n % 2)), itype=itype), + [list(range(0, N, 2)), list(range(1, N, 2))], + msg="`agroup` called with a `by` function and an infinite size must cogroup elements and yield groups starting with the group containing the oldest element.", + ) + + self.assertListEqual( + to_list( + Stream(range(10)).agroup(by=asyncify(lambda n: n % 4 == 0)), itype=itype + ), + [[0, 4, 8], [1, 2, 3, 5, 6, 7, 9]], + msg="`agroup` called with a `by` function and reaching exhaustion must cogroup elements and yield uncomplete groups starting with the group containing the oldest element, even though it's not the largest.", + ) + + stream_iter = bi_iterable_to_iter( + Stream(src_raising_at_exhaustion).agroup(by=asyncify(lambda n: n % 2)), + itype=itype, + ) + self.assertListEqual( + [anext_or_next(stream_iter), anext_or_next(stream_iter)], + [list(range(0, N, 2)), list(range(1, N, 2))], + msg="`agroup` called with a `by` function and encountering an exception must cogroup elements and yield uncomplete groups starting with the group containing the oldest element.", + ) + with self.assertRaises( + TestError, + msg="`agroup` called with a `by` function and encountering an exception must raise it after all groups have been yielded", + ): + anext_or_next(stream_iter) + + # test seconds + by + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, range(10))).agroup( + interval=datetime.timedelta(seconds=slow_identity_duration * 2.9), + by=asyncify(lambda n: n % 4 == 0), + ), + itype=itype, + ), + [[1, 2], [0, 4], [3, 5, 6, 7], [8], [9]], + msg="`agroup` called with a `by` function must cogroup elements and yield the largest groups when `seconds` is reached event though it's not the oldest.", + ) + + stream_iter = bi_iterable_to_iter( + Stream(src).agroup( + size=3, + by=asyncify( + lambda n: throw(stopiteration_for_iter_type(itype)) if n == 2 else n + ), + ), + itype=itype, + ) + self.assertListEqual( + [anext_or_next(stream_iter), anext_or_next(stream_iter)], + [[0], [1]], + msg="`agroup` should yield incomplete groups when `by` raises", + ) + with self.assertRaisesRegex( + (WrappedError, RuntimeError), + stopiteration_for_iter_type(itype).__name__, + msg="`agroup` should raise and skip `elem` if `by(elem)` raises", + ): + anext_or_next(stream_iter) + self.assertListEqual( + anext_or_next(stream_iter), + [3], + msg="`agroup` should continue yielding after `by`'s exception has been raised.", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_throttle(self, itype: IterableType) -> None: # behavior with invalid arguments with self.assertRaisesRegex( ValueError, r"`per` must be None or a positive timedelta but got datetime\.timedelta\(0\)", msg="`throttle` should raise error when called with negative `per`.", ): - list(Stream([1]).throttle(1, per=datetime.timedelta(microseconds=0))) + to_list( + Stream([1]).throttle(1, per=datetime.timedelta(microseconds=0)), + itype=itype, + ) with self.assertRaisesRegex( ValueError, "`count` must be >= 1 but got 0", msg="`throttle` should raise error when called with `count` < 1.", ): - list(Stream([1]).throttle(0, per=datetime.timedelta(seconds=1))) + to_list( + Stream([1]).throttle(0, per=datetime.timedelta(seconds=1)), itype=itype + ) # test interval interval_seconds = 0.3 @@ -1228,8 +1733,7 @@ def slow_first_elem(elem: int): ], ): with self.subTest(stream=stream): - duration, res = timestream(stream) - + duration, res = timestream(stream, itype=itype) self.assertListEqual( res, expected_elems, @@ -1246,11 +1750,12 @@ def slow_first_elem(elem: int): ) self.assertEqual( - next( - iter( + anext_or_next( + bi_iterable_to_iter( Stream(src) .throttle(1, per=datetime.timedelta(seconds=0.2)) - .throttle(1, per=datetime.timedelta(seconds=0.1)) + .throttle(1, per=datetime.timedelta(seconds=0.1)), + itype=itype, ) ), 0, @@ -1280,7 +1785,7 @@ def slow_first_elem(elem: int): ], ): with self.subTest(N=N, stream=stream): - duration, res = timestream(stream) + duration, res = timestream(stream, itype=itype) self.assertListEqual( res, expected_elems, @@ -1306,7 +1811,7 @@ def slow_first_elem(elem: int): .throttle(1, per=datetime.timedelta(seconds=0.2)), ]: with self.subTest(stream=stream): - duration, _ = timestream(stream) + duration, _ = timestream(stream, itype=itype) self.assertAlmostEqual( duration, expected_duration, @@ -1368,29 +1873,35 @@ def slow_first_elem(elem: int): msg="`throttle` must support legacy kwargs", ) - def test_distinct(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_distinct(self, itype: IterableType) -> None: self.assertListEqual( - list(Stream("abbcaabcccddd").distinct()), + to_list(Stream("abbcaabcccddd").distinct(), itype=itype), list("abcd"), msg="`distinct` should yield distinct elements", ) self.assertListEqual( - list(Stream("aabbcccaabbcccc").distinct(consecutive_only=True)), + to_list( + Stream("aabbcccaabbcccc").distinct(consecutive_only=True), itype=itype + ), list("abcabc"), msg="`distinct` should only remove the duplicates that are consecutive if `consecutive_only=True`", ) for consecutive_only in [True, False]: self.assertListEqual( - list( + to_list( Stream(["foo", "bar", "a", "b"]).distinct( len, consecutive_only=consecutive_only - ) + ), + itype=itype, ), ["foo", "a"], msg="`distinct` should yield the first encountered elem among duplicates", ) self.assertListEqual( - list(Stream([]).distinct(consecutive_only=consecutive_only)), + to_list( + Stream([]).distinct(consecutive_only=consecutive_only), itype=itype + ), [], msg="`distinct` should yield zero elements on empty stream", ) @@ -1399,11 +1910,51 @@ def test_distinct(self) -> None: "unhashable type: 'list'", msg="`distinct` should raise for non-hashable elements", ): - list(Stream([[1]]).distinct()) + to_list(Stream([[1]]).distinct(), itype=itype) - def test_catch(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_adistinct(self, itype: IterableType) -> None: self.assertListEqual( - list(Stream(src).catch(finally_raise=True)), + to_list(Stream("abbcaabcccddd").adistinct(), itype=itype), + list("abcd"), + msg="`adistinct` should yield distinct elements", + ) + self.assertListEqual( + to_list( + Stream("aabbcccaabbcccc").adistinct(consecutive_only=True), itype=itype + ), + list("abcabc"), + msg="`adistinct` should only remove the duplicates that are consecutive if `consecutive_only=True`", + ) + for consecutive_only in [True, False]: + self.assertListEqual( + to_list( + Stream(["foo", "bar", "a", "b"]).adistinct( + asyncify(len), consecutive_only=consecutive_only + ), + itype=itype, + ), + ["foo", "a"], + msg="`adistinct` should yield the first encountered elem among duplicates", + ) + self.assertListEqual( + to_list( + Stream([]).adistinct(consecutive_only=consecutive_only), itype=itype + ), + [], + msg="`adistinct` should yield zero elements on empty stream", + ) + with self.assertRaisesRegex( + TypeError, + "unhashable type: 'list'", + msg="`adistinct` should raise for non-hashable elements", + ): + to_list(Stream([[1]]).adistinct(), itype=itype) + + @parameterized.expand(ITERABLE_TYPES) + def test_catch(self, itype: IterableType) -> None: + self.assertListEqual( + to_list(Stream(src).catch(finally_raise=True), itype=itype), list(src), msg="`catch` should yield elements in exception-less scenarios", ) @@ -1431,12 +1982,12 @@ def f(i): safe_src = list(src) del safe_src[3] self.assertListEqual( - list(stream.catch(ZeroDivisionError)), + to_list(stream.catch(ZeroDivisionError), itype=itype), list(map(f, safe_src)), msg="If the exception type matches the `error_type`, then the impacted element should be ignored.", ) self.assertListEqual( - list(stream.catch()), + to_list(stream.catch(), itype=itype), list(map(f, safe_src)), msg="If the predicate is not specified, then all exceptions should be caught.", ) @@ -1445,7 +1996,7 @@ def f(i): ZeroDivisionError, msg="If a non caught exception type occurs, then it should be raised.", ): - list(stream.catch(TestError)) + to_list(stream.catch(TestError), itype=itype) first_value = 1 second_value = 2 @@ -1465,9 +2016,11 @@ def f(i): erroring_stream.catch(finally_raise=True), erroring_stream.catch(finally_raise=True), ]: - erroring_stream_iterator = iter(caught_erroring_stream) + erroring_stream_iterator = bi_iterable_to_iter( + caught_erroring_stream, itype=itype + ) self.assertEqual( - next(erroring_stream_iterator), + anext_or_next(erroring_stream_iterator), first_value, msg="`catch` should yield the first non exception throwing element.", ) @@ -1476,13 +2029,14 @@ def f(i): TestError, msg="`catch` should raise the first error encountered when `finally_raise` is True.", ): - for _ in erroring_stream_iterator: + while True: + anext_or_next(erroring_stream_iterator) n_yields += 1 with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(erroring_stream_iterator)), msg="`catch` with `finally_raise`=True should finally raise StopIteration to avoid infinite recursion if there is another catch downstream.", ): - next(erroring_stream_iterator) + anext_or_next(erroring_stream_iterator) self.assertEqual( n_yields, 3, @@ -1493,61 +2047,65 @@ def f(i): map(lambda _: throw(TestError), range(2000)) ).catch(TestError) self.assertListEqual( - list(only_caught_errors_stream), + to_list(only_caught_errors_stream, itype=itype), [], msg="When upstream raise exceptions without yielding any element, listing the stream must return empty list, without recursion issue.", ) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(itype), msg="When upstream raise exceptions without yielding any element, then the first call to `next` on a stream catching all errors should raise StopIteration.", ): - next(iter(only_caught_errors_stream)) + anext_or_next(bi_iterable_to_iter(only_caught_errors_stream, itype=itype)) - iterator = iter( + iterator = bi_iterable_to_iter( Stream(map(throw, [TestError, ValueError])) .catch(ValueError, finally_raise=True) - .catch(TestError, finally_raise=True) + .catch(TestError, finally_raise=True), + itype=itype, ) with self.assertRaises( ValueError, msg="With 2 chained `catch`s with `finally_raise=True`, the error caught by the first `catch` is finally raised first (even though it was raised second)...", ): - next(iterator) + anext_or_next(iterator) with self.assertRaises( TestError, msg="... and then the error caught by the second `catch` is raised...", ): - next(iterator) + anext_or_next(iterator) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(iterator)), msg="... and a StopIteration is raised next.", ): - next(iterator) + anext_or_next(iterator) with self.assertRaises( TypeError, msg="`catch` does not catch if `when` not satisfied", ): - list( + to_list( Stream(map(throw, [ValueError, TypeError])).catch( when=lambda exception: "ValueError" in repr(exception) - ) + ), + itype=itype, ) self.assertListEqual( - list( + to_list( Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch( ZeroDivisionError, replacement=float("inf") - ) + ), + itype=itype, ), [float("inf"), 1, 0.5, 0.25], msg="`catch` should be able to yield a non-None replacement", ) self.assertListEqual( - list( + to_list( Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch( ZeroDivisionError, replacement=cast(float, None) - ) + ), + itype=itype, ), [None, 1, 0.5, 0.25], msg="`catch` should be able to yield a None replacement", @@ -1556,7 +2114,7 @@ def f(i): errors_counter: Counter[Type[Exception]] = Counter() self.assertListEqual( - list( + to_list( Stream( map( lambda n: 1 / n, # potential ZeroDivisionError @@ -1571,7 +2129,8 @@ def f(i): ).catch( (ValueError, TestError, ZeroDivisionError), when=lambda err: errors_counter.update([type(err)]) is None, - ) + ), + itype=itype, ), list(map(lambda n: 1 / n, range(2, 10, 2))), msg="`catch` should accept multiple types", @@ -1589,15 +2148,16 @@ def f(i): # Stream(src).catch() # type: ignore self.assertEqual( - list(Stream(map(int, "foo")).catch(replacement=0)), + to_list(Stream(map(int, "foo")).catch(replacement=0), itype=itype), [0] * len("foo"), msg="`catch` must catch all errors when no error type provided", ) self.assertEqual( - list( + to_list( Stream(map(int, "foo")).catch( (None, None, ValueError, None), replacement=0 - ) + ), + itype=itype, ), [0] * len("foo"), msg="`catch` must catch the provided non-None error types", @@ -1606,14 +2166,239 @@ def f(i): ValueError, msg="`catch` must be noop if error type is None", ): - list(Stream(map(int, "foo")).catch(None)) + to_list(Stream(map(int, "foo")).catch(None), itype=itype) with self.assertRaises( ValueError, msg="`catch` must be noop if error types are None", ): - list(Stream(map(int, "foo")).catch((None, None, None))) + to_list(Stream(map(int, "foo")).catch((None, None, None)), itype=itype) + + @parameterized.expand(ITERABLE_TYPES) + def test_acatch(self, itype: IterableType) -> None: + self.assertListEqual( + to_list(Stream(src).acatch(finally_raise=True), itype=itype), + list(src), + msg="`acatch` should yield elements in exception-less scenarios", + ) + + with self.assertRaisesRegex( + TypeError, + "`iterator` must be an AsyncIterator but got a ", + msg="`afunctions.acatch` function should raise TypeError when first argument is not an AsyncIterator", + ): + from streamable import afunctions + + afunctions.acatch(cast(AsyncIterator, [3, 4]), Exception) + + with self.assertRaisesRegex( + TypeError, + "`errors` must be None, or a subclass of `Exception`, or an iterable of optional subclasses of `Exception`, but got ", + msg="`acatch` should raise TypeError when first argument is not None or Type[Exception], or Iterable[Optional[Type[Exception]]]", + ): + Stream(src).acatch(1) # type: ignore + + def f(i): + return i / (3 - i) + + stream = Stream(lambda: map(f, src)) + safe_src = list(src) + del safe_src[3] + self.assertListEqual( + to_list(stream.acatch(ZeroDivisionError), itype=itype), + list(map(f, safe_src)), + msg="If the exception type matches the `error_type`, then the impacted element should be ignored.", + ) + self.assertListEqual( + to_list(stream.acatch(), itype=itype), + list(map(f, safe_src)), + msg="If the predicate is not specified, then all exceptions should be caught.", + ) + + with self.assertRaises( + ZeroDivisionError, + msg="If a non caught exception type occurs, then it should be raised.", + ): + to_list(stream.acatch(TestError), itype=itype) + + first_value = 1 + second_value = 2 + third_value = 3 + functions = [ + lambda: throw(TestError), + lambda: throw(TypeError), + lambda: first_value, + lambda: second_value, + lambda: throw(ValueError), + lambda: third_value, + lambda: throw(ZeroDivisionError), + ] + + erroring_stream: Stream[int] = Stream(lambda: map(lambda f: f(), functions)) + for caught_erroring_stream in [ + erroring_stream.acatch(finally_raise=True), + erroring_stream.acatch(finally_raise=True), + ]: + erroring_stream_iterator = bi_iterable_to_iter( + caught_erroring_stream, itype=itype + ) + self.assertEqual( + anext_or_next(erroring_stream_iterator), + first_value, + msg="`acatch` should yield the first non exception throwing element.", + ) + n_yields = 1 + with self.assertRaises( + TestError, + msg="`acatch` should raise the first error encountered when `finally_raise` is True.", + ): + while True: + anext_or_next(erroring_stream_iterator) + n_yields += 1 + with self.assertRaises( + stopiteration_for_iter_type(type(erroring_stream_iterator)), + msg="`acatch` with `finally_raise`=True should finally raise StopIteration to avoid infinite recursion if there is another catch downstream.", + ): + anext_or_next(erroring_stream_iterator) + self.assertEqual( + n_yields, + 3, + msg="3 elements should have passed been yielded between caught exceptions.", + ) + + only_caught_errors_stream = Stream( + map(lambda _: throw(TestError), range(2000)) + ).acatch(TestError) + self.assertListEqual( + to_list(only_caught_errors_stream, itype=itype), + [], + msg="When upstream raise exceptions without yielding any element, listing the stream must return empty list, without recursion issue.", + ) + with self.assertRaises( + stopiteration_for_iter_type(itype), + msg="When upstream raise exceptions without yielding any element, then the first call to `next` on a stream catching all errors should raise StopIteration.", + ): + anext_or_next(bi_iterable_to_iter(only_caught_errors_stream, itype=itype)) - def test_observe(self) -> None: + iterator = bi_iterable_to_iter( + Stream(map(throw, [TestError, ValueError])) + .acatch(ValueError, finally_raise=True) + .acatch(TestError, finally_raise=True), + itype=itype, + ) + with self.assertRaises( + ValueError, + msg="With 2 chained `acatch`s with `finally_raise=True`, the error caught by the first `acatch` is finally raised first (even though it was raised second)...", + ): + anext_or_next(iterator) + with self.assertRaises( + TestError, + msg="... and then the error caught by the second `acatch` is raised...", + ): + anext_or_next(iterator) + with self.assertRaises( + stopiteration_for_iter_type(type(iterator)), + msg="... and a StopIteration is raised next.", + ): + anext_or_next(iterator) + + with self.assertRaises( + TypeError, + msg="`acatch` does not catch if `when` not satisfied", + ): + to_list( + Stream(map(throw, [ValueError, TypeError])).acatch( + when=asyncify(lambda exception: "ValueError" in repr(exception)) + ), + itype=itype, + ) + + self.assertListEqual( + to_list( + Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).acatch( + ZeroDivisionError, replacement=float("inf") + ), + itype=itype, + ), + [float("inf"), 1, 0.5, 0.25], + msg="`acatch` should be able to yield a non-None replacement", + ) + self.assertListEqual( + to_list( + Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).acatch( + ZeroDivisionError, replacement=cast(float, None) + ), + itype=itype, + ), + [None, 1, 0.5, 0.25], + msg="`acatch` should be able to yield a None replacement", + ) + + errors_counter: Counter[Type[Exception]] = Counter() + + self.assertListEqual( + to_list( + Stream( + map( + lambda n: 1 / n, # potential ZeroDivisionError + map( + throw_for_odd_func(TestError), # potential TestError + map( + int, # potential ValueError + "01234foo56789", + ), + ), + ) + ).acatch( + (ValueError, TestError, ZeroDivisionError), + when=asyncify( + lambda err: errors_counter.update([type(err)]) is None + ), + ), + itype=itype, + ), + list(map(lambda n: 1 / n, range(2, 10, 2))), + msg="`acatch` should accept multiple types", + ) + self.assertDictEqual( + errors_counter, + {TestError: 5, ValueError: 3, ZeroDivisionError: 1}, + msg="`acatch` should accept multiple types", + ) + + # with self.assertRaises( + # TypeError, + # msg="`acatch` without any error type must raise", + # ): + # Stream(src).acatch() # type: ignore + + self.assertEqual( + to_list(Stream(map(int, "foo")).acatch(replacement=0), itype=itype), + [0] * len("foo"), + msg="`acatch` must catch all errors when no error type provided", + ) + self.assertEqual( + to_list( + Stream(map(int, "foo")).acatch( + (None, None, ValueError, None), replacement=0 + ), + itype=itype, + ), + [0] * len("foo"), + msg="`acatch` must catch the provided non-None error types", + ) + with self.assertRaises( + ValueError, + msg="`acatch` must be noop if error type is None", + ): + to_list(Stream(map(int, "foo")).acatch(None), itype=itype) + with self.assertRaises( + ValueError, + msg="`acatch` must be noop if error types are None", + ): + to_list(Stream(map(int, "foo")).acatch((None, None, None)), itype=itype) + + @parameterized.expand(ITERABLE_TYPES) + def test_observe(self, itype: IterableType) -> None: value_error_rainsing_stream: Stream[List[int]] = ( Stream("123--678") .throttle(10, per=datetime.timedelta(seconds=1)) @@ -1625,7 +2410,7 @@ def test_observe(self) -> None: ) self.assertListEqual( - list(value_error_rainsing_stream.catch(ValueError)), + to_list(value_error_rainsing_stream.catch(ValueError), itype=itype), [[1, 2], [3], [6, 7], [8]], msg="This can break due to `group`/`map`/`catch`, check other breaking tests to determine quickly if it's an issue with `observe`.", ) @@ -1634,10 +2419,11 @@ def test_observe(self) -> None: ValueError, msg="`observe` should forward-raise exceptions", ): - list(value_error_rainsing_stream) + to_list(value_error_rainsing_stream, itype=itype) def test_is_iterable(self) -> None: self.assertIsInstance(Stream(src), Iterable) + self.assertIsInstance(Stream(src), AsyncIterable) def test_count(self) -> None: l: List[int] = [] @@ -1656,6 +2442,23 @@ def effect(x: int) -> None: l, list(src), msg="`count` should iterate over the entire stream." ) + def test_acount(self) -> None: + l: List[int] = [] + + def effect(x: int) -> None: + nonlocal l + l.append(x) + + stream = Stream(lambda: map(effect, src)) + self.assertEqual( + asyncio.run(stream.acount()), + N, + msg="`count` should return the count of elements.", + ) + self.assertListEqual( + l, list(src), msg="`count` should iterate over the entire stream." + ) + def test_call(self) -> None: l: List[int] = [] stream = Stream(src).map(l.append) @@ -1670,51 +2473,62 @@ def test_call(self) -> None: msg="`__call__` should exhaust the stream.", ) - def test_multiple_iterations(self) -> None: + def test_await(self) -> None: + l: List[int] = [] + stream = Stream(src).map(l.append) + self.assertIs( + asyncio.run(awaitable_to_coroutine(stream)), + stream, + msg="`__call__` should return the stream.", + ) + self.assertListEqual( + l, + list(src), + msg="`__call__` should exhaust the stream.", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_multiple_iterations(self, itype: IterableType) -> None: stream = Stream(src) for _ in range(3): self.assertListEqual( - list(stream), + to_list(stream, itype=itype), list(src), msg="The first iteration over a stream should yield the same elements as any subsequent iteration on the same stream, even if it is based on a `source` returning an iterator that only support 1 iteration.", ) @parameterized.expand( - [ - [1], - [100], - ] + [(concurrency, itype) for concurrency in (1, 100) for itype in ITERABLE_TYPES] ) - def test_amap(self, concurrency) -> None: + def test_amap(self, concurrency, itype) -> None: self.assertListEqual( - list( + to_list( Stream(src).amap( async_randomly_slowed(async_square), concurrency=concurrency - ) + ), + itype=itype, ), list(map(square, src)), msg="At any concurrency the `amap` method should act as the builtin map function, transforming elements while preserving input elements order.", ) - stream = Stream(src).amap(identity) # type: ignore + stream = Stream(src).amap(identity, concurrency=concurrency) # type: ignore with self.assertRaisesRegex( TypeError, - r"`transformation` must be an async function i\.e\. a function returning a Coroutine but it returned a ", + r"must be an async function i\.e\. a function returning a Coroutine but it returned a ", msg="`amap` should raise a TypeError if a non async function is passed to it.", ): - next(iter(stream)) + anext_or_next(bi_iterable_to_iter(stream, itype=itype)) @parameterized.expand( - [ - [1], - [100], - ] + [(concurrency, itype) for concurrency in (1, 100) for itype in ITERABLE_TYPES] ) - def test_aforeach(self, concurrency) -> None: + def test_aforeach(self, concurrency, itype) -> None: self.assertListEqual( - list( + to_list( Stream(src).aforeach( async_randomly_slowed(async_square), concurrency=concurrency - ) + ), + itype=itype, ), list(src), msg="At any concurrency the `foreach` method must preserve input elements order.", @@ -1725,9 +2539,10 @@ def test_aforeach(self, concurrency) -> None: r"`transformation` must be an async function i\.e\. a function returning a Coroutine but it returned a ", msg="`aforeach` should raise a TypeError if a non async function is passed to it.", ): - next(iter(stream)) + anext_or_next(bi_iterable_to_iter(stream, itype=itype)) - def test_pipe(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_pipe(self, itype: IterableType) -> None: def func( stream: Stream, *ints: int, **strings: str ) -> Tuple[Stream, Tuple[int, ...], Dict[str, str]]: @@ -1744,8 +2559,8 @@ def func( ) self.assertListEqual( - stream.pipe(list), - list(stream), + stream.pipe(to_list, itype=itype), + to_list(stream, itype=itype), msg="`pipe` should be ok without args and kwargs.", ) @@ -1753,18 +2568,27 @@ def test_eq(self) -> None: stream = ( Stream(src) .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=1)) ) @@ -1772,64 +2596,100 @@ def test_eq(self) -> None: stream, Stream(src) .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=1)), ) self.assertNotEqual( stream, Stream(list(src)) # not same source .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=1)), ) self.assertNotEqual( stream, Stream(src) .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=2)), # not the same interval ) - def test_ref_cycles(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_ref_cycles(self, itype: IterableType) -> None: + async def async_int(o: Any) -> int: + return int(o) + stream = ( - Stream(map(int, "123_5")).group(1).groupby(len).catch(finally_raise=True) + Stream("123_5") + .amap(async_int) + .map(str) + .group(1) + .groupby(len) + .catch(finally_raise=True) ) exception: Exception try: - list(stream) + to_list(stream, itype=itype) except ValueError as e: exception = e self.assertIsInstance( @@ -1837,19 +2697,21 @@ def test_ref_cycles(self) -> None: ValueError, msg="`finally_raise` must be respected", ) - frames: Iterator[FrameType] = map( - itemgetter(0), traceback.walk_tb(exception.__traceback__) - ) - next(frames) self.assertFalse( [ (var, val) - for frame in frames + # go through the frames of the exception's traceback + for frame, _ in traceback.walk_tb(exception.__traceback__) + # skipping the current frame + if frame is not cast(TracebackType, exception.__traceback__).tb_frame + # go through the locals captured in that frame for var, val in frame.f_locals.items() + # check if one of them is an exception if isinstance(val, Exception) - and frame in map(itemgetter(0), traceback.walk_tb(val.__traceback__)) + # check if it is captured in its own traceback + and frame is cast(TracebackType, val.__traceback__).tb_frame ], - msg=f"the exception's traceback should not contain an exception that captures itself in its own traceback", + msg=f"the exception's traceback should not contain an exception captured in its own traceback", ) def test_on_queue_in_thread(self) -> None: diff --git a/tests/test_visitor.py b/tests/test_visitor.py index d13301df..ac5cce7c 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -2,8 +2,16 @@ from typing import cast from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -29,19 +37,27 @@ def visit_stream(self, stream: Stream) -> None: visitor = ConcreteVisitor() visitor.visit_catch_stream(cast(CatchStream, ...)) + visitor.visit_acatch_stream(cast(ACatchStream, ...)) visitor.visit_distinct_stream(cast(DistinctStream, ...)) + visitor.visit_adistinct_stream(cast(ADistinctStream, ...)) visitor.visit_filter_stream(cast(FilterStream, ...)) + visitor.visit_afilter_stream(cast(AFilterStream, ...)) visitor.visit_flatten_stream(cast(FlattenStream, ...)) + visitor.visit_aflatten_stream(cast(AFlattenStream, ...)) visitor.visit_foreach_stream(cast(ForeachStream, ...)) visitor.visit_aforeach_stream(cast(AForeachStream, ...)) visitor.visit_group_stream(cast(GroupStream, ...)) + visitor.visit_agroup_stream(cast(AGroupStream, ...)) visitor.visit_groupby_stream(cast(GroupbyStream, ...)) + visitor.visit_agroupby_stream(cast(AGroupbyStream, ...)) visitor.visit_map_stream(cast(MapStream, ...)) visitor.visit_amap_stream(cast(AMapStream, ...)) visitor.visit_observe_stream(cast(ObserveStream, ...)) visitor.visit_skip_stream(cast(SkipStream, ...)) + visitor.visit_askip_stream(cast(ASkipStream, ...)) visitor.visit_throttle_stream(cast(ThrottleStream, ...)) visitor.visit_truncate_stream(cast(TruncateStream, ...)) + visitor.visit_atruncate_stream(cast(ATruncateStream, ...)) visitor.visit_stream(cast(Stream, ...)) def test_depth_visitor_example(self): diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..94d6efa3 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,215 @@ +import asyncio +import random +import time +import timeit +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Callable, + Coroutine, + Iterable, + Iterator, + List, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from streamable.stream import Stream +from streamable.util.iterabletools import BiIterable + +T = TypeVar("T") +R = TypeVar("R") + +IterableType = Union[Type[Iterable], Type[AsyncIterable]] +ITERABLE_TYPES: Tuple[IterableType, ...] = (Iterable, AsyncIterable) + +TEST_EVENT_LOOP = asyncio.new_event_loop() + + +async def _aiter_to_list(aiterable: AsyncIterable[T]) -> List[T]: + return [elem async for elem in aiterable] + + +def aiterable_to_list(aiterable: AsyncIterable[T]) -> List[T]: + return TEST_EVENT_LOOP.run_until_complete(_aiter_to_list(aiterable)) + + +async def _aiter_to_set(aiterable: AsyncIterable[T]) -> Set[T]: + return {elem async for elem in aiterable} + + +def aiterable_to_set(aiterable: AsyncIterable[T]) -> Set[T]: + return TEST_EVENT_LOOP.run_until_complete(_aiter_to_set(aiterable)) + + +def stopiteration_for_iter_type(itype: IterableType) -> Type[Exception]: + if issubclass(itype, AsyncIterable): + return StopAsyncIteration + return StopIteration + + +def to_list(stream: Stream[T], itype: IterableType) -> List[T]: + assert isinstance(stream, Stream) + if itype is AsyncIterable: + return aiterable_to_list(stream) + else: + return list(stream) + + +def to_set(stream: Stream[T], itype: IterableType) -> Set[T]: + assert isinstance(stream, Stream) + if itype is AsyncIterable: + return aiterable_to_set(stream) + else: + return set(stream) + + +def bi_iterable_to_iter( + iterable: Union[BiIterable[T], Stream[T]], itype: IterableType +) -> Union[Iterator[T], AsyncIterator[T]]: + if itype is AsyncIterable: + return iterable.__aiter__() + else: + return iter(iterable) + + +def anext_or_next(it: Union[Iterator[T], AsyncIterator[T]]) -> T: + if isinstance(it, AsyncIterator): + return TEST_EVENT_LOOP.run_until_complete(it.__anext__()) + else: + return next(it) + + +def alist_or_list(iterable: Union[Iterable[T], AsyncIterable[T]]) -> List[T]: + if isinstance(iterable, AsyncIterable): + return aiterable_to_list(iterable) + else: + return list(iterable) + + +def timestream( + stream: Stream[T], times: int = 1, itype: IterableType = Iterable +) -> Tuple[float, List[T]]: + res: List[T] = [] + + def iterate(): + nonlocal res + res = to_list(stream, itype=itype) + + return timeit.timeit(iterate, number=times) / times, res + + +def identity_sleep(seconds: float) -> float: + time.sleep(seconds) + return seconds + + +async def async_identity_sleep(seconds: float) -> float: + await asyncio.sleep(seconds) + return seconds + + +# simulates an I/0 bound function +slow_identity_duration = 0.05 + + +def slow_identity(x: T) -> T: + time.sleep(slow_identity_duration) + return x + + +async def async_slow_identity(x: T) -> T: + await asyncio.sleep(slow_identity_duration) + return x + + +def identity(x: T) -> T: + return x + + +# fmt: off +async def async_identity(x: T) -> T: return x +# fmt: on + + +def square(x): + return x**2 + + +async def async_square(x): + return x**2 + + +def throw(exc: Type[Exception]): + raise exc() + + +def throw_func(exc: Type[Exception]) -> Callable[[T], T]: + return lambda _: throw(exc) + + +def async_throw_func(exc: Type[Exception]) -> Callable[[T], Coroutine[Any, Any, T]]: + async def f(_: T) -> T: + raise exc + + return f + + +def throw_for_odd_func(exc): + return lambda i: throw(exc) if i % 2 == 1 else i + + +def async_throw_for_odd_func(exc): + async def f(i): + return throw(exc) if i % 2 == 1 else i + + return f + + +class TestError(Exception): + pass + + +DELTA_RATE = 0.4 +# size of the test collections +N = 256 + +src = range(N) + +even_src = range(0, N, 2) + + +def randomly_slowed( + func: Callable[[T], R], min_sleep: float = 0.001, max_sleep: float = 0.05 +) -> Callable[[T], R]: + def wrap(x: T) -> R: + time.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) + return func(x) + + return wrap + + +def async_randomly_slowed( + async_func: Callable[[T], Coroutine[Any, Any, R]], + min_sleep: float = 0.001, + max_sleep: float = 0.05, +) -> Callable[[T], Coroutine[Any, Any, R]]: + async def wrap(x: T) -> R: + await asyncio.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) + return await async_func(x) + + return wrap + + +def range_raising_at_exhaustion( + start: int, end: int, step: int, exception: Exception +) -> Iterator[int]: + yield from range(start, end, step) + raise exception + + +src_raising_at_exhaustion = lambda: range_raising_at_exhaustion(0, N, 1, TestError()) diff --git a/version.py b/version.py index 7117f661..3ca80c9b 100644 --- a/version.py +++ b/version.py @@ -1,2 +1,2 @@ # print CHANGELOG: git log --oneline -- version.py | grep -v '\-rc' -__version__ = "1.5.1" +__version__ = "1.6.0a1" From d1a2a04a03ffc4f5ce86df9f3bf64761071a2bd7 Mon Sep 17 00:00:00 2001 From: ebonnal Date: Sun, 11 May 2025 22:37:48 +0100 Subject: [PATCH 2/7] `repr`: introduce `one_liner_max_depth=3` (closes #104) --- streamable/visitors/representation.py | 11 +++++++++-- tests/test_stream.py | 22 ++++++++++++++++------ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/streamable/visitors/representation.py b/streamable/visitors/representation.py index 926876ce..0e36f20b 100644 --- a/streamable/visitors/representation.py +++ b/streamable/visitors/representation.py @@ -32,8 +32,9 @@ class ToStringVisitor(Visitor[str], ABC): - def __init__(self) -> None: + def __init__(self, one_liner_max_depth: int = 3) -> None: self.methods_reprs: List[str] = [] + self.one_liner_max_depth = one_liner_max_depth @staticmethod @abstractmethod @@ -182,10 +183,16 @@ def visit_atruncate_stream(self, stream: ATruncateStream) -> str: return stream.upstream.accept(self) def visit_stream(self, stream: Stream) -> str: + source_stream = f"Stream({self.to_string(stream.source)})" + depth = len(self.methods_reprs) + 1 + if depth == 1: + return source_stream + if depth <= self.one_liner_max_depth: + return f"{source_stream}.{'.'.join(reversed(self.methods_reprs))}" methods_block = "".join( map(lambda r: f" .{r}\n", reversed(self.methods_reprs)) ) - return f"(\n Stream({self.to_string(stream.source)})\n{methods_block})" + return f"(\n {source_stream}\n{methods_block})" class ReprVisitor(ToStringVisitor): diff --git a/tests/test_stream.py b/tests/test_stream.py index 8dfa02d3..cad827b8 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -193,19 +193,29 @@ class CustomCallable: complex_stream.display(logging.ERROR) self.assertEqual( - """( - Stream(range(0, 256)) -)""", str(Stream(src)), + "Stream(range(0, 256))", msg="`repr` should work as expected on a stream without operation", ) self.assertEqual( + str(Stream(src).skip(10)), + "Stream(range(0, 256)).skip(10, until=None)", + msg="`repr` should return a one-liner for a stream with 1 operations", + ) + self.assertEqual( + str(Stream(src).skip(10).skip(10)), + "Stream(range(0, 256)).skip(10, until=None).skip(10, until=None)", + msg="`repr` should return a one-liner for a stream with 2 operations", + ) + self.assertEqual( + str(Stream(src).skip(10).skip(10).skip(10)), """( Stream(range(0, 256)) - .map(, concurrency=2, ordered=True, via='process') + .skip(10, until=None) + .skip(10, until=None) + .skip(10, until=None) )""", - str(Stream(src).map(lambda _: _, concurrency=2, via="process")), - msg="`repr` should work as expected on a stream with 1 operation", + msg="`repr` should go to line for a stream with 3 operations", ) self.assertEqual( str(complex_stream), From 47b35a06c5a5a603fcdf58a6b9bd212dc521a04a Mon Sep 17 00:00:00 2001 From: ebonnal Date: Mon, 12 May 2025 00:06:51 +0100 Subject: [PATCH 3/7] README: 'snippet'; shorter async operations section; ETL as example --- README.md | 196 +++++++++++++++++++++++------------------------------- 1 file changed, 82 insertions(+), 114 deletions(-) diff --git a/README.md b/README.md index 1a799397..a2b65791 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,4 @@ [![coverage](https://codecov.io/gh/ebonnal/streamable/graph/badge.svg?token=S62T0JQK9N)](https://codecov.io/gh/ebonnal/streamable) -[![testing](https://github.com/ebonnal/streamable/actions/workflows/testing.yml/badge.svg?branch=main)](https://github.com/ebonnal/streamable/actions) -[![typing](https://github.com/ebonnal/streamable/actions/workflows/typing.yml/badge.svg?branch=main)](https://github.com/ebonnal/streamable/actions) -[![formatting](https://github.com/ebonnal/streamable/actions/workflows/formatting.yml/badge.svg?branch=main)](https://github.com/ebonnal/streamable/actions) [![PyPI](https://github.com/ebonnal/streamable/actions/workflows/pypi.yml/badge.svg?branch=main)](https://pypi.org/project/streamable) [![Anaconda-Server Badge](https://anaconda.org/conda-forge/streamable/badges/version.svg)](https://anaconda.org/conda-forge/streamable) @@ -11,19 +8,16 @@ - πŸ”— ***Fluent*** chainable lazy operations - πŸ”€ ***Concurrent*** via *threads*/*processes*/`async` -- πŸ‡Ή ***Typed***, fully annotated, `Stream[T]` is both an `Iterable[T]` and an `AsyncIterable[T]` -- πŸ›‘οΈ ***Tested*** extensively with **Python 3.7 to 3.14** -- πŸͺΆ ***Light***, no dependencies - +- πŸ‡Ή Fully ***Typed***, `Stream[T]` is both an `Iterable[T]` and an `AsyncIterable[T]` +- πŸ›‘οΈ ***Battle-tested*** for prod, extensively tested with **Python 3.7 to 3.14**. ## 1. install +> no dependencies ```bash pip install streamable -``` -*or* -```bash +# or conda install conda-forge::streamable ``` @@ -55,10 +49,13 @@ inverses: Stream[float] = ( ## 5. iterate -Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`/`AsyncIterable`, elements are processed *on-the-fly*: +Iterate over a `Stream[T]` just as you would over any other `Iterable[T]` (or `AsyncIterable`), elements are processed *on-the-fly*: ### as `Iterable[T]` + +
πŸ‘€ show snippets
+ - **into data structure** ```python >>> list(inverses) @@ -87,8 +84,12 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`/`Async 1.0 ``` +
+ ### as `AsyncIterable[T]` +
πŸ‘€ show snippets
+ - **`async for`** ```python >>> async def main() -> List[float]: @@ -104,11 +105,12 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`/`Async 1.0 ``` +
-# ↔️ **Extract-Transform-Load** +# ↔ example: Extract-Transform-Load -**ETL scripts** can benefit from the expressiveness of this library. Below is a pipeline that extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: +Let's take an example showcasing most of the `Stream`'s operations: Below is a pipeline that extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: ```python import csv @@ -157,7 +159,8 @@ with open("./quadruped_pokemons.csv", mode="w") as file: ## or the `async` way -Using `.amap` (the `.map`'s `async` counterpart), and `await`ing the stream to exhaust it as an `AsyncIterable[T]`: +- use the `.amap` operation: the `.map`'s `async` counterpart, see [`async` Operations](#-async-operations). +- `await` the `Stream`: runs a full iteration over it as an `AsyncIterable[T]`. ```python import asyncio @@ -211,13 +214,11 @@ asyncio.run(main()) # πŸ“’ ***Operations*** -*A dozen expressive lazy operations and that’s it!* - ## `.map` > Applies a transformation on elements: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integer_strings: Stream[str] = integers.map(str) @@ -226,16 +227,11 @@ assert list(integer_strings) == ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9 ```
-### concurrency - -> [!NOTE] -> By default, all the concurrency modes presented below yield results in the upstream order (FIFO). Set the parameter `ordered=False` to yield results as they become available (***First Done, First Out***). - ### thread-based concurrency -> Applies the transformation via `concurrency` threads: +> Applies the transformation via `concurrency` threads, yielding results in the upstream order (FIFO), set the parameter `ordered=False` to yield results as they become available (*First Done, First Out*). -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import requests @@ -252,16 +248,13 @@ assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur']
> [!NOTE] -> `concurrency` is also the size of the buffer containing not-yet-yielded results. **If the buffer is full, the iteration over the upstream is paused** until a result is yielded from the buffer. - -> [!TIP] -> The performance of thread-based concurrency in a CPU-bound script can be drastically improved by using a [Python 3.13+ free-threading build](https://docs.python.org/3/using/configure.html#cmdoption-disable-gil). +> **Memory-efficient**: Only `concurrent` upstream elements are pulled for processing; the next upstream element is pulled only when a result is yielded downstream. ### process-based concurrency > Set `via="process"`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python if __name__ == "__main__": @@ -273,15 +266,15 @@ if __name__ == "__main__": ```
-### `async`-based concurrency: [see `.amap`](#amap) +### `async`-based concurrency: [`.amap`](#amap) -> [The `.amap` operation can apply an `async` function concurrently.](#amap) +> The [`.amap`](#amap) operation can apply an `async` function concurrently. ### "starmap" > The `star` function decorator transforms a function that takes several positional arguments into a function that takes a tuple: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable import star @@ -296,12 +289,11 @@ assert list(zeros) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- ## `.foreach` > Applies a side effect on elements: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python state: List[int] = [] @@ -318,14 +310,14 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] > - set the `concurrency` parameter for **thread-based concurrency** > - set `via="process"` for **process-based concurrency** > - set `ordered=False` for ***First Done First Out*** -> - [The `.aforeach` operation can apply an `async` effect concurrently.](#aforeach) +> - The [`.aforeach`](#aforeach) operation can apply an `async` effect concurrently. ## `.group` > Groups into `List`s > ... up to a given group `size`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_by_5: Stream[List[int]] = integers.group(size=5) @@ -336,7 +328,7 @@ assert list(integers_by_5) == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] > ... and/or co-groups `by` a given key: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_by_parity: Stream[List[int]] = integers.group(by=lambda n: n % 2) @@ -347,7 +339,7 @@ assert list(integers_by_parity) == [[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]] > ... and/or co-groups the elements yielded by the upstream within a given time `interval`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from datetime import timedelta @@ -365,7 +357,7 @@ assert list(integers_within_1_sec) == [[0, 1, 2], [3, 4], [5, 6], [7, 8], [9]] > [!TIP] > Combine the `size`/`by`/`interval` parameters: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_by_parity_by_2: Stream[List[int]] = ( @@ -380,7 +372,7 @@ assert list(integers_by_parity_by_2) == [[0, 2], [1, 3], [4, 6], [5, 7], [8], [9 ## `.groupby` > Like `.group`, but groups into `(key, elements)` tuples: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_by_parity: Stream[Tuple[str, List[int]]] = ( @@ -395,7 +387,7 @@ assert list(integers_by_parity) == [("even", [0, 2, 4, 6, 8]), ("odd", [1, 3, 5, > [!TIP] > Then *"starmap"* over the tuples: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable import star @@ -413,7 +405,7 @@ assert list(counts_by_parity) == [("even", 5), ("odd", 5)] > Ungroups elements assuming that they are `Iterable`s: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python even_then_odd_integers: Stream[int] = integers_by_parity.flatten() @@ -426,7 +418,7 @@ assert list(even_then_odd_integers) == [0, 2, 4, 6, 8, 1, 3, 5, 7, 9] > Flattens `concurrency` iterables concurrently: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python mixed_ones_and_zeros: Stream[int] = ( @@ -441,7 +433,7 @@ assert list(mixed_ones_and_zeros) == [0, 1, 0, 1, 0, 1, 0, 1] > Keeps only the elements that satisfy a condition: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python even_integers: Stream[int] = integers.filter(lambda n: n % 2 == 0) @@ -454,7 +446,7 @@ assert list(even_integers) == [0, 2, 4, 6, 8] > Removes duplicates: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python distinct_chars: Stream[str] = Stream("foobarfooo").distinct() @@ -465,7 +457,7 @@ assert list(distinct_chars) == ["f", "o", "b", "a", "r"] > specifying a deduplication `key`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python strings_of_distinct_lengths: Stream[str] = ( @@ -480,7 +472,7 @@ assert list(strings_of_distinct_lengths) == ["a", "foo"] > [!WARNING] > During iteration, all distinct elements that are yielded are retained in memory to perform deduplication. However, you can remove only consecutive duplicates without a memory footprint by setting `consecutive_only=True`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python consecutively_distinct_chars: Stream[str] = ( @@ -496,7 +488,7 @@ assert list(consecutively_distinct_chars) == ["f", "o", "b", "a", "r", "f", "o"] > Ends iteration once a given number of elements have been yielded: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python five_first_integers: Stream[int] = integers.truncate(5) @@ -507,7 +499,7 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4] > or `when` a condition is satisfied: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python five_first_integers: Stream[int] = integers.truncate(when=lambda n: n == 5) @@ -522,7 +514,7 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4] > Skips the first specified number of elements: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_after_five: Stream[int] = integers.skip(5) @@ -533,7 +525,7 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9] > or skips elements `until` a predicate is satisfied: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_after_five: Stream[int] = integers.skip(until=lambda n: n >= 5) @@ -548,7 +540,7 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9] > Catches a given type of exception, and optionally yields a `replacement` value: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python inverses: Stream[float] = ( @@ -562,7 +554,7 @@ assert list(inverses) == [float("inf"), 1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0
> You can specify an additional `when` condition for the catch: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import requests @@ -582,9 +574,9 @@ assert list(status_codes_ignoring_resolution_errors) == [200, 404] > It has an optional `finally_raise: bool` parameter to raise the first exception caught (if any) when the iteration terminates. > [!TIP] -> Apply side effects when catching an exception by integrating them into `when`: +> Leverage `when` to apply side effects on catch: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python errors: List[Exception] = [] @@ -608,7 +600,7 @@ assert len(errors) == len("foo") > Limits the number of yields `per` time interval: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from datetime import timedelta @@ -624,7 +616,7 @@ assert list(three_integers_per_second) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ## `.observe` > Logs the progress of iterations: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python >>> assert list(integers.throttle(2, per=timedelta(seconds=1)).observe("integers")) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -645,7 +637,7 @@ INFO: [duration=0:00:04.003852 errors=0] 10 integers yielded > [!TIP] > To mute these logs, set the logging level above `INFO`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import logging @@ -657,7 +649,7 @@ logging.getLogger("streamable").setLevel(logging.WARNING) > Concatenates streams: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4, 5, 6, 7, 8, 9] @@ -667,10 +659,9 @@ assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4 ## `zip` -> [!TIP] > Use the standard `zip` function: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable import star @@ -686,14 +677,14 @@ assert list(cubes) == [0, 1, 8, 27, 64, 125, 216, 343, 512, 729] ## Shorthands for consuming the stream -> [!NOTE] -> Although consuming the stream is beyond the scope of this library, it provides two basic shorthands to trigger an iteration: + +Although consuming the stream is beyond the scope of this library, it provides two basic shorthands to trigger an iteration: ## `.count` > Iterates over the stream until exhaustion and returns the number of elements yielded: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python assert integers.count() == 10 @@ -703,7 +694,7 @@ assert integers.count() == 10 ## `()` > *Calling* the stream iterates over it until exhaustion and returns it: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python state: List[int] = [] @@ -715,9 +706,9 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ## `.pipe` -> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any: +> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any (inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html)): -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import pandas as pd @@ -731,26 +722,22 @@ import pandas as pd ```
-> Inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html). - --- --- # πŸ“’ ***`async` Operations*** -Operations that accept a function as an argument have an `async` counterpart, which has the same signature but accepts `async` functions instead. These `async` operations are named the same as the original ones but with an `a` prefix. +The operations accepting a function as an argument have an `async` counterpart operation, which has the same signature but accepts `async` functions instead. -> [!TIP] -> One can mix regular and `async` operations on the same `Stream`, and then consume it as a regular `Iterable` or as an `AsyncIterable`. +**inter-operability**: Both regular and `async` operations can be mixed on the same `Stream`, and it can then be consumed as regular `Iterable` or as `AsyncIterable`. -## `.amap` +## `.amap` > Applies an `async` transformation on elements: +- consume as `Iterable[T]`: -### Consume as `Iterable[T]` - -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import asyncio @@ -771,9 +758,9 @@ asyncio.run(http_async_client.aclose()) ```
-### Consume as `AsyncIterable[T]` +- consume as `AsyncIterable[T]`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import asyncio @@ -794,52 +781,33 @@ asyncio.run(main()) ```
+## Others -## `.aforeach` - -> Applies an `async` side effect on elements. Supports `concurrency` like `.amap`. - -## `.agroup` - -> Groups into `List`s according to an `async` grouping function. - -## `.agroupby` - -> Groups into `(key, elements)` tuples, according to an `async` grouping function. - -## `.aflatten` - -> Ungroups elements assuming that they are `AsyncIterable`s. - -> Like for `.flatten` you can set the `concurrency` parameter. - -## `.afilter` - -> Keeps only the elements that satisfy an `async` condition. +> **`.aforeach`**: Applies an `async` side effect on elements. Supports `concurrency` like `.amap`. -## `.adistinct` +> **`.agroup`**: Groups into `List`s according to an `async` grouping function. -> Removes duplicates according to an `async` deduplication `key`. +> **`.agroupby`**: Groups into `(key, elements)` tuples, according to an `async` grouping function. -## `.atruncate` +> **`.aflatten`**: Ungroups elements assuming that they are `AsyncIterable`s. Like for `.flatten` you can set the `concurrency` parameter. -> Ends iteration once a given number of elements have been yielded or `when` an `async` condition is satisfied. +> **`.afilter`**: Keeps only the elements that satisfy an `async` condition. -## `.askip` +> **`.adistinct`**: Removes duplicates according to an `async` deduplication `key`. -> Skips the specified number of elements or `until` an `async` predicate is satisfied. +> **`.atruncate`**: Ends iteration once a given number of elements have been yielded or `when` an `async` condition is satisfied. -## `.acatch` +> **`.askip`**: Skips the specified number of elements or `until` an `async` predicate is satisfied. -> Catches a given type of exception `when` an `async` condition is satisfied. +> **`.acatch`**: Catches a given type of exception `when` an `async` condition is satisfied. ## Shorthands for consuming the stream as an `AsyncIterable[T]` -## `.acount` +### `.acount` > Iterates over the stream until exhaustion and returns the number of elements yielded: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python assert asyncio.run(integers.acount()) == 10 @@ -847,11 +815,11 @@ assert asyncio.run(integers.acount()) == 10
-## `await` +### `await` > *Awaiting* the stream iterates over it until exhaustion and returns it: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python async def test_await() -> None: @@ -873,7 +841,7 @@ asyncio.run(test_await()) > [!TIP] > If any of the operations raises an exception, you can resume the iteration after handling it: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from contextlib import suppress @@ -900,7 +868,7 @@ assert collected == [0, 1, 2, 3, 5, 6, 7, 8, 9] > [!TIP] > A `Stream` can be visited via its `.accept` method: implement a custom [***visitor***](https://en.wikipedia.org/wiki/Visitor_pattern) by extending the abstract class `streamable.visitors.Visitor`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable.visitors import Visitor @@ -921,7 +889,7 @@ assert depth(Stream(range(10)).map(str).foreach(print)) == 3 ## Functions > [!TIP] > The `Stream`'s methods are also exposed as functions: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable.functions import catch From 28a7946f21ef4a32a3dbcc98198f82f707143d8b Mon Sep 17 00:00:00 2001 From: ebonnal Date: Mon, 12 May 2025 17:40:26 +0100 Subject: [PATCH 4/7] prefer `.__next__`/`.__iter__` --- streamable/aiterators.py | 8 +-- streamable/iterators.py | 120 +++++++++++++++---------------- streamable/util/futuretools.py | 2 +- streamable/util/iterabletools.py | 6 +- streamable/visitors/iterator.py | 4 +- 5 files changed, 69 insertions(+), 71 deletions(-) diff --git a/streamable/aiterators.py b/streamable/aiterators.py index 493c7a8f..72fb85d8 100644 --- a/streamable/aiterators.py +++ b/streamable/aiterators.py @@ -156,12 +156,12 @@ class FlattenAsyncIterator(AsyncIterator[U]): def __init__(self, iterator: AsyncIterator[Iterable[U]]) -> None: validate_aiterator(iterator) self.iterator = iterator - self._current_iterator_elem: Iterator[U] = iter(tuple()) + self._current_iterator_elem: Iterator[U] = tuple().__iter__() async def __anext__(self) -> U: while True: try: - return next(self._current_iterator_elem) + return self._current_iterator_elem.__next__() except StopIteration: self._current_iterator_elem = iter_wo_stopasynciteration( await self.iterator.__anext__() @@ -275,11 +275,11 @@ def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]: return None def _pop_first_group(self) -> Tuple[U, List[T]]: - first_key: U = next(iter(self._groups_by), cast(U, ...)) + first_key: U = self._groups_by.__iter__().__next__() return first_key, self._groups_by.pop(first_key) def _pop_largest_group(self) -> Tuple[U, List[T]]: - largest_group_key: Any = next(iter(self._groups_by), ...) + largest_group_key: Any = self._groups_by.__iter__().__next__() for key, group in self._groups_by.items(): if len(group) > len(self._groups_by[largest_group_key]): diff --git a/streamable/iterators.py b/streamable/iterators.py index 7b684937..30385944 100644 --- a/streamable/iterators.py +++ b/streamable/iterators.py @@ -90,7 +90,7 @@ def __init__( def __next__(self) -> T: while True: try: - return next(self.iterator) + return self.iterator.__next__() except StopIteration: if self._to_be_finally_raised: try: @@ -119,7 +119,7 @@ def __init__( def __next__(self) -> T: while True: - elem = next(self.iterator) + elem = self.iterator.__next__() key = self.key(elem) if self.key else elem if key not in self._already_seen: break @@ -138,7 +138,7 @@ def __init__( def __next__(self) -> T: while True: - elem = next(self.iterator) + elem = self.iterator.__next__() key = self.key(elem) if self.key else elem if key != self._last_key: break @@ -150,14 +150,16 @@ class FlattenIterator(Iterator[U]): def __init__(self, iterator: Iterator[Iterable[U]]) -> None: validate_iterator(iterator) self.iterator = iterator - self._current_iterator_elem: Iterator[U] = iter(tuple()) + self._current_iterator_elem: Iterator[U] = tuple().__iter__() def __next__(self) -> U: while True: try: - return next(self._current_iterator_elem) + return self._current_iterator_elem.__next__() except StopIteration: - self._current_iterator_elem = iter_wo_stopiteration(next(self.iterator)) + self._current_iterator_elem = iter_wo_stopiteration( + self.iterator.__next__() + ) class AFlattenIterator(Iterator[U], GetEventLoopMixin): @@ -175,7 +177,7 @@ def __next__(self) -> U: ) except StopAsyncIteration: self._current_iterator_elem = aiter_wo_stopiteration( - next(self.iterator) + self.iterator.__next__() ) @@ -233,7 +235,7 @@ def __next__(self) -> List[T]: while len(self._current_group) < self.size and ( not self._interval_seconds_have_elapsed() or not self._current_group ): - self._current_group.append(next(self.iterator)) + self._current_group.append(self.iterator.__next__()) except Exception as e: if not self._current_group: raise @@ -258,7 +260,7 @@ def __init__( self._groups_by: DefaultDict[U, List[T]] = defaultdict(list) def _group_next_elem(self) -> None: - elem = next(self.iterator) + elem = self.iterator.__next__() self._groups_by[self.key(elem)].append(elem) def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]: @@ -268,11 +270,11 @@ def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]: return None def _pop_first_group(self) -> Tuple[U, List[T]]: - first_key: U = next(iter(self._groups_by), cast(U, ...)) + first_key: U = self._groups_by.__iter__().__next__() return first_key, self._groups_by.pop(first_key) def _pop_largest_group(self) -> Tuple[U, List[T]]: - largest_group_key: Any = next(iter(self._groups_by), ...) + largest_group_key: Any = self._groups_by.__iter__().__next__() for key, group in self._groups_by.items(): if len(group) > len(self._groups_by[largest_group_key]): @@ -309,11 +311,11 @@ def __next__(self) -> Tuple[U, List[T]]: except StopIteration: self._is_exhausted = True - return next(self) + return self.__next__() except Exception as e: self._to_be_raised = e - return next(self) + return self.__next__() class CountSkipIterator(Iterator[T]): @@ -328,11 +330,11 @@ def __init__(self, iterator: Iterator[T], count: int) -> None: def __next__(self) -> T: if not self._done_skipping: while self._n_skipped < self.count: - next(self.iterator) + self.iterator.__next__() # do not count exceptions as skipped elements self._n_skipped += 1 self._done_skipping = True - return next(self.iterator) + return self.iterator.__next__() class PredicateSkipIterator(Iterator[T]): @@ -343,10 +345,10 @@ def __init__(self, iterator: Iterator[T], until: Callable[[T], Any]) -> None: self._done_skipping = False def __next__(self) -> T: - elem = next(self.iterator) + elem = self.iterator.__next__() if not self._done_skipping: while not self.until(elem): - elem = next(self.iterator) + elem = self.iterator.__next__() self._done_skipping = True return elem @@ -364,10 +366,10 @@ def __init__( self._done_skipping = False def __next__(self) -> T: - elem = next(self.iterator) + elem = self.iterator.__next__() if not self._done_skipping: while self._n_skipped < self.count and not self.until(elem): - elem = next(self.iterator) + elem = self.iterator.__next__() # do not count exceptions as skipped elements self._n_skipped += 1 self._done_skipping = True @@ -385,7 +387,7 @@ def __init__(self, iterator: Iterator[T], count: int) -> None: def __next__(self) -> T: if self._current_count == self.count: raise StopIteration() - elem = next(self.iterator) + elem = self.iterator.__next__() self._current_count += 1 return elem @@ -400,7 +402,7 @@ def __init__(self, iterator: Iterator[T], when: Callable[[T], Any]) -> None: def __next__(self) -> T: if self._satisfied: raise StopIteration() - elem = next(self.iterator) + elem = self.iterator.__next__() if self.when(elem): self._satisfied = True raise StopIteration() @@ -436,7 +438,7 @@ def _log(self) -> None: def __next__(self) -> T: try: - elem = next(self.iterator) + elem = self.iterator.__next__() self._n_nexts += 1 self._n_yields += 1 return elem @@ -472,7 +474,7 @@ def __init__( def safe_next(self) -> Tuple[Optional[T], Optional[Exception]]: try: - return next(self.iterator), None + return self.iterator.__next__(), None except StopIteration: raise except Exception as e: @@ -513,7 +515,7 @@ def __init__( self.iterator = iterator def __next__(self) -> T: - elem = next(self.iterator) + elem = self.iterator.__next__() if isinstance(elem, self.ExceptionContainer): raise elem.exception return elem @@ -566,13 +568,17 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]: # queue tasks up to buffersize with suppress(StopIteration): while len(future_results) < self.buffersize: - future_results.add_future(self._launch_task(next(self.iterator))) + future_results.add_future( + self._launch_task(self.iterator.__next__()) + ) # wait, queue, yield while future_results: - result = next(future_results) + result = future_results.__next__() with suppress(StopIteration): - future_results.add_future(self._launch_task(next(self.iterator))) + future_results.add_future( + self._launch_task(self.iterator.__next__()) + ) yield result @@ -638,16 +644,14 @@ def __init__( via: "Literal['thread', 'process']", ) -> None: super().__init__( - iter( - _ConcurrentMapIterable( - iterator, - transformation, - concurrency, - buffersize, - ordered, - via, - ) - ) + _ConcurrentMapIterable( + iterator, + transformation, + concurrency, + buffersize, + ordered, + via, + ).__iter__() ) @@ -701,14 +705,12 @@ def __init__( ordered: bool, ) -> None: super().__init__( - iter( - _ConcurrentAMapIterable( - iterator, - transformation, - buffersize, - ordered, - ) - ) + _ConcurrentAMapIterable( + iterator, + transformation, + buffersize, + ordered, + ).__iter__() ) @@ -751,7 +753,7 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: while len(iterator_and_future_pairs) < self.buffersize: if not iterator_to_queue: try: - iterable = next(self.iterables_iterator) + iterable = self.iterables_iterator.__next__() except StopIteration: break try: @@ -776,13 +778,11 @@ def __init__( buffersize: int, ) -> None: super().__init__( - iter( - _ConcurrentFlattenIterable( - iterables_iterator, - concurrency, - buffersize, - ) - ) + _ConcurrentFlattenIterable( + iterables_iterator, + concurrency, + buffersize, + ).__iter__() ) @@ -828,7 +828,7 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: while len(iterator_and_future_pairs) < self.buffersize: if not iterator_to_queue: try: - iterable = next(self.iterables_iterator) + iterable = self.iterables_iterator.__next__() except StopIteration: break try: @@ -859,11 +859,9 @@ def __init__( buffersize: int, ) -> None: super().__init__( - iter( - _ConcurrentAFlattenIterable( - iterables_iterator, - concurrency, - buffersize, - ) - ) + _ConcurrentAFlattenIterable( + iterables_iterator, + concurrency, + buffersize, + ).__iter__() ) diff --git a/streamable/util/futuretools.py b/streamable/util/futuretools.py index 278ce5fb..9b7af7d8 100644 --- a/streamable/util/futuretools.py +++ b/streamable/util/futuretools.py @@ -20,7 +20,7 @@ class FutureResultCollection(Iterator[T], AsyncIterator[T], Sized, ABC): def add_future(self, future: "Future[T]") -> None: ... async def __anext__(self) -> T: - return next(self) + return self.__next__() class DequeFutureResultCollection(FutureResultCollection[T]): diff --git a/streamable/util/iterabletools.py b/streamable/util/iterabletools.py index cbb1228c..6d46ef66 100644 --- a/streamable/util/iterabletools.py +++ b/streamable/util/iterabletools.py @@ -25,7 +25,7 @@ def __init__(self, iterable: Iterable[T]): self.iterable = iterable def __iter__(self) -> Iterator[T]: - return iter(self.iterable) + return self.iterable.__iter__() def __aiter__(self) -> AsyncIterator[T]: return SyncToAsyncIterator(self.iterable) @@ -36,11 +36,11 @@ def __aiter__(self) -> AsyncIterator[T]: class SyncToAsyncIterator(AsyncIterator[T]): def __init__(self, iterator: Iterable[T]): - self.iterator: Iterator[T] = iter(iterator) + self.iterator: Iterator[T] = iterator.__iter__() async def __anext__(self) -> T: try: - return next(self.iterator) + return self.iterator.__next__() except StopIteration as e: raise StopAsyncIteration() from e diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py index 34167252..bb63bdb0 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/iterator.py @@ -216,13 +216,13 @@ def visit_atruncate_stream(self, stream: ATruncateStream[T]) -> Iterator[T]: def visit_stream(self, stream: Stream[T]) -> Iterator[T]: if isinstance(stream.source, Iterable): - return iter(stream.source) + return stream.source.__iter__() if isinstance(stream.source, AsyncIterable): return async_to_sync_iter(stream.source) if callable(stream.source): iterable = stream.source() if isinstance(iterable, Iterable): - return iter(iterable) + return iterable.__iter__() if isinstance(iterable, AsyncIterable): return async_to_sync_iter(iterable) raise TypeError( From 2564f921c8feb57e720b1ffa95626e660042273a Mon Sep 17 00:00:00 2001 From: ebonnal Date: Tue, 13 May 2025 14:31:04 +0100 Subject: [PATCH 5/7] README: shorten, merge the "`async` Operation" section into the regular one --- README.md | 246 ++++++++++++++++++------------------------- tests/test_readme.py | 2 +- 2 files changed, 104 insertions(+), 144 deletions(-) diff --git a/README.md b/README.md index a2b65791..01aea391 100644 --- a/README.md +++ b/README.md @@ -8,18 +8,19 @@ - πŸ”— ***Fluent*** chainable lazy operations - πŸ”€ ***Concurrent*** via *threads*/*processes*/`async` -- πŸ‡Ή Fully ***Typed***, `Stream[T]` is both an `Iterable[T]` and an `AsyncIterable[T]` +- πŸ‡Ή Fully ***Typed***, `Stream[T]` is an `Iterable[T]` (and an `AsyncIterable[T]`) - πŸ›‘οΈ ***Battle-tested*** for prod, extensively tested with **Python 3.7 to 3.14**. ## 1. install -> no dependencies -```bash -pip install streamable -# or -conda install conda-forge::streamable -``` +`pip install streamable` + +or + +`conda install conda-forge::streamable` + +No dependencies. ## 2. import @@ -49,10 +50,10 @@ inverses: Stream[float] = ( ## 5. iterate -Iterate over a `Stream[T]` just as you would over any other `Iterable[T]` (or `AsyncIterable`), elements are processed *on-the-fly*: +Iterate over a `Stream[T]` just as you would over any other `Iterable[T]` (or `AsyncIterable[T]`), elements are processed *on-the-fly*: -### as `Iterable[T]` +### as an `Iterable[T]`
πŸ‘€ show snippets
@@ -86,7 +87,7 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]` (or `A
-### as `AsyncIterable[T]` +### as an `AsyncIterable[T]`
πŸ‘€ show snippets
@@ -159,8 +160,7 @@ with open("./quadruped_pokemons.csv", mode="w") as file: ## or the `async` way -- use the `.amap` operation: the `.map`'s `async` counterpart, see [`async` Operations](#-async-operations). -- `await` the `Stream`: runs a full iteration over it as an `AsyncIterable[T]`. +Use the `.amap` operation and `await` the `Stream`: ```python import asyncio @@ -182,7 +182,7 @@ async def main() -> None: Stream(itertools.count(1)) # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs .throttle(16, per=timedelta(seconds=1)) - # GETs pokemons via 8 concurrent asyncio coroutines + # GETs pokemons via 8 concurrent coroutines .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") .amap(http_async_client.get, concurrency=8) .foreach(httpx.Response.raise_for_status) @@ -214,7 +214,12 @@ asyncio.run(main()) # πŸ“’ ***Operations*** -## `.map` +A dozen expressive lazy operations and that's it. + +> [!NOTE] +> **`async` counterparts:** The operations accepting a function as an argument have an `async` counterpart operation (same name but with an "`a`" prefix), which has the same signature but accepts `async` functions instead. Both regular and `async` operations can be mixed on the same `Stream`, and it can then be consumed as regular `Iterable` or as `AsyncIterable`. + +## `.map`/`.amap` > Applies a transformation on elements: @@ -248,7 +253,7 @@ assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur']
> [!NOTE] -> **Memory-efficient**: Only `concurrent` upstream elements are pulled for processing; the next upstream element is pulled only when a result is yielded downstream. +> **Memory-efficient**: Only `concurrency` upstream elements are pulled for processing; the next upstream element is pulled only when a result is yielded downstream. ### process-based concurrency @@ -266,9 +271,52 @@ if __name__ == "__main__": ```
-### `async`-based concurrency: [`.amap`](#amap) +### `async`-based concurrency: `.amap` + +> `.amap` can apply an `async` transformation concurrently. + +
πŸ‘€ show snippet
+ +- consumed as an `Iterable[T]`: + +```python +import asyncio +import httpx + +http_async_client = httpx.AsyncClient() + +pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http_async_client.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) +) + +assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] +asyncio.run(http_async_client.aclose()) +``` + +- consumed as an `AsyncIterable[T]`: + +```python +import asyncio +import httpx + +async def main() -> None: + async with httpx.AsyncClient() as http_async_client: + pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http_async_client.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) + ) + assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] -> The [`.amap`](#amap) operation can apply an `async` function concurrently. +asyncio.run(main()) +``` +
### "starmap" @@ -289,7 +337,7 @@ assert list(zeros) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
-## `.foreach` +## `.foreach` / `.aforeach` > Applies a side effect on elements: @@ -310,9 +358,9 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] > - set the `concurrency` parameter for **thread-based concurrency** > - set `via="process"` for **process-based concurrency** > - set `ordered=False` for ***First Done First Out*** -> - The [`.aforeach`](#aforeach) operation can apply an `async` effect concurrently. +> - The `.aforeach` operation can apply an `async` effect concurrently. -## `.group` +## `.group` / `.agroup` > Groups into `List`s @@ -369,7 +417,7 @@ assert list(integers_by_parity_by_2) == [[0, 2], [1, 3], [4, 6], [5, 7], [8], [9 ```
-## `.groupby` +## `.groupby` / `.agroupby` > Like `.group`, but groups into `(key, elements)` tuples:
πŸ‘€ show snippet
@@ -401,9 +449,9 @@ assert list(counts_by_parity) == [("even", 5), ("odd", 5)] ```
-## `.flatten` +## `.flatten` / `.aflatten` -> Ungroups elements assuming that they are `Iterable`s: +> Ungroups elements assuming that they are `Iterable`s (or `AsyncIterable`s for `.aflatten`):
πŸ‘€ show snippet
@@ -416,7 +464,7 @@ assert list(even_then_odd_integers) == [0, 2, 4, 6, 8, 1, 3, 5, 7, 9] ### thread-based concurrency -> Flattens `concurrency` iterables concurrently: +> Concurrently flattens `concurrency` iterables via threads (or via coroutines for `.aflatten`):
πŸ‘€ show snippet
@@ -429,7 +477,7 @@ assert list(mixed_ones_and_zeros) == [0, 1, 0, 1, 0, 1, 0, 1] ```
-## `.filter` +## `.filter` / `.afilter` > Keeps only the elements that satisfy a condition: @@ -442,7 +490,7 @@ assert list(even_integers) == [0, 2, 4, 6, 8] ```
-## `.distinct` +## `.distinct` / `.adistinct` > Removes duplicates: @@ -484,7 +532,7 @@ assert list(consecutively_distinct_chars) == ["f", "o", "b", "a", "r", "f", "o"] ```
-## `.truncate` +## `.truncate` / `.atruncate` > Ends iteration once a given number of elements have been yielded: @@ -510,7 +558,7 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4] > If both `count` and `when` are set, truncation occurs as soon as either condition is met. -## `.skip` +## `.skip` / `.askip` > Skips the first specified number of elements: @@ -536,7 +584,7 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9] > If both `count` and `until` are set, skipping stops as soon as either condition is met. -## `.catch` +## `.catch` / `.acatch` > Catches a given type of exception, and optionally yields a `replacement` value: @@ -680,9 +728,9 @@ assert list(cubes) == [0, 1, 8, 27, 64, 125, 216, 343, 512, 729] Although consuming the stream is beyond the scope of this library, it provides two basic shorthands to trigger an iteration: -## `.count` +## `.count` / `.acount` -> Iterates over the stream until exhaustion and returns the number of elements yielded: +> `.count` iterates over the stream until exhaustion and returns the number of elements yielded:
πŸ‘€ show snippet
@@ -691,149 +739,61 @@ assert integers.count() == 10 ```
-## `()` +> The `.acount` (`async` method) iterates over the stream as an `AsyncIterable` until exhaustion and returns the number of elements yielded: -> *Calling* the stream iterates over it until exhaustion and returns it:
πŸ‘€ show snippet
```python -state: List[int] = [] -appending_integers: Stream[int] = integers.foreach(state.append) -assert appending_integers() is appending_integers -assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +assert asyncio.run(integers.acount()) == 10 ``` -
- -## `.pipe` - -> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any (inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html)): - -
πŸ‘€ show snippet
- -```python -import pandas as pd -( - integers - .observe("ints") - .pipe(pd.DataFrame, columns=["integer"]) - .to_csv("integers.csv", index=False) -) -```
---- ---- - -# πŸ“’ ***`async` Operations*** - -The operations accepting a function as an argument have an `async` counterpart operation, which has the same signature but accepts `async` functions instead. - -**inter-operability**: Both regular and `async` operations can be mixed on the same `Stream`, and it can then be consumed as regular `Iterable` or as `AsyncIterable`. - -## `.amap` - -> Applies an `async` transformation on elements: - -- consume as `Iterable[T]`: +## `()` / `await` +> *Calling* the stream iterates over it until exhaustion, and returns it:
πŸ‘€ show snippet
```python -import asyncio -import httpx - -http_async_client = httpx.AsyncClient() - -pokemon_names: Stream[str] = ( - Stream(range(1, 4)) - .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") - .amap(http_async_client.get, concurrency=3) - .map(httpx.Response.json) - .map(lambda poke: poke["name"]) -) - -assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] -asyncio.run(http_async_client.aclose()) +state: List[int] = [] +appending_integers: Stream[int] = integers.foreach(state.append) +assert appending_integers() is appending_integers +assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ```
-- consume as `AsyncIterable[T]`: +> *Awaiting* the stream iterates over it as an `AsyncIterable` until exhaustion, and returns it:
πŸ‘€ show snippet
```python -import asyncio -import httpx - -async def main() -> None: - async with httpx.AsyncClient() as http_async_client: - pokemon_names: Stream[str] = ( - Stream(range(1, 4)) - .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") - .amap(http_async_client.get, concurrency=3) - .map(httpx.Response.json) - .map(lambda poke: poke["name"]) - ) - assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] - -asyncio.run(main()) +async def test_await() -> None: + state: List[int] = [] + appending_integers: Stream[int] = integers.foreach(state.append) + appending_integers is await appending_integers + assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +asyncio.run(test_await()) ```
-## Others - -> **`.aforeach`**: Applies an `async` side effect on elements. Supports `concurrency` like `.amap`. - -> **`.agroup`**: Groups into `List`s according to an `async` grouping function. - -> **`.agroupby`**: Groups into `(key, elements)` tuples, according to an `async` grouping function. - -> **`.aflatten`**: Ungroups elements assuming that they are `AsyncIterable`s. Like for `.flatten` you can set the `concurrency` parameter. - -> **`.afilter`**: Keeps only the elements that satisfy an `async` condition. - -> **`.adistinct`**: Removes duplicates according to an `async` deduplication `key`. - -> **`.atruncate`**: Ends iteration once a given number of elements have been yielded or `when` an `async` condition is satisfied. - -> **`.askip`**: Skips the specified number of elements or `until` an `async` predicate is satisfied. - -> **`.acatch`**: Catches a given type of exception `when` an `async` condition is satisfied. - -## Shorthands for consuming the stream as an `AsyncIterable[T]` - -### `.acount` +## `.pipe` -> Iterates over the stream until exhaustion and returns the number of elements yielded: +> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any (inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html)):
πŸ‘€ show snippet
```python -assert asyncio.run(integers.acount()) == 10 -``` -
- - -### `await` - -> *Awaiting* the stream iterates over it until exhaustion and returns it: - -
πŸ‘€ show snippet
+import pandas as pd -```python -async def test_await() -> None: - state: List[int] = [] - appending_integers: Stream[int] = integers.foreach(state.append) - appending_integers is await appending_integers - assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] -asyncio.run(test_await()) +( + integers + .observe("ints") + .pipe(pd.DataFrame, columns=["integer"]) + .to_csv("integers.csv", index=False) +) ```
---- ---- - # πŸ’‘ Notes ## Exceptions are not terminating the iteration diff --git a/tests/test_readme.py b/tests/test_readme.py index b5627d41..cf748176 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -347,7 +347,7 @@ async def main() -> None: Stream(itertools.count(1)) # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs .throttle(16, per=timedelta(seconds=1)) - # GETs pokemons via 8 concurrent asyncio coroutines + # GETs pokemons via 8 concurrent coroutines .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") .amap(http_async_client.get, concurrency=8) .foreach(httpx.Response.raise_for_status) From 8d968490c0c08d50cee45735b9500ca8e9068124 Mon Sep 17 00:00:00 2001 From: Enzo Bonnal Date: Tue, 13 May 2025 23:11:00 +0100 Subject: [PATCH 6/7] README: clarify Co-authored-by: laurylopes --- README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 01aea391..cccbf145 100644 --- a/README.md +++ b/README.md @@ -109,9 +109,11 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]` (or `A
-# ↔ example: Extract-Transform-Load +# ↔ Showcase: Extract-Transform-Load -Let's take an example showcasing most of the `Stream`'s operations: Below is a pipeline that extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: +Let's take an example showcasing most of the `Stream`'s operations: + +This script extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: ```python import csv @@ -217,7 +219,7 @@ asyncio.run(main()) A dozen expressive lazy operations and that's it. > [!NOTE] -> **`async` counterparts:** The operations accepting a function as an argument have an `async` counterpart operation (same name but with an "`a`" prefix), which has the same signature but accepts `async` functions instead. Both regular and `async` operations can be mixed on the same `Stream`, and it can then be consumed as regular `Iterable` or as `AsyncIterable`. +> **`async` twin operations:** Each operation that takes a function also has an async version (same name with an β€œ`a`” prefix) that accepts `async` functions. You can mix both types of operations on the same `Stream`, which can be used as either an `Iterable` or an `AsyncIterable`. ## `.map`/`.amap` From f7cd1b64fccf98642d663a27d16eda132c8e1655 Mon Sep 17 00:00:00 2001 From: ebonnal Date: Thu, 15 May 2025 15:00:14 +0100 Subject: [PATCH 7/7] 1.6.0a2 --- version.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/version.py b/version.py index 3ca80c9b..58329072 100644 --- a/version.py +++ b/version.py @@ -1,2 +1,2 @@ -# print CHANGELOG: git log --oneline -- version.py | grep -v '\-rc' -__version__ = "1.6.0a1" +# print CHANGELOG: git log --oneline -- version.py +__version__ = "1.6.0a2"