-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathgroup.py
29 lines (26 loc) · 1009 Bytes
/
group.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import numpy as np
import pandas as pd
from hmm import GroupLevelHMM
from membership import MembershipVector
from trajectory import Trajectory
class Group():
def __init__(self, hmm, membership, trajectory, groupId):
self._hmm = hmm
self._membership = membership
self._trajectory = trajectory
self._groupId = groupId
def update(self):
member = self._membership
# X is a trajectory for each user
p_g = self._membership.getMeanProbByGroup(self._groupId)
if p_g > 1:
raise ValueError("Probability cannot be bigger than 1")
for userId in member.userList:
trajectoryArray = self._trajectory.getTrajectoryByUser(userId)
p_ugH = 0
for trajectory in trajectoryArray:
p_ugH += self._hmm.score(trajectory)
p_guH = p_ugH + np.log(p_g)
p_guH = np.exp(p_guH)
member.setProbByGroupUser(p_guH, userId, self._groupId)
return member