diff --git a/src/mutual_information.jl b/src/mutual_information.jl index d27cd67..afa5fa0 100644 --- a/src/mutual_information.jl +++ b/src/mutual_information.jl @@ -68,8 +68,8 @@ function bitinformation(A::AbstractArray{T}; masked_value::Union{T,Nothing}=nothing, kwargs...) where {T<:Union{Integer,AbstractFloat}} - # create a BitArray mask if a masked_value is provided - isnothing(masked_value) || return bitinformation(A,A .== masked_value;dim,kwargs...) + # create a BitArray mask if a masked_value is provided, use === to also allow NaN comparison + isnothing(masked_value) || return bitinformation(A,A .=== masked_value;dim,kwargs...) A = permute_dim_forward(A,dim) # Permute A to take adjacent entry in dimension dim n = size(A)[1] # n elements in dim diff --git a/test/information.jl b/test/information.jl index b78e6ee..cb0541c 100644 --- a/test/information.jl +++ b/test/information.jl @@ -206,5 +206,12 @@ end round!(A,1) mask = A .== masked_value @test bitinformation(A,mask) == bitinformation(A;masked_value) + + # check that masked_value=NaN also works + A[:,2] .= NaN # put some NaNs somewhere + mask = BitArray(undef,size(A)...) # create corresponding mask + fill!(mask,false) + mask[:,2] .= true + @test bitinformation(A,mask) == bitinformation(A;masked_value=convert(T,NaN)) end end \ No newline at end of file