Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
use once_cell::sync::OnceCell;
use std::ffi::{CStr, CString};
use std::mem;
use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};

/// Equivalent of the `COZ_PROGRESS` and `COZ_PROGRESS_NAMED` macros
///
Expand Down Expand Up @@ -158,7 +158,8 @@ impl Counter {
mem::size_of_val(&counter.count),
mem::size_of::<libc::size_t>()
);
counter.count.fetch_add(1, SeqCst);
counter.count.fetch_add(1, Relaxed);
coz_add_delays();
}
}

Expand Down Expand Up @@ -203,6 +204,11 @@ struct coz_counter_t {
/// `typedef coz_counter_t* (*coz_get_counter_t)(int, const char*);`
type GetCounterFn = unsafe extern "C" fn(libc::c_int, *const libc::c_char) -> *mut coz_counter_t;

/// The type of `_coz_add_delays` as defined in `include/coz.h`
///
/// `typedef void (*coz_add_delays_t)(void);`
type AddDelaysFn = unsafe extern "C" fn();

#[cfg(target_os = "linux")]
fn coz_get_counter(ty: libc::c_int, name: &CStr) -> Option<*mut coz_counter_t> {
static GET_COUNTER: OnceCell<Option<GetCounterFn>> = OnceCell::new();
Expand All @@ -226,7 +232,34 @@ fn coz_get_counter(ty: libc::c_int, name: &CStr) -> Option<*mut coz_counter_t> {
func.map(|f| unsafe { f(ty, name.as_ptr()) })
}

/// Calls `_coz_add_delays()` from libcoz.
///
/// This must be called after every counter increment to allow the profiler to
/// inject virtual delays for causal profiling experiments. Without this call,
/// the profiler cannot detect progress points and will report 0 experiments.
#[cfg(target_os = "linux")]
fn coz_add_delays() {
static ADD_DELAYS: OnceCell<Option<AddDelaysFn>> = OnceCell::new();
let func = ADD_DELAYS.get_or_init(|| {
let name = CStr::from_bytes_with_nul(b"_coz_add_delays\0").unwrap();
let func = unsafe { libc::dlsym(libc::RTLD_DEFAULT, name.as_ptr()) };
if func.is_null() {
None
} else {
Some(unsafe { mem::transmute(func) })
}
});

if let Some(f) = func {
// SAFETY: _coz_add_delays is a void->void function with no invariants.
unsafe { f() };
}
}

#[cfg(not(target_os = "linux"))]
fn coz_get_counter(_ty: libc::c_int, _name: &CStr) -> Option<*mut coz_counter_t> {
None
}

#[cfg(not(target_os = "linux"))]
fn coz_add_delays() {}