Skip to content

Commit

Permalink
Fix errors related ndarray
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanUlbrich committed Apr 21, 2022
1 parent ef564a2 commit fc06f40
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 33 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ documentation = "https://docs.rs/csaps"
readme = "README.md"
license = "MIT"

edition = "2021"
edition = "2018"


[badges]
Expand Down
19 changes: 5 additions & 14 deletions src/umv/evaluate.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
use ndarray::{
Dimension,
Axis,
Array,
Array1,
Array2,
ArrayView1,
ArrayView2,
s,
stack,
};
use ndarray::{prelude::*, s, concatenate};


use crate::{
Real,
Expand All @@ -35,12 +26,12 @@ impl<'a, T> NdSpline<'a, T>
xi: ArrayView1<'a, T>) -> Array2<T>
{
let edges = {
let mesh = breaks.slice(s![1..-1]);
let mesh = breaks.slice(s![1 as i32..-1]);
let one = Array1::<T>::ones((1, ));
let left_bound = &one * T::neg_infinity();
let right_bound = &one * T::infinity();

stack![Axis(0), left_bound, mesh, right_bound]
concatenate![Axis(0), left_bound, mesh, right_bound]
};

let mut indices = digitize(&xi, &edges);
Expand Down Expand Up @@ -68,7 +59,7 @@ impl<'a, T> NdSpline<'a, T>
.map(coeffs_by_index)
.collect();

stack(Axis(1), &indexed_coeffs).unwrap()
concatenate(Axis(1), &indexed_coeffs).unwrap()
};

// Vectorized computing the spline pieces (polynoms) on the given data sites
Expand Down
36 changes: 18 additions & 18 deletions src/umv/make.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use ndarray::{prelude::*, stack};
use ndarray::{prelude::*, stack, concatenate};

use crate::{
Real,
Expand All @@ -17,9 +17,9 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
{
pub(super) fn make_spline(&mut self) -> Result<()> {
let one = T::one();
let two = T::from(2.0).unwrap();
let three = T::from(3.0).unwrap();
let six = T::from(6.0).unwrap();
let two = T::from::<f64>(2.0).unwrap();
let three = T::from::<f64>(3.0).unwrap();
let six = T::from::<f64>(6.0).unwrap();

let breaks = self.x;

Expand All @@ -41,8 +41,8 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
// The corner case for Nx2 data (2 data points)
if pcount == 2 {
drop(dx);
let yi = y.slice(s![.., 0]).insert_axis(Axis(1));
let coeffs = stack![Axis(1), dydx, yi];
let yi = y.slice(s![.., 0 as usize]).insert_axis(Axis(1));
let coeffs = concatenate![Axis(1), dydx, yi];

self.smooth = Some(one);
self.spline = Some(NdSpline::new(breaks, coeffs));
Expand All @@ -56,11 +56,11 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
let qtwq = {
let qt = {
let odx = ones(pcount - 1) / &dx;
let odx_head = odx.slice(s![..-1]).insert_axis(Axis(0)).into_owned();
let odx_tail = odx.slice(s![1..]).insert_axis(Axis(0)).into_owned();
let odx_head = odx.slice(s![..-1 as i32]).insert_axis(Axis(0)).into_owned();
let odx_tail = odx.slice(s![1 as i32..]).insert_axis(Axis(0)).into_owned();
drop(odx);
let odx_body = -(&odx_tail + &odx_head);
let diags_qt = stack![Axis(0), odx_head, odx_body, odx_tail];
let diags_qt = concatenate![Axis(0), odx_head, odx_body, odx_tail];

sprsext::diags(diags_qt, &[0, 1, 2], (pcount - 2, pcount))
};
Expand All @@ -76,10 +76,10 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
};

let r = {
let dx_head = dx.slice(s![..-1]).insert_axis(Axis(0)).into_owned();
let dx_tail = dx.slice(s![1..]).insert_axis(Axis(0)).into_owned();
let dx_head = dx.slice(s![..-1 as i32]).insert_axis(Axis(0)).into_owned();
let dx_tail = dx.slice(s![1 as i32..]).insert_axis(Axis(0)).into_owned();
let dx_body = (&dx_tail + &dx_head) * two;
let diags_r = stack![Axis(0), dx_tail, dx_body, dx_head];
let diags_r = concatenate![Axis(0), dx_tail, dx_body, dx_head];

sprsext::diags(diags_r, &[-1, 0, 1], (pcount - 2, pcount - 2))
};
Expand Down Expand Up @@ -109,11 +109,11 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
sprsext::solve(&a, &b)
};

// Compute and stack spline coefficients
// Compute and concatenatespline coefficients
let coeffs = {
let vpad = |arr: &Array2<T>| -> Array2<T> {
let pad = Array2::<T>::zeros((1, arr.shape()[1]));
stack(Axis(0), &[pad.view(), arr.view(), pad.view()]).unwrap()
concatenate(Axis(0), &[pad.view(), arr.view(), pad.view()]).unwrap()
};

let dx = dx.insert_axis(Axis(1));
Expand All @@ -133,17 +133,17 @@ impl<'a, T, D> CubicSmoothingSpline<'a, T, D>
};

let c3 = vpad(&(usol * smooth));
let c3_head = c3.slice(s![..-1, ..]);
let c3_tail = c3.slice(s![1.., ..]);
let c3_head = c3.slice(s![..-1 as i32, ..]);
let c3_tail = c3.slice(s![1 as i32.., ..]);

let p1 = diff(&c3, Some(Axis(0))) / &dx;
let p2 = &c3_head * three;
let p3 = diff(&yi, Some(Axis(0))) / &dx - (&c3_head * two + c3_tail) * dx;
let p4 = yi.slice(s![..-1, ..]);
let p4 = yi.view().slice(s![..-1 as i32, ..]);

drop(c3);

stack![Axis(0), p1, p2, p3, p4].t().to_owned()
concatenate(Axis(0), &[p1.view(), p2.view(), p3.view(), p4]).unwrap().t().to_owned()
};

self.smooth = Some(smooth);
Expand Down

0 comments on commit fc06f40

Please sign in to comment.