Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/migrating_to_lsp_plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
| `register_plugin(MyPlugin)` / `unregister_plugin(MyPlugin)` | `MyPlugin.register()` / `MyPlugin.unregister()` - no standalone import needed |
| *(not present)* | `on_initialized_async()` |
| *(not present)* | `on_pre_send_response_async(response)` |
| *(not present)* | `on_transport_ready(reader, writer)` |

The methods `on_selection_modified_async` and `on_session_end_async` are available in `LspPlugin` with the same name and the same signature. `on_pre_send_notification_async` and `on_server_notification_async` keep the same names but use more specific argument types — see step 11.

Expand Down
10 changes: 10 additions & 0 deletions plugin/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,16 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
cls.name = cls.__module__.split('.')[0] # pyright: ignore[reportAttributeAccessIssue]
cls.plugin_storage_path = Path(ST_STORAGE_PATH, cls.name) # pyright: ignore[reportAttributeAccessIssue]

async def on_transport_ready(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
"""
Called before any JSON-RPC communication starts.

Override to do custom communication with the language server. The communication is not logged.
Use the provided writer object to write bytes; use the reader object to read bytes.
When returning from this async method, the JSON-RPC communication starts.
"""
pass

@deprecated("override on_initialized instead")
def on_initialized_async(self) -> None:
pass
Expand Down
4 changes: 4 additions & 0 deletions plugin/core/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,10 @@ async def initialize(
loop = asyncio.get_running_loop()
if self._plugin_class and issubclass(self._plugin_class, LspPlugin):
self._plugin = self._plugin_class(weakref.ref(self))
if (reader := transport.reader) and (writer := transport.writer):
await self._plugin.on_transport_ready(reader, writer)
else:
raise RuntimeError("transport has already stopped")
self.transport = transport
self.working_directory = working_directory
params = get_initialize_params(variables, self._workspace_folders, self.config)
Expand Down
96 changes: 38 additions & 58 deletions plugin/core/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ async def start(
raise Exception('Failed to create transport config due to not being able to pipe stdio')
return TransportWrapper(
callback_object=callbacks,
transport=StreamTransport(encode_json, decode_json, process.stdout, process.stdin),
transport=Transport(encode_json, decode_json, process.stdout, process.stdin),
process=process,
process_args=launch.command,
error_reader=ErrorReader(callbacks, process.stderr),
Expand Down Expand Up @@ -164,7 +164,7 @@ async def start(
)
return TransportWrapper(
callback_object=callbacks,
transport=StreamTransport(encode_json, decode_json, reader, writer),
transport=Transport(encode_json, decode_json, reader, writer),
process=process,
process_args=launch.command if launch else None,
error_reader=error_reader,
Expand Down Expand Up @@ -216,7 +216,7 @@ def __init__(self) -> None:

async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
async with self.cv:
transport = StreamTransport(encode_json, decode_json, reader, writer)
transport = Transport(encode_json, decode_json, reader, writer)
self.wrapper = TransportWrapper(callbacks, transport, self.process, command, self.error_reader)
self.cv.notify()

Expand Down Expand Up @@ -250,37 +250,7 @@ async def __call__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWri
# --- Transports -------------------------------------------------------------------------------------------------------


class TransportCallbacks:
async def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: ...

async def on_payload(self, payload: JSONRPCMessage) -> None: ...

def on_stderr_message(self, message: str) -> None: ...


class Transport(ABC):
def __init__(self, encoder: Callable[[JSONRPCMessage], bytes], decoder: Callable[[bytes], JSONRPCMessage]) -> None:
self._encoder = encoder
self._decoder = decoder

@abstractmethod
async def read(self) -> JSONRPCMessage | None:
raise NotImplementedError

@abstractmethod
async def write(self, payload: JSONRPCMessage) -> None:
raise NotImplementedError

@abstractmethod
async def write_bytes(self, payload: bytes) -> None:
raise NotImplementedError

@abstractmethod
async def close(self) -> None:
raise NotImplementedError


async def parse_headers(reader: asyncio.StreamReader) -> dict[str, str]:
async def _parse_headers(reader: asyncio.StreamReader) -> dict[str, str]:
headers: dict[str, str] = {}
try:
headers_bytes = (await reader.readuntil(b'\r\n\r\n')).decode("ascii").rstrip()
Expand All @@ -294,27 +264,44 @@ async def parse_headers(reader: asyncio.StreamReader) -> dict[str, str]:
return headers


async def parse_content_length(reader: asyncio.StreamReader) -> int | None:
headers = await parse_headers(reader)
async def _parse_content_length(reader: asyncio.StreamReader) -> int | None:
headers = await _parse_headers(reader)
content_length = headers.get("content-length")
return int(content_length) if content_length else None


class StreamTransport(Transport):
class TransportCallbacks:
async def on_transport_close(self, exit_code: int, exception: Exception | None) -> None: ...

async def on_payload(self, payload: JSONRPCMessage) -> None: ...

def on_stderr_message(self, message: str) -> None: ...


@final
class Transport:
def __init__(
self,
encoder: Callable[[JSONRPCMessage], bytes],
decoder: Callable[[bytes], JSONRPCMessage],
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
super().__init__(encoder, decoder)
self._encoder = encoder
self._decoder = decoder
self._reader = reader
self._writer = writer

@override
@property
def reader(self) -> asyncio.StreamReader:
return self._reader

@property
def writer(self) -> asyncio.StreamWriter:
return self._writer

async def read(self) -> JSONRPCMessage:
content_length = await parse_content_length(self._reader)
content_length = await _parse_content_length(self._reader)
if content_length is None:
raise StopLoopError
body = await self._reader.readexactly(content_length)
Expand All @@ -323,7 +310,6 @@ async def read(self) -> JSONRPCMessage:
except Exception as ex:
raise Exception(f"JSON decode error: {ex}") from ex

@override
async def write(self, payload: JSONRPCMessage) -> None:
body = self._encoder(payload)
self._writer.writelines((f"Content-Length: {len(body)}\r\n\r\n".encode("ascii"), body))
Expand All @@ -334,27 +320,17 @@ async def write(self, payload: JSONRPCMessage) -> None:
# there's other logic that will make the transport shut down.
pass

@override
async def write_bytes(self, payload: bytes) -> None:
self._writer.write(payload)
await self._writer.drain()

@override
async def close(self) -> None:
self._writer.close()
await self._writer.wait_closed()


# --- TransportWrapper -------------------------------------------------------------------------------------------------


@final
class TransportWrapper:
"""
Double dispatch-like class that takes a (subclass of) Transport, and provides to a (subclass of) TransportCallbacks
appropriately decoded messages. The TransportWrapper is also responsible for keeping the spawned child
process around (if any), and also keeps track of the ErrorReader. It can be the case that there is no ErrorReader,
for instance when talking to a remote TCP language server. So it can be None.
Provides to a (subclass of) TransportCallbacks appropriately decoded messages. The TransportWrapper is also
responsible for keeping the spawned child process around (if any), and also keeps track of the ErrorReader. It can
be the case that there is no ErrorReader, for instance when talking to a remote TCP language server.
"""

def __init__(
Expand All @@ -380,14 +356,18 @@ def process_args(self) -> list[str] | None:
"""
return self._process_args

@property
def reader(self) -> asyncio.StreamReader | None:
return self._transport.reader if self._transport else None

@property
def writer(self) -> asyncio.StreamWriter | None:
return self._transport.writer if self._transport else None

async def send(self, payload: JSONRPCMessage) -> None:
if self._transport:
await self._transport.write(payload)

async def send_bytes(self, payload: bytes) -> None:
if self._transport:
await self._transport.write_bytes(payload)

async def close(self) -> None:
if self._error_reader:
self._error_reader.on_transport_close()
Expand Down