Skip to content

Commit

Permalink
add separate package for cli and update setup entry points
Browse files Browse the repository at this point in the history
  • Loading branch information
rknaebel committed Nov 24, 2020
1 parent e80c91e commit feb11b5
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 66 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@ These example commands are executed from within the repository folder.

### Training
```shell script
python train.py --parser lin --dir models/lin --conll ../conll2016
python cli/train.py lin path/to/model path/to/conll
```
Training data format is json, the folder contains subfolders `en.{train,dev,test}`
with files `rtelations.json` and `parses.json`.

### Prediction
### Evaluation
```shell script
python parse.py -i path/to/some/textfile -m models/lin
python cli/test.py lin path/to/model path/to/conll
```

### Evaluation
### Prediction
```shell script
python test.py --parser lin --dir models/lin --conll ../conll2016
python cli/parse.py -i path/to/some/textfile -m models/lin
```
Empty file added cli/__init__.py
Empty file.
File renamed without changes.
47 changes: 19 additions & 28 deletions test.py → cli/test.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import os

import click
from tqdm import tqdm

from discopy.parsers import get_parser
from discopy.semi_utils import get_arguments

args = get_arguments()
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES', '')

from discopy.data.conll16 import get_conll_dataset
from discopy.utils import init_logger
import discopy.evaluate.exact

os.makedirs(args.dir, exist_ok=True)

logger = init_logger()


Expand Down Expand Up @@ -41,33 +38,27 @@ def evaluate_parser(pdtb_gold, pdtb_pred, threshold=0.7):
return discopy.evaluate.exact.evaluate_all(gold_relations, pred_relations, threshold=threshold)


if __name__ == '__main__':
parses_train, pdtb_train = get_conll_dataset(args.conll, 'en.train', load_trees=True, connective_mapping=True)
parses_val, pdtb_val = get_conll_dataset(args.conll, 'en.dev', load_trees=True, connective_mapping=True)
parses_test, pdtb_test = get_conll_dataset(args.conll, 'en.test', load_trees=True, connective_mapping=True)
parses_blind, pdtb_blind = get_conll_dataset(args.conll, 'en.blind-test', load_trees=True, connective_mapping=True)

@click.command()
@click.argument('parser', type=str)
@click.argument('model-path', type=str)
@click.argument('conll-path', type=str)
@click.option('-t', '--threshold', default=0.9, type=str)
def main(parser, model_path, conll_path, threshold):
parses_test, pdtb_test = get_conll_dataset(conll_path, 'en.test', load_trees=True, connective_mapping=True)
parses_blind, pdtb_blind = get_conll_dataset(conll_path, 'en.blind-test', load_trees=True, connective_mapping=True)
logger.info('Init Parser...')
parser = get_parser(args.parser)
parser_path = args.dir

if args.train:
logger.info('Train end-to-end Parser...')
parser.fit(pdtb_train, parses_train, pdtb_val, parses_val)
parser.save(os.path.join(args.dir))
elif os.path.exists(args.dir):
logger.info('Load pre-trained Parser...')
parser.load(args.dir)
else:
raise ValueError('Training and Loading not clear')

parser = get_parser(parser)
logger.info('Load pre-trained Parser...')
parser.load(model_path)
logger.info('component evaluation (test)')
parser.score(pdtb_test, parses_test)

logger.info('extract discourse relations from test data')
pdtb_pred = extract_discourse_relations(parser, parses_test)
evaluate_parser(pdtb_test, pdtb_pred, threshold=args.threshold)

evaluate_parser(pdtb_test, pdtb_pred, threshold=threshold)
logger.info('extract discourse relations from BLIND data')
pdtb_pred = extract_discourse_relations(parser, parses_blind)
evaluate_parser(pdtb_blind, pdtb_pred, threshold=args.threshold)
evaluate_parser(pdtb_blind, pdtb_pred, threshold=threshold)


if __name__ == '__main__':
main()
48 changes: 17 additions & 31 deletions train.py → cli/train.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,15 @@
import argparse
import os

from discopy.data.conll16 import get_conll_dataset
from discopy.parsers import get_parser

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import click
from tqdm import tqdm

os.environ['CUDA_VISIBLE_DEVICES'] = os.environ.get('CUDA_VISIBLE_DEVICES', '')

from discopy.data.conll16 import get_conll_dataset
from discopy.parsers import get_parser
import discopy.evaluate.exact
from discopy.utils import init_logger

argument_parser = argparse.ArgumentParser()
argument_parser.add_argument("--dir", help="",
default='tmp')
argument_parser.add_argument("--conll", help="",
default='')
argument_parser.add_argument("--parser", help="",
default='lin')
argument_parser.add_argument("--threshold", help="",
default=0.9, type=float)
args = argument_parser.parse_args()

os.makedirs(args.dir, exist_ok=True)

logger = init_logger()


Expand Down Expand Up @@ -51,19 +37,19 @@ def evaluate_parser(pdtb_gold, pdtb_pred, threshold=0.7):
return discopy.evaluate.exact.evaluate_all(gold_relations, pred_relations, threshold=threshold)


if __name__ == '__main__':
parses_train, pdtb_train = get_conll_dataset(args.conll, 'en.train', load_trees=True, connective_mapping=True)
parses_val, pdtb_val = get_conll_dataset(args.conll, 'en.dev', load_trees=True, connective_mapping=True)
parses_test, pdtb_test = get_conll_dataset(args.conll, 'en.test', load_trees=True, connective_mapping=True)
parses_blind, pdtb_blind = get_conll_dataset(args.conll, 'en.blind-test', load_trees=True, connective_mapping=True)

@click.command()
@click.argument('parser', type=str)
@click.argument('model-path', type=str)
@click.argument('conll-path', type=str)
def main(parser, model_path, conll_path):
parses_train, pdtb_train = get_conll_dataset(conll_path, 'en.train', load_trees=True, connective_mapping=True)
parses_val, pdtb_val = get_conll_dataset(conll_path, 'en.dev', load_trees=True, connective_mapping=True)
logger.info('Init Parser...')
parser = get_parser(args.parser)

parser = get_parser(parser)
logger.info('Train end-to-end Parser...')
parser.fit(pdtb_train, parses_train, pdtb_val, parses_val)
parser.save(os.path.join(args.dir))
parser.save(os.path.join(model_path))


logger.info('extract discourse relations from test data')
pdtb_pred = extract_discourse_relations(parser, parses_test)
all_results = evaluate_parser(pdtb_test, pdtb_pred, threshold=args.threshold)
if __name__ == '__main__':
main()
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@
zip_safe=False,
entry_points={
'console_scripts': [
'discopy=main:main',
'discopy-parse=parse'
'discopy-train=cli.train:main',
'discopy-test=cli.test:main',
'discopy-parse=cli.parse:main',
],
}
)

0 comments on commit feb11b5

Please sign in to comment.