Skip to content

Commit

Permalink
Merge pull request #10 from bobertoyin/9-loosen-type-for-query-parame…
Browse files Browse the repository at this point in the history
…ters

feat: changing query params to slice of key-value pairs
  • Loading branch information
bobertoyin authored May 26, 2022
2 parents ff5cba5 + b0a0f35 commit 40e98f6
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 74 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "mbta-rs"
version = "0.2.1"
version = "0.3.0"
edition = "2021"
authors = ["Robert Yin <bobertoyin@gmail.com>"]
description = "Simple Rust client for interacting with the MBTA V3 API."
Expand Down
18 changes: 9 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ serde_json = "*"

Simple example usage:
```rust
use std::{collections::HashMap, env};
use std::env;
use mbta_rs::Client;

let client = match env::var("MBTA_TOKEN") {
Ok(token) => Client::with_key(token),
Err(_) => Client::without_key()
};

let query_params = HashMap::from([
("page[limit]".to_string(), "3".to_string())
]);
let query_params = [
("page[limit]", "3")
];

let alerts_response = client.alerts(query_params);
let alerts_response = client.alerts(&query_params);
if let Ok(response) = alerts_response {
for alert in response.data {
println!("MBTA alert: {}", alert.attributes.header);
Expand Down Expand Up @@ -99,7 +99,7 @@ let client = match env::var("MBTA_TOKEN") {
Err(_) => Client::without_key()
};

let routes = client.routes(HashMap::from([("filter[type]".into(), "0,1".into())])).expect("failed to get routes");
let routes = client.routes(&[("filter[type]", "0,1")]).expect("failed to get routes");
let mut map = StaticMapBuilder::new()
.width(1000)
.height(1000)
Expand All @@ -110,17 +110,17 @@ let mut map = StaticMapBuilder::new()
.expect("failed to build map");

for route in routes.data {
let query = HashMap::from([("filter[route]".into(), route.id)]);
let query_params = [("filter[route]", &route.id)];
let shapes = client
.shapes(query.clone())
.shapes(&query_params)
.expect("failed to get shapes");
for shape in shapes.data {
shape
.plot(&mut map, true, PlotStyle::new((route.attributes.color.clone(), 3.0), Some(("#FFFFFF".into(), 1.0))))
.expect("failed to plot shape");
}
let stops = client
.stops(query.clone())
.stops(&query_params)
.expect("failed to get stops");
for stop in stops.data {
stop.plot(
Expand Down
44 changes: 24 additions & 20 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! The client for interacting with the V3 API.
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;

use serde::de::DeserializeOwned;

Expand All @@ -27,34 +27,34 @@ macro_rules! mbta_endpoint_multiple {
///
/// # Arguments
///
/// * `query_params` - a [HashMap] of query parameter names to values
/// * `query_params` - a slice of pairings of query parameter names to values
///
/// ```
/// # use std::{collections::HashMap, env};
/// # use std::env;
/// # use mbta_rs::Client;
/// #
/// # let client = match env::var("MBTA_TOKEN") {
/// # Ok(token) => Client::with_key(token),
/// # Err(_) => Client::without_key()
/// # };
/// #
/// # let query_params = HashMap::from([
/// # ("page[limit]".to_string(), "3".to_string())
/// # ]);
#[doc = concat!("let ", stringify!($func), "_response = client.", stringify!($func), "(query_params);\n")]
/// # let query_params = [
/// # ("page[limit]", "3")
/// # ];
#[doc = concat!("let ", stringify!($func), "_response = client.", stringify!($func), "(&query_params);\n")]
#[doc = concat!("if let Ok(", stringify!($func), ") = ", stringify!($func), "_response {\n")]
#[doc = concat!(" for item in ", stringify!($func), ".data {\n")]
/// println!("{}", item.id);
/// }
/// }
/// ```
pub fn $func(&self, query_params: HashMap<String, String>) -> Result<Response<$model>, ClientError> {
pub fn $func<K: AsRef<str>, V: AsRef<str>>(&self, query_params: &[(K, V)]) -> Result<Response<$model>, ClientError> {
let allowed_query_params: HashSet<String> = $allowed_query_params.into_iter().map(|s: &str| s.to_string()).collect();
for (k, v) in &query_params {
if !allowed_query_params.contains(&k.to_string()) {
for (k, v) in query_params {
if !allowed_query_params.contains(k.as_ref()) {
return Err(ClientError::InvalidQueryParam {
name: k.to_string(),
value: v.to_string(),
name: k.as_ref().to_string(),
value: v.as_ref().to_string(),
});
}
}
Expand All @@ -76,7 +76,7 @@ macro_rules! mbta_endpoint_single {
#[doc = concat!("* `id` - the id of the ", stringify!($func), " to return")]
///
/// ```
/// # use std::{collections::HashMap, env};
/// # use std::env;
/// # use mbta_rs::Client;
/// #
/// # let client = match env::var("MBTA_TOKEN") {
Expand All @@ -91,7 +91,7 @@ macro_rules! mbta_endpoint_single {
/// }
/// ```
pub fn $func(&self, id: &str) -> Result<Response<$model>, ClientError> {
self.get(&format!("{}/{}", $endpoint, id), HashMap::new())
self.get::<$model, String, String>(&format!("{}/{}", $endpoint, id), &[])
}
}
};
Expand Down Expand Up @@ -279,7 +279,7 @@ pub struct Client {
impl Client {
/// Create a [Client] without an API key.
///
/// "Without an api key in the query string or as a request header, requests will be tracked by IP address and have stricter rate limit."
/// > "Without an api key in the query string or as a request header, requests will be tracked by IP address and have stricter rate limit." - Massachusetts Bay Transportation Authority
pub fn without_key() -> Self {
Self {
api_key: None,
Expand Down Expand Up @@ -312,20 +312,24 @@ impl Client {
}
}

/// Helper method for making generalized GET requests to any endpoint with any query parameters.
/// Presumes that all query parameters given in the [HashMap] are valid.
/// Helper method for making generalized `GET` requests to any endpoint with any query parameters.
/// Presumes that all query parameters given are valid.
///
/// # Arguments
///
/// * query_params - a [HashMap] of query parameter names to values
fn get<T: DeserializeOwned>(&self, endpoint: &str, query_params: HashMap<String, String>) -> Result<Response<T>, ClientError> {
/// * query_params - a slice of pairings of query parameter names to values
fn get<T: DeserializeOwned, K: AsRef<str>, V: AsRef<str>>(
&self,
endpoint: &str,
query_params: &[(K, V)],
) -> Result<Response<T>, ClientError> {
let path = format!("{}/{}", self.base_url, endpoint);
let request = ureq::get(&path);
let request = match &self.api_key {
Some(key) => request.set("x-api-key", key),
None => request,
};
let request = query_params.iter().fold(request, |r, (k, v)| r.query(k, v));
let request = query_params.iter().fold(request, |r, (k, v)| r.query(k.as_ref(), v.as_ref()));
let response: Response<T> = request.call()?.into_json()?;
Ok(response)
}
Expand Down
13 changes: 7 additions & 6 deletions tests/map.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
//! Simple tests for tile map plotting.
use std::{collections::HashMap, fs::remove_file, path::PathBuf};
use std::{fs::remove_file, path::PathBuf};

use mbta_rs::{map::*, *};
use raster::{compare::similar, open};
Expand All @@ -24,7 +24,8 @@ fn image_file(relative_path: &str) -> PathBuf {
#[rstest]
fn test_simple_map_render(client: Client) {
// Arrange
let routes = client.routes(HashMap::from([("filter[type]".into(), "0,1".into())])).expect("failed to get routes");
let route_params = [("filter[type]", "0,1")];
let routes = client.routes(&route_params).expect("failed to get routes");
let mut map = StaticMapBuilder::new()
.width(1000)
.height(1000)
Expand All @@ -40,19 +41,19 @@ fn test_simple_map_render(client: Client) {

// Act
for route in routes.data {
let query = HashMap::from([("filter[route]".into(), route.id)]);
let shapes = client.shapes(query.clone()).expect("failed to get shapes");
let params = [("filter[route]", &route.id)];
let shapes = client.shapes(&params).expect("failed to get shapes");
for shape in shapes.data {
shape
.plot(&mut map, true, PlotStyle::new((route.attributes.color.clone(), 3.0), Some(("#FFFFFF".into(), 1.0))))
.expect("failed to plot shape");
}
let stops = client.stops(query.clone()).expect("failed to get stops");
let stops = client.stops(&params).expect("failed to get stops");
for stop in stops.data {
stop.plot(&mut map, true, PlotStyle::new((route.attributes.color.clone(), 3.0), Some(("#FFFFFF".into(), 1.0))))
.expect("failed to plot stop");
}
let vehicles = client.vehicles(query).expect("failed to get vehicles");
let vehicles = client.vehicles(&params).expect("failed to get vehicles");
for vehicle in vehicles.data {
vehicle
.plot(&mut map, true, IconStyle::new(image_file("train.png"), 12.5, 12.5))
Expand Down
22 changes: 8 additions & 14 deletions tests/simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ macro_rules! test_endpoint_plural_and_singular {
(plural_func=$plural_func:ident, singular_func=$singular_func:ident) => {
#[cfg(test)]
mod $plural_func {
use std::collections::HashMap;

use rstest::*;

use mbta_rs::*;
Expand All @@ -24,11 +22,10 @@ macro_rules! test_endpoint_plural_and_singular {
#[rstest]
fn success_plural_models(client: Client) {
// Arrange
let params = [("page[limit]", "3")];

// Act
let $plural_func = client
.$plural_func(HashMap::from([("page[limit]".into(), "3".into())]))
.expect(&format!("failed to get {}", stringify!($plural_func)));
let $plural_func = client.$plural_func(&params).expect(&format!("failed to get {}", stringify!($plural_func)));

// Assert
assert_eq!($plural_func.data.len(), 3);
Expand All @@ -39,11 +36,10 @@ macro_rules! test_endpoint_plural_and_singular {
#[rstest]
fn request_failure_plural_models(client: Client) {
// Arrange
let params = [("sort", "foobar")];

// Act
let error = client
.$plural_func(HashMap::from([("sort".into(), "foobar".into())]))
.expect_err(&format!("{} did not fail", stringify!($plural_func)));
let error = client.$plural_func(&params).expect_err(&format!("{} did not fail", stringify!($plural_func)));

// Assert
if let ClientError::ResponseError { errors } = error {
Expand All @@ -56,11 +52,10 @@ macro_rules! test_endpoint_plural_and_singular {
#[rstest]
fn query_param_failure_plural_models(client: Client) {
// Arrange
let params = [("foo", "bar")];

// Act
let error = client
.$plural_func(HashMap::from([("foo".into(), "bar".into())]))
.expect_err(&format!("{} did not fail", stringify!($plural_func)));
let error = client.$plural_func(&params).expect_err(&format!("{} did not fail", stringify!($plural_func)));

// Assert
if let ClientError::InvalidQueryParam { name, value } = error {
Expand All @@ -74,9 +69,8 @@ macro_rules! test_endpoint_plural_and_singular {
#[rstest]
fn success_singular_model(client: Client) {
// Arrange
let $plural_func = client
.$plural_func(HashMap::from([("page[limit]".into(), "3".into())]))
.expect(&format!("failed to get {}", stringify!($plural_func)));
let params = [("page[limit]", "3")];
let $plural_func = client.$plural_func(&params).expect(&format!("failed to get {}", stringify!($plural_func)));

// Act & Assert
for $singular_func in $plural_func.data {
Expand Down
36 changes: 12 additions & 24 deletions tests/with_route.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ macro_rules! test_endpoint_plural_with_route {
(plural_func=$plural_func:ident) => {
#[cfg(test)]
mod $plural_func {
use std::collections::HashMap;

use rstest::*;

use mbta_rs::*;
Expand All @@ -25,15 +23,12 @@ macro_rules! test_endpoint_plural_with_route {
#[rstest]
fn success_plural_models(client: Client) {
// Arrange
let routes = client.routes(HashMap::from([("page[limit]".into(), "1".into())])).expect("failed to get routes");
let route_params = [("page[limit]", "1")];
let routes = client.routes(&route_params).expect("failed to get routes");
let params = [("page[limit]", "3"), ("filter[route]", &routes.data[0].id.clone())];

// Act
let $plural_func = client
.$plural_func(HashMap::from([
("page[limit]".into(), "3".into()),
("filter[route]".into(), routes.data[0].id.clone()),
]))
.expect(&format!("failed to get {}", stringify!($plural_func)));
let $plural_func = client.$plural_func(&params).expect(&format!("failed to get {}", stringify!($plural_func)));

// Assert
assert_eq!($plural_func.data.len(), 3);
Expand All @@ -44,11 +39,10 @@ macro_rules! test_endpoint_plural_with_route {
#[rstest]
fn request_failure_plural_models(client: Client) {
// Arrange
let params = [("sort", "foobar")];

// Act
let error = client
.$plural_func(HashMap::from([("sort".into(), "foobar".into())]))
.expect_err(&format!("{} did not fail", stringify!($plural_func)));
let error = client.$plural_func(&params).expect_err(&format!("{} did not fail", stringify!($plural_func)));

// Assert
if let ClientError::ResponseError { errors } = error {
Expand All @@ -61,11 +55,10 @@ macro_rules! test_endpoint_plural_with_route {
#[rstest]
fn query_param_failure_plural_models(client: Client) {
// Arrange
let params = [("foo", "bar")];

// Act
let error = client
.$plural_func(HashMap::from([("foo".into(), "bar".into())]))
.expect_err(&format!("{} did not fail", stringify!($plural_func)));
let error = client.$plural_func(&params).expect_err(&format!("{} did not fail", stringify!($plural_func)));

// Assert
if let ClientError::InvalidQueryParam { name, value } = error {
Expand All @@ -86,8 +79,6 @@ macro_rules! test_endpoint_singular_with_route {
(plural_func=$plural_func:ident, singular_func=$singular_func:ident) => {
#[cfg(test)]
mod $singular_func {
use std::collections::HashMap;

use rstest::*;

use mbta_rs::*;
Expand All @@ -104,13 +95,10 @@ macro_rules! test_endpoint_singular_with_route {
#[rstest]
fn success_singular_model(client: Client) {
// Arrange
let routes = client.routes(HashMap::from([("page[limit]".into(), "1".into())])).expect("failed to get routes");
let $plural_func = client
.$plural_func(HashMap::from([
("page[limit]".into(), "3".into()),
("filter[route]".into(), routes.data[0].id.clone()),
]))
.expect(&format!("failed to get {}", stringify!($plural_func)));
let route_params = [("page[limit]", "1")];
let routes = client.routes(&route_params).expect("failed to get routes");
let params = [("page[limit]", "3"), ("filter[route]", &routes.data[0].id.clone())];
let $plural_func = client.$plural_func(&params).expect(&format!("failed to get {}", stringify!($plural_func)));

// Act & Assert
for $singular_func in $plural_func.data {
Expand Down

0 comments on commit 40e98f6

Please sign in to comment.