From 69257c7346731aad55f9126428e28761c1c58c44 Mon Sep 17 00:00:00 2001 From: Daniel Hollarek Date: Tue, 28 May 2024 04:44:03 +0200 Subject: [PATCH] core: Re-merged labels and label tensor --- .../core/{labels/label.py => labels.py} | 43 ++++++++++++++++- xrdpattern/core/labels/__init__.py | 2 - xrdpattern/core/labels/label_tensor.py | 48 ------------------- 3 files changed, 42 insertions(+), 51 deletions(-) rename xrdpattern/core/{labels/label.py => labels.py} (75%) delete mode 100644 xrdpattern/core/labels/__init__.py delete mode 100644 xrdpattern/core/labels/label_tensor.py diff --git a/xrdpattern/core/labels/label.py b/xrdpattern/core/labels.py similarity index 75% rename from xrdpattern/core/labels/label.py rename to xrdpattern/core/labels.py index 2b7b0aa..dfcab7e 100644 --- a/xrdpattern/core/labels/label.py +++ b/xrdpattern/core/labels.py @@ -3,10 +3,10 @@ from dataclasses import dataclass import torch +from torch import Tensor from xrdpattern.core.structure import Angles, Lengths from xrdpattern.core.structure import CrystalStructure, CrystalBase, AtomicSite - # --------------------------------------------------------- NUM_SPACEGROUPS = 230 @@ -151,3 +151,44 @@ class PowderProperties: crystallite_size: float = 500 temp_in_kelvin : int = 293 + +class LabelTensor(Tensor): + example_powder_experiment: PatternLabel = PatternLabel.make_empty() + + def new_empty(self, *sizes, dtype=None, device=None, requires_grad=False): + dtype = dtype if dtype is not None else self.dtype + device = device if device is not None else self.device + return LabelTensor(torch.empty(*sizes, dtype=dtype, device=device, requires_grad=requires_grad)) + + @staticmethod + def __new__(cls, tensor) -> LabelTensor: + return torch.Tensor.as_subclass(tensor, cls) + + #noinspection PyTypeChecker + def get_lattice_params(self) -> LabelTensor: + region = self.example_powder_experiment.lattice_param_region + return self[..., region.start:region.end] + + # noinspection PyTypeChecker + def get_atomic_site(self, index: int) -> LabelTensor: + region = self.example_powder_experiment.atomic_site_regions[index] + return self[..., region.start:region.end] + + # noinspection PyTypeChecker + def get_spacegroups(self) -> LabelTensor: + region = self.example_powder_experiment.spacegroup_region + return self[..., region.start:region.end] + + # noinspection PyTypeChecker + def get_artifacts(self) -> LabelTensor: + region = self.example_powder_experiment.artifacts_region + return self[..., region.start:region.end] + + # noinspection PyTypeChecker + def get_domain(self) -> LabelTensor: + region = self.example_powder_experiment.domain_region + return self[..., region.start:region.end] + + # noinspection PyTypeChecker + def to_sample(self) -> PatternLabel: + raise NotImplementedError diff --git a/xrdpattern/core/labels/__init__.py b/xrdpattern/core/labels/__init__.py deleted file mode 100644 index 4ecab1e..0000000 --- a/xrdpattern/core/labels/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .label import PatternLabel, Artifacts, PowderProperties -from .label_tensor import LabelTensor diff --git a/xrdpattern/core/labels/label_tensor.py b/xrdpattern/core/labels/label_tensor.py deleted file mode 100644 index b95d881..0000000 --- a/xrdpattern/core/labels/label_tensor.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -import torch -from torch import Tensor - -from xrdpattern.core.labels import PatternLabel - - -class LabelTensor(Tensor): - example_powder_experiment: PatternLabel = PatternLabel.make_empty() - - def new_empty(self, *sizes, dtype=None, device=None, requires_grad=False): - dtype = dtype if dtype is not None else self.dtype - device = device if device is not None else self.device - return LabelTensor(torch.empty(*sizes, dtype=dtype, device=device, requires_grad=requires_grad)) - - @staticmethod - def __new__(cls, tensor) -> LabelTensor: - return torch.Tensor.as_subclass(tensor, cls) - - #noinspection PyTypeChecker - def get_lattice_params(self) -> LabelTensor: - region = self.example_powder_experiment.lattice_param_region - return self[..., region.start:region.end] - - # noinspection PyTypeChecker - def get_atomic_site(self, index: int) -> LabelTensor: - region = self.example_powder_experiment.atomic_site_regions[index] - return self[..., region.start:region.end] - - # noinspection PyTypeChecker - def get_spacegroups(self) -> LabelTensor: - region = self.example_powder_experiment.spacegroup_region - return self[..., region.start:region.end] - - # noinspection PyTypeChecker - def get_artifacts(self) -> LabelTensor: - region = self.example_powder_experiment.artifacts_region - return self[..., region.start:region.end] - - # noinspection PyTypeChecker - def get_domain(self) -> LabelTensor: - region = self.example_powder_experiment.domain_region - return self[..., region.start:region.end] - - # noinspection PyTypeChecker - def to_sample(self) -> PatternLabel: - raise NotImplementedError