Skip to content
114 changes: 114 additions & 0 deletions sources/sharepoint/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Iterator, Dict
import re

import dlt
from dlt.common.typing import TDataItems
from dlt.common.configuration.specs import configspec, BaseConfiguration
from dlt.common import logger
import pandas as pd

from .helpers import SharepointClient
from .sharepoint_files_config import SharepointFilesConfig, SharepointListConfig


@configspec
class SharepointCredentials(BaseConfiguration):
client_id: str = None
tenant_id: str = None
site_id: str = None
client_secret: str = None
sub_site_id: str = ""


@dlt.source(name="sharepoint_list", max_table_nesting=0)
def sharepoint_list(
sharepoint_list_config: SharepointListConfig,
credentials: SharepointCredentials = dlt.secrets.value,
) -> Iterator[Dict[str, str]]:
client: SharepointClient = SharepointClient(**credentials)
client.connect()
logger.info(f"Connected to SharePoint site: {client.site_info}")

def get_pipe(sharepoint_list_config: SharepointListConfig):
def get_records(sharepoint_list_config: SharepointListConfig):
data = client.get_items_from_list(
list_title=sharepoint_list_config.list_title,
select=sharepoint_list_config.select,
)
yield from data

return dlt.resource(get_records, name=sharepoint_list_config.table_name)(
sharepoint_list_config
)

yield get_pipe(sharepoint_list_config=sharepoint_list_config)


@dlt.source(name="sharepoint_files", max_table_nesting=0)
def sharepoint_files(
sharepoint_files_config: SharepointFilesConfig,
credentials: SharepointCredentials = dlt.secrets.value,
):
client: SharepointClient = SharepointClient(**credentials)
client.connect()
logger.info(f"Connected to SharePoint site: {client.site_info}")

def get_files(
sharepoint_files_config: SharepointFilesConfig,
last_update_timestamp: dlt.sources.incremental = dlt.sources.incremental(
cursor_path="lastModifiedDateTime",
initial_value="2020-01-01T00:00:00Z",
primary_key=(),
),
):
current_last_value = last_update_timestamp.last_value
logger.debug(f"current_last_value: {current_last_value}")
for file_item in client.get_files_from_path(
folder_path=sharepoint_files_config.folder_path,
file_name_startswith=sharepoint_files_config.file_name_startswith,
pattern=sharepoint_files_config.pattern,
):
logger.debug(
"filtering files based on lastModifiedDateTime, compare to last_value:"
f" {current_last_value}"
)
if (
file_item["lastModifiedDateTime"] > current_last_value
or not sharepoint_files_config.is_file_incremental
):
logger.info(
f"Processing file after lastModifiedDateTime filter: {file_item['name']}"
)

file_item["pd_function"] = (
sharepoint_files_config.file_type.get_pd_function()
)
file_item["pd_kwargs"] = sharepoint_files_config.pandas_kwargs
yield file_item
else:
logger.info(
f"Skipping file {file_item['name']} based on lastModifiedDateTime filter"
)

def get_records(file_item: Dict) -> TDataItems:
chunksize = file_item["pd_kwargs"].get("chunksize", None)
file_io = client.get_file_bytes_io(file_item=file_item)

if chunksize:
with file_item["pd_function"](file_io, **file_item["pd_kwargs"]) as reader:
for num, chunk in enumerate(reader):
logger.info(f"Processing chunk {num} of {file_item['name']}")
yield chunk
else:
df = file_item["pd_function"](file_io, **file_item["pd_kwargs"])
yield df
logger.debug(f"get_records done for {file_item['name']}")

def get_pipe(sharepoint_files_config: SharepointFilesConfig):
return dlt.resource(
get_files, name=f"{sharepoint_files_config.table_name}_files"
)(sharepoint_files_config) | dlt.transformer(
get_records, name=sharepoint_files_config.table_name, parallelized=False
)

yield get_pipe(sharepoint_files_config=sharepoint_files_config)
178 changes: 178 additions & 0 deletions sources/sharepoint/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from typing import Dict, List
from io import BytesIO
import re

from msal import ConfidentialClientApplication
from dlt.common import logger
from dlt.sources.helpers.rest_client import RESTClient
from dlt.sources.helpers.rest_client.auth import BearerTokenAuth
from dlt.sources.helpers.rest_client.paginators import JSONLinkPaginator


class SharepointClient:
# * playground: https://developer.microsoft.com/en-us/graph/graph-explorer
# * If the result contains more results, Microsoft Graph returns an @odata.nextLink property

def __init__(
self,
client_id: str,
tenant_id: str,
site_id: str,
client_secret: str,
sub_site_id: str = "",
) -> None:
self.client_id = client_id
self.tenant_id = tenant_id
self.client_secret = client_secret
self.sub_site_id = sub_site_id
self.site_id = site_id
if not all([self.client_id, self.tenant_id, self.client_secret, self.site_id]):
raise ValueError(
"client_id, tenant_id, client_secret and site_id are required for connect to"
" SharePoint"
)
self.graph_api_url = "https://graph.microsoft.com/v1.0/sites"
self.graph_site_url = f"{self.graph_api_url}/{self.site_id}"
if self.sub_site_id:
self.graph_site_url += f"/sites/{self.sub_site_id}"

def connect(self) -> None:
authority = f"https://login.microsoftonline.com/{self.tenant_id}"
scope = ["https://graph.microsoft.com/.default"]

app = ConfidentialClientApplication(
self.client_id,
authority=authority,
client_credential=self.client_secret,
)

# Get the access token
token_response = app.acquire_token_for_client(scopes=scope)
access_token = token_response.get("access_token", None)

if access_token:
self.client = RESTClient(
base_url=self.graph_site_url,
auth=BearerTokenAuth(access_token),
paginator=JSONLinkPaginator(next_url_path="@odata.nextLink"),
)
logger.info(f"Connected to SharePoint site id: {self.site_id} successfully")
else:
raise ConnectionError("Connection failed : ", token_response)

@property
def sub_sites(self) -> List:
url = f"{self.graph_site_url}/sites"
response = self.client.get(url)
site_info = response.json()
if "value" in site_info:
return site_info["value"]
else:
logger.warning(f"No subsite found in {url}")

@property
def site_info(self) -> Dict:
url = f"{self.graph_site_url}"
response = self.client.get(url)
site_info = response.json()
if not "error" in site_info:
return site_info
else:
logger.warning(f"No site_info found in {url}")

def get_all_lists_in_site(self) -> List[Dict]:
url = f"{self.graph_site_url}/lists"
res = self.client.get(url)
res.raise_for_status()
lists_info = res.json()
if "value" in lists_info:
all_items = lists_info["value"]
filtered_lists = [
item
for item in all_items
if item.get("list", {}).get("template") == "genericList"
and "Lists" in item.get("webUrl", "")
]
return filtered_lists
else:
filtered_lists = []
if not filtered_lists:
logger.warning(f"No lists found in {url}")
return filtered_lists

def get_items_from_list(self, list_title: str, select: str = None) -> List[Dict]:
# TODO, pagination not yet implemented
logger.warning(
"Pagination is not implemented for get_items_from_list, "
"it will return only first page of items."
)
all_lists = self.get_all_lists_in_site()
filtered_lists = [
x
for x in all_lists
if x.get("list", {}).get("template") == "genericList"
and "Lists" in x.get("webUrl", "")
]

possible_list_titles = [x["displayName"] for x in filtered_lists]
if list_title not in possible_list_titles:
raise ValueError(
f"List with title '{list_title}' not found in site {self.site_id}. "
f"Available lists: {possible_list_titles}"
)

# Get the list ID
list_id = next(
x["id"] for x in filtered_lists if x["displayName"] == list_title
)

url = f"{self.graph_site_url}/lists/{list_id}/items?expand=fields"
if select:
url += f"(select={select})"
res = self.client.get(url)
res.raise_for_status()
items_info = res.json()

if "value" in items_info:
output = [x.get("fields", {}) for x in items_info["value"]]
else:
output = []
if output:
logger.info(f"Got {len(output)} items from list: {list_title}")
return output
else:
logger.warning(
f"No items found in list: {list_title}, with select: {select}"
)

def get_files_from_path(
self, folder_path: str, file_name_startswith: str, pattern: str = None
) -> Dict:
folder_url = (
f"{self.graph_site_url}/drive/root:/{folder_path}:/children?$filter=startswith(name,"
f" '{file_name_startswith}')"
)
logger.debug(f"Getting files from folder with endpoint: {folder_url}")
res = self.client.get(folder_url)
file_and_folder_items = res.json().get("value", [])
file_items = [x for x in file_and_folder_items if "file" in x.keys()]
if pattern:
logger.debug(f"Filtering files with pattern: {pattern}")
file_items = [x for x in file_items if re.search(pattern, x["name"])]

logger.debug(f"Got number files from ms graph api: {len(file_items)}")
return file_items

def get_file_bytes_io(self, file_item: Dict):
file_url = file_item["@microsoft.graph.downloadUrl"]
response = self.client.get(file_url)
if response.status_code == 200:
bytes_io = BytesIO(response.content)
logger.info(
f"File {file_item['name']} downloaded to BytesIO, size: {len(bytes_io.getvalue())}"
)
return bytes_io
else:
raise FileNotFoundError(
f"File not found: {file_item['name']} or can't be downloaded"
)
70 changes: 70 additions & 0 deletions sources/sharepoint/sharepoint_files_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Optional, Dict
import re
from enum import Enum

import pandas as pd
from pydantic import BaseModel


class FileType(Enum):
EXCEL = "excel"
CSV = "csv"
JSON = "json"
PARQUET = "parquet"
SAS = "sas"
SPSS = "spss"
SAV = "sav"

def get_pd_function(self):
return {
self.EXCEL: pd.read_excel,
self.CSV: pd.read_csv,
self.JSON: pd.read_json,
self.PARQUET: pd.read_parquet,
self.SAS: pd.read_sas,
self.SPSS: pd.read_spss,
}[self]


class SharepointListConfig(BaseModel):
table_name: str
list_title: str
select: Optional[str] = None
is_incremental: Optional[bool] = False

def __init__(self, **data):
super().__init__(**data)
if self.is_incremental is True:
raise NotImplementedError(
"Incremental loading for Sharepoint List is not implemented yet."
)


class SharepointFilesConfig(BaseModel):
file_type: FileType
folder_path: str
table_name: str
file_name_startswith: str
pattern: Optional[str] = ".*"
pandas_kwargs: Dict = {}
is_file_incremental: bool = False

def __init__(self, **data):
super().__init__(**data)
self.folder_path = validate_folder_path(self.folder_path)
self.pattern = f"^{self.file_name_startswith}{self.pattern}"


def validate_folder_path(folder_path: str) -> str:
if folder_path.startswith("/"):
folder_path = folder_path[1:]
if folder_path.endswith("/"):
folder_path = folder_path[:-1]
if not re.compile(r"^[a-zA-Z0-9_\-/\s\.]*$").match(folder_path):
raise ValueError(
"Invalid folder path, only alphanumeric characters, dashes and underscores are"
f" allowed: {folder_path}"
)
if re.compile(r"//").search(folder_path):
raise ValueError(f"Invalid folder path with double slashes: {folder_path}")
return folder_path
Loading