-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
133 lines (104 loc) · 3.45 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import models
import yfinance
from fastapi import FastAPI, Request, Depends, BackgroundTasks
from fastapi.templating import Jinja2Templates
from database import SessionLocal, engine
from pydantic import BaseModel
from models import Stock
from sqlalchemy.orm import Session
app = FastAPI()
models.Base.metadata.create_all(bind=engine)
templates = Jinja2Templates(directory="templates")
class StockRequest(BaseModel):
symbol: str
def get_db():
try:
db = SessionLocal()
yield db
finally:
db.close()
@app.get('/')
def home(request: Request, symbol = None, forward_pe=None, sector=None, order=None, db: Session = Depends(get_db)):
"""
Show all stocks in the database.
Filters by symbol, p/e and sector if provided.
"""
stocks = db.query(Stock)
if symbol:
stocks = stocks.filter(Stock.symbol == symbol)
if forward_pe:
stocks = stocks.filter(Stock.forward_pe > forward_pe)
if sector:
stocks = stocks.filter(Stock.sector == sector)
if (order is None) or (order == 'Market Cap'):
stocks = stocks.order_by(Stock.marketcap.desc()).all()
elif order == 'Country':
stocks = stocks.order_by(Stock.country).all()
elif order == 'Sector':
stocks = stocks.order_by(Stock.sector).all()
elif order == 'P/E':
stocks = stocks.order_by(Stock.forward_pe.desc()).all()
elif order == 'Volume':
stocks = stocks.order_by(Stock.volume.desc()).all()
return templates.TemplateResponse("home.html", {
"request": request,
"stocks": stocks,
"symbol": symbol
})
def fetch_stock_data(symbol: str):
"""
Background task to use yfinance and load key statistics
"""
db = SessionLocal()
stock = db.query(Stock).filter(Stock.symbol == symbol).first()
yahoo_data = yfinance.Ticker(stock.symbol)
if (stock.symbol is None) or (stock.symbol == "") or (yahoo_data.info['regularMarketPrice'] is None):
db.delete(stock)
db.commit()
return None
stock.name = yahoo_data.info['shortName']
try:
stock.country = yahoo_data.info['country']
except KeyError:
stock.country = None
try:
stock.sector = yahoo_data.info['sector']
except KeyError:
stock.sector = None
stock.ma200 = yahoo_data.info['twoHundredDayAverage']
stock.ma50 = yahoo_data.info['fiftyDayAverage']
stock.price = yahoo_data.info['previousClose']
stock.forward_pe = yahoo_data.info['forwardPE']
stock.marketcap = yahoo_data.info['marketCap']
stock.volume = yahoo_data.info['volume']
db.add(stock)
db.commit()
@app.post("/stock")
async def create_stock(stock_request: StockRequest, db: Session = Depends(get_db)):
"""
Add one or more tickers to the database
"""
stock = Stock()
stock.symbol = stock_request.symbol
if stock.symbol is None or stock.symbol == "":
return None
db.add(stock)
db.commit()
return {
"code": "success",
"message": "stock was added to the database"
}
@app.post("/fetchall")
async def fetchall(stock_request: StockRequest, background_tasks: BackgroundTasks):
"""
Get yfinance for update all symbols
"""
stock = Stock()
stock.symbol = stock_request.symbol
if stock.symbol is None or stock.symbol == "":
return None
background_tasks.add_task(fetch_stock_data, stock.symbol)
return {
"code": "success",
"message": "fetch all stocks"
}