Skip to content

Commit

Permalink
fix bug in generate function
Browse files Browse the repository at this point in the history
  • Loading branch information
LudensZhang committed Jun 3, 2024
1 parent 417f8ab commit 509432d
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion mgm/CLI/CLI_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def find_pkg_resource(path):
raise FileNotFoundError('Resource {} not found, please check'.format(path))

def get_CLI_parser():
modes = ['construct', 'map','pretrain', 'train', 'finetune', 'predict']
modes = ['construct', 'map','pretrain', 'train', 'finetune', 'predict', 'generate']
# noinspection PyTypeChecker
parser = argparse.ArgumentParser(
description=('MGM (Microbiao General Model) is a large-scaled pretrained language model for interpretable microbiome data analysis.\n'
Expand Down
1 change: 1 addition & 0 deletions mgm/CLI/main_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def generate(cfg, args):
gen_sent = gen_num_sent(start,
model,
num_sent=args.num_samples,
tokenizer=extended_tokenizer,
bad_words=bad_words)

dump(gen_sent, open(args.output, "wb"))
6 changes: 3 additions & 3 deletions mgm/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def find_class(self, module, name):
return MicroTokenizer
return super().find_class(module, name)

def generate(sent, model, do_sample=True, bad_words_ids=None, num_return_sequences=100):
def generate(sent, model, tokenizer, do_sample=True, bad_words_ids=None, num_return_sequences=100):
sent = sent.to(model.device)
gen_sent = model.generate(sent,
max_length=512,
Expand All @@ -56,8 +56,8 @@ def generate(sent, model, do_sample=True, bad_words_ids=None, num_return_sequenc
low_memory=True if num_return_sequences > 1 else False)
return gen_sent.cpu().detach()

def gen_num_sent(start, model, num_sent, bad_words=None):
gen_sent = [generate(sent, model, bad_words_ids=bad_words, num_return_sequences=num_sent) for sent in start]
def gen_num_sent(start, model, num_sent, tokenizer, bad_words=None):
gen_sent = [generate(sent, model, tokenizer, bad_words_ids=bad_words, num_return_sequences=num_sent) for sent in start]
gen_sent = [torch.cat([sent, torch.ones(num_sent, 512 - sent.shape[1], dtype=torch.long) * tokenizer.pad_token_id], dim=1) for sent in gen_sent]
gen_sent = torch.cat(gen_sent, dim=0)
return gen_sent
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
if __name__ == "__main__":
setup(
name=NAME,
version="0.4.0",
version="0.4.1",
author=AUTHOR,
author_email=EMAIL,
url=URL,
Expand Down

0 comments on commit 509432d

Please sign in to comment.