Skip to content

Commit

Permalink
Parameterized the data files loaded in
Browse files Browse the repository at this point in the history
  • Loading branch information
PayscaleNateW committed Nov 4, 2016
1 parent f7cd6c0 commit 6ad8ba2
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
6 changes: 3 additions & 3 deletions data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def clean_str(string):
return string.strip().lower()


def load_data_and_labels():
def load_data_and_labels(positive_data_file, negative_data_file):
"""
Loads MR polarity data from files, splits the data into words and generates labels.
Returns split sentences and labels.
"""
# Load data from files
positive_examples = list(open("./data/rt-polaritydata/rt-polarity.pos", "r").readlines())
positive_examples = list(open(positive_data_file, "r").readlines())
positive_examples = [s.strip() for s in positive_examples]
negative_examples = list(open("./data/rt-polaritydata/rt-polarity.neg", "r").readlines())
negative_examples = list(open(negative_data_file, "r").readlines())
negative_examples = [s.strip() for s in negative_examples]
# Split by words
x_text = positive_examples + negative_examples
Expand Down
6 changes: 5 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# Parameters
# ==================================================

# Data Parameters
tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the positive data.")

# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
Expand All @@ -31,7 +35,7 @@

# CHANGE THIS: Load data. Load your own data here
if FLAGS.eval_train:
x_raw, y_test = data_helpers.load_data_and_labels()
x_raw, y_test = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
y_test = np.argmax(y_test, axis=1)
else:
x_raw = ["a masterpiece four years in the making", "everything is off."]
Expand Down
8 changes: 6 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# Parameters
# ==================================================

# Data loading params
tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the positive data.")

# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')")
Expand All @@ -27,7 +32,6 @@
# Misc Parameters
tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement")
tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices")
tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")

FLAGS = tf.flags.FLAGS
FLAGS._parse_flags()
Expand All @@ -42,7 +46,7 @@

# Load data
print("Loading data...")
x_text, y = data_helpers.load_data_and_labels()
x_text, y = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)

# Build vocabulary
max_document_length = max([len(x.split(" ")) for x in x_text])
Expand Down

0 comments on commit 6ad8ba2

Please sign in to comment.