diff --git a/Pilot/dirac-pilot.py b/Pilot/dirac-pilot.py index 9c434c97..a43de292 100644 --- a/Pilot/dirac-pilot.py +++ b/Pilot/dirac-pilot.py @@ -49,6 +49,12 @@ getCommand, pythonPathCheck, ) + +try: + from Pilot.proxyTools import revokePilotToken +except ImportError: + from proxyTools import revokePilotToken + ############################ if __name__ == "__main__": @@ -64,7 +70,7 @@ sys.stdout.write(bufContent) # now the remote logger. remote = pilotParams.pilotLogging and (pilotParams.loggerURL is not None) - if remote: + if remote and pilotParams.jwt: # In a remote logger enabled Dirac version we would have some classic logger content from a wrapper, # which we passed in: receivedContent = "" @@ -76,12 +82,18 @@ bufsize=pilotParams.loggerBufsize, pilotUUID=pilotParams.pilotUUID, debugFlag=pilotParams.debugFlag, - wnVO=pilotParams.wnVO, + jwt=pilotParams.jwt ) log.info("Remote logger activated") - log.buffer.write(receivedContent) + log.buffer.write(log.format_to_json( + "INFO", + receivedContent, + )) log.buffer.flush() - log.buffer.write(bufContent) + log.buffer.write(log.format_to_json( + "INFO", + bufContent, + )) else: log = Logger("Pilot", debugFlag=pilotParams.debugFlag) @@ -104,7 +116,7 @@ log.info("Executing commands: %s" % str(pilotParams.commands)) - if remote: + if remote and pilotParams.jwt: # It's safer to cancel the timer here. Each command has got its own logger object with a timer cancelled by the # finaliser. No need for a timer in the "else" code segment below. try: @@ -122,5 +134,22 @@ log.error("Command %s could not be instantiated" % commandName) # send the last message and abandon ship. if remote: - log.buffer.flush() + log.buffer.flush(force=True) sys.exit(-1) + + log.info("Pilot tasks finished.") + + if not remote: + log.buffer.flush() + + if pilotParams.jwt: + if remote: + log.buffer.flush(force=True) + + log.info("Revoking pilot token.") + revokePilotToken( + pilotParams.diracXServer, + pilotParams.pilotUUID, + pilotParams.jwt, + pilotParams.clientID + ) diff --git a/Pilot/pilotCommands.py b/Pilot/pilotCommands.py index 945a6b78..20ccd014 100644 --- a/Pilot/pilotCommands.py +++ b/Pilot/pilotCommands.py @@ -28,7 +28,6 @@ def __init__(self, pilotParams): import sys import time import traceback -import subprocess from collections import Counter ############################ @@ -44,7 +43,7 @@ def __init__(self, pilotParams): from shlex import quote except ImportError: from pipes import quote - + try: from Pilot.pilotTools import ( CommandBase, @@ -61,6 +60,16 @@ def __init__(self, pilotParams): safe_listdir, sendMessage, ) + +try: + from Pilot.proxyTools import BaseRequest, refreshTokenLoop +except ImportError: + from proxyTools import BaseRequest, refreshTokenLoop + +try: + from urllib.error import HTTPError, URLError +except ImportError: + from urllib2 import HTTPError, URLError ############################ @@ -92,16 +101,37 @@ def wrapper(self): self.log.info( "Flushing the remote logger buffer for pilot on sys.exit(): %s (exit code:%s)" % (pRef, str(exCode)) ) - self.log.buffer.flush() # flush the buffer unconditionally (on sys.exit()). - try: - sendMessage(self.log.url, self.log.pilotUUID, self.log.wnVO, "finaliseLogs", {"retCode": str(exCode)}) - except Exception as exc: - self.log.error("Remote logger couldn't be finalised %s " % str(exc)) + if self.pp.jwt: + try: + sendMessage(self.log.url, self.log.pilotUUID, self.pp.jwt, [ + { + "severity": "ERROR", + "message": str(exCode) + }, + { + "severity": "ERROR", + "message": traceback.format_exc() + } + ]) + + self.log.buffer.flush(force=True) + except Exception as exc: + self.log.error("Remote logger couldn't be finalised %s " % str(exc)) + raise + + # No force here because there's no remote logger if we're here + self.log.buffer.flush() raise except Exception as exc: # unexpected exit: document it and bail out. self.log.error(str(exc)) self.log.error(traceback.format_exc()) + + if self.pp.jwt: + # Force flush if it's a remote logger + self.log.buffer.flush(force=True) + else: + self.log.buffer.flush() raise finally: self.log.buffer.cancelTimer() @@ -132,7 +162,7 @@ def __init__(self, pilotParams): @logFinalizer def execute(self): """Get host and local user info, and other basic checks, e.g. space available""" - + self.log.info("Uname = %s" % " ".join(os.uname())) self.log.info("Host Name = %s" % socket.gethostname()) self.log.info("Host FQDN = %s" % socket.getfqdn()) @@ -1232,3 +1262,5 @@ def execute(self): """Standard entry point to a pilot command""" self._setNagiosOptions() self._runNagiosProbes() + + diff --git a/Pilot/pilotTools.py b/Pilot/pilotTools.py index 8afe0f62..e52953f0 100644 --- a/Pilot/pilotTools.py +++ b/Pilot/pilotTools.py @@ -2,6 +2,7 @@ from __future__ import absolute_import, division, print_function +import enum import fcntl import getopt import json @@ -9,7 +10,7 @@ import re import select import signal -import ssl +import threading import subprocess import sys import threading @@ -22,11 +23,8 @@ # python 2 -> 3 "hacks" try: from urllib.error import HTTPError, URLError - from urllib.parse import urlencode from urllib.request import urlopen except ImportError: - from urllib import urlencode - from urllib2 import HTTPError, URLError, urlopen try: @@ -69,9 +67,9 @@ def load_module_from_path(module_name, path_to_module): basestring = str try: - from Pilot.proxyTools import getVO + from Pilot.proxyTools import X509BasedRequest, getVO, TokenBasedRequest, BaseRequest, refreshTokenLoop except ImportError: - from proxyTools import getVO + from proxyTools import X509BasedRequest, getVO, TokenBasedRequest, BaseRequest, refreshTokenLoop try: FileNotFoundError # pylint: disable=used-before-assignment @@ -526,7 +524,7 @@ def __init__( pilotUUID="unknown", flushInterval=10, bufsize=1000, - wnVO="unknown", + jwt = {} ): """ c'tor @@ -536,36 +534,45 @@ def __init__( super(RemoteLogger, self).__init__(name, debugFlag, pilotOutput) self.url = url self.pilotUUID = pilotUUID - self.wnVO = wnVO self.isPilotLoggerOn = isPilotLoggerOn - sendToURL = partial(sendMessage, url, pilotUUID, wnVO, "sendMessage") - self.buffer = FixedSizeBuffer(sendToURL, bufsize=bufsize, autoflush=flushInterval) + sendToURL = partial(sendMessage, url, pilotUUID) + self.buffer = FixedSizeBuffer(sendToURL, bufsize=bufsize, autoflush=flushInterval, jwt=jwt) + + def format_to_json(self, level, message): + splitted_message = message.split("\n") + + output = [] + for mess in splitted_message: + if mess: + output.append({ + "timestamp": datetime.utcnow().strftime("%Y-%m-%dT%H:%M:%S.%fZ"), + "severity": level, + "message": mess, + "scope": self.name + }) + return output def debug(self, msg, header=True, _sendPilotLog=False): - # TODO: Send pilot log remotely? super(RemoteLogger, self).debug(msg, header) if ( self.isPilotLoggerOn and self.debugFlag ): # the -d flag activates this debug flag in CommandBase via PilotParams - self.sendMessage(self.messageTemplate.format(level="DEBUG", message=msg)) + self.sendMessage(self.format_to_json(level="DEBUG", message=msg)) def error(self, msg, header=True, _sendPilotLog=False): - # TODO: Send pilot log remotely? super(RemoteLogger, self).error(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="ERROR", message=msg)) + self.sendMessage(self.format_to_json(level="ERROR", message=msg)) def warn(self, msg, header=True, _sendPilotLog=False): - # TODO: Send pilot log remotely? super(RemoteLogger, self).warn(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="WARNING", message=msg)) + self.sendMessage(self.format_to_json(level="WARNING", message=msg)) def info(self, msg, header=True, _sendPilotLog=False): - # TODO: Send pilot log remotely? super(RemoteLogger, self).info(msg, header) if self.isPilotLoggerOn: - self.sendMessage(self.messageTemplate.format(level="INFO", message=msg)) + self.sendMessage(self.format_to_json(level="INFO", message=msg)) def sendMessage(self, msg): """ @@ -577,7 +584,7 @@ def sendMessage(self, msg): :rtype: None """ try: - self.buffer.write(msg + "\n") + self.buffer.write(msg) except Exception as err: super(RemoteLogger, self).error("Message not sent") super(RemoteLogger, self).error(str(err)) @@ -604,7 +611,7 @@ class FixedSizeBuffer(object): Once it's full, a message is sent to a remote server and the buffer is renewed. """ - def __init__(self, senderFunc, bufsize=1000, autoflush=10): + def __init__(self, senderFunc, bufsize=1000, autoflush=10, jwt={}): """ Constructor. @@ -622,34 +629,32 @@ def __init__(self, senderFunc, bufsize=1000, autoflush=10): self._timer.start() else: self._timer = None - self.output = StringIO() + self.output = [] self.bufsize = bufsize self._nlines = 0 self.senderFunc = senderFunc + self.jwt = jwt @synchronized - def write(self, text): + def write(self, content_json): """ Write text to a string buffer. Newline characters are counted and number of lines in the buffer is increased accordingly. - :param text: text string to write - :type text: str + :param content_json: Json to send, format following format_to_json + :type content_json: list[dict] :return: None :rtype: None """ - # reopen the buffer in a case we had to flush a partially filled buffer - if self.output.closed: - self.output = StringIO() - self.output.write(text) - self._nlines += max(1, text.count("\n")) + + self.output.extend(content_json) + + try: + self._nlines += max(1, len(content_json)) + except Exception: + raise ValueError(content_json) self.sendFullBuffer() - @synchronized - def getValue(self): - content = self.output.getvalue() - return content - @synchronized def sendFullBuffer(self): """ @@ -659,22 +664,19 @@ def sendFullBuffer(self): if self._nlines >= self.bufsize: self.flush() - self.output = StringIO() + self.output = [] @synchronized - def flush(self): + def flush(self, force=False): """ Flush the buffer and send log records to a remote server. The buffer is closed as well. :return: None :rtype: None """ - if not self.output.closed and self._nlines > 0: - self.output.flush() - buf = self.getValue() - self.senderFunc(buf) + if force or (self.output and self._nlines > 0): + self.senderFunc(self.jwt, self.output) self._nlines = 0 - self.output.close() def cancelTimer(self): """ @@ -687,40 +689,40 @@ def cancelTimer(self): self._timer.cancel() -def sendMessage(url, pilotUUID, wnVO, method, rawMessage): +def sendMessage(url, pilotUUID, jwt={}, rawMessage=[]): """ Invoke a remote method on a Tornado server and pass a JSON message to it. :param str url: Server URL :param str pilotUUID: pilot unique ID - :param str wnVO: VO name, relevant only if not contained in a proxy :param str method: a method to be invoked :param str rawMessage: a message to be sent, in JSON format + :param dict jwt: JWT for the requests :return: None. """ + caPath = os.getenv("X509_CERT_DIR") - cert = os.getenv("X509_USER_PROXY") - context = ssl.create_default_context() - context.load_verify_locations(capath=caPath) - - message = json.dumps((json.dumps(rawMessage), pilotUUID, wnVO)) - - try: - context.load_cert_chain(cert) # this is a proxy - raw_data = {"method": method, "args": message} - except IsADirectoryError: # assuming it'a dir containing cert and key - context.load_cert_chain(os.path.join(cert, "hostcert.pem"), os.path.join(cert, "hostkey.pem")) - raw_data = {"method": method, "args": message, "extraCredentials": '"hosts"'} + raw_data = { + "pilot_stamp": pilotUUID, + "lines": rawMessage + } - if sys.version_info.major == 3: - data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3 - else: - # Python2 - data = urlencode(raw_data) + config = None + + config = TokenBasedRequest( + url=url, + caPath=caPath, + jwtData=jwt, + pilotUUID=pilotUUID + ) - res = urlopen(url, data, context=context) - res.close() + # Do the request + _res = config.executeRequest( + raw_data=raw_data, + insecure=True, + json_output=False + ) class CommandBase(object): @@ -755,7 +757,7 @@ def __init__(self, pilotParams): debugFlag=self.debugFlag, flushInterval=interval, bufsize=bufsize, - wnVO=pilotParams.wnVO, + jwt=pilotParams.jwt ) self.log.isPilotLoggerOn = isPilotLoggerOn @@ -806,7 +808,10 @@ def executeAndGetOutput(self, cmd, environDict=None): sys.stdout.write(outChunk) sys.stdout.flush() if hasattr(self.log, "buffer") and self.log.isPilotLoggerOn: - self.log.buffer.write(outChunk) + self.log.buffer.write(self.log.format_to_json( + "COMMAND", + outChunk + )) outData += outChunk # If no data was read on any of the pipes then the process has finished if not dataWasRead: @@ -908,10 +913,19 @@ def __init__(self): self.site = "" self.setup = "" self.configServer = "" + self.diracXServer = "" self.ceName = "" self.ceType = "" self.queueName = "" self.gridCEType = "" + self.pilotSecret = "" + self.clientID = "" + self.refreshTokenEvery = 300 + self.jwt = { + "access_token": "", + "refresh_token": "" + } + self.jwt_lock = threading.Lock() # maxNumberOfProcessors: the number of # processors allocated to the pilot which the pilot can allocate to one payload # used to set payloadProcessors unless other limits are reached (like the number of processors on the WN) @@ -996,6 +1010,7 @@ def __init__(self): ("y:", "CEType=", "CE Type (normally InProcess)"), ("z", "pilotLogging", "Activate pilot logging system"), ("C:", "configurationServer=", "Configuration servers to use"), + ("", "diracx_URL=", "DiracX Server URL to use"), ("D:", "disk=", "Require at least MB available"), ("E:", "commandExtensions=", "Python modules with extra commands"), ("F:", "pilotCFGFile=", "Specify pilot CFG file"), @@ -1021,6 +1036,9 @@ def __init__(self): ("", "preinstalledEnvPrefix=", "preinstalled pilot environment area prefix"), ("", "architectureScript=", "architecture script to use"), ("", "CVMFS_locations=", "comma-separated list of CVMS locations"), + ("", "pilotSecret=", "secret that the pilot uses with DiracX"), + ("", "clientID=", "client id used by DiracX to revoke a token"), + ("", "refreshTokenEvery=", "how often we have to refresh a token (in seconds)") ) # Possibly get Setup and JSON URL/filename from command line @@ -1047,6 +1065,52 @@ def __init__(self): self.installEnv["X509_USER_PROXY"] = self.certsLocation os.environ["X509_USER_PROXY"] = self.certsLocation + + + if self.pilotUUID or not self.pilotSecret or not self.diracXServer: + self.log.info("Fetching JWT in DiracX (URL: %s)" % self.diracXServer) + + config = BaseRequest( + "%s/api/pilots/token" % ( + self.diracXServer + ), + os.getenv("X509_CERT_DIR"), + self.pilotUUID + ) + + try: + self.jwt = config.executeRequest({ + "pilot_stamp": self.pilotUUID, + "pilot_secret": self.pilotSecret + }, insecure=True) + except (HTTPError, URLError) as e: + self.log.error("Request failed: %s" % str(e)) + self.log.error("Could not fetch pilot tokens. Aborting...") + sys.exit(1) + + self.log.info("Fetched the pilot token with the pilot secret.") + + self.log.info("Starting the refresh thread.") + self.log.info("Refreshing the token every %d seconds." % self.refreshTokenEvery) + # Start background refresh thread + t = threading.Thread( + target=refreshTokenLoop, + args=( + self.diracXServer, + self.pilotUUID, + self.jwt, + self.jwt_lock, + self.log, + self.refreshTokenEvery + ) + ) + t.daemon = True + t.start() + else: + self.log.info("PilotUUID, pilotSecret, and diracXServer are needed to support DiracX.") + + + def __setSecurityDir(self, envName, dirLocation): """Set the environment variable of the `envName`, and add it also to the Pilot Parameters @@ -1151,6 +1215,8 @@ def __initCommandLine2(self): self.keepPythonPath = True elif o in ("-C", "--configurationServer"): self.configServer = v + elif o == "--diracx_URL": + self.diracXServer = v elif o in ("-G", "--Group"): self.userGroup = v elif o in ("-x", "--execute"): @@ -1224,6 +1290,12 @@ def __initCommandLine2(self): self.architectureScript = v elif o == "--CVMFS_locations": self.CVMFS_locations = v.split(",") + elif o == "--pilotSecret": + self.pilotSecret = v + elif o == "--clientID": + self.clientID = v + elif o == "--refreshTokenEvery": + self.refreshTokenEvery = int(v) def __loadJSON(self): """ diff --git a/Pilot/proxyTools.py b/Pilot/proxyTools.py index a5fa652e..b7c882b2 100644 --- a/Pilot/proxyTools.py +++ b/Pilot/proxyTools.py @@ -1,11 +1,31 @@ -"""few functions for dealing with proxies""" +"""few functions for dealing with proxies and authentication""" from __future__ import absolute_import, division, print_function +import json +from multiprocessing import Value +import os +import time import re +import ssl +import sys from base64 import b16decode from subprocess import PIPE, Popen +try: + IsADirectoryError # pylint: disable=used-before-assignment +except NameError: + IsADirectoryError = IOError + +try: + from urllib.parse import urlencode + from urllib.error import HTTPError + from urllib.request import Request, urlopen +except ImportError: + from urllib import urlencode + + from urllib2 import Request, urlopen, HTTPError + VOMS_FQANS_OID = b"1.3.6.1.4.1.8005.100.100.4" VOMS_EXTENSION_OID = b"1.3.6.1.4.1.8005.100.100.5" @@ -30,15 +50,10 @@ def findExtension(oid, lines): def getVO(proxy_data): """Fetches the VO in a chain certificate - Args: - proxy_data (bytes): Bytes for the proxy chain - - Raises: - Exception: Any error related to openssl - NotImplementedError: Not documented error - - Returns: - str: A VO + :param proxy_data: Bytes for the proxy chain + :type proxy_data: bytes + :return: A VO + :rtype: str """ chain = re.findall(br"-----BEGIN CERTIFICATE-----\n.+?\n-----END CERTIFICATE-----", proxy_data, flags=re.DOTALL) @@ -65,3 +80,259 @@ def getVO(proxy_data): if match: return match.groups()[0].decode() raise NotImplementedError("Something went very wrong") + + +class BaseRequest(object): + """This class helps supporting multiple kinds of requests that require connections""" + + def __init__(self, url, caPath, pilotUUID, name="unknown"): + self.name = name + self.url = url + self.caPath = caPath + self.headers = { + "User-Agent": "Dirac Pilot [Unknown ID]" + } + self.pilotUUID = pilotUUID + # We assume we have only one context, so this variable could be shared to avoid opening n times a cert. + # On the contrary, to avoid race conditions, we do avoid using "self.data" and "self.headers" + self._context = None + + self._prepareRequest() + + def generateUserAgent(self): + """To analyse the traffic, we can send a taylor-made User-Agent""" + self.addHeader("User-Agent", "Dirac Pilot [%s]" % self.pilotUUID) + + def _prepareRequest(self): + """As previously, loads the SSL certificates of the server (to avoid "unknown issuer")""" + # Load the SSL context + self._context = ssl.create_default_context() + self._context.load_verify_locations(capath=self.caPath) + + def addHeader(self, key, value): + """Add a header (key, value) into the request header""" + self.headers[key] = value + + def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True): + """Execute a HTTP request with the data, headers, and the pre-defined data (SSL + auth) + + :param raw_data: Data to send + :type raw_data: dict + :param insecure: Deactivate proxy verification WARNING Debug ONLY + :type insecure: bool + :param content_type: Data format to send, either "json" or "x-www-form-urlencoded" or "query" + :type content_type: str + :param json_output: If we have an output + :type json_output: bool + :return: Parsed JSON response + :rtype: dict + """ + if content_type == "json": + data = json.dumps(raw_data).encode("utf-8") + self.addHeader("Content-Type", "application/json") + self.addHeader("Content-Length", str(len(data))) + else: + + data = urlencode(raw_data) + + if content_type == "x-www-form-urlencoded": + if sys.version_info.major == 3: + data = urlencode(raw_data).encode("utf-8") # encode to bytes ! for python3 + + self.addHeader("Content-Type", "application/x-www-form-urlencoded") + self.addHeader("Content-Length", str(len(data))) + elif content_type == "query": + self.url = self.url + "?" + data + data = None # No body + else: + raise ValueError("Invalid content_type. Use 'json' or 'x-www-form-urlencoded'.") + + + request = Request(self.url, data=data, headers=self.headers, method="POST") + + ctx = self._context # Save in case of an insecure request + + if insecure: + # DEBUG ONLY + # Overrides context + ctx = ssl.create_default_context() + ctx.check_hostname = False + ctx.verify_mode = ssl.CERT_NONE + + + try: + if sys.version_info.major == 3: + # Python 3 code + with urlopen(request, context=ctx) as res: + response_data = res.read().decode("utf-8") # Decode response bytes + else: + # Python 2 code + res = urlopen(request, context=ctx) + try: + response_data = res.read() + finally: + res.close() + except HTTPError as e: + raise RuntimeError("HTTPError : %s" % e.read().decode()) + + if json_output: + try: + return json.loads(response_data) # Parse JSON response + except ValueError: # In Python 2, json.JSONDecodeError is a subclass of ValueError + raise ValueError("Invalid JSON response: %s" % response_data) + + +class TokenBasedRequest(BaseRequest): + """Connected Request with JWT support""" + + def __init__(self, url, caPath, jwtData, pilotUUID): + super(TokenBasedRequest, self).__init__(url, caPath, pilotUUID, "TokenBasedConnection") + self.jwtData = jwtData + self.addJwtToHeader() + + def addJwtToHeader(self): + # Adds the JWT in the HTTP request (in the Bearer field) + self.headers["Authorization"] = "Bearer %s" % self.jwtData["access_token"] + + def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True): + + return super(TokenBasedRequest, self).executeRequest( + raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output + ) + +class X509BasedRequest(BaseRequest): + """Connected Request with X509 support""" + + def __init__(self, url, caPath, certEnv, pilotUUID): + super(X509BasedRequest, self).__init__(url, caPath, pilotUUID, "X509BasedConnection") + + self.certEnv = certEnv + self._hasExtraCredentials = False + + # Load X509 once + try: + self._context.load_cert_chain(self.certEnv) + except IsADirectoryError: # assuming it'a dir containing cert and key + self._context.load_cert_chain( + os.path.join(self.certEnv, "hostcert.pem"), os.path.join(self.certEnv, "hostkey.pem") + ) + self._hasExtraCredentials = True + + def executeRequest(self, raw_data, insecure=False, content_type="json", json_output=True): + # Adds a flag if the passed cert is a Directory + if self._hasExtraCredentials: + raw_data["extraCredentials"] = '"hosts"' + return super(X509BasedRequest, self).executeRequest( + raw_data, + insecure=insecure, + content_type=content_type, + json_output=json_output + ) + + +def refreshPilotToken(url, pilotUUID, jwt, jwt_lock): + """ + Refresh the JWT token in a separate thread. + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param dict jwt: Shared dict with current JWT; updated in-place + :param threading.Lock jwt_lock: Lock to safely update the jwt dict + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + # Create request object with required configuration + config = TokenBasedRequest( + url="%s/api/pilots/refresh-token" % url, + caPath=caPath, + pilotUUID=pilotUUID, + jwtData=jwt + ) + + # Perform the request to refresh the token + response = config.executeRequest( + raw_data={ + "refresh_token": jwt["refresh_token"] + }, + insecure=True, + ) + + # Ensure thread-safe update of the shared jwt dictionary + jwt_lock.acquire() + try: + jwt.update(response) + finally: + jwt_lock.release() + + +def revokePilotToken(url, pilotUUID, jwt, clientID): + """ + Refresh the JWT token in a separate thread. + + :param str url: Server URL + :param str pilotUUID: Pilot unique ID + :param str clientID: ClientID used to revoke tokens + :param dict jwt: Shared dict with current JWT; + :return: None + """ + + # PRECONDITION: jwt must contain "refresh_token" + if not jwt or "refresh_token" not in jwt: + raise ValueError("To refresh a token, a pilot needs a JWT with refresh_token") + + # Get CA path from environment + caPath = os.getenv("X509_CERT_DIR") + + # Create request object with required configuration + config = BaseRequest( + url="%s/api/auth/revoke" % url, + caPath=caPath, + pilotUUID=pilotUUID + ) + + # Prepare refresh token payload + payload = { + "refresh_token": jwt["refresh_token"], + "client_id": clientID + } + + # Perform the request to revoke the token + _response = config.executeRequest( + raw_data=payload, + insecure=True, + content_type="query" + ) + +# === Token refresher thread function === +def refreshTokenLoop(url, pilotUUID, jwt, jwt_lock, logger, interval=600): + """ + Periodically refresh the pilot JWT token. + + :param str url: DiracX server URL + :param str pilotUUID: Pilot UUID + :param dict jwt: Shared JWT dictionary + :param threading.Lock jwt_lock: Lock to safely update JWT + :param Logger logger: Logger to debug + :param int interval: Sleep time between refreshes in seconds + :return: None + """ + while True: + time.sleep(interval) + + try: + refreshPilotToken(url, pilotUUID, jwt, jwt_lock) + + logger.info("Token refreshed.") + except Exception as e: + logger.error("Token refresh failed: %s\n" % str(e)) + continue