|
1 | 1 | import re
|
| 2 | +import logging |
2 | 3 |
|
3 | 4 | from .base import ResponseMicroService
|
| 5 | +from ..context import Context |
| 6 | +from ..exception import SATOSAError |
4 | 7 |
|
| 8 | +logger = logging.getLogger(__name__) |
5 | 9 |
|
6 | 10 | class AddStaticAttributes(ResponseMicroService):
|
7 | 11 | """
|
@@ -29,28 +33,62 @@ def __init__(self, config, *args, **kwargs):
|
29 | 33 | def process(self, context, data):
|
30 | 34 | # apply default filters
|
31 | 35 | provider_filters = self.attribute_filters.get("", {})
|
32 |
| - self._apply_requester_filters(data.attributes, provider_filters, data.requester) |
| 36 | + target_provider = data.auth_info.issuer |
| 37 | + self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider) |
33 | 38 |
|
34 | 39 | # apply target provider specific filters
|
35 |
| - target_provider = data.auth_info.issuer |
36 | 40 | provider_filters = self.attribute_filters.get(target_provider, {})
|
37 |
| - self._apply_requester_filters(data.attributes, provider_filters, data.requester) |
| 41 | + self._apply_requester_filters(data.attributes, provider_filters, data.requester, context, target_provider) |
38 | 42 | return super().process(context, data)
|
39 | 43 |
|
40 |
| - def _apply_requester_filters(self, attributes, provider_filters, requester): |
| 44 | + def _apply_requester_filters(self, attributes, provider_filters, requester, context, target_provider): |
41 | 45 | # apply default requester filters
|
42 | 46 | default_requester_filters = provider_filters.get("", {})
|
43 |
| - self._apply_filter(attributes, default_requester_filters) |
| 47 | + self._apply_filters(attributes, default_requester_filters, context, target_provider) |
44 | 48 |
|
45 | 49 | # apply requester specific filters
|
46 | 50 | requester_filters = provider_filters.get(requester, {})
|
47 |
| - self._apply_filter(attributes, requester_filters) |
48 |
| - |
49 |
| - def _apply_filter(self, attributes, attribute_filters): |
50 |
| - for attribute_name, attribute_filter in attribute_filters.items(): |
51 |
| - regex = re.compile(attribute_filter) |
52 |
| - if attribute_name == "": # default filter for all attributes |
53 |
| - for attribute, values in attributes.items(): |
54 |
| - attributes[attribute] = list(filter(regex.search, attributes[attribute])) |
55 |
| - elif attribute_name in attributes: |
56 |
| - attributes[attribute_name] = list(filter(regex.search, attributes[attribute_name])) |
| 51 | + self._apply_filters(attributes, requester_filters, context, target_provider) |
| 52 | + |
| 53 | + def _apply_filters(self, attributes, attribute_filters, context, target_provider): |
| 54 | + for attribute_name, attribute_filters in attribute_filters.items(): |
| 55 | + if type(attribute_filters) == str: |
| 56 | + # convert simple notation to filter list |
| 57 | + attribute_filters = {'regexp': attribute_filters} |
| 58 | + |
| 59 | + for filter_type, filter_value in attribute_filters.items(): |
| 60 | + |
| 61 | + if filter_type == "regexp": |
| 62 | + filter_func = re.compile(filter_value).search |
| 63 | + elif filter_type == "shibmdscope_match_scope": |
| 64 | + mdstore = context.get_decoration(Context.KEY_METADATA_STORE) |
| 65 | + md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else [] |
| 66 | + filter_func = lambda v: self._shibmdscope_match_scope(v, md_scopes) |
| 67 | + elif filter_type == "shibmdscope_match_value": |
| 68 | + mdstore = context.get_decoration(Context.KEY_METADATA_STORE) |
| 69 | + md_scopes = list(mdstore.shibmd_scopes(target_provider,"idpsso_descriptor")) if mdstore else [] |
| 70 | + filter_func = lambda v: self._shibmdscope_match_value(v, md_scopes) |
| 71 | + else: |
| 72 | + raise SATOSAError("Unknown filter type") |
| 73 | + |
| 74 | + if attribute_name == "": # default filter for all attributes |
| 75 | + for attribute, values in attributes.items(): |
| 76 | + attributes[attribute] = list(filter(filter_func, attributes[attribute])) |
| 77 | + elif attribute_name in attributes: |
| 78 | + attributes[attribute_name] = list(filter(filter_func, attributes[attribute_name])) |
| 79 | + |
| 80 | + def _shibmdscope_match_value(self, value, md_scopes): |
| 81 | + for md_scope in md_scopes: |
| 82 | + if not md_scope['regexp'] and md_scope['text'] == value: |
| 83 | + return True |
| 84 | + elif md_scope['regexp'] and re.fullmatch(md_scope['text'], value): |
| 85 | + return True |
| 86 | + return False |
| 87 | + |
| 88 | + def _shibmdscope_match_scope(self, value, md_scopes): |
| 89 | + split_value = value.split('@') |
| 90 | + if len(split_value) != 2: |
| 91 | + logger.info(f"Discarding invalid scoped value {value}") |
| 92 | + return False |
| 93 | + value_scope = split_value[1] |
| 94 | + return self._shibmdscope_match_value(value_scope, md_scopes) |
0 commit comments