Pretrain a neural network
In this tutorial, we demonstrate the steps to pretrain the Neural Network (NN) component of the hybrid model. The training data is derived from the pre-processed outputs generated after running the conceptual model. The NN is trained to predict key quantities - \(ET\), \(M\), \(Pr\), \(Ps\), and \(Q\) - by matching the outputs of the conceptual model.
IMPORTANT: In this tutorial, we use a file listing 4 basins and demonstrate running a multi-basin model.
Before we start
This tutorial is rendered from a Jupyter notebook that is hosted on GitHub. If you’d like to run the code yourself, you can access the notebook and configuration files directly from the repository: 02-PretrainNNmodel.
To run this notebook locally, ensure you have completed the setup steps outlined in Getting started. These steps include setting up the environment, installing the required packages, and preparing the data files necessary for the tutorial.
Dependency on a Previous Tutorial: Before running this tutorial, you must complete the 01-RunConceptModel Tutorial. After completing it:
1- Move the generated run folder to
src/data.2- Update the
data_dirfield in theconfig_run_pretrain.yml(here) file to point to this folder.
Import packages
[1]:
%reload_ext autoreload
%autoreload 2
import sys
from pathlib import Path
# Dynamically set the project directory based on the notebook's location
notebook_dir = Path().resolve()
project_dir = str(notebook_dir.parent.parent) # Adjust based on your project structure
sys.path.append(project_dir)
from src.thn_run import (
_load_cfg_and_ds,
get_basin_interpolators
)
from src.modelzoo_concept import get_concept_model
from src.modelzoo_nn import (
get_nn_model,
get_nn_pretrainer,
)
Constants
[2]:
config_file = 'config_run_pretrain.yml'
Load config file and prepare dataset
[3]:
cfg, dataset = _load_cfg_and_ds(Path(config_file), model='pretrainer')
-- Loading the config file and the dataset
-- Using device: cpu --
Setting seed for reproducibility: 1111
-- Loading basin dynamics into xarray data set.
100%|██████████| 4/4 [00:00<00:00, 21.40it/s]
[4]:
cfg._cfg['experiment_name']
[4]:
'run_pretrain_nn_model'
A folder has been created in the runs directory with the name specified as experiment_name in the configuration file, appended with a YYMMDD_HHMMSS timestamp. This folder will contain the configuration, results, plots, and metrics associated with the run.
[5]:
# Dataset attributes
dataset.__dict__.keys()
[5]:
dict_keys(['cfg', 'is_train', '_compute_scaler', 'scaler', 'basins', '_disable_pbar', '_per_basin_target_stds', '_dates', 'start_and_end_dates', 'num_samples', 'period_starts', 'alias_map', 'alias_map_clean', 'ds_train', 'ds_valid', 'ds_static'])
[6]:
display(
'basins', dataset.basins,
'start_and_end_dates', dataset.start_and_end_dates,
'ds_train', dataset.ds_train,
'ds_valid', dataset.ds_valid,
)
'basins'
['01013500', '01022500', '01030500', '06431500']
'start_and_end_dates'
{'train': {'start_date': Timestamp('1980-10-01 00:00:00'),
'end_date': Timestamp('2000-09-30 00:00:00')},
'valid': {'start_date': Timestamp('2000-10-01 00:00:00'),
'end_date': Timestamp('2010-09-30 00:00:00')}}
'ds_train'
<xarray.Dataset> Size: 2MB
Dimensions: (basin: 4, date: 7305)
Coordinates:
* date (date) datetime64[ns] 58kB 1980-10-01 1980-10-02 ... 2000-09-30
* basin (basin) <U8 128B '01013500' '01022500' '01030500' '06431500'
Data variables: (12/15)
dayl (basin, date) float32 117kB 11.33 11.28 11.23 ... 11.52 11.42
et_bucket (basin, date) float32 117kB 0.8673 1.153 1.251 ... 1.213 1.155
m_bucket (basin, date) float32 117kB 1.192e-07 1.192e-07 ... 0.1266
obs_runoff (basin, date) float32 117kB 0.551 0.5607 0.5586 ... 0.441 0.4353
pr_bucket (basin, date) float32 117kB 3.1 4.24 ... 1.192e-07 1.192e-07
prcp (basin, date) float32 117kB 3.1 4.24 8.02 15.27 ... 0.0 0.0 0.0
... ...
s_water (basin, date) float32 117kB 1.303e+03 1.305e+03 ... 984.9 983.5
srad (basin, date) float32 117kB 192.6 206.3 165.4 ... 350.7 303.8
tmax (basin, date) float32 117kB 10.05 15.82 15.86 ... 22.11 19.83
tmean (basin, date) float32 117kB 6.08 10.53 11.84 ... 13.41 12.77
tmin (basin, date) float32 117kB 2.11 5.24 7.81 ... 1.95 4.71 5.72
vp (basin, date) float32 117kB 711.3 898.6 ... 755.6 868.9
'ds_valid'
<xarray.Dataset> Size: 906kB
Dimensions: (basin: 4, date: 3652)
Coordinates:
* date (date) datetime64[ns] 29kB 2000-10-01 2000-10-02 ... 2010-09-30
* basin (basin) <U8 128B '01013500' '01022500' '01030500' '06431500'
Data variables: (12/15)
dayl (basin, date) float32 58kB 11.33 11.28 11.23 ... 11.52 11.52
et_bucket (basin, date) float32 58kB 1.485 1.449 1.463 ... 1.703 1.2 1.113
m_bucket (basin, date) float32 58kB 1.192e-07 1.192e-07 ... 0.1205 0.1076
obs_runoff (basin, date) float32 58kB 0.1158 0.1126 ... 0.4525 0.4525
pr_bucket (basin, date) float32 58kB 1.192e-07 1.192e-07 ... 1.192e-07
prcp (basin, date) float32 58kB 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0
... ...
s_water (basin, date) float32 58kB 1.504e+03 1.502e+03 ... 1.047e+03
srad (basin, date) float32 58kB 327.8 331.3 314.3 ... 352.8 363.1
tmax (basin, date) float32 58kB 20.31 20.4 20.0 ... 26.98 20.26 19.66
tmean (basin, date) float32 58kB 12.23 11.92 12.17 ... 12.22 11.05
tmin (basin, date) float32 58kB 4.16 3.44 4.34 ... 8.61 4.18 2.44
vp (basin, date) float32 58kB 822.2 783.9 827.6 ... 810.6 710.9Feel free to generate plots from the training and validation sets to get familiar with the data
Create interpolators
As the time-series data was loaded on a one-day resolution, we need to run interpolation during the solution of the system of ODEs for adaptative-step methods and fixe-step methods with higher resolution.
[7]:
# Get the basin interpolators
interpolators = get_basin_interpolators(dataset, cfg, project_dir)
Create the Conceptual model
[8]:
time_idx0 = 0 # Start from the first time index - 0 for training
model_concept = get_concept_model(cfg, dataset.ds_train, interpolators,
time_idx0, dataset.scaler)
Create the Neural network model
[9]:
model_nn = get_nn_model(model_concept, dataset.ds_static)
Create the Pretrainer Model
[10]:
pretrainer = get_nn_pretrainer(model_nn, dataset)
Pretrain the model
[11]:
pretrain_ok = pretrainer.train(loss=cfg.loss_pretrain, lr=cfg.lr_pretrain,
epochs=cfg.epochs_pretrain,
disable_pbar=False,
any_log=False)
------------------------------------------------------------
-- Pretraining the neural network model -- (cpu)
------------------------------------------------------------
# Epoch 00001: 0%| | 0/116 [00:00<?, ?it/s, Loss=2.1382e+00]# Epoch 00001: 100%|██████████| 116/116 [00:00<00:00, 175.79it/s, Loss=1.0095e+00]
* Plotting basin 01030500: 100%|██████████| 4/4 [00:02<00:00, 1.90it/s]
# Epoch 00002: 100%|██████████| 116/116 [00:00<00:00, 138.03it/s, Loss=3.6343e-01]
# Epoch 00003: 100%|██████████| 116/116 [00:00<00:00, 161.53it/s, Loss=3.0324e-01]
# Epoch 00004: 100%|██████████| 116/116 [00:00<00:00, 146.59it/s, Loss=2.6700e-01]
# Epoch 00005: 100%|██████████| 116/116 [00:00<00:00, 156.76it/s, Loss=2.3470e-01]
# Epoch 00006: 100%|██████████| 116/116 [00:00<00:00, 163.43it/s, Loss=2.1603e-01]
# Epoch 00007: 100%|██████████| 116/116 [00:00<00:00, 150.90it/s, Loss=2.0421e-01]
# Epoch 00008: 100%|██████████| 116/116 [00:00<00:00, 152.84it/s, Loss=2.0980e-01]
# Epoch 00009: 100%|██████████| 116/116 [00:00<00:00, 173.21it/s, Loss=2.0567e-01]
# Epoch 00010: 100%|██████████| 116/116 [00:00<00:00, 166.42it/s, Loss=2.1032e-01]
# Epoch 00011: 100%|██████████| 116/116 [00:00<00:00, 179.96it/s, Loss=2.1279e-01]
# Epoch 00012: 100%|██████████| 116/116 [00:00<00:00, 149.49it/s, Loss=2.1788e-01]
# Epoch 00013: 100%|██████████| 116/116 [00:00<00:00, 201.65it/s, Loss=2.1850e-01]
# Epoch 00014: 100%|██████████| 116/116 [00:00<00:00, 204.54it/s, Loss=2.1837e-01]
# Epoch 00015: 100%|██████████| 116/116 [00:00<00:00, 205.26it/s, Loss=2.1834e-01]
# Epoch 00016: 100%|██████████| 116/116 [00:00<00:00, 206.15it/s, Loss=2.1539e-01]
# Epoch 00017: 100%|██████████| 116/116 [00:00<00:00, 201.06it/s, Loss=2.1600e-01]
# Epoch 00018: 100%|██████████| 116/116 [00:00<00:00, 188.47it/s, Loss=2.1032e-01]
# Epoch 00019: 100%|██████████| 116/116 [00:00<00:00, 207.30it/s, Loss=2.0384e-01]
# Epoch 00020: 100%|██████████| 116/116 [00:00<00:00, 200.19it/s, Loss=1.9702e-01]
# Epoch 00021: 100%|██████████| 116/116 [00:00<00:00, 138.75it/s, Loss=1.9039e-01]
# Epoch 00022: 100%|██████████| 116/116 [00:01<00:00, 103.36it/s, Loss=1.8611e-01]
# Epoch 00023: 100%|██████████| 116/116 [00:00<00:00, 130.51it/s, Loss=1.8147e-01]
# Epoch 00024: 100%|██████████| 116/116 [00:00<00:00, 130.23it/s, Loss=1.7528e-01]
# Epoch 00025: 100%|██████████| 116/116 [00:00<00:00, 134.12it/s, Loss=1.7129e-01]
# Epoch 00026: 100%|██████████| 116/116 [00:00<00:00, 133.02it/s, Loss=1.6737e-01]
# Epoch 00027: 100%|██████████| 116/116 [00:00<00:00, 129.72it/s, Loss=1.6380e-01]
# Epoch 00028: 100%|██████████| 116/116 [00:00<00:00, 131.77it/s, Loss=1.6077e-01]
# Epoch 00029: 100%|██████████| 116/116 [00:00<00:00, 130.17it/s, Loss=1.5746e-01]
# Epoch 00030: 100%|██████████| 116/116 [00:00<00:00, 133.09it/s, Loss=1.5431e-01]
# Epoch 00031: 100%|██████████| 116/116 [00:00<00:00, 129.34it/s, Loss=1.5180e-01]
# Epoch 00032: 100%|██████████| 116/116 [00:01<00:00, 102.10it/s, Loss=1.5100e-01]
# Epoch 00033: 100%|██████████| 116/116 [00:00<00:00, 132.66it/s, Loss=1.4985e-01]
# Epoch 00034: 100%|██████████| 116/116 [00:00<00:00, 130.04it/s, Loss=1.4877e-01]
# Epoch 00035: 100%|██████████| 116/116 [00:00<00:00, 134.54it/s, Loss=1.4753e-01]
# Epoch 00036: 100%|██████████| 116/116 [00:00<00:00, 133.63it/s, Loss=1.4661e-01]
# Epoch 00037: 100%|██████████| 116/116 [00:00<00:00, 130.98it/s, Loss=1.4545e-01]
# Epoch 00038: 100%|██████████| 116/116 [00:00<00:00, 132.88it/s, Loss=1.4368e-01]
# Epoch 00039: 100%|██████████| 116/116 [00:00<00:00, 135.64it/s, Loss=1.4226e-01]
# Epoch 00040: 100%|██████████| 116/116 [00:00<00:00, 134.35it/s, Loss=1.4064e-01]
# Epoch 00041: 100%|██████████| 116/116 [00:00<00:00, 148.39it/s, Loss=1.3904e-01]
# Epoch 00042: 100%|██████████| 116/116 [00:01<00:00, 111.46it/s, Loss=1.3721e-01]
# Epoch 00043: 100%|██████████| 116/116 [00:00<00:00, 127.96it/s, Loss=1.3627e-01]
# Epoch 00044: 100%|██████████| 116/116 [00:00<00:00, 133.08it/s, Loss=1.3440e-01]
# Epoch 00045: 100%|██████████| 116/116 [00:00<00:00, 136.01it/s, Loss=1.3359e-01]
# Epoch 00046: 100%|██████████| 116/116 [00:00<00:00, 132.79it/s, Loss=1.3252e-01]
# Epoch 00047: 100%|██████████| 116/116 [00:00<00:00, 130.88it/s, Loss=1.3132e-01]
# Epoch 00048: 100%|██████████| 116/116 [00:00<00:00, 132.89it/s, Loss=1.3019e-01]
# Epoch 00049: 100%|██████████| 116/116 [00:00<00:00, 130.17it/s, Loss=1.2775e-01]
# Epoch 00050: 100%|██████████| 116/116 [00:00<00:00, 130.18it/s, Loss=1.2687e-01]
* Plotting basin 01030500: 100%|██████████| 4/4 [00:02<00:00, 1.49it/s]
* Plotting basin 01030500: 100%|██████████| 4/4 [00:09<00:00, 2.30s/it]
* Evaluating basin 06431500 (ds_train): 100%|██████████| 4/4 [00:00<00:00, 24.85it/s]
* Evaluating basin 06431500 (ds_valid): 100%|██████████| 4/4 [00:00<00:00, 46.46it/s]
Feel free to explore the model_plots and model_results folders to evaluate the outcomes of the pretraining stage. These folders provide insights into how well the model fits the target variables: \(ET\), \(M\), \(Pr\), \(Ps\), and \(Q\).
The model_weights folder contains the best-performing version of the pretrained model, which will be loaded by the hybrid model at the start of the training process. This process will be demonstrated in the next tutorial: 03-TrainHybridModel.