Skip to content

Commit

Permalink
Fix tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotwutingfeng committed Dec 21, 2023
1 parent a138a7f commit 20cee0a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 18 deletions.
38 changes: 23 additions & 15 deletions test_train_arrival.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,22 @@
)


def _verify_arrival_time_response(d):
assert "results" in d
assert isinstance(d["results"], list)
assert d["results"]
first_result = d["results"][0]
assert isinstance(first_result, dict)
def _verify_arrival_time_response(
arrival_time_response, expected_station_name, expected_station_code=None
):
assert "results" in arrival_time_response
results = arrival_time_response["results"]
assert isinstance(results, list)
if not results:
return
assert all(isinstance(result, dict) for result in results)

station_names = set(result.get("mrt", "") for result in results) - set([""])
station_codes = set(result.get("code", "") for result in results) - set([""])
assert len(station_codes) == 1
assert expected_station_code is None or expected_station_code in station_codes
assert len(station_names) == 1
assert expected_station_name in station_names


def test_get_all_station_names():
Expand All @@ -25,17 +35,15 @@ def test_get_train_arrival_time_by_id():
expected_station_name, expected_station_code = test_case[0], test_case[1]
res = get_train_arrival_time_by_id(expected_station_name)
arrival_time_response = json.loads(res)
_verify_arrival_time_response(arrival_time_response)
first_result = arrival_time_response["results"][0]
assert first_result.get("code", "") == expected_station_code
assert first_result.get("mrt", expected_station_name)
_verify_arrival_time_response(
arrival_time_response, expected_station_name, expected_station_code
)


def test_get_all_train_arrival_time():
res = get_all_train_arrival_time(3)
limit = 5
res = get_all_train_arrival_time(limit)
arrival_time_responses = json.loads(res)
assert len(arrival_time_responses) == 3
assert len(arrival_time_responses) == limit
for station_name, arrival_time_response in arrival_time_responses.items():
_verify_arrival_time_response(arrival_time_response)
first_result = arrival_time_response["results"][0]
assert first_result.get("mrt", station_name)
_verify_arrival_time_response(arrival_time_response, station_name)
6 changes: 3 additions & 3 deletions train_arrival.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_train_arrival_time_by_id(station_name):
"""
params = {"station": station_name}

max_attempts = 5
max_attempts = 3

for attempt in range(max_attempts):
if attempt:
Expand All @@ -131,9 +131,9 @@ def get_train_arrival_time_by_id(station_name):
d = json.loads(data)
if not d.get("results", []):
continue
mrt_name = d["results"][0].get("mrt", "")
mrt_names = set(result.get("mrt", "") for result in d["results"]) - set([""])
if (
mrt_name != station_name
len(mrt_names) != 1 or station_name not in mrt_names
): # Ensure that the 'mrt' field matches station name.
continue
return data # Output guaranteed to be valid JSON.
Expand Down

0 comments on commit 20cee0a

Please sign in to comment.