Skip to content

Commit

Permalink
intersection
Browse files Browse the repository at this point in the history
Summary: set([1,2]) & set([2,3]) == set([2])

Reviewed By: stepancheg

Differential Revision: D60375710

fbshipit-source-id: b99dde0cfe2c520169d3d830ee8df7fc6f34c02d
  • Loading branch information
perehonchuk authored and facebook-github-bot committed Oct 1, 2024
1 parent 35ea4c9 commit a8d01cd
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
54 changes: 54 additions & 0 deletions starlark/src/values/types/set/methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
//! Methods for the `set` type.

use starlark_derive::starlark_module;
use starlark_map::small_set::SmallSet;

use crate as starlark;
use crate::environment::MethodsBuilder;
Expand Down Expand Up @@ -61,6 +62,39 @@ pub(crate) fn set_methods(builder: &mut MethodsBuilder) {
}
Ok(SetData { content: data })
}

/// Return a new set with elements common to the set and all others.
/// Unlike Python does not support variable number of arguments.
/// ```
/// # starlark::assert::is_true(r#"
/// x = set([1, 2, 3])
/// y = [3, 4, 5]
/// x.intersection(y) == set([3])
/// # "#);
/// ```
fn intersection<'v>(
this: SetRef<'v>,
#[starlark(require=pos)] other: ValueOfUnchecked<'v, StarlarkIter<Value<'v>>>,
heap: &'v Heap,
) -> starlark::Result<SetData<'v>> {
//TODO(romanp) check if other is set
let other_it = other.get().iterate(heap)?;
let mut other_set = SmallSet::default();
for elem in other_it {
other_set.insert_hashed(elem.get_hashed()?);
}
let mut data = SetData::default();
if other_set.is_empty() {
return Ok(data);
}

for hashed in this.content.iter_hashed() {
if other_set.contains_hashed(hashed) {
data.content.insert_hashed_unique_unchecked(hashed.copied());
}
}
Ok(data)
}
}
#[cfg(test)]
mod tests {
Expand Down Expand Up @@ -128,4 +162,24 @@ mod tests {
fn test_union_ordering_mixed() {
assert::eq("list(set([1, 3, 5]).union(set([4, 3])))", "[1, 3, 5, 4]");
}

#[test]
fn test_intersection() {
assert::eq("set([1, 2, 3]).intersection(set([3, 4, 5]))", "set([3])")
}

#[test]
fn test_intersection_empty() {
assert::eq("set([1, 2, 3]).intersection(set([]))", "set([])")
}

#[test]
fn test_intersection_iter() {
assert::eq("set([1, 2, 3]).intersection([3, 4])", "set([3])")
}

#[test]
fn test_intersection_order() {
assert::eq("list(set([1, 2, 3]).intersection([4, 3, 1]))", "[1, 3]")
}
}
4 changes: 2 additions & 2 deletions starlark/testcases/eval/go/set.star
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ asserts.fails(lambda : x.union([1, 2, {}]), "unhashable type: dict")
# intersection, set & set or set.intersection(iterable)
asserts.eq(list(set("a".elems()) & set("b".elems())), [])
asserts.eq(list(set("ab".elems()) & set("bc".elems())), ["b"])
# asserts.eq(list(set("a".elems()).intersection("b".elems())), [])
# asserts.eq(list(set("ab".elems()).intersection("bc".elems())), ["b"])
asserts.eq(list(set("a".elems()).intersection("b".elems())), [])
asserts.eq(list(set("ab".elems()).intersection("bc".elems())), ["b"])

# # symmetric difference, set ^ set or set.symmetric_difference(iterable)
# asserts.eq(set([1, 2, 3]) ^ set([4, 5, 3]), set([1, 2, 4, 5]))
Expand Down

0 comments on commit a8d01cd

Please sign in to comment.