Skip to content

Commit

Permalink
deploy: 8099be9
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jan 27, 2025
1 parent b269d4d commit e52dd59
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 12 deletions.
11 changes: 4 additions & 7 deletions _modules/lzero/entry/train_muzero.html
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ <h1>Source code for lzero.entry.train_muzero</h1><div class="highlight"><pre>
<span class="kn">from</span><span class="w"> </span><span class="nn">lzero.policy.random_policy</span><span class="w"> </span><span class="kn">import</span> <span class="n">LightZeroRandomPolicy</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">lzero.worker</span><span class="w"> </span><span class="kn">import</span> <span class="n">MuZeroCollector</span> <span class="k">as</span> <span class="n">Collector</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">lzero.worker</span><span class="w"> </span><span class="kn">import</span> <span class="n">MuZeroEvaluator</span> <span class="k">as</span> <span class="n">Evaluator</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">random_collect</span>
<span class="kn">from</span><span class="w"> </span><span class="nn">.utils</span><span class="w"> </span><span class="kn">import</span> <span class="n">random_collect</span><span class="p">,</span> <span class="n">calculate_update_per_collect</span>


<div class="viewcode-block" id="train_muzero"><a class="viewcode-back" href="../../../api_doc/entry/index.html#lzero.entry.train_muzero.train_muzero">[docs]</a><span class="k">def</span><span class="w"> </span><span class="nf">train_muzero</span><span class="p">(</span>
Expand Down Expand Up @@ -277,12 +277,9 @@ <h1>Source code for lzero.entry.train_muzero</h1><div class="highlight"><pre>

<span class="c1"># Collect data by default config n_sample/n_episode.</span>
<span class="n">new_data</span> <span class="o">=</span> <span class="n">collector</span><span class="o">.</span><span class="n">collect</span><span class="p">(</span><span class="n">train_iter</span><span class="o">=</span><span class="n">learner</span><span class="o">.</span><span class="n">train_iter</span><span class="p">,</span> <span class="n">policy_kwargs</span><span class="o">=</span><span class="n">collect_kwargs</span><span class="p">)</span>
<span class="k">if</span> <span class="n">cfg</span><span class="o">.</span><span class="n">policy</span><span class="o">.</span><span class="n">update_per_collect</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="c1"># update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.</span>
<span class="c1"># The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.</span>
<span class="c1"># On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.</span>
<span class="n">collected_transitions_num</span> <span class="o">=</span> <span class="nb">sum</span><span class="p">(</span><span class="nb">min</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">game_segment</span><span class="p">),</span> <span class="n">cfg</span><span class="o">.</span><span class="n">policy</span><span class="o">.</span><span class="n">game_segment_length</span><span class="p">)</span> <span class="k">for</span> <span class="n">game_segment</span> <span class="ow">in</span> <span class="n">new_data</span><span class="p">[</span><span class="mi">0</span><span class="p">])</span>
<span class="n">update_per_collect</span> <span class="o">=</span> <span class="nb">int</span><span class="p">(</span><span class="n">collected_transitions_num</span> <span class="o">*</span> <span class="n">cfg</span><span class="o">.</span><span class="n">policy</span><span class="o">.</span><span class="n">replay_ratio</span><span class="p">)</span>

<span class="c1"># Determine updates per collection</span>
<span class="n">update_per_collect</span> <span class="o">=</span> <span class="n">calculate_update_per_collect</span><span class="p">(</span><span class="n">cfg</span><span class="p">,</span> <span class="n">new_data</span><span class="p">)</span>

<span class="c1"># save returned new_data collected by the collector</span>
<span class="n">replay_buffer</span><span class="o">.</span><span class="n">push_game_segments</span><span class="p">(</span><span class="n">new_data</span><span class="p">)</span>
Expand Down
9 changes: 4 additions & 5 deletions _modules/lzero/policy/unizero.html
Original file line number Diff line number Diff line change
Expand Up @@ -819,11 +819,10 @@ <h1>Source code for lzero.policy.unizero</h1><div class="highlight"><pre>
<span class="n">network_output</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_eval_model</span><span class="o">.</span><span class="n">initial_inference</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">last_batch_obs</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">last_batch_action</span><span class="p">,</span> <span class="n">data</span><span class="p">)</span>
<span class="n">latent_state_roots</span><span class="p">,</span> <span class="n">reward_roots</span><span class="p">,</span> <span class="n">pred_values</span><span class="p">,</span> <span class="n">policy_logits</span> <span class="o">=</span> <span class="n">mz_network_output_unpack</span><span class="p">(</span><span class="n">network_output</span><span class="p">)</span>

<span class="k">if</span> <span class="ow">not</span> <span class="bp">self</span><span class="o">.</span><span class="n">_eval_model</span><span class="o">.</span><span class="n">training</span><span class="p">:</span>
<span class="c1"># if not in training, obtain the scalars of the value/reward</span>
<span class="n">pred_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">inverse_scalar_transform_handle</span><span class="p">(</span><span class="n">pred_values</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># shape(B, 1)</span>
<span class="n">latent_state_roots</span> <span class="o">=</span> <span class="n">latent_state_roots</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">policy_logits</span> <span class="o">=</span> <span class="n">policy_logits</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="c1"># list shape(B, A)</span>
<span class="c1"># if not in training, obtain the scalars of the value/reward</span>
<span class="n">pred_values</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">inverse_scalar_transform_handle</span><span class="p">(</span><span class="n">pred_values</span><span class="p">)</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span> <span class="c1"># shape(B, 1)</span>
<span class="n">latent_state_roots</span> <span class="o">=</span> <span class="n">latent_state_roots</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>
<span class="n">policy_logits</span> <span class="o">=</span> <span class="n">policy_logits</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span><span class="o">.</span><span class="n">tolist</span><span class="p">()</span> <span class="c1"># list shape(B, A)</span>

<span class="n">legal_actions</span> <span class="o">=</span> <span class="p">[[</span><span class="n">i</span> <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">x</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="n">action_mask</span><span class="p">[</span><span class="n">j</span><span class="p">])</span> <span class="k">if</span> <span class="n">x</span> <span class="o">==</span> <span class="mi">1</span><span class="p">]</span> <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">active_eval_env_num</span><span class="p">)]</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cfg</span><span class="o">.</span><span class="n">mcts_ctree</span><span class="p">:</span>
Expand Down

0 comments on commit e52dd59

Please sign in to comment.