From 95a3ebe3781586329884de3d4b469ffc7754831b Mon Sep 17 00:00:00 2001 From: YuhanLiin Date: Sun, 7 Aug 2022 00:08:04 -0400 Subject: [PATCH 1/2] Add float param to PcaParams --- algorithms/linfa-reduction/src/pca.rs | 91 ++++++++++++++------------- 1 file changed, 49 insertions(+), 42 deletions(-) diff --git a/algorithms/linfa-reduction/src/pca.rs b/algorithms/linfa-reduction/src/pca.rs index c2e31213b..548681a83 100644 --- a/algorithms/linfa-reduction/src/pca.rs +++ b/algorithms/linfa-reduction/src/pca.rs @@ -21,6 +21,8 @@ //! let dataset = embedding.predict(dataset); //! ``` //! +use std::marker::PhantomData; + use crate::error::{ReductionError, Result}; #[cfg(not(feature = "blas"))] use linfa_linalg::{lobpcg::TruncatedSvd, Order}; @@ -44,12 +46,13 @@ use linfa::{ serde(crate = "serde_crate") )] #[derive(Debug, Clone, PartialEq, Eq)] -pub struct PcaParams { +pub struct PcaParams { embedding_size: usize, apply_whitening: bool, + _float: PhantomData, } -impl PcaParams { +impl PcaParams { /// Apply whitening to the embedding vector /// /// Whitening will scale the eigenvalues of the transformation such that the covariance will be @@ -61,6 +64,37 @@ impl PcaParams { } } +/// Fitted Principal Component Analysis model +/// +/// The model contains the mean and hyperplane for the projection of data. +/// +/// # Example +/// +/// ``` +/// use linfa::traits::{Fit, Predict}; +/// use linfa_reduction::Pca; +/// +/// let dataset = linfa_datasets::iris(); +/// +/// // apply PCA projection along a line which maximizes the spread of the data +/// let embedding = Pca::params(1) +/// .fit(&dataset).unwrap(); +/// +/// // reduce dimensionality of the dataset +/// let dataset = embedding.predict(dataset); +/// ``` +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, PartialEq)] +pub struct Pca { + embedding: Array2, + sigma: Array1, + mean: Array1, +} + /// Fit a PCA model given a dataset /// /// The Principal Component Analysis takes the records of a dataset and tries to find the best @@ -73,7 +107,7 @@ impl PcaParams { /// # Returns /// /// A fitted PCA model with origin and hyperplane -impl> Fit, T, ReductionError> for PcaParams { +impl> Fit, T, ReductionError> for PcaParams { type Object = Pca; fn fit(&self, dataset: &DatasetBase, T>) -> Result> { @@ -119,65 +153,38 @@ impl> Fit, T, ReductionError> for PcaPa } } -/// Fitted Principal Component Analysis model -/// -/// The model contains the mean and hyperplane for the projection of data. -/// -/// # Example -/// -/// ``` -/// use linfa::traits::{Fit, Predict}; -/// use linfa_reduction::Pca; -/// -/// let dataset = linfa_datasets::iris(); -/// -/// // apply PCA projection along a line which maximizes the spread of the data -/// let embedding = Pca::params(1) -/// .fit(&dataset).unwrap(); -/// -/// // reduce dimensionality of the dataset -/// let dataset = embedding.predict(dataset); -/// ``` -#[cfg_attr( - feature = "serde", - derive(Serialize, Deserialize), - serde(crate = "serde_crate") -)] -#[derive(Debug, Clone, PartialEq)] -pub struct Pca { - embedding: Array2, - sigma: Array1, - mean: Array1, -} - -impl Pca { +impl Pca { /// Create default parameter set /// /// # Parameters /// /// * `embedding_size`: the target dimensionality - pub fn params(embedding_size: usize) -> PcaParams { + pub fn params(embedding_size: usize) -> PcaParams { PcaParams { embedding_size, apply_whitening: false, + _float: PhantomData, } } /// Return the amount of explained variance per element - pub fn explained_variance(&self) -> Array1 { - self.sigma.mapv(|x| x * x / (self.sigma.len() as f64 - 1.0)) + pub fn explained_variance(&self) -> Array1 { + self.sigma + .mapv(|x| x * x / F::from(self.sigma.len() - 1).unwrap()) } /// Return the normalized amount of explained variance per element - pub fn explained_variance_ratio(&self) -> Array1 { - let ex_var = self.sigma.mapv(|x| x * x / (self.sigma.len() as f64 - 1.0)); + pub fn explained_variance_ratio(&self) -> Array1 { + let ex_var = self + .sigma + .mapv(|x| x * x / F::from(self.sigma.len() - 1).unwrap()); let sum_ex_var = ex_var.sum(); ex_var / sum_ex_var } /// Return the singular values - pub fn singular_values(&self) -> &Array1 { + pub fn singular_values(&self) -> &Array1 { &self.sigma } } @@ -234,7 +241,7 @@ mod tests { has_autotraits::(); has_autotraits::(); has_autotraits::(); - has_autotraits::(); + has_autotraits::>(); has_autotraits::>(); } From c3b8006970d661dd8c0a6bb132507e18a00d10c3 Mon Sep 17 00:00:00 2001 From: YuhanLiin Date: Sun, 7 Aug 2022 00:25:29 -0400 Subject: [PATCH 2/2] Generate f32 PCA impl using macros --- algorithms/linfa-reduction/src/pca.rs | 124 ++++++++++++++------------ 1 file changed, 66 insertions(+), 58 deletions(-) diff --git a/algorithms/linfa-reduction/src/pca.rs b/algorithms/linfa-reduction/src/pca.rs index 548681a83..3ec403443 100644 --- a/algorithms/linfa-reduction/src/pca.rs +++ b/algorithms/linfa-reduction/src/pca.rs @@ -95,64 +95,72 @@ pub struct Pca { mean: Array1, } -/// Fit a PCA model given a dataset -/// -/// The Principal Component Analysis takes the records of a dataset and tries to find the best -/// fit in a lower dimensional space such that the maximal variance is retained. -/// -/// # Parameters -/// -/// * `dataset`: A dataset with records in N dimensions -/// -/// # Returns -/// -/// A fitted PCA model with origin and hyperplane -impl> Fit, T, ReductionError> for PcaParams { - type Object = Pca; - - fn fit(&self, dataset: &DatasetBase, T>) -> Result> { - if dataset.nsamples() == 0 { - return Err(ReductionError::NotEnoughSamples); - } else if dataset.nfeatures() < self.embedding_size || self.embedding_size == 0 { - return Err(ReductionError::EmbeddingTooSmall(self.embedding_size)); - } +macro_rules! impl_pca { + ($F:ty) => { + /// Fit a PCA model given a dataset + /// + /// The Principal Component Analysis takes the records of a dataset and tries to find the best + /// fit in a lower dimensional space such that the maximal variance is retained. + /// + /// # Parameters + /// + /// * `dataset`: A dataset with records in N dimensions + /// + /// # Returns + /// + /// A fitted PCA model with origin and hyperplane + impl> Fit, T, ReductionError> for PcaParams<$F> { + type Object = Pca<$F>; + + fn fit(&self, dataset: &DatasetBase, T>) -> Result> { + if dataset.nsamples() == 0 { + return Err(ReductionError::NotEnoughSamples); + } else if dataset.nfeatures() < self.embedding_size || self.embedding_size == 0 { + return Err(ReductionError::EmbeddingTooSmall(self.embedding_size)); + } - let x = dataset.records(); - // calculate mean of data and subtract it - // safe because of above 0 samples check - let mean = x.mean_axis(Axis(0)).unwrap(); - let x = x - &mean; - - // estimate Singular Value Decomposition - #[cfg(feature = "blas")] - let result = - TruncatedSvd::new(x, TruncatedOrder::Largest).decompose(self.embedding_size)?; - #[cfg(not(feature = "blas"))] - let result = TruncatedSvd::new_with_rng(x, Order::Largest, SmallRng::seed_from_u64(42)) - .decompose(self.embedding_size)?; - // explained variance is the spectral distribution of the eigenvalues - let (_, sigma, mut v_t) = result.values_vectors(); - - // cut singular values to avoid numerical problems - let sigma = sigma.mapv(|x| x.max(1e-8)); - - // scale the embedding with the square root of the dimensionality and eigenvalue such that - // the product of the resulting matrix gives the unit covariance. - if self.apply_whitening { - let cov_scale = (dataset.nsamples() as f64 - 1.).sqrt(); - for (mut v_t, sigma) in v_t.axis_iter_mut(Axis(0)).zip(sigma.iter()) { - v_t *= cov_scale / *sigma; + let x = dataset.records(); + // calculate mean of data and subtract it + // safe because of above 0 samples check + let mean = x.mean_axis(Axis(0)).unwrap(); + let x = x - &mean; + + // estimate Singular Value Decomposition + #[cfg(feature = "blas")] + let result = + TruncatedSvd::new(x, TruncatedOrder::Largest).decompose(self.embedding_size)?; + #[cfg(not(feature = "blas"))] + let result = + TruncatedSvd::new_with_rng(x, Order::Largest, SmallRng::seed_from_u64(42)) + .decompose(self.embedding_size)?; + // explained variance is the spectral distribution of the eigenvalues + let (_, sigma, mut v_t) = result.values_vectors(); + + // cut singular values to avoid numerical problems + let sigma = sigma.mapv(|x| x.max(1e-8)); + + // scale the embedding with the square root of the dimensionality and eigenvalue such that + // the product of the resulting matrix gives the unit covariance. + if self.apply_whitening { + let cov_scale = (dataset.nsamples() as $F - 1.).sqrt(); + for (mut v_t, sigma) in v_t.axis_iter_mut(Axis(0)).zip(sigma.iter()) { + v_t *= cov_scale / *sigma; + } + } + + Ok(Pca { + embedding: v_t, + sigma, + mean, + }) } } - - Ok(Pca { - embedding: v_t, - sigma, - mean, - }) - } + }; } +impl_pca!(f64); +impl_pca!(f32); + impl Pca { /// Create default parameter set /// @@ -255,7 +263,7 @@ mod tests { let mut rng = SmallRng::seed_from_u64(42); // rotate data by 45° - let tmp = Array2::random_using((300, 2), Uniform::new(-1.0f64, 1.), &mut rng); + let tmp = Array2::random_using((300, 2), Uniform::new(-1.0f32, 1.), &mut rng); let q = array![[1., 1.], [-1., 1.]]; let dataset = Dataset::from(tmp.dot(&q)); @@ -265,7 +273,7 @@ mod tests { // check that the covariance is unit diagonal let cov = proj.t().dot(&proj); - assert_abs_diff_eq!(cov / (300. - 1.), Array2::eye(2), epsilon = 1e-5); + assert_abs_diff_eq!(cov / (300. - 1.), Array2::eye(2), epsilon = 1e-3); } /// Random number whitening test @@ -303,7 +311,7 @@ mod tests { let mut rng = SmallRng::seed_from_u64(3); // generate normal distribution random data with N >> p - let data = Array2::random_using((1000, 500), StandardNormal, &mut rng); + let data = Array2::::random_using((1000, 500), StandardNormal, &mut rng); let dataset = Dataset::from(data / 1000f64.sqrt()); let model = Pca::params(500).fit(&dataset).unwrap(); @@ -377,13 +385,13 @@ mod tests { #[test] fn test_explained_variance_diag() { - let dataset = Dataset::from(Array2::from_diag(&array![1., 1., 1., 1.])); + let dataset = Dataset::from(Array2::from_diag(&array![1.0f32, 1., 1., 1.])); let model = Pca::params(3).fit(&dataset).unwrap(); assert_abs_diff_eq!( model.explained_variance_ratio(), array![1. / 3., 1. / 3., 1. / 3.], - epsilon = 1e-6 + epsilon = 1e-3 ); } }