diff --git a/microSALT/utils/pubmlst/__init__.py b/microSALT/utils/pubmlst/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/microSALT/utils/pubmlst/api.py b/microSALT/utils/pubmlst/api.py new file mode 100644 index 00000000..5f49aa38 --- /dev/null +++ b/microSALT/utils/pubmlst/api.py @@ -0,0 +1,58 @@ +import requests + +from microSALT.utils.pubmlst.authentication import generate_oauth_header +from microSALT.utils.pubmlst.helpers import fetch_paginated_data + +BASE_API = "https://rest.pubmlst.org" + + +def query_databases(session_token, session_secret): + """Query available PubMLST databases.""" + url = f"{BASE_API}/db" + headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} + response = requests.get(url, headers=headers) + if response.status_code == 200: + return response.json() + else: + raise ValueError(f"Failed to query databases: {response.status_code} - {response.text}") + + +def fetch_schemes(database, session_token, session_secret): + """Fetch available schemes for a database.""" + url = f"{BASE_API}/db/{database}/schemes" + headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} + response = requests.get(url, headers=headers) + if response.status_code == 200: + return response.json() + else: + raise ValueError(f"Failed to fetch schemes: {response.status_code} - {response.text}") + + +def download_profiles(database, scheme_id, session_token, session_secret): + """Download MLST profiles.""" + url = f"{BASE_API}/db/{database}/schemes/{scheme_id}/profiles" + return fetch_paginated_data(url, session_token, session_secret) + + +def download_locus(database, locus, session_token, session_secret): + """Download locus sequence files.""" + url = f"{BASE_API}/db/{database}/loci/{locus}/alleles_fasta" + headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} + response = requests.get(url, headers=headers) + if response.status_code == 200: + return response.content # Return raw FASTA content + else: + raise ValueError(f"Failed to download locus: {response.status_code} - {response.text}") + + +def check_database_metadata(database, session_token, session_secret): + """Check database metadata (last update).""" + url = f"{BASE_API}/db/{database}" + headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} + response = requests.get(url, headers=headers) + if response.status_code == 200: + return response.json() + else: + raise ValueError( + f"Failed to check database metadata: {response.status_code} - {response.text}" + ) diff --git a/microSALT/utils/pubmlst/authentication.py b/microSALT/utils/pubmlst/authentication.py new file mode 100644 index 00000000..46fc2d50 --- /dev/null +++ b/microSALT/utils/pubmlst/authentication.py @@ -0,0 +1,102 @@ +import base64 +import hashlib +import hmac +import json +import os +import time +from datetime import datetime, timedelta +from urllib.parse import quote_plus, urlencode + +from dateutil import parser +from rauth import OAuth1Session + +import microSALT.utils.pubmlst.credentials as credentials + +BASE_API = "https://rest.pubmlst.org" +SESSION_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "session_credentials.json") +SESSION_EXPIRATION_BUFFER = 60 # Seconds before expiration to renew + + +def save_session_token(token, secret, expiration_date): + """Save session token, secret, and expiration to a JSON file.""" + session_data = { + "token": token, + "secret": secret, + "expiration": expiration_date.isoformat(), + } + with open(SESSION_FILE, "w") as f: + json.dump(session_data, f) + print(f"Session token saved to {SESSION_FILE}.") + + +def load_session_token(): + """Load session token from file if it exists and is valid.""" + if os.path.exists(SESSION_FILE): + with open(SESSION_FILE, "r") as f: + session_data = json.load(f) + expiration = parser.parse(session_data["expiration"]) + if datetime.now() < expiration - timedelta(seconds=SESSION_EXPIRATION_BUFFER): + print("Using existing session token.") + return session_data["token"], session_data["secret"] + return None, None + + +def generate_oauth_header(url, token, token_secret): + """Generate the OAuth1 Authorization header.""" + oauth_timestamp = str(int(time.time())) + oauth_nonce = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8").strip("=") + oauth_signature_method = "HMAC-SHA1" + oauth_version = "1.0" + + oauth_params = { + "oauth_consumer_key": credentials.CLIENT_ID, + "oauth_token": token, + "oauth_signature_method": oauth_signature_method, + "oauth_timestamp": oauth_timestamp, + "oauth_nonce": oauth_nonce, + "oauth_version": oauth_version, + } + + # Create the signature base string + params_encoded = urlencode(sorted(oauth_params.items())) + base_string = f"GET&{quote_plus(url)}&{quote_plus(params_encoded)}" + signing_key = f"{credentials.CLIENT_SECRET}&{token_secret}" + + # Sign the base string + hashed = hmac.new(signing_key.encode("utf-8"), base_string.encode("utf-8"), hashlib.sha1) + oauth_signature = base64.b64encode(hashed.digest()).decode("utf-8") + + # Add the signature + oauth_params["oauth_signature"] = oauth_signature + + # Construct the Authorization header + auth_header = "OAuth " + ", ".join( + [f'{quote_plus(k)}="{quote_plus(v)}"' for k, v in oauth_params.items()] + ) + return auth_header + + +def get_new_session_token(): + """Request a new session token using client credentials.""" + print("Fetching a new session token...") + db = "pubmlst_neisseria_seqdef" + url = f"{BASE_API}/db/{db}/oauth/get_session_token" + + session = OAuth1Session( + consumer_key=credentials.CLIENT_ID, + consumer_secret=credentials.CLIENT_SECRET, + access_token=credentials.ACCESS_TOKEN, + access_token_secret=credentials.ACCESS_SECRET, + ) + + response = session.get(url, headers={"User-Agent": "BIGSdb downloader"}) + if response.status_code == 200: + token_data = response.json() + session_token = token_data["oauth_token"] + session_secret = token_data["oauth_token_secret"] + expiration_time = datetime.now() + timedelta(hours=12) # 12-hour validity + save_session_token(session_token, session_secret, expiration_time) + return session_token, session_secret + else: + print(f"Error: {response.status_code} - {response.text}") + return None, None diff --git a/microSALT/utils/pubmlst/credentials.py b/microSALT/utils/pubmlst/credentials.py new file mode 100644 index 00000000..a8189f01 --- /dev/null +++ b/microSALT/utils/pubmlst/credentials.py @@ -0,0 +1,4 @@ +CLIENT_ID = +CLIENT_SECRET = +ACCESS_TOKEN = +ACCESS_SECRET = diff --git a/microSALT/utils/pubmlst/get_credentials.py b/microSALT/utils/pubmlst/get_credentials.py new file mode 100644 index 00000000..3e62a869 --- /dev/null +++ b/microSALT/utils/pubmlst/get_credentials.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 + +import json +import os +import sys + +from rauth import OAuth1Service + +BASE_WEB = { + "PubMLST": "https://pubmlst.org/bigsdb", +} +BASE_API = { + "PubMLST": "https://rest.pubmlst.org", +} + +SITE = "PubMLST" +DB = "pubmlst_test_seqdef" + +# Import client_id and client_secret from credentials.py +try: + from microSALT.utils.pubmlst_old.credentials import CLIENT_ID, CLIENT_SECRET +except ImportError: + print("Error: 'credentials.py' file not found or missing CLIENT_ID and CLIENT_SECRET.") + sys.exit(1) + + +def main(): + site = SITE + db = DB + + access_token, access_secret = get_new_access_token(site, db, CLIENT_ID, CLIENT_SECRET) + print(f"\nAccess Token: {access_token}") + print(f"Access Token Secret: {access_secret}") + + save_to_credentials_py(CLIENT_ID, CLIENT_SECRET, access_token, access_secret) + + +def get_new_access_token(site, db, client_id, client_secret): + """Obtain a new access token and secret.""" + service = OAuth1Service( + name="BIGSdb_downloader", + consumer_key=client_id, + consumer_secret=client_secret, + request_token_url=f"{BASE_API[site]}/db/{db}/oauth/get_request_token", + access_token_url=f"{BASE_API[site]}/db/{db}/oauth/get_access_token", + base_url=BASE_API[site], + ) + + request_token, request_secret = get_request_token(service) + print( + "Please log in using your user account at " + f"{BASE_WEB[site]}?db={db}&page=authorizeClient&oauth_token={request_token} " + "using a web browser to obtain a verification code." + ) + verifier = input("Please enter verification code: ") + + # Exchange request token for access token + raw_access = service.get_raw_access_token( + request_token, request_secret, params={"oauth_verifier": verifier} + ) + if raw_access.status_code != 200: + print(f"Error obtaining access token: {raw_access.text}") + sys.exit(1) + + access_data = raw_access.json() + return access_data["oauth_token"], access_data["oauth_token_secret"] + + +def get_request_token(service): + """Handle JSON response from the request token endpoint.""" + response = service.get_raw_request_token(params={"oauth_callback": "oob"}) + if response.status_code != 200: + print(f"Error obtaining request token: {response.text}") + sys.exit(1) + try: + data = json.loads(response.text) + return data["oauth_token"], data["oauth_token_secret"] + except json.JSONDecodeError: + print(f"Failed to parse JSON response: {response.text}") + sys.exit(1) + + +def save_to_credentials_py(client_id, client_secret, access_token, access_secret): + """Save tokens in the credentials.py file.""" + script_dir = os.path.dirname(os.path.abspath(__file__)) + credentials_path = os.path.join(script_dir, "credentials.py") + with open(credentials_path, "w") as f: + f.write(f'CLIENT_ID = "{client_id}"\n') + f.write(f'CLIENT_SECRET = "{client_secret}"\n') + f.write(f'ACCESS_TOKEN = "{access_token}"\n') + f.write(f'ACCESS_SECRET = "{access_secret}"\n') + print(f"Tokens saved to {credentials_path}") + + +if __name__ == "__main__": + main() diff --git a/microSALT/utils/pubmlst/helpers.py b/microSALT/utils/pubmlst/helpers.py new file mode 100644 index 00000000..ad01eeb0 --- /dev/null +++ b/microSALT/utils/pubmlst/helpers.py @@ -0,0 +1,22 @@ +import requests + +from microSALT.utils.pubmlst.authentication import generate_oauth_header + + +def fetch_paginated_data(url, session_token, session_secret): + """Fetch paginated data using the session token and secret.""" + results = [] + while url: + headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} + response = requests.get(url, headers=headers) + + print(f"Fetching URL: {url}") + print(f"Response Status Code: {response.status_code}") + + if response.status_code == 200: + data = response.json() + results.extend(data.get("profiles", [])) + url = data.get("paging", {}).get("next", None) # Get the next page URL if available + else: + raise ValueError(f"Failed to fetch data: {response.status_code} - {response.text}") + return results