Skip to content

Commit 1189630

Browse files
authored
Merge pull request #20 from pablormier/dev
dev7 version
2 parents 288e7f9 + 4083c15 commit 1189630

10 files changed

Lines changed: 1831 additions & 74 deletions

File tree

.github/workflows/deploy-docs.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ jobs:
6969
esac
7070
echo "doc_version=${version}" >> $GITHUB_OUTPUT
7171
72+
- name: Get GitHub Pages URL via GH CLI
73+
id: pages
74+
env:
75+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
76+
run: |
77+
echo "GH CLI version: $(gh --version)"
78+
echo "Fetching Pages URL for ${{ github.repository }}…"
79+
page_url=$(gh api \
80+
-H "Accept: application/vnd.github.v3+json" \
81+
/repos/${{ github.repository }}/pages \
82+
--jq .html_url \
83+
)
84+
echo "Retrieved page_url: $page_url"
85+
# Export it so later steps see it:
86+
echo "PAGE_URL=${page_url:-}" >> $GITHUB_ENV
87+
echo "Exported PAGE_URL=${page_url:-}"
88+
7289
- name: Generate switcher.json
7390
run: |
7491
set -e

corneto/__init__.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from corneto import _plotting as pl
66
from corneto._constants import *
7-
from corneto._data import Data, Feature, Sample
7+
from corneto._data import Data, Feature, GraphData, Sample
88
from corneto._util import info, suppress_output
99
from corneto.backend import DEFAULT_BACKEND, DEFAULT_SOLVER, available_backends
1010

@@ -17,6 +17,7 @@
1717
#from corneto._graph import Attr, Attributes, EdgeType, Graph
1818
from corneto.graph import Attr, Attributes, EdgeType, Graph
1919
from corneto.io import load_graph_from_sif
20+
2021
# from corneto._core import GReNet as Graph
2122
from corneto.methods import (
2223
create_flow_graph,
@@ -26,8 +27,10 @@
2627
)
2728
from corneto.utils import Attr, Attributes
2829

29-
logger = logging.getLogger(__name__)
30-
logger.addHandler(logging.NullHandler())
30+
from corneto._logging import disable_logging, enable_logging, set_verbosity
31+
32+
#logger = logging.getLogger(__name__)
33+
#logger.addHandler(logging.NullHandler())
3134

3235

3336
def get_version():
@@ -73,7 +76,8 @@ def __getattr__(self, attr):
7376
"EdgeType",
7477
"Feature",
7578
"Graph",
76-
"K",
79+
"GraphData",
80+
"K", # deprecated
7781
"Sample",
7882
"available_backends",
7983
"info",

corneto/_logging.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import logging
2+
3+
4+
def enable_logging(level="info", stream=None):
5+
"""Enable logging output for the mypackage package.
6+
7+
Args:
8+
level (str or int, optional): Logging level (e.g., "info", "debug", "warning", logging.INFO, etc.).
9+
String (case-insensitive) or int accepted. Defaults to "info".
10+
stream (file-like, optional): Stream for logging output. Defaults to sys.stderr.
11+
12+
Example:
13+
import corneto as cn
14+
cn.enable_logging("debug")
15+
"""
16+
# Convert string level to logging constant if needed
17+
if isinstance(level, str):
18+
level_name = level.strip().upper()
19+
level = getattr(logging, level_name, logging.INFO)
20+
21+
logger = logging.getLogger(__package__ or "corneto") # fallback for script use
22+
23+
# Avoid adding multiple handlers if already present
24+
if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
25+
handler = logging.StreamHandler(stream)
26+
formatter = logging.Formatter("[%(levelname)s] %(name)s: %(message)s")
27+
handler.setFormatter(formatter)
28+
logger.addHandler(handler)
29+
logger.setLevel(level)
30+
logger.propagate = False
31+
32+
33+
def disable_logging():
34+
"""Disable all logging output from the mypackage package.
35+
Removes all handlers and sets level to WARNING.
36+
"""
37+
logger = logging.getLogger(__package__ or "mypackage")
38+
logger.handlers.clear()
39+
logger.setLevel(logging.WARNING)
40+
41+
42+
# Alias for discoverability
43+
set_verbosity = enable_logging

corneto/methods/sampler.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import logging
2+
import numpy as np
3+
from collections.abc import Sequence
4+
from typing import Dict, List, Union
5+
6+
# Configure module-level logger; users can configure handlers in their application
7+
logger = logging.getLogger(__name__)
8+
9+
10+
def sample_alternative_solutions(
11+
problem,
12+
variable_name: str,
13+
*,
14+
percentage: float = 0.10,
15+
scale: float = 0.03,
16+
rel_opt_tol: float = 0.05,
17+
max_samples: int = 30,
18+
perturbation_name: str = "perturbation",
19+
solver_kwargs: dict | None = None,
20+
rng: np.random.Generator | int | None = None,
21+
collect_vars: Sequence[str] | None = None,
22+
verbose: int = 1, # 0 = silent, 1 = summary, 2 = full detail
23+
) -> Dict[str, np.ndarray]:
24+
"""Sample alternative solutions by perturbing a chosen decision variable.
25+
26+
This routine takes an optimization problem (with attributes .expr,
27+
.solve(), and .objectives), identifies a target variable within it,
28+
and generates up to max_samples new feasible solutions by randomly
29+
perturbing a fraction of that variable's entries. Only those perturbations
30+
that keep all original objectives within a relative tolerance of the
31+
baseline are accepted.
32+
33+
Args:
34+
problem: An optimization problem instance exposing
35+
- expr: a mapping of variable names to variable objects,
36+
- solve(...): method to solve the problem,
37+
- objectives: list of objective objects (each with .name and .value).
38+
variable_name (str): Name of the variable in problem.expr to perturb.
39+
percentage (float, optional): Fraction of the variable's entries to
40+
perturb in each trial (default 0.10).
41+
scale (float, optional): Standard deviation of the normal random noise
42+
(default 0.03).
43+
rel_opt_tol (float, optional): Maximum allowed relative deviation of
44+
any original objective from its baseline value (default 0.05).
45+
max_samples (int, optional): Maximum number of perturbation trials
46+
to attempt (default 30).
47+
perturbation_name (str, optional): Name to assign to the added
48+
perturbation objective (default `"perturbation").
49+
solver_kwargs (dict or None, optional): Extra keyword arguments passed
50+
to problem.solve() (default None).
51+
rng (np.random.Generator or int or None, optional): Random number
52+
generator or seed for reproducibility (default None).
53+
collect_vars (Sequence[str] or None, optional):
54+
Names of the variables whose values you want back.
55+
56+
- `None (default) – collect **every** variable in problem.expr
57+
- `[] – collect **none** (method solves but returns an empty dict)
58+
- `["x", "y"] – collect only those named variables
59+
verbose (int, optional): Verbosity level:
60+
0 = silent, 1 = summary, 2 = full detail (default 1).
61+
62+
Returns:
63+
dict:
64+
A dictionary that maps each collected variable name to a NumPy
65+
array with shape `(n_samples, *variable.shape) where
66+
67+
* `n_samples ≥ 1 – it counts the incumbent plus every accepted
68+
perturbation;
69+
* the remaining dimensions match the variable’s own shape.
70+
71+
Example::
72+
73+
out = sample_alternative_solutions(problem, "x", collect_vars=["x", "y"])
74+
x_stack = out["x"] # shape (n_samples, *x.shape)
75+
incumbent_x = x_stack[0] # first slice is always the baseline
76+
77+
Raises:
78+
KeyError:
79+
If `variable_name is not in problem.expr **or** if any name
80+
inside `collect_vars is missing from problem.expr.
81+
"""
82+
# Map verbosity to logging levels
83+
if verbose >= 2:
84+
log_level = logging.DEBUG
85+
elif verbose == 1:
86+
log_level = logging.INFO
87+
else:
88+
log_level = logging.WARNING
89+
logger.setLevel(log_level)
90+
91+
if solver_kwargs is None:
92+
solver_kwargs = {}
93+
rng = rng if isinstance(rng, np.random.Generator) else np.random.default_rng(rng)
94+
95+
# ------------------ sanity checks ------------------
96+
if variable_name not in problem.expr:
97+
raise KeyError(f"Variable '{variable_name}' not found in problem.expr")
98+
99+
if collect_vars is None:
100+
collect_vars = list(problem.expr.keys())
101+
else:
102+
missing = [v for v in collect_vars if v not in problem.expr]
103+
if missing:
104+
raise KeyError(f"Variables not found in problem.expr: {missing}")
105+
106+
collected: Dict[str, List[np.ndarray]] = {v: [] for v in collect_vars}
107+
target_var = problem.expr[variable_name]
108+
109+
# 1) original solve ---------------------------------
110+
logger.debug("Solving original model …")
111+
problem.solve(**solver_kwargs, verbosity=0)
112+
baseline_obj = {o.name: float(o.value) for o in problem.objectives}
113+
logger.debug(
114+
"Baseline objectives: "
115+
+ ", ".join(f"{k}={v:.6g}" for k, v in baseline_obj.items())
116+
)
117+
118+
for v in collect_vars:
119+
collected[v].append(np.asarray(problem.expr[v].value).copy())
120+
121+
# 2) build perturbation parameter -------------------
122+
var_shape = tuple(int(s) for s in target_var.shape)
123+
total_elems = int(np.prod(var_shape))
124+
n_perturb = max(1, int(total_elems * percentage))
125+
126+
noise_buf = np.zeros(var_shape, dtype=float)
127+
pert = problem.backend.Parameter(
128+
name=f"{perturbation_name}_param", shape=var_shape, value=noise_buf
129+
)
130+
problem.add_objective(
131+
(target_var.multiply(pert))
132+
.sum()
133+
.reshape(
134+
1,
135+
),
136+
name=perturbation_name,
137+
)
138+
139+
flat_buf = noise_buf.reshape(-1)
140+
n_accept = n_reject = 0
141+
142+
# 3) sampling loop
143+
for trial in range(1, max_samples + 1):
144+
# 3a) new perturbation
145+
flat_buf.fill(0.0)
146+
idx = rng.choice(total_elems, n_perturb, replace=False)
147+
flat_buf[idx] = rng.normal(0.0, scale, n_perturb)
148+
pert.value = noise_buf
149+
150+
# 3b) solve
151+
problem.solve(warm_start=True, **solver_kwargs, verbosity=0)
152+
153+
# 3c) compute relative errors for each objective
154+
relerrs = {}
155+
current_vals = {}
156+
for o in problem.objectives:
157+
if o.name == perturbation_name:
158+
continue
159+
val = float(o.value)
160+
current_vals[o.name] = val
161+
denom = max(abs(baseline_obj[o.name]), 1e-9)
162+
relerrs[o.name] = abs(val - baseline_obj[o.name]) / denom
163+
164+
# check tolerance
165+
violated = next(
166+
((name, err) for name, err in relerrs.items() if err > rel_opt_tol), None
167+
)
168+
169+
# log objective values and errors
170+
detail_msg = ", ".join(
171+
f"{name}: val={current_vals[name]:.6g}, rel.err={relerrs[name]:.4f}"
172+
for name in current_vals
173+
)
174+
175+
if violated is None:
176+
for v in collect_vars:
177+
collected[v].append(np.asarray(problem.expr[v].value).copy())
178+
n_accept += 1
179+
logger.info(
180+
f"[{trial}/{max_samples}] accepted (total accepted={n_accept}) -> {detail_msg}"
181+
)
182+
else:
183+
n_reject += 1
184+
logger.info(
185+
f"[{trial}/{max_samples}] rejected (tol={rel_opt_tol}) -> {detail_msg}"
186+
)
187+
188+
# 4) stack lists into arrays ------------------------
189+
out: Dict[str, np.ndarray] = {
190+
v: np.stack(values, axis=0) for v, values in collected.items()
191+
}
192+
193+
logger.info(
194+
f"Done. accepted={n_accept}, rejected={n_reject}, solutions returned="
195+
f"{out[next(iter(out))].shape[0]}"
196+
)
197+
return out

docs/conf.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,8 @@
164164
"navbar_align": "left",
165165
"switcher": {
166166
# The switcher.json file is now available at the root.
167-
"json_url": f"{html_baseurl}/switcher.json",
167+
#"json_url": f"{html_baseurl}/switcher.json",
168+
"json_url": "switcher.json",
168169
"version_match": corneto.__version__,
169170
},
170171
"navbar_start": ["navbar-logo", "version-switcher"],

docs/generate_switcher.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,26 @@ def get_version_from_branch(branch):
5959
return branch
6060

6161

62+
def get_base_url():
63+
# 1) try the real published URL
64+
url = os.environ.get("PAGE_URL")
65+
if url:
66+
print(f"Using PAGE_URL: {url}")
67+
return url.rstrip("/")
68+
# 2) fallback to github.io
69+
repo = os.environ.get("GITHUB_REPOSITORY", "username/corneto")
70+
user, project = repo.split("/", 1)
71+
print(f"Using {user}.github.io/{project} as fallback for the switcher URL.")
72+
return f"https://{user}.github.io/{project}"
73+
74+
6275
def main():
6376
# Derive GitHub username from the environment variable.
64-
repo = os.environ.get("GITHUB_REPOSITORY", "username/corneto")
65-
username = repo.split("/")[0]
66-
base_url = f"https://{username}.github.io/corneto"
77+
#repo = os.environ.get("GITHUB_REPOSITORY", "username/corneto")
78+
#username = repo.split("/")[0]
79+
#base_url = f"https://{username}.github.io/corneto"
80+
base_url = get_base_url()
81+
root_path = os.environ.get("SITE_ROOT", "").rstrip("/")
6782

6883
switcher = []
6984

docs/tutorials/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@
1212
context-specific-metabolic-omics.ipynb
1313
single-sample-carnival-transcriptomics.ipynb
1414
network-sampler.ipynb
15+
network-sampler-example.ipynb
1516
kpnn-with-sc.ipynb
1617
```

docs/tutorials/network-sampler-example.ipynb

Lines changed: 1479 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)