diff --git a/dedupe/datamodel.py b/dedupe/datamodel.py index bb4335658..786988d97 100644 --- a/dedupe/datamodel.py +++ b/dedupe/datamodel.py @@ -8,6 +8,7 @@ import numpy import dedupe.variables +from dedupe.variables.base import CustomType from dedupe.variables.base import FieldType as FieldVariable from dedupe.variables.base import MissingDataType, Variable from dedupe.variables.interaction import InteractionType @@ -28,21 +29,23 @@ ) from dedupe.predicates import Predicate -VARIABLE_CLASSES = {k: v for k, v in FieldVariable.all_subclasses() if k} +VARIABLE_CLASSES = {k: v for k, v in Variable.all_subclasses() if k} class DataModel(object): version = 1 def __init__(self, variable_definitions: Iterable[VariableDefinition]): - variable_definitions = list(variable_definitions) - if not variable_definitions: - raise ValueError("The variable definitions cannot be empty") - all_variables: list[Variable] - self.primary_variables, all_variables = typify_variables(variable_definitions) - self._derived_start = len(all_variables) - - all_variables += interactions(variable_definitions, self.primary_variables) + variables = typify_variables(variable_definitions) + non_interactions: list[FieldVariable] = [ + v for v in variables if not isinstance(v, InteractionType) # type: ignore[misc] + ] + self.primary_variables = non_interactions + expanded_primary = _expand_higher_variables(self.primary_variables) + self._derived_start = len(expanded_primary) + + all_variables = expanded_primary.copy() + all_variables += _expanded_interactions(variables) all_variables += missing(all_variables) self._missing_field_indices = missing_field_indices(all_variables) @@ -50,9 +53,6 @@ def __init__(self, variable_definitions: Iterable[VariableDefinition]): self._len = len(all_variables) - def __len__(self) -> int: - return self._len - # Changing this from a property to just a normal attribute causes # pickling problems, because we are removing static methods from # their class context. This could be fixed by defining comparators @@ -82,7 +82,7 @@ def distances( ) -> numpy.typing.NDArray[numpy.float_]: num_records = len(record_pairs) - distances = numpy.empty((num_records, len(self)), "f4") + distances = numpy.empty((num_records, self._len), "f4") for i, (record_1, record_2) in enumerate(record_pairs): @@ -144,11 +144,12 @@ def __setstate__(self, d): def typify_variables( variable_definitions: Iterable[VariableDefinition], -) -> tuple[list[FieldVariable], list[Variable]]: - primary_variables: list[FieldVariable] = [] - all_variables: list[Variable] = [] - only_custom = True +) -> list[Variable]: + variable_definitions = list(variable_definitions) + if not variable_definitions: + raise ValueError("The variable definitions cannot be empty") + variables: list[Variable] = [] for definition in variable_definitions: try: variable_type = definition["type"] @@ -167,12 +168,6 @@ def typify_variables( "{'field' : 'Phone', type: 'String'}" ) - if variable_type != "Custom": - only_custom = False - - if variable_type == "Interaction": - continue - if variable_type == "FuzzyCategorical" and "other fields" not in definition: definition["other fields"] = [ # type: ignore d["field"] @@ -183,30 +178,35 @@ def typify_variables( try: variable_class = VARIABLE_CLASSES[variable_type] except KeyError: + valid = ", ".join(VARIABLE_CLASSES) raise KeyError( - "Field type %s not valid. Valid types include %s" - % (definition["type"], ", ".join(VARIABLE_CLASSES)) + f"Variable type {variable_type} not valid. Valid types include {valid}" ) - variable_object = variable_class(definition) - assert isinstance(variable_object, FieldVariable) - - primary_variables.append(variable_object) + assert isinstance(variable_object, Variable) + variables.append(variable_object) - if hasattr(variable_object, "higher_vars"): - all_variables.extend(variable_object.higher_vars) - else: - variable_object = cast(Variable, variable_object) - all_variables.append(variable_object) - - if only_custom: + no_blocking_variables = all( + isinstance(v, (CustomType, InteractionType)) for v in variables + ) + if no_blocking_variables: raise ValueError( - "At least one of the variable types needs to be a type" - "other than 'Custom'. 'Custom' types have no associated" - "blocking rules" + "At least one of the variable types needs to be a type " + "other than 'Custom' or 'Interaction', " + "since these types have no associated blocking rules." ) - return primary_variables, all_variables + return variables + + +def _expand_higher_variables(variables: Iterable[Variable]) -> list[Variable]: + result: list[Variable] = [] + for variable in variables: + if hasattr(variable, "higher_vars"): + result.extend(variable.higher_vars) + else: + result.append(variable) + return result def missing(variables: list[Variable]) -> list[MissingDataType]: @@ -217,16 +217,12 @@ def missing(variables: list[Variable]) -> list[MissingDataType]: return missing_variables -def interactions( - definitions: Iterable[VariableDefinition], primary_variables: list[FieldVariable] -) -> list[InteractionType]: - field_d = {field.name: field for field in primary_variables} - +def _expanded_interactions(variables: list[Variable]) -> list[InteractionType]: + field_vars = {var.name: var for var in variables if isinstance(var, FieldVariable)} interactions = [] - for definition in definitions: - if definition["type"] == "Interaction": - var = InteractionType(definition) - var.expandInteractions(field_d) + for var in variables: + if isinstance(var, InteractionType): + var.expandInteractions(field_vars) interactions.extend(var.higher_vars) return interactions @@ -236,15 +232,27 @@ def missing_field_indices(variables: list[Variable]) -> list[int]: def interaction_indices(variables: list[Variable]) -> list[list[int]]: - var_names = [var.name for var in variables] + _ensure_unique_names(variables) + name_to_index = {var.name: i for i, var in enumerate(variables)} indices = [] for var in variables: if hasattr(var, "interaction_fields"): - interaction_indices = [var_names.index(f) for f in var.interaction_fields] # type: ignore + interaction_indices = [name_to_index[f] for f in var.interaction_fields] # type: ignore indices.append(interaction_indices) return indices +def _ensure_unique_names(variables: Iterable[Variable]) -> None: + seen = set() + for var in variables: + if var.name in seen: + raise ValueError( + "Variable name used more than once! " + "Choose a unique name for each variable: '{var.name}'" + ) + seen.add(var.name) + + def reduce_method(m): # type: ignore[no-untyped-def] return (getattr, (m.__self__, m.__func__.__name__)) diff --git a/dedupe/variables/base.py b/dedupe/variables/base.py index f80b28faa..c958db5ea 100644 --- a/dedupe/variables/base.py +++ b/dedupe/variables/base.py @@ -58,21 +58,16 @@ def all_subclasses( class DerivedType(Variable): - type = "Derived" - def __init__(self, definition: VariableDefinition): self.name = "(%s: %s)" % (str(definition["name"]), str(definition["type"])) super(DerivedType, self).__init__(definition) class MissingDataType(Variable): - type = "MissingData" + has_missing = False def __init__(self, name: str): - - self.name = "(%s: Not Missing)" % name - - self.has_missing = False + self.name = f"({name}: Not Missing)" class FieldType(Variable): diff --git a/tests/test_api.py b/tests/test_api.py index 84ac9169a..6f9f40dd5 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -58,11 +58,13 @@ def test_initialize_fields(self): [], ) + # only customs with self.assertRaises(ValueError): dedupe.api.ActiveMatching( [{"field": "name", "type": "Custom", "comparator": lambda x, y: 1}], ) + # Only customs with self.assertRaises(ValueError): dedupe.api.ActiveMatching( [ @@ -71,6 +73,44 @@ def test_initialize_fields(self): ], ) + # Only custom and interactions + with self.assertRaises(ValueError): + dedupe.api.ActiveMatching( + [ + {"field": "name", "type": "Custom", "comparator": lambda x, y: 1}, + {"field": "age", "type": "Custom", "comparator": lambda x, y: 1}, + {"type": "Interaction", "interaction variables": ["name", "age"]}, + ], + ) + + # Only interactions + with self.assertRaises(ValueError): + dedupe.api.ActiveMatching( + [ + {"type": "Interaction", "interaction variables": []}, + ], + ) + + # Duplicate variable names (explicitly) + with self.assertRaises(ValueError) as e: + dedupe.api.ActiveMatching( + [ + {"field": "age", "type": "String", "variable name": "my_age"}, + {"field": "age", "type": "ShortString", "variable name": "my_age"}, + ], + ) + assert "Variable name used more than once!" in str(e.exception) + + # Duplicate variable names (implicitly) + with self.assertRaises(ValueError) as e: + dedupe.api.ActiveMatching( + [ + {"field": "age", "type": "String"}, + {"field": "age", "type": "String"}, + ], + ) + assert "Variable name used more than once!" in str(e.exception) + dedupe.api.ActiveMatching( [ {"field": "name", "type": "Custom", "comparator": lambda x, y: 1}, @@ -78,6 +118,19 @@ def test_initialize_fields(self): ], ) + dedupe.api.ActiveMatching( + [ + {"field": "name", "variable name": "name", "type": "String"}, + { + "field": "age", + "variable name": "age", + "type": "Custom", + "comparator": lambda x, y: 1, + }, + {"type": "Interaction", "interaction variables": ["name", "age"]}, + ], + ) + def test_check_record(self): matcher = dedupe.api.ActiveMatching(self.field_definition)