Skip to content

[Maintenance] Refactored scopes and auth #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,4 @@ logs/*

# etc
swat/etc/*chrome*
swat/etc/custom_config.yaml
79 changes: 46 additions & 33 deletions swat/base.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@

import copy
import dataclasses
import logging
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional, Literal, Union

import json
from google.auth.transport.requests import Request
from google.oauth2.service_account import Credentials as ServiceCredentials
from google.oauth2.credentials import Credentials

from .utils import ROOT_DIR, PathlibEncoder
import json
import yaml

from .utils import ROOT_DIR, deep_merge


DEFAULT_CRED_STORE_FILE = ROOT_DIR / 'swat' / 'etc' / '.cred_store.pkl'
DEFAULT_EMULATION_ARTIFACTS_DIR = ROOT_DIR / 'swat' / 'etc' / 'artifacts'
DEFAULT_CUSTOM_CONFIG_PATH = ROOT_DIR / 'swat' / 'etc' / 'custom_config.yaml'


@dataclass
Expand Down Expand Up @@ -73,18 +78,22 @@ def to_dict(self):
class Cred:

creds: Optional[CRED_TYPES]
session: Optional[Credentials]

def session(self, scopes: Optional[list[str]] = None) -> Optional[Credentials]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to get a session, just call this method, with optional scopes

if isinstance(self.creds, OAuthCreds):
session = Credentials.from_authorized_user_info(self.creds.to_dict(), scopes=scopes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
session = Credentials.from_authorized_user_info(self.creds.to_dict(), scopes=scopes)
session = Credentials.from_authorized_user_info(self.creds.to_dict()['installed'], scopes=scopes)

else:
session = ServiceCredentials.from_service_account_info(self.creds.to_dict(), scopes=scopes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
session = ServiceCredentials.from_service_account_info(self.creds.to_dict(), scopes=scopes)
session = ServiceCredentials.from_service_account_info(self.creds.to_dict()['installed'], scopes=scopes)


if session.expired and session.refresh_token:
session.refresh(Request())
return session

@property
def client_id(self) -> Optional[str]:
if self.creds and hasattr(self.creds, 'client_id'):
return self.creds.client_id

def refreshed_session(self) -> Optional[Credentials]:
if self.session and self.session.expired and self.session.refresh_token:
self.session.refresh(Request())
return self.session

def to_dict(self):
return {k: v for k, v in dataclasses.asdict(self).items() if not k.startswith('_')}

Expand All @@ -100,14 +109,6 @@ def __post_init__(self):
if not isinstance(self.path, Path):
self.path = Path(self.path)

@property
def has_sessions(self) -> bool:
"""Return a boolean indicating if the creds have sessions."""
for key, cred in self.store.items():
if cred.session:
return True
return False

@classmethod
def from_file(cls, file: Path = DEFAULT_CRED_STORE_FILE) -> Optional['CredStore']:
if file.exists():
Expand All @@ -118,21 +119,21 @@ def save(self):
logging.info(f'Saved cred store to {self.path}')
self.path.write_bytes(pickle.dumps(self))

def add(self, key: str, creds: Optional[CRED_TYPES] = None, session: Optional[Credentials] = None,
override: bool = False, type: Optional[Literal['oauth', 'service']] = None):
def add(self, key: str, creds: Optional[CRED_TYPES] = None, override: bool = False,
cred_type: Optional[Literal["oauth", "service"]] = None):
"""Add a credential to the store."""
if key in self.store and not override:
raise ValueError(f'Value exists for: {key}')

cred = Cred(creds=creds, session=session)
cred = Cred(creds=creds)
self.store[key] = cred
logging.info(f'Added {type} cred with key: {key}')
logging.info(f'Added {cred_type} cred with key: {key}')

def remove(self, key: str) -> bool:
"""Remove cred by key and type."""
return self.store.pop(key, None) is not None

def get(self, key: str, validate_type: Optional[Literal['oauth', 'service']] = None,
def get(self, key: str, validate_type: Optional[Literal["oauth", "service"]] = None,
missing_error: bool = True) -> Optional[Cred]:
value = self.store.get(key)
creds = value.creds
Expand All @@ -156,23 +157,35 @@ def get_by_client_id(self, client_id: str, validate_type: Optional[Literal['oaut

def list_credentials(self) -> list[str]:
"""Get the list of creds from the store."""
return [f'{k}{f":{v.creds}" if v.creds else ""}' for k, v in self.store.items()]
return [f'{k}:{v.creds.__class__.__name__}:{v.creds.project_id}' for k, v in self.store.items()]


class Config:
"""Config class for handling config and custom_config."""

def __init__(self, path: Path, custom_path: Optional[Path] = DEFAULT_CUSTOM_CONFIG_PATH):
self.path = path
self.custom_path = custom_path

assert path.exists(), f'Config file not found: {path}'
self.config = yaml.safe_load(path.read_text())

self.custom_config = yaml.safe_load(custom_path.read_text()) if custom_path.exists() else {}

@property
def merged(self) -> dict:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when getting items, use this property

when setting, set directly to custom_config - these are all saved on exit

"""Safely retrieve a fresh merge of primary and custom configs."""
# I regret nothing
config = copy.deepcopy(self.config)
return deep_merge(config, self.custom_config)

def list_sessions(self) -> list[str]:
"""Get the list of sessions from the store."""
sessions = []
for k, v in self.store.items():
if v.session:
if 'service' in v.session.__module__:
sessions.append(f'{k}:{v.session.__module__}:{v.session.service_account_email}')
else:
sessions.append(f'{k}:{v.session.__module__}:{v.session.client_id}')
return sessions
def save_custom(self):
self.custom_path.write_text(yaml.dump(self.custom_config))


@dataclass
class SWAT:
"""Base object for SWAT."""

config: dict
config: Config
cred_store: CredStore = field(default_factory=lambda: CredStore.from_file() or CredStore())
34 changes: 21 additions & 13 deletions swat/commands/audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __call__(self,
raise argparse.ArgumentError(self, f'invalid filter argument "{value}", expected "key=value"')
setattr(namespace, self.dest, {key: val})


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This command is failing, but I think unrelated?

@dataclass
class Filters:
"""Dataclass representing a set of filters."""
Expand Down Expand Up @@ -68,10 +69,13 @@ class Command(BaseCommand):
parser = get_custom_argparse_formatter(prog='audit', description='Google Workspace Audit')
parser.add_argument('application', help='Application name')
parser.add_argument('duration', help='Duration in format Xs, Xm, Xh or Xd.')
parser.add_argument('--columns', nargs='+', help='Columns to keep in the output. If not set, will take columns from config.')
parser.add_argument('--columns', nargs='+',
help='Columns to keep in the output. If not set, will take columns from config.')
parser.add_argument('--export', action='store_true', default=False, help='Path to export the data')
parser.add_argument('--export-format', choices=['csv', 'ndjson'], default='csv', help='Export format. Default is csv.')
parser.add_argument('--filters', nargs='*', action=KeyValueAction, dest='filters', default={}, help='Filters to apply on the data')
parser.add_argument('--export-format', choices=['csv', 'ndjson'], default='csv',
help='Export format. Default is csv.')
parser.add_argument('--filters', nargs='*', action=KeyValueAction, dest='filters', default={},
help='Filters to apply on the data')
parser.add_argument('--interactive', action='store_true', help='Interactive mode')

def __init__(self, **kwargs) -> None:
Expand All @@ -89,7 +93,7 @@ def __init__(self, **kwargs) -> None:

# Check if the session exists in the credential store
if self.obj.cred_store.store.get('default') is None:
self.logger.error(f'Please authenticate with "auth session --default --creds" before running this command.')
self.logger.error(f'Please add "default" creds with "creds add default ..." before running this command.')
return

try:
Expand All @@ -99,7 +103,8 @@ def __init__(self, **kwargs) -> None:
return

try:
self.service = build('admin', 'reports_v1', credentials=self.obj.cred_store.store['default'].session)
creds = self.obj.cred_store.get('default', validate_type='oauth')
self.service = build('admin', 'reports_v1', credentials=creds.session())
except HttpError as err:
self.logger.error(f'An error occurred: {err}')
return
Expand All @@ -114,7 +119,6 @@ def __init__(self, **kwargs) -> None:
self.args.filters = [f.strip('\'"') for f in self.args.filters]
self.filters = Filters(self.args.filters)


def export_data(self, df: pd.DataFrame) -> None:
"""
Exports the dataframe to a specified format.
Expand All @@ -135,7 +139,6 @@ def export_data(self, df: pd.DataFrame) -> None:
else:
self.logger.warning(f'Unsupported export format: {self.args.export_format}. No data was exported.')


def flatten_json(self, y: dict) -> dict:
"""
Flattens a nested dictionary and returns a new dictionary with
Expand Down Expand Up @@ -187,7 +190,6 @@ def flatten_activities(self, activities: list) -> pd.DataFrame:
flattened_data.append(merged_data)
return pd.DataFrame(flattened_data)


def fetch_data(self) -> pd.DataFrame:
"""
Fetches the activity data from the Google Workspace Audit service, using the provided start time,
Expand Down Expand Up @@ -227,9 +229,14 @@ def filter_columns(self, df: pd.DataFrame) -> pd.DataFrame:
Returns:
pd.DataFrame: The filtered dataframe.
"""
columns = self.args.columns or self.obj.config['google']['audit']['columns']
columns = self.args.columns or self.obj.config.merged['google']['audit']['columns']
modified_columns = ['.*' + column + '.*' for column in columns]
df = df[[column for column in df.columns for pattern in modified_columns if re.search(pattern, column, re.IGNORECASE)]]
df = df[
[
column for column in df.columns for pattern in modified_columns
if re.search(pattern, column, re.IGNORECASE)
]
]
return df

def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) -> None:
Expand All @@ -244,7 +251,8 @@ def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) ->
None
"""
# Ask the user which columns to display
selected_columns_input = input('Enter the columns to display, separated by commas (see logged available columns): ')
selected_columns_input = input('Enter the columns to display, separated by commas '
'(see logged available columns): ')
selected_columns = [column.strip() for column in selected_columns_input.split(',')]

# Keep only the selected columns
Expand All @@ -267,7 +275,8 @@ def interactive_session(self, df: pd.DataFrame, df_unfiltered: pd.DataFrame) ->
except ValueError:
self.logger.warning(f'Invalid row number: {row_number}')

def show_results(self, df: pd.DataFrame) -> None:
@staticmethod
def show_results(df: pd.DataFrame) -> None:
"""
Prints the DataFrame to the console in a markdown table format.

Expand All @@ -279,7 +288,6 @@ def show_results(self, df: pd.DataFrame) -> None:
"""
print(Fore.GREEN + df.to_markdown(headers='keys', tablefmt='fancy_grid') + Fore.RESET)


def execute(self) -> None:
"""
Main execution method of the Command class.
Expand Down
80 changes: 0 additions & 80 deletions swat/commands/auth.py

This file was deleted.

3 changes: 2 additions & 1 deletion swat/commands/creds.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def add_creds(self):
else:
creds = OAuthCreds.from_file(self.args.creds) if self.args.creds else None

self.obj.cred_store.add(self.args.key, creds=creds, override=self.args.override)
cred_type = 'service' if self.args.service_account else 'oauth'
self.obj.cred_store.add(self.args.key, creds=creds, override=self.args.override, cred_type=cred_type)
self.logger.info(f'Credentials added with key: {self.args.key}')
except TypeError as e:
self.logger.info(f'Invalid credentials file: {self.args.creds} - {e}')
Expand Down
Loading