@@ -16,17 +16,18 @@ class Athena:
16
16
def __init__ (self , session ):
17
17
self ._session = session
18
18
self ._client_athena = session .boto3_session .client (
19
- service_name = "athena" , config = session .botocore_config )
19
+ service_name = "athena" , config = session .botocore_config
20
+ )
20
21
21
22
def get_query_columns_metadata (self , query_execution_id ):
22
23
response = self ._client_athena .get_query_results (
23
- QueryExecutionId = query_execution_id , MaxResults = 1 )
24
+ QueryExecutionId = query_execution_id , MaxResults = 1
25
+ )
24
26
col_info = response ["ResultSet" ]["ResultSetMetadata" ]["ColumnInfo" ]
25
27
return {x ["Name" ]: x ["Type" ] for x in col_info }
26
28
27
29
def get_query_dtype (self , query_execution_id ):
28
- cols_metadata = self .get_query_columns_metadata (
29
- query_execution_id = query_execution_id )
30
+ cols_metadata = self .get_query_columns_metadata (query_execution_id = query_execution_id )
30
31
logger .debug (f"cols_metadata: { cols_metadata } " )
31
32
dtype = {}
32
33
parse_timestamps = []
@@ -53,10 +54,11 @@ def create_athena_bucket(self):
53
54
54
55
:return: Bucket s3 path (E.g. s3://aws-athena-query-results-ACCOUNT-REGION/)
55
56
"""
56
- account_id = (self ._session .boto3_session .client (
57
- service_name = "sts" ,
58
- config = self ._session .botocore_config ).get_caller_identity ().get (
59
- "Account" ))
57
+ account_id = (
58
+ self ._session .boto3_session .client (
59
+ service_name = "sts" , config = self ._session .botocore_config
60
+ ).get_caller_identity ().get ("Account" )
61
+ )
60
62
session_region = self ._session .boto3_session .region_name
61
63
s3_output = f"s3://aws-athena-query-results-{ account_id } -{ session_region } /"
62
64
s3_resource = self ._session .boto3_session .resource ("s3" )
@@ -82,7 +84,8 @@ def run_query(self, query, database, s3_output=None, workgroup=None):
82
84
QueryString = query ,
83
85
QueryExecutionContext = {"Database" : database },
84
86
ResultConfiguration = {"OutputLocation" : s3_output },
85
- WorkGroup = workgroup )
87
+ WorkGroup = workgroup
88
+ )
86
89
return response ["QueryExecutionId" ]
87
90
88
91
def wait_query (self , query_execution_id ):
@@ -93,24 +96,20 @@ def wait_query(self, query_execution_id):
93
96
:return: Query response
94
97
"""
95
98
final_states = ["FAILED" , "SUCCEEDED" , "CANCELLED" ]
96
- response = self ._client_athena .get_query_execution (
97
- QueryExecutionId = query_execution_id )
99
+ response = self ._client_athena .get_query_execution (QueryExecutionId = query_execution_id )
98
100
state = response ["QueryExecution" ]["Status" ]["State" ]
99
101
while state not in final_states :
100
102
sleep (QUERY_WAIT_POLLING_DELAY )
101
- response = self ._client_athena .get_query_execution (
102
- QueryExecutionId = query_execution_id )
103
+ response = self ._client_athena .get_query_execution (QueryExecutionId = query_execution_id )
103
104
state = response ["QueryExecution" ]["Status" ]["State" ]
104
105
logger .debug (f"state: { state } " )
105
106
logger .debug (
106
107
f"StateChangeReason: { response ['QueryExecution' ]['Status' ].get ('StateChangeReason' )} "
107
108
)
108
109
if state == "FAILED" :
109
- raise QueryFailed (
110
- response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
110
+ raise QueryFailed (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
111
111
elif state == "CANCELLED" :
112
- raise QueryCancelled (
113
- response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
112
+ raise QueryCancelled (response ["QueryExecution" ]["Status" ].get ("StateChangeReason" ))
114
113
return response
115
114
116
115
def repair_table (self , database , table , s3_output = None , workgroup = None ):
@@ -130,17 +129,17 @@ def repair_table(self, database, table, s3_output=None, workgroup=None):
130
129
:return: Query execution ID
131
130
"""
132
131
query = f"MSCK REPAIR TABLE { table } ;"
133
- query_id = self .run_query (query = query ,
134
- database = database ,
135
- s3_output = s3_output ,
136
- workgroup = workgroup )
132
+ query_id = self .run_query (
133
+ query = query , database = database , s3_output = s3_output , workgroup = workgroup
134
+ )
137
135
self .wait_query (query_execution_id = query_id )
138
136
return query_id
139
137
140
138
@staticmethod
141
139
def _normalize_name (name ):
142
- name = "" .join (c for c in unicodedata .normalize ("NFD" , name )
143
- if unicodedata .category (c ) != "Mn" )
140
+ name = "" .join (
141
+ c for c in unicodedata .normalize ("NFD" , name ) if unicodedata .category (c ) != "Mn"
142
+ )
144
143
name = name .replace (" " , "_" )
145
144
name = name .replace ("-" , "_" )
146
145
name = name .replace ("." , "_" )
0 commit comments