diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 95c0c39..153d060 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -281,7 +281,7 @@ def contract_path( #> 5 defg,hd->efgh efgh->efgh ``` """ - if optimize is True: + if (optimize is True) or (optimize is None): optimize = "auto" # Hidden option, only einsum should call this @@ -341,9 +341,11 @@ def contract_path( naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict) # Compute the path - if not isinstance(optimize, (str, paths.PathOptimizer)): + if optimize is False: + path_tuple: PathType = [tuple(range(num_ops))] + elif not isinstance(optimize, (str, paths.PathOptimizer)): # Custom path supplied - path_tuple: PathType = optimize # type: ignore + path_tuple = optimize # type: ignore elif num_ops <= 2: # Nothing to be optimized path_tuple = [tuple(range(num_ops))] @@ -536,11 +538,12 @@ def contract( - `'branch-2'` An even more restricted version of 'branch-all' that only searches the best two options at each step. Scales exponentially with the number of terms in the contraction. - - `'auto'` Choose the best of the above algorithms whilst aiming to + - `'auto', None, True` Choose the best of the above algorithms whilst aiming to keep the path finding time below 1ms. - `'auto-hq'` Aim for a high quality contraction, choosing the best of the above algorithms whilst aiming to keep the path finding time below 1sec. + - `False` will not optimize the contraction. memory_limit:- Give the upper bound of the largest intermediate tensor contract will build. - None or -1 means there is no limit. @@ -571,7 +574,7 @@ def contract( performed optimally. When NumPy is linked to a threaded BLAS, potential speedups are on the order of 20-100 for a six core machine. """ - if optimize is True: + if (optimize is True) or (optimize is None): optimize = "auto" operands_list = [subscripts] + list(operands) diff --git a/opt_einsum/tests/test_contract.py b/opt_einsum/tests/test_contract.py index 313198d..72d90f7 100644 --- a/opt_einsum/tests/test_contract.py +++ b/opt_einsum/tests/test_contract.py @@ -14,6 +14,7 @@ # NumPy is required for the majority of this file np = pytest.importorskip("numpy") + tests = [ # Test scalar-like operations "a,->a", @@ -99,6 +100,18 @@ ] +@pytest.mark.parametrize("optimize", (True, False, None)) +def test_contract_plain_types(optimize: OptimizeKind) -> None: + expr = "ij,jk,kl->il" + ops = [np.random.rand(2, 2), np.random.rand(2, 2), np.random.rand(2, 2)] + + path = contract_path(expr, *ops, optimize=optimize) + assert len(path) == 2 + + result = contract(expr, *ops, optimize=optimize) + assert result.shape == (2, 2) + + @pytest.mark.parametrize("string", tests) @pytest.mark.parametrize("optimize", _PATH_OPTIONS) def test_compare(optimize: OptimizeKind, string: str) -> None: