Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion core-relations/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub use table_spec::{
ColumnId, Constraint, MutationBuffer, Offset, Rebuilder, Row, Table, TableChange, TableSpec,
TableVersion, WrappedTable,
};
pub use uf::{DisplacedTable, DisplacedTableWithProvenance, ProofReason, ProofStep};
pub use uf::{DisplacedTable, DisplacedTableWithProvenance, LeaderChange, ProofReason, ProofStep};

use egglog_numeric_id as numeric_id;
use egglog_union_find as union_find;
83 changes: 76 additions & 7 deletions core-relations/src/uf/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,31 @@ mod tests;

type UnionFind = crate::union_find::UnionFind<Value>;

/// A callback that runs every time a leader change takes effect. See the documentation for
/// [`LeaderChange`] for the information that is provided.
type LeaderChangeCallback = Arc<dyn Fn(&mut ExecutionState, LeaderChange) + Send + Sync>;

/// Details for a leader change caused by a union.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct LeaderChange {
/// The lhs value provided to the write.
pub write_lhs: Value,
/// The leader of the lhs equivalence class before the union.
pub lhs_leader: Value,
/// The rhs value provided to the write.
pub write_rhs: Value,
/// The leader of the rhs equivalence class before the union.
pub rhs_leader: Value,
/// The timestamp associated with the write that triggered the union.
pub ts: Value,
}

impl LeaderChange {
pub fn new_leader(&self) -> Value {
std::cmp::min(self.lhs_leader, self.rhs_leader)
}
}

/// A special table backed by a union-find used to efficiently implement
/// egglog-style canonicaliztion.
///
Expand Down Expand Up @@ -59,6 +84,8 @@ pub struct DisplacedTable {
changed: bool,
lookup_table: HashMap<Value, RowId>,
buffered_writes: Arc<SegQueue<RowBuffer>>,
/// Stored as Arc so DisplacedTable cloning preserves the callback.
on_leader_change: Option<LeaderChangeCallback>,
}

struct Canonicalizer<'a> {
Expand Down Expand Up @@ -205,6 +232,7 @@ impl Default for DisplacedTable {
changed: false,
lookup_table: HashMap::default(),
buffered_writes: Arc::new(SegQueue::new()),
on_leader_change: None,
}
}
}
Expand All @@ -217,6 +245,7 @@ impl Clone for DisplacedTable {
changed: self.changed,
lookup_table: self.lookup_table.clone(),
buffered_writes: Default::default(),
on_leader_change: self.on_leader_change.clone(),
}
}
}
Expand Down Expand Up @@ -445,10 +474,10 @@ impl Table for DisplacedTable {
})
}

fn merge(&mut self, _: &mut ExecutionState) -> TableChange {
fn merge(&mut self, exec_state: &mut ExecutionState) -> TableChange {
while let Some(rowbuf) = self.buffered_writes.pop() {
for row in rowbuf.iter() {
self.changed |= self.insert_impl(row).is_some();
self.changed |= self.insert_impl(row, exec_state).is_some();
}
}
let changed = mem::take(&mut self.changed);
Expand All @@ -462,6 +491,28 @@ impl Table for DisplacedTable {
}

impl DisplacedTable {
/// Construct with a leader-change callback.
pub fn with_leader_change_callback<F>(callback: F) -> Self
where
F: Fn(&mut ExecutionState, LeaderChange) + Send + Sync + 'static,
{
Self {
on_leader_change: Some(Arc::new(callback)),
..Self::default()
}
}

pub fn set_leader_change_callback<F>(&mut self, callback: F)
where
F: Fn(&mut ExecutionState, LeaderChange) + Send + Sync + 'static,
{
self.on_leader_change = Some(Arc::new(callback));
}

pub fn clear_leader_change_callback(&mut self) {
self.on_leader_change = None;
}

pub fn underlying_uf(&self) -> &UnionFind {
&self.uf
}
Expand All @@ -488,9 +539,15 @@ impl DisplacedTable {
let vals = self.expand(row);
eval_constraint(&vals, constraint)
}
fn insert_impl(&mut self, row: &[Value]) -> Option<(Value, Value)> {
fn insert_impl(
&mut self,
row: &[Value],
exec_state: &mut ExecutionState,
) -> Option<(Value, Value)> {
assert_eq!(row.len(), 3, "attempt to insert a row with the wrong arity");
if self.uf.find(row[0]) == self.uf.find(row[1]) {
let lhs_leader = self.uf.find(row[0]);
let rhs_leader = self.uf.find(row[1]);
if lhs_leader == rhs_leader {
return None;
}
let (parent, child) = self.uf.union(row[0], row[1]);
Expand All @@ -499,6 +556,18 @@ impl DisplacedTable {
let _ = self.uf.find(parent);
let _ = self.uf.find(child);
let ts = row[2];
if let Some(callback) = &self.on_leader_change {
callback(
exec_state,
LeaderChange {
write_lhs: row[0],
lhs_leader,
write_rhs: row[1],
rhs_leader,
ts,
},
);
}
if let Some((_, highest)) = self.displaced.last() {
assert!(
*highest <= ts,
Expand Down Expand Up @@ -723,11 +792,11 @@ impl DisplacedTableWithProvenance {
.or_insert_with(|| self.proof_graph.add_node(val))
}

fn insert_impl(&mut self, row: &[Value]) {
fn insert_impl(&mut self, row: &[Value], exec_state: &mut ExecutionState) {
let [a, b, ts, reason] = row else {
panic!("attempt to insert a row with the wrong arity ({row:?})");
};
match self.base.insert_impl(&[*a, *b, *ts]) {
match self.base.insert_impl(&[*a, *b, *ts], exec_state) {
Some((parent, child)) => {
self.displaced.push((child, parent));
self.context
Expand Down Expand Up @@ -810,7 +879,7 @@ impl Table for DisplacedTableWithProvenance {
fn merge(&mut self, exec_state: &mut ExecutionState) -> TableChange {
while let Some(rowbuf) = self.buffered_writes.pop() {
for row in rowbuf.iter() {
self.insert_impl(row);
self.insert_impl(row, exec_state);
}
}

Expand Down
48 changes: 47 additions & 1 deletion core-relations/src/uf/tests.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::sync::{Arc, Mutex};

use crate::numeric_id::NumericId;

use crate::{
Expand All @@ -7,7 +9,7 @@ use crate::{
uf::ProofReason,
};

use super::DisplacedTable;
use super::{DisplacedTable, LeaderChange};

fn v(x: usize) -> Value {
Value::from_usize(x)
Expand Down Expand Up @@ -98,3 +100,47 @@ fn displaced_proof() {
]
)
}

#[test]
fn displaced_leader_change_callback() {
empty_execution_state!(e);
let changes: Arc<Mutex<Vec<LeaderChange>>> = Arc::new(Mutex::new(Vec::new()));
let changes_ref = Arc::clone(&changes);
let mut d = DisplacedTable::with_leader_change_callback(move |_, change| {
changes_ref.lock().unwrap().push(change);
});
{
let mut buf = d.new_buffer();
buf.stage_insert(&[v(5), v(3), v(0)]);
buf.stage_insert(&[v(5), v(3), v(1)]);
}
d.merge(&mut e);

{
let changes = changes.lock().unwrap();
assert_eq!(changes.len(), 1);
let change = changes[0];
assert_eq!(change.write_lhs, v(5));
assert_eq!(change.lhs_leader, v(5));
assert_eq!(change.write_rhs, v(3));
assert_eq!(change.rhs_leader, v(3));
assert_eq!(change.ts, v(0));
assert_eq!(change.new_leader(), v(3));
}

{
let mut buf = d.new_buffer();
buf.stage_insert(&[v(5), v(2), v(2)]);
}
d.merge(&mut e);

let changes = changes.lock().unwrap();
assert_eq!(changes.len(), 2);
let change = changes[1];
assert_eq!(change.write_lhs, v(5));
assert_eq!(change.lhs_leader, v(3));
assert_eq!(change.write_rhs, v(2));
assert_eq!(change.rhs_leader, v(2));
assert_eq!(change.ts, v(2));
assert_eq!(change.new_leader(), v(2));
}
42 changes: 27 additions & 15 deletions egglog-bridge/examples/ac.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use egglog_bridge::{ColumnTy, DefaultVal, EGraph, FunctionConfig, MergeFn, define_rule};
use egglog_bridge::{
ColumnTy, DefaultVal, EGraph, FunctionConfig, FunctionId, MergeFn, define_rule,
};

use mimalloc::MiMalloc;

#[global_allocator]
static GLOBAL: MiMalloc = MiMalloc;

fn add_table(egraph: &mut EGraph, config: FunctionConfig) -> FunctionId {
egraph.add_table(config).unwrap()
}

#[allow(clippy::disallowed_macros)]
fn main() {
const N: usize = 13;
Expand All @@ -14,20 +20,26 @@ fn main() {
let start = web_time::Instant::now();
let mut egraph = EGraph::default();
let int_base = egraph.base_values_mut().register_type::<i64>();
let num_table = egraph.add_table(FunctionConfig {
schema: vec![ColumnTy::Base(int_base), ColumnTy::Id],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "num".into(),
can_subsume: false,
});
let add_table = egraph.add_table(FunctionConfig {
schema: vec![ColumnTy::Id; 3],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "add".into(),
can_subsume: false,
});
let num_table = add_table(
&mut egraph,
FunctionConfig {
schema: vec![ColumnTy::Base(int_base), ColumnTy::Id],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "num".into(),
can_subsume: false,
},
);
let add_table = add_table(
&mut egraph,
FunctionConfig {
schema: vec![ColumnTy::Id; 3],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "add".into(),
can_subsume: false,
},
);

let add_comm = define_rule! {
[egraph] ((-> (add_table x y) id))
Expand Down
42 changes: 27 additions & 15 deletions egglog-bridge/examples/ac_tracing.rs
Original file line number Diff line number Diff line change
@@ -1,26 +1,38 @@
use std::mem;

use egglog_bridge::{ColumnTy, DefaultVal, EGraph, FunctionConfig, MergeFn, define_rule};
use egglog_bridge::{
ColumnTy, DefaultVal, EGraph, FunctionConfig, FunctionId, MergeFn, define_rule,
};

fn add_table(egraph: &mut EGraph, config: FunctionConfig) -> FunctionId {
egraph.add_table(config).unwrap()
}

fn main() {
const N: usize = 12;
env_logger::init();
let mut egraph = EGraph::with_tracing();
let int_base = egraph.base_values_mut().register_type::<i64>();
let num_table = egraph.add_table(FunctionConfig {
schema: vec![ColumnTy::Base(int_base), ColumnTy::Id],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "num".into(),
can_subsume: false,
});
let add_table = egraph.add_table(FunctionConfig {
schema: vec![ColumnTy::Id; 3],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "add".into(),
can_subsume: false,
});
let num_table = add_table(
&mut egraph,
FunctionConfig {
schema: vec![ColumnTy::Base(int_base), ColumnTy::Id],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "num".into(),
can_subsume: false,
},
);
let add_table = add_table(
&mut egraph,
FunctionConfig {
schema: vec![ColumnTy::Id; 3],
default: DefaultVal::FreshId,
merge: MergeFn::UnionId,
name: "add".into(),
can_subsume: false,
},
);

let add_comm = define_rule! {
[egraph] ((-> (add_table x y) id))
Expand Down
Loading
Loading