Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public class Async {

private final Vertx vertx;
private final boolean useVirtualEventLoopThreads;
private static final String VTHREAD_CTX = "VTHREAD_CTX";

public Async(Vertx vertx) {
this(vertx, false);
Expand All @@ -32,17 +33,21 @@ public Async(Vertx vertx, boolean useVirtualEventLoopThreads) {
* Run a task on a virtual thread
*/
public void run(Handler<Void> task) {
assert !Thread.currentThread().isVirtual();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually expect to be able to call .run from a virtual thread all the time

Context ctx = vertx.getOrCreateContext();
EventLoop eventLoop;
if (ctx.isEventLoopContext()) {
eventLoop = ((ContextInternal)ctx).nettyEventLoop();
} else {
if (!ctx.isEventLoopContext()) {
throw new IllegalStateException();
}
// Scheduler scheduler = useVirtualEventLoopThreads ? new SchedulerImpl(LoomaniaScheduler2.threadFactory(eventLoop)): new SchedulerImpl(SchedulerImpl.DEFAULT_THREAD_FACTORY);
Scheduler scheduler = useVirtualEventLoopThreads ? new EventLoopScheduler(eventLoop) : new DefaultScheduler(DefaultScheduler.DEFAULT_THREAD_FACTORY);
VirtualThreadContext context = VirtualThreadContext.create(vertx, eventLoop, scheduler);
context.runOnContext(task);
var unsafeCtx = (ContextInternal) ctx;
VirtualThreadContext virtualCtx = unsafeCtx.getLocal(VTHREAD_CTX);
if (virtualCtx == null) {
EventLoop eventLoop = unsafeCtx.nettyEventLoop();
// Scheduler scheduler = useVirtualEventLoopThreads ? new SchedulerImpl(LoomaniaScheduler2.threadFactory(eventLoop)): new SchedulerImpl(SchedulerImpl.DEFAULT_THREAD_FACTORY);
Scheduler scheduler = useVirtualEventLoopThreads ? new EventLoopScheduler(eventLoop) : new DefaultScheduler(DefaultScheduler.DEFAULT_THREAD_FACTORY);
virtualCtx = VirtualThreadContext.create(vertx, eventLoop, scheduler);
unsafeCtx.putLocal(VTHREAD_CTX, virtualCtx);
}
virtualCtx.runOnContext(task);
}

private static VirtualThreadContext virtualThreadContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,29 @@ private static ThreadFactory threadFactory(Executor carrier) {

private final ThreadFactory threadFactory;
private final LinkedList<Runnable> tasks = new LinkedList<>();
private boolean flag;
private boolean runOnContext;

public EventLoopScheduler(EventLoop carrier) {
this(command -> {
if (carrier.inEventLoop()) {
command.run();
} else {
carrier.execute(command);
}
execute(carrier, command);
});
}

public EventLoopScheduler(Executor carrier) {
private static void execute(EventLoop carrier, Runnable command) {
if (carrier.inEventLoop()) {
command.run();
} else {
carrier.execute(command);
}
}

private EventLoopScheduler(Executor carrier) {
threadFactory = threadFactory(command -> {
if (flag) {
if (runOnContext) {
tasks.addLast(command);
} else {
// "external" continuations are prioritized and placed
// upfront, to be consumed first
tasks.addFirst(command);
}
carrier.execute(() -> {
Expand All @@ -84,11 +90,11 @@ public Consumer<Runnable> unschedule() {

public void execute(Runnable runnable) {
Thread thread = threadFactory.newThread(runnable);
flag = true;
runOnContext = true;
try {
thread.start();
} finally {
flag = false;
runOnContext = false;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,10 @@
import io.vertx.core.Context;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.impl.CloseFuture;
import io.vertx.core.impl.ContextBase;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.impl.Deployment;
import io.vertx.core.impl.VertxImpl;
import io.vertx.core.impl.VertxInternal;
import io.vertx.core.impl.WorkerPool;
import io.vertx.core.impl.*;

import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.*;
import java.util.concurrent.locks.Lock;
import java.util.function.Consumer;

Expand All @@ -32,7 +22,11 @@ public static VirtualThreadContext create(Vertx vertx, EventLoop nettyEventLoop,
}

private final Scheduler scheduler;
private final ThreadLocal<Boolean> inThread = new ThreadLocal<>();

// Use this instead of a ThreadLocal because must friendly with Virtual Threads!
// ideally we should use https://github.com/JCTools/JCTools/blob/master/jctools-core/src/main/java/org/jctools/maps/NonBlockingHashMap.java
// which doesn't use any synchronized op!
private final ConcurrentHashSet<Thread> inThread = new ConcurrentHashSet<>();

VirtualThreadContext(VertxInternal vertx,
EventLoop eventLoop,
Expand Down Expand Up @@ -97,38 +91,39 @@ public boolean isWorkerContext() {
private <T> void run(ContextInternal ctx, T value, Handler<T> task) {
Objects.requireNonNull(task, "Task handler must not be null");
scheduler.execute(() -> {
inThread.set(true);
var current = Thread.currentThread();
inThread.add(current);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I adopted this fix separately in my fork, it's an essential one

try {
ctx.dispatch(value, task);
} finally {
inThread.remove();
inThread.remove(current);
}
});
}

private <T> void execute2(T argument, Handler<T> task) {
if (Context.isOnWorkerThread()) {
inThread.set(true);
try {
task.handle(argument);
} finally {
inThread.remove();
}
handle(argument, task);
} else {
scheduler.execute(() -> {
inThread.set(true);
try {
task.handle(argument);
} finally {
inThread.remove();
}
handle(argument, task);
});
}
}

private <T> void handle(T argument, Handler<T> task) {
var current = Thread.currentThread();
inThread.add(current);
try {
task.handle(argument);
} finally {
inThread.remove(current);
}
}

@Override
public boolean inThread() {
return inThread.get() == Boolean.TRUE;
return inThread.contains(Thread.currentThread());
}

@Override
Expand All @@ -138,7 +133,8 @@ public ContextInternal duplicate() {
}

public void lock(Lock lock) {
inThread.remove();
var current = Thread.currentThread();
inThread.remove(current);
Consumer<Runnable> cont = scheduler.unschedule();
CompletableFuture<Void> latch = new CompletableFuture<>();
try {
Expand All @@ -154,12 +150,13 @@ public void lock(Lock lock) {
} catch (ExecutionException e) {
throwAsUnchecked(e);
} finally {
inThread.set(true);
inThread.add(current);
}
}

public <T> T await(CompletionStage<T> fut) {
inThread.remove();
var current = Thread.currentThread();
inThread.remove(current);
Consumer<Runnable> cont = scheduler.unschedule();
CompletableFuture<T> latch = new CompletableFuture<>();
fut.whenComplete((v, err) -> {
Expand All @@ -175,7 +172,7 @@ public <T> T await(CompletionStage<T> fut) {
throwAsUnchecked(e.getCause());
return null;
} finally {
inThread.set(true);
inThread.add(current);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,23 @@ public void testSuspend() {
await();
}

@Test
public void testSuspendFromEventLoop() {
vertx.runOnContext(v0 -> {
async.run(v1 -> {
CompletableFuture<Void> cf = new CompletableFuture<>();
vertx.runOnContext(v2 -> {
cf.complete(null);
});
try {
cf.get(10, TimeUnit.SECONDS);
} catch (Exception e) {
fail(e);
}
testComplete();
});
});
await();
}

}