diff --git a/.github/workflows/run_black.yml b/.github/workflows/run_black.yml new file mode 100644 index 0000000..f589d4f --- /dev/null +++ b/.github/workflows/run_black.yml @@ -0,0 +1,38 @@ +name: Run Black Formatter + +on: + pull_request: + types: [opened, synchronize, reopened] + push: + branches: + - main + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: psf/black@stable + + - name: Add label if failure + if: failure() + run: | + curl --request POST \ + --url "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.number }}/labels" \ + --header "authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + --header "Content-Type: application/json" \ + --data-raw '{"labels": ["Formatter Failed"]}' + + - name: Check and remove label if sucess + if: success() + run: | + labels=$(curl -s \ + --request GET \ + --url "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.number }}/labels" \ + --header "authorization: Bearer ${{ secrets.GITHUB_TOKEN }}") + + if [[ $labels == *"Formatter Failed"* ]]; then + curl --request DELETE \ + --url "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.number }}/labels/Formatter%20Failed" \ + --header "authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" + fi \ No newline at end of file diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml new file mode 100644 index 0000000..fb0b927 --- /dev/null +++ b/.github/workflows/run_tests.yml @@ -0,0 +1,66 @@ +name: Run Tests + +on: + pull_request: + types: [opened, synchronize, reopened] + push: + branches: + - main + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.10.12 + + - name: Install dependencies + run: | + pip install --upgrade pip + pip install requests + pip install psycopg2-binary + pip install python-dotenv + pip install apache-airflow==2.8.1 + pip install apache-airflow[cncf.kubernetes] + pip install pandas + pip install Flask-Session==0.5.0 + + - name: Initialize Airflow database + run: airflow db migrate + + - name: Run tests + run: | + python -m unittest discover tests + python tests/dags_test.py + + - name: Add label if failure + if: failure() + run: | + curl --request POST \ + --url "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.number }}/labels" \ + --header "authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + --header "Content-Type: application/json" \ + --data-raw '{"labels": ["Tests Failed"]}' + + - name: Check and remove label if present + if: success() + run: | + labels=$(curl -s \ + --request GET \ + --url "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.number }}/labels" \ + --header "authorization: Bearer ${{ secrets.GITHUB_TOKEN }}") + + if [[ $labels == *"Tests Failed"* ]]; then + curl --request DELETE \ + --url "https://api.github.com/repos/${{ github.repository }}/issues/${{ github.event.number }}/labels/Tests%20Failed" \ + --header "authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" + fi + + + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..6f471f9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,17 @@ +repos: +- repo: local + hooks: + - id: unit-tests + name: Run Unit Tests + entry: | + python3 -c " + import subprocess + import sys + TEST_RESULT = subprocess.call(['python3', '-m', 'unittest', 'discover', 'tests']) + sys.exit(TEST_RESULT) + " + language: system +- repo: https://github.com/psf/black + rev: 22.10.0 + hooks: + - id: black diff --git a/README.md b/README.md index 8101006..53a44ef 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,10 @@ MarketTrackPipe is an automated Apache Airflow data pipeline for collecting and storing stock and cryptocurrency market data. The pipeline retrieves daily data for the top 5 stocks and top 5 cryptocurrencies based on market performance from Alpha Vantage, Financial Modeling Prep, and CoinMarketCap APIs and stores it in a PostgreSQL database. The pipeline is containerized using Docker and written in Python 3. +The pipeline follows object-oriented programming principles to ensure modularity, maintainability, and extensibility. Each component of the pipeline is designed as a separate class with well-defined responsibilities. + +Unit testing is implemented throughout the workflow to ensure the reliability and efficiency of the pipeline. These tests validate the functionality of each component and help identify any potential issues or bugs. + ## Project Components @@ -28,25 +32,123 @@ MarketTrackPipe is an automated Apache Airflow data pipeline for collecting and ``` - `core`: Contains core functionality for processing market data. +```mermaid +classDiagram +class BaseApiClient { + <> + +logger: logging.Logger + <> + +@abstractmethod get_data(): Dict[str, List[str]] +} +class StockApiClient { + +ALPHA_API_KEY: str + +PREP_API_KEY: str + +ALPHA_BASE_URL: str + +PREP_BASE_URL: str + +logger: logging.Logger + +get_stocks(): Dict[str, List[str]] + +get_data(symbols: Dict[str, List[str]]): Dict[str, List[Dict]] +} +class CryptoApiClient { + +COIN_API_KEY: str + +logger: logging.Logger + +get_data(): Dict[str, List[Dict]] +} +class Storage { + +host: str + +port: int + +database: str + +user: str + +password: str + +conn + +cur + +logger: logging.Logger + +_connect() + +_close() + +store_data(data: Dict[str, List[Dict[str, any]]], data_type: str): None +} +class MarketDataEngine { + +api_client: BaseApiClient + +db_connector: Storage + +logger: logging.Logger + +process_stock_data() + +process_crypto_data() +} +BaseApiClient <|-- StockApiClient +BaseApiClient <|-- CryptoApiClient +MarketDataEngine "1" --> "1" BaseApiClient +MarketDataEngine "1" --> "1" Storage +``` +
+ - `dags`: Contains the Apache Airflow DAG definitions for orchestrating the data collection and storage process. - `tests`: Contains the unit tests for testing individual components of the project. - `init.sql`: SQL script for creating and initializing the database schema. +```mermaid +graph TD; + subgraph DB + schema[market_data] + stock[stock_data] + crypto[crypto_data] + end + subgraph Fields + date_collected + symbol + name + market_cap + volume + price + change_percent + end + + schema --> |Schema| stock & crypto + + stock & crypto -->|Table| gainers & losers & actives + + gainers & losers & actives --> Fields +``` +
+ - `docker-compose.yml`: Defines the services and configures the project's containers, setting up the environment (postgres, pgadmin, airflow). The `MarketDataEngine` class within `core/market_data_processor.py` encapsulates the logic for retrieving and storing market data. The `market_data_dag.py` file within the `dags` directory sets up the Apache Airflow DAGs for collecting and storing market data. - -The `init.sql` defines two schemas in `market_data` database, one for `stock_data` and another for `crypto_data`, and then creates tables within each schema to store `gainer`, `loser`, and `active` data for both stock and crypto. - -The columns for each table are as follows: - -- `id` : a unique identifier for each row in the table -- `date_collected` : the date on which the data was collected, defaulting to the current date -- `symbol` : the stock or crypto symbol -- `name` : the name of the stock or crypto -- `market_cap` : the market capitalization of the stock or crypto -- `volume` : the trading volume of the stock or crypto -- `price` : the price of the stock or crypto -- `change_percent` : the percentage change in the price of the stock or crypto +
+```mermaid +graph TD; + subgraph MarketTrackPipe + A((Airflow)) + D(Docker) + P(PostgreSQL) + G(pgAdmin) + end + subgraph Core + MDE(MarketDataEngine) + SAPI(StockApiClient) + CAPI(CryptoApiClient) + STR(Storage) + end + subgraph Dags + MD_DAG_stocks(process_stock_data) + MD_DAG_crypto(process_crypto_data) + end + + D --> A & P & G + P --> G + A --> Dags + Dags --> MDE + MDE --> SAPI & CAPI + SAPI & CAPI --> API + API --> SAPI & CAPI + SAPI & CAPI --> STR + STR --> P + + style A fill:#f9f,stroke:#333,stroke-width:4px; + style D fill:#bbf,stroke:#333,stroke-width:2px; + style P fill:#f9f,stroke:#333,stroke-width:4px; + style MDE fill:#f9f,stroke:#333,stroke-width:4px; + style MD_DAG_stocks fill:#f9f,stroke:#333,stroke-width:4px; + style MD_DAG_crypto fill:#f9f,stroke:#333,stroke-width:4px; +``` ## Requirements @@ -81,7 +183,6 @@ The columns for each table are as follows: airflow trigger_dag data_collection_storage_crypto ``` - ## Setting up Pre-commit Hooks (Developer Setup) To ensure code quality and run unit tests before committing changes, MarketTrackPipe uses [pre-commit](https://pre-commit.com/) hooks. Follow these steps to set it up: @@ -97,9 +198,10 @@ To ensure code quality and run unit tests before committing changes, MarketTrack ```bash pre-commit install ``` - + This will install the pre-commit hook into your git repository.
+ 3. Now, every time you commit changes, pre-commit will automatically run unit tests to ensure code quality. Additionally, these tests are also executed in a GitHub Actions workflow on every pull request to the repository. ## Usage @@ -137,4 +239,3 @@ Additionally, a GitHub Action is configured to automatically run the black forma Please make sure to run `pip install pre-commit` and `pre-commit install` as mentioned in the setup instructions to enable the pre-commit hook on your local development environment. Contributors are encouraged to follow the black code style guidelines when making changes to the codebase. - diff --git a/core/market_data_processor.py b/core/market_data_processor.py index 5f8a6e7..5e07f26 100644 --- a/core/market_data_processor.py +++ b/core/market_data_processor.py @@ -6,14 +6,6 @@ from typing import Dict, List from abc import ABC, abstractmethod -# Configure logging -logging.basicConfig( - level=logging.WARNING, - format="[%(asctime)s] [%(levelname)s] [%(name)s] - %(message)s", -) - -logger = logging.getLogger(__name__) - class BaseApiClient(ABC): @@ -206,7 +198,7 @@ def get_data(self, symbols: Dict[str, List[str]]) -> Dict[str, List[Dict]]: class CryptoApiClient(BaseApiClient): - def __init__(self, COIN_API_KEY: str, logger: logging.Logger = None): + def __init__(self, COIN_API_KEY: str, logger: logging.Logger): super().__init__(logger=logger) self.COIN_API_KEY = COIN_API_KEY @@ -353,7 +345,7 @@ def __init__( database: str, user: str, password: str, - logger: logging.Logger = logger, + logger: logging.Logger, ): self.host = host self.port = port @@ -454,7 +446,7 @@ def __init__( self, api_client: BaseApiClient, db_connector: "Storage", - logger: logging.Logger = logger, + logger: logging.Logger, ): self.api_client = api_client self.db_connector = db_connector diff --git a/dags/market_data_dag.py b/dags/market_data_dag.py index c328324..755fd8e 100644 --- a/dags/market_data_dag.py +++ b/dags/market_data_dag.py @@ -4,6 +4,15 @@ from airflow.models import DAG from datetime import datetime from dotenv import load_dotenv +import logging + +# Configure logging +logging.basicConfig( + level=logging.WARNING, + format="[%(asctime)s] [%(levelname)s] [%(name)s] - %(message)s", +) + +logger = logging.getLogger(__name__) # Load environment variables from .env file load_dotenv() @@ -43,22 +52,23 @@ # Create instances of the classes stock_api_client = StockApiClient( - os.environ["ALPHA_API_KEY"], os.environ["PREP_API_KEY"] + os.environ["ALPHA_API_KEY"], os.environ["PREP_API_KEY"], logger ) -crypto_api_client = CryptoApiClient(os.environ["COIN_API_KEY"]) +crypto_api_client = CryptoApiClient(os.environ["COIN_API_KEY"], logger) db_connector = Storage( - os.getenv["POSTGRES_HOST"], - os.getenv["POSTGRES_PORT"], - os.getenv["POSTGRES_DB"], - os.getenv["POSTGRES_USER"], - os.getenv["POSTGRES_PASSWORD"], + os.getenv("POSTGRES_HOST"), + os.getenv("POSTGRES_PORT"), + os.getenv("POSTGRES_DB"), + os.getenv("POSTGRES_USER"), + os.getenv("POSTGRES_PASSWORD"), + logger, ) -stock_engine = MarketDataEngine(stock_api_client, db_connector) -crypto_engine = MarketDataEngine(crypto_api_client, db_connector) +stock_engine = MarketDataEngine(stock_api_client, db_connector, logger) +crypto_engine = MarketDataEngine(crypto_api_client, db_connector, logger) # Create the DAG for stock data collection and storage dag_stocks = DAG( - "data_collection_storage_stocks", + "process_stock_data", default_args=default_args_stocks, schedule_interval="0 23 * * 1-5", # Schedule to run everyday at 11 PM from Monday to Friday description="Collect and store stock data", @@ -66,7 +76,7 @@ # Create the DAG for cryptocurrency data collection and storage dag_cryptos = DAG( - "data_collection_storage_crypto", + "process_crypto_data", default_args=default_args_cryptos, schedule_interval="0 23 * * *", # Schedule to run everyday at 11 PM description="Collect and store cryptocurrency data", diff --git a/tests/dags_test.py b/tests/dags_test.py new file mode 100644 index 0000000..226144e --- /dev/null +++ b/tests/dags_test.py @@ -0,0 +1,123 @@ +import os +import sys +import unittest +from unittest.mock import patch +from airflow.models import DagBag +from airflow.operators.python import PythonOperator +import logging + +# Set the logging level to ERROR for the Airflow logger +logging.getLogger("airflow").setLevel(logging.ERROR) + +# Find the parent directory +parent_directory = os.path.dirname(os.path.abspath(__file__)) + +# Find the project root +project_root = os.path.dirname(parent_directory) + +# Add the project root to the Python path +sys.path.insert(0, project_root) + +from core.market_data_processor import StockApiClient, CryptoApiClient, Storage +from dags.market_data_dag import process_crypto_data_task, process_stock_data_task + + +class TestMarketDataDag(unittest.TestCase): + """ + Unit tests for the Market Data DAGs. + """ + + def setUp(self): + + self.dagbag = DagBag( + dag_folder=os.path.join(project_root, "dags"), include_examples=False + ) + self.stock_dag_id = "process_stock_data" + self.crypto_dag_id = "process_crypto_data" + + def test_dag_stocks_exists(self): + self.assertIn(self.stock_dag_id, self.dagbag.dags) + + def test_dag_stocks_loaded(self): + dag = self.dagbag.get_dag(self.stock_dag_id) + self.assertDictEqual(self.dagbag.import_errors, {}) + self.assertIsNotNone(dag) + self.assertEqual(len(dag.tasks), 1) + + def test_dag_stocks_schedule_interval(self): + dag = self.dagbag.get_dag(self.stock_dag_id) + self.assertEqual(dag.schedule_interval, "0 23 * * 1-5") + + @patch.object(StockApiClient, "get_stocks") + @patch.object(StockApiClient, "get_data") + @patch.object(Storage, "store_data") + def test_process_stock_data_task( + self, mock_store_data, mock_get_data, mock_get_stocks + ): + # Setup mock behavior + stocks = {"gainers": ["ABC"]} + + stock_data = { + "gainers": [ + { + "symbol": "ABC", + "volume": "123456", + "price": "50.25", + "change_percent": "2.5", + "market_cap": "1.2B", + "name": "ABC Company", + } + ] + } + mock_get_stocks.return_value = stocks + mock_get_data.return_value = stock_data + + # Get the task + task_id = "get_stocks" + + test = PythonOperator( + task_id=task_id, + python_callable=process_stock_data_task.python_callable, + ) + + test.execute(context={}) + + # Check if the methods were called + mock_get_stocks.assert_called_once() + mock_get_data.assert_called_once() + mock_store_data.assert_called_once() + + def test_dag_cryptos_exists(self): + self.assertIn(self.crypto_dag_id, self.dagbag.dags) + + def test_dag_cryptos_loaded(self): + dag = self.dagbag.get_dag(self.crypto_dag_id) + self.assertDictEqual(self.dagbag.import_errors, {}) + self.assertIsNotNone(dag) + self.assertEqual(len(dag.tasks), 1) + + def test_dag_cryptos_schedule_interval(self): + dag = self.dagbag.get_dag(self.crypto_dag_id) + self.assertEqual(dag.schedule_interval, "0 23 * * *") + + @patch.object(CryptoApiClient, "get_data") + @patch.object(Storage, "store_data") + def test_process_crypto_data_task(self, mock_get_crypto_data, mock_store_data): + # Get the DAG and task + mock_get_crypto_data.return_value = {} + task_id = "get_crypto" + + test = PythonOperator( + task_id=task_id, + python_callable=process_crypto_data_task.python_callable, + ) + + test.execute(context={}) + + # Check if the methods were called + mock_get_crypto_data.assert_called_once() + mock_store_data.assert_called_once() + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/tests_market_data_processor.py b/tests/tests_market_data_processor.py new file mode 100644 index 0000000..89ed0b8 --- /dev/null +++ b/tests/tests_market_data_processor.py @@ -0,0 +1,463 @@ +import os +import sys +import unittest +from unittest.mock import patch, MagicMock +import logging + +# Find the parent directory +parent_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(parent_dir) + +# Add the project root to the Python path +sys.path.insert(0, project_root) + +from core.market_data_processor import ( + StockApiClient, + CryptoApiClient, + Storage, + MarketDataEngine, +) + + +class TestStockApiClient(unittest.TestCase): + """ + Unit tests for the StockApiClient class. + """ + + def setUp(self): + # Mock logger for testing + self.logger = logging.getLogger(__name__) + self.logger.addHandler(logging.StreamHandler()) + self.stock_api_client = StockApiClient( + "alpha_api_key", "prep_api_key", logger=self.logger + ) + + def test_init(self): + self.assertEqual(self.stock_api_client.ALPHA_API_KEY, "alpha_api_key") + self.assertEqual(self.stock_api_client.PREP_API_KEY, "prep_api_key") + + @patch("requests.get") + def test_get_stocks(self, mock_get): + mock_response = MagicMock() + mock_response.json.return_value = [ + {"symbol": "ABC"}, + {"symbol": "DEF"}, + {"symbol": "GHI"}, + {"symbol": "JKL"}, + {"symbol": "MNO"}, + ] + mock_get.return_value = mock_response + + stocks = self.stock_api_client.get_stocks() + + self.assertEqual( + stocks, + { + "gainers": ["ABC", "DEF", "GHI", "JKL", "MNO"], + "losers": ["ABC", "DEF", "GHI", "JKL", "MNO"], + "actives": ["ABC", "DEF", "GHI", "JKL", "MNO"], + }, + ) + + @patch("requests.get") + def test_get_stocks_empty_data(self, mock_get): + # Mock the response from the API with empty data + mock_get.return_value.json.return_value = [] + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + # Call the method under test and assert that it raises the expected exception + with self.assertRaises(ValueError): + self.stock_api_client.get_stocks() + + @patch("requests.get") + def test_get_stocks_data_error(self, mock_get): + # Mock the response from the API with invalid data + mock_get.return_value.json.return_value = [{"invalid": "data"}] + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + # Call the method under test and assert that it raises the expected exception + with self.assertRaises(KeyError): + self.stock_api_client.get_stocks() + + @patch("requests.get") + def test_get_data(self, mock_get): + # Mock the response for a successful Alpha Vantage API request + alpha_response = MagicMock() + alpha_response.status_code = 200 + alpha_response.json.return_value = { + "Global Quote": { + "06. volume": "1000", + "05. price": "150.0", + "10. change percent": "5.0", + } + } + + # Mock the response for a successful Financial Modeling Prep API request + prep_response = MagicMock() + prep_response.status_code = 200 + prep_response.json.return_value = [ + {"companyName": "Apple Inc.", "mktCap": "2000000"} + ] + + mock_get.side_effect = [alpha_response, prep_response] + + # Call the method under test + symbols = {"gainers": ["AAPL"]} + stock_data = self.stock_api_client.get_data(symbols) + + # Assertions based on the mock data + expected_stock_data = { + "gainers": [ + { + "symbol": "AAPL", + "volume": "1000", + "price": "150.0", + "change_percent": "5.0", + "market_cap": "2000000", + "name": "Apple Inc.", + } + ] + } + + self.assertEqual(stock_data, expected_stock_data) + + @patch("requests.get") + def test_get_alpha_data_invalid_data(self, mock_get): + # Define symbols to be used in the test + symbols = {"gainers": ["AAPL"]} + + # Mock the response for a successful Alpha Vantage API request + alpha_response = MagicMock() + alpha_response.status_code = 200 + alpha_response.json.return_value = { + "Quote": { + "06. volume": "1000", + "05. price": "150.0", + "10. change percent": "5.0", + } + } + + # Mock the response for a successful Financial Modeling Prep API request + prep_response = MagicMock() + prep_response.status_code = 200 + prep_response.json.return_value = [ + {"companyName": "Apple Inc.", "mktCap": "2000000"} + ] + + mock_get.side_effect = [alpha_response, prep_response] + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + # Call the method under test and assert that it raises the expected exception + with self.assertRaises(KeyError): + self.stock_api_client.get_data(symbols) + + @patch("requests.get") + def test_get_prep_data_invalid_data(self, mock_get): + # Define symbols to be used in the test + symbols = {"gainers": ["AAPL"]} + + # Mock the response for a successful Alpha Vantage API request + alpha_response = MagicMock() + alpha_response.status_code = 200 + alpha_response.json.return_value = { + "Global Quote": { + "06. volume": "1000", + "05. price": "150.0", + "10. change percent": "5.0", + } + } + + # Mock the response for a successful Financial Modeling Prep API request + prep_response = MagicMock() + prep_response.status_code = 200 + prep_response.json.return_value = [{"company": "Name"}] + + mock_get.side_effect = [alpha_response, prep_response] + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + # Call the method under test and assert that it raises the expected exception + with self.assertRaises(KeyError): + self.stock_api_client.get_data(symbols) + + +class TestCryptoApiClient(unittest.TestCase): + """ + Unit tests for the CryptoApiClient class. + """ + + def setUp(self): + self.logger = logging.getLogger(__name__) + self.logger.addHandler(logging.StreamHandler()) + self.crypto_api_client = CryptoApiClient("coin_api_key", logger=self.logger) + + @patch("requests.get") + def test_get_data(self, mock_get): + # Mock the response from the API + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [ + { + "name": "BREPE", + "symbol": "BREPE", + "quote": { + "USD": { + "price": 7, + "volume_24h": 10, + "percent_change_24h": 16, + "market_cap": 0, + } + }, + } + ] + } + mock_get.return_value = mock_response + + # Call the method under test + data = self.crypto_api_client.get_data() + + # Assert that the method returned the expected data + expected_data = { + "gainers": [ + { + "symbol": "BREPE", + "name": "BREPE", + "volume": 10, + "price": 7, + "change_percent": 16, + "market_cap": 0, + } + ], + "losers": [ + { + "symbol": "BREPE", + "name": "BREPE", + "volume": 10, + "price": 7, + "change_percent": 16, + "market_cap": 0, + } + ], + "actives": [ + { + "symbol": "BREPE", + "name": "BREPE", + "volume": 10, + "price": 7, + "change_percent": 16, + "market_cap": 0, + } + ], + } + + self.assertEqual(data, expected_data) + + @patch("requests.get") + def test_get_data_invalid_data(self, mock_get): + # Mock the response from the API with invalid data (missing 'quote' key) + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "data": [{"name": "BREPE", "symbol": "BREPE"}] + } + mock_get.return_value = mock_response + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + # Call the method under test and assert that it raises the expected exception + with self.assertRaises(KeyError): + self.crypto_api_client.get_data() + + +class TestStorage(unittest.TestCase): + """ + Unit tests for the Storage class. + """ + + def setUp(self): + # Mock logger for testing + self.logger = logging.getLogger(__name__) + self.logger.addHandler(logging.StreamHandler()) + + self.storage = Storage( + "localhost", 5432, "test_db", "user", "password", logger=self.logger + ) + + @patch("psycopg2.connect") + def test_connect(self, mock_connect): + # Mock the cursor and its methods + mock_conn = mock_connect.return_value + mock_cur = MagicMock() + mock_conn.cursor.return_value = mock_cur + + self.storage._connect() + + mock_connect.assert_called_once_with( + host="localhost", + port=5432, + database="test_db", + user="user", + password="password", + ) + self.assertEqual(self.storage.conn, mock_conn) + self.assertEqual(self.storage.cur, mock_cur) + + # Ensure the connection object supports context management + with self.storage.conn: + pass # This should not raise an AttributeError + + @patch("psycopg2.connect") + def test_close(self, mock_connect): + mock_conn = mock_connect.return_value + mock_cur = MagicMock() + mock_conn.cursor.return_value = mock_cur + + self.storage._connect() + self.storage._close() + + mock_cur.close.assert_called_once() + mock_conn.close.assert_called_once() + + @patch("psycopg2.connect") + def test_store_data_with_valid_data(self, mock_connect): + # Mock the cursor and its methods + mock_conn = mock_connect.return_value + mock_cur = MagicMock() + mock_conn.cursor.return_value = mock_cur + mock_execute = mock_cur.execute + mock_commit = mock_conn.commit + + # Test case with valid data + data = { + "gainers": [ + { + "symbol": "ABC", + "name": "ABC Corp", + "volume": 1000, + "price": 10.0, + "market_cap": 1000000, + "change_percent": 5.0, + } + ] + } + data_type = "stock" + + self.storage.store_data(data, data_type) + + # Assert that execute and commit methods were called + mock_execute.assert_called_once() + mock_execute.assert_called_once_with( + "INSERT INTO stock_data.gainers (symbol, name, market_cap, volume, price, change_percent) VALUES (%s, %s, %s, %s, %s, %s)", + ("ABC", "ABC Corp", 1000000, 1000, 10.0, 5.0), + ) + mock_commit.assert_called_once() + + @patch("psycopg2.connect") + def test_store_data_with_invalid_data_empty_symbol(self, mock_connect): + # Test case with invalid data (empty symbol) + data_invalid = { + "gainers": [ + { + "symbol": "", + "name": "ABC Corp", + "volume": 1000, + "price": 10.0, + "market_cap": 1000000, + "change_percent": 5.0, + } + ] + } + data_type = "stock" + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + with self.assertRaises(ValueError): + self.storage.store_data(data_invalid, data_type) + + @patch("psycopg2.connect") + def test_store_data_with_invalid_data_type(self, mock_connect): + # Test case with invalid data type + data = { + "gainers": [ + { + "symbol": "ABC", + "name": "ABC Corp", + "volume": 1000, + "price": 10.0, + "market_cap": 1000000, + "change_percent": 5.0, + } + ] + } + data_type_invalid = 123 # Invalid data type (not a string) + + # Mock the logger.error method to capture log messages + with patch.object(self.logger, "error"): + with self.assertRaises(TypeError): + self.storage.store_data(data, data_type_invalid) + + +class TestMarketDataEngine(unittest.TestCase): + """ + Unit tests for the MarketDataEngine class. + """ + + def setUp(self): + self.stock_api_client = MagicMock(spec=StockApiClient) + self.crypto_api_client = MagicMock(spec=CryptoApiClient) + self.db_connector = MagicMock(spec=Storage) + self.logger = MagicMock(spec=logging.Logger) + self.stock_engine = MarketDataEngine( + self.stock_api_client, self.db_connector, self.logger + ) + self.crypto_engine = MarketDataEngine( + self.crypto_api_client, self.db_connector, self.logger + ) + + def test_process_stock_data(self): + # Mock the return values for the api_client methods + self.stock_api_client.get_stocks.return_value = ["AAPL", "GOOG", "MSFT"] + self.stock_api_client.get_data.return_value = { + "AAPL": {"price": 150.0}, + "GOOG": {"price": 2000.0}, + "MSFT": {"price": 300.0}, + } + + # Call the method under test + self.stock_engine.process_stock_data() + + # Assert that the methods were called with the expected arguments + self.stock_api_client.get_stocks.assert_called_once() + self.stock_api_client.get_data.assert_called_once_with(["AAPL", "GOOG", "MSFT"]) + self.db_connector.store_data.assert_called_once_with( + { + "AAPL": {"price": 150.0}, + "GOOG": {"price": 2000.0}, + "MSFT": {"price": 300.0}, + }, + "stock", + ) + + def test_process_crypto_data(self): + # Mock the return value for the api_client.get_data method + self.crypto_api_client.get_data.return_value = { + "BTC": {"price": 50000.0}, + "ETH": {"price": 4000.0}, + } + + # Call the method under test + self.crypto_engine.process_crypto_data() + + # Assert that the methods were called with the expected arguments + self.crypto_api_client.get_data.assert_called_once() + self.db_connector.store_data.assert_called_once_with( + {"BTC": {"price": 50000.0}, "ETH": {"price": 4000.0}}, "crypto" + ) + + +if __name__ == "__main__": + unittest.main()