From 0f3226504384449ea3a80aa820b7e477914dd777 Mon Sep 17 00:00:00 2001
From: Do1e <dpj.email@qq.com>
Date: Tue, 28 Nov 2023 19:09:54 +0800
Subject: [PATCH] fix: #94 add s3_proxy and onedrive_proxy config

---
 core/settings.py |  2 ++
 core/storage.py  | 54 +++++++++++++++++++++++++++++++++++++++++-------
 2 files changed, 49 insertions(+), 7 deletions(-)

diff --git a/core/settings.py b/core/settings.py
index 17937293..2d748cab 100644
--- a/core/settings.py
+++ b/core/settings.py
@@ -18,12 +18,14 @@
     's3_secret_access_key': '',
     's3_bucket_name': '',
     's3_endpoint_url': '',
+    's3_proxy': 0,
     'aws_session_token': '',
     'onedrive_domain': '',
     'onedrive_client_id': '',
     'onedrive_username': '',
     'onedrive_password': '',
     'onedrive_root_path': 'filebox_storage',
+    'onedrive_proxy': 0,
     'admin_token': 'FileCodeBox2023',
     'openUpload': 1,
     'uploadSize': 1024 * 1024 * 10,
diff --git a/core/storage.py b/core/storage.py
index 3bec7a4a..446adcd9 100644
--- a/core/storage.py
+++ b/core/storage.py
@@ -2,9 +2,11 @@
 # @Author  : Lan
 # @File    : storage.py
 # @Software: PyCharm
+import aiohttp
 import asyncio
 from pathlib import Path
 import datetime
+import io
 import re
 import sys
 import aioboto3
@@ -91,6 +93,7 @@ def __init__(self):
         self.bucket_name = settings.s3_bucket_name
         self.endpoint_url = settings.s3_endpoint_url
         self.aws_session_token = settings.aws_session_token
+        self.proxy = settings.s3_proxy
         self.session = aioboto3.Session(aws_access_key_id=self.access_key_id, aws_secret_access_key=self.secret_access_key)
 
     async def save_file(self, file: UploadFile, save_path: str):
@@ -101,12 +104,32 @@ async def delete_file(self, file_code: FileCodes):
         async with self.session.client("s3", endpoint_url=self.endpoint_url) as s3:
             await s3.delete_object(Bucket=self.bucket_name, Key=await file_code.get_file_path())
 
+    async def get_file_response(self, file_code: FileCodes):
+        try:
+            filename = file_code.prefix + file_code.suffix
+            async with self.session.client("s3", endpoint_url=self.endpoint_url) as s3:
+                link = await s3.generate_presigned_url('get_object', Params={'Bucket': self.bucket_name, 'Key': await file_code.get_file_path()}, ExpiresIn=3600)
+            tmp = io.BytesIO()
+            async with aiohttp.ClientSession() as session:
+                async with session.get(link) as resp:
+                    tmp.write(await resp.read())
+            tmp.seek(0)
+            content = tmp.read()
+            tmp.close()
+            return Response(content, media_type="application/octet-stream", headers=
+                            {"Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"'})
+        except Exception:
+            raise HTTPException(status_code=503, detail='服务代理下载异常,请稍后再试')
+
     async def get_file_url(self, file_code: FileCodes):
         if file_code.prefix == '文本分享':
             return file_code.text
-        async with self.session.client("s3", endpoint_url=self.endpoint_url) as s3:
-            result = await s3.generate_presigned_url('get_object', Params={'Bucket': self.bucket_name, 'Key': await file_code.get_file_path()}, ExpiresIn=3600)
-            return result
+        if self.proxy:
+            return await get_file_url(file_code.code)
+        else:
+            async with self.session.client("s3", endpoint_url=self.endpoint_url) as s3:
+                result = await s3.generate_presigned_url('get_object', Params={'Bucket': self.bucket_name, 'Key': await file_code.get_file_path()}, ExpiresIn=3600)
+                return result
 
 
 class OneDriveFileStorage(FileStorageInterface):
@@ -122,6 +145,7 @@ def __init__(self):
         self.client_id = settings.onedrive_client_id
         self.username = settings.onedrive_username
         self.password = settings.onedrive_password
+        self.proxy = settings.onedrive_proxy
         self._ClientRequestException = ClientRequestException
 
         try:
@@ -193,11 +217,27 @@ def _get_file_url(self, save_path, name):
         premission = remote_file.create_link("view", "anonymous", expiration_datetime=expiration_datetime).execute_query()
         return self._convert_link_to_download_link(premission.link.webUrl)
 
+    async def get_file_response(self, file_code: FileCodes):
+        try:
+            filename = file_code.prefix + file_code.suffix
+            link = await asyncio.to_thread(self._get_file_url, await file_code.get_file_path(), filename)
+            tmp = io.BytesIO()
+            async with aiohttp.ClientSession() as session:
+                async with session.get(link) as resp:
+                    tmp.write(await resp.read())
+            tmp.seek(0)
+            content = tmp.read()
+            tmp.close()
+            return Response(content, media_type="application/octet-stream", headers=
+                            {"Content-Disposition": f'attachment; filename="{filename.encode("utf-8").decode("latin-1")}"'})
+        except Exception:
+            raise HTTPException(status_code=503, detail='服务代理下载异常,请稍后再试')
+
     async def get_file_url(self, file_code: FileCodes):
-        if file_code.prefix == '文本分享':
-            return file_code.text
-        result = await asyncio.to_thread(self._get_file_url, await file_code.get_file_path(), f'{file_code.prefix}{file_code.suffix}')
-        return result
+        if self.proxy:
+            return await get_file_url(file_code.code)
+        else:
+            return await asyncio.to_thread(self._get_file_url, await file_code.get_file_path(), f'{file_code.prefix}{file_code.suffix}')
 
 
 class OpenDALFileStorage(FileStorageInterface):