Skip to content

Commit b8bd3a7

Browse files
authored
Merge pull request #31 from yassun7010/add_connection_classmethod
wip: add connection classmethods.
2 parents 95e6306 + a8204a4 commit b8bd3a7

File tree

24 files changed

+511
-399
lines changed

24 files changed

+511
-399
lines changed

turu-bigquery/src/turu/bigquery/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import importlib.metadata
22

3-
from .connection import Connection, connect
3+
from .connection import Connection
44
from .cursor import Cursor
55
from .mock_connection import MockConnection
66
from .mock_cursor import MockCursor
@@ -14,3 +14,5 @@
1414
"MockConnection",
1515
"MockCursor",
1616
]
17+
18+
connect = Connection.connect

turu-bigquery/src/turu/bigquery/connection.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional
1+
from typing import Any, Optional
22

33
import google.api_core.client_info
44
import google.api_core.client_options
@@ -9,15 +9,45 @@
99
import turu.bigquery.cursor
1010
import turu.core.connection
1111
import turu.core.mock
12-
from typing_extensions import Never, deprecated, override
12+
from typing_extensions import Never, Self, deprecated, override
1313

1414
from .cursor import Cursor
1515

16+
try:
17+
from google.cloud.bigquery_storage import BigQueryReadClient # type: ignore
18+
19+
except ImportError:
20+
21+
class BigQueryReadClient:
22+
pass
23+
1624

1725
class Connection(turu.core.connection.Connection):
1826
def __init__(self, connection: google.cloud.bigquery.dbapi.Connection):
1927
self._raw_connection = connection
2028

29+
@override
30+
@classmethod
31+
def connect( # type: ignore[override]
32+
cls,
33+
client: Optional[google.cloud.bigquery.Client] = None,
34+
bqstorage_client: Optional[BigQueryReadClient] = None,
35+
) -> Self:
36+
import google.cloud.bigquery
37+
import google.cloud.bigquery.dbapi
38+
39+
return cls(
40+
google.cloud.bigquery.dbapi.connect(
41+
client=client,
42+
bqstorage_client=bqstorage_client,
43+
),
44+
)
45+
46+
@classmethod
47+
@override
48+
def connect_from_env(cls, *args: Any, **kwargs: Any) -> Self:
49+
return cls.connect(*args, **kwargs)
50+
2151
@override
2252
def close(self) -> None:
2353
"""Close the connection and any cursors created from it."""
@@ -39,28 +69,3 @@ def cursor(self) -> Cursor[Never]:
3969
"""Return a new cursor object."""
4070

4171
return Cursor(self._raw_connection.cursor())
42-
43-
44-
try:
45-
from google.cloud.bigquery_storage import BigQueryReadClient # type: ignore
46-
47-
48-
except ImportError:
49-
50-
class BigQueryReadClient:
51-
pass
52-
53-
54-
def connect(
55-
client: Optional[google.cloud.bigquery.Client] = None,
56-
bqstorage_client: Optional[BigQueryReadClient] = None,
57-
) -> Connection:
58-
import google.cloud.bigquery
59-
import google.cloud.bigquery.dbapi
60-
61-
return Connection(
62-
google.cloud.bigquery.dbapi.connect(
63-
client=client,
64-
bqstorage_client=bqstorage_client,
65-
),
66-
)

turu-bigquery/src/turu/bigquery/mock_connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .mock_cursor import MockCursor
88

99

10-
class MockConnection(Connection, turu.core.mock.MockConnection):
10+
class MockConnection(turu.core.mock.MockConnection, Connection):
1111
def __init__(self, *args, **kwargs):
1212
turu.core.mock.MockConnection.__init__(self)
1313

turu-core/src/turu/core/async_connection.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,22 @@
44
import turu.core.async_cursor
55
from turu.core.protocols.async_connection import AsyncConnectionProtocol
66
from turu.core.protocols.async_cursor import Parameters
7-
from typing_extensions import Never
7+
from typing_extensions import Never, Self
88

99

1010
class AsyncConnection(AsyncConnectionProtocol):
11+
@classmethod
12+
@abstractmethod
13+
async def connect(cls, *args: Any, **kwargs: Any) -> Self:
14+
"""Connect to a database."""
15+
...
16+
17+
@classmethod
18+
@abstractmethod
19+
async def connect_from_env(cls, *args: Any, **kwargs: Any) -> Self:
20+
"""Connect to a database using environment variables."""
21+
...
22+
1123
@abstractmethod
1224
async def close(self) -> None:
1325
...

turu-core/src/turu/core/connection.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,22 @@
44
import turu.core.cursor
55
from turu.core.protocols.connection import ConnectionProtocol
66
from turu.core.protocols.cursor import Parameters
7-
from typing_extensions import Never, override
7+
from typing_extensions import Never, Self, override
88

99

1010
class Connection(ConnectionProtocol):
11+
@classmethod
12+
@abstractmethod
13+
def connect(cls, *args: Any, **kwargs: Any) -> Self:
14+
"""Connect to a database."""
15+
...
16+
17+
@classmethod
18+
@abstractmethod
19+
def connect_from_env(cls, *args: Any, **kwargs: Any) -> Self:
20+
"""Connect to a database using environment variables."""
21+
...
22+
1123
@abstractmethod
1224
@override
1325
def close(self) -> None:

turu-core/src/turu/core/mock/async_connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ class MockAsyncConnection(turu.core.async_connection.AsyncConnection):
2323
def __init__(self, store: Optional[TuruMockStore] = None):
2424
self._turu_mock_store = store or TuruMockStore()
2525

26+
@classmethod
27+
async def connect(cls, *args: Any, **kwargs: Any) -> Self:
28+
return cls()
29+
30+
@classmethod
31+
async def connect_from_env(cls, *args: Any, **kwargs: Any) -> Self:
32+
return cls()
33+
2634
def chain(self) -> Self:
2735
"""this method is just for code formatting by black."""
2836

turu-core/src/turu/core/mock/connection.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ class MockConnection(turu.core.connection.Connection):
2727
def __init__(self, store: Optional[TuruMockStore] = None):
2828
self._turu_mock_store = store or TuruMockStore()
2929

30+
@classmethod
31+
def connect(cls, *args: Any, **kwargs: Any) -> Self:
32+
return cls()
33+
34+
@classmethod
35+
def connect_from_env(cls, *args: Any, **kwargs: Any) -> Self:
36+
return cls()
37+
3038
def chain(self) -> Self:
3139
"""this method is just for code formatting by black."""
3240

turu-core/tests/turu/core/test_record.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from pydantic import BaseModel
99
from turu.core.exception import TuruRowTypeNotSupportedError
1010
from turu.core.record import RecordCursor, record_to_csv
11-
from typing_extensions import Never
11+
from typing_extensions import Never, Self
1212

1313

1414
class RowPydantic(BaseModel):
@@ -197,6 +197,14 @@ def custom_method(self, value: int) -> None:
197197
pass
198198

199199
class CustomConnection(turu.core.mock.MockConnection):
200+
@classmethod
201+
def connect(cls, *args, **kwargs) -> Self:
202+
return cls()
203+
204+
@classmethod
205+
def connect_from_env(cls, *args, **kwargs) -> Self:
206+
return cls.connect()
207+
200208
def cursor(self) -> CustomCursor:
201209
return CustomCursor(self._turu_mock_store)
202210

turu-mysql/src/turu/mysql/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@
22

33
from .async_connection import (
44
AsyncConnection,
5-
connect_async,
6-
connect_async_from_env,
75
)
86
from .async_cursor import AsyncCursor
97
from .connection import (
108
Connection,
11-
connect,
12-
connect_from_env,
139
)
1410
from .cursor import Cursor
1511
from .mock_async_connection import MockAsyncConnection
@@ -19,7 +15,6 @@
1915

2016
__version__ = importlib.metadata.version("turu-mysql")
2117

22-
2318
__all__ = [
2419
"AsyncConnection",
2520
"AsyncCursor",
@@ -34,3 +29,8 @@
3429
"MockConnection",
3530
"MockCursor",
3631
]
32+
33+
connect = Connection.connect
34+
connect_from_env = Connection.connect_from_env
35+
connect_async = AsyncConnection.connect
36+
connect_async_from_env = AsyncConnection.connect_from_env

turu-mysql/src/turu/mysql/async_connection.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import turu.core.mock
1010
import turu.mysql.async_cursor
1111
import turu.mysql.cursor
12-
from typing_extensions import Never, Unpack, override
12+
from typing_extensions import Never, Self, Unpack, override
1313

1414
from .async_cursor import AsyncCursor
1515

@@ -18,6 +18,48 @@ class AsyncConnection(turu.core.async_connection.AsyncConnection):
1818
def __init__(self, connection: aiomysql.Connection):
1919
self._raw_connection = connection
2020

21+
@override
22+
@classmethod
23+
async def connect( # type: ignore[override]
24+
cls,
25+
user: Optional[str] = None,
26+
password: str = "",
27+
host: str = "localhost",
28+
database: Optional[str] = None,
29+
port: int = 0,
30+
**kwargs: Unpack["_ConnectParams"],
31+
) -> Self:
32+
return cls(
33+
await aiomysql.connection._connect(
34+
user=user,
35+
password=password,
36+
host=host,
37+
db=database,
38+
port=port,
39+
**kwargs,
40+
)
41+
)
42+
43+
@override
44+
@classmethod
45+
async def connect_from_env( # type: ignore[override]
46+
cls,
47+
user_envname: str = "MYSQL_USER",
48+
password_envname: str = "MYSQL_PASSWORD",
49+
host_envname: str = "MYSQL_HOST",
50+
database_envname: str = "MYSQL_DATABASE",
51+
port_envname: str = "MYSQL_PORT",
52+
**kwargs: Unpack["_ConnectParams"],
53+
) -> Self:
54+
return await cls.connect(
55+
user=os.environ.get(user_envname),
56+
password=os.environ.get(password_envname, ""),
57+
host=os.environ.get(host_envname, "localhost"),
58+
database=os.environ.get(database_envname),
59+
port=int(os.environ.get(port_envname, 0)),
60+
**kwargs,
61+
)
62+
2163
@override
2264
async def close(self) -> None:
2365
await self._raw_connection.ensure_closed()
@@ -60,41 +102,3 @@ class _ConnectParams(TypedDict, total=False):
60102
auth_plugin: str
61103
program_name: str
62104
server_public_key: Optional[Any]
63-
64-
65-
async def connect_async(
66-
user: Optional[str] = None,
67-
password: str = "",
68-
host: str = "localhost",
69-
database: Optional[str] = None,
70-
port: int = 0,
71-
**kwargs: Unpack[_ConnectParams],
72-
) -> AsyncConnection:
73-
return AsyncConnection(
74-
await aiomysql.connection._connect(
75-
user=user,
76-
password=password,
77-
host=host,
78-
db=database,
79-
port=port,
80-
**kwargs,
81-
)
82-
)
83-
84-
85-
async def connect_async_from_env(
86-
user_envname: str = "MYSQL_USER",
87-
password_envname: str = "MYSQL_PASSWORD",
88-
host_envname: str = "MYSQL_HOST",
89-
database_envname: str = "MYSQL_DATABASE",
90-
port_envname: str = "MYSQL_PORT",
91-
**kwargs: Unpack[_ConnectParams],
92-
) -> AsyncConnection:
93-
return await connect_async(
94-
user=os.environ.get(user_envname),
95-
password=os.environ.get(password_envname, ""),
96-
host=os.environ.get(host_envname, "localhost"),
97-
database=os.environ.get(database_envname),
98-
port=int(os.environ.get(port_envname, 0)),
99-
**kwargs,
100-
)

0 commit comments

Comments
 (0)