3
3
from diffgram import __version__
4
4
5
5
from diffgram .file .view import get_label_file_dict
6
- from diffgram .core .directory import get_directory_list
7
- from diffgram .core .directory import set_directory_by_name
8
6
from diffgram .convert .convert import convert_label
9
7
from diffgram .label .label_new import label_new
10
8
@@ -29,7 +27,10 @@ def __init__(
29
27
client_secret = None ,
30
28
debug = False ,
31
29
staging = False ,
32
- host = None
30
+ host = None ,
31
+ set_default_directory = True ,
32
+ refresh_local_label_dict = True
33
+
33
34
):
34
35
35
36
self .session = requests .Session ()
@@ -50,24 +51,41 @@ def __init__(
50
51
self .host = "https://diffgram.com"
51
52
else :
52
53
self .host = host
53
- self .directory_id = None
54
- self .name_to_file_id = None
54
+
55
55
self .auth (
56
56
project_string_id = project_string_id ,
57
57
client_id = client_id ,
58
58
client_secret = client_secret )
59
- self .client_id = client_id
60
- self .client_secret = client_secret
61
59
62
60
self .file = FileConstructor (self )
63
- self .train = Train (self )
61
+ # self.train = Train(self)
64
62
self .job = Job (self )
65
63
self .guide = Guide (self )
66
- self .directory = Directory (self , validate_ids = False )
64
+ self .directory = Directory (self ,
65
+ init_file_ids = False ,
66
+ validate_ids = False )
67
67
self .export = Export (self )
68
68
self .task = Task (client = self )
69
+
70
+ self .directory_id = None
71
+ self .name_to_file_id = None
72
+
73
+
74
+ if set_default_directory is True :
75
+ self .set_default_directory ()
76
+ print ("Default directory set:" , self .directory_id )
77
+
78
+ if refresh_local_label_dict is True :
79
+ self .get_label_file_dict ()
80
+
81
+ self .client_id = client_id
82
+ self .client_secret = client_secret
83
+
69
84
self .label_schema_list = self .get_label_schema_list ()
70
85
86
+ self .directory_list = []
87
+
88
+
71
89
def get_member_list (self ):
72
90
url = '/api/project/{}/view' .format (self .project_string_id )
73
91
response = self .session .get (url = self .host + url )
@@ -216,9 +234,7 @@ def handle_errors(self,
216
234
def auth (self ,
217
235
project_string_id ,
218
236
client_id = None ,
219
- client_secret = None ,
220
- set_default_directory = True ,
221
- refresh_local_label_dict = True
237
+ client_secret = None
222
238
):
223
239
"""
224
240
Define authorization configuration
@@ -242,55 +258,65 @@ def auth(self,
242
258
if client_id and client_secret :
243
259
self .session .auth = (client_id , client_secret )
244
260
245
- if set_default_directory is True :
246
- self .set_default_directory ()
247
261
248
- if refresh_local_label_dict is True :
249
- # Refresh local labels from Diffgram project
250
- self .get_label_file_dict ()
262
+ def set_directory_by_name (self , name ):
263
+ """
264
+
265
+ Arguments
266
+ self
267
+ name, string
268
+
269
+ """
270
+
271
+ if name is None :
272
+ raise Exception ("No name provided." )
273
+
274
+ # Don't refresh by default, just set from existing
275
+
276
+ names_attempted = []
277
+ did_set = False
278
+
279
+ if not self .directory_list :
280
+ self .directory_list = self .directory .get_directory_list ()
281
+
282
+ for directory in self .directory_list :
283
+
284
+ if directory .nickname == name :
285
+ self .set_default_directory (directory = directory )
286
+ did_set = True
287
+ break
288
+ else :
289
+ names_attempted .append (directory .nickname )
290
+
291
+ if did_set is False :
292
+ raise Exception (name , " does not exist. Valid names are: " +
293
+ str (names_attempted ))
294
+
251
295
252
296
def set_default_directory (self ,
253
- directory_id = None ):
297
+ directory_id = None ,
298
+ directory = None ):
254
299
"""
255
300
-> If no id is provided fetch directory list for project
256
301
and set first directory to default.
257
302
-> Sets the headers of self.session
258
303
259
- Arguments
260
- directory_id, int, defaults to None
261
-
262
- Returns
263
- None
264
-
265
- Future
266
- TODO return error if invalid directory?
267
-
268
304
"""
269
305
270
306
if directory_id :
271
- # TODO check if valid?
272
- # data = {}
273
- # data["directory_id"] = directory_id
274
307
self .directory_id = directory_id
275
- else :
276
-
277
- data = self .get_directory_list ()
278
-
279
- self .default_directory = data ['default_directory' ]
280
-
281
- # Hold over till refactoring (would prefer to
282
- # just call self.directory_default.id
283
- self .directory_id = self .default_directory ['id' ]
308
+ if directory :
309
+ self .directory_id = directory .id
310
+ self .default_directory = directory
311
+
312
+ self .directory_list = self .directory .get_directory_list ()
284
313
285
- self .directory_list = data ["directory_list" ]
286
314
self .session .headers .update (
287
315
{'directory_id' : str (self .directory_id )})
288
316
289
317
290
318
# TODO review not using this pattern anymore
291
319
292
320
setattr (Project , "get_label_file_dict" , get_label_file_dict )
293
- setattr (Project , "get_directory_list" , get_directory_list )
294
321
setattr (Project , "convert_label" , convert_label )
295
322
setattr (Project , "label_new" , label_new )
296
- setattr (Project , "set_directory_by_name" , set_directory_by_name )
0 commit comments