1717import os
1818import stat
1919import sys
20- from multiprocessing import Process , Queue
20+ from multiprocessing import Pool , Process , Queue , current_process
2121from pathlib import Path
2222
2323from torch .utils .data import Dataset as TorchDataset
@@ -69,7 +69,8 @@ class Dataset(TorchDataset):
6969 Number of directory entries to read for each readdir call.
7070 dir_cache_size: int (optional)
7171 Number of directory object entries to cache in memory.
72-
72+ readdir_workers: int (optional)
73+ Number of parallel workers for namespace scanning.
7374
7475 Methods
7576 -------
@@ -92,7 +93,8 @@ class Dataset(TorchDataset):
9293 def __init__ (self , pool = None , cont = None , path = None ,
9394 transform_fn = transform_fn_default ,
9495 readdir_batch_size = READDIR_BATCH_SIZE ,
95- dir_cache_size = DIR_CACHE_SIZE ):
96+ dir_cache_size = DIR_CACHE_SIZE ,
97+ readdir_workers = PARALLEL_SCAN_WORKERS ):
9698 super ().__init__ ()
9799
98100 self ._pool = pool
@@ -102,7 +104,8 @@ def __init__(self, pool=None, cont=None, path=None,
102104 self ._readdir_batch_size = readdir_batch_size
103105 self ._closed = False
104106
105- self .objects = self ._dfs .parallel_list (path , readdir_batch_size = self ._readdir_batch_size )
107+ self .objects = self ._dfs .parallel_list (
108+ path , readdir_batch_size = self ._readdir_batch_size , workers = readdir_workers )
106109
107110 def __len__ (self ):
108111 """ Returns number of items in this dataset """
@@ -216,6 +219,8 @@ class IterableDataset(TorchIterableDataset):
216219 Number of samples to fetch per iteration.
217220 dir_cache_size: int (optional)
218221 Number of directory object entries to cache in memory.
222+ readdir_workers: int (optional)
223+ Number of parallel workers for namespace scanning.
219224
220225
221226 Methods
@@ -233,7 +238,8 @@ def __init__(self, pool=None, cont=None, path=None,
233238 transform_fn = transform_fn_default ,
234239 readdir_batch_size = READDIR_BATCH_SIZE ,
235240 batch_size = ITER_BATCH_SIZE ,
236- dir_cache_size = DIR_CACHE_SIZE ):
241+ dir_cache_size = DIR_CACHE_SIZE ,
242+ readdir_workers = PARALLEL_SCAN_WORKERS ):
237243 super ().__init__ ()
238244
239245 self ._pool = pool
@@ -244,7 +250,8 @@ def __init__(self, pool=None, cont=None, path=None,
244250 self ._batch_size = batch_size
245251 self ._closed = False
246252
247- self .objects = self ._dfs .parallel_list (path , readdir_batch_size = self ._readdir_batch_size )
253+ self .objects = self ._dfs .parallel_list (
254+ path , readdir_batch_size = self ._readdir_batch_size , workers = readdir_workers )
248255 self .workset = self .objects
249256
250257 def __iter__ (self ):
@@ -646,6 +653,35 @@ def writer(self, file, ensure_path=True):
646653 self ._chunks_limit , self ._workers )
647654
648655
656+ def _readdir_worker_init (dfs , readdir_batch_size ):
657+ """
658+ Worker init for parallel readdir.
659+
660+ Receives `self` as an argument to re-init DAOS after fork, per worker process.
661+
662+ It has to be module function since the multiprocessing.Pool methods to init workers
663+ will pickle instance method with main process's _Dfs class reference.
664+ """
665+
666+ dfs .worker_init ()
667+ proc = current_process ()
668+ proc .dfs = dfs
669+ proc .readdir_batch_size = readdir_batch_size
670+
671+
672+ def _readdir_batch (work ):
673+ """
674+ Reads the anchored directory at `path` with `anchor_index` and returns
675+ list of discovered directories and files.
676+
677+ It has to be module function since the multiprocessing.Pool methods to submit jobs
678+ will pickle instance method with main process's _Dfs class reference.
679+ """
680+ path , anchor_index = work
681+ proc = current_process ()
682+ return proc .dfs .readdir_anchored (path , anchor_index , proc .readdir_batch_size )
683+
684+
649685class _Dfs ():
650686 """
651687 Class encapsulating libdfs interface to load PyTorch Dataset
@@ -676,49 +712,10 @@ def disconnect(self):
676712 raise OSError (ret , os .strerror (ret ))
677713 self ._dfs = None
678714
679- def list_worker_fn (self , in_work , out_dirs , out_files , readdir_batch_size = READDIR_BATCH_SIZE ):
680- """
681- Worker function to scan directory in parallel.
682- It expects to receive tuples (path, index) to scan the directory with an anchor index,
683- from the `in_work` queue.
684- It should emit tuples (scanned, to_scan) to the `out_dirs` queue, where `scanned` is the
685- number of scanned directories and `to_scan` is the list of directories to scan in parallel.
686- Upon completion it should emit the list of files in the `out_files` queue.
687- """
688-
689- self .worker_init ()
690-
691- result = []
692- while True :
693- work = in_work .get ()
694- if work is None :
695- break
696-
697- (path , index ) = work
698-
699- dirs = []
700- files = []
701- ret = torch_shim .torch_list_with_anchor (DAOS_MAGIC , self ._dfs ,
702- path , index , files , dirs , readdir_batch_size
703- )
704- if ret != 0 :
705- raise OSError (ret , os .strerror (ret ), path )
706-
707- dirs = [chunk for d in dirs for chunk in self .split_dir_for_parallel_scan (
708- os .path .join (path , d ))
709- ]
710- # Even if there are no dirs, we should emit the tuple to notify the main process
711- out_dirs .put ((1 , dirs ))
712-
713- files = [(os .path .join (path , file ), size ) for (file , size ) in files ]
714- result .extend (files )
715-
716- out_files .put (result )
717-
718715 def split_dir_for_parallel_scan (self , path ):
719716 """
720717 Splits dir for parallel readdir.
721- It returns list of tuples (dirname, anchor index ) to be consumed by worker function
718+ It returns list of tuples (dirname, anchor_index ) to be consumed by workers
722719 """
723720
724721 ret , splits = torch_shim .torch_recommended_dir_split (DAOS_MAGIC , self ._dfs , path )
@@ -727,6 +724,28 @@ def split_dir_for_parallel_scan(self, path):
727724
728725 return [(path , idx ) for idx in range (0 , splits )]
729726
727+ def readdir_anchored (self , path , anchor_index , readdir_batch_size ):
728+ """
729+ Scans one anchored by index directory at `path`.
730+
731+ Returns (dirs, files):
732+ `dirs` are (path, anchor_index) work items for directories found in this batch,
733+ `files` is a list of resulting tuples: (full_path, size).
734+ """
735+ dirs = []
736+ files = []
737+ ret = torch_shim .torch_list_with_anchor (
738+ DAOS_MAGIC , self ._dfs , path , anchor_index , files , dirs , readdir_batch_size )
739+ if ret != 0 :
740+ raise OSError (ret , os .strerror (ret ), path )
741+
742+ subdirs = [split
743+ for name in dirs
744+ for split in self .split_dir_for_parallel_scan (os .path .join (path , name ))]
745+
746+ files = [(os .path .join (path , name ), size ) for (name , size ) in files ]
747+ return subdirs , files
748+
730749 def parallel_list (self , path = None ,
731750 readdir_batch_size = READDIR_BATCH_SIZE ,
732751 workers = PARALLEL_SCAN_WORKERS ):
@@ -736,43 +755,42 @@ def parallel_list(self, path=None,
736755
737756 To fully use this feature the container should be configured with directory object classes
738757 supporting this mode, e.g. OC_SX.
758+
759+ Using multiprocessing.Pool ensures propagation of errors in the workers and cleaning up
760+ resources, regardless of operation outcome.
761+
762+ It would be even better to use `concurrent.futures.ProcessPoolExecutor`; however,
763+ its `initializer` and `initargs` arguments are available only in Python 3.7+.
764+
765+ Although Python 3.6 is EOL, many distributions still ship it by default.
766+ Keeping `_readdir_worker_init` and `_readdir_batch` as module-level functions
767+ instead of private class methods, is a small price that allows us to support
768+ a much broader range of platforms.
739769 """
770+
740771 if path is None :
741772 path = os .sep
742773
743774 if not path .startswith (os .sep ):
744775 raise ValueError ("relative path is unacceptable" )
745776
746- procs = []
747- work = Queue ()
748- dirs = Queue ()
749- files = Queue ()
750- for _ in range (workers ):
751- worker = Process (target = self .list_worker_fn , args = (
752- work , dirs , files , readdir_batch_size ))
753- worker .start ()
754- procs .append (worker )
755-
756- queued = 0
757- processed = 0
758- for anchored_dir in self .split_dir_for_parallel_scan (path ):
759- work .put (anchored_dir )
760- queued += 1
761-
762- while processed < queued :
763- (scanned , to_scan ) = dirs .get ()
764- processed += scanned
765- for d in to_scan :
766- work .put (d )
767- queued += 1
777+ if readdir_batch_size <= 0 :
778+ raise ValueError ("readdir batch size should be a positive number" )
768779
769- result = []
770- for _ in range (workers ):
771- work .put (None )
772- result .extend (files .get ())
780+ if workers <= 0 :
781+ raise ValueError ("at least one worker is required for namespace scanning" )
773782
774- for worker in procs :
775- worker .join ()
783+ result = []
784+ batch = self .split_dir_for_parallel_scan (path )
785+ with Pool (workers ,
786+ initializer = _readdir_worker_init ,
787+ initargs = (self , readdir_batch_size )) as pool :
788+ while batch :
789+ next_batch = []
790+ for dirs , files in pool .imap_unordered (_readdir_batch , batch ):
791+ next_batch .extend (dirs )
792+ result .extend (files )
793+ batch = next_batch
776794
777795 return result
778796
0 commit comments