Skip to content

Commit

Permalink
checkpoint10: all-deps test pass and made progress on TorchScriptabil…
Browse files Browse the repository at this point in the history
…ty, torch-test test_torch_script_correlate_density_angular_selection passes

- made `like` parameter in _parse_selected_keys
  • Loading branch information
agoscinski committed Feb 11, 2024
1 parent 9d23bfb commit a9907d0
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List, Optional, Tuple, Union

from . import _cg_cache, _dispatch
from ._classes import Labels, TensorBlock, TensorMap, LabelsEntry, torch_jit_is_scripting, torch_jit_annotate, is_labels
from ._classes import Labels, TensorBlock, TensorMap, LabelsEntry, torch_jit_is_scripting, torch_jit_annotate, is_labels, Array


# ==================================================================
Expand Down Expand Up @@ -45,9 +45,9 @@ def _standardize_keys(tensor: TensorMap) -> TensorMap:

def _parse_selected_keys(
n_iterations: int,
like: Array,
angular_cutoff: Optional[int] = None,
selected_keys: Optional[Union[Labels, List[Union[Labels, None]]]] = None,
like=None,
) -> List[Union[None, Labels]]:
"""
Parses the `selected_keys` argument passed to public functions. Checks the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ def _correlate_density(
# Parse the selected keys
selected_keys_: List[Union[Labels, None]] = _clebsch_gordan._parse_selected_keys(
n_iterations=n_iterations,
like=density.keys.values,
angular_cutoff=angular_cutoff,
selected_keys=selected_keys,
like=density.keys.values,
)
# Parse the bool flags that control skipping of redundant CG combinations
# and TensorMap output from each iteration
Expand Down

0 comments on commit a9907d0

Please sign in to comment.