Skip to content

Commit

Permalink
Merge pull request #52 from bminixhofer/update-tract
Browse files Browse the repository at this point in the history
Update tract
  • Loading branch information
bminixhofer authored Jul 23, 2021
2 parents 5a117f8 + adfc76f commit 7193727
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 86 deletions.
2 changes: 1 addition & 1 deletion bindings/javascript/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ serde_derive = "1.0"
serde_json = "1.0"
wasm-bindgen-futures = "0.4"
js-sys = "0.3"
ndarray = "=0.13.0"
ndarray = "0.15"
futures = "0.3"
# The `console_error_panic_hook` crate provides better debugging of panics by
# logging them with `console.error`. This is great for development, but requires
Expand Down
8 changes: 4 additions & 4 deletions bindings/javascript/dev_server/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion bindings/javascript/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
},
"homepage": "https://github.com/bminixhofer/nnsplit#readme",
"dependencies": {
"tractjs": "^0.3.2"
"tractjs": "^0.4.0"
},
"devDependencies": {
"cypress": "^4.9.0",
Expand Down
4 changes: 3 additions & 1 deletion bindings/javascript/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ impl NNSplit {
options.into_serde().unwrap()
};

let backend = TractJSBackend::new(&path, options.length_divisor).await?;
// larger batch sizes seem to slow things down so hardcode to 1 for now
// would need more investigation to find the cause
let backend = TractJSBackend::new(&path, options.length_divisor, 1).await?;
let metadata = backend.get_metadata().await?;

Ok(NNSplit {
Expand Down
49 changes: 27 additions & 22 deletions bindings/javascript/src/tractjs_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use js_sys::{Array, Float32Array, Promise, Uint32Array, Uint8Array};
use ndarray::prelude::*;
use serde_derive::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::collections::HashMap;
use std::{cmp, collections::HashMap};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;

Expand All @@ -21,7 +21,7 @@ extern "C" {
fn load(path: &str, options: JsValue) -> Promise;

#[wasm_bindgen(method)]
fn predict_one(this: &Model, input: Tensor, symbol_values: JsValue) -> Promise;
fn predict_one(this: &Model, input: Tensor) -> Promise;

#[wasm_bindgen(method)]
fn get_metadata(this: &Model) -> Promise;
Expand All @@ -43,11 +43,15 @@ extern "C" {

pub struct TractJSBackend {
model: Model,
length_divisor: usize,
batch_size: usize,
}

impl TractJSBackend {
pub async fn new(model_path: &str, length_divisor: usize) -> Result<Self, JsValue> {
pub async fn new(
model_path: &str,
length_divisor: usize,
batch_size: usize,
) -> Result<Self, JsValue> {
let mut input_facts = HashMap::new();
input_facts.insert(
0,
Expand All @@ -65,33 +69,32 @@ impl TractJSBackend {
.await?
.into();

Ok(TractJSBackend {
model,
length_divisor,
})
Ok(TractJSBackend { model, batch_size })
}

pub async fn predict(&self, input: Array2<u8>) -> Result<Array3<f32>, JsValue> {
let shape: Array = vec![JsValue::from(1u32), JsValue::from(input.shape()[1] as u32)]
.into_iter()
.collect();
let mut symbol_values = HashMap::new();
symbol_values.insert("s", input.shape()[1] / self.length_divisor);

let preds = (0..input.shape()[0])
.map(|i| {
.step_by(self.batch_size)
.map(|start| {
let end = cmp::min(start + self.batch_size, input.shape()[0]);
let actual_batch_size = end - start;

let shape: Array = vec![
JsValue::from(actual_batch_size as u32),
JsValue::from(input.shape()[1] as u32),
]
.into_iter()
.collect();

let tensor = Tensor::new(
Uint8Array::from(input.slice(s![i, ..]).as_slice().expect_throw(
Uint8Array::from(input.slice(s![start..end, ..]).as_slice().expect_throw(
"converting ndarray to slice failed (likely not contiguous)",
))
.into(),
shape.clone(),
shape,
);

JsFuture::from(
self.model
.predict_one(tensor, JsValue::from_serde(&symbol_values).unwrap()),
)
JsFuture::from(self.model.predict_one(tensor))
})
.collect::<Vec<_>>();
let preds = join_all(preds)
Expand All @@ -107,11 +110,13 @@ impl TractJSBackend {

let data_vec = preds.into_iter().fold(Vec::new(), |mut arr, x| {
let curr: Float32Array = x.data().into();

arr.extend(curr.to_vec());
arr
});

let mut preds =
Array3::from_shape_vec(shape, data_vec).map_err(|_| "Array conversion error")?;
Array3::from_shape_vec(shape, data_vec).map_err(|_| "Array conversion error.")?;

// sigmoid
preds.mapv_inplace(|x| 1f32 / (1f32 + (-x).exp()));
Expand Down
6 changes: 3 additions & 3 deletions bindings/python/Cargo.build.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ name = "nnsplit"
crate-type = ["cdylib"]

[dependencies]
ndarray = "=0.13.0"
numpy = "0.12.2"
ndarray = "0.15.0"
numpy = "0.14.1"
lazy_static = "1.4"
serde_json = "1.0"

Expand All @@ -25,7 +25,7 @@ default-features = false
features = ["model-loader"]

[dependencies.pyo3]
version = "0.12"
version = "0.14"
features = ["extension-module"]

[workspace]
6 changes: 3 additions & 3 deletions bindings/python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ requires-dist = ["onnxruntime==1.7", "tqdm>=4"]
name = "nnsplit_python"

[dependencies]
ndarray = "=0.13.0"
numpy = "0.12.2"
ndarray = "0.15.0"
numpy = "0.14.1"
lazy_static = "1.4"
serde_json = "1.0"

Expand All @@ -23,4 +23,4 @@ default-features = false
features = ["model-loader"]

[dependencies.pyo3]
version = "0.12"
version = "0.14"
8 changes: 4 additions & 4 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl<'a> IntoPy<Split> for core::Split<'a> {
/// * batch_size (int): Batch size to use.
/// * length_divisor (int): Total length will be padded until it is divisible by this number. Allows some additional optimizations.
#[pyclass]
#[text_signature = "(model_path, use_cuda=None, **kwargs)"]
#[pyo3(text_signature = "(model_path, use_cuda=None, **kwargs)")]
pub struct NNSplit {
backend: ONNXRuntimeBackend,
logic: core::NNSplitLogic,
Expand Down Expand Up @@ -185,7 +185,7 @@ impl NNSplit {
/// * padding (int): How much to zero pad the text on both sides.
/// * batch_size (int): Batch size to use.
/// * length_divisor (int): Total length will be padded until it is divisible by this number. Allows some additional optimizations.
#[text_signature = "(model_name, use_cuda=None, **kwargs)"]
#[pyo3(text_signature = "(model_name, use_cuda=None, **kwargs)")]
#[args(kwargs = "**")]
#[staticmethod]
pub fn load(
Expand Down Expand Up @@ -217,7 +217,7 @@ impl NNSplit {
/// verbose (bool): Whether to display a progress bar.
/// Returns:
/// splits (List[Split]): A list of `Split` objects with the same length as the input text list.
#[text_signature = "(texts, verbose=False)"]
#[pyo3(text_signature = "(texts, verbose=False)")]
pub fn split(
&self,
py: Python,
Expand All @@ -241,7 +241,7 @@ impl NNSplit {
///
/// Returns:
/// levels (List[str]): A list of strings describing the split levels, from top (largest split) to bottom (smallest split).
#[text_signature = "()"]
#[pyo3(text_signature = "()")]
pub fn get_levels(&self) -> PyResult<Vec<String>> {
Ok(self
.logic
Expand Down
12 changes: 8 additions & 4 deletions bindings/python/src/onnxruntime_backend.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ impl ONNXRuntimeBackend {
) -> PyResult<Array3<f32>> {
let prediction: &PyArray3<f32> = MODULE
.as_ref(py)
.call1("predict_batch", (session, data.to_pyarray(py)))?
.getattr("predict_batch")?
.call1((session, data.to_pyarray(py)))?
.extract()?;

let mut prediction = prediction.to_owned_array();
Expand All @@ -45,7 +46,8 @@ impl ONNXRuntimeBackend {
pub fn new<P: AsRef<str>>(py: Python, model_path: P, use_cuda: Option<bool>) -> PyResult<Self> {
let session = MODULE
.as_ref(py)
.call1("create_session", (model_path.as_ref(), use_cuda))?
.getattr("create_session")?
.call1((model_path.as_ref(), use_cuda))?
.into();

let dummy_data = Array2::<u8>::zeros((1, 12));
Expand All @@ -71,7 +73,8 @@ impl ONNXRuntimeBackend {
Some(
MODULE
.as_ref(py)
.call1("get_progress_bar", (input_shape[0],))?,
.getattr("get_progress_bar")?
.call1((input_shape[0],))?,
)
} else {
None
Expand Down Expand Up @@ -101,7 +104,8 @@ impl ONNXRuntimeBackend {
pub fn get_metadata(&self, py: Python) -> PyResult<HashMap<String, String>> {
MODULE
.as_ref(py)
.call1("get_metadata", (&self.session,))?
.getattr("get_metadata")?
.call1((&self.session,))?
.extract()
}
}
8 changes: 4 additions & 4 deletions nnsplit/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,21 @@ readme = "README.md"
include = ["src/*.rs", "tests/*.rs", "build.rs", "data", "models.csv"]

[dependencies]
ndarray = "=0.13.0"
ndarray = "0.15"
thiserror = "1.0"
lazy_static = "1.4"
serde = "1.0"
serde_derive = "1.0"
serde_json = "1.0"
tract-onnx = { version = "0.12.1", optional = true }
tract-onnx = { version = "0.15.2", optional = true }
directories = {version = "3.0.1", optional = true}
minreq = {version = "2.2.1", features = ["https"], optional = true}
url = {version = "2.2.0", optional = true}

[dev-dependencies]
serde_json = "1.0"
quickcheck_macros = "0.9"
quickcheck = "0.9"
quickcheck_macros = "1.0"
quickcheck = "1.0"
rand = "0.8"

[features]
Expand Down
Loading

0 comments on commit 7193727

Please sign in to comment.