-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathucb.py
35 lines (30 loc) · 956 Bytes
/
ucb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from .randmax import randmax
import numpy as np
from .base_mab import BaseMAB
class UCB(BaseMAB):
"""UCB1 with parameter alpha
Parameters
----------
nbArms :int,
Number of arms of bandit
alpha : float,
"""
def __init__(self, nbArms,alpha=2):
self.nbArms = nbArms
self.clear()
self.Best = 0
self.alpha = alpha
def clear(self):
self.nbDraws = np.zeros(self.nbArms)
self.cumRewards = np.zeros(self.nbArms)
self.arm_means = np.zeros(self.nbArms)
self.t = 0
def chooseArmToPlay(self):
if self.t < self.nbArms :
return self.t
else :
return randmax(self.cumRewards/self.nbDraws + np.sqrt(self.alpha * np.log(self.t)/self.nbDraws))
def receiveReward(self, arm, reward):
self.t += 1
self.cumRewards[arm] = self.cumRewards[arm]+reward
self.nbDraws[arm] = self.nbDraws[arm] +1