-
Notifications
You must be signed in to change notification settings - Fork 6
[Maintenance] Refactored scopes and auth #66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -87,3 +87,4 @@ logs/* | |
|
||
# etc | ||
swat/etc/*chrome* | ||
swat/etc/custom_config.yaml |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,20 +1,25 @@ | ||||||
|
||||||
import copy | ||||||
import dataclasses | ||||||
import logging | ||||||
import pickle | ||||||
from dataclasses import dataclass, field | ||||||
from pathlib import Path | ||||||
from typing import Optional, Literal, Union | ||||||
|
||||||
import json | ||||||
from google.auth.transport.requests import Request | ||||||
from google.oauth2.service_account import Credentials as ServiceCredentials | ||||||
from google.oauth2.credentials import Credentials | ||||||
|
||||||
from .utils import ROOT_DIR, PathlibEncoder | ||||||
import json | ||||||
import yaml | ||||||
|
||||||
from .utils import ROOT_DIR, deep_merge | ||||||
|
||||||
|
||||||
DEFAULT_CRED_STORE_FILE = ROOT_DIR / 'swat' / 'etc' / '.cred_store.pkl' | ||||||
DEFAULT_EMULATION_ARTIFACTS_DIR = ROOT_DIR / 'swat' / 'etc' / 'artifacts' | ||||||
DEFAULT_CUSTOM_CONFIG_PATH = ROOT_DIR / 'swat' / 'etc' / 'custom_config.yaml' | ||||||
|
||||||
|
||||||
@dataclass | ||||||
|
@@ -73,18 +78,22 @@ def to_dict(self): | |||||
class Cred: | ||||||
|
||||||
creds: Optional[CRED_TYPES] | ||||||
session: Optional[Credentials] | ||||||
|
||||||
def session(self, scopes: Optional[list[str]] = None) -> Optional[Credentials]: | ||||||
if isinstance(self.creds, OAuthCreds): | ||||||
session = Credentials.from_authorized_user_info(self.creds.to_dict(), scopes=scopes) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
else: | ||||||
session = ServiceCredentials.from_service_account_info(self.creds.to_dict(), scopes=scopes) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
if session.expired and session.refresh_token: | ||||||
session.refresh(Request()) | ||||||
return session | ||||||
|
||||||
@property | ||||||
def client_id(self) -> Optional[str]: | ||||||
if self.creds and hasattr(self.creds, 'client_id'): | ||||||
return self.creds.client_id | ||||||
|
||||||
def refreshed_session(self) -> Optional[Credentials]: | ||||||
if self.session and self.session.expired and self.session.refresh_token: | ||||||
self.session.refresh(Request()) | ||||||
return self.session | ||||||
|
||||||
def to_dict(self): | ||||||
return {k: v for k, v in dataclasses.asdict(self).items() if not k.startswith('_')} | ||||||
|
||||||
|
@@ -100,14 +109,6 @@ def __post_init__(self): | |||||
if not isinstance(self.path, Path): | ||||||
self.path = Path(self.path) | ||||||
|
||||||
@property | ||||||
def has_sessions(self) -> bool: | ||||||
"""Return a boolean indicating if the creds have sessions.""" | ||||||
for key, cred in self.store.items(): | ||||||
if cred.session: | ||||||
return True | ||||||
return False | ||||||
|
||||||
@classmethod | ||||||
def from_file(cls, file: Path = DEFAULT_CRED_STORE_FILE) -> Optional['CredStore']: | ||||||
if file.exists(): | ||||||
|
@@ -118,21 +119,21 @@ def save(self): | |||||
logging.info(f'Saved cred store to {self.path}') | ||||||
self.path.write_bytes(pickle.dumps(self)) | ||||||
|
||||||
def add(self, key: str, creds: Optional[CRED_TYPES] = None, session: Optional[Credentials] = None, | ||||||
override: bool = False, type: Optional[Literal['oauth', 'service']] = None): | ||||||
def add(self, key: str, creds: Optional[CRED_TYPES] = None, override: bool = False, | ||||||
cred_type: Optional[Literal["oauth", "service"]] = None): | ||||||
"""Add a credential to the store.""" | ||||||
if key in self.store and not override: | ||||||
raise ValueError(f'Value exists for: {key}') | ||||||
|
||||||
cred = Cred(creds=creds, session=session) | ||||||
cred = Cred(creds=creds) | ||||||
self.store[key] = cred | ||||||
logging.info(f'Added {type} cred with key: {key}') | ||||||
logging.info(f'Added {cred_type} cred with key: {key}') | ||||||
|
||||||
def remove(self, key: str) -> bool: | ||||||
"""Remove cred by key and type.""" | ||||||
return self.store.pop(key, None) is not None | ||||||
|
||||||
def get(self, key: str, validate_type: Optional[Literal['oauth', 'service']] = None, | ||||||
def get(self, key: str, validate_type: Optional[Literal["oauth", "service"]] = None, | ||||||
missing_error: bool = True) -> Optional[Cred]: | ||||||
value = self.store.get(key) | ||||||
creds = value.creds | ||||||
|
@@ -156,23 +157,35 @@ def get_by_client_id(self, client_id: str, validate_type: Optional[Literal['oaut | |||||
|
||||||
def list_credentials(self) -> list[str]: | ||||||
"""Get the list of creds from the store.""" | ||||||
return [f'{k}{f":{v.creds}" if v.creds else ""}' for k, v in self.store.items()] | ||||||
return [f'{k}:{v.creds.__class__.__name__}:{v.creds.project_id}' for k, v in self.store.items()] | ||||||
|
||||||
|
||||||
class Config: | ||||||
"""Config class for handling config and custom_config.""" | ||||||
|
||||||
def __init__(self, path: Path, custom_path: Optional[Path] = DEFAULT_CUSTOM_CONFIG_PATH): | ||||||
self.path = path | ||||||
self.custom_path = custom_path | ||||||
|
||||||
assert path.exists(), f'Config file not found: {path}' | ||||||
self.config = yaml.safe_load(path.read_text()) | ||||||
|
||||||
self.custom_config = yaml.safe_load(custom_path.read_text()) if custom_path.exists() else {} | ||||||
|
||||||
@property | ||||||
def merged(self) -> dict: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. when getting items, use this property when setting, set directly to |
||||||
"""Safely retrieve a fresh merge of primary and custom configs.""" | ||||||
# I regret nothing | ||||||
config = copy.deepcopy(self.config) | ||||||
return deep_merge(config, self.custom_config) | ||||||
|
||||||
def list_sessions(self) -> list[str]: | ||||||
"""Get the list of sessions from the store.""" | ||||||
sessions = [] | ||||||
for k, v in self.store.items(): | ||||||
if v.session: | ||||||
if 'service' in v.session.__module__: | ||||||
sessions.append(f'{k}:{v.session.__module__}:{v.session.service_account_email}') | ||||||
else: | ||||||
sessions.append(f'{k}:{v.session.__module__}:{v.session.client_id}') | ||||||
return sessions | ||||||
def save_custom(self): | ||||||
self.custom_path.write_text(yaml.dump(self.custom_config)) | ||||||
|
||||||
|
||||||
@dataclass | ||||||
class SWAT: | ||||||
"""Base object for SWAT.""" | ||||||
|
||||||
config: dict | ||||||
config: Config | ||||||
cred_store: CredStore = field(default_factory=lambda: CredStore.from_file() or CredStore()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,6 +41,7 @@ def __call__(self, | |
raise argparse.ArgumentError(self, f'invalid filter argument "{value}", expected "key=value"') | ||
setattr(namespace, self.dest, {key: val}) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This command is failing, but I think unrelated? |
||
@dataclass | ||
class Filters: | ||
"""Dataclass representing a set of filters.""" | ||
|
@@ -68,10 +69,13 @@ class Command(BaseCommand): | |
parser = get_custom_argparse_formatter(prog='audit', description='Google Workspace Audit') | ||
parser.add_argument('application', help='Application name') | ||
parser.add_argument('duration', help='Duration in format Xs, Xm, Xh or Xd.') | ||
parser.add_argument('--columns', nargs='+', help='Columns to keep in the output. If not set, will take columns from config.') | ||
parser.add_argument('--columns', nargs='+', | ||
help='Columns to keep in the output. If not set, will take columns from config.') | ||
parser.add_argument('--export', action='store_true', default=False, help='Path to export the data') | ||
parser.add_argument('--export-format', choices=['csv', 'ndjson'], default='csv', help='Export format. Default is csv.') | ||
parser.add_argument('--filters', nargs='*', action=KeyValueAction, dest='filters', default={}, help='Filters to apply on the data') | ||
parser.add_argument('--export-format', choices=['csv', 'ndjson'], default='csv', | ||
help='Export format. Default is csv.') | ||
parser.add_argument('--filters', nargs='*', action=KeyValueAction, dest='filters', default={}, | ||
help='Filters to apply on the data') | ||
parser.add_argument('--interactive', action='store_true', help='Interactive mode') | ||
|
||
def __init__(self, **kwargs) -> None: | ||
|
@@ -89,7 +93,7 @@ def __init__(self, **kwargs) -> None: | |
|
||
# Check if the session exists in the credential store | ||
if self.obj.cred_store.store.get('default') is None: | ||
self.logger.error(f'Please authenticate with "auth session --default --creds" before running this command.') | ||
self.logger.error(f'Please add "default" creds with "creds add default ..." before running this command.') | ||
return | ||
|
||
try: | ||
|
@@ -99,7 +103,8 @@ def __init__(self, **kwargs) -> None: | |
return | ||
|
||
try: | ||
self.service = build('admin', 'reports_v1', credentials=self.obj.cred_store.store['default'].session) | ||
creds = self.obj.cred_store.get('default', validate_type='oauth') | ||
self.service = build('admin', 'reports_v1', credentials=creds.session()) | ||
except HttpError as err: | ||
self.logger.error(f'An error occurred: {err}') | ||
return | ||
|
@@ -114,7 +119,6 @@ def __init__(self, **kwargs) -> None: | |
self.args.filters = [f.strip('\'"') for f in self.args.filters] | ||
self.filters = Filters(self.args.filters) | ||
|
||
|
||
def export_data(self, df: pd.DataFrame) -> None: | ||
""" | ||
Exports the dataframe to a specified format. | ||
|
@@ -135,7 +139,6 @@ def export_data(self, df: pd.DataFrame) -> None: | |
else: | ||
self.logger.warning(f'Unsupported export format: {self.args.export_format}. No data was exported.') | ||
|
||
|
||
def flatten_json(self, y: dict) -> dict: | ||
""" | ||
Flattens a nested dictionary and returns a new dictionary with | ||
|
@@ -187,7 +190,6 @@ def flatten_activities(self, activities: list) -> pd.DataFrame: | |
flattened_data.append(merged_data) | ||
return pd.DataFrame(flattened_data) | ||
|
||
|
||
def fetch_data(self) -> pd.DataFrame: | ||
""" | ||
Fetches the activity data from the Google Workspace Audit service, using the provided start time, | ||
|
@@ -227,9 +229,14 @@ def filter_columns(self, df: pd.DataFrame) -> pd.DataFrame: | |
Returns: | ||
pd.DataFrame: The filtered dataframe. | ||
""" | ||
columns = self.args.columns or self.obj.config['google']['audit']['columns'] | ||
columns = self.args.columns or self.obj.config.merged['google']['audit']['columns'] | ||
modified_columns = ['.*' + column + '.*' for column in columns] | ||
df = df[[column for column in df.columns for pattern in modified_columns if re.search(pattern, column, re.IGNORECASE)]] | ||
df = df[ | ||
[ | ||
column for column in df.columns for pattern in modified_columns | ||
if re.search(pattern, column, re.IGNORECASE) | ||
] | ||
] | ||
return df | ||
|
||
def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) -> None: | ||
|
@@ -244,7 +251,8 @@ def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) -> | |
None | ||
""" | ||
# Ask the user which columns to display | ||
selected_columns_input = input('Enter the columns to display, separated by commas (see logged available columns): ') | ||
selected_columns_input = input('Enter the columns to display, separated by commas ' | ||
'(see logged available columns): ') | ||
selected_columns = [column.strip() for column in selected_columns_input.split(',')] | ||
|
||
# Keep only the selected columns | ||
|
@@ -267,7 +275,8 @@ def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) -> | |
except ValueError: | ||
self.logger.warning(f'Invalid row number: {row_number}') | ||
|
||
def show_results(self, df: pd.DataFrame) -> None: | ||
@staticmethod | ||
def show_results(df: pd.DataFrame) -> None: | ||
""" | ||
Prints the DataFrame to the console in a markdown table format. | ||
|
||
|
@@ -279,7 +288,6 @@ def show_results(self, df: pd.DataFrame) -> None: | |
""" | ||
print(Fore.GREEN + df.to_markdown(headers='keys', tablefmt='fancy_grid') + Fore.RESET) | ||
|
||
|
||
def execute(self) -> None: | ||
""" | ||
Main execution method of the Command class. | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to get a session, just call this method, with optional scopes