diff --git a/pyslurm/core/error.pyx b/pyslurm/core/error.pyx index 722f3097..954aa2a6 100644 --- a/pyslurm/core/error.pyx +++ b/pyslurm/core/error.pyx @@ -27,6 +27,19 @@ from pyslurm cimport slurm cimport libc.errno +def _check_modify_arguments(changes, **kwargs): + if changes is None and not kwargs: + raise ArgumentError("Nothing to change was provided") + + if changes is not None and kwargs: + raise ArgumentError("Provide either a changes object or keyword arguments, not both") + + +def _get_modify_arguments_for(cls, changes, **kwargs): + _check_modify_arguments(changes, **kwargs) + return changes or cls(**kwargs) + + def slurm_strerror(errno): """Convert a slurm errno to a string. @@ -69,7 +82,15 @@ class PyslurmError(Exception): """The base Exception for all Pyslurm errors.""" -class RPCError(PyslurmError): +class ClientError(PyslurmError): + pass + + +class ServerError(PyslurmError): + pass + + +class RPCError(ServerError): """Exception for handling Slurm RPC errors. Args: @@ -100,12 +121,26 @@ class RPCError(PyslurmError): super().__init__(self.msg) -def verify_rpc(errno): +class InvalidUsageError(ClientError): + pass + + +class ArgumentError(InvalidUsageError): + pass + + +class NotFoundError(RPCError): + pass + + +def verify_rpc(errno, msg=None): """Verify a Slurm RPC Args: errno (int): A Slurm error value + msg (str): + An optional message """ if errno != slurm.SLURM_SUCCESS: - raise RPCError(errno) + raise RPCError(errno, msg) diff --git a/pyslurm/db/__init__.py b/pyslurm/db/__init__.py index a821cc43..b633c82a 100644 --- a/pyslurm/db/__init__.py +++ b/pyslurm/db/__init__.py @@ -19,7 +19,7 @@ # with PySlurm; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -from .connection import Connection +from .connection import Connection, connect from .step import JobStep, JobSteps from .stats import JobStatistics, JobStepStatistics from .job import ( @@ -44,3 +44,18 @@ Association, AssociationFilter, ) +from .user import ( + Users, + User, + UserFilter, +) +from .account import ( + Accounts, + Account, + AccountFilter, +) +from .wckey import ( + WCKeys, + WCKey, + WCKeyFilter, +) diff --git a/pyslurm/db/account.pxd b/pyslurm/db/account.pxd new file mode 100644 index 00000000..cf5169bd --- /dev/null +++ b/pyslurm/db/account.pxd @@ -0,0 +1,108 @@ +######################################################################### +# account.pxd - pyslurm slurmdbd account api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from libc.string cimport memcpy, memset +from pyslurm cimport slurm +from pyslurm.slurm cimport ( + slurmdb_account_rec_t, + slurmdb_assoc_rec_t, + slurmdb_assoc_cond_t, + slurmdb_account_cond_t, + slurmdb_accounts_get, + slurmdb_accounts_add, + slurmdb_accounts_remove, + slurmdb_accounts_modify, + slurmdb_destroy_account_rec, + slurmdb_destroy_account_cond, + try_xmalloc, +) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, + make_char_list, + slurm_list_to_pylist, + qos_list_to_pylist, +) +from pyslurm.db.tres cimport ( + _set_tres_limits, + TrackableResources, +) +from pyslurm.db.connection cimport Connection, ConnectionWrapper +from pyslurm.utils cimport cstr +from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list +from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr +from pyslurm.xcollections cimport MultiClusterMap +from pyslurm.utils.uint cimport u16_set_bool_flag + + +cdef class AccountAPI(ConnectionWrapper): + pass + + +cdef class Accounts(dict): + cdef public: + Connection _db_conn + + +cdef class AccountFilter: + cdef slurmdb_account_cond_t *ptr + + cdef public: + with_assocs + with_deleted + with_coordinators + names + organizations + descriptions + + +cdef class Account: + """Slurm Database Account. + + Attributes: + name (str): + Name of the Account. + description (str): + Description of the Account. + organization (str): + Organization of the Account. + is_deleted (bool): + Whether this Account has been deleted or not. + association (pyslurm.db.Association): + This accounts association. + """ + cdef: + slurmdb_account_rec_t *ptr + + cdef readonly: + cluster + + cdef public: + associations + coordinators + association + Connection _db_conn + + @staticmethod + cdef Account from_ptr(slurmdb_account_rec_t *in_ptr) diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx new file mode 100644 index 00000000..b1966389 --- /dev/null +++ b/pyslurm/db/account.pyx @@ -0,0 +1,352 @@ +######################################################################### +# account.pyx - pyslurm slurmdbd account api +######################################################################### +# Copyright (C) 2026 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) +from pyslurm.utils.helpers import ( + instance_to_dict, + user_to_uid, +) +from pyslurm.utils.uint import * +from pyslurm import xcollections +from pyslurm.db.error import handle_response +from typing import Any, Union, Optional, List, Dict + + +cdef class AccountAPI(ConnectionWrapper): + + def load(self, db_filter: Optional[AccountFilter] = None): + cdef: + Accounts out = Accounts() + Account account + SlurmList account_data + SlurmListItem account_ptr + SlurmList assoc_data + SlurmListItem assoc_ptr + Association assoc + QualitiesOfService qos_data + TrackableResources tres_data + + self.db_conn.validate() + + if not db_filter: + db_filter = AccountFilter() + + if db_filter.with_assocs is not False: + db_filter.with_assocs = True + + db_filter._create() + account_data = SlurmList.wrap(slurmdb_accounts_get(self.db_conn.ptr, db_filter.ptr)) + + if account_data.is_null: + raise RPCError(msg="Failed to get Account data from slurmdbd.") + + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() + + for account_ptr in SlurmList.iter_and_pop(account_data): + account = Account.from_ptr(account_ptr.data) + out[account.name] = account + self.db_conn.apply_reuse(account) + + assoc_data = SlurmList.wrap(account.ptr.assoc_list, owned=False) + for assoc_ptr in SlurmList.iter_and_pop(assoc_data): + assoc = Association.from_ptr(assoc_ptr.data) + assoc.qos_data = qos_data + assoc.tres_data = tres_data + self.db_conn.apply_reuse(assoc) + _parse_assoc_ptr(assoc) + + if not assoc.user: + # This is the Association of the account itself. + account.association = assoc + else: + # These must be User Associations. + # TODO: maybe rename to user_associations + account.associations.append(assoc) + + self.db_conn.apply_reuse(out) + return out + + def delete(self, db_filter: AccountFilter): + out = [] + # Check is required because for some reason if the acct_cond doesn't + # contain any valid conditions, slurmdbd will delete all accounts. + # TODO: Maybe make it configurable + if not db_filter.names: + return + + self.db_conn.validate() + db_filter._create() + + response = SlurmList.wrap(slurmdb_accounts_remove(self.db_conn.ptr, db_filter.ptr)) + rc = slurm_errno() + self.db_conn.check_commit(rc) + return handle_response(response, rc) + + def modify( + self, + db_filter: AccountFilter, + changes: Optional[Account] = None, + **kwargs: Any + ): + cdef Account _changes + + _changes = _get_modify_arguments_for(Account, changes, **kwargs) + + self.db_conn.validate() + db_filter._create() + + response = SlurmList.wrap(slurmdb_accounts_modify( + self.db_conn.ptr, db_filter.ptr, _changes.ptr) + ) + rc = slurm_errno() + self.db_conn.check_commit(rc) + return handle_response(response, rc) + + def create(self, accounts: List[Account]): + cdef: + Account account + SlurmList account_list + list assocs_to_add = [] + + self.db_conn.validate() + account_list = SlurmList.create(slurmdb_destroy_account_rec, owned=False) + + for account in accounts: + if not account.association: + account.association = Association(account=account.name) + + assocs_to_add.append(account.association) + slurm.slurm_list_append(account_list.info, account.ptr) + + rc = slurmdb_accounts_add(self.db_conn.ptr, account_list.info) + + # Could also solve this construct via a simple try..finally, but I just + # don't want to execute commit/rollback potentially twice, even if it + # is completely fine. + try: + if rc == slurm.SLURM_SUCCESS: + self.db_conn.associations.create(assocs_to_add) + except RPCError: + # Just re-raise - required rollback was already taken care of + raise + except Exception: + # Doing this catch-all thing might be too cautious, but just in + # case anything goes wrong before Associations were attempted to be + # added, we make sure that adding the users is also rollbacked. + # + # Because we don't want to leave Users with no associations behind + # in the system, if associations were requested to be added. + self.db_conn.check_commit(slurm.SLURM_ERROR) + raise + + # TODO: SLURM_NO_CHANGE_IN_DATA + # Should this be an error? + + # Rollback or commit in case no associations were attempted to be added + self.db_conn.check_commit(rc) + verify_rpc(rc) + + +cdef class Accounts(dict): + + def __init__(self, accounts={}, **kwargs: Any): + super().__init__() + self.update(accounts) + self.update(kwargs) + self._db_conn = None + + @staticmethod + def load(db_conn: Connection, db_filter: Optional[AccountFilter] = None): + return db_conn.accounts.load(db_filter) + + def delete(self, db_conn: Optional[Connection] = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_filter = AccountFilter(names=list(self.keys())) + db_conn.accounts.delete(db_filter) + + def modify(self, changes: Optional[Account] = None, db_conn: Optional[Connection] = None, **kwargs: Any): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_filter = AccountFilter(names=list(self.keys())) + return db_conn.accounts.modify(db_filter, changes, **kwargs) + + def create(self, db_conn: Optional[Connection] = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_conn.accounts.create(list(self.values())) + + +cdef class AccountFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs: Any): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_account_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_account_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_account_cond_t") + + memset(self.ptr, 0, sizeof(slurmdb_account_cond_t)) + + self.ptr.assoc_cond = try_xmalloc(sizeof(slurmdb_assoc_cond_t)) + if not self.ptr.assoc_cond: + raise MemoryError("xmalloc failed for slurmdb_assoc_cond_t") + + def _parse_flag(self, val, flag_val): + if val: + self.ptr.flags |= flag_val + + def _create(self): + self._alloc() + cdef slurmdb_account_cond_t *ptr = self.ptr + + make_char_list(&ptr.assoc_cond.acct_list, self.names) + self._parse_flag(self.with_assocs, slurm.SLURMDB_ACCT_FLAG_WASSOC) + self._parse_flag(self.with_deleted, slurm.SLURMDB_ACCT_FLAG_DELETED) + self._parse_flag(self.with_coordinators, slurm.SLURMDB_ACCT_FLAG_WCOORD) + + +cdef class Account: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, name: str = None, description: str = None, organization: str = None, **kwargs: Any): + self._alloc_impl() + self._init_defaults() + self.name = name + self.description = description or name + self.organization = organization or name + + for k, v in kwargs.items(): + setattr(self, k, v) + + def _init_defaults(self): + self.associations = [] + self.association = None + self.coordinators = [] + + def __dealloc__(self): + self._dealloc_impl() + + def _dealloc_impl(self): + slurmdb_destroy_account_rec(self.ptr) + self.ptr = NULL + + def _alloc_impl(self): + if not self.ptr: + self.ptr = try_xmalloc( + sizeof(slurmdb_account_rec_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_account_rec_t") + + memset(self.ptr, 0, sizeof(slurmdb_account_rec_t)) + + def __repr__(self): + return f'pyslurm.db.{self.__class__.__name__}({self.name})' + + @staticmethod + cdef Account from_ptr(slurmdb_account_rec_t *in_ptr): + cdef Account wrap = Account.__new__(Account) + wrap.ptr = in_ptr + wrap._init_defaults() + return wrap + + def to_dict(self, recursive: bool = False): + """Database Account information formatted as a dictionary. + + Returns: + (dict): Database Account information as dict. + """ + return instance_to_dict(self, recursive) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Account): + return self.name == other.name + return NotImplemented + + @staticmethod + def load(db_conn: Connection, name: str): + account = db_conn.accounts.load().get(name) + if not account: + # TODO: Maybe don't raise here and just return None and let the + # Caller handle it? + raise NotFoundError(msg=f"Account {name} does not exist.") + return account + + def create(self, db_conn: Optional[Connection] = None): + Accounts({self.name: self}).create(self._db_conn or db_conn) + + def delete(self, db_conn: Optional[Connection] = None): + Accounts({self.name: self}).delete(self._db_conn or db_conn) + + def modify(self, changes: Optional[Account] = None, db_conn: Optional[Connection] = None, **kwargs: Any): + Accounts({self.name: self}).modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) + + @property + def name(self): + return cstr.to_unicode(self.ptr.name) + + @name.setter + def name(self, val): + cstr.fmalloc(&self.ptr.name, val) + + @property + def description(self): + return cstr.to_unicode(self.ptr.description) + + @description.setter + def description(self, val): + cstr.fmalloc(&self.ptr.description, val) + + @property + def organization(self): + return cstr.to_unicode(self.ptr.organization) + + @organization.setter + def organization(self, val): + cstr.fmalloc(&self.ptr.organization, val) + + @property + def is_deleted(self): + if self.ptr.flags & slurm.SLURMDB_ACCT_FLAG_DELETED: + return True + return False diff --git a/pyslurm/db/assoc.pxd b/pyslurm/db/assoc.pxd index b0dad3a9..cf84951e 100644 --- a/pyslurm/db/assoc.pxd +++ b/pyslurm/db/assoc.pxd @@ -26,11 +26,13 @@ from pyslurm cimport slurm from pyslurm.slurm cimport ( slurmdb_assoc_rec_t, slurmdb_assoc_cond_t, - slurmdb_associations_get, slurmdb_destroy_assoc_rec, slurmdb_destroy_assoc_cond, slurmdb_init_assoc_rec, + slurmdb_associations_get, slurmdb_associations_modify, + slurmdb_associations_add, + slurmdb_associations_remove, try_xmalloc, ) from pyslurm.db.util cimport ( @@ -44,7 +46,7 @@ from pyslurm.db.tres cimport ( _set_tres_limits, TrackableResources, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.utils.uint cimport * from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list @@ -54,25 +56,40 @@ cdef _parse_assoc_ptr(Association ass) cdef _create_assoc_ptr(Association ass, conn=*) -cdef class Associations(MultiClusterMap): +cdef class AssociationAPI(ConnectionWrapper): pass +cdef class Associations(MultiClusterMap): + cdef public: + Connection _db_conn + + cdef class AssociationFilter: cdef slurmdb_assoc_cond_t *ptr cdef public: users ids + accounts + parent_accounts + clusters + partitions + qos cdef class Association: cdef: slurmdb_assoc_rec_t *ptr + slurmdb_assoc_rec_t *umsg QualitiesOfService qos_data TrackableResources tres_data + owned cdef public: + Connection _db_conn + default_qos + group_tres group_tres_mins group_tres_run_mins @@ -81,7 +98,21 @@ cdef class Association: max_tres_per_job max_tres_per_node qos + group_jobs + group_jobs_accrue + group_submit_jobs + group_wall_time + max_jobs + max_jobs_accrue + max_submit_jobs + max_wall_time_per_job + min_priority_threshold + priority + shares @staticmethod cdef Association from_ptr(slurmdb_assoc_rec_t *in_ptr) + +cdef class AssociationList(SlurmList): + pass diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index 56370a04..b30e8619 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -1,7 +1,7 @@ ######################################################################### # assoc.pyx - pyslurm slurmdbd association api ######################################################################### -# Copyright (C) 2023 Toni Harzendorf +# Copyright (C) 2026 Toni Harzendorf # # This file is part of PySlurm # @@ -22,64 +22,58 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.helpers import ( instance_to_dict, user_to_uid, ) from pyslurm.utils.uint import * -from pyslurm.db.connection import _open_conn_or_error from pyslurm import settings from pyslurm import xcollections +from pyslurm.db.error import handle_response +from typing import Any, Union, Optional, List, Dict -cdef class Associations(MultiClusterMap): +cdef class AssociationAPI(ConnectionWrapper): - def __init__(self, assocs=None): - super().__init__(data=assocs, - typ="Associations", - val_type=Association, - id_attr=Association.id, - key_type=int) - - @staticmethod - def load(AssociationFilter db_filter=None, Connection db_connection=None): + def load(self, db_filter: Optional[AssociationFilter] = None): cdef: Associations out = Associations() Association assoc - AssociationFilter cond = db_filter SlurmList assoc_data SlurmListItem assoc_ptr - Connection conn QualitiesOfService qos_data TrackableResources tres_data - # Prepare SQL Filter - if not db_filter: - cond = AssociationFilter() - cond._create() + self.db_conn.validate() - # Setup DB Conn - conn = _open_conn_or_error(db_connection) + if not db_filter: + db_filter = AssociationFilter() + db_filter._create() - # Fetch Assoc Data assoc_data = SlurmList.wrap(slurmdb_associations_get( - conn.ptr, cond.ptr)) + self.db_conn.ptr, db_filter.ptr) + ) if assoc_data.is_null: - raise RPCError(msg="Failed to get Association data from slurmdbd") + raise RPCError(msg="Failed to get Association data from slurmdbd.") # Fetch other necessary dependencies needed for translating some # attributes (i.e QoS IDs to its name) - qos_data = QualitiesOfService.load(db_connection=conn, - name_is_key=False) - tres_data = TrackableResources.load(db_connection=conn) + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() - # Setup Association objects for assoc_ptr in SlurmList.iter_and_pop(assoc_data): assoc = Association.from_ptr(assoc_ptr.data) assoc.qos_data = qos_data assoc.tres_data = tres_data + self.db_conn.apply_reuse(assoc) _parse_assoc_ptr(assoc) cluster = assoc.cluster @@ -87,59 +81,156 @@ cdef class Associations(MultiClusterMap): out.data[cluster] = {} out.data[cluster][assoc.id] = assoc + self.db_conn.apply_reuse(out) return out - @staticmethod - def modify(db_filter, Association changes, Connection db_connection=None): + def delete(self, db_filter: AssociationFilter): + out = [] + # TODO: Properly check if the filter is empty, cause it will then probably + # target all assocs. Or maybe that is fine and we need to clearly document + # to take caution + # if not db_filter.ids: + # return + + self.db_conn.validate() + a_filter._create() + + response = SlurmList.wrap(slurmdb_associations_remove( + self.db_conn.ptr, db_filter.ptr) + ) + rc = slurm_errno() + self.db_conn.check_commit(rc) + return handle_response(response, rc) + + def modify(self, db_filter: AssociationFilter, changes: Optional[Association] = None, **kwargs: Any): cdef: - AssociationFilter afilter - Connection conn - SlurmList response - SlurmListItem response_ptr - list out = [] - - # Prepare SQL Filter - if isinstance(db_filter, Associations): - assoc_ids = [ass.id for ass in db_filter] - afilter = AssociationFilter(ids=assoc_ids) - else: - afilter = db_filter - afilter._create() - - # Setup DB conn - conn = _open_conn_or_error(db_connection) + Association _changes + + out = [] + # TODO: prohibit mixing multiple user assocs with account assocs + # This is not possible, and the request will simply affect nothing... + + _changes = _get_modify_arguments_for(Association, changes, **kwargs) + + self.db_conn.validate() + db_filter._create() # Any data that isn't parsed yet or needs validation is done in this # function. - _create_assoc_ptr(changes, conn) + _create_assoc_ptr(_changes, self.db_conn) - # Modify associations, get the result - # This returns a List of char* with the associations that were - # modified + # Returns a List of char* with the associations that were modified response = SlurmList.wrap(slurmdb_associations_modify( - conn.ptr, afilter.ptr, changes.ptr)) + self.db_conn.ptr, db_filter.ptr, _changes.ptr)) + rc = slurm_errno() + self.db_conn.check_commit(rc) + return handle_response(response, rc) - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue + def create(self, associations: List[Association]): + cdef: + Association assoc + AssociationList assoc_list = AssociationList(owned=False) - # TODO: Better format - out.append(response_str) + if not associations: + return - elif not response.is_null: - # There was no real error, but simply nothing has been modified - raise RPCError(msg="Nothing was modified") - else: - # Autodetects the last slurm error - raise RPCError() + self.db_conn.validate() - if not db_connection: - # Autocommit if no connection was explicitly specified. - conn.commit() + for i, assoc in enumerate(associations): + # Make sure to remove any duplicate associations, i.e. associations + # having the same account name set. For some reason, the slurmdbd + # doesn't like that. + if assoc not in assoc_list: + assoc_list.append(assoc) - return out + rc = slurmdb_associations_add(self.db_conn.ptr, assoc_list.info) + # TODO: SLURM_NO_CHANGE_IN_DATA + # Should this be an error? + self.db_conn.check_commit(rc) + verify_rpc(rc) + + +cdef class AssociationList(SlurmList): + + def __init__(self, owned=True): + self.info = slurm.slurm_list_create(slurm.slurmdb_destroy_assoc_rec) + self.owned = owned + + def append(self, Association assoc): + slurm.slurm_list_append(self.info, assoc.ptr) + assoc.owned = False + self.cnt = slurm.slurm_list_count(self.info) + + def __iter__(self): + return super().__iter__() + + def __next__(self): + if self.is_null or self.is_itr_null: + raise StopIteration + + if self.itr_cnt < self.cnt: + self.itr_cnt += 1 + assoc = Association.from_ptr(slurm.slurm_list_next(self.itr)) + assoc.owned = False + return assoc + + self._dealloc_itr() + raise StopIteration + + def extend(self, list_in): + for item in list_in: + self.append(item) + + +cdef class Associations(MultiClusterMap): + + def __init__(self, assocs=None): + super().__init__(data=assocs, + typ="Associations", + val_type=Association, + id_attr=Association.id, + key_type=int) + self._db_conn = None + + def _do_api_call(self, fn, **kwargs): + res_user = [] + res_accts = [] + db_filter_users = AssociationFilter(users=[], ids=[]) + db_filter_accounts = AssociationFilter(accounts=[], ids=[]) + + # We need to split User Associations from Account associations + # If we mix both, the request will not do anything... + for assoc in self.values(): + if assoc.user: + db_filter_users.users.append(assoc.user) + db_filter_users.ids.append(assoc.id) + else: + db_filter_accounts.accounts.append(assoc.account) + db_filter_accounts.ids.append(assoc.id) + + if db_filter_users.users: + res_user = fn(db_filter_users, **kwargs) + + if db_filter_accounts.accounts: + res_accts = fn(db_filter_accounts, **kwargs) + + return res_user + res_accts + + @staticmethod + def load(db_conn: Connection, db_filter: Optional[AssociationFilter] = None): + return db_conn.associations.load(db_filter) + + def delete(self, db_conn: Optional[Connection] = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + self._do_api_call(db_conn.associations.delete, changes=changes) + + def modify(self, changes: Optional[Association] = None, db_conn: Optional[Connection] = None, **kwargs: Any): + db_conn = Connection.reuse(self._db_conn, db_conn) + return self._do_api_call(db_conn.associations.modify, changes=changes, **kwargs) + + def create(self, db_conn: Optional[Connection] = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + self._do_api_call(db_conn.associations.create, associations=list(self.values())) cdef class AssociationFilter: @@ -147,7 +238,7 @@ cdef class AssociationFilter: def __cinit__(self): self.ptr = NULL - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for k, v in kwargs.items(): setattr(self, k, v) @@ -174,22 +265,38 @@ cdef class AssociationFilter: cdef slurmdb_assoc_cond_t *ptr = self.ptr make_char_list(&ptr.user_list, self.users) + make_char_list(&ptr.id_list, self.ids) + make_char_list(&ptr.acct_list, self.accounts) + make_char_list(&ptr.parent_acct_list, self.parent_accounts) + make_char_list(&ptr.cluster_list, self.clusters) + make_char_list(&ptr.partition_list, self.partitions) + # TODO: These should be QOS ids, not names + make_char_list(&ptr.qos_list, self.qos) + # TODO: ASSOC_COND_FLAGS cdef class Association: def __cinit__(self): self.ptr = NULL + self.owned = True - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self._alloc_impl() self.id = 0 + + # Only when an Account-Association is initialized, we default to + # "root" as the Parent Account. + user = kwargs.get("user") + self.parent_account = kwargs.pop("parent_account", + "root" if not user else None) self.cluster = settings.LOCAL_CLUSTER for k, v in kwargs.items(): setattr(self, k, v) def __dealloc__(self): - self._dealloc_impl() + if self.owned: + self._dealloc_impl() def _dealloc_impl(self): slurmdb_destroy_assoc_rec(self.ptr) @@ -213,7 +320,7 @@ cdef class Association: wrap.ptr = in_ptr return wrap - def to_dict(self, recursive = False): + def to_dict(self, recursive: bool = False): """Database Association information formatted as a dictionary. Returns: @@ -221,11 +328,29 @@ cdef class Association: """ return instance_to_dict(self, recursive) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Association): - return self.id == other.id and self.cluster == other.cluster +# return self.id == other.id and self.cluster == other.cluster + return self.cluster == other.cluster and self.partition == other.partition and self.account == other.account and self.user == other.user return NotImplemented + @staticmethod + def load(db_conn: Connection, id: Union[str, int]): + assoc = db_conn.associations.load().get(int(id)) + if not assoc: + raise NotFoundError(msg=f"Association with id '{id}' does not exist.") + return assoc + + def create(self, db_conn: Optional[Connection] = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_conn.associations.create([self]) + + def delete(self, db_conn: Optional[Connection] = None): + Associations({self.id: self}).delete(self._db_conn or db_conn) + + def modify(self, changes: Optional[Association] = None, db_conn: Optional[Connection] = None, **kwargs: Any): + Associations({self.id: self}).modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) + @property def account(self): return cstr.to_unicode(self.ptr.acct) @@ -250,42 +375,6 @@ cdef class Association: def comment(self, val): cstr.fmalloc(&self.ptr.comment, val) - # uint32_t def_qos_id - - # uint16_t flags (ASSOC_FLAG_*) - - @property - def group_jobs(self): - return u32_parse(self.ptr.grp_jobs, zero_is_noval=False) - - @group_jobs.setter - def group_jobs(self, val): - self.ptr.grp_jobs = u32(val, zero_is_noval=False) - - @property - def group_jobs_accrue(self): - return u32_parse(self.ptr.grp_jobs_accrue, zero_is_noval=False) - - @group_jobs_accrue.setter - def group_jobs_accrue(self, val): - self.ptr.grp_jobs_accrue = u32(val, zero_is_noval=False) - - @property - def group_submit_jobs(self): - return u32_parse(self.ptr.grp_submit_jobs, zero_is_noval=False) - - @group_submit_jobs.setter - def group_submit_jobs(self, val): - self.ptr.grp_submit_jobs = u32(val, zero_is_noval=False) - - @property - def group_wall_time(self): - return u32_parse(self.ptr.grp_wall, zero_is_noval=False) - - @group_wall_time.setter - def group_wall_time(self, val): - self.ptr.grp_wall = u32(val, zero_is_noval=False) - @property def id(self): return u32_parse(self.ptr.id) @@ -294,58 +383,35 @@ cdef class Association: def id(self, val): self.ptr.id = val - @property - def is_default(self): - return u16_parse_bool(self.ptr.is_def) - - @property - def max_jobs(self): - return u32_parse(self.ptr.max_jobs, zero_is_noval=False) - - @max_jobs.setter - def max_jobs(self, val): - self.ptr.max_jobs = u32(val, zero_is_noval=False) - - @property - def max_jobs_accrue(self): - return u32_parse(self.ptr.max_jobs_accrue, zero_is_noval=False) - - @max_jobs_accrue.setter - def max_jobs_accrue(self, val): - self.ptr.max_jobs_accrue = u32(val, zero_is_noval=False) - - @property - def max_submit_jobs(self): - return u32_parse(self.ptr.max_submit_jobs, zero_is_noval=False) - - @max_submit_jobs.setter - def max_submit_jobs(self, val): - self.ptr.max_submit_jobs = u32(val, zero_is_noval=False) - @property - def max_wall_time_per_job(self): - return u32_parse(self.ptr.max_wall_pj, zero_is_noval=False) + # uint32_t def_qos_id - @max_wall_time_per_job.setter - def max_wall_time_per_job(self, val): - self.ptr.max_wall_pj = u32(val, zero_is_noval=False) + # uint16_t flags (ASSOC_FLAG_*) @property - def min_priority_threshold(self): - return u32_parse(self.ptr.min_prio_thresh, zero_is_noval=False) + def is_default(self): + return u16_parse_bool(self.ptr.is_def) - @min_priority_threshold.setter - def min_priority_threshold(self, val): - self.ptr.min_prio_thresh = u32(val, zero_is_noval=False) + @is_default.setter + def is_default(self, val): + self.ptr.is_def = u16_bool(val) @property def parent_account(self): return cstr.to_unicode(self.ptr.parent_acct) + @parent_account.setter + def parent_account(self, val): + cstr.fmalloc(&self.ptr.parent_acct, val) + @property def parent_account_id(self): return u32_parse(self.ptr.parent_id, zero_is_noval=False) + @property + def lineage(self): + return cstr.to_unicode(self.ptr.lineage) + @property def partition(self): return cstr.to_unicode(self.ptr.partition) @@ -354,22 +420,6 @@ cdef class Association: def partition(self, val): cstr.fmalloc(&self.ptr.partition, val) - @property - def priority(self): - return u32_parse(self.ptr.priority, zero_is_noval=False) - - @priority.setter - def priority(self, val): - self.ptr.priority = u32(val) - - @property - def shares(self): - return u32_parse(self.ptr.shares_raw, zero_is_noval=False) - - @shares.setter - def shares(self, val): - self.ptr.shares_raw = u32(val) - @property def user(self): return cstr.to_unicode(self.ptr.user) @@ -378,6 +428,10 @@ cdef class Association: def user(self, val): cstr.fmalloc(&self.ptr.user, val) + @property + def user_id(self): + return u32_parse(self.ptr.uid, zero_is_noval=False) + cdef _parse_assoc_ptr(Association ass): cdef: @@ -400,12 +454,25 @@ cdef _parse_assoc_ptr(Association ass): ass.ptr.max_tres_pn, tres) ass.qos = qos_list_to_pylist(ass.ptr.qos_list, qos) + ass.group_jobs = u32_parse(ass.ptr.grp_jobs, zero_is_noval=False) + ass.group_jobs_accrue = u32_parse(ass.ptr.grp_jobs_accrue, zero_is_noval=False) + ass.group_submit_jobs = u32_parse(ass.ptr.grp_submit_jobs, zero_is_noval=False) + ass.group_wall_time = u32_parse(ass.ptr.grp_wall, zero_is_noval=False) + ass.max_jobs = u32_parse(ass.ptr.max_jobs, zero_is_noval=False) + ass.max_jobs_accrue = u32_parse(ass.ptr.max_jobs_accrue, zero_is_noval=False) + ass.max_submit_jobs = u32_parse(ass.ptr.max_submit_jobs, zero_is_noval=False) + ass.max_wall_time_per_job = u32_parse(ass.ptr.max_wall_pj, zero_is_noval=False) + ass.min_priority_threshold = u32_parse(ass.ptr.min_prio_thresh, zero_is_noval=False) + ass.priority = u32_parse(ass.ptr.priority, zero_is_noval=False) + ass.shares = u32_parse(ass.ptr.shares_raw, zero_is_noval=False) + # TODO: default_qos + cdef _create_assoc_ptr(Association ass, conn=None): # _set_tres_limits will also check if specified TRES are valid and # translate them to its ID which is why we need to load the current TRES # available in the system. - ass.tres_data = TrackableResources.load(db_connection=conn) + ass.tres_data = conn.tres.load() _set_tres_limits(&ass.ptr.grp_tres, ass.group_tres, ass.tres_data) _set_tres_limits(&ass.ptr.grp_tres_mins, ass.group_tres_mins, ass.tres_data) @@ -423,6 +490,17 @@ cdef _create_assoc_ptr(Association ass, conn=None): # _set_qos_list will also check if specified QoS are valid and translate # them to its ID, which is why we need to load the current QOS available # in the system. - ass.qos_data = QualitiesOfService.load(db_connection=conn) - _set_qos_list(&ass.ptr.qos_list, self.qos, ass.qos_data) - + ass.qos_data = conn.qos.load() + _set_qos_list(&ass.ptr.qos_list, ass.qos, ass.qos_data) + + ass.ptr.grp_jobs = u32(ass.group_jobs, zero_is_noval=False) + ass.ptr.grp_jobs_accrue = u32(ass.group_jobs_accrue, zero_is_noval=False) + ass.ptr.grp_submit_jobs = u32(ass.group_submit_jobs, zero_is_noval=False) + ass.ptr.grp_wall = u32(ass.group_wall_time, zero_is_noval=False) + ass.ptr.max_jobs = u32(ass.max_jobs, zero_is_noval=False) + ass.ptr.max_jobs_accrue = u32(ass.max_jobs_accrue, zero_is_noval=False) + ass.ptr.max_submit_jobs = u32(ass.max_submit_jobs, zero_is_noval=False) + ass.ptr.max_wall_pj = u32(ass.max_wall_time_per_job, zero_is_noval=False) + ass.ptr.min_prio_thresh = u32(ass.min_priority_threshold, zero_is_noval=False) + ass.ptr.priority = u32(ass.priority) + ass.ptr.shares_raw = u32(ass.shares) diff --git a/pyslurm/db/connection.pxd b/pyslurm/db/connection.pxd index 6ac2dfc6..233c9a9c 100644 --- a/pyslurm/db/connection.pxd +++ b/pyslurm/db/connection.pxd @@ -31,6 +31,17 @@ from pyslurm.slurm cimport ( ) +cdef class ConnectionConfig: + cdef public: + transaction_mode + reuse_connection + + +cdef class ConnectionWrapper: + cdef: + Connection db_conn + + cdef class Connection: """A connection to the slurmdbd. @@ -41,3 +52,14 @@ cdef class Connection: cdef: void *ptr uint16_t flags + + cdef public: + config + + cdef readonly: + users + accounts + associations + tres + qos + jobs diff --git a/pyslurm/db/connection.pyx b/pyslurm/db/connection.pyx index 9e1a4428..d514dfa6 100644 --- a/pyslurm/db/connection.pyx +++ b/pyslurm/db/connection.pyx @@ -1,7 +1,7 @@ ######################################################################### # connection.pyx - pyslurm slurmdbd database connection ######################################################################### -# Copyright (C) 2023 Toni Harzendorf +# Copyright (C) 2026 Toni Harzendorf # # This file is part of PySlurm # @@ -22,17 +22,60 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError +from pyslurm.core.error import RPCError, PyslurmError +from contextlib import contextmanager +from pyslurm.db.user import UserAPI +from pyslurm.db.account import AccountAPI +from pyslurm.db.assoc import AssociationAPI +from pyslurm.db.tres import TrackableResourceAPI +from pyslurm.db.qos import QualityOfServiceAPI +from pyslurm.db.job import JobsAPI +from typing import Any, Optional +from pyslurm.utils.enums import StrEnum +from enum import auto -def _open_conn_or_error(conn): - if not conn: - conn = Connection.open() +class TransactionMode(StrEnum): + PER_OPERATION = auto() + MANUAL = auto() - if not conn.is_open: - raise ValueError("Database connection is not open") - return conn +cdef class ConnectionConfig: + + def __init__( + self, + transaction_mode: TransactionMode = TransactionMode.PER_OPERATION, + reuse_connection: bool = True, + ): + self.transaction_mode = transaction_mode + self.reuse_connection = reuse_connection + + +cdef class ConnectionWrapper: + + def __init__(self, db_conn: Connection): + self.db_conn = db_conn + + +class InvalidConnectionError(PyslurmError): + pass + + +class ConfigError(PyslurmError): + pass + + +@contextmanager +def connect(config: Optional[ConnectionConfig] = None, **kwargs: Any): + """A managed Slurm DB Connection""" + if config is not None and kwargs: + raise ConfigError("Must provide either a config directly, or kwargs, not both") + + connection = Connection.open(config, **kwargs) + try: + yield connection + finally: + connection.close() cdef class Connection: @@ -53,7 +96,36 @@ cdef class Connection: return f'pyslurm.db.{self.__class__.__name__} is {state}' @staticmethod - def open(): + def reuse( + reusable_conn: Optional[Connection] = None, + explicit_conn: Optional[Connection] = None + ): + if explicit_conn: + return explicit_conn + elif reusable_conn: + return reusable_conn + else: + raise InvalidConnectionError("No suitable Connection was provided") + + def apply_reuse(self, obj): + if self.config.reuse_connection: + obj._db_conn = self + + def validate(self): + if not self.is_open: + raise InvalidConnectionError("Connection is closed") + + def check_commit(self, rc): + if self.config.transaction_mode != TransactionMode.PER_OPERATION: + return + + if rc == slurm.SLURM_SUCCESS: + self.commit() + else: + self.rollback() + + @staticmethod + def open(config: Optional[ConnectionConfig] = None, **kwargs: Any): """Open a new connection to the slurmdbd Raises: @@ -68,11 +140,23 @@ cdef class Connection: >>> print(connection.is_open) True """ + if config is not None and kwargs: + raise ConfigError("Must provide either a config directly, or kwargs, not both") + cdef Connection conn = Connection.__new__(Connection) conn.ptr = slurmdb_connection_get(&conn.flags) if not conn.ptr: raise RPCError(msg="Failed to open onnection to slurmdbd") + conn.config = config or ConnectionConfig(**kwargs) + + # Initialize all DB APIs + conn.users = UserAPI(conn) + conn.accounts = AccountAPI(conn) + conn.associations = AssociationAPI(conn) + conn.tres = TrackableResourceAPI(conn) + conn.qos = QualityOfServiceAPI(conn) + conn.jobs = JobsAPI(conn) return conn def close(self): @@ -92,11 +176,17 @@ cdef class Connection: def commit(self): """Commit recent changes.""" + if not self.is_open: + raise InvalidConnectionError("Tried to commit when Connection is already closed.") + if slurmdb_connection_commit(self.ptr, 1) == slurm.SLURM_ERROR: raise RPCError("Failed to commit database changes.") def rollback(self): """Rollback recent changes.""" + if not self.is_open: + raise InvalidConnectionError("Tried to rollback when Connection is already closed.") + if slurmdb_connection_commit(self.ptr, 0) == slurm.SLURM_ERROR: raise RPCError("Failed to rollback database changes.") diff --git a/pyslurm/db/error.pyx b/pyslurm/db/error.pyx new file mode 100644 index 00000000..34d83ea9 --- /dev/null +++ b/pyslurm/db/error.pyx @@ -0,0 +1,200 @@ +######################################################################### +# error.pyx - pyslurm db specific errors +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import RPCError, slurm_errno, verify_rpc +from pyslurm.db.util cimport SlurmList, SlurmListItem +from pyslurm cimport slurm +import re + + +# The response involving assoc modification and deletions is a string that can +# be in the following form: +# +# C = X A = X U = X P = X +# +# And we have to parse this stuff... The Partition (P) is optional +assoc_str_pattern = re.compile(r'(\w)\s*=\s*(\w+)') + + +def handle_response(SlurmList response, rc): + out = [] + if not response.is_null and response.cnt: + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + # The slurmdbd actually deletes the associations, even if + # Jobs are running. The client side must then decide whether to + # rollback or actually commit the changes. sacctmgr does the + # rollback. + + # By default, any errors are automatically rollbacked, for safety. + # User can disable this behaviour. + raise JobsRunningError.from_response(response, rc) + elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: + raise DefaultAccountError.from_response(response, rc) + + out = parse_basic_response(response) + elif not response.is_null or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + # Nothing was modified + # TODO: Should this be an error actually? + pass + elif rc == slurm.ESLURM_INVALID_PARENT_ACCOUNT: + # When modifying Associations failed. + # TODO: proper error + verify_rpc(rc) + else: + # ESLURM_ONE_CHANGE - may happen when name of a user is attempted to be + # changed. only 1 user can be specified at a time. Could also detect + # that earlier + verify_rpc(rc) + + return out + + +class AssociationChangeInfo: + + def __init__(self, user, cluster, account, partition=None): + self.user = user + self.cluster = cluster + self.account = account + self.partition = partition + self.running_jobs = [] + + +class JobsRunningError(RPCError): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.associations = [] + + @staticmethod + def from_response(SlurmList response, rc): + running_jobs = parse_running_job_errors(response) + err = JobsRunningError(errno=rc) + err.associations = list(running_jobs.values()) + return err + + +class DefaultAccountError(RPCError): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.associations = [] + + @staticmethod + def from_response(SlurmList response, rc): + err = DefaultAccountError(errno=rc) + err.associations = parse_default_account_errors(response) + return err + + +def get_responses(SlurmList response): + cdef SlurmListItem response_ptr + + if response.is_null: + return [] + + for response_ptr in response: + response_str = response_ptr.to_str() + if response_str: + yield response_str + + +def parse_assoc_str(value): + matches = assoc_str_pattern.findall(value) + return dict(matches) + + +def get_assoc_response(SlurmList response): + for resp in get_responses(response): + yield parse_assoc_str(resp) + + +def parse_default_account_errors(SlurmList response): + assocs = [] + for item in get_assoc_response(response): + info = AssociationChangeInfo( + cluster = item["C"], + account = item["A"], + user = item["U"], + ) + assoc_str = f"{info.cluster}-{info.account}-{info.user}" + + if len(item) > 3: + info.partition = item["P"] + assoc_str = f"{assoc_str}-{info.partition}" + + assocs.append(info) + + return assocs + + +def parse_basic_response(SlurmList response): + return list(get_responses(response)) + + +def parse_running_job_errors(SlurmList response): + cdef SlurmListItem response_ptr + + running_jobs_for_assoc = {} + for response_ptr in response: + response_str = response_ptr.to_str() + if not response_str: + continue + + # The response is a string in the following form: + # JobId = X C = X A = X U = X P = X + # + # And we have to parse this stuff... The Partition (P) is optional + resp = response_str.rstrip().lstrip() + splitted = resp.split(" ") + values = [] + for item in splitted: + if not item: + continue + + key, value = item.split("=") + values.append(value.strip()) + + job_id = int(values[0]) + cluster = values[1] + account = values[2] + user = values[3] + partition = None + assoc_str = f"{cluster}-{account}-{user}" + + if len(values) > 4: + partition = values[4] + assoc_str = f"{assoc_str}-{partition}" + + if assoc_str not in running_jobs_for_assoc: + info = AssociationChangeInfo( + user = user, + cluster = cluster, + account = account, + partition = partition, + ) + running_jobs_for_assoc[assoc_str] = info + + running_jobs_for_assoc[assoc_str].running_jobs.append(job_id) + + return running_jobs_for_assoc diff --git a/pyslurm/db/job.pxd b/pyslurm/db/job.pxd index f3593c2f..4688583d 100644 --- a/pyslurm/db/job.pxd +++ b/pyslurm/db/job.pxd @@ -50,7 +50,7 @@ from pyslurm.db.util cimport ( ) from pyslurm.db.step cimport JobStep, JobSteps from pyslurm.db.stats cimport JobStatistics -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.db.qos cimport QualitiesOfService from pyslurm.db.tres cimport ( @@ -63,6 +63,10 @@ from pyslurm.utils.uint cimport u32_parse_bool_flag from libc.stdint cimport uint32_t +cdef class JobsAPI(ConnectionWrapper): + pass + + cdef class JobFilter: """Query-Conditions for Jobs in the Slurm Database. @@ -140,6 +144,7 @@ cdef class JobFilter: cdef slurmdb_job_cond_t *ptr cdef public: + Connection _db_conn ids start_time end_time @@ -183,6 +188,7 @@ cdef class Jobs(MultiClusterMap): Total amount of requested memory in Mebibytes. """ cdef public: + Connection _db_conn stats cpus nodes @@ -364,6 +370,7 @@ cdef class Job: TrackableResources tres_data cdef public: + Connection _db_conn JobSteps steps JobStatistics stats diff --git a/pyslurm/db/job.pyx b/pyslurm/db/job.pyx index 3d580356..9890b3a9 100644 --- a/pyslurm/db/job.pyx +++ b/pyslurm/db/job.pyx @@ -23,7 +23,12 @@ # cython: language_level=3 from typing import Union, Any -from pyslurm.core.error import RPCError, PyslurmError +from pyslurm.core.error import ( + RPCError, + ArgumentError, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.uint import * from pyslurm import settings from pyslurm import xcollections @@ -42,8 +47,213 @@ from pyslurm.utils.helpers import ( _get_exit_code, gres_from_tres_dict, ) -from pyslurm.db.connection import _open_conn_or_error from pyslurm.enums import SchedulerType +from typing import Any, Optional + + +cdef class JobsAPI(ConnectionWrapper): + + def load(self, db_filter: Optional[JobFilter] = None): + """Load Jobs from the Slurm Database + + Implements the slurmdb_jobs_get RPC. + + Args: + db_filter (pyslurm.db.JobFilter): + A search filter that the slurmdbd will apply when retrieving + Jobs from the database. + + Returns: + (pyslurm.db.Jobs): A Collection of database Jobs. + + Raises: + (pyslurm.RPCError): When getting the Jobs from the Database was not + successful + + Examples: + Without a Filter the default behaviour applies, which is + simply retrieving all Jobs from the same day: + + >>> import pyslurm + >>> db_jobs = pyslurm.db.Jobs.load() + >>> print(db_jobs) + pyslurm.db.Jobs({1: pyslurm.db.Job(1), 2: pyslurm.db.Job(2)}) + >>> print(db_jobs[1]) + pyslurm.db.Job(1) + + Now with a Job Filter, so only Jobs that have specific Accounts + are returned: + + >>> import pyslurm + >>> accounts = ["acc1", "acc2"] + >>> db_filter = pyslurm.db.JobFilter(accounts=accounts) + >>> db_jobs = pyslurm.db.Jobs.load(db_filter) + """ + cdef: + Jobs out = Jobs() + Job job + JobFilter cond = db_filter + SlurmList job_data + SlurmListItem job_ptr + QualitiesOfService qos_data + TrackableResources tres_data + + self.db_conn.validate() + + # Prepare SQL Filter + if not db_filter: + db_filter = JobFilter() + + db_filter._db_conn = self.db_conn + db_filter._create() + + # Fetch Job data + job_data = SlurmList.wrap(slurmdb_jobs_get(self.db_conn.ptr, db_filter.ptr)) + if job_data.is_null: + raise RPCError(msg="Failed to get Jobs from slurmdbd") + + # Fetch other necessary dependencies needed for translating some + # attributes (i.e QoS IDs to its name) + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() + + # TODO: How to handle the possibility of duplicate job ids that could + # appear if IDs on a cluster are reset? + for job_ptr in SlurmList.iter_and_pop(job_data): + job = Job.from_ptr(job_ptr.data) + job.qos_data = qos_data + job.tres_data = tres_data + job._create_steps() + job.stats = JobStatistics.from_steps(job.steps) + self.db_conn.apply_reuse(job) + + elapsed = job.elapsed_time if job.elapsed_time else 0 + cpus = job.cpus if job.cpus else 1 + job.stats.elapsed_cpu_time = elapsed * cpus + + cluster = job.cluster + if cluster not in out.data: + out.data[cluster] = {} + out[cluster][job.id] = job + + out._add_stats(job) + + return out + + def modify( + self, + db_filter: Union[JobFilter, Jobs], + changes: Optional[Job] = None, + **kwargs: Any + ): + """Modify Slurm database Jobs. + + Implements the slurm_job_modify RPC. + + Args: + db_filter (Union[pyslurm.db.JobFilter, pyslurm.db.Jobs]): + A filter to decide which Jobs should be modified. + changes (pyslurm.db.Job): + Another [pyslurm.db.Job][] object that contains all the + changes to apply. Check the `Other Parameters` of the + [pyslurm.db.Job][] class to see which properties can be + modified. + + Returns: + (list[int]): A list of Jobs that were modified + + Raises: + (pyslurm.RPCError): When a failure modifying the Jobs occurred. + + Examples: + In its simplest form, you can do something like this: + + >>> import pyslurm + >>> + >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) + >>> changes = pyslurm.db.Job(comment="A comment for the job") + >>> modified_jobs = pyslurm.db.Jobs.modify(db_filter, changes) + >>> print(modified_jobs) + [9999] + + In the above example, the changes will be automatically committed + if successful. + You can however also control this manually by providing your own + connection object: + + >>> import pyslurm + >>> + >>> db_conn = pyslurm.db.Connection.open() + >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) + >>> changes = pyslurm.db.Job(comment="A comment for the job") + >>> modified_jobs = pyslurm.db.Jobs.modify( + ... db_filter, changes, db_conn) + + Now you can first examine which Jobs have been modified: + + >>> print(modified_jobs) + [9999] + + And then you can actually commit the changes: + + >>> db_conn.commit() + + You can also explicitly rollback these changes instead of + committing, so they will not become active: + + >>> db_conn.rollback() + """ + cdef: + JobFilter cond + Job _changes + SlurmList response + SlurmListItem response_ptr + list out = [] + + _changes = _get_modify_arguments_for(Job, changes, **kwargs) + + self.db_conn.validate() + + # Prepare SQL Filter + if isinstance(db_filter, Jobs): + job_ids = [job.id for job in self] + cond = JobFilter(ids=job_ids) + else: + cond = db_filter + + cond._db_conn = self.db_conn + cond._create() + + # Modify Jobs, get the result + # This returns a List of char* with the Jobs ids that were + # modified + response = SlurmList.wrap( + slurmdb_job_modify(self.db_conn.ptr, cond.ptr, _changes.ptr)) + + if not response.is_null and response.cnt: + for response_ptr in response: + response_str = cstr.to_unicode(response_ptr.data) + if not response_str: + continue + + # The strings in the list returned above have a structure + # like this: + # + # " submitted at " + # + # We are just interested in the Job-ID, so extract it + job_id = response_str.split(" ")[0] + if job_id and job_id.isdigit(): + out.append(int(job_id)) + + elif not response.is_null: + # There was no real error, but simply nothing has been modified + raise RPCError(msg="Nothing was modified") + else: + # Autodetects the last slurm error + raise RPCError() + + return out cdef class JobFilter: @@ -54,6 +264,7 @@ cdef class JobFilter: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) + self._db_conn = None def __dealloc__(self): self._dealloc() @@ -76,7 +287,7 @@ cdef class JobFilter: return None qos_id_list = [] - qos_data = QualitiesOfService.load() + qos_data = self._db_conn.qos.load(self._db_conn, name_is_key=True) for user_input in self.qos: found = False for qos in qos_data.values(): @@ -200,9 +411,10 @@ cdef class Jobs(MultiClusterMap): id_attr=Job.id, key_type=int) self._reset_stats() + self._db_conn = None @staticmethod - def load(JobFilter db_filter=None, Connection db_connection=None): + def load(Connection db_conn, JobFilter db_filter=None): """Load Jobs from the Slurm Database Implements the slurmdb_jobs_get RPC. @@ -211,9 +423,8 @@ cdef class Jobs(MultiClusterMap): db_filter (pyslurm.db.JobFilter): A search filter that the slurmdbd will apply when retrieving Jobs from the database. - db_connection (pyslurm.db.Connection): - An open database connection. By default if none is specified, - one will be opened automatically. + db_conn (pyslurm.db.Connection): + An open database connection. Returns: (pyslurm.db.Jobs): A Collection of database Jobs. @@ -241,56 +452,37 @@ cdef class Jobs(MultiClusterMap): >>> db_filter = pyslurm.db.JobFilter(accounts=accounts) >>> db_jobs = pyslurm.db.Jobs.load(db_filter) """ - cdef: - Jobs out = Jobs() - Job job - JobFilter cond = db_filter - SlurmList job_data - SlurmListItem job_ptr - Connection conn - QualitiesOfService qos_data - TrackableResources tres_data + return db_conn.jobs.load(db_filter) - # Prepare SQL Filter - if not db_filter: - cond = JobFilter() - cond._create() - - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - - # Fetch Job data - job_data = SlurmList.wrap(slurmdb_jobs_get(conn.ptr, cond.ptr)) - if job_data.is_null: - raise RPCError(msg="Failed to get Jobs from slurmdbd") - - # Fetch other necessary dependencies needed for translating some - # attributes (i.e QoS IDs to its name) - qos_data = QualitiesOfService.load(db_connection=conn, - name_is_key=False) - tres_data = TrackableResources.load(db_connection=conn) - - # TODO: How to handle the possibility of duplicate job ids that could - # appear if IDs on a cluster are reset? - for job_ptr in SlurmList.iter_and_pop(job_data): - job = Job.from_ptr(job_ptr.data) - job.qos_data = qos_data - job.tres_data = tres_data - job._create_steps() - job.stats = JobStatistics.from_steps(job.steps) - - elapsed = job.elapsed_time if job.elapsed_time else 0 - cpus = job.cpus if job.cpus else 1 - job.stats.elapsed_cpu_time = elapsed * cpus + def modify( + self, + changes: Optional[Job] = None, + db_conn: Optional[Connection] = None, + **kwargs: Any + ): + """Modify all Database Jobs in this collection. - cluster = job.cluster - if cluster not in out.data: - out.data[cluster] = {} - out[cluster][job.id] = job + Args: + changes (pyslurm.db.Job): + Another [pyslurm.db.Job][] object that contains all the + changes to apply. Check the `Other Parameters` of the + [pyslurm.db.Job][] class to see which properties can be + modified. + db_conn (pyslurm.db.Connection): + A Connection to the slurmdbd. + **kwargs (Any): + Instead of providing a separate `Job` object that has the + changes, you can also pass them as keyword args. - out._add_stats(job) + Returns: + (list[int]): A list of Jobs that were modified - return out + Raises: + (pyslurm.RPCError): When a failure modifying the Jobs occurred. + """ + db_conn = Connection.reuse(self._db_conn, db_conn) + db_filter = JobFilter(ids=list(self.keys())) + return db_conn.jobs.modify(db_filter=db_filter, changes=changes, **kwargs) def _reset_stats(self): self.stats = JobStatistics() @@ -310,134 +502,6 @@ cdef class Jobs(MultiClusterMap): for job in self.values(): self._add_stats(job) - @staticmethod - def modify(db_filter, Job changes, db_connection=None): - """Modify Slurm database Jobs. - - Implements the slurm_job_modify RPC. - - Args: - db_filter (Union[pyslurm.db.JobFilter, pyslurm.db.Jobs]): - A filter to decide which Jobs should be modified. - changes (pyslurm.db.Job): - Another [pyslurm.db.Job][] object that contains all the - changes to apply. Check the `Other Parameters` of the - [pyslurm.db.Job][] class to see which properties can be - modified. - db_connection (pyslurm.db.Connection): - A Connection to the slurmdbd. By default, if no connection is - supplied, one will automatically be created internally. This - means that when the changes were considered successful by the - slurmdbd, those modifications will be **automatically - committed**. - - If you however decide to provide your own Connection instance - (which must be already opened before), and the changes were - successful, they will basically be in a kind of "staging - area". By the time this function returns, the changes are not - actually made. - You are then responsible to decide whether the changes should - be committed or rolled back by using the respective methods on - the connection object. This way, you have a chance to see - which Jobs were modified before you commit the changes. - - Returns: - (list[int]): A list of Jobs that were modified - - Raises: - (pyslurm.RPCError): When a failure modifying the Jobs occurred. - - Examples: - In its simplest form, you can do something like this: - - >>> import pyslurm - >>> - >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) - >>> changes = pyslurm.db.Job(comment="A comment for the job") - >>> modified_jobs = pyslurm.db.Jobs.modify(db_filter, changes) - >>> print(modified_jobs) - [9999] - - In the above example, the changes will be automatically committed - if successful. - You can however also control this manually by providing your own - connection object: - - >>> import pyslurm - >>> - >>> db_conn = pyslurm.db.Connection.open() - >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) - >>> changes = pyslurm.db.Job(comment="A comment for the job") - >>> modified_jobs = pyslurm.db.Jobs.modify( - ... db_filter, changes, db_conn) - - Now you can first examine which Jobs have been modified: - - >>> print(modified_jobs) - [9999] - - And then you can actually commit the changes: - - >>> db_conn.commit() - - You can also explicitly rollback these changes instead of - committing, so they will not become active: - - >>> db_conn.rollback() - """ - cdef: - JobFilter cond - Connection conn - SlurmList response - SlurmListItem response_ptr - list out = [] - - # Prepare SQL Filter - if isinstance(db_filter, Jobs): - job_ids = [job.id for job in self] - cond = JobFilter(ids=job_ids) - else: - cond = db_filter - cond._create() - - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - - # Modify Jobs, get the result - # This returns a List of char* with the Jobs ids that were - # modified - response = SlurmList.wrap( - slurmdb_job_modify(conn.ptr, cond.ptr, changes.ptr)) - - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue - - # The strings in the list returned above have a structure - # like this: - # - # " submitted at " - # - # We are just interested in the Job-ID, so extract it - job_id = response_str.split(" ")[0] - if job_id and job_id.isdigit(): - out.append(int(job_id)) - - elif not response.is_null: - # There was no real error, but simply nothing has been modified - raise RPCError(msg="Nothing was modified") - else: - # Autodetects the last slurm error - raise RPCError() - - if not db_connection: - # Autocommit if no connection was explicitly specified. - conn.commit() - - return out - cdef class Job: @@ -474,10 +538,18 @@ cdef class Job: return wrap @staticmethod - def load(job_id, cluster=None, with_script=False, with_env=False): + def load( + db_conn: Connection, + job_id: int, + cluster = None, + with_script = False, + with_env = False + ): """Load the information for a specific Job from the Database. Args: + db_conn (pyslurm.db.Connection): + A slurmdbd connection. job_id (int): ID of the Job to be loaded. cluster (str): @@ -494,27 +566,15 @@ cdef class Job: (pyslurm.db.Job): Returns a new Database Job instance Raises: - (pyslurm.RPCError): If requesting the information for the database + (pyslurm.NotFoundError): If requesting the information for the database Job was not successful. - - Examples: - >>> import pyslurm - >>> db_job = pyslurm.db.Job.load(10000) - - In the above example, attributes like `script` and `environment` - are not populated. You must explicitly request one of them to be - loaded: - - >>> import pyslurm - >>> db_job = pyslurm.db.Job.load(10000, with_script=True) - >>> print(db_job.script) """ cluster = settings.LOCAL_CLUSTER if not cluster else cluster jfilter = JobFilter(ids=[int(job_id)], clusters=[cluster], with_script=with_script, with_env=with_env) - job = Jobs.load(jfilter).get((cluster, int(job_id))) + job = db_conn.jobs.load(jfilter).get((cluster, int(job_id))) if not job: - raise RPCError(msg=f"Job {job_id} does not exist on " + raise NotFoundError(msg=f"Job {job_id} does not exist on " f"Cluster {cluster}") # TODO: There might be multiple entries when job ids were reset. @@ -531,6 +591,8 @@ cdef class Job: step = JobStep.from_ptr(step_ptr.data) step.tres_data = self.tres_data self.steps[step.id] = step + # TODO: + # self.db_conn.apply_reuse(step) def as_dict(self): return self.to_dict() @@ -551,8 +613,13 @@ cdef class Job: def __repr__(self): return f'pyslurm.db.{self.__class__.__name__}({self.id})' - def modify(self, changes, db_connection=None): - """Modify a Slurm database Job. + def modify( + self, + changes: Optional[Job] = None, + db_conn: Optional[Connection] = None, + **kwargs: Any + ): + """Modify this Database Job. Args: changes (pyslurm.db.Job): @@ -560,7 +627,7 @@ cdef class Job: changes to apply. Check the `Other Parameters` of the [pyslurm.db.Job][] class to see which properties can be modified. - db_connection (pyslurm.db.Connection): + db_conn (pyslurm.db.Connection): A slurmdbd connection. See [pyslurm.db.Jobs.modify][pyslurm.db.job.Jobs.modify] for more info on this parameter. @@ -568,8 +635,8 @@ cdef class Job: Raises: (pyslurm.RPCError): When modifying the Job failed. """ - cdef JobFilter jfilter = JobFilter(ids=[self.id]) - Jobs.modify(jfilter, changes, db_connection) + jobs = Jobs({self.id: self}) + jobs.modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) @property def account(self): @@ -860,10 +927,6 @@ cdef class Job: def wckey_id(self): return u32_parse(self.ptr.wckeyid) -# @property -# def wckey_id(self): -# return u32_parse(self.ptr.wckeyid) - @property def working_directory(self): return cstr.to_unicode(self.ptr.work_dir) diff --git a/pyslurm/db/qos.pxd b/pyslurm/db/qos.pxd index 02131974..d69d3ed5 100644 --- a/pyslurm/db/qos.pxd +++ b/pyslurm/db/qos.pxd @@ -38,17 +38,22 @@ from pyslurm.db.util cimport ( SlurmListItem, make_char_list, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.utils.uint cimport u16_set_bool_flag cdef _set_qos_list(list_t **in_list, vals, QualitiesOfService data) -cdef class QualitiesOfService(dict): +cdef class QualityOfServiceAPI(ConnectionWrapper): pass +cdef class QualitiesOfService(dict): + cdef public: + Connection _db_conn + + cdef class QualityOfServiceFilter: cdef slurmdb_qos_cond_t *ptr @@ -61,6 +66,9 @@ cdef class QualityOfServiceFilter: cdef class QualityOfService: + cdef public: + Connection _db_conn + cdef slurmdb_qos_rec_t *ptr @staticmethod diff --git a/pyslurm/db/qos.pyx b/pyslurm/db/qos.pyx index 9e24189b..80fa7805 100644 --- a/pyslurm/db/qos.pyx +++ b/pyslurm/db/qos.pyx @@ -22,19 +22,24 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.helpers import instance_to_dict -from pyslurm.db.connection import _open_conn_or_error +from typing import Any, Union, Optional, List, Dict -cdef class QualitiesOfService(dict): - - def __init__(self): - pass +cdef class QualityOfServiceAPI(ConnectionWrapper): - @staticmethod - def load(QualityOfServiceFilter db_filter=None, - Connection db_connection=None, name_is_key=True): + def load( + self, + db_filter: Optional[QualityOfServiceFilter] = None, + name_is_key: bool = True + ): """Load QoS data from the Database Args: @@ -46,21 +51,16 @@ cdef class QualitiesOfService(dict): cdef: QualitiesOfService out = QualitiesOfService() QualityOfService qos - QualityOfServiceFilter cond = db_filter SlurmList qos_data SlurmListItem qos_ptr - Connection conn - # Prepare SQL Filter - if not db_filter: - cond = QualityOfServiceFilter() - cond._create() + self.db_conn.validate() - # Setup DB Conn - conn = _open_conn_or_error(db_connection) + if not db_filter: + db_filter = QualityOfServiceFilter() + db_filter._create() - # Fetch QoS Data - qos_data = SlurmList.wrap(slurmdb_qos_get(conn.ptr, cond.ptr)) + qos_data = SlurmList.wrap(slurmdb_qos_get(self.db_conn.ptr, db_filter.ptr)) if qos_data.is_null: raise RPCError(msg="Failed to get QoS data from slurmdbd") @@ -68,18 +68,44 @@ cdef class QualitiesOfService(dict): # Setup QOS objects for qos_ptr in SlurmList.iter_and_pop(qos_data): qos = QualityOfService.from_ptr(qos_ptr.data) + self.db_conn.apply_reuse(qos) _id = qos.name if name_is_key else qos.id out[_id] = qos return out +cdef class QualitiesOfService(dict): + + def __init__(self, qos={}, **kwargs: Any): + super().__init__() + self.update(qos) + self.update(kwargs) + self._db_conn = None + + @staticmethod + def load( + db_conn: Connection, + db_filter: Optional[Connection] = None, + name_is_key: bool = True + ): + """Load QoS data from the Database + + Args: + name_is_key (bool, optional): + By default, the keys in this dict are the names of each QoS. + If this is set to `False`, then the unique ID of the QoS will + be used as dict keys. + """ + return db_conn.qos.load(db_filter=db_filter, name_is_key=name_is_key) + + cdef class QualityOfServiceFilter: def __cinit__(self): self.ptr = NULL - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for k, v in kwargs.items(): setattr(self, k, v) @@ -133,9 +159,11 @@ cdef class QualityOfService: def __cinit__(self): self.ptr = NULL - def __init__(self, name=None): + def __init__(self, name: str = None, **kwargs: Any): self._alloc_impl() self.name = name + for k, v in kwargs.items(): + setattr(self, k, v) def __dealloc__(self): self._dealloc_impl() @@ -160,7 +188,7 @@ cdef class QualityOfService: def __repr__(self): return f'pyslurm.db.{self.__class__.__name__}({self.name})' - def to_dict(self, recursive = False): + def to_dict(self, recursive: bool = False): """Database QualityOfService information formatted as a dictionary. Returns: @@ -169,7 +197,7 @@ cdef class QualityOfService: return instance_to_dict(self, recursive) @staticmethod - def load(name): + def load(db_conn: Connection, name: str): """Load the information for a specific Quality of Service. Args: @@ -184,10 +212,9 @@ cdef class QualityOfService: (pyslurm.RPCError): If requesting the information from the database was not successful. """ - qfilter = QualityOfServiceFilter(names=[name]) - qos = QualitiesOfService.load(qfilter).get(name) + qos = db_conn.qos.load(name_is_key=True).get(name) if not qos: - raise RPCError(msg=f"QualityOfService {name} does not exist") + raise NotFoundError(msg=f"QualityOfService {name} does not exist") return qos diff --git a/pyslurm/db/tres.pxd b/pyslurm/db/tres.pxd index 952c8a7d..0dd7a5c9 100644 --- a/pyslurm/db/tres.pxd +++ b/pyslurm/db/tres.pxd @@ -37,7 +37,7 @@ from pyslurm.db.util cimport ( SlurmList, SlurmListItem, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper cdef find_tres_count(char *tres_str, typ, on_noval=*, on_inf=*) cdef find_tres_limit(char *tres_str, typ) @@ -46,6 +46,10 @@ cdef _tres_ids_to_names(char *tres_str, dict tres_id_map) cdef _set_tres_limits(char **dest, src, tres_data) +cdef class TrackableResourceAPI(ConnectionWrapper): + pass + + cdef class FilesystemResources(dict): """Collection of Filesystem TRES. This inherits from `dict`.""" pass diff --git a/pyslurm/db/tres.pyx b/pyslurm/db/tres.pyx index 598aee91..faba34d9 100644 --- a/pyslurm/db/tres.pyx +++ b/pyslurm/db/tres.pyx @@ -28,7 +28,6 @@ from pyslurm.constants import UNLIMITED from pyslurm.core.error import RPCError from pyslurm.utils.helpers import instance_to_dict, dehumanize from pyslurm.utils import cstr -from pyslurm.db.connection import _open_conn_or_error from pyslurm import xcollections import json import re @@ -43,6 +42,37 @@ TRES_NAME_REQUIRED = ["fs", "license", "interconnect", "gres"] gres_pattern = re.compile(r'[/:]') +cdef class TrackableResourceAPI(ConnectionWrapper): + + def load(self, db_filter: TrackableResourceFilter = None): + """Load Trackable Resources from the Database.""" + cdef: + TrackableResources out = TrackableResources() + TrackableResource tres + SlurmList tres_data + SlurmListItem tres_ptr + + self.db_conn.validate() + + if not db_filter: + db_filter = TrackableResourceFilter() + + db_filter._create() + tres_data = SlurmList.wrap(slurmdb_tres_get(self.db_conn.ptr, db_filter.ptr)) + + if tres_data.is_null: + raise RPCError(msg="Failed to get TRES data from slurmdbd") + + # Setup TRES objects + for tres_ptr in SlurmList.iter_and_pop(tres_data): + tres = TrackableResource.from_ptr( + tres_ptr.data) + out._handle_tres_type(tres) + out._id_map[tres.id] = tres + + return out + + cdef class FilesystemResources(dict): def to_dict(self, recursive=False): @@ -292,40 +322,12 @@ cdef class TrackableResources: elif hasattr(self, tres.type): setattr(self, tres.type, tres) elif tres.type: - print(tres.type, tres.name, tres.type_and_name) self.other[tres.type_and_name] = tres @staticmethod - def load(Connection db_connection=None): + def load(Connection db_conn): """Load Trackable Resources from the Database.""" - cdef: - TrackableResources out = TrackableResources() - TrackableResource tres - Connection conn - SlurmList tres_data - SlurmListItem tres_ptr - TrackableResourceFilter db_filter = TrackableResourceFilter() - - # Prepare SQL Filter - db_filter._create() - - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - - # Fetch TRES data - tres_data = SlurmList.wrap(slurmdb_tres_get(conn.ptr, db_filter.ptr)) - - if tres_data.is_null: - raise RPCError(msg="Failed to get TRES data from slurmdbd") - - # Setup TRES objects - for tres_ptr in SlurmList.iter_and_pop(tres_data): - tres = TrackableResource.from_ptr( - tres_ptr.data) - out._handle_tres_type(tres) - out._id_map[tres.id] = tres - - return out + return db_conn.tres.load() @staticmethod cdef find_count_in_str(char *tres_str, typ, on_noval=0, on_inf=0): @@ -490,4 +492,7 @@ def _validate_tres_single(local_tres, dict tres_id_map): cdef _set_tres_limits(char **dest, src, tres_data): + if not src: + return + # TODO: Allow users to pass a dict cstr.from_dict(dest, src._validate(tres_data)) diff --git a/pyslurm/db/user.pxd b/pyslurm/db/user.pxd new file mode 100644 index 00000000..d5a27091 --- /dev/null +++ b/pyslurm/db/user.pxd @@ -0,0 +1,112 @@ +######################################################################### +# user.pxd - pyslurm slurmdbd user api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from libc.string cimport memcpy, memset +from pyslurm cimport slurm +from pyslurm.slurm cimport ( + slurmdb_user_rec_t, + slurmdb_assoc_rec_t, + slurmdb_assoc_cond_t, + slurmdb_user_cond_t, + slurmdb_users_get, + slurmdb_users_add, + slurmdb_users_modify, + slurmdb_users_remove, + slurmdb_destroy_user_rec, + slurmdb_destroy_user_cond, + try_xmalloc, +) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, + make_char_list, + slurm_list_to_pylist, + qos_list_to_pylist, +) +from pyslurm.db.tres cimport ( + _set_tres_limits, + TrackableResources, +) +from pyslurm.db.connection cimport Connection, ConnectionWrapper +from pyslurm.utils cimport cstr +from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list +from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr, AssociationFilter, AssociationList +from pyslurm.xcollections cimport MultiClusterMap +from pyslurm.utils.uint cimport u16_set_bool_flag + + +cdef class UserAPI(ConnectionWrapper): + pass + + +cdef class Users(dict): + cdef public: + Connection _db_conn + + +cdef class UserFilter: + cdef slurmdb_user_cond_t *ptr + + cdef public: + names + with_assocs + with_coordinators + with_wckeys + with_deleted + associations + + +cdef class User: + """Slurm Database User + + Attributes: + name (str): + The name of the User. + previous_name (str): + Previous name of the User, in case it was modified before. + user_id (int): + UID of the User. + default_account (str): + Default Account of the User. + default_wckey (str): + Default WCKey for the User. + is_deleted (bool): + Whether this User has been deleted or not. + admin_level (pyslurm.AdminLevel): + Admin Level of the User. + """ + cdef: + slurmdb_user_rec_t *ptr + + cdef public: + associations + coordinators + wckeys + Connection _db_conn + + cdef readonly: + default_association + + @staticmethod + cdef User from_ptr(slurmdb_user_rec_t *in_ptr) diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx new file mode 100644 index 00000000..dae61424 --- /dev/null +++ b/pyslurm/db/user.pyx @@ -0,0 +1,386 @@ +######################################################################### +# user.pyx - pyslurm slurmdbd user api +######################################################################### +# Copyright (C) 2026 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) +from pyslurm.utils.helpers import ( + instance_to_dict, + user_to_uid, +) +from pyslurm.utils.uint import * +from pyslurm import xcollections +from pyslurm.utils.enums import SlurmEnum +from pyslurm.db.error import handle_response +from pyslurm.enums import AdminLevel +from typing import Any, Union, Optional, List, Dict + + +cdef class UserAPI(ConnectionWrapper): + + def load(self, db_filter: Optional[UserFilter] = None): + cdef: + Users out = Users() + UserFilter cond = db_filter + SlurmListItem user_ptr + SlurmListItem assoc_ptr + + self.db_conn.validate() + + if not db_filter: + cond = UserFilter() + + if cond.with_assocs is not False: + # If not explicitly disabled, always fetch the Associations of a + # User. + cond.with_assocs = True + cond._create() + + user_data = SlurmList.wrap(slurmdb_users_get(self.db_conn.ptr, cond.ptr)) + + if user_data.is_null: + raise RPCError(msg="Failed to get User data from slurmdbd") + + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() + + for user_ptr in SlurmList.iter_and_pop(user_data): + user = User.from_ptr(user_ptr.data) + self.db_conn.apply_reuse(user) + out[user.name] = user + + assoc_data = SlurmList.wrap(user.ptr.assoc_list, owned=False) + for assoc_ptr in SlurmList.iter_and_pop(assoc_data): + assoc = Association.from_ptr(assoc_ptr.data) + assoc.qos_data = qos_data + assoc.tres_data = tres_data + _parse_assoc_ptr(assoc) + user.associations.append(assoc) + self.db_conn.apply_reuse(assoc) + + if assoc.user == user.name: + user.default_association = assoc + + self.db_conn.apply_reuse(out) + return out + + def delete(self, db_filter: UserFilter): + out = [] + + # TODO: test again when this is empty, does it really delete everything? + if not db_filter.names: + return + + self.db_conn.validate() + db_filter._create() + + response = SlurmList.wrap(slurmdb_users_remove(self.db_conn.ptr, db_filter.ptr)) + rc = slurm_errno() + self.db_conn.check_commit(rc) + return handle_response(response, rc) + + def modify(self, db_filter: UserFilter, changes: Optional[User] = None, **kwargs: Any): + cdef: + User _changes + SlurmListItem response_ptr + + # TODO: Properly check if the filter is empty, cause it will then probably + # target all users. Or maybe that is fine and we need to clearly document + # to take caution + #if not db_filter.names: + # return + + _changes = _get_modify_arguments_for(User, changes, **kwargs) + + self.db_conn.validate() + db_filter._create() + + response = SlurmList.wrap(slurmdb_users_modify( + self.db_conn.ptr, db_filter.ptr, _changes.ptr) + ) + rc = slurm_errno() + self.db_conn.check_commit(rc) + + return handle_response(response, rc) + + def create(self, users: List[User]): + cdef: + User user + SlurmList user_list + list assocs_to_add = [] + + if not users: + return + + self.db_conn.validate() + user_list = SlurmList.create(slurmdb_destroy_user_rec, owned=False) + + for user in users: + if user.default_account: + has_default_assoc = False + for assoc in user.associations: + if not assoc.is_default: + continue + + if has_default_assoc: + raise ValueError("Multiple Associations declared as default") + + has_default_assoc = True + if not assoc.account: + assoc.account = user.default_account + elif assoc.account != user.default_account: + raise ValueError("Ambigous account definition") + + # Do we really need to specify a default association anyway? + if not has_default_assoc: + # Caller didn't specify any default association, so we + # create a basic one. + assoc = Association(user=user.name, + account=user.default_account, is_default=True) + user.associations.append(assoc) + + assocs_to_add.extend(user.associations) + slurm.slurm_list_append(user_list.info, user.ptr) + + rc = slurmdb_users_add(self.db_conn.ptr, user_list.info) + + # Could also solve this construct via a simple try..finally, but I just + # don't want to execute commit/rollback potentially twice, even if it + # is completely fine. + try: + if rc == slurm.SLURM_SUCCESS: + self.db_conn.associations.create(assocs_to_add) + except RPCError: + # Just re-raise - required rollback was already taken care of + raise + except Exception: + # Doing this catch-all thing might be too cautious, but just in + # case anything goes wrong before Associations were attempted to be + # added, we make sure that adding the users is also rollbacked. + # + # Because we don't want to leave Users with no associations behind + # in the system, if associations were requested to be added. + self.db_conn.check_commit(slurm.SLURM_ERROR) + raise + + # TODO: SLURM_NO_CHANGE_IN_DATA + # Should this be an error? + + # Rollback or commit in case no associations were attempted to be added + self.db_conn.check_commit(rc) + verify_rpc(rc) + + +cdef class Users(dict): + + def __init__(self, users={}, **kwargs): + super().__init__() + self.update(users) + self.update(kwargs) + self._db_conn = None + + @staticmethod + def load(db_conn: Connection, db_filter: Optional[UserFilter] = None): + return db_conn.users.load(db_filter) + + def delete(self, db_conn: Optional[Connection] = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_filter = UserFilter(names=list(self.keys())) + db_conn.users.delete(db_filter) + + def modify(self, changes: Optional[User] = None, db_conn: Optional[Connection] = None, **kwargs: Any): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_filter = UserFilter(names=list(self.keys())) + return db_conn.users.modify(db_filter, changes=changes, **kwargs) + + def create(self, db_conn: Optional[Connection] = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_conn.users.create(list(self.values())) + + +cdef class UserFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs: Any): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_user_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_user_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_user_cond_t") + + memset(self.ptr, 0, sizeof(slurmdb_user_cond_t)) + + self.ptr.assoc_cond = try_xmalloc(sizeof(slurmdb_assoc_cond_t)) + if not self.ptr.assoc_cond: + raise MemoryError("xmalloc failed for slurmdb_assoc_cond_t") + + + def _create(self): + self._alloc() + cdef slurmdb_user_cond_t *ptr = self.ptr + + make_char_list(&ptr.assoc_cond.user_list, self.names) + ptr.with_assocs = 1 if self.with_assocs else 0 + ptr.with_coords = 1 if self.with_coordinators else 0 + ptr.with_wckeys = 1 if self.with_wckeys else 0 + ptr.with_deleted = 1 if self.with_deleted else 0 + + +cdef class User: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, name: str = None, **kwargs: Any): + self._alloc_impl() + self.name = name + self._init_defaults() + for k, v in kwargs.items(): + setattr(self, k, v) + + def _init_defaults(self): + self.associations = [] + self.coordinators = [] + self.default_association = None + self.wckeys = [] + + def __dealloc__(self): + self._dealloc_impl() + + def _dealloc_impl(self): + slurmdb_destroy_user_rec(self.ptr) + self.ptr = NULL + + def _alloc_impl(self): + if not self.ptr: + self.ptr = try_xmalloc( + sizeof(slurmdb_user_rec_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_user_rec_t") + + memset(self.ptr, 0, sizeof(slurmdb_user_rec_t)) + self.ptr.uid = slurm.NO_VAL + + def __repr__(self): + return f'pyslurm.db.{self.__class__.__name__}({self.name})' + + @staticmethod + cdef User from_ptr(slurmdb_user_rec_t *in_ptr): + cdef User wrap = User.__new__(User) + wrap.ptr = in_ptr + wrap._init_defaults() + return wrap + + def to_dict(self, recursive: bool = False): + """Database User information formatted as a dictionary. + + Returns: + (dict): Database User information as dict. + """ + return instance_to_dict(self, recursive) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, User): + return self.name == other.name + return NotImplemented + + @staticmethod + def load(db_conn: Connection, name: str): + user = db_conn.users.load().get(name) + if not user: + raise NotFoundError(msg=f"User {name} does not exist.") + return user + + def create(self, db_conn: Optional[Connection] = None): + Users({self.name: self}).create(self._db_conn or db_conn) + + def delete(self, db_conn: Optional[Connection] = None): + Users({self.name: self}).delete(self._db_conn or db_conn) + + def modify( + self, + changes: Optional[User] = None, + db_conn: Optional[Connection] = None, + **kwargs: Any + ): + Users({self.name: self}).modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) + + @property + def name(self): + return cstr.to_unicode(self.ptr.name) + + @name.setter + def name(self, val): + cstr.fmalloc(&self.ptr.name, val) + + @property + def previous_name(self): + return cstr.to_unicode(self.ptr.old_name) + + @property + def user_id(self): + return u32_parse(self.ptr.uid, zero_is_noval=False) + + @property + def default_account(self): + return cstr.to_unicode(self.ptr.default_acct) + + @default_account.setter + def default_account(self, val): + cstr.fmalloc(&self.ptr.default_acct, val) + + @property + def default_wckey(self): + return cstr.to_unicode(self.ptr.default_wckey) + + @property + def is_deleted(self): + if self.ptr.flags & slurm.SLURMDB_USER_FLAG_DELETED: + return True + return False + + @property + def admin_level(self): + return AdminLevel.from_flag(self.ptr.admin_level, + default=AdminLevel.UNDEFINED) + + @admin_level.setter + def admin_level(self, val): + self.ptr.admin_level = AdminLevel(val)._flag diff --git a/pyslurm/db/wckey.pxd b/pyslurm/db/wckey.pxd new file mode 100644 index 00000000..37680fba --- /dev/null +++ b/pyslurm/db/wckey.pxd @@ -0,0 +1,71 @@ +######################################################################### +# wckey.pxd - pyslurm slurmdbd wckey api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from libc.string cimport memcpy, memset +from pyslurm cimport slurm +from pyslurm.slurm cimport ( + slurmdb_wckey_rec_t, + slurmdb_wckey_cond_t, + slurmdb_wckeys_get, + slurmdb_destroy_wckey_rec, + slurmdb_destroy_wckey_cond, + try_xmalloc, +) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, + make_char_list, + slurm_list_to_pylist, + qos_list_to_pylist, +) +from pyslurm.db.tres cimport ( + _set_tres_limits, + TrackableResources, +) +from pyslurm.db.connection cimport Connection +from pyslurm.utils cimport cstr +from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list +from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr +from pyslurm.xcollections cimport MultiClusterMap +from pyslurm.utils.uint cimport u16_set_bool_flag + + +cdef class WCKeys(MultiClusterMap): + pass + + +cdef class WCKeyFilter: + cdef slurmdb_wckey_cond_t *ptr + + cdef public: + names + + +cdef class WCKey: + cdef: + slurmdb_wckey_rec_t *ptr + _cluster + + @staticmethod + cdef WCKey from_ptr(slurmdb_wckey_rec_t *in_ptr) diff --git a/pyslurm/db/wckey.pyx b/pyslurm/db/wckey.pyx new file mode 100644 index 00000000..c072cb2b --- /dev/null +++ b/pyslurm/db/wckey.pyx @@ -0,0 +1,191 @@ +######################################################################### +# wckey.pyx - pyslurm slurmdbd wckey api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import RPCError +from pyslurm.utils.helpers import ( + instance_to_dict, + user_to_uid, +) +from pyslurm.utils.uint import * +from pyslurm import settings +from pyslurm import xcollections + + +cdef class WCKeys(MultiClusterMap): + + def __init__(self, wckeys=None): + super().__init__(data=wckeys, + typ="WCKeys", + val_type=WCKey, + id_attr=WCKey.name, + key_type=str) + + @staticmethod + def load(Connection db_conn, WCKeyFilter db_filter=None): + cdef: + WCKeys out = WCKeys() + WCKey wckey + WCKeyFilter cond = db_filter + SlurmList wckey_data + SlurmListItem wckey_ptr + + db_conn.validate() + + if not db_filter: + cond = WCKeyFilter() + cond._create() + + wckey_data = SlurmList.wrap(slurmdb_wckeys_get(db_conn.ptr, cond.ptr)) + + if wckey_data.is_null: + raise RPCError(msg="Failed to get WCKey data from slurmdbd.") + + for wckey_ptr in SlurmList.iter_and_pop(wckey_data): + wckey = WCKey.from_ptr(wckey_ptr.data) + + cluster = wckey.cluster + if cluster not in out.data: + out.data[cluster] = {} + out.data[cluster][wckey.name] = wckey + + return out + + +cdef class WCKeyFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_wckey_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_wckey_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_wckey_cond_t") + + def _create(self): + self._alloc() + cdef slurmdb_wckey_cond_t *ptr = self.ptr + + make_char_list(&ptr.name_list, self.names) + + +cdef class WCKey: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, name=None, **kwargs): + self._alloc_impl() + self.name = name + self._init_defaults() + for k, v in kwargs.items(): + setattr(self, k, v) + + def _init_defaults(self): + self._cluster = settings.LOCAL_CLUSTER + + def __dealloc__(self): + self._dealloc_impl() + + def _dealloc_impl(self): + slurmdb_destroy_wckey_rec(self.ptr) + self.ptr = NULL + + def _alloc_impl(self): + if not self.ptr: + self.ptr = try_xmalloc( + sizeof(slurmdb_wckey_rec_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_wckey_rec_t") + + memset(self.ptr, 0, sizeof(slurmdb_wckey_rec_t)) + + def __repr__(self): + return f'pyslurm.db.{self.__class__.__name__}({self.name})' + + @staticmethod + cdef WCKey from_ptr(slurmdb_wckey_rec_t *in_ptr): + cdef WCKey wrap = WCKey.__new__(WCKey) + wrap.ptr = in_ptr + wrap._init_defaults() + return wrap + + def to_dict(self): + """Database WCKey information formatted as a dictionary. + + Returns: + (dict): Database WCKey information as dict. + """ + return instance_to_dict(self) + + def __eq__(self, other): + if isinstance(other, WCKey): + return self.id == other.id and self.cluster == other.cluster + return NotImplemented + + @property + def name(self): + return cstr.to_unicode(self.ptr.name) + + @property + def cluster(self): + cluster = cstr.to_unicode(self.ptr.cluster) + if not cluster: + return self._cluster + return cluster + + @property + def user_name(self): + return cstr.to_unicode(self.ptr.user) + + @property + def user_id(self): + return self.ptr.uid + + @property + def is_default(self): + return bool(self.ptr.is_def) + + @property + def id(self): + return self.ptr.id + + @property + def is_deleted(self): + if self.ptr.flags & slurm.SLURMDB_WCKEY_FLAG_DELETED: + return True + return False + + # TODO: list_t *accounting_list diff --git a/pyslurm/enums.pyx b/pyslurm/enums.pyx index acca2f6c..bd59577f 100644 --- a/pyslurm/enums.pyx +++ b/pyslurm/enums.pyx @@ -32,10 +32,10 @@ from pyslurm cimport slurm class SchedulerType(SlurmEnum): - SUBMIT = auto(), slurm.SLURMDB_JOB_FLAG_SUBMIT - MAIN = auto(), slurm.SLURMDB_JOB_FLAG_SCHED + SUBMIT = auto(), slurm.SLURMDB_JOB_FLAG_SUBMIT + MAIN = auto(), slurm.SLURMDB_JOB_FLAG_SCHED BACKFILL = auto(), slurm.SLURMDB_JOB_FLAG_BACKFILL - UNKNOWN = auto() + UNKNOWN = auto() SchedulerType.SUBMIT.__doc__ = "Scheduled immediately on submit" @@ -43,6 +43,14 @@ SchedulerType.MAIN.__doc__ = "Scheduled by the Main Scheduler" SchedulerType.SUBMIT.__doc__ = "Scheduled by the Backfill Scheduler" +class AdminLevel(SlurmEnum): + UNDEFINED = auto(), slurm.SLURMDB_ADMIN_NOTSET + NONE = auto(), slurm.SLURMDB_ADMIN_NONE + OPERATOR = auto(), slurm.SLURMDB_ADMIN_OPERATOR + ADMINISTRATOR = auto(), slurm.SLURMDB_ADMIN_SUPER_USER + + __all__ = [ "SchedulerType", + "AdminLevel", ] diff --git a/pyslurm/utils/enums.pyx b/pyslurm/utils/enums.pyx index 5eb07a79..3b842b84 100644 --- a/pyslurm/utils/enums.pyx +++ b/pyslurm/utils/enums.pyx @@ -46,7 +46,7 @@ class DocstringSupport(EnumType): return cls -class SlurmEnum(str, Enum, metaclass=DocstringSupport): +class StrEnum(str, Enum, metaclass=DocstringSupport): def __new__(cls, name, *args): # https://docs.python.org/3/library/enum.html @@ -62,9 +62,6 @@ class SlurmEnum(str, Enum, metaclass=DocstringSupport): v = str(name) new_string = str.__new__(cls, v) new_string._value_ = v - - new_string._flag = int(args[0]) if len(args) >= 1 else 0 - new_string._clear_flag = int(args[1]) if len(args) >= 2 else 0 return new_string def __str__(self): @@ -73,13 +70,25 @@ class SlurmEnum(str, Enum, metaclass=DocstringSupport): @staticmethod def _generate_next_value_(name, _start, _count, _last_values): # We just care about the name of the member to be defined. - return name.upper() + return name.lower() + + +class SlurmEnum(StrEnum): + + def __new__(cls, name, *args): + v = str(name) + new_string = str.__new__(cls, v) + new_string._value_ = v + + new_string._flag = int(args[0]) if len(args) >= 1 else 0 + new_string._clear_flag = int(args[1]) if len(args) >= 2 else 0 + return new_string @classmethod def from_flag(cls, flags, default): out = cls(default) for item in cls: - if flags & item._flag or flags == item._flag: + if flags == item._flag: return item return out diff --git a/tests/integration/test_assoc.py b/tests/integration/test_assoc.py new file mode 100644 index 00000000..bb711172 --- /dev/null +++ b/tests/integration/test_assoc.py @@ -0,0 +1,205 @@ +######################################################################### +# test_assoc.py - database assoc/accounts/user api tests +######################################################################### +# Copyright (C) 2026 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +"""test_assoc.py - Integration test assoc/account/user functionalities.""" + +import pyslurm +import pytest +import uuid +from pyslurm.db import ( + User, + Account, + Association, +) + + +def _modify_account(account, conn, with_kwargs, **kwargs): + changes = Account(**kwargs) + + assert account.description != changes.description + assoc_before = account.association.to_dict(recursive=True) + + if with_kwargs: + account.modify(db_conn=conn, **kwargs) + else: + account.modify(changes, conn) + + account = Account.load(conn, account.name) + assoc_after = account.association.to_dict(recursive=True) + assert account.description == changes.description + # Make sure we didn't change anything in the Association + assert assoc_before == assoc_after + + +def _modify_user(user, conn, with_kwargs, **kwargs): + changes = User(**kwargs) + + assert user.admin_level == pyslurm.AdminLevel.NONE + assoc_before = user.default_association.to_dict(recursive=True) + + if with_kwargs: + user.modify(db_conn=conn, **kwargs) + else: + user.modify(changes, conn) + + user = User.load(conn, user.name) + assoc_after = user.default_association.to_dict(recursive=True) + assert user.admin_level == changes.admin_level + # Make sure we didn't change anything in the Association + assert assoc_before == assoc_after + + +def _modify_assoc(assoc, conn, with_kwargs, **kwargs): + changes = Association(**kwargs) + + assert assoc.group_jobs == "UNLIMITED" + assert assoc.group_submit_jobs == "UNLIMITED" + + if with_kwargs: + assoc.modify(db_conn=conn, **kwargs) + else: + assoc.modify(changes, conn) + + assoc = Association.load(conn, assoc.id) + assert assoc.group_jobs == changes.group_jobs + assert assoc.group_submit_jobs == changes.group_submit_jobs + assert assoc.group_jobs != "UNLIMITED" + assert assoc.group_submit_jobs != "UNLIMITED" + + +def _load_assoc(assoc_id, conn): + assocs = conn.associations.load() + return assocs.get(assoc_id) + + +def _load_account(name, conn): + accounts = conn.accounts.load() + assert len(accounts) + return accounts.get(name) + + +def _load_user(name, conn): + users = conn.users.load() + assert len(users) + return users.get(name) + + +def _delete_account(account, conn): + account = Account.load(conn, account.name) + account.delete() + conn.commit() + assert not _load_account(account.name, conn) + assert not _load_assoc(account.association.id, conn) + + +def _add_account(account, conn): + account.create(conn) + # Although everything works without this commit, the slurmdbd complains + # that it can't find the account assoc when going to add the user. + # Everything is created properly, but this error appears. This is + # probably why you can't create both account and user / associations + # directly with sacctmgr. Either it is designed like this, or this is a + # bug in the as_mysql plugin. + # The tests pass anyway, so it is fine, but needs to be documented. + conn.commit() + account = _load_account(account.name, conn) + assoc = account.association + assert assoc + assert _load_assoc(assoc.id, conn) + assert assoc.account == account.name + assert assoc.user is None + assert assoc.parent_account == "root" + return account + + +def _add_user(user, conn): + user.create(conn) + conn.commit() + user = _load_user(user.name, conn) + assert len(user.associations) == 1 + assoc = user.default_association + assert assoc + assert _load_assoc(assoc.id, conn) + assert assoc.user == user.name + assert assoc.account == user.default_account + assert assoc.is_default + return user + + +def _delete_user(user, conn): + assoc_id = user.default_association.id + user = User.load(conn, user.name) + user.delete() + conn.commit() + assert not _load_user(user.name, conn) + assert not _load_assoc(assoc_id, conn) + + +def _test_modify_delete(user, account, conn): + assert conn.is_open + _modify_account(account, conn, with_kwargs=False, description="this is a new description") + _modify_account(account, conn, with_kwargs=True, description="another description") + + _modify_user(user, conn, with_kwargs=False, admin_level="administrator") + _modify_user(user, conn, with_kwargs=True, admin_level="operator") + + _modify_assoc(user.default_association, conn, with_kwargs=False, + group_jobs=10, group_submit_jobs=20) + _modify_assoc(user.default_association, conn, with_kwargs=True, + group_jobs=50, group_submit_jobs=100) + + _delete_account(account, conn) + _delete_user(user, conn) + + +def _test_api(user, account, conn): + # Save them before reloading + user_name = user.name + acc_name = account.name + + account = _add_account(account, conn) + user = _add_user(user, conn) + assert user.name == user_name + assert account.name == acc_name + _test_modify_delete(user, account, conn) + + +def test_user_and_account_no_assoc(): + random_name = str(uuid.uuid4())[:8] + user_name = f"user_{random_name}" + acc_name = f"acc_{random_name}" + + with pyslurm.db.connect() as conn: + account = Account(name=acc_name) + user = User(name=user_name, default_account=acc_name) + _test_api(user, account, conn) + + +def test_user_and_accounts_with_assoc_empty(): + random_name = str(uuid.uuid4())[:8] + user_name = f"user_{random_name}" + acc_name = f"acc_{random_name}" + + with pyslurm.db.connect() as conn: + account_assoc = Association(account=acc_name) + account = Account(name=acc_name, association=account_assoc) + user_assoc = Association(user=user_name, account=acc_name) + user = User(name=user_name, associations=[user_assoc]) + _test_api(user, account, conn) diff --git a/tests/integration/test_db_connection.py b/tests/integration/test_db_connection.py index 95b6f311..4b4c2d59 100644 --- a/tests/integration/test_db_connection.py +++ b/tests/integration/test_db_connection.py @@ -33,6 +33,11 @@ def test_open(): conn = pyslurm.db.Connection.open() assert conn.is_open + with pyslurm.db.connect() as conn2: + pass + + assert not conn2.is_open + def test_close(): conn = pyslurm.db.Connection.open() diff --git a/tests/integration/test_db_job.py b/tests/integration/test_db_job.py index 310df51f..03892baf 100644 --- a/tests/integration/test_db_job.py +++ b/tests/integration/test_db_job.py @@ -38,18 +38,23 @@ def test_load_single(submit_job): job = submit_job() util.wait() - db_job = pyslurm.db.Job.load(job.id) - assert db_job.id == job.id + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(conn, job.id) - with pytest.raises(pyslurm.RPCError): - pyslurm.db.Job.load(0) + assert db_job.id == job.id + + with pytest.raises(pyslurm.core.error.NotFoundError): + pyslurm.db.Job.load(conn, 0) def test_parse_all(submit_job): job = submit_job() util.wait() - db_job = pyslurm.db.Job.load(job.id) + + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(conn, job.id) + job_dict = db_job.to_dict() assert job_dict["stats"] @@ -60,8 +65,9 @@ def test_to_json(submit_job): job = submit_job() util.wait() - jfilter = pyslurm.db.JobFilter(ids=[job.id]) - jobs = pyslurm.db.Jobs.load(jfilter) + with pyslurm.db.connect() as conn: + jfilter = pyslurm.db.JobFilter(ids=[job.id]) + jobs = conn.jobs.load(jfilter) json_data = jobs.to_json() dict_data = json.loads(json_data) @@ -74,66 +80,101 @@ def test_modify(submit_job): job = submit_job() util.wait(5) - jfilter = pyslurm.db.JobFilter(ids=[job.id]) - changes = pyslurm.db.Job(comment="test comment") - pyslurm.db.Jobs.modify(jfilter, changes) + # With explicit separate Job object as changes + with pyslurm.db.connect() as conn: + comment = "comment two" - job = pyslurm.db.Job.load(job.id) - assert job.comment == "test comment" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment + jfilter = pyslurm.db.JobFilter(ids=[job.id]) + changes = pyslurm.db.Job(comment=comment) + conn.jobs.modify(jfilter, changes) + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment -def test_modify_with_existing_conn(submit_job): - job = submit_job() - util.wait(5) + # With filter via **kwargs + with pyslurm.db.connect(transaction_mode="manual") as conn: + comment = "comment two" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment + + jfilter = pyslurm.db.JobFilter(ids=[job.id]) + conn.jobs.modify(jfilter, comment=comment) + + conn.commit() + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment - conn = pyslurm.db.Connection.open() - jfilter = pyslurm.db.JobFilter(ids=[job.id]) - changes = pyslurm.db.Job(comment="test comment") - pyslurm.db.Jobs.modify(jfilter, changes, conn) + with pytest.raises(pyslurm.core.error.ArgumentError): + conn.jobs.modify(jfilter) - job = pyslurm.db.Job.load(job.id) - assert job.comment != "test comment" + # Without filter, using modify() on the instance + # By default, connections are inherited + with pyslurm.db.connect() as conn: + comment = "comment three" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment - conn.commit() - job = pyslurm.db.Job.load(job.id) - assert job.comment == "test comment" + job.modify(comment=comment) + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment -def test_if_steps_exist(submit_job): - # TODO - pass + # Without inherited connection, not supplying a connection will fail + with pyslurm.db.connect(reuse_connection=False) as conn: + comment = "comment four" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment + job.modify(db_conn=conn, comment=comment) -def test_load_with_filter_node(submit_job): - # TODO - pass + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment + with pytest.raises(pyslurm.db.connection.InvalidConnectionError): + job.modify(comment=comment) -def test_load_with_filter_qos(submit_job): - # TODO - pass +# def test_if_steps_exist(submit_job): +# # TODO +# pass -def test_load_with_filter_cluster(submit_job): - # TODO - pass +# def test_load_with_filter_node(submit_job): +# # TODO +# pass -def test_load_with_filter_multiple(submit_job): - # TODO - pass + +# def test_load_with_filter_qos(submit_job): +# # TODO +# pass + + +# def test_load_with_filter_cluster(submit_job): +# # TODO +# pass + + +# def test_load_with_filter_multiple(submit_job): +# # TODO +# pass def test_load_with_script(submit_job): script = util.create_job_script() job = submit_job(script=script) util.wait(5) - db_job = pyslurm.db.Job.load(job.id, with_script=True) + + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(conn, job.id, with_script=True) assert db_job.script == script def test_load_with_env(submit_job): job = submit_job() util.wait(5) - db_job = pyslurm.db.Job.load(job.id, with_env=True) + + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(conn, job.id, with_env=True) assert db_job.environment diff --git a/tests/integration/test_db_qos.py b/tests/integration/test_db_qos.py index e1cde024..9a041a12 100644 --- a/tests/integration/test_db_qos.py +++ b/tests/integration/test_db_qos.py @@ -27,29 +27,32 @@ def test_load_single(): - qos = pyslurm.db.QualityOfService.load("normal") + with pyslurm.db.connect() as conn: + qos = pyslurm.db.QualityOfService.load(conn, "normal") - assert qos.name == "normal" - assert qos.id == 1 + assert qos.name == "normal" + assert qos.id == 1 - with pytest.raises(pyslurm.RPCError): - pyslurm.db.QualityOfService.load("qos_non_existent") + with pytest.raises(pyslurm.RPCError): + pyslurm.db.QualityOfService.load(conn, "qos_non_existent") def test_parse_all(submit_job): - qos = pyslurm.db.QualityOfService.load("normal") - qos_dict = qos.to_dict() + with pyslurm.db.connect() as conn: + qos = pyslurm.db.QualityOfService.load(conn, "normal") + qos_dict = qos.to_dict() - assert qos_dict - assert qos_dict["name"] == qos.name + assert qos_dict + assert qos_dict["name"] == qos.name def test_load_all(): - qos = pyslurm.db.QualitiesOfService.load() - assert qos - + with pyslurm.db.connect() as conn: + qos = conn.qos.load() + assert qos def test_load_with_filter_name(): - qfilter = pyslurm.db.QualityOfServiceFilter(names=["non_existent"]) - qos = pyslurm.db.QualitiesOfService.load(qfilter) - assert not qos + with pyslurm.db.connect() as conn: + db_filter = pyslurm.db.QualityOfServiceFilter(names=["non_existent"]) + qos = conn.qos.load(db_filter) + assert not qos