Skip to content

Commit 9dd32f5

Browse files
authored
ruff: format with 0.9.0 (#1235)
This just landed on PyPI; some - but not all - of these changes are accepted by the older 0.8.0 too.
1 parent 5a74526 commit 9dd32f5

File tree

8 files changed

+47
-46
lines changed

8 files changed

+47
-46
lines changed

.github/container/jax-nccl-test

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ if __name__ == "__main__":
7474
)
7575
args = parser.parse_args()
7676

77-
assert (
78-
args.process_id is None or args.distributed
79-
), "--process-id is only relevant with --distributed"
77+
assert args.process_id is None or args.distributed, (
78+
"--process-id is only relevant with --distributed"
79+
)
8080
if args.distributed:
8181
null_args = {
8282
args.coordinator_address is None,
@@ -108,7 +108,7 @@ if __name__ == "__main__":
108108
f"Rank {args.process_id} has local rank {local_process_id} and "
109109
f"devices {local_device_ids} from a total of {visible_devices} "
110110
f"visible on this node, {args.process_count} processes and "
111-
f"{args.process_count*args.gpus_per_process} total devices.",
111+
f"{args.process_count * args.gpus_per_process} total devices.",
112112
flush=True,
113113
)
114114
jax.distributed.initialize(
@@ -209,7 +209,7 @@ if __name__ == "__main__":
209209
if host_timer:
210210
result.block_until_ready()
211211
if jax.process_index() == 0:
212-
print(f"First {op} duration {time.time()-start:.2f}s")
212+
print(f"First {op} duration {time.time() - start:.2f}s")
213213
return result
214214

215215
def device_put_local(x: jax.Array):

.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -299,11 +299,11 @@
299299
"# Print out the largest entries adding up to at least this fraction of the total\n",
300300
"threshold = 0.97\n",
301301
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
302-
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n",
302+
"print(f\"Top {threshold:.0%}+ of {total_compile_time * 1e-9:.2f}s compilation time\")\n",
303303
"for row in compile_summary[\n",
304304
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
305305
"].itertuples():\n",
306-
" print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs*1e-3:.2f}s {row.Index}\")"
306+
" print(f\"{row.FracNonChild:6.2%} {row.DurNonChildMs * 1e-3:.2f}s {row.Index}\")"
307307
]
308308
},
309309
{
@@ -585,9 +585,9 @@
585585
"detailed_mask = (compute_duration_rel_stds > var_threshold) & (\n",
586586
" compute_duration_means > mean_threshold\n",
587587
")\n",
588-
"assert (\n",
589-
" detailed_mask.sum() <= detailed_limit\n",
590-
"), f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n",
588+
"assert detailed_mask.sum() <= detailed_limit, (\n",
589+
" f\"Aimed for {detailed_limit} and got {detailed_mask.sum()}\"\n",
590+
")\n",
591591
"\n",
592592
"fig, axs = plt.subplots(\n",
593593
" ncols=2, width_ratios=[1, 2], figsize=[15, 5], tight_layout=True\n",

.github/container/nsys_jax/nsys_jax/analysis.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def align_profiler_data_timestamps(
2828
# Error if the communication frame doesn't exist at all, but not if it is empty.
2929
# Calling this on a profile that does not contain any communication should
3030
# gracefully yield empty results.
31-
assert (
32-
frames.communication is not None
33-
), "align_profiler_data_timestamps requires a communication frame"
31+
assert frames.communication is not None, (
32+
"align_profiler_data_timestamps requires a communication frame"
33+
)
3434
if not len(frames.communication):
3535
# Nothing to be done, return an empty result
3636
return frames, {}
@@ -43,9 +43,9 @@ def align_profiler_data_timestamps(
4343
f"WARNING: cannot align {num_profiled_devices} devices because max collective size is 1"
4444
)
4545
return frames, {}
46-
assert (
47-
num_profiled_devices == max_collective_size
48-
), f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
46+
assert num_profiled_devices == max_collective_size, (
47+
f"Aligning {num_profiled_devices} using collectives of size {max_collective_size} is not implemented"
48+
)
4949
# Find the collectives that will be used
5050
align_df = comm_df[comm_df["CollectiveSize"] == max_collective_size]
5151
# Calculate the collectives' end times
@@ -190,19 +190,18 @@ def _get_message_size(
190190
) -> tuple[int, str, int, float, float]:
191191
_, inst = module_proto.find_instruction(instruction)
192192
comm_inst = inst.communication_proto()
193-
assert (
194-
comm_inst.opcode
195-
in {
196-
"all-gather-start",
197-
"all-reduce-start",
198-
"all-to-all",
199-
"collective-broadcast",
200-
"collective-permute-start",
201-
"dynamic-slice",
202-
"dynamic-update-slice",
203-
"reduce-scatter",
204-
}
205-
), f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
193+
assert comm_inst.opcode in {
194+
"all-gather-start",
195+
"all-reduce-start",
196+
"all-to-all",
197+
"collective-broadcast",
198+
"collective-permute-start",
199+
"dynamic-slice",
200+
"dynamic-update-slice",
201+
"reduce-scatter",
202+
}, (
203+
f"{instruction}: message size calculation for {comm_inst.opcode} has not yet been validated"
204+
)
206205

207206
def _byte_size(inst) -> int:
208207
size_bits = math.prod(
@@ -256,9 +255,9 @@ def _byte_size(inst) -> int:
256255
collective_size = iota_group_list.num_devices_per_group
257256
else:
258257
collective_sizes = set(len(group.replica_ids) for group in replica_groups)
259-
assert (
260-
len(collective_sizes) == 1
261-
), f"Heterogeneous collective {comm_inst} could not be interpreted"
258+
assert len(collective_sizes) == 1, (
259+
f"Heterogeneous collective {comm_inst} could not be interpreted"
260+
)
262261
collective_size = next(iter(collective_sizes))
263262
total_msg_size = 0
264263
for operand_id in comm_inst.operand_ids:

.github/container/nsys_jax/nsys_jax/data_loaders.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def _calculate_overlap(thunk_df: pd.DataFrame) -> pd.DataFrame:
5656
serial_mask = (
5757
compute_df["ProjStartMs"].array[1:] >= compute_df["ProjEndMs"].array[:-1]
5858
)
59-
assert serial_mask.all(), f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}"
59+
assert serial_mask.all(), (
60+
f"Only {serial_mask.sum()}/{len(serial_mask)} compute kernel pairs failed to overlap on device {device} and program #{program_id}"
61+
)
6062
# Update the projected duration of each communication kernel to only
6163
# include the non-overlapped time
6264
for comm_thunk in comm_df.itertuples():

.github/container/nsys_jax/nsys_jax/protobuf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ def _visit_computation(computation_id):
110110
if called_inst.opcode in comm_opcodes or _is_offloading_instruction(
111111
called_inst
112112
):
113-
assert (
114-
self._comm_proto is None
115-
), f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
113+
assert self._comm_proto is None, (
114+
f"Found {called_inst.opcode} child having already found {self._comm_proto.opcode}"
115+
)
116116
self._comm_proto = called_inst
117117

118118
for called_id in self._proto.called_computation_ids:

.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def copy_proto_files_to_tmp(
341341
if not osp.isdir(dst_dir):
342342
os.makedirs(dst_dir)
343343
shutil.copy(osp.join(root, proto), osp.join(proto_dir, proto))
344-
print(f"{archive_name}: gathered .proto files in {time.time()-start:.2f}s")
344+
print(f"{archive_name}: gathered .proto files in {time.time() - start:.2f}s")
345345
return proto_dir, proto_files
346346

347347
def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
@@ -369,7 +369,7 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
369369
if osp.isdir(full_path) or not osp.exists(full_path):
370370
continue
371371
output_queue.put((ofile, full_path, COMPRESS_NONE))
372-
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
372+
print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s")
373373

374374
def compress_and_archive(prefix, file, output_queue):
375375
"""
@@ -403,7 +403,7 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
403403
)
404404
for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir):
405405
compress_and_archive(tmp_dir, ofile, output_queue)
406-
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
406+
print(f"{archive_name}: post-processing finished in {time.time() - start:.2f}s")
407407

408408
def save_device_stream_thread_names(tmp_dir, report, output_queue):
409409
"""
@@ -501,7 +501,7 @@ def table_columns(table_name):
501501
else:
502502
print("WARNING: NOT writing device metadata, no device activity profiled?")
503503
print(
504-
f"{archive_name}: extracted device/thread names in {time.time()-start:.2f}s"
504+
f"{archive_name}: extracted device/thread names in {time.time() - start:.2f}s"
505505
)
506506

507507
def find_pb_files_in_tmp(tmp_dir):
@@ -553,7 +553,7 @@ def gather_source_files(
553553
continue
554554
assert osp.isabs(src_file), f"{src_file} is not absolute"
555555
output_queue.put(("sources" + src_file, src_file, COMPRESS_DEFLATE))
556-
print(f"{archive_name}: gathered source code in {time.time()-start:.2f}s")
556+
print(f"{archive_name}: gathered source code in {time.time() - start:.2f}s")
557557

558558
def execute_analysis_scripts(mirror_dir, analysis_scripts):
559559
"""
@@ -631,7 +631,7 @@ def write_output_file(to_process, mirror_dir, analysis_scripts):
631631
for path_in_archive, local_path in analysis_outputs:
632632
archive.write(filename=local_path, arcname=path_in_archive)
633633
os.chmod(archive_name, 0o644)
634-
print(f"{archive_name}: wrote in {time.time()-start:.2f}s")
634+
print(f"{archive_name}: wrote in {time.time() - start:.2f}s")
635635
if exit_code != 0:
636636
print("Exiting due to analysis script errors")
637637
sys.exit(exit_code)

.github/container/nsys_jax/nsys_jax/scripts/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def analysis_recipe_path(script):
4040
)
4141
if script_file.is_file():
4242
return script_file
43-
assert os.path.exists(
44-
script
45-
), f"{script} does not exist and is not the name of a built-in analysis script"
43+
assert os.path.exists(script), (
44+
f"{script} does not exist and is not the name of a built-in analysis script"
45+
)
4646
return contextlib.nullcontext(pathlib.Path(script))
4747

4848

.github/triage/jax_toolbox_triage/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def main():
2323
logger = get_logger(args.output_prefix)
2424
logger.info(
2525
"Verbose output, including stdout/err of triage commands, will be written to "
26-
f'{(args.output_prefix / "debug.log").resolve()}'
26+
f"{(args.output_prefix / 'debug.log').resolve()}"
2727
)
2828
container_url = functools.partial(container_url_base, container=args.container)
2929
container_exists = functools.partial(

0 commit comments

Comments
 (0)