Skip to content

Commit

Permalink
deploy: 8a142a9
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Jan 27, 2025
1 parent eccc135 commit b269d4d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 10 deletions.
13 changes: 6 additions & 7 deletions _modules/lzero/mcts/buffer/game_buffer.html
Original file line number Diff line number Diff line change
Expand Up @@ -655,14 +655,13 @@ <h1>Source code for lzero.mcts.buffer.game_buffer</h1><div class="highlight"><pr
<span class="c1"># print(f&#39;valid_len is {valid_len}&#39;)</span>

<span class="k">if</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;priorities&#39;</span><span class="p">]</span> <span class="ow">is</span> <span class="kc">None</span><span class="p">:</span>
<span class="n">max_prio</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">game_pos_priorities</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">game_segment_buffer</span> <span class="k">else</span> <span class="mi">1</span>
<span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">game_segment_buffer</span><span class="p">:</span>
<span class="n">max_prio</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">game_pos_priorities</span><span class="o">.</span><span class="n">max</span><span class="p">()</span> <span class="k">if</span> <span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">game_pos_priorities</span><span class="p">)</span> <span class="o">&gt;</span> <span class="mi">0</span> <span class="k">else</span> <span class="mi">1</span>
<span class="k">else</span><span class="p">:</span>
<span class="n">max_prio</span> <span class="o">=</span> <span class="mi">1</span>

<span class="c1"># if no &#39;priorities&#39; provided, set the valid part of the new-added game history the max_prio</span>
<span class="bp">self</span><span class="o">.</span><span class="n">game_pos_priorities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">(</span>
<span class="p">(</span>
<span class="bp">self</span><span class="o">.</span><span class="n">game_pos_priorities</span><span class="p">,</span> <span class="p">[</span><span class="n">max_prio</span>
<span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">valid_len</span><span class="p">)]</span> <span class="o">+</span> <span class="p">[</span><span class="mf">0.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">valid_len</span><span class="p">,</span> <span class="n">data_length</span><span class="p">)]</span>
<span class="p">)</span>
<span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">game_pos_priorities</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">concatenate</span><span class="p">((</span><span class="bp">self</span><span class="o">.</span><span class="n">game_pos_priorities</span><span class="p">,</span> <span class="p">[</span><span class="n">max_prio</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">valid_len</span><span class="p">)]</span> <span class="o">+</span> <span class="p">[</span><span class="mf">0.</span> <span class="k">for</span> <span class="n">_</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">valid_len</span><span class="p">,</span> <span class="n">data_length</span><span class="p">)]))</span>
<span class="k">else</span><span class="p">:</span>
<span class="k">assert</span> <span class="n">data_length</span> <span class="o">==</span> <span class="nb">len</span><span class="p">(</span><span class="n">meta</span><span class="p">[</span><span class="s1">&#39;priorities&#39;</span><span class="p">]),</span> <span class="s2">&quot; priorities should be of same length as the game steps&quot;</span>
<span class="n">priorities</span> <span class="o">=</span> <span class="n">meta</span><span class="p">[</span><span class="s1">&#39;priorities&#39;</span><span class="p">]</span><span class="o">.</span><span class="n">copy</span><span class="p">()</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
Expand Down
3 changes: 1 addition & 2 deletions _modules/lzero/model/sampled_efficientzero_model.html
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,7 @@ <h1>Source code for lzero.model.sampled_efficientzero_model</h1><div class="high
<span class="c1"># (3,96,96), and frame_stack_num is 4. Due to downsample, the encoding of observation (latent_state) is</span>
<span class="c1"># (64, 96/16, 96/16), where 64 is the number of channels, 96/16 is the size of the latent state. Thus,</span>
<span class="c1"># self.projection_input_dim = 64 * 96/16 * 96/16 = 64*6*6 = 2304</span>
<span class="bp">self</span><span class="o">.</span><span class="n">projection_input_dim</span> <span class="o">=</span> <span class="n">num_channels</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">observation_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">/</span> <span class="mi">16</span>
<span class="p">)</span> <span class="o">*</span> <span class="n">math</span><span class="o">.</span><span class="n">ceil</span><span class="p">(</span><span class="n">observation_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span> <span class="o">/</span> <span class="mi">16</span><span class="p">)</span>
<span class="bp">self</span><span class="o">.</span><span class="n">projection_input_dim</span> <span class="o">=</span> <span class="n">num_channels</span> <span class="o">*</span> <span class="n">latent_size</span>
<span class="k">else</span><span class="p">:</span>
<span class="bp">self</span><span class="o">.</span><span class="n">projection_input_dim</span> <span class="o">=</span> <span class="n">num_channels</span> <span class="o">*</span> <span class="n">observation_shape</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">*</span> <span class="n">observation_shape</span><span class="p">[</span><span class="mi">2</span><span class="p">]</span>

Expand Down
2 changes: 1 addition & 1 deletion _modules/lzero/policy/sampled_muzero.html
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ <h1>Source code for lzero.policy.sampled_muzero</h1><div class="highlight"><pre>
<span class="s1">&#39;total_loss&#39;</span><span class="p">:</span> <span class="n">loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="s1">&#39;policy_loss&#39;</span><span class="p">:</span> <span class="n">policy_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="s1">&#39;policy_entropy&#39;</span><span class="p">:</span> <span class="n">policy_entropy</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_cfg</span><span class="o">.</span><span class="n">num_unroll_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="s1">&#39;target_policy_entropy&#39;</span><span class="p">:</span> <span class="n">target_policy_entropy</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_cfg</span><span class="o">.</span><span class="n">num_unroll_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="s1">&#39;target_policy_entropy&#39;</span><span class="p">:</span> <span class="n">target_policy_entropy</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">_cfg</span><span class="o">.</span><span class="n">num_unroll_steps</span> <span class="o">+</span> <span class="mi">1</span><span class="p">),</span>
<span class="s1">&#39;reward_loss&#39;</span><span class="p">:</span> <span class="n">reward_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="s1">&#39;value_loss&#39;</span><span class="p">:</span> <span class="n">value_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">(),</span>
<span class="s1">&#39;consistency_loss&#39;</span><span class="p">:</span> <span class="n">consistency_loss</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">_cfg</span><span class="o">.</span><span class="n">num_unroll_steps</span><span class="p">,</span>
Expand Down

0 comments on commit b269d4d

Please sign in to comment.