Skip to content

Commit

Permalink
dev: add schwefel function
Browse files Browse the repository at this point in the history
  • Loading branch information
BillHuang2001 committed Dec 19, 2023
1 parent b8e917f commit 07848cc
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
========
Schwefel
========

.. autoclass:: evox.problems.numerical.Schwefel
:members:
1 change: 1 addition & 0 deletions src/evox/problems/numerical/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .zdt import ZDT1, ZDT2, ZDT3, ZDT4, ZDT6
from .griewank import Griewank, griewank_func
from .rosenbrock import Rosenbrock, rosenbrock_func
from .schwefel import Schwefel, schwefel_func
from .dtlz import DTLZ1, DTLZ2, DTLZ3, DTLZ4, DTLZ5, DTLZ6, DTLZ7
from .lsmop import LSMOP1, LSMOP2, LSMOP3, LSMOP4, LSMOP5, LSMOP6, LSMOP7, LSMOP8, LSMOP9
from .maf import MaF1, MaF2, MaF3, MaF4, MaF5, MaF6, MaF7, MaF8, MaF9, MaF10, MaF11, MaF12, MaF13, MaF14, MaF15
Expand Down
23 changes: 23 additions & 0 deletions src/evox/problems/numerical/schwefel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import jax
from jax import jit
import jax.numpy as jnp
from evox import problem, jit_class


@jit
def schwefel_func(x):
_pop_size, dim = x.shape
return 418.9828872724338 * dim - jnp.sum(x * jnp.sin(jnp.sqrt(jnp.abs(x))))


@jit_class
class Schwefel(problem):
"""The Schwefel function
The minimum is x = [420.9687462275036, ...]
"""

def __init__(self):
super().__init__()

def evaluate(self, state, x):
return schwefel_func(x), state
31 changes: 17 additions & 14 deletions tests/test_classic_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
def test_ackley():
ackley = problems.numerical.Ackley()
key = jax.random.PRNGKey(12345)
keys = jax.random.split(key, 16)
state = ackley.init(keys)
state = ackley.init(key)
X = jnp.zeros((16, 2))
F, state = ackley.evaluate(state, X)
chex.assert_trees_all_close(F, jnp.zeros((16,)), atol=1e-6)
Expand All @@ -18,8 +17,7 @@ def test_ackley():
def test_griewank():
griewank = problems.numerical.Griewank()
key = jax.random.PRNGKey(12345)
keys = jax.random.split(key, 2)
state = griewank.init(keys)
state = griewank.init(key)
X = jnp.zeros((16, 2))
F, state = griewank.evaluate(state, X)
chex.assert_trees_all_close(F, jnp.zeros((16,)), atol=1e-6)
Expand All @@ -28,8 +26,7 @@ def test_griewank():
def test_rastrigin():
rastrigin = problems.numerical.Rastrigin()
key = jax.random.PRNGKey(12345)
keys = jax.random.split(key, 16)
state = rastrigin.init(keys)
state = rastrigin.init(key)
X = jnp.zeros((16, 2))
F, state = rastrigin.evaluate(state, X)
chex.assert_trees_all_close(F, jnp.zeros((16,)), atol=1e-6)
Expand All @@ -38,21 +35,27 @@ def test_rastrigin():
def test_rosenbrock():
rosenbrock = problems.numerical.Rosenbrock()
key = jax.random.PRNGKey(12345)
keys = jax.random.split(key, 16)
state = rosenbrock.init(keys)
state = rosenbrock.init(key)
X = jnp.ones((16, 2))
F, state = rosenbrock.evaluate(state, X)
chex.assert_trees_all_close(F, jnp.zeros((16, )), atol=1e-6)
chex.assert_trees_all_close(F, jnp.zeros((16,)), atol=1e-6)


def test_schwefel():
schwefel = problems.numerical.Schwefel()
key = jax.random.PRNGKey(12345)
state = rosenbrock.init(key)
X = jnp.array([[420.9687462275036, 420.9687462275036]])
F, state = rosenbrock.evaluate(state, X)
chex.assert_trees_all_close(F, jnp.zeros((1,)), atol=1e-6)


def test_dtlz1():
dtlz1 = problems.numerical.DTLZ1(d=None, m=4)
key = jax.random.PRNGKey(12345)
keys = jax.random.split(key, 16)
state = dtlz1.init(keys)
X = jnp.ones((16, 7))*0.5
state = dtlz1.init(key)
X = jnp.ones((16, 7)) * 0.5
F, state = dtlz1.evaluate(state, X)
pf, state = dtlz1.pf(state)
print(pf.shape)
chex.assert_trees_all_close(
jnp.sum(F, axis=1), 0.5*jnp.ones((16, )), atol=1e-6)
chex.assert_trees_all_close(jnp.sum(F, axis=1), 0.5 * jnp.ones((16,)), atol=1e-6)

0 comments on commit 07848cc

Please sign in to comment.