@@ -157,7 +157,7 @@ async def _stub_stream_fn(request: PricingRequest) -> AsyncGenerator[dict[str, A
157157# ---------------------------------------------------------------------------
158158
159159
160- def create_app (model_path : str | None = None ) -> FastAPI :
160+ def create_app (model_path : str | None = None ) -> FastAPI : # noqa: C901
161161 """Build and return the FastAPI application.
162162
163163 Parameters
@@ -169,7 +169,7 @@ def create_app(model_path: str | None = None) -> FastAPI:
169169 """
170170
171171 @asynccontextmanager
172- async def lifespan (app : FastAPI ) -> AsyncGenerator [None , None ]:
172+ async def lifespan (app : FastAPI ) -> AsyncGenerator [None , None ]: # noqa: C901
173173 # --- startup ---
174174 model = None
175175 actor_critic = None
@@ -186,14 +186,14 @@ def _real_batch_fn(requests: list[PricingRequest]) -> list[PricingResponse]:
186186 results : list [PricingResponse ] = []
187187 for req in requests :
188188 with torch .no_grad ():
189- x = _build_observation (
190- req .current_prices , obs_dim , device
191- )
189+ x = _build_observation (req .current_prices , obs_dim , device )
192190 z_t , _ = model .rssm .encode_obs (x )
193191 model .reset_state (batch_size = 1 )
194192 h_t = torch .zeros (
195- 1 , model .rssm .d_model ,
196- device = device , dtype = z_t .dtype ,
193+ 1 ,
194+ model .rssm .d_model ,
195+ device = device ,
196+ dtype = z_t .dtype ,
197197 )
198198 state = torch .cat ([h_t , z_t ], dim = - 1 )
199199 actions , _ , _ = actor_critic .act (state , deterministic = True )
@@ -230,16 +230,10 @@ def _real_batch_fn(requests: list[PricingRequest]) -> list[PricingResponse]:
230230 mean_r_mean = total_profit / max (H , 1 )
231231 r_std_rel = mean_r_std / (abs (mean_r_mean ) + 1e-6 )
232232 k = min (0.1 , float (r_std_rel ))
233- uncertainty_bounds = [
234- (p * (1 - k ), p * (1 + k ))
235- for p in rec_prices
236- ]
233+ uncertainty_bounds = [(p * (1 - k ), p * (1 + k )) for p in rec_prices ]
237234 n_skus = len (req .current_prices )
238235 avg_price = sum (req .current_prices ) / max (n_skus , 1 )
239- est_units = (
240- total_profit / (avg_price * 0.2 + 1e-6 )
241- / max (n_skus , 1 )
242- )
236+ est_units = total_profit / (avg_price * 0.2 + 1e-6 ) / max (n_skus , 1 )
243237 expected_units = [est_units ] * n_skus
244238
245239 results .append (
@@ -264,35 +258,30 @@ async def _real_stream_fn(
264258 request : PricingRequest ,
265259 ) -> AsyncGenerator [dict [str , Any ], None ]:
266260 with torch .no_grad ():
267- x = _build_observation (
268- request .current_prices , obs_dim , device
269- )
261+ x = _build_observation (request .current_prices , obs_dim , device )
270262 z_t , _ = model .rssm .encode_obs (x )
271263 model .reset_state (batch_size = 1 )
272264 h_t = torch .zeros (
273- 1 , model .rssm .d_model ,
274- device = device , dtype = z_t .dtype ,
265+ 1 ,
266+ model .rssm .d_model ,
267+ device = device ,
268+ dtype = z_t .dtype ,
275269 )
276270 n = len (request .current_prices )
277271 H = min (request .horizon , 13 )
278272 prices = list (request .current_prices )
279273 for step in range (H ):
280274 state = torch .cat ([h_t , z_t ], dim = - 1 )
281- actions , _ , _ = actor_critic .act (
282- state , deterministic = True
283- )
275+ actions , _ , _ = actor_critic .act (state , deterministic = True )
284276 mult = _discrete_actions_to_multipliers (actions )
285277 step_out = model .imagine_step (z_t , mult )
286278 h_t = step_out ["h" ]
287279 z_t = step_out ["z" ]
288280 rec_prices = [
289- prices [i ] * mult [0 , i ].item ()
290- for i in range (min (n , mult .shape [1 ]))
281+ prices [i ] * mult [0 , i ].item () for i in range (min (n , mult .shape [1 ]))
291282 ]
292283 if len (rec_prices ) < n :
293- rec_prices .extend (
294- [rec_prices [- 1 ]] * (n - len (rec_prices ))
295- )
284+ rec_prices .extend ([rec_prices [- 1 ]] * (n - len (rec_prices )))
296285 prices = rec_prices
297286 yield {
298287 "step" : step ,
0 commit comments