Skip to content

Commit

Permalink
count init model
Browse files Browse the repository at this point in the history
  • Loading branch information
Lurkrazy committed Mar 20, 2024
1 parent d8966f1 commit a2e2ec9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
count_sampled = 0;
int num_start = n_trials;
int max_time = 40;
int max_measurement = 1000;
max_measurement_ = 1000;
int start_idx = 0;
init_model_size_ = sample_init_min_pop_;
auto start_time = std::chrono::high_resolution_clock::now(); // init start time
std::chrono::minutes max_duration(static_cast<int>(max_time)); // convert max_time to minutes
std::chrono::minutes duration; // declare duration variable
Expand Down Expand Up @@ -324,7 +325,7 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure
std::cout << "Num of sampled: #" << count_sampled << std::endl;

duration = std::chrono::duration_cast<std::chrono::minutes>(std::chrono::high_resolution_clock::now() - start_time);
} while (count_sampled != -1 && duration <= max_duration && measured_states_throughputs_.size() < max_measurement);
} while (count_sampled != -1 && duration <= max_duration && measured_states_throughputs_.size() < max_measurement_);

PrintTitle("Done", verbose);

Expand Down Expand Up @@ -1363,7 +1364,7 @@ int SketchPolicyNode::DGD_Move(
gflops_map_[state_str] =
search_task->compute_dag->flop_ct / FloatArrayMean(results[i]->costs) / 1e9;
measured_states_throughputs_.push_back(gflops_map_[state_str]);
if (measured_states_throughputs_.size() > 1000 + sample_init_min_pop_) {
if (measured_states_throughputs_.size() > max_measurement_ + init_model_size_) {
std::cout << "stop here" << std::endl;
return -1;
}
Expand Down
2 changes: 2 additions & 0 deletions src/auto_scheduler/search_policy/sketch_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ class SketchPolicyNode : public SearchPolicyNode {
std::unordered_set<std::string> cache_failed;
/*! \brief The minimul output population of SampleInitPopulation */
int sample_init_min_pop_;
int init_model_size_;
int max_measurement_;
float global_tolerant_threashold;

friend class SketchPolicy;
Expand Down

0 comments on commit a2e2ec9

Please sign in to comment.