diff --git a/.github/workflows/pypi.yml b/.github/workflows/pypi.yml
index 02bed3e0..09157371 100644
--- a/.github/workflows/pypi.yml
+++ b/.github/workflows/pypi.yml
@@ -4,6 +4,9 @@ on:
push:
branches:
- main
+ - alpha
+ - beta
+ - rc
paths:
- 'version.py'
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 0624e37c..06a6cebc 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -27,7 +27,7 @@ is relatively close to
Applying an operation simply performs argument validation and returns a new instance of a `Stream` sub-class corresponding to the operation. Defining a stream is basically constructing a composite structure where each node is a `Stream` instance: The above `events` variable holds a `CatchStream[int]` instance, whose `.upstream` attribute points to a `TruncateStream[int]` instance, whose `.upstream` attribute points to a `ForeachStream[int]` instance, whose `.upstream` points to a `Stream[int]` instance. Each node's `.source` attribute points to the same `range(10)`.
## Visitor Pattern
-Each node in this composite structure exposes an `.accept` method enabling traversal by a visitor. Both `Stream.__iter__` and `Stream.__repr__` rely on visitor classes defined in the `streamable.visitors` package.
+Each node in this composite structure exposes an `.accept` method enabling traversal by a visitor. `Stream.__iter__`/`Stream.__aiter__`/`Stream.__repr__` rely on visitor classes defined in the `streamable.visitors` package.
## Decorator Pattern
A `Stream[T]` both inherits from `Iterable[T]` and holds an `Iterable[T]` as its `.source`: when you instantiate a stream from an iterable you decorate it with a fluent interface.
diff --git a/README.md b/README.md
index 4bf64c0b..cccbf145 100644
--- a/README.md
+++ b/README.md
@@ -1,31 +1,26 @@
[](https://codecov.io/gh/ebonnal/streamable)
-[](https://github.com/ebonnal/streamable/actions)
-[](https://github.com/ebonnal/streamable/actions)
-[](https://github.com/ebonnal/streamable/actions)
[](https://pypi.org/project/streamable)
[](https://anaconda.org/conda-forge/streamable)
# ΰΌ `streamable`
-### *Pythonic Stream-like manipulation of iterables*
+### *Pythonic Stream-like manipulation of (async) iterables*
- π ***Fluent*** chainable lazy operations
-- π ***Concurrent*** via *threads*/*processes*/`asyncio`
-- πΉ ***Typed***, fully annotated, `Stream[T]` is an `Iterable[T]`
-- π‘οΈ ***Tested*** extensively with **Python 3.7 to 3.14**
-- πͺΆ ***Light***, no dependencies
-
+- π ***Concurrent*** via *threads*/*processes*/`async`
+- πΉ Fully ***Typed***, `Stream[T]` is an `Iterable[T]` (and an `AsyncIterable[T]`)
+- π‘οΈ ***Battle-tested*** for prod, extensively tested with **Python 3.7 to 3.14**.
## 1. install
-```bash
-pip install streamable
-```
-*or*
-```bash
-conda install conda-forge::streamable
-```
+`pip install streamable`
+
+or
+
+`conda install conda-forge::streamable`
+
+No dependencies.
## 2. import
@@ -35,7 +30,7 @@ from streamable import Stream
## 3. init
-Create a `Stream[T]` *decorating* an `Iterable[T]`:
+Create a `Stream[T]` *decorating* an `Iterable[T]` (or an `AsyncIterable[T]`):
```python
integers: Stream[int] = Stream(range(10))
@@ -55,9 +50,14 @@ inverses: Stream[float] = (
## 5. iterate
-Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, elements are processed *on-the-fly*:
+Iterate over a `Stream[T]` just as you would over any other `Iterable[T]` (or `AsyncIterable[T]`), elements are processed *on-the-fly*:
-- **collect**
+
+### as an `Iterable[T]`
+
+π show snippets
+
+- **into data structure**
```python
>>> list(inverses)
[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11]
@@ -65,7 +65,13 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, eleme
{0.5, 1.0, 0.2, 0.33, 0.25, 0.17, 0.14, 0.12, 0.11}
```
-- **reduce**
+- **`for`**
+```python
+>>> [inverse for inverse in inverses]:
+[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11]
+```
+
+- **`reduce`**
```python
>>> sum(inverses)
2.82
@@ -73,27 +79,153 @@ Iterate over a `Stream[T]` just as you would over any other `Iterable[T]`, eleme
>>> reduce(..., inverses)
```
-- **loop**
+- **`iter`/`next`**
```python
->>> for inverse in inverses:
->>> ...
+>>> next(iter(inverses))
+1.0
```
-- **next**
+
+
+### as an `AsyncIterable[T]`
+
+π show snippets
+
+- **`async for`**
```python
->>> next(iter(inverses))
+>>> async def main() -> List[float]:
+>>> return [inverse async for inverse in inverses]
+
+>>> asyncio.run(main())
+[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11]
+```
+
+- **`aiter`/`anext`**
+```python
+>>> asyncio.run(anext(aiter(inverses))) # before 3.10: inverses.__aiter__().__anext__()
1.0
```
+
+
+
+# β Showcase: Extract-Transform-Load
+
+Let's take an example showcasing most of the `Stream`'s operations:
+
+This script extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV:
+
+```python
+import csv
+from datetime import timedelta
+import itertools
+import requests
+from streamable import Stream
+
+with open("./quadruped_pokemons.csv", mode="w") as file:
+ fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"]
+ writer = csv.DictWriter(file, fields, extrasaction='ignore')
+ writer.writeheader()
+
+ pipeline: Stream = (
+ # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur
+ Stream(itertools.count(1))
+ # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs
+ .throttle(16, per=timedelta(seconds=1))
+ # GETs pokemons concurrently using a pool of 8 threads
+ .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}")
+ .map(requests.get, concurrency=8)
+ .foreach(requests.Response.raise_for_status)
+ .map(requests.Response.json)
+ # Stops the iteration when reaching the 1st pokemon of the 4th generation
+ .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv")
+ .observe("pokemons")
+ # Keeps only quadruped Pokemons
+ .filter(lambda poke: poke["shape"]["name"] == "quadruped")
+ .observe("quadruped pokemons")
+ # Catches errors due to None "generation" or "shape"
+ .catch(
+ TypeError,
+ when=lambda error: str(error) == "'NoneType' object is not subscriptable"
+ )
+ # Writes a batch of pokemons every 5 seconds to the CSV file
+ .group(interval=timedelta(seconds=5))
+ .foreach(writer.writerows)
+ .flatten()
+ .observe("written pokemons")
+ # Catches exceptions and raises the 1st one at the end of the iteration
+ .catch(Exception, finally_raise=True)
+ )
+
+ pipeline()
+```
+
+## or the `async` way
+
+Use the `.amap` operation and `await` the `Stream`:
+
+```python
+import asyncio
+import csv
+from datetime import timedelta
+import itertools
+import httpx
+from streamable import Stream
+
+async def main() -> None:
+ with open("./quadruped_pokemons.csv", mode="w") as file:
+ fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"]
+ writer = csv.DictWriter(file, fields, extrasaction='ignore')
+ writer.writeheader()
+
+ async with httpx.AsyncClient() as http_async_client:
+ pipeline: Stream = (
+ # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur
+ Stream(itertools.count(1))
+ # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs
+ .throttle(16, per=timedelta(seconds=1))
+ # GETs pokemons via 8 concurrent coroutines
+ .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}")
+ .amap(http_async_client.get, concurrency=8)
+ .foreach(httpx.Response.raise_for_status)
+ .map(httpx.Response.json)
+ # Stops the iteration when reaching the 1st pokemon of the 4th generation
+ .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv")
+ .observe("pokemons")
+ # Keeps only quadruped Pokemons
+ .filter(lambda poke: poke["shape"]["name"] == "quadruped")
+ .observe("quadruped pokemons")
+ # Catches errors due to None "generation" or "shape"
+ .catch(
+ TypeError,
+ when=lambda error: str(error) == "'NoneType' object is not subscriptable"
+ )
+ # Writes a batch of pokemons every 5 seconds to the CSV file
+ .group(interval=timedelta(seconds=5))
+ .foreach(writer.writerows)
+ .flatten()
+ .observe("written pokemons")
+ # Catches exceptions and raises the 1st one at the end of the iteration
+ .catch(Exception, finally_raise=True)
+ )
+
+ await pipeline
+
+asyncio.run(main())
+```
+
# π ***Operations***
-*A dozen expressive lazy operations and thatβs it!*
+A dozen expressive lazy operations and that's it.
-# `.map`
+> [!NOTE]
+> **`async` twin operations:** Each operation that takes a function also has an async version (same name with an β`a`β prefix) that accepts `async` functions. You can mix both types of operations on the same `Stream`, which can be used as either an `Iterable` or an `AsyncIterable`.
+
+## `.map`/`.amap`
> Applies a transformation on elements:
-π show example
+π show snippet
```python
integer_strings: Stream[str] = integers.map(str)
@@ -102,16 +234,11 @@ assert list(integer_strings) == ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9
```
-## concurrency
-
-> [!NOTE]
-> By default, all the concurrency modes presented below yield results in the upstream order (FIFO). Set the parameter `ordered=False` to yield results as they become available (***First Done, First Out***).
-
### thread-based concurrency
-> Applies the transformation via `concurrency` threads:
+> Applies the transformation via `concurrency` threads, yielding results in the upstream order (FIFO), set the parameter `ordered=False` to yield results as they become available (*First Done, First Out*).
-π show example
+π show snippet
```python
import requests
@@ -128,16 +255,13 @@ assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur']
> [!NOTE]
-> `concurrency` is also the size of the buffer containing not-yet-yielded results. **If the buffer is full, the iteration over the upstream is paused** until a result is yielded from the buffer.
-
-> [!TIP]
-> The performance of thread-based concurrency in a CPU-bound script can be drastically improved by using a [Python 3.13+ free-threading build](https://docs.python.org/3/using/configure.html#cmdoption-disable-gil).
+> **Memory-efficient**: Only `concurrency` upstream elements are pulled for processing; the next upstream element is pulled only when a result is yielded downstream.
### process-based concurrency
> Set `via="process"`:
-π show example
+π show snippet
```python
if __name__ == "__main__":
@@ -149,15 +273,17 @@ if __name__ == "__main__":
```
-### `asyncio`-based concurrency
+### `async`-based concurrency: `.amap`
+
+> `.amap` can apply an `async` transformation concurrently.
-> The sibling operation `.amap` applies an async function:
+π show snippet
-π show example
+- consumed as an `Iterable[T]`:
```python
-import httpx
import asyncio
+import httpx
http_async_client = httpx.AsyncClient()
@@ -170,15 +296,35 @@ pokemon_names: Stream[str] = (
)
assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur']
-asyncio.get_event_loop().run_until_complete(http_async_client.aclose())
+asyncio.run(http_async_client.aclose())
+```
+
+- consumed as an `AsyncIterable[T]`:
+
+```python
+import asyncio
+import httpx
+
+async def main() -> None:
+ async with httpx.AsyncClient() as http_async_client:
+ pokemon_names: Stream[str] = (
+ Stream(range(1, 4))
+ .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}")
+ .amap(http_async_client.get, concurrency=3)
+ .map(httpx.Response.json)
+ .map(lambda poke: poke["name"])
+ )
+ assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur']
+
+asyncio.run(main())
```
-## "starmap"
+### "starmap"
> The `star` function decorator transforms a function that takes several positional arguments into a function that takes a tuple:
-π show example
+π show snippet
```python
from streamable import star
@@ -193,14 +339,11 @@ assert list(zeros) == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
-
-# `.foreach`
-
-
+## `.foreach` / `.aforeach`
> Applies a side effect on elements:
-π show example
+π show snippet
```python
state: List[int] = []
@@ -211,19 +354,20 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
```
-## concurrency
+### concurrency
> Similar to `.map`:
> - set the `concurrency` parameter for **thread-based concurrency**
> - set `via="process"` for **process-based concurrency**
-> - use the sibling `.aforeach` operation for **`asyncio`-based concurrency**
> - set `ordered=False` for ***First Done First Out***
+> - The `.aforeach` operation can apply an `async` effect concurrently.
-# `.group`
+## `.group` / `.agroup`
-> Groups elements into `List`s:
+> Groups into `List`s
-π show example
+> ... up to a given group `size`:
+π show snippet
```python
integers_by_5: Stream[List[int]] = integers.group(size=5)
@@ -231,7 +375,10 @@ integers_by_5: Stream[List[int]] = integers.group(size=5)
assert list(integers_by_5) == [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
```
-π show example
+
+> ... and/or co-groups `by` a given key:
+
+π show snippet
```python
integers_by_parity: Stream[List[int]] = integers.group(by=lambda n: n % 2)
@@ -239,7 +386,10 @@ integers_by_parity: Stream[List[int]] = integers.group(by=lambda n: n % 2)
assert list(integers_by_parity) == [[0, 2, 4, 6, 8], [1, 3, 5, 7, 9]]
```
-π show example
+
+> ... and/or co-groups the elements yielded by the upstream within a given time `interval`:
+
+π show snippet
```python
from datetime import timedelta
@@ -254,8 +404,10 @@ assert list(integers_within_1_sec) == [[0, 1, 2], [3, 4], [5, 6], [7, 8], [9]]
```
-> Mix the `size`/`by`/`interval` parameters:
-π show example
+> [!TIP]
+> Combine the `size`/`by`/`interval` parameters:
+
+π show snippet
```python
integers_by_parity_by_2: Stream[List[int]] = (
@@ -267,11 +419,10 @@ assert list(integers_by_parity_by_2) == [[0, 2], [1, 3], [4, 6], [5, 7], [8], [9
```
-
-## `.groupby`
+## `.groupby` / `.agroupby`
> Like `.group`, but groups into `(key, elements)` tuples:
-π show example
+π show snippet
```python
integers_by_parity: Stream[Tuple[str, List[int]]] = (
@@ -286,7 +437,7 @@ assert list(integers_by_parity) == [("even", [0, 2, 4, 6, 8]), ("odd", [1, 3, 5,
> [!TIP]
> Then *"starmap"* over the tuples:
-π show example
+π show snippet
```python
from streamable import star
@@ -300,11 +451,11 @@ assert list(counts_by_parity) == [("even", 5), ("odd", 5)]
```
-# `.flatten`
+## `.flatten` / `.aflatten`
-> Ungroups elements assuming that they are `Iterable`s:
+> Ungroups elements assuming that they are `Iterable`s (or `AsyncIterable`s for `.aflatten`):
-π show example
+π show snippet
```python
even_then_odd_integers: Stream[int] = integers_by_parity.flatten()
@@ -315,9 +466,9 @@ assert list(even_then_odd_integers) == [0, 2, 4, 6, 8, 1, 3, 5, 7, 9]
### thread-based concurrency
-> Flattens `concurrency` iterables concurrently:
+> Concurrently flattens `concurrency` iterables via threads (or via coroutines for `.aflatten`):
-π show example
+π show snippet
```python
mixed_ones_and_zeros: Stream[int] = (
@@ -328,11 +479,11 @@ assert list(mixed_ones_and_zeros) == [0, 1, 0, 1, 0, 1, 0, 1]
```
-# `.filter`
+## `.filter` / `.afilter`
> Keeps only the elements that satisfy a condition:
-π show example
+π show snippet
```python
even_integers: Stream[int] = integers.filter(lambda n: n % 2 == 0)
@@ -341,11 +492,11 @@ assert list(even_integers) == [0, 2, 4, 6, 8]
```
-# `.distinct`
+## `.distinct` / `.adistinct`
> Removes duplicates:
-π show example
+π show snippet
```python
distinct_chars: Stream[str] = Stream("foobarfooo").distinct()
@@ -356,7 +507,7 @@ assert list(distinct_chars) == ["f", "o", "b", "a", "r"]
> specifying a deduplication `key`:
-π show example
+π show snippet
```python
strings_of_distinct_lengths: Stream[str] = (
@@ -371,7 +522,7 @@ assert list(strings_of_distinct_lengths) == ["a", "foo"]
> [!WARNING]
> During iteration, all distinct elements that are yielded are retained in memory to perform deduplication. However, you can remove only consecutive duplicates without a memory footprint by setting `consecutive_only=True`:
-π show example
+π show snippet
```python
consecutively_distinct_chars: Stream[str] = (
@@ -383,11 +534,11 @@ assert list(consecutively_distinct_chars) == ["f", "o", "b", "a", "r", "f", "o"]
```
-# `.truncate`
+## `.truncate` / `.atruncate`
> Ends iteration once a given number of elements have been yielded:
-π show example
+π show snippet
```python
five_first_integers: Stream[int] = integers.truncate(5)
@@ -398,7 +549,7 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4]
> or `when` a condition is satisfied:
-π show example
+π show snippet
```python
five_first_integers: Stream[int] = integers.truncate(when=lambda n: n == 5)
@@ -409,11 +560,11 @@ assert list(five_first_integers) == [0, 1, 2, 3, 4]
> If both `count` and `when` are set, truncation occurs as soon as either condition is met.
-# `.skip`
+## `.skip` / `.askip`
> Skips the first specified number of elements:
-π show example
+π show snippet
```python
integers_after_five: Stream[int] = integers.skip(5)
@@ -424,7 +575,7 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9]
> or skips elements `until` a predicate is satisfied:
-π show example
+π show snippet
```python
integers_after_five: Stream[int] = integers.skip(until=lambda n: n >= 5)
@@ -435,11 +586,11 @@ assert list(integers_after_five) == [5, 6, 7, 8, 9]
> If both `count` and `until` are set, skipping stops as soon as either condition is met.
-# `.catch`
+## `.catch` / `.acatch`
> Catches a given type of exception, and optionally yields a `replacement` value:
-π show example
+π show snippet
```python
inverses: Stream[float] = (
@@ -453,7 +604,7 @@ assert list(inverses) == [float("inf"), 1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0
> You can specify an additional `when` condition for the catch:
-π show example
+π show snippet
```python
import requests
@@ -473,9 +624,9 @@ assert list(status_codes_ignoring_resolution_errors) == [200, 404]
> It has an optional `finally_raise: bool` parameter to raise the first exception caught (if any) when the iteration terminates.
> [!TIP]
-> Apply side effects when catching an exception by integrating them into `when`:
+> Leverage `when` to apply side effects on catch:
-π show example
+π show snippet
```python
errors: List[Exception] = []
@@ -495,12 +646,11 @@ assert len(errors) == len("foo")
```
-
-# `.throttle`
+## `.throttle`
> Limits the number of yields `per` time interval:
-π show example
+π show snippet
```python
from datetime import timedelta
@@ -513,10 +663,10 @@ assert list(three_integers_per_second) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
-# `.observe`
+## `.observe`
> Logs the progress of iterations:
-π show example
+π show snippet
```python
>>> assert list(integers.throttle(2, per=timedelta(seconds=1)).observe("integers")) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
@@ -537,7 +687,7 @@ INFO: [duration=0:00:04.003852 errors=0] 10 integers yielded
> [!TIP]
> To mute these logs, set the logging level above `INFO`:
-π show example
+π show snippet
```python
import logging
@@ -545,11 +695,11 @@ logging.getLogger("streamable").setLevel(logging.WARNING)
```
-# `+`
+## `+`
> Concatenates streams:
-π show example
+π show snippet
```python
assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4, 5, 6, 7, 8, 9]
@@ -557,12 +707,11 @@ assert list(integers + integers) == [0, 1, 2, 3 ,4, 5, 6, 7, 8, 9, 0, 1, 2, 3 ,4
-# `zip`
+## `zip`
-> [!TIP]
> Use the standard `zip` function:
-π show example
+π show snippet
```python
from streamable import star
@@ -578,29 +727,34 @@ assert list(cubes) == [0, 1, 8, 27, 64, 125, 216, 343, 512, 729]
## Shorthands for consuming the stream
-> [!NOTE]
-> Although consuming the stream is beyond the scope of this library, it provides two basic shorthands to trigger an iteration:
-## `.count`
+Although consuming the stream is beyond the scope of this library, it provides two basic shorthands to trigger an iteration:
+## `.count` / `.acount`
+> `.count` iterates over the stream until exhaustion and returns the number of elements yielded:
-> Iterates over the stream until exhaustion and returns the number of elements yielded:
-
-π show example
+π show snippet
```python
assert integers.count() == 10
```
+> The `.acount` (`async` method) iterates over the stream as an `AsyncIterable` until exhaustion and returns the number of elements yielded:
+
+π show snippet
-## `()`
+```python
+assert asyncio.run(integers.acount()) == 10
+```
+
+## `()` / `await`
-> *Calling* the stream iterates over it until exhaustion and returns it:
-π show example
+> *Calling* the stream iterates over it until exhaustion, and returns it:
+π show snippet
```python
state: List[int] = []
@@ -610,12 +764,25 @@ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
```
+> *Awaiting* the stream iterates over it as an `AsyncIterable` until exhaustion, and returns it:
-# `.pipe`
+π show snippet
-> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any:
+```python
+async def test_await() -> None:
+ state: List[int] = []
+ appending_integers: Stream[int] = integers.foreach(state.append)
+ appending_integers is await appending_integers
+ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+asyncio.run(test_await())
+```
+
-π show example
+## `.pipe`
+
+> Calls a function, passing the stream as first argument, followed by `*args/**kwargs` if any (inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html)):
+
+π show snippet
```python
import pandas as pd
@@ -629,9 +796,6 @@ import pandas as pd
```
-> Inspired by the `.pipe` from [pandas](https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.pipe.html) or [polars](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.pipe.html).
-
-
# π‘ Notes
## Exceptions are not terminating the iteration
@@ -639,7 +803,7 @@ import pandas as pd
> [!TIP]
> If any of the operations raises an exception, you can resume the iteration after handling it:
-π show example
+π show snippet
```python
from contextlib import suppress
@@ -662,63 +826,11 @@ assert collected == [0, 1, 2, 3, 5, 6, 7, 8, 9]
-## Extract-Transform-Load
-> [!TIP]
-> **Custom ETL scripts** can benefit from the expressiveness of this library. Below is a pipeline that extracts the 67 quadruped PokΓ©mon from the first three generations using [PokΓ©API](https://pokeapi.co/) and loads them into a CSV:
-
-π show example
-
-```python
-import csv
-from datetime import timedelta
-import itertools
-import requests
-from streamable import Stream
-
-with open("./quadruped_pokemons.csv", mode="w") as file:
- fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"]
- writer = csv.DictWriter(file, fields, extrasaction='ignore')
- writer.writeheader()
-
- pipeline: Stream = (
- # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur
- Stream(itertools.count(1))
- # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs
- .throttle(16, per=timedelta(seconds=1))
- # GETs pokemons concurrently using a pool of 8 threads
- .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}")
- .map(requests.get, concurrency=8)
- .foreach(requests.Response.raise_for_status)
- .map(requests.Response.json)
- # Stops the iteration when reaching the 1st pokemon of the 4th generation
- .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv")
- .observe("pokemons")
- # Keeps only quadruped Pokemons
- .filter(lambda poke: poke["shape"]["name"] == "quadruped")
- .observe("quadruped pokemons")
- # Catches errors due to None "generation" or "shape"
- .catch(
- TypeError,
- when=lambda error: str(error) == "'NoneType' object is not subscriptable"
- )
- # Writes a batch of pokemons every 5 seconds to the CSV file
- .group(interval=timedelta(seconds=5))
- .foreach(writer.writerows)
- .flatten()
- .observe("written pokemons")
- # Catches exceptions and raises the 1st one at the end of the iteration
- .catch(Exception, finally_raise=True)
- )
-
- pipeline()
-```
-
-
## Visitor Pattern
> [!TIP]
> A `Stream` can be visited via its `.accept` method: implement a custom [***visitor***](https://en.wikipedia.org/wiki/Visitor_pattern) by extending the abstract class `streamable.visitors.Visitor`:
-π show example
+π show snippet
```python
from streamable.visitors import Visitor
@@ -739,7 +851,7 @@ assert depth(Stream(range(10)).map(str).foreach(print)) == 3
## Functions
> [!TIP]
> The `Stream`'s methods are also exposed as functions:
-π show example
+π show snippet
```python
from streamable.functions import catch
diff --git a/setup.py b/setup.py
index 4e6561a5..1ae77f09 100644
--- a/setup.py
+++ b/setup.py
@@ -1,4 +1,5 @@
from setuptools import find_packages, setup # type: ignore
+
from version import __version__
setup(
@@ -10,7 +11,7 @@
license="Apache 2.",
author="ebonnal",
author_email="bonnal.enzo.dev@gmail.com",
- description="Pythonic Stream-like manipulation of iterables",
+ description="Pythonic Stream-like manipulation of (async) iterables",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
)
diff --git a/streamable/__init__.py b/streamable/__init__.py
index a51ad250..17cf16ae 100644
--- a/streamable/__init__.py
+++ b/streamable/__init__.py
@@ -1,2 +1,4 @@
from streamable.stream import Stream
from streamable.util.functiontools import star
+
+__all__ = ["Stream", "star"]
diff --git a/streamable/afunctions.py b/streamable/afunctions.py
new file mode 100644
index 00000000..5532de3c
--- /dev/null
+++ b/streamable/afunctions.py
@@ -0,0 +1,353 @@
+import builtins
+import datetime
+from contextlib import suppress
+from operator import itemgetter
+from typing import (
+ Any,
+ AsyncIterable,
+ AsyncIterator,
+ Callable,
+ Coroutine,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+from streamable.aiterators import (
+ ACatchAsyncIterator,
+ ADistinctAsyncIterator,
+ AFilterAsyncIterator,
+ AFlattenAsyncIterator,
+ AGroupbyAsyncIterator,
+ AMapAsyncIterator,
+ ConcurrentAFlattenAsyncIterator,
+ ConcurrentAMapAsyncIterator,
+ ConcurrentFlattenAsyncIterator,
+ ConcurrentMapAsyncIterator,
+ ConsecutiveADistinctAsyncIterator,
+ CountAndPredicateASkipAsyncIterator,
+ CountSkipAsyncIterator,
+ CountTruncateAsyncIterator,
+ FlattenAsyncIterator,
+ GroupAsyncIterator,
+ ObserveAsyncIterator,
+ PredicateASkipAsyncIterator,
+ PredicateATruncateAsyncIterator,
+ YieldsPerPeriodThrottleAsyncIterator,
+)
+from streamable.util.constants import NO_REPLACEMENT
+from streamable.util.functiontools import asyncify
+from streamable.util.validationtools import (
+ validate_aiterator,
+ validate_concurrency,
+ validate_errors,
+ validate_group_size,
+ # validate_not_none,
+ validate_optional_count,
+ validate_optional_positive_count,
+ validate_optional_positive_interval,
+ validate_via,
+)
+
+with suppress(ImportError):
+ from typing import Literal
+
+T = TypeVar("T")
+U = TypeVar("U")
+
+
+def catch(
+ aiterator: AsyncIterator[T],
+ errors: Union[
+ Optional[Type[Exception]], Iterable[Optional[Type[Exception]]]
+ ] = Exception,
+ *,
+ when: Optional[Callable[[Exception], Any]] = None,
+ replacement: T = NO_REPLACEMENT, # type: ignore
+ finally_raise: bool = False,
+) -> AsyncIterator[T]:
+ return acatch(
+ aiterator,
+ errors,
+ when=asyncify(when) if when else None,
+ replacement=replacement,
+ finally_raise=finally_raise,
+ )
+
+
+def acatch(
+ aiterator: AsyncIterator[T],
+ errors: Union[
+ Optional[Type[Exception]], Iterable[Optional[Type[Exception]]]
+ ] = Exception,
+ *,
+ when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None,
+ replacement: T = NO_REPLACEMENT, # type: ignore
+ finally_raise: bool = False,
+) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ validate_errors(errors)
+ # validate_not_none(finally_raise, "finally_raise")
+ if errors is None:
+ return aiterator
+ return ACatchAsyncIterator(
+ aiterator,
+ tuple(builtins.filter(None, errors))
+ if isinstance(errors, Iterable)
+ else errors,
+ when=when,
+ replacement=replacement,
+ finally_raise=finally_raise,
+ )
+
+
+def distinct(
+ aiterator: AsyncIterator[T],
+ key: Optional[Callable[[T], Any]] = None,
+ *,
+ consecutive_only: bool = False,
+) -> AsyncIterator[T]:
+ return adistinct(
+ aiterator,
+ asyncify(key) if key else None,
+ consecutive_only=consecutive_only,
+ )
+
+
+def adistinct(
+ aiterator: AsyncIterator[T],
+ key: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+ *,
+ consecutive_only: bool = False,
+) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ # validate_not_none(consecutive_only, "consecutive_only")
+ if consecutive_only:
+ return ConsecutiveADistinctAsyncIterator(aiterator, key)
+ return ADistinctAsyncIterator(aiterator, key)
+
+
+def filter(
+ aiterator: AsyncIterator[T],
+ when: Callable[[T], Any],
+) -> AsyncIterator[T]:
+ return afilter(aiterator, asyncify(when))
+
+
+def afilter(
+ aiterator: AsyncIterator[T],
+ when: Callable[[T], Any],
+) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ return AFilterAsyncIterator(aiterator, when)
+
+
+def flatten(
+ aiterator: AsyncIterator[Iterable[T]], *, concurrency: int = 1
+) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ validate_concurrency(concurrency)
+ if concurrency == 1:
+ return FlattenAsyncIterator(aiterator)
+ else:
+ return ConcurrentFlattenAsyncIterator(
+ aiterator,
+ concurrency=concurrency,
+ buffersize=concurrency,
+ )
+
+
+def aflatten(
+ aiterator: AsyncIterator[AsyncIterable[T]], *, concurrency: int = 1
+) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ validate_concurrency(concurrency)
+ if concurrency == 1:
+ return AFlattenAsyncIterator(aiterator)
+ else:
+ return ConcurrentAFlattenAsyncIterator(
+ aiterator,
+ concurrency=concurrency,
+ buffersize=concurrency,
+ )
+
+
+def group(
+ aiterator: AsyncIterator[T],
+ size: Optional[int] = None,
+ *,
+ interval: Optional[datetime.timedelta] = None,
+ by: Optional[Callable[[T], Any]] = None,
+) -> AsyncIterator[List[T]]:
+ return agroup(
+ aiterator,
+ size,
+ interval=interval,
+ by=asyncify(by) if by else None,
+ )
+
+
+def agroup(
+ aiterator: AsyncIterator[T],
+ size: Optional[int] = None,
+ *,
+ interval: Optional[datetime.timedelta] = None,
+ by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+) -> AsyncIterator[List[T]]:
+ validate_aiterator(aiterator)
+ validate_group_size(size)
+ validate_optional_positive_interval(interval)
+ if by is None:
+ return GroupAsyncIterator(aiterator, size, interval)
+ return map(itemgetter(1), AGroupbyAsyncIterator(aiterator, by, size, interval))
+
+
+def groupby(
+ aiterator: AsyncIterator[T],
+ key: Callable[[T], U],
+ *,
+ size: Optional[int] = None,
+ interval: Optional[datetime.timedelta] = None,
+) -> AsyncIterator[Tuple[U, List[T]]]:
+ return agroupby(aiterator, asyncify(key), size=size, interval=interval)
+
+
+def agroupby(
+ aiterator: AsyncIterator[T],
+ key: Callable[[T], Coroutine[Any, Any, U]],
+ *,
+ size: Optional[int] = None,
+ interval: Optional[datetime.timedelta] = None,
+) -> AsyncIterator[Tuple[U, List[T]]]:
+ validate_aiterator(aiterator)
+ validate_group_size(size)
+ validate_optional_positive_interval(interval)
+ return AGroupbyAsyncIterator(aiterator, key, size, interval)
+
+
+def map(
+ transformation: Callable[[T], U],
+ aiterator: AsyncIterator[T],
+ *,
+ concurrency: int = 1,
+ ordered: bool = True,
+ via: "Literal['thread', 'process']" = "thread",
+) -> AsyncIterator[U]:
+ validate_aiterator(aiterator)
+ # validate_not_none(transformation, "transformation")
+ # validate_not_none(ordered, "ordered")
+ validate_concurrency(concurrency)
+ validate_via(via)
+ if concurrency == 1:
+ return amap(asyncify(transformation), aiterator)
+ else:
+ return ConcurrentMapAsyncIterator(
+ aiterator,
+ transformation,
+ concurrency=concurrency,
+ buffersize=concurrency,
+ ordered=ordered,
+ via=via,
+ )
+
+
+def amap(
+ transformation: Callable[[T], Coroutine[Any, Any, U]],
+ aiterator: AsyncIterator[T],
+ *,
+ concurrency: int = 1,
+ ordered: bool = True,
+) -> AsyncIterator[U]:
+ validate_aiterator(aiterator)
+ # validate_not_none(transformation, "transformation")
+ # validate_not_none(ordered, "ordered")
+ validate_concurrency(concurrency)
+ if concurrency == 1:
+ return AMapAsyncIterator(aiterator, transformation)
+ return ConcurrentAMapAsyncIterator(
+ aiterator,
+ transformation,
+ buffersize=concurrency,
+ ordered=ordered,
+ )
+
+
+def observe(aiterator: AsyncIterator[T], what: str) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ # validate_not_none(what, "what")
+ return ObserveAsyncIterator(aiterator, what)
+
+
+def skip(
+ aiterator: AsyncIterator[T],
+ count: Optional[int] = None,
+ *,
+ until: Optional[Callable[[T], Any]] = None,
+) -> AsyncIterator[T]:
+ return askip(
+ aiterator,
+ count,
+ until=asyncify(until) if until else None,
+ )
+
+
+def askip(
+ aiterator: AsyncIterator[T],
+ count: Optional[int] = None,
+ *,
+ until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ validate_optional_count(count)
+ if until is not None:
+ if count is not None:
+ return CountAndPredicateASkipAsyncIterator(aiterator, count, until)
+ return PredicateASkipAsyncIterator(aiterator, until)
+ if count is not None:
+ return CountSkipAsyncIterator(aiterator, count)
+ return aiterator
+
+
+def throttle(
+ aiterator: AsyncIterator[T],
+ count: Optional[int],
+ *,
+ per: Optional[datetime.timedelta] = None,
+) -> AsyncIterator[T]:
+ validate_optional_positive_count(count)
+ validate_optional_positive_interval(per, name="per")
+ if count and per:
+ aiterator = YieldsPerPeriodThrottleAsyncIterator(aiterator, count, per)
+ return aiterator
+
+
+def truncate(
+ aiterator: AsyncIterator[T],
+ count: Optional[int] = None,
+ *,
+ when: Optional[Callable[[T], Any]] = None,
+) -> AsyncIterator[T]:
+ return atruncate(
+ aiterator,
+ count,
+ when=asyncify(when) if when else None,
+ )
+
+
+def atruncate(
+ aiterator: AsyncIterator[T],
+ count: Optional[int] = None,
+ *,
+ when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+) -> AsyncIterator[T]:
+ validate_aiterator(aiterator)
+ validate_optional_count(count)
+ if count is not None:
+ aiterator = CountTruncateAsyncIterator(aiterator, count)
+ if when is not None:
+ aiterator = PredicateATruncateAsyncIterator(aiterator, when)
+ return aiterator
diff --git a/streamable/aiterators.py b/streamable/aiterators.py
new file mode 100644
index 00000000..72fb85d8
--- /dev/null
+++ b/streamable/aiterators.py
@@ -0,0 +1,924 @@
+import asyncio
+import datetime
+import multiprocessing
+import queue
+import time
+from abc import ABC, abstractmethod
+from collections import defaultdict, deque
+from concurrent.futures import Executor, Future, ProcessPoolExecutor, ThreadPoolExecutor
+from contextlib import contextmanager, suppress
+from math import ceil
+from typing import (
+ Any,
+ AsyncIterable,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ ContextManager,
+ Coroutine,
+ DefaultDict,
+ Deque,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ NamedTuple,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+)
+
+from streamable.util.asynctools import (
+ GetEventLoopMixin,
+ awaitable_to_coroutine,
+ empty_aiter,
+)
+from streamable.util.functiontools import (
+ aiter_wo_stopasynciteration,
+ awrap_error,
+ iter_wo_stopasynciteration,
+ wrap_error,
+)
+from streamable.util.loggertools import get_logger
+from streamable.util.validationtools import (
+ validate_aiterator,
+ validate_base,
+ validate_buffersize,
+ validate_concurrency,
+ validate_count,
+ validate_errors,
+ validate_group_size,
+ validate_optional_positive_interval,
+)
+
+T = TypeVar("T")
+U = TypeVar("U")
+
+from streamable.util.constants import NO_REPLACEMENT
+from streamable.util.futuretools import (
+ FDFOAsyncFutureResultCollection,
+ FDFOOSFutureResultCollection,
+ FIFOAsyncFutureResultCollection,
+ FIFOOSFutureResultCollection,
+ FutureResultCollection,
+)
+
+with suppress(ImportError):
+ from typing import Literal
+
+
+class ACatchAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ errors: Union[Type[Exception], Tuple[Type[Exception], ...]],
+ when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]],
+ replacement: T,
+ finally_raise: bool,
+ ) -> None:
+ validate_aiterator(iterator)
+ validate_errors(errors)
+ self.iterator = iterator
+ self.errors = errors
+ self.when = awrap_error(when, StopAsyncIteration) if when else None
+ self.replacement = replacement
+ self.finally_raise = finally_raise
+ self._to_be_finally_raised: Optional[Exception] = None
+
+ async def __anext__(self) -> T:
+ while True:
+ try:
+ return await self.iterator.__anext__()
+ except StopAsyncIteration:
+ if self._to_be_finally_raised:
+ try:
+ raise self._to_be_finally_raised
+ finally:
+ self._to_be_finally_raised = None
+ raise
+ except self.errors as e:
+ if not self.when or await self.when(e):
+ if self.finally_raise and not self._to_be_finally_raised:
+ self._to_be_finally_raised = e
+ if self.replacement is not NO_REPLACEMENT:
+ return self.replacement
+ continue
+ raise
+
+
+class ADistinctAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ key: Optional[Callable[[T], Coroutine[Any, Any, Any]]],
+ ) -> None:
+ validate_aiterator(iterator)
+ self.iterator = iterator
+ self.key = awrap_error(key, StopAsyncIteration) if key else None
+ self._already_seen: Set[Any] = set()
+
+ async def __anext__(self) -> T:
+ while True:
+ elem = await self.iterator.__anext__()
+ key = await self.key(elem) if self.key else elem
+ if key not in self._already_seen:
+ break
+ self._already_seen.add(key)
+ return elem
+
+
+class ConsecutiveADistinctAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ key: Optional[Callable[[T], Coroutine[Any, Any, Any]]],
+ ) -> None:
+ validate_aiterator(iterator)
+ self.iterator = iterator
+ self.key = awrap_error(key, StopAsyncIteration) if key else None
+ self._last_key: Any = object()
+
+ async def __anext__(self) -> T:
+ while True:
+ elem = await self.iterator.__anext__()
+ key = await self.key(elem) if self.key else elem
+ if key != self._last_key:
+ break
+ self._last_key = key
+ return elem
+
+
+class FlattenAsyncIterator(AsyncIterator[U]):
+ def __init__(self, iterator: AsyncIterator[Iterable[U]]) -> None:
+ validate_aiterator(iterator)
+ self.iterator = iterator
+ self._current_iterator_elem: Iterator[U] = tuple().__iter__()
+
+ async def __anext__(self) -> U:
+ while True:
+ try:
+ return self._current_iterator_elem.__next__()
+ except StopIteration:
+ self._current_iterator_elem = iter_wo_stopasynciteration(
+ await self.iterator.__anext__()
+ )
+
+
+class AFlattenAsyncIterator(AsyncIterator[U]):
+ def __init__(self, iterator: AsyncIterator[AsyncIterable[U]]) -> None:
+ validate_aiterator(iterator)
+ self.iterator = iterator
+ self._current_iterator_elem: AsyncIterator[U] = empty_aiter()
+
+ async def __anext__(self) -> U:
+ while True:
+ try:
+ return await self._current_iterator_elem.__anext__()
+ except StopAsyncIteration:
+ self._current_iterator_elem = aiter_wo_stopasynciteration(
+ await self.iterator.__anext__()
+ )
+
+
+class _GroupAsyncIteratorMixin(Generic[T]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ size: Optional[int],
+ interval: Optional[datetime.timedelta],
+ ) -> None:
+ validate_aiterator(iterator)
+ validate_group_size(size)
+ validate_optional_positive_interval(interval)
+ self.iterator = iterator
+ self.size = size or cast(int, float("inf"))
+ self.interval = interval
+ self._interval_seconds = interval.total_seconds() if interval else float("inf")
+ self._to_be_raised: Optional[Exception] = None
+ self._last_group_yielded_at: float = 0
+
+ def _interval_seconds_have_elapsed(self) -> bool:
+ if not self.interval:
+ return False
+ return (
+ time.perf_counter() - self._last_group_yielded_at
+ ) >= self._interval_seconds
+
+ def _remember_group_time(self) -> None:
+ if self.interval:
+ self._last_group_yielded_at = time.perf_counter()
+
+ def _init_last_group_time(self) -> None:
+ if self.interval and not self._last_group_yielded_at:
+ self._last_group_yielded_at = time.perf_counter()
+
+
+class GroupAsyncIterator(_GroupAsyncIteratorMixin[T], AsyncIterator[List[T]]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ size: Optional[int],
+ interval: Optional[datetime.timedelta],
+ ) -> None:
+ super().__init__(iterator, size, interval)
+ self._current_group: List[T] = []
+
+ async def __anext__(self) -> List[T]:
+ self._init_last_group_time()
+ if self._to_be_raised:
+ try:
+ raise self._to_be_raised
+ finally:
+ self._to_be_raised = None
+ try:
+ while len(self._current_group) < self.size and (
+ not self._interval_seconds_have_elapsed() or not self._current_group
+ ):
+ self._current_group.append(await self.iterator.__anext__())
+ except Exception as e:
+ if not self._current_group:
+ raise
+ self._to_be_raised = e
+
+ group, self._current_group = self._current_group, []
+ self._remember_group_time()
+ return group
+
+
+class AGroupbyAsyncIterator(
+ _GroupAsyncIteratorMixin[T], AsyncIterator[Tuple[U, List[T]]]
+):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ key: Callable[[T], Coroutine[Any, Any, U]],
+ size: Optional[int],
+ interval: Optional[datetime.timedelta],
+ ) -> None:
+ super().__init__(iterator, size, interval)
+ self.key = awrap_error(key, StopAsyncIteration)
+ self._is_exhausted = False
+ self._groups_by: DefaultDict[U, List[T]] = defaultdict(list)
+
+ async def _group_next_elem(self) -> None:
+ elem = await self.iterator.__anext__()
+ self._groups_by[await self.key(elem)].append(elem)
+
+ def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]:
+ for key, group in self._groups_by.items():
+ if len(group) >= self.size:
+ return key, self._groups_by.pop(key)
+ return None
+
+ def _pop_first_group(self) -> Tuple[U, List[T]]:
+ first_key: U = self._groups_by.__iter__().__next__()
+ return first_key, self._groups_by.pop(first_key)
+
+ def _pop_largest_group(self) -> Tuple[U, List[T]]:
+ largest_group_key: Any = self._groups_by.__iter__().__next__()
+
+ for key, group in self._groups_by.items():
+ if len(group) > len(self._groups_by[largest_group_key]):
+ largest_group_key = key
+
+ return largest_group_key, self._groups_by.pop(largest_group_key)
+
+ async def __anext__(self) -> Tuple[U, List[T]]:
+ self._init_last_group_time()
+ if self._is_exhausted:
+ if self._groups_by:
+ return self._pop_first_group()
+ raise StopAsyncIteration
+
+ if self._to_be_raised:
+ if self._groups_by:
+ self._remember_group_time()
+ return self._pop_first_group()
+ try:
+ raise self._to_be_raised
+ finally:
+ self._to_be_raised = None
+
+ try:
+ await self._group_next_elem()
+
+ full_group: Optional[Tuple[U, List[T]]] = self._pop_full_group()
+ while not full_group and not self._interval_seconds_have_elapsed():
+ await self._group_next_elem()
+ full_group = self._pop_full_group()
+
+ self._remember_group_time()
+ return full_group or self._pop_largest_group()
+
+ except StopAsyncIteration:
+ self._is_exhausted = True
+ return await self.__anext__()
+
+ except Exception as e:
+ self._to_be_raised = e
+ return await self.__anext__()
+
+
+class CountSkipAsyncIterator(AsyncIterator[T]):
+ def __init__(self, iterator: AsyncIterator[T], count: int) -> None:
+ validate_aiterator(iterator)
+ validate_count(count)
+ self.iterator = iterator
+ self.count = count
+ self._n_skipped = 0
+ self._done_skipping = False
+
+ async def __anext__(self) -> T:
+ if not self._done_skipping:
+ while self._n_skipped < self.count:
+ await self.iterator.__anext__()
+ # do not count exceptions as skipped elements
+ self._n_skipped += 1
+ self._done_skipping = True
+ return await self.iterator.__anext__()
+
+
+class PredicateASkipAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self, iterator: AsyncIterator[T], until: Callable[[T], Coroutine[Any, Any, Any]]
+ ) -> None:
+ validate_aiterator(iterator)
+ self.iterator = iterator
+ self.until = awrap_error(until, StopAsyncIteration)
+ self._done_skipping = False
+
+ async def __anext__(self) -> T:
+ elem = await self.iterator.__anext__()
+ if not self._done_skipping:
+ while not await self.until(elem):
+ elem = await self.iterator.__anext__()
+ self._done_skipping = True
+ return elem
+
+
+class CountAndPredicateASkipAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ count: int,
+ until: Callable[[T], Coroutine[Any, Any, Any]],
+ ) -> None:
+ validate_aiterator(iterator)
+ validate_count(count)
+ self.iterator = iterator
+ self.count = count
+ self.until = awrap_error(until, StopAsyncIteration)
+ self._n_skipped = 0
+ self._done_skipping = False
+
+ async def __anext__(self) -> T:
+ elem = await self.iterator.__anext__()
+ if not self._done_skipping:
+ while self._n_skipped < self.count and not await self.until(elem):
+ elem = await self.iterator.__anext__()
+ # do not count exceptions as skipped elements
+ self._n_skipped += 1
+ self._done_skipping = True
+ return elem
+
+
+class CountTruncateAsyncIterator(AsyncIterator[T]):
+ def __init__(self, iterator: AsyncIterator[T], count: int) -> None:
+ validate_aiterator(iterator)
+ validate_count(count)
+ self.iterator = iterator
+ self.count = count
+ self._current_count = 0
+
+ async def __anext__(self) -> T:
+ if self._current_count == self.count:
+ raise StopAsyncIteration()
+ elem = await self.iterator.__anext__()
+ self._current_count += 1
+ return elem
+
+
+class PredicateATruncateAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self, iterator: AsyncIterator[T], when: Callable[[T], Coroutine[Any, Any, Any]]
+ ) -> None:
+ validate_aiterator(iterator)
+ self.iterator = iterator
+ self.when = awrap_error(when, StopAsyncIteration)
+ self._satisfied = False
+
+ async def __anext__(self) -> T:
+ if self._satisfied:
+ raise StopAsyncIteration()
+ elem = await self.iterator.__anext__()
+ if await self.when(elem):
+ self._satisfied = True
+ raise StopAsyncIteration()
+ return elem
+
+
+class AMapAsyncIterator(AsyncIterator[U]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ transformation: Callable[[T], Coroutine[Any, Any, U]],
+ ) -> None:
+ validate_aiterator(iterator)
+
+ self.iterator = iterator
+ self.transformation = awrap_error(transformation, StopAsyncIteration)
+
+ async def __anext__(self) -> U:
+ return await self.transformation(await self.iterator.__anext__())
+
+
+class AFilterAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ when: Callable[[T], Coroutine[Any, Any, Any]],
+ ) -> None:
+ validate_aiterator(iterator)
+
+ self.iterator = iterator
+ self.when = awrap_error(when, StopAsyncIteration)
+
+ async def __anext__(self) -> T:
+ while True:
+ elem = await self.iterator.__anext__()
+ if await self.when(elem):
+ return elem
+
+
+class ObserveAsyncIterator(AsyncIterator[T]):
+ def __init__(self, iterator: AsyncIterator[T], what: str, base: int = 2) -> None:
+ validate_aiterator(iterator)
+ validate_base(base)
+
+ self.iterator = iterator
+ self.what = what
+ self.base = base
+
+ self._n_yields = 0
+ self._n_errors = 0
+ self._n_nexts = 0
+ self._logged_n_nexts = 0
+ self._next_threshold = 0
+
+ self._start_time = time.perf_counter()
+
+ def _log(self) -> None:
+ get_logger().info(
+ "[%s %s] %s",
+ f"duration={datetime.datetime.fromtimestamp(time.perf_counter()) - datetime.datetime.fromtimestamp(self._start_time)}",
+ f"errors={self._n_errors}",
+ f"{self._n_yields} {self.what} yielded",
+ )
+ self._logged_n_nexts = self._n_nexts
+ self._next_threshold = self.base * self._logged_n_nexts
+
+ async def __anext__(self) -> T:
+ try:
+ elem = await self.iterator.__anext__()
+ self._n_nexts += 1
+ self._n_yields += 1
+ return elem
+ except StopAsyncIteration:
+ if self._n_nexts != self._logged_n_nexts:
+ self._log()
+ raise
+ except Exception:
+ self._n_nexts += 1
+ self._n_errors += 1
+ raise
+ finally:
+ if self._n_nexts >= self._next_threshold:
+ self._log()
+
+
+class YieldsPerPeriodThrottleAsyncIterator(AsyncIterator[T]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ max_yields: int,
+ period: datetime.timedelta,
+ ) -> None:
+ validate_aiterator(iterator)
+ self.iterator = iterator
+ self.max_yields = max_yields
+ self._period_seconds = period.total_seconds()
+
+ self._period_index: int = -1
+ self._yields_in_period = 0
+
+ self._offset: Optional[float] = None
+
+ async def safe_next(self) -> Tuple[Optional[T], Optional[Exception]]:
+ try:
+ return await self.iterator.__anext__(), None
+ except StopAsyncIteration:
+ raise
+ except Exception as e:
+ return None, e
+
+ async def __anext__(self) -> T:
+ elem, caught_error = await self.safe_next()
+
+ now = time.perf_counter()
+ if not self._offset:
+ self._offset = now
+ now -= self._offset
+
+ num_periods = now / self._period_seconds
+ period_index = int(num_periods)
+
+ if self._period_index != period_index:
+ self._period_index = period_index
+ self._yields_in_period = max(0, self._yields_in_period - self.max_yields)
+
+ if self._yields_in_period >= self.max_yields:
+ await asyncio.sleep(
+ (ceil(num_periods) - num_periods) * self._period_seconds
+ )
+ self._yields_in_period += 1
+
+ if caught_error:
+ raise caught_error
+ return cast(T, elem)
+
+
+class _RaisingAsyncIterator(AsyncIterator[T]):
+ class ExceptionContainer(NamedTuple):
+ exception: Exception
+
+ def __init__(
+ self,
+ iterator: AsyncIterator[Union[T, ExceptionContainer]],
+ ) -> None:
+ self.iterator = iterator
+
+ async def __anext__(self) -> T:
+ elem = await self.iterator.__anext__()
+ if isinstance(elem, self.ExceptionContainer):
+ raise elem.exception
+ return elem
+
+
+class _ConcurrentMapAsyncIterableMixin(
+ Generic[T, U],
+ ABC,
+ AsyncIterable[Union[U, _RaisingAsyncIterator.ExceptionContainer]],
+):
+ """
+ Template Method Pattern:
+ This abstract class's `__iter__` is a skeleton for a queue-based concurrent mapping algorithm
+ that relies on abstract helper methods (`_context_manager`, `_create_future`, `_future_result_collection`)
+ that must be implemented by concrete subclasses.
+ """
+
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ buffersize: int,
+ ordered: bool,
+ ) -> None:
+ validate_aiterator(iterator)
+ validate_buffersize(buffersize)
+ self.iterator = iterator
+ self.buffersize = buffersize
+ self.ordered = ordered
+
+ def _context_manager(self) -> ContextManager:
+ @contextmanager
+ def dummy_context_manager_generator():
+ yield
+
+ return dummy_context_manager_generator()
+
+ @abstractmethod
+ def _launch_task(
+ self, elem: T
+ ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]": ...
+
+ # factory method
+ @abstractmethod
+ def _future_result_collection(
+ self,
+ ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]: ...
+
+ async def __aiter__(
+ self,
+ ) -> AsyncIterator[Union[U, _RaisingAsyncIterator.ExceptionContainer]]:
+ with self._context_manager():
+ future_results = self._future_result_collection()
+
+ # queue tasks up to buffersize
+ with suppress(StopAsyncIteration):
+ while len(future_results) < self.buffersize:
+ future_results.add_future(
+ self._launch_task(await self.iterator.__anext__())
+ )
+
+ # wait, queue, yield
+ while future_results:
+ result = await future_results.__anext__()
+ with suppress(StopAsyncIteration):
+ future_results.add_future(
+ self._launch_task(await self.iterator.__anext__())
+ )
+ yield result
+
+
+class _ConcurrentMapAsyncIterable(_ConcurrentMapAsyncIterableMixin[T, U]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ transformation: Callable[[T], U],
+ concurrency: int,
+ buffersize: int,
+ ordered: bool,
+ via: "Literal['thread', 'process']",
+ ) -> None:
+ super().__init__(iterator, buffersize, ordered)
+ validate_concurrency(concurrency)
+ self.transformation = wrap_error(transformation, StopAsyncIteration)
+ self.concurrency = concurrency
+ self.executor: Executor
+ self.via = via
+
+ def _context_manager(self) -> ContextManager:
+ if self.via == "thread":
+ self.executor = ThreadPoolExecutor(max_workers=self.concurrency)
+ if self.via == "process":
+ self.executor = ProcessPoolExecutor(max_workers=self.concurrency)
+ return self.executor
+
+ # picklable
+ @staticmethod
+ def _safe_transformation(
+ transformation: Callable[[T], U], elem: T
+ ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]:
+ try:
+ return transformation(elem)
+ except Exception as e:
+ return _RaisingAsyncIterator.ExceptionContainer(e)
+
+ def _launch_task(
+ self, elem: T
+ ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]":
+ return self.executor.submit(
+ self._safe_transformation, self.transformation, elem
+ )
+
+ def _future_result_collection(
+ self,
+ ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]:
+ if self.ordered:
+ return FIFOOSFutureResultCollection()
+ return FDFOOSFutureResultCollection(
+ multiprocessing.Queue if self.via == "process" else queue.Queue
+ )
+
+
+class ConcurrentMapAsyncIterator(_RaisingAsyncIterator[U]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ transformation: Callable[[T], U],
+ concurrency: int,
+ buffersize: int,
+ ordered: bool,
+ via: "Literal['thread', 'process']",
+ ) -> None:
+ super().__init__(
+ _ConcurrentMapAsyncIterable(
+ iterator,
+ transformation,
+ concurrency,
+ buffersize,
+ ordered,
+ via,
+ ).__aiter__()
+ )
+
+
+class _ConcurrentAMapAsyncIterable(
+ _ConcurrentMapAsyncIterableMixin[T, U], GetEventLoopMixin
+):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ transformation: Callable[[T], Coroutine[Any, Any, U]],
+ buffersize: int,
+ ordered: bool,
+ ) -> None:
+ super().__init__(iterator, buffersize, ordered)
+ self.transformation = awrap_error(transformation, StopAsyncIteration)
+
+ async def _safe_transformation(
+ self, elem: T
+ ) -> Union[U, _RaisingAsyncIterator.ExceptionContainer]:
+ try:
+ coroutine = self.transformation(elem)
+ if not isinstance(coroutine, Coroutine):
+ raise TypeError(
+ f"`transformation` must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}",
+ )
+ return await coroutine
+ except Exception as e:
+ return _RaisingAsyncIterator.ExceptionContainer(e)
+
+ def _launch_task(
+ self, elem: T
+ ) -> "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]":
+ return cast(
+ "Future[Union[U, _RaisingAsyncIterator.ExceptionContainer]]",
+ self.get_event_loop().create_task(self._safe_transformation(elem)),
+ )
+
+ def _future_result_collection(
+ self,
+ ) -> FutureResultCollection[Union[U, _RaisingAsyncIterator.ExceptionContainer]]:
+ if self.ordered:
+ return FIFOAsyncFutureResultCollection(self.get_event_loop())
+ else:
+ return FDFOAsyncFutureResultCollection(self.get_event_loop())
+
+
+class ConcurrentAMapAsyncIterator(_RaisingAsyncIterator[U]):
+ def __init__(
+ self,
+ iterator: AsyncIterator[T],
+ transformation: Callable[[T], Coroutine[Any, Any, U]],
+ buffersize: int,
+ ordered: bool,
+ ) -> None:
+ super().__init__(
+ _ConcurrentAMapAsyncIterable(
+ iterator,
+ transformation,
+ buffersize,
+ ordered,
+ ).__aiter__()
+ )
+
+
+class _ConcurrentFlattenAsyncIterable(
+ AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]]
+):
+ def __init__(
+ self,
+ iterables_iterator: AsyncIterator[Iterable[T]],
+ concurrency: int,
+ buffersize: int,
+ ) -> None:
+ validate_aiterator(iterables_iterator)
+ validate_concurrency(concurrency)
+ self.iterables_iterator = iterables_iterator
+ self.concurrency = concurrency
+ self.buffersize = buffersize
+
+ async def __aiter__(
+ self,
+ ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]:
+ with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
+ iterator_and_future_pairs: Deque[Tuple[Iterator[T], Future]] = deque()
+ element_to_yield: Deque[
+ Union[T, _RaisingAsyncIterator.ExceptionContainer]
+ ] = deque(maxlen=1)
+ iterator_to_queue: Optional[Iterator[T]] = None
+ # wait, queue, yield (FIFO)
+ while True:
+ if iterator_and_future_pairs:
+ iterator, future = iterator_and_future_pairs.popleft()
+ try:
+ element_to_yield.append(future.result())
+ iterator_to_queue = iterator
+ except StopIteration:
+ pass
+ except Exception as e:
+ element_to_yield.append(
+ _RaisingAsyncIterator.ExceptionContainer(e)
+ )
+ iterator_to_queue = iterator
+
+ # queue tasks up to buffersize
+ while len(iterator_and_future_pairs) < self.buffersize:
+ if not iterator_to_queue:
+ try:
+ iterable = await self.iterables_iterator.__anext__()
+ except StopAsyncIteration:
+ break
+ try:
+ iterator_to_queue = iter_wo_stopasynciteration(iterable)
+ except Exception as e:
+ yield _RaisingAsyncIterator.ExceptionContainer(e)
+ continue
+ future = executor.submit(next, iterator_to_queue)
+ iterator_and_future_pairs.append((iterator_to_queue, future))
+ iterator_to_queue = None
+ if element_to_yield:
+ yield element_to_yield.pop()
+ if not iterator_and_future_pairs:
+ break
+
+
+class ConcurrentFlattenAsyncIterator(_RaisingAsyncIterator[T]):
+ def __init__(
+ self,
+ iterables_iterator: AsyncIterator[Iterable[T]],
+ concurrency: int,
+ buffersize: int,
+ ) -> None:
+ super().__init__(
+ _ConcurrentFlattenAsyncIterable(
+ iterables_iterator,
+ concurrency,
+ buffersize,
+ ).__aiter__()
+ )
+
+
+class _ConcurrentAFlattenAsyncIterable(
+ AsyncIterable[Union[T, _RaisingAsyncIterator.ExceptionContainer]], GetEventLoopMixin
+):
+ def __init__(
+ self,
+ iterables_iterator: AsyncIterator[AsyncIterable[T]],
+ concurrency: int,
+ buffersize: int,
+ ) -> None:
+ validate_aiterator(iterables_iterator)
+ validate_concurrency(concurrency)
+ self.iterables_iterator = iterables_iterator
+ self.concurrency = concurrency
+ self.buffersize = buffersize
+
+ async def __aiter__(
+ self,
+ ) -> AsyncIterator[Union[T, _RaisingAsyncIterator.ExceptionContainer]]:
+ iterator_and_future_pairs: Deque[Tuple[AsyncIterator[T], Awaitable[T]]] = (
+ deque()
+ )
+ element_to_yield: Deque[Union[T, _RaisingAsyncIterator.ExceptionContainer]] = (
+ deque(maxlen=1)
+ )
+ iterator_to_queue: Optional[AsyncIterator[T]] = None
+ # wait, queue, yield (FIFO)
+ while True:
+ if iterator_and_future_pairs:
+ iterator, future = iterator_and_future_pairs.popleft()
+ try:
+ element_to_yield.append(await future)
+ iterator_to_queue = iterator
+ except StopAsyncIteration:
+ pass
+ except Exception as e:
+ element_to_yield.append(_RaisingAsyncIterator.ExceptionContainer(e))
+ iterator_to_queue = iterator
+
+ # queue tasks up to buffersize
+ while len(iterator_and_future_pairs) < self.buffersize:
+ if not iterator_to_queue:
+ try:
+ iterable = await self.iterables_iterator.__anext__()
+ except StopAsyncIteration:
+ break
+ try:
+ iterator_to_queue = aiter_wo_stopasynciteration(iterable)
+ except Exception as e:
+ yield _RaisingAsyncIterator.ExceptionContainer(e)
+ continue
+ future = self.get_event_loop().create_task(
+ awaitable_to_coroutine(
+ cast(AsyncIterator, iterator_to_queue).__anext__()
+ )
+ )
+ iterator_and_future_pairs.append(
+ (cast(AsyncIterator, iterator_to_queue), future)
+ )
+ iterator_to_queue = None
+ if element_to_yield:
+ yield element_to_yield.pop()
+ if not iterator_and_future_pairs:
+ break
+
+
+class ConcurrentAFlattenAsyncIterator(_RaisingAsyncIterator[T]):
+ def __init__(
+ self,
+ iterables_iterator: AsyncIterator[AsyncIterable[T]],
+ concurrency: int,
+ buffersize: int,
+ ) -> None:
+ super().__init__(
+ _ConcurrentAFlattenAsyncIterable(
+ iterables_iterator,
+ concurrency,
+ buffersize,
+ ).__aiter__()
+ )
diff --git a/streamable/functions.py b/streamable/functions.py
index aa369fd0..601831f2 100644
--- a/streamable/functions.py
+++ b/streamable/functions.py
@@ -4,6 +4,7 @@
from operator import itemgetter
from typing import (
Any,
+ AsyncIterable,
Callable,
Coroutine,
Iterable,
@@ -17,9 +18,12 @@
)
from streamable.iterators import (
- AsyncConcurrentMapIterator,
+ AFlattenIterator,
CatchIterator,
+ ConcurrentAFlattenIterator,
+ ConcurrentAMapIterator,
ConcurrentFlattenIterator,
+ ConcurrentMapIterator,
ConsecutiveDistinctIterator,
CountAndPredicateSkipIterator,
CountSkipIterator,
@@ -29,13 +33,12 @@
GroupbyIterator,
GroupIterator,
ObserveIterator,
- OSConcurrentMapIterator,
PredicateSkipIterator,
PredicateTruncateIterator,
YieldsPerPeriodThrottleIterator,
)
from streamable.util.constants import NO_REPLACEMENT
-from streamable.util.functiontools import wrap_error
+from streamable.util.functiontools import syncify, wrap_error
from streamable.util.validationtools import (
validate_concurrency,
validate_errors,
@@ -79,6 +82,25 @@ def catch(
)
+def acatch(
+ iterator: Iterator[T],
+ errors: Union[
+ Optional[Type[Exception]], Iterable[Optional[Type[Exception]]]
+ ] = Exception,
+ *,
+ when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None,
+ replacement: T = NO_REPLACEMENT, # type: ignore
+ finally_raise: bool = False,
+) -> Iterator[T]:
+ return catch(
+ iterator,
+ errors,
+ when=syncify(when) if when else None,
+ replacement=replacement,
+ finally_raise=finally_raise,
+ )
+
+
def distinct(
iterator: Iterator[T],
key: Optional[Callable[[T], Any]] = None,
@@ -92,6 +114,19 @@ def distinct(
return DistinctIterator(iterator, key)
+def adistinct(
+ iterator: Iterator[T],
+ key: Optional[Callable[[T], Any]] = None,
+ *,
+ consecutive_only: bool = False,
+) -> Iterator[T]:
+ return distinct(
+ iterator,
+ syncify(key) if key else None,
+ consecutive_only=consecutive_only,
+ )
+
+
def flatten(iterator: Iterator[Iterable[T]], *, concurrency: int = 1) -> Iterator[T]:
validate_iterator(iterator)
validate_concurrency(concurrency)
@@ -105,6 +140,21 @@ def flatten(iterator: Iterator[Iterable[T]], *, concurrency: int = 1) -> Iterato
)
+def aflatten(
+ iterator: Iterator[AsyncIterable[T]], *, concurrency: int = 1
+) -> Iterator[T]:
+ validate_iterator(iterator)
+ validate_concurrency(concurrency)
+ if concurrency == 1:
+ return AFlattenIterator(iterator)
+ else:
+ return ConcurrentAFlattenIterator(
+ iterator,
+ concurrency=concurrency,
+ buffersize=concurrency,
+ )
+
+
def group(
iterator: Iterator[T],
size: Optional[int] = None,
@@ -120,6 +170,21 @@ def group(
return map(itemgetter(1), GroupbyIterator(iterator, by, size, interval))
+def agroup(
+ iterator: Iterator[T],
+ size: Optional[int] = None,
+ *,
+ interval: Optional[datetime.timedelta] = None,
+ by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+) -> Iterator[List[T]]:
+ return group(
+ iterator,
+ size,
+ interval=interval,
+ by=syncify(by) if by else None,
+ )
+
+
def groupby(
iterator: Iterator[T],
key: Callable[[T], U],
@@ -133,6 +198,21 @@ def groupby(
return GroupbyIterator(iterator, key, size, interval)
+def agroupby(
+ iterator: Iterator[T],
+ key: Callable[[T], Coroutine[Any, Any, U]],
+ *,
+ size: Optional[int] = None,
+ interval: Optional[datetime.timedelta] = None,
+) -> Iterator[Tuple[U, List[T]]]:
+ return groupby(
+ iterator,
+ syncify(key),
+ size=size,
+ interval=interval,
+ )
+
+
def map(
transformation: Callable[[T], U],
iterator: Iterator[T],
@@ -149,7 +229,7 @@ def map(
if concurrency == 1:
return builtins.map(wrap_error(transformation, StopIteration), iterator)
else:
- return OSConcurrentMapIterator(
+ return ConcurrentMapIterator(
iterator,
transformation,
concurrency=concurrency,
@@ -170,7 +250,9 @@ def amap(
# validate_not_none(transformation, "transformation")
# validate_not_none(ordered, "ordered")
validate_concurrency(concurrency)
- return AsyncConcurrentMapIterator(
+ if concurrency == 1:
+ return map(syncify(transformation), iterator)
+ return ConcurrentAMapIterator(
iterator,
transformation,
buffersize=concurrency,
@@ -201,6 +283,19 @@ def skip(
return iterator
+def askip(
+ iterator: Iterator[T],
+ count: Optional[int] = None,
+ *,
+ until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+) -> Iterator[T]:
+ return skip(
+ iterator,
+ count,
+ until=syncify(until) if until else None,
+ )
+
+
def throttle(
iterator: Iterator[T],
count: Optional[int],
@@ -227,3 +322,16 @@ def truncate(
if when is not None:
iterator = PredicateTruncateIterator(iterator, when)
return iterator
+
+
+def atruncate(
+ iterator: Iterator[T],
+ count: Optional[int] = None,
+ *,
+ when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+) -> Iterator[T]:
+ return truncate(
+ iterator,
+ count,
+ when=syncify(when) if when else None,
+ )
diff --git a/streamable/iterators.py b/streamable/iterators.py
index d1713a52..30385944 100644
--- a/streamable/iterators.py
+++ b/streamable/iterators.py
@@ -1,4 +1,3 @@
-import asyncio
import datetime
import multiprocessing
import queue
@@ -10,6 +9,9 @@
from math import ceil
from typing import (
Any,
+ AsyncIterable,
+ AsyncIterator,
+ Awaitable,
Callable,
ContextManager,
Coroutine,
@@ -29,7 +31,16 @@
cast,
)
-from streamable.util.functiontools import iter_wo_stopiteration, wrap_error
+from streamable.util.asynctools import (
+ GetEventLoopMixin,
+ awaitable_to_coroutine,
+ empty_aiter,
+)
+from streamable.util.functiontools import (
+ aiter_wo_stopiteration,
+ iter_wo_stopiteration,
+ wrap_error,
+)
from streamable.util.loggertools import get_logger
from streamable.util.validationtools import (
validate_base,
@@ -79,7 +90,7 @@ def __init__(
def __next__(self) -> T:
while True:
try:
- return next(self.iterator)
+ return self.iterator.__next__()
except StopIteration:
if self._to_be_finally_raised:
try:
@@ -108,7 +119,7 @@ def __init__(
def __next__(self) -> T:
while True:
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
key = self.key(elem) if self.key else elem
if key not in self._already_seen:
break
@@ -127,7 +138,7 @@ def __init__(
def __next__(self) -> T:
while True:
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
key = self.key(elem) if self.key else elem
if key != self._last_key:
break
@@ -139,14 +150,35 @@ class FlattenIterator(Iterator[U]):
def __init__(self, iterator: Iterator[Iterable[U]]) -> None:
validate_iterator(iterator)
self.iterator = iterator
- self._current_iterator_elem: Iterator[U] = iter(tuple())
+ self._current_iterator_elem: Iterator[U] = tuple().__iter__()
def __next__(self) -> U:
while True:
try:
- return next(self._current_iterator_elem)
+ return self._current_iterator_elem.__next__()
except StopIteration:
- self._current_iterator_elem = iter_wo_stopiteration(next(self.iterator))
+ self._current_iterator_elem = iter_wo_stopiteration(
+ self.iterator.__next__()
+ )
+
+
+class AFlattenIterator(Iterator[U], GetEventLoopMixin):
+ def __init__(self, iterator: Iterator[AsyncIterable[U]]) -> None:
+ validate_iterator(iterator)
+ self.iterator = iterator
+
+ self._current_iterator_elem: AsyncIterator[U] = empty_aiter()
+
+ def __next__(self) -> U:
+ while True:
+ try:
+ return self.get_event_loop().run_until_complete(
+ self._current_iterator_elem.__anext__()
+ )
+ except StopAsyncIteration:
+ self._current_iterator_elem = aiter_wo_stopiteration(
+ self.iterator.__next__()
+ )
class _GroupIteratorMixin(Generic[T]):
@@ -203,7 +235,7 @@ def __next__(self) -> List[T]:
while len(self._current_group) < self.size and (
not self._interval_seconds_have_elapsed() or not self._current_group
):
- self._current_group.append(next(self.iterator))
+ self._current_group.append(self.iterator.__next__())
except Exception as e:
if not self._current_group:
raise
@@ -228,7 +260,7 @@ def __init__(
self._groups_by: DefaultDict[U, List[T]] = defaultdict(list)
def _group_next_elem(self) -> None:
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
self._groups_by[self.key(elem)].append(elem)
def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]:
@@ -238,11 +270,11 @@ def _pop_full_group(self) -> Optional[Tuple[U, List[T]]]:
return None
def _pop_first_group(self) -> Tuple[U, List[T]]:
- first_key: U = next(iter(self._groups_by), cast(U, ...))
+ first_key: U = self._groups_by.__iter__().__next__()
return first_key, self._groups_by.pop(first_key)
def _pop_largest_group(self) -> Tuple[U, List[T]]:
- largest_group_key: Any = next(iter(self._groups_by), ...)
+ largest_group_key: Any = self._groups_by.__iter__().__next__()
for key, group in self._groups_by.items():
if len(group) > len(self._groups_by[largest_group_key]):
@@ -279,11 +311,11 @@ def __next__(self) -> Tuple[U, List[T]]:
except StopIteration:
self._is_exhausted = True
- return next(self)
+ return self.__next__()
except Exception as e:
self._to_be_raised = e
- return next(self)
+ return self.__next__()
class CountSkipIterator(Iterator[T]):
@@ -298,11 +330,11 @@ def __init__(self, iterator: Iterator[T], count: int) -> None:
def __next__(self) -> T:
if not self._done_skipping:
while self._n_skipped < self.count:
- next(self.iterator)
+ self.iterator.__next__()
# do not count exceptions as skipped elements
self._n_skipped += 1
self._done_skipping = True
- return next(self.iterator)
+ return self.iterator.__next__()
class PredicateSkipIterator(Iterator[T]):
@@ -313,10 +345,10 @@ def __init__(self, iterator: Iterator[T], until: Callable[[T], Any]) -> None:
self._done_skipping = False
def __next__(self) -> T:
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
if not self._done_skipping:
while not self.until(elem):
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
self._done_skipping = True
return elem
@@ -334,10 +366,10 @@ def __init__(
self._done_skipping = False
def __next__(self) -> T:
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
if not self._done_skipping:
while self._n_skipped < self.count and not self.until(elem):
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
# do not count exceptions as skipped elements
self._n_skipped += 1
self._done_skipping = True
@@ -355,7 +387,7 @@ def __init__(self, iterator: Iterator[T], count: int) -> None:
def __next__(self) -> T:
if self._current_count == self.count:
raise StopIteration()
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
self._current_count += 1
return elem
@@ -370,7 +402,7 @@ def __init__(self, iterator: Iterator[T], when: Callable[[T], Any]) -> None:
def __next__(self) -> T:
if self._satisfied:
raise StopIteration()
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
if self.when(elem):
self._satisfied = True
raise StopIteration()
@@ -406,7 +438,7 @@ def _log(self) -> None:
def __next__(self) -> T:
try:
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
self._n_nexts += 1
self._n_yields += 1
return elem
@@ -442,7 +474,7 @@ def __init__(
def safe_next(self) -> Tuple[Optional[T], Optional[Exception]]:
try:
- return next(self.iterator), None
+ return self.iterator.__next__(), None
except StopIteration:
raise
except Exception as e:
@@ -483,13 +515,13 @@ def __init__(
self.iterator = iterator
def __next__(self) -> T:
- elem = next(self.iterator)
+ elem = self.iterator.__next__()
if isinstance(elem, self.ExceptionContainer):
raise elem.exception
return elem
-class _ConcurrentMapIterable(
+class _ConcurrentMapIterableMixin(
Generic[T, U], ABC, Iterable[Union[U, _RaisingIterator.ExceptionContainer]]
):
"""
@@ -536,17 +568,21 @@ def __iter__(self) -> Iterator[Union[U, _RaisingIterator.ExceptionContainer]]:
# queue tasks up to buffersize
with suppress(StopIteration):
while len(future_results) < self.buffersize:
- future_results.add_future(self._launch_task(next(self.iterator)))
+ future_results.add_future(
+ self._launch_task(self.iterator.__next__())
+ )
# wait, queue, yield
while future_results:
- result = next(future_results)
+ result = future_results.__next__()
with suppress(StopIteration):
- future_results.add_future(self._launch_task(next(self.iterator)))
+ future_results.add_future(
+ self._launch_task(self.iterator.__next__())
+ )
yield result
-class _OSConcurrentMapIterable(_ConcurrentMapIterable[T, U]):
+class _ConcurrentMapIterable(_ConcurrentMapIterableMixin[T, U]):
def __init__(
self,
iterator: Iterator[T],
@@ -597,7 +633,7 @@ def _future_result_collection(
)
-class OSConcurrentMapIterator(_RaisingIterator[U]):
+class ConcurrentMapIterator(_RaisingIterator[U]):
def __init__(
self,
iterator: Iterator[T],
@@ -608,20 +644,18 @@ def __init__(
via: "Literal['thread', 'process']",
) -> None:
super().__init__(
- iter(
- _OSConcurrentMapIterable(
- iterator,
- transformation,
- concurrency,
- buffersize,
- ordered,
- via,
- )
- )
+ _ConcurrentMapIterable(
+ iterator,
+ transformation,
+ concurrency,
+ buffersize,
+ ordered,
+ via,
+ ).__iter__()
)
-class _AsyncConcurrentMapIterable(_ConcurrentMapIterable[T, U]):
+class _ConcurrentAMapIterable(_ConcurrentMapIterableMixin[T, U], GetEventLoopMixin):
def __init__(
self,
iterator: Iterator[T],
@@ -631,12 +665,6 @@ def __init__(
) -> None:
super().__init__(iterator, buffersize, ordered)
self.transformation = wrap_error(transformation, StopIteration)
- self.event_loop: asyncio.AbstractEventLoop
- try:
- self.event_loop = asyncio.get_event_loop()
- except RuntimeError:
- self.event_loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.event_loop)
async def _safe_transformation(
self, elem: T
@@ -656,19 +684,19 @@ def _launch_task(
) -> "Future[Union[U, _RaisingIterator.ExceptionContainer]]":
return cast(
"Future[Union[U, _RaisingIterator.ExceptionContainer]]",
- self.event_loop.create_task(self._safe_transformation(elem)),
+ self.get_event_loop().create_task(self._safe_transformation(elem)),
)
def _future_result_collection(
self,
) -> FutureResultCollection[Union[U, _RaisingIterator.ExceptionContainer]]:
if self.ordered:
- return FIFOAsyncFutureResultCollection(self.event_loop)
+ return FIFOAsyncFutureResultCollection(self.get_event_loop())
else:
- return FDFOAsyncFutureResultCollection(self.event_loop)
+ return FDFOAsyncFutureResultCollection(self.get_event_loop())
-class AsyncConcurrentMapIterator(_RaisingIterator[U]):
+class ConcurrentAMapIterator(_RaisingIterator[U]):
def __init__(
self,
iterator: Iterator[T],
@@ -677,14 +705,12 @@ def __init__(
ordered: bool,
) -> None:
super().__init__(
- iter(
- _AsyncConcurrentMapIterable(
- iterator,
- transformation,
- buffersize,
- ordered,
- )
- )
+ _ConcurrentAMapIterable(
+ iterator,
+ transformation,
+ buffersize,
+ ordered,
+ ).__iter__()
)
@@ -727,7 +753,7 @@ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]:
while len(iterator_and_future_pairs) < self.buffersize:
if not iterator_to_queue:
try:
- iterable = next(self.iterables_iterator)
+ iterable = self.iterables_iterator.__next__()
except StopIteration:
break
try:
@@ -752,11 +778,90 @@ def __init__(
buffersize: int,
) -> None:
super().__init__(
- iter(
- _ConcurrentFlattenIterable(
- iterables_iterator,
- concurrency,
- buffersize,
+ _ConcurrentFlattenIterable(
+ iterables_iterator,
+ concurrency,
+ buffersize,
+ ).__iter__()
+ )
+
+
+class _ConcurrentAFlattenIterable(
+ Iterable[Union[T, _RaisingIterator.ExceptionContainer]], GetEventLoopMixin
+):
+ def __init__(
+ self,
+ iterables_iterator: Iterator[AsyncIterable[T]],
+ concurrency: int,
+ buffersize: int,
+ ) -> None:
+ validate_iterator(iterables_iterator)
+ validate_concurrency(concurrency)
+ self.iterables_iterator = iterables_iterator
+ self.concurrency = concurrency
+ self.buffersize = buffersize
+
+ def __iter__(self) -> Iterator[Union[T, _RaisingIterator.ExceptionContainer]]:
+ iterator_and_future_pairs: Deque[Tuple[AsyncIterator[T], Awaitable[T]]] = (
+ deque()
+ )
+ element_to_yield: Deque[Union[T, _RaisingIterator.ExceptionContainer]] = deque(
+ maxlen=1
+ )
+ iterator_to_queue: Optional[AsyncIterator[T]] = None
+ # wait, queue, yield (FIFO)
+ while True:
+ if iterator_and_future_pairs:
+ iterator, future = iterator_and_future_pairs.popleft()
+ try:
+ element_to_yield.append(
+ self.get_event_loop().run_until_complete(future)
+ )
+ iterator_to_queue = iterator
+ except StopAsyncIteration:
+ pass
+ except Exception as e:
+ element_to_yield.append(_RaisingIterator.ExceptionContainer(e))
+ iterator_to_queue = iterator
+
+ # queue tasks up to buffersize
+ while len(iterator_and_future_pairs) < self.buffersize:
+ if not iterator_to_queue:
+ try:
+ iterable = self.iterables_iterator.__next__()
+ except StopIteration:
+ break
+ try:
+ iterator_to_queue = aiter_wo_stopiteration(iterable)
+ except Exception as e:
+ yield _RaisingIterator.ExceptionContainer(e)
+ continue
+ future = self.get_event_loop().create_task(
+ awaitable_to_coroutine(
+ cast(AsyncIterator, iterator_to_queue).__anext__()
+ )
)
- )
+ iterator_and_future_pairs.append(
+ (cast(AsyncIterator, iterator_to_queue), future)
+ )
+ iterator_to_queue = None
+ if element_to_yield:
+ yield element_to_yield.pop()
+ if not iterator_and_future_pairs:
+ break
+
+
+class ConcurrentAFlattenIterator(_RaisingIterator[T]):
+ def __init__(
+ self,
+ iterables_iterator: Iterator[AsyncIterable[T]],
+ concurrency: int,
+ buffersize: int,
+ ) -> None:
+ super().__init__(
+ _ConcurrentAFlattenIterable(
+ iterables_iterator,
+ concurrency,
+ buffersize,
+ ).__iter__()
)
diff --git a/streamable/stream.py b/streamable/stream.py
index d25693c2..de8f2444 100644
--- a/streamable/stream.py
+++ b/streamable/stream.py
@@ -5,10 +5,14 @@
from typing import (
TYPE_CHECKING,
Any,
+ AsyncIterable,
+ AsyncIterator,
+ Awaitable,
Callable,
Collection,
Coroutine,
Dict,
+ Generator,
Generic,
Iterable,
Iterator,
@@ -25,6 +29,7 @@
)
from streamable.util.constants import NO_REPLACEMENT
+from streamable.util.functiontools import asyncify
from streamable.util.loggertools import get_logger
from streamable.util.validationtools import (
validate_concurrency,
@@ -54,17 +59,29 @@
V = TypeVar("V")
-class Stream(Iterable[T]):
+class Stream(Iterable[T], AsyncIterable[T], Awaitable["Stream[T]"]):
__slots__ = ("_source", "_upstream")
# fmt: off
@overload
def __init__(self, source: Iterable[T]) -> None: ...
@overload
+ def __init__(self, source: AsyncIterable[T]) -> None: ...
+ @overload
def __init__(self, source: Callable[[], Iterable[T]]) -> None: ...
+ @overload
+ def __init__(self, source: Callable[[], AsyncIterable[T]]) -> None: ...
# fmt: on
- def __init__(self, source: Union[Iterable[T], Callable[[], Iterable[T]]]) -> None:
+ def __init__(
+ self,
+ source: Union[
+ Iterable[T],
+ Callable[[], Iterable[T]],
+ AsyncIterable[T],
+ Callable[[], AsyncIterable[T]],
+ ],
+ ) -> None:
"""
A `Stream[T]` decorates an `Iterable[T]` with a **fluent interface** enabling the chaining of lazy operations.
@@ -84,7 +101,11 @@ def upstream(self) -> "Optional[Stream]":
return self._upstream
@property
- def source(self) -> Union[Iterable, Callable[[], Iterable]]:
+ def source(
+ self,
+ ) -> Union[
+ Iterable, Callable[[], Iterable], AsyncIterable, Callable[[], AsyncIterable]
+ ]:
"""
Returns:
Callable[[], Iterable]: Function called at iteration time (i.e. by `__iter__`) to get a fresh source iterable.
@@ -103,6 +124,11 @@ def __iter__(self) -> Iterator[T]:
return self.accept(IteratorVisitor[T]())
+ def __aiter__(self) -> AsyncIterator[T]:
+ from streamable.visitors.aiterator import AsyncIteratorVisitor
+
+ return self.accept(AsyncIteratorVisitor[T]())
+
def __repr__(self) -> str:
from streamable.visitors.representation import ReprVisitor
@@ -134,6 +160,16 @@ def __call__(self) -> "Stream[T]":
self.count()
return self
+ def __await__(self) -> Generator[int, None, "Stream[T]"]:
+ """
+ Iterates over this stream until exhaustion.
+
+ Returns:
+ Stream[T]: self.
+ """
+ yield from (self.acount().__await__())
+ return self
+
def accept(self, visitor: "Visitor[V]") -> V:
"""
Entry point to visit this stream (en.wikipedia.org/wiki/Visitor_pattern).
@@ -175,6 +211,40 @@ def catch(
finally_raise=finally_raise,
)
+ def acatch(
+ self,
+ errors: Union[
+ Optional[Type[Exception]], Iterable[Optional[Type[Exception]]]
+ ] = Exception,
+ *,
+ when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]] = None,
+ replacement: T = NO_REPLACEMENT, # type: ignore
+ finally_raise: bool = False,
+ ) -> "Stream[T]":
+ """
+ Catches the upstream exceptions if they are instances of `errors` type and they satisfy the `when` predicate.
+ Optionally yields a `replacement` value.
+ If any exception was caught during the iteration and `finally_raise=True`, the first caught exception will be raised when the iteration finishes.
+
+ Args:
+ errors (Optional[Type[Exception]], Iterable[Optional[Type[Exception]]], optional): The exception type to catch, or an iterable of exception types to catch (default: catches all `Exception`s)
+ when (Optional[Callable[[Exception], Coroutine[Any, Any, Any]]], optional): An additional condition that must be satisfied to catch the exception, i.e. `when(exception)` must be truthy. (default: no additional condition)
+ replacement (T, optional): The value to yield when an exception is caught. (default: do not yield any replacement value)
+ finally_raise (bool, optional): If True the first exception caught is raised when upstream's iteration ends. (default: iteration ends without raising)
+
+ Returns:
+ Stream[T]: A stream of upstream elements catching the eligible exceptions.
+ """
+ validate_errors(errors)
+ # validate_not_none(finally_raise, "finally_raise")
+ return ACatchStream(
+ self,
+ errors,
+ when=when,
+ replacement=replacement,
+ finally_raise=finally_raise,
+ )
+
def count(self) -> int:
"""
Iterates over this stream until exhaustion and returns the count of elements.
@@ -185,6 +255,18 @@ def count(self) -> int:
return sum(1 for _ in self)
+ async def acount(self) -> int:
+ """
+ Iterates over this stream until exhaustion and returns the count of elements.
+
+ Returns:
+ int: Number of elements yielded during an entire iteration over this stream.
+ """
+ count = 0
+ async for _ in self:
+ count += 1
+ return count
+
def display(self, level: int = logging.INFO) -> "Stream[T]":
"""
Logs (INFO level) a representation of the stream.
@@ -226,6 +308,33 @@ def distinct(
# validate_not_none(consecutive_only, "consecutive_only")
return DistinctStream(self, key, consecutive_only)
+ def adistinct(
+ self,
+ key: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+ *,
+ consecutive_only: bool = False,
+ ) -> "Stream[T]":
+ """
+ Filters the stream to yield only distinct elements.
+ If a deduplication `key` is specified, `foo` and `bar` are treated as duplicates when `key(foo) == key(bar)`.
+
+
+ Among duplicates, the first encountered occurence in upstream order is yielded.
+
+ Warning:
+ During iteration, the distinct elements yielded are retained in memory to perform deduplication.
+ Alternatively, remove only consecutive duplicates without memory footprint by setting `consecutive_only=True`.
+
+ Args:
+ key (Callable[[T], Coroutine[Any, Any, Any]], optional): Elements are deduplicated based on `key(elem)`. (default: the deduplication is performed on the elements themselves)
+ consecutive_only (bool, optional): Whether to deduplicate only consecutive duplicates, or globally. (default: the deduplication is global)
+
+ Returns:
+ Stream: A stream containing only unique upstream elements.
+ """
+ # validate_not_none(consecutive_only, "consecutive_only")
+ return ADistinctStream(self, key, consecutive_only)
+
def filter(self, when: Callable[[T], Any] = bool) -> "Stream[T]":
"""
Filters the stream to yield only elements satisfying the `when` predicate.
@@ -240,6 +349,24 @@ def filter(self, when: Callable[[T], Any] = bool) -> "Stream[T]":
# validate_not_none(when, "when")
return FilterStream(self, cast(Optional[Callable[[T], Any]], when) or bool)
+ def afilter(self, when: Callable[[T], Coroutine[Any, Any, Any]]) -> "Stream[T]":
+ """
+ Filters the stream to yield only elements satisfying the `when` predicate.
+
+ Args:
+ when (Callable[[T], Coroutine[Any, Any, Any]], optional): An element is kept if `when(elem)` is truthy. (default: keeps truthy elements)
+
+ Returns:
+ Stream[T]: A stream of upstream elements satisfying the `when` predicate.
+ """
+ # Unofficially accept `stream.afilter(None)`, behaving as builtin `filter(None, iter)`
+ # validate_not_none(when, "when")
+ return AFilterStream(
+ self,
+ cast(Optional[Callable[[T], Coroutine[Any, Any, Any]]], when)
+ or asyncify(bool),
+ )
+
# fmt: off
@overload
def flatten(
@@ -317,13 +444,43 @@ def flatten(self: "Stream[Iterable[U]]", *, concurrency: int = 1) -> "Stream[U]"
Iterates over upstream elements assumed to be iterables, and individually yields their items.
Args:
- concurrency (int, optional): Represents both the number of threads used to concurrently flatten the upstream iterables and the number of iterables buffered. (default: no concurrency)
+ concurrency (int, optional): Number of upstream iterables concurrently flattened via threads. (default: no concurrency)
Returns:
Stream[R]: A stream of flattened elements from upstream iterables.
"""
validate_concurrency(concurrency)
return FlattenStream(self, concurrency)
+ # fmt: off
+ @overload
+ def aflatten(
+ self: "Stream[AsyncIterator[U]]",
+ *,
+ concurrency: int = 1,
+ ) -> "Stream[U]": ...
+
+ @overload
+ def aflatten(
+ self: "Stream[AsyncIterable[U]]",
+ *,
+ concurrency: int = 1,
+ ) -> "Stream[U]": ...
+ # fmt: on
+
+ def aflatten(
+ self: "Stream[AsyncIterable[U]]", *, concurrency: int = 1
+ ) -> "Stream[U]":
+ """
+ Iterates over upstream elements assumed to be async iterables, and individually yields their items.
+
+ Args:
+ concurrency (int, optional): Number of upstream async iterables concurrently flattened. (default: no concurrency)
+ Returns:
+ Stream[R]: A stream of flattened elements from upstream async iterables.
+ """
+ validate_concurrency(concurrency)
+ return AFlattenStream(self, concurrency)
+
def foreach(
self,
effect: Callable[[T], Any],
@@ -352,7 +509,7 @@ def foreach(
def aforeach(
self,
- effect: Callable[[T], Coroutine],
+ effect: Callable[[T], Coroutine[Any, Any, Any]],
*,
concurrency: int = 1,
ordered: bool = True,
@@ -362,7 +519,7 @@ def aforeach(
If the `effect(elem)` coroutine throws an exception then it will be thrown and `elem` will not be yielded.
Args:
- effect (Callable[[T], Any]): The asynchronous function to be applied to each element as a side effect.
+ effect (Callable[[T], Coroutine[Any, Any, Any]]): The asynchronous function to be applied to each element as a side effect.
concurrency (int, optional): Represents both the number of async tasks concurrently applying the `effect` and the size of the buffer containing not-yet-yielded elements. If the buffer is full, the iteration over the upstream is paused until an element is yielded from the buffer. (default: no concurrency)
ordered (bool, optional): If `concurrency` > 1, whether to preserve the order of upstream elements or to yield them as soon as they are processed. (default: preserves upstream order)
Returns:
@@ -401,6 +558,34 @@ def group(
validate_optional_positive_interval(interval)
return GroupStream(self, size, interval, by)
+ def agroup(
+ self,
+ size: Optional[int] = None,
+ *,
+ interval: Optional[datetime.timedelta] = None,
+ by: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+ ) -> "Stream[List[T]]":
+ """
+ Groups upstream elements into lists.
+
+ A group is yielded when any of the following conditions is met:
+ - The group reaches `size` elements.
+ - `interval` seconds have passed since the last group was yielded.
+ - The upstream source is exhausted.
+
+ If `by` is specified, groups will only contain elements sharing the same `by(elem)` value (see `.agroupby` for `(key, elements)` pairs).
+
+ Args:
+ size (Optional[int], optional): The maximum number of elements per group. (default: no size limit)
+ interval (float, optional): Yields a group if `interval` seconds have passed since the last group was yielded. (default: no interval limit)
+ by (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): If specified, groups will only contain elements sharing the same `by(elem)` value. (default: does not co-group elements)
+ Returns:
+ Stream[List[T]]: A stream of upstream elements grouped into lists.
+ """
+ validate_group_size(size)
+ validate_optional_positive_interval(interval)
+ return AGroupStream(self, size, interval, by)
+
def groupby(
self,
key: Callable[[T], U],
@@ -427,6 +612,32 @@ def groupby(
# validate_not_none(key, "key")
return GroupbyStream(self, key, size, interval)
+ def agroupby(
+ self,
+ key: Callable[[T], Coroutine[Any, Any, U]],
+ *,
+ size: Optional[int] = None,
+ interval: Optional[datetime.timedelta] = None,
+ ) -> "Stream[Tuple[U, List[T]]]":
+ """
+ Groups upstream elements into `(key, elements)` tuples.
+
+ A group is yielded when any of the following conditions is met:
+ - A group reaches `size` elements.
+ - `interval` seconds have passed since the last group was yielded.
+ - The upstream source is exhausted.
+
+ Args:
+ key (Callable[[T], Coroutine[Any, Any, U]]): An async function that returns the group key for an element.
+ size (Optional[int], optional): The maximum number of elements per group. (default: no size limit)
+ interval (Optional[datetime.timedelta], optional): If specified, yields a group if `interval` seconds have passed since the last group was yielded. (default: no interval limit)
+
+ Returns:
+ Stream[Tuple[U, List[T]]]: A stream of upstream elements grouped by key, as `(key, elements)` tuples.
+ """
+ # validate_not_none(key, "key")
+ return AGroupbyStream(self, key, size, interval)
+
def map(
self,
transformation: Callable[[T], U],
@@ -531,6 +742,26 @@ def skip(
validate_optional_count(count)
return SkipStream(self, count, until)
+ def askip(
+ self,
+ count: Optional[int] = None,
+ *,
+ until: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+ ) -> "Stream[T]":
+ """
+ Skips elements until `until(elem)` is truthy, or `count` elements have been skipped.
+ If both `count` and `until` are set, skipping stops as soon as either condition is met.
+
+ Args:
+ count (Optional[int], optional): The maximum number of elements to skip. (default: no count-based skipping)
+ until (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): Elements are skipped until the first one for which `until(elem)` is truthy. This element and all the subsequent ones will be yielded. (default: no predicate-based skipping)
+
+ Returns:
+ Stream: A stream of the upstream elements remaining after skipping.
+ """
+ validate_optional_count(count)
+ return ASkipStream(self, count, until)
+
def throttle(
self,
count: Optional[int] = None,
@@ -604,6 +835,26 @@ def truncate(
validate_optional_count(count)
return TruncateStream(self, count, when)
+ def atruncate(
+ self,
+ count: Optional[int] = None,
+ *,
+ when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+ ) -> "Stream[T]":
+ """
+ Stops an iteration as soon as `when(elem)` is truthy, or `count` elements have been yielded.
+ If both `count` and `when` are set, truncation occurs as soon as either condition is met.
+
+ Args:
+ count (int, optional): The maximum number of elements to yield. (default: no count-based truncation)
+ when (Optional[Callable[[T], Coroutine[Any, Any, Any]]], optional): An async predicate function that determines when to stop the iteration. Iteration stops immediately after encountering the first element for which `when(elem)` is truthy, and that element will not be yielded. (default: no predicate-based truncation)
+
+ Returns:
+ Stream[T]: A stream of at most `count` upstream elements not satisfying the `when` predicate.
+ """
+ validate_optional_count(count)
+ return ATruncateStream(self, count, when)
+
class DownStream(Stream[U], Generic[T, U]):
"""
@@ -621,7 +872,11 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "DownStream[T, U]":
return new
@property
- def source(self) -> Union[Iterable, Callable[[], Iterable]]:
+ def source(
+ self,
+ ) -> Union[
+ Iterable, Callable[[], Iterable], AsyncIterable, Callable[[], AsyncIterable]
+ ]:
return self._upstream.source
@property
@@ -654,6 +909,27 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_catch_stream(self)
+class ACatchStream(DownStream[T, T]):
+ __slots__ = ("_upstream", "_errors", "_when", "_replacement", "_finally_raise")
+
+ def __init__(
+ self,
+ upstream: Stream[T],
+ errors: Union[Optional[Type[Exception]], Iterable[Optional[Type[Exception]]]],
+ when: Optional[Callable[[Exception], Coroutine[Any, Any, Any]]],
+ replacement: T,
+ finally_raise: bool,
+ ) -> None:
+ super().__init__(upstream)
+ self._errors = errors
+ self._when = when
+ self._replacement = replacement
+ self._finally_raise = finally_raise
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_acatch_stream(self)
+
+
class DistinctStream(DownStream[T, T]):
__slots__ = ("_upstream", "_key", "_consecutive_only")
@@ -671,6 +947,23 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_distinct_stream(self)
+class ADistinctStream(DownStream[T, T]):
+ __slots__ = ("_upstream", "_key", "_consecutive_only")
+
+ def __init__(
+ self,
+ upstream: Stream[T],
+ key: Optional[Callable[[T], Coroutine[Any, Any, Any]]],
+ consecutive_only: bool,
+ ) -> None:
+ super().__init__(upstream)
+ self._key = key
+ self._consecutive_only = consecutive_only
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_adistinct_stream(self)
+
+
class FilterStream(DownStream[T, T]):
__slots__ = ("_upstream", "_when")
@@ -682,6 +975,19 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_filter_stream(self)
+class AFilterStream(DownStream[T, T]):
+ __slots__ = ("_upstream", "_when")
+
+ def __init__(
+ self, upstream: Stream[T], when: Callable[[T], Coroutine[Any, Any, Any]]
+ ) -> None:
+ super().__init__(upstream)
+ self._when = when
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_afilter_stream(self)
+
+
class FlattenStream(DownStream[Iterable[T], T]):
__slots__ = ("_upstream", "_concurrency")
@@ -693,6 +999,17 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_flatten_stream(self)
+class AFlattenStream(DownStream[AsyncIterable[T], T]):
+ __slots__ = ("_upstream", "_concurrency")
+
+ def __init__(self, upstream: Stream[AsyncIterable[T]], concurrency: int) -> None:
+ super().__init__(upstream)
+ self._concurrency = concurrency
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_aflatten_stream(self)
+
+
class ForeachStream(DownStream[T, T]):
__slots__ = ("_upstream", "_effect", "_concurrency", "_ordered", "_via")
@@ -752,6 +1069,25 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_group_stream(self)
+class AGroupStream(DownStream[T, List[T]]):
+ __slots__ = ("_upstream", "_size", "_interval", "_by")
+
+ def __init__(
+ self,
+ upstream: Stream[T],
+ size: Optional[int],
+ interval: Optional[datetime.timedelta],
+ by: Optional[Callable[[T], Coroutine[Any, Any, Any]]],
+ ) -> None:
+ super().__init__(upstream)
+ self._size = size
+ self._interval = interval
+ self._by = by
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_agroup_stream(self)
+
+
class GroupbyStream(DownStream[T, Tuple[U, List[T]]]):
__slots__ = ("_upstream", "_key", "_size", "_interval")
@@ -771,6 +1107,25 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_groupby_stream(self)
+class AGroupbyStream(DownStream[T, Tuple[U, List[T]]]):
+ __slots__ = ("_upstream", "_key", "_size", "_interval")
+
+ def __init__(
+ self,
+ upstream: Stream[T],
+ key: Callable[[T], Coroutine[Any, Any, U]],
+ size: Optional[int],
+ interval: Optional[datetime.timedelta],
+ ) -> None:
+ super().__init__(upstream)
+ self._key = key
+ self._size = size
+ self._interval = interval
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_agroupby_stream(self)
+
+
class MapStream(DownStream[T, U]):
__slots__ = ("_upstream", "_transformation", "_concurrency", "_ordered", "_via")
@@ -839,6 +1194,23 @@ def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_skip_stream(self)
+class ASkipStream(DownStream[T, T]):
+ __slots__ = ("_upstream", "_count", "_until")
+
+ def __init__(
+ self,
+ upstream: Stream[T],
+ count: Optional[int],
+ until: Optional[Callable[[T], Coroutine[Any, Any, Any]]],
+ ) -> None:
+ super().__init__(upstream)
+ self._count = count
+ self._until = until
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_askip_stream(self)
+
+
class ThrottleStream(DownStream[T, T]):
__slots__ = ("_upstream", "_count", "_per")
@@ -871,3 +1243,20 @@ def __init__(
def accept(self, visitor: "Visitor[V]") -> V:
return visitor.visit_truncate_stream(self)
+
+
+class ATruncateStream(DownStream[T, T]):
+ __slots__ = ("_upstream", "_count", "_when")
+
+ def __init__(
+ self,
+ upstream: Stream[T],
+ count: Optional[int] = None,
+ when: Optional[Callable[[T], Coroutine[Any, Any, Any]]] = None,
+ ) -> None:
+ super().__init__(upstream)
+ self._count = count
+ self._when = when
+
+ def accept(self, visitor: "Visitor[V]") -> V:
+ return visitor.visit_atruncate_stream(self)
diff --git a/streamable/util/asynctools.py b/streamable/util/asynctools.py
new file mode 100644
index 00000000..439445e5
--- /dev/null
+++ b/streamable/util/asynctools.py
@@ -0,0 +1,26 @@
+import asyncio
+from typing import AsyncIterator, Awaitable, Optional, TypeVar
+
+T = TypeVar("T")
+
+
+async def awaitable_to_coroutine(aw: Awaitable[T]) -> T:
+ return await aw
+
+
+async def empty_aiter() -> AsyncIterator:
+ return
+ yield
+
+
+class GetEventLoopMixin:
+ _EVENT_LOOP_SINGLETON: Optional[asyncio.AbstractEventLoop] = None
+
+ @classmethod
+ def get_event_loop(cls) -> asyncio.AbstractEventLoop:
+ try:
+ return asyncio.get_running_loop()
+ except RuntimeError:
+ if not cls._EVENT_LOOP_SINGLETON:
+ cls._EVENT_LOOP_SINGLETON = asyncio.new_event_loop()
+ return cls._EVENT_LOOP_SINGLETON
diff --git a/streamable/util/errors.py b/streamable/util/errors.py
new file mode 100644
index 00000000..3f5bcf54
--- /dev/null
+++ b/streamable/util/errors.py
@@ -0,0 +1,4 @@
+class WrappedError(Exception):
+ def __init__(self, error: Exception):
+ super().__init__(repr(error))
+ self.error = error
diff --git a/streamable/util/functiontools.py b/streamable/util/functiontools.py
index 3925d84e..ef8ca70e 100644
--- a/streamable/util/functiontools.py
+++ b/streamable/util/functiontools.py
@@ -1,15 +1,23 @@
-from typing import Any, Callable, Coroutine, Generic, Tuple, Type, TypeVar, overload
+from functools import partial
+from typing import (
+ Any,
+ AsyncIterable,
+ Callable,
+ Coroutine,
+ Generic,
+ Tuple,
+ Type,
+ TypeVar,
+ overload,
+)
+
+from streamable.util.asynctools import GetEventLoopMixin
+from streamable.util.errors import WrappedError
T = TypeVar("T")
R = TypeVar("R")
-class WrappedError(Exception):
- def __init__(self, error: Exception):
- super().__init__(repr(error))
- self.error = error
-
-
class _ErrorWrappingDecorator(Generic[T, R]):
def __init__(self, func: Callable[[T], R], error_type: Type[Exception]) -> None:
self.func = func
@@ -26,7 +34,33 @@ def wrap_error(func: Callable[[T], R], error_type: Type[Exception]) -> Callable[
return _ErrorWrappingDecorator(func, error_type)
+def awrap_error(
+ async_func: Callable[[T], Coroutine[Any, Any, R]], error_type: Type[Exception]
+) -> Callable[[T], Coroutine[Any, Any, R]]:
+ async def wrap(elem: T) -> R:
+ try:
+ coroutine = async_func(elem)
+ if not isinstance(coroutine, Coroutine):
+ raise TypeError(
+ f"must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}"
+ )
+ return await coroutine
+ except error_type as e:
+ raise WrappedError(e) from e
+
+ return wrap
+
+
iter_wo_stopiteration = wrap_error(iter, StopIteration)
+iter_wo_stopasynciteration = wrap_error(iter, StopAsyncIteration)
+
+
+def _aiter(aiterable: AsyncIterable):
+ return aiterable.__aiter__()
+
+
+aiter_wo_stopasynciteration = wrap_error(_aiter, StopAsyncIteration)
+aiter_wo_stopiteration = wrap_error(_aiter, StopIteration)
class _Sidify(Generic[T]):
@@ -115,3 +149,28 @@ def add(a: int, b: int) -> int:
```
"""
return _Star(func)
+
+
+class _Syncify(Generic[R], GetEventLoopMixin):
+ def __init__(self, async_func: Callable[[T], Coroutine[Any, Any, R]]) -> None:
+ self.async_func = async_func
+
+ def __call__(self, T) -> R:
+ coroutine = self.async_func(T)
+ if not isinstance(coroutine, Coroutine):
+ raise TypeError(
+ f"must be an async function i.e. a function returning a Coroutine but it returned a {type(coroutine)}"
+ )
+ return self.get_event_loop().run_until_complete(coroutine)
+
+
+def syncify(async_func: Callable[[T], Coroutine[Any, Any, R]]) -> Callable[[T], R]:
+ return _Syncify(async_func)
+
+
+async def _async_call(func: Callable[[T], R], o: T) -> R:
+ return func(o)
+
+
+def asyncify(func: Callable[[T], R]) -> Callable[[T], Coroutine[Any, Any, R]]:
+ return partial(_async_call, func)
diff --git a/streamable/util/futuretools.py b/streamable/util/futuretools.py
index ec551973..9b7af7d8 100644
--- a/streamable/util/futuretools.py
+++ b/streamable/util/futuretools.py
@@ -3,7 +3,7 @@
from collections import deque
from concurrent.futures import Future
from contextlib import suppress
-from typing import Awaitable, Deque, Iterator, Sized, Type, TypeVar, cast
+from typing import AsyncIterator, Awaitable, Deque, Iterator, Sized, Type, TypeVar, cast
with suppress(ImportError):
from streamable.util.protocols import Queue
@@ -11,7 +11,7 @@
T = TypeVar("T")
-class FutureResultCollection(Iterator[T], Sized, ABC):
+class FutureResultCollection(Iterator[T], AsyncIterator[T], Sized, ABC):
"""
Iterator over added futures' results. Supports adding new futures after iteration started.
"""
@@ -19,6 +19,9 @@ class FutureResultCollection(Iterator[T], Sized, ABC):
@abstractmethod
def add_future(self, future: "Future[T]") -> None: ...
+ async def __anext__(self) -> T:
+ return self.__next__()
+
class DequeFutureResultCollection(FutureResultCollection[T]):
def __init__(self) -> None:
@@ -87,6 +90,9 @@ def __next__(self) -> T:
cast(Awaitable[T], self._futures.popleft())
)
+ async def __anext__(self) -> T:
+ return await cast(Awaitable[T], self._futures.popleft())
+
class FDFOAsyncFutureResultCollection(CallbackFutureResultCollection[T]):
"""
@@ -106,3 +112,9 @@ def __next__(self) -> T:
self._n_futures -= 1
self._waiter = self.event_loop.create_future()
return result
+
+ async def __anext__(self) -> T:
+ result = await self._waiter
+ self._n_futures -= 1
+ self._waiter = self.event_loop.create_future()
+ return result
diff --git a/streamable/util/iterabletools.py b/streamable/util/iterabletools.py
new file mode 100644
index 00000000..6d46ef66
--- /dev/null
+++ b/streamable/util/iterabletools.py
@@ -0,0 +1,62 @@
+from typing import (
+ AsyncIterable,
+ AsyncIterator,
+ Callable,
+ Iterable,
+ Iterator,
+ TypeVar,
+)
+
+from streamable.util.asynctools import GetEventLoopMixin
+
+T = TypeVar("T")
+
+
+class BiIterable(Iterable[T], AsyncIterable[T]):
+ pass
+
+
+class BiIterator(Iterator[T], AsyncIterator[T]):
+ pass
+
+
+class SyncToBiIterable(BiIterable[T]):
+ def __init__(self, iterable: Iterable[T]):
+ self.iterable = iterable
+
+ def __iter__(self) -> Iterator[T]:
+ return self.iterable.__iter__()
+
+ def __aiter__(self) -> AsyncIterator[T]:
+ return SyncToAsyncIterator(self.iterable)
+
+
+sync_to_bi_iterable: Callable[[Iterable[T]], BiIterable[T]] = SyncToBiIterable
+
+
+class SyncToAsyncIterator(AsyncIterator[T]):
+ def __init__(self, iterator: Iterable[T]):
+ self.iterator: Iterator[T] = iterator.__iter__()
+
+ async def __anext__(self) -> T:
+ try:
+ return self.iterator.__next__()
+ except StopIteration as e:
+ raise StopAsyncIteration() from e
+
+
+sync_to_async_iter: Callable[[Iterable[T]], AsyncIterator[T]] = SyncToAsyncIterator
+
+
+class AsyncToSyncIterator(Iterator[T], GetEventLoopMixin):
+ def __init__(self, iterator: AsyncIterable[T]):
+ self.iterator: AsyncIterator[T] = iterator.__aiter__()
+
+ def __next__(self) -> T:
+ try:
+ return self.get_event_loop().run_until_complete(self.iterator.__anext__())
+ except StopAsyncIteration as e:
+ raise StopIteration() from e
+
+
+async_to_sync_iter: Callable[[AsyncIterable[T]], Iterator[T]] = AsyncToSyncIterator
diff --git a/streamable/util/validationtools.py b/streamable/util/validationtools.py
index 15fc399b..375f0e12 100644
--- a/streamable/util/validationtools.py
+++ b/streamable/util/validationtools.py
@@ -1,6 +1,14 @@
import datetime
from contextlib import suppress
-from typing import Any, Iterable, Iterator, Optional, Type, TypeVar, Union
+from typing import (
+ AsyncIterator,
+ Iterable,
+ Iterator,
+ Optional,
+ Type,
+ TypeVar,
+ Union,
+)
with suppress(ImportError):
from typing import Literal
@@ -13,6 +21,13 @@ def validate_iterator(iterator: Iterator):
raise TypeError(f"`iterator` must be an Iterator but got a {type(iterator)}")
+def validate_aiterator(iterator: AsyncIterator):
+ if not isinstance(iterator, AsyncIterator):
+ raise TypeError(
+ f"`iterator` must be an AsyncIterator but got a {type(iterator)}"
+ )
+
+
def validate_base(base: int):
if base <= 0:
raise ValueError(f"`base` must be > 0 but got {base}")
@@ -69,7 +84,7 @@ def validate_optional_positive_count(count: Optional[int]):
# def validate_not_none(value: Any, name: str) -> None:
# if value is None:
-# raise TypeError(f"`{name}` cannot be None")
+# raise TypeError(f"`{name}` must not be None")
def validate_errors(
diff --git a/streamable/visitors/__init__.py b/streamable/visitors/__init__.py
index 1234b6b4..ba67ca01 100644
--- a/streamable/visitors/__init__.py
+++ b/streamable/visitors/__init__.py
@@ -1 +1,3 @@
from streamable.visitors.base import Visitor
+
+__all__ = ["Visitor"]
diff --git a/streamable/visitors/aiterator.py b/streamable/visitors/aiterator.py
new file mode 100644
index 00000000..4cfe959e
--- /dev/null
+++ b/streamable/visitors/aiterator.py
@@ -0,0 +1,227 @@
+from typing import AsyncIterable, AsyncIterator, Iterable, TypeVar, cast
+
+from streamable import afunctions
+from streamable.stream import (
+ ACatchStream,
+ ADistinctStream,
+ AFilterStream,
+ AFlattenStream,
+ AForeachStream,
+ AGroupbyStream,
+ AGroupStream,
+ AMapStream,
+ ASkipStream,
+ ATruncateStream,
+ CatchStream,
+ DistinctStream,
+ FilterStream,
+ FlattenStream,
+ ForeachStream,
+ GroupbyStream,
+ GroupStream,
+ MapStream,
+ ObserveStream,
+ SkipStream,
+ Stream,
+ ThrottleStream,
+ TruncateStream,
+)
+from streamable.util.functiontools import async_sidify, sidify
+from streamable.util.iterabletools import sync_to_async_iter
+from streamable.visitors import Visitor
+
+T = TypeVar("T")
+U = TypeVar("U")
+
+
+class AsyncIteratorVisitor(Visitor[AsyncIterator[T]]):
+ def visit_catch_stream(self, stream: CatchStream[T]) -> AsyncIterator[T]:
+ return afunctions.catch(
+ stream.upstream.accept(self),
+ stream._errors,
+ when=stream._when,
+ replacement=stream._replacement,
+ finally_raise=stream._finally_raise,
+ )
+
+ def visit_acatch_stream(self, stream: ACatchStream[T]) -> AsyncIterator[T]:
+ return afunctions.acatch(
+ stream.upstream.accept(self),
+ stream._errors,
+ when=stream._when,
+ replacement=stream._replacement,
+ finally_raise=stream._finally_raise,
+ )
+
+ def visit_distinct_stream(self, stream: DistinctStream[T]) -> AsyncIterator[T]:
+ return afunctions.distinct(
+ stream.upstream.accept(self),
+ stream._key,
+ consecutive_only=stream._consecutive_only,
+ )
+
+ def visit_adistinct_stream(self, stream: ADistinctStream[T]) -> AsyncIterator[T]:
+ return afunctions.adistinct(
+ stream.upstream.accept(self),
+ stream._key,
+ consecutive_only=stream._consecutive_only,
+ )
+
+ def visit_filter_stream(self, stream: FilterStream[T]) -> AsyncIterator[T]:
+ return afunctions.filter(stream.upstream.accept(self), stream._when)
+
+ def visit_afilter_stream(self, stream: AFilterStream[T]) -> AsyncIterator[T]:
+ return afunctions.afilter(stream.upstream.accept(self), stream._when)
+
+ def visit_flatten_stream(self, stream: FlattenStream[T]) -> AsyncIterator[T]:
+ return afunctions.flatten(
+ stream.upstream.accept(AsyncIteratorVisitor[Iterable]()),
+ concurrency=stream._concurrency,
+ )
+
+ def visit_aflatten_stream(self, stream: AFlattenStream[T]) -> AsyncIterator[T]:
+ return afunctions.aflatten(
+ stream.upstream.accept(AsyncIteratorVisitor[AsyncIterable]()),
+ concurrency=stream._concurrency,
+ )
+
+ def visit_foreach_stream(self, stream: ForeachStream[T]) -> AsyncIterator[T]:
+ return self.visit_map_stream(
+ MapStream(
+ stream.upstream,
+ sidify(stream._effect),
+ stream._concurrency,
+ stream._ordered,
+ stream._via,
+ )
+ )
+
+ def visit_aforeach_stream(self, stream: AForeachStream[T]) -> AsyncIterator[T]:
+ return self.visit_amap_stream(
+ AMapStream(
+ stream.upstream,
+ async_sidify(stream._effect),
+ stream._concurrency,
+ stream._ordered,
+ )
+ )
+
+ def visit_group_stream(self, stream: GroupStream[U]) -> AsyncIterator[T]:
+ return cast(
+ AsyncIterator[T],
+ afunctions.group(
+ stream.upstream.accept(AsyncIteratorVisitor[U]()),
+ stream._size,
+ interval=stream._interval,
+ by=stream._by,
+ ),
+ )
+
+ def visit_agroup_stream(self, stream: AGroupStream[U]) -> AsyncIterator[T]:
+ return cast(
+ AsyncIterator[T],
+ afunctions.agroup(
+ stream.upstream.accept(AsyncIteratorVisitor[U]()),
+ stream._size,
+ interval=stream._interval,
+ by=stream._by,
+ ),
+ )
+
+ def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> AsyncIterator[T]:
+ return cast(
+ AsyncIterator[T],
+ afunctions.groupby(
+ stream.upstream.accept(AsyncIteratorVisitor[U]()),
+ stream._key,
+ size=stream._size,
+ interval=stream._interval,
+ ),
+ )
+
+ def visit_agroupby_stream(self, stream: AGroupbyStream[U, T]) -> AsyncIterator[T]:
+ return cast(
+ AsyncIterator[T],
+ afunctions.agroupby(
+ stream.upstream.accept(AsyncIteratorVisitor[U]()),
+ stream._key,
+ size=stream._size,
+ interval=stream._interval,
+ ),
+ )
+
+ def visit_map_stream(self, stream: MapStream[U, T]) -> AsyncIterator[T]:
+ return afunctions.map(
+ stream._transformation,
+ stream.upstream.accept(AsyncIteratorVisitor[U]()),
+ concurrency=stream._concurrency,
+ ordered=stream._ordered,
+ via=stream._via,
+ )
+
+ def visit_amap_stream(self, stream: AMapStream[U, T]) -> AsyncIterator[T]:
+ return afunctions.amap(
+ stream._transformation,
+ stream.upstream.accept(AsyncIteratorVisitor[U]()),
+ concurrency=stream._concurrency,
+ ordered=stream._ordered,
+ )
+
+ def visit_observe_stream(self, stream: ObserveStream[T]) -> AsyncIterator[T]:
+ return afunctions.observe(
+ stream.upstream.accept(self),
+ stream._what,
+ )
+
+ def visit_skip_stream(self, stream: SkipStream[T]) -> AsyncIterator[T]:
+ return afunctions.skip(
+ stream.upstream.accept(self),
+ stream._count,
+ until=stream._until,
+ )
+
+ def visit_askip_stream(self, stream: ASkipStream[T]) -> AsyncIterator[T]:
+ return afunctions.askip(
+ stream.upstream.accept(self),
+ stream._count,
+ until=stream._until,
+ )
+
+ def visit_throttle_stream(self, stream: ThrottleStream[T]) -> AsyncIterator[T]:
+ return afunctions.throttle(
+ stream.upstream.accept(self),
+ stream._count,
+ per=stream._per,
+ )
+
+ def visit_truncate_stream(self, stream: TruncateStream[T]) -> AsyncIterator[T]:
+ return afunctions.truncate(
+ stream.upstream.accept(self),
+ stream._count,
+ when=stream._when,
+ )
+
+ def visit_atruncate_stream(self, stream: ATruncateStream[T]) -> AsyncIterator[T]:
+ return afunctions.atruncate(
+ stream.upstream.accept(self),
+ stream._count,
+ when=stream._when,
+ )
+
+ def visit_stream(self, stream: Stream[T]) -> AsyncIterator[T]:
+ if isinstance(stream.source, Iterable):
+ return sync_to_async_iter(stream.source)
+ if isinstance(stream.source, AsyncIterable):
+ return stream.source.__aiter__()
+ if callable(stream.source):
+ iterable = stream.source()
+ if isinstance(iterable, Iterable):
+ return sync_to_async_iter(iterable)
+ if isinstance(iterable, AsyncIterable):
+ return iterable.__aiter__()
+ raise TypeError(
+ f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a Callable[[], {type(iterable)}]"
+ )
+ raise TypeError(
+ f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a {type(stream.source)}"
+ )
diff --git a/streamable/visitors/base.py b/streamable/visitors/base.py
index d8f5d409..3cf28bfd 100644
--- a/streamable/visitors/base.py
+++ b/streamable/visitors/base.py
@@ -15,15 +15,27 @@ def visit_stream(self, stream: stream.Stream) -> V: ...
def visit_catch_stream(self, stream: stream.CatchStream) -> V:
return self.visit_stream(stream)
+ def visit_acatch_stream(self, stream: stream.ACatchStream) -> V:
+ return self.visit_stream(stream)
+
def visit_distinct_stream(self, stream: stream.DistinctStream) -> V:
return self.visit_stream(stream)
+ def visit_adistinct_stream(self, stream: stream.ADistinctStream) -> V:
+ return self.visit_stream(stream)
+
def visit_filter_stream(self, stream: stream.FilterStream) -> V:
return self.visit_stream(stream)
+ def visit_afilter_stream(self, stream: stream.AFilterStream) -> V:
+ return self.visit_stream(stream)
+
def visit_flatten_stream(self, stream: stream.FlattenStream) -> V:
return self.visit_stream(stream)
+ def visit_aflatten_stream(self, stream: stream.AFlattenStream) -> V:
+ return self.visit_stream(stream)
+
def visit_foreach_stream(self, stream: stream.ForeachStream) -> V:
return self.visit_stream(stream)
@@ -33,9 +45,15 @@ def visit_aforeach_stream(self, stream: stream.AForeachStream) -> V:
def visit_group_stream(self, stream: stream.GroupStream) -> V:
return self.visit_stream(stream)
+ def visit_agroup_stream(self, stream: stream.AGroupStream) -> V:
+ return self.visit_stream(stream)
+
def visit_groupby_stream(self, stream: stream.GroupbyStream) -> V:
return self.visit_stream(stream)
+ def visit_agroupby_stream(self, stream: stream.AGroupbyStream) -> V:
+ return self.visit_stream(stream)
+
def visit_observe_stream(self, stream: stream.ObserveStream) -> V:
return self.visit_stream(stream)
@@ -48,8 +66,14 @@ def visit_amap_stream(self, stream: stream.AMapStream) -> V:
def visit_skip_stream(self, stream: stream.SkipStream) -> V:
return self.visit_stream(stream)
+ def visit_askip_stream(self, stream: stream.ASkipStream) -> V:
+ return self.visit_stream(stream)
+
def visit_throttle_stream(self, stream: stream.ThrottleStream) -> V:
return self.visit_stream(stream)
def visit_truncate_stream(self, stream: stream.TruncateStream) -> V:
return self.visit_stream(stream)
+
+ def visit_atruncate_stream(self, stream: stream.ATruncateStream) -> V:
+ return self.visit_stream(stream)
diff --git a/streamable/visitors/equality.py b/streamable/visitors/equality.py
index 3af9f2d0..d03bc09e 100644
--- a/streamable/visitors/equality.py
+++ b/streamable/visitors/equality.py
@@ -1,8 +1,16 @@
-from typing import Any, TypeVar
+from typing import Any, Union
from streamable.stream import (
+ ACatchStream,
+ ADistinctStream,
+ AFilterStream,
+ AFlattenStream,
AForeachStream,
+ AGroupbyStream,
+ AGroupStream,
AMapStream,
+ ASkipStream,
+ ATruncateStream,
CatchStream,
DistinctStream,
FilterStream,
@@ -19,17 +27,17 @@
)
from streamable.visitors import Visitor
-T = TypeVar("T")
-U = TypeVar("U")
-
class EqualityVisitor(Visitor[bool]):
def __init__(self, other: Any):
self.other: Any = other
- def visit_catch_stream(self, stream: CatchStream[T]) -> bool:
+ def type_eq(self, stream: Stream) -> bool:
+ return type(stream) == type(self.other)
+
+ def catch_eq(self, stream: Union[CatchStream, ACatchStream]) -> bool:
return (
- isinstance(self.other, CatchStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._errors == self.other._errors
and stream._when == self.other._when
@@ -37,114 +45,154 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> bool:
and stream._finally_raise == self.other._finally_raise
)
- def visit_distinct_stream(self, stream: DistinctStream[T]) -> bool:
+ def visit_catch_stream(self, stream: CatchStream) -> bool:
+ return self.catch_eq(stream)
+
+ def visit_acatch_stream(self, stream: ACatchStream) -> bool:
+ return self.catch_eq(stream)
+
+ def distinct_eq(self, stream: Union[DistinctStream, ADistinctStream]) -> bool:
return (
- isinstance(self.other, DistinctStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._key == self.other._key
and stream._consecutive_only == self.other._consecutive_only
)
- def visit_filter_stream(self, stream: FilterStream[T]) -> bool:
+ def visit_distinct_stream(self, stream: DistinctStream) -> bool:
+ return self.distinct_eq(stream)
+
+ def visit_adistinct_stream(self, stream: ADistinctStream) -> bool:
+ return self.distinct_eq(stream)
+
+ def filter_eq(self, stream: Union[FilterStream, AFilterStream]) -> bool:
return (
- isinstance(self.other, FilterStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._when == self.other._when
)
- def visit_flatten_stream(self, stream: FlattenStream[T]) -> bool:
- return (
- isinstance(self.other, FlattenStream)
- and stream.upstream.accept(EqualityVisitor(self.other.upstream))
- and stream._concurrency == self.other._concurrency
- )
+ def visit_filter_stream(self, stream: FilterStream) -> bool:
+ return self.filter_eq(stream)
- def visit_foreach_stream(self, stream: ForeachStream[T]) -> bool:
+ def visit_afilter_stream(self, stream: AFilterStream) -> bool:
+ return self.filter_eq(stream)
+
+ def flatten_eq(self, stream: Union[FlattenStream, AFlattenStream]) -> bool:
return (
- isinstance(self.other, ForeachStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._concurrency == self.other._concurrency
- and stream._effect == self.other._effect
- and stream._ordered == self.other._ordered
- and stream._via == self.other._via
)
- def visit_aforeach_stream(self, stream: AForeachStream[T]) -> bool:
+ def visit_flatten_stream(self, stream: FlattenStream) -> bool:
+ return self.flatten_eq(stream)
+
+ def visit_aflatten_stream(self, stream: AFlattenStream) -> bool:
+ return self.flatten_eq(stream)
+
+ def foreach_eq(self, stream: Union[ForeachStream, AForeachStream]) -> bool:
return (
- isinstance(self.other, AForeachStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._concurrency == self.other._concurrency
and stream._effect == self.other._effect
and stream._ordered == self.other._ordered
)
- def visit_group_stream(self, stream: GroupStream[U]) -> bool:
+ def visit_foreach_stream(self, stream: ForeachStream) -> bool:
+ return self.foreach_eq(stream) and stream._via == self.other._via
+
+ def visit_aforeach_stream(self, stream: AForeachStream) -> bool:
+ return self.foreach_eq(stream)
+
+ def group_eq(self, stream: Union[GroupStream, AGroupStream]) -> bool:
return (
- isinstance(self.other, GroupStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._by == self.other._by
and stream._size == self.other._size
and stream._interval == self.other._interval
)
- def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> bool:
+ def visit_group_stream(self, stream: GroupStream) -> bool:
+ return self.group_eq(stream)
+
+ def visit_agroup_stream(self, stream: AGroupStream) -> bool:
+ return self.group_eq(stream)
+
+ def groupby_eq(self, stream: Union[GroupbyStream, AGroupbyStream]) -> bool:
return (
- isinstance(self.other, GroupbyStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._key == self.other._key
and stream._size == self.other._size
and stream._interval == self.other._interval
)
- def visit_map_stream(self, stream: MapStream[U, T]) -> bool:
- return (
- isinstance(self.other, MapStream)
- and stream.upstream.accept(EqualityVisitor(self.other.upstream))
- and stream._concurrency == self.other._concurrency
- and stream._transformation == self.other._transformation
- and stream._ordered == self.other._ordered
- and stream._via == self.other._via
- )
+ def visit_groupby_stream(self, stream: GroupbyStream) -> bool:
+ return self.groupby_eq(stream)
+
+ def visit_agroupby_stream(self, stream: AGroupbyStream) -> bool:
+ return self.groupby_eq(stream)
- def visit_amap_stream(self, stream: AMapStream[U, T]) -> bool:
+ def map_eq(self, stream: Union[MapStream, AMapStream]) -> bool:
return (
- isinstance(self.other, AMapStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._concurrency == self.other._concurrency
and stream._transformation == self.other._transformation
and stream._ordered == self.other._ordered
)
- def visit_observe_stream(self, stream: ObserveStream[T]) -> bool:
+ def visit_map_stream(self, stream: MapStream) -> bool:
+ return self.map_eq(stream) and stream._via == self.other._via
+
+ def visit_amap_stream(self, stream: AMapStream) -> bool:
+ return self.map_eq(stream)
+
+ def visit_observe_stream(self, stream: ObserveStream) -> bool:
return (
- isinstance(self.other, ObserveStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._what == self.other._what
)
- def visit_skip_stream(self, stream: SkipStream[T]) -> bool:
+ def skip_eq(self, stream: Union[SkipStream, ASkipStream]) -> bool:
return (
- isinstance(self.other, SkipStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._count == self.other._count
and stream._until == self.other._until
)
- def visit_throttle_stream(self, stream: ThrottleStream[T]) -> bool:
+ def visit_skip_stream(self, stream: SkipStream) -> bool:
+ return self.skip_eq(stream)
+
+ def visit_askip_stream(self, stream: ASkipStream) -> bool:
+ return self.skip_eq(stream)
+
+ def visit_throttle_stream(self, stream: ThrottleStream) -> bool:
return (
- isinstance(self.other, ThrottleStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._count == self.other._count
and stream._per == self.other._per
)
- def visit_truncate_stream(self, stream: TruncateStream[T]) -> bool:
+ def truncate_eq(self, stream: Union[TruncateStream, ATruncateStream]) -> bool:
return (
- isinstance(self.other, TruncateStream)
+ self.type_eq(stream)
and stream.upstream.accept(EqualityVisitor(self.other.upstream))
and stream._count == self.other._count
and stream._when == self.other._when
)
- def visit_stream(self, stream: Stream[T]) -> bool:
- return isinstance(self.other, Stream) and stream.source == self.other.source
+ def visit_truncate_stream(self, stream: TruncateStream) -> bool:
+ return self.truncate_eq(stream)
+
+ def visit_atruncate_stream(self, stream: ATruncateStream) -> bool:
+ return self.truncate_eq(stream)
+
+ def visit_stream(self, stream: Stream) -> bool:
+ return self.type_eq(stream) and stream.source == self.other.source
diff --git a/streamable/visitors/iterator.py b/streamable/visitors/iterator.py
index 417aa27b..bb63bdb0 100644
--- a/streamable/visitors/iterator.py
+++ b/streamable/visitors/iterator.py
@@ -1,9 +1,17 @@
-from typing import Iterable, Iterator, TypeVar, cast
+from typing import AsyncIterable, Iterable, Iterator, TypeVar, cast
from streamable import functions
from streamable.stream import (
+ ACatchStream,
+ ADistinctStream,
+ AFilterStream,
+ AFlattenStream,
AForeachStream,
+ AGroupbyStream,
+ AGroupStream,
AMapStream,
+ ASkipStream,
+ ATruncateStream,
CatchStream,
DistinctStream,
FilterStream,
@@ -18,7 +26,8 @@
ThrottleStream,
TruncateStream,
)
-from streamable.util.functiontools import async_sidify, sidify, wrap_error
+from streamable.util.functiontools import async_sidify, sidify, syncify, wrap_error
+from streamable.util.iterabletools import async_to_sync_iter
from streamable.visitors import Visitor
T = TypeVar("T")
@@ -35,6 +44,15 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> Iterator[T]:
finally_raise=stream._finally_raise,
)
+ def visit_acatch_stream(self, stream: ACatchStream[T]) -> Iterator[T]:
+ return functions.acatch(
+ stream.upstream.accept(self),
+ stream._errors,
+ when=stream._when,
+ replacement=stream._replacement,
+ finally_raise=stream._finally_raise,
+ )
+
def visit_distinct_stream(self, stream: DistinctStream[T]) -> Iterator[T]:
return functions.distinct(
stream.upstream.accept(self),
@@ -42,18 +60,37 @@ def visit_distinct_stream(self, stream: DistinctStream[T]) -> Iterator[T]:
consecutive_only=stream._consecutive_only,
)
+ def visit_adistinct_stream(self, stream: ADistinctStream[T]) -> Iterator[T]:
+ return functions.adistinct(
+ stream.upstream.accept(self),
+ stream._key,
+ consecutive_only=stream._consecutive_only,
+ )
+
def visit_filter_stream(self, stream: FilterStream[T]) -> Iterator[T]:
return filter(
wrap_error(stream._when, StopIteration),
cast(Iterable[T], stream.upstream.accept(self)),
)
+ def visit_afilter_stream(self, stream: AFilterStream[T]) -> Iterator[T]:
+ return filter(
+ wrap_error(syncify(stream._when), StopIteration),
+ cast(Iterable[T], stream.upstream.accept(self)),
+ )
+
def visit_flatten_stream(self, stream: FlattenStream[T]) -> Iterator[T]:
return functions.flatten(
stream.upstream.accept(IteratorVisitor[Iterable]()),
concurrency=stream._concurrency,
)
+ def visit_aflatten_stream(self, stream: AFlattenStream[T]) -> Iterator[T]:
+ return functions.aflatten(
+ stream.upstream.accept(IteratorVisitor[AsyncIterable]()),
+ concurrency=stream._concurrency,
+ )
+
def visit_foreach_stream(self, stream: ForeachStream[T]) -> Iterator[T]:
return self.visit_map_stream(
MapStream(
@@ -86,6 +123,17 @@ def visit_group_stream(self, stream: GroupStream[U]) -> Iterator[T]:
),
)
+ def visit_agroup_stream(self, stream: AGroupStream[U]) -> Iterator[T]:
+ return cast(
+ Iterator[T],
+ functions.agroup(
+ stream.upstream.accept(IteratorVisitor[U]()),
+ stream._size,
+ interval=stream._interval,
+ by=stream._by,
+ ),
+ )
+
def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> Iterator[T]:
return cast(
Iterator[T],
@@ -97,6 +145,17 @@ def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> Iterator[T]:
),
)
+ def visit_agroupby_stream(self, stream: AGroupbyStream[U, T]) -> Iterator[T]:
+ return cast(
+ Iterator[T],
+ functions.agroupby(
+ stream.upstream.accept(IteratorVisitor[U]()),
+ stream._key,
+ size=stream._size,
+ interval=stream._interval,
+ ),
+ )
+
def visit_map_stream(self, stream: MapStream[U, T]) -> Iterator[T]:
return functions.map(
stream._transformation,
@@ -127,6 +186,13 @@ def visit_skip_stream(self, stream: SkipStream[T]) -> Iterator[T]:
until=stream._until,
)
+ def visit_askip_stream(self, stream: ASkipStream[T]) -> Iterator[T]:
+ return functions.askip(
+ stream.upstream.accept(self),
+ stream._count,
+ until=stream._until,
+ )
+
def visit_throttle_stream(self, stream: ThrottleStream[T]) -> Iterator[T]:
return functions.throttle(
stream.upstream.accept(self),
@@ -141,17 +207,27 @@ def visit_truncate_stream(self, stream: TruncateStream[T]) -> Iterator[T]:
when=stream._when,
)
+ def visit_atruncate_stream(self, stream: ATruncateStream[T]) -> Iterator[T]:
+ return functions.atruncate(
+ stream.upstream.accept(self),
+ stream._count,
+ when=stream._when,
+ )
+
def visit_stream(self, stream: Stream[T]) -> Iterator[T]:
if isinstance(stream.source, Iterable):
- iterable = stream.source
- elif callable(stream.source):
+ return stream.source.__iter__()
+ if isinstance(stream.source, AsyncIterable):
+ return async_to_sync_iter(stream.source)
+ if callable(stream.source):
iterable = stream.source()
- if not isinstance(iterable, Iterable):
- raise TypeError(
- f"`source` must be an Iterable or a Callable[[], Iterable] but got a Callable[[], {type(iterable)}]"
- )
- else:
+ if isinstance(iterable, Iterable):
+ return iterable.__iter__()
+ if isinstance(iterable, AsyncIterable):
+ return async_to_sync_iter(iterable)
raise TypeError(
- f"`source` must be an Iterable or a Callable[[], Iterable] but got a {type(stream.source)}"
+ f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a Callable[[], {type(iterable)}]"
)
- return iter(iterable)
+ raise TypeError(
+ f"`source` must be an Iterable/AsyncIterable or a Callable[[], Iterable/AsyncIterable] but got a {type(stream.source)}"
+ )
diff --git a/streamable/visitors/representation.py b/streamable/visitors/representation.py
index 434b3ffc..0e36f20b 100644
--- a/streamable/visitors/representation.py
+++ b/streamable/visitors/representation.py
@@ -1,9 +1,17 @@
from abc import ABC, abstractmethod
-from typing import Any, Iterable, List, Type, TypeVar
+from typing import Any, Iterable, List
from streamable.stream import (
+ ACatchStream,
+ ADistinctStream,
+ AFilterStream,
+ AFlattenStream,
AForeachStream,
+ AGroupbyStream,
+ AGroupStream,
AMapStream,
+ ASkipStream,
+ ATruncateStream,
CatchStream,
DistinctStream,
FilterStream,
@@ -22,19 +30,17 @@
from streamable.util.functiontools import _Star
from streamable.visitors import Visitor
-T = TypeVar("T")
-U = TypeVar("U")
-
class ToStringVisitor(Visitor[str], ABC):
- def __init__(self) -> None:
+ def __init__(self, one_liner_max_depth: int = 3) -> None:
self.methods_reprs: List[str] = []
+ self.one_liner_max_depth = one_liner_max_depth
@staticmethod
@abstractmethod
def to_string(o: object) -> str: ...
- def visit_catch_stream(self, stream: CatchStream[T]) -> str:
+ def visit_catch_stream(self, stream: CatchStream) -> str:
replacement = ""
if stream._replacement is not NO_REPLACEMENT:
replacement = f", replacement={self.to_string(stream._replacement)}"
@@ -47,87 +53,146 @@ def visit_catch_stream(self, stream: CatchStream[T]) -> str:
)
return stream.upstream.accept(self)
- def visit_distinct_stream(self, stream: DistinctStream[T]) -> str:
+ def visit_acatch_stream(self, stream: ACatchStream) -> str:
+ replacement = ""
+ if stream._replacement is not NO_REPLACEMENT:
+ replacement = f", replacement={self.to_string(stream._replacement)}"
+ if isinstance(stream._errors, Iterable):
+ errors = f"({', '.join(map(self.to_string, stream._errors))})"
+ else:
+ errors = self.to_string(stream._errors)
+ self.methods_reprs.append(
+ f"acatch({errors}, when={self.to_string(stream._when)}{replacement}, finally_raise={self.to_string(stream._finally_raise)})"
+ )
+ return stream.upstream.accept(self)
+
+ def visit_distinct_stream(self, stream: DistinctStream) -> str:
self.methods_reprs.append(
f"distinct({self.to_string(stream._key)}, consecutive_only={self.to_string(stream._consecutive_only)})"
)
return stream.upstream.accept(self)
- def visit_filter_stream(self, stream: FilterStream[T]) -> str:
+ def visit_adistinct_stream(self, stream: ADistinctStream) -> str:
+ self.methods_reprs.append(
+ f"adistinct({self.to_string(stream._key)}, consecutive_only={self.to_string(stream._consecutive_only)})"
+ )
+ return stream.upstream.accept(self)
+
+ def visit_filter_stream(self, stream: FilterStream) -> str:
self.methods_reprs.append(f"filter({self.to_string(stream._when)})")
return stream.upstream.accept(self)
- def visit_flatten_stream(self, stream: FlattenStream[T]) -> str:
+ def visit_afilter_stream(self, stream: AFilterStream) -> str:
+ self.methods_reprs.append(f"afilter({self.to_string(stream._when)})")
+ return stream.upstream.accept(self)
+
+ def visit_flatten_stream(self, stream: FlattenStream) -> str:
self.methods_reprs.append(
f"flatten(concurrency={self.to_string(stream._concurrency)})"
)
return stream.upstream.accept(self)
- def visit_foreach_stream(self, stream: ForeachStream[T]) -> str:
+ def visit_aflatten_stream(self, stream: AFlattenStream) -> str:
+ self.methods_reprs.append(
+ f"aflatten(concurrency={self.to_string(stream._concurrency)})"
+ )
+ return stream.upstream.accept(self)
+
+ def visit_foreach_stream(self, stream: ForeachStream) -> str:
via = f", via={self.to_string(stream._via)}" if stream._concurrency > 1 else ""
self.methods_reprs.append(
f"foreach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}{via})"
)
return stream.upstream.accept(self)
- def visit_aforeach_stream(self, stream: AForeachStream[T]) -> str:
+ def visit_aforeach_stream(self, stream: AForeachStream) -> str:
self.methods_reprs.append(
f"aforeach({self.to_string(stream._effect)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)})"
)
return stream.upstream.accept(self)
- def visit_group_stream(self, stream: GroupStream[U]) -> str:
+ def visit_group_stream(self, stream: GroupStream) -> str:
self.methods_reprs.append(
f"group(size={self.to_string(stream._size)}, by={self.to_string(stream._by)}, interval={self.to_string(stream._interval)})"
)
return stream.upstream.accept(self)
- def visit_groupby_stream(self, stream: GroupbyStream[U, T]) -> str:
+ def visit_agroup_stream(self, stream: AGroupStream) -> str:
+ self.methods_reprs.append(
+ f"agroup(size={self.to_string(stream._size)}, by={self.to_string(stream._by)}, interval={self.to_string(stream._interval)})"
+ )
+ return stream.upstream.accept(self)
+
+ def visit_groupby_stream(self, stream: GroupbyStream) -> str:
self.methods_reprs.append(
f"groupby({self.to_string(stream._key)}, size={self.to_string(stream._size)}, interval={self.to_string(stream._interval)})"
)
return stream.upstream.accept(self)
- def visit_map_stream(self, stream: MapStream[U, T]) -> str:
+ def visit_agroupby_stream(self, stream: AGroupbyStream) -> str:
+ self.methods_reprs.append(
+ f"agroupby({self.to_string(stream._key)}, size={self.to_string(stream._size)}, interval={self.to_string(stream._interval)})"
+ )
+ return stream.upstream.accept(self)
+
+ def visit_map_stream(self, stream: MapStream) -> str:
via = f", via={self.to_string(stream._via)}" if stream._concurrency > 1 else ""
self.methods_reprs.append(
f"map({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)}{via})"
)
return stream.upstream.accept(self)
- def visit_amap_stream(self, stream: AMapStream[U, T]) -> str:
+ def visit_amap_stream(self, stream: AMapStream) -> str:
self.methods_reprs.append(
f"amap({self.to_string(stream._transformation)}, concurrency={self.to_string(stream._concurrency)}, ordered={self.to_string(stream._ordered)})"
)
return stream.upstream.accept(self)
- def visit_observe_stream(self, stream: ObserveStream[T]) -> str:
+ def visit_observe_stream(self, stream: ObserveStream) -> str:
self.methods_reprs.append(f"""observe({self.to_string(stream._what)})""")
return stream.upstream.accept(self)
- def visit_skip_stream(self, stream: SkipStream[T]) -> str:
+ def visit_skip_stream(self, stream: SkipStream) -> str:
self.methods_reprs.append(
f"skip({self.to_string(stream._count)}, until={self.to_string(stream._until)})"
)
return stream.upstream.accept(self)
- def visit_throttle_stream(self, stream: ThrottleStream[T]) -> str:
+ def visit_askip_stream(self, stream: ASkipStream) -> str:
+ self.methods_reprs.append(
+ f"askip({self.to_string(stream._count)}, until={self.to_string(stream._until)})"
+ )
+ return stream.upstream.accept(self)
+
+ def visit_throttle_stream(self, stream: ThrottleStream) -> str:
self.methods_reprs.append(
f"throttle({self.to_string(stream._count)}, per={self.to_string(stream._per)})"
)
return stream.upstream.accept(self)
- def visit_truncate_stream(self, stream: TruncateStream[T]) -> str:
+ def visit_truncate_stream(self, stream: TruncateStream) -> str:
self.methods_reprs.append(
f"truncate(count={self.to_string(stream._count)}, when={self.to_string(stream._when)})"
)
return stream.upstream.accept(self)
- def visit_stream(self, stream: Stream[T]) -> str:
+ def visit_atruncate_stream(self, stream: ATruncateStream) -> str:
+ self.methods_reprs.append(
+ f"atruncate(count={self.to_string(stream._count)}, when={self.to_string(stream._when)})"
+ )
+ return stream.upstream.accept(self)
+
+ def visit_stream(self, stream: Stream) -> str:
+ source_stream = f"Stream({self.to_string(stream.source)})"
+ depth = len(self.methods_reprs) + 1
+ if depth == 1:
+ return source_stream
+ if depth <= self.one_liner_max_depth:
+ return f"{source_stream}.{'.'.join(reversed(self.methods_reprs))}"
methods_block = "".join(
map(lambda r: f" .{r}\n", reversed(self.methods_reprs))
)
- return f"(\n Stream({self.to_string(stream.source)})\n{methods_block})"
+ return f"(\n {source_stream}\n{methods_block})"
class ReprVisitor(ToStringVisitor):
diff --git a/tests/test_iterators.py b/tests/test_iterators.py
index b9a74de5..2bb30d37 100644
--- a/tests/test_iterators.py
+++ b/tests/test_iterators.py
@@ -1,6 +1,15 @@
+import asyncio
import unittest
+from typing import AsyncIterator
-from streamable.iterators import ObserveIterator, _OSConcurrentMapIterable
+from streamable.aiterators import (
+ _ConcurrentAMapAsyncIterable,
+ _RaisingAsyncIterator,
+)
+from streamable.iterators import ObserveIterator, _ConcurrentMapIterable
+from streamable.util.asynctools import awaitable_to_coroutine
+from streamable.util.iterabletools import sync_to_async_iter
+from tests.utils import async_identity, identity, src
class TestIterators(unittest.TestCase):
@@ -8,9 +17,9 @@ def test_validation(self):
with self.assertRaisesRegex(
ValueError,
"`buffersize` must be >= 1 but got 0",
- msg="`_OSConcurrentMapIterable` constructor should raise for non-positive buffersize",
+ msg="`_ConcurrentMapIterable` constructor should raise for non-positive buffersize",
):
- _OSConcurrentMapIterable(
+ _ConcurrentMapIterable(
iterator=iter([]),
transformation=str,
concurrency=1,
@@ -29,3 +38,28 @@ def test_validation(self):
what="",
base=0,
)
+
+ def test_ConcurrentAMapAsyncIterable(self) -> None:
+ with self.assertRaisesRegex(
+ TypeError,
+ r"must be an async function i\.e\. a function returning a Coroutine but it returned a ",
+ msg="`amap` should raise a TypeError if a non async function is passed to it.",
+ ):
+ concurrent_amap_async_iterable: _ConcurrentAMapAsyncIterable[int, int] = (
+ _ConcurrentAMapAsyncIterable(
+ sync_to_async_iter(src),
+ async_identity,
+ buffersize=2,
+ ordered=True,
+ )
+ )
+
+ # remove error wrapping
+ concurrent_amap_async_iterable.transformation = identity # type: ignore
+
+ aiterator: AsyncIterator[int] = _RaisingAsyncIterator(
+ concurrent_amap_async_iterable.__aiter__()
+ )
+ print(
+ asyncio.run(awaitable_to_coroutine(aiterator.__aiter__().__anext__()))
+ )
diff --git a/tests/test_readme.py b/tests/test_readme.py
index a0f1a89e..cf748176 100644
--- a/tests/test_readme.py
+++ b/tests/test_readme.py
@@ -1,7 +1,8 @@
+import asyncio
import time
import unittest
from datetime import timedelta
-from typing import Iterator, List, Tuple
+from typing import AsyncIterable, Iterator, List, Tuple, TypeVar
from streamable.stream import Stream
@@ -15,10 +16,12 @@
three_integers_per_second: Stream[int] = integers.throttle(5, per=timedelta(seconds=1))
+T = TypeVar("T")
+
# fmt: off
class TestReadme(unittest.TestCase):
- def test_collect_it(self) -> None:
+ def test_iterate(self) -> None:
self.assertListEqual(
list(inverses),
[1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11],
@@ -34,6 +37,11 @@ def test_collect_it(self) -> None:
self.assertEqual(next(inverses_iter), 1.0)
self.assertEqual(next(inverses_iter), 0.5)
+ async def main() -> List[float]:
+ return [inverse async for inverse in inverses]
+
+ assert asyncio.run(main()) == [1.0, 0.5, 0.33, 0.25, 0.2, 0.17, 0.14, 0.12, 0.11]
+
def test_map_example(self) -> None:
integer_strings: Stream[str] = integers.map(str)
@@ -59,7 +67,7 @@ def test_process_concurrent_map_example(self) -> None:
# but the `state` of the main process is not mutated
assert state == []
- def test_async_concurrent_map_example(self) -> None:
+ def test_amap_example(self) -> None:
import asyncio
import httpx
@@ -75,7 +83,25 @@ def test_async_concurrent_map_example(self) -> None:
)
assert list(pokemon_names) == ['bulbasaur', 'ivysaur', 'venusaur']
- asyncio.get_event_loop().run_until_complete(http_async_client.aclose())
+ asyncio.run(http_async_client.aclose())
+
+ def test_async_amap_example(self) -> None:
+ import asyncio
+
+ import httpx
+
+ async def main() -> None:
+ async with httpx.AsyncClient() as http_async_client:
+ pokemon_names: Stream[str] = (
+ Stream(range(1, 4))
+ .map(lambda i: f"https://pokeapi.co/api/v2/pokemon-species/{i}")
+ .amap(http_async_client.get, concurrency=3)
+ .map(httpx.Response.json)
+ .map(lambda poke: poke["name"])
+ )
+ assert [name async for name in pokemon_names] == ['bulbasaur', 'ivysaur', 'venusaur']
+
+ asyncio.run(main())
def test_starmap_example(self) -> None:
from streamable import star
@@ -267,12 +293,23 @@ def test_zip_example(self) -> None:
def test_count_example(self) -> None:
assert integers.count() == 10
+ def test_acount_example(self) -> None:
+ assert asyncio.run(integers.acount()) == 10
+
def test_call_example(self) -> None:
state: List[int] = []
appending_integers: Stream[int] = integers.foreach(state.append)
assert appending_integers() is appending_integers
assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+ def test_await_example(self) -> None:
+ async def test() -> None:
+ state: List[int] = []
+ appending_integers: Stream[int] = integers.foreach(state.append)
+ appending_integers is await appending_integers
+ assert state == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
+ asyncio.run(test())
+
def test_non_stopping_exceptions_example(self) -> None:
from contextlib import suppress
@@ -286,6 +323,59 @@ def test_non_stopping_exceptions_example(self) -> None:
collected_casted_ints.extend(casted_ints)
assert collected_casted_ints == [0, 1, 2, 3, 5, 6, 7, 8, 9]
+ def test_async_etl_example(self) -> None: # pragma: no cover
+ # for mypy typing check only
+ if not self:
+ import asyncio
+ import csv
+ import itertools
+ from datetime import timedelta
+
+ import httpx
+
+ from streamable import Stream
+
+ async def main() -> None:
+ with open("./quadruped_pokemons.csv", mode="w") as file:
+ fields = ["id", "name", "is_legendary", "base_happiness", "capture_rate"]
+ writer = csv.DictWriter(file, fields, extrasaction='ignore')
+ writer.writeheader()
+
+ async with httpx.AsyncClient() as http_async_client:
+ pipeline: Stream = (
+ # Infinite Stream[int] of Pokemon ids starting from PokΓ©mon #1: Bulbasaur
+ Stream(itertools.count(1))
+ # Limits to 16 requests per second to be friendly to our fellow PokΓ©API devs
+ .throttle(16, per=timedelta(seconds=1))
+ # GETs pokemons via 8 concurrent coroutines
+ .map(lambda poke_id: f"https://pokeapi.co/api/v2/pokemon-species/{poke_id}")
+ .amap(http_async_client.get, concurrency=8)
+ .foreach(httpx.Response.raise_for_status)
+ .map(httpx.Response.json)
+ # Stops the iteration when reaching the 1st pokemon of the 4th generation
+ .truncate(when=lambda poke: poke["generation"]["name"] == "generation-iv")
+ .observe("pokemons")
+ # Keeps only quadruped Pokemons
+ .filter(lambda poke: poke["shape"]["name"] == "quadruped")
+ .observe("quadruped pokemons")
+ # Catches errors due to None "generation" or "shape"
+ .catch(
+ TypeError,
+ when=lambda error: str(error) == "'NoneType' object is not subscriptable"
+ )
+ # Writes a batch of pokemons every 5 seconds to the CSV file
+ .group(interval=timedelta(seconds=5))
+ .foreach(writer.writerows)
+ .flatten()
+ .observe("written pokemons")
+ # Catches exceptions and raises the 1st one at the end of the iteration
+ .catch(Exception, finally_raise=True)
+ )
+
+ await pipeline
+
+ asyncio.run(main())
+
def test_etl_example(self) -> None: # pragma: no cover
# for mypy typing check only
if not self:
diff --git a/tests/test_stream.py b/tests/test_stream.py
index 8bda6a9e..cad827b8 100644
--- a/tests/test_stream.py
+++ b/tests/test_stream.py
@@ -3,22 +3,21 @@
import datetime
import logging
import math
-from operator import itemgetter
-from pickle import PickleError
import queue
-import random
import sys
import threading
import time
-import timeit
import traceback
-from types import FrameType
import unittest
from collections import Counter
+from functools import partial
+from pickle import PickleError
+from types import TracebackType
from typing import (
Any,
+ AsyncIterable,
+ AsyncIterator,
Callable,
- Coroutine,
Dict,
Iterable,
Iterator,
@@ -27,139 +26,52 @@
Set,
Tuple,
Type,
- TypeVar,
+ Union,
cast,
)
from parameterized import parameterized # type: ignore
from streamable import Stream
-from streamable.util.functiontools import WrappedError, star
-
-T = TypeVar("T")
-R = TypeVar("R")
-
-
-def timestream(stream: Stream[T], times: int = 1) -> Tuple[float, List[T]]:
- res: List[T] = []
-
- def iterate():
- nonlocal res
- res = list(stream)
-
- return timeit.timeit(iterate, number=times) / times, res
-
-
-def identity_sleep(seconds: float) -> float:
- time.sleep(seconds)
- return seconds
-
-
-async def async_identity_sleep(seconds: float) -> float:
- await asyncio.sleep(seconds)
- return seconds
-
-
-# simulates an I/0 bound function
-slow_identity_duration = 0.05
-
-
-def slow_identity(x: T) -> T:
- time.sleep(slow_identity_duration)
- return x
-
-
-async def async_slow_identity(x: T) -> T:
- await asyncio.sleep(slow_identity_duration)
- return x
-
-
-def identity(x: T) -> T:
- return x
-
-
-# fmt: off
-async def async_identity(x: T) -> T: return x
-# fmt: on
-
-
-def square(x):
- return x**2
-
-
-async def async_square(x):
- return x**2
-
-
-def throw(exc: Type[Exception]):
- raise exc()
-
-
-def throw_func(exc: Type[Exception]) -> Callable[[T], T]:
- return lambda _: throw(exc)
-
-
-def async_throw_func(exc: Type[Exception]) -> Callable[[T], Coroutine[Any, Any, T]]:
- async def f(_: T) -> T:
- raise exc
-
- return f
-
-
-def throw_for_odd_func(exc):
- return lambda i: throw(exc) if i % 2 == 1 else i
-
-
-def async_throw_for_odd_func(exc):
- async def f(i):
- return throw(exc) if i % 2 == 1 else i
-
- return f
-
-
-class TestError(Exception):
- pass
-
-
-DELTA_RATE = 0.4
-# size of the test collections
-N = 256
-
-src = range(N)
-
-even_src = range(0, N, 2)
-
-
-def randomly_slowed(
- func: Callable[[T], R], min_sleep: float = 0.001, max_sleep: float = 0.05
-) -> Callable[[T], R]:
- def wrap(x: T) -> R:
- time.sleep(min_sleep + random.random() * (max_sleep - min_sleep))
- return func(x)
-
- return wrap
-
-
-def async_randomly_slowed(
- async_func: Callable[[T], Coroutine[Any, Any, R]],
- min_sleep: float = 0.001,
- max_sleep: float = 0.05,
-) -> Callable[[T], Coroutine[Any, Any, R]]:
- async def wrap(x: T) -> R:
- await asyncio.sleep(min_sleep + random.random() * (max_sleep - min_sleep))
- return await async_func(x)
-
- return wrap
-
-
-def range_raising_at_exhaustion(
- start: int, end: int, step: int, exception: Exception
-) -> Iterator[int]:
- yield from range(start, end, step)
- raise exception
-
-
-src_raising_at_exhaustion = lambda: range_raising_at_exhaustion(0, N, 1, TestError())
+from streamable.util.asynctools import awaitable_to_coroutine
+from streamable.util.functiontools import WrappedError, asyncify, star
+from streamable.util.iterabletools import (
+ sync_to_async_iter,
+ sync_to_bi_iterable,
+)
+from tests.utils import (
+ DELTA_RATE,
+ ITERABLE_TYPES,
+ IterableType,
+ N,
+ TestError,
+ alist_or_list,
+ anext_or_next,
+ async_identity,
+ async_identity_sleep,
+ async_randomly_slowed,
+ async_slow_identity,
+ async_square,
+ async_throw_for_odd_func,
+ async_throw_func,
+ bi_iterable_to_iter,
+ even_src,
+ identity,
+ identity_sleep,
+ stopiteration_for_iter_type,
+ randomly_slowed,
+ slow_identity,
+ slow_identity_duration,
+ square,
+ src,
+ src_raising_at_exhaustion,
+ throw,
+ throw_for_odd_func,
+ throw_func,
+ timestream,
+ to_list,
+ to_set,
+)
class TestStream(unittest.TestCase):
@@ -204,6 +116,19 @@ def test_init(self) -> None:
):
Stream(src).upstream = Stream(src) # type: ignore
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_async_src(self, itype) -> None:
+ self.assertEqual(
+ to_list(Stream(sync_to_async_iter(src)), itype),
+ list(src),
+ msg="a stream with an async source must be collectable as an Iterable or as AsyncIterable",
+ )
+ self.assertEqual(
+ to_list(Stream(sync_to_async_iter(src).__aiter__), itype),
+ list(src),
+ msg="a stream with an async source must be collectable as an Iterable or as AsyncIterable",
+ )
+
def test_repr_and_display(self) -> None:
class CustomCallable:
pass
@@ -211,31 +136,44 @@ class CustomCallable:
complex_stream: Stream[int] = (
Stream(src)
.truncate(1024, when=lambda _: False)
+ .atruncate(1024, when=async_identity)
.skip(10)
+ .askip(10)
.skip(until=lambda _: True)
+ .askip(until=async_identity)
.distinct(lambda _: _)
+ .adistinct(async_identity)
.filter()
.map(lambda i: (i,))
.map(lambda i: (i,), concurrency=2)
.filter(star(bool))
+ .afilter(star(async_identity))
.foreach(lambda _: _)
.foreach(lambda _: _, concurrency=2)
.aforeach(async_identity)
.map(cast(Callable[[Any], Any], CustomCallable()))
.amap(async_identity)
.group(100)
+ .agroup(100)
.groupby(len)
+ .agroupby(async_identity)
.map(star(lambda key, group: group))
.observe("groups")
.flatten(concurrency=4)
+ .map(sync_to_async_iter)
+ .aflatten(concurrency=4)
+ .map(lambda _: 0)
.throttle(
64,
per=datetime.timedelta(seconds=1),
)
.observe("foos")
- .catch(finally_raise=True)
+ .catch(finally_raise=True, when=identity)
+ .acatch(finally_raise=True, when=async_identity)
.catch((TypeError, ValueError, None, ZeroDivisionError))
- .catch(TypeError, replacement=None, finally_raise=True)
+ .acatch((TypeError, ValueError, None, ZeroDivisionError))
+ .catch(TypeError, replacement=1, finally_raise=True)
+ .acatch(TypeError, replacement=1, finally_raise=True)
)
print(repr(complex_stream))
@@ -255,73 +193,98 @@ class CustomCallable:
complex_stream.display(logging.ERROR)
self.assertEqual(
- """(
- Stream(range(0, 256))
-)""",
str(Stream(src)),
+ "Stream(range(0, 256))",
msg="`repr` should work as expected on a stream without operation",
)
self.assertEqual(
+ str(Stream(src).skip(10)),
+ "Stream(range(0, 256)).skip(10, until=None)",
+ msg="`repr` should return a one-liner for a stream with 1 operations",
+ )
+ self.assertEqual(
+ str(Stream(src).skip(10).skip(10)),
+ "Stream(range(0, 256)).skip(10, until=None).skip(10, until=None)",
+ msg="`repr` should return a one-liner for a stream with 2 operations",
+ )
+ self.assertEqual(
+ str(Stream(src).skip(10).skip(10).skip(10)),
"""(
Stream(range(0, 256))
- .map(, concurrency=2, ordered=True, via='process')
+ .skip(10, until=None)
+ .skip(10, until=None)
+ .skip(10, until=None)
)""",
- str(Stream(src).map(lambda _: _, concurrency=2, via="process")),
- msg="`repr` should work as expected on a stream with 1 operation",
+ msg="`repr` should go to line for a stream with 3 operations",
)
self.assertEqual(
str(complex_stream),
"""(
Stream(range(0, 256))
.truncate(count=1024, when=)
+ .atruncate(count=1024, when=async_identity)
.skip(10, until=None)
+ .askip(10, until=None)
.skip(None, until=)
+ .askip(None, until=async_identity)
.distinct(, consecutive_only=False)
+ .adistinct(async_identity, consecutive_only=False)
.filter(bool)
.map(, concurrency=1, ordered=True)
.map(, concurrency=2, ordered=True, via='thread')
.filter(star(bool))
+ .afilter(star(async_identity))
.foreach(, concurrency=1, ordered=True)
.foreach(, concurrency=2, ordered=True, via='thread')
.aforeach(async_identity, concurrency=1, ordered=True)
.map(CustomCallable(...), concurrency=1, ordered=True)
.amap(async_identity, concurrency=1, ordered=True)
.group(size=100, by=None, interval=None)
+ .agroup(size=100, by=None, interval=None)
.groupby(len, size=None, interval=None)
+ .agroupby(async_identity, size=None, interval=None)
.map(star(), concurrency=1, ordered=True)
.observe('groups')
.flatten(concurrency=4)
+ .map(SyncToAsyncIterator, concurrency=1, ordered=True)
+ .aflatten(concurrency=4)
+ .map(, concurrency=1, ordered=True)
.throttle(64, per=datetime.timedelta(seconds=1))
.observe('foos')
- .catch(Exception, when=None, finally_raise=True)
+ .catch(Exception, when=identity, finally_raise=True)
+ .acatch(Exception, when=async_identity, finally_raise=True)
.catch((TypeError, ValueError, None, ZeroDivisionError), when=None, finally_raise=False)
- .catch(TypeError, when=None, replacement=None, finally_raise=True)
+ .acatch((TypeError, ValueError, None, ZeroDivisionError), when=None, finally_raise=False)
+ .catch(TypeError, when=None, replacement=1, finally_raise=True)
+ .acatch(TypeError, when=None, replacement=1, finally_raise=True)
)""",
msg="`repr` should work as expected on a stream with many operation",
)
- def test_iter(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_iter(self, itype: IterableType) -> None:
self.assertIsInstance(
- iter(Stream(src)),
- Iterator,
+ bi_iterable_to_iter(Stream(src), itype=itype),
+ itype,
msg="iter(stream) must return an Iterator.",
)
with self.assertRaisesRegex(
TypeError,
- r"`source` must be an Iterable or a Callable\[\[\], Iterable\] but got a ",
+ r"`source` must be an Iterable/AsyncIterable or a Callable\[\[\], Iterable/AsyncIterable\] but got a ",
msg="Getting an Iterator from a Stream with a source not being a Union[Callable[[], Iterator], ITerable] must raise TypeError.",
):
- iter(Stream(1)) # type: ignore
+ bi_iterable_to_iter(Stream(1), itype=itype) # type: ignore
with self.assertRaisesRegex(
TypeError,
- r"`source` must be an Iterable or a Callable\[\[\], Iterable\] but got a Callable\[\[\], \]",
+ r"`source` must be an Iterable/AsyncIterable or a Callable\[\[\], Iterable/AsyncIterable\] but got a Callable\[\[\], \]",
msg="Getting an Iterator from a Stream with a source not being a Union[Callable[[], Iterator], ITerable] must raise TypeError.",
):
- iter(Stream(lambda: 1)) # type: ignore
+ bi_iterable_to_iter(Stream(lambda: 1), itype=itype) # type: ignore
- def test_add(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_add(self, itype: IterableType) -> None:
from streamable.stream import FlattenStream
stream = Stream(src)
@@ -335,7 +298,7 @@ def test_add(self) -> None:
stream_b = Stream(range(10, 20))
stream_c = Stream(range(20, 30))
self.assertListEqual(
- list(stream_a + stream_b + stream_c),
+ to_list(stream_a + stream_b + stream_c, itype=itype),
list(range(30)),
msg="`chain` must yield the elements of the first stream the move on with the elements of the next ones and so on.",
)
@@ -345,7 +308,9 @@ def test_add(self) -> None:
[Stream.map, [identity]],
[Stream.amap, [async_identity]],
[Stream.foreach, [identity]],
+ [Stream.aforeach, [identity]],
[Stream.flatten, []],
+ [Stream.aflatten, []],
]
)
def test_sanitize_concurrency(self, method, args) -> None:
@@ -384,26 +349,30 @@ def test_sanitize_via(self, method) -> None:
method(Stream(src), identity, via="foo")
@parameterized.expand(
- [
- [1],
- [2],
- ]
+ [(concurrency, itype) for concurrency in (1, 2) for itype in ITERABLE_TYPES]
)
- def test_map(self, concurrency) -> None:
+ def test_map(self, concurrency, itype) -> None:
self.assertListEqual(
- list(Stream(src).map(randomly_slowed(square), concurrency=concurrency)),
+ to_list(
+ Stream(src).map(randomly_slowed(square), concurrency=concurrency),
+ itype=itype,
+ ),
list(map(square, src)),
msg="At any concurrency the `map` method should act as the builtin map function, transforming elements while preserving input elements order.",
)
@parameterized.expand(
[
- [True, identity],
- [False, sorted],
+ (ordered, order_mutation, itype)
+ for itype in ITERABLE_TYPES
+ for ordered, order_mutation in [
+ (True, identity),
+ (False, sorted),
+ ]
]
)
def test_process_concurrency(
- self, ordered, order_mutation
+ self, ordered, order_mutation, itype
) -> None: # pragma: no cover
# 3.7 and 3.8 are passing the test but hang forever after
if sys.version_info.minor < 9:
@@ -420,7 +389,7 @@ def local_identity(x):
"",
msg="process-based concurrency should not be able to serialize a lambda or a local func",
):
- list(Stream(src).map(f, concurrency=2, via="process"))
+ to_list(Stream(src).map(f, concurrency=2, via="process"), itype=itype)
sleeps = [0.01, 1, 0.01]
state: List[str] = []
@@ -433,7 +402,7 @@ def local_identity(x):
.foreach(lambda _: state.append(""), concurrency=1, ordered=True)
)
self.assertListEqual(
- list(stream),
+ to_list(stream, itype=itype),
expected_result_list,
msg="process-based concurrency must correctly transform elements, respecting `ordered`...",
)
@@ -444,38 +413,38 @@ def local_identity(x):
)
# test partial iteration:
self.assertEqual(
- next(iter(stream)),
+ anext_or_next(bi_iterable_to_iter(stream, itype=itype)),
expected_result_list[0],
msg="process-based concurrency must behave ok with partial iteration",
)
@parameterized.expand(
[
- [16, 0],
- [1, 0],
- [16, 1],
- [16, 15],
- [16, 16],
+ (concurrency, n_elems, itype)
+ for concurrency, n_elems in [
+ [16, 0],
+ [1, 0],
+ [16, 1],
+ [16, 15],
+ [16, 16],
+ ]
+ for itype in ITERABLE_TYPES
]
)
def test_map_with_more_concurrency_than_elements(
- self, concurrency, n_elems
+ self, concurrency, n_elems, itype
) -> None:
self.assertListEqual(
- list(Stream(range(n_elems)).map(str, concurrency=concurrency)),
+ to_list(
+ Stream(range(n_elems)).map(str, concurrency=concurrency), itype=itype
+ ),
list(map(str, range(n_elems))),
msg="`map` method should act correctly when concurrency > number of elements.",
)
@parameterized.expand(
[
- [
- ordered,
- order_mutation,
- expected_duration,
- operation,
- func,
- ]
+ [ordered, order_mutation, expected_duration, operation, func, itype]
for ordered, order_mutation, expected_duration in [
(True, identity, 0.3),
(False, sorted, 0.21),
@@ -486,6 +455,7 @@ def test_map_with_more_concurrency_than_elements(
(Stream.aforeach, asyncio.sleep),
(Stream.amap, async_identity_sleep),
]
+ for itype in ITERABLE_TYPES
]
)
def test_mapping_ordering(
@@ -495,11 +465,13 @@ def test_mapping_ordering(
expected_duration: float,
operation,
func,
+ itype,
) -> None:
seconds = [0.1, 0.01, 0.2]
duration, res = timestream(
operation(Stream(seconds), func, ordered=ordered, concurrency=2),
5,
+ itype=itype,
)
self.assertListEqual(
res,
@@ -515,23 +487,21 @@ def test_mapping_ordering(
)
@parameterized.expand(
- [
- [1],
- [2],
- ]
+ [(concurrency, itype) for concurrency in (1, 2) for itype in ITERABLE_TYPES]
)
- def test_foreach(self, concurrency) -> None:
+ def test_foreach(self, concurrency, itype) -> None:
side_collection: Set[int] = set()
def side_effect(x: int, func: Callable[[int], int]):
nonlocal side_collection
side_collection.add(func(x))
- res = list(
+ res = to_list(
Stream(src).foreach(
lambda i: randomly_slowed(side_effect(i, square)),
concurrency=concurrency,
- )
+ ),
+ itype=itype,
)
self.assertListEqual(
@@ -554,6 +524,7 @@ def side_effect(x: int, func: Callable[[int], int]):
method,
throw_func_,
throw_for_odd_func_,
+ itype,
]
for raised_exc, caught_exc in [
(TestError, (TestError,)),
@@ -562,9 +533,11 @@ def side_effect(x: int, func: Callable[[int], int]):
for concurrency in [1, 2]
for method, throw_func_, throw_for_odd_func_ in [
(Stream.foreach, throw_func, throw_for_odd_func),
+ (Stream.aforeach, async_throw_func, async_throw_for_odd_func),
(Stream.map, throw_func, throw_for_odd_func),
(Stream.amap, async_throw_func, async_throw_for_odd_func),
]
+ for itype in ITERABLE_TYPES
]
)
def test_map_or_foreach_with_exception(
@@ -575,20 +548,25 @@ def test_map_or_foreach_with_exception(
method: Callable[[Stream, Callable[[Any], int], int], Stream],
throw_func: Callable[[Exception], Callable[[Any], int]],
throw_for_odd_func: Callable[[Type[Exception]], Callable[[Any], int]],
+ itype: IterableType,
) -> None:
with self.assertRaises(
caught_exc,
msg="At any concurrency, `map` and `foreach` and `amap` must raise",
):
- list(method(Stream(src), throw_func(raised_exc), concurrency=concurrency)) # type: ignore
+ to_list(
+ method(Stream(src), throw_func(raised_exc), concurrency=concurrency), # type: ignore
+ itype=itype,
+ )
self.assertListEqual(
- list(
+ to_list(
method(
Stream(src),
throw_for_odd_func(raised_exc),
concurrency=concurrency, # type: ignore
- ).catch(caught_exc)
+ ).catch(caught_exc),
+ itype=itype,
),
list(even_src),
msg="At any concurrency, `map` and `foreach` and `amap` must not stop after one exception occured.",
@@ -596,18 +574,24 @@ def test_map_or_foreach_with_exception(
@parameterized.expand(
[
- [method, func, concurrency]
+ [method, func, concurrency, itype]
for method, func in [
(Stream.foreach, slow_identity),
+ (Stream.aforeach, async_slow_identity),
(Stream.map, slow_identity),
(Stream.amap, async_slow_identity),
]
for concurrency in [1, 2, 4]
+ for itype in ITERABLE_TYPES
]
)
- def test_map_and_foreach_concurrency(self, method, func, concurrency) -> None:
+ def test_map_and_foreach_concurrency(
+ self, method, func, concurrency, itype
+ ) -> None:
expected_iteration_duration = N * slow_identity_duration / concurrency
- duration, res = timestream(method(Stream(src), func, concurrency=concurrency))
+ duration, res = timestream(
+ method(Stream(src), func, concurrency=concurrency), itype=itype
+ )
self.assertListEqual(res, list(src))
self.assertAlmostEqual(
duration,
@@ -618,50 +602,80 @@ def test_map_and_foreach_concurrency(self, method, func, concurrency) -> None:
@parameterized.expand(
[
- [1],
- [2],
+ (concurrency, itype, flatten)
+ for concurrency in (1, 2)
+ for itype in ITERABLE_TYPES
+ for flatten in (Stream.flatten, Stream.aflatten)
]
)
- def test_flatten(self, concurrency) -> None:
+ def test_flatten(self, concurrency, itype, flatten) -> None:
n_iterables = 32
it = list(range(N // n_iterables))
double_it = it + it
iterables_stream = Stream(
- lambda: map(slow_identity, [double_it] + [it for _ in range(n_iterables)])
- )
+ [sync_to_bi_iterable(double_it)]
+ + [sync_to_bi_iterable(it) for _ in range(n_iterables)]
+ ).map(slow_identity)
self.assertCountEqual(
- list(iterables_stream.flatten(concurrency=concurrency)),
+ to_list(flatten(iterables_stream, concurrency=concurrency), itype=itype),
list(it) * n_iterables + double_it,
msg="At any concurrency the `flatten` method should yield all the upstream iterables' elements.",
)
self.assertListEqual(
- list(
- Stream([iter([]) for _ in range(2000)]).flatten(concurrency=concurrency)
+ to_list(
+ flatten(
+ Stream([sync_to_bi_iterable(iter([])) for _ in range(2000)]),
+ concurrency=concurrency,
+ ),
+ itype=itype,
),
[],
msg="`flatten` should not yield any element if upstream elements are empty iterables, and be resilient to recursion issue in case of successive empty upstream iterables.",
)
with self.assertRaises(
- TypeError,
+ (TypeError, AttributeError),
msg="`flatten` should raise if an upstream element is not iterable.",
):
- next(iter(Stream(cast(Iterable, src)).flatten()))
+ anext_or_next(
+ bi_iterable_to_iter(
+ flatten(Stream(cast(Union[Iterable, AsyncIterable], src))),
+ itype=itype,
+ )
+ )
# test typing with ranges
_: Stream[int] = Stream((src, src)).flatten()
- def test_flatten_concurrency(self) -> None:
+ @parameterized.expand(
+ [
+ (flatten, itype, slow)
+ for flatten, slow in (
+ (Stream.flatten, partial(Stream.map, transformation=slow_identity)),
+ (
+ Stream.aflatten,
+ partial(Stream.amap, transformation=async_slow_identity),
+ ),
+ )
+ for itype in ITERABLE_TYPES
+ ]
+ )
+ def test_flatten_concurrency(self, flatten, itype, slow) -> None:
+ concurrency = 2
iterable_size = 5
runtime, res = timestream(
- Stream(
- lambda: [
- Stream(map(slow_identity, ["a"] * iterable_size)),
- Stream(map(slow_identity, ["b"] * iterable_size)),
- Stream(map(slow_identity, ["c"] * iterable_size)),
- ]
- ).flatten(concurrency=2),
+ flatten(
+ Stream(
+ lambda: [
+ slow(Stream(["a"] * iterable_size)),
+ slow(Stream(["b"] * iterable_size)),
+ slow(Stream(["c"] * iterable_size)),
+ ]
+ ),
+ concurrency=concurrency,
+ ),
times=3,
+ itype=itype,
)
self.assertListEqual(
res,
@@ -669,7 +683,7 @@ def test_flatten_concurrency(self) -> None:
msg="`flatten` should process 'a's and 'b's concurrently and then 'c's",
)
a_runtime = b_runtime = c_runtime = iterable_size * slow_identity_duration
- expected_runtime = (a_runtime + b_runtime) / 2 + c_runtime
+ expected_runtime = (a_runtime + b_runtime) / concurrency + c_runtime
self.assertAlmostEqual(
runtime,
expected_runtime,
@@ -688,21 +702,29 @@ def test_flatten_typing(self) -> None:
Stream("abc").map(lambda char: filter(lambda _: True, char)).flatten()
)
+ flattened_asynciter_stream: Stream[str] = (
+ Stream("abc").map(sync_to_async_iter).aflatten()
+ )
+
@parameterized.expand(
[
- [exception_type, mapped_exception_type, concurrency]
+ [exception_type, mapped_exception_type, concurrency, itype, flatten]
+ for concurrency in [1, 2]
+ for itype in ITERABLE_TYPES
for exception_type, mapped_exception_type in [
(TestError, TestError),
- (StopIteration, WrappedError),
+ (stopiteration_for_iter_type(itype), (WrappedError, RuntimeError)),
]
- for concurrency in [1, 2]
+ for flatten in (Stream.flatten, Stream.aflatten)
]
)
- def test_flatten_with_exception(
+ def test_flatten_with_exception_in_iter(
self,
exception_type: Type[Exception],
mapped_exception_type: Type[Exception],
concurrency: int,
+ itype: IterableType,
+ flatten: Callable,
) -> None:
n_iterables = 5
@@ -710,60 +732,79 @@ class IterableRaisingInIter(Iterable[int]):
def __iter__(self) -> Iterator[int]:
raise exception_type
- self.assertSetEqual(
- set(
+ res: Set[int] = to_set(
+ flatten(
Stream(
map(
- lambda i: (
- IterableRaisingInIter()
- if i % 2
- else cast(Iterable[int], range(i, i + 1))
+ lambda i: sync_to_bi_iterable(
+ IterableRaisingInIter() if i % 2 else range(i, i + 1)
),
range(n_iterables),
)
- )
- .flatten(concurrency=concurrency)
- .catch(mapped_exception_type)
- ),
+ ),
+ concurrency=concurrency,
+ ).catch(mapped_exception_type),
+ itype=itype,
+ )
+ self.assertSetEqual(
+ res,
set(range(0, n_iterables, 2)),
- msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should remap StopIteration one to PacifiedStopIteration.",
+ msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should wrap Stop(Async)Iteration.",
)
+ @parameterized.expand(
+ [
+ [concurrency, itype, flatten]
+ for concurrency in [1, 2]
+ for itype in ITERABLE_TYPES
+ for flatten in (Stream.flatten, Stream.aflatten)
+ ]
+ )
+ def test_flatten_with_exception_in_next(
+ self,
+ concurrency: int,
+ itype: IterableType,
+ flatten: Callable,
+ ) -> None:
+ n_iterables = 5
+
class IteratorRaisingInNext(Iterator[int]):
def __init__(self) -> None:
self.first_next = True
- def __iter__(self) -> Iterator[int]:
- return self
-
def __next__(self) -> int:
if not self.first_next:
- raise StopIteration
+ raise StopIteration()
self.first_next = False
- raise exception_type
+ raise TestError
- self.assertSetEqual(
- set(
+ res = to_set(
+ flatten(
Stream(
map(
lambda i: (
- IteratorRaisingInNext()
- if i % 2
- else cast(Iterable[int], range(i, i + 1))
+ sync_to_bi_iterable(
+ IteratorRaisingInNext() if i % 2 else range(i, i + 1)
+ )
),
range(n_iterables),
)
- )
- .flatten(concurrency=concurrency)
- .catch(mapped_exception_type)
- ),
+ ),
+ concurrency=concurrency,
+ ).catch(TestError),
+ itype=itype,
+ )
+ self.assertSetEqual(
+ res,
set(range(0, n_iterables, 2)),
- msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should remap StopIteration one to PacifiedStopIteration.",
+ msg="At any concurrency the `flatten` method should be resilient to exceptions thrown by iterators, especially it should wrap Stop(Async)Iteration.",
)
- @parameterized.expand([[concurrency] for concurrency in [2, 4]])
+ @parameterized.expand(
+ [(concurrency, itype) for concurrency in [2, 4] for itype in ITERABLE_TYPES]
+ )
def test_partial_iteration_on_streams_using_concurrency(
- self, concurrency: int
+ self, concurrency: int, itype: IterableType
) -> None:
yielded_elems = []
@@ -798,14 +839,14 @@ def remembering_src() -> Iterator[int]:
),
]:
yielded_elems = []
- iterator = iter(stream)
+ iterator = bi_iterable_to_iter(stream, itype=itype)
time.sleep(0.5)
self.assertEqual(
len(yielded_elems),
0,
msg=f"before the first call to `next` a concurrent {type(stream)} should have pulled 0 upstream elements.",
)
- next(iterator)
+ anext_or_next(iterator)
time.sleep(0.5)
self.assertEqual(
len(yielded_elems),
@@ -813,40 +854,82 @@ def remembering_src() -> Iterator[int]:
msg=f"`after the first call to `next` a concurrent {type(stream)} with concurrency={concurrency} should have pulled only {n_pulls_after_first_next} upstream elements.",
)
- def test_filter(self) -> None:
- def keep(x) -> Any:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_filter(self, itype: IterableType) -> None:
+ def keep(x) -> int:
return x % 2
self.assertListEqual(
- list(Stream(src).filter(keep)),
+ to_list(Stream(src).filter(keep), itype=itype),
list(filter(keep, src)),
msg="`filter` must act like builtin filter",
)
self.assertListEqual(
- list(Stream(src).filter()),
+ to_list(Stream(src).filter(bool), itype=itype),
list(filter(None, src)),
msg="`filter` with `bool` as predicate must act like builtin filter with None predicate.",
)
self.assertListEqual(
- list(Stream(src).filter()),
+ to_list(Stream(src).filter(), itype=itype),
list(filter(None, src)),
msg="`filter` without predicate must act like builtin filter with None predicate.",
)
self.assertListEqual(
- list(Stream(src).filter(None)), # type: ignore
+ to_list(Stream(src).filter(None), itype=itype), # type: ignore
list(filter(None, src)),
msg="`filter` with None predicate must act unofficially like builtin filter with None predicate.",
)
- # Unofficially accept `stream.filter(None)`, behaving as builtin `filter(None, iter)`
+ self.assertEqual(
+ to_list(Stream(src).filter(None), itype=itype), # type: ignore
+ list(filter(None, src)),
+ msg="Unofficially accept `stream.filter(None)`, behaving as builtin `filter(None, iter)`",
+ )
# with self.assertRaisesRegex(
# TypeError,
# "`when` cannot be None",
# msg="`filter` does not accept a None predicate",
# ):
- # list(Stream(src).filter(None)) # type: ignore
+ # to_list(Stream(src).filter(None), itype=itype) # type: ignore
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_afilter(self, itype: IterableType) -> None:
+ def keep(x) -> int:
+ return x % 2
+
+ async def async_keep(x) -> int:
+ return keep(x)
+
+ self.assertListEqual(
+ to_list(Stream(src).afilter(async_keep), itype=itype),
+ list(filter(keep, src)),
+ msg="`afilter` must act like builtin filter",
+ )
+ self.assertListEqual(
+ to_list(Stream(src).afilter(asyncify(bool)), itype=itype),
+ list(filter(None, src)),
+ msg="`afilter` with `bool` as predicate must act like builtin filter with None predicate.",
+ )
+ self.assertListEqual(
+ to_list(Stream(src).afilter(None), itype=itype), # type: ignore
+ list(filter(None, src)),
+ msg="`afilter` with None predicate must act unofficially like builtin filter with None predicate.",
+ )
+
+ self.assertEqual(
+ to_list(Stream(src).afilter(None), itype=itype), # type: ignore
+ list(filter(None, src)),
+ msg="Unofficially accept `stream.afilter(None)`, behaving as builtin `filter(None, iter)`",
+ )
+ # with self.assertRaisesRegex(
+ # TypeError,
+ # "`when` cannot be None",
+ # msg="`afilter` does not accept a None predicate",
+ # ):
+ # to_list(Stream(src).afilter(None), itype=itype) # type: ignore
- def test_skip(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_skip(self, itype: IterableType) -> None:
with self.assertRaisesRegex(
ValueError,
"`count` must be >= 0 but got -1",
@@ -855,89 +938,164 @@ def test_skip(self) -> None:
Stream(src).skip(-1)
self.assertListEqual(
- list(Stream(src).skip()),
+ to_list(Stream(src).skip(), itype=itype),
list(src),
msg="`skip` must be no-op if both `count` and `until` are None",
)
self.assertListEqual(
- list(Stream(src).skip(None)),
+ to_list(Stream(src).skip(None), itype=itype),
list(src),
msg="`skip` must be no-op if both `count` and `until` are None",
)
for count in [0, 1, 3]:
self.assertListEqual(
- list(Stream(src).skip(count)),
+ to_list(Stream(src).skip(count), itype=itype),
list(src)[count:],
msg="`skip` must skip `count` elements",
)
self.assertListEqual(
- list(
+ to_list(
Stream(map(throw_for_odd_func(TestError), src))
.skip(count)
- .catch(TestError)
+ .catch(TestError),
+ itype=itype,
),
list(filter(lambda i: i % 2 == 0, src))[count:],
msg="`skip` must not count exceptions as skipped elements",
)
self.assertListEqual(
- list(Stream(src).skip(until=lambda n: n >= count)),
+ to_list(Stream(src).skip(until=lambda n: n >= count), itype=itype),
list(src)[count:],
msg="`skip` must yield starting from the first element satisfying `until`",
)
self.assertListEqual(
- list(Stream(src).skip(count, until=lambda n: False)),
+ to_list(Stream(src).skip(count, until=lambda n: False), itype=itype),
list(src)[count:],
msg="`skip` must ignore `count` elements if `until` is never satisfied",
)
self.assertListEqual(
- list(Stream(src).skip(count * 2, until=lambda n: n >= count)),
+ to_list(
+ Stream(src).skip(count * 2, until=lambda n: n >= count), itype=itype
+ ),
list(src)[count:],
msg="`skip` must ignore less than `count` elements if `until` is satisfied first",
)
self.assertListEqual(
- list(Stream(src).skip(until=lambda n: False)),
+ to_list(Stream(src).skip(until=lambda n: False), itype=itype),
[],
msg="`skip` must not yield any element if `until` is never satisfied",
)
- def test_truncate(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_askip(self, itype: IterableType) -> None:
+ with self.assertRaisesRegex(
+ ValueError,
+ "`count` must be >= 0 but got -1",
+ msg="`askip` must raise ValueError if `count` is negative",
+ ):
+ Stream(src).askip(-1)
+
+ self.assertListEqual(
+ to_list(Stream(src).askip(), itype=itype),
+ list(src),
+ msg="`askip` must be no-op if both `count` and `until` are None",
+ )
+
+ self.assertListEqual(
+ to_list(Stream(src).askip(None), itype=itype),
+ list(src),
+ msg="`askip` must be no-op if both `count` and `until` are None",
+ )
+
+ for count in [0, 1, 3]:
+ self.assertListEqual(
+ to_list(Stream(src).askip(count), itype=itype),
+ list(src)[count:],
+ msg="`askip` must skip `count` elements",
+ )
+
+ self.assertListEqual(
+ to_list(
+ Stream(map(throw_for_odd_func(TestError), src))
+ .askip(count)
+ .catch(TestError),
+ itype=itype,
+ ),
+ list(filter(lambda i: i % 2 == 0, src))[count:],
+ msg="`askip` must not count exceptions as skipped elements",
+ )
+
+ self.assertListEqual(
+ to_list(
+ Stream(src).askip(until=asyncify(lambda n: n >= count)), itype=itype
+ ),
+ list(src)[count:],
+ msg="`askip` must yield starting from the first element satisfying `until`",
+ )
+
+ self.assertListEqual(
+ to_list(
+ Stream(src).askip(count, until=asyncify(lambda n: False)),
+ itype=itype,
+ ),
+ list(src)[count:],
+ msg="`askip` must ignore `count` elements if `until` is never satisfied",
+ )
+
+ self.assertListEqual(
+ to_list(
+ Stream(src).askip(count * 2, until=asyncify(lambda n: n >= count)),
+ itype=itype,
+ ),
+ list(src)[count:],
+ msg="`askip` must ignore less than `count` elements if `until` is satisfied first",
+ )
+
+ self.assertListEqual(
+ to_list(Stream(src).askip(until=asyncify(lambda n: False)), itype=itype),
+ [],
+ msg="`askip` must not yield any element if `until` is never satisfied",
+ )
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_truncate(self, itype: IterableType) -> None:
self.assertListEqual(
- list(Stream(src).truncate(N * 2)),
+ to_list(Stream(src).truncate(N * 2), itype=itype),
list(src),
msg="`truncate` must be ok with count >= stream length",
)
self.assertListEqual(
- list(Stream(src).truncate()),
+ to_list(Stream(src).truncate(), itype=itype),
list(src),
msg="`truncate must be no-op if both `count` and `when` are None",
)
self.assertListEqual(
- list(Stream(src).truncate(None)),
+ to_list(Stream(src).truncate(None), itype=itype),
list(src),
msg="`truncate must be no-op if both `count` and `when` are None",
)
self.assertListEqual(
- list(Stream(src).truncate(2)),
+ to_list(Stream(src).truncate(2), itype=itype),
[0, 1],
msg="`truncate` must be ok with count >= 1",
)
self.assertListEqual(
- list(Stream(src).truncate(1)),
+ to_list(Stream(src).truncate(1), itype=itype),
[0],
msg="`truncate` must be ok with count == 1",
)
self.assertListEqual(
- list(Stream(src).truncate(0)),
+ to_list(Stream(src).truncate(0), itype=itype),
[],
msg="`truncate` must be ok with count == 0",
)
@@ -950,106 +1108,222 @@ def test_truncate(self) -> None:
Stream(src).truncate(-1)
self.assertListEqual(
- list(Stream(src).truncate(cast(int, float("inf")))),
+ to_list(Stream(src).truncate(cast(int, float("inf"))), itype=itype),
list(src),
msg="`truncate` must be no-op if `count` is inf",
)
count = N // 2
- raising_stream_iterator = iter(
- Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).truncate(count)
+ raising_stream_iterator = bi_iterable_to_iter(
+ Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).truncate(count),
+ itype=itype,
)
with self.assertRaises(
ZeroDivisionError,
msg="`truncate` must not stop iteration when encountering exceptions and raise them without counting them...",
):
- next(raising_stream_iterator)
+ anext_or_next(raising_stream_iterator)
- self.assertListEqual(list(raising_stream_iterator), list(range(1, count + 1)))
+ self.assertListEqual(
+ alist_or_list(raising_stream_iterator), list(range(1, count + 1))
+ )
with self.assertRaises(
- StopIteration,
+ stopiteration_for_iter_type(type(raising_stream_iterator)),
msg="... and after reaching the limit it still continues to raise StopIteration on calls to next",
):
- next(raising_stream_iterator)
+ anext_or_next(raising_stream_iterator)
- iter_truncated_on_predicate = iter(Stream(src).truncate(when=lambda n: n == 5))
+ iter_truncated_on_predicate = bi_iterable_to_iter(
+ Stream(src).truncate(when=lambda n: n == 5), itype=itype
+ )
self.assertListEqual(
- list(iter_truncated_on_predicate),
- list(Stream(src).truncate(5)),
+ alist_or_list(iter_truncated_on_predicate),
+ to_list(Stream(src).truncate(5), itype=itype),
msg="`when` n == 5 must be equivalent to `count` = 5",
)
with self.assertRaises(
- StopIteration,
+ stopiteration_for_iter_type(type(iter_truncated_on_predicate)),
msg="After exhaustion a call to __next__ on a truncated iterator must raise StopIteration",
):
- next(iter_truncated_on_predicate)
+ anext_or_next(iter_truncated_on_predicate)
with self.assertRaises(
ZeroDivisionError,
msg="an exception raised by `when` must be raised",
):
- list(Stream(src).truncate(when=lambda _: 1 / 0))
+ to_list(Stream(src).truncate(when=lambda _: 1 / 0), itype=itype)
self.assertListEqual(
- list(Stream(src).truncate(6, when=lambda n: n == 5)),
+ to_list(Stream(src).truncate(6, when=lambda n: n == 5), itype=itype),
list(range(5)),
msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.",
)
self.assertListEqual(
- list(Stream(src).truncate(5, when=lambda n: n == 6)),
+ to_list(Stream(src).truncate(5, when=lambda n: n == 6), itype=itype),
list(range(5)),
msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.",
)
- def test_group(self) -> None:
- # behavior with invalid arguments
- for seconds in [-1, 0]:
- with self.assertRaises(
- ValueError,
- msg="`group` should raise error when called with `seconds` <= 0.",
- ):
- list(
- Stream([1]).group(
- size=100, interval=datetime.timedelta(seconds=seconds)
- )
- )
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_atruncate(self, itype: IterableType) -> None:
+ self.assertListEqual(
+ to_list(Stream(src).atruncate(N * 2), itype=itype),
+ list(src),
+ msg="`atruncate` must be ok with count >= stream length",
+ )
- for size in [-1, 0]:
- with self.assertRaises(
- ValueError,
- msg="`group` should raise error when called with `size` < 1.",
- ):
- list(Stream([1]).group(size=size))
+ self.assertListEqual(
+ to_list(Stream(src).atruncate(), itype=itype),
+ list(src),
+ msg="`atruncate` must be no-op if both `count` and `when` are None",
+ )
- # group size
self.assertListEqual(
- list(Stream(range(6)).group(size=4)),
- [[0, 1, 2, 3], [4, 5]],
- msg="",
+ to_list(Stream(src).atruncate(None), itype=itype),
+ list(src),
+ msg="`atruncate` must be no-op if both `count` and `when` are None",
)
+
self.assertListEqual(
- list(Stream(range(6)).group(size=2)),
- [[0, 1], [2, 3], [4, 5]],
- msg="",
+ to_list(Stream(src).atruncate(2), itype=itype),
+ [0, 1],
+ msg="`atruncate` must be ok with count >= 1",
)
self.assertListEqual(
- list(Stream([]).group(size=2)),
+ to_list(Stream(src).atruncate(1), itype=itype),
+ [0],
+ msg="`atruncate` must be ok with count == 1",
+ )
+ self.assertListEqual(
+ to_list(Stream(src).atruncate(0), itype=itype),
[],
- msg="",
+ msg="`atruncate` must be ok with count == 0",
)
- # behavior with exceptions
- def f(i):
- return i / (110 - i)
+ with self.assertRaisesRegex(
+ ValueError,
+ "`count` must be >= 0 but got -1",
+ msg="`atruncate` must raise ValueError if `count` is negative",
+ ):
+ Stream(src).atruncate(-1)
- stream_iterator = iter(Stream(lambda: map(f, src)).group(100))
- next(stream_iterator)
self.assertListEqual(
- next(stream_iterator),
- list(map(f, range(100, 110))),
+ to_list(Stream(src).atruncate(cast(int, float("inf"))), itype=itype),
+ list(src),
+ msg="`atruncate` must be no-op if `count` is inf",
+ )
+
+ count = N // 2
+ raising_stream_iterator = bi_iterable_to_iter(
+ Stream(lambda: map(lambda x: round((1 / x) * x**2), src)).atruncate(count),
+ itype=itype,
+ )
+
+ with self.assertRaises(
+ ZeroDivisionError,
+ msg="`atruncate` must not stop iteration when encountering exceptions and raise them without counting them...",
+ ):
+ anext_or_next(raising_stream_iterator)
+
+ self.assertListEqual(
+ alist_or_list(raising_stream_iterator), list(range(1, count + 1))
+ )
+
+ with self.assertRaises(
+ stopiteration_for_iter_type(type(raising_stream_iterator)),
+ msg="... and after reaching the limit it still continues to raise StopIteration on calls to next",
+ ):
+ anext_or_next(raising_stream_iterator)
+
+ iter_truncated_on_predicate = bi_iterable_to_iter(
+ Stream(src).atruncate(when=asyncify(lambda n: n == 5)), itype=itype
+ )
+ self.assertListEqual(
+ alist_or_list(iter_truncated_on_predicate),
+ to_list(Stream(src).atruncate(5), itype=itype),
+ msg="`when` n == 5 must be equivalent to `count` = 5",
+ )
+ with self.assertRaises(
+ stopiteration_for_iter_type(type(iter_truncated_on_predicate)),
+ msg="After exhaustion a call to __next__ on a truncated iterator must raise StopIteration",
+ ):
+ anext_or_next(iter_truncated_on_predicate)
+
+ with self.assertRaises(
+ ZeroDivisionError,
+ msg="an exception raised by `when` must be raised",
+ ):
+ to_list(Stream(src).atruncate(when=asyncify(lambda _: 1 / 0)), itype=itype)
+
+ self.assertListEqual(
+ to_list(
+ Stream(src).atruncate(6, when=asyncify(lambda n: n == 5)), itype=itype
+ ),
+ list(range(5)),
+ msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.",
+ )
+
+ self.assertListEqual(
+ to_list(
+ Stream(src).atruncate(5, when=asyncify(lambda n: n == 6)), itype=itype
+ ),
+ list(range(5)),
+ msg="`when` and `count` argument can be set at the same time, and the truncation should happen as soon as one or the other is satisfied.",
+ )
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_group(self, itype: IterableType) -> None:
+ # behavior with invalid arguments
+ for seconds in [-1, 0]:
+ with self.assertRaises(
+ ValueError,
+ msg="`group` should raise error when called with `seconds` <= 0.",
+ ):
+ to_list(
+ Stream([1]).group(
+ size=100, interval=datetime.timedelta(seconds=seconds)
+ ),
+ itype=itype,
+ )
+
+ for size in [-1, 0]:
+ with self.assertRaises(
+ ValueError,
+ msg="`group` should raise error when called with `size` < 1.",
+ ):
+ to_list(Stream([1]).group(size=size), itype=itype)
+
+ # group size
+ self.assertListEqual(
+ to_list(Stream(range(6)).group(size=4), itype=itype),
+ [[0, 1, 2, 3], [4, 5]],
+ msg="",
+ )
+ self.assertListEqual(
+ to_list(Stream(range(6)).group(size=2), itype=itype),
+ [[0, 1], [2, 3], [4, 5]],
+ msg="",
+ )
+ self.assertListEqual(
+ to_list(Stream([]).group(size=2), itype=itype),
+ [],
+ msg="",
+ )
+
+ # behavior with exceptions
+ def f(i):
+ return i / (110 - i)
+
+ stream_iterator = bi_iterable_to_iter(
+ Stream(lambda: map(f, src)).group(100), itype=itype
+ )
+ anext_or_next(stream_iterator)
+ self.assertListEqual(
+ anext_or_next(stream_iterator),
+ list(map(f, range(100, 110))),
msg="when encountering upstream exception, `group` should yield the current accumulated group...",
)
@@ -1057,90 +1331,110 @@ def f(i):
ZeroDivisionError,
msg="... and raise the upstream exception during the next call to `next`...",
):
- next(stream_iterator)
+ anext_or_next(stream_iterator)
self.assertListEqual(
- next(stream_iterator),
+ anext_or_next(stream_iterator),
list(map(f, range(111, 211))),
msg="... and restarting a fresh group to yield after that.",
)
# behavior of the `seconds` parameter
self.assertListEqual(
- list(
+ to_list(
Stream(lambda: map(slow_identity, src)).group(
size=100,
interval=datetime.timedelta(seconds=slow_identity_duration / 1000),
- )
+ ),
+ itype=itype,
),
list(map(lambda e: [e], src)),
msg="`group` should not yield empty groups even though `interval` if smaller than upstream's frequency",
)
self.assertListEqual(
- list(
+ to_list(
Stream(lambda: map(slow_identity, src)).group(
size=100,
interval=datetime.timedelta(seconds=slow_identity_duration / 1000),
by=lambda _: None,
- )
+ ),
+ itype=itype,
),
list(map(lambda e: [e], src)),
msg="`group` with `by` argument should not yield empty groups even though `interval` if smaller than upstream's frequency",
)
self.assertListEqual(
- list(
+ to_list(
Stream(lambda: map(slow_identity, src)).group(
size=100,
interval=datetime.timedelta(
seconds=2 * slow_identity_duration * 0.99
),
- )
+ ),
+ itype=itype,
),
list(map(lambda e: [e, e + 1], even_src)),
msg="`group` should yield upstream elements in a two-element group if `interval` inferior to twice the upstream yield period",
)
self.assertListEqual(
- next(iter(Stream(src).group())),
+ anext_or_next(bi_iterable_to_iter(Stream(src).group(), itype=itype)),
list(src),
msg="`group` without arguments should group the elements all together",
)
+ # test agroupby
+ groupby_stream_iter: Union[
+ Iterator[Tuple[int, List[int]]], AsyncIterator[Tuple[int, List[int]]]
+ ] = bi_iterable_to_iter(
+ Stream(src).groupby(lambda n: n % 2, size=2), itype=itype
+ )
+ self.assertListEqual(
+ [anext_or_next(groupby_stream_iter), anext_or_next(groupby_stream_iter)],
+ [(0, [0, 2]), (1, [1, 3])],
+ msg="`groupby` must cogroup elements.",
+ )
+
# test by
- stream_iter = iter(Stream(src).group(size=2, by=lambda n: n % 2))
+ stream_iter = bi_iterable_to_iter(
+ Stream(src).group(size=2, by=lambda n: n % 2), itype=itype
+ )
self.assertListEqual(
- [next(stream_iter), next(stream_iter)],
+ [anext_or_next(stream_iter), anext_or_next(stream_iter)],
[[0, 2], [1, 3]],
msg="`group` called with a `by` function must cogroup elements.",
)
self.assertListEqual(
- next(
- iter(
+ anext_or_next(
+ bi_iterable_to_iter(
Stream(src_raising_at_exhaustion).group(
size=10, by=lambda n: n % 4 != 0
- )
- )
+ ),
+ itype=itype,
+ ),
),
[1, 2, 3, 5, 6, 7, 9, 10, 11, 13],
msg="`group` called with a `by` function and a `size` should yield the first batch becoming full.",
)
self.assertListEqual(
- list(Stream(src).group(by=lambda n: n % 2)),
+ to_list(Stream(src).group(by=lambda n: n % 2), itype=itype),
[list(range(0, N, 2)), list(range(1, N, 2))],
msg="`group` called with a `by` function and an infinite size must cogroup elements and yield groups starting with the group containing the oldest element.",
)
self.assertListEqual(
- list(Stream(range(10)).group(by=lambda n: n % 4 == 0)),
+ to_list(Stream(range(10)).group(by=lambda n: n % 4 == 0), itype=itype),
[[0, 4, 8], [1, 2, 3, 5, 6, 7, 9]],
msg="`group` called with a `by` function and reaching exhaustion must cogroup elements and yield uncomplete groups starting with the group containing the oldest element, even though it's not the largest.",
)
- stream_iter = iter(Stream(src_raising_at_exhaustion).group(by=lambda n: n % 2))
+ stream_iter = bi_iterable_to_iter(
+ Stream(src_raising_at_exhaustion).group(by=lambda n: n % 2), itype=itype
+ )
self.assertListEqual(
- [next(stream_iter), next(stream_iter)],
+ [anext_or_next(stream_iter), anext_or_next(stream_iter)],
[list(range(0, N, 2)), list(range(1, N, 2))],
msg="`group` called with a `by` function and encountering an exception must cogroup elements and yield uncomplete groups starting with the group containing the oldest element.",
)
@@ -1148,56 +1442,277 @@ def f(i):
TestError,
msg="`group` called with a `by` function and encountering an exception must raise it after all groups have been yielded",
):
- next(stream_iter)
+ anext_or_next(stream_iter)
# test seconds + by
self.assertListEqual(
- list(
+ to_list(
Stream(lambda: map(slow_identity, range(10))).group(
interval=datetime.timedelta(seconds=slow_identity_duration * 2.9),
by=lambda n: n % 4 == 0,
- )
+ ),
+ itype=itype,
),
[[1, 2], [0, 4], [3, 5, 6, 7], [8], [9]],
msg="`group` called with a `by` function must cogroup elements and yield the largest groups when `seconds` is reached event though it's not the oldest.",
)
- stream_iter = iter(
+ stream_iter = bi_iterable_to_iter(
Stream(src).group(
- size=3, by=lambda n: throw(StopIteration) if n == 2 else n
- )
+ size=3,
+ by=lambda n: throw(stopiteration_for_iter_type(itype)) if n == 2 else n,
+ ),
+ itype=itype,
)
self.assertListEqual(
- [next(stream_iter), next(stream_iter)],
+ [anext_or_next(stream_iter), anext_or_next(stream_iter)],
[[0], [1]],
msg="`group` should yield incomplete groups when `by` raises",
)
with self.assertRaisesRegex(
- WrappedError,
- "StopIteration()",
+ (WrappedError, RuntimeError),
+ stopiteration_for_iter_type(itype).__name__,
msg="`group` should raise and skip `elem` if `by(elem)` raises",
):
- next(stream_iter)
+ anext_or_next(stream_iter)
self.assertListEqual(
- next(stream_iter),
+ anext_or_next(stream_iter),
[3],
msg="`group` should continue yielding after `by`'s exception has been raised.",
)
- def test_throttle(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_agroup(self, itype: IterableType) -> None:
+ # behavior with invalid arguments
+ for seconds in [-1, 0]:
+ with self.assertRaises(
+ ValueError,
+ msg="`agroup` should raise error when called with `seconds` <= 0.",
+ ):
+ to_list(
+ Stream([1]).agroup(
+ size=100, interval=datetime.timedelta(seconds=seconds)
+ ),
+ itype=itype,
+ )
+
+ for size in [-1, 0]:
+ with self.assertRaises(
+ ValueError,
+ msg="`agroup` should raise error when called with `size` < 1.",
+ ):
+ to_list(Stream([1]).agroup(size=size), itype=itype)
+
+ # group size
+ self.assertListEqual(
+ to_list(Stream(range(6)).agroup(size=4), itype=itype),
+ [[0, 1, 2, 3], [4, 5]],
+ msg="",
+ )
+ self.assertListEqual(
+ to_list(Stream(range(6)).agroup(size=2), itype=itype),
+ [[0, 1], [2, 3], [4, 5]],
+ msg="",
+ )
+ self.assertListEqual(
+ to_list(Stream([]).agroup(size=2), itype=itype),
+ [],
+ msg="",
+ )
+
+ # behavior with exceptions
+ def f(i):
+ return i / (110 - i)
+
+ stream_iterator = bi_iterable_to_iter(
+ Stream(lambda: map(f, src)).agroup(100), itype=itype
+ )
+ anext_or_next(stream_iterator)
+ self.assertListEqual(
+ anext_or_next(stream_iterator),
+ list(map(f, range(100, 110))),
+ msg="when encountering upstream exception, `agroup` should yield the current accumulated group...",
+ )
+
+ with self.assertRaises(
+ ZeroDivisionError,
+ msg="... and raise the upstream exception during the next call to `next`...",
+ ):
+ anext_or_next(stream_iterator)
+
+ self.assertListEqual(
+ anext_or_next(stream_iterator),
+ list(map(f, range(111, 211))),
+ msg="... and restarting a fresh group to yield after that.",
+ )
+
+ # behavior of the `seconds` parameter
+ self.assertListEqual(
+ to_list(
+ Stream(lambda: map(slow_identity, src)).agroup(
+ size=100,
+ interval=datetime.timedelta(seconds=slow_identity_duration / 1000),
+ ),
+ itype=itype,
+ ),
+ list(map(lambda e: [e], src)),
+ msg="`agroup` should not yield empty groups even though `interval` if smaller than upstream's frequency",
+ )
+ self.assertListEqual(
+ to_list(
+ Stream(lambda: map(slow_identity, src)).agroup(
+ size=100,
+ interval=datetime.timedelta(seconds=slow_identity_duration / 1000),
+ by=asyncify(lambda _: None),
+ ),
+ itype=itype,
+ ),
+ list(map(lambda e: [e], src)),
+ msg="`agroup` with `by` argument should not yield empty groups even though `interval` if smaller than upstream's frequency",
+ )
+ self.assertListEqual(
+ to_list(
+ Stream(lambda: map(slow_identity, src)).agroup(
+ size=100,
+ interval=datetime.timedelta(
+ seconds=2 * slow_identity_duration * 0.99
+ ),
+ ),
+ itype=itype,
+ ),
+ list(map(lambda e: [e, e + 1], even_src)),
+ msg="`agroup` should yield upstream elements in a two-element group if `interval` inferior to twice the upstream yield period",
+ )
+
+ self.assertListEqual(
+ anext_or_next(bi_iterable_to_iter(Stream(src).agroup(), itype=itype)),
+ list(src),
+ msg="`agroup` without arguments should group the elements all together",
+ )
+
+ # test agroupby
+ groupby_stream_iter: Union[
+ Iterator[Tuple[int, List[int]]], AsyncIterator[Tuple[int, List[int]]]
+ ] = bi_iterable_to_iter(
+ Stream(src).agroupby(asyncify(lambda n: n % 2), size=2), itype=itype
+ )
+ self.assertListEqual(
+ [anext_or_next(groupby_stream_iter), anext_or_next(groupby_stream_iter)],
+ [(0, [0, 2]), (1, [1, 3])],
+ msg="`agroupby` must cogroup elements.",
+ )
+
+ # test by
+ stream_iter = bi_iterable_to_iter(
+ Stream(src).agroup(size=2, by=asyncify(lambda n: n % 2)), itype=itype
+ )
+ self.assertListEqual(
+ [anext_or_next(stream_iter), anext_or_next(stream_iter)],
+ [[0, 2], [1, 3]],
+ msg="`agroup` called with a `by` function must cogroup elements.",
+ )
+
+ self.assertListEqual(
+ anext_or_next(
+ bi_iterable_to_iter(
+ Stream(src_raising_at_exhaustion).agroup(
+ size=10, by=asyncify(lambda n: n % 4 != 0)
+ ),
+ itype=itype,
+ ),
+ ),
+ [1, 2, 3, 5, 6, 7, 9, 10, 11, 13],
+ msg="`agroup` called with a `by` function and a `size` should yield the first batch becoming full.",
+ )
+
+ self.assertListEqual(
+ to_list(Stream(src).agroup(by=asyncify(lambda n: n % 2)), itype=itype),
+ [list(range(0, N, 2)), list(range(1, N, 2))],
+ msg="`agroup` called with a `by` function and an infinite size must cogroup elements and yield groups starting with the group containing the oldest element.",
+ )
+
+ self.assertListEqual(
+ to_list(
+ Stream(range(10)).agroup(by=asyncify(lambda n: n % 4 == 0)), itype=itype
+ ),
+ [[0, 4, 8], [1, 2, 3, 5, 6, 7, 9]],
+ msg="`agroup` called with a `by` function and reaching exhaustion must cogroup elements and yield uncomplete groups starting with the group containing the oldest element, even though it's not the largest.",
+ )
+
+ stream_iter = bi_iterable_to_iter(
+ Stream(src_raising_at_exhaustion).agroup(by=asyncify(lambda n: n % 2)),
+ itype=itype,
+ )
+ self.assertListEqual(
+ [anext_or_next(stream_iter), anext_or_next(stream_iter)],
+ [list(range(0, N, 2)), list(range(1, N, 2))],
+ msg="`agroup` called with a `by` function and encountering an exception must cogroup elements and yield uncomplete groups starting with the group containing the oldest element.",
+ )
+ with self.assertRaises(
+ TestError,
+ msg="`agroup` called with a `by` function and encountering an exception must raise it after all groups have been yielded",
+ ):
+ anext_or_next(stream_iter)
+
+ # test seconds + by
+ self.assertListEqual(
+ to_list(
+ Stream(lambda: map(slow_identity, range(10))).agroup(
+ interval=datetime.timedelta(seconds=slow_identity_duration * 2.9),
+ by=asyncify(lambda n: n % 4 == 0),
+ ),
+ itype=itype,
+ ),
+ [[1, 2], [0, 4], [3, 5, 6, 7], [8], [9]],
+ msg="`agroup` called with a `by` function must cogroup elements and yield the largest groups when `seconds` is reached event though it's not the oldest.",
+ )
+
+ stream_iter = bi_iterable_to_iter(
+ Stream(src).agroup(
+ size=3,
+ by=asyncify(
+ lambda n: throw(stopiteration_for_iter_type(itype)) if n == 2 else n
+ ),
+ ),
+ itype=itype,
+ )
+ self.assertListEqual(
+ [anext_or_next(stream_iter), anext_or_next(stream_iter)],
+ [[0], [1]],
+ msg="`agroup` should yield incomplete groups when `by` raises",
+ )
+ with self.assertRaisesRegex(
+ (WrappedError, RuntimeError),
+ stopiteration_for_iter_type(itype).__name__,
+ msg="`agroup` should raise and skip `elem` if `by(elem)` raises",
+ ):
+ anext_or_next(stream_iter)
+ self.assertListEqual(
+ anext_or_next(stream_iter),
+ [3],
+ msg="`agroup` should continue yielding after `by`'s exception has been raised.",
+ )
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_throttle(self, itype: IterableType) -> None:
# behavior with invalid arguments
with self.assertRaisesRegex(
ValueError,
r"`per` must be None or a positive timedelta but got datetime\.timedelta\(0\)",
msg="`throttle` should raise error when called with negative `per`.",
):
- list(Stream([1]).throttle(1, per=datetime.timedelta(microseconds=0)))
+ to_list(
+ Stream([1]).throttle(1, per=datetime.timedelta(microseconds=0)),
+ itype=itype,
+ )
with self.assertRaisesRegex(
ValueError,
"`count` must be >= 1 but got 0",
msg="`throttle` should raise error when called with `count` < 1.",
):
- list(Stream([1]).throttle(0, per=datetime.timedelta(seconds=1)))
+ to_list(
+ Stream([1]).throttle(0, per=datetime.timedelta(seconds=1)), itype=itype
+ )
# test interval
interval_seconds = 0.3
@@ -1228,8 +1743,7 @@ def slow_first_elem(elem: int):
],
):
with self.subTest(stream=stream):
- duration, res = timestream(stream)
-
+ duration, res = timestream(stream, itype=itype)
self.assertListEqual(
res,
expected_elems,
@@ -1246,11 +1760,12 @@ def slow_first_elem(elem: int):
)
self.assertEqual(
- next(
- iter(
+ anext_or_next(
+ bi_iterable_to_iter(
Stream(src)
.throttle(1, per=datetime.timedelta(seconds=0.2))
- .throttle(1, per=datetime.timedelta(seconds=0.1))
+ .throttle(1, per=datetime.timedelta(seconds=0.1)),
+ itype=itype,
)
),
0,
@@ -1280,7 +1795,7 @@ def slow_first_elem(elem: int):
],
):
with self.subTest(N=N, stream=stream):
- duration, res = timestream(stream)
+ duration, res = timestream(stream, itype=itype)
self.assertListEqual(
res,
expected_elems,
@@ -1306,7 +1821,7 @@ def slow_first_elem(elem: int):
.throttle(1, per=datetime.timedelta(seconds=0.2)),
]:
with self.subTest(stream=stream):
- duration, _ = timestream(stream)
+ duration, _ = timestream(stream, itype=itype)
self.assertAlmostEqual(
duration,
expected_duration,
@@ -1368,29 +1883,35 @@ def slow_first_elem(elem: int):
msg="`throttle` must support legacy kwargs",
)
- def test_distinct(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_distinct(self, itype: IterableType) -> None:
self.assertListEqual(
- list(Stream("abbcaabcccddd").distinct()),
+ to_list(Stream("abbcaabcccddd").distinct(), itype=itype),
list("abcd"),
msg="`distinct` should yield distinct elements",
)
self.assertListEqual(
- list(Stream("aabbcccaabbcccc").distinct(consecutive_only=True)),
+ to_list(
+ Stream("aabbcccaabbcccc").distinct(consecutive_only=True), itype=itype
+ ),
list("abcabc"),
msg="`distinct` should only remove the duplicates that are consecutive if `consecutive_only=True`",
)
for consecutive_only in [True, False]:
self.assertListEqual(
- list(
+ to_list(
Stream(["foo", "bar", "a", "b"]).distinct(
len, consecutive_only=consecutive_only
- )
+ ),
+ itype=itype,
),
["foo", "a"],
msg="`distinct` should yield the first encountered elem among duplicates",
)
self.assertListEqual(
- list(Stream([]).distinct(consecutive_only=consecutive_only)),
+ to_list(
+ Stream([]).distinct(consecutive_only=consecutive_only), itype=itype
+ ),
[],
msg="`distinct` should yield zero elements on empty stream",
)
@@ -1399,11 +1920,51 @@ def test_distinct(self) -> None:
"unhashable type: 'list'",
msg="`distinct` should raise for non-hashable elements",
):
- list(Stream([[1]]).distinct())
+ to_list(Stream([[1]]).distinct(), itype=itype)
- def test_catch(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_adistinct(self, itype: IterableType) -> None:
+ self.assertListEqual(
+ to_list(Stream("abbcaabcccddd").adistinct(), itype=itype),
+ list("abcd"),
+ msg="`adistinct` should yield distinct elements",
+ )
self.assertListEqual(
- list(Stream(src).catch(finally_raise=True)),
+ to_list(
+ Stream("aabbcccaabbcccc").adistinct(consecutive_only=True), itype=itype
+ ),
+ list("abcabc"),
+ msg="`adistinct` should only remove the duplicates that are consecutive if `consecutive_only=True`",
+ )
+ for consecutive_only in [True, False]:
+ self.assertListEqual(
+ to_list(
+ Stream(["foo", "bar", "a", "b"]).adistinct(
+ asyncify(len), consecutive_only=consecutive_only
+ ),
+ itype=itype,
+ ),
+ ["foo", "a"],
+ msg="`adistinct` should yield the first encountered elem among duplicates",
+ )
+ self.assertListEqual(
+ to_list(
+ Stream([]).adistinct(consecutive_only=consecutive_only), itype=itype
+ ),
+ [],
+ msg="`adistinct` should yield zero elements on empty stream",
+ )
+ with self.assertRaisesRegex(
+ TypeError,
+ "unhashable type: 'list'",
+ msg="`adistinct` should raise for non-hashable elements",
+ ):
+ to_list(Stream([[1]]).adistinct(), itype=itype)
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_catch(self, itype: IterableType) -> None:
+ self.assertListEqual(
+ to_list(Stream(src).catch(finally_raise=True), itype=itype),
list(src),
msg="`catch` should yield elements in exception-less scenarios",
)
@@ -1431,12 +1992,12 @@ def f(i):
safe_src = list(src)
del safe_src[3]
self.assertListEqual(
- list(stream.catch(ZeroDivisionError)),
+ to_list(stream.catch(ZeroDivisionError), itype=itype),
list(map(f, safe_src)),
msg="If the exception type matches the `error_type`, then the impacted element should be ignored.",
)
self.assertListEqual(
- list(stream.catch()),
+ to_list(stream.catch(), itype=itype),
list(map(f, safe_src)),
msg="If the predicate is not specified, then all exceptions should be caught.",
)
@@ -1445,7 +2006,7 @@ def f(i):
ZeroDivisionError,
msg="If a non caught exception type occurs, then it should be raised.",
):
- list(stream.catch(TestError))
+ to_list(stream.catch(TestError), itype=itype)
first_value = 1
second_value = 2
@@ -1465,9 +2026,11 @@ def f(i):
erroring_stream.catch(finally_raise=True),
erroring_stream.catch(finally_raise=True),
]:
- erroring_stream_iterator = iter(caught_erroring_stream)
+ erroring_stream_iterator = bi_iterable_to_iter(
+ caught_erroring_stream, itype=itype
+ )
self.assertEqual(
- next(erroring_stream_iterator),
+ anext_or_next(erroring_stream_iterator),
first_value,
msg="`catch` should yield the first non exception throwing element.",
)
@@ -1476,13 +2039,14 @@ def f(i):
TestError,
msg="`catch` should raise the first error encountered when `finally_raise` is True.",
):
- for _ in erroring_stream_iterator:
+ while True:
+ anext_or_next(erroring_stream_iterator)
n_yields += 1
with self.assertRaises(
- StopIteration,
+ stopiteration_for_iter_type(type(erroring_stream_iterator)),
msg="`catch` with `finally_raise`=True should finally raise StopIteration to avoid infinite recursion if there is another catch downstream.",
):
- next(erroring_stream_iterator)
+ anext_or_next(erroring_stream_iterator)
self.assertEqual(
n_yields,
3,
@@ -1493,61 +2057,65 @@ def f(i):
map(lambda _: throw(TestError), range(2000))
).catch(TestError)
self.assertListEqual(
- list(only_caught_errors_stream),
+ to_list(only_caught_errors_stream, itype=itype),
[],
msg="When upstream raise exceptions without yielding any element, listing the stream must return empty list, without recursion issue.",
)
with self.assertRaises(
- StopIteration,
+ stopiteration_for_iter_type(itype),
msg="When upstream raise exceptions without yielding any element, then the first call to `next` on a stream catching all errors should raise StopIteration.",
):
- next(iter(only_caught_errors_stream))
+ anext_or_next(bi_iterable_to_iter(only_caught_errors_stream, itype=itype))
- iterator = iter(
+ iterator = bi_iterable_to_iter(
Stream(map(throw, [TestError, ValueError]))
.catch(ValueError, finally_raise=True)
- .catch(TestError, finally_raise=True)
+ .catch(TestError, finally_raise=True),
+ itype=itype,
)
with self.assertRaises(
ValueError,
msg="With 2 chained `catch`s with `finally_raise=True`, the error caught by the first `catch` is finally raised first (even though it was raised second)...",
):
- next(iterator)
+ anext_or_next(iterator)
with self.assertRaises(
TestError,
msg="... and then the error caught by the second `catch` is raised...",
):
- next(iterator)
+ anext_or_next(iterator)
with self.assertRaises(
- StopIteration,
+ stopiteration_for_iter_type(type(iterator)),
msg="... and a StopIteration is raised next.",
):
- next(iterator)
+ anext_or_next(iterator)
with self.assertRaises(
TypeError,
msg="`catch` does not catch if `when` not satisfied",
):
- list(
+ to_list(
Stream(map(throw, [ValueError, TypeError])).catch(
when=lambda exception: "ValueError" in repr(exception)
- )
+ ),
+ itype=itype,
)
self.assertListEqual(
- list(
+ to_list(
Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch(
ZeroDivisionError, replacement=float("inf")
- )
+ ),
+ itype=itype,
),
[float("inf"), 1, 0.5, 0.25],
msg="`catch` should be able to yield a non-None replacement",
)
self.assertListEqual(
- list(
+ to_list(
Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).catch(
ZeroDivisionError, replacement=cast(float, None)
- )
+ ),
+ itype=itype,
),
[None, 1, 0.5, 0.25],
msg="`catch` should be able to yield a None replacement",
@@ -1556,7 +2124,7 @@ def f(i):
errors_counter: Counter[Type[Exception]] = Counter()
self.assertListEqual(
- list(
+ to_list(
Stream(
map(
lambda n: 1 / n, # potential ZeroDivisionError
@@ -1571,7 +2139,8 @@ def f(i):
).catch(
(ValueError, TestError, ZeroDivisionError),
when=lambda err: errors_counter.update([type(err)]) is None,
- )
+ ),
+ itype=itype,
),
list(map(lambda n: 1 / n, range(2, 10, 2))),
msg="`catch` should accept multiple types",
@@ -1589,15 +2158,16 @@ def f(i):
# Stream(src).catch() # type: ignore
self.assertEqual(
- list(Stream(map(int, "foo")).catch(replacement=0)),
+ to_list(Stream(map(int, "foo")).catch(replacement=0), itype=itype),
[0] * len("foo"),
msg="`catch` must catch all errors when no error type provided",
)
self.assertEqual(
- list(
+ to_list(
Stream(map(int, "foo")).catch(
(None, None, ValueError, None), replacement=0
- )
+ ),
+ itype=itype,
),
[0] * len("foo"),
msg="`catch` must catch the provided non-None error types",
@@ -1606,14 +2176,239 @@ def f(i):
ValueError,
msg="`catch` must be noop if error type is None",
):
- list(Stream(map(int, "foo")).catch(None))
+ to_list(Stream(map(int, "foo")).catch(None), itype=itype)
with self.assertRaises(
ValueError,
msg="`catch` must be noop if error types are None",
):
- list(Stream(map(int, "foo")).catch((None, None, None)))
+ to_list(Stream(map(int, "foo")).catch((None, None, None)), itype=itype)
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_acatch(self, itype: IterableType) -> None:
+ self.assertListEqual(
+ to_list(Stream(src).acatch(finally_raise=True), itype=itype),
+ list(src),
+ msg="`acatch` should yield elements in exception-less scenarios",
+ )
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "`iterator` must be an AsyncIterator but got a ",
+ msg="`afunctions.acatch` function should raise TypeError when first argument is not an AsyncIterator",
+ ):
+ from streamable import afunctions
+
+ afunctions.acatch(cast(AsyncIterator, [3, 4]), Exception)
+
+ with self.assertRaisesRegex(
+ TypeError,
+ "`errors` must be None, or a subclass of `Exception`, or an iterable of optional subclasses of `Exception`, but got ",
+ msg="`acatch` should raise TypeError when first argument is not None or Type[Exception], or Iterable[Optional[Type[Exception]]]",
+ ):
+ Stream(src).acatch(1) # type: ignore
+
+ def f(i):
+ return i / (3 - i)
+
+ stream = Stream(lambda: map(f, src))
+ safe_src = list(src)
+ del safe_src[3]
+ self.assertListEqual(
+ to_list(stream.acatch(ZeroDivisionError), itype=itype),
+ list(map(f, safe_src)),
+ msg="If the exception type matches the `error_type`, then the impacted element should be ignored.",
+ )
+ self.assertListEqual(
+ to_list(stream.acatch(), itype=itype),
+ list(map(f, safe_src)),
+ msg="If the predicate is not specified, then all exceptions should be caught.",
+ )
+
+ with self.assertRaises(
+ ZeroDivisionError,
+ msg="If a non caught exception type occurs, then it should be raised.",
+ ):
+ to_list(stream.acatch(TestError), itype=itype)
+
+ first_value = 1
+ second_value = 2
+ third_value = 3
+ functions = [
+ lambda: throw(TestError),
+ lambda: throw(TypeError),
+ lambda: first_value,
+ lambda: second_value,
+ lambda: throw(ValueError),
+ lambda: third_value,
+ lambda: throw(ZeroDivisionError),
+ ]
+
+ erroring_stream: Stream[int] = Stream(lambda: map(lambda f: f(), functions))
+ for caught_erroring_stream in [
+ erroring_stream.acatch(finally_raise=True),
+ erroring_stream.acatch(finally_raise=True),
+ ]:
+ erroring_stream_iterator = bi_iterable_to_iter(
+ caught_erroring_stream, itype=itype
+ )
+ self.assertEqual(
+ anext_or_next(erroring_stream_iterator),
+ first_value,
+ msg="`acatch` should yield the first non exception throwing element.",
+ )
+ n_yields = 1
+ with self.assertRaises(
+ TestError,
+ msg="`acatch` should raise the first error encountered when `finally_raise` is True.",
+ ):
+ while True:
+ anext_or_next(erroring_stream_iterator)
+ n_yields += 1
+ with self.assertRaises(
+ stopiteration_for_iter_type(type(erroring_stream_iterator)),
+ msg="`acatch` with `finally_raise`=True should finally raise StopIteration to avoid infinite recursion if there is another catch downstream.",
+ ):
+ anext_or_next(erroring_stream_iterator)
+ self.assertEqual(
+ n_yields,
+ 3,
+ msg="3 elements should have passed been yielded between caught exceptions.",
+ )
+
+ only_caught_errors_stream = Stream(
+ map(lambda _: throw(TestError), range(2000))
+ ).acatch(TestError)
+ self.assertListEqual(
+ to_list(only_caught_errors_stream, itype=itype),
+ [],
+ msg="When upstream raise exceptions without yielding any element, listing the stream must return empty list, without recursion issue.",
+ )
+ with self.assertRaises(
+ stopiteration_for_iter_type(itype),
+ msg="When upstream raise exceptions without yielding any element, then the first call to `next` on a stream catching all errors should raise StopIteration.",
+ ):
+ anext_or_next(bi_iterable_to_iter(only_caught_errors_stream, itype=itype))
- def test_observe(self) -> None:
+ iterator = bi_iterable_to_iter(
+ Stream(map(throw, [TestError, ValueError]))
+ .acatch(ValueError, finally_raise=True)
+ .acatch(TestError, finally_raise=True),
+ itype=itype,
+ )
+ with self.assertRaises(
+ ValueError,
+ msg="With 2 chained `acatch`s with `finally_raise=True`, the error caught by the first `acatch` is finally raised first (even though it was raised second)...",
+ ):
+ anext_or_next(iterator)
+ with self.assertRaises(
+ TestError,
+ msg="... and then the error caught by the second `acatch` is raised...",
+ ):
+ anext_or_next(iterator)
+ with self.assertRaises(
+ stopiteration_for_iter_type(type(iterator)),
+ msg="... and a StopIteration is raised next.",
+ ):
+ anext_or_next(iterator)
+
+ with self.assertRaises(
+ TypeError,
+ msg="`acatch` does not catch if `when` not satisfied",
+ ):
+ to_list(
+ Stream(map(throw, [ValueError, TypeError])).acatch(
+ when=asyncify(lambda exception: "ValueError" in repr(exception))
+ ),
+ itype=itype,
+ )
+
+ self.assertListEqual(
+ to_list(
+ Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).acatch(
+ ZeroDivisionError, replacement=float("inf")
+ ),
+ itype=itype,
+ ),
+ [float("inf"), 1, 0.5, 0.25],
+ msg="`acatch` should be able to yield a non-None replacement",
+ )
+ self.assertListEqual(
+ to_list(
+ Stream(map(lambda n: 1 / n, [0, 1, 2, 4])).acatch(
+ ZeroDivisionError, replacement=cast(float, None)
+ ),
+ itype=itype,
+ ),
+ [None, 1, 0.5, 0.25],
+ msg="`acatch` should be able to yield a None replacement",
+ )
+
+ errors_counter: Counter[Type[Exception]] = Counter()
+
+ self.assertListEqual(
+ to_list(
+ Stream(
+ map(
+ lambda n: 1 / n, # potential ZeroDivisionError
+ map(
+ throw_for_odd_func(TestError), # potential TestError
+ map(
+ int, # potential ValueError
+ "01234foo56789",
+ ),
+ ),
+ )
+ ).acatch(
+ (ValueError, TestError, ZeroDivisionError),
+ when=asyncify(
+ lambda err: errors_counter.update([type(err)]) is None
+ ),
+ ),
+ itype=itype,
+ ),
+ list(map(lambda n: 1 / n, range(2, 10, 2))),
+ msg="`acatch` should accept multiple types",
+ )
+ self.assertDictEqual(
+ errors_counter,
+ {TestError: 5, ValueError: 3, ZeroDivisionError: 1},
+ msg="`acatch` should accept multiple types",
+ )
+
+ # with self.assertRaises(
+ # TypeError,
+ # msg="`acatch` without any error type must raise",
+ # ):
+ # Stream(src).acatch() # type: ignore
+
+ self.assertEqual(
+ to_list(Stream(map(int, "foo")).acatch(replacement=0), itype=itype),
+ [0] * len("foo"),
+ msg="`acatch` must catch all errors when no error type provided",
+ )
+ self.assertEqual(
+ to_list(
+ Stream(map(int, "foo")).acatch(
+ (None, None, ValueError, None), replacement=0
+ ),
+ itype=itype,
+ ),
+ [0] * len("foo"),
+ msg="`acatch` must catch the provided non-None error types",
+ )
+ with self.assertRaises(
+ ValueError,
+ msg="`acatch` must be noop if error type is None",
+ ):
+ to_list(Stream(map(int, "foo")).acatch(None), itype=itype)
+ with self.assertRaises(
+ ValueError,
+ msg="`acatch` must be noop if error types are None",
+ ):
+ to_list(Stream(map(int, "foo")).acatch((None, None, None)), itype=itype)
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_observe(self, itype: IterableType) -> None:
value_error_rainsing_stream: Stream[List[int]] = (
Stream("123--678")
.throttle(10, per=datetime.timedelta(seconds=1))
@@ -1625,7 +2420,7 @@ def test_observe(self) -> None:
)
self.assertListEqual(
- list(value_error_rainsing_stream.catch(ValueError)),
+ to_list(value_error_rainsing_stream.catch(ValueError), itype=itype),
[[1, 2], [3], [6, 7], [8]],
msg="This can break due to `group`/`map`/`catch`, check other breaking tests to determine quickly if it's an issue with `observe`.",
)
@@ -1634,10 +2429,11 @@ def test_observe(self) -> None:
ValueError,
msg="`observe` should forward-raise exceptions",
):
- list(value_error_rainsing_stream)
+ to_list(value_error_rainsing_stream, itype=itype)
def test_is_iterable(self) -> None:
self.assertIsInstance(Stream(src), Iterable)
+ self.assertIsInstance(Stream(src), AsyncIterable)
def test_count(self) -> None:
l: List[int] = []
@@ -1656,6 +2452,23 @@ def effect(x: int) -> None:
l, list(src), msg="`count` should iterate over the entire stream."
)
+ def test_acount(self) -> None:
+ l: List[int] = []
+
+ def effect(x: int) -> None:
+ nonlocal l
+ l.append(x)
+
+ stream = Stream(lambda: map(effect, src))
+ self.assertEqual(
+ asyncio.run(stream.acount()),
+ N,
+ msg="`count` should return the count of elements.",
+ )
+ self.assertListEqual(
+ l, list(src), msg="`count` should iterate over the entire stream."
+ )
+
def test_call(self) -> None:
l: List[int] = []
stream = Stream(src).map(l.append)
@@ -1670,51 +2483,62 @@ def test_call(self) -> None:
msg="`__call__` should exhaust the stream.",
)
- def test_multiple_iterations(self) -> None:
+ def test_await(self) -> None:
+ l: List[int] = []
+ stream = Stream(src).map(l.append)
+ self.assertIs(
+ asyncio.run(awaitable_to_coroutine(stream)),
+ stream,
+ msg="`__call__` should return the stream.",
+ )
+ self.assertListEqual(
+ l,
+ list(src),
+ msg="`__call__` should exhaust the stream.",
+ )
+
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_multiple_iterations(self, itype: IterableType) -> None:
stream = Stream(src)
for _ in range(3):
self.assertListEqual(
- list(stream),
+ to_list(stream, itype=itype),
list(src),
msg="The first iteration over a stream should yield the same elements as any subsequent iteration on the same stream, even if it is based on a `source` returning an iterator that only support 1 iteration.",
)
@parameterized.expand(
- [
- [1],
- [100],
- ]
+ [(concurrency, itype) for concurrency in (1, 100) for itype in ITERABLE_TYPES]
)
- def test_amap(self, concurrency) -> None:
+ def test_amap(self, concurrency, itype) -> None:
self.assertListEqual(
- list(
+ to_list(
Stream(src).amap(
async_randomly_slowed(async_square), concurrency=concurrency
- )
+ ),
+ itype=itype,
),
list(map(square, src)),
msg="At any concurrency the `amap` method should act as the builtin map function, transforming elements while preserving input elements order.",
)
- stream = Stream(src).amap(identity) # type: ignore
+ stream = Stream(src).amap(identity, concurrency=concurrency) # type: ignore
with self.assertRaisesRegex(
TypeError,
- r"`transformation` must be an async function i\.e\. a function returning a Coroutine but it returned a ",
+ r"must be an async function i\.e\. a function returning a Coroutine but it returned a ",
msg="`amap` should raise a TypeError if a non async function is passed to it.",
):
- next(iter(stream))
+ anext_or_next(bi_iterable_to_iter(stream, itype=itype))
@parameterized.expand(
- [
- [1],
- [100],
- ]
+ [(concurrency, itype) for concurrency in (1, 100) for itype in ITERABLE_TYPES]
)
- def test_aforeach(self, concurrency) -> None:
+ def test_aforeach(self, concurrency, itype) -> None:
self.assertListEqual(
- list(
+ to_list(
Stream(src).aforeach(
async_randomly_slowed(async_square), concurrency=concurrency
- )
+ ),
+ itype=itype,
),
list(src),
msg="At any concurrency the `foreach` method must preserve input elements order.",
@@ -1725,9 +2549,10 @@ def test_aforeach(self, concurrency) -> None:
r"`transformation` must be an async function i\.e\. a function returning a Coroutine but it returned a ",
msg="`aforeach` should raise a TypeError if a non async function is passed to it.",
):
- next(iter(stream))
+ anext_or_next(bi_iterable_to_iter(stream, itype=itype))
- def test_pipe(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_pipe(self, itype: IterableType) -> None:
def func(
stream: Stream, *ints: int, **strings: str
) -> Tuple[Stream, Tuple[int, ...], Dict[str, str]]:
@@ -1744,8 +2569,8 @@ def func(
)
self.assertListEqual(
- stream.pipe(list),
- list(stream),
+ stream.pipe(to_list, itype=itype),
+ to_list(stream, itype=itype),
msg="`pipe` should be ok without args and kwargs.",
)
@@ -1753,18 +2578,27 @@ def test_eq(self) -> None:
stream = (
Stream(src)
.catch((TypeError, ValueError), replacement=2, when=identity)
+ .acatch((TypeError, ValueError), replacement=2, when=async_identity)
.distinct(key=identity)
+ .adistinct(key=async_identity)
.filter(identity)
+ .afilter(async_identity)
.foreach(identity, concurrency=3)
.aforeach(async_identity, concurrency=3)
.group(3, by=bool)
.flatten(concurrency=3)
+ .agroup(3, by=async_identity)
+ .map(sync_to_async_iter)
+ .aflatten(concurrency=3)
.groupby(bool)
+ .agroupby(async_identity)
.map(identity, via="process")
.amap(async_identity)
.observe("foo")
.skip(3)
+ .askip(3)
.truncate(4)
+ .atruncate(4)
.throttle(1, per=datetime.timedelta(seconds=1))
)
@@ -1772,64 +2606,100 @@ def test_eq(self) -> None:
stream,
Stream(src)
.catch((TypeError, ValueError), replacement=2, when=identity)
+ .acatch((TypeError, ValueError), replacement=2, when=async_identity)
.distinct(key=identity)
+ .adistinct(key=async_identity)
.filter(identity)
+ .afilter(async_identity)
.foreach(identity, concurrency=3)
.aforeach(async_identity, concurrency=3)
.group(3, by=bool)
.flatten(concurrency=3)
+ .agroup(3, by=async_identity)
+ .map(sync_to_async_iter)
+ .aflatten(concurrency=3)
.groupby(bool)
+ .agroupby(async_identity)
.map(identity, via="process")
.amap(async_identity)
.observe("foo")
.skip(3)
+ .askip(3)
.truncate(4)
+ .atruncate(4)
.throttle(1, per=datetime.timedelta(seconds=1)),
)
self.assertNotEqual(
stream,
Stream(list(src)) # not same source
.catch((TypeError, ValueError), replacement=2, when=identity)
+ .acatch((TypeError, ValueError), replacement=2, when=async_identity)
.distinct(key=identity)
+ .adistinct(key=async_identity)
.filter(identity)
+ .afilter(async_identity)
.foreach(identity, concurrency=3)
.aforeach(async_identity, concurrency=3)
.group(3, by=bool)
.flatten(concurrency=3)
+ .agroup(3, by=async_identity)
+ .map(sync_to_async_iter)
+ .aflatten(concurrency=3)
.groupby(bool)
+ .agroupby(async_identity)
.map(identity, via="process")
.amap(async_identity)
.observe("foo")
.skip(3)
+ .askip(3)
.truncate(4)
+ .atruncate(4)
.throttle(1, per=datetime.timedelta(seconds=1)),
)
self.assertNotEqual(
stream,
Stream(src)
.catch((TypeError, ValueError), replacement=2, when=identity)
+ .acatch((TypeError, ValueError), replacement=2, when=async_identity)
.distinct(key=identity)
+ .adistinct(key=async_identity)
.filter(identity)
+ .afilter(async_identity)
.foreach(identity, concurrency=3)
.aforeach(async_identity, concurrency=3)
.group(3, by=bool)
.flatten(concurrency=3)
+ .agroup(3, by=async_identity)
+ .map(sync_to_async_iter)
+ .aflatten(concurrency=3)
.groupby(bool)
+ .agroupby(async_identity)
.map(identity, via="process")
.amap(async_identity)
.observe("foo")
.skip(3)
+ .askip(3)
.truncate(4)
+ .atruncate(4)
.throttle(1, per=datetime.timedelta(seconds=2)), # not the same interval
)
- def test_ref_cycles(self) -> None:
+ @parameterized.expand(ITERABLE_TYPES)
+ def test_ref_cycles(self, itype: IterableType) -> None:
+ async def async_int(o: Any) -> int:
+ return int(o)
+
stream = (
- Stream(map(int, "123_5")).group(1).groupby(len).catch(finally_raise=True)
+ Stream("123_5")
+ .amap(async_int)
+ .map(str)
+ .group(1)
+ .groupby(len)
+ .catch(finally_raise=True)
)
exception: Exception
try:
- list(stream)
+ to_list(stream, itype=itype)
except ValueError as e:
exception = e
self.assertIsInstance(
@@ -1837,19 +2707,21 @@ def test_ref_cycles(self) -> None:
ValueError,
msg="`finally_raise` must be respected",
)
- frames: Iterator[FrameType] = map(
- itemgetter(0), traceback.walk_tb(exception.__traceback__)
- )
- next(frames)
self.assertFalse(
[
(var, val)
- for frame in frames
+ # go through the frames of the exception's traceback
+ for frame, _ in traceback.walk_tb(exception.__traceback__)
+ # skipping the current frame
+ if frame is not cast(TracebackType, exception.__traceback__).tb_frame
+ # go through the locals captured in that frame
for var, val in frame.f_locals.items()
+ # check if one of them is an exception
if isinstance(val, Exception)
- and frame in map(itemgetter(0), traceback.walk_tb(val.__traceback__))
+ # check if it is captured in its own traceback
+ and frame is cast(TracebackType, val.__traceback__).tb_frame
],
- msg=f"the exception's traceback should not contain an exception that captures itself in its own traceback",
+ msg=f"the exception's traceback should not contain an exception captured in its own traceback",
)
def test_on_queue_in_thread(self) -> None:
diff --git a/tests/test_visitor.py b/tests/test_visitor.py
index d13301df..ac5cce7c 100644
--- a/tests/test_visitor.py
+++ b/tests/test_visitor.py
@@ -2,8 +2,16 @@
from typing import cast
from streamable.stream import (
+ ACatchStream,
+ ADistinctStream,
+ AFilterStream,
+ AFlattenStream,
AForeachStream,
+ AGroupbyStream,
+ AGroupStream,
AMapStream,
+ ASkipStream,
+ ATruncateStream,
CatchStream,
DistinctStream,
FilterStream,
@@ -29,19 +37,27 @@ def visit_stream(self, stream: Stream) -> None:
visitor = ConcreteVisitor()
visitor.visit_catch_stream(cast(CatchStream, ...))
+ visitor.visit_acatch_stream(cast(ACatchStream, ...))
visitor.visit_distinct_stream(cast(DistinctStream, ...))
+ visitor.visit_adistinct_stream(cast(ADistinctStream, ...))
visitor.visit_filter_stream(cast(FilterStream, ...))
+ visitor.visit_afilter_stream(cast(AFilterStream, ...))
visitor.visit_flatten_stream(cast(FlattenStream, ...))
+ visitor.visit_aflatten_stream(cast(AFlattenStream, ...))
visitor.visit_foreach_stream(cast(ForeachStream, ...))
visitor.visit_aforeach_stream(cast(AForeachStream, ...))
visitor.visit_group_stream(cast(GroupStream, ...))
+ visitor.visit_agroup_stream(cast(AGroupStream, ...))
visitor.visit_groupby_stream(cast(GroupbyStream, ...))
+ visitor.visit_agroupby_stream(cast(AGroupbyStream, ...))
visitor.visit_map_stream(cast(MapStream, ...))
visitor.visit_amap_stream(cast(AMapStream, ...))
visitor.visit_observe_stream(cast(ObserveStream, ...))
visitor.visit_skip_stream(cast(SkipStream, ...))
+ visitor.visit_askip_stream(cast(ASkipStream, ...))
visitor.visit_throttle_stream(cast(ThrottleStream, ...))
visitor.visit_truncate_stream(cast(TruncateStream, ...))
+ visitor.visit_atruncate_stream(cast(ATruncateStream, ...))
visitor.visit_stream(cast(Stream, ...))
def test_depth_visitor_example(self):
diff --git a/tests/utils.py b/tests/utils.py
new file mode 100644
index 00000000..94d6efa3
--- /dev/null
+++ b/tests/utils.py
@@ -0,0 +1,215 @@
+import asyncio
+import random
+import time
+import timeit
+from typing import (
+ Any,
+ AsyncIterable,
+ AsyncIterator,
+ Callable,
+ Coroutine,
+ Iterable,
+ Iterator,
+ List,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+from streamable.stream import Stream
+from streamable.util.iterabletools import BiIterable
+
+T = TypeVar("T")
+R = TypeVar("R")
+
+IterableType = Union[Type[Iterable], Type[AsyncIterable]]
+ITERABLE_TYPES: Tuple[IterableType, ...] = (Iterable, AsyncIterable)
+
+TEST_EVENT_LOOP = asyncio.new_event_loop()
+
+
+async def _aiter_to_list(aiterable: AsyncIterable[T]) -> List[T]:
+ return [elem async for elem in aiterable]
+
+
+def aiterable_to_list(aiterable: AsyncIterable[T]) -> List[T]:
+ return TEST_EVENT_LOOP.run_until_complete(_aiter_to_list(aiterable))
+
+
+async def _aiter_to_set(aiterable: AsyncIterable[T]) -> Set[T]:
+ return {elem async for elem in aiterable}
+
+
+def aiterable_to_set(aiterable: AsyncIterable[T]) -> Set[T]:
+ return TEST_EVENT_LOOP.run_until_complete(_aiter_to_set(aiterable))
+
+
+def stopiteration_for_iter_type(itype: IterableType) -> Type[Exception]:
+ if issubclass(itype, AsyncIterable):
+ return StopAsyncIteration
+ return StopIteration
+
+
+def to_list(stream: Stream[T], itype: IterableType) -> List[T]:
+ assert isinstance(stream, Stream)
+ if itype is AsyncIterable:
+ return aiterable_to_list(stream)
+ else:
+ return list(stream)
+
+
+def to_set(stream: Stream[T], itype: IterableType) -> Set[T]:
+ assert isinstance(stream, Stream)
+ if itype is AsyncIterable:
+ return aiterable_to_set(stream)
+ else:
+ return set(stream)
+
+
+def bi_iterable_to_iter(
+ iterable: Union[BiIterable[T], Stream[T]], itype: IterableType
+) -> Union[Iterator[T], AsyncIterator[T]]:
+ if itype is AsyncIterable:
+ return iterable.__aiter__()
+ else:
+ return iter(iterable)
+
+
+def anext_or_next(it: Union[Iterator[T], AsyncIterator[T]]) -> T:
+ if isinstance(it, AsyncIterator):
+ return TEST_EVENT_LOOP.run_until_complete(it.__anext__())
+ else:
+ return next(it)
+
+
+def alist_or_list(iterable: Union[Iterable[T], AsyncIterable[T]]) -> List[T]:
+ if isinstance(iterable, AsyncIterable):
+ return aiterable_to_list(iterable)
+ else:
+ return list(iterable)
+
+
+def timestream(
+ stream: Stream[T], times: int = 1, itype: IterableType = Iterable
+) -> Tuple[float, List[T]]:
+ res: List[T] = []
+
+ def iterate():
+ nonlocal res
+ res = to_list(stream, itype=itype)
+
+ return timeit.timeit(iterate, number=times) / times, res
+
+
+def identity_sleep(seconds: float) -> float:
+ time.sleep(seconds)
+ return seconds
+
+
+async def async_identity_sleep(seconds: float) -> float:
+ await asyncio.sleep(seconds)
+ return seconds
+
+
+# simulates an I/0 bound function
+slow_identity_duration = 0.05
+
+
+def slow_identity(x: T) -> T:
+ time.sleep(slow_identity_duration)
+ return x
+
+
+async def async_slow_identity(x: T) -> T:
+ await asyncio.sleep(slow_identity_duration)
+ return x
+
+
+def identity(x: T) -> T:
+ return x
+
+
+# fmt: off
+async def async_identity(x: T) -> T: return x
+# fmt: on
+
+
+def square(x):
+ return x**2
+
+
+async def async_square(x):
+ return x**2
+
+
+def throw(exc: Type[Exception]):
+ raise exc()
+
+
+def throw_func(exc: Type[Exception]) -> Callable[[T], T]:
+ return lambda _: throw(exc)
+
+
+def async_throw_func(exc: Type[Exception]) -> Callable[[T], Coroutine[Any, Any, T]]:
+ async def f(_: T) -> T:
+ raise exc
+
+ return f
+
+
+def throw_for_odd_func(exc):
+ return lambda i: throw(exc) if i % 2 == 1 else i
+
+
+def async_throw_for_odd_func(exc):
+ async def f(i):
+ return throw(exc) if i % 2 == 1 else i
+
+ return f
+
+
+class TestError(Exception):
+ pass
+
+
+DELTA_RATE = 0.4
+# size of the test collections
+N = 256
+
+src = range(N)
+
+even_src = range(0, N, 2)
+
+
+def randomly_slowed(
+ func: Callable[[T], R], min_sleep: float = 0.001, max_sleep: float = 0.05
+) -> Callable[[T], R]:
+ def wrap(x: T) -> R:
+ time.sleep(min_sleep + random.random() * (max_sleep - min_sleep))
+ return func(x)
+
+ return wrap
+
+
+def async_randomly_slowed(
+ async_func: Callable[[T], Coroutine[Any, Any, R]],
+ min_sleep: float = 0.001,
+ max_sleep: float = 0.05,
+) -> Callable[[T], Coroutine[Any, Any, R]]:
+ async def wrap(x: T) -> R:
+ await asyncio.sleep(min_sleep + random.random() * (max_sleep - min_sleep))
+ return await async_func(x)
+
+ return wrap
+
+
+def range_raising_at_exhaustion(
+ start: int, end: int, step: int, exception: Exception
+) -> Iterator[int]:
+ yield from range(start, end, step)
+ raise exception
+
+
+src_raising_at_exhaustion = lambda: range_raising_at_exhaustion(0, N, 1, TestError())
diff --git a/version.py b/version.py
index 7117f661..58329072 100644
--- a/version.py
+++ b/version.py
@@ -1,2 +1,2 @@
-# print CHANGELOG: git log --oneline -- version.py | grep -v '\-rc'
-__version__ = "1.5.1"
+# print CHANGELOG: git log --oneline -- version.py
+__version__ = "1.6.0a2"