Skip to content

Commit

Permalink
Refactoring after discussions
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Feb 20, 2025
1 parent c83570e commit 229b08f
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 65 deletions.
49 changes: 49 additions & 0 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,55 @@ def get_sorting_property(self, key: str, ids: Optional[Iterable] = None) -> np.n
"""
return self.sorting.get_property(key, ids=ids)

def can_perform_merges(self,
merge_unit_groups,
new_unit_ids=None,
merging_mode="soft",
sparsity_overlap=0.75,
new_id_strategy="append",
**kwargs):
"""
Check if merges can be performed given merging params
Parameters
----------
merge_unit_groups : list/tuple of lists/tuples
A list of lists for every merge group. Each element needs to have at least two elements
(two units to merge).
sparsity_overlap : float
The percentage of overlap that units should share in order to accept merges.
Returns
-------
can_perform_merges : bool
If soft merging can be performed.
"""
assert merging_mode in ["hard", "soft"], "merging_mode must be 'hard' or 'soft'"

new_unit_ids = generate_unit_ids_for_merge_group(
self.unit_ids, merge_unit_groups, new_unit_ids, new_id_strategy
)
all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids)

if merging_mode == "hard":
return True
elif merging_mode == "soft":
if self.sparsity is None:
return True
else:
for unit_index, unit_id in enumerate(all_unit_ids):
if unit_id in new_unit_ids:
# This is a new unit, and the sparsity mask will be the intersection of the
# ones of all merges
current_merge_group = merge_unit_groups[list(new_unit_ids).index(unit_id)]
merge_unit_indices = self.sorting.ids_to_indices(current_merge_group)
union_mask = np.sum(self.sparsity.mask[merge_unit_indices], axis=0) > 0
intersection_mask = np.prod(self.sparsity.mask[merge_unit_indices], axis=0) > 0
thr = np.sum(intersection_mask) / np.sum(union_mask)
if thr < sparsity_overlap:
return False
return True

def _save_or_select_or_merge(
self,
format="binary_folder",
Expand Down
122 changes: 66 additions & 56 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,77 +415,74 @@ def _auto_merge_units_single_iteration(
sorting_analyzer: SortingAnalyzer,
compute_merge_kwargs: dict = {},
apply_merge_kwargs: dict = {},
recursive: bool = False,
extra_outputs: bool = False,
force_copy: bool = True,
merging_strategy: str = "warning",
**job_kwargs,
) -> SortingAnalyzer:
"""
Compute merge unit groups and apply it on a SortingAnalyzer. Used by auto_merge_units, see documentation there.
Internally uses `compute_merge_unit_groups()`
Parameters
----------
sorting_analyzer : SortingAnalyzer
The SortingAnalyzer to be merged
compute_merge_kwargs : dict
The parameters to be passed to compute_merge_unit_groups()
apply_merge_kwargs : dict
The parameters to be passed to merge_units()
extra_outputs : bool, default: False
If True, additional outputs are returned
force_copy : bool, default: True
If True, the sorting_analyzer is copied before applying the merges
merging_strategy : str, default: "warning"
The strategy to use when merging units. Can be "warning" or "strict". If "warning", the merging will
stop with a warning if soft merges can not be done. Otherwise, if "strict", an error is thrown
Returns
-------
sorting_analyzer:
The new sorting analyzer where all the merges from all the presets have been applied
merges, outs:
resolved_merges, merge_unit_groups, outs:
Returned only when extra_outputs=True
A list with the merges performed, and dictionaries that contains data for debugging and plotting.
Note that if recursive, then you are receiving list of lists (for all merges and outs at every step)
"""

if force_copy:
# To avoid erasing the extensions of the user
sorting_analyzer = sorting_analyzer.copy()

if not recursive:
merge_unit_groups = compute_merge_unit_groups(
sorting_analyzer, **compute_merge_kwargs, extra_outputs=extra_outputs, force_copy=False, **job_kwargs
)
merge_unit_groups = compute_merge_unit_groups(
sorting_analyzer, **compute_merge_kwargs, extra_outputs=extra_outputs, force_copy=False, **job_kwargs
)

if extra_outputs:
merge_unit_groups, outs = merge_unit_groups
if extra_outputs:
merge_unit_groups, outs = merge_unit_groups

merged_units = len(merge_unit_groups) > 0
if merged_units:
merged_units = len(merge_unit_groups) > 0
if merged_units:
can_do_merges = sorting_analyzer.can_perform_merges(merge_unit_groups, **apply_merge_kwargs)
if can_do_merges:
merged_analyzer, new_unit_ids = sorting_analyzer.merge_units(
merge_unit_groups, return_new_unit_ids=True, **apply_merge_kwargs, **job_kwargs
)
else:
merged_analyzer = sorting_analyzer
new_unit_ids = []

resolved_merges = {key: value for (key, value) in zip(new_unit_ids, merge_unit_groups)}
if merging_strategy == "warning":
warnings.warn("Some merges can not be done. Merging is stopped", UserWarning)
if extra_outputs:
return merged_analyzer, {}, merge_unit_groups, outs
else:
return merged_analyzer
elif merging_strategy == "strict":
raise ValueError("Some merges can not be done. Merging is stopped")
else:
merged_units = True
merged_analyzer = sorting_analyzer
new_unit_ids = []

if extra_outputs:
all_merging_groups = []
resolved_merges = {}
all_outs = []

while merged_units:
merge_unit_groups = compute_merge_unit_groups(
merged_analyzer, **compute_merge_kwargs, extra_outputs=extra_outputs, force_copy=False, **job_kwargs
)

if extra_outputs:
merge_unit_groups, outs = merge_unit_groups

merged_units = len(merge_unit_groups) > 0

if merged_units:
merged_analyzer, new_unit_ids = merged_analyzer.merge_units(
merge_unit_groups, return_new_unit_ids=True, **apply_merge_kwargs, **job_kwargs
)

if extra_outputs:
all_merging_groups += [merge_unit_groups]
new_merges = {key: value for (key, value) in zip(new_unit_ids, merge_unit_groups)}
resolved_merges = resolve_pairs(resolved_merges, new_merges)
all_outs += [outs]
resolved_merges = {key: value for (key, value) in zip(new_unit_ids, merge_unit_groups)}

if extra_outputs:
return merged_analyzer, resolved_merges, merge_unit_groups, outs
Expand Down Expand Up @@ -688,6 +685,7 @@ def auto_merge_units(
steps: list[str] | None = None,
apply_merge_kwargs: dict = {},
recursive: bool = False,
merging_strategy: str = "warning",
extra_outputs: bool = False,
force_copy: bool = True,
**job_kwargs,
Expand Down Expand Up @@ -715,6 +713,9 @@ def auto_merge_units(
recursive : bool, default: False
If True, then each presets of the list is applied until no further merges can be done, before trying
the next one
merging_strategy : str, default: "warning"
The strategy to use when merging units. Can be "warning" or "strict". If "warning", the merging will
stop with a warning if soft merges can not be done. Otherwise, if "strict", an error is thrown
extra_outputs : bool, default: False
If True, additional list of merges applied at every preset, and dictionary (`outs`) with processed data are returned.
force_copy : boolean, default: True
Expand Down Expand Up @@ -772,22 +773,31 @@ def auto_merge_units(
compute_merge_kwargs = {"steps": to_launch}

compute_merge_kwargs.update({"steps_params": params})
# print(compute_merge_kwargs)
sorting_analyzer = _auto_merge_units_single_iteration(
sorting_analyzer,
compute_merge_kwargs,
apply_merge_kwargs=apply_merge_kwargs,
recursive=recursive,
extra_outputs=extra_outputs,
force_copy=False,
**job_kwargs,
)
found_merges = True

if extra_outputs:
sorting_analyzer, new_merges, merge_unit_groups, outs = sorting_analyzer
all_merging_groups += [merge_unit_groups]
resolved_merges = resolve_pairs(resolved_merges, new_merges)
all_outs += [outs]
while found_merges:

num_units = len(sorting_analyzer.unit_ids)
sorting_analyzer = _auto_merge_units_single_iteration(
sorting_analyzer,
compute_merge_kwargs,
apply_merge_kwargs=apply_merge_kwargs,
extra_outputs=extra_outputs,
merging_strategy=merging_strategy,
force_copy=False,
**job_kwargs,
)

if extra_outputs:
sorting_analyzer, new_merges, merge_unit_groups, outs = sorting_analyzer
all_merging_groups += [merge_unit_groups]
resolved_merges = resolve_pairs(resolved_merges, new_merges)
all_outs += [outs]

if not recursive:
found_merges = False
else:
found_merges = len(sorting_analyzer.unit_ids) < num_units

if extra_outputs:
return sorting_analyzer, resolved_merges, merge_unit_groups, all_outs
Expand Down
9 changes: 0 additions & 9 deletions src/spikeinterface/generation/drifting_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,6 @@
contact_shapes="circle",
contact_shape_params={"radius": 8},
),
"MicroMacro": dict(
num_columns=1,
num_contact_per_column=[8],
xpitch=300,
ypitch=300,
y_shift_per_column=[0],
contact_shapes="circle",
contact_shape_params={"radius": 8},
),
}


Expand Down

0 comments on commit 229b08f

Please sign in to comment.