Skip to content
Snippets Groups Projects
Commit a4cfe2b0 authored by Daniel Guderian's avatar Daniel Guderian
Browse files

added test routine

parent 611acf0c
No related branches found
No related tags found
1 merge request!14revive make_data_split
File added
File added
......@@ -38,11 +38,11 @@
# --- Main options ---#
n_files_train = 3
n_files_train = 1
n_files_validate = 1
n_files_rest = 0
output_file_folder = "/sps/km3net/users/guderian/NN_stuff/split_data_output/ORCA4/graph/ts/"
output_file_folder = "data_split_test_output"
output_file_name = "test_list"
......@@ -54,16 +54,16 @@ print_only = false # only print information of your input_groups, don't make any
# --- Input groups : these are the datafiles, that should be concatenated somehow --- #
[elec_cc]
dir = "/sps/km3net/users/guderian/NN_stuff/graphs/ORCA4/base/gsg_elecCC-CC_1-500GeV.km3sim/test/"
run_ids_train = [6763, 6767]
run_ids_validate = [6768, 6769]
[neutrino]
dir = "processed_data_neutrino"
run_ids_train = [1, 6767]
run_ids_validate = [1, 6769]
[muon_nc]
dir = "/sps/km3net/users/guderian/NN_stuff/graphs/ORCA4/base/gsg_muonNC-NC_1-500GeV.km3sim/test/"
run_ids_train = [6763, 6767]
run_ids_validate = [6768, 6769]
[muon]
dir = "processed_data_muon"
run_ids_train = [1, 6767]
run_ids_validate = [9999, 6769]
# --- Input groups : these are the datafiles, that should be concatenated somehow --- #
\ No newline at end of file
......@@ -16,7 +16,7 @@ DET_FILE_NEUTRINO = os.path.join(test_dir, "data", "neutrino_detector_file.detx"
class TestStdRecoExtractor(TestCase):
""" Assert that the neutrino info is extracted correctly File has 18 events. """
""" Assert that the neutrino info is extracted correctly. File has 18 events. """
@classmethod
def setUpClass(cls):
......@@ -33,7 +33,7 @@ class TestStdRecoExtractor(TestCase):
cls.outfile = os.path.join(cls.tmpdir.name, "binned.h5")
cls.proc.run(infile=NEUTRINO_FILE, outfile=cls.outfile)
cls.f = h5py.File(cls.outfile, "r")
@classmethod
def tearDownClass(cls):
cls.f.close()
......@@ -192,3 +192,5 @@ class TestStdRecoExtractor(TestCase):
}
for k, v in target.items():
np.testing.assert_equal(y[k], v)
\ No newline at end of file
......@@ -7,8 +7,79 @@ from orcasong.tools.make_data_split import *
__author__ = 'Daniel Guderian'
test_dir = os.path.dirname(os.path.realpath(__file__))
mupage = os.path.join(test_dir, "data", "mupage.root.h5")
neutrino_file = os.path.join(test_dir, "data", "neutrino.h5")
config_file = os.path.join(test_dir, "data", "test_make_data_split_config.toml")
test_data_dir = os.path.join(test_dir, "data")
#these are files that were processed with orcasong
mupage_file = os.path.join(test_data_dir, "processed_data_muon", "processed_graph_muon.h5")
neutrino_file = os.path.join(test_data_dir,"processed_data_neutrino", "processed_graph_neutrino.h5")
#config file containing 2 input groups
config_file = os.path.join(test_data_dir, "test_make_data_split_config.toml")
#the list files that will be created
list_file_dir = os.path.join(test_data_dir, "data_split_test_output", "conc_list_files")
list_output_val = os.path.join(list_file_dir, "test_list_validate_0.txt")
list_output_train = os.path.join(list_file_dir, "test_list_train_0.txt")
#no idea how to tbh...
\ No newline at end of file
class TestMakeDataSplit(TestCase):
''' Runs the make_data_split like in the use case. At the end, the created lists are checked.'''
@classmethod
def setUpClass(cls):
#the expected lists to compare to
cls.input_categories_list = ["neutrino","muon"]
#include name with linebreak as they will look like this in the final files
cls.file_path_list = ['processed_data_muon/processed_graph_muon.h5','processed_data_neutrino/processed_graph_neutrino.h5',
'processed_data_muon/processed_graph_muon.h5\n','processed_data_neutrino/processed_graph_neutrino.h5\n']
cls.file_path_list_val = ['processed_data_neutrino/processed_graph_neutrino.h5','processed_data_neutrino/processed_graph_neutrino.h5\n']
cls.n_events_list = [18,3]
#create list_file_dir
if not os.path.exists(list_file_dir):
os.makedirs(list_file_dir)
@classmethod
def tearDownClass(cls):
#remove the lists created
os.remove(list_output_val)
os.remove(list_output_train)
os.removedirs(list_file_dir)
def test_read_keys_off_config(self):
#decode config
self.cfg = toml.load(config_file)
self.cfg['toml_filename'] = config_file
#get input groups and compare
self.ip_group_keys = get_all_ip_group_keys(self.cfg)
self.assertSequenceEqual(self.ip_group_keys,self.input_categories_list)
def test_get_filepath_and_n_events(self):
os.chdir(test_data_dir)
self.n_evts_total = 0
for key in self.ip_group_keys:
print('Collecting information from input group ' + key)
self.cfg[key]['fpaths'] = get_h5_filepaths(self.cfg[key]['dir'])
self.cfg[key]['n_files'] = len(self.cfg[key]['fpaths'])
self.cfg[key]['n_evts'], self.cfg[key]['n_evts_per_file_mean'], self.cfg[key]['run_ids'] = get_number_of_evts_and_run_ids(self.cfg[key]['fpaths'], dataset_key='y')
self.n_evts_total += self.cfg[key]['n_evts']
self.assertIn(self.cfg[key]['fpaths'][0],self.file_path_list)
self.assertIn(self.cfg[key]['n_evts'],self.n_events_list)
def test_make_split(self):
#main
self.cfg['n_evts_total'] = self.n_evts_total
print_input_statistics(self.cfg, self.ip_group_keys)
for key in self.ip_group_keys:
add_fpaths_for_data_split_to_cfg(self.cfg, key)
make_dsplit_list_files(self.cfg)
#assert the single output lists
with open(list_output_val) as f:
for line in f:
self.assertIn(line,self.file_path_list_val)
f.close
with open(list_output_train) as f2:
for line in f2:
self.assertIn(line,self.file_path_list)
f.close
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment