diff --git a/tomotwin/modules/tools/umap.py b/tomotwin/modules/tools/umap.py index 5ee9e89..0e83252 100644 --- a/tomotwin/modules/tools/umap.py +++ b/tomotwin/modules/tools/umap.py @@ -9,7 +9,6 @@ except ImportError: print("cuml can't be loaded") -import mrcfile import numpy as np import pandas as pd from numpy.typing import ArrayLike @@ -98,28 +97,6 @@ def calcuate_umap( return embedding, reducer - def create_embedding_mask(self, embeddings: pd.DataFrame): - """ - Creates mask where each individual subvolume of the running windows gets an individual ID - """ - print("Create embedding mask") - Z = embeddings.attrs["tomogram_input_shape"][0] - Y = embeddings.attrs["tomogram_input_shape"][1] - X = embeddings.attrs["tomogram_input_shape"][2] - stride = embeddings.attrs["stride"][0] - segmentation_array = np.zeros(shape=(Z, Y, X), dtype=np.float32) - z = np.array(embeddings["Z"], dtype=int) - y = np.array(embeddings["Y"], dtype=int) - x = np.array(embeddings["X"], dtype=int) - - values = np.array(range(1, len(x) + 1)) - for stride_x in tqdm(list(range(stride))): - for stride_y in range(stride): - for stride_z in range(stride): - index = (z + stride_z, y + stride_y, x + stride_x) - segmentation_array[index] = values - - return segmentation_array def run(self, args): print("Read data") @@ -144,23 +121,18 @@ def run(self, args): os.makedirs(out_pth,exist_ok=True) fname = os.path.splitext(os.path.basename(args.input))[0] df_embeddings = pd.DataFrame(umap_embeddings) + df_embeddings.reset_index(drop=True, inplace=True) + embeddings.reset_index(drop=True, inplace=True) print("Write embeedings to disk") df_embeddings.columns = [f"umap_{i}" for i in range(umap_embeddings.shape[1])] + df_embeddings = pd.concat([embeddings[['X', 'Y', 'Z']], df_embeddings], axis=1) + df_embeddings.attrs['embeddings_attrs'] = embeddings.attrs + df_embeddings.attrs['embeddings_path'] = os.path.realpath(args.input) + df_embeddings.to_pickle(os.path.join(out_pth,fname+".tumap")) print("Write umap model to disk") pickle.dump(fitted_umap, open(os.path.join(out_pth, fname + "_umap_model.pkl"), "wb")) - print("Calculate label mask and write it to disk") - embedding_mask = self.create_embedding_mask(embeddings) - with mrcfile.new( - os.path.join( - args.output, - fname + "_label_mask.mrci", - ), - overwrite=True, - ) as mrc: - mrc.set_data(embedding_mask) - print("Done")