-
Notifications
You must be signed in to change notification settings - Fork 2
/
application.py
133 lines (98 loc) · 3.79 KB
/
application.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse, FileResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.staticfiles import StaticFiles
import os
from pydantic import BaseModel
import os
import glob
import shutil
from downloader import Downloader
from demucs_processor import DemucsProcessor
from spotify_to_yt import ConvertSpofity
app = FastAPI()
# Get the directory of the current file
current_dir = os.path.dirname(os.path.abspath(__file__))
# Mount the static directory
app.mount(
"/static", StaticFiles(directory=os.path.join(current_dir, "static")), name="static"
)
# Set up Jinja2 templates
templates = Jinja2Templates(directory=os.path.join(current_dir, "templates"))
global filename
demucs_processor = DemucsProcessor(num_threads=4, segment_size=7)
downloader = Downloader()
class DownloadRequest(BaseModel):
url: str
filetype: str
class ProcessRequest(BaseModel):
filename: str
filetype: str
numStems: int
@app.get("/")
async def home(request: Request):
refresh_directories()
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/delete")
async def delete(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.post("/download_video")
async def download_audio(request: DownloadRequest):
input_url = request.url
if "spotify" in input_url:
url = ConvertSpofity(input_url).get_youtube_url()
else:
url = input_url
if url:
filename = downloader.download_video(url, request.filetype)
print("filename", filename)
return {"status": "success", "filename": str(filename)}
@app.post("/process_audio")
async def process_audio(request: ProcessRequest):
demucs_processor.process_audio(request.filename, request.filetype, request.numStems)
return {"message": "Finished", "filename": str(request.filename)}
@app.get("/download")
async def download(filename: str):
file_path = f"{filename}.zip"
if os.path.exists(file_path):
response = FileResponse(file_path, filename=os.path.basename(file_path))
return response
else:
raise HTTPException(status_code=404, detail="File not found")
@app.get("/tracks/{stem_type}/{songname}")
async def serve_audio(stem_type: str, songname: str):
directory = f"tracks/{stem_type}/{songname}"
if os.path.exists(directory):
files = os.listdir(directory)
return JSONResponse(content=files)
else:
raise HTTPException(status_code=404, detail="Directory not found")
@app.get("/tracks/{stem_type}/{songname}/{filename}")
async def serve_file(stem_type: str, songname: str, filename: str):
file_path = f"tracks/{stem_type}/{songname}/{filename}"
if os.path.exists(file_path):
return FileResponse(file_path)
else:
raise HTTPException(status_code=404, detail="File not found")
@app.get("/login")
async def login(request: Request):
return templates.TemplateResponse("login.html", {"request": request})
@app.get("/register")
async def register(request: Request):
return templates.TemplateResponse("register.html", {"request": request})
def refresh_directories():
for directory in glob.glob("tracks/htdemucs/*"):
if os.path.isdir(directory):
shutil.rmtree(directory)
for directory in glob.glob("tracks/htdemucs_6s/*"):
if os.path.isdir(directory):
shutil.rmtree(directory)
for file in glob.glob("*.mp3") + glob.glob("*.wav") + glob.glob("*.flac"):
os.remove(file)
@app.get("/flaskwebgui-keep-server-alive")
async def keep_alive():
return "Server is alive"
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8001)))