From 49059b3d73c49c2df726b645ba49396e0978afd7 Mon Sep 17 00:00:00 2001 From: Kisung Kang Date: Tue, 3 Sep 2024 14:17:50 +0200 Subject: [PATCH] fix stress in split more --- scripts/utils.py | 51 +++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/scripts/utils.py b/scripts/utils.py index 231bce9..6c1b5ef 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -324,32 +324,39 @@ def split_son(num_split, E_gs, harmonic_F=False): # Save all information into data-test.npz npz_name = 'MODEL/data-test.npz' - np.savez( - npz_name[:-4], - E=np.array(E_test), - F=np.array(F_test), - R=np.array(R_test), - z=np.array(z_test), - CELL=np.array(CELL_test), - PBC=np.array(PBC_test), - sigma=np.array(sigma_test), - stress=np.array(stress_test) - ) + arrays_to_save = { + 'E': np.array(E_test), + 'F': np.array(F_test), + 'R': np.array(R_test), + 'z': np.array(z_test), + 'CELL': np.array(CELL_test), + 'PBC': np.array(PBC_test), + 'sigma': np.array(sigma_test), + } + + if 'stress' in test_item['calculator']: + arrays_to_save['stress'] = np.array(stress_test) + + np.savez(npz_name[:-4], **arrays_to_save) if harmonic_F: # Save all information into data-test.npz npz_name = 'MODEL/data-test_ori.npz' - np.savez( - npz_name[:-4], - E=np.array(E_test_ori), - F=np.array(F_test_ori), - R=np.array(R_test), - z=np.array(z_test), - CELL=np.array(CELL_test), - PBC=np.array(PBC_test), - sigma=np.array(sigma_test), - stress=np.array(stress_test) - ) + + arrays_to_save_ori = { + 'E': np.array(E_test_ori), + 'F': np.array(F_test_ori), + 'R': np.array(R_test), + 'z': np.array(z_test), + 'CELL': np.array(CELL_test), + 'PBC': np.array(PBC_test), + 'sigma': np.array(sigma_test), + } + + if 'stress' in test_item['calculator']: + arrays_to_save_ori['stress'] = np.array(stress_test) + + np.savez(npz_name[:-4], **arrays_to_save_ori) single_print('[split_son]\tFinish the sampling testing data: data-train.npz')