Skip to content

Commit 6927565

Browse files
fix(typing): lots of typing, some other improvements
1 parent 2e6f40d commit 6927565

34 files changed

+1030
-263
lines changed

src/ynamazon/amazon_transactions.py

Lines changed: 68 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,20 @@
33
import tempfile
44
from datetime import date
55
from decimal import Decimal
6-
from typing import Annotated, Union # , Self # not available python <3.11
6+
from typing import Annotated, Union, cast # , Self # not available python <3.11
77

88
from amazonorders.entity.order import Order
99
from amazonorders.entity.transaction import Transaction
10-
from amazonorders.orders import AmazonOrders
1110
from amazonorders.session import AmazonSession
12-
from amazonorders.transactions import AmazonTransactions
1311
from cache_decorator import Cache
1412
from loguru import logger
1513
from pydantic import AnyUrl, BaseModel, EmailStr, Field, SecretStr, field_validator
1614
from rich import print as rprint
1715
from rich.table import Table
1816

17+
from ynamazon.order_models import AmazonOrderModels
18+
from ynamazon.transaction_models import AmazonTransactionModels
19+
1920
from .settings import settings
2021
from .types_pydantic import AmazonItemType
2122

@@ -42,16 +43,16 @@ def invert_value(cls, value: Decimal) -> Decimal:
4243
@classmethod
4344
def from_transaction_and_orders(cls, orders_dict: "dict[str, Order]", transaction: Transaction):
4445
"""Creates an instance from an order and transactions."""
45-
order = orders_dict.get(transaction.order_number) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
46+
order = orders_dict.get(transaction.order_number)
4647
if order is None:
47-
raise ValueError(f"Order with number {transaction.order_number} not found.") # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
48+
raise ValueError(f"Order with number {transaction.order_number} not found.")
4849
return cls(
49-
completed_date=transaction.completed_date, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
50-
transaction_total=transaction.grand_total, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
51-
order_total=order.grand_total, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
52-
order_number=order.order_number, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
53-
order_link=order.order_details_link, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
54-
items=order.items, # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
50+
completed_date=transaction.completed_date,
51+
transaction_total=transaction.grand_total, # pyright: ignore[reportArgumentType]
52+
order_total=order.grand_total, # pyright: ignore[reportArgumentType]
53+
order_number=order.order_number,
54+
order_link=order.order_details_link, # pyright: ignore[reportArgumentType]
55+
items=order.items,
5556
)
5657

5758

@@ -77,11 +78,19 @@ def amazon_session(self) -> AmazonSession:
7778
)
7879

7980

81+
def validate_year(year: str | int) -> int:
82+
"""Validates that a year has been entered with 4 digits."""
83+
year_str = str(year)
84+
if len(year_str) == 4 and year_str.isdigit():
85+
return int(year_str)
86+
raise ValueError("Year must be a 4-digit number.")
87+
88+
8089
class AmazonTransactionRetriever:
8190
def __init__(
8291
self,
8392
amazon_config: AmazonConfig,
84-
order_years: list[str] | None = None,
93+
order_years: list[str | int] | None = None,
8594
transaction_days: int = 31,
8695
force_refresh_amazon: bool = False,
8796
):
@@ -92,13 +101,15 @@ def __init__(
92101
transaction_days (int): Number of days to fetch transactions for. Defaults to 31.
93102
force_refresh_amazon (bool): Refresh cache by fetching transactions directly from Amazon.
94103
"""
95-
self.amazon_config = amazon_config
96-
self.order_years = self.__class__._normalized_years(order_years)
97-
self.transaction_days = transaction_days
98-
self.force_refresh_amazon = force_refresh_amazon
104+
self.amazon_config: AmazonConfig = amazon_config
105+
self.order_years: list[int] = self.normalize_years(order_years)
106+
self.transaction_days: int = transaction_days
107+
self.force_refresh_amazon: bool = force_refresh_amazon
99108

100109
# for memoizing the results of method calls
101-
self._memo = {}
110+
self._session: AmazonSession | None = None
111+
self._amazon_orders: AmazonOrderModels | None = None
112+
self._amazon_transactions: AmazonTransactionModels | None = None
102113

103114
def get_amazon_transactions(self) -> list[AmazonTransactionWithOrderInfo]:
104115
"""Get Amazon transactions linked to orders.
@@ -108,13 +119,17 @@ def get_amazon_transactions(self) -> list[AmazonTransactionWithOrderInfo]:
108119
Returns:
109120
list[TransactionWithOrderInfo]: A list of transactions with order info
110121
"""
111-
return self._get_amazon_transactions(
112-
order_years=self.order_years,
113-
transaction_days=self.transaction_days,
114-
amazon_config=self.amazon_config,
115-
use_cache=not self.force_refresh_amazon,
122+
return cast(
123+
"list[AmazonTransactionWithOrderInfo]",
124+
self._get_amazon_transactions(
125+
order_years=self.order_years,
126+
transaction_days=self.transaction_days,
127+
amazon_config=self.amazon_config,
128+
use_cache=not self.force_refresh_amazon,
129+
),
116130
)
117131

132+
# HACK: unused parameters are needed for caching to work correctly
118133
@Cache(
119134
validity_duration="2h",
120135
enable_cache_arg_name="use_cache",
@@ -126,13 +141,14 @@ def get_amazon_transactions(self) -> list[AmazonTransactionWithOrderInfo]:
126141
)
127142
def _get_amazon_transactions(
128143
self,
129-
order_years: list[str],
130-
transaction_days: int,
131-
amazon_config: AmazonConfig,
144+
order_years: list[str], # pyright: ignore[reportUnusedParameter]
145+
transaction_days: int, # pyright: ignore[reportUnusedParameter]
146+
amazon_config: AmazonConfig, # pyright: ignore[reportUnusedParameter]
147+
use_cache: bool = True, # pyright: ignore[reportUnusedParameter]
132148
) -> list[AmazonTransactionWithOrderInfo]:
133-
orders_dict = {order.order_number: order for order in self._amazon_orders()}
149+
orders_dict = {order.order_number: order for order in self.fetch_amazon_orders()}
134150

135-
amazon_transactions = self._amazon_transactions()
151+
amazon_transactions = self.fetch_amazon_transactions()
136152

137153
amazon_transaction_with_order_details: list[AmazonTransactionWithOrderInfo] = []
138154
for transaction in amazon_transactions:
@@ -150,7 +166,7 @@ def _get_amazon_transactions(
150166

151167
return amazon_transaction_with_order_details
152168

153-
def _amazon_orders(self) -> list[Order]:
169+
def fetch_amazon_orders(self) -> AmazonOrderModels:
154170
"""Returns a list of Amazon orders.
155171
156172
Args:
@@ -159,60 +175,48 @@ def _amazon_orders(self) -> list[Order]:
159175
Returns:
160176
list[Order]: A list of Amazon orders.
161177
"""
162-
if "amazon_orders" in self._memo:
163-
return self._memo["amazon_orders"]
178+
if self._amazon_orders is not None:
179+
return self._amazon_orders
164180

165-
amazon_orders = AmazonOrders(self._session())
181+
orders = AmazonOrderModels.get_order_history(self.get_session(), self.order_years)
182+
orders.sort_by_order_placed_date()
166183

167-
all_orders: list[Order] = []
168-
for year in self.order_years:
169-
all_orders.extend(amazon_orders.get_order_history(year=year))
170-
all_orders.sort(key=lambda order: order.order_placed_date)
184+
self._amazon_orders = orders
171185

172-
self._memo["amazon_orders"] = all_orders
186+
return self._amazon_orders
173187

174-
return self._memo["amazon_orders"]
175-
176-
def _amazon_transactions(self) -> list[Transaction]:
188+
def fetch_amazon_transactions(self) -> AmazonTransactionModels:
177189
"""Fetches and sorts Amazon transactions."""
178-
if "amazon_transactions" in self._memo:
179-
return self._memo["amazon_transactions"]
190+
if self._amazon_transactions is not None:
191+
return self._amazon_transactions
180192

181-
self._memo["amazon_transactions"] = AmazonTransactions(
182-
amazon_session=self._session()
183-
).get_transactions(days=self.transaction_days)
193+
transactions = AmazonTransactionModels.get_transactions(
194+
self.get_session(), self.transaction_days
195+
)
196+
transactions.sort_by_completed_date()
184197

185-
self._memo["amazon_transactions"].sort(key=lambda trans: trans.completed_date)
198+
self._amazon_transactions = transactions
186199

187-
return self._memo["amazon_transactions"]
200+
return self._amazon_transactions
188201

189-
def _session(self) -> AmazonSession:
190-
if "session" in self._memo:
191-
return self._memo["session"]
202+
def get_session(self) -> AmazonSession:
203+
if self._session is not None:
204+
return self._session
192205

193206
amazon_session = self.amazon_config.amazon_session()
194207
amazon_session.login()
195208

196209
if amazon_session.is_authenticated:
197-
self._memo["session"] = amazon_session
198-
return self._memo["session"]
210+
self._session = amazon_session
211+
return self._session
212+
raise ValueError("Failed to authenticate with Amazon.")
199213

200214
@classmethod
201-
def _normalized_years(cls, years: list[str] | None = None) -> list[str]:
215+
def normalize_years(cls, years: list[str | int] | None = None) -> list[int]:
202216
if years is None:
203217
return [date.today().year]
204218

205-
result: list[str] = []
206-
207-
for year in years:
208-
if len(year) == 2:
209-
result.append("20" + year)
210-
elif len(year) == 4:
211-
result.append(year)
212-
else:
213-
raise ValueError("Year must be specified as 2 or 4 digits (e.g. 21 or 2021)")
214-
215-
return result
219+
return [validate_year(year) for year in years]
216220

217221

218222
def print_amazon_transactions(
@@ -239,7 +243,7 @@ def print_amazon_transactions(
239243
f"${transaction.order_total:.2f}",
240244
transaction.order_number,
241245
str(transaction.order_link),
242-
" | ".join(_truncate_title(item.title) for item in transaction.items), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
246+
" | ".join(_truncate_title(item.title) for item in transaction.items),
243247
)
244248

245249
rprint(table)

src/ynamazon/base.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from collections.abc import Callable, Iterable, Mapping
22
from typing import Generic, Self, SupportsIndex, TypeVar, cast, overload, override
33

4-
from pydantic import BaseModel, RootModel
4+
from pydantic import BaseModel, Field, RootModel
55
from pydantic_core.core_schema import ListSchema, ModelSchema
66

7-
_ModelT = TypeVar("_ModelT", bound=BaseModel)
7+
_ModelT = TypeVar("_ModelT")
88
_T = TypeVar("_T")
99

1010

@@ -92,3 +92,18 @@ def empty(cls) -> "DictRootModel[_KV, _VT]":
9292
def __len__(self) -> int:
9393
"""Returns the number of items in the dict."""
9494
return len(self.root)
95+
96+
97+
class MultiLineText(BaseModel):
98+
"""A class to handle multi-line text."""
99+
100+
lines: list[str] = Field(default_factory=list)
101+
102+
@override
103+
def __str__(self) -> str:
104+
"""Returns the string representation of the object."""
105+
return "\n".join(self.lines)
106+
107+
def append(self, line: str) -> None:
108+
"""Appends a line to the text."""
109+
self.lines.append(line)

src/ynamazon/main.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from typing import TYPE_CHECKING, override
1+
from typing import TYPE_CHECKING
22

33
from loguru import logger
4-
from pydantic import BaseModel, Field
54
from rich.console import Console
65
from rich.prompt import Confirm
76

@@ -10,6 +9,7 @@
109
AmazonTransactionRetriever,
1110
locate_amazon_transaction_by_amount,
1211
)
12+
from ynamazon.base import MultiLineText
1313
from ynamazon.exceptions import YnabSetupError
1414
from ynamazon.settings import settings
1515
from ynamazon.ynab_memo import process_memo
@@ -30,21 +30,6 @@
3030
from ynab.configuration import Configuration
3131

3232

33-
class MultiLineText(BaseModel):
34-
"""A class to handle multi-line text."""
35-
36-
lines: list[str] = Field(default_factory=list)
37-
38-
@override
39-
def __str__(self) -> str:
40-
"""Returns the string representation of the object."""
41-
return "\n".join(self.lines)
42-
43-
def append(self, line: str) -> None:
44-
"""Appends a line to the text."""
45-
self.lines.append(line)
46-
47-
4833
# TODO: reduce complexity of this function
4934
def process_transactions( # noqa: C901
5035
amazon_config: AmazonConfig | None = None,

0 commit comments

Comments
 (0)