7
7
import inspect
8
8
import itertools
9
9
from types import MethodType
10
+ from typing import List
10
11
11
12
import numpy as np
12
13
@@ -96,6 +97,54 @@ def filter_attributes(ctx, f, **kwargs):
96
97
_process_exclusion (ctx , cls_attrs , kwargs ["exclude" ], f )
97
98
98
99
100
+ def validate_data_types (
101
+ prohibited_data_types : List [str ] = None ,
102
+ allowed_data_types : List [str ] = None ,
103
+ reserved_words = ["collaborators" ],
104
+ ** kwargs ,
105
+ ):
106
+ """Validates that the types of attributes in kwargs are not among the prohibited data types
107
+ and are among the allowed data types if specified.
108
+ Raises a TypeError if any prohibited data type is found or if a type is not allowed.
109
+
110
+ Args:
111
+ prohibited_data_types (List[str], optional): A list of prohibited data type names
112
+ (e.g., ['int', 'float']).
113
+ allowed_data_types (List[str], optional): A list of allowed data type names.
114
+ reserved_words: A list of strings that should be allowed as attribute values, even if 'str'
115
+ is included in prohibited_data_types.
116
+ kwargs (dict): Arbitrary keyword arguments representing attribute names and their values.
117
+
118
+ Raises:
119
+ TypeError: If any prohibited data types are found in kwargs or if a type is not allowed.
120
+ ValueError: If both prohibited_data_types and allowed_data_types are set simultaneously.
121
+ """
122
+ if prohibited_data_types is None :
123
+ prohibited_data_types = []
124
+ if allowed_data_types is None :
125
+ allowed_data_types = []
126
+
127
+ if prohibited_data_types and allowed_data_types :
128
+ raise ValueError ("Cannot set both 'prohibited_data_types' and 'allowed_data_types'." )
129
+
130
+ for attr_name , attr_value in kwargs .items ():
131
+ attr_type = type (attr_value ).__name__
132
+ if (
133
+ prohibited_data_types
134
+ and attr_type in prohibited_data_types
135
+ and attr_value not in reserved_words
136
+ ):
137
+ raise TypeError (
138
+ f"The attribute '{ attr_name } ' = '{ attr_value } ' "
139
+ f"has a prohibited value type: { attr_type } "
140
+ )
141
+ if allowed_data_types and attr_type not in allowed_data_types :
142
+ raise TypeError (
143
+ f"The attribute '{ attr_name } ' = '{ attr_value } ' "
144
+ f"has a type that is not allowed: { attr_type } "
145
+ )
146
+
147
+
99
148
def _validate_include_exclude (kwargs , cls_attrs ):
100
149
"""Validates that 'include' and 'exclude' are not both present, and that
101
150
attributes in 'include' or 'exclude' exist in the context.
@@ -152,13 +201,13 @@ def _process_exclusion(ctx, cls_attrs, exclude_list, f):
152
201
delattr (ctx , attr )
153
202
154
203
155
- def checkpoint (ctx , parent_func , chkpnt_reserved_words = ["next" , "runtime" ]):
204
+ def checkpoint (ctx , parent_func , checkpoint_reserved_words = ["next" , "runtime" ]):
156
205
"""Optionally saves the current state for the task just executed.
157
206
158
207
Args:
159
208
ctx (any): The context to checkpoint.
160
209
parent_func (function): The function that was just executed.
161
- chkpnt_reserved_words (list, optional): A list of reserved words to
210
+ checkpoint_reserved_words (list, optional): A list of reserved words to
162
211
exclude from checkpointing. Defaults to ["next", "runtime"].
163
212
164
213
Returns:
@@ -173,7 +222,7 @@ def checkpoint(ctx, parent_func, chkpnt_reserved_words=["next", "runtime"]):
173
222
if ctx ._checkpoint :
174
223
# all objects will be serialized using Metaflow interface
175
224
print (f"Saving data artifacts for { parent_func .__name__ } " )
176
- artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = chkpnt_reserved_words )
225
+ artifacts_iter , _ = generate_artifacts (ctx = ctx , reserved_words = checkpoint_reserved_words )
177
226
task_id = ctx ._metaflow_interface .create_task (parent_func .__name__ )
178
227
ctx ._metaflow_interface .save_artifacts (
179
228
artifacts_iter (),
@@ -195,7 +244,7 @@ def old_check_resource_allocation(num_gpus, each_participant_gpu_usage):
195
244
# But at this point the function will raise an error because
196
245
# remaining_gpu_memory is never cleared.
197
246
# The participant list should remove the participant if it fits in the gpu
198
- # and save the partipant if it doesn't and continue to the next GPU to see
247
+ # and save the participant if it doesn't and continue to the next GPU to see
199
248
# if it fits in that one, only if we run out of GPUs should this function
200
249
# raise an error.
201
250
for gpu in np .ones (num_gpus , dtype = int ):
@@ -230,7 +279,7 @@ def check_resource_allocation(num_gpus, each_participant_gpu_usage):
230
279
if gpu == 0 :
231
280
break
232
281
if gpu < participant_gpu_usage :
233
- # participant doesn't fitm break to next GPU
282
+ # participant doesn't fit, break to next GPU
234
283
break
235
284
else :
236
285
# if participant fits remove from need_assigned
0 commit comments