Skip to content

Commit cce939a

Browse files
committed
Fix: style
1 parent b2eaa60 commit cce939a

1 file changed

Lines changed: 48 additions & 28 deletions

File tree

src/libkernelbot/run_eval.py

Lines changed: 48 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,16 @@ def _create_files(files: Optional[dict[str, str]]):
130130
def _directory_to_zip_bytes(directory_path) -> str:
131131
"""Create a zip archive and return as base64 encoded bytes."""
132132
with tempfile.TemporaryDirectory() as temp_dir:
133-
archive_path = os.path.join(temp_dir, 'archive')
134-
shutil.make_archive(archive_path, 'zip', directory_path)
133+
archive_path = os.path.join(temp_dir, "archive")
134+
shutil.make_archive(archive_path, "zip", directory_path)
135135

136-
with open(archive_path + '.zip', 'rb') as f:
136+
with open(archive_path + ".zip", "rb") as f:
137137
data = f.read()
138138

139-
return base64.b64encode(data).decode('utf-8')
139+
return base64.b64encode(data).decode("utf-8")
140140

141141

142-
def _filter_ncu_report(report: str, tables: list):
142+
def _filter_ncu_report(report: str, tables: list): # noqa: C901
143143
"""
144144
Extract the Speed-of-light section from the full ncu terminal report.
145145
@@ -154,19 +154,19 @@ def _filter_ncu_report(report: str, tables: list):
154154
collect = False
155155
length = 0
156156
for line in report.splitlines():
157-
if len(line) >= 3 and line[1] == ' ' and line[2] != ' ':
157+
if len(line) >= 3 and line[1] == " " and line[2] != " ":
158158
if n_kernels != 0:
159159
result += "\n"
160160
n_kernels += 1
161161
if n_kernels == 3:
162-
result += "\nAdditional kernel launches follow. Please check the .ncu-rep file for more details.\n"
162+
result += "\nAdditional kernel launches follow. Please check the .ncu-rep file for more details.\n" # noqa: E501
163163
result += line + "\n"
164164

165165
if n_kernels > 2:
166166
continue
167167

168168
if "Table Name : " in line:
169-
table = line[line.find("Table Name :") + len("Table Name :"):].strip()
169+
table = line[line.find("Table Name :") + len("Table Name :") :].strip()
170170
if table in tables:
171171
result += "\n"
172172
collect = True
@@ -181,7 +181,7 @@ def _filter_ncu_report(report: str, tables: list):
181181
length += 1
182182
# just as a precaution, also limit lines directly
183183
if length > 100:
184-
result += "\n[...]\nReport has been truncated. Please check the .ncu-rep file for more details.\n"
184+
result += "\n[...]\nReport has been truncated. Please check the .ncu-rep file for more details.\n" # noqa: E501
185185
break
186186
return result
187187

@@ -406,10 +406,15 @@ def profile_program_roc(
406406
"--",
407407
] + call
408408

409-
run_result = run_program(call, seed=seed, timeout=timeout, multi_gpu=multi_gpu, extra_env={
410-
"GPU_DUMP_CODE_OBJECT": "1",
411-
},
412-
)
409+
run_result = run_program(
410+
call,
411+
seed=seed,
412+
timeout=timeout,
413+
multi_gpu=multi_gpu,
414+
extra_env={
415+
"GPU_DUMP_CODE_OBJECT": "1",
416+
},
417+
)
413418

414419
profile_result = None
415420

@@ -453,32 +458,49 @@ def profile_program_ncu(
453458
# Wrap program in ncu
454459
call = [
455460
"ncu",
456-
"--set", "full",
461+
"--set",
462+
"full",
457463
"--nvtx",
458-
"--nvtx-include", "custom_kernel/",
459-
"--import-source", "1",
460-
"-c", "10",
461-
"-o", f"{str(output_dir / 'profile.ncu-rep')}",
464+
"--nvtx-include",
465+
"custom_kernel/",
466+
"--import-source",
467+
"1",
468+
"-c",
469+
"10",
470+
"-o",
471+
f"{str(output_dir / 'profile.ncu-rep')}",
462472
"--",
463473
] + call
464474

465-
run_result = run_program(call, seed=seed, timeout=timeout, multi_gpu=multi_gpu, extra_env={
466-
"POPCORN_NCU": "1"
467-
})
475+
run_result = run_program(
476+
call, seed=seed, timeout=timeout, multi_gpu=multi_gpu, extra_env={"POPCORN_NCU": "1"}
477+
)
468478
profile_result = None
469479

470480
try:
471-
get_tables = ["GPU Throughput", "Pipe Utilization (% of active cycles)", "Warp State (All Cycles)"]
472-
ncu_cmd = ["ncu", "--import", f"{str(output_dir / 'profile.ncu-rep')}", "--print-details", "body"]
481+
get_tables = [
482+
"GPU Throughput",
483+
"Pipe Utilization (% of active cycles)",
484+
"Warp State (All Cycles)",
485+
]
486+
ncu_cmd = [
487+
"ncu",
488+
"--import",
489+
f"{str(output_dir / 'profile.ncu-rep')}",
490+
"--print-details",
491+
"body",
492+
]
473493
report = subprocess.check_output(ncu_cmd, text=True)
474494
report = _filter_ncu_report(report, get_tables)
475-
run_result.result["benchmark.0.report"] = base64.b64encode(report.encode("utf-8")).decode("utf-8")
495+
run_result.result["benchmark.0.report"] = base64.b64encode(report.encode("utf-8")).decode(
496+
"utf-8"
497+
)
476498
except subprocess.CalledProcessError:
477499
pass
478500

479501
if run_result.success:
480502
profile_result = ProfileResult(
481-
profiler='Nsight-Compute',
503+
profiler="Nsight-Compute",
482504
trace=_directory_to_zip_bytes(output_dir),
483505
download_url=None,
484506
)
@@ -822,9 +844,7 @@ def run_config(config: dict):
822844
}
823845
if config["lang"] == "py":
824846
runner = functools.partial(
825-
run_pytorch_script,
826-
sources=config["sources"],
827-
main=config["main"]
847+
run_pytorch_script, sources=config["sources"], main=config["main"]
828848
)
829849
elif config["lang"] == "cu":
830850
runner = functools.partial(

0 commit comments

Comments
 (0)