From 54f404509009e7f76c0faccf0f2b358dabb9bf87 Mon Sep 17 00:00:00 2001 From: Eric Vin Date: Fri, 24 Jan 2025 18:14:28 -0800 Subject: [PATCH] Fixed subtle testing bug --- examples/contracts/dev.contract | 3 ++- examples/contracts/highway.scenic | 3 ++- src/scenic/contracts/testing.py | 8 +++++--- src/scenic/contracts/utils.py | 4 +++- src/scenic/core/simulators.py | 6 +++++- 5 files changed, 17 insertions(+), 7 deletions(-) diff --git a/examples/contracts/dev.contract b/examples/contracts/dev.contract index 44a7e89b4..44cc70f6f 100644 --- a/examples/contracts/dev.contract +++ b/examples/contracts/dev.contract @@ -5,6 +5,7 @@ import random from typing import Union import numpy +import scenic from scenic.core.geometry import normalizeAngle from scenic.domains.driving.controllers import PIDLongitudinalController, PIDLateralController from scenic.domains.driving.actions import RegulatedControlAction @@ -396,7 +397,7 @@ cs_safety = compose over car.cs: max_brake=MAX_BRAKE, max_accel=MAX_ACCEL) using LeanContractProof(localPath("ScenicLean/"), "SafeThrottleFilter", localPath("repl")) -max_braking_force = assume car.cac satisfies MaxBrakingForce(MAX_BRAKE), with correctness 0.99, with confidence 0.99 +max_braking_force = assume car.cac satisfies MaxBrakingForce(MAX_BRAKE), with correctness 0.9, with confidence 0.95 keeps_distance_raw = compose over car: use accurate_distance diff --git a/examples/contracts/highway.scenic b/examples/contracts/highway.scenic index b185347c2..a063fa81c 100644 --- a/examples/contracts/highway.scenic +++ b/examples/contracts/highway.scenic @@ -32,7 +32,8 @@ class EgoCar(Car): targetDir[dynamic, final]: float(roadDirection[self.position].yaw) ego = new EgoCar at roadDirection.followFrom(toVector(leadCar), -STARTING_DISTANCE, stepSize=0.1), - with leadDist STARTING_DISTANCE, with behavior FollowLaneBehavior(), with name "EgoCar", with timestep 0.1 + with leadDist STARTING_DISTANCE, + with behavior FollowLaneBehavior(), with name "EgoCar", with timestep 0.1 # Create/activate monitor to store lead distance monitor UpdateDistance(tailCar, leadCar): diff --git a/src/scenic/contracts/testing.py b/src/scenic/contracts/testing.py index b028be7a0..f1db7fa12 100644 --- a/src/scenic/contracts/testing.py +++ b/src/scenic/contracts/testing.py @@ -216,13 +216,17 @@ def testScene(self, scene): # Instantiate simulator simulator = self.scenario.getSimulator() + def vw_update_hook(): + for vw_name, vw in base_value_windows.items(): + vw.update() + # Step contract till termination with simulator.simulateStepped(scene, maxSteps=self.maxSteps) as simulation: while not simulation.result or eval_step < sim_step: # If simulation not terminated, advance simulation one time step, catching any rejections if not simulation.result: try: - simulation.advance() + simulation.advance(vw_update_hook) except ( RejectSimulationException, RejectionException, @@ -237,8 +241,6 @@ def testScene(self, scene): continue # If the simulation didn't finish, update all base value windows - for vw_name, vw in base_value_windows.items(): - vw.update() # Increment simulation step sim_step += 1 diff --git a/src/scenic/contracts/utils.py b/src/scenic/contracts/utils.py index 9e21c52b8..eaf9e2db7 100644 --- a/src/scenic/contracts/utils.py +++ b/src/scenic/contracts/utils.py @@ -125,7 +125,9 @@ def leadDistance(source, target, network, maxDistance=250): # print(source.position, target.position, hash(network)) # Find all lanes this point could be a part of and recurse on them. - viable_lanes = [lane for lane in network.lanes if lane.containsPoint(source.position)] + viable_lanes = [ + lane for lane in network.lanes if lane.containsPoint(toVector(source)) + ] return min( min( diff --git a/src/scenic/core/simulators.py b/src/scenic/core/simulators.py index 3f5c3787d..841f15532 100644 --- a/src/scenic/core/simulators.py +++ b/src/scenic/core/simulators.py @@ -440,7 +440,7 @@ def _run(self): if self.terminationType: return - def advance(self): + def advance(self, preActionHook=None): if self.terminationType or self._cleaned: raise TerminatedSimulationException() @@ -530,6 +530,10 @@ def advance(self): # Log lastActions agent.lastActions = actions + # Run the preActionHook if provided + if preActionHook is not None: + preActionHook() + # Execute the actions if self.verbosity >= 3: for agent, actions in allActions.items():