diff --git a/src/graphnet/data/dataset.py b/src/graphnet/data/dataset.py
index a0110988b6e4e844ec83c02acdee5c5e89ddf6b0..22a5c185c1d63da5a8eebc67887eae5a1d3c3e23 100644
--- a/src/graphnet/data/dataset.py
+++ b/src/graphnet/data/dataset.py
@@ -17,6 +17,8 @@ from typing import (
 import numpy as np
 import torch
 from torch_geometric.data import Data
+from torch.utils.data import TensorDataset
+
 
 from graphnet.constants import GRAPHNET_ROOT_DIR
 from graphnet.data.utilities.string_selection_resolver import (
@@ -150,6 +152,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
         loss_weight_column: Optional[str] = None,
         loss_weight_default_value: Optional[float] = None,
         seed: Optional[int] = None,
+        timeseries: Optional[bool] = None,
     ):
         """Construct Dataset.
 
@@ -195,6 +198,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
                 subset of events when resolving a string-based selection (e.g.,
                 `"10000 random events ~ event_no % 5 > 0"` or `"20% random
                 events ~ event_no % 5 > 0"`).
+            timeseries: Whether the dataset is a timeseries dataset. Defaults to None (Not a timeseries dataset).
         """
         # Base class constructor
         super().__init__(name=__name__, class_name=self.__class__.__name__)
@@ -218,6 +222,7 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
         self._index_column = index_column
         self._truth_table = truth_table
         self._loss_weight_default_value = loss_weight_default_value
+        self._timeseries = timeseries
 
         if node_truth is not None:
             assert isinstance(node_truth_table, str)
@@ -375,7 +380,14 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
         features, truth, node_truth, loss_weight = self._query(
             sequential_index
         )
-        graph = self._create_graph(features, truth, node_truth, loss_weight)
+        if self._timeseries is not False:
+            graph = self._create_timeseries(
+                features, truth, node_truth, loss_weight
+            )
+        else:
+            graph = self._create_graph(
+                features, truth, node_truth, loss_weight
+            )
         return graph
 
     # Internal method(s)
@@ -616,6 +628,223 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
         graph["dataset_path"] = self._path
         return graph
 
+    def _create_timeseries(
+        self,
+        features: List[Tuple[float, ...]],
+        truth: Tuple[Any, ...],
+        node_truth: Optional[List[Tuple[Any, ...]]] = None,
+        loss_weight: Optional[float] = None,
+    ) -> Data:
+        """Create Pytorch Data (i.e. timeseries) object.
+
+        No preprocessing is performed at this stage, just as no node adjancency
+        is imposed. This means that the `edge_attr` and `edge_weight`
+        attributes are not set.
+
+        Args:
+            features: List of list of tuples, containing event features.
+            truth: List of tuples, containing truth information.
+            node_truth: List of tuples, containing node-level truth.
+            loss_weight: A weight associated with the event for weighing the
+                loss.
+
+        Returns:
+            time series object.
+        """
+        # Convert nested list to simple dict
+        truth_dict = {
+            key: truth[index] for index, key in enumerate(self._truth)
+        }
+
+        # Define custom labels
+        labels_dict = self._get_labels(truth_dict)
+
+        # Convert nested list to simple dict
+        if node_truth is not None:
+            node_truth_array = np.asarray(node_truth)
+            assert self._node_truth is not None
+            node_truth_dict = {
+                key: node_truth_array[:, index]
+                for index, key in enumerate(self._node_truth)
+            }
+
+        # Catch cases with no reconstructed pulses
+        if len(features):
+            data = np.asarray(features)[:, 1:]
+        else:
+            data = np.array([]).reshape((0, len(self._features) - 1))
+
+        # Construct graph data object
+        x = torch.tensor(data, dtype=self._dtype)  # pylint: disable=C0103
+        n_pulses = torch.tensor(len(x), dtype=torch.long)
+        if "dom_number" not in features:
+            number_index = -1
+            unique_index = [
+                self._features.index(key) - 1
+                for key in ["dom_x", "dom_y", "dom_z"]
+            ]
+            dom_number = torch.unique(
+                x[:, unique_index], return_inverse=True, dim=0, sorted=False
+            )[-1]
+            x = torch.cat([x, dom_number.reshape(-1, 1)], dim=1)
+        else:
+            number_index = self._features.index("dom_number") - 1
+
+        dom_activation_sort = x[:, number_index].sort()[-1]
+        x = x[dom_activation_sort]
+        bin_count = torch.bincount(
+            x[:, number_index].type(torch.int64)
+        ).cumsum(0)
+        # removing dom_number from features should maybe be optional.
+        if number_index != -1:
+            x = x[:, torch.arange(x.shape[-1]) != number_index]
+        else:
+            x = x[:, :-1]
+        n_activations = bin_count - torch.cat(
+            [torch.tensor([0]), bin_count[:-1]]
+        )
+        activations_sort = n_activations.sort(descending=True)[-1]
+        x = np.array(
+            [np.array(_) for _ in torch.tensor_split(x, bin_count[:-1])],
+            dtype=object,
+        )[activations_sort]
+        x = [
+            torch.tensor(
+                dom_list[
+                    np.argsort(
+                        dom_list[
+                            :,
+                            self._features.index("dom_time"),
+                        ]
+                    )[::-1]
+                ]
+            )
+            for dom_list in x
+        ]
+        x, xyztt = Dataset.DOM_time_series_to_pack_sequence(
+            x, device="cpu", features=self._features[1:]
+        )
+        graph = TimeSeries(x=xyztt, edge_index=None)
+        graph.time_series = [
+            x.data,
+            x.batch_sizes,
+            x.sorted_indices,
+            x.unsorted_indices,
+        ]
+        # print(x.batch_sizes)
+        graph.n_pulses = n_pulses
+        graph.features = self._features[1:]
+        # graph.n_activations = n_activations[activations_sort]
+
+        # Add loss weight to graph.
+        if loss_weight is not None and self._loss_weight_column is not None:
+            # No loss weight was retrieved, i.e., it is missing for the current
+            # event.
+            if loss_weight < 0:
+                if self._loss_weight_default_value is None:
+                    raise ValueError(
+                        "At least one event is missing an entry in "
+                        f"{self._loss_weight_column} "
+                        "but loss_weight_default_value is None."
+                    )
+                graph[self._loss_weight_column] = torch.tensor(
+                    self._loss_weight_default_value, dtype=self._dtype
+                ).reshape(-1, 1)
+            else:
+                graph[self._loss_weight_column] = torch.tensor(
+                    loss_weight, dtype=self._dtype
+                ).reshape(-1, 1)
+
+        # Write attributes, either target labels, truth info or original
+        # features.
+        add_these_to_graph = [labels_dict, truth_dict]
+        if node_truth is not None:
+            add_these_to_graph.append(node_truth_dict)
+        for write_dict in add_these_to_graph:
+            for key, value in write_dict.items():
+                try:
+                    graph[key] = torch.tensor(value)
+                except TypeError:
+                    # Cannot convert `value` to Tensor due to its data type,
+                    # e.g. `str`.
+                    self.debug(
+                        (
+                            f"Could not assign `{key}` with type "
+                            f"'{type(value).__name__}' as attribute to graph."
+                        )
+                    )
+
+        # Additionally add original features as (static) attributes
+        for index, feature in enumerate(graph.features[:-2]):
+            if feature not in ["x"]:
+                graph[feature] = graph.x[:, index]
+
+        # Add custom labels to the graph
+        for key, fn in self._label_fns.items():
+            graph[key] = fn(graph)
+
+        # Add Dataset Path. Useful if multiple datasets are concatenated.
+        graph["dataset_path"] = self._path
+        return graph
+
+    def DOM_time_series_to_pack_sequence(
+        tensor: torch.Tensor, device: torch.device, features: List[str]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Packs DOM time series data into a pack sequence.
+
+        Also returns a tensor of x/y/z-coordinates and time of first/mean activation for each DOM.
+
+        Args:
+            tensor (torch.Tensor): Input tensor.
+            device (torch.device): Device to use.
+            features (list): List of features in the tensor.
+
+        Returns:
+            tensor (torch.Tensor): Packed sequence of DOM time series data.
+            xyztt (torch.Tensor): Tensor of x/y/z-coordinates as well as time of mean & first activations for each DOM.
+        """
+        xyztc = torch.stack(
+            [
+                torch.cat(
+                    [
+                        v[
+                            0,
+                            [
+                                features.index("dom_x"),
+                                features.index("dom_y"),
+                                features.index("dom_z"),
+                            ],
+                        ],
+                        torch.as_tensor(
+                            [
+                                v[:, features.index("dom_time")].mean(),
+                                v[:, features.index("charge")].sum(),
+                            ],
+                            device=device,
+                        ),
+                    ]
+                )
+                for v in tensor
+            ]
+        )
+
+        tensor = torch.nn.utils.rnn.pack_sequence(tensor, enforce_sorted=True)
+        return tensor, xyztc
+
+    @torch.jit.ignore
+    def sorted_jit_ignore(tensor: torch.Tensor) -> torch.Tensor:
+        """Sort a tensor based on the length of the elements.
+
+        Args:
+            tensor (torch.Tensor): Input tensor.
+
+        Returns:
+            torch.Tensor: Sorted tensor.
+            torch.Tensor: Indices of sorted tensor.
+        """
+        tensor = sorted(tensor, key=len, reverse=True)
+        return tensor
+
     def _get_labels(self, truth_dict: Dict[str, Any]) -> Dict[str, Any]:
         """Return dictionary of  labels, to be added as graph attributes."""
         if "pid" in truth_dict.keys():
@@ -664,6 +893,18 @@ class Dataset(Logger, Configurable, torch.utils.data.Dataset, ABC):
             return -1
 
 
+class TimeSeries(Data):
+    """Dataset class for time series data."""
+
+    def __cat_dim__(
+        self: Data, key: str, value: Any, *args: Any, **kwargs: Dict[str, Any]
+    ) -> Any:
+        """Specify change in concatenation dimension for specific key."""
+        if key == "time_series":
+            return 0
+        return super().__cat_dim__(key, value, *args, **kwargs)
+
+
 class EnsembleDataset(torch.utils.data.ConcatDataset):
     """Construct a single dataset from a collection of datasets."""
 
diff --git a/src/graphnet/models/components/DOM_handling.py b/src/graphnet/models/components/DOM_handling.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bfc95b4d16d78910a96d6e4c7ae70600bceb027
--- /dev/null
+++ b/src/graphnet/models/components/DOM_handling.py
@@ -0,0 +1,147 @@
+"""Functions for handling event level DOM data."""
+from typing import Tuple, Optional
+from torch_geometric.nn.pool import knn_graph
+
+import torch
+
+
+@torch.jit.script
+def append_dom_id(
+    tensor: torch.Tensor,
+    batch: torch.Tensor,
+    device: torch.device = torch.device("cpu"),
+) -> torch.Tensor:
+    """Assign a unique ID to each DOM.
+
+    The ID is assigned based on the position of the DOM in the input tensor. requires x,y,z as the first three columns of the input tensor.
+
+    Args:
+        tensor (torch.Tensor): Input tensor.
+        batch (torch.Tensor, optional): Batch tensor. Defaults to None.
+        device (torch.device): Device to use.
+
+    Returns:
+        torch.Tensor: Tensor of DOM IDs.
+    """
+    inverse_matrix = torch.zeros(
+        (tensor.shape[0], 3), device=device, dtype=torch.int64
+    )
+    for i in range(3):
+        _, inverse = torch.unique(tensor[:, i], return_inverse=True)
+        inverse_matrix[:, i] = inverse
+
+    if batch is None:
+        batch = torch.zeros(tensor.shape[0], device=device, dtype=torch.int64)
+
+    inverse_matrix = torch.hstack([batch.unsqueeze(1), inverse_matrix])
+    for i in range(3):
+        inverse_matrix[:, 1] = inverse_matrix[:, 0] + (
+            (torch.max(inverse_matrix[:, 0]) + 1) * (inverse_matrix[:, 1] + 1)
+        )
+
+        _, inverse_matrix[:, 1] = torch.unique(
+            inverse_matrix[:, 1], return_inverse=True
+        )
+
+        inverse_matrix = inverse_matrix[:, -(3 - i) :]
+
+    inverse_matrix = inverse_matrix.flatten()
+    tensor = torch.hstack([tensor, inverse_matrix.unsqueeze(1)])
+    return tensor
+
+
+torch.jit.script
+
+
+def DOM_to_time_series(
+    tensor: torch.Tensor, batch: torch.tensor
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Create time series for each activated DOM from dom activation data.
+
+    Returns a time series view of the DOM activations as well as an updated batch tensor. REQURES the DOM ID to be the last column of the input tensor.
+    Args:
+        tensor (torch.Tensor): Input tensor.
+        batch (torch.Tensor): Batch tensor.
+
+    Returns:
+        tensor (torch.Tensor): Times series DOM data.
+        sort_batch (torch.Tensor): Batch tensor.
+    """
+    dom_activation_sort = tensor[:, -1].sort()[-1]
+    tensor, batch = tensor[dom_activation_sort], batch[dom_activation_sort]
+    bin_count = torch.bincount(tensor[:, -1].type(torch.int64)).cumsum(0)
+    batch = batch[bin_count - 1]
+    tensor = tensor[:, :-1]
+    tensor = torch.tensor_split(tensor, bin_count.cpu()[:-1])
+    lengths_index = (
+        torch.as_tensor([v.size(0) for v in tensor]).sort()[-1].flip(0)
+    )
+    batch = batch[lengths_index]
+
+    return tensor, batch
+
+
+torch.jit.script
+
+
+def DOM_time_series_to_pack_sequence(
+    tensor: torch.Tensor, device: torch.device
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """Packs DOM time series data into a pack sequence.
+
+    Also returns a tensor of x/y/z-coordinates and time of first/mean activation for each DOM.
+
+    Args:
+        tensor (torch.Tensor): Input tensor.
+        device (torch.device): Device to use.
+
+    Returns:
+        tensor (torch.Tensor): Packed sequence of DOM time series data.
+        xyztt (torch.Tensor): Tensor of x/y/z-coordinates as well as time of mean & first activations for each DOM.
+    """
+    tensor = sorted_jit_ignore(tensor)
+    xyztt = torch.stack(
+        [
+            torch.cat(
+                [
+                    v[0, :3],
+                    torch.as_tensor(
+                        [v[:, 4].mean(), v[:, 4].min()], device=device
+                    ),
+                ]
+            )
+            for v in tensor
+        ]
+    )
+
+    tensor = torch.nn.utils.rnn.pack_sequence(tensor, enforce_sorted=True)
+    return tensor, xyztt
+
+
+@torch.jit.ignore
+def sorted_jit_ignore(tensor: torch.Tensor) -> torch.Tensor:
+    """Sort a tensor based on the length of the elements.
+
+    Args:
+        tensor (torch.Tensor): Input tensor.
+
+    Returns:
+        torch.Tensor: Sorted tensor.
+        torch.Tensor: Indices of sorted tensor.
+    """
+    tensor = sorted(tensor, key=len, reverse=True)
+    return tensor
+
+
+@torch.jit.ignore
+def knn_graph_ignore(
+    x: torch.Tensor, k: int, batch: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+    """Create a kNN graph based on the input data.
+
+    Args:
+        x: Input data.
+        k: Number of neighbours.
+        batch: Batch index.
+    """
+    return knn_graph(x=x, k=k, batch=batch)
diff --git a/src/graphnet/models/detector/icecube.py b/src/graphnet/models/detector/icecube.py
index e6addb8e37f178bbae3994fdd8ff5cba759d36c9..4bd162042c754e70c9fd9b151c614bf8b4228e7d 100644
--- a/src/graphnet/models/detector/icecube.py
+++ b/src/graphnet/models/detector/icecube.py
@@ -71,6 +71,72 @@ class IceCubeKaggle(Detector):
         return data
 
 
+class IceCube86RNN(Detector):
+    """`Detector` class for IceCube-86 with a time-series data input."""
+
+    features = FEATURES.ICECUBE86[:5]
+
+    def _forward(self, data: Data) -> Data:
+        """Ingest data, build graph, and preprocess features.
+
+        Args:
+            data: Input graph data.
+
+        Returns:
+            Connected and preprocessed time-series graph data.
+        """
+        self._validate_features(data)
+
+        data.x[:, 0] /= 500.0  # dom_x
+        data.x[:, 1] /= 500.0  # dom_y
+        data.x[:, 2] /= 500.0  # dom_z
+        data.x[:, 3] = (data.x[:, 3] - 1.0e04) / 3.0e4  # dom_time
+        data.x[:, 4] = torch.log10(data.x[:, 4]) / 3.0  # charge
+
+        data.time_series[0] /= torch.tensor(
+            [500.0, 500.0, 500.0, 1.0, 1.0], device=data.x.device
+        )
+        data.time_series[0][:, 3] = (
+            data.time_series[0][:, 3] - 1.0e04
+        ) / 3.0e4
+        data.time_series[0][:, 4] = (
+            torch.log10(data.time_series[0][:, 4]) / 3.0
+        )
+
+        data.time_series[0] = torch.tensor_split(
+            data.time_series[0], (data.n_pulses.cumsum(0)[:-1]).cpu()
+        )
+
+        data.time_series[1] = torch.tensor_split(
+            data.time_series[1],
+            (
+                torch.argwhere(
+                    torch.gt(data.time_series[1][1:], data.time_series[1][:-1])
+                ).flatten()
+                + 1
+            ).cpu(),
+        )
+
+        time_series = []
+        for sequence, batch_sizes, sorted_indices, unsorted_indices in zip(
+            data.time_series[0],
+            data.time_series[1],
+            data.time_series[2],
+            data.time_series[3],
+        ):
+            time_series.append(
+                torch.nn.utils.rnn.PackedSequence(
+                    data=sequence,
+                    batch_sizes=batch_sizes.cpu(),
+                    sorted_indices=sorted_indices,
+                    unsorted_indices=unsorted_indices,
+                )
+            )
+
+        data.time_series = time_series
+        return data
+
+
 class IceCubeDeepCore(IceCube86):
     """`Detector` class for IceCube-DeepCore."""
 
diff --git a/src/graphnet/models/gnn/__init__.py b/src/graphnet/models/gnn/__init__.py
index 4da7f5b21e2948f709678dc0a6f9c8f3cb3e3b21..b5092e9fef6154ed792b3cdd3000385d1534c679 100644
--- a/src/graphnet/models/gnn/__init__.py
+++ b/src/graphnet/models/gnn/__init__.py
@@ -3,3 +3,4 @@
 from .convnet import ConvNet
 from .dynedge import DynEdge
 from .dynedge_jinst import DynEdgeJINST
+from .node_rnn import Node_RNN
diff --git a/src/graphnet/models/gnn/node_rnn.py b/src/graphnet/models/gnn/node_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..490c1f02425c7156111682b730435a00a21bfe6e
--- /dev/null
+++ b/src/graphnet/models/gnn/node_rnn.py
@@ -0,0 +1,68 @@
+"""Implementation of the NodeTimeRNN model.
+
+(cannot be used as a standalone model)
+"""
+import torch
+
+from graphnet.models.gnn.gnn import GNN
+from graphnet.models.components.DOM_handling import (
+    append_dom_id,
+    DOM_time_series_to_pack_sequence,
+    DOM_to_time_series,
+    knn_graph_ignore,
+)
+from graphnet.utilities.config import save_model_config
+from torch_geometric.data import Data
+
+
+class Node_RNN(GNN):
+    """Implementation of the RNN model architecture.
+
+    The model takes as input the typical DOM data format and transforms it into
+    a time series of DOM activations pr. DOM. before applying a RNN layer and
+    outputting the an RNN output for each DOM.
+    """
+
+    @save_model_config
+    def __init__(
+        self,
+        nb_inputs: int,
+        hidden_size: int,
+        num_layers: int,
+        nb_neighbours: int,
+    ) -> None:
+        """Construct `NodeTimeRNN`.
+
+        Args:
+            nb_inputs: Number of features in the input data.
+            hidden_size: Number of features for the RNN output and hidden layers.
+            num_layers: Number of layers in the RNN.
+            nb_neighbours: Number of neighbours to use when reconstructing the graph representation.
+        """
+        self._nb_neighbours = nb_neighbours
+        self._num_layers = num_layers
+        self._hidden_size = hidden_size
+        self._nb_neighbours = nb_neighbours
+        self._nb_inputs = nb_inputs
+        super().__init__(nb_inputs, hidden_size + 5)
+
+        self._rnn = torch.nn.RNN(
+            num_layers=self._num_layers,
+            input_size=self._nb_inputs,
+            hidden_size=self._hidden_size,
+        )
+
+    def forward(self, data: Data) -> torch.Tensor:
+        """Apply learnable forward pass to the GNN."""
+        rnn_out = []
+        for b_uniq in data.batch.unique():
+            rnn_out.append(
+                self._rnn(data.time_series[b_uniq])[-1][0]
+            )  # apply rnn layer
+        # x = self._rnn(x)[-1][0]  # apply rnn layer
+        rnn_out = torch.cat(rnn_out)
+
+        data.x = torch.hstack(
+            [data.x, rnn_out]
+        )  # reintroduce x/y/z-coordinates and time of first/mean activation for each DOM
+        return data
diff --git a/src/graphnet/models/gnn/rnn_dynedge.py b/src/graphnet/models/gnn/rnn_dynedge.py
new file mode 100644
index 0000000000000000000000000000000000000000..0deec711319d7c3e5a688398fce5ea5f623406f1
--- /dev/null
+++ b/src/graphnet/models/gnn/rnn_dynedge.py
@@ -0,0 +1,90 @@
+"""RNN_DynEdge model implementation."""
+from typing import List, Optional, Tuple, Union
+
+import torch
+from graphnet.models.gnn.gnn import GNN
+from graphnet.models.gnn.dynedge import DynEdge
+from graphnet.models.gnn.node_rnn import Node_RNN
+
+from graphnet.utilities.config import save_model_config
+from torch_geometric.data import Data
+
+
+class RNN_DynEdge(GNN):
+    """The RNN_DynEdge model class."""
+
+    @save_model_config
+    def __init__(
+        self,
+        nb_inputs: int,
+        *,
+        nb_neighbours: int = 8,
+        RNN_layers: int = 1,
+        RNN_hidden_size: int = 64,
+        features_subset: Optional[Union[List[int], slice]] = None,
+        dynedge_layer_sizes: Optional[List[Tuple[int, ...]]] = None,
+        post_processing_layer_sizes: Optional[List[int]] = None,
+        readout_layer_sizes: Optional[List[int]] = None,
+        global_pooling_schemes: Optional[Union[str, List[str]]] = None,
+        add_global_variables_after_pooling: bool = False,
+    ):
+        """Initialize the RNN_DynEdge model.
+
+        Args:
+            nb_inputs (int): Number of input features.
+            nb_neighbours (int, optional): Number of neighbours to consider.
+                Defaults to 8.
+            RNN_layers (int, optional): Number of RNN layers.
+                Defaults to 1.
+            RNN_hidden_size (int, optional): Size of the hidden state of the RNN. Also determines the size of the output of the RNN.
+                Defaults to 64.
+            features_subset (Optional[Union[List[int], slice]], optional): Subset of features to use.
+            dynedge_layer_sizes (Optional[List[Tuple[int, ...]]], optional): List of tuples of integers representing the sizes of the hidden layers of the DynEdge model.
+            post_processing_layer_sizes (Optional[List[int]], optional): List of integers representing the sizes of the hidden layers of the post-processing model.
+            readout_layer_sizes (Optional[List[int]], optional): List of integers representing the sizes of the hidden layers of the readout model.
+            global_pooling_schemes (Optional[Union[str, List[str]]], optional): Pooling schemes to use. Defaults to None.
+            add_global_variables_after_pooling (bool, optional): Whether to add global variables after pooling. Defaults to False.
+        """
+        self._nb_neighbours = nb_neighbours
+        self._nb_inputs = nb_inputs
+        self._RNN_layers = RNN_layers
+        self._RNN_hidden_size = RNN_hidden_size
+
+        self._feautres_subset = features_subset
+        self._dynedge_layer_sizes = dynedge_layer_sizes
+        self._post_processing_layer_sizes = post_processing_layer_sizes
+        self._global_pooling_schemes = global_pooling_schemes
+        self._add_global_variables_after_pooling = (
+            add_global_variables_after_pooling
+        )
+        if readout_layer_sizes is None:
+            readout_layer_sizes = [
+                512,
+            ]
+        self._readout_layer_sizes = readout_layer_sizes
+
+        super().__init__(nb_inputs, self._readout_layer_sizes[-1])
+
+        self._rnn = Node_RNN(
+            num_layers=self._RNN_layers,
+            nb_inputs=self._nb_inputs,
+            hidden_size=self._RNN_hidden_size,
+            nb_neighbours=self._nb_neighbours,
+        )
+
+        self._dynedge = DynEdge(
+            nb_inputs=self._RNN_hidden_size + 5,
+            nb_neighbours=self._nb_neighbours,
+            dynedge_layer_sizes=self._dynedge_layer_sizes,
+            post_processing_layer_sizes=self._post_processing_layer_sizes,
+            readout_layer_sizes=self._readout_layer_sizes,
+            global_pooling_schemes=self._global_pooling_schemes,
+            add_global_variables_after_pooling=self._add_global_variables_after_pooling,
+        )
+
+    def forward(self, data: Data) -> torch.Tensor:
+        """Apply learnable forward pass of the RNN and DynEdge models."""
+        data = self._rnn(data)
+        readout = self._dynedge(data)
+
+        return readout
diff --git a/src/graphnet/utilities/config/dataset_config.py b/src/graphnet/utilities/config/dataset_config.py
index fc5beb44b01d429f173ce3e9b8d28f12d3ceaef4..782ff7c6f03239c76f85ef55c77bf75a4bf4c216 100644
--- a/src/graphnet/utilities/config/dataset_config.py
+++ b/src/graphnet/utilities/config/dataset_config.py
@@ -47,6 +47,7 @@ class DatasetConfig(BaseConfig):
     loss_weight_default_value: Optional[float] = None
 
     seed: Optional[int] = None
+    timeseries: Optional[bool] = None
 
     def __init__(self, **data: Any) -> None:
         """Construct `DataConfig`.