33import tempfile
44from datetime import date
55from decimal import Decimal
6- from typing import Annotated , Union # , Self # not available python <3.11
6+ from typing import Annotated , Union , cast # , Self # not available python <3.11
77
88from amazonorders .entity .order import Order
99from amazonorders .entity .transaction import Transaction
10- from amazonorders .orders import AmazonOrders
1110from amazonorders .session import AmazonSession
12- from amazonorders .transactions import AmazonTransactions
1311from cache_decorator import Cache
1412from loguru import logger
1513from pydantic import AnyUrl , BaseModel , EmailStr , Field , SecretStr , field_validator
1614from rich import print as rprint
1715from rich .table import Table
1816
17+ from ynamazon .order_models import AmazonOrderModels
18+ from ynamazon .transaction_models import AmazonTransactionModels
19+
1920from .settings import settings
2021from .types_pydantic import AmazonItemType
2122
@@ -42,16 +43,16 @@ def invert_value(cls, value: Decimal) -> Decimal:
4243 @classmethod
4344 def from_transaction_and_orders (cls , orders_dict : "dict[str, Order]" , transaction : Transaction ):
4445 """Creates an instance from an order and transactions."""
45- order = orders_dict .get (transaction .order_number ) # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
46+ order = orders_dict .get (transaction .order_number )
4647 if order is None :
47- raise ValueError (f"Order with number { transaction .order_number } not found." ) # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue]
48+ raise ValueError (f"Order with number { transaction .order_number } not found." )
4849 return cls (
49- completed_date = transaction .completed_date , # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
50- transaction_total = transaction .grand_total , # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue ]
51- order_total = order .grand_total , # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue ]
52- order_number = order .order_number , # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
53- order_link = order .order_details_link , # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue ]
54- items = order .items , # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
50+ completed_date = transaction .completed_date ,
51+ transaction_total = transaction .grand_total , # pyright: ignore[reportArgumentType ]
52+ order_total = order .grand_total , # pyright: ignore[reportArgumentType ]
53+ order_number = order .order_number ,
54+ order_link = order .order_details_link , # pyright: ignore[reportArgumentType ]
55+ items = order .items ,
5556 )
5657
5758
@@ -77,11 +78,19 @@ def amazon_session(self) -> AmazonSession:
7778 )
7879
7980
81+ def validate_year (year : str | int ) -> int :
82+ """Validates that a year has been entered with 4 digits."""
83+ year_str = str (year )
84+ if len (year_str ) == 4 and year_str .isdigit ():
85+ return int (year_str )
86+ raise ValueError ("Year must be a 4-digit number." )
87+
88+
8089class AmazonTransactionRetriever :
8190 def __init__ (
8291 self ,
8392 amazon_config : AmazonConfig ,
84- order_years : list [str ] | None = None ,
93+ order_years : list [str | int ] | None = None ,
8594 transaction_days : int = 31 ,
8695 force_refresh_amazon : bool = False ,
8796 ):
@@ -92,13 +101,15 @@ def __init__(
92101 transaction_days (int): Number of days to fetch transactions for. Defaults to 31.
93102 force_refresh_amazon (bool): Refresh cache by fetching transactions directly from Amazon.
94103 """
95- self .amazon_config = amazon_config
96- self .order_years = self .__class__ . _normalized_years (order_years )
97- self .transaction_days = transaction_days
98- self .force_refresh_amazon = force_refresh_amazon
104+ self .amazon_config : AmazonConfig = amazon_config
105+ self .order_years : list [ int ] = self .normalize_years (order_years )
106+ self .transaction_days : int = transaction_days
107+ self .force_refresh_amazon : bool = force_refresh_amazon
99108
100109 # for memoizing the results of method calls
101- self ._memo = {}
110+ self ._session : AmazonSession | None = None
111+ self ._amazon_orders : AmazonOrderModels | None = None
112+ self ._amazon_transactions : AmazonTransactionModels | None = None
102113
103114 def get_amazon_transactions (self ) -> list [AmazonTransactionWithOrderInfo ]:
104115 """Get Amazon transactions linked to orders.
@@ -108,13 +119,17 @@ def get_amazon_transactions(self) -> list[AmazonTransactionWithOrderInfo]:
108119 Returns:
109120 list[TransactionWithOrderInfo]: A list of transactions with order info
110121 """
111- return self ._get_amazon_transactions (
112- order_years = self .order_years ,
113- transaction_days = self .transaction_days ,
114- amazon_config = self .amazon_config ,
115- use_cache = not self .force_refresh_amazon ,
122+ return cast (
123+ "list[AmazonTransactionWithOrderInfo]" ,
124+ self ._get_amazon_transactions (
125+ order_years = self .order_years ,
126+ transaction_days = self .transaction_days ,
127+ amazon_config = self .amazon_config ,
128+ use_cache = not self .force_refresh_amazon ,
129+ ),
116130 )
117131
132+ # HACK: unused parameters are needed for caching to work correctly
118133 @Cache (
119134 validity_duration = "2h" ,
120135 enable_cache_arg_name = "use_cache" ,
@@ -126,13 +141,14 @@ def get_amazon_transactions(self) -> list[AmazonTransactionWithOrderInfo]:
126141 )
127142 def _get_amazon_transactions (
128143 self ,
129- order_years : list [str ],
130- transaction_days : int ,
131- amazon_config : AmazonConfig ,
144+ order_years : list [str ], # pyright: ignore[reportUnusedParameter]
145+ transaction_days : int , # pyright: ignore[reportUnusedParameter]
146+ amazon_config : AmazonConfig , # pyright: ignore[reportUnusedParameter]
147+ use_cache : bool = True , # pyright: ignore[reportUnusedParameter]
132148 ) -> list [AmazonTransactionWithOrderInfo ]:
133- orders_dict = {order .order_number : order for order in self ._amazon_orders ()}
149+ orders_dict = {order .order_number : order for order in self .fetch_amazon_orders ()}
134150
135- amazon_transactions = self ._amazon_transactions ()
151+ amazon_transactions = self .fetch_amazon_transactions ()
136152
137153 amazon_transaction_with_order_details : list [AmazonTransactionWithOrderInfo ] = []
138154 for transaction in amazon_transactions :
@@ -150,7 +166,7 @@ def _get_amazon_transactions(
150166
151167 return amazon_transaction_with_order_details
152168
153- def _amazon_orders (self ) -> list [ Order ] :
169+ def fetch_amazon_orders (self ) -> AmazonOrderModels :
154170 """Returns a list of Amazon orders.
155171
156172 Args:
@@ -159,60 +175,48 @@ def _amazon_orders(self) -> list[Order]:
159175 Returns:
160176 list[Order]: A list of Amazon orders.
161177 """
162- if "amazon_orders" in self ._memo :
163- return self ._memo [ "amazon_orders" ]
178+ if self ._amazon_orders is not None :
179+ return self ._amazon_orders
164180
165- amazon_orders = AmazonOrders (self ._session ())
181+ orders = AmazonOrderModels .get_order_history (self .get_session (), self .order_years )
182+ orders .sort_by_order_placed_date ()
166183
167- all_orders : list [Order ] = []
168- for year in self .order_years :
169- all_orders .extend (amazon_orders .get_order_history (year = year ))
170- all_orders .sort (key = lambda order : order .order_placed_date )
184+ self ._amazon_orders = orders
171185
172- self ._memo [ "amazon_orders" ] = all_orders
186+ return self ._amazon_orders
173187
174- return self ._memo ["amazon_orders" ]
175-
176- def _amazon_transactions (self ) -> list [Transaction ]:
188+ def fetch_amazon_transactions (self ) -> AmazonTransactionModels :
177189 """Fetches and sorts Amazon transactions."""
178- if "amazon_transactions" in self ._memo :
179- return self ._memo [ "amazon_transactions" ]
190+ if self ._amazon_transactions is not None :
191+ return self ._amazon_transactions
180192
181- self ._memo ["amazon_transactions" ] = AmazonTransactions (
182- amazon_session = self ._session ()
183- ).get_transactions (days = self .transaction_days )
193+ transactions = AmazonTransactionModels .get_transactions (
194+ self .get_session (), self .transaction_days
195+ )
196+ transactions .sort_by_completed_date ()
184197
185- self ._memo [ "amazon_transactions" ]. sort ( key = lambda trans : trans . completed_date )
198+ self ._amazon_transactions = transactions
186199
187- return self ._memo [ "amazon_transactions" ]
200+ return self ._amazon_transactions
188201
189- def _session (self ) -> AmazonSession :
190- if "session" in self ._memo :
191- return self ._memo [ "session" ]
202+ def get_session (self ) -> AmazonSession :
203+ if self ._session is not None :
204+ return self ._session
192205
193206 amazon_session = self .amazon_config .amazon_session ()
194207 amazon_session .login ()
195208
196209 if amazon_session .is_authenticated :
197- self ._memo ["session" ] = amazon_session
198- return self ._memo ["session" ]
210+ self ._session = amazon_session
211+ return self ._session
212+ raise ValueError ("Failed to authenticate with Amazon." )
199213
200214 @classmethod
201- def _normalized_years (cls , years : list [str ] | None = None ) -> list [str ]:
215+ def normalize_years (cls , years : list [str | int ] | None = None ) -> list [int ]:
202216 if years is None :
203217 return [date .today ().year ]
204218
205- result : list [str ] = []
206-
207- for year in years :
208- if len (year ) == 2 :
209- result .append ("20" + year )
210- elif len (year ) == 4 :
211- result .append (year )
212- else :
213- raise ValueError ("Year must be specified as 2 or 4 digits (e.g. 21 or 2021)" )
214-
215- return result
219+ return [validate_year (year ) for year in years ]
216220
217221
218222def print_amazon_transactions (
@@ -239,7 +243,7 @@ def print_amazon_transactions(
239243 f"${ transaction .order_total :.2f} " ,
240244 transaction .order_number ,
241245 str (transaction .order_link ),
242- " | " .join (_truncate_title (item .title ) for item in transaction .items ), # pyright: ignore[reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
246+ " | " .join (_truncate_title (item .title ) for item in transaction .items ),
243247 )
244248
245249 rprint (table )
0 commit comments