Skip to content

Commit

Permalink
Refactor weight calculation in puma function
Browse files Browse the repository at this point in the history
  • Loading branch information
yoyolicoris committed Apr 5, 2024
1 parent 1119bbc commit 9088489
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions kamui/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,16 +200,21 @@ def cal_Ek(K, psi, i, j):

G.add_edges(edges[:, 0], edges[:, 1], weight, np.zeros_like(weight))

a = e10 - e00
flip_mask = a < 0
tmp_st_weight = np.zeros((2, total_nodes))

for i, a in enumerate(e10 - e00):
u, v = edges[i]
if a > 0:
tmp_st_weight[0, u] += a
tmp_st_weight[1, v] += a
else:
tmp_st_weight[1, u] -= a
tmp_st_weight[0, v] -= a
flip_index = np.stack((flip_mask.astype(np.int), 1 - flip_mask.astype(np.int)), axis=1)
positive_a = np.where(flip_mask, -a, a)
np.add.at(tmp_st_weight, (flip_index.ravel(), edges.ravel()), positive_a.repeat(2))

# for i, a in enumerate(e10 - e00):
# u, v = edges[i]
# if a > 0:
# tmp_st_weight[0, u] += a
# tmp_st_weight[1, v] += a
# else:
# tmp_st_weight[1, u] -= a
# tmp_st_weight[0, v] -= a

for i in range(total_nodes):
G.add_tedge(i, tmp_st_weight[0, i], tmp_st_weight[1, i])
Expand Down

0 comments on commit 9088489

Please sign in to comment.