"""
Implementation of the block DTW, which is a variant of the SDTW that computes the SDTW on blocks of the signal instead of the entire signal.
Example
-------
>>> from dtw_loss_functions import block_dtw
>>> import torch
>>> block_size = 25
>>> use_cuda = torch.cuda.is_available()
>>> block_dtw_loss = block_dtw.block_dtw(block_size, sdtw_config = {'use_cuda' : use_cuda})
>>> batch_size = 5
>>> time_samples = 300
>>> channels = 1
>>> device = 'cuda' if use_cuda else 'cpu'
>>> x = torch.randn(batch_size, time_samples, channels).to(device)
>>> x_r = torch.randn(batch_size, time_samples, channels).to(device)
>>> output_block_dtw = block_dtw_loss(x, x_r)
Authors
-------
Alberto Zancanaro <alberto.zancanaro@uni.lu>
"""
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
# Imports
import torch
from .soft_dtw import soft_dtw
# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
[docs]
class block_dtw(torch.nn.Module) :
"""
Class that compute the block DTW loss, which is a variant of the SDTW that computes the SDTW on blocks of the signal instead of the entire signal.
The block DTW can be computed in two ways:
- Naive (SEQUENTIAL) implementation: compute the SDTW on each block separately and sum the results.
- Optimized (PARALLEL) implementation: exploit reshaping of the input tensors to compute the SDTW on all blocks at once.
This class will select automatically which implementation to use based on the input tensors length and block size.
See the docstring of block_dtw_optimized for more details on the requirements for the optimized implementation.
If you are not sure which implementation to use, you can use the block_dtw class.
Note that if you know a priori that the optimized version can be used in your case, it is recommended to use directly the block_dtw_optimized class, which is faster than the block_dtw class (no overhead of checking the input tensors length and block size).
Attributes
----------
block_size : int
Size of the blocks into which to divide the signal.
block_dtw_naive : :class:`block_dtw_naive`
Instance of the naive implementation of the block DTW.
block_dtw_optimized : :class:`block_dtw_optimized`
Instance of the optimized implementation of the block DTW.
Parameters
----------
block_size : int
Size of the blocks into which to divide the signal.
implementation : str
Implementation to use for the SDTW. This parameter is passed to the SDTW implementation used in the block DTW. See the docstring of the :class:`~dtw_loss_functions.soft_dtw.soft_dtw` class for more details on the available implementations and their parameters.
sdtw_config : dict, optional
Configuration dictionary for the SDTW function used in the block DTW. See the docstring of the :class:`~dtw_loss_functions.soft_dtw.soft_dtw` class for more details on the available parameters and their default values.
If a parameter is not specified, or if the dictionary is empty, the default values will be used for all the parameters.
"""
def __init__(self, block_size : int,
implementation : str = 'mag',
sdtw_config : dict = {}
) :
super().__init__()
self.block_size = block_size
self.block_dtw_naive = block_dtw_naive(block_size, implementation, sdtw_config)
self.block_dtw_optimized = block_dtw_optimized(block_size, implementation, sdtw_config)
[docs]
def forward(self, x : torch.tensor, x_r : torch.tensor) -> torch.tensor :
"""
Compute the block DTW loss between the input tensors ``x`` and ``x_r``.
Parameters
----------
x : torch.tensor
First input tensor of shape ``B x T x C``
x_r : torch.tensor
Second input tensor of shape ``B x T x C``
Returns
-------
recon_error : torch.tensor
Tensor of shape ``B`` containing the block DTW loss for each sample in the batch.
"""
if x.shape[1] % self.block_size == 0 :
return self.block_dtw_optimized(x, x_r)
else :
return self.block_dtw_naive(x, x_r)
[docs]
class block_dtw_naive(soft_dtw) :
"""
Naive implementation of the block DTW, which computes the SDTW on each block separately.
For details on the parameters, see the docstring of the :class:`block_dtw` class.
"""
def __init__(self, block_size : int, implementation : str = 'mag', sdtw_config : dict = {}) :
super().__init__(implementation = implementation, sdtw_config = sdtw_config)
self.block_size = block_size
[docs]
def forward(self, x : torch.tensor, x_r : torch.tensor) -> torch.tensor :
"""
Compute the block DTW loss between the input tensors ``x`` and ``x_r`` by computing the SDTW on each block separately and summing the results.
Parameters
----------
x : torch.tensor
First input tensor of shape ``B x T x C``
x_r : torch.tensor
Second input tensor of shape ``B x T x C``
Returns
-------
recon_error : torch.tensor
Tensor of shape ``B`` containing the block DTW loss for each sample in the batch.
"""
tmp_recon_loss = 0
i = 0
continue_cylce = True
while continue_cylce :
# Get indices for the block
idx_1 = int(i * self.block_size)
idx_2 = int((i + 1) * self.block_size) if int((i + 1) * self.block_size) < x.shape[1] else -1
# Get block of the signal
if idx_2 == -1 :
x_block = x[:, idx_1:, :]
x_r_block = x_r[:, idx_1:, :]
else :
x_block = x[:, idx_1:idx_2, :]
x_r_block = x_r[:, idx_1:idx_2, :]
# Compute dtw for the block
block_loss = super().forward(x_block, x_r_block)
# End the cylce at the last block
if idx_2 == -1 : continue_cylce = False
tmp_recon_loss += block_loss
# Increase index
i += 1
return tmp_recon_loss
[docs]
class block_dtw_optimized(soft_dtw) :
"""
Optimized implementation of the block DTW, which exploits reshaping of the input tensors to compute the SDTW on all blocks at once.
This version can be used only if the length of the input tensors is divisible by the block size, i.e. if ``length_signal % block_size == 0``.
Note that the class will not check if this condition is satisfied, so it is the responsibility of the user to ensure that the input tensors length and block size are compatible.
This requirement is necessary because the optimized implementation exploits reshaping of the input tensors to compute the SDTW on all blocks at once.
This works because SoftDTW implementation allow batched inputs,so we can reshape the input tensors to have a new batch size equal to the number of blocks, and compute the SDTW on all blocks at once.
For details on the parameters, see the docstring of the :class:`block_dtw` class.
"""
def __init__(self, block_size : int, implementation : str = 'mag', sdtw_config : dict = {}) :
super().__init__(implementation = implementation, sdtw_config = sdtw_config)
self.block_size = block_size
[docs]
def forward(self, x : torch.tensor, x_r : torch.tensor) -> torch.tensor :
"""
Compute the block DTW loss between the input tensors `x` and `x_r` by exploiting reshaping of the input tensors to compute the SDTW on all blocks at once.
Parameters
----------
x : torch.tensor
First input tensor of shape B x T x C
x_r : torch.tensor
Second input tensor of shape B x T x C
Returns
-------
recon_error : torch.tensor
Tensor of shape B containing the block DTW loss for each sample in the batch.
"""
# Compute new batch size
virtual_batch_size = int(x.shape[0] * (x.shape[1] / self.block_size))
# Reshape input tensors
x_reshaped = x.view(virtual_batch_size, self.block_size, -1)
x_r_reshaped = x_r.view(virtual_batch_size, self.block_size, -1)
block_loss = super().forward(x_reshaped, x_r_reshaped)
# Return the loss reshaped to have the original batch size
return block_loss.view(x.shape[0], -1).sum(dim = 1)