-
Notifications
You must be signed in to change notification settings - Fork 3k
jnp.ldexp
overflows when it shouldn't
#28040
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Since multiplication with a power of 2 is exact then replacing the naive implementation with Here's another example case to watch out while fixing this issue: >>> np.ldexp(np.float16(0.001), 23)
8390.0
>>> jnp.ldexp(jnp.float16(0.001), 23)
Array(inf, dtype=float16) |
I like this idea! I did a bit of an experiment to see if we could do this via bitwise ops (i.e. cast |
Here's the most extreme case we'd have to handle: import numpy as np
dtype = np.dtype('float16')
x = np.finfo(dtype).smallest_subnormal # x = 6e-08
N = int(np.log2(np.finfo(dtype).max) - np.log2(x)) # N = 40
print(np.ldexp(x, N - 1) # np.float16(32770.0) |
For float16, the extreme values for N are -40 and 39 that still result finite values from ldexp. Consider the following script import numpy as np
import warnings
warnings.filterwarnings("ignore")
def ldexp_current(m, e):
return m * dtype(2) ** e
def ldexp1(m, e):
return (m * dtype(2)) * dtype(2) ** (e - type(e)(1))
def ldexp2(m, e):
e1 = e // type(e)(2)
e2 = e - e1
return (m * (dtype(2) ** e1)) * dtype(2) ** (e2)
def ldexp3(m, e):
return m * np.exp2(type(m)(e))
def ldexp4(m, e):
m1, e1 = np.frexp(m)
return m1 * np.exp2(type(m)(e + e1))
def ldexp5(m, e):
m1, e1 = np.frexp(m)
if e + e1 > 15:
m1 *= type(m)(2)
e1 -= type(e1)(1)
return m1 * np.exp2(type(m)(e + e1))
dtype = np.float16
fi = np.finfo(dtype)
min_value = fi.smallest_subnormal
max_value = fi.max
utype = {np.float16: np.uint16, np.float32: np.uint32, np.float64: np.uint64}[dtype]
start, end = min_value.view(utype), max_value.view(utype)
finite_positive = np.array(range(start, end + 1), dtype=utype).view(dtype)
Nmin = -40
Nmax = 39
for func in [ldexp_current, ldexp1, ldexp2, ldexp3, ldexp4, ldexp5]:
matches = 0
mismatches = 0
for e in range(Nmin, Nmax + 1):
for f in finite_positive:
expected = np.ldexp(f, e)
result = func(f, e)
if expected == result:
matches += 1
else:
mismatches += 1
print(f'{func.__name__}: {matches=} {mismatches=}') that computes ldexp for all positive finite float16 values using alternative algorithms. Here's the output:
that is, the following algorithm reproduces def ldexp5(m, e):
m1, e1 = np.frexp(m)
if e + e1 > 15: # constant 15 corresponds to float16
m1 *= type(m)(2)
e1 -= type(e1)(1)
return m1 * np.exp2(type(m)(e + e1)) |
Can we do the |
We can but at the expense of loosing some accuracy. With def ldexp6(m, e):
m1, e1 = np.frexp(m)
m1 *= type(m)(2)
e1 -= type(e1)(1)
return m1 * np.exp2(type(m)(e + e1)) the matching/mismatching statistics is
where there is only a single pattern of mismatches: |
FWIW, the generalization of ldexp5 to other dtypes is fi = np.finfo(dtype)
e_limit = 2 ** (fi.nexp - 1) - 1
def ldexp5(m, e):
m1, e1 = np.frexp(m)
if e + e1 > e_limit:
m1 *= type(m)(2)
e1 -= type(e1)(1)
return m1 * np.exp2(type(m)(e + e1)) and the same comment on ldexp6 above applies to float32 and float64 as well. |
I tried this, it fails grad tests because |
For example:
This is because
jnp.ldexp
has a naive implementation which is essentiallyreturn x * 2 ** n
, which will always overflow ifn
is too large.The text was updated successfully, but these errors were encountered: