diff --git a/crates/ott-common/src/discovery/dns.rs b/crates/ott-common/src/discovery/dns.rs index f675d3095..eb5797c3f 100644 --- a/crates/ott-common/src/discovery/dns.rs +++ b/crates/ott-common/src/discovery/dns.rs @@ -1,8 +1,11 @@ +use std::str::FromStr; + use async_trait::async_trait; use hickory_resolver::{ config::{NameServerConfig, Protocol, ResolverConfig, ResolverOpts}, TokioAsyncResolver, }; +use serde::Deserializer; use tracing::info; use super::*; @@ -12,6 +15,8 @@ pub struct DnsDiscoveryConfig { /// The port that monoliths should be listening on for load balancer connections. pub service_port: u16, /// The DNS server to query. Optional. If not provided, the system configuration will be used instead. + #[serde(deserialize_with = "deserialize_dns_server")] + #[serde(default)] pub dns_server: Option, /// The A record to query. If using docker-compose, this should be the service name for the monolith. pub query: String, @@ -21,6 +26,21 @@ pub struct DnsDiscoveryConfig { pub polling_interval: Option, } +fn deserialize_dns_server<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + let buf = String::deserialize(deserializer)?; + + match IpAddr::from_str(&buf) { + Ok(ip) => Ok(Some(SocketAddr::new(ip, 53))), + Err(_) => match SocketAddr::from_str(&buf) { + Ok(socket) => Ok(Some(socket)), + Err(e) => Err(serde::de::Error::custom(e)), + }, + } +} + pub struct DnsServiceDiscoverer { config: DnsDiscoveryConfig, } @@ -76,7 +96,7 @@ mod test { use serde_json::json; #[test] - fn server_deserializes_correctly() { + fn dns_server_deserializes_correctly() { let json = json!({ "service_port": 8080, "dns_server": "127.0.0.1:100", @@ -88,4 +108,43 @@ mod test { assert_eq!(config.dns_server, Some(([127, 0, 0, 1], 100).into())); } + + #[test] + fn dns_server_deserialization_defaults_to_port_53() { + let json = json!({ + "service_port": 8080, + "dns_server": "127.0.0.1", + "query": "".to_string(), + }); + + let config: DnsDiscoveryConfig = + serde_json::from_value(json).expect("Failed to deserialize json"); + + assert_eq!(config.dns_server, Some(([127, 0, 0, 1], 53).into())); + } + + #[test] + fn dns_server_is_optional() { + let json = json!({ + "service_port": 8080, + "query": "".to_string(), + }); + + let config: DnsDiscoveryConfig = + serde_json::from_value(json).expect("Failed to deserialize json"); + + assert!(config.dns_server.is_none()) + } + + #[test] + fn dns_server_failed_deserialization_throws_error() { + let json = json!({ + "not": "valid", + "DnsDiscoveryConfig": true, + }); + + let config: Result = serde_json::from_value(json); + + assert!(config.is_err()) + } }