Source code for losses.ohem_loss

import logging

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

logging.getLogger(__name__)

# Adapted from OCNet Repository (https://github.com/PkuRainBow/OCNet)
[docs]class OhemCrossEntropy2d(nn.Module): """ Adapted version of the Ohem Cross Entropy loss from OCNet repository (https://github.com/PkuRainBow/OCNet). """
[docs] def __init__(self, thresh=0.6, min_kept=0, weight=None, ignore_index=255): """Initialize the Ohem Cross Entropy loss. Args: thresh (float, optional): threshold index apply to the model prediction. Defaults to 0.6. min_kept (int, optional): _description_. Defaults to 0. weight (Tensor, optional): a manual rescaling weight given to each class. Defaults to None. ignore_index (int, optional): target value that is ignored and does not contribute to the input gradient. Defaults to 255. """ super().__init__() self.ignore_label = ignore_index self.thresh = float(thresh) self.min_kept = int(min_kept) self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
[docs] def forward(self, predict, target): """Foward function use during trainning. Args: predict (Tensor): the output from model, shape (N, C, H, W). target (Tensor): ground truth, shape (N, H, W). Returns: Tensor: Ohem loss score. """ n, c, h, w = predict.size() input_label = target.data.cpu().numpy().ravel().astype(np.int32) x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1)) input_prob = np.exp(x - x.max(axis=0).reshape((1, -1))) input_prob /= input_prob.sum(axis=0).reshape((1, -1)) valid_flag = input_label != self.ignore_label valid_inds = np.where(valid_flag)[0] label = input_label[valid_flag] num_valid = valid_flag.sum() if self.min_kept >= num_valid: logging.info('Labels: {}'.format(num_valid)) elif num_valid > 0: prob = input_prob[:, valid_flag] pred = prob[label, np.arange(len(label), dtype=np.int32)] threshold = self.thresh if self.min_kept > 0: index = pred.argsort() threshold_index = index[min(len(index), self.min_kept) - 1] if pred[threshold_index] > self.thresh: threshold = pred[threshold_index] kept_flag = pred <= threshold valid_inds = valid_inds[kept_flag] logging.info( 'hard ratio: {} = {} / {} '.format(round(len(valid_inds) / num_valid, 4), len(valid_inds), num_valid)) label = input_label[valid_inds].copy() input_label.fill(self.ignore_label) input_label[valid_inds] = label logging.info(np.sum(input_label != self.ignore_label)) target = torch.from_numpy(input_label.reshape(target.size())).long()#.cuda() return self.criterion(predict, target)