Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Submit STAC using transactions API #297

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions dags/veda_data_pipeline/groups/processing_tasks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from datetime import timedelta
import json
import logging
from copy import deepcopy
import smart_open
from airflow.models.variable import Variable
from airflow.decorators import task
from veda_data_pipeline.utils.submit_stac import submission_handler
from veda_data_pipeline.utils.submit_stac_transactions import submit_transactions_handler

group_kwgs = {"group_id": "Process", "tooltip": "Process"}

airflow_vars = Variable.get("aws_dags_variables")
airflow_vars_json = json.loads(airflow_vars)
TRANSACTIONS_ENDPOINT_ENABLED = airflow_vars_json.get("TRANSACTIONS_ENDPOINT_ENABLED", False)

def log_task(text: str):
logging.info(text)
Expand All @@ -29,8 +32,18 @@ def remove_thumbnail_asset(ti):
payload.pop("assets")
return payload

if TRANSACTIONS_ENDPOINT_ENABLED:
# assuming default chunk size (500), this matches the current dynamoDB configuration on the STAC ingestor
task_kwargs = {"retries": 3, "retry_delay": 10, "retry_exponential_backoff": True, "max_active_tis_per_dag": 2}
submit_kwargs = {}
submit_handler = submit_transactions_handler
else:
task_kwargs = {"retries": 2, "retry_delay": 60, "retry_exponential_backoff": True, "max_active_tis_per_dag": 5}
submit_kwargs = {"endpoint": "/ingestions"}
submit_handler = submission_handler

# with exponential backoff enabled, retry delay is converted to seconds
@task(retries=2, retry_delay=60, retry_exponential_backoff=True, max_active_tis_per_dag=5)
@task(**task_kwargs)
def submit_to_stac_ingestor_task(built_stac: dict):
"""Submit STAC items to the STAC ingestor API."""
event = built_stac.copy()
Expand All @@ -44,11 +57,11 @@ def submit_to_stac_ingestor_task(built_stac: dict):
stac_items = json.loads(_file.read())

for item in stac_items:
submission_handler(
submit_handler(
event=item,
endpoint="/ingestions",
cognito_app_secret=cognito_app_secret,
stac_ingestor_api_url=stac_ingestor_api_url,
**submit_kwargs,
)
return event

Expand Down
3 changes: 0 additions & 3 deletions dags/veda_data_pipeline/utils/submit_stac.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,7 @@ def submission_handler(
endpoint: str = "/ingestions",
cognito_app_secret=None,
stac_ingestor_api_url=None,
context=None,
) -> None | dict:
if context is None:
context = {}

stac_item = event

Expand Down
121 changes: 121 additions & 0 deletions dags/veda_data_pipeline/utils/submit_stac_transactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import json
import logging
import requests
from typing import List, TypedDict

import boto3

logging.basicConfig(level=logging.INFO)

class Creds(TypedDict):
access_token: str
expires_in: int
token_type: str

class AppConfig(TypedDict):
cognito_domain: str
client_id: str
client_secret: str
scope: str

class TransactionsApi:

@classmethod
def from_veda_auth_secret(cls, *, secret_id: str, base_url: str) -> "IngestionApi":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the return type hint be "TransactionsApi" instead of "IngestionApi"

cognito_details = cls._get_cognito_service_details(secret_id)
credentials = cls._get_app_credentials(**cognito_details)
return cls(token=credentials["access_token"], base_url=base_url)

@staticmethod
def _get_cognito_service_details(secret_id: str) -> AppConfig:
client = boto3.client("secretsmanager")
response = client.get_secret_value(SecretId=secret_id)
return json.loads(response["SecretString"])

@staticmethod
def _get_app_credentials(
cognito_domain: str, client_id: str, client_secret: str, scope: str, **kwargs
) -> Creds:
response = requests.post(
f"{cognito_domain}/oauth2/token",
headers={
"Content-Type": "application/x-www-form-urlencoded",
},
auth=(client_id, client_secret),
data={
"grant_type": "client_credentials",
# A space-separated list of scopes to request for the generated access token.
"scope": scope,
},
)
try:
response.raise_for_status()
except Exception as ex:
print(response.text)
raise f"Error, {ex}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raise should be derived from a baseException otherwise it won't raise a string 🤔
Probably we should use

raise RuntimeError(f"Error, {ex}")

return response.json()

def __init__(self, stac_ingestor_api_url: str, cognito_app_secret: str = None):
"""
:param stac_endpoint: Base URL of the STAC API (e.g., 'https://example.com/stac').
:param token: Optional Bearer token for authenticated STAC APIs.
"""
self.stac_ingestor_api_url = stac_ingestor_api_url.rstrip('/')
self.cognito_app_secret = cognito_app_secret
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are we using self.cognito_app_secret?


def post_items(self, collection_id: str, items: List[dict]) -> dict:
"""
Perform a PUT request to update or create a STAC Item in the given collection.

:param collection_id: The target collection ID.
:param item_id: The target item ID.
:param item_body: The full STAC Item JSON body.
:return: The JSON response (as a dict) from the STAC API.
:raises RuntimeError: If the response is not 200/201.
"""
url = f"{self.base_url.rstrip('/')}{self.stac_ingestor_api_url}/collections/{collection_id}/bulk_items"
headers = {"Content-Type": "application/json"}

if self.token:
headers["Authorization"] = f"Bearer {self.token}"

logging.info(f"PUT {url}")
response = requests.post(url, headers=headers, json=items)

if response.status_code not in (200, 201):
logging.error("Failed PUT request: %s %s", response.status_code, response.text)
raise RuntimeError(f"PUT request failed: {response.text}")

return response.json()


def submit_transactions_handler(
event,
cognito_app_secret=None,
stac_ingestor_api_url=None,
):
"""
Handler function that can be integrated in the same way as the existing `submission_handler`,
but uses the TransactionsApi to perform a PUT request to STAC's Transactions endpoint.

:param event: A dict containing the data needed for STAC item submission,
including collection_id, item_id, and the STAC item body itself.
:param context: (Optional) context object, for AWS Lambda or similar environments.
:return: A dict representing the API response.
"""

collection_id = event[0].get("collection")
api = TransactionsApi(stac_ingestor_api_url, cognito_app_secret)
try:
response = api.post_items(collection_id, event)
logging.info("STAC Item POST completed successfully.")
except RuntimeError as err:
logging.error("Error while performing POST: %s", str(err))
raise
return {
"statusCode": 200,
"body": json.dumps({
"message": "POST request completed successfully",
"response": response
})
}
2 changes: 0 additions & 2 deletions dags/veda_data_pipeline/veda_promotion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ def transfer_assets_to_production_bucket(ti=None, payload={}):
return payload

with DAG("veda_promotion_pipeline", params=template_dag_run_conf, **dag_args) as dag:
# ECS dependency variable

start = EmptyOperator(task_id="start", dag=dag)
end = EmptyOperator(task_id="end", dag=dag)

Expand Down