import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
logging.getLogger(__name__)
# Adapted from OCNet Repository (https://github.com/PkuRainBow/OCNet)
[docs]class OhemCrossEntropy2d(nn.Module):
"""
Adapted version of the Ohem Cross Entropy loss
from OCNet repository (https://github.com/PkuRainBow/OCNet).
"""
[docs] def __init__(self, thresh=0.6, min_kept=0, weight=None, ignore_index=255):
"""Initialize the Ohem Cross Entropy loss.
Args:
thresh (float, optional): threshold index apply to the model prediction. Defaults to 0.6.
min_kept (int, optional): _description_. Defaults to 0.
weight (Tensor, optional): a manual rescaling weight given to each class. Defaults to None.
ignore_index (int, optional): target value that is ignored and does not contribute to the input gradient. Defaults to 255.
"""
super().__init__()
self.ignore_label = ignore_index
self.thresh = float(thresh)
self.min_kept = int(min_kept)
self.criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index)
[docs] def forward(self, predict, target):
"""Foward function use during trainning.
Args:
predict (Tensor): the output from model, shape (N, C, H, W).
target (Tensor): ground truth, shape (N, H, W).
Returns:
Tensor: Ohem loss score.
"""
n, c, h, w = predict.size()
input_label = target.data.cpu().numpy().ravel().astype(np.int32)
x = np.rollaxis(predict.data.cpu().numpy(), 1).reshape((c, -1))
input_prob = np.exp(x - x.max(axis=0).reshape((1, -1)))
input_prob /= input_prob.sum(axis=0).reshape((1, -1))
valid_flag = input_label != self.ignore_label
valid_inds = np.where(valid_flag)[0]
label = input_label[valid_flag]
num_valid = valid_flag.sum()
if self.min_kept >= num_valid:
logging.info('Labels: {}'.format(num_valid))
elif num_valid > 0:
prob = input_prob[:, valid_flag]
pred = prob[label, np.arange(len(label), dtype=np.int32)]
threshold = self.thresh
if self.min_kept > 0:
index = pred.argsort()
threshold_index = index[min(len(index), self.min_kept) - 1]
if pred[threshold_index] > self.thresh:
threshold = pred[threshold_index]
kept_flag = pred <= threshold
valid_inds = valid_inds[kept_flag]
logging.info(
'hard ratio: {} = {} / {} '.format(round(len(valid_inds) / num_valid, 4), len(valid_inds), num_valid))
label = input_label[valid_inds].copy()
input_label.fill(self.ignore_label)
input_label[valid_inds] = label
logging.info(np.sum(input_label != self.ignore_label))
target = torch.from_numpy(input_label.reshape(target.size())).long()#.cuda()
return self.criterion(predict, target)