diff --git a/src/utils.rs b/src/utils.rs index 66c72a0..9f60fc6 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -351,15 +351,15 @@ where /// Calculate the recall. pub fn calculate_recall(truth: &[i32], res: &[i32], topk: usize) -> f32 { + assert_eq!(res.len(), topk); let mut count = 0; - let length = topk.min(truth.len()); - for t in truth.iter().take(length) { - for y in res.iter().take(length.min(res.len())) { - if *t == *y { + for id in res { + for t in truth.iter().take(topk) { + if *id == *t { count += 1; break; } } } - (count as f32) / (length as f32) + (count as f32) / (topk as f32) }