diff --git a/discopy/components/connective/bert.py b/discopy/components/connective/bert.py index 5ef2914..cd72fc9 100644 --- a/discopy/components/connective/bert.py +++ b/discopy/components/connective/bert.py @@ -70,18 +70,33 @@ class ConnectiveClassifier(Component): def __init__(self, input_dim, used_context: int = 0): in_size = input_dim + 2 * used_context * input_dim self.model = get_conn_model(in_size, 1, 1024) + self.input_dim = input_dim self.used_context = used_context + def get_config(self): + return { + 'model_name': self.model_name, + 'input_dim': self.input_dim, + 'used_context': self.used_context, + } + + @staticmethod + def from_config(config: dict): + clf = ConnectiveClassifier(config['input_dim'], config['used_context']) + clf.sense_map = config['sense_map'] + clf.classes = config['classes'] + return clf + def load(self, path): - if not os.path.exists(os.path.join(path, f'connective_nn_{self.used_context}.model')): + if not os.path.exists(os.path.join(path, self.model_name)): raise FileNotFoundError("Model not found.") - self.model = tf.keras.models.load_model(os.path.join(path, f'connective_nn_{self.used_context}.model'), + self.model = tf.keras.models.load_model(os.path.join(path, self.model_name), compile=False) def save(self, path): if not os.path.exists(path): os.makedirs(path) - self.model.save(os.path.join(path, f'connective_nn_{self.used_context}.model')) + self.model.save(os.path.join(path, self.model_name)) def fit(self, docs_train: List[Document], docs_val: List[Document] = None): if docs_val is None: