diff --git a/jvm/src/main/scala/async/VThreadSupport.scala b/jvm/src/main/scala/async/VThreadSupport.scala index a547332e..83153146 100644 --- a/jvm/src/main/scala/async/VThreadSupport.scala +++ b/jvm/src/main/scala/async/VThreadSupport.scala @@ -3,6 +3,7 @@ package gears.async import scala.annotation.unchecked.uncheckedVariance import java.util.concurrent.locks.ReentrantLock import scala.concurrent.duration.FiniteDuration +import java.lang.invoke.{VarHandle, MethodHandles} object VThreadScheduler extends Scheduler: private val VTFactory = Thread @@ -12,12 +13,27 @@ object VThreadScheduler extends Scheduler: override def execute(body: Runnable): Unit = VTFactory.newThread(body) - override def schedule(delay: FiniteDuration, body: Runnable): Cancellable = + override def schedule(delay: FiniteDuration, body: Runnable): Cancellable = ScheduledRunnable(delay, body) + + private class ScheduledRunnable(val delay: FiniteDuration, val body: Runnable) extends Cancellable { + @volatile var interruptGuard = true // to avoid interrupting the body + val th = VTFactory.newThread: () => - Thread.sleep(delay.toMillis) - execute(body) + try Thread.sleep(delay.toMillis) + catch case e: InterruptedException => () /* we got cancelled, don't propagate */ + if ScheduledRunnable.interruptGuardVar.getAndSet(this, false) then body.run() th.start() - () => th.interrupt() + + final override def cancel(): Unit = + if ScheduledRunnable.interruptGuardVar.getAndSet(this, false) then th.interrupt() + } + + private object ScheduledRunnable: + val interruptGuardVar = + MethodHandles + .lookup() + .in(classOf[ScheduledRunnable]) + .findVarHandle(classOf[ScheduledRunnable], "interruptGuard", classOf[Boolean]) object VThreadSupport extends AsyncSupport: diff --git a/shared/src/test/scala/SchedulerBehavior.scala b/shared/src/test/scala/SchedulerBehavior.scala new file mode 100644 index 00000000..9f9a98b8 --- /dev/null +++ b/shared/src/test/scala/SchedulerBehavior.scala @@ -0,0 +1,42 @@ +import gears.async.{Async, Future, Listener} +import gears.async.AsyncOperations.* +import gears.async.default.given +import concurrent.duration.DurationInt +import gears.async.Future.Promise +import scala.util.Success + +class SchedulerBehavior extends munit.FunSuite { + test("schedule cancellation works") { + Async.blocking: + var bodyRan = false + val cancellable = Async.current.scheduler.schedule(1.seconds, () => bodyRan = true) + + // cancel immediately + cancellable.cancel() + + sleep(1000) + assert(!bodyRan) + } + + test("schedule cancellation doesn't abort inner code") { + Async.blocking: + var bodyRan = false + val fut = Promise[Unit]() + val cancellable = Async.current.scheduler.schedule( + 50.milliseconds, + () => + fut.complete(Success(())) + Async.blocking: + sleep(500) + bodyRan = true + ) + + // cancel after body started running + fut.await + cancellable.cancel() + + sleep(1000) + + assert(bodyRan) + } +}