Skip to content

Commit

Permalink
Merge pull request dennybritz#43 from PayscaleNateW/parameterize-file…
Browse files Browse the repository at this point in the history
…s-and-train-dev-split

Parameterize files and train dev split
  • Loading branch information
dennybritz authored Nov 5, 2016
2 parents 9c161b5 + 6ad8ba2 commit 0b48d34
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
6 changes: 3 additions & 3 deletions data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@ def clean_str(string):
return string.strip().lower()


def load_data_and_labels():
def load_data_and_labels(positive_data_file, negative_data_file):
"""
Loads MR polarity data from files, splits the data into words and generates labels.
Returns split sentences and labels.
"""
# Load data from files
positive_examples = list(open("./data/rt-polaritydata/rt-polarity.pos", "r").readlines())
positive_examples = list(open(positive_data_file, "r").readlines())
positive_examples = [s.strip() for s in positive_examples]
negative_examples = list(open("./data/rt-polaritydata/rt-polarity.neg", "r").readlines())
negative_examples = list(open(negative_data_file, "r").readlines())
negative_examples = [s.strip() for s in negative_examples]
# Split by words
x_text = positive_examples + negative_examples
Expand Down
6 changes: 5 additions & 1 deletion eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# Parameters
# ==================================================

# Data Parameters
tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the positive data.")

# Eval Parameters
tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)")
tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run")
Expand All @@ -31,7 +35,7 @@

# CHANGE THIS: Load data. Load your own data here
if FLAGS.eval_train:
x_raw, y_test = data_helpers.load_data_and_labels()
x_raw, y_test = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)
y_test = np.argmax(y_test, axis=1)
else:
x_raw = ["a masterpiece four years in the making", "everything is off."]
Expand Down
12 changes: 9 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
# Parameters
# ==================================================

# Data loading params
tf.flags.DEFINE_float("dev_sample_percentage", .1, "Percentage of the training data to use for validation")
tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.")
tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the positive data.")

# Model Hyperparameters
tf.flags.DEFINE_integer("embedding_dim", 128, "Dimensionality of character embedding (default: 128)")
tf.flags.DEFINE_string("filter_sizes", "3,4,5", "Comma-separated filter sizes (default: '3,4,5')")
Expand Down Expand Up @@ -41,7 +46,7 @@

# Load data
print("Loading data...")
x_text, y = data_helpers.load_data_and_labels()
x_text, y = data_helpers.load_data_and_labels(FLAGS.positive_data_file, FLAGS.negative_data_file)

# Build vocabulary
max_document_length = max([len(x.split(" ")) for x in x_text])
Expand All @@ -56,8 +61,9 @@

# Split train/test set
# TODO: This is very crude, should use cross-validation
x_train, x_dev = x_shuffled[:-1000], x_shuffled[-1000:]
y_train, y_dev = y_shuffled[:-1000], y_shuffled[-1000:]
dev_sample_index = -1 * int(FLAGS.dev_sample_percentage * float(len(y)))
x_train, x_dev = x_shuffled[:dev_sample_index], x_shuffled[dev_sample_index:]
y_train, y_dev = y_shuffled[:dev_sample_index], y_shuffled[dev_sample_index:]
print("Vocabulary Size: {:d}".format(len(vocab_processor.vocabulary_)))
print("Train/Dev split: {:d}/{:d}".format(len(y_train), len(y_dev)))

Expand Down

0 comments on commit 0b48d34

Please sign in to comment.