Skip to content

Commit

Permalink
Add optional normalization (cont.) (#98)
Browse files Browse the repository at this point in the history
* fix(ppo): optional reward scaling and minibatch advantage whitening

* feat(ppo): add optional reward clipping

* chore(ppo): add tests, comments

* fix(github): rename master to main for build

* feat(ppo): add manual reward scaling
  • Loading branch information
maxreciprocate authored Nov 20, 2022
1 parent aafcae9 commit 3db86ca
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 15 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ name: Build

on:
push:
branches: [ master ]
branches: [ main ]
pull_request:
branches: [ master ]
branches: [ main ]

jobs:
build:
Expand Down
6 changes: 4 additions & 2 deletions configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ method:
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
scale_reward: True
clip_reward: 10
scale_reward: "running" # False | "ref" | "running" estimate against which to scale rewards
ref_mean: null
ref_std: null # rescale rewards with this deviation
cliprange_reward: 10
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
Expand Down
4 changes: 4 additions & 0 deletions configs/test_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ method:
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 1.0 # value term weight
scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards
cliprange_reward: 10
ref_mean: null
ref_std: null
gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
Expand Down
20 changes: 20 additions & 0 deletions tests/test_ppo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from trlx.data.configs import TRLConfig
from trlx.model.nn.ppo_models import GPTHydraHeadWithValueModel
from trlx.utils.modeling import RunningMoments
from transformers import AutoTokenizer
import torch

Expand Down Expand Up @@ -44,3 +45,22 @@ def test_forward(self):
logits_diff = torch.sum(unfrozen_logits - frozen_logits).item()
self.assertEqual(hs_diff, 0)
self.assertEqual(logits_diff, 0)

class TestStatistics(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.m = RunningMoments()
cls.a1 = torch.arange(100, dtype=float)
cls.a2 = torch.ones(100, dtype=float)
cls.a3 = torch.exp(torch.arange(10, dtype=float))
cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float)

def test_running_moments(self):
assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6)
assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6)

a = torch.hstack((self.a1, self.a2, self.a3, self.a4))
assert torch.isclose(self.m.mean, a.mean(), atol=1e-6)
assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6)
6 changes: 4 additions & 2 deletions trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,10 @@ class PPOConfig(MethodConfig):
cliprange: float
cliprange_value: float
vf_coef: float
scale_reward: bool
clip_reward: float
scale_reward: str
ref_mean: Optional[float]
ref_std: Optional[float]
cliprange_reward: float
gen_kwargs: dict

def get_advantages_and_returns(
Expand Down
12 changes: 7 additions & 5 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(
self.rl_model.metric_fn = metric_fn

self.running = RunningMoments()
self.ref_mean = None
self.ref_std = None
self.ref_mean = self.rl_model.config.method.ref_mean
self.ref_std = self.rl_model.config.method.ref_std

def score(self, samples):
"""
Expand Down Expand Up @@ -84,19 +84,21 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
scores = torch.as_tensor(self.score(texts), device=samples.device)
stats["exp_score_time"] = time() - exp_score_time

# store statistics of the initial rollout as reference
if self.ref_mean is None:
self.ref_mean, self.ref_std = scores.mean(), scores.std()
all_scores_mean, all_scores_std = self.running.update(scores)

stats["exp_scores_mean"] = all_scores_mean
stats["exp_scores_std"] = all_scores_std
stats["running_mean"] = self.running.mean
stats["running_std"] = self.running.std

if self.rl_model.config.method.scale_reward:
if self.rl_model.config.method.scale_reward == "running":
scores /= self.running.std
elif self.rl_model.config.method.scale_reward == "ref":
scores /= self.ref_std

clip_reward = self.rl_model.config.method.clip_reward
clip_reward = self.rl_model.config.method.cliprange_reward
if clip_reward:
scores = torch.clip(scores, -clip_reward, clip_reward)

Expand Down
9 changes: 5 additions & 4 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,13 @@ def update(self, xs: torch.Tensor) -> Tuple[float, float]:
delta = xs_mean - self.mean
tot_count = self.count + xs_count

m_a = self.var * self.count
m_b = xs_var * xs_count
m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count
new_sum = xs_var * xs_count
# correct old_sum deviation accounting for the new mean
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
tot_sum = old_sum + new_sum

self.mean += delta * xs_count / tot_count
self.var = m_2 / tot_count
self.var = tot_sum / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).sqrt()
self.count = tot_count

Expand Down

0 comments on commit 3db86ca

Please sign in to comment.