11import subprocess
22from typing import Optional , Dict , cast
3+
4+ import paramiko
35from dagster import DagsterError
4- from dagster_teradata . ttu . utils . encryption_utils import SecureCredentialManager
6+
57from dagster_teradata .ttu .utils .tpt_util import (
68 prepare_tpt_ddl_script ,
79 get_remote_temp_directory ,
1012 is_valid_file ,
1113 read_file ,
1214)
13- from dagster_teradata . ttu . utils . util import _setup_ssh_connection
15+
1416from dagster_teradata .ttu .tpt_executer import (
1517 execute_ddl ,
1618 _execute_tdload_via_ssh ,
1719 _execute_tdload_locally ,
1820 execute_tdload ,
19- TPTExecutor ,
2021)
22+
23+ from dagster_teradata .ttu .utils .encryption_utils import (
24+ SecureCredentialManager ,
25+ generate_random_password ,
26+ generate_encrypted_file_with_openssl ,
27+ decrypt_remote_file_to_string ,
28+ get_stored_credentials ,
29+ )
30+
2131from paramiko .client import SSHClient
2232
2333
34+
35+
36+
2437class DdlOperator :
2538 """Operator for executing DDL statements on Teradata using TPT."""
2639
@@ -40,6 +53,63 @@ def __init__(self, connection, teradata_connection_resource, log):
4053 self .ssh_key_path = None
4154 self .remote_port = None
4255
56+ def _setup_ssh_connection (
57+ self ,
58+ host : str ,
59+ user : Optional [str ],
60+ password : Optional [str ],
61+ key_path : Optional [str ],
62+ port : int ,
63+ ) -> bool :
64+ """
65+ Establish SSH connection using either password or key authentication.
66+
67+ Args:
68+ host: Remote hostname
69+ user: Remote username
70+ password: Remote password (optional if key_path provided)
71+ key_path: Path to SSH private key (optional if password provided)
72+ port: SSH port
73+
74+ Returns:
75+ bool: True if connection succeeded, False otherwise
76+
77+ Raises:
78+ DagsterError: If connection fails
79+
80+ Note:
81+ - Tries stored credentials if no password provided
82+ - Prompts for password if no credentials available
83+ - Stores new credentials if successfully authenticated
84+ """
85+ try :
86+ self .ssh_client = paramiko .SSHClient ()
87+ self .ssh_client .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
88+
89+ if key_path :
90+ key = paramiko .RSAKey .from_private_key_file (key_path )
91+ self .ssh_client .connect (host , port = port , username = user , pkey = key )
92+ else :
93+ if not password :
94+ if user is None :
95+ raise ValueError (
96+ "Username is required to fetch stored credentials"
97+ )
98+ # Attempt to retrieve stored credentials
99+ creds = get_stored_credentials (self , host , user )
100+ password = (
101+ self .cred_manager .decrypt (creds ["password" ]) if creds else None
102+ )
103+
104+ self .ssh_client .connect (
105+ host , port = port , username = user , password = password
106+ )
107+
108+ self .log .info (f"SSH connected to { user } @{ host } " )
109+ return True
110+ except Exception as e :
111+ raise DagsterError (f"SSH connection failed: { e } " )
112+
43113 def ddl_operator (
44114 self ,
45115 ddl : list [str ] = None ,
@@ -81,15 +151,16 @@ def ddl_operator(
81151 )
82152
83153 if self .remote_host :
84- self . ssh_client = _setup_ssh_connection (
154+ if not self . _setup_ssh_connection (
85155 host = self .remote_host ,
86156 user = cast (str , self .remote_user ),
87157 password = self .remote_remote_password ,
88158 key_path = self .ssh_key_path ,
89159 port = self .remote_port ,
90- )
91- if not self .ssh_client :
92- raise DagsterError ("Failed to establish SSH connection." )
160+ ):
161+ raise DagsterError (
162+ "Failed to establish SSH connection. Please check the provided credentials."
163+ )
93164
94165 if self .ssh_client and not self .remote_working_dir :
95166 self .remote_working_dir = get_remote_temp_directory (
@@ -99,7 +170,7 @@ def ddl_operator(
99170 if not self .remote_working_dir :
100171 self .remote_working_dir = "/tmp"
101172
102- return execute_ddl (tpt_ddl_script , self .remote_working_dir )
173+ return execute_ddl (self , tpt_ddl_script , self .remote_working_dir )
103174 except Exception as e :
104175 self .log .error ("DDL execution failed: %s" , str (e ))
105176 raise
@@ -167,6 +238,63 @@ def __init__(
167238 self .ssh_key_path = None
168239 self .remote_port = None
169240
241+ def _setup_ssh_connection (
242+ self ,
243+ host : str ,
244+ user : Optional [str ],
245+ password : Optional [str ],
246+ key_path : Optional [str ],
247+ port : int ,
248+ ) -> bool :
249+ """
250+ Establish SSH connection using either password or key authentication.
251+
252+ Args:
253+ host: Remote hostname
254+ user: Remote username
255+ password: Remote password (optional if key_path provided)
256+ key_path: Path to SSH private key (optional if password provided)
257+ port: SSH port
258+
259+ Returns:
260+ bool: True if connection succeeded, False otherwise
261+
262+ Raises:
263+ DagsterError: If connection fails
264+
265+ Note:
266+ - Tries stored credentials if no password provided
267+ - Prompts for password if no credentials available
268+ - Stores new credentials if successfully authenticated
269+ """
270+ try :
271+ self .ssh_client = paramiko .SSHClient ()
272+ self .ssh_client .set_missing_host_key_policy (paramiko .AutoAddPolicy ())
273+
274+ if key_path :
275+ key = paramiko .RSAKey .from_private_key_file (key_path )
276+ self .ssh_client .connect (host , port = port , username = user , pkey = key )
277+ else :
278+ if not password :
279+ if user is None :
280+ raise ValueError (
281+ "Username is required to fetch stored credentials"
282+ )
283+ # Attempt to retrieve stored credentials
284+ creds = get_stored_credentials (self , host , user )
285+ password = (
286+ self .cred_manager .decrypt (creds ["password" ]) if creds else None
287+ )
288+
289+ self .ssh_client .connect (
290+ host , port = port , username = user , password = password
291+ )
292+
293+ self .log .info (f"SSH connected to { user } @{ host } " )
294+ return True
295+ except Exception as e :
296+ raise DagsterError (f"SSH connection failed: { e } " )
297+
170298 def tdload_operator (
171299 self ,
172300 source_table : Optional [str ] = None ,
@@ -221,16 +349,16 @@ def tdload_operator(
221349 self .log .info ("Prepared job vars for mode '%s'" , mode )
222350
223351 if self .remote_host :
224- self .ssh_client = _setup_ssh_connection (
225- self ,
352+ if not self ._setup_ssh_connection (
226353 host = self .remote_host ,
227354 user = cast (str , self .remote_user ),
228355 password = self .remote_remote_password ,
229356 key_path = self .ssh_key_path ,
230357 port = self .remote_port ,
231- )
232- if not self .ssh_client :
233- raise DagsterError ("SSH connection failed." )
358+ ):
359+ raise DagsterError (
360+ "Failed to establish SSH connection. Please check the provided credentials."
361+ )
234362
235363 if self .remote_host and not self .remote_working_dir :
236364 self .remote_working_dir = get_remote_temp_directory (
@@ -308,6 +436,7 @@ def _execute_based_on_configuration(
308436 )
309437 raise ValueError ("Invalid remote job var file path" )
310438
439+ else :
311440 return execute_tdload (
312441 self ,
313442 log = self .log ,
@@ -324,15 +453,15 @@ def _execute_based_on_configuration(
324453 file_path = tdload_job_var_file
325454 )
326455 raise ValueError ("Invalid local job var file path" )
327-
328- return execute_tdload (
329- self ,
330- log = self .log ,
331- remote_working_dir = self .remote_working_dir or "/tmp" ,
332- job_var_content = tdload_job_var_content ,
333- tdload_options = self .tdload_options ,
334- tdload_job_name = self .tdload_job_name ,
335- )
456+ else :
457+ return execute_tdload (
458+ self ,
459+ log = self .log ,
460+ remote_working_dir = self .remote_working_dir or "/tmp" ,
461+ job_var_content = tdload_job_var_content ,
462+ tdload_options = self .tdload_options ,
463+ tdload_job_name = self .tdload_job_name ,
464+ )
336465
337466 def _handle_remote_job_var_file (
338467 self , ssh_client : SSHClient , file_path : str | None
@@ -399,54 +528,3 @@ def on_kill(self):
399528 process .kill ()
400529 except Exception as e :
401530 self .log .error ("Process termination failed: %s" , str (e ))
402-
403-
404- class TPTOperator :
405- """Enhanced TPT operator based on Airflow's TPTOperator."""
406-
407- def __init__ (
408- self ,
409- teradata_connection_resource ,
410- operator_type : str ,
411- script : Optional [str ] = None ,
412- variables : Optional [Dict ] = None ,
413- source_table : Optional [str ] = None ,
414- select_stmt : Optional [str ] = None ,
415- target_table : Optional [str ] = None ,
416- source_file : Optional [str ] = None ,
417- target_file : Optional [str ] = None ,
418- format_options : Optional [Dict ] = None ,
419- ssh_conn_params : Optional [Dict ] = None ,
420- log = None ,
421- ):
422- self .teradata_connection_resource = teradata_connection_resource
423- self .operator_type = operator_type
424- self .script = script
425- self .variables = variables or {}
426- self .source_table = source_table
427- self .select_stmt = select_stmt
428- self .target_table = target_table
429- self .source_file = source_file
430- self .target_file = target_file
431- self .format_options = format_options or {}
432- self .ssh_conn_params = ssh_conn_params
433- self .log = log
434-
435- def execute (self ) -> int :
436- """Execute the TPT operation."""
437- # Get connection parameters
438- conn_params = self .teradata_connection_resource ._connection_args
439-
440- # Create TPT executor
441- with TPTExecutor (conn_params , self .ssh_conn_params ) as executor :
442- return executor .run_tpt_job (
443- operator_type = self .operator_type ,
444- script = self .script ,
445- variables = self .variables ,
446- source_table = self .source_table ,
447- select_stmt = self .select_stmt ,
448- target_table = self .target_table ,
449- source_file = self .source_file ,
450- target_file = self .target_file ,
451- format_options = self .format_options ,
452- )
0 commit comments