Skip to content

Commit

Permalink
core: Fiber improvements and fixes (#743)
Browse files Browse the repository at this point in the history
  • Loading branch information
fwbrasil authored Oct 13, 2024
1 parent 114a85d commit 5bf22f6
Show file tree
Hide file tree
Showing 4 changed files with 340 additions and 107 deletions.
119 changes: 65 additions & 54 deletions kyo-core/shared/src/main/scala/kyo/Fiber.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ import scala.util.NotGiven
import scala.util.control.NonFatal
import scala.util.control.NoStackTrace

opaque type Fiber[E, A] = IOPromise[E, A]
opaque type Fiber[+E, +A] = IOPromise[E, A]

object Fiber extends FiberPlatformSpecific:

inline given [E, A]: Flat[Fiber[E, A]] = Flat.unsafe.bypass

private val _unit = success(()).mask
private val _never = IOPromise[Nothing, Unit]().mask
private val _unit = IOPromise(Result.unit).mask()
private val _never = IOPromise[Nothing, Unit]().mask()

private[kyo] inline def fromTask[E, A](inline ioTask: IOTask[?, E, A]): Fiber[E, A] = ioTask

Expand Down Expand Up @@ -83,21 +83,8 @@ object Fiber extends FiberPlatformSpecific:
* @return
* A Fiber that completes with the result of the Future
*/
def fromFuture[A](f: Future[A])(using frame: Frame): Fiber[Throwable, A] < IO =
import scala.util.*
IO {
val p = new IOPromise[Throwable, A] with (Try[A] => Unit):
def apply(result: Try[A]) =
result match
case Success(v) =>
completeDiscard(Result.success(v))
case Failure(ex) =>
completeDiscard(Result.fail(ex))

f.onComplete(p)(ExecutionContext.parasitic)
p
}
end fromFuture
def fromFuture[A](future: => Future[A])(using frame: Frame): Fiber[Throwable, A] < IO =
IO.Unsafe(Unsafe.fromFuture(future))

private def result[E, A](result: Result[E, A]): Fiber[E, A] = IOPromise(result)

Expand Down Expand Up @@ -149,7 +136,7 @@ object Fiber extends FiberPlatformSpecific:
* @param f
* The callback function
*/
def onComplete(f: Result[E, A] => Unit < IO)(using Frame): Unit < IO =
def onComplete[E2 >: E, A2 >: A](f: Result[E2, A2] => Unit < IO)(using Frame): Unit < IO =
import AllowUnsafe.embrace.danger
IO(self.onComplete(r => IO.Unsafe.run(f(r)).eval))

Expand Down Expand Up @@ -183,13 +170,7 @@ object Fiber extends FiberPlatformSpecific:
* A Future that completes with the result of the Fiber
*/
def toFuture(using E <:< Throwable, Frame): Future[A] < IO =
IO {
val r = scala.concurrent.Promise[A]()
self.onComplete { v =>
r.complete(v.toTry)
}
r.future
}
IO.Unsafe(Unsafe.toFuture(self)())

/** Maps the result of the Fiber.
*
Expand All @@ -198,13 +179,9 @@ object Fiber extends FiberPlatformSpecific:
* @return
* A new Fiber with the mapped result
*/
def map[B](f: A => B)(using Frame): Fiber[E, B] < IO =
IO {
val p = new IOPromise[E, B](interrupts = self) with (Result[E, A] => Unit):
def apply(v: Result[E, A]) = completeDiscard(v.map(f))
self.onComplete(p)
p
}
def map[B: Flat](f: A => B < IO)(using Frame): Fiber[E, B] < IO =
import AllowUnsafe.embrace.danger
IO.Unsafe(Unsafe.map(self)((r => IO.Unsafe.run(f(r)).eval)))

/** Flat maps the result of the Fiber.
*
Expand All @@ -213,13 +190,9 @@ object Fiber extends FiberPlatformSpecific:
* @return
* A new Fiber with the flat mapped result
*/
def flatMap[E2, B](f: A => Fiber[E2, B])(using Frame): Fiber[E | E2, B] < IO =
IO {
val p = new IOPromise[E | E2, B](interrupts = self) with (Result[E, A] => Unit):
def apply(r: Result[E, A]) = r.fold(completeDiscard)(v => becomeDiscard(f(v)))
self.onComplete(p)
p
}
def flatMap[E2, B](f: A => Fiber[E2, B] < IO)(using Frame): Fiber[E | E2, B] < IO =
import AllowUnsafe.embrace.danger
IO.Unsafe(Unsafe.flatMap(self)(r => IO.Unsafe.run(f(r)).eval))

/** Maps the Result of the Fiber using the provided function.
*
Expand All @@ -231,13 +204,8 @@ object Fiber extends FiberPlatformSpecific:
* @return
* A new Fiber with the mapped Result
*/
def mapResult[E2, B](f: Result[E, A] => Result[E2, B])(using Frame): Fiber[E2, B] < IO =
IO {
val p = new IOPromise[E2, B](interrupts = self) with (Result[E, A] => Unit):
def apply(r: Result[E, A]) = completeDiscard(Result(f(r)).flatten)
self.onComplete(p)
p
}
def mapResult[E2, B](f: Result[E, A] => Result[E2, B] < IO)(using Frame): Fiber[E2, B] < IO =
IO.Unsafe(Unsafe.mapResult(self)(r => IO.Unsafe.run(f(r)).eval))

/** Creates a new Fiber that runs with interrupt masking.
*
Expand All @@ -248,7 +216,7 @@ object Fiber extends FiberPlatformSpecific:
* @return
* A new Fiber that runs with interrupt masking
*/
def mask(using Frame): Fiber[E, A] < IO = IO(self.mask)
def mask(using Frame): Fiber[E, A] < IO = IO.Unsafe(Unsafe.mask(self)())

/** Interrupts the Fiber.
*
Expand Down Expand Up @@ -375,15 +343,27 @@ object Fiber extends FiberPlatformSpecific:
end if
}

opaque type Unsafe[E, A] = IOPromise[E, A]
opaque type Unsafe[+E, +A] = IOPromise[E, A]

/** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */
object Unsafe:
inline given [E, A]: Flat[Unsafe[E, A]] = Flat.unsafe.bypass

def init[E, A]()(using AllowUnsafe): Unsafe[E, A] = IOPromise()
def init[E, A](result: Result[E, A])(using AllowUnsafe): Unsafe[E, A] = IOPromise(result)

def fromPromise[E, A](p: Promise.Unsafe[E, A]): Unsafe[E, A] = p.safe
def fromFuture[A](f: => Future[A])(using AllowUnsafe): Unsafe[Throwable, A] =
import scala.util.*
val p = new IOPromise[Throwable, A] with (Try[A] => Unit):
def apply(result: Try[A]) =
result match
case Success(v) =>
completeDiscard(Result.success(v))
case Failure(ex) =>
completeDiscard(Result.fail(ex))

f.onComplete(p)(ExecutionContext.parasitic)
p
end fromFuture

extension [E, A](self: Unsafe[E, A])
def done()(using AllowUnsafe): Boolean = self.done()
Expand All @@ -392,11 +372,42 @@ object Fiber extends FiberPlatformSpecific:
def block(deadline: Clock.Deadline.Unsafe)(using AllowUnsafe, Frame): Result[E | Timeout, A] = self.block(deadline)
def interrupt(error: Panic)(using AllowUnsafe): Boolean = self.interrupt(error)
def interruptDiscard(error: Panic)(using AllowUnsafe): Unit = discard(self.interrupt(error))
def safe: Fiber[E, A] = self
def mask()(using AllowUnsafe): Unsafe[E, A] = self.mask()

def toFuture()(using E <:< Throwable, AllowUnsafe): Future[A] =
val r = scala.concurrent.Promise[A]()
self.onComplete { v =>
r.complete(v.toTry)
}
r.future
end toFuture

def map[B](f: A => B)(using AllowUnsafe): Unsafe[E, B] =
val p = new IOPromise[E, B](interrupts = self) with (Result[E, A] => Unit):
def apply(v: Result[E, A]) = completeDiscard(v.map(f))
self.onComplete(p)
p
end map

def flatMap[E2, B](f: A => Unsafe[E2, B])(using AllowUnsafe): Unsafe[E | E2, B] =
val p = new IOPromise[E | E2, B](interrupts = self) with (Result[E, A] => Unit):
def apply(r: Result[E, A]) = r.fold(completeDiscard)(v => becomeDiscard(f(v)))
self.onComplete(p)
p
end flatMap

def mapResult[E2, B](f: Result[E, A] => Result[E2, B])(using AllowUnsafe): Unsafe[E2, B] =
val p = new IOPromise[E2, B](interrupts = self) with (Result[E, A] => Unit):
def apply(r: Result[E, A]) = completeDiscard(Result(f(r)).flatten)
self.onComplete(p)
p
end mapResult

def safe: Fiber[E, A] = self
end extension
end Unsafe

opaque type Promise[E, A] <: Fiber[E, A] = IOPromise[E, A]
opaque type Promise[+E, +A] <: Fiber[E, A] = IOPromise[E, A]

object Promise:
inline given [E, A]: Flat[Promise[E, A]] = Flat.unsafe.bypass
Expand Down Expand Up @@ -446,7 +457,7 @@ object Fiber extends FiberPlatformSpecific:
def unsafe: Unsafe[E, A] = self
end extension

opaque type Unsafe[E, A] <: Fiber.Unsafe[E, A] = IOPromise[E, A]
opaque type Unsafe[+E, +A] <: Fiber.Unsafe[E, A] = IOPromise[E, A]

/** WARNING: Low-level API meant for integrations, libraries, and performance-sensitive code. See AllowUnsafe for more details. */
object Unsafe:
Expand Down
66 changes: 33 additions & 33 deletions kyo-core/shared/src/main/scala/kyo/scheduler/IOPromise.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import scala.annotation.tailrec
import scala.util.control.NonFatal
import scala.util.control.NoStackTrace

private[kyo] class IOPromise[E, A](init: State[E, A]) extends Safepoint.Interceptor:
private[kyo] class IOPromise[+E, +A](init: State[E, A]) extends Safepoint.Interceptor:

@volatile private var state: State[E, A] = init

Expand All @@ -21,7 +21,7 @@ private[kyo] class IOPromise[E, A](init: State[E, A]) extends Safepoint.Intercep
def removeFinalizer(f: () => Unit): Unit = {}
def enter(frame: Frame, value: Any): Boolean = true

private def cas[E2 <: E, A2 <: A](curr: State[E2, A2], next: State[E2, A2]): Boolean =
private def cas[E2 >: E, A2 >: A](curr: State[E2, A2], next: State[E2, A2]): Boolean =
if stateHandle eq null then
((isNull(state) && isNull(curr)) || state.equals(curr)) && {
state = next.asInstanceOf[State[E, A]]
Expand Down Expand Up @@ -62,7 +62,7 @@ private[kyo] class IOPromise[E, A](init: State[E, A]) extends Safepoint.Intercep
interruptsLoop(this)
end interrupts

final def mask: IOPromise[E, A] =
final def mask(): IOPromise[E, A] =
val p = new IOPromise[E, A]:
override def interrupt(error: Panic): Boolean = false
onComplete(p.completeDiscard)
Expand Down Expand Up @@ -91,7 +91,7 @@ private[kyo] class IOPromise[E, A](init: State[E, A]) extends Safepoint.Intercep
compressLoop(this)
end compress

final private def merge(p: Pending[E, A]): Unit =
final private def merge[E2 >: E, A2 >: A](p: Pending[E2, A2]): Unit =
@tailrec def mergeLoop(promise: IOPromise[E, A]): Unit =
promise.state match
case p2: Pending[E, A] @unchecked =>
Expand All @@ -104,10 +104,10 @@ private[kyo] class IOPromise[E, A](init: State[E, A]) extends Safepoint.Intercep
mergeLoop(this)
end merge

final def becomeDiscard[E2 <: E, A2 <: A](other: IOPromise[E2, A2]): Unit =
final def becomeDiscard[E2 >: E, A2 >: A](other: IOPromise[E2, A2]): Unit =
discard(become(other))

final def become[E2 <: E, A2 <: A](other: IOPromise[E2, A2]): Boolean =
final def become[E2 >: E, A2 >: A](other: IOPromise[E2, A2]): Boolean =
@tailrec def becomeLoop(other: IOPromise[E2, A2]): Boolean =
state match
case p: Pending[E2, A2] @unchecked =>
Expand Down Expand Up @@ -153,24 +153,24 @@ private[kyo] class IOPromise[E, A](init: State[E, A]) extends Safepoint.Intercep

protected def onComplete(): Unit = {}

final private def interrupt(p: Pending[E, A], v: Panic): Boolean =
final private def interrupt[E2 >: E, A2 >: A](p: Pending[E2, A2], v: Panic): Boolean =
cas(p, v) && {
onComplete()
p.flushInterrupt(v)
true
}

final private def complete(p: Pending[E, A], v: Result[E, A]): Boolean =
final private def complete[E2 >: E, A2 >: A](p: Pending[E2, A2], v: Result[E2, A2]): Boolean =
cas(p, v) && {
onComplete()
p.flush(v)
true
}

final def completeDiscard[E2 <: E, A2 <: A](v: Result[E2, A2]): Unit =
final def completeDiscard[E2 >: E, A2 >: A](v: Result[E2, A2]): Unit =
discard(complete(v))

final def complete[E2 <: E, A2 <: A](v: Result[E2, A2]): Boolean =
final def complete[E2 >: E, A2 >: A](v: Result[E2, A2]): Boolean =
@tailrec def completeLoop(): Boolean =
state match
case p: Pending[E, A] @unchecked =>
Expand Down Expand Up @@ -232,16 +232,16 @@ private[kyo] object IOPromise extends IOPromisePlatformSpecific:

case class Interrupt(origin: Frame) extends Exception with NoStackTrace

type State[E, A] = Result[E, A] | Pending[E, A] | Linked[E, A]
type State[+E, +A] = Result[E, A] | Pending[E, A] | Linked[E, A]

case class Linked[E, A](p: IOPromise[E, A])
case class Linked[+E, +A](p: IOPromise[E, A])

abstract class Pending[E, A]:
abstract class Pending[+E, +A]:
self =>

def waiters: Int
def interrupt(v: Panic): Pending[E, A]
def run(v: Result[E, A]): Pending[E, A]
def run[E2 >: E, A2 >: A](v: Result[E2, A2]): Pending[E2, A2]

@nowarn("msg=anonymous")
inline def onComplete(inline f: Result[E, A] => Unit): Pending[E, A] =
Expand All @@ -250,8 +250,8 @@ private[kyo] object IOPromise extends IOPromisePlatformSpecific:
def interrupt(v: Panic) =
f(v)
self
def run(v: Result[E, A]) =
try f(v)
def run[E2 >: E, A2 >: A](v: Result[E2, A2]) =
try f(v.asInstanceOf[Result[E, A]])
catch
case ex if NonFatal(ex) =>
given Frame = Frame.internal
Expand All @@ -267,7 +267,7 @@ private[kyo] object IOPromise extends IOPromisePlatformSpecific:
discard(p.interrupt(panic))
self
def waiters: Int = self.waiters + 1
def run(v: Result[E, A]) =
def run[E2 >: E, A2 >: A](v: Result[E2, A2]) =
self

@nowarn("msg=anonymous")
Expand All @@ -277,29 +277,29 @@ private[kyo] object IOPromise extends IOPromisePlatformSpecific:
f(panic)
self
def waiters: Int = self.waiters + 1
def run(v: Result[E, A]) =
def run[E2 >: E, A2 >: A](v: Result[E2, A2]) =
self

final def merge(tail: Pending[E, A]): Pending[E, A] =
final def merge[E2 >: E, A2 >: A](tail: Pending[E2, A2]): Pending[E2, A2] =

@tailrec def runLoop(p: Pending[E, A], v: Result[E, A]): Pending[E, A] =
@tailrec def runLoop[E3 >: E2, A3 >: A2](p: Pending[? <: E3, ? <: A3], v: Result[E3, A3]): Pending[E3, A3] =
p match
case _ if (p eq Pending.Empty) => tail
case p: Pending[E, A] => runLoop(p.run(v), v)
case p: Pending[?, ?] => runLoop(p.run(v), v)

@tailrec def interruptLoop(p: Pending[E, A], panic: Panic): Pending[E, A] =
@tailrec def interruptLoop(p: Pending[E, A], panic: Panic): Pending[E2, A2] =
p match
case _ if (p eq Pending.Empty) => tail
case p: Pending[E, A] => interruptLoop(p.interrupt(panic), panic)

new Pending[E, A]:
def waiters: Int = self.waiters + 1
def interrupt(panic: Panic) = interruptLoop(self, panic)
def run(v: Result[E, A]) = runLoop(self, v)
new Pending[E2, A2]:
def waiters: Int = self.waiters + tail.waiters
def interrupt(panic: Panic) = interruptLoop(self, panic)
def run[E3 >: E2, A3 >: A2](v: Result[E3, A3]) = runLoop(self, v)
end new
end merge

final def flushInterrupt(v: Panic): Unit =
final def flushInterrupt[E2 >: E, A2 >: A](v: Panic): Unit =
@tailrec def flushInterruptLoop(p: Pending[E, A]): Unit =
p match
case _ if (p eq Pending.Empty) => ()
Expand All @@ -308,11 +308,11 @@ private[kyo] object IOPromise extends IOPromisePlatformSpecific:
flushInterruptLoop(this)
end flushInterrupt

final def flush(v: Result[E, A]): Unit =
@tailrec def flushLoop(p: Pending[E, A]): Unit =
final def flush[E2 >: E, A2 >: A](v: Result[E2, A2]): Unit =
@tailrec def flushLoop[E3 >: E2, A3 >: A2](p: Pending[? <: E3, ? <: A3]): Unit =
p match
case _ if (p eq Pending.Empty) => ()
case p: Pending[E, A] =>
case p: Pending[?, ?] =>
flushLoop(p.run(v))
flushLoop(this)
end flush
Expand All @@ -322,9 +322,9 @@ private[kyo] object IOPromise extends IOPromisePlatformSpecific:
object Pending:
def apply[E, A](): Pending[E, A] = Empty.asInstanceOf[Pending[E, A]]
case object Empty extends Pending[Nothing, Nothing]:
def waiters: Int = 0
def interrupt(v: Panic) = this
def run(v: Result[Nothing, Nothing]) = this
def waiters: Int = 0
def interrupt(v: Panic) = this
def run[E2, A2](v: Result[E2, A2]) = this
end Empty
end Pending
end IOPromise
Loading

0 comments on commit 5bf22f6

Please sign in to comment.