Skip to content

Commit 735ade5

Browse files
sigeislern-gao
andauthored
Optional ssh port forwarding (#134)
* Optional ssh port forwarding * Update database.py * File locking for ssh tunnel * add configuration to ssh forwarding; move defaults to settings.py --------- Co-authored-by: Nicholas Gao <nicholas.gao@tum.de>
1 parent 5f22bf6 commit 735ade5

File tree

7 files changed

+156
-39
lines changed

7 files changed

+156
-39
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
repos:
22
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v4.5.0
3+
rev: v4.6.0
44
hooks:
55
- id: check-case-conflict
66
- id: check-toml
@@ -10,7 +10,7 @@ repos:
1010
- id: trailing-whitespace
1111
- repo: https://github.com/astral-sh/ruff-pre-commit
1212
# Ruff version.
13-
rev: v0.3.0
13+
rev: v0.3.5
1414
hooks:
1515
# Run the linter.
1616
- id: ruff

README.md

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ Keeping track of computational experiments can be annoying and failure to do so
77
While workload scheduling systems such as [`Slurm`](https://slurm.schedmd.com/overview.html) make it easy to run many experiments in parallel on a cluster, it can be hard to keep track of which parameter configurations are running, failed, or completed.
88
[`sacred`](https://github.com/IDSIA/sacred) is a great tool to collect and manage experiments and their results, especially when used with a [`MongoDB`](https://www.mongodb.com/). However, it is lacking integration with workload schedulers.
99

10-
**`SEML`** enables you to
10+
**`SEML`** enables you to
1111
* very easily define hyperparameter search spaces using YAML files,
1212
* run these hyperparameter configurations on a compute cluster using `Slurm`,
1313
* and to track the experimental results using `sacred` and `MongoDB`.
@@ -31,7 +31,7 @@ conda install -c conda-forge seml
3131
```
3232
Then configure your MongoDB via:
3333
```bash
34-
seml configure --mongodb # provide your MongoDB credentials
34+
seml configure
3535
```
3636
For convenience, you may create your first **`SEML`** project via:
3737
```bash
@@ -40,6 +40,17 @@ seml project init -t default new_project
4040
```
4141
in an empty directoy. **`SEML`** will automatically create a python package for you.
4242

43+
44+
### SSH Port Forwarding
45+
If your MongoDB is only accessible via an SSH port forward, **`SEML`** allows you to directly configure this as well if you install the `ssh_forward` dependencies via:
46+
```bash
47+
pip install seml[ssh_forward]
48+
```
49+
It remains to configure the SSH settings:
50+
```bash
51+
seml configure --ssh_forward
52+
```
53+
4354
### Development
4455
If you want to develop `seml` please clone the repository and install it via
4556
```bash
@@ -48,7 +59,7 @@ pip install -e .[dev]
4859
and install pre-commit hooks via
4960
```bash
5061
pre-commit install
51-
```
62+
```
5263

5364
## Documentation
5465
Documentation is available in our [docs.md](docs.md) or via the CLI:

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ dependencies = [
3636

3737
[project.optional-dependencies]
3838
dev = ["pytest", "ruff", "pre-commit"]
39+
ssh_forward = ["sshtunnel>=0.4.0", "filelock>=3.13.3"]
3940

4041
[tool.ruff.format]
4142
quote-style = "single"

src/seml/__main__.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -356,27 +356,20 @@ def clean_db_command(ctx: typer.Context, yes: YesAnnotation = False):
356356
@restrict_collection(False)
357357
def configure_command(
358358
ctx: typer.Context,
359-
all: Annotated[
359+
ssh_forward: Annotated[
360360
bool,
361361
typer.Option(
362-
'-a',
363-
'--all',
364-
help='Configure all SEML settings',
362+
'-sf',
363+
'--ssh-forward',
364+
help='Configure SSH forwarding settings for MongoDB.',
365365
is_flag=True,
366366
),
367367
] = False,
368-
mongodb: Annotated[
369-
bool,
370-
typer.Option(
371-
help='Configure MongoDB settings',
372-
is_flag=True,
373-
),
374-
] = True,
375368
):
376369
"""
377370
Configure SEML (database, argument completion, ...).
378371
"""
379-
configure(all=all, mongodb=mongodb)
372+
configure(all=False, mongodb=True, setup_ssh_forward=ssh_forward)
380373

381374

382375
@app.command('start-jupyter')

src/seml/configure.py

Lines changed: 37 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,30 @@
11
import logging
2+
import yaml
23

34
from seml.settings import SETTINGS
45

56

6-
def mongodb_configure():
7+
def prompt_ssh_forward():
8+
"""
9+
Prompt the user for SSH Forward settings. The output format corresponds
10+
to the argument of sshtunnel.SSHTunnelForwarder.
11+
"""
12+
from seml.console import prompt
13+
14+
logging.info('Configuring SSH Forward settings.')
15+
ssh_host = prompt('SSH host')
16+
port = prompt('Port', default=22, type=int)
17+
username = prompt('User name')
18+
ssh_pkey = prompt('Path to SSH private key', default='~/.ssh/id_rsa')
19+
return dict(
20+
ssh_address_or_host=ssh_host,
21+
ssh_port=port,
22+
ssh_username=username,
23+
ssh_pkey=ssh_pkey,
24+
)
25+
26+
27+
def mongodb_configure(setup_ssh_forward: bool = False):
728
from seml.console import prompt
829

930
if SETTINGS.DATABASE.MONGODB_CONFIG_PATH.exists() and not prompt(
@@ -13,31 +34,36 @@ def mongodb_configure():
1334
return
1435
logging.info('Configuring MongoDB. Warning: Password will be stored in plain text.')
1536
host = prompt('MongoDB host')
16-
port = prompt('Port', default='27017')
37+
port = prompt('Port', default=27017, type=int)
1738
database = prompt('Database name')
1839
username = prompt('User name')
1940
password = prompt('Password', hide_input=True)
2041
file_path = SETTINGS.DATABASE.MONGODB_CONFIG_PATH
21-
config_string = (
22-
f'username: {username}\n'
23-
f'password: {password}\n'
24-
f'port: {port}\n'
25-
f'database: {database}\n'
26-
f'host: {host}'
42+
config = dict(
43+
host=host,
44+
port=port,
45+
database=database,
46+
username=username,
47+
password=password,
2748
)
49+
if setup_ssh_forward:
50+
config['ssh_config'] = prompt_ssh_forward()
51+
config_string = yaml.dump(config)
2852
logging.info(
2953
f"Saving the following configuration to {file_path}:\n"
30-
f"{config_string.replace(f'password: {password}', 'password: ********')}"
54+
f"{config_string.replace(f'{password}', '********')}"
3155
)
3256
file_path.parent.mkdir(parents=True, exist_ok=True)
3357
with open(file_path, 'w') as f:
3458
f.write(config_string)
3559

3660

37-
def configure(all: bool = False, mongodb: bool = False):
61+
def configure(
62+
all: bool = False, mongodb: bool = False, setup_ssh_forward: bool = False
63+
):
3864
configured_any = False
3965
if mongodb or all:
40-
mongodb_configure()
66+
mongodb_configure(setup_ssh_forward=setup_ssh_forward)
4167
configured_any = True
4268
if not configured_any:
4369
logging.info('Did not specify any configuration to configure')

src/seml/database.py

Lines changed: 91 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
import logging
2+
import random
3+
import time
24
from typing import List
35

6+
import yaml
7+
48
from seml.errors import MongoDBError
59
from seml.settings import SETTINGS
610
from seml.utils import s_if
@@ -18,7 +22,69 @@ def get_collection(collection_name, mongodb_config=None, suffix=None):
1822
return db[collection_name]
1923

2024

21-
def get_mongo_client(db_name, host, port, username, password, **kwargs):
25+
def retried_and_locked_ssh_port_forward(
26+
retries_max=SETTINGS.SSH_FORWARD.RETRIES_MAX,
27+
retries_delay=SETTINGS.SSH_FORWARD.RETRIES_DELAY,
28+
lock_file=SETTINGS.SSH_FORWARD.LOCK_FILE,
29+
lock_timeout=SETTINGS.SSH_FORWARD.LOCK_TIMEOUT,
30+
**ssh_config,
31+
):
32+
try:
33+
from sshtunnel import (
34+
BaseSSHTunnelForwarderError,
35+
SSHTunnelForwarder,
36+
create_logger,
37+
)
38+
except ImportError:
39+
logging.error(
40+
'Opening ssh tunnel requires `sshtunnel` (e.g. `pip install sshtunnel`)'
41+
)
42+
exit(1)
43+
try:
44+
from filelock import FileLock, Timeout
45+
except ImportError:
46+
logging.error(
47+
'Opening ssh tunnel requires `filelock` (e.g. `pip install filelock`)'
48+
)
49+
exit(1)
50+
51+
delay = retries_delay
52+
error = None
53+
for _ in range(retries_max):
54+
try:
55+
lock = FileLock(lock_file, timeout=lock_timeout)
56+
with lock:
57+
server = SSHTunnelForwarder(
58+
**ssh_config,
59+
logger=create_logger(logging.getLogger(), loglevel=logging.ERROR),
60+
)
61+
server.start()
62+
return server
63+
except Timeout as e:
64+
error = e
65+
logging.warn(f'Failed to aquire lock for ssh tunnel {lock_file}')
66+
except BaseSSHTunnelForwarderError as e:
67+
error = e
68+
logging.warn(f'Retry establishing ssh tunnel in {delay} s')
69+
# Jittered exponential retry
70+
time.sleep(delay)
71+
delay *= 2
72+
delay += random.uniform(0, 1)
73+
74+
if error:
75+
logging.error(f'Failed to establish ssh tunnel: {error}')
76+
exit(1)
77+
78+
79+
def get_mongo_client(
80+
db_name, host, port, username, password, ssh_config=None, **kwargs
81+
):
82+
if ssh_config is not None:
83+
server = retried_and_locked_ssh_port_forward(**ssh_config)
84+
85+
host = server.local_bind_host
86+
port = server.local_bind_port
87+
2288
import pymongo
2389

2490
client = pymongo.MongoClient(
@@ -76,7 +142,8 @@ def get_mongodb_config(path=SETTINGS.DATABASE.MONGODB_CONFIG_PATH):
76142
- database name
77143
- username
78144
- password
79-
- directConnection
145+
- directConnection (Optional)
146+
- ssh_config (Optional)
80147
81148
Default path is $HOME/.config/seml/mongodb.config.
82149
@@ -87,6 +154,15 @@ def get_mongodb_config(path=SETTINGS.DATABASE.MONGODB_CONFIG_PATH):
87154
database: <database_name>
88155
host: <host>
89156
directConnection: <bool> (Optional)
157+
ssh_config: <dict> (Optional)
158+
ssh_address_or_host: <the url of the jump host>
159+
ssh_pkey: <the ssh host key>
160+
ssh_username: <username for jump host>
161+
retries_max: <number of retries to establish shh tunnel, default 6> (Optional)
162+
retries_delay: <initial wait time for exponential retry, default 1> (Optional)
163+
lock_file: <lockfile to avoid establishing ssh tunnel parallely, default `~/seml_ssh.lock`> (Optional)
164+
lock_timeout: <timeout for aquiring lock, default 30> (Optional)
165+
** further arguments passed to `SSHTunnelForwarder` (see https://github.com/pahaz/sshtunnel)
90166
91167
Returns
92168
-------
@@ -103,14 +179,8 @@ def get_mongodb_config(path=SETTINGS.DATABASE.MONGODB_CONFIG_PATH):
103179
f"MongoDB credentials could not be read at '{path}'.{config_str}"
104180
)
105181

106-
with open(path, 'r') as f:
107-
for line in f.readlines():
108-
# ignore lines that are empty
109-
if len(line.strip()) > 0:
110-
split = line.split(':')
111-
key = split[0].strip()
112-
value = split[1].strip()
113-
access_dict[key] = value
182+
with open(path, 'r') as conf:
183+
access_dict = yaml.safe_load(conf)
114184

115185
required_entries = ['username', 'password', 'port', 'host', 'database']
116186
for entry in required_entries:
@@ -129,7 +199,7 @@ def get_mongodb_config(path=SETTINGS.DATABASE.MONGODB_CONFIG_PATH):
129199
else False
130200
)
131201

132-
return {
202+
cfg = {
133203
'password': db_password,
134204
'username': db_username,
135205
'host': db_host,
@@ -138,6 +208,15 @@ def get_mongodb_config(path=SETTINGS.DATABASE.MONGODB_CONFIG_PATH):
138208
'directConnection': db_direct,
139209
}
140210

211+
if 'ssh_config' not in access_dict:
212+
return cfg
213+
214+
cfg['ssh_config'] = access_dict['ssh_config']
215+
cfg['ssh_config']['remote_bind_address'] = (db_host, db_port)
216+
cfg['directConnection'] = True
217+
218+
return cfg
219+
141220

142221
def build_filter_dict(filter_states, batch_id, filter_dict, sacred_id=None):
143222
"""
@@ -255,6 +334,7 @@ def upload_file(filename, db_collection, batch_id, filetype):
255334

256335
def delete_files(database, file_ids, progress=False):
257336
import gridfs
337+
258338
from seml.console import track
259339

260340
fs = gridfs.GridFS(database)

src/seml/settings.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,12 @@
145145
'AUTOCOMPLETE_CACHE_ALIVE_TIME': 300,
146146
'SETUP_COMMAND': '',
147147
'END_COMMAND': '',
148+
'SSH_FORWARD': {
149+
'LOCK_FILE': '/tmp/seml_ssh_forward.lock',
150+
'RETRIES_MAX': 6,
151+
'RETRIES_DELAY': 1,
152+
'LOCK_TIMEOUT': 30,
153+
},
148154
},
149155
)
150156

0 commit comments

Comments
 (0)