Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benches/hnsw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ pub fn hnsw(c: &mut Criterion) {
c.bench_function("search hnsw2", |b| {
let index = build_hnsw2(DIM, data.clone());
b.iter(|| {
let _ = index.search(black_box(data[0].clone()), 10);
let _ = index.search(black_box(&data[0]), 10);
});
});
}
Expand Down
132 changes: 132 additions & 0 deletions src/data_structures/capped_heap.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
use std::{cmp::Reverse, collections::BinaryHeap, fmt::Debug};

struct Item<K, V> {
key: K,
value: V,
}
impl<K: std::cmp::Ord, V: std::cmp::Ord> std::cmp::PartialEq for Item<K, V> {
fn eq(&self, other: &Self) -> bool {
self.key == other.key && self.value == other.value
}
}
impl<K: std::cmp::Ord, V: std::cmp::Ord> std::cmp::Eq for Item<K, V> {}
impl<K: std::cmp::Ord, V: std::cmp::Ord> std::cmp::PartialOrd for Item<K, V> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<K: std::cmp::Ord, V: std::cmp::Ord> std::cmp::Ord for Item<K, V> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.key
.cmp(&other.key)
.then_with(|| self.value.cmp(&other.value).reverse())
}
}

impl<K: Debug, V: Debug> Debug for Item<K, V> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Item {{ key: {:?}, value: {:?} }}", self.key, self.value)
}
}

pub struct CappedHeap<K, V> {
heap: BinaryHeap<Reverse<Item<K, V>>>,
limit: usize,
}

impl<K: std::cmp::Ord, V: std::cmp::Ord> CappedHeap<K, V> {
pub fn new(limit: usize) -> Self {
Self {
heap: BinaryHeap::new(),
limit,
}
}

pub fn insert(&mut self, key: K, value: V) {
let new_item = Reverse(Item { key, value });
if self.heap.len() < self.limit {
self.heap.push(new_item);
} else if let Some(i) = self.heap.peek() {
if new_item < *i {
self.heap.pop();
self.heap.push(new_item);
}
}
}

pub fn into_top(self) -> impl Iterator<Item = (K, V)> {
let v = self.heap.into_sorted_vec();
v.into_iter()
.map(|Reverse(Item { key, value })| (key, value))
}
}

#[cfg(test)]
mod tests {
use rand::prelude::SliceRandom;
use rand::rng;

use super::*;

#[test]
fn test_cappend_heap_item_order() {
// Different key
{
let a = Item { key: 1, value: 1 };
let b = Item { key: 2, value: 1 };

assert!(b > a);
assert!(a < b);

let c = Item { key: 0, value: 1 };

assert!(a > c);
assert!(c < a);
assert!(b > c);
assert!(c < b);
}

// Same key different value
{
let a = Item { key: 0, value: 1 };
let b = Item { key: 0, value: 2 };

assert!(b < a);
assert!(a > b);
}
}

#[test]
fn test_cappend_heap_foo() {
let mut heap = CappedHeap::new(3);

heap.insert(1, 1);
heap.insert(2, 2);
heap.insert(3, 3);
heap.insert(4, 4);

let top: Vec<_> = heap.into_top().collect();

assert_eq!(top.len(), 3);
assert_eq!(top, vec![(4, 4), (3, 3), (2, 2)]);
}

#[test]
fn test_cappend_heap_order_consistency() {
let mut data = vec![(1, 1), (1, 2), (1, 3), (1, 4)];

for _ in 0..100 {
let mut heap = CappedHeap::new(3);

data.shuffle(&mut rng());
for (key, value) in &data {
heap.insert(*key, *value);
}

let top: Vec<_> = heap.into_top().collect();

assert_eq!(top.len(), 3);
assert_eq!(top, vec![(1, 1), (1, 2), (1, 3)]);
}
}
}
16 changes: 10 additions & 6 deletions src/data_structures/hnsw2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ impl<
self.dim
}

pub fn get_data(&self) -> impl Iterator<Item = (DocumentId, &[f32])> + '_ {
self.inner.data().map(|(v, DocumentIdWrapper(id))| (id, v))
}

pub fn into_data(self) -> impl Iterator<Item = (DocumentId, Vec<f32>)> {
self.inner
.into_data()
Expand All @@ -108,10 +112,10 @@ impl<
.map_err(|e| anyhow::anyhow!(e))
}

pub fn search(&self, target: Vec<f32>, limit: usize) -> Vec<(DocumentId, f32)> {
pub fn search(&self, target: &[f32], limit: usize) -> Vec<(DocumentId, f32)> {
assert_eq!(target.len(), self.dim);

let v = self.inner.node_search_k(&Node::new(&target), limit);
let v = self.inner.node_search_k(&Node::new(target), limit);

let mut result = Vec::new();
for (node, _) in v {
Expand All @@ -122,7 +126,7 @@ impl<
// Anyway, it is good for ranking purposes
// 1 means the vectors are equal
// 0 means the vectors are orthogonal
let score = real_cosine_similarity(n, &target)
let score = real_cosine_similarity(n, target)
.expect("real_cosine_similarity should not return an error");

let id = match node.idx() {
Expand Down Expand Up @@ -159,7 +163,7 @@ mod tests {
index.build().unwrap();

let target = vec![255.0, 0.0, 0.0];
let v = index.search(target, 10);
let v = index.search(&target, 10);

let res: HashMap<_, _> = v.into_iter().collect();

Expand Down Expand Up @@ -192,8 +196,8 @@ mod tests {
.map(|_| normal.sample(&mut rand::rng()))
.collect::<Vec<f32>>();

let v1 = index.search(target.clone(), 10);
let v2 = new_index.search(target.clone(), 10);
let v1 = index.search(&target, 10);
let v2 = new_index.search(&target, 10);

assert_eq!(v1, v2);
}
Expand Down
2 changes: 2 additions & 0 deletions src/data_structures/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
pub mod capped_heap;
pub mod fst;
pub mod hnsw;
pub mod hnsw2;
pub mod map;
pub mod ordered_key;
pub mod radix;
pub mod vector_bruteforce;
189 changes: 189 additions & 0 deletions src/data_structures/vector_bruteforce.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use std::{cmp::Reverse, fmt::Debug};

use crate::{data_structures::capped_heap::CappedHeap, hnsw2::core::simd_metrics::SIMDOptmized};
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
pub struct VectorBruteForce<DocumentId: Debug + Clone + Serialize> {
dim: usize,
data: Vec<(DocumentId, Box<[f32]>, f32)>, // (id, vector, magnitude)
}

fn real_cosine_similarity(
(vec1, magnitude_vec1): (&[f32], f32),
(vec2, magnitude_vec2): (&[f32], f32),
) -> Result<f32, &'static str> {
let a = f32::real_dot_product(vec1, vec2).unwrap();

Ok(a / (magnitude_vec1.sqrt() * magnitude_vec2.sqrt()))
}

impl<DocumentId: Debug + Clone + Copy + Serialize + Ord + Send + Sync>
VectorBruteForce<DocumentId>
{
pub fn new(dim: usize) -> Self {
Self {
dim,
data: Vec::new(),
}
}

pub fn dim(&self) -> usize {
self.dim
}

pub fn len(&self) -> usize {
self.data.len()
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}

pub fn get_data(&self) -> impl Iterator<Item = (DocumentId, &[f32])> + '_ {
self.data
.iter()
.map(|(id, vec_box, _)| (*id, vec_box.as_ref()))
}

pub fn into_data(self) -> impl Iterator<Item = (DocumentId, Vec<f32>)> {
self.data
.into_iter()
.map(|(id, vec_box, _)| (id, vec_box.into_vec()))
}

pub fn set_capacity(&mut self, capacity: usize) {
if self.data.len() >= capacity {
return;
}
self.data.reserve_exact(capacity - self.data.len());
}

pub fn add_owned(&mut self, point: Vec<f32>, id: DocumentId) {
let magnitude = f32::real_dot_product(&point, &point).unwrap();
self.data.push((id, point.into_boxed_slice(), magnitude));
}

pub fn search(&self, target: &[f32], limit: usize) -> Vec<(DocumentId, f32)> {
let target_magnitude = f32::real_dot_product(target, target).unwrap();

let half_data_len = self.data.len() / 2;
let (first_half, second_half) = self.data.split_at(half_data_len);

let (capped_head_one, capped_head_two) = rayon::join(
|| {
let mut capped_head_one = CappedHeap::new(limit);

for (id, vec, magnitude) in first_half {
// The cosine similarity isnt a distance in the math sense
// https://en.wikipedia.org/wiki/Distance#Mathematical_formalization
// Anyway, it is good for ranking purposes
// 1 means the vectors are equal
// 0 means the vectors are orthogonal
let score =
real_cosine_similarity((vec, *magnitude), (target, target_magnitude))
.expect("real_cosine_similarity should not return an error");

capped_head_one.insert(*id, OrderedFloat(score));
}
capped_head_one
},
|| {
let mut capped_head_two = CappedHeap::new(limit);

for (id, vec, magnitude) in second_half {
// The cosine similarity isnt a distance in the math sense
// https://en.wikipedia.org/wiki/Distance#Mathematical_formalization
// Anyway, it is good for ranking purposes
// 1 means the vectors are equal
// 0 means the vectors are orthogonal
let score =
real_cosine_similarity((vec, *magnitude), (target, target_magnitude))
.expect("real_cosine_similarity should not return an error");

capped_head_two.insert(*id, OrderedFloat(score));
}
capped_head_two
},
);

let mut output: Vec<_> = capped_head_one
.into_top()
.map(|(id, OrderedFloat(score))| (id, score))
.chain(
capped_head_two
.into_top()
.map(|(id, OrderedFloat(score))| (id, score)),
)
.collect();

output.sort_by_key(|(_, score)| Reverse(OrderedFloat(*score)));
output
}
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;

use rand::distr::{Distribution, Uniform};

use super::*;

#[test]
fn test_basic3_index() {
let dim = 3;

let mut index = VectorBruteForce::new(dim);

let points = [
vec![255.0, 0.0, 0.0],
vec![0.0, 255.0, 0.0],
vec![0.0, 0.0, 255.0],
];

for (id, point) in points.iter().enumerate() {
let id = id;
index.add_owned(point.clone(), id);
}

let target = vec![255.0, 0.0, 0.0];
let v = index.search(&target, 10);

let res: HashMap<_, _> = v.into_iter().collect();

assert_eq!(res, HashMap::from([(0, 1.0), (1, 0.0), (2, 0.0),]))
}

#[test]
fn test_basic3_serialize_deserialize() {
let n = 10_000;
let dimension = 64;

let normal = Uniform::new(0.0, 10.0).unwrap();
let samples = (0..n)
.map(|_| {
(0..dimension)
.map(|_| normal.sample(&mut rand::rng()))
.collect::<Vec<f32>>()
})
.collect::<Vec<Vec<f32>>>();
let mut index = VectorBruteForce::new(dimension);
for (i, sample) in samples.into_iter().enumerate() {
index.add_owned(sample.clone(), i);
}

let decoded = bincode::serialize(&index).unwrap();
let new_index: VectorBruteForce<usize> = bincode::deserialize(&decoded).unwrap();

let target = (0..dimension)
.map(|_| normal.sample(&mut rand::rng()))
.collect::<Vec<f32>>();

let v1 = index.search(&target, 10);
let v2 = new_index.search(&target, 10);

assert_eq!(v1, v2);
}
}
Loading
Loading