Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filter condition is not broadcasted #1638

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions include/xtensor/xindex_view.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,29 @@ namespace xt
return view_type(std::forward<E>(e), std::move(indices));
}

/**
* @brief creates a view into \a e filtered by \a condition.
*
* Returns an xt::view (2D view) with the rows selected where \a condition evaluates to \em true.
*
* @param e the underlying xexpression
* @param condition xexpression with shape of \a e which selects rows
*
* \code{.cpp}
* xarray<int> a = { { 0, 1, 2, 3 }, { 4, 5, 6, 7 }, { 8, 9, 10, 11 } };
* auto filter = xt::filter_rows(a, xt::view(a, xt::all(), 0) > 4);
* std::cout << filter << std:endl; // { {8, 9, 10, 11} }
* \endcode
*/
template <class E, class O>
inline auto filter_rows(E&& e, O&& condition) noexcept
{
auto row_indices = rowwhere(std::forward<O>(condition));
return view(e, keep(row_indices), all());
}



/**
* @brief creates a filtration of \c e filtered by \a condition.
*
Expand Down
31 changes: 31 additions & 0 deletions include/xtensor/xoperation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,37 @@ namespace xt
return indices;
}

/**
* @ingroup logical_operators
* @brief return vector of row indices where arr is not zero
*
* @param arr input array
* @return vector of size_types where arr is not equal to zero
*/
template <class T>
inline auto rowwhere(const T& arr)
{
auto shape = arr.shape();
using index_type = xindex_type_t<typename T::shape_type>;
using size_type = typename T::size_type;

auto idx = xtl::make_sequence<index_type>(arr.dimension(), 0);
std::vector<size_type> indices;

size_type total_size = compute_size(shape);
for (size_type i = 0; i < total_size; i++, detail::next_idx(shape, idx))
{
if (arr.element(std::begin(idx), std::end(idx)))
{
indices.push_back(i);
}
}

return indices;
}



/**
* @ingroup logical_operators
* @brief Any
Expand Down
9 changes: 9 additions & 0 deletions test/test_xindex_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ namespace xt
EXPECT_EQ(expected, a);
}

TEST(xindex_view, filter_row)
{
xarray<int> a = { { 0, 1, 2, 3 },{ 4, 5, 6, 7 },{ 8, 9, 10, 11 } };
xarray<int> res = { {8, 9, 10, 11} };
auto filter = filter_rows(a, view(a, all(), 0) > 4);

EXPECT_EQ(filter, res);
}

TEST(xindex_view, const_adapt_filter)
{
const std::vector<double> av({1,2,3,4,5,6});
Expand Down