Skip to content

Commit 2022cc8

Browse files
committed
Use typed exceptions
1 parent 38ca71b commit 2022cc8

File tree

3 files changed

+58
-28
lines changed

3 files changed

+58
-28
lines changed

aurora_data_api/__init__.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import namedtuple
88
from collections.abc import Mapping
99
from .exceptions import (Warning, Error, InterfaceError, DatabaseError, DataError, OperationalError, IntegrityError,
10-
InternalError, ProgrammingError, NotSupportedError)
10+
InternalError, ProgrammingError, NotSupportedError, MySQLError, PostgreSQLError)
1111
from .mysql_error_codes import MySQLErrorCodes
1212
from .postgresql_error_codes import PostgreSQLErrorCodes
1313
import boto3
@@ -219,18 +219,20 @@ def _format_parameter_set(self, parameters):
219219
def _get_database_error(self, original_error):
220220
error_msg = getattr(original_error, "response", {}).get("Error", {}).get("Message", "")
221221
try:
222-
if error_msg.startswith("Database error code"): # MySQL error
223-
code, msg = (s.split(": ", 1)[1] for s in error_msg.split(". ", 1))
224-
mysql_error_code = MySQLErrorCodes(int(code))
225-
return DatabaseError(mysql_error_code, msg)
226-
elif error_msg.startswith("ERROR: "): # PostgreSQL error
227-
error_msg = error_msg[len("ERROR: "):]
228-
error_lines = error_msg.splitlines()
229-
if error_lines[-1].startswith(" Position: ") and " SQLState: " in error_lines[-1]:
230-
position, sqlstate = (i.split(":", 1)[1].strip() for i in error_lines[-1].strip().split(";"))
231-
postgres_error_code = PostgreSQLErrorCodes(sqlstate)
232-
return DatabaseError(postgres_error_code, "\n".join(error_lines[:-1]), int(position))
233-
raise Exception("unable to parse postgresql error")
222+
res = re.search(r"Error code: (\d+); SQLState: (\d+)$", error_msg)
223+
if res: # MySQL error
224+
error_code = int(res.group(1))
225+
error_class = MySQLError.from_code(error_code)
226+
error = error_class(error_msg)
227+
error.response = getattr(original_error, "response", {})
228+
return error
229+
res = re.search(r"ERROR: .*\n Position: (\d+); SQLState: (\w+)$", error_msg)
230+
if res: # PostgreSQL error
231+
error_code = res.group(2)
232+
error_class = PostgreSQLError.from_code(error_code)
233+
error = error_class(error_msg)
234+
error.response = getattr(original_error, "response", {})
235+
return error
234236
except Exception:
235237
pass
236238
return DatabaseError(original_error)

aurora_data_api/exceptions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from .mysql_error_codes import MySQLErrorCodes
2+
from .postgresql_error_codes import PostgreSQLErrorCodes
3+
4+
15
class Warning(Exception):
26
pass
37

@@ -36,3 +40,25 @@ class ProgrammingError(DatabaseError):
3640

3741
class NotSupportedError(DatabaseError):
3842
pass
43+
44+
45+
class _DatabaseErrorFactory:
46+
def __getattr__(self, a):
47+
err_cls = type(getattr(self.err_index, a).name, (DatabaseError, ), {})
48+
setattr(self, a, err_cls)
49+
return err_cls
50+
51+
def from_code(self, err_code):
52+
return getattr(self, self.err_index(err_code).name)
53+
54+
55+
class _MySQLErrorFactory(_DatabaseErrorFactory):
56+
err_index = MySQLErrorCodes
57+
58+
59+
class _PostgreSQLErrorFactory(_DatabaseErrorFactory):
60+
err_index = PostgreSQLErrorCodes
61+
62+
63+
MySQLError = _MySQLErrorFactory()
64+
PostgreSQLError = _PostgreSQLErrorFactory()

test/test.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,19 @@ def setUpClass(cls):
4141
"ts": "2020-09-17 13:49:32.780180",
4242
} for i in range(2048)]
4343
)
44-
except aurora_data_api.DatabaseError as e:
45-
if e.args[0] != MySQLErrorCodes.ER_PARSE_ERROR:
46-
raise
44+
except aurora_data_api.MySQLError.ER_PARSE_ERROR:
4745
cls.using_mysql = True
4846
cur.execute("DROP TABLE IF EXISTS aurora_data_api_test")
49-
cur.execute(
50-
"CREATE TABLE aurora_data_api_test (id SERIAL, name TEXT, birthday DATE, num NUMERIC(10, 5))"
51-
)
47+
cur.execute("CREATE TABLE aurora_data_api_test "
48+
"(id SERIAL, name TEXT, birthday DATE, num NUMERIC(10, 5), ts TIMESTAMP)")
5249
cur.executemany(
53-
"INSERT INTO aurora_data_api_test(name, birthday, num) VALUES (:name, :birthday, :num)", [{
50+
("INSERT INTO aurora_data_api_test(name, birthday, num, ts) VALUES "
51+
"(:name, :birthday, :num, CAST(:ts AS DATETIME))"),
52+
[{
5453
"name": "row{}".format(i),
5554
"birthday": "2000-01-01",
56-
"num": decimal.Decimal("%d.%d" % (i, i))
55+
"num": decimal.Decimal("%d.%d" % (i, i)),
56+
"ts": "2020-09-17 13:49:32.780180",
5757
} for i in range(2048)]
5858
)
5959

@@ -64,7 +64,8 @@ def tearDownClass(cls):
6464

6565
def test_invalid_statements(self):
6666
with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur:
67-
with self.assertRaisesRegex(aurora_data_api.DatabaseError, "syntax"):
67+
with self.assertRaises((aurora_data_api.exceptions.PostgreSQLError.ER_SYNTAX_ERR,
68+
aurora_data_api.MySQLError.ER_PARSE_ERROR)):
6869
cur.execute("selec * from table")
6970

7071
def test_iterators(self):
@@ -83,8 +84,10 @@ def test_iterators(self):
8384
expect_row0 = (1,
8485
'row0',
8586
datetime.date(2000, 1, 1) if self.using_mysql else '{"x": 0, "y": "0", "z": [0, 0, 1]}',
86-
decimal.Decimal(0),
87-
datetime.datetime(2020, 9, 17, 13, 49, 32, 780180))
87+
decimal.Decimal(0.0),
88+
datetime.datetime(2020, 9, 17, 13, 49, 33)
89+
if self.using_mysql
90+
else datetime.datetime(2020, 9, 17, 13, 49, 32, 780180))
8891
i = 0
8992
cursor.execute("select * from aurora_data_api_test")
9093
for f in cursor:
@@ -142,12 +145,11 @@ def test_postgres_exceptions(self):
142145
return
143146
with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur:
144147
table = "aurora_data_api_nonexistent_test_table"
145-
with self.assertRaises(aurora_data_api.DatabaseError) as e:
148+
with self.assertRaises(aurora_data_api.exceptions.PostgreSQLError.ER_UNDEF_TABLE) as e:
146149
sql = f"select * from {table}"
147150
cur.execute(sql)
148-
self.assertEqual(e.exception.args, (PostgreSQLErrorCodes.ER_UNDEF_TABLE,
149-
f'relation "{table}" does not exist',
150-
15))
151+
self.assertTrue(f'relation "{table}" does not exist' in str(e.exception))
152+
self.assertTrue(isinstance(e.exception.response, dict))
151153

152154
def test_rowcount(self):
153155
with aurora_data_api.connect(database=self.db_name) as conn, conn.cursor() as cur:

0 commit comments

Comments
 (0)