Aggregation

import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import

get_aggregation

Obtains the aggregation method to be utilized based on the model’s configurations. The aggregation methods include selecting the ``<cls>``s’ representations, selecting the markers’ representations, max-pooling, and dynamic multi-pooling.

Args:

  • config: The configurations of the model.

Returns:

  • The proposed method/class for the aggregation process.

def get_aggregation(config):
    """Obtains the aggregation method to be utilized.
    Obtains the aggregation method to be utilized based on the model's configurations. The aggregation methods include
    selecting the `<cls>`s' representations, selecting the markers' representations, max-pooling, and dynamic
    multi-pooling.
    Args:
        config:
            The configurations of the model.
    Returns:
        The proposed method/class for the aggregation process.
    """
    if config.aggregation == "cls":
        return select_cls
    elif config.aggregation == "marker":
        return select_marker
    elif config.aggregation == "dynamic_pooling":
        return DynamicPooling(config)
    elif config.aggregation == "max_pooling":
        return max_pooling
    else:
        raise ValueError("Invaild %s aggregation method" % config.aggregation)

aggregate

Aggregates information to each position. The aggregation methods include selecting the “cls”s’ representations, selecting the markers’ representations, max-pooling, and dynamic multi-pooling.

Args:

  • config: The configurations of the model.

  • method: The method proposed to be utilized in the aggregation process.

  • hidden_states: A tensor representing the hidden states output by the backbone model.

  • trigger_left: A tensor indicating the left position of the triggers.

  • trigger_right: A tensor indicating the right position of the triggers.

  • argument_left: A tensor indicating the left position of the arguments.

  • argument_right: A tensor indicating the right position of the arguments.

def aggregate(config,
              method,
              hidden_states: torch.Tensor,
              trigger_left: torch.Tensor,
              trigger_right: torch.Tensor,
              argument_left: torch.Tensor,
              argument_right: torch.Tensor):
    """Aggregates information to each position.
    Aggregates information to each position. The aggregation methods include selecting the "cls"s' representations,
    selecting the markers' representations, max-pooling, and dynamic multi-pooling.
    Args:
        config:
            The configurations of the model.
        method:
            The method proposed to be utilized in the aggregation process.
            TODO: The data type of the variable `method` should be configured.
        hidden_states (`torch.Tensor`):
            A tensor representing the hidden states output by the backbone model.
        trigger_left (`torch.Tensor`):
            A tensor indicating the left position of the triggers.
        trigger_right (`torch.Tensor`):
            A tensor indicating the right position of the triggers.
        argument_left (`torch.Tensor`):
            A tensor indicating the left position of the arguments.
        argument_right (`torch.Tensor`):
            A tensor indicating the right position of the arguments.
    """
    if config.aggregation == "cls":
        return method(hidden_states)
    elif config.aggregation == "marker":
        if argument_left is not None:
            return method(hidden_states, argument_left, argument_right)
        else:
            return method(hidden_states, trigger_left, trigger_right)
    elif config.aggregation == "max_pooling":
        return method(hidden_states)
    elif config.aggregation == "dynamic_pooling":
        return method(hidden_states, trigger_left, argument_left)
    else:
        raise ValueError("Invaild %s aggregation method" % config.aggregation)

max_pooling

Applies the max-pooling operation over the representation of the entire input sequence to capture the most useful information. The operation processes on the hidden states, which are output by the backbone model.

Args:

  • ``hidden_states`: A tensor representing the hidden states output by the backbone model.

Returns:

  • pooled_states: A tensor represents the max-pooled hidden states, containing the most useful information of the sequence.

def max_pooling(hidden_states: torch.Tensor) -> torch.Tensor:
    """Applies the max-pooling operation over the sentence representation.
    Applies the max-pooling operation over the representation of the entire input sequence to capture the most useful
    information. The operation processes on the hidden states, which are output by the backbone model.
    Args:
        hidden_states (`torch.Tensor`):
            A tensor representing the hidden states output by the backbone model.
    Returns:
        pooled_states (`torch.Tensor`):
            A tensor represents the max-pooled hidden states, containing the most useful information of the sequence.
    """
    batch_size, seq_length, hidden_size = hidden_states.size()
    pooled_states = F.max_pool1d(input=hidden_states.transpose(1, 2), kernel_size=seq_length).squeeze(-1)
    return pooled_states

select_cls

Returns the representations of each sequence’s <cls> token by slicing the hidden state tensor output by the backbone model. The representations of the <cls> tokens contain general information of the sequences.

Args:

  • hidden_states: A tensor represents the hidden states output by the backbone model.

Returns:

  • A tensor containing the representations of each sequence’s <cls> token.

def select_cls(hidden_states: torch.Tensor) -> torch.Tensor:
    """Returns the representations of the `<cls>` tokens.
    Returns the representations of each sequence's `<cls>` token by slicing the hidden state tensor output by the
    backbone model. The representations of the `<cls>` tokens contain general information of the sequences.
    Args:
        hidden_states (`torch.Tensor`):
            A tensor represents the hidden states output by the backbone model.
    Returns:
        `torch.Tensor`:
            A tensor containing the representations of each sequence's `<cls>` token.
    """
    return hidden_states[:, 0, :]

select_marker

Returns the representations of each sequence’s marker tokens by slicing the hidden state tensor output by the backbone model.

Args:

  • hidden_states: A tensor representing the hidden states output by the backbone model.

  • left: A tensor indicates the left position of the markers.

  • right: A tensor indicates the right position of the markers.

Returns:

  • ``marker_output`: A tensor containing the representations of each sequence’s marker tokens by concatenating their left and right token’s representations.

def select_marker(hidden_states: torch.Tensor,
                  left: torch.Tensor,
                  right: torch.Tensor) -> torch.Tensor:
    """Returns the representations of the marker tokens.
    Returns the representations of each sequence's marker tokens by slicing the hidden state tensor output by the
    backbone model.
    Args:
        hidden_states (`torch.Tensor`):
            A tensor representing the hidden states output by the backbone model.
        left (`torch.Tensor`):
            A tensor indicates the left position of the markers.
        right (`torch.Tensor`):
            A tensor indicates the right position of the markers.
    Returns:
        marker_output (`torch.Tensor`):
            A tensor containing the representations of each sequence's marker tokens by concatenating their left and
            right token's representations.
    """
    batch_size = hidden_states.size(0)
    batch_indice = torch.arange(batch_size)
    left_states = hidden_states[batch_indice, left.to(torch.long), :]
    right_states = hidden_states[batch_indice, right.to(torch.long), :]
    marker_output = torch.cat((left_states, right_states), dim=-1)
    return marker_output

DynamicPooling

Dynamic multi-pooling layer for Convolutional Neural Network (CNN), which is able to capture more valuable information within a sentence, particularly for some cases, such as multiple triggers are within a sentence and different argument candidate may play a different role with a different trigger.

Attributes:

  • activation: An nn.Tanh layer representing the tanh activation function.

  • dropout: An nn.Dropout layer for the dropout operation with the default dropout rate (0.5).

class DynamicPooling(nn.Module):
    """Dynamic multi-pooling layer for Convolutional Neural Network (CNN).
    Dynamic multi-pooling layer for Convolutional Neural Network (CNN), which is able to capture more valuable
    information within a sentence, particularly for some cases, such as multiple triggers are within a sentence and
    different argument candidate may play a different role with a different trigger.
    Attributes:
        dense (`nn.Linear`):
            TODO: The purpose of the linear layer should be configured.
        activation (`nn.Tanh`):
            An `nn.Tanh` layer representing the tanh activation function.
        dropout (`nn.Dropout`):
            An `nn.Dropout` layer for the dropout operation with the default dropout rate (0.5).
    """
    def __init__(self,
                 config) -> None:
        """Constructs a `DynamicPooling`."""
        super(DynamicPooling, self).__init__()
        self.dense = nn.Linear(config.hidden_size*config.head_scale, config.hidden_size*config.head_scale)
        self.activation = nn.Tanh()
        self.dropout = nn.Dropout()

    def get_mask(self,
                 position: torch.Tensor,
                 batch_size: int,
                 seq_length: int,
                 device: str) -> torch.Tensor:
        """Returns the mask indicating whether the token is padded or not."""
        all_masks = []
        for i in range(batch_size):
            mask = torch.zeros((seq_length), dtype=torch.int16, device=device)
            mask[:int(position[i])] = 1
            all_masks.append(mask.to(torch.bool))
        all_masks = torch.stack(all_masks, dim=0)
        return all_masks

    def max_pooling(self,
                    hidden_states: torch.Tensor,
                    mask: torch.Tensor) -> torch.Tensor:
        """Conducts the max-pooling operation on the hidden states."""
        batch_size, seq_length, hidden_size = hidden_states.size()
        conved = hidden_states.transpose(1, 2)
        conved = conved.transpose(0, 1)
        states = (conved * mask).transpose(0, 1)
        states += torch.ones_like(states)
        pooled_states = F.max_pool1d(input=states, kernel_size=seq_length).contiguous().view(batch_size, hidden_size)
        pooled_states -= torch.ones_like(pooled_states)
        return pooled_states

    def forward(self,
                hidden_states: torch.Tensor,
                trigger_position: torch.Tensor,
                argument_position: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Conducts the dynamic multi-pooling process on the hidden states."""
        batch_size, seq_length = hidden_states.size()[:2]
        trigger_mask = self.get_mask(trigger_position, batch_size, seq_length, hidden_states.device)
        if argument_position is not None:
            argument_mask = self.get_mask(argument_position, batch_size, seq_length, hidden_states.device)
            left_mask = torch.logical_and(trigger_mask, argument_mask).to(torch.float32)
            middle_mask = torch.logical_xor(trigger_mask, argument_mask).to(torch.float32)
            right_mask = 1 - torch.logical_or(trigger_mask, argument_mask).to(torch.float32)
            # pooling
            left_states = self.max_pooling(hidden_states, left_mask)
            middle_states = self.max_pooling(hidden_states, middle_mask)
            right_states = self.max_pooling(hidden_states, right_mask)
            pooled_output = torch.cat((left_states, middle_states, right_states), dim=-1)
        else:
            left_mask = trigger_mask.to(torch.float32)
            right_mask = 1 - left_mask
            left_states = self.max_pooling(hidden_states, left_mask)
            right_states = self.max_pooling(hidden_states, right_mask)
            pooled_output = torch.cat((left_states, right_states), dim=-1)

        return pooled_output