Skip to content

Commit 3c33249

Browse files
committed
wip: implementing tests for the atomic behavior of delete_rows
1 parent 7065385 commit 3c33249

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

pandas/io/sql.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2879,8 +2879,8 @@ def drop_table(self, name: str, schema: str | None = None) -> None:
28792879
def delete_rows(self, name: str, schema: str | None = None) -> None:
28802880
delete_sql = f"DELETE FROM {_get_valid_sqlite_name(name)}"
28812881
if self.has_table(name, schema):
2882-
with self.run_transaction():
2883-
self.execute(delete_sql)
2882+
with self.run_transaction() as cur:
2883+
cur.execute(delete_sql)
28842884

28852885
def _create_sql_schema(
28862886
self,

pandas/tests/io/test_sql.py

+28
Original file line numberDiff line numberDiff line change
@@ -2716,6 +2716,34 @@ def test_delete_rows_success(conn, test_frame1, request):
27162716
assert pandasSQL.has_table("temp_frame")
27172717

27182718

2719+
@pytest.mark.parametrize("conn", adbc_connectable)
2720+
def test_delete_rows_is_atomic(conn, request):
2721+
import adbc_driver_manager
2722+
2723+
if "sqlite" in conn:
2724+
pytest.skip("sqlite has no inspection system") # TODO: Change error message
2725+
2726+
table_name = "temp_frame"
2727+
original_df = DataFrame({"a": [1, 2, 3]})
2728+
replacing_df = DataFrame({"a": ["a", "b", "c", "d"]})
2729+
2730+
conn = request.getfixturevalue(conn)
2731+
pandasSQL = pandasSQL_builder(conn)
2732+
2733+
with pandasSQL.run_transaction():
2734+
pandasSQL.to_sql(original_df, table_name, if_exists="fail", index=False)
2735+
2736+
with pytest.raises(adbc_driver_manager.ProgrammingError):
2737+
with pandasSQL.run_transaction():
2738+
pandasSQL.to_sql(
2739+
replacing_df, table_name, if_exists="delete_rows", index=False
2740+
)
2741+
2742+
with pandasSQL.run_transaction():
2743+
unchanged_df = pandasSQL.read_query(f"SELECT * FROM {table_name}")
2744+
tm.assert_frame_equal(unchanged_df, original_df)
2745+
2746+
27192747
@pytest.mark.parametrize("conn", all_connectable)
27202748
def test_roundtrip(conn, request, test_frame1):
27212749
if conn == "sqlite_str":

0 commit comments

Comments
 (0)