diff --git a/splink/accuracy.py b/splink/accuracy.py index d8a92a0491..cfb8aa59e1 100644 --- a/splink/accuracy.py +++ b/splink/accuracy.py @@ -115,29 +115,29 @@ def truth_space_table_from_labels_with_predictions_sqls( power(2, truth_threshold) / (1 + power(2, truth_threshold)) as match_probability, row_count, - P as p, - N as n, - TP as tp, - TN as tn, - FP as fp, - FN as fn, - P/row_count as P_rate, + cast(P as float8) as p, + cast(N as float8) as n, + cast(TP as float8) as tp, + cast(TN as float8) as tn, + cast(FP as float8) as fp, + cast(FN as float8) as fn, + cast(P/row_count as float8) as P_rate, cast(N as float)/row_count as N_rate, - cast(TP as float)/P as tp_rate, - cast(TN as float)/N as tn_rate, - cast(FP as float)/N as fp_rate, - cast(FN as float)/P as fn_rate, - case when TP+FP=0 then 1 else cast(TP as float)/(TP+FP) end as precision, - cast(TP as float)/P as recall, - cast(TN as float)/N as specificity, - case when TN+FN=0 then 1 else cast(TN as float)/(TN+FN) end as npv, - cast(TP+TN as float)/(P+N) as accuracy, - 2.0*TP/(2*TP + FN + FP) as f1, - 5.0*TP/(5*TP + 4*FN + FP) as f2, - 1.25*TP/(1.25*TP + 0.25*FN + FP) as f0_5, - 4.0*TP*TN/((4.0*TP*TN) + ((TP + TN)*(FP + FN))) as p4, + cast(TP as float8)/P as tp_rate, + cast(TN as float8)/N as tn_rate, + cast(FP as float8)/N as fp_rate, + cast(FN as float8)/P as fn_rate, + case when TP+FP=0 then 1 else cast(TP as float8)/(TP+FP) end as precision, + cast(TP as float8)/P as recall, + cast(TN as float8)/N as specificity, + case when TN+FN=0 then 1 else cast(TN as float8)/(TN+FN) end as npv, + cast(TP+TN as float8)/(P+N) as accuracy, + cast(2.0*TP/(2*TP + FN + FP) as float8) as f1, + cast(5.0*TP/(5*TP + 4*FN + FP) as float8) as f2, + cast(1.25*TP/(1.25*TP + 0.25*FN + FP) as float8) as f0_5, + cast(4.0*TP*TN/((4.0*TP*TN) + ((TP + TN)*(FP + FN))) as float8) as p4, case when TN+FN=0 or TP+FP=0 or P=0 or N=0 then 0 - else cast((TP*TN)-(FP*FN) as float)/sqrt((TP+FP)*P*N*(TN+FN)) end as phi + else cast((TP*TN)-(FP*FN) as float8)/sqrt((TP+FP)*P*N*(TN+FN)) end as phi from __splink__labels_with_pos_neg_grouped_with_truth_stats """