Skip to content

Commit

Permalink
add config methods to bert connective disambiguation model
Browse files Browse the repository at this point in the history
  • Loading branch information
rknaebel committed Jul 20, 2021
1 parent a4053e0 commit 3e8decd
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions discopy/components/connective/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,33 @@ class ConnectiveClassifier(Component):
def __init__(self, input_dim, used_context: int = 0):
in_size = input_dim + 2 * used_context * input_dim
self.model = get_conn_model(in_size, 1, 1024)
self.input_dim = input_dim
self.used_context = used_context

def get_config(self):
return {
'model_name': self.model_name,
'input_dim': self.input_dim,
'used_context': self.used_context,
}

@staticmethod
def from_config(config: dict):
clf = ConnectiveClassifier(config['input_dim'], config['used_context'])
clf.sense_map = config['sense_map']
clf.classes = config['classes']
return clf

def load(self, path):
if not os.path.exists(os.path.join(path, f'connective_nn_{self.used_context}.model')):
if not os.path.exists(os.path.join(path, self.model_name)):
raise FileNotFoundError("Model not found.")
self.model = tf.keras.models.load_model(os.path.join(path, f'connective_nn_{self.used_context}.model'),
self.model = tf.keras.models.load_model(os.path.join(path, self.model_name),
compile=False)

def save(self, path):
if not os.path.exists(path):
os.makedirs(path)
self.model.save(os.path.join(path, f'connective_nn_{self.used_context}.model'))
self.model.save(os.path.join(path, self.model_name))

def fit(self, docs_train: List[Document], docs_val: List[Document] = None):
if docs_val is None:
Expand Down

0 comments on commit 3e8decd

Please sign in to comment.