Source code for telnetlib3.client_base

"""Module provides class BaseClient."""

from __future__ import annotations

# std imports
import zlib
import asyncio
import logging
import weakref
import datetime
import collections
from typing import Any, Union, Optional, cast

# local
from ._base import TelnetProtocolBase, _log_exception, _process_data_chunk
from ._types import ShellCallback
from .telopt import DO, WILL, theNULL, name_commands
from .accessories import TRACE, hexdump
from .stream_reader import TelnetReader, TelnetReaderUnicode
from .stream_writer import TelnetWriter, TelnetWriterUnicode

__all__ = ("BaseClient",)


[docs] class BaseClient(TelnetProtocolBase, asyncio.streams.FlowControlMixin, asyncio.Protocol): """Base Telnet Client Protocol.""" _transport: Optional[asyncio.Transport] = None _closing = False _reader_factory = TelnetReader _reader_factory_encoding = TelnetReaderUnicode _writer_factory = TelnetWriter _writer_factory_encoding = TelnetWriterUnicode _check_later: Optional[asyncio.Handle] = None def __init__( self, shell: Optional[ShellCallback] = None, encoding: Union[str, bool] = "utf8", encoding_errors: str = "strict", force_binary: bool = False, connect_minwait: float = 0, connect_maxwait: float = 4.0, limit: Optional[int] = None, waiter_closed: Optional[asyncio.Future[None]] = None, _waiter_connected: Optional[asyncio.Future[None]] = None, ) -> None: """Class initializer.""" super().__init__() self.log = logging.getLogger("telnetlib3.client") #: encoding for new connections self.default_encoding = encoding self._encoding_errors = encoding_errors self.force_binary = force_binary self._extra: dict[str, Any] = {} self.waiter_closed = waiter_closed or asyncio.Future() #: a future used for testing self._waiter_connected = _waiter_connected or asyncio.Future() self._tasks: list[Any] = [] self.shell = shell #: minimum duration for :meth:`check_negotiation`. self.connect_minwait = connect_minwait #: maximum duration for :meth:`check_negotiation`. self.connect_maxwait = connect_maxwait self.reader: Optional[Union[TelnetReader, TelnetReaderUnicode]] = None self.writer: Optional[Union[TelnetWriter, TelnetWriterUnicode]] = None self._limit = limit # MCCP2: server→client decompression self._mccp2_decompressor: Optional[zlib._Decompress] = None self._mccp2_wbits_fallback: bool = False # MCCP3: client→server compression self._mccp3_compressor: Optional[zlib._Compress] = None self._mccp3_orig_write: Any = None # High-throughput receive pipeline self._rx_queue: collections.deque[bytes] = collections.deque() self._rx_bytes = 0 self._rx_task: Optional[asyncio.Task[Any]] = None self._reading_paused = False # Apply backpressure to transport when our queue grows too large self._read_high = 512 * 1024 # pause_reading() above this many buffered bytes self._read_low = 256 * 1024 # resume_reading() below this many buffered bytes # Base protocol methods
[docs] def eof_received(self) -> None: """Called when the other end calls write_eof() or equivalent.""" self.log.debug("EOF from server, closing.") self.connection_lost(None)
[docs] def connection_lost(self, exc: Optional[Exception]) -> None: """ Called when the connection is lost or closed. :param exc: Exception instance, or ``None`` to indicate a closing EOF sent by this end. """ if self._closing: return self._closing = True # Clean up MCCP compressors/decompressors self._mccp2_decompressor = None self._mccp2_wbits_fallback = False self._mccp3_compressor = None self._mccp3_orig_write = None # Drain any pending rx data before signalling EOF to prevent # _process_rx from calling feed_data() after feed_eof(). self._rx_queue.clear() self._rx_bytes = 0 if self._rx_task is not None and not self._rx_task.done(): self._rx_task.cancel() self._rx_task = None # inform yielding readers about closed connection if exc is None: self.log.info("Connection closed to %s", self) self.reader.feed_eof() else: self.log.info("Connection lost to %s: %s", self, exc) self.reader.set_exception(exc) # cancel protocol tasks, namely on-connect negotiations for task in self._tasks: task.cancel() # close transport (may already be closed), set waiter_closed and # cancel Future _waiter_connected. self._transport.close() if not self._waiter_connected.done(): # strangely, for symmetry, our '_waiter_connected' must be set if # we are disconnected before negotiation may be considered # complete. We set waiter_closed, and any function consuming # the StreamReader will receive eof. self._waiter_connected.set_result(None) if self.shell is None and not self.waiter_closed.done(): # when a shell is defined, we allow the completion of the coroutine # to set the result of waiter_closed. self.waiter_closed.set_result(weakref.proxy(self)) # break circular references. self._transport = None
[docs] def connection_made(self, transport: asyncio.BaseTransport) -> None: """ Called when a connection is made. Ensure ``super().connection_made(transport)`` is called when derived. """ _transport = cast(asyncio.Transport, transport) self._transport = _transport self._when_connected = datetime.datetime.now() self._last_received = datetime.datetime.now() reader_factory: type[TelnetReader] | type[TelnetReaderUnicode] = self._reader_factory writer_factory: type[TelnetWriter] | type[TelnetWriterUnicode] = self._writer_factory reader_kwds: dict[str, Any] = {} writer_kwds: dict[str, Any] = {} if self.default_encoding: reader_kwds["fn_encoding"] = self.encoding writer_kwds["fn_encoding"] = self.encoding reader_kwds["encoding_errors"] = self._encoding_errors writer_kwds["encoding_errors"] = self._encoding_errors reader_factory = self._reader_factory_encoding writer_factory = self._writer_factory_encoding if self._limit: reader_kwds["limit"] = self._limit self.reader = reader_factory(**reader_kwds) # Attach transport so TelnetReader can apply pause_reading/resume_reading try: self.reader.set_transport(_transport) except Exception: # Reader may not support transport coupling; ignore. pass self.writer = writer_factory( transport=_transport, protocol=self, reader=self.reader, client=True, **writer_kwds ) self.log.info("Connected to %s", self) self._log_tls_info(self.log) self._waiter_connected.add_done_callback(self.begin_shell) asyncio.get_event_loop().call_soon(self.begin_negotiation)
[docs] def begin_shell(self, future: asyncio.Future[None]) -> None: """Start the shell coroutine after negotiation completes.""" # Don't start shell if the connection was cancelled or errored if future.cancelled() or future.exception() is not None: return if self.shell is not None: assert self.reader is not None and self.writer is not None coro = self.shell(self.reader, self.writer) if asyncio.iscoroutine(coro): # When a shell is defined as a coroutine, we must ensure # that self.waiter_closed is not closed until the shell # has had an opportunity to respond to EOF. Because # feed_eof() occurs in connection_lost(), we must allow # the event loop to return to our shell coroutine before # the waiter_closed future is set. # # We accomplish this by chaining the completion of the # shell future to set the result of the waiter_closed # future. fut = asyncio.get_event_loop().create_task(coro) fut.add_done_callback( lambda fut_obj: ( self.waiter_closed.set_result(weakref.proxy(self)) if self.waiter_closed is not None and not self.waiter_closed.done() else None ) )
[docs] def data_received(self, data: bytes) -> None: """ Process bytes received by transport. Buffer incoming data and schedule async processing to keep the event loop responsive. Apply read-side backpressure using transport.pause_reading()/resume_reading(). """ if self.log.isEnabledFor(TRACE): self.log.log(TRACE, "recv %d bytes\n%s", len(data), hexdump(data, prefix="<< ")) self._last_received = datetime.datetime.now() # Detect SyncTERM font switching sequences and auto-switch encoding. self._detect_syncterm_font(data) # Enqueue and account for buffered size self._rx_queue.append(data) self._rx_bytes += len(data) # Start processor task if not running if self._rx_task is None or self._rx_task.done(): loop = asyncio.get_event_loop() self._rx_task = loop.create_task(self._process_rx()) # Pause reading if buffered bytes exceed high watermark if not self._reading_paused and self._rx_bytes >= self._read_high: if self._transport is not None: try: self._transport.pause_reading() self._reading_paused = True except Exception: # Some transports may not support pause_reading; ignore. pass
def _detect_syncterm_font(self, data: bytes) -> None: """ Scan *data* for SyncTERM font selection and switch encoding. When :attr:`_encoding_explicit` is set on the writer (indicating the user passed ``--encoding``), the font switch is logged but does not override the encoding. """ if self.writer is None: return from .server_fingerprinting import detect_syncterm_font encoding = detect_syncterm_font(data) if encoding is not None: self.log.debug("SyncTERM font switch: %s", encoding) if getattr(self.writer, "_encoding_explicit", False): self.log.debug( "ignoring font switch, explicit encoding: %s", self.writer.environ_encoding ) else: self.writer.environ_encoding = encoding self.force_binary = True # public properties
[docs] def begin_negotiation(self) -> None: """ Begin on-connect negotiation. A Telnet client is expected to send only a minimal amount of client session options immediately after connection, it is generally the server which dictates server option support. Deriving implementations should always call ``super().begin_negotiation()``. """ self._check_later = asyncio.get_event_loop().call_soon(self._check_negotiation_timer) self._tasks.append(self._check_later) # Send proactive WILL/DO for any "always" options if self.writer is not None: for opt in self.writer.always_will: self.writer.iac(WILL, opt) for opt in self.writer.always_do: self.writer.iac(DO, opt)
[docs] def encoding(self, outgoing: bool = False, incoming: bool = False) -> Union[str, bool]: """ Encoding that should be used for the direction indicated. The base implementation **always** returns ``encoding`` argument given to class initializer or, when unset (``None``), ``US-ASCII``. """ return self.default_encoding or "US-ASCII" # pragma: no cover
[docs] def check_negotiation(self, final: bool = False) -> bool: """ Callback, return whether negotiation is complete. :param final: Whether this is the final time this callback will be requested to answer regarding protocol negotiation. :returns: Whether negotiation is over (client end is satisfied). Method is called on each new command byte processed until negotiation is considered final, or after :attr:`connect_maxwait` has elapsed, setting the ``_waiter_connected`` attribute to value ``self`` when complete. If critical negotiations have completed (TTYPE and either NEW_ENVIRON or CHARSET), negotiation is considered complete immediately without waiting for connect_minwait. Otherwise, this method returns False until :attr:`connect_minwait` has elapsed, ensuring the server may batch telnet negotiation demands without prematurely entering the callback shell. Ensure ``super().check_negotiation()`` is called and conditionally combined when derived. """ from .telopt import TTYPE, CHARSET, NEW_ENVIRON # First check if there are any pending options if any(self.writer.pending_option.values()): return False # Check if critical options are enabled (terminal type and encoding info) have_terminal_type = self.writer.local_option.enabled(TTYPE) have_environ = self.writer.local_option.enabled(NEW_ENVIRON) have_charset = self.writer.remote_option.enabled( CHARSET ) and self.writer.local_option.enabled(CHARSET) # If we have terminal type and either environment or charset info, we can bypass the minwait critical_options_negotiated = have_terminal_type and (have_environ or have_charset) if critical_options_negotiated: if final: self.log.debug("Critical options negotiated, bypassing connect_minwait") return True # Otherwise, ensure we wait the minimum time for server to batch commands return self.duration > self.connect_minwait
# private methods def _process_chunk(self, data: bytes) -> bool: """Process a chunk of received bytes; return True if any IAC/SB cmd observed.""" self._last_received = datetime.datetime.now() # MCCP2: decompress server→client data when active if self._mccp2_decompressor is not None: try: data = self._mccp2_decompressor.decompress(data) except zlib.error: if not self._mccp2_wbits_fallback: self.log.debug("MCCP2 auto-detect failed, retrying raw deflate") self._mccp2_wbits_fallback = True self._mccp2_decompressor = zlib.decompressobj(wbits=-zlib.MAX_WBITS) try: data = self._mccp2_decompressor.decompress(data) except zlib.error: self.log.warning("MCCP2 decompression error, disabling") self._mccp2_end() return False else: self.log.warning("MCCP2 decompression error, disabling") self._mccp2_end() return False if self._mccp2_decompressor.eof: unused = self._mccp2_decompressor.unused_data self._mccp2_end() cmd = self._process_chunk_inner(data) if unused: cmd = self._process_chunk(unused) or cmd return cmd return self._process_chunk_inner(data) def _process_chunk_inner(self, data: bytes) -> bool: """Inner chunk processing with IAC interpretation and mid-chunk MCCP2 detection.""" try: mode = self.writer.mode except Exception: mode = "local" slc_needed = (mode == "remote") or (mode == "kludge" and self.writer.slc_simulated) if slc_needed: slc_vals = {defn.val[0] for defn in self.writer.slctab.values() if defn.val != theNULL} slc_special: frozenset[int] | None = frozenset({255} | slc_vals) else: slc_special = None cmd_received = _process_data_chunk( data, self.writer, self.reader, slc_special, self.log.warning ) if self.writer._compressed_remainder is not None: remainder = self.writer._compressed_remainder self.writer._compressed_remainder = None self._mccp2_start() if remainder: cmd_received = self._process_chunk(remainder) or cmd_received # MCCP3: start compressor when writer signals activation if self.writer.mccp3_active and self._mccp3_compressor is None: self._mccp3_start() return cmd_received async def _process_rx(self) -> None: """Async processor for receive queue that yields control and applies backpressure.""" processed = 0 any_cmd = False try: while self._rx_queue: # Stop processing if connection was closed (feed_eof already called) if self._closing: self._rx_queue.clear() self._rx_bytes = 0 break chunk = self._rx_queue.popleft() self._rx_bytes -= len(chunk) cmd = self._process_chunk(chunk) any_cmd = any_cmd or cmd processed += len(chunk) # Resume reading when we've drained below low watermark if self._reading_paused and self._rx_bytes <= self._read_low: if self._transport is not None: try: self._transport.resume_reading() self._reading_paused = False except Exception: pass # Yield periodically to keep loop responsive without excessive context switching if processed >= 128 * 1024: await asyncio.sleep(0) processed = 0 finally: self._rx_task = None # Aggressively re-check negotiation if any command was seen and not yet connected if any_cmd and not self._waiter_connected.done(): self._check_negotiation_timer() def _mccp2_start(self) -> None: """Start MCCP2 decompression of server→client data.""" self._mccp2_decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32) self._mccp2_wbits_fallback = False self.log.debug("MCCP2 decompression started (server→client)") def _mccp2_end(self) -> None: """Stop MCCP2 decompression.""" self._mccp2_decompressor = None self.writer.mccp2_active = False self.log.debug("MCCP2 decompression ended (server→client)") def _mccp3_start(self) -> None: """Start MCCP3 compression of client→server data.""" self._mccp3_compressor = zlib.compressobj( zlib.Z_BEST_COMPRESSION, zlib.DEFLATED, 12, 5, zlib.Z_DEFAULT_STRATEGY ) # Wrap transport.write so all outbound bytes are compressed transport = self.writer._transport orig_write = transport.write def compressed_write(data: bytes) -> None: if self._mccp3_compressor is not None: compressed = self._mccp3_compressor.compress(data) compressed += self._mccp3_compressor.flush(zlib.Z_SYNC_FLUSH) orig_write(compressed) else: orig_write(data) transport.write = compressed_write # type: ignore[method-assign] self._mccp3_orig_write = orig_write self.log.debug("MCCP3 compression started (client→server)") def _mccp3_end(self) -> None: """Stop MCCP3 compression, flush Z_FINISH.""" if self._mccp3_compressor is not None: if not self.writer.is_closing(): self._mccp3_orig_write(self._mccp3_compressor.flush(zlib.Z_FINISH)) self._mccp3_compressor = None # Restore original transport.write self.writer._transport.write = self._mccp3_orig_write # type: ignore[method-assign] self.writer.mccp3_active = False self.log.debug("MCCP3 compression ended (client→server)") def _check_negotiation_timer(self) -> None: self._check_later.cancel() self._tasks.remove(self._check_later) later = self.connect_maxwait - self.duration final = bool(later < 0) if self.check_negotiation(final=final): self.log.debug("negotiation complete after %1.2fs.", self.duration) self._waiter_connected.set_result(None) elif final: self.log.debug("negotiation failed after %1.2fs.", self.duration) _failed = [ name_commands(cmd_option) for (cmd_option, pending) in self.writer.pending_option.items() if pending ] self.log.debug("failed-reply: %r", ", ".join(_failed)) self._waiter_connected.set_result(None) else: # keep re-queuing until complete. Aggressively re-queue until # connect_minwait, or connect_maxwait, whichever occurs next # in our time-series. sooner = self.connect_minwait - self.duration if sooner > 0: later = sooner self._check_later = asyncio.get_event_loop().call_later( later, self._check_negotiation_timer ) self._tasks.append(self._check_later) _log_exception = staticmethod(_log_exception)