diff --git a/src/fstats.f90 b/src/fstats.f90 index 7a1608e..0d818f0 100644 --- a/src/fstats.f90 +++ b/src/fstats.f90 @@ -25,6 +25,9 @@ module fstats public :: f_distribution public :: chi_squared_distribution public :: binomial_distribution + public :: multivariate_distribution + public :: multivariate_distribution_function + public :: multivariate_normal_distribution public :: mean public :: variance public :: standard_deviation diff --git a/src/fstats_distributions.f90 b/src/fstats_distributions.f90 index 559afa0..5f536ea 100644 --- a/src/fstats_distributions.f90 +++ b/src/fstats_distributions.f90 @@ -3,6 +3,8 @@ module fstats_distributions use ieee_arithmetic use fstats_special_functions use fstats_helper_routines + use ferror + use fstats_errors implicit none private public :: distribution @@ -13,6 +15,9 @@ module fstats_distributions public :: f_distribution public :: chi_squared_distribution public :: binomial_distribution + public :: multivariate_distribution + public :: multivariate_distribution_function + public :: multivariate_normal_distribution real(real64), parameter :: pi = 2.0d0 * acos(0.0d0) @@ -137,6 +142,48 @@ pure function distribution_property(this) result(rst) procedure, public :: variance => bd_variance end type +! ****************************************************************************** +! MULTIVARIATE DISTRIBUTIONS +! ------------------------------------------------------------------------------ + type, abstract :: multivariate_distribution + !! Defines a multivariate probability distribution. + contains + procedure(multivariate_distribution_function), deferred, pass :: pdf + !! Computes the probability density function. + end type + + interface + pure function multivariate_distribution_function(this, x) result(rst) + !! Defines an interface for a multivariate probability distribution + !! function. + use iso_fortran_env, only : real64 + import multivariate_distribution + class(multivariate_distribution), intent(in) :: this + !! The distribution object. + real(real64), intent(in), dimension(:) :: x + !! The values at which to evaluate the function. + real(real64) :: rst + !! The value of the function. + end function + end interface + +! ------------------------------------------------------------------------------ + type, extends(multivariate_distribution) :: multivariate_normal_distribution + !! Defines a multivariate normal (Gaussian) distribution. + real(real64), private, allocatable, dimension(:) :: m_means + !! An N-element array of mean values. + real(real64), private, allocatable, dimension(:,:) :: m_cov + !! The N-by-N covariance matrix. This matrix must be + !! positive-definite. + real(real64), private, allocatable, dimension(:,:) :: m_covInv + !! The N-by-N inverse of the covariance matrix. + real(real64), private :: m_covDet + !! The determinant of the covariance matrix. + contains + procedure, public :: initialize => mvnd_init + procedure, public :: pdf => mvnd_pdf + end type + contains ! ------------------------------------------------------------------------------ pure elemental function dist_std_var(this, x) result(rst) @@ -658,5 +705,189 @@ pure function bd_variance(this) result(rst) rst = this%n * this%p * (1.0d0 - this%p) end function +! ****************************************************************************** +! MULTIVARIATE NORMAL DISTRIBUTION +! ------------------------------------------------------------------------------ +subroutine mvnd_init(this, mu, sigma, err) + use linalg, only : cholesky_factor + !! Initializes the multivariate normal distribution by defining the mean + !! values and covariance matrix. + class(multivariate_normal_distribution), intent(inout) :: this + !! The multivariate_normal_distribution object. + real(real64), intent(in), dimension(:) :: mu + !! An N-element array containing the mean values. + real(real64), intent(in), dimension(:,:) :: sigma + !! The N-by-N covariance matrix. The PDF exists only if this matrix + !! is positive-definite; therefore, the positive-definite constraint + !! is checked within this routine and enforced. An error is thrown if + !! the supplied matrix is not positive-definite. + class(errors), intent(inout), optional, target :: err + !! The error handling object. + + ! Local Variables + integer(int32) :: n, flag + real(real64), allocatable, dimension(:,:) :: L + class(errors), pointer :: errmgr + type(errors), target :: deferr + + ! Initialization + if (present(err)) then + errmgr => err + else + errmgr => deferr + end if + n = size(mu) + + ! Input Checking + if (size(sigma, 1) /= n .or. size(sigma, 2) /= n) then + call report_matrix_size_error(errmgr, "mvnd_init", "sigma", n, n, & + size(sigma, 1), size(sigma, 2)) + return + end if + + ! Store the matrices + this%m_means = mu + this%m_cov = sigma + allocate(L(n, n), stat = flag, source = sigma) + if (flag /= 0) go to 10 + if (allocated(this%m_covInv)) then + if (size(this%m_covInv, 1) /= n .or. size(this%m_covInv, 2) /= n) then + deallocate(this%m_covInv) + allocate(this%m_covInv(n, n), stat = flag) + if (flag /= 0) go to 10 + end if + else + allocate(this%m_covInv(n, n), stat = flag) + if (flag /= 0) go to 10 + end if + + ! Compute the Cholesky factorization of the covariance matrix + call cholesky_factor(L, upper = .false., err = errmgr) + if (errmgr%has_error_occurred()) return + + ! Compute the inverse and determinant + call populate_identity(this%m_covInv) + call cholesky_inverse(L, this%m_covInv) + this%m_covDet = cholesky_determinant(L) + + ! End + return + + ! Memory Error Handling +10 continue + call report_memory_error(errmgr, "mvnd_init", flag) + return +end subroutine + +! ------------------------------------------------------------------------------ +pure function mvnd_pdf(this, x) result(rst) + !! Evaluates the PDF for the multivariate normal distribution. + class(multivariate_normal_distribution), intent(in) :: this + !! The multivariate_normal_distribution object. + real(real64), intent(in), dimension(:) :: x + !! The values at which to evaluate the function. + real(real64) :: rst + !! The value of the function. + + ! Local Variables + integer(int32) :: n + real(real64) :: arg + real(real64), allocatable, dimension(:) :: delta, prod + + ! Process + n = size(x) + delta = x - this%m_means + prod = matmul(this%m_covInv, delta) ! prod = inv(sigma) * (x - mu) + arg = dot_product(delta, prod) ! arg = (x - mu)**T * prod + rst = exp(-0.5d0 * arg) / sqrt((2.0d0 * pi)**n * this%m_covDet) +end function + +! ****************************************************************************** +! SUPPORTING ROUTINES +! ------------------------------------------------------------------------------ +subroutine cholesky_inverse(x, u) + use linalg, only : solve_triangular_system + !! Computes the inverse of a Cholesky-factored matrix. + real(real64), intent(in), dimension(:,:) :: x + !! The lower-triangular Cholesky factored matrix. + real(real64), intent(inout), dimension(:,:) :: u + !! On input, an N-by-N identity matrix. On output, the N-by-N inverted + !! matrix. + + ! To compute the inverse of a Cholesky factored matrix (L) consider the + ! following: + ! + ! A = L * L**T + ! + ! (L * L**T) * inv(A) = I, where I is an identity matrix + ! + ! First, solve L * U = I, for the N-by-N matrix U + ! + ! And then solve L' * inv(A) = U for inv(A) + + ! Solve L * U = I for U + call solve_triangular_system(.true., .false., .false., .true., 1.0d0, x, u) + + ! Solve L**T * inv(A) = U for inv(A) + call solve_triangular_system(.true., .false., .true., .true., 1.0d0, x, u) +end subroutine + +! ------------------------------------------------------------------------------ +pure function cholesky_determinant(x) result(rst) + !! Computes the determinant of a Cholesky factored (lower) matrix. + real(real64), intent(in), dimension(:,:) :: x + !! The lower-triangular Cholesky-factored matrix. + real(real64) :: rst + !! The determinant. + + ! Local Variables + integer(int32) :: i, ep, n + real(real64) :: temp + + ! Initialization + n = size(x, 1) + rst = 0.0d0 + + ! Compute the product of the squares of the diagonal + temp = 1.0d0 + ep = 0 + do i = 1, n + temp = (x(i,i))**2 * temp + if (temp == 0.0d0) then + rst = 0.0d0 + return + end if + + do while (abs(temp) < 1.0d0) + temp = 1.0d1 * temp + ep = ep - 1 + end do + + do while (abs(temp) > 1.0d1) + temp = 1.0d-1 * temp + ep = ep + 1 + end do + end do + rst = temp * (1.0d1)**ep +end function + +! ------------------------------------------------------------------------------ +subroutine populate_identity(x) + !! Populates the supplied matrix as an identity matrix. + real(real64), intent(inout), dimension(:,:) :: x + + ! Local Variables + integer(int32) :: i, m, n, mn + + ! Process + m = size(x, 1) + n = size(x, 2) + mn = min(m, n) + x = 0.0d0 + do i = 1, mn + x(i,i) = 1.0d0 + end do +end subroutine + ! ------------------------------------------------------------------------------ end module \ No newline at end of file diff --git a/tests/fstats_distribution_tests.f90 b/tests/fstats_distribution_tests.f90 index 6e9c1bc..3aed30f 100644 --- a/tests/fstats_distribution_tests.f90 +++ b/tests/fstats_distribution_tests.f90 @@ -219,5 +219,45 @@ function test_standardized_variable() result(rst) end if end function +! ------------------------------------------------------------------------------ + function test_multivariate_normal_distribution() result(rst) + use linalg, only : mtx_inverse, det + ! Arguments + logical :: rst + + ! Parameters + real(real64), parameter :: pi = 2.0d0 * acos(0.0d0) + real(real64), parameter :: tol = 1.0d-8 + + ! Local Variables + real(real64) :: x(2), mu(2), rho, s1, s2, sigma(2, 2), arg, ans, phi, & + dsig, inv(2, 2) + type(multivariate_normal_distribution) :: dist + + ! Initialization + rst = .true. + call random_number(x) + call random_number(mu) + call random_number(rho) + call random_number(s1) + call random_number(s2) + sigma = reshape([s2**2, -rho * s1 * s2, -rho * s1 * s2, s1**2], [2, 2]) + call dist%initialize(mu, sigma) + + ! Compute the actual solution + inv = sigma + call mtx_inverse(inv) + arg = -0.5d0 * dot_product(x - mu, matmul(inv, x - mu)) + dsig = det(sigma) + ans = exp(arg) / sqrt((2.0d0 * pi)**2 * dsig) + + ! Test + phi = dist%pdf(x) + if (.not.is_equal(phi, ans, tol)) then + rst = .false. + print "(A)", "TEST FAILED: test_multivariate_normal_distribution -1" + end if + end function + ! ------------------------------------------------------------------------------ end module \ No newline at end of file diff --git a/tests/fstats_tests.f90 b/tests/fstats_tests.f90 index 6c2b346..d962f6e 100644 --- a/tests/fstats_tests.f90 +++ b/tests/fstats_tests.f90 @@ -30,6 +30,9 @@ program tests local = binomial_distribution_test_1() if (.not.local) overall = .false. + local = test_multivariate_normal_distribution() + if (.not.local) overall = .false. + ! Statistics Tests local = mean_test_1() if (.not.local) overall = .false.