Skip to content

Commit

Permalink
clean observations
Browse files Browse the repository at this point in the history
  • Loading branch information
dufourc1 committed Dec 20, 2024
1 parent 994b942 commit 90cabac
Showing 1 changed file with 72 additions and 74 deletions.
146 changes: 72 additions & 74 deletions src/observations.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,13 @@
# remove all references to graphs, and only use sparse matrices ?
"""
Observations{G, D}
A struct to hold observations for a network. The type parameter `G` represents the network
structure and must support indexing and the `size` function.
# Fields
- `graph::G`: The network structure (e.g. adjacency matrix).
- `dist_ref::D`: distribution of the observations (used for getting support, type of elements, etc.)
"""
struct Observations{G, D}
graph::G
dist_ref::D
Expand All @@ -15,10 +24,6 @@ Get the number of nodes in the graph.
# Returns
- `num_nodes`: The number of nodes.
"""
function number_nodes(graph::Observations{AbstractGraph, D}) where {D}
return nv(graph.graph)
end

function number_nodes(graph::Observations)
return size(graph.graph, 1)
end
Expand All @@ -39,20 +44,22 @@ function get_obs(graph::Observations, x::Tuple)
return get_obs(graph, x[1], x[2])
end

function get_obs(graph::Observations, i::Int, j::Int)
return get_obs(graph.graph, i, j)
end
"""
get_obs(graph::Observations, i::Int, j::Int)
function get_obs(g::SimpleGraph, x::Tuple)
return get_obs(g, x[1], x[2])
end
Get the observation for the given pair of nodes.
function get_obs(g::SimpleGraph, i::Int, j::Int)
return convert(Bool, has_edge(g, i, j))
end
# Arguments
- `graph::Observations`: The graph observations.
- `i::Int`: The first node.
- `j::Int`: The second node.
get_obs(g::AbstractArray, x) = get_obs(g, x[1], x[2])
get_obs(g::AbstractArray, i, j) = g[i, j]
# Returns
- `obs`: The observation.
"""
function get_obs(graph::Observations, i::Int, j::Int)
return graph.graph[i, j]
end

"""
density(graph::Observations)
Expand All @@ -65,13 +72,8 @@ Get the density of the graph.
# Returns
- `density`: The density of the graph.
"""
density(graph::Observations) = density(graph.graph)
function density(g::AbstractGraph)
return Graphs.density(g)
end

function density(g::AbstractMatrix)
return sum(g) / ((size(g, 1) * (size(g, 1) - 1)))
function density(graph::Observations)
return sum(graph.graph) / ((size(graph.graph, 1) * (size(graph.graph, 1) - 1)))
end

"""
Expand All @@ -85,11 +87,7 @@ Get the degree of each node in the graph.
# Returns
- `degrees`: The degrees of the nodes.
"""
function get_degree(graph::Observations{AbstractGraph, D}) where {D}
Graphs.degree(graph.graph)
end

function get_degree(graph)
function get_degree(graph::Observations)
return sum(graph.graph, dims = 2)
end

Expand All @@ -104,14 +102,12 @@ Get the adjacency matrix of the graph.
# Returns
- `adj_matrix`: The adjacency matrix.
"""
function get_adj(graph::Observations{AbstractGraph, D}) where {D}
return Graphs.adjacency_matrix(graph.graph)
end

function get_adj(graph::Observations)
return graph.graph
end



function normalized_laplacian(graph::Observations)
return normalized_laplacian(graph.graph)
end
Expand All @@ -123,51 +119,43 @@ end

normalized_laplacian(g::CategoricalArray) = normalized_laplacian(levelcode.(g))

function normalized_laplacian(g::AbstractMatrix)
degrees = sum(g, dims = 1)
"""
normalized_laplacian(graph::Observations)
Get the normalized Laplacian of the graph.
# Arguments
- `graph::Observations`: The graph observations.
# Returns
- `L`: The normalized Laplacian matrix.
"""
function normalized_laplacian(graph::AbstractMatrix)
degrees = sum(graph, dims = 1)
degrees .-= minimum(degrees)
n = size(g, 1)
L = similar(g, Float64)
n = size(graph, 1)
L = similar(graph, Float64)
for j in 1:n
for i in 1:n
if i == j
L[i, j] = 1
elseif degrees[i] == 0 || degrees[j] == 0
L[i, j] = 0
elseif g[i, j] != 0
elseif graph[i, j] != 0
L[i, j] = -1 / sqrt(degrees[i] * degrees[j])
end
end
end
return L
end

function Metis.graph(graph::Observations{<:AbstractGraph, <:Bernoulli})
return Metis.graph(graph.graph)
end

function Metis.graph(g::Observations{<:AbstractMatrix, <:Bernoulli})
return Metis.graph(SimpleGraph(g.graph))
end

function Metis.graph(graph::Observations{<:AbstractGraph, <:UnivariateDistribution})
function Metis.graph(graph::Observations{G, <:UnivariateDistribution}) where {G}
use_weights = true
if minimum(graph.dist_ref) < 0
@warn "Negative values are not allowed for MetisStart, using binary graph"
return Metis.graph(graph.graph)
else
return Metis.graph(graph.graph, weights = true)
use_weights = false
end
end

function Metis.graph(g::Observations{<:AbstractMatrix, <:UnivariateDistribution})
return Metis.graph(
weights(SimpleWeightedGraph(g.graph)), weights = true)
end

function Metis.graph(g::Observations{<:CategoricalMatrix, <:UnivariateFinite})
A, _ = categorical_matrix(g)
return Metis.graph(
adjacency_matrix(SimpleWeightedGraph(A)), weights = true)
return Metis.graph(sparse(graph.graph), weights = use_weights)
end


Expand All @@ -185,21 +173,19 @@ Discretise the graph observations.
- `discretised_graph`: The discretised graph observations.
- `discretiser`: The discretiser used.
Assume that the diagonal is zero.
0 indicates no edge, while missing indicates no information about the edge.
By default maps 0 to 0. If you want another behaviour use the function where you
pass a `Discretizer` object.
number_levels will be the number of levels in the discretized distribution (excluding 0).
"""
function discretise(graph::Observations{G, D};
number_groups = nothing, number_levels = nothing) where {G, D}
function discretise(graph::Observations; number_groups = nothing, number_levels = nothing)
if isnothing(number_groups) && isnothing(number_levels)
throw(ArgumentError("Either `number_groups` or `number_levels` must be provided"))
end
if isnothing(number_levels)
number_levels = round(Int,get_num_levels_from_groups(number_nodes(graph), number_groups))
number_levels = round(Int, get_num_levels_from_groups(number_nodes(graph), number_groups))
else
if !isnothing(number_groups)
@warn "disregarding `number_groups` as `number_levels` is provided"
Expand All @@ -208,23 +194,35 @@ function discretise(graph::Observations{G, D};
return discretise(graph, DiscretizerZeroToZero(number_levels, extrema(graph.graph)...))
end

function discretise(graph::Observations{G, D}, discretiser ::Discretizer) where {G,D<:UnivariateDistribution}
A_encoded = encode(discretiser, _graph_to_mat(graph))
return Observations(A_encoded, DiscretizedDistribution(discretiser)), discretiser
end
"""
discretise(graph::Observations, discretiser::Discretizer)
Discretise the graph observations using the given discretiser.
function _graph_to_mat(graph::Observations{<:AbstractGraph, D}) where {D<:UnivariateDistribution}
return weights(graph.graph)
end
# Arguments
- `graph::Observations`: The graph observations.
- `discretiser::Discretizer`: The discretiser to use.
function _graph_to_mat(graph::Observations{<:AbstractMatrix, D}) where {D<:UnivariateDistribution}
return graph.graph
# Returns
- `discretised_graph`: The discretised graph observations.
- `discretiser`: The discretiser used.
"""
function discretise(graph::Observations, discretiser::Discretizer)
A_encoded = encode(discretiser, graph.graph)
return Observations(A_encoded, DiscretizedDistribution(discretiser)), discretiser
end


"""
get_num_levels_from_groups(n, number_groups)
Get the number of levels for the discretized distribution given n and k.
# Arguments
- `n`: The number of nodes.
- `number_groups`: The number of groups.
# Returns
- `num_levels`: The number of levels.
"""
function get_num_levels_from_groups(n, number_groups)
return max(1, n^(0.5 * (1 - log(number_groups) / log(n))))
Expand Down

0 comments on commit 90cabac

Please sign in to comment.