Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/BEAST2-Dev/bdsky
Browse files Browse the repository at this point in the history
  • Loading branch information
denisekuehnert committed Jun 29, 2016
2 parents 81b524a + 77048ba commit f514dc0
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 19 deletions.
67 changes: 58 additions & 9 deletions src/beast/evolution/speciation/BirthDeathSkylineModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ public class BirthDeathSkylineModel extends SpeciesTreeDistribution {
public Input<RealParameter> samplingProportion =
new Input<RealParameter>("samplingProportion", "The samplingProportion = samplingRate / becomeUninfectiousRate", Input.Validate.XOR, samplingRate);

public Input<RealParameter> netDiversification =
new Input<RealParameter>("netDiversification", "Net diversification rate", Input.Validate.XOR, birthRate);
public Input<RealParameter> turnOver =
new Input<RealParameter>("turnOver", "Turn over rate", Input.Validate.XOR, deathRate);
// public Input<RealParameter> samplingProportion =
// new Input<RealParameter>("samplingProportion", "Sampling proportion", Input.Validate.XOR, samplingRate);

public Input<Boolean> forceRateChange =
new Input<Boolean>("forceRateChange", "If there is more than one interval and we estimate the time of rate change, do we enforce it to be within the tree interval? Default true", true);
Expand Down Expand Up @@ -175,7 +181,7 @@ public class BirthDeathSkylineModel extends SpeciesTreeDistribution {

protected Double[] times = new Double[]{0.};

protected Boolean transform;
protected Boolean transform, transform_d_r_s;
Boolean m_forceRateChange;

Boolean birthRateTimesRelative = false;
Expand Down Expand Up @@ -226,10 +232,9 @@ public void initAndValidate() {
rhoSamplingCount = 0;
printTempResults = false;


transform = transform_d_r_s = false;
if (birthRate.get() != null && deathRate.get() != null && samplingRate.get() != null) {

transform = false;
death = deathRate.get().getValues();
psi = samplingRate.get().getValues();
birth = birthRate.get().getValues();
Expand All @@ -239,8 +244,14 @@ public void initAndValidate() {

transform = true;

} else if (netDiversification.get() != null && turnOver.get() != null && samplingProportion.get() != null) {

transform_d_r_s = true;

} else {
throw new RuntimeException("Either specify birthRate, deathRate and samplingRate OR specify R0, becomeUninfectiousRate and samplingProportion!");
throw new RuntimeException("Either specify birthRate, deathRate and samplingRate " +
"OR specify R0, becomeUninfectiousRate and samplingProportion " +
"OR specify netDiversification, turnOver and samplingProportion!");
}


Expand All @@ -250,6 +261,12 @@ public void initAndValidate() {
samplingChanges = samplingProportion.get().getDimension() - 1;
deathChanges = becomeUninfectiousRate.get().getDimension() - 1;

} else if (transform_d_r_s) {

if (birthChanges < 1) birthChanges = netDiversification.get().getDimension() - 1;
deathChanges = turnOver.get().getDimension() - 1;
samplingChanges = samplingProportion.get().getDimension() - 1;

} else {

if (birthChanges < 1) birthChanges = birthRate.get().getDimension() - 1;
Expand Down Expand Up @@ -562,6 +579,8 @@ protected Double updateRatesAndTimes(TreeInterface tree) {

if (transform)
transformParameters();
else if (transform_d_r_s)
transformParameters_d_r_s();
else {

Double[] birthRates = birthRate.get().getValues();
Expand Down Expand Up @@ -592,9 +611,6 @@ protected Double updateRatesAndTimes(TreeInterface tree) {
}
}




if (m_rho.get() != null && (m_rho.get().getDimension()==1 || rhoSamplingTimes.get() != null)) {

Double[] rhos = m_rho.get().getValues();
Expand Down Expand Up @@ -852,7 +868,7 @@ protected void transformParameters() {

Double[] R = R0.get().getValues(); // if SAModel: R0 = lambda/delta
Double[] b = becomeUninfectiousRate.get().getValues(); // delta = mu + psi*r
Double[] p = samplingProportion.get().getValues(); // if SAModel: s = psi/(mu+psi*r)
Double[] p = samplingProportion.get().getValues(); // if SAModel: s = psi/(mu+psi)
Double[] removalProbabilities = new Double[1];
if (SAModel) removalProbabilities = removalProbability.get().getValues();

Expand All @@ -871,7 +887,7 @@ protected void transformParameters() {
} else {
birth[i] = R[birthChanges > 0 ? index(times[i], birthRateChangeTimes) : 0] * b[deathChanges > 0 ? index(times[i], deathRateChangeTimes) : 0];
r[i] = removalProbabilities[rChanges > 0 ? index(times[i], rChangeTimes) : 0];
psi[i] = p[samplingChanges > 0 ? index(times[i], samplingRateChangeTimes) : 0] * b[deathChanges > 0 ? index(times[i], deathRateChangeTimes) : 0]
psi[i] = p[samplingChanges > 0 ? index(times[i], samplingRateChangeTimes) : 0] * b[deathChanges > 0 ? index(times[i], deathRateChangeTimes) : 0]
/ (1+(r[i]-1)*p[samplingChanges > 0 ? index(times[i], samplingRateChangeTimes) : 0]);
death[i] = b[deathChanges > 0 ? index(times[i], deathRateChangeTimes) : 0] - psi[i]*r[i];

Expand All @@ -881,6 +897,39 @@ protected void transformParameters() {
}
}

protected void transformParameters_d_r_s() {

Double[] nd = netDiversification.get().getValues();
Double[] to = turnOver.get().getValues();
Double[] sp = samplingProportion.get().getValues();
Double[] rp = new Double[1];
if (SAModel) rp = removalProbability.get().getValues();

birth = new Double[totalIntervals];
death = new Double[totalIntervals];
psi = new Double[totalIntervals];
if (SAModel) r = new Double[totalIntervals];

/* nd = lambda - mu - r * psi lambda = nd / (1 - to)
to = (mu + r * psi) / lambda --> mu = lambda * to * (1 - sp) / (1 - sp + r * sp)
sp = psi / (mu + psi) psi = mu * sp / (1 - sp)
SAModel: 0 <= rp < 1; No SA: rp = 1
Relation to transform: nd = (R0 - 1) * delta, to = 1/R0, sp = s
*/ // isBDSIR()???
for (int i = 0; i < totalIntervals; i++) {
birth[i] = nd[birthChanges > 0 ? index(times[i], birthRateChangeTimes) : 0] / (1 - to[deathChanges > 0 ? index(times[i], deathRateChangeTimes) : 0]);
if (SAModel) {
r[i] = rp[rChanges > 0 ? index(times[i], rChangeTimes) : 0];
death[i] = birth[i] * to[deathChanges > 0 ? index(times[i], deathRateChangeTimes) : 0]
/ (1 + r[i] / (1 / sp[samplingChanges > 0 ? index(times[i], samplingRateChangeTimes) : 0] - 1));
} else {
death[i] = birth[i] * to[deathChanges > 0 ? index(times[i], deathRateChangeTimes) : 0]
* (1 - sp[samplingChanges > 0 ? index(times[i], samplingRateChangeTimes) : 0]);
}
psi[i] = death[i] / (1 / sp[samplingChanges > 0 ? index(times[i], samplingRateChangeTimes) : 0] - 1);
}
}

@Override
public double calculateTreeLogLikelihood(TreeInterface tree) {

Expand Down
Loading

0 comments on commit f514dc0

Please sign in to comment.