diff --git a/.gitignore b/.gitignore index bba8abe..5cc0eed 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Ignore files generated by Python -__pycache__/ +**/__pycache__ *.pyc # Ignore log files @@ -21,4 +21,6 @@ data/*.json # Ignore Airflow files plugins/ logs/ -airflow.cfg \ No newline at end of file +airflow.cfg + +module/ \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..e69de29 diff --git a/assets/marketpipe_logo.png b/assets/marketpipe_logo.png new file mode 100644 index 0000000..fda6fb9 Binary files /dev/null and b/assets/marketpipe_logo.png differ diff --git a/core/crypto_api_client.py b/core/crypto_api_client.py new file mode 100644 index 0000000..a2bd269 --- /dev/null +++ b/core/crypto_api_client.py @@ -0,0 +1,93 @@ +import requests +import logging +from dotenv import load_dotenv +import os +from base_api import BaseApiClient +from utils.market_data_processor_utils import read_json + + +COIN_BASE_URL = "https://pro-api.coinmarketcap.com/v1/cryptocurrency/listings/latest?" + + +load_dotenv() + + +class CryptoApiClient(BaseApiClient): + """ + A client for retrieving cryptocurrency data from the CoinMarketCap API. + + This class inherits from the BaseApiClient abstract base class and implements the get_data method. + + Attributes: + logger (logging.Logger): The logger object for logging messages. + + Methods: + __init__(self, COIN_API_KEY: str, logger: logging.Logger): Initializes a new instance of the CryptoApiClient class. + """ + + def __init__(self, logger: logging.Logger): + """ + Initializes a new instance of the CryptoApiClient class. + + Args: + logger (logging.Logger): The logger object for logging messages. + """ + super().__init__(logger=logger) + self.symbols = read_json("mdp_config.json")['assets']['cryptos']['symbols'] + + def get_data(self) -> dict[str, dict[str, any]]: + """ + Retrieves market data for the given list of symbols from CoinMarketCap API. + + Returns: + Dict[str, Dict[str, str]]: A dictionary containing the retrieved data for each symbol. + """ + parameters = { + "start": "1", + "limit": "100", + "convert": "USD", + "sort": "percent_change_24h", + } + + headers = { + "Accepts": "application/json", + "X-CMC_PRO_API_KEY": os.getenv("COIN_API_KEY"), + } + + crypto_data = {} + + for symbol in self.symbols: + try: + parameters["symbol"] = symbol + + response = requests.get(COIN_BASE_URL, headers=headers, params=parameters) + response.raise_for_status() + + data = response.json() + symbol_data = data.get("data", [])[0] + + symbol_info = { + "name": symbol_data.get("name", ""), + "volume": symbol_data["quote"]["USD"].get("volume_24h", ""), + "price": symbol_data["quote"]["USD"].get("price", ""), + "change_percent": symbol_data["quote"]["USD"].get( + "percent_change_24h", "" + ), + "market_cap": symbol_data["quote"]["USD"].get("market_cap", ""), + } + + crypto_data[symbol] = symbol_info + + self.logger.info(f"Successfully retrieved data for symbol {symbol}.") + except requests.exceptions.RequestException as req_error: + self.logger.error( + f"Error during API request for {symbol}: {req_error}" + ) + raise + except (IndexError, KeyError) as data_error: + self.logger.error( + f"Error processing data for {symbol}: {data_error}" + ) + raise + + return crypto_data diff --git a/core/storage.py b/core/storage.py new file mode 100644 index 0000000..dd882ab --- /dev/null +++ b/core/storage.py @@ -0,0 +1,94 @@ +import psycopg2 +import logging +from dotenv import load_dotenv +import os + +load_dotenv() + + +class Storage: + """ + A class that handles storing data in a database. + + Attributes: + logger (logging.Logger): The logger object for logging messages. + conn: The database connection object. + cur: The database cursor object. + """ + + def __init__(self, logger: logging.Logger): + self.logger = logger + self.conn = None + self.cur = None + + def _connect(self): + try: + self.conn = psycopg2.connect( + host=os.getenv("POSTGRES_HOST"), + port=os.getenv("POSTGRES_PORT"), + database=os.getenv("POSTGRES_DB"), + user=os.getenv("POSTGRES_USER"), + password=os.getenv("POSTGRES_PASSWORD"), + ) + self.cur = self.conn.cursor() + except psycopg2.Error as e: + self.logger.error(f"Error connecting to the database: {e}") + raise + + def _close(self): + try: + if self.cur: + self.cur.close() + if self.conn: + self.conn.close() + except psycopg2.Error as e: + self.logger.error(f"Error closing the database connection: {e}") + + def store_data(self, data: dict[str, dict[str, any]], table: str) -> None: + if not isinstance(table, str): + error_msg = "Table name must be a string" + self.logger.error(error_msg) + raise TypeError(error_msg) + try: + self.logger.info("Storing data in the database.") + + self._connect() + + with self.conn, self.cur: + for symbol, asset_data in data.items(): + name = asset_data["name"] + volume = asset_data["volume"] + price = asset_data["price"] + market_cap = asset_data["market_cap"] + change_percent = asset_data["change_percent"] + + if not all([symbol, name]): + self.logger.error( + f"One or more required fields are missing from the {table} data for symbol: {symbol}, name: {name}" + ) + raise ValueError( + f"One or more required fields are missing from the {table} data" + ) + + self.cur.execute( + f"INSERT INTO {table} (symbol, name, market_cap, volume, price, change_percent) VALUES (%s, %s, %s, %s, %s, %s)", + (symbol, name, market_cap, volume, price, change_percent), + ) + + self.logger.info( + f"Successfully stored data for symbol {symbol} in the {table} table." + ) + + self.conn.commit() + + except psycopg2.Error as error: + self.logger.error( + f"An error occurred while storing data in the database: {error}" + ) + if self.conn: + self.conn.rollback() + finally: + self._close() + + + diff --git a/dags/market_data_dag.py b/dags/market_data_dag.py index 755fd8e..717ebbe 100644 --- a/dags/market_data_dag.py +++ b/dags/market_data_dag.py @@ -1,97 +1,47 @@ -import os -import sys +from airflow import DAG from airflow.operators.python import PythonOperator -from airflow.models import DAG from datetime import datetime -from dotenv import load_dotenv import logging +from utils.market_data_processor_utils import read_json +from core.data_processor import DataProcessor -# Configure logging -logging.basicConfig( - level=logging.WARNING, - format="[%(asctime)s] [%(levelname)s] [%(name)s] - %(message)s", -) +config = read_json("mdp_config.json") -logger = logging.getLogger(__name__) - -# Load environment variables from .env file -load_dotenv() - -# Find the parent directory -parent_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.dirname(parent_dir) +default_args = { + "owner": config.get("owner", "airflow"), + "depends_on_past": False, + "start_date": datetime.now(), + "email_on_failure": config.get("email_on_failure", False), + "email_on_retry": config.get("email_on_retry", False), + "retries": config.get("retries", 1), +} -# Add the project root to the Python path -sys.path.insert(0, project_root) -from core.market_data_processor import ( - StockApiClient, - CryptoApiClient, - Storage, - MarketDataEngine, -) +def create_market_data_dag(asset_type, dag_id, description): + dag = DAG( + dag_id, + default_args=default_args, + schedule_interval=config["assets"][asset_type]["schedule_interval"], + description=description, + ) -# Define default arguments for the DAGs -default_args_stocks = { - "owner": "airflow", - "depends_on_past": False, - "start_date": datetime(2023, 3, 15), - "email_on_failure": False, - "email_on_retry": False, - "retries": 0, -} + market_processor = DataProcessor(asset_type) -default_args_cryptos = { - "owner": "airflow", - "depends_on_past": False, - "start_date": datetime(2023, 3, 15), - "email_on_failure": False, - "email_on_retry": False, - "retries": 0, -} + with dag: + get_data_task = PythonOperator( + task_id=f"get_{asset_type}_data", + python_callable=market_processor.get_data, + ) -# Create instances of the classes -stock_api_client = StockApiClient( - os.environ["ALPHA_API_KEY"], os.environ["PREP_API_KEY"], logger -) -crypto_api_client = CryptoApiClient(os.environ["COIN_API_KEY"], logger) -db_connector = Storage( - os.getenv("POSTGRES_HOST"), - os.getenv("POSTGRES_PORT"), - os.getenv("POSTGRES_DB"), - os.getenv("POSTGRES_USER"), - os.getenv("POSTGRES_PASSWORD"), - logger, -) -stock_engine = MarketDataEngine(stock_api_client, db_connector, logger) -crypto_engine = MarketDataEngine(crypto_api_client, db_connector, logger) + store_data_task = PythonOperator( + task_id=f"store_{asset_type}_data", + python_callable=market_processor.store_data, + ) -# Create the DAG for stock data collection and storage -dag_stocks = DAG( - "process_stock_data", - default_args=default_args_stocks, - schedule_interval="0 23 * * 1-5", # Schedule to run everyday at 11 PM from Monday to Friday - description="Collect and store stock data", -) + get_data_task >> store_data_task -# Create the DAG for cryptocurrency data collection and storage -dag_cryptos = DAG( - "process_crypto_data", - default_args=default_args_cryptos, - schedule_interval="0 23 * * *", # Schedule to run everyday at 11 PM - description="Collect and store cryptocurrency data", -) + return dag -# Define the task for stock data collection and storage -process_stock_data_task = PythonOperator( - task_id="get_stocks", - python_callable=stock_engine.process_stock_data, - dag=dag_stocks, -) -# Define the tasks for cryptocurrency data collection and storage -process_crypto_data_task = PythonOperator( - task_id="get_crypto", - python_callable=crypto_engine.process_crypto_data, - dag=dag_cryptos, -) +create_market_data_dag("stocks", "process_stock_data", "Collect and store stock data") +create_market_data_dag("cryptos", "process_crypto_data", "Collect and store crypto data") diff --git a/database_setup/init.sql b/database_setup/init.sql new file mode 100644 index 0000000..5c25161 --- /dev/null +++ b/database_setup/init.sql @@ -0,0 +1,27 @@ +-- Create the schema +CREATE SCHEMA IF NOT EXISTS market_data; + + +-- Create a table for stock data +CREATE TABLE IF NOT EXISTS market_data.stocks ( + id SERIAL PRIMARY KEY, + date_collected DATE NOT NULL DEFAULT CURRENT_DATE, + symbol VARCHAR(20) NOT NULL, + name VARCHAR(50) NOT NULL, + market_cap DECIMAL(20,2) NOT NULL, + volume INT NOT NULL, + price DECIMAL(10,2) NOT NULL, + change_percent DECIMAL(15,8) NOT NULL +); + +-- Create a table for cryptocurrency data +CREATE TABLE IF NOT EXISTS market_data.cryptos ( + id SERIAL PRIMARY KEY, + date_collected DATE NOT NULL DEFAULT CURRENT_DATE, + symbol VARCHAR(20) NOT NULL, + name VARCHAR(50) NOT NULL, + market_cap DECIMAL(20,2) NOT NULL, + volume INT NOT NULL, + price DECIMAL(25,15) NOT NULL, + change_percent DECIMAL(50,30) NOT NULL +); diff --git a/init.sql b/init.sql deleted file mode 100644 index 90f2144..0000000 --- a/init.sql +++ /dev/null @@ -1,79 +0,0 @@ --- Create a schema for stock_data -CREATE SCHEMA stock_data; - --- Create a schema for crypto data -CREATE SCHEMA crypto_data; - --- Create tables for stock data --- Create a table to store gainers data -CREATE TABLE stock_data.gainers ( - id SERIAL PRIMARY KEY, - date_collected DATE NOT NULL DEFAULT CURRENT_DATE, - symbol VARCHAR(20) NOT NULL, - name VARCHAR(50) NOT NULL, - market_cap NUMERIC(20,2) NOT NULL, - volume INT NOT NULL, - price NUMERIC(10,2) NOT NULL, - change_percent NUMERIC(15,8) NOT NULL -); - --- Create a table to store losers data -CREATE TABLE stock_data.losers ( - id SERIAL PRIMARY KEY, - date_collected DATE NOT NULL DEFAULT CURRENT_DATE, - symbol VARCHAR(20) NOT NULL, - name VARCHAR(50) NOT NULL, - market_cap NUMERIC(20,2) NOT NULL, - volume INT NOT NULL, - price NUMERIC(10,2) NOT NULL, - change_percent NUMERIC(15,8) NOT NULL -); - --- Create a table to store actives data -CREATE TABLE stock_data.actives ( - id SERIAL PRIMARY KEY, - date_collected DATE NOT NULL DEFAULT CURRENT_DATE, - symbol VARCHAR(20) NOT NULL, - name VARCHAR(50) NOT NULL, - market_cap NUMERIC(20,2) NOT NULL, - volume INT NOT NULL, - price NUMERIC(10,2) NOT NULL, - change_percent NUMERIC(15,8) NOT NULL -); - --- Create tables for crypto data --- Create a table to store gainers data -CREATE TABLE crypto_data.gainers ( - id SERIAL PRIMARY KEY, - date_collected DATE NOT NULL DEFAULT CURRENT_DATE, - symbol VARCHAR(20) NOT NULL, - name VARCHAR(50) NOT NULL, - market_cap NUMERIC(20,2) NOT NULL, - volume INT NOT NULL, - price NUMERIC(25,15) NOT NULL, - change_percent NUMERIC(50,30) NOT NULL -); - --- Create a table to store losers data -CREATE TABLE crypto_data.losers ( - id SERIAL PRIMARY KEY, - date_collected DATE NOT NULL DEFAULT CURRENT_DATE, - symbol VARCHAR(20) NOT NULL, - name VARCHAR(50) NOT NULL, - market_cap NUMERIC(20,2) NOT NULL, - volume INT NOT NULL, - price NUMERIC(25,15) NOT NULL, - change_percent NUMERIC(50,30) NOT NULL -); - --- Create a table to store actives data -CREATE TABLE crypto_data.actives ( - id SERIAL PRIMARY KEY, - date_collected DATE NOT NULL DEFAULT CURRENT_DATE, - symbol VARCHAR(20) NOT NULL, - name VARCHAR(50) NOT NULL, - market_cap NUMERIC(20,2) NOT NULL, - volume INT NOT NULL, - price NUMERIC(25,15) NOT NULL, - change_percent NUMERIC(50,30) NOT NULL -); diff --git a/tests/test_data_processor.py b/tests/test_data_processor.py new file mode 100644 index 0000000..3993d55 --- /dev/null +++ b/tests/test_data_processor.py @@ -0,0 +1,178 @@ +import os +import sys +import unittest +from unittest.mock import patch, MagicMock +import logging +from utils import market_data_processor_utils +from core.market_data_processor import ( + StockApiClient, + CryptoApiClient, + Storage, + MarketDataEngine, +) + +class TestStorage(unittest.TestCase): + """ + Unit tests for the Storage class. + """ + + def setUp(self): + self.logger = MagicMock(spec=logging.Logger) + + self.storage = Storage(logger=self.logger) + + @patch.dict( + "core.market_data_processor.os.environ", + { + "POSTGRES_USER": "user", + "POSTGRES_PASSWORD": "password", + "POSTGRES_DB": "test_db", + "POSTGRES_HOST": "localhost", + "POSTGRES_PORT": "5432", + }, + clear=True, + ) + @patch("core.market_data_processor.psycopg2.connect") + def test_connect(self, mock_connect): + self.storage._connect() + + mock_connect.assert_called_once_with( + host="localhost", + port="5432", + database="test_db", + user="user", + password="password", + ) + + @patch("core.market_data_processor.psycopg2.connect") + def test_close(self, mock_connect): + mock_conn = mock_connect.return_value + mock_cur = MagicMock() + mock_conn.cursor.return_value = mock_cur + + self.storage._connect() + self.storage._close() + + mock_cur.close.assert_called_once() + mock_conn.close.assert_called_once() + + @patch("core.market_data_processor.psycopg2.connect") + def test_store_data_with_valid_data(self, mock_connect): + mock_conn = mock_connect.return_value + mock_cur = MagicMock() + mock_conn.cursor.return_value = mock_cur + mock_execute = mock_cur.execute + mock_commit = mock_conn.commit + + # Test case with valid data + data = { + "ABC": { + "volume": 123456, + "price": 50.25, + "change_percent": 2.5, + "market_cap": "1.2B", + "name": "ABC Company", + } + } + table = "stocks" + + self.storage.store_data(data, table) + + # Assert that execute and commit methods were called + mock_execute.assert_called_once() + mock_execute.assert_called_once_with( + "INSERT INTO stocks (symbol, name, market_cap, volume, price, change_percent) VALUES (%s, %s, %s, %s, %s, %s)", + ("ABC", "ABC Company", "1.2B", 123456, 50.25, 2.5), + ) + mock_commit.assert_called_once() + + @patch("core.market_data_processor.psycopg2.connect") + def test_store_data_with_invalid_data_empty_symbol(self, mock_connect): + # (empty name) + data_invalid = { + "ABC": { + "volume": 1000, + "price": 10.0, + "change_percent": 5.0, + "market_cap": 1000000, + "name": "", + } + } + + table = "stocks" + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + with self.assertRaises(ValueError): + self.storage.store_data(data_invalid, table) + + @patch("core.market_data_processor.psycopg2.connect") + def test_store_data_with_invalid_data_type(self, mock_connect): + data = { + "ABC": { + "volume": 1000, + "price": 10.0, + "change_percent": 5.0, + "market_cap": 1000000, + "name": "", + } + } + + table_invalid = 123 # Invalid table(not a string) + + with patch.object(self.logger, "error"): + with self.assertRaises(TypeError): + self.storage.store_data(data, table_invalid) + + +class TestMarketDataEngine(unittest.TestCase): + """ + Unit tests for the MarketDataEngine class. + """ + + def setUp(self): + self.stock_api_client = MagicMock(spec=StockApiClient) + self.crypto_api_client = MagicMock(spec=CryptoApiClient) + self.db_connector = MagicMock(spec=Storage) + self.logger = MagicMock(spec=logging.Logger) + self.stock_engine = MarketDataEngine( + self.stock_api_client, self.db_connector, self.logger + ) + self.crypto_engine = MarketDataEngine( + self.crypto_api_client, self.db_connector, self.logger + ) + + def test_process_stock_data(self): + self.stock_api_client.get_data.return_value = { + "AAPL": {"price": 150.0}, + "GOOG": {"price": 2000.0}, + "MSFT": {"price": 300.0}, + } + + self.stock_engine.process_stock_data() + + self.db_connector.store_data.assert_called_once_with( + { + "AAPL": {"price": 150.0}, + "GOOG": {"price": 2000.0}, + "MSFT": {"price": 300.0}, + }, + "stocks", + ) + + def test_process_crypto_data(self): + self.crypto_api_client.get_data.return_value = { + "BTC": {"price": 50000.0}, + "ETH": {"price": 4000.0}, + } + + self.crypto_engine.process_crypto_data() + + self.crypto_api_client.get_data.assert_called_once() + self.db_connector.store_data.assert_called_once_with( + {"BTC": {"price": 50000.0}, "ETH": {"price": 4000.0}}, "cryptos" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tests_market_data_processor.py b/tests/tests_market_data_processor.py deleted file mode 100644 index 89ed0b8..0000000 --- a/tests/tests_market_data_processor.py +++ /dev/null @@ -1,463 +0,0 @@ -import os -import sys -import unittest -from unittest.mock import patch, MagicMock -import logging - -# Find the parent directory -parent_dir = os.path.dirname(os.path.abspath(__file__)) -project_root = os.path.dirname(parent_dir) - -# Add the project root to the Python path -sys.path.insert(0, project_root) - -from core.market_data_processor import ( - StockApiClient, - CryptoApiClient, - Storage, - MarketDataEngine, -) - - -class TestStockApiClient(unittest.TestCase): - """ - Unit tests for the StockApiClient class. - """ - - def setUp(self): - # Mock logger for testing - self.logger = logging.getLogger(__name__) - self.logger.addHandler(logging.StreamHandler()) - self.stock_api_client = StockApiClient( - "alpha_api_key", "prep_api_key", logger=self.logger - ) - - def test_init(self): - self.assertEqual(self.stock_api_client.ALPHA_API_KEY, "alpha_api_key") - self.assertEqual(self.stock_api_client.PREP_API_KEY, "prep_api_key") - - @patch("requests.get") - def test_get_stocks(self, mock_get): - mock_response = MagicMock() - mock_response.json.return_value = [ - {"symbol": "ABC"}, - {"symbol": "DEF"}, - {"symbol": "GHI"}, - {"symbol": "JKL"}, - {"symbol": "MNO"}, - ] - mock_get.return_value = mock_response - - stocks = self.stock_api_client.get_stocks() - - self.assertEqual( - stocks, - { - "gainers": ["ABC", "DEF", "GHI", "JKL", "MNO"], - "losers": ["ABC", "DEF", "GHI", "JKL", "MNO"], - "actives": ["ABC", "DEF", "GHI", "JKL", "MNO"], - }, - ) - - @patch("requests.get") - def test_get_stocks_empty_data(self, mock_get): - # Mock the response from the API with empty data - mock_get.return_value.json.return_value = [] - - # Mock the logger.error method to capture log messages - with patch.object(self.logger, "error"): - # Call the method under test and assert that it raises the expected exception - with self.assertRaises(ValueError): - self.stock_api_client.get_stocks() - - @patch("requests.get") - def test_get_stocks_data_error(self, mock_get): - # Mock the response from the API with invalid data - mock_get.return_value.json.return_value = [{"invalid": "data"}] - - # Mock the logger.error method to capture log messages - with patch.object(self.logger, "error"): - # Call the method under test and assert that it raises the expected exception - with self.assertRaises(KeyError): - self.stock_api_client.get_stocks() - - @patch("requests.get") - def test_get_data(self, mock_get): - # Mock the response for a successful Alpha Vantage API request - alpha_response = MagicMock() - alpha_response.status_code = 200 - alpha_response.json.return_value = { - "Global Quote": { - "06. volume": "1000", - "05. price": "150.0", - "10. change percent": "5.0", - } - } - - # Mock the response for a successful Financial Modeling Prep API request - prep_response = MagicMock() - prep_response.status_code = 200 - prep_response.json.return_value = [ - {"companyName": "Apple Inc.", "mktCap": "2000000"} - ] - - mock_get.side_effect = [alpha_response, prep_response] - - # Call the method under test - symbols = {"gainers": ["AAPL"]} - stock_data = self.stock_api_client.get_data(symbols) - - # Assertions based on the mock data - expected_stock_data = { - "gainers": [ - { - "symbol": "AAPL", - "volume": "1000", - "price": "150.0", - "change_percent": "5.0", - "market_cap": "2000000", - "name": "Apple Inc.", - } - ] - } - - self.assertEqual(stock_data, expected_stock_data) - - @patch("requests.get") - def test_get_alpha_data_invalid_data(self, mock_get): - # Define symbols to be used in the test - symbols = {"gainers": ["AAPL"]} - - # Mock the response for a successful Alpha Vantage API request - alpha_response = MagicMock() - alpha_response.status_code = 200 - alpha_response.json.return_value = { - "Quote": { - "06. volume": "1000", - "05. price": "150.0", - "10. change percent": "5.0", - } - } - - # Mock the response for a successful Financial Modeling Prep API request - prep_response = MagicMock() - prep_response.status_code = 200 - prep_response.json.return_value = [ - {"companyName": "Apple Inc.", "mktCap": "2000000"} - ] - - mock_get.side_effect = [alpha_response, prep_response] - - # Mock the logger.error method to capture log messages - with patch.object(self.logger, "error"): - # Call the method under test and assert that it raises the expected exception - with self.assertRaises(KeyError): - self.stock_api_client.get_data(symbols) - - @patch("requests.get") - def test_get_prep_data_invalid_data(self, mock_get): - # Define symbols to be used in the test - symbols = {"gainers": ["AAPL"]} - - # Mock the response for a successful Alpha Vantage API request - alpha_response = MagicMock() - alpha_response.status_code = 200 - alpha_response.json.return_value = { - "Global Quote": { - "06. volume": "1000", - "05. price": "150.0", - "10. change percent": "5.0", - } - } - - # Mock the response for a successful Financial Modeling Prep API request - prep_response = MagicMock() - prep_response.status_code = 200 - prep_response.json.return_value = [{"company": "Name"}] - - mock_get.side_effect = [alpha_response, prep_response] - - # Mock the logger.error method to capture log messages - with patch.object(self.logger, "error"): - # Call the method under test and assert that it raises the expected exception - with self.assertRaises(KeyError): - self.stock_api_client.get_data(symbols) - - -class TestCryptoApiClient(unittest.TestCase): - """ - Unit tests for the CryptoApiClient class. - """ - - def setUp(self): - self.logger = logging.getLogger(__name__) - self.logger.addHandler(logging.StreamHandler()) - self.crypto_api_client = CryptoApiClient("coin_api_key", logger=self.logger) - - @patch("requests.get") - def test_get_data(self, mock_get): - # Mock the response from the API - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [ - { - "name": "BREPE", - "symbol": "BREPE", - "quote": { - "USD": { - "price": 7, - "volume_24h": 10, - "percent_change_24h": 16, - "market_cap": 0, - } - }, - } - ] - } - mock_get.return_value = mock_response - - # Call the method under test - data = self.crypto_api_client.get_data() - - # Assert that the method returned the expected data - expected_data = { - "gainers": [ - { - "symbol": "BREPE", - "name": "BREPE", - "volume": 10, - "price": 7, - "change_percent": 16, - "market_cap": 0, - } - ], - "losers": [ - { - "symbol": "BREPE", - "name": "BREPE", - "volume": 10, - "price": 7, - "change_percent": 16, - "market_cap": 0, - } - ], - "actives": [ - { - "symbol": "BREPE", - "name": "BREPE", - "volume": 10, - "price": 7, - "change_percent": 16, - "market_cap": 0, - } - ], - } - - self.assertEqual(data, expected_data) - - @patch("requests.get") - def test_get_data_invalid_data(self, mock_get): - # Mock the response from the API with invalid data (missing 'quote' key) - mock_response = MagicMock() - mock_response.status_code = 200 - mock_response.json.return_value = { - "data": [{"name": "BREPE", "symbol": "BREPE"}] - } - mock_get.return_value = mock_response - - # Mock the logger.error method to capture log messages - with patch.object(self.logger, "error"): - # Call the method under test and assert that it raises the expected exception - with self.assertRaises(KeyError): - self.crypto_api_client.get_data() - - -class TestStorage(unittest.TestCase): - """ - Unit tests for the Storage class. - """ - - def setUp(self): - # Mock logger for testing - self.logger = logging.getLogger(__name__) - self.logger.addHandler(logging.StreamHandler()) - - self.storage = Storage( - "localhost", 5432, "test_db", "user", "password", logger=self.logger - ) - - @patch("psycopg2.connect") - def test_connect(self, mock_connect): - # Mock the cursor and its methods - mock_conn = mock_connect.return_value - mock_cur = MagicMock() - mock_conn.cursor.return_value = mock_cur - - self.storage._connect() - - mock_connect.assert_called_once_with( - host="localhost", - port=5432, - database="test_db", - user="user", - password="password", - ) - self.assertEqual(self.storage.conn, mock_conn) - self.assertEqual(self.storage.cur, mock_cur) - - # Ensure the connection object supports context management - with self.storage.conn: - pass # This should not raise an AttributeError - - @patch("psycopg2.connect") - def test_close(self, mock_connect): - mock_conn = mock_connect.return_value - mock_cur = MagicMock() - mock_conn.cursor.return_value = mock_cur - - self.storage._connect() - self.storage._close() - - mock_cur.close.assert_called_once() - mock_conn.close.assert_called_once() - - @patch("psycopg2.connect") - def test_store_data_with_valid_data(self, mock_connect): - # Mock the cursor and its methods - mock_conn = mock_connect.return_value - mock_cur = MagicMock() - mock_conn.cursor.return_value = mock_cur - mock_execute = mock_cur.execute - mock_commit = mock_conn.commit - - # Test case with valid data - data = { - "gainers": [ - { - "symbol": "ABC", - "name": "ABC Corp", - "volume": 1000, - "price": 10.0, - "market_cap": 1000000, - "change_percent": 5.0, - } - ] - } - data_type = "stock" - - self.storage.store_data(data, data_type) - - # Assert that execute and commit methods were called - mock_execute.assert_called_once() - mock_execute.assert_called_once_with( - "INSERT INTO stock_data.gainers (symbol, name, market_cap, volume, price, change_percent) VALUES (%s, %s, %s, %s, %s, %s)", - ("ABC", "ABC Corp", 1000000, 1000, 10.0, 5.0), - ) - mock_commit.assert_called_once() - - @patch("psycopg2.connect") - def test_store_data_with_invalid_data_empty_symbol(self, mock_connect): - # Test case with invalid data (empty symbol) - data_invalid = { - "gainers": [ - { - "symbol": "", - "name": "ABC Corp", - "volume": 1000, - "price": 10.0, - "market_cap": 1000000, - "change_percent": 5.0, - } - ] - } - data_type = "stock" - - # Mock the logger.error method to capture log messages - with patch.object(self.logger, "error"): - with self.assertRaises(ValueError): - self.storage.store_data(data_invalid, data_type) - - @patch("psycopg2.connect") - def test_store_data_with_invalid_data_type(self, mock_connect): - # Test case with invalid data type - data = { - "gainers": [ - { - "symbol": "ABC", - "name": "ABC Corp", - "volume": 1000, - "price": 10.0, - "market_cap": 1000000, - "change_percent": 5.0, - } - ] - } - data_type_invalid = 123 # Invalid data type (not a string) - - # Mock the logger.error method to capture log messages - with patch.object(self.logger, "error"): - with self.assertRaises(TypeError): - self.storage.store_data(data, data_type_invalid) - - -class TestMarketDataEngine(unittest.TestCase): - """ - Unit tests for the MarketDataEngine class. - """ - - def setUp(self): - self.stock_api_client = MagicMock(spec=StockApiClient) - self.crypto_api_client = MagicMock(spec=CryptoApiClient) - self.db_connector = MagicMock(spec=Storage) - self.logger = MagicMock(spec=logging.Logger) - self.stock_engine = MarketDataEngine( - self.stock_api_client, self.db_connector, self.logger - ) - self.crypto_engine = MarketDataEngine( - self.crypto_api_client, self.db_connector, self.logger - ) - - def test_process_stock_data(self): - # Mock the return values for the api_client methods - self.stock_api_client.get_stocks.return_value = ["AAPL", "GOOG", "MSFT"] - self.stock_api_client.get_data.return_value = { - "AAPL": {"price": 150.0}, - "GOOG": {"price": 2000.0}, - "MSFT": {"price": 300.0}, - } - - # Call the method under test - self.stock_engine.process_stock_data() - - # Assert that the methods were called with the expected arguments - self.stock_api_client.get_stocks.assert_called_once() - self.stock_api_client.get_data.assert_called_once_with(["AAPL", "GOOG", "MSFT"]) - self.db_connector.store_data.assert_called_once_with( - { - "AAPL": {"price": 150.0}, - "GOOG": {"price": 2000.0}, - "MSFT": {"price": 300.0}, - }, - "stock", - ) - - def test_process_crypto_data(self): - # Mock the return value for the api_client.get_data method - self.crypto_api_client.get_data.return_value = { - "BTC": {"price": 50000.0}, - "ETH": {"price": 4000.0}, - } - - # Call the method under test - self.crypto_engine.process_crypto_data() - - # Assert that the methods were called with the expected arguments - self.crypto_api_client.get_data.assert_called_once() - self.db_connector.store_data.assert_called_once_with( - {"BTC": {"price": 50000.0}, "ETH": {"price": 4000.0}}, "crypto" - ) - - -if __name__ == "__main__": - unittest.main()