Skip to content

Commit ddb37e9

Browse files
committed
clean up, WIP
1 parent a89329e commit ddb37e9

File tree

2 files changed

+3
-15
lines changed

2 files changed

+3
-15
lines changed

methods/catalog/rbr/library/rbr_loss.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# methods/catalog/rbr/library.py
22
import math
3-
from typing import Callable, Dict, Optional, Sequence, Tuple, Any
3+
from typing import Optional, Sequence, Any
44

55
import numpy as np
66
import torch
@@ -238,15 +238,7 @@ def robust_bayesian_recourse(
238238
random_state: Optional[int] = None,
239239
verbose: bool = False,
240240
) -> np.ndarray:
241-
"""
242-
High-level function that matches the CARLA library-call pattern.
243-
Parameters largely mirror the original code you provided.
244-
- raw_model: object with .predict(np.ndarray) -> labels/probs
245-
- x0: 1D numpy array (a single factual)
246-
- cat_features_indices: indices of encoded categorical features to clamp/round
247-
- train_data: numpy array (N, d) required (used to find boundary and feasible set)
248-
Returns counterfactual as numpy array same shape as x0.
249-
"""
241+
250242
# helper to call raw_model.predict consistently
251243
def predict_fn_np(arr: np.ndarray) -> np.ndarray:
252244
# raw_model might accept (n,d) and return probs or labels
@@ -425,6 +417,4 @@ def projection(x, delta):
425417

426418
# ----------------------------- end of optimize() -----------------------
427419

428-
# final clamping for feature valid ranges [0,1] if raw_model expects that (user may want different behaviour)
429-
# NOTE: the CARLA wrapper can do final "check_counterfactuals" conversions; here we return raw vector
430420
return cf

methods/catalog/rbr/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
from typing import Any, Callable, Dict, Optional, Tuple
1+
from typing import Dict, Optional
22

3-
import numpy as np
43
import pandas as pd
5-
import torch
64

75
from methods.catalog.rbr.library.rbr_loss import robust_bayesian_recourse
86
from methods.processing.counterfactuals import merge_default_parameters

0 commit comments

Comments
 (0)