diff --git a/app/__init__.py b/app/__init__.py index a7f7092..de0472e 100644 --- a/app/__init__.py +++ b/app/__init__.py @@ -5,7 +5,12 @@ from app import pages, transactions, data, predict from app.models import db from app.database import init_db_command -from app.data import create_dash_app +from app.dashboard import ( + create_summary_dash_app, + create_fraud_trends_dash_app, + create_geo_analysis_dash_app, + create_device_browser_dash_app +) load_dotenv() @@ -31,10 +36,13 @@ def create_app(): app.register_blueprint(transactions.bp) app.register_blueprint(data.bp) app.register_blueprint(predict.bp) - #app.register_blueprint(train.bp) - # Initialize the Dash app, passing the Flask app as the server - create_dash_app(app) + # Initialize Dash apps and pass the `db` object for querying the database + create_summary_dash_app(app, db) + create_fraud_trends_dash_app(app, db) + create_geo_analysis_dash_app(app, db) + create_device_browser_dash_app(app, db) + # Add CLI command app.cli.add_command(init_db_command) diff --git a/app/dashboard.py b/app/dashboard.py new file mode 100644 index 0000000..f399765 --- /dev/null +++ b/app/dashboard.py @@ -0,0 +1,107 @@ +import pandas as pd +import plotly.express as px +from dash import Dash, dcc, html, Input, Output +from sqlalchemy import create_engine, func +from app.models import Features + +def create_summary_dash_app(flask_app, db): + dash_app = Dash(__name__, server=flask_app, url_base_pathname='/summary/') + + with flask_app.app_context(): + # Query for summary statistics from the 'features' table + total_records = db.session.query(func.count(Features.user_id)).scalar() + avg_purchase_value = db.session.query(func.avg(Features.purchase_value)).scalar() + + dash_app.layout = html.Div([ + html.H3('Summary Statistics'), + html.P(f"Total Records: {total_records}"), + html.P(f"Average Purchase Value: {avg_purchase_value:.2f}") + ]) + + return dash_app + + +def create_fraud_trends_dash_app(flask_app, db): + dash_app = Dash(__name__, server=flask_app, url_base_pathname='/fraud_trends/') + + dash_app.layout = html.Div([ + html.H3('Fraud Trends Over Time'), + dcc.Graph(id='fraud-trends-graph') + ]) + + @dash_app.callback( + Output('fraud-trends-graph', 'figure'), + [Input('fraud-trends-graph', 'id')] + ) + def update_fraud_trends(_): + with flask_app.app_context(): + # Query for fraud trends over time + fraud_trends = ( + db.session.query(Features.purchase_time, func.count(Features.user_id).label('fraud_cases')) + .filter(Features.class_ == 1) # Assuming 'class' indicates fraud status + .group_by(Features.purchase_time) + .all() + ) + fraud_trends_df = pd.DataFrame(fraud_trends, columns=['purchase_time', 'fraud_cases']) + + fig = px.line(fraud_trends_df, x='purchase_time', y='fraud_cases', title='Fraud Cases Over Time') + return fig + + return dash_app + + +def create_geo_analysis_dash_app(flask_app, db): + dash_app = Dash(__name__, server=flask_app, url_base_pathname='/geo_analysis/') + + dash_app.layout = html.Div([ + html.H3('Geographical Analysis of Fraud'), + dcc.Graph(id='geo-analysis-graph') + ]) + + @dash_app.callback( + Output('geo-analysis-graph', 'figure'), + [Input('geo-analysis-graph', 'id')] + ) + def update_geo_analysis(_): + with flask_app.app_context(): + # Query for geographical fraud data + geo_data = ( + db.session.query(Features.country, func.count(Features.ip_address).label('ip_address_count')) + .group_by(Features.country) + .all() + ) + geo_data_df = pd.DataFrame(geo_data, columns=['country', 'ip_address_count']) + + fig = px.bar(geo_data_df, x='ip_address_count', y='country', title='Fraud Cases by Country') + return fig + + return dash_app + + +def create_device_browser_dash_app(flask_app, db): + dash_app = Dash(__name__, server=flask_app, url_base_pathname='/device_browser/') + + dash_app.layout = html.Div([ + html.H3('Device and Browser Analysis'), + dcc.Graph(id='device-browser-graph') + ]) + + @dash_app.callback( + Output('device-browser-graph', 'figure'), + [Input('device-browser-graph', 'id')] + ) + def update_device_browser_analysis(_): + with flask_app.app_context(): + # Query for device and browser data + device_browser_data = ( + db.session.query(Features.device_id, Features.browser, func.count(Features.class_).label('fraud_cases')) + .filter(Features.class_ == 1) # Assuming 'class' indicates fraud status + .group_by(Features.device_id, Features.browser) + .all() + ) + device_browser_df = pd.DataFrame(device_browser_data, columns=['device_id', 'browser', 'fraud_cases']) + + fig = px.bar(device_browser_df, x='device_id', y='fraud_cases', color='browser', title='Fraud by Device and Browser') + return fig + + return dash_app diff --git a/app/data.py b/app/data.py index f29dfd8..5741ea4 100644 --- a/app/data.py +++ b/app/data.py @@ -15,6 +15,10 @@ from mlflow import pyfunc import xgboost as xgb from dash import Dash, dcc, html, Input, Output +from flask import Blueprint +from dash import Dash +import dash_core_components as dcc +import dash_html_components as html load_dotenv() @@ -108,36 +112,3 @@ def result(): print("Prediction: ", prediction) return render_template('data/result.html', prediction=prediction, ID=ID) - -from flask import Blueprint -from dash import Dash -import dash_core_components as dcc -import dash_html_components as html - -# Create a Flask Blueprint -bp = Blueprint("data", __name__) - -# Do not link Dash directly to the Blueprint -# Instead, define Dash within a function where you have access to the main Flask app - -def create_dash_app(flask_app): - # Create the Dash app and pass the Flask app as the server - dash_app = Dash(__name__, server=flask_app, url_base_pathname='/dash/') - - # Define Dash layout - dash_app.layout = html.Div(children=[ - html.H1('Dash App'), - dcc.Graph( - id='example-graph', - figure={ - 'data': [ - {'x': [1, 2, 3], 'y': [4, 1, 2], 'type': 'bar', 'name': 'Sample'} - ], - 'layout': { - 'title': 'Sample Data Visualization' - } - } - ) - ]) - - return dash_app diff --git a/app/models.py b/app/models.py index 27ee3b1..52ddbe9 100644 --- a/app/models.py +++ b/app/models.py @@ -20,3 +20,20 @@ class Transaction(db.Model): upper_bound_ip_address = db.Column(db.Float, nullable=False) country = db.Column(db.String, nullable=False) +class Features(db.Model): + __tablename__ = 'features' + + user_id = db.Column(db.Integer, primary_key=True) + signup_time = db.Column(db.DateTime) + purchase_time = db.Column(db.DateTime) + purchase_value = db.Column(db.Float) + device_id = db.Column(db.String) + source = db.Column(db.String) + browser = db.Column(db.String) + sex = db.Column(db.String) + age = db.Column(db.Integer) + ip_address = db.Column(db.String) + class_ = db.Column(db.Integer) # Fraud class: 0 = non-fraud, 1 = fraud + lower_bound_ip_address = db.Column(db.BigInteger) + upper_bound_ip_address = db.Column(db.BigInteger) + country = db.Column(db.String)