diff --git a/src/graphnet/data/dataconverter.py b/src/graphnet/data/dataconverter.py index 41cec5eec563f597128d78dedf0128cbc5a1e6a0..da3e03622fea5f3ec15a34395ece0b5f116677c8 100644 --- a/src/graphnet/data/dataconverter.py +++ b/src/graphnet/data/dataconverter.py @@ -201,15 +201,16 @@ class DataConverter(ABC, Logger): super().__init__(name=__name__, class_name=self.__class__.__name__) @final - def __call__(self, directories: Union[str, List[str]]) -> None: + def __call__(self, directories: Union[str, List[str]], recursive: Optional[bool] = True) -> None: """Convert I3-files in `directories. Args: directories: One or more directories, the I3 files within which should be converted to an intermediate file format. + recursive: Whether or not to search the directories recursively. """ # Find all I3 and GCD files in the specified directories. - i3_files, gcd_files = find_i3_files(directories, self._gcd_rescue) + i3_files, gcd_files = find_i3_files(directories, self._gcd_rescue, recursive) if len(i3_files) == 0: self.error(f"No files found in {directories}.") return diff --git a/src/graphnet/utilities/filesys.py b/src/graphnet/utilities/filesys.py index 1923c7ce3f62c8d85ef4f5ba6c2d8d521e7c3ac4..54ca39214d4dc84e44adce7250f72b9054a16ece 100644 --- a/src/graphnet/utilities/filesys.py +++ b/src/graphnet/utilities/filesys.py @@ -31,7 +31,8 @@ def has_extension(filename: str, extensions: List[str]) -> bool: def find_i3_files( - directories: Union[str, List[str]], gcd_rescue: Optional[str] = None + directories: Union[str, List[str]], gcd_rescue: Optional[str] = None, + recursive: Optional[bool] = True, ) -> Tuple[List[str], List[str]]: """Find I3 files and corresponding GCD files in `directories`. @@ -43,6 +44,7 @@ def find_i3_files( directories: Directories to search recursively for I3 files. gcd_rescue: Path to the GCD that will be default if no GCD is present in the directory. + recursive: Whether or not to search the directories recursively. Returns: i3_list: Paths to I3 files in `directories` @@ -57,11 +59,14 @@ def find_i3_files( for directory in directories: - # Recursively find all I3-like files in `directory`. + # Find all I3-like files in `directory`, may or may not be recursively. paths = [] i3_patterns = ["*.bz2", "*.zst", "*.gz"] for i3_pattern in i3_patterns: - paths.extend(list(Path(directory).rglob(i3_pattern))) + if recursive: + paths.extend(list(Path(directory).rglob(i3_pattern))) + else: + paths.extend(list(Path(directory).glob(i3_pattern))) # Loop over all folders containing such I3-like files. folders = sorted(set([path.parent for path in paths]))