Skip to content

Commit 3ba5773

Browse files
authored
Merge pull request #3 from settylab/dev
v0.6.2
2 parents 2ca0e3d + 596e57e commit 3ba5773

41 files changed

Lines changed: 5844 additions & 202 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: Tests
22

33
on:
44
push:
5-
branches: [ main ]
5+
branches: [ main, dev ]
66
pull_request:
7-
branches: [ main ]
7+
branches: [ main, dev ]
88

99
jobs:
1010
test:

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
All notable changes to this project will be documented in this file.
44

5+
## [0.6.2]
6+
7+
- fix differential expression analysis using `groups`
8+
- increase testing coverage
9+
- thread and GPU-usage control in CLI
10+
- fix `volcano_de` plot when the layer is `None`
11+
512
## [0.6.1]
613

714
- table output for CLI

docs/source/cli.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,14 @@ Boolean Flags
245245
--store-additional-stats # Store extra statistics
246246
--overwrite # Overwrite without warning
247247
248+
Compute Options
249+
^^^^^^^^^^^^^^^
250+
251+
.. code-block:: text
252+
253+
--use-gpu # Use GPU acceleration (requires CUDA-enabled JAX)
254+
--threads N # Number of threads for JAX/NumPy/Dask (default: all cores)
255+
248256
Advanced Options
249257
^^^^^^^^^^^^^^^^
250258

@@ -325,6 +333,14 @@ Boolean Flags
325333
--store-landmarks # Store landmarks for reuse
326334
--overwrite # Overwrite without warning
327335
336+
Compute Options
337+
^^^^^^^^^^^^^^^
338+
339+
.. code-block:: text
340+
341+
--use-gpu # Use GPU acceleration (requires CUDA-enabled JAX)
342+
--threads N # Number of threads for JAX/NumPy/Dask (default: all cores)
343+
328344
Example: Complete Analysis
329345
^^^^^^^^^^^^^^^^^^^^^^^^^^
330346

examples/01_getting_started.ipynb

Lines changed: 72 additions & 4 deletions
Large diffs are not rendered by default.

kompot/anndata/differential_expression.py

Lines changed: 106 additions & 111 deletions
Large diffs are not rendered by default.

kompot/cli/compute_config.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
Compute configuration for JAX, NumPy, and Dask.
3+
4+
This module handles GPU/CPU configuration and thread limiting for computational backends.
5+
6+
IMPORTANT NOTES:
7+
1. NumPy thread limits: Set early in main() via environment variables BEFORE NumPy import.
8+
The _configure_thread_limits() function here is called later but only affects subsequently
9+
loaded modules (like Dask), not NumPy which is already initialized.
10+
11+
2. JAX configuration: Must be called AFTER mellon import, as mellon configures JAX on import.
12+
The _configure_jax() function can override mellon's settings.
13+
14+
3. Dask configuration: Can be set at any time via dask.config.
15+
"""
16+
17+
import os
18+
import logging
19+
20+
logger = logging.getLogger("kompot.cli")
21+
22+
23+
def configure_compute(use_gpu: bool = False, n_threads: int = None):
24+
"""
25+
Configure computational backends (JAX, NumPy, Dask) for thread control and GPU usage.
26+
27+
This function must be called AFTER importing mellon, as mellon configures JAX
28+
to use CPU on import. This function can override that configuration.
29+
30+
Parameters
31+
----------
32+
use_gpu : bool, default=False
33+
If True, configure JAX to use GPU. If False, force CPU usage.
34+
n_threads : int, optional
35+
Number of threads to use. If specified, limits threads for:
36+
- JAX (XLA)
37+
- NumPy (OpenBLAS/MKL)
38+
- Dask
39+
40+
Notes
41+
-----
42+
Thread limiting affects:
43+
- JAX: Set via XLA_FLAGS environment variable
44+
- NumPy: Set via OMP_NUM_THREADS, MKL_NUM_THREADS, OPENBLAS_NUM_THREADS
45+
- Dask: Set via dask.config
46+
47+
Examples
48+
--------
49+
>>> # CPU-only with 4 threads
50+
>>> configure_compute(use_gpu=False, n_threads=4)
51+
52+
>>> # GPU with thread limiting
53+
>>> configure_compute(use_gpu=True, n_threads=8)
54+
"""
55+
logger.info("=" * 60)
56+
logger.info("Configuring computational backends")
57+
logger.info("=" * 60)
58+
59+
# Configure thread limits BEFORE JAX initialization
60+
if n_threads is not None:
61+
logger.info(f"Setting thread limit: {n_threads} threads")
62+
_configure_thread_limits(n_threads)
63+
else:
64+
logger.info("No thread limit specified (using system defaults)")
65+
66+
# Configure JAX (must be done AFTER mellon import)
67+
_configure_jax(use_gpu, n_threads)
68+
69+
# Configure Dask if available
70+
try:
71+
_configure_dask(n_threads)
72+
except ImportError:
73+
logger.debug("Dask not available, skipping dask configuration")
74+
75+
logger.info("=" * 60)
76+
77+
78+
def _configure_thread_limits(n_threads: int):
79+
"""
80+
Set environment variables to limit threads for NumPy and related libraries.
81+
82+
Parameters
83+
----------
84+
n_threads : int
85+
Number of threads to use
86+
"""
87+
n_threads_str = str(n_threads)
88+
89+
# OpenMP (used by NumPy, SciPy, etc.)
90+
os.environ['OMP_NUM_THREADS'] = n_threads_str
91+
logger.debug(f" Set OMP_NUM_THREADS={n_threads_str}")
92+
93+
# Intel MKL (if NumPy is built with MKL)
94+
os.environ['MKL_NUM_THREADS'] = n_threads_str
95+
logger.debug(f" Set MKL_NUM_THREADS={n_threads_str}")
96+
97+
# OpenBLAS (if NumPy is built with OpenBLAS)
98+
os.environ['OPENBLAS_NUM_THREADS'] = n_threads_str
99+
logger.debug(f" Set OPENBLAS_NUM_THREADS={n_threads_str}")
100+
101+
# BLAS (general)
102+
os.environ['BLAS_NUM_THREADS'] = n_threads_str
103+
logger.debug(f" Set BLAS_NUM_THREADS={n_threads_str}")
104+
105+
logger.info(f" NumPy/BLAS thread limit: {n_threads} threads")
106+
107+
108+
def _configure_jax(use_gpu: bool, n_threads: int = None):
109+
"""
110+
Configure JAX for GPU/CPU usage and thread limiting.
111+
112+
Must be called AFTER mellon import, as mellon sets JAX to CPU mode on import.
113+
114+
Parameters
115+
----------
116+
use_gpu : bool
117+
Whether to use GPU
118+
n_threads : int, optional
119+
Number of threads for CPU execution
120+
"""
121+
import jax
122+
123+
if use_gpu:
124+
# Check if GPU is available
125+
try:
126+
devices = jax.devices('gpu')
127+
if len(devices) > 0:
128+
logger.info(f" JAX: GPU mode enabled")
129+
logger.info(f" Available GPU devices: {len(devices)}")
130+
for i, device in enumerate(devices):
131+
logger.info(f" Device {i}: {device}")
132+
133+
# Set default device to GPU
134+
# Note: mellon may have set it to CPU, we override here
135+
jax.config.update('jax_platform_name', 'gpu')
136+
else:
137+
logger.warning(" JAX: GPU requested but no GPU devices found, falling back to CPU")
138+
jax.config.update('jax_platform_name', 'cpu')
139+
use_gpu = False
140+
except RuntimeError as e:
141+
logger.warning(f" JAX: GPU not available ({e}), using CPU")
142+
jax.config.update('jax_platform_name', 'cpu')
143+
use_gpu = False
144+
else:
145+
logger.info(" JAX: CPU mode (GPU disabled)")
146+
jax.config.update('jax_platform_name', 'cpu')
147+
148+
# Configure thread limits for JAX/XLA
149+
if not use_gpu and n_threads is not None:
150+
# Set intra-op parallelism for CPU
151+
xla_flags = os.environ.get('XLA_FLAGS', '')
152+
153+
# Add thread limit to XLA_FLAGS
154+
thread_flag = f'--xla_cpu_multi_thread_eigen=true intra_op_parallelism_threads={n_threads}'
155+
156+
if 'intra_op_parallelism_threads' not in xla_flags:
157+
if xla_flags:
158+
xla_flags = f'{xla_flags} {thread_flag}'
159+
else:
160+
xla_flags = thread_flag
161+
162+
os.environ['XLA_FLAGS'] = xla_flags
163+
logger.info(f" JAX/XLA thread limit: {n_threads} threads")
164+
logger.debug(f" XLA_FLAGS={xla_flags}")
165+
else:
166+
logger.debug(" XLA thread limit already configured")
167+
168+
169+
def _configure_dask(n_threads: int = None):
170+
"""
171+
Configure Dask thread limits.
172+
173+
Parameters
174+
----------
175+
n_threads : int, optional
176+
Number of threads for Dask
177+
"""
178+
try:
179+
import dask
180+
import dask.config
181+
182+
if n_threads is not None:
183+
# Configure Dask to use specified number of threads
184+
dask.config.set(scheduler='threads', num_workers=n_threads)
185+
logger.info(f" Dask: thread limit set to {n_threads} threads")
186+
logger.debug(f" Dask scheduler: threads, num_workers={n_threads}")
187+
else:
188+
logger.debug(" Dask: using default configuration")
189+
190+
except ImportError:
191+
# Dask not installed, skip
192+
pass
193+
194+
195+
def get_device_info():
196+
"""
197+
Get information about available compute devices.
198+
199+
Returns
200+
-------
201+
dict
202+
Dictionary with device information including:
203+
- gpu_available: bool
204+
- gpu_devices: list of device descriptions
205+
- cpu_count: int (logical cores)
206+
- jax_platform: str (current JAX platform)
207+
"""
208+
info = {
209+
'gpu_available': False,
210+
'gpu_devices': [],
211+
'cpu_count': os.cpu_count(),
212+
'jax_platform': None
213+
}
214+
215+
try:
216+
import jax
217+
218+
# Check current JAX platform
219+
try:
220+
current_backend = jax.devices()[0].platform
221+
info['jax_platform'] = current_backend
222+
except Exception:
223+
info['jax_platform'] = 'unknown'
224+
225+
# Check for GPU devices
226+
try:
227+
gpu_devices = jax.devices('gpu')
228+
if len(gpu_devices) > 0:
229+
info['gpu_available'] = True
230+
info['gpu_devices'] = [str(d) for d in gpu_devices]
231+
except RuntimeError:
232+
pass
233+
234+
except ImportError:
235+
pass
236+
237+
return info
238+
239+
240+
def log_compute_environment():
241+
"""Log information about the current compute environment."""
242+
info = get_device_info()
243+
244+
logger.info("Compute Environment:")
245+
logger.info(f" CPU cores: {info['cpu_count']}")
246+
logger.info(f" JAX platform: {info['jax_platform']}")
247+
logger.info(f" GPU available: {info['gpu_available']}")
248+
if info['gpu_available']:
249+
logger.info(f" GPU devices: {len(info['gpu_devices'])}")
250+
for i, device in enumerate(info['gpu_devices']):
251+
logger.info(f" {i}: {device}")

kompot/cli/da.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ..anndata import compute_differential_abundance
1010
from .utils import load_config, merge_args_with_config, validate_anndata_path
11+
from .compute_config import configure_compute
1112

1213

1314
logger = logging.getLogger("kompot.cli")
@@ -147,6 +148,19 @@ def add_da_parser(subparsers) -> argparse.ArgumentParser:
147148
help='Overwrite existing results without warning'
148149
)
149150

151+
# Compute configuration
152+
parser.add_argument(
153+
'--use-gpu',
154+
action='store_true',
155+
help='Use GPU for computation (requires CUDA-enabled JAX)'
156+
)
157+
158+
parser.add_argument(
159+
'--threads',
160+
type=int,
161+
help='Number of threads to use for JAX, NumPy, and Dask (default: all available cores)'
162+
)
163+
150164
parser.set_defaults(func=run_da)
151165

152166
return parser
@@ -179,10 +193,25 @@ def run_da(args):
179193
logger.info(f"Loading configuration from {args.config}")
180194
config = load_config(args.config)
181195

196+
# Configure compute resources (must be done AFTER mellon import in compute_differential_abundance)
197+
# Extract compute config before other processing
198+
use_gpu = getattr(args, 'use_gpu', False)
199+
n_threads = getattr(args, 'threads', None)
200+
201+
# Log configuration before compute setup
202+
if use_gpu:
203+
logger.info("GPU acceleration: ENABLED")
204+
else:
205+
logger.info("GPU acceleration: DISABLED (using CPU)")
206+
if n_threads:
207+
logger.info(f"Thread limit: {n_threads}")
208+
else:
209+
logger.info("Thread limit: NONE (using all available cores)")
210+
182211
# Convert args to dict, removing None values and CLI-specific args
183212
args_dict = {
184213
k: v for k, v in vars(args).items()
185-
if v is not None and k not in ['input', 'output', 'table_output', 'config', 'func', 'verbose', 'command']
214+
if v is not None and k not in ['input', 'output', 'table_output', 'config', 'func', 'verbose', 'command', 'use_gpu', 'threads']
186215
}
187216

188217
# Rename CLI args to match function parameters
@@ -222,6 +251,21 @@ def run_da(args):
222251
logger.info(f" Condition 2: {params['condition2']}")
223252
logger.info(f" ObsM key: {params.get('obsm_key', 'X_pca')}")
224253

254+
# Configure computational backend
255+
# This must be called AFTER mellon import (which happens in compute_differential_abundance)
256+
# So we do a "lazy" import here to trigger mellon import, then configure
257+
logger.info("")
258+
logger.info("Configuring computational backend...")
259+
try:
260+
# Import mellon to trigger its JAX configuration
261+
import mellon
262+
# Now configure our settings (will override mellon's CPU-only default if needed)
263+
configure_compute(use_gpu=use_gpu, n_threads=n_threads)
264+
except Exception as e:
265+
logger.warning(f"Could not configure compute backend: {e}")
266+
logger.warning("Proceeding with default configuration")
267+
logger.info("")
268+
225269
# Run analysis - use return_full_results if table output is requested
226270
try:
227271
if args.table_output:

0 commit comments

Comments
 (0)