Skip to content

Commit

Permalink
optimize hlaPredMerge()
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengxwen committed Mar 12, 2024
1 parent 94df3b8 commit 788bace
Show file tree
Hide file tree
Showing 5 changed files with 134 additions and 19 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Package: HIBAG
Type: Package
Title: HLA Genotype Imputation with Attribute Bagging
Version: 1.38.3
Date: 2024-01-29
Version: 1.38.4
Date: 2024-03-11
Depends: R (>= 3.2.0)
Imports: methods, RcppParallel
Suggests: parallel, ggplot2, reshape2, gdsfmt, SNPRelate, SeqArray, knitr,
Expand Down
3 changes: 2 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ useDynLib(HIBAG,
HIBAG_SortAlleleStr, HIBAG_Kernel_Version, HIBAG_Kernel_SetTarget,
HIBAG_Predict_Resp, HIBAG_Predict_Dosage, HIBAG_Predict_Resp_Prob,
HIBAG_Training, HIBAG_SeqMerge, HIBAG_SeqRmDot,
HIBAG_bgzip_create
HIBAG_bgzip_create,
HIBAG_SumList, HIBAG_UpdateAddProbW, HIBAG_NormalizeProb
)

# Export function names
Expand Down
45 changes: 33 additions & 12 deletions R/HIBAG.R
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,8 @@ hlaPredict <- function(object, snp, cl=FALSE,
#

hlaPredMerge <- function(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
ret.dosage=TRUE, ret.postprob=FALSE, max.resolution="", rm.suffix=FALSE)
ret.dosage=TRUE, ret.postprob=FALSE, max.resolution="", rm.suffix=FALSE,
verbose=TRUE)
{
# check "..."
pdlist <- list(...)
Expand All @@ -843,6 +844,13 @@ hlaPredMerge <- function(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
"'hlaPredict(..., type=\"response+prob\")'.")
}
}
stopifnot(is.logical(verbose), length(verbose)==1L)
if (verbose)
{
cat("Aggregate ", length(pdlist), " set",
if (length(pdlist) > 1L) "s" else "",
" of predictions:\n", sep="")
}

# check equivalence
stopifnot(is.null(equivalence) | is.data.frame(equivalence))
Expand Down Expand Up @@ -924,11 +932,20 @@ hlaPredMerge <- function(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
for (i in seq_along(pdlist))
{
h <- unique(unlist(strsplit(rownames(pdlist[[i]]$postprob), "/")))
hla.allele <- unique(c(hla.allele, replace(h)))
nh <- replace(h)
hla.allele <- unique(c(hla.allele, nh))
if (verbose)
{
cat(" ", i, ". # of unique alleles: ", length(h), sep="")
if (!is.null(equivalence)) cat(" ==> ", length(nh))
cat("\n")
}
}
hla.allele <- hlaUniqueAllele(hla.allele)
n.hla <- length(hla.allele)
n.samp <- length(samp.id)
if (verbose)
cat("# of unique allele in the merged set = ", n.hla, "\n", sep="")

prob <- matrix(0.0, nrow=n.hla*(n.hla+1L)/2L, ncol=n.samp)
m <- outer(hla.allele, hla.allele, paste, sep="/")
Expand All @@ -939,12 +956,16 @@ hlaPredMerge <- function(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
# matching probabilities
has.matching <- all(
vapply(pdlist, function(x) !is.null(x$value$matching), TRUE))
if (has.matching) has.matching <- 0
if (has.matching)
{
# sum(i, weight[i] * pdlist[[i]]$value$matching)
has.matching <- .Call(HIBAG_SumList, weight,
lapply(pdlist, function(x) x$value$matching))
}

# for-loop
for (i in seq_along(pdlist))
{
w <- weight[i]
p <- pdlist[[i]]$postprob
h <- replace(unlist(strsplit(rownames(p), "/", fixed=TRUE)))
h1 <- h[seq(1L, length(h), 2L)]
Expand All @@ -953,17 +974,17 @@ hlaPredMerge <- function(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
j2 <- match(paste(h2, h1, sep="/"), m)
j1[is.na(j1)] <- j2[is.na(j1)]
stopifnot(!anyNA(j1)) # check
if (use.matching)
p <- sweep(p, 2L, pdlist[[i]]$value$matching, "*")
p <- p * w
for (j in seq_along(j1))
prob[j1[j], ] <- prob[j1[j], ] + p[j, ]
if (is.numeric(has.matching))
has.matching <- has.matching + w * pdlist[[i]]$value$matching
# update probabilities, equal R code:
# if (use.matching)
# p <- sweep(p, 2L, pdlist[[i]]$value$matching, "*")
# prob[j1, ] <- prob[j1, ] + p * weight[i]
.Call(HIBAG_UpdateAddProbW, prob, j1, p, weight[i],
if (use.matching) pdlist[[i]]$value$matching else NULL)
}

# normalize prob
prob <- sweep(prob, 2L, colSums(prob), "/")
# equal: prob <- sweep(prob, 2L, colSums(prob), "/")
.Call(HIBAG_NormalizeProb, prob)
pb <- apply(prob, 2L, max)
pt <- unlist(strsplit(m[apply(prob, 2L, which.max)], "/", fixed=TRUE))
assembly <- pdlist[[1L]]$assembly
Expand Down
7 changes: 5 additions & 2 deletions man/hlaPredMerge.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ HLA types.
}
\usage{
hlaPredMerge(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
ret.dosage=TRUE, ret.postprob=FALSE, max.resolution="", rm.suffix=FALSE)
ret.dosage=TRUE, ret.postprob=FALSE, max.resolution="", rm.suffix=FALSE,
verbose=TRUE)
}
\arguments{
\item{...}{The object(s) of \code{\link{hlaAlleleClass}}, having a field
Expand All @@ -20,7 +21,8 @@ hlaPredMerge(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
sample sizes}
\item{equivalence}{a \code{data.frame} with two columns, the first column
for new equivalent alleles, and the second for the alleles possibly
existed in the object(s) passed to this function}
exist in the object(s) passed to this function; there is no replace
if the allele is not found in the second column}
\item{use.matching}{if \code{TRUE}, use actual probabilities (i.e.,
poster prob. * matching) for merging; otherwise, use poster prob.
instead. \code{use.matching=TRUE} is recommended.}
Expand All @@ -32,6 +34,7 @@ hlaPredMerge(..., weight=NULL, equivalence=NULL, use.matching=TRUE,
"" for no limit on resolution}
\item{rm.suffix}{whether remove the non-digit suffix in the last field,
e.g., for "01:22N", "N" is a non-digit suffix}
\item{verbose}{if \code{TRUE}, show information}
}
\details{
Calculate a new probability matrix for each pair of HLA alleles, by
Expand Down
94 changes: 92 additions & 2 deletions src/HIBAG.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// ===============================================================
//
// HIBAG R package (HLA Genotype Imputation with Attribute Bagging)
// Copyright (C) 2011-2022 Xiuwen Zheng (zhengx@u.washington.edu)
// Copyright (C) 2011-2024 Xiuwen Zheng (zhengx@u.washington.edu)
// All rights reserved.
//
// This program is free software: you can redistribute it and/or modify
Expand Down Expand Up @@ -1446,7 +1446,7 @@ SEXP HIBAG_Clear_GPU()


/**
* Call by .onLoad()
* Called by .onLoad()
**/
SEXP HIBAG_Init(SEXP data_lst)
{
Expand All @@ -1457,6 +1457,96 @@ SEXP HIBAG_Init(SEXP data_lst)
}


/**
* Sum the vectors in the list with weights
**/
SEXP HIBAG_SumList(SEXP weight, SEXP list)
{
const int n = Rf_length(weight);
const int nl = Rf_length(list);
if (n != nl)
Rf_error("HIBAG_SumList error.");
if (n > 0)
{
const int nx = Rf_length(VECTOR_ELT(list, 0));
SEXP rv_ans = PROTECT(NEW_NUMERIC(nx));
double *pv = REAL(rv_ans);
memset(pv, 0, sizeof(double)*nx); // zero fill
const double *pw = REAL(weight);
for (int i=0; i < n; i++)
{
SEXP y = VECTOR_ELT(list, i);
if (nx != Rf_length(y))
Rf_error("HIBAG_SumList, list length error.");
const double *py = REAL(y);
const double w = pw[i];
for (int j=0; j < nx; j++) pv[j] += w * py[j];
}
UNPROTECT(1);
return rv_ans;
} else {
return Rf_ScalarLogical(FALSE);
}
}


/**
* Update out_prob by adding weighted in_p
**/
SEXP HIBAG_UpdateAddProbW(SEXP out_prob, SEXP ii, SEXP in_p, SEXP weight,
SEXP matching)
{
// check
if (!Rf_isMatrix(out_prob))
Rf_error("HIBAG_UpdateAddProbW out_prob error.");
if (!Rf_isMatrix(in_p))
Rf_error("HIBAG_UpdateAddProbW in_p error.");
const int *n1 = INTEGER(GET_DIM(out_prob));
const int *n2 = INTEGER(GET_DIM(in_p));
if (n1[1] != n2[1])
Rf_error("HIBAG_UpdateAddProbW dim(prob) error.");
if (Rf_length(ii) != n2[0])
Rf_error("HIBAG_UpdateAddProbW ii error.");
if (Rf_length(matching) != n2[1])
Rf_error("HIBAG_UpdateAddProbW matching error.");
// update
const double w = Rf_asReal(weight);
const double *pm = !Rf_isNull(matching) ? REAL(matching) : NULL;
double *po = REAL(out_prob);
double *pi = REAL(in_p);
const int *I = INTEGER(ii);
for (int i=0; i < n2[1]; i++)
{
const double w2 = pm ? w * pm[i] : w;
for (int j=0; j < n2[0]; j++)
po[I[j] - 1] += pi[j] * w2;
po += n1[0];
pi += n2[0];
}
return out_prob;
}


/**
* Normalize the prob matrix for each column
**/
SEXP HIBAG_NormalizeProb(SEXP prob)
{
if (!Rf_isMatrix(prob))
Rf_error("HIBAG_NormalizeProb prob error.");
const int *dm = INTEGER(GET_DIM(prob));
double *p = REAL(prob);
for (int i=0; i < dm[1]; i++)
{
double sum = 0;
for (int j=0; j < dm[0]; j++) sum += p[j];
for (int j=0; j < dm[0]; j++) p[j] /= sum;
p += dm[0];
}
return prob;
}


// -----------------------------------------------------------------------
// -----------------------------------------------------------------------

Expand Down

0 comments on commit 788bace

Please sign in to comment.