Skip to content

BNODE Export

bnode_core.ode.bnode.bnode_export

BNODE Model Export to ONNX Format.

This module provides functionality to export trained Balanced Neural ODE (BNODE) models to ONNX format for deployment in production environments. The export process decomposes the BNODE architecture into individual ONNX components that can be integrated into external inference pipelines.

Warning

This documentation is AI generated and may contain inaccuracies. Please verify.

Overview

BNODE models consist of multiple neural network components organized in a variational autoencoder (VAE) structure with a latent ODE:

  1. Encoders: Transform high-dimensional inputs to latent representations

    • State encoder: Maps physical states to latent state space
    • Control encoder: Maps control inputs to latent control space (optional)
    • Parameter encoder: Maps system parameters to latent parameter space (optional)
  2. Latent ODE: Defines dynamics in the learned latent space

    • Can be linear (SSM-based) or nonlinear (neural network)
    • Supports variance propagation (constant or dynamic)
    • Optionally conditioned on latent controls and parameters
  3. Decoder: Reconstructs physical quantities from latent states

    • Maps latent states back to physical state space
    • Optionally reconstructs system outputs
    • Can incorporate latent parameters and controls

Export Process

The export workflow involves:

  1. Model Loading: Retrieve trained model from MLflow or local directory
  2. Configuration: Load training configuration and dataset for normalization
  3. Component Separation: Extract individual neural network modules
  4. ONNX Conversion: Export each component with dynamic batch dimensions
  5. Example I/O: Save sample inputs/outputs in HDF5 for validation

Exported Artifacts

For each BNODE model, the following files are generated:

  • encoder_states.onnx: State encoder neural network
  • encoder_controls.onnx: Control encoder (if applicable)
  • encoder_parameters.onnx: Parameter encoder (if applicable)
  • latent_ode.onnx: Latent ODE function
  • latent_ode_ssm_from_param.onnx: SSM parameter mapping (linear models only)
  • decoder.onnx: Decoder neural network
  • *_example_io.hdf5: Example input/output data for each component
  • bnode_config.yaml: Complete model configuration

Configuration

The export process is configured using Hydra with the onnx_export_config_class dataclass. Configuration can be provided via:

  1. YAML config file (conf/onnx_export.yaml)
  2. Command-line overrides
  3. Programmatic instantiation
Required Configuration Fields

Either mlflow_run_id OR model_directory must be specified:

  • mlflow_run_id (str): MLflow run ID to retrieve model from tracking server
  • model_directory (str): Local path to model artifacts directory
Optional Configuration Fields
  • mlflow_tracking_uri (str): MLflow server URI (default: local ./mlruns)
  • model_checkpoint_path (str): Specific checkpoint file (default: latest)
  • config_path (str): Custom config file (default: .hydra/config_validated.yaml)
  • dataset_path (str): Dataset for normalization (default: dataset.hdf5)
  • output_dir (str): Export destination (default: Hydra output directory)

Dataset Requirements

A dataset in HDF5 format is required for model initialization and normalization. The dataset must contain:

  • Structure: Training/validation/test splits with trajectories
  • Variables: States, controls, parameters, outputs (as applicable)
  • Format: Shape (n_trajectories, n_variables, n_timesteps)
  • Source: Generated by data_generation module or provided externally

The dataset is used to: 1. Initialize model dimensions and normalization layers 2. Provide example inputs for ONNX graph tracing 3. Validate exported models against known data

Typical Usage Examples

Example 1: Export from MLflow Run::

# Export model from MLflow tracking server
uv run bnode_export mlflow_run_id=abc123def456 \
                    mlflow_tracking_uri=http://localhost:5000 \
                    output_dir=./exports/my_model

Example 2: Export from Local Directory::

# Export model from local artifacts
uv run bnode_export model_directory=./outputs/2024-11-07/10-30-45 \
                    output_dir=./exports/my_model

Example 3: Custom Checkpoint and Dataset::

# Specify custom checkpoint and dataset paths
uv run bnode_export mlflow_run_id=abc123def456 \
                    model_checkpoint_path=./checkpoints/model_phase_3.pt \
                    dataset_path=./data/custom_dataset.hdf5 \
                    output_dir=./exports/custom_export

Example 4: Export with Hydra Multirun::

# Export multiple models in parallel
uv run bnode_export --multirun \
                    mlflow_run_id=run1,run2,run3 \
                    output_dir=./exports

ONNX Export Backend Selection

This module uses PyTorch's ONNX export with dynamo=False to leverage the legacy TorchScript-based exporter. This choice is made for the following reasons:

Why Legacy Exporter (dynamo=False)?

  1. Stability: The TorchScript-based exporter is mature and battle-tested with complex neural network architectures including custom layers and control flow.

  2. Complex Model Support: BNODE models contain:

  3. Conditional logic (variance modes, parameter inclusion)

  4. Custom normalization layers with stateful initialization
  5. Multiple encoder/decoder components with optional inputs
  6. Dynamic control flow based on model configuration

The legacy exporter handles these patterns more reliably.

  1. Keyword Arguments: Models with optional keyword arguments (like latent parameters and controls) require the kwargs parameter, which works seamlessly with the legacy exporter.

About the New Exporter (dynamo=True)

Starting in PyTorch 2.9, dynamo=True is the default, using the new torch.export-based exporter with ONNXScript. This provides:

  • Better support for dynamic shapes and LLMs
  • More modern export pipeline
  • Enhanced control flow handling

However, it may encounter issues with:

  • Complex conditional logic
  • Custom layer implementations
  • Stateful modules (like normalization with deferred initialization)

Migration Path

If you wish to try the new exporter:

  1. Change dynamo=False to dynamo=True in all torch.onnx.export calls
  2. Test thoroughly with your specific model configurations
  3. Address any tracing warnings or errors
  4. Validate exported ONNX models match PyTorch reference outputs

The legacy exporter will remain supported for the foreseeable future, so there's no immediate need to migrate unless you require new exporter-specific features.

Deprecation Warnings

You may see deprecation warnings when using dynamo=False. These are informational and do not affect functionality. The warnings encourage trying the new exporter but the legacy path remains fully supported.

Notes

  • All exported ONNX models use dynamic batch dimensions for flexible inference
  • Models are exported in evaluation mode (dropout/batch norm frozen)
  • Normalization parameters are embedded in the exported models
  • Linear latent ODEs export separate SSM parameter mapping for efficiency
  • Export uses kwargs parameter to handle optional model inputs correctly

See Also

  • bnode_core.ode.trainer : Training pipeline for BNODE models
  • bnode_core.config.onnx_export_config_class : Configuration dataclass
  • bnode_core.ode.bnode : BNODE model architecture definitions

load_trained_latent_ode(cfg_export)

Load a trained BNODE model from MLflow or local directory.

This function retrieves a trained BNODE model and its associated artifacts (configuration, dataset, checkpoint) from either an MLflow tracking server or a local directory. It reconstructs the model architecture and loads the trained weights.

Parameters:

Name Type Description Default
cfg_export onnx_export_config_class

Export configuration containing:

  • mlflow_run_id: MLflow run identifier (if loading from MLflow)
  • model_directory: Local directory path (if loading locally)
  • mlflow_tracking_uri: MLflow tracking server URI
  • config_path: Optional custom configuration path
  • dataset_path: Optional custom dataset path
  • model_checkpoint_path: Optional specific checkpoint file
required

Returns:

Name Type Description
dict dict

Dictionary containing:

  • 'model': Initialized BNODE model with loaded weights
  • 'cfg': OmegaConf configuration object
  • 'dataset_file': Opened HDF5 dataset file handle
  • 'dataset': Processed training dataset (stacked format)
  • 'temp_dir': Temporary directory path (if artifacts were downloaded, None otherwise)

Raises:

Type Description
FileNotFoundError

If dataset file cannot be found at specified path.

Notes
  • CUDA is disabled (use_cuda=False) for CPU-based ONNX export
  • Normalization is not re-initialized (uses saved parameters)
  • Latest checkpoint is used if model_checkpoint_path is None
  • Model is loaded in evaluation mode for deterministic inference
  • When loading from MLflow remote server, artifacts are downloaded to Hydra output folder
  • Temporary artifact directory is named 'mlflow_artifacts_{run_id}' in Hydra output
Source code in src/bnode_core/ode/bnode/bnode_export.py
def load_trained_latent_ode(cfg_export):
    """Load a trained BNODE model from MLflow or local directory.

    This function retrieves a trained BNODE model and its associated artifacts
    (configuration, dataset, checkpoint) from either an MLflow tracking server
    or a local directory. It reconstructs the model architecture and loads the
    trained weights.

    Args:
        cfg_export (onnx_export_config_class): Export configuration containing:

            - mlflow_run_id: MLflow run identifier (if loading from MLflow)
            - model_directory: Local directory path (if loading locally)
            - mlflow_tracking_uri: MLflow tracking server URI
            - config_path: Optional custom configuration path
            - dataset_path: Optional custom dataset path
            - model_checkpoint_path: Optional specific checkpoint file

    Returns:
        dict (dict): Dictionary containing:

            - 'model': Initialized BNODE model with loaded weights
            - 'cfg': OmegaConf configuration object
            - 'dataset_file': Opened HDF5 dataset file handle
            - 'dataset': Processed training dataset (stacked format)
            - 'temp_dir': Temporary directory path (if artifacts were downloaded, None otherwise)

    Raises:
        FileNotFoundError: If dataset file cannot be found at specified path.

    Notes:
        - CUDA is disabled (use_cuda=False) for CPU-based ONNX export
        - Normalization is not re-initialized (uses saved parameters)
        - Latest checkpoint is used if model_checkpoint_path is None
        - Model is loaded in evaluation mode for deterministic inference
        - When loading from MLflow remote server, artifacts are downloaded to Hydra output folder
        - Temporary artifact directory is named 'mlflow_artifacts_{run_id}' in Hydra output
    """
    temp_dir = None

    # get artifacts directory
    if cfg_export.mlflow_run_id is not None:
        mlflow.set_tracking_uri(cfg_export.mlflow_tracking_uri)
        mlflow_run = mlflow.get_run(cfg_export.mlflow_run_id)
        artifact_uri = mlflow_run.info.artifact_uri

        # Check if artifacts are on remote server (not local file://)
        if not artifact_uri.startswith('file://'):
            # Download artifacts to temporary directory in Hydra output folder
            hydra_output_dir = filepaths.dir_current_hydra_output()
            temp_dir = hydra_output_dir / f'mlflow_artifacts_{cfg_export.mlflow_run_id}'
            temp_dir.mkdir(parents=True, exist_ok=True)
            logging.info(f'Downloading MLflow artifacts from remote server to: {temp_dir}')

            # Download all artifacts
            mlflow.artifacts.download_artifacts(
                run_id=cfg_export.mlflow_run_id,
                dst_path=str(temp_dir)
            )
            dir_artifacts = temp_dir
            logging.info(f'Successfully downloaded artifacts to {temp_dir}')
        else:
            # Local MLflow artifacts
            dir_artifacts = Path(artifact_uri.replace('file://', ''))
    else:
        dir_artifacts = Path(cfg_export.model_directory)
    logging.info('Resolved artifacts uri as {}'.format(str(dir_artifacts)))

    # get config and dataset paths
    if cfg_export.config_path is None:
        path_config = dir_artifacts / '.hydra' / 'config_validated.yaml'
    else: 
        raise ValueError('Custom config_path is not supported in this version.')

    # dataset path
    if cfg_export.dataset_path is None:
        path_dataset = dir_artifacts / 'dataset.hdf5'
    else:
        if cfg_export.dataset_path.startswith('file://'):
            path_dataset = Path(cfg_export.dataset_path.replace('file://', ''))
        else:
            path_dataset = Path(cfg_export.dataset_path)


    # load config (and validate it using the dataclass?)
    with open(path_config) as file:
        cfg_dict = yaml.load(file, Loader=yaml.FullLoader)
        cfg = OmegaConf.create(cfg_dict)
        cfg.use_cuda = False
    logging.info('Loaded config of BNODE: {}'.format(str(cfg)))


    # load training dataset
    if path_dataset.is_file():
        dataset_file = h5py.File(path_dataset, 'r')
    else:
        raise FileNotFoundError(f'Dataset file {path_dataset} not found. Please provide a valid dataset path.')
    dataset = make_stacked_dataset(dataset_file, 'train')
    model = initialize_model(cfg, train_dataset=dataset, hdf5_dataset=None, 
                             initialize_normalization=False, model_type='bnode')

    # load latest checkpoint
    if cfg_export.model_checkpoint_path is None:
        path_checkpoint = sorted(dir_artifacts.rglob('model_phase_*.pt'))[-1]
    else:
        if cfg_export.model_checkpoint_path.startswith('file://'):
            path_checkpoint = Path(cfg_export.model_checkpoint_path.replace('file://', ''))
        else:
            path_checkpoint = Path(cfg_export.model_checkpoint_path)
        if not path_checkpoint.is_file():
            raise FileNotFoundError(f'Checkpoint file {path_checkpoint} not found. Please provide a valid checkpoint path.')

    model.load(path_checkpoint, device='cpu')
    return {'model': model, 'cfg': cfg, 'dataset_file': dataset_file, 'dataset': dataset, 'temp_dir': temp_dir}

export_example_io_data(res, inputs, path_example_io)

Export example input/output data for ONNX model validation.

Saves the inputs and outputs of a model component to an HDF5 file for later validation. This allows users to verify that their ONNX runtime produces identical results to the PyTorch reference implementation.

Parameters:

Name Type Description Default
res different types

Model output(s). Can be: - torch.Tensor: Single output tensor - tuple/list: Multiple output tensors - dict: Named output tensors

required
inputs dict

Dictionary of input tensors with their names as keys. Each value can be a torch.Tensor or None.

required
path_example_io Path

Output path for the HDF5 file.

required
HDF5 Structure

/inputs/ : dataset (array) : dataset (array) ... /outputs/ output: dataset (single output case) OR output_0, output_1, ...: datasets (multiple output case) OR , , ...: datasets (dict output case)

Notes
  • Tensors are automatically converted to NumPy arrays
  • None inputs are skipped
  • Outputs are organized based on their type (tensor/tuple/dict)
Source code in src/bnode_core/ode/bnode/bnode_export.py
def export_example_io_data(res, inputs, path_example_io):
    """Export example input/output data for ONNX model validation.

    Saves the inputs and outputs of a model component to an HDF5 file for
    later validation. This allows users to verify that their ONNX runtime
    produces identical results to the PyTorch reference implementation.

    Args:
        res (different types): Model output(s). Can be:
            - torch.Tensor: Single output tensor
            - tuple/list: Multiple output tensors
            - dict: Named output tensors
        inputs (dict): Dictionary of input tensors with their names as keys.
            Each value can be a torch.Tensor or None.
        path_example_io (Path): Output path for the HDF5 file.

    HDF5 Structure:
        /inputs/
            <input_name_1>: dataset (array)
            <input_name_2>: dataset (array)
            ...
        /outputs/
            output: dataset (single output case)
            OR
            output_0, output_1, ...: datasets (multiple output case)
            OR
            <output_name_1>, <output_name_2>, ...: datasets (dict output case)

    Notes:
        - Tensors are automatically converted to NumPy arrays
        - None inputs are skipped
        - Outputs are organized based on their type (tensor/tuple/dict)
    """
    with h5py.File(path_example_io, 'w') as f:
        # Save inputs
        grp_in = f.create_group('inputs')
        for in_key, in_val in inputs.items():
            if isinstance(in_val, torch.Tensor):
                grp_in.create_dataset(in_key, data=in_val.detach().numpy())
            elif in_val is not None:
                grp_in.create_dataset(in_key, data=in_val)
        # Save outputs
        grp_out = f.create_group('outputs')
        if isinstance(res, torch.Tensor):
            grp_out.create_dataset('output', data=res.detach().cpu().numpy())
        elif isinstance(res, (tuple, list)):
            for i, out_val in enumerate(res):
                grp_out.create_dataset(f'output_{i}', data=out_val.detach().cpu().numpy())
        elif isinstance(res, dict):
            for out_key, out_val in res.items():
                grp_out.create_dataset(out_key, data=out_val.detach().cpu().numpy())

log_shapes_of_dict(d, name='')

Log the shapes of tensors in a data structure for debugging.

Recursively traverses dictionaries, lists, or tuples and logs the shapes of any PyTorch tensors found. Useful for debugging model I/O during export.

Source code in src/bnode_core/ode/bnode/bnode_export.py
def log_shapes_of_dict(d, name=''):
    """Log the shapes of tensors in a data structure for debugging.

    Recursively traverses dictionaries, lists, or tuples and logs the shapes
    of any PyTorch tensors found. Useful for debugging model I/O during export.
    """
    if name:
        logging.info(f"Shapes in {name}:")
    else:
        logging.info("Shapes in .... :")
    if isinstance(d, dict):
        for key, value in d.items():
            if isinstance(value, torch.Tensor):
                logging.info(f"\t{key}: {value.shape}")
            elif isinstance(value, (tuple, list)):
                logging.info(f"\t{key}: {[v.shape for v in value if isinstance(v, torch.Tensor)]}")
            else:
                logging.info(f"\t{key}: {value}")
    elif isinstance(d, (tuple, list)):
        for i, value in enumerate(d):
            if isinstance(value, torch.Tensor):
                logging.info(f"\t[{i}]: {value.shape}")
            elif isinstance(value, (tuple, list)):
                logging.info(f"\t[{i}]: {[v.shape for v in value if isinstance(v, torch.Tensor)]}")
            else:
                logging.info(f"\t[{i}]: {value}")
    else:
        logging.info(f"\t{type(d)}: {d}")

export_bnode(cfg_export: onnx_export_config_class)

Main function for BNODE model export to ONNX format.

This function orchestrates the complete export process:

  1. Loads trained BNODE model and configuration
  2. Extracts individual encoder, decoder, and ODE components
  3. Exports each component to ONNX format with dynamic batch dimensions
  4. Saves example input/output data for validation
  5. Exports configuration file for reference

The function is designed to be invoked via the uv run bnode_export command, which is registered in pyproject.toml. Hydra manages configuration loading and command-line argument parsing.

Parameters:

Name Type Description Default
cfg_export onnx_export_config_class

Export configuration managed by Hydra. Configuration can be specified via YAML files or command-line overrides.

required
Export Workflow
  1. Model Loading:

    • Retrieves model from MLflow or local directory
    • Loads configuration and dataset for normalization
    • Restores trained weights from checkpoint
  2. Component Extraction:

    • State encoder (always present)
    • Control encoder (if model uses controls)
    • Parameter encoder (if model uses parameters)
    • Latent ODE function
    • SSM parameter mapping (linear models only)
    • Decoder (states and/or outputs)
  3. ONNX Export:

    • Each component is exported separately
    • Dynamic batch dimensions allow flexible inference
    • Input/output names are explicitly defined
    • Example I/O saved for each component
  4. Artifact Organization:

    • All exports saved to output_dir or Hydra output
    • Configuration exported as bnode_config.yaml
    • Example data in *_example_io.hdf5 files
Command-Line Usage

Export from MLflow::

uv run bnode_export mlflow_run_id=<run_id> \
                    mlflow_tracking_uri=<uri> \
                    output_dir=<output_path>

Export from local directory::

uv run bnode_export model_directory=<path> \
                    output_dir=<output_path>

With custom checkpoint::

uv run bnode_export mlflow_run_id=<run_id> \
                    model_checkpoint_path=<checkpoint.pt> \
                    dataset_path=<dataset.hdf5>
Exported Files
  • encoder_states.onnx: State encoder
  • encoder_controls.onnx: Control encoder (if applicable)
  • encoder_parameters.onnx: Parameter encoder (if applicable)
  • latent_ode.onnx: Latent dynamics function
  • latent_ode_ssm_from_param.onnx: SSM mapping (linear models)
  • decoder.onnx: Decoder network
  • encoder_*_example_io.hdf5: Example encoder I/O
  • latent_ode_example_io.hdf5: Example ODE I/O
  • decoder_example_io.hdf5: Example decoder I/O
  • bnode_config.yaml: Complete model configuration
ONNX Model Specifications

All exported models include:

  • Dynamic axes: Batch dimension (axis 0) is dynamic
  • Input names: Descriptive names for each input tensor
  • Output names: Descriptive names for each output tensor

Encoder outputs:

  • latent_<type>_mu: Mean of latent distribution
  • latent_<type>_logvar: Log-variance of latent distribution

ODE outputs:

  • lat_states_mu_dot: State derivative (constant variance)
  • concat(lat_states_mu_dot,lat_states_logvar_dot): Combined derivatives (dynamic variance)

Decoder outputs:

  • states: Reconstructed physical states
  • outputs: Reconstructed system outputs (if applicable)
Notes
  • Model is set to evaluation mode before export
  • All computations performed on CPU (use_cuda=False)
  • Normalization parameters are embedded in exported models
  • Example I/O files can be used to validate ONNX Runtime inference
  • Hydra output directory is used if output_dir not specified
See Also
  • load_trained_latent_ode: Model loading function
  • export_example_io_data: Example I/O export function
  • bnode_core.ode.trainer.initialize_model: Model initialization
Source code in src/bnode_core/ode/bnode/bnode_export.py
def export_bnode(cfg_export: onnx_export_config_class):
    """Main function for BNODE model export to ONNX format.

    This function orchestrates the complete export process:

    1. Loads trained BNODE model and configuration
    2. Extracts individual encoder, decoder, and ODE components
    3. Exports each component to ONNX format with dynamic batch dimensions
    4. Saves example input/output data for validation
    5. Exports configuration file for reference

    The function is designed to be invoked via the ``uv run bnode_export`` command,
    which is registered in ``pyproject.toml``. Hydra manages configuration loading
    and command-line argument parsing.

    Args:
        cfg_export (onnx_export_config_class): Export configuration managed by Hydra.
            Configuration can be specified via YAML files or command-line overrides.

    Export Workflow:
        1. **Model Loading**: 
            - Retrieves model from MLflow or local directory
            - Loads configuration and dataset for normalization
            - Restores trained weights from checkpoint

        2. **Component Extraction**:
            - State encoder (always present)
            - Control encoder (if model uses controls)
            - Parameter encoder (if model uses parameters)
            - Latent ODE function
            - SSM parameter mapping (linear models only)
            - Decoder (states and/or outputs)

        3. **ONNX Export**:
            - Each component is exported separately
            - Dynamic batch dimensions allow flexible inference
            - Input/output names are explicitly defined
            - Example I/O saved for each component

        4. **Artifact Organization**:
            - All exports saved to output_dir or Hydra output
            - Configuration exported as ``bnode_config.yaml``
            - Example data in ``*_example_io.hdf5`` files

    Command-Line Usage:
        Export from MLflow::

            uv run bnode_export mlflow_run_id=<run_id> \\
                                mlflow_tracking_uri=<uri> \\
                                output_dir=<output_path>

        Export from local directory::

            uv run bnode_export model_directory=<path> \\
                                output_dir=<output_path>

        With custom checkpoint::

            uv run bnode_export mlflow_run_id=<run_id> \\
                                model_checkpoint_path=<checkpoint.pt> \\
                                dataset_path=<dataset.hdf5>

    Exported Files:
        - ``encoder_states.onnx``: State encoder
        - ``encoder_controls.onnx``: Control encoder (if applicable)
        - ``encoder_parameters.onnx``: Parameter encoder (if applicable)
        - ``latent_ode.onnx``: Latent dynamics function
        - ``latent_ode_ssm_from_param.onnx``: SSM mapping (linear models)
        - ``decoder.onnx``: Decoder network
        - ``encoder_*_example_io.hdf5``: Example encoder I/O
        - ``latent_ode_example_io.hdf5``: Example ODE I/O
        - ``decoder_example_io.hdf5``: Example decoder I/O
        - ``bnode_config.yaml``: Complete model configuration

    ONNX Model Specifications:
        All exported models include:

        - **Dynamic axes**: Batch dimension (axis 0) is dynamic
        - **Input names**: Descriptive names for each input tensor
        - **Output names**: Descriptive names for each output tensor

        Encoder outputs:

        - ``latent_<type>_mu``: Mean of latent distribution
        - ``latent_<type>_logvar``: Log-variance of latent distribution

        ODE outputs:

        - ``lat_states_mu_dot``: State derivative (constant variance)
        - ``concat(lat_states_mu_dot,lat_states_logvar_dot)``: Combined
            derivatives (dynamic variance)

        Decoder outputs:

        - ``states``: Reconstructed physical states
        - ``outputs``: Reconstructed system outputs (if applicable)

    Notes:
        - Model is set to evaluation mode before export
        - All computations performed on CPU (use_cuda=False)
        - Normalization parameters are embedded in exported models
        - Example I/O files can be used to validate ONNX Runtime inference
        - Hydra output directory is used if output_dir not specified

    See Also:
        - ``load_trained_latent_ode``: Model loading function
        - ``export_example_io_data``: Example I/O export function
        - ``bnode_core.ode.trainer.initialize_model``: Model initialization
    """
    logging.info('Exporting BNODE using the following config {}'.format(str(cfg_export)))

    # load model
    res = load_trained_latent_ode(cfg_export)
    model, cfg, dataset_file, dataset = res['model'], res['cfg'], res['dataset_file'], res['dataset']
    temp_dir = res['temp_dir']
    model.eval()

    # determine output dir
    dir_output = Path(cfg_export.output_dir) if cfg_export.output_dir is not None else filepaths.dir_current_hydra_output()

    # export bnode config
    path_config = dir_output / 'bnode_config.yaml'
    logging.info(f'Exporting BNODE config to {path_config}')
    path_config.parent.mkdir(parents=True, exist_ok=True)
    with open(path_config, 'w') as f:
        yaml.dump(OmegaConf.to_container(cfg, resolve=True), f, default_flow_style=False)

    # get test points for graph construction
    test_state = dataset[0]['states'][:,0].unsqueeze(0)
    test_control = dataset[0]['controls'][:,0].unsqueeze(0) if model.include_controls else None
    test_parameters = dataset[0]['parameters'].unsqueeze(0) if model.include_parameters else None

    # export the encoders
    encoders = {'states': model.state_encoder, 
                'controls': model.controls_encoder if model.include_controls else None,
                'parameters': model.parameter_encoder if model.include_params_encoder else None
            }
    # construct test inputs for graph construction
    inputs_dict = {
        'states': {'x': test_state},
        'controls': {'x': test_control} if model.params_to_control_encoder is False else {'x': test_control, 'params': test_parameters},
        'parameters': {'x': test_parameters},
    }
    # handling of additional inputs to state encoder
    if model.params_to_state_encoder is True:
        inputs_dict['states']['params'] = test_parameters
    if model.controls_to_state_encoder is True:
        inputs_dict['states']['controls'] = test_control

    latents_dict = {}
    for key, encoder in encoders.items():
        if encoder is not None:
            path_encoder = dir_output / f'encoder_{key}.onnx'
            logging.info(f'Exporting {key} encoder to {path_encoder}')
            # test model
            log_shapes_of_dict(inputs_dict[key], f'Inputs for {key} encoder')
            res = encoder(**inputs_dict[key])
            log_shapes_of_dict(res, f'Outputs of {key} encoder')
            logging.info(f'Test result {res}')
            # export
            input_names = list(inputs_dict[key].keys())
            output_names=['latent_' + key + '_mu', 'latent_' + key + '_logvar']
            dynamic_axes={}
            for name in input_names:
                dynamic_axes[name] = {0: 'batch_size'}
            for name in output_names:
                dynamic_axes[name] = {0: 'batch_size'}
            # Use legacy TorchScript-based exporter for better stability
            torch.onnx.export(encoder, 
                              args=(),
                              kwargs=inputs_dict[key],
                              f=path_encoder, 
                              input_names=input_names,
                              output_names=output_names,
                              dynamic_axes=dynamic_axes,
                              dynamo=False
            )
            logging.info(f'Exported {key} encoder successfully')
            # export also example io
            path_example_io = dir_output / f'encoder_{key}_example_io.hdf5'
            export_example_io_data(res, inputs_dict[key], path_example_io)
            # save latent variable
            latents_dict[key] = res[0] # the first is mu

    # export ssm from parameters model and get A_from_param and B_from_param for the latent ODE function
    ode = model.latent_ode_func
    if ode.include_parameters is True and ode.linear is True:
        # this is only possible if the model is linear and has parameters
        logging.info('Exporting SSM from parameters')
        ssm = model.latent_ode_func.ssm_from_param
        path_ssm = dir_output / 'latent_ode_ssm_from_param.onnx'
        logging.info(f'Export latent ODE SSM from parameters to {path_ssm}')
        # construct test input
        inputs = {
            'lat_parameters': latents_dict['parameters'],
        }
        # test model
        log_shapes_of_dict(inputs, 'Inputs for latent ODE SSM from parameters')
        res = ssm(**inputs)
        log_shapes_of_dict(res, 'Outputs of latent ODE SSM from parameters')
        logging.info(f'Test result {res}')
        # export
        input_names=['lat_parameters']
        output_names=['A', 'B'] if ode.include_controls else ['A']
        dynamic_axes={}
        for name in input_names:
            dynamic_axes[name] = {0: 'batch_size'}
        for name in output_names:
            dynamic_axes[name] = {0: 'batch_size'}
        torch.onnx.export(ssm, 
                          args=(),
                          kwargs=inputs,
                          f=path_ssm, 
                          input_names=input_names, 
                          output_names=output_names, 
                          dynamic_axes=dynamic_axes,
                          dynamo=False)
        logging.info(f'Exported latent ODE SSM from parameters successfully')
        # get A_from_param and B_from_param
        if ode.include_controls:
            A_from_param, B_from_param = res
        else:
            A_from_param = res

    # export the latent ode function
    path_ode = dir_output / 'latent_ode.onnx'
    logging.info(f'Export latent ODE to {path_ode}')
    # construct test input
    inputs = {
        'lat_states': latents_dict['states'],
        'lat_parameters': latents_dict['parameters'] if ode.include_parameters is True else None,
        'lat_controls': latents_dict['controls'] if ode.include_controls is True else None,
        'A_from_param': A_from_param if ode.include_parameters is True and ode.linear else None,
        'B_from_param': B_from_param if ode.include_parameters is True and ode.linear and ode.include_controls else None,
    }
    # test model
    log_shapes_of_dict(inputs, 'Inputs for latent ODE')
    res = ode(**inputs)
    log_shapes_of_dict(res, 'Outputs of latent ODE')
    logging.info(f'Test result {res}')
    # export
    input_names=[]
    for key in inputs.keys():
        if inputs[key] != None:
            input_names.append(key)
    dynamic_axes={}
    for name in input_names:
        dynamic_axes[name] = {0: 'batch_size'}
    if model.lat_ode_type == 'variance_constant' or model.lat_ode_type == 'vanilla':
        output_names = ['lat_states_mu_dot']
        dynamic_axes['lat_states_mu_dot'] = {0: 'batch_size'}
    elif model.lat_ode_type == 'variance_dynamic':
        output_names = ['concat(lat_states_mu_dot,lat_states_logvar_dot)']
        dynamic_axes['concat(lat_states_mu_dot,lat_states_logvar_dot)'] = {0: 'batch_size'}
    # Filter out None values and convert to tuple
    filtered_inputs = {k: v for k, v in inputs.items() if v is not None}
    torch.onnx.export(ode, 
                      args=(),
                      kwargs=filtered_inputs,
                      f=path_ode, 
                      input_names=input_names, 
                      output_names=output_names, 
                      dynamic_axes=dynamic_axes,
                      dynamo=False)
    logging.info(f'Exported latent ODE successfully')
    # export also example io
    path_example_io = dir_output / f'latent_ode_example_io.hdf5'
    export_example_io_data(res, inputs, path_example_io)

    # export the decoder
    # TODO: What to with split return? because we have to tak the first n elements for states etc now... implement other function for this? / optional argument?
    decoder = model.decoder
    decoder.onnx_export = True  # disable concatenation of outputs for ONNX export
    path_decoder = dir_output / 'decoder.onnx'
    logging.info(f'Export decoder to {path_decoder}')
    # construct test input
    inputs = {
        'lat_state': latents_dict['states'],
        'lat_parameters': latents_dict['parameters'] if decoder.include_parameters is True else None,
        'lat_controls': latents_dict['controls'] if decoder.include_controls is True else None,
    }
    # test model
    log_shapes_of_dict(inputs, 'Inputs for decoder')
    res = decoder(**inputs)
    log_shapes_of_dict(res, 'Outputs of decoder')
    logging.info(f'Test result {res}')
    input_names = []
    # export
    for key in inputs.keys():
        if inputs[key] != None:
            input_names.append(key)
    if decoder.include_outputs and decoder.include_states:
        output_names = ['states', 'outputs']
    elif decoder.include_outputs:
        output_names = ['outputs']
    elif decoder.include_states:
        output_names = ['states']
    dynamic_axes={}
    for name in input_names:
        dynamic_axes[name] = {0: 'batch_size'}
    for name in output_names:
        dynamic_axes[name] = {0: 'batch_size'}
    # Filter out None values and convert to tuple
    filtered_inputs = {k: v for k, v in inputs.items() if v is not None}
    torch.onnx.export(decoder, 
                      args=(),
                      kwargs=filtered_inputs,
                      f=path_decoder, 
                      input_names=input_names, 
                      output_names=output_names, 
                      dynamic_axes=dynamic_axes,
                      dynamo=False)
    logging.info(f'Exported decoder successfully')
    # export also example io
    path_example_io = dir_output / f'decoder_example_io.hdf5'
    export_example_io_data(res, inputs, path_example_io)

    if temp_dir is not None:
        logging.info(f'Cleaning up temporary directory: {temp_dir}')
        shutil.rmtree(temp_dir, ignore_errors=True)
        logging.info('Temporary directory cleaned up successfully')

    # copy the current hydra folder to the output for reference
    dir_hydra_current = filepaths.dir_current_hydra_output()
    dir_hydra_output_copy = dir_output / 'hydra'
    if dir_hydra_current.is_dir() and (dir_hydra_output_copy.absolute() != dir_hydra_current.absolute()):
        logging.info(f'Copying current Hydra directory {dir_hydra_current} to output {dir_hydra_output_copy}')
        shutil.copytree(dir_hydra_current, dir_hydra_output_copy, dirs_exist_ok=True)
        logging.info('Hydra directory copied successfully')
    else: 
       Warning(f'Current Hydra directory {dir_hydra_current} not found and could not be copied.')