-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathDPG.py
90 lines (70 loc) · 2.75 KB
/
DPG.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import sys
import random
import numpy as np
from peloton_bloomfilters import BloomFilter
from numpy_ringbuffer import RingBuffer
class DPG:
memory_bank_size = 10 ** 6
def __init__(self, latent_size, stddv_p, stddv, hot_start, N, batch_size, accuracy_bloomf=0.01, static=False):
self.i = 0
self.I = 0
self.BI = 1
#self.guessed_z = []
self.DYNAMIC = False
self.hot_start = hot_start
self.latent_size = latent_size
self.stddv_p = stddv_p
self.stddv = stddv
self.LOG = False
self.STATIC = static
self.init_att_size = 0
self.N = N
self.batch_size = batch_size
self.accuracy_bloomf = accuracy_bloomf
self.matched_i = 0
self.P = []
if not self.STATIC:
self.guessed_z = RingBuffer(capacity=self.memory_bank_size, dtype=(np.float32, self.latent_size))
def enable_logging(self):
print("LOGGING!")
self.LOG = True
self.log_unique = []
self.log_guessed = []
def __call__(self, z, x, attacked_set):
new = None
self.I += 1
if not self.init_att_size:
self.guesses = BloomFilter(self.N, self.accuracy_bloomf)
self.init_att_size = len(attacked_set)
#if True:
if not x in self.guesses:
self.guesses.add(x)
self.i += 1
self.FLAG = True
new = x
if self.LOG:
if not self.I % self.log_fq:
self.log_unique += [(self.I, self.i, self.DYNAMIC)]
# Matched
if x in attacked_set:
self.matched_i += 1
attacked_set.remove(x)
if not self.STATIC:
self.guessed_z.append(z)
self.P += [self.BI]
if self.LOG:
m = self.matched_i / self.init_att_size
self.log_guessed += [(self.I, self.i, x, m, self.DYNAMIC)]
if not self.STATIC and self.matched_i / self.init_att_size > self.hot_start and not self.DYNAMIC:
print("DYNAMIC starts now ....")
self.DYNAMIC = True
return new
def guess(self, z_ph, G, sess):
if self.DYNAMIC and len(self.guessed_z):
idxs = np.random.randint(0, len(self.guessed_z), self.batch_size, np.int32)
gi = self.guessed_z[idxs]
z = np.random.normal(gi, self.stddv)
else:
z = np.random.normal(0, scale=self.stddv_p, size=(self.batch_size, self.latent_size))
x = sess.run(G, {z_ph:z})
return z, x