diff --git a/pom.xml b/pom.xml index b0a6988d..f5b1b26c 100644 --- a/pom.xml +++ b/pom.xml @@ -48,6 +48,14 @@ 1.11.0 + + + com.github.ben-manes.caffeine + caffeine + 2.9.3 + true + + org.junit.jupiter diff --git a/src/main/java/org/crazycake/shiro/LettuceRedisClusterManager.java b/src/main/java/org/crazycake/shiro/LettuceRedisClusterManager.java index d95d5669..674be02d 100644 --- a/src/main/java/org/crazycake/shiro/LettuceRedisClusterManager.java +++ b/src/main/java/org/crazycake/shiro/LettuceRedisClusterManager.java @@ -29,7 +29,7 @@ * @author Teamo * @since 2022/05/19 */ -public class LettuceRedisClusterManager implements IRedisManager { +public class LettuceRedisClusterManager implements IRedisManager, AutoCloseable { /** * Comma-separated list of "host:port" pairs to bootstrap from. This represents an @@ -82,14 +82,18 @@ public class LettuceRedisClusterManager implements IRedisManager { */ private ClusterClientOptions clusterClientOptions = ClusterClientOptions.create(); + /** + * RedisClusterClient. + */ + private RedisClusterClient redisClusterClient; + private void initialize() { if (genericObjectPool == null) { synchronized (LettuceRedisClusterManager.class) { if (genericObjectPool == null) { - RedisClusterClient redisClusterClient = RedisClusterClient.create(getClusterRedisURI()); + redisClusterClient = RedisClusterClient.create(getClusterRedisURI()); redisClusterClient.setOptions(clusterClientOptions); - StatefulRedisClusterConnection connect = redisClusterClient.connect(new ByteArrayCodec()); - genericObjectPool = ConnectionPoolSupport.createGenericObjectPool(() -> connect, genericObjectPoolConfig); + genericObjectPool = ConnectionPoolSupport.createGenericObjectPool(() -> redisClusterClient.connect(new ByteArrayCodec()), genericObjectPoolConfig); } } } @@ -106,6 +110,12 @@ private StatefulRedisClusterConnection getStatefulConnection() { } } + private void returnObject(StatefulRedisClusterConnection connection) { + if (connection != null) { + genericObjectPool.returnObject(connection); + } + } + private List getClusterRedisURI() { Objects.requireNonNull(nodes, "nodes must not be null!"); return nodes.stream().map(node -> { @@ -128,7 +138,9 @@ public byte[] get(byte[] key) { return null; } byte[] value = null; - try (StatefulRedisClusterConnection connection = getStatefulConnection()) { + StatefulRedisClusterConnection connection = null; + try { + connection = getStatefulConnection(); if (isAsync) { RedisAdvancedClusterAsyncCommands async = connection.async(); value = LettuceFutures.awaitOrCancel(async.get(key), timeout.getSeconds(), TimeUnit.SECONDS); @@ -136,6 +148,8 @@ public byte[] get(byte[] key) { RedisAdvancedClusterCommands sync = connection.sync(); value = sync.get(key); } + } finally { + returnObject(connection); } return value; } @@ -145,7 +159,9 @@ public byte[] set(byte[] key, byte[] value, int expire) { if (key == null) { return null; } - try (StatefulRedisClusterConnection connection = getStatefulConnection()) { + StatefulRedisClusterConnection connection = null; + try { + connection = getStatefulConnection(); if (isAsync) { RedisAdvancedClusterAsyncCommands async = connection.async(); if (expire > 0) { @@ -161,13 +177,17 @@ public byte[] set(byte[] key, byte[] value, int expire) { sync.set(key, value); } } + } finally { + returnObject(connection); } return value; } @Override public void del(byte[] key) { - try (StatefulRedisClusterConnection connection = getStatefulConnection()) { + StatefulRedisClusterConnection connection = null; + try { + connection = getStatefulConnection(); if (isAsync) { RedisAdvancedClusterAsyncCommands async = connection.async(); async.del(key); @@ -175,14 +195,17 @@ public void del(byte[] key) { RedisAdvancedClusterCommands sync = connection.sync(); sync.del(key); } + } finally { + returnObject(connection); } } @Override public Long dbSize(byte[] pattern) { AtomicLong dbSize = new AtomicLong(0L); - - try (StatefulRedisClusterConnection connection = getStatefulConnection()) { + StatefulRedisClusterConnection connection = null; + try { + connection = getStatefulConnection(); if (isAsync) { RedisAdvancedClusterAsyncCommands async = connection.async(); Partitions parse = ClusterPartitionParser.parse(LettuceFutures.awaitOrCancel(async.clusterNodes(), timeout.getSeconds(), TimeUnit.SECONDS)); @@ -214,6 +237,8 @@ public Long dbSize(byte[] pattern) { } }); } + } finally { + returnObject(connection); } return dbSize.get(); } @@ -221,8 +246,9 @@ public Long dbSize(byte[] pattern) { @Override public Set keys(byte[] pattern) { Set keys = new HashSet<>(); - - try (StatefulRedisClusterConnection connection = getStatefulConnection()) { + StatefulRedisClusterConnection connection = null; + try { + connection = getStatefulConnection(); if (isAsync) { RedisAdvancedClusterAsyncCommands async = connection.async(); Partitions parse = ClusterPartitionParser.parse(LettuceFutures.awaitOrCancel(async.clusterNodes(), timeout.getSeconds(), TimeUnit.SECONDS)); @@ -254,10 +280,22 @@ public Set keys(byte[] pattern) { } }); } + } finally { + returnObject(connection); } return keys; } + @Override + public void close() throws Exception { + if (genericObjectPool != null) { + genericObjectPool.close(); + } + if (redisClusterClient != null) { + redisClusterClient.shutdown(); + } + } + public List getNodes() { return nodes; } diff --git a/src/main/java/org/crazycake/shiro/LettuceRedisManager.java b/src/main/java/org/crazycake/shiro/LettuceRedisManager.java index 5913ad25..79230fa5 100644 --- a/src/main/java/org/crazycake/shiro/LettuceRedisManager.java +++ b/src/main/java/org/crazycake/shiro/LettuceRedisManager.java @@ -16,7 +16,7 @@ * @author Teamo * @since 2022/05/18 */ -public class LettuceRedisManager extends AbstractLettuceRedisManager { +public class LettuceRedisManager extends AbstractLettuceRedisManager> { /** * Redis server host. @@ -33,14 +33,18 @@ public class LettuceRedisManager extends AbstractLettuceRedisManager { */ private volatile GenericObjectPool> genericObjectPool; - @SuppressWarnings({"unchecked", "rawtypes"}) + /** + * Redis client. + */ + private RedisClient redisClient; + private void initialize() { if (genericObjectPool == null) { synchronized (LettuceRedisManager.class) { if (genericObjectPool == null) { - RedisClient redisClient = RedisClient.create(createRedisURI()); + redisClient = RedisClient.create(createRedisURI()); redisClient.setOptions(getClientOptions()); - GenericObjectPoolConfig genericObjectPoolConfig = getGenericObjectPoolConfig(); + GenericObjectPoolConfig> genericObjectPoolConfig = getGenericObjectPoolConfig(); genericObjectPool = ConnectionPoolSupport.createGenericObjectPool(() -> redisClient.connect(new ByteArrayCodec()), genericObjectPoolConfig); } } @@ -72,6 +76,23 @@ protected StatefulRedisConnection getStatefulConnection() { } } + @Override + protected void returnObject(StatefulRedisConnection connect) { + if (connect != null) { + genericObjectPool.returnObject(connect); + } + } + + @Override + public void close() throws Exception { + if (genericObjectPool != null) { + genericObjectPool.close(); + } + if (redisClient != null) { + redisClient.shutdown(); + } + } + public String getHost() { return host; } @@ -87,4 +108,12 @@ public int getPort() { public void setPort(int port) { this.port = port; } + + public GenericObjectPool> getGenericObjectPool() { + return genericObjectPool; + } + + public void setGenericObjectPool(GenericObjectPool> genericObjectPool) { + this.genericObjectPool = genericObjectPool; + } } diff --git a/src/main/java/org/crazycake/shiro/LettuceRedisSentinelManager.java b/src/main/java/org/crazycake/shiro/LettuceRedisSentinelManager.java index c3308ab0..a5a8a789 100644 --- a/src/main/java/org/crazycake/shiro/LettuceRedisSentinelManager.java +++ b/src/main/java/org/crazycake/shiro/LettuceRedisSentinelManager.java @@ -19,7 +19,7 @@ * @author Teamo * @since 2022/05/19 */ -public class LettuceRedisSentinelManager extends AbstractLettuceRedisManager { +public class LettuceRedisSentinelManager extends AbstractLettuceRedisManager> { private static final String DEFAULT_MASTER_NAME = "mymaster"; private String masterName = DEFAULT_MASTER_NAME; @@ -35,18 +35,24 @@ public class LettuceRedisSentinelManager extends AbstractLettuceRedisManager { */ private volatile GenericObjectPool> genericObjectPool; - @SuppressWarnings({"unchecked", "rawtypes"}) + /** + * RedisClient. + */ + private RedisClient redisClient; + private void initialize() { if (genericObjectPool == null) { synchronized (LettuceRedisSentinelManager.class) { if (genericObjectPool == null) { RedisURI redisURI = this.createSentinelRedisURI(); - RedisClient redisClient = RedisClient.create(redisURI); + redisClient = RedisClient.create(); redisClient.setOptions(getClientOptions()); - StatefulRedisMasterReplicaConnection connect = MasterReplica.connect(redisClient, new ByteArrayCodec(), redisURI); - connect.setReadFrom(readFrom); - GenericObjectPoolConfig genericObjectPoolConfig = getGenericObjectPoolConfig(); - genericObjectPool = ConnectionPoolSupport.createGenericObjectPool(() -> connect, genericObjectPoolConfig); + GenericObjectPoolConfig> genericObjectPoolConfig = getGenericObjectPoolConfig(); + genericObjectPool = ConnectionPoolSupport.createGenericObjectPool(() -> { + StatefulRedisMasterReplicaConnection connect = MasterReplica.connect(redisClient, new ByteArrayCodec(), redisURI); + connect.setReadFrom(readFrom); + return connect; + }, genericObjectPoolConfig); } } } @@ -64,6 +70,23 @@ protected StatefulRedisMasterReplicaConnection getStatefulConnec } } + @Override + protected void returnObject(StatefulRedisMasterReplicaConnection connect) { + if (connect != null) { + genericObjectPool.returnObject(connect); + } + } + + @Override + public void close() throws Exception { + if (genericObjectPool != null) { + genericObjectPool.close(); + } + if (redisClient != null) { + redisClient.shutdown(); + } + } + private RedisURI createSentinelRedisURI() { Objects.requireNonNull(nodes, "nodes must not be null!"); @@ -118,4 +141,12 @@ public ReadFrom getReadFrom() { public void setReadFrom(ReadFrom readFrom) { this.readFrom = readFrom; } + + public GenericObjectPool> getGenericObjectPool() { + return genericObjectPool; + } + + public void setGenericObjectPool(GenericObjectPool> genericObjectPool) { + this.genericObjectPool = genericObjectPool; + } } diff --git a/src/main/java/org/crazycake/shiro/RedisSessionDAO.java b/src/main/java/org/crazycake/shiro/RedisSessionDAO.java index 6069a85b..08377a3e 100644 --- a/src/main/java/org/crazycake/shiro/RedisSessionDAO.java +++ b/src/main/java/org/crazycake/shiro/RedisSessionDAO.java @@ -3,7 +3,8 @@ import org.apache.shiro.session.Session; import org.apache.shiro.session.UnknownSessionException; import org.apache.shiro.session.mgt.eis.AbstractSessionDAO; -import org.crazycake.shiro.common.SessionInMemory; +import org.crazycake.shiro.cache.CacheStrategy; +import org.crazycake.shiro.cache.MapCacheStrategy; import org.crazycake.shiro.exception.SerializationException; import org.crazycake.shiro.serializer.ObjectSerializer; import org.crazycake.shiro.serializer.RedisSerializer; @@ -12,39 +13,34 @@ import org.slf4j.LoggerFactory; import java.io.Serializable; -import java.util.*; +import java.util.Collection; +import java.util.HashSet; +import java.util.Set; /** * Used for setting/getting authentication information from Redis */ public class RedisSessionDAO extends AbstractSessionDAO { - private static Logger logger = LoggerFactory.getLogger(RedisSessionDAO.class); + private static final Logger LOGGER = LoggerFactory.getLogger(RedisSessionDAO.class); private static final String DEFAULT_SESSION_KEY_PREFIX = "shiro:session:"; private String keyPrefix = DEFAULT_SESSION_KEY_PREFIX; - /** - * doReadSession be called about 10 times when login. - * Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal. - * The default value is 1000 milliseconds (1s). - * Most of time, you don't need to change it. - * - * You can turn it off by setting sessionInMemoryEnabled to false - */ - private static final long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L; - private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT; - private static final boolean DEFAULT_SESSION_IN_MEMORY_ENABLED = true; private boolean sessionInMemoryEnabled = DEFAULT_SESSION_IN_MEMORY_ENABLED; - private static ThreadLocal sessionsInThread = new ThreadLocal(); + /** + * The cache strategy implementation (e.g., MapCacheStrategy) for managing in-memory session storage + * initialized with a default timeout of {@link CacheStrategy#DEFAULT_SESSION_IN_MEMORY_TIMEOUT} milliseconds. + */ + private CacheStrategy cacheStrategy = new MapCacheStrategy(); /** * expire time in seconds. * NOTE: Please make sure expire is longer than session.getTimeout(), * otherwise you might need the issue that session in Redis got erased when the Session is still available - * + *

* DEFAULT_EXPIRE: use the timeout of session instead of setting it by yourself * NO_EXPIRE: never expire */ @@ -87,7 +83,7 @@ public void update(Session session) throws UnknownSessionException { private void saveSession(Session session) throws UnknownSessionException { if (session == null || session.getId() == null) { - logger.error("session or session id is null"); + LOGGER.error("session or session id is null"); throw new UnknownSessionException("session or session id is null"); } byte[] key; @@ -96,7 +92,7 @@ private void saveSession(Session session) throws UnknownSessionException { key = keySerializer.serialize(getRedisSessionKey(session.getId())); value = valueSerializer.serialize(session); } catch (SerializationException e) { - logger.error("serialize session error. session id=" + session.getId()); + LOGGER.error("serialize session error. session id=" + session.getId()); throw new UnknownSessionException(e); } if (expire == DEFAULT_EXPIRE) { @@ -104,11 +100,11 @@ private void saveSession(Session session) throws UnknownSessionException { return; } if (expire != NO_EXPIRE && expire * MILLISECONDS_IN_A_SECOND < session.getTimeout()) { - logger.warn("Redis session expire time: " - + (expire * MILLISECONDS_IN_A_SECOND) - + " is less than Session timeout: " - + session.getTimeout() - + " . It may cause some problems."); + LOGGER.warn("Redis session expire time: " + + (expire * MILLISECONDS_IN_A_SECOND) + + " is less than Session timeout: " + + session.getTimeout() + + " . It may cause some problems."); } redisManager.set(key, value, expire); } @@ -123,7 +119,7 @@ public void delete(Session session) { this.removeExpiredSessionInMemory(); } if (session == null || session.getId() == null) { - logger.error("session or session id is null"); + LOGGER.error("session or session id is null"); return; } if (this.sessionInMemoryEnabled) { @@ -132,7 +128,7 @@ public void delete(Session session) { try { redisManager.del(keySerializer.serialize(getRedisSessionKey(session.getId()))); } catch (SerializationException e) { - logger.error("delete session error. session id=" + session.getId()); + LOGGER.error("delete session error. session id=" + session.getId()); } } @@ -158,7 +154,7 @@ public Collection getActiveSessions() { } } } catch (SerializationException e) { - logger.error("get active sessions error."); + LOGGER.error("get active sessions error."); } return sessions; } @@ -169,7 +165,7 @@ protected Serializable doCreate(Session session) { this.removeExpiredSessionInMemory(); } if (session == null) { - logger.error("session is null"); + LOGGER.error("session is null"); throw new UnknownSessionException("session is null"); } Serializable sessionId = this.generateSessionId(session); @@ -189,7 +185,7 @@ protected Session doReadSession(Serializable sessionId) { this.removeExpiredSessionInMemory(); } if (sessionId == null) { - logger.warn("session id is null"); + LOGGER.warn("session id is null"); return null; } if (this.sessionInMemoryEnabled) { @@ -201,87 +197,31 @@ protected Session doReadSession(Serializable sessionId) { Session session = null; try { String sessionRedisKey = getRedisSessionKey(sessionId); - logger.debug("read session: " + sessionRedisKey + " from Redis"); + LOGGER.debug("read session: " + sessionRedisKey + " from Redis"); session = (Session) valueSerializer.deserialize(redisManager.get(keySerializer.serialize(sessionRedisKey))); if (this.sessionInMemoryEnabled) { setSessionToThreadLocal(sessionId, session); } } catch (SerializationException e) { - logger.error("read session error. sessionId: " + sessionId); + LOGGER.error("read session error. sessionId: " + sessionId); } return session; } private void setSessionToThreadLocal(Serializable sessionId, Session session) { - this.initSessionsInThread(); - Map sessionMap = (Map) sessionsInThread.get(); - sessionMap.put(sessionId, this.createSessionInMemory(session)); + this.cacheStrategy.put(sessionId, session); } private void delSessionFromThreadLocal(Serializable sessionId) { - Map sessionMap = (Map) sessionsInThread.get(); - if (sessionMap == null) { - return; - } - sessionMap.remove(sessionId); - } - - private SessionInMemory createSessionInMemory(Session session) { - SessionInMemory sessionInMemory = new SessionInMemory(); - sessionInMemory.setCreateTime(new Date()); - sessionInMemory.setSession(session); - return sessionInMemory; - } - - private void initSessionsInThread() { - Map sessionMap = (Map) sessionsInThread.get(); - if (sessionMap == null) { - sessionMap = new HashMap(); - sessionsInThread.set(sessionMap); - } + this.cacheStrategy.remove(sessionId); } private void removeExpiredSessionInMemory() { - Map sessionMap = (Map) sessionsInThread.get(); - if (sessionMap == null) { - return; - } - Iterator it = sessionMap.keySet().iterator(); - while (it.hasNext()) { - Serializable sessionId = it.next(); - SessionInMemory sessionInMemory = sessionMap.get(sessionId); - if (sessionInMemory == null) { - it.remove(); - continue; - } - long liveTime = getSessionInMemoryLiveTime(sessionInMemory); - if (liveTime > sessionInMemoryTimeout) { - it.remove(); - } - } - if (sessionMap.size() == 0) { - sessionsInThread.remove(); - } + this.cacheStrategy.removeExpired(); } private Session getSessionFromThreadLocal(Serializable sessionId) { - if (sessionsInThread.get() == null) { - return null; - } - - Map sessionMap = (Map) sessionsInThread.get(); - SessionInMemory sessionInMemory = sessionMap.get(sessionId); - if (sessionInMemory == null) { - return null; - } - - logger.debug("read session from memory"); - return sessionInMemory.getSession(); - } - - private long getSessionInMemoryLiveTime(SessionInMemory sessionInMemory) { - Date now = new Date(); - return now.getTime() - sessionInMemory.getCreateTime().getTime(); + return this.cacheStrategy.get(sessionId); } private String getRedisSessionKey(Serializable sessionId) { @@ -321,11 +261,11 @@ public void setValueSerializer(RedisSerializer valueSerializer) { } public long getSessionInMemoryTimeout() { - return sessionInMemoryTimeout; + return cacheStrategy.getSessionInMemoryTimeout(); } public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) { - this.sessionInMemoryTimeout = sessionInMemoryTimeout; + this.cacheStrategy.setSessionInMemoryTimeout(sessionInMemoryTimeout); } public int getExpire() { @@ -344,7 +284,11 @@ public void setSessionInMemoryEnabled(boolean sessionInMemoryEnabled) { this.sessionInMemoryEnabled = sessionInMemoryEnabled; } - public static ThreadLocal getSessionsInThread() { - return sessionsInThread; + public CacheStrategy getCacheStrategy() { + return cacheStrategy; + } + + public void setCacheStrategy(CacheStrategy cacheStrategy) { + this.cacheStrategy = cacheStrategy; } } diff --git a/src/main/java/org/crazycake/shiro/cache/CacheStrategy.java b/src/main/java/org/crazycake/shiro/cache/CacheStrategy.java new file mode 100644 index 00000000..16231902 --- /dev/null +++ b/src/main/java/org/crazycake/shiro/cache/CacheStrategy.java @@ -0,0 +1,66 @@ +package org.crazycake.shiro.cache; + +import org.apache.shiro.session.Session; + +import java.io.Serializable; + +/** + * Cache strategy + * + * @author Teamo + * @since 2025/3/30 + */ +public interface CacheStrategy { + + /** + * doReadSession be called about 10 times when login. + * Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal. + * The default value is 1000 milliseconds (1s). + * Most of time, you don't need to change it. + *

+ * You can turn it off by setting sessionInMemoryEnabled to false + */ + long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L; + + /** + * Put session into cache + * + * @param sessionId session id + * @param session session + */ + void put(Serializable sessionId, Session session); + + /** + * Get session from cache + * + * @param sessionId session id + * @return session + */ + Session get(Serializable sessionId); + + /** + * Remove session from cache + * + * @param sessionId session id + */ + void remove(Serializable sessionId); + + /** + * Remove expired session from cache + */ + void removeExpired(); + + /** + * Set session in memory timeout + * + * @param sessionInMemoryTimeout session in memory timeout + */ + void setSessionInMemoryTimeout(long sessionInMemoryTimeout); + + /** + * Get session in memory timeout + * + * @return session in memory timeout + */ + long getSessionInMemoryTimeout(); +} diff --git a/src/main/java/org/crazycake/shiro/cache/CaffeineCacheStrategy.java b/src/main/java/org/crazycake/shiro/cache/CaffeineCacheStrategy.java new file mode 100644 index 00000000..54bd77e7 --- /dev/null +++ b/src/main/java/org/crazycake/shiro/cache/CaffeineCacheStrategy.java @@ -0,0 +1,89 @@ +package org.crazycake.shiro.cache; + +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import org.apache.shiro.session.Session; +import org.crazycake.shiro.common.SessionInMemory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.util.Date; +import java.util.concurrent.TimeUnit; + +/** + * CaffeineCacheStrategy + * + * @author Teamo + * @since 2025/3/30 + */ +public class CaffeineCacheStrategy implements CacheStrategy { + + private static final Logger LOGGER = LoggerFactory.getLogger(CaffeineCacheStrategy.class); + + private final ThreadLocal> sessionsInThread = new ThreadLocal<>(); + + private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT; + + @Override + public void put(Serializable sessionId, Session session) { + getCache().put(sessionId, createSessionInMemory(session)); + } + + @Override + public Session get(Serializable sessionId) { + SessionInMemory sessionInMemory = getCache().getIfPresent(sessionId); + if (sessionInMemory == null) { + return null; + } + LOGGER.debug("read session from caffeine cache"); + return sessionInMemory.getSession(); + } + + @Override + public void remove(Serializable sessionId) { + getCache().invalidate(sessionId); + } + + @Override + public void removeExpired() { + Cache sessionCache = sessionsInThread.get(); + if (sessionCache == null) { + return; + } + + sessionCache.cleanUp(); + + if (sessionCache.asMap().isEmpty()) { + sessionsInThread.remove(); + } + } + + @Override + public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) { + this.sessionInMemoryTimeout = sessionInMemoryTimeout; + } + + @Override + public long getSessionInMemoryTimeout() { + return this.sessionInMemoryTimeout; + } + + private Cache getCache() { + Cache cache = sessionsInThread.get(); + if (cache == null) { + cache = Caffeine.newBuilder() + .expireAfterWrite(this.sessionInMemoryTimeout, TimeUnit.MILLISECONDS) + .build(); + sessionsInThread.set(cache); + } + return cache; + } + + private SessionInMemory createSessionInMemory(Session session) { + SessionInMemory s = new SessionInMemory(); + s.setCreateTime(new Date()); + s.setSession(session); + return s; + } +} diff --git a/src/main/java/org/crazycake/shiro/cache/MapCacheStrategy.java b/src/main/java/org/crazycake/shiro/cache/MapCacheStrategy.java new file mode 100644 index 00000000..5fa82f0f --- /dev/null +++ b/src/main/java/org/crazycake/shiro/cache/MapCacheStrategy.java @@ -0,0 +1,85 @@ +package org.crazycake.shiro.cache; + +import org.apache.shiro.session.Session; +import org.crazycake.shiro.common.SessionInMemory; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.Serializable; +import java.time.Instant; +import java.util.Date; +import java.util.HashMap; +import java.util.Map; + +/** + * MapCacheStrategy + * + * @author Teamo + * @since 2025/3/30 + */ +public class MapCacheStrategy implements CacheStrategy { + + private static final Logger LOGGER = LoggerFactory.getLogger(MapCacheStrategy.class); + + private final ThreadLocal> sessionsInThread = ThreadLocal.withInitial(HashMap::new); + + private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT; + + @Override + public void put(Serializable sessionId, Session session) { + sessionsInThread.get().put(sessionId, createSessionInMemory(session)); + } + + @Override + public Session get(Serializable sessionId) { + SessionInMemory sessionInMemory = sessionsInThread.get().get(sessionId); + if (sessionInMemory == null) { + return null; + } + LOGGER.debug("read session from map cache"); + return sessionInMemory.getSession(); + } + + @Override + public void remove(Serializable sessionId) { + sessionsInThread.get().remove(sessionId); + } + + @Override + public void removeExpired() { + Map sessionMap = sessionsInThread.get(); + if (sessionMap == null) { + return; + } + + sessionMap.keySet().removeIf(id -> { + SessionInMemory sessionInMemory = sessionMap.get(id); + return sessionInMemory == null || getSessionInMemoryLiveTime(sessionInMemory) > sessionInMemoryTimeout; + }); + + if (sessionMap.isEmpty()) { + sessionsInThread.remove(); + } + } + + @Override + public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) { + this.sessionInMemoryTimeout = sessionInMemoryTimeout; + } + + @Override + public long getSessionInMemoryTimeout() { + return this.sessionInMemoryTimeout; + } + + private SessionInMemory createSessionInMemory(Session session) { + SessionInMemory sessionInMemory = new SessionInMemory(); + sessionInMemory.setCreateTime(new Date()); + sessionInMemory.setSession(session); + return sessionInMemory; + } + + private long getSessionInMemoryLiveTime(SessionInMemory sessionInMemory) { + return Instant.now().toEpochMilli() - sessionInMemory.getCreateTime().getTime(); + } +} diff --git a/src/main/java/org/crazycake/shiro/common/AbstractLettuceRedisManager.java b/src/main/java/org/crazycake/shiro/common/AbstractLettuceRedisManager.java index 44ac427b..ea537ba4 100644 --- a/src/main/java/org/crazycake/shiro/common/AbstractLettuceRedisManager.java +++ b/src/main/java/org/crazycake/shiro/common/AbstractLettuceRedisManager.java @@ -16,7 +16,7 @@ * @author Teamo * @since 2022/05/19 */ -public abstract class AbstractLettuceRedisManager implements IRedisManager { +public abstract class AbstractLettuceRedisManager> implements IRedisManager, AutoCloseable { /** * Default value of count. @@ -56,16 +56,21 @@ public abstract class AbstractLettuceRedisManager implements IRedisManager { /** * genericObjectPoolConfig used to initialize GenericObjectPoolConfig object. */ - @SuppressWarnings("rawtypes") - private GenericObjectPoolConfig genericObjectPoolConfig = new GenericObjectPoolConfig<>(); + private GenericObjectPoolConfig genericObjectPoolConfig = new GenericObjectPoolConfig<>(); /** * Get a stateful connection. * * @return T */ - @SuppressWarnings("rawtypes") - protected abstract StatefulRedisConnection getStatefulConnection(); + protected abstract T getStatefulConnection(); + + /** + * Return a stateful connection. + * + * @param connect T + */ + protected abstract void returnObject(T connect); public Duration getTimeout() { return timeout; @@ -115,22 +120,23 @@ public void setClientOptions(ClientOptions clientOptions) { this.clientOptions = clientOptions; } - public GenericObjectPoolConfig getGenericObjectPoolConfig() { + public GenericObjectPoolConfig getGenericObjectPoolConfig() { return genericObjectPoolConfig; } - public void setGenericObjectPoolConfig(GenericObjectPoolConfig genericObjectPoolConfig) { + public void setGenericObjectPoolConfig(GenericObjectPoolConfig genericObjectPoolConfig) { this.genericObjectPoolConfig = genericObjectPoolConfig; } @Override - @SuppressWarnings("unchecked") public byte[] get(byte[] key) { if (key == null) { return null; } byte[] value = null; - try (StatefulRedisConnection connect = getStatefulConnection()) { + T connect = null; + try { + connect = getStatefulConnection(); if (isAsync) { RedisAsyncCommands async = connect.async(); RedisFuture redisFuture = async.get(key); @@ -139,17 +145,20 @@ public byte[] get(byte[] key) { RedisCommands sync = connect.sync(); value = sync.get(key); } + } finally { + returnObject(connect); } return value; } @Override - @SuppressWarnings({"unchecked"}) public byte[] set(byte[] key, byte[] value, int expire) { if (key == null) { return null; } - try (StatefulRedisConnection connect = getStatefulConnection()) { + T connect = null; + try { + connect = getStatefulConnection(); if (isAsync) { RedisAsyncCommands async = connect.async(); if (expire > 0) { @@ -165,14 +174,17 @@ public byte[] set(byte[] key, byte[] value, int expire) { sync.set(key, value); } } + } finally { + returnObject(connect); } return value; } @Override - @SuppressWarnings("unchecked") public void del(byte[] key) { - try (StatefulRedisConnection connect = getStatefulConnection()) { + T connect = null; + try { + connect = getStatefulConnection(); if (isAsync) { RedisAsyncCommands async = connect.async(); async.del(key); @@ -180,37 +192,45 @@ public void del(byte[] key) { RedisCommands sync = connect.sync(); sync.del(key); } + } finally { + returnObject(connect); } } @Override - @SuppressWarnings("unchecked") public Long dbSize(byte[] pattern) { long dbSize = 0L; KeyScanCursor scanCursor = new KeyScanCursor<>(); scanCursor.setCursor(ScanCursor.INITIAL.getCursor()); ScanArgs scanArgs = ScanArgs.Builder.matches(pattern).limit(count); - try (StatefulRedisConnection connect = getStatefulConnection()) { + T connect = null; + try { + connect = getStatefulConnection(); while (!scanCursor.isFinished()) { scanCursor = getKeyScanCursor(connect, scanCursor, scanArgs); dbSize += scanCursor.getKeys().size(); } + } finally { + returnObject(connect); } return dbSize; } @Override - @SuppressWarnings("unchecked") public Set keys(byte[] pattern) { Set keys = new HashSet<>(); KeyScanCursor scanCursor = new KeyScanCursor<>(); scanCursor.setCursor(ScanCursor.INITIAL.getCursor()); ScanArgs scanArgs = ScanArgs.Builder.matches(pattern).limit(count); - try (StatefulRedisConnection connect = getStatefulConnection()) { + T connect = null; + try { + connect = getStatefulConnection(); while (!scanCursor.isFinished()) { scanCursor = getKeyScanCursor(connect, scanCursor, scanArgs); keys.addAll(scanCursor.getKeys()); } + } finally { + returnObject(connect); } return keys; } diff --git a/src/test/java/org/crazycake/shiro/CacheStrategyTest.java b/src/test/java/org/crazycake/shiro/CacheStrategyTest.java new file mode 100644 index 00000000..c8dfb2cd --- /dev/null +++ b/src/test/java/org/crazycake/shiro/CacheStrategyTest.java @@ -0,0 +1,166 @@ +package org.crazycake.shiro; + +import org.apache.shiro.session.InvalidSessionException; +import org.apache.shiro.session.mgt.SimpleSession; +import org.crazycake.shiro.cache.CacheStrategy; +import org.crazycake.shiro.cache.CaffeineCacheStrategy; +import org.crazycake.shiro.cache.MapCacheStrategy; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Date; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import static org.mockito.Mockito.mock; + +/** + * @author Teamo + * @since 2025/3/31 + */ +public class CacheStrategyTest { + + private IRedisManager redisManager; + + @BeforeEach + public void setUp() { + redisManager = mock(IRedisManager.class); + } + + @Test + public void testCaffeinePerformance() throws InterruptedException { + RedisSessionDAO caffeineDAO = mountRedisSessionDAOWithCacheStrategy(new CaffeineCacheStrategy(), 3000); + measurePerformance(caffeineDAO, "CaffeineCache"); + } + + @Test + public void testMapPerformance() throws InterruptedException { + RedisSessionDAO mapDAO = mountRedisSessionDAOWithCacheStrategy(new MapCacheStrategy(), 3000); + measurePerformance(mapDAO, "MapCache"); + } + + private RedisSessionDAO mountRedisSessionDAOWithCacheStrategy( + CacheStrategy cacheStrategy, + Integer expire) { + RedisSessionDAO redisSessionDAO = new RedisSessionDAO(); + redisSessionDAO.setExpire(expire); + redisSessionDAO.setKeyPrefix("student:"); + redisSessionDAO.setRedisManager(redisManager); + redisSessionDAO.setCacheStrategy(cacheStrategy); + return redisSessionDAO; + } + + private void measurePerformance(RedisSessionDAO sessionDAO, String cacheType) throws InterruptedException { + Runtime runtime = Runtime.getRuntime(); + int threadsSize = 200; + CountDownLatch cdl = new CountDownLatch(threadsSize); + Runnable runnable = () -> { + // 调整循环次数以放大性能差异 + int loopCount = 15; + StudentSession session = null; + for (int i = 0; i < loopCount; i++) { + session = new StudentSession(i, 2000); + + sessionDAO.update(session); + + sessionDAO.doReadSession(session.getId()); + } + cdl.countDown(); + }; + List threads = new ArrayList<>(); + Thread thread = null; + for (int i = 0; i < threadsSize; i++) { + thread = new Thread(runnable); + threads.add(thread); + } + + System.gc(); + Thread.sleep(100); + + long initialUsedMemory = runtime.totalMemory() - runtime.freeMemory(); + System.out.println(cacheType + " initial used memory: " + initialUsedMemory / 1024 / 1024 + "MB"); + long startTime = System.currentTimeMillis(); + + threads.parallelStream().forEach(Thread::start); + + cdl.await(); + + long duration = System.currentTimeMillis() - startTime; + long finalUsedMemory = runtime.totalMemory() - runtime.freeMemory(); + System.out.println(cacheType + " total used time: " + duration + "ms"); + System.out.println(cacheType + " final used memory: " + ((finalUsedMemory / 1024 / 1024) - (initialUsedMemory / 1024 / 1024)) + "MB"); + } + + static class StudentSession extends SimpleSession { + private final Integer id; + private final long timeout; + + public StudentSession(Integer id, long timeout) { + this.id = id; + this.timeout = timeout; + } + + @Override + public Serializable getId() { + return id; + } + + @Override + public Date getStartTimestamp() { + return null; + } + + @Override + public Date getLastAccessTime() { + return null; + } + + @Override + public long getTimeout() throws InvalidSessionException { + return timeout; + } + + @Override + public void setTimeout(long l) throws InvalidSessionException { + + } + + @Override + public String getHost() { + return null; + } + + @Override + public void touch() throws InvalidSessionException { + + } + + @Override + public void stop() throws InvalidSessionException { + + } + + @Override + public Collection getAttributeKeys() throws InvalidSessionException { + return null; + } + + @Override + public Object getAttribute(Object o) throws InvalidSessionException { + return null; + } + + @Override + public void setAttribute(Object o, Object o1) throws InvalidSessionException { + + } + + @Override + public Object removeAttribute(Object o) throws InvalidSessionException { + return null; + } + } +}