From 9b4695dd2ce52b03ae41dbb28239d13c8446f39c Mon Sep 17 00:00:00 2001 From: "Samuel F. Potter" Date: Fri, 17 Nov 2023 22:08:25 -0800 Subject: [PATCH] make utri_solve a little more robust --- include/jmm/utri.h | 2 +- src/eik3.c | 3 ++- src/utri.c | 29 +++++++++++++++++++---------- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/include/jmm/utri.h b/include/jmm/utri.h index 4129c2f..48aaa5a 100644 --- a/include/jmm/utri.h +++ b/include/jmm/utri.h @@ -9,7 +9,7 @@ typedef struct utri utri_s; void utri_alloc(utri_s **utri); void utri_dealloc(utri_s **utri); void utri_init(utri_s *u, eik3_s const *eik, size_t lhat, size_t const l[2]); -void utri_solve(utri_s *utri); +bool utri_solve(utri_s *utri); par3_s utri_get_par(utri_s const *u); dbl utri_get_value(utri_s const *utri); void utri_get_jet31t(utri_s const *utri, jet31t *jet); diff --git a/src/eik3.c b/src/eik3.c index c52cc68..9038830 100644 --- a/src/eik3.c +++ b/src/eik3.c @@ -308,7 +308,8 @@ do_utri(eik3_s *eik, size_t l, size_t l0, size_t l1, utri_cache_s *utri_cache, if (utri_is_degenerate(utri)) goto cleanup; - utri_solve(utri); + if (!utri_solve(utri)) + goto cleanup; if (par != NULL && utri_get_value(utri) < eik->jet[l].f) *par = utri_get_par(utri); diff --git a/src/utri.c b/src/utri.c index 3183bd8..7648600 100644 --- a/src/utri.c +++ b/src/utri.c @@ -195,27 +195,36 @@ static dbl hybrid_Dp(dbl lam, cubic_s const *p) { return cubic_df(p, lam); } -void utri_solve(utri_s *utri) { +bool utri_solve(utri_s *utri) { static int CALL_NUMBER = 0; + ++CALL_NUMBER; + if (utri->stype == STYPE_CONSTANT) { - dbl lam, f[2]; + dbl lam, f[2], Df[2]; if (hybrid((hybrid_cost_func_t)hybrid_f, 0, 1, utri, &lam)) - return; + return true; set_lambda(utri, 0); f[0] = utri->f; + Df[0] = utri->Df; set_lambda(utri, 1); f[1] = utri->f; + Df[1] = utri->Df; - assert(f[0] != f[1]); - - if (f[0] < f[1]) + if (f[0] < f[1] && Df[0] >= 0) { set_lambda(utri, 0); - else + return true; + } + + if (f[1] < f[0] && Df[1] <= 0) { set_lambda(utri, 1); + return true; + } + + return false; } else if (utri->stype == STYPE_FUNC_PTR) { @@ -315,11 +324,11 @@ void utri_solve(utri_s *utri) { uline_get_topt(uline, utri->topt); uline_dealloc(&uline); - } - else assert(false); + return true; + } - ++CALL_NUMBER; + assert(false); } static void get_update_inds(utri_s const *utri, size_t l[2]) {