Skip to content

Commit

Permalink
txbatcher: use lock, rename private methods, add type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
ecdsa committed Feb 22, 2025
1 parent cebc5af commit 9f589b0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 27 deletions.
52 changes: 27 additions & 25 deletions electrum/txbatcher.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import threading
import copy

from typing import Dict, Sequence
from . import util
from .logging import Logger
from .util import log_exceptions
Expand Down Expand Up @@ -74,6 +76,8 @@ def __init__(self, wallet):
self.wallet = wallet
self.config = wallet.config
self.db = wallet.db
self.lock = threading.Lock()
# fixme: not robust to client restart, because we do not persist batch_payments
self.batch_payments = [] # list of payments we need to make
self.batch_inputs = {} # list of inputs we need to sweep

Expand All @@ -96,7 +100,8 @@ def __init__(self, wallet):

def add_batch_payment(self, output: 'PartialTxOutput'):
# todo: maybe we should raise NotEnoughFunds here
self.batch_payments.append(output)
with self.lock:
self.batch_payments.append(output)

def add_sweep_info(self, sweep_info: 'SweepInfo'):
txin = sweep_info.txin
Expand All @@ -123,7 +128,6 @@ def add_sweep_info(self, sweep_info: 'SweepInfo'):
base_txin.witness_script = txin.witness_script
base_txin.script_sig = txin.script_sig


def get_base_tx(self) -> Optional[Transaction]:
if self._base_tx:
return self._base_tx
Expand All @@ -140,7 +144,7 @@ def get_base_tx(self) -> Optional[Transaction]:
base_tx.add_info_from_wallet(self.wallet) # needed for txid
return base_tx

def find_confirmed_base_tx(self) -> Optional[Transaction]:
def _find_confirmed_base_tx(self) -> Optional[Transaction]:
for txid in self.batch_txids:
tx_mined_status = self.wallet.adb.get_tx_height(txid)
if tx_mined_status.conf > 0:
Expand All @@ -149,8 +153,7 @@ def find_confirmed_base_tx(self) -> Optional[Transaction]:
tx.add_info_from_wallet(self.wallet) # needed for txid
return tx

def to_pay_after(self, tx):
# fixme: not robust to client restart, because we do not persist batch_payments
def _to_pay_after(self, tx) -> Sequence[PartialTxOutput]:
if not tx:
return self.batch_payments
to_pay = []
Expand All @@ -162,7 +165,7 @@ def to_pay_after(self, tx):
outputs.remove(x)
return to_pay

def to_sweep_after(self, tx):
def _to_sweep_after(self, tx) -> Dict[str, SweepInfo]:
tx_prevouts = set(txin.prevout for txin in tx.inputs()) if tx else set()
result = []
for k,v in self.batch_inputs.items():
Expand All @@ -179,7 +182,7 @@ def to_sweep_after(self, tx):
result.append((k,v))
return dict(result)

def should_bump_fee(self, base_tx):
def _should_bump_fee(self, base_tx) -> bool:
if base_tx is None:
return False
base_tx_fee = base_tx.get_fee()
Expand All @@ -196,30 +199,30 @@ async def run(self):
password = self.wallet.get_unlocked_password()
if self.wallet.has_keystore_encryption() and not password:
continue
await self.maybe_broadcast_legacy_htlc_txs()
tx = self.find_confirmed_base_tx()
await self._maybe_broadcast_legacy_htlc_txs()
tx = self._find_confirmed_base_tx()
if tx:
self.logger.info(f'base tx confirmed {tx.txid()}')
self.clear_batch_processing(tx)
self.start_new_batch(tx)
self._clear_batch_processing(tx)
self._start_new_batch(tx)
base_tx = self.get_base_tx()
to_pay = self.to_pay_after(base_tx)
to_sweep = self.to_sweep_after(base_tx)
to_pay = self._to_pay_after(base_tx)
to_sweep = self._to_sweep_after(base_tx)
to_sweep_now = {}
for k, v in to_sweep.items():
can_broadcast, wanted_height = self.can_broadcast(v, base_tx)
can_broadcast, wanted_height = self._can_broadcast(v, base_tx)
if can_broadcast:
to_sweep_now[k] = v
else:
self.wallet.add_future_tx(v, wanted_height)
if not to_pay and not to_sweep_now and not self.should_bump_fee(base_tx):
if not to_pay and not to_sweep_now and not self._should_bump_fee(base_tx):
continue
try:
tx = self.create_batch_tx(base_tx, to_sweep_now, to_pay, password)
tx = self._create_batch_tx(base_tx, to_sweep_now, to_pay, password)
except Exception as e:
self.logger.exception(f'Cannot create batch transaction: {repr(e)}')
if base_tx:
self.start_new_batch(base_tx)
self._start_new_batch(base_tx)
continue
await asyncio.sleep(self.RETRY_DELAY)
continue
Expand All @@ -242,7 +245,7 @@ async def run(self):
self.logger.info(f'starting new batch because could not broadcast')
self.start_new_batch(base_tx)

def create_batch_tx(self, base_tx, to_sweep, to_pay, password):
def _create_batch_tx(self, base_tx, to_sweep, to_pay, password):
self.logger.info(f'to_sweep: {list(to_sweep.keys())}')
self.logger.info(f'to_pay: {to_pay}')
inputs = []
Expand All @@ -260,7 +263,6 @@ def create_batch_tx(self, base_tx, to_sweep, to_pay, password):
self.logger.info(f'locktime: {locktime}')
outputs += to_pay
inputs += self.get_change_inputs(self._parent_tx) if self._parent_tx else []

tx = self.wallet.create_transaction(
base_tx=base_tx,
inputs=inputs,
Expand All @@ -275,18 +277,18 @@ def create_batch_tx(self, base_tx, to_sweep, to_pay, password):
assert tx.is_complete()
return tx

def clear_batch_processing(self, tx):
def _clear_batch_processing(self, tx):
# this ensure that we can accept an input again
# if the spending tx is removed from the blockchain
# fixme: what if there are several batches?
for txin in tx.inputs():
if txin.prevout in self.batch_processing:
self.batch_processing.remove(txin.prevout)

def start_new_batch(self, tx):
def _start_new_batch(self, tx):
use_change = tx and tx.has_change() and any([txout in self.batch_payments for txout in tx.outputs()])
self.batch_payments = self.to_pay_after(tx)
self.batch_inputs = self.to_sweep_after(tx)
self.batch_payments = self._to_pay_after(tx)
self.batch_inputs = self._to_sweep_after(tx)
self.batch_txids.clear()
self._base_tx = None
self._parent_tx = tx if use_change else None
Expand All @@ -301,7 +303,7 @@ def get_change_inputs(self, parent_tx):
txin.nsequence = 0xffffffff - 2
return inputs

def can_broadcast(self, sweep_info: 'SweepInfo', base_tx):
def _can_broadcast(self, sweep_info: 'SweepInfo', base_tx):
prevout = sweep_info.txin.prevout.to_str()
name = sweep_info.name
prev_txid, index = prevout.split(':')
Expand Down Expand Up @@ -331,7 +333,7 @@ def can_broadcast(self, sweep_info: 'SweepInfo', base_tx):
wanted_height = prev_height
return can_broadcast, wanted_height

async def maybe_broadcast_legacy_htlc_txs(self):
async def _maybe_broadcast_legacy_htlc_txs(self):
""" pre-anchor htlc txs cannot be batched """
for sweep_info in list(self.batch_inputs.values()):
if sweep_info.name == 'first-stage-htlc':
Expand Down
2 changes: 0 additions & 2 deletions tests/test_txbatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,6 @@ async def test_batch_payments(self, mock_save_db):
assert wallet.adb.get_transaction(tx1.txid()) is not None
assert wallet.adb.get_transaction(tx1_prime.txid()) is None
# txbatcher creates tx2
self.logger.info(f'to pay after {wallet.txbatcher.to_pay_after(tx1)}')
self.logger.info(f'{tx_mined_status}')
await self.network._tx_event.wait()
tx2 = wallet.txbatcher.get_base_tx()
assert output1 in tx1.outputs()
Expand Down

0 comments on commit 9f589b0

Please sign in to comment.