Skip to content

Trainer Module

bnode_core.ode.trainer

Neural ODE and Balanced Neural ODE Training Module.

This module provides the main training pipeline for Neural ODE (NODE) and Balanced Neural ODE (BNODE) models. It handles model initialization, multi-phase training, validation, testing, and MLflow experiment tracking.

Architecture Support

The trainer automatically detects and supports two model architectures:

  • Neural ODE (NODE): Direct neural differential equation models.
  • Balanced Neural ODE (BNODE): Latent-space ODE models with encoder-decoder architecture for improved training stability and representation learning.

Training Pipeline Overview

The training process follows these stages:

  1. Model Instantiation

    • Automatically detects NODE vs BNODE from config
    • Initializes normalization layers using dataset statistics
    • Sets up device (CPU/CUDA) based on availability and config
  2. Pre-training (Optional, NODE only)

    • Can be enabled in config: nn_model.training.pre_train=true
    • Trains on state derivatives (state_der) if present in dataset
    • Uses collocation method for initial parameter estimation
    • Not supported for BNODE models (No latent states gradients available, but you can mock this behavior by using a short main training phase with states_grad_loss)
  3. Multi-Phase Main Training

    • Configured as a list in nn_model.training.main_training
    • Each phase can have different hyperparameters:
      • Solver type (euler, rk4, dopri5, etc.)
      • Learning rate, batch size, sequence length
      • Early stopping patience and threshold
    • See resources/config/nn_model/bnode_pytest.yaml for an example
  4. Final Testing

    • Evaluates model on all dataset splits (train/val/test)
    • Optionally saves predictions and internal variables to dataset
    • Logs final metrics to MLflow

Key Training Features

Compatibility with NODE and BNODE

  • Trainer auto-detects model type from config
  • Both models provide a consistent training interface with e.g. the model_and_loss_evaluation method.

Adaptive Batch Processing

Each epoch processes a specified number of batches (not entire dataset). Configured via nn_model.training.main_training[i].batches_per_epoch.

NaN Recovery

  • If NaN loss detected, automatically reloads last checkpoint
  • Reduces gradient clipping norm to stabilize training
  • Note: LR scheduling might be a better long-term solution

Reparameterization Control (BNODE)

  • Training uses active reparameterization (variational inference)
  • When evaluating (validation/test, or at final test for all datasets), reparameterization is disabled. Also for deterministic mode.
  • Ensures consistent evaluation metrics

Progressive Sequence Length Increase

  • When switching phases, sequence length gradually increases
  • Initial test with final sequence length to assess extrapolation
  • Training sequence length increases gradually (controlled by seq_len_increase_in_batches)
  • Validation/test always use full sequence length to monitor extrapolation performance
  • Early abort if stable extrapolation achieved: loss_train < 2 * loss_validation for N consecutive epochs (seq_len_increase_abort_after_n_stable_epochs)

MLflow Integration

  • Logs metrics at end of each phase: {metric}_{context}_job{phase}_final
  • Final test metrics logged as: {metric}_final
  • All Hydra outputs and trained models saved as artifacts
  • Experiment tracking with run name, parameters, and tags

Typical Usage Examples

As other modules of the bnode_core package, we use Hydra for configuration management.

Basic training with default config:

uv run trainer nn_model=latent_ode_base dataset_name=myDataset

Training with custom model configuration:

uv run trainer nn_model=myCustomModel dataset_name=myDataset \
    mlflow_experiment_name=my_experiment \
    nn_model.network.lat_states_dim=1024 \

Hyperparameter sweep (multi-run mode):

uv run trainer \
    nn_model=latent_ode_base \
    dataset_name=myDataset \
    nn_model.training.beta_start_override=0.1,0.01,0.001 \
    -m

Override specific training parameters:

uv run trainer \
    nn_model=latent_ode_base \
    dataset_name=myDataset \
    nn_model.training.lr_start_override=1e-4 \
    nn_model.training.batch_size_override=512 \
    use_cuda=false

View available configuration options (from Hydra):

uv run trainer --help

Configuration

For detailed configuration options, see:

  • Config Documentation: Consult the Config section of the documentation
  • Config Files: examples in resources/config/nn_model/ directory
  • Config Schema: bnode_core.config module for all available parameters
  • Search Tip: Use Ctrl+F in config files to find specific parameter behavior

Command Line Interface

The trainer is registered as a UV script in pyproject.toml, enabling direct execution via uv run trainer. All Hydra config parameters can be overridden via command line using dot notation.

Notes

  • CUDA is automatically used if available (override with use_cuda=false)
  • Model checkpoints saved after each phase: model_phase_{i}.pt
  • Failed artifact logging tracked in could_not_log_artifacts.txt
  • Supports mixed precision training (AMP) when enabled
  • Early stopping based on validation loss with configurable patience

See Also

bnode_core.config : Configuration schemas and validation bnode_core.ode.node.node_architecture : Neural ODE model implementation bnode_core.ode.bnode.bnode_architecture : Balanced Neural ODE model implementation bnode_core.nn.nn_utils.load_data : Dataset loading utilities

initialize_model(cfg: train_test_config_class, train_dataset: TimeSeriesDataset, hdf5_dataset: hdf5_dataset_class, initialize_normalization=True, model_type: str = None)

Initialize and configure NODE or BNODE model with dataset statistics.

Automatically detects model type from config and initializes normalization layers using training dataset statistics. Handles device placement (CPU/CUDA) and copies model architecture file to Hydra output directory.

Parameters:

Name Type Description Default
cfg train_test_config_class

Validated Hydra configuration.

required
train_dataset TimeSeriesDataset

Training dataset for normalization.

required
hdf5_dataset Dataset

HDF5 dataset handle for statistics.

required
initialize_normalization bool

Whether to initialize normalization layers from dataset statistics. Defaults to True.

True
model_type str

Force specific model type ('node' or 'bnode'). If None, auto-detects from config. Defaults to None.

None

Returns:

Name Type Description
model Module

Initialized model (NeuralODE or BalancedNeuralODE) moved to appropriate device.

Side Effects
  • Modifies cfg.use_cuda based on availability
  • Copies model architecture source file to Hydra output directory
  • Logs device and parameter count information
Notes
  • CUDA is used if available and cfg.use_cuda=True
  • Normalization uses training set statistics only
  • Model type detection based on network class in config
Source code in src/bnode_core/ode/trainer.py
def initialize_model(cfg: train_test_config_class, train_dataset: TimeSeriesDataset, hdf5_dataset: hdf5_dataset_class, 
                     initialize_normalization=True, model_type: str = None):
    """Initialize and configure NODE or BNODE model with dataset statistics.

    Automatically detects model type from config and initializes normalization
    layers using training dataset statistics. Handles device placement (CPU/CUDA)
    and copies model architecture file to Hydra output directory.

    Args:
        cfg (train_test_config_class): Validated Hydra configuration.
        train_dataset (TimeSeriesDataset): Training dataset for normalization.
        hdf5_dataset (hdf5_dataset_class): HDF5 dataset handle for statistics.
        initialize_normalization (bool, optional): Whether to initialize normalization
            layers from dataset statistics. Defaults to True.
        model_type (str, optional): Force specific model type ('node' or 'bnode').
            If None, auto-detects from config. Defaults to None.

    Returns:
        model (torch.nn.Module): Initialized model (NeuralODE or BalancedNeuralODE) moved
            to appropriate device.

    Side Effects:
        - Modifies cfg.use_cuda based on availability
        - Copies model architecture source file to Hydra output directory
        - Logs device and parameter count information

    Notes:
        - CUDA is used if available and cfg.use_cuda=True
        - Normalization uses training set statistics only
        - Model type detection based on network class in config
    """
    _cuda_available = torch.cuda.is_available()
    logging.info('CUDA available: {} | cfg.use_cuda: {}'.format(_cuda_available, cfg.use_cuda))
    if _cuda_available and cfg.use_cuda:
        cfg.use_cuda = True
    else:
        cfg.use_cuda = False
    logging.info("---> Training with cuda: {}".format(cfg.use_cuda))
    device = torch.device('cuda' if torch.cuda.is_available() and cfg.use_cuda else 'cpu')
    # create model (insert specific creations here)
    from bnode_core.config import neural_ode_network_class, latent_ode_network_class
    if model_type == None:
        if type(cfg.nn_model.network) is neural_ode_network_class:
            model_type='node'
        elif type(cfg.nn_model.network) is latent_ode_network_class:
            model_type='bnode'
        else: 
            raise ValueError('The neural network class could not be resolved')
        assert model_type in ['node', 'bnode']
    if model_type == 'node':
        model = NeuralODE(states_dim=train_dataset[0]['states'].shape[0],
                        controls_dim=train_dataset[0]['controls'].shape[0] if 'controls' in train_dataset[0].keys() else 0,
                        parameters_dim=train_dataset[0]['parameters'].shape[0] if 'parameters' in train_dataset[0].keys() else 0,
                        outputs_dim=train_dataset[0]['outputs'].shape[0] if 'outputs' in train_dataset[0].keys() else 0,
                        controls_to_output_nn=cfg.nn_model.network.controls_to_output_nn,
                        hidden_dim=cfg.nn_model.network.linear_hidden_dim, 
                        n_layers=cfg.nn_model.network.n_linear_layers,
                        hidden_dim_output_nn=cfg.nn_model.network.hidden_dim_output_nn,
                        n_layers_output_nn=cfg.nn_model.network.n_layers_output_nn,
                        activation=eval(cfg.nn_model.network.activation),
                        intialization=cfg.nn_model.training.pre_training.initialization_type,
                        initialization_ode=cfg.nn_model.training.initialization_type_ode,)
        # initialize normalizations
        if initialize_normalization:
            model.normalization_init(hdf5_dataset)
    elif model_type == 'bnode':
        model = BalancedNeuralODE(
                        states_dim=train_dataset[0]['states'].shape[0],
                        lat_states_mu_dim=cfg.nn_model.network.lat_states_dim,
                        parameters_dim=train_dataset[0]['parameters'].shape[0] if 'parameters' in train_dataset[0].keys() else 0,
                        lat_parameters_dim=cfg.nn_model.network.lat_parameters_dim,
                        controls_dim=train_dataset[0]['controls'].shape[0] if 'controls' in train_dataset[0].keys() else 0,
                        lat_controls_dim=cfg.nn_model.network.lat_controls_dim,
                        outputs_dim=train_dataset[0]['outputs'].shape[0] if 'outputs' in train_dataset[0].keys() else 0,
                        hidden_dim=cfg.nn_model.network.linear_hidden_dim,
                        n_layers=cfg.nn_model.network.n_linear_layers,
                        controls_to_decoder=cfg.nn_model.network.controls_to_decoder,
                        predict_states=cfg.nn_model.network.predict_states,
                        activation=eval(cfg.nn_model.network.activation),
                        initialization_type=cfg.nn_model.training.initialization_type,
                        initialization_type_ode=cfg.nn_model.training.initialization_type_ode,
                        initialization_type_ode_matrix=cfg.nn_model.training.initialization_type_ode_matrix,
                        lat_ode_type=cfg.nn_model.network.lat_ode_type,
                        include_params_encoder= cfg.nn_model.network.include_params_encoder,
                        params_to_state_encoder=cfg.nn_model.network.params_to_state_encoder,
                        params_to_control_encoder=cfg.nn_model.network.params_to_control_encoder,
                        params_to_decoder=cfg.nn_model.network.params_to_decoder,
                        controls_to_state_encoder=cfg.nn_model.network.controls_to_state_encoder,
                        state_encoder_linear = cfg.nn_model.network.state_encoder_linear,
                        control_encoder_linear = cfg.nn_model.network.control_encoder_linear,
                        parameter_encoder_linear = cfg.nn_model.network.parameter_encoder_linear,
                        ode_linear = cfg.nn_model.network.ode_linear,
                        decoder_linear = cfg.nn_model.network.decoder_linear,
                        lat_state_mu_independent = cfg.nn_model.network.lat_state_mu_independent,
                        )
        # initialize normalizations
        if initialize_normalization:
            model.normalization_init(hdf5_dataset)
    logging.info('Initialized model: {}'.format(model))
    logging.info('Number of trainable parameters: {}'.format(sum(p.numel() for p in model.parameters() if p.requires_grad)))
    model.to(device)
    logging.info('moved model to {}'.format(device))
    return model

train_all_phases(cfg: train_test_config_class)

Execute complete multi-phase training pipeline with MLflow tracking.

Main orchestration function that coordinates:

  • Dataset loading
  • Model initialization
  • Optional pre-training (NODE only)
  • Multi-phase main training
  • Final testing and evaluation
  • MLflow artifact logging

The function processes a job list consisting of optional pre-training, multiple main training phases, and final testing. Each phase can have different hyperparameters and training strategies.

Parameters:

Name Type Description Default
cfg train_test_config_class

Validated Hydra configuration containing: - dataset_path, dataset_name: Dataset location and identifier - nn_model.training.pre_train: Enable pre-training (NODE only) - nn_model.training.main_training: List of training phase configs - nn_model.training.test: Enable final testing - use_cuda: Device preference - mlflow_experiment_name: MLflow experiment name

required
Side Effects
  • Creates/updates model checkpoints: model_phase_{i}.pt
  • Logs metrics, parameters, and artifacts to MLflow
  • Saves predictions to dataset if configured
  • Copies Hydra outputs to MLflow artifacts
  • Creates could_not_log_artifacts.txt on logging failures
Training Flow
  1. Load HDF5 dataset and log to MLflow
  2. Build job list (pre-train, main phases, test)
  3. For each job:
  4. Initialize/reload dataloaders if needed
  5. Initialize/load model if needed
  6. Execute training or testing
  7. Save checkpoint and log metrics
  8. Copy all outputs to MLflow artifacts

Raises:

Type Description
RuntimeError

If CUDA memory errors occur repeatedly

FileNotFoundError

If dataset or checkpoint files missing

Notes
  • Decorated with @log_hydra_to_mlflow for automatic config logging
  • Memory errors trigger dataloader recreation with adjusted settings
  • NaN losses trigger checkpoint reload and gradient clipping adjustment
  • Progressive sequence length increase during phase transitions
See Also

train_one_phase : Single training phase execution initialize_model : Model instantiation and initialization

Source code in src/bnode_core/ode/trainer.py
@log_hydra_to_mlflow
def train_all_phases(cfg: train_test_config_class):
    """Execute complete multi-phase training pipeline with MLflow tracking.

    Main orchestration function that coordinates:

    - Dataset loading
    - Model initialization  
    - Optional pre-training (NODE only)
    - Multi-phase main training
    - Final testing and evaluation
    - MLflow artifact logging

    The function processes a job list consisting of optional pre-training,
    multiple main training phases, and final testing. Each phase can have
    different hyperparameters and training strategies.

    Args:
        cfg (train_test_config_class): Validated Hydra configuration containing:
            - dataset_path, dataset_name: Dataset location and identifier
            - nn_model.training.pre_train: Enable pre-training (NODE only)
            - nn_model.training.main_training: List of training phase configs
            - nn_model.training.test: Enable final testing
            - use_cuda: Device preference
            - mlflow_experiment_name: MLflow experiment name

    Side Effects:
        - Creates/updates model checkpoints: model_phase_{i}.pt
        - Logs metrics, parameters, and artifacts to MLflow
        - Saves predictions to dataset if configured
        - Copies Hydra outputs to MLflow artifacts
        - Creates could_not_log_artifacts.txt on logging failures

    Training Flow:
        1. Load HDF5 dataset and log to MLflow
        2. Build job list (pre-train, main phases, test)
        3. For each job:
           - Initialize/reload dataloaders if needed
           - Initialize/load model if needed
           - Execute training or testing
           - Save checkpoint and log metrics
        4. Copy all outputs to MLflow artifacts

    Raises:
        RuntimeError: If CUDA memory errors occur repeatedly
        FileNotFoundError: If dataset or checkpoint files missing

    Notes:
        - Decorated with @log_hydra_to_mlflow for automatic config logging
        - Memory errors trigger dataloader recreation with adjusted settings
        - NaN losses trigger checkpoint reload and gradient clipping adjustment
        - Progressive sequence length increase during phase transitions

    See Also:
        train_one_phase : Single training phase execution
        initialize_model : Model instantiation and initialization
    """
    logging.info('Start training all phases....')
    device = torch.device('cuda' if torch.cuda.is_available() and cfg.use_cuda else 'cpu')
    logging.info('Using device: {}'.format(device))

    # load hdf5 dataset
    hdf5_dataset, _ = load_dataset_and_config(cfg.dataset_name, cfg.dataset_path)
    mlflow.log_param('dataset_name', cfg.dataset_name)

    # collect jobs
    # job_list=[] filled with dict of style: {'skip': bool, 'test': bool, 'train_cfg': cfg, 'pre_train': bool}
    job_list = []
    # pre-training
    job_list.append({'skip': not cfg.nn_model.training.pre_train or cfg.nn_model.training.load_pretrained_model or cfg.nn_model.training.load_trained_model_for_test,
                     'test': False, 'train_cfg': cfg.nn_model.training.pre_training, 'pre_train': True})
    # main training
    for idx, main_train_cfg in enumerate(cfg.nn_model.training.main_training):
        job_list.append({'skip': cfg.nn_model.training.load_trained_model_for_test, 'test': False, 'train_cfg': main_train_cfg, 'pre_train': False})
    # test
    if cfg.nn_model.training.test is True:
        job_list.append({'skip': False, 'test': True, 'train_cfg': cfg.nn_model.training.main_training[-1], 'pre_train': False})
    logging.info('Created job list: {}'.format(job_list))

    # flags
    _created_datasets_and_loaders=False
    _loaded_seq_len=-1
    _loaded_batch_size=-1
    _created_model=False
    _epoch_0 = 0
    _reload_dataloaders_required = False
    for idx, job in enumerate(job_list):
        while True: # loop to catch memory errors
            try:
                if job['skip'] is False: # create dataloaders for this job
                    if job['pre_train'] is True:
                        logging.info('Starting Pre-Training with settings {}'.format(job['train_cfg']))
                    elif job['test'] is True:
                        logging.info('Starting Testing with settings {}'.format(job['train_cfg']))
                    else:
                        logging.info('Starting Train Job {} with settings {}'.format(idx, job['train_cfg']))
                    # loading datasets and initializing dataloaders
                    # set seq_len
                    if job['pre_train'] is True:
                        _load_seq_len = job['train_cfg'].load_seq_len
                        _seq_len_batches = 1
                    elif job['test'] is True:
                        _load_seq_len = None
                        _seq_len_batches = None
                    else:
                        _load_seq_len = job['train_cfg'].load_seq_len
                        _seq_len_batches = job['train_cfg'].seq_len_train
                    if _created_datasets_and_loaders is False or _load_seq_len != _loaded_seq_len: 
                        if _created_datasets_and_loaders is True:
                            _keys = list(datasets.keys())
                            for key in _keys:
                                del datasets[key]
                        # make torch tensor datasets
                        datasets = {}
                        for context in ['train', 'test', 'validation', 'common_test']:
                            datasets[context] = make_stacked_dataset(hdf5_dataset, context, _load_seq_len, _seq_len_batches)
                        _loaded_seq_len = _load_seq_len
                        _reload_dataloaders_required = True
                    else:
                        for context in ['train', 'test', 'validation', 'common_test']:
                            datasets[context].set_seq_len(_seq_len_batches)
                        _reload_dataloaders_required = True # TODO; check if this is necessary
                    _batch_size = job['train_cfg'].batch_size if job['test'] is False else cfg.nn_model.training.batch_size_test
                    _drop_last = True if job['test'] is False else False
                    _shuffle = True if job['test'] is False else False
                    if _created_datasets_and_loaders is False or _loaded_batch_size != _batch_size or _reload_dataloaders_required is True or job['test'] is True:
                        # initialiaze batch_loader, as batch size can't be set to a new value
                        if _created_datasets_and_loaders is True:
                            #del dataloaders
                            _keys = list(dataloaders.keys())
                            for key in _keys:
                                del dataloaders[key]
                        # create new
                        dataloaders={}
                        for context in ['train', 'test', 'validation', 'common_test']:
                            if job['test'] is True and len(datasets[context]) == 0: # when only testing, datasets can be empty
                                dataloaders[context] = None
                                logging.info('Only Testing: No data for context {} in dataset. Skipping loading dataloader for this context'.format(context))
                            else:
                                _num_workers = cfg.n_workers_train_loader if context == 'train' else cfg.n_workers_other_loaders
                                if context == 'train' and job['pre_train'] is True:
                                    _num_workers = 1 * _num_workers
                                if _batch_size > len(datasets[context]):
                                    _batch_size_here = int(len(datasets[context])/2)+3
                                    logging.warning('Batch size {} is larger than dataset size {} for context {}. Setting batch size to {}'.format(_batch_size, len(datasets[context]), context, _batch_size_here))
                                else:
                                    _batch_size_here = _batch_size
                                if len(datasets[context]) == 0:
                                    raise ValueError('While creating dataloaders, dataset for context {} is empty. Aborting.'.format(context))
                                dataloaders[context] = torch.utils.data.DataLoader(datasets[context], batch_size=_batch_size_here, shuffle=_shuffle,
                                                                                    num_workers = _num_workers, persistent_workers=True, 
                                                                                    pin_memory=True, drop_last=_drop_last, prefetch_factor=cfg.prefetch_factor)
                        _created_datasets_and_loaders = True
                        _loaded_batch_size = _batch_size
                        # update seq_len train for this job to the actual seq_len of the dataset
                        if 'seq_len' in datasets['train'].__dict__.keys(): # for custom dataset (wiht map)
                            job['train_cfg'].seq_len_train = datasets['train'].seq_len
                        else:
                            job['train_cfg'].seq_len_train = datasets['train'].datasets['time'].shape[2]


                    _created_model_this_job = False	
                    # initialize model
                    if _created_model is False:
                        model = initialize_model(cfg, datasets['train'], hdf5_dataset)
                        _created_model, _created_model_this_job = True, True
                    if cfg.nn_model.training.load_pretrained_model is True and _created_model_this_job is True:
                        _path = filepaths.filepath_from_local_or_ml_artifacts(cfg.nn_model.training.path_pretrained_model)
                        model.load(path=_path, device=device)
                        logging.info('Loaded pretrained model from {}'.format(_path))
                        if cfg.nn_model.training.pre_trained_model_seq_len is not None: 
                            job_list[idx]['train_cfg'].seq_len_epoch_start = cfg.nn_model.training.pre_trained_model_seq_len
                            logging.info('Set seq_len_epoch_start for next job to {}'.format(cfg.nn_model.training.pre_trained_model_seq_len))
                        else:
                            job_list[idx]['train_cfg'].seq_len_epoch_start = job['train_cfg'].seq_len_train
                            logging.info('Set seq_len_epoch_start for this job to seq_len_train {} as no pre_trained_model_seq_len is given in config'.format(job['train_cfg'].seq_len_train))
                    if cfg.nn_model.training.load_trained_model_for_test is True:
                        _path = cfg.nn_model.training.path_trained_model
                        _path = filepaths.filepath_from_local_or_ml_artifacts(_path)
                        model.load(path=_path, device=device)
                        logging.info('Loaded trained model from {}'.format(_path))

                if job['skip'] is True:
                    if job['pre_train'] is True:
                        logging.info('Skipping Pre-Training')
                    else:
                        logging.info('Skipping Train Job {} as trained model is loaded in following phases'.format(idx))
                else:
                    if job['test'] is False:
                        # train one phase
                        _epoch_0 = train_one_phase(cfg, model, dataloaders, job['train_cfg'], job['test'], job['pre_train'], idx, _epoch_0)
                        # set seq_len_epoch_start for next job
                        if len(job_list) > idx+1:
                            # consequently, seq_len_epoch_start should be seq_len_train
                            job_list[idx+1]['train_cfg'].seq_len_epoch_start = job['train_cfg'].seq_len_train if job['pre_train'] is False else 1
                            logging.info('Set seq_len_epoch_start for next job to {}, the seq_len_train of this job'.format(job_list[idx+1]['train_cfg'].seq_len_epoch_start))
                    else:
                        logging.info('Testing model')
                        hdf5_dataset.close()
                        # copy dataset to hydra output directory
                        _save_predictions = cfg.nn_model.training.save_predictions_in_dataset
                        if _save_predictions is True:
                            _path = filepaths.filepath_dataset_current_hydra_output()
                            shutil.copy(filepaths.filepath_dataset_from_config(cfg.dataset_name, cfg.dataset_path), _path)
                            logging.info('Adding predictions to dataset')
                            logging.info('copied dataset to file: {}'.format(_path))
                            hdf5_dataset = h5py.File(_path, 'r+')
                        else:
                            logging.info('Not saving predictions in dataset')
                        for context in ['train', 'test', 'validation', 'common_test']:
                            if dataloaders[context] is None:
                                logging.info('No data for context {} in dataset. Skipping.'.format(context))
                            else:
                                logging.info('Testing of dataset for context {}'.format(context))
                                if _save_predictions is True:
                                    # Stream results batch-by-batch to HDF5 to reduce RAM usage
                                    total_len = len(dataloaders[context].dataset)
                                    data_iter = iter(dataloaders[context])
                                    created_dsets = False
                                    write_offset = 0
                                    metrics_sum = {}
                                    n_batches = 0
                                    keys_to_save = []
                                    while True:
                                        try:
                                            data_batch = next(data_iter)
                                        except StopIteration:
                                            break
                                        with torch.no_grad():
                                            logging.info(f"\t Batch {n_batches+1}/{int(total_len/cfg.nn_model.training.batch_size_test)+1}")
                                            ret_vals_batch, model_outputs_batch = model.model_and_loss_evaluation(
                                                data_batch, job['train_cfg'], job['pre_train'], device,
                                                return_model_outputs=True, test=True
                                            )
                                        # Initialize datasets on first batch according to save policy
                                        if not created_dsets:
                                            # Decide which keys to save
                                            for key in model_outputs_batch.keys():
                                                if cfg.nn_model.training.test_save_internal_variables is True:
                                                    _save = True
                                                else:
                                                    if key in ['states_hat', 'states_der_hat', 'outputs_hat']:
                                                        _save = True
                                                    elif cfg.nn_model.training.test_save_internal_variables_for == context:
                                                        _save = True
                                                        logging.info('Saving internal variable {} as test_save_internal_variables_for context is {}'.format(key, context))
                                                    else:
                                                        _save = False
                                                        logging.info('Not saving internal variable {} as test_save_no_internal_variables is True'.format(key))
                                                if _save:
                                                    keys_to_save.append(key)
                                            # Create HDF5 datasets per key with full size on first dimension
                                            for key in keys_to_save:
                                                arr = model_outputs_batch[key]
                                                shape_rest = arr.shape[1:]
                                                dset_shape = (total_len,) + shape_rest
                                                hdf5_dataset.create_dataset(context + '/' + key, shape=dset_shape, dtype=arr.dtype)
                                            created_dsets = True
                                        # Write this batch to HDF5
                                        Batch = next(iter(model_outputs_batch.values())).shape[0] if len(model_outputs_batch) > 0 else 0
                                        for key in keys_to_save:
                                            arr = model_outputs_batch[key]
                                            hdf5_dataset[context + '/' + key][write_offset:write_offset + arr.shape[0], ...] = arr
                                        write_offset += Batch
                                        # Accumulate metrics for averaging later (match old np.mean over batches)
                                        if n_batches == 0:
                                            metrics_sum = {k: float(v) for k, v in ret_vals_batch.items()}
                                        else:
                                            for k, v in ret_vals_batch.items():
                                                metrics_sum[k] += float(v)
                                        n_batches += 1
                                    # Compute mean metrics across batches
                                    ret_vals = {k: (metrics_sum[k] / max(n_batches, 1)) for k in metrics_sum.keys()}
                                else:
                                    ret_vals = test_or_validate_one_epoch(model, dataloaders[context], job['train_cfg'], job['pre_train'], device, all_batches=True, return_model_outputs=False)
                                # log stats with logging
                                logging.info('Stats for context {}: {}'.format(context, ret_vals))
                                # log stats with mlflow
                                mlflow.log_metrics(append_context_to_dict_keys(ret_vals, context), step=_epoch_0+1) 
                                mlflow.log_metrics(append_context_to_dict_keys(ret_vals, '{}_final'.format(context)), step=_epoch_0+1)
                                # save loss function values
                                if _save_predictions is True:
                                    for key, value in ret_vals.items():
                                        hdf5_dataset.create_dataset(context+'/'+key, data=value)
                        if _save_predictions is True:
                            hdf5_dataset.close()
                            # save this file
                            shutil.copy(Path(__file__), filepaths.dir_current_hydra_output())
                            logging.info('copied current trainer.py: {} \nto: \n{}'.format(Path(__file__), filepaths.dir_current_hydra_output()))
                if cfg.use_cuda:
                    torch.cuda.empty_cache() 
                break # break the exception loop
            except RuntimeError as e:
                if 'CUDA out of memory' in str(e) or 'CUDA memory is almost full' in str(e):
                    logging.warning('CUDA out of memory error. Trying again in 10 seconds')
                    pyTime.sleep(10)
                    logging.info('Setting batch size to {}'.format(int(_batch_size * 0.7)))
                    if not job['test']:
                        job['train_cfg'].batch_size = int(_batch_size * 0.7)
                    else:
                        cfg.nn_model.training.batch_size_test = int(_batch_size * 0.7)
                    if cfg.use_cuda:
                        torch.cuda.empty_cache()
                else:
                    raise e

main()

Entry point for (B)NODE training via Hydra CLI.

Initializes Hydra configuration system and launches train_all_phases with validated config. Auto-detects config directory and uses 'train_test_ode' as the default config name.

This function is registered as 'trainer' in pyproject.toml, enabling command-line execution via::

uv run trainer [config_overrides]

Examples:

See module docstring for usage examples.

Side Effects
  • Registers config store with Hydra
  • Auto-detects config directory from filepaths
  • Launches Hydra-decorated train_all_phases
Source code in src/bnode_core/ode/trainer.py
def main():
    """Entry point for (B)NODE training via Hydra CLI.

    Initializes Hydra configuration system and launches train_all_phases with
    validated config. Auto-detects config directory and uses 'train_test_ode'
    as the default config name.

    This function is registered as 'trainer' in pyproject.toml, enabling
    command-line execution via::

        uv run trainer [config_overrides]

    Examples:
        See module docstring for usage examples.

    Side Effects:
        - Registers config store with Hydra
        - Auto-detects config directory from filepaths
        - Launches Hydra-decorated train_all_phases
    """
    cs = get_config_store()
    config_dir = filepaths.config_dir_auto_recognize()
    config_name = 'train_test_ode'
    hydra.main(config_path=str(config_dir.absolute()), config_name=config_name, version_base=None)(train_all_phases)()