Skip to content

Commit 2ddbf39

Browse files
committed
Refactored scopes and auth
1 parent 3fbc632 commit 2ddbf39

File tree

11 files changed

+109
-193
lines changed

11 files changed

+109
-193
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,4 @@ logs/*
8787

8888
# etc
8989
swat/etc/*chrome*
90+
swat/etc/custom_config.yaml

swat/base.py

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11

2+
import copy
23
import dataclasses
34
import logging
45
import pickle
56
from dataclasses import dataclass, field
67
from pathlib import Path
78
from typing import Optional, Literal, Union
89

9-
import json
1010
from google.auth.transport.requests import Request
11+
from google.oauth2.service_account import Credentials as ServiceCredentials
1112
from google.oauth2.credentials import Credentials
1213

13-
from .utils import ROOT_DIR, PathlibEncoder
14+
import json
15+
import yaml
16+
17+
from .utils import ROOT_DIR, deep_merge
1418

1519

1620
DEFAULT_CRED_STORE_FILE = ROOT_DIR / 'swat' / 'etc' / '.cred_store.pkl'
1721
DEFAULT_EMULATION_ARTIFACTS_DIR = ROOT_DIR / 'swat' / 'etc' / 'artifacts'
22+
DEFAULT_CUSTOM_CONFIG_PATH = ROOT_DIR / 'swat' / 'etc' / 'custom_config.yaml'
1823

1924

2025
@dataclass
@@ -73,18 +78,22 @@ def to_dict(self):
7378
class Cred:
7479

7580
creds: Optional[CRED_TYPES]
76-
session: Optional[Credentials]
81+
82+
def session(self, scopes: Optional[list[str]] = None) -> Optional[Credentials]:
83+
if isinstance(self.creds, OAuthCreds):
84+
session = Credentials.from_authorized_user_info(str(self.creds.to_dict()), scopes=scopes)
85+
else:
86+
session = ServiceCredentials.from_service_account_info(self.creds.to_dict(), scopes=scopes)
87+
88+
if session.expired and session.refresh_token:
89+
session.refresh(Request())
90+
return session
7791

7892
@property
7993
def client_id(self) -> Optional[str]:
8094
if self.creds and hasattr(self.creds, 'client_id'):
8195
return self.creds.client_id
8296

83-
def refreshed_session(self) -> Optional[Credentials]:
84-
if self.session and self.session.expired and self.session.refresh_token:
85-
self.session.refresh(Request())
86-
return self.session
87-
8897
def to_dict(self):
8998
return {k: v for k, v in dataclasses.asdict(self).items() if not k.startswith('_')}
9099

@@ -100,14 +109,6 @@ def __post_init__(self):
100109
if not isinstance(self.path, Path):
101110
self.path = Path(self.path)
102111

103-
@property
104-
def has_sessions(self) -> bool:
105-
"""Return a boolean indicating if the creds have sessions."""
106-
for key, cred in self.store.items():
107-
if cred.session:
108-
return True
109-
return False
110-
111112
@classmethod
112113
def from_file(cls, file: Path = DEFAULT_CRED_STORE_FILE) -> Optional['CredStore']:
113114
if file.exists():
@@ -118,21 +119,21 @@ def save(self):
118119
logging.info(f'Saved cred store to {self.path}')
119120
self.path.write_bytes(pickle.dumps(self))
120121

121-
def add(self, key: str, creds: Optional[CRED_TYPES] = None, session: Optional[Credentials] = None,
122-
override: bool = False, type: Optional[Literal['oauth', 'service']] = None):
122+
def add(self, key: str, creds: Optional[CRED_TYPES] = None, override: bool = False,
123+
cred_type: Optional[Literal["oauth", "service"]] = None):
123124
"""Add a credential to the store."""
124125
if key in self.store and not override:
125126
raise ValueError(f'Value exists for: {key}')
126127

127-
cred = Cred(creds=creds, session=session)
128+
cred = Cred(creds=creds)
128129
self.store[key] = cred
129-
logging.info(f'Added {type} cred with key: {key}')
130+
logging.info(f'Added {cred_type} cred with key: {key}')
130131

131132
def remove(self, key: str) -> bool:
132133
"""Remove cred by key and type."""
133134
return self.store.pop(key, None) is not None
134135

135-
def get(self, key: str, validate_type: Optional[Literal['oauth', 'service']] = None,
136+
def get(self, key: str, validate_type: Optional[Literal["oauth", "service"]] = None,
136137
missing_error: bool = True) -> Optional[Cred]:
137138
value = self.store.get(key)
138139
creds = value.creds
@@ -156,23 +157,35 @@ def get_by_client_id(self, client_id: str, validate_type: Optional[Literal['oaut
156157

157158
def list_credentials(self) -> list[str]:
158159
"""Get the list of creds from the store."""
159-
return [f'{k}{f":{v.creds}" if v.creds else ""}' for k, v in self.store.items()]
160+
return [f'{k}:{v.creds.__class__.__name__}:{v.creds.project_id}' for k, v in self.store.items()]
161+
162+
163+
class Config:
164+
"""Config class for handling config and custom_config."""
165+
166+
def __init__(self, path: Path, custom_path: Optional[Path] = DEFAULT_CUSTOM_CONFIG_PATH):
167+
self.path = path
168+
self.custom_path = custom_path
169+
170+
assert path.exists(), f'Config file not found: {path}'
171+
self.config = yaml.safe_load(path.read_text())
172+
173+
self.custom_config = yaml.safe_load(custom_path.read_text()) if custom_path.exists() else {}
174+
175+
@property
176+
def merged(self) -> dict:
177+
"""Safely retrieve a fresh merge of primary and custom configs."""
178+
# I regret nothing
179+
config = copy.deepcopy(self.config)
180+
return deep_merge(config, self.custom_config)
160181

161-
def list_sessions(self) -> list[str]:
162-
"""Get the list of sessions from the store."""
163-
sessions = []
164-
for k, v in self.store.items():
165-
if v.session:
166-
if 'service' in v.session.__module__:
167-
sessions.append(f'{k}:{v.session.__module__}:{v.session.service_account_email}')
168-
else:
169-
sessions.append(f'{k}:{v.session.__module__}:{v.session.client_id}')
170-
return sessions
182+
def save_custom(self):
183+
self.custom_path.write_text(yaml.dump(self.custom_config))
171184

172185

173186
@dataclass
174187
class SWAT:
175188
"""Base object for SWAT."""
176189

177-
config: dict
190+
config: Config
178191
cred_store: CredStore = field(default_factory=lambda: CredStore.from_file() or CredStore())

swat/commands/audit.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __call__(self,
4141
raise argparse.ArgumentError(self, f'invalid filter argument "{value}", expected "key=value"')
4242
setattr(namespace, self.dest, {key: val})
4343

44+
4445
@dataclass
4546
class Filters:
4647
"""Dataclass representing a set of filters."""
@@ -68,10 +69,13 @@ class Command(BaseCommand):
6869
parser = get_custom_argparse_formatter(prog='audit', description='Google Workspace Audit')
6970
parser.add_argument('application', help='Application name')
7071
parser.add_argument('duration', help='Duration in format Xs, Xm, Xh or Xd.')
71-
parser.add_argument('--columns', nargs='+', help='Columns to keep in the output. If not set, will take columns from config.')
72+
parser.add_argument('--columns', nargs='+',
73+
help='Columns to keep in the output. If not set, will take columns from config.')
7274
parser.add_argument('--export', action='store_true', default=False, help='Path to export the data')
73-
parser.add_argument('--export-format', choices=['csv', 'ndjson'], default='csv', help='Export format. Default is csv.')
74-
parser.add_argument('--filters', nargs='*', action=KeyValueAction, dest='filters', default={}, help='Filters to apply on the data')
75+
parser.add_argument('--export-format', choices=['csv', 'ndjson'], default='csv',
76+
help='Export format. Default is csv.')
77+
parser.add_argument('--filters', nargs='*', action=KeyValueAction, dest='filters', default={},
78+
help='Filters to apply on the data')
7579
parser.add_argument('--interactive', action='store_true', help='Interactive mode')
7680

7781
def __init__(self, **kwargs) -> None:
@@ -89,7 +93,7 @@ def __init__(self, **kwargs) -> None:
8993

9094
# Check if the session exists in the credential store
9195
if self.obj.cred_store.store.get('default') is None:
92-
self.logger.error(f'Please authenticate with "auth session --default --creds" before running this command.')
96+
self.logger.error(f'Please add "default" creds with "creds add default ..." before running this command.')
9397
return
9498

9599
try:
@@ -99,7 +103,8 @@ def __init__(self, **kwargs) -> None:
99103
return
100104

101105
try:
102-
self.service = build('admin', 'reports_v1', credentials=self.obj.cred_store.store['default'].session)
106+
creds = self.obj.cred_store.get('default', validate_type='oauth')
107+
self.service = build('admin', 'reports_v1', credentials=creds.session())
103108
except HttpError as err:
104109
self.logger.error(f'An error occurred: {err}')
105110
return
@@ -114,7 +119,6 @@ def __init__(self, **kwargs) -> None:
114119
self.args.filters = [f.strip('\'"') for f in self.args.filters]
115120
self.filters = Filters(self.args.filters)
116121

117-
118122
def export_data(self, df: pd.DataFrame) -> None:
119123
"""
120124
Exports the dataframe to a specified format.
@@ -135,7 +139,6 @@ def export_data(self, df: pd.DataFrame) -> None:
135139
else:
136140
self.logger.warning(f'Unsupported export format: {self.args.export_format}. No data was exported.')
137141

138-
139142
def flatten_json(self, y: dict) -> dict:
140143
"""
141144
Flattens a nested dictionary and returns a new dictionary with
@@ -187,7 +190,6 @@ def flatten_activities(self, activities: list) -> pd.DataFrame:
187190
flattened_data.append(merged_data)
188191
return pd.DataFrame(flattened_data)
189192

190-
191193
def fetch_data(self) -> pd.DataFrame:
192194
"""
193195
Fetches the activity data from the Google Workspace Audit service, using the provided start time,
@@ -227,9 +229,14 @@ def filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
227229
Returns:
228230
pd.DataFrame: The filtered dataframe.
229231
"""
230-
columns = self.args.columns or self.obj.config['google']['audit']['columns']
232+
columns = self.args.columns or self.obj.config.merged['google']['audit']['columns']
231233
modified_columns = ['.*' + column + '.*' for column in columns]
232-
df = df[[column for column in df.columns for pattern in modified_columns if re.search(pattern, column, re.IGNORECASE)]]
234+
df = df[
235+
[
236+
column for column in df.columns for pattern in modified_columns
237+
if re.search(pattern, column, re.IGNORECASE)
238+
]
239+
]
233240
return df
234241

235242
def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) -> None:
@@ -244,7 +251,8 @@ def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) ->
244251
None
245252
"""
246253
# Ask the user which columns to display
247-
selected_columns_input = input('Enter the columns to display, separated by commas (see logged available columns): ')
254+
selected_columns_input = input('Enter the columns to display, separated by commas '
255+
'(see logged available columns): ')
248256
selected_columns = [column.strip() for column in selected_columns_input.split(',')]
249257

250258
# Keep only the selected columns
@@ -267,7 +275,8 @@ def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) ->
267275
except ValueError:
268276
self.logger.warning(f'Invalid row number: {row_number}')
269277

270-
def show_results(self, df: pd.DataFrame) -> None:
278+
@staticmethod
279+
def show_results(df: pd.DataFrame) -> None:
271280
"""
272281
Prints the DataFrame to the console in a markdown table format.
273282
@@ -279,7 +288,6 @@ def show_results(self, df: pd.DataFrame) -> None:
279288
"""
280289
print(Fore.GREEN + df.to_markdown(headers='keys', tablefmt='fancy_grid') + Fore.RESET)
281290

282-
283291
def execute(self) -> None:
284292
"""
285293
Main execution method of the Command class.

swat/commands/auth.py

Lines changed: 0 additions & 80 deletions
This file was deleted.

swat/commands/creds.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def add_creds(self):
4141
else:
4242
creds = OAuthCreds.from_file(self.args.creds) if self.args.creds else None
4343

44-
self.obj.cred_store.add(self.args.key, creds=creds, override=self.args.override)
44+
cred_type = 'service' if self.args.service_account else 'oauth'
45+
self.obj.cred_store.add(self.args.key, creds=creds, override=self.args.override, cred_type=cred_type)
4546
self.logger.info(f'Credentials added with key: {self.args.key}')
4647
except TypeError as e:
4748
self.logger.info(f'Invalid credentials file: {self.args.creds} - {e}')

0 commit comments

Comments
 (0)