Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions candle-examples/build.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
#![allow(unused)]
use anyhow::{Context, Result};
use std::env;
use std::io::Write;
use std::path::PathBuf;
use std::path::{Path, PathBuf};

struct KernelDirectories {
kernel_glob: &'static str,
Expand All @@ -20,11 +21,21 @@ fn main() -> Result<()> {

#[cfg(feature = "cuda")]
{
// Added: Get the safe output directory from the environment.
let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());

for kdir in KERNEL_DIRS.iter() {
let builder = bindgen_cuda::Builder::default().kernel_paths_glob(kdir.kernel_glob);
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write(kdir.rust_target).unwrap()

// Changed: This now writes to a safe path inside $OUT_DIR.
let safe_target = out_dir.join(
Path::new(kdir.rust_target)
.file_name()
.context("Failed to get filename from rust_target")?,
);
bindings.write(safe_target).unwrap()
}
}
Ok(())
Expand Down
4 changes: 3 additions & 1 deletion candle-examples/examples/custom-ops/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ extern crate intel_mkl_src;

#[rustfmt::skip]
#[cfg(feature = "cuda")]
mod cuda_kernels;
mod cuda_kernels {
include!(concat!(env!("OUT_DIR"), "/cuda_kernels.rs"));
}

use clap::Parser;

Expand Down
4 changes: 3 additions & 1 deletion candle-kernels/build.rs

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is missing some imports:

use std::env;
use std::path::PathBuf;

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is soo stupid of me :/

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happens to me more times than I dare to admit :)

Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ fn main() {
println!("cargo:rerun-if-changed=src/cuda_utils.cuh");
println!("cargo:rerun-if-changed=src/binary_op_macros.cuh");

let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap());
let ptx_path = out_dir.join("ptx.rs");
let builder = bindgen_cuda::Builder::default();
println!("cargo:info={builder:?}");
let bindings = builder.build_ptx().unwrap();
bindings.write("src/ptx.rs").unwrap();
bindings.write(ptx_path).unwrap();
}
4 changes: 3 additions & 1 deletion candle-kernels/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
mod ptx;
mod ptx {
include!(concat!(env!("OUT_DIR"), "/ptx.rs"));
}

#[repr(u32)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
Expand Down
Loading