Skip to content

Commit ded58f2

Browse files
committed
refactor: Fix more easy mypy issues.
1 parent 5680e04 commit ded58f2

6 files changed

Lines changed: 74 additions & 25 deletions

File tree

actual/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -346,8 +346,11 @@ def export_data(self, output_file: str | PathLike[str] | IO[bytes] | None = None
346346
z.write(self.data_dir / "metadata.json", "metadata.json")
347347
content = temp_file.getvalue()
348348
if output_file:
349-
with open(output_file, "wb") as f:
350-
f.write(content)
349+
if isinstance(output_file, (str, PathLike)):
350+
with open(output_file, "wb") as f:
351+
f.write(content)
352+
else:
353+
output_file.write(content)
351354
return content
352355

353356
def encrypt(self, encryption_password: str):

actual/budgets.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from sqlmodel import Session, col, select
77

8-
from actual.database import Categories, CategoryGroups, ReflectBudgets, Transactions, ZeroBudgets
8+
from actual.database import BaseBudgets, Categories, CategoryGroups, ReflectBudgets, Transactions
99
from actual.queries import (
1010
_balance_base_query,
1111
_get_budget_table,
@@ -64,7 +64,7 @@ class BudgetCategory(_HasDatabaseObject):
6464
This reflects the values displayed on the Actual frontend.
6565
"""
6666

67-
budget: ReflectBudgets | ZeroBudgets | None = None
67+
budget: BaseBudgets | None = None
6868
"""
6969
The underlying budget database record, if it exists.
7070

actual/database.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,19 @@ def get_attribute_from_reflected_table_name(metadata: MetaData, table_name: str,
104104
return table.columns.get(column_name, None)
105105

106106

107-
def get_class_by_table_name(table_name: str) -> type["BaseModel"] | None:
107+
def get_class_by_table_name(table_name: str) -> type["BaseModel"]:
108108
"""
109109
Returns, based on the defined tables `__tablename__` the corresponding SQLModel object.
110110
111-
If not found, returns `None`.
111+
If not found, raises `ValueError`.
112+
113+
:param table_name: SQL table name.
114+
:return SQLModel: SQLAlchemy object.
115+
:raises ValueError: Raises `ValueError` if the table name is not existing.
112116
"""
113117
entry = __TABLE_COLUMNS_MAP__.get(table_name)
114118
if entry is None:
115-
return None
119+
raise ValueError(f"Could not find table '{table_name}' on the database model.")
116120
return entry["entity"]
117121

118122

actual/queries.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
from sqlalchemy.orm import joinedload
1414
from sqlalchemy.sql.expression import Select
1515
from sqlmodel import Session, col, select
16+
from sqlmodel.sql.expression import SelectOfScalar
1617

1718
from actual.crypto import is_uuid
1819
from actual.database import (
1920
Accounts,
21+
BaseBudgets,
2022
Categories,
2123
CategoryGroups,
2224
CategoryMapping,
@@ -56,8 +58,8 @@ def _transactions_base_query(
5658
account: Accounts | str | None | None = None,
5759
category: Categories | str | None = None,
5860
include_deleted: bool = False,
59-
) -> Select:
60-
query = (
61+
) -> SelectOfScalar[Transactions]:
62+
query: SelectOfScalar[Transactions] = (
6163
select(Transactions)
6264
.options(
6365
joinedload(Transactions.account),
@@ -232,7 +234,7 @@ def match_transaction(
232234
query = _transactions_base_query(
233235
s, date - datetime.timedelta(days=7), date + datetime.timedelta(days=8), account=account
234236
).where(col(Transactions.amount) == round(amount * 100))
235-
results: list[Transactions] = s.exec(query).all() # noqa
237+
results: list[Transactions] = list(s.exec(query).all())
236238
# filter out the ones that were already matched
237239
if already_matched:
238240
matched = {t.id for t in already_matched}
@@ -535,13 +537,17 @@ def create_split(s: Session, transaction: Transactions, amount: float | decimal.
535537
return split
536538

537539

538-
def _base_query(instance: type[T], name: str | None = None, include_deleted: bool = False) -> Select:
540+
def _base_query(instance: type[T], name: str | None = None, include_deleted: bool = False) -> SelectOfScalar[T]:
539541
"""Internal method to reduce querying complexity on sub-functions."""
540542
query = select(instance)
541543
if not include_deleted:
542-
query = query.where(sqlalchemy.func.coalesce(instance.tombstone, 0) == 0)
544+
tombstone_col = getattr(instance, "tombstone", None)
545+
if tombstone_col is not None:
546+
query = query.where(sqlalchemy.func.coalesce(tombstone_col, 0) == 0)
543547
if name:
544-
query = query.where(instance.name.ilike(f"%{sqlalchemy.text(name).compile()}%"))
548+
name_col = getattr(instance, "name", None)
549+
if name_col is not None:
550+
query = query.where(name_col.ilike(f"%{sqlalchemy.text(name).compile()}%"))
545551
return query
546552

547553

@@ -839,7 +845,7 @@ def get_or_create_account(s: Session, name: str | Accounts) -> Accounts:
839845
return account
840846

841847

842-
def _get_budget_table(s: Session) -> type[ReflectBudgets | ZeroBudgets]:
848+
def _get_budget_table(s: Session) -> type[ZeroBudgets] | type[ReflectBudgets]:
843849
"""
844850
Finds out which type of budget the user uses. The types are:
845851
@@ -858,7 +864,7 @@ def _get_budget_table(s: Session) -> type[ReflectBudgets | ZeroBudgets]:
858864

859865
def get_budgets(
860866
s: Session, month: datetime.date | None = None, category: str | Categories | None = None
861-
) -> typing.Sequence[ZeroBudgets | ReflectBudgets]:
867+
) -> typing.Sequence[BaseBudgets]:
862868
"""
863869
Returns a list of all available budgets.
864870
@@ -889,7 +895,7 @@ def get_budgets(
889895
return s.exec(query).unique().all()
890896

891897

892-
def get_budget(s: Session, month: datetime.date, category: str | Categories) -> ZeroBudgets | ReflectBudgets | None:
898+
def get_budget(s: Session, month: datetime.date, category: str | Categories) -> BaseBudgets | None:
893899
"""
894900
Gets an existing budget by category name, returns `None` if not found.
895901
@@ -910,7 +916,7 @@ def create_budget(
910916
category: str | Categories,
911917
amount: decimal.Decimal | float | int = 0.0,
912918
carryover: bool | None = None,
913-
) -> ZeroBudgets | ReflectBudgets:
919+
) -> BaseBudgets:
914920
"""
915921
Gets an existing budget based on the month and category. If it already exists, the amount will be replaced by
916922
the new amount.
@@ -1140,6 +1146,32 @@ def get_schedules(
11401146
return s.exec(query).all()
11411147

11421148

1149+
@typing.overload
1150+
def create_schedule(
1151+
s: Session,
1152+
date: datetime.date | datetime.datetime | Schedule,
1153+
amount: tuple[decimal.Decimal, decimal.Decimal] | tuple[float, float],
1154+
amount_operation: typing.Literal["isbetween"],
1155+
name: str | None,
1156+
payee: str | Payees | None,
1157+
account: str | Accounts | None,
1158+
posts_transaction: bool,
1159+
) -> Schedules: ...
1160+
1161+
1162+
@typing.overload
1163+
def create_schedule(
1164+
s: Session,
1165+
date: datetime.date | datetime.datetime | Schedule,
1166+
amount: decimal.Decimal | float,
1167+
amount_operation: typing.Literal["is", "isapprox"],
1168+
name: str | None,
1169+
payee: str | Payees | None,
1170+
account: str | Accounts | None,
1171+
posts_transaction: bool,
1172+
) -> Schedules: ...
1173+
1174+
11431175
def create_schedule(
11441176
s: Session,
11451177
date: datetime.date | datetime.datetime | Schedule,
@@ -1171,9 +1203,6 @@ def create_schedule(
11711203
:param posts_transaction: Whether the schedule should auto-post transactions on your behalf. Defaults to false.
11721204
:return: Rule database object created.
11731205
"""
1174-
if amount_operation == "isbetween" and not isinstance(amount, tuple):
1175-
raise ActualError("When using 'isbetween', amount must be a tuple (num1, num2), where num1 < num2.")
1176-
11771206
schedule_id = str(uuid.uuid4())
11781207
conditions = []
11791208
# Handle the payee condition
@@ -1197,6 +1226,8 @@ def create_schedule(
11971226
)
11981227
# Handle the amount condition
11991228
if amount_operation == "isbetween":
1229+
if not isinstance(amount, tuple):
1230+
raise ActualError("When using 'isbetween', amount must be a tuple (num1, num2), where num1 < num2.")
12001231
conditions.append(
12011232
Condition(
12021233
field="amount",
@@ -1205,6 +1236,8 @@ def create_schedule(
12051236
)
12061237
)
12071238
else:
1239+
if isinstance(amount, tuple):
1240+
raise ActualError(f"When using '{amount_operation}', amount must be a single decimal number.")
12081241
conditions.append(Condition(field="amount", op=ConditionType(amount_operation), value=decimal_to_cents(amount)))
12091242

12101243
actions = [

tests/test_database.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -508,10 +508,6 @@ def test_schedule_is_betweeen(session):
508508
payee = get_or_create_payee(session, "Insurance company")
509509
# should always be paid on the first working day of the month
510510
config = create_schedule_config(expected_date, patterns=[Pattern(1, "day")], skip_weekend=True)
511-
# if the amount_operation="isbetween", the schedule needs two amounts
512-
with pytest.raises(ActualError, match="amount must be a tuple"):
513-
create_schedule(session, config, 100.0, "isbetween", "Insurance", payee, account)
514-
515511
schedule = create_schedule(session, config, (100.0, 110.0), "isbetween", "Insurance", payee, account)
516512
assert json.loads(schedule.rule.conditions) == [
517513
{"field": "description", "type": "id", "op": "is", "value": payee.id},
@@ -585,6 +581,18 @@ def test_schedule_populates_next_date_simple_date(session, start_date, expected_
585581
assert rows[0].base_next_date == expected_next_date
586582

587583

584+
def test_schedule_exceptions(session):
585+
expected_date = datetime.date(2025, 10, 11)
586+
account = create_account(session, "Bank")
587+
payee = get_or_create_payee(session, "Insurance company")
588+
# should always be paid on the first working day of the month
589+
config = create_schedule_config(expected_date, patterns=[Pattern(1, "day")], skip_weekend=True)
590+
with pytest.raises(ActualError, match="amount must be a tuple"):
591+
create_schedule(session, config, 100.0, "isbetween", "Insurance", payee, account)
592+
with pytest.raises(ActualError, match="amount must be a single decimal number"):
593+
create_schedule(session, config, (100.0, 110.0), "isapprox", "Insurance", payee, account)
594+
595+
588596
def test_get_transactions_with_cleared_filter(session):
589597
acct = create_account(session, "ClearedTxs")
590598
create_transaction(session, date=today, account=acct, amount=10, cleared=False)

tests/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
def test_get_class_by_table_name():
1818
assert get_class_by_table_name("transactions") == Transactions
19-
assert get_class_by_table_name("foo") is None
19+
with pytest.raises(ValueError, match="Could not find table 'foo'"):
20+
get_class_by_table_name("foo")
2021

2122

2223
def test_get_attribute_by_table_name():

0 commit comments

Comments
 (0)