-
-
Notifications
You must be signed in to change notification settings - Fork 1
/
attrib_tuned_model.py
39 lines (28 loc) · 1.32 KB
/
attrib_tuned_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from typing import NamedTuple
import numpy as np
import tensorflow as tf
from config import ATTRIB_CONFIDENCE, ATTRIB_MODEL_PATH
from crossing_type import CrossingType
class AttribClassification(NamedTuple):
is_valid: bool
crossing_type: CrossingType
confidence: float
class AttribTunedModel:
def __init__(self):
from attrib_model import get_attrib_model
self._model = get_attrib_model()
self._model.load_weights(str(ATTRIB_MODEL_PATH))
def predict_single(self, X: np.ndarray, min_confidence: float = ATTRIB_CONFIDENCE) -> AttribClassification:
with tf.device('/CPU:0'): # force CPU to better understand real performance
pred_proba: dict = self._model.predict(X[np.newaxis, ...], verbose=0)[0]
assert len(pred_proba) == 1
confidence = float(pred_proba[0])
is_valid = confidence > min_confidence
crossing_type = CrossingType.UNKNOWN
# if (not is_uncontrolled and not is_traffic_signals) or (is_uncontrolled and is_traffic_signals):
# crossing_type = CrossingType.UNKNOWN
# elif is_uncontrolled:
# crossing_type = CrossingType.UNCONTROLLED
# elif is_traffic_signals:
# crossing_type = CrossingType.TRAFFIC_SIGNALS
return AttribClassification(is_valid, crossing_type, confidence)