Neural Network Utils
bnode_core.nn.nn_utils.normalization
Normalization layers for neural network inputs with time series and 1D data support.
This module provides PyTorch normalization layers that compute and store mean/std statistics from data, then normalize (or denormalize) inputs during forward passes. Supports both time series data (batch, channels, time) and 1D feature vectors (batch, features).
NormalizationLayerTimeSeries
Bases: Module
Normalization layer for time series data with shape (batch, channels, time).
Computes and stores per-channel mean and standard deviation from input data, then normalizes future inputs to zero mean and unit variance. Can also denormalize outputs back to original scale. Statistics are computed once during first forward pass or via explicit initialization.
Expected input shape: (batch_size, n_channels, sequence_length)
Attributes:
| Name | Type | Description |
|---|---|---|
_initialized |
bool
|
Whether mean/std have been computed from data. |
std |
Tensor
|
Per-channel standard deviations, shape (n_channels,). |
mu |
Tensor
|
Per-channel means, shape (n_channels,). |
Source code in src/bnode_core/nn/nn_utils/normalization.py
__init__(n_channels)
Initialize normalization layer buffers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
n_channels
|
int
|
Number of channels in time series data. |
required |
Source code in src/bnode_core/nn/nn_utils/normalization.py
initialize_normalization(x)
Compute and store mean and std from input data.
Calculates per-channel statistics across batch and time dimensions. Adds small epsilon (1e-3) to variance for numerical stability. Only runs if not already initialized.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input data with shape (batch_size, n_channels, sequence_length). |
required |
Side Effects
Sets self.mu and self.std buffers if not already initialized.
Source code in src/bnode_core/nn/nn_utils/normalization.py
forward(x: torch.Tensor, denormalize: bool = False) -> torch.Tensor
Normalize or denormalize input time series.
If not initialized and normalizing, automatically initializes from input data. Normalizes via (x - mu) / std or denormalizes via x * std + mu.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input with shape (batch_size, n_channels, sequence_length). |
required |
denormalize
|
bool
|
If False, normalize input. If True, denormalize (reverse transformation). Defaults to False. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Normalized or denormalized data with same shape as input. |
Source code in src/bnode_core/nn/nn_utils/normalization.py
NormalizationLayer1D
Bases: Module
Normalization layer for 1D feature vectors with shape (batch, features).
Computes and stores per-feature mean and standard deviation, then normalizes inputs to zero mean and unit variance. Can also denormalize outputs. Supports both 2D (batch, features) and 3D (batch, features, time) inputs. Accepts both torch.Tensor and numpy.ndarray for initialization.
Expected input shape: (batch_size, num_features) or (batch_size, num_features, sequence_length)
Attributes:
| Name | Type | Description |
|---|---|---|
_initialized |
bool
|
Whether mean/std have been computed. |
std |
Tensor
|
Per-feature standard deviations, shape (num_features,). |
mu |
Tensor
|
Per-feature means, shape (num_features,). |
Source code in src/bnode_core/nn/nn_utils/normalization.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | |
__init__(num_features)
Initialize normalization layer buffers.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Number of features/channels to normalize. |
required |
Source code in src/bnode_core/nn/nn_utils/normalization.py
initialize_normalization(x, eps=1e-05, verbose=False, name=None)
Compute and store mean and std from input data.
Calculates per-feature statistics across batch dimension. Adds epsilon to variance for numerical stability. Supports both torch.Tensor and numpy.ndarray inputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor or ndarray
|
Input data with shape (batch_size, num_features). |
required |
eps
|
float
|
Small constant added to variance for stability. Defaults to 1e-5. |
1e-05
|
verbose
|
bool
|
If True, logs initialization info. Defaults to False. |
False
|
name
|
str
|
Name for logging output. Defaults to None. |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If x is neither torch.Tensor nor np.ndarray. |
RuntimeError
|
If normalization layer has already been initialized. |
Side Effects
Sets self.mu and self.std buffers, logs initialization if verbose=True.
Source code in src/bnode_core/nn/nn_utils/normalization.py
forward(x: torch.Tensor, denormalize: bool = False) -> torch.Tensor
Normalize or denormalize input features.
If not initialized and normalizing, automatically initializes from input. Handles both 2D (batch, features) and 3D (batch, features, time) inputs by broadcasting. Normalizes via (x - mu) / std or denormalizes via x * std + mu.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input with shape (batch_size, num_features) or (batch_size, num_features, sequence_length). |
required |
denormalize
|
bool
|
If False, normalize input. If True, denormalize. Defaults to False. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
torch.Tensor: Normalized or denormalized data with same shape as input. |
Source code in src/bnode_core/nn/nn_utils/normalization.py
__repr__() -> str
Return string representation of the layer.
Returns:
| Name | Type | Description |
|---|---|---|
str |
str
|
String showing layer type and number of features. |
bnode_core.nn.nn_utils.load_data
Dataset loading utilities for neural network training.
Provides functions to load HDF5 datasets and their configurations, and create PyTorch-compatible dataset objects for training.
Attention
This documentation is generated by AI. Please be aware of possible inaccurcies.
TimeSeriesDataset
Bases: StackDataset
Dataset for time series with sliding window sampling of variable-length subsequences.
Extends StackDataset to enable extracting subsequences from longer time series via a sliding window approach. The full sequences are stored internally, but getitem returns only a subsequence of specified length. This enables training on different sequence lengths without reloading data, and increases effective dataset size by treating each sliding window position as a separate sample.
The dataset expects dict-style data with a 'time' key, where all time series have shape (n_samples, n_channels, n_timesteps). Non-time-series data (2D) is replicated across all windows from the same sample.
Attributes:
| Name | Type | Description |
|---|---|---|
seq_len |
int
|
Length of subsequences returned by getitem. |
mapping |
list
|
List of [sample_idx, start_pos, end_pos] tuples defining each sliding window position across all samples. |
_length |
int
|
Total number of sliding windows (dataset length). |
_length_old |
int
|
Original number of samples before windowing. |
Source code in src/bnode_core/nn/nn_utils/load_data.py
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 | |
__init__(seq_len: int, *args, **kwargs)
Initialize TimeSeriesDataset with sliding window parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seq_len
|
int
|
Length of subsequences to extract. If larger than available time series length, will be clamped to maximum available length. |
required |
*args
|
Positional arguments passed to parent StackDataset. |
()
|
|
**kwargs
|
Keyword arguments passed to parent StackDataset. Must result in a dict-style dataset with a 'time' key. |
{}
|
Raises:
| Type | Description |
|---|---|
AssertionError
|
If datasets is not a dict or lacks 'time' key. |
Source code in src/bnode_core/nn/nn_utils/load_data.py
set_seq_len(seq_len: int)
Change the subsequence length and rebuild the sliding window mapping.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seq_len
|
int
|
New subsequence length. If None, 0, or larger than available time series length, will be clamped to maximum available length. |
required |
Source code in src/bnode_core/nn/nn_utils/load_data.py
initialize_map(seq_len: int)
Create the sliding window index mapping for all samples.
Builds a mapping list where each entry [sample_idx, start_pos, end_pos] defines a sliding window position. Windows slide by 1 timestep across each sample, then continue to the next sample. This treats each window position as an independent dataset item.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seq_len
|
int
|
Length of sliding windows. Must be at least 1. |
required |
Raises:
| Type | Description |
|---|---|
AssertionError
|
If seq_len < 1. |
Side Effects
- Sets self.mapping to list of [sample, start, end] tuples
- Updates self._length to total number of windows across all samples
Source code in src/bnode_core/nn/nn_utils/load_data.py
__getitem__(index: int) -> dict
Get a single sliding window sample.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
index
|
int
|
Index of the sliding window to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
dict |
dict
|
Dictionary with same keys as self.datasets. For 3D+ arrays (time series), returns subsequence [start:end] from appropriate sample. For 2D arrays, returns full array for the sample. |
Source code in src/bnode_core/nn/nn_utils/load_data.py
__getitems__(indices: list) -> list
Get multiple sliding window samples (batch retrieval).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
indices
|
list
|
List of window indices to retrieve. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
list |
list
|
List of dictionaries, one per index, each containing the requested sliding window data. |
Note
This method requires PyTorch >= 2.2 for optimal batched data loading.
Source code in src/bnode_core/nn/nn_utils/load_data.py
__len__() -> int
Return the total number of sliding windows in the dataset.
Returns:
| Name | Type | Description |
|---|---|---|
int |
int
|
Total number of windows across all samples. |
load_validate_dataset_config(path: Path) -> base_pModelClass
Load and validate dataset configuration from YAML file.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
Path
|
Path to the dataset configuration YAML file. |
required |
Returns:
| Type | Description |
|---|---|
base_pModelClass
|
Validated dataset configuration as base_pModelClass instance. |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If configuration file doesn't exist. |
Note
Uses OmegaConf to load YAML and validates against base_pModelClass schema.
Source code in src/bnode_core/nn/nn_utils/load_data.py
load_dataset_and_config(dataset_name: str, dataset_path: str) -> Tuple[h5py.File, Optional[base_pModelClass]]
Load HDF5 dataset and its configuration.
Loads the HDF5 dataset file and attempts to load its configuration. If configuration file doesn't exist, returns None for config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset_name
|
str
|
Name/identifier of the dataset. |
required |
dataset_path
|
str
|
Explicit path to dataset file, or empty string to use default location. |
required |
Returns:
| Type | Description |
|---|---|
Tuple[File, Optional[base_pModelClass]]
|
Tuple of (dataset, dataset_config) where:
|
Note
The returned h5py.File should be closed when done (dataset.close()). Uses filepath_dataset_from_config to resolve actual file path.
Source code in src/bnode_core/nn/nn_utils/load_data.py
make_stacked_dataset(dataset: h5py.File, context: str, seq_len_from_file: Optional[int] = None, seq_len_batches: Optional[int] = None) -> Union[torch.utils.data.StackDataset, TimeSeriesDataset]
Create a PyTorch dataset from HDF5 data with optional time series batching.
Loads time series data (states, derivatives, parameters, controls, outputs) from an HDF5 file and wraps it in a PyTorch StackDataset. If seq_len_batches is specified, returns a TimeSeriesDataset that enables sliding window sampling for variable-length sequences.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
File
|
Open HDF5 file containing time series data with groups for different contexts (train/test/validation). |
required |
context
|
str
|
Dataset context to load. Must be one of: 'train', 'test', 'validation', 'common_test', or 'common_validation'. |
required |
seq_len_from_file
|
int
|
If provided, truncates loaded sequences to this length from the original file data. Defaults to None (use full sequence length). |
None
|
seq_len_batches
|
int
|
If provided, returns a TimeSeriesDataset that extracts subsequences of this length via sliding window. If None, returns standard StackDataset with full sequences. Defaults to None. |
None
|
Returns:
| Type | Description |
|---|---|
Union[StackDataset, TimeSeriesDataset]
|
torch.utils.data.StackDataset or TimeSeriesDataset: A dataset that yields dictionaries containing tensors for 'time', 'states', and optionally 'states_der', 'parameters', 'controls', and 'outputs'. Each tensor has shape (batch, channels, time_steps). |
Note
- All None-valued arrays are automatically excluded from the returned dataset
- Time tensor is replicated across batch dimension from single time vector
- When seq_len_batches is used, the dataset length increases to accommodate all possible sliding windows across the original sequences
Source code in src/bnode_core/nn/nn_utils/load_data.py
bnode_core.nn.nn_utils.kullback_leibler
Kullback-Leibler divergence computation for VAE training.
Provides functions to compute KL divergence between learned latent distributions and standard normal prior, with support for timeseries data and dimension analysis.
Attention
This documentation is generated by AI. Please be aware of possible inaccurcies.
kullback_leibler(mu: torch.Tensor, logvar: torch.Tensor, per_dimension: bool = False, reduce: bool = True, time_series_aggregation_mode: Optional[str] = 'mean') -> torch.Tensor
Compute KL divergence KL(N(mu, exp(logvar)) || N(0, I)).
Calculates the Kullback-Leibler divergence between a learned normal distribution N(mu, sigma^2) and the standard normal prior N(0, 1). Uses the analytical formula: KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Tensor
|
Mean of learned distribution, shape (batch, latent_dim) or (batch, latent_dim, seq_len) for timeseries. |
required |
logvar
|
Tensor
|
Log-variance of learned distribution, same shape as mu. |
required |
per_dimension
|
bool
|
If True, return KL divergence per latent dimension instead of summing across dimensions. Default: False. |
False
|
reduce
|
bool
|
If True, return mean over batch. If False, return per-sample values. Default: True. |
True
|
time_series_aggregation_mode
|
Optional[str]
|
How to aggregate over time dimension if input is timeseries (3D). Options: 'mean', 'max', 'sum', or None (keep time dim). Default: 'mean'. |
'mean'
|
Returns:
| Type | Description |
|---|---|
Tensor
|
KL divergence tensor. Shape depends on parameters: - per_dimension=False, reduce=True: scalar - per_dimension=False, reduce=False: (batch,) - per_dimension=True, reduce=True: (latent_dim,) - per_dimension=True, reduce=False: (batch, latent_dim) |
Note
The KL divergence is always non-negative and equals zero only when the learned distribution matches the prior exactly.
Source code in src/bnode_core/nn/nn_utils/kullback_leibler.py
count_populated_dimensions(mu: torch.Tensor, logvar: torch.Tensor, threshold: float = 0.05, kl_timeseries_aggregation_mode: str = 'mean', return_idx: bool = False) -> torch.Tensor
Count number of latent dimensions actively used by the model.
A dimension is considered "populated" or "active" if its KL divergence exceeds a threshold. This helps diagnose posterior collapse (when KL → 0 for all dimensions) and track how many dimensions the model actually uses.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mu
|
Tensor
|
Mean of learned distribution, shape (batch, latent_dim) or (batch, latent_dim, seq_len). |
required |
logvar
|
Tensor
|
Log-variance of learned distribution, same shape as mu. |
required |
threshold
|
float
|
Minimum KL divergence for a dimension to be considered active. Default: 0.05. |
0.05
|
kl_timeseries_aggregation_mode
|
str
|
Aggregation mode for timeseries data ('mean', 'max', 'sum'). Default: 'mean'. |
'mean'
|
return_idx
|
bool
|
If True, also return boolean mask of active dimensions. Default: False. |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Tuple of (n_dim_populated, idx): - n_dim_populated: Number of dimensions with KL > threshold (scalar tensor). - idx: Boolean mask of active dimensions (tensor or None if return_idx=False). |
Note
Computes KL per dimension averaged over batch (per_dimension=True, reduce=True). Uses torch.no_grad() for efficiency since this is a diagnostic metric.
Source code in src/bnode_core/nn/nn_utils/kullback_leibler.py
bnode_core.nn.nn_utils.early_stopping
Early stopping utility for PyTorch model training.
Monitors validation loss and stops training when no improvement is observed for a specified number of epochs (patience). Saves best model checkpoint.
Attention
This documentation is generated by AI. Please be aware of possible inaccurcies.
EarlyStopping
Stop training early if validation loss doesn't improve after given patience.
Tracks validation loss and saves model checkpoints when improvements occur. Triggers early stopping flag when loss plateaus for 'patience' epochs.
Attributes:
| Name | Type | Description |
|---|---|---|
patience |
Number of epochs to wait before stopping after loss plateau. |
|
verbose |
If True, print messages for each loss improvement. |
|
counter |
Number of epochs since last loss improvement. |
|
best_score |
Best validation loss seen so far. |
|
corresponding_score |
Training loss corresponding to best validation loss. |
|
early_stop |
Flag indicating whether to stop training. |
|
score_last_save |
Validation loss at last checkpoint save. |
|
threshold |
Minimum loss improvement to qualify as improvement. |
|
threshold_mode |
Either 'abs' (absolute) or 'rel' (relative) threshold. |
|
path |
File path for saving model checkpoint. |
|
optimizer_path |
File path for saving optimizer state. |
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | |
__init__(patience: int = 7, verbose: bool = False, threshold: float = 0, threshold_mode: str = 'abs', path: str = 'checkpoint.pt', optimizer_path: str = 'optimizer.pt', trace_func: Callable = print)
Initialize early stopping monitor.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
patience
|
int
|
Number of epochs to wait after last validation loss improvement before triggering early stop. Default: 7. |
7
|
verbose
|
bool
|
If True, prints message for each validation loss improvement. Default: False. |
False
|
threshold
|
float
|
Minimum change in monitored loss to qualify as improvement. Default: 0. |
0
|
threshold_mode
|
str
|
Either 'abs' (absolute: loss < best - threshold) or 'rel' (relative: loss < best * (1 - threshold)). Default: 'abs'. |
'abs'
|
path
|
str
|
Path to save best model checkpoint. Default: 'checkpoint.pt'. |
'checkpoint.pt'
|
optimizer_path
|
str
|
Path to save optimizer state. Default: 'optimizer.pt'. |
'optimizer.pt'
|
trace_func
|
Callable
|
Logging function for status messages. Default: print. |
print
|
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
reset()
Reset all early stopping state to initial values.
Useful for starting a new training phase with fresh early stopping.
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
reset_counter()
Reset only the patience counter and early_stop flag.
Keeps best_score intact. Useful after manual interventions.
__call__(loss: float, model: nn.Module, epoch: Optional[int] = None, optimizer: Optional[optim.Optimizer] = None, corresponding_loss: Optional[float] = None)
Update early stopping state based on current validation loss.
Checks if loss has improved according to threshold criteria. Saves checkpoint if improvement occurred, otherwise increments patience counter. Sets early_stop flag when counter reaches patience.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss
|
float
|
Current validation loss. |
required |
model
|
Module
|
PyTorch model with save() method. |
required |
epoch
|
Optional[int]
|
Current epoch number (for logging). Optional. |
None
|
optimizer
|
Optional[Optimizer]
|
PyTorch optimizer to save state. Optional. |
None
|
corresponding_loss
|
Optional[float]
|
Training loss from same epoch (for tracking). Optional. |
None
|
Side Effects
- Updates counter, best_score, and corresponding_score
- Saves model checkpoint when loss improves
- Sets early_stop flag when patience exceeded
- Handles NaN loss by setting to infinity
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
save_checkpoint(loss: float, model: nn.Module, optimizer: Optional[optim.Optimizer], epoch: Optional[int])
Save model checkpoint when validation loss improves.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss
|
float
|
Current validation loss (for logging). |
required |
model
|
Module
|
PyTorch model with save() method. |
required |
optimizer
|
Optional[Optimizer]
|
PyTorch optimizer (state saved if not None). |
required |
epoch
|
Optional[int]
|
Current epoch number (for logging). |
required |
Side Effects
- Calls model.save(path) to persist model state
- Saves optimizer state_dict if optimizer provided
- Updates score_last_save
- Logs checkpoint save if verbose=True
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
bnode_core.nn.nn_utils.count_parameters
Utility for counting trainable parameters in PyTorch models.
count_parameters(model: nn.Module) -> int
Count total number of trainable parameters in a model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
PyTorch model (nn.Module). |
required |
Returns:
| Type | Description |
|---|---|
int
|
Total number of trainable parameters (int). |
Example
model = VAE(...) num_params = count_parameters(model) print(f"Model has {num_params:,} trainable parameters")
Source code in src/bnode_core/nn/nn_utils/count_parameters.py
bnode_core.nn.nn_utils.capacity_scheduler
Capacity scheduler for VAE training with controlled KL divergence growth.
This module implements a scheduler that gradually increases the KL divergence capacity during VAE training to prevent posterior collapse while maintaining good reconstruction quality.
Attention
This documentation is generated by AI. Please be aware of possible inaccurcies.
capacity_scheduler
Schedule the KL divergence capacity for VAE bottleneck layer.
Gradually increases the target KL divergence (capacity) to allow the model to learn meaningful latent representations without posterior collapse. The capacity increases when validation loss plateaus, encouraging the model to use more of the latent space.
Attributes:
| Name | Type | Description |
|---|---|---|
patience |
Number of epochs to wait before increasing capacity after loss plateau. |
|
capacity |
Current KL divergence capacity target. |
|
capacity_max |
Maximum capacity value (stops increasing after reaching this). |
|
capacity_increment |
Amount to increase capacity by (absolute or relative). |
|
capacity_increment_mode |
Either 'abs' (add increment) or 'rel' (multiply by increment). |
|
counter |
Number of epochs since last loss improvement. |
|
best_score |
Best validation loss seen so far. |
|
threshold |
Minimum loss improvement to reset counter. |
|
threshold_mode |
Either 'abs' (absolute threshold) or 'rel' (relative threshold). |
|
enabled |
If False, scheduler is disabled and returns None for capacity. |
|
reached_max_capacity |
True if capacity has reached capacity_max. |
Source code in src/bnode_core/nn/nn_utils/capacity_scheduler.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | |
__init__(patience: int, capacity_start: float, capacity_max: float, capacity_increment: float, capacity_increment_mode: str, threshold: float, threshold_mode: str, trace_func: Callable = logging.info, enabled: bool = True)
Initialize the capacity scheduler.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
patience
|
int
|
Number of epochs to wait after last validation loss improvement before increasing capacity. |
required |
capacity_start
|
float
|
Initial KL divergence capacity target. |
required |
capacity_max
|
float
|
Maximum capacity value (caps further increases). |
required |
capacity_increment
|
float
|
Amount to change capacity when triggered. |
required |
capacity_increment_mode
|
str
|
Either 'abs' (additive: capacity += increment) or 'rel' (multiplicative: capacity *= increment). |
required |
threshold
|
float
|
Minimum change in validation loss to qualify as improvement. |
required |
threshold_mode
|
str
|
Either 'abs' (absolute: loss < best - threshold) or 'rel' (relative: loss < best * (1 - threshold)). |
required |
trace_func
|
Callable
|
Logging function for status messages (default: logging.info). |
info
|
enabled
|
bool
|
If False, scheduler is disabled and get_capacity() returns None. |
True
|
Raises:
| Type | Description |
|---|---|
AssertionError
|
If capacity_increment_mode or threshold_mode are invalid. |
Source code in src/bnode_core/nn/nn_utils/capacity_scheduler.py
update(score: float)
Update scheduler state based on current validation loss.
Tracks validation loss improvements and increases capacity when loss plateaus for 'patience' epochs. Capacity increases until reaching capacity_max.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
score
|
float
|
Current validation loss (typically MSE loss). |
required |
Side Effects
- Updates counter and best_score
- Increases capacity if patience threshold reached
- Sets reached_max_capacity flag when maximum is hit
- Logs capacity changes via trace_func
Source code in src/bnode_core/nn/nn_utils/capacity_scheduler.py
get_capacity() -> Optional[float]
Get current capacity target for KL divergence loss.
Returns:
| Type | Description |
|---|---|
Optional[float]
|
Current capacity value (float) if enabled, None if disabled. |