diff --git a/backend/db.py b/backend/db.py index fa0c1db..db980c4 100644 --- a/backend/db.py +++ b/backend/db.py @@ -64,9 +64,10 @@ class ExchangeRate(Base): engine = create_engine('postgresql+psycopg2://{0}:{1}@{2}:{3}/{4}'.format(DB_USER, DB_PW, DB_HOST, DB_PORT, DB_NAME)) Session = sessionmaker(bind=engine) -def truncate_offers(): +def truncate_all_tables(): session = Session() session.execute(text("TRUNCATE aircraft_offer")) + session.execute(text("TRUNCATE exchange_rates")) session.commit() session.close() @@ -103,6 +104,7 @@ def offer_url_exists(offer_url): finally: session.close() + def get_exchange_rates_as_dict(session): all_exchange_rates = session.query(ExchangeRate).all() exchange_rates = dict() diff --git a/backend/exchange_rates.py b/backend/exchange_rates.py index 03dbb3c..7f82b51 100644 --- a/backend/exchange_rates.py +++ b/backend/exchange_rates.py @@ -11,9 +11,13 @@ def fetch_exchange_rates_from_ecb(): start_time = time.time() - count = 0 response = requests.get(ECB_FX_RATES_URL).text - et = ET.fromstring(response) + count = update_exchange_rates(response) + logger.info("Updated %d exchange rates in %.2fs", count, (time.time() - start_time)) + +def update_exchange_rates(xmlData): + count = 0 + et = ET.fromstring(xmlData) for element in et.iter(): curr_name = element.get('currency') if curr_name is None: @@ -22,8 +26,8 @@ def fetch_exchange_rates_from_ecb(): logger.debug("Fetched exchange rate: {0} for currency {1}".format(str(curr_rate), str(curr_name))) db.update_exchange_rate(db.ExchangeRate(currency=curr_name, rate=curr_rate)) count += 1 + return count - logger.info("Updated %d exchange rates in %.2fs", count, (time.time() - start_time)) def get_currency_code(o): currency_str = o @@ -36,7 +40,8 @@ def get_currency_code(o): '$': 'USD', 'US$': 'USD', '£': 'GBP', - 'Fr.': 'CHF' + 'Fr.': 'CHF', + 'A$': 'AUD', } if currency_str in iso_code_mapping: return iso_code_mapping[currency_str] diff --git a/backend/tests/eurofxref-daily.xml b/backend/tests/eurofxref-daily.xml new file mode 100644 index 0000000..df3755f --- /dev/null +++ b/backend/tests/eurofxref-daily.xml @@ -0,0 +1,41 @@ + + + Reference rates + + European Central Bank + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/backend/tests/test_db.py b/backend/tests/test_db.py index 9ddc055..49d3e81 100644 --- a/backend/tests/test_db.py +++ b/backend/tests/test_db.py @@ -8,7 +8,7 @@ class DbTest(unittest.TestCase): def setUp(self): - db.truncate_offers() + db.truncate_all_tables() def test_should_store_and_fetch_offer(self): # given diff --git a/backend/tests/test_exchange_rates.py b/backend/tests/test_exchange_rates.py index 87752a2..432d1b5 100644 --- a/backend/tests/test_exchange_rates.py +++ b/backend/tests/test_exchange_rates.py @@ -1,11 +1,33 @@ import unittest +from ddt import ddt from price_parser import Price -from exchange_rates import get_currency_code +import tests.test__testcontainers_setup +import db +import exchange_rates +@ddt class ExchangeRatesTest(unittest.TestCase): + def setUp(self): + db.truncate_all_tables() + def test_get_currency_code(self): - self.assertEqual("EUR", get_currency_code(Price.fromstring("12.000,00 €"))) - self.assertEqual("GBP", get_currency_code(Price.fromstring("16,23 £"))) - self.assertEqual("CHF", get_currency_code(Price.fromstring("12 SFr."))) + self.assertEqual("EUR", exchange_rates.get_currency_code(Price.fromstring("12.000,00 €"))) + self.assertEqual("GBP", exchange_rates.get_currency_code(Price.fromstring("16,23 £"))) + self.assertEqual("CHF", exchange_rates.get_currency_code(Price.fromstring("12 SFr."))) + self.assertEqual("AUD", exchange_rates.get_currency_code(Price.fromstring("12 A$"))) + + + def test_should_update_fx_rates_in_db(self): + # when + exchange_rates.update_exchange_rates(readFile("tests/eurofxref-daily.xml")) + + # then + fxRates = db.get_exchange_rates_as_dict(db.Session()) + self.assertEqual("4.3250", fxRates["PLN"]) + + +def readFile(name: str): + with open(name, 'r') as file: + return file.read() diff --git a/backend/tests/test_pipelines.py b/backend/tests/test_pipelines.py index b7faa0e..7b25e0d 100644 --- a/backend/tests/test_pipelines.py +++ b/backend/tests/test_pipelines.py @@ -13,7 +13,7 @@ class DuplicateDetectionTest(unittest.TestCase): def setUp(self): - db.truncate_offers() + db.truncate_all_tables() self.sample_offer = buildOfferWithUrl("https://offers.com/1") self.detection = pipelines.DuplicateDetection() @@ -111,7 +111,7 @@ def test_regular_offers_are_not_dropped(self, offer_title): class StoragePipelineTest(unittest.TestCase): def setUp(self): - db.truncate_offers() + db.truncate_all_tables() self.storage = pipelines.StoragePipeline() def test_should_store_offer(self):