Skip to content

Commit

Permalink
refactor: Update get_Z function to include label argument
Browse files Browse the repository at this point in the history
  • Loading branch information
LudensZhang committed Jul 8, 2024
1 parent 7bb4d52 commit c125a05
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 7 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Generates synthetic microbiome data using the pretrained MGM model. A prompt fil
**Example:**

```bash
mgm generate -m infant_model_clf -p infant_data/prompt.txt -n 100 -o infant_synthetic.pkl
mgm generate -m infant_model_gen -p infant_data/prompt.txt -n 100 -o infant_synthetic.pkl
```

#### `reconstruct`
Expand All @@ -114,8 +114,8 @@ Reconstruct abundance from ranked corpus.
**Output:** Reconstructed corpus ; Reconstructor model; Decoded label

```bash
mgm reconstruct -a infant_data/abundance.csv -i infant_generate.pkl -g infant_model_generate -w True -o reconstructor_file
mgm reconstruct -r reconstructor_file/reconstructor_model.ckpt -i infant_generate.pkl -g infant_model_generate -w True -o reconstructor_file
mgm reconstruct -a infant_data/abundance.csv -i infant_synthetic.pkl -g infant_model_generate -w True -o reconstructor_file
mgm reconstruct -r reconstructor_file/reconstructor_model.ckpt -i infant_synthetic.pkl -g infant_model_generate -w True -o reconstructor_file
```

For detailed usage of each mode, refer to the help message:
Expand Down
2 changes: 1 addition & 1 deletion mgm/CLI/main_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def reconstruct(cfg, args):
P_train = abundance
P_train = torch.tensor(P_train.values)
Z_train = corpus.data
Z_train = get_Z(Z_train, position_encodings, vocab_size)
Z_train = get_Z(Z_train, position_encodings, vocab_size, label=args.withLabel)

del corpus

Expand Down
2 changes: 1 addition & 1 deletion mgm/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def gen_num_sent(start, model, num_sent, tokenizer, bad_words=None):
def loss_bc(p_i,q_i):
return torch.sum(torch.abs(p_i-q_i))/torch.sum(torch.abs(p_i+q_i))

def get_Z(corpus, position_encodings, vocab_size, label):
def get_Z(corpus, position_encodings, vocab_size, label=True):

if label:
corpus = torch.cat((corpus[:, 0:1], corpus[:, 2:],
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
NAME = "microformer-mgm"
AUTHOR = "Haohong Zhang"
EMAIL = "haohongzh@gmail.com"
URL = "https://github.com/LudensZhang/MGM"
URL = "https://github.com/HUST-NingKang-Lab/MGM"
LICENSE = "MIT"
DESCRIPTION = "MGM (Microbial General Model) as a large-scaled pretrained language model for interpretable microbiome data analysis."


if __name__ == "__main__":
setup(
name=NAME,
version="0.5.5",
version="0.5.6",
author=AUTHOR,
author_email=EMAIL,
url=URL,
Expand Down

0 comments on commit c125a05

Please sign in to comment.