Skip to content

Commit 6753b04

Browse files
committed
update snakemake
1 parent 78ea07e commit 6753b04

File tree

1 file changed

+105
-43
lines changed

1 file changed

+105
-43
lines changed

scripts/_helpers.py

+105-43
Original file line numberDiff line numberDiff line change
@@ -55,65 +55,127 @@ def __dir__(self):
5555
return dict_keys + obj_attrs
5656

5757

58-
def mock_snakemake(rulename, **wildcards):
58+
def mock_snakemake(
59+
rulename,
60+
root_dir=None,
61+
configfiles=None,
62+
submodule_dir="workflow/submodules/pypsa-eur",
63+
**wildcards,
64+
):
5965
"""
6066
This function is expected to be executed from the 'scripts'-directory of '
6167
the snakemake project. It returns a snakemake.script.Snakemake object,
6268
based on the Snakefile.
69+
6370
If a rule has wildcards, you have to specify them in **wildcards.
71+
6472
Parameters
6573
----------
6674
rulename: str
6775
name of the rule for which the snakemake object should be generated
76+
root_dir: str/path-like
77+
path to the root directory of the snakemake project
78+
configfiles: list, str
79+
list of configfiles to be used to update the config
80+
submodule_dir: str, Path
81+
in case PyPSA-Eur is used as a submodule, submodule_dir is
82+
the path of pypsa-eur relative to the project directory.
6883
**wildcards:
6984
keyword arguments fixing the wildcards. Only necessary if wildcards are
7085
needed.
7186
"""
7287
import os
7388

7489
import snakemake as sm
75-
from packaging.version import Version, parse
90+
from pypsa.descriptors import Dict
91+
from snakemake.api import Workflow
92+
from snakemake.common import SNAKEFILE_CHOICES
7693
from snakemake.script import Snakemake
77-
78-
script_dir = Path(__file__).parent.resolve()
79-
assert (
80-
Path.cwd().resolve() == script_dir
81-
), f"mock_snakemake has to be run from the repository scripts directory {script_dir}"
82-
os.chdir(script_dir.parent)
83-
for p in sm.SNAKEFILE_CHOICES:
84-
if os.path.exists(p):
85-
snakefile = p
86-
break
87-
kwargs = dict(rerun_triggers=[]) if parse(sm.__version__) > Version("7.7.0") else {}
88-
workflow = sm.Workflow(snakefile, overwrite_configfiles=[], **kwargs)
89-
workflow.include(snakefile)
90-
workflow.global_resources = {}
91-
rule = workflow.get_rule(rulename)
92-
dag = sm.dag.DAG(workflow, rules=[rule])
93-
wc = Dict(wildcards)
94-
job = sm.jobs.Job(rule, dag, wc)
95-
96-
def make_accessable(*ios):
97-
for io in ios:
98-
for i in range(len(io)):
99-
io[i] = os.path.abspath(io[i])
100-
101-
make_accessable(job.input, job.output, job.log)
102-
snakemake = Snakemake(
103-
job.input,
104-
job.output,
105-
job.params,
106-
job.wildcards,
107-
job.threads,
108-
job.resources,
109-
job.log,
110-
job.dag.workflow.config,
111-
job.rule.name,
112-
None,
94+
from snakemake.settings import (
95+
ConfigSettings,
96+
DAGSettings,
97+
ResourceSettings,
98+
StorageSettings,
99+
WorkflowSettings,
113100
)
114-
# create log and output dir if not existent
115-
for path in list(snakemake.log) + list(snakemake.output):
116-
Path(path).parent.mkdir(parents=True, exist_ok=True)
117101

118-
os.chdir(script_dir)
119-
return snakemake
102+
script_dir = Path(__file__).parent.resolve()
103+
if root_dir is None:
104+
root_dir = script_dir.parent
105+
else:
106+
root_dir = Path(root_dir).resolve()
107+
108+
user_in_script_dir = Path.cwd().resolve() == script_dir
109+
if str(submodule_dir) in __file__:
110+
# the submodule_dir path is only need to locate the project dir
111+
os.chdir(Path(__file__[: __file__.find(str(submodule_dir))]))
112+
elif user_in_script_dir:
113+
os.chdir(root_dir)
114+
elif Path.cwd().resolve() != root_dir:
115+
raise RuntimeError(
116+
"mock_snakemake has to be run from the repository root"
117+
f" {root_dir} or scripts directory {script_dir}"
118+
)
119+
try:
120+
for p in SNAKEFILE_CHOICES:
121+
if os.path.exists(p):
122+
snakefile = p
123+
break
124+
if configfiles is None:
125+
configfiles = []
126+
elif isinstance(configfiles, str):
127+
configfiles = [configfiles]
128+
129+
resource_settings = ResourceSettings()
130+
config_settings = ConfigSettings(configfiles=map(Path, configfiles))
131+
workflow_settings = WorkflowSettings()
132+
storage_settings = StorageSettings()
133+
dag_settings = DAGSettings(rerun_triggers=[])
134+
workflow = Workflow(
135+
config_settings,
136+
resource_settings,
137+
workflow_settings,
138+
storage_settings,
139+
dag_settings,
140+
storage_provider_settings=dict(),
141+
)
142+
workflow.include(snakefile)
143+
144+
if configfiles:
145+
for f in configfiles:
146+
if not os.path.exists(f):
147+
raise FileNotFoundError(f"Config file {f} does not exist.")
148+
workflow.configfile(f)
149+
150+
workflow.global_resources = {}
151+
rule = workflow.get_rule(rulename)
152+
dag = sm.dag.DAG(workflow, rules=[rule])
153+
wc = Dict(wildcards)
154+
job = sm.jobs.Job(rule, dag, wc)
155+
156+
def make_accessable(*ios):
157+
for io in ios:
158+
for i, _ in enumerate(io):
159+
io[i] = os.path.abspath(io[i])
160+
161+
make_accessable(job.input, job.output, job.log)
162+
snakemake = Snakemake(
163+
job.input,
164+
job.output,
165+
job.params,
166+
job.wildcards,
167+
job.threads,
168+
job.resources,
169+
job.log,
170+
job.dag.workflow.config,
171+
job.rule.name,
172+
None,
173+
)
174+
# create log and output dir if not existent
175+
for path in list(snakemake.log) + list(snakemake.output):
176+
Path(path).parent.mkdir(parents=True, exist_ok=True)
177+
178+
finally:
179+
if user_in_script_dir:
180+
os.chdir(script_dir)
181+
return snakemake

0 commit comments

Comments
 (0)