From 908848932d87d537941b26fd3ff7dfd0bae903a3 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Fri, 5 Apr 2024 08:38:02 +0000 Subject: [PATCH] Refactor weight calculation in puma function --- kamui/core.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/kamui/core.py b/kamui/core.py index 5118527..7456ab8 100644 --- a/kamui/core.py +++ b/kamui/core.py @@ -200,16 +200,21 @@ def cal_Ek(K, psi, i, j): G.add_edges(edges[:, 0], edges[:, 1], weight, np.zeros_like(weight)) + a = e10 - e00 + flip_mask = a < 0 tmp_st_weight = np.zeros((2, total_nodes)) - - for i, a in enumerate(e10 - e00): - u, v = edges[i] - if a > 0: - tmp_st_weight[0, u] += a - tmp_st_weight[1, v] += a - else: - tmp_st_weight[1, u] -= a - tmp_st_weight[0, v] -= a + flip_index = np.stack((flip_mask.astype(np.int), 1 - flip_mask.astype(np.int)), axis=1) + positive_a = np.where(flip_mask, -a, a) + np.add.at(tmp_st_weight, (flip_index.ravel(), edges.ravel()), positive_a.repeat(2)) + + # for i, a in enumerate(e10 - e00): + # u, v = edges[i] + # if a > 0: + # tmp_st_weight[0, u] += a + # tmp_st_weight[1, v] += a + # else: + # tmp_st_weight[1, u] -= a + # tmp_st_weight[0, v] -= a for i in range(total_nodes): G.add_tedge(i, tmp_st_weight[0, i], tmp_st_weight[1, i])