diff --git a/src/cortex-cli/src/stats_cmd.rs b/src/cortex-cli/src/stats_cmd.rs index 1e407503..3728bc5c 100644 --- a/src/cortex-cli/src/stats_cmd.rs +++ b/src/cortex-cli/src/stats_cmd.rs @@ -192,9 +192,11 @@ impl StatsCli { /// Get the cortex home directory. fn get_cortex_home() -> PathBuf { - dirs::home_dir() - .map(|h| h.join(".cortex")) - .unwrap_or_else(|| PathBuf::from(".cortex")) + cortex_common::get_cortex_home().unwrap_or_else(|| { + dirs::home_dir() + .map(|h| h.join(".cortex")) + .unwrap_or_else(|| PathBuf::from(".cortex")) + }) } /// Get pricing for a model. @@ -735,6 +737,21 @@ mod tests { assert!((cost - 12.5).abs() < 0.001); } + #[test] + fn test_get_cortex_home_respects_cortex_home() { + let previous = std::env::var_os("CORTEX_HOME"); + let expected = std::env::temp_dir().join("cortex-stats-test-home"); + + unsafe { std::env::set_var("CORTEX_HOME", &expected) }; + assert_eq!(get_cortex_home(), expected); + + if let Some(value) = previous { + unsafe { std::env::set_var("CORTEX_HOME", value) }; + } else { + unsafe { std::env::remove_var("CORTEX_HOME") }; + } + } + #[test] fn test_validate_days_range() { // Valid values