Skip to content

Commit 88834df

Browse files
Merge pull request #741 from douglasjacobsen/py-nemo-cache-create
Add nemo executable to create transformer cache
2 parents f39fd59 + 376b7b8 commit 88834df

File tree

1 file changed

+41
-31
lines changed

1 file changed

+41
-31
lines changed

var/ramble/repos/builtin/applications/py-nemo/application.py

Lines changed: 41 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,15 @@ class PyNemo(ExecutableApplication):
2727

2828
tags("ml-framework", "machine-learning")
2929

30+
executable(
31+
"setup_transformer_cache",
32+
'bash -c "python3 -c \'from transformers import AutoTokenizer; AutoTokenizer.from_pretrained(\\"gpt2\\")\'"',
33+
use_mpi=True,
34+
)
35+
3036
executable(
3137
"pretraining_exec",
32-
'bash -c "cd /opt/NeMo; git rev-parse HEAD; export PYTHONPATH=/opt/NeMo:\${PYTHONPATH}; '
33-
"CUDA_VISIBLE_DEVICES={cuda_visible_devices} "
38+
'bash -c "cd /opt/NeMo; git rev-parse HEAD; '
3439
"python3 -u /opt/NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py "
3540
'--config-path={nemo_generated_config_path} --config-name={nemo_generated_config_name}"',
3641
use_mpi=True,
@@ -50,7 +55,11 @@ class PyNemo(ExecutableApplication):
5055

5156
workload(
5257
"pretraining",
53-
executables=["create_logs", "pretraining_exec"],
58+
executables=[
59+
"create_logs",
60+
"setup_transformer_cache",
61+
"pretraining_exec",
62+
],
5463
inputs=["nemo_fetched_config"],
5564
)
5665

@@ -1361,38 +1370,39 @@ def _preprocess_log(self, workspace, app_inst):
13611370

13621371
final_regex = re.compile(self.final_epoch_regex)
13631372

1364-
with open(log_file, "r", encoding="ISO-8859-1") as f:
1365-
data = f.read()
1366-
1367-
with open(log_file, "r", encoding="ISO-8859-1") as f:
1368-
for line in f.readlines():
1369-
m = final_regex.match(line)
1373+
if os.path.exists(log_file):
1374+
with open(log_file, "r", encoding="ISO-8859-1") as f:
1375+
data = f.read()
13701376

1371-
if m:
1372-
timestamp = m.group("elapsed_time")
1377+
with open(log_file, "r", encoding="ISO-8859-1") as f:
1378+
for line in f.readlines():
1379+
m = final_regex.match(line)
13731380

1374-
time_parts = timestamp.split(":")
1381+
if m:
1382+
timestamp = m.group("elapsed_time")
13751383

1376-
part_s = 0
1377-
mult = 1
1378-
for part in reversed(time_parts):
1379-
part_s += int(part) * mult
1380-
mult = mult * 60
1381-
elapsed_s += part_s
1384+
time_parts = timestamp.split(":")
13821385

1383-
processed_log = self.expander.expand_var(
1384-
"{experiment_run_dir}/processed_{experiment_name}.out"
1385-
)
1386+
part_s = 0
1387+
mult = 1
1388+
for part in reversed(time_parts):
1389+
part_s += int(part) * mult
1390+
mult = mult * 60
1391+
elapsed_s += part_s
13861392

1387-
with open(processed_log, "w+") as f:
1388-
f.write(
1389-
data.replace("\x13", "\n")
1390-
.replace("\x96\x88", "")
1391-
.replace("â", "")
1393+
processed_log = self.expander.expand_var(
1394+
"{experiment_run_dir}/processed_{experiment_name}.out"
13921395
)
13931396

1394-
sec_file_path = self.expander.expand_var(
1395-
"{experiment_run_dir}/elapsed_seconds"
1396-
)
1397-
with open(sec_file_path, "w+") as f:
1398-
f.write(f"Elapsed seconds: {elapsed_s}")
1397+
with open(processed_log, "w+") as f:
1398+
f.write(
1399+
data.replace("\x13", "\n")
1400+
.replace("\x96\x88", "")
1401+
.replace("â", "")
1402+
)
1403+
1404+
sec_file_path = self.expander.expand_var(
1405+
"{experiment_run_dir}/elapsed_seconds"
1406+
)
1407+
with open(sec_file_path, "w+") as f:
1408+
f.write(f"Elapsed seconds: {elapsed_s}")

0 commit comments

Comments
 (0)