Train a hybrid model - MLP/LSTM

In this tutorial, we demonstrate the steps to train a hybrid model using MLP or LSTM as the Neural Network component. The training data is derived from the pre-processed outputs generated after running the conceptual model. The Hybrid Model is trained to predict \(Q\) (runoff) by matching the observed data.

IMPORTANT: While this tutorial uses a file listing 2 basins, it does not demonstrate a multi-basin model. Instead, a single-basin model is run for each basin individually.

Before we start

  • This tutorial is rendered from a Jupyter notebook that is hosted on GitHub. If you’d like to run the code yourself, you can access the notebook and configuration files directly from the repository: 03-TrainHybridModel.

  • To run this notebook locally, ensure you have completed the setup steps outlined in Getting started. These steps include setting up the environment, installing the required packages, and preparing the data files necessary for the tutorial.

  • Dependency on a Previous Tutorial: Before running this tutorial, you must complete the 01-RunConceptModel Tutorial. After completing it:

    1- Move the generated run folder to src/data.

    2- Update the data_dir field in the config_run_train_mlp.yml (here) (or config_run_train_lstm.yml (here)) file to point to this folder.

Import packages

[13]:
%reload_ext autoreload
%autoreload 2

import sys
from pathlib import Path

# Dynamically set the project directory based on the notebook's location
notebook_dir = Path().resolve()
project_dir = str(notebook_dir.parent.parent)  # Adjust based on your project structure
sys.path.append(project_dir)

import os
import yaml

from src.thn_run import (
    _load_cfg_and_ds,
    get_basin_interpolators
)

from src.modelzoo_concept import get_concept_model
from src.modelzoo_nn import (
    get_nn_model,
    get_nn_pretrainer,
)
from src.modelzoo_hybrid import (
    get_hybrid_model,
    get_trainer,
)

Constants

Feel free to run and explore both nn_type = ‘mlp’ and nn_type = ‘lstm’

[14]:
nn_type = 'mlp'
# nn_type = 'lstm'

config_file = f'config_run_train_{nn_type}.yml'

Load main config file

This step is essential when running multiple single-basin models. Refer to src/scripts_paper/run_hybrid_trainer_single_all_mlp.py for the implementation. A parallelized version of the code demonstrated in this tutorial is also available for more efficient execution.

[15]:
# Load the MAIN configuration file
if Path(config_file).exists():
    with open(config_file, 'r') as f:
        cfg = yaml.safe_load(f)
else:
    raise FileNotFoundError(f'Configuration file {config_file} not found!')

# Read basin list
with open(cfg['basin_file'], 'r') as f:
    all_basins = [basin.strip() for basin in f.readlines()]

print(all_basins)
['01013500', '06431500']

Train hybrid model for each basin

[16]:
for basin in all_basins:

    # Temporary basin configuration file
    basin_file = f'temp_basin_{basin}_{nn_type}.txt'
    with open(basin_file, 'w') as f:
        f.write(basin)

    # Update the basin configuration file
    cfg['basin_file'] = basin_file

    # Create temporary configuration file config_file_temp_basin.yml
    config_file_temp = str(config_file).split('.')[0] + f'_temp_{nn_type}_{basin}.yml'
    with open(config_file_temp, 'w') as f:
        yaml.dump(cfg, f)

    # Load the configuration file and dataset
    cfg_run, dataset = _load_cfg_and_ds(
        Path(config_file_temp), model='hybrid')

    # Delete the basin_file and config_file_temp after training
    if os.path.isfile(basin_file):
        os.remove(basin_file)
    if os.path.isfile(config_file_temp):
        os.remove(config_file_temp)

    # Get the basin interpolators
    interpolators = get_basin_interpolators(dataset, cfg_run, project_dir)

    # Conceptual model
    time_idx0 = 0
    model_concept = get_concept_model(cfg_run, dataset.ds_train,
                                      interpolators, time_idx0,
                                      dataset.scaler)

    # Neural network model
    model_nn = get_nn_model(model_concept, dataset.ds_static)

    # Pretrainer
    pretrainer = get_nn_pretrainer(model_nn, dataset)

    # Pretrain the model
    pretrain_ok = pretrainer.train(loss=cfg_run.loss_pretrain,
                                lr=cfg_run.lr_pretrain,
                                epochs=cfg_run.epochs_pretrain,
                                disable_pbar=False,
                                any_log=False
    )

    # Train the hybrid model
    if pretrain_ok:
        # Build the hybrid model
        model_hybrid = get_hybrid_model(cfg_run, pretrainer, dataset)

        # Build the trainer
        trainer = get_trainer(model_hybrid)

        # Train the model
        trainer.train()
    else:
        print(f'Pretraining failed for basin {basin}')

-- Loading the config file and the dataset
-- Using device: cpu --
Setting seed for reproducibility: 111
cfg.nn_model_dir is not defined - parameters MUST be defined in the config file
-- Loading basin dynamics into xarray data set.
100%|██████████| 1/1 [00:00<00:00, 19.78it/s]
------------------------------------------------------------
-- Pretraining the neural network model -- (cpu)
------------------------------------------------------------
# Epoch 00001: 100%|██████████| 29/29 [00:00<00:00, 154.14it/s, Loss=1.0589e+00]
* Plotting basin 01013500: 100%|██████████| 1/1 [00:00<00:00,  2.17it/s]
# Epoch 00002: 100%|██████████| 29/29 [00:00<00:00, 145.18it/s, Loss=3.4281e-01]
# Epoch 00003: 100%|██████████| 29/29 [00:00<00:00, 168.09it/s, Loss=1.8930e-01]
# Epoch 00004: 100%|██████████| 29/29 [00:00<00:00, 171.46it/s, Loss=1.3642e-01]
# Epoch 00005: 100%|██████████| 29/29 [00:00<00:00, 172.81it/s, Loss=1.1274e-01]
# Epoch 00006: 100%|██████████| 29/29 [00:00<00:00, 173.53it/s, Loss=8.4502e-02]
# Epoch 00007: 100%|██████████| 29/29 [00:00<00:00, 175.04it/s, Loss=8.5674e-02]
# Epoch 00008: 100%|██████████| 29/29 [00:00<00:00, 175.02it/s, Loss=4.5157e-02]
# Epoch 00009: 100%|██████████| 29/29 [00:00<00:00, 175.12it/s, Loss=3.9378e-02]
# Epoch 00010: 100%|██████████| 29/29 [00:00<00:00, 175.01it/s, Loss=3.3307e-02]
# Epoch 00011: 100%|██████████| 29/29 [00:00<00:00, 175.16it/s, Loss=3.2391e-02]
# Epoch 00012: 100%|██████████| 29/29 [00:00<00:00, 193.63it/s, Loss=2.7244e-02]
# Epoch 00013: 100%|██████████| 29/29 [00:00<00:00, 212.29it/s, Loss=2.6126e-02]
# Epoch 00014: 100%|██████████| 29/29 [00:00<00:00, 212.66it/s, Loss=1.8526e-02]
# Epoch 00015: 100%|██████████| 29/29 [00:00<00:00, 179.34it/s, Loss=3.0113e-02]
# Epoch 00016: 100%|██████████| 29/29 [00:00<00:00, 173.77it/s, Loss=2.6183e-02]
# Epoch 00017: 100%|██████████| 29/29 [00:00<00:00, 174.22it/s, Loss=2.6346e-02]
# Epoch 00018: 100%|██████████| 29/29 [00:00<00:00, 174.14it/s, Loss=2.0164e-02]
# Epoch 00019: 100%|██████████| 29/29 [00:00<00:00, 174.14it/s, Loss=1.7950e-02]
# Epoch 00020: 100%|██████████| 29/29 [00:00<00:00, 173.66it/s, Loss=1.5267e-02]
# Epoch 00021: 100%|██████████| 29/29 [00:00<00:00, 185.53it/s, Loss=1.5734e-02]
# Epoch 00022: 100%|██████████| 29/29 [00:00<00:00, 189.15it/s, Loss=1.8016e-02]
# Epoch 00023: 100%|██████████| 29/29 [00:00<00:00, 170.21it/s, Loss=2.8083e-02]
# Epoch 00024: 100%|██████████| 29/29 [00:00<00:00, 212.24it/s, Loss=2.7260e-02]
# Epoch 00025: 100%|██████████| 29/29 [00:00<00:00, 158.02it/s, Loss=1.5969e-02]
# Epoch 00026: 100%|██████████| 29/29 [00:00<00:00, 212.93it/s, Loss=1.1404e-02]
# Epoch 00027: 100%|██████████| 29/29 [00:00<00:00, 212.73it/s, Loss=1.4070e-02]
# Epoch 00028: 100%|██████████| 29/29 [00:00<00:00, 212.88it/s, Loss=1.1904e-02]
# Epoch 00029: 100%|██████████| 29/29 [00:00<00:00, 189.46it/s, Loss=1.8034e-02]
# Epoch 00030: 100%|██████████| 29/29 [00:00<00:00, 152.10it/s, Loss=2.9080e-02]
# Epoch 00031: 100%|██████████| 29/29 [00:00<00:00, 175.16it/s, Loss=2.1873e-02]
# Epoch 00032: 100%|██████████| 29/29 [00:00<00:00, 168.34it/s, Loss=2.0210e-02]
# Epoch 00033: 100%|██████████| 29/29 [00:00<00:00, 173.46it/s, Loss=2.6820e-02]
# Epoch 00034: 100%|██████████| 29/29 [00:00<00:00, 214.19it/s, Loss=1.6065e-02]
# Epoch 00035: 100%|██████████| 29/29 [00:00<00:00, 171.56it/s, Loss=1.4127e-02]
# Epoch 00036: 100%|██████████| 29/29 [00:00<00:00, 157.76it/s, Loss=1.8038e-02]
# Epoch 00037: 100%|██████████| 29/29 [00:00<00:00, 169.43it/s, Loss=1.8243e-02]
# Epoch 00038: 100%|██████████| 29/29 [00:00<00:00, 163.33it/s, Loss=1.2418e-02]
# Epoch 00039: 100%|██████████| 29/29 [00:00<00:00, 169.58it/s, Loss=1.2180e-02]
# Epoch 00040: 100%|██████████| 29/29 [00:00<00:00, 164.72it/s, Loss=1.1189e-02]
# Epoch 00041: 100%|██████████| 29/29 [00:00<00:00, 186.76it/s, Loss=1.3415e-02]
# Epoch 00042: 100%|██████████| 29/29 [00:00<00:00, 214.10it/s, Loss=1.7245e-02]
# Epoch 00043: 100%|██████████| 29/29 [00:00<00:00, 214.88it/s, Loss=1.9379e-02]
# Epoch 00044: 100%|██████████| 29/29 [00:00<00:00, 211.27it/s, Loss=1.3128e-02]
# Epoch 00045: 100%|██████████| 29/29 [00:00<00:00, 146.39it/s, Loss=1.5571e-02]
# Epoch 00046: 100%|██████████| 29/29 [00:00<00:00, 169.57it/s, Loss=1.2147e-02]
# Epoch 00047: 100%|██████████| 29/29 [00:00<00:00, 50.64it/s, Loss=1.8673e-02]
# Epoch 00048: 100%|██████████| 29/29 [00:00<00:00, 160.19it/s, Loss=1.9846e-02]
# Epoch 00049: 100%|██████████| 29/29 [00:00<00:00, 168.70it/s, Loss=1.3078e-02]
# Epoch 00050: 100%|██████████| 29/29 [00:00<00:00, 163.77it/s, Loss=1.6706e-02]
* Plotting basin 01013500: 100%|██████████| 1/1 [00:00<00:00,  2.15it/s]
* Plotting basin 01013500: 100%|██████████| 1/1 [00:02<00:00,  2.11s/it]
* Evaluating basin 01013500 (ds_train): 100%|██████████| 1/1 [00:00<00:00, 23.98it/s]
* Evaluating basin 01013500 (ds_valid): 100%|██████████| 1/1 [00:00<00:00, 42.52it/s]
------------------------------------------------------------
-- Training the hybrid model on cpu --
Initial learning rate: 1.00e-03
------------------------------------------------------------
# Epoch 00001 : 100%|██████████| 29/29 [00:20<00:00,  1.44it/s, Loss=-7.2505e-01]
-- Saving the basin plots (epoch 1) | --
* Plotting basin 01013500: 100%|██████████| 1/1 [00:13<00:00, 13.59s/it]
-- Best model updated at epoch 1 with loss -7.2505e-01
# Epoch 00002 : 100%|██████████| 29/29 [00:20<00:00,  1.45it/s, Loss=-7.8818e-01]
-- Best model updated at epoch 2 with loss -7.8818e-01
# Epoch 00003 : 100%|██████████| 29/29 [00:20<00:00,  1.44it/s, Loss=-8.3137e-01]
-- Best model updated at epoch 3 with loss -8.3137e-01
# Epoch 00004 : 100%|██████████| 29/29 [00:20<00:00,  1.45it/s, Loss=-8.3978e-01]
-- Best model updated at epoch 4 with loss -8.3978e-01
# Epoch 00005 : 100%|██████████| 29/29 [00:20<00:00,  1.44it/s, Loss=-8.5919e-01]
-- Best model updated at epoch 5 with loss -8.5919e-01
Learning rate updated from 1.00e-03 to 5.00e-04
# Epoch 00006 : 100%|██████████| 29/29 [00:19<00:00,  1.45it/s, Loss=-8.7086e-01]
-- Best model updated at epoch 6 with loss -8.7086e-01
# Epoch 00007 : 100%|██████████| 29/29 [00:20<00:00,  1.43it/s, Loss=-8.7791e-01]
-- Best model updated at epoch 7 with loss -8.7791e-01
# Epoch 00008 : 100%|██████████| 29/29 [00:19<00:00,  1.46it/s, Loss=-8.8585e-01]
-- Best model updated at epoch 8 with loss -8.8585e-01
# Epoch 00009 : 100%|██████████| 29/29 [00:20<00:00,  1.43it/s, Loss=-8.7632e-01]
# Epoch 00010 : 100%|██████████| 29/29 [00:19<00:00,  1.46it/s, Loss=-8.8353e-01]
-- Training completed | Evaluating the model --
Loaded best model from runs/run_train_hybrid_mlp_01013500_241202_204921/model_weights/trainer_exphydrom100_mlp_1basins.pth
* Evaluating basin 01013500 (ds_train): 100%|██████████| 1/1 [00:08<00:00,  8.99s/it]
* Evaluating basin 01013500 (ds_valid): 100%|██████████| 1/1 [00:04<00:00,  4.60s/it]
* Plotting basin 01013500: 100%|██████████| 1/1 [00:13<00:00, 13.63s/it]
-- Loading the config file and the dataset
-- Using device: cpu --
Setting seed for reproducibility: 111
cfg.nn_model_dir is not defined - parameters MUST be defined in the config file
-- Loading basin dynamics into xarray data set.
100%|██████████| 1/1 [00:00<00:00, 22.40it/s]
------------------------------------------------------------
-- Pretraining the neural network model -- (cpu)
------------------------------------------------------------
# Epoch 00001: 100%|██████████| 29/29 [00:00<00:00, 196.79it/s, Loss=9.2318e-01]
* Plotting basin 06431500: 100%|██████████| 1/1 [00:00<00:00,  2.21it/s]
# Epoch 00002: 100%|██████████| 29/29 [00:00<00:00, 198.48it/s, Loss=5.6483e-01]
# Epoch 00003: 100%|██████████| 29/29 [00:00<00:00, 209.55it/s, Loss=2.6139e-01]
# Epoch 00004: 100%|██████████| 29/29 [00:00<00:00, 209.46it/s, Loss=1.4373e-01]
# Epoch 00005: 100%|██████████| 29/29 [00:00<00:00, 183.16it/s, Loss=1.7507e-01]
# Epoch 00006: 100%|██████████| 29/29 [00:00<00:00, 171.53it/s, Loss=1.0512e-01]
# Epoch 00007: 100%|██████████| 29/29 [00:00<00:00, 171.93it/s, Loss=6.5881e-02]
# Epoch 00008: 100%|██████████| 29/29 [00:00<00:00, 171.84it/s, Loss=4.6967e-02]
# Epoch 00009: 100%|██████████| 29/29 [00:00<00:00, 171.77it/s, Loss=7.3533e-02]
# Epoch 00010: 100%|██████████| 29/29 [00:00<00:00, 171.74it/s, Loss=4.5364e-02]
# Epoch 00011: 100%|██████████| 29/29 [00:00<00:00, 171.97it/s, Loss=4.1242e-02]
# Epoch 00012: 100%|██████████| 29/29 [00:00<00:00, 171.70it/s, Loss=1.6208e-02]
# Epoch 00013: 100%|██████████| 29/29 [00:00<00:00, 171.84it/s, Loss=1.8649e-02]
# Epoch 00014: 100%|██████████| 29/29 [00:00<00:00, 171.84it/s, Loss=1.8316e-02]
# Epoch 00015: 100%|██████████| 29/29 [00:00<00:00, 171.79it/s, Loss=1.5166e-02]
# Epoch 00016: 100%|██████████| 29/29 [00:00<00:00, 171.67it/s, Loss=1.9659e-02]
# Epoch 00017: 100%|██████████| 29/29 [00:00<00:00, 172.03it/s, Loss=1.3842e-02]
# Epoch 00018: 100%|██████████| 29/29 [00:00<00:00, 172.03it/s, Loss=2.4619e-02]
# Epoch 00019: 100%|██████████| 29/29 [00:00<00:00, 155.12it/s, Loss=2.2544e-02]
# Epoch 00020: 100%|██████████| 29/29 [00:00<00:00, 146.50it/s, Loss=1.5178e-02]
# Epoch 00021: 100%|██████████| 29/29 [00:00<00:00, 145.66it/s, Loss=1.1820e-02]
# Epoch 00022: 100%|██████████| 29/29 [00:00<00:00, 142.52it/s, Loss=1.2231e-02]
# Epoch 00023: 100%|██████████| 29/29 [00:00<00:00, 205.93it/s, Loss=1.2024e-02]
# Epoch 00024: 100%|██████████| 29/29 [00:00<00:00, 206.32it/s, Loss=1.4387e-02]
# Epoch 00025: 100%|██████████| 29/29 [00:00<00:00, 205.22it/s, Loss=9.0759e-03]
# Epoch 00026: 100%|██████████| 29/29 [00:00<00:00, 205.75it/s, Loss=1.8350e-02]
# Epoch 00027: 100%|██████████| 29/29 [00:00<00:00, 193.76it/s, Loss=2.1324e-02]
# Epoch 00028: 100%|██████████| 29/29 [00:00<00:00, 204.98it/s, Loss=3.0382e-02]
# Epoch 00029: 100%|██████████| 29/29 [00:00<00:00, 205.71it/s, Loss=3.4781e-02]
# Epoch 00030: 100%|██████████| 29/29 [00:00<00:00, 201.55it/s, Loss=3.5706e-02]
# Epoch 00031: 100%|██████████| 29/29 [00:00<00:00, 147.46it/s, Loss=2.5784e-02]
# Epoch 00032: 100%|██████████| 29/29 [00:00<00:00, 160.57it/s, Loss=3.0937e-02]
# Epoch 00033: 100%|██████████| 29/29 [00:00<00:00, 208.36it/s, Loss=4.7720e-02]
# Epoch 00034: 100%|██████████| 29/29 [00:00<00:00, 194.34it/s, Loss=2.4296e-02]
# Epoch 00035: 100%|██████████| 29/29 [00:00<00:00, 149.93it/s, Loss=4.6494e-02]
# Epoch 00036: 100%|██████████| 29/29 [00:00<00:00, 159.44it/s, Loss=2.7136e-02]
# Epoch 00037: 100%|██████████| 29/29 [00:00<00:00, 155.29it/s, Loss=2.6448e-02]
# Epoch 00038: 100%|██████████| 29/29 [00:00<00:00, 164.62it/s, Loss=2.1001e-02]
# Epoch 00039: 100%|██████████| 29/29 [00:00<00:00, 159.75it/s, Loss=5.3858e-02]
# Epoch 00040: 100%|██████████| 29/29 [00:00<00:00, 160.78it/s, Loss=3.0169e-02]
# Epoch 00041: 100%|██████████| 29/29 [00:00<00:00, 162.99it/s, Loss=1.9988e-02]
# Epoch 00042: 100%|██████████| 29/29 [00:00<00:00, 162.71it/s, Loss=3.6176e-02]
# Epoch 00043: 100%|██████████| 29/29 [00:00<00:00, 164.67it/s, Loss=2.3878e-02]
# Epoch 00044: 100%|██████████| 29/29 [00:00<00:00, 163.59it/s, Loss=2.9262e-02]
# Epoch 00045: 100%|██████████| 29/29 [00:00<00:00, 157.78it/s, Loss=3.6939e-02]
# Epoch 00046: 100%|██████████| 29/29 [00:00<00:00, 157.90it/s, Loss=3.7963e-02]
# Epoch 00047: 100%|██████████| 29/29 [00:00<00:00, 168.81it/s, Loss=3.4369e-02]
# Epoch 00048: 100%|██████████| 29/29 [00:00<00:00, 161.98it/s, Loss=2.1830e-02]
# Epoch 00049: 100%|██████████| 29/29 [00:00<00:00, 161.07it/s, Loss=1.5596e-02]
# Epoch 00050: 100%|██████████| 29/29 [00:00<00:00, 167.28it/s, Loss=1.5110e-02]
* Plotting basin 06431500: 100%|██████████| 1/1 [00:00<00:00,  2.10it/s]
* Plotting basin 06431500: 100%|██████████| 1/1 [00:02<00:00,  2.49s/it]
* Evaluating basin 06431500 (ds_train): 100%|██████████| 1/1 [00:00<00:00, 22.55it/s]
* Evaluating basin 06431500 (ds_valid): 100%|██████████| 1/1 [00:00<00:00, 44.35it/s]
------------------------------------------------------------
-- Training the hybrid model on cpu --
Initial learning rate: 1.00e-03
------------------------------------------------------------
# Epoch 00001 : 100%|██████████| 29/29 [00:20<00:00,  1.45it/s, Loss=5.8657e+00]
-- Saving the basin plots (epoch 1) | --
* Plotting basin 06431500: 100%|██████████| 1/1 [00:13<00:00, 13.73s/it]
-- Best model updated at epoch 1 with loss 5.8657e+00
# Epoch 00002 : 100%|██████████| 29/29 [00:19<00:00,  1.46it/s, Loss=1.4855e+00]
-- Best model updated at epoch 2 with loss 1.4855e+00
# Epoch 00003 : 100%|██████████| 29/29 [00:19<00:00,  1.46it/s, Loss=1.4386e+00]
-- Best model updated at epoch 3 with loss 1.4386e+00
# Epoch 00004 : 100%|██████████| 29/29 [00:20<00:00,  1.43it/s, Loss=9.0867e-01]
-- Best model updated at epoch 4 with loss 9.0867e-01
# Epoch 00005 : 100%|██████████| 29/29 [00:20<00:00,  1.45it/s, Loss=5.5492e-01]
-- Best model updated at epoch 5 with loss 5.5492e-01
Learning rate updated from 1.00e-03 to 5.00e-04
# Epoch 00006 : 100%|██████████| 29/29 [00:20<00:00,  1.42it/s, Loss=4.2375e-01]
-- Best model updated at epoch 6 with loss 4.2375e-01
# Epoch 00007 : 100%|██████████| 29/29 [00:19<00:00,  1.45it/s, Loss=3.0190e-01]
-- Best model updated at epoch 7 with loss 3.0190e-01
# Epoch 00008 : 100%|██████████| 29/29 [00:20<00:00,  1.45it/s, Loss=4.4728e-01]
# Epoch 00009 : 100%|██████████| 29/29 [00:20<00:00,  1.43it/s, Loss=2.9819e-01]
-- Best model updated at epoch 9 with loss 2.9819e-01
# Epoch 00010 : 100%|██████████| 29/29 [00:20<00:00,  1.45it/s, Loss=3.6943e-01]
-- Training completed | Evaluating the model --
Loaded best model from runs/run_train_hybrid_mlp_06431500_241202_205336/model_weights/trainer_exphydrom100_mlp_1basins.pth
* Evaluating basin 06431500 (ds_train): 100%|██████████| 1/1 [00:08<00:00,  8.94s/it]
* Evaluating basin 06431500 (ds_valid): 100%|██████████| 1/1 [00:04<00:00,  4.64s/it]
* Plotting basin 06431500: 100%|██████████| 1/1 [00:13<00:00, 13.66s/it]

You might want to explore the methods evaluate and save_plots in the class BaseHybridModelTrainer (src/modelzoo_hybrid/basetrainer)