From c6de6e9b9d166d923b4cd2d80d7ca56533b4096a Mon Sep 17 00:00:00 2001 From: Curtis Wigington Date: Tue, 25 Sep 2018 08:56:27 -0700 Subject: [PATCH] prep to add data --- README.md | 18 +++--- compare_results.py | 80 +++++++++++++++++++++++++++ run_decode.py | 2 +- sample_config_60.yaml | 126 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 217 insertions(+), 9 deletions(-) create mode 100755 compare_results.py create mode 100644 sample_config_60.yaml diff --git a/README.md b/README.md index 8eb1109..3bc8b90 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,12 @@ # Start Follow Read -This repository is the implementation of the methods described in our paper [Start, Follow, Read: Full-Page End-to-end Handwriting Recognition](http://example.com). +This repository is the implementation of the methods described in our paper [Start, Follow, Read: Full-Page End-to-end Handwriting Recognition](http://openaccess.thecvf.com/content_ECCV_2018/html/Curtis_Wigington_Start_Follow_Read_ECCV_2018_paper.html). All steps to reproduce our results for the [ICDAR2017 Competition on Handwritten Text Recognition on the READ Dataset](https://scriptnet.iit.demokritos.gr/competitions/8/) can be found in this repo. This code is free for academic and research use. For commercial use of our code and methods please contact [BYU Tech Transfer](techtransfer.byu.edu). -We will also include [pretrained models](http://example.com). +We will also include pretrained models, results, and the segmentation data inferred during training. These can be found on the [release page](https://github.com/cwig/start_follow_read/releases). ## Dependencies @@ -16,7 +16,7 @@ The dependencies are all found in `environment.yaml`. They are installed as foll conda env create -f environment.yml ``` -The environment is activated as `source activate sfr_env`. +The environment is activated as `source activate sfr_env`. You will need to install the following libraries from source. warp-ctc is needed for training. PyKaldi is used for the language model. A pretrained Start, Follow, Read network can run @@ -61,7 +61,7 @@ python preprocessing/prep_train_b.py data/Train-B data/Train-B data/train_b data Currently we only support running the tests for the Test-B task, not Test-A. When we compute the results for the Test-B while fully exploiting the competition provided regions-of-interest (ROI) we have to do a preprocessing step. This process masks out parts of the image that are not contained in the ROI. ``` -python preprocessing/prep_test_b_with_regions.py data/Test-B data/Test-B data/train_b_roi +python preprocessing/prep_test_b_with_regions.py data/Test-B data/Test-B data/test_b_roi ``` #### Generate Character Settings @@ -188,7 +188,7 @@ python run_hwr.py data/test_b_roi sample_config_60.yaml data/test_b_roi_results ``` ``` -python run_decode.py sample_config_60.yaml data/output/test_b_roi_results --in_xml_folder data/Test-B --out_xml_folder data/output/test_b_roi_xml --roi --aug --lm +python run_decode.py sample_config_60.yaml data/test_b_roi_results --in_xml_folder data/Test-B --out_xml_folder data/test_b_roi_xml --roi --aug --lm ``` #### Without using the competition regions-of-interest @@ -198,15 +198,17 @@ python run_hwr.py data/test_b sample_config_60.yaml data/test_b_results ``` ``` -python run_decode.py sample_config_60.yaml data/output/test_b_results --in_xml_folder data/Test-B --out_xml_folder data/output/test_b_xml --aug --lm +python run_decode.py sample_config_60.yaml data/test_b_results --in_xml_folder data/Test-B --out_xml_folder data/test_b_xml --aug --lm ``` #### Submission -The xml folder needs to be compressed to a `.tar` and then can be submitted to the online evaluation system. We also include the xml files from our baseline system so you can compute the error with regards those predictions instead of submitting to the online system. +The xml folder needs to be compressed to a `.tar` and then can be submitted to the [online evaluation system](https://scriptnet.iit.demokritos.gr/competitions/8/). + +We also include the xml files from our system so you can compute the error with regards those predictions instead of submitting to the online system. This will give you a rough idea of how your results compare to other results. The error rate is not computed the same as on the evaluation server. `` -not yet documented +python compare_results.py `` ## Validation (General) diff --git a/compare_results.py b/compare_results.py new file mode 100755 index 0000000..b3643b4 --- /dev/null +++ b/compare_results.py @@ -0,0 +1,80 @@ +import sys +import os +from preprocessing import parse_PAGE +from collections import defaultdict +import editdistance + +def read_xml(filename): + with open(filename) as f: + xml_string_data = f.read() + # xml_string_data = f1.replace("&", "&") + # xml_string_data = f1.replace("&", "&") + + return xml_string_data + +def get_lines_in_region(data): + regions = defaultdict(list) + for l in data['lines']: + regions[l['region_id']].append(l) + return regions + +if __name__ == "__main__": + + f1 = sys.argv[1] + f2 = sys.argv[2] + + f1_files = {} + for root, folders, files in os.walk(f1): + for f in files: + if f.endswith(".xml"): + f1_files[f] = os.path.join(root, f) + + f2_files = {} + for root, folders, files in os.walk(f2): + for f in files: + if f.endswith(".xml"): + f2_files[f] = os.path.join(root, f) + print len(f1_files) + print len(f2_files) + + sum_dif = 0 + results =[] + running_sum = 0 + + for i, k in enumerate(sorted(f1_files)): + filename = k + f1 = f1_files[k] + f2 = f2_files[k] + + xml1 = read_xml(f1) + xml2 = read_xml(f2) + + data1 = parse_PAGE.readXMLFile(xml1)[0] + data2 = parse_PAGE.readXMLFile(xml2)[0] + + + region1 = get_lines_in_region(data1) + region2 = get_lines_in_region(data2) + + joint_set = set(region1.keys()) | set(region2.keys()) + xor_set = set(region1.keys()) ^ set(region2.keys()) + + if len(xor_set) != 0: + print k, xor_set + + for k in set(region1.keys()) | set(region2.keys()): + + full_r1 = "\n".join([l['ground_truth'] for l in region1[k] ]) + full_r2 = "\n".join([l['ground_truth'] for l in region2[k] ]) + + dis = editdistance.eval(full_r1, full_r2) + + length = (len(full_r1) + len(full_r2)) + if length == 0: + out = 0 + else: + out = dis / float(length) + + results.append((out, filename, k, i, full_r1, full_r2)) + sum_dif += out + print "WER", sum_dif diff --git a/run_decode.py b/run_decode.py index 2e9953d..c4b6e3a 100644 --- a/run_decode.py +++ b/run_decode.py @@ -85,7 +85,7 @@ def main(): idx_to_char[int(k)] = v if use_aug: - model_mode = "pretrain" + model_mode = "best_overall" _,_, hw = init_model(config, hw_dir=model_mode, only_load="hw") dtype = torch.cuda.FloatTensor hw.eval() diff --git a/sample_config_60.yaml b/sample_config_60.yaml new file mode 100644 index 0000000..c3510e7 --- /dev/null +++ b/sample_config_60.yaml @@ -0,0 +1,126 @@ +network: + sol: + base0: 16 #architecture dependant - don't change + base1: 16 #architecture dependant - don't change + + lf: + look_ahead_matrix: + step_bias: + + hw: + num_of_outputs: 197 + num_of_channels: 3 + cnn_out_size: 1024 #architecture dependant + input_height: 60 #architecture dependant + char_set_path: "data/char_set.json" + + lm: + fst_path: "../hwn5-comp-2017/data/comp_lm/data/graph/HCLG.fst" + mdl_path: "../hwn5-comp-2017/data/comp_lm/data/lang_test/basic.mdl" + words_path: "../hwn5-comp-2017/data/comp_lm/data/graph/words.txt" + phones_path: "../hwn5-comp-2017/data/comp_lm/data/lang_test/phones.txt" + beam: 8 + +pretraining: + training_set: + img_folder: "" + json_folder: "" + file_list: "data/train_a_training_set.json" + + validation_set: + img_folder: "" + json_folder: "" + file_list: "data/train_a_validation_set.json" + + sol: + alpha_alignment: 0.1 + alpha_backprop: 0.1 + learning_rate: 0.0001 #pyyaml bug: no scientific notation + crop_params: + prob_label: 0.5 + crop_size: 256 + training_rescale_range: [384, 640] + validation_rescale_range: [512,512] #Don't validate on random range + batch_size: 1 #During pretrain, only 45 images. If batch is 32 you would get 32 and 13 in an epoch + images_per_epoch: 1000 + stop_after_no_improvement: 10 + + lf: + learning_rate: 0.0001 #pyyaml bug: no scientific notation + batch_size: 1 + images_per_epoch: 1000 + stop_after_no_improvement: 10 + + hw: + learning_rate: 0.0002 #pyyaml bug: no scientific notation + batch_size: 8 + images_per_epoch: 1000 + stop_after_no_improvement: 10 + + snapshot_path: "data/snapshots/init" + +training: + training_set: + img_folder: "" + json_folder: "" + file_list: "data/train_b_training_set.json" + + validation_set: + img_folder: "" + json_folder: "" + file_list: "data/train_b_validation_set.json" + + sol: + alpha_alignment: 0.1 + alpha_backprop: 0.1 + learning_rate: 0.0001 #pyyaml bug: no scientific notation + crop_params: + prob_label: 0.5 + crop_size: 256 + training_rescale_range: [384, 640] + validation_rescale_range: [512,512] #You should not validation on random range + validation_subset_size: 1000 + batch_size: 1 + images_per_epoch: 10000 + reset_interval: 3600 #seconds + + + lf: + learning_rate: 0.0001 #pyyaml bug: no scientific notation + + batch_size: 1 + refresh_interval: 3600 #seconds + images_per_epoch: 1000 #batches + validation_subset_size: 100 #images + reset_interval: 3600 #seconds + + hw: + learning_rate: 0.0002 #pyyaml bug: no scientific notation + + batch_size: 8 + refresh_interval: 3600 #seconds + images_per_epoch: 20000 #batches + validation_subset_size: 2000 #images + reset_interval: 3600 #seconds + + alignment: + accept_threshold: 0.1 + sol_resize_width: 512 + metric: "cer" + train_refresh_groups: 10 + + validation_post_processing: + sol_thresholds: [0.1,0.3,0.5,0.7,0.9] + lf_nms_ranges: [[0,6],[0,16],[0,20]] + lf_nms_thresholds: [0.1,0.3,0.5,0.7,0.9] + + snapshot: + best_overall: "data/snapshots/best_overall" + best_validation: "data/snapshots/best_validation" + current: "data/snapshots/current" + pretrain: "data/snapshots/init" + +post_processing: + sol_threshold: 0.1 + lf_nms_range: [0,6] + lf_nms_threshold: 0.5