diff --git a/pkg/datastore/postgres/repositories/address_repository.go b/pkg/datastore/postgres/repositories/address_repository.go index 3976fa280..7cf709d2a 100644 --- a/pkg/datastore/postgres/repositories/address_repository.go +++ b/pkg/datastore/postgres/repositories/address_repository.go @@ -27,11 +27,11 @@ import ( type AddressRepository struct{} -func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (int, error) { +func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (int64, error) { stringAddressToCommonAddress := common.HexToAddress(address) hexAddress := stringAddressToCommonAddress.Hex() - var addressId int + var addressId int64 getErr := db.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress) if getErr == sql.ErrNoRows { insertErr := db.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId) @@ -41,11 +41,11 @@ func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (in return addressId, getErr } -func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address string) (int, error) { +func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address string) (int64, error) { stringAddressToCommonAddress := common.HexToAddress(address) hexAddress := stringAddressToCommonAddress.Hex() - var addressId int + var addressId int64 getErr := tx.Get(&addressId, `SELECT id FROM public.addresses WHERE address = $1`, hexAddress) if getErr == sql.ErrNoRows { insertErr := tx.QueryRow(`INSERT INTO public.addresses (address) VALUES($1) RETURNING id`, hexAddress).Scan(&addressId) @@ -55,7 +55,7 @@ func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address st return addressId, getErr } -func (AddressRepository) GetAddressById(db *postgres.DB, id int) (string, error) { +func (AddressRepository) GetAddressById(db *postgres.DB, id int64) (string, error) { var address string getErr := db.Get(&address, `SELECT address FROM public.addresses WHERE id = $1`, id) return address, getErr diff --git a/pkg/datastore/postgres/repositories/address_repository_test.go b/pkg/datastore/postgres/repositories/address_repository_test.go index 5b67cfd00..1b2b788f5 100644 --- a/pkg/datastore/postgres/repositories/address_repository_test.go +++ b/pkg/datastore/postgres/repositories/address_repository_test.go @@ -41,8 +41,12 @@ var _ = Describe("address lookup", func() { repo = repositories.AddressRepository{} }) + AfterEach(func() { + test_config.CleanTestDB(db) + }) + type dbAddress struct { - Id int + Id int64 Address string } @@ -59,16 +63,17 @@ var _ = Describe("address lookup", func() { }) It("returns the existing record id if the address already exists", func() { - _, createErr := repo.GetOrCreateAddress(db, address) + createId, createErr := repo.GetOrCreateAddress(db, address) Expect(createErr).NotTo(HaveOccurred()) - _, getErr := repo.GetOrCreateAddress(db, address) + getId, getErr := repo.GetOrCreateAddress(db, address) Expect(getErr).NotTo(HaveOccurred()) var addressCount int addressErr := db.Get(&addressCount, `SELECT count(*) FROM public.addresses`) Expect(addressErr).NotTo(HaveOccurred()) Expect(addressCount).To(Equal(1)) + Expect(createId).To(Equal(getId)) }) It("gets upper-cased addresses", func() { @@ -102,10 +107,15 @@ var _ = Describe("address lookup", func() { Expect(txErr).NotTo(HaveOccurred()) }) + AfterEach(func() { + tx.Rollback() + }) + It("creates an address record", func() { addressId, createErr := repo.GetOrCreateAddressInTransaction(tx, address) Expect(createErr).NotTo(HaveOccurred()) - tx.Commit() + commitErr := tx.Commit() + Expect(commitErr).NotTo(HaveOccurred()) var actualAddress dbAddress getErr := db.Get(&actualAddress, `SELECT id, address FROM public.addresses LIMIT 1`)