Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

entropic_partial_wasserstein not stable #723

Open
wzm2256 opened this issue Mar 12, 2025 · 1 comment
Open

entropic_partial_wasserstein not stable #723

wzm2256 opened this issue Mar 12, 2025 · 1 comment

Comments

@wzm2256
Copy link

wzm2256 commented Mar 12, 2025

Describe the bug

The entropic_partial_wasserstein function produces nan when eps is small.

To Reproduce

import ot
import torch
import numpy as np

def compute_OT(M, alpha, beta, epsilon):
    s1, s2 = M.shape[0], M.shape[1]
    assert s1 == s2
    unif_vec = ot.unif(s1)
    a, b = unif_vec/beta, unif_vec
    pi_1_np = ot.partial.entropic_partial_wasserstein(a, b, M, m=alpha, reg=epsilon)
    print(f"Original: sum(pi) = {pi_1_np.sum():.4f}, alpha = {alpha:.4f}")


beta = 0.35
alpha = 0.01


M_1 = torch.load('M_1.pt')
print(f"M_1 norm = {np.linalg.norm(M_1):.2f}\n")

epsilon = 10.
compute_OT(M_1, alpha, beta, epsilon)

epsilon = 0.1
compute_OT(M_1, alpha, beta, epsilon)

Output

Original: sum(pi) = 0.0100, alpha = 0.0100
G:\Mycode\POT\ot\partial.py:698: RuntimeWarning: divide by zero encountered in divide 
  np.multiply(K, m / np.sum(K), out=K)
G:\Mycode\POT\ot\partial.py:698: RuntimeWarning: invalid value encountered in multiply
  np.multiply(K, m / np.sum(K), out=K)
Warning: numerical errors at iteration 0
Original: sum(pi) = nan, alpha = 0.0100

When eps=0.1, the output is Nan.

Expected behavior

Environment (please complete the following information):

  • OS (e.g. MacOS, Windows, Linux): Windows
  • Python version: 3.10
  • How was POT installed (source, pip, conda): pip
  • Build command you used (if compiling from source):
  • Only for GPU related bugs:
    • CUDA version:
    • GPU models and configuration:
    • Any other relevant information:

Output of the following code snippet:

import platform; print(platform.platform())
import sys; print("Python", sys.version)
import numpy; print("NumPy", numpy.__version__)
import scipy; print("SciPy", scipy.__version__)
import ot; print("POT", ot.__version__)

Additional context

@wzm2256
Copy link
Author

wzm2256 commented Mar 12, 2025

I will create a pull request soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant