diff --git a/starlark/src/values/types/set/methods.rs b/starlark/src/values/types/set/methods.rs index 278df588..25afbd21 100644 --- a/starlark/src/values/types/set/methods.rs +++ b/starlark/src/values/types/set/methods.rs @@ -18,6 +18,7 @@ //! Methods for the `set` type. use starlark_derive::starlark_module; +use starlark_map::small_set::SmallSet; use crate as starlark; use crate::environment::MethodsBuilder; @@ -61,6 +62,39 @@ pub(crate) fn set_methods(builder: &mut MethodsBuilder) { } Ok(SetData { content: data }) } + + /// Return a new set with elements common to the set and all others. + /// Unlike Python does not support variable number of arguments. + /// ``` + /// # starlark::assert::is_true(r#" + /// x = set([1, 2, 3]) + /// y = [3, 4, 5] + /// x.intersection(y) == set([3]) + /// # "#); + /// ``` + fn intersection<'v>( + this: SetRef<'v>, + #[starlark(require=pos)] other: ValueOfUnchecked<'v, StarlarkIter>>, + heap: &'v Heap, + ) -> starlark::Result> { + //TODO(romanp) check if other is set + let other_it = other.get().iterate(heap)?; + let mut other_set = SmallSet::default(); + for elem in other_it { + other_set.insert_hashed(elem.get_hashed()?); + } + let mut data = SetData::default(); + if other_set.is_empty() { + return Ok(data); + } + + for hashed in this.content.iter_hashed() { + if other_set.contains_hashed(hashed) { + data.content.insert_hashed_unique_unchecked(hashed.copied()); + } + } + Ok(data) + } } #[cfg(test)] mod tests { @@ -128,4 +162,24 @@ mod tests { fn test_union_ordering_mixed() { assert::eq("list(set([1, 3, 5]).union(set([4, 3])))", "[1, 3, 5, 4]"); } + + #[test] + fn test_intersection() { + assert::eq("set([1, 2, 3]).intersection(set([3, 4, 5]))", "set([3])") + } + + #[test] + fn test_intersection_empty() { + assert::eq("set([1, 2, 3]).intersection(set([]))", "set([])") + } + + #[test] + fn test_intersection_iter() { + assert::eq("set([1, 2, 3]).intersection([3, 4])", "set([3])") + } + + #[test] + fn test_intersection_order() { + assert::eq("list(set([1, 2, 3]).intersection([4, 3, 1]))", "[1, 3]") + } } diff --git a/starlark/testcases/eval/go/set.star b/starlark/testcases/eval/go/set.star index 0945b643..ccd2b16d 100644 --- a/starlark/testcases/eval/go/set.star +++ b/starlark/testcases/eval/go/set.star @@ -70,8 +70,8 @@ asserts.fails(lambda : x.union([1, 2, {}]), "unhashable type: dict") # intersection, set & set or set.intersection(iterable) asserts.eq(list(set("a".elems()) & set("b".elems())), []) asserts.eq(list(set("ab".elems()) & set("bc".elems())), ["b"]) -# asserts.eq(list(set("a".elems()).intersection("b".elems())), []) -# asserts.eq(list(set("ab".elems()).intersection("bc".elems())), ["b"]) +asserts.eq(list(set("a".elems()).intersection("b".elems())), []) +asserts.eq(list(set("ab".elems()).intersection("bc".elems())), ["b"]) # # symmetric difference, set ^ set or set.symmetric_difference(iterable) # asserts.eq(set([1, 2, 3]) ^ set([4, 5, 3]), set([1, 2, 4, 5]))