-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: map2 deep in pure functional parallelism
- Loading branch information
Showing
2 changed files
with
123 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package fpinscala | ||
package chapter7 | ||
package par | ||
|
||
import java.util.concurrent.ExecutorService | ||
import java.util.concurrent.TimeUnit | ||
import java.util.concurrent.Callable | ||
import java.util.concurrent.Future | ||
import fpinscala.chapter3.lizt.Lizt | ||
|
||
object Par: | ||
opaque type Par[A] = ExecutorService => Future[A] | ||
|
||
private case class UnitFuture[A](get: A) extends Future[A]: | ||
def isDone = true | ||
def get(timeout: Long, units: TimeUnit): A = get | ||
def isCancelled = false | ||
def cancel(evenIfRunning: Boolean) = false | ||
|
||
def unit[A](a: A): Par[A] = es => UnitFuture(a) | ||
|
||
def fork[A](a: => Par[A]): Par[A] = | ||
es => | ||
es.submit(new Callable[A] { | ||
def call = a(es).get() | ||
}) | ||
|
||
def lazyUnit[A](a: => A): Par[A] = fork(unit(a)) | ||
|
||
def asyncF[A, B](f: A => B): A => Par[B] = | ||
a => lazyUnit(f(a)) | ||
|
||
extension [A](pa: Par[A]) def run(es: ExecutorService): Future[A] = pa(es) | ||
extension [A](pa: Par[A]) | ||
def map2[B, C](pb: Par[B])(f: (A, B) => C): Par[C] = | ||
(es: ExecutorService) => | ||
val futureA = pa(es) | ||
val futureB = pb(es) | ||
UnitFuture(f(futureA.get, futureB.get)) | ||
|
||
def mapErku[A, B, C](pa: Par[A], pb: Par[B])(f: (A, B) => C): Par[C] = | ||
(es: ExecutorService) => | ||
val futureA = pa(es) | ||
val futureB = pb(es) | ||
UnitFuture(f(futureA.get, futureB.get)) | ||
|
||
extension [A](pa: Par[A]) | ||
def map2WithTimeout[B, C](pb: Par[B])(f: (A, B) => C): Par[C] = | ||
(es: ExecutorService) => | ||
new Future[C]: | ||
private val futureA = pa(es) | ||
private val futureB = pb(es) | ||
@volatile private var cache: Option[C] = None | ||
|
||
def isDone(): Boolean = cache.isDefined | ||
def get(): C = get(Long.MaxValue, TimeUnit.NANOSECONDS) | ||
|
||
def get(timeout: Long, units: TimeUnit): C = | ||
val timenanos = TimeUnit.NANOSECONDS.convert(timeout, units) | ||
val started = System.nanoTime() | ||
val a = futureA.get(timenanos, TimeUnit.NANOSECONDS) | ||
val elapsed = System.nanoTime() - started | ||
val b = futureB.get(timeout - elapsed, TimeUnit.NANOSECONDS) | ||
val c = f(a, b) | ||
cache = Some(c) | ||
c | ||
|
||
def isCancelled(): Boolean = futureA.isCancelled() || futureB.isCancelled() | ||
|
||
def cancel(evenIfRunning: Boolean): Boolean = | ||
futureA.cancel(evenIfRunning) || futureB.cancel(evenIfRunning) | ||
|
||
def sum(ints: Lizt[Int]): Par[Int] = | ||
val elements = Lizt.length(ints) | ||
if elements < 2 then Par.unit(Lizt.headOpshn(ints).getOrElse(0)) | ||
else | ||
val (l, r) = Lizt.splitAt(ints, elements / 2) | ||
mapErku(fork(sum(l)), fork(sum(r)))(_ + _) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package fpinscala | ||
package chapter7 | ||
package tests | ||
|
||
import org.scalatest.matchers.should._ | ||
import par._ | ||
import org.scalatest.freespec.AnyFreeSpec | ||
import java.util.concurrent.Executors | ||
import fpinscala.chapter3.lizt.Lizt | ||
|
||
class ParTests extends AnyFreeSpec with Matchers { | ||
private val es = Executors.newFixedThreadPool(5) | ||
|
||
"Par should" - { | ||
"provide a convenient interface for summing integers" in { | ||
Par.sum(Lizt(1, 2, 3, 4)).run(es).get() shouldBe 10 | ||
} | ||
|
||
"be able to" - { | ||
"promote a constant value to a parallel computation" in { | ||
Par.unit(1).run(es).get should be(1) | ||
} | ||
|
||
"combine the results of two parallel computations" in { | ||
Par.unit(11).map2(Par.unit(2))(_ * _).run(es).get should be(22) | ||
} | ||
|
||
"mark a computation for concurrent evaluation" in { | ||
pending | ||
} | ||
|
||
"lazily mark a computation for concurrent evaluation" in { | ||
pending | ||
} | ||
|
||
"actually perform a computation and provide its' value" in { | ||
pending | ||
} | ||
|
||
"wrap any function into a lazy blanket" in { | ||
pending | ||
} | ||
} | ||
} | ||
} |