diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
index b840a4c0c23aa776db76a7cc4ab438686d4f3a4b..00f59c0bd31a28134fb4cb53c1ef53343776ba9b 100644
--- a/.github/workflows/build.yml
+++ b/.github/workflows/build.yml
@@ -46,8 +46,10 @@ jobs:
       - uses: actions/checkout@v3
       - name: Upgrade packages already installed on icecube/icetray
         run: |
+          python --version
           pip install --upgrade astropy  # Installed version incompatible with numpy 1.23.0 [https://github.com/astropy/astropy/issues/12534]
           pip install --ignore-installed PyYAML  # Distutils installed [https://github.com/pypa/pip/issues/5247]
+          pip install --upgrade psutil # lets see..
       - name: Install package
         uses: ./.github/actions/install
         with:
@@ -73,7 +75,7 @@ jobs:
     runs-on: ubuntu-latest
     strategy:
       matrix:
-        python-version: [3.7, 3.8, 3.9, '3.10']
+        python-version: [3.8, 3.9, '3.10']
     steps:
       - uses: actions/checkout@v3
       - name: Set up Python ${{ matrix.python-version }}
diff --git a/requirements/torch_cpu.txt b/requirements/torch_cpu.txt
index 76f533905859fcfd079b3f330bf28dd843e6dd67..6f68e36005c1b17431905e94a37b5baf881731bc 100644
--- a/requirements/torch_cpu.txt
+++ b/requirements/torch_cpu.txt
@@ -1,7 +1,2 @@
---find-links https://download.pytorch.org/whl/torch_stable.html
-torch==1.11+cpu
---find-links https://data.pyg.org/whl/torch-1.11.0+cpu.html
-torch-cluster==1.6.0
-torch_scatter==2.0.9
-torch-sparse==0.6.13
-torch_geometric==2.0.4
\ No newline at end of file
+--find-links https://download.pytorch.org/whl/cpu
+--find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html
\ No newline at end of file
diff --git a/requirements/torch_gpu.txt b/requirements/torch_gpu.txt
index 4004fd8af68b32ac0015dcb3c5c2c0163b0b726c..c325f35afe41f2ebe975086ee74a595e46acc652 100644
--- a/requirements/torch_gpu.txt
+++ b/requirements/torch_gpu.txt
@@ -1,8 +1,3 @@
 # Contains packages recommended for functional performance
 --find-links https://download.pytorch.org/whl/torch_stable.html
-torch==1.11+cu115
---find-links https://data.pyg.org/whl/torch-1.11.0+cu115.html
-torch-cluster==1.6.0
-torch_scatter==2.0.9
-torch-sparse==0.6.13
-torch_geometric==2.0.4
\ No newline at end of file
+--find-links https://data.pyg.org/whl/torch-2.0.0+cu117.html
diff --git a/requirements/torch_macos.txt b/requirements/torch_macos.txt
index a9e43921c0f0012472faaefe2fcf9b9a1e7c76f1..be7a35257d19260cb76a351902f9d905345be654 100644
--- a/requirements/torch_macos.txt
+++ b/requirements/torch_macos.txt
@@ -1,7 +1,2 @@
 --find-links https://download.pytorch.org/whl/torch_stable.html
-torch==1.11
---find-links https://data.pyg.org/whl/torch-1.11.0+cpu.html
-torch-cluster==1.6.0
-torch_scatter==2.0.9
-torch-sparse==0.6.13
-torch_geometric==2.0.4
\ No newline at end of file
+--find-links https://data.pyg.org/whl/torch-2.0.0+cpu.html
\ No newline at end of file
diff --git a/setup.py b/setup.py
index b262f0fa4502ab40d28e6e88cf930cff0ec35e89..67cf250231f70549c9e596c02b0eed6d8c61eb39 100644
--- a/setup.py
+++ b/setup.py
@@ -47,12 +47,12 @@ EXTRAS_REQUIRE = {
         "versioneer",
     ],
     "torch": [
-        "torch>=1.11",
+        "torch>=2.0",
         "torch-cluster>=1.6",
         "torch-scatter>=2.0",
         "torch-sparse>=0.6",
-        "torch-geometric>=2.0",
-        "pytorch-lightning>=1.6, <2.0",
+        "torch-geometric>=2.1",
+        "pytorch-lightning>=2.0",
     ],
 }
 
@@ -61,7 +61,6 @@ CLASSIFIERS = [
     "Development Status :: 3 - Alpha",
     "Intended Audience :: Developers",
     "Intended Audience :: Science/Research",
-    "Programming Language :: Python :: 3.7",
     "Programming Language :: Python :: 3.8",
     "Programming Language :: Python :: 3.9",
     "Programming Language :: Python :: 3.10",
diff --git a/src/graphnet/data/dataloader.py b/src/graphnet/data/dataloader.py
index b199f5865afe89c0903cf23ac26f276d4a339f25..1ded6fa37d88b6fb32165bb405dbb90f6071a65f 100644
--- a/src/graphnet/data/dataloader.py
+++ b/src/graphnet/data/dataloader.py
@@ -34,6 +34,7 @@ class DataLoader(torch.utils.data.DataLoader):
         num_workers: int = 10,
         persistent_workers: bool = True,
         collate_fn: Callable = collate_fn,
+        prefetch_factor: int = 2,
         **kwargs: Any,
     ) -> None:
         """Construct `DataLoader`."""
@@ -45,7 +46,7 @@ class DataLoader(torch.utils.data.DataLoader):
             num_workers=num_workers,
             collate_fn=collate_fn,
             persistent_workers=persistent_workers,
-            prefetch_factor=2,
+            prefetch_factor=prefetch_factor,
             **kwargs,
         )
 
diff --git a/src/graphnet/models/model.py b/src/graphnet/models/model.py
index 5e95ae9174a927fc7842cfdbf4c9d075cc562ef8..00acf910910712628a3b6f2936b8952c0c80cda4 100644
--- a/src/graphnet/models/model.py
+++ b/src/graphnet/models/model.py
@@ -29,8 +29,8 @@ class Model(Logger, Configurable, LightningModule, ABC):
     def forward(self, x: Union[Tensor, Data]) -> Union[Tensor, Data]:
         """Forward pass."""
 
-    def _construct_trainers(
-        self,
+    @staticmethod
+    def _construct_trainer(
         max_epochs: int = 10,
         gpus: Optional[Union[List[int], int]] = None,
         callbacks: Optional[List[Callback]] = None,
@@ -40,16 +40,16 @@ class Model(Logger, Configurable, LightningModule, ABC):
         gradient_clip_val: Optional[float] = None,
         distribution_strategy: Optional[str] = "ddp",
         **trainer_kwargs: Any,
-    ) -> None:
+    ) -> Trainer:
 
         if gpus:
             accelerator = "gpu"
             devices = gpus
         else:
             accelerator = "cpu"
-            devices = None
+            devices = 1
 
-        self._trainer = Trainer(
+        trainer = Trainer(
             accelerator=accelerator,
             devices=devices,
             max_epochs=max_epochs,
@@ -58,21 +58,11 @@ class Model(Logger, Configurable, LightningModule, ABC):
             logger=logger,
             gradient_clip_val=gradient_clip_val,
             strategy=distribution_strategy,
+            default_root_dir=ckpt_path,
             **trainer_kwargs,
         )
 
-        inference_devices = devices
-        if isinstance(inference_devices, list):
-            inference_devices = inference_devices[:1]
-
-        self._inference_trainer = Trainer(
-            accelerator=accelerator,
-            devices=inference_devices,
-            callbacks=callbacks,
-            logger=logger,
-            strategy=None,
-            **trainer_kwargs,
-        )
+        return trainer
 
     def fit(
         self,
@@ -101,7 +91,7 @@ class Model(Logger, Configurable, LightningModule, ABC):
             )
 
         self.train(mode=True)
-        self._construct_trainers(
+        trainer = self._construct_trainer(
             max_epochs=max_epochs,
             gpus=gpus,
             callbacks=callbacks,
@@ -114,7 +104,7 @@ class Model(Logger, Configurable, LightningModule, ABC):
         )
 
         try:
-            self._trainer.fit(
+            trainer.fit(
                 self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
             )
         except KeyboardInterrupt:
@@ -157,7 +147,7 @@ class Model(Logger, Configurable, LightningModule, ABC):
         self,
         dataloader: DataLoader,
         gpus: Optional[Union[List[int], int]] = None,
-        distribution_strategy: Optional[str] = None,
+        distribution_strategy: Optional[str] = "auto",
     ) -> List[Tensor]:
         """Return predictions for `dataloader`.
 
@@ -165,17 +155,17 @@ class Model(Logger, Configurable, LightningModule, ABC):
         """
         self.train(mode=False)
 
-        if not hasattr(self, "_inference_trainer"):
-            self._construct_trainers(
-                gpus=gpus, distribution_strategy=distribution_strategy
-            )
-        elif gpus is not None:
-            self.warning(
-                "A `Trainer` instance has already been constructed, possibly "
-                "when the model was trained. Will use this to get predictions. "
-                f"Argument `gpus = {gpus}` will be ignored."
-            )
-        predictions_list = self._inference_trainer.predict(self, dataloader)
+        callbacks = self._create_default_callbacks(
+            val_dataloader=None,
+        )
+
+        inference_trainer = self._construct_trainer(
+            gpus=gpus,
+            distribution_strategy=distribution_strategy,
+            callbacks=callbacks,
+        )
+
+        predictions_list = inference_trainer.predict(self, dataloader)
         assert len(predictions_list), "Got no predictions"
 
         nb_outputs = len(predictions_list[0])
@@ -195,7 +185,7 @@ class Model(Logger, Configurable, LightningModule, ABC):
         additional_attributes: Optional[List[str]] = None,
         index_column: str = "event_no",
         gpus: Optional[Union[List[int], int]] = None,
-        distribution_strategy: Optional[str] = None,
+        distribution_strategy: Optional[str] = "auto",
     ) -> pd.DataFrame:
         """Return predictions for `dataloader` as a DataFrame.
 
diff --git a/src/graphnet/models/standard_model.py b/src/graphnet/models/standard_model.py
index 41b70bb269f49ffede804dc7b542be27676787bb..844f4f55b91229d7dcb6e2ad179e461174a8d26b 100644
--- a/src/graphnet/models/standard_model.py
+++ b/src/graphnet/models/standard_model.py
@@ -179,7 +179,7 @@ class StandardModel(Model):
         self,
         dataloader: DataLoader,
         gpus: Optional[Union[List[int], int]] = None,
-        distribution_strategy: Optional[str] = None,
+        distribution_strategy: Optional[str] = "auto",
     ) -> List[Tensor]:
         """Return predictions for `dataloader`."""
         self.inference()
@@ -198,7 +198,7 @@ class StandardModel(Model):
         additional_attributes: Optional[List[str]] = None,
         index_column: str = "event_no",
         gpus: Optional[Union[List[int], int]] = None,
-        distribution_strategy: Optional[str] = None,
+        distribution_strategy: Optional[str] = "auto",
     ) -> pd.DataFrame:
         """Return predictions for `dataloader` as a DataFrame.
 
diff --git a/src/graphnet/training/callbacks.py b/src/graphnet/training/callbacks.py
index 04d30933f2bf29102592ae28ccfadd25f8f77990..a66255ca691d9dafb83bab8ea382bc22e5d8effa 100644
--- a/src/graphnet/training/callbacks.py
+++ b/src/graphnet/training/callbacks.py
@@ -123,12 +123,12 @@ class ProgressBar(TQDMProgressBar):
         lightning is to overwrite the progress bar from previous epochs.
         """
         if trainer.current_epoch > 0:
-            self.main_progress_bar.set_postfix(
+            self.train_progress_bar.set_postfix(
                 self.get_metrics(trainer, model)
             )
             print("")
         super().on_train_epoch_start(trainer, model)
-        self.main_progress_bar.set_description(
+        self.train_progress_bar.set_description(
             f"Epoch {trainer.current_epoch:2d}"
         )
 
@@ -150,5 +150,5 @@ class ProgressBar(TQDMProgressBar):
             assert isinstance(h, logging.StreamHandler)
             level = h.level
             h.setLevel(logging.ERROR)
-            logger.info(str(super().main_progress_bar))
+            logger.info(str(super().train_progress_bar))
             h.setLevel(level)