Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fetch feeds in batches for export_csv #907

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 60 additions & 41 deletions api/src/shared/common/db_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Iterator

from geoalchemy2 import WKTElement
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import joinedload, Session
from sqlalchemy.orm.query import Query
Expand All @@ -14,15 +17,11 @@
Entitytype,
Redirectingid,
)

from shared.feed_filters.gtfs_feed_filter import GtfsFeedFilter, LocationFilter
from shared.feed_filters.gtfs_rt_feed_filter import GtfsRtFeedFilter, EntityTypeFilter

from .entity_type_enum import EntityType

from sqlalchemy import or_

from .error_handling import raise_internal_http_validation_error, invalid_bounding_coordinates, invalid_bounding_method
from .iter_utils import batched


def get_gtfs_feeds_query(
Expand Down Expand Up @@ -75,28 +74,39 @@ def get_gtfs_feeds_query(
return feed_query


def get_all_gtfs_feeds_query(
def get_all_gtfs_feeds(
db_session: Session,
include_wip: bool = False,
db_session: Session = None,
) -> Query[any]:
"""Get the DB query to use to retrieve all the GTFS feeds, filtering out the WIP if needed"""

feed_query = db_session.query(Gtfsfeed)

batch_size: int = 250,
) -> Iterator[Gtfsfeed]:
"""
Fetch all GTFS feeds.

@param db_session: The database session.
@param include_wip: Whether to include or exclude WIP feeds.
@param batch_size: The number of feeds to fetch from the database at a time.
A lower value means less memory but more queries.

@return: The GTFS feeds in an iterator.
"""
feed_query = db_session.query(Gtfsfeed).order_by(Gtfsfeed.stable_id).yield_per(batch_size)
if not include_wip:
feed_query = feed_query.filter(
or_(Gtfsfeed.operational_status == None, Gtfsfeed.operational_status != "wip") # noqa: E711
feed_query = feed_query.filter(Gtfsfeed.operational_status.is_distinct_from("wip"))

for batch in batched(feed_query, batch_size):
stable_ids = (f.stable_id for f in batch)
yield from (
db_session.query(Gtfsfeed)
.filter(Gtfsfeed.stable_id.in_(stable_ids))
.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
*get_joinedload_options(),
)
.order_by(Gtfsfeed.stable_id)
)

feed_query = feed_query.options(
joinedload(Gtfsfeed.gtfsdatasets)
.joinedload(Gtfsdataset.validation_reports)
.joinedload(Validationreport.features),
*get_joinedload_options(),
).order_by(Gtfsfeed.stable_id)

return feed_query


def get_gtfs_rt_feeds_query(
limit: int | None,
Expand Down Expand Up @@ -161,29 +171,38 @@ def get_gtfs_rt_feeds_query(
return feed_query


def get_all_gtfs_rt_feeds_query(
def get_all_gtfs_rt_feeds(
db_session: Session,
include_wip: bool = False,
db_session: Session = None,
) -> Query:
"""Get the DB query to use to retrieve all the GTFS rt feeds, filtering out the WIP if needed"""
feed_query = db_session.query(Gtfsrealtimefeed)

batch_size: int = 250,
) -> Iterator[Gtfsrealtimefeed]:
"""
Fetch all GTFS realtime feeds.

@param db_session: The database session.
@param include_wip: Whether to include or exclude WIP feeds.
@param batch_size: The number of feeds to fetch from the database at a time.
A lower value means less memory but more queries.

@return: The GTFS realtime feeds in an iterator.
"""
feed_query = db_session.query(Gtfsrealtimefeed.stable_id).order_by(Gtfsrealtimefeed.stable_id).yield_per(batch_size)
if not include_wip:
feed_query = feed_query.filter(
or_(
Gtfsrealtimefeed.operational_status == None, # noqa: E711
Gtfsrealtimefeed.operational_status != "wip",
feed_query = feed_query.filter(Gtfsrealtimefeed.operational_status.is_distinct_from("wip"))

for batch in batched(feed_query, batch_size):
stable_ids = (f.stable_id for f in batch)
yield from (
db_session.query(Gtfsrealtimefeed)
.filter(Gtfsrealtimefeed.stable_id.in_(stable_ids))
.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*get_joinedload_options(),
)
.order_by(Gtfsfeed.stable_id)
)

feed_query = feed_query.options(
joinedload(Gtfsrealtimefeed.entitytypes),
joinedload(Gtfsrealtimefeed.gtfs_feeds),
*get_joinedload_options(),
).order_by(Gtfsfeed.stable_id)

return feed_query


def apply_bounding_filtering(
query: Query,
Expand Down
13 changes: 13 additions & 0 deletions api/src/shared/common/iter_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from itertools import islice
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially wanted to put this file under api/src/utils but ran into import issues when running the tests so I chose the path of least resistance. Happy to move it.



def batched(iterable, n):
"""
Batch an iterable into tuples of length `n`. The last batch may be shorter.
Based on the implementation in more-itertools and will be built-in once we
switch to Python 3.12+.
"""
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
32 changes: 9 additions & 23 deletions functions-python/export_csv/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from shared.helpers.logger import Logger
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed
from shared.common.db_utils import get_all_gtfs_rt_feeds_query, get_all_gtfs_feeds_query
from shared.common.db_utils import get_all_gtfs_rt_feeds, get_all_gtfs_feeds

from shared.helpers.database import Database

Expand Down Expand Up @@ -114,33 +114,19 @@ def fetch_feeds() -> Iterator[Dict]:
logging.info(f"Using database {db.database_url}")
try:
with db.start_db_session() as session:
gtfs_feeds_query = get_all_gtfs_feeds_query(
include_wip=False,
db_session=session,
)

gtfs_feeds = gtfs_feeds_query.all()

logging.info(f"Retrieved {len(gtfs_feeds)} GTFS feeds.")

gtfs_rt_feeds_query = get_all_gtfs_rt_feeds_query(
include_wip=False,
db_session=session,
)

gtfs_rt_feeds = gtfs_rt_feeds_query.all()

logging.info(f"Retrieved {len(gtfs_rt_feeds)} GTFS realtime feeds.")

for feed in gtfs_feeds:
feed_count = 0
for feed in get_all_gtfs_feeds(session, include_wip=False):
yield get_feed_csv_data(feed)
feed_count += 1

logging.info(f"Processed {len(gtfs_feeds)} GTFS feeds.")
logging.info(f"Processed {feed_count} GTFS feeds.")

for feed in gtfs_rt_feeds:
rt_feed_count = 0
for feed in get_all_gtfs_rt_feeds(session, include_wip=True):
yield get_gtfs_rt_feed_csv_data(feed)
rt_feed_count += 1

logging.info(f"Processed {len(gtfs_rt_feeds)} GTFS realtime feeds.")
logging.info(f"Processed {rt_feed_count} GTFS realtime feeds.")

except Exception as error:
logging.error(f"Error retrieving feeds: {error}")
Expand Down
9 changes: 6 additions & 3 deletions functions-python/export_csv/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def populate_database():
session.add(tu_entitytype)

# GTFS Realtime feeds
for i in range(3):
gtfs_rt_feed = Gtfsrealtimefeed(
gtfs_rt_feeds = [
Gtfsrealtimefeed(
id=fake.uuid4(),
data_type="gtfs_rt",
feed_name=f"gtfs-rt-{i} Some fake name",
Expand All @@ -201,7 +201,10 @@ def populate_database():
provider=f"gtfs-rt-{i} Some fake company",
entitytypes=[vp_entitytype, tu_entitytype] if (i == 0) else [vp_entitytype],
)
session.add(gtfs_rt_feed)
for i in range(3)
]
gtfs_rt_feeds[0].gtfs_feeds.append(active_gtfs_feeds[0])
session.add_all(gtfs_rt_feeds)

session.commit()

Expand Down
2 changes: 1 addition & 1 deletion functions-python/export_csv/tests/test_export_csv_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
gtfs-2,gtfs,,,,,gtfs-2 Some fake company,gtfs-2 Some fake name,gtfs-2 Some fake note,gtfs-2_some_fake_email@fake.com,,https://gtfs-2_some_fake_producer_url,0,,,,https://gtfs-2_some_fake_license_url,,,,,,inactive,,gtfs-0,Some redirect comment
gtfs-deprecated-0,gtfs,,,,,gtfs-deprecated-0 Some fake company,gtfs-deprecated-0 Some fake name,gtfs-deprecated-0 Some fake note,gtfs-deprecated-0_some_fake_email@fake.com,,https://gtfs-deprecated-0_some_fake_producer_url,0,,,,https://gtfs-0_some_fake_license_url,,,,,,deprecated,,,
gtfs-deprecated-1,gtfs,,,,,gtfs-deprecated-1 Some fake company,gtfs-deprecated-1 Some fake name,gtfs-deprecated-1 Some fake note,gtfs-deprecated-1_some_fake_email@fake.com,,https://gtfs-deprecated-1_some_fake_producer_url,1,,,,https://gtfs-1_some_fake_license_url,,,,,,deprecated,,,
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,gtfs-rt-0_some_fake_email@fake.com,,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,,,,,,,,,
gtfs-rt-0,gtfs_rt,tu|vp,,,,gtfs-rt-0 Some fake company,gtfs-rt-0 Some fake name,gtfs-rt-0 Some fake note,gtfs-rt-0_some_fake_email@fake.com,gtfs-0,https://gtfs-rt-0_some_fake_producer_url,0,https://gtfs-rt-0_some_fake_authentication_info_url,gtfs-rt-0_fake_api_key_parameter_name,,https://gtfs-rt-0_some_fake_license_url,,,,,,,,,
gtfs-rt-1,gtfs_rt,vp,,,,gtfs-rt-1 Some fake company,gtfs-rt-1 Some fake name,gtfs-rt-1 Some fake note,gtfs-rt-1_some_fake_email@fake.com,,https://gtfs-rt-1_some_fake_producer_url,1,https://gtfs-rt-1_some_fake_authentication_info_url,gtfs-rt-1_fake_api_key_parameter_name,,https://gtfs-rt-1_some_fake_license_url,,,,,,,,,
gtfs-rt-2,gtfs_rt,vp,,,,gtfs-rt-2 Some fake company,gtfs-rt-2 Some fake name,gtfs-rt-2 Some fake note,gtfs-rt-2_some_fake_email@fake.com,,https://gtfs-rt-2_some_fake_producer_url,2,https://gtfs-rt-2_some_fake_authentication_info_url,gtfs-rt-2_fake_api_key_parameter_name,,https://gtfs-rt-2_some_fake_license_url,,,,,,,,,
""" # noqa
Expand Down
4 changes: 2 additions & 2 deletions functions-python/helpers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import logging
import os
import threading
from typing import Optional
from typing import Optional, ContextManager

from sqlalchemy import create_engine, text, event, Engine
from sqlalchemy.orm import sessionmaker, Session, mapper, class_mapper
Expand Down Expand Up @@ -159,7 +159,7 @@ def _get_session(self, echo: bool) -> "sessionmaker[Session]":
return self._Sessions[echo]

@contextmanager
def start_db_session(self, echo: bool = True):
def start_db_session(self, echo: bool = True) -> ContextManager[Session]:
"""
Context manager to start a database session with optional echo.

Expand Down
Loading