Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/bkd/bkd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ impl<T: Copy + PartialOrd + Debug, D: Debug> BKDTree<T, D> {
}
}

pub fn is_empty(&self) -> bool {
self.len() == 0
}

pub fn insert(&mut self, point: Point<T, D>)
where
T: num_traits::Float,
Expand Down
8 changes: 8 additions & 0 deletions src/data_structures/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,11 @@ pub mod map;
pub mod ordered_key;
pub mod radix;
pub mod vector_bruteforce;

pub trait ShouldInclude<DocumentId>: Send + Sync {
fn should_include(&self, doc_id: &DocumentId) -> bool;

fn should_exclude(&self, doc_id: &DocumentId) -> bool {
!self.should_include(doc_id)
}
}
98 changes: 63 additions & 35 deletions src/data_structures/vector_bruteforce.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use std::{cmp::Reverse, fmt::Debug};

use crate::{data_structures::capped_heap::CappedHeap, hnsw2::core::simd_metrics::SIMDOptmized};
use crate::{
data_structures::{ShouldInclude, capped_heap::CappedHeap},
hnsw2::core::simd_metrics::SIMDOptmized,
};
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -65,46 +68,38 @@ impl<DocumentId: Debug + Clone + Copy + Serialize + Ord + Send + Sync>
self.data.push((id, point.into_boxed_slice(), magnitude));
}

pub fn search(&self, target: &[f32], limit: usize) -> Vec<(DocumentId, f32)> {
pub fn search(
&self,
target: &[f32],
limit: usize,
similarity: f32,
should_include: &impl ShouldInclude<DocumentId>,
) -> Vec<(DocumentId, f32)> {
let target_magnitude = f32::real_dot_product(target, target).unwrap();

let half_data_len = self.data.len() / 2;
let (first_half, second_half) = self.data.split_at(half_data_len);

let (capped_head_one, capped_head_two) = rayon::join(
|| {
let mut capped_head_one = CappedHeap::new(limit);

for (id, vec, magnitude) in first_half {
// The cosine similarity isnt a distance in the math sense
// https://en.wikipedia.org/wiki/Distance#Mathematical_formalization
// Anyway, it is good for ranking purposes
// 1 means the vectors are equal
// 0 means the vectors are orthogonal
let score =
real_cosine_similarity((vec, *magnitude), (target, target_magnitude))
.expect("real_cosine_similarity should not return an error");

capped_head_one.insert(*id, OrderedFloat(score));
}
capped_head_one
search_on(
target,
target_magnitude,
first_half,
limit,
similarity,
should_include,
)
},
|| {
let mut capped_head_two = CappedHeap::new(limit);

for (id, vec, magnitude) in second_half {
// The cosine similarity isnt a distance in the math sense
// https://en.wikipedia.org/wiki/Distance#Mathematical_formalization
// Anyway, it is good for ranking purposes
// 1 means the vectors are equal
// 0 means the vectors are orthogonal
let score =
real_cosine_similarity((vec, *magnitude), (target, target_magnitude))
.expect("real_cosine_similarity should not return an error");

capped_head_two.insert(*id, OrderedFloat(score));
}
capped_head_two
search_on(
target,
target_magnitude,
second_half,
limit,
similarity,
should_include,
)
},
);

Expand All @@ -123,6 +118,33 @@ impl<DocumentId: Debug + Clone + Copy + Serialize + Ord + Send + Sync>
}
}

fn search_on<DocumentId: Clone + Copy + Ord>(
target: &[f32],
target_magnitude: f32,
data: &[(DocumentId, Box<[f32]>, f32)],
limit: usize,
similarity: f32,
should_include: &impl ShouldInclude<DocumentId>,
) -> CappedHeap<DocumentId, OrderedFloat<f32>> {
let mut capped_head_two = CappedHeap::new(limit);

for (id, vec, magnitude) in data {
if should_include.should_exclude(id) {
continue;
}

let score = real_cosine_similarity((vec, *magnitude), (target, target_magnitude))
.expect("real_cosine_similarity should not return an error");

if score < similarity {
continue;
}

capped_head_two.insert(*id, OrderedFloat(score));
}
capped_head_two
}

#[cfg(test)]
mod tests {
use std::collections::HashMap;
Expand All @@ -131,6 +153,12 @@ mod tests {

use super::*;

impl ShouldInclude<usize> for () {
fn should_include(&self, _doc_id: &usize) -> bool {
true
}
}

#[test]
fn test_basic3_index() {
let dim = 3;
Expand All @@ -149,7 +177,7 @@ mod tests {
}

let target = vec![255.0, 0.0, 0.0];
let v = index.search(&target, 10);
let v = index.search(&target, 10, 0.0, &());

let res: HashMap<_, _> = v.into_iter().collect();

Expand Down Expand Up @@ -181,8 +209,8 @@ mod tests {
.map(|_| normal.sample(&mut rand::rng()))
.collect::<Vec<f32>>();

let v1 = index.search(&target, 10);
let v2 = new_index.search(&target, 10);
let v1 = index.search(&target, 10, 0.0, &());
let v2 = new_index.search(&target, 10, 0.0, &());

assert_eq!(v1, v2);
}
Expand Down
Loading