Skip to content

Commit d020d52

Browse files
committed
Made tdload code changes
1 parent 0943ab2 commit d020d52

File tree

6 files changed

+1173
-369
lines changed

6 files changed

+1173
-369
lines changed

libraries/dagster-teradata/dagster_teradata/resources.py

Lines changed: 1 addition & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
from . import constants
3131
from dagster_teradata.ttu.bteq import Bteq
32-
from dagster_teradata.ttu.tpt import DdlOperator, TdLoadOperator, TPTOperator
32+
from dagster_teradata.ttu.tpt import DdlOperator, TdLoadOperator
3333
from dagster_teradata.teradata_compute_cluster_manager import (
3434
TeradataComputeClusterManager,
3535
)
@@ -676,53 +676,6 @@ def tdload_operator(
676676
remote_port=remote_port,
677677
)
678678

679-
def run_tpt_job(
680-
self,
681-
operator_type: str,
682-
script: Optional[str] = None,
683-
variables: Optional[Dict] = None,
684-
source_table: Optional[str] = None,
685-
select_stmt: Optional[str] = None,
686-
target_table: Optional[str] = None,
687-
source_file: Optional[str] = None,
688-
target_file: Optional[str] = None,
689-
format_options: Optional[Dict] = None,
690-
ssh_conn_params: Optional[Dict] = None,
691-
) -> int:
692-
"""
693-
Run a TPT job with the given parameters.
694-
695-
Args:
696-
operator_type (str): Type of TPT operator (tdload, tbuild, tpump, etc.)
697-
script (Optional[str]): TPT script content
698-
variables (Optional[Dict]): TPT variables
699-
source_table (Optional[str]): Source table name
700-
select_stmt (Optional[str]): SELECT statement for data extraction
701-
target_table (Optional[str]): Target table name
702-
source_file (Optional[str]): Source file path
703-
target_file (Optional[str]): Target file path
704-
format_options (Optional[Dict]): Format options for source/target
705-
ssh_conn_params (Optional[Dict]): SSH connection parameters for remote execution
706-
707-
Returns:
708-
int: Return code from the TPT operation
709-
"""
710-
operator = TPTOperator(
711-
teradata_connection_resource=self.teradata_connection_resource,
712-
operator_type=operator_type,
713-
script=script,
714-
variables=variables,
715-
source_table=source_table,
716-
select_stmt=select_stmt,
717-
target_table=target_table,
718-
source_file=source_file,
719-
target_file=target_file,
720-
format_options=format_options,
721-
ssh_conn_params=ssh_conn_params,
722-
log=self.log,
723-
)
724-
return operator.execute()
725-
726679
def drop_database(self, databases: Union[str, Sequence[str]]) -> None:
727680
"""
728681
Drop one or more databases in Teradata.

libraries/dagster-teradata/dagster_teradata/ttu/bteq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,7 @@ def _transfer_to_and_execute_bteq_on_remote(
416416
raise DagsterError(
417417
"Failed to establish SSH connection. `ssh_client` is None."
418418
)
419+
self.log.info("Successfully established SSH connection with host: %s", remote_host)
419420
verify_bteq_installed_remote(self.ssh_client)
420421
password = generate_random_password() # Encryption/Decryption password
421422
encrypted_file_path = os.path.join(tmp_dir, "bteq_script.enc")
@@ -542,7 +543,7 @@ def execute_bteq_script_at_local(
542543
stdout=subprocess.PIPE,
543544
stderr=subprocess.STDOUT,
544545
shell=True,
545-
# preexec_fn=os.setsid,
546+
preexec_fn=os.setsid,
546547
)
547548
encode_bteq_script = bteq_script.encode(str(temp_file_read_encoding or "UTF-8"))
548549
self.log.debug("encode_bteq_script : %s", encode_bteq_script)

libraries/dagster-teradata/dagster_teradata/ttu/tpt.py

Lines changed: 151 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import subprocess
22
from typing import Optional, Dict, cast
3+
4+
import paramiko
35
from dagster import DagsterError
4-
from dagster_teradata.ttu.utils.encryption_utils import SecureCredentialManager
6+
57
from dagster_teradata.ttu.utils.tpt_util import (
68
prepare_tpt_ddl_script,
79
get_remote_temp_directory,
@@ -10,17 +12,28 @@
1012
is_valid_file,
1113
read_file,
1214
)
13-
from dagster_teradata.ttu.utils.util import _setup_ssh_connection
15+
1416
from dagster_teradata.ttu.tpt_executer import (
1517
execute_ddl,
1618
_execute_tdload_via_ssh,
1719
_execute_tdload_locally,
1820
execute_tdload,
19-
TPTExecutor,
2021
)
22+
23+
from dagster_teradata.ttu.utils.encryption_utils import (
24+
SecureCredentialManager,
25+
generate_random_password,
26+
generate_encrypted_file_with_openssl,
27+
decrypt_remote_file_to_string,
28+
get_stored_credentials,
29+
)
30+
2131
from paramiko.client import SSHClient
2232

2333

34+
35+
36+
2437
class DdlOperator:
2538
"""Operator for executing DDL statements on Teradata using TPT."""
2639

@@ -40,6 +53,63 @@ def __init__(self, connection, teradata_connection_resource, log):
4053
self.ssh_key_path = None
4154
self.remote_port = None
4255

56+
def _setup_ssh_connection(
57+
self,
58+
host: str,
59+
user: Optional[str],
60+
password: Optional[str],
61+
key_path: Optional[str],
62+
port: int,
63+
) -> bool:
64+
"""
65+
Establish SSH connection using either password or key authentication.
66+
67+
Args:
68+
host: Remote hostname
69+
user: Remote username
70+
password: Remote password (optional if key_path provided)
71+
key_path: Path to SSH private key (optional if password provided)
72+
port: SSH port
73+
74+
Returns:
75+
bool: True if connection succeeded, False otherwise
76+
77+
Raises:
78+
DagsterError: If connection fails
79+
80+
Note:
81+
- Tries stored credentials if no password provided
82+
- Prompts for password if no credentials available
83+
- Stores new credentials if successfully authenticated
84+
"""
85+
try:
86+
self.ssh_client = paramiko.SSHClient()
87+
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
88+
89+
if key_path:
90+
key = paramiko.RSAKey.from_private_key_file(key_path)
91+
self.ssh_client.connect(host, port=port, username=user, pkey=key)
92+
else:
93+
if not password:
94+
if user is None:
95+
raise ValueError(
96+
"Username is required to fetch stored credentials"
97+
)
98+
# Attempt to retrieve stored credentials
99+
creds = get_stored_credentials(self, host, user)
100+
password = (
101+
self.cred_manager.decrypt(creds["password"]) if creds else None
102+
)
103+
104+
self.ssh_client.connect(
105+
host, port=port, username=user, password=password
106+
)
107+
108+
self.log.info(f"SSH connected to {user}@{host}")
109+
return True
110+
except Exception as e:
111+
raise DagsterError(f"SSH connection failed: {e}")
112+
43113
def ddl_operator(
44114
self,
45115
ddl: list[str] = None,
@@ -81,15 +151,16 @@ def ddl_operator(
81151
)
82152

83153
if self.remote_host:
84-
self.ssh_client = _setup_ssh_connection(
154+
if not self._setup_ssh_connection(
85155
host=self.remote_host,
86156
user=cast(str, self.remote_user),
87157
password=self.remote_remote_password,
88158
key_path=self.ssh_key_path,
89159
port=self.remote_port,
90-
)
91-
if not self.ssh_client:
92-
raise DagsterError("Failed to establish SSH connection.")
160+
):
161+
raise DagsterError(
162+
"Failed to establish SSH connection. Please check the provided credentials."
163+
)
93164

94165
if self.ssh_client and not self.remote_working_dir:
95166
self.remote_working_dir = get_remote_temp_directory(
@@ -99,7 +170,7 @@ def ddl_operator(
99170
if not self.remote_working_dir:
100171
self.remote_working_dir = "/tmp"
101172

102-
return execute_ddl(tpt_ddl_script, self.remote_working_dir)
173+
return execute_ddl(self, tpt_ddl_script, self.remote_working_dir)
103174
except Exception as e:
104175
self.log.error("DDL execution failed: %s", str(e))
105176
raise
@@ -167,6 +238,63 @@ def __init__(
167238
self.ssh_key_path = None
168239
self.remote_port = None
169240

241+
def _setup_ssh_connection(
242+
self,
243+
host: str,
244+
user: Optional[str],
245+
password: Optional[str],
246+
key_path: Optional[str],
247+
port: int,
248+
) -> bool:
249+
"""
250+
Establish SSH connection using either password or key authentication.
251+
252+
Args:
253+
host: Remote hostname
254+
user: Remote username
255+
password: Remote password (optional if key_path provided)
256+
key_path: Path to SSH private key (optional if password provided)
257+
port: SSH port
258+
259+
Returns:
260+
bool: True if connection succeeded, False otherwise
261+
262+
Raises:
263+
DagsterError: If connection fails
264+
265+
Note:
266+
- Tries stored credentials if no password provided
267+
- Prompts for password if no credentials available
268+
- Stores new credentials if successfully authenticated
269+
"""
270+
try:
271+
self.ssh_client = paramiko.SSHClient()
272+
self.ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
273+
274+
if key_path:
275+
key = paramiko.RSAKey.from_private_key_file(key_path)
276+
self.ssh_client.connect(host, port=port, username=user, pkey=key)
277+
else:
278+
if not password:
279+
if user is None:
280+
raise ValueError(
281+
"Username is required to fetch stored credentials"
282+
)
283+
# Attempt to retrieve stored credentials
284+
creds = get_stored_credentials(self, host, user)
285+
password = (
286+
self.cred_manager.decrypt(creds["password"]) if creds else None
287+
)
288+
289+
self.ssh_client.connect(
290+
host, port=port, username=user, password=password
291+
)
292+
293+
self.log.info(f"SSH connected to {user}@{host}")
294+
return True
295+
except Exception as e:
296+
raise DagsterError(f"SSH connection failed: {e}")
297+
170298
def tdload_operator(
171299
self,
172300
source_table: Optional[str] = None,
@@ -221,16 +349,16 @@ def tdload_operator(
221349
self.log.info("Prepared job vars for mode '%s'", mode)
222350

223351
if self.remote_host:
224-
self.ssh_client = _setup_ssh_connection(
225-
self,
352+
if not self._setup_ssh_connection(
226353
host=self.remote_host,
227354
user=cast(str, self.remote_user),
228355
password=self.remote_remote_password,
229356
key_path=self.ssh_key_path,
230357
port=self.remote_port,
231-
)
232-
if not self.ssh_client:
233-
raise DagsterError("SSH connection failed.")
358+
):
359+
raise DagsterError(
360+
"Failed to establish SSH connection. Please check the provided credentials."
361+
)
234362

235363
if self.remote_host and not self.remote_working_dir:
236364
self.remote_working_dir = get_remote_temp_directory(
@@ -308,6 +436,7 @@ def _execute_based_on_configuration(
308436
)
309437
raise ValueError("Invalid remote job var file path")
310438

439+
else:
311440
return execute_tdload(
312441
self,
313442
log=self.log,
@@ -324,15 +453,15 @@ def _execute_based_on_configuration(
324453
file_path=tdload_job_var_file
325454
)
326455
raise ValueError("Invalid local job var file path")
327-
328-
return execute_tdload(
329-
self,
330-
log=self.log,
331-
remote_working_dir=self.remote_working_dir or "/tmp",
332-
job_var_content=tdload_job_var_content,
333-
tdload_options=self.tdload_options,
334-
tdload_job_name=self.tdload_job_name,
335-
)
456+
else:
457+
return execute_tdload(
458+
self,
459+
log=self.log,
460+
remote_working_dir=self.remote_working_dir or "/tmp",
461+
job_var_content=tdload_job_var_content,
462+
tdload_options=self.tdload_options,
463+
tdload_job_name=self.tdload_job_name,
464+
)
336465

337466
def _handle_remote_job_var_file(
338467
self, ssh_client: SSHClient, file_path: str | None
@@ -399,54 +528,3 @@ def on_kill(self):
399528
process.kill()
400529
except Exception as e:
401530
self.log.error("Process termination failed: %s", str(e))
402-
403-
404-
class TPTOperator:
405-
"""Enhanced TPT operator based on Airflow's TPTOperator."""
406-
407-
def __init__(
408-
self,
409-
teradata_connection_resource,
410-
operator_type: str,
411-
script: Optional[str] = None,
412-
variables: Optional[Dict] = None,
413-
source_table: Optional[str] = None,
414-
select_stmt: Optional[str] = None,
415-
target_table: Optional[str] = None,
416-
source_file: Optional[str] = None,
417-
target_file: Optional[str] = None,
418-
format_options: Optional[Dict] = None,
419-
ssh_conn_params: Optional[Dict] = None,
420-
log=None,
421-
):
422-
self.teradata_connection_resource = teradata_connection_resource
423-
self.operator_type = operator_type
424-
self.script = script
425-
self.variables = variables or {}
426-
self.source_table = source_table
427-
self.select_stmt = select_stmt
428-
self.target_table = target_table
429-
self.source_file = source_file
430-
self.target_file = target_file
431-
self.format_options = format_options or {}
432-
self.ssh_conn_params = ssh_conn_params
433-
self.log = log
434-
435-
def execute(self) -> int:
436-
"""Execute the TPT operation."""
437-
# Get connection parameters
438-
conn_params = self.teradata_connection_resource._connection_args
439-
440-
# Create TPT executor
441-
with TPTExecutor(conn_params, self.ssh_conn_params) as executor:
442-
return executor.run_tpt_job(
443-
operator_type=self.operator_type,
444-
script=self.script,
445-
variables=self.variables,
446-
source_table=self.source_table,
447-
select_stmt=self.select_stmt,
448-
target_table=self.target_table,
449-
source_file=self.source_file,
450-
target_file=self.target_file,
451-
format_options=self.format_options,
452-
)

0 commit comments

Comments
 (0)