11# This is an integration test for the immutability functionality
22from __future__ import annotations
3- import json
4- import os
5- import pickle
6- import sys
7- import time
8- import paderbox
9- from paderbox .io .cache import url_to_local_path
10- from collections import defaultdict
11- from typing import Any
12- import lazy_dataset
13- import numpy as np
14- import psutil
15- import torch
16- from tabulate import tabulate
17-
18-
19- # Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
20- def create_coco () -> list [Any ]:
21- json_path = url_to_local_path ("https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json" )
22- with open (json_path ) as f :
23- obj = json .load (f )
24- return obj ["annotations" ]
25-
26-
27-
28- def get_mem_info (pid : int ) -> dict [str , int ]:
29- res = defaultdict (int )
30- for mmap in psutil .Process (pid ).memory_maps ():
31- res ['rss' ] += mmap .rss
32- res ['pss' ] += mmap .pss
33- res ['uss' ] += mmap .private_clean + mmap .private_dirty
34- res ['shared' ] += mmap .shared_clean + mmap .shared_dirty
35- if mmap .path .startswith ('/' ): # looks like a file path
36- res ['shared_file' ] += mmap .shared_clean + mmap .shared_dirty
37- return res
38-
39-
40- class MemoryMonitor ():
41- """Class used to monitor the memory usage of processes"""
42-
43- def __init__ (self , pids : list [int ] = None ):
44- if pids is None :
45- pids = [os .getpid ()]
46- self .pids = pids
47-
48- def add_pid (self , pid : int ):
49- assert pid not in self .pids
50- self .pids .append (pid )
51-
52- def _refresh (self ):
53- self .data = {pid : get_mem_info (pid ) for pid in self .pids }
54- return self .data
55-
56- def table (self ) -> str :
57- self ._refresh ()
58- table = []
59- keys = list (list (self .data .values ())[0 ].keys ())
60- now = str (int (time .perf_counter () % 1e5 ))
61- for pid , data in self .data .items ():
62- table .append ((now , str (pid )) + tuple (self .format (data [k ]) for k in keys ))
63- return tabulate (table , headers = ["time" , "PID" ] + keys )
64-
65- def str (self ):
66- self ._refresh ()
67- keys = list (list (self .data .values ())[0 ].keys ())
68- res = []
69- for pid in self .pids :
70- s = f"PID={ pid } "
71- for k in keys :
72- v = self .format (self .data [pid ][k ])
73- s += f", { k } ={ v } "
74- res .append (s )
75- return "\n " .join (res )
76-
77- @staticmethod
78- def format (size : int ) -> str :
79- for unit in ('' , 'K' , 'M' , 'G' ):
80- if size < 1024 :
81- break
82- size /= 1024.0
83- return "%.1f%s" % (size , unit )
84-
85-
86- def read_sample (x ):
87- """
88- A function that is supposed to read object x, incrementing its refcount.
89- This mimics what a real dataloader would do."""
90- if sys .version_info >= (3 , 10 , 6 ):
91- """Before this version, pickle does not increment refcount. This is a bug that's
92- fixed in https://github.com/python/cpython/pull/92931. """
93- return pickle .dumps (x )
94- else :
95- import msgpack
96- return msgpack .dumps (x )
97-
98-
99- class DatasetFromList (torch .utils .data .Dataset ):
100- def __init__ (self , lst ):
101- self .lst = lst
102-
103- def __len__ (self ):
104- return len (self .lst )
105-
106- def __getitem__ (self , idx : int ):
107- return self .lst [idx ]
108-
109-
110- def worker (_ , dataset : torch .utils .data .Dataset ):
111- while True :
112- for sample in dataset :
113- # read the data, with a fake latency
114- time .sleep (0.000001 )
115- result = read_sample (sample )
116-
1173
1184if __name__ == "__main__" :
5+ import json
6+ import os
7+ import pickle
8+ import sys
9+ import time
10+ import paderbox
11+ from paderbox .io .cache import url_to_local_path
12+ from collections import defaultdict
13+ from typing import Any
14+ import lazy_dataset
15+ import numpy as np
16+ import psutil
17+ from tabulate import tabulate
11918 import matplotlib .pyplot as plt
19+ import torch
20+
21+ # Download from https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json
22+ def create_coco () -> list [Any ]:
23+ json_path = url_to_local_path ("https://huggingface.co/datasets/merve/coco/resolve/main/annotations/instances_train2017.json" )
24+ with open (json_path ) as f :
25+ obj = json .load (f )
26+ return obj ["annotations" ]
27+
28+
29+
30+ def get_mem_info (pid : int ) -> dict [str , int ]:
31+ res = defaultdict (int )
32+ for mmap in psutil .Process (pid ).memory_maps ():
33+ res ['rss' ] += mmap .rss
34+ res ['pss' ] += mmap .pss
35+ res ['uss' ] += mmap .private_clean + mmap .private_dirty
36+ res ['shared' ] += mmap .shared_clean + mmap .shared_dirty
37+ if mmap .path .startswith ('/' ): # looks like a file path
38+ res ['shared_file' ] += mmap .shared_clean + mmap .shared_dirty
39+ return res
40+
41+
42+ class MemoryMonitor ():
43+ """Class used to monitor the memory usage of processes"""
44+
45+ def __init__ (self , pids : list [int ] = None ):
46+ if pids is None :
47+ pids = [os .getpid ()]
48+ self .pids = pids
49+
50+ def add_pid (self , pid : int ):
51+ assert pid not in self .pids
52+ self .pids .append (pid )
53+
54+ def _refresh (self ):
55+ self .data = {pid : get_mem_info (pid ) for pid in self .pids }
56+ return self .data
57+
58+ def table (self ) -> str :
59+ self ._refresh ()
60+ table = []
61+ keys = list (list (self .data .values ())[0 ].keys ())
62+ now = str (int (time .perf_counter () % 1e5 ))
63+ for pid , data in self .data .items ():
64+ table .append ((now , str (pid )) + tuple (self .format (data [k ]) for k in keys ))
65+ return tabulate (table , headers = ["time" , "PID" ] + keys )
66+
67+ def str (self ):
68+ self ._refresh ()
69+ keys = list (list (self .data .values ())[0 ].keys ())
70+ res = []
71+ for pid in self .pids :
72+ s = f"PID={ pid } "
73+ for k in keys :
74+ v = self .format (self .data [pid ][k ])
75+ s += f", { k } ={ v } "
76+ res .append (s )
77+ return "\n " .join (res )
78+
79+ @staticmethod
80+ def format (size : int ) -> str :
81+ for unit in ('' , 'K' , 'M' , 'G' ):
82+ if size < 1024 :
83+ break
84+ size /= 1024.0
85+ return "%.1f%s" % (size , unit )
86+
87+
88+ def read_sample (x ):
89+ """
90+ A function that is supposed to read object x, incrementing its refcount.
91+ This mimics what a real dataloader would do."""
92+ if sys .version_info >= (3 , 10 , 6 ):
93+ """Before this version, pickle does not increment refcount. This is a bug that's
94+ fixed in https://github.com/python/cpython/pull/92931. """
95+ return pickle .dumps (x )
96+ else :
97+ import msgpack
98+ return msgpack .dumps (x )
99+
100+
101+ class DatasetFromList (torch .utils .data .Dataset ):
102+
103+ def __init__ (self , lst ):
104+ self .lst = lst
105+
106+ def __len__ (self ):
107+ return len (self .lst )
108+
109+ def __getitem__ (self , idx : int ):
110+ return self .lst [idx ]
111+
112+
113+ def worker (_ , dataset : torch .utils .data .Dataset ):
114+ while True :
115+ for sample in dataset :
116+ # read the data, with a fake latency
117+ time .sleep (0.000001 )
118+ result = read_sample (sample )
120119 monitor = MemoryMonitor ()
121- immutable_warranty = "pickle " # copy pickle wu
120+ immutable_warranty = "wu " # copy pickle wu
122121 ds = lazy_dataset .new (create_coco (), immutable_warranty = immutable_warranty )
123122 print (monitor .table ())
124123
@@ -144,7 +143,7 @@ def worker(_, dataset: torch.utils.data.Dataset):
144143 axis .set_xlabel ("Times (s)" )
145144 axis .legend ()
146145 axis .set_ylabel ("Memory usage (MB)" )
147- # plt.savefig(f"/net/vol/deegen/SHK/Lazy_dataset_test/{immutable_warranty}.svg", format="svg")#, dpi=600)
148- plt .show ()
146+ plt .savefig (f"/net/vol/deegen/SHK/Lazy_dataset_test/{ immutable_warranty } .svg" , format = "svg" )#, dpi=600)
147+ # plt.show()
149148 finally :
150149 ctx .join ()
0 commit comments