Skip to content
Snippets Groups Projects
Commit cb84b8a6 authored by Daniel Guderian's avatar Daniel Guderian
Browse files

fixted tests

parent 808450f9
No related branches found
No related tags found
1 merge request!14revive make_data_split
processed_data_neutrino/processed_graph_neutrino.h5
processed_data_muon/processed_graph_muon.h5
processed_data_neutrino/processed_graph_neutrino.h5
......@@ -2,7 +2,8 @@ from unittest import TestCase
import os
import h5py
import numpy as np
from orcasong.tools.make_data_split import *
import toml
import orcasong.tools.make_data_split as mds
__author__ = 'Daniel Guderian'
......@@ -35,7 +36,9 @@ class TestMakeDataSplit(TestCase):
#create list_file_dir
if not os.path.exists(list_file_dir):
os.makedirs(list_file_dir)
@classmethod
def tearDownClass(cls):
#remove the lists created
......@@ -44,33 +47,34 @@ class TestMakeDataSplit(TestCase):
os.removedirs(list_file_dir)
def test_read_keys_off_config(self):
#decode config
self.cfg = toml.load(config_file)
self.cfg['toml_filename'] = config_file
self.cfg = read_config(config_file)
#get input groups and compare
self.ip_group_keys = get_all_ip_group_keys(self.cfg)
self.ip_group_keys = mds.get_all_ip_group_keys(self.cfg)
self.assertSequenceEqual(self.ip_group_keys,self.input_categories_list)
def test_get_filepath_and_n_events(self):
os.chdir(test_data_dir)
self.n_evts_total = 0
for key in self.ip_group_keys:
print('Collecting information from input group ' + key)
self.cfg[key]['fpaths'] = get_h5_filepaths(self.cfg[key]['dir'])
self.cfg[key]['n_files'] = len(self.cfg[key]['fpaths'])
self.cfg[key]['n_evts'], self.cfg[key]['n_evts_per_file_mean'], self.cfg[key]['run_ids'] = get_number_of_evts_and_run_ids(self.cfg[key]['fpaths'], dataset_key='y')
self.n_evts_total += self.cfg[key]['n_evts']
#repeat first 2 steps
self.cfg = read_config(config_file)
self.ip_group_keys = mds.get_all_ip_group_keys(self.cfg)
self.cfg,self.n_evts_total = update_cfg(self.cfg)
for key in self.ip_group_keys:
self.assertIn(self.cfg[key]['fpaths'][0],self.file_path_list)
self.assertIn(self.cfg[key]['n_evts'],self.n_events_list)
def test_make_split(self):
#main
#repeat first 3 steps
self.cfg = read_config(config_file)
self.ip_group_keys = mds.get_all_ip_group_keys(self.cfg)
self.cfg,self.n_evts_total = update_cfg(self.cfg)
self.cfg['n_evts_total'] = self.n_evts_total
print_input_statistics(self.cfg, self.ip_group_keys)
mds.print_input_statistics(self.cfg, self.ip_group_keys)
for key in self.ip_group_keys:
add_fpaths_for_data_split_to_cfg(self.cfg, key)
make_dsplit_list_files(self.cfg)
mds.add_fpaths_for_data_split_to_cfg(self.cfg, key)
mds.make_dsplit_list_files(self.cfg)
#assert the single output lists
with open(list_output_val) as f:
......@@ -82,4 +86,26 @@ class TestMakeDataSplit(TestCase):
self.assertIn(line,self.file_path_list)
f.close
def update_cfg(cfg):
''' Update the cfg with file paths and also return the total number of events'''
#get input groups and compare
ip_group_keys = mds.get_all_ip_group_keys(cfg)
os.chdir(test_data_dir)
n_evts_total = 0
for key in ip_group_keys:
print('Collecting information from input group ' + key)
cfg[key]['fpaths'] = mds.get_h5_filepaths(cfg[key]['dir'])
cfg[key]['n_files'] = len(cfg[key]['fpaths'])
cfg[key]['n_evts'], cfg[key]['n_evts_per_file_mean'], cfg[key]['run_ids'] = mds.get_number_of_evts_and_run_ids(cfg[key]['fpaths'], dataset_key='y')
n_evts_total += cfg[key]['n_evts']
return cfg,n_evts_total
def read_config(config_file):
#decode config
cfg = toml.load(config_file)
cfg['toml_filename'] = config_file
return cfg
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment