From 6a2adabef89da83534acc37fb4ab60805db77767 Mon Sep 17 00:00:00 2001 From: Christopher Tarry Date: Fri, 5 Jan 2024 13:19:47 -0500 Subject: [PATCH] fixes for nate's review --- persist/sqlite/query.go | 1 + persist/sqlite/txn.go | 42 ++++++++++++++++++++++++++--------------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/persist/sqlite/query.go b/persist/sqlite/query.go index a94d25eb..017d7d66 100644 --- a/persist/sqlite/query.go +++ b/persist/sqlite/query.go @@ -54,6 +54,7 @@ func (s *Store) Block(id types.BlockID) (result types.Block, err error) { if rows, err = s.query("SELECT address, value FROM MinerPayouts WHERE block_id = ? ORDER BY block_order", encode(id)); err != nil { return } + defer rows.Close() var address, value []byte for rows.Next() { diff --git a/persist/sqlite/txn.go b/persist/sqlite/txn.go index f091a6c8..0924890a 100644 --- a/persist/sqlite/txn.go +++ b/persist/sqlite/txn.go @@ -2,6 +2,7 @@ package sqlite import ( "bytes" + "fmt" "go.sia.tech/core/chain" "go.sia.tech/core/types" @@ -23,36 +24,47 @@ func encodeUint64(x uint64) []byte { return buf.Bytes() } -func (s *Store) addBlock(b types.Block, height uint64) error { +func (s *Store) addBlock(tx txn, b types.Block, height uint64) error { // nonce is encoded because database/sql doesn't support uint64 with high bit set - _, err := s.exec("INSERT INTO Blocks(id, height, parent_id, nonce, timestamp) VALUES (?, ?, ?, ?, ?);", encode(b.ID()), height, encode(b.ParentID), encodeUint64(b.Nonce), b.Timestamp.Unix()) + _, err := tx.Exec("INSERT INTO Blocks(id, height, parent_id, nonce, timestamp) VALUES (?, ?, ?, ?, ?);", encode(b.ID()), height, encode(b.ParentID), encodeUint64(b.Nonce), b.Timestamp.Unix()) return err } -func (s *Store) addMinerPayouts(bid types.BlockID, scos []types.SiacoinOutput) error { +func (s *Store) addMinerPayouts(tx txn, bid types.BlockID, scos []types.SiacoinOutput) error { for i, sco := range scos { - if _, err := s.exec("INSERT INTO MinerPayouts(block_id, block_order, address, value) VALUES (?, ?, ?, ?);", encode(bid), i, encode(sco.Address), encode(sco.Value)); err != nil { + if _, err := tx.Exec("INSERT INTO MinerPayouts(block_id, block_order, address, value) VALUES (?, ?, ?, ?);", encode(bid), i, encode(sco.Address), encode(sco.Value)); err != nil { return err } } return nil } -func (s *Store) deleteBlock(bid types.BlockID) error { - _, err := s.exec("DELETE FROM Blocks WHERE id = ?", encode(bid)) +func (s *Store) deleteBlock(tx txn, bid types.BlockID) error { + _, err := tx.Exec("DELETE FROM Blocks WHERE id = ?", encode(bid)) return err } func (s *Store) applyUpdates() error { - for _, update := range s.pendingUpdates { - if err := s.addBlock(update.Block, update.State.Index.Height); err != nil { - return err - } else if err := s.addMinerPayouts(update.Block.ID(), update.Block.MinerPayouts); err != nil { - return err + return s.transaction(func(tx txn) error { + for _, update := range s.pendingUpdates { + if err := s.addBlock(tx, update.Block, update.State.Index.Height); err != nil { + return fmt.Errorf("applyUpdates: failed to add block: %v", err) + } else if err := s.addMinerPayouts(tx, update.Block.ID(), update.Block.MinerPayouts); err != nil { + return fmt.Errorf("applyUpdates: failed to add miner payouts: %v", err) + } } - } - s.pendingUpdates = s.pendingUpdates[:0] - return nil + s.pendingUpdates = s.pendingUpdates[:0] + return nil + }) +} + +func (s *Store) revertUpdate(cru *chain.RevertUpdate) error { + return s.transaction(func(tx txn) error { + if err := s.deleteBlock(tx, cru.Block.ID()); err != nil { + return fmt.Errorf("revertUpdate: failed to delete block: %v", err) + } + return nil + }) } // ProcessChainApplyUpdate implements chain.Subscriber. @@ -75,5 +87,5 @@ func (s *Store) ProcessChainRevertUpdate(cru *chain.RevertUpdate) error { if err := s.applyUpdates(); err != nil { return err } - return s.deleteBlock(cru.Block.ID()) + return s.revertUpdate(cru) }