From a9907d0bf1456cdb0c449ebb23d0f1688f852be4 Mon Sep 17 00:00:00 2001 From: Alexander Goscinski Date: Sun, 11 Feb 2024 19:19:30 +0100 Subject: [PATCH] checkpoint10: all-deps test pass and made progress on TorchScriptabilty, torch-test test_torch_script_correlate_density_angular_selection passes - made `like` parameter in _parse_selected_keys --- .../rascaline/utils/clebsch_gordan/_clebsch_gordan.py | 4 ++-- .../rascaline/utils/clebsch_gordan/correlate_density.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py b/python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py index 1d152a25b..497352482 100644 --- a/python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py +++ b/python/rascaline/rascaline/utils/clebsch_gordan/_clebsch_gordan.py @@ -6,7 +6,7 @@ from typing import List, Optional, Tuple, Union from . import _cg_cache, _dispatch -from ._classes import Labels, TensorBlock, TensorMap, LabelsEntry, torch_jit_is_scripting, torch_jit_annotate, is_labels +from ._classes import Labels, TensorBlock, TensorMap, LabelsEntry, torch_jit_is_scripting, torch_jit_annotate, is_labels, Array # ================================================================== @@ -45,9 +45,9 @@ def _standardize_keys(tensor: TensorMap) -> TensorMap: def _parse_selected_keys( n_iterations: int, + like: Array, angular_cutoff: Optional[int] = None, selected_keys: Optional[Union[Labels, List[Union[Labels, None]]]] = None, - like=None, ) -> List[Union[None, Labels]]: """ Parses the `selected_keys` argument passed to public functions. Checks the diff --git a/python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py b/python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py index a40a3450c..edb408f3b 100644 --- a/python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py +++ b/python/rascaline/rascaline/utils/clebsch_gordan/correlate_density.py @@ -173,9 +173,9 @@ def _correlate_density( # Parse the selected keys selected_keys_: List[Union[Labels, None]] = _clebsch_gordan._parse_selected_keys( n_iterations=n_iterations, + like=density.keys.values, angular_cutoff=angular_cutoff, selected_keys=selected_keys, - like=density.keys.values, ) # Parse the bool flags that control skipping of redundant CG combinations # and TensorMap output from each iteration