Skip to content

Commit 19f1202

Browse files
authored
Add benchmark for WavDecoder (#1474)
1 parent 5883447 commit 19f1202

1 file changed

Lines changed: 226 additions & 0 deletions

File tree

Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
import io
2+
import os
3+
import subprocess
4+
from time import perf_counter_ns
5+
6+
import soundfile as sf
7+
import torch
8+
9+
from torchcodec.decoders import AudioDecoder, WavDecoder
10+
11+
12+
def bench(f, *args, num_exp=100, warmup=0, **kwargs):
13+
14+
for _ in range(warmup):
15+
f(*args, **kwargs)
16+
17+
times = []
18+
for _ in range(num_exp):
19+
start = perf_counter_ns()
20+
f(*args, **kwargs)
21+
end = perf_counter_ns()
22+
times.append(end - start)
23+
return torch.tensor(times).float()
24+
25+
26+
def report_stats(times, unit="ms"):
27+
mul = {
28+
"ns": 1,
29+
"µs": 1e-3,
30+
"ms": 1e-6,
31+
"s": 1e-9,
32+
}[unit]
33+
times = times * mul
34+
std = times.std().item()
35+
med = times.median().item()
36+
print(f"{med = :.2f}{unit} +- {std:.2f}")
37+
return med
38+
39+
40+
WAV_DIR = "/tmp/wav_files"
41+
42+
FORMATS = {
43+
"u8": ("pcm_u8", "int16"),
44+
"s16": ("pcm_s16le", "int16"),
45+
"s24": ("pcm_s24le", "int32"),
46+
"s32": ("pcm_s32le", "int32"),
47+
"float32": ("pcm_f32le", "float32"),
48+
"float64": ("pcm_f64le", "float64"),
49+
}
50+
51+
DURATIONS = {
52+
"10s": 10,
53+
"5min": 300,
54+
}
55+
56+
57+
def generate_wav_files():
58+
os.makedirs(WAV_DIR, exist_ok=True)
59+
for fmt_name, (codec, _) in FORMATS.items():
60+
for dur_name, dur_seconds in DURATIONS.items():
61+
path = os.path.join(WAV_DIR, f"{fmt_name}_{dur_name}.wav")
62+
if os.path.exists(path):
63+
continue
64+
print(f"Generating {path} ...")
65+
subprocess.run(
66+
[
67+
"ffmpeg",
68+
"-y",
69+
"-f",
70+
"lavfi",
71+
"-i",
72+
f"sine=frequency=440:duration={dur_seconds}",
73+
"-c:a",
74+
codec,
75+
path,
76+
],
77+
check=True,
78+
capture_output=True,
79+
)
80+
81+
82+
def decode_torchcodec(raw_bytes):
83+
decoder = WavDecoder(raw_bytes)
84+
return decoder.get_all_samples()
85+
86+
87+
def decode_audio_decoder(raw_bytes):
88+
decoder = AudioDecoder(raw_bytes)
89+
return decoder.get_all_samples().data
90+
91+
92+
def decode_soundfile(raw_bytes):
93+
data, sr = sf.read(io.BytesIO(raw_bytes), dtype="float32")
94+
return data
95+
96+
97+
def decode_soundfile_native(raw_bytes, native_dtype):
98+
data, sr = sf.read(io.BytesIO(raw_bytes), dtype=native_dtype)
99+
return data
100+
101+
102+
def validate_results():
103+
print("Validating WavDecoder vs AudioDecoder outputs...")
104+
for fmt_name in FORMATS:
105+
for dur_name in DURATIONS:
106+
path = os.path.join(WAV_DIR, f"{fmt_name}_{dur_name}.wav")
107+
with open(path, "rb") as f:
108+
raw_bytes = f.read()
109+
wav_out = decode_torchcodec(raw_bytes)
110+
audio_out = decode_audio_decoder(raw_bytes)
111+
torch.testing.assert_close(wav_out.data, audio_out.data, atol=0, rtol=0)
112+
print(f" {fmt_name}/{dur_name}: OK")
113+
print("All validations passed!\n")
114+
115+
116+
def main():
117+
generate_wav_files()
118+
validate_results()
119+
120+
results = []
121+
122+
for fmt_name, (_, native_dtype) in FORMATS.items():
123+
for dur_name in DURATIONS:
124+
path = os.path.join(WAV_DIR, f"{fmt_name}_{dur_name}.wav")
125+
with open(path, "rb") as f:
126+
raw_bytes = f.read()
127+
num_exp = 10 if dur_name == "5min" else 100
128+
129+
print(f"\n=== {fmt_name} / {dur_name} ({path}) ===")
130+
131+
print(" WavDecoder: ", end="")
132+
tc_times = bench(decode_torchcodec, raw_bytes, num_exp=num_exp, warmup=2)
133+
tc_med = report_stats(tc_times)
134+
135+
print(" AudioDecoder: ", end="")
136+
ad_times = bench(decode_audio_decoder, raw_bytes, num_exp=num_exp, warmup=2)
137+
ad_med = report_stats(ad_times)
138+
139+
print(" soundfile (dtype=float32): ", end="")
140+
sf_times = bench(decode_soundfile, raw_bytes, num_exp=num_exp, warmup=2)
141+
sf_med = report_stats(sf_times)
142+
143+
print(f" soundfile (dtype={native_dtype}): ", end="")
144+
sfn_times = bench(
145+
decode_soundfile_native,
146+
raw_bytes,
147+
native_dtype,
148+
num_exp=num_exp,
149+
warmup=2,
150+
)
151+
sfn_med = report_stats(sfn_times)
152+
153+
results.append((fmt_name, dur_name, tc_med, ad_med, sf_med, sfn_med))
154+
155+
print("\n" + "=" * 155)
156+
print("SUMMARY")
157+
print("=" * 155)
158+
print(
159+
f"{'format':<10} {'duration':<10} "
160+
f"{'WavDec (ms)':>12} {'AudioDec (ms)':>14} {'sndfile f32 (ms)':>17} {'sndfile native (ms)':>20} "
161+
f"{'AudioDec/WavDec':>16} {'sndfile f32/WavDec':>19} {'sndfile nat/WavDec':>19}"
162+
)
163+
print("-" * 155)
164+
for fmt_name, dur_name, tc_med, ad_med, sf_med, sfn_med in results:
165+
audio_over_wav = ad_med / tc_med if tc_med > 0 else float("inf")
166+
sf_over_wav = sf_med / tc_med if tc_med > 0 else float("inf")
167+
sfn_over_wav = sfn_med / tc_med if tc_med > 0 else float("inf")
168+
print(
169+
f"{fmt_name:<10} {dur_name:<10} "
170+
f"{tc_med:>12.2f} {ad_med:>14.2f} {sf_med:>17.2f} {sfn_med:>20.2f} "
171+
f"{audio_over_wav:>15.2f}x {sf_over_wav:>18.2f}x {sfn_over_wav:>18.2f}x"
172+
)
173+
174+
175+
def bench_input_types():
176+
print("\n" + "=" * 100)
177+
print("FILE vs FILE-LIKE vs BYTES")
178+
print("=" * 100)
179+
print(
180+
f"{'format':<10} {'duration':<10} "
181+
f"{'file (ms)':>12} {'file-like (ms)':>16} {'bytes (ms)':>12} "
182+
f"{'flike/file':>12} {'bytes/file':>12}"
183+
)
184+
print("-" * 100)
185+
186+
for fmt_name in FORMATS:
187+
for dur_name in DURATIONS:
188+
path = os.path.join(WAV_DIR, f"{fmt_name}_{dur_name}.wav")
189+
num_exp = 10 if dur_name == "5min" else 100
190+
191+
with open(path, "rb") as f:
192+
raw_bytes = f.read()
193+
194+
def decode_file():
195+
return WavDecoder(path).get_all_samples()
196+
197+
def decode_filelike():
198+
with open(path, "rb") as f:
199+
return WavDecoder(f).get_all_samples()
200+
201+
def decode_bytes():
202+
return WavDecoder(raw_bytes).get_all_samples()
203+
204+
file_med = (
205+
bench(decode_file, num_exp=num_exp, warmup=2).median().item() * 1e-6
206+
)
207+
flike_med = (
208+
bench(decode_filelike, num_exp=num_exp, warmup=2).median().item() * 1e-6
209+
)
210+
bytes_med = (
211+
bench(decode_bytes, num_exp=num_exp, warmup=2).median().item() * 1e-6
212+
)
213+
214+
flike_ratio = flike_med / file_med if file_med > 0 else float("inf")
215+
bytes_ratio = bytes_med / file_med if file_med > 0 else float("inf")
216+
217+
print(
218+
f"{fmt_name:<10} {dur_name:<10} "
219+
f"{file_med:>12.2f} {flike_med:>16.2f} {bytes_med:>12.2f} "
220+
f"{flike_ratio:>11.2f}x {bytes_ratio:>11.2f}x"
221+
)
222+
223+
224+
if __name__ == "__main__":
225+
main()
226+
bench_input_types()

0 commit comments

Comments
 (0)