Skip to content

Commit

Permalink
core: Re-merged labels and label tensor
Browse files Browse the repository at this point in the history
  • Loading branch information
Somerandomguy10111 committed May 28, 2024
1 parent 723822c commit 69257c7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 51 deletions.
43 changes: 42 additions & 1 deletion xrdpattern/core/labels/label.py → xrdpattern/core/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from dataclasses import dataclass

import torch
from torch import Tensor

from xrdpattern.core.structure import Angles, Lengths
from xrdpattern.core.structure import CrystalStructure, CrystalBase, AtomicSite

# ---------------------------------------------------------

NUM_SPACEGROUPS = 230
Expand Down Expand Up @@ -151,3 +151,44 @@ class PowderProperties:
crystallite_size: float = 500
temp_in_kelvin : int = 293


class LabelTensor(Tensor):
example_powder_experiment: PatternLabel = PatternLabel.make_empty()

def new_empty(self, *sizes, dtype=None, device=None, requires_grad=False):
dtype = dtype if dtype is not None else self.dtype
device = device if device is not None else self.device
return LabelTensor(torch.empty(*sizes, dtype=dtype, device=device, requires_grad=requires_grad))

@staticmethod
def __new__(cls, tensor) -> LabelTensor:
return torch.Tensor.as_subclass(tensor, cls)

#noinspection PyTypeChecker
def get_lattice_params(self) -> LabelTensor:
region = self.example_powder_experiment.lattice_param_region
return self[..., region.start:region.end]

# noinspection PyTypeChecker
def get_atomic_site(self, index: int) -> LabelTensor:
region = self.example_powder_experiment.atomic_site_regions[index]
return self[..., region.start:region.end]

# noinspection PyTypeChecker
def get_spacegroups(self) -> LabelTensor:
region = self.example_powder_experiment.spacegroup_region
return self[..., region.start:region.end]

# noinspection PyTypeChecker
def get_artifacts(self) -> LabelTensor:
region = self.example_powder_experiment.artifacts_region
return self[..., region.start:region.end]

# noinspection PyTypeChecker
def get_domain(self) -> LabelTensor:
region = self.example_powder_experiment.domain_region
return self[..., region.start:region.end]

# noinspection PyTypeChecker
def to_sample(self) -> PatternLabel:
raise NotImplementedError
2 changes: 0 additions & 2 deletions xrdpattern/core/labels/__init__.py

This file was deleted.

48 changes: 0 additions & 48 deletions xrdpattern/core/labels/label_tensor.py

This file was deleted.

0 comments on commit 69257c7

Please sign in to comment.