Skip to content

Commit

Permalink
Merge pull request #6238 from mitar/manual-sgd
Browse files Browse the repository at this point in the history
[pycaffe] expose interface for manual, step-by-step optimization
  • Loading branch information
shelhamer authored Jun 8, 2018
2 parents a357693 + 1bdcb74 commit 2a1c552
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 15 deletions.
5 changes: 3 additions & 2 deletions include/caffe/sgd_solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ class SGDSolver : public Solver<Dtype> {

const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

virtual void ApplyUpdate();
Dtype GetLearningRate();

protected:
void PreSolve();
Dtype GetLearningRate();
virtual void ApplyUpdate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
Expand Down
3 changes: 2 additions & 1 deletion include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ class Solver {
*/
virtual inline const char* type() const { return ""; }

protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;

protected:
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
Expand Down
20 changes: 12 additions & 8 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::class_<SolverParameter>("SolverParameter", bp::no_init)
.add_property("max_iter", &SolverParameter::max_iter)
.add_property("display", &SolverParameter::display)
.add_property("layer_wise_reduce", &SolverParameter::layer_wise_reduce);
.add_property("layer_wise_reduce", &SolverParameter::layer_wise_reduce)
.add_property("base_lr", &SolverParameter::base_lr,
&SolverParameter::set_base_lr);
bp::class_<LayerParameter>("LayerParameter", bp::no_init);

bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
Expand All @@ -507,26 +509,28 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("restore", &Solver<Dtype>::Restore)
.def("snapshot", &Solver<Dtype>::Snapshot)
.def("share_weights", &share_weights)
.def("apply_update", &Solver<Dtype>::ApplyUpdate)
.add_property("param", bp::make_function(&Solver<Dtype>::param,
bp::return_value_policy<bp::copy_const_reference>()));
bp::return_internal_reference<>()));
BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver<Dtype>);

bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
"SGDSolver", bp::init<string>());
bp::class_<NesterovSolver<Dtype>, bp::bases<Solver<Dtype> >,
"SGDSolver", bp::init<string>())
.add_property("lr", &SGDSolver<Dtype>::GetLearningRate);
bp::class_<NesterovSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<NesterovSolver<Dtype> >, boost::noncopyable>(
"NesterovSolver", bp::init<string>());
bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<AdaGradSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>(
"AdaGradSolver", bp::init<string>());
bp::class_<RMSPropSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<RMSPropSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<RMSPropSolver<Dtype> >, boost::noncopyable>(
"RMSPropSolver", bp::init<string>());
bp::class_<AdaDeltaSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<AdaDeltaSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<AdaDeltaSolver<Dtype> >, boost::noncopyable>(
"AdaDeltaSolver", bp::init<string>());
bp::class_<AdamSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<AdamSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<AdamSolver<Dtype> >, boost::noncopyable>(
"AdamSolver", bp::init<string>());

Expand Down
11 changes: 11 additions & 0 deletions python/caffe/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def test_solve(self):
self.solver.solve()
self.assertEqual(self.solver.iter, 100)

def test_apply_update(self):
net = self.solver.net
data = net.layers[1].blobs[0].data[...]
# Reset the weights of that layer to 0
data[...] = 0
net.layers[1].blobs[0].diff[...] = 1
# Apply the update, the initial learning rate should be 0.01
self.solver.apply_update()
# Check that the new weights are -0.01, with a precision of 1e-7
self.assertTrue((data - -0.01 * np.ones(data.shape)).max() < 1e-7)

def test_net_memory(self):
"""Check that nets survive after the solver is destroyed."""

Expand Down
4 changes: 0 additions & 4 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,6 @@ void Solver<Dtype>::Step(int iters) {
}
ApplyUpdate();

// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++iter_;

SolverAction::Enum request = GetRequestedAction();

// Save a snapshot if needed.
Expand Down
4 changes: 4 additions & 0 deletions src/caffe/solvers/sgd_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ void SGDSolver<Dtype>::ApplyUpdate() {
ComputeUpdateValue(param_id, rate);
}
this->net_->Update();

// Increment the internal iter_ counter -- its value should always indicate
// the number of times the weights have been updated.
++this->iter_;
}

template <typename Dtype>
Expand Down

0 comments on commit 2a1c552

Please sign in to comment.