diff --git a/include/xtensor/xindex_view.hpp b/include/xtensor/xindex_view.hpp index c01e0ce03..11af54ec0 100644 --- a/include/xtensor/xindex_view.hpp +++ b/include/xtensor/xindex_view.hpp @@ -780,6 +780,29 @@ namespace xt return view_type(std::forward(e), std::move(indices)); } + /** + * @brief creates a view into \a e filtered by \a condition. + * + * Returns an xt::view (2D view) with the rows selected where \a condition evaluates to \em true. + * + * @param e the underlying xexpression + * @param condition xexpression with shape of \a e which selects rows + * + * \code{.cpp} + * xarray a = { { 0, 1, 2, 3 }, { 4, 5, 6, 7 }, { 8, 9, 10, 11 } }; + * auto filter = xt::filter_rows(a, xt::view(a, xt::all(), 0) > 4); + * std::cout << filter << std:endl; // { {8, 9, 10, 11} } + * \endcode + */ + template + inline auto filter_rows(E&& e, O&& condition) noexcept + { + auto row_indices = rowwhere(std::forward(condition)); + return view(e, keep(row_indices), all()); + } + + + /** * @brief creates a filtration of \c e filtered by \a condition. * diff --git a/include/xtensor/xoperation.hpp b/include/xtensor/xoperation.hpp index 0c8ded582..731db5031 100644 --- a/include/xtensor/xoperation.hpp +++ b/include/xtensor/xoperation.hpp @@ -890,6 +890,37 @@ namespace xt return indices; } + /** + * @ingroup logical_operators + * @brief return vector of row indices where arr is not zero + * + * @param arr input array + * @return vector of size_types where arr is not equal to zero + */ + template + inline auto rowwhere(const T& arr) + { + auto shape = arr.shape(); + using index_type = xindex_type_t; + using size_type = typename T::size_type; + + auto idx = xtl::make_sequence(arr.dimension(), 0); + std::vector indices; + + size_type total_size = compute_size(shape); + for (size_type i = 0; i < total_size; i++, detail::next_idx(shape, idx)) + { + if (arr.element(std::begin(idx), std::end(idx))) + { + indices.push_back(i); + } + } + + return indices; + } + + + /** * @ingroup logical_operators * @brief Any diff --git a/test/test_xindex_view.cpp b/test/test_xindex_view.cpp index b55e8faf1..28d5a3b9d 100644 --- a/test/test_xindex_view.cpp +++ b/test/test_xindex_view.cpp @@ -153,6 +153,15 @@ namespace xt EXPECT_EQ(expected, a); } + TEST(xindex_view, filter_row) + { + xarray a = { { 0, 1, 2, 3 },{ 4, 5, 6, 7 },{ 8, 9, 10, 11 } }; + xarray res = { {8, 9, 10, 11} }; + auto filter = filter_rows(a, view(a, all(), 0) > 4); + + EXPECT_EQ(filter, res); + } + TEST(xindex_view, const_adapt_filter) { const std::vector av({1,2,3,4,5,6});