Skip to content

Commit

Permalink
Reducing number of samples for evaluation when testing
Browse files Browse the repository at this point in the history
  • Loading branch information
ivrodr-msft authored and mahilleb-msft committed May 31, 2017
1 parent d91223a commit 27775bf
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train_and_test(network, trainer, train_source, test_source, minibatch_size,

# Train and evaluate the network.
def vgg16_train_and_eval(train_data, test_data, num_quantization_bits=32, minibatch_size=128, epoch_size = 1281167, max_epochs=80,
restore=True, log_to_file=None, num_mbs_per_log=None, gen_heartbeat=False):
restore=True, log_to_file=None, num_mbs_per_log=None, gen_heartbeat=False, testing=False):
_cntk_py.set_computation_network_trace_level(0)

progress_printer = ProgressPrinter(
Expand All @@ -186,7 +186,14 @@ def vgg16_train_and_eval(train_data, test_data, num_quantization_bits=32, miniba
network = create_vgg16()
trainer = create_trainer(network, epoch_size, num_quantization_bits, progress_printer)
train_source = create_image_mb_source(train_data, True, total_number_of_samples=max_epochs * epoch_size)
test_source = create_image_mb_source(test_data, False, total_number_of_samples=FULL_DATA_SWEEP)

if testing:
# reduce number of samples for validation when testing
num_of_validation_samples = max_epochs * epoch_size * 10
else:
num_of_validation_samples = FULL_DATA_SWEEP

test_source = create_image_mb_source(test_data, False, total_number_of_samples=num_of_validation_samples)
train_and_test(network, trainer, train_source, test_source, minibatch_size, epoch_size, restore)


Expand All @@ -203,6 +210,7 @@ def vgg16_train_and_eval(train_data, test_data, num_quantization_bits=32, miniba
parser.add_argument('-q', '--quantized_bits', help='Number of quantized bits used for gradient aggregation', type=int, required=False, default='32')
parser.add_argument('-r', '--restart', help='Indicating whether to restart from scratch (instead of restart from checkpoint file by default)', action='store_true')
parser.add_argument('-device', '--device', type=int, help="Force to run the script on a specified device", required=False, default=None)
parser.add_argument('-testing', '--testing', help='Indicate if running for testing purposes (validation only done in a portion of the test dataset)', action='store_true')

args = vars(parser.parse_args())

Expand Down Expand Up @@ -232,6 +240,8 @@ def vgg16_train_and_eval(train_data, test_data, num_quantization_bits=32, miniba
restore=not args['restart'],
log_to_file=args['logdir'],
num_mbs_per_log=200,
gen_heartbeat=True)
gen_heartbeat=True,
testing=args['testing'])

# Must call MPI finalize when process exit without exceptions
Communicator.finalize()
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def train_and_test(network, trainer, train_source, test_source, minibatch_size,

# Train and evaluate the network.
def vgg19_train_and_eval(train_data, test_data, num_quantization_bits=32, minibatch_size=128, epoch_size = 1281167, max_epochs=80,
restore=True, log_to_file=None, num_mbs_per_log=None, gen_heartbeat=False):
restore=True, log_to_file=None, num_mbs_per_log=None, gen_heartbeat=False, testing=False):
_cntk_py.set_computation_network_trace_level(0)

progress_printer = ProgressPrinter(
Expand All @@ -186,7 +186,14 @@ def vgg19_train_and_eval(train_data, test_data, num_quantization_bits=32, miniba
network = create_vgg19()
trainer = create_trainer(network, epoch_size, num_quantization_bits, progress_printer)
train_source = create_image_mb_source(train_data, True, total_number_of_samples=max_epochs * epoch_size)
test_source = create_image_mb_source(test_data, False, total_number_of_samples=FULL_DATA_SWEEP)

if testing:
# reduce number of samples for validation when testing
num_of_validation_samples = max_epochs * epoch_size * 10
else:
num_of_validation_samples = FULL_DATA_SWEEP

test_source = create_image_mb_source(test_data, False, total_number_of_samples=num_of_validation_samples)
train_and_test(network, trainer, train_source, test_source, minibatch_size, epoch_size, restore)


Expand All @@ -203,6 +210,7 @@ def vgg19_train_and_eval(train_data, test_data, num_quantization_bits=32, miniba
parser.add_argument('-q', '--quantized_bits', help='Number of quantized bits used for gradient aggregation', type=int, required=False, default='32')
parser.add_argument('-r', '--restart', help='Indicating whether to restart from scratch (instead of restart from checkpoint file by default)', action='store_true')
parser.add_argument('-device', '--device', type=int, help="Force to run the script on a specified device", required=False, default=None)
parser.add_argument('-testing', '--testing', help='Indicate if running for testing purposes (validation only done in a portion of the test dataset)', action='store_true')

args = vars(parser.parse_args())

Expand Down Expand Up @@ -232,6 +240,8 @@ def vgg19_train_and_eval(train_data, test_data, num_quantization_bits=32, miniba
restore=not args['restart'],
log_to_file=args['logdir'],
num_mbs_per_log=200,
gen_heartbeat=True)
gen_heartbeat=True,
testing=args['testing'])

# Must call MPI finalize when process exit without exceptions
Communicator.finalize()
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def test_VGG16_imagenet_distributed(device_id):
"-datadir", prepare_ImageNet_data(),
"-q", "32",
"-device", str(device_id),
"-r"]
"-r",
"-testing"]

# Currently we only test for CPU since the memory usage is very high for GPU (~6 GB)
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.99, True, timeout_seconds=900, use_only_cpu=True)
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.99, True, timeout_seconds=500, use_only_cpu=True)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ def test_VGG19_imagenet_distributed(device_id):
"-datadir", prepare_ImageNet_data(),
"-q", "32",
"-device", str(device_id),
"-r"]
"-r",
"-testing"]

# Currently we only test for CPU since the memory usage is very high for GPU (~6 GB)
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.99, True, timeout_seconds=900, use_only_cpu=True)
mpiexec_test(device_id, script_under_test, mpiexec_params, params, 0.99, True, timeout_seconds=500, use_only_cpu=True)

0 comments on commit 27775bf

Please sign in to comment.