Skip to content

Commit dff7000

Browse files
authored
Move Field generation responsability to Resource (#66)
1 parent 6107892 commit dff7000

File tree

15 files changed

+180
-103
lines changed

15 files changed

+180
-103
lines changed

local_data_api/models.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from base64 import b64encode
43
from datetime import date, datetime, time
54
from decimal import Decimal
65
from enum import Enum
@@ -25,33 +24,6 @@ class Field(BaseModel):
2524
longValue: Optional[int]
2625
stringValue: Optional[str]
2726

28-
@classmethod
29-
def from_value(cls, value: Any) -> Field:
30-
if isinstance(value, bool):
31-
return cls(booleanValue=value)
32-
elif isinstance(value, str):
33-
return cls(stringValue=value)
34-
elif isinstance(value, datetime):
35-
return cls(stringValue=str(value))
36-
elif isinstance(value, int):
37-
return cls(longValue=value)
38-
elif isinstance(value, float):
39-
return cls(doubleValue=value)
40-
elif isinstance(value, bytes):
41-
return cls(blobValue=b64encode(value))
42-
elif value is None:
43-
return cls(isNull=True)
44-
elif type(value).__name__.endswith('UUID'):
45-
return cls(stringValue=str(value))
46-
elif type(value).__name__.endswith('PGobject'):
47-
return cls(stringValue=str(value))
48-
elif type(value).__name__.endswith('BigInteger'):
49-
return cls(longValue=int(str(value)))
50-
elif type(value).__name__.endswith('PgArray'):
51-
return cls(stringValue=str(value))
52-
else:
53-
raise Exception(f'unsupported type {type(value)}: {value} ')
54-
5527

5628
class SqlParameter(BaseModel):
5729
name: str

local_data_api/resources/jdbc/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def execute(
115115
response = ExecuteStatementResponse(
116116
numberOfRecordsUpdated=0,
117117
records=[
118-
[Field.from_value(column) for column in row]
118+
[self.get_field_from_value(column) for column in row]
119119
for row in cursor.fetchall()
120120
],
121121
)
@@ -129,7 +129,9 @@ def execute(
129129
last_generated_id: int = self.last_generated_id(cursor)
130130
generated_fields: List[Field] = []
131131
if last_generated_id > 0:
132-
generated_fields.append(Field.from_value(last_generated_id))
132+
generated_fields.append(
133+
self.get_field_from_value(last_generated_id)
134+
)
133135
return ExecuteStatementResponse(
134136
numberOfRecordsUpdated=rowcount,
135137
generatedFields=generated_fields,

local_data_api/resources/jdbc/mysql.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from typing import Any
2+
13
import jaydebeapi
24
from sqlalchemy.dialects import mysql
35
from sqlalchemy.engine import Dialect
46

7+
from local_data_api.models import Field
58
from local_data_api.resources.jdbc import JDBC
69
from local_data_api.resources.resource import register_resource_type
710

@@ -20,3 +23,9 @@ def reset_generated_id(cursor: jaydebeapi.Cursor) -> None:
2023
def last_generated_id(cursor: jaydebeapi.Cursor) -> int:
2124
cursor.execute("SELECT LAST_INSERT_ID()")
2225
return int(str(cursor.fetchone()[0]))
26+
27+
def get_field_from_value(self, value: Any) -> Field:
28+
if type(value).__name__.endswith('BigInteger'):
29+
return Field(longValue=int(str(value)))
30+
else:
31+
return super().get_field_from_value(value)

local_data_api/resources/jdbc/postgres.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
from typing import Any
2+
13
import jaydebeapi
24
from sqlalchemy.dialects import postgresql
35
from sqlalchemy.engine import Dialect
46

7+
from local_data_api.models import Field
58
from local_data_api.resources.jdbc import JDBC
69
from local_data_api.resources.resource import register_resource_type
710

@@ -19,3 +22,13 @@ def reset_generated_id(cursor: jaydebeapi.Cursor) -> None:
1922
@staticmethod
2023
def last_generated_id(cursor: jaydebeapi.Cursor) -> int:
2124
return 0
25+
26+
def get_field_from_value(self, value: Any) -> Field:
27+
if type(value).__name__.endswith('UUID'):
28+
return Field(stringValue=str(value))
29+
elif type(value).__name__.endswith('PGobject'):
30+
return Field(stringValue=str(value))
31+
elif type(value).__name__.endswith('PgArray'):
32+
return Field(stringValue=str(value))
33+
else:
34+
return super().get_field_from_value(value)

local_data_api/resources/mysql.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pymysql.protocol import FieldDescriptorPacket
88
from sqlalchemy.dialects import mysql
99

10-
from local_data_api.models import ColumnMetadata
10+
from local_data_api.models import ColumnMetadata, Field
1111
from local_data_api.resources.resource import Resource, register_resource_type
1212

1313
if TYPE_CHECKING: # pragma: no cover
@@ -73,3 +73,6 @@ def connect(database: Optional[str] = None): # type: ignore
7373
return pymysql.connect(**kwargs)
7474

7575
return connect
76+
77+
def get_field_from_value(self, value: Any) -> Field:
78+
return super().get_field_from_value(value)

local_data_api/resources/postgres.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
3+
from typing import TYPE_CHECKING, Any, Dict, List, Optional
44

55
import psycopg2
66
from psycopg2._psycopg import Column
77
from sqlalchemy.dialects import postgresql
88

9-
from local_data_api.models import ColumnMetadata
9+
from local_data_api.models import ColumnMetadata, Field
1010
from local_data_api.resources.resource import Resource, register_resource_type
1111

1212
if TYPE_CHECKING: # pragma: no cover
@@ -68,3 +68,6 @@ def connect(database: Optional[str] = None): # type: ignore
6868
return psycopg2.connect(**kwargs)
6969

7070
return connect
71+
72+
def get_field_from_value(self, value: Any) -> Field:
73+
return super().get_field_from_value(value)

local_data_api/resources/resource.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
import re
55
import string
66
from abc import ABC, abstractmethod
7+
from base64 import b64encode
78
from dataclasses import dataclass
89
from enum import Enum
910
from hashlib import sha1
10-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, Type, Union
11+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
1112

12-
from pydantic.main import BaseModel
1313
from sqlalchemy import text
1414
from sqlalchemy.engine import Dialect
1515
from sqlalchemy.exc import ArgumentError, CompileError
@@ -328,6 +328,25 @@ def create_transaction_id() -> str:
328328
for _ in range(TRANSACTION_ID_LENGTH)
329329
)
330330

331+
@abstractmethod
332+
def get_field_from_value(self, value: Any) -> Field:
333+
if isinstance(value, bool):
334+
return Field(booleanValue=value)
335+
elif isinstance(value, str):
336+
return Field(stringValue=value)
337+
elif type(value).__name__ == 'datetime':
338+
return Field(stringValue=str(value))
339+
elif isinstance(value, int):
340+
return Field(longValue=value)
341+
elif isinstance(value, float):
342+
return Field(doubleValue=value)
343+
elif isinstance(value, bytes):
344+
return Field(blobValue=b64encode(value))
345+
elif value is None:
346+
return Field(isNull=True)
347+
else:
348+
raise Exception(f'unsupported type {type(value)}: {value} ')
349+
331350
@abstractmethod
332351
def create_column_metadata_set(self, cursor: Cursor) -> List[ColumnMetadata]:
333352
raise NotImplementedError
@@ -369,7 +388,7 @@ def execute(
369388
response: ExecuteStatementResponse = ExecuteStatementResponse(
370389
numberOfRecordsUpdated=0,
371390
records=[
372-
[Field.from_value(column) for column in row]
391+
[self.get_field_from_value(column) for column in row]
373392
for row in cursor.fetchall()
374393
],
375394
)
@@ -383,7 +402,9 @@ def execute(
383402
last_generated_id: int = cursor.lastrowid
384403
generated_fields: List[Field] = []
385404
if last_generated_id > 0:
386-
generated_fields.append(Field.from_value(last_generated_id))
405+
generated_fields.append(
406+
self.get_field_from_value(last_generated_id)
407+
)
387408
return ExecuteStatementResponse(
388409
numberOfRecordsUpdated=rowcount,
389410
generatedFields=generated_fields,

local_data_api/resources/sqlite.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from sqlalchemy.dialects import sqlite
77

8-
from local_data_api.models import ColumnMetadata
8+
from local_data_api.models import ColumnMetadata, Field
99
from local_data_api.resources.resource import Resource, register_resource_type
1010

1111
if TYPE_CHECKING: # pragma: no cover
@@ -32,3 +32,6 @@ def connect(_: Optional[str] = None): # type: ignore
3232
return sqlite3.connect(':memory:')
3333

3434
return connect
35+
36+
def get_field_from_value(self, value: Any) -> Field:
37+
return super().get_field_from_value(value)

tests/test_models.py

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -42,61 +42,3 @@ def test_valid_field() -> None:
4242
).valid_value
4343
== '2020-02-27'
4444
)
45-
46-
47-
def test_from_value() -> None:
48-
assert Field.from_value('str') == Field(stringValue='str')
49-
assert Field.from_value(123) == Field(longValue=123)
50-
assert Field.from_value(1.23) == Field(doubleValue=1.23)
51-
assert Field.from_value(True) == Field(booleanValue=True)
52-
assert Field.from_value(False) == Field(booleanValue=False)
53-
assert Field.from_value(b'bytes') == Field(blobValue=b64encode(b'bytes'))
54-
assert Field.from_value(None) == Field(isNull=True)
55-
assert Field.from_value(datetime(2019, 5, 18, 15, 17, 8)) == Field(
56-
stringValue='2019-05-18 15:17:08'
57-
)
58-
59-
class JavaUUID:
60-
def __init__(self, val: str):
61-
self._val: str = val
62-
63-
def __str__(self) -> str:
64-
return self._val
65-
66-
uuid = 'e9e1df6b-c6d3-4a34-9227-c27056d596c6'
67-
assert Field.from_value(JavaUUID(uuid)) == Field(stringValue=uuid)
68-
69-
class PGobject:
70-
def __init__(self, val: str):
71-
self._val: str = val
72-
73-
def __str__(self) -> str:
74-
return self._val
75-
76-
assert Field.from_value(PGobject("{}")) == Field(stringValue="{}")
77-
78-
class BigInteger:
79-
def __init__(self, val: int):
80-
self._val: int = val
81-
82-
def __str__(self) -> int:
83-
return self._val
84-
85-
assert Field.from_value(BigInteger("55")) == Field(longValue=55)
86-
87-
class PgArray:
88-
def __init__(self, val: str):
89-
self._val: str = val
90-
91-
def __str__(self) -> str:
92-
return self._val
93-
94-
assert Field.from_value(PgArray("{ITEM1,ITEM2}")) == Field(
95-
stringValue="{ITEM1,ITEM2}"
96-
)
97-
98-
class Dummy:
99-
pass
100-
101-
with pytest.raises(Exception):
102-
Field.from_value(Dummy())

tests/test_resource/test_jdbc/test_jdbc.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
from typing import Any
4+
35
import jaydebeapi
46
import pytest
57

6-
from local_data_api.models import ColumnMetadata
8+
from local_data_api.models import ColumnMetadata, Field
79
from local_data_api.resources.jdbc import JDBC, attach_thread_to_jvm, connection_maker
810

911

@@ -23,6 +25,9 @@ def last_generated_id(cursor: jaydebeapi.Cursor) -> int:
2325
cursor.execute("SELECT LAST_INSERT_ID()")
2426
return int(str(cursor.fetchone()[0]))
2527

28+
def get_field_from_value(self, value: Any) -> Field:
29+
return super().get_field_from_value(value)
30+
2631

2732
def test_attach_thread_to_jvm(mocker):
2833
mock_jpype = mocker.Mock()

0 commit comments

Comments
 (0)