Skip to content

Commit

Permalink
[tet.py] Add max batch size limit for ap state.
Browse files Browse the repository at this point in the history
When obserng, wss will be calculated in the same batch, but sometimes
memory is not enough, so a max batch size limit should be provided.
  • Loading branch information
hzhangxyz committed Oct 8, 2022
1 parent 5919ea2 commit a20c8f4
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.org
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
+ *tetragono*: Add =ap_hamiltonian= to replace the hamiltonian of the ansatz product state in tetragono shell.
+ *tetragono*: Add =multichain_number= for =ap_run=, which will run multiple chains inside the same MPI process.
+ *wrapper*: Add python package =wrapper_TAT= to provide a wrapper over torch to provide similar interface as =TAT.py=.
+ *tetragono*: Add =observe_max_batch_size= option for =ap_run=, which will set the max limit of batch size when
calculating wss.
*** Changed
*** Deprecated
+ *tetragono*: =gm_data_load= is deprecated, it will be removed in the future, please use =gm_hamiltonian= to replace
Expand Down
4 changes: 4 additions & 0 deletions tetragono/tetragono/ansatz_product_state/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ def gradient_descent(
sweep_hopping_hamiltonians=None,
# About subspace
restrict_subspace=None,
# About observe
observe_max_batch_size=None,
# About gradient method
use_check_difference=False,
use_line_search=False,
Expand Down Expand Up @@ -196,6 +198,7 @@ def restrict(configuration, replacement=None):
enable_natural_gradient=use_natural_gradient,
cache_natural_delta=cache_natural_delta,
restrict_subspace=restrict,
max_batch_size=observe_max_batch_size,
)
if measurement:
measurement_names = measurement.split(",")
Expand All @@ -211,6 +214,7 @@ def restrict(configuration, replacement=None):
enable_energy=True,
enable_gradient=use_line_search,
restrict_subspace=restrict,
max_batch_size=observe_max_batch_size,
)

# Main loop
Expand Down
28 changes: 26 additions & 2 deletions tetragono/tetragono/ansatz_product_state/observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Observer:
"owner", "_observer", "_enable_gradient", "_enable_natural", "_cache_natural_delta", "_restrict_subspace",
"_start", "_result", "_result_square", "_result_reweight", "_count", "_total_weight", "_total_weight_square",
"_total_log_ws", "_total_energy", "_total_energy_square", "_total_energy_reweight", "_Delta", "_EDelta",
"_Deltas"
"_Deltas", "_max_batch_size"
]

def __enter__(self):
Expand Down Expand Up @@ -120,6 +120,7 @@ def __init__(
enable_natural_gradient=False,
cache_natural_delta=None,
restrict_subspace=None,
max_batch_size=None,
):
"""
Create observer object for the given ansatz product state.
Expand All @@ -140,6 +141,8 @@ def __init__(
The folder name to cache deltas used in natural gradient.
restrict_subspace : optional
A function return bool to restrict sampling subspace.
max_batch_size : int, optional
The max batch size when calculating wss
"""
self.owner = owner

Expand All @@ -166,6 +169,8 @@ def __init__(
self._EDelta = None
self._Deltas = None

self._max_batch_size = None

if observer_set is not None:
self._observer = observer_set
if enable_energy:
Expand All @@ -176,6 +181,18 @@ def __init__(
self.enable_natural_gradient()
self.cache_natural_delta(cache_natural_delta)
self.restrict_subspace(restrict_subspace)
self.set_max_batch_size(max_batch_size)

def set_max_batch_size(self, max_batch_size):
"""
Set the max batch size when calculating wss.
Parameters
----------
max_batch_size : int, optional
The max batch size when calculating wss.
"""
self._max_batch_size = max_batch_size

def restrict_subspace(self, restrict_subspace):
"""
Expand Down Expand Up @@ -303,7 +320,14 @@ def __call__(self, sampling_result):
observer_shrinked)
configuration_list.append(new_configuration)
# measure
wss_list, _ = self.owner.ansatz.weight_and_delta(configuration_list, False)
if self._max_batch_size is None:
wss_list, _ = self.owner.ansatz.weight_and_delta(configuration_list, False)
else:
wss_list = []
for i in range((len(configuration_list) - 1) // self._max_batch_size + 1):
wss_list_this, _ = self.owner.ansatz.weight_and_delta(
configuration_list[i * self._max_batch_size:(i + 1) * self._max_batch_size], False)
wss_list += wss_list_this
for chain, reweight in enumerate(reweights):
for name, configuration_data_name in configuration_data[chain].items():
if name == "energy":
Expand Down

0 comments on commit a20c8f4

Please sign in to comment.