Skip to content

jnp.ldexp overflows when it shouldn't #28040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jakevdp opened this issue Apr 15, 2025 · 8 comments · Fixed by #28158
Closed

jnp.ldexp overflows when it shouldn't #28040

jakevdp opened this issue Apr 15, 2025 · 8 comments · Fixed by #28158
Assignees
Labels
bug Something isn't working

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Apr 15, 2025

For example:

In [1]: import numpy as np

In [2]: np.ldexp(np.float16(0.5), 16)
Out[2]: np.float16(32770.0)

In [3]: import jax.numpy as jnp

In [4]: jnp.ldexp(jnp.float16(0.5), 16)
Out[4]: Array(inf, dtype=float16)

This is because jnp.ldexp has a naive implementation which is essentially return x * 2 ** n, which will always overflow if n is too large.

@jakevdp jakevdp added the bug Something isn't working label Apr 15, 2025
@pearu
Copy link
Collaborator

pearu commented Apr 15, 2025

Since multiplication with a power of 2 is exact then replacing the naive implementation with (x * 2 ** (n // 2)) * 2 ** (n - n // 2) (or with (x * 2) * (2 ** (n - 1))) worksaround this problem (warning: untested for all float16 values).

Here's another example case to watch out while fixing this issue:

>>> np.ldexp(np.float16(0.001), 23)
8390.0
>>> jnp.ldexp(jnp.float16(0.001), 23)
Array(inf, dtype=float16)

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 15, 2025

Since multiplication with a power of 2 is exact then replacing the naive implementation with (x * 2 ** (n // 2)) * 2 ** (n - n // 2) (or with (x * 2) * (2 ** (n - 1))) worksaround this problem (warning: untested for all float16 values).

I like this idea! I did a bit of an experiment to see if we could do this via bitwise ops (i.e. cast x1 to uint, extract the exponent bits, add x2, and insert the new exponent back in) but the simple approach here fails for subnormal numbers. That case is complicated enough that I think it would be better to stick with multiplication.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 15, 2025

Here's the most extreme case we'd have to handle:

import numpy as np
dtype = np.dtype('float16')
x = np.finfo(dtype).smallest_subnormal  # x = 6e-08
N = int(np.log2(np.finfo(dtype).max) - np.log2(x))  # N = 40
print(np.ldexp(x, N - 1)  # np.float16(32770.0)

@pearu
Copy link
Collaborator

pearu commented Apr 16, 2025

For float16, the extreme values for N are -40 and 39 that still result finite values from ldexp.

Consider the following script

import numpy as np
import warnings
warnings.filterwarnings("ignore")


def ldexp_current(m, e):
    return m * dtype(2) ** e

def ldexp1(m, e):
    return (m * dtype(2)) * dtype(2) ** (e - type(e)(1))

def ldexp2(m, e):
    e1 = e // type(e)(2)
    e2 = e - e1
    return (m * (dtype(2) ** e1)) * dtype(2) ** (e2)

def ldexp3(m, e):
    return m * np.exp2(type(m)(e))

def ldexp4(m, e):
    m1, e1 = np.frexp(m)
    return m1 * np.exp2(type(m)(e + e1))

def ldexp5(m, e):
    m1, e1 = np.frexp(m)
    if e + e1 > 15:
        m1 *= type(m)(2)
        e1 -= type(e1)(1)
    return m1 * np.exp2(type(m)(e + e1))

dtype = np.float16

fi = np.finfo(dtype)

min_value = fi.smallest_subnormal
max_value = fi.max

utype = {np.float16: np.uint16, np.float32: np.uint32, np.float64: np.uint64}[dtype]

start, end = min_value.view(utype), max_value.view(utype)
finite_positive = np.array(range(start, end + 1), dtype=utype).view(dtype)

Nmin = -40
Nmax = 39

for func in [ldexp_current, ldexp1, ldexp2, ldexp3, ldexp4, ldexp5]:
    matches = 0
    mismatches = 0
    for e in range(Nmin, Nmax + 1):
        for f in finite_positive:
            expected = np.ldexp(f, e)
            result = func(f, e)
            if expected == result:
                matches += 1
            else:
                mismatches += 1
    print(f'{func.__name__}: {matches=} {mismatches=}')

that computes ldexp for all positive finite float16 values using alternative algorithms. Here's the output:

ldexp_current: matches=986029 mismatches=1553411
ldexp1: matches=994222 mismatches=1545218
ldexp2: matches=986029 mismatches=1553411
ldexp3: matches=2276314 mismatches=263126
ldexp4: matches=2507697 mismatches=31743
ldexp5: matches=2539440 mismatches=0

that is, the following algorithm reproduces numpy.ldexp exactly for float16:

def ldexp5(m, e):
    m1, e1 = np.frexp(m)
    if e + e1 > 15:  # constant 15 corresponds to float16
        m1 *= type(m)(2)
        e1 -= type(e1)(1)
    return m1 * np.exp2(type(m)(e + e1))

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 16, 2025

Can we do the if e + e1 > 15 correction in all cases to avoid having to use a cond?

@pearu
Copy link
Collaborator

pearu commented Apr 16, 2025

Can we do the if e + e1 > 15 correction in all cases to avoid having to use a cond?

We can but at the expense of loosing some accuracy. With

def ldexp6(m, e):
    m1, e1 = np.frexp(m)
    m1 *= type(m)(2)
    e1 -= type(e1)(1)
    return m1 * np.exp2(type(m)(e + e1))

the matching/mismatching statistics is

ldexp6: matches=2507737 mismatches=31703

where there is only a single pattern of mismatches: m and e are such that numpy.ldexp(m, e) returns smallest subnormal while ldexp6(m, e) returns 0, that is, the maximal error for this specific case is 1 ULP. In all other cases, ldexp6 and numpy.ldexp would be identical.
If documented, this sounds like a reasonable compromise. What do you think?

@pearu
Copy link
Collaborator

pearu commented Apr 16, 2025

FWIW, the generalization of ldexp5 to other dtypes is

fi = np.finfo(dtype)
e_limit = 2 ** (fi.nexp - 1) - 1

def ldexp5(m, e):
    m1, e1 = np.frexp(m)
    if e + e1 > e_limit:
        m1 *= type(m)(2)
        e1 -= type(e1)(1)
    return m1 * np.exp2(type(m)(e + e1))

and the same comment on ldexp6 above applies to float32 and float64 as well.

@jakevdp
Copy link
Collaborator Author

jakevdp commented Apr 17, 2025

I tried this, it fails grad tests because frexp does not have a proper gradient. I'm addressing that in #28106

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants