Skip to content

Commit

Permalink
Add KDTreeMatch struct to return richer data
Browse files Browse the repository at this point in the history
  • Loading branch information
ideoforms committed Jan 20, 2024
1 parent 0932d0c commit d272822
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 37 deletions.
35 changes: 19 additions & 16 deletions source/include/signalflow/core/kdtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@
#include <memory>
#include <vector>

class KDTreeMatch
{
public:
KDTreeMatch(int index,
std::vector<float> coordinate,
float distance)
: index(index), coordinate(coordinate), distance(distance) {}
int index;
std::vector<float> coordinate;
float distance;

int get_index() { return this->index; }
std::vector<float> get_coordinate() { return this->coordinate; }
float get_distance() { return this->distance; }
};

class KDTreeNode
{
public:
Expand All @@ -14,9 +30,8 @@ class KDTreeNode
KDTreeNode *right,
const std::vector<std::vector<float>> &bounding_box);
~KDTreeNode();
std::pair<KDTreeNode *, float> get_nearest(const std::vector<float> &target,
KDTreeNode *current_nearest = nullptr,
float current_nearest_distance = std::numeric_limits<float>::infinity());
KDTreeMatch get_nearest(const std::vector<float> &target,
KDTreeMatch current_nearest);

int get_index();
std::vector<float> get_coordinates();
Expand All @@ -34,7 +49,7 @@ class KDTree
public:
KDTree(std::vector<std::vector<float>> data);
~KDTree();
std::vector<float> get_nearest(const std::vector<float> &target);
KDTreeMatch get_nearest(const std::vector<float> &target);

private:
size_t num_dimensions;
Expand All @@ -44,16 +59,4 @@ class KDTree
std::vector<std::vector<float>> bounding_box);
};

class KDTreeMatch
{
public:
KDTreeMatch(int index,
std::vector<float> coordinate,
float distance)
: index(index), coordinate(coordinate), distance(distance) {}
int index;
std::vector<float> coordinate;
float distance;
};

#endif // SIGNALFLOW_KDTREE_H
32 changes: 14 additions & 18 deletions source/src/core/kdtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ KDTreeNode::~KDTreeNode()
}
}

std::pair<KDTreeNode *, float> KDTreeNode::get_nearest(const std::vector<float> &target,
KDTreeNode *current_nearest,
float current_nearest_distance)
KDTreeMatch KDTreeNode::get_nearest(const std::vector<float> &target,
KDTreeMatch current_nearest)
{
float distance = distance_from_point_to_point(target, this->coordinates);

if (distance < current_nearest_distance)
if (distance < current_nearest.distance)
{
current_nearest = this;
current_nearest_distance = distance;
current_nearest.coordinate = this->coordinates;
current_nearest.index = this->index;
current_nearest.distance = distance;
}

/*------------------------------------------------------------------------------
Expand All @@ -99,21 +99,17 @@ std::pair<KDTreeNode *, float> KDTreeNode::get_nearest(const std::vector<float>
* pruning based on whether the target would be to the left or right
* of this node. Need to implement and benchmark.
*------------------------------------------------------------------------------*/
if (this->left && distance_from_point_to_bounding_box(target, left->bounding_box) < current_nearest_distance)
if (this->left && distance_from_point_to_bounding_box(target, left->bounding_box) < current_nearest.distance)
{
auto result = left->get_nearest(target, current_nearest, current_nearest_distance);
current_nearest = result.first;
current_nearest_distance = result.second;
current_nearest = left->get_nearest(target, current_nearest);
}

if (this->right && distance_from_point_to_bounding_box(target, right->bounding_box) < current_nearest_distance)
if (this->right && distance_from_point_to_bounding_box(target, right->bounding_box) < current_nearest.distance)
{
auto result = right->get_nearest(target, current_nearest, current_nearest_distance);
current_nearest = result.first;
current_nearest_distance = result.second;
current_nearest = right->get_nearest(target, current_nearest);
}

return { current_nearest, current_nearest_distance };
return current_nearest;
}

int KDTreeNode::get_index()
Expand Down Expand Up @@ -235,12 +231,12 @@ KDTreeNode *KDTree::construct_subtree(std::vector<std::vector<float>> data,
return node;
}

std::vector<float> KDTree::get_nearest(const std::vector<float> &target)
KDTreeMatch KDTree::get_nearest(const std::vector<float> &target)
{
if (target.size() != this->num_dimensions)
{
throw std::runtime_error("Target has an invalid number of dimensions (expected = " + std::to_string(this->num_dimensions) + ", actual = " + std::to_string(target.size()) + ")");
}
auto result = this->root->get_nearest(target);
return result.first->get_coordinates();
auto result = this->root->get_nearest(target, KDTreeMatch(0, {}, std::numeric_limits<float>::max()));
return result;
}
5 changes: 3 additions & 2 deletions source/src/node/analysis/nearest-neighbour.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,13 @@ void NearestNeighbour::process(Buffer &out, int num_frames)

float target_value = this->target->out[0][0];
std::vector<float> target_value_vector({ target_value });
std::vector<float> output_value_vector = this->kdtree->get_nearest(target_value_vector);
KDTreeMatch match = this->kdtree->get_nearest(target_value_vector);
int index = match.index;
for (auto channel = 0; channel < this->get_num_output_channels(); channel++)
{
for (auto frame = 0; frame < num_frames; frame++)
{
this->out[channel][frame] = output_value_vector[channel];
this->out[channel][frame] = index;
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions source/src/python/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,9 @@ void init_python_util(py::module &m)
py::class_<KDTree>(m, "KDTree", "A KDTree structure")
.def(py::init<std::vector<std::vector<float>>>(), "data"_a = nullptr)
.def("get_nearest", &KDTree::get_nearest, "target"_a);

py::class_<KDTreeMatch>(m, "KDTreeMatch", "A KDTreeMatch result")
.def_property_readonly("index", &KDTreeMatch::get_index, R"pbdoc(The index)pbdoc")
.def_property_readonly("coordinate", &KDTreeMatch::get_coordinate, R"pbdoc(The coordinate)pbdoc")
.def_property_readonly("distance", &KDTreeMatch::get_distance, R"pbdoc(The distance)pbdoc");
}
12 changes: 11 additions & 1 deletion tests/test_kdtree.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@ def test_kdtree():
tree = KDTree(corpus)
target = np.array([1, 5])
nearest = tree.get_nearest(target)
assert np.all(np.isclose(target, nearest, atol=1.0))

# check that the nearest value found is sufficiently near to the target
assert np.all(np.isclose(target, nearest.coordinate, atol=1.0))

# check that nearest.index points to the item in the corpus identical to nearest.coordinate
# (subject to rounding error)
assert np.all(np.isclose(corpus[nearest.index], nearest.coordinate, atol=1e-6))

# check that nearest.distance calculates the correct distance
# (subject to rounding error)
assert nearest.distance == pytest.approx(np.linalg.norm(corpus[nearest.index] - target))

def test_kdtree_validation():
corpus = np.random.uniform(0, 10, [1024, 2])
Expand Down

0 comments on commit d272822

Please sign in to comment.