Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add API endpoint to query by SMILES #65

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 133 additions & 39 deletions backend/app/app/api/v2/endpoints/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,101 @@ def _pandas_to_buffer(df):

return buffer

def _query_data(molecule_id, data_type, db):

# Check for valid data type.
if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]:
raise HTTPException(status_code=400, detail="Invalid data type.")

table_name = f"{data_type}_data"
query = text(f"""
SELECT t.*, m.SMILES
FROM {table_name} t
JOIN molecule m ON t.molecule_id = m.molecule_id
WHERE t.molecule_id = :molecule_id
""")

stmt = query.bindparams(molecule_id=molecule_id)

results = db.execute(stmt).fetchall()

# Hacky way to get each row as a dictionary.
# do this to generalize for different data sets - column names may vary.
list_of_dicts = [row._asdict() for row in results]

return list_of_dicts

def _get_molecule_data_types(molecule_id, db):
"""Get the data types (e.g. ML, DFT, xTB) available for a molecule.

Parameters
----------
molecule_id : int
The molecule ID.
db : Session
The database session.

Returns
-------
dict
A dictionary with the data types available in the database. `True` or `False` values indicate whether the data type is available for the molecule.

"""

# First, fetch all table names that have 'data' in their name
query_tables = text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name LIKE '%data%';")
tables = db.execute(query_tables).fetchall()

# Prepare a dictionary to store the results
results = {}

# Loop through each table to check for the molecule_id
for table in tables:
table_name = table[0]
# Assuming the column storing molecule IDs is named 'molecule_id' in all tables.
# You might need to adjust this according to your database schema.
query_check = text(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE molecule_id = :molecule_id) AS exists;")
exists = db.execute(query_check, {"molecule_id": molecule_id}).fetchone()[0]
results[table_name] = exists

return results

def _build_multiple_molecule_query(molecule_ids_list, data_type):

# Use pandas.read_sql_query to get the data.
table_name = f"{data_type}_data"

# Generating a safe query with placeholders
placeholders = ', '.join([':id' + str(i) for i in range(len(molecule_ids_list))])
query_parameters = {'id' + str(i): mid for i, mid in enumerate(molecule_ids_list)}

query = text(f"""
SELECT t.*, m.SMILES
FROM {table_name} t
JOIN molecule m ON t.molecule_id = m.molecule_id
WHERE t.molecule_id IN ({placeholders})
""")

return query, query_parameters

def _get_id_from_smiles(smiles, db):

canonical_smiles = valid_smiles(smiles)

query = text("""SELECT molecule_id FROM molecule WHERE canonical_smiles = :smiles""")

stmt = query.bindparams(smiles=canonical_smiles)

molecule_id = db.execute(stmt).fetchone()

try:
molecule_id = molecule_id[0]
except TypeError:
molecule_id = None

return molecule_id


def valid_smiles(smiles):
"""Check to see if a smile string is valid to represent a molecule.

Expand Down Expand Up @@ -99,21 +194,7 @@ async def get_data_types(db: Session = Depends(deps.get_db)):

@router.get("/{molecule_id}/data_types", response_model=Any)
async def get_molecule_data_types(molecule_id: int | str, db: Session = Depends(deps.get_db)):
# First, fetch all table names that have 'data' in their name
query_tables = text("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_name LIKE '%data%';")
tables = db.execute(query_tables).fetchall()

# Prepare a dictionary to store the results
results = {}

# Loop through each table to check for the molecule_id
for table in tables:
table_name = table[0]
# Assuming the column storing molecule IDs is named 'molecule_id' in all tables.
# You might need to adjust this according to your database schema.
query_check = text(f"SELECT EXISTS(SELECT 1 FROM {table_name} WHERE molecule_id = :molecule_id) AS exists;")
exists = db.execute(query_check, {"molecule_id": molecule_id}).fetchone()[0]
results[table_name] = exists
results = _get_molecule_data_types(molecule_id, db)

return results

Expand All @@ -126,33 +207,50 @@ async def get_molecule_data(molecule_id: int | str,
if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]:
raise HTTPException(status_code=400, detail="Invalid data type.")

table_name = f"{data_type}_data"
query = text(f"""
SELECT t.*, m.SMILES
FROM {table_name} t
JOIN molecule m ON t.molecule_id = m.molecule_id
WHERE t.molecule_id = :molecule_id
""")

stmt = query.bindparams(molecule_id=molecule_id)
result = _query_data(molecule_id, data_type, db)

return result

results = db.execute(stmt).fetchall()
@router.get("/data", response_model=Any)
async def get_data(smiles: str, data_type: str="ml", db: Session = Depends(deps.get_db)):
"""Get the data for a molecule by its SMILES string."""

# Hacky way to get each row as a dictionary.
# do this to generalize for different data sets - column names may vary.
list_of_dicts = [row._asdict() for row in results]
molecule_id = _get_id_from_smiles(smiles, db)

return list_of_dicts
if molecule_id is None:
raise HTTPException(status_code=404, detail="Molecule not found")

# Check that the molecule has the requested data type.
valid_datatypes = _get_molecule_data_types(molecule_id, db)

if not valid_datatypes.get(f"{data_type}_data"):
raise HTTPException(status_code=404, detail="Data type not found for molecule")

result = _query_data(molecule_id, data_type, db)

return result

@router.get("/data/export/batch")
async def get_molecules_data(molecule_ids: str,
async def get_molecules_data(molecule_ids: Optional[str]=None,
molecule_smiles: Optional[str]=None,
data_type: str="ml",
return_type: str="csv",
context: Optional[str]=None,
db: Session = Depends(deps.get_db)):

if molecule_ids is None and molecule_smiles is None:
raise HTTPException(status_code=400, detail="No molecule IDs or SMILES provided.")

if molecule_ids is not None and molecule_smiles is not None:
raise HTTPException(status_code=400, detail="Provide either molecule IDs or SMILES, not both.")

if molecule_smiles:
molecule_smiles_list = [ x.strip() for x in molecule_smiles.split(",")]
molecule_ids_list = [_get_id_from_smiles(smiles, db) for smiles in molecule_smiles_list]

if molecule_ids:
molecule_ids_list = [ x.strip() for x in molecule_ids.split(",")]

molecule_ids_list = [ x.strip() for x in molecule_ids.split(",")]
first_molecule_id = molecule_ids_list[0]
num_molecules = len(molecule_ids_list)

Expand All @@ -172,14 +270,8 @@ async def get_molecules_data(molecule_ids: str,

# Generating a safe query with placeholders
placeholders = ', '.join([':id' + str(i) for i in range(len(molecule_ids_list))])
query_parameters = {'id' + str(i): mid for i, mid in enumerate(molecule_ids_list)}

query = text(f"""
SELECT t.*, m.SMILES
FROM {table_name} t
JOIN molecule m ON t.molecule_id = m.molecule_id
WHERE t.molecule_id IN ({placeholders})
""")
query, query_parameters = _build_multiple_molecule_query(molecule_ids_list, data_type)

df = pd.read_sql_query(query, db.bind, params=query_parameters)

Expand All @@ -201,12 +293,14 @@ async def get_molecules_data(molecule_ids: str,

return response



@router.get("/data/export/{molecule_id}")
async def export_molecule_data(molecule_id: int | str,
data_type: str="ml",
db: Session = Depends(deps.get_db)):
"""Export data for a single molecule as a CSV."""


# Check for valid data type.
if data_type.lower() not in ["ml", "dft", "xtb", "xtb_ni"]:
raise HTTPException(status_code=400, detail="Invalid data type.")
Expand Down