Skip to content

Commit

Permalink
CHEVROLET BOLT EUV 2022: Add a simple neural feed forward (commaai#31266
Browse files Browse the repository at this point in the history
)

* add simple neural feed forward

* update refs

* do not sample during inference in op

* update refs
  • Loading branch information
nuwandavek authored Feb 1, 2024
1 parent 628c829 commit 6196256
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 3 deletions.
19 changes: 17 additions & 2 deletions selfdrive/car/gm/interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
#!/usr/bin/env python3
import os
from cereal import car
from math import fabs, exp
from panda import Panda

from openpilot.common.basedir import BASEDIR
from openpilot.common.conversions import Conversions as CV
from openpilot.selfdrive.car import create_button_events, get_safety_config
from openpilot.selfdrive.car.gm.radar_interface import RADAR_HEADER_MSG
from openpilot.selfdrive.car.gm.values import CAR, CruiseButtons, CarControllerParams, EV_CAR, CAMERA_ACC_CAR, CanBus
from openpilot.selfdrive.car.interfaces import CarInterfaceBase, TorqueFromLateralAccelCallbackType, FRICTION_THRESHOLD, LatControlInputs
from openpilot.selfdrive.car.interfaces import CarInterfaceBase, TorqueFromLateralAccelCallbackType, FRICTION_THRESHOLD, LatControlInputs, NanoFFModel
from openpilot.selfdrive.controls.lib.drive_helpers import get_friction

ButtonType = car.CarState.ButtonEvent.Type
Expand All @@ -25,6 +27,8 @@
CAR.SILVERADO: [3.29974374, 1.0, 0.25571356, 0.0465122]
}

NEURAL_PARAMS_PATH = os.path.join(BASEDIR, 'selfdrive/car/torque_data/neural_ff_weights.json')


class CarInterface(CarInterfaceBase):
@staticmethod
Expand Down Expand Up @@ -61,8 +65,19 @@ def sig(val):
steer_torque = (sig(latcontrol_inputs.lateral_acceleration * a) * b) + (latcontrol_inputs.lateral_acceleration * c)
return float(steer_torque) + friction

def torque_from_lateral_accel_neural(self, latcontrol_inputs: LatControlInputs, torque_params: car.CarParams.LateralTorqueTuning, lateral_accel_error: float,
lateral_accel_deadzone: float, friction_compensation: bool, gravity_adjusted: bool) -> float:
friction = get_friction(lateral_accel_error, lateral_accel_deadzone, FRICTION_THRESHOLD, torque_params, friction_compensation)
inputs = list(latcontrol_inputs)
if gravity_adjusted:
inputs[0] += inputs[1]
return float(self.neural_ff_model.predict(inputs)) + friction

def torque_from_lateral_accel(self) -> TorqueFromLateralAccelCallbackType:
if self.CP.carFingerprint in NON_LINEAR_TORQUE_PARAMS:
if self.CP.carFingerprint == CAR.BOLT_EUV:
self.neural_ff_model = NanoFFModel(NEURAL_PARAMS_PATH, self.CP.carFingerprint)
return self.torque_from_lateral_accel_neural
elif self.CP.carFingerprint in NON_LINEAR_TORQUE_PARAMS:
return self.torque_from_lateral_accel_siglin
else:
return self.torque_from_lateral_accel_linear
Expand Down
33 changes: 33 additions & 0 deletions selfdrive/car/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import os
import time
import numpy as np
Expand Down Expand Up @@ -473,3 +474,35 @@ def get_interface_attr(attr: str, combine_brands: bool = False, ignore_none: boo
pass

return result


class NanoFFModel:
def __init__(self, weights_loc: str, platform: str):
self.weights_loc = weights_loc
self.platform = platform
self.load_weights(platform)

def load_weights(self, platform: str):
with open(self.weights_loc, 'r') as fob:
self.weights = {k: np.array(v) for k, v in json.load(fob)[platform].items()}

def relu(self, x: np.ndarray):
return np.maximum(0.0, x)

def forward(self, x: np.ndarray):
assert x.ndim == 1
x = (x - self.weights['input_norm_mat'][:, 0]) / (self.weights['input_norm_mat'][:, 1] - self.weights['input_norm_mat'][:, 0])
x = self.relu(np.dot(x, self.weights['w_1']) + self.weights['b_1'])
x = self.relu(np.dot(x, self.weights['w_2']) + self.weights['b_2'])
x = self.relu(np.dot(x, self.weights['w_3']) + self.weights['b_3'])
x = np.dot(x, self.weights['w_4']) + self.weights['b_4']
return x

def predict(self, x: List[float], do_sample: bool = False):
x = self.forward(np.array(x))
if do_sample:
pred = np.random.laplace(x[0], np.exp(x[1]) / self.weights['temperature'])
else:
pred = x[0]
pred = pred * (self.weights['output_norm_mat'][1] - self.weights['output_norm_mat'][0]) + self.weights['output_norm_mat'][0]
return pred
1 change: 1 addition & 0 deletions selfdrive/car/torque_data/neural_ff_weights.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"CHEVROLET BOLT EUV 2022": {"w_1": [[0.3452189564704895, -0.15614677965641022, -0.04062516987323761, -0.5960758328437805, 0.3211185932159424, 0.31732726097106934, -0.04430829733610153, -0.37327295541763306, -0.14118380844593048, 0.12712529301643372, 0.2641555070877075, -0.3451094627380371, -0.005127656273543835, 0.6185108423233032, 0.03725295141339302, 0.3763789236545563], [-0.0708412230014801, 0.3667356073856354, 0.031383827328681946, 0.1740853488445282, -0.04695861041545868, 0.018055908381938934, 0.009072160348296165, -0.23640218377113342, -0.10362917929887772, 0.022628149017691612, -0.224413201212883, 0.20718418061733246, -0.016947750002145767, -0.3872031271457672, -0.15500062704086304, -0.06375953555107117], [-0.0838046595454216, -0.0242826659232378, -0.07765661180019379, 0.028858814388513565, -0.09516210108995438, 0.008368706330657005, 0.1689300835132599, 0.015036891214549541, -0.15121428668498993, 0.1388195902109146, 0.11486363410949707, 0.0651545450091362, 0.13559958338737488, 0.04300367832183838, -0.13856294751167297, -0.058136988431215286], [-0.006249868310987949, 0.08809533715248108, -0.040690965950489044, 0.02359287068247795, -0.00766348373144865, 0.24816390872001648, -0.17360293865203857, -0.03676899895071983, -0.17564819753170013, 0.18998438119888306, -0.050583917647600174, -0.006488069426268339, 0.10649101436138153, -0.024557121098041534, -0.103276826441288, 0.18448011577129364]], "b_1": [0.2935388386249542, 0.10967712104320526, -0.014007942751049995, 0.211833655834198, 0.33605605363845825, 0.37722209095954895, -0.16615016758441925, 0.3134673535823822, 0.06695777177810669, 0.3425212800502777, 0.3769673705101013, 0.23186539113521576, 0.5770409107208252, -0.05929069593548775, 0.01839117519557476, 0.03828774020075798], "w_2": [[-0.06261160969734192, 0.010185074992477894, -0.06083013117313385, -0.04531499370932579, -0.08979734033346176, 0.3432150185108185, -0.019801849499344826, 0.3010321259498596], [0.19698476791381836, -0.009238275699317455, 0.08842222392559052, -0.09516377002000809, -0.05022778362035751, 0.13626104593276978, -0.052890390157699585, 0.15569131076335907], [0.0724768117070198, -0.09018408507108688, 0.06850195676088333, -0.025572121143341064, 0.0680626779794693, -0.07648195326328278, 0.07993496209383011, -0.059752143919467926], [1.267876386642456, -0.05755887180566788, -0.08429178595542908, 0.021366603672504425, -0.0006479775765910745, -1.4292563199996948, -0.08077696710824966, -1.414825439453125], [0.04535430669784546, 0.06555880606174469, -0.027145234867930412, -0.07661093026399612, -0.05702832341194153, 0.23650476336479187, 0.0024587824009358883, 0.20126521587371826], [0.006042032968252897, 0.042880818247795105, 0.002187949838116765, -0.017126334831118584, -0.08352015167474747, 0.19801731407642365, -0.029196614399552345, 0.23713473975658417], [-0.01644900068640709, -0.04358499124646187, 0.014584392309188843, 0.07155826687812805, -0.09354910999536514, -0.033351872116327286, 0.07138452678918839, -0.04755295440554619], [-1.1012420654296875, -0.03534531593322754, 0.02167935110628605, -0.01116552110761404, -0.08436500281095505, 1.1038788557052612, 0.027903547510504723, 1.0676132440567017], [0.03843916580080986, -0.0952216386795044, 0.039226632565259933, 0.002778085647150874, -0.020275786519050598, -0.07848760485649109, 0.04803166165947914, 0.015538203530013561], [0.018385495990514755, -0.025189843028783798, 0.0036680365446954966, -0.02105865254998207, 0.04808586835861206, 0.1575016975402832, 0.02703506126999855, 0.23039312660694122], [-0.0033881019335240126, -0.10210853815078735, -0.04877309128642082, 0.006989633198827505, 0.046798162162303925, 0.38676899671554565, -0.032304272055625916, 0.2345031052827835], [0.22092825174331665, -0.09642873704433441, 0.04499409720301628, 0.05108088254928589, -0.10191166400909424, 0.12818090617656708, -0.021021494641900063, 0.09440375864505768], [0.1212429478764534, -0.028194155544042587, -0.0981956496834755, 0.08226924389600754, 0.055346839129924774, 0.27067816257476807, -0.09064067900180817, 0.12580905854701996], [-1.6740131378173828, -0.02066155895590782, -0.05924689769744873, 0.06347910314798355, -0.07821853458881378, 1.2807466983795166, 0.04589352011680603, 1.310766577720642], [-0.09893272817134857, -0.04093599319458008, -0.02502273954451084, 0.09490344673395157, -0.0211324505507946, -0.09021010994911194, 0.07936318963766098, -0.03593116253614426], [-0.08490308374166489, -0.015558987855911255, -0.048692114651203156, -0.007421435788273811, -0.040531404316425323, 0.25889304280281067, 0.06012800335884094, 0.27946868538856506]], "b_2": [0.07973937690258026, -0.010446485131978989, -0.003066520905122161, -0.031895797699689865, 0.006032303906977177, 0.24106740951538086, -0.008969511836767197, 0.2872662842273712], "w_3": [[-1.364486813545227, -0.11682678014039993, 0.01764785870909691, 0.03926877677440643], [-0.05695437639951706, 0.05472218990325928, 0.1266128271818161, 0.09950875490903854], [0.11415273696184158, -0.10069356113672256, 0.0864749327301979, -0.043946366757154465], [-0.10138195008039474, -0.040128443390131, -0.08937158435583115, -0.0048376512713730335], [-0.0028251828625798225, -0.04743027314543724, 0.06340016424655914, 0.07277824729681015], [0.49482327699661255, -0.06410001963376999, -0.0999293103814125, -0.14250673353672028], [0.042802367359399796, 0.0015462725423276424, -0.05991362780332565, 0.1022040992975235], [0.3523194193840027, 0.07343732565641403, 0.04157765582203865, -0.12358107417821884]], "b_3": [0.2653026282787323, -0.058485131710767746, -0.0744510293006897, 0.012550175189971924], "w_4": [[0.5988775491714478, 0.09668736904859543], [-0.04360569268465042, 0.06491032242774963], [-0.11868984252214432, -0.09601487964391708], [-0.06554870307445526, -0.14189276099205017]], "b_4": [-0.08148707449436188, -2.8251802921295166], "input_norm_mat": [[-3.0, 3.0], [-3.0, 3.0], [0.0, 40.0], [-3.0, 3.0]], "output_norm_mat": [-1.0, 1.0], "temperature": 100.0}}
2 changes: 1 addition & 1 deletion selfdrive/test/process_replay/ref_commit
Original file line number Diff line number Diff line change
@@ -1 +1 @@
eaab6bd55c5eab33fc9a0d8de8289b912e923887
d9a3a0d4e806b49ec537233d30926bec70308485

0 comments on commit 6196256

Please sign in to comment.