Skip to content

Commit

Permalink
Canonical parameters now Functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
tgvaughan committed Oct 14, 2024
1 parent 5cbae33 commit dbeae8c
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 71 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ protected Double updateRatesAndTimes(TreeInterface tree) {
if (origin.get() == null)
maxTime = treeInput.get().getRoot().getHeight();
else
maxTime = originIsRootEdge.get()? treeInput.get().getRoot().getHeight() + origin.get().getValue() : origin.get().getValue();
maxTime = originIsRootEdge.get()? treeInput.get().getRoot().getHeight() + origin.get().getArrayValue() : origin.get().getArrayValue();

/* the cut-off time (x_cut) should be smaller than the youngest internal node and fossil, as we assume sampling
exactly one representative extant species per clade and no fossil sampling between x_cut and the present */
Expand Down Expand Up @@ -61,35 +61,35 @@ else if (node.getHeight()/maxTime <= 1e-8) // allow some tolerance
times = timesSet.toArray(new Double[timesSet.size()]);

// increase dimension of birth rate
Double[] tempb = new Double[totalIntervals + 1];
double[] tempb = new double[totalIntervals + 1];
System.arraycopy(birth, 0, tempb, 0, totalIntervals);
tempb[totalIntervals] = birth[totalIntervals - 1];
birth = tempb;

// increase dimension of death rate
Double[] tempd = new Double[totalIntervals + 1];
double[] tempd = new double[totalIntervals + 1];
System.arraycopy(death, 0, tempd, 0, totalIntervals);
tempd[totalIntervals] = death[totalIntervals - 1];
death = tempd;

// increase dimension of sampling rate
Double[] tempp = new Double[totalIntervals + 1];
double[] tempp = new double[totalIntervals + 1];
System.arraycopy(psi, 0, tempp, 0, totalIntervals);
tempp[totalIntervals] = 0.0; // add 0 to the last entry
psi = tempp;

// set extant sampling proportion here as rho will be overwritten
samplingProp = rho[totalIntervals - 1];
// increase dimension of rho rate, add 0 to the last entry
Double[] tempr = new Double[totalIntervals + 1];
double[] tempr = new double[totalIntervals + 1];
System.arraycopy(rho, 0, tempr, 0, totalIntervals - 1);
tempr[totalIntervals - 1] = 0.0;
tempr[totalIntervals] = 1.0; // assuming complete sampling
rho = tempr;

// increase dimension of r rate
if (SAModel) {
Double[] temp = new Double[totalIntervals + 1];
double[] temp = new double[totalIntervals + 1];
System.arraycopy(r, 0, temp, 0, totalIntervals);
temp[totalIntervals] = r[totalIntervals - 1];
r = temp;
Expand Down
132 changes: 67 additions & 65 deletions src/bdsky/evolution/speciation/BirthDeathSkylineModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -82,31 +82,31 @@ public class BirthDeathSkylineModel extends SpeciesTreeDistribution {

// the times for rho sampling
public Input<RealParameter> rhoSamplingTimes =
new Input<RealParameter>("rhoSamplingTimes", "The times t_i specifying when rho-sampling occurs", (RealParameter) null);
new Input<RealParameter>("rhoSamplingTimes", "The times t_i specifying when rho-sampling occurs");


public Input<RealParameter> origin =
new Input<RealParameter>("origin", "The time from origin to last sample (must be larger than tree height)", (RealParameter) null);
public Input<Function> origin =
new Input<>("origin", "The time from origin to last sample (must be larger than tree height)");

public Input<Boolean> originIsRootEdge =
new Input<>("originIsRootEdge", "The origin is only the length of the root edge", false);

public Input<Boolean> conditionOnRootInput = new Input<Boolean>("conditionOnRoot", "the tree " +
"likelihood is conditioned on the root height otherwise on the time of origin", false);

public Input<RealParameter> birthRate =
new Input<RealParameter>("birthRate", "BirthRate = BirthRateVector * birthRateScalar, birthrate can change over time");
public Input<RealParameter> deathRate =
new Input<RealParameter>("deathRate", "The deathRate vector with birthRates between times");
public Input<RealParameter> samplingRate =
new Input<RealParameter>("samplingRate", "The sampling rate per individual"); // psi
public Input<RealParameter> removalProbability =
new Input<RealParameter>("removalProbability", "The probability of an individual to become noninfectious immediately after the sampling");

public Input<RealParameter> m_rho =
new Input<RealParameter>("rho", "The proportion of lineages sampled at rho-sampling times (default 0.)");
public Input<Function> birthRate =
new Input<>("birthRate", "BirthRate = BirthRateVector * birthRateScalar, birthrate can change over time");
public Input<Function> deathRate =
new Input<>("deathRate", "The deathRate vector with birthRates between times");
public Input<Function> samplingRate =
new Input<>("samplingRate", "The sampling rate per individual"); // psi
public Input<Function> removalProbability =
new Input<>("removalProbability", "The probability of an individual to become noninfectious immediately after the sampling");

public Input<Function> m_rho =
new Input<>("rho", "The proportion of lineages sampled at rho-sampling times (default 0.)");
public Input<Boolean> contemp =
new Input<Boolean>("contemp", "Only contemporaneous sampling (i.e. all tips are from same sampling time, default false)", false);
new Input<>("contemp", "Only contemporaneous sampling (i.e. all tips are from same sampling time, default false)", false);


public Input<Function> reproductiveNumberInput =
Expand Down Expand Up @@ -145,11 +145,11 @@ public class BirthDeathSkylineModel extends SpeciesTreeDistribution {
protected double taxonAge;

// these four arrays are totalIntervals in length
protected Double[] birth;
protected Double[] death;
protected Double[] psi;
protected Double[] rho;
protected Double[] r;
protected double[] birth;
protected double[] death;
protected double[] psi;
protected double[] rho;
protected double[] r;

// true if the node of the given index occurs at the time of a rho-sampling event
protected boolean[] isRhoTip;
Expand Down Expand Up @@ -222,8 +222,8 @@ public void initAndValidate() {
throw new RuntimeException("Origin parameter is not set when conditioning on the origin!");
if (conditionOnRootInput.get() && origin.get() != null)
throw new RuntimeException("Origin parameter should not be set when conditioning on the root!");
if (origin.get() != null && (!originIsRootEdge.get() && treeInput.get().getRoot().getHeight() >= origin.get().getValue()))
throw new RuntimeException("Origin parameter ("+origin.get().getValue()+") must be larger than " +
if (origin.get() != null && (!originIsRootEdge.get() && treeInput.get().getRoot().getHeight() >= origin.get().getArrayValue()))
throw new RuntimeException("Origin parameter ("+origin.get().getArrayValue()+") must be larger than " +
"tree height("+treeInput.get().getRoot().getHeight()+"). Please change initial origin value!");

if (removalProbability.get() != null) SAModel = true;
Expand Down Expand Up @@ -263,10 +263,10 @@ public void initAndValidate() {

if (birthRate.get() != null && deathRate.get() != null && samplingRate.get() != null) {

birth = birthRate.get().getValues();
death = deathRate.get().getValues();
psi = samplingRate.get().getValues();
if (SAModel) r = removalProbability.get().getValues();
birth = birthRate.get().getDoubleValues();
death = deathRate.get().getDoubleValues();
psi = samplingRate.get().getDoubleValues();
if (SAModel) r = removalProbability.get().getDoubleValues();

} else if (( reproductiveNumberInput.get() != null || transform_lambda_base) && becomeUninfectiousRate.get() != null && samplingProportion.get() != null) {

Expand Down Expand Up @@ -307,7 +307,7 @@ public void initAndValidate() {
if (SAModel) rChanges = removalProbability.get().getDimension() -1;

if (m_rho.get() != null) {
rho = m_rho.get().getValues();
rho = m_rho.get().getDoubleValues();
rhoChanges = m_rho.get().getDimension() - 1;
}

Expand All @@ -318,7 +318,7 @@ public void initAndValidate() {

if (m_rho.get().getDimension() == 1 && rhoSamplingTimes.get() == null || rhoSamplingTimes.get().getDimension() < 2) {
if (!contempData && ((samplingProportion.get() != null && samplingProportion.get().getDimension() == 1 && samplingProportion.get().getArrayValue() == 0.) ||
(samplingRate.get() != null && samplingRate.get().getDimension() == 1 && samplingRate.get().getValue() == 0.))) {
(samplingRate.get() != null && samplingRate.get().getDimension() == 1 && samplingRate.get().getArrayValue() == 0.))) {
contempData = true;
if (printTempResults)
System.out.println("Parameters were chosen for contemporaneously sampled data. Setting contemp=true.");
Expand All @@ -330,15 +330,15 @@ public void initAndValidate() {
throw new RuntimeException("when contemp=true, rho must have dimension 1");

else {
rho = new Double[totalIntervals];
rho = new double[totalIntervals];
Arrays.fill(rho, 0.);
rho[totalIntervals - 1] = m_rho.get().getValue();
rho[totalIntervals - 1] = m_rho.get().getArrayValue();
// rhoSamplingCount = 1;
}
}

} else {
rho = new Double[totalIntervals];
rho = new double[totalIntervals];
Arrays.fill(rho, 0.);
}
isRhoTip = new boolean[treeInput.get().getLeafNodeCount()];
Expand All @@ -362,16 +362,18 @@ public void initAndValidate() {
// sanity check for sampled ancestor analysis
// make sure that operators are valid for such an analysis
boolean isSAAnalysis = false;
if (removalProbability.get() != null && removalProbability.get().getValue() >= 1.0 && removalProbability.get().isEstimatedInput.get()) {
if (removalProbability.get() != null && removalProbability.get().getArrayValue() >= 1.0
&& removalProbability.get() instanceof RealParameter
&& ((RealParameter)removalProbability.get()).isEstimatedInput.get()) {
// default parameters have estimated=true by default.
// check there is an operator on this parameter
for (BEASTInterface o : removalProbability.get().getOutputs()) {
for (BEASTInterface o : ((RealParameter)removalProbability.get()).getOutputs()) {
if (o instanceof Operator) {
isSAAnalysis = true;
}
}
}
if (removalProbability.get() != null && removalProbability.get().getValue() < 1.0 || isSAAnalysis) {
if (removalProbability.get() != null && removalProbability.get().getArrayValue() < 1.0 || isSAAnalysis) {
// this is a sampled ancestor analysis
// check that there are no invalid operators in this analysis
List<Operator> operators = getOperators(this);
Expand Down Expand Up @@ -446,7 +448,7 @@ private boolean rhoSamplingConditionHolds() {

if (SAModel) {
for (int i=0; i<removalProbability.get().getDimension(); i++) {
if (removalProbability.get().getValue(i) != 0.0) {
if (removalProbability.get().getArrayValue(i) != 0.0) {
return false;
}
}
Expand All @@ -472,7 +474,7 @@ public void getChangeTimes(List<Double> changeTimes, RealParameter intervalTimes
double maxTime;

if (origin.get() != null) {
maxTime = originIsRootEdge.get()? treeInput.get().getRoot().getHeight() + origin.get().getValue() :origin.get().getValue();
maxTime = originIsRootEdge.get() ? treeInput.get().getRoot().getHeight() + origin.get().getArrayValue() : origin.get().getArrayValue();
} else {
maxTime = treeInput.get().getRoot().getHeight();
}
Expand Down Expand Up @@ -625,7 +627,7 @@ protected Double updateRatesAndTimes(TreeInterface tree) {

double t_root = tree.getRoot().getHeight();

if (origin.get() != null && (m_forceRateChange && timesSet.last() > (originIsRootEdge.get()? t_root+ origin.get().getValue() : origin.get().getValue()))) {
if (origin.get() != null && (m_forceRateChange && timesSet.last() > (originIsRootEdge.get() ? t_root+ origin.get().getArrayValue() : origin.get().getArrayValue()))) {
return Double.NEGATIVE_INFINITY;
}

Expand All @@ -639,16 +641,16 @@ else if (transform_d_r_s)
transformParameters_d_r_s();
else {

Double[] birthRates = birthRate.get().getValues();
Double[] deathRates = deathRate.get().getValues();
Double[] samplingRates = samplingRate.get().getValues();
Double[] removalProbabilities = new Double[1];
if (SAModel) removalProbabilities = removalProbability.get().getValues();
double[] birthRates = birthRate.get().getDoubleValues();
double[] deathRates = deathRate.get().getDoubleValues();
double[] samplingRates = samplingRate.get().getDoubleValues();
double[] removalProbabilities = new double[1];
if (SAModel) removalProbabilities = removalProbability.get().getDoubleValues();

birth = new Double[totalIntervals];
death = new Double[totalIntervals];
psi = new Double[totalIntervals];
if (SAModel) r = new Double[totalIntervals];
birth = new double[totalIntervals];
death = new double[totalIntervals];
psi = new double[totalIntervals];
if (SAModel) r = new double[totalIntervals];

birth[0] = birthRates[0];

Expand All @@ -671,8 +673,8 @@ else if (transform_d_r_s)

if (m_rho.get() != null && (m_rho.get().getDimension()==1 || rhoSamplingTimes.get() != null)) {

Double[] rhos = m_rho.get().getValues();
rho = new Double[totalIntervals];
double[] rhos = m_rho.get().getDoubleValues();
rho = new double[totalIntervals];

for (int i = 0; i < totalIntervals; i++) {

Expand All @@ -693,7 +695,7 @@ else if (transform_d_r_s)
/* calculate and store Ai, Bi and p0 */
public Double preCalculation(TreeInterface tree) {

if (origin.get() != null && (!originIsRootEdge.get() && tree.getRoot().getHeight() >= origin.get().getValue()) && taxonInput.get() == null ) {
if (origin.get() != null && (!originIsRootEdge.get() && tree.getRoot().getHeight() >= origin.get().getArrayValue()) && taxonInput.get() == null ) {
return Double.NEGATIVE_INFINITY;
}

Expand All @@ -706,13 +708,13 @@ public Double preCalculation(TreeInterface tree) {

if (m_rho.get() != null) {
if (contempData) {
rho = new Double[totalIntervals];
rho = new double[totalIntervals];
Arrays.fill(rho, 0.);
rho[totalIntervals-1] = m_rho.get().getValue();
rho[totalIntervals-1] = m_rho.get().getArrayValue();
}

} else {
rho = new Double[totalIntervals];
rho = new double[totalIntervals];
Arrays.fill(rho, 0.0);
}

Expand Down Expand Up @@ -926,13 +928,13 @@ protected void transformParameters() {
double birth_ratio = transform_lambda_base ? lambda_ratioInput.get().getDoubleValues()[0] : 1;
double[] b = becomeUninfectiousRate.get().getDoubleValues(); // delta = mu + psi*r
double[] p = samplingProportion.get().getDoubleValues(); // if SAModel: s = psi/(mu+psi)
Double[] removalProbabilities = new Double[1];
if (SAModel) removalProbabilities = removalProbability.get().getValues();
double[] removalProbabilities = new double[1];
if (SAModel) removalProbabilities = removalProbability.get().getDoubleValues();

birth = new Double[totalIntervals];
death = new Double[totalIntervals];
psi = new Double[totalIntervals];
if (SAModel) r = new Double[totalIntervals];
birth = new double[totalIntervals];
death = new double[totalIntervals];
psi = new double[totalIntervals];
if (SAModel) r = new double[totalIntervals];

if (isBDSIR()) birth[0] = R[0] * b[0]; // the rest will be done in BDSIR class

Expand All @@ -953,10 +955,10 @@ protected void transformParameters() {

protected void transformParameters_d_r_s() {

birth = new Double[totalIntervals];
death = new Double[totalIntervals];
psi = new Double[totalIntervals];
if (SAModel) r = new Double[totalIntervals];
birth = new double[totalIntervals];
death = new double[totalIntervals];
psi = new double[totalIntervals];
if (SAModel) r = new double[totalIntervals];

/* nd = lambda - mu - r * psi lambda = nd / (1 - to)
to = (mu + r * psi) / lambda --> psi = lambda * to * sp / (1 - sp + r * sp)
Expand All @@ -972,14 +974,14 @@ protected void transformParameters_d_r_s() {
birth[i] = nd[index(times[i], birthRateChangeTimes)] / (1 - to[index(times[i], deathRateChangeTimes)]);
}
} else { // lambda-turnover-samplingproportion parametrization
Double[] br = birthRate.get().getValues();
double[] br = birthRate.get().getDoubleValues();
for (int i = 0; i < totalIntervals; i++) {
birth[i] = br[index(times[i], birthRateChangeTimes)];
}
}

if (SAModel) {
Double[] rp = removalProbability.get().getValues();
double[] rp = removalProbability.get().getDoubleValues();
for (int i = 0; i < totalIntervals; i++) {
r[i] = rp[index(times[i], rChangeTimes)];
psi[i] = birth[i] * to[index(times[i], deathRateChangeTimes)] / (1 / sp[index(times[i], samplingRateChangeTimes)] - 1 + r[i]);
Expand Down Expand Up @@ -1047,7 +1049,7 @@ public double calculateTreeLogLikelihood(TreeInterface tree) {
return logP;

if (taxonInput.get() != null) {
if (taxonAge > origin.get().getValue()) {
if (taxonAge > origin.get().getArrayValue()) {
return Double.NEGATIVE_INFINITY;
}
double x = times[totalIntervals - 1] - taxonAge;
Expand Down

0 comments on commit dbeae8c

Please sign in to comment.