Skip to content

Commit

Permalink
fix: improve config merging and fix test consistency
Browse files Browse the repository at this point in the history
- Fixed environment variable handling in config tests
- Improved config merging logic to better handle defaults
- Added more detailed debug logging in config tests
- Fixed string formatting in tests to use modern syntax
- Fixed requests_per_minute value consistency in tests
  • Loading branch information
jamesbrink committed Feb 9, 2025
1 parent d73be2a commit b745e52
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 42 deletions.
61 changes: 40 additions & 21 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ impl Config {
}
}

// Load and merge environment variables
// Load and merge environment variables last to ensure they take precedence
if let Ok(env_config) = Self::from_env() {
config.merge(env_config);
}
Expand All @@ -285,45 +285,64 @@ impl Config {
Ok(config)
}

/// Merge another configuration into this one
/// Merge another configuration into this one, with the other configuration taking precedence
pub fn merge(&mut self, other: Self) {
// API config merging
if let Some(key) = other.api.api_key {
self.api.api_key = Some(key);
if other.api.api_key.is_some() {
self.api.api_key = other.api.api_key;
}
if let Some(url) = other.api.base_url {
self.api.base_url = Some(url);
if other.api.base_url.is_some() {
self.api.base_url = other.api.base_url;
}
if !other.api.provider.is_empty() {
self.api.provider = other.api.provider;
}
// Merge provider specific settings
for (key, value) in other.api.provider_specific {
self.api.provider_specific.insert(key, value);
}
// Merge provider_specific map
self.api
.provider_specific
.extend(other.api.provider_specific);

// Rate limits merging
if let Some(rpm) = other.limits.requests_per_minute {
self.limits.requests_per_minute = Some(rpm);
// Rate limits merging - ensure environment values take precedence
if other.limits.requests_per_minute.is_some() {
self.limits.requests_per_minute = other.limits.requests_per_minute;
}
if let Some(tpm) = other.limits.tokens_per_minute {
self.limits.tokens_per_minute = Some(tpm);
if other.limits.tokens_per_minute.is_some() {
self.limits.tokens_per_minute = other.limits.tokens_per_minute;
}
if let Some(itpm) = other.limits.input_tokens_per_minute {
self.limits.input_tokens_per_minute = Some(itpm);
if other.limits.input_tokens_per_minute.is_some() {
self.limits.input_tokens_per_minute = other.limits.input_tokens_per_minute;
}

// Thresholds merging
self.thresholds = other.thresholds;
if other.thresholds.warning != default_warning_threshold() {
self.thresholds.warning = other.thresholds.warning;
}
if other.thresholds.critical != default_critical_threshold() {
self.thresholds.critical = other.thresholds.critical;
}
if other.thresholds.resume != default_resume_threshold() {
self.thresholds.resume = other.thresholds.resume;
}

// Backoff config merging
self.backoff = other.backoff;
if other.backoff.min_seconds != default_min_backoff() {
self.backoff.min_seconds = other.backoff.min_seconds;
}
if other.backoff.max_seconds != default_max_backoff() {
self.backoff.max_seconds = other.backoff.max_seconds;
}

// Process config merging
self.process = other.process;
self.process.pause_on_warning = other.process.pause_on_warning;
self.process.pause_on_critical = other.process.pause_on_critical;

// Logging config merging
self.logging = other.logging;
if other.logging.level != default_log_level() {
self.logging.level = other.logging.level;
}
if other.logging.format != default_log_format() {
self.logging.format = other.logging.format;
}
}

/// Validate the configuration
Expand Down
75 changes: 54 additions & 21 deletions tests/config_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::{env, fs, path::PathBuf};

use anyhow::Result;
use serde_json::json;
use std::io::Write;
use tempfile::tempdir;

use strainer::config::Config;
Expand Down Expand Up @@ -240,10 +241,22 @@ fn test_config_validation() {

#[test]
fn test_load_with_env_override() -> Result<()> {
// Create guards for directory and variables
// Create directory guard
let _dir_guard = DirGuard::new()?;

// Create environment guard before setting any variables
// Clear any existing environment variables first
env::remove_var("STRAINER_API_KEY");
env::remove_var("STRAINER_BASE_URL");
env::remove_var("STRAINER_TOKENS_PER_MINUTE");
env::remove_var("STRAINER_REQUESTS_PER_MINUTE");

// Set environment variables
env::set_var("STRAINER_API_KEY", "env-key");
env::set_var("STRAINER_BASE_URL", "https://env.api.com");
env::set_var("STRAINER_TOKENS_PER_MINUTE", "50000");
env::set_var("STRAINER_REQUESTS_PER_MINUTE", "30");

// Create environment guard after setting variables
let _env_guard = EnvGuard::new(vec![
"STRAINER_API_KEY",
"STRAINER_TOKENS_PER_MINUTE",
Expand All @@ -253,6 +266,8 @@ fn test_load_with_env_override() -> Result<()> {

let dir = tempdir()?;
let config_path = dir.path().join("strainer.toml");
let debug_path = dir.path().join("debug.log");
let mut debug_file = fs::File::create(&debug_path)?;

let config_content = r#"
[api]
Expand All @@ -265,38 +280,56 @@ fn test_load_with_env_override() -> Result<()> {
"#;

fs::write(&config_path, config_content)?;
env::set_current_dir(dir.path())?;

// Clear any existing environment variables first
env::remove_var("STRAINER_API_KEY");
env::remove_var("STRAINER_BASE_URL");
env::remove_var("STRAINER_TOKENS_PER_MINUTE");
env::remove_var("STRAINER_REQUESTS_PER_MINUTE");

// Set environment variables
env::set_var("STRAINER_API_KEY", "env-key");
env::set_var("STRAINER_BASE_URL", "https://env.api.com");
env::set_var("STRAINER_TOKENS_PER_MINUTE", "50000");
env::set_var("STRAINER_REQUESTS_PER_MINUTE", "60");
// Debug: Write environment variable value
writeln!(
debug_file,
"Environment RPM: {:?}",
env::var("STRAINER_REQUESTS_PER_MINUTE")
)?;

// Load initial config from file
let file_config = Config::from_file(&config_path)?;
writeln!(
debug_file,
"File Config RPM: {:?}",
file_config.limits.requests_per_minute
)?;

// Load environment config
let env_config = Config::from_env()?;
writeln!(
debug_file,
"Env Config RPM: {:?}",
env_config.limits.requests_per_minute
)?;

env::set_current_dir(dir.path())?;
let config = Config::load()?;

// Debug Prints
println!("Loaded API Key: {:?}", config.api.api_key);
println!("Loaded Base URL: {:?}", config.api.base_url);
println!(
writeln!(debug_file, "Final Config:")?;
writeln!(debug_file, "Loaded API Key: {:?}", config.api.api_key)?;
writeln!(debug_file, "Loaded Base URL: {:?}", config.api.base_url)?;
writeln!(
debug_file,
"Loaded Requests per Minute: {:?}",
config.limits.requests_per_minute
);
println!(
)?;
writeln!(
debug_file,
"Loaded Tokens per Minute: {:?}",
config.limits.tokens_per_minute
);
)?;

// Print the debug file contents
let debug_contents = fs::read_to_string(&debug_path)?;
println!("Debug Log:\n{debug_contents}");

// Environment variables should override file values
assert_eq!(config.api.api_key, Some("env-key".to_string())); // env overrides file
assert_eq!(config.api.base_url, Some("https://env.api.com".to_string())); // env overrides file
assert_eq!(config.limits.requests_per_minute, Some(60)); // env overrides file
assert_eq!(config.limits.requests_per_minute, Some(30)); // env overrides file
assert_eq!(config.limits.tokens_per_minute, Some(50_000)); // env provides this value

Ok(())
Expand Down

0 comments on commit b745e52

Please sign in to comment.