use core::borrow::Borrow;
use core::cell::{OnceCell, Ref, RefCell, RefMut};
use core::hash::Hash;
use core::num::NonZeroUsize;
use std::collections::hash_map::{Entry as HashMapEntry};
use uuid::Uuid;
use crate::error::Error;
use crate::io::{FileSystem, HashedFileIn};
use crate::kmeans::Scalar;
use crate::linalg::{dot, subtract};
use crate::nbest::{NBestByKey, TakeNBestByKey};
use crate::protos::database::{
AttributesLog as ProtosAttributesLog,
Database as ProtosDatabase,
Partition as ProtosPartition,
VectorSet as ProtosVectorSet,
};
use crate::protos::{Deserialize, read_message};
use crate::slice::AsSlice;
use crate::vector::BlockVectorSet;
use super::{AttributeTable, AttributeValue, Attributes};
pub const PROTOBUF_EXTENSION: &str = "binpb";
pub trait LoadDatabase<T, FS> {
fn load_database<P>(fs: FS, path: P) -> Result<Database<T, FS>, Error>
where
P: AsRef<str>;
}
pub struct Database<T, FS> {
fs: FS,
vector_size: usize,
num_partitions: usize,
num_divisions: usize,
num_codes: usize,
partition_ids: Vec<String>,
partitions: RefCell<Vec<Option<Partition<T>>>>,
partition_centroids_id: String,
partition_centroids: OnceCell<BlockVectorSet<T>>,
codebook_ids: Vec<String>,
codebooks: RefCell<Option<Vec<BlockVectorSet<T>>>>,
attributes_log_ids: Vec<String>,
attributes_log_load_flags: RefCell<Vec<bool>>,
attribute_names: Vec<String>,
attribute_table: RefCell<Option<AttributeTable>>,
}
impl<T, FS> Database<T, FS>
where
FS: FileSystem,
{
pub fn vector_size(&self) -> usize {
self.vector_size
}
pub fn num_partitions(&self) -> usize {
self.num_partitions
}
pub fn num_divisions(&self) -> usize {
self.num_divisions
}
pub fn num_codes(&self) -> usize {
self.num_codes
}
pub fn subvector_size(&self) -> usize {
self.vector_size / self.num_divisions
}
pub fn get_partition_id(&self, index: usize) -> Option<&String> {
self.partition_ids.get(index)
}
pub fn get_codebook_id(&self, index: usize) -> Option<&String> {
self.codebook_ids.get(index)
}
}
impl<T, FS> Database<T, FS>
where
FS: FileSystem,
Self: LoadPartition<T>,
{
pub fn get_attribute<K>(
&self,
vector_id: &Uuid,
key: &K,
) -> Result<Option<AttributeValueRef>, Error>
where
String: Borrow<K>,
K: Hash + Eq + ?Sized,
{
if self.attribute_table.borrow().is_none() {
self.load_attribute_table()?;
}
self.get_attribute_internal(vector_id, key)
}
fn get_attribute_in_partition<K>(
&self,
partition_index: usize,
vector_id: &Uuid,
key: &K,
) -> Result<Option<AttributeValueRef>, Error>
where
String: Borrow<K>,
K: Hash + Eq + ?Sized,
{
self.load_attributes_log(partition_index)?;
self.get_attribute_internal(vector_id, key)
}
fn get_attribute_internal<K>(
&self,
vector_id: &Uuid,
key: &K,
) -> Result<Option<AttributeValueRef>, Error>
where
String: Borrow<K>,
K: Hash + Eq + ?Sized,
{
let attribute_table = Ref::filter_map(
self.attribute_table.borrow(),
|tbl| tbl.as_ref(),
).expect("attribute table must be loaded");
let attributes = Ref::filter_map(
attribute_table,
|tbl| tbl.get(vector_id),
).or(Err(Error::InvalidArgs(
format!("no such vector ID: {}", vector_id),
)))?;
match Ref::filter_map(attributes, |attrs| attrs.get(key)) {
Ok(value) => Ok(Some(value)),
Err(_) => Ok(None),
}
}
fn load_attribute_table(&self) -> Result<(), Error> {
for pi in 0..self.num_partitions() {
self.load_attributes_log(pi)?;
}
Ok(())
}
fn load_attributes_log(&self, partition_index: usize) -> Result<(), Error> {
if self.attributes_log_load_flags.borrow()[partition_index] {
return Ok(());
}
let partition = self.get_partition(partition_index)?;
let mut f = self.fs.open_compressed_hashed_file(format!(
"attributes/{}.{}",
self.attributes_log_ids[partition_index],
PROTOBUF_EXTENSION,
))?;
let attributes_log: ProtosAttributesLog = read_message(&mut f)?;
if attributes_log.partition_id != self.partition_ids[partition_index] {
return Err(Error::InvalidData(format!(
"inconsistent partition IDs: {} vs {}",
attributes_log.partition_id,
self.partition_ids[partition_index],
)));
}
if self.attribute_table.borrow().is_none() {
self.attribute_table.replace(Some(AttributeTable::new()));
}
let mut attribute_table = RefMut::filter_map(
self.attribute_table.borrow_mut(),
|tbl| tbl.as_mut(),
).expect("attribute table must exist");
for (i, entry) in attributes_log.entries.into_iter().enumerate() {
let attribute_name = self.attribute_names
.get(entry.name_index as usize)
.ok_or(Error::InvalidData(format!(
"attribute name index out of bounds: {}",
entry.name_index,
)))?;
let vector_id = entry.vector_id
.into_option()
.ok_or(Error::InvalidData(format!(
"attributes log[{}, {}]: missing vector ID",
partition_index,
i,
)))?
.deserialize()?;
let value = entry.value
.into_option()
.ok_or(Error::InvalidData(format!(
"attributes log[{}, {}]: missing value",
partition_index,
i,
)))?
.deserialize()?;
match attribute_table.entry(vector_id) {
HashMapEntry::Occupied(slot) => {
match slot.into_mut().entry(attribute_name.clone()) {
HashMapEntry::Occupied(slot) => {
*slot.into_mut() = value;
},
HashMapEntry::Vacant(slot) => {
slot.insert(value);
},
};
},
HashMapEntry::Vacant(slot) => {
slot.insert(Attributes::from([
(attribute_name.clone(), value),
]));
},
};
}
for vector_id in partition.vector_ids.iter() {
attribute_table
.entry(vector_id.clone())
.or_insert_with(Attributes::new);
}
self.attributes_log_load_flags.borrow_mut()[partition_index] = true;
Ok(())
}
fn get_partition(
&self,
index: usize,
) -> Result<PartitionRef<'_, T>, Error> {
if index >= self.num_partitions() {
return Err(Error::InvalidArgs(format!(
"partition index out of bounds: {}",
index,
)));
}
if self.partitions.borrow()[index].is_none() {
self.partitions.borrow_mut()[index] =
Some(self.load_partition(index)?);
}
let partition =
Ref:: filter_map(
self.partitions.borrow(),
|partitions| partitions[index].as_ref(),
)
.or(Err(Error::InvalidData(
"partition must be loaded".to_string(),
)))
.unwrap();
Ok(partition)
}
}
type PartitionRef<'a, T> = Ref<'a, Partition<T>>;
pub type AttributeValueRef<'a> = Ref<'a, AttributeValue>;
impl<T, FS> Database<T, FS>
where
T: Scalar,
FS: FileSystem,
Self: LoadPartition<T> + LoadCodebook<T> + LoadPartitionCentroids<T>,
{
pub fn query<'a, V>(
&'a self,
v: &V,
k: NonZeroUsize,
nprobe: NonZeroUsize,
) -> Result<Vec<QueryResult<'a, T, FS>>, Error>
where
V: AsSlice<T> + ?Sized,
{
self.query_with_events(v, k, nprobe, |_| {})
}
pub fn query_with_events<'a, V, EventHandler>(
&'a self,
v: &V,
k: NonZeroUsize,
nprobe: NonZeroUsize,
mut event: EventHandler,
) -> Result<Vec<QueryResult<'a, T, FS>>, Error>
where
V: AsSlice<T> + ?Sized,
EventHandler: FnMut(QueryEvent) -> (),
{
event(QueryEvent::StartingQueryInitialization);
if self.partition_centroids.get().is_none() {
self.partition_centroids
.set(self.load_partition_centroids()?)
.unwrap();
}
if self.codebooks.borrow().is_none() {
let mut codebooks: Vec<BlockVectorSet<T>> =
Vec::with_capacity(self.num_divisions());
for di in 0..self.num_divisions() {
codebooks.push(self.load_codebook(di)?);
}
self.codebooks.replace(Some(codebooks));
}
event(QueryEvent::FinishedQueryInitialization);
event(QueryEvent::StartingPartitionSelection);
let v = v.as_slice();
let queries = self.query_partitions(v, k, nprobe)?;
event(QueryEvent::FinishedPartitionSelection);
let all_results: Vec<Vec<QueryResult<'a, T, FS>>> = queries
.into_iter()
.map(|query| {
event(QueryEvent::StartingPartitionQuery(
query.partition_index,
));
let results = query.execute();
if results.is_ok() {
event(QueryEvent::FinishedPartitionQuery(
query.partition_index,
));
}
results
})
.collect::<Result<Vec<_>, Error>>()?;
event(QueryEvent::StartingResultSelection);
let mut all_results: Vec<QueryResult<'a, T, FS>> = all_results
.into_iter()
.flatten()
.n_best_by_key(k.get(), |r| r.squared_distance)
.into();
all_results.sort_by(|lhs, rhs| {
lhs.squared_distance.partial_cmp(&rhs.squared_distance).unwrap()
});
event(QueryEvent::FinishedResultSelection);
Ok(all_results)
}
fn query_partitions<'a>(
&'a self,
v: &[T],
k: NonZeroUsize,
nprobe: NonZeroUsize,
) -> Result<Vec<PartitionQuery<'a, T, FS>>, Error> {
let nprobe = nprobe.get();
let k = k.get();
let num_partitions = self.num_partitions();
if nprobe > num_partitions {
return Err(Error::InvalidArgs(format!(
"nprobe {} exceeds the number of partitions {}",
nprobe,
num_partitions,
)));
}
let partition_centroids = self.partition_centroids.get()
.expect("partition centroids must be loaded");
let mut distances: NBestByKey<(usize, Vec<T>, T), T, _> =
NBestByKey::new(nprobe, |(_, _, distance)| *distance);
for pi in 0..num_partitions {
let mut localized: Vec<T> = Vec::with_capacity(self.vector_size());
unsafe {
localized.set_len(self.vector_size());
}
let centroid = partition_centroids.get(pi);
subtract(v, ¢roid, &mut localized[..]);
let distance = dot(&localized[..], &localized[..]);
distances.push((pi, localized, distance));
}
distances.sort_by(|lhs, rhs| lhs.2.partial_cmp(&rhs.2).unwrap());
let queries = distances
.into_iter()
.map(|(pi, localized, _)| PartitionQuery {
db: self,
codebooks: Ref::map(
self.codebooks.borrow(),
|cb| cb.as_ref().unwrap(),
),
partition_index: pi,
localized,
k,
})
.collect();
Ok(queries)
}
}
#[derive(Clone)]
pub struct Partition<T> {
_t: std::marker::PhantomData<T>,
encoded_vectors: BlockVectorSet<u32>,
vector_ids: Vec<Uuid>,
}
impl<T> Partition<T> {
pub fn num_vectors(&self) -> usize {
self.encoded_vectors.len()
}
pub fn get_encoded_vector(&self, index: usize) -> Option<&[u32]> {
if index < self.encoded_vectors.len() {
Some(self.encoded_vectors.get(index))
} else {
None
}
}
pub fn get_vector_id(&self, index: usize) -> Option<&Uuid> {
self.vector_ids.get(index)
}
}
pub trait LoadPartition<T> {
fn load_partition(&self, index: usize) -> Result<Partition<T>, Error>;
}
pub trait LoadCodebook<T> {
fn load_codebook(&self, index: usize) -> Result<BlockVectorSet<T>, Error>;
}
pub trait LoadPartitionCentroids<T> {
fn load_partition_centroids(&self) -> Result<BlockVectorSet<T>, Error>;
}
#[derive(Debug)]
pub enum QueryEvent {
StartingQueryInitialization,
FinishedQueryInitialization,
StartingPartitionSelection,
FinishedPartitionSelection,
StartingPartitionQuery(usize),
FinishedPartitionQuery(usize),
StartingResultSelection,
FinishedResultSelection,
}
struct PartitionQuery<'a, T, FS> {
db: &'a Database<T, FS>,
codebooks: Ref<'a, Vec<BlockVectorSet<T>>>,
partition_index: usize,
localized: Vec<T>, k: usize,
}
impl<'a, T, FS> PartitionQuery<'a, T, FS>
where
T: Scalar,
FS: FileSystem,
Database<T, FS>: LoadPartition<T> + LoadCodebook<T>,
{
fn execute(&self) -> Result<Vec<QueryResult<'a, T, FS>>, Error> {
let num_divisions = self.db.num_divisions();
let num_codes = self.db.num_codes();
let subvector_size = self.db.subvector_size();
let partition = self.db.get_partition(self.partition_index)?;
let mut distance_table: Vec<T> =
Vec::with_capacity(num_divisions * num_codes);
let mut vector_buf: Vec<T> = Vec::with_capacity(subvector_size);
unsafe {
vector_buf.set_len(subvector_size);
}
for di in 0..num_divisions {
let from = di * subvector_size;
let to = from + subvector_size;
let subv = &self.localized[from..to];
let codebook = &self.codebooks[di];
for ci in 0..num_codes {
let code_vector = codebook.get(ci);
let d = &mut vector_buf[..];
subtract(subv, code_vector, d);
distance_table.push(dot(d, d));
}
}
let num_vectors = partition.num_vectors();
let mut results: NBestByKey<QueryResult<'a, T, FS>, T, _> =
NBestByKey::new(
self.k,
|i: &QueryResult<'a, T, FS>| i.squared_distance,
);
for vi in 0..num_vectors {
let encoded_vector = partition.get_encoded_vector(vi).unwrap();
let mut distance = T::zero();
for di in 0..num_divisions {
let ci = encoded_vector[di] as usize;
distance += distance_table[di * num_codes + ci];
}
results.push(QueryResult {
db: self.db,
partition_index: self.partition_index,
vector_id: partition.get_vector_id(vi).unwrap().clone(),
vector_index: vi,
squared_distance: distance,
});
}
Ok(results.into())
}
}
#[derive(Clone)]
pub struct QueryResult<'a, T, FS> {
db: &'a Database<T, FS>,
pub partition_index: usize,
pub vector_id: Uuid,
pub vector_index: usize,
pub squared_distance: T,
}
impl<'a, T, FS> QueryResult<'a, T, FS>
where
T: Scalar,
FS: FileSystem,
Database<T, FS>:
LoadPartition<T> + LoadCodebook<T> + LoadPartitionCentroids<T>,
{
pub fn get_attribute<K>(
&self,
key: &K,
) -> Result<Option<AttributeValueRef>, Error>
where
String: Borrow<K>,
K: Hash + Eq + ?Sized,
{
self.db.get_attribute_in_partition(
self.partition_index,
&self.vector_id,
key,
)
}
}
mod f32impl {
use super::*;
impl<FS> LoadDatabase<f32, FS> for Database<f32, FS>
where
FS: FileSystem,
{
fn load_database<P>(fs: FS, path: P) -> Result<Database<f32, FS>, Error>
where
P: AsRef<str>,
{
let mut f = fs.open_compressed_hashed_file(path)?;
let db: ProtosDatabase = read_message(&mut f)?;
f.verify()?;
let vector_size = db.vector_size as usize;
let num_partitions = db.num_partitions as usize;
let num_divisions = db.num_divisions as usize;
let num_codes = db.num_codes as usize;
if vector_size == 0 {
return Err(Error::InvalidData(format!("vector_size is zero")));
}
if num_divisions == 0 {
return Err(Error::InvalidData(format!("num_divisions is zero")));
}
if num_partitions == 0 {
return Err(Error::InvalidData(format!("num_partitions is zero")));
}
if num_codes == 0 {
return Err(Error::InvalidData(format!("num_codes is zero")));
}
if vector_size % num_divisions != 0 {
return Err(Error::InvalidData(format!(
"vector_size {} is not multiple of num_divisions {}",
vector_size,
num_divisions,
)));
}
if num_partitions != db.partition_ids.len() {
return Err(Error::InvalidData(format!(
"num_partitions {} and partition_ids.len() {} do not match",
db.num_partitions,
db.partition_ids.len(),
)));
}
if num_divisions != db.codebook_ids.len() {
return Err(Error::InvalidData(format!(
"num_divisions {} and codebook_ids.len() {} do not match",
db.num_divisions,
db.codebook_ids.len(),
)));
}
let db = Database {
fs,
vector_size,
num_partitions,
num_divisions,
num_codes,
partition_ids: db.partition_ids,
partitions: RefCell::new(vec![None; num_partitions]),
partition_centroids_id: db.partition_centroids_id,
partition_centroids: OnceCell::new(),
codebook_ids: db.codebook_ids,
codebooks: RefCell::new(None),
attributes_log_ids: db.attributes_log_ids,
attributes_log_load_flags:
RefCell::new(vec![false; num_partitions]),
attribute_names: db.attribute_names,
attribute_table: RefCell::new(None),
};
Ok(db)
}
}
impl<FS> LoadPartitionCentroids<f32> for Database<f32, FS>
where
FS: FileSystem,
{
fn load_partition_centroids(
&self,
) -> Result<BlockVectorSet<f32>, Error> {
let mut f = self.fs.open_hashed_file(format!(
"partitions/{}.{}",
self.partition_centroids_id,
PROTOBUF_EXTENSION,
))?;
let partition_centroids: ProtosVectorSet = read_message(&mut f)?;
let partition_centroids: BlockVectorSet<f32> =
partition_centroids.deserialize()?;
if partition_centroids.vector_size() != self.vector_size() {
return Err(Error::InvalidData(format!(
"partition centroids vector size mismatch: expected {}, got {}",
self.vector_size(),
partition_centroids.vector_size(),
)));
}
if partition_centroids.len() != self.num_partitions() {
return Err(Error::InvalidData(format!(
"partition centroids data length mismatch: expected {}, got {}",
self.num_partitions(),
partition_centroids.len(),
)));
}
Ok(partition_centroids)
}
}
impl<FS> LoadCodebook<f32> for Database<f32, FS>
where
FS: FileSystem,
{
fn load_codebook(
&self,
index: usize,
) -> Result<BlockVectorSet<f32>, Error>
where
FS: FileSystem,
{
if index >= self.num_divisions() {
return Err(Error::InvalidArgs(format!(
"index {} exceeds the number of codebooks {}",
index,
self.num_divisions(),
)));
}
let mut f = self.fs.open_hashed_file(format!(
"codebooks/{}.{}",
self.get_codebook_id(index).unwrap(),
PROTOBUF_EXTENSION,
))?;
let codebook: ProtosVectorSet = read_message(&mut f)?;
f.verify()?;
let codebook: BlockVectorSet<f32> = codebook.deserialize()?;
if codebook.vector_size() != self.subvector_size() {
return Err(Error::InvalidData(format!(
"vector_size is inconsistent: expected {} but got {}",
self.subvector_size(),
codebook.vector_size(),
)));
}
if codebook.len() != self.num_codes() {
return Err(Error::InvalidData(format!(
"number of codes is inconsistent: expected {} but got {}",
self.num_codes(),
codebook.len(),
)));
}
Ok(codebook)
}
}
impl<FS> LoadPartition<f32> for Database<f32, FS>
where
FS: FileSystem,
{
fn load_partition(
&self,
index: usize,
) -> Result<Partition<f32>, Error> {
if index >= self.num_partitions {
return Err(Error::InvalidArgs(format!(
"index {} exceeds the number of partitions {}",
index,
self.num_partitions,
)));
}
let mut f = self.fs.open_compressed_hashed_file(format!(
"partitions/{}.{}",
self.get_partition_id(index).unwrap(),
PROTOBUF_EXTENSION,
))?;
let partition: ProtosPartition = read_message(&mut f)?;
f.verify()?;
let vector_size = partition.vector_size as usize;
let num_divisions = partition.num_divisions as usize;
let encoded_vectors: BlockVectorSet<u32> = partition.encoded_vectors
.into_option()
.ok_or(Error::InvalidData(
"missing encoded vectors".to_string(),
))?
.deserialize()?;
if vector_size != self.vector_size() {
return Err(Error::InvalidData(format!(
"vector_size {} and partition.vector_size {} do not match",
self.vector_size(),
vector_size,
)));
}
if num_divisions != self.num_divisions() {
return Err(Error::InvalidData(format!(
"num_divisions {} and partition.num_divisions {} do not match",
self.num_divisions(),
num_divisions,
)));
}
if encoded_vectors.len() != partition.vector_ids.len() {
return Err(Error::InvalidData(format!(
"number of vector IDs is inconsistent: exptected {} but got {}",
encoded_vectors.len(),
partition.vector_ids.len(),
)));
}
let vector_ids: Vec<Uuid> = partition.vector_ids
.into_iter()
.map(|id| id.deserialize().unwrap())
.collect();
Ok(Partition {
_t: std::marker::PhantomData,
encoded_vectors,
vector_ids,
})
}
}
}