diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml index 02bed3e0..09157371 100644 --- a/.github/workflows/pypi.yml +++ b/.github/workflows/pypi.yml @@ -4,6 +4,9 @@ on: push: branches: - main + - alpha + - beta + - rc paths: - 'version.py' diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 0624e37c..06a6cebc 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -27,7 +27,7 @@ is relatively close to Applying an operation simply performs argument validation and returns a new instance of a `Stream` sub-class corresponding to the operation. Defining a stream is basically constructing a composite structure where each node is a `Stream` instance: The above `events` variable holds a `CatchStream[int]` instance, whose `.upstream` attribute points to a `TruncateStream[int]` instance, whose `.upstream` attribute points to a `ForeachStream[int]` instance, whose `.upstream` points to a `Stream[int]` instance. Each node's `.source` attribute points to the same `range(10)`. ## Visitor Pattern -Each node in this composite structure exposes an `.accept` method enabling traversal by a visitor. Both `Stream.__iter__` and `Stream.__repr__` rely on visitor classes defined in the `streamable.visitors` package. +Each node in this composite structure exposes an `.accept` method enabling traversal by a visitor. `Stream.__iter__`/`Stream.__aiter__`/`Stream.__repr__` rely on visitor classes defined in the `streamable.visitors` package. ## Decorator Pattern A `Stream[T]` both inherits from `Iterable[T]` and holds an `Iterable[T]` as its `.source`: when you instantiate a stream from an iterable you decorate it with a fluent interface. diff --git a/README.md b/README.md index 4bf64c0b..cccbf145 100644 --- a/README.md +++ b/README.md @@ -1,31 +1,26 @@ [![coverage](https://codecov.io/gh/ebonnal/streamable/graph/badge.svg?token=S62T0JQK9N)](https://codecov.io/gh/ebonnal/streamable) -[![testing](https://github.com/ebonnal/streamable/actions/workflows/testing.yml/badge.svg?branch=main)](https://github.com/ebonnal/streamable/actions) -[![typing](https://github.com/ebonnal/streamable/actions/workflows/typing.yml/badge.svg?branch=main)](https://github.com/ebonnal/streamable/actions) -[![formatting](https://github.com/ebonnal/streamable/actions/workflows/formatting.yml/badge.svg?branch=main)](https://github.com/ebonnal/streamable/actions) [![PyPI](https://github.com/ebonnal/streamable/actions/workflows/pypi.yml/badge.svg?branch=main)](https://pypi.org/project/streamable) [![Anaconda-Server Badge](https://anaconda.org/conda-forge/streamable/badges/version.svg)](https://anaconda.org/conda-forge/streamable) # ΰΌ„ `streamable` -### *Pythonic Stream-like manipulation of iterables* +### *Pythonic Stream-like manipulation of (async) iterables* - πŸ”— ***Fluent*** chainable lazy operations -- πŸ”€ ***Concurrent*** via *threads*/*processes*/`asyncio` -- πŸ‡Ή ***Typed***, fully annotated, `Stream[T]` is an `Iterable[T]` -- πŸ›‘οΈ ***Tested*** extensively with **Python 3.7 to 3.14** -- πŸͺΆ ***Light***, no dependencies - +- πŸ”€ ***Concurrent*** via *threads*/*processes*/`async` +- πŸ‡Ή Fully ***Typed***, `Stream[T]` is an `Iterable[T]` (and an `AsyncIterable[T]`) +- πŸ›‘οΈ ***Battle-tested*** for prod, extensively tested with **Python 3.7 to 3.14**. ## 1. install -```bash -pip install streamable -``` -*or* -```bash -conda install conda-forge::streamable -``` +`pip install streamable` + +or + +`conda install conda-forge::streamable` + +No dependencies. ## 2. import @@ -35,7 +30,7 @@ from streamable import Stream ## 3. init -Create a `Stream[T]` *decorating* an `Iterable[T]`: +Create a `Stream[T]` *decorating* an `Iterable[T]` (or an `AsyncIterable[T]`): ```python integers: Stream[int] = Stream(range(10)) @@ -55,9 +50,14 @@ inverses: Stream[float] = ( ## 5. iterate -Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, elements are processed *on-the-fly*: +Iterate over a `Stream[T]` just as you would over any other `Iterable[T]` (or `AsyncIterable[T]`), elements are processed *on-the-fly*: -- **collect** + +### as an `Iterable[T]` + +
πŸ‘€ show snippets
+ +- **into data structure** ```python >>> list(inverses) [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] @@ -65,7 +65,13 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, eleme {0.5, 1.0, 0.2, 0.33, 0.25, 0.17, 0.14, 0.12, 0.11} ``` -- **reduce** +- **`for`** +```python +>>> [inverse for inverse in inverses]: +[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] +``` + +- **`reduce`** ```python >>> sum(inverses) 2.82 @@ -73,27 +79,153 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, eleme >>> reduce(..., inverses) ``` -- **loop** +- **`iter`/`next`** ```python ->>> for inverse in inverses: ->>> ... +>>> next(iter(inverses)) +1.0 ``` -- **next** +
+ +### as an `AsyncIterable[T]` + +
πŸ‘€ show snippets
+ +- **`async for`** ```python ->>> next(iter(inverses)) +>>> async def main() -> List[float]: +>>> return [inverse async for inverse in inverses] + +>>> asyncio.run(main()) +[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] +``` + +- **`aiter`/`anext`** +```python +>>> asyncio.run(anext(aiter(inverses))) # before 3.10: inverses.__aiter__().__anext__() 1.0 ``` +
+ + +# ↔ Showcase: Extract-Transform-Load + +Let's take an example showcasing most of the `Stream`'s operations: + +This script extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: + +```python +import csv +from datetime import timedelta +import itertools +import requests +from streamable import Stream + +with open("./quadruped_pokemons.csv", mode="w") as file: + fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] + writer = csv.DictWriter(file, fields, extrasaction='ignore') + writer.writeheader() + + pipeline: Stream = ( + # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur + Stream(itertools.count(1)) + # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs + .throttle(16, per=timedelta(seconds=1)) + # GETs pokemons concurrently using a pool of 8 threads + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .map(requests.get, concurrency=8) + .foreach(requests.Response.raise_for_status) + .map(requests.Response.json) + # Stops the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keeps only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + .observe("quadruped pokemons") + # Catches errors due to None "generation" or "shape" + .catch( + TypeError, + when=lambda error: str(error) == "'NoneType' object is not subscriptable" + ) + # Writes a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catches exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) + + pipeline() +``` + +## or the `async` way + +Use the `.amap` operation and `await` the `Stream`: + +```python +import asyncio +import csv +from datetime import timedelta +import itertools +import httpx +from streamable import Stream + +async def main() -> None: + with open("./quadruped_pokemons.csv", mode="w") as file: + fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] + writer = csv.DictWriter(file, fields, extrasaction='ignore') + writer.writeheader() + + async with httpx.AsyncClient() as http_async_client: + pipeline: Stream = ( + # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur + Stream(itertools.count(1)) + # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs + .throttle(16, per=timedelta(seconds=1)) + # GETs pokemons via 8 concurrent coroutines + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .amap(http_async_client.get, concurrency=8) + .foreach(httpx.Response.raise_for_status) + .map(httpx.Response.json) + # Stops the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keeps only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + .observe("quadruped pokemons") + # Catches errors due to None "generation" or "shape" + .catch( + TypeError, + when=lambda error: str(error) == "'NoneType' object is not subscriptable" + ) + # Writes a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catches exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) + + await pipeline + +asyncio.run(main()) +``` + # πŸ“’ ***Operations*** -*A dozen expressive lazy operations and that’s it!* +A dozen expressive lazy operations and that's it. -# `.map` +> [!NOTE] +> **`async` twin operations:** Each operation that takes a function also has an async version (same name with an β€œ`a`” prefix) that accepts `async` functions. You can mix both types of operations on the same `Stream`, which can be used as either an `Iterable` or an `AsyncIterable`. + +## `.map`/`.amap` > Applies a transformation on elements: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integer_strings: Stream[str] = integers.map(str) @@ -102,16 +234,11 @@ assert list(integer_strings) == ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9 ```
-## concurrency - -> [!NOTE] -> By default, all the concurrency modes presented below yield results in the upstream order (FIFO). Set the parameter `ordered=False` to yield results as they become available (***First Done, First Out***). - ### thread-based concurrency -> Applies the transformation via `concurrency` threads: +> Applies the transformation via `concurrency` threads, yielding results in the upstream order (FIFO), set the parameter `ordered=False` to yield results as they become available (*First Done, First Out*). -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import requests @@ -128,16 +255,13 @@ assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur']
> [!NOTE] -> `concurrency` is also the size of the buffer containing not-yet-yielded results. **If the buffer is full, the iteration over the upstream is paused** until a result is yielded from the buffer. - -> [!TIP] -> The performance of thread-based concurrency in a CPU-bound script can be drastically improved by using a [Python 3.13+ free-threading build](https://docs.python.org/3/using/configure.html#cmdoption-disable-gil). +> **Memory-efficient**: Only `concurrency` upstream elements are pulled for processing; the next upstream element is pulled only when a result is yielded downstream. ### process-based concurrency > Set `via="process"`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python if __name__ == "__main__": @@ -149,15 +273,17 @@ if __name__ == "__main__": ```
-### `asyncio`-based concurrency +### `async`-based concurrency: `.amap` + +> `.amap` can apply an `async` transformation concurrently. -> The sibling operation `.amap` applies an async function: +
πŸ‘€ show snippet
-
πŸ‘€ show example
+- consumed as an `Iterable[T]`: ```python -import httpx import asyncio +import httpx http_async_client = httpx.AsyncClient() @@ -170,15 +296,35 @@ pokemon_names: Stream[str] = ( ) assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] -asyncio.get_event_loop().run_until_complete(http_async_client.aclose()) +asyncio.run(http_async_client.aclose()) +``` + +- consumed as an `AsyncIterable[T]`: + +```python +import asyncio +import httpx + +async def main() -> None: + async with httpx.AsyncClient() as http_async_client: + pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http_async_client.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) + ) + assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] + +asyncio.run(main()) ```
-## "starmap" +### "starmap" > The `star` function decorator transforms a function that takes several positional arguments into a function that takes a tuple: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable import star @@ -193,14 +339,11 @@ assert list(zeros) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
- -# `.foreach` - - +## `.foreach` / `.aforeach` > Applies a side effect on elements: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python state: List[int] = [] @@ -211,19 +354,20 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ```
-## concurrency +### concurrency > Similar to `.map`: > - set the `concurrency` parameter for **thread-based concurrency** > - set `via="process"` for **process-based concurrency** -> - use the sibling `.aforeach` operation for **`asyncio`-based concurrency** > - set `ordered=False` for ***First Done First Out*** +> - The `.aforeach` operation can apply an `async` effect concurrently. -# `.group` +## `.group` / `.agroup` -> Groups elements into `List`s: +> Groups into `List`s -
πŸ‘€ show example
+> ... up to a given group `size`: +
πŸ‘€ show snippet
```python integers_by_5: Stream[List[int]] = integers.group(size=5) @@ -231,7 +375,10 @@ integers_by_5: Stream[List[int]] = integers.group(size=5) assert list(integers_by_5) == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]] ```
-
πŸ‘€ show example
+ +> ... and/or co-groups `by` a given key: + +
πŸ‘€ show snippet
```python integers_by_parity: Stream[List[int]] = integers.group(by=lambda n: n % 2) @@ -239,7 +386,10 @@ integers_by_parity: Stream[List[int]] = integers.group(by=lambda n: n % 2) assert list(integers_by_parity) == [[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]] ```
-
πŸ‘€ show example
+ +> ... and/or co-groups the elements yielded by the upstream within a given time `interval`: + +
πŸ‘€ show snippet
```python from datetime import timedelta @@ -254,8 +404,10 @@ assert list(integers_within_1_sec) == [[0, 1, 2], [3, 4], [5, 6], [7, 8], [9]] ```
-> Mix the `size`/`by`/`interval` parameters: -
πŸ‘€ show example
+> [!TIP] +> Combine the `size`/`by`/`interval` parameters: + +
πŸ‘€ show snippet
```python integers_by_parity_by_2: Stream[List[int]] = ( @@ -267,11 +419,10 @@ assert list(integers_by_parity_by_2) == [[0, 2], [1, 3], [4, 6], [5, 7], [8], [9 ```
- -## `.groupby` +## `.groupby` / `.agroupby` > Like `.group`, but groups into `(key, elements)` tuples: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_by_parity: Stream[Tuple[str, List[int]]] = ( @@ -286,7 +437,7 @@ assert list(integers_by_parity) == [("even", [0, 2, 4, 6, 8]), ("odd", [1, 3, 5, > [!TIP] > Then *"starmap"* over the tuples: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable import star @@ -300,11 +451,11 @@ assert list(counts_by_parity) == [("even", 5), ("odd", 5)] ```
-# `.flatten` +## `.flatten` / `.aflatten` -> Ungroups elements assuming that they are `Iterable`s: +> Ungroups elements assuming that they are `Iterable`s (or `AsyncIterable`s for `.aflatten`): -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python even_then_odd_integers: Stream[int] = integers_by_parity.flatten() @@ -315,9 +466,9 @@ assert list(even_then_odd_integers) == [0, 2, 4, 6, 8, 1, 3, 5, 7, 9] ### thread-based concurrency -> Flattens `concurrency` iterables concurrently: +> Concurrently flattens `concurrency` iterables via threads (or via coroutines for `.aflatten`): -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python mixed_ones_and_zeros: Stream[int] = ( @@ -328,11 +479,11 @@ assert list(mixed_ones_and_zeros) == [0, 1, 0, 1, 0, 1, 0, 1] ```
-# `.filter` +## `.filter` / `.afilter` > Keeps only the elements that satisfy a condition: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python even_integers: Stream[int] = integers.filter(lambda n: n % 2 == 0) @@ -341,11 +492,11 @@ assert list(even_integers) == [0, 2, 4, 6, 8] ```
-# `.distinct` +## `.distinct` / `.adistinct` > Removes duplicates: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python distinct_chars: Stream[str] = Stream("foobarfooo").distinct() @@ -356,7 +507,7 @@ assert list(distinct_chars) == ["f", "o", "b", "a", "r"] > specifying a deduplication `key`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python strings_of_distinct_lengths: Stream[str] = ( @@ -371,7 +522,7 @@ assert list(strings_of_distinct_lengths) == ["a", "foo"] > [!WARNING] > During iteration, all distinct elements that are yielded are retained in memory to perform deduplication. However, you can remove only consecutive duplicates without a memory footprint by setting `consecutive_only=True`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python consecutively_distinct_chars: Stream[str] = ( @@ -383,11 +534,11 @@ assert list(consecutively_distinct_chars) == ["f", "o", "b", "a", "r", "f", "o"] ```
-# `.truncate` +## `.truncate` / `.atruncate` > Ends iteration once a given number of elements have been yielded: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python five_first_integers: Stream[int] = integers.truncate(5) @@ -398,7 +549,7 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4] > or `when` a condition is satisfied: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python five_first_integers: Stream[int] = integers.truncate(when=lambda n: n == 5) @@ -409,11 +560,11 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4] > If both `count` and `when` are set, truncation occurs as soon as either condition is met. -# `.skip` +## `.skip` / `.askip` > Skips the first specified number of elements: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_after_five: Stream[int] = integers.skip(5) @@ -424,7 +575,7 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9] > or skips elements `until` a predicate is satisfied: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python integers_after_five: Stream[int] = integers.skip(until=lambda n: n >= 5) @@ -435,11 +586,11 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9] > If both `count` and `until` are set, skipping stops as soon as either condition is met. -# `.catch` +## `.catch` / `.acatch` > Catches a given type of exception, and optionally yields a `replacement` value: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python inverses: Stream[float] = ( @@ -453,7 +604,7 @@ assert list(inverses) == [float("inf"), 1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0
> You can specify an additional `when` condition for the catch: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import requests @@ -473,9 +624,9 @@ assert list(status_codes_ignoring_resolution_errors) == [200, 404] > It has an optional `finally_raise: bool` parameter to raise the first exception caught (if any) when the iteration terminates. > [!TIP] -> Apply side effects when catching an exception by integrating them into `when`: +> Leverage `when` to apply side effects on catch: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python errors: List[Exception] = [] @@ -495,12 +646,11 @@ assert len(errors) == len("foo") ```
- -# `.throttle` +## `.throttle` > Limits the number of yields `per` time interval: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from datetime import timedelta @@ -513,10 +663,10 @@ assert list(three_integers_per_second) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
-# `.observe` +## `.observe` > Logs the progress of iterations: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python >>> assert list(integers.throttle(2, per=timedelta(seconds=1)).observe("integers")) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] @@ -537,7 +687,7 @@ INFO: [duration=0:00:04.003852 errors=0] 10 integers yielded > [!TIP] > To mute these logs, set the logging level above `INFO`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python import logging @@ -545,11 +695,11 @@ logging.getLogger("streamable").setLevel(logging.WARNING) ```
-# `+` +## `+` > Concatenates streams: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4, 5, 6, 7, 8, 9] @@ -557,12 +707,11 @@ assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4
-# `zip` +## `zip` -> [!TIP] > Use the standard `zip` function: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable import star @@ -578,29 +727,34 @@ assert list(cubes) == [0, 1, 8, 27, 64, 125, 216, 343, 512, 729] ## Shorthands for consuming the stream -> [!NOTE] -> Although consuming the stream is beyond the scope of this library, it provides two basic shorthands to trigger an iteration: -## `.count` +Although consuming the stream is beyond the scope of this library, it provides two basic shorthands to trigger an iteration: +## `.count` / `.acount` +> `.count` iterates over the stream until exhaustion and returns the number of elements yielded: -> Iterates over the stream until exhaustion and returns the number of elements yielded: - -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python assert integers.count() == 10 ```
+> The `.acount` (`async` method) iterates over the stream as an `AsyncIterable` until exhaustion and returns the number of elements yielded: + +
πŸ‘€ show snippet
-## `()` +```python +assert asyncio.run(integers.acount()) == 10 +``` +
+## `()` / `await` -> *Calling* the stream iterates over it until exhaustion and returns it: -
πŸ‘€ show example
+> *Calling* the stream iterates over it until exhaustion, and returns it: +
πŸ‘€ show snippet
```python state: List[int] = [] @@ -610,12 +764,25 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] ```
+> *Awaiting* the stream iterates over it as an `AsyncIterable` until exhaustion, and returns it: -# `.pipe` +
πŸ‘€ show snippet
-> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any: +```python +async def test_await() -> None: + state: List[int] = [] + appending_integers: Stream[int] = integers.foreach(state.append) + appending_integers is await appending_integers + assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] +asyncio.run(test_await()) +``` +
-
πŸ‘€ show example
+## `.pipe` + +> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any (inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html)): + +
πŸ‘€ show snippet
```python import pandas as pd @@ -629,9 +796,6 @@ import pandas as pd ```
-> Inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html). - - # πŸ’‘ Notes ## Exceptions are not terminating the iteration @@ -639,7 +803,7 @@ import pandas as pd > [!TIP] > If any of the operations raises an exception, you can resume the iteration after handling it: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from contextlib import suppress @@ -662,63 +826,11 @@ assert collected == [0, 1, 2, 3, 5, 6, 7, 8, 9]
-## Extract-Transform-Load -> [!TIP] -> **Custom ETL scripts** can benefit from the expressiveness of this library. Below is a pipeline that extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV: - -
πŸ‘€ show example
- -```python -import csv -from datetime import timedelta -import itertools -import requests -from streamable import Stream - -with open("./quadruped_pokemons.csv", mode="w") as file: - fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] - writer = csv.DictWriter(file, fields, extrasaction='ignore') - writer.writeheader() - - pipeline: Stream = ( - # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur - Stream(itertools.count(1)) - # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs - .throttle(16, per=timedelta(seconds=1)) - # GETs pokemons concurrently using a pool of 8 threads - .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") - .map(requests.get, concurrency=8) - .foreach(requests.Response.raise_for_status) - .map(requests.Response.json) - # Stops the iteration when reaching the 1st pokemon of the 4th generation - .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") - .observe("pokemons") - # Keeps only quadruped Pokemons - .filter(lambda poke: poke["shape"]["name"] == "quadruped") - .observe("quadruped pokemons") - # Catches errors due to None "generation" or "shape" - .catch( - TypeError, - when=lambda error: str(error) == "'NoneType' object is not subscriptable" - ) - # Writes a batch of pokemons every 5 seconds to the CSV file - .group(interval=timedelta(seconds=5)) - .foreach(writer.writerows) - .flatten() - .observe("written pokemons") - # Catches exceptions and raises the 1st one at the end of the iteration - .catch(Exception, finally_raise=True) - ) - - pipeline() -``` -
- ## Visitor Pattern > [!TIP] > A `Stream` can be visited via its `.accept` method: implement a custom [***visitor***](https://en.wikipedia.org/wiki/Visitor_pattern) by extending the abstract class `streamable.visitors.Visitor`: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable.visitors import Visitor @@ -739,7 +851,7 @@ assert depth(Stream(range(10)).map(str).foreach(print)) == 3 ## Functions > [!TIP] > The `Stream`'s methods are also exposed as functions: -
πŸ‘€ show example
+
πŸ‘€ show snippet
```python from streamable.functions import catch diff --git a/setup.py b/setup.py index 4e6561a5..1ae77f09 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,5 @@ from setuptools import find_packages, setup # type: ignore + from version import __version__ setup( @@ -10,7 +11,7 @@ license="Apache 2.", author="ebonnal", author_email="bonnal.enzo.dev@gmail.com", - description="Pythonic Stream-like manipulation of iterables", + description="Pythonic Stream-like manipulation of (async) iterables", long_description=open("README.md").read(), long_description_content_type="text/markdown", ) diff --git a/streamable/__init__.py b/streamable/__init__.py index a51ad250..17cf16ae 100644 --- a/streamable/__init__.py +++ b/streamable/__init__.py @@ -1,2 +1,4 @@ from streamable.stream import Stream from streamable.util.functiontools import star + +__all__ = ["Stream", "star"] diff --git a/streamable/afunctions.py b/streamable/afunctions.py new file mode 100644 index 00000000..5532de3c --- /dev/null +++ b/streamable/afunctions.py @@ -0,0 +1,353 @@ +import builtins +import datetime +from contextlib import suppress +from operator import itemgetter +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Callable, + Coroutine, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) + +from streamable.aiterators import ( + ACatchAsyncIterator, + ADistinctAsyncIterator, + AFilterAsyncIterator, + AFlattenAsyncIterator, + AGroupbyAsyncIterator, + AMapAsyncIterator, + ConcurrentAFlattenAsyncIterator, + ConcurrentAMapAsyncIterator, + ConcurrentFlattenAsyncIterator, + ConcurrentMapAsyncIterator, + ConsecutiveADistinctAsyncIterator, + CountAndPredicateASkipAsyncIterator, + CountSkipAsyncIterator, + CountTruncateAsyncIterator, + FlattenAsyncIterator, + GroupAsyncIterator, + ObserveAsyncIterator, + PredicateASkipAsyncIterator, + PredicateATruncateAsyncIterator, + YieldsPerPeriodThrottleAsyncIterator, +) +from streamable.util.constants import NO_REPLACEMENT +from streamable.util.functiontools import asyncify +from streamable.util.validationtools import ( + validate_aiterator, + validate_concurrency, + validate_errors, + validate_group_size, + # validate_not_none, + validate_optional_count, + validate_optional_positive_count, + validate_optional_positive_interval, + validate_via, +) + +with suppress(ImportError): + from typing import Literal + +T = TypeVar("T") +U = TypeVar("U") + + +def catch( + aiterator: AsyncIterator[T], + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Any]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, +) -> AsyncIterator[T]: + return acatch( + aiterator, + errors, + when=asyncify(when) if when else None, + replacement=replacement, + finally_raise=finally_raise, + ) + + +def acatch( + aiterator: AsyncIterator[T], + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_errors(errors) + # validate_not_none(finally_raise, "finally_raise") + if errors is None: + return aiterator + return ACatchAsyncIterator( + aiterator, + tuple(builtins.filter(None, errors)) + if isinstance(errors, Iterable) + else errors, + when=when, + replacement=replacement, + finally_raise=finally_raise, + ) + + +def distinct( + aiterator: AsyncIterator[T], + key: Optional[Callable[[T], Any]] = None, + *, + consecutive_only: bool = False, +) -> AsyncIterator[T]: + return adistinct( + aiterator, + asyncify(key) if key else None, + consecutive_only=consecutive_only, + ) + + +def adistinct( + aiterator: AsyncIterator[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + *, + consecutive_only: bool = False, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + # validate_not_none(consecutive_only, "consecutive_only") + if consecutive_only: + return ConsecutiveADistinctAsyncIterator(aiterator, key) + return ADistinctAsyncIterator(aiterator, key) + + +def filter( + aiterator: AsyncIterator[T], + when: Callable[[T], Any], +) -> AsyncIterator[T]: + return afilter(aiterator, asyncify(when)) + + +def afilter( + aiterator: AsyncIterator[T], + when: Callable[[T], Any], +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + return AFilterAsyncIterator(aiterator, when) + + +def flatten( + aiterator: AsyncIterator[Iterable[T]], *, concurrency: int = 1 +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_concurrency(concurrency) + if concurrency == 1: + return FlattenAsyncIterator(aiterator) + else: + return ConcurrentFlattenAsyncIterator( + aiterator, + concurrency=concurrency, + buffersize=concurrency, + ) + + +def aflatten( + aiterator: AsyncIterator[AsyncIterable[T]], *, concurrency: int = 1 +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_concurrency(concurrency) + if concurrency == 1: + return AFlattenAsyncIterator(aiterator) + else: + return ConcurrentAFlattenAsyncIterator( + aiterator, + concurrency=concurrency, + buffersize=concurrency, + ) + + +def group( + aiterator: AsyncIterator[T], + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Any]] = None, +) -> AsyncIterator[List[T]]: + return agroup( + aiterator, + size, + interval=interval, + by=asyncify(by) if by else None, + ) + + +def agroup( + aiterator: AsyncIterator[T], + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> AsyncIterator[List[T]]: + validate_aiterator(aiterator) + validate_group_size(size) + validate_optional_positive_interval(interval) + if by is None: + return GroupAsyncIterator(aiterator, size, interval) + return map(itemgetter(1), AGroupbyAsyncIterator(aiterator, by, size, interval)) + + +def groupby( + aiterator: AsyncIterator[T], + key: Callable[[T], U], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, +) -> AsyncIterator[Tuple[U, List[T]]]: + return agroupby(aiterator, asyncify(key), size=size, interval=interval) + + +def agroupby( + aiterator: AsyncIterator[T], + key: Callable[[T], Coroutine[Any, Any, U]], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, +) -> AsyncIterator[Tuple[U, List[T]]]: + validate_aiterator(aiterator) + validate_group_size(size) + validate_optional_positive_interval(interval) + return AGroupbyAsyncIterator(aiterator, key, size, interval) + + +def map( + transformation: Callable[[T], U], + aiterator: AsyncIterator[T], + *, + concurrency: int = 1, + ordered: bool = True, + via: "Literal['thread', 'process']" = "thread", +) -> AsyncIterator[U]: + validate_aiterator(aiterator) + # validate_not_none(transformation, "transformation") + # validate_not_none(ordered, "ordered") + validate_concurrency(concurrency) + validate_via(via) + if concurrency == 1: + return amap(asyncify(transformation), aiterator) + else: + return ConcurrentMapAsyncIterator( + aiterator, + transformation, + concurrency=concurrency, + buffersize=concurrency, + ordered=ordered, + via=via, + ) + + +def amap( + transformation: Callable[[T], Coroutine[Any, Any, U]], + aiterator: AsyncIterator[T], + *, + concurrency: int = 1, + ordered: bool = True, +) -> AsyncIterator[U]: + validate_aiterator(aiterator) + # validate_not_none(transformation, "transformation") + # validate_not_none(ordered, "ordered") + validate_concurrency(concurrency) + if concurrency == 1: + return AMapAsyncIterator(aiterator, transformation) + return ConcurrentAMapAsyncIterator( + aiterator, + transformation, + buffersize=concurrency, + ordered=ordered, + ) + + +def observe(aiterator: AsyncIterator[T], what: str) -> AsyncIterator[T]: + validate_aiterator(aiterator) + # validate_not_none(what, "what") + return ObserveAsyncIterator(aiterator, what) + + +def skip( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Any]] = None, +) -> AsyncIterator[T]: + return askip( + aiterator, + count, + until=asyncify(until) if until else None, + ) + + +def askip( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_optional_count(count) + if until is not None: + if count is not None: + return CountAndPredicateASkipAsyncIterator(aiterator, count, until) + return PredicateASkipAsyncIterator(aiterator, until) + if count is not None: + return CountSkipAsyncIterator(aiterator, count) + return aiterator + + +def throttle( + aiterator: AsyncIterator[T], + count: Optional[int], + *, + per: Optional[datetime.timedelta] = None, +) -> AsyncIterator[T]: + validate_optional_positive_count(count) + validate_optional_positive_interval(per, name="per") + if count and per: + aiterator = YieldsPerPeriodThrottleAsyncIterator(aiterator, count, per) + return aiterator + + +def truncate( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Any]] = None, +) -> AsyncIterator[T]: + return atruncate( + aiterator, + count, + when=asyncify(when) if when else None, + ) + + +def atruncate( + aiterator: AsyncIterator[T], + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> AsyncIterator[T]: + validate_aiterator(aiterator) + validate_optional_count(count) + if count is not None: + aiterator = CountTruncateAsyncIterator(aiterator, count) + if when is not None: + aiterator = PredicateATruncateAsyncIterator(aiterator, when) + return aiterator diff --git a/streamable/aiterators.py b/streamable/aiterators.py new file mode 100644 index 00000000..72fb85d8 --- /dev/null +++ b/streamable/aiterators.py @@ -0,0 +1,924 @@ +import asyncio +import datetime +import multiprocessing +import queue +import time +from abc import ABC, abstractmethod +from collections import defaultdict, deque +from concurrent.futures import Executor, Future, ProcessPoolExecutor, ThreadPoolExecutor +from contextlib import contextmanager, suppress +from math import ceil +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + ContextManager, + Coroutine, + DefaultDict, + Deque, + Generic, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) + +from streamable.util.asynctools import ( + GetEventLoopMixin, + awaitable_to_coroutine, + empty_aiter, +) +from streamable.util.functiontools import ( + aiter_wo_stopasynciteration, + awrap_error, + iter_wo_stopasynciteration, + wrap_error, +) +from streamable.util.loggertools import get_logger +from streamable.util.validationtools import ( + validate_aiterator, + validate_base, + validate_buffersize, + validate_concurrency, + validate_count, + validate_errors, + validate_group_size, + validate_optional_positive_interval, +) + +T = TypeVar("T") +U = TypeVar("U") + +from streamable.util.constants import NO_REPLACEMENT +from streamable.util.futuretools import ( + FDFOAsyncFutureResultCollection, + FDFOOSFutureResultCollection, + FIFOAsyncFutureResultCollection, + FIFOOSFutureResultCollection, + FutureResultCollection, +) + +with suppress(ImportError): + from typing import Literal + + +class ACatchAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + errors: Union[Type[Exception], Tuple[Type[Exception], ...]], + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]], + replacement: T, + finally_raise: bool, + ) -> None: + validate_aiterator(iterator) + validate_errors(errors) + self.iterator = iterator + self.errors = errors + self.when = awrap_error(when, StopAsyncIteration) if when else None + self.replacement = replacement + self.finally_raise = finally_raise + self._to_be_finally_raised: Optional[Exception] = None + + async def __anext__(self) -> T: + while True: + try: + return await self.iterator.__anext__() + except StopAsyncIteration: + if self._to_be_finally_raised: + try: + raise self._to_be_finally_raised + finally: + self._to_be_finally_raised = None + raise + except self.errors as e: + if not self.when or await self.when(e): + if self.finally_raise and not self._to_be_finally_raised: + self._to_be_finally_raised = e + if self.replacement is not NO_REPLACEMENT: + return self.replacement + continue + raise + + +class ADistinctAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.key = awrap_error(key, StopAsyncIteration) if key else None + self._already_seen: Set[Any] = set() + + async def __anext__(self) -> T: + while True: + elem = await self.iterator.__anext__() + key = await self.key(elem) if self.key else elem + if key not in self._already_seen: + break + self._already_seen.add(key) + return elem + + +class ConsecutiveADistinctAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.key = awrap_error(key, StopAsyncIteration) if key else None + self._last_key: Any = object() + + async def __anext__(self) -> T: + while True: + elem = await self.iterator.__anext__() + key = await self.key(elem) if self.key else elem + if key != self._last_key: + break + self._last_key = key + return elem + + +class FlattenAsyncIterator(AsyncIterator[U]): + def __init__(self, iterator: AsyncIterator[Iterable[U]]) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self._current_iterator_elem: Iterator[U] = tuple().__iter__() + + async def __anext__(self) -> U: + while True: + try: + return self._current_iterator_elem.__next__() + except StopIteration: + self._current_iterator_elem = iter_wo_stopasynciteration( + await self.iterator.__anext__() + ) + + +class AFlattenAsyncIterator(AsyncIterator[U]): + def __init__(self, iterator: AsyncIterator[AsyncIterable[U]]) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self._current_iterator_elem: AsyncIterator[U] = empty_aiter() + + async def __anext__(self) -> U: + while True: + try: + return await self._current_iterator_elem.__anext__() + except StopAsyncIteration: + self._current_iterator_elem = aiter_wo_stopasynciteration( + await self.iterator.__anext__() + ) + + +class _GroupAsyncIteratorMixin(Generic[T]): + def __init__( + self, + iterator: AsyncIterator[T], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + validate_aiterator(iterator) + validate_group_size(size) + validate_optional_positive_interval(interval) + self.iterator = iterator + self.size = size or cast(int, float("inf")) + self.interval = interval + self._interval_seconds = interval.total_seconds() if interval else float("inf") + self._to_be_raised: Optional[Exception] = None + self._last_group_yielded_at: float = 0 + + def _interval_seconds_have_elapsed(self) -> bool: + if not self.interval: + return False + return ( + time.perf_counter() - self._last_group_yielded_at + ) >= self._interval_seconds + + def _remember_group_time(self) -> None: + if self.interval: + self._last_group_yielded_at = time.perf_counter() + + def _init_last_group_time(self) -> None: + if self.interval and not self._last_group_yielded_at: + self._last_group_yielded_at = time.perf_counter() + + +class GroupAsyncIterator(_GroupAsyncIteratorMixin[T], AsyncIterator[List[T]]): + def __init__( + self, + iterator: AsyncIterator[T], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + super().__init__(iterator, size, interval) + self._current_group: List[T] = [] + + async def __anext__(self) -> List[T]: + self._init_last_group_time() + if self._to_be_raised: + try: + raise self._to_be_raised + finally: + self._to_be_raised = None + try: + while len(self._current_group) < self.size and ( + not self._interval_seconds_have_elapsed() or not self._current_group + ): + self._current_group.append(await self.iterator.__anext__()) + except Exception as e: + if not self._current_group: + raise + self._to_be_raised = e + + group, self._current_group = self._current_group, [] + self._remember_group_time() + return group + + +class AGroupbyAsyncIterator( + _GroupAsyncIteratorMixin[T], AsyncIterator[Tuple[U, List[T]]] +): + def __init__( + self, + iterator: AsyncIterator[T], + key: Callable[[T], Coroutine[Any, Any, U]], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + super().__init__(iterator, size, interval) + self.key = awrap_error(key, StopAsyncIteration) + self._is_exhausted = False + self._groups_by: DefaultDict[U, List[T]] = defaultdict(list) + + async def _group_next_elem(self) -> None: + elem = await self.iterator.__anext__() + self._groups_by[await self.key(elem)].append(elem) + + def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]: + for key, group in self._groups_by.items(): + if len(group) >= self.size: + return key, self._groups_by.pop(key) + return None + + def _pop_first_group(self) -> Tuple[U, List[T]]: + first_key: U = self._groups_by.__iter__().__next__() + return first_key, self._groups_by.pop(first_key) + + def _pop_largest_group(self) -> Tuple[U, List[T]]: + largest_group_key: Any = self._groups_by.__iter__().__next__() + + for key, group in self._groups_by.items(): + if len(group) > len(self._groups_by[largest_group_key]): + largest_group_key = key + + return largest_group_key, self._groups_by.pop(largest_group_key) + + async def __anext__(self) -> Tuple[U, List[T]]: + self._init_last_group_time() + if self._is_exhausted: + if self._groups_by: + return self._pop_first_group() + raise StopAsyncIteration + + if self._to_be_raised: + if self._groups_by: + self._remember_group_time() + return self._pop_first_group() + try: + raise self._to_be_raised + finally: + self._to_be_raised = None + + try: + await self._group_next_elem() + + full_group: Optional[Tuple[U, List[T]]] = self._pop_full_group() + while not full_group and not self._interval_seconds_have_elapsed(): + await self._group_next_elem() + full_group = self._pop_full_group() + + self._remember_group_time() + return full_group or self._pop_largest_group() + + except StopAsyncIteration: + self._is_exhausted = True + return await self.__anext__() + + except Exception as e: + self._to_be_raised = e + return await self.__anext__() + + +class CountSkipAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T], count: int) -> None: + validate_aiterator(iterator) + validate_count(count) + self.iterator = iterator + self.count = count + self._n_skipped = 0 + self._done_skipping = False + + async def __anext__(self) -> T: + if not self._done_skipping: + while self._n_skipped < self.count: + await self.iterator.__anext__() + # do not count exceptions as skipped elements + self._n_skipped += 1 + self._done_skipping = True + return await self.iterator.__anext__() + + +class PredicateASkipAsyncIterator(AsyncIterator[T]): + def __init__( + self, iterator: AsyncIterator[T], until: Callable[[T], Coroutine[Any, Any, Any]] + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.until = awrap_error(until, StopAsyncIteration) + self._done_skipping = False + + async def __anext__(self) -> T: + elem = await self.iterator.__anext__() + if not self._done_skipping: + while not await self.until(elem): + elem = await self.iterator.__anext__() + self._done_skipping = True + return elem + + +class CountAndPredicateASkipAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + count: int, + until: Callable[[T], Coroutine[Any, Any, Any]], + ) -> None: + validate_aiterator(iterator) + validate_count(count) + self.iterator = iterator + self.count = count + self.until = awrap_error(until, StopAsyncIteration) + self._n_skipped = 0 + self._done_skipping = False + + async def __anext__(self) -> T: + elem = await self.iterator.__anext__() + if not self._done_skipping: + while self._n_skipped < self.count and not await self.until(elem): + elem = await self.iterator.__anext__() + # do not count exceptions as skipped elements + self._n_skipped += 1 + self._done_skipping = True + return elem + + +class CountTruncateAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T], count: int) -> None: + validate_aiterator(iterator) + validate_count(count) + self.iterator = iterator + self.count = count + self._current_count = 0 + + async def __anext__(self) -> T: + if self._current_count == self.count: + raise StopAsyncIteration() + elem = await self.iterator.__anext__() + self._current_count += 1 + return elem + + +class PredicateATruncateAsyncIterator(AsyncIterator[T]): + def __init__( + self, iterator: AsyncIterator[T], when: Callable[[T], Coroutine[Any, Any, Any]] + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.when = awrap_error(when, StopAsyncIteration) + self._satisfied = False + + async def __anext__(self) -> T: + if self._satisfied: + raise StopAsyncIteration() + elem = await self.iterator.__anext__() + if await self.when(elem): + self._satisfied = True + raise StopAsyncIteration() + return elem + + +class AMapAsyncIterator(AsyncIterator[U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], Coroutine[Any, Any, U]], + ) -> None: + validate_aiterator(iterator) + + self.iterator = iterator + self.transformation = awrap_error(transformation, StopAsyncIteration) + + async def __anext__(self) -> U: + return await self.transformation(await self.iterator.__anext__()) + + +class AFilterAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + when: Callable[[T], Coroutine[Any, Any, Any]], + ) -> None: + validate_aiterator(iterator) + + self.iterator = iterator + self.when = awrap_error(when, StopAsyncIteration) + + async def __anext__(self) -> T: + while True: + elem = await self.iterator.__anext__() + if await self.when(elem): + return elem + + +class ObserveAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: AsyncIterator[T], what: str, base: int = 2) -> None: + validate_aiterator(iterator) + validate_base(base) + + self.iterator = iterator + self.what = what + self.base = base + + self._n_yields = 0 + self._n_errors = 0 + self._n_nexts = 0 + self._logged_n_nexts = 0 + self._next_threshold = 0 + + self._start_time = time.perf_counter() + + def _log(self) -> None: + get_logger().info( + "[%s %s] %s", + f"duration={datetime.datetime.fromtimestamp(time.perf_counter()) - datetime.datetime.fromtimestamp(self._start_time)}", + f"errors={self._n_errors}", + f"{self._n_yields} {self.what} yielded", + ) + self._logged_n_nexts = self._n_nexts + self._next_threshold = self.base * self._logged_n_nexts + + async def __anext__(self) -> T: + try: + elem = await self.iterator.__anext__() + self._n_nexts += 1 + self._n_yields += 1 + return elem + except StopAsyncIteration: + if self._n_nexts != self._logged_n_nexts: + self._log() + raise + except Exception: + self._n_nexts += 1 + self._n_errors += 1 + raise + finally: + if self._n_nexts >= self._next_threshold: + self._log() + + +class YieldsPerPeriodThrottleAsyncIterator(AsyncIterator[T]): + def __init__( + self, + iterator: AsyncIterator[T], + max_yields: int, + period: datetime.timedelta, + ) -> None: + validate_aiterator(iterator) + self.iterator = iterator + self.max_yields = max_yields + self._period_seconds = period.total_seconds() + + self._period_index: int = -1 + self._yields_in_period = 0 + + self._offset: Optional[float] = None + + async def safe_next(self) -> Tuple[Optional[T], Optional[Exception]]: + try: + return await self.iterator.__anext__(), None + except StopAsyncIteration: + raise + except Exception as e: + return None, e + + async def __anext__(self) -> T: + elem, caught_error = await self.safe_next() + + now = time.perf_counter() + if not self._offset: + self._offset = now + now -= self._offset + + num_periods = now / self._period_seconds + period_index = int(num_periods) + + if self._period_index != period_index: + self._period_index = period_index + self._yields_in_period = max(0, self._yields_in_period - self.max_yields) + + if self._yields_in_period >= self.max_yields: + await asyncio.sleep( + (ceil(num_periods) - num_periods) * self._period_seconds + ) + self._yields_in_period += 1 + + if caught_error: + raise caught_error + return cast(T, elem) + + +class _RaisingAsyncIterator(AsyncIterator[T]): + class ExceptionContainer(NamedTuple): + exception: Exception + + def __init__( + self, + iterator: AsyncIterator[Union[T, ExceptionContainer]], + ) -> None: + self.iterator = iterator + + async def __anext__(self) -> T: + elem = await self.iterator.__anext__() + if isinstance(elem, self.ExceptionContainer): + raise elem.exception + return elem + + +class _ConcurrentMapAsyncIterableMixin( + Generic[T, U], + ABC, + AsyncIterable[Union[U, _RaisingAsyncIterator.ExceptionContainer]], +): + """ + Template Method Pattern: + This abstract class's `__iter__` is a skeleton for a queue-based concurrent mapping algorithm + that relies on abstract helper methods (`_context_manager`, `_create_future`, `_future_result_collection`) + that must be implemented by concrete subclasses. + """ + + def __init__( + self, + iterator: AsyncIterator[T], + buffersize: int, + ordered: bool, + ) -> None: + validate_aiterator(iterator) + validate_buffersize(buffersize) + self.iterator = iterator + self.buffersize = buffersize + self.ordered = ordered + + def _context_manager(self) -> ContextManager: + @contextmanager + def dummy_context_manager_generator(): + yield + + return dummy_context_manager_generator() + + @abstractmethod + def _launch_task( + self, elem: T + ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": ... + + # factory method + @abstractmethod + def _future_result_collection( + self, + ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: ... + + async def __aiter__( + self, + ) -> AsyncIterator[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: + with self._context_manager(): + future_results = self._future_result_collection() + + # queue tasks up to buffersize + with suppress(StopAsyncIteration): + while len(future_results) < self.buffersize: + future_results.add_future( + self._launch_task(await self.iterator.__anext__()) + ) + + # wait, queue, yield + while future_results: + result = await future_results.__anext__() + with suppress(StopAsyncIteration): + future_results.add_future( + self._launch_task(await self.iterator.__anext__()) + ) + yield result + + +class _ConcurrentMapAsyncIterable(_ConcurrentMapAsyncIterableMixin[T, U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], U], + concurrency: int, + buffersize: int, + ordered: bool, + via: "Literal['thread', 'process']", + ) -> None: + super().__init__(iterator, buffersize, ordered) + validate_concurrency(concurrency) + self.transformation = wrap_error(transformation, StopAsyncIteration) + self.concurrency = concurrency + self.executor: Executor + self.via = via + + def _context_manager(self) -> ContextManager: + if self.via == "thread": + self.executor = ThreadPoolExecutor(max_workers=self.concurrency) + if self.via == "process": + self.executor = ProcessPoolExecutor(max_workers=self.concurrency) + return self.executor + + # picklable + @staticmethod + def _safe_transformation( + transformation: Callable[[T], U], elem: T + ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]: + try: + return transformation(elem) + except Exception as e: + return _RaisingAsyncIterator.ExceptionContainer(e) + + def _launch_task( + self, elem: T + ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": + return self.executor.submit( + self._safe_transformation, self.transformation, elem + ) + + def _future_result_collection( + self, + ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: + if self.ordered: + return FIFOOSFutureResultCollection() + return FDFOOSFutureResultCollection( + multiprocessing.Queue if self.via == "process" else queue.Queue + ) + + +class ConcurrentMapAsyncIterator(_RaisingAsyncIterator[U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], U], + concurrency: int, + buffersize: int, + ordered: bool, + via: "Literal['thread', 'process']", + ) -> None: + super().__init__( + _ConcurrentMapAsyncIterable( + iterator, + transformation, + concurrency, + buffersize, + ordered, + via, + ).__aiter__() + ) + + +class _ConcurrentAMapAsyncIterable( + _ConcurrentMapAsyncIterableMixin[T, U], GetEventLoopMixin +): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], Coroutine[Any, Any, U]], + buffersize: int, + ordered: bool, + ) -> None: + super().__init__(iterator, buffersize, ordered) + self.transformation = awrap_error(transformation, StopAsyncIteration) + + async def _safe_transformation( + self, elem: T + ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]: + try: + coroutine = self.transformation(elem) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"`transformation` must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}", + ) + return await coroutine + except Exception as e: + return _RaisingAsyncIterator.ExceptionContainer(e) + + def _launch_task( + self, elem: T + ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": + return cast( + "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]", + self.get_event_loop().create_task(self._safe_transformation(elem)), + ) + + def _future_result_collection( + self, + ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: + if self.ordered: + return FIFOAsyncFutureResultCollection(self.get_event_loop()) + else: + return FDFOAsyncFutureResultCollection(self.get_event_loop()) + + +class ConcurrentAMapAsyncIterator(_RaisingAsyncIterator[U]): + def __init__( + self, + iterator: AsyncIterator[T], + transformation: Callable[[T], Coroutine[Any, Any, U]], + buffersize: int, + ordered: bool, + ) -> None: + super().__init__( + _ConcurrentAMapAsyncIterable( + iterator, + transformation, + buffersize, + ordered, + ).__aiter__() + ) + + +class _ConcurrentFlattenAsyncIterable( + AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]] +): + def __init__( + self, + iterables_iterator: AsyncIterator[Iterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + validate_aiterator(iterables_iterator) + validate_concurrency(concurrency) + self.iterables_iterator = iterables_iterator + self.concurrency = concurrency + self.buffersize = buffersize + + async def __aiter__( + self, + ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]: + with ThreadPoolExecutor(max_workers=self.concurrency) as executor: + iterator_and_future_pairs: Deque[Tuple[Iterator[T], Future]] = deque() + element_to_yield: Deque[ + Union[T, _RaisingAsyncIterator.ExceptionContainer] + ] = deque(maxlen=1) + iterator_to_queue: Optional[Iterator[T]] = None + # wait, queue, yield (FIFO) + while True: + if iterator_and_future_pairs: + iterator, future = iterator_and_future_pairs.popleft() + try: + element_to_yield.append(future.result()) + iterator_to_queue = iterator + except StopIteration: + pass + except Exception as e: + element_to_yield.append( + _RaisingAsyncIterator.ExceptionContainer(e) + ) + iterator_to_queue = iterator + + # queue tasks up to buffersize + while len(iterator_and_future_pairs) < self.buffersize: + if not iterator_to_queue: + try: + iterable = await self.iterables_iterator.__anext__() + except StopAsyncIteration: + break + try: + iterator_to_queue = iter_wo_stopasynciteration(iterable) + except Exception as e: + yield _RaisingAsyncIterator.ExceptionContainer(e) + continue + future = executor.submit(next, iterator_to_queue) + iterator_and_future_pairs.append((iterator_to_queue, future)) + iterator_to_queue = None + if element_to_yield: + yield element_to_yield.pop() + if not iterator_and_future_pairs: + break + + +class ConcurrentFlattenAsyncIterator(_RaisingAsyncIterator[T]): + def __init__( + self, + iterables_iterator: AsyncIterator[Iterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + super().__init__( + _ConcurrentFlattenAsyncIterable( + iterables_iterator, + concurrency, + buffersize, + ).__aiter__() + ) + + +class _ConcurrentAFlattenAsyncIterable( + AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]], GetEventLoopMixin +): + def __init__( + self, + iterables_iterator: AsyncIterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + validate_aiterator(iterables_iterator) + validate_concurrency(concurrency) + self.iterables_iterator = iterables_iterator + self.concurrency = concurrency + self.buffersize = buffersize + + async def __aiter__( + self, + ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]: + iterator_and_future_pairs: Deque[Tuple[AsyncIterator[T], Awaitable[T]]] = ( + deque() + ) + element_to_yield: Deque[Union[T, _RaisingAsyncIterator.ExceptionContainer]] = ( + deque(maxlen=1) + ) + iterator_to_queue: Optional[AsyncIterator[T]] = None + # wait, queue, yield (FIFO) + while True: + if iterator_and_future_pairs: + iterator, future = iterator_and_future_pairs.popleft() + try: + element_to_yield.append(await future) + iterator_to_queue = iterator + except StopAsyncIteration: + pass + except Exception as e: + element_to_yield.append(_RaisingAsyncIterator.ExceptionContainer(e)) + iterator_to_queue = iterator + + # queue tasks up to buffersize + while len(iterator_and_future_pairs) < self.buffersize: + if not iterator_to_queue: + try: + iterable = await self.iterables_iterator.__anext__() + except StopAsyncIteration: + break + try: + iterator_to_queue = aiter_wo_stopasynciteration(iterable) + except Exception as e: + yield _RaisingAsyncIterator.ExceptionContainer(e) + continue + future = self.get_event_loop().create_task( + awaitable_to_coroutine( + cast(AsyncIterator, iterator_to_queue).__anext__() + ) + ) + iterator_and_future_pairs.append( + (cast(AsyncIterator, iterator_to_queue), future) + ) + iterator_to_queue = None + if element_to_yield: + yield element_to_yield.pop() + if not iterator_and_future_pairs: + break + + +class ConcurrentAFlattenAsyncIterator(_RaisingAsyncIterator[T]): + def __init__( + self, + iterables_iterator: AsyncIterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + super().__init__( + _ConcurrentAFlattenAsyncIterable( + iterables_iterator, + concurrency, + buffersize, + ).__aiter__() + ) diff --git a/streamable/functions.py b/streamable/functions.py index aa369fd0..601831f2 100644 --- a/streamable/functions.py +++ b/streamable/functions.py @@ -4,6 +4,7 @@ from operator import itemgetter from typing import ( Any, + AsyncIterable, Callable, Coroutine, Iterable, @@ -17,9 +18,12 @@ ) from streamable.iterators import ( - AsyncConcurrentMapIterator, + AFlattenIterator, CatchIterator, + ConcurrentAFlattenIterator, + ConcurrentAMapIterator, ConcurrentFlattenIterator, + ConcurrentMapIterator, ConsecutiveDistinctIterator, CountAndPredicateSkipIterator, CountSkipIterator, @@ -29,13 +33,12 @@ GroupbyIterator, GroupIterator, ObserveIterator, - OSConcurrentMapIterator, PredicateSkipIterator, PredicateTruncateIterator, YieldsPerPeriodThrottleIterator, ) from streamable.util.constants import NO_REPLACEMENT -from streamable.util.functiontools import wrap_error +from streamable.util.functiontools import syncify, wrap_error from streamable.util.validationtools import ( validate_concurrency, validate_errors, @@ -79,6 +82,25 @@ def catch( ) +def acatch( + iterator: Iterator[T], + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, +) -> Iterator[T]: + return catch( + iterator, + errors, + when=syncify(when) if when else None, + replacement=replacement, + finally_raise=finally_raise, + ) + + def distinct( iterator: Iterator[T], key: Optional[Callable[[T], Any]] = None, @@ -92,6 +114,19 @@ def distinct( return DistinctIterator(iterator, key) +def adistinct( + iterator: Iterator[T], + key: Optional[Callable[[T], Any]] = None, + *, + consecutive_only: bool = False, +) -> Iterator[T]: + return distinct( + iterator, + syncify(key) if key else None, + consecutive_only=consecutive_only, + ) + + def flatten(iterator: Iterator[Iterable[T]], *, concurrency: int = 1) -> Iterator[T]: validate_iterator(iterator) validate_concurrency(concurrency) @@ -105,6 +140,21 @@ def flatten(iterator: Iterator[Iterable[T]], *, concurrency: int = 1) -> Iterato ) +def aflatten( + iterator: Iterator[AsyncIterable[T]], *, concurrency: int = 1 +) -> Iterator[T]: + validate_iterator(iterator) + validate_concurrency(concurrency) + if concurrency == 1: + return AFlattenIterator(iterator) + else: + return ConcurrentAFlattenIterator( + iterator, + concurrency=concurrency, + buffersize=concurrency, + ) + + def group( iterator: Iterator[T], size: Optional[int] = None, @@ -120,6 +170,21 @@ def group( return map(itemgetter(1), GroupbyIterator(iterator, by, size, interval)) +def agroup( + iterator: Iterator[T], + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> Iterator[List[T]]: + return group( + iterator, + size, + interval=interval, + by=syncify(by) if by else None, + ) + + def groupby( iterator: Iterator[T], key: Callable[[T], U], @@ -133,6 +198,21 @@ def groupby( return GroupbyIterator(iterator, key, size, interval) +def agroupby( + iterator: Iterator[T], + key: Callable[[T], Coroutine[Any, Any, U]], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, +) -> Iterator[Tuple[U, List[T]]]: + return groupby( + iterator, + syncify(key), + size=size, + interval=interval, + ) + + def map( transformation: Callable[[T], U], iterator: Iterator[T], @@ -149,7 +229,7 @@ def map( if concurrency == 1: return builtins.map(wrap_error(transformation, StopIteration), iterator) else: - return OSConcurrentMapIterator( + return ConcurrentMapIterator( iterator, transformation, concurrency=concurrency, @@ -170,7 +250,9 @@ def amap( # validate_not_none(transformation, "transformation") # validate_not_none(ordered, "ordered") validate_concurrency(concurrency) - return AsyncConcurrentMapIterator( + if concurrency == 1: + return map(syncify(transformation), iterator) + return ConcurrentAMapIterator( iterator, transformation, buffersize=concurrency, @@ -201,6 +283,19 @@ def skip( return iterator +def askip( + iterator: Iterator[T], + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> Iterator[T]: + return skip( + iterator, + count, + until=syncify(until) if until else None, + ) + + def throttle( iterator: Iterator[T], count: Optional[int], @@ -227,3 +322,16 @@ def truncate( if when is not None: iterator = PredicateTruncateIterator(iterator, when) return iterator + + +def atruncate( + iterator: Iterator[T], + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, +) -> Iterator[T]: + return truncate( + iterator, + count, + when=syncify(when) if when else None, + ) diff --git a/streamable/iterators.py b/streamable/iterators.py index d1713a52..30385944 100644 --- a/streamable/iterators.py +++ b/streamable/iterators.py @@ -1,4 +1,3 @@ -import asyncio import datetime import multiprocessing import queue @@ -10,6 +9,9 @@ from math import ceil from typing import ( Any, + AsyncIterable, + AsyncIterator, + Awaitable, Callable, ContextManager, Coroutine, @@ -29,7 +31,16 @@ cast, ) -from streamable.util.functiontools import iter_wo_stopiteration, wrap_error +from streamable.util.asynctools import ( + GetEventLoopMixin, + awaitable_to_coroutine, + empty_aiter, +) +from streamable.util.functiontools import ( + aiter_wo_stopiteration, + iter_wo_stopiteration, + wrap_error, +) from streamable.util.loggertools import get_logger from streamable.util.validationtools import ( validate_base, @@ -79,7 +90,7 @@ def __init__( def __next__(self) -> T: while True: try: - return next(self.iterator) + return self.iterator.__next__() except StopIteration: if self._to_be_finally_raised: try: @@ -108,7 +119,7 @@ def __init__( def __next__(self) -> T: while True: - elem = next(self.iterator) + elem = self.iterator.__next__() key = self.key(elem) if self.key else elem if key not in self._already_seen: break @@ -127,7 +138,7 @@ def __init__( def __next__(self) -> T: while True: - elem = next(self.iterator) + elem = self.iterator.__next__() key = self.key(elem) if self.key else elem if key != self._last_key: break @@ -139,14 +150,35 @@ class FlattenIterator(Iterator[U]): def __init__(self, iterator: Iterator[Iterable[U]]) -> None: validate_iterator(iterator) self.iterator = iterator - self._current_iterator_elem: Iterator[U] = iter(tuple()) + self._current_iterator_elem: Iterator[U] = tuple().__iter__() def __next__(self) -> U: while True: try: - return next(self._current_iterator_elem) + return self._current_iterator_elem.__next__() except StopIteration: - self._current_iterator_elem = iter_wo_stopiteration(next(self.iterator)) + self._current_iterator_elem = iter_wo_stopiteration( + self.iterator.__next__() + ) + + +class AFlattenIterator(Iterator[U], GetEventLoopMixin): + def __init__(self, iterator: Iterator[AsyncIterable[U]]) -> None: + validate_iterator(iterator) + self.iterator = iterator + + self._current_iterator_elem: AsyncIterator[U] = empty_aiter() + + def __next__(self) -> U: + while True: + try: + return self.get_event_loop().run_until_complete( + self._current_iterator_elem.__anext__() + ) + except StopAsyncIteration: + self._current_iterator_elem = aiter_wo_stopiteration( + self.iterator.__next__() + ) class _GroupIteratorMixin(Generic[T]): @@ -203,7 +235,7 @@ def __next__(self) -> List[T]: while len(self._current_group) < self.size and ( not self._interval_seconds_have_elapsed() or not self._current_group ): - self._current_group.append(next(self.iterator)) + self._current_group.append(self.iterator.__next__()) except Exception as e: if not self._current_group: raise @@ -228,7 +260,7 @@ def __init__( self._groups_by: DefaultDict[U, List[T]] = defaultdict(list) def _group_next_elem(self) -> None: - elem = next(self.iterator) + elem = self.iterator.__next__() self._groups_by[self.key(elem)].append(elem) def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]: @@ -238,11 +270,11 @@ def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]: return None def _pop_first_group(self) -> Tuple[U, List[T]]: - first_key: U = next(iter(self._groups_by), cast(U, ...)) + first_key: U = self._groups_by.__iter__().__next__() return first_key, self._groups_by.pop(first_key) def _pop_largest_group(self) -> Tuple[U, List[T]]: - largest_group_key: Any = next(iter(self._groups_by), ...) + largest_group_key: Any = self._groups_by.__iter__().__next__() for key, group in self._groups_by.items(): if len(group) > len(self._groups_by[largest_group_key]): @@ -279,11 +311,11 @@ def __next__(self) -> Tuple[U, List[T]]: except StopIteration: self._is_exhausted = True - return next(self) + return self.__next__() except Exception as e: self._to_be_raised = e - return next(self) + return self.__next__() class CountSkipIterator(Iterator[T]): @@ -298,11 +330,11 @@ def __init__(self, iterator: Iterator[T], count: int) -> None: def __next__(self) -> T: if not self._done_skipping: while self._n_skipped < self.count: - next(self.iterator) + self.iterator.__next__() # do not count exceptions as skipped elements self._n_skipped += 1 self._done_skipping = True - return next(self.iterator) + return self.iterator.__next__() class PredicateSkipIterator(Iterator[T]): @@ -313,10 +345,10 @@ def __init__(self, iterator: Iterator[T], until: Callable[[T], Any]) -> None: self._done_skipping = False def __next__(self) -> T: - elem = next(self.iterator) + elem = self.iterator.__next__() if not self._done_skipping: while not self.until(elem): - elem = next(self.iterator) + elem = self.iterator.__next__() self._done_skipping = True return elem @@ -334,10 +366,10 @@ def __init__( self._done_skipping = False def __next__(self) -> T: - elem = next(self.iterator) + elem = self.iterator.__next__() if not self._done_skipping: while self._n_skipped < self.count and not self.until(elem): - elem = next(self.iterator) + elem = self.iterator.__next__() # do not count exceptions as skipped elements self._n_skipped += 1 self._done_skipping = True @@ -355,7 +387,7 @@ def __init__(self, iterator: Iterator[T], count: int) -> None: def __next__(self) -> T: if self._current_count == self.count: raise StopIteration() - elem = next(self.iterator) + elem = self.iterator.__next__() self._current_count += 1 return elem @@ -370,7 +402,7 @@ def __init__(self, iterator: Iterator[T], when: Callable[[T], Any]) -> None: def __next__(self) -> T: if self._satisfied: raise StopIteration() - elem = next(self.iterator) + elem = self.iterator.__next__() if self.when(elem): self._satisfied = True raise StopIteration() @@ -406,7 +438,7 @@ def _log(self) -> None: def __next__(self) -> T: try: - elem = next(self.iterator) + elem = self.iterator.__next__() self._n_nexts += 1 self._n_yields += 1 return elem @@ -442,7 +474,7 @@ def __init__( def safe_next(self) -> Tuple[Optional[T], Optional[Exception]]: try: - return next(self.iterator), None + return self.iterator.__next__(), None except StopIteration: raise except Exception as e: @@ -483,13 +515,13 @@ def __init__( self.iterator = iterator def __next__(self) -> T: - elem = next(self.iterator) + elem = self.iterator.__next__() if isinstance(elem, self.ExceptionContainer): raise elem.exception return elem -class _ConcurrentMapIterable( +class _ConcurrentMapIterableMixin( Generic[T, U], ABC, Iterable[Union[U, _RaisingIterator.ExceptionContainer]] ): """ @@ -536,17 +568,21 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]: # queue tasks up to buffersize with suppress(StopIteration): while len(future_results) < self.buffersize: - future_results.add_future(self._launch_task(next(self.iterator))) + future_results.add_future( + self._launch_task(self.iterator.__next__()) + ) # wait, queue, yield while future_results: - result = next(future_results) + result = future_results.__next__() with suppress(StopIteration): - future_results.add_future(self._launch_task(next(self.iterator))) + future_results.add_future( + self._launch_task(self.iterator.__next__()) + ) yield result -class _OSConcurrentMapIterable(_ConcurrentMapIterable[T, U]): +class _ConcurrentMapIterable(_ConcurrentMapIterableMixin[T, U]): def __init__( self, iterator: Iterator[T], @@ -597,7 +633,7 @@ def _future_result_collection( ) -class OSConcurrentMapIterator(_RaisingIterator[U]): +class ConcurrentMapIterator(_RaisingIterator[U]): def __init__( self, iterator: Iterator[T], @@ -608,20 +644,18 @@ def __init__( via: "Literal['thread', 'process']", ) -> None: super().__init__( - iter( - _OSConcurrentMapIterable( - iterator, - transformation, - concurrency, - buffersize, - ordered, - via, - ) - ) + _ConcurrentMapIterable( + iterator, + transformation, + concurrency, + buffersize, + ordered, + via, + ).__iter__() ) -class _AsyncConcurrentMapIterable(_ConcurrentMapIterable[T, U]): +class _ConcurrentAMapIterable(_ConcurrentMapIterableMixin[T, U], GetEventLoopMixin): def __init__( self, iterator: Iterator[T], @@ -631,12 +665,6 @@ def __init__( ) -> None: super().__init__(iterator, buffersize, ordered) self.transformation = wrap_error(transformation, StopIteration) - self.event_loop: asyncio.AbstractEventLoop - try: - self.event_loop = asyncio.get_event_loop() - except RuntimeError: - self.event_loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.event_loop) async def _safe_transformation( self, elem: T @@ -656,19 +684,19 @@ def _launch_task( ) -> "Future[Union[U, _RaisingIterator.ExceptionContainer]]": return cast( "Future[Union[U, _RaisingIterator.ExceptionContainer]]", - self.event_loop.create_task(self._safe_transformation(elem)), + self.get_event_loop().create_task(self._safe_transformation(elem)), ) def _future_result_collection( self, ) -> FutureResultCollection[Union[U, _RaisingIterator.ExceptionContainer]]: if self.ordered: - return FIFOAsyncFutureResultCollection(self.event_loop) + return FIFOAsyncFutureResultCollection(self.get_event_loop()) else: - return FDFOAsyncFutureResultCollection(self.event_loop) + return FDFOAsyncFutureResultCollection(self.get_event_loop()) -class AsyncConcurrentMapIterator(_RaisingIterator[U]): +class ConcurrentAMapIterator(_RaisingIterator[U]): def __init__( self, iterator: Iterator[T], @@ -677,14 +705,12 @@ def __init__( ordered: bool, ) -> None: super().__init__( - iter( - _AsyncConcurrentMapIterable( - iterator, - transformation, - buffersize, - ordered, - ) - ) + _ConcurrentAMapIterable( + iterator, + transformation, + buffersize, + ordered, + ).__iter__() ) @@ -727,7 +753,7 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: while len(iterator_and_future_pairs) < self.buffersize: if not iterator_to_queue: try: - iterable = next(self.iterables_iterator) + iterable = self.iterables_iterator.__next__() except StopIteration: break try: @@ -752,11 +778,90 @@ def __init__( buffersize: int, ) -> None: super().__init__( - iter( - _ConcurrentFlattenIterable( - iterables_iterator, - concurrency, - buffersize, + _ConcurrentFlattenIterable( + iterables_iterator, + concurrency, + buffersize, + ).__iter__() + ) + + +class _ConcurrentAFlattenIterable( + Iterable[Union[T, _RaisingIterator.ExceptionContainer]], GetEventLoopMixin +): + def __init__( + self, + iterables_iterator: Iterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + validate_iterator(iterables_iterator) + validate_concurrency(concurrency) + self.iterables_iterator = iterables_iterator + self.concurrency = concurrency + self.buffersize = buffersize + + def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]: + iterator_and_future_pairs: Deque[Tuple[AsyncIterator[T], Awaitable[T]]] = ( + deque() + ) + element_to_yield: Deque[Union[T, _RaisingIterator.ExceptionContainer]] = deque( + maxlen=1 + ) + iterator_to_queue: Optional[AsyncIterator[T]] = None + # wait, queue, yield (FIFO) + while True: + if iterator_and_future_pairs: + iterator, future = iterator_and_future_pairs.popleft() + try: + element_to_yield.append( + self.get_event_loop().run_until_complete(future) + ) + iterator_to_queue = iterator + except StopAsyncIteration: + pass + except Exception as e: + element_to_yield.append(_RaisingIterator.ExceptionContainer(e)) + iterator_to_queue = iterator + + # queue tasks up to buffersize + while len(iterator_and_future_pairs) < self.buffersize: + if not iterator_to_queue: + try: + iterable = self.iterables_iterator.__next__() + except StopIteration: + break + try: + iterator_to_queue = aiter_wo_stopiteration(iterable) + except Exception as e: + yield _RaisingIterator.ExceptionContainer(e) + continue + future = self.get_event_loop().create_task( + awaitable_to_coroutine( + cast(AsyncIterator, iterator_to_queue).__anext__() + ) ) - ) + iterator_and_future_pairs.append( + (cast(AsyncIterator, iterator_to_queue), future) + ) + iterator_to_queue = None + if element_to_yield: + yield element_to_yield.pop() + if not iterator_and_future_pairs: + break + + +class ConcurrentAFlattenIterator(_RaisingIterator[T]): + def __init__( + self, + iterables_iterator: Iterator[AsyncIterable[T]], + concurrency: int, + buffersize: int, + ) -> None: + super().__init__( + _ConcurrentAFlattenIterable( + iterables_iterator, + concurrency, + buffersize, + ).__iter__() ) diff --git a/streamable/stream.py b/streamable/stream.py index d25693c2..de8f2444 100644 --- a/streamable/stream.py +++ b/streamable/stream.py @@ -5,10 +5,14 @@ from typing import ( TYPE_CHECKING, Any, + AsyncIterable, + AsyncIterator, + Awaitable, Callable, Collection, Coroutine, Dict, + Generator, Generic, Iterable, Iterator, @@ -25,6 +29,7 @@ ) from streamable.util.constants import NO_REPLACEMENT +from streamable.util.functiontools import asyncify from streamable.util.loggertools import get_logger from streamable.util.validationtools import ( validate_concurrency, @@ -54,17 +59,29 @@ V = TypeVar("V") -class Stream(Iterable[T]): +class Stream(Iterable[T], AsyncIterable[T], Awaitable["Stream[T]"]): __slots__ = ("_source", "_upstream") # fmt: off @overload def __init__(self, source: Iterable[T]) -> None: ... @overload + def __init__(self, source: AsyncIterable[T]) -> None: ... + @overload def __init__(self, source: Callable[[], Iterable[T]]) -> None: ... + @overload + def __init__(self, source: Callable[[], AsyncIterable[T]]) -> None: ... # fmt: on - def __init__(self, source: Union[Iterable[T], Callable[[], Iterable[T]]]) -> None: + def __init__( + self, + source: Union[ + Iterable[T], + Callable[[], Iterable[T]], + AsyncIterable[T], + Callable[[], AsyncIterable[T]], + ], + ) -> None: """ A `Stream[T]` decorates an `Iterable[T]` with a **fluent interface** enabling the chaining of lazy operations. @@ -84,7 +101,11 @@ def upstream(self) -> "Optional[Stream]": return self._upstream @property - def source(self) -> Union[Iterable, Callable[[], Iterable]]: + def source( + self, + ) -> Union[ + Iterable, Callable[[], Iterable], AsyncIterable, Callable[[], AsyncIterable] + ]: """ Returns: Callable[[], Iterable]: Function called at iteration time (i.e. by `__iter__`) to get a fresh source iterable. @@ -103,6 +124,11 @@ def __iter__(self) -> Iterator[T]: return self.accept(IteratorVisitor[T]()) + def __aiter__(self) -> AsyncIterator[T]: + from streamable.visitors.aiterator import AsyncIteratorVisitor + + return self.accept(AsyncIteratorVisitor[T]()) + def __repr__(self) -> str: from streamable.visitors.representation import ReprVisitor @@ -134,6 +160,16 @@ def __call__(self) -> "Stream[T]": self.count() return self + def __await__(self) -> Generator[int, None, "Stream[T]"]: + """ + Iterates over this stream until exhaustion. + + Returns: + Stream[T]: self. + """ + yield from (self.acount().__await__()) + return self + def accept(self, visitor: "Visitor[V]") -> V: """ Entry point to visit this stream (en.wikipedia.org/wiki/Visitor_pattern). @@ -175,6 +211,40 @@ def catch( finally_raise=finally_raise, ) + def acatch( + self, + errors: Union[ + Optional[Type[Exception]], Iterable[Optional[Type[Exception]]] + ] = Exception, + *, + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None, + replacement: T = NO_REPLACEMENT, # type: ignore + finally_raise: bool = False, + ) -> "Stream[T]": + """ + Catches the upstream exceptions if they are instances of `errors` type and they satisfy the `when` predicate. + Optionally yields a `replacement` value. + If any exception was caught during the iteration and `finally_raise=True`, the first caught exception will be raised when the iteration finishes. + + Args: + errors (Optional[Type[Exception]], Iterable[Optional[Type[Exception]]], optional): The exception type to catch, or an iterable of exception types to catch (default: catches all `Exception`s) + when (Optional[Callable[[Exception], Coroutine[Any, Any, Any]]], optional): An additional condition that must be satisfied to catch the exception, i.e. `when(exception)` must be truthy. (default: no additional condition) + replacement (T, optional): The value to yield when an exception is caught. (default: do not yield any replacement value) + finally_raise (bool, optional): If True the first exception caught is raised when upstream's iteration ends. (default: iteration ends without raising) + + Returns: + Stream[T]: A stream of upstream elements catching the eligible exceptions. + """ + validate_errors(errors) + # validate_not_none(finally_raise, "finally_raise") + return ACatchStream( + self, + errors, + when=when, + replacement=replacement, + finally_raise=finally_raise, + ) + def count(self) -> int: """ Iterates over this stream until exhaustion and returns the count of elements. @@ -185,6 +255,18 @@ def count(self) -> int: return sum(1 for _ in self) + async def acount(self) -> int: + """ + Iterates over this stream until exhaustion and returns the count of elements. + + Returns: + int: Number of elements yielded during an entire iteration over this stream. + """ + count = 0 + async for _ in self: + count += 1 + return count + def display(self, level: int = logging.INFO) -> "Stream[T]": """ Logs (INFO level) a representation of the stream. @@ -226,6 +308,33 @@ def distinct( # validate_not_none(consecutive_only, "consecutive_only") return DistinctStream(self, key, consecutive_only) + def adistinct( + self, + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + *, + consecutive_only: bool = False, + ) -> "Stream[T]": + """ + Filters the stream to yield only distinct elements. + If a deduplication `key` is specified, `foo` and `bar` are treated as duplicates when `key(foo) == key(bar)`. + + + Among duplicates, the first encountered occurence in upstream order is yielded. + + Warning: + During iteration, the distinct elements yielded are retained in memory to perform deduplication. + Alternatively, remove only consecutive duplicates without memory footprint by setting `consecutive_only=True`. + + Args: + key (Callable[[T], Coroutine[Any, Any, Any]], optional): Elements are deduplicated based on `key(elem)`. (default: the deduplication is performed on the elements themselves) + consecutive_only (bool, optional): Whether to deduplicate only consecutive duplicates, or globally. (default: the deduplication is global) + + Returns: + Stream: A stream containing only unique upstream elements. + """ + # validate_not_none(consecutive_only, "consecutive_only") + return ADistinctStream(self, key, consecutive_only) + def filter(self, when: Callable[[T], Any] = bool) -> "Stream[T]": """ Filters the stream to yield only elements satisfying the `when` predicate. @@ -240,6 +349,24 @@ def filter(self, when: Callable[[T], Any] = bool) -> "Stream[T]": # validate_not_none(when, "when") return FilterStream(self, cast(Optional[Callable[[T], Any]], when) or bool) + def afilter(self, when: Callable[[T], Coroutine[Any, Any, Any]]) -> "Stream[T]": + """ + Filters the stream to yield only elements satisfying the `when` predicate. + + Args: + when (Callable[[T], Coroutine[Any, Any, Any]], optional): An element is kept if `when(elem)` is truthy. (default: keeps truthy elements) + + Returns: + Stream[T]: A stream of upstream elements satisfying the `when` predicate. + """ + # Unofficially accept `stream.afilter(None)`, behaving as builtin `filter(None, iter)` + # validate_not_none(when, "when") + return AFilterStream( + self, + cast(Optional[Callable[[T], Coroutine[Any, Any, Any]]], when) + or asyncify(bool), + ) + # fmt: off @overload def flatten( @@ -317,13 +444,43 @@ def flatten(self: "Stream[Iterable[U]]", *, concurrency: int = 1) -> "Stream[U]" Iterates over upstream elements assumed to be iterables, and individually yields their items. Args: - concurrency (int, optional): Represents both the number of threads used to concurrently flatten the upstream iterables and the number of iterables buffered. (default: no concurrency) + concurrency (int, optional): Number of upstream iterables concurrently flattened via threads. (default: no concurrency) Returns: Stream[R]: A stream of flattened elements from upstream iterables. """ validate_concurrency(concurrency) return FlattenStream(self, concurrency) + # fmt: off + @overload + def aflatten( + self: "Stream[AsyncIterator[U]]", + *, + concurrency: int = 1, + ) -> "Stream[U]": ... + + @overload + def aflatten( + self: "Stream[AsyncIterable[U]]", + *, + concurrency: int = 1, + ) -> "Stream[U]": ... + # fmt: on + + def aflatten( + self: "Stream[AsyncIterable[U]]", *, concurrency: int = 1 + ) -> "Stream[U]": + """ + Iterates over upstream elements assumed to be async iterables, and individually yields their items. + + Args: + concurrency (int, optional): Number of upstream async iterables concurrently flattened. (default: no concurrency) + Returns: + Stream[R]: A stream of flattened elements from upstream async iterables. + """ + validate_concurrency(concurrency) + return AFlattenStream(self, concurrency) + def foreach( self, effect: Callable[[T], Any], @@ -352,7 +509,7 @@ def foreach( def aforeach( self, - effect: Callable[[T], Coroutine], + effect: Callable[[T], Coroutine[Any, Any, Any]], *, concurrency: int = 1, ordered: bool = True, @@ -362,7 +519,7 @@ def aforeach( If the `effect(elem)` coroutine throws an exception then it will be thrown and `elem` will not be yielded. Args: - effect (Callable[[T], Any]): The asynchronous function to be applied to each element as a side effect. + effect (Callable[[T], Coroutine[Any, Any, Any]]): The asynchronous function to be applied to each element as a side effect. concurrency (int, optional): Represents both the number of async tasks concurrently applying the `effect` and the size of the buffer containing not-yet-yielded elements. If the buffer is full, the iteration over the upstream is paused until an element is yielded from the buffer. (default: no concurrency) ordered (bool, optional): If `concurrency` > 1, whether to preserve the order of upstream elements or to yield them as soon as they are processed. (default: preserves upstream order) Returns: @@ -401,6 +558,34 @@ def group( validate_optional_positive_interval(interval) return GroupStream(self, size, interval, by) + def agroup( + self, + size: Optional[int] = None, + *, + interval: Optional[datetime.timedelta] = None, + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> "Stream[List[T]]": + """ + Groups upstream elements into lists. + + A group is yielded when any of the following conditions is met: + - The group reaches `size` elements. + - `interval` seconds have passed since the last group was yielded. + - The upstream source is exhausted. + + If `by` is specified, groups will only contain elements sharing the same `by(elem)` value (see `.agroupby` for `(key, elements)` pairs). + + Args: + size (Optional[int], optional): The maximum number of elements per group. (default: no size limit) + interval (float, optional): Yields a group if `interval` seconds have passed since the last group was yielded. (default: no interval limit) + by (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): If specified, groups will only contain elements sharing the same `by(elem)` value. (default: does not co-group elements) + Returns: + Stream[List[T]]: A stream of upstream elements grouped into lists. + """ + validate_group_size(size) + validate_optional_positive_interval(interval) + return AGroupStream(self, size, interval, by) + def groupby( self, key: Callable[[T], U], @@ -427,6 +612,32 @@ def groupby( # validate_not_none(key, "key") return GroupbyStream(self, key, size, interval) + def agroupby( + self, + key: Callable[[T], Coroutine[Any, Any, U]], + *, + size: Optional[int] = None, + interval: Optional[datetime.timedelta] = None, + ) -> "Stream[Tuple[U, List[T]]]": + """ + Groups upstream elements into `(key, elements)` tuples. + + A group is yielded when any of the following conditions is met: + - A group reaches `size` elements. + - `interval` seconds have passed since the last group was yielded. + - The upstream source is exhausted. + + Args: + key (Callable[[T], Coroutine[Any, Any, U]]): An async function that returns the group key for an element. + size (Optional[int], optional): The maximum number of elements per group. (default: no size limit) + interval (Optional[datetime.timedelta], optional): If specified, yields a group if `interval` seconds have passed since the last group was yielded. (default: no interval limit) + + Returns: + Stream[Tuple[U, List[T]]]: A stream of upstream elements grouped by key, as `(key, elements)` tuples. + """ + # validate_not_none(key, "key") + return AGroupbyStream(self, key, size, interval) + def map( self, transformation: Callable[[T], U], @@ -531,6 +742,26 @@ def skip( validate_optional_count(count) return SkipStream(self, count, until) + def askip( + self, + count: Optional[int] = None, + *, + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> "Stream[T]": + """ + Skips elements until `until(elem)` is truthy, or `count` elements have been skipped. + If both `count` and `until` are set, skipping stops as soon as either condition is met. + + Args: + count (Optional[int], optional): The maximum number of elements to skip. (default: no count-based skipping) + until (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): Elements are skipped until the first one for which `until(elem)` is truthy. This element and all the subsequent ones will be yielded. (default: no predicate-based skipping) + + Returns: + Stream: A stream of the upstream elements remaining after skipping. + """ + validate_optional_count(count) + return ASkipStream(self, count, until) + def throttle( self, count: Optional[int] = None, @@ -604,6 +835,26 @@ def truncate( validate_optional_count(count) return TruncateStream(self, count, when) + def atruncate( + self, + count: Optional[int] = None, + *, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> "Stream[T]": + """ + Stops an iteration as soon as `when(elem)` is truthy, or `count` elements have been yielded. + If both `count` and `when` are set, truncation occurs as soon as either condition is met. + + Args: + count (int, optional): The maximum number of elements to yield. (default: no count-based truncation) + when (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): An async predicate function that determines when to stop the iteration. Iteration stops immediately after encountering the first element for which `when(elem)` is truthy, and that element will not be yielded. (default: no predicate-based truncation) + + Returns: + Stream[T]: A stream of at most `count` upstream elements not satisfying the `when` predicate. + """ + validate_optional_count(count) + return ATruncateStream(self, count, when) + class DownStream(Stream[U], Generic[T, U]): """ @@ -621,7 +872,11 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "DownStream[T, U]": return new @property - def source(self) -> Union[Iterable, Callable[[], Iterable]]: + def source( + self, + ) -> Union[ + Iterable, Callable[[], Iterable], AsyncIterable, Callable[[], AsyncIterable] + ]: return self._upstream.source @property @@ -654,6 +909,27 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_catch_stream(self) +class ACatchStream(DownStream[T, T]): + __slots__ = ("_upstream", "_errors", "_when", "_replacement", "_finally_raise") + + def __init__( + self, + upstream: Stream[T], + errors: Union[Optional[Type[Exception]], Iterable[Optional[Type[Exception]]]], + when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]], + replacement: T, + finally_raise: bool, + ) -> None: + super().__init__(upstream) + self._errors = errors + self._when = when + self._replacement = replacement + self._finally_raise = finally_raise + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_acatch_stream(self) + + class DistinctStream(DownStream[T, T]): __slots__ = ("_upstream", "_key", "_consecutive_only") @@ -671,6 +947,23 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_distinct_stream(self) +class ADistinctStream(DownStream[T, T]): + __slots__ = ("_upstream", "_key", "_consecutive_only") + + def __init__( + self, + upstream: Stream[T], + key: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + consecutive_only: bool, + ) -> None: + super().__init__(upstream) + self._key = key + self._consecutive_only = consecutive_only + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_adistinct_stream(self) + + class FilterStream(DownStream[T, T]): __slots__ = ("_upstream", "_when") @@ -682,6 +975,19 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_filter_stream(self) +class AFilterStream(DownStream[T, T]): + __slots__ = ("_upstream", "_when") + + def __init__( + self, upstream: Stream[T], when: Callable[[T], Coroutine[Any, Any, Any]] + ) -> None: + super().__init__(upstream) + self._when = when + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_afilter_stream(self) + + class FlattenStream(DownStream[Iterable[T], T]): __slots__ = ("_upstream", "_concurrency") @@ -693,6 +999,17 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_flatten_stream(self) +class AFlattenStream(DownStream[AsyncIterable[T], T]): + __slots__ = ("_upstream", "_concurrency") + + def __init__(self, upstream: Stream[AsyncIterable[T]], concurrency: int) -> None: + super().__init__(upstream) + self._concurrency = concurrency + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_aflatten_stream(self) + + class ForeachStream(DownStream[T, T]): __slots__ = ("_upstream", "_effect", "_concurrency", "_ordered", "_via") @@ -752,6 +1069,25 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_group_stream(self) +class AGroupStream(DownStream[T, List[T]]): + __slots__ = ("_upstream", "_size", "_interval", "_by") + + def __init__( + self, + upstream: Stream[T], + size: Optional[int], + interval: Optional[datetime.timedelta], + by: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + super().__init__(upstream) + self._size = size + self._interval = interval + self._by = by + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_agroup_stream(self) + + class GroupbyStream(DownStream[T, Tuple[U, List[T]]]): __slots__ = ("_upstream", "_key", "_size", "_interval") @@ -771,6 +1107,25 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_groupby_stream(self) +class AGroupbyStream(DownStream[T, Tuple[U, List[T]]]): + __slots__ = ("_upstream", "_key", "_size", "_interval") + + def __init__( + self, + upstream: Stream[T], + key: Callable[[T], Coroutine[Any, Any, U]], + size: Optional[int], + interval: Optional[datetime.timedelta], + ) -> None: + super().__init__(upstream) + self._key = key + self._size = size + self._interval = interval + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_agroupby_stream(self) + + class MapStream(DownStream[T, U]): __slots__ = ("_upstream", "_transformation", "_concurrency", "_ordered", "_via") @@ -839,6 +1194,23 @@ def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_skip_stream(self) +class ASkipStream(DownStream[T, T]): + __slots__ = ("_upstream", "_count", "_until") + + def __init__( + self, + upstream: Stream[T], + count: Optional[int], + until: Optional[Callable[[T], Coroutine[Any, Any, Any]]], + ) -> None: + super().__init__(upstream) + self._count = count + self._until = until + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_askip_stream(self) + + class ThrottleStream(DownStream[T, T]): __slots__ = ("_upstream", "_count", "_per") @@ -871,3 +1243,20 @@ def __init__( def accept(self, visitor: "Visitor[V]") -> V: return visitor.visit_truncate_stream(self) + + +class ATruncateStream(DownStream[T, T]): + __slots__ = ("_upstream", "_count", "_when") + + def __init__( + self, + upstream: Stream[T], + count: Optional[int] = None, + when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None, + ) -> None: + super().__init__(upstream) + self._count = count + self._when = when + + def accept(self, visitor: "Visitor[V]") -> V: + return visitor.visit_atruncate_stream(self) diff --git a/streamable/util/asynctools.py b/streamable/util/asynctools.py new file mode 100644 index 00000000..439445e5 --- /dev/null +++ b/streamable/util/asynctools.py @@ -0,0 +1,26 @@ +import asyncio +from typing import AsyncIterator, Awaitable, Optional, TypeVar + +T = TypeVar("T") + + +async def awaitable_to_coroutine(aw: Awaitable[T]) -> T: + return await aw + + +async def empty_aiter() -> AsyncIterator: + return + yield + + +class GetEventLoopMixin: + _EVENT_LOOP_SINGLETON: Optional[asyncio.AbstractEventLoop] = None + + @classmethod + def get_event_loop(cls) -> asyncio.AbstractEventLoop: + try: + return asyncio.get_running_loop() + except RuntimeError: + if not cls._EVENT_LOOP_SINGLETON: + cls._EVENT_LOOP_SINGLETON = asyncio.new_event_loop() + return cls._EVENT_LOOP_SINGLETON diff --git a/streamable/util/errors.py b/streamable/util/errors.py new file mode 100644 index 00000000..3f5bcf54 --- /dev/null +++ b/streamable/util/errors.py @@ -0,0 +1,4 @@ +class WrappedError(Exception): + def __init__(self, error: Exception): + super().__init__(repr(error)) + self.error = error diff --git a/streamable/util/functiontools.py b/streamable/util/functiontools.py index 3925d84e..ef8ca70e 100644 --- a/streamable/util/functiontools.py +++ b/streamable/util/functiontools.py @@ -1,15 +1,23 @@ -from typing import Any, Callable, Coroutine, Generic, Tuple, Type, TypeVar, overload +from functools import partial +from typing import ( + Any, + AsyncIterable, + Callable, + Coroutine, + Generic, + Tuple, + Type, + TypeVar, + overload, +) + +from streamable.util.asynctools import GetEventLoopMixin +from streamable.util.errors import WrappedError T = TypeVar("T") R = TypeVar("R") -class WrappedError(Exception): - def __init__(self, error: Exception): - super().__init__(repr(error)) - self.error = error - - class _ErrorWrappingDecorator(Generic[T, R]): def __init__(self, func: Callable[[T], R], error_type: Type[Exception]) -> None: self.func = func @@ -26,7 +34,33 @@ def wrap_error(func: Callable[[T], R], error_type: Type[Exception]) -> Callable[ return _ErrorWrappingDecorator(func, error_type) +def awrap_error( + async_func: Callable[[T], Coroutine[Any, Any, R]], error_type: Type[Exception] +) -> Callable[[T], Coroutine[Any, Any, R]]: + async def wrap(elem: T) -> R: + try: + coroutine = async_func(elem) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}" + ) + return await coroutine + except error_type as e: + raise WrappedError(e) from e + + return wrap + + iter_wo_stopiteration = wrap_error(iter, StopIteration) +iter_wo_stopasynciteration = wrap_error(iter, StopAsyncIteration) + + +def _aiter(aiterable: AsyncIterable): + return aiterable.__aiter__() + + +aiter_wo_stopasynciteration = wrap_error(_aiter, StopAsyncIteration) +aiter_wo_stopiteration = wrap_error(_aiter, StopIteration) class _Sidify(Generic[T]): @@ -115,3 +149,28 @@ def add(a: int, b: int) -> int: ``` """ return _Star(func) + + +class _Syncify(Generic[R], GetEventLoopMixin): + def __init__(self, async_func: Callable[[T], Coroutine[Any, Any, R]]) -> None: + self.async_func = async_func + + def __call__(self, T) -> R: + coroutine = self.async_func(T) + if not isinstance(coroutine, Coroutine): + raise TypeError( + f"must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}" + ) + return self.get_event_loop().run_until_complete(coroutine) + + +def syncify(async_func: Callable[[T], Coroutine[Any, Any, R]]) -> Callable[[T], R]: + return _Syncify(async_func) + + +async def _async_call(func: Callable[[T], R], o: T) -> R: + return func(o) + + +def asyncify(func: Callable[[T], R]) -> Callable[[T], Coroutine[Any, Any, R]]: + return partial(_async_call, func) diff --git a/streamable/util/futuretools.py b/streamable/util/futuretools.py index ec551973..9b7af7d8 100644 --- a/streamable/util/futuretools.py +++ b/streamable/util/futuretools.py @@ -3,7 +3,7 @@ from collections import deque from concurrent.futures import Future from contextlib import suppress -from typing import Awaitable, Deque, Iterator, Sized, Type, TypeVar, cast +from typing import AsyncIterator, Awaitable, Deque, Iterator, Sized, Type, TypeVar, cast with suppress(ImportError): from streamable.util.protocols import Queue @@ -11,7 +11,7 @@ T = TypeVar("T") -class FutureResultCollection(Iterator[T], Sized, ABC): +class FutureResultCollection(Iterator[T], AsyncIterator[T], Sized, ABC): """ Iterator over added futures' results. Supports adding new futures after iteration started. """ @@ -19,6 +19,9 @@ class FutureResultCollection(Iterator[T], Sized, ABC): @abstractmethod def add_future(self, future: "Future[T]") -> None: ... + async def __anext__(self) -> T: + return self.__next__() + class DequeFutureResultCollection(FutureResultCollection[T]): def __init__(self) -> None: @@ -87,6 +90,9 @@ def __next__(self) -> T: cast(Awaitable[T], self._futures.popleft()) ) + async def __anext__(self) -> T: + return await cast(Awaitable[T], self._futures.popleft()) + class FDFOAsyncFutureResultCollection(CallbackFutureResultCollection[T]): """ @@ -106,3 +112,9 @@ def __next__(self) -> T: self._n_futures -= 1 self._waiter = self.event_loop.create_future() return result + + async def __anext__(self) -> T: + result = await self._waiter + self._n_futures -= 1 + self._waiter = self.event_loop.create_future() + return result diff --git a/streamable/util/iterabletools.py b/streamable/util/iterabletools.py new file mode 100644 index 00000000..6d46ef66 --- /dev/null +++ b/streamable/util/iterabletools.py @@ -0,0 +1,62 @@ +from typing import ( + AsyncIterable, + AsyncIterator, + Callable, + Iterable, + Iterator, + TypeVar, +) + +from streamable.util.asynctools import GetEventLoopMixin + +T = TypeVar("T") + + +class BiIterable(Iterable[T], AsyncIterable[T]): + pass + + +class BiIterator(Iterator[T], AsyncIterator[T]): + pass + + +class SyncToBiIterable(BiIterable[T]): + def __init__(self, iterable: Iterable[T]): + self.iterable = iterable + + def __iter__(self) -> Iterator[T]: + return self.iterable.__iter__() + + def __aiter__(self) -> AsyncIterator[T]: + return SyncToAsyncIterator(self.iterable) + + +sync_to_bi_iterable: Callable[[Iterable[T]], BiIterable[T]] = SyncToBiIterable + + +class SyncToAsyncIterator(AsyncIterator[T]): + def __init__(self, iterator: Iterable[T]): + self.iterator: Iterator[T] = iterator.__iter__() + + async def __anext__(self) -> T: + try: + return self.iterator.__next__() + except StopIteration as e: + raise StopAsyncIteration() from e + + +sync_to_async_iter: Callable[[Iterable[T]], AsyncIterator[T]] = SyncToAsyncIterator + + +class AsyncToSyncIterator(Iterator[T], GetEventLoopMixin): + def __init__(self, iterator: AsyncIterable[T]): + self.iterator: AsyncIterator[T] = iterator.__aiter__() + + def __next__(self) -> T: + try: + return self.get_event_loop().run_until_complete(self.iterator.__anext__()) + except StopAsyncIteration as e: + raise StopIteration() from e + + +async_to_sync_iter: Callable[[AsyncIterable[T]], Iterator[T]] = AsyncToSyncIterator diff --git a/streamable/util/validationtools.py b/streamable/util/validationtools.py index 15fc399b..375f0e12 100644 --- a/streamable/util/validationtools.py +++ b/streamable/util/validationtools.py @@ -1,6 +1,14 @@ import datetime from contextlib import suppress -from typing import Any, Iterable, Iterator, Optional, Type, TypeVar, Union +from typing import ( + AsyncIterator, + Iterable, + Iterator, + Optional, + Type, + TypeVar, + Union, +) with suppress(ImportError): from typing import Literal @@ -13,6 +21,13 @@ def validate_iterator(iterator: Iterator): raise TypeError(f"`iterator` must be an Iterator but got a {type(iterator)}") +def validate_aiterator(iterator: AsyncIterator): + if not isinstance(iterator, AsyncIterator): + raise TypeError( + f"`iterator` must be an AsyncIterator but got a {type(iterator)}" + ) + + def validate_base(base: int): if base <= 0: raise ValueError(f"`base` must be > 0 but got {base}") @@ -69,7 +84,7 @@ def validate_optional_positive_count(count: Optional[int]): # def validate_not_none(value: Any, name: str) -> None: # if value is None: -# raise TypeError(f"`{name}` cannot be None") +# raise TypeError(f"`{name}` must not be None") def validate_errors( diff --git a/streamable/visitors/__init__.py b/streamable/visitors/__init__.py index 1234b6b4..ba67ca01 100644 --- a/streamable/visitors/__init__.py +++ b/streamable/visitors/__init__.py @@ -1 +1,3 @@ from streamable.visitors.base import Visitor + +__all__ = ["Visitor"] diff --git a/streamable/visitors/aiterator.py b/streamable/visitors/aiterator.py new file mode 100644 index 00000000..4cfe959e --- /dev/null +++ b/streamable/visitors/aiterator.py @@ -0,0 +1,227 @@ +from typing import AsyncIterable, AsyncIterator, Iterable, TypeVar, cast + +from streamable import afunctions +from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, + AForeachStream, + AGroupbyStream, + AGroupStream, + AMapStream, + ASkipStream, + ATruncateStream, + CatchStream, + DistinctStream, + FilterStream, + FlattenStream, + ForeachStream, + GroupbyStream, + GroupStream, + MapStream, + ObserveStream, + SkipStream, + Stream, + ThrottleStream, + TruncateStream, +) +from streamable.util.functiontools import async_sidify, sidify +from streamable.util.iterabletools import sync_to_async_iter +from streamable.visitors import Visitor + +T = TypeVar("T") +U = TypeVar("U") + + +class AsyncIteratorVisitor(Visitor[AsyncIterator[T]]): + def visit_catch_stream(self, stream: CatchStream[T]) -> AsyncIterator[T]: + return afunctions.catch( + stream.upstream.accept(self), + stream._errors, + when=stream._when, + replacement=stream._replacement, + finally_raise=stream._finally_raise, + ) + + def visit_acatch_stream(self, stream: ACatchStream[T]) -> AsyncIterator[T]: + return afunctions.acatch( + stream.upstream.accept(self), + stream._errors, + when=stream._when, + replacement=stream._replacement, + finally_raise=stream._finally_raise, + ) + + def visit_distinct_stream(self, stream: DistinctStream[T]) -> AsyncIterator[T]: + return afunctions.distinct( + stream.upstream.accept(self), + stream._key, + consecutive_only=stream._consecutive_only, + ) + + def visit_adistinct_stream(self, stream: ADistinctStream[T]) -> AsyncIterator[T]: + return afunctions.adistinct( + stream.upstream.accept(self), + stream._key, + consecutive_only=stream._consecutive_only, + ) + + def visit_filter_stream(self, stream: FilterStream[T]) -> AsyncIterator[T]: + return afunctions.filter(stream.upstream.accept(self), stream._when) + + def visit_afilter_stream(self, stream: AFilterStream[T]) -> AsyncIterator[T]: + return afunctions.afilter(stream.upstream.accept(self), stream._when) + + def visit_flatten_stream(self, stream: FlattenStream[T]) -> AsyncIterator[T]: + return afunctions.flatten( + stream.upstream.accept(AsyncIteratorVisitor[Iterable]()), + concurrency=stream._concurrency, + ) + + def visit_aflatten_stream(self, stream: AFlattenStream[T]) -> AsyncIterator[T]: + return afunctions.aflatten( + stream.upstream.accept(AsyncIteratorVisitor[AsyncIterable]()), + concurrency=stream._concurrency, + ) + + def visit_foreach_stream(self, stream: ForeachStream[T]) -> AsyncIterator[T]: + return self.visit_map_stream( + MapStream( + stream.upstream, + sidify(stream._effect), + stream._concurrency, + stream._ordered, + stream._via, + ) + ) + + def visit_aforeach_stream(self, stream: AForeachStream[T]) -> AsyncIterator[T]: + return self.visit_amap_stream( + AMapStream( + stream.upstream, + async_sidify(stream._effect), + stream._concurrency, + stream._ordered, + ) + ) + + def visit_group_stream(self, stream: GroupStream[U]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.group( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._size, + interval=stream._interval, + by=stream._by, + ), + ) + + def visit_agroup_stream(self, stream: AGroupStream[U]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.agroup( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._size, + interval=stream._interval, + by=stream._by, + ), + ) + + def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.groupby( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._key, + size=stream._size, + interval=stream._interval, + ), + ) + + def visit_agroupby_stream(self, stream: AGroupbyStream[U, T]) -> AsyncIterator[T]: + return cast( + AsyncIterator[T], + afunctions.agroupby( + stream.upstream.accept(AsyncIteratorVisitor[U]()), + stream._key, + size=stream._size, + interval=stream._interval, + ), + ) + + def visit_map_stream(self, stream: MapStream[U, T]) -> AsyncIterator[T]: + return afunctions.map( + stream._transformation, + stream.upstream.accept(AsyncIteratorVisitor[U]()), + concurrency=stream._concurrency, + ordered=stream._ordered, + via=stream._via, + ) + + def visit_amap_stream(self, stream: AMapStream[U, T]) -> AsyncIterator[T]: + return afunctions.amap( + stream._transformation, + stream.upstream.accept(AsyncIteratorVisitor[U]()), + concurrency=stream._concurrency, + ordered=stream._ordered, + ) + + def visit_observe_stream(self, stream: ObserveStream[T]) -> AsyncIterator[T]: + return afunctions.observe( + stream.upstream.accept(self), + stream._what, + ) + + def visit_skip_stream(self, stream: SkipStream[T]) -> AsyncIterator[T]: + return afunctions.skip( + stream.upstream.accept(self), + stream._count, + until=stream._until, + ) + + def visit_askip_stream(self, stream: ASkipStream[T]) -> AsyncIterator[T]: + return afunctions.askip( + stream.upstream.accept(self), + stream._count, + until=stream._until, + ) + + def visit_throttle_stream(self, stream: ThrottleStream[T]) -> AsyncIterator[T]: + return afunctions.throttle( + stream.upstream.accept(self), + stream._count, + per=stream._per, + ) + + def visit_truncate_stream(self, stream: TruncateStream[T]) -> AsyncIterator[T]: + return afunctions.truncate( + stream.upstream.accept(self), + stream._count, + when=stream._when, + ) + + def visit_atruncate_stream(self, stream: ATruncateStream[T]) -> AsyncIterator[T]: + return afunctions.atruncate( + stream.upstream.accept(self), + stream._count, + when=stream._when, + ) + + def visit_stream(self, stream: Stream[T]) -> AsyncIterator[T]: + if isinstance(stream.source, Iterable): + return sync_to_async_iter(stream.source) + if isinstance(stream.source, AsyncIterable): + return stream.source.__aiter__() + if callable(stream.source): + iterable = stream.source() + if isinstance(iterable, Iterable): + return sync_to_async_iter(iterable) + if isinstance(iterable, AsyncIterable): + return iterable.__aiter__() + raise TypeError( + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a Callable[[], {type(iterable)}]" + ) + raise TypeError( + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a {type(stream.source)}" + ) diff --git a/streamable/visitors/base.py b/streamable/visitors/base.py index d8f5d409..3cf28bfd 100644 --- a/streamable/visitors/base.py +++ b/streamable/visitors/base.py @@ -15,15 +15,27 @@ def visit_stream(self, stream: stream.Stream) -> V: ... def visit_catch_stream(self, stream: stream.CatchStream) -> V: return self.visit_stream(stream) + def visit_acatch_stream(self, stream: stream.ACatchStream) -> V: + return self.visit_stream(stream) + def visit_distinct_stream(self, stream: stream.DistinctStream) -> V: return self.visit_stream(stream) + def visit_adistinct_stream(self, stream: stream.ADistinctStream) -> V: + return self.visit_stream(stream) + def visit_filter_stream(self, stream: stream.FilterStream) -> V: return self.visit_stream(stream) + def visit_afilter_stream(self, stream: stream.AFilterStream) -> V: + return self.visit_stream(stream) + def visit_flatten_stream(self, stream: stream.FlattenStream) -> V: return self.visit_stream(stream) + def visit_aflatten_stream(self, stream: stream.AFlattenStream) -> V: + return self.visit_stream(stream) + def visit_foreach_stream(self, stream: stream.ForeachStream) -> V: return self.visit_stream(stream) @@ -33,9 +45,15 @@ def visit_aforeach_stream(self, stream: stream.AForeachStream) -> V: def visit_group_stream(self, stream: stream.GroupStream) -> V: return self.visit_stream(stream) + def visit_agroup_stream(self, stream: stream.AGroupStream) -> V: + return self.visit_stream(stream) + def visit_groupby_stream(self, stream: stream.GroupbyStream) -> V: return self.visit_stream(stream) + def visit_agroupby_stream(self, stream: stream.AGroupbyStream) -> V: + return self.visit_stream(stream) + def visit_observe_stream(self, stream: stream.ObserveStream) -> V: return self.visit_stream(stream) @@ -48,8 +66,14 @@ def visit_amap_stream(self, stream: stream.AMapStream) -> V: def visit_skip_stream(self, stream: stream.SkipStream) -> V: return self.visit_stream(stream) + def visit_askip_stream(self, stream: stream.ASkipStream) -> V: + return self.visit_stream(stream) + def visit_throttle_stream(self, stream: stream.ThrottleStream) -> V: return self.visit_stream(stream) def visit_truncate_stream(self, stream: stream.TruncateStream) -> V: return self.visit_stream(stream) + + def visit_atruncate_stream(self, stream: stream.ATruncateStream) -> V: + return self.visit_stream(stream) diff --git a/streamable/visitors/equality.py b/streamable/visitors/equality.py index 3af9f2d0..d03bc09e 100644 --- a/streamable/visitors/equality.py +++ b/streamable/visitors/equality.py @@ -1,8 +1,16 @@ -from typing import Any, TypeVar +from typing import Any, Union from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -19,17 +27,17 @@ ) from streamable.visitors import Visitor -T = TypeVar("T") -U = TypeVar("U") - class EqualityVisitor(Visitor[bool]): def __init__(self, other: Any): self.other: Any = other - def visit_catch_stream(self, stream: CatchStream[T]) -> bool: + def type_eq(self, stream: Stream) -> bool: + return type(stream) == type(self.other) + + def catch_eq(self, stream: Union[CatchStream, ACatchStream]) -> bool: return ( - isinstance(self.other, CatchStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._errors == self.other._errors and stream._when == self.other._when @@ -37,114 +45,154 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> bool: and stream._finally_raise == self.other._finally_raise ) - def visit_distinct_stream(self, stream: DistinctStream[T]) -> bool: + def visit_catch_stream(self, stream: CatchStream) -> bool: + return self.catch_eq(stream) + + def visit_acatch_stream(self, stream: ACatchStream) -> bool: + return self.catch_eq(stream) + + def distinct_eq(self, stream: Union[DistinctStream, ADistinctStream]) -> bool: return ( - isinstance(self.other, DistinctStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._key == self.other._key and stream._consecutive_only == self.other._consecutive_only ) - def visit_filter_stream(self, stream: FilterStream[T]) -> bool: + def visit_distinct_stream(self, stream: DistinctStream) -> bool: + return self.distinct_eq(stream) + + def visit_adistinct_stream(self, stream: ADistinctStream) -> bool: + return self.distinct_eq(stream) + + def filter_eq(self, stream: Union[FilterStream, AFilterStream]) -> bool: return ( - isinstance(self.other, FilterStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._when == self.other._when ) - def visit_flatten_stream(self, stream: FlattenStream[T]) -> bool: - return ( - isinstance(self.other, FlattenStream) - and stream.upstream.accept(EqualityVisitor(self.other.upstream)) - and stream._concurrency == self.other._concurrency - ) + def visit_filter_stream(self, stream: FilterStream) -> bool: + return self.filter_eq(stream) - def visit_foreach_stream(self, stream: ForeachStream[T]) -> bool: + def visit_afilter_stream(self, stream: AFilterStream) -> bool: + return self.filter_eq(stream) + + def flatten_eq(self, stream: Union[FlattenStream, AFlattenStream]) -> bool: return ( - isinstance(self.other, ForeachStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._concurrency == self.other._concurrency - and stream._effect == self.other._effect - and stream._ordered == self.other._ordered - and stream._via == self.other._via ) - def visit_aforeach_stream(self, stream: AForeachStream[T]) -> bool: + def visit_flatten_stream(self, stream: FlattenStream) -> bool: + return self.flatten_eq(stream) + + def visit_aflatten_stream(self, stream: AFlattenStream) -> bool: + return self.flatten_eq(stream) + + def foreach_eq(self, stream: Union[ForeachStream, AForeachStream]) -> bool: return ( - isinstance(self.other, AForeachStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._concurrency == self.other._concurrency and stream._effect == self.other._effect and stream._ordered == self.other._ordered ) - def visit_group_stream(self, stream: GroupStream[U]) -> bool: + def visit_foreach_stream(self, stream: ForeachStream) -> bool: + return self.foreach_eq(stream) and stream._via == self.other._via + + def visit_aforeach_stream(self, stream: AForeachStream) -> bool: + return self.foreach_eq(stream) + + def group_eq(self, stream: Union[GroupStream, AGroupStream]) -> bool: return ( - isinstance(self.other, GroupStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._by == self.other._by and stream._size == self.other._size and stream._interval == self.other._interval ) - def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> bool: + def visit_group_stream(self, stream: GroupStream) -> bool: + return self.group_eq(stream) + + def visit_agroup_stream(self, stream: AGroupStream) -> bool: + return self.group_eq(stream) + + def groupby_eq(self, stream: Union[GroupbyStream, AGroupbyStream]) -> bool: return ( - isinstance(self.other, GroupbyStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._key == self.other._key and stream._size == self.other._size and stream._interval == self.other._interval ) - def visit_map_stream(self, stream: MapStream[U, T]) -> bool: - return ( - isinstance(self.other, MapStream) - and stream.upstream.accept(EqualityVisitor(self.other.upstream)) - and stream._concurrency == self.other._concurrency - and stream._transformation == self.other._transformation - and stream._ordered == self.other._ordered - and stream._via == self.other._via - ) + def visit_groupby_stream(self, stream: GroupbyStream) -> bool: + return self.groupby_eq(stream) + + def visit_agroupby_stream(self, stream: AGroupbyStream) -> bool: + return self.groupby_eq(stream) - def visit_amap_stream(self, stream: AMapStream[U, T]) -> bool: + def map_eq(self, stream: Union[MapStream, AMapStream]) -> bool: return ( - isinstance(self.other, AMapStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._concurrency == self.other._concurrency and stream._transformation == self.other._transformation and stream._ordered == self.other._ordered ) - def visit_observe_stream(self, stream: ObserveStream[T]) -> bool: + def visit_map_stream(self, stream: MapStream) -> bool: + return self.map_eq(stream) and stream._via == self.other._via + + def visit_amap_stream(self, stream: AMapStream) -> bool: + return self.map_eq(stream) + + def visit_observe_stream(self, stream: ObserveStream) -> bool: return ( - isinstance(self.other, ObserveStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._what == self.other._what ) - def visit_skip_stream(self, stream: SkipStream[T]) -> bool: + def skip_eq(self, stream: Union[SkipStream, ASkipStream]) -> bool: return ( - isinstance(self.other, SkipStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._count == self.other._count and stream._until == self.other._until ) - def visit_throttle_stream(self, stream: ThrottleStream[T]) -> bool: + def visit_skip_stream(self, stream: SkipStream) -> bool: + return self.skip_eq(stream) + + def visit_askip_stream(self, stream: ASkipStream) -> bool: + return self.skip_eq(stream) + + def visit_throttle_stream(self, stream: ThrottleStream) -> bool: return ( - isinstance(self.other, ThrottleStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._count == self.other._count and stream._per == self.other._per ) - def visit_truncate_stream(self, stream: TruncateStream[T]) -> bool: + def truncate_eq(self, stream: Union[TruncateStream, ATruncateStream]) -> bool: return ( - isinstance(self.other, TruncateStream) + self.type_eq(stream) and stream.upstream.accept(EqualityVisitor(self.other.upstream)) and stream._count == self.other._count and stream._when == self.other._when ) - def visit_stream(self, stream: Stream[T]) -> bool: - return isinstance(self.other, Stream) and stream.source == self.other.source + def visit_truncate_stream(self, stream: TruncateStream) -> bool: + return self.truncate_eq(stream) + + def visit_atruncate_stream(self, stream: ATruncateStream) -> bool: + return self.truncate_eq(stream) + + def visit_stream(self, stream: Stream) -> bool: + return self.type_eq(stream) and stream.source == self.other.source diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py index 417aa27b..bb63bdb0 100644 --- a/streamable/visitors/iterator.py +++ b/streamable/visitors/iterator.py @@ -1,9 +1,17 @@ -from typing import Iterable, Iterator, TypeVar, cast +from typing import AsyncIterable, Iterable, Iterator, TypeVar, cast from streamable import functions from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -18,7 +26,8 @@ ThrottleStream, TruncateStream, ) -from streamable.util.functiontools import async_sidify, sidify, wrap_error +from streamable.util.functiontools import async_sidify, sidify, syncify, wrap_error +from streamable.util.iterabletools import async_to_sync_iter from streamable.visitors import Visitor T = TypeVar("T") @@ -35,6 +44,15 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> Iterator[T]: finally_raise=stream._finally_raise, ) + def visit_acatch_stream(self, stream: ACatchStream[T]) -> Iterator[T]: + return functions.acatch( + stream.upstream.accept(self), + stream._errors, + when=stream._when, + replacement=stream._replacement, + finally_raise=stream._finally_raise, + ) + def visit_distinct_stream(self, stream: DistinctStream[T]) -> Iterator[T]: return functions.distinct( stream.upstream.accept(self), @@ -42,18 +60,37 @@ def visit_distinct_stream(self, stream: DistinctStream[T]) -> Iterator[T]: consecutive_only=stream._consecutive_only, ) + def visit_adistinct_stream(self, stream: ADistinctStream[T]) -> Iterator[T]: + return functions.adistinct( + stream.upstream.accept(self), + stream._key, + consecutive_only=stream._consecutive_only, + ) + def visit_filter_stream(self, stream: FilterStream[T]) -> Iterator[T]: return filter( wrap_error(stream._when, StopIteration), cast(Iterable[T], stream.upstream.accept(self)), ) + def visit_afilter_stream(self, stream: AFilterStream[T]) -> Iterator[T]: + return filter( + wrap_error(syncify(stream._when), StopIteration), + cast(Iterable[T], stream.upstream.accept(self)), + ) + def visit_flatten_stream(self, stream: FlattenStream[T]) -> Iterator[T]: return functions.flatten( stream.upstream.accept(IteratorVisitor[Iterable]()), concurrency=stream._concurrency, ) + def visit_aflatten_stream(self, stream: AFlattenStream[T]) -> Iterator[T]: + return functions.aflatten( + stream.upstream.accept(IteratorVisitor[AsyncIterable]()), + concurrency=stream._concurrency, + ) + def visit_foreach_stream(self, stream: ForeachStream[T]) -> Iterator[T]: return self.visit_map_stream( MapStream( @@ -86,6 +123,17 @@ def visit_group_stream(self, stream: GroupStream[U]) -> Iterator[T]: ), ) + def visit_agroup_stream(self, stream: AGroupStream[U]) -> Iterator[T]: + return cast( + Iterator[T], + functions.agroup( + stream.upstream.accept(IteratorVisitor[U]()), + stream._size, + interval=stream._interval, + by=stream._by, + ), + ) + def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> Iterator[T]: return cast( Iterator[T], @@ -97,6 +145,17 @@ def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> Iterator[T]: ), ) + def visit_agroupby_stream(self, stream: AGroupbyStream[U, T]) -> Iterator[T]: + return cast( + Iterator[T], + functions.agroupby( + stream.upstream.accept(IteratorVisitor[U]()), + stream._key, + size=stream._size, + interval=stream._interval, + ), + ) + def visit_map_stream(self, stream: MapStream[U, T]) -> Iterator[T]: return functions.map( stream._transformation, @@ -127,6 +186,13 @@ def visit_skip_stream(self, stream: SkipStream[T]) -> Iterator[T]: until=stream._until, ) + def visit_askip_stream(self, stream: ASkipStream[T]) -> Iterator[T]: + return functions.askip( + stream.upstream.accept(self), + stream._count, + until=stream._until, + ) + def visit_throttle_stream(self, stream: ThrottleStream[T]) -> Iterator[T]: return functions.throttle( stream.upstream.accept(self), @@ -141,17 +207,27 @@ def visit_truncate_stream(self, stream: TruncateStream[T]) -> Iterator[T]: when=stream._when, ) + def visit_atruncate_stream(self, stream: ATruncateStream[T]) -> Iterator[T]: + return functions.atruncate( + stream.upstream.accept(self), + stream._count, + when=stream._when, + ) + def visit_stream(self, stream: Stream[T]) -> Iterator[T]: if isinstance(stream.source, Iterable): - iterable = stream.source - elif callable(stream.source): + return stream.source.__iter__() + if isinstance(stream.source, AsyncIterable): + return async_to_sync_iter(stream.source) + if callable(stream.source): iterable = stream.source() - if not isinstance(iterable, Iterable): - raise TypeError( - f"`source` must be an Iterable or a Callable[[], Iterable] but got a Callable[[], {type(iterable)}]" - ) - else: + if isinstance(iterable, Iterable): + return iterable.__iter__() + if isinstance(iterable, AsyncIterable): + return async_to_sync_iter(iterable) raise TypeError( - f"`source` must be an Iterable or a Callable[[], Iterable] but got a {type(stream.source)}" + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a Callable[[], {type(iterable)}]" ) - return iter(iterable) + raise TypeError( + f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a {type(stream.source)}" + ) diff --git a/streamable/visitors/representation.py b/streamable/visitors/representation.py index 434b3ffc..0e36f20b 100644 --- a/streamable/visitors/representation.py +++ b/streamable/visitors/representation.py @@ -1,9 +1,17 @@ from abc import ABC, abstractmethod -from typing import Any, Iterable, List, Type, TypeVar +from typing import Any, Iterable, List from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -22,19 +30,17 @@ from streamable.util.functiontools import _Star from streamable.visitors import Visitor -T = TypeVar("T") -U = TypeVar("U") - class ToStringVisitor(Visitor[str], ABC): - def __init__(self) -> None: + def __init__(self, one_liner_max_depth: int = 3) -> None: self.methods_reprs: List[str] = [] + self.one_liner_max_depth = one_liner_max_depth @staticmethod @abstractmethod def to_string(o: object) -> str: ... - def visit_catch_stream(self, stream: CatchStream[T]) -> str: + def visit_catch_stream(self, stream: CatchStream) -> str: replacement = "" if stream._replacement is not NO_REPLACEMENT: replacement = f", replacement={self.to_string(stream._replacement)}" @@ -47,87 +53,146 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> str: ) return stream.upstream.accept(self) - def visit_distinct_stream(self, stream: DistinctStream[T]) -> str: + def visit_acatch_stream(self, stream: ACatchStream) -> str: + replacement = "" + if stream._replacement is not NO_REPLACEMENT: + replacement = f", replacement={self.to_string(stream._replacement)}" + if isinstance(stream._errors, Iterable): + errors = f"({', '.join(map(self.to_string, stream._errors))})" + else: + errors = self.to_string(stream._errors) + self.methods_reprs.append( + f"acatch({errors}, when={self.to_string(stream._when)}{replacement}, finally_raise={self.to_string(stream._finally_raise)})" + ) + return stream.upstream.accept(self) + + def visit_distinct_stream(self, stream: DistinctStream) -> str: self.methods_reprs.append( f"distinct({self.to_string(stream._key)}, consecutive_only={self.to_string(stream._consecutive_only)})" ) return stream.upstream.accept(self) - def visit_filter_stream(self, stream: FilterStream[T]) -> str: + def visit_adistinct_stream(self, stream: ADistinctStream) -> str: + self.methods_reprs.append( + f"adistinct({self.to_string(stream._key)}, consecutive_only={self.to_string(stream._consecutive_only)})" + ) + return stream.upstream.accept(self) + + def visit_filter_stream(self, stream: FilterStream) -> str: self.methods_reprs.append(f"filter({self.to_string(stream._when)})") return stream.upstream.accept(self) - def visit_flatten_stream(self, stream: FlattenStream[T]) -> str: + def visit_afilter_stream(self, stream: AFilterStream) -> str: + self.methods_reprs.append(f"afilter({self.to_string(stream._when)})") + return stream.upstream.accept(self) + + def visit_flatten_stream(self, stream: FlattenStream) -> str: self.methods_reprs.append( f"flatten(concurrency={self.to_string(stream._concurrency)})" ) return stream.upstream.accept(self) - def visit_foreach_stream(self, stream: ForeachStream[T]) -> str: + def visit_aflatten_stream(self, stream: AFlattenStream) -> str: + self.methods_reprs.append( + f"aflatten(concurrency={self.to_string(stream._concurrency)})" + ) + return stream.upstream.accept(self) + + def visit_foreach_stream(self, stream: ForeachStream) -> str: via = f", via={self.to_string(stream._via)}" if stream._concurrency > 1 else "" self.methods_reprs.append( f"foreach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}{via})" ) return stream.upstream.accept(self) - def visit_aforeach_stream(self, stream: AForeachStream[T]) -> str: + def visit_aforeach_stream(self, stream: AForeachStream) -> str: self.methods_reprs.append( f"aforeach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)})" ) return stream.upstream.accept(self) - def visit_group_stream(self, stream: GroupStream[U]) -> str: + def visit_group_stream(self, stream: GroupStream) -> str: self.methods_reprs.append( f"group(size={self.to_string(stream._size)}, by={self.to_string(stream._by)}, interval={self.to_string(stream._interval)})" ) return stream.upstream.accept(self) - def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> str: + def visit_agroup_stream(self, stream: AGroupStream) -> str: + self.methods_reprs.append( + f"agroup(size={self.to_string(stream._size)}, by={self.to_string(stream._by)}, interval={self.to_string(stream._interval)})" + ) + return stream.upstream.accept(self) + + def visit_groupby_stream(self, stream: GroupbyStream) -> str: self.methods_reprs.append( f"groupby({self.to_string(stream._key)}, size={self.to_string(stream._size)}, interval={self.to_string(stream._interval)})" ) return stream.upstream.accept(self) - def visit_map_stream(self, stream: MapStream[U, T]) -> str: + def visit_agroupby_stream(self, stream: AGroupbyStream) -> str: + self.methods_reprs.append( + f"agroupby({self.to_string(stream._key)}, size={self.to_string(stream._size)}, interval={self.to_string(stream._interval)})" + ) + return stream.upstream.accept(self) + + def visit_map_stream(self, stream: MapStream) -> str: via = f", via={self.to_string(stream._via)}" if stream._concurrency > 1 else "" self.methods_reprs.append( f"map({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}{via})" ) return stream.upstream.accept(self) - def visit_amap_stream(self, stream: AMapStream[U, T]) -> str: + def visit_amap_stream(self, stream: AMapStream) -> str: self.methods_reprs.append( f"amap({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)})" ) return stream.upstream.accept(self) - def visit_observe_stream(self, stream: ObserveStream[T]) -> str: + def visit_observe_stream(self, stream: ObserveStream) -> str: self.methods_reprs.append(f"""observe({self.to_string(stream._what)})""") return stream.upstream.accept(self) - def visit_skip_stream(self, stream: SkipStream[T]) -> str: + def visit_skip_stream(self, stream: SkipStream) -> str: self.methods_reprs.append( f"skip({self.to_string(stream._count)}, until={self.to_string(stream._until)})" ) return stream.upstream.accept(self) - def visit_throttle_stream(self, stream: ThrottleStream[T]) -> str: + def visit_askip_stream(self, stream: ASkipStream) -> str: + self.methods_reprs.append( + f"askip({self.to_string(stream._count)}, until={self.to_string(stream._until)})" + ) + return stream.upstream.accept(self) + + def visit_throttle_stream(self, stream: ThrottleStream) -> str: self.methods_reprs.append( f"throttle({self.to_string(stream._count)}, per={self.to_string(stream._per)})" ) return stream.upstream.accept(self) - def visit_truncate_stream(self, stream: TruncateStream[T]) -> str: + def visit_truncate_stream(self, stream: TruncateStream) -> str: self.methods_reprs.append( f"truncate(count={self.to_string(stream._count)}, when={self.to_string(stream._when)})" ) return stream.upstream.accept(self) - def visit_stream(self, stream: Stream[T]) -> str: + def visit_atruncate_stream(self, stream: ATruncateStream) -> str: + self.methods_reprs.append( + f"atruncate(count={self.to_string(stream._count)}, when={self.to_string(stream._when)})" + ) + return stream.upstream.accept(self) + + def visit_stream(self, stream: Stream) -> str: + source_stream = f"Stream({self.to_string(stream.source)})" + depth = len(self.methods_reprs) + 1 + if depth == 1: + return source_stream + if depth <= self.one_liner_max_depth: + return f"{source_stream}.{'.'.join(reversed(self.methods_reprs))}" methods_block = "".join( map(lambda r: f" .{r}\n", reversed(self.methods_reprs)) ) - return f"(\n Stream({self.to_string(stream.source)})\n{methods_block})" + return f"(\n {source_stream}\n{methods_block})" class ReprVisitor(ToStringVisitor): diff --git a/tests/test_iterators.py b/tests/test_iterators.py index b9a74de5..2bb30d37 100644 --- a/tests/test_iterators.py +++ b/tests/test_iterators.py @@ -1,6 +1,15 @@ +import asyncio import unittest +from typing import AsyncIterator -from streamable.iterators import ObserveIterator, _OSConcurrentMapIterable +from streamable.aiterators import ( + _ConcurrentAMapAsyncIterable, + _RaisingAsyncIterator, +) +from streamable.iterators import ObserveIterator, _ConcurrentMapIterable +from streamable.util.asynctools import awaitable_to_coroutine +from streamable.util.iterabletools import sync_to_async_iter +from tests.utils import async_identity, identity, src class TestIterators(unittest.TestCase): @@ -8,9 +17,9 @@ def test_validation(self): with self.assertRaisesRegex( ValueError, "`buffersize` must be >= 1 but got 0", - msg="`_OSConcurrentMapIterable` constructor should raise for non-positive buffersize", + msg="`_ConcurrentMapIterable` constructor should raise for non-positive buffersize", ): - _OSConcurrentMapIterable( + _ConcurrentMapIterable( iterator=iter([]), transformation=str, concurrency=1, @@ -29,3 +38,28 @@ def test_validation(self): what="", base=0, ) + + def test_ConcurrentAMapAsyncIterable(self) -> None: + with self.assertRaisesRegex( + TypeError, + r"must be an async function i\.e\. a function returning a Coroutine but it returned a ", + msg="`amap` should raise a TypeError if a non async function is passed to it.", + ): + concurrent_amap_async_iterable: _ConcurrentAMapAsyncIterable[int, int] = ( + _ConcurrentAMapAsyncIterable( + sync_to_async_iter(src), + async_identity, + buffersize=2, + ordered=True, + ) + ) + + # remove error wrapping + concurrent_amap_async_iterable.transformation = identity # type: ignore + + aiterator: AsyncIterator[int] = _RaisingAsyncIterator( + concurrent_amap_async_iterable.__aiter__() + ) + print( + asyncio.run(awaitable_to_coroutine(aiterator.__aiter__().__anext__())) + ) diff --git a/tests/test_readme.py b/tests/test_readme.py index a0f1a89e..cf748176 100644 --- a/tests/test_readme.py +++ b/tests/test_readme.py @@ -1,7 +1,8 @@ +import asyncio import time import unittest from datetime import timedelta -from typing import Iterator, List, Tuple +from typing import AsyncIterable, Iterator, List, Tuple, TypeVar from streamable.stream import Stream @@ -15,10 +16,12 @@ three_integers_per_second: Stream[int] = integers.throttle(5, per=timedelta(seconds=1)) +T = TypeVar("T") + # fmt: off class TestReadme(unittest.TestCase): - def test_collect_it(self) -> None: + def test_iterate(self) -> None: self.assertListEqual( list(inverses), [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11], @@ -34,6 +37,11 @@ def test_collect_it(self) -> None: self.assertEqual(next(inverses_iter), 1.0) self.assertEqual(next(inverses_iter), 0.5) + async def main() -> List[float]: + return [inverse async for inverse in inverses] + + assert asyncio.run(main()) == [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11] + def test_map_example(self) -> None: integer_strings: Stream[str] = integers.map(str) @@ -59,7 +67,7 @@ def test_process_concurrent_map_example(self) -> None: # but the `state` of the main process is not mutated assert state == [] - def test_async_concurrent_map_example(self) -> None: + def test_amap_example(self) -> None: import asyncio import httpx @@ -75,7 +83,25 @@ def test_async_concurrent_map_example(self) -> None: ) assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur'] - asyncio.get_event_loop().run_until_complete(http_async_client.aclose()) + asyncio.run(http_async_client.aclose()) + + def test_async_amap_example(self) -> None: + import asyncio + + import httpx + + async def main() -> None: + async with httpx.AsyncClient() as http_async_client: + pokemon_names: Stream[str] = ( + Stream(range(1, 4)) + .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}") + .amap(http_async_client.get, concurrency=3) + .map(httpx.Response.json) + .map(lambda poke: poke["name"]) + ) + assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur'] + + asyncio.run(main()) def test_starmap_example(self) -> None: from streamable import star @@ -267,12 +293,23 @@ def test_zip_example(self) -> None: def test_count_example(self) -> None: assert integers.count() == 10 + def test_acount_example(self) -> None: + assert asyncio.run(integers.acount()) == 10 + def test_call_example(self) -> None: state: List[int] = [] appending_integers: Stream[int] = integers.foreach(state.append) assert appending_integers() is appending_integers assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + def test_await_example(self) -> None: + async def test() -> None: + state: List[int] = [] + appending_integers: Stream[int] = integers.foreach(state.append) + appending_integers is await appending_integers + assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + asyncio.run(test()) + def test_non_stopping_exceptions_example(self) -> None: from contextlib import suppress @@ -286,6 +323,59 @@ def test_non_stopping_exceptions_example(self) -> None: collected_casted_ints.extend(casted_ints) assert collected_casted_ints == [0, 1, 2, 3, 5, 6, 7, 8, 9] + def test_async_etl_example(self) -> None: # pragma: no cover + # for mypy typing check only + if not self: + import asyncio + import csv + import itertools + from datetime import timedelta + + import httpx + + from streamable import Stream + + async def main() -> None: + with open("./quadruped_pokemons.csv", mode="w") as file: + fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"] + writer = csv.DictWriter(file, fields, extrasaction='ignore') + writer.writeheader() + + async with httpx.AsyncClient() as http_async_client: + pipeline: Stream = ( + # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur + Stream(itertools.count(1)) + # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs + .throttle(16, per=timedelta(seconds=1)) + # GETs pokemons via 8 concurrent coroutines + .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}") + .amap(http_async_client.get, concurrency=8) + .foreach(httpx.Response.raise_for_status) + .map(httpx.Response.json) + # Stops the iteration when reaching the 1st pokemon of the 4th generation + .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv") + .observe("pokemons") + # Keeps only quadruped Pokemons + .filter(lambda poke: poke["shape"]["name"] == "quadruped") + .observe("quadruped pokemons") + # Catches errors due to None "generation" or "shape" + .catch( + TypeError, + when=lambda error: str(error) == "'NoneType' object is not subscriptable" + ) + # Writes a batch of pokemons every 5 seconds to the CSV file + .group(interval=timedelta(seconds=5)) + .foreach(writer.writerows) + .flatten() + .observe("written pokemons") + # Catches exceptions and raises the 1st one at the end of the iteration + .catch(Exception, finally_raise=True) + ) + + await pipeline + + asyncio.run(main()) + def test_etl_example(self) -> None: # pragma: no cover # for mypy typing check only if not self: diff --git a/tests/test_stream.py b/tests/test_stream.py index 8bda6a9e..cad827b8 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -3,22 +3,21 @@ import datetime import logging import math -from operator import itemgetter -from pickle import PickleError import queue -import random import sys import threading import time -import timeit import traceback -from types import FrameType import unittest from collections import Counter +from functools import partial +from pickle import PickleError +from types import TracebackType from typing import ( Any, + AsyncIterable, + AsyncIterator, Callable, - Coroutine, Dict, Iterable, Iterator, @@ -27,139 +26,52 @@ Set, Tuple, Type, - TypeVar, + Union, cast, ) from parameterized import parameterized # type: ignore from streamable import Stream -from streamable.util.functiontools import WrappedError, star - -T = TypeVar("T") -R = TypeVar("R") - - -def timestream(stream: Stream[T], times: int = 1) -> Tuple[float, List[T]]: - res: List[T] = [] - - def iterate(): - nonlocal res - res = list(stream) - - return timeit.timeit(iterate, number=times) / times, res - - -def identity_sleep(seconds: float) -> float: - time.sleep(seconds) - return seconds - - -async def async_identity_sleep(seconds: float) -> float: - await asyncio.sleep(seconds) - return seconds - - -# simulates an I/0 bound function -slow_identity_duration = 0.05 - - -def slow_identity(x: T) -> T: - time.sleep(slow_identity_duration) - return x - - -async def async_slow_identity(x: T) -> T: - await asyncio.sleep(slow_identity_duration) - return x - - -def identity(x: T) -> T: - return x - - -# fmt: off -async def async_identity(x: T) -> T: return x -# fmt: on - - -def square(x): - return x**2 - - -async def async_square(x): - return x**2 - - -def throw(exc: Type[Exception]): - raise exc() - - -def throw_func(exc: Type[Exception]) -> Callable[[T], T]: - return lambda _: throw(exc) - - -def async_throw_func(exc: Type[Exception]) -> Callable[[T], Coroutine[Any, Any, T]]: - async def f(_: T) -> T: - raise exc - - return f - - -def throw_for_odd_func(exc): - return lambda i: throw(exc) if i % 2 == 1 else i - - -def async_throw_for_odd_func(exc): - async def f(i): - return throw(exc) if i % 2 == 1 else i - - return f - - -class TestError(Exception): - pass - - -DELTA_RATE = 0.4 -# size of the test collections -N = 256 - -src = range(N) - -even_src = range(0, N, 2) - - -def randomly_slowed( - func: Callable[[T], R], min_sleep: float = 0.001, max_sleep: float = 0.05 -) -> Callable[[T], R]: - def wrap(x: T) -> R: - time.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) - return func(x) - - return wrap - - -def async_randomly_slowed( - async_func: Callable[[T], Coroutine[Any, Any, R]], - min_sleep: float = 0.001, - max_sleep: float = 0.05, -) -> Callable[[T], Coroutine[Any, Any, R]]: - async def wrap(x: T) -> R: - await asyncio.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) - return await async_func(x) - - return wrap - - -def range_raising_at_exhaustion( - start: int, end: int, step: int, exception: Exception -) -> Iterator[int]: - yield from range(start, end, step) - raise exception - - -src_raising_at_exhaustion = lambda: range_raising_at_exhaustion(0, N, 1, TestError()) +from streamable.util.asynctools import awaitable_to_coroutine +from streamable.util.functiontools import WrappedError, asyncify, star +from streamable.util.iterabletools import ( + sync_to_async_iter, + sync_to_bi_iterable, +) +from tests.utils import ( + DELTA_RATE, + ITERABLE_TYPES, + IterableType, + N, + TestError, + alist_or_list, + anext_or_next, + async_identity, + async_identity_sleep, + async_randomly_slowed, + async_slow_identity, + async_square, + async_throw_for_odd_func, + async_throw_func, + bi_iterable_to_iter, + even_src, + identity, + identity_sleep, + stopiteration_for_iter_type, + randomly_slowed, + slow_identity, + slow_identity_duration, + square, + src, + src_raising_at_exhaustion, + throw, + throw_for_odd_func, + throw_func, + timestream, + to_list, + to_set, +) class TestStream(unittest.TestCase): @@ -204,6 +116,19 @@ def test_init(self) -> None: ): Stream(src).upstream = Stream(src) # type: ignore + @parameterized.expand(ITERABLE_TYPES) + def test_async_src(self, itype) -> None: + self.assertEqual( + to_list(Stream(sync_to_async_iter(src)), itype), + list(src), + msg="a stream with an async source must be collectable as an Iterable or as AsyncIterable", + ) + self.assertEqual( + to_list(Stream(sync_to_async_iter(src).__aiter__), itype), + list(src), + msg="a stream with an async source must be collectable as an Iterable or as AsyncIterable", + ) + def test_repr_and_display(self) -> None: class CustomCallable: pass @@ -211,31 +136,44 @@ class CustomCallable: complex_stream: Stream[int] = ( Stream(src) .truncate(1024, when=lambda _: False) + .atruncate(1024, when=async_identity) .skip(10) + .askip(10) .skip(until=lambda _: True) + .askip(until=async_identity) .distinct(lambda _: _) + .adistinct(async_identity) .filter() .map(lambda i: (i,)) .map(lambda i: (i,), concurrency=2) .filter(star(bool)) + .afilter(star(async_identity)) .foreach(lambda _: _) .foreach(lambda _: _, concurrency=2) .aforeach(async_identity) .map(cast(Callable[[Any], Any], CustomCallable())) .amap(async_identity) .group(100) + .agroup(100) .groupby(len) + .agroupby(async_identity) .map(star(lambda key, group: group)) .observe("groups") .flatten(concurrency=4) + .map(sync_to_async_iter) + .aflatten(concurrency=4) + .map(lambda _: 0) .throttle( 64, per=datetime.timedelta(seconds=1), ) .observe("foos") - .catch(finally_raise=True) + .catch(finally_raise=True, when=identity) + .acatch(finally_raise=True, when=async_identity) .catch((TypeError, ValueError, None, ZeroDivisionError)) - .catch(TypeError, replacement=None, finally_raise=True) + .acatch((TypeError, ValueError, None, ZeroDivisionError)) + .catch(TypeError, replacement=1, finally_raise=True) + .acatch(TypeError, replacement=1, finally_raise=True) ) print(repr(complex_stream)) @@ -255,73 +193,98 @@ class CustomCallable: complex_stream.display(logging.ERROR) self.assertEqual( - """( - Stream(range(0, 256)) -)""", str(Stream(src)), + "Stream(range(0, 256))", msg="`repr` should work as expected on a stream without operation", ) self.assertEqual( + str(Stream(src).skip(10)), + "Stream(range(0, 256)).skip(10, until=None)", + msg="`repr` should return a one-liner for a stream with 1 operations", + ) + self.assertEqual( + str(Stream(src).skip(10).skip(10)), + "Stream(range(0, 256)).skip(10, until=None).skip(10, until=None)", + msg="`repr` should return a one-liner for a stream with 2 operations", + ) + self.assertEqual( + str(Stream(src).skip(10).skip(10).skip(10)), """( Stream(range(0, 256)) - .map(, concurrency=2, ordered=True, via='process') + .skip(10, until=None) + .skip(10, until=None) + .skip(10, until=None) )""", - str(Stream(src).map(lambda _: _, concurrency=2, via="process")), - msg="`repr` should work as expected on a stream with 1 operation", + msg="`repr` should go to line for a stream with 3 operations", ) self.assertEqual( str(complex_stream), """( Stream(range(0, 256)) .truncate(count=1024, when=) + .atruncate(count=1024, when=async_identity) .skip(10, until=None) + .askip(10, until=None) .skip(None, until=) + .askip(None, until=async_identity) .distinct(, consecutive_only=False) + .adistinct(async_identity, consecutive_only=False) .filter(bool) .map(, concurrency=1, ordered=True) .map(, concurrency=2, ordered=True, via='thread') .filter(star(bool)) + .afilter(star(async_identity)) .foreach(, concurrency=1, ordered=True) .foreach(, concurrency=2, ordered=True, via='thread') .aforeach(async_identity, concurrency=1, ordered=True) .map(CustomCallable(...), concurrency=1, ordered=True) .amap(async_identity, concurrency=1, ordered=True) .group(size=100, by=None, interval=None) + .agroup(size=100, by=None, interval=None) .groupby(len, size=None, interval=None) + .agroupby(async_identity, size=None, interval=None) .map(star(), concurrency=1, ordered=True) .observe('groups') .flatten(concurrency=4) + .map(SyncToAsyncIterator, concurrency=1, ordered=True) + .aflatten(concurrency=4) + .map(, concurrency=1, ordered=True) .throttle(64, per=datetime.timedelta(seconds=1)) .observe('foos') - .catch(Exception, when=None, finally_raise=True) + .catch(Exception, when=identity, finally_raise=True) + .acatch(Exception, when=async_identity, finally_raise=True) .catch((TypeError, ValueError, None, ZeroDivisionError), when=None, finally_raise=False) - .catch(TypeError, when=None, replacement=None, finally_raise=True) + .acatch((TypeError, ValueError, None, ZeroDivisionError), when=None, finally_raise=False) + .catch(TypeError, when=None, replacement=1, finally_raise=True) + .acatch(TypeError, when=None, replacement=1, finally_raise=True) )""", msg="`repr` should work as expected on a stream with many operation", ) - def test_iter(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_iter(self, itype: IterableType) -> None: self.assertIsInstance( - iter(Stream(src)), - Iterator, + bi_iterable_to_iter(Stream(src), itype=itype), + itype, msg="iter(stream) must return an Iterator.", ) with self.assertRaisesRegex( TypeError, - r"`source` must be an Iterable or a Callable\[\[\], Iterable\] but got a ", + r"`source` must be an Iterable/AsyncIterable or a Callable\[\[\], Iterable/AsyncIterable\] but got a ", msg="Getting an Iterator from a Stream with a source not being a Union[Callable[[], Iterator], ITerable] must raise TypeError.", ): - iter(Stream(1)) # type: ignore + bi_iterable_to_iter(Stream(1), itype=itype) # type: ignore with self.assertRaisesRegex( TypeError, - r"`source` must be an Iterable or a Callable\[\[\], Iterable\] but got a Callable\[\[\], \]", + r"`source` must be an Iterable/AsyncIterable or a Callable\[\[\], Iterable/AsyncIterable\] but got a Callable\[\[\], \]", msg="Getting an Iterator from a Stream with a source not being a Union[Callable[[], Iterator], ITerable] must raise TypeError.", ): - iter(Stream(lambda: 1)) # type: ignore + bi_iterable_to_iter(Stream(lambda: 1), itype=itype) # type: ignore - def test_add(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_add(self, itype: IterableType) -> None: from streamable.stream import FlattenStream stream = Stream(src) @@ -335,7 +298,7 @@ def test_add(self) -> None: stream_b = Stream(range(10, 20)) stream_c = Stream(range(20, 30)) self.assertListEqual( - list(stream_a + stream_b + stream_c), + to_list(stream_a + stream_b + stream_c, itype=itype), list(range(30)), msg="`chain` must yield the elements of the first stream the move on with the elements of the next ones and so on.", ) @@ -345,7 +308,9 @@ def test_add(self) -> None: [Stream.map, [identity]], [Stream.amap, [async_identity]], [Stream.foreach, [identity]], + [Stream.aforeach, [identity]], [Stream.flatten, []], + [Stream.aflatten, []], ] ) def test_sanitize_concurrency(self, method, args) -> None: @@ -384,26 +349,30 @@ def test_sanitize_via(self, method) -> None: method(Stream(src), identity, via="foo") @parameterized.expand( - [ - [1], - [2], - ] + [(concurrency, itype) for concurrency in (1, 2) for itype in ITERABLE_TYPES] ) - def test_map(self, concurrency) -> None: + def test_map(self, concurrency, itype) -> None: self.assertListEqual( - list(Stream(src).map(randomly_slowed(square), concurrency=concurrency)), + to_list( + Stream(src).map(randomly_slowed(square), concurrency=concurrency), + itype=itype, + ), list(map(square, src)), msg="At any concurrency the `map` method should act as the builtin map function, transforming elements while preserving input elements order.", ) @parameterized.expand( [ - [True, identity], - [False, sorted], + (ordered, order_mutation, itype) + for itype in ITERABLE_TYPES + for ordered, order_mutation in [ + (True, identity), + (False, sorted), + ] ] ) def test_process_concurrency( - self, ordered, order_mutation + self, ordered, order_mutation, itype ) -> None: # pragma: no cover # 3.7 and 3.8 are passing the test but hang forever after if sys.version_info.minor < 9: @@ -420,7 +389,7 @@ def local_identity(x): "", msg="process-based concurrency should not be able to serialize a lambda or a local func", ): - list(Stream(src).map(f, concurrency=2, via="process")) + to_list(Stream(src).map(f, concurrency=2, via="process"), itype=itype) sleeps = [0.01, 1, 0.01] state: List[str] = [] @@ -433,7 +402,7 @@ def local_identity(x): .foreach(lambda _: state.append(""), concurrency=1, ordered=True) ) self.assertListEqual( - list(stream), + to_list(stream, itype=itype), expected_result_list, msg="process-based concurrency must correctly transform elements, respecting `ordered`...", ) @@ -444,38 +413,38 @@ def local_identity(x): ) # test partial iteration: self.assertEqual( - next(iter(stream)), + anext_or_next(bi_iterable_to_iter(stream, itype=itype)), expected_result_list[0], msg="process-based concurrency must behave ok with partial iteration", ) @parameterized.expand( [ - [16, 0], - [1, 0], - [16, 1], - [16, 15], - [16, 16], + (concurrency, n_elems, itype) + for concurrency, n_elems in [ + [16, 0], + [1, 0], + [16, 1], + [16, 15], + [16, 16], + ] + for itype in ITERABLE_TYPES ] ) def test_map_with_more_concurrency_than_elements( - self, concurrency, n_elems + self, concurrency, n_elems, itype ) -> None: self.assertListEqual( - list(Stream(range(n_elems)).map(str, concurrency=concurrency)), + to_list( + Stream(range(n_elems)).map(str, concurrency=concurrency), itype=itype + ), list(map(str, range(n_elems))), msg="`map` method should act correctly when concurrency > number of elements.", ) @parameterized.expand( [ - [ - ordered, - order_mutation, - expected_duration, - operation, - func, - ] + [ordered, order_mutation, expected_duration, operation, func, itype] for ordered, order_mutation, expected_duration in [ (True, identity, 0.3), (False, sorted, 0.21), @@ -486,6 +455,7 @@ def test_map_with_more_concurrency_than_elements( (Stream.aforeach, asyncio.sleep), (Stream.amap, async_identity_sleep), ] + for itype in ITERABLE_TYPES ] ) def test_mapping_ordering( @@ -495,11 +465,13 @@ def test_mapping_ordering( expected_duration: float, operation, func, + itype, ) -> None: seconds = [0.1, 0.01, 0.2] duration, res = timestream( operation(Stream(seconds), func, ordered=ordered, concurrency=2), 5, + itype=itype, ) self.assertListEqual( res, @@ -515,23 +487,21 @@ def test_mapping_ordering( ) @parameterized.expand( - [ - [1], - [2], - ] + [(concurrency, itype) for concurrency in (1, 2) for itype in ITERABLE_TYPES] ) - def test_foreach(self, concurrency) -> None: + def test_foreach(self, concurrency, itype) -> None: side_collection: Set[int] = set() def side_effect(x: int, func: Callable[[int], int]): nonlocal side_collection side_collection.add(func(x)) - res = list( + res = to_list( Stream(src).foreach( lambda i: randomly_slowed(side_effect(i, square)), concurrency=concurrency, - ) + ), + itype=itype, ) self.assertListEqual( @@ -554,6 +524,7 @@ def side_effect(x: int, func: Callable[[int], int]): method, throw_func_, throw_for_odd_func_, + itype, ] for raised_exc, caught_exc in [ (TestError, (TestError,)), @@ -562,9 +533,11 @@ def side_effect(x: int, func: Callable[[int], int]): for concurrency in [1, 2] for method, throw_func_, throw_for_odd_func_ in [ (Stream.foreach, throw_func, throw_for_odd_func), + (Stream.aforeach, async_throw_func, async_throw_for_odd_func), (Stream.map, throw_func, throw_for_odd_func), (Stream.amap, async_throw_func, async_throw_for_odd_func), ] + for itype in ITERABLE_TYPES ] ) def test_map_or_foreach_with_exception( @@ -575,20 +548,25 @@ def test_map_or_foreach_with_exception( method: Callable[[Stream, Callable[[Any], int], int], Stream], throw_func: Callable[[Exception], Callable[[Any], int]], throw_for_odd_func: Callable[[Type[Exception]], Callable[[Any], int]], + itype: IterableType, ) -> None: with self.assertRaises( caught_exc, msg="At any concurrency, `map` and `foreach` and `amap` must raise", ): - list(method(Stream(src), throw_func(raised_exc), concurrency=concurrency)) # type: ignore + to_list( + method(Stream(src), throw_func(raised_exc), concurrency=concurrency), # type: ignore + itype=itype, + ) self.assertListEqual( - list( + to_list( method( Stream(src), throw_for_odd_func(raised_exc), concurrency=concurrency, # type: ignore - ).catch(caught_exc) + ).catch(caught_exc), + itype=itype, ), list(even_src), msg="At any concurrency, `map` and `foreach` and `amap` must not stop after one exception occured.", @@ -596,18 +574,24 @@ def test_map_or_foreach_with_exception( @parameterized.expand( [ - [method, func, concurrency] + [method, func, concurrency, itype] for method, func in [ (Stream.foreach, slow_identity), + (Stream.aforeach, async_slow_identity), (Stream.map, slow_identity), (Stream.amap, async_slow_identity), ] for concurrency in [1, 2, 4] + for itype in ITERABLE_TYPES ] ) - def test_map_and_foreach_concurrency(self, method, func, concurrency) -> None: + def test_map_and_foreach_concurrency( + self, method, func, concurrency, itype + ) -> None: expected_iteration_duration = N * slow_identity_duration / concurrency - duration, res = timestream(method(Stream(src), func, concurrency=concurrency)) + duration, res = timestream( + method(Stream(src), func, concurrency=concurrency), itype=itype + ) self.assertListEqual(res, list(src)) self.assertAlmostEqual( duration, @@ -618,50 +602,80 @@ def test_map_and_foreach_concurrency(self, method, func, concurrency) -> None: @parameterized.expand( [ - [1], - [2], + (concurrency, itype, flatten) + for concurrency in (1, 2) + for itype in ITERABLE_TYPES + for flatten in (Stream.flatten, Stream.aflatten) ] ) - def test_flatten(self, concurrency) -> None: + def test_flatten(self, concurrency, itype, flatten) -> None: n_iterables = 32 it = list(range(N // n_iterables)) double_it = it + it iterables_stream = Stream( - lambda: map(slow_identity, [double_it] + [it for _ in range(n_iterables)]) - ) + [sync_to_bi_iterable(double_it)] + + [sync_to_bi_iterable(it) for _ in range(n_iterables)] + ).map(slow_identity) self.assertCountEqual( - list(iterables_stream.flatten(concurrency=concurrency)), + to_list(flatten(iterables_stream, concurrency=concurrency), itype=itype), list(it) * n_iterables + double_it, msg="At any concurrency the `flatten` method should yield all the upstream iterables' elements.", ) self.assertListEqual( - list( - Stream([iter([]) for _ in range(2000)]).flatten(concurrency=concurrency) + to_list( + flatten( + Stream([sync_to_bi_iterable(iter([])) for _ in range(2000)]), + concurrency=concurrency, + ), + itype=itype, ), [], msg="`flatten` should not yield any element if upstream elements are empty iterables, and be resilient to recursion issue in case of successive empty upstream iterables.", ) with self.assertRaises( - TypeError, + (TypeError, AttributeError), msg="`flatten` should raise if an upstream element is not iterable.", ): - next(iter(Stream(cast(Iterable, src)).flatten())) + anext_or_next( + bi_iterable_to_iter( + flatten(Stream(cast(Union[Iterable, AsyncIterable], src))), + itype=itype, + ) + ) # test typing with ranges _: Stream[int] = Stream((src, src)).flatten() - def test_flatten_concurrency(self) -> None: + @parameterized.expand( + [ + (flatten, itype, slow) + for flatten, slow in ( + (Stream.flatten, partial(Stream.map, transformation=slow_identity)), + ( + Stream.aflatten, + partial(Stream.amap, transformation=async_slow_identity), + ), + ) + for itype in ITERABLE_TYPES + ] + ) + def test_flatten_concurrency(self, flatten, itype, slow) -> None: + concurrency = 2 iterable_size = 5 runtime, res = timestream( - Stream( - lambda: [ - Stream(map(slow_identity, ["a"] * iterable_size)), - Stream(map(slow_identity, ["b"] * iterable_size)), - Stream(map(slow_identity, ["c"] * iterable_size)), - ] - ).flatten(concurrency=2), + flatten( + Stream( + lambda: [ + slow(Stream(["a"] * iterable_size)), + slow(Stream(["b"] * iterable_size)), + slow(Stream(["c"] * iterable_size)), + ] + ), + concurrency=concurrency, + ), times=3, + itype=itype, ) self.assertListEqual( res, @@ -669,7 +683,7 @@ def test_flatten_concurrency(self) -> None: msg="`flatten` should process 'a's and 'b's concurrently and then 'c's", ) a_runtime = b_runtime = c_runtime = iterable_size * slow_identity_duration - expected_runtime = (a_runtime + b_runtime) / 2 + c_runtime + expected_runtime = (a_runtime + b_runtime) / concurrency + c_runtime self.assertAlmostEqual( runtime, expected_runtime, @@ -688,21 +702,29 @@ def test_flatten_typing(self) -> None: Stream("abc").map(lambda char: filter(lambda _: True, char)).flatten() ) + flattened_asynciter_stream: Stream[str] = ( + Stream("abc").map(sync_to_async_iter).aflatten() + ) + @parameterized.expand( [ - [exception_type, mapped_exception_type, concurrency] + [exception_type, mapped_exception_type, concurrency, itype, flatten] + for concurrency in [1, 2] + for itype in ITERABLE_TYPES for exception_type, mapped_exception_type in [ (TestError, TestError), - (StopIteration, WrappedError), + (stopiteration_for_iter_type(itype), (WrappedError, RuntimeError)), ] - for concurrency in [1, 2] + for flatten in (Stream.flatten, Stream.aflatten) ] ) - def test_flatten_with_exception( + def test_flatten_with_exception_in_iter( self, exception_type: Type[Exception], mapped_exception_type: Type[Exception], concurrency: int, + itype: IterableType, + flatten: Callable, ) -> None: n_iterables = 5 @@ -710,60 +732,79 @@ class IterableRaisingInIter(Iterable[int]): def __iter__(self) -> Iterator[int]: raise exception_type - self.assertSetEqual( - set( + res: Set[int] = to_set( + flatten( Stream( map( - lambda i: ( - IterableRaisingInIter() - if i % 2 - else cast(Iterable[int], range(i, i + 1)) + lambda i: sync_to_bi_iterable( + IterableRaisingInIter() if i % 2 else range(i, i + 1) ), range(n_iterables), ) - ) - .flatten(concurrency=concurrency) - .catch(mapped_exception_type) - ), + ), + concurrency=concurrency, + ).catch(mapped_exception_type), + itype=itype, + ) + self.assertSetEqual( + res, set(range(0, n_iterables, 2)), - msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should remap StopIteration one to PacifiedStopIteration.", + msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should wrap Stop(Async)Iteration.", ) + @parameterized.expand( + [ + [concurrency, itype, flatten] + for concurrency in [1, 2] + for itype in ITERABLE_TYPES + for flatten in (Stream.flatten, Stream.aflatten) + ] + ) + def test_flatten_with_exception_in_next( + self, + concurrency: int, + itype: IterableType, + flatten: Callable, + ) -> None: + n_iterables = 5 + class IteratorRaisingInNext(Iterator[int]): def __init__(self) -> None: self.first_next = True - def __iter__(self) -> Iterator[int]: - return self - def __next__(self) -> int: if not self.first_next: - raise StopIteration + raise StopIteration() self.first_next = False - raise exception_type + raise TestError - self.assertSetEqual( - set( + res = to_set( + flatten( Stream( map( lambda i: ( - IteratorRaisingInNext() - if i % 2 - else cast(Iterable[int], range(i, i + 1)) + sync_to_bi_iterable( + IteratorRaisingInNext() if i % 2 else range(i, i + 1) + ) ), range(n_iterables), ) - ) - .flatten(concurrency=concurrency) - .catch(mapped_exception_type) - ), + ), + concurrency=concurrency, + ).catch(TestError), + itype=itype, + ) + self.assertSetEqual( + res, set(range(0, n_iterables, 2)), - msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should remap StopIteration one to PacifiedStopIteration.", + msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should wrap Stop(Async)Iteration.", ) - @parameterized.expand([[concurrency] for concurrency in [2, 4]]) + @parameterized.expand( + [(concurrency, itype) for concurrency in [2, 4] for itype in ITERABLE_TYPES] + ) def test_partial_iteration_on_streams_using_concurrency( - self, concurrency: int + self, concurrency: int, itype: IterableType ) -> None: yielded_elems = [] @@ -798,14 +839,14 @@ def remembering_src() -> Iterator[int]: ), ]: yielded_elems = [] - iterator = iter(stream) + iterator = bi_iterable_to_iter(stream, itype=itype) time.sleep(0.5) self.assertEqual( len(yielded_elems), 0, msg=f"before the first call to `next` a concurrent {type(stream)} should have pulled 0 upstream elements.", ) - next(iterator) + anext_or_next(iterator) time.sleep(0.5) self.assertEqual( len(yielded_elems), @@ -813,40 +854,82 @@ def remembering_src() -> Iterator[int]: msg=f"`after the first call to `next` a concurrent {type(stream)} with concurrency={concurrency} should have pulled only {n_pulls_after_first_next} upstream elements.", ) - def test_filter(self) -> None: - def keep(x) -> Any: + @parameterized.expand(ITERABLE_TYPES) + def test_filter(self, itype: IterableType) -> None: + def keep(x) -> int: return x % 2 self.assertListEqual( - list(Stream(src).filter(keep)), + to_list(Stream(src).filter(keep), itype=itype), list(filter(keep, src)), msg="`filter` must act like builtin filter", ) self.assertListEqual( - list(Stream(src).filter()), + to_list(Stream(src).filter(bool), itype=itype), list(filter(None, src)), msg="`filter` with `bool` as predicate must act like builtin filter with None predicate.", ) self.assertListEqual( - list(Stream(src).filter()), + to_list(Stream(src).filter(), itype=itype), list(filter(None, src)), msg="`filter` without predicate must act like builtin filter with None predicate.", ) self.assertListEqual( - list(Stream(src).filter(None)), # type: ignore + to_list(Stream(src).filter(None), itype=itype), # type: ignore list(filter(None, src)), msg="`filter` with None predicate must act unofficially like builtin filter with None predicate.", ) - # Unofficially accept `stream.filter(None)`, behaving as builtin `filter(None, iter)` + self.assertEqual( + to_list(Stream(src).filter(None), itype=itype), # type: ignore + list(filter(None, src)), + msg="Unofficially accept `stream.filter(None)`, behaving as builtin `filter(None, iter)`", + ) # with self.assertRaisesRegex( # TypeError, # "`when` cannot be None", # msg="`filter` does not accept a None predicate", # ): - # list(Stream(src).filter(None)) # type: ignore + # to_list(Stream(src).filter(None), itype=itype) # type: ignore + + @parameterized.expand(ITERABLE_TYPES) + def test_afilter(self, itype: IterableType) -> None: + def keep(x) -> int: + return x % 2 + + async def async_keep(x) -> int: + return keep(x) + + self.assertListEqual( + to_list(Stream(src).afilter(async_keep), itype=itype), + list(filter(keep, src)), + msg="`afilter` must act like builtin filter", + ) + self.assertListEqual( + to_list(Stream(src).afilter(asyncify(bool)), itype=itype), + list(filter(None, src)), + msg="`afilter` with `bool` as predicate must act like builtin filter with None predicate.", + ) + self.assertListEqual( + to_list(Stream(src).afilter(None), itype=itype), # type: ignore + list(filter(None, src)), + msg="`afilter` with None predicate must act unofficially like builtin filter with None predicate.", + ) + + self.assertEqual( + to_list(Stream(src).afilter(None), itype=itype), # type: ignore + list(filter(None, src)), + msg="Unofficially accept `stream.afilter(None)`, behaving as builtin `filter(None, iter)`", + ) + # with self.assertRaisesRegex( + # TypeError, + # "`when` cannot be None", + # msg="`afilter` does not accept a None predicate", + # ): + # to_list(Stream(src).afilter(None), itype=itype) # type: ignore - def test_skip(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_skip(self, itype: IterableType) -> None: with self.assertRaisesRegex( ValueError, "`count` must be >= 0 but got -1", @@ -855,89 +938,164 @@ def test_skip(self) -> None: Stream(src).skip(-1) self.assertListEqual( - list(Stream(src).skip()), + to_list(Stream(src).skip(), itype=itype), list(src), msg="`skip` must be no-op if both `count` and `until` are None", ) self.assertListEqual( - list(Stream(src).skip(None)), + to_list(Stream(src).skip(None), itype=itype), list(src), msg="`skip` must be no-op if both `count` and `until` are None", ) for count in [0, 1, 3]: self.assertListEqual( - list(Stream(src).skip(count)), + to_list(Stream(src).skip(count), itype=itype), list(src)[count:], msg="`skip` must skip `count` elements", ) self.assertListEqual( - list( + to_list( Stream(map(throw_for_odd_func(TestError), src)) .skip(count) - .catch(TestError) + .catch(TestError), + itype=itype, ), list(filter(lambda i: i % 2 == 0, src))[count:], msg="`skip` must not count exceptions as skipped elements", ) self.assertListEqual( - list(Stream(src).skip(until=lambda n: n >= count)), + to_list(Stream(src).skip(until=lambda n: n >= count), itype=itype), list(src)[count:], msg="`skip` must yield starting from the first element satisfying `until`", ) self.assertListEqual( - list(Stream(src).skip(count, until=lambda n: False)), + to_list(Stream(src).skip(count, until=lambda n: False), itype=itype), list(src)[count:], msg="`skip` must ignore `count` elements if `until` is never satisfied", ) self.assertListEqual( - list(Stream(src).skip(count * 2, until=lambda n: n >= count)), + to_list( + Stream(src).skip(count * 2, until=lambda n: n >= count), itype=itype + ), list(src)[count:], msg="`skip` must ignore less than `count` elements if `until` is satisfied first", ) self.assertListEqual( - list(Stream(src).skip(until=lambda n: False)), + to_list(Stream(src).skip(until=lambda n: False), itype=itype), [], msg="`skip` must not yield any element if `until` is never satisfied", ) - def test_truncate(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_askip(self, itype: IterableType) -> None: + with self.assertRaisesRegex( + ValueError, + "`count` must be >= 0 but got -1", + msg="`askip` must raise ValueError if `count` is negative", + ): + Stream(src).askip(-1) + + self.assertListEqual( + to_list(Stream(src).askip(), itype=itype), + list(src), + msg="`askip` must be no-op if both `count` and `until` are None", + ) + + self.assertListEqual( + to_list(Stream(src).askip(None), itype=itype), + list(src), + msg="`askip` must be no-op if both `count` and `until` are None", + ) + + for count in [0, 1, 3]: + self.assertListEqual( + to_list(Stream(src).askip(count), itype=itype), + list(src)[count:], + msg="`askip` must skip `count` elements", + ) + + self.assertListEqual( + to_list( + Stream(map(throw_for_odd_func(TestError), src)) + .askip(count) + .catch(TestError), + itype=itype, + ), + list(filter(lambda i: i % 2 == 0, src))[count:], + msg="`askip` must not count exceptions as skipped elements", + ) + + self.assertListEqual( + to_list( + Stream(src).askip(until=asyncify(lambda n: n >= count)), itype=itype + ), + list(src)[count:], + msg="`askip` must yield starting from the first element satisfying `until`", + ) + + self.assertListEqual( + to_list( + Stream(src).askip(count, until=asyncify(lambda n: False)), + itype=itype, + ), + list(src)[count:], + msg="`askip` must ignore `count` elements if `until` is never satisfied", + ) + + self.assertListEqual( + to_list( + Stream(src).askip(count * 2, until=asyncify(lambda n: n >= count)), + itype=itype, + ), + list(src)[count:], + msg="`askip` must ignore less than `count` elements if `until` is satisfied first", + ) + + self.assertListEqual( + to_list(Stream(src).askip(until=asyncify(lambda n: False)), itype=itype), + [], + msg="`askip` must not yield any element if `until` is never satisfied", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_truncate(self, itype: IterableType) -> None: self.assertListEqual( - list(Stream(src).truncate(N * 2)), + to_list(Stream(src).truncate(N * 2), itype=itype), list(src), msg="`truncate` must be ok with count >= stream length", ) self.assertListEqual( - list(Stream(src).truncate()), + to_list(Stream(src).truncate(), itype=itype), list(src), msg="`truncate must be no-op if both `count` and `when` are None", ) self.assertListEqual( - list(Stream(src).truncate(None)), + to_list(Stream(src).truncate(None), itype=itype), list(src), msg="`truncate must be no-op if both `count` and `when` are None", ) self.assertListEqual( - list(Stream(src).truncate(2)), + to_list(Stream(src).truncate(2), itype=itype), [0, 1], msg="`truncate` must be ok with count >= 1", ) self.assertListEqual( - list(Stream(src).truncate(1)), + to_list(Stream(src).truncate(1), itype=itype), [0], msg="`truncate` must be ok with count == 1", ) self.assertListEqual( - list(Stream(src).truncate(0)), + to_list(Stream(src).truncate(0), itype=itype), [], msg="`truncate` must be ok with count == 0", ) @@ -950,106 +1108,222 @@ def test_truncate(self) -> None: Stream(src).truncate(-1) self.assertListEqual( - list(Stream(src).truncate(cast(int, float("inf")))), + to_list(Stream(src).truncate(cast(int, float("inf"))), itype=itype), list(src), msg="`truncate` must be no-op if `count` is inf", ) count = N // 2 - raising_stream_iterator = iter( - Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).truncate(count) + raising_stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).truncate(count), + itype=itype, ) with self.assertRaises( ZeroDivisionError, msg="`truncate` must not stop iteration when encountering exceptions and raise them without counting them...", ): - next(raising_stream_iterator) + anext_or_next(raising_stream_iterator) - self.assertListEqual(list(raising_stream_iterator), list(range(1, count + 1))) + self.assertListEqual( + alist_or_list(raising_stream_iterator), list(range(1, count + 1)) + ) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(raising_stream_iterator)), msg="... and after reaching the limit it still continues to raise StopIteration on calls to next", ): - next(raising_stream_iterator) + anext_or_next(raising_stream_iterator) - iter_truncated_on_predicate = iter(Stream(src).truncate(when=lambda n: n == 5)) + iter_truncated_on_predicate = bi_iterable_to_iter( + Stream(src).truncate(when=lambda n: n == 5), itype=itype + ) self.assertListEqual( - list(iter_truncated_on_predicate), - list(Stream(src).truncate(5)), + alist_or_list(iter_truncated_on_predicate), + to_list(Stream(src).truncate(5), itype=itype), msg="`when` n == 5 must be equivalent to `count` = 5", ) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(iter_truncated_on_predicate)), msg="After exhaustion a call to __next__ on a truncated iterator must raise StopIteration", ): - next(iter_truncated_on_predicate) + anext_or_next(iter_truncated_on_predicate) with self.assertRaises( ZeroDivisionError, msg="an exception raised by `when` must be raised", ): - list(Stream(src).truncate(when=lambda _: 1 / 0)) + to_list(Stream(src).truncate(when=lambda _: 1 / 0), itype=itype) self.assertListEqual( - list(Stream(src).truncate(6, when=lambda n: n == 5)), + to_list(Stream(src).truncate(6, when=lambda n: n == 5), itype=itype), list(range(5)), msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", ) self.assertListEqual( - list(Stream(src).truncate(5, when=lambda n: n == 6)), + to_list(Stream(src).truncate(5, when=lambda n: n == 6), itype=itype), list(range(5)), msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", ) - def test_group(self) -> None: - # behavior with invalid arguments - for seconds in [-1, 0]: - with self.assertRaises( - ValueError, - msg="`group` should raise error when called with `seconds` <= 0.", - ): - list( - Stream([1]).group( - size=100, interval=datetime.timedelta(seconds=seconds) - ) - ) + @parameterized.expand(ITERABLE_TYPES) + def test_atruncate(self, itype: IterableType) -> None: + self.assertListEqual( + to_list(Stream(src).atruncate(N * 2), itype=itype), + list(src), + msg="`atruncate` must be ok with count >= stream length", + ) - for size in [-1, 0]: - with self.assertRaises( - ValueError, - msg="`group` should raise error when called with `size` < 1.", - ): - list(Stream([1]).group(size=size)) + self.assertListEqual( + to_list(Stream(src).atruncate(), itype=itype), + list(src), + msg="`atruncate` must be no-op if both `count` and `when` are None", + ) - # group size self.assertListEqual( - list(Stream(range(6)).group(size=4)), - [[0, 1, 2, 3], [4, 5]], - msg="", + to_list(Stream(src).atruncate(None), itype=itype), + list(src), + msg="`atruncate` must be no-op if both `count` and `when` are None", ) + self.assertListEqual( - list(Stream(range(6)).group(size=2)), - [[0, 1], [2, 3], [4, 5]], - msg="", + to_list(Stream(src).atruncate(2), itype=itype), + [0, 1], + msg="`atruncate` must be ok with count >= 1", ) self.assertListEqual( - list(Stream([]).group(size=2)), + to_list(Stream(src).atruncate(1), itype=itype), + [0], + msg="`atruncate` must be ok with count == 1", + ) + self.assertListEqual( + to_list(Stream(src).atruncate(0), itype=itype), [], - msg="", + msg="`atruncate` must be ok with count == 0", ) - # behavior with exceptions - def f(i): - return i / (110 - i) + with self.assertRaisesRegex( + ValueError, + "`count` must be >= 0 but got -1", + msg="`atruncate` must raise ValueError if `count` is negative", + ): + Stream(src).atruncate(-1) - stream_iterator = iter(Stream(lambda: map(f, src)).group(100)) - next(stream_iterator) self.assertListEqual( - next(stream_iterator), - list(map(f, range(100, 110))), + to_list(Stream(src).atruncate(cast(int, float("inf"))), itype=itype), + list(src), + msg="`atruncate` must be no-op if `count` is inf", + ) + + count = N // 2 + raising_stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).atruncate(count), + itype=itype, + ) + + with self.assertRaises( + ZeroDivisionError, + msg="`atruncate` must not stop iteration when encountering exceptions and raise them without counting them...", + ): + anext_or_next(raising_stream_iterator) + + self.assertListEqual( + alist_or_list(raising_stream_iterator), list(range(1, count + 1)) + ) + + with self.assertRaises( + stopiteration_for_iter_type(type(raising_stream_iterator)), + msg="... and after reaching the limit it still continues to raise StopIteration on calls to next", + ): + anext_or_next(raising_stream_iterator) + + iter_truncated_on_predicate = bi_iterable_to_iter( + Stream(src).atruncate(when=asyncify(lambda n: n == 5)), itype=itype + ) + self.assertListEqual( + alist_or_list(iter_truncated_on_predicate), + to_list(Stream(src).atruncate(5), itype=itype), + msg="`when` n == 5 must be equivalent to `count` = 5", + ) + with self.assertRaises( + stopiteration_for_iter_type(type(iter_truncated_on_predicate)), + msg="After exhaustion a call to __next__ on a truncated iterator must raise StopIteration", + ): + anext_or_next(iter_truncated_on_predicate) + + with self.assertRaises( + ZeroDivisionError, + msg="an exception raised by `when` must be raised", + ): + to_list(Stream(src).atruncate(when=asyncify(lambda _: 1 / 0)), itype=itype) + + self.assertListEqual( + to_list( + Stream(src).atruncate(6, when=asyncify(lambda n: n == 5)), itype=itype + ), + list(range(5)), + msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", + ) + + self.assertListEqual( + to_list( + Stream(src).atruncate(5, when=asyncify(lambda n: n == 6)), itype=itype + ), + list(range(5)), + msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_group(self, itype: IterableType) -> None: + # behavior with invalid arguments + for seconds in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`group` should raise error when called with `seconds` <= 0.", + ): + to_list( + Stream([1]).group( + size=100, interval=datetime.timedelta(seconds=seconds) + ), + itype=itype, + ) + + for size in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`group` should raise error when called with `size` < 1.", + ): + to_list(Stream([1]).group(size=size), itype=itype) + + # group size + self.assertListEqual( + to_list(Stream(range(6)).group(size=4), itype=itype), + [[0, 1, 2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream(range(6)).group(size=2), itype=itype), + [[0, 1], [2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream([]).group(size=2), itype=itype), + [], + msg="", + ) + + # behavior with exceptions + def f(i): + return i / (110 - i) + + stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(f, src)).group(100), itype=itype + ) + anext_or_next(stream_iterator) + self.assertListEqual( + anext_or_next(stream_iterator), + list(map(f, range(100, 110))), msg="when encountering upstream exception, `group` should yield the current accumulated group...", ) @@ -1057,90 +1331,110 @@ def f(i): ZeroDivisionError, msg="... and raise the upstream exception during the next call to `next`...", ): - next(stream_iterator) + anext_or_next(stream_iterator) self.assertListEqual( - next(stream_iterator), + anext_or_next(stream_iterator), list(map(f, range(111, 211))), msg="... and restarting a fresh group to yield after that.", ) # behavior of the `seconds` parameter self.assertListEqual( - list( + to_list( Stream(lambda: map(slow_identity, src)).group( size=100, interval=datetime.timedelta(seconds=slow_identity_duration / 1000), - ) + ), + itype=itype, ), list(map(lambda e: [e], src)), msg="`group` should not yield empty groups even though `interval` if smaller than upstream's frequency", ) self.assertListEqual( - list( + to_list( Stream(lambda: map(slow_identity, src)).group( size=100, interval=datetime.timedelta(seconds=slow_identity_duration / 1000), by=lambda _: None, - ) + ), + itype=itype, ), list(map(lambda e: [e], src)), msg="`group` with `by` argument should not yield empty groups even though `interval` if smaller than upstream's frequency", ) self.assertListEqual( - list( + to_list( Stream(lambda: map(slow_identity, src)).group( size=100, interval=datetime.timedelta( seconds=2 * slow_identity_duration * 0.99 ), - ) + ), + itype=itype, ), list(map(lambda e: [e, e + 1], even_src)), msg="`group` should yield upstream elements in a two-element group if `interval` inferior to twice the upstream yield period", ) self.assertListEqual( - next(iter(Stream(src).group())), + anext_or_next(bi_iterable_to_iter(Stream(src).group(), itype=itype)), list(src), msg="`group` without arguments should group the elements all together", ) + # test agroupby + groupby_stream_iter: Union[ + Iterator[Tuple[int, List[int]]], AsyncIterator[Tuple[int, List[int]]] + ] = bi_iterable_to_iter( + Stream(src).groupby(lambda n: n % 2, size=2), itype=itype + ) + self.assertListEqual( + [anext_or_next(groupby_stream_iter), anext_or_next(groupby_stream_iter)], + [(0, [0, 2]), (1, [1, 3])], + msg="`groupby` must cogroup elements.", + ) + # test by - stream_iter = iter(Stream(src).group(size=2, by=lambda n: n % 2)) + stream_iter = bi_iterable_to_iter( + Stream(src).group(size=2, by=lambda n: n % 2), itype=itype + ) self.assertListEqual( - [next(stream_iter), next(stream_iter)], + [anext_or_next(stream_iter), anext_or_next(stream_iter)], [[0, 2], [1, 3]], msg="`group` called with a `by` function must cogroup elements.", ) self.assertListEqual( - next( - iter( + anext_or_next( + bi_iterable_to_iter( Stream(src_raising_at_exhaustion).group( size=10, by=lambda n: n % 4 != 0 - ) - ) + ), + itype=itype, + ), ), [1, 2, 3, 5, 6, 7, 9, 10, 11, 13], msg="`group` called with a `by` function and a `size` should yield the first batch becoming full.", ) self.assertListEqual( - list(Stream(src).group(by=lambda n: n % 2)), + to_list(Stream(src).group(by=lambda n: n % 2), itype=itype), [list(range(0, N, 2)), list(range(1, N, 2))], msg="`group` called with a `by` function and an infinite size must cogroup elements and yield groups starting with the group containing the oldest element.", ) self.assertListEqual( - list(Stream(range(10)).group(by=lambda n: n % 4 == 0)), + to_list(Stream(range(10)).group(by=lambda n: n % 4 == 0), itype=itype), [[0, 4, 8], [1, 2, 3, 5, 6, 7, 9]], msg="`group` called with a `by` function and reaching exhaustion must cogroup elements and yield uncomplete groups starting with the group containing the oldest element, even though it's not the largest.", ) - stream_iter = iter(Stream(src_raising_at_exhaustion).group(by=lambda n: n % 2)) + stream_iter = bi_iterable_to_iter( + Stream(src_raising_at_exhaustion).group(by=lambda n: n % 2), itype=itype + ) self.assertListEqual( - [next(stream_iter), next(stream_iter)], + [anext_or_next(stream_iter), anext_or_next(stream_iter)], [list(range(0, N, 2)), list(range(1, N, 2))], msg="`group` called with a `by` function and encountering an exception must cogroup elements and yield uncomplete groups starting with the group containing the oldest element.", ) @@ -1148,56 +1442,277 @@ def f(i): TestError, msg="`group` called with a `by` function and encountering an exception must raise it after all groups have been yielded", ): - next(stream_iter) + anext_or_next(stream_iter) # test seconds + by self.assertListEqual( - list( + to_list( Stream(lambda: map(slow_identity, range(10))).group( interval=datetime.timedelta(seconds=slow_identity_duration * 2.9), by=lambda n: n % 4 == 0, - ) + ), + itype=itype, ), [[1, 2], [0, 4], [3, 5, 6, 7], [8], [9]], msg="`group` called with a `by` function must cogroup elements and yield the largest groups when `seconds` is reached event though it's not the oldest.", ) - stream_iter = iter( + stream_iter = bi_iterable_to_iter( Stream(src).group( - size=3, by=lambda n: throw(StopIteration) if n == 2 else n - ) + size=3, + by=lambda n: throw(stopiteration_for_iter_type(itype)) if n == 2 else n, + ), + itype=itype, ) self.assertListEqual( - [next(stream_iter), next(stream_iter)], + [anext_or_next(stream_iter), anext_or_next(stream_iter)], [[0], [1]], msg="`group` should yield incomplete groups when `by` raises", ) with self.assertRaisesRegex( - WrappedError, - "StopIteration()", + (WrappedError, RuntimeError), + stopiteration_for_iter_type(itype).__name__, msg="`group` should raise and skip `elem` if `by(elem)` raises", ): - next(stream_iter) + anext_or_next(stream_iter) self.assertListEqual( - next(stream_iter), + anext_or_next(stream_iter), [3], msg="`group` should continue yielding after `by`'s exception has been raised.", ) - def test_throttle(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_agroup(self, itype: IterableType) -> None: + # behavior with invalid arguments + for seconds in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`agroup` should raise error when called with `seconds` <= 0.", + ): + to_list( + Stream([1]).agroup( + size=100, interval=datetime.timedelta(seconds=seconds) + ), + itype=itype, + ) + + for size in [-1, 0]: + with self.assertRaises( + ValueError, + msg="`agroup` should raise error when called with `size` < 1.", + ): + to_list(Stream([1]).agroup(size=size), itype=itype) + + # group size + self.assertListEqual( + to_list(Stream(range(6)).agroup(size=4), itype=itype), + [[0, 1, 2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream(range(6)).agroup(size=2), itype=itype), + [[0, 1], [2, 3], [4, 5]], + msg="", + ) + self.assertListEqual( + to_list(Stream([]).agroup(size=2), itype=itype), + [], + msg="", + ) + + # behavior with exceptions + def f(i): + return i / (110 - i) + + stream_iterator = bi_iterable_to_iter( + Stream(lambda: map(f, src)).agroup(100), itype=itype + ) + anext_or_next(stream_iterator) + self.assertListEqual( + anext_or_next(stream_iterator), + list(map(f, range(100, 110))), + msg="when encountering upstream exception, `agroup` should yield the current accumulated group...", + ) + + with self.assertRaises( + ZeroDivisionError, + msg="... and raise the upstream exception during the next call to `next`...", + ): + anext_or_next(stream_iterator) + + self.assertListEqual( + anext_or_next(stream_iterator), + list(map(f, range(111, 211))), + msg="... and restarting a fresh group to yield after that.", + ) + + # behavior of the `seconds` parameter + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, src)).agroup( + size=100, + interval=datetime.timedelta(seconds=slow_identity_duration / 1000), + ), + itype=itype, + ), + list(map(lambda e: [e], src)), + msg="`agroup` should not yield empty groups even though `interval` if smaller than upstream's frequency", + ) + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, src)).agroup( + size=100, + interval=datetime.timedelta(seconds=slow_identity_duration / 1000), + by=asyncify(lambda _: None), + ), + itype=itype, + ), + list(map(lambda e: [e], src)), + msg="`agroup` with `by` argument should not yield empty groups even though `interval` if smaller than upstream's frequency", + ) + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, src)).agroup( + size=100, + interval=datetime.timedelta( + seconds=2 * slow_identity_duration * 0.99 + ), + ), + itype=itype, + ), + list(map(lambda e: [e, e + 1], even_src)), + msg="`agroup` should yield upstream elements in a two-element group if `interval` inferior to twice the upstream yield period", + ) + + self.assertListEqual( + anext_or_next(bi_iterable_to_iter(Stream(src).agroup(), itype=itype)), + list(src), + msg="`agroup` without arguments should group the elements all together", + ) + + # test agroupby + groupby_stream_iter: Union[ + Iterator[Tuple[int, List[int]]], AsyncIterator[Tuple[int, List[int]]] + ] = bi_iterable_to_iter( + Stream(src).agroupby(asyncify(lambda n: n % 2), size=2), itype=itype + ) + self.assertListEqual( + [anext_or_next(groupby_stream_iter), anext_or_next(groupby_stream_iter)], + [(0, [0, 2]), (1, [1, 3])], + msg="`agroupby` must cogroup elements.", + ) + + # test by + stream_iter = bi_iterable_to_iter( + Stream(src).agroup(size=2, by=asyncify(lambda n: n % 2)), itype=itype + ) + self.assertListEqual( + [anext_or_next(stream_iter), anext_or_next(stream_iter)], + [[0, 2], [1, 3]], + msg="`agroup` called with a `by` function must cogroup elements.", + ) + + self.assertListEqual( + anext_or_next( + bi_iterable_to_iter( + Stream(src_raising_at_exhaustion).agroup( + size=10, by=asyncify(lambda n: n % 4 != 0) + ), + itype=itype, + ), + ), + [1, 2, 3, 5, 6, 7, 9, 10, 11, 13], + msg="`agroup` called with a `by` function and a `size` should yield the first batch becoming full.", + ) + + self.assertListEqual( + to_list(Stream(src).agroup(by=asyncify(lambda n: n % 2)), itype=itype), + [list(range(0, N, 2)), list(range(1, N, 2))], + msg="`agroup` called with a `by` function and an infinite size must cogroup elements and yield groups starting with the group containing the oldest element.", + ) + + self.assertListEqual( + to_list( + Stream(range(10)).agroup(by=asyncify(lambda n: n % 4 == 0)), itype=itype + ), + [[0, 4, 8], [1, 2, 3, 5, 6, 7, 9]], + msg="`agroup` called with a `by` function and reaching exhaustion must cogroup elements and yield uncomplete groups starting with the group containing the oldest element, even though it's not the largest.", + ) + + stream_iter = bi_iterable_to_iter( + Stream(src_raising_at_exhaustion).agroup(by=asyncify(lambda n: n % 2)), + itype=itype, + ) + self.assertListEqual( + [anext_or_next(stream_iter), anext_or_next(stream_iter)], + [list(range(0, N, 2)), list(range(1, N, 2))], + msg="`agroup` called with a `by` function and encountering an exception must cogroup elements and yield uncomplete groups starting with the group containing the oldest element.", + ) + with self.assertRaises( + TestError, + msg="`agroup` called with a `by` function and encountering an exception must raise it after all groups have been yielded", + ): + anext_or_next(stream_iter) + + # test seconds + by + self.assertListEqual( + to_list( + Stream(lambda: map(slow_identity, range(10))).agroup( + interval=datetime.timedelta(seconds=slow_identity_duration * 2.9), + by=asyncify(lambda n: n % 4 == 0), + ), + itype=itype, + ), + [[1, 2], [0, 4], [3, 5, 6, 7], [8], [9]], + msg="`agroup` called with a `by` function must cogroup elements and yield the largest groups when `seconds` is reached event though it's not the oldest.", + ) + + stream_iter = bi_iterable_to_iter( + Stream(src).agroup( + size=3, + by=asyncify( + lambda n: throw(stopiteration_for_iter_type(itype)) if n == 2 else n + ), + ), + itype=itype, + ) + self.assertListEqual( + [anext_or_next(stream_iter), anext_or_next(stream_iter)], + [[0], [1]], + msg="`agroup` should yield incomplete groups when `by` raises", + ) + with self.assertRaisesRegex( + (WrappedError, RuntimeError), + stopiteration_for_iter_type(itype).__name__, + msg="`agroup` should raise and skip `elem` if `by(elem)` raises", + ): + anext_or_next(stream_iter) + self.assertListEqual( + anext_or_next(stream_iter), + [3], + msg="`agroup` should continue yielding after `by`'s exception has been raised.", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_throttle(self, itype: IterableType) -> None: # behavior with invalid arguments with self.assertRaisesRegex( ValueError, r"`per` must be None or a positive timedelta but got datetime\.timedelta\(0\)", msg="`throttle` should raise error when called with negative `per`.", ): - list(Stream([1]).throttle(1, per=datetime.timedelta(microseconds=0))) + to_list( + Stream([1]).throttle(1, per=datetime.timedelta(microseconds=0)), + itype=itype, + ) with self.assertRaisesRegex( ValueError, "`count` must be >= 1 but got 0", msg="`throttle` should raise error when called with `count` < 1.", ): - list(Stream([1]).throttle(0, per=datetime.timedelta(seconds=1))) + to_list( + Stream([1]).throttle(0, per=datetime.timedelta(seconds=1)), itype=itype + ) # test interval interval_seconds = 0.3 @@ -1228,8 +1743,7 @@ def slow_first_elem(elem: int): ], ): with self.subTest(stream=stream): - duration, res = timestream(stream) - + duration, res = timestream(stream, itype=itype) self.assertListEqual( res, expected_elems, @@ -1246,11 +1760,12 @@ def slow_first_elem(elem: int): ) self.assertEqual( - next( - iter( + anext_or_next( + bi_iterable_to_iter( Stream(src) .throttle(1, per=datetime.timedelta(seconds=0.2)) - .throttle(1, per=datetime.timedelta(seconds=0.1)) + .throttle(1, per=datetime.timedelta(seconds=0.1)), + itype=itype, ) ), 0, @@ -1280,7 +1795,7 @@ def slow_first_elem(elem: int): ], ): with self.subTest(N=N, stream=stream): - duration, res = timestream(stream) + duration, res = timestream(stream, itype=itype) self.assertListEqual( res, expected_elems, @@ -1306,7 +1821,7 @@ def slow_first_elem(elem: int): .throttle(1, per=datetime.timedelta(seconds=0.2)), ]: with self.subTest(stream=stream): - duration, _ = timestream(stream) + duration, _ = timestream(stream, itype=itype) self.assertAlmostEqual( duration, expected_duration, @@ -1368,29 +1883,35 @@ def slow_first_elem(elem: int): msg="`throttle` must support legacy kwargs", ) - def test_distinct(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_distinct(self, itype: IterableType) -> None: self.assertListEqual( - list(Stream("abbcaabcccddd").distinct()), + to_list(Stream("abbcaabcccddd").distinct(), itype=itype), list("abcd"), msg="`distinct` should yield distinct elements", ) self.assertListEqual( - list(Stream("aabbcccaabbcccc").distinct(consecutive_only=True)), + to_list( + Stream("aabbcccaabbcccc").distinct(consecutive_only=True), itype=itype + ), list("abcabc"), msg="`distinct` should only remove the duplicates that are consecutive if `consecutive_only=True`", ) for consecutive_only in [True, False]: self.assertListEqual( - list( + to_list( Stream(["foo", "bar", "a", "b"]).distinct( len, consecutive_only=consecutive_only - ) + ), + itype=itype, ), ["foo", "a"], msg="`distinct` should yield the first encountered elem among duplicates", ) self.assertListEqual( - list(Stream([]).distinct(consecutive_only=consecutive_only)), + to_list( + Stream([]).distinct(consecutive_only=consecutive_only), itype=itype + ), [], msg="`distinct` should yield zero elements on empty stream", ) @@ -1399,11 +1920,51 @@ def test_distinct(self) -> None: "unhashable type: 'list'", msg="`distinct` should raise for non-hashable elements", ): - list(Stream([[1]]).distinct()) + to_list(Stream([[1]]).distinct(), itype=itype) - def test_catch(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_adistinct(self, itype: IterableType) -> None: + self.assertListEqual( + to_list(Stream("abbcaabcccddd").adistinct(), itype=itype), + list("abcd"), + msg="`adistinct` should yield distinct elements", + ) self.assertListEqual( - list(Stream(src).catch(finally_raise=True)), + to_list( + Stream("aabbcccaabbcccc").adistinct(consecutive_only=True), itype=itype + ), + list("abcabc"), + msg="`adistinct` should only remove the duplicates that are consecutive if `consecutive_only=True`", + ) + for consecutive_only in [True, False]: + self.assertListEqual( + to_list( + Stream(["foo", "bar", "a", "b"]).adistinct( + asyncify(len), consecutive_only=consecutive_only + ), + itype=itype, + ), + ["foo", "a"], + msg="`adistinct` should yield the first encountered elem among duplicates", + ) + self.assertListEqual( + to_list( + Stream([]).adistinct(consecutive_only=consecutive_only), itype=itype + ), + [], + msg="`adistinct` should yield zero elements on empty stream", + ) + with self.assertRaisesRegex( + TypeError, + "unhashable type: 'list'", + msg="`adistinct` should raise for non-hashable elements", + ): + to_list(Stream([[1]]).adistinct(), itype=itype) + + @parameterized.expand(ITERABLE_TYPES) + def test_catch(self, itype: IterableType) -> None: + self.assertListEqual( + to_list(Stream(src).catch(finally_raise=True), itype=itype), list(src), msg="`catch` should yield elements in exception-less scenarios", ) @@ -1431,12 +1992,12 @@ def f(i): safe_src = list(src) del safe_src[3] self.assertListEqual( - list(stream.catch(ZeroDivisionError)), + to_list(stream.catch(ZeroDivisionError), itype=itype), list(map(f, safe_src)), msg="If the exception type matches the `error_type`, then the impacted element should be ignored.", ) self.assertListEqual( - list(stream.catch()), + to_list(stream.catch(), itype=itype), list(map(f, safe_src)), msg="If the predicate is not specified, then all exceptions should be caught.", ) @@ -1445,7 +2006,7 @@ def f(i): ZeroDivisionError, msg="If a non caught exception type occurs, then it should be raised.", ): - list(stream.catch(TestError)) + to_list(stream.catch(TestError), itype=itype) first_value = 1 second_value = 2 @@ -1465,9 +2026,11 @@ def f(i): erroring_stream.catch(finally_raise=True), erroring_stream.catch(finally_raise=True), ]: - erroring_stream_iterator = iter(caught_erroring_stream) + erroring_stream_iterator = bi_iterable_to_iter( + caught_erroring_stream, itype=itype + ) self.assertEqual( - next(erroring_stream_iterator), + anext_or_next(erroring_stream_iterator), first_value, msg="`catch` should yield the first non exception throwing element.", ) @@ -1476,13 +2039,14 @@ def f(i): TestError, msg="`catch` should raise the first error encountered when `finally_raise` is True.", ): - for _ in erroring_stream_iterator: + while True: + anext_or_next(erroring_stream_iterator) n_yields += 1 with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(erroring_stream_iterator)), msg="`catch` with `finally_raise`=True should finally raise StopIteration to avoid infinite recursion if there is another catch downstream.", ): - next(erroring_stream_iterator) + anext_or_next(erroring_stream_iterator) self.assertEqual( n_yields, 3, @@ -1493,61 +2057,65 @@ def f(i): map(lambda _: throw(TestError), range(2000)) ).catch(TestError) self.assertListEqual( - list(only_caught_errors_stream), + to_list(only_caught_errors_stream, itype=itype), [], msg="When upstream raise exceptions without yielding any element, listing the stream must return empty list, without recursion issue.", ) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(itype), msg="When upstream raise exceptions without yielding any element, then the first call to `next` on a stream catching all errors should raise StopIteration.", ): - next(iter(only_caught_errors_stream)) + anext_or_next(bi_iterable_to_iter(only_caught_errors_stream, itype=itype)) - iterator = iter( + iterator = bi_iterable_to_iter( Stream(map(throw, [TestError, ValueError])) .catch(ValueError, finally_raise=True) - .catch(TestError, finally_raise=True) + .catch(TestError, finally_raise=True), + itype=itype, ) with self.assertRaises( ValueError, msg="With 2 chained `catch`s with `finally_raise=True`, the error caught by the first `catch` is finally raised first (even though it was raised second)...", ): - next(iterator) + anext_or_next(iterator) with self.assertRaises( TestError, msg="... and then the error caught by the second `catch` is raised...", ): - next(iterator) + anext_or_next(iterator) with self.assertRaises( - StopIteration, + stopiteration_for_iter_type(type(iterator)), msg="... and a StopIteration is raised next.", ): - next(iterator) + anext_or_next(iterator) with self.assertRaises( TypeError, msg="`catch` does not catch if `when` not satisfied", ): - list( + to_list( Stream(map(throw, [ValueError, TypeError])).catch( when=lambda exception: "ValueError" in repr(exception) - ) + ), + itype=itype, ) self.assertListEqual( - list( + to_list( Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch( ZeroDivisionError, replacement=float("inf") - ) + ), + itype=itype, ), [float("inf"), 1, 0.5, 0.25], msg="`catch` should be able to yield a non-None replacement", ) self.assertListEqual( - list( + to_list( Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch( ZeroDivisionError, replacement=cast(float, None) - ) + ), + itype=itype, ), [None, 1, 0.5, 0.25], msg="`catch` should be able to yield a None replacement", @@ -1556,7 +2124,7 @@ def f(i): errors_counter: Counter[Type[Exception]] = Counter() self.assertListEqual( - list( + to_list( Stream( map( lambda n: 1 / n, # potential ZeroDivisionError @@ -1571,7 +2139,8 @@ def f(i): ).catch( (ValueError, TestError, ZeroDivisionError), when=lambda err: errors_counter.update([type(err)]) is None, - ) + ), + itype=itype, ), list(map(lambda n: 1 / n, range(2, 10, 2))), msg="`catch` should accept multiple types", @@ -1589,15 +2158,16 @@ def f(i): # Stream(src).catch() # type: ignore self.assertEqual( - list(Stream(map(int, "foo")).catch(replacement=0)), + to_list(Stream(map(int, "foo")).catch(replacement=0), itype=itype), [0] * len("foo"), msg="`catch` must catch all errors when no error type provided", ) self.assertEqual( - list( + to_list( Stream(map(int, "foo")).catch( (None, None, ValueError, None), replacement=0 - ) + ), + itype=itype, ), [0] * len("foo"), msg="`catch` must catch the provided non-None error types", @@ -1606,14 +2176,239 @@ def f(i): ValueError, msg="`catch` must be noop if error type is None", ): - list(Stream(map(int, "foo")).catch(None)) + to_list(Stream(map(int, "foo")).catch(None), itype=itype) with self.assertRaises( ValueError, msg="`catch` must be noop if error types are None", ): - list(Stream(map(int, "foo")).catch((None, None, None))) + to_list(Stream(map(int, "foo")).catch((None, None, None)), itype=itype) + + @parameterized.expand(ITERABLE_TYPES) + def test_acatch(self, itype: IterableType) -> None: + self.assertListEqual( + to_list(Stream(src).acatch(finally_raise=True), itype=itype), + list(src), + msg="`acatch` should yield elements in exception-less scenarios", + ) + + with self.assertRaisesRegex( + TypeError, + "`iterator` must be an AsyncIterator but got a ", + msg="`afunctions.acatch` function should raise TypeError when first argument is not an AsyncIterator", + ): + from streamable import afunctions + + afunctions.acatch(cast(AsyncIterator, [3, 4]), Exception) + + with self.assertRaisesRegex( + TypeError, + "`errors` must be None, or a subclass of `Exception`, or an iterable of optional subclasses of `Exception`, but got ", + msg="`acatch` should raise TypeError when first argument is not None or Type[Exception], or Iterable[Optional[Type[Exception]]]", + ): + Stream(src).acatch(1) # type: ignore + + def f(i): + return i / (3 - i) + + stream = Stream(lambda: map(f, src)) + safe_src = list(src) + del safe_src[3] + self.assertListEqual( + to_list(stream.acatch(ZeroDivisionError), itype=itype), + list(map(f, safe_src)), + msg="If the exception type matches the `error_type`, then the impacted element should be ignored.", + ) + self.assertListEqual( + to_list(stream.acatch(), itype=itype), + list(map(f, safe_src)), + msg="If the predicate is not specified, then all exceptions should be caught.", + ) + + with self.assertRaises( + ZeroDivisionError, + msg="If a non caught exception type occurs, then it should be raised.", + ): + to_list(stream.acatch(TestError), itype=itype) + + first_value = 1 + second_value = 2 + third_value = 3 + functions = [ + lambda: throw(TestError), + lambda: throw(TypeError), + lambda: first_value, + lambda: second_value, + lambda: throw(ValueError), + lambda: third_value, + lambda: throw(ZeroDivisionError), + ] + + erroring_stream: Stream[int] = Stream(lambda: map(lambda f: f(), functions)) + for caught_erroring_stream in [ + erroring_stream.acatch(finally_raise=True), + erroring_stream.acatch(finally_raise=True), + ]: + erroring_stream_iterator = bi_iterable_to_iter( + caught_erroring_stream, itype=itype + ) + self.assertEqual( + anext_or_next(erroring_stream_iterator), + first_value, + msg="`acatch` should yield the first non exception throwing element.", + ) + n_yields = 1 + with self.assertRaises( + TestError, + msg="`acatch` should raise the first error encountered when `finally_raise` is True.", + ): + while True: + anext_or_next(erroring_stream_iterator) + n_yields += 1 + with self.assertRaises( + stopiteration_for_iter_type(type(erroring_stream_iterator)), + msg="`acatch` with `finally_raise`=True should finally raise StopIteration to avoid infinite recursion if there is another catch downstream.", + ): + anext_or_next(erroring_stream_iterator) + self.assertEqual( + n_yields, + 3, + msg="3 elements should have passed been yielded between caught exceptions.", + ) + + only_caught_errors_stream = Stream( + map(lambda _: throw(TestError), range(2000)) + ).acatch(TestError) + self.assertListEqual( + to_list(only_caught_errors_stream, itype=itype), + [], + msg="When upstream raise exceptions without yielding any element, listing the stream must return empty list, without recursion issue.", + ) + with self.assertRaises( + stopiteration_for_iter_type(itype), + msg="When upstream raise exceptions without yielding any element, then the first call to `next` on a stream catching all errors should raise StopIteration.", + ): + anext_or_next(bi_iterable_to_iter(only_caught_errors_stream, itype=itype)) - def test_observe(self) -> None: + iterator = bi_iterable_to_iter( + Stream(map(throw, [TestError, ValueError])) + .acatch(ValueError, finally_raise=True) + .acatch(TestError, finally_raise=True), + itype=itype, + ) + with self.assertRaises( + ValueError, + msg="With 2 chained `acatch`s with `finally_raise=True`, the error caught by the first `acatch` is finally raised first (even though it was raised second)...", + ): + anext_or_next(iterator) + with self.assertRaises( + TestError, + msg="... and then the error caught by the second `acatch` is raised...", + ): + anext_or_next(iterator) + with self.assertRaises( + stopiteration_for_iter_type(type(iterator)), + msg="... and a StopIteration is raised next.", + ): + anext_or_next(iterator) + + with self.assertRaises( + TypeError, + msg="`acatch` does not catch if `when` not satisfied", + ): + to_list( + Stream(map(throw, [ValueError, TypeError])).acatch( + when=asyncify(lambda exception: "ValueError" in repr(exception)) + ), + itype=itype, + ) + + self.assertListEqual( + to_list( + Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).acatch( + ZeroDivisionError, replacement=float("inf") + ), + itype=itype, + ), + [float("inf"), 1, 0.5, 0.25], + msg="`acatch` should be able to yield a non-None replacement", + ) + self.assertListEqual( + to_list( + Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).acatch( + ZeroDivisionError, replacement=cast(float, None) + ), + itype=itype, + ), + [None, 1, 0.5, 0.25], + msg="`acatch` should be able to yield a None replacement", + ) + + errors_counter: Counter[Type[Exception]] = Counter() + + self.assertListEqual( + to_list( + Stream( + map( + lambda n: 1 / n, # potential ZeroDivisionError + map( + throw_for_odd_func(TestError), # potential TestError + map( + int, # potential ValueError + "01234foo56789", + ), + ), + ) + ).acatch( + (ValueError, TestError, ZeroDivisionError), + when=asyncify( + lambda err: errors_counter.update([type(err)]) is None + ), + ), + itype=itype, + ), + list(map(lambda n: 1 / n, range(2, 10, 2))), + msg="`acatch` should accept multiple types", + ) + self.assertDictEqual( + errors_counter, + {TestError: 5, ValueError: 3, ZeroDivisionError: 1}, + msg="`acatch` should accept multiple types", + ) + + # with self.assertRaises( + # TypeError, + # msg="`acatch` without any error type must raise", + # ): + # Stream(src).acatch() # type: ignore + + self.assertEqual( + to_list(Stream(map(int, "foo")).acatch(replacement=0), itype=itype), + [0] * len("foo"), + msg="`acatch` must catch all errors when no error type provided", + ) + self.assertEqual( + to_list( + Stream(map(int, "foo")).acatch( + (None, None, ValueError, None), replacement=0 + ), + itype=itype, + ), + [0] * len("foo"), + msg="`acatch` must catch the provided non-None error types", + ) + with self.assertRaises( + ValueError, + msg="`acatch` must be noop if error type is None", + ): + to_list(Stream(map(int, "foo")).acatch(None), itype=itype) + with self.assertRaises( + ValueError, + msg="`acatch` must be noop if error types are None", + ): + to_list(Stream(map(int, "foo")).acatch((None, None, None)), itype=itype) + + @parameterized.expand(ITERABLE_TYPES) + def test_observe(self, itype: IterableType) -> None: value_error_rainsing_stream: Stream[List[int]] = ( Stream("123--678") .throttle(10, per=datetime.timedelta(seconds=1)) @@ -1625,7 +2420,7 @@ def test_observe(self) -> None: ) self.assertListEqual( - list(value_error_rainsing_stream.catch(ValueError)), + to_list(value_error_rainsing_stream.catch(ValueError), itype=itype), [[1, 2], [3], [6, 7], [8]], msg="This can break due to `group`/`map`/`catch`, check other breaking tests to determine quickly if it's an issue with `observe`.", ) @@ -1634,10 +2429,11 @@ def test_observe(self) -> None: ValueError, msg="`observe` should forward-raise exceptions", ): - list(value_error_rainsing_stream) + to_list(value_error_rainsing_stream, itype=itype) def test_is_iterable(self) -> None: self.assertIsInstance(Stream(src), Iterable) + self.assertIsInstance(Stream(src), AsyncIterable) def test_count(self) -> None: l: List[int] = [] @@ -1656,6 +2452,23 @@ def effect(x: int) -> None: l, list(src), msg="`count` should iterate over the entire stream." ) + def test_acount(self) -> None: + l: List[int] = [] + + def effect(x: int) -> None: + nonlocal l + l.append(x) + + stream = Stream(lambda: map(effect, src)) + self.assertEqual( + asyncio.run(stream.acount()), + N, + msg="`count` should return the count of elements.", + ) + self.assertListEqual( + l, list(src), msg="`count` should iterate over the entire stream." + ) + def test_call(self) -> None: l: List[int] = [] stream = Stream(src).map(l.append) @@ -1670,51 +2483,62 @@ def test_call(self) -> None: msg="`__call__` should exhaust the stream.", ) - def test_multiple_iterations(self) -> None: + def test_await(self) -> None: + l: List[int] = [] + stream = Stream(src).map(l.append) + self.assertIs( + asyncio.run(awaitable_to_coroutine(stream)), + stream, + msg="`__call__` should return the stream.", + ) + self.assertListEqual( + l, + list(src), + msg="`__call__` should exhaust the stream.", + ) + + @parameterized.expand(ITERABLE_TYPES) + def test_multiple_iterations(self, itype: IterableType) -> None: stream = Stream(src) for _ in range(3): self.assertListEqual( - list(stream), + to_list(stream, itype=itype), list(src), msg="The first iteration over a stream should yield the same elements as any subsequent iteration on the same stream, even if it is based on a `source` returning an iterator that only support 1 iteration.", ) @parameterized.expand( - [ - [1], - [100], - ] + [(concurrency, itype) for concurrency in (1, 100) for itype in ITERABLE_TYPES] ) - def test_amap(self, concurrency) -> None: + def test_amap(self, concurrency, itype) -> None: self.assertListEqual( - list( + to_list( Stream(src).amap( async_randomly_slowed(async_square), concurrency=concurrency - ) + ), + itype=itype, ), list(map(square, src)), msg="At any concurrency the `amap` method should act as the builtin map function, transforming elements while preserving input elements order.", ) - stream = Stream(src).amap(identity) # type: ignore + stream = Stream(src).amap(identity, concurrency=concurrency) # type: ignore with self.assertRaisesRegex( TypeError, - r"`transformation` must be an async function i\.e\. a function returning a Coroutine but it returned a ", + r"must be an async function i\.e\. a function returning a Coroutine but it returned a ", msg="`amap` should raise a TypeError if a non async function is passed to it.", ): - next(iter(stream)) + anext_or_next(bi_iterable_to_iter(stream, itype=itype)) @parameterized.expand( - [ - [1], - [100], - ] + [(concurrency, itype) for concurrency in (1, 100) for itype in ITERABLE_TYPES] ) - def test_aforeach(self, concurrency) -> None: + def test_aforeach(self, concurrency, itype) -> None: self.assertListEqual( - list( + to_list( Stream(src).aforeach( async_randomly_slowed(async_square), concurrency=concurrency - ) + ), + itype=itype, ), list(src), msg="At any concurrency the `foreach` method must preserve input elements order.", @@ -1725,9 +2549,10 @@ def test_aforeach(self, concurrency) -> None: r"`transformation` must be an async function i\.e\. a function returning a Coroutine but it returned a ", msg="`aforeach` should raise a TypeError if a non async function is passed to it.", ): - next(iter(stream)) + anext_or_next(bi_iterable_to_iter(stream, itype=itype)) - def test_pipe(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_pipe(self, itype: IterableType) -> None: def func( stream: Stream, *ints: int, **strings: str ) -> Tuple[Stream, Tuple[int, ...], Dict[str, str]]: @@ -1744,8 +2569,8 @@ def func( ) self.assertListEqual( - stream.pipe(list), - list(stream), + stream.pipe(to_list, itype=itype), + to_list(stream, itype=itype), msg="`pipe` should be ok without args and kwargs.", ) @@ -1753,18 +2578,27 @@ def test_eq(self) -> None: stream = ( Stream(src) .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=1)) ) @@ -1772,64 +2606,100 @@ def test_eq(self) -> None: stream, Stream(src) .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=1)), ) self.assertNotEqual( stream, Stream(list(src)) # not same source .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=1)), ) self.assertNotEqual( stream, Stream(src) .catch((TypeError, ValueError), replacement=2, when=identity) + .acatch((TypeError, ValueError), replacement=2, when=async_identity) .distinct(key=identity) + .adistinct(key=async_identity) .filter(identity) + .afilter(async_identity) .foreach(identity, concurrency=3) .aforeach(async_identity, concurrency=3) .group(3, by=bool) .flatten(concurrency=3) + .agroup(3, by=async_identity) + .map(sync_to_async_iter) + .aflatten(concurrency=3) .groupby(bool) + .agroupby(async_identity) .map(identity, via="process") .amap(async_identity) .observe("foo") .skip(3) + .askip(3) .truncate(4) + .atruncate(4) .throttle(1, per=datetime.timedelta(seconds=2)), # not the same interval ) - def test_ref_cycles(self) -> None: + @parameterized.expand(ITERABLE_TYPES) + def test_ref_cycles(self, itype: IterableType) -> None: + async def async_int(o: Any) -> int: + return int(o) + stream = ( - Stream(map(int, "123_5")).group(1).groupby(len).catch(finally_raise=True) + Stream("123_5") + .amap(async_int) + .map(str) + .group(1) + .groupby(len) + .catch(finally_raise=True) ) exception: Exception try: - list(stream) + to_list(stream, itype=itype) except ValueError as e: exception = e self.assertIsInstance( @@ -1837,19 +2707,21 @@ def test_ref_cycles(self) -> None: ValueError, msg="`finally_raise` must be respected", ) - frames: Iterator[FrameType] = map( - itemgetter(0), traceback.walk_tb(exception.__traceback__) - ) - next(frames) self.assertFalse( [ (var, val) - for frame in frames + # go through the frames of the exception's traceback + for frame, _ in traceback.walk_tb(exception.__traceback__) + # skipping the current frame + if frame is not cast(TracebackType, exception.__traceback__).tb_frame + # go through the locals captured in that frame for var, val in frame.f_locals.items() + # check if one of them is an exception if isinstance(val, Exception) - and frame in map(itemgetter(0), traceback.walk_tb(val.__traceback__)) + # check if it is captured in its own traceback + and frame is cast(TracebackType, val.__traceback__).tb_frame ], - msg=f"the exception's traceback should not contain an exception that captures itself in its own traceback", + msg=f"the exception's traceback should not contain an exception captured in its own traceback", ) def test_on_queue_in_thread(self) -> None: diff --git a/tests/test_visitor.py b/tests/test_visitor.py index d13301df..ac5cce7c 100644 --- a/tests/test_visitor.py +++ b/tests/test_visitor.py @@ -2,8 +2,16 @@ from typing import cast from streamable.stream import ( + ACatchStream, + ADistinctStream, + AFilterStream, + AFlattenStream, AForeachStream, + AGroupbyStream, + AGroupStream, AMapStream, + ASkipStream, + ATruncateStream, CatchStream, DistinctStream, FilterStream, @@ -29,19 +37,27 @@ def visit_stream(self, stream: Stream) -> None: visitor = ConcreteVisitor() visitor.visit_catch_stream(cast(CatchStream, ...)) + visitor.visit_acatch_stream(cast(ACatchStream, ...)) visitor.visit_distinct_stream(cast(DistinctStream, ...)) + visitor.visit_adistinct_stream(cast(ADistinctStream, ...)) visitor.visit_filter_stream(cast(FilterStream, ...)) + visitor.visit_afilter_stream(cast(AFilterStream, ...)) visitor.visit_flatten_stream(cast(FlattenStream, ...)) + visitor.visit_aflatten_stream(cast(AFlattenStream, ...)) visitor.visit_foreach_stream(cast(ForeachStream, ...)) visitor.visit_aforeach_stream(cast(AForeachStream, ...)) visitor.visit_group_stream(cast(GroupStream, ...)) + visitor.visit_agroup_stream(cast(AGroupStream, ...)) visitor.visit_groupby_stream(cast(GroupbyStream, ...)) + visitor.visit_agroupby_stream(cast(AGroupbyStream, ...)) visitor.visit_map_stream(cast(MapStream, ...)) visitor.visit_amap_stream(cast(AMapStream, ...)) visitor.visit_observe_stream(cast(ObserveStream, ...)) visitor.visit_skip_stream(cast(SkipStream, ...)) + visitor.visit_askip_stream(cast(ASkipStream, ...)) visitor.visit_throttle_stream(cast(ThrottleStream, ...)) visitor.visit_truncate_stream(cast(TruncateStream, ...)) + visitor.visit_atruncate_stream(cast(ATruncateStream, ...)) visitor.visit_stream(cast(Stream, ...)) def test_depth_visitor_example(self): diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..94d6efa3 --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,215 @@ +import asyncio +import random +import time +import timeit +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Callable, + Coroutine, + Iterable, + Iterator, + List, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from streamable.stream import Stream +from streamable.util.iterabletools import BiIterable + +T = TypeVar("T") +R = TypeVar("R") + +IterableType = Union[Type[Iterable], Type[AsyncIterable]] +ITERABLE_TYPES: Tuple[IterableType, ...] = (Iterable, AsyncIterable) + +TEST_EVENT_LOOP = asyncio.new_event_loop() + + +async def _aiter_to_list(aiterable: AsyncIterable[T]) -> List[T]: + return [elem async for elem in aiterable] + + +def aiterable_to_list(aiterable: AsyncIterable[T]) -> List[T]: + return TEST_EVENT_LOOP.run_until_complete(_aiter_to_list(aiterable)) + + +async def _aiter_to_set(aiterable: AsyncIterable[T]) -> Set[T]: + return {elem async for elem in aiterable} + + +def aiterable_to_set(aiterable: AsyncIterable[T]) -> Set[T]: + return TEST_EVENT_LOOP.run_until_complete(_aiter_to_set(aiterable)) + + +def stopiteration_for_iter_type(itype: IterableType) -> Type[Exception]: + if issubclass(itype, AsyncIterable): + return StopAsyncIteration + return StopIteration + + +def to_list(stream: Stream[T], itype: IterableType) -> List[T]: + assert isinstance(stream, Stream) + if itype is AsyncIterable: + return aiterable_to_list(stream) + else: + return list(stream) + + +def to_set(stream: Stream[T], itype: IterableType) -> Set[T]: + assert isinstance(stream, Stream) + if itype is AsyncIterable: + return aiterable_to_set(stream) + else: + return set(stream) + + +def bi_iterable_to_iter( + iterable: Union[BiIterable[T], Stream[T]], itype: IterableType +) -> Union[Iterator[T], AsyncIterator[T]]: + if itype is AsyncIterable: + return iterable.__aiter__() + else: + return iter(iterable) + + +def anext_or_next(it: Union[Iterator[T], AsyncIterator[T]]) -> T: + if isinstance(it, AsyncIterator): + return TEST_EVENT_LOOP.run_until_complete(it.__anext__()) + else: + return next(it) + + +def alist_or_list(iterable: Union[Iterable[T], AsyncIterable[T]]) -> List[T]: + if isinstance(iterable, AsyncIterable): + return aiterable_to_list(iterable) + else: + return list(iterable) + + +def timestream( + stream: Stream[T], times: int = 1, itype: IterableType = Iterable +) -> Tuple[float, List[T]]: + res: List[T] = [] + + def iterate(): + nonlocal res + res = to_list(stream, itype=itype) + + return timeit.timeit(iterate, number=times) / times, res + + +def identity_sleep(seconds: float) -> float: + time.sleep(seconds) + return seconds + + +async def async_identity_sleep(seconds: float) -> float: + await asyncio.sleep(seconds) + return seconds + + +# simulates an I/0 bound function +slow_identity_duration = 0.05 + + +def slow_identity(x: T) -> T: + time.sleep(slow_identity_duration) + return x + + +async def async_slow_identity(x: T) -> T: + await asyncio.sleep(slow_identity_duration) + return x + + +def identity(x: T) -> T: + return x + + +# fmt: off +async def async_identity(x: T) -> T: return x +# fmt: on + + +def square(x): + return x**2 + + +async def async_square(x): + return x**2 + + +def throw(exc: Type[Exception]): + raise exc() + + +def throw_func(exc: Type[Exception]) -> Callable[[T], T]: + return lambda _: throw(exc) + + +def async_throw_func(exc: Type[Exception]) -> Callable[[T], Coroutine[Any, Any, T]]: + async def f(_: T) -> T: + raise exc + + return f + + +def throw_for_odd_func(exc): + return lambda i: throw(exc) if i % 2 == 1 else i + + +def async_throw_for_odd_func(exc): + async def f(i): + return throw(exc) if i % 2 == 1 else i + + return f + + +class TestError(Exception): + pass + + +DELTA_RATE = 0.4 +# size of the test collections +N = 256 + +src = range(N) + +even_src = range(0, N, 2) + + +def randomly_slowed( + func: Callable[[T], R], min_sleep: float = 0.001, max_sleep: float = 0.05 +) -> Callable[[T], R]: + def wrap(x: T) -> R: + time.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) + return func(x) + + return wrap + + +def async_randomly_slowed( + async_func: Callable[[T], Coroutine[Any, Any, R]], + min_sleep: float = 0.001, + max_sleep: float = 0.05, +) -> Callable[[T], Coroutine[Any, Any, R]]: + async def wrap(x: T) -> R: + await asyncio.sleep(min_sleep + random.random() * (max_sleep - min_sleep)) + return await async_func(x) + + return wrap + + +def range_raising_at_exhaustion( + start: int, end: int, step: int, exception: Exception +) -> Iterator[int]: + yield from range(start, end, step) + raise exception + + +src_raising_at_exhaustion = lambda: range_raising_at_exhaustion(0, N, 1, TestError()) diff --git a/version.py b/version.py index 7117f661..58329072 100644 --- a/version.py +++ b/version.py @@ -1,2 +1,2 @@ -# print CHANGELOG: git log --oneline -- version.py | grep -v '\-rc' -__version__ = "1.5.1" +# print CHANGELOG: git log --oneline -- version.py +__version__ = "1.6.0a2"