From 444a5806c0a9eefe6635ddff6e3b73e98e71a96c Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Fri, 20 Feb 2026 20:50:23 +0100 Subject: [PATCH 01/13] wip association api --- pyslurm/db/__init__.py | 17 +- pyslurm/db/account.pxd | 102 ++++++++++++ pyslurm/db/account.pyx | 274 +++++++++++++++++++++++++++++++ pyslurm/db/assoc.pxd | 30 +++- pyslurm/db/assoc.pyx | 184 +++++++++++++++------ pyslurm/db/connection.pyx | 29 +++- pyslurm/db/error.pyx | 150 +++++++++++++++++ pyslurm/db/job.pyx | 58 ++----- pyslurm/db/qos.pyx | 15 +- pyslurm/db/tres.pyx | 14 +- pyslurm/db/user.pxd | 110 +++++++++++++ pyslurm/db/user.pyx | 335 ++++++++++++++++++++++++++++++++++++++ pyslurm/db/wckey.pxd | 71 ++++++++ pyslurm/db/wckey.pyx | 191 ++++++++++++++++++++++ 14 files changed, 1463 insertions(+), 117 deletions(-) create mode 100644 pyslurm/db/account.pxd create mode 100644 pyslurm/db/account.pyx create mode 100644 pyslurm/db/error.pyx create mode 100644 pyslurm/db/user.pxd create mode 100644 pyslurm/db/user.pyx create mode 100644 pyslurm/db/wckey.pxd create mode 100644 pyslurm/db/wckey.pyx diff --git a/pyslurm/db/__init__.py b/pyslurm/db/__init__.py index a821cc43..b633c82a 100644 --- a/pyslurm/db/__init__.py +++ b/pyslurm/db/__init__.py @@ -19,7 +19,7 @@ # with PySlurm; if not, write to the Free Software Foundation, Inc., # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. -from .connection import Connection +from .connection import Connection, connect from .step import JobStep, JobSteps from .stats import JobStatistics, JobStepStatistics from .job import ( @@ -44,3 +44,18 @@ Association, AssociationFilter, ) +from .user import ( + Users, + User, + UserFilter, +) +from .account import ( + Accounts, + Account, + AccountFilter, +) +from .wckey import ( + WCKeys, + WCKey, + WCKeyFilter, +) diff --git a/pyslurm/db/account.pxd b/pyslurm/db/account.pxd new file mode 100644 index 00000000..850a1761 --- /dev/null +++ b/pyslurm/db/account.pxd @@ -0,0 +1,102 @@ +######################################################################### +# account.pxd - pyslurm slurmdbd account api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from libc.string cimport memcpy, memset +from pyslurm cimport slurm +from pyslurm.slurm cimport ( + slurmdb_account_rec_t, + slurmdb_assoc_rec_t, + slurmdb_assoc_cond_t, + slurmdb_account_cond_t, + slurmdb_accounts_get, + slurmdb_accounts_add, + slurmdb_accounts_remove, + slurmdb_accounts_modify, + slurmdb_destroy_account_rec, + slurmdb_destroy_account_cond, + try_xmalloc, +) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, + make_char_list, + slurm_list_to_pylist, + qos_list_to_pylist, +) +from pyslurm.db.tres cimport ( + _set_tres_limits, + TrackableResources, +) +from pyslurm.db.connection cimport Connection +from pyslurm.utils cimport cstr +from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list +from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr +from pyslurm.xcollections cimport MultiClusterMap +from pyslurm.utils.uint cimport u16_set_bool_flag + + +cdef class Accounts(dict): + pass + + +cdef class AccountFilter: + cdef slurmdb_account_cond_t *ptr + + cdef public: + with_assocs + with_deleted + with_coordinators + names + organizations + descriptions + + +cdef class Account: + """Slurm Database Account. + + Attributes: + name (str): + Name of the Account. + description (str): + Description of the Account. + organization (str): + Organization of the Account. + is_deleted (bool): + Whether this Account has been deleted or not. + association (pyslurm.db.Association): + This accounts association. + """ + cdef: + slurmdb_account_rec_t *ptr + + cdef readonly: + cluster + + cdef public: + associations + coordinators + association + + @staticmethod + cdef Account from_ptr(slurmdb_account_rec_t *in_ptr) diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx new file mode 100644 index 00000000..1f5eff48 --- /dev/null +++ b/pyslurm/db/account.pyx @@ -0,0 +1,274 @@ +######################################################################### +# account.pyx - pyslurm slurmdbd account api +######################################################################### +# Copyright (C) 2023 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import RPCError, verify_rpc, slurm_errno +from pyslurm.utils.helpers import ( + instance_to_dict, + user_to_uid, +) +from pyslurm.utils.uint import * +from pyslurm import xcollections +from pyslurm.db.error import DefaultAccountError, JobsRunningError + + +cdef class Accounts(dict): + + def __init__(self, accounts=None): + super().__init__() + + @staticmethod + def load(Connection db_conn, AccountFilter db_filter=None): + cdef: + Accounts out = Accounts() + Account account + AccountFilter cond = db_filter + SlurmList account_data + SlurmListItem account_ptr + SlurmList assoc_data + SlurmListItem assoc_ptr + Association assoc + QualitiesOfService qos_data + TrackableResources tres_data + + db_conn.validate() + + if not db_filter: + cond = AccountFilter() + + if cond.with_assocs is not False: + cond.with_assocs = True + + cond._create() + account_data = SlurmList.wrap(slurmdb_accounts_get(db_conn.ptr, cond.ptr)) + + if account_data.is_null: + raise RPCError(msg="Failed to get Account data from slurmdbd.") + + qos_data = QualitiesOfService.load(db_conn=db_conn, + name_is_key=False) + tres_data = TrackableResources.load(db_conn=db_conn) + + for account_ptr in SlurmList.iter_and_pop(account_data): + account = Account.from_ptr(account_ptr.data) + out[account.name] = account + + assoc_data = SlurmList.wrap(account.ptr.assoc_list, owned=False) + for assoc_ptr in SlurmList.iter_and_pop(assoc_data): + assoc = Association.from_ptr(assoc_ptr.data) + assoc.qos_data = qos_data + assoc.tres_data = tres_data + _parse_assoc_ptr(assoc) + + if not assoc.user: + # This is the Association of the account itself. + account.association = assoc + else: + # These must be User Associations. + account.associations.append(assoc) + + return out + + @staticmethod + def create(Connection db_conn, accounts): + cdef: + Account account + SlurmList account_list + list assocs_to_add = [] + + db_conn.validate() + account_list = SlurmList.create(slurmdb_destroy_account_rec, owned=False) + + for account in accounts: +# if not account.associations and add_assoc: +# # For convenience, we create the associations by default +# # automatically, just like sacctmgr. Can be disabled for more +# # control. +# assoc = Association(account=account.name) +# account.associations.append(assoc) + + assocs_to_add.extend(account.associations) + slurm.slurm_list_append(account_list.info, account.ptr) + + verify_rpc(slurmdb_accounts_add(db_conn.ptr, account_list.info)) + Associations.create(db_conn, assocs_to_add) + + def delete(self, Connection db_conn): + cdef: + AccountFilter a_filter + SlurmList response + list out = [] + + db_conn.validate() + + a_filter = AccountFilter(names=list(self.keys())) + a_filter._create() + + response = SlurmList.wrap(slurmdb_accounts_remove(db_conn.ptr, a_filter.ptr)) + rc = slurm_errno() + + if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + return + +# if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: +# verify_rpc(rc) + + # Handle the error cases. + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + raise JobsRunningError.from_response(response, rc) + elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: + raise DefaultAccountError.from_response(response, rc) + else: + verify_rpc(rc) + + +cdef class AccountFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_account_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_account_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_account_cond_t") + + memset(self.ptr, 0, sizeof(slurmdb_account_cond_t)) + + self.ptr.assoc_cond = try_xmalloc(sizeof(slurmdb_assoc_cond_t)) + if not self.ptr.assoc_cond: + raise MemoryError("xmalloc failed for slurmdb_assoc_cond_t") + + def _parse_flag(self, val, flag_val): + if val: + self.ptr.flags |= flag_val + + def _create(self): + self._alloc() + cdef slurmdb_account_cond_t *ptr = self.ptr + + make_char_list(&ptr.assoc_cond.acct_list, self.names) + self._parse_flag(self.with_assocs, slurm.SLURMDB_ACCT_FLAG_WASSOC) + self._parse_flag(self.with_deleted, slurm.SLURMDB_ACCT_FLAG_DELETED) + self._parse_flag(self.with_coordinators, slurm.SLURMDB_ACCT_FLAG_WCOORD) + + +cdef class Account: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, name=None, **kwargs): + self._alloc_impl() + self.name = name + self._init_defaults() + for k, v in kwargs.items(): + setattr(self, k, v) + + def _init_defaults(self): + self.associations = [] + self.association = None + self.coordinators = [] + + def __dealloc__(self): + self._dealloc_impl() + + def _dealloc_impl(self): + slurmdb_destroy_account_rec(self.ptr) + self.ptr = NULL + + def _alloc_impl(self): + if not self.ptr: + self.ptr = try_xmalloc( + sizeof(slurmdb_account_rec_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_account_rec_t") + + memset(self.ptr, 0, sizeof(slurmdb_account_rec_t)) + + def __repr__(self): + return f'pyslurm.db.{self.__class__.__name__}({self.name})' + + @staticmethod + cdef Account from_ptr(slurmdb_account_rec_t *in_ptr): + cdef Account wrap = Account.__new__(Account) + wrap.ptr = in_ptr + wrap._init_defaults() + return wrap + + def to_dict(self): + """Database Account information formatted as a dictionary. + + Returns: + (dict): Database Account information as dict. + """ + return instance_to_dict(self) + + def __eq__(self, other): + if isinstance(other, Account): + return self.name == other.name + return NotImplemented + + def create(self, Connection db_conn): + Accounts.create(db_conn, [self]) + + @property + def name(self): + return cstr.to_unicode(self.ptr.name) + + @name.setter + def name(self, val): + cstr.fmalloc(&self.ptr.name, val) + + @property + def description(self): + return cstr.to_unicode(self.ptr.description) + + @description.setter + def description(self, val): + cstr.fmalloc(&self.ptr.description, val) + + @property + def organization(self): + return cstr.to_unicode(self.ptr.organization) + + @organization.setter + def organization(self, val): + cstr.fmalloc(&self.ptr.organization, val) + + @property + def is_deleted(self): + if self.ptr.flags & slurm.SLURMDB_ACCT_FLAG_DELETED: + return True + return False diff --git a/pyslurm/db/assoc.pxd b/pyslurm/db/assoc.pxd index b0dad3a9..aac46f0c 100644 --- a/pyslurm/db/assoc.pxd +++ b/pyslurm/db/assoc.pxd @@ -26,11 +26,13 @@ from pyslurm cimport slurm from pyslurm.slurm cimport ( slurmdb_assoc_rec_t, slurmdb_assoc_cond_t, - slurmdb_associations_get, slurmdb_destroy_assoc_rec, slurmdb_destroy_assoc_cond, slurmdb_init_assoc_rec, + slurmdb_associations_get, slurmdb_associations_modify, + slurmdb_associations_add, + slurmdb_associations_remove, try_xmalloc, ) from pyslurm.db.util cimport ( @@ -64,13 +66,35 @@ cdef class AssociationFilter: cdef public: users ids + accounts + parent_accounts + clusters + partitions + qos + + +cdef class AssociationLimits: + + cdef public: + group_tres + group_tres_mins + group_tres_run_mins + max_tres_mins_per_job + max_tres_run_mins_per_user + max_tres_per_job + max_tres_per_node + qos + + group_jobs cdef class Association: cdef: slurmdb_assoc_rec_t *ptr + slurmdb_assoc_rec_t *umsg QualitiesOfService qos_data TrackableResources tres_data + owned cdef public: group_tres @@ -81,7 +105,11 @@ cdef class Association: max_tres_per_job max_tres_per_node qos + default_qos @staticmethod cdef Association from_ptr(slurmdb_assoc_rec_t *in_ptr) + +cdef class AssociationList(SlurmList): + pass diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index 56370a04..25265f26 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -22,15 +22,47 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError +from pyslurm.core.error import RPCError, verify_rpc, slurm_errno from pyslurm.utils.helpers import ( instance_to_dict, user_to_uid, ) from pyslurm.utils.uint import * -from pyslurm.db.connection import _open_conn_or_error from pyslurm import settings from pyslurm import xcollections +from pyslurm.db.error import JobsRunningError, DefaultAccountError + + +cdef class AssociationList(SlurmList): + + def __init__(self, owned=True): + self.info = slurm.slurm_list_create(slurm.slurmdb_destroy_assoc_rec) + self.owned = owned + + def append(self, Association assoc): + slurm.slurm_list_append(self.info, assoc.ptr) + assoc.owned = False + self.cnt = slurm.slurm_list_count(self.info) + + def __iter__(self): + return super().__iter__() + + def __next__(self): + if self.is_null or self.is_itr_null: + raise StopIteration + + if self.itr_cnt < self.cnt: + self.itr_cnt += 1 + assoc = Association.from_ptr(slurm.slurm_list_next(self.itr)) + assoc.owned = False + return assoc + + self._dealloc_itr() + raise StopIteration + + def extend(self, list_in): + for item in list_in: + self.append(item) cdef class Associations(MultiClusterMap): @@ -43,39 +75,34 @@ cdef class Associations(MultiClusterMap): key_type=int) @staticmethod - def load(AssociationFilter db_filter=None, Connection db_connection=None): + def load(Connection db_conn, AssociationFilter db_filter=None): cdef: Associations out = Associations() Association assoc AssociationFilter cond = db_filter SlurmList assoc_data SlurmListItem assoc_ptr - Connection conn QualitiesOfService qos_data TrackableResources tres_data - # Prepare SQL Filter + db_conn.validate() + if not db_filter: cond = AssociationFilter() cond._create() - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - - # Fetch Assoc Data assoc_data = SlurmList.wrap(slurmdb_associations_get( - conn.ptr, cond.ptr)) + db_conn.ptr, cond.ptr)) if assoc_data.is_null: - raise RPCError(msg="Failed to get Association data from slurmdbd") + raise RPCError(msg="Failed to get Association data from slurmdbd.") # Fetch other necessary dependencies needed for translating some # attributes (i.e QoS IDs to its name) - qos_data = QualitiesOfService.load(db_connection=conn, + qos_data = QualitiesOfService.load(db_conn=db_conn, name_is_key=False) - tres_data = TrackableResources.load(db_connection=conn) + tres_data = TrackableResources.load(db_conn=db_conn) - # Setup Association objects for assoc_ptr in SlurmList.iter_and_pop(assoc_data): assoc = Association.from_ptr(assoc_ptr.data) assoc.qos_data = qos_data @@ -90,14 +117,16 @@ cdef class Associations(MultiClusterMap): return out @staticmethod - def modify(db_filter, Association changes, Connection db_connection=None): + def modify(Connection db_conn, db_filter, Association changes): cdef: AssociationFilter afilter - Connection conn SlurmList response SlurmListItem response_ptr list out = [] + db_conn.validate() + + # TODO: make db_filter optional? # Prepare SQL Filter if isinstance(db_filter, Associations): assoc_ids = [ass.id for ass in db_filter] @@ -106,18 +135,13 @@ cdef class Associations(MultiClusterMap): afilter = db_filter afilter._create() - # Setup DB conn - conn = _open_conn_or_error(db_connection) - # Any data that isn't parsed yet or needs validation is done in this # function. - _create_assoc_ptr(changes, conn) + _create_assoc_ptr(changes, db_conn) - # Modify associations, get the result - # This returns a List of char* with the associations that were - # modified + # Returns a List of char* with the associations that were modified response = SlurmList.wrap(slurmdb_associations_modify( - conn.ptr, afilter.ptr, changes.ptr)) + db_conn.ptr, afilter.ptr, changes.ptr)) if not response.is_null and response.cnt: for response_ptr in response: @@ -130,17 +154,66 @@ cdef class Associations(MultiClusterMap): elif not response.is_null: # There was no real error, but simply nothing has been modified - raise RPCError(msg="Nothing was modified") + return None else: # Autodetects the last slurm error raise RPCError() - if not db_connection: - # Autocommit if no connection was explicitly specified. - conn.commit() - return out + @staticmethod + def create(Connection db_conn, associations, auto_add=True): + cdef: + Association assoc + AssociationList assoc_list = AssociationList(owned=False) + + if not associations: + return + + conn.validate() + + for i, assoc in enumerate(associations): + # Make sure to remove any duplicate associations, i.e. associations + # having the same account name set. For some reason, the slurmdbd + # doesn't like that. + if assoc not in assoc_list: + assoc_list.append(assoc) + + verify_rpc(slurmdb_associations_add(db_conn.ptr, assoc_list.info)) + + def delete(self, Connection db_conn): + cdef: + AssociationFilter afilter + SlurmList response + SlurmListItem response_ptr + + db_conn.validate() + + ids = [assoc.id for assoc in self.values()] + if not ids: + return + + a_filter = AssociationFilter(ids=ids) + a_filter._create() + + response = SlurmList.wrap(slurmdb_associations_remove(db_conn.ptr, + a_filter.ptr)) + rc = slurm_errno() + + if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + return + + #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: + # verify_rpc(rc) + + # Handle the error cases. + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + raise JobsRunningError.from_response(response, rc) + elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: + raise DefaultAccountError.from_response(response, rc) + else: + verify_rpc(rc) + cdef class AssociationFilter: @@ -174,12 +247,25 @@ cdef class AssociationFilter: cdef slurmdb_assoc_cond_t *ptr = self.ptr make_char_list(&ptr.user_list, self.users) + make_char_list(&ptr.id_list, self.ids) + make_char_list(&ptr.acct_list, self.accounts) + make_char_list(&ptr.parent_acct_list, self.parent_accounts) + make_char_list(&ptr.cluster_list, self.clusters) + make_char_list(&ptr.partition_list, self.partitions) + # TODO: These should be QOS ids, not names + make_char_list(&ptr.qos_list, self.qos) + # TODO: ASSOC_COND_FLAGS + + +cdef class AssociationLimits: + pass cdef class Association: def __cinit__(self): self.ptr = NULL + self.owned = True def __init__(self, **kwargs): self._alloc_impl() @@ -189,7 +275,8 @@ cdef class Association: setattr(self, k, v) def __dealloc__(self): - self._dealloc_impl() + if self.owned: + self._dealloc_impl() def _dealloc_impl(self): slurmdb_destroy_assoc_rec(self.ptr) @@ -223,7 +310,8 @@ cdef class Association: def __eq__(self, other): if isinstance(other, Association): - return self.id == other.id and self.cluster == other.cluster +# return self.id == other.id and self.cluster == other.cluster + return self.cluster == other.cluster and self.partition == other.partition and self.account == other.account and self.user == other.user return NotImplemented @property @@ -254,21 +342,8 @@ cdef class Association: # uint16_t flags (ASSOC_FLAG_*) - @property - def group_jobs(self): - return u32_parse(self.ptr.grp_jobs, zero_is_noval=False) - - @group_jobs.setter - def group_jobs(self, val): - self.ptr.grp_jobs = u32(val, zero_is_noval=False) - - @property - def group_jobs_accrue(self): - return u32_parse(self.ptr.grp_jobs_accrue, zero_is_noval=False) - @group_jobs_accrue.setter def group_jobs_accrue(self, val): - self.ptr.grp_jobs_accrue = u32(val, zero_is_noval=False) @property def group_submit_jobs(self): @@ -346,6 +421,10 @@ cdef class Association: def parent_account_id(self): return u32_parse(self.ptr.parent_id, zero_is_noval=False) + @property + def lineage(self): + return cstr.to_unicode(self.ptr.lineage) + @property def partition(self): return cstr.to_unicode(self.ptr.partition) @@ -378,12 +457,18 @@ cdef class Association: def user(self, val): cstr.fmalloc(&self.ptr.user, val) + @property + def user_id(self): + return u32_parse(self.ptr.uid, zero_is_noval=False) + cdef _parse_assoc_ptr(Association ass): cdef: TrackableResources tres = ass.tres_data QualitiesOfService qos = ass.qos_data + policy = ass.policy + ass.group_tres = TrackableResources.from_cstr( ass.ptr.grp_tres, tres) ass.group_tres_mins = TrackableResources.from_cstr( @@ -400,12 +485,16 @@ cdef _parse_assoc_ptr(Association ass): ass.ptr.max_tres_pn, tres) ass.qos = qos_list_to_pylist(ass.ptr.qos_list, qos) + policy.group_jobs = u32_parse(ass.ptr.grp_jobs, zero_is_noval=False) + policy.group_jobs_accrue = u32_parse(ass.ptr.grp_jobs_accrue, zero_is_noval=False) + # TODO: default_qos + cdef _create_assoc_ptr(Association ass, conn=None): # _set_tres_limits will also check if specified TRES are valid and # translate them to its ID which is why we need to load the current TRES # available in the system. - ass.tres_data = TrackableResources.load(db_connection=conn) + ass.tres_data = TrackableResources.load(db_conn=conn) _set_tres_limits(&ass.ptr.grp_tres, ass.group_tres, ass.tres_data) _set_tres_limits(&ass.ptr.grp_tres_mins, ass.group_tres_mins, ass.tres_data) @@ -423,6 +512,9 @@ cdef _create_assoc_ptr(Association ass, conn=None): # _set_qos_list will also check if specified QoS are valid and translate # them to its ID, which is why we need to load the current QOS available # in the system. - ass.qos_data = QualitiesOfService.load(db_connection=conn) + ass.qos_data = QualitiesOfService.load(db_conn=conn) _set_qos_list(&ass.ptr.qos_list, self.qos, ass.qos_data) + ass.ptr.group_jobs = u32(ass.policy.group_jobs, zero_is_noval=False) + ass.ptr.group_jobs_accrue = u32(ass.policy.group_jobs_accrue, zero_is_noval=False) + diff --git a/pyslurm/db/connection.pyx b/pyslurm/db/connection.pyx index 9e1a4428..1e1833d8 100644 --- a/pyslurm/db/connection.pyx +++ b/pyslurm/db/connection.pyx @@ -22,17 +22,22 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError +from pyslurm.core.error import RPCError, PyslurmError +from contextlib import contextmanager -def _open_conn_or_error(conn): - if not conn: - conn = Connection.open() +class InvalidConnectionError(PyslurmError): + pass - if not conn.is_open: - raise ValueError("Database connection is not open") - return conn +@contextmanager +def connect(): + """A managed Slurm DB Connection""" + connection = Connection.open() + try: + yield connection + finally: + connection.close() cdef class Connection: @@ -52,6 +57,10 @@ cdef class Connection: state = "open" if self.is_open else "closed" return f'pyslurm.db.{self.__class__.__name__} is {state}' + def validate(self): + if not self.is_open: + raise InvalidConnectionError("Connection is closed") + @staticmethod def open(): """Open a new connection to the slurmdbd @@ -92,11 +101,17 @@ cdef class Connection: def commit(self): """Commit recent changes.""" + if not self.is_open: + return + if slurmdb_connection_commit(self.ptr, 1) == slurm.SLURM_ERROR: raise RPCError("Failed to commit database changes.") def rollback(self): """Rollback recent changes.""" + if not self.is_open: + return + if slurmdb_connection_commit(self.ptr, 0) == slurm.SLURM_ERROR: raise RPCError("Failed to rollback database changes.") diff --git a/pyslurm/db/error.pyx b/pyslurm/db/error.pyx new file mode 100644 index 00000000..8c15cd5f --- /dev/null +++ b/pyslurm/db/error.pyx @@ -0,0 +1,150 @@ +######################################################################### +# error.pyx - pyslurm db specific errors +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import RPCError, slurm_errno, verify_rpc +from pyslurm.db.util cimport SlurmList, SlurmListItem + + +class AssociationChangeInfo: + + def __init__(self, user, cluster, account, partition=None): + self.user = user + self.cluster = cluster + self.account = account + self.partition = partition + self.running_jobs = [] + + +class JobsRunningError(RPCError): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.associations = [] + + @staticmethod + def from_response(SlurmList response, rc): + running_jobs = parse_running_job_errors(response) + err = JobsRunningError(errno=rc) + err.associations = list(running_jobs.values()) + return err + + +class DefaultAccountError(RPCError): + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.associations = [] + + @staticmethod + def from_response(SlurmList response, rc): + err = DefaultAccountError(errno=rc) + err.associations = parse_default_account_errors(response) + return err + + +def parse_default_account_errors(SlurmList response): + cdef SlurmListItem response_ptr + + assocs = [] + for response_ptr in response: + response_str = response_ptr.to_str() + if not response_str: + continue + + # The response is a string in the following form: + # C = X A = X U = X P = X + # + # And we have to parse this stuff... The Partition (P) is optional + resp = response_str.rstrip().lstrip() + splitted = resp.split(" ") + values = [] + for item in splitted: + if not item: + continue + + key, value = item.split("=") + values.append(value.strip()) + + info = AssociationChangeInfo( + cluster = values[0], + account = values[1], + user = values[2], + ) + assoc_str = f"{info.cluster}-{info.account}-{info.user}" + + if len(values) > 3: + info.partition = values[3] + assoc_str = f"{assoc_str}-{info.partition}" + + assocs.append(info) + + return assocs + + +def parse_running_job_errors(SlurmList response): + cdef SlurmListItem response_ptr + + running_jobs_for_assoc = {} + for response_ptr in response: + response_str = response_ptr.to_str() + if not response_str: + continue + + # The response is a string in the following form: + # JobId = X C = X A = X U = X P = X + # + # And we have to parse this stuff... The Partition (P) is optional + resp = response_str.rstrip().lstrip() + splitted = resp.split(" ") + values = [] + for item in splitted: + if not item: + continue + + key, value = item.split("=") + values.append(value.strip()) + + job_id = int(values[0]) + cluster = values[1] + account = values[2] + user = values[3] + partition = None + assoc_str = f"{cluster}-{account}-{user}" + + if len(values) > 4: + partition = values[4] + assoc_str = f"{assoc_str}-{partition}" + + if assoc_str not in running_jobs_for_assoc: + info = AssociationChangeInfo( + user = user, + cluster = cluster, + account = account, + partition = partition, + ) + running_jobs_for_assoc[assoc_str] = info + + running_jobs_for_assoc[assoc_str].running_jobs.append(job_id) + + return running_jobs_for_assoc diff --git a/pyslurm/db/job.pyx b/pyslurm/db/job.pyx index 3d580356..99e2bef4 100644 --- a/pyslurm/db/job.pyx +++ b/pyslurm/db/job.pyx @@ -42,7 +42,6 @@ from pyslurm.utils.helpers import ( _get_exit_code, gres_from_tres_dict, ) -from pyslurm.db.connection import _open_conn_or_error from pyslurm.enums import SchedulerType @@ -202,7 +201,7 @@ cdef class Jobs(MultiClusterMap): self._reset_stats() @staticmethod - def load(JobFilter db_filter=None, Connection db_connection=None): + def load(Connection db_conn, JobFilter db_filter=None): """Load Jobs from the Slurm Database Implements the slurmdb_jobs_get RPC. @@ -211,9 +210,8 @@ cdef class Jobs(MultiClusterMap): db_filter (pyslurm.db.JobFilter): A search filter that the slurmdbd will apply when retrieving Jobs from the database. - db_connection (pyslurm.db.Connection): - An open database connection. By default if none is specified, - one will be opened automatically. + db_conn (pyslurm.db.Connection): + An open database connection. Returns: (pyslurm.db.Jobs): A Collection of database Jobs. @@ -247,28 +245,26 @@ cdef class Jobs(MultiClusterMap): JobFilter cond = db_filter SlurmList job_data SlurmListItem job_ptr - Connection conn QualitiesOfService qos_data TrackableResources tres_data + db_conn.validate() + # Prepare SQL Filter if not db_filter: cond = JobFilter() cond._create() - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - # Fetch Job data - job_data = SlurmList.wrap(slurmdb_jobs_get(conn.ptr, cond.ptr)) + job_data = SlurmList.wrap(slurmdb_jobs_get(db_conn.ptr, cond.ptr)) if job_data.is_null: raise RPCError(msg="Failed to get Jobs from slurmdbd") # Fetch other necessary dependencies needed for translating some # attributes (i.e QoS IDs to its name) - qos_data = QualitiesOfService.load(db_connection=conn, + qos_data = QualitiesOfService.load(db_conn=db_conn, name_is_key=False) - tres_data = TrackableResources.load(db_connection=conn) + tres_data = TrackableResources.load(db_conn=db_conn) # TODO: How to handle the possibility of duplicate job ids that could # appear if IDs on a cluster are reset? @@ -311,7 +307,7 @@ cdef class Jobs(MultiClusterMap): self._add_stats(job) @staticmethod - def modify(db_filter, Job changes, db_connection=None): + def modify(Connection db_conn, db_filter, Job changes): """Modify Slurm database Jobs. Implements the slurm_job_modify RPC. @@ -324,22 +320,8 @@ cdef class Jobs(MultiClusterMap): changes to apply. Check the `Other Parameters` of the [pyslurm.db.Job][] class to see which properties can be modified. - db_connection (pyslurm.db.Connection): - A Connection to the slurmdbd. By default, if no connection is - supplied, one will automatically be created internally. This - means that when the changes were considered successful by the - slurmdbd, those modifications will be **automatically - committed**. - - If you however decide to provide your own Connection instance - (which must be already opened before), and the changes were - successful, they will basically be in a kind of "staging - area". By the time this function returns, the changes are not - actually made. - You are then responsible to decide whether the changes should - be committed or rolled back by using the respective methods on - the connection object. This way, you have a chance to see - which Jobs were modified before you commit the changes. + db_conn (pyslurm.db.Connection): + A Connection to the slurmdbd. Returns: (list[int]): A list of Jobs that were modified @@ -387,11 +369,12 @@ cdef class Jobs(MultiClusterMap): """ cdef: JobFilter cond - Connection conn SlurmList response SlurmListItem response_ptr list out = [] + db_conn.validate() + # Prepare SQL Filter if isinstance(db_filter, Jobs): job_ids = [job.id for job in self] @@ -400,14 +383,11 @@ cdef class Jobs(MultiClusterMap): cond = db_filter cond._create() - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - # Modify Jobs, get the result # This returns a List of char* with the Jobs ids that were # modified response = SlurmList.wrap( - slurmdb_job_modify(conn.ptr, cond.ptr, changes.ptr)) + slurmdb_job_modify(db_conn.ptr, cond.ptr, changes.ptr)) if not response.is_null and response.cnt: for response_ptr in response: @@ -432,10 +412,6 @@ cdef class Jobs(MultiClusterMap): # Autodetects the last slurm error raise RPCError() - if not db_connection: - # Autocommit if no connection was explicitly specified. - conn.commit() - return out @@ -551,7 +527,7 @@ cdef class Job: def __repr__(self): return f'pyslurm.db.{self.__class__.__name__}({self.id})' - def modify(self, changes, db_connection=None): + def modify(self, Connection db_conn, changes): """Modify a Slurm database Job. Args: @@ -560,7 +536,7 @@ cdef class Job: changes to apply. Check the `Other Parameters` of the [pyslurm.db.Job][] class to see which properties can be modified. - db_connection (pyslurm.db.Connection): + db_conn (pyslurm.db.Connection): A slurmdbd connection. See [pyslurm.db.Jobs.modify][pyslurm.db.job.Jobs.modify] for more info on this parameter. @@ -569,7 +545,7 @@ cdef class Job: (pyslurm.RPCError): When modifying the Job failed. """ cdef JobFilter jfilter = JobFilter(ids=[self.id]) - Jobs.modify(jfilter, changes, db_connection) + Jobs.modify(db_conn, jfilter, changes) @property def account(self): diff --git a/pyslurm/db/qos.pyx b/pyslurm/db/qos.pyx index 9e24189b..f896eb86 100644 --- a/pyslurm/db/qos.pyx +++ b/pyslurm/db/qos.pyx @@ -24,7 +24,6 @@ from pyslurm.core.error import RPCError from pyslurm.utils.helpers import instance_to_dict -from pyslurm.db.connection import _open_conn_or_error cdef class QualitiesOfService(dict): @@ -33,8 +32,8 @@ cdef class QualitiesOfService(dict): pass @staticmethod - def load(QualityOfServiceFilter db_filter=None, - Connection db_connection=None, name_is_key=True): + def load(Connection db_conn, QualityOfServiceFilter db_filter=None, + name_is_key=True): """Load QoS data from the Database Args: @@ -49,18 +48,14 @@ cdef class QualitiesOfService(dict): QualityOfServiceFilter cond = db_filter SlurmList qos_data SlurmListItem qos_ptr - Connection conn - # Prepare SQL Filter + db_conn.validate() + if not db_filter: cond = QualityOfServiceFilter() cond._create() - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - - # Fetch QoS Data - qos_data = SlurmList.wrap(slurmdb_qos_get(conn.ptr, cond.ptr)) + qos_data = SlurmList.wrap(slurmdb_qos_get(db_conn.ptr, cond.ptr)) if qos_data.is_null: raise RPCError(msg="Failed to get QoS data from slurmdbd") diff --git a/pyslurm/db/tres.pyx b/pyslurm/db/tres.pyx index 598aee91..12b90fa6 100644 --- a/pyslurm/db/tres.pyx +++ b/pyslurm/db/tres.pyx @@ -28,7 +28,6 @@ from pyslurm.constants import UNLIMITED from pyslurm.core.error import RPCError from pyslurm.utils.helpers import instance_to_dict, dehumanize from pyslurm.utils import cstr -from pyslurm.db.connection import _open_conn_or_error from pyslurm import xcollections import json import re @@ -292,28 +291,21 @@ cdef class TrackableResources: elif hasattr(self, tres.type): setattr(self, tres.type, tres) elif tres.type: - print(tres.type, tres.name, tres.type_and_name) self.other[tres.type_and_name] = tres @staticmethod - def load(Connection db_connection=None): + def load(Connection db_conn): """Load Trackable Resources from the Database.""" cdef: TrackableResources out = TrackableResources() TrackableResource tres - Connection conn SlurmList tres_data SlurmListItem tres_ptr TrackableResourceFilter db_filter = TrackableResourceFilter() - # Prepare SQL Filter + db_conn.validate() db_filter._create() - - # Setup DB Conn - conn = _open_conn_or_error(db_connection) - - # Fetch TRES data - tres_data = SlurmList.wrap(slurmdb_tres_get(conn.ptr, db_filter.ptr)) + tres_data = SlurmList.wrap(slurmdb_tres_get(db_conn.ptr, db_filter.ptr)) if tres_data.is_null: raise RPCError(msg="Failed to get TRES data from slurmdbd") diff --git a/pyslurm/db/user.pxd b/pyslurm/db/user.pxd new file mode 100644 index 00000000..20dbd3c8 --- /dev/null +++ b/pyslurm/db/user.pxd @@ -0,0 +1,110 @@ +######################################################################### +# user.pxd - pyslurm slurmdbd user api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from libc.string cimport memcpy, memset +from pyslurm cimport slurm +from pyslurm.slurm cimport ( + slurmdb_user_rec_t, + slurmdb_assoc_rec_t, + slurmdb_assoc_cond_t, + slurmdb_user_cond_t, + slurmdb_users_get, + slurmdb_users_add, + slurmdb_users_modify, + slurmdb_users_remove, + slurmdb_destroy_user_rec, + slurmdb_destroy_user_cond, + try_xmalloc, +) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, + make_char_list, + slurm_list_to_pylist, + qos_list_to_pylist, +) +from pyslurm.db.tres cimport ( + _set_tres_limits, + TrackableResources, +) +from pyslurm.db.connection cimport Connection +from pyslurm.utils cimport cstr +from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list +from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr, AssociationFilter, AssociationList +from pyslurm.xcollections cimport MultiClusterMap +from pyslurm.utils.uint cimport u16_set_bool_flag + + +cdef class Users(dict): + pass + + +# cdef class UserAddRequest: +# cdef slurmdb_add_assoc_cond_t *ptr + + +cdef class UserFilter: + cdef slurmdb_user_cond_t *ptr + + cdef public: + names + with_assocs + with_coordinators + with_wckeys + with_deleted + associations + + +cdef class User: + """Slurm Database User + + Attributes: + name (str): + The name of the User. + previous_name (str): + Previous name of the User, in case it was modified before. + user_id (int): + UID of the User. + default_account (str): + Default Account of the User. + default_wckey (str): + Default WCKey for the User. + is_deleted (bool): + Whether this User has been deleted or not. + admin_level (pyslurm.AdminLevel): + Admin Level of the User. + """ + cdef: + slurmdb_user_rec_t *ptr + + cdef public: + associations + coordinators + wckeys + + cdef readonly: + default_association + + @staticmethod + cdef User from_ptr(slurmdb_user_rec_t *in_ptr) diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx new file mode 100644 index 00000000..20c9aebe --- /dev/null +++ b/pyslurm/db/user.pyx @@ -0,0 +1,335 @@ +######################################################################### +# user.pyx - pyslurm slurmdbd user api +######################################################################### +# Copyright (C) 2023 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import RPCError, slurm_errno, verify_rpc +from pyslurm.utils.helpers import ( + instance_to_dict, + user_to_uid, +) +from pyslurm.utils.uint import * +from pyslurm import xcollections +from pyslurm.utils.enums import SlurmEnum +from pyslurm.db.error import JobsRunningError + + +class AdminLevel(SlurmEnum): + UNDEFINED = "UNDEFINED", slurm.SLURMDB_ADMIN_NOTSET + NONE = "NONE", slurm.SLURMDB_ADMIN_NONE + OPERATOR = "OPERATOR", slurm.SLURMDB_ADMIN_OPERATOR + ADMINISTRATOR = "ADMINISTRATOR", slurm.SLURMDB_ADMIN_SUPER_USER + + +cdef class Users(dict): + + def __init__(self, **kwargs): + super().__init__(kwargs) + + @staticmethod + def load(Connection db_conn, UserFilter db_filter=None): + cdef: + Users out = Users() + User user + UserFilter cond = db_filter + SlurmList user_data + SlurmListItem user_ptr + SlurmList assoc_data + SlurmListItem assoc_ptr + Association assoc + QualitiesOfService qos_data + TrackableResources tres_data + + db_conn.validate() + + if not db_filter: + cond = UserFilter() + + if cond.with_assocs is not False: + # If not explicitly disabled, always fetch the Associations of a + # User. + cond.with_assocs = True + cond._create() + + user_data = SlurmList.wrap(slurmdb_users_get(db_conn.ptr, cond.ptr)) + + if user_data.is_null: + raise RPCError(msg="Failed to get User data from slurmdbd") + + qos_data = QualitiesOfService.load(db_conn=db_conn, + name_is_key=False) + tres_data = TrackableResources.load(db_conn=db_conn) + + for user_ptr in SlurmList.iter_and_pop(user_data): + user = User.from_ptr(user_ptr.data) + out[user.name] = user + + assoc_data = SlurmList.wrap(user.ptr.assoc_list, owned=False) + for assoc_ptr in SlurmList.iter_and_pop(assoc_data): + assoc = Association.from_ptr(assoc_ptr.data) + assoc.qos_data = qos_data + assoc.tres_data = tres_data + _parse_assoc_ptr(assoc) + user.associations.append(assoc) + + if assoc.user == user.name: + user.default_association = assoc + + return out + + def delete(self, Connection db_conn): + cdef: + UserFilter u_filter + SlurmList response + SlurmListItem response_ptr + + db_conn.validate() + + names = list(self.keys()) + if not names: + return + + u_filter = UserFilter(names=names) + u_filter._create() + + response = SlurmList.wrap(slurmdb_users_remove(db_conn.ptr, u_filter.ptr)) + rc = slurm_errno() + + if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + return + + #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: + # verify_rpc(rc) + + # Handle the error case. Running Jobs should be the only possible error + # where slurmdbd sends a response list. + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + raise JobsRunningError.from_response(response, rc) + else: + verify_rpc(rc) + + def modify(self, Connection db_conn, User changes): + cdef: + UserFilter u_filter + AssociationFilter a_filter + SlurmList response + SlurmListItem response_ptr + list out = [] + + db_conn.validate() + + u_filter = UserFilter(names=list(self.keys())) +# a_filter = AssociationFilter() + + # u_filter.ptr.assoc_cond = a_filter.ptr + u_filter._create() + response = SlurmList.wrap(slurmdb_users_modify( + db_conn.ptr, u_filter.ptr, changes.ptr)) + + if not response.is_null and response.cnt: + for response_ptr in response: + response_str = cstr.to_unicode(response_ptr.data) + if not response_str: + continue + + out.append(response_str) + + elif not response.is_null: + # There was no real error, but simply nothing has been modified + return out + else: + # Autodetects the last slurm error + raise RPCError(msg="Failed to modify users.") + + return out + + @staticmethod + def create(Connection db_conn, users): + cdef: + User user + SlurmList user_list + list assocs_to_add = [] + + db_conn.validate() + user_list = SlurmList.create(slurmdb_destroy_user_rec, owned=False) + + for user in users: + assocs_to_add.extend(user.associations) + slurm.slurm_list_append(user_list.info, user.ptr) + + verify_rpc(slurmdb_users_add(db_conn.ptr, user_list.info)) + Associations.create(db_conn, assocs_to_add) + + +cdef class UserFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_user_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_user_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_user_cond_t") + + memset(self.ptr, 0, sizeof(slurmdb_user_cond_t)) + + self.ptr.assoc_cond = try_xmalloc(sizeof(slurmdb_assoc_cond_t)) + if not self.ptr.assoc_cond: + raise MemoryError("xmalloc failed for slurmdb_assoc_cond_t") + + + def _create(self): + self._alloc() + cdef slurmdb_user_cond_t *ptr = self.ptr + + make_char_list(&ptr.assoc_cond.user_list, self.names) + ptr.with_assocs = 1 if self.with_assocs else 0 + ptr.with_coords = 1 if self.with_coordinators else 0 + ptr.with_wckeys = 1 if self.with_wckeys else 0 + ptr.with_deleted = 1 if self.with_deleted else 0 + + +cdef class User: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, name=None, **kwargs): + self._alloc_impl() + self.name = name + self._init_defaults() + for k, v in kwargs.items(): + setattr(self, k, v) + + def _init_defaults(self): + self.associations = [] + self.coordinators = [] + self.default_association = None + self.wckeys = [] + + def __dealloc__(self): + self._dealloc_impl() + + def _dealloc_impl(self): + slurmdb_destroy_user_rec(self.ptr) + self.ptr = NULL + + def _alloc_impl(self): + if not self.ptr: + self.ptr = try_xmalloc( + sizeof(slurmdb_user_rec_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_user_rec_t") + + memset(self.ptr, 0, sizeof(slurmdb_user_rec_t)) + self.ptr.uid = slurm.NO_VAL + + def __repr__(self): + return f'pyslurm.db.{self.__class__.__name__}({self.name})' + + @staticmethod + cdef User from_ptr(slurmdb_user_rec_t *in_ptr): + cdef User wrap = User.__new__(User) + wrap.ptr = in_ptr + wrap._init_defaults() + return wrap + + def to_dict(self): + """Database User information formatted as a dictionary. + + Returns: + (dict): Database User information as dict. + """ + return instance_to_dict(self) + + def __eq__(self, other): + if isinstance(other, User): + return self.name == other.name + return NotImplemented + + @staticmethod + def load(Connection db_conn, name): + user = Users.load(db_conn=db_conn).get(name) + if not user: + raise RPCError(msg=f"User {name} does not exist.") + + return user + + def create(self, Connection db_conn): + Users.create(db_conn, [self]) + + def delete(self, Connection db_conn): + Users({self.name: self}).delete(db_conn) + + def modify(self, Connection db_conn): + Users({self.name: self}).modify(self, db_conn) + + @property + def name(self): + return cstr.to_unicode(self.ptr.name) + + @name.setter + def name(self, val): + cstr.fmalloc(&self.ptr.name, val) + + @property + def previous_name(self): + return cstr.to_unicode(self.ptr.old_name) + + @property + def user_id(self): + return u32_parse(self.ptr.uid, zero_is_noval=False) + + @property + def default_account(self): + return cstr.to_unicode(self.ptr.default_acct) + + @property + def default_wckey(self): + return cstr.to_unicode(self.ptr.default_wckey) + + @property + def is_deleted(self): + if self.ptr.flags & slurm.SLURMDB_USER_FLAG_DELETED: + return True + return False + + @property + def admin_level(self): + return AdminLevel.from_flag(self.ptr.admin_level, + default=AdminLevel.UNDEFINED) + + @admin_level.setter + def admin_level(self, val): + self.ptr.admin_level = AdminLevel(val)._flag diff --git a/pyslurm/db/wckey.pxd b/pyslurm/db/wckey.pxd new file mode 100644 index 00000000..37680fba --- /dev/null +++ b/pyslurm/db/wckey.pxd @@ -0,0 +1,71 @@ +######################################################################### +# wckey.pxd - pyslurm slurmdbd wckey api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from libc.string cimport memcpy, memset +from pyslurm cimport slurm +from pyslurm.slurm cimport ( + slurmdb_wckey_rec_t, + slurmdb_wckey_cond_t, + slurmdb_wckeys_get, + slurmdb_destroy_wckey_rec, + slurmdb_destroy_wckey_cond, + try_xmalloc, +) +from pyslurm.db.util cimport ( + SlurmList, + SlurmListItem, + make_char_list, + slurm_list_to_pylist, + qos_list_to_pylist, +) +from pyslurm.db.tres cimport ( + _set_tres_limits, + TrackableResources, +) +from pyslurm.db.connection cimport Connection +from pyslurm.utils cimport cstr +from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list +from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr +from pyslurm.xcollections cimport MultiClusterMap +from pyslurm.utils.uint cimport u16_set_bool_flag + + +cdef class WCKeys(MultiClusterMap): + pass + + +cdef class WCKeyFilter: + cdef slurmdb_wckey_cond_t *ptr + + cdef public: + names + + +cdef class WCKey: + cdef: + slurmdb_wckey_rec_t *ptr + _cluster + + @staticmethod + cdef WCKey from_ptr(slurmdb_wckey_rec_t *in_ptr) diff --git a/pyslurm/db/wckey.pyx b/pyslurm/db/wckey.pyx new file mode 100644 index 00000000..c072cb2b --- /dev/null +++ b/pyslurm/db/wckey.pyx @@ -0,0 +1,191 @@ +######################################################################### +# wckey.pyx - pyslurm slurmdbd wckey api +######################################################################### +# Copyright (C) 2025 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +# +# cython: c_string_type=unicode, c_string_encoding=default +# cython: language_level=3 + +from pyslurm.core.error import RPCError +from pyslurm.utils.helpers import ( + instance_to_dict, + user_to_uid, +) +from pyslurm.utils.uint import * +from pyslurm import settings +from pyslurm import xcollections + + +cdef class WCKeys(MultiClusterMap): + + def __init__(self, wckeys=None): + super().__init__(data=wckeys, + typ="WCKeys", + val_type=WCKey, + id_attr=WCKey.name, + key_type=str) + + @staticmethod + def load(Connection db_conn, WCKeyFilter db_filter=None): + cdef: + WCKeys out = WCKeys() + WCKey wckey + WCKeyFilter cond = db_filter + SlurmList wckey_data + SlurmListItem wckey_ptr + + db_conn.validate() + + if not db_filter: + cond = WCKeyFilter() + cond._create() + + wckey_data = SlurmList.wrap(slurmdb_wckeys_get(db_conn.ptr, cond.ptr)) + + if wckey_data.is_null: + raise RPCError(msg="Failed to get WCKey data from slurmdbd.") + + for wckey_ptr in SlurmList.iter_and_pop(wckey_data): + wckey = WCKey.from_ptr(wckey_ptr.data) + + cluster = wckey.cluster + if cluster not in out.data: + out.data[cluster] = {} + out.data[cluster][wckey.name] = wckey + + return out + + +cdef class WCKeyFilter: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def __dealloc__(self): + self._dealloc() + + def _dealloc(self): + slurmdb_destroy_wckey_cond(self.ptr) + self.ptr = NULL + + def _alloc(self): + self._dealloc() + self.ptr = try_xmalloc(sizeof(slurmdb_wckey_cond_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_wckey_cond_t") + + def _create(self): + self._alloc() + cdef slurmdb_wckey_cond_t *ptr = self.ptr + + make_char_list(&ptr.name_list, self.names) + + +cdef class WCKey: + + def __cinit__(self): + self.ptr = NULL + + def __init__(self, name=None, **kwargs): + self._alloc_impl() + self.name = name + self._init_defaults() + for k, v in kwargs.items(): + setattr(self, k, v) + + def _init_defaults(self): + self._cluster = settings.LOCAL_CLUSTER + + def __dealloc__(self): + self._dealloc_impl() + + def _dealloc_impl(self): + slurmdb_destroy_wckey_rec(self.ptr) + self.ptr = NULL + + def _alloc_impl(self): + if not self.ptr: + self.ptr = try_xmalloc( + sizeof(slurmdb_wckey_rec_t)) + if not self.ptr: + raise MemoryError("xmalloc failed for slurmdb_wckey_rec_t") + + memset(self.ptr, 0, sizeof(slurmdb_wckey_rec_t)) + + def __repr__(self): + return f'pyslurm.db.{self.__class__.__name__}({self.name})' + + @staticmethod + cdef WCKey from_ptr(slurmdb_wckey_rec_t *in_ptr): + cdef WCKey wrap = WCKey.__new__(WCKey) + wrap.ptr = in_ptr + wrap._init_defaults() + return wrap + + def to_dict(self): + """Database WCKey information formatted as a dictionary. + + Returns: + (dict): Database WCKey information as dict. + """ + return instance_to_dict(self) + + def __eq__(self, other): + if isinstance(other, WCKey): + return self.id == other.id and self.cluster == other.cluster + return NotImplemented + + @property + def name(self): + return cstr.to_unicode(self.ptr.name) + + @property + def cluster(self): + cluster = cstr.to_unicode(self.ptr.cluster) + if not cluster: + return self._cluster + return cluster + + @property + def user_name(self): + return cstr.to_unicode(self.ptr.user) + + @property + def user_id(self): + return self.ptr.uid + + @property + def is_default(self): + return bool(self.ptr.is_def) + + @property + def id(self): + return self.ptr.id + + @property + def is_deleted(self): + if self.ptr.flags & slurm.SLURMDB_WCKEY_FLAG_DELETED: + return True + return False + + # TODO: list_t *accounting_list From caf58df11730fbcb93fae4c3f293a62b7125a9b1 Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sat, 21 Feb 2026 23:33:05 +0100 Subject: [PATCH 02/13] wip --- pyslurm/db/account.pyx | 48 ++++++--- pyslurm/db/assoc.pxd | 29 +++-- pyslurm/db/assoc.pyx | 138 ++++++++---------------- pyslurm/db/user.pyx | 42 +++++++- tests/integration/test_db_connection.py | 5 + 5 files changed, 138 insertions(+), 124 deletions(-) diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx index 1f5eff48..2e5346d1 100644 --- a/pyslurm/db/account.pyx +++ b/pyslurm/db/account.pyx @@ -34,8 +34,10 @@ from pyslurm.db.error import DefaultAccountError, JobsRunningError cdef class Accounts(dict): - def __init__(self, accounts=None): + def __init__(self, accounts={}, **kwargs): super().__init__() + self.update(accounts) + self.update(kwargs) @staticmethod def load(Connection db_conn, AccountFilter db_filter=None): @@ -85,6 +87,7 @@ cdef class Accounts(dict): account.association = assoc else: # These must be User Associations. + # TODO: maybe rename to user_associations account.associations.append(assoc) return out @@ -100,17 +103,15 @@ cdef class Accounts(dict): account_list = SlurmList.create(slurmdb_destroy_account_rec, owned=False) for account in accounts: -# if not account.associations and add_assoc: -# # For convenience, we create the associations by default -# # automatically, just like sacctmgr. Can be disabled for more -# # control. -# assoc = Association(account=account.name) -# account.associations.append(assoc) - - assocs_to_add.extend(account.associations) + if not account.association: + account.association = Association(account=account.name) + + assocs_to_add.append(account.association) slurm.slurm_list_append(account_list.info, account.ptr) verify_rpc(slurmdb_accounts_add(db_conn.ptr, account_list.info)) + # TODO: Maybe don't create the associations automatically? And don't do + # any hidden stuff? Associations.create(db_conn, assocs_to_add) def delete(self, Connection db_conn): @@ -119,9 +120,15 @@ cdef class Accounts(dict): SlurmList response list out = [] + # Check is required because for some reason if the acct_cond doesn't + # contain any valid conditions, slurmdbd will delete all accounts. + names = list(self.keys()) + if not names: + return + db_conn.validate() - a_filter = AccountFilter(names=list(self.keys())) + a_filter = AccountFilter(names=names) a_filter._create() response = SlurmList.wrap(slurmdb_accounts_remove(db_conn.ptr, a_filter.ptr)) @@ -189,10 +196,13 @@ cdef class Account: def __cinit__(self): self.ptr = NULL - def __init__(self, name=None, **kwargs): + def __init__(self, name=None, description=None, organization=None, **kwargs): self._alloc_impl() - self.name = name self._init_defaults() + self.name = name + self.description = description or name + self.organization = organization or name + for k, v in kwargs.items(): setattr(self, k, v) @@ -240,9 +250,23 @@ cdef class Account: return self.name == other.name return NotImplemented + @staticmethod + def load(Connection db_conn, name): + account = Accounts.load(db_conn=db_conn).get(name) + if not account: + raise RPCError(msg=f"Account {name} does not exist.") + + return account + def create(self, Connection db_conn): Accounts.create(db_conn, [self]) + def delete(self, Connection db_conn): + Accounts({self.name: self}).delete(db_conn) + + def modify(self, Connection db_conn): + Accounts({self.name: self}).modify(self, db_conn) + @property def name(self): return cstr.to_unicode(self.ptr.name) diff --git a/pyslurm/db/assoc.pxd b/pyslurm/db/assoc.pxd index aac46f0c..f1d08397 100644 --- a/pyslurm/db/assoc.pxd +++ b/pyslurm/db/assoc.pxd @@ -73,21 +73,6 @@ cdef class AssociationFilter: qos -cdef class AssociationLimits: - - cdef public: - group_tres - group_tres_mins - group_tres_run_mins - max_tres_mins_per_job - max_tres_run_mins_per_user - max_tres_per_job - max_tres_per_node - qos - - group_jobs - - cdef class Association: cdef: slurmdb_assoc_rec_t *ptr @@ -97,6 +82,8 @@ cdef class Association: owned cdef public: + default_qos + group_tres group_tres_mins group_tres_run_mins @@ -105,7 +92,17 @@ cdef class Association: max_tres_per_job max_tres_per_node qos - default_qos + group_jobs + group_jobs_accrue + group_submit_jobs + group_wall_time + max_jobs + max_jobs_accrue + max_submit_jobs + max_wall_time_per_job + min_priority_threshold + priority + shares @staticmethod cdef Association from_ptr(slurmdb_assoc_rec_t *in_ptr) diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index 25265f26..4b573af9 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -170,13 +170,18 @@ cdef class Associations(MultiClusterMap): if not associations: return - conn.validate() + db_conn.validate() for i, assoc in enumerate(associations): # Make sure to remove any duplicate associations, i.e. associations # having the same account name set. For some reason, the slurmdbd # doesn't like that. if assoc not in assoc_list: +# print(assoc.user) +# print(assoc.account) +# print(assoc.is_default) +# print(assoc.parent_account) +# print(assoc.cluster) assoc_list.append(assoc) verify_rpc(slurmdb_associations_add(db_conn.ptr, assoc_list.info)) @@ -257,10 +262,6 @@ cdef class AssociationFilter: # TODO: ASSOC_COND_FLAGS -cdef class AssociationLimits: - pass - - cdef class Association: def __cinit__(self): @@ -270,6 +271,12 @@ cdef class Association: def __init__(self, **kwargs): self._alloc_impl() self.id = 0 + + # Only when an Account-Association is initialized, we default to + # "root" as the Parent Account. + user = kwargs.get("user") + self.parent_account = kwargs.pop("parent_account", + "root" if not user else None) self.cluster = settings.LOCAL_CLUSTER for k, v in kwargs.items(): setattr(self, k, v) @@ -338,29 +345,6 @@ cdef class Association: def comment(self, val): cstr.fmalloc(&self.ptr.comment, val) - # uint32_t def_qos_id - - # uint16_t flags (ASSOC_FLAG_*) - - @group_jobs_accrue.setter - def group_jobs_accrue(self, val): - - @property - def group_submit_jobs(self): - return u32_parse(self.ptr.grp_submit_jobs, zero_is_noval=False) - - @group_submit_jobs.setter - def group_submit_jobs(self, val): - self.ptr.grp_submit_jobs = u32(val, zero_is_noval=False) - - @property - def group_wall_time(self): - return u32_parse(self.ptr.grp_wall, zero_is_noval=False) - - @group_wall_time.setter - def group_wall_time(self, val): - self.ptr.grp_wall = u32(val, zero_is_noval=False) - @property def id(self): return u32_parse(self.ptr.id) @@ -369,54 +353,27 @@ cdef class Association: def id(self, val): self.ptr.id = val - @property - def is_default(self): - return u16_parse_bool(self.ptr.is_def) - - @property - def max_jobs(self): - return u32_parse(self.ptr.max_jobs, zero_is_noval=False) - - @max_jobs.setter - def max_jobs(self, val): - self.ptr.max_jobs = u32(val, zero_is_noval=False) - @property - def max_jobs_accrue(self): - return u32_parse(self.ptr.max_jobs_accrue, zero_is_noval=False) - - @max_jobs_accrue.setter - def max_jobs_accrue(self, val): - self.ptr.max_jobs_accrue = u32(val, zero_is_noval=False) - - @property - def max_submit_jobs(self): - return u32_parse(self.ptr.max_submit_jobs, zero_is_noval=False) - - @max_submit_jobs.setter - def max_submit_jobs(self, val): - self.ptr.max_submit_jobs = u32(val, zero_is_noval=False) - - @property - def max_wall_time_per_job(self): - return u32_parse(self.ptr.max_wall_pj, zero_is_noval=False) + # uint32_t def_qos_id - @max_wall_time_per_job.setter - def max_wall_time_per_job(self, val): - self.ptr.max_wall_pj = u32(val, zero_is_noval=False) + # uint16_t flags (ASSOC_FLAG_*) @property - def min_priority_threshold(self): - return u32_parse(self.ptr.min_prio_thresh, zero_is_noval=False) + def is_default(self): + return u16_parse_bool(self.ptr.is_def) - @min_priority_threshold.setter - def min_priority_threshold(self, val): - self.ptr.min_prio_thresh = u32(val, zero_is_noval=False) + @is_default.setter + def is_default(self, val): + self.ptr.is_def = u16_bool(val) @property def parent_account(self): return cstr.to_unicode(self.ptr.parent_acct) + @parent_account.setter + def parent_account(self, val): + cstr.fmalloc(&self.ptr.parent_acct, val) + @property def parent_account_id(self): return u32_parse(self.ptr.parent_id, zero_is_noval=False) @@ -433,22 +390,6 @@ cdef class Association: def partition(self, val): cstr.fmalloc(&self.ptr.partition, val) - @property - def priority(self): - return u32_parse(self.ptr.priority, zero_is_noval=False) - - @priority.setter - def priority(self, val): - self.ptr.priority = u32(val) - - @property - def shares(self): - return u32_parse(self.ptr.shares_raw, zero_is_noval=False) - - @shares.setter - def shares(self, val): - self.ptr.shares_raw = u32(val) - @property def user(self): return cstr.to_unicode(self.ptr.user) @@ -467,8 +408,6 @@ cdef _parse_assoc_ptr(Association ass): TrackableResources tres = ass.tres_data QualitiesOfService qos = ass.qos_data - policy = ass.policy - ass.group_tres = TrackableResources.from_cstr( ass.ptr.grp_tres, tres) ass.group_tres_mins = TrackableResources.from_cstr( @@ -485,8 +424,17 @@ cdef _parse_assoc_ptr(Association ass): ass.ptr.max_tres_pn, tres) ass.qos = qos_list_to_pylist(ass.ptr.qos_list, qos) - policy.group_jobs = u32_parse(ass.ptr.grp_jobs, zero_is_noval=False) - policy.group_jobs_accrue = u32_parse(ass.ptr.grp_jobs_accrue, zero_is_noval=False) + ass.group_jobs = u32_parse(ass.ptr.grp_jobs, zero_is_noval=False) + ass.group_jobs_accrue = u32_parse(ass.ptr.grp_jobs_accrue, zero_is_noval=False) + ass.group_submit_jobs = u32_parse(ass.ptr.grp_submit_jobs, zero_is_noval=False) + ass.group_wall_time = u32_parse(ass.ptr.grp_wall, zero_is_noval=False) + ass.max_jobs = u32_parse(ass.ptr.max_jobs, zero_is_noval=False) + ass.max_jobs_accrue = u32_parse(ass.ptr.max_jobs_accrue, zero_is_noval=False) + ass.max_submit_jobs = u32_parse(ass.ptr.max_submit_jobs, zero_is_noval=False) + ass.max_wall_time_per_job = u32_parse(ass.ptr.max_wall_pj, zero_is_noval=False) + ass.min_priority_threshold = u32_parse(ass.ptr.min_prio_thresh, zero_is_noval=False) + ass.priority = u32_parse(ass.ptr.priority, zero_is_noval=False) + ass.shares = u32_parse(ass.ptr.shares_raw, zero_is_noval=False) # TODO: default_qos @@ -513,8 +461,16 @@ cdef _create_assoc_ptr(Association ass, conn=None): # them to its ID, which is why we need to load the current QOS available # in the system. ass.qos_data = QualitiesOfService.load(db_conn=conn) - _set_qos_list(&ass.ptr.qos_list, self.qos, ass.qos_data) - - ass.ptr.group_jobs = u32(ass.policy.group_jobs, zero_is_noval=False) - ass.ptr.group_jobs_accrue = u32(ass.policy.group_jobs_accrue, zero_is_noval=False) - + _set_qos_list(&ass.ptr.qos_list, ass.qos, ass.qos_data) + + ass.ptr.grp_jobs = u32(ass.group_jobs, zero_is_noval=False) + ass.ptr.grp_jobs_accrue = u32(ass.group_jobs_accrue, zero_is_noval=False) + ass.ptr.grp_submit_jobs = u32(ass.group_submit_jobs, zero_is_noval=False) + ass.ptr.grp_wall = u32(ass.group_wall_time, zero_is_noval=False) + ass.ptr.max_jobs = u32(ass.max_jobs, zero_is_noval=False) + ass.ptr.max_jobs_accrue = u32(ass.max_jobs_accrue, zero_is_noval=False) + ass.ptr.max_submit_jobs = u32(ass.max_submit_jobs, zero_is_noval=False) + ass.ptr.max_wall_pj = u32(ass.max_wall_time_per_job, zero_is_noval=False) + ass.ptr.min_prio_thresh = u32(ass.min_priority_threshold, zero_is_noval=False) + ass.ptr.priority = u32(ass.priority) + ass.ptr.shares_raw = u32(ass.shares) diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx index 20c9aebe..bcb5d98b 100644 --- a/pyslurm/db/user.pyx +++ b/pyslurm/db/user.pyx @@ -42,8 +42,10 @@ class AdminLevel(SlurmEnum): cdef class Users(dict): - def __init__(self, **kwargs): - super().__init__(kwargs) + def __init__(self, users={}, **kwargs): + super().__init__() + self.update(accounts) + self.update(kwargs) @staticmethod def load(Connection db_conn, UserFilter db_filter=None): @@ -102,12 +104,13 @@ cdef class Users(dict): SlurmList response SlurmListItem response_ptr - db_conn.validate() - + # TODO: test again when this is empty, does it really delete everything? names = list(self.keys()) if not names: return + db_conn.validate() + u_filter = UserFilter(names=names) u_filter._create() @@ -173,10 +176,35 @@ cdef class Users(dict): user_list = SlurmList.create(slurmdb_destroy_user_rec, owned=False) for user in users: + if user.default_account: + has_default_assoc = False + for assoc in user.associations: + if not assoc.is_default: + continue + + if has_default_assoc: + raise ValueError("Multiple Associations declared as default") + + has_default_assoc = True + if not assoc.account: + assoc.account = user.default_account + elif assoc.account != user.default_account: + raise ValueError("Ambigous account definition") + + # Do we really need to specify a default association anyway? + if not has_default_assoc: + # Caller didn't specify any default association, so we + # create a basic one. + assoc = Association(user=user.name, + account=user.default_account, is_default=True) + user.associations.append(assoc) + assocs_to_add.extend(user.associations) slurm.slurm_list_append(user_list.info, user.ptr) verify_rpc(slurmdb_users_add(db_conn.ptr, user_list.info)) + # TODO: Maybe don't create the associations automatically? And don't do + # any hidden stuff? Associations.create(db_conn, assocs_to_add) @@ -293,7 +321,7 @@ cdef class User: Users({self.name: self}).delete(db_conn) def modify(self, Connection db_conn): - Users({self.name: self}).modify(self, db_conn) + Users({self.name: self}).modify(db_conn, self) @property def name(self): @@ -315,6 +343,10 @@ cdef class User: def default_account(self): return cstr.to_unicode(self.ptr.default_acct) + @default_account.setter + def default_account(self, val): + cstr.fmalloc(&self.ptr.default_acct, val) + @property def default_wckey(self): return cstr.to_unicode(self.ptr.default_wckey) diff --git a/tests/integration/test_db_connection.py b/tests/integration/test_db_connection.py index 95b6f311..4b4c2d59 100644 --- a/tests/integration/test_db_connection.py +++ b/tests/integration/test_db_connection.py @@ -33,6 +33,11 @@ def test_open(): conn = pyslurm.db.Connection.open() assert conn.is_open + with pyslurm.db.connect() as conn2: + pass + + assert not conn2.is_open + def test_close(): conn = pyslurm.db.Connection.open() From 576317a12a45b19e98ac78267f8bfd6fbbd6390d Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sun, 22 Feb 2026 20:59:22 +0100 Subject: [PATCH 03/13] wip --- pyslurm/db/account.pyx | 42 ++++++++++++++++++++++++++++++++++++++---- pyslurm/db/assoc.pyx | 2 +- pyslurm/db/user.pyx | 19 ++++++------------- 3 files changed, 45 insertions(+), 18 deletions(-) diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx index 2e5346d1..95cb89d4 100644 --- a/pyslurm/db/account.pyx +++ b/pyslurm/db/account.pyx @@ -92,6 +92,38 @@ cdef class Accounts(dict): return out + def modify(self, Connection db_conn, Account changes): + cdef: + AccountFilter acct_filter + SlurmList response + SlurmListItem response_ptr + list out = [] + + db_conn.validate() + + acct_filter = AccountFilter(names=list(self.keys())) + acct_filter._create() + + response = SlurmList.wrap(slurmdb_accounts_modify( + db_conn.ptr, acct_filter.ptr, changes.ptr)) + + if not response.is_null and response.cnt: + for response_ptr in response: + response_str = cstr.to_unicode(response_ptr.data) + if not response_str: + continue + + out.append(response_str) + + elif not response.is_null: + # There was no real error, but simply nothing has been modified + return out + else: + # Autodetects the last slurm error + raise RPCError(msg="Failed to modify accounts.") + + return out + @staticmethod def create(Connection db_conn, accounts): cdef: @@ -237,13 +269,13 @@ cdef class Account: wrap._init_defaults() return wrap - def to_dict(self): + def to_dict(self, recursive=False): """Database Account information formatted as a dictionary. Returns: (dict): Database Account information as dict. """ - return instance_to_dict(self) + return instance_to_dict(self, recursive) def __eq__(self, other): if isinstance(other, Account): @@ -254,6 +286,8 @@ cdef class Account: def load(Connection db_conn, name): account = Accounts.load(db_conn=db_conn).get(name) if not account: + # TODO: Maybe don't raise here and just return None and let the + # Caller handle it? raise RPCError(msg=f"Account {name} does not exist.") return account @@ -264,8 +298,8 @@ cdef class Account: def delete(self, Connection db_conn): Accounts({self.name: self}).delete(db_conn) - def modify(self, Connection db_conn): - Accounts({self.name: self}).modify(self, db_conn) + def modify(self, Connection db_conn, Account changes): + Accounts({self.name: self}).modify(db_conn, changes) @property def name(self): diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index 4b573af9..4ff13800 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -307,7 +307,7 @@ cdef class Association: wrap.ptr = in_ptr return wrap - def to_dict(self, recursive = False): + def to_dict(self, recursive=False): """Database Association information formatted as a dictionary. Returns: diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx index bcb5d98b..43a0c839 100644 --- a/pyslurm/db/user.pyx +++ b/pyslurm/db/user.pyx @@ -31,20 +31,14 @@ from pyslurm.utils.uint import * from pyslurm import xcollections from pyslurm.utils.enums import SlurmEnum from pyslurm.db.error import JobsRunningError - - -class AdminLevel(SlurmEnum): - UNDEFINED = "UNDEFINED", slurm.SLURMDB_ADMIN_NOTSET - NONE = "NONE", slurm.SLURMDB_ADMIN_NONE - OPERATOR = "OPERATOR", slurm.SLURMDB_ADMIN_OPERATOR - ADMINISTRATOR = "ADMINISTRATOR", slurm.SLURMDB_ADMIN_SUPER_USER +from pyslurm.enums import AdminLevel cdef class Users(dict): def __init__(self, users={}, **kwargs): super().__init__() - self.update(accounts) + self.update(users) self.update(kwargs) @staticmethod @@ -293,13 +287,13 @@ cdef class User: wrap._init_defaults() return wrap - def to_dict(self): + def to_dict(self, recursive=False): """Database User information formatted as a dictionary. Returns: (dict): Database User information as dict. """ - return instance_to_dict(self) + return instance_to_dict(self, recursive) def __eq__(self, other): if isinstance(other, User): @@ -311,7 +305,6 @@ cdef class User: user = Users.load(db_conn=db_conn).get(name) if not user: raise RPCError(msg=f"User {name} does not exist.") - return user def create(self, Connection db_conn): @@ -320,8 +313,8 @@ cdef class User: def delete(self, Connection db_conn): Users({self.name: self}).delete(db_conn) - def modify(self, Connection db_conn): - Users({self.name: self}).modify(db_conn, self) + def modify(self, Connection db_conn, User changes): + Users({self.name: self}).modify(db_conn, changes) @property def name(self): From 149da1a0ed3f08fb2182c83fdd97094db0ad42ea Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sun, 22 Feb 2026 20:59:31 +0100 Subject: [PATCH 04/13] SlurmEnum: actually only check for equality --- pyslurm/utils/enums.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyslurm/utils/enums.pyx b/pyslurm/utils/enums.pyx index 5eb07a79..d44fd29b 100644 --- a/pyslurm/utils/enums.pyx +++ b/pyslurm/utils/enums.pyx @@ -79,7 +79,7 @@ class SlurmEnum(str, Enum, metaclass=DocstringSupport): def from_flag(cls, flags, default): out = cls(default) for item in cls: - if flags & item._flag or flags == item._flag: + if flags == item._flag: return item return out From 21add6c9828aadf33c526c96d5aa640261243958 Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sun, 22 Feb 2026 21:00:02 +0100 Subject: [PATCH 05/13] enums: add AdminLevel --- pyslurm/enums.pyx | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pyslurm/enums.pyx b/pyslurm/enums.pyx index acca2f6c..bd59577f 100644 --- a/pyslurm/enums.pyx +++ b/pyslurm/enums.pyx @@ -32,10 +32,10 @@ from pyslurm cimport slurm class SchedulerType(SlurmEnum): - SUBMIT = auto(), slurm.SLURMDB_JOB_FLAG_SUBMIT - MAIN = auto(), slurm.SLURMDB_JOB_FLAG_SCHED + SUBMIT = auto(), slurm.SLURMDB_JOB_FLAG_SUBMIT + MAIN = auto(), slurm.SLURMDB_JOB_FLAG_SCHED BACKFILL = auto(), slurm.SLURMDB_JOB_FLAG_BACKFILL - UNKNOWN = auto() + UNKNOWN = auto() SchedulerType.SUBMIT.__doc__ = "Scheduled immediately on submit" @@ -43,6 +43,14 @@ SchedulerType.MAIN.__doc__ = "Scheduled by the Main Scheduler" SchedulerType.SUBMIT.__doc__ = "Scheduled by the Backfill Scheduler" +class AdminLevel(SlurmEnum): + UNDEFINED = auto(), slurm.SLURMDB_ADMIN_NOTSET + NONE = auto(), slurm.SLURMDB_ADMIN_NONE + OPERATOR = auto(), slurm.SLURMDB_ADMIN_OPERATOR + ADMINISTRATOR = auto(), slurm.SLURMDB_ADMIN_SUPER_USER + + __all__ = [ "SchedulerType", + "AdminLevel", ] From 3487d0f58480510cd72adfb3e168c79ce72992e3 Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sun, 22 Feb 2026 21:17:50 +0100 Subject: [PATCH 06/13] add the first tests for assoc/account/user db apis --- tests/integration/test_assoc.py | 167 ++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 tests/integration/test_assoc.py diff --git a/tests/integration/test_assoc.py b/tests/integration/test_assoc.py new file mode 100644 index 00000000..7170ef20 --- /dev/null +++ b/tests/integration/test_assoc.py @@ -0,0 +1,167 @@ +######################################################################### +# test_assoc.py - database assoc/accounts/user api tests +######################################################################### +# Copyright (C) 2026 Toni Harzendorf +# +# This file is part of PySlurm +# +# PySlurm is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. + +# PySlurm is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License along +# with PySlurm; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. +"""test_assoc.py - Integration test assoc/account/user functionalities.""" + +import pyslurm +import pytest +import uuid +from pyslurm.db import ( + User, + Account, + Association, +) + + +def _modify_account(account, conn): + new_desc = "this is a new description" + changes = Account(description=new_desc) + assert account.description != new_desc + assoc_before = account.association.to_dict(recursive=True) + account.modify(conn, changes) + account = Account.load(conn, account.name) + assoc_after = account.association.to_dict(recursive=True) + assert account.description == new_desc + # Make sure we didn't change anything in the Association + assert assoc_before == assoc_after + + +def _modify_user(user, conn): + user_changes = User( + admin_level = pyslurm.AdminLevel.ADMINISTRATOR + ) + assert user.admin_level == pyslurm.AdminLevel.NONE + assoc_before = user.default_association.to_dict(recursive=True) + user.modify(conn, user_changes) + user = User.load(conn, user.name) + assoc_after = user.default_association.to_dict(recursive=True) + assert user.admin_level == pyslurm.AdminLevel.ADMINISTRATOR + # Make sure we didn't change anything in the Association + assert assoc_before == assoc_after + + +def _load_assoc(assoc_id, conn): + assocs = pyslurm.db.Associations.load(conn) + return assocs.get(assoc_id) + + +def _load_account(name, conn): + accounts = pyslurm.db.Accounts.load(conn) + assert len(accounts) + return accounts.get(name) + + +def _load_user(name, conn): + users = pyslurm.db.Users.load(conn) + assert len(users) + return users.get(name) + + +def _delete_account(account, conn): + account = Account.load(conn, account.name) + account.delete(conn) + conn.commit() + assert not _load_account(account.name, conn) + assert not _load_assoc(account.association.id, conn) + + +def _add_account(account, conn): + account.create(conn) + # Although everything works without this commit, the slurmdbd complains + # that it can't find the account assoc when going to add the user. + # Everything is created properly, but this error appears. This is + # probably why you can't create both account and user / associations + # directly with sacctmgr. Either it is designed like this, or this is a + # bug in the as_mysql plugin. + # The tests pass anyway, so it is fine, but needs to be documented. + conn.commit() + account = _load_account(account.name, conn) + assoc = account.association + assert assoc + assert _load_assoc(assoc.id, conn) + assert assoc.account == account.name + assert assoc.user is None + assert assoc.parent_account == "root" + return account + + +def _add_user(user, conn): + user.create(conn) + conn.commit() + user = _load_user(user.name, conn) + assert len(user.associations) == 1 + assoc = user.default_association + assert assoc + assert _load_assoc(assoc.id, conn) + assert assoc.user == user.name + assert assoc.account == user.default_account + assert assoc.is_default + return user + + +def _delete_user(user, conn): + assoc_id = user.default_association.id + user = User.load(conn, user.name) + user.delete(conn) + conn.commit() + assert not _load_user(user.name, conn) + assert not _load_assoc(assoc_id, conn) + + +def _test_modify_delete(user, account, conn): + assert conn.is_open + _modify_account(account, conn) + _modify_user(user, conn) + _delete_account(account, conn) + _delete_user(user, conn) + + +def test_user_and_account_no_assoc(): + random_name = str(uuid.uuid4())[:8] + user_name = f"user_{random_name}" + acc_name = f"acc_{random_name}" + + with pyslurm.db.connect() as conn: + account = Account(name=acc_name) + user = User(name=user_name, default_account=acc_name) + + account = _add_account(account, conn) + user = _add_user(user, conn) + assert user.name == user_name + assert account.name == acc_name + _test_modify_delete(user, account, conn) + + +def test_user_and_accounts_with_assoc_empty(): + random_name = str(uuid.uuid4())[:8] + user_name = f"user_{random_name}" + acc_name = f"acc_{random_name}" + + with pyslurm.db.connect() as conn: + account_assoc = Association(account=acc_name) + account = Account(name=acc_name, association=account_assoc) + user_assoc = Association(user=user_name, account=acc_name) + user = User(name=user_name, associations=[user_assoc]) + + account = _add_account(account, conn) + user = _add_user(user, conn) + assert user.name == user_name + assert account.name == acc_name + _test_modify_delete(user, account, conn) From bad74eceb4b1a9b387d66c6b4b938a3e520f0e41 Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Tue, 24 Feb 2026 20:45:20 +0100 Subject: [PATCH 07/13] wip --- pyslurm/core/error.pyx | 6 +- pyslurm/db/account.pxd | 6 +- pyslurm/db/account.pyx | 172 +++++++++++++---------- pyslurm/db/assoc.pyx | 286 ++++++++++++++++++++------------------ pyslurm/db/connection.pxd | 14 ++ pyslurm/db/connection.pyx | 24 +++- pyslurm/db/error.pyx | 17 +++ pyslurm/db/user.pxd | 6 +- pyslurm/db/user.pyx | 152 +++++++++++++------- 9 files changed, 417 insertions(+), 266 deletions(-) diff --git a/pyslurm/core/error.pyx b/pyslurm/core/error.pyx index 722f3097..f186a5b0 100644 --- a/pyslurm/core/error.pyx +++ b/pyslurm/core/error.pyx @@ -100,12 +100,14 @@ class RPCError(PyslurmError): super().__init__(self.msg) -def verify_rpc(errno): +def verify_rpc(errno, msg=None): """Verify a Slurm RPC Args: errno (int): A Slurm error value + msg (str): + An optional message """ if errno != slurm.SLURM_SUCCESS: - raise RPCError(errno) + raise RPCError(errno, msg) diff --git a/pyslurm/db/account.pxd b/pyslurm/db/account.pxd index 850a1761..dca2b2d2 100644 --- a/pyslurm/db/account.pxd +++ b/pyslurm/db/account.pxd @@ -48,7 +48,7 @@ from pyslurm.db.tres cimport ( _set_tres_limits, TrackableResources, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr @@ -56,6 +56,10 @@ from pyslurm.xcollections cimport MultiClusterMap from pyslurm.utils.uint cimport u16_set_bool_flag +cdef class AccountAPI(ConnectionWrapper): + pass + + cdef class Accounts(dict): pass diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx index 95cb89d4..1e834691 100644 --- a/pyslurm/db/account.pyx +++ b/pyslurm/db/account.pyx @@ -29,22 +29,19 @@ from pyslurm.utils.helpers import ( ) from pyslurm.utils.uint import * from pyslurm import xcollections -from pyslurm.db.error import DefaultAccountError, JobsRunningError +from pyslurm.db.error import ( + DefaultAccountError, + JobsRunningError, + parse_basic_response, +) -cdef class Accounts(dict): +cdef class AccountAPI(ConnectionWrapper): - def __init__(self, accounts={}, **kwargs): - super().__init__() - self.update(accounts) - self.update(kwargs) - - @staticmethod - def load(Connection db_conn, AccountFilter db_filter=None): + def load(self, db_filter: AccountFilter = None): cdef: Accounts out = Accounts() Account account - AccountFilter cond = db_filter SlurmList account_data SlurmListItem account_ptr SlurmList assoc_data @@ -53,23 +50,23 @@ cdef class Accounts(dict): QualitiesOfService qos_data TrackableResources tres_data - db_conn.validate() + self.db_conn.validate() if not db_filter: - cond = AccountFilter() + db_filter = AccountFilter() - if cond.with_assocs is not False: - cond.with_assocs = True + if db_filter.with_assocs is not False: + db_filter.with_assocs = True - cond._create() - account_data = SlurmList.wrap(slurmdb_accounts_get(db_conn.ptr, cond.ptr)) + db_filter._create() + account_data = SlurmList.wrap(slurmdb_accounts_get(self.db_conn.ptr, db_filter.ptr)) if account_data.is_null: raise RPCError(msg="Failed to get Account data from slurmdbd.") - qos_data = QualitiesOfService.load(db_conn=db_conn, + qos_data = QualitiesOfService.load(db_conn=self.db_conn, name_is_key=False) - tres_data = TrackableResources.load(db_conn=db_conn) + tres_data = TrackableResources.load(db_conn=self.db_conn) for account_ptr in SlurmList.iter_and_pop(account_data): account = Account.from_ptr(account_ptr.data) @@ -92,46 +89,87 @@ cdef class Accounts(dict): return out - def modify(self, Connection db_conn, Account changes): + + def delete(self, db_filter: AccountFilter): cdef: - AccountFilter acct_filter SlurmList response - SlurmListItem response_ptr list out = [] - db_conn.validate() + # Check is required because for some reason if the acct_cond doesn't + # contain any valid conditions, slurmdbd will delete all accounts. + # TODO: Maybe make it configurable + if not db_filter.names: + return - acct_filter = AccountFilter(names=list(self.keys())) - acct_filter._create() + self.db_conn.validate() + db_filter._create() - response = SlurmList.wrap(slurmdb_accounts_modify( - db_conn.ptr, acct_filter.ptr, changes.ptr)) + response = SlurmList.wrap(slurmdb_accounts_remove(self.db_conn.ptr, db_filter.ptr)) + rc = slurm_errno() + self.db_conn.check_commit(rc) + + if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + return + +# if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: +# verify_rpc(rc) + + # Handle the error cases. + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + raise JobsRunningError.from_response(response, rc) + elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: + raise DefaultAccountError.from_response(response, rc) + else: + verify_rpc(rc) - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue - out.append(response_str) + def modify(self, db_filter: AccountFilter, changes: Account): + cdef: + SlurmList response + SlurmListItem response_ptr + list out = [] - elif not response.is_null: - # There was no real error, but simply nothing has been modified + self.db_conn.validate() + db_filter._create() + + response = SlurmList.wrap(slurmdb_accounts_modify( + self.db_conn.ptr, db_filter.ptr, changes.ptr) + ) + rc = slurm_errno() + self.db_conn.check_commit(rc) + + if rc == slurm.SLURM_SUCCESS: + return parse_basic_response(response) + elif rc == slurm.SLURM_NO_CHANGE_IN_DATA: return out else: - # Autodetects the last slurm error + # verify_rpc(rc) raise RPCError(msg="Failed to modify accounts.") - return out - @staticmethod - def create(Connection db_conn, accounts): +# if not response.is_null and response.cnt: +# for response_ptr in response: +# response_str = cstr.to_unicode(response_ptr.data) +# if not response_str: +# continue + +# out.append(response_str) + +# elif not response.is_null: +# # There was no real error, but simply nothing has been modified +# return out +# else: +# # Autodetects the last slurm error +# raise RPCError(msg="Failed to modify accounts.") + + + def create(self, accounts): cdef: Account account SlurmList account_list list assocs_to_add = [] - db_conn.validate() + self.db_conn.validate() account_list = SlurmList.create(slurmdb_destroy_account_rec, owned=False) for account in accounts: @@ -141,44 +179,38 @@ cdef class Accounts(dict): assocs_to_add.append(account.association) slurm.slurm_list_append(account_list.info, account.ptr) - verify_rpc(slurmdb_accounts_add(db_conn.ptr, account_list.info)) + rc = slurmdb_accounts_add(self.db_conn.ptr, account_list.info) + # TODO: Only commit here when we don't add any associations? + # So we don't leave any Accounts without associations behind? + self.db_conn.check_commit(rc) + verify_rpc(rc) # TODO: Maybe don't create the associations automatically? And don't do # any hidden stuff? - Associations.create(db_conn, assocs_to_add) - - def delete(self, Connection db_conn): - cdef: - AccountFilter a_filter - SlurmList response - list out = [] + Associations.create(self.db_conn, assocs_to_add) - # Check is required because for some reason if the acct_cond doesn't - # contain any valid conditions, slurmdbd will delete all accounts. - names = list(self.keys()) - if not names: - return - db_conn.validate() +cdef class Accounts(dict): - a_filter = AccountFilter(names=names) - a_filter._create() + def __init__(self, accounts={}, **kwargs): + super().__init__() + self.update(accounts) + self.update(kwargs) - response = SlurmList.wrap(slurmdb_accounts_remove(db_conn.ptr, a_filter.ptr)) - rc = slurm_errno() + @staticmethod + def load(db_conn: Connection, db_filter: AccountFilter = None): + return db_conn.accounts.load(db_filter) - if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return + def delete(self, db_conn: Connection): + db_filter = AccountFilter(names=list(self.keys())) + db_conn.accounts.delete(db_filter) -# if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: -# verify_rpc(rc) + def modify(self, db_conn: Connection, changes: Account): + db_filter = AccountFilter(names=list(self.keys())) + return db_conn.accounts.modify(db_filter, changes) - # Handle the error cases. - if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: - raise JobsRunningError.from_response(response, rc) - elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: - raise DefaultAccountError.from_response(response, rc) - else: - verify_rpc(rc) + @staticmethod + def create(db_conn: Connection, accounts): + db_conn.accounts.create(accounts) cdef class AccountFilter: @@ -284,7 +316,7 @@ cdef class Account: @staticmethod def load(Connection db_conn, name): - account = Accounts.load(db_conn=db_conn).get(name) + account = db_conn.accounts.load().get(name) if not account: # TODO: Maybe don't raise here and just return None and let the # Caller handle it? @@ -293,7 +325,7 @@ cdef class Account: return account def create(self, Connection db_conn): - Accounts.create(db_conn, [self]) + db_conn.accounts.create([self]) def delete(self, Connection db_conn): Accounts({self.name: self}).delete(db_conn) diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index 4ff13800..4fce9e76 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -33,6 +33,143 @@ from pyslurm import xcollections from pyslurm.db.error import JobsRunningError, DefaultAccountError +def load(db_conn: Connection, db_filter: AssociationFilter = None): + cdef: + Associations out = Associations() + Association assoc + SlurmList assoc_data + SlurmListItem assoc_ptr + QualitiesOfService qos_data + TrackableResources tres_data + + db_conn.validate() + + if not db_filter: + db_filter = AssociationFilter() + db_filter._create() + + assoc_data = SlurmList.wrap(slurmdb_associations_get( + db_conn.ptr, db_filter.ptr) + ) + + if assoc_data.is_null: + raise RPCError(msg="Failed to get Association data from slurmdbd.") + + # Fetch other necessary dependencies needed for translating some + # attributes (i.e QoS IDs to its name) + qos_data = QualitiesOfService.load(db_conn=db_conn, + name_is_key=False) + tres_data = TrackableResources.load(db_conn=db_conn) + + for assoc_ptr in SlurmList.iter_and_pop(assoc_data): + assoc = Association.from_ptr(assoc_ptr.data) + assoc.qos_data = qos_data + assoc.tres_data = tres_data + _parse_assoc_ptr(assoc) + + cluster = assoc.cluster + if cluster not in out.data: + out.data[cluster] = {} + out.data[cluster][assoc.id] = assoc + + return out + + +def delete(db_conn: Connection, db_filter: AssociationFilter): + cdef: + SlurmList response + SlurmListItem response_ptr + + # TODO: Properly check if the filter is empty, cause it will then probably + # target all assocs. Or maybe that is fine and we need to clearly document + # to take caution + # if not db_filter.ids: + # return + + db_conn.validate() + a_filter._create() + + response = SlurmList.wrap(slurmdb_associations_remove( + db_conn.ptr, db_filter.ptr) + ) + rc = slurm_errno() + db_conn.check_commit(rc) + + if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + return + + #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: + # verify_rpc(rc) + + # Handle the error cases. + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + raise JobsRunningError.from_response(response, rc) + elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: + raise DefaultAccountError.from_response(response, rc) + else: + verify_rpc(rc) + + +def modify(db_conn: Connection, db_filter: AssociationFilter, changes: Association): + cdef: + SlurmList response + SlurmListItem response_ptr + list out = [] + + db_conn.validate() + db_filter._create() + + # Any data that isn't parsed yet or needs validation is done in this + # function. + _create_assoc_ptr(changes, db_conn) + + # Returns a List of char* with the associations that were modified + response = SlurmList.wrap(slurmdb_associations_modify( + db_conn.ptr, db_filter.ptr, changes.ptr)) + rc = slurm_errno() + db_conn.check_commit(rc) + + if not response.is_null and response.cnt: + for response_ptr in response: + response_str = cstr.to_unicode(response_ptr.data) + if not response_str: + continue + + # TODO: Better format + out.append(response_str) + + elif not response.is_null: + # There was no real error, but simply nothing has been modified + return None + else: + # Autodetects the last slurm error + raise RPCError() + + return out + + +def create(db_conn: Connection, associations): + cdef: + Association assoc + AssociationList assoc_list = AssociationList(owned=False) + + if not associations: + return + + db_conn.validate() + + for i, assoc in enumerate(associations): + # Make sure to remove any duplicate associations, i.e. associations + # having the same account name set. For some reason, the slurmdbd + # doesn't like that. + if assoc not in assoc_list: + assoc_list.append(assoc) + + rc = slurmdb_associations_add(db_conn.ptr, assoc_list.info) + db_conn.check_commit(rc) + verify_rpc(rc) + + cdef class AssociationList(SlurmList): def __init__(self, owned=True): @@ -75,149 +212,20 @@ cdef class Associations(MultiClusterMap): key_type=int) @staticmethod - def load(Connection db_conn, AssociationFilter db_filter=None): - cdef: - Associations out = Associations() - Association assoc - AssociationFilter cond = db_filter - SlurmList assoc_data - SlurmListItem assoc_ptr - QualitiesOfService qos_data - TrackableResources tres_data - - db_conn.validate() - - if not db_filter: - cond = AssociationFilter() - cond._create() - - assoc_data = SlurmList.wrap(slurmdb_associations_get( - db_conn.ptr, cond.ptr)) - - if assoc_data.is_null: - raise RPCError(msg="Failed to get Association data from slurmdbd.") - - # Fetch other necessary dependencies needed for translating some - # attributes (i.e QoS IDs to its name) - qos_data = QualitiesOfService.load(db_conn=db_conn, - name_is_key=False) - tres_data = TrackableResources.load(db_conn=db_conn) - - for assoc_ptr in SlurmList.iter_and_pop(assoc_data): - assoc = Association.from_ptr(assoc_ptr.data) - assoc.qos_data = qos_data - assoc.tres_data = tres_data - _parse_assoc_ptr(assoc) - - cluster = assoc.cluster - if cluster not in out.data: - out.data[cluster] = {} - out.data[cluster][assoc.id] = assoc - - return out + def load(db_conn: Connection, db_filter: AssociationFilter = None): + return load(db_conn, db_filter) - @staticmethod - def modify(Connection db_conn, db_filter, Association changes): - cdef: - AssociationFilter afilter - SlurmList response - SlurmListItem response_ptr - list out = [] - - db_conn.validate() - - # TODO: make db_filter optional? - # Prepare SQL Filter - if isinstance(db_filter, Associations): - assoc_ids = [ass.id for ass in db_filter] - afilter = AssociationFilter(ids=assoc_ids) - else: - afilter = db_filter - afilter._create() - - # Any data that isn't parsed yet or needs validation is done in this - # function. - _create_assoc_ptr(changes, db_conn) - - # Returns a List of char* with the associations that were modified - response = SlurmList.wrap(slurmdb_associations_modify( - db_conn.ptr, afilter.ptr, changes.ptr)) - - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue - - # TODO: Better format - out.append(response_str) - - elif not response.is_null: - # There was no real error, but simply nothing has been modified - return None - else: - # Autodetects the last slurm error - raise RPCError() + def delete(self, db_conn: Connection): + db_filter = AssociationFilter(ids=list(self.keys())) + delete(db_conn, db_filter, changes) - return out + def modify(self, db_conn: Connection, changes: Association): + db_filter = AssociationFilter(ids=list(self.keys())) + return modify(db_conn, db_filter, changes) @staticmethod - def create(Connection db_conn, associations, auto_add=True): - cdef: - Association assoc - AssociationList assoc_list = AssociationList(owned=False) - - if not associations: - return - - db_conn.validate() - - for i, assoc in enumerate(associations): - # Make sure to remove any duplicate associations, i.e. associations - # having the same account name set. For some reason, the slurmdbd - # doesn't like that. - if assoc not in assoc_list: -# print(assoc.user) -# print(assoc.account) -# print(assoc.is_default) -# print(assoc.parent_account) -# print(assoc.cluster) - assoc_list.append(assoc) - - verify_rpc(slurmdb_associations_add(db_conn.ptr, assoc_list.info)) - - def delete(self, Connection db_conn): - cdef: - AssociationFilter afilter - SlurmList response - SlurmListItem response_ptr - - db_conn.validate() - - ids = [assoc.id for assoc in self.values()] - if not ids: - return - - a_filter = AssociationFilter(ids=ids) - a_filter._create() - - response = SlurmList.wrap(slurmdb_associations_remove(db_conn.ptr, - a_filter.ptr)) - rc = slurm_errno() - - if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return - - #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: - # verify_rpc(rc) - - # Handle the error cases. - if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: - raise JobsRunningError.from_response(response, rc) - elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: - raise DefaultAccountError.from_response(response, rc) - else: - verify_rpc(rc) + def create(db_conn: Connection, associations): + create(db_conn, associations) cdef class AssociationFilter: diff --git a/pyslurm/db/connection.pxd b/pyslurm/db/connection.pxd index 6ac2dfc6..de041c8d 100644 --- a/pyslurm/db/connection.pxd +++ b/pyslurm/db/connection.pxd @@ -31,6 +31,11 @@ from pyslurm.slurm cimport ( ) +cdef class ConnectionWrapper: + cdef: + Connection db_conn + + cdef class Connection: """A connection to the slurmdbd. @@ -41,3 +46,12 @@ cdef class Connection: cdef: void *ptr uint16_t flags + + cdef public: + commit_on_success + rollback_on_error + + cdef readonly: + users + accounts + diff --git a/pyslurm/db/connection.pyx b/pyslurm/db/connection.pyx index 1e1833d8..c75bc041 100644 --- a/pyslurm/db/connection.pyx +++ b/pyslurm/db/connection.pyx @@ -24,6 +24,14 @@ from pyslurm.core.error import RPCError, PyslurmError from contextlib import contextmanager +from pyslurm.db.user import UserAPI +from pyslurm.db.account import AccountAPI + + +cdef class ConnectionWrapper: + + def __init__(self, db_conn: Connection): + self.db_conn = db_conn class InvalidConnectionError(PyslurmError): @@ -31,7 +39,7 @@ class InvalidConnectionError(PyslurmError): @contextmanager -def connect(): +def connect(commit_on_success=True, rollback_on_error=True): """A managed Slurm DB Connection""" connection = Connection.open() try: @@ -61,8 +69,14 @@ cdef class Connection: if not self.is_open: raise InvalidConnectionError("Connection is closed") + def check_commit(self, rc): + if self.commit_on_success and rc == slurm.SLURM_SUCCESS: + self.commit() + elif self.rollback_on_error and rc != slurm.SLURM_SUCCESS: + self.rollback() + @staticmethod - def open(): + def open(commit_on_success=True, rollback_on_error=True): """Open a new connection to the slurmdbd Raises: @@ -82,6 +96,12 @@ cdef class Connection: if not conn.ptr: raise RPCError(msg="Failed to open onnection to slurmdbd") + conn.commit_on_success = commit_on_success + conn.rollback_on_error = rollback_on_error + + # APIs + conn.users = UserAPI(conn) + conn.accounts = AccountAPI(conn) return conn def close(self): diff --git a/pyslurm/db/error.pyx b/pyslurm/db/error.pyx index 8c15cd5f..e4e7d80d 100644 --- a/pyslurm/db/error.pyx +++ b/pyslurm/db/error.pyx @@ -63,6 +63,23 @@ class DefaultAccountError(RPCError): return err +def get_responses(SlurmList response): + cdef SlurmListItem response_ptr + + #TODO: check also for count? + if response.is_null: + return [] + + for response_ptr in response: + response_str = response_ptr.to_str() + if response_str: + yield response_str + + +def parse_basic_response(SlurmList response): + return get_responses(response) + + def parse_default_account_errors(SlurmList response): cdef SlurmListItem response_ptr diff --git a/pyslurm/db/user.pxd b/pyslurm/db/user.pxd index 20dbd3c8..b639edfc 100644 --- a/pyslurm/db/user.pxd +++ b/pyslurm/db/user.pxd @@ -48,7 +48,7 @@ from pyslurm.db.tres cimport ( _set_tres_limits, TrackableResources, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list from pyslurm.db.assoc cimport Associations, Association, _parse_assoc_ptr, AssociationFilter, AssociationList @@ -56,6 +56,10 @@ from pyslurm.xcollections cimport MultiClusterMap from pyslurm.utils.uint cimport u16_set_bool_flag +cdef class UserAPI(ConnectionWrapper): + pass + + cdef class Users(dict): pass diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx index 43a0c839..ed049d9e 100644 --- a/pyslurm/db/user.pyx +++ b/pyslurm/db/user.pyx @@ -30,19 +30,13 @@ from pyslurm.utils.helpers import ( from pyslurm.utils.uint import * from pyslurm import xcollections from pyslurm.utils.enums import SlurmEnum -from pyslurm.db.error import JobsRunningError +from pyslurm.db.error import JobsRunningError, parse_basic_response from pyslurm.enums import AdminLevel -cdef class Users(dict): - - def __init__(self, users={}, **kwargs): - super().__init__() - self.update(users) - self.update(kwargs) +cdef class UserAPI(ConnectionWrapper): - @staticmethod - def load(Connection db_conn, UserFilter db_filter=None): + def load(self, db_filter: UserFilter = None): cdef: Users out = Users() User user @@ -55,7 +49,7 @@ cdef class Users(dict): QualitiesOfService qos_data TrackableResources tres_data - db_conn.validate() + self.db_conn.validate() if not db_filter: cond = UserFilter() @@ -66,14 +60,14 @@ cdef class Users(dict): cond.with_assocs = True cond._create() - user_data = SlurmList.wrap(slurmdb_users_get(db_conn.ptr, cond.ptr)) + user_data = SlurmList.wrap(slurmdb_users_get(self.db_conn.ptr, cond.ptr)) if user_data.is_null: raise RPCError(msg="Failed to get User data from slurmdbd") - qos_data = QualitiesOfService.load(db_conn=db_conn, + qos_data = QualitiesOfService.load(db_conn=self.db_conn, name_is_key=False) - tres_data = TrackableResources.load(db_conn=db_conn) + tres_data = TrackableResources.load(db_conn=self.db_conn) for user_ptr in SlurmList.iter_and_pop(user_data): user = User.from_ptr(user_ptr.data) @@ -92,24 +86,21 @@ cdef class Users(dict): return out - def delete(self, Connection db_conn): + + def delete(self, db_filter: UserFilter): cdef: - UserFilter u_filter SlurmList response - SlurmListItem response_ptr # TODO: test again when this is empty, does it really delete everything? - names = list(self.keys()) - if not names: + if not db_filter.names: return - db_conn.validate() + self.db_conn.validate() + db_filter._create() - u_filter = UserFilter(names=names) - u_filter._create() - - response = SlurmList.wrap(slurmdb_users_remove(db_conn.ptr, u_filter.ptr)) + response = SlurmList.wrap(slurmdb_users_remove(self.db_conn.ptr, db_filter.ptr)) rc = slurm_errno() + self.db_conn.check_commit(rc) if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: return @@ -120,53 +111,85 @@ cdef class Users(dict): # Handle the error case. Running Jobs should be the only possible error # where slurmdbd sends a response list. if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + # TODO: The slurmdbd actually deletes the associations, even if + # Jobs are running. The client side must then decide whether to + # rollback or actually commit the changes. sacctmgr does the + # rollback. + + # Should we also do this here automatically to prevent + # anyone accidentally forgetting this? Or let the caller handle it? + # If we do it, then it might rollback changes that were done + # earlier and haven't been committed yet. raise JobsRunningError.from_response(response, rc) else: verify_rpc(rc) - def modify(self, Connection db_conn, User changes): + + def modify(self, db_filter: UserFilter, changes: User): cdef: - UserFilter u_filter - AssociationFilter a_filter SlurmList response SlurmListItem response_ptr list out = [] - db_conn.validate() + # TODO: Properly check if the filter is empty, cause it will then probably + # target all users. Or maybe that is fine and we need to clearly document + # to take caution + #if not db_filter.names: + # return - u_filter = UserFilter(names=list(self.keys())) -# a_filter = AssociationFilter() + self.db_conn.validate() + db_filter._create() - # u_filter.ptr.assoc_cond = a_filter.ptr - u_filter._create() response = SlurmList.wrap(slurmdb_users_modify( - db_conn.ptr, u_filter.ptr, changes.ptr)) - - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue - - out.append(response_str) + self.db_conn.ptr, db_filter.ptr, changes.ptr) + ) + rc = slurm_errno() + self.db_conn.check_commit(rc) - elif not response.is_null: - # There was no real error, but simply nothing has been modified + if rc == slurm.SLURM_SUCCESS: + return parse_basic_response(response) + elif rc == slurm.SLURM_NO_CHANGE_IN_DATA: return out else: - # Autodetects the last slurm error + # verify_rpc(rc) + # ESLURM_ONE_CHANGE - when the name is changed, only 1 user can be + # specified at a time + + # SLURM_ERROR - general error raise RPCError(msg="Failed to modify users.") - return out +# if not response.is_null and response.cnt: +# for response_ptr in response: +# response_str = cstr.to_unicode(response_ptr.data) +# if not response_str: +# continue - @staticmethod - def create(Connection db_conn, users): +# out.append(response_str) + +# elif not response.is_null: +# # There was no real error, but simply nothing has been modified +# return out +# else: +# # TODO: handle errors better +# # ESLURM_NO_CHANGE_IN_DATA + +# # ESLURM_ONE_CHANGE - when the name is changed, only 1 user can be +# # specified at a time + +# # Autodetects the last slurm error +# raise RPCError(msg="Failed to modify users.") + + + def create(self, users): cdef: User user SlurmList user_list list assocs_to_add = [] - db_conn.validate() + if not users: + return + + self.db_conn.validate() user_list = SlurmList.create(slurmdb_destroy_user_rec, owned=False) for user in users: @@ -196,10 +219,37 @@ cdef class Users(dict): assocs_to_add.extend(user.associations) slurm.slurm_list_append(user_list.info, user.ptr) - verify_rpc(slurmdb_users_add(db_conn.ptr, user_list.info)) + rc = slurmdb_users_add(self.db_conn.ptr, user_list.info) + # TODO: Only commit here when we don't add any associations? + # So we don't leave any Users without associations behind? + self.db_conn.check_commit(rc) + verify_rpc(rc) # TODO: Maybe don't create the associations automatically? And don't do # any hidden stuff? - Associations.create(db_conn, assocs_to_add) + Associations.create(self.db_conn, assocs_to_add) + + +cdef class Users(dict): + + def __init__(self, users={}, **kwargs): + super().__init__() + self.update(users) + self.update(kwargs) + + @staticmethod + def load(db_conn: Connection, db_filter: UserFilter = None): + return db_conn.users.load(db_filter) + + def delete(self, Connection db_conn): + db_filter = UserFilter(names=list(self.keys())) + db_conn.users.delete(db_filter) + + def modify(self, db_conn: Connection, changes: User): + db_filter = UserFilter(names=list(self.keys())) + return db_conn.users.modify(db_filter, changes) + + def create(self, db_conn: Connection): + db_conn.users.create(list(self.values())) cdef class UserFilter: @@ -302,13 +352,13 @@ cdef class User: @staticmethod def load(Connection db_conn, name): - user = Users.load(db_conn=db_conn).get(name) + user = db_conn.users.load().get(name) if not user: raise RPCError(msg=f"User {name} does not exist.") return user def create(self, Connection db_conn): - Users.create(db_conn, [self]) + db_conn.users.create([self]) def delete(self, Connection db_conn): Users({self.name: self}).delete(db_conn) From 099ed6435dfc21c3bb13a485d96cd9170549a2e2 Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Thu, 26 Feb 2026 17:09:59 +0100 Subject: [PATCH 08/13] wip unified DB API --- pyslurm/db/account.pxd | 4 +- pyslurm/db/account.pyx | 38 ++-- pyslurm/db/assoc.pxd | 10 +- pyslurm/db/assoc.pyx | 299 ++++++++++++++------------- pyslurm/db/connection.pxd | 15 +- pyslurm/db/connection.pyx | 65 +++++- pyslurm/db/job.pxd | 9 +- pyslurm/db/job.pyx | 344 ++++++++++++++++++++----------- pyslurm/db/qos.pxd | 12 +- pyslurm/db/qos.pyx | 53 +++-- pyslurm/db/tres.pxd | 6 +- pyslurm/db/tres.pyx | 54 +++-- pyslurm/db/user.pxd | 8 +- pyslurm/db/user.pyx | 34 +-- tests/integration/test_assoc.py | 14 +- tests/integration/test_db_job.py | 49 +++-- tests/integration/test_db_qos.py | 33 +-- 17 files changed, 660 insertions(+), 387 deletions(-) diff --git a/pyslurm/db/account.pxd b/pyslurm/db/account.pxd index dca2b2d2..cf5169bd 100644 --- a/pyslurm/db/account.pxd +++ b/pyslurm/db/account.pxd @@ -61,7 +61,8 @@ cdef class AccountAPI(ConnectionWrapper): cdef class Accounts(dict): - pass + cdef public: + Connection _db_conn cdef class AccountFilter: @@ -101,6 +102,7 @@ cdef class Account: associations coordinators association + Connection _db_conn @staticmethod cdef Account from_ptr(slurmdb_account_rec_t *in_ptr) diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx index 1e834691..4338acc4 100644 --- a/pyslurm/db/account.pyx +++ b/pyslurm/db/account.pyx @@ -64,19 +64,20 @@ cdef class AccountAPI(ConnectionWrapper): if account_data.is_null: raise RPCError(msg="Failed to get Account data from slurmdbd.") - qos_data = QualitiesOfService.load(db_conn=self.db_conn, - name_is_key=False) - tres_data = TrackableResources.load(db_conn=self.db_conn) + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() for account_ptr in SlurmList.iter_and_pop(account_data): account = Account.from_ptr(account_ptr.data) out[account.name] = account + self.db_conn.apply_reuse(account) assoc_data = SlurmList.wrap(account.ptr.assoc_list, owned=False) for assoc_ptr in SlurmList.iter_and_pop(assoc_data): assoc = Association.from_ptr(assoc_ptr.data) assoc.qos_data = qos_data assoc.tres_data = tres_data + self.db_conn.apply_reuse(assoc) _parse_assoc_ptr(assoc) if not assoc.user: @@ -87,6 +88,7 @@ cdef class AccountAPI(ConnectionWrapper): # TODO: maybe rename to user_associations account.associations.append(assoc) + self.db_conn.apply_reuse(out) return out @@ -186,7 +188,7 @@ cdef class AccountAPI(ConnectionWrapper): verify_rpc(rc) # TODO: Maybe don't create the associations automatically? And don't do # any hidden stuff? - Associations.create(self.db_conn, assocs_to_add) + self.db_conn.associations.create(assocs_to_add) cdef class Accounts(dict): @@ -195,22 +197,25 @@ cdef class Accounts(dict): super().__init__() self.update(accounts) self.update(kwargs) + self._db_conn = None @staticmethod def load(db_conn: Connection, db_filter: AccountFilter = None): return db_conn.accounts.load(db_filter) - def delete(self, db_conn: Connection): + def delete(self, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = AccountFilter(names=list(self.keys())) db_conn.accounts.delete(db_filter) - def modify(self, db_conn: Connection, changes: Account): + def modify(self, changes: Account, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = AccountFilter(names=list(self.keys())) return db_conn.accounts.modify(db_filter, changes) - @staticmethod - def create(db_conn: Connection, accounts): - db_conn.accounts.create(accounts) + def create(self, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_conn.accounts.create(list(self.values())) cdef class AccountFilter: @@ -315,23 +320,22 @@ cdef class Account: return NotImplemented @staticmethod - def load(Connection db_conn, name): + def load(db_conn: Connection, name: str): account = db_conn.accounts.load().get(name) if not account: # TODO: Maybe don't raise here and just return None and let the # Caller handle it? raise RPCError(msg=f"Account {name} does not exist.") - return account - def create(self, Connection db_conn): - db_conn.accounts.create([self]) + def create(self, db_conn: Connection | None = None): + Accounts({self.name: self}).create(self._db_conn or db_conn) - def delete(self, Connection db_conn): - Accounts({self.name: self}).delete(db_conn) + def delete(self, db_conn: Connection | None = None): + Accounts({self.name: self}).delete(self._db_conn or db_conn) - def modify(self, Connection db_conn, Account changes): - Accounts({self.name: self}).modify(db_conn, changes) + def modify(self, changes: Account, db_conn: Connection | None = None): + Accounts({self.name: self}).modify(changes, self._db_conn or db_conn) @property def name(self): diff --git a/pyslurm/db/assoc.pxd b/pyslurm/db/assoc.pxd index f1d08397..cf84951e 100644 --- a/pyslurm/db/assoc.pxd +++ b/pyslurm/db/assoc.pxd @@ -46,7 +46,7 @@ from pyslurm.db.tres cimport ( _set_tres_limits, TrackableResources, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.utils.uint cimport * from pyslurm.db.qos cimport QualitiesOfService, _set_qos_list @@ -56,10 +56,15 @@ cdef _parse_assoc_ptr(Association ass) cdef _create_assoc_ptr(Association ass, conn=*) -cdef class Associations(MultiClusterMap): +cdef class AssociationAPI(ConnectionWrapper): pass +cdef class Associations(MultiClusterMap): + cdef public: + Connection _db_conn + + cdef class AssociationFilter: cdef slurmdb_assoc_cond_t *ptr @@ -82,6 +87,7 @@ cdef class Association: owned cdef public: + Connection _db_conn default_qos group_tres diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index 4fce9e76..8a086828 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -33,143 +33,146 @@ from pyslurm import xcollections from pyslurm.db.error import JobsRunningError, DefaultAccountError -def load(db_conn: Connection, db_filter: AssociationFilter = None): - cdef: - Associations out = Associations() - Association assoc - SlurmList assoc_data - SlurmListItem assoc_ptr - QualitiesOfService qos_data - TrackableResources tres_data - - db_conn.validate() - - if not db_filter: - db_filter = AssociationFilter() - db_filter._create() - - assoc_data = SlurmList.wrap(slurmdb_associations_get( - db_conn.ptr, db_filter.ptr) - ) +cdef class AssociationAPI(ConnectionWrapper): + + def load(self, db_filter: AssociationFilter = None): + cdef: + Associations out = Associations() + Association assoc + SlurmList assoc_data + SlurmListItem assoc_ptr + QualitiesOfService qos_data + TrackableResources tres_data + + self.db_conn.validate() + + if not db_filter: + db_filter = AssociationFilter() + db_filter._create() + + assoc_data = SlurmList.wrap(slurmdb_associations_get( + self.db_conn.ptr, db_filter.ptr) + ) + + if assoc_data.is_null: + raise RPCError(msg="Failed to get Association data from slurmdbd.") + + # Fetch other necessary dependencies needed for translating some + # attributes (i.e QoS IDs to its name) + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() + + for assoc_ptr in SlurmList.iter_and_pop(assoc_data): + assoc = Association.from_ptr(assoc_ptr.data) + assoc.qos_data = qos_data + assoc.tres_data = tres_data + self.db_conn.apply_reuse(assoc) + _parse_assoc_ptr(assoc) + + cluster = assoc.cluster + if cluster not in out.data: + out.data[cluster] = {} + out.data[cluster][assoc.id] = assoc + + self.db_conn.apply_reuse(out) + return out + + + def delete(self, db_filter: AssociationFilter): + cdef: + SlurmList response + SlurmListItem response_ptr + + # TODO: Properly check if the filter is empty, cause it will then probably + # target all assocs. Or maybe that is fine and we need to clearly document + # to take caution + # if not db_filter.ids: + # return + + self.db_conn.validate() + a_filter._create() + + response = SlurmList.wrap(slurmdb_associations_remove( + self.db_conn.ptr, db_filter.ptr) + ) + rc = slurm_errno() + self.db_conn.check_commit(rc) + + if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + return + + #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: + # verify_rpc(rc) + + # Handle the error cases. + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + raise JobsRunningError.from_response(response, rc) + elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: + raise DefaultAccountError.from_response(response, rc) + else: + verify_rpc(rc) + + + def modify(self, db_filter: AssociationFilter, changes: Association): + cdef: + SlurmList response + SlurmListItem response_ptr + list out = [] + + self.db_conn.validate() + db_filter._create() + + # Any data that isn't parsed yet or needs validation is done in this + # function. + _create_assoc_ptr(changes, self.db_conn) + + # Returns a List of char* with the associations that were modified + response = SlurmList.wrap(slurmdb_associations_modify( + self.db_conn.ptr, db_filter.ptr, changes.ptr)) + rc = slurm_errno() + self.db_conn.check_commit(rc) + + if not response.is_null and response.cnt: + for response_ptr in response: + response_str = cstr.to_unicode(response_ptr.data) + if not response_str: + continue + + # TODO: Better format + out.append(response_str) + + elif not response.is_null: + # There was no real error, but simply nothing has been modified + return None + else: + # Autodetects the last slurm error + raise RPCError() - if assoc_data.is_null: - raise RPCError(msg="Failed to get Association data from slurmdbd.") + return out - # Fetch other necessary dependencies needed for translating some - # attributes (i.e QoS IDs to its name) - qos_data = QualitiesOfService.load(db_conn=db_conn, - name_is_key=False) - tres_data = TrackableResources.load(db_conn=db_conn) - for assoc_ptr in SlurmList.iter_and_pop(assoc_data): - assoc = Association.from_ptr(assoc_ptr.data) - assoc.qos_data = qos_data - assoc.tres_data = tres_data - _parse_assoc_ptr(assoc) + def create(self, associations): + cdef: + Association assoc + AssociationList assoc_list = AssociationList(owned=False) - cluster = assoc.cluster - if cluster not in out.data: - out.data[cluster] = {} - out.data[cluster][assoc.id] = assoc + if not associations: + return - return out + self.db_conn.validate() + for i, assoc in enumerate(associations): + # Make sure to remove any duplicate associations, i.e. associations + # having the same account name set. For some reason, the slurmdbd + # doesn't like that. + if assoc not in assoc_list: + assoc_list.append(assoc) -def delete(db_conn: Connection, db_filter: AssociationFilter): - cdef: - SlurmList response - SlurmListItem response_ptr - - # TODO: Properly check if the filter is empty, cause it will then probably - # target all assocs. Or maybe that is fine and we need to clearly document - # to take caution - # if not db_filter.ids: - # return - - db_conn.validate() - a_filter._create() - - response = SlurmList.wrap(slurmdb_associations_remove( - db_conn.ptr, db_filter.ptr) - ) - rc = slurm_errno() - db_conn.check_commit(rc) - - if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return - - #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: - # verify_rpc(rc) - - # Handle the error cases. - if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: - raise JobsRunningError.from_response(response, rc) - elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: - raise DefaultAccountError.from_response(response, rc) - else: + rc = slurmdb_associations_add(self.db_conn.ptr, assoc_list.info) + self.db_conn.check_commit(rc) verify_rpc(rc) -def modify(db_conn: Connection, db_filter: AssociationFilter, changes: Association): - cdef: - SlurmList response - SlurmListItem response_ptr - list out = [] - - db_conn.validate() - db_filter._create() - - # Any data that isn't parsed yet or needs validation is done in this - # function. - _create_assoc_ptr(changes, db_conn) - - # Returns a List of char* with the associations that were modified - response = SlurmList.wrap(slurmdb_associations_modify( - db_conn.ptr, db_filter.ptr, changes.ptr)) - rc = slurm_errno() - db_conn.check_commit(rc) - - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue - - # TODO: Better format - out.append(response_str) - - elif not response.is_null: - # There was no real error, but simply nothing has been modified - return None - else: - # Autodetects the last slurm error - raise RPCError() - - return out - - -def create(db_conn: Connection, associations): - cdef: - Association assoc - AssociationList assoc_list = AssociationList(owned=False) - - if not associations: - return - - db_conn.validate() - - for i, assoc in enumerate(associations): - # Make sure to remove any duplicate associations, i.e. associations - # having the same account name set. For some reason, the slurmdbd - # doesn't like that. - if assoc not in assoc_list: - assoc_list.append(assoc) - - rc = slurmdb_associations_add(db_conn.ptr, assoc_list.info) - db_conn.check_commit(rc) - verify_rpc(rc) - - cdef class AssociationList(SlurmList): def __init__(self, owned=True): @@ -210,22 +213,25 @@ cdef class Associations(MultiClusterMap): val_type=Association, id_attr=Association.id, key_type=int) + self._db_conn = None @staticmethod - def load(db_conn: Connection, db_filter: AssociationFilter = None): - return load(db_conn, db_filter) + def load(db_conn: Connection, db_filter: AssociationFilter | None = None): + return db_conn.associations.load(db_filter) - def delete(self, db_conn: Connection): + def delete(self, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = AssociationFilter(ids=list(self.keys())) - delete(db_conn, db_filter, changes) + db_conn.associations.delete(db_filter, changes) - def modify(self, db_conn: Connection, changes: Association): + def modify(self, changes: Association, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = AssociationFilter(ids=list(self.keys())) - return modify(db_conn, db_filter, changes) + return db_conn.associations.modify(db_filter, changes) - @staticmethod - def create(db_conn: Connection, associations): - create(db_conn, associations) + def create(self, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_conn.associations.create(list(self.values())) cdef class AssociationFilter: @@ -329,6 +335,23 @@ cdef class Association: return self.cluster == other.cluster and self.partition == other.partition and self.account == other.account and self.user == other.user return NotImplemented +# @staticmethod +# def load(db_conn: Connection, name: str): +# user = db_conn.users.load().get(name) +# if not user: +# raise RPCError(msg=f"User {name} does not exist.") +# return user + + def create(self, db_conn: Connection = None): + db_conn = Connection.reuse(self._db_conn, db_conn) + db_conn.associations.create([self]) + + def delete(self, db_conn: Connection = None): + Associations({self.id: self}).delete(self._db_conn or db_conn) + + def modify(self, changes: Association, db_conn: Connection | None = None): + Associations({self.id: self}).modify(changes, self._db_conn or db_conn) + @property def account(self): return cstr.to_unicode(self.ptr.acct) @@ -450,7 +473,7 @@ cdef _create_assoc_ptr(Association ass, conn=None): # _set_tres_limits will also check if specified TRES are valid and # translate them to its ID which is why we need to load the current TRES # available in the system. - ass.tres_data = TrackableResources.load(db_conn=conn) + ass.tres_data = conn.tres.load() _set_tres_limits(&ass.ptr.grp_tres, ass.group_tres, ass.tres_data) _set_tres_limits(&ass.ptr.grp_tres_mins, ass.group_tres_mins, ass.tres_data) @@ -468,7 +491,7 @@ cdef _create_assoc_ptr(Association ass, conn=None): # _set_qos_list will also check if specified QoS are valid and translate # them to its ID, which is why we need to load the current QOS available # in the system. - ass.qos_data = QualitiesOfService.load(db_conn=conn) + ass.qos_data = conn.qos.load() _set_qos_list(&ass.ptr.qos_list, ass.qos, ass.qos_data) ass.ptr.grp_jobs = u32(ass.group_jobs, zero_is_noval=False) diff --git a/pyslurm/db/connection.pxd b/pyslurm/db/connection.pxd index de041c8d..0a01a31d 100644 --- a/pyslurm/db/connection.pxd +++ b/pyslurm/db/connection.pxd @@ -31,6 +31,13 @@ from pyslurm.slurm cimport ( ) +cdef class ConnectionConfig: + cdef public: + commit_on_success + rollback_on_error + reuse_connection + + cdef class ConnectionWrapper: cdef: Connection db_conn @@ -48,10 +55,12 @@ cdef class Connection: uint16_t flags cdef public: - commit_on_success - rollback_on_error + config cdef readonly: users accounts - + associations + tres + qos + jobs diff --git a/pyslurm/db/connection.pyx b/pyslurm/db/connection.pyx index c75bc041..0b76457c 100644 --- a/pyslurm/db/connection.pyx +++ b/pyslurm/db/connection.pyx @@ -1,7 +1,7 @@ ######################################################################### # connection.pyx - pyslurm slurmdbd database connection ######################################################################### -# Copyright (C) 2023 Toni Harzendorf +# Copyright (C) 2026 Toni Harzendorf # # This file is part of PySlurm # @@ -26,6 +26,24 @@ from pyslurm.core.error import RPCError, PyslurmError from contextlib import contextmanager from pyslurm.db.user import UserAPI from pyslurm.db.account import AccountAPI +from pyslurm.db.assoc import AssociationAPI +from pyslurm.db.tres import TrackableResourceAPI +from pyslurm.db.qos import QualityOfServiceAPI +from pyslurm.db.job import JobsAPI +from typing import Any + + +cdef class ConnectionConfig: + + def __init__( + self, + commit_on_success: bool = True, + rollback_on_error: bool = True, + reuse_connection: bool = True, + ): + self.commit_on_success = commit_on_success + self.rollback_on_error = rollback_on_error + self.reuse_connection = reuse_connection cdef class ConnectionWrapper: @@ -38,10 +56,17 @@ class InvalidConnectionError(PyslurmError): pass +class ConfigError(PyslurmError): + pass + + @contextmanager -def connect(commit_on_success=True, rollback_on_error=True): +def connect(config: ConnectionConfig | None = None, **kwargs: Any): """A managed Slurm DB Connection""" - connection = Connection.open() + if config is not None and kwargs: + raise ConfigError("Must provide either a config directly, or kwargs, not both") + + connection = Connection.open(config, **kwargs) try: yield connection finally: @@ -65,18 +90,34 @@ cdef class Connection: state = "open" if self.is_open else "closed" return f'pyslurm.db.{self.__class__.__name__} is {state}' + @staticmethod + def reuse( + reusable_conn: Connection | None = None, + explicit_conn: Connection | None = None + ): + if explicit_conn: + return explicit_conn + elif reusable_conn: + return reusable_conn + else: + raise InvalidConnectionError("No suitable Connection was provided") + + def apply_reuse(self, obj): + if self.config.reuse_connection: + obj._db_conn = self + def validate(self): if not self.is_open: raise InvalidConnectionError("Connection is closed") def check_commit(self, rc): - if self.commit_on_success and rc == slurm.SLURM_SUCCESS: + if self.config.commit_on_success and rc == slurm.SLURM_SUCCESS: self.commit() - elif self.rollback_on_error and rc != slurm.SLURM_SUCCESS: + elif self.config.rollback_on_error and rc != slurm.SLURM_SUCCESS: self.rollback() @staticmethod - def open(commit_on_success=True, rollback_on_error=True): + def open(config: ConnectionConfig | None = None, **kwargs: Any): """Open a new connection to the slurmdbd Raises: @@ -91,17 +132,23 @@ cdef class Connection: >>> print(connection.is_open) True """ + if config is not None and kwargs: + raise ConfigError("Must provide either a config directly, or kwargs, not both") + cdef Connection conn = Connection.__new__(Connection) conn.ptr = slurmdb_connection_get(&conn.flags) if not conn.ptr: raise RPCError(msg="Failed to open onnection to slurmdbd") - conn.commit_on_success = commit_on_success - conn.rollback_on_error = rollback_on_error + conn.config = config or ConnectionConfig(**kwargs) - # APIs + # Initialize all DB APIs conn.users = UserAPI(conn) conn.accounts = AccountAPI(conn) + conn.associations = AssociationAPI(conn) + conn.tres = TrackableResourceAPI(conn) + conn.qos = QualityOfServiceAPI(conn) + conn.jobs = JobsAPI(conn) return conn def close(self): diff --git a/pyslurm/db/job.pxd b/pyslurm/db/job.pxd index f3593c2f..4688583d 100644 --- a/pyslurm/db/job.pxd +++ b/pyslurm/db/job.pxd @@ -50,7 +50,7 @@ from pyslurm.db.util cimport ( ) from pyslurm.db.step cimport JobStep, JobSteps from pyslurm.db.stats cimport JobStatistics -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.db.qos cimport QualitiesOfService from pyslurm.db.tres cimport ( @@ -63,6 +63,10 @@ from pyslurm.utils.uint cimport u32_parse_bool_flag from libc.stdint cimport uint32_t +cdef class JobsAPI(ConnectionWrapper): + pass + + cdef class JobFilter: """Query-Conditions for Jobs in the Slurm Database. @@ -140,6 +144,7 @@ cdef class JobFilter: cdef slurmdb_job_cond_t *ptr cdef public: + Connection _db_conn ids start_time end_time @@ -183,6 +188,7 @@ cdef class Jobs(MultiClusterMap): Total amount of requested memory in Mebibytes. """ cdef public: + Connection _db_conn stats cpus nodes @@ -364,6 +370,7 @@ cdef class Job: TrackableResources tres_data cdef public: + Connection _db_conn JobSteps steps JobStatistics stats diff --git a/pyslurm/db/job.pyx b/pyslurm/db/job.pyx index 99e2bef4..b70bae69 100644 --- a/pyslurm/db/job.pyx +++ b/pyslurm/db/job.pyx @@ -45,6 +45,200 @@ from pyslurm.utils.helpers import ( from pyslurm.enums import SchedulerType +cdef class JobsAPI(ConnectionWrapper): + + def load(self, db_filter: JobFilter | None = None): + """Load Jobs from the Slurm Database + + Implements the slurmdb_jobs_get RPC. + + Args: + db_filter (pyslurm.db.JobFilter): + A search filter that the slurmdbd will apply when retrieving + Jobs from the database. + + Returns: + (pyslurm.db.Jobs): A Collection of database Jobs. + + Raises: + (pyslurm.RPCError): When getting the Jobs from the Database was not + successful + + Examples: + Without a Filter the default behaviour applies, which is + simply retrieving all Jobs from the same day: + + >>> import pyslurm + >>> db_jobs = pyslurm.db.Jobs.load() + >>> print(db_jobs) + pyslurm.db.Jobs({1: pyslurm.db.Job(1), 2: pyslurm.db.Job(2)}) + >>> print(db_jobs[1]) + pyslurm.db.Job(1) + + Now with a Job Filter, so only Jobs that have specific Accounts + are returned: + + >>> import pyslurm + >>> accounts = ["acc1", "acc2"] + >>> db_filter = pyslurm.db.JobFilter(accounts=accounts) + >>> db_jobs = pyslurm.db.Jobs.load(db_filter) + """ + cdef: + Jobs out = Jobs() + Job job + JobFilter cond = db_filter + SlurmList job_data + SlurmListItem job_ptr + QualitiesOfService qos_data + TrackableResources tres_data + + self.db_conn.validate() + + # Prepare SQL Filter + if not db_filter: + db_filter = JobFilter() + + db_filter._db_conn = db_conn + db_filter._create() + + # Fetch Job data + job_data = SlurmList.wrap(slurmdb_jobs_get(self.db_conn.ptr, db_filter.ptr)) + if job_data.is_null: + raise RPCError(msg="Failed to get Jobs from slurmdbd") + + # Fetch other necessary dependencies needed for translating some + # attributes (i.e QoS IDs to its name) + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() + + # TODO: How to handle the possibility of duplicate job ids that could + # appear if IDs on a cluster are reset? + for job_ptr in SlurmList.iter_and_pop(job_data): + job = Job.from_ptr(job_ptr.data) + job.qos_data = qos_data + job.tres_data = tres_data + job._create_steps() + job.stats = JobStatistics.from_steps(job.steps) + + elapsed = job.elapsed_time if job.elapsed_time else 0 + cpus = job.cpus if job.cpus else 1 + job.stats.elapsed_cpu_time = elapsed * cpus + + cluster = job.cluster + if cluster not in out.data: + out.data[cluster] = {} + out[cluster][job.id] = job + + out._add_stats(job) + + return out + + def modify(self, db_filter: JobFilter | Jobs, changes: Job): + """Modify Slurm database Jobs. + + Implements the slurm_job_modify RPC. + + Args: + db_filter (Union[pyslurm.db.JobFilter, pyslurm.db.Jobs]): + A filter to decide which Jobs should be modified. + changes (pyslurm.db.Job): + Another [pyslurm.db.Job][] object that contains all the + changes to apply. Check the `Other Parameters` of the + [pyslurm.db.Job][] class to see which properties can be + modified. + + Returns: + (list[int]): A list of Jobs that were modified + + Raises: + (pyslurm.RPCError): When a failure modifying the Jobs occurred. + + Examples: + In its simplest form, you can do something like this: + + >>> import pyslurm + >>> + >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) + >>> changes = pyslurm.db.Job(comment="A comment for the job") + >>> modified_jobs = pyslurm.db.Jobs.modify(db_filter, changes) + >>> print(modified_jobs) + [9999] + + In the above example, the changes will be automatically committed + if successful. + You can however also control this manually by providing your own + connection object: + + >>> import pyslurm + >>> + >>> db_conn = pyslurm.db.Connection.open() + >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) + >>> changes = pyslurm.db.Job(comment="A comment for the job") + >>> modified_jobs = pyslurm.db.Jobs.modify( + ... db_filter, changes, db_conn) + + Now you can first examine which Jobs have been modified: + + >>> print(modified_jobs) + [9999] + + And then you can actually commit the changes: + + >>> db_conn.commit() + + You can also explicitly rollback these changes instead of + committing, so they will not become active: + + >>> db_conn.rollback() + """ + cdef: + JobFilter cond + SlurmList response + SlurmListItem response_ptr + list out = [] + + self.db_conn.validate() + + # Prepare SQL Filter + if isinstance(db_filter, Jobs): + job_ids = [job.id for job in self] + cond = JobFilter(ids=job_ids) + + cond._db_conn = db_conn + cond._create() + + # Modify Jobs, get the result + # This returns a List of char* with the Jobs ids that were + # modified + response = SlurmList.wrap( + slurmdb_job_modify(self.db_conn.ptr, cond.ptr, changes.ptr)) + + if not response.is_null and response.cnt: + for response_ptr in response: + response_str = cstr.to_unicode(response_ptr.data) + if not response_str: + continue + + # The strings in the list returned above have a structure + # like this: + # + # " submitted at " + # + # We are just interested in the Job-ID, so extract it + job_id = response_str.split(" ")[0] + if job_id and job_id.isdigit(): + out.append(int(job_id)) + + elif not response.is_null: + # There was no real error, but simply nothing has been modified + raise RPCError(msg="Nothing was modified") + else: + # Autodetects the last slurm error + raise RPCError() + + return out + + cdef class JobFilter: def __cinit__(self): @@ -53,6 +247,7 @@ cdef class JobFilter: def __init__(self, **kwargs): for k, v in kwargs.items(): setattr(self, k, v) + self._db_conn = None def __dealloc__(self): self._dealloc() @@ -75,7 +270,7 @@ cdef class JobFilter: return None qos_id_list = [] - qos_data = QualitiesOfService.load() + qos_data = self._db_conn.qos.load(self._db_conn, name_is_key=True) for user_input in self.qos: found = False for qos in qos_data.values(): @@ -199,6 +394,7 @@ cdef class Jobs(MultiClusterMap): id_attr=Job.id, key_type=int) self._reset_stats() + self._db_conn = None @staticmethod def load(Connection db_conn, JobFilter db_filter=None): @@ -239,75 +435,9 @@ cdef class Jobs(MultiClusterMap): >>> db_filter = pyslurm.db.JobFilter(accounts=accounts) >>> db_jobs = pyslurm.db.Jobs.load(db_filter) """ - cdef: - Jobs out = Jobs() - Job job - JobFilter cond = db_filter - SlurmList job_data - SlurmListItem job_ptr - QualitiesOfService qos_data - TrackableResources tres_data - - db_conn.validate() - - # Prepare SQL Filter - if not db_filter: - cond = JobFilter() - cond._create() + return db_conn.jobs.load(db_filter) - # Fetch Job data - job_data = SlurmList.wrap(slurmdb_jobs_get(db_conn.ptr, cond.ptr)) - if job_data.is_null: - raise RPCError(msg="Failed to get Jobs from slurmdbd") - - # Fetch other necessary dependencies needed for translating some - # attributes (i.e QoS IDs to its name) - qos_data = QualitiesOfService.load(db_conn=db_conn, - name_is_key=False) - tres_data = TrackableResources.load(db_conn=db_conn) - - # TODO: How to handle the possibility of duplicate job ids that could - # appear if IDs on a cluster are reset? - for job_ptr in SlurmList.iter_and_pop(job_data): - job = Job.from_ptr(job_ptr.data) - job.qos_data = qos_data - job.tres_data = tres_data - job._create_steps() - job.stats = JobStatistics.from_steps(job.steps) - - elapsed = job.elapsed_time if job.elapsed_time else 0 - cpus = job.cpus if job.cpus else 1 - job.stats.elapsed_cpu_time = elapsed * cpus - - cluster = job.cluster - if cluster not in out.data: - out.data[cluster] = {} - out[cluster][job.id] = job - - out._add_stats(job) - - return out - - def _reset_stats(self): - self.stats = JobStatistics() - self.cpus = 0 - self.nodes = 0 - self.memory = 0 - - def _add_stats(self, job): - self.stats.add(job.stats) - self.cpus += job.cpus - self.nodes += job.num_nodes - self.memory += job.memory - - def calc_stats(self): - """(Re)Calculate Statistics for the Job Collection.""" - self._reset_stats() - for job in self.values(): - self._add_stats(job) - - @staticmethod - def modify(Connection db_conn, db_filter, Job changes): + def modify(self, changes: Job, db_conn: Connection | None = None): """Modify Slurm database Jobs. Implements the slurm_job_modify RPC. @@ -367,52 +497,27 @@ cdef class Jobs(MultiClusterMap): >>> db_conn.rollback() """ - cdef: - JobFilter cond - SlurmList response - SlurmListItem response_ptr - list out = [] - - db_conn.validate() - - # Prepare SQL Filter - if isinstance(db_filter, Jobs): - job_ids = [job.id for job in self] - cond = JobFilter(ids=job_ids) - else: - cond = db_filter - cond._create() - - # Modify Jobs, get the result - # This returns a List of char* with the Jobs ids that were - # modified - response = SlurmList.wrap( - slurmdb_job_modify(db_conn.ptr, cond.ptr, changes.ptr)) - - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue + db_conn = Connection.reuse(self._db_conn, db_conn) + db_filter = JobFilter(names=list(self.keys())) + return db_conn.jobs.modify(db_filter, changes) - # The strings in the list returned above have a structure - # like this: - # - # " submitted at " - # - # We are just interested in the Job-ID, so extract it - job_id = response_str.split(" ")[0] - if job_id and job_id.isdigit(): - out.append(int(job_id)) + def _reset_stats(self): + self.stats = JobStatistics() + self.cpus = 0 + self.nodes = 0 + self.memory = 0 - elif not response.is_null: - # There was no real error, but simply nothing has been modified - raise RPCError(msg="Nothing was modified") - else: - # Autodetects the last slurm error - raise RPCError() + def _add_stats(self, job): + self.stats.add(job.stats) + self.cpus += job.cpus + self.nodes += job.num_nodes + self.memory += job.memory - return out + def calc_stats(self): + """(Re)Calculate Statistics for the Job Collection.""" + self._reset_stats() + for job in self.values(): + self._add_stats(job) cdef class Job: @@ -450,10 +555,18 @@ cdef class Job: return wrap @staticmethod - def load(job_id, cluster=None, with_script=False, with_env=False): + def load( + db_conn: Connection, + job_id: int, + cluster = None, + with_script = False, + with_env = False + ): """Load the information for a specific Job from the Database. Args: + db_conn (pyslurm.db.Connection): + A slurmdbd connection. job_id (int): ID of the Job to be loaded. cluster (str): @@ -527,7 +640,7 @@ cdef class Job: def __repr__(self): return f'pyslurm.db.{self.__class__.__name__}({self.id})' - def modify(self, Connection db_conn, changes): + def modify(self, changes, db_conn: Connection | None = None): """Modify a Slurm database Job. Args: @@ -544,8 +657,7 @@ cdef class Job: Raises: (pyslurm.RPCError): When modifying the Job failed. """ - cdef JobFilter jfilter = JobFilter(ids=[self.id]) - Jobs.modify(db_conn, jfilter, changes) + Jobs({self.id: self}).modify(changes, self._db_conn or db_conn) @property def account(self): diff --git a/pyslurm/db/qos.pxd b/pyslurm/db/qos.pxd index 02131974..d69d3ed5 100644 --- a/pyslurm/db/qos.pxd +++ b/pyslurm/db/qos.pxd @@ -38,17 +38,22 @@ from pyslurm.db.util cimport ( SlurmListItem, make_char_list, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper from pyslurm.utils cimport cstr from pyslurm.utils.uint cimport u16_set_bool_flag cdef _set_qos_list(list_t **in_list, vals, QualitiesOfService data) -cdef class QualitiesOfService(dict): +cdef class QualityOfServiceAPI(ConnectionWrapper): pass +cdef class QualitiesOfService(dict): + cdef public: + Connection _db_conn + + cdef class QualityOfServiceFilter: cdef slurmdb_qos_cond_t *ptr @@ -61,6 +66,9 @@ cdef class QualityOfServiceFilter: cdef class QualityOfService: + cdef public: + Connection _db_conn + cdef slurmdb_qos_rec_t *ptr @staticmethod diff --git a/pyslurm/db/qos.pyx b/pyslurm/db/qos.pyx index f896eb86..e780f989 100644 --- a/pyslurm/db/qos.pyx +++ b/pyslurm/db/qos.pyx @@ -26,14 +26,13 @@ from pyslurm.core.error import RPCError from pyslurm.utils.helpers import instance_to_dict -cdef class QualitiesOfService(dict): - - def __init__(self): - pass +cdef class QualityOfServiceAPI(ConnectionWrapper): - @staticmethod - def load(Connection db_conn, QualityOfServiceFilter db_filter=None, - name_is_key=True): + def load( + self, + db_filter: QualityOfServiceFilter | None = None, + name_is_key: bool = True + ): """Load QoS data from the Database Args: @@ -45,17 +44,16 @@ cdef class QualitiesOfService(dict): cdef: QualitiesOfService out = QualitiesOfService() QualityOfService qos - QualityOfServiceFilter cond = db_filter SlurmList qos_data SlurmListItem qos_ptr - db_conn.validate() + self.db_conn.validate() if not db_filter: - cond = QualityOfServiceFilter() - cond._create() + db_filter = QualityOfServiceFilter() + db_filter._create() - qos_data = SlurmList.wrap(slurmdb_qos_get(db_conn.ptr, cond.ptr)) + qos_data = SlurmList.wrap(slurmdb_qos_get(self.db_conn.ptr, db_filter.ptr)) if qos_data.is_null: raise RPCError(msg="Failed to get QoS data from slurmdbd") @@ -63,12 +61,38 @@ cdef class QualitiesOfService(dict): # Setup QOS objects for qos_ptr in SlurmList.iter_and_pop(qos_data): qos = QualityOfService.from_ptr(qos_ptr.data) + self.db_conn.apply_reuse(qos) _id = qos.name if name_is_key else qos.id out[_id] = qos return out +cdef class QualitiesOfService(dict): + + def __init__(self, qos={}, **kwargs): + super().__init__() + self.update(qos) + self.update(kwargs) + self._db_conn = None + + @staticmethod + def load( + db_conn: Connection, + db_filter: Connection | None = None, + name_is_key: bool = True + ): + """Load QoS data from the Database + + Args: + name_is_key (bool, optional): + By default, the keys in this dict are the names of each QoS. + If this is set to `False`, then the unique ID of the QoS will + be used as dict keys. + """ + return db_conn.qos.load(db_filter=db_filter, name_is_key=name_is_key) + + cdef class QualityOfServiceFilter: def __cinit__(self): @@ -164,7 +188,7 @@ cdef class QualityOfService: return instance_to_dict(self, recursive) @staticmethod - def load(name): + def load(db_conn: Connection, name: str): """Load the information for a specific Quality of Service. Args: @@ -179,8 +203,7 @@ cdef class QualityOfService: (pyslurm.RPCError): If requesting the information from the database was not successful. """ - qfilter = QualityOfServiceFilter(names=[name]) - qos = QualitiesOfService.load(qfilter).get(name) + qos = db_conn.qos.load(name_is_key=True).get(name) if not qos: raise RPCError(msg=f"QualityOfService {name} does not exist") diff --git a/pyslurm/db/tres.pxd b/pyslurm/db/tres.pxd index 952c8a7d..53dfe28d 100644 --- a/pyslurm/db/tres.pxd +++ b/pyslurm/db/tres.pxd @@ -37,7 +37,7 @@ from pyslurm.db.util cimport ( SlurmList, SlurmListItem, ) -from pyslurm.db.connection cimport Connection +from pyslurm.db.connection cimport Connection, ConnectionWrapper cdef find_tres_count(char *tres_str, typ, on_noval=*, on_inf=*) cdef find_tres_limit(char *tres_str, typ) @@ -46,6 +46,10 @@ cdef _tres_ids_to_names(char *tres_str, dict tres_id_map) cdef _set_tres_limits(char **dest, src, tres_data) +cdef class TrackeblResourceAPI(ConnectionWrapper): + pass + + cdef class FilesystemResources(dict): """Collection of Filesystem TRES. This inherits from `dict`.""" pass diff --git a/pyslurm/db/tres.pyx b/pyslurm/db/tres.pyx index 12b90fa6..53436add 100644 --- a/pyslurm/db/tres.pyx +++ b/pyslurm/db/tres.pyx @@ -42,6 +42,37 @@ TRES_NAME_REQUIRED = ["fs", "license", "interconnect", "gres"] gres_pattern = re.compile(r'[/:]') +cdef class TrackableResourceAPI(ConnectionWrapper): + + def load(self, db_filter: TrackableResourceFilter = None): + """Load Trackable Resources from the Database.""" + cdef: + TrackableResources out = TrackableResources() + TrackableResource tres + SlurmList tres_data + SlurmListItem tres_ptr + + self.db_conn.validate() + + if not db_filter: + db_filter = TrackableResourceFilter() + + db_filter._create() + tres_data = SlurmList.wrap(slurmdb_tres_get(self.db_conn.ptr, db_filter.ptr)) + + if tres_data.is_null: + raise RPCError(msg="Failed to get TRES data from slurmdbd") + + # Setup TRES objects + for tres_ptr in SlurmList.iter_and_pop(tres_data): + tres = TrackableResource.from_ptr( + tres_ptr.data) + out._handle_tres_type(tres) + out._id_map[tres.id] = tres + + return out + + cdef class FilesystemResources(dict): def to_dict(self, recursive=False): @@ -296,28 +327,7 @@ cdef class TrackableResources: @staticmethod def load(Connection db_conn): """Load Trackable Resources from the Database.""" - cdef: - TrackableResources out = TrackableResources() - TrackableResource tres - SlurmList tres_data - SlurmListItem tres_ptr - TrackableResourceFilter db_filter = TrackableResourceFilter() - - db_conn.validate() - db_filter._create() - tres_data = SlurmList.wrap(slurmdb_tres_get(db_conn.ptr, db_filter.ptr)) - - if tres_data.is_null: - raise RPCError(msg="Failed to get TRES data from slurmdbd") - - # Setup TRES objects - for tres_ptr in SlurmList.iter_and_pop(tres_data): - tres = TrackableResource.from_ptr( - tres_ptr.data) - out._handle_tres_type(tres) - out._id_map[tres.id] = tres - - return out + return db_conn.tres.load() @staticmethod cdef find_count_in_str(char *tres_str, typ, on_noval=0, on_inf=0): diff --git a/pyslurm/db/user.pxd b/pyslurm/db/user.pxd index b639edfc..d5a27091 100644 --- a/pyslurm/db/user.pxd +++ b/pyslurm/db/user.pxd @@ -61,11 +61,8 @@ cdef class UserAPI(ConnectionWrapper): cdef class Users(dict): - pass - - -# cdef class UserAddRequest: -# cdef slurmdb_add_assoc_cond_t *ptr + cdef public: + Connection _db_conn cdef class UserFilter: @@ -106,6 +103,7 @@ cdef class User: associations coordinators wckeys + Connection _db_conn cdef readonly: default_association diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx index ed049d9e..e79f2cf1 100644 --- a/pyslurm/db/user.pyx +++ b/pyslurm/db/user.pyx @@ -65,12 +65,12 @@ cdef class UserAPI(ConnectionWrapper): if user_data.is_null: raise RPCError(msg="Failed to get User data from slurmdbd") - qos_data = QualitiesOfService.load(db_conn=self.db_conn, - name_is_key=False) - tres_data = TrackableResources.load(db_conn=self.db_conn) + qos_data = self.db_conn.qos.load(name_is_key=False) + tres_data = self.db_conn.tres.load() for user_ptr in SlurmList.iter_and_pop(user_data): user = User.from_ptr(user_ptr.data) + self.db_conn.apply_reuse(user) out[user.name] = user assoc_data = SlurmList.wrap(user.ptr.assoc_list, owned=False) @@ -80,10 +80,12 @@ cdef class UserAPI(ConnectionWrapper): assoc.tres_data = tres_data _parse_assoc_ptr(assoc) user.associations.append(assoc) + self.db_conn.apply_reuse(assoc) if assoc.user == user.name: user.default_association = assoc + self.db_conn.apply_reuse(out) return out @@ -226,7 +228,7 @@ cdef class UserAPI(ConnectionWrapper): verify_rpc(rc) # TODO: Maybe don't create the associations automatically? And don't do # any hidden stuff? - Associations.create(self.db_conn, assocs_to_add) + self.db_conn.associations.create(assocs_to_add) cdef class Users(dict): @@ -235,20 +237,24 @@ cdef class Users(dict): super().__init__() self.update(users) self.update(kwargs) + self._db_conn = None @staticmethod def load(db_conn: Connection, db_filter: UserFilter = None): return db_conn.users.load(db_filter) - def delete(self, Connection db_conn): + def delete(self, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = UserFilter(names=list(self.keys())) db_conn.users.delete(db_filter) - def modify(self, db_conn: Connection, changes: User): + def modify(self, changes: User, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = UserFilter(names=list(self.keys())) return db_conn.users.modify(db_filter, changes) - def create(self, db_conn: Connection): + def create(self, db_conn: Connection | None = None): + db_conn = Connection.reuse(self._db_conn, db_conn) db_conn.users.create(list(self.values())) @@ -351,20 +357,20 @@ cdef class User: return NotImplemented @staticmethod - def load(Connection db_conn, name): + def load(db_conn: Connection, name: str): user = db_conn.users.load().get(name) if not user: raise RPCError(msg=f"User {name} does not exist.") return user - def create(self, Connection db_conn): - db_conn.users.create([self]) + def create(self, db_conn: Connection = None): + Users({self.name: self}).create(self._db_conn or db_conn) - def delete(self, Connection db_conn): - Users({self.name: self}).delete(db_conn) + def delete(self, db_conn: Connection = None): + Users({self.name: self}).delete(self._db_conn or db_conn) - def modify(self, Connection db_conn, User changes): - Users({self.name: self}).modify(db_conn, changes) + def modify(self, changes: User, db_conn: Connection | None = None): + Users({self.name: self}).modify(changes, self._db_conn or db_conn) @property def name(self): diff --git a/tests/integration/test_assoc.py b/tests/integration/test_assoc.py index 7170ef20..5a62ba07 100644 --- a/tests/integration/test_assoc.py +++ b/tests/integration/test_assoc.py @@ -35,7 +35,7 @@ def _modify_account(account, conn): changes = Account(description=new_desc) assert account.description != new_desc assoc_before = account.association.to_dict(recursive=True) - account.modify(conn, changes) + account.modify(changes, conn) account = Account.load(conn, account.name) assoc_after = account.association.to_dict(recursive=True) assert account.description == new_desc @@ -49,7 +49,7 @@ def _modify_user(user, conn): ) assert user.admin_level == pyslurm.AdminLevel.NONE assoc_before = user.default_association.to_dict(recursive=True) - user.modify(conn, user_changes) + user.modify(user_changes, conn) user = User.load(conn, user.name) assoc_after = user.default_association.to_dict(recursive=True) assert user.admin_level == pyslurm.AdminLevel.ADMINISTRATOR @@ -58,25 +58,25 @@ def _modify_user(user, conn): def _load_assoc(assoc_id, conn): - assocs = pyslurm.db.Associations.load(conn) + assocs = conn.associations.load() return assocs.get(assoc_id) def _load_account(name, conn): - accounts = pyslurm.db.Accounts.load(conn) + accounts = conn.accounts.load() assert len(accounts) return accounts.get(name) def _load_user(name, conn): - users = pyslurm.db.Users.load(conn) + users = conn.users.load() assert len(users) return users.get(name) def _delete_account(account, conn): account = Account.load(conn, account.name) - account.delete(conn) + account.delete() conn.commit() assert not _load_account(account.name, conn) assert not _load_assoc(account.association.id, conn) @@ -119,7 +119,7 @@ def _add_user(user, conn): def _delete_user(user, conn): assoc_id = user.default_association.id user = User.load(conn, user.name) - user.delete(conn) + user.delete() conn.commit() assert not _load_user(user.name, conn) assert not _load_assoc(assoc_id, conn) diff --git a/tests/integration/test_db_job.py b/tests/integration/test_db_job.py index 310df51f..a992c65e 100644 --- a/tests/integration/test_db_job.py +++ b/tests/integration/test_db_job.py @@ -38,7 +38,9 @@ def test_load_single(submit_job): job = submit_job() util.wait() - db_job = pyslurm.db.Job.load(job.id) + + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(job.id) assert db_job.id == job.id @@ -49,7 +51,10 @@ def test_load_single(submit_job): def test_parse_all(submit_job): job = submit_job() util.wait() - db_job = pyslurm.db.Job.load(job.id) + + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(job.id) + job_dict = db_job.to_dict() assert job_dict["stats"] @@ -60,8 +65,9 @@ def test_to_json(submit_job): job = submit_job() util.wait() - jfilter = pyslurm.db.JobFilter(ids=[job.id]) - jobs = pyslurm.db.Jobs.load(jfilter) + with pyslurm.db.connect() as conn: + jfilter = pyslurm.db.JobFilter(ids=[job.id]) + jobs = conn.jobs.load(jfilter) json_data = jobs.to_json() dict_data = json.loads(json_data) @@ -74,29 +80,30 @@ def test_modify(submit_job): job = submit_job() util.wait(5) - jfilter = pyslurm.db.JobFilter(ids=[job.id]) - changes = pyslurm.db.Job(comment="test comment") - pyslurm.db.Jobs.modify(jfilter, changes) + with pyslurm.db.connect() as conn: + jfilter = pyslurm.db.JobFilter(ids=[job.id]) + changes = pyslurm.db.Job(comment="test comment") + pyslurm.db.Jobs.modify(jfilter, changes) job = pyslurm.db.Job.load(job.id) assert job.comment == "test comment" -def test_modify_with_existing_conn(submit_job): +def test_modify_with_no_auto_commit(submit_job): job = submit_job() util.wait(5) - conn = pyslurm.db.Connection.open() - jfilter = pyslurm.db.JobFilter(ids=[job.id]) - changes = pyslurm.db.Job(comment="test comment") - pyslurm.db.Jobs.modify(jfilter, changes, conn) + with pyslurm.db.connect(commit_on_success=False) as conn: + jfilter = pyslurm.db.JobFilter(ids=[job.id]) + changes = pyslurm.db.Job(comment="test comment") + conn.jobs.modify(jfilter, changes) - job = pyslurm.db.Job.load(job.id) - assert job.comment != "test comment" + job = pyslurm.db.Job.load(job.id) + assert job.comment != "test comment" - conn.commit() - job = pyslurm.db.Job.load(job.id) - assert job.comment == "test comment" + conn.commit() + job = pyslurm.db.Job.load(job.id) + assert job.comment == "test comment" def test_if_steps_exist(submit_job): @@ -128,12 +135,16 @@ def test_load_with_script(submit_job): script = util.create_job_script() job = submit_job(script=script) util.wait(5) - db_job = pyslurm.db.Job.load(job.id, with_script=True) + + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(job.id, with_script=True) assert db_job.script == script def test_load_with_env(submit_job): job = submit_job() util.wait(5) - db_job = pyslurm.db.Job.load(job.id, with_env=True) + + with pyslurm.db.connect() as conn: + db_job = pyslurm.db.Job.load(job.id, with_env=True) assert db_job.environment diff --git a/tests/integration/test_db_qos.py b/tests/integration/test_db_qos.py index e1cde024..9a041a12 100644 --- a/tests/integration/test_db_qos.py +++ b/tests/integration/test_db_qos.py @@ -27,29 +27,32 @@ def test_load_single(): - qos = pyslurm.db.QualityOfService.load("normal") + with pyslurm.db.connect() as conn: + qos = pyslurm.db.QualityOfService.load(conn, "normal") - assert qos.name == "normal" - assert qos.id == 1 + assert qos.name == "normal" + assert qos.id == 1 - with pytest.raises(pyslurm.RPCError): - pyslurm.db.QualityOfService.load("qos_non_existent") + with pytest.raises(pyslurm.RPCError): + pyslurm.db.QualityOfService.load(conn, "qos_non_existent") def test_parse_all(submit_job): - qos = pyslurm.db.QualityOfService.load("normal") - qos_dict = qos.to_dict() + with pyslurm.db.connect() as conn: + qos = pyslurm.db.QualityOfService.load(conn, "normal") + qos_dict = qos.to_dict() - assert qos_dict - assert qos_dict["name"] == qos.name + assert qos_dict + assert qos_dict["name"] == qos.name def test_load_all(): - qos = pyslurm.db.QualitiesOfService.load() - assert qos - + with pyslurm.db.connect() as conn: + qos = conn.qos.load() + assert qos def test_load_with_filter_name(): - qfilter = pyslurm.db.QualityOfServiceFilter(names=["non_existent"]) - qos = pyslurm.db.QualitiesOfService.load(qfilter) - assert not qos + with pyslurm.db.connect() as conn: + db_filter = pyslurm.db.QualityOfServiceFilter(names=["non_existent"]) + qos = conn.qos.load(db_filter) + assert not qos From ee046b2f242241eacd090ab596bd86d1315e3d9a Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Fri, 27 Feb 2026 22:54:00 +0100 Subject: [PATCH 09/13] wip with proper type hints --- pyslurm/core/error.pyx | 35 ++++++++- pyslurm/db/account.pyx | 52 ++++++++------ pyslurm/db/assoc.pyx | 97 ++++++++++++++++--------- pyslurm/db/job.pyx | 120 ++++++++++++------------------- pyslurm/db/qos.pyx | 25 ++++--- pyslurm/db/tres.pyx | 3 + pyslurm/db/user.pyx | 55 ++++++++------ tests/integration/test_assoc.py | 90 ++++++++++++++++------- tests/integration/test_db_job.py | 106 +++++++++++++++++---------- 9 files changed, 364 insertions(+), 219 deletions(-) diff --git a/pyslurm/core/error.pyx b/pyslurm/core/error.pyx index f186a5b0..954aa2a6 100644 --- a/pyslurm/core/error.pyx +++ b/pyslurm/core/error.pyx @@ -27,6 +27,19 @@ from pyslurm cimport slurm cimport libc.errno +def _check_modify_arguments(changes, **kwargs): + if changes is None and not kwargs: + raise ArgumentError("Nothing to change was provided") + + if changes is not None and kwargs: + raise ArgumentError("Provide either a changes object or keyword arguments, not both") + + +def _get_modify_arguments_for(cls, changes, **kwargs): + _check_modify_arguments(changes, **kwargs) + return changes or cls(**kwargs) + + def slurm_strerror(errno): """Convert a slurm errno to a string. @@ -69,7 +82,15 @@ class PyslurmError(Exception): """The base Exception for all Pyslurm errors.""" -class RPCError(PyslurmError): +class ClientError(PyslurmError): + pass + + +class ServerError(PyslurmError): + pass + + +class RPCError(ServerError): """Exception for handling Slurm RPC errors. Args: @@ -100,6 +121,18 @@ class RPCError(PyslurmError): super().__init__(self.msg) +class InvalidUsageError(ClientError): + pass + + +class ArgumentError(InvalidUsageError): + pass + + +class NotFoundError(RPCError): + pass + + def verify_rpc(errno, msg=None): """Verify a Slurm RPC diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx index 4338acc4..1d369ff2 100644 --- a/pyslurm/db/account.pyx +++ b/pyslurm/db/account.pyx @@ -1,7 +1,7 @@ ######################################################################### # account.pyx - pyslurm slurmdbd account api ######################################################################### -# Copyright (C) 2023 Toni Harzendorf +# Copyright (C) 2026 Toni Harzendorf # # This file is part of PySlurm # @@ -22,7 +22,13 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError, verify_rpc, slurm_errno +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.helpers import ( instance_to_dict, user_to_uid, @@ -34,11 +40,12 @@ from pyslurm.db.error import ( JobsRunningError, parse_basic_response, ) +from typing import Any, Union, Optional, List, Dict cdef class AccountAPI(ConnectionWrapper): - def load(self, db_filter: AccountFilter = None): + def load(self, db_filter: Optional[AccountFilter] = None): cdef: Accounts out = Accounts() Account account @@ -125,17 +132,20 @@ cdef class AccountAPI(ConnectionWrapper): verify_rpc(rc) - def modify(self, db_filter: AccountFilter, changes: Account): + def modify(self, db_filter: AccountFilter, changes: Optional[Account] = None, **kwargs: Any): cdef: SlurmList response + Account _changes SlurmListItem response_ptr list out = [] + _changes = _get_modify_arguments_for(Account, changes, **kwargs) + self.db_conn.validate() db_filter._create() response = SlurmList.wrap(slurmdb_accounts_modify( - self.db_conn.ptr, db_filter.ptr, changes.ptr) + self.db_conn.ptr, db_filter.ptr, _changes.ptr) ) rc = slurm_errno() self.db_conn.check_commit(rc) @@ -165,7 +175,7 @@ cdef class AccountAPI(ConnectionWrapper): # raise RPCError(msg="Failed to modify accounts.") - def create(self, accounts): + def create(self, accounts: List[str]): cdef: Account account SlurmList account_list @@ -193,27 +203,27 @@ cdef class AccountAPI(ConnectionWrapper): cdef class Accounts(dict): - def __init__(self, accounts={}, **kwargs): + def __init__(self, accounts={}, **kwargs: Any): super().__init__() self.update(accounts) self.update(kwargs) self._db_conn = None @staticmethod - def load(db_conn: Connection, db_filter: AccountFilter = None): + def load(db_conn: Connection, db_filter: Optional[AccountFilter] = None): return db_conn.accounts.load(db_filter) - def delete(self, db_conn: Connection | None = None): + def delete(self, db_conn: Optional[Connection] = None): db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = AccountFilter(names=list(self.keys())) db_conn.accounts.delete(db_filter) - def modify(self, changes: Account, db_conn: Connection | None = None): + def modify(self, changes: Optional[Account] = None, db_conn: Optional[Connection] = None, **kwargs: Any): db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = AccountFilter(names=list(self.keys())) - return db_conn.accounts.modify(db_filter, changes) + return db_conn.accounts.modify(db_filter, changes, **kwargs) - def create(self, db_conn: Connection | None = None): + def create(self, db_conn: Optional[Connection] = None): db_conn = Connection.reuse(self._db_conn, db_conn) db_conn.accounts.create(list(self.values())) @@ -223,7 +233,7 @@ cdef class AccountFilter: def __cinit__(self): self.ptr = NULL - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for k, v in kwargs.items(): setattr(self, k, v) @@ -265,7 +275,7 @@ cdef class Account: def __cinit__(self): self.ptr = NULL - def __init__(self, name=None, description=None, organization=None, **kwargs): + def __init__(self, name: str = None, description: str = None, organization: str = None, **kwargs: Any): self._alloc_impl() self._init_defaults() self.name = name @@ -306,7 +316,7 @@ cdef class Account: wrap._init_defaults() return wrap - def to_dict(self, recursive=False): + def to_dict(self, recursive: bool = False): """Database Account information formatted as a dictionary. Returns: @@ -314,7 +324,7 @@ cdef class Account: """ return instance_to_dict(self, recursive) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Account): return self.name == other.name return NotImplemented @@ -325,17 +335,17 @@ cdef class Account: if not account: # TODO: Maybe don't raise here and just return None and let the # Caller handle it? - raise RPCError(msg=f"Account {name} does not exist.") + raise NotFoundError(msg=f"Account {name} does not exist.") return account - def create(self, db_conn: Connection | None = None): + def create(self, db_conn: Optional[Connection] = None): Accounts({self.name: self}).create(self._db_conn or db_conn) - def delete(self, db_conn: Connection | None = None): + def delete(self, db_conn: Optional[Connection] = None): Accounts({self.name: self}).delete(self._db_conn or db_conn) - def modify(self, changes: Account, db_conn: Connection | None = None): - Accounts({self.name: self}).modify(changes, self._db_conn or db_conn) + def modify(self, changes: Optional[Account] = None, db_conn: Optional[Connection] = None, **kwargs: Any): + Accounts({self.name: self}).modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) @property def name(self): diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index 8a086828..d92ef671 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -1,7 +1,7 @@ ######################################################################### # assoc.pyx - pyslurm slurmdbd association api ######################################################################### -# Copyright (C) 2023 Toni Harzendorf +# Copyright (C) 2026 Toni Harzendorf # # This file is part of PySlurm # @@ -22,7 +22,13 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError, verify_rpc, slurm_errno +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.helpers import ( instance_to_dict, user_to_uid, @@ -31,11 +37,12 @@ from pyslurm.utils.uint import * from pyslurm import settings from pyslurm import xcollections from pyslurm.db.error import JobsRunningError, DefaultAccountError +from typing import Any, Union, Optional, List, Dict cdef class AssociationAPI(ConnectionWrapper): - def load(self, db_filter: AssociationFilter = None): + def load(self, db_filter: Optional[AssociationFilter] = None): cdef: Associations out = Associations() Association assoc @@ -112,23 +119,28 @@ cdef class AssociationAPI(ConnectionWrapper): else: verify_rpc(rc) - - def modify(self, db_filter: AssociationFilter, changes: Association): + def modify(self, db_filter: AssociationFilter, changes: Optional[Association] = None, **kwargs: Any): cdef: + Association _changes SlurmList response SlurmListItem response_ptr list out = [] + # TODO: prohibit mixing multiple user assocs with account assocs + # This is not possible, and the request will simply affect nothing... + + _changes = _get_modify_arguments_for(Association, changes, **kwargs) + self.db_conn.validate() db_filter._create() # Any data that isn't parsed yet or needs validation is done in this # function. - _create_assoc_ptr(changes, self.db_conn) + _create_assoc_ptr(_changes, self.db_conn) # Returns a List of char* with the associations that were modified response = SlurmList.wrap(slurmdb_associations_modify( - self.db_conn.ptr, db_filter.ptr, changes.ptr)) + self.db_conn.ptr, db_filter.ptr, _changes.ptr)) rc = slurm_errno() self.db_conn.check_commit(rc) @@ -150,8 +162,7 @@ cdef class AssociationAPI(ConnectionWrapper): return out - - def create(self, associations): + def create(self, associations: List[Association]): cdef: Association assoc AssociationList assoc_list = AssociationList(owned=False) @@ -215,23 +226,45 @@ cdef class Associations(MultiClusterMap): key_type=int) self._db_conn = None + def _do_api_call(self, fn, **kwargs): + res_user = [] + res_accts = [] + db_filter_users = AssociationFilter(users=[], ids=[]) + db_filter_accounts = AssociationFilter(accounts=[], ids=[]) + + # We need to split User Associations from Account associations + # If we mix both, the request will not do anything... + for assoc in self.values(): + if assoc.user: + db_filter_users.users.append(assoc.user) + db_filter_users.ids.append(assoc.id) + else: + db_filter_accounts.accounts.append(assoc.account) + db_filter_accounts.ids.append(assoc.id) + + if db_filter_users.users: + res_user = fn(db_filter_users, **kwargs) + + if db_filter_accounts.accounts: + res_accts = fn(db_filter_accounts, **kwargs) + + return res_user + res_accts + @staticmethod - def load(db_conn: Connection, db_filter: AssociationFilter | None = None): + def load(db_conn: Connection, db_filter: Optional[AssociationFilter] = None): return db_conn.associations.load(db_filter) - def delete(self, db_conn: Connection | None = None): + def delete(self, db_conn: Optional[Connection] = None): db_conn = Connection.reuse(self._db_conn, db_conn) - db_filter = AssociationFilter(ids=list(self.keys())) - db_conn.associations.delete(db_filter, changes) + self._do_api_call(db_conn.associations.delete, changes=changes) - def modify(self, changes: Association, db_conn: Connection | None = None): + def modify(self, changes: Optional[Association] = None, db_conn: Optional[Connection] = None, **kwargs: Any): db_conn = Connection.reuse(self._db_conn, db_conn) - db_filter = AssociationFilter(ids=list(self.keys())) - return db_conn.associations.modify(db_filter, changes) + return self._do_api_call(db_conn.associations.modify, changes=changes, **kwargs) - def create(self, db_conn: Connection | None = None): + def create(self, db_conn: Optional[Connection] = None): db_conn = Connection.reuse(self._db_conn, db_conn) - db_conn.associations.create(list(self.values())) + self._do_api_call(db_conn.associations.create, associations=list(self.values())) cdef class AssociationFilter: @@ -239,7 +272,7 @@ cdef class AssociationFilter: def __cinit__(self): self.ptr = NULL - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for k, v in kwargs.items(): setattr(self, k, v) @@ -282,7 +315,7 @@ cdef class Association: self.ptr = NULL self.owned = True - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): self._alloc_impl() self.id = 0 @@ -321,7 +354,7 @@ cdef class Association: wrap.ptr = in_ptr return wrap - def to_dict(self, recursive=False): + def to_dict(self, recursive: bool = False): """Database Association information formatted as a dictionary. Returns: @@ -329,28 +362,28 @@ cdef class Association: """ return instance_to_dict(self, recursive) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, Association): # return self.id == other.id and self.cluster == other.cluster return self.cluster == other.cluster and self.partition == other.partition and self.account == other.account and self.user == other.user return NotImplemented -# @staticmethod -# def load(db_conn: Connection, name: str): -# user = db_conn.users.load().get(name) -# if not user: -# raise RPCError(msg=f"User {name} does not exist.") -# return user + @staticmethod + def load(db_conn: Connection, id: Union[str, int]): + assoc = db_conn.associations.load().get(int(id)) + if not assoc: + raise NotFoundError(msg=f"Association with id '{id}' does not exist.") + return assoc - def create(self, db_conn: Connection = None): + def create(self, db_conn: Optional[Connection] = None): db_conn = Connection.reuse(self._db_conn, db_conn) db_conn.associations.create([self]) - def delete(self, db_conn: Connection = None): + def delete(self, db_conn: Optional[Connection] = None): Associations({self.id: self}).delete(self._db_conn or db_conn) - def modify(self, changes: Association, db_conn: Connection | None = None): - Associations({self.id: self}).modify(changes, self._db_conn or db_conn) + def modify(self, changes: Optional[Association] = None, db_conn: Optional[Connection] = None, **kwargs: Any): + Associations({self.id: self}).modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) @property def account(self): diff --git a/pyslurm/db/job.pyx b/pyslurm/db/job.pyx index b70bae69..e396ba7e 100644 --- a/pyslurm/db/job.pyx +++ b/pyslurm/db/job.pyx @@ -23,7 +23,12 @@ # cython: language_level=3 from typing import Union, Any -from pyslurm.core.error import RPCError, PyslurmError +from pyslurm.core.error import ( + RPCError, + ArgumentError, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.uint import * from pyslurm import settings from pyslurm import xcollections @@ -98,7 +103,7 @@ cdef class JobsAPI(ConnectionWrapper): if not db_filter: db_filter = JobFilter() - db_filter._db_conn = db_conn + db_filter._db_conn = self.db_conn db_filter._create() # Fetch Job data @@ -119,6 +124,7 @@ cdef class JobsAPI(ConnectionWrapper): job.tres_data = tres_data job._create_steps() job.stats = JobStatistics.from_steps(job.steps) + self.db_conn.apply_reuse(job) elapsed = job.elapsed_time if job.elapsed_time else 0 cpus = job.cpus if job.cpus else 1 @@ -133,7 +139,12 @@ cdef class JobsAPI(ConnectionWrapper): return out - def modify(self, db_filter: JobFilter | Jobs, changes: Job): + def modify( + self, + db_filter: JobFilter | Jobs, + changes: Job | None = None, + **kwargs: Any + ): """Modify Slurm database Jobs. Implements the slurm_job_modify RPC. @@ -193,25 +204,30 @@ cdef class JobsAPI(ConnectionWrapper): """ cdef: JobFilter cond + Job _changes SlurmList response SlurmListItem response_ptr list out = [] + _changes = _get_modify_arguments_for(Job, changes, **kwargs) + self.db_conn.validate() # Prepare SQL Filter if isinstance(db_filter, Jobs): job_ids = [job.id for job in self] cond = JobFilter(ids=job_ids) + else: + cond = db_filter - cond._db_conn = db_conn + cond._db_conn = self.db_conn cond._create() # Modify Jobs, get the result # This returns a List of char* with the Jobs ids that were # modified response = SlurmList.wrap( - slurmdb_job_modify(self.db_conn.ptr, cond.ptr, changes.ptr)) + slurmdb_job_modify(self.db_conn.ptr, cond.ptr, _changes.ptr)) if not response.is_null and response.cnt: for response_ptr in response: @@ -437,14 +453,15 @@ cdef class Jobs(MultiClusterMap): """ return db_conn.jobs.load(db_filter) - def modify(self, changes: Job, db_conn: Connection | None = None): - """Modify Slurm database Jobs. - - Implements the slurm_job_modify RPC. + def modify( + self, + changes: Job | None = None, + db_conn: Connection | None = None, + **kwargs: Any + ): + """Modify all Database Jobs in this collection. Args: - db_filter (Union[pyslurm.db.JobFilter, pyslurm.db.Jobs]): - A filter to decide which Jobs should be modified. changes (pyslurm.db.Job): Another [pyslurm.db.Job][] object that contains all the changes to apply. Check the `Other Parameters` of the @@ -452,54 +469,19 @@ cdef class Jobs(MultiClusterMap): modified. db_conn (pyslurm.db.Connection): A Connection to the slurmdbd. + **kwargs (Any): + Instead of providing a separate `Job` object that has the + changes, you can also pass them as keyword args. Returns: (list[int]): A list of Jobs that were modified Raises: (pyslurm.RPCError): When a failure modifying the Jobs occurred. - - Examples: - In its simplest form, you can do something like this: - - >>> import pyslurm - >>> - >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) - >>> changes = pyslurm.db.Job(comment="A comment for the job") - >>> modified_jobs = pyslurm.db.Jobs.modify(db_filter, changes) - >>> print(modified_jobs) - [9999] - - In the above example, the changes will be automatically committed - if successful. - You can however also control this manually by providing your own - connection object: - - >>> import pyslurm - >>> - >>> db_conn = pyslurm.db.Connection.open() - >>> db_filter = pyslurm.db.JobFilter(ids=[9999]) - >>> changes = pyslurm.db.Job(comment="A comment for the job") - >>> modified_jobs = pyslurm.db.Jobs.modify( - ... db_filter, changes, db_conn) - - Now you can first examine which Jobs have been modified: - - >>> print(modified_jobs) - [9999] - - And then you can actually commit the changes: - - >>> db_conn.commit() - - You can also explicitly rollback these changes instead of - committing, so they will not become active: - - >>> db_conn.rollback() """ db_conn = Connection.reuse(self._db_conn, db_conn) - db_filter = JobFilter(names=list(self.keys())) - return db_conn.jobs.modify(db_filter, changes) + db_filter = JobFilter(ids=list(self.keys())) + return db_conn.jobs.modify(db_filter=db_filter, changes=changes, **kwargs) def _reset_stats(self): self.stats = JobStatistics() @@ -583,27 +565,15 @@ cdef class Job: (pyslurm.db.Job): Returns a new Database Job instance Raises: - (pyslurm.RPCError): If requesting the information for the database + (pyslurm.NotFoundError): If requesting the information for the database Job was not successful. - - Examples: - >>> import pyslurm - >>> db_job = pyslurm.db.Job.load(10000) - - In the above example, attributes like `script` and `environment` - are not populated. You must explicitly request one of them to be - loaded: - - >>> import pyslurm - >>> db_job = pyslurm.db.Job.load(10000, with_script=True) - >>> print(db_job.script) """ cluster = settings.LOCAL_CLUSTER if not cluster else cluster jfilter = JobFilter(ids=[int(job_id)], clusters=[cluster], with_script=with_script, with_env=with_env) - job = Jobs.load(jfilter).get((cluster, int(job_id))) + job = db_conn.jobs.load(jfilter).get((cluster, int(job_id))) if not job: - raise RPCError(msg=f"Job {job_id} does not exist on " + raise NotFoundError(msg=f"Job {job_id} does not exist on " f"Cluster {cluster}") # TODO: There might be multiple entries when job ids were reset. @@ -620,6 +590,8 @@ cdef class Job: step = JobStep.from_ptr(step_ptr.data) step.tres_data = self.tres_data self.steps[step.id] = step + # TODO: + # self.db_conn.apply_reuse(step) def as_dict(self): return self.to_dict() @@ -640,8 +612,13 @@ cdef class Job: def __repr__(self): return f'pyslurm.db.{self.__class__.__name__}({self.id})' - def modify(self, changes, db_conn: Connection | None = None): - """Modify a Slurm database Job. + def modify( + self, + changes: Job | None = None, + db_conn: Connection | None = None, + **kwargs: Any + ): + """Modify this Database Job. Args: changes (pyslurm.db.Job): @@ -657,7 +634,8 @@ cdef class Job: Raises: (pyslurm.RPCError): When modifying the Job failed. """ - Jobs({self.id: self}).modify(changes, self._db_conn or db_conn) + jobs = Jobs({self.id: self}) + jobs.modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) @property def account(self): @@ -948,10 +926,6 @@ cdef class Job: def wckey_id(self): return u32_parse(self.ptr.wckeyid) -# @property -# def wckey_id(self): -# return u32_parse(self.ptr.wckeyid) - @property def working_directory(self): return cstr.to_unicode(self.ptr.work_dir) diff --git a/pyslurm/db/qos.pyx b/pyslurm/db/qos.pyx index e780f989..80fa7805 100644 --- a/pyslurm/db/qos.pyx +++ b/pyslurm/db/qos.pyx @@ -22,15 +22,22 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.helpers import instance_to_dict +from typing import Any, Union, Optional, List, Dict cdef class QualityOfServiceAPI(ConnectionWrapper): def load( self, - db_filter: QualityOfServiceFilter | None = None, + db_filter: Optional[QualityOfServiceFilter] = None, name_is_key: bool = True ): """Load QoS data from the Database @@ -70,7 +77,7 @@ cdef class QualityOfServiceAPI(ConnectionWrapper): cdef class QualitiesOfService(dict): - def __init__(self, qos={}, **kwargs): + def __init__(self, qos={}, **kwargs: Any): super().__init__() self.update(qos) self.update(kwargs) @@ -79,7 +86,7 @@ cdef class QualitiesOfService(dict): @staticmethod def load( db_conn: Connection, - db_filter: Connection | None = None, + db_filter: Optional[Connection] = None, name_is_key: bool = True ): """Load QoS data from the Database @@ -98,7 +105,7 @@ cdef class QualityOfServiceFilter: def __cinit__(self): self.ptr = NULL - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for k, v in kwargs.items(): setattr(self, k, v) @@ -152,9 +159,11 @@ cdef class QualityOfService: def __cinit__(self): self.ptr = NULL - def __init__(self, name=None): + def __init__(self, name: str = None, **kwargs: Any): self._alloc_impl() self.name = name + for k, v in kwargs.items(): + setattr(self, k, v) def __dealloc__(self): self._dealloc_impl() @@ -179,7 +188,7 @@ cdef class QualityOfService: def __repr__(self): return f'pyslurm.db.{self.__class__.__name__}({self.name})' - def to_dict(self, recursive = False): + def to_dict(self, recursive: bool = False): """Database QualityOfService information formatted as a dictionary. Returns: @@ -205,7 +214,7 @@ cdef class QualityOfService: """ qos = db_conn.qos.load(name_is_key=True).get(name) if not qos: - raise RPCError(msg=f"QualityOfService {name} does not exist") + raise NotFoundError(msg=f"QualityOfService {name} does not exist") return qos diff --git a/pyslurm/db/tres.pyx b/pyslurm/db/tres.pyx index 53436add..faba34d9 100644 --- a/pyslurm/db/tres.pyx +++ b/pyslurm/db/tres.pyx @@ -492,4 +492,7 @@ def _validate_tres_single(local_tres, dict tres_id_map): cdef _set_tres_limits(char **dest, src, tres_data): + if not src: + return + # TODO: Allow users to pass a dict cstr.from_dict(dest, src._validate(tres_data)) diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx index e79f2cf1..d851dea5 100644 --- a/pyslurm/db/user.pyx +++ b/pyslurm/db/user.pyx @@ -1,7 +1,7 @@ ######################################################################### # user.pyx - pyslurm slurmdbd user api ######################################################################### -# Copyright (C) 2023 Toni Harzendorf +# Copyright (C) 2026 Toni Harzendorf # # This file is part of PySlurm # @@ -22,7 +22,13 @@ # cython: c_string_type=unicode, c_string_encoding=default # cython: language_level=3 -from pyslurm.core.error import RPCError, slurm_errno, verify_rpc +from pyslurm.core.error import ( + RPCError, + slurm_errno, + verify_rpc, + NotFoundError, + _get_modify_arguments_for, +) from pyslurm.utils.helpers import ( instance_to_dict, user_to_uid, @@ -32,11 +38,12 @@ from pyslurm import xcollections from pyslurm.utils.enums import SlurmEnum from pyslurm.db.error import JobsRunningError, parse_basic_response from pyslurm.enums import AdminLevel +from typing import Any, Union, Optional, List, Dict cdef class UserAPI(ConnectionWrapper): - def load(self, db_filter: UserFilter = None): + def load(self, db_filter: Optional[UserFilter] = None): cdef: Users out = Users() User user @@ -127,9 +134,10 @@ cdef class UserAPI(ConnectionWrapper): verify_rpc(rc) - def modify(self, db_filter: UserFilter, changes: User): + def modify(self, db_filter: UserFilter, changes: Optional[User] = None, **kwargs: Any): cdef: SlurmList response + User _changes SlurmListItem response_ptr list out = [] @@ -139,11 +147,13 @@ cdef class UserAPI(ConnectionWrapper): #if not db_filter.names: # return + _changes = _get_modify_arguments_for(User, changes, **kwargs) + self.db_conn.validate() db_filter._create() response = SlurmList.wrap(slurmdb_users_modify( - self.db_conn.ptr, db_filter.ptr, changes.ptr) + self.db_conn.ptr, db_filter.ptr, _changes.ptr) ) rc = slurm_errno() self.db_conn.check_commit(rc) @@ -182,7 +192,7 @@ cdef class UserAPI(ConnectionWrapper): # raise RPCError(msg="Failed to modify users.") - def create(self, users): + def create(self, users: List[str]): cdef: User user SlurmList user_list @@ -240,20 +250,20 @@ cdef class Users(dict): self._db_conn = None @staticmethod - def load(db_conn: Connection, db_filter: UserFilter = None): + def load(db_conn: Connection, db_filter: Optional[UserFilter] = None): return db_conn.users.load(db_filter) - def delete(self, db_conn: Connection | None = None): + def delete(self, db_conn: Optional[Connection] = None): db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = UserFilter(names=list(self.keys())) db_conn.users.delete(db_filter) - def modify(self, changes: User, db_conn: Connection | None = None): + def modify(self, changes: Optional[User] = None, db_conn: Optional[Connection] = None, **kwargs: Any): db_conn = Connection.reuse(self._db_conn, db_conn) db_filter = UserFilter(names=list(self.keys())) - return db_conn.users.modify(db_filter, changes) + return db_conn.users.modify(db_filter, changes=changes, **kwargs) - def create(self, db_conn: Connection | None = None): + def create(self, db_conn: Optional[Connection] = None): db_conn = Connection.reuse(self._db_conn, db_conn) db_conn.users.create(list(self.values())) @@ -263,7 +273,7 @@ cdef class UserFilter: def __cinit__(self): self.ptr = NULL - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any): for k, v in kwargs.items(): setattr(self, k, v) @@ -303,7 +313,7 @@ cdef class User: def __cinit__(self): self.ptr = NULL - def __init__(self, name=None, **kwargs): + def __init__(self, name: str = None, **kwargs: Any): self._alloc_impl() self.name = name self._init_defaults() @@ -343,7 +353,7 @@ cdef class User: wrap._init_defaults() return wrap - def to_dict(self, recursive=False): + def to_dict(self, recursive: bool = False): """Database User information formatted as a dictionary. Returns: @@ -351,7 +361,7 @@ cdef class User: """ return instance_to_dict(self, recursive) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if isinstance(other, User): return self.name == other.name return NotImplemented @@ -360,17 +370,22 @@ cdef class User: def load(db_conn: Connection, name: str): user = db_conn.users.load().get(name) if not user: - raise RPCError(msg=f"User {name} does not exist.") + raise NotFoundError(msg=f"User {name} does not exist.") return user - def create(self, db_conn: Connection = None): + def create(self, db_conn: Optional[Connection] = None): Users({self.name: self}).create(self._db_conn or db_conn) - def delete(self, db_conn: Connection = None): + def delete(self, db_conn: Optional[Connection] = None): Users({self.name: self}).delete(self._db_conn or db_conn) - def modify(self, changes: User, db_conn: Connection | None = None): - Users({self.name: self}).modify(changes, self._db_conn or db_conn) + def modify( + self, + changes: Optional[User] = None, + db_conn: Optional[Connection] = None, + **kwargs: Any + ): + Users({self.name: self}).modify(changes=changes, db_conn=(self._db_conn or db_conn), **kwargs) @property def name(self): diff --git a/tests/integration/test_assoc.py b/tests/integration/test_assoc.py index 5a62ba07..628778a7 100644 --- a/tests/integration/test_assoc.py +++ b/tests/integration/test_assoc.py @@ -30,33 +30,60 @@ ) -def _modify_account(account, conn): - new_desc = "this is a new description" - changes = Account(description=new_desc) - assert account.description != new_desc +def _modify_account(account, conn, with_kwargs, **kwargs): + changes = Account(**kwargs) + + assert account.description != changes.description assoc_before = account.association.to_dict(recursive=True) - account.modify(changes, conn) + + if with_kwargs: + account.modify(db_conn=conn, **kwargs) + else: + account.modify(changes, conn) + account = Account.load(conn, account.name) assoc_after = account.association.to_dict(recursive=True) - assert account.description == new_desc + assert account.description == changes.description # Make sure we didn't change anything in the Association assert assoc_before == assoc_after -def _modify_user(user, conn): - user_changes = User( - admin_level = pyslurm.AdminLevel.ADMINISTRATOR - ) +def _modify_user(user, conn, with_kwargs, **kwargs): + changes = User(**kwargs) + assert user.admin_level == pyslurm.AdminLevel.NONE assoc_before = user.default_association.to_dict(recursive=True) - user.modify(user_changes, conn) + + if with_kwargs: + user.modify(db_conn=conn, **kwargs) + else: + user.modify(changes, conn) + user = User.load(conn, user.name) assoc_after = user.default_association.to_dict(recursive=True) - assert user.admin_level == pyslurm.AdminLevel.ADMINISTRATOR + assert user.admin_level == changes.admin_level # Make sure we didn't change anything in the Association assert assoc_before == assoc_after +def _modify_assoc(assoc, conn, with_kwargs, **kwargs): + changes = Association(**kwargs) + + assert assoc.group_jobs == "UNLIMITED" + assert assoc.group_submit_jobs == "UNLIMITED" + + if with_kwargs: + assoc.modify(db_conn=conn, **kwargs) + else: + assoc.modify(changes, conn) + + assoc = Association.load(conn, assoc.id) + assert assoc.group_jobs == changes.group_jobs + assert assoc.group_submit_jobs == changes.group_submit_jobs + assert assoc.group_jobs != "UNLIMITED" + assert assoc.group_submit_jobs != "UNLIMITED" + + def _load_assoc(assoc_id, conn): assocs = conn.associations.load() return assocs.get(assoc_id) @@ -127,12 +154,33 @@ def _delete_user(user, conn): def _test_modify_delete(user, account, conn): assert conn.is_open - _modify_account(account, conn) - _modify_user(user, conn) + _modify_account(account, conn, with_kwargs=False, description="this is a new description") + _modify_account(account, conn, with_kwargs=True, description="another description") + + _modify_user(user, conn, with_kwargs=False, admin_level="ADMINISTRATOR") + _modify_user(user, conn, with_kwargs=True, admin_level="OPERATOR") + + _modify_assoc(user.default_association, conn, with_kwargs=False, + group_jobs=10, group_submit_jobs=20) + _modify_assoc(user.default_association, conn, with_kwargs=True, + group_jobs=50, group_submit_jobs=100) + _delete_account(account, conn) _delete_user(user, conn) +def _test_api(user, account, conn): + # Save them, before reloading + user_name = user.name + acc_name = account.name + + account = _add_account(account, conn) + user = _add_user(user, conn) + assert user.name == user_name + assert account.name == acc_name + _test_modify_delete(user, account, conn) + + def test_user_and_account_no_assoc(): random_name = str(uuid.uuid4())[:8] user_name = f"user_{random_name}" @@ -141,12 +189,7 @@ def test_user_and_account_no_assoc(): with pyslurm.db.connect() as conn: account = Account(name=acc_name) user = User(name=user_name, default_account=acc_name) - - account = _add_account(account, conn) - user = _add_user(user, conn) - assert user.name == user_name - assert account.name == acc_name - _test_modify_delete(user, account, conn) + _test_api(user, account, conn) def test_user_and_accounts_with_assoc_empty(): @@ -159,9 +202,4 @@ def test_user_and_accounts_with_assoc_empty(): account = Account(name=acc_name, association=account_assoc) user_assoc = Association(user=user_name, account=acc_name) user = User(name=user_name, associations=[user_assoc]) - - account = _add_account(account, conn) - user = _add_user(user, conn) - assert user.name == user_name - assert account.name == acc_name - _test_modify_delete(user, account, conn) + _test_api(user, account, conn) diff --git a/tests/integration/test_db_job.py b/tests/integration/test_db_job.py index a992c65e..73f82bbf 100644 --- a/tests/integration/test_db_job.py +++ b/tests/integration/test_db_job.py @@ -40,12 +40,12 @@ def test_load_single(submit_job): util.wait() with pyslurm.db.connect() as conn: - db_job = pyslurm.db.Job.load(job.id) + db_job = pyslurm.db.Job.load(conn, job.id) - assert db_job.id == job.id + assert db_job.id == job.id - with pytest.raises(pyslurm.RPCError): - pyslurm.db.Job.load(0) + with pytest.raises(pyslurm.core.error.NotFoundError): + pyslurm.db.Job.load(conn, 0) def test_parse_all(submit_job): @@ -53,7 +53,7 @@ def test_parse_all(submit_job): util.wait() with pyslurm.db.connect() as conn: - db_job = pyslurm.db.Job.load(job.id) + db_job = pyslurm.db.Job.load(conn, job.id) job_dict = db_job.to_dict() @@ -80,55 +80,85 @@ def test_modify(submit_job): job = submit_job() util.wait(5) + # With explicit separate Job object as changes with pyslurm.db.connect() as conn: - jfilter = pyslurm.db.JobFilter(ids=[job.id]) - changes = pyslurm.db.Job(comment="test comment") - pyslurm.db.Jobs.modify(jfilter, changes) - - job = pyslurm.db.Job.load(job.id) - assert job.comment == "test comment" - + comment = "comment two" -def test_modify_with_no_auto_commit(submit_job): - job = submit_job() - util.wait(5) + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment - with pyslurm.db.connect(commit_on_success=False) as conn: jfilter = pyslurm.db.JobFilter(ids=[job.id]) - changes = pyslurm.db.Job(comment="test comment") + changes = pyslurm.db.Job(comment=comment) conn.jobs.modify(jfilter, changes) + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment + + # With filter via **kwargs + with pyslurm.db.connect(commit_on_success=False) as conn: + comment = "comment two" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment - job = pyslurm.db.Job.load(job.id) - assert job.comment != "test comment" + jfilter = pyslurm.db.JobFilter(ids=[job.id]) + conn.jobs.modify(jfilter, comment=comment) conn.commit() - job = pyslurm.db.Job.load(job.id) - assert job.comment == "test comment" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment + + with pytest.raises(pyslurm.core.error.ArgumentError): + conn.jobs.modify(jfilter) + + # Without filter, using modify() on the instance + # By default, connections are inherited + with pyslurm.db.connect() as conn: + comment = "comment three" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment + + job.modify(comment=comment) + + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment + + # Without inherited connection, not supplying a connection will fail + with pyslurm.db.connect(reuse_connection=False) as conn: + comment = "comment four" + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment != comment + + job.modify(db_conn=conn, comment=comment) + + job = pyslurm.db.Job.load(conn, job.id) + assert job.comment == comment + + with pytest.raises(pyslurm.db.connection.InvalidConnectionError): + job.modify(comment=comment) -def test_if_steps_exist(submit_job): - # TODO - pass +# def test_if_steps_exist(submit_job): +# # TODO +# pass -def test_load_with_filter_node(submit_job): - # TODO - pass +# def test_load_with_filter_node(submit_job): +# # TODO +# pass -def test_load_with_filter_qos(submit_job): - # TODO - pass +# def test_load_with_filter_qos(submit_job): +# # TODO +# pass -def test_load_with_filter_cluster(submit_job): - # TODO - pass +# def test_load_with_filter_cluster(submit_job): +# # TODO +# pass -def test_load_with_filter_multiple(submit_job): - # TODO - pass +# def test_load_with_filter_multiple(submit_job): +# # TODO +# pass def test_load_with_script(submit_job): @@ -137,7 +167,7 @@ def test_load_with_script(submit_job): util.wait(5) with pyslurm.db.connect() as conn: - db_job = pyslurm.db.Job.load(job.id, with_script=True) + db_job = pyslurm.db.Job.load(conn, job.id, with_script=True) assert db_job.script == script @@ -146,5 +176,5 @@ def test_load_with_env(submit_job): util.wait(5) with pyslurm.db.connect() as conn: - db_job = pyslurm.db.Job.load(job.id, with_env=True) + db_job = pyslurm.db.Job.load(conn, job.id, with_env=True) assert db_job.environment From bd687c47f18214f9619837ec8ea647159993e31f Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sat, 28 Feb 2026 00:06:06 +0100 Subject: [PATCH 10/13] wip --- pyslurm/db/connection.pxd | 3 +- pyslurm/db/connection.pyx | 34 +++++++++++-------- pyslurm/db/error.pyx | 56 ++++++++++++++++---------------- pyslurm/db/job.pyx | 15 +++++---- pyslurm/utils/enums.pyx | 19 ++++++++--- tests/integration/test_assoc.py | 6 ++-- tests/integration/test_db_job.py | 2 +- 7 files changed, 76 insertions(+), 59 deletions(-) diff --git a/pyslurm/db/connection.pxd b/pyslurm/db/connection.pxd index 0a01a31d..233c9a9c 100644 --- a/pyslurm/db/connection.pxd +++ b/pyslurm/db/connection.pxd @@ -33,8 +33,7 @@ from pyslurm.slurm cimport ( cdef class ConnectionConfig: cdef public: - commit_on_success - rollback_on_error + transaction_mode reuse_connection diff --git a/pyslurm/db/connection.pyx b/pyslurm/db/connection.pyx index 0b76457c..d514dfa6 100644 --- a/pyslurm/db/connection.pyx +++ b/pyslurm/db/connection.pyx @@ -30,19 +30,24 @@ from pyslurm.db.assoc import AssociationAPI from pyslurm.db.tres import TrackableResourceAPI from pyslurm.db.qos import QualityOfServiceAPI from pyslurm.db.job import JobsAPI -from typing import Any +from typing import Any, Optional +from pyslurm.utils.enums import StrEnum +from enum import auto + + +class TransactionMode(StrEnum): + PER_OPERATION = auto() + MANUAL = auto() cdef class ConnectionConfig: def __init__( self, - commit_on_success: bool = True, - rollback_on_error: bool = True, + transaction_mode: TransactionMode = TransactionMode.PER_OPERATION, reuse_connection: bool = True, ): - self.commit_on_success = commit_on_success - self.rollback_on_error = rollback_on_error + self.transaction_mode = transaction_mode self.reuse_connection = reuse_connection @@ -61,7 +66,7 @@ class ConfigError(PyslurmError): @contextmanager -def connect(config: ConnectionConfig | None = None, **kwargs: Any): +def connect(config: Optional[ConnectionConfig] = None, **kwargs: Any): """A managed Slurm DB Connection""" if config is not None and kwargs: raise ConfigError("Must provide either a config directly, or kwargs, not both") @@ -92,8 +97,8 @@ cdef class Connection: @staticmethod def reuse( - reusable_conn: Connection | None = None, - explicit_conn: Connection | None = None + reusable_conn: Optional[Connection] = None, + explicit_conn: Optional[Connection] = None ): if explicit_conn: return explicit_conn @@ -111,13 +116,16 @@ cdef class Connection: raise InvalidConnectionError("Connection is closed") def check_commit(self, rc): - if self.config.commit_on_success and rc == slurm.SLURM_SUCCESS: + if self.config.transaction_mode != TransactionMode.PER_OPERATION: + return + + if rc == slurm.SLURM_SUCCESS: self.commit() - elif self.config.rollback_on_error and rc != slurm.SLURM_SUCCESS: + else: self.rollback() @staticmethod - def open(config: ConnectionConfig | None = None, **kwargs: Any): + def open(config: Optional[ConnectionConfig] = None, **kwargs: Any): """Open a new connection to the slurmdbd Raises: @@ -169,7 +177,7 @@ cdef class Connection: def commit(self): """Commit recent changes.""" if not self.is_open: - return + raise InvalidConnectionError("Tried to commit when Connection is already closed.") if slurmdb_connection_commit(self.ptr, 1) == slurm.SLURM_ERROR: raise RPCError("Failed to commit database changes.") @@ -177,7 +185,7 @@ cdef class Connection: def rollback(self): """Rollback recent changes.""" if not self.is_open: - return + raise InvalidConnectionError("Tried to rollback when Connection is already closed.") if slurmdb_connection_commit(self.ptr, 0) == slurm.SLURM_ERROR: raise RPCError("Failed to rollback database changes.") diff --git a/pyslurm/db/error.pyx b/pyslurm/db/error.pyx index e4e7d80d..3bb84622 100644 --- a/pyslurm/db/error.pyx +++ b/pyslurm/db/error.pyx @@ -24,6 +24,16 @@ from pyslurm.core.error import RPCError, slurm_errno, verify_rpc from pyslurm.db.util cimport SlurmList, SlurmListItem +import re + + +# The response involving assoc modification and deletions is a string that can +# be in the following form: +# +# C = X A = X U = X P = X +# +# And we have to parse this stuff... The Partition (P) is optional +assoc_str_pattern = re.compile(r'(\w)\s*=\s*(\w+)') class AssociationChangeInfo: @@ -76,42 +86,28 @@ def get_responses(SlurmList response): yield response_str -def parse_basic_response(SlurmList response): - return get_responses(response) - - -def parse_default_account_errors(SlurmList response): - cdef SlurmListItem response_ptr +def parse_assoc_str(value): + matches = assoc_str_pattern.findall(value) + return dict(matches) - assocs = [] - for response_ptr in response: - response_str = response_ptr.to_str() - if not response_str: - continue - # The response is a string in the following form: - # C = X A = X U = X P = X - # - # And we have to parse this stuff... The Partition (P) is optional - resp = response_str.rstrip().lstrip() - splitted = resp.split(" ") - values = [] - for item in splitted: - if not item: - continue +def get_assoc_response(SlurmList response): + for resp in get_responses(response): + yield parse_assoc_str(resp) - key, value = item.split("=") - values.append(value.strip()) +def parse_default_account_errors_2(SlurmList response): + assocs = [] + for item in get_assoc_response(response): info = AssociationChangeInfo( - cluster = values[0], - account = values[1], - user = values[2], + cluster = item["C"], + account = item["A"], + user = item["U"], ) assoc_str = f"{info.cluster}-{info.account}-{info.user}" - if len(values) > 3: - info.partition = values[3] + if len(item) > 3: + info.partition = item["P"] assoc_str = f"{assoc_str}-{info.partition}" assocs.append(info) @@ -119,6 +115,10 @@ def parse_default_account_errors(SlurmList response): return assocs +def parse_basic_response(SlurmList response): + return list(get_responses(response)) + + def parse_running_job_errors(SlurmList response): cdef SlurmListItem response_ptr diff --git a/pyslurm/db/job.pyx b/pyslurm/db/job.pyx index e396ba7e..9890b3a9 100644 --- a/pyslurm/db/job.pyx +++ b/pyslurm/db/job.pyx @@ -48,11 +48,12 @@ from pyslurm.utils.helpers import ( gres_from_tres_dict, ) from pyslurm.enums import SchedulerType +from typing import Any, Optional cdef class JobsAPI(ConnectionWrapper): - def load(self, db_filter: JobFilter | None = None): + def load(self, db_filter: Optional[JobFilter] = None): """Load Jobs from the Slurm Database Implements the slurmdb_jobs_get RPC. @@ -141,8 +142,8 @@ cdef class JobsAPI(ConnectionWrapper): def modify( self, - db_filter: JobFilter | Jobs, - changes: Job | None = None, + db_filter: Union[JobFilter, Jobs], + changes: Optional[Job] = None, **kwargs: Any ): """Modify Slurm database Jobs. @@ -455,8 +456,8 @@ cdef class Jobs(MultiClusterMap): def modify( self, - changes: Job | None = None, - db_conn: Connection | None = None, + changes: Optional[Job] = None, + db_conn: Optional[Connection] = None, **kwargs: Any ): """Modify all Database Jobs in this collection. @@ -614,8 +615,8 @@ cdef class Job: def modify( self, - changes: Job | None = None, - db_conn: Connection | None = None, + changes: Optional[Job] = None, + db_conn: Optional[Connection] = None, **kwargs: Any ): """Modify this Database Job. diff --git a/pyslurm/utils/enums.pyx b/pyslurm/utils/enums.pyx index d44fd29b..3b842b84 100644 --- a/pyslurm/utils/enums.pyx +++ b/pyslurm/utils/enums.pyx @@ -46,7 +46,7 @@ class DocstringSupport(EnumType): return cls -class SlurmEnum(str, Enum, metaclass=DocstringSupport): +class StrEnum(str, Enum, metaclass=DocstringSupport): def __new__(cls, name, *args): # https://docs.python.org/3/library/enum.html @@ -62,9 +62,6 @@ class SlurmEnum(str, Enum, metaclass=DocstringSupport): v = str(name) new_string = str.__new__(cls, v) new_string._value_ = v - - new_string._flag = int(args[0]) if len(args) >= 1 else 0 - new_string._clear_flag = int(args[1]) if len(args) >= 2 else 0 return new_string def __str__(self): @@ -73,7 +70,19 @@ class SlurmEnum(str, Enum, metaclass=DocstringSupport): @staticmethod def _generate_next_value_(name, _start, _count, _last_values): # We just care about the name of the member to be defined. - return name.upper() + return name.lower() + + +class SlurmEnum(StrEnum): + + def __new__(cls, name, *args): + v = str(name) + new_string = str.__new__(cls, v) + new_string._value_ = v + + new_string._flag = int(args[0]) if len(args) >= 1 else 0 + new_string._clear_flag = int(args[1]) if len(args) >= 2 else 0 + return new_string @classmethod def from_flag(cls, flags, default): diff --git a/tests/integration/test_assoc.py b/tests/integration/test_assoc.py index 628778a7..bb711172 100644 --- a/tests/integration/test_assoc.py +++ b/tests/integration/test_assoc.py @@ -157,8 +157,8 @@ def _test_modify_delete(user, account, conn): _modify_account(account, conn, with_kwargs=False, description="this is a new description") _modify_account(account, conn, with_kwargs=True, description="another description") - _modify_user(user, conn, with_kwargs=False, admin_level="ADMINISTRATOR") - _modify_user(user, conn, with_kwargs=True, admin_level="OPERATOR") + _modify_user(user, conn, with_kwargs=False, admin_level="administrator") + _modify_user(user, conn, with_kwargs=True, admin_level="operator") _modify_assoc(user.default_association, conn, with_kwargs=False, group_jobs=10, group_submit_jobs=20) @@ -170,7 +170,7 @@ def _test_modify_delete(user, account, conn): def _test_api(user, account, conn): - # Save them, before reloading + # Save them before reloading user_name = user.name acc_name = account.name diff --git a/tests/integration/test_db_job.py b/tests/integration/test_db_job.py index 73f82bbf..03892baf 100644 --- a/tests/integration/test_db_job.py +++ b/tests/integration/test_db_job.py @@ -94,7 +94,7 @@ def test_modify(submit_job): assert job.comment == comment # With filter via **kwargs - with pyslurm.db.connect(commit_on_success=False) as conn: + with pyslurm.db.connect(transaction_mode="manual") as conn: comment = "comment two" job = pyslurm.db.Job.load(conn, job.id) assert job.comment != comment From 3cd4d2db02a052979e0de8b3df40ecb4eeabcbb4 Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sat, 28 Feb 2026 00:17:26 +0100 Subject: [PATCH 11/13] fix typo --- pyslurm/db/tres.pxd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyslurm/db/tres.pxd b/pyslurm/db/tres.pxd index 53dfe28d..0dd7a5c9 100644 --- a/pyslurm/db/tres.pxd +++ b/pyslurm/db/tres.pxd @@ -46,7 +46,7 @@ cdef _tres_ids_to_names(char *tres_str, dict tres_id_map) cdef _set_tres_limits(char **dest, src, tres_data) -cdef class TrackeblResourceAPI(ConnectionWrapper): +cdef class TrackableResourceAPI(ConnectionWrapper): pass From 4ddb5eaa30a79a1c14388f887414dabe6a8868b0 Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sat, 28 Feb 2026 00:27:05 +0100 Subject: [PATCH 12/13] fix another typo --- pyslurm/db/error.pyx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyslurm/db/error.pyx b/pyslurm/db/error.pyx index 3bb84622..613d5b0d 100644 --- a/pyslurm/db/error.pyx +++ b/pyslurm/db/error.pyx @@ -96,7 +96,7 @@ def get_assoc_response(SlurmList response): yield parse_assoc_str(resp) -def parse_default_account_errors_2(SlurmList response): +def parse_default_account_errors(SlurmList response): assocs = [] for item in get_assoc_response(response): info = AssociationChangeInfo( From 1345c94bc38c1fc9575c552092242b93ab7e14ec Mon Sep 17 00:00:00 2001 From: Toni Harzendorf Date: Sun, 1 Mar 2026 13:30:46 +0100 Subject: [PATCH 13/13] wip --- pyslurm/db/account.pyx | 98 ++++++++++++++------------------------- pyslurm/db/assoc.pyx | 48 +++---------------- pyslurm/db/error.pyx | 35 +++++++++++++- pyslurm/db/user.pyx | 103 ++++++++++++----------------------------- 4 files changed, 106 insertions(+), 178 deletions(-) diff --git a/pyslurm/db/account.pyx b/pyslurm/db/account.pyx index 1d369ff2..b1966389 100644 --- a/pyslurm/db/account.pyx +++ b/pyslurm/db/account.pyx @@ -35,11 +35,7 @@ from pyslurm.utils.helpers import ( ) from pyslurm.utils.uint import * from pyslurm import xcollections -from pyslurm.db.error import ( - DefaultAccountError, - JobsRunningError, - parse_basic_response, -) +from pyslurm.db.error import handle_response from typing import Any, Union, Optional, List, Dict @@ -98,12 +94,8 @@ cdef class AccountAPI(ConnectionWrapper): self.db_conn.apply_reuse(out) return out - def delete(self, db_filter: AccountFilter): - cdef: - SlurmList response - list out = [] - + out = [] # Check is required because for some reason if the acct_cond doesn't # contain any valid conditions, slurmdbd will delete all accounts. # TODO: Maybe make it configurable @@ -116,28 +108,15 @@ cdef class AccountAPI(ConnectionWrapper): response = SlurmList.wrap(slurmdb_accounts_remove(self.db_conn.ptr, db_filter.ptr)) rc = slurm_errno() self.db_conn.check_commit(rc) + return handle_response(response, rc) - if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return - -# if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: -# verify_rpc(rc) - - # Handle the error cases. - if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: - raise JobsRunningError.from_response(response, rc) - elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: - raise DefaultAccountError.from_response(response, rc) - else: - verify_rpc(rc) - - - def modify(self, db_filter: AccountFilter, changes: Optional[Account] = None, **kwargs: Any): - cdef: - SlurmList response - Account _changes - SlurmListItem response_ptr - list out = [] + def modify( + self, + db_filter: AccountFilter, + changes: Optional[Account] = None, + **kwargs: Any + ): + cdef Account _changes _changes = _get_modify_arguments_for(Account, changes, **kwargs) @@ -149,33 +128,9 @@ cdef class AccountAPI(ConnectionWrapper): ) rc = slurm_errno() self.db_conn.check_commit(rc) + return handle_response(response, rc) - if rc == slurm.SLURM_SUCCESS: - return parse_basic_response(response) - elif rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return out - else: - # verify_rpc(rc) - raise RPCError(msg="Failed to modify accounts.") - - -# if not response.is_null and response.cnt: -# for response_ptr in response: -# response_str = cstr.to_unicode(response_ptr.data) -# if not response_str: -# continue - -# out.append(response_str) - -# elif not response.is_null: -# # There was no real error, but simply nothing has been modified -# return out -# else: -# # Autodetects the last slurm error -# raise RPCError(msg="Failed to modify accounts.") - - - def create(self, accounts: List[str]): + def create(self, accounts: List[Account]): cdef: Account account SlurmList account_list @@ -192,13 +147,32 @@ cdef class AccountAPI(ConnectionWrapper): slurm.slurm_list_append(account_list.info, account.ptr) rc = slurmdb_accounts_add(self.db_conn.ptr, account_list.info) - # TODO: Only commit here when we don't add any associations? - # So we don't leave any Accounts without associations behind? + + # Could also solve this construct via a simple try..finally, but I just + # don't want to execute commit/rollback potentially twice, even if it + # is completely fine. + try: + if rc == slurm.SLURM_SUCCESS: + self.db_conn.associations.create(assocs_to_add) + except RPCError: + # Just re-raise - required rollback was already taken care of + raise + except Exception: + # Doing this catch-all thing might be too cautious, but just in + # case anything goes wrong before Associations were attempted to be + # added, we make sure that adding the users is also rollbacked. + # + # Because we don't want to leave Users with no associations behind + # in the system, if associations were requested to be added. + self.db_conn.check_commit(slurm.SLURM_ERROR) + raise + + # TODO: SLURM_NO_CHANGE_IN_DATA + # Should this be an error? + + # Rollback or commit in case no associations were attempted to be added self.db_conn.check_commit(rc) verify_rpc(rc) - # TODO: Maybe don't create the associations automatically? And don't do - # any hidden stuff? - self.db_conn.associations.create(assocs_to_add) cdef class Accounts(dict): diff --git a/pyslurm/db/assoc.pyx b/pyslurm/db/assoc.pyx index d92ef671..b30e8619 100644 --- a/pyslurm/db/assoc.pyx +++ b/pyslurm/db/assoc.pyx @@ -36,7 +36,7 @@ from pyslurm.utils.helpers import ( from pyslurm.utils.uint import * from pyslurm import settings from pyslurm import xcollections -from pyslurm.db.error import JobsRunningError, DefaultAccountError +from pyslurm.db.error import handle_response from typing import Any, Union, Optional, List, Dict @@ -84,12 +84,8 @@ cdef class AssociationAPI(ConnectionWrapper): self.db_conn.apply_reuse(out) return out - def delete(self, db_filter: AssociationFilter): - cdef: - SlurmList response - SlurmListItem response_ptr - + out = [] # TODO: Properly check if the filter is empty, cause it will then probably # target all assocs. Or maybe that is fine and we need to clearly document # to take caution @@ -104,28 +100,13 @@ cdef class AssociationAPI(ConnectionWrapper): ) rc = slurm_errno() self.db_conn.check_commit(rc) - - if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return - - #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: - # verify_rpc(rc) - - # Handle the error cases. - if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: - raise JobsRunningError.from_response(response, rc) - elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: - raise DefaultAccountError.from_response(response, rc) - else: - verify_rpc(rc) + return handle_response(response, rc) def modify(self, db_filter: AssociationFilter, changes: Optional[Association] = None, **kwargs: Any): cdef: Association _changes - SlurmList response - SlurmListItem response_ptr - list out = [] + out = [] # TODO: prohibit mixing multiple user assocs with account assocs # This is not possible, and the request will simply affect nothing... @@ -143,24 +124,7 @@ cdef class AssociationAPI(ConnectionWrapper): self.db_conn.ptr, db_filter.ptr, _changes.ptr)) rc = slurm_errno() self.db_conn.check_commit(rc) - - if not response.is_null and response.cnt: - for response_ptr in response: - response_str = cstr.to_unicode(response_ptr.data) - if not response_str: - continue - - # TODO: Better format - out.append(response_str) - - elif not response.is_null: - # There was no real error, but simply nothing has been modified - return None - else: - # Autodetects the last slurm error - raise RPCError() - - return out + return handle_response(response, rc) def create(self, associations: List[Association]): cdef: @@ -180,6 +144,8 @@ cdef class AssociationAPI(ConnectionWrapper): assoc_list.append(assoc) rc = slurmdb_associations_add(self.db_conn.ptr, assoc_list.info) + # TODO: SLURM_NO_CHANGE_IN_DATA + # Should this be an error? self.db_conn.check_commit(rc) verify_rpc(rc) diff --git a/pyslurm/db/error.pyx b/pyslurm/db/error.pyx index 613d5b0d..34d83ea9 100644 --- a/pyslurm/db/error.pyx +++ b/pyslurm/db/error.pyx @@ -24,6 +24,7 @@ from pyslurm.core.error import RPCError, slurm_errno, verify_rpc from pyslurm.db.util cimport SlurmList, SlurmListItem +from pyslurm cimport slurm import re @@ -36,6 +37,39 @@ import re assoc_str_pattern = re.compile(r'(\w)\s*=\s*(\w+)') +def handle_response(SlurmList response, rc): + out = [] + if not response.is_null and response.cnt: + if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: + # The slurmdbd actually deletes the associations, even if + # Jobs are running. The client side must then decide whether to + # rollback or actually commit the changes. sacctmgr does the + # rollback. + + # By default, any errors are automatically rollbacked, for safety. + # User can disable this behaviour. + raise JobsRunningError.from_response(response, rc) + elif rc == slurm.ESLURM_NO_REMOVE_DEFAULT_ACCOUNT: + raise DefaultAccountError.from_response(response, rc) + + out = parse_basic_response(response) + elif not response.is_null or rc == slurm.SLURM_NO_CHANGE_IN_DATA: + # Nothing was modified + # TODO: Should this be an error actually? + pass + elif rc == slurm.ESLURM_INVALID_PARENT_ACCOUNT: + # When modifying Associations failed. + # TODO: proper error + verify_rpc(rc) + else: + # ESLURM_ONE_CHANGE - may happen when name of a user is attempted to be + # changed. only 1 user can be specified at a time. Could also detect + # that earlier + verify_rpc(rc) + + return out + + class AssociationChangeInfo: def __init__(self, user, cluster, account, partition=None): @@ -76,7 +110,6 @@ class DefaultAccountError(RPCError): def get_responses(SlurmList response): cdef SlurmListItem response_ptr - #TODO: check also for count? if response.is_null: return [] diff --git a/pyslurm/db/user.pyx b/pyslurm/db/user.pyx index d851dea5..dae61424 100644 --- a/pyslurm/db/user.pyx +++ b/pyslurm/db/user.pyx @@ -36,7 +36,7 @@ from pyslurm.utils.helpers import ( from pyslurm.utils.uint import * from pyslurm import xcollections from pyslurm.utils.enums import SlurmEnum -from pyslurm.db.error import JobsRunningError, parse_basic_response +from pyslurm.db.error import handle_response from pyslurm.enums import AdminLevel from typing import Any, Union, Optional, List, Dict @@ -46,15 +46,9 @@ cdef class UserAPI(ConnectionWrapper): def load(self, db_filter: Optional[UserFilter] = None): cdef: Users out = Users() - User user UserFilter cond = db_filter - SlurmList user_data SlurmListItem user_ptr - SlurmList assoc_data SlurmListItem assoc_ptr - Association assoc - QualitiesOfService qos_data - TrackableResources tres_data self.db_conn.validate() @@ -95,10 +89,8 @@ cdef class UserAPI(ConnectionWrapper): self.db_conn.apply_reuse(out) return out - def delete(self, db_filter: UserFilter): - cdef: - SlurmList response + out = [] # TODO: test again when this is empty, does it really delete everything? if not db_filter.names: @@ -110,36 +102,12 @@ cdef class UserAPI(ConnectionWrapper): response = SlurmList.wrap(slurmdb_users_remove(self.db_conn.ptr, db_filter.ptr)) rc = slurm_errno() self.db_conn.check_commit(rc) - - if rc == slurm.SLURM_SUCCESS or rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return - - #if rc == slurm.ESLURM_ACCESS_DENIED or response.is_null: - # verify_rpc(rc) - - # Handle the error case. Running Jobs should be the only possible error - # where slurmdbd sends a response list. - if rc == slurm.ESLURM_JOBS_RUNNING_ON_ASSOC: - # TODO: The slurmdbd actually deletes the associations, even if - # Jobs are running. The client side must then decide whether to - # rollback or actually commit the changes. sacctmgr does the - # rollback. - - # Should we also do this here automatically to prevent - # anyone accidentally forgetting this? Or let the caller handle it? - # If we do it, then it might rollback changes that were done - # earlier and haven't been committed yet. - raise JobsRunningError.from_response(response, rc) - else: - verify_rpc(rc) - + return handle_response(response, rc) def modify(self, db_filter: UserFilter, changes: Optional[User] = None, **kwargs: Any): cdef: - SlurmList response User _changes SlurmListItem response_ptr - list out = [] # TODO: Properly check if the filter is empty, cause it will then probably # target all users. Or maybe that is fine and we need to clearly document @@ -158,41 +126,9 @@ cdef class UserAPI(ConnectionWrapper): rc = slurm_errno() self.db_conn.check_commit(rc) - if rc == slurm.SLURM_SUCCESS: - return parse_basic_response(response) - elif rc == slurm.SLURM_NO_CHANGE_IN_DATA: - return out - else: - # verify_rpc(rc) - # ESLURM_ONE_CHANGE - when the name is changed, only 1 user can be - # specified at a time + return handle_response(response, rc) - # SLURM_ERROR - general error - raise RPCError(msg="Failed to modify users.") - -# if not response.is_null and response.cnt: -# for response_ptr in response: -# response_str = cstr.to_unicode(response_ptr.data) -# if not response_str: -# continue - -# out.append(response_str) - -# elif not response.is_null: -# # There was no real error, but simply nothing has been modified -# return out -# else: -# # TODO: handle errors better -# # ESLURM_NO_CHANGE_IN_DATA - -# # ESLURM_ONE_CHANGE - when the name is changed, only 1 user can be -# # specified at a time - -# # Autodetects the last slurm error -# raise RPCError(msg="Failed to modify users.") - - - def create(self, users: List[str]): + def create(self, users: List[User]): cdef: User user SlurmList user_list @@ -232,13 +168,32 @@ cdef class UserAPI(ConnectionWrapper): slurm.slurm_list_append(user_list.info, user.ptr) rc = slurmdb_users_add(self.db_conn.ptr, user_list.info) - # TODO: Only commit here when we don't add any associations? - # So we don't leave any Users without associations behind? + + # Could also solve this construct via a simple try..finally, but I just + # don't want to execute commit/rollback potentially twice, even if it + # is completely fine. + try: + if rc == slurm.SLURM_SUCCESS: + self.db_conn.associations.create(assocs_to_add) + except RPCError: + # Just re-raise - required rollback was already taken care of + raise + except Exception: + # Doing this catch-all thing might be too cautious, but just in + # case anything goes wrong before Associations were attempted to be + # added, we make sure that adding the users is also rollbacked. + # + # Because we don't want to leave Users with no associations behind + # in the system, if associations were requested to be added. + self.db_conn.check_commit(slurm.SLURM_ERROR) + raise + + # TODO: SLURM_NO_CHANGE_IN_DATA + # Should this be an error? + + # Rollback or commit in case no associations were attempted to be added self.db_conn.check_commit(rc) verify_rpc(rc) - # TODO: Maybe don't create the associations automatically? And don't do - # any hidden stuff? - self.db_conn.associations.create(assocs_to_add) cdef class Users(dict):