-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathbuild.rs
More file actions
346 lines (307 loc) · 10.4 KB
/
Copy pathbuild.rs
File metadata and controls
346 lines (307 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
fn main() {
let total_timer = BuildTimer::start("build.rs total");
let sidecar_abi = std::fs::read_to_string("sidecar_abi_version.txt")
.expect("sidecar_abi_version.txt is required")
.trim()
.to_string();
sidecar_abi
.parse::<u32>()
.expect("sidecar_abi_version.txt must contain a u32 ABI version");
println!("cargo:rerun-if-changed=sidecar_abi_version.txt");
println!("cargo:rustc-env=KRASIS_SIDECAR_ABI_VERSION={sidecar_abi}");
println!("cargo::rustc-check-cfg=cfg(no_numa)");
println!("cargo::rustc-check-cfg=cfg(has_decode_kernels)");
println!("cargo::rustc-check-cfg=cfg(has_prefill_kernels)");
println!("cargo::rustc-check-cfg=cfg(has_hqq_search_kernels)");
// Force rerun when env changes (e.g. CUDA_HOME)
println!("cargo:rerun-if-env-changed=CUDA_HOME");
println!("cargo:rerun-if-env-changed=CUDA_PATH");
// Probe for libnuma — link only if the library is found.
// The runtime code (numa.rs) checks numa_available() and falls back
// gracefully, but the linker needs -lnuma at build time if we use
// extern "C" FFI declarations.
//
// When libnuma is NOT found (e.g. CI manylinux containers), we set
// cfg(no_numa) so numa.rs can stub out the FFI calls.
let has_numa = timed_value("probe libnuma", || probe_lib("numa"));
if has_numa {
println!("cargo:rustc-link-lib=numa");
} else {
println!("cargo:rustc-cfg=no_numa");
println!("cargo:warning=libnuma not found — NUMA support disabled (will use fallback)");
}
// Compile CUDA decode kernels to PTX if nvcc is available.
// The PTX is embedded as a string constant via include_str!.
timed_phase("decode PTX", compile_cuda_kernels);
// Compile CUDA prefill kernels to PTX (Rust prefill path).
timed_phase("prefill PTX", compile_prefill_kernels);
// Compile diagnostic HQQ search kernels to PTX.
timed_phase("HQQ search PTX", compile_hqq_search_kernels);
total_timer.finish();
}
struct BuildTimer {
label: &'static str,
start: std::time::Instant,
}
impl BuildTimer {
fn start(label: &'static str) -> Self {
Self {
label,
start: std::time::Instant::now(),
}
}
fn finish(self) {
log_build_timing(self.label, self.start.elapsed());
}
}
fn timed_phase<F>(label: &'static str, f: F)
where
F: FnOnce(),
{
let timer = BuildTimer::start(label);
f();
timer.finish();
}
fn timed_value<T, F>(label: &'static str, f: F) -> T
where
F: FnOnce() -> T,
{
let timer = BuildTimer::start(label);
let value = f();
timer.finish();
value
}
fn log_build_timing(label: &str, elapsed: std::time::Duration) {
let safe_label = label.replace('"', "'");
println!(
"cargo:warning=KRASIS_BUILD_TIMING phase=\"{}\" duration_ms={} duration_s={:.3}",
safe_label,
elapsed.as_millis(),
elapsed.as_secs_f64()
);
}
fn is_output_fresh(inputs: &[&str], outputs: &[&str]) -> bool {
if outputs.is_empty()
|| outputs
.iter()
.any(|path| !std::path::Path::new(path).exists())
{
return false;
}
let newest_input = inputs
.iter()
.filter_map(|path| file_mtime(path))
.max()
.unwrap_or(std::time::SystemTime::UNIX_EPOCH);
let oldest_output = outputs
.iter()
.filter_map(|path| file_mtime(path))
.min()
.unwrap_or(std::time::SystemTime::UNIX_EPOCH);
oldest_output >= newest_input
}
fn file_mtime(path: &str) -> Option<std::time::SystemTime> {
std::fs::metadata(path).ok()?.modified().ok()
}
fn run_status_timed(
mut cmd: std::process::Command,
label: &str,
) -> Result<std::process::ExitStatus, std::io::Error> {
let start = std::time::Instant::now();
let status = cmd.status();
log_build_timing(label, start.elapsed());
status
}
fn nvcc_host_compiler_args() -> Vec<String> {
match std::env::var("KRASIS_NVCC_CCBIN") {
Ok(path) if !path.trim().is_empty() => {
vec!["-ccbin".to_string(), path]
}
_ => Vec::new(),
}
}
fn compile_cuda_kernels() {
let cu_src = "src/cuda/decode_kernels.cu";
println!("cargo:rerun-if-changed={cu_src}");
if !std::path::Path::new(cu_src).exists() {
println!("cargo:warning=decode_kernels.cu not found — GPU decode kernels disabled");
return;
}
// Find nvcc
let nvcc = find_nvcc();
let Some(nvcc) = nvcc else {
println!("cargo:warning=nvcc not found — GPU decode kernels disabled");
return;
};
let out_dir = std::env::var("OUT_DIR").unwrap();
let ptx_path = format!("{out_dir}/decode_kernels.ptx");
if is_output_fresh(&[cu_src], &[&ptx_path]) {
println!("cargo:rustc-cfg=has_decode_kernels");
println!("cargo:warning=Reusing cached GPU decode kernels at {ptx_path}");
return;
}
// Compile .cu to .ptx targeting sm_80 (works on Ampere, Ada, Hopper)
let mut cmd = std::process::Command::new(&nvcc);
cmd.args([
"-ptx",
"-allow-unsupported-compiler",
"-arch=sm_80",
"-O3",
"--use_fast_math",
"-o",
&ptx_path,
cu_src,
])
.args(nvcc_host_compiler_args());
let status = run_status_timed(cmd, "nvcc decode PTX compile");
match status {
Ok(s) if s.success() => {
println!("cargo:rustc-cfg=has_decode_kernels");
println!("cargo:warning=Compiled GPU decode kernels to PTX ({ptx_path})");
}
Ok(s) => {
println!("cargo:warning=nvcc failed with status {s} — GPU decode kernels disabled");
}
Err(e) => {
println!("cargo:warning=nvcc execution error: {e} — GPU decode kernels disabled");
}
}
}
fn compile_prefill_kernels() {
let cu_src = "src/cuda/prefill_kernels.cu";
let shim_header = "src/cuda/prefill_shim.h";
println!("cargo:rerun-if-changed={cu_src}");
println!("cargo:rerun-if-changed={shim_header}");
if !std::path::Path::new(cu_src).exists() {
println!("cargo:warning=prefill_kernels.cu not found — GPU prefill kernels disabled");
return;
}
let nvcc = find_nvcc();
let Some(nvcc) = nvcc else {
println!("cargo:warning=nvcc not found — GPU prefill kernels disabled");
return;
};
let out_dir = std::env::var("OUT_DIR").unwrap();
let ptx_path = format!("{out_dir}/prefill_kernels.ptx");
if is_output_fresh(&[cu_src, shim_header], &[&ptx_path]) {
println!("cargo:rustc-cfg=has_prefill_kernels");
println!("cargo:warning=Reusing cached GPU prefill kernels at {ptx_path}");
return;
}
let mut cmd = std::process::Command::new(&nvcc);
cmd.args([
"-ptx",
"-allow-unsupported-compiler",
"-arch=sm_80",
"-O3",
"--use_fast_math",
"-o",
&ptx_path,
cu_src,
])
.args(nvcc_host_compiler_args());
let status = run_status_timed(cmd, "nvcc prefill PTX compile");
match status {
Ok(s) if s.success() => {
println!("cargo:rustc-cfg=has_prefill_kernels");
println!("cargo:warning=Compiled GPU prefill kernels to PTX ({ptx_path})");
}
Ok(s) => {
println!("cargo:warning=nvcc failed with status {s} — GPU prefill kernels disabled");
}
Err(e) => {
println!("cargo:warning=nvcc execution error: {e} — GPU prefill kernels disabled");
}
}
}
fn compile_hqq_search_kernels() {
let cu_src = "src/cuda/hqq_search_kernels.cu";
println!("cargo:rerun-if-changed={cu_src}");
if !std::path::Path::new(cu_src).exists() {
println!("cargo:warning=hqq_search_kernels.cu not found — HQQ CUDA search disabled");
return;
}
let nvcc = find_nvcc();
let Some(nvcc) = nvcc else {
println!("cargo:warning=nvcc not found — HQQ CUDA search disabled");
return;
};
let out_dir = std::env::var("OUT_DIR").unwrap();
let ptx_path = format!("{out_dir}/hqq_search_kernels.ptx");
if is_output_fresh(&[cu_src], &[&ptx_path]) {
println!("cargo:rustc-cfg=has_hqq_search_kernels");
println!("cargo:warning=Reusing cached HQQ CUDA search kernels at {ptx_path}");
return;
}
let mut cmd = std::process::Command::new(&nvcc);
cmd.args([
"-ptx",
"-allow-unsupported-compiler",
"-arch=sm_80",
"-O3",
"--use_fast_math",
"-o",
&ptx_path,
cu_src,
])
.args(nvcc_host_compiler_args());
let status = run_status_timed(cmd, "nvcc HQQ search PTX compile");
match status {
Ok(s) if s.success() => {
println!("cargo:rustc-cfg=has_hqq_search_kernels");
println!("cargo:warning=Compiled HQQ CUDA search kernels to PTX ({ptx_path})");
}
Ok(s) => {
println!("cargo:warning=nvcc failed with status {s} — HQQ CUDA search disabled");
}
Err(e) => {
println!("cargo:warning=nvcc execution error: {e} — HQQ CUDA search disabled");
}
}
}
fn find_nvcc() -> Option<String> {
// Check CUDA_HOME / CUDA_PATH
for var in ["CUDA_HOME", "CUDA_PATH"] {
if let Ok(cuda_dir) = std::env::var(var) {
let nvcc = format!("{cuda_dir}/bin/nvcc");
if std::path::Path::new(&nvcc).exists() {
return Some(nvcc);
}
}
}
// Check common paths
for path in [
"/usr/local/cuda/bin/nvcc",
"/usr/local/cuda-12.6/bin/nvcc",
"/usr/local/cuda-12/bin/nvcc",
] {
if std::path::Path::new(path).exists() {
return Some(path.to_string());
}
}
// Try PATH
if std::process::Command::new("nvcc")
.arg("--version")
.output()
.is_ok()
{
return Some("nvcc".to_string());
}
None
}
/// Try to find a shared library by compiling a minimal C program that links it.
fn probe_lib(name: &str) -> bool {
// Quick check: see if the lib exists in common paths
for dir in &["/usr/lib", "/usr/lib64", "/usr/lib/x86_64-linux-gnu"] {
let so = format!("{dir}/lib{name}.so");
if std::path::Path::new(&so).exists() {
return true;
}
}
// Try pkg-config as fallback
std::process::Command::new("pkg-config")
.args(["--exists", name])
.status()
.map(|s| s.success())
.unwrap_or(false)
}