Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions include/libkrun.h
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,24 @@ int32_t krun_add_vsock_port2(uint32_t ctx_id,
uint32_t port,
const char *c_filepath,
bool listen);

/**
* Adds a Unix socket tunnel between host and guest that creates Unix socket tunnels over vsock.
*
* Arguments:
* "ctx_id" - the configuration context ID.
* "c_guest_filepath"- a null-terminated string representing the path of the UNIX socket in the guest.
* "c_host_filepath" - a null-terminated string representing the path of the UNIX socket in the host.
* "listen" - true if guest should listen on the socket, false if guest should connect
*
* Returns:
* Zero on success or a negative error number on failure.
*/
int32_t krun_add_vsock_unix_tunnel(uint32_t ctx_id,
const char *c_guest_filepath,
const char *c_host_filepath,
bool listen);

/**
* Returns the eventfd file descriptor to signal the guest to shut down orderly. This must be
* called before starting the microVM with "krun_start_event". Only available in libkrun-efi.
Expand Down
312 changes: 312 additions & 0 deletions init/init.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/un.h>
#include <sys/wait.h>
#include <pthread.h>

#include <linux/vm_sockets.h>

Expand All @@ -40,6 +42,21 @@

static int jsoneq(const char *, jsmntok_t *, const char *);

struct unix_tunnel {
char guest_path[256];
int listen;
int vsock_port;
};

// Global synchronization for tunnel setup
static pthread_mutex_t tunnel_setup_mutex = PTHREAD_MUTEX_INITIALIZER;
static pthread_cond_t tunnel_setup_cond = PTHREAD_COND_INITIALIZER;
static int tunnel_setup_count = 0;
static int tunnel_ready_count = 0;

static void setup_unix_tunnels(const char *tunnels_env);
static void *unix_tunnel_proxy_thread(void *arg);

#ifdef SEV
static char *sev_get_luks_passphrase(int *);
static char *snp_get_luks_passphrase(char *, char *, char *, int *);
Expand Down Expand Up @@ -975,6 +992,299 @@ void set_exit_code(int code)
close(fd);
}

static void *unix_tunnel_proxy_thread(void *arg)
{
struct unix_tunnel *tunnel = (struct unix_tunnel *)arg;
int listen_sock = -1, connect_sock = -1, vsock_sock = -1;
struct sockaddr_un unix_addr;
struct sockaddr_vm vm_addr;
fd_set read_fds;
char buffer[4096];
int max_fd, ret;

if (tunnel->listen) {
// Create Unix socket for listening
listen_sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (listen_sock < 0) {
perror("Failed to create Unix listen socket");
goto thread_ready; // Still signal ready even on error
}

memset(&unix_addr, 0, sizeof(unix_addr));
unix_addr.sun_family = AF_UNIX;
if (strlen(tunnel->guest_path) >= sizeof(unix_addr.sun_path)) {
fprintf(stderr, "Unix socket path too long: %s\n", tunnel->guest_path);
close(listen_sock);
goto thread_ready;
}
strcpy(unix_addr.sun_path, tunnel->guest_path);
unlink(tunnel->guest_path); // Remove any existing socket

if (bind(listen_sock, (struct sockaddr *)&unix_addr, sizeof(unix_addr)) < 0) {
perror("Failed to bind Unix socket");
close(listen_sock);
goto thread_ready;
}

if (listen(listen_sock, 5) < 0) {
perror("Failed to listen on Unix socket");
close(listen_sock);
goto thread_ready;
}

// Signal that this thread is ready and then start the main loop
goto thread_ready;
} else {
// Connect mode: Listen on vsock, connect to Unix socket
vsock_sock = socket(AF_VSOCK, SOCK_STREAM, 0);
if (vsock_sock < 0) {
perror("Failed to create vsock socket");
goto thread_ready;
}

memset(&vm_addr, 0, sizeof(vm_addr));
vm_addr.svm_family = AF_VSOCK;
vm_addr.svm_cid = VMADDR_CID_ANY;
vm_addr.svm_port = tunnel->vsock_port;

if (bind(vsock_sock, (struct sockaddr *)&vm_addr, sizeof(vm_addr)) < 0) {
perror("Failed to bind vsock socket");
close(vsock_sock);
goto thread_ready;
}

if (listen(vsock_sock, 5) < 0) {
perror("Failed to listen on vsock socket");
close(vsock_sock);
goto thread_ready;
}

// Signal that this thread is ready and then start the main loop
goto thread_ready;
}

thread_ready:
// Signal that this thread has set up (whether successfully or not)
pthread_mutex_lock(&tunnel_setup_mutex);
tunnel_ready_count++;
pthread_cond_signal(&tunnel_setup_cond);
pthread_mutex_unlock(&tunnel_setup_mutex);

// If setup failed, exit thread
if (tunnel->listen && listen_sock < 0) {
return NULL;
}
if (!tunnel->listen && vsock_sock < 0) {
return NULL;
}

// Main proxy loop
if (tunnel->listen) {
// Listen mode: Accept Unix connections and proxy to vsock
while (1) {
// Accept Unix connection
connect_sock = accept(listen_sock, NULL, NULL);
if (connect_sock < 0) {
perror("Failed to accept Unix connection");
continue;
}

// Create vsock connection to host
int new_vsock_sock = socket(AF_VSOCK, SOCK_STREAM, 0);
if (new_vsock_sock < 0) {
perror("Failed to create vsock socket");
close(connect_sock);
continue;
}

memset(&vm_addr, 0, sizeof(vm_addr));
vm_addr.svm_family = AF_VSOCK;
vm_addr.svm_cid = VMADDR_CID_HOST;
vm_addr.svm_port = tunnel->vsock_port;

if (connect(new_vsock_sock, (struct sockaddr *)&vm_addr, sizeof(vm_addr)) < 0) {
perror("Failed to connect to host via vsock");
close(connect_sock);
close(new_vsock_sock);
continue;
}

// Proxy data between Unix socket and vsock
while (1) {
FD_ZERO(&read_fds);
FD_SET(connect_sock, &read_fds);
FD_SET(new_vsock_sock, &read_fds);
max_fd = (connect_sock > new_vsock_sock) ? connect_sock : new_vsock_sock;

ret = select(max_fd + 1, &read_fds, NULL, NULL, NULL);
if (ret <= 0) break;

if (FD_ISSET(connect_sock, &read_fds)) {
ret = read(connect_sock, buffer, sizeof(buffer));
if (ret <= 0) break;
if (write(new_vsock_sock, buffer, ret) != ret) break;
}

if (FD_ISSET(new_vsock_sock, &read_fds)) {
ret = read(new_vsock_sock, buffer, sizeof(buffer));
if (ret <= 0) break;
if (write(connect_sock, buffer, ret) != ret) break;
}
}

close(connect_sock);
close(new_vsock_sock);
}
} else {
// Connect mode: Accept vsock connections and proxy to Unix socket
while (1) {
// Accept vsock connection
connect_sock = accept(vsock_sock, NULL, NULL);
if (connect_sock < 0) {
perror("Failed to accept vsock connection");
continue;
}

// Create Unix socket connection
int new_listen_sock = socket(AF_UNIX, SOCK_STREAM, 0);
if (new_listen_sock < 0) {
perror("Failed to create Unix socket");
close(connect_sock);
continue;
}

memset(&unix_addr, 0, sizeof(unix_addr));
unix_addr.sun_family = AF_UNIX;
if (strlen(tunnel->guest_path) >= sizeof(unix_addr.sun_path)) {
fprintf(stderr, "Unix socket path too long: %s\n", tunnel->guest_path);
close(connect_sock);
close(new_listen_sock);
continue;
}
strcpy(unix_addr.sun_path, tunnel->guest_path);

if (connect(new_listen_sock, (struct sockaddr *)&unix_addr, sizeof(unix_addr)) < 0) {
perror("Failed to connect to Unix socket");
close(connect_sock);
close(new_listen_sock);
continue;
}

// Proxy data between vsock and Unix socket
while (1) {
FD_ZERO(&read_fds);
FD_SET(connect_sock, &read_fds);
FD_SET(new_listen_sock, &read_fds);
max_fd = (connect_sock > new_listen_sock) ? connect_sock : new_listen_sock;

ret = select(max_fd + 1, &read_fds, NULL, NULL, NULL);
if (ret <= 0) break;

if (FD_ISSET(connect_sock, &read_fds)) {
ret = read(connect_sock, buffer, sizeof(buffer));
if (ret <= 0) break;
if (write(new_listen_sock, buffer, ret) != ret) break;
}

if (FD_ISSET(new_listen_sock, &read_fds)) {
ret = read(new_listen_sock, buffer, sizeof(buffer));
if (ret <= 0) break;
if (write(connect_sock, buffer, ret) != ret) break;
}
}

close(connect_sock);
close(new_listen_sock);
}
}

return NULL;
}

static void setup_unix_tunnels(const char *tunnels_env)
{
char *tunnels_copy, *tunnel_str, *saveptr;
struct unix_tunnel *tunnel;
pthread_t thread;

if (!tunnels_env || strlen(tunnels_env) == 0) {
return;
}

tunnels_copy = strdup(tunnels_env);
if (!tunnels_copy) {
perror("Failed to allocate memory for tunnels");
return;
}

// Reset tunnel setup counters
pthread_mutex_lock(&tunnel_setup_mutex);
tunnel_setup_count = 0;
tunnel_ready_count = 0;
pthread_mutex_unlock(&tunnel_setup_mutex);

// Parse tunnels: "guest_path1:vsock_port1:listen,guest_path2:vsock_port2:connect,..."
tunnel_str = strtok_r(tunnels_copy, ",", &saveptr);
while (tunnel_str != NULL) {
char *guest_path, *vsock_port_str, *mode_str;
char *tunnel_saveptr;

// Parse individual tunnel: "guest_path:vsock_port:mode"
guest_path = strtok_r(tunnel_str, ":", &tunnel_saveptr);
vsock_port_str = strtok_r(NULL, ":", &tunnel_saveptr);
mode_str = strtok_r(NULL, ":", &tunnel_saveptr);

if (guest_path && vsock_port_str && mode_str) {
// Validate guest path length for Unix socket compatibility
if (strlen(guest_path) >= sizeof(((struct sockaddr_un*)0)->sun_path)) {
fprintf(stderr, "Guest socket path too long (max %zu chars): %s\n",
sizeof(((struct sockaddr_un*)0)->sun_path) - 1, guest_path);
tunnel_str = strtok_r(NULL, ",", &saveptr);
continue;
}

tunnel = malloc(sizeof(struct unix_tunnel));
if (!tunnel) {
perror("Failed to allocate memory for tunnel");
break;
}

strncpy(tunnel->guest_path, guest_path, sizeof(tunnel->guest_path) - 1);
tunnel->guest_path[sizeof(tunnel->guest_path) - 1] = '\0';
tunnel->listen = (strcmp(mode_str, "listen") == 0) ? 1 : 0;
tunnel->vsock_port = atoi(vsock_port_str);

// Increment tunnel setup count
pthread_mutex_lock(&tunnel_setup_mutex);
tunnel_setup_count++;
pthread_mutex_unlock(&tunnel_setup_mutex);

// Create thread for this tunnel
if (pthread_create(&thread, NULL, unix_tunnel_proxy_thread, tunnel) != 0) {
perror("Failed to create tunnel thread");
free(tunnel);
// Decrement count since thread creation failed
pthread_mutex_lock(&tunnel_setup_mutex);
tunnel_setup_count--;
pthread_mutex_unlock(&tunnel_setup_mutex);
} else {
pthread_detach(thread); // Detach so thread resources are cleaned up
}
}

tunnel_str = strtok_r(NULL, ",", &saveptr);
}

// Wait for all tunnel threads to be ready
pthread_mutex_lock(&tunnel_setup_mutex);
while (tunnel_ready_count < tunnel_setup_count) {
pthread_cond_wait(&tunnel_setup_cond, &tunnel_setup_mutex);
}
pthread_mutex_unlock(&tunnel_setup_mutex);

free(tunnels_copy);
}

int main(int argc, char **argv)
{
struct ifreq ifr;
Expand Down Expand Up @@ -1057,6 +1367,8 @@ int main(int argc, char **argv)
exec_argv[0] = &DEFAULT_KRUN_INIT[0];
}

setup_unix_tunnels(getenv("KRUN_TUNNELS"));

#ifdef __TIMESYNC__
if (fork() == 0) {
clock_worker();
Expand Down
Loading
Loading