diff --git a/test/test_batch_timeframe.py b/test/test_batch_timeframe.py index 8886c59..b6cb815 100644 --- a/test/test_batch_timeframe.py +++ b/test/test_batch_timeframe.py @@ -437,6 +437,41 @@ def test_batch_timeframe_subtract_two_instances_successfully(): assert btf2.duration == btf1.duration +@pytest.mark.skip(reason="Not implemented yet") +def test_batch_timeframe_subtract_superset_from_subset(): + tf_long = TimeFrame(datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 13)) + tf_a = BatchTimeFrame([tf_long]) + + tf_sub_a = TimeFrame(datetime(2022, 3, 1, 12, 30), datetime(2022, 3, 1, 12, 45)) + tf_b = BatchTimeFrame([tf_sub_a]) + + result = tf_a - tf_b + + assert result == BatchTimeFrame( + [ + TimeFrame(datetime(2022, 3, 1, 12), datetime(2022, 3, 1, 12, 30)) + - timedelta(microseconds=1), + TimeFrame( + datetime(2022, 3, 1, 12, 45) + timedelta(microseconds=1), + datetime(2022, 3, 1, 13), + ), + ] + ) + + +def test_batch_timeframe_subtracted_from_batch_timeframe_raise_not_implemented_error(): + tf1 = TimeFrame(datetime(2021, 1, 18, 10), datetime(2021, 1, 18, 11)) + tf2 = TimeFrame(datetime(2021, 1, 18, 12), datetime(2021, 1, 18, 14)) + tf3 = TimeFrame(datetime(2021, 1, 18, 18), datetime(2021, 1, 18, 20)) + + tf_list1 = [tf1, tf2, tf3] + + btf1 = BatchTimeFrame(tf_list1) + + with pytest.raises(NotImplementedError): + btf1 - btf1 + + def test_batch_timeframe_subtract_with_timeframe_successfully(): tf1 = TimeFrame(datetime(2021, 1, 18, 10), datetime(2021, 1, 18, 11)) tf2 = TimeFrame(datetime(2021, 1, 18, 12), datetime(2021, 1, 18, 14)) diff --git a/timeframe/timeframe.py b/timeframe/timeframe.py index 627ca5e..64a6407 100644 --- a/timeframe/timeframe.py +++ b/timeframe/timeframe.py @@ -190,19 +190,15 @@ def __sub__(self, tf: BaseTimeFrame) -> "BatchTimeFrame": if not isinstance(tf, BaseTimeFrame): raise TypeError(f"{tf} should be a BaseTimeFrame") - if not isinstance(tf, BatchTimeFrame): - candidates = [tf] - else: - candidates = tf - - result = list(self) + if isinstance(tf, TimeFrame): + return BatchTimeFrame( + [current_timeframe - tf for current_timeframe in self] + ) - for candidate in candidates: - for index, current_timeframe in enumerate(result): - if candidate._has_common_ground(current_timeframe): - result[index] = current_timeframe - candidate + if isinstance(tf, _Empty): + return self - return BatchTimeFrame(result) + raise NotImplementedError("BatchTimeFrame is not supported for subtraction!") def __repr__(self) -> str: return "\n".join(str(tf) for tf in list(self))