Skip to content

Commit

Permalink
Added escape from api test
Browse files Browse the repository at this point in the history
Doesn't decode train services if API key test accepted
  • Loading branch information
plutomedia987 committed Jan 19, 2025
1 parent 6d70905 commit 904a9c7
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 28 deletions.
66 changes: 42 additions & 24 deletions custom_components/nationalrailuk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def rebuild_date(base, time):
class NationalRailClient:
"""Client for the National Rail API"""

def __init__(self, api_token, station, destinations) -> None:
def __init__(self, api_token, station, destinations, apiTest=False) -> None:
self.station = station
self.api_token = api_token
self.destinations = destinations if destinations is not None else []
Expand Down Expand Up @@ -80,6 +80,8 @@ def __init__(self, api_token, station, destinations) -> None:

self.client = AsyncClient(wsdl=WSDL, settings=settings, plugins=[self.history])

self.apitest = apiTest

# Prepackage the authorisation token
header = xsd.Element(
"{http://thalesgroup.com/RTTI/2013-11-28/Token/types}AccessToken",
Expand Down Expand Up @@ -115,20 +117,32 @@ async def get_raw_arrivals_departures(self):
_soapheaders=[self.header_value],
)

# Build header info
if not res[each]["generatedAt"]:
res[each]["generatedAt"] = batch["generatedAt"]
res[each]["locationName"] = batch["locationName"]
res[each]["crs"] = batch["crs"]
res[each]["filterLocationName"] = batch["filterLocationName"]
res[each]["filtercrs"] = batch["filtercrs"]

if not res[each][ft["keyName"]]:
res[each][ft["keyName"]] = batch["trainServices"]["service"]
else:
res[each][ft["keyName"]].append(
batch["trainServices"]["service"]
)
if not self.apitest:
# try:
# Build header info
if not res[each]["generatedAt"]:
res[each]["generatedAt"] = batch["generatedAt"]
res[each]["locationName"] = batch["locationName"]
res[each]["crs"] = batch["crs"]
res[each]["filterLocationName"] = batch[
"filterLocationName"
]
res[each]["filtercrs"] = batch["filtercrs"]

# print(batch)

if not res[each][ft["keyName"]]:
res[each][ft["keyName"]] = batch["trainServices"]["service"]
else:
res[each][ft["keyName"]].append(
batch["trainServices"]["service"]
)
# except (KeyError, TypeError) as err:
# raise NationalRailClientException(
# "No train services returned from API"
# ) from err
# except Fault as err:
# raise NationalRailClientException("Unknown error") from err

# with open("output.txt", "w") as convert_file:
# convert_file.write(str(res))
Expand Down Expand Up @@ -385,12 +399,16 @@ async def async_get_data(self):
# with open("output.txt", "w") as convert_file:
# convert_file.write(str(raw_data))

try:
_LOGGER.info("Procession station schedule for %s", self.station)
data = self.process_data(raw_data)
# with open("output.json", "w") as convert_file:
# convert_file.write(str(data))
except Exception as err:
_LOGGER.exception("Exception whilst processing data: ")
raise NationalRailClientException("unexpected data from api") from err
return data
if not self.apitest:
# print(f"API test: {self.apitest}")
try:
_LOGGER.info("Procession station schedule for %s", self.station)
data = self.process_data(raw_data)
# with open("output.json", "w") as convert_file:
# convert_file.write(str(data))
except Exception as err:
_LOGGER.exception("Exception whilst processing data: ")
raise NationalRailClientException("unexpected data from api") from err
return data

return {}
10 changes: 8 additions & 2 deletions custom_components/nationalrailuk/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,
# validate the token by calling a known line

try:
my_api = NationalRailClient(data[CONF_TOKEN], "STP", ["ZFD"])
my_api = NationalRailClient(data[CONF_TOKEN], "STP", ["ZFD"], apiTest=True)
res = await my_api.async_get_data()
except NationalRailClientInvalidToken as err:
_LOGGER.exception(err)
Expand All @@ -57,9 +57,13 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str,

try:
my_api = NationalRailClient(
data[CONF_TOKEN], data[CONF_STATION], data[CONF_DESTINATIONS].split(",")
data[CONF_TOKEN],
data[CONF_STATION],
data[CONF_DESTINATIONS].split(","),
apiTest=False,
)
res = await my_api.async_get_data()
# print(res)
except NationalRailClientInvalidInput as err:
_LOGGER.exception(err)
raise InvalidInput() from err
Expand Down Expand Up @@ -115,6 +119,8 @@ async def async_step_user(
class InvalidToken(HomeAssistantError):
"""Error to indicate the Token is invalid."""

# print("Invalid token called")


class InvalidInput(HomeAssistantError):
"""Error to indicate there is invalid user input."""
10 changes: 8 additions & 2 deletions custom_components/nationalrailuk/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, hass, token, station, destinations):
destinations = destinations.split(",")
self.station = station
self.destinations = destinations
self.my_api = NationalRailClient(token, station, destinations)
self.my_api = NationalRailClient(token, station, destinations, apiTest=False)

self.last_data_refresh = None

Expand Down Expand Up @@ -94,6 +94,7 @@ async def _async_update_data(self):
async with async_timeout.timeout(30):
data = await self.my_api.async_get_data()
self.last_data_refresh = time.time()

# except aiohttp.ClientError as err:
# raise UpdateFailed(f"Error communicating with API: {err}") from err

Expand Down Expand Up @@ -166,14 +167,15 @@ class NationalRailSchedule(CoordinatorEntity):

attribution = "This uses National Rail Darwin Data Feeds"

def __init__(self, coordinator):
def __init__(self, coordinator: NationalRailScheduleCoordinator):
"""Pass coordinator to CoordinatorEntity."""
super().__init__(coordinator)
self.entity_id = f"sensor.{coordinator.data['name'].lower()}"

@property
def extra_state_attributes(self):
"""Return the state attributes."""
# print(self.coordinator.data)
return self.coordinator.data

@property
Expand All @@ -187,3 +189,7 @@ def state(self):
return (
self.coordinator.last_data_refresh
) # self.coordinator.data["next_train_expected"]

@property
def available(self) -> bool:
return True

0 comments on commit 904a9c7

Please sign in to comment.