diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index fe38b2589..7b6ce4717 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -11,7 +11,6 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; import org.slf4j.Logger; @@ -22,8 +21,10 @@ import io.modelcontextprotocol.common.McpTransportContext; import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.HttpHeaders; +import io.modelcontextprotocol.spec.InMemoryMcpSessionStore; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSessionStore; import io.modelcontextprotocol.spec.McpStreamableServerSession; import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; @@ -104,9 +105,9 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet private McpStreamableServerSession.Factory sessionFactory; /** - * Map of active client sessions, keyed by mcp-session-id. + * Store for active client sessions, keyed by mcp-session-id. */ - private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + private final McpSessionStore sessionStore; private McpTransportContextExtractor contextExtractor; @@ -141,22 +142,25 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet */ private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint, boolean disallowDelete, McpTransportContextExtractor contextExtractor, - Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) { + Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator, + McpSessionStore sessionStore) { Assert.notNull(jsonMapper, "JsonMapper must not be null"); Assert.notNull(mcpEndpoint, "MCP endpoint must not be null"); Assert.notNull(contextExtractor, "Context extractor must not be null"); Assert.notNull(securityValidator, "Security validator must not be null"); + Assert.notNull(sessionStore, "Session store must not be null"); this.jsonMapper = jsonMapper; this.mcpEndpoint = mcpEndpoint; this.disallowDelete = disallowDelete; this.contextExtractor = contextExtractor; this.securityValidator = securityValidator; + this.sessionStore = sessionStore; if (keepAliveInterval != null) { this.keepAliveScheduler = KeepAliveScheduler - .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values())) + .builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessionStore.values())) .initialDelay(keepAliveInterval) .interval(keepAliveInterval) .build(); @@ -187,15 +191,15 @@ public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory) */ @Override public Mono notifyClients(String method, Object params) { - if (this.sessions.isEmpty()) { + if (this.sessionStore.isEmpty()) { logger.debug("No active sessions to broadcast message to"); return Mono.empty(); } - logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size()); + logger.debug("Attempting to broadcast message to {} active sessions", this.sessionStore.size()); return Mono.fromRunnable(() -> { - this.sessions.values().parallelStream().forEach(session -> { + this.sessionStore.values().parallelStream().forEach(session -> { try { session.sendNotification(method, params).block(); } @@ -209,7 +213,7 @@ public Mono notifyClients(String method, Object params) { @Override public Mono notifyClient(String sessionId, String method, Object params) { return Mono.defer(() -> { - McpStreamableServerSession session = this.sessions.get(sessionId); + McpStreamableServerSession session = this.sessionStore.get(sessionId); if (session == null) { logger.debug("Session {} not found", sessionId); return Mono.empty(); @@ -226,9 +230,9 @@ public Mono notifyClient(String sessionId, String method, Object params) { public Mono closeGracefully() { return Mono.fromRunnable(() -> { this.isClosing = true; - logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size()); + logger.debug("Initiating graceful shutdown with {} active sessions", this.sessionStore.size()); - this.sessions.values().parallelStream().forEach(session -> { + this.sessionStore.values().parallelStream().forEach(session -> { try { session.closeGracefully().block(); } @@ -237,10 +241,10 @@ public Mono closeGracefully() { } }); - this.sessions.clear(); + this.sessionStore.clear(); logger.debug("Graceful shutdown completed"); }).then().doOnSuccess(v -> { - sessions.clear(); + sessionStore.clear(); logger.debug("Graceful shutdown completed"); if (this.keepAliveScheduler != null) { this.keepAliveScheduler.shutdown(); @@ -299,7 +303,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) return; } - McpStreamableServerSession session = this.sessions.get(sessionId); + McpStreamableServerSession session = this.sessionStore.get(sessionId); if (session == null) { response.sendError(HttpServletResponse.SC_NOT_FOUND); @@ -452,7 +456,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) }); McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory .startSession(initializeRequest); - this.sessions.put(init.session().getId(), init.session()); + this.sessionStore.save(init.session().getId(), init.session()); try { McpSchema.InitializeResult initResult = init.initResult().block(); @@ -493,7 +497,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response) return; } - McpStreamableServerSession session = this.sessions.get(sessionId); + McpStreamableServerSession session = this.sessionStore.get(sessionId); if (session == null) { this.responseError(response, HttpServletResponse.SC_NOT_FOUND, @@ -612,7 +616,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response } String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID); - McpStreamableServerSession session = this.sessions.get(sessionId); + McpStreamableServerSession session = this.sessionStore.get(sessionId); if (session == null) { response.sendError(HttpServletResponse.SC_NOT_FOUND); @@ -621,7 +625,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response try { session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block(); - this.sessions.remove(sessionId); + this.sessionStore.remove(sessionId); response.setStatus(HttpServletResponse.SC_OK); } catch (Exception e) { @@ -755,7 +759,7 @@ public Mono sendMessage(McpSchema.JSONRPCMessage message, String messageId } catch (Exception e) { logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage()); - HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + HttpServletStreamableServerTransportProvider.this.sessionStore.remove(this.sessionId); this.asyncContext.complete(); } finally { @@ -801,7 +805,7 @@ public void close() { this.closed = true; - // HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId); + // HttpServletStreamableServerTransportProvider.this.sessionStore.remove(this.sessionId); this.asyncContext.complete(); logger.debug("Successfully completed async context for session {}", sessionId); } @@ -838,6 +842,8 @@ public static class Builder { private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP; + private McpSessionStore sessionStore; + /** * Sets the JsonMapper to use for JSON serialization/deserialization of MCP * messages. @@ -909,6 +915,19 @@ public Builder securityValidator(ServerTransportSecurityValidator securityValida return this; } + /** + * Sets the session store for managing active client sessions. If not set, an + * {@link InMemoryMcpSessionStore} will be used by default. + * @param sessionStore The session store to use. Must not be null. + * @return this builder instance + * @throws IllegalArgumentException if sessionStore is null + */ + public Builder sessionStore(McpSessionStore sessionStore) { + Assert.notNull(sessionStore, "Session store must not be null"); + this.sessionStore = sessionStore; + return this; + } + /** * Builds a new instance of {@link HttpServletStreamableServerTransportProvider} * with the configured settings. @@ -919,7 +938,8 @@ public HttpServletStreamableServerTransportProvider build() { Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set"); return new HttpServletStreamableServerTransportProvider( jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, mcpEndpoint, disallowDelete, - contextExtractor, keepAliveInterval, securityValidator); + contextExtractor, keepAliveInterval, securityValidator, + sessionStore == null ? new InMemoryMcpSessionStore() : sessionStore); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/InMemoryMcpSessionStore.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/InMemoryMcpSessionStore.java new file mode 100644 index 000000000..109980c8d --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/InMemoryMcpSessionStore.java @@ -0,0 +1,63 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.Collection; +import java.util.concurrent.ConcurrentHashMap; + +/** + * Default in-memory implementation of {@link McpSessionStore} backed by a + * {@link ConcurrentHashMap}. This implementation is suitable for single-instance + * deployments where session state does not need to be shared across multiple server + * instances. + * + *

+ * This is the default session store used by + * {@link io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider} + * when no custom {@link McpSessionStore} is provided. + * + * @author WeiLin Wang + * @see McpSessionStore + */ +public class InMemoryMcpSessionStore implements McpSessionStore { + + private final ConcurrentHashMap sessions = new ConcurrentHashMap<>(); + + @Override + public void save(String sessionId, McpStreamableServerSession session) { + this.sessions.put(sessionId, session); + } + + @Override + public McpStreamableServerSession get(String sessionId) { + return this.sessions.get(sessionId); + } + + @Override + public McpStreamableServerSession remove(String sessionId) { + return this.sessions.remove(sessionId); + } + + @Override + public Collection values() { + return this.sessions.values(); + } + + @Override + public boolean isEmpty() { + return this.sessions.isEmpty(); + } + + @Override + public int size() { + return this.sessions.size(); + } + + @Override + public void clear() { + this.sessions.clear(); + } + +} diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSessionStore.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSessionStore.java new file mode 100644 index 000000000..1d182f7d8 --- /dev/null +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpSessionStore.java @@ -0,0 +1,78 @@ +/* + * Copyright 2024-2026 the original author or authors. + */ + +package io.modelcontextprotocol.spec; + +import java.util.Collection; + +/** + * Strategy interface for storing and retrieving MCP server sessions. This abstraction + * allows the session storage mechanism to be customized, enabling implementations such as + * in-memory (default), Redis-backed, JDBC-backed, or any distributed store. + * + *

+ * The default implementation {@link InMemoryMcpSessionStore} uses a + * {@link java.util.concurrent.ConcurrentHashMap} which is suitable for single-instance + * deployments. For distributed or multi-instance deployments, a custom implementation + * backed by a distributed data store should be used. + * + *

+ * Note: {@link McpStreamableServerSession} objects contain active transport connections + * (SSE streams) that are inherently tied to the JVM instance. A distributed session store + * therefore stores the session reference per-node and coordinates session lifecycle + * across nodes (e.g., detecting when a session was created on a different node). + * + * @author WeiLin Wang + * @see InMemoryMcpSessionStore + * @see McpStreamableServerSession + */ +public interface McpSessionStore { + + /** + * Stores a session with the given ID. If a session with the same ID already exists, + * it will be replaced. + * @param sessionId the unique session identifier + * @param session the session to store + */ + void save(String sessionId, McpStreamableServerSession session); + + /** + * Retrieves a session by its ID. + * @param sessionId the unique session identifier + * @return the session associated with the given ID, or {@code null} if not found + */ + McpStreamableServerSession get(String sessionId); + + /** + * Removes a session by its ID. + * @param sessionId the unique session identifier + * @return the previously stored session, or {@code null} if no session was stored + * with the given ID + */ + McpStreamableServerSession remove(String sessionId); + + /** + * Returns all currently stored sessions. + * @return a collection of all stored sessions; never {@code null} + */ + Collection values(); + + /** + * Returns whether there are any sessions stored. + * @return {@code true} if no sessions are stored, {@code false} otherwise + */ + boolean isEmpty(); + + /** + * Returns the number of stored sessions. + * @return the session count + */ + int size(); + + /** + * Removes all stored sessions. + */ + void clear(); + +}