Skip to content

Neural Network Utils

bnode_core.nn.nn_utils.normalization

Normalization layers for neural network inputs with time series and 1D data support.

This module provides PyTorch normalization layers that compute and store mean/std statistics from data, then normalize (or denormalize) inputs during forward passes. Supports both time series data (batch, channels, time) and 1D feature vectors (batch, features).

NormalizationLayerTimeSeries

Bases: Module

Normalization layer for time series data with shape (batch, channels, time).

Computes and stores per-channel mean and standard deviation from input data, then normalizes future inputs to zero mean and unit variance. Can also denormalize outputs back to original scale. Statistics are computed once during first forward pass or via explicit initialization.

Expected input shape: (batch_size, n_channels, sequence_length)

Attributes:

Name Type Description
_initialized bool

Whether mean/std have been computed from data.

std Tensor

Per-channel standard deviations, shape (n_channels,).

mu Tensor

Per-channel means, shape (n_channels,).

Source code in src/bnode_core/nn/nn_utils/normalization.py
class NormalizationLayerTimeSeries(nn.Module):
    """Normalization layer for time series data with shape (batch, channels, time).

    Computes and stores per-channel mean and standard deviation from input data, then
    normalizes future inputs to zero mean and unit variance. Can also denormalize outputs
    back to original scale. Statistics are computed once during first forward pass or via
    explicit initialization.

    Expected input shape: (batch_size, n_channels, sequence_length)

    Attributes:
        _initialized (bool): Whether mean/std have been computed from data.
        std (torch.Tensor): Per-channel standard deviations, shape (n_channels,).
        mu (torch.Tensor): Per-channel means, shape (n_channels,).
    """
    def __init__(self, n_channels):
        """Initialize normalization layer buffers.

        Args:
            n_channels (int): Number of channels in time series data.
        """
        super().__init__()
        self.register_buffer("_initialized", torch.tensor(False))
        self.register_buffer('std', torch.zeros(n_channels))
        self.register_buffer('mu', torch.zeros(n_channels))

    def initialize_normalization(self,x):
        """Compute and store mean and std from input data.

        Calculates per-channel statistics across batch and time dimensions. Adds small
        epsilon (1e-3) to variance for numerical stability. Only runs if not already
        initialized.

        Args:
            x (torch.Tensor): Input data with shape (batch_size, n_channels, sequence_length).

        Side Effects:
            Sets self.mu and self.std buffers if not already initialized.
        """
        if not self._initialized:
            variance = torch.var(x, dim=(0,2)).detach()
            self.std.set_(torch.sqrt(variance + torch.ones(variance.size()).to(variance.device) * 1e-3))
            self.mu.set_(torch.mean(x, dim=(0,2)).detach())
            self._initialized = torch.tensor(True)
            assert self.std.requires_grad == False
            assert self.mu.requires_grad == False

    def forward(self, x: torch.Tensor, denormalize: bool = False) -> torch.Tensor:
        """Normalize or denormalize input time series.

        If not initialized and normalizing, automatically initializes from input data.
        Normalizes via (x - mu) / std or denormalizes via x * std + mu.

        Args:
            x (torch.Tensor): Input with shape (batch_size, n_channels, sequence_length).
            denormalize (bool, optional): If False, normalize input. If True, denormalize
                (reverse transformation). Defaults to False.

        Returns:
            torch.Tensor: Normalized or denormalized data with same shape as input.
        """
        if denormalize is False:
            if not self._initialized:
                self.initialize_normalization(x)
        batch_size = x.shape[0]
        seq_len = x.shape[2]
        # add dimensions at position 0 (for number of batches) and at position 2 (for sequence length)
        # expand these dimensions
        std = self.std.unsqueeze(0).unsqueeze(2).expand(batch_size,-1,seq_len)
        mu = self.mu.unsqueeze(0).unsqueeze(2).expand(batch_size,-1,seq_len)
        if denormalize is False:
            x = torch.subtract(x,mu)
            x = torch.divide(x, std)
        else:
            x = torch.multiply(x, std)
            x = torch.add(x, mu)
        return x

__init__(n_channels)

Initialize normalization layer buffers.

Parameters:

Name Type Description Default
n_channels int

Number of channels in time series data.

required
Source code in src/bnode_core/nn/nn_utils/normalization.py
def __init__(self, n_channels):
    """Initialize normalization layer buffers.

    Args:
        n_channels (int): Number of channels in time series data.
    """
    super().__init__()
    self.register_buffer("_initialized", torch.tensor(False))
    self.register_buffer('std', torch.zeros(n_channels))
    self.register_buffer('mu', torch.zeros(n_channels))

initialize_normalization(x)

Compute and store mean and std from input data.

Calculates per-channel statistics across batch and time dimensions. Adds small epsilon (1e-3) to variance for numerical stability. Only runs if not already initialized.

Parameters:

Name Type Description Default
x Tensor

Input data with shape (batch_size, n_channels, sequence_length).

required
Side Effects

Sets self.mu and self.std buffers if not already initialized.

Source code in src/bnode_core/nn/nn_utils/normalization.py
def initialize_normalization(self,x):
    """Compute and store mean and std from input data.

    Calculates per-channel statistics across batch and time dimensions. Adds small
    epsilon (1e-3) to variance for numerical stability. Only runs if not already
    initialized.

    Args:
        x (torch.Tensor): Input data with shape (batch_size, n_channels, sequence_length).

    Side Effects:
        Sets self.mu and self.std buffers if not already initialized.
    """
    if not self._initialized:
        variance = torch.var(x, dim=(0,2)).detach()
        self.std.set_(torch.sqrt(variance + torch.ones(variance.size()).to(variance.device) * 1e-3))
        self.mu.set_(torch.mean(x, dim=(0,2)).detach())
        self._initialized = torch.tensor(True)
        assert self.std.requires_grad == False
        assert self.mu.requires_grad == False

forward(x: torch.Tensor, denormalize: bool = False) -> torch.Tensor

Normalize or denormalize input time series.

If not initialized and normalizing, automatically initializes from input data. Normalizes via (x - mu) / std or denormalizes via x * std + mu.

Parameters:

Name Type Description Default
x Tensor

Input with shape (batch_size, n_channels, sequence_length).

required
denormalize bool

If False, normalize input. If True, denormalize (reverse transformation). Defaults to False.

False

Returns:

Type Description
Tensor

torch.Tensor: Normalized or denormalized data with same shape as input.

Source code in src/bnode_core/nn/nn_utils/normalization.py
def forward(self, x: torch.Tensor, denormalize: bool = False) -> torch.Tensor:
    """Normalize or denormalize input time series.

    If not initialized and normalizing, automatically initializes from input data.
    Normalizes via (x - mu) / std or denormalizes via x * std + mu.

    Args:
        x (torch.Tensor): Input with shape (batch_size, n_channels, sequence_length).
        denormalize (bool, optional): If False, normalize input. If True, denormalize
            (reverse transformation). Defaults to False.

    Returns:
        torch.Tensor: Normalized or denormalized data with same shape as input.
    """
    if denormalize is False:
        if not self._initialized:
            self.initialize_normalization(x)
    batch_size = x.shape[0]
    seq_len = x.shape[2]
    # add dimensions at position 0 (for number of batches) and at position 2 (for sequence length)
    # expand these dimensions
    std = self.std.unsqueeze(0).unsqueeze(2).expand(batch_size,-1,seq_len)
    mu = self.mu.unsqueeze(0).unsqueeze(2).expand(batch_size,-1,seq_len)
    if denormalize is False:
        x = torch.subtract(x,mu)
        x = torch.divide(x, std)
    else:
        x = torch.multiply(x, std)
        x = torch.add(x, mu)
    return x

NormalizationLayer1D

Bases: Module

Normalization layer for 1D feature vectors with shape (batch, features).

Computes and stores per-feature mean and standard deviation, then normalizes inputs to zero mean and unit variance. Can also denormalize outputs. Supports both 2D (batch, features) and 3D (batch, features, time) inputs. Accepts both torch.Tensor and numpy.ndarray for initialization.

Expected input shape: (batch_size, num_features) or (batch_size, num_features, sequence_length)

Attributes:

Name Type Description
_initialized bool

Whether mean/std have been computed.

std Tensor

Per-feature standard deviations, shape (num_features,).

mu Tensor

Per-feature means, shape (num_features,).

Source code in src/bnode_core/nn/nn_utils/normalization.py
class NormalizationLayer1D(nn.Module):
    """Normalization layer for 1D feature vectors with shape (batch, features).

    Computes and stores per-feature mean and standard deviation, then normalizes inputs
    to zero mean and unit variance. Can also denormalize outputs. Supports both 2D
    (batch, features) and 3D (batch, features, time) inputs. Accepts both torch.Tensor
    and numpy.ndarray for initialization.

    Expected input shape: (batch_size, num_features) or (batch_size, num_features, sequence_length)

    Attributes:
        _initialized (bool): Whether mean/std have been computed.
        std (torch.Tensor): Per-feature standard deviations, shape (num_features,).
        mu (torch.Tensor): Per-feature means, shape (num_features,).
    """
    def __init__(self, num_features):
        """Initialize normalization layer buffers.

        Args:
            num_features (int): Number of features/channels to normalize.
        """
        super().__init__()
        self.register_buffer("_initialized", torch.tensor(False))
        self.register_buffer('std', torch.zeros((num_features)))
        self.register_buffer('mu', torch.zeros(num_features))

    def initialize_normalization(self, x, eps = 1e-5, verbose = False, name = None):
        """Compute and store mean and std from input data.

        Calculates per-feature statistics across batch dimension. Adds epsilon to variance
        for numerical stability. Supports both torch.Tensor and numpy.ndarray inputs.

        Args:
            x (torch.Tensor or np.ndarray): Input data with shape (batch_size, num_features).
            eps (float, optional): Small constant added to variance for stability. Defaults to 1e-5.
            verbose (bool, optional): If True, logs initialization info. Defaults to False.
            name (str, optional): Name for logging output. Defaults to None.

        Raises:
            ValueError: If x is neither torch.Tensor nor np.ndarray.
            RuntimeError: If normalization layer has already been initialized.

        Side Effects:
            Sets self.mu and self.std buffers, logs initialization if verbose=True.
        """
        if not self._initialized:
            if isinstance(x, torch.Tensor):
                variance = torch.var(x, dim=(0)).detach()
                self.std.set_(torch.sqrt(variance + torch.ones(variance.size()).to(variance.device) * eps))
                self.mu.set_(torch.mean(x, dim=(0)).detach())
            elif isinstance(x, np.ndarray):
                variance = np.var(x, axis=0)
                self.std.set_(torch.sqrt(torch.tensor(variance + np.ones(variance.shape) * eps, dtype=torch.float32)))
                self.mu.set_(torch.tensor(np.mean(x, axis=0), dtype=torch.float32))
            else:
                raise ValueError('Unknown type of input: {}'.format(type(x)))
            self._initialized = torch.tensor(True)
            assert self.std.requires_grad == False
            assert self.mu.requires_grad == False

            logging.info("Initialized normalization layer {} with mean {} and std {}".format(name, self.mu, self.std))
        else:
            raise RuntimeError("normalization layer has already been initialized")

    def forward(self, x: torch.Tensor, denormalize: bool = False) -> torch.Tensor:
        """Normalize or denormalize input features.

        If not initialized and normalizing, automatically initializes from input. Handles
        both 2D (batch, features) and 3D (batch, features, time) inputs by broadcasting.
        Normalizes via (x - mu) / std or denormalizes via x * std + mu.

        Args:
            x (torch.Tensor): Input with shape (batch_size, num_features) or 
                (batch_size, num_features, sequence_length).
            denormalize (bool, optional): If False, normalize input. If True, denormalize.
                Defaults to False.

        Returns:
            torch.Tensor: Normalized or denormalized data with same shape as input.
        """
        if not denormalize:
            if not self._initialized:
                self.initialize_normalization(x)
        batch_size = x.shape[0]
        # add dimension at position 0 and expand to batch_size
        std = self.std.unsqueeze(0).expand(batch_size,-1)
        mu = self.mu.unsqueeze(0).expand(batch_size,-1)
        if len(x.shape) == 3:
            # if x is a 3D tensor, we assume it has shape (batch_size, num_features, sequence_length)
            seq_len = x.shape[2]
            std = std.unsqueeze(2).expand(batch_size,-1,seq_len)
            mu = mu.unsqueeze(2).expand(batch_size,-1,seq_len)
        if not denormalize:
            x = torch.subtract(x, mu)
            x = torch.divide(x, std)
        else:
            x = torch.multiply(x, std)
            x = torch.add(x, mu)
        return x

    def __repr__(self) -> str:
        """Return string representation of the layer.

        Returns:
            str: String showing layer type and number of features.
        """
        return 'NormalizationLayer1D(num_features={})'.format(self.std.shape[0])

__init__(num_features)

Initialize normalization layer buffers.

Parameters:

Name Type Description Default
num_features int

Number of features/channels to normalize.

required
Source code in src/bnode_core/nn/nn_utils/normalization.py
def __init__(self, num_features):
    """Initialize normalization layer buffers.

    Args:
        num_features (int): Number of features/channels to normalize.
    """
    super().__init__()
    self.register_buffer("_initialized", torch.tensor(False))
    self.register_buffer('std', torch.zeros((num_features)))
    self.register_buffer('mu', torch.zeros(num_features))

initialize_normalization(x, eps=1e-05, verbose=False, name=None)

Compute and store mean and std from input data.

Calculates per-feature statistics across batch dimension. Adds epsilon to variance for numerical stability. Supports both torch.Tensor and numpy.ndarray inputs.

Parameters:

Name Type Description Default
x Tensor or ndarray

Input data with shape (batch_size, num_features).

required
eps float

Small constant added to variance for stability. Defaults to 1e-5.

1e-05
verbose bool

If True, logs initialization info. Defaults to False.

False
name str

Name for logging output. Defaults to None.

None

Raises:

Type Description
ValueError

If x is neither torch.Tensor nor np.ndarray.

RuntimeError

If normalization layer has already been initialized.

Side Effects

Sets self.mu and self.std buffers, logs initialization if verbose=True.

Source code in src/bnode_core/nn/nn_utils/normalization.py
def initialize_normalization(self, x, eps = 1e-5, verbose = False, name = None):
    """Compute and store mean and std from input data.

    Calculates per-feature statistics across batch dimension. Adds epsilon to variance
    for numerical stability. Supports both torch.Tensor and numpy.ndarray inputs.

    Args:
        x (torch.Tensor or np.ndarray): Input data with shape (batch_size, num_features).
        eps (float, optional): Small constant added to variance for stability. Defaults to 1e-5.
        verbose (bool, optional): If True, logs initialization info. Defaults to False.
        name (str, optional): Name for logging output. Defaults to None.

    Raises:
        ValueError: If x is neither torch.Tensor nor np.ndarray.
        RuntimeError: If normalization layer has already been initialized.

    Side Effects:
        Sets self.mu and self.std buffers, logs initialization if verbose=True.
    """
    if not self._initialized:
        if isinstance(x, torch.Tensor):
            variance = torch.var(x, dim=(0)).detach()
            self.std.set_(torch.sqrt(variance + torch.ones(variance.size()).to(variance.device) * eps))
            self.mu.set_(torch.mean(x, dim=(0)).detach())
        elif isinstance(x, np.ndarray):
            variance = np.var(x, axis=0)
            self.std.set_(torch.sqrt(torch.tensor(variance + np.ones(variance.shape) * eps, dtype=torch.float32)))
            self.mu.set_(torch.tensor(np.mean(x, axis=0), dtype=torch.float32))
        else:
            raise ValueError('Unknown type of input: {}'.format(type(x)))
        self._initialized = torch.tensor(True)
        assert self.std.requires_grad == False
        assert self.mu.requires_grad == False

        logging.info("Initialized normalization layer {} with mean {} and std {}".format(name, self.mu, self.std))
    else:
        raise RuntimeError("normalization layer has already been initialized")

forward(x: torch.Tensor, denormalize: bool = False) -> torch.Tensor

Normalize or denormalize input features.

If not initialized and normalizing, automatically initializes from input. Handles both 2D (batch, features) and 3D (batch, features, time) inputs by broadcasting. Normalizes via (x - mu) / std or denormalizes via x * std + mu.

Parameters:

Name Type Description Default
x Tensor

Input with shape (batch_size, num_features) or (batch_size, num_features, sequence_length).

required
denormalize bool

If False, normalize input. If True, denormalize. Defaults to False.

False

Returns:

Type Description
Tensor

torch.Tensor: Normalized or denormalized data with same shape as input.

Source code in src/bnode_core/nn/nn_utils/normalization.py
def forward(self, x: torch.Tensor, denormalize: bool = False) -> torch.Tensor:
    """Normalize or denormalize input features.

    If not initialized and normalizing, automatically initializes from input. Handles
    both 2D (batch, features) and 3D (batch, features, time) inputs by broadcasting.
    Normalizes via (x - mu) / std or denormalizes via x * std + mu.

    Args:
        x (torch.Tensor): Input with shape (batch_size, num_features) or 
            (batch_size, num_features, sequence_length).
        denormalize (bool, optional): If False, normalize input. If True, denormalize.
            Defaults to False.

    Returns:
        torch.Tensor: Normalized or denormalized data with same shape as input.
    """
    if not denormalize:
        if not self._initialized:
            self.initialize_normalization(x)
    batch_size = x.shape[0]
    # add dimension at position 0 and expand to batch_size
    std = self.std.unsqueeze(0).expand(batch_size,-1)
    mu = self.mu.unsqueeze(0).expand(batch_size,-1)
    if len(x.shape) == 3:
        # if x is a 3D tensor, we assume it has shape (batch_size, num_features, sequence_length)
        seq_len = x.shape[2]
        std = std.unsqueeze(2).expand(batch_size,-1,seq_len)
        mu = mu.unsqueeze(2).expand(batch_size,-1,seq_len)
    if not denormalize:
        x = torch.subtract(x, mu)
        x = torch.divide(x, std)
    else:
        x = torch.multiply(x, std)
        x = torch.add(x, mu)
    return x

__repr__() -> str

Return string representation of the layer.

Returns:

Name Type Description
str str

String showing layer type and number of features.

Source code in src/bnode_core/nn/nn_utils/normalization.py
def __repr__(self) -> str:
    """Return string representation of the layer.

    Returns:
        str: String showing layer type and number of features.
    """
    return 'NormalizationLayer1D(num_features={})'.format(self.std.shape[0])

bnode_core.nn.nn_utils.load_data

Dataset loading utilities for neural network training.

Provides functions to load HDF5 datasets and their configurations, and create PyTorch-compatible dataset objects for training.

Attention

This documentation is generated by AI. Please be aware of possible inaccurcies.

TimeSeriesDataset

Bases: StackDataset

Dataset for time series with sliding window sampling of variable-length subsequences.

Extends StackDataset to enable extracting subsequences from longer time series via a sliding window approach. The full sequences are stored internally, but getitem returns only a subsequence of specified length. This enables training on different sequence lengths without reloading data, and increases effective dataset size by treating each sliding window position as a separate sample.

The dataset expects dict-style data with a 'time' key, where all time series have shape (n_samples, n_channels, n_timesteps). Non-time-series data (2D) is replicated across all windows from the same sample.

Attributes:

Name Type Description
seq_len int

Length of subsequences returned by getitem.

mapping list

List of [sample_idx, start_pos, end_pos] tuples defining each sliding window position across all samples.

_length int

Total number of sliding windows (dataset length).

_length_old int

Original number of samples before windowing.

Source code in src/bnode_core/nn/nn_utils/load_data.py
class TimeSeriesDataset(torch.utils.data.StackDataset):
    """Dataset for time series with sliding window sampling of variable-length subsequences.

    Extends StackDataset to enable extracting subsequences from longer time series via a
    sliding window approach. The full sequences are stored internally, but __getitem__
    returns only a subsequence of specified length. This enables training on different
    sequence lengths without reloading data, and increases effective dataset size by
    treating each sliding window position as a separate sample.

    The dataset expects dict-style data with a 'time' key, where all time series have shape
    (n_samples, n_channels, n_timesteps). Non-time-series data (2D) is replicated across
    all windows from the same sample.

    Attributes:
        seq_len (int): Length of subsequences returned by __getitem__.
        mapping (list): List of [sample_idx, start_pos, end_pos] tuples defining each
            sliding window position across all samples.
        _length (int): Total number of sliding windows (dataset length).
        _length_old (int): Original number of samples before windowing.
    """
    def __init__(self, seq_len: int, *args, **kwargs):
        """Initialize TimeSeriesDataset with sliding window parameters.

        Args:
            seq_len (int): Length of subsequences to extract. If larger than available
                time series length, will be clamped to maximum available length.
            *args: Positional arguments passed to parent StackDataset.
            **kwargs: Keyword arguments passed to parent StackDataset. Must result in
                a dict-style dataset with a 'time' key.

        Raises:
            AssertionError: If datasets is not a dict or lacks 'time' key.
        """
        super().__init__(*args, **kwargs)
        assert isinstance(self.datasets, dict), "can only handle dict style stacked datasets with one key-value pair time"
        assert 'time' in self.datasets.keys(), "need one dataset with key time to define the map" 
        self._length_old = self._length
        if seq_len > self.datasets['time'].shape[2]:
            Warning("seq_len is {}, setting to len of timeseries".format(seq_len))
            seq_len = self.datasets['time'].shape[2]
        self.seq_len = seq_len
        self.initialize_map(seq_len)

    def set_seq_len(self, seq_len: int):
        """Change the subsequence length and rebuild the sliding window mapping.

        Args:
            seq_len (int): New subsequence length. If None, 0, or larger than available
                time series length, will be clamped to maximum available length.
        """
        if seq_len == None or seq_len == 0 or seq_len > self.datasets['time'].shape[2]:
            Warning("seq_len is {}, setting to len of timeseries".format(seq_len))
            seq_len = self.datasets['time'].shape[2]
        self.seq_len = seq_len
        self.initialize_map(seq_len)

    def initialize_map(self, seq_len: int):
        """Create the sliding window index mapping for all samples.

        Builds a mapping list where each entry [sample_idx, start_pos, end_pos] defines
        a sliding window position. Windows slide by 1 timestep across each sample, then
        continue to the next sample. This treats each window position as an independent
        dataset item.

        Args:
            seq_len (int): Length of sliding windows. Must be at least 1.

        Raises:
            AssertionError: If seq_len < 1.

        Side Effects:
            - Sets self.mapping to list of [sample, start, end] tuples
            - Updates self._length to total number of windows across all samples
        """
        assert seq_len > 0, "seq_len must be at least 1"
        # define map
        n_batches_per_sample = (self.datasets['time'].shape[2] - (seq_len - 1))
        n_batches_total = n_batches_per_sample * self._length_old
        self.mapping = [[] for i in range(n_batches_total)]
        self._length = n_batches_total
        # fill out map. mapping shall contain [n_sample, start_position, stop_position]
        k_stop=seq_len # stop position in sequence
        j=0 # sample position in datasets
        for i in range(n_batches_total):
            self.mapping[i] = [j, k_stop-seq_len, k_stop]
            if k_stop + 2 > n_batches_per_sample + seq_len: # if the over next sequence would be out of bounds, go to next sample (+1 more because of > and not >=)
                k_stop = seq_len
                j += 1
            else:
                k_stop += 1
        #self.mapping = torch.tensor(np.array(self.mapping), dtype=torch.int64)


    def __getitem__(self, index: int) -> dict:
        """Get a single sliding window sample.

        Args:
            index (int): Index of the sliding window to retrieve.

        Returns:
            dict: Dictionary with same keys as self.datasets. For 3D+ arrays (time series),
                returns subsequence [start:end] from appropriate sample. For 2D arrays,
                returns full array for the sample.
        """
        i, k_start, k_stop = self.mapping[index]
        ret_val = {}
        for key, value in self.datasets.items():
            if value.ndim == 2:
                ret_val[key] = value[i, :]
            else:
                ret_val[key] = value[i, :, k_start:k_stop]
        return ret_val

    def __getitems__(self, indices: list) -> list:
        """Get multiple sliding window samples (batch retrieval).

        Args:
            indices (list): List of window indices to retrieve.

        Returns:
            list: List of dictionaries, one per index, each containing the requested
                sliding window data.

        Note:
            This method requires PyTorch >= 2.2 for optimal batched data loading.
        """
        samples = [None] * len(indices)
        for i, index in enumerate(indices):
            samples[i] = self.__getitem__(index)
        return samples

    def __len__(self) -> int:
        """Return the total number of sliding windows in the dataset.

        Returns:
            int: Total number of windows across all samples.
        """
        return self._length

__init__(seq_len: int, *args, **kwargs)

Initialize TimeSeriesDataset with sliding window parameters.

Parameters:

Name Type Description Default
seq_len int

Length of subsequences to extract. If larger than available time series length, will be clamped to maximum available length.

required
*args

Positional arguments passed to parent StackDataset.

()
**kwargs

Keyword arguments passed to parent StackDataset. Must result in a dict-style dataset with a 'time' key.

{}

Raises:

Type Description
AssertionError

If datasets is not a dict or lacks 'time' key.

Source code in src/bnode_core/nn/nn_utils/load_data.py
def __init__(self, seq_len: int, *args, **kwargs):
    """Initialize TimeSeriesDataset with sliding window parameters.

    Args:
        seq_len (int): Length of subsequences to extract. If larger than available
            time series length, will be clamped to maximum available length.
        *args: Positional arguments passed to parent StackDataset.
        **kwargs: Keyword arguments passed to parent StackDataset. Must result in
            a dict-style dataset with a 'time' key.

    Raises:
        AssertionError: If datasets is not a dict or lacks 'time' key.
    """
    super().__init__(*args, **kwargs)
    assert isinstance(self.datasets, dict), "can only handle dict style stacked datasets with one key-value pair time"
    assert 'time' in self.datasets.keys(), "need one dataset with key time to define the map" 
    self._length_old = self._length
    if seq_len > self.datasets['time'].shape[2]:
        Warning("seq_len is {}, setting to len of timeseries".format(seq_len))
        seq_len = self.datasets['time'].shape[2]
    self.seq_len = seq_len
    self.initialize_map(seq_len)

set_seq_len(seq_len: int)

Change the subsequence length and rebuild the sliding window mapping.

Parameters:

Name Type Description Default
seq_len int

New subsequence length. If None, 0, or larger than available time series length, will be clamped to maximum available length.

required
Source code in src/bnode_core/nn/nn_utils/load_data.py
def set_seq_len(self, seq_len: int):
    """Change the subsequence length and rebuild the sliding window mapping.

    Args:
        seq_len (int): New subsequence length. If None, 0, or larger than available
            time series length, will be clamped to maximum available length.
    """
    if seq_len == None or seq_len == 0 or seq_len > self.datasets['time'].shape[2]:
        Warning("seq_len is {}, setting to len of timeseries".format(seq_len))
        seq_len = self.datasets['time'].shape[2]
    self.seq_len = seq_len
    self.initialize_map(seq_len)

initialize_map(seq_len: int)

Create the sliding window index mapping for all samples.

Builds a mapping list where each entry [sample_idx, start_pos, end_pos] defines a sliding window position. Windows slide by 1 timestep across each sample, then continue to the next sample. This treats each window position as an independent dataset item.

Parameters:

Name Type Description Default
seq_len int

Length of sliding windows. Must be at least 1.

required

Raises:

Type Description
AssertionError

If seq_len < 1.

Side Effects
  • Sets self.mapping to list of [sample, start, end] tuples
  • Updates self._length to total number of windows across all samples
Source code in src/bnode_core/nn/nn_utils/load_data.py
def initialize_map(self, seq_len: int):
    """Create the sliding window index mapping for all samples.

    Builds a mapping list where each entry [sample_idx, start_pos, end_pos] defines
    a sliding window position. Windows slide by 1 timestep across each sample, then
    continue to the next sample. This treats each window position as an independent
    dataset item.

    Args:
        seq_len (int): Length of sliding windows. Must be at least 1.

    Raises:
        AssertionError: If seq_len < 1.

    Side Effects:
        - Sets self.mapping to list of [sample, start, end] tuples
        - Updates self._length to total number of windows across all samples
    """
    assert seq_len > 0, "seq_len must be at least 1"
    # define map
    n_batches_per_sample = (self.datasets['time'].shape[2] - (seq_len - 1))
    n_batches_total = n_batches_per_sample * self._length_old
    self.mapping = [[] for i in range(n_batches_total)]
    self._length = n_batches_total
    # fill out map. mapping shall contain [n_sample, start_position, stop_position]
    k_stop=seq_len # stop position in sequence
    j=0 # sample position in datasets
    for i in range(n_batches_total):
        self.mapping[i] = [j, k_stop-seq_len, k_stop]
        if k_stop + 2 > n_batches_per_sample + seq_len: # if the over next sequence would be out of bounds, go to next sample (+1 more because of > and not >=)
            k_stop = seq_len
            j += 1
        else:
            k_stop += 1

__getitem__(index: int) -> dict

Get a single sliding window sample.

Parameters:

Name Type Description Default
index int

Index of the sliding window to retrieve.

required

Returns:

Name Type Description
dict dict

Dictionary with same keys as self.datasets. For 3D+ arrays (time series), returns subsequence [start:end] from appropriate sample. For 2D arrays, returns full array for the sample.

Source code in src/bnode_core/nn/nn_utils/load_data.py
def __getitem__(self, index: int) -> dict:
    """Get a single sliding window sample.

    Args:
        index (int): Index of the sliding window to retrieve.

    Returns:
        dict: Dictionary with same keys as self.datasets. For 3D+ arrays (time series),
            returns subsequence [start:end] from appropriate sample. For 2D arrays,
            returns full array for the sample.
    """
    i, k_start, k_stop = self.mapping[index]
    ret_val = {}
    for key, value in self.datasets.items():
        if value.ndim == 2:
            ret_val[key] = value[i, :]
        else:
            ret_val[key] = value[i, :, k_start:k_stop]
    return ret_val

__getitems__(indices: list) -> list

Get multiple sliding window samples (batch retrieval).

Parameters:

Name Type Description Default
indices list

List of window indices to retrieve.

required

Returns:

Name Type Description
list list

List of dictionaries, one per index, each containing the requested sliding window data.

Note

This method requires PyTorch >= 2.2 for optimal batched data loading.

Source code in src/bnode_core/nn/nn_utils/load_data.py
def __getitems__(self, indices: list) -> list:
    """Get multiple sliding window samples (batch retrieval).

    Args:
        indices (list): List of window indices to retrieve.

    Returns:
        list: List of dictionaries, one per index, each containing the requested
            sliding window data.

    Note:
        This method requires PyTorch >= 2.2 for optimal batched data loading.
    """
    samples = [None] * len(indices)
    for i, index in enumerate(indices):
        samples[i] = self.__getitem__(index)
    return samples

__len__() -> int

Return the total number of sliding windows in the dataset.

Returns:

Name Type Description
int int

Total number of windows across all samples.

Source code in src/bnode_core/nn/nn_utils/load_data.py
def __len__(self) -> int:
    """Return the total number of sliding windows in the dataset.

    Returns:
        int: Total number of windows across all samples.
    """
    return self._length

load_validate_dataset_config(path: Path) -> base_pModelClass

Load and validate dataset configuration from YAML file.

Parameters:

Name Type Description Default
path Path

Path to the dataset configuration YAML file.

required

Returns:

Type Description
base_pModelClass

Validated dataset configuration as base_pModelClass instance.

Raises:

Type Description
FileNotFoundError

If configuration file doesn't exist.

Note

Uses OmegaConf to load YAML and validates against base_pModelClass schema.

Source code in src/bnode_core/nn/nn_utils/load_data.py
def load_validate_dataset_config(path: Path) -> base_pModelClass:
    """Load and validate dataset configuration from YAML file.

    Args:
        path: Path to the dataset configuration YAML file.

    Returns:
        Validated dataset configuration as base_pModelClass instance.

    Raises:
        FileNotFoundError: If configuration file doesn't exist.

    Note:
        Uses OmegaConf to load YAML and validates against base_pModelClass schema.
    """

    if not path.exists():
        raise FileNotFoundError('Dataset config file not found: {}'.format(path))

    logging.info('Loading dataset config file: {}'.format(path))
    _dataset_config_dict = OmegaConf.load(path)
    _dataset_config_dict = OmegaConf.to_object(_dataset_config_dict) # make dict
    dataset_config = base_pModelClass(**_dataset_config_dict) # validate
    logging.info('Validated dataset config file: {}'.format(path))
    return dataset_config

load_dataset_and_config(dataset_name: str, dataset_path: str) -> Tuple[h5py.File, Optional[base_pModelClass]]

Load HDF5 dataset and its configuration.

Loads the HDF5 dataset file and attempts to load its configuration. If configuration file doesn't exist, returns None for config.

Parameters:

Name Type Description Default
dataset_name str

Name/identifier of the dataset.

required
dataset_path str

Explicit path to dataset file, or empty string to use default location.

required

Returns:

Type Description
Tuple[File, Optional[base_pModelClass]]

Tuple of (dataset, dataset_config) where:

  • dataset: Open h5py.File handle to HDF5 dataset.
  • dataset_config: Validated configuration (base_pModelClass) or None if not found.
Note

The returned h5py.File should be closed when done (dataset.close()). Uses filepath_dataset_from_config to resolve actual file path.

Source code in src/bnode_core/nn/nn_utils/load_data.py
def load_dataset_and_config(dataset_name: str, dataset_path: str) -> Tuple[h5py.File, Optional[base_pModelClass]]:
    """Load HDF5 dataset and its configuration.

    Loads the HDF5 dataset file and attempts to load its configuration.
    If configuration file doesn't exist, returns None for config.

    Args:
        dataset_name: Name/identifier of the dataset.
        dataset_path: Explicit path to dataset file, or empty string to use default location.

    Returns:
        Tuple of (dataset, dataset_config) where:

            - dataset: Open h5py.File handle to HDF5 dataset.
            - dataset_config: Validated configuration (base_pModelClass) or None if not found.

    Note:
        The returned h5py.File should be closed when done (dataset.close()).
        Uses filepath_dataset_from_config to resolve actual file path.
    """
    _path = filepath_dataset_from_config(dataset_name, dataset_path)

    dataset = h5py.File(_path, 'r')
    logging.info('Loaded dataset from file: {}'.format(_path))

    _path = filepath_dataset_config_from_name(dataset_name)
    if not _path.exists():
        logging.info('No dataset config file found, using information from dataset file')
        dataset_config = None
    else:
        dataset_config = load_validate_dataset_config(_path)
    return dataset, dataset_config

make_stacked_dataset(dataset: h5py.File, context: str, seq_len_from_file: Optional[int] = None, seq_len_batches: Optional[int] = None) -> Union[torch.utils.data.StackDataset, TimeSeriesDataset]

Create a PyTorch dataset from HDF5 data with optional time series batching.

Loads time series data (states, derivatives, parameters, controls, outputs) from an HDF5 file and wraps it in a PyTorch StackDataset. If seq_len_batches is specified, returns a TimeSeriesDataset that enables sliding window sampling for variable-length sequences.

Parameters:

Name Type Description Default
dataset File

Open HDF5 file containing time series data with groups for different contexts (train/test/validation).

required
context str

Dataset context to load. Must be one of: 'train', 'test', 'validation', 'common_test', or 'common_validation'.

required
seq_len_from_file int

If provided, truncates loaded sequences to this length from the original file data. Defaults to None (use full sequence length).

None
seq_len_batches int

If provided, returns a TimeSeriesDataset that extracts subsequences of this length via sliding window. If None, returns standard StackDataset with full sequences. Defaults to None.

None

Returns:

Type Description
Union[StackDataset, TimeSeriesDataset]

torch.utils.data.StackDataset or TimeSeriesDataset: A dataset that yields dictionaries containing tensors for 'time', 'states', and optionally 'states_der', 'parameters', 'controls', and 'outputs'. Each tensor has shape (batch, channels, time_steps).

Note
  • All None-valued arrays are automatically excluded from the returned dataset
  • Time tensor is replicated across batch dimension from single time vector
  • When seq_len_batches is used, the dataset length increases to accommodate all possible sliding windows across the original sequences
Source code in src/bnode_core/nn/nn_utils/load_data.py
def make_stacked_dataset(
    dataset: h5py.File, 
    context: str, 
    seq_len_from_file: Optional[int] = None, 
    seq_len_batches: Optional[int] = None
) -> Union[torch.utils.data.StackDataset, 'TimeSeriesDataset']:
    """Create a PyTorch dataset from HDF5 data with optional time series batching.

    Loads time series data (states, derivatives, parameters, controls, outputs) from an HDF5
    file and wraps it in a PyTorch StackDataset. If seq_len_batches is specified, returns a
    TimeSeriesDataset that enables sliding window sampling for variable-length sequences.

    Args:
        dataset (h5py.File): Open HDF5 file containing time series data with groups for
            different contexts (train/test/validation).
        context (str): Dataset context to load. Must be one of: 'train', 'test', 'validation',
            'common_test', or 'common_validation'.
        seq_len_from_file (int, optional): If provided, truncates loaded sequences to this length
            from the original file data. Defaults to None (use full sequence length).
        seq_len_batches (int, optional): If provided, returns a TimeSeriesDataset that extracts
            subsequences of this length via sliding window. If None, returns standard StackDataset
            with full sequences. Defaults to None.

    Returns:
        torch.utils.data.StackDataset or TimeSeriesDataset: A dataset that yields dictionaries
            containing tensors for 'time', 'states', and optionally 'states_der', 'parameters',
            'controls', and 'outputs'. Each tensor has shape (batch, channels, time_steps).

    Note:
        - All None-valued arrays are automatically excluded from the returned dataset
        - Time tensor is replicated across batch dimension from single time vector
        - When seq_len_batches is used, the dataset length increases to accommodate all
          possible sliding windows across the original sequences
    """
    assert context in ['train', 'test', 'validation', 'common_test', 'common_validation'], 'context must be one of train, test, validation, common_test, common_validation'

    # get tensors of dataset
    time = dataset['time'][:]
    states = dataset[context]['states'][:]
    states_der = dataset[context]['states_der'][:] if 'states_der' in dataset[context].keys() else None
    parameters = dataset[context]['parameters'][:] if 'parameters' in dataset[context].keys() else None
    controls = dataset[context]['controls'][:] if 'controls' in dataset[context].keys() else None
    outputs = dataset[context]['outputs'][:] if 'outputs' in dataset[context].keys() else None

    # cut data from file to seq_len
    if seq_len_from_file is not None:
        time = time[:seq_len_from_file]
        states = states[:,:,:seq_len_from_file]
        states_der = states_der[:,:,:seq_len_from_file] if states_der is not None else None # TODO: add finite difference calculation for states_der and cfg.dataset_prep entries to say if derivatives are included
        parameters = parameters[:] if parameters is not None else None
        controls = controls[:,:,:seq_len_from_file] if controls is not None else None
        outputs = outputs[:,:,:seq_len_from_file] if outputs is not None else None

    # define wrapper to delete nones from kwargs dict
    def _delete_nones(**kwargs):
        kwargs = dict((k,v) for k,v in kwargs.items() if v is not None)
        return kwargs

    # make torch dataset with dict as output
    dataset_type = torch.utils.data.StackDataset if seq_len_batches is None else lambda *args, **kwargs: TimeSeriesDataset(seq_len_batches, *args, **kwargs)
    torch_dataset = dataset_type(
        **_delete_nones(
            time = torch.tensor(time, dtype=torch.float32).unsqueeze(0).expand(states.shape[0], -1).unsqueeze(1),
            states = torch.tensor(states, dtype=torch.float32),
            states_der = torch.tensor(states_der, dtype=torch.float32) if states_der is not None else None,
            parameters = torch.tensor(parameters, dtype=torch.float32) if parameters is not None else None,
            controls = torch.tensor(controls, dtype=torch.float32) if controls is not None else None,
            outputs = torch.tensor(outputs, dtype=torch.float32) if outputs is not None else None,
        )
    )
    logging.info('Created {} with {} and sequence length {}'.format(type(torch_dataset),torch_dataset.datasets.keys(), torch_dataset.datasets['time'].shape[2]))
    return torch_dataset

bnode_core.nn.nn_utils.kullback_leibler

Kullback-Leibler divergence computation for VAE training.

Provides functions to compute KL divergence between learned latent distributions and standard normal prior, with support for timeseries data and dimension analysis.

Attention

This documentation is generated by AI. Please be aware of possible inaccurcies.

kullback_leibler(mu: torch.Tensor, logvar: torch.Tensor, per_dimension: bool = False, reduce: bool = True, time_series_aggregation_mode: Optional[str] = 'mean') -> torch.Tensor

Compute KL divergence KL(N(mu, exp(logvar)) || N(0, I)).

Calculates the Kullback-Leibler divergence between a learned normal distribution N(mu, sigma^2) and the standard normal prior N(0, 1). Uses the analytical formula: KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))

Parameters:

Name Type Description Default
mu Tensor

Mean of learned distribution, shape (batch, latent_dim) or (batch, latent_dim, seq_len) for timeseries.

required
logvar Tensor

Log-variance of learned distribution, same shape as mu.

required
per_dimension bool

If True, return KL divergence per latent dimension instead of summing across dimensions. Default: False.

False
reduce bool

If True, return mean over batch. If False, return per-sample values. Default: True.

True
time_series_aggregation_mode Optional[str]

How to aggregate over time dimension if input is timeseries (3D). Options: 'mean', 'max', 'sum', or None (keep time dim). Default: 'mean'.

'mean'

Returns:

Type Description
Tensor

KL divergence tensor. Shape depends on parameters: - per_dimension=False, reduce=True: scalar - per_dimension=False, reduce=False: (batch,) - per_dimension=True, reduce=True: (latent_dim,) - per_dimension=True, reduce=False: (batch, latent_dim)

Note

The KL divergence is always non-negative and equals zero only when the learned distribution matches the prior exactly.

Source code in src/bnode_core/nn/nn_utils/kullback_leibler.py
def kullback_leibler(
    mu: torch.Tensor, 
    logvar: torch.Tensor, 
    per_dimension: bool = False, 
    reduce: bool = True, 
    time_series_aggregation_mode: Optional[str] = 'mean'
) -> torch.Tensor:
    """Compute KL divergence KL(N(mu, exp(logvar)) || N(0, I)).

    Calculates the Kullback-Leibler divergence between a learned normal distribution
    N(mu, sigma^2) and the standard normal prior N(0, 1). Uses the analytical formula:
    KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))

    Args:
        mu: Mean of learned distribution, shape (batch, latent_dim) or 
            (batch, latent_dim, seq_len) for timeseries.
        logvar: Log-variance of learned distribution, same shape as mu.
        per_dimension: If True, return KL divergence per latent dimension instead of
            summing across dimensions. Default: False.
        reduce: If True, return mean over batch. If False, return per-sample values.
            Default: True.
        time_series_aggregation_mode: How to aggregate over time dimension if input is
            timeseries (3D). Options: 'mean', 'max', 'sum', or None (keep time dim).
            Default: 'mean'.

    Returns:
        KL divergence tensor. Shape depends on parameters:
            - per_dimension=False, reduce=True: scalar
            - per_dimension=False, reduce=False: (batch,)
            - per_dimension=True, reduce=True: (latent_dim,)
            - per_dimension=True, reduce=False: (batch, latent_dim)

    Note:
        The KL divergence is always non-negative and equals zero only when the learned
        distribution matches the prior exactly.
    """
    is_timeseries = len(mu.shape) == 3
    kl = -0.5 *(1 + logvar - mu.pow(2) - logvar.exp())

    if is_timeseries:
        if time_series_aggregation_mode == 'mean':
            kl = torch.mean(kl, dim=2)
        elif time_series_aggregation_mode == 'max':
            kl = torch.max(kl, dim=2)
        elif time_series_aggregation_mode == 'sum':
            kl = torch.sum(kl, dim=2)
        elif time_series_aggregation_mode == None:
            pass

    if per_dimension is False:
        kl = torch.sum(kl, dim=1)  

    if reduce:
        kl = torch.mean(kl, dim=0)

    return kl

count_populated_dimensions(mu: torch.Tensor, logvar: torch.Tensor, threshold: float = 0.05, kl_timeseries_aggregation_mode: str = 'mean', return_idx: bool = False) -> torch.Tensor

Count number of latent dimensions actively used by the model.

A dimension is considered "populated" or "active" if its KL divergence exceeds a threshold. This helps diagnose posterior collapse (when KL → 0 for all dimensions) and track how many dimensions the model actually uses.

Parameters:

Name Type Description Default
mu Tensor

Mean of learned distribution, shape (batch, latent_dim) or (batch, latent_dim, seq_len).

required
logvar Tensor

Log-variance of learned distribution, same shape as mu.

required
threshold float

Minimum KL divergence for a dimension to be considered active. Default: 0.05.

0.05
kl_timeseries_aggregation_mode str

Aggregation mode for timeseries data ('mean', 'max', 'sum'). Default: 'mean'.

'mean'
return_idx bool

If True, also return boolean mask of active dimensions. Default: False.

False

Returns:

Type Description
Tensor

Tuple of (n_dim_populated, idx): - n_dim_populated: Number of dimensions with KL > threshold (scalar tensor). - idx: Boolean mask of active dimensions (tensor or None if return_idx=False).

Note

Computes KL per dimension averaged over batch (per_dimension=True, reduce=True). Uses torch.no_grad() for efficiency since this is a diagnostic metric.

Source code in src/bnode_core/nn/nn_utils/kullback_leibler.py
def count_populated_dimensions(
    mu: torch.Tensor, 
    logvar: torch.Tensor, 
    threshold: float = 0.05, 
    kl_timeseries_aggregation_mode: str = 'mean', 
    return_idx: bool = False
) -> torch.Tensor:
    """Count number of latent dimensions actively used by the model.

    A dimension is considered "populated" or "active" if its KL divergence exceeds
    a threshold. This helps diagnose posterior collapse (when KL → 0 for all dimensions)
    and track how many dimensions the model actually uses.

    Args:
        mu: Mean of learned distribution, shape (batch, latent_dim) or (batch, latent_dim, seq_len).
        logvar: Log-variance of learned distribution, same shape as mu.
        threshold: Minimum KL divergence for a dimension to be considered active. Default: 0.05.
        kl_timeseries_aggregation_mode: Aggregation mode for timeseries data ('mean', 'max', 'sum').
            Default: 'mean'.
        return_idx: If True, also return boolean mask of active dimensions. Default: False.

    Returns:
        Tuple of (n_dim_populated, idx):
            - n_dim_populated: Number of dimensions with KL > threshold (scalar tensor).
            - idx: Boolean mask of active dimensions (tensor or None if return_idx=False).

    Note:
        Computes KL per dimension averaged over batch (per_dimension=True, reduce=True).
        Uses torch.no_grad() for efficiency since this is a diagnostic metric.
    """
    with torch.no_grad():
        kl = kullback_leibler(mu, logvar, per_dimension = True, reduce = True, time_series_aggregation_mode = kl_timeseries_aggregation_mode)
        # print histogram of kl divergence to terminal
        # print('kl divergence histogram: 0.0 - 2.0')
        # print(torch.histc(kl, bins=21, min=0, max=2.0))
        idx = kl > threshold
        n_dim_populated = torch.sum(idx)
    if return_idx:
        return n_dim_populated.detach(), idx.detach()
    else:
        return n_dim_populated.detach(), None

bnode_core.nn.nn_utils.early_stopping

Early stopping utility for PyTorch model training.

Monitors validation loss and stops training when no improvement is observed for a specified number of epochs (patience). Saves best model checkpoint.

Attention

This documentation is generated by AI. Please be aware of possible inaccurcies.

EarlyStopping

Stop training early if validation loss doesn't improve after given patience.

Tracks validation loss and saves model checkpoints when improvements occur. Triggers early stopping flag when loss plateaus for 'patience' epochs.

Attributes:

Name Type Description
patience

Number of epochs to wait before stopping after loss plateau.

verbose

If True, print messages for each loss improvement.

counter

Number of epochs since last loss improvement.

best_score

Best validation loss seen so far.

corresponding_score

Training loss corresponding to best validation loss.

early_stop

Flag indicating whether to stop training.

score_last_save

Validation loss at last checkpoint save.

threshold

Minimum loss improvement to qualify as improvement.

threshold_mode

Either 'abs' (absolute) or 'rel' (relative) threshold.

path

File path for saving model checkpoint.

optimizer_path

File path for saving optimizer state.

Source code in src/bnode_core/nn/nn_utils/early_stopping.py
class EarlyStopping:
    """Stop training early if validation loss doesn't improve after given patience.

    Tracks validation loss and saves model checkpoints when improvements occur.
    Triggers early stopping flag when loss plateaus for 'patience' epochs.

    Attributes:
        patience: Number of epochs to wait before stopping after loss plateau.
        verbose: If True, print messages for each loss improvement.
        counter: Number of epochs since last loss improvement.
        best_score: Best validation loss seen so far.
        corresponding_score: Training loss corresponding to best validation loss.
        early_stop: Flag indicating whether to stop training.
        score_last_save: Validation loss at last checkpoint save.
        threshold: Minimum loss improvement to qualify as improvement.
        threshold_mode: Either 'abs' (absolute) or 'rel' (relative) threshold.
        path: File path for saving model checkpoint.
        optimizer_path: File path for saving optimizer state.
    """

    def __init__(
        self, 
        patience: int = 7, 
        verbose: bool = False, 
        threshold: float = 0, 
        threshold_mode: str = 'abs', 
        path: str = 'checkpoint.pt', 
        optimizer_path: str = 'optimizer.pt', 
        trace_func: Callable = print
    ):
        """Initialize early stopping monitor.

        Args:
            patience: Number of epochs to wait after last validation loss improvement
                before triggering early stop. Default: 7.
            verbose: If True, prints message for each validation loss improvement. Default: False.
            threshold: Minimum change in monitored loss to qualify as improvement. Default: 0.
            threshold_mode: Either 'abs' (absolute: loss < best - threshold) or 
                'rel' (relative: loss < best * (1 - threshold)). Default: 'abs'.
            path: Path to save best model checkpoint. Default: 'checkpoint.pt'.
            optimizer_path: Path to save optimizer state. Default: 'optimizer.pt'.
            trace_func: Logging function for status messages. Default: print.
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.corresponding_score = None
        self.early_stop = False
        self.score_last_save = np.inf
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.path = path
        self.optimizer_path = optimizer_path
        self.trace_func = trace_func
        self.trace_func('EarlyStopping initialized with patience = {}, threshold = {}, threshold_mode = {}'.format(
            self.patience, self.threshold, self.threshold_mode))

    def reset(self):
        """Reset all early stopping state to initial values.

        Useful for starting a new training phase with fresh early stopping.
        """
        self.reset_counter()
        self.best_score = None
        self.corresponding_score = None
        self.early_stop = False
        self.score_last_save = np.inf

    def reset_counter(self):
        """Reset only the patience counter and early_stop flag.

        Keeps best_score intact. Useful after manual interventions.
        """
        self.counter = 0
        self.early_stop = False

    def __call__(
        self, 
        loss: float, 
        model: nn.Module, 
        epoch: Optional[int] = None, 
        optimizer: Optional[optim.Optimizer] = None, 
        corresponding_loss: Optional[float] = None
    ):
        """Update early stopping state based on current validation loss.

        Checks if loss has improved according to threshold criteria. Saves checkpoint
        if improvement occurred, otherwise increments patience counter. Sets early_stop
        flag when counter reaches patience.

        Args:
            loss: Current validation loss.
            model: PyTorch model with save() method.
            epoch: Current epoch number (for logging). Optional.
            optimizer: PyTorch optimizer to save state. Optional.
            corresponding_loss: Training loss from same epoch (for tracking). Optional.

        Side Effects:
            - Updates counter, best_score, and corresponding_score
            - Saves model checkpoint when loss improves
            - Sets early_stop flag when patience exceeded
            - Handles NaN loss by setting to infinity
        """
        # if loss is not a number
        if np.isnan(loss):
            loss = np.inf
            logging.warning('EarlyStopping: loss is NaN. Setting to Inf for early stopping update.')
        score = loss

        # initial case
        if self.best_score is None:
            self.best_score = score
            self.corresponding_score = corresponding_loss
            self.save_checkpoint(loss, model, optimizer, epoch)

        _update_flag = False

        if self.threshold_mode == 'abs':
            if score < self.best_score - self.threshold:
                _update_flag = True
        elif self.threshold_mode == 'rel':
            if score < self.best_score * (1 - self.threshold):
                _update_flag = True
        else:
            raise ValueError('Invalid threshold mode selected.')

        if _update_flag:
            self.best_score = score
            self.counter = 0
            self.corresponding_score = corresponding_loss
            self.save_checkpoint(loss, model, optimizer, epoch)
        else:
            self.counter += 1
        if self.counter >= self.patience:
                self.early_stop = True

    def save_checkpoint(
        self, 
        loss: float, 
        model: nn.Module, 
        optimizer: Optional[optim.Optimizer], 
        epoch: Optional[int]
    ):
        """Save model checkpoint when validation loss improves.

        Args:
            loss: Current validation loss (for logging).
            model: PyTorch model with save() method.
            optimizer: PyTorch optimizer (state saved if not None).
            epoch: Current epoch number (for logging).

        Side Effects:
            - Calls model.save(path) to persist model state
            - Saves optimizer state_dict if optimizer provided
            - Updates score_last_save
            - Logs checkpoint save if verbose=True
        """
        if self.verbose:
            self.trace_func('----------------------> Epoch {} Validation loss decreased ({:.6f} --> {:.6f}).  Saving model to {}'.format(epoch, self.score_last_save, loss, self.path))
        model.save(self.path)
        self.score_last_save = loss
        if optimizer is not None:
                torch.save(optimizer.state_dict(), self.optimizer_path)

__init__(patience: int = 7, verbose: bool = False, threshold: float = 0, threshold_mode: str = 'abs', path: str = 'checkpoint.pt', optimizer_path: str = 'optimizer.pt', trace_func: Callable = print)

Initialize early stopping monitor.

Parameters:

Name Type Description Default
patience int

Number of epochs to wait after last validation loss improvement before triggering early stop. Default: 7.

7
verbose bool

If True, prints message for each validation loss improvement. Default: False.

False
threshold float

Minimum change in monitored loss to qualify as improvement. Default: 0.

0
threshold_mode str

Either 'abs' (absolute: loss < best - threshold) or 'rel' (relative: loss < best * (1 - threshold)). Default: 'abs'.

'abs'
path str

Path to save best model checkpoint. Default: 'checkpoint.pt'.

'checkpoint.pt'
optimizer_path str

Path to save optimizer state. Default: 'optimizer.pt'.

'optimizer.pt'
trace_func Callable

Logging function for status messages. Default: print.

print
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
def __init__(
    self, 
    patience: int = 7, 
    verbose: bool = False, 
    threshold: float = 0, 
    threshold_mode: str = 'abs', 
    path: str = 'checkpoint.pt', 
    optimizer_path: str = 'optimizer.pt', 
    trace_func: Callable = print
):
    """Initialize early stopping monitor.

    Args:
        patience: Number of epochs to wait after last validation loss improvement
            before triggering early stop. Default: 7.
        verbose: If True, prints message for each validation loss improvement. Default: False.
        threshold: Minimum change in monitored loss to qualify as improvement. Default: 0.
        threshold_mode: Either 'abs' (absolute: loss < best - threshold) or 
            'rel' (relative: loss < best * (1 - threshold)). Default: 'abs'.
        path: Path to save best model checkpoint. Default: 'checkpoint.pt'.
        optimizer_path: Path to save optimizer state. Default: 'optimizer.pt'.
        trace_func: Logging function for status messages. Default: print.
    """
    self.patience = patience
    self.verbose = verbose
    self.counter = 0
    self.best_score = None
    self.corresponding_score = None
    self.early_stop = False
    self.score_last_save = np.inf
    self.threshold = threshold
    self.threshold_mode = threshold_mode
    self.path = path
    self.optimizer_path = optimizer_path
    self.trace_func = trace_func
    self.trace_func('EarlyStopping initialized with patience = {}, threshold = {}, threshold_mode = {}'.format(
        self.patience, self.threshold, self.threshold_mode))

reset()

Reset all early stopping state to initial values.

Useful for starting a new training phase with fresh early stopping.

Source code in src/bnode_core/nn/nn_utils/early_stopping.py
def reset(self):
    """Reset all early stopping state to initial values.

    Useful for starting a new training phase with fresh early stopping.
    """
    self.reset_counter()
    self.best_score = None
    self.corresponding_score = None
    self.early_stop = False
    self.score_last_save = np.inf

reset_counter()

Reset only the patience counter and early_stop flag.

Keeps best_score intact. Useful after manual interventions.

Source code in src/bnode_core/nn/nn_utils/early_stopping.py
def reset_counter(self):
    """Reset only the patience counter and early_stop flag.

    Keeps best_score intact. Useful after manual interventions.
    """
    self.counter = 0
    self.early_stop = False

__call__(loss: float, model: nn.Module, epoch: Optional[int] = None, optimizer: Optional[optim.Optimizer] = None, corresponding_loss: Optional[float] = None)

Update early stopping state based on current validation loss.

Checks if loss has improved according to threshold criteria. Saves checkpoint if improvement occurred, otherwise increments patience counter. Sets early_stop flag when counter reaches patience.

Parameters:

Name Type Description Default
loss float

Current validation loss.

required
model Module

PyTorch model with save() method.

required
epoch Optional[int]

Current epoch number (for logging). Optional.

None
optimizer Optional[Optimizer]

PyTorch optimizer to save state. Optional.

None
corresponding_loss Optional[float]

Training loss from same epoch (for tracking). Optional.

None
Side Effects
  • Updates counter, best_score, and corresponding_score
  • Saves model checkpoint when loss improves
  • Sets early_stop flag when patience exceeded
  • Handles NaN loss by setting to infinity
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
def __call__(
    self, 
    loss: float, 
    model: nn.Module, 
    epoch: Optional[int] = None, 
    optimizer: Optional[optim.Optimizer] = None, 
    corresponding_loss: Optional[float] = None
):
    """Update early stopping state based on current validation loss.

    Checks if loss has improved according to threshold criteria. Saves checkpoint
    if improvement occurred, otherwise increments patience counter. Sets early_stop
    flag when counter reaches patience.

    Args:
        loss: Current validation loss.
        model: PyTorch model with save() method.
        epoch: Current epoch number (for logging). Optional.
        optimizer: PyTorch optimizer to save state. Optional.
        corresponding_loss: Training loss from same epoch (for tracking). Optional.

    Side Effects:
        - Updates counter, best_score, and corresponding_score
        - Saves model checkpoint when loss improves
        - Sets early_stop flag when patience exceeded
        - Handles NaN loss by setting to infinity
    """
    # if loss is not a number
    if np.isnan(loss):
        loss = np.inf
        logging.warning('EarlyStopping: loss is NaN. Setting to Inf for early stopping update.')
    score = loss

    # initial case
    if self.best_score is None:
        self.best_score = score
        self.corresponding_score = corresponding_loss
        self.save_checkpoint(loss, model, optimizer, epoch)

    _update_flag = False

    if self.threshold_mode == 'abs':
        if score < self.best_score - self.threshold:
            _update_flag = True
    elif self.threshold_mode == 'rel':
        if score < self.best_score * (1 - self.threshold):
            _update_flag = True
    else:
        raise ValueError('Invalid threshold mode selected.')

    if _update_flag:
        self.best_score = score
        self.counter = 0
        self.corresponding_score = corresponding_loss
        self.save_checkpoint(loss, model, optimizer, epoch)
    else:
        self.counter += 1
    if self.counter >= self.patience:
            self.early_stop = True

save_checkpoint(loss: float, model: nn.Module, optimizer: Optional[optim.Optimizer], epoch: Optional[int])

Save model checkpoint when validation loss improves.

Parameters:

Name Type Description Default
loss float

Current validation loss (for logging).

required
model Module

PyTorch model with save() method.

required
optimizer Optional[Optimizer]

PyTorch optimizer (state saved if not None).

required
epoch Optional[int]

Current epoch number (for logging).

required
Side Effects
  • Calls model.save(path) to persist model state
  • Saves optimizer state_dict if optimizer provided
  • Updates score_last_save
  • Logs checkpoint save if verbose=True
Source code in src/bnode_core/nn/nn_utils/early_stopping.py
def save_checkpoint(
    self, 
    loss: float, 
    model: nn.Module, 
    optimizer: Optional[optim.Optimizer], 
    epoch: Optional[int]
):
    """Save model checkpoint when validation loss improves.

    Args:
        loss: Current validation loss (for logging).
        model: PyTorch model with save() method.
        optimizer: PyTorch optimizer (state saved if not None).
        epoch: Current epoch number (for logging).

    Side Effects:
        - Calls model.save(path) to persist model state
        - Saves optimizer state_dict if optimizer provided
        - Updates score_last_save
        - Logs checkpoint save if verbose=True
    """
    if self.verbose:
        self.trace_func('----------------------> Epoch {} Validation loss decreased ({:.6f} --> {:.6f}).  Saving model to {}'.format(epoch, self.score_last_save, loss, self.path))
    model.save(self.path)
    self.score_last_save = loss
    if optimizer is not None:
            torch.save(optimizer.state_dict(), self.optimizer_path)

bnode_core.nn.nn_utils.count_parameters

Utility for counting trainable parameters in PyTorch models.

count_parameters(model: nn.Module) -> int

Count total number of trainable parameters in a model.

Parameters:

Name Type Description Default
model Module

PyTorch model (nn.Module).

required

Returns:

Type Description
int

Total number of trainable parameters (int).

Example

model = VAE(...) num_params = count_parameters(model) print(f"Model has {num_params:,} trainable parameters")

Source code in src/bnode_core/nn/nn_utils/count_parameters.py
def count_parameters(model: nn.Module) -> int:
    """Count total number of trainable parameters in a model.

    Args:
        model: PyTorch model (nn.Module).

    Returns:
        Total number of trainable parameters (int).

    Example:
        >>> model = VAE(...)
        >>> num_params = count_parameters(model)
        >>> print(f"Model has {num_params:,} trainable parameters")
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

bnode_core.nn.nn_utils.capacity_scheduler

Capacity scheduler for VAE training with controlled KL divergence growth.

This module implements a scheduler that gradually increases the KL divergence capacity during VAE training to prevent posterior collapse while maintaining good reconstruction quality.

Attention

This documentation is generated by AI. Please be aware of possible inaccurcies.

capacity_scheduler

Schedule the KL divergence capacity for VAE bottleneck layer.

Gradually increases the target KL divergence (capacity) to allow the model to learn meaningful latent representations without posterior collapse. The capacity increases when validation loss plateaus, encouraging the model to use more of the latent space.

Attributes:

Name Type Description
patience

Number of epochs to wait before increasing capacity after loss plateau.

capacity

Current KL divergence capacity target.

capacity_max

Maximum capacity value (stops increasing after reaching this).

capacity_increment

Amount to increase capacity by (absolute or relative).

capacity_increment_mode

Either 'abs' (add increment) or 'rel' (multiply by increment).

counter

Number of epochs since last loss improvement.

best_score

Best validation loss seen so far.

threshold

Minimum loss improvement to reset counter.

threshold_mode

Either 'abs' (absolute threshold) or 'rel' (relative threshold).

enabled

If False, scheduler is disabled and returns None for capacity.

reached_max_capacity

True if capacity has reached capacity_max.

Source code in src/bnode_core/nn/nn_utils/capacity_scheduler.py
class capacity_scheduler:
    """Schedule the KL divergence capacity for VAE bottleneck layer.

    Gradually increases the target KL divergence (capacity) to allow the model to
    learn meaningful latent representations without posterior collapse. The capacity
    increases when validation loss plateaus, encouraging the model to use more of
    the latent space.

    Attributes:
        patience: Number of epochs to wait before increasing capacity after loss plateau.
        capacity: Current KL divergence capacity target.
        capacity_max: Maximum capacity value (stops increasing after reaching this).
        capacity_increment: Amount to increase capacity by (absolute or relative).
        capacity_increment_mode: Either 'abs' (add increment) or 'rel' (multiply by increment).
        counter: Number of epochs since last loss improvement.
        best_score: Best validation loss seen so far.
        threshold: Minimum loss improvement to reset counter.
        threshold_mode: Either 'abs' (absolute threshold) or 'rel' (relative threshold).
        enabled: If False, scheduler is disabled and returns None for capacity.
        reached_max_capacity: True if capacity has reached capacity_max.
    """

    def __init__(
        self, 
        patience: int, 
        capacity_start: float, 
        capacity_max: float, 
        capacity_increment: float, 
        capacity_increment_mode: str, 
        threshold: float, 
        threshold_mode: str, 
        trace_func: Callable = logging.info, 
        enabled: bool = True
    ):
        """Initialize the capacity scheduler.

        Args:
            patience: Number of epochs to wait after last validation loss improvement
                before increasing capacity.
            capacity_start: Initial KL divergence capacity target.
            capacity_max: Maximum capacity value (caps further increases).
            capacity_increment: Amount to change capacity when triggered.
            capacity_increment_mode: Either 'abs' (additive: capacity += increment) or 
                'rel' (multiplicative: capacity *= increment).
            threshold: Minimum change in validation loss to qualify as improvement.
            threshold_mode: Either 'abs' (absolute: loss < best - threshold) or 
                'rel' (relative: loss < best * (1 - threshold)).
            trace_func: Logging function for status messages (default: logging.info).
            enabled: If False, scheduler is disabled and get_capacity() returns None.

        Raises:
            AssertionError: If capacity_increment_mode or threshold_mode are invalid.
        """
        self.patience = patience
        self.capacity = capacity_start
        self.capacity_max = capacity_max
        self.capacity_increment = capacity_increment
        assert capacity_increment_mode in ['abs', 'rel'], 'Invalid capacity increment mode selected.'
        self.capacity_increment_mode = capacity_increment_mode
        self.counter = 0
        self.best_score = np.inf
        self.corresponding_score = None
        self.early_stop = False
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        assert threshold_mode in ['abs', 'rel'], 'Invalid threshold mode selected.'
        self.trace_func = trace_func
        self.enabled = enabled
        self.reached_max_capacity = False
        self.trace_func('CapacityScheduler initialized with patience = {}, capacity = {}, capacity_increment = {}, capacity_increment_mode = {}, threshold = {}, threshold_mode = {}'.format(
            self.patience, self.capacity, self.capacity_increment, self.capacity_increment_mode, self.threshold, self.threshold_mode))
        if not self.enabled:
            self.trace_func('CapacityScheduler disabled.')

    def update(self, score: float):
        """Update scheduler state based on current validation loss.

        Tracks validation loss improvements and increases capacity when loss
        plateaus for 'patience' epochs. Capacity increases until reaching
        capacity_max.

        Args:
            score: Current validation loss (typically MSE loss).

        Side Effects:
            - Updates counter and best_score
            - Increases capacity if patience threshold reached
            - Sets reached_max_capacity flag when maximum is hit
            - Logs capacity changes via trace_func
        """
        if self.enabled and not self.reached_max_capacity:
            _update_flag = False

            if self.threshold_mode == 'abs':
                if score < self.best_score - self.threshold:
                    _update_flag = True
            elif self.threshold_mode == 'rel':
                if score < self.best_score * (1 - self.threshold):
                    _update_flag = True

            if _update_flag:
                self.best_score = score
                self.counter = 0
            else:
                self.counter += 1

            if self.counter >= self.patience:
                if self.reached_max_capacity is False:
                    # update capacity
                    if self.capacity_increment_mode == 'abs':
                        new_capacity = self.capacity + self.capacity_increment
                    elif self.capacity_increment_mode == 'rel':
                        new_capacity = self.capacity * self.capacity_increment
                    if new_capacity > self.capacity_max:
                        new_capacity = self.capacity_max
                        self.reached_max_capacity = True
                        self.trace_func('\tCapacityScheduler reached maximum capacity of {}.'.format(self.capacity_max))
                    self.capacity = new_capacity

                    self.trace_func('\tCapacityScheduler updated capacity to {} after {} epochs.'.format(self.capacity, self.counter))
                    self.counter = 0

    def get_capacity(self) -> Optional[float]:
        """Get current capacity target for KL divergence loss.

        Returns:
            Current capacity value (float) if enabled, None if disabled.
        """
        return self.capacity if self.enabled else None

__init__(patience: int, capacity_start: float, capacity_max: float, capacity_increment: float, capacity_increment_mode: str, threshold: float, threshold_mode: str, trace_func: Callable = logging.info, enabled: bool = True)

Initialize the capacity scheduler.

Parameters:

Name Type Description Default
patience int

Number of epochs to wait after last validation loss improvement before increasing capacity.

required
capacity_start float

Initial KL divergence capacity target.

required
capacity_max float

Maximum capacity value (caps further increases).

required
capacity_increment float

Amount to change capacity when triggered.

required
capacity_increment_mode str

Either 'abs' (additive: capacity += increment) or 'rel' (multiplicative: capacity *= increment).

required
threshold float

Minimum change in validation loss to qualify as improvement.

required
threshold_mode str

Either 'abs' (absolute: loss < best - threshold) or 'rel' (relative: loss < best * (1 - threshold)).

required
trace_func Callable

Logging function for status messages (default: logging.info).

info
enabled bool

If False, scheduler is disabled and get_capacity() returns None.

True

Raises:

Type Description
AssertionError

If capacity_increment_mode or threshold_mode are invalid.

Source code in src/bnode_core/nn/nn_utils/capacity_scheduler.py
def __init__(
    self, 
    patience: int, 
    capacity_start: float, 
    capacity_max: float, 
    capacity_increment: float, 
    capacity_increment_mode: str, 
    threshold: float, 
    threshold_mode: str, 
    trace_func: Callable = logging.info, 
    enabled: bool = True
):
    """Initialize the capacity scheduler.

    Args:
        patience: Number of epochs to wait after last validation loss improvement
            before increasing capacity.
        capacity_start: Initial KL divergence capacity target.
        capacity_max: Maximum capacity value (caps further increases).
        capacity_increment: Amount to change capacity when triggered.
        capacity_increment_mode: Either 'abs' (additive: capacity += increment) or 
            'rel' (multiplicative: capacity *= increment).
        threshold: Minimum change in validation loss to qualify as improvement.
        threshold_mode: Either 'abs' (absolute: loss < best - threshold) or 
            'rel' (relative: loss < best * (1 - threshold)).
        trace_func: Logging function for status messages (default: logging.info).
        enabled: If False, scheduler is disabled and get_capacity() returns None.

    Raises:
        AssertionError: If capacity_increment_mode or threshold_mode are invalid.
    """
    self.patience = patience
    self.capacity = capacity_start
    self.capacity_max = capacity_max
    self.capacity_increment = capacity_increment
    assert capacity_increment_mode in ['abs', 'rel'], 'Invalid capacity increment mode selected.'
    self.capacity_increment_mode = capacity_increment_mode
    self.counter = 0
    self.best_score = np.inf
    self.corresponding_score = None
    self.early_stop = False
    self.threshold = threshold
    self.threshold_mode = threshold_mode
    assert threshold_mode in ['abs', 'rel'], 'Invalid threshold mode selected.'
    self.trace_func = trace_func
    self.enabled = enabled
    self.reached_max_capacity = False
    self.trace_func('CapacityScheduler initialized with patience = {}, capacity = {}, capacity_increment = {}, capacity_increment_mode = {}, threshold = {}, threshold_mode = {}'.format(
        self.patience, self.capacity, self.capacity_increment, self.capacity_increment_mode, self.threshold, self.threshold_mode))
    if not self.enabled:
        self.trace_func('CapacityScheduler disabled.')

update(score: float)

Update scheduler state based on current validation loss.

Tracks validation loss improvements and increases capacity when loss plateaus for 'patience' epochs. Capacity increases until reaching capacity_max.

Parameters:

Name Type Description Default
score float

Current validation loss (typically MSE loss).

required
Side Effects
  • Updates counter and best_score
  • Increases capacity if patience threshold reached
  • Sets reached_max_capacity flag when maximum is hit
  • Logs capacity changes via trace_func
Source code in src/bnode_core/nn/nn_utils/capacity_scheduler.py
def update(self, score: float):
    """Update scheduler state based on current validation loss.

    Tracks validation loss improvements and increases capacity when loss
    plateaus for 'patience' epochs. Capacity increases until reaching
    capacity_max.

    Args:
        score: Current validation loss (typically MSE loss).

    Side Effects:
        - Updates counter and best_score
        - Increases capacity if patience threshold reached
        - Sets reached_max_capacity flag when maximum is hit
        - Logs capacity changes via trace_func
    """
    if self.enabled and not self.reached_max_capacity:
        _update_flag = False

        if self.threshold_mode == 'abs':
            if score < self.best_score - self.threshold:
                _update_flag = True
        elif self.threshold_mode == 'rel':
            if score < self.best_score * (1 - self.threshold):
                _update_flag = True

        if _update_flag:
            self.best_score = score
            self.counter = 0
        else:
            self.counter += 1

        if self.counter >= self.patience:
            if self.reached_max_capacity is False:
                # update capacity
                if self.capacity_increment_mode == 'abs':
                    new_capacity = self.capacity + self.capacity_increment
                elif self.capacity_increment_mode == 'rel':
                    new_capacity = self.capacity * self.capacity_increment
                if new_capacity > self.capacity_max:
                    new_capacity = self.capacity_max
                    self.reached_max_capacity = True
                    self.trace_func('\tCapacityScheduler reached maximum capacity of {}.'.format(self.capacity_max))
                self.capacity = new_capacity

                self.trace_func('\tCapacityScheduler updated capacity to {} after {} epochs.'.format(self.capacity, self.counter))
                self.counter = 0

get_capacity() -> Optional[float]

Get current capacity target for KL divergence loss.

Returns:

Type Description
Optional[float]

Current capacity value (float) if enabled, None if disabled.

Source code in src/bnode_core/nn/nn_utils/capacity_scheduler.py
def get_capacity(self) -> Optional[float]:
    """Get current capacity target for KL divergence loss.

    Returns:
        Current capacity value (float) if enabled, None if disabled.
    """
    return self.capacity if self.enabled else None