Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
/target
*.html
file_*.hsnw
13 changes: 13 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,9 @@ tempfile = { version = "3", optional = true }
[dev-dependencies]
tempfile = { version = "3" }
criterion = { version = "0.5.1", features = ["async_tokio"] }
rallo = { version = "0.5" }

[features]
no_thread = []
generate_new_path = ["tempfile"]
track_allocations = []
52 changes: 52 additions & 0 deletions src/bin/hnsw_perf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use std::{io::Write, time::Instant};

use oramacore_lib::data_structures::hnsw2::HNSW2Index;
use rand::distr::{Distribution, Uniform};

fn main() {
let n = 50_000;
let dimension = 64;

let normal = Uniform::new(0.0, 10.0).unwrap();
let samples = (0..n)
.map(|_| {
(0..dimension)
.map(|_| normal.sample(&mut rand::rng()))
.collect::<Vec<f32>>()
})
.collect::<Vec<Vec<f32>>>();
let mut index = HNSW2Index::new(dimension);

let now = Instant::now();
for (i, sample) in samples.into_iter().enumerate() {
index.add_owned(sample, i).unwrap();
}
println!("adding {} points took: {:.2?}", n, now.elapsed());
index.build().unwrap();
println!("building index took: {:.2?}", now.elapsed());

for j in [
8 * 1024,
16 * 1024,
32 * 1024,
64 * 1024,
128 * 1024,
256 * 1024,
512 * 1024,
1024 * 1024,
]
.iter()
{
let f = format!("file_{j}.hsnw");
let mut f = std::fs::File::create(&f).unwrap();
let buf = std::io::BufWriter::with_capacity(*j, &mut f);

let before = Instant::now();
bincode::serialize_into(buf, &index).unwrap();
f.flush().unwrap();
f.sync_data().unwrap();
let elapsed = before.elapsed();

println!("serialize to {j} bytes buffer took: {elapsed:.2?}");
}
}
6 changes: 6 additions & 0 deletions src/data_structures/hnsw2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ impl<
.map_err(|e| anyhow::anyhow!(e))
}

pub fn add_owned(&mut self, point: Vec<f32>, id: DocumentId) -> Result<()> {
self.inner
.add_owned(point, DocumentIdWrapper(id))
.map_err(|e| anyhow::anyhow!(e))
}

pub fn build(&mut self) -> Result<()> {
self.inner
.build(Metric::Euclidean)
Expand Down
7 changes: 7 additions & 0 deletions src/hnsw2/core/ann_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ pub trait ANNIndex<E: node::FloatElement, T: node::IdxType>: Send + Sync {
self.add_node(&node::Node::new_with_idx(vs, idx))
}

/// add node with owned data
///
/// call `add_node()` internal
fn add_owned(&mut self, vs: Vec<E>, idx: T) -> Result<(), &'static str> {
self.add_node(&node::Node::new_with_idx_owned(vs, idx))
}

/// add multiple node one time
///
/// return `Err(&'static str)` if there is something wrong with the adding process, and the `static str` is the debug reason
Expand Down
20 changes: 20 additions & 0 deletions src/hnsw2/core/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,17 @@ impl<E: FloatElement, T: IdxType> Node<E, T> {
}
}

/// new without idx
///
/// new a point without a idx
pub fn new_owned(vectors: Vec<E>) -> Node<E, T> {
Node::<E, T>::valid_elements(&vectors);
Node {
vectors,
idx: Option::None,
}
}

/// new with idx
///
/// new a point with a idx
Expand All @@ -117,6 +128,15 @@ impl<E: FloatElement, T: IdxType> Node<E, T> {
n
}

/// new with idx
///
/// new a point with a idx
pub fn new_with_idx_owned(vectors: Vec<E>, id: T) -> Node<E, T> {
let mut n = Node::new_owned(vectors);
n.set_idx(id);
n
}

/// calculate the point distance
pub fn metric(&self, other: &Node<E, T>, t: metrics::Metric) -> Result<E, &'static str> {
metrics::metric(&self.vectors, &other.vectors, t)
Expand Down
Binary file added src/hnsw2/dump.hsnw
Binary file not shown.
42 changes: 32 additions & 10 deletions src/hnsw2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use rand::prelude::*;
use rayon::prelude::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::BinaryHeap;

use std::collections::HashMap;
Expand Down Expand Up @@ -280,7 +281,7 @@ impl<E: node::FloatElement, T: node::IdxType> HNSWIndex<E, T> {
has_deletion: bool,
) -> BinaryHeap<Neighbor<E, usize>> {
let mut candidates: BinaryHeap<Neighbor<E, usize>> = BinaryHeap::new();
let mut top_candidates: BinaryHeap<Neighbor<E, usize>> = BinaryHeap::new();
let mut top_candidates: BinaryHeap<Neighbor<E, usize>> = BinaryHeap::with_capacity(ef);
for neighbor in sorted_candidates.iter() {
let root = neighbor.idx();
if !has_deletion || !self.is_deleted(root) {
Expand Down Expand Up @@ -331,6 +332,8 @@ impl<E: node::FloatElement, T: node::IdxType> HNSWIndex<E, T> {
});
}

// println!("top_candidates {}. {}", top_candidates.len(), ef);

top_candidates
}
//find ef nearist nodes to search data from root at level
Expand Down Expand Up @@ -647,7 +650,7 @@ impl<E: node::FloatElement, T: node::IdxType> ann_index::ANNIndex<E, T> for HNSW
}

#[derive(Default, Debug, Serialize, Deserialize)]
pub struct HNSWIndexDump<E: node::FloatElement, T: node::IdxType> {
pub struct HNSWIndexDump<'hsnw, E: node::FloatElement, T: node::IdxType> {
_dimension: usize, // dimension
_n_items: usize, // next item count
_n_constructed_items: usize,
Expand All @@ -666,8 +669,8 @@ pub struct HNSWIndexDump<E: node::FloatElement, T: node::IdxType> {
// use for serde
_id2neighbor_tmp: Vec<Vec<Vec<usize>>>,
_id2neighbor0_tmp: Vec<Vec<usize>>,
_nodes_tmp: Vec<node::Node<E, T>>,
_item2id_tmp: Vec<(T, usize)>,
_nodes_tmp: Vec<Cow<'hsnw, Box<node::Node<E, T>>>>,
_item2id_tmp: Vec<(Cow<'hsnw, T>, usize)>,
_delete_ids_tmp: Vec<usize>,
}

Expand All @@ -689,8 +692,12 @@ impl<E: node::FloatElement, T: node::IdxType> Serialize for HNSWIndex<E, T> {
.map(|x| x.read().unwrap().clone())
.collect();

let _nodes_tmp = self._nodes.iter().map(|x| *x.clone()).collect();
let _item2id_tmp = self._item2id.iter().map(|(k, v)| (k.clone(), *v)).collect();
let _nodes_tmp = self._nodes.iter().map(Cow::Borrowed).collect();
let _item2id_tmp = self
._item2id
.iter()
.map(|(k, v)| (Cow::Borrowed(k), *v))
.collect();
let _delete_ids_tmp = self._delete_ids.iter().copied().collect();

let dump = HNSWIndexDump {
Expand Down Expand Up @@ -730,8 +737,8 @@ impl<'de, E: node::FloatElement + DeserializeOwned, T: node::IdxType + Deseriali

let _nodes: Vec<_> = dump
._nodes_tmp
.iter()
.map(|x| Box::new(x.clone()))
.into_iter()
.map(|x| x.into_owned())
.collect();

let _id2neighbor = dump
Expand All @@ -746,7 +753,15 @@ impl<'de, E: node::FloatElement + DeserializeOwned, T: node::IdxType + Deseriali
.map(RwLock::new)
.collect::<Vec<_>>();

let _item2id = dump._item2id_tmp.into_iter().collect::<HashMap<_, _>>();
let _item2id = dump
._item2id_tmp
.into_iter()
.map(|(k, v)| {
// K is always owned here
// serde allocates it by itself
(k.into_owned(), v)
})
.collect::<HashMap<_, _>>();
let _delete_ids = dump._delete_ids_tmp.into_iter().collect::<HashSet<_>>();

Ok(Self {
Expand Down Expand Up @@ -774,7 +789,7 @@ impl<'de, E: node::FloatElement + DeserializeOwned, T: node::IdxType + Deseriali
}

#[cfg(test)]
mod tests {
mod hsnw_tests {
use super::core::{ann_index::ANNIndex, metrics::Metric};

use super::*;
Expand Down Expand Up @@ -812,4 +827,11 @@ mod tests {

assert_eq!(v1, v2);
}

#[test]
fn test_serde_backcompatibility() {
let b = include_bytes!("./dump.hsnw");
let new_index: HNSWIndex<f32, usize> = bincode::deserialize(b).unwrap();
assert_eq!(new_index.len(), 100);
}
}
50 changes: 50 additions & 0 deletions tests/hsnw.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use rallo::RalloAllocator;

// This is the maximum length of a frame
const MAX_FRAME_LENGTH: usize = 512;
// Maximum number of allocations to keep
const MAX_LOG_COUNT: usize = 1_024 * 256;
#[global_allocator]
static ALLOCATOR: RalloAllocator<MAX_FRAME_LENGTH, MAX_LOG_COUNT> = RalloAllocator::new();

#[cfg(feature = "track_allocations")]
#[test]
fn test_allocation() {
use oramacore_lib::data_structures::hnsw2::HNSW2Index;
use rand::distr::{Distribution, Uniform};
use std::time::Instant;

let n = 1_000;
let dimension = 64;

let normal = Uniform::new(0.0, 10.0).unwrap();
let samples = (0..n)
.map(|_| {
(0..dimension)
.map(|_| normal.sample(&mut rand::rng()))
.collect::<Vec<f32>>()
})
.collect::<Vec<Vec<f32>>>();
let mut index = HNSW2Index::new(dimension);

let now = Instant::now();
for (i, sample) in samples.into_iter().enumerate() {
index.add_owned(sample, i).unwrap();
}

println!("Time to add: {:?}", now.elapsed());
unsafe { ALLOCATOR.start_track() };
println!("Tracking: {:?}", now.elapsed());
index.build().unwrap();
ALLOCATOR.stop_track();

// Safety: it is called after `stop_track`
let stats = unsafe { ALLOCATOR.calculate_stats() };
let tree = stats.into_tree().unwrap();

let file_name = "simple-memory-flamegraph.html";
let path = std::env::current_dir().unwrap().join(file_name);
tree.print_flamegraph(&path);

println!("Flamegraph saved to {}", path.display());
}
Loading