diff --git a/app/filters.py b/app/filters.py index 7b30d76..38fd83c 100644 --- a/app/filters.py +++ b/app/filters.py @@ -37,6 +37,13 @@ def apply(self, query): column_to_filter = getattr(models.School, self.key) return query.filter(column_to_filter.in_(self.values)) +class TextMatchFilter(Filter): + supported_keys = ['name'] + + def apply(self, query): + column_to_filter = getattr(models.School, self.key) + return query.filter(column_to_filter.icontains(self.values)) + class UpdateTimestampFilter(Filter): supported_keys = ['update_timestamp'] @@ -67,7 +74,13 @@ def apply(self, query): class SchoolFilter: - filter_classes = [StateFilter, BasicFilter, LatLonSorter, BoundingBoxFilter, UpdateTimestampFilter] + filter_classes = [StateFilter, + BasicFilter, + LatLonSorter, + BoundingBoxFilter, + UpdateTimestampFilter, + TextMatchFilter + ] def __init__(self, params): self.used_filters = [] diff --git a/app/main.py b/app/main.py index 19fb4a5..cc224e6 100644 --- a/app/main.py +++ b/app/main.py @@ -36,6 +36,9 @@ def read_schools(skip: int = 0, state: Optional[List[State]] = Query(None), school_type: Optional[List[str]] = Query(None), legal_status: Optional[List[str]] = Query(None), + name: Optional[str] = Query(None, + description="Allows searching for names of schools." + "Searches for case-insensitive substrings."), by_lat: Optional[float] = Query(None, description="Allows ordering result by distance from a geographical point." "Must be used in combination with `by_lon`" @@ -66,6 +69,7 @@ def read_schools(skip: int = 0, "school_type": school_type, "legal_status": legal_status, "update_timestamp": update_timestamp, + "name": name, } if by_lat or by_lon: if not (by_lon and by_lat): diff --git a/test/test_api.py b/test/test_api.py index 749089e..00eaf63 100644 --- a/test/test_api.py +++ b/test/test_api.py @@ -347,6 +347,29 @@ def test_schools_by_update_date(self, client, db): assert response.status_code == 200 assert len(response.json()) == 1 + def test_schools_by_name(self, client, db): + # Arrange + for school in [ + SchoolFactory.create(name='Schule am kleinen Deich'), + SchoolFactory.create(name='Schule an der Dorfstraßse'), + ]: + db.add(school) + db.commit() + + # Act + unfiltered_response = client.get("/schools") + deich_response = client.get("/schools?name=deich") + no_match_response = client.get("/schools?name=nicht%20da") + + # Assert + assert unfiltered_response.status_code == 200 + assert deich_response.status_code == 200 + assert no_match_response.status_code == 200 + + assert len(unfiltered_response.json()) == 2 + assert len(deich_response.json()) == 1 + assert len(no_match_response.json()) == 0 + def test_get_single_no_result(self, client, db): # Arrange self.__setup_schools(db)