BNODE Model Export to ONNX Format.
This module provides functionality to export trained Balanced Neural ODE (BNODE)
models to ONNX format for deployment in production environments. The export process
decomposes the BNODE architecture into individual ONNX components that can be
integrated into external inference pipelines.
Warning
This documentation is AI generated and may contain inaccuracies. Please verify.
Overview
BNODE models consist of multiple neural network components organized in a
variational autoencoder (VAE) structure with a latent ODE:
-
Encoders: Transform high-dimensional inputs to latent representations
- State encoder: Maps physical states to latent state space
- Control encoder: Maps control inputs to latent control space (optional)
- Parameter encoder: Maps system parameters to latent parameter space (optional)
-
Latent ODE: Defines dynamics in the learned latent space
- Can be linear (SSM-based) or nonlinear (neural network)
- Supports variance propagation (constant or dynamic)
- Optionally conditioned on latent controls and parameters
-
Decoder: Reconstructs physical quantities from latent states
- Maps latent states back to physical state space
- Optionally reconstructs system outputs
- Can incorporate latent parameters and controls
Export Process
The export workflow involves:
- Model Loading: Retrieve trained model from MLflow or local directory
- Configuration: Load training configuration and dataset for normalization
- Component Separation: Extract individual neural network modules
- ONNX Conversion: Export each component with dynamic batch dimensions
- Example I/O: Save sample inputs/outputs in HDF5 for validation
Exported Artifacts
For each BNODE model, the following files are generated:
encoder_states.onnx: State encoder neural network
encoder_controls.onnx: Control encoder (if applicable)
encoder_parameters.onnx: Parameter encoder (if applicable)
latent_ode.onnx: Latent ODE function
latent_ode_ssm_from_param.onnx: SSM parameter mapping (linear models only)
decoder.onnx: Decoder neural network
*_example_io.hdf5: Example input/output data for each component
bnode_config.yaml: Complete model configuration
Configuration
The export process is configured using Hydra with the onnx_export_config_class
dataclass. Configuration can be provided via:
- YAML config file (
conf/onnx_export.yaml)
- Command-line overrides
- Programmatic instantiation
Required Configuration Fields
Either mlflow_run_id OR model_directory must be specified:
- mlflow_run_id (str): MLflow run ID to retrieve model from tracking server
- model_directory (str): Local path to model artifacts directory
Optional Configuration Fields
- mlflow_tracking_uri (str): MLflow server URI (default: local
./mlruns)
- model_checkpoint_path (str): Specific checkpoint file (default: latest)
- config_path (str): Custom config file (default:
.hydra/config_validated.yaml)
- dataset_path (str): Dataset for normalization (default:
dataset.hdf5)
- output_dir (str): Export destination (default: Hydra output directory)
Dataset Requirements
A dataset in HDF5 format is required for model initialization and normalization.
The dataset must contain:
- Structure: Training/validation/test splits with trajectories
- Variables: States, controls, parameters, outputs (as applicable)
- Format: Shape
(n_trajectories, n_variables, n_timesteps)
- Source: Generated by
data_generation module or provided externally
The dataset is used to:
1. Initialize model dimensions and normalization layers
2. Provide example inputs for ONNX graph tracing
3. Validate exported models against known data
Typical Usage Examples
Example 1: Export from MLflow Run::
# Export model from MLflow tracking server
uv run bnode_export mlflow_run_id=abc123def456 \
mlflow_tracking_uri=http://localhost:5000 \
output_dir=./exports/my_model
Example 2: Export from Local Directory::
# Export model from local artifacts
uv run bnode_export model_directory=./outputs/2024-11-07/10-30-45 \
output_dir=./exports/my_model
Example 3: Custom Checkpoint and Dataset::
# Specify custom checkpoint and dataset paths
uv run bnode_export mlflow_run_id=abc123def456 \
model_checkpoint_path=./checkpoints/model_phase_3.pt \
dataset_path=./data/custom_dataset.hdf5 \
output_dir=./exports/custom_export
Example 4: Export with Hydra Multirun::
# Export multiple models in parallel
uv run bnode_export --multirun \
mlflow_run_id=run1,run2,run3 \
output_dir=./exports
ONNX Export Backend Selection
This module uses PyTorch's ONNX export with dynamo=False to leverage the
legacy TorchScript-based exporter. This choice is made for the following reasons:
Why Legacy Exporter (dynamo=False)?
-
Stability: The TorchScript-based exporter is mature and battle-tested with
complex neural network architectures including custom layers and control flow.
-
Complex Model Support: BNODE models contain:
-
Conditional logic (variance modes, parameter inclusion)
- Custom normalization layers with stateful initialization
- Multiple encoder/decoder components with optional inputs
- Dynamic control flow based on model configuration
The legacy exporter handles these patterns more reliably.
- Keyword Arguments: Models with optional keyword arguments (like latent
parameters and controls) require the
kwargs parameter, which works
seamlessly with the legacy exporter.
About the New Exporter (dynamo=True)
Starting in PyTorch 2.9, dynamo=True is the default, using the new
torch.export-based exporter with ONNXScript. This provides:
- Better support for dynamic shapes and LLMs
- More modern export pipeline
- Enhanced control flow handling
However, it may encounter issues with:
- Complex conditional logic
- Custom layer implementations
- Stateful modules (like normalization with deferred initialization)
Migration Path
If you wish to try the new exporter:
- Change
dynamo=False to dynamo=True in all torch.onnx.export calls
- Test thoroughly with your specific model configurations
- Address any tracing warnings or errors
- Validate exported ONNX models match PyTorch reference outputs
The legacy exporter will remain supported for the foreseeable future, so there's
no immediate need to migrate unless you require new exporter-specific features.
Deprecation Warnings
You may see deprecation warnings when using dynamo=False. These are
informational and do not affect functionality. The warnings encourage trying
the new exporter but the legacy path remains fully supported.
Notes
- All exported ONNX models use dynamic batch dimensions for flexible inference
- Models are exported in evaluation mode (dropout/batch norm frozen)
- Normalization parameters are embedded in the exported models
- Linear latent ODEs export separate SSM parameter mapping for efficiency
- Export uses
kwargs parameter to handle optional model inputs correctly
See Also
bnode_core.ode.trainer : Training pipeline for BNODE models
bnode_core.config.onnx_export_config_class : Configuration dataclass
bnode_core.ode.bnode : BNODE model architecture definitions
load_trained_latent_ode(cfg_export)
Load a trained BNODE model from MLflow or local directory.
This function retrieves a trained BNODE model and its associated artifacts
(configuration, dataset, checkpoint) from either an MLflow tracking server
or a local directory. It reconstructs the model architecture and loads the
trained weights.
Parameters:
| Name |
Type |
Description |
Default |
cfg_export
|
onnx_export_config_class
|
Export configuration containing:
- mlflow_run_id: MLflow run identifier (if loading from MLflow)
- model_directory: Local directory path (if loading locally)
- mlflow_tracking_uri: MLflow tracking server URI
- config_path: Optional custom configuration path
- dataset_path: Optional custom dataset path
- model_checkpoint_path: Optional specific checkpoint file
|
required
|
Returns:
| Name | Type |
Description |
dict |
dict
|
Dictionary containing:
- 'model': Initialized BNODE model with loaded weights
- 'cfg': OmegaConf configuration object
- 'dataset_file': Opened HDF5 dataset file handle
- 'dataset': Processed training dataset (stacked format)
- 'temp_dir': Temporary directory path (if artifacts were downloaded, None otherwise)
|
Raises:
| Type |
Description |
FileNotFoundError
|
If dataset file cannot be found at specified path.
|
Notes
- CUDA is disabled (use_cuda=False) for CPU-based ONNX export
- Normalization is not re-initialized (uses saved parameters)
- Latest checkpoint is used if model_checkpoint_path is None
- Model is loaded in evaluation mode for deterministic inference
- When loading from MLflow remote server, artifacts are downloaded to Hydra output folder
- Temporary artifact directory is named 'mlflow_artifacts_{run_id}' in Hydra output
Source code in src/bnode_core/ode/bnode/bnode_export.py
| def load_trained_latent_ode(cfg_export):
"""Load a trained BNODE model from MLflow or local directory.
This function retrieves a trained BNODE model and its associated artifacts
(configuration, dataset, checkpoint) from either an MLflow tracking server
or a local directory. It reconstructs the model architecture and loads the
trained weights.
Args:
cfg_export (onnx_export_config_class): Export configuration containing:
- mlflow_run_id: MLflow run identifier (if loading from MLflow)
- model_directory: Local directory path (if loading locally)
- mlflow_tracking_uri: MLflow tracking server URI
- config_path: Optional custom configuration path
- dataset_path: Optional custom dataset path
- model_checkpoint_path: Optional specific checkpoint file
Returns:
dict (dict): Dictionary containing:
- 'model': Initialized BNODE model with loaded weights
- 'cfg': OmegaConf configuration object
- 'dataset_file': Opened HDF5 dataset file handle
- 'dataset': Processed training dataset (stacked format)
- 'temp_dir': Temporary directory path (if artifacts were downloaded, None otherwise)
Raises:
FileNotFoundError: If dataset file cannot be found at specified path.
Notes:
- CUDA is disabled (use_cuda=False) for CPU-based ONNX export
- Normalization is not re-initialized (uses saved parameters)
- Latest checkpoint is used if model_checkpoint_path is None
- Model is loaded in evaluation mode for deterministic inference
- When loading from MLflow remote server, artifacts are downloaded to Hydra output folder
- Temporary artifact directory is named 'mlflow_artifacts_{run_id}' in Hydra output
"""
temp_dir = None
# get artifacts directory
if cfg_export.mlflow_run_id is not None:
mlflow.set_tracking_uri(cfg_export.mlflow_tracking_uri)
mlflow_run = mlflow.get_run(cfg_export.mlflow_run_id)
artifact_uri = mlflow_run.info.artifact_uri
# Check if artifacts are on remote server (not local file://)
if not artifact_uri.startswith('file://'):
# Download artifacts to temporary directory in Hydra output folder
hydra_output_dir = filepaths.dir_current_hydra_output()
temp_dir = hydra_output_dir / f'mlflow_artifacts_{cfg_export.mlflow_run_id}'
temp_dir.mkdir(parents=True, exist_ok=True)
logging.info(f'Downloading MLflow artifacts from remote server to: {temp_dir}')
# Download all artifacts
mlflow.artifacts.download_artifacts(
run_id=cfg_export.mlflow_run_id,
dst_path=str(temp_dir)
)
dir_artifacts = temp_dir
logging.info(f'Successfully downloaded artifacts to {temp_dir}')
else:
# Local MLflow artifacts
dir_artifacts = Path(artifact_uri.replace('file://', ''))
else:
dir_artifacts = Path(cfg_export.model_directory)
logging.info('Resolved artifacts uri as {}'.format(str(dir_artifacts)))
# get config and dataset paths
if cfg_export.config_path is None:
path_config = dir_artifacts / '.hydra' / 'config_validated.yaml'
else:
raise ValueError('Custom config_path is not supported in this version.')
# dataset path
if cfg_export.dataset_path is None:
path_dataset = dir_artifacts / 'dataset.hdf5'
else:
if cfg_export.dataset_path.startswith('file://'):
path_dataset = Path(cfg_export.dataset_path.replace('file://', ''))
else:
path_dataset = Path(cfg_export.dataset_path)
# load config (and validate it using the dataclass?)
with open(path_config) as file:
cfg_dict = yaml.load(file, Loader=yaml.FullLoader)
cfg = OmegaConf.create(cfg_dict)
cfg.use_cuda = False
logging.info('Loaded config of BNODE: {}'.format(str(cfg)))
# load training dataset
if path_dataset.is_file():
dataset_file = h5py.File(path_dataset, 'r')
else:
raise FileNotFoundError(f'Dataset file {path_dataset} not found. Please provide a valid dataset path.')
dataset = make_stacked_dataset(dataset_file, 'train')
model = initialize_model(cfg, train_dataset=dataset, hdf5_dataset=None,
initialize_normalization=False, model_type='bnode')
# load latest checkpoint
if cfg_export.model_checkpoint_path is None:
path_checkpoint = sorted(dir_artifacts.rglob('model_phase_*.pt'))[-1]
else:
if cfg_export.model_checkpoint_path.startswith('file://'):
path_checkpoint = Path(cfg_export.model_checkpoint_path.replace('file://', ''))
else:
path_checkpoint = Path(cfg_export.model_checkpoint_path)
if not path_checkpoint.is_file():
raise FileNotFoundError(f'Checkpoint file {path_checkpoint} not found. Please provide a valid checkpoint path.')
model.load(path_checkpoint, device='cpu')
return {'model': model, 'cfg': cfg, 'dataset_file': dataset_file, 'dataset': dataset, 'temp_dir': temp_dir}
|
export_example_io_data(res, inputs, path_example_io)
Export example input/output data for ONNX model validation.
Saves the inputs and outputs of a model component to an HDF5 file for
later validation. This allows users to verify that their ONNX runtime
produces identical results to the PyTorch reference implementation.
Parameters:
| Name |
Type |
Description |
Default |
res
|
different types
|
Model output(s). Can be:
- torch.Tensor: Single output tensor
- tuple/list: Multiple output tensors
- dict: Named output tensors
|
required
|
inputs
|
dict
|
Dictionary of input tensors with their names as keys.
Each value can be a torch.Tensor or None.
|
required
|
path_example_io
|
Path
|
Output path for the HDF5 file.
|
required
|
HDF5 Structure
/inputs/
: dataset (array)
: dataset (array)
...
/outputs/
output: dataset (single output case)
OR
output_0, output_1, ...: datasets (multiple output case)
OR
, , ...: datasets (dict output case)
Notes
- Tensors are automatically converted to NumPy arrays
- None inputs are skipped
- Outputs are organized based on their type (tensor/tuple/dict)
Source code in src/bnode_core/ode/bnode/bnode_export.py
| def export_example_io_data(res, inputs, path_example_io):
"""Export example input/output data for ONNX model validation.
Saves the inputs and outputs of a model component to an HDF5 file for
later validation. This allows users to verify that their ONNX runtime
produces identical results to the PyTorch reference implementation.
Args:
res (different types): Model output(s). Can be:
- torch.Tensor: Single output tensor
- tuple/list: Multiple output tensors
- dict: Named output tensors
inputs (dict): Dictionary of input tensors with their names as keys.
Each value can be a torch.Tensor or None.
path_example_io (Path): Output path for the HDF5 file.
HDF5 Structure:
/inputs/
<input_name_1>: dataset (array)
<input_name_2>: dataset (array)
...
/outputs/
output: dataset (single output case)
OR
output_0, output_1, ...: datasets (multiple output case)
OR
<output_name_1>, <output_name_2>, ...: datasets (dict output case)
Notes:
- Tensors are automatically converted to NumPy arrays
- None inputs are skipped
- Outputs are organized based on their type (tensor/tuple/dict)
"""
with h5py.File(path_example_io, 'w') as f:
# Save inputs
grp_in = f.create_group('inputs')
for in_key, in_val in inputs.items():
if isinstance(in_val, torch.Tensor):
grp_in.create_dataset(in_key, data=in_val.detach().numpy())
elif in_val is not None:
grp_in.create_dataset(in_key, data=in_val)
# Save outputs
grp_out = f.create_group('outputs')
if isinstance(res, torch.Tensor):
grp_out.create_dataset('output', data=res.detach().cpu().numpy())
elif isinstance(res, (tuple, list)):
for i, out_val in enumerate(res):
grp_out.create_dataset(f'output_{i}', data=out_val.detach().cpu().numpy())
elif isinstance(res, dict):
for out_key, out_val in res.items():
grp_out.create_dataset(out_key, data=out_val.detach().cpu().numpy())
|
log_shapes_of_dict(d, name='')
Log the shapes of tensors in a data structure for debugging.
Recursively traverses dictionaries, lists, or tuples and logs the shapes
of any PyTorch tensors found. Useful for debugging model I/O during export.
Source code in src/bnode_core/ode/bnode/bnode_export.py
| def log_shapes_of_dict(d, name=''):
"""Log the shapes of tensors in a data structure for debugging.
Recursively traverses dictionaries, lists, or tuples and logs the shapes
of any PyTorch tensors found. Useful for debugging model I/O during export.
"""
if name:
logging.info(f"Shapes in {name}:")
else:
logging.info("Shapes in .... :")
if isinstance(d, dict):
for key, value in d.items():
if isinstance(value, torch.Tensor):
logging.info(f"\t{key}: {value.shape}")
elif isinstance(value, (tuple, list)):
logging.info(f"\t{key}: {[v.shape for v in value if isinstance(v, torch.Tensor)]}")
else:
logging.info(f"\t{key}: {value}")
elif isinstance(d, (tuple, list)):
for i, value in enumerate(d):
if isinstance(value, torch.Tensor):
logging.info(f"\t[{i}]: {value.shape}")
elif isinstance(value, (tuple, list)):
logging.info(f"\t[{i}]: {[v.shape for v in value if isinstance(v, torch.Tensor)]}")
else:
logging.info(f"\t[{i}]: {value}")
else:
logging.info(f"\t{type(d)}: {d}")
|
export_bnode(cfg_export: onnx_export_config_class)
Main function for BNODE model export to ONNX format.
This function orchestrates the complete export process:
- Loads trained BNODE model and configuration
- Extracts individual encoder, decoder, and ODE components
- Exports each component to ONNX format with dynamic batch dimensions
- Saves example input/output data for validation
- Exports configuration file for reference
The function is designed to be invoked via the uv run bnode_export command,
which is registered in pyproject.toml. Hydra manages configuration loading
and command-line argument parsing.
Parameters:
| Name |
Type |
Description |
Default |
cfg_export
|
onnx_export_config_class
|
Export configuration managed by Hydra.
Configuration can be specified via YAML files or command-line overrides.
|
required
|
Export Workflow
-
Model Loading:
- Retrieves model from MLflow or local directory
- Loads configuration and dataset for normalization
- Restores trained weights from checkpoint
-
Component Extraction:
- State encoder (always present)
- Control encoder (if model uses controls)
- Parameter encoder (if model uses parameters)
- Latent ODE function
- SSM parameter mapping (linear models only)
- Decoder (states and/or outputs)
-
ONNX Export:
- Each component is exported separately
- Dynamic batch dimensions allow flexible inference
- Input/output names are explicitly defined
- Example I/O saved for each component
-
Artifact Organization:
- All exports saved to output_dir or Hydra output
- Configuration exported as
bnode_config.yaml
- Example data in
*_example_io.hdf5 files
Command-Line Usage
Export from MLflow::
uv run bnode_export mlflow_run_id=<run_id> \
mlflow_tracking_uri=<uri> \
output_dir=<output_path>
Export from local directory::
uv run bnode_export model_directory=<path> \
output_dir=<output_path>
With custom checkpoint::
uv run bnode_export mlflow_run_id=<run_id> \
model_checkpoint_path=<checkpoint.pt> \
dataset_path=<dataset.hdf5>
Exported Files
encoder_states.onnx: State encoder
encoder_controls.onnx: Control encoder (if applicable)
encoder_parameters.onnx: Parameter encoder (if applicable)
latent_ode.onnx: Latent dynamics function
latent_ode_ssm_from_param.onnx: SSM mapping (linear models)
decoder.onnx: Decoder network
encoder_*_example_io.hdf5: Example encoder I/O
latent_ode_example_io.hdf5: Example ODE I/O
decoder_example_io.hdf5: Example decoder I/O
bnode_config.yaml: Complete model configuration
ONNX Model Specifications
All exported models include:
- Dynamic axes: Batch dimension (axis 0) is dynamic
- Input names: Descriptive names for each input tensor
- Output names: Descriptive names for each output tensor
Encoder outputs:
latent_<type>_mu: Mean of latent distribution
latent_<type>_logvar: Log-variance of latent distribution
ODE outputs:
lat_states_mu_dot: State derivative (constant variance)
concat(lat_states_mu_dot,lat_states_logvar_dot): Combined
derivatives (dynamic variance)
Decoder outputs:
states: Reconstructed physical states
outputs: Reconstructed system outputs (if applicable)
Notes
- Model is set to evaluation mode before export
- All computations performed on CPU (use_cuda=False)
- Normalization parameters are embedded in exported models
- Example I/O files can be used to validate ONNX Runtime inference
- Hydra output directory is used if output_dir not specified
See Also
load_trained_latent_ode: Model loading function
export_example_io_data: Example I/O export function
bnode_core.ode.trainer.initialize_model: Model initialization
Source code in src/bnode_core/ode/bnode/bnode_export.py
| def export_bnode(cfg_export: onnx_export_config_class):
"""Main function for BNODE model export to ONNX format.
This function orchestrates the complete export process:
1. Loads trained BNODE model and configuration
2. Extracts individual encoder, decoder, and ODE components
3. Exports each component to ONNX format with dynamic batch dimensions
4. Saves example input/output data for validation
5. Exports configuration file for reference
The function is designed to be invoked via the ``uv run bnode_export`` command,
which is registered in ``pyproject.toml``. Hydra manages configuration loading
and command-line argument parsing.
Args:
cfg_export (onnx_export_config_class): Export configuration managed by Hydra.
Configuration can be specified via YAML files or command-line overrides.
Export Workflow:
1. **Model Loading**:
- Retrieves model from MLflow or local directory
- Loads configuration and dataset for normalization
- Restores trained weights from checkpoint
2. **Component Extraction**:
- State encoder (always present)
- Control encoder (if model uses controls)
- Parameter encoder (if model uses parameters)
- Latent ODE function
- SSM parameter mapping (linear models only)
- Decoder (states and/or outputs)
3. **ONNX Export**:
- Each component is exported separately
- Dynamic batch dimensions allow flexible inference
- Input/output names are explicitly defined
- Example I/O saved for each component
4. **Artifact Organization**:
- All exports saved to output_dir or Hydra output
- Configuration exported as ``bnode_config.yaml``
- Example data in ``*_example_io.hdf5`` files
Command-Line Usage:
Export from MLflow::
uv run bnode_export mlflow_run_id=<run_id> \\
mlflow_tracking_uri=<uri> \\
output_dir=<output_path>
Export from local directory::
uv run bnode_export model_directory=<path> \\
output_dir=<output_path>
With custom checkpoint::
uv run bnode_export mlflow_run_id=<run_id> \\
model_checkpoint_path=<checkpoint.pt> \\
dataset_path=<dataset.hdf5>
Exported Files:
- ``encoder_states.onnx``: State encoder
- ``encoder_controls.onnx``: Control encoder (if applicable)
- ``encoder_parameters.onnx``: Parameter encoder (if applicable)
- ``latent_ode.onnx``: Latent dynamics function
- ``latent_ode_ssm_from_param.onnx``: SSM mapping (linear models)
- ``decoder.onnx``: Decoder network
- ``encoder_*_example_io.hdf5``: Example encoder I/O
- ``latent_ode_example_io.hdf5``: Example ODE I/O
- ``decoder_example_io.hdf5``: Example decoder I/O
- ``bnode_config.yaml``: Complete model configuration
ONNX Model Specifications:
All exported models include:
- **Dynamic axes**: Batch dimension (axis 0) is dynamic
- **Input names**: Descriptive names for each input tensor
- **Output names**: Descriptive names for each output tensor
Encoder outputs:
- ``latent_<type>_mu``: Mean of latent distribution
- ``latent_<type>_logvar``: Log-variance of latent distribution
ODE outputs:
- ``lat_states_mu_dot``: State derivative (constant variance)
- ``concat(lat_states_mu_dot,lat_states_logvar_dot)``: Combined
derivatives (dynamic variance)
Decoder outputs:
- ``states``: Reconstructed physical states
- ``outputs``: Reconstructed system outputs (if applicable)
Notes:
- Model is set to evaluation mode before export
- All computations performed on CPU (use_cuda=False)
- Normalization parameters are embedded in exported models
- Example I/O files can be used to validate ONNX Runtime inference
- Hydra output directory is used if output_dir not specified
See Also:
- ``load_trained_latent_ode``: Model loading function
- ``export_example_io_data``: Example I/O export function
- ``bnode_core.ode.trainer.initialize_model``: Model initialization
"""
logging.info('Exporting BNODE using the following config {}'.format(str(cfg_export)))
# load model
res = load_trained_latent_ode(cfg_export)
model, cfg, dataset_file, dataset = res['model'], res['cfg'], res['dataset_file'], res['dataset']
temp_dir = res['temp_dir']
model.eval()
# determine output dir
dir_output = Path(cfg_export.output_dir) if cfg_export.output_dir is not None else filepaths.dir_current_hydra_output()
# export bnode config
path_config = dir_output / 'bnode_config.yaml'
logging.info(f'Exporting BNODE config to {path_config}')
path_config.parent.mkdir(parents=True, exist_ok=True)
with open(path_config, 'w') as f:
yaml.dump(OmegaConf.to_container(cfg, resolve=True), f, default_flow_style=False)
# get test points for graph construction
test_state = dataset[0]['states'][:,0].unsqueeze(0)
test_control = dataset[0]['controls'][:,0].unsqueeze(0) if model.include_controls else None
test_parameters = dataset[0]['parameters'].unsqueeze(0) if model.include_parameters else None
# export the encoders
encoders = {'states': model.state_encoder,
'controls': model.controls_encoder if model.include_controls else None,
'parameters': model.parameter_encoder if model.include_params_encoder else None
}
# construct test inputs for graph construction
inputs_dict = {
'states': {'x': test_state},
'controls': {'x': test_control} if model.params_to_control_encoder is False else {'x': test_control, 'params': test_parameters},
'parameters': {'x': test_parameters},
}
# handling of additional inputs to state encoder
if model.params_to_state_encoder is True:
inputs_dict['states']['params'] = test_parameters
if model.controls_to_state_encoder is True:
inputs_dict['states']['controls'] = test_control
latents_dict = {}
for key, encoder in encoders.items():
if encoder is not None:
path_encoder = dir_output / f'encoder_{key}.onnx'
logging.info(f'Exporting {key} encoder to {path_encoder}')
# test model
log_shapes_of_dict(inputs_dict[key], f'Inputs for {key} encoder')
res = encoder(**inputs_dict[key])
log_shapes_of_dict(res, f'Outputs of {key} encoder')
logging.info(f'Test result {res}')
# export
input_names = list(inputs_dict[key].keys())
output_names=['latent_' + key + '_mu', 'latent_' + key + '_logvar']
dynamic_axes={}
for name in input_names:
dynamic_axes[name] = {0: 'batch_size'}
for name in output_names:
dynamic_axes[name] = {0: 'batch_size'}
# Use legacy TorchScript-based exporter for better stability
torch.onnx.export(encoder,
args=(),
kwargs=inputs_dict[key],
f=path_encoder,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
dynamo=False
)
logging.info(f'Exported {key} encoder successfully')
# export also example io
path_example_io = dir_output / f'encoder_{key}_example_io.hdf5'
export_example_io_data(res, inputs_dict[key], path_example_io)
# save latent variable
latents_dict[key] = res[0] # the first is mu
# export ssm from parameters model and get A_from_param and B_from_param for the latent ODE function
ode = model.latent_ode_func
if ode.include_parameters is True and ode.linear is True:
# this is only possible if the model is linear and has parameters
logging.info('Exporting SSM from parameters')
ssm = model.latent_ode_func.ssm_from_param
path_ssm = dir_output / 'latent_ode_ssm_from_param.onnx'
logging.info(f'Export latent ODE SSM from parameters to {path_ssm}')
# construct test input
inputs = {
'lat_parameters': latents_dict['parameters'],
}
# test model
log_shapes_of_dict(inputs, 'Inputs for latent ODE SSM from parameters')
res = ssm(**inputs)
log_shapes_of_dict(res, 'Outputs of latent ODE SSM from parameters')
logging.info(f'Test result {res}')
# export
input_names=['lat_parameters']
output_names=['A', 'B'] if ode.include_controls else ['A']
dynamic_axes={}
for name in input_names:
dynamic_axes[name] = {0: 'batch_size'}
for name in output_names:
dynamic_axes[name] = {0: 'batch_size'}
torch.onnx.export(ssm,
args=(),
kwargs=inputs,
f=path_ssm,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
dynamo=False)
logging.info(f'Exported latent ODE SSM from parameters successfully')
# get A_from_param and B_from_param
if ode.include_controls:
A_from_param, B_from_param = res
else:
A_from_param = res
# export the latent ode function
path_ode = dir_output / 'latent_ode.onnx'
logging.info(f'Export latent ODE to {path_ode}')
# construct test input
inputs = {
'lat_states': latents_dict['states'],
'lat_parameters': latents_dict['parameters'] if ode.include_parameters is True else None,
'lat_controls': latents_dict['controls'] if ode.include_controls is True else None,
'A_from_param': A_from_param if ode.include_parameters is True and ode.linear else None,
'B_from_param': B_from_param if ode.include_parameters is True and ode.linear and ode.include_controls else None,
}
# test model
log_shapes_of_dict(inputs, 'Inputs for latent ODE')
res = ode(**inputs)
log_shapes_of_dict(res, 'Outputs of latent ODE')
logging.info(f'Test result {res}')
# export
input_names=[]
for key in inputs.keys():
if inputs[key] != None:
input_names.append(key)
dynamic_axes={}
for name in input_names:
dynamic_axes[name] = {0: 'batch_size'}
if model.lat_ode_type == 'variance_constant' or model.lat_ode_type == 'vanilla':
output_names = ['lat_states_mu_dot']
dynamic_axes['lat_states_mu_dot'] = {0: 'batch_size'}
elif model.lat_ode_type == 'variance_dynamic':
output_names = ['concat(lat_states_mu_dot,lat_states_logvar_dot)']
dynamic_axes['concat(lat_states_mu_dot,lat_states_logvar_dot)'] = {0: 'batch_size'}
# Filter out None values and convert to tuple
filtered_inputs = {k: v for k, v in inputs.items() if v is not None}
torch.onnx.export(ode,
args=(),
kwargs=filtered_inputs,
f=path_ode,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
dynamo=False)
logging.info(f'Exported latent ODE successfully')
# export also example io
path_example_io = dir_output / f'latent_ode_example_io.hdf5'
export_example_io_data(res, inputs, path_example_io)
# export the decoder
# TODO: What to with split return? because we have to tak the first n elements for states etc now... implement other function for this? / optional argument?
decoder = model.decoder
decoder.onnx_export = True # disable concatenation of outputs for ONNX export
path_decoder = dir_output / 'decoder.onnx'
logging.info(f'Export decoder to {path_decoder}')
# construct test input
inputs = {
'lat_state': latents_dict['states'],
'lat_parameters': latents_dict['parameters'] if decoder.include_parameters is True else None,
'lat_controls': latents_dict['controls'] if decoder.include_controls is True else None,
}
# test model
log_shapes_of_dict(inputs, 'Inputs for decoder')
res = decoder(**inputs)
log_shapes_of_dict(res, 'Outputs of decoder')
logging.info(f'Test result {res}')
input_names = []
# export
for key in inputs.keys():
if inputs[key] != None:
input_names.append(key)
if decoder.include_outputs and decoder.include_states:
output_names = ['states', 'outputs']
elif decoder.include_outputs:
output_names = ['outputs']
elif decoder.include_states:
output_names = ['states']
dynamic_axes={}
for name in input_names:
dynamic_axes[name] = {0: 'batch_size'}
for name in output_names:
dynamic_axes[name] = {0: 'batch_size'}
# Filter out None values and convert to tuple
filtered_inputs = {k: v for k, v in inputs.items() if v is not None}
torch.onnx.export(decoder,
args=(),
kwargs=filtered_inputs,
f=path_decoder,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
dynamo=False)
logging.info(f'Exported decoder successfully')
# export also example io
path_example_io = dir_output / f'decoder_example_io.hdf5'
export_example_io_data(res, inputs, path_example_io)
if temp_dir is not None:
logging.info(f'Cleaning up temporary directory: {temp_dir}')
shutil.rmtree(temp_dir, ignore_errors=True)
logging.info('Temporary directory cleaned up successfully')
# copy the current hydra folder to the output for reference
dir_hydra_current = filepaths.dir_current_hydra_output()
dir_hydra_output_copy = dir_output / 'hydra'
if dir_hydra_current.is_dir() and (dir_hydra_output_copy.absolute() != dir_hydra_current.absolute()):
logging.info(f'Copying current Hydra directory {dir_hydra_current} to output {dir_hydra_output_copy}')
shutil.copytree(dir_hydra_current, dir_hydra_output_copy, dirs_exist_ok=True)
logging.info('Hydra directory copied successfully')
else:
Warning(f'Current Hydra directory {dir_hydra_current} not found and could not be copied.')
|