diff --git a/README.md b/README.md index c40310a..5863415 100644 --- a/README.md +++ b/README.md @@ -22,15 +22,17 @@ These example commands are executed from within the repository folder. ### Training ```shell script -python train.py --parser lin --dir models/lin --conll ../conll2016 +python cli/train.py lin path/to/model path/to/conll ``` +Training data format is json, the folder contains subfolders `en.{train,dev,test}` +with files `rtelations.json` and `parses.json`. -### Prediction +### Evaluation ```shell script -python parse.py -i path/to/some/textfile -m models/lin +python cli/test.py lin path/to/model path/to/conll ``` -### Evaluation +### Prediction ```shell script -python test.py --parser lin --dir models/lin --conll ../conll2016 +python cli/parse.py -i path/to/some/textfile -m models/lin ``` diff --git a/cli/__init__.py b/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/parse.py b/cli/parse.py similarity index 100% rename from parse.py rename to cli/parse.py diff --git a/test.py b/cli/test.py similarity index 59% rename from test.py rename to cli/test.py index 809760a..2ccd5f0 100644 --- a/test.py +++ b/cli/test.py @@ -1,19 +1,16 @@ import os +import click from tqdm import tqdm from discopy.parsers import get_parser -from discopy.semi_utils import get_arguments -args = get_arguments() -os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu +os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES', '') from discopy.data.conll16 import get_conll_dataset from discopy.utils import init_logger import discopy.evaluate.exact -os.makedirs(args.dir, exist_ok=True) - logger = init_logger() @@ -41,33 +38,27 @@ def evaluate_parser(pdtb_gold, pdtb_pred, threshold=0.7): return discopy.evaluate.exact.evaluate_all(gold_relations, pred_relations, threshold=threshold) -if __name__ == '__main__': - parses_train, pdtb_train = get_conll_dataset(args.conll, 'en.train', load_trees=True, connective_mapping=True) - parses_val, pdtb_val = get_conll_dataset(args.conll, 'en.dev', load_trees=True, connective_mapping=True) - parses_test, pdtb_test = get_conll_dataset(args.conll, 'en.test', load_trees=True, connective_mapping=True) - parses_blind, pdtb_blind = get_conll_dataset(args.conll, 'en.blind-test', load_trees=True, connective_mapping=True) - +@click.command() +@click.argument('parser', type=str) +@click.argument('model-path', type=str) +@click.argument('conll-path', type=str) +@click.option('-t', '--threshold', default=0.9, type=str) +def main(parser, model_path, conll_path, threshold): + parses_test, pdtb_test = get_conll_dataset(conll_path, 'en.test', load_trees=True, connective_mapping=True) + parses_blind, pdtb_blind = get_conll_dataset(conll_path, 'en.blind-test', load_trees=True, connective_mapping=True) logger.info('Init Parser...') - parser = get_parser(args.parser) - parser_path = args.dir - - if args.train: - logger.info('Train end-to-end Parser...') - parser.fit(pdtb_train, parses_train, pdtb_val, parses_val) - parser.save(os.path.join(args.dir)) - elif os.path.exists(args.dir): - logger.info('Load pre-trained Parser...') - parser.load(args.dir) - else: - raise ValueError('Training and Loading not clear') - + parser = get_parser(parser) + logger.info('Load pre-trained Parser...') + parser.load(model_path) logger.info('component evaluation (test)') parser.score(pdtb_test, parses_test) - logger.info('extract discourse relations from test data') pdtb_pred = extract_discourse_relations(parser, parses_test) - evaluate_parser(pdtb_test, pdtb_pred, threshold=args.threshold) - + evaluate_parser(pdtb_test, pdtb_pred, threshold=threshold) logger.info('extract discourse relations from BLIND data') pdtb_pred = extract_discourse_relations(parser, parses_blind) - evaluate_parser(pdtb_blind, pdtb_pred, threshold=args.threshold) + evaluate_parser(pdtb_blind, pdtb_pred, threshold=threshold) + + +if __name__ == '__main__': + main() diff --git a/train.py b/cli/train.py similarity index 52% rename from train.py rename to cli/train.py index 4a43dae..98d2c5c 100644 --- a/train.py +++ b/cli/train.py @@ -1,29 +1,15 @@ -import argparse import os -from discopy.data.conll16 import get_conll_dataset -from discopy.parsers import get_parser - -os.environ['CUDA_VISIBLE_DEVICES'] = '0' - +import click from tqdm import tqdm +os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES', '') + +from discopy.data.conll16 import get_conll_dataset +from discopy.parsers import get_parser import discopy.evaluate.exact from discopy.utils import init_logger -argument_parser = argparse.ArgumentParser() -argument_parser.add_argument("--dir", help="", - default='tmp') -argument_parser.add_argument("--conll", help="", - default='') -argument_parser.add_argument("--parser", help="", - default='lin') -argument_parser.add_argument("--threshold", help="", - default=0.9, type=float) -args = argument_parser.parse_args() - -os.makedirs(args.dir, exist_ok=True) - logger = init_logger() @@ -51,19 +37,19 @@ def evaluate_parser(pdtb_gold, pdtb_pred, threshold=0.7): return discopy.evaluate.exact.evaluate_all(gold_relations, pred_relations, threshold=threshold) -if __name__ == '__main__': - parses_train, pdtb_train = get_conll_dataset(args.conll, 'en.train', load_trees=True, connective_mapping=True) - parses_val, pdtb_val = get_conll_dataset(args.conll, 'en.dev', load_trees=True, connective_mapping=True) - parses_test, pdtb_test = get_conll_dataset(args.conll, 'en.test', load_trees=True, connective_mapping=True) - parses_blind, pdtb_blind = get_conll_dataset(args.conll, 'en.blind-test', load_trees=True, connective_mapping=True) - +@click.command() +@click.argument('parser', type=str) +@click.argument('model-path', type=str) +@click.argument('conll-path', type=str) +def main(parser, model_path, conll_path): + parses_train, pdtb_train = get_conll_dataset(conll_path, 'en.train', load_trees=True, connective_mapping=True) + parses_val, pdtb_val = get_conll_dataset(conll_path, 'en.dev', load_trees=True, connective_mapping=True) logger.info('Init Parser...') - parser = get_parser(args.parser) - + parser = get_parser(parser) logger.info('Train end-to-end Parser...') parser.fit(pdtb_train, parses_train, pdtb_val, parses_val) - parser.save(os.path.join(args.dir)) + parser.save(os.path.join(model_path)) + - logger.info('extract discourse relations from test data') - pdtb_pred = extract_discourse_relations(parser, parses_test) - all_results = evaluate_parser(pdtb_test, pdtb_pred, threshold=args.threshold) +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py index 04ae7ba..9d162ef 100644 --- a/setup.py +++ b/setup.py @@ -28,8 +28,9 @@ zip_safe=False, entry_points={ 'console_scripts': [ - 'discopy=main:main', - 'discopy-parse=parse' + 'discopy-train=cli.train:main', + 'discopy-test=cli.test:main', + 'discopy-parse=cli.parse:main', ], } )