use core::num::NonZeroUsize;
use crate::error::Error;
use crate::kmeans::{ClusterEvent, Codebook, Scalar, cluster_with_events};
use crate::linalg::{add_in, subtract_in};
use crate::slice::AsSlice;
use crate::vector::{BlockVectorSet, VectorSet};
pub struct Partitions<T, VS> {
pub codebook: Codebook<T>,
pub residues: VS,
}
impl<T, VS> Partitions<T, VS>
where
T: Scalar,
VS: VectorSet<T>,
{
pub fn num_partitions(&self) -> usize {
self.codebook.indices.len()
}
pub fn all_vectors(&self) -> AllVectorIterator<'_, T, VS> {
AllVectorIterator::new(self)
}
}
pub struct AllVectorIterator<'a, T, VS>
where
VS: VectorSet<T>,
{
partitions: &'a Partitions<T, VS>,
next_index: usize,
}
impl<'a, T, VS> AllVectorIterator<'a, T, VS>
where
T: Scalar,
VS: VectorSet<T>,
{
pub fn new(partitions: &'a Partitions<T, VS>) -> Self {
Self {
partitions,
next_index: 0,
}
}
}
impl<'a, T, VS> Iterator for AllVectorIterator<'a, T, VS>
where
T: Scalar,
VS: VectorSet<T>,
{
type Item = Vec<T>;
fn next(&mut self) -> Option<Self::Item> {
if self.next_index < self.partitions.residues.len() {
let mut v: Vec<T> = Vec::with_capacity(
self.partitions.residues.vector_size(),
);
v.extend_from_slice(
self.partitions.residues.get(self.next_index).as_slice(),
);
let codebook = &self.partitions.codebook;
let ci = codebook.indices[self.next_index];
let centroid = codebook.centroids.get(ci).as_slice();
add_in(&mut v[..], centroid);
self.next_index += 1;
Some(v)
} else {
None
}
}
}
pub trait Partitioning<T, VS>
where
Self: Sized,
{
fn partition(self, p: NonZeroUsize) -> Result<Partitions<T, VS>, Error> {
self.partition_with_events(p, |_| ())
}
fn partition_with_events<EV>(
self,
p: NonZeroUsize,
event_handler: EV,
) -> Result<Partitions<T, VS>, Error>
where
EV: FnMut(ClusterEvent<'_, T>) -> ();
}
impl<T> Partitioning<T, Self> for BlockVectorSet<T>
where
T: Scalar,
{
fn partition_with_events<EV>(
mut self,
p: NonZeroUsize,
event_handler: EV,
) -> Result<Partitions<T, Self>, Error>
where
EV: FnMut(ClusterEvent<'_, T>) -> (),
{
let codebook = cluster_with_events(&self, p, event_handler)?;
for i in 0..p.get() {
let centroid = codebook.centroids.get(i);
for (j, _) in codebook.indices
.iter()
.enumerate()
.filter(|(_, &ci)| ci == i)
{
let v = self.get_mut(j);
subtract_in(v, centroid);
}
}
Ok(Partitions {
codebook,
residues: self,
})
}
}