Skip to content

Commit

Permalink
Separate data from code:
Browse files Browse the repository at this point in the history
* Update interface to use new parameterisation
* Update interface to take beta values again
* Remove obsolete population/matrix getters
* Update e2e test
* Update data exporter
  • Loading branch information
giovannic committed May 1, 2020
1 parent 4fc766a commit e7f6907
Show file tree
Hide file tree
Showing 11 changed files with 91 additions and 319 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/e2e.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
- name: Install squire
run: |
R.exe -e "install.packages(c('V8', 'odin', 'deSolve', 'jsonlite', 'remotes'))"
R.exe -e "remotes::install_github(c('mrc-ide/odin.js', 'mrc-ide/squire@squire_js'))"
R.exe -e "remotes::install_github(c('mrc-ide/odin.js', 'mrc-ide/squire@feat/shared_parameters'))"
shell: pwsh
- name: Install Chrome
run: |
Expand Down
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ RUN R -e 'install.packages( \
c("V8", "odin", "deSolve", "jsonlite", "remotes"))'

RUN R -e 'library(remotes); install_github(c("mrc-ide/odin.js", \
"mrc-ide/squire"))'
"mrc-ide/squire@feat/shared_parameters"))'

# Install node
RUN apt install -y curl
Expand Down
47 changes: 20 additions & 27 deletions R/export.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,27 @@ if (length(args) != 1) {
stop("The only argument is the output directory")
}

countries <- get_lmic_countries()
populations <- lapply(countries, function(c) get_population(c)$n)
names(populations) <- countries
countries <- unique(squire::population$country)
names(countries) <- unique(squire::population$iso3c)

processMatrix <- function(c) {
m <- get_mixing_matrix(c)
p <- get_population(c)$n
m <- process_contact_matrix_scaled_age(m, p)
mm <- t(t(m) / p)
aperm(array(c(mm), dim = c(dim(mm), 1)), c(3, 1, 2))
}
out_dir <- args[1]
for (iso3c in names(countries)) {
pars <- parameters_explicit_SEEIR(
countries[[iso3c]],
hosp_bed_capacity = 1, #dummy value
ICU_bed_capacity = 1 #dummy value
)

matrices <- lapply(countries, processMatrix)
names(matrices) <- countries
country_data <- list(
population = pars$population,
contactMatrix = pars$mix_mat_set,
beta = pars$beta
)

dur_R <- 2.09
dur_hosp <- 5
processEigen <- function(c) {
m <- get_mixing_matrix(c)
p <- get_population(c)$n
m <- process_contact_matrix_scaled_age(m, p)
adjusted_eigen(dur_R, dur_hosp, default_probs()$prob_hosp, m)
write_json(
country_data,
file.path(out_dir, paste0(iso3c, '.json')),
matrix='columnmajor',
digits=NA
)
}

eigens <- lapply(countries, processEigen)
names(eigens) <- countries

out_dir <- args[1]
write_json(populations, file.path(out_dir, 'population.json'), digits=NA)
write_json(matrices, file.path(out_dir, 'matrices.json'), matrix='columnmajor', digits=NA)
write_json(eigens, file.path(out_dir, 'eigens.json'), auto_unbox=TRUE, digits=NA)
54 changes: 7 additions & 47 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,43 +36,7 @@ npm test
You can access the model using ES6 import syntax:

```js
import { runModel, getPopulation, getMixingMatrix, getBeta } from './squire.js'
```

#### getPopulation

Returns the population for each age group in a country.

The age groups are fixed to the following 17:

0-4, 5-9, 10-14, 15-19, 20-24, 25-29, 30-34, 35-39,
40-44, 45-49, 50-54, 55-59, 60-64, 65-69, 70-74, 75-79, 80+

```js
getPopulation('Nigeria');
// Outputs array of length 17
```

#### getMixingMatrix

Returns the matrix representing the mixing between age groups in a country.

```js
getMixingMatrix('Nigeria');
// Outputs a 17 x 17 nested array
```

#### estimateBeta

Estimates the transmissibility parameter (beta) for a country, given R0. If an
array of R0 is given, a corresponding array of beta estimates will be returned.

```js
estimateBeta('Nigeria', 3);
// returns an estimate for beta in Nigeria given an R0 of 3

estimateBeta('Nigeria', [3, 3*.5, 3*.3, 3*.25]);
// returns 4 estimates for beta in Nigeria given 4 different R0s
import { runModel } from './squire.js'
```

#### runModel
Expand All @@ -83,8 +47,7 @@ signature:
```js
function runModel(
population,
ttMatrix,
mixMatSet,
contactMatrix,
ttBeta,
betaSet,
nBeds,
Expand All @@ -97,9 +60,7 @@ function runModel(
Parameters:

* population - is an array of populations for each age group
* ttMatrix - is an array of timesteps at which the mixing matrix will change
* ttMatSet - is an array of mixing matrices to change to in line with
`ttMatrix`
* contactMatrix - is the contact matrix to use for the simulation
* ttBeta - is an array of timsteps at which the transmissibility will change
* betaSet - is an array of beta values that will change in line with `ttBeta`
* nBeds - is the country's capacity for hosiptal beds
Expand All @@ -110,12 +71,11 @@ Parameters:
You can get some basic model output by running the following example:

```js
const mm = getMixingMatrix('Nigeria')
const beta = estimateBeta('Nigeria', [3, 3/2, 3])
import nigeriaData from './data/NGA.json';
const beta = [nigeriaData.beta, nigeriaData.beta/2, nigeriaData.beta]);
let results = runModel(
getPopulation('Nigeria'),
[0],
[mm],
nigeriaData.population,
nigeriaData.contactMatrix,
[0, 50, 200],
beta,
10000000000,
Expand Down
42 changes: 26 additions & 16 deletions e2e/test.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import {
approxEqualArray,
getColumn
approxEqualArray
} from './utils.js'

import { flattenNested } from '../src/utils.js'
import { flattenNested } from '../src/utils.js';
import json from '@rollup/plugin-json';

const fs = require('fs')
const webdriver = require('selenium-webdriver');
Expand All @@ -23,31 +23,36 @@ let driver = new webdriver.Builder()
.setChromeOptions(options)
.build();

async function test() {
let bundle = await rollup.rollup({ input: './e2e/test_script.js' });
async function load() {
let bundle = await rollup.rollup({
input: './e2e/test_script.js',
plugins: json()
});
bundle = await bundle.generate({ format: 'es' });
bundle = bundle.output[0].code;
await driver.executeScript(
return await driver.executeScript(
`var s=window.document.createElement('script');
s.type = 'module';
s.textContent = ${bundle};
window.document.head.appendChild(s);`
)
}

async function test() {
let scenario = 0;
let failed = false;
for (const country of [ 'St. Lucia', 'Nigeria', 'India' ]) {
for (const country of [ 'LCA', 'NGA', 'IND' ]) {
for (const bed of [ 100, 100000, 100000000 ]) {
for (const R0 of [ 4, 3, 2, 1 ]) {
let beta = JSON.parse(
fs.readFileSync(`./data/pars_${scenario}.json`)
).beta_set;
let results = await driver.executeScript(
`const mm = getMixingMatrix('${country}');
const beta = estimateBeta('${country}', [${R0}, ${R0/2}]);
return runModel(
getPopulation('${country}'),
[0],
[mm],
`return runModel(
${country}.population,
${country}.contactMatrix,
[0, 50],
beta,
[${beta}],
${bed},
${bed},
0,
Expand Down Expand Up @@ -93,8 +98,13 @@ async function test() {
return failed;
}

test()
.then(failed => { process.exit(failed) })
async function run() {
await load();
const failed = await test();
process.exit(failed);
}

run()
.catch(e => {
console.log(e);
process.exit(1)
Expand Down
17 changes: 8 additions & 9 deletions e2e/test_script.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import {
getPopulation,
getMixingMatrix,
estimateBeta,
runModel
} from '../build/squire.js'
import { runModel } from '../build/squire.js'
import LCA from '../data/LCA.json'
import NGA from '../data/NGA.json'
import IND from '../data/IND.json'


window.getPopulation = getPopulation;
window.getMixingMatrix = getMixingMatrix;
window.estimateBeta = estimateBeta;
window.runModel = runModel;
window.LCA = LCA
window.NGA = NGA;
window.IND = IND;
2 changes: 0 additions & 2 deletions e2e/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,3 @@ export function approxEqualArray(x, y, tolerance) {
}
return xy < tolerance;
}

export function getColumn(y, i) { return y.map(row => row[i]); }
34 changes: 5 additions & 29 deletions src/index.js
Original file line number Diff line number Diff line change
@@ -1,27 +1,9 @@

import { transpose312, reshape3d, wellFormedArray } from './utils.js'
import population from '../data/population.json'
import matrices from '../data/matrices.json'
import eigens from '../data/eigens.json'
import { wellFormedArray } from './utils.js'
import pars from '../data/pars_0.json'

export const getCountries = () => Object.keys(population);
export const getPopulation = (country) => { return population[country] };
export const getMixingMatrix = (country) => { return matrices[country] };
export const estimateBeta = (country, R0) => {
const lambda = eigens[country];
if (lambda == null) {
throw Error("Unknown country");
}
if (Array.isArray(R0)) {
return R0.map(r => { return r / lambda })
}
return R0 / lambda;
};

export const runModel = function(
population,
ttMatrix,
mixMatSet,
ttBeta,
betaSet,
Expand All @@ -35,8 +17,8 @@ export const runModel = function(
throw Error("timeStart is greater than timeEnd");
}

if (!wellFormedArray(mixMatSet, [mixMatSet.length, population.length, population.length])) {
throw Error("mixMatSet must have the dimensions t * nAge * nAge");
if (!wellFormedArray(mixMatSet, [population.length, population.length, 1])) {
throw Error("mixMatSet must have the dimensions nAge x nAge x 1");
}

if (population.length !== mixMatSet[0].length) {
Expand All @@ -47,10 +29,6 @@ export const runModel = function(
throw Error("mismatch between ttBeta and betaSet size");
}

if (ttMatrix.length !== mixMatSet.length) {
throw Error("mismatch between ttMatrix and mixMatSet size");
}

if (nBeds < 0 || nICUBeds < 0) {
throw Error("Bed counts must be greater than or equal to 0");
}
Expand All @@ -64,10 +42,8 @@ export const runModel = function(
const user = {
...pars,
S_0: true_pop,
tt_matrix: ttMatrix,
mix_mat_set: transpose312(
reshape3d(mixMatSet, [nGroups, nGroups, ttMatrix.length])
),
tt_matrix: [0],
mix_mat_set: mixMatSet,
tt_beta: ttBeta,
beta_set: betaSet,
hosp_bed_capacity: nBeds,
Expand Down
49 changes: 0 additions & 49 deletions src/utils.js
Original file line number Diff line number Diff line change
@@ -1,52 +1,3 @@
export const transpose312 = function(array) {
let dim = [3, 1, 2].map(d => getDimSize(array, d));
let transposed = Array(dim[2]).fill(0).map(
() => Array(dim[0]).fill(0).map(
() => Array(dim[1]).fill(0)
)
);
const data = flattenNested(array);
let counter = 0;
for (let third = 0; third < dim[1]; ++third) {
for (let first = 0; first < dim[2]; ++first) {
for (let second = 0; second < dim[0]; ++second) {
transposed[first][second][third] = data[counter];
counter++;
}
}
}
return transposed;
}

export const reshape3d = function(array, dims) {
if (!(dims || dims.length != 3)) {
throw new Error("Dims array is incorrect");
}
let reshaped = Array(dims[2]).fill(0).map(
() => Array(dims[1]).fill(0).map(
() => Array(dims[0]).fill(0)
)
);
const data = flattenNested(array);
let counter = 0;
for (let first = 0; first < dims[2]; ++first) {
for (let second = 0; second < dims[1]; ++second) {
for (let third = 0; third < dims[0]; ++third) {
reshaped[first][second][third] = data[counter];
counter++;
}
}
}
return reshaped;
}

const getDimSize = function(array, dim) {
if (dim == 1) {
return array.length;
}
return getDimSize(array[0], dim - 1);
}

export const flattenNested = function(array) {
const flat = [].concat(...array);
return flat.some(Array.isArray) ? flattenNested(flat) : flat;
Expand Down
Loading

0 comments on commit e7f6907

Please sign in to comment.