1616
1717package org .springframework .ai .chat .memory .jdbc ;
1818
19+ import org .springframework .ai .chat .memory .ChatMemory ;
20+ import org .springframework .ai .chat .messages .*;
21+ import org .springframework .boot .autoconfigure .condition .ConditionalOnMissingClass ;
22+ import org .springframework .boot .jdbc .DatabaseDriver ;
23+ import org .springframework .jdbc .core .BatchPreparedStatementSetter ;
24+ import org .springframework .jdbc .core .JdbcTemplate ;
25+ import org .springframework .jdbc .core .RowMapper ;
26+ import org .springframework .stereotype .Component ;
27+ import org .springframework .util .Assert ;
28+
29+ import java .sql .Connection ;
1930import java .sql .PreparedStatement ;
2031import java .sql .ResultSet ;
2132import java .sql .SQLException ;
2233import java .util .List ;
2334
24- import org .springframework .ai .chat .memory .ChatMemory ;
25- import org .springframework .ai .chat .messages .AssistantMessage ;
26- import org .springframework .ai .chat .messages .Message ;
27- import org .springframework .ai .chat .messages .MessageType ;
28- import org .springframework .ai .chat .messages .SystemMessage ;
29- import org .springframework .ai .chat .messages .UserMessage ;
30- import org .springframework .jdbc .core .BatchPreparedStatementSetter ;
31- import org .springframework .jdbc .core .JdbcTemplate ;
32- import org .springframework .jdbc .core .RowMapper ;
35+ import static org .springframework .boot .jdbc .DatabaseDriver .SQLSERVER ;
3336
3437/**
3538 * An implementation of {@link ChatMemory} for JDBC. Creating an instance of
3639 * JdbcChatMemory example:
3740 * <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
3841 *
3942 * @author Jonathan Leijendekker
43+ * @author Xavier Chopin
4044 * @since 1.0.0
4145 */
4246public class JdbcChatMemory implements ChatMemory {
@@ -45,14 +49,33 @@ public class JdbcChatMemory implements ChatMemory {
4549 INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""" ;
4650
4751 private static final String QUERY_GET = """
48- SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""" ;
52+ SELECT content, type \
53+ FROM ai_chat_memory \
54+ WHERE conversation_id = ? \
55+ ORDER BY "timestamp" DESC \
56+ LIMIT ?
57+ """ ;
58+
59+ private static final String MSSQL_QUERY_GET = """
60+ SELECT content, type \
61+ FROM ( \
62+ SELECT TOP (?) content, type, [timestamp] \
63+ FROM ai_chat_memory \
64+ WHERE conversation_id = ? \
65+ ORDER BY [timestamp] DESC \
66+ ) AS recent \
67+ ORDER BY [timestamp] ASC \
68+ """ ;
4969
5070 private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?" ;
5171
5272 private final JdbcTemplate jdbcTemplate ;
5373
74+ private final DatabaseDriver driver ;
75+
5476 public JdbcChatMemory (JdbcChatMemoryConfig config ) {
5577 this .jdbcTemplate = config .getJdbcTemplate ();
78+ this .driver = this .detectDialect (this .jdbcTemplate );
5679 }
5780
5881 public static JdbcChatMemory create (JdbcChatMemoryConfig config ) {
@@ -66,16 +89,19 @@ public void add(String conversationId, List<Message> messages) {
6689
6790 @ Override
6891 public List <Message > get (String conversationId , int lastN ) {
69- return this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
92+ return switch (driver ) {
93+ case SQLSERVER -> this .jdbcTemplate .query (MSSQL_QUERY_GET , new MessageRowMapper (), lastN , conversationId );
94+ default -> this .jdbcTemplate .query (QUERY_GET , new MessageRowMapper (), conversationId , lastN );
95+ };
7096 }
7197
7298 @ Override
7399 public void clear (String conversationId ) {
74100 this .jdbcTemplate .update (QUERY_CLEAR , conversationId );
75101 }
76102
77- private record AddBatchPreparedStatement (String conversationId ,
78- List < Message > messages ) implements BatchPreparedStatementSetter {
103+ private record AddBatchPreparedStatement (String conversationId , List < Message > messages )
104+ implements BatchPreparedStatementSetter {
79105 @ Override
80106 public void setValues (PreparedStatement ps , int i ) throws SQLException {
81107 var message = this .messages .get (i );
@@ -108,4 +134,15 @@ public Message mapRow(ResultSet rs, int i) throws SQLException {
108134
109135 }
110136
137+ private DatabaseDriver detectDialect (JdbcTemplate jdbcTemplate ) {
138+ try {
139+ Assert .notNull (jdbcTemplate .getDataSource (), "jdbcTemplate.dataSource must not be null" );
140+ try (Connection conn = jdbcTemplate .getDataSource ().getConnection ()) {
141+ String url = conn .getMetaData ().getURL ();
142+ return DatabaseDriver .fromJdbcUrl (url );
143+ }
144+ } catch (SQLException ex ) {
145+ throw new IllegalStateException ("Impossible to detect dialect" , ex );
146+ }
147+ }
111148}
0 commit comments