Skip to content

Commit

Permalink
Improve iter_possible_adjacencies (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgoettgens authored Dec 5, 2024
1 parent 87dbf8c commit 8aec992
Showing 1 changed file with 79 additions and 45 deletions.
124 changes: 79 additions & 45 deletions src/ArcDiagram.jl
Original file line number Diff line number Diff line change
Expand Up @@ -631,14 +631,27 @@ function all_arc_diagrams(
end
end
if isodd(n_upper_verts + n_lower_verts)
return ArcDiagramIterator{Undirected}([], 0)
return ArcDiagramIterator{Undirected}(ArcDiagramUndirected[], 0)
end
forbidden_neighbors = Dict{Int, Vector{Int}}()
for i in 1:n_upper_verts
i = -i
forbidden_neighbors[i] = Vector{Int}()
end
for i in 1:n_lower_verts
forbidden_neighbors[i] = Vector{Int}()
end
for is in indep_sets
for i in is
union!(forbidden_neighbors[i], is)
end
end
iter, len = iter_possible_adjacencies_undir(
n_upper_verts,
n_lower_verts,
[0 for _ in 1:n_upper_verts],
[0 for _ in 1:n_lower_verts],
indep_sets,
forbidden_neighbors,
)
return ArcDiagramIterator{Undirected}(iter, len)
end
Expand All @@ -648,36 +661,46 @@ function iter_possible_adjacencies_undir(
n_lower_verts::Int,
partial_upper::Vector{Int},
partial_lower::Vector{Int},
indep_sets::AbstractVector{<:AbstractVector{Int}},
forbidden_neighbors::Dict{Int, Vector{Int}},
)
i = findfirst(==(0), partial_upper)
i = findfirst(iszero, partial_upper)
if !isnothing(i)
i = -i
relevant_indep_sets = filter(is -> i in is, indep_sets)
poss_upper_adjs = setdiff(setdiff(map(j -> -j, findall(==(0), partial_upper)), i), relevant_indep_sets...)
poss_lower_adjs = setdiff(findall(==(0), partial_lower), relevant_indep_sets...)
choices = Iterators.map([poss_upper_adjs; poss_lower_adjs]) do j
partial_upper2 = deepcopy(partial_upper)
partial_lower2 = deepcopy(partial_lower)
poss_upper_adjs = (-j for j in findall(iszero, partial_upper) if i != -j && !(-j in forbidden_neighbors[i]))
poss_lower_adjs = (j for j in findall(iszero, partial_lower) if !(j in forbidden_neighbors[i]))
choices = Iterators.map(Iterators.flatten([poss_upper_adjs, poss_lower_adjs])) do j
partial_upper2 = copy(partial_upper)
partial_lower2 = copy(partial_lower)
partial_upper2[-i] = j
if j < 0
partial_upper2[-j] = i
else
partial_lower2[j] = i
end
iter_possible_adjacencies_undir(n_upper_verts, n_lower_verts, partial_upper2, partial_lower2, indep_sets)
iter_possible_adjacencies_undir(
n_upper_verts,
n_lower_verts,
partial_upper2,
partial_lower2,
forbidden_neighbors,
)
end
return Iterators.flatten(Iterators.map(first, choices)), sum(last, choices; init=0)
else
i = findfirst(==(0), partial_lower)
i = findfirst(iszero, partial_lower)
if !isnothing(i)
relevant_indep_sets = filter(is -> i in is, indep_sets)
poss_lower_adjs = setdiff(setdiff(findall(==(0), partial_lower), i), relevant_indep_sets...)
poss_lower_adjs = (j for j in findall(iszero, partial_lower) if i != j && !(j in forbidden_neighbors[i]))
choices = Iterators.map(poss_lower_adjs) do j
partial_lower2 = deepcopy(partial_lower)
partial_lower2 = copy(partial_lower)
partial_lower2[i] = j
partial_lower2[j] = i
iter_possible_adjacencies_undir(n_upper_verts, n_lower_verts, partial_upper, partial_lower2, indep_sets)
iter_possible_adjacencies_undir(
n_upper_verts,
n_lower_verts,
partial_upper,
partial_lower2,
forbidden_neighbors,
)
end
return Iterators.flatten(Iterators.map(first, choices)), sum(last, choices; init=0)
else
Expand All @@ -704,7 +727,7 @@ function all_arc_diagrams(
end
end
if isodd(n_upper_verts + n_lower_verts)
return ArcDiagramIterator{Directed}([], 0)
return ArcDiagramIterator{Directed}(ArcDiagramDirected[], 0)
end
rets = if n_upper_verts == 0
[all_arc_diagrams(Directed, BitVector([]), n_lower_verts; indep_sets, check=false)]
Expand Down Expand Up @@ -734,10 +757,10 @@ function all_arc_diagrams(
end
end
if isodd(n_upper_verts + n_lower_verts)
return ArcDiagramIterator{Directed}([], 0)
return ArcDiagramIterator{Directed}(ArcDiagramDirected[], 0)
end
if abs(parity_diff(parity_upper_verts)) > n_lower_verts
return ArcDiagramIterator{Directed}([], 0)
return ArcDiagramIterator{Directed}(ArcDiagramDirected[], 0)
end
rets = if n_lower_verts == 0
[all_arc_diagrams(Directed, parity_upper_verts, BitVector([]); indep_sets, check=false)]
Expand Down Expand Up @@ -770,10 +793,23 @@ function all_arc_diagrams(
end
end
if isodd(n_upper_verts + n_lower_verts)
return ArcDiagramIterator{Directed}([], 0)
return ArcDiagramIterator{Directed}(ArcDiagramDirected[], 0)
end
if parity_diff(parity_upper_verts) != parity_diff(parity_lower_verts)
return ArcDiagramIterator{Directed}([], 0)
return ArcDiagramIterator{Directed}(ArcDiagramDirected[], 0)
end
forbidden_neighbors = Dict{Int, Vector{Int}}()
for i in 1:n_upper_verts
i = -i
forbidden_neighbors[i] = Vector{Int}()
end
for i in 1:n_lower_verts
forbidden_neighbors[i] = Vector{Int}()
end
for is in indep_sets
for i in is
union!(forbidden_neighbors[i], is)
end
end
iter, len = iter_possible_adjacencies_dir(
n_upper_verts,
Expand All @@ -782,7 +818,7 @@ function all_arc_diagrams(
parity_lower_verts,
[0 for _ in 1:n_upper_verts],
[0 for _ in 1:n_lower_verts],
indep_sets,
forbidden_neighbors,
)
return ArcDiagramIterator{Directed}(iter, len)
end
Expand All @@ -794,23 +830,22 @@ function iter_possible_adjacencies_dir(
parity_lower_verts::BitVector,
partial_upper::Vector{Int},
partial_lower::Vector{Int},
indep_sets::AbstractVector{<:AbstractVector{Int}},
forbidden_neighbors::Dict{Int, Vector{Int}},
)
i = findfirst(==(0), partial_upper)
i = findfirst(iszero, partial_upper)
if !isnothing(i)
i = -i
relevant_indep_sets = filter(is -> i in is, indep_sets)
poss_upper_adjs = [
j for j in setdiff(setdiff(map(j -> -j, findall(==(0), partial_upper)), i), relevant_indep_sets...) if
parity_upper_verts[-i] != parity_upper_verts[-j]
]
poss_lower_adjs = [
j for j in setdiff(findall(==(0), partial_lower), relevant_indep_sets...) if
parity_upper_verts[-i] == parity_lower_verts[j]
]
choices = Iterators.map([poss_upper_adjs; poss_lower_adjs]) do j
partial_upper2 = deepcopy(partial_upper)
partial_lower2 = deepcopy(partial_lower)
poss_upper_adjs = (
-j for j in findall(iszero, partial_upper) if
i != -j && !(-j in forbidden_neighbors[i]) && parity_upper_verts[-i] != parity_upper_verts[j]
)
poss_lower_adjs = (
j for j in findall(iszero, partial_lower) if
!(j in forbidden_neighbors[i]) && parity_upper_verts[-i] == parity_lower_verts[j]
)
choices = Iterators.map(Iterators.flatten([poss_upper_adjs, poss_lower_adjs])) do j
partial_upper2 = copy(partial_upper)
partial_lower2 = copy(partial_lower)
partial_upper2[-i] = j
if j < 0
partial_upper2[-j] = i
Expand All @@ -824,20 +859,19 @@ function iter_possible_adjacencies_dir(
parity_lower_verts,
partial_upper2,
partial_lower2,
indep_sets,
forbidden_neighbors,
)
end
return Iterators.flatten(Iterators.map(first, choices)), sum(last, choices; init=0)
else
i = findfirst(==(0), partial_lower)
i = findfirst(iszero, partial_lower)
if !isnothing(i)
relevant_indep_sets = filter(is -> i in is, indep_sets)
poss_lower_adjs = [
j for j in setdiff(setdiff(findall(==(0), partial_lower), i), relevant_indep_sets...) if
parity_lower_verts[i] != parity_lower_verts[j]
]
poss_lower_adjs = (
j for j in findall(iszero, partial_lower) if
i != j && !(j in forbidden_neighbors[i]) && parity_lower_verts[i] != parity_lower_verts[j]
)
choices = Iterators.map(poss_lower_adjs) do j
partial_lower2 = deepcopy(partial_lower)
partial_lower2 = copy(partial_lower)
partial_lower2[i] = j
partial_lower2[j] = i
iter_possible_adjacencies_dir(
Expand All @@ -847,7 +881,7 @@ function iter_possible_adjacencies_dir(
parity_lower_verts,
partial_upper,
partial_lower2,
indep_sets,
forbidden_neighbors,
)
end
return Iterators.flatten(Iterators.map(first, choices)), sum(last, choices; init=0)
Expand Down

0 comments on commit 8aec992

Please sign in to comment.