-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
282 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import requests | ||
|
||
from microSALT.utils.pubmlst.authentication import generate_oauth_header | ||
from microSALT.utils.pubmlst.helpers import fetch_paginated_data | ||
|
||
BASE_API = "https://rest.pubmlst.org" | ||
|
||
|
||
def query_databases(session_token, session_secret): | ||
"""Query available PubMLST databases.""" | ||
url = f"{BASE_API}/db" | ||
headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} | ||
response = requests.get(url, headers=headers) | ||
if response.status_code == 200: | ||
return response.json() | ||
else: | ||
raise ValueError(f"Failed to query databases: {response.status_code} - {response.text}") | ||
|
||
|
||
def fetch_schemes(database, session_token, session_secret): | ||
"""Fetch available schemes for a database.""" | ||
url = f"{BASE_API}/db/{database}/schemes" | ||
headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} | ||
response = requests.get(url, headers=headers) | ||
if response.status_code == 200: | ||
return response.json() | ||
else: | ||
raise ValueError(f"Failed to fetch schemes: {response.status_code} - {response.text}") | ||
|
||
|
||
def download_profiles(database, scheme_id, session_token, session_secret): | ||
"""Download MLST profiles.""" | ||
url = f"{BASE_API}/db/{database}/schemes/{scheme_id}/profiles" | ||
return fetch_paginated_data(url, session_token, session_secret) | ||
|
||
|
||
def download_locus(database, locus, session_token, session_secret): | ||
"""Download locus sequence files.""" | ||
url = f"{BASE_API}/db/{database}/loci/{locus}/alleles_fasta" | ||
headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} | ||
response = requests.get(url, headers=headers) | ||
if response.status_code == 200: | ||
return response.content # Return raw FASTA content | ||
else: | ||
raise ValueError(f"Failed to download locus: {response.status_code} - {response.text}") | ||
|
||
|
||
def check_database_metadata(database, session_token, session_secret): | ||
"""Check database metadata (last update).""" | ||
url = f"{BASE_API}/db/{database}" | ||
headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} | ||
response = requests.get(url, headers=headers) | ||
if response.status_code == 200: | ||
return response.json() | ||
else: | ||
raise ValueError( | ||
f"Failed to check database metadata: {response.status_code} - {response.text}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import base64 | ||
import hashlib | ||
import hmac | ||
import json | ||
import os | ||
import time | ||
from datetime import datetime, timedelta | ||
from urllib.parse import quote_plus, urlencode | ||
|
||
from dateutil import parser | ||
from rauth import OAuth1Session | ||
|
||
import microSALT.utils.pubmlst.credentials as credentials | ||
|
||
BASE_API = "https://rest.pubmlst.org" | ||
SESSION_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "session_credentials.json") | ||
SESSION_EXPIRATION_BUFFER = 60 # Seconds before expiration to renew | ||
|
||
|
||
def save_session_token(token, secret, expiration_date): | ||
"""Save session token, secret, and expiration to a JSON file.""" | ||
session_data = { | ||
"token": token, | ||
"secret": secret, | ||
"expiration": expiration_date.isoformat(), | ||
} | ||
with open(SESSION_FILE, "w") as f: | ||
json.dump(session_data, f) | ||
print(f"Session token saved to {SESSION_FILE}.") | ||
|
||
|
||
def load_session_token(): | ||
"""Load session token from file if it exists and is valid.""" | ||
if os.path.exists(SESSION_FILE): | ||
with open(SESSION_FILE, "r") as f: | ||
session_data = json.load(f) | ||
expiration = parser.parse(session_data["expiration"]) | ||
if datetime.now() < expiration - timedelta(seconds=SESSION_EXPIRATION_BUFFER): | ||
print("Using existing session token.") | ||
return session_data["token"], session_data["secret"] | ||
return None, None | ||
|
||
|
||
def generate_oauth_header(url, token, token_secret): | ||
"""Generate the OAuth1 Authorization header.""" | ||
oauth_timestamp = str(int(time.time())) | ||
oauth_nonce = base64.urlsafe_b64encode(os.urandom(32)).decode("utf-8").strip("=") | ||
oauth_signature_method = "HMAC-SHA1" | ||
oauth_version = "1.0" | ||
|
||
oauth_params = { | ||
"oauth_consumer_key": credentials.CLIENT_ID, | ||
"oauth_token": token, | ||
"oauth_signature_method": oauth_signature_method, | ||
"oauth_timestamp": oauth_timestamp, | ||
"oauth_nonce": oauth_nonce, | ||
"oauth_version": oauth_version, | ||
} | ||
|
||
# Create the signature base string | ||
params_encoded = urlencode(sorted(oauth_params.items())) | ||
base_string = f"GET&{quote_plus(url)}&{quote_plus(params_encoded)}" | ||
signing_key = f"{credentials.CLIENT_SECRET}&{token_secret}" | ||
|
||
# Sign the base string | ||
hashed = hmac.new(signing_key.encode("utf-8"), base_string.encode("utf-8"), hashlib.sha1) | ||
oauth_signature = base64.b64encode(hashed.digest()).decode("utf-8") | ||
|
||
# Add the signature | ||
oauth_params["oauth_signature"] = oauth_signature | ||
|
||
# Construct the Authorization header | ||
auth_header = "OAuth " + ", ".join( | ||
[f'{quote_plus(k)}="{quote_plus(v)}"' for k, v in oauth_params.items()] | ||
) | ||
return auth_header | ||
|
||
|
||
def get_new_session_token(): | ||
"""Request a new session token using client credentials.""" | ||
print("Fetching a new session token...") | ||
db = "pubmlst_neisseria_seqdef" | ||
url = f"{BASE_API}/db/{db}/oauth/get_session_token" | ||
|
||
session = OAuth1Session( | ||
consumer_key=credentials.CLIENT_ID, | ||
consumer_secret=credentials.CLIENT_SECRET, | ||
access_token=credentials.ACCESS_TOKEN, | ||
access_token_secret=credentials.ACCESS_SECRET, | ||
) | ||
|
||
response = session.get(url, headers={"User-Agent": "BIGSdb downloader"}) | ||
if response.status_code == 200: | ||
token_data = response.json() | ||
session_token = token_data["oauth_token"] | ||
session_secret = token_data["oauth_token_secret"] | ||
expiration_time = datetime.now() + timedelta(hours=12) # 12-hour validity | ||
save_session_token(session_token, session_secret, expiration_time) | ||
return session_token, session_secret | ||
else: | ||
print(f"Error: {response.status_code} - {response.text}") | ||
return None, None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
CLIENT_ID = | ||
CLIENT_SECRET = | ||
ACCESS_TOKEN = | ||
ACCESS_SECRET = |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import json | ||
import os | ||
import sys | ||
|
||
from rauth import OAuth1Service | ||
|
||
BASE_WEB = { | ||
"PubMLST": "https://pubmlst.org/bigsdb", | ||
} | ||
BASE_API = { | ||
"PubMLST": "https://rest.pubmlst.org", | ||
} | ||
|
||
SITE = "PubMLST" | ||
DB = "pubmlst_test_seqdef" | ||
|
||
# Import client_id and client_secret from credentials.py | ||
try: | ||
from microSALT.utils.pubmlst_old.credentials import CLIENT_ID, CLIENT_SECRET | ||
except ImportError: | ||
print("Error: 'credentials.py' file not found or missing CLIENT_ID and CLIENT_SECRET.") | ||
sys.exit(1) | ||
|
||
|
||
def main(): | ||
site = SITE | ||
db = DB | ||
|
||
access_token, access_secret = get_new_access_token(site, db, CLIENT_ID, CLIENT_SECRET) | ||
print(f"\nAccess Token: {access_token}") | ||
print(f"Access Token Secret: {access_secret}") | ||
|
||
save_to_credentials_py(CLIENT_ID, CLIENT_SECRET, access_token, access_secret) | ||
|
||
|
||
def get_new_access_token(site, db, client_id, client_secret): | ||
"""Obtain a new access token and secret.""" | ||
service = OAuth1Service( | ||
name="BIGSdb_downloader", | ||
consumer_key=client_id, | ||
consumer_secret=client_secret, | ||
request_token_url=f"{BASE_API[site]}/db/{db}/oauth/get_request_token", | ||
access_token_url=f"{BASE_API[site]}/db/{db}/oauth/get_access_token", | ||
base_url=BASE_API[site], | ||
) | ||
|
||
request_token, request_secret = get_request_token(service) | ||
print( | ||
"Please log in using your user account at " | ||
f"{BASE_WEB[site]}?db={db}&page=authorizeClient&oauth_token={request_token} " | ||
"using a web browser to obtain a verification code." | ||
) | ||
verifier = input("Please enter verification code: ") | ||
|
||
# Exchange request token for access token | ||
raw_access = service.get_raw_access_token( | ||
request_token, request_secret, params={"oauth_verifier": verifier} | ||
) | ||
if raw_access.status_code != 200: | ||
print(f"Error obtaining access token: {raw_access.text}") | ||
sys.exit(1) | ||
|
||
access_data = raw_access.json() | ||
return access_data["oauth_token"], access_data["oauth_token_secret"] | ||
|
||
|
||
def get_request_token(service): | ||
"""Handle JSON response from the request token endpoint.""" | ||
response = service.get_raw_request_token(params={"oauth_callback": "oob"}) | ||
if response.status_code != 200: | ||
print(f"Error obtaining request token: {response.text}") | ||
sys.exit(1) | ||
try: | ||
data = json.loads(response.text) | ||
return data["oauth_token"], data["oauth_token_secret"] | ||
except json.JSONDecodeError: | ||
print(f"Failed to parse JSON response: {response.text}") | ||
sys.exit(1) | ||
|
||
|
||
def save_to_credentials_py(client_id, client_secret, access_token, access_secret): | ||
"""Save tokens in the credentials.py file.""" | ||
script_dir = os.path.dirname(os.path.abspath(__file__)) | ||
credentials_path = os.path.join(script_dir, "credentials.py") | ||
with open(credentials_path, "w") as f: | ||
f.write(f'CLIENT_ID = "{client_id}"\n') | ||
f.write(f'CLIENT_SECRET = "{client_secret}"\n') | ||
f.write(f'ACCESS_TOKEN = "{access_token}"\n') | ||
f.write(f'ACCESS_SECRET = "{access_secret}"\n') | ||
print(f"Tokens saved to {credentials_path}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import requests | ||
|
||
from microSALT.utils.pubmlst.authentication import generate_oauth_header | ||
|
||
|
||
def fetch_paginated_data(url, session_token, session_secret): | ||
"""Fetch paginated data using the session token and secret.""" | ||
results = [] | ||
while url: | ||
headers = {"Authorization": generate_oauth_header(url, session_token, session_secret)} | ||
response = requests.get(url, headers=headers) | ||
|
||
print(f"Fetching URL: {url}") | ||
print(f"Response Status Code: {response.status_code}") | ||
|
||
if response.status_code == 200: | ||
data = response.json() | ||
results.extend(data.get("profiles", [])) | ||
url = data.get("paging", {}).get("next", None) # Get the next page URL if available | ||
else: | ||
raise ValueError(f"Failed to fetch data: {response.status_code} - {response.text}") | ||
return results |