Label Smoother Sum

Note

Copyright Text2Event from https://github.com/luyaojie/Text2Event.

Licensed under the MIT License.

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright Text2Event from https://github.com/luyaojie/Text2Event.
# Licensed under the MIT License.

import torch

from dataclasses import dataclass
from typing import Dict

SumLabelSmoother

A label-smoothing sum module operated on the pre-computed output from the model, which is a regularization technique that addresses the overfitting and overconfidence problems by adding some noises to decrease the weights of the actual samples when calculating losses.

Attributes:

  • epsilon: A float variable indicating the label smoothing factor.

  • ignore_index: An integer representing the index in the labels to ignore when computing the loss.

@dataclass
class SumLabelSmoother:
    """A label-smoothing sum module operated on the pre-computed output from the model.
    A label-smoothing sum module operated on the pre-computed output from the model, which is a regularization technique
    that addresses the overfitting and overconfidence problems by adding some noises to decrease the weights of the
    actual samples when calculating losses.
    Attributes:
        epsilon (`float`, `optional`, defaults to 0.1):
            A float variable indicating the label smoothing factor.
        ignore_index (`int`, `optional`, defaults to -100):
            An integer representing the index in the labels to ignore when computing the loss.
    """

    epsilon: float = 0.1
    ignore_index: int = -100

    def __call__(self,
                 model_output: Dict[str, torch.Tensor],
                 labels: torch.Tensor) -> float:
        """Conducts the label smoothing process."""
        logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
        log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
        if labels.dim() == log_probs.dim() - 1:
            labels = labels.unsqueeze(-1)

        padding_mask = labels.eq(self.ignore_index)
        # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
        # will ignore them in any case.
        labels.clamp_min_(0)
        nll_loss = log_probs.gather(dim=-1, index=labels)
        smoothed_loss = log_probs.sum(dim=-1, keepdim=True)

        nll_loss.masked_fill_(padding_mask, 0.0)
        smoothed_loss.masked_fill_(padding_mask, 0.0)

        # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
        # num_active_elements = padding_mask.numel() - padding_mask.long().sum()
        nll_loss = nll_loss.sum()  # / num_active_elements
        smoothed_loss = smoothed_loss.sum()  # / (num_active_elements * log_probs.shape[-1])
        eps_i = self.epsilon / log_probs.size(-1)
        return (1 - self.epsilon) * nll_loss + eps_i * smoothed_loss