Skip to content

Commit a48e0fe

Browse files
committed
fix: resolve ruff lint and format errors for CI
- Remove unused imports (numpy in cost_plus, gymnasium in run_rl_baselines) - Fix line-too-long errors in baseline scripts and trainer - Suppress C901 complexity warnings for serve.py factory functions - Fix import sort order in run_rl_baselines - Auto-format 6 files with ruff format
1 parent 0992aea commit a48e0fe

9 files changed

Lines changed: 63 additions & 55 deletions

File tree

scripts/baselines/competitive_matching.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ def main() -> None:
7373

7474
print("Loading Dominick's CSO test data (weeks 341-400)...")
7575
test_df = load_test_data(args.data_dir)
76-
print(f" {len(test_df)} rows, {test_df['WEEK'].nunique()} weeks, {test_df['UPC'].nunique()} UPCs")
76+
n_rows, n_weeks = len(test_df), test_df["WEEK"].nunique()
77+
n_upcs = test_df["UPC"].nunique()
78+
print(f" {n_rows} rows, {n_weeks} weeks, {n_upcs} UPCs")
7779

7880
results = run_competitive_matching(test_df, args.noise_pct, args.seed)
7981
print(f"Competitive matching: mean return = {results['mean_return']:.2f}")

scripts/baselines/cost_plus.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import json
1111
from pathlib import Path
1212

13-
import numpy as np
1413
import pandas as pd
1514

1615

@@ -62,7 +61,9 @@ def main() -> None:
6261

6362
print("Loading Dominick's CSO test data (weeks 341-400)...")
6463
test_df = load_test_data(args.data_dir)
65-
print(f" {len(test_df)} rows, {test_df['WEEK'].nunique()} weeks, {test_df['UPC'].nunique()} UPCs")
64+
n_rows, n_weeks = len(test_df), test_df["WEEK"].nunique()
65+
n_upcs = test_df["UPC"].nunique()
66+
print(f" {n_rows} rows, {n_weeks} weeks, {n_upcs} UPCs")
6667

6768
results = run_cost_plus(test_df, args.markup)
6869
print(f"Cost-plus ({args.markup:.0%}): mean return = {results['mean_return']:.2f}")

scripts/baselines/run_rl_baselines.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import sys
1313
from pathlib import Path
1414

15-
import gymnasium as gym
1615
import numpy as np
1716
import torch
1817

@@ -42,8 +41,7 @@ def load_world_model(checkpoint_path: str, device: str = "cpu"):
4241
def make_env(world_model, wrapper: str = "discrete", seed: int = 42):
4342
"""Create GroceryPricingEnv backed by the trained world model."""
4443
sys.path.insert(0, "/workspace/src")
45-
from retail_world_model.envs.grocery import GroceryPricingEnv
46-
44+
from retail_world_model.envs.grocery import GroceryPricingEnv # noqa: I001
4745
from sb3_wrapper import ContinuousActionWrapper, FlatDiscreteWrapper
4846

4947
rng = np.random.default_rng(seed)

scripts/baselines/static_xgboost.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ def main() -> None:
105105
default=Path("/workspace/docs/data"),
106106
)
107107
parser.add_argument("--seed", type=int, default=42)
108-
parser.add_argument("--max-test-rows", type=int, default=5000,
109-
help="Subsample test rows for tractability")
108+
parser.add_argument(
109+
"--max-test-rows", type=int, default=5000, help="Subsample test rows for tractability"
110+
)
110111
parser.add_argument(
111112
"--output",
112113
type=Path,
@@ -124,7 +125,7 @@ def main() -> None:
124125
if args.max_test_rows and len(test_df) > args.max_test_rows:
125126
test_sample = test_df.sample(args.max_test_rows, random_state=args.seed)
126127
scale_factor = len(test_df) / args.max_test_rows
127-
print(f" Subsampled test to {args.max_test_rows} rows (scale factor: {scale_factor:.1f}x)")
128+
print(f" Subsampled to {args.max_test_rows} rows (scale={scale_factor:.1f}x)")
128129
else:
129130
test_sample = test_df
130131
scale_factor = 1.0

scripts/extract_ablation_results.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,7 @@ def extract_from_wandb():
7070
}
7171
ret_str = f"{ret_val:.2f}" if ret_val is not None else "N/A"
7272
wm_str = f"{wm_val:.2f}"
73-
print(
74-
f" {abl_name}: return={ret_str}, "
75-
f"wm_loss={wm_str}, step={step_val}"
76-
)
73+
print(f" {abl_name}: return={ret_str}, wm_loss={wm_str}, step={step_val}")
7774

7875
return results
7976

scripts/generate_figures.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -705,8 +705,12 @@ def fig_policy_heatmap():
705705

706706
fig, ax = plt.subplots(figsize=(10, 4.5))
707707
im = ax.imshow(
708-
base, aspect="auto", cmap="RdYlGn",
709-
interpolation="nearest", vmin=-0.5, vmax=0.5,
708+
base,
709+
aspect="auto",
710+
cmap="RdYlGn",
711+
interpolation="nearest",
712+
vmin=-0.5,
713+
vmax=0.5,
710714
)
711715
ax.set_xlabel("Test Week (relative)")
712716
ax.set_ylabel("Store Index")
@@ -750,8 +754,14 @@ def fig_reward_distribution():
750754
z = np.polyfit(actual, predicted, 1)
751755
p = np.poly1d(z)
752756
x_fit = np.linspace(lims[0], lims[1], 100)
753-
ax.plot(x_fit, p(x_fit), "-", color=COLORS["red"], linewidth=1.5,
754-
label=f"Fit: y={z[0]:.2f}x+{z[1]:.1f}")
757+
ax.plot(
758+
x_fit,
759+
p(x_fit),
760+
"-",
761+
color=COLORS["red"],
762+
linewidth=1.5,
763+
label=f"Fit: y={z[0]:.2f}x+{z[1]:.1f}",
764+
)
755765
corr = np.corrcoef(actual, predicted)[0, 1]
756766
ax.set_xlabel("Actual Reward")
757767
ax.set_ylabel("Predicted Reward")
@@ -784,10 +794,18 @@ def fig_imagination_rollout():
784794
std = np.abs(np.random.normal(0, 1.5, H)) * np.sqrt(t + 1)
785795

786796
ax.plot(t, actual, "o-", color=COLORS["blue"], linewidth=1.5, markersize=4, label="Actual")
787-
ax.plot(t, predicted, "s--", color=COLORS["orange"], linewidth=1.5, markersize=4,
788-
label="Imagined")
789-
ax.fill_between(t, predicted - 2 * std, predicted + 2 * std,
790-
alpha=0.15, color=COLORS["orange"])
797+
ax.plot(
798+
t,
799+
predicted,
800+
"s--",
801+
color=COLORS["orange"],
802+
linewidth=1.5,
803+
markersize=4,
804+
label="Imagined",
805+
)
806+
ax.fill_between(
807+
t, predicted - 2 * std, predicted + 2 * std, alpha=0.15, color=COLORS["orange"]
808+
)
791809
ax.set_xlabel("Rollout Step")
792810
ax.set_ylabel("Demand (units)")
793811
ax.set_title(f"SKU Example {i + 1} (base ≈ {base:.0f})")

src/retail_world_model/api/serve.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ async def _stub_stream_fn(request: PricingRequest) -> AsyncGenerator[dict[str, A
157157
# ---------------------------------------------------------------------------
158158

159159

160-
def create_app(model_path: str | None = None) -> FastAPI:
160+
def create_app(model_path: str | None = None) -> FastAPI: # noqa: C901
161161
"""Build and return the FastAPI application.
162162
163163
Parameters
@@ -169,7 +169,7 @@ def create_app(model_path: str | None = None) -> FastAPI:
169169
"""
170170

171171
@asynccontextmanager
172-
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
172+
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: C901
173173
# --- startup ---
174174
model = None
175175
actor_critic = None
@@ -186,14 +186,14 @@ def _real_batch_fn(requests: list[PricingRequest]) -> list[PricingResponse]:
186186
results: list[PricingResponse] = []
187187
for req in requests:
188188
with torch.no_grad():
189-
x = _build_observation(
190-
req.current_prices, obs_dim, device
191-
)
189+
x = _build_observation(req.current_prices, obs_dim, device)
192190
z_t, _ = model.rssm.encode_obs(x)
193191
model.reset_state(batch_size=1)
194192
h_t = torch.zeros(
195-
1, model.rssm.d_model,
196-
device=device, dtype=z_t.dtype,
193+
1,
194+
model.rssm.d_model,
195+
device=device,
196+
dtype=z_t.dtype,
197197
)
198198
state = torch.cat([h_t, z_t], dim=-1)
199199
actions, _, _ = actor_critic.act(state, deterministic=True)
@@ -230,16 +230,10 @@ def _real_batch_fn(requests: list[PricingRequest]) -> list[PricingResponse]:
230230
mean_r_mean = total_profit / max(H, 1)
231231
r_std_rel = mean_r_std / (abs(mean_r_mean) + 1e-6)
232232
k = min(0.1, float(r_std_rel))
233-
uncertainty_bounds = [
234-
(p * (1 - k), p * (1 + k))
235-
for p in rec_prices
236-
]
233+
uncertainty_bounds = [(p * (1 - k), p * (1 + k)) for p in rec_prices]
237234
n_skus = len(req.current_prices)
238235
avg_price = sum(req.current_prices) / max(n_skus, 1)
239-
est_units = (
240-
total_profit / (avg_price * 0.2 + 1e-6)
241-
/ max(n_skus, 1)
242-
)
236+
est_units = total_profit / (avg_price * 0.2 + 1e-6) / max(n_skus, 1)
243237
expected_units = [est_units] * n_skus
244238

245239
results.append(
@@ -264,35 +258,30 @@ async def _real_stream_fn(
264258
request: PricingRequest,
265259
) -> AsyncGenerator[dict[str, Any], None]:
266260
with torch.no_grad():
267-
x = _build_observation(
268-
request.current_prices, obs_dim, device
269-
)
261+
x = _build_observation(request.current_prices, obs_dim, device)
270262
z_t, _ = model.rssm.encode_obs(x)
271263
model.reset_state(batch_size=1)
272264
h_t = torch.zeros(
273-
1, model.rssm.d_model,
274-
device=device, dtype=z_t.dtype,
265+
1,
266+
model.rssm.d_model,
267+
device=device,
268+
dtype=z_t.dtype,
275269
)
276270
n = len(request.current_prices)
277271
H = min(request.horizon, 13)
278272
prices = list(request.current_prices)
279273
for step in range(H):
280274
state = torch.cat([h_t, z_t], dim=-1)
281-
actions, _, _ = actor_critic.act(
282-
state, deterministic=True
283-
)
275+
actions, _, _ = actor_critic.act(state, deterministic=True)
284276
mult = _discrete_actions_to_multipliers(actions)
285277
step_out = model.imagine_step(z_t, mult)
286278
h_t = step_out["h"]
287279
z_t = step_out["z"]
288280
rec_prices = [
289-
prices[i] * mult[0, i].item()
290-
for i in range(min(n, mult.shape[1]))
281+
prices[i] * mult[0, i].item() for i in range(min(n, mult.shape[1]))
291282
]
292283
if len(rec_prices) < n:
293-
rec_prices.extend(
294-
[rec_prices[-1]] * (n - len(rec_prices))
295-
)
284+
rec_prices.extend([rec_prices[-1]] * (n - len(rec_prices)))
296285
prices = rec_prices
297286
yield {
298287
"step": step,

src/retail_world_model/models/rssm.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ def _encode_raw(
123123
"""Route through ObsEncoder or EntityEncoder based on config."""
124124
if self.encoder_type == "entity" and isinstance(self.obs_encoder, EntityEncoder):
125125
ids = entity_ids or {}
126-
default_ids = torch.zeros(
127-
*x_t.shape[:-1], dtype=torch.long, device=x_t.device
128-
)
126+
default_ids = torch.zeros(*x_t.shape[:-1], dtype=torch.long, device=x_t.device)
129127
return self.obs_encoder(
130128
upc_ids=ids.get("upc_ids", default_ids),
131129
store_ids=ids.get("store_ids", default_ids),

src/retail_world_model/training/trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,10 @@ def train_phase_a(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
8787
"""Phase A: world model update."""
8888
self.opt_wm.zero_grad()
8989
losses = elbo_loss(
90-
batch, self.model,
91-
use_symlog=self.use_symlog, use_twohot=self.use_twohot,
90+
batch,
91+
self.model,
92+
use_symlog=self.use_symlog,
93+
use_twohot=self.use_twohot,
9294
)
9395
losses["total"].backward()
9496
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_wm)
@@ -218,7 +220,9 @@ def train_step(self, batch: dict[str, torch.Tensor]) -> dict[str, float]:
218220
entity_ids = None
219221
if "store_id" in batch and "month_ids" in batch:
220222
entity_ids = {
221-
"store_ids": batch["store_id"].unsqueeze(1).expand(-1, batch["x_BT"].shape[1]),
223+
"store_ids": batch["store_id"]
224+
.unsqueeze(1)
225+
.expand(-1, batch["x_BT"].shape[1]),
222226
"month_ids": batch["month_ids"],
223227
}
224228
output = self.model.forward(batch["x_BT"], batch["a_BT"], entity_ids=entity_ids) # type: ignore[union-attr]

0 commit comments

Comments
 (0)