Skip to content

Commit

Permalink
Refactor nbs according to dict io
Browse files Browse the repository at this point in the history
  • Loading branch information
GiovanniBordiga committed Jul 24, 2024
1 parent e76668b commit 9d51b7b
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 287 deletions.
20 changes: 14 additions & 6 deletions notebooks/kagome_focusing_3dp_pla_shims.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"import jax.numpy as jnp\n",
"from jax.config import config\n",
Expand All @@ -28,7 +36,7 @@
"from problems.kagome_focusing import ForwardProblem, OptimizationProblem\n",
"from difflexmm.geometry import compute_inertia\n",
"from difflexmm.plotting import generate_animation, plot_geometry\n",
"from difflexmm.utils import save_data, load_data\n",
"from difflexmm.utils import save_data, load_data, SolutionData\n",
"from pathlib import Path\n",
"from typing import Optional, Any\n",
"import matplotlib\n",
Expand Down Expand Up @@ -210,7 +218,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -304,7 +312,7 @@
"metadata": {},
"outputs": [],
"source": [
"optimization = OptimizationProblem.from_data(\n",
"optimization = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" )\n",
Expand Down Expand Up @@ -338,7 +346,7 @@
"\n",
"save_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" optimization.to_data() # Optimization problem\n",
" optimization.to_dict() # Optimization problem\n",
")\n"
]
},
Expand Down
52 changes: 28 additions & 24 deletions notebooks/quads_energy_splitting_3dp_pla_shims.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,33 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"import os\n",
"os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=8\" # Use 8 CPU cores for JAX pmap \n",
"\n",
"from difflexmm.utils import SolutionType, SolutionData, EigenmodeData, save_data, load_data, ControlParams, GeometricalParams, MechanicalParams, LigamentParams, ContactParams\n",
"from difflexmm.geometry import QuadGeometry, compute_inertia, rotation_matrix, compute_edge_angles, compute_edge_lengths, compute_xy_limits\n",
"from difflexmm.energy import strain_energy_bond, build_strain_energy, kinetic_energy, ligament_energy, ligament_energy_linearized, build_contact_energy, combine_block_energies, ligament_strains\n",
"from difflexmm.utils import save_data, load_data\n",
"from difflexmm.geometry import QuadGeometry, compute_inertia, compute_xy_limits\n",
"from difflexmm.energy import ligament_strains\n",
"from difflexmm.kinematics import block_to_node_kinematics\n",
"from difflexmm.dynamics import setup_dynamic_solver\n",
"from difflexmm.plotting import generate_animation, generate_frames, plot_geometry, generate_patch_collection\n",
"from difflexmm.plotting import generate_animation, plot_geometry\n",
"from problems.quads_focusing import ForwardProblem\n",
"from problems.quads_energy_splitting import OptimizationProblem\n",
"import matplotlib\n",
"import matplotlib.pyplot as plt\n",
"from matplotlib import animation\n",
"from matplotlib import cm\n",
"from pathlib import Path\n",
"from typing import NamedTuple, Any, Optional, List, Union, Tuple, Dict\n",
"import dataclasses\n",
"from dataclasses import dataclass\n",
"from typing import Optional, List\n",
"\n",
"import jax.numpy as jnp\n",
"from jax.config import config\n",
Expand All @@ -60,7 +64,7 @@
},
{
"cell_type": "code",
"execution_count": 120,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -572,7 +576,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -665,11 +669,11 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"optimization = OptimizationProblem.from_data(\n",
"optimization = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" )\n",
Expand Down Expand Up @@ -704,7 +708,7 @@
"\n",
"save_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" optimization.to_data() # Optimization problem\n",
" optimization.to_dict() # Optimization problem\n",
")\n"
]
},
Expand All @@ -722,7 +726,7 @@
"metadata": {},
"outputs": [],
"source": [
"optimization = OptimizationProblem.from_data(\n",
"optimization = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" )\n",
Expand Down Expand Up @@ -791,7 +795,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -807,14 +811,14 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Import optimization data files from pareto folder\n",
"pareto_objectives_data = []\n",
"for path in Path(f\"../data/{optimization.name}/pareto/{pareto_folder}/\").glob(\"*.pkl\"):\n",
" optimization = OptimizationProblem.from_data(load_data(path))\n",
" optimization = OptimizationProblem.from_dict(load_data(path))\n",
" pareto_objectives_data.append(\n",
" jnp.array(optimization.objective_values_individual) /\n",
" optimization.forward_problem.n_timepoints\n",
Expand All @@ -823,13 +827,13 @@
},
{
"cell_type": "code",
"execution_count": 125,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e8f87edf69bf44f3a201426cfb7dae2f",
"model_id": "49f797cd2c114c81bbd0421aaceb2911",
"version_major": 2,
"version_minor": 0
},
Expand Down Expand Up @@ -907,7 +911,7 @@
],
"source": [
"pareto_data_filename = f\"paretoSample_weights_0.599_0.401_iniAngle_35.0\"\n",
"optimization = OptimizationProblem.from_data(\n",
"optimization = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/{optimization.name}/pareto/{pareto_folder}/{pareto_data_filename}.pkl\",\n",
" )\n",
Expand Down Expand Up @@ -1147,7 +1151,7 @@
],
"source": [
"pareto_data_filename = \"paretoSample_weights_0.300_0.700_iniAngle_15.0\"\n",
"optimization = OptimizationProblem.from_data(\n",
"optimization = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/{optimization.name}/pareto/{pareto_folder}/{pareto_data_filename}.pkl\",\n",
" )\n",
Expand Down
143 changes: 23 additions & 120 deletions notebooks/quads_focusing_3dp_pla_shims.ipynb

Large diffs are not rendered by default.

47 changes: 26 additions & 21 deletions notebooks/quads_focusing_3dp_pla_shims_random_initial_guess.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,29 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
}
],
"source": [
"from difflexmm.utils import SolutionType, SolutionData, EigenmodeData, save_data, load_data, ControlParams, GeometricalParams, MechanicalParams, LigamentParams, ContactParams\n",
"from difflexmm.geometry import QuadGeometry, compute_inertia, rotation_matrix, compute_edge_angles, compute_edge_lengths\n",
"from difflexmm.energy import strain_energy_bond, build_strain_energy, kinetic_energy, ligament_energy, ligament_energy_linearized, build_contact_energy, combine_block_energies\n",
"from difflexmm.dynamics import setup_dynamic_solver\n",
"from difflexmm.plotting import generate_animation, generate_frames, plot_geometry, generate_patch_collection\n",
"from difflexmm.utils import SolutionData, save_data, load_data\n",
"from difflexmm.geometry import QuadGeometry, compute_inertia\n",
"from difflexmm.energy import kinetic_energy\n",
"from difflexmm.plotting import generate_animation, plot_geometry\n",
"from problems.quads_focusing import ForwardProblem, OptimizationProblem\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"from matplotlib.colors import to_rgba\n",
"from matplotlib import patches\n",
"from pathlib import Path\n",
"from typing import NamedTuple, Any, Optional, List, Union, Tuple, Dict\n",
"import dataclasses\n",
"from dataclasses import dataclass\n",
"from typing import Any, Optional\n",
"\n",
"import jax.numpy as jnp\n",
"from jax import random\n",
Expand Down Expand Up @@ -550,7 +555,7 @@
"metadata": {},
"outputs": [],
"source": [
"optimization = OptimizationProblem.from_data(\n",
"optimization = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" )\n",
Expand Down Expand Up @@ -587,7 +592,7 @@
"\n",
"save_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" optimization.to_data() # Optimization problem\n",
" optimization.to_dict() # Optimization problem\n",
")"
]
},
Expand Down Expand Up @@ -942,13 +947,13 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# Regular initial design\n",
"optimization_filename_regular = \"opt_with_angle_30_and_length_3_constraints_quads_24x16_excited_blocks_2_amplitude_7.50_loading_rate_30.00_input_shift_0_initial_angle_25.0_target_size_2x2_target_shift_4x5\"\n",
"optimization_regular = OptimizationProblem.from_data(\n",
"optimization_regular = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/quads_focusing_3dp_pla_shims/{optimization_filename_regular}.pkl\",\n",
" )\n",
Expand All @@ -959,12 +964,12 @@
"noise_amplitude_sweep = [0.1, 0.15, 0.2]\n",
"optimizations_random_initial_design = [\n",
" [\n",
" OptimizationProblem.from_data(\n",
" load_data(\n",
" f\"../data/quads_focusing_3dp_pla_shims/opt_with_angle_30_and_length_3_constraints_quads_24x16_excited_blocks_2_amplitude_7.50_loading_rate_30.00_input_shift_0_random_initial_design_noise_{noise_amplitude:.2f}_keys_{key}_{key+1}_target_size_2x2_target_shift_4x5.pkl\",\n",
" OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/quads_focusing_3dp_pla_shims/opt_with_angle_30_and_length_3_constraints_quads_24x16_excited_blocks_2_amplitude_7.50_loading_rate_30.00_input_shift_0_random_initial_design_noise_{noise_amplitude:.2f}_keys_{key}_{key+1}_target_size_2x2_target_shift_4x5.pkl\",\n",
" )\n",
" )\n",
" )\n",
" for key in keys_sweep\n",
" for key in keys_sweep\n",
" ]\n",
" for noise_amplitude in noise_amplitude_sweep\n",
"]"
Expand All @@ -979,13 +984,13 @@
},
{
"cell_type": "code",
"execution_count": 50,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "96f9734da6e342229561f4993c98ba48",
"model_id": "d137249b426c4cc3931c03cea0ec0075",
"version_major": 2,
"version_minor": 0
},
Expand Down
15 changes: 6 additions & 9 deletions notebooks/quads_focusing_3dp_pla_shims_restricted_space.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,17 @@
"metadata": {},
"outputs": [],
"source": [
"from difflexmm.utils import SolutionType, SolutionData, EigenmodeData, save_data, load_data, ControlParams, GeometricalParams, MechanicalParams, LigamentParams, ContactParams\n",
"from difflexmm.geometry import QuadGeometry, compute_inertia, rotation_matrix, compute_edge_angles, compute_edge_lengths\n",
"from difflexmm.energy import strain_energy_bond, build_strain_energy, kinetic_energy, ligament_energy, ligament_energy_linearized, build_contact_energy, combine_block_energies\n",
"from difflexmm.dynamics import setup_dynamic_solver\n",
"from difflexmm.plotting import generate_animation, generate_frames, plot_geometry, generate_patch_collection\n",
"from difflexmm.utils import SolutionData, save_data, load_data\n",
"from difflexmm.geometry import QuadGeometry, compute_inertia\n",
"from difflexmm.energy import kinetic_energy\n",
"from difflexmm.plotting import generate_animation, plot_geometry\n",
"from problems.quads_focusing_restricted_space import ForwardProblem, OptimizationProblem\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"from matplotlib.colors import to_rgba\n",
"from matplotlib import patches\n",
"from pathlib import Path\n",
"from typing import NamedTuple, Any, Optional, List, Union, Tuple, Dict\n",
"import dataclasses\n",
"from dataclasses import dataclass\n",
"from typing import Any, Optional\n",
"\n",
"import jax.numpy as jnp\n",
"from jax.config import config\n",
Expand Down Expand Up @@ -546,7 +543,7 @@
"metadata": {},
"outputs": [],
"source": [
"optimization = OptimizationProblem.from_data(\n",
"optimization = OptimizationProblem.from_dict(\n",
" load_data(\n",
" f\"../data/{optimization.name}/{optimization_filename}.pkl\",\n",
" )\n",
Expand Down
Loading

0 comments on commit 9d51b7b

Please sign in to comment.