Skip to content

Commit

Permalink
Refactor blocking rule initialization in EMTrainingSession and Linker…
Browse files Browse the repository at this point in the history
… classes
  • Loading branch information
RobinL committed Jan 10, 2024
1 parent cc3ac68 commit 5c07616
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
13 changes: 7 additions & 6 deletions splink/em_training_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class EMTrainingSession:
def __init__(
self,
linker: Linker,
blocking_rule_for_training: str,
blocking_rule_for_training: BlockingRule,
fix_u_probabilities: bool = False,
fix_m_probabilities: bool = False,
fix_probability_two_random_records_match: bool = False,
Expand All @@ -54,10 +54,10 @@ def __init__(
self._settings_obj._training_mode = True

if not isinstance(blocking_rule_for_training, BlockingRule):
blocking_rule = BlockingRule(blocking_rule_for_training)
blocking_rule_for_training = BlockingRule(blocking_rule_for_training)

self._settings_obj._blocking_rule_for_training = blocking_rule
self._blocking_rule_for_training = blocking_rule
self._settings_obj._blocking_rule_for_training = blocking_rule_for_training
self._blocking_rule_for_training = blocking_rule_for_training
self._settings_obj._estimate_without_term_frequencies = (
estimate_without_term_frequencies
)
Expand All @@ -68,7 +68,7 @@ def __init__(
)
else:
self._comparison_levels_to_reverse_blocking_rule = self._original_settings_obj._get_comparison_levels_corresponding_to_training_blocking_rule( # noqa
blocking_rule_for_training
blocking_rule_for_training.blocking_rule_sql
)

self._settings_obj._probability_two_random_records_match = (
Expand All @@ -87,7 +87,8 @@ def __init__(
if not comparisons_to_deactivate:
comparisons_to_deactivate = []
br_cols = get_columns_used_from_sql(
blocking_rule_for_training, self._settings_obj._sql_dialect
blocking_rule_for_training.blocking_rule_sql,
self._settings_obj._sql_dialect,
)
for cc in self._settings_obj.comparisons:
cc_cols = cc._input_columns_used_by_case_statement
Expand Down
3 changes: 1 addition & 2 deletions splink/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1651,8 +1651,7 @@ def estimate_parameters_using_expectation_maximisation(
# to be used by the training linkers
self._initialise_df_concat_with_tf()

# Extract the blocking rule
blocking_rule = blocking_rule_to_obj(blocking_rule).blocking_rule_sql
blocking_rule = blocking_rule_to_obj(blocking_rule)

if comparisons_to_deactivate:
# If user provided a string, convert to Comparison object
Expand Down

0 comments on commit 5c07616

Please sign in to comment.