use core::borrow::Borrow;
use core::hash::Hash;
use core::iter::{IntoIterator, Iterator};
use core::num::NonZeroUsize;
use std::collections::HashMap;
use std::collections::hash_map::{Entry as HashMapEntry};
use uuid::Uuid;
use crate::error::Error;
use crate::kmeans::{ClusterEvent, Codebook, Scalar, cluster_with_events};
use crate::linalg::{dot, subtract_in};
use crate::partitions::{Partitioning, Partitions};
use crate::slice::AsSlice;
use crate::vector::{BlockVectorSet, VectorSet, divide_vector_set};
use super::{Attributes, AttributeValue};
pub mod proto;
pub struct DatabaseBuilder<T, VS>
where
VS: VectorSet<T>,
{
_t: core::marker::PhantomData<T>,
vs: VS,
num_partitions: usize,
num_divisions: usize,
num_clusters: usize,
}
impl<T, VS> DatabaseBuilder<T, VS>
where
T: Scalar,
VS: VectorSet<T> + Partitioning<T, VS>,
{
pub fn new(vs: VS) -> Self {
Self {
_t: core::marker::PhantomData,
vs,
num_partitions: 10,
num_divisions: 8,
num_clusters: 16,
}
}
pub fn with_partitions(mut self, num_partitions: NonZeroUsize) -> Self {
self.num_partitions = num_partitions.get();
self
}
pub fn with_divisions(mut self, num_divisions: NonZeroUsize) -> Self {
self.num_divisions = num_divisions.get();
self
}
pub fn with_clusters(mut self, num_clusters: NonZeroUsize) -> Self {
self.num_clusters = num_clusters.get();
self
}
pub fn build(self) -> Result<Database<T, VS>, Error> {
self.build_with_events(|_| {})
}
pub fn build_with_events<EventHandler>(
self,
mut event: EventHandler,
) -> Result<Database<T, VS>, Error>
where
EventHandler: FnMut(BuildEvent<'_, T>) -> (),
{
event(BuildEvent::StartingIdAssignment);
let mut vector_ids: Vec<Uuid> = Vec::with_capacity(self.vs.len());
for _ in 0..self.vs.len() {
vector_ids.push(Uuid::new_v4());
}
event(BuildEvent::FinishedIdAssignment);
event(BuildEvent::StartingPartitioning);
let partitions = self.vs.partition_with_events(
self.num_partitions.try_into().unwrap(),
|e| event(BuildEvent::ClusterEvent(e)),
)?;
event(BuildEvent::FinishedPartitioning);
event(BuildEvent::StartingSubvectorDivision);
let divided = divide_vector_set(
&partitions.residues,
self.num_divisions.try_into().unwrap(),
)?;
event(BuildEvent::FinishedSubvectorDivision);
let mut codebooks: Vec<Codebook<T>> = Vec::with_capacity(
self.num_divisions.try_into().unwrap(),
);
for (i, subvs) in divided.iter().enumerate() {
event(BuildEvent::StartingQuantization(i));
codebooks.push(cluster_with_events(
subvs,
self.num_clusters.try_into().unwrap(),
|e| event(BuildEvent::ClusterEvent(e)),
)?);
event(BuildEvent::FinishedQuantization(i));
}
Ok(Database {
vector_size: partitions.residues.vector_size(),
num_partitions: self.num_partitions,
num_divisions: self.num_divisions,
num_clusters: self.num_clusters,
vector_ids,
partitions,
codebooks,
attribute_table: HashMap::new(),
})
}
}
#[derive(Debug)]
pub enum BuildEvent<'a, T> {
StartingIdAssignment,
FinishedIdAssignment,
StartingPartitioning,
FinishedPartitioning,
StartingSubvectorDivision,
FinishedSubvectorDivision,
StartingQuantization(usize),
FinishedQuantization(usize),
ClusterEvent(ClusterEvent<'a, T>),
}
pub struct Database<T, VS>
where
VS: VectorSet<T>,
{
vector_size: usize,
num_partitions: usize,
num_divisions: usize,
num_clusters: usize,
vector_ids: Vec<Uuid>,
partitions: Partitions<T, VS>,
codebooks: Vec<Codebook<T>>,
attribute_table: HashMap<Uuid, Attributes>,
}
impl<T, VS> Database<T, VS>
where
VS: VectorSet<T>,
{
pub fn num_vectors(&self) -> usize {
self.vector_ids.len()
}
pub const fn vector_size(&self) -> usize {
self.vector_size
}
pub const fn num_partitions(&self) -> usize {
self.num_partitions
}
pub const fn num_divisions(&self) -> usize {
self.num_divisions
}
pub fn subvector_size(&self) -> usize {
self.vector_size / self.num_divisions
}
pub const fn num_clusters(&self) -> usize {
self.num_clusters
}
pub fn vector_ids(&self) -> impl Iterator<Item = &Uuid> {
self.vector_ids.iter()
}
pub fn partitions(&self) -> PartitionIter<'_, T, VS> {
PartitionIter {
database: self,
next_index: 0,
}
}
pub fn get_attribute<K>(
&self,
id: &Uuid,
key: &K,
) -> Result<Option<&AttributeValue>, Error>
where
String: Borrow<K>,
K: Hash + Eq + ?Sized,
{
Ok(
self.attribute_table
.get(id)
.ok_or(Error::InvalidArgs(
format!("no such vector ID: {}", id),
))?
.get(key),
)
}
pub fn set_attribute_at<KV, KEY, VAL>(
&mut self,
i: usize,
attribute: KV,
) -> Result<(), Error>
where
KV: Into<(KEY, VAL)>,
KEY: Into<String>,
VAL: Into<AttributeValue>,
{
let id = self.vector_ids.get(i)
.ok_or(Error::InvalidArgs(
format!("vector index out of bounds: {}", i),
))?;
let (key, value) = attribute.into();
let key = key.into();
let value = value.into();
if let Some(attributes) = self.attribute_table.get_mut(id) {
match attributes.entry(key.into()) {
HashMapEntry::Occupied(entry) => {
*entry.into_mut() = value.into();
},
HashMapEntry::Vacant(entry) => {
entry.insert(value.into());
},
};
} else {
self.attribute_table.insert(
id.clone(),
Attributes::from([(key, value)]),
);
}
Ok(())
}
}
impl<T, VS> Database<T, VS>
where
T: Scalar,
VS: VectorSet<T>,
{
pub fn query<V>(
&self,
v: &V,
k: NonZeroUsize,
nprobe: NonZeroUsize,
) -> Result<Vec<QueryResult<T>>, Error>
where
V: AsSlice<T> + ?Sized,
{
self.query_with_events(v, k, nprobe, |_| {})
}
pub fn query_with_events<V, EventHandler>(
&self,
v: &V,
k: NonZeroUsize,
nprobe: NonZeroUsize,
mut event: EventHandler,
) -> Result<Vec<QueryResult<T>>, Error>
where
V: AsSlice<T> + ?Sized,
EventHandler: FnMut(QueryEvent) -> (),
{
event(QueryEvent::StartingPartitionSelection);
let v = v.as_slice();
let queries = self.query_partitions(v, nprobe)?;
event(QueryEvent::FinishedPartitionSelection);
let mut all_results: Vec<QueryResult<T>> = Vec::new();
for query in &queries {
event(QueryEvent::StartingPartitionQuery(
query.partition_index,
));
let results = query.execute()?;
all_results.extend(results);
event(QueryEvent::FinishedPartitionQuery(
query.partition_index,
));
}
event(QueryEvent::StartingResultSelection);
all_results.sort_by(|lhs, rhs| {
lhs.squared_distance.partial_cmp(&rhs.squared_distance).unwrap()
});
all_results.truncate(k.get());
event(QueryEvent::FinishedResultSelection);
Ok(all_results)
}
fn query_partitions<'a>(
&'a self,
v: &[T],
nprobe: NonZeroUsize,
) -> Result<Vec<PartitionQuery<'a, T, VS>>, Error> {
let nprobe = nprobe.get();
if nprobe > self.num_partitions {
return Err(Error::InvalidArgs(format!(
"nprobe {} exceeds the number of partitions {}",
nprobe,
self.num_partitions,
)));
}
let mut local_vectors: Vec<(usize, Vec<T>, T)> =
Vec::with_capacity(self.num_partitions);
for pi in 0..self.num_partitions {
let mut localized: Vec<T> = Vec::new();
localized.extend_from_slice(v);
let centroid = self.partitions.codebook.centroids.get(pi);
subtract_in(&mut localized[..], centroid.as_slice());
let distance = dot(&localized[..], &localized[..]);
local_vectors.push((pi, localized, distance));
}
local_vectors.sort_by(|lhs, rhs| lhs.2.partial_cmp(&rhs.2).unwrap());
local_vectors.truncate(nprobe);
let queries = local_vectors
.into_iter()
.map(|(partition_index, localized, _)| PartitionQuery {
db: self,
partition_index,
localized,
})
.collect();
Ok(queries)
}
}
pub struct PartitionIter<'a, T, VS>
where
VS: VectorSet<T>,
{
database: &'a Database<T, VS>,
next_index: usize,
}
impl<'a, T, VS> Iterator for PartitionIter<'a, T, VS>
where
T: Clone,
VS: VectorSet<T>,
{
type Item = Partition<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.next_index < self.database.num_partitions {
let partition = Partition::new(self.database, self.next_index);
self.next_index += 1;
Some(partition)
} else {
None
}
}
}
pub struct Partition<T> {
centroid: Vec<T>,
encoded_vectors: BlockVectorSet<u32>,
vector_ids: Vec<Uuid>,
}
impl<T> Partition<T> {
pub fn vector_size(&self) -> usize {
self.centroid.len()
}
pub fn num_divisions(&self) -> usize {
self.encoded_vectors.vector_size()
}
pub fn num_vectors(&self) -> usize {
self.encoded_vectors.len()
}
}
impl<T> Partition<T>
where
T: Clone,
{
fn new<VS>(db: &Database<T, VS>, index: usize) -> Self
where
VS: VectorSet<T>,
{
let mut centroid: Vec<T> = Vec::with_capacity(db.vector_size());
centroid.extend_from_slice(
db.partitions.codebook.centroids.get(index),
);
let num_divisions = db.num_divisions();
let num_vectors = db.partitions.codebook.indices
.iter()
.filter(|&&pi| pi == index)
.count();
let mut encoded_vectors: Vec<u32> =
Vec::with_capacity(num_vectors * num_divisions);
let mut vector_ids: Vec<Uuid> = Vec::with_capacity(num_vectors);
for (vi, _) in db.partitions.codebook.indices
.iter()
.enumerate()
.filter(|(_, &pi)| pi == index)
{
for di in 0..num_divisions {
encoded_vectors.push(
db.codebooks[di].indices[vi].try_into().unwrap(),
);
}
vector_ids.push(db.vector_ids[vi]);
}
Partition {
centroid,
encoded_vectors: BlockVectorSet::chunk(
encoded_vectors,
num_divisions.try_into().unwrap(),
).unwrap(),
vector_ids,
}
}
}
#[derive(Debug)]
pub enum QueryEvent {
StartingPartitionSelection,
FinishedPartitionSelection,
StartingPartitionQuery(usize),
FinishedPartitionQuery(usize),
StartingResultSelection,
FinishedResultSelection,
}
pub struct PartitionQuery<'a, T, VS>
where
VS: VectorSet<T>,
{
db: &'a Database<T, VS>,
partition_index: usize,
localized: Vec<T>,
}
impl<'a, T, VS> PartitionQuery<'a, T, VS>
where
T: Scalar,
VS: VectorSet<T>,
{
pub fn execute(&self) -> Result<Vec<QueryResult<T>>, Error> {
let num_divisions = self.db.num_divisions();
let num_clusters = self.db.num_clusters();
let md = self.db.subvector_size();
let mut distance_table: Vec<T> = Vec::with_capacity(
num_divisions * num_clusters,
);
let mut vector_buf = vec![T::zero(); md];
for di in 0..num_divisions {
let from = di * md;
let to = from + md;
let subv = &self.localized[from..to];
for ci in 0..num_clusters {
let centroid = self.db.codebooks[di].centroids.get(ci);
let d = &mut vector_buf[..];
d.copy_from_slice(subv);
subtract_in(d, centroid.as_slice());
distance_table.push(dot(d, d));
}
}
let mut results: Vec<QueryResult<T>> = Vec::with_capacity(
self.partition_size(),
);
for (pvi, (vi, _)) in self.db.partitions.codebook.indices
.iter()
.enumerate()
.filter(|(_, &pi)| pi == self.partition_index)
.enumerate()
{
let mut distance = T::zero();
for di in 0..num_divisions {
let ci = self.db.codebooks[di].indices[vi];
distance += distance_table[di * num_clusters + ci];
}
results.push(QueryResult {
partition_index: self.partition_index,
vector_id: self.db.vector_ids[vi].clone(),
vector_index: pvi,
squared_distance: distance,
});
}
Ok(results)
}
fn partition_size(&self) -> usize {
self.db.partitions.codebook.indices
.iter()
.filter(|pi| **pi == self.partition_index)
.count()
}
}
#[derive(Clone, Debug)]
pub struct QueryResult<T> {
pub partition_index: usize,
pub vector_id: Uuid,
pub vector_index: usize,
pub squared_distance: T,
}