> genericObjectPool) {
+ this.genericObjectPool = genericObjectPool;
+ }
}
diff --git a/src/main/java/org/crazycake/shiro/RedisSessionDAO.java b/src/main/java/org/crazycake/shiro/RedisSessionDAO.java
index 6069a85b..08377a3e 100644
--- a/src/main/java/org/crazycake/shiro/RedisSessionDAO.java
+++ b/src/main/java/org/crazycake/shiro/RedisSessionDAO.java
@@ -3,7 +3,8 @@
import org.apache.shiro.session.Session;
import org.apache.shiro.session.UnknownSessionException;
import org.apache.shiro.session.mgt.eis.AbstractSessionDAO;
-import org.crazycake.shiro.common.SessionInMemory;
+import org.crazycake.shiro.cache.CacheStrategy;
+import org.crazycake.shiro.cache.MapCacheStrategy;
import org.crazycake.shiro.exception.SerializationException;
import org.crazycake.shiro.serializer.ObjectSerializer;
import org.crazycake.shiro.serializer.RedisSerializer;
@@ -12,39 +13,34 @@
import org.slf4j.LoggerFactory;
import java.io.Serializable;
-import java.util.*;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Set;
/**
* Used for setting/getting authentication information from Redis
*/
public class RedisSessionDAO extends AbstractSessionDAO {
- private static Logger logger = LoggerFactory.getLogger(RedisSessionDAO.class);
+ private static final Logger LOGGER = LoggerFactory.getLogger(RedisSessionDAO.class);
private static final String DEFAULT_SESSION_KEY_PREFIX = "shiro:session:";
private String keyPrefix = DEFAULT_SESSION_KEY_PREFIX;
- /**
- * doReadSession be called about 10 times when login.
- * Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal.
- * The default value is 1000 milliseconds (1s).
- * Most of time, you don't need to change it.
- *
- * You can turn it off by setting sessionInMemoryEnabled to false
- */
- private static final long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L;
- private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT;
-
private static final boolean DEFAULT_SESSION_IN_MEMORY_ENABLED = true;
private boolean sessionInMemoryEnabled = DEFAULT_SESSION_IN_MEMORY_ENABLED;
- private static ThreadLocal sessionsInThread = new ThreadLocal();
+ /**
+ * The cache strategy implementation (e.g., MapCacheStrategy) for managing in-memory session storage
+ * initialized with a default timeout of {@link CacheStrategy#DEFAULT_SESSION_IN_MEMORY_TIMEOUT} milliseconds.
+ */
+ private CacheStrategy cacheStrategy = new MapCacheStrategy();
/**
* expire time in seconds.
* NOTE: Please make sure expire is longer than session.getTimeout(),
* otherwise you might need the issue that session in Redis got erased when the Session is still available
- *
+ *
* DEFAULT_EXPIRE: use the timeout of session instead of setting it by yourself
* NO_EXPIRE: never expire
*/
@@ -87,7 +83,7 @@ public void update(Session session) throws UnknownSessionException {
private void saveSession(Session session) throws UnknownSessionException {
if (session == null || session.getId() == null) {
- logger.error("session or session id is null");
+ LOGGER.error("session or session id is null");
throw new UnknownSessionException("session or session id is null");
}
byte[] key;
@@ -96,7 +92,7 @@ private void saveSession(Session session) throws UnknownSessionException {
key = keySerializer.serialize(getRedisSessionKey(session.getId()));
value = valueSerializer.serialize(session);
} catch (SerializationException e) {
- logger.error("serialize session error. session id=" + session.getId());
+ LOGGER.error("serialize session error. session id=" + session.getId());
throw new UnknownSessionException(e);
}
if (expire == DEFAULT_EXPIRE) {
@@ -104,11 +100,11 @@ private void saveSession(Session session) throws UnknownSessionException {
return;
}
if (expire != NO_EXPIRE && expire * MILLISECONDS_IN_A_SECOND < session.getTimeout()) {
- logger.warn("Redis session expire time: "
- + (expire * MILLISECONDS_IN_A_SECOND)
- + " is less than Session timeout: "
- + session.getTimeout()
- + " . It may cause some problems.");
+ LOGGER.warn("Redis session expire time: "
+ + (expire * MILLISECONDS_IN_A_SECOND)
+ + " is less than Session timeout: "
+ + session.getTimeout()
+ + " . It may cause some problems.");
}
redisManager.set(key, value, expire);
}
@@ -123,7 +119,7 @@ public void delete(Session session) {
this.removeExpiredSessionInMemory();
}
if (session == null || session.getId() == null) {
- logger.error("session or session id is null");
+ LOGGER.error("session or session id is null");
return;
}
if (this.sessionInMemoryEnabled) {
@@ -132,7 +128,7 @@ public void delete(Session session) {
try {
redisManager.del(keySerializer.serialize(getRedisSessionKey(session.getId())));
} catch (SerializationException e) {
- logger.error("delete session error. session id=" + session.getId());
+ LOGGER.error("delete session error. session id=" + session.getId());
}
}
@@ -158,7 +154,7 @@ public Collection getActiveSessions() {
}
}
} catch (SerializationException e) {
- logger.error("get active sessions error.");
+ LOGGER.error("get active sessions error.");
}
return sessions;
}
@@ -169,7 +165,7 @@ protected Serializable doCreate(Session session) {
this.removeExpiredSessionInMemory();
}
if (session == null) {
- logger.error("session is null");
+ LOGGER.error("session is null");
throw new UnknownSessionException("session is null");
}
Serializable sessionId = this.generateSessionId(session);
@@ -189,7 +185,7 @@ protected Session doReadSession(Serializable sessionId) {
this.removeExpiredSessionInMemory();
}
if (sessionId == null) {
- logger.warn("session id is null");
+ LOGGER.warn("session id is null");
return null;
}
if (this.sessionInMemoryEnabled) {
@@ -201,87 +197,31 @@ protected Session doReadSession(Serializable sessionId) {
Session session = null;
try {
String sessionRedisKey = getRedisSessionKey(sessionId);
- logger.debug("read session: " + sessionRedisKey + " from Redis");
+ LOGGER.debug("read session: " + sessionRedisKey + " from Redis");
session = (Session) valueSerializer.deserialize(redisManager.get(keySerializer.serialize(sessionRedisKey)));
if (this.sessionInMemoryEnabled) {
setSessionToThreadLocal(sessionId, session);
}
} catch (SerializationException e) {
- logger.error("read session error. sessionId: " + sessionId);
+ LOGGER.error("read session error. sessionId: " + sessionId);
}
return session;
}
private void setSessionToThreadLocal(Serializable sessionId, Session session) {
- this.initSessionsInThread();
- Map sessionMap = (Map) sessionsInThread.get();
- sessionMap.put(sessionId, this.createSessionInMemory(session));
+ this.cacheStrategy.put(sessionId, session);
}
private void delSessionFromThreadLocal(Serializable sessionId) {
- Map sessionMap = (Map) sessionsInThread.get();
- if (sessionMap == null) {
- return;
- }
- sessionMap.remove(sessionId);
- }
-
- private SessionInMemory createSessionInMemory(Session session) {
- SessionInMemory sessionInMemory = new SessionInMemory();
- sessionInMemory.setCreateTime(new Date());
- sessionInMemory.setSession(session);
- return sessionInMemory;
- }
-
- private void initSessionsInThread() {
- Map sessionMap = (Map) sessionsInThread.get();
- if (sessionMap == null) {
- sessionMap = new HashMap();
- sessionsInThread.set(sessionMap);
- }
+ this.cacheStrategy.remove(sessionId);
}
private void removeExpiredSessionInMemory() {
- Map sessionMap = (Map) sessionsInThread.get();
- if (sessionMap == null) {
- return;
- }
- Iterator it = sessionMap.keySet().iterator();
- while (it.hasNext()) {
- Serializable sessionId = it.next();
- SessionInMemory sessionInMemory = sessionMap.get(sessionId);
- if (sessionInMemory == null) {
- it.remove();
- continue;
- }
- long liveTime = getSessionInMemoryLiveTime(sessionInMemory);
- if (liveTime > sessionInMemoryTimeout) {
- it.remove();
- }
- }
- if (sessionMap.size() == 0) {
- sessionsInThread.remove();
- }
+ this.cacheStrategy.removeExpired();
}
private Session getSessionFromThreadLocal(Serializable sessionId) {
- if (sessionsInThread.get() == null) {
- return null;
- }
-
- Map sessionMap = (Map) sessionsInThread.get();
- SessionInMemory sessionInMemory = sessionMap.get(sessionId);
- if (sessionInMemory == null) {
- return null;
- }
-
- logger.debug("read session from memory");
- return sessionInMemory.getSession();
- }
-
- private long getSessionInMemoryLiveTime(SessionInMemory sessionInMemory) {
- Date now = new Date();
- return now.getTime() - sessionInMemory.getCreateTime().getTime();
+ return this.cacheStrategy.get(sessionId);
}
private String getRedisSessionKey(Serializable sessionId) {
@@ -321,11 +261,11 @@ public void setValueSerializer(RedisSerializer valueSerializer) {
}
public long getSessionInMemoryTimeout() {
- return sessionInMemoryTimeout;
+ return cacheStrategy.getSessionInMemoryTimeout();
}
public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) {
- this.sessionInMemoryTimeout = sessionInMemoryTimeout;
+ this.cacheStrategy.setSessionInMemoryTimeout(sessionInMemoryTimeout);
}
public int getExpire() {
@@ -344,7 +284,11 @@ public void setSessionInMemoryEnabled(boolean sessionInMemoryEnabled) {
this.sessionInMemoryEnabled = sessionInMemoryEnabled;
}
- public static ThreadLocal getSessionsInThread() {
- return sessionsInThread;
+ public CacheStrategy getCacheStrategy() {
+ return cacheStrategy;
+ }
+
+ public void setCacheStrategy(CacheStrategy cacheStrategy) {
+ this.cacheStrategy = cacheStrategy;
}
}
diff --git a/src/main/java/org/crazycake/shiro/cache/CacheStrategy.java b/src/main/java/org/crazycake/shiro/cache/CacheStrategy.java
new file mode 100644
index 00000000..16231902
--- /dev/null
+++ b/src/main/java/org/crazycake/shiro/cache/CacheStrategy.java
@@ -0,0 +1,66 @@
+package org.crazycake.shiro.cache;
+
+import org.apache.shiro.session.Session;
+
+import java.io.Serializable;
+
+/**
+ * Cache strategy
+ *
+ * @author Teamo
+ * @since 2025/3/30
+ */
+public interface CacheStrategy {
+
+ /**
+ * doReadSession be called about 10 times when login.
+ * Save Session in ThreadLocal to resolve this problem. sessionInMemoryTimeout is expiration of Session in ThreadLocal.
+ * The default value is 1000 milliseconds (1s).
+ * Most of time, you don't need to change it.
+ *
+ * You can turn it off by setting sessionInMemoryEnabled to false
+ */
+ long DEFAULT_SESSION_IN_MEMORY_TIMEOUT = 1000L;
+
+ /**
+ * Put session into cache
+ *
+ * @param sessionId session id
+ * @param session session
+ */
+ void put(Serializable sessionId, Session session);
+
+ /**
+ * Get session from cache
+ *
+ * @param sessionId session id
+ * @return session
+ */
+ Session get(Serializable sessionId);
+
+ /**
+ * Remove session from cache
+ *
+ * @param sessionId session id
+ */
+ void remove(Serializable sessionId);
+
+ /**
+ * Remove expired session from cache
+ */
+ void removeExpired();
+
+ /**
+ * Set session in memory timeout
+ *
+ * @param sessionInMemoryTimeout session in memory timeout
+ */
+ void setSessionInMemoryTimeout(long sessionInMemoryTimeout);
+
+ /**
+ * Get session in memory timeout
+ *
+ * @return session in memory timeout
+ */
+ long getSessionInMemoryTimeout();
+}
diff --git a/src/main/java/org/crazycake/shiro/cache/CaffeineCacheStrategy.java b/src/main/java/org/crazycake/shiro/cache/CaffeineCacheStrategy.java
new file mode 100644
index 00000000..54bd77e7
--- /dev/null
+++ b/src/main/java/org/crazycake/shiro/cache/CaffeineCacheStrategy.java
@@ -0,0 +1,89 @@
+package org.crazycake.shiro.cache;
+
+import com.github.benmanes.caffeine.cache.Cache;
+import com.github.benmanes.caffeine.cache.Caffeine;
+import org.apache.shiro.session.Session;
+import org.crazycake.shiro.common.SessionInMemory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.util.Date;
+import java.util.concurrent.TimeUnit;
+
+/**
+ * CaffeineCacheStrategy
+ *
+ * @author Teamo
+ * @since 2025/3/30
+ */
+public class CaffeineCacheStrategy implements CacheStrategy {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(CaffeineCacheStrategy.class);
+
+ private final ThreadLocal> sessionsInThread = new ThreadLocal<>();
+
+ private long sessionInMemoryTimeout = DEFAULT_SESSION_IN_MEMORY_TIMEOUT;
+
+ @Override
+ public void put(Serializable sessionId, Session session) {
+ getCache().put(sessionId, createSessionInMemory(session));
+ }
+
+ @Override
+ public Session get(Serializable sessionId) {
+ SessionInMemory sessionInMemory = getCache().getIfPresent(sessionId);
+ if (sessionInMemory == null) {
+ return null;
+ }
+ LOGGER.debug("read session from caffeine cache");
+ return sessionInMemory.getSession();
+ }
+
+ @Override
+ public void remove(Serializable sessionId) {
+ getCache().invalidate(sessionId);
+ }
+
+ @Override
+ public void removeExpired() {
+ Cache sessionCache = sessionsInThread.get();
+ if (sessionCache == null) {
+ return;
+ }
+
+ sessionCache.cleanUp();
+
+ if (sessionCache.asMap().isEmpty()) {
+ sessionsInThread.remove();
+ }
+ }
+
+ @Override
+ public void setSessionInMemoryTimeout(long sessionInMemoryTimeout) {
+ this.sessionInMemoryTimeout = sessionInMemoryTimeout;
+ }
+
+ @Override
+ public long getSessionInMemoryTimeout() {
+ return this.sessionInMemoryTimeout;
+ }
+
+ private Cache getCache() {
+ Cache cache = sessionsInThread.get();
+ if (cache == null) {
+ cache = Caffeine.newBuilder()
+ .expireAfterWrite(this.sessionInMemoryTimeout, TimeUnit.MILLISECONDS)
+ .build();
+ sessionsInThread.set(cache);
+ }
+ return cache;
+ }
+
+ private SessionInMemory createSessionInMemory(Session session) {
+ SessionInMemory s = new SessionInMemory();
+ s.setCreateTime(new Date());
+ s.setSession(session);
+ return s;
+ }
+}
diff --git a/src/main/java/org/crazycake/shiro/cache/MapCacheStrategy.java b/src/main/java/org/crazycake/shiro/cache/MapCacheStrategy.java
new file mode 100644
index 00000000..5fa82f0f
--- /dev/null
+++ b/src/main/java/org/crazycake/shiro/cache/MapCacheStrategy.java
@@ -0,0 +1,85 @@
+package org.crazycake.shiro.cache;
+
+import org.apache.shiro.session.Session;
+import org.crazycake.shiro.common.SessionInMemory;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.Serializable;
+import java.time.Instant;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * MapCacheStrategy
+ *
+ * @author Teamo
+ * @since 2025/3/30
+ */
+public class MapCacheStrategy implements CacheStrategy {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(MapCacheStrategy.class);
+
+ private final ThreadLocal