From dbefa1360c269d9bbe21e36c5890e1c583e44965 Mon Sep 17 00:00:00 2001 From: Andreea Popescu Date: Wed, 15 Jan 2025 12:45:43 +0000 Subject: [PATCH] safe math --- pallets/subtensor/src/epoch/math.rs | 71 +++++++++++++++++------------ 1 file changed, 43 insertions(+), 28 deletions(-) diff --git a/pallets/subtensor/src/epoch/math.rs b/pallets/subtensor/src/epoch/math.rs index 9d63ac1c5..682e9cfd9 100644 --- a/pallets/subtensor/src/epoch/math.rs +++ b/pallets/subtensor/src/epoch/math.rs @@ -1027,29 +1027,28 @@ pub fn weighted_median_col_sparse( // ratio=0: Result = A // ratio=1: Result = B #[allow(dead_code)] -pub fn interpolate( - mat1: &Vec>, - mat2: &Vec>, - ratio: I32F32, -) -> Vec> { +pub fn interpolate(mat1: &[Vec], mat2: &[Vec], ratio: I32F32) -> Vec> { if ratio == I32F32::from_num(0) { - return mat1.clone(); + return mat1.to_owned(); } if ratio == I32F32::from_num(1) { - return mat2.clone(); + return mat2.to_owned(); } assert!(mat1.len() == mat2.len()); - if mat1.len() == 0 { + if mat1.is_empty() { return vec![vec![]; 1]; } - if mat1[0].len() == 0 { + if mat1.first().unwrap_or(&vec![]).is_empty() { return vec![vec![]; 1]; } - let mut result: Vec> = vec![vec![I32F32::from_num(0); mat1[0].len()]; mat1.len()]; - for i in 0..mat1.len() { - assert!(mat1[i].len() == mat2[i].len()); - for j in 0..mat1[i].len() { - result[i][j] = mat1[i][j] + ratio * (mat2[i][j] - mat1[i][j]); + let mut result: Vec> = + vec![vec![I32F32::from_num(0); mat1.first().unwrap_or(&vec![]).len()]; mat1.len()]; + for (i, (row1, row2)) in mat1.iter().zip(mat2.iter()).enumerate() { + assert!(row1.len() == row2.len()); + for (j, (&v1, &v2)) in row1.iter().zip(row2.iter()).enumerate() { + if let Some(res) = result.get_mut(i).unwrap_or(&mut vec![]).get_mut(j) { + *res = v1.saturating_add(ratio.saturating_mul(v2.saturating_sub(v1))); + } } } result @@ -1061,34 +1060,50 @@ pub fn interpolate( // ratio=1: Result = B #[allow(dead_code)] pub fn interpolate_sparse( - mat1: &Vec>, - mat2: &Vec>, + mat1: &[Vec<(u16, I32F32)>], + mat2: &[Vec<(u16, I32F32)>], columns: u16, ratio: I32F32, ) -> Vec> { if ratio == I32F32::from_num(0) { - return mat1.clone(); + return mat1.to_owned(); } if ratio == I32F32::from_num(1) { - return mat2.clone(); + return mat2.to_owned(); } assert!(mat1.len() == mat2.len()); let rows = mat1.len(); let zero: I32F32 = I32F32::from_num(0); let mut result: Vec> = vec![vec![]; rows]; for i in 0..rows { - let mut row1: Vec = vec![zero; columns as usize]; - for (j, value) in mat1[i].iter() { - row1[*j as usize] = *value; - } - let mut row2: Vec = vec![zero; columns as usize]; - for (j, value) in mat2[i].iter() { - row2[*j as usize] = *value; - } + let row1: Vec = mat1.get(i).unwrap_or(&vec![]).iter().fold( + vec![zero; columns as usize], + |mut acc, (j, value)| { + if let Some(entry) = acc.get_mut(*j as usize) { + *entry = *value; + } + acc + }, + ); + + let row2: Vec = mat2.get(i).unwrap_or(&vec![]).iter().fold( + vec![zero; columns as usize], + |mut acc, (j, value)| { + if let Some(entry) = acc.get_mut(*j as usize) { + *entry = *value; + } + acc + }, + ); + for j in 0..columns as usize { - let interp: I32F32 = row1[j] + ratio * (row2[j] - row1[j]); + let v1 = row1.get(j).unwrap_or(&zero); + let v2 = row2.get(j).unwrap_or(&zero); + let interp = v1.saturating_add(ratio.saturating_mul(v2.saturating_sub(*v1))); if zero < interp { - result[i].push((j as u16, interp)) + if let Some(res) = result.get_mut(i) { + res.push((j as u16, interp)); + } } } }