Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "dep-graph"
version = "0.2.0"
version = "0.2.1"
authors = ["Nicolas Moutschen <nicolas.moutschen@gmail.com>"]
edition = "2018"
license = "MIT"
Expand All @@ -12,17 +12,16 @@ description = "Dependency graph resolver library"

[features]
default = ["parallel"]

parallel = ["rayon", "crossbeam-channel"]

[dev-dependencies]
criterion = "0.3"
criterion = "0.8"

[[bench]]
name = "dep_graph"
harness = false

[dependencies]
crossbeam-channel = { version = "0.4", optional = true }
rayon = { version = "1.5", optional = true }
crossbeam-channel = { version = "0.5", optional = true }
rayon = { version = "1.11", optional = true }
num_cpus = "1.13"
7 changes: 4 additions & 3 deletions benches/dep_graph.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use criterion::{criterion_group, criterion_main, Criterion};
use dep_graph::{DepGraph, Node};
#[cfg(feature = "parallel")]
use rayon::prelude::*;
use std::hint::black_box;
use std::thread;
use std::time::Duration;

Expand Down Expand Up @@ -30,12 +31,12 @@ fn add_layer(index: usize, count: usize) -> Vec<Node<String>> {
pub fn parallel_benchmark(c: &mut Criterion) {
const NUM_LAYERS: usize = 20;
#[cfg(feature = "parallel")]
fn par_no_op(nodes: &Vec<Node<String>>) {
fn par_no_op(nodes: &[Node<String>]) {
DepGraph::new(nodes)
.into_par_iter()
.for_each(|_node| thread::sleep(Duration::from_nanos(100)))
}
fn seq_no_op(nodes: &Vec<Node<String>>) {
fn seq_no_op(nodes: &[Node<String>]) {
DepGraph::new(nodes)
.into_iter()
.for_each(|_node| thread::sleep(Duration::from_nanos(100)))
Expand Down
84 changes: 47 additions & 37 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ pub type InnerDependencyMap<I> = HashMap<I, HashSet<I>>;
pub type DependencyMap<I> = Arc<RwLock<InnerDependencyMap<I>>>;

/// Dependency graph
#[derive(Debug, Default)]
pub struct DepGraph<I>
where
I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
Expand Down Expand Up @@ -43,16 +44,12 @@ where

if node.deps().is_empty() {
ready_nodes.push(node.id().clone());
}

for node_dep in node.deps() {
if !rdeps.contains_key(node_dep) {
let mut dep_rdeps = HashSet::new();
dep_rdeps.insert(node.id().clone());
rdeps.insert(node_dep.clone(), dep_rdeps.clone());
} else {
let dep_rdeps = rdeps.get_mut(node_dep).unwrap();
dep_rdeps.insert(node.id().clone());
} else {
for node_dep in node.deps() {
rdeps
.entry(node_dep.clone())
.or_default()
.insert(node.id().clone());
}
}
}
Expand All @@ -65,6 +62,20 @@ where
}
}

impl<I: Clone> Clone for DepGraph<I>
where
I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
ready_nodes: self.ready_nodes.clone(),
// clone the inner HashMap so that a new iteration can be started
deps: Arc::new(RwLock::new(self.deps.read().unwrap().clone())),
rdeps: Arc::new(RwLock::new(self.rdeps.read().unwrap().clone())),
}
}
}

impl<I> IntoIterator for DepGraph<I>
where
I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
Expand All @@ -73,7 +84,7 @@ where
type IntoIter = DepGraphIter<I>;

fn into_iter(self) -> Self::IntoIter {
DepGraphIter::<I>::new(self.ready_nodes.clone(), self.deps.clone(), self.rdeps)
DepGraphIter::<I>::new(self.ready_nodes, self.deps, self.rdeps)
}
}

Expand Down Expand Up @@ -133,33 +144,32 @@ pub fn remove_node_id<I>(
where
I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
{
let rdep_ids = {
match rdeps.read().unwrap().get(&id) {
Some(node) => node.clone(),
// If no node depends on a node, it will not appear
// in rdeps.
None => Default::default(),
}
};

let mut deps = deps.write().unwrap();
let next_nodes = rdep_ids
.iter()
.filter_map(|rdep_id| {
let rdep = match deps.get_mut(&rdep_id) {
Some(rdep) => rdep,
None => return None,
};

rdep.remove(&id);

if rdep.is_empty() {
Some(rdep_id.clone())
} else {
None
}
})
.collect();

let next_nodes = if let Some(rdep_ids) = rdeps.read().unwrap().get(&id) {
let next_nodes = rdep_ids
.iter()
.filter_map(|rdep_id| {
let rdep = match deps.get_mut(rdep_id) {
Some(rdep) => rdep,
None => return None,
};

rdep.remove(&id);

if rdep.is_empty() {
Some(rdep_id.clone())
} else {
None
}
})
.collect();

next_nodes
} else {
// If no node depends on a node, it will not appear in rdeps.
vec![]
};

// Remove the current node from the list of dependencies.
deps.remove(&id);
Expand Down
91 changes: 70 additions & 21 deletions src/graph_par.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc, RwLock,
};
use std::thread;
use std::thread::{self, JoinHandle};
use std::time::Duration;

/// Default timeout in milliseconds
Expand Down Expand Up @@ -89,9 +89,12 @@ where
/// This will decrement the processing counter and notify the dispatcher thread.
fn drop(&mut self) {
(*self.counter).fetch_sub(1, Ordering::SeqCst);
self.item_done_tx
.send(self.inner.clone())
.expect("could not send message")
// Ignore send errors - the receiver may have been dropped.
// Drop implementations should not panic per Rust best practices,
// as panicking during unwinding from another panic will abort the program.
// See: https://doc.rust-lang.org/std/ops/trait.Drop.html
// See: https://github.com/nmoutschen/dep-graph/issues/3
let _ = self.item_done_tx.send(self.inner.clone());
}
}

Expand Down Expand Up @@ -151,6 +154,9 @@ where
counter: Arc<AtomicUsize>,
item_ready_rx: Receiver<I>,
item_done_tx: Sender<I>,
dispatcher_thread: JoinHandle<Result<(), Error>>,
/// Total number of nodes in the graph (used for rayon's len() hint)
total_nodes: usize,
}

impl<I> DepGraphParIter<I>
Expand All @@ -165,33 +171,50 @@ where
let timeout = Arc::new(RwLock::new(DEFAULT_TIMEOUT));
let counter = Arc::new(AtomicUsize::new(0));

// Capture total node count before moving deps to the dispatcher thread.
// This is used by IndexedParallelIterator::len() to give rayon an accurate
// hint about parallelism, preventing it from spawning more workers than items.
let total_nodes = deps.read().unwrap().len();

// Create communication channel for processed nodes
let (item_ready_tx, item_ready_rx) = crossbeam_channel::unbounded::<I>();
let (item_done_tx, item_done_rx) = crossbeam_channel::unbounded::<I>();

// Track items in flight: dispatched but not yet completed.
// This is more reliable than checking counter + pending_items because
// there's a race window between recv() and Wrapper::new() where neither
// counter nor pending_items accounts for the item.
let mut in_flight: usize = 0;

// Inject ready nodes
ready_nodes
.iter()
.for_each(|node| item_ready_tx.send(node.clone()).unwrap());
for node in ready_nodes.into_iter() {
item_ready_tx
.send(node)
.unwrap_or_else(|err| panic!("could not send message: {}", err));
in_flight += 1;
}

// Clone Arcs for dispatcher thread
let loop_timeout = timeout.clone();
let loop_counter = counter.clone();

// Start dispatcher thread
thread::spawn(move || {
let dispatcher_thread = thread::spawn(move || {
loop {
crossbeam_channel::select! {
// Grab a processed node ID
recv(item_done_rx) -> id => {
let id = id.unwrap();
in_flight -= 1;

// Remove the node from all reverse dependencies
let next_nodes = remove_node_id::<I>(id, &deps, &rdeps)?;

// Send the next available nodes to the channel.
next_nodes
.iter()
.for_each(|node_id| item_ready_tx.send(node_id.clone()).unwrap());
for node_id in next_nodes.into_iter() {
item_ready_tx.send(node_id)
.unwrap_or_else(|err| panic!("could not send message: {}", err));
in_flight += 1;
}

// If there are no more nodes, leave the loop
if deps.read().unwrap().is_empty() {
Expand All @@ -201,11 +224,12 @@ where
// Timeout
default(*loop_timeout.read().unwrap()) => {
let deps = deps.read().unwrap();
let counter_val = loop_counter.load(Ordering::SeqCst);
if deps.is_empty() {
break;
// There are still some items processing.
} else if counter_val > 0 {
// There are still items in flight (dispatched but not completed).
// This properly handles the race window between recv() and Wrapper::new().
// See: https://github.com/nmoutschen/dep-graph/issues/3
} else if in_flight > 0 {
continue;
} else {
return Err(Error::ResolveGraphError("circular dependency detected"));
Expand All @@ -223,9 +247,10 @@ where
DepGraphParIter {
timeout,
counter,

item_ready_rx,
item_done_tx,
dispatcher_thread,
total_nodes,
}
}

Expand Down Expand Up @@ -254,7 +279,16 @@ where
I: Clone + fmt::Debug + Eq + Hash + PartialEq + Send + Sync + 'static,
{
fn len(&self) -> usize {
num_cpus::get()
// Return the minimum of node count and CPU count.
//
// - If len() > total_nodes: rayon spawns excess workers that block
// forever on recv(), causing deadlocks (seen on ARM with 8+ cores).
// - If len() > num_cpus: rayon's work distribution becomes inefficient,
// causing severe slowdowns (seen on ubuntu-latest with 2 cores).
//
// Using min() handles both cases: it prevents deadlocks while keeping
// rayon's work-stealing scheduler efficient.
std::cmp::min(self.total_nodes, num_cpus::get())
}

fn drive<C>(self, consumer: C) -> C::Result
Expand All @@ -269,9 +303,10 @@ where
CB: ProducerCallback<Self::Item>,
{
callback.callback(DepGraphProducer {
counter: self.counter.clone(),
counter: self.counter,
item_ready_rx: self.item_ready_rx,
item_done_tx: self.item_done_tx,
dispatcher_thread: Some(self.dispatcher_thread),
})
}
}
Expand All @@ -283,6 +318,7 @@ where
counter: Arc<AtomicUsize>,
item_ready_rx: Receiver<I>,
item_done_tx: Sender<I>,
dispatcher_thread: Option<JoinHandle<Result<(), Error>>>,
}

impl<I> Iterator for DepGraphProducer<I>
Expand All @@ -299,7 +335,17 @@ where
self.counter.clone(),
self.item_done_tx.clone(),
)),
Err(_) => None,
Err(_) => {
// Wait for dispatcher thread to finish and report any error
if let Some(thread) = self.dispatcher_thread.take() {
thread
.join()
.unwrap_or_else(|err| panic!("could not join thread: {:?}", err))
.unwrap();
}

None
}
}
}
}
Expand Down Expand Up @@ -327,9 +373,10 @@ where

fn into_iter(self) -> Self::IntoIter {
Self {
counter: self.counter.clone(),
item_ready_rx: self.item_ready_rx.clone(),
counter: self.counter,
item_ready_rx: self.item_ready_rx,
item_done_tx: self.item_done_tx,
dispatcher_thread: self.dispatcher_thread,
}
}

Expand All @@ -339,11 +386,13 @@ where
counter: self.counter.clone(),
item_ready_rx: self.item_ready_rx.clone(),
item_done_tx: self.item_done_tx.clone(),
dispatcher_thread: self.dispatcher_thread,
},
Self {
counter: self.counter.clone(),
item_ready_rx: self.item_ready_rx.clone(),
item_done_tx: self.item_done_tx,
dispatcher_thread: None,
},
)
}
Expand Down
Loading