From f3ca2141e3f5826bf8ae1fedc89b503f479242f6 Mon Sep 17 00:00:00 2001 From: Benjamin Killeen Date: Sun, 12 Mar 2023 22:30:04 -0400 Subject: [PATCH] added instruments --- deepdrr/instruments/__init__.py | 7 + deepdrr/instruments/base.py | 328 ++++++++++++++++++++++++++++++++ deepdrr/utils/data_utils.py | 29 ++- deepdrr/utils/mesh_utils.py | 98 +++++++++- 4 files changed, 452 insertions(+), 10 deletions(-) create mode 100644 deepdrr/instruments/__init__.py create mode 100644 deepdrr/instruments/base.py diff --git a/deepdrr/instruments/__init__.py b/deepdrr/instruments/__init__.py new file mode 100644 index 00000000..73771bcf --- /dev/null +++ b/deepdrr/instruments/__init__.py @@ -0,0 +1,7 @@ +"""Instruments are modeled by voxelizing the surface meshes of the instrument components. + +""" + +from .base import Instrument + +__all__ = ["Instrument"] diff --git a/deepdrr/instruments/base.py b/deepdrr/instruments/base.py new file mode 100644 index 00000000..e82eaadf --- /dev/null +++ b/deepdrr/instruments/base.py @@ -0,0 +1,328 @@ +#!/usr/bin/env python3 +import logging +from abc import ABC +from abc import abstractmethod +from pathlib import Path +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import numpy as np +import pyvista as pv +from deepdrr import geo + +from ..vol import Volume +from .. import utils +from ..utils import data_utils + + +log = logging.getLogger(__name__) + + +class Instrument(Volume, ABC): + """A class for representing instruments based on voxelized surface models. + + In the DeepDRR_DATA directory, place an STL file for each material you want to use, inside a + directory determined by the class name. For example, if you have a + class called `MyTool` with steel and plastic components, place the STL files `steel.stl` and + `plastic.stl` in `DeepDRR_DATA/instruments/MyTool`. + + TODO: multiple components of the same material. + + DeepDRR_DATA + └── instruments + ├── ToolClassName + │ ├── material_1.stl + │ │── material_2.stl + │ └── material_3.stl + └── ToolClassName2 + ├── material_1.stl + │── material_2.stl + └── material_3.stl + + """ + + # Every tool should define the tip in anatomical (modeling) coordinates, which is the center point begin inserted into the body along + # the main axis of the tool, and the base, another point on that axis, so that they can be aligned. + base: geo.Point3D + tip: geo.Point3D + radius: float + + # Available materials may be found at: + # https://www.nist.gov/pml/x-ray-mass-attenuation-coefficients + + _material_mapping = { + "ABS Plastic": "polyethylene", + "Ceramic": "concrete", + "Stainless Steel": "iron", + "stainless_steel": "iron", + "steel": "iron", + "cement": "concrete", + "plastic": "polyethylene", + "metal": "iron", + "bone": "bone", + "titanium": "titanium", + } + + _default_densities = { + "polyethylene": 1.05, # polyethyelene is 0.97, but ABS plastic is 1.05 + "concrete": 1.5, + "iron": 7.5, + "titanium": 7, + "bone": 1.5, + } + + NUM_POINTS = 4000 + + def __init__( + self, + density: float = 0.1, + world_from_anatomical: Optional[geo.FrameTransform] = None, + densities: Dict[str, float] = {}, + ): + """Create the tool. + + Args: + density: The spacing of the voxelization for each component of the tool. + world_from_anatomical: Defines the pose of the tool in world. + densities: Optional overrides to the material densities + + """ + self.density = density + self._densities = self._default_densities.copy() + self._densities.update(densities) + + self.instruments_dir = data_utils.deepdrr_data_dir() / "instruments" + self.instruments_dir.mkdir(parents=True, exist_ok=True) + + self.surfaces = {} + bounds = [] + for material_dir, model_paths in self.get_model_paths(): + surface = pv.PolyData() + for p in model_paths: + s = pv.read(p) + if len(s.points) > self.NUM_POINTS: + s = s.decimate(1 - self.NUM_POINTS / len(s.points)) + surface += s + + material_dirname = ( + material_dir.name if isinstance(material_dir, Path) else material_dir + ) + self.surfaces[material_dirname] = surface + bounds.append(surface.bounds) + + bounds = np.array(bounds) + x_min, y_min, z_min = bounds[:, [0, 2, 4]].min(0) + x_max, y_max, z_max = bounds[:, [1, 3, 5]].max(0) + bounds = [x_min, x_max, y_min, y_max, z_min, z_max] + + cache_dir = self.get_cache_dir() + materials_path = cache_dir / "materials.npz".format() + anatomical_from_ijk_path = cache_dir / "anatomical_from_ijk.npy" + if materials_path.exists() and anatomical_from_ijk_path.exists(): + log.debug(f"using cached voxelization: {materials_path.absolute()}") + materials = dict(np.load(materials_path)) + anatomical_from_ijk = geo.FrameTransform(np.load(anatomical_from_ijk_path)) + else: + materials, anatomical_from_ijk = self._get_materials(density, bounds) + np.savez_compressed(materials_path, **materials) + np.save(anatomical_from_ijk_path, geo.get_data(anatomical_from_ijk)) + + # Convert from actual materials to DeepDRR compatible. + materials = dict( + (self._material_mapping[m], seg) for m, seg in materials.items() + ) + + data = np.zeros_like(list(materials.values())[0], dtype=np.float64) + for material, seg in materials.items(): + data += self._densities[material] * seg + + super().__init__( + data, + materials, + anatomical_from_ijk, + world_from_anatomical, + anatomical_coordinate_system=None, + ) + + def get_model_paths(self) -> List[Tuple[Path, List[Path]]]: + """Get the model paths associated with this Tool. + + Returns: + List[Tuple[Path, List[Path]]]: List of tuples containing the material dir and a list of paths with STL files for that material. + """ + stl_dir = self.instruments_dir / self.__class__.__name__ + model_paths = [(p.stem, [p]) for p in stl_dir.glob("*.stl")] + if not model_paths: + raise FileNotFoundError( + f"couldn't find materials for {self.__class__.__name__} in {stl_dir}" + ) + return model_paths + + def get_cache_dir(self) -> Path: + cache_dir = ( + data_utils.deepdrr_data_dir() + / "cache" + / self.__class__.__name__ + / "{}mm".format(str(self.density).replace(".", "-")) + ) + cache_dir.mkdir(exist_ok=True, parents=True) + return cache_dir + + def _get_materials(self, density, bounds): + materials = {} + for material, surface in self.surfaces.items(): + log.info( + f'voxelizing {self.__class__.__name__} "{material}" (may take a while)...' + ) + materials[material], anatomical_from_ijk = utils.mesh_utils.voxelize( + surface, + density=density, + bounds=bounds, + ) + + return materials, anatomical_from_ijk + + @property + def base_in_world(self) -> geo.Point3D: + return self.world_from_anatomical @ self.base + + @property + def tip_in_world(self) -> geo.Point3D: + return self.world_from_anatomical @ self.tip + + @property + def length_in_world(self): + return (self.tip_in_world - self.base_in_world).norm() + + def align( + self, + startpoint: geo.Point3D, + endpoint: geo.Point3D, + progress: float = 1, + distance: Optional[float] = None, + ): + """Place the tool along the line between startpoint and endpoint. + + Args: + startpoint (geo.Point3D): Startpoint in world. + endpoint (geo.Point3D): Point in world toward which the tool points. + progress (float): The fraction between startpoint and endpoint to place the tip of the tool. Defaults to 1. + distance (Optional[float], optional): The distance of the tip along the trajectory. 0 corresponds + to the tip placed at the start point, |startpoint - endpoint| at the end point. + Overrides progress if provided. Defaults to None. + + + """ + # useful: https://math.stackexchange.com/questions/180418/calculate-rotation-matrix-to-align-vector-a-to-vector-b-in-3d + if distance is not None: + progress = distance / self.length_in_world + + # interpolate along the direction of the tool to get the desired points in world. + startpoint = geo.point(startpoint) + endpoint = geo.point(endpoint) + progress = float(progress) + trajectory_vector = endpoint - startpoint + + desired_tip_in_world = startpoint.lerp(endpoint, progress) + desired_base_in_world = ( + desired_tip_in_world - trajectory_vector.hat() * self.length_in_world + ) + + self.world_from_anatomical = geo.FrameTransform.from_line_segments( + desired_tip_in_world, + desired_base_in_world, + self.tip, + self.base, + ) + + def orient( + self, + startpoint: geo.Point3D, + direction: geo.Vector3D, + distance: float = 0, + ): + return self.align( + startpoint, + startpoint + direction.hat(), + distance=distance, + ) + + def twist(self, angle: float, degrees: bool = True): + """Rotate the tool clockwise (when looking down on it) by `angle`. + + Args: + angle (float): The angle. + degrees (bool, optional): Whether `angle` is in degrees. Defaults to True. + """ + rotvec = (self.tip - self.base).hat() + rotvec *= utils.radians(angle, degrees=degrees) + self.world_from_anatomical = self.world_from_anatomical @ geo.frame_transform( + geo.Rotation.from_rotvec(rotvec) + ) + + def get_mesh_in_world(self, full: bool = True, use_cached: bool = True): + mesh = sum(self.surfaces.values(), pv.PolyData()) + mesh.transform(geo.get_data(self.world_from_anatomical), inplace=True) + # meshh += pv.Sphere( + # center=list(self.world_from_ijk @ geo.point(0, 0, 0)), radius=5 + # ) + + x, y, z = np.array(self.shape) - 1 + points = [ + [0, 0, 0], + [0, 0, z], + [0, y, 0], + [0, y, z], + [x, 0, 0], + [x, 0, z], + [x, y, 0], + [x, y, z], + ] + + points = [list(self.world_from_ijk @ geo.point(p)) for p in points] + mesh += pv.Line(points[0], points[1]) + mesh += pv.Line(points[0], points[2]) + mesh += pv.Line(points[3], points[1]) + mesh += pv.Line(points[3], points[2]) + mesh += pv.Line(points[4], points[5]) + mesh += pv.Line(points[4], points[6]) + mesh += pv.Line(points[7], points[5]) + mesh += pv.Line(points[7], points[6]) + mesh += pv.Line(points[0], points[4]) + mesh += pv.Line(points[1], points[5]) + mesh += pv.Line(points[2], points[6]) + mesh += pv.Line(points[3], points[7]) + + return mesh + + @property + def center(self) -> geo.Point3D: + return self.base.lerp(self.tip, 0.5) + + @property + def center_in_world(self) -> geo.Point3D: + return self.world_from_anatomical @ self.center + + @property + def trajectory_in_world(self) -> geo.Ray3D: + return geo.Ray3D.from_pn( + self.tip_in_world, self.tip_in_world - self.base_in_world + ) + + @property + def centerline_in_world(self) -> geo.Line3D: + return geo.line(self.tip_in_world, self.base_in_world) + + def advance(self, distance: float): + """Move the tool forward by the given distance. + + Args: + distance (float): The distance to move the tool forward. + """ + self.align( + self.tip_in_world, + self.tip_in_world + (self.tip_in_world - self.base_in_world), + distance=distance, + ) diff --git a/deepdrr/utils/data_utils.py b/deepdrr/utils/data_utils.py index 331be66e..aaa80bba 100644 --- a/deepdrr/utils/data_utils.py +++ b/deepdrr/utils/data_utils.py @@ -11,6 +11,26 @@ log = logging.getLogger(__name__) +def deepdrr_data_dir() -> Path: + """Get the data directory for DeepDRR. + + The data directory is determined by the environment variable `DEEPDRR_DATA_DIR` if it exists. + Otherwise, it is `~/datasets/DeepDRR`. If the directory does not exist, it is created. + + Returns: + Path: The data directory. + """ + if os.environ.get("DEEPDRR_DATA_DIR") is not None: + root = Path(os.environ.get("DEEPDRR_DATA_DIR")).expanduser() + else: + root = Path.home() / "datasets" / "DeepDRR_DATA" + + if not root.exists(): + root.mkdir(parents=True) + + return root + + def download( url: str, filename: Optional[str] = None, @@ -30,14 +50,7 @@ def download( Returns: Path: The path of the downloaded file, or the extracted directory. """ - if root is None and os.environ.get("DEEPDRR_DATA_DIR") is not None: - root = os.environ["DEEPDRR_DATA_DIR"] - elif root is None: - root = "~/datasets/DeepDRR_Data" - - root = Path(root).expanduser() - if not root.exists(): - root.mkdir(parents=True) + root = deepdrr_data_dir() if filename is None: filename = os.path.basename(url) diff --git a/deepdrr/utils/mesh_utils.py b/deepdrr/utils/mesh_utils.py index a1a93249..ab2afd91 100644 --- a/deepdrr/utils/mesh_utils.py +++ b/deepdrr/utils/mesh_utils.py @@ -1,10 +1,23 @@ +import logging +import os +import shutil +from pathlib import Path +from typing import List from typing import Optional +from typing import Tuple +from typing import Union -import logging +import nibabel as nib import numpy as np +import pyvista as pv +from rich.progress import Progress +from rich.progress import track import vtk from vtk.util import numpy_support as nps -import pyvista as pv + +from .. import geo +from ..utils import listify + log = logging.getLogger(__name__) @@ -99,3 +112,84 @@ def isosurface( log.warning(f"surface is not closed, with {surface.n_open_edges} open edges") return surface + + +def voxelize( + surface: pv.PolyData, + density: float = 0.2, + bounds: Optional[List[float]] = None, +) -> Tuple[np.ndarray, geo.FrameTransform]: + """Voxelize the surface mesh with the given density. + + Args: + surface (pv.PolyData): The surface. + density (Union[float, Tuple[float, float, float]]): Either a single float or a + list of floats giving the size of a voxel in x, y, z. + (This is really a spacing, but it's misnamed in pyvista.) + + Returns: + Tuple[np.ndarray, geo.FrameTransform]: The voxelized segmentation of the surface as np.uint8 and the associated world_from_ijk transform. + """ + density = listify(density, 3) + voxels = pv.voxelize(surface, density=density, check_surface=False) + + spacing = np.array(density) + if bounds is None: + bounds = surface.bounds + + x_min, x_max, y_min, y_max, z_min, z_max = bounds + size = np.array([(x_max - x_min), (y_max - y_min), (z_max - z_min)]) + if np.any(size) < 0: + raise ValueError(f"invalid bounds: {bounds}") + x, y, z = np.ceil(size / spacing).astype(int) + 1 + origin = np.array([x_min, y_min, z_min]) + world_from_ijk = geo.FrameTransform.from_rt(np.diag(spacing), origin) + ijk_from_world = world_from_ijk.inv + + data = np.zeros((x, y, z), dtype=np.uint8) + for p in track(voxels.points, "Rasterizing..."): + p = geo.point(p) + ijk = ijk_from_world @ p + i, j, k = np.array(ijk).astype(int) + data[i, j, k] = 1 + + return data, world_from_ijk + + +def voxelize_file(path: str, output_path: str, **kwargs): + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + surface = pv.read(path) + try: + data, world_from_ijk = voxelize(surface, **kwargs) + except ValueError: + log.warning(f"skipped {path} due to size error") + return + + img = nib.Nifti1Image(data, geo.get_data(geo.RAS_from_LPS @ world_from_ijk)) + nib.save(img, output_path) + + +def voxelize_dir(input_dir: str, output_dir: str, use_cached: bool = True, **kwargs): + input_dir = Path(input_dir) + output_dir = Path(output_dir) + + if output_dir.exists(): + shutil.rmtree(output_dir) + output_dir.mkdir() + + input_len = len(input_dir.parts) + paths: List[Path] = list(input_dir.glob("*/*.stl")) + output_path: Path + with Progress() as progress: + surfaces_voxelized = progress.add_task("Voxelizing surfaces", total=len(paths)) + for path in paths: + log.info(f"voxelizing {path}") + output_path = output_dir / os.path.join(*path.parts[input_len:]) + output_path = output_path.with_suffix(".nii.gz") + if output_path.exists() and use_cached: + progress.advance(surfaces_voxelized) + continue + + voxelize_file(path, output_path, **kwargs) + progress.advance(surfaces_voxelized)