Skip to content

Commit

Permalink
Add multivariate distribution
Browse files Browse the repository at this point in the history
  • Loading branch information
jchristopherson committed Feb 12, 2025
1 parent 90314c0 commit 3b959ef
Show file tree
Hide file tree
Showing 4 changed files with 277 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/fstats.f90
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ module fstats
public :: f_distribution
public :: chi_squared_distribution
public :: binomial_distribution
public :: multivariate_distribution
public :: multivariate_distribution_function
public :: multivariate_normal_distribution
public :: mean
public :: variance
public :: standard_deviation
Expand Down
231 changes: 231 additions & 0 deletions src/fstats_distributions.f90
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module fstats_distributions
use ieee_arithmetic
use fstats_special_functions
use fstats_helper_routines
use ferror
use fstats_errors
implicit none
private
public :: distribution
Expand All @@ -13,6 +15,9 @@ module fstats_distributions
public :: f_distribution
public :: chi_squared_distribution
public :: binomial_distribution
public :: multivariate_distribution
public :: multivariate_distribution_function
public :: multivariate_normal_distribution

real(real64), parameter :: pi = 2.0d0 * acos(0.0d0)

Expand Down Expand Up @@ -137,6 +142,48 @@ pure function distribution_property(this) result(rst)
procedure, public :: variance => bd_variance
end type

! ******************************************************************************
! MULTIVARIATE DISTRIBUTIONS
! ------------------------------------------------------------------------------
type, abstract :: multivariate_distribution
!! Defines a multivariate probability distribution.
contains
procedure(multivariate_distribution_function), deferred, pass :: pdf
!! Computes the probability density function.
end type

interface
pure function multivariate_distribution_function(this, x) result(rst)
!! Defines an interface for a multivariate probability distribution
!! function.
use iso_fortran_env, only : real64
import multivariate_distribution
class(multivariate_distribution), intent(in) :: this
!! The distribution object.
real(real64), intent(in), dimension(:) :: x
!! The values at which to evaluate the function.
real(real64) :: rst
!! The value of the function.
end function
end interface

! ------------------------------------------------------------------------------
type, extends(multivariate_distribution) :: multivariate_normal_distribution
!! Defines a multivariate normal (Gaussian) distribution.
real(real64), private, allocatable, dimension(:) :: m_means
!! An N-element array of mean values.
real(real64), private, allocatable, dimension(:,:) :: m_cov
!! The N-by-N covariance matrix. This matrix must be
!! positive-definite.
real(real64), private, allocatable, dimension(:,:) :: m_covInv
!! The N-by-N inverse of the covariance matrix.
real(real64), private :: m_covDet
!! The determinant of the covariance matrix.
contains
procedure, public :: initialize => mvnd_init
procedure, public :: pdf => mvnd_pdf
end type

contains
! ------------------------------------------------------------------------------
pure elemental function dist_std_var(this, x) result(rst)
Expand Down Expand Up @@ -658,5 +705,189 @@ pure function bd_variance(this) result(rst)
rst = this%n * this%p * (1.0d0 - this%p)
end function

! ******************************************************************************
! MULTIVARIATE NORMAL DISTRIBUTION
! ------------------------------------------------------------------------------
subroutine mvnd_init(this, mu, sigma, err)
use linalg, only : cholesky_factor
!! Initializes the multivariate normal distribution by defining the mean
!! values and covariance matrix.
class(multivariate_normal_distribution), intent(inout) :: this
!! The multivariate_normal_distribution object.
real(real64), intent(in), dimension(:) :: mu
!! An N-element array containing the mean values.
real(real64), intent(in), dimension(:,:) :: sigma
!! The N-by-N covariance matrix. The PDF exists only if this matrix
!! is positive-definite; therefore, the positive-definite constraint
!! is checked within this routine and enforced. An error is thrown if
!! the supplied matrix is not positive-definite.
class(errors), intent(inout), optional, target :: err
!! The error handling object.

! Local Variables
integer(int32) :: n, flag
real(real64), allocatable, dimension(:,:) :: L
class(errors), pointer :: errmgr
type(errors), target :: deferr

! Initialization
if (present(err)) then
errmgr => err
else
errmgr => deferr
end if
n = size(mu)

! Input Checking
if (size(sigma, 1) /= n .or. size(sigma, 2) /= n) then
call report_matrix_size_error(errmgr, "mvnd_init", "sigma", n, n, &
size(sigma, 1), size(sigma, 2))
return
end if

! Store the matrices
this%m_means = mu
this%m_cov = sigma
allocate(L(n, n), stat = flag, source = sigma)
if (flag /= 0) go to 10
if (allocated(this%m_covInv)) then
if (size(this%m_covInv, 1) /= n .or. size(this%m_covInv, 2) /= n) then
deallocate(this%m_covInv)
allocate(this%m_covInv(n, n), stat = flag)
if (flag /= 0) go to 10
end if
else
allocate(this%m_covInv(n, n), stat = flag)
if (flag /= 0) go to 10
end if

! Compute the Cholesky factorization of the covariance matrix
call cholesky_factor(L, upper = .false., err = errmgr)
if (errmgr%has_error_occurred()) return

! Compute the inverse and determinant
call populate_identity(this%m_covInv)
call cholesky_inverse(L, this%m_covInv)
this%m_covDet = cholesky_determinant(L)

! End
return

! Memory Error Handling
10 continue
call report_memory_error(errmgr, "mvnd_init", flag)
return
end subroutine

! ------------------------------------------------------------------------------
pure function mvnd_pdf(this, x) result(rst)
!! Evaluates the PDF for the multivariate normal distribution.
class(multivariate_normal_distribution), intent(in) :: this
!! The multivariate_normal_distribution object.
real(real64), intent(in), dimension(:) :: x
!! The values at which to evaluate the function.
real(real64) :: rst
!! The value of the function.

! Local Variables
integer(int32) :: n
real(real64) :: arg
real(real64), allocatable, dimension(:) :: delta, prod

! Process
n = size(x)
delta = x - this%m_means
prod = matmul(this%m_covInv, delta) ! prod = inv(sigma) * (x - mu)
arg = dot_product(delta, prod) ! arg = (x - mu)**T * prod
rst = exp(-0.5d0 * arg) / sqrt((2.0d0 * pi)**n * this%m_covDet)
end function

! ******************************************************************************
! SUPPORTING ROUTINES
! ------------------------------------------------------------------------------
subroutine cholesky_inverse(x, u)
use linalg, only : solve_triangular_system
!! Computes the inverse of a Cholesky-factored matrix.
real(real64), intent(in), dimension(:,:) :: x
!! The lower-triangular Cholesky factored matrix.
real(real64), intent(inout), dimension(:,:) :: u
!! On input, an N-by-N identity matrix. On output, the N-by-N inverted
!! matrix.

! To compute the inverse of a Cholesky factored matrix (L) consider the
! following:
!
! A = L * L**T
!
! (L * L**T) * inv(A) = I, where I is an identity matrix
!
! First, solve L * U = I, for the N-by-N matrix U
!
! And then solve L' * inv(A) = U for inv(A)

! Solve L * U = I for U
call solve_triangular_system(.true., .false., .false., .true., 1.0d0, x, u)

! Solve L**T * inv(A) = U for inv(A)
call solve_triangular_system(.true., .false., .true., .true., 1.0d0, x, u)
end subroutine

! ------------------------------------------------------------------------------
pure function cholesky_determinant(x) result(rst)
!! Computes the determinant of a Cholesky factored (lower) matrix.
real(real64), intent(in), dimension(:,:) :: x
!! The lower-triangular Cholesky-factored matrix.
real(real64) :: rst
!! The determinant.

! Local Variables
integer(int32) :: i, ep, n
real(real64) :: temp

! Initialization
n = size(x, 1)
rst = 0.0d0

! Compute the product of the squares of the diagonal
temp = 1.0d0
ep = 0
do i = 1, n
temp = (x(i,i))**2 * temp
if (temp == 0.0d0) then
rst = 0.0d0
return
end if

do while (abs(temp) < 1.0d0)
temp = 1.0d1 * temp
ep = ep - 1
end do

do while (abs(temp) > 1.0d1)
temp = 1.0d-1 * temp
ep = ep + 1
end do
end do
rst = temp * (1.0d1)**ep
end function

! ------------------------------------------------------------------------------
subroutine populate_identity(x)
!! Populates the supplied matrix as an identity matrix.
real(real64), intent(inout), dimension(:,:) :: x

! Local Variables
integer(int32) :: i, m, n, mn

! Process
m = size(x, 1)
n = size(x, 2)
mn = min(m, n)
x = 0.0d0
do i = 1, mn
x(i,i) = 1.0d0
end do
end subroutine

! ------------------------------------------------------------------------------
end module
40 changes: 40 additions & 0 deletions tests/fstats_distribution_tests.f90
Original file line number Diff line number Diff line change
Expand Up @@ -219,5 +219,45 @@ function test_standardized_variable() result(rst)
end if
end function

! ------------------------------------------------------------------------------
function test_multivariate_normal_distribution() result(rst)
use linalg, only : mtx_inverse, det
! Arguments
logical :: rst

! Parameters
real(real64), parameter :: pi = 2.0d0 * acos(0.0d0)
real(real64), parameter :: tol = 1.0d-8

! Local Variables
real(real64) :: x(2), mu(2), rho, s1, s2, sigma(2, 2), arg, ans, phi, &
dsig, inv(2, 2)
type(multivariate_normal_distribution) :: dist

! Initialization
rst = .true.
call random_number(x)
call random_number(mu)
call random_number(rho)
call random_number(s1)
call random_number(s2)
sigma = reshape([s2**2, -rho * s1 * s2, -rho * s1 * s2, s1**2], [2, 2])
call dist%initialize(mu, sigma)

! Compute the actual solution
inv = sigma
call mtx_inverse(inv)
arg = -0.5d0 * dot_product(x - mu, matmul(inv, x - mu))
dsig = det(sigma)
ans = exp(arg) / sqrt((2.0d0 * pi)**2 * dsig)

! Test
phi = dist%pdf(x)
if (.not.is_equal(phi, ans, tol)) then
rst = .false.
print "(A)", "TEST FAILED: test_multivariate_normal_distribution -1"
end if
end function

! ------------------------------------------------------------------------------
end module
3 changes: 3 additions & 0 deletions tests/fstats_tests.f90
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ program tests
local = binomial_distribution_test_1()
if (.not.local) overall = .false.

local = test_multivariate_normal_distribution()
if (.not.local) overall = .false.

! Statistics Tests
local = mean_test_1()
if (.not.local) overall = .false.
Expand Down

0 comments on commit 3b959ef

Please sign in to comment.