Skip to content

Commit

Permalink
Refine docstrings for TorchPCA and store X_new_ attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
miriambt committed Nov 15, 2024
1 parent f648e4d commit 010fedd
Showing 1 changed file with 59 additions and 40 deletions.
99 changes: 59 additions & 40 deletions snputils/processing/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,28 +33,29 @@ def _svd_flip(u, v, u_based_decision=True):

class TorchPCA:
"""
A class to perform Principal Component Analysis (PCA) using PyTorch tensors.
A class for GPU-based Principal Component Analysis (PCA) with PyTorch tensors.
This implementation leverages GPU acceleration to achieve significant performance improvements,
being up to 25 times faster than `sklearn.decomposition.PCA` when running on a compatible GPU.
"""
def __init__(self, n_components: int = None, fitting: str = 'reduced'):
def __init__(self, n_components: int = 2, fitting: str = 'reduced'):
"""
Args:
n_components (int, optional):
n_components (int, default=2):
The number of principal components. If None, defaults to the minimum of `n_samples` and `n_snps`.
fitting (str, default='reduced'):
The fitting approach to use for the SVD computation. Options:
- 'full': Full Singular Value Decomposition (SVD).
- 'reduced': Economy SVD.
- 'lowrank': Low-rank approximation, which provides a faster but approximate solution.
Default to 'reduced'.
The fitting approach for the SVD computation. Options:
- `'full'`: Full Singular Value Decomposition (SVD).
- `'reduced'`: Economy SVD.
- `'lowrank'`: Low-rank approximation, which provides a faster but approximate solution.
Default is 'reduced'.
"""
self.__n_components = n_components
self.__fitting = fitting
self.__n_components_ = None
self.__components_ = None
self.__mean_ = None
self.__X_new_ = None # Store transformed SNP data

def __getitem__(self, key):
"""
Expand Down Expand Up @@ -82,7 +83,7 @@ def n_components(self) -> Optional[int]:
Retrieve `n_components`.
Returns:
int: The number of components to keep. If None, defaults to the minimum of `n_samples` and `n_snps`.
**int:** The number of principal components.
"""
return self.__n_components

Expand All @@ -99,7 +100,8 @@ def fitting(self) -> str:
Retrieve `fitting`.
Returns:
str: The fitting approach to use for the SVD computation.
**str:**
The fitting approach for the SVD computation.
"""
return self.__fitting

Expand All @@ -116,10 +118,9 @@ def n_components_(self) -> Optional[int]:
Retrieve `n_components_`.
Returns:
int:
The effective number of components retained after fitting.
It equals the parameter `n_components`, or the lesser value of
`n_snps` and `n_samples` if `n_components` is `None`.
**int:**
The effective number of components retained after fitting,
calculated as `min(self.n_components, min(n_samples, n_snps))`.
"""
return self.__n_components_

Expand All @@ -136,9 +137,8 @@ def components_(self) -> Optional[torch.Tensor]:
Retrieve `components_`.
Returns:
torch.Tensor:
The matrix of principal components obtained after fitting.
Each row represents a principal component vector.
**tensor of shape (n_components_, n_snps):**
Matrix of principal components, where each row is a principal component vector.
"""
return self.__components_

Expand All @@ -155,8 +155,8 @@ def mean_(self) -> Optional[torch.Tensor]:
Retrieve `mean_`.
Returns:
torch.Tensor:
The per-feature mean vector of the input data used for centering.
**tensor of shape (n_snps,):**
Per-feature mean vector of the input data used for centering.
"""
return self.__mean_

Expand All @@ -167,13 +167,31 @@ def mean_(self, x: torch.Tensor) -> None:
"""
self.__mean_ = x

@property
def X_new_(self) -> Optional[Union[torch.Tensor, np.ndarray]]:
"""
Retrieve `X_new_`.
Returns:
**tensor of shape (n_samples, n_components_):**
The transformed SNP data projected onto the `n_components_` principal components.
"""
return self.__X_new_

@X_new_.setter
def X_new_(self, x: torch.Tensor) -> None:
"""
Update `X_new_`.
"""
self.__X_new_ = x

def copy(self) -> 'TorchPCA':
"""
Create and return a copy of the current `TorchPCA` instance.
Create and return a copy of `self`.
Returns:
TorchPCA:
A new instance of the current object.
**TorchPCA:**
A new instance of the current object.
"""
return copy.copy(self)

Expand All @@ -183,8 +201,7 @@ def _fit(self, X: torch.Tensor) -> Tuple:
Args:
X (tensor of shape (n_samples, n_snps)):
The input SNP data used for fitting the model, where `n_samples` is the number of
samples and `n_snps` is the number of features.
Input SNP data used for fitting the model.
Returns:
Tuple: U, S, and Vt matrices from the SVD, where:
Expand Down Expand Up @@ -225,52 +242,54 @@ def _fit(self, X: torch.Tensor) -> Tuple:

def fit(self, X: torch.Tensor) -> 'TorchPCA':
"""
Fit the model with `X`.
Fit the model to the input SNP data.
Args:
X (tensor of shape (n_samples, n_snps)):
The input SNP data used for fitting the model, where `n_samples` is the number of
samples and `n_snps` is the number of features.
The SNP data matrix to fit the model.
Returns:
TorchPCA:
The fitted `TorchPCA` object instance.
**TorchPCA:**
The fitted instance of `self`.
"""
self._fit(X)
return self

def transform(self, X: torch.Tensor) -> torch.Tensor:
"""
Apply dimensionality reduction on `X`.
Apply dimensionality reduction to the input SNP data using the fitted model.
Args:
X (tensor of shape (n_samples, n_snps)):
The data to transform, where `n_samples` is the number of samples and `n_snps` is the number of features.
The SNP data matrix to be transformed.
Returns:
torch.Tensor of shape (n_samples, n_components_):
The transformed SNP data onto the `n_components_` principal components.
**tensor of shape (n_samples, n_components_):**
The transformed SNP data projected onto the `n_components_` principal components,
stored in `self.X_new_`.
"""
if self.components_ is None or self.mean_ is None:
raise ValueError("The PCA model must be fitted before calling `transform`.")

return torch.matmul(X - self.mean_, self.components_.T)
self.X_new_ = torch.matmul(X - self.mean_, self.components_.T)
return self.X_new_

def fit_transform(self, X):
"""
Fit the model with `X` and apply the dimensionality reduction on `X`.
Fit the model to the SNP data and apply dimensionality reduction on the same SNP data.
Args:
X (tensor of shape n_samples, n_snps):
The input SNP data used for fitting the model, where `n_samples` is the number of
samples and `n_snps` is the number of features.
The SNP data matrix used for both fitting and transformation.
Returns:
torch.Tensor of shape (n_samples, n_components_):
The transformed SNP data onto the `n_components_` principal components.
**tensor of shape (n_samples, n_components_):**
The transformed SNP data projected onto the `n_components_` principal components,
stored in `self.X_new_`.
"""
U, S, _ = self._fit(X)
return U * S.unsqueeze(0)
self.X_new_ = U * S.unsqueeze(0)
return self.X_new_


class PCA:
Expand Down

0 comments on commit 010fedd

Please sign in to comment.