Source code for losses.boundary_loss

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F

logging.getLogger(__name__)


def one_hot(label, n_classes, requires_grad=True):
    """Return One Hot Label

    Args:
        label (_type_): _description_
        n_classes (_type_): _description_
        requires_grad (bool, optional): _description_. Defaults to True.

    Returns:
        _type_: label on a form of an one hot vector.
    """
    label = label.squeeze(1).type(torch.long)
    device = label.device
    one_hot_label = torch.eye(
        n_classes, device=device, requires_grad=requires_grad
    )[label]
    one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3)

    return one_hot_label


[docs]class BoundaryLoss(nn.Module): """ Boundary Loss proposed in the paper *Boundary Loss for Remote Sensing Imagery Semantic Segmentation* from *Alexey Bokhovkin et al.* (https://arxiv.org/abs/1905.07852) From: https://github.com/yiskw713/boundary_loss_for_remote_sensing """
[docs] def __init__(self, theta0=19, theta=19, ignore_index=None): """Initialize the boundary loss. Args: theta0 (int, optional): size of the sliding window. Defaults to 19. theta (int, optional): predened threshold on a distance. Defaults to 19. ignore_index (int, optional): index to be ignore during trainning. Defaults to None. """ super().__init__() self.theta0 = theta0 self.theta = theta self.ignore_index = ignore_index if self.ignore_index: logging.error(f'Ignore_index not implemented for Boundary Loss. Got ignore_index "{ignore_index}"')
[docs] def forward(self, pred, gt): """Foward function use during trainning. Args: pred (Tensor): the output from model (before softmax), shape (N, C, H, W). gt (Tensor): ground truth, shape (N, H, W). Returns: Tensor: boundary loss score, averaged over mini-batch. """ n, c, _, _ = pred.shape logging.debug(f"Prediction shape: {gt.shape}") # softmax so that predicted map can be distributed in [0, 1] pred = torch.softmax(pred, dim=1) # one-hot vector of ground truth logging.debug(f"Ground truth shape: {gt.shape}") one_hot_gt = one_hot(gt, c) # boundary map gt_b = F.max_pool2d( 1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) gt_b -= 1 - one_hot_gt pred_b = F.max_pool2d( 1 - pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) pred_b -= 1 - pred # extended boundary map gt_b_ext = F.max_pool2d( gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) pred_b_ext = F.max_pool2d( pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) # reshape gt_b = gt_b.view(n, c, -1) pred_b = pred_b.view(n, c, -1) gt_b_ext = gt_b_ext.view(n, c, -1) pred_b_ext = pred_b_ext.view(n, c, -1) # Precision, Recall P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7) R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7) # Boundary F1 Score BF1 = 2 * P * R / (P + R + 1e-7) # summing BF1 Score for each class and average over mini-batch loss = torch.mean(1 - BF1) return loss
# for debug if __name__ == "__main__": import torch.optim as optim from torchvision.models import segmentation device = 'cuda' if torch.cuda.is_available() else 'cpu' img = torch.randn(8, 3, 224, 224).to(device) gt = torch.randint(0, 10, (8, 224, 224)).to(device) model = segmentation.fcn_resnet50(num_classes=10).to(device) optimizer = optim.Adam(model.parameters(), lr=0.0001) criterion = BoundaryLoss() y = model(img) loss = criterion(y['out'], gt) optimizer.zero_grad() loss.backward() optimizer.step() print(loss)