7
7
import inspect
8
8
import itertools
9
9
from types import MethodType
10
+ from typing import List
10
11
11
12
import numpy as np
12
13
@@ -96,6 +97,31 @@ def filter_attributes(ctx, f, **kwargs):
96
97
_process_exclusion (ctx , cls_attrs , kwargs ["exclude" ], f )
97
98
98
99
100
+ def validate_data_types (
101
+ prohibited_data_types : List [str ], reserved_words = ["collaborators" ], ** kwargs
102
+ ):
103
+ """Validates that the types of attributes in kwargs are not among the prohibited data types.
104
+ Raises a TypeError if any prohibited data type is found.
105
+
106
+ Args:
107
+ prohibited_data_types (List[str]): A list of prohibited data type names
108
+ (e.g., ['int', 'float']).
109
+ kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
110
+
111
+ Raises:
112
+ TypeError: If any prohibited data types are found in kwargs.
113
+ ValueError: If prohibited_data_types is empty.
114
+ """
115
+ if not prohibited_data_types :
116
+ raise ValueError ("prohibited_data_types must not be empty." )
117
+ for attr_name , attr_value in kwargs .items ():
118
+ if type (attr_value ).__name__ in prohibited_data_types and attr_value not in reserved_words :
119
+ raise TypeError (
120
+ f"The attribute '{ attr_name } ' = '{ attr_value } ' has a prohibited value type: "
121
+ f"{ type (attr_value ).__name__ } "
122
+ )
123
+
124
+
99
125
def _validate_include_exclude (kwargs , cls_attrs ):
100
126
"""Validates that 'include' and 'exclude' are not both present, and that
101
127
attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +178,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152
178
delattr (ctx , attr )
153
179
154
180
155
- def checkpoint (ctx , parent_func , chkpnt_reserved_words = ["next" , "runtime" ]):
181
+ def checkpoint (ctx , parent_func , checkpoint_reserved_words = ["next" , "runtime" ]):
156
182
"""Optionally saves the current state for the task just executed.
157
183
158
184
Args:
159
185
ctx (any): The context to checkpoint.
160
186
parent_func (function): The function that was just executed.
161
- chkpnt_reserved_words (list, optional): A list of reserved words to
187
+ checkpoint_reserved_words (list, optional): A list of reserved words to
162
188
exclude from checkpointing. Defaults to ["next", "runtime"].
163
189
164
190
Returns:
@@ -173,7 +199,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173
199
if ctx ._checkpoint :
174
200
# all objects will be serialized using Metaflow interface
175
201
print (f"Saving data artifacts for { parent_func .__name__ } " )
176
- artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = chkpnt_reserved_words )
202
+ artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = checkpoint_reserved_words )
177
203
task_id = ctx ._metaflow_interface .create_task (parent_func .__name__ )
178
204
ctx ._metaflow_interface .save_artifacts (
179
205
artifacts_iter (),
@@ -188,15 +214,15 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
188
214
189
215
def old_check_resource_allocation (num_gpus , each_participant_gpu_usage ):
190
216
remaining_gpu_memory = {}
191
- # TODO for each GPU the funtion tries see if all participant usages fit
217
+ # TODO for each GPU the function tries see if all participant usages fit
192
218
# into a GPU, it it doesn't it removes that participant from the
193
219
# participant list, and adds it to the remaining_gpu_memory dict. So any
194
220
# sum of GPU requirements above 1 triggers this.
195
- # But at this point the funtion will raise an error because
221
+ # But at this point the function will raise an error because
196
222
# remaining_gpu_memory is never cleared.
197
223
# The participant list should remove the participant if it fits in the gpu
198
- # and save the partipant if it doesn't and continue to the next GPU to see
199
- # if it fits in that one, only if we run out of GPUs should this funtion
224
+ # and save the participant if it doesn't and continue to the next GPU to see
225
+ # if it fits in that one, only if we run out of GPUs should this function
200
226
# raise an error.
201
227
for gpu in np .ones (num_gpus , dtype = int ):
202
228
for i , (participant_name , participant_gpu_usage ) in enumerate (
@@ -230,7 +256,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230
256
if gpu == 0 :
231
257
break
232
258
if gpu < participant_gpu_usage :
233
- # participant doesn't fitm break to next GPU
259
+ # participant doesn't fit; break to next GPU
234
260
break
235
261
else :
236
262
# if participant fits remove from need_assigned
0 commit comments