diff --git a/tensortrade/instruments/quantity.py b/tensortrade/instruments/quantity.py index 101614e80..14634280b 100644 --- a/tensortrade/instruments/quantity.py +++ b/tensortrade/instruments/quantity.py @@ -107,7 +107,6 @@ def _bool_operation(left: Union['Quantity', float, int], right: Union['Quantity', float, int], bool_op: operator) -> bool: left, right = Quantity.validate(left, right) - boolean = bool_op(left.size, right.size) if not isinstance(boolean, bool): @@ -120,7 +119,6 @@ def _math_operation(left: Union['Quantity', float, int], right: Union['Quantity', float, int], op: operator) -> 'Quantity': left, right = Quantity.validate(left, right) - size = op(left._size, right._size) return Quantity(left.instrument, size, left.path_id) diff --git a/tensortrade/wallets/portfolio.py b/tensortrade/wallets/portfolio.py index a5e8db7d0..cc017cb73 100644 --- a/tensortrade/wallets/portfolio.py +++ b/tensortrade/wallets/portfolio.py @@ -226,7 +226,7 @@ def on_next(self, data: dict): performance_data = {k: data[k] for k in self._keys} performance_data['base_symbol'] = self.base_instrument.symbol performance_step = pd.DataFrame(performance_data, index=index) - + net_worth = data['net_worth'] if self._performance is None: @@ -245,3 +245,6 @@ def reset(self): self._initial_net_worth = None self._net_worth = None self._performance = None + + for wallet in self._wallets.values(): + wallet.reset() diff --git a/tensortrade/wallets/wallet.py b/tensortrade/wallets/wallet.py index ea4a7f651..5c6a4c064 100644 --- a/tensortrade/wallets/wallet.py +++ b/tensortrade/wallets/wallet.py @@ -24,6 +24,7 @@ class Wallet(Identifiable): def __init__(self, exchange: 'Exchange', quantity: 'Quantity'): self._exchange = exchange + self._initial_size = quantity.size self._instrument = quantity.instrument self._balance = quantity self._locked = {} @@ -89,6 +90,10 @@ def deallocate(self, path_id: str): if quantity is not None: self += quantity.size * self.instrument + def reset(self): + self._balance = Quantity(self._instrument, self._initial_size) + self._locked = {} + def __iadd__(self, quantity: 'Quantity') -> 'Wallet': if quantity.is_locked: if quantity.path_id not in self.locked.keys():