diff --git a/wsrepl/MessageHandler.py b/wsrepl/MessageHandler.py index 6753bff..b740bcf 100644 --- a/wsrepl/MessageHandler.py +++ b/wsrepl/MessageHandler.py @@ -1,6 +1,7 @@ import asyncio import textual import threading +from typing import Callable import ssl from collections import OrderedDict from urllib.parse import urlparse @@ -46,14 +47,16 @@ def __init__(self, exit() self.initial_messages: list[WSMessage] = self._load_initial_messages(initial_msgs_file) - processed_headers: OrderedDict = self._process_headers(headers, headers_file, user_agent, origin, cookies) + + def generate_headers() -> OrderedDict: + return self._process_headers(headers, headers_file, self.plugin.headers_callback, user_agent, origin, cookies) self._ws = WebsocketConnection( # Stuff WebsocketConnection needs to call back to us async_handler=self, # WebSocketApp args url=url, - header=processed_headers + header=generate_headers ) # Args passed to websocket.WebSocketApp.run_forever() @@ -61,7 +64,7 @@ def __init__(self, proxy = 'http://' + proxy if '://' not in proxy else proxy self._ws.connect_args = { - 'suppress_origin': 'Origin' in processed_headers, + 'suppress_origin': self.plugin.sets_origin or origin in self._process_headers(headers, headers_file, None, user_agent, origin, cookies), 'sslopt': {'cert_reqs': ssl.CERT_NONE} if not verify_tls else {}, 'ping_interval': 0, # Disable websocket-client's autoping because it doesn't provide feedback 'http_proxy_host': urlparse(proxy).hostname if proxy else None, @@ -91,8 +94,7 @@ def __init__(self, self.hide_ping_pong = hide_ping_pong self.hide_0x1_ping_pong = hide_0x1_ping_pong - def _process_headers(self, headers: list[str] | None, headers_file: str | None, - user_agent: str | None, origin: str | None, cookies: list[str]) -> OrderedDict: + def _process_headers(self, headers: list[str] | None, headers_file: str | None, get_plugin_headers: Callable[[], list[str]] | None, user_agent: str | None, origin: str | None, cookies: list[str]) -> OrderedDict: """Process headers and return an OrderedDict of headers.""" result = OrderedDict() cookie_headers = [] @@ -129,6 +131,18 @@ def _process_headers(self, headers: list[str] | None, headers_file: str | None, elif name not in result: result[name] = value + # Headers from plugins are next + if get_plugin_headers is not None: + for header in get_plugin_headers(): + name, value = map(str.strip, header.split(":", 1)) + if name in blacklisted_headers: + continue + + if name.lower().strip() == "cookie": + cookie_headers.append(value.strip()) + else: + result[name] = value + # Add User-Agent if not already present if user_agent and "User-Agent" not in result: result["User-Agent"] = user_agent diff --git a/wsrepl/Plugin.py b/wsrepl/Plugin.py index 6f58371..cdca69d 100644 --- a/wsrepl/Plugin.py +++ b/wsrepl/Plugin.py @@ -10,12 +10,17 @@ def __init__(self, message_handler) -> None: self.messages = [] self.ping_0x1_payload = "" self.pong_0x1_payload = "" + self.sets_origin = False self.init() def init(self): """Called when the plugin is loaded""" pass + def headers_callback(self): + """Set headers on connect and reconnect""" + return [] + async def send(self, message: WSMessage) -> None: """Send a message to the server""" await self.handler.send(message)