Skip to content

Commit 64fa046

Browse files
feat: support streamable http server connection (#150)
* feat: support streamable http server connection * test: fix test * Update src/mcpm/router/client_connection.py
1 parent e331cad commit 64fa046

File tree

6 files changed

+429
-386
lines changed

6 files changed

+429
-386
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dependencies = [
2424
"rich>=12.0.0",
2525
"requests>=2.28.0",
2626
"pydantic>=2.5.1",
27-
"mcp==1.6.0",
27+
"mcp>=1.8.0",
2828
"ruamel-yaml>=0.18.10",
2929
"watchfiles>=1.0.4",
3030
"duckdb>=1.2.2",

src/mcpm/router/client_connection.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
import logging
33
from typing import Optional, TextIO, cast
44

5+
import requests
56
from mcp import ClientSession, InitializeResult, StdioServerParameters, stdio_client
67
from mcp.client.sse import sse_client
8+
from mcp.client.streamable_http import streamablehttp_client
79

810
from mcpm.core.schema import ServerConfig, SSEServerConfig, STDIOServerConfig
911

@@ -21,6 +23,11 @@ def _sse_transport_context(server_config: ServerConfig):
2123
return sse_client(server_config.url, headers=server_config.headers)
2224

2325

26+
def _streamable_http_transport_context(server_config: ServerConfig):
27+
server_config = cast(SSEServerConfig, server_config)
28+
return streamablehttp_client(server_config.url, headers=server_config.headers)
29+
30+
2431
class ServerConnection:
2532
def __init__(self, server_config: ServerConfig, errlog: TextIO) -> None:
2633
self.session: Optional[ClientSession] = None
@@ -31,14 +38,21 @@ def __init__(self, server_config: ServerConfig, errlog: TextIO) -> None:
3138
self._shutdown_event = asyncio.Event()
3239
self._errlog = errlog
3340

34-
self._transport_context_factory = (
35-
lambda config: _stdio_transport_context(config, errlog=self._errlog)
36-
if isinstance(config, STDIOServerConfig)
37-
else _sse_transport_context(config)
38-
)
39-
4041
self._server_task = asyncio.create_task(self._server_lifespan_cycle())
4142

43+
def _transport_context_factory(self, server_config: ServerConfig):
44+
if isinstance(server_config, STDIOServerConfig):
45+
return _stdio_transport_context(server_config, self._errlog)
46+
elif isinstance(server_config, SSEServerConfig):
47+
r = requests.head(server_config.url)
48+
if r.status_code != 200:
49+
return _streamable_http_transport_context(server_config)
50+
if r.headers.get("connection") == "keep-alive" and r.headers.get("content-type", "").startswith(
51+
"text/event-stream"
52+
):
53+
return _sse_transport_context(server_config)
54+
return _streamable_http_transport_context(server_config)
55+
4256
def healthy(self) -> bool:
4357
return self.session is not None and self._initialized
4458

@@ -56,7 +70,7 @@ async def wait_for_shutdown_request(self):
5670

5771
async def _server_lifespan_cycle(self):
5872
try:
59-
async with self._transport_context_factory(self.server_config) as (read, write):
73+
async with self._transport_context_factory(self.server_config) as (read, write, *_):
6074
async with ClientSession(read, write) as session:
6175
self.session_initialized_response = await session.initialize()
6276

src/mcpm/router/transport.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1111
from mcp import types
1212
from mcp.server.sse import SseServerTransport
13+
from mcp.shared.message import SessionMessage
1314
from pydantic import ValidationError
1415
from sse_starlette import EventSourceResponse
1516
from starlette.background import BackgroundTask
@@ -86,11 +87,11 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
8687
return
8788

8889
logger.debug("Setting up SSE connection")
89-
read_stream: MemoryObjectReceiveStream[types.JSONRPCMessage | Exception]
90-
read_stream_writer: MemoryObjectSendStream[types.JSONRPCMessage | Exception]
90+
read_stream: MemoryObjectReceiveStream[SessionMessage | Exception]
91+
read_stream_writer: MemoryObjectSendStream[SessionMessage | Exception]
9192

92-
write_stream: MemoryObjectSendStream[types.JSONRPCMessage]
93-
write_stream_reader: MemoryObjectReceiveStream[types.JSONRPCMessage]
93+
write_stream: MemoryObjectSendStream[SessionMessage]
94+
write_stream_reader: MemoryObjectReceiveStream[SessionMessage]
9495

9596
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
9697
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
@@ -118,12 +119,12 @@ async def sse_writer():
118119
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
119120
logger.debug(f"Sent endpoint event: {session_uri}")
120121

121-
async for message in write_stream_reader:
122-
logger.debug(f"Sending message via SSE: {message}")
122+
async for session_message in write_stream_reader:
123+
logger.debug(f"Sending message via SSE: {session_message}")
123124
await sse_stream_writer.send(
124125
{
125126
"event": "message",
126-
"data": message.model_dump_json(by_alias=True, exclude_none=True),
127+
"data": session_message.message.model_dump_json(by_alias=True, exclude_none=True),
127128
}
128129
)
129130

@@ -228,7 +229,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send):
228229

229230
# add error handling, catch possible pipe errors
230231
try:
231-
await writer.send(message)
232+
await writer.send(SessionMessage(message=message))
232233
except (BrokenPipeError, ConnectionError, OSError) as e:
233234
# if it's EPIPE error or other connection error, log it but don't throw an exception
234235
if isinstance(e, OSError) and e.errno == 32: # EPIPE

src/mcpm/utils/errlog_manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
DEFAULT_ROOT_STDERR_LOG_DIR = get_log_directory("mcpm") / "errlogs"
88

9+
910
class ServerErrorLogManager:
1011
"""
1112
A manager for server error logs.
@@ -30,5 +31,5 @@ def close_errlog_file(self, server_id: str) -> None:
3031
del self._log_files[server_id]
3132

3233
def close_all(self) -> None:
33-
for server_id in self._log_files:
34+
for server_id in list(self._log_files.keys()):
3435
self.close_errlog_file(server_id)

tests/test_add.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_add_server_with_missing_arg(windsurf_manager, monkeypatch):
9696
patch("rich.progress.Progress.add_task"),
9797
):
9898
# Use CliRunner which provides its own isolated environment
99-
runner = CliRunner(mix_stderr=False)
99+
runner = CliRunner()
100100
result = runner.invoke(add, ["server-test", "--force", "--alias", "test-missing-arg"])
101101

102102
if result.exit_code != 0:
@@ -164,7 +164,7 @@ def test_add_server_with_empty_args(windsurf_manager, monkeypatch):
164164
patch("rich.progress.Progress.stop"),
165165
patch("rich.progress.Progress.add_task"),
166166
):
167-
runner = CliRunner(mix_stderr=False)
167+
runner = CliRunner()
168168
result = runner.invoke(add, ["server-test", "--force", "--alias", "test-empty-args"])
169169

170170
assert result.exit_code == 0

0 commit comments

Comments
 (0)