Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.locks.ReentrantLock;

import org.slf4j.Logger;
Expand All @@ -22,8 +21,10 @@
import io.modelcontextprotocol.common.McpTransportContext;
import io.modelcontextprotocol.server.McpTransportContextExtractor;
import io.modelcontextprotocol.spec.HttpHeaders;
import io.modelcontextprotocol.spec.InMemoryMcpSessionStore;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSessionStore;
import io.modelcontextprotocol.spec.McpStreamableServerSession;
import io.modelcontextprotocol.spec.McpStreamableServerTransport;
import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider;
Expand Down Expand Up @@ -104,9 +105,9 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
private McpStreamableServerSession.Factory sessionFactory;

/**
* Map of active client sessions, keyed by mcp-session-id.
* Store for active client sessions, keyed by mcp-session-id.
*/
private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();
private final McpSessionStore sessionStore;

private McpTransportContextExtractor<HttpServletRequest> contextExtractor;

Expand Down Expand Up @@ -141,22 +142,25 @@ public class HttpServletStreamableServerTransportProvider extends HttpServlet
*/
private HttpServletStreamableServerTransportProvider(McpJsonMapper jsonMapper, String mcpEndpoint,
boolean disallowDelete, McpTransportContextExtractor<HttpServletRequest> contextExtractor,
Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator) {
Duration keepAliveInterval, ServerTransportSecurityValidator securityValidator,
McpSessionStore sessionStore) {
Assert.notNull(jsonMapper, "JsonMapper must not be null");
Assert.notNull(mcpEndpoint, "MCP endpoint must not be null");
Assert.notNull(contextExtractor, "Context extractor must not be null");
Assert.notNull(securityValidator, "Security validator must not be null");
Assert.notNull(sessionStore, "Session store must not be null");

this.jsonMapper = jsonMapper;
this.mcpEndpoint = mcpEndpoint;
this.disallowDelete = disallowDelete;
this.contextExtractor = contextExtractor;
this.securityValidator = securityValidator;
this.sessionStore = sessionStore;

if (keepAliveInterval != null) {

this.keepAliveScheduler = KeepAliveScheduler
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessions.values()))
.builder(() -> (isClosing) ? Flux.empty() : Flux.fromIterable(sessionStore.values()))
.initialDelay(keepAliveInterval)
.interval(keepAliveInterval)
.build();
Expand Down Expand Up @@ -187,15 +191,15 @@ public void setSessionFactory(McpStreamableServerSession.Factory sessionFactory)
*/
@Override
public Mono<Void> notifyClients(String method, Object params) {
if (this.sessions.isEmpty()) {
if (this.sessionStore.isEmpty()) {
logger.debug("No active sessions to broadcast message to");
return Mono.empty();
}

logger.debug("Attempting to broadcast message to {} active sessions", this.sessions.size());
logger.debug("Attempting to broadcast message to {} active sessions", this.sessionStore.size());

return Mono.fromRunnable(() -> {
this.sessions.values().parallelStream().forEach(session -> {
this.sessionStore.values().parallelStream().forEach(session -> {
try {
session.sendNotification(method, params).block();
}
Expand All @@ -209,7 +213,7 @@ public Mono<Void> notifyClients(String method, Object params) {
@Override
public Mono<Void> notifyClient(String sessionId, String method, Object params) {
return Mono.defer(() -> {
McpStreamableServerSession session = this.sessions.get(sessionId);
McpStreamableServerSession session = this.sessionStore.get(sessionId);
if (session == null) {
logger.debug("Session {} not found", sessionId);
return Mono.empty();
Expand All @@ -226,9 +230,9 @@ public Mono<Void> notifyClient(String sessionId, String method, Object params) {
public Mono<Void> closeGracefully() {
return Mono.fromRunnable(() -> {
this.isClosing = true;
logger.debug("Initiating graceful shutdown with {} active sessions", this.sessions.size());
logger.debug("Initiating graceful shutdown with {} active sessions", this.sessionStore.size());

this.sessions.values().parallelStream().forEach(session -> {
this.sessionStore.values().parallelStream().forEach(session -> {
try {
session.closeGracefully().block();
}
Expand All @@ -237,10 +241,10 @@ public Mono<Void> closeGracefully() {
}
});

this.sessions.clear();
this.sessionStore.clear();
logger.debug("Graceful shutdown completed");
}).then().doOnSuccess(v -> {
sessions.clear();
sessionStore.clear();
logger.debug("Graceful shutdown completed");
if (this.keepAliveScheduler != null) {
this.keepAliveScheduler.shutdown();
Expand Down Expand Up @@ -299,7 +303,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response)
return;
}

McpStreamableServerSession session = this.sessions.get(sessionId);
McpStreamableServerSession session = this.sessionStore.get(sessionId);

if (session == null) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
Expand Down Expand Up @@ -452,7 +456,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
});
McpStreamableServerSession.McpStreamableServerSessionInit init = this.sessionFactory
.startSession(initializeRequest);
this.sessions.put(init.session().getId(), init.session());
this.sessionStore.save(init.session().getId(), init.session());

try {
McpSchema.InitializeResult initResult = init.initResult().block();
Expand Down Expand Up @@ -493,7 +497,7 @@ protected void doPost(HttpServletRequest request, HttpServletResponse response)
return;
}

McpStreamableServerSession session = this.sessions.get(sessionId);
McpStreamableServerSession session = this.sessionStore.get(sessionId);

if (session == null) {
this.responseError(response, HttpServletResponse.SC_NOT_FOUND,
Expand Down Expand Up @@ -612,7 +616,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response
}

String sessionId = request.getHeader(HttpHeaders.MCP_SESSION_ID);
McpStreamableServerSession session = this.sessions.get(sessionId);
McpStreamableServerSession session = this.sessionStore.get(sessionId);

if (session == null) {
response.sendError(HttpServletResponse.SC_NOT_FOUND);
Expand All @@ -621,7 +625,7 @@ protected void doDelete(HttpServletRequest request, HttpServletResponse response

try {
session.delete().contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)).block();
this.sessions.remove(sessionId);
this.sessionStore.remove(sessionId);
response.setStatus(HttpServletResponse.SC_OK);
}
catch (Exception e) {
Expand Down Expand Up @@ -755,7 +759,7 @@ public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message, String messageId
}
catch (Exception e) {
logger.error("Failed to send message to session {}: {}", this.sessionId, e.getMessage());
HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
HttpServletStreamableServerTransportProvider.this.sessionStore.remove(this.sessionId);
this.asyncContext.complete();
}
finally {
Expand Down Expand Up @@ -801,7 +805,7 @@ public void close() {

this.closed = true;

// HttpServletStreamableServerTransportProvider.this.sessions.remove(this.sessionId);
// HttpServletStreamableServerTransportProvider.this.sessionStore.remove(this.sessionId);
this.asyncContext.complete();
logger.debug("Successfully completed async context for session {}", sessionId);
}
Expand Down Expand Up @@ -838,6 +842,8 @@ public static class Builder {

private ServerTransportSecurityValidator securityValidator = ServerTransportSecurityValidator.NOOP;

private McpSessionStore sessionStore;

/**
* Sets the JsonMapper to use for JSON serialization/deserialization of MCP
* messages.
Expand Down Expand Up @@ -909,6 +915,19 @@ public Builder securityValidator(ServerTransportSecurityValidator securityValida
return this;
}

/**
* Sets the session store for managing active client sessions. If not set, an
* {@link InMemoryMcpSessionStore} will be used by default.
* @param sessionStore The session store to use. Must not be null.
* @return this builder instance
* @throws IllegalArgumentException if sessionStore is null
*/
public Builder sessionStore(McpSessionStore sessionStore) {
Assert.notNull(sessionStore, "Session store must not be null");
this.sessionStore = sessionStore;
return this;
}

/**
* Builds a new instance of {@link HttpServletStreamableServerTransportProvider}
* with the configured settings.
Expand All @@ -919,7 +938,8 @@ public HttpServletStreamableServerTransportProvider build() {
Assert.notNull(this.mcpEndpoint, "MCP endpoint must be set");
return new HttpServletStreamableServerTransportProvider(
jsonMapper == null ? McpJsonDefaults.getMapper() : jsonMapper, mcpEndpoint, disallowDelete,
contextExtractor, keepAliveInterval, securityValidator);
contextExtractor, keepAliveInterval, securityValidator,
sessionStore == null ? new InMemoryMcpSessionStore() : sessionStore);
}

}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright 2024-2026 the original author or authors.
*/

package io.modelcontextprotocol.spec;

import java.util.Collection;
import java.util.concurrent.ConcurrentHashMap;

/**
* Default in-memory implementation of {@link McpSessionStore} backed by a
* {@link ConcurrentHashMap}. This implementation is suitable for single-instance
* deployments where session state does not need to be shared across multiple server
* instances.
*
* <p>
* This is the default session store used by
* {@link io.modelcontextprotocol.server.transport.HttpServletStreamableServerTransportProvider}
* when no custom {@link McpSessionStore} is provided.
*
* @author WeiLin Wang
* @see McpSessionStore
*/
public class InMemoryMcpSessionStore implements McpSessionStore {

private final ConcurrentHashMap<String, McpStreamableServerSession> sessions = new ConcurrentHashMap<>();

@Override
public void save(String sessionId, McpStreamableServerSession session) {
this.sessions.put(sessionId, session);
}

@Override
public McpStreamableServerSession get(String sessionId) {
return this.sessions.get(sessionId);
}

@Override
public McpStreamableServerSession remove(String sessionId) {
return this.sessions.remove(sessionId);
}

@Override
public Collection<McpStreamableServerSession> values() {
return this.sessions.values();
}

@Override
public boolean isEmpty() {
return this.sessions.isEmpty();
}

@Override
public int size() {
return this.sessions.size();
}

@Override
public void clear() {
this.sessions.clear();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* Copyright 2024-2026 the original author or authors.
*/

package io.modelcontextprotocol.spec;

import java.util.Collection;

/**
* Strategy interface for storing and retrieving MCP server sessions. This abstraction
* allows the session storage mechanism to be customized, enabling implementations such as
* in-memory (default), Redis-backed, JDBC-backed, or any distributed store.
*
* <p>
* The default implementation {@link InMemoryMcpSessionStore} uses a
* {@link java.util.concurrent.ConcurrentHashMap} which is suitable for single-instance
* deployments. For distributed or multi-instance deployments, a custom implementation
* backed by a distributed data store should be used.
*
* <p>
* Note: {@link McpStreamableServerSession} objects contain active transport connections
* (SSE streams) that are inherently tied to the JVM instance. A distributed session store
* therefore stores the session reference per-node and coordinates session lifecycle
* across nodes (e.g., detecting when a session was created on a different node).
*
* @author WeiLin Wang
* @see InMemoryMcpSessionStore
* @see McpStreamableServerSession
*/
public interface McpSessionStore {

/**
* Stores a session with the given ID. If a session with the same ID already exists,
* it will be replaced.
* @param sessionId the unique session identifier
* @param session the session to store
*/
void save(String sessionId, McpStreamableServerSession session);

/**
* Retrieves a session by its ID.
* @param sessionId the unique session identifier
* @return the session associated with the given ID, or {@code null} if not found
*/
McpStreamableServerSession get(String sessionId);

/**
* Removes a session by its ID.
* @param sessionId the unique session identifier
* @return the previously stored session, or {@code null} if no session was stored
* with the given ID
*/
McpStreamableServerSession remove(String sessionId);

/**
* Returns all currently stored sessions.
* @return a collection of all stored sessions; never {@code null}
*/
Collection<McpStreamableServerSession> values();

/**
* Returns whether there are any sessions stored.
* @return {@code true} if no sessions are stored, {@code false} otherwise
*/
boolean isEmpty();

/**
* Returns the number of stored sessions.
* @return the session count
*/
int size();

/**
* Removes all stored sessions.
*/
void clear();

}