Skip to content

Commit 15014ed

Browse files
fix: Replace unmaintained sshtunnel dependency with straight Paramiko (#581)
1 parent c59ba04 commit 15014ed

File tree

3 files changed

+167
-24
lines changed

3 files changed

+167
-24
lines changed

pyproject.toml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ dependencies = [
3131
"psycopg[binary]==3.2.10",
3232
"psycopg2-binary==2.9.10",
3333
"sqlalchemy~=2.0",
34-
"sshtunnel==0.4.0",
3534
"singer-sdk[sql]~=0.52.0",
3635
]
3736

@@ -83,10 +82,6 @@ warn_redundant_casts = true
8382
warn_unused_configs = true
8483
warn_unused_ignores = true
8584

86-
[[tool.mypy.overrides]]
87-
module = ["sshtunnel"]
88-
ignore_missing_imports = true
89-
9085
[build-system]
9186
requires = [
9287
"hatchling==1.27.0",

target_postgres/connector.py

Lines changed: 165 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import itertools
88
import math
99
import signal
10+
import socket
1011
import sys
12+
import threading
1113
import typing as t
12-
from contextlib import contextmanager
14+
from contextlib import contextmanager, suppress
1315
from functools import cached_property
1416
from os import chmod, path
1517
from typing import cast
@@ -40,12 +42,169 @@
4042
TIMESTAMP,
4143
TypeDecorator,
4244
)
43-
from sshtunnel import SSHTunnelForwarder
4445

4546
if t.TYPE_CHECKING:
4647
from singer_sdk.sql.connector import FullyQualifiedName
4748

4849

50+
class SSHTunnelForwarder:
51+
"""SSH Tunnel forwarder using paramiko.
52+
53+
This class provides SSH tunnel functionality similar to sshtunnel package,
54+
but implemented directly with paramiko.
55+
"""
56+
57+
def __init__(
58+
self,
59+
ssh_address_or_host: tuple[str, int],
60+
ssh_username: str,
61+
ssh_pkey: paramiko.PKey,
62+
ssh_private_key_password: str | None,
63+
remote_bind_address: tuple[str, int],
64+
) -> None:
65+
"""Initialize SSH tunnel forwarder.
66+
67+
Args:
68+
ssh_address_or_host: Tuple of (ssh_host, ssh_port)
69+
ssh_username: SSH username
70+
ssh_pkey: Paramiko private key object
71+
ssh_private_key_password: Private key password (optional)
72+
remote_bind_address: Tuple of (remote_host, remote_port)
73+
"""
74+
self.ssh_host, self.ssh_port = ssh_address_or_host
75+
self.ssh_username = ssh_username
76+
self.ssh_pkey = ssh_pkey
77+
self.ssh_private_key_password = ssh_private_key_password
78+
self.remote_bind_host, self.remote_bind_port = remote_bind_address
79+
80+
self.ssh_client: paramiko.SSHClient | None = None
81+
self.local_bind_host = "127.0.0.1"
82+
self.local_bind_port: int | None = None
83+
self._server_socket: socket.socket | None = None
84+
self._thread: threading.Thread | None = None
85+
self._stop_event = threading.Event()
86+
87+
def start(self) -> None:
88+
"""Start the SSH tunnel."""
89+
# Create SSH client
90+
self.ssh_client = paramiko.SSHClient()
91+
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
92+
93+
# Connect to SSH server
94+
self.ssh_client.connect(
95+
hostname=self.ssh_host,
96+
port=self.ssh_port,
97+
username=self.ssh_username,
98+
pkey=self.ssh_pkey,
99+
passphrase=self.ssh_private_key_password,
100+
)
101+
102+
# Create local socket for port forwarding
103+
self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
104+
self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
105+
self._server_socket.bind((self.local_bind_host, 0))
106+
self._server_socket.listen(5)
107+
108+
# Get the dynamically assigned local port
109+
self.local_bind_port = self._server_socket.getsockname()[1]
110+
111+
# Start forwarding thread
112+
self._thread = threading.Thread(target=self._forward_tunnel, daemon=True)
113+
self._thread.start()
114+
115+
def _forward_tunnel(self) -> None:
116+
"""Forward connections through the SSH tunnel."""
117+
if self._server_socket is None or self.ssh_client is None:
118+
return
119+
120+
while not self._stop_event.is_set():
121+
try:
122+
# Set timeout so we can check stop event periodically
123+
self._server_socket.settimeout(1.0)
124+
try:
125+
local_socket, _ = self._server_socket.accept()
126+
except TimeoutError:
127+
continue
128+
129+
# Create channel through SSH tunnel
130+
transport = self.ssh_client.get_transport()
131+
if transport is None:
132+
local_socket.close()
133+
continue
134+
135+
channel = transport.open_channel(
136+
"direct-tcpip",
137+
(self.remote_bind_host, self.remote_bind_port),
138+
local_socket.getpeername(),
139+
)
140+
141+
# Start forwarding data between local socket and channel
142+
threading.Thread(
143+
target=self._forward_data,
144+
args=(local_socket, channel),
145+
daemon=True,
146+
).start()
147+
except OSError:
148+
if not self._stop_event.is_set():
149+
break
150+
151+
def _forward_data(
152+
self, local_socket: socket.socket, channel: paramiko.Channel
153+
) -> None:
154+
"""Forward data between local socket and SSH channel.
155+
156+
Args:
157+
local_socket: Local socket
158+
channel: SSH channel
159+
"""
160+
try:
161+
162+
def forward_local_to_remote():
163+
while True:
164+
data = local_socket.recv(4096)
165+
if len(data) == 0:
166+
break
167+
channel.send(data)
168+
channel.close()
169+
170+
def forward_remote_to_local():
171+
while True:
172+
data = channel.recv(4096)
173+
if len(data) == 0:
174+
break
175+
local_socket.send(data)
176+
local_socket.close()
177+
178+
# Start both forwarding directions
179+
t1 = threading.Thread(target=forward_local_to_remote, daemon=True)
180+
t2 = threading.Thread(target=forward_remote_to_local, daemon=True)
181+
t1.start()
182+
t2.start()
183+
t1.join()
184+
t2.join()
185+
except OSError:
186+
pass
187+
finally:
188+
with suppress(OSError):
189+
local_socket.close()
190+
with suppress(OSError):
191+
channel.close()
192+
193+
def stop(self) -> None:
194+
"""Stop the SSH tunnel."""
195+
self._stop_event.set()
196+
197+
if self._server_socket:
198+
with suppress(OSError):
199+
self._server_socket.close()
200+
201+
if self._thread and self._thread.is_alive():
202+
self._thread.join(timeout=2.0)
203+
204+
if self.ssh_client:
205+
self.ssh_client.close()
206+
207+
49208
class JSONSchemaToPostgres(JSONSchemaToSQL):
50209
"""Convert JSON Schema types to Postgres types."""
51210

@@ -88,10 +247,13 @@ def __init__(self, config: dict) -> None:
88247
"""
89248
url: URL = make_url(self.get_sqlalchemy_url(config=config))
90249
ssh_config = config.get("ssh_tunnel", {})
91-
self.ssh_tunnel: SSHTunnelForwarder
250+
self.ssh_tunnel: SSHTunnelForwarder | None = None
92251

93252
if ssh_config.get("enable", False):
94253
# Return a new URL with SSH tunnel parameters
254+
if url.host is None or url.port is None:
255+
msg = "Database host and port must be specified for SSH tunnel"
256+
raise ValueError(msg)
95257
self.ssh_tunnel = SSHTunnelForwarder(
96258
ssh_address_or_host=(ssh_config["host"], ssh_config["port"]),
97259
ssh_username=ssh_config["username"],

uv.lock

Lines changed: 2 additions & 16 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)