Skip to content

Commit

Permalink
add initial methods for NCES2 and ROCES
Browse files Browse the repository at this point in the history
  • Loading branch information
Jean-KOUAGOU committed Dec 6, 2024
1 parent d328569 commit c74de16
Show file tree
Hide file tree
Showing 7 changed files with 417 additions and 40 deletions.
2 changes: 1 addition & 1 deletion examples/train_nces.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def start(args):

nces = NCES(knowledge_base_path=knowledge_base_path, learner_names=args.models,
path_of_embeddings=path_of_embeddings, max_length=48, proj_dim=128, rnn_n_layers=2, drop_prob=0.1,
num_heads=4, num_seeds=1, num_inds=32, verbose=True, load_pretrained=args.load_pretrained)
num_heads=4, num_seeds=1, m=32, verbose=True, load_pretrained=args.load_pretrained)

nces.train(training_data, epochs=args.epochs, learning_rate=args.learning_rate, num_workers=2, save_model=True)

Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def get_default_arguments(description=None):
parser.add_argument("--drop_prob", type=float, default=0.1, help="Drop probability.")
parser.add_argument("--num_heads", type=int, default=4, help="Number of heads")
parser.add_argument("--num_seeds", type=int, default=1, help="Number of seeds (only for SetTransformer).")
parser.add_argument("--num_inds", type=int, default=32, help="Number of inducing points (only for SetTransformer).")
parser.add_argument("--m", type=int, default=32, help="Number of inducing points (only for SetTransformer).")
parser.add_argument("--ln", type=bool, default=False, help="Layer normalization (only for SetTransformer).")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
parser.add_argument("--decay_rate", type=int, default=0, help="Decay rate.")
Expand Down
27 changes: 18 additions & 9 deletions ontolearn/base_nces.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,14 @@
import numpy as np
from torch.functional import F
from torch.nn.utils.rnn import pad_sequence
from .utils import read_csv
from abc import abstractmethod


class BaseNCES:

def __init__(self, knowledge_base_path, learner_names, path_of_embeddings, batch_size=256, learning_rate=1e-4,
decay_rate=0.0, clip_value=5.0, num_workers=4):
self.name = "NCES"
def __init__(self, knowledge_base_path, quality_func, num_predictions, proj_dim=128, drop_prob=0.1,
num_heads=4, num_seeds=1, m=32, ln=False, learning_rate=1e-4, decay_rate=0.0, clip_value=5.0,
batch_size=256, num_workers=4, max_length=48, load_pretrained=True, sorted_examples=False, verbose: int = 0):
kb = KnowledgeBase(path=knowledge_base_path)
self.kb_namespace = list(kb.ontology.classes_in_signature())[0].iri.get_namespace()
self.renderer = DLSyntaxObjectRenderer()
Expand All @@ -52,15 +51,25 @@ def __init__(self, knowledge_base_path, learner_names, path_of_embeddings, batch
self.all_individuals = set([ind.str.split("/")[-1] for ind in kb.individuals()])
self.inv_vocab = np.array(vocab, dtype='object')
self.vocab = {vocab[i]: i for i in range(len(vocab))}
self.learner_names = learner_names
self.num_examples = self.find_optimal_number_of_examples(kb)
self.batch_size = batch_size
self.num_predictions = num_predictions
self.proj_dim = proj_dim
self.drop_prob = drop_prob
self.num_heads = num_heads
self.num_seeds = num_seeds
self.m = m
self.ln = ln
self.learning_rate = learning_rate
self.decay_rate = decay_rate
self.clip_value = clip_value
self.batch_size = batch_size
self.num_workers = num_workers
self.instance_embeddings = read_csv(path_of_embeddings)
self.input_size = self.instance_embeddings.shape[1]
self.max_length = max_length
self.load_pretrained = load_pretrained
self.sorted_examples = sorted_examples
self.verbose = verbose
self.num_examples = self.find_optimal_number_of_examples(kb)
self.best_predictions = None


@staticmethod
def find_optimal_number_of_examples(kb):
Expand Down
6 changes: 3 additions & 3 deletions ontolearn/clip_architectures.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ def forward(self, x1, x2):

class LengthLearner_SetTransformer(nn.Module):
"""SetTransformer architecture."""
def __init__(self, input_size, output_size, proj_dim=256, num_heads=4, num_seeds=1, num_inds=32):
def __init__(self, input_size, output_size, proj_dim=256, num_heads=4, num_seeds=1, m=32):
super().__init__()
self.name = 'SetTransformer'
self.loss = nn.CrossEntropyLoss()
self.enc = nn.Sequential(
ISAB(input_size, proj_dim, num_heads, num_inds),
ISAB(proj_dim, proj_dim, num_heads, num_inds))
ISAB(input_size, proj_dim, num_heads, m),
ISAB(proj_dim, proj_dim, num_heads, m))
self.dec = nn.Sequential(
PMA(proj_dim, num_heads, num_seeds),
nn.Linear(proj_dim, output_size))
Expand Down
Loading

0 comments on commit c74de16

Please sign in to comment.