Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new container risingwave #731

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ We have an [issue template](.github/ISSUE_TEMPLATE/new-container.md) for adding
Once you've talked to the maintainers (we do our best to reply!) then you can proceed with contributing the new container.

> [!WARNING]
> PLease raise an issue before you try to contribute a new container! It helps maintainers understand your use-case and motivation.
> Please raise an issue before you try to contribute a new container! It helps maintainers understand your use-case and motivation.
> This way we can keep pull requests foruced on the "how", not the "why"! :pray:
> It also gives maintainers a chance to give you last-minute guidance on caveats or expectations, particularly with
> new extra dependencies and how to manage them.
Expand Down
2 changes: 2 additions & 0 deletions modules/risingwave/README.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. autoclass:: testcontainers.risingwave.RisingWaveContainer
.. title:: testcontainers.risingwave.RisingWaveContainer
87 changes: 87 additions & 0 deletions modules/risingwave/testcontainers/risingwave/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from risingwave.core import RisingWave, RisingWaveConnOptions
from testcontainers.core.container import DockerContainer
from testcontainers.core.waiting_utils import (
wait_container_is_ready,
wait_for_logs,
)


class RisingWaveContainer(DockerContainer):
"""
RisingWave database container.

Example:

The example spins up a RisingWave database and connects to it using
the :code:`risingwave-py` library.

.. doctest::

>>> from testcontainers.risingwave import RisingWaveContainer

>>> with RisingWaveContainer("risingwavelabs/risingwave:v2.0.2") as rw:
... client = rw.get_client()
... version = client.fetchone("select version()")
>>> version
'PostgreSQL 13.14.0-RisingWave-2.0.2...'
"""

_DEFAULT_PORT = 4566

def __init__(
self,
image: str = "risingwavelabs/risingwave:latest",
port: int = _DEFAULT_PORT, # external port
internal_port: int = _DEFAULT_PORT,
username: str = "root",
password: str = "",
dbname: str = "dev",
**kwargs,
) -> None:
super().__init__(image=image, **kwargs)
self.username: str = username
self.password: str = password
self.dbname: str = dbname
# support concurrent testing by using different ports
self.internal_port: int = internal_port
self.port: int = port

self.with_exposed_ports(self.internal_port)
self.with_bind_ports(self.internal_port, self.port)

def _configure(self) -> None:
self.with_command("single_node")

@wait_container_is_ready()
def _connect(self) -> None:
wait_for_logs(self, predicate="RisingWave standalone mode is ready")

def get_client(self, **kwargs) -> RisingWave:
conn = RisingWaveConnOptions.from_connection_info(
host=kwargs.get("host", self.get_container_host_ip()),
user=kwargs.get("user", self.username),
password=kwargs.get("password", self.password),
port=kwargs.get("port", self.port),
database=kwargs.get("database", self.dbname),
)
return RisingWave(conn)

def start(self) -> RisingWaveContainer:
super().start()
self._connect()
return self
122 changes: 122 additions & 0 deletions modules/risingwave/tests/test_risingwave.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import threading

import pytest
from risingwave.core import OutputFormat
from sqlalchemy import Row
from testcontainers.risingwave import RisingWaveContainer


@pytest.mark.inside_docker_check
@pytest.mark.parametrize("version", ["v1.10.1", "v1.10.2", "v2.0.2", "latest"])
def test_docker_run_risingwave_versions(version: str):
with RisingWaveContainer(f"risingwavelabs/risingwave:{version}") as rw:
client = rw.get_client()
try:
result = client.fetchone("select version();")
assert isinstance(result, Row), "Result is not of type sqlalchemy.Row"
got_version = result[0]
assert got_version.startswith("PostgreSQL 13.14.0-RisingWave")
if version != "latest":
assert version in version
finally:
# close to suppress warnings from OperationalError in between runs.
client.conn.close()


@pytest.mark.inside_docker_check
def test_docker_run_risingwave_in_parallel():
PORTS = [4500, 4566]

def run_risingwave(port):
with RisingWaveContainer(image="risingwavelabs/risingwave:v2.0.2", port=port) as rw:
client = rw.get_client()
try:
assert rw.port == port
assert rw.internal_port == 4566
assert client.conn.closed is False
finally:
# close to suppress warnings from OperationalError in between runs.
client.conn.close()

for port in PORTS:
threading.Thread(target=run_risingwave, args=(port,)).start()


@pytest.mark.inside_docker_check
def test_docker_run_risingwave_create_materialized_view():
import time

import pandas as pd

# ARRANGE
SCHEMA = "testcontainer"
raw_data = [
{"id": 1, "name": "Alice", "age": 18},
{"id": 2, "name": "Bob", "age": 19},
{"id": 3, "name": "Charlie", "age": 20},
{"id": 4, "name": "David", "age": 21},
{"id": 5, "name": "Alice", "age": 22},
]
want_data = pd.DataFrame(
{
"name": ["Alice", "Bob", "Charlie", "David"],
"total_age": [40, 19, 20, 21],
"total_count": [2, 1, 1, 1],
}
)

# ACT
with RisingWaveContainer() as rw:
client = rw.get_client()

def generate_data():
client.execute(f"CREATE SCHEMA IF NOT EXISTS {SCHEMA}")
client.execute(
f"""
CREATE TABLE IF NOT EXISTS {SCHEMA}.example_table
(id INT, name TEXT, age INT)
"""
)
for rd in raw_data:
client.insert_row(
table_name="example_table",
schema_name=SCHEMA,
force_flush=True,
**rd,
)
time.sleep(1)

def create_mv():
while not client.check_exist(name="example_table", schema_name=SCHEMA):
time.sleep(1)
continue
return client.mv(
name="example_mv",
schema_name=SCHEMA,
stmt=f"""
SELECT name, SUM(age) AS total_age, COUNT(*) AS total_count
FROM {SCHEMA}.example_table
GROUP BY 1
""",
)

try:
for fn in [create_mv, generate_data]:
threading.Thread(target=fn).start()

time.sleep(len(raw_data) + 2)
got_data = client.fetch(
f"select * from {SCHEMA}.example_mv",
OutputFormat.DATAFRAME,
)

# ASSERT
assert isinstance(got_data, pd.DataFrame)
pd.testing.assert_frame_equal(
got_data.sort_values("name").reset_index(drop=True),
want_data.sort_values("name").reset_index(drop=True),
)

finally:
# close to suppress warnings from OperationalError in between runs.
client.conn.close()
Loading