Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/routing-rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,19 @@ defaults to the `adhoc` group.
```yaml
routing:
defaultRoutingGroup: "test-group"
# Optional: cache backend metadata to reduce database look-ups
databaseCacheTTL: "5m"
Copy link
Member

@Peiyingy Peiyingy Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's create a CacheConfiguration instead of putting all the configs under routing configs. We can define cacheEnabled, expireAfterWrite, and maximumSize explicitly instead of using if (!routingConfiguration.getDatabaseCacheTTL().isZero()) to decide if cache is enabled, and set default values for them. Also, we can make cache behavior configurable there.

```

Set `databaseCacheTTL` to a non-zero [Airlift duration](https://airlift.github.io/airlift/units/) value to enable in-memory caching of backend
metadata retrieved from the gateway database. Trino Gateway caches the list of
backend clusters for the specified time and refreshes it asynchronously. Use
this setting to reduce database load and improve routing performance.

A value of `0s` (the default) disables the cache and queries the database on
every request.


The routing rules engine feature enables you to either write custom logic to
route requests based on the request info such as any of the [request
headers](https://trino.io/docs/current/develop/client-protocol.html#client-request-headers),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@

import io.airlift.units.Duration;

import static io.airlift.units.Duration.ZERO;
import static java.util.concurrent.TimeUnit.MINUTES;

public class RoutingConfiguration
{
private Duration asyncTimeout = new Duration(2, MINUTES);
private Duration databaseCacheTTL = ZERO;

private boolean addXForwardedHeaders = true;

Expand Down Expand Up @@ -54,4 +56,14 @@ public void setDefaultRoutingGroup(String defaultRoutingGroup)
{
this.defaultRoutingGroup = defaultRoutingGroup;
}

public Duration getDatabaseCacheTTL()
{
return databaseCacheTTL;
}

public void setDatabaseCacheTTL(Duration databaseCacheTTL)
{
this.databaseCacheTTL = databaseCacheTTL;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,6 @@ public interface GatewayBackendDao
@SqlQuery("SELECT * FROM gateway_backend")
List<GatewayBackend> findAll();

@SqlQuery("""
SELECT * FROM gateway_backend
WHERE active = true
""")
List<GatewayBackend> findActiveBackend();

@SqlQuery("""
SELECT * FROM gateway_backend
WHERE active = true AND routing_group = :routingGroup
""")
List<GatewayBackend> findActiveBackendByRoutingGroup(String routingGroup);

@SqlQuery("""
SELECT * FROM gateway_backend
WHERE name = :name
""")
List<GatewayBackend> findByName(String name);

@SqlQuery("""
SELECT * FROM gateway_backend
WHERE name = :name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
*/
package io.trino.gateway.ha.router;

import com.google.common.cache.CacheBuilder;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also use this opportunity to upgrade the caching dependency from guava to caffeine.

import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
import com.google.common.collect.ImmutableList;
import com.google.common.util.concurrent.MoreExecutors;
import io.airlift.log.Logger;
import io.airlift.stats.CounterStat;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
import io.trino.gateway.ha.config.RoutingConfiguration;
import io.trino.gateway.ha.persistence.dao.GatewayBackend;
Expand All @@ -24,35 +29,92 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.Executors;

import static com.google.common.base.Preconditions.checkState;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.util.Objects.requireNonNull;

public class HaGatewayManager
implements GatewayBackendManager
{
private static final Logger log = Logger.get(HaGatewayManager.class);
private static final Object ALL_BACKEND_CACHE_KEY = new Object();

private final GatewayBackendDao dao;
private final String defaultRoutingGroup;
private final boolean cacheEnabled;
private final LoadingCache<Object, List<GatewayBackend>> backendCache;

private final CounterStat backendLookupSuccesses = new CounterStat();
private final CounterStat backendLookupFailures = new CounterStat();

public HaGatewayManager(Jdbi jdbi, RoutingConfiguration routingConfiguration)
{
dao = requireNonNull(jdbi, "jdbi is null").onDemand(GatewayBackendDao.class);
this.defaultRoutingGroup = routingConfiguration.getDefaultRoutingGroup();
if (!routingConfiguration.getDatabaseCacheTTL().isZero()) {
cacheEnabled = true;
backendCache = CacheBuilder
.newBuilder()
.initialCapacity(1)
.refreshAfterWrite(routingConfiguration.getDatabaseCacheTTL().toJavaTime())
.build(CacheLoader.asyncReloading(
CacheLoader.from(this::fetchAllBackends),
MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor())));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using a ThreadFactory to prevent thread leaks?
By default, newSingleThreadExecutor seems to create non-daemon threads.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

   ThreadFactory daemonThreadFactory = runnable -> {
       Thread thread = new Thread(runnable, "backend-cache-refresh");
       thread.setDaemon(true);
       return thread;
   };
   ...

   MoreExecutors.listeningDecorator(Executors.newSingleThreadExecutor(daemonThreadFactory))));

// Load the data once during initialization. This ensures a fail-fast behavior in case of database misconfiguration.
backendCache.getUnchecked(ALL_BACKEND_CACHE_KEY);
}
else {
cacheEnabled = false;
backendCache = null;
}
}

private List<GatewayBackend> fetchAllBackends()
{
try {
List<GatewayBackend> backends = dao.findAll();
backendLookupSuccesses.update(1);
return backends;
}
catch (Exception e) {
backendLookupFailures.update(1);
log.warn(e, "Failed to fetch backends");
throw e;
Comment on lines +83 to +84
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we throw exception? or just return empty list to maintain service availability
I mean the cache will retry on the next refresh cycle anyway

}
}

private void invalidateBackendCache()
{
if (cacheEnabled) {
backendCache.invalidateAll();
}
}

private List<GatewayBackend> getOrFetchAllBackends()
{
if (cacheEnabled) {
return backendCache.getUnchecked(ALL_BACKEND_CACHE_KEY);
}
else {
return fetchAllBackends();
}
}

@Override
public List<ProxyBackendConfiguration> getAllBackends()
{
List<GatewayBackend> proxyBackendList = dao.findAll();
List<GatewayBackend> proxyBackendList = getOrFetchAllBackends();
return upcast(proxyBackendList);
}

@Override
public List<ProxyBackendConfiguration> getAllActiveBackends()
{
List<GatewayBackend> proxyBackendList = dao.findActiveBackend();
List<GatewayBackend> proxyBackendList = getOrFetchAllBackends().stream()
.filter(GatewayBackend::active)
.collect(toImmutableList());
return upcast(proxyBackendList);
}

Expand All @@ -71,14 +133,19 @@ public List<ProxyBackendConfiguration> getActiveDefaultBackends()
@Override
public List<ProxyBackendConfiguration> getActiveBackends(String routingGroup)
{
List<GatewayBackend> proxyBackendList = dao.findActiveBackendByRoutingGroup(routingGroup);
List<GatewayBackend> proxyBackendList = getOrFetchAllBackends().stream()
.filter(GatewayBackend::active)
.filter(backend -> backend.routingGroup().equals(routingGroup))
.collect(toImmutableList());
return upcast(proxyBackendList);
}

@Override
public Optional<ProxyBackendConfiguration> getBackendByName(String name)
{
List<GatewayBackend> proxyBackendList = dao.findByName(name);
List<GatewayBackend> proxyBackendList = getOrFetchAllBackends().stream()
.filter(backend -> backend.name().equals(name))
.collect(toImmutableList());
return upcast(proxyBackendList).stream().findAny();
}

Expand All @@ -102,6 +169,7 @@ private void updateClusterActivationStatus(String clusterName, boolean newStatus
boolean previousStatus = model.active();
changeActiveStatus.run();
logActivationStatusChange(clusterName, newStatus, previousStatus);
invalidateBackendCache();
}

private static void logActivationStatusChange(String clusterName, boolean newStatus, boolean previousStatus)
Expand All @@ -117,6 +185,7 @@ public ProxyBackendConfiguration addBackend(ProxyBackendConfiguration backend)
String backendProxyTo = removeTrailingSlash(backend.getProxyTo());
String backendExternalUrl = removeTrailingSlash(backend.getExternalUrl());
dao.create(backend.getName(), backend.getRoutingGroup(), backendProxyTo, backendExternalUrl, backend.isActive());
invalidateBackendCache();
return backend;
}

Expand All @@ -133,12 +202,14 @@ public ProxyBackendConfiguration updateBackend(ProxyBackendConfiguration backend
dao.update(backend.getName(), backend.getRoutingGroup(), backendProxyTo, backendExternalUrl, backend.isActive());
logActivationStatusChange(backend.getName(), backend.isActive(), model.active());
}
invalidateBackendCache();
return backend;
}

public void deleteBackend(String name)
{
dao.deleteByName(name);
invalidateBackendCache();
}

private static List<ProxyBackendConfiguration> upcast(List<GatewayBackend> gatewayBackendList)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,41 @@
*/
package io.trino.gateway.ha.router;

import io.airlift.units.Duration;
import io.trino.gateway.ha.config.ProxyBackendConfiguration;
import io.trino.gateway.ha.config.RoutingConfiguration;
import io.trino.gateway.ha.persistence.JdbcConnectionManager;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.TestInstance.Lifecycle;

import java.util.concurrent.TimeUnit;

import static io.trino.gateway.ha.TestingJdbcConnectionManager.createTestingJdbcConnectionManager;
import static org.assertj.core.api.Assertions.assertThat;

@TestInstance(Lifecycle.PER_CLASS)
final class TestHaGatewayManager
{
private HaGatewayManager haGatewayManager;

@BeforeAll
void setUp()
@Test
void testGatewayManagerWithCache()
{
JdbcConnectionManager connectionManager = createTestingJdbcConnectionManager();
RoutingConfiguration routingConfiguration = new RoutingConfiguration();
haGatewayManager = new HaGatewayManager(connectionManager.getJdbi(), routingConfiguration);
routingConfiguration.setDatabaseCacheTTL(new Duration(5, TimeUnit.SECONDS));
testGatewayManager(new HaGatewayManager(connectionManager.getJdbi(), routingConfiguration));
}

@Test
void testGatewayManager()
void testGatewayManagerWithoutCache()
{
JdbcConnectionManager connectionManager = createTestingJdbcConnectionManager();
RoutingConfiguration routingConfiguration = new RoutingConfiguration();
routingConfiguration.setDatabaseCacheTTL(new Duration(0, TimeUnit.SECONDS));
testGatewayManager(new HaGatewayManager(connectionManager.getJdbi(), routingConfiguration));
}

void testGatewayManager(HaGatewayManager haGatewayManager)
{
ProxyBackendConfiguration backend = new ProxyBackendConfiguration();
backend.setActive(true);
Expand Down Expand Up @@ -101,6 +110,10 @@ void testGatewayManager()
@Test
void testRemoveTrailingSlashInUrl()
{
JdbcConnectionManager connectionManager = createTestingJdbcConnectionManager();
RoutingConfiguration routingConfiguration = new RoutingConfiguration();
HaGatewayManager haGatewayManager = new HaGatewayManager(connectionManager.getJdbi(), routingConfiguration);

ProxyBackendConfiguration etl = new ProxyBackendConfiguration();
etl.setActive(false);
etl.setRoutingGroup("etl");
Expand Down