Skip to content

Commit

Permalink
added report output
Browse files Browse the repository at this point in the history
  • Loading branch information
sampathkethineedi committed Mar 8, 2019
1 parent f8d1efd commit 3c9e569
Showing 1 changed file with 18 additions and 7 deletions.
25 changes: 18 additions & 7 deletions snorkel_process.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
report = []
import datetime
report = ['\n*** Report for process run on '+str(datetime.datetime.now())+' ***\n']


def doc_parse(path):
Expand Down Expand Up @@ -56,7 +57,8 @@ def extract_candidates(candExtractor, cSubClass):

for i, sents in enumerate([train_sents, dev_sents, test_sents]):
candExtractor.apply(sents, split=i)
print("Number of candidates:", session.query(cSubClass).filter(cSubClass.split == i).count())
report.append("Candidates in split "+str(i)+' : '+
str(session.query(cSubClass).filter(cSubClass.split == i).count())+'\n')


def apply_LF(lf_file):
Expand All @@ -71,7 +73,8 @@ def apply_LF(lf_file):
np.random.seed(1701)
L_train = labeler.apply(split=0)
L_train.todense()
print(L_train.lf_stats(session))
report.append('\n#LF Stats\n')
report.append(L_train.lf_stats(session).to_csv(sep=' ', index=False, header=True))
return L_train


Expand All @@ -86,11 +89,19 @@ def apply_GenMod(L_train):
gen_model.train(L_train, cardinality=3)
# print(gen_model.weights.lf_accuracy)
train_marginals = gen_model.marginals(L_train)
print(gen_model.learned_lf_stats())
report.append('\n#Gen Model Stats\n')
report.append(gen_model.learned_lf_stats().to_csv(sep=' ', index=False, header=True))
save_marginals(session, L_train, train_marginals)


def runSnorkelProcess(path, restart, lf):
def generate_report(name):
with open(name+'/report.txt', 'a') as file:
for item in report:
file.write(item)
file.close()


def runSnorkelProcess(name, path, restart, lf):
"""
Main process flow
:param path: Path to TSV file
Expand All @@ -106,6 +117,7 @@ def runSnorkelProcess(path, restart, lf):
def_cand_extractor()
l_train = apply_LF(lf)
apply_GenMod(l_train)
generate_report(name)


if __name__ == "__main__":
Expand Down Expand Up @@ -140,6 +152,5 @@ def runSnorkelProcess(path, restart, lf):
from snorkel.annotations import save_marginals

session = SnorkelSession()
runSnorkelProcess(args.path, args.restart, args.lf)
runSnorkelProcess(args.name, args.path, args.restart, args.lf)
db_process(args.name)
print(report)

0 comments on commit 3c9e569

Please sign in to comment.