From b9fb22e3d9040ee8e83355817ec942cda88334d8 Mon Sep 17 00:00:00 2001 From: He-Pin Date: Mon, 27 Mar 2023 20:16:09 +0800 Subject: [PATCH] +str Add `startAfterNrOfConsumers` to BroadcastHub. --- .../pekko/stream/scaladsl/HubSpec.scala | 46 +++++++++ .../org/apache/pekko/stream/javadsl/Hub.scala | 29 ++++++ .../apache/pekko/stream/scaladsl/Hub.scala | 99 ++++++++++++++----- 3 files changed, 147 insertions(+), 27 deletions(-) diff --git a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala index c0aa775e2f9..ce6c4495b2a 100644 --- a/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala +++ b/stream-tests/src/test/scala/org/apache/pekko/stream/scaladsl/HubSpec.scala @@ -305,6 +305,52 @@ class HubSpec extends StreamSpec { f2.futureValue should ===(1 to 10) } + "broadcast elements to downstream after at least one subscriber" in { + val broadcast = Source(1 to 10).runWith(BroadcastHub.sink[Int](1, 256)) + val resultOne = broadcast.runWith(Sink.seq) // nothing happening yet + + Await.result(resultOne, 1.second) should be(1 to 10) // fails + } + + "broadcast all elements to all consumers" in { + val sourceQueue = Source.queue[Int](10) // used to block the source until we say so + val (queue, broadcast) = sourceQueue.toMat(BroadcastHub.sink(2, 256))(Keep.both).run() + val resultOne = broadcast.runWith(Sink.seq) // nothing happening yet + for (i <- 1 to 5) { + queue.offer(i) + } + val resultTwo = broadcast.runWith(Sink.seq) + for (i <- 6 to 10) { + queue.offer(i) + } + queue.complete() // only now is the source emptied + + Await.result(resultOne, 1.second) should be(1 to 10) + Await.result(resultTwo, 1.second) should be(1 to 10) + } + + "broadcast all elements to all consumers with hot upstream" in { + val broadcast = Source(1 to 10).runWith(BroadcastHub.sink[Int](2, 256)) + val resultOne = broadcast.runWith(Sink.seq) // nothing happening yet + val resultTwo = broadcast.runWith(Sink.seq) + + Await.result(resultOne, 1.second) should be(1 to 10) + Await.result(resultTwo, 1.second) should be(1 to 10) + } + + "broadcast all elements to all consumers with hot upstream even some subscriber unsubscribe" in { + val broadcast = Source(1 to 10).runWith(BroadcastHub.sink[Int](2, 256)) + val sub = broadcast.runWith(TestSink.apply()) + sub.request(1) + Thread.sleep(1000) + sub.cancel() + val resultOne = broadcast.runWith(Sink.seq) // nothing happening yet + val resultTwo = broadcast.runWith(Sink.seq) // nothing happening yet + + Await.result(resultOne, 1.second) should be(1 to 10) + Await.result(resultTwo, 1.second) should be(1 to 10) + } + "send the same prefix to consumers attaching around the same time if one cancels earlier" in { val (firstElem, source) = Source.maybe[Int].concat(Source(2 to 20)).toMat(BroadcastHub.sink(8))(Keep.both).run() diff --git a/stream/src/main/scala/org/apache/pekko/stream/javadsl/Hub.scala b/stream/src/main/scala/org/apache/pekko/stream/javadsl/Hub.scala index a9f49dfe0c9..d7f5cd9740c 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/javadsl/Hub.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/javadsl/Hub.scala @@ -164,6 +164,35 @@ object BroadcastHub { pekko.stream.scaladsl.BroadcastHub.sink[T](bufferSize).mapMaterializedValue(_.asJava).asJava } + /** + * Creates a [[Sink]] that receives elements from its upstream producer and broadcasts them to a dynamic set + * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized + * value. This [[Source]] can be materialized an arbitrary number of times and each materialization will receive the + * broadcast elements from the original [[Sink]]. + * + * Every new materialization of the [[Sink]] results in a new, independent hub, which materializes to its own + * [[Source]] for consuming the [[Sink]] of that materialization. + * + * If the original [[Sink]] is failed, then the failure is immediately propagated to all of its materialized + * [[Source]]s (possibly jumping over already buffered elements). If the original [[Sink]] is completed, then + * all corresponding [[Source]]s are completed. Both failure and normal completion is "remembered" and later + * materializations of the [[Source]] will see the same (failure or completion) state. [[Source]]s that are + * cancelled are simply removed from the dynamic set of consumers. + * + * @param clazz Type of elements this hub emits and consumes + * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. + * This is only used initially when the operator is starting up, i.e. it is not honored when consumers have + * been removed (canceled). + * @param bufferSize Buffer size used by the producer. Gives an upper bound on how "far" from each other two + * concurrent consumers can be in terms of element. If the buffer is full, the producer + * is backpressured. Must be a power of two and less than 4096. + * @since 1.1.0 + */ + def of[T](@unused clazz: Class[T], startAfterNrOfConsumers: Int, bufferSize: Int): Sink[T, Source[T, NotUsed]] = { + pekko.stream.scaladsl.BroadcastHub.sink[T](startAfterNrOfConsumers, bufferSize).mapMaterializedValue( + _.asJava).asJava + } + /** * Creates a [[Sink]] with default buffer size 256 that receives elements from its upstream producer and broadcasts them to a dynamic set * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized diff --git a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala index 6ca55cfa259..f336d961f74 100644 --- a/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala +++ b/stream/src/main/scala/org/apache/pekko/stream/scaladsl/Hub.scala @@ -424,6 +424,32 @@ object BroadcastHub { */ def sink[T](bufferSize: Int): Sink[T, Source[T, NotUsed]] = Sink.fromGraph(new BroadcastHub[T](bufferSize)) + /** + * Creates a [[Sink]] that receives elements from its upstream producer and broadcasts them to a dynamic set + * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized + * value. This [[Source]] can be materialized an arbitrary number of times and each materialization will receive the + * broadcast elements from the original [[Sink]]. + * + * Every new materialization of the [[Sink]] results in a new, independent hub, which materializes to its own + * [[Source]] for consuming the [[Sink]] of that materialization. + * + * If the original [[Sink]] is failed, then the failure is immediately propagated to all of its materialized + * [[Source]]s (possibly jumping over already buffered elements). If the original [[Sink]] is completed, then + * all corresponding [[Source]]s are completed. Both failure and normal completion is "remembered" and later + * materializations of the [[Source]] will see the same (failure or completion) state. [[Source]]s that are + * cancelled are simply removed from the dynamic set of consumers. + * + * @param startAfterNrOfConsumers Elements are buffered until this number of consumers have been connected. + * This is only used initially when the operator is starting up, i.e. it is not honored when consumers have + * been removed (canceled). + * @param bufferSize Buffer size used by the producer. Gives an upper bound on how "far" from each other two + * concurrent consumers can be in terms of element. If this buffer is full, the producer + * is backpressured. Must be a power of two and less than 4096. + * @since 1.1.0 + */ + def sink[T](startAfterNrOfConsumers: Int, bufferSize: Int): Sink[T, Source[T, NotUsed]] = + Sink.fromGraph(new BroadcastHub[T](startAfterNrOfConsumers, bufferSize)) + /** * Creates a [[Sink]] with default buffer size 256 that receives elements from its upstream producer and broadcasts them to a dynamic set * of consumers. After the [[Sink]] returned by this method is materialized, it returns a [[Source]] as materialized @@ -446,11 +472,13 @@ object BroadcastHub { /** * INTERNAL API */ -private[pekko] class BroadcastHub[T](bufferSize: Int) +private[pekko] class BroadcastHub[T](startAfterNrOfConsumers: Int, bufferSize: Int) extends GraphStageWithMaterializedValue[SinkShape[T], Source[T, NotUsed]] { + require(startAfterNrOfConsumers >= 0, "startAfterNrOfConsumers must >= 0") require(bufferSize > 0, "Buffer size must be positive") require(bufferSize < 4096, "Buffer size larger then 4095 is not allowed") require((bufferSize & bufferSize - 1) == 0, "Buffer size must be a power of two") + def this(bufferSize: Int) = this(0, bufferSize) private val Mask = bufferSize - 1 private val WheelMask = (bufferSize * 2) - 1 @@ -482,6 +510,7 @@ private[pekko] class BroadcastHub[T](bufferSize: Int) private[this] val callbackPromise: Promise[AsyncCallback[HubEvent]] = Promise() private[this] val noRegistrationsState = Open(callbackPromise.future, Nil) val state = new AtomicReference[HubState](noRegistrationsState) + private var initialized = false // Start from values that will almost immediately overflow. This has no effect on performance, any starting // number will do, however, this protects from regressions as these values *almost surely* overflow and fail @@ -511,7 +540,9 @@ private[pekko] class BroadcastHub[T](bufferSize: Int) override def preStart(): Unit = { setKeepGoing(true) callbackPromise.success(getAsyncCallback[HubEvent](onEvent)) - pull(in) + if (startAfterNrOfConsumers == 0) { + pull(in) + } } // Cannot complete immediately if there is no space in the queue to put the completion marker @@ -522,8 +553,29 @@ private[pekko] class BroadcastHub[T](bufferSize: Int) if (!isFull) pull(in) } + private def tryPull(): Unit = { + if (initialized && !isClosed(in) && !hasBeenPulled(in) && !isFull) { + pull(in) + } + } + private def onEvent(ev: HubEvent): Unit = { ev match { + case Advance(id, previousOffset) => + val newOffset = previousOffset + DemandThreshold + // Move the consumer from its last known offset to its new one. Check if we are unblocked. + val consumer = findAndRemoveConsumer(id, previousOffset) + addConsumer(consumer, newOffset) + checkUnblock(previousOffset) + case NeedWakeup(id, previousOffset, currentOffset) => + // Move the consumer from its last known offset to its new one. Check if we are unblocked. + val consumer = findAndRemoveConsumer(id, previousOffset) + addConsumer(consumer, currentOffset) + + // Also check if the consumer is now unblocked since we published an element since it went asleep. + if (currentOffset != tail) consumer.callback.invoke(Wakeup) + checkUnblock(previousOffset) + case RegistrationPending => state.getAndSet(noRegistrationsState).asInstanceOf[Open].registrations.foreach { consumer => val startFrom = head @@ -538,6 +590,10 @@ private[pekko] class BroadcastHub[T](bufferSize: Int) case _ => () } } + if (activeConsumers >= startAfterNrOfConsumers) { + initialized = true + } + tryPull() case UnRegister(id, previousOffset, finalOffset) => if (findAndRemoveConsumer(id, previousOffset) != null) @@ -552,24 +608,10 @@ private[pekko] class BroadcastHub[T](bufferSize: Int) head += 1 } head = finalOffset - if (!hasBeenPulled(in)) pull(in) + tryPull() } } else checkUnblock(previousOffset) - case Advance(id, previousOffset) => - val newOffset = previousOffset + DemandThreshold - // Move the consumer from its last known offset to its new one. Check if we are unblocked. - val consumer = findAndRemoveConsumer(id, previousOffset) - addConsumer(consumer, newOffset) - checkUnblock(previousOffset) - case NeedWakeup(id, previousOffset, currentOffset) => - // Move the consumer from its last known offset to its new one. Check if we are unblocked. - val consumer = findAndRemoveConsumer(id, previousOffset) - addConsumer(consumer, currentOffset) - - // Also check if the consumer is now unblocked since we published an element since it went asleep. - if (currentOffset != tail) consumer.callback.invoke(Wakeup) - checkUnblock(previousOffset) } } @@ -624,7 +666,7 @@ private[pekko] class BroadcastHub[T](bufferSize: Int) private def checkUnblock(offsetOfConsumerRemoved: Int): Unit = { if (unblockIfPossible(offsetOfConsumerRemoved)) { if (isClosed(in)) complete() - else if (!hasBeenPulled(in)) pull(in) + else tryPull() } } @@ -1106,6 +1148,9 @@ object PartitionHub { startAfterNrOfConsumers: Int, bufferSize: Int) extends GraphStageWithMaterializedValue[SinkShape[T], Source[T, NotUsed]] { + require(partitioner != null, "partitioner must not be null") + require(startAfterNrOfConsumers >= 0, "startAfterNrOfConsumers must >= 0") + require(bufferSize > 0, "Buffer size must be positive") import PartitionHub.ConsumerInfo import PartitionHub.Internal._ @@ -1231,20 +1276,20 @@ object PartitionHub { val newConsumers = (consumerInfo.consumers :+ consumer).sortBy(_.id) consumerInfo = new ConsumerInfoImpl(newConsumers) queue.init(consumer.id) - if (newConsumers.size >= startAfterNrOfConsumers) { - initialized = true - } - consumer.callback.invoke(Initialize) + } - if (initialized && pending.nonEmpty) { - pending.foreach(publish) - pending = Vector.empty[T] - } + if (consumerInfo.size >= startAfterNrOfConsumers) { + initialized = true + } - tryPull() + if (initialized && pending.nonEmpty) { + pending.foreach(publish) + pending = Vector.empty[T] } + tryPull() + case UnRegister(id) => val newConsumers = consumerInfo.consumers.filterNot(_.id == id) consumerInfo = new ConsumerInfoImpl(newConsumers)