Skip to content
Snippets Groups Projects

Dom wise graph representation

Closed Jorge Prado requested to merge github/fork/RasmusOrsoe/dom_wise-graph-representation into main
1 unresolved thread
5 files
+ 106
4
Compare changes
  • Side-by-side
  • Inline
Files
5
@@ -19,6 +19,7 @@ class SQLiteDataset(torch.utils.data.Dataset):
truth_table: str = 'truth',
selection: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32,
node_representation = 'pulse',
    • Created by: asogaard

      This would probably be best suited for an "enum" or similar type of structure (i.e., to have a fixed set of options), but this is perfectly fine for now.

      • Created by: asogaard

        If you stick to strings, though, I would add a assert node_representation in ['pulse', 'node'] check in the constructor (see also comment below).

      • Please register or sign in to reply
Please register or sign in to reply
):
# Check(s)
@@ -37,6 +38,8 @@ class SQLiteDataset(torch.utils.data.Dataset):
assert isinstance(features, (list, tuple))
assert isinstance(truth, (list, tuple))
assert node_representation.lower() in ['pulse', 'dom']
self._database = database
self._pulsemaps = pulsemaps
self._features = [index_column] + features
@@ -44,6 +47,7 @@ class SQLiteDataset(torch.utils.data.Dataset):
self._index_column = index_column
self._truth_table = truth_table
self._dtype = dtype
self._node_representation = node_representation
self._features_string = ', '.join(self._features)
self._truth_string = ', '.join(self._truth)
@@ -118,6 +122,36 @@ class SQLiteDataset(torch.utils.data.Dataset):
except:
return -1
def _get_unique_positions(self, tensor):
return torch.unique(tensor, return_counts = True, return_inverse=True, dim=0)
    • Created by: asogaard

      This loop-based implement could likely be optimised for speed, but if it works and you haven't been impeded by the speed I am happy to included it and open an issue for future improvement.

      • Created by: RasmusOrsoe

        It's fairly performant, but it is slower than the pulse representation. I agree with your suggestion and would welcome input on how to vectorize this. I thought about it a bit and sought inspiration from your upgrade code, but I could not see a direct application for this. So I'd be quite interested to see the solution.

      • Please register or sign in to reply
Please register or sign in to reply
def _make_dom_wise_representation(self,data):
unique_doms, inverse_idx, n_pulses_pr_dom = self._get_unique_positions(data.x[:,[self._features.index('dom_x') -1,
self._features.index('dom_y') -1,
self._features.index('dom_z') -1,
self._features.index('rde') -1,
self._features.index('pmt_area') -1]])
unique_inverse_indices = torch.unique(inverse_idx)
count = 0
pulse_statistics = torch.zeros(size = (len(unique_doms), 8))
#'dom_x','dom_y','dom_z','dom_time','charge','rde','pmt_area'
time_idx = self._features.index('dom_time')-1
charge_idx = self._features.index('charge')-1
for unique_inverse_idx in unique_inverse_indices:
time = data.x[inverse_idx == unique_inverse_idx,time_idx]
charge = data.x[inverse_idx == unique_inverse_idx,charge_idx]
pulse_statistics[count,0] = torch.min(time)
pulse_statistics[count,1] = torch.mean(time)
pulse_statistics[count,2] = torch.max(time)
pulse_statistics[count,3] = torch.std(time, unbiased=False)
pulse_statistics[count,4] = torch.min(charge)
pulse_statistics[count,5] = torch.mean(charge)
pulse_statistics[count,6] = torch.max(charge)
pulse_statistics[count,7] = torch.std(charge, unbiased=False)
count +=1
data.x = torch.cat((unique_doms, n_pulses_pr_dom.unsqueeze(1), pulse_statistics), dim = 1)
return data
def _create_graph(self, features, truth):
"""Create Pytorch Data (i.e.graph) object.
@@ -161,12 +195,18 @@ class SQLiteDataset(torch.utils.data.Dataset):
data = np.array([]).reshape((0, len(self._features) - 1))
# Construct graph data object
x = torch.tensor(data, dtype=self._dtype)
n_pulses = torch.tensor(len(x), dtype=torch.int32)
graph = Data(
x=x,
edge_index= None
)
if self._node_representation.lower() == 'dom':
graph = self._make_dom_wise_representation(graph)
    • Created by: asogaard

      DRY. Also, the check for node_representation in ['pulse', 'node'] should probably be done in the constructor, such that we can assume that it has a standard value in all other methods. I'll assume that here to simplify.

              x = torch.tensor(data, dtype=self._dtype)
              n_pulses = torch.tensor(len(x), dtype=torch.int32)
              graph = Data(
                  x=x,
                  edge_index= None
              )
                  
              if self._node_representation == 'dom':
                  graph = self._make_dom_wise_representation(graph)
Please register or sign in to reply
graph.n_pulses = n_pulses
graph.features = self._features[1:]
Loading