diff --git a/examples/fitting_a_hh_neuron/main.py b/examples/fitting_a_hh_neuron/main.py index 799b65f..220723f 100644 --- a/examples/fitting_a_hh_neuron/main.py +++ b/examples/fitting_a_hh_neuron/main.py @@ -19,12 +19,13 @@ import brainstate as bst import braintools as bts import brainunit as u -import dendritex as dx import jax import matplotlib.pyplot as plt import numpy as np import pandas as pd +import dendritex as dx + bst.environ.set(dt=0.01 * u.ms) # Load Input and Output Data @@ -220,13 +221,13 @@ def visualize_hh_input_and_output(): bounds = { 'gl': [1e0, 1e2] * u.nS, - 'g_na': [1e1, 2e2] * u.uS, - 'g_kd': [1e1, 1e2] * u.uS, + 'g_na': [1e0, 2e2] * u.uS, + 'g_kd': [1e0, 1e2] * u.uS, 'C': [0.1, 2] * u.uF * u.cm ** -2 * area, } -def fitting_by_others(method='DE', n_sample=200): +def fitting_by_others(method='DE', n_sample=200, n_iter=20): print(f"Method: {method}, n_sample: {n_sample}") @jax.jit @@ -242,7 +243,7 @@ def loss_with_multiple_run(**params): method=method, ) opt.initialize() - param = opt.minimize(10) + param = opt.minimize(n_iter) loss = compare_potentials(param) print(param) print(loss)