Welcome to OmniEvent’s documentation!

_images/Omnievent.png

Overview

OmniEvent is a powerful open-source toolkit for event extraction, including event detection and event argument extraction. We comprehensively cover various paradigms and provide fair and unified evaluations on widely-used English and Chinese datasets. Modular implementations make OmniEvent highly extensible.

Highlights

  • Comprehensive Capability
    • Support to do Event Extraction at once, and also to independently do its two subtasks: Event Detection, Event Argument Extraction.

    • Cover various paradigms: Token Classification, Sequence Labeling, MRC (QA) and Seq2Seq, are deployed.

    • Implement Transformers-based (BERT, T5, etc.) and classical models (CNN, LSTM, CRF, etc.) are implemented.

    • Both Chinese and English are supported for all event extraction sub-tasks, paradigms and models.

  • Modular Implementation
    • All models are decomposed into four modules:
      • Input Engineering: Prepare inputs and support various input engineering methods like prompting.

      • Backbone: Encode text into hidden states.

      • Aggregation: Fuse hidden states (e.g., select [CLS], pooling, GCN) to the final event representation.

      • Output Head: Map the event representation to the final outputs, such as Linear, CRF, MRC head, etc.

  • Unified Benchmark & Evaluation
    • Various datasets are processed into a unified format.

    • Predictions of different paradigms are all converted into a unified candidate set for fair evaluations.

    • Four evaluation modes (gold, loose, default, strict) well cover different previous evaluation settings.

  • Big Model Training & Inference
    • Efficient training and inference of big models for event extraction are supported with BMTrain.

  • Easy to Use & Highly Extensible
    • Datasets can be downloaded and processed with a single command.

    • Fully compatible with 🤗 Transformers and its Trainer).

    • Users can easily reproduce existing models and build customized models with OmniEvent.

Installation

With pip

This repository is tested on Python 3.9+, Pytorch 1.12.1+. OmniEvent can be installed with pip as follows:

pip install OmniEvent

Easy Start

OmniEvent provides ready-to-use models for the users. Examples are shown below.

Make sure you have installed OmniEvent as instructed above. Note that it may take a few minutes to download checkpoint for the first time.

Train your Own Model with OmniEvent

OmniEvent can help users easily train and evaluate their customized models on a specific dataset.

We show a step-by-step example of using OmniEvent to train and evaluate an Event Detection model on ACE-EN dataset in the Seq2Seq paradigm. More examples are shown in examples.

Step 1: Process the dataset into the unified format

We provide standard data processing scripts for commonly-adopted datasets. Checkout the details in scripts/data_processing.

dataset=ace2005-en  # the dataset name
cd scripts/data_processing/$dataset
bash run.sh

Step 2: Set up the customized configurations

We keep track of the configurations of dataset, model and training parameters via a single *.yaml file. See /configs for details.

>>> from OmniEvent.arguments import DataArguments, ModelArguments, TrainingArguments, ArgumentParser
>>> from OmniEvent.input_engineering.seq2seq_processor import type_start, type_end

>>> parser = ArgumentParser((ModelArguments, DataArguments, TrainingArguments))
>>> model_args, data_args, training_args = parser.parse_yaml_file(yaml_file="config/all-datasets/ed/s2s/ace-en.yaml")

>>> training_args.output_dir = 'output/ACE2005-EN/ED/seq2seq/t5-base/'
>>> data_args.markers = ["<event>", "</event>", type_start, type_end]

Step 3: Initialize the model and tokenizer

OmniEvent supports various backbones. The users can specify the model and tokenizer in the config file and initialize them as follows.

>>> from OmniEvent.backbone.backbone import get_backbone
>>> from OmniEvent.model.model import get_model

>>> backbone, tokenizer, config = get_backbone(model_type=model_args.model_type,
                                               model_name_or_path=model_args.model_name_or_path,
                                               tokenizer_name=model_args.model_name_or_path,
                                               markers=data_args.markers,
                                               new_tokens=data_args.markers)
>>> model = get_model(model_args, backbone)

Step 4: Initialize dataset and evaluation metric

OmniEvent prepares the DataProcessor and the corresponding evaluation metrics for different task and paradigms.

Note

Note that the metrics here are paradigm-dependent and are not used for the final unified evaluation.

>>> from OmniEvent.input_engineering.seq2seq_processor import EDSeq2SeqProcessor
>>> from OmniEvent.evaluation.metric import compute_seq_F1

>>> train_dataset = EDSeq2SeqProcessor(data_args, tokenizer, data_args.train_file)
>>> eval_dataset = EDSeq2SeqProcessor(data_args, tokenizer, data_args.validation_file)
>>> metric_fn = compute_seq_F1

Step 5: Define Trainer and train

OmniEvent adopts Trainer from 🤗 Transformers) for training and evaluation.

>>> from OmniEvent.trainer_seq2seq import Seq2SeqTrainer

>>> trainer = Seq2SeqTrainer(
        args=training_args,
        model=model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        compute_metrics=metric_fn,
        data_collator=train_dataset.collate_fn,
        tokenizer=tokenizer,
    )
>>> trainer.train()

Step 6: Unified Evaluation

Since the metrics in Step 4 depend on the paradigm, it is not fair to directly compare the performance of different paradigms.

OmniEvent evaluates models of different paradigms in a unifed manner, where the predictions of different models are converted to word-level and then evaluated.

>>> from OmniEvent.evaluation.utils import predict, get_pred_s2s
>>> from OmniEvent.evaluation.convert_format import get_trigger_detection_s2s

>>> logits, labels, metrics, test_dataset = predict(trainer=trainer, tokenizer=tokenizer, data_class=data_class,
                                                    data_args=data_args, data_file=data_args.test_file,
                                                    training_args=training_args)
>>> # paradigm-dependent metrics
>>> print("{} test performance before converting: {}".formate(test_dataset.dataset_name, metrics["test_micro_f1"]))
ACE2005-EN test performance before converting: 66.4215686224377

>>> preds = get_pred_s2s(logits, tokenizer)
>>> # convert to the unified prediction and evaluate
>>> pred_labels = get_trigger_detection_s2s(preds, labels, data_args.test_file, data_args, None)
ACE2005-EN test performance after converting: 67.41016109045849

For those datasets whose test set annotations are not given, such as MAVEN and LEVEN, OmniEvent provide APIs to generate submission files. See dump_result.py) for details.

Supported Datasets & Models & Contests

Continually updated. Welcome to add more!

Datasets

Language

Domain

Task

Dataset

English

General

ED

MAVEN

English

General

ED EAE

ACE-EN

English

General

ED EAE

ACE-DYGIE

English

General

ED EAE

RichERE (KBP + ERE)

Chinese

Legal

ED

LEVEN

Chinese

General

ED EAE

DuEE

Chinese

General

ED EAE

ACE-ZH

Chinese

Financial

ED EAE

FewFC

Models

  • Paradigm
    • Token Classification (TC)

    • Sequence Labeling (SL)

    • Sequence to Sequence (Seq2Seq)

    • Machine Reading Comprehension (MRC)

  • Backbone
    • CNN / LSTM

    • Transformers (BERT, T5, etc.)

  • Aggregation
    • Select [CLS]

    • Dynamic/Max Pooling

    • Marker

    • GCN

  • Head
    • Linear / CRF / MRC heads

Contests

OmniEvent plans to support various event extraction contest. Currently, we support the following contests and the list is continually updated!

Experiments

We implement and evaluate state-of-the-art methods on some popular benchmarks using OmniEvent. The results of all Event Detection experiments are shown in the table below. The full results can be accessed via the links below.

_images/exp1.png _images/exp2.png

Convert the Dataset into Unified OmniEvent Format

To simplify subsequent data loading and modeling, we provide pre-processing scripts for commonly-used Event Extraction datasets. Users can download the dataset and convert it to the unified OmniEvent format by configuring the data path defined in the run.sh file under the scripts/data_preprocessing folder with the same name as the dataset.

Unified OmniEvent Format

A unified OmniEvent dataset is a JSON Line file with the extension .unified.jsonl (such as, train.unified.jsonl, valid.unified.jsonl, and test.unified.jsonl), which is a convenient format for storing structured data that enables processing one record, in one line, at a time. Taking a record from TAC KBP 2016 as an example, a piece of data in the unified OmniEvent format could be demonstrated as follows:

{
    "id": "NYT_ENG_20130910.0002-6",
    "text": "In 1997 , Chun was sentenced to life in prison and Roh to 17 years .",
    "events": [{
        "type": "sentence",
        "triggers": [{
            "id": "em-2342",
            "trigger_word": "sentenced",
            "position": [19, 28],
            "arguments": [{
                "role": "defendant",
                "mentions": [{
                    "id": "m-291",
                    "mention": "Chun",
                    "position": [10, 14]}]}, ... ]}, ... ]} ... ],
    "negative_triggers": [{
        "id": 0,
        "trigger_word": "In",
        "position": [0, 2]}, ... ],
    "entities":  [{
        "type": "PER",
        "mentions": [{
            "id": "m-291",
            "mention": "Chun",
            "position": [10, 14]}, ... ]}, ... ]}

Supported Datasets

The pre-processing scripts support almost all commonly-used Event Extraction datasets, so as to minimize the data conversion difficulties. Additional pre-processing scripts are still being developed, and you can submit datasets for which you wish us to complete in “Pull requests”. Currently, we have developed pre-processing scripts for the following datasets:

  • ACE2005: ACE2005-EN, ACE2005-DyGIE, ACE2005-OneIE, ACE2005-ZH

  • DuEE: DuEE1.0, DuEE-fin

  • ERE: LDC2015E29, LDC2015E68, LDC2015E78

  • FewFC

  • TAC KBP: TAC KBP 2014, TAC KBP 2015, TAC KBP 2016, TAC KBP 2017

  • LEVEN

  • MAVEN

Dataset Conversion

Step 1: Download the Dataset

The first step of data conversion is to download the proposed dataset from its corresponding website. For example, for the DuEE 1.0 dataset, it could be downloaded from here.

Step 2: Configure the Dataset Path

After downloading the dataset from the Internet, the run.sh file under the folder with the same name as the dataset should be configured. For example, for the DuEE 1.0 dataset, the run.sh file under the path scripts/data_preprocessing/duee should be configured, in which the data_dir path should be the same as the path of placing the downloaded dataset, you can also modify the path of the processed dataset by configuring the save_dir path:

python duee.py \
    --data_dir ../../../data/original/DuEE1.0 \
    --save_dir ../../../data/processed/DuEE1.0
Step 3: Execute the run.sh File

After downloading the dataset and configuring the corresponding run.sh file, finally, the dataset could finally be converted to the unified OmniEvent format by executing the configured run.sh file. For example, for the DuEE1.0 dataset, we could execute the run.sh file as follows:

bash run.sh

Examples

Note

To make sure you run the lastest versions of example scirpts, you need to install the repository from source as follows:

git clone https://github.com/THU-KEG/OmniEvent.git
cd OmniEvent
pip install .

BigModel

The BigModel directory contains tuning code for large PLMs. The tuning code is supported by BMTrain engine.

ED

The ED directory contains examples of event detection.

EAE

The EAE directory contains examples of event argument extraction. You can conduct EAE independently using golden event triggers or you can use the predictions of ED to do event extraction.

Tuning Large PLMs for Event Extraction

We provide an example script for tuning large pre-trained language models (PLMs) on event extraction tasks. We use BMTrain as the distributed training engine. BMTrain is an efficient large model training toolkit, see BMTrain and ModelCenter for more details. We adapt the code of ModelCenter for event extraction and place the code in OmniEvent/utils.

Setup

Install the code in OmniEvent/utils/ModelCenter:

cd utils/ModelCenter
pip install .

Easy Start

Run bash train.sh to train MT5-xxl. You can modify the config and the important hyper-parameters are as follows:

NNODES # number of nodes
GPUS_PER_NODE # gpus use on one node
model-config # We only support T5 and MT5

The original ModelCenter repo doesn’t support inference method (i.e. generate) for decoder PLMs. We provide beam_search.py for inference.

Tokenizer

import collections
import logging
import numpy as np
import os
import pdb

from transformers import PreTrainedTokenizer
from typing import Dict, Iterable, List, Optional, Tuple, Union

logger = logging.getLogger(__name__)

load_vocab

Loads a vocabulary file, allocates a unique id for each word within the vocabulary and saves the correspondence between words and ids into a dictionary. Generates and returns word embeddings if it is required.

Args:

  • vocab_file: The path of the vocabulary file.

  • return_embeddings: Whether or not to return the word embeddings.

Returns:

  • word_embeddings: An numpy array represents each word’s embedding within the vocabulary, with the size of (number of words) * (embedding dimension). Returns word embeddings if return_embeddings is set as True.

  • vocab: A dictionary indicates the unique id of each word within the vocabulary.

def load_vocab(vocab_file: str,
            return_embeddings: bool = False) -> Union[Dict[str, int], np.ndarray]:
    """Loads a vocabulary file into a dictionary.

    Loads a vocabulary file, allocates a unique id for each word within the vocabulary and saves the correspondence
    between words and ids into a dictionary. Generates and returns word embeddings if it is required.

    Args:
        vocab_file (`str`):
            The path of the vocabulary file.
        return_embeddings (`bool`, `optional`, defaults to `False`):
            Whether or not to return the word embeddings.

    Returns:
        word_embeddings (`np.ndarray`):
            An numpy array represents each word's embedding within the vocabulary, with the size of (number of words) *
            (embedding dimension). Returns word embeddings if `return_embeddings` is set as True.
        vocab (`Dict[str, int]`):
            A dictionary indicates the unique id of each word within the vocabulary.
    """
    vocab = collections.OrderedDict()
    vocab["[PAD]"] = 0
    with open(vocab_file, "r", encoding="utf-8") as reader:
        lines = reader.readlines()
    num_embeddings = len(lines) + 1
    embedding_dim = len(lines[0].split()) - 1
    for index, line in enumerate(lines):
        token = " ".join(line.split()[:-embedding_dim])
        if token in vocab:
            token = f"{token}_{index+1}"
        vocab[token] = index + 1
    if return_embeddings:
        word_embeddings = np.zeros((num_embeddings, embedding_dim), dtype=np.float32)
        for index, line in enumerate(lines):
            embedding = [float(value) for value in line.strip().split()[-embedding_dim:]]
            word_embeddings[index+1] = embedding
        return word_embeddings
    return vocab

whitespace_tokenize()

Cleans the whitespace at the beginning and end of the text and splits the text into a list based on whitespaces.

Args: - tex: A string representing the input text to be processed.

Returns:

  • tokens: A list of strings in which each element represents a word within the input text.

def whitespace_tokenize(text: str) -> List[str]:
    """Runs basic whitespace cleaning and splitting on a piece of text.

    Cleans the whitespace at the beginning and end of the text and splits the text into a list based on whitespaces.

    Args:
        text (`str`):
            A string representing the input text to be processed.

    Returns:
        tokens (`List[str]`):
            A list of strings in which each element represents a word within the input text.
    """
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens

WordLevelTokenizer

This tokenizer inherits from PreTrainedTokenizer which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

Attributes:

  • vocab: A dictionary indicating the correspondence between words and ids within the vocabulary.

  • ids_to_tokens: A dictionary indicating the correspondence between ids and words within the vocabulary.

  • whitespace_tokenizer: A WhitespaceTokenizer instance for word piece tokenization.

VOCAB_FILES_NAMES = {"vocab_file": "vec.txt"}

PRETRAINED_VOCAB_FILES_MAP = {}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}

PRETRAINED_INIT_CONFIGURATION = {}
class WordLevelTokenizer(PreTrainedTokenizer):
    """Construct a BERT tokenizer. Based on WordPiece.

    This tokenizer inherits from `PreTrainedTokenizer` which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.

    Attributes:
        vocab (`Dict[str, int]`):
            A dictionary indicating the correspondence between words and ids within the vocabulary.
        ids_to_tokens (`Dict[int, str]`):
            A dictionary indicating the correspondence between ids and words within the vocabulary.
        whitespace_tokenizer (`WhitespaceTokenizer`):
            A `WhitespaceTokenizer` instance for word piece tokenization.
    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

    def __init__(self,
                vocab_file: str,
                do_lower_case: bool = True,
                never_split: Iterable = None,
                unk_token: str = "[UNK]",
                sep_token: str = "[SEP]",
                pad_token: str = "[PAD]",
                cls_token: str = "[CLS]",
                strip_accents: bool = None,
                model_max_length: int = 512,
                **kwargs):
        """Construct a WordLevelTokenizer."""
        kwargs["model_max_length"] = model_max_length
        super().__init__(
            do_lower_case=do_lower_case,
            never_split=never_split,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            strip_accents=strip_accents,
            **kwargs,
        )

        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        self.vocab = load_vocab(vocab_file)
        # insert special token
        for token in [unk_token, sep_token, pad_token, cls_token]:
            if token not in self.vocab:
                self.vocab[token] = len(self.vocab)
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
        self.whitespace_tokenizer = WhitespaceTokenizer(vocab=self.vocab, do_lower_case=do_lower_case,
                                                        unk_token=self.unk_token)

    @property
    def do_lower_case(self):
        """Returns whether or not to lowercase the input when tokenizing."""
        return self.whitespace_tokenizer.do_lower_case

    @property
    def vocab_size(self):
        """Returns the length of the vocabulary"""
        return len(self.vocab)

    def get_vocab(self):
        """Returns the vocabulary in a dictionary."""
        return dict(self.vocab, **self.added_tokens_encoder)

    def _tokenize(self,
                text: str):
        """Tokenizes the input text into tokens."""
        if self.do_lower_case:
            text = text.lower()
        split_tokens = self.whitespace_tokenizer.tokenize(text)
        return split_tokens

    def _convert_token_to_id(self,
                            token: str):
        """Converts a token (`str`) in an id using the vocab."""
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self,
                            index: int):
        """Converts an index (`int`) in a token (`str`) using the vocab."""
        return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self,
                                tokens: str):
        """Converts a sequence of tokens (`str`) in a single string."""
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    def build_inputs_with_special_tokens(self,
                                        token_ids_0: List[int],
                                        token_ids_1: Optional[List[int]] = None) -> List[int]:
        """Builds model inputs from a sequence or a pair of sequence.
        Builds model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:

        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`

        Args:
            token_ids_0 (`List[int]`):
                List of ids to which the special tokens will be added.
            token_ids_1 (`List[int]`, `optional`):
                Optional second list of ids for sequence pairs.

        Returns:
            `List[int]`: List of [input ids](../glossary#input-ids) with the appropriate special tokens.
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(self,
                                token_ids_0: List[int],
                                token_ids_1: Optional[List[int]] = None,
                                already_has_special_tokens: bool = False) -> List[int]:
        """Retrieve sequence ids from a token list that has no special tokens added."""

        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(self,
                                            token_ids_0: List[int],
                                            token_ids_1: Optional[List[int]] = None) -> List[int]:
        """Create a mask from the two sequences passed to be used in a sequence-pair classification task."""
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self,
                        save_directory: str,
                        filename_prefix: Optional[str] = None) -> Tuple[str]:
        """Saves the vocabulary (copy original file) and special tokens file to a directory."""
        index = 0
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!"
                    )
                    index = token_index
                writer.write(token + "\n")
                index += 1
        return (vocab_file,)

WhitespaceTokenizer

Tokenizes a piece of text into its word pieces by matching whether the token is in the vocabulary.

Attributes:

  • vocab: A dictionary indicates the correspondence between words and ids within the vocabulary.

  • do_lower_case: A boolean variable indicating Whether or not to lowercase the input when tokenizing.

  • unk_token: A string representing the unknown token.

class WhitespaceTokenizer(object):
    """A tokenizer to conduct word piece tokenization.

    Tokenizes a piece of text into its word pieces by matching whether the token is in the vocabulary.

    Attributes:
        vocab (`Dict[str, int]`):
            A dictionary indicates the correspondence between words and ids within the vocabulary.
        do_lower_case (`bool`):
            A boolean variable indicating Whether or not to lowercase the input when tokenizing.
        unk_token (`str`):
            A string representing the unknown token.
    """

    def __init__(self,
                vocab: Dict[str, int],
                do_lower_case: bool,
                unk_token: str):
        """Constructs a `WhitespaceTokenizer`."""
        self.vocab = vocab
        self.do_lower_case = do_lower_case
        self.unk_token = unk_token

    def tokenize(self,
                text: str) -> List[str]:
        """Tokenizes a piece of text into its word pieces."""

        output_tokens = []
        for token in whitespace_tokenize(text):
            if token in self.vocab:
                output_tokens.append(token)
            else:
                output_tokens.append(self.unk_token)
        return output_tokens

Whitespace Tokenizer

import collections
import os
import pdb
import logging
import numpy as np
from typing import List, Optional, Tuple
from transformers import PreTrainedTokenizer

logger = logging.getLogger(__name__)

load_vocab

Loads a vocabulary file into a dictionary.

def load_vocab(vocab_file, return_embeddings=False):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    vocab["[PAD]"] = 0
    with open(vocab_file, "r", encoding="utf-8") as reader:
        lines = reader.readlines()
    num_embeddings = len(lines) + 1
    embedding_dim = len(lines[0].split()) - 1
    for index, line in enumerate(lines):
        token = " ".join(line.split()[:-embedding_dim])
        if token in vocab:
            token = f"{token}_{index+1}"
        vocab[token] = index + 1
    if return_embeddings:
        word_embeddings = np.zeros((num_embeddings, embedding_dim), dtype=np.float32)
        for index, line in enumerate(lines):
            embedding = [float(value) for value in line.strip().split()[-embedding_dim:]]
            word_embeddings[index+1] = embedding
        return word_embeddings
    return vocab

whitespace_tokenize

Runs basic whitespace cleaning and splitting on a piece of text.

def whitespace_tokenize(text):
    """Runs basic whitespace cleaning and splitting on a piece of text."""
    text = text.strip()
    if not text:
        return []
    tokens = text.split()
    return tokens
VOCAB_FILES_NAMES = {"vocab_file": "vec.txt"}

PRETRAINED_VOCAB_FILES_MAP = {}

PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {}

PRETRAINED_INIT_CONFIGURATION = {}

WordLevelTokenizer

Construct a BERT tokenizer. Based on WordPiece.

This tokenizer inherits from PreTrainedTokenizer which contains most of the main methods. Users should refer to this superclass for more information regarding those methods.

class WordLevelTokenizer(PreTrainedTokenizer):
    r"""
    Construct a BERT tokenizer. Based on WordPiece.
    This tokenizer inherits from [`PreTrainedTokenizer`] which contains most of the main methods. Users should refer to
    this superclass for more information regarding those methods.
    Args:
        vocab_file (`str`):
            File containing the vocabulary.
        do_lower_case (`bool`, *optional*, defaults to `True`):
            Whether or not to lowercase the input when tokenizing.
        do_basic_tokenize (`bool`, *optional*, defaults to `True`):
            Whether or not to do basic tokenization before WordPiece.
        never_split (`Iterable`, *optional*):
            Collection of tokens which will never be split during tokenization. Only has an effect when
            `do_basic_tokenize=True`
        unk_token (`str`, *optional*, defaults to `"[UNK]"`):
            The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
            token instead.
        sep_token (`str`, *optional*, defaults to `"[SEP]"`):
            The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
            sequence classification or for a text and a question for question answering. It is also used as the last
            token of a sequence built with special tokens.
        pad_token (`str`, *optional*, defaults to `"[PAD]"`):
            The token used for padding, for example when batching sequences of different lengths.
        cls_token (`str`, *optional*, defaults to `"[CLS]"`):
            The classifier token which is used when doing sequence classification (classification of the whole sequence
            instead of per-token classification). It is the first token of the sequence when built with special tokens.
        mask_token (`str`, *optional*, defaults to `"[MASK]"`):
            The token used for masking values. This is the token used when training this model with masked language
            modeling. This is the token which the model will try to predict.
        tokenize_chinese_chars (`bool`, *optional*, defaults to `True`):
            Whether or not to tokenize Chinese characters.
            This should likely be deactivated for Japanese (see this
            [issue](https://github.com/huggingface/transformers/issues/328)).
        strip_accents (`bool`, *optional*):
            Whether or not to strip all accents. If this option is not specified, then it will be determined by the
            value for `lowercase` (as in the original BERT).
    """

    vocab_files_names = VOCAB_FILES_NAMES
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
    pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

    def __init__(
        self,
        vocab_file,
        do_lower_case=True,
        never_split=None,
        unk_token="[UNK]",
        sep_token="[SEP]",
        pad_token="[PAD]",
        cls_token="[CLS]",
        strip_accents=None,
        model_max_length=512,
        **kwargs
    ):
        kwargs["model_max_length"] = model_max_length
        super().__init__(
            do_lower_case=do_lower_case,
            never_split=never_split,
            unk_token=unk_token,
            sep_token=sep_token,
            pad_token=pad_token,
            cls_token=cls_token,
            strip_accents=strip_accents,
            **kwargs,
        )

        if not os.path.isfile(vocab_file):
            raise ValueError(
                f"Can't find a vocabulary file at path '{vocab_file}'. To load the vocabulary from a Google pretrained"
                " model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
            )
        self.vocab = load_vocab(vocab_file)
        # insert special token
        for token in [unk_token, sep_token, pad_token, cls_token]:
            if token not in self.vocab:
                self.vocab[token] = len(self.vocab)
        self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
        self.whitespace_tokenizer = WhitespaceTokenizer(vocab=self.vocab, do_lower_case=do_lower_case, unk_token=self.unk_token)

    @property
    def do_lower_case(self):
        return self.whitespace_tokenizer.do_lower_case

    @property
    def vocab_size(self):
        return len(self.vocab)

    def get_vocab(self):
        return dict(self.vocab, **self.added_tokens_encoder)

    def _tokenize(self, text):
        if self.do_lower_case:
            text = text.lower()
        split_tokens = self.whitespace_tokenizer.tokenize(text)
        return split_tokens

    def _convert_token_to_id(self, token):
        """Converts a token (str) in an id using the vocab."""
        return self.vocab.get(token, self.vocab.get(self.unk_token))

    def _convert_id_to_token(self, index):
        """Converts an index (integer) in a token (str) using the vocab."""
        return self.ids_to_tokens.get(index, self.unk_token)

    def convert_tokens_to_string(self, tokens):
        """Converts a sequence of tokens (string) in a single string."""
        out_string = " ".join(tokens).replace(" ##", "").strip()
        return out_string

    def build_inputs_with_special_tokens(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
        adding special tokens. A BERT sequence has the following format:
        - single sequence: `[CLS] X [SEP]`
        - pair of sequences: `[CLS] A [SEP] B [SEP]`
        Args:
            token_ids_0 (`List[int]`):
                List of IDs to which the special tokens will be added.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
        Returns:
            `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
        """
        if token_ids_1 is None:
            return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
        cls = [self.cls_token_id]
        sep = [self.sep_token_id]
        return cls + token_ids_0 + sep + token_ids_1 + sep

    def get_special_tokens_mask(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
    ) -> List[int]:
        """
        Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
        special tokens using the tokenizer `prepare_for_model` method.
        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
            already_has_special_tokens (`bool`, *optional*, defaults to `False`):
                Whether or not the token list is already formatted with special tokens for the model.
        Returns:
            `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
        """

        if already_has_special_tokens:
            return super().get_special_tokens_mask(
                token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
            )

        if token_ids_1 is not None:
            return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
        return [1] + ([0] * len(token_ids_0)) + [1]

    def create_token_type_ids_from_sequences(
        self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
    ) -> List[int]:
        """
        Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
        pair mask has the following format:
        ```
        0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
        | first sequence    | second sequence |
        ```
        If `token_ids_1` is `None`, this method only returns the first portion of the mask (0s).
        Args:
            token_ids_0 (`List[int]`):
                List of IDs.
            token_ids_1 (`List[int]`, *optional*):
                Optional second list of IDs for sequence pairs.
        Returns:
            `List[int]`: List of [token type IDs](../glossary#token-type-ids) according to the given sequence(s).
        """
        sep = [self.sep_token_id]
        cls = [self.cls_token_id]
        if token_ids_1 is None:
            return len(cls + token_ids_0 + sep) * [0]
        return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]

    def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
        index = 0
        if os.path.isdir(save_directory):
            vocab_file = os.path.join(
                save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
            )
        else:
            vocab_file = (filename_prefix + "-" if filename_prefix else "") + save_directory
        with open(vocab_file, "w", encoding="utf-8") as writer:
            for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
                if index != token_index:
                    logger.warning(
                        f"Saving vocabulary to {vocab_file}: vocabulary indices are not consecutive."
                        " Please check that the vocabulary is not corrupted!"
                    )
                    index = token_index
                writer.write(token + "\n")
                index += 1
        return (vocab_file,)

WhitespaceTokenizer

Runs WordPiece tokenization.

class WhitespaceTokenizer(object):
    """Runs WordPiece tokenization."""

    def __init__(self, vocab, do_lower_case, unk_token):
        self.vocab = vocab
        self.do_lower_case = do_lower_case
        self.unk_token = unk_token

    def tokenize(self, text):
        """
        Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
        tokenization using the given vocabulary.
        For example, `input = "unaffable"` wil return as output `["un", "##aff", "##able"]`.
        Args:
            text: A single token or whitespace separated tokens. This should have
                already been passed through *BasicTokenizer*.
        Returns:
            A list of wordpiece tokens.
        """

        output_tokens = []
        for token in whitespace_tokenize(text):
            if token in self.vocab:
                output_tokens.append(token)
            else:
                output_tokens.append(self.unk_token)
        return output_tokens

Base Processor

import os
import json
import torch
import logging

from torch.utils.data import Dataset
from typing import Dict, List, Optional, Union

logger = logging.getLogger(__name__)

EDInputExample

A single training/test example for event detection, representing the basic information of an event trigger, including its example id, the source text it is within, its start and end position, and the event type of the trigger.

Attributes:

  • example_id: A string or an integer for the unique id of the example.

  • text: A string representing the source text the event trigger is within.

  • trigger_left: An integer indicating the left position of the event trigger.

  • trigger_right: An integer indicating the right position of the event trigger.

  • labels: A string indicating the event type of the trigger.

class EDInputExample(object):
    """A single training/test example for event detection.

    A single training/test example for event detection, representing the basic information of an event trigger,
    including its example id, the source text it is within, its start and end position, and the label of the event.

    Attributes:
        example_id (`Union[int, str]`):
            A string or an integer for the unique id of the example.
        text (`str`):
            A string representing the source text the event trigger is within.
        trigger_left (`int`, `optional`, defaults to `None`):
            An integer indicating the left position of the event trigger.
        trigger_right (`int`, `optional`, defaults to `None`):
            An integer indicating the right position of the event trigger.
        labels (`int`, `optional`, defaults to `None`):
            A string indicating the event type of the trigger.
    """

    def __init__(self,
                 example_id: Union[int, str],
                 text: str,
                 trigger_left: Optional[int] = None,
                 trigger_right: Optional[int] = None,
                 labels: Optional[str] = None) -> None:
        """Constructs an `EDInputExample`."""
        self.example_id = example_id
        self.text = text
        self.trigger_left = trigger_left
        self.trigger_right = trigger_right
        self.labels = labels

EDInputFeatures

Input features of an instance for event detection, representing the basic features of an event trigger, including its example id, the indices of tokens in the vocabulary, attention masks, segment token indices, start and end position, and the event type of the trigger.

Attributes:

  • example_id: A string or an integer for the unique id of the example.

  • input_ids: A list of integers representing the indices of input sequence tokens in the vocabulary.

  • attention_mask: A list of integers (in 0/1) for masks to avoid attention on padding tokens.

  • token_type_ids: A list of integers indicating the first and second portions of the inputs.

  • trigger_left: An integer indicating the left position of the event trigger.

  • trigger_right: An integer indicating the right position of the event trigger.

  • labels: A string indicating the event type of the trigger.

class EDInputFeatures(object):
    """Input features of an instance for event detection.

    Input features of an instance for event detection, representing the basic features of an event trigger, including
    its example id, the indices of tokens in the vocabulary, attention masks, segment token indices, start and end
    position, and the label of the event.

    Attributes:
        example_id (`Union[int, str]`):
            A string or an integer for the unique id of the example.
        input_ids (`List[int]`):
            A list of integers representing the indices of input sequence tokens in the vocabulary.
        attention_mask (`List[int]`):
            A list of integers (in 0/1) for masks to avoid attention on padding tokens.
        token_type_ids (`List[int]`, `optional`, defaults to `None`):
            A list of integers indicating the first and second portions of the inputs.
        trigger_left (`int`, `optional`, defaults to `None`):
            An integer indicating the left position of the event trigger.
        trigger_right (`int`, `optional`, defaults to `None`):
            An integer indicating the right position of the event trigger.
        labels (`str`, `optional`, defaults to `None`):
            A string indicating the event type of the trigger.
    """

    def __init__(self,
                 example_id: Union[int, str],
                 input_ids: List[int],
                 attention_mask: List[int],
                 token_type_ids: Optional[List[int]] = None,
                 trigger_left: Optional[int] = None,
                 trigger_right: Optional[int] = None,
                 labels: Optional[str] = None) -> None:
        """Constructs an `EDInputFeatures`."""
        self.example_id = example_id
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.trigger_left = trigger_left
        self.trigger_right = trigger_right
        self.labels = labels

EAEInputExample

A single training/test example for event argument extraction, representing the basic information of an event trigger, including its example id, the source text it is within, the predicted and actual event type, the input template for the Machine Reading Comprehension (MRC) paradigm, the start and end position of the event trigger and argument, and the label of the event.

Attributes:

  • example_id: A string or an integer for the unique id of the example.

  • text: A string representing the source text the event trigger and argument is within.

  • pred_type: A string indicating the event type predicted by the model.

  • true_type: A string indicating the actual event type from the annotation.

  • input_template: The input template for the MRC paradigm.

  • trigger_left: An integer indicating the left position of the event trigger.

  • trigger_right: An integer indicating the right position of the event trigger.

  • argument_left: An integer indicating the left position of the argument mention.

  • argument_right: An integer indicating the right position of the argument mention.

  • argument_role: A string indicating the argument role of the argument mention.

  • labels: A string indicating the label of the event.

class EAEInputExample(object):
    """A single training/test example for event argument extraction.

    A single training/test example for event argument extraction, representing the basic information of an event
    trigger, including its example id, the source text it is within, the predicted and actual event type, the input
    template for the Machine Reading Comprehension (MRC) paradigm, the start and end position of the event trigger and
    argument, and the label of the event.

    Attributes:
        example_id (`Union[int, str]`):
            A string or an integer for the unique id of the example.
        text (`str`):
            A string representing the source text the event trigger and argument is within.
        pred_type (`str`):
            A string indicating the event type predicted by the model.
        true_type (`str`):
            A string indicating the actual event type from the annotation.
        input_template:
            The input template for the MRC paradigm.
        trigger_left (`int`, `optional`, defaults to `None`):
            An integer indicating the left position of the event trigger.
        trigger_right (`int`, `optional`, defaults to `None`):
            An integer indicating the right position of the event trigger.
        argument_left (`int`, `optional`, defaults to `None`):
            An integer indicating the left position of the argument mention.
        argument_right (`int`, `optional`, defaults to `None`):
            An integer indicating the right position of the argument mention.
        argument_role (`str`, `optional`, defaults to `None`):
            A string indicating the argument role of the argument mention.
        labels (`str`, `optional`, defaults to `None`):
            A string indicating the label of the event.
    """

    def __init__(self,
                 example_id: Union[int, str],
                 text: str,
                 pred_type: str,
                 true_type: str,
                 input_template: Optional = None,
                 trigger_left: Optional[int] = None,
                 trigger_right: Optional[int] = None,
                 argument_left: Optional[int] = None,
                 argument_right: Optional[int] = None,
                 argument_role: Optional[str] = None,
                 labels: Optional[str] = None):
        """Constructs a `EAEInputExample`."""
        self.example_id = example_id
        self.text = text
        self.pred_type = pred_type
        self.true_type = true_type
        self.input_template = input_template
        self.trigger_left = trigger_left
        self.trigger_right = trigger_right
        self.argument_left = argument_left
        self.argument_right = argument_right
        self.argument_role = argument_role
        self.labels = labels

EAEInputFeatures

Input features of an instance for event argument extraction, representing the basic features of an argument mention, including its example id, the indices of tokens in the vocabulary, the attention mask, segment token indices, the start and end position of the event trigger and argument mention, and the event type of the trigger.

Attributes:

  • example_id: A string or an integer for the unique id of the example.

  • input_ids: A list of integers representing the indices of input sequence tokens in the vocabulary.

  • attention_mask: A list of integers (in 0/1) for masks to avoid attention on padding tokens.

  • token_type_ids: A list of integers indicating the first and second portions of the inputs.

  • trigger_left: An integer for the left position of the event trigger.

  • trigger_right: An integer for the right position of the event trigger.

  • argument_left: An integer for the left position of the argument mention.

  • argument_right: An integer for the right position of the argument mention.

  • labels: A string indicating the event type of the trigger.

class EAEInputFeatures(object):
    """Input features of an instance for event argument extraction.

    Input features of an instance for event argument extraction, representing the basic features of an argument mention,
    including its example id, the indices of tokens in the vocabulary, the attention mask, segment token indices, the
    start and end position of the event trigger and argument mention, and the label of the event.

    Attributes:
        example_id (`Union[int, str]`):
            A string or an integer for the unique id of the example.
        input_ids (`List[int]`):
            A list of integers representing the indices of input sequence tokens in the vocabulary.
        attention_mask (`List[int]`):
            A list of integers (in 0/1) for masks to avoid attention on padding tokens.
        token_type_ids (`List[int]`, `optional`, defaults to `None`):
            A list of integers indicating the first and second portions of the inputs.
        trigger_left (`int`, `optional`, defaults to `None`):
            An integer for the left position of the event trigger.
        trigger_right (`int`, `optional`, defaults to `None`):
            An integer for the right position of the event trigger.
        argument_left (`int`, `optional`, defaults to `None`):
            An integer for the left position of the argument mention.
        argument_right (`int`, `optional`, defaults to `None`):
            An integer for the right position of the argument mention.
        labels (`str`, `optional`, defaults to `None`):
            A string indicating the event type of the trigger.
    """

    def __init__(self,
                 example_id: Union[int, str],
                 input_ids: List[int],
                 attention_mask: List[int],
                 token_type_ids: Optional[List[int]] = None,
                 trigger_left: Optional[int] = None,
                 trigger_right: Optional[int] = None,
                 argument_left: Optional[int] = None,
                 argument_right: Optional[int] = None,
                 labels: Optional[str] = None) -> None:
        """Constructs an `EAEInputFeatures`."""
        self.example_id = example_id
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.token_type_ids = token_type_ids
        self.trigger_left = trigger_left
        self.trigger_right = trigger_right
        self.argument_left = argument_left
        self.argument_right = argument_right
        self.labels = labels

EDDataProcessor

The base class of data processor for event detection, which would be inherited to construct task-specific data processors.

Attributes:

  • config: The pre-defined configurations of the execution.

  • tokenizer: The tokenizer method proposed for the tokenization process.

  • examples: A list of ``EDInputExample``s constructed based on the input dataset.

  • input_features: A list of ``EDInputFeatures``s corresponding to the ``EDInputExample``s.

class EDDataProcessor(Dataset):
    """Base class of data processor for event detection.

    The base class of data processor for event detection, which would be inherited to construct task-specific data
    processors.

    Attributes:
        config:
            The pre-defined configurations of the execution.
        tokenizer (`str`):
            The tokenizer method proposed for the tokenization process.
        examples (`List[EDInputExample]`):
            A list of `EDInputExample`s constructed based on the input dataset.
        input_features (`List[EDInputFeatures]`):
            A list of `EDInputFeatures`s corresponding to the `EDInputExample`s.
    """

    def __init__(self,
                 config,
                 tokenizer) -> None:
        """Constructs an `EDDataProcessor`."""
        self.config = config
        self.tokenizer = tokenizer
        self.examples = []
        self.input_features = []

    def read_examples(self,
                      input_file: str):
        """Obtains a collection of `EDInputExample`s for the dataset."""
        raise NotImplementedError

    def convert_examples_to_features(self):
        """Converts the `EDInputExample`s into `EDInputFeatures`s."""
        raise NotImplementedError

    def _truncate(self,
                  outputs: dict,
                  max_seq_length: int):
        """Truncates the sequence that exceeds the maximum length."""
        is_truncation = False
        if len(outputs["input_ids"]) > max_seq_length:
            print("An instance exceeds the maximum length.")
            is_truncation = True
            for key in ["input_ids", "attention_mask", "token_type_ids", "offset_mapping"]:
                if key not in outputs:
                    continue
                outputs[key] = outputs[key][:max_seq_length]
        return outputs, is_truncation

    def get_ids(self) -> List[Union[int, str]]:
        """Returns the id of the examples."""
        ids = []
        for example in self.examples:
            ids.append(example.example_id)
        return ids

    def __len__(self) -> int:
        """Returns the length of the examples."""
        return len(self.input_features)

    def __getitem__(self,
                    index: int) -> Dict[str, torch.Tensor]:
        """Obtains the features of a given example index and converts them into a dictionary."""
        features = self.input_features[index]
        data_dict = dict(
            input_ids=torch.tensor(features.input_ids, dtype=torch.long),
            attention_mask=torch.tensor(features.attention_mask, dtype=torch.float32)
        )
        if features.token_type_ids is not None and self.config.return_token_type_ids:
            data_dict["token_type_ids"] = torch.tensor(features.token_type_ids, dtype=torch.long)
        if features.trigger_left is not None:
            data_dict["trigger_left"] = torch.tensor(features.trigger_left, dtype=torch.float32)
        if features.trigger_right is not None:
            data_dict["trigger_right"] = torch.tensor(features.trigger_right, dtype=torch.float32)
        if features.labels is not None:
            data_dict["labels"] = torch.tensor(features.labels, dtype=torch.long)
        return data_dict

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
        """Collates the samples in batches."""
        output_batch = dict()
        for key in batch[0].keys():
            output_batch[key] = torch.stack([x[key] for x in batch], dim=0)
        if self.config.truncate_in_batch:
            input_length = int(output_batch["attention_mask"].sum(-1).max())
            for key in ["input_ids", "attention_mask", "token_type_ids"]:
                if key not in output_batch:
                    continue
                output_batch[key] = output_batch[key][:, :input_length]
            if "labels" in output_batch and len(output_batch["labels"].shape) == 2:
                if self.config.truncate_seq2seq_output:
                    output_length = int((output_batch["labels"] != -100).sum(-1).max())
                    output_batch["labels"] = output_batch["labels"][:, :output_length]
                else:
                    output_batch["labels"] = output_batch["labels"][:, :input_length]
        return output_batch

EAEDataProcessor

The base class of data processor for event argument extraction, which would be inherited to construct task-specific data processors.

Attributes:

  • config: The pre-defined configurations of the execution.

  • tokenizer: The tokenizer method proposed for the tokenization process.

  • is_training: A boolean variable indicating the state is training or not.

  • examples: A list of ``EDInputExample``s constructed based on the input dataset.

  • input_features: A list of ``EAEInputFeatures``s corresponding to the ``EAEInputExample``s.

  • data_for_evaluation: A dictionary representing the evaluation data.

  • event_preds: A list of event prediction data if the file exists.

class EAEDataProcessor(Dataset):
    """Base class of data processor for event argument extraction.

    The base class of data processor for event argument extraction, which would be inherited to construct task-specific
    data processors.

    Attributes:
        config:
            The pre-defined configurations of the execution.
        tokenizer:
            The tokenizer method proposed for the tokenization process.
        is_training (`bool`):
            A boolean variable indicating the state is training or not.
        examples (`List[EDInputExample]`):
            A list of `EDInputExample`s constructed based on the input dataset.
        input_features (`List[EAEInputFeatures]`):
            A list of `EAEInputFeatures`s corresponding to the `EAEInputExample`s.
        data_for_evaluation (`dict`):
            A dictionary representing the evaluation data.
        event_preds (`list`):
            A list of event prediction data if the file exists.
    """

    def __init__(self,
                 config,
                 tokenizer,
                 pred_file: str,
                 is_training: bool) -> None:
        """Constructs a EAEDataProcessor."""
        self.config = config
        self.tokenizer = tokenizer
        self.is_training = is_training
        if hasattr(config, "role2id"):
            self.config.role2id["X"] = -100
        self.examples = []
        self.input_features = []
        # data for trainer evaluation
        self.data_for_evaluation = {}
        # event prediction file path
        if pred_file is not None:
            if not os.path.exists(pred_file):
                logger.warning("%s doesn't exist.We use golden triggers" % pred_file)
                self.event_preds = None
            else:
                self.event_preds = json.load(open(pred_file))
        else:
            logger.warning("Event predictions is none! We use golden triggers.")
            self.event_preds = None

    def read_examples(self,
                      input_file: str):
        """Obtains a collection of `EAEInputExample`s for the dataset."""
        raise NotImplementedError

    def convert_examples_to_features(self):
        """Converts the `EAEInputExample`s into `EAEInputFeatures`s."""
        raise NotImplementedError

    def get_data_for_evaluation(self) -> Dict[str, Union[int, str]]:
        """Obtains the data for evaluation."""
        self.data_for_evaluation["pred_types"] = self.get_pred_types()
        self.data_for_evaluation["true_types"] = self.get_true_types()
        self.data_for_evaluation["ids"] = self.get_ids()
        if self.examples[0].argument_role is not None:
            self.data_for_evaluation["roles"] = self.get_roles()
        return self.data_for_evaluation

    def get_pred_types(self) -> List[str]:
        """Obtains the event type predicted by the model."""
        pred_types = []
        for example in self.examples:
            pred_types.append(example.pred_type)
        return pred_types

    def get_true_types(self) -> List[str]:
        """Obtains the actual event type from the annotation."""
        true_types = []
        for example in self.examples:
            true_types.append(example.true_type)
        return true_types

    def get_roles(self) -> List[str]:
        """Obtains the role of each argument mention."""
        roles = []
        for example in self.examples:
            roles.append(example.argument_role)
        return roles

    def _truncate(self,
                  outputs: Dict[str, List[int]],
                  max_seq_length: int):
        """Truncates the sequence that exceeds the maximum length."""
        is_truncation = False
        if len(outputs["input_ids"]) > max_seq_length:
            print("An instance exceeds the maximum length.")
            is_truncation = True
            for key in ["input_ids", "attention_mask", "token_type_ids", "offset_mapping"]:
                if key not in outputs:
                    continue
                outputs[key] = outputs[key][:max_seq_length]
        return outputs, is_truncation

    def get_ids(self) -> List[Union[int, str]]:
        """Returns the id of the examples."""
        ids = []
        for example in self.examples:
            ids.append(example.example_id)
        return ids

    def __len__(self) -> int:
        """Returns the length of the examples."""
        return len(self.input_features)

    def __getitem__(self,
                    index: int) -> Dict[str, torch.Tensor]:
        """Returns the features of a given example index in a dictionary."""
        features = self.input_features[index]
        data_dict = dict(
            input_ids=torch.tensor(features.input_ids, dtype=torch.long),
            attention_mask=torch.tensor(features.attention_mask, dtype=torch.float32)
        )
        if features.token_type_ids is not None and self.config.return_token_type_ids:
            data_dict["token_type_ids"] = torch.tensor(features.token_type_ids, dtype=torch.long)
        if features.trigger_left is not None:
            data_dict["trigger_left"] = torch.tensor(features.trigger_left, dtype=torch.long)
        if features.trigger_right is not None:
            data_dict["trigger_right"] = torch.tensor(features.trigger_right, dtype=torch.long)
        if features.argument_left is not None:
            data_dict["argument_left"] = torch.tensor(features.argument_left, dtype=torch.long)
        if features.argument_right is not None:
            data_dict["argument_right"] = torch.tensor(features.argument_right, dtype=torch.long)
        if features.labels is not None:
            data_dict["labels"] = torch.tensor(features.labels, dtype=torch.long)
        return data_dict

    def collate_fn(self, batch) -> Dict[str, torch.Tensor]:
        """Collates the samples in batches."""
        output_batch = dict()
        for key in batch[0].keys():
            output_batch[key] = torch.stack([x[key] for x in batch], dim=0)
        if self.config.truncate_in_batch:
            input_length = int(output_batch["attention_mask"].sum(-1).max())
            for key in ["input_ids", "attention_mask", "token_type_ids"]:
                if key not in output_batch:
                    continue
                output_batch[key] = output_batch[key][:, :input_length]
            if "labels" in output_batch and len(output_batch["labels"].shape) == 2:
                if self.config.truncate_seq2seq_output:
                    output_length = int((output_batch["labels"] != -100).sum(-1).max())
                    output_batch["labels"] = output_batch["labels"][:, :output_length]
                else:
                    output_batch["labels"] = output_batch["labels"][:, :input_length]
        return output_batch

Token Classification Processor

import json
import logging

from tqdm import tqdm
from typing import List, Optional, Dict

from .base_processor import (
    EDDataProcessor,
    EDInputExample,
    EDInputFeatures,
    EAEDataProcessor,
    EAEInputExample,
    EAEInputFeatures
)

logger = logging.getLogger(__name__)

EDTCProcessor

Data processor for token classification for event detection. The class is inherited from the`EDDataProcessor` class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; the rest of the attributes and functions are multiplexed from the EDDataProcessor class.

class EDTCProcessor(EDDataProcessor):
    """Data processor for token classification for event detection.

    Data processor for token classification for event detection. The class is inherited from the`EDDataProcessor` class,
    in which the undefined functions, including `read_examples()` and `convert_examples_to_features()` are  implemented;
    the rest of the attributes and functions are multiplexed from the `EDDataProcessor` class.
    """

    def __init__(self,
                config,
                tokenizer: str,
                input_file: str) -> None:
        """Constructs an EDTCProcessor."""
        super().__init__(config, tokenizer)
        self.read_examples(input_file)
        self.convert_examples_to_features()

    def read_examples(self,
                    input_file: str) -> None:
        """Obtains a collection of `EDInputExample`s for the dataset."""
        self.examples = []
        with open(input_file, "r") as f:
            for line in tqdm(f.readlines(), desc="Reading from %s" % input_file):
                item = json.loads(line.strip())
                # training and valid set
                if "events" in item:
                    for event in item["events"]:
                        for trigger in event["triggers"]:
                            example = EDInputExample(
                                example_id=trigger["id"],
                                text=item["text"],
                                trigger_left=trigger["position"][0],
                                trigger_right=trigger["position"][1],
                                labels=event["type"]
                            )
                            self.examples.append(example)
                if "negative_triggers" in item:
                    for neg in item["negative_triggers"]:
                        example = EDInputExample(
                            example_id=neg["id"],
                            text=item["text"],
                            trigger_left=neg["position"][0],
                            trigger_right=neg["position"][1],
                            labels="NA"
                        )
                        self.examples.append(example)
                # test set
                if "candidates" in item:
                    for candidate in item["candidates"]:
                        example = EDInputExample(
                            example_id=candidate["id"],
                            text=item["text"],
                            trigger_left=candidate["position"][0],
                            trigger_right=candidate["position"][1],
                            labels="NA",
                        )
                        # # if test set has labels
                        # assert not (self.config.test_exists_labels ^ ("type" in candidate))
                        # if "type" in candidate:
                        #     example.labels = candidate["type"]
                        self.examples.append(example)

    def convert_examples_to_features(self) -> None:
        """Converts the `EDInputExample`s into `EDInputFeatures`s."""
        # merge and then tokenize
        self.input_features = []
        for example in tqdm(self.examples, desc="Processing features for TC"):
            text_left = example.text[:example.trigger_left]
            text_mid = example.text[example.trigger_left:example.trigger_right]
            text_right = example.text[example.trigger_right:]

            if self.config.language == "Chinese":
                text = text_left + self.config.markers[0] + text_mid + self.config.markers[1] + text_right
            else:
                text = text_left + self.config.markers[0] + " " + text_mid + " " + self.config.markers[1] + text_right

            outputs = self.tokenizer(text, padding="max_length", truncation=True, max_length=self.config.max_seq_length)
            is_overflow = False
            try:
                left = outputs["input_ids"].index(self.tokenizer.convert_tokens_to_ids(self.config.markers[0]))
                right = outputs["input_ids"].index(self.tokenizer.convert_tokens_to_ids(self.config.markers[1]))
            except:
                logger.warning("Markers are not in the input tokens.")
                left, right = 0, 0
                is_overflow = True

            # Roberta tokenizer doesn't return token_type_ids
            if "token_type_ids" not in outputs:
                outputs["token_type_ids"] = [0] * len(outputs["input_ids"])

            features = EDInputFeatures(
                example_id=example.example_id,
                input_ids=outputs["input_ids"],
                attention_mask=outputs["attention_mask"],
                token_type_ids=outputs["token_type_ids"],
                trigger_left=left,
                trigger_right=right
            )
            if example.labels is not None:
                features.labels = self.config.type2id[example.labels]
            self.input_features.append(features)

EAETCProcessor

Data processor for token classification for event argument extraction. The class is inherited from the EAEDataProcessor class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; a new function entitled insert_marker() is defined, and the rest of the attributes and functions are multiplexed from the EAEDataProcessor class.

class EAETCProcessor(EAEDataProcessor):
    """Data processor for token classification for event argument extraction.

    Data processor for token classification for event argument extraction. The class is inherited from the
    `EAEDataProcessor` class, in which the undefined functions, including `read_examples()` and
    `convert_examples_to_features()` are  implemented; a new function entitled `insert_marker()` is defined, and
    the rest of the attributes and functions are multiplexed from the `EAEDataProcessor` class.
    """

    def __init__(self,
                config,
                tokenizer: str,
                input_file: str,
                pred_file: str,
                is_training: Optional[bool] = False):
        """Constructs a `EAETCProcessor`."""
        super().__init__(config, tokenizer, pred_file, is_training)
        self.read_examples(input_file)
        self.convert_examples_to_features()

    def read_examples(self,
                    input_file: str) -> None:
        """Obtains a collection of `EAEInputExample`s for the dataset."""
        self.examples = []
        trigger_idx = 0
        with open(input_file, "r") as f:
            all_lines = f.readlines()
            for line in tqdm(all_lines, desc="Reading from %s" % input_file):
                item = json.loads(line.strip())
                if "events" in item:
                    for event in item["events"]:
                        for trigger in event["triggers"]:
                            true_type = event["type"]
                            if self.is_training or self.config.golden_trigger or self.event_preds is None:
                                pred_type = true_type
                            else:
                                pred_type = self.event_preds[trigger_idx]

                            trigger_idx += 1

                            if self.config.eae_eval_mode in ['default', 'loose']:
                                if pred_type == "NA":
                                    continue

                            args_for_trigger = set()
                            positive_offsets = []
                            for argument in trigger["arguments"]:
                                for mention in argument["mentions"]:
                                    example = EAEInputExample(
                                        example_id=trigger["id"],
                                        text=item["text"],
                                        pred_type=pred_type,
                                        true_type=event["type"],
                                        trigger_left=trigger["position"][0],
                                        trigger_right=trigger["position"][1],
                                        argument_left=mention["position"][0],
                                        argument_right=mention["position"][1],
                                        labels=argument["role"]
                                    )
                                    args_for_trigger.add(mention['mention_id'])
                                    positive_offsets.append(mention["position"])
                                    self.examples.append(example)
                            if "entities" in item:
                                for entity in item["entities"]:
                                    # check whether the entity is an argument
                                    is_argument = False
                                    for mention in entity["mentions"]:
                                        if mention["mention_id"] in args_for_trigger:
                                            is_argument = True
                                            break
                                    if is_argument:
                                        continue
                                    # negative arguments
                                    for mention in entity["mentions"]:
                                        example = EAEInputExample(
                                            example_id=trigger["id"],
                                            text=item["text"],
                                            pred_type=pred_type,
                                            true_type=event["type"],
                                            trigger_left=trigger["position"][0],
                                            trigger_right=trigger["position"][1],
                                            argument_left=mention["position"][0],
                                            argument_right=mention["position"][1],
                                            labels="NA"
                                        )
                                        if "train" in input_file or self.config.golden_trigger:
                                            example.pred_type = event["type"]
                                        self.examples.append(example)
                            else:
                                for neg in item["negative_triggers"]:
                                    is_argument = False
                                    neg_set = set(range(neg["position"][0], neg["position"][1]))
                                    for pos_offset in positive_offsets:
                                        pos_set = set(range(pos_offset[0], pos_offset[1]))
                                        if not pos_set.isdisjoint(neg_set):
                                            is_argument = True
                                            break
                                    if is_argument:
                                        continue
                                    example = EAEInputExample(
                                        example_id=trigger["id"],
                                        text=item["text"],
                                        pred_type=pred_type,
                                        true_type=event["type"],
                                        trigger_left=trigger["position"][0],
                                        trigger_right=trigger["position"][1],
                                        argument_left=neg["position"][0],
                                        argument_right=neg["position"][1],
                                        labels="NA"
                                    )
                                    if "train" in input_file or self.config.golden_trigger:
                                        example.pred_type = event["type"]
                                    self.examples.append(example)

                    # negative triggers
                    for trigger in item["negative_triggers"]:
                        if self.config.eae_eval_mode in ['default', 'strict']:
                            if self.is_training or self.config.golden_trigger or self.event_preds is None:
                                pred_type = "NA"
                            else:
                                pred_type = self.event_preds[trigger_idx]

                            if pred_type != "NA":
                                if "entities" in item:
                                    for entity in item["entities"]:
                                        for mention in entity["mentions"]:
                                            example = EAEInputExample(
                                                example_id=trigger_idx,
                                                text=item["text"],
                                                pred_type=pred_type,
                                                true_type="NA",
                                                trigger_left=trigger["position"][0],
                                                trigger_right=trigger["position"][1],
                                                argument_left=mention["position"][0],
                                                argument_right=mention["position"][1],
                                                labels="NA"
                                            )
                                            self.examples.append(example)
                                else:
                                    for neg in item["negative_triggers"]:
                                        example = EAEInputExample(
                                            example_id=trigger_idx,
                                            text=item["text"],
                                            pred_type=pred_type,
                                            true_type=event["type"],
                                            trigger_left=trigger["position"][0],
                                            trigger_right=trigger["position"][1],
                                            argument_left=neg["position"][0],
                                            argument_right=neg["position"][1],
                                            labels="NA"
                                        )
                                        if "train" in input_file or self.config.golden_trigger:
                                            example.pred_type = event["type"]
                                        self.examples.append(example)
                        trigger_idx += 1
                else:
                    for candi in item["candidates"]:
                        pred_type = self.event_preds[trigger_idx]   # we can only use pred type here, gold not available
                        if pred_type != "NA":
                            if "entities" in item:
                                for entity in item["entities"]:
                                    for mention in entity["mentions"]:
                                        example = EAEInputExample(
                                            example_id=trigger_idx,
                                            text=item["text"],
                                            pred_type=pred_type,
                                            true_type="NA",
                                            trigger_left=candi["position"][0],
                                            trigger_right=candi["position"][1],
                                            argument_left=mention["position"][0],
                                            argument_right=mention["position"][1],
                                            labels="NA"
                                        )
                                        self.examples.append(example)
                            else:
                                for neg in item["negative_triggers"]:
                                    example = EAEInputExample(
                                        example_id=trigger_idx,
                                        text=item["text"],
                                        pred_type=pred_type,
                                        true_type=event["type"],
                                        trigger_left=trigger["position"][0],
                                        trigger_right=trigger["position"][1],
                                        argument_left=neg["position"][0],
                                        argument_right=neg["position"][1],
                                        labels="NA"
                                    )
                                    if "train" in input_file or self.config.golden_trigger:
                                        example.pred_type = event["type"]
                                    self.examples.append(example)

                        trigger_idx += 1
            if self.event_preds is not None:
                assert trigger_idx == len(self.event_preds)

    def insert_marker(self,
                    text: str,
                    type: str,
                    trigger_position: List[int],
                    argument_position: List[int],
                    markers: Dict[str, str],
                    whitespace: Optional[bool] = True) -> str:
        """Adds a marker at the start and end position of event triggers and argument mentions."""
        markered_text = ""
        for i, char in enumerate(text):
            if i == trigger_position[0]:
                markered_text += markers[type][0]
                markered_text += " " if whitespace else ""
            if i == argument_position[0]:
                markered_text += markers["argument"][0]
                markered_text += " " if whitespace else ""
            markered_text += char
            if i == trigger_position[1] - 1:
                markered_text += " " if whitespace else ""
                markered_text += markers[type][1]
            if i == argument_position[1] - 1:
                markered_text += " " if whitespace else ""
                markered_text += markers["argument"][1]
        return markered_text

    def convert_examples_to_features(self) -> None:
        """Converts the `EAEInputExample`s into `EAEInputFeatures`s."""
        # merge and then tokenize
        self.input_features = []
        whitespace = True if self.config.language == "English" else False
        for example in tqdm(self.examples, desc="Processing features for TC"):
            text = self.insert_marker(example.text,
                                    example.pred_type,
                                    [example.trigger_left, example.trigger_right],
                                    [example.argument_left, example.argument_right],
                                    self.config.markers,
                                    whitespace)
            outputs = self.tokenizer(text,
                                    padding="max_length",
                                    truncation=True,
                                    max_length=self.config.max_seq_length)
            is_overflow = False
            # argument position
            try:
                argument_left = outputs["input_ids"].index(
                    self.tokenizer.convert_tokens_to_ids(self.config.markers["argument"][0]))
                argument_right = outputs["input_ids"].index(
                    self.tokenizer.convert_tokens_to_ids(self.config.markers["argument"][1]))
            except:
                argument_left, argument_right = 0, 0
                logger.warning("Argument markers are not in the input tokens.")
                is_overflow = True
            # trigger position
            try:
                trigger_left = outputs["input_ids"].index(
                    self.tokenizer.convert_tokens_to_ids(self.config.markers[example.pred_type][0]))
                trigger_right = outputs["input_ids"].index(
                    self.tokenizer.convert_tokens_to_ids(self.config.markers[example.pred_type][1]))
            except:
                trigger_left, trigger_right = 0, 0
                logger.warning("Trigger markers are not in the input tokens.")
            # Roberta tokenizer doesn't return token_type_ids
            if "token_type_ids" not in outputs:
                outputs["token_type_ids"] = [0] * len(outputs["input_ids"])

            features = EAEInputFeatures(
                example_id=example.example_id,
                input_ids=outputs["input_ids"],
                attention_mask=outputs["attention_mask"],
                token_type_ids=outputs["token_type_ids"],
                trigger_left=trigger_left,
                trigger_right=trigger_right,
                argument_left=argument_left,
                argument_right=argument_right
            )
            if example.labels is not None:
                features.labels = self.config.role2id[example.labels]
                if is_overflow:
                    features.labels = -100
            self.input_features.append(features)

Sequence Labeling Processor

import json
import logging
from typing import List, Union, Any, Optional

from tqdm import tqdm
from .base_processor import (
    EDDataProcessor,
    EDInputExample,
    EDInputFeatures,
    EAEDataProcessor,
    EAEInputExample,
    EAEInputFeatures
)

logger = logging.getLogger(__name__)

EDSLProcessor

Data processor for sequence labeling for event detection. The class is inherited from the EDDataProcessor class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; a new function entitled get_final_labels() is defined to obtain final results, and the rest of the attributes and functions are multiplexed from the EDDataProcessor class.

class EDSLProcessor(EDDataProcessor):
    """Data processor for sequence labeling for event detection.
    Data processor for sequence labeling for event detection. The class is inherited from the `EDDataProcessor` class,
    in which the undefined functions, including `read_examples()` and `convert_examples_to_features()` are  implemented;
    a new function entitled `get_final_labels()` is defined to obtain final results, and the rest of the attributes and
    functions are multiplexed from the `EDDataProcessor` class.
    Attributes:
        is_overflow:
    """

    def __init__(self,
                 config,
                 tokenizer: str,
                 input_file: str) -> None:
        """Constructs a EDSLProcessor."""
        super().__init__(config, tokenizer)
        self.read_examples(input_file)
        self.is_overflow = []
        self.convert_examples_to_features()

    def read_examples(self,
                      input_file: str) -> None:
        """Obtains a collection of `EDInputExample`s for the dataset."""
        self.examples = []
        language = self.config.language

        with open(input_file, "r", encoding="utf-8") as f:
            for line in tqdm(f.readlines(), desc="Reading from %s" % input_file):
                item = json.loads(line.strip())
                text = item["text"]
                words = get_words(text=text, language=language)
                labels = ["O"] * len(words)

                if "events" in item:
                    for event in item["events"]:
                        for trigger in event["triggers"]:
                            left_pos, right_pos = get_left_and_right_pos(text, trigger, language, True)
                            labels[left_pos] = f"B-{event['type']}"
                            for i in range(left_pos + 1, right_pos):
                                labels[i] = f"I-{event['type']}"

                example = EDInputExample(example_id=item["id"], text=words, labels=labels)
                self.examples.append(example)

    def get_final_labels(self,
                         example: EDInputExample,
                         word_ids_of_each_token: List[int],
                         label_all_tokens: Optional[bool] = False) -> List[Union[str, int]]:
        """Obtains the final label of each token."""
        final_labels = []
        pre_word_id = None
        for word_id in word_ids_of_each_token:
            if word_id is None:
                final_labels.append(-100)
            elif word_id != pre_word_id:  # first split token of a word
                final_labels.append(self.config.type2id[example.labels[word_id]])
            else:
                final_labels.append(self.config.type2id[example.labels[word_id]] if label_all_tokens else -100)
            pre_word_id = word_id

        return final_labels

    def convert_examples_to_features(self) -> None:
        """Converts the `EDInputExample`s into `EDInputFeatures`s."""
        self.input_features = []

        for example in tqdm(self.examples, desc="Processing features for SL"):
            outputs = self.tokenizer(example.text,
                                     padding="max_length",
                                     truncation=False,
                                     max_length=self.config.max_seq_length,
                                     is_split_into_words=True)
            # Roberta tokenizer doesn't return token_type_ids
            if "token_type_ids" not in outputs:
                outputs["token_type_ids"] = [0] * len(outputs["input_ids"])

            outputs, is_overflow = self._truncate(outputs, self.config.max_seq_length)
            self.is_overflow.append(is_overflow)

            word_ids_of_each_token = get_word_ids(self.tokenizer, outputs, example.text)[: self.config.max_seq_length]
            final_labels = self.get_final_labels(example, word_ids_of_each_token, label_all_tokens=False)

            features = EDInputFeatures(
                example_id=example.example_id,
                input_ids=outputs["input_ids"],
                attention_mask=outputs["attention_mask"],
                token_type_ids=outputs["token_type_ids"],
                labels=final_labels,
            )
            self.input_features.append(features)

EAESLProcessor

Data processor for sequence labeling for event argument extraction. The class is inherited from the EAEDataProcessor class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; twp new functions, entitled get_final_labels() and insert_markers()` are defined, and the rest of the attributes and functions are multiplexed from the EAEDataProcessor class.

Attributes:

  • positive_candidate_indices: A list of integers indicating the indices of positive trigger candidates.

class EAESLProcessor(EAEDataProcessor):
    """Data processor for sequence labeling for event argument extraction.
    Data processor for sequence labeling for event argument extraction. The class is inherited from the
    `EAEDataProcessor` class, in which the undefined functions, including `read_examples()` and
    `convert_examples_to_features()` are  implemented; twp new functions, entitled `get_final_labels()` and
    `insert_markers()` are defined, and the rest of the attributes and functions are multiplexed from the
    `EAEDataProcessor` class.
    Attributes:
        positive_candidate_indices (`List[int]`):
            A list of integers indicating the indices of positive trigger candidates.
        is_overflow:
    """

    def __init__(self,
                 config: str,
                 tokenizer: str,
                 input_file: str,
                 pred_file: str,
                 is_training: Optional[bool] = False) -> None:
        """Constructs an EAESLProcessor/"""
        super().__init__(config, tokenizer, pred_file, is_training)
        self.positive_candidate_indices = []
        self.is_overflow = []
        self.config.role2id["X"] = -100
        self.read_examples(input_file)
        self.convert_examples_to_features()

    def read_examples(self,
                      input_file: str) -> None:
        """Obtains a collection of `EAEInputExample`s for the dataset."""
        self.examples = []
        language = self.config.language
        trigger_idx = 0
        with open(input_file, "r", encoding="utf-8") as f:
            for line in tqdm(f.readlines(), desc="Reading from %s" % input_file):
                item = json.loads(line.strip())
                text = item["text"]
                words = get_words(text=text, language=language)

                if "events" in item:
                    for event in item["events"]:
                        for trigger in event["triggers"]:
                            pred_type = self.get_single_pred(trigger_idx, input_file, true_type=event["type"])
                            trigger_idx += 1

                            # Evaluation mode for EAE
                            # If the predicted event type is NA, We don't consider the trigger
                            if self.config.eae_eval_mode in ["default", "loose"] and pred_type == "NA":
                                continue
                            trigger_left, trigger_right = get_left_and_right_pos(text, trigger, language, True)
                            labels = ["O"] * len(words)

                            for argument in trigger["arguments"]:
                                for mention in argument["mentions"]:
                                    left_pos, right_pos = get_left_and_right_pos(text, mention, language, True)
                                    labels[left_pos] = f"B-{argument['role']}"
                                    for i in range(left_pos + 1, right_pos):
                                        labels[i] = f"I-{argument['role']}"

                            example = EAEInputExample(
                                example_id=item["id"],
                                text=words,
                                pred_type=pred_type,
                                true_type=event["type"],
                                trigger_left=trigger_left,
                                trigger_right=trigger_right,
                                labels=labels,
                            )
                            self.examples.append(example)

                    # negative triggers
                    for neg in item["negative_triggers"]:
                        pred_type = self.get_single_pred(trigger_idx, input_file, true_type="NA")
                        trigger_idx += 1
                        if self.config.eae_eval_mode == "loose":
                            continue
                        elif self.config.eae_eval_mode in ["default", "strict"]:
                            if pred_type != "NA":
                                neg_left, neg_right = get_left_and_right_pos(text, neg, language, True)
                                example = EAEInputExample(
                                    example_id=item["id"],
                                    text=words,
                                    pred_type=pred_type,
                                    true_type="NA",
                                    trigger_left=neg_left,
                                    trigger_right=neg_right,
                                    labels=["O"] * len(words),
                                )
                                self.examples.append(example)
                        else:
                            raise ValueError("Invalid eac_eval_mode: %s" % self.config.eae_eval_mode)
                else:
                    for can in item["candidates"]:
                        can_left, can_right = get_left_and_right_pos(text, can, language, True)
                        labels = ["O"] * len(words)
                        pred_type = self.event_preds[trigger_idx]
                        trigger_idx += 1
                        if pred_type != "NA":
                            example = EAEInputExample(
                                example_id=item["id"],
                                text=words,
                                pred_type=pred_type,
                                true_type="NA",  # true type not given, set to NA.
                                trigger_left=can_left,
                                trigger_right=can_right,
                                labels=labels,
                            )
                            self.examples.append(example)
                            self.positive_candidate_indices.append(trigger_idx-1)
            if self.event_preds is not None:
                assert trigger_idx == len(self.event_preds)

    def get_final_labels(self,
                         labels: dict,
                         word_ids_of_each_token: List[Any],
                         label_all_tokens: bool = False) -> List[Union[str, int]]:
        """Obtains the final label of each token."""
        final_labels = []
        pre_word_id = None
        for word_id in word_ids_of_each_token:
            if word_id is None:
                final_labels.append(-100)
            elif word_id != pre_word_id:  # first split token of a word
                final_labels.append(self.config.role2id[labels[word_id]])
            else:
                final_labels.append(self.config.role2id[labels[word_id]] if label_all_tokens else -100)
            pre_word_id = word_id

        return final_labels

    @staticmethod
    def insert_marker(text: list,
                      event_type: str,
                      labels,
                      trigger_pos: List[int],
                      markers):
        """Adds a marker at the start and end position of event triggers and argument mentions."""
        left, right = trigger_pos

        marked_text = text[:left] + [markers[event_type][0]] + text[left:right] + [markers[event_type][1]] + text[right:]
        marked_labels = labels[:left] + ["X"] + labels[left:right] + ["X"] + labels[right:]

        assert len(marked_text) == len(marked_labels)
        return marked_text, marked_labels

    def convert_examples_to_features(self) -> None:
        """Converts the `EAEInputExample`s into `EAEInputFeatures`s."""
        self.input_features = []
        self.is_overflow = []

        for example in tqdm(self.examples, desc="Processing features for SL"):
            text, labels = self.insert_marker(example.text,
                                              example.pred_type,
                                              example.labels,
                                              [example.trigger_left, example.trigger_right],
                                              self.config.markers)
            outputs = self.tokenizer(text,
                                     padding="max_length",
                                     truncation=False,
                                     max_length=self.config.max_seq_length,
                                     is_split_into_words=True)
            # Roberta tokenizer doesn't return token_type_ids
            if "token_type_ids" not in outputs:
                outputs["token_type_ids"] = [0] * len(outputs["input_ids"])
            outputs, is_overflow = self._truncate(outputs, self.config.max_seq_length)
            self.is_overflow.append(is_overflow)

            word_ids_of_each_token = get_word_ids(self.tokenizer, outputs, example.text)[: self.config.max_seq_length]
            final_labels = self.get_final_labels(labels, word_ids_of_each_token, label_all_tokens=False)

            features = EAEInputFeatures(
                example_id=example.example_id,
                input_ids=outputs["input_ids"],
                attention_mask=outputs["attention_mask"],
                token_type_ids=outputs["token_type_ids"],
                labels=final_labels,
            )
            self.input_features.append(features)

Sequence-to-Sequence Processor

import re
import json
import logging
from typing import List, Union, Tuple, Optional

from tqdm import tqdm
from collections import defaultdict
from .base_processor import (
    EDInputExample,
    EDDataProcessor,
    EDInputFeatures,
    EAEDataProcessor,
    EAEInputExample,
    EAEInputFeatures
)

type_start = "<"
type_end = ">"
split_word = ":"

logger = logging.getLogger(__name__)

extract_argument

Extracts the arguments from the raw text.

Args:

  • raw_text: A string indicating the input raw text.

  • instance_id: The id of the input example.

  • event_type: A string indicating the type of the event.

  • template: The template of the event argument extraction.

def extract_argument(raw_text: str,
                     instance_id: Union[int, str],
                     event_type: str,
                     template=re.compile(f"{type_start}|{type_end}")) -> List[Tuple]:
    """Extracts the arguments from the raw text.
    Args:
        raw_text (`str`):
            A string indicating the input raw text.
        instance_id (`Union[int, str]`):
            The id of the input example.
        event_type (`str`):
            A string indicating the type of the event.
        template (`optional`, defaults to `re.compile(f"[{type_start}{type_end}]")`):
            The template of the event argument extraction.
    """
    arguments = []
    for span in template.split(raw_text):
        if span.strip() == "":
            continue
        words = span.strip().split(split_word)
        if len(words) != 2:
            continue
        role = words[0].strip().replace(" ", "")
        value = words[1].strip().replace(" ", "")
        if role != "" and value != "":
            arguments.append((instance_id, event_type, role, value))
    arguments = list(set(arguments))
    return arguments

EDSeq2SeqProcessor

Data processor for Sequence-to-Sequence (Seq2Seq) for event detection. The class is inherited from the EDDataProcessor class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; the rest of the attributes and functions are multiplexed from the EDDataProcessor class.

class EDSeq2SeqProcessor(EDDataProcessor):
    """Data processor for Sequence-to-Sequence (Seq2Seq) for event detection.
    Data processor for Sequence-to-Sequence (Seq2Seq) for event detection. The class is inherited from the
    `EDDataProcessor` class, in which the undefined functions, including `read_examples()` and
    `convert_examples_to_features()` are  implemented; the rest of the attributes and functions are multiplexed from the
    `EDDataProcessor` class.
    """

    def __init__(self,
                 config,
                 tokenizer,
                 input_file: str) -> None:
        """Constructs a `EDSeq2SeqProcessor`."""
        super().__init__(config, tokenizer)
        self.read_examples(input_file)
        self.convert_examples_to_features()

    def read_examples(self,
                      input_file: str) -> None:
        """Obtains a collection of `EDInputExample`s for the dataset."""
        self.examples = []
        with open(input_file, "r", encoding="utf-8") as f:
            for idx, line in enumerate(tqdm(f.readlines(), desc="Reading from %s" % input_file)):
                item = json.loads(line.strip())
                if "source" in item:
                    kwargs = {"source": [item["source"]]}
                    if item["source"] in ["<duee>", "<fewfc>", "<leven>"]:
                        self.config.language = "Chinese"
                    else:
                        self.config.language = "English"
                else:
                    kwargs = {"source": []}

                words = get_words(text=item["text"], language=self.config.language)
                # training and valid set
                if "events" in item:
                    labels = []
                    for event in item["events"]:
                        type = get_plain_label(event["type"])
                        for trigger in event["triggers"]:
                            labels.append(f"{type_start} {type}{split_word} {trigger['trigger_word']} {type_end}")
                    labels = "".join(labels)

                    example = EDInputExample(
                        example_id=idx,
                        text=words,
                        labels=labels,
                        **kwargs,
                    )
                    self.examples.append(example)
                else:
                    example = EDInputExample(example_id=idx, text=words, labels="", **kwargs)
                    self.examples.append(example)

    def convert_examples_to_features(self) -> None:
        """Converts the `EDInputExample`s into `EDInputFeatures`s."""
        self.input_features = []
        for example in tqdm(self.examples, desc="Processing features for Seq2Seq"):
            # context
            input_context = self.tokenizer(example.kwargs["source"]+example.text,
                                           truncation=True,
                                           padding="max_length",
                                           max_length=self.config.max_seq_length,
                                           is_split_into_words=True)
            # output labels
            label_outputs = self.tokenizer(example.labels.split(),
                                           truncation=True,
                                           padding="max_length",
                                           max_length=self.config.max_out_length,
                                           is_split_into_words=True)
            # set -100 to unused token
            for i, flag in enumerate(label_outputs["attention_mask"]):
                if flag == 0:
                    label_outputs["input_ids"][i] = -100
            features = EDInputFeatures(
                example_id=example.example_id,
                input_ids=input_context["input_ids"],
                attention_mask=input_context["attention_mask"],
                labels=label_outputs["input_ids"],
            )
            self.input_features.append(features)

EAESeq2SeqProcessor

Data processor for token classification for event argument extraction. The class is inherited from the EAEDataProcessor class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; a new function entitled insert_marker() is defined, and the rest of the attributes and functions are multiplexed from the EAEDataProcessor class.

class EAESeq2SeqProcessor(EAEDataProcessor):
    """Data processor for sequence to sequence for event argument extraction.
    Data processor for token classification for event argument extraction. The class is inherited from the
    `EAEDataProcessor` class, in which the undefined functions, including `read_examples()` and
    `convert_examples_to_features()` are  implemented; a new function entitled `insert_marker()` is defined, and
    the rest of the attributes and functions are multiplexed from the `EAEDataProcessor` class.
    """

    def __init__(self,
                 config,
                 tokenizer: str,
                 input_file: str,
                 pred_file: str,
                 is_training: Optional[bool] = False) -> None:
        """Constructs a `EAESeq2SeqProcessor`."""
        super().__init__(config, tokenizer, pred_file, is_training)
        self.read_examples(input_file)
        self.convert_examples_to_features()

    def read_examples(self,
                      input_file: str) -> None:
        """Obtains a collection of `EAEInputExample`s for the dataset."""
        self.examples = []
        self.data_for_evaluation["golden_arguments"] = []
        self.data_for_evaluation["roles"] = []
        language = self.config.language
        trigger_idx = 0
        with open(input_file, "r", encoding="utf-8") as f:
            for line in tqdm(f.readlines(), desc="Reading from %s" % input_file):
                item = json.loads(line.strip())
                if "source" in item:
                    kwargs = {"source": [item["source"]]}
                    if item["source"] in ["<duee>", "<fewfc>", "<leven>"]:
                        self.config.language = "Chinese"
                    else:
                        self.config.language = "English"
                else:
                    kwargs = {"source": []}

                text = item["text"]
                words = get_words(text=text, language=language)

                if "events" in item:
                    for event in item["events"]:
                        for trigger in event["triggers"]:
                            pred_type = self.get_single_pred(trigger_idx, input_file, true_type=event["type"])
                            pred_type = get_plain_label(pred_type)
                            trigger_idx += 1

                            # Evaluation mode for EAE
                            # If the predicted event type is NA, We don't consider the trigger
                            if self.config.eae_eval_mode in ["default", "loose"] and pred_type == "NA":
                                continue

                            labels = []
                            arguments_per_trigger = defaultdict(list)
                            for argument in trigger["arguments"]:
                                role = get_plain_label(argument["role"])
                                for mention in argument["mentions"]:
                                    arguments_per_trigger[role].append(mention["mention"])
                                    labels.append(f"{type_start} {role}{split_word} {mention['mention']} {type_end}")
                            labels = "".join(labels)

                            self.data_for_evaluation["golden_arguments"].append(dict(arguments_per_trigger))
                            example = EAEInputExample(
                                example_id=trigger_idx - 1,
                                text=words,
                                pred_type=pred_type,
                                true_type=get_plain_label(event["type"]),
                                trigger_left=trigger["position"][0],
                                trigger_right=trigger["position"][1],
                                labels=labels,
                                **kwargs,
                            )
                            self.examples.append(example)
                    # negative triggers
                    for neg_trigger in item["negative_triggers"]:
                        pred_type = self.get_single_pred(trigger_idx, input_file, true_type="NA")
                        pred_type = get_plain_label(pred_type)
                        trigger_idx += 1

                        if self.config.eae_eval_mode == "loose":
                            continue
                        elif self.config.eae_eval_mode in ["default", "strict"]:
                            if pred_type != "NA":
                                arguments_per_trigger = {}
                                self.data_for_evaluation["golden_arguments"].append(dict(arguments_per_trigger))
                                example = EAEInputExample(
                                    example_id=trigger_idx - 1,
                                    text=words,
                                    pred_type=pred_type,
                                    true_type="NA",
                                    trigger_left=neg_trigger["position"][0],
                                    trigger_right=neg_trigger["position"][1],
                                    labels="",
                                    **kwargs,
                                )
                                self.examples.append(example)
                        else:
                            raise ValueError("Invaild eac_eval_mode: %s" % self.config.eae_eval_mode)
                else:
                    for candi in item["candidates"]:
                        pred_type = self.event_preds[trigger_idx]
                        pred_type = get_plain_label(pred_type)
                        trigger_idx += 1
                        if pred_type != "NA":
                            arguments_per_trigger = {}
                            self.data_for_evaluation["golden_arguments"].append(dict(arguments_per_trigger))
                            example = EAEInputExample(
                                example_id=trigger_idx - 1,
                                text=words,
                                pred_type=pred_type,
                                true_type="NA",  # true type not given, set to NA.
                                trigger_left=candi["position"][0],
                                trigger_right=candi["position"][1],
                                labels="",
                                **kwargs,
                            )
                            self.examples.append(example)
            if self.event_preds is not None and not self.config.golden_trigger:
                assert trigger_idx == len(self.event_preds)
            print('there are {} examples'.format(len(self.examples)))

    @staticmethod
    def insert_marker(tokens: List[str],
                      trigger_pos: List[int],
                      markers: List,
                      whitespace: Optional[bool] = True) -> List[str]:
        """Adds a marker at the start and end position of event triggers and argument mentions."""
        space = " " if whitespace else ""
        marked_words = []
        char_pos = 0
        for i, token in enumerate(tokens):
            if char_pos == trigger_pos[0]:
                marked_words.append(markers[0])
            char_pos += len(token) + len(space)
            marked_words.append(token)
            if char_pos == trigger_pos[1] + len(space):
                marked_words.append(markers[1])
        return marked_words

    def convert_examples_to_features(self) -> None:
        """Converts the `EAEInputExample`s into `EAEInputFeatures`s."""
        self.input_features = []
        whitespace = True if self.config.language == "English" else False
        for example in tqdm(self.examples, desc="Processing features for Seq2Seq"):
            # context
            words = self.insert_marker(example.text,
                                       [example.trigger_left, example.trigger_right],
                                       self.config.markers,
                                       whitespace)
            input_context = self.tokenizer(example.kwargs["source"] + words,
                                           truncation=True,
                                           padding="max_length",
                                           max_length=self.config.max_seq_length,
                                           is_split_into_words=True)
            # output labels
            label_outputs = self.tokenizer(example.labels.split(),
                                           padding="max_length",
                                           truncation=True,
                                           max_length=self.config.max_out_length,
                                           is_split_into_words=True)
            # set -100 to unused token
            for i, flag in enumerate(label_outputs["attention_mask"]):
                if flag == 0:
                    label_outputs["input_ids"][i] = -100
            features = EAEInputFeatures(
                example_id=example.example_id,
                input_ids=input_context["input_ids"],
                attention_mask=input_context["attention_mask"],
                labels=label_outputs["input_ids"],
            )
            self.input_features.append(features)

Machine Reading Comprehension (MRC) Converter

import collections
import logging
from typing import Dict, List, Optional

logger = logging.getLogger(__name__)

read_query_templates

Loads query templates from a prompt file. If a translation is required, the query templates would be translated from English to Chinese based on four types of regulations.

Args:

  • prompt_file: A string indicating the path of the prompt file.

  • translate: A boolean variable indicating whether or not to translate the query templates into Chinese.

Returns:

  • query_templates: A dictionary containing the query templates applicable for every event type and argument role.

def read_query_templates(prompt_file: str,
                         translate: Optional[bool] = False) -> Dict[str, Dict[str, List[str]]]:
    """Loads query templates from a prompt file.
    Loads query templates from a prompt file. If a translation is required, the query templates would be translated from
    English to Chinese based on four types of regulations.
    Args:
        prompt_file (`str`):
            A string indicating the path of the prompt file.
        translate (`bool`, `optional`, defaults to `False`):
            A boolean variable indicating whether or not to translate the query templates into Chinese.
    Returns:
        query_templates (`Dict[str, Dict[str, List[str]]]`)
            A dictionary containing the query templates applicable for every event type and argument role.
    """
    et_translation = dict()
    ar_translation = dict()
    if translate:
        # the event types and argument roles in ACE2005-zh are expressed in English, we translate them to Chinese
        et_file = "/".join(prompt_file.split('/')[:-1]) + "/chinese_event_types.txt"
        title = None

        for line in open(et_file, encoding='utf-8').readlines():
            num = line.split()[0]
            chinese = line.split()[1][:line.split()[1].index("(")]
            english = line[line.index("(") + 1:line.index(')')]
            if '.' not in num:
                title = chinese, english
            if title:
                et_translation['{}.{}'.format(title[1], english)] = "{}.{}".format(title[0], chinese)

        ar_file = "/".join(prompt_file.split('/')[:-1]) + "/chinese_arg_roles.txt"
        for line in open(ar_file, encoding='utf-8').readlines():
            english, chinese = line.strip().split()
            ar_translation[english] = chinese

    query_templates = dict()
    with open(prompt_file, "r", encoding='utf-8') as f:
        for line in f:
            event_arg, query = line.strip().split(",")
            event_type, arg_name = event_arg.split("_")

            if event_type not in query_templates:
                query_templates[event_type] = dict()
            if arg_name not in query_templates[event_type]:
                query_templates[event_type][arg_name] = list()

            if translate:
                # 0 template arg_name
                query_templates[event_type][arg_name].append(ar_translation[arg_name])
                # 1 template arg_name + in trigger (replace [trigger] when forming the instance)
                query_templates[event_type][arg_name].append(ar_translation[arg_name] + "在[trigger]中")
                # 2 template arg_query
                query_templates[event_type][arg_name].append(query)
                # 3 arg_query + trigger (replace [trigger] when forming the instance)
                query_templates[event_type][arg_name].append(query[:-1] + "在[trigger]中?")
            else:
                # 0 template arg_name
                query_templates[event_type][arg_name].append(arg_name)
                # 1 template arg_name + in trigger (replace [trigger] when forming the instance)
                query_templates[event_type][arg_name].append(arg_name + " in [trigger]")
                # 2 template arg_query
                query_templates[event_type][arg_name].append(query)
                # 3 arg_query + trigger (replace [trigger] when forming the instance)
                query_templates[event_type][arg_name].append(query[:-1] + " in [trigger]?")

    for event_type in query_templates:
        for arg_name in query_templates[event_type]:
            assert len(query_templates[event_type][arg_name]) == 4

    return query_templates

_get_best_indexes

Gets the n-best logits from a list. The methods returns a list containing the indexes of the n-best logits that satisfies both the logits are n-best and greater than the logit of the “cls” token.

Args:

def _get_best_indexes(logits: List[int],
                      n_best_size: Optional[int] = 1,
                      larger_than_cls: Optional[bool] = False,
                      cls_logit: Optional[int] = None) -> List[int]:
    """Gets the n-best logits from a list.
    Gets the n-best logits from a list. The methods returns a list containing the indexes of the n-best logits that
    satisfies both the logits are n-best and greater than the logit of the "cls" token.
    """
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        if larger_than_cls:
            if index_and_score[i][1] < cls_logit:
                break
        best_indexes.append(index_and_score[i][0])
    return best_indexes

char_pos_to_word_pos

Returns the word-level position of a mention by counting the number of words before the start position of the mention.

Args:

  • text: A string representing the source text that the mention is within.

  • position: An integer indicating the character-level position of the mention.

Returns:

  • An integer indicating the word-level position of the mention.

def char_pos_to_word_pos(text: str,
                         position: int) -> int:
    """Returns the word-level position of a mention.
    Returns the word-level position of a mention by counting the number of words before the start position of the
    mention.
    Args:
        text (`str`):
            A string representing the source text that the mention is within.
        position (`int`)
            An integer indicating the character-level position of the mention.
    Returns:
        An integer indicating the word-level position of the mention.
    """
    return len(text[:position].split())

make_predictions

Obtains the prediction from the Machine Reading Comprehension (MRC) model.

def make_predictions(all_start_logits, all_end_logits, training_args):
    """Obtains the prediction from the Machine Reading Comprehension (MRC) model."""
    data_for_evaluation = training_args.data_for_evaluation
    assert len(all_start_logits) == len(data_for_evaluation["ids"])
    # all golden labels
    final_all_labels = []
    for arguments in data_for_evaluation["golden_arguments"]:
        arguments_per_trigger = []
        for argument in arguments["arguments"]:
            event_argument_type = arguments["true_type"] + "_" + argument["role"]
            for mention in argument["mentions"]:
                arguments_per_trigger.append(
                    (event_argument_type, (mention["position"][0], mention["position"][1]), arguments["id"]))
        final_all_labels.extend(arguments_per_trigger)
    # predictions
    _PrelimPrediction = collections.namedtuple("PrelimPrediction",
                                               ["start_index", "end_index", "start_logit", "end_logit"])
    final_all_predictions = []
    for example_id, (start_logits, end_logits) in enumerate(zip(all_start_logits, all_end_logits)):
        event_argument_type = data_for_evaluation["pred_types"][example_id] + "_" + \
                              data_for_evaluation["roles"][example_id]
        start_indexes = _get_best_indexes(start_logits, 20, True, start_logits[0])
        end_indexes = _get_best_indexes(end_logits, 20, True, end_logits[0])
        # add span preds
        prelim_predictions = []
        for start_index in start_indexes:
            for end_index in end_indexes:
                if start_index < data_for_evaluation["text_range"][example_id]["start"] or \
                        end_index < data_for_evaluation["text_range"][example_id]["start"]:
                    continue
                if start_index >= data_for_evaluation["text_range"][example_id]["end"] or \
                        end_index >= data_for_evaluation["text_range"][example_id]["end"]:
                    continue
                if end_index < start_index:
                    continue
                word_start_index = start_index - 1
                word_end_index = end_index - 1
                length = word_end_index - word_start_index + 1
                if length > 5:
                    continue
                prelim_predictions.append(
                    _PrelimPrediction(start_index=word_start_index, end_index=word_end_index,
                                      start_logit=start_logits[start_index], end_logit=end_logits[end_index]))
        # sort
        prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
        # get final pred in format: [event_type_offset_argument_type, [start_offset, end_offset]]
        max_num_pred_per_arg = 1
        predictions_per_query = []
        for _, pred in enumerate(prelim_predictions[:max_num_pred_per_arg]):
            na_prob = (start_logits[0] + end_logits[0]) - (pred.start_logit + pred.end_logit)
            predictions_per_query.append((event_argument_type, (pred.start_index, pred.end_index), na_prob,
                                          data_for_evaluation["ids"][example_id]))
        final_all_predictions.extend(predictions_per_query)

    logger.info("\nAll predictions and labels generated. %d %d\n" % (len(final_all_predictions), len(final_all_labels)))
    return final_all_predictions, final_all_labels

find_best_thresh

def find_best_thresh(new_preds, new_all_gold):
    best_score = 0
    best_na_thresh = 0
    gold_arg_n, pred_arg_n = len(new_all_gold), 0

    candidate_preds = []
    for argument in new_preds:
        candidate_preds.append(argument[:-2] + argument[-1:])
        pred_arg_n += 1

        pred_in_gold_n, gold_in_pred_n = 0, 0
        # pred_in_gold_n
        for argu in candidate_preds:
            if argu in new_all_gold:
                pred_in_gold_n += 1
        # gold_in_pred_n
        for argu in new_all_gold:
            if argu in candidate_preds:
                gold_in_pred_n += 1

        prec_c, recall_c, f1_c = 0, 0, 0
        if pred_arg_n != 0:
            prec_c = 100.0 * pred_in_gold_n / pred_arg_n
        else:
            prec_c = 0
        if gold_arg_n != 0:
            recall_c = 100.0 * gold_in_pred_n / gold_arg_n
        else:
            recall_c = 0
        if prec_c or recall_c:
            f1_c = 2 * prec_c * recall_c / (prec_c + recall_c)
        else:
            f1_c = 0

        if f1_c > best_score:
            best_score = f1_c
            best_na_thresh = argument[-2]

    return best_na_thresh + 1e-10

compute_mrc_F1_cls

def compute_mrc_F1_cls(all_predictions, all_labels):
    all_predictions = sorted(all_predictions, key=lambda x: x[-2])
    # best_na_thresh = 0
    best_na_thresh = find_best_thresh(all_predictions, all_labels)
    print("Best thresh founded. %.6f" % best_na_thresh)

    final_new_preds = []
    for argument in all_predictions:
        if argument[-2] < best_na_thresh:
            final_new_preds.append(argument[:-2] + argument[-1:])  # no na_prob

    # get results (classification)
    gold_arg_n, pred_arg_n, pred_in_gold_n, gold_in_pred_n = 0, 0, 0, 0
    # pred_arg_n
    for argument in final_new_preds:
        pred_arg_n += 1
    # gold_arg_n
    for argument in all_labels:
        gold_arg_n += 1
    # pred_in_gold_n
    for argument in final_new_preds:
        if argument in all_labels:
            pred_in_gold_n += 1
    # gold_in_pred_n
    for argument in all_labels:
        if argument in final_new_preds:
            gold_in_pred_n += 1

    prec_c, recall_c, f1_c = 0, 0, 0
    if pred_arg_n != 0:
        prec_c = 100.0 * pred_in_gold_n / pred_arg_n
    else:
        prec_c = 0
    if gold_arg_n != 0:
        recall_c = 100.0 * gold_in_pred_n / gold_arg_n
    else:
        recall_c = 0
    if prec_c or recall_c:
        f1_c = 2 * prec_c * recall_c / (prec_c + recall_c)
    else:
        f1_c = 0

    logger.info("Precision: %.2f, recall: %.2f" % (prec_c, recall_c))
    return f1_c

Machine Reading Comprehension (MRC) Processor

import pdb
import json
import logging

from tqdm import tqdm
from .base_processor import (
    EAEDataProcessor,
    EAEInputExample,
    EAEInputFeatures
)
from .mrc_converter import read_query_templates
from .input_utils import get_words, get_left_and_right_pos
from collections import defaultdict

logger = logging.getLogger(__name__)

EAEMRCProcessor

Data processor for Machine Reading Comprehension (MRC) for event argument extraction. The class is inherited from the EAEDataProcessor class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; a new function entitled remove_sub_word() is defined to remove the annotations whose word is a sub-word, the rest of the attributes and functions are multiplexed from the EAEDataProcessor class.

class EAEMRCProcessor(EAEDataProcessor):
    """Data processor for Machine Reading Comprehension (MRC) for event argument extraction.

    Data processor for Machine Reading Comprehension (MRC) for event argument extraction. The class is inherited from
    the `EAEDataProcessor` class, in which the undefined functions, including `read_examples()` and
    `convert_examples_to_features()` are implemented; a new function entitled `remove_sub_word()` is defined to remove
    the annotations whose word is a sub-word, the rest of the attributes and functions are multiplexed from the
    `EAEDataProcessor` class.
    """

    def __init__(self,
                config,
                tokenizer,
                input_file: str,
                pred_file: str,
                is_training: bool = False) -> None:
        """Constructs a `EAEMRCProcessor`."""
        super().__init__(config, tokenizer, pred_file, is_training)
        self.read_examples(input_file)
        self.convert_examples_to_features()

    def read_examples(self,
                    input_file: str) -> None:
        """Obtains a collection of `EAEInputExample`s for the dataset."""
        self.examples = []
        self.data_for_evaluation["golden_arguments"] = []
        trigger_idx = 0
        query_templates = read_query_templates(self.config.prompt_file,
                                            translate=self.config.dataset_name == "ACE2005-ZH")
        template_id = self.config.mrc_template_id
        with open(input_file, "r", encoding="utf-8") as f:
            for idx, line in enumerate(tqdm(f.readlines(), desc="Reading from %s" % input_file)):
                item = json.loads(line.strip())
                if "events" in item:
                    words = get_words(text=item["text"], language=self.config.language)
                    for event in item["events"]:
                        for trigger in event["triggers"]:
                            if self.is_training or self.config.golden_trigger or self.event_preds is None:
                                pred_event_type = event["type"]
                            else:
                                pred_event_type = self.event_preds[trigger_idx]
                            trigger_idx += 1

                            # Evaluation mode for EAE
                            # If predicted event type is NA:
                            #   in [default] and [loose] modes, we don't consider the trigger
                            #   in [strict] mode, we consider the trigger
                            if self.config.eae_eval_mode in ["default", "loose"] and pred_event_type == "NA":
                                continue

                            # golden label for the trigger
                            arguments_per_trigger = dict(id=trigger_idx-1,
                                                        arguments=[],
                                                        pred_type=pred_event_type,
                                                        true_type=event["type"])
                            for argument in trigger["arguments"]:
                                arguments_per_role = dict(role=argument["role"], mentions=[])
                                for mention in argument["mentions"]:
                                    left_pos, right_pos = get_left_and_right_pos(text=item["text"],
                                                                                trigger=mention,
                                                                                language=self.config.language)
                                    arguments_per_role["mentions"].append({
                                        "position": [left_pos, right_pos - 1]
                                    })
                                arguments_per_trigger["arguments"].append(arguments_per_role)
                            self.data_for_evaluation["golden_arguments"].append(arguments_per_trigger)

                            if pred_event_type == "NA":
                                assert self.config.eae_eval_mode == "strict"
                                # in strict mode, we add the gold args for the trigger but do not make predictions
                                continue

                            trigger_left, trigger_right = get_left_and_right_pos(text=item["text"],
                                                                                trigger=trigger,
                                                                                language=self.config.language)

                            for role in query_templates[pred_event_type].keys():
                                query = query_templates[pred_event_type][role][template_id]
                                query = query.replace("[trigger]", self.tokenizer.tokenize(trigger["trigger_word"])[0])
                                query = get_words(text=query, language=self.config.language)
                                if self.is_training:
                                    no_answer = True
                                    for argument in trigger["arguments"]:
                                        if argument["role"] not in query_templates[pred_event_type]:
                                            # raise ValueError(
                                            #     "No template for %s in %s" % (argument["role"], pred_event_type)
                                            # )
                                            logger.warning(
                                                "No template for %s in %s" % (argument["role"], pred_event_type))
                                            pass
                                        if argument["role"] != role:
                                            continue
                                        no_answer = False
                                        for mention in argument["mentions"]:
                                            left_pos, right_pos = get_left_and_right_pos(text=item["text"],
                                                                                        trigger=mention,
                                                                                        language=self.config.language)
                                            example = EAEInputExample(
                                                example_id=trigger_idx-1,
                                                text=words,
                                                pred_type=pred_event_type,
                                                true_type=event["type"],
                                                input_template=query,
                                                trigger_left=trigger_left,
                                                trigger_right=trigger_right,
                                                argument_left=left_pos,
                                                argument_right=right_pos - 1,
                                                argument_role=role
                                            )
                                            self.examples.append(example)
                                    if no_answer:
                                        example = EAEInputExample(
                                            example_id=trigger_idx-1,
                                            text=words,
                                            pred_type=pred_event_type,
                                            true_type=event["type"],
                                            input_template=query,
                                            trigger_left=trigger_left,
                                            trigger_right=trigger_right,
                                            argument_left=-1,
                                            argument_right=-1,
                                            argument_role=role
                                        )
                                        self.examples.append(example)
                                else:
                                    # one instance per query
                                    example = EAEInputExample(
                                        example_id=trigger_idx-1,
                                        text=words,
                                        pred_type=pred_event_type,
                                        true_type=event["type"],
                                        input_template=query,
                                        trigger_left=trigger_left,
                                        trigger_right=trigger_right,
                                        argument_left=-1,
                                        argument_right=-1,
                                        argument_role=role
                                    )
                                    self.examples.append(example)
                    # negative triggers
                    for neg_trigger in item["negative_triggers"]:
                        if self.is_training or self.config.golden_trigger or self.event_preds is None:
                            pred_event_type = "NA"
                        else:
                            pred_event_type = self.event_preds[trigger_idx]

                        trigger_idx += 1
                        if self.config.eae_eval_mode == "loose":
                            continue
                        elif self.config.eae_eval_mode in ["default", "strict"]:
                            if pred_event_type == "NA":
                                continue
                            trigger_left, trigger_right = get_left_and_right_pos(text=item["text"],
                                                                                trigger=neg_trigger,
                                                                                language=self.config.language)
                            for role in query_templates[pred_event_type].keys():
                                query = query_templates[pred_event_type][role][template_id]
                                query = query.replace("[trigger]",
                                                    self.tokenizer.tokenize(neg_trigger["trigger_word"])[0])
                                query = get_words(text=query, language=self.config.language)
                                # one instance per query
                                example = EAEInputExample(
                                    example_id=trigger_idx-1,
                                    text=words,
                                    pred_type=pred_event_type,
                                    true_type="NA",
                                    input_template=query,
                                    trigger_left=trigger_left,
                                    trigger_right=trigger_right,
                                    argument_left=-1,
                                    argument_right=-1,
                                    argument_role=role
                                )
                                self.examples.append(example)

                        else:
                            raise ValueError("Invalid eae_eval_mode: %s" % self.config.eae_eval_mode)
                else:
                    for candi in item["candidates"]:
                        trigger_left, trigger_right = get_left_and_right_pos(text=item["text"],
                                                                            trigger=candi,
                                                                            language=self.config.language)
                        pred_event_type = self.event_preds[trigger_idx]
                        trigger_idx += 1
                        if pred_event_type != "NA":
                            for role in query_templates[pred_event_type].keys():
                                query = query_templates[pred_event_type][role][template_id]
                                query = query.replace("[trigger]", self.tokenizer.tokenize(candi["trigger_word"])[0])
                                query = get_words(text=query, language=self.config.language)
                                # one instance per query
                                example = EAEInputExample(
                                    example_id=trigger_idx-1,
                                    text=words,
                                    pred_type=pred_event_type,
                                    true_type="NA",
                                    input_template=query,
                                    trigger_left=trigger_left,
                                    trigger_right=trigger_right,
                                    argument_left=-1,
                                    argument_right=-1,
                                    argument_role=role
                                )
                                self.examples.append(example)
            if self.event_preds is not None:
                assert trigger_idx == len(self.event_preds)

    def convert_examples_to_features(self) -> None:
        """Converts the `EAEInputExample`s into `EAEInputFeatures`s."""
        self.input_features = []
        self.data_for_evaluation["text_range"] = []
        self.data_for_evaluation["text"] = []

        for example in tqdm(self.examples, desc="Processing features for MRC"):
            # context
            input_context = self.tokenizer(example.text,
                                        truncation=True,
                                        max_length=self.config.max_seq_length,
                                        is_split_into_words=True)
            # template
            input_template = self.tokenizer(example.input_template,
                                            truncation=True,
                                            padding="max_length",
                                            max_length=self.config.max_seq_length,
                                            is_split_into_words=True)

            input_context = self.remove_sub_word(input_context)
            # concatenate
            input_ids = input_context["input_ids"] + input_template["input_ids"]
            attention_mask = input_context["attention_mask"] + input_template["attention_mask"]
            token_type_ids = [0] * len(input_context["input_ids"]) + [1] * len(input_template["input_ids"])
            # truncation
            input_ids = input_ids[:self.config.max_seq_length]
            attention_mask = attention_mask[:self.config.max_seq_length]
            token_type_ids = token_type_ids[:self.config.max_seq_length]
            # output labels
            start_position = 0 if example.argument_left == -1 else example.argument_left + 1
            end_position = 0 if example.argument_right == -1 else example.argument_right + 1
            # data for evaluation
            text_range = dict()
            text_range["start"] = 1
            text_range["end"] = text_range["start"] + sum(input_context["attention_mask"][1:])
            self.data_for_evaluation["text_range"].append(text_range)
            self.data_for_evaluation["text"].append(example.text)
            # features
            features = EAEInputFeatures(
                example_id=example.example_id,
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                argument_left=start_position,
                argument_right=end_position
            )
            self.input_features.append(features)

    @staticmethod
    def remove_sub_word(inputs):
        """Removes the annotations whose word is a sub-word."""
        outputs = defaultdict(list)
        pre_word_id = -1
        for token_id, word_id in enumerate(inputs.word_ids()):
            if token_id == 0 or (word_id != pre_word_id and word_id is not None):
                for key in inputs:
                    outputs[key].append(inputs[key][token_id])
            pre_word_id = word_id
        return outputs

Input Utils

get_bio_labels

Generates the id of the BIO labels corresponding to the original label. The correspondences between the BIO labels and their ids are saved in a dictionary.

Args:

  • original_labels: A list of strings representing the original labels within the dataset.

  • labels_to_exclude: A list of strings indicating the labels excluded to use, the id of which would not be generated.

Returns:

  • bio_labels: A dictionary containing the correspondence the BIO labels and their ids.

def get_bio_labels(original_labels: List[str],
                   labels_to_exclude: Optional[List[str]] = ["NA"]) -> Dict[str, int]:
    """Generates the id of the BIO labels corresponding to the original label.

    Generates the id of the BIO labels corresponding to the original label. The correspondences between the BIO labels
    and their ids are saved in a dictionary.

    Args:
        original_labels (`List[str]`):
            A list of strings representing the original labels within the dataset.
        labels_to_exclude (`List[str]`, `optional`, defaults to ["NA"]):
            A list of strings indicating the labels excluded to use, the id of which would not be generated.

    Returns:
        bio_labels (`Dict[str, int]`):
            A dictionary containing the correspondence the BIO labels and their ids.
    """
    bio_labels = {"O": 0}
    for label in original_labels:
        if label in labels_to_exclude:
            continue
        bio_labels[f"B-{label}"] = len(bio_labels)
        bio_labels[f"I-{label}"] = len(bio_labels)
    return bio_labels

get_start_poses

Obtains the start position of each word within the sentence. The character-level start positions of each word are stored in a list.

Args:

  • sentence: A string representing the input sentence.

Returns:

  • start_poses: A list of integers representing the character-level start position of each word within the sentence.

def get_start_poses(sentence: str) -> List[int]:
    """Obtains the start position of each word within the sentence.

    Obtains the start position of each word within the sentence. The character-level start positions of each word are
    stored in a list.

    Args:
        sentence (`str`):
            A string representing the input sentence.

    Returns:
        start_poses (`List[int]`):
            A list of integers representing the character-level start position of each word within the sentence.
    """
    words = sentence.split()
    start_pos = 0
    start_poses = []
    for word in words:
        start_poses.append(start_pos)
        start_pos += len(word) + 1
    return start_poses

check_if_start

Check whether the start position of the mention is the beginning of a word, that is, check whether a trigger or an argument is a sub-word.

Args: - start_poses: A list of integers representing the character-level start position of each word within the sentence. - char_pos: A list of integers indicating the start and end position of a mention.

Returns:

Returns True if the start position of the mention is the start of a word; returns False otherwise.

def check_if_start(start_poses: List[int],
                   char_pos: List[int]) -> bool:
    """Check whether the start position of the mention is the beginning of a word.

    Check whether the start position of the mention is the beginning of a word, that is, check whether a trigger or an
    argument is a sub-word.

    Args:
        start_poses (`List[int]`):
            A list of integers representing the character-level start position of each word within the sentence.
        char_pos (`List[int]`):
            A list of integers indicating the start and end position of a mention.

    Returns:
        Returns `True` if the start position of the mention is the start of a word; returns `False` otherwise.
    """
    if char_pos[0] in start_poses:
        return True
    return False

get_word_position

Returns the word-level position of a given mention by matching the index of its character-level start position in the list containing the start position of each word within the sentence.

Args:

  • start_poses: A list of integers representing the character-level start position of each word within the sentence.

  • char_pos: A list of integers indicating the start and end position of a given mention.

Returns:

An integer indicating the word-level position of the given mention.

def get_word_position(start_poses: List[int],
                      char_pos: List[int]) -> int:
    """Returns the word-level position of a given mention.

    Returns the word-level position of a given mention by matching the index of its character-level start position in
    the list containing the start position of each word within the sentence.

    Args:
        start_poses (`List[int]`):
            A list of integers representing the character-level start position of each word within the sentence.
        char_pos (`List[int]`)
            A list of integers indicating the start and end position of a given mention.

    Returns:
        `int`:
            An integer indicating the word-level position of the given mention.
    """
    return start_poses.index(char_pos[0])

get_words

Obtains the words within the source text. The recognition of words differs according to language. The words are obtained through splitting white spaces in English, while each Chinese character is regarded as a word in Chinese.

Args:

  • text: A string representing the input source text.

  • language: A string indicating the language of the source text, English or Chinese.

Returns:

  • words : A list of strings containing the words within the source text.

def get_words(text: str,
              language: str) -> List[str]:
    """Obtains the words within the given text.

    Obtains the words within the source text. The recognition of words differs according to language. The words are
    obtained through splitting white spaces in English, while each Chinese character is regarded as a word in Chinese.

    Args:
        text (`str`):
            A string representing the input source text.
        language (`str`):
            A string indicating the language of the source text, English or Chinese.

    Returns:
        words (`List[str]`):
            A list of strings containing the words within the source text.
    """
    if language == "English":
        words = text.split()
    elif language == "Chinese":
        words = list(text)
    else:
        raise NotImplementedError
    return words

get_left_and_right_pos

Obtains the word-level position of the trigger word’s start and end position. The method of obtaining the position differs according to language. The method returns the number of words before the given position for English texts, while for Chinese, each character is regarded as a word.

Args:

  • text: A string representing the source text that the trigger word is within.

  • trigger: A dictionary containing the trigger word, position, and arguments of an event trigger.

  • language: A string indicating the language of the source text and trigger word, English or Chinese.

  • keep_space: A flag that indicates whether to keep the space in Chinese text during offset calculating. During data preprocessing, the space has to be kept due to the offsets consider space. During evaluation, the space is automatically removed by the tokenizer and the output hidden states do not involve space logits, therefore, offset counting should not keep the space.

def get_left_and_right_pos(text: str,
                           trigger: Dict[str, Union[int, str, List[int], List[Dict]]],
                           language: str,
                           keep_space: bool = False) -> Tuple[int, int]:
    """Obtains the word-level position of the trigger word's start and end position.
    Obtains the word-level position of the trigger word's start and end position. The method of obtaining the position
    differs according to language. The method returns the number of words before the given position for English texts,
    while for Chinese, each character is regarded as a word.
    Args:
        text (`str`):
            A string representing the source text that the trigger word is within.
        trigger (`Dict[str, Union[int, str, List[int], List[Dict]]]`):
            A dictionary containing the trigger word, position, and arguments of an event trigger.
        language (`str`):
            A string indicating the language of the source text and trigger word, English or Chinese.
        keep_space (`bool`):
            A flag that indicates whether to keep the space in Chinese text during offset calculating.
                During data preprocessing, the space has to be kept due to the offsets consider space.
                During evaluation, the space is automatically removed by the tokenizer and the output hidden states do
                not involve space logits, therefore, offset counting should not keep the space.
    Returns:
        left_pos (`int`), right_pos (`int`):
            Two integers indicating the number of words before the start and end position of the trigger word.
    """
    if language == "English":
        left_pos = len(text[:trigger["position"][0]].split())
        right_pos = len(text[:trigger["position"][1]].split())
    elif language == "Chinese":
        left_pos = trigger["position"][0] if keep_space else len("".join(text[:trigger["position"][0]].split()))
        right_pos = trigger["position"][1] if keep_space else len("".join(text[:trigger["position"][1]].split()))
    else:
        raise NotImplementedError
    return left_pos, right_pos

get_word_ids

Returns a list indicating the word corresponding to each token. Special tokens added by the tokenizer are mapped to None and other tokens are mapped to the index of their corresponding word (several tokens will be mapped to the same word index if they are parts of that word).

Args:

  • tokenizer: The tokenizer that has been used for word tokenization.

  • outputs: The outputs of the tokenizer.

  • word_list: A list of word strings.

Returns:

  • word_ids: A list mapping the tokens to their actual word in the initial sentence.

def get_word_ids(tokenizer: PreTrainedTokenizer,
                 outputs: BatchEncoding,
                 word_list: List[str]) -> List[int]:
    """Return a list mapping the tokens to their actual word in the initial sentence for a tokenizer.
    Return a list indicating the word corresponding to each token. Special tokens added by the tokenizer are mapped to
    None and other tokens are mapped to the index of their corresponding word (several tokens will be mapped to the same
    word index if they are parts of that word).
    Args:
        tokenizer (`PreTrainedTokenizer`):
            The tokenizer that has been used for word tokenization.
        outputs (`BatchEncoding`):
            The outputs of the tokenizer.
        word_list (`List[str]`):
            A list of word strings.
    Returns:
        word_ids (`List[int]`):
            A list mapping the tokens to their actual word in the initial sentence
    """
    word_list = [w.lower() for w in word_list]
    try:
        word_ids = outputs.word_ids()
        return word_ids
    except:
        assert isinstance(tokenizer, WordLevelTokenizer)
        pass
    tokens = tokenizer.convert_ids_to_tokens(outputs["input_ids"])
    word_ids = []
    word_idx = 0

    for token in tokens:
        if token not in word_list and token != "[UNK]":
            word_ids.append(None)
        else:
            if token != "[UNK]":
                assert token == word_list[word_idx]
            word_ids.append(word_idx)
            word_idx += 1
    return word_ids

check_pred_len

Check whether the length of the prediction sequence equals that of the original word sequence. The prediction sequence consists of prediction for each word in the original sentence. Sometimes, there might be special tokens or extra space in the original sentence, and the tokenizer will automatically ignore them, which may cause the output length differs from the input length.

Args:

  • pred: A list of predicted event types or argument roles.

  • item: A single item of the training/valid/test data.

  • language: The language of the input text.

def check_pred_len(pred: List[str],
                   item: Dict[str, Union[str, List[dict]]],
                   language: str) -> None:
    """Check whether the length of the prediction sequence equals that of the original word sequence.
    The prediction sequence consists of prediction for each word in the original sentence. Sometimes, there might be
    special tokens or extra space in the original sentence, and the tokenizer will automatically ignore them, which may
    cause the output length differs from the input length.
    Args:
        pred (`List[str]`):
            A list of predicted event types or argument roles.
        item (`Dict[str, Union[str, List[dict]]]`):
            A single item of the training/valid/test data.
        language ('str'):
            The language of the input text.
    Returns:
        None.
    """
    if language == "English":
        if len(pred) != len(item["text"].split()):
            logger.warning("There might be special tokens in the input text: {}".format(item["text"]))

    elif language == "Chinese":
        if len(pred) != len("".join(item["text"].split())):  # remove space token
            logger.warning("There might be special tokens in the input text: {}".format(item["text"]))
    else:
        raise NotImplementedError

get_ed_candidates

Obtain the candidate tokens for the event detection (ED) task. The unified evaluation considers prediction of each token that is possibly a trigger (ED candidate).

Args:

  • item: A single item of the training/valid/test data.

Returns: - candidates: A list of dictionary that contains the possible trigger. - label_names: A list of string contains the ground truth label for each possible trigger.

def get_ed_candidates(item: Dict[str, Union[str, List[dict]]]) -> Tuple[List[dict], List[str]]:
    """Obtain the candidate tokens for the event detection (ED) task.
    The unified evaluation considers prediction of each token that is possibly a trigger (ED candidate).
    Args:
        item (`Dict[str, Union[str, List[dict]]]`):
            A single item of the training/valid/test data.
    Returns:
        candidates(`List[dict]`), label_names (`List[str]`):
            candidates: A list of dictionary that contains the possible trigger.
            label_names: A list of string contains the ground truth label for each possible trigger.
    """
    candidates = []
    label_names = []
    if "events" in item:
        for event in item["events"]:
            for trigger in event["triggers"]:
                label_names.append(event["type"])
                candidates.append(trigger)
        for neg_trigger in item["negative_triggers"]:
            label_names.append("NA")
            candidates.append(neg_trigger)
    else:
        candidates = item["candidates"]

    return candidates, label_names

check_is_argument

Checks whether a given mention is argument or not. If it is an argument, we have to exclude it from the negative arguments list.

Args:

  • mention: The mention that contains the word, position and other meta information like id, etc.

  • positive_offsets: A list that contains the offsets of all the ground truth arguments.

Returns:

  • is_argument: A flag that indicates whether the mention is an argument or not.

def check_is_argument(mention: Dict[str, Union[str, dict]] = None,
                      positive_offsets: List[Tuple[int, int]] = None) -> bool:
    """Check whether a given mention is argument or not.
    Check whether a given mention is argument or not. If it is an argument, we have to exclude it from the negative
    arguments list.
    Args:
        mention (`Dict[str, Union[str, dict]]`):
            The mention that contains the word, position and other meta information like id, etc.
        positive_offsets (`List[Tuple[int, int]]`):
            A list that contains the offsets of all the ground truth arguments.
    Returns:
        is_argument(`bool`):
            A flag that indicates whether the mention is an argument or not.
    """
    is_argument = False
    if positive_offsets:
        mention_set = set(range(mention["position"][0], mention["position"][1]))
        for pos_offset in positive_offsets:
            pos_set = set(range(pos_offset[0], pos_offset[1]))
            if not pos_set.isdisjoint(mention_set):
                is_argument = True
                break
    return is_argument

get_negative_argument_candidates

Obtains the negative candidate arguments, which are not included in the actual arguments list, for the specified trigger. The unified evaluation considers prediction of each token that is possibly an argument (EAE candidate).

Args:

  • item:`` A single item of the training/valid/test data.

  • positive_offsets: A list that contains the offsets of all the ground truth arguments.

Returns:

  • candidates: A list of dictionary that contains the possible arguments.

  • label_names: A list of string contains the ground truth label for each possible argument.

def get_negative_argument_candidates(item: Dict[str, Union[str, List[dict]]],
                                     positive_offsets: List[Tuple[int, int]] = None,
                                     ) -> List[Dict[str, Union[str, dict]]]:
    """Obtain the negative candidate arguments for each trigger in the event argument extraction (EAE) task.
    Obtain the negative candidate arguments, which are not included in the actual arguments list, for the specified
    trigger. The unified evaluation considers prediction of each token that is possibly an argument (EAE candidate).
    Args:
        item (`Dict[str, Union[str, List[dict]]]`):
            A single item of the training/valid/test data.
        positive_offsets (`List[Tuple[int, int]]`):
            A list that contains the offsets of all the ground truth arguments.
    Returns:
        candidates(`List[dict]`), label_names (`List[str]`):
            candidates: A list of dictionary that contains the possible arguments.
            label_names: A list of string contains the ground truth label for each possible argument.
    """
    if "entities" in item:
        neg_arg_candidates = []
        for entity in item["entities"]:
            ent_is_arg = any([check_is_argument(men, positive_offsets) for men in entity["mentions"]])
            neg_arg_candidates.extend([] if ent_is_arg else entity["mentions"])
    else:
        neg_arg_candidates = item["negative_triggers"]
    return neg_arg_candidates

get_eae_candidates

Obtains the candidate arguments for each trigger in the event argument extraction (EAE) task. The unified evaluation considers prediction of each token that is possibly an argument (EAE candidate). And the EAE

Args:

  • item: A single item of the training/valid/test data.

  • trigger: A single item of trigger in the item.

Returns:

  • candidates: A list of dictionary that contains the possible arguments.

  • label_names: A list of string contains the ground truth label for each possible argument.task requires the model to predict the argument role of each candidate given a specific trigger.

def get_eae_candidates(item: Dict[str, Union[str, List[dict]]],
                       trigger: Dict[str, Union[str, dict]]) -> Tuple[List[dict], List[str]]:
    """Obtain the candidate arguments for each trigger in the event argument extraction (EAE) task.
    The unified evaluation considers prediction of each token that is possibly an argument (EAE candidate). And the EAE
    task requires the model to predict the argument role of each candidate given a specific trigger.
    Args:
        item (`Dict[str, Union[str, List[dict]]]`):
            A single item of the training/valid/test data.
        trigger (`Dict[str, Union[str, List[dict]]`):
            A single item of trigger in the item.
    Returns:
        candidates(`List[dict]`), label_names (`List[str]`):
            candidates: A list of dictionary that contains the possible arguments.
            label_names: A list of string contains the ground truth label for each possible argument.
    """
    candidates = []
    positive_offsets = []
    label_names = []
    if "arguments" in trigger:
        arguments = sorted(trigger["arguments"], key=lambda a: a["role"])
        for argument in arguments:
            for mention in argument["mentions"]:
                label_names.append(argument["role"])
                candidates.append(mention)
                positive_offsets.append(mention["position"])

    neg_arg_candidates = get_negative_argument_candidates(item, positive_offsets=positive_offsets)

    for neg in neg_arg_candidates:
        is_argument = check_is_argument(neg, positive_offsets)
        if not is_argument:
            label_names.append("NA")
            candidates.append(neg)

    return candidates, label_names

get_event_preds

Loads the event detection predictions of each token for event argument extraction. The Event Argument Extraction task requires the event detection predictions. If the event prediction file exists, we use the predictions by the event detection model. Otherwise, we use the golden event type for each token.

Args:

  • pred_file: The file that contains the event detection predictions for each token.

Returns:

  • event_preds: A list of the predicted event types for each token.

def get_event_preds(pred_file: Union[str, Path]) -> List[str]:
    """Load the event detection predictions of each token for event argument extraction.
    The Event Argument Extraction task requires the event detection predictions. If the event prediction file exists,
    we use the predictions by the event detection model. Otherwise, we use the golden event type for each token.
    Args:
        pred_file (`Union[str, Path]`):
            The file that contains the event detection predictions for each token.
    Returns:
        event_preds (`List[str`]):
            A list of the predicted event types for each token.
    """
    if pred_file is not None and os.path.exists(pred_file):
        event_preds = json.load(open(pred_file))
    else:
        event_preds = None
        logger.info("Load {} failed, using golden triggers for EAE evaluation".format(pred_file))

    return event_preds

get_plain_label

This function is used in the Seq2seq paradigm that the model has to generate the event types and argument roles. Some event types and argument roles are formatted, such as Attack.Time-Start, we convert them in to a plain one, like timestart, by removing the event type in the front and shifting upper case to lower case.

def get_plain_label(input_label: str) -> str:
    """Convert the formatted original event type or argument role to a plain one.
    This function is used in the Seq2seq paradigm that the model has to generate the event types and argument roles.
    Some  event types and argument roles are formatted, such as `Attack.Time-Start`, we convert them in to a plain
    one, like `timestart`, by removing the event type in the front and shifting upper case to lower case.
    Args:
        input_label (`str`):
            The original label with format.
    Returns:
        return_label (`str`):
            The plain label without format.
    """
    if input_label == "NA":
        return input_label

    return_label = "".join("".join(input_label.split(".")[-1].split("-")).split("_")).lower()

    return return_label

str_full_to_half

Converts a full-width string to a half-width one. The corpus of some datasets contain full-width strings, which may bring about unexpected error for mapping the tokens to the original input sentence.

Args:

  • ustring: Original string.

Returns:

  • rstring: Output string with the full-width tokens converted

def str_full_to_half(ustring: str) -> str:
    """Convert a full-width string to a half-width one.
    The corpus of some datasets contain full-width strings, which may bring about unexpected error for mapping the
    tokens to the original input sentence.
    Args:
        ustring(`str`):
            Original string.
    Returns:
        rstring (`str`):
            Output string with the full-width tokens converted
    """
    rstring = ""
    for uchar in ustring:
        inside_code = ord(uchar)
        if inside_code == 12288:   # full width space
            inside_code = 32
        elif 65281 <= inside_code <= 65374:    # full width char (exclude space)
            inside_code -= 65248
        rstring += chr(inside_code)
    return rstring

Backbone

import numpy as np
import os
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F

from audioop import bias
from typing import List, Optional, Tuple, Union
from unicodedata import bidirectional

from transformers import BertModel, BertTokenizerFast
from transformers import RobertaModel, RobertaTokenizerFast
from transformers import T5ForConditionalGeneration, T5TokenizerFast
from transformers import MT5ForConditionalGeneration
from transformers import BartForConditionalGeneration, BartTokenizerFast
from transformers.utils import ModelOutput

from ..input_engineering.whitespace_tokenizer import WordLevelTokenizer, load_vocab, VOCAB_FILES_NAMES

get_backbone

Obtains the backbone model and tokenizer. The backbone model is selected from BERT, RoBERTa, T5, MT5, CNN, and LSTM, corresponding to a distinct tokenizer.

Args:

  • model_type: A string indicating the model being used as the backbone network.

  • model_name_or_path: A string indicating the path of the pre-trained model.

  • tokenizer_name: A string indicating the repository name for the model in the hub or a path to a local folder.

  • markers: A list of strings to mark the start and end position of event triggers and argument mentions.

  • model_args: The pre-defined arguments for the model.

  • new_tokens: A list of strings indicating new tokens to be added to the tokenizer’s vocabulary.

Returns:

  • model: The backbone model, which is selected from BERT, RoBERTa, T5, MT5, CNN, and LSTM.

  • tokenizer: The tokenizer proposed for the tokenization process, corresponds to the backbone model.

  • config: The configurations of the model.

def get_backbone(model_type: str,
                 model_name_or_path: str,
                 tokenizer_name: str,
                 markers: List[str],
                 model_args: Optional = None,
                 new_tokens: Optional[List[str]] = []):
    """Obtains the backbone model and tokenizer.
    Obtains the backbone model and tokenizer. The backbone model is selected from BERT, RoBERTa, T5, MT5, CNN, and LSTM,
    corresponding to a distinct tokenizer.
    Args:
        model_type (`str`):
            A string indicating the model being used as the backbone network.
        model_name_or_path (`str`):
            A string indicating the path of the pre-trained model.
        tokenizer_name (`str`):
            A string indicating the repository name for the model in the hub or a path to a local folder.
        markers (`List[str]`):
            A list of strings to mark the start and end position of event triggers and argument mentions.
        model_args (`optional`, defaults to `None`):
            The pre-defined arguments for the model.  TODO: The data type of `model_args` should be configured.
        new_tokens (`List[str]`, `optional`, defaults to []):
            A list of strings indicating new tokens to be added to the tokenizer's vocabulary.
    Returns:
        model (`Union[BertModel, RobertaModel, T5ForConditionalGeneration, CNN, LSTM]`):
            The backbone model, which is selected from BERT, RoBERTa, T5, MT5, CNN, and LSTM.
        tokenizer (`str`):
            The tokenizer proposed for the tokenization process, corresponds to the backbone model.
        config:
            The configurations of the model.         TODO: The data type of `config` should be configured.
    """
    if model_type == "bert":
        model = BertModel.from_pretrained(model_name_or_path)
        tokenizer = BertTokenizerFast.from_pretrained(tokenizer_name, never_split=markers)
    elif model_type == "roberta":
        model = RobertaModel.from_pretrained(model_name_or_path)
        tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_name, never_split=markers, add_prefix_space=True)
    elif model_type == "bart":
        model = BartForConditionalGeneration.from_pretrained(model_name_or_path)
        tokenizer = BartTokenizerFast.from_pretrained(tokenizer_name, never_split=markers, add_prefix_space=True)
    elif model_type == "t5":
        model = T5ForConditionalGeneration.from_pretrained(model_name_or_path)
        tokenizer = T5TokenizerFast.from_pretrained(tokenizer_name, never_split=markers)
    elif model_type == "mt5":
        model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path)
        tokenizer = T5TokenizerFast.from_pretrained(tokenizer_name, never_split=markers)
    elif model_type == "cnn":
        tokenizer = WordLevelTokenizer.from_pretrained(model_args.vocab_file)
        model = CNN(model_args, len(tokenizer))
    elif model_type == 'lstm':
        tokenizer = WordLevelTokenizer.from_pretrained(model_args.vocab_file)
        model = LSTM(model_args, len(tokenizer))
    else:
        raise ValueError("No such model. %s" % model_type)

    for token in new_tokens:
        tokenizer.add_tokens(token, special_tokens=True)
    if len(new_tokens) > 0:
        model.resize_token_embeddings(len(tokenizer))

    config = model.config
    return model, tokenizer, config

WordEmbedding

Base class for word embedding, in which the word embeddings are loaded from a pre-trained word embedding file and could be resized into a distinct size.

Attributes:

  • word_embeddings: A tensor representing the word embedding matrix, whose dimension is (number of tokens) * (embedding dimension).

  • position_embeddings: A tensor representing the position embedding matrix, whose dimension is (number of positions) * (embedding dimension).

  • dropout: An nn.Dropout layer for the dropout operation with the pre-defined dropout rate.

class WordEmbedding(nn.Module):
    """Base class for word embedding.
    Base class for word embedding, in which the word embeddings are loaded from a pre-trained word embedding file and
    could be resized into a distinct size.
    Attributes:
        word_embeddings (`torch.Tensor`):
            A tensor representing the word embedding matrix, whose dimension is (number of tokens) * (embedding
            dimension).
        position_embeddings (`torch.Tensor`):
            A tensor representing the position embedding matrix, whose dimension is (number of positions) * (embedding
            dimension).
        dropout (`nn.Dropout`):
            An `nn.Dropout` layer for the dropout operation with the pre-defined dropout rate.
    """
    def __init__(self,
                 config,
                 vocab_size: int) -> None:
        """Constructs a `WordEmbedding`."""
        super(WordEmbedding, self).__init__()
        if not os.path.exists(os.path.join(config.vocab_file, VOCAB_FILES_NAMES["vocab_file"].replace("txt", "npy"))):
            embeddings = load_vocab(os.path.join(config.vocab_file, VOCAB_FILES_NAMES["vocab_file"]),
                                    return_embeddings=True)
            np.save(os.path.join(config.vocab_file, VOCAB_FILES_NAMES["vocab_file"].replace("txt", "npy")), embeddings)
        else:
            embeddings = np.load(os.path.join(config.vocab_file, VOCAB_FILES_NAMES["vocab_file"].replace("txt", "npy")))
        self.word_embeddings = nn.Embedding.from_pretrained(torch.tensor(embeddings), freeze=False, padding_idx=0)
        self.position_embeddings = nn.Embedding(config.num_position_embeddings, config.position_embedding_dim)
        self.register_buffer("position_ids", torch.arange(config.num_position_embeddings).expand((1, -1)))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.resize_token_embeddings(vocab_size)

    def resize_token_embeddings(self,
                                vocab_size: int) -> None:
        """Resizes the embeddings from the pre-trained embedding dimension to pre-defined embedding size."""
        if len(self.word_embeddings.weight) > vocab_size:
            raise ValueError("Invalid vocab_size %d < original vocab size." % vocab_size)
        elif len(self.word_embeddings.weight) == vocab_size:
            pass
        else:
            num_added_token = vocab_size - len(self.word_embeddings.weight)
            embedding_dim = self.word_embeddings.weight.shape[1]
            average_embedding = torch.mean(self.word_embeddings.weight, dim=0).expand(1, -1)
            self.word_embeddings.weight = nn.Parameter(torch.cat(
                (
                    self.word_embeddings.weight.data,
                    average_embedding.expand(num_added_token, embedding_dim)
                )
            ))

    def forward(self,
                input_ids: torch.Tensor,
                position_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Generates word embeddings and position embeddings and concatenates them together."""
        input_shape = input_ids.size()
        batch_size, seq_length = input_shape[0], input_shape[1]
        if position_ids is None:
            position_ids = self.position_ids[:, :seq_length].expand(batch_size, seq_length)
        # input embeddings & position embeddings
        inputs_embeds = self.word_embeddings(input_ids)
        position_embeds = self.position_embeddings(position_ids)
        embeds = torch.cat((inputs_embeds, position_embeds), dim=-1)
        embeds = self.dropout(embeds)
        return embeds

Output

A class for the model’s output, containing the hidden states of the sequence.

class Output(ModelOutput):
    """A class for the model's output, containing the hidden states of the sequence."""
    last_hidden_state: torch.Tensor = None

CNN

A Convolutional Neural Network (CNN) as the backbone model, which comprises a 1-d convolutional layer, a relu activation layer, and a dropout layer. The last hidden state of the model would be returned.

Attributes:

  • config: The configurations of the model.

  • embedding: A WordEmbedding instance representing the embedding matrices of tokens and positions.

  • conv: A nn.Conv1d layer representing 1-dimensional convolution layer.

  • dropout: An nn.Dropout layer for the dropout operation with the pre-defined dropout rate.

class CNN(nn.Module):
    """A Convolutional Neural Network (CNN) as backbone model.
    A Convolutional Neural Network (CNN) as the backbone model, which comprises a 1-d convolutional layer, a relu
    activation layer, and a dropout layer. The last hidden state of the model would be returned.
    Attributes:
        config:
            The configurations of the model.
        embedding (`WordEmbedding`):
            A `WordEmbedding` instance representing the embedding matrices of tokens and positions.
        conv (`nn.Conv1d`):
            A `nn.Conv1d` layer representing 1-dimensional convolution layer.
        dropout (`nn.Dropout`):
            An `nn.Dropout` layer for the dropout operation with the pre-defined dropout rate.
    """
    def __init__(self,
                 config,
                 vocab_size: int,
                 kernel_size: Optional[int] = 3,
                 padding_size: Optional[int] = 1) -> None:
        """Constructs a `CNN`."""
        super(CNN, self).__init__()
        self.config = config
        self.embedding = WordEmbedding(config, vocab_size)
        self.conv = nn.Conv1d(config.word_embedding_dim + config.position_embedding_dim,
                              config.hidden_size,
                              kernel_size,
                              padding=padding_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def resize_token_embeddings(self,
                                vocab_size: int) -> None:
        """Resizes the embeddings from the pre-trained embedding dimension to pre-defined embedding size."""
        self.embedding.resize_token_embeddings(vocab_size)

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: torch.Tensor,
                return_dict: Optional[bool] = True) -> Union[Output, Tuple[torch.Tensor]]:
        """Conducts the convolution operations on the input tokens."""
        x = self.embedding(input_ids)  # (B, L, H)
        x = x.transpose(1, 2)  # (B, H, L)
        x = F.relu(self.conv(x).transpose(1, 2))  # (B, H, L)
        x = self.dropout(x)
        if return_dict:
            return Output(last_hidden_state=x)
        else:
            return x

LSTM

A bidirectional two-layered Long Short-Term Memory (LSTM) network as the backbone model, which utilizes recurrent computations for hidden states and addresses long-term information preservation and short-term input skipping using gated memory cells.

Attributes:

  • config: The configurations of the model.

  • embedding: A WordEmbedding instance representing the embedding matrices of tokens and positions.

  • rnn: A nn.LSTM layer representing a bi-directional two-layered LSTM network, which manipulates the word embedding and position embedding for recurrent computations.

  • dropout: An nn.Dropout layer for the dropout operation with the pre-defined dropout rate.

class LSTM(nn.Module):
    """A Long Short-Term Memory (LSTM) network as backbone model.
    A bidirectional two-layered Long Short-Term Memory (LSTM) network as the backbone model, which utilizes recurrent
    computations for hidden states and addresses long-term information preservation and short-term input skipping
    using gated memory cells.
    Attributes:
        config:
            The configurations of the model.
        embedding (`WordEmbedding`):
            A `WordEmbedding` instance representing the embedding matrices of tokens and positions.
        rnn (`nn.LSTM`):
            A `nn.LSTM` layer representing a bi-directional two-layered LSTM network, which manipulates the word
            embedding and position embedding for recurrent computations.
        dropout (`nn.Dropout`):
            An `nn.Dropout` layer for the dropout operation with the pre-defined dropout rate.
       """
    def __init__(self,
                 config,
                 vocab_size: int) -> None:
        """Constructs a `LSTM`."""
        super(LSTM, self).__init__()
        self.config = config
        self.embedding = WordEmbedding(config, vocab_size)
        self.rnn = nn.LSTM(config.word_embedding_dim + config.position_embedding_dim,
                           config.hidden_size,
                           num_layers=2,
                           bidirectional=True,
                           batch_first=True,
                           dropout=config.hidden_dropout_prob)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def resize_token_embeddings(self,
                                vocab_size: int) -> None:
        """Resizes the embeddings from the pre-trained embedding dimension to pre-defined embedding size."""
        self.embedding.resize_token_embeddings(vocab_size)

    def prepare_pack_padded_sequence(self,
                                     input_ids: torch.Tensor,
                                     input_lengths: torch.Tensor,
                                     descending: Optional[bool] = True):
        """Sorts the input sequences based on their length."""
        sorted_input_lengths, indices = torch.sort(input_lengths, descending=descending)
        _, desorted_indices = torch.sort(indices, descending=False)
        sorted_input_ids = input_ids[indices]
        return sorted_input_ids, sorted_input_lengths, desorted_indices

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: torch.Tensor,
                return_dict: Optional[bool] = True):
        """Forward propagation of a LSTM network."""
        # add a pseudo input of max_length
        add_pseudo = max(torch.sum(attention_mask, dim=-1).tolist()) != input_ids.shape[1]
        if add_pseudo:
            input_ids = torch.cat((torch.zeros_like(input_ids[0]).unsqueeze(0), input_ids), dim=0)
            attention_mask = torch.cat((torch.ones_like(attention_mask[0]).unsqueeze(0), attention_mask), dim=0)
        input_length = torch.sum(attention_mask, dim=-1).to(torch.long)
        sorted_input_ids, sorted_seq_length, desorted_indices = self.prepare_pack_padded_sequence(input_ids,
                                                                                                  input_length)
        x = self.embedding(sorted_input_ids)  # (B, L, H)
        packed_embedded = nn.utils.rnn.pack_padded_sequence(x, sorted_seq_length.cpu(), batch_first=True)
        self.rnn.flatten_parameters()
        packed_output, (hidden, cell) = self.rnn(packed_embedded)
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        x = output[desorted_indices]
        if add_pseudo:
            x = self.dropout(x)[1:, :, :]  # remove the pseudo input
        else:
            x = self.dropout(x)

        if return_dict:
            return Output(
                last_hidden_state=x
            )
        else:
            return (x)

Model

import os
import torch
import torch.nn as nn

from typing import Dict, Optional, Union
from transformers import BartForConditionalGeneration, MT5ForConditionalGeneration, T5ForConditionalGeneration

from OmniEvent.aggregation.aggregation import get_aggregation, aggregate
from OmniEvent.head.head import get_head
from OmniEvent.head.classification import LinearHead
from OmniEvent.arguments import (
    ModelArguments,
    DataArguments,
    TrainingArguments,
    ArgumentParser
)
from OmniEvent.utils import check_web_and_convert_path

get_model

Returns the model proposed to be utilized for training and prediction based on the pre-defined paradigm. The paradigms of training and prediction include token classification, sequence labeling, Sequence-to-Sequence (Seq2Seq), and Machine Reading Comprehension (MRC).

Args:

  • model_args: The arguments of the model for training and prediction.

  • backbone: The backbone model obtained from the get_backbone() method.

Returns:

  • The model method/class proposed to be utilized for training and prediction.

def get_model(model_args,
              backbone):
    """Returns the model proposed to be utilized for training and prediction.
    Returns the model proposed to be utilized for training and prediction based on the pre-defined paradigm. The
    paradigms of training and prediction include token classification, sequence labeling, Sequence-to-Sequence
    (Seq2Seq), and Machine Reading Comprehension (MRC).
    Args:
        model_args:
            The arguments of the model for training and prediction.
        backbone:
            The backbone model obtained from the `get_backbone()` method.
    Returns:
        The model method/class proposed to be utilized for training and prediction.
    """
    if model_args.paradigm == "token_classification":
        return ModelForTokenClassification(model_args, backbone)
    elif model_args.paradigm == "sequence_labeling":
        return ModelForSequenceLabeling(model_args, backbone)
    elif model_args.paradigm == "seq2seq":
        return backbone
    elif model_args.paradigm == "mrc":
        return ModelForMRC(model_args, backbone)
    else:
        raise ValueError("No such paradigm")

get_model_cls

def get_model_cls(model_args):
    if model_args.paradigm == "token_classification":
        return ModelForTokenClassification
    elif model_args.paradigm == "sequence_labeling":
        return ModelForSequenceLabeling
    elif model_args.paradigm == "seq2seq":
        if model_args.model_type == "bart":
            return BartForConditionalGeneration
        elif model_args.model_type == "t5":
            return T5ForConditionalGeneration
        elif model_args.model_type == "mt5":
            return MT5ForConditionalGeneration
        else:
            raise ValueError("Invalid model_type %s" % model_args.model_type)
    elif model_args.paradigm == "mrc":
        return ModelForMRC
    else:
        raise ValueError("No such paradigm")

BaseModel

class BaseModel(nn.Module):

    @classmethod
    def from_pretrained(cls, model_name_or_path: Union[str, os.PathLike], config=None, **kwargs):
        if config is None:
            parser = ArgumentParser((ModelArguments, DataArguments, TrainingArguments))
            model_args, _, _ = parser.from_pretrained(model_name_or_path, **kwargs)
        path = check_web_and_convert_path(model_name_or_path, 'model')
        model = get_model(model_args)
        model.load_state_dict(torch.load(path), strict=False)
        return model

ModelForTokenClassification

BERT model for token classification, which firstly obtains hidden states through the backbone model, then aggregates the hidden states through the aggregation method/class, and finally classifies each token to their corresponding label through token-wise linear transformation.

Attributes:

  • config: The configurations of the model.

  • backbone: The backbone network obtained from the get_backbone() method to output initial hidden states.

  • aggregation: The aggregation method/class for aggregating the hidden states output by the backbone network.

  • cls_head: A ClassificationHead instance classifying each token into its corresponding label through a token-wise linear transformation.

class ModelForTokenClassification(BaseModel):
    """BERT model for token classification.
    BERT model for token classification, which firstly obtains hidden states through the backbone model, then aggregates
    the hidden states through the aggregation method/class, and finally classifies each token to their corresponding
    label through token-wise linear transformation.
    Attributes:
        config:
            The configurations of the model.
        backbone:
            The backbone network obtained from the `get_backbone()` method to output initial hidden states.
        aggregation:
            The aggregation method/class for aggregating the hidden states output by the backbone network.
        cls_head (`ClassificationHead`):
            A `ClassificationHead` instance classifying each token into its corresponding label through a token-wise
            linear transformation.
    """

    def __init__(self,
                 config,
                 backbone) -> None:
        """Constructs a `ModelForTokenClassification`."""
        super(ModelForTokenClassification, self).__init__()
        self.config = config
        self.backbone = backbone
        self.aggregation = get_aggregation(config)
        self.cls_head = get_head(config)

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: Optional[torch.Tensor] = None,
                trigger_left: Optional[torch.Tensor] = None,
                trigger_right: Optional[torch.Tensor] = None,
                argument_left: Optional[torch.Tensor] = None,
                argument_right: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Manipulates the inputs through a backbone, aggregation, and classification module,
           returns the predicted logits and loss."""
        # backbone encode
        outputs = self.backbone(input_ids=input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type_ids,
                                return_dict=True)
        hidden_states = outputs.last_hidden_state
        # aggregation
        hidden_state = aggregate(self.config,
                                 self.aggregation,
                                 hidden_states,
                                 trigger_left,
                                 trigger_right,
                                 argument_left,
                                 argument_right)
        # classification
        logits = self.cls_head(hidden_state)
        # compute loss
        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)
        return dict(loss=loss, logits=logits)

ModelForSequenceLabeling

BERT model for sequence labeling, which firstly obtains hidden states through the backbone model, then labels each token to their corresponding label, and finally decodes the label through a Conditional Random Field (CRF) module.

Attributes:

  • config: The configurations of the model.

  • backbone: The backbone network obtained from the get_backbone() method to output initial hidden states.

  • cls_head: A ClassificationHead instance classifying each token into its corresponding label through a token-wise linear transformation.

class ModelForSequenceLabeling(BaseModel):
    """BERT model for sequence labeling.
    BERT model for sequence labeling, which firstly obtains hidden states through the backbone model, then labels each
    token to their corresponding label, and finally decodes the label through a Conditional Random Field (CRF) module.
    Attributes:
        config:
            The configurations of the model.
        backbone:
            The backbone network obtained from the `get_backbone()` method to output initial hidden states.
        cls_head (`ClassificationHead`):
            A `ClassificationHead` instance classifying each token into its corresponding label through a token-wise
            linear transformation.
    """

    def __init__(self,
                 config,
                 backbone) -> None:
        """Constructs a `ModelForSequenceLabeling`."""
        super(ModelForSequenceLabeling, self).__init__()
        self.config = config
        self.backbone = backbone
        self.cls_head = LinearHead(config)
        self.head = get_head(config)

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: Optional[torch.Tensor] = None,
                labels: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Manipulates the inputs through a backbone, classification, and CRF module,
           returns the predicted logits and loss."""
        # backbone encode
        outputs = self.backbone(input_ids=input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type_ids,
                                return_dict=True)
        hidden_states = outputs.last_hidden_state
        # classification
        logits = self.cls_head(hidden_states)  # [batch_size, seq_length, num_labels]
        # compute loss
        loss = None
        if labels is not None:
            if self.config.head_type != "crf":
                loss_fn = nn.CrossEntropyLoss()
                loss = loss_fn(logits.reshape(-1, logits.shape[-1]), labels.reshape(-1))
            else:
                # CRF
                labels[:, 0] = 0
                mask = labels != -100
                tags = labels * mask.to(torch.long)
                loss = -self.head(emissions=logits,
                                  tags=tags,
                                  mask=mask,
                                  reduction="token_mean")
                labels[:, 0] = -100
        else:
            if self.config.head_type == "crf":
                mask = torch.ones_like(logits[:, :, 0])
                preds = self.head.decode(emissions=logits, mask=mask)
                logits = torch.LongTensor(preds)

        return dict(loss=loss, logits=logits)

ModelForMRC

BERT model for Machine Reading Comprehension (MRC), which firstly obtains hidden states through the backbone model, then predicts the start and end logits of each mention type through an MRC head.

Attributes:

  • config: The configurations of the model.

  • backbone: The backbone network obtained from the get_backbone() method to output initial hidden states.

  • mrc_head: A ClassificationHead instance classifying the hidden states into start and end logits of each mention type through token-wise linear transformations.

class ModelForMRC(BaseModel):
    """BERT Model for Machine Reading Comprehension (MRC).
    BERT model for Machine Reading Comprehension (MRC), which firstly obtains hidden states through the backbone model,
    then predicts the start and end logits of each mention type through an MRC head.
    Attributes:
        config:
            The configurations of the model.
        backbone:
            The backbone network obtained from the `get_backbone()` method to output initial hidden states.
        mrc_head (`MRCHead`):
            A `ClassificationHead` instance classifying the hidden states into start and end logits of each mention type
            through token-wise linear transformations.
    """

    def __init__(self,
                 config,
                 backbone) -> None:
        """Constructs a `ModelForMRC`."""
        super(ModelForMRC, self).__init__()
        self.backbone = backbone
        self.mrc_head = get_head(config)

    def forward(self,
                input_ids: torch.Tensor,
                attention_mask: torch.Tensor,
                token_type_ids: Optional[torch.Tensor] = None,
                argument_left: Optional[torch.Tensor] = None,
                argument_right: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        """Manipulates the inputs through a backbone and a MRC head module,
           returns the predicted start and logits and loss."""
        # backbone encode
        outputs = self.backbone(input_ids=input_ids,
                                attention_mask=attention_mask,
                                token_type_ids=token_type_ids,
                                return_dict=True)
        hidden_states = outputs.last_hidden_state
        start_logits, end_logits = self.mrc_head(hidden_states)
        total_loss = None
        # pdb.set_trace()
        if argument_left is not None and argument_right is not None:
            # If we are on multi-GPU, split add a dimension
            if len(argument_left.size()) > 1:
                argument_left = argument_left.squeeze(-1)
            if len(argument_right.size()) > 1:
                argument_right = argument_right.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            argument_left = argument_left.clamp(0, ignored_index)
            argument_right = argument_right.clamp(0, ignored_index)

            loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
            start_loss = loss_fct(start_logits, argument_left)
            end_loss = loss_fct(end_logits, argument_right)
            total_loss = (start_loss + end_loss) / 2

        logits = torch.cat((start_logits, end_logits), dim=-1)  # [batch_size, seq_length*2]
        return dict(loss=total_loss, logits=logits)

Label Smoother Sum

Note

Copyright Text2Event from https://github.com/luyaojie/Text2Event.

Licensed under the MIT License.

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright Text2Event from https://github.com/luyaojie/Text2Event.
# Licensed under the MIT License.

import torch

from dataclasses import dataclass
from typing import Dict

SumLabelSmoother

A label-smoothing sum module operated on the pre-computed output from the model, which is a regularization technique that addresses the overfitting and overconfidence problems by adding some noises to decrease the weights of the actual samples when calculating losses.

Attributes:

  • epsilon: A float variable indicating the label smoothing factor.

  • ignore_index: An integer representing the index in the labels to ignore when computing the loss.

@dataclass
class SumLabelSmoother:
    """A label-smoothing sum module operated on the pre-computed output from the model.
    A label-smoothing sum module operated on the pre-computed output from the model, which is a regularization technique
    that addresses the overfitting and overconfidence problems by adding some noises to decrease the weights of the
    actual samples when calculating losses.
    Attributes:
        epsilon (`float`, `optional`, defaults to 0.1):
            A float variable indicating the label smoothing factor.
        ignore_index (`int`, `optional`, defaults to -100):
            An integer representing the index in the labels to ignore when computing the loss.
    """

    epsilon: float = 0.1
    ignore_index: int = -100

    def __call__(self,
                 model_output: Dict[str, torch.Tensor],
                 labels: torch.Tensor) -> float:
        """Conducts the label smoothing process."""
        logits = model_output["logits"] if isinstance(model_output, dict) else model_output[0]
        log_probs = -torch.nn.functional.log_softmax(logits, dim=-1)
        if labels.dim() == log_probs.dim() - 1:
            labels = labels.unsqueeze(-1)

        padding_mask = labels.eq(self.ignore_index)
        # In case the ignore_index is -100, the gather will fail, so we replace labels by 0. The padding_mask
        # will ignore them in any case.
        labels.clamp_min_(0)
        nll_loss = log_probs.gather(dim=-1, index=labels)
        smoothed_loss = log_probs.sum(dim=-1, keepdim=True)

        nll_loss.masked_fill_(padding_mask, 0.0)
        smoothed_loss.masked_fill_(padding_mask, 0.0)

        # Take the mean over the label dimensions, then divide by the number of active elements (i.e. not-padded):
        # num_active_elements = padding_mask.numel() - padding_mask.long().sum()
        nll_loss = nll_loss.sum()  # / num_active_elements
        smoothed_loss = smoothed_loss.sum()  # / (num_active_elements * log_probs.shape[-1])
        eps_i = self.epsilon / log_probs.size(-1)
        return (1 - self.epsilon) * nll_loss + eps_i * smoothed_loss

Constraint Decoding

Note

Copyright Text2Event from https://github.com/luyaojie/Text2Event.

Licensed under the MIT License.

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright Text2Event from https://github.com/luyaojie/Text2Event.
# Licensed under the MIT License.

from typing import List, Dict
import os
import re
import sys
sys.path.append("../")
import pdb

# debug = True if 'DEBUG' in os.environ else False
# debug_step = True if 'DEBUG_STEP' in os.environ else False
debug = False
debug_step = False

from ..input_engineering.seq2seq_processor import type_start, type_end


def get_label_name_tree(label_name_list, tokenizer, end_symbol='<end>'):
    sub_token_tree = dict()

    label_tree = dict()
    for typename in label_name_list:
        after_tokenized = tokenizer.encode(typename, add_special_tokens=False)
        label_tree[typename] = after_tokenized

    for _, sub_label_seq in label_tree.items():
        parent = sub_token_tree
        for value in sub_label_seq:
            if value not in parent:
                parent[value] = dict()
            parent = parent[value]

        parent[end_symbol] = None

    return sub_token_tree


def match_sublist(the_list, to_match):
    """
    :param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5]
    :param to_match: [1, 2]
    :return:
        [(0, 1), (6, 7)]
    """
    len_to_match = len(to_match)
    matched_list = list()
    for index in range(len(the_list) - len_to_match + 1):
        if to_match == the_list[index:index + len_to_match]:
            matched_list += [(index, index + len_to_match - 1)]
    return matched_list


def find_bracket_position(generated_text, _type_start, _type_end):
    bracket_position = {_type_start: list(), _type_end: list()}
    for index, char in enumerate(generated_text):
        if char in bracket_position:
            bracket_position[char] += [index]
    return bracket_position


def generated_search_src_sequence(generated, src_sequence, end_sequence_search_tokens=None):
    # print(generated, src_sequence) if debug else None

    if len(generated) == 0:
        # It has not been generated yet. All SRC are valid.
        return src_sequence

    matched_tuples = match_sublist(the_list=src_sequence, to_match=generated)

    valid_token = list()
    for _, end in matched_tuples:
        next_index = end + 1
        if next_index < len(src_sequence):
            valid_token += [src_sequence[next_index]]

    if end_sequence_search_tokens:
        valid_token += end_sequence_search_tokens

    return valid_token


def get_constraint_decoder(tokenizer, type_schema, source_prefix=None):
    return StruConstraintDecoder(tokenizer=tokenizer, type_schema=type_schema, source_prefix=source_prefix)


class ConstraintDecoder:
    def __init__(self, tokenizer, source_prefix):
        self.tokenizer = tokenizer
        self.source_prefix = source_prefix
        self.source_prefix_tokenized = tokenizer.encode(source_prefix,
                                                        add_special_tokens=False) if source_prefix else []

    def get_state_valid_tokens(self, src_sentence: List[str], tgt_generated: List[str]) -> List[str]:
        pass

    def constraint_decoding(self, batch_id, src_sentence, tgt_generated):
        if self.source_prefix_tokenized:
            # Remove Source Prefix for Generation
            src_sentence = src_sentence[len(self.source_prefix_tokenized):]

        if debug:
            # if batch_id == 4:
            print("Src:", self.tokenizer.convert_ids_to_tokens(src_sentence))
            print("Tgt:", self.tokenizer.convert_ids_to_tokens(tgt_generated))
            print(batch_id, len(tgt_generated), tgt_generated)

        valid_token_ids = self.get_state_valid_tokens(
            src_sentence.tolist(),
            tgt_generated.tolist()
        )
        # pdb.set_trace()

        # if debug:
        #     print('========================================')
        #     print('valid tokens:', self.tokenizer.convert_ids_to_tokens(
        #         valid_token_ids), valid_token_ids)
        #     if debug_step:
        #         input()

        # return self.tokenizer.convert_tokens_to_ids(valid_tokens)
        return valid_token_ids


class StruConstraintDecoder(ConstraintDecoder):
    def __init__(self, tokenizer, type_schema, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
        self.tree_end = '<tree-end>'
        self.type_tree = get_label_name_tree(type_schema["role_list"],
                                             tokenizer=self.tokenizer,
                                             end_symbol=self.tree_end)
        self.type_start = self.tokenizer.convert_tokens_to_ids([type_start])[0]
        self.type_end = self.tokenizer.convert_tokens_to_ids([type_end])[0]

    def check_state(self, tgt_generated):
        if tgt_generated[-1] == self.tokenizer.pad_token_id:
            return 'start', -1

        special_token_set = {self.type_start, self.type_end}
        special_index_token = list(
            filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated))))

        last_special_index, last_special_token = special_index_token[-1]

        if len(special_index_token) == 1:
            if last_special_token != self.type_start:
                return 'error', 0

        bracket_position = find_bracket_position(
            tgt_generated, _type_start=self.type_start, _type_end=self.type_end)
        start_number, end_number = len(bracket_position[self.type_start]), len(
            bracket_position[self.type_end])

        if start_number == end_number:
            return 'end_generate', -1
        if start_number == end_number + 1:
            state = 'start_first_generation'
        elif start_number == end_number + 2:
            state = 'generate_span'
        else:
            state = 'error'
        return state, last_special_index

    def search_prefix_tree_and_sequence(self, generated: List[str], prefix_tree: Dict, src_sentence: List[str],
                                        end_sequence_search_tokens: List[str] = None):
        """
        Generate Type Name + Text Span
        :param generated:
        :param prefix_tree:
        :param src_sentence:
        :param end_sequence_search_tokens:
        :return:
        """
        tree = prefix_tree
        for index, token in enumerate(generated):
            tree = tree[token]
            is_tree_end = len(tree) == 1 and self.tree_end in tree

            if is_tree_end:
                valid_token = generated_search_src_sequence(
                    generated=generated[index + 1:],
                    src_sequence=src_sentence,
                    end_sequence_search_tokens=end_sequence_search_tokens,
                )
                return valid_token

            if self.tree_end in tree:
                try:
                    valid_token = generated_search_src_sequence(
                        generated=generated[index + 1:],
                        src_sequence=src_sentence,
                        end_sequence_search_tokens=end_sequence_search_tokens,
                    )
                    return valid_token
                except IndexError:
                    # Still search tree
                    continue

        valid_token = list(tree.keys())
        return valid_token

    def get_state_valid_tokens(self, src_sentence, tgt_generated):
        """
        :param src_sentence:
        :param tgt_generated:
        :return:
            List[str], valid token list
        """
        if self.tokenizer.eos_token_id in src_sentence:
            src_sentence = src_sentence[:src_sentence.index(
                self.tokenizer.eos_token_id)]

        state, index = self.check_state(tgt_generated)

        print("State: %s" % state) if debug else None

        if state == 'error':
            print("Error:")
            print("Src:", src_sentence)
            print("Tgt:", tgt_generated)
            valid_tokens = [self.tokenizer.eos_token_id]

        elif state == 'start':
            valid_tokens = [self.type_start]

        elif state == 'start_first_generation':
            valid_tokens = [self.type_start, self.type_end]

        elif state == 'generate_span':

            if tgt_generated[-1] == self.type_start:
                # Start Event Label
                return list(self.type_tree.keys())

            elif tgt_generated[-1] == self.type_end:
                raise RuntimeError('Invalid %s in %s' %
                                   (self.type_end, tgt_generated))

            else:
                try:
                    valid_tokens = self.search_prefix_tree_and_sequence(
                        generated=tgt_generated[index + 1:],
                        prefix_tree=self.type_tree,
                        src_sentence=src_sentence,
                        end_sequence_search_tokens=[self.type_end]
                    )
                except:
                    print("Warning! An unexpected token is generated due to len(valid_tokens) < num_beams.")
                    valid_tokens = [self.tokenizer.eos_token_id]

        elif state == 'end_generate':
            valid_tokens = [self.tokenizer.eos_token_id]

        else:
            raise NotImplementedError(
                'State `%s` for %s is not implemented.' % (state, self.__class__))

        print("Valid: %s" % valid_tokens) if debug else None
        return valid_tokens


class SpanConstraintDecoder(ConstraintDecoder):
    def __init__(self, tokenizer, type_schema, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
        self.tree_end = '<tree-end>'
        self.type_tree = get_label_name_tree(type_schema["role_list"],
                                             tokenizer=self.tokenizer,
                                             end_symbol=self.tree_end)

    def check_state(self, tgt_generated, special_tokens_in_tgt):
        if tgt_generated[-1] == self.tokenizer.pad_token_id:
            return 'start', -1
        else:
            index = len(tgt_generated)
            for i, token in enumerate(tgt_generated):
                if token == special_tokens_in_tgt[-1]:
                    index = i+1
                    break
            return "generate", index

    def get_special_tokens(self, sentence):
        special_template = re.compile("<extra_id_\d+>")
        tokens = self.tokenizer.convert_ids_to_tokens(sentence)
        special_tokens = []
        for token in tokens:
            if special_template.match(token) is not None:
                special_tokens.append(token)
        return self.tokenizer.convert_tokens_to_ids(special_tokens)

    def truncate_src(self, src_sentence):
        special_template = re.compile("<extra_id_\d+>")
        index = len(src_sentence)
        tokens = self.tokenizer.convert_ids_to_tokens(src_sentence)
        for i, token in enumerate(tokens):
            if special_template.match(token) is not None:
                index = i
                break
        return src_sentence[:index]

    def get_state_valid_tokens(self, src_sentence, tgt_generated):
        """
        :param src_sentence:
        :param tgt_generated:
        :return:
            List[str], valid token list
        """
        if self.tokenizer.eos_token_id in src_sentence:
            src_sentence = src_sentence[:src_sentence.index(
                self.tokenizer.eos_token_id)]

        special_tokens_in_src = self.get_special_tokens(src_sentence)
        special_tokens_in_gen = self.get_special_tokens(tgt_generated)

        # truncate
        src_sentence = self.truncate_src(src_sentence)

        state, index = self.check_state(tgt_generated, special_tokens_in_gen)


        if state == 'start':
            valid_tokens = [special_tokens_in_src[0]]

        elif state == 'generate':
            # valid_tokens = [self.type_start, self.type_end]
            valid_special_tokens = [self.tokenizer.convert_tokens_to_ids("[SEP]")]
            for token in special_tokens_in_src:
                if token not in special_tokens_in_gen:
                    valid_special_tokens.append(token)
            valid_tokens = generated_search_src_sequence(
                        generated=tgt_generated[index:],
                        src_sequence=src_sentence,
                        end_sequence_search_tokens=[self.tokenizer.eos_token_id],
                    )
            valid_tokens = valid_special_tokens + valid_tokens
        else:
            raise NotImplementedError(
                'State `%s` for %s is not implemented.' % (state, self.__class__))

        print("Valid: %s" % valid_tokens) if debug else None
        return valid_tokens

Aggregation

import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import

get_aggregation

Obtains the aggregation method to be utilized based on the model’s configurations. The aggregation methods include selecting the ``<cls>``s’ representations, selecting the markers’ representations, max-pooling, and dynamic multi-pooling.

Args:

  • config: The configurations of the model.

Returns:

  • The proposed method/class for the aggregation process.

def get_aggregation(config):
    """Obtains the aggregation method to be utilized.
    Obtains the aggregation method to be utilized based on the model's configurations. The aggregation methods include
    selecting the `<cls>`s' representations, selecting the markers' representations, max-pooling, and dynamic
    multi-pooling.
    Args:
        config:
            The configurations of the model.
    Returns:
        The proposed method/class for the aggregation process.
    """
    if config.aggregation == "cls":
        return select_cls
    elif config.aggregation == "marker":
        return select_marker
    elif config.aggregation == "dynamic_pooling":
        return DynamicPooling(config)
    elif config.aggregation == "max_pooling":
        return max_pooling
    else:
        raise ValueError("Invaild %s aggregation method" % config.aggregation)

aggregate

Aggregates information to each position. The aggregation methods include selecting the “cls”s’ representations, selecting the markers’ representations, max-pooling, and dynamic multi-pooling.

Args:

  • config: The configurations of the model.

  • method: The method proposed to be utilized in the aggregation process.

  • hidden_states: A tensor representing the hidden states output by the backbone model.

  • trigger_left: A tensor indicating the left position of the triggers.

  • trigger_right: A tensor indicating the right position of the triggers.

  • argument_left: A tensor indicating the left position of the arguments.

  • argument_right: A tensor indicating the right position of the arguments.

def aggregate(config,
              method,
              hidden_states: torch.Tensor,
              trigger_left: torch.Tensor,
              trigger_right: torch.Tensor,
              argument_left: torch.Tensor,
              argument_right: torch.Tensor):
    """Aggregates information to each position.
    Aggregates information to each position. The aggregation methods include selecting the "cls"s' representations,
    selecting the markers' representations, max-pooling, and dynamic multi-pooling.
    Args:
        config:
            The configurations of the model.
        method:
            The method proposed to be utilized in the aggregation process.
            TODO: The data type of the variable `method` should be configured.
        hidden_states (`torch.Tensor`):
            A tensor representing the hidden states output by the backbone model.
        trigger_left (`torch.Tensor`):
            A tensor indicating the left position of the triggers.
        trigger_right (`torch.Tensor`):
            A tensor indicating the right position of the triggers.
        argument_left (`torch.Tensor`):
            A tensor indicating the left position of the arguments.
        argument_right (`torch.Tensor`):
            A tensor indicating the right position of the arguments.
    """
    if config.aggregation == "cls":
        return method(hidden_states)
    elif config.aggregation == "marker":
        if argument_left is not None:
            return method(hidden_states, argument_left, argument_right)
        else:
            return method(hidden_states, trigger_left, trigger_right)
    elif config.aggregation == "max_pooling":
        return method(hidden_states)
    elif config.aggregation == "dynamic_pooling":
        return method(hidden_states, trigger_left, argument_left)
    else:
        raise ValueError("Invaild %s aggregation method" % config.aggregation)

max_pooling

Applies the max-pooling operation over the representation of the entire input sequence to capture the most useful information. The operation processes on the hidden states, which are output by the backbone model.

Args:

  • ``hidden_states`: A tensor representing the hidden states output by the backbone model.

Returns:

  • pooled_states: A tensor represents the max-pooled hidden states, containing the most useful information of the sequence.

def max_pooling(hidden_states: torch.Tensor) -> torch.Tensor:
    """Applies the max-pooling operation over the sentence representation.
    Applies the max-pooling operation over the representation of the entire input sequence to capture the most useful
    information. The operation processes on the hidden states, which are output by the backbone model.
    Args:
        hidden_states (`torch.Tensor`):
            A tensor representing the hidden states output by the backbone model.
    Returns:
        pooled_states (`torch.Tensor`):
            A tensor represents the max-pooled hidden states, containing the most useful information of the sequence.
    """
    batch_size, seq_length, hidden_size = hidden_states.size()
    pooled_states = F.max_pool1d(input=hidden_states.transpose(1, 2), kernel_size=seq_length).squeeze(-1)
    return pooled_states

select_cls

Returns the representations of each sequence’s <cls> token by slicing the hidden state tensor output by the backbone model. The representations of the <cls> tokens contain general information of the sequences.

Args:

  • hidden_states: A tensor represents the hidden states output by the backbone model.

Returns:

  • A tensor containing the representations of each sequence’s <cls> token.

def select_cls(hidden_states: torch.Tensor) -> torch.Tensor:
    """Returns the representations of the `<cls>` tokens.
    Returns the representations of each sequence's `<cls>` token by slicing the hidden state tensor output by the
    backbone model. The representations of the `<cls>` tokens contain general information of the sequences.
    Args:
        hidden_states (`torch.Tensor`):
            A tensor represents the hidden states output by the backbone model.
    Returns:
        `torch.Tensor`:
            A tensor containing the representations of each sequence's `<cls>` token.
    """
    return hidden_states[:, 0, :]

select_marker

Returns the representations of each sequence’s marker tokens by slicing the hidden state tensor output by the backbone model.

Args:

  • hidden_states: A tensor representing the hidden states output by the backbone model.

  • left: A tensor indicates the left position of the markers.

  • right: A tensor indicates the right position of the markers.

Returns:

  • ``marker_output`: A tensor containing the representations of each sequence’s marker tokens by concatenating their left and right token’s representations.

def select_marker(hidden_states: torch.Tensor,
                  left: torch.Tensor,
                  right: torch.Tensor) -> torch.Tensor:
    """Returns the representations of the marker tokens.
    Returns the representations of each sequence's marker tokens by slicing the hidden state tensor output by the
    backbone model.
    Args:
        hidden_states (`torch.Tensor`):
            A tensor representing the hidden states output by the backbone model.
        left (`torch.Tensor`):
            A tensor indicates the left position of the markers.
        right (`torch.Tensor`):
            A tensor indicates the right position of the markers.
    Returns:
        marker_output (`torch.Tensor`):
            A tensor containing the representations of each sequence's marker tokens by concatenating their left and
            right token's representations.
    """
    batch_size = hidden_states.size(0)
    batch_indice = torch.arange(batch_size)
    left_states = hidden_states[batch_indice, left.to(torch.long), :]
    right_states = hidden_states[batch_indice, right.to(torch.long), :]
    marker_output = torch.cat((left_states, right_states), dim=-1)
    return marker_output

DynamicPooling

Dynamic multi-pooling layer for Convolutional Neural Network (CNN), which is able to capture more valuable information within a sentence, particularly for some cases, such as multiple triggers are within a sentence and different argument candidate may play a different role with a different trigger.

Attributes:

  • activation: An nn.Tanh layer representing the tanh activation function.

  • dropout: An nn.Dropout layer for the dropout operation with the default dropout rate (0.5).

class DynamicPooling(nn.Module):
    """Dynamic multi-pooling layer for Convolutional Neural Network (CNN).
    Dynamic multi-pooling layer for Convolutional Neural Network (CNN), which is able to capture more valuable
    information within a sentence, particularly for some cases, such as multiple triggers are within a sentence and
    different argument candidate may play a different role with a different trigger.
    Attributes:
        dense (`nn.Linear`):
            TODO: The purpose of the linear layer should be configured.
        activation (`nn.Tanh`):
            An `nn.Tanh` layer representing the tanh activation function.
        dropout (`nn.Dropout`):
            An `nn.Dropout` layer for the dropout operation with the default dropout rate (0.5).
    """
    def __init__(self,
                 config) -> None:
        """Constructs a `DynamicPooling`."""
        super(DynamicPooling, self).__init__()
        self.dense = nn.Linear(config.hidden_size*config.head_scale, config.hidden_size*config.head_scale)
        self.activation = nn.Tanh()
        self.dropout = nn.Dropout()

    def get_mask(self,
                 position: torch.Tensor,
                 batch_size: int,
                 seq_length: int,
                 device: str) -> torch.Tensor:
        """Returns the mask indicating whether the token is padded or not."""
        all_masks = []
        for i in range(batch_size):
            mask = torch.zeros((seq_length), dtype=torch.int16, device=device)
            mask[:int(position[i])] = 1
            all_masks.append(mask.to(torch.bool))
        all_masks = torch.stack(all_masks, dim=0)
        return all_masks

    def max_pooling(self,
                    hidden_states: torch.Tensor,
                    mask: torch.Tensor) -> torch.Tensor:
        """Conducts the max-pooling operation on the hidden states."""
        batch_size, seq_length, hidden_size = hidden_states.size()
        conved = hidden_states.transpose(1, 2)
        conved = conved.transpose(0, 1)
        states = (conved * mask).transpose(0, 1)
        states += torch.ones_like(states)
        pooled_states = F.max_pool1d(input=states, kernel_size=seq_length).contiguous().view(batch_size, hidden_size)
        pooled_states -= torch.ones_like(pooled_states)
        return pooled_states

    def forward(self,
                hidden_states: torch.Tensor,
                trigger_position: torch.Tensor,
                argument_position: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Conducts the dynamic multi-pooling process on the hidden states."""
        batch_size, seq_length = hidden_states.size()[:2]
        trigger_mask = self.get_mask(trigger_position, batch_size, seq_length, hidden_states.device)
        if argument_position is not None:
            argument_mask = self.get_mask(argument_position, batch_size, seq_length, hidden_states.device)
            left_mask = torch.logical_and(trigger_mask, argument_mask).to(torch.float32)
            middle_mask = torch.logical_xor(trigger_mask, argument_mask).to(torch.float32)
            right_mask = 1 - torch.logical_or(trigger_mask, argument_mask).to(torch.float32)
            # pooling
            left_states = self.max_pooling(hidden_states, left_mask)
            middle_states = self.max_pooling(hidden_states, middle_mask)
            right_states = self.max_pooling(hidden_states, right_mask)
            pooled_output = torch.cat((left_states, middle_states, right_states), dim=-1)
        else:
            left_mask = trigger_mask.to(torch.float32)
            right_mask = 1 - left_mask
            left_states = self.max_pooling(hidden_states, left_mask)
            right_states = self.max_pooling(hidden_states, right_mask)
            pooled_output = torch.cat((left_states, right_states), dim=-1)

        return pooled_output

Classification Head

from .classification import LinearHead, MRCHead
from .crf import CRF

get_head

def get_head(config):
    if config.head_type == "linear":
        return LinearHead(config)
    elif config.head_type == "mrc":
        return MRCHead(config)
    elif config.head_type == "crf":
        return CRF(config.num_labels, batch_first=True)
    elif config.head_type in ["none", "None"] or config.head_type is None:
        return None
    else:
        raise ValueError("Invalid head_type %s in config" % config.head_type)

Classification Head

from turtle import forward
import torch
import torch.nn as nn

LinearHead

A token-wise classification head for classifying hidden states to label distributions through a linear transformation, selecting the label with the highest probability corresponding to each logit.

Attributes:

  • classifier: An nn.Linear layer classifying each logit into its corresponding label.

class LinearHead(nn.Module):
    """A token-wise classification head for classifying the hidden states to label distributions.
    A token-wise classification head for classifying hidden states to label distributions through a linear
    transformation, selecting the label with the highest probability corresponding to each logit.
    Attributes:
        classifier (`nn.Linear`):
            An `nn.Linear` layer classifying each logit into its corresponding label.
    """
    def __init__(self, config):
        super(LinearHead, self).__init__()
        self.classifier = nn.Linear(config.hidden_size*config.head_scale, config.num_labels)

    def forward(self,
                hidden_state: torch.Tensor) -> torch.Tensor:
        """Classifies hidden states to label distribution."""
        logits = self.classifier(hidden_state)
        return logits

MRCHead

A classification head for the Machine Reading Comprehension (MRC) paradigm, predicting the answer of each question corresponding to a mention type. The classifier returns two logits indicating the start and end position of each mention corresponding to the question.

Attributes:

  • qa_outputs: An nn.Linear layer transforming the hidden states to two logits, indicating the start and end position of a given mention type.

class MRCHead(nn.Module):
    """A token-wise classification head for the Machine Reading Comprehension (MRC) paradigm.
    A classification head for the Machine Reading Comprehension (MRC) paradigm, predicting the answer of each question
    corresponding to a mention type. The classifier returns two logits indicating the start and end position of each
    mention corresponding to the question.
    Attributes:
        qa_outputs (`nn.Linear`):
            An `nn.Linear` layer transforming the hidden states to two logits, indicating the start and end position
            of a given mention type.
    """
    def __init__(self,
                 config) -> None:
        """Constructs a `MRCHead`."""
        super(MRCHead, self).__init__()
        self.qa_outputs = nn.Linear(config.hidden_size, 2)

    def forward(self,
                hidden_state: torch.Tensor):
        """The forward propagation of `MRCHead`."""
        logits = self.qa_outputs(hidden_state)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()
        end_logits = end_logits.squeeze(-1).contiguous()
        return start_logits, end_logits

Conditional Random Field (CRF)

Note

Copyright pytorch-crf from https://github.com/kmkurn/pytorch-crf.

Licensed under the MIT License.

CRF

This module implements a Conditional Random Field (CRF). The forward computation of this class computes the log likelihood of the given sequence of tags and emission score tensor. This class also has CRF.decode() method which finds the best tag sequence given an emission score tensor using Viterbi algorithm.

Attributes:

  • num_tags: An integer indicating the number of tags to be predicted.

  • batch_first: A boolean variable indicating whether or not splitting the data in batches.

  • start_transitions: An nn.Parameter matrix containing the start transition score tensor of size (num_tags,).

  • end_transitions: An nn.Parameter matrix containing the end transition score tensor of size (num_tags,).

  • transitions: An nn.Parameter matrix indicating the score tensor of size (num_tags, num_tags).

class CRF(nn.Module):
    """Conditional Random Field (CRF) module.
    This module implements a Conditional Random Field (CRF). The forward computation of this class computes the log
    likelihood of the given sequence of tags and emission score tensor. This class also has `CRF.decode()` method which
    finds the best tag sequence given an emission score tensor using Viterbi algorithm.
    Attributes:
        num_tags (`int`):
            An integer indicating the number of tags to be predicted.
        batch_first (`bool`):
            A boolean variable indicating whether or not splitting the data in batches.
        start_transitions (`nn.Parameter`):
            An `nn.Parameter` matrix containing the start transition score tensor of size `(num_tags,)`.
        end_transitions (`nn.Parameter`):
            An `nn.Parameter` matrix containing the end transition score tensor of size `(num_tags,)`.
        transitions (`nn.Parameter`):
            An `nn.Parameter` matrix indicating the score tensor of size `(num_tags, num_tags)`.
    """

    def __init__(self,
                 num_tags: int,
                 batch_first: bool = False) -> None:
        """Constructs a `CRF`."""
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()
        self.num_tags = num_tags
        self.batch_first = batch_first
        self.start_transitions = nn.Parameter(torch.empty(num_tags))
        self.end_transitions = nn.Parameter(torch.empty(num_tags))
        self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Initialize the transition parameters.
        The parameters will be initialized randomly from a uniform distribution between -0.1 and 0.1.
        """
        nn.init.uniform_(self.start_transitions, -0.1, 0.1)
        nn.init.uniform_(self.end_transitions, -0.1, 0.1)
        nn.init.uniform_(self.transitions, -0.1, 0.1)

    def __repr__(self) -> str:
        """Displays the class name and the number of tags."""
        return f'{self.__class__.__name__}(num_tags={self.num_tags})'

    def forward(self,
                emissions: torch.Tensor,
                tags: torch.LongTensor,
                mask: Optional[torch.ByteTensor] = None,
                reduction: str = 'sum') -> torch.Tensor:
        """Compute the conditional log likelihood of a sequence of tags given emission scores."""
        self._validate(emissions, tags=tags, mask=mask)
        if reduction not in ('none', 'sum', 'mean', 'token_mean'):
            raise ValueError(f'invalid reduction: {reduction}')
        if mask is None:
            mask = torch.ones_like(tags, dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            tags = tags.transpose(0, 1)
            mask = mask.transpose(0, 1)

        # shape: (batch_size,)
        numerator = self._compute_score(emissions, tags, mask)
        # shape: (batch_size,)
        denominator = self._compute_normalizer(emissions, mask)
        # shape: (batch_size,)
        llh = numerator - denominator

        if reduction == 'none':
            return llh
        if reduction == 'sum':
            return llh.sum()
        if reduction == 'mean':
            return llh.mean()
        assert reduction == 'token_mean'
        return llh.sum() / mask.type_as(emissions).sum()

    def decode(self, emissions: torch.Tensor,
               mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
        """Find the most likely tag sequence using Viterbi algorithm."""
        self._validate(emissions, mask=mask)
        if mask is None:
            mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)

        if self.batch_first:
            emissions = emissions.transpose(0, 1)
            mask = mask.transpose(0, 1)

        return self._viterbi_decode(emissions, mask)

    def _validate(self,
                emissions: torch.Tensor,
                tags: Optional[torch.LongTensor] = None,
                mask: Optional[torch.ByteTensor] = None) -> None:
        """Validates the emission dimension and whether its slice satisfies tag number, tag shape and mask shape."""
        if emissions.dim() != 3:
            raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
        if emissions.size(2) != self.num_tags:
            raise ValueError(
                f'expected last dimension of emissions is {self.num_tags}, '
                f'got {emissions.size(2)}')

        if tags is not None:
            if emissions.shape[:2] != tags.shape:
                raise ValueError(
                    'the first two dimensions of emissions and tags must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')

        if mask is not None:
            if emissions.shape[:2] != mask.shape:
                raise ValueError(
                    'the first two dimensions of emissions and mask must match, '
                    f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
            no_empty_seq = not self.batch_first and mask[0].all()
            no_empty_seq_bf = self.batch_first and mask[:, 0].all()
            if not no_empty_seq and not no_empty_seq_bf:
                raise ValueError('mask of the first timestep must all be on')

    def _compute_score(self,
                       emissions: torch.Tensor,
                       tags: torch.LongTensor,
                       mask: torch.ByteTensor) -> torch.Tensor:
        """Computes the score based on the emission and transition matrix."""
        # emissions: (seq_length, batch_size, num_tags)
        # tags: (seq_length, batch_size)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and tags.dim() == 2
        assert emissions.shape[:2] == tags.shape
        assert emissions.size(2) == self.num_tags
        assert mask.shape == tags.shape
        assert mask[0].all()

        seq_length, batch_size = tags.shape
        mask = mask.type_as(emissions)

        # Start transition score and first emission
        # shape: (batch_size,)
        score = self.start_transitions[tags[0]]
        score += emissions[0, torch.arange(batch_size), tags[0]]

        for i in range(1, seq_length):
            # Transition score to next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += self.transitions[tags[i - 1], tags[i]] * mask[i]

            # Emission score for next tag, only added if next timestep is valid (mask == 1)
            # shape: (batch_size,)
            score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]

        # End transition score
        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        # shape: (batch_size,)
        last_tags = tags[seq_ends, torch.arange(batch_size)]
        # shape: (batch_size,)
        score += self.end_transitions[last_tags]

        return score

    def _compute_normalizer(self,
                            emissions: torch.Tensor,
                            mask: torch.ByteTensor) -> torch.Tensor:
        """Compute the log-sum-exp score."""
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length = emissions.size(0)

        # Start transition score and first emission; score has size of
        # (batch_size, num_tags) where for each batch, the j-th column stores
        # the score that the first timestep has tag j
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]

        for i in range(1, seq_length):
            # Broadcast score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emissions = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the sum of scores of all
            # possible tag sequences so far that end with transitioning from tag i to tag j
            # and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emissions

            # Sum over all possible current tags, but we're in score space, so a sum
            # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
            # all possible tag sequences so far, that end in tag i
            # shape: (batch_size, num_tags)
            next_score = torch.logsumexp(next_score, dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Sum (log-sum-exp) over all possible tags
        # shape: (batch_size,)
        return torch.logsumexp(score, dim=1)

    def _viterbi_decode(self,
                        emissions: torch.FloatTensor,
                        mask: torch.ByteTensor) -> List[List[int]]:
        """Decodes the optimal path using Viterbi algorithm."""
        # emissions: (seq_length, batch_size, num_tags)
        # mask: (seq_length, batch_size)
        assert emissions.dim() == 3 and mask.dim() == 2
        assert emissions.shape[:2] == mask.shape
        assert emissions.size(2) == self.num_tags
        assert mask[0].all()

        seq_length, batch_size = mask.shape

        # Start transition and first emission
        # shape: (batch_size, num_tags)
        score = self.start_transitions + emissions[0]
        history = []

        # score is a tensor of size (batch_size, num_tags) where for every batch,
        # value at column j stores the score of the best tag sequence so far that ends
        # with tag j
        # history saves where the best tags candidate transitioned from; this is used
        # when we trace back the best tag sequence

        # Viterbi algorithm recursive case: we compute the score of the best tag sequence
        # for every possible next tag
        for i in range(1, seq_length):
            # Broadcast viterbi score for every possible next tag
            # shape: (batch_size, num_tags, 1)
            broadcast_score = score.unsqueeze(2)

            # Broadcast emission score for every possible current tag
            # shape: (batch_size, 1, num_tags)
            broadcast_emission = emissions[i].unsqueeze(1)

            # Compute the score tensor of size (batch_size, num_tags, num_tags) where
            # for each sample, entry at row i and column j stores the score of the best
            # tag sequence so far that ends with transitioning from tag i to tag j and emitting
            # shape: (batch_size, num_tags, num_tags)
            next_score = broadcast_score + self.transitions + broadcast_emission

            # Find the maximum score over all possible current tag
            # shape: (batch_size, num_tags)
            next_score, indices = next_score.max(dim=1)

            # Set score to the next score if this timestep is valid (mask == 1)
            # and save the index that produces the next score
            # shape: (batch_size, num_tags)
            score = torch.where(mask[i].unsqueeze(1), next_score, score)
            history.append(indices)

        # End transition score
        # shape: (batch_size, num_tags)
        score += self.end_transitions

        # Now, compute the best path for each sample

        # shape: (batch_size,)
        seq_ends = mask.long().sum(dim=0) - 1
        best_tags_list = []

        for idx in range(batch_size):
            # Find the tag which maximizes the score at the last timestep; this is our best tag
            # for the last timestep
            _, best_last_tag = score[idx].max(dim=0)
            best_tags = [best_last_tag.item()]

            # We trace back where the best last tag comes from, append that to our best tag
            # sequence, and trace it back again, and so on
            for hist in reversed(history[:seq_ends[idx]]):
                best_last_tag = hist[idx][best_tags[-1]]
                best_tags.append(best_last_tag.item())

            # Reverse the order because we start from the last timestep
            best_tags.reverse()
            best_tags_list.append(best_tags)

        return best_tags_list

Evaluation Metrics

import copy

import torch
import numpy as np

from sklearn.metrics import f1_score
from seqeval.metrics import f1_score as span_f1_score
from seqeval.scheme import IOB2
from typing import Tuple, Dict, List, Optional, Union

from ..input_engineering.mrc_converter import make_predictions, compute_mrc_F1_cls
from ..input_engineering.seq2seq_processor import extract_argument

compute_unified_micro_f1

Compute the F1 score of the unified evaluation on the converted word-level predictions.

Args:

  • label_names: A list of ground truth labels of each word.

  • results: A list of predicted event types or argument roles of each word.

Returns:

  • micro_f1: The computation results of F1 score.

def compute_unified_micro_f1(label_names: List[str], results: List[str]) -> float:
    """Computes the F1 score of the converted word-level predictions.
    Compute the F1 score of the unified evaluation on the converted word-level predictions.
    Args:
        label_names (`List[str]`):
            A list of ground truth labels of each word.
        results (`List[str]`):
            A list of predicted event types or argument roles of each word.
    Returns:
        micro_f1 (`float`):
            The computation results of  F1 score.
    """
    pos_labels = list(set(label_names))
    pos_labels.remove("NA")
    micro_f1 = f1_score(label_names, results, labels=pos_labels, average="micro") * 100.0
    return micro_f1

f1_score_overall

Computes the overall F1 score of the predictions based on the calculation of the overall precision and recall after counting the true predictions, in which both the prediction of mention and type are correct.

Args:

  • preds: A list of strings indicating the prediction of labels from the model.

  • labels: A list of strings indicating the actual labels obtained from the annotated dataset.

Returns: - precision, recall, and f1: Three float variables representing the computation results of precision, recall, and F1 score, respectively.

def f1_score_overall(preds: Union[List[str], List[tuple]],
                     labels: Union[List[str], List[tuple]]) -> Tuple[float, float, float]:
    """Computes the overall F1 score of the predictions.
    Computes the overall F1 score of the predictions based on the calculation of the overall precision and recall after
    counting the true predictions, in which both the prediction of mention and type are correct.
    Args:
        preds (`Union[List[str], List[tuple]]`):
            A list of strings indicating the prediction of labels from the model.
        labels (`Union[List[str], List[tuple]]`):
            A list of strings indicating the actual labels obtained from the annotated dataset.
    Returns:
        precision (`float`), recall (`float`), and f1 (`float`):
            Three integers representing the computation results of precision, recall, and F1 score, respectively.
    """
    true_pos = 0
    label_stack = copy.deepcopy(labels)
    for pred in preds:
        if pred in label_stack:
            true_pos += 1
            label_stack.remove(pred)  # one prediction can only be matched to one ground truth.
    precision = true_pos / (len(preds)+1e-10)
    recall = true_pos / (len(labels)+1e-10)
    f1 = 2 * precision * recall / (precision + recall + 1e-10)
    return precision, recall, f1

compute_seq_F1

Computes the F1 score of the Sequence-to-Sequence (Seq2Seq) paradigm. The predictions of the model are firstly decoded into strings, then the overall F1 score of the prediction could be calculated.

Args:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

Returns: - A dictionary containing the calculation result of the F1 score.

def compute_seq_F1(logits: np.ndarray,
                   labels: np.ndarray,
                   **kwargs) -> Dict[str, float]:
    """Computes the F1 score of the Sequence-to-Sequence (Seq2Seq) paradigm.
    Computes the F1 score of the  Sequence-to-Sequence (Seq2Seq) paradigm. The predictions of the model are firstly
    decoded into strings, then the overall F1 score of the prediction could be calculated.
    Args:
        logits (`List[int]`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels: (`List[str]`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
    Returns:
        `Dict[str: float]`:
            A dictionary containing the calculation result of the F1 score.
    """
    tokenizer = kwargs["tokenizer"]
    training_args = kwargs["training_args"]
    decoded_preds = tokenizer.batch_decode(logits, skip_special_tokens=False)

    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=False)

    def clean_str(x_str):
        for to_remove_token in [tokenizer.eos_token, tokenizer.pad_token]:
            x_str = x_str.replace(to_remove_token, '')

        return x_str.strip()
    if training_args.task_name == "EAE":
        pred_types = training_args.data_for_evaluation["pred_types"]
        true_types = training_args.data_for_evaluation["true_types"]
        assert len(true_types) == len(decoded_labels)
        assert len(decoded_preds) == len(decoded_labels)
        pred_arguments, golden_arguments = [], []
        for i, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
            pred = clean_str(pred)
            label = clean_str(label)
            pred_arguments.extend(extract_argument(pred, i, pred_types[i]))
            golden_arguments.extend(extract_argument(label, i, true_types[i]))
        precision, recall, micro_f1 = f1_score_overall(pred_arguments, golden_arguments)
    else:
        assert len(decoded_preds) == len(decoded_labels)
        pred_triggers, golden_triggers = [], []
        for i, (pred, label) in enumerate(zip(decoded_preds, decoded_labels)):
            pred = clean_str(pred)
            label = clean_str(label)
            pred_triggers.extend(extract_argument(pred, i, "NA"))
            golden_triggers.extend(extract_argument(label, i, "NA"))
        precision, recall, micro_f1 = f1_score_overall(pred_triggers, golden_triggers)
    return {"micro_f1": micro_f1*100}

select_start_position

Selects the preds and labels of the first sub-word token for each word. The PreTrainedTokenizer tends to split word into sub-word tokens, and we select the prediction of the first sub-word token as the prediction of this word.

Args:

  • preds: The prediction ids of the input.

  • labels: The label ids of the input.

  • merge: Whether merge the predictions and labels into a one-dimensional list.

Return:

  • final_preds, final_labels: The tuple of final predictions and labels.

def select_start_position(preds: np.ndarray,
                          labels: np.ndarray,
                          merge: Optional[bool] = True) -> Tuple[List[List[str]], List[List[str]]]:
    """Select the preds and labels of the first sub-word token for each word.
    The PreTrainedTokenizer tends to split word into sub-word tokens, and we select the prediction of the first sub-word
    token as the prediction of this word.
    Args:
        preds (`np.ndarray`):
            The prediction ids of the input.
        labels (`np.ndarray`):
            The label ids of the input.
        merge (`bool`):
            Whether merge the predictions and labels into a one-dimensional list.
    Return:
        final_preds, final_labels (`Tuple[List[List[str]], List[List[str]]]`):
            The tuple of final predictions and labels.
    """
    final_preds = []
    final_labels = []

    if merge:
        final_preds = preds[labels != -100].tolist()
        final_labels = labels[labels != -100].tolist()
    else:
        for i in range(labels.shape[0]):
            final_preds.append(preds[i][labels[i] != -100].tolist())
            final_labels.append(labels[i][labels[i] != -100].tolist())

    return final_preds, final_labels

convert_to_names

Converts the given labels from id to their names by obtaining the value based on the given key from id2label dictionary, containing the correspondence between the ids and names of each label.

Args:

  • instances: A list of string lists containing label ids of the instances.

  • id2label: A dictionary containing the correspondence between the ids and names of each label.

Returns:

  • name_instances: A list of string lists containing the label names, where each value corresponds to the id in the input list.

def convert_to_names(instances: List[List[str]],
                     id2label: Dict[str, str]) -> List[List[str]]:
    """Converts the given labels from id to their names.
    Converts the given labels from id to their names by obtaining the value based on the given key from `id2label`
    dictionary, containing the correspondence between the ids and names of each label.
    Args:
        instances (`List[List[str]]`):
            A list of string lists containing label ids of the instances.
        id2label (`Dict[int, str]`):
            A dictionary containing the correspondence between the ids and names of each label.
    Returns:
        name_instances (`List[List[str]]`):
            A list of string lists containing the label names, where each value corresponds to the id in the input list.
    """
    name_instances = []
    for instance in instances:
        name_instances.append([id2label[item] for item in instance])
    return name_instances

compute_span_F1

Computes the F1 score of the Sequence Labeling (SL) paradigm. The prediction of the model is converted into strings, then the overall F1 score of the prediction could be calculated.

Args:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

Returns: - A dictionary containing the calculation result of F1 score.

def compute_span_F1(logits: np.ndarray,
                    labels: np.ndarray,
                    **kwargs) -> Dict[str, int]:
    """Computes the F1 score of the Sequence Labeling (SL) paradigm.
    Computes the F1 score of the Sequence Labeling (SL) paradigm. The prediction of the model is converted into strings,
    then the overall F1 score of the prediction could be calculated.
    Args:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
    Returns:
        `Dict[str: float]`:
            A dictionary containing the calculation result of F1 score.
    """

    preds = np.argmax(logits, axis=-1) if len(logits.shape) == 3 else logits
    # convert id to name
    training_args = kwargs["training_args"]
    if training_args.task_name == "EAE":
        id2label = {id: role for role, id in training_args.role2id.items()}
    elif training_args.task_name == "ED":
        id2label = {id: role for role, id in training_args.type2id.items()}
    else:
        raise ValueError("No such task!")
    final_preds, final_labels = select_start_position(preds, labels, False)
    final_preds = convert_to_names(final_preds, id2label)
    final_labels = convert_to_names(final_labels, id2label)

    # if the type is wrongly predicted, set arguments NA
    if training_args.task_name == "EAE":
        pred_types = training_args.data_for_evaluation["pred_types"]
        true_types = training_args.data_for_evaluation["true_types"]
        assert len(pred_types) == len(true_types)
        assert len(pred_types) == len(final_labels)
        for i, (pred, true) in enumerate(zip(pred_types, true_types)):
            if pred != true:
                final_preds[i] = [id2label[0]] * len(final_preds[i])  # set to NA

    micro_f1 = span_f1_score(final_labels, final_preds, mode='strict', scheme=IOB2) * 100.0
    return {"micro_f1": micro_f1}

compute_F1

Computes the F1 score of the Token Classification (TC) paradigm. The prediction of the model is converted into strings, then the overall F1 score of the prediction could be calculated.

Args:

  • ``logits`: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

Returns: - A dictionary containing the calculation result of F1 score.

def compute_F1(logits: np.ndarray,
               labels: np.ndarray,
               **kwargs) -> Dict[str, float]:
    """Computes the F1 score of the Token Classification (TC) paradigm.
    Computes the F1 score of the Token Classification (TC) paradigm. The prediction of the model is converted into
    strings, then the overall F1 score of the prediction could be calculated.
    Args:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
    Returns:
        `Dict[str: float]`:
            A dictionary containing the calculation result of F1 score.
    """
    predictions = np.argmax(logits, axis=-1)
    training_args = kwargs["training_args"]
    # if the type is wrongly predicted, set arguments NA
    if training_args.task_name == "EAE":
        pred_types = training_args.data_for_evaluation["pred_types"]
        true_types = training_args.data_for_evaluation["true_types"]
        assert len(pred_types) == len(true_types)
        assert len(pred_types) == len(predictions)
        for i, (pred, true) in enumerate(zip(pred_types, true_types)):
            if pred != true:
                predictions[i] = 0 # set to NA
        pos_labels = list(set(training_args.role2id.values()))
    else:
        pos_labels = list(set(training_args.type2id.values()))
    pos_labels.remove(0)
    micro_f1 = f1_score(labels, predictions, labels=pos_labels, average="micro") * 100.0
    return {"micro_f1": micro_f1}

softmax

Conducts the softmax operation on the last dimension and returns a numpy array.

Args:

  • logits: An numpy array of integers containing the type of each logit.

  • dim: An integer indicating the dimension for the softmax operation.

Returns:

  • An numpy array representing the normalized probability of each logit corresponding to each type of label.

def softmax(logits: np.ndarray,
            dim: Optional[int] = -1) -> np.ndarray:
    """Conducts the softmax operation on the last dimension.
    Conducts the softmax operation on the last dimension and returns a numpy array.
    Args:
        logits (`np.ndarray`):
            An numpy array of integers containing the type of each logit.
        dim (`int`, `optional`, defaults to -1):
            An integer indicating the dimension for the softmax operation.
    Returns:
        `np.ndarray`:
            An numpy array representing the normalized probability of each logit corresponding to each type of label.
    """
    logits = torch.tensor(logits)
    return torch.softmax(logits, dim=dim).numpy()

compute_accuracy

Computes the accuracy of the predictions by calculating the fraction of the true label prediction count and the entire number of data pieces.

Args:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

Returns:

  • A dictionary containing the calculation result of the accuracy.

def compute_accuracy(logits: np.ndarray,
                     labels: np.ndarray,
                     **kwargs) -> Dict[str, int]:
    """Compute the accuracy of the predictions.
    Compute the accuracy of the predictions by calculating the fraction of the true label prediction count and the
    entire number of data pieces.
    Args:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels:
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
    Returns:
        `Dict[str: float]`:
            A dictionary containing the calculation result of the accuracy.
    """
    predictions = np.argmax(softmax(logits), axis=-1)
    accuracy = (predictions == labels).sum() / labels.shape[0]
    return {"accuracy": accuracy}

compute_mrc_F1

Computes the F1 score of the Machine Reading Comprehension (MRC) method. The prediction of the model is firstly decoded into strings, then the overall F1 score of the prediction could be calculated.

Args:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

Returns:

  • A dictionary containing the calculation result of F1 score.

def compute_mrc_F1(logits: np.ndarray,
                   labels: np.ndarray,
                   **kwargs) -> Dict[str, float]:
    """Computes the F1 score of the Machine Reading Comprehension (MRC) method.
    Computes the F1 score of the Machine Reading Comprehension (MRC) method. The prediction of the model is firstly
    decoded into strings, then the overall F1 score of the prediction could be calculated.
    Args:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
    Returns:
        `Dict[str: float]`:
            A dictionary containing the calculation result of F1 score.
    """
    start_logits, end_logits = np.split(logits, 2, axis=-1)
    all_predictions, all_labels = make_predictions(start_logits, end_logits, kwargs["training_args"])
    micro_f1 = compute_mrc_F1_cls(all_predictions, all_labels)
    return {"micro_f1": micro_f1}

Convert Format

import json
import logging
import numpy as np

from typing import List, Dict, Union, Tuple
from sklearn.metrics import f1_score
from .metric import select_start_position, compute_unified_micro_f1
from ..input_engineering.input_utils import (
    get_left_and_right_pos,
    check_pred_len,
    get_ed_candidates,
    get_eae_candidates,
    get_event_preds,
    get_plain_label,
)
logger = logging.getLogger(__name__)

get_pred_per_mention

Get the predicted event type or argument role for each mention via the predictions of different paradigms. The predictions of Sequence Labeling, Seq2Seq, MRC paradigms are not aligned to each word. We need to convert the paradigm-dependent predictions to word-level for the unified evaluation. This function is designed to get the prediction for each single mention, given the paradigm-dependent predictions.

Args:

  • pos_start: The start position of the mention in the sequence of tokens.

  • pos_end: The end position of the mention in the sequence of tokens.

  • preds: The predictions of the sequence of tokens.

  • id2label: A dictionary that contains the mapping from id to textual label.

  • label: The ground truth label of the input mention.

  • label2id: A dictionary that contains the mapping from textual label to id.

  • text: The text of the input context.

  • paradigm: A string that indicates the paradigm.

Returns:

  • A string which represents the predicted label.

def get_pred_per_mention(pos_start: int,
                         pos_end: int,
                         preds: List[Union[str, Tuple[str, str]]],
                         id2label: Dict[int, str] = None,
                         label: str = None,
                         label2id: Dict[str, int] = None,
                         text: str = None,
                         paradigm: str = "sl") -> str:
    """Get the predicted event type or argument role for each mention via the predictions of different paradigms.
    The predictions of Sequence Labeling, Seq2Seq, MRC paradigms are not aligned to each word. We need to convert the
    paradigm-dependent predictions to word-level for the unified evaluation. This function is designed to get the
    prediction for each single mention, given the paradigm-dependent predictions.
    Args:
        pos_start (`int`):
            The start position of the mention in the sequence of tokens.
        pos_end (`int`):
            The end position of the mention in the sequence of tokens.
        preds (`List[Union[str, Tuple[str, str]]]`):
            The predictions of the sequence of tokens.
        id2label (`Dict[int, str]`):
            A dictionary that contains the mapping from id to textual label.
        label (`str`):
            The ground truth label of the input mention.
        label2id (`Dict[str, int]`):
            A dictionary that contains the mapping from textual label to id.
        text (`str`):
            The text of the input context.
        paradigm (`str`):
            A string that indicates the paradigm.
    Returns:
        A string which represents the predicted label.
    """
    if paradigm == "sl":
        # sequence labeling paradigm
        if pos_start == pos_end or\
                pos_end > len(preds) or \
                id2label[int(preds[pos_start])] == "O" or \
                id2label[int(preds[pos_start])].split("-")[0] != "B":
            return "NA"

        predictions = set()
        for pos in range(pos_start, pos_end):
            _pred = id2label[int(preds[pos])][2:]
            predictions.add(_pred)

        if len(predictions) > 1:
            return "NA"
        else:
            return list(predictions)[0]

    elif paradigm == "s2s":
        # seq2seq paradigm
        predictions = []
        word = text[pos_start: pos_end]
        for i, pred in enumerate(preds):
            if pred[0] == word:
                if pred[1] in label2id:
                    pred_label = pred[1]
                    predictions.append(pred_label)
        if label in predictions:
            pred_label = label
        else:
            pred_label = predictions[0] if predictions else "NA"

        # remove the prediction that has been used for a specific mention.
        if (word, pred_label) in preds:
            preds.remove((word, pred_label))

        return pred_label

    elif paradigm == "mrc":
        # mrc paradigm
        predictions = []
        for pred in preds:
            if pred[1] == (pos_start, pos_end - 1):
                pred_role = pred[0].split("_")[-1]
                predictions.append(pred_role)

        if label in predictions:
            return label
        else:
            return predictions[0] if predictions else "NA"
    else:
        raise NotImplementedError

get_trigger_detection_sl

Obtains the event detection prediction results of the ACE2005 dataset based on the sequence labeling paradigm, predicting the labels and calculating the micro F1 score based on the predictions and labels.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • labels: A list of strings indicating the actual labels of the instances.

  • data_file: A string indicating the path of the testing data file.

  • data_args: The pre-defined arguments for data processing.

Returns:

  • results: A list of strings indicating the prediction results of event triggers.

def get_trigger_detection_sl(preds: np.array,
                                     labels: np.array,
                                     data_file: str,
                                     data_args,
                                     is_overflow) -> List[str]:
    """Obtains the event detection prediction results of the ACE2005 dataset based on the sequence labeling paradigm.
    Obtains the event detection prediction results of the ACE2005 dataset based on the sequence labeling paradigm,
    predicting the labels and calculating the micro F1 score based on the predictions and labels.
    Args:
        preds (`np.array`):
            A list of strings indicating the predicted types of the instances.
        labels (`np.array`):
            A list of strings indicating the actual labels of the instances.
        data_file (`str`):
            A string indicating the path of the testing data file.
        data_args:
            The pre-defined arguments for data processing.
        is_overflow:
    Returns:
        results (`List[str]`):
            A list of strings indicating the prediction results of event triggers.
    """
    # get per-word predictions
    preds, labels = select_start_position(preds, labels, False)
    results = []
    label_names = []
    language = data_args.language

    with open(data_file, "r", encoding='utf-8') as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            item = json.loads(line.strip())

            if not is_overflow[i]:
                check_pred_len(pred=preds[i], item=item, language=language)

            candidates, label_names_per_item = get_ed_candidates(item=item)
            label_names.extend(label_names_per_item)

            # loop for converting
            for candidate in candidates:
                left_pos, right_pos = get_left_and_right_pos(text=item["text"], trigger=candidate, language=language)
                pred = get_pred_per_mention(left_pos, right_pos, preds[i], data_args.id2type)
                results.append(pred)

    if "events" in item:
        micro_f1 = compute_unified_micro_f1(label_names=label_names, results=results)
        logger.info("{} test performance after converting: {}".format(data_args.dataset_name, micro_f1))

    return results

get_argument_extraction_sl

Obtains the event argument extraction prediction results of the ACE2005 dataset based on the sequence labeling paradigm, predicting the labels of entities and negative triggers and calculating the micro F1 score based on the predictions and labels.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • labels: A list of strings indicating the actual labels of the instances.

  • data_file: A string indicating the path of the testing data file.

  • data_args: The pre-defined arguments for data processing.

Returns:

  • results: A list of strings indicating the prediction results of event arguments.

def get_argument_extraction_sl(preds: np.array,
                                       labels: np.array,
                                       data_file: str,
                                       data_args,
                                       is_overflow) -> List[str]:
    """Obtains the event argument extraction results of the ACE2005 dataset based on the sequence labeling paradigm.
    Obtains the event argument extraction prediction results of the ACE2005 dataset based on the sequence labeling
    paradigm, predicting the labels of entities and negative triggers and calculating the micro F1 score based on the
    predictions and labels.
    Args:
        preds (`np.array`):
            A list of strings indicating the predicted types of the instances.
        labels (`np.array`):
            A list of strings indicating the actual labels of the instances.
        data_file (`str`):
            A string indicating the path of the testing data file.
        data_args:
            The pre-defined arguments for data processing.
        is_overflow:
    Returns:
        results (`List[str]`):
            A list of strings indicating the prediction results of event arguments.
    """
    # evaluation mode
    eval_mode = data_args.eae_eval_mode
    language = data_args.language
    golden_trigger = data_args.golden_trigger

    # pred events
    event_preds = get_event_preds(pred_file=data_args.test_pred_file)

    # get per-word predictions
    preds, labels = select_start_position(preds, labels, False)
    results = []
    label_names = []
    with open(data_file, "r", encoding="utf-8") as f:
        trigger_idx = 0
        eae_instance_idx = 0
        lines = f.readlines()
        for line in lines:
            item = json.loads(line.strip())
            text = item["text"]
            for event in item["events"]:
                for trigger in event["triggers"]:
                    true_type = event["type"]
                    pred_type = true_type if golden_trigger or event_preds is None else event_preds[trigger_idx]
                    trigger_idx += 1

                    if eval_mode in ['default', 'loose']:
                        if pred_type == "NA":
                            continue

                    if not is_overflow[eae_instance_idx]:
                        check_pred_len(pred=preds[eae_instance_idx], item=item, language=language)

                    candidates, label_names_per_trigger = get_eae_candidates(item=item, trigger=trigger)
                    label_names.extend(label_names_per_trigger)

                    # loop for converting
                    for candi in candidates:
                        if true_type == pred_type:
                            # get word positions
                            left_pos, right_pos = get_left_and_right_pos(text=text, trigger=candi, language=language)
                            # get predictions
                            pred = get_pred_per_mention(left_pos, right_pos, preds[eae_instance_idx], data_args.id2role)
                        else:
                            pred = "NA"
                        # record results
                        results.append(pred)
                    eae_instance_idx += 1

            # negative triggers
            for trigger in item["negative_triggers"]:
                true_type = "NA"
                pred_type = true_type if golden_trigger or event_preds is None else event_preds[trigger_idx]
                trigger_idx += 1

                if eval_mode in ['default', 'strict']:  # loose mode has no neg
                    if pred_type != "NA":
                        if not is_overflow[eae_instance_idx]:
                            check_pred_len(pred=preds[eae_instance_idx], item=item, language=language)

                        candidates, label_names_per_trigger = get_eae_candidates(item=item, trigger=trigger)
                        label_names.extend(label_names_per_trigger)

                        # loop for converting
                        for candi in candidates:
                            # get word positions
                            left_pos, right_pos = get_left_and_right_pos(text=text, trigger=candi, language=language)
                            # get predictions
                            pred = get_pred_per_mention(left_pos, right_pos, preds[eae_instance_idx], data_args.id2role)
                            # record results
                            results.append(pred)

                        eae_instance_idx += 1

        assert len(preds) == eae_instance_idx

    pos_labels = list(set(label_names))
    pos_labels.remove("NA")
    micro_f1 = f1_score(label_names, results, labels=pos_labels, average="micro") * 100.0

    logger.info('Number of Instances: {}'.format(eae_instance_idx))
    logger.info("{} test performance after converting: {}".format(data_args.dataset_name, micro_f1))
    return results

get_argument_extraction_mrc

Obtains the event argument extraction prediction results of the ACE2005 dataset based on the MRC paradigm, predicting the labels of entities and negative triggers and calculating the micro F1 score based on the predictions and labels.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • labels: A list of strings indicating the actual labels of the instances.

  • data_file: A string indicating the path of the testing data file.

  • data_args: The pre-defined arguments for data processing.

Returns:

  • results: A list of strings indicating the prediction results of event arguments.

def get_argument_extraction_mrc(preds, labels, data_file, data_args, is_overflow):
    """Obtains the event argument extraction results of the ACE2005 dataset based on the MRC paradigm.
    Obtains the event argument extraction prediction results of the ACE2005 dataset based on the MRC paradigm,
    predicting the labels of entities and negative triggers and calculating the micro F1 score based on the
    predictions and labels.
    Args:
        preds (`np.array`):
            A list of strings indicating the predicted types of the instances.
        labels (`np.array`):
            A list of strings indicating the actual labels of the instances.
        data_file (`str`):
            A string indicating the path of the testing data file.
        data_args:
            The pre-defined arguments for data processing.
        is_overflow:
    Returns:
        results (`List[str]`):
            A list of strings indicating the prediction results of event arguments.
    """

    # evaluation mode
    eval_mode = data_args.eae_eval_mode
    golden_trigger = data_args.golden_trigger
    language = data_args.language

    # pred events
    event_preds = get_event_preds(pred_file=data_args.test_pred_file)

    # get per-word predictions
    results = []
    all_labels = []
    with open(data_args.test_file, "r", encoding="utf-8") as f:
        trigger_idx = 0
        eae_instance_idx = 0
        lines = f.readlines()
        for line in lines:
            item = json.loads(line.strip())
            text = item["text"]

            # preds per index
            preds_per_idx = []
            for pred in preds:
                if pred[-1] == trigger_idx:
                    preds_per_idx.append(pred)

            for event in item["events"]:
                for trigger in event["triggers"]:
                    true_type = event["type"]
                    pred_type = true_type if golden_trigger or event_preds is None else event_preds[trigger_idx]
                    trigger_idx += 1

                    if eval_mode in ['default', 'loose']:
                        if pred_type == "NA":
                            continue

                    # get candidates
                    candidates, labels_per_idx = get_eae_candidates(item, trigger)
                    all_labels.extend(labels_per_idx)

                    # loop for converting
                    for cid, candi in enumerate(candidates):
                        label = labels_per_idx[cid]
                        if pred_type == true_type:
                            # get word positions
                            left_pos, right_pos = get_left_and_right_pos(text=text, trigger=candi, language=language)
                            # get predictions
                            pred_role = get_pred_per_mention(pos_start=left_pos, pos_end=right_pos, preds=preds_per_idx,
                                                             label=label, paradigm='mrc')
                        else:
                            pred_role = "NA"
                        # record results
                        results.append(pred_role)
                    eae_instance_idx += 1

            # negative triggers
            for trigger in item["negative_triggers"]:
                true_type = "NA"
                pred_type = true_type if golden_trigger or event_preds is None else event_preds[trigger_idx]
                trigger_idx += 1

                if eval_mode in ['default', 'strict']:  # loose mode has no neg
                    if pred_type != "NA":
                        # get candidates
                        candidates, labels_per_idx = get_eae_candidates(item, trigger)
                        all_labels.extend(labels_per_idx)

                        # loop for converting
                        for candi in candidates:
                            label = "NA"
                            # get word positions
                            left_pos, right_pos = get_left_and_right_pos(text=text, trigger=candi, language=language)
                            # get predictions
                            pred_role = get_pred_per_mention(pos_start=left_pos, pos_end=right_pos, preds=preds_per_idx,
                                                             label=label, paradigm='mrc')
                            # record results
                            results.append(pred_role)

                        eae_instance_idx += 1

    pos_labels = list(data_args.role2id.keys())
    pos_labels.remove("NA")
    micro_f1 = f1_score(all_labels, results, labels=pos_labels, average="micro") * 100.0

    logger.info('Number of Instances: {}'.format(eae_instance_idx))
    logger.info("{} test performance after converting: {}".format(data_args.dataset_name, micro_f1))
    return results

get_trigger_detection_s2s

Obtains the event detection prediction results of the ACE2005 dataset based on the Seq2Seq paradigm, predicting the labels and calculating the micro F1 score based on the predictions and labels.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • labels: A list of strings indicating the actual labels of the instances.

  • data_file: A string indicating the path of the testing data file.

  • data_args: The pre-defined arguments for data processing.

Returns:

  • results: A list of strings indicating the prediction results of event triggers.

def get_trigger_detection_s2s(preds, labels, data_file, data_args, is_overflow):
    """Obtains the event detection prediction results of the ACE2005 dataset based on the Seq2Seq paradigm.
    Obtains the event detection prediction results of the ACE2005 dataset based on the Seq2Seq paradigm,
    predicting the labels and calculating the micro F1 score based on the predictions and labels.
    Args:
        preds (`np.array`):
            A list of strings indicating the predicted types of the instances.
        labels (`np.array`):
            A list of strings indicating the actual labels of the instances.
        data_file (`str`):
            A string indicating the path of the testing data file.
        data_args:
            The pre-defined arguments for data processing.
        is_overflow:
    Returns:
        results (`List[str]`):
            A list of strings indicating the prediction results of event triggers.
    """
    # get per-word predictions
    results = []
    label_names = []
    with open(data_file, "r", encoding='utf-8') as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            item = json.loads(line.strip())
            text = item["text"]
            preds_per_idx = preds[idx]

            candidates, labels_per_item = get_ed_candidates(item=item)
            for i, label in enumerate(labels_per_item):
                labels_per_item[i] = get_plain_label(label)
            label_names.extend(labels_per_item)

            # loop for converting
            for cid, candidate in enumerate(candidates):
                label = labels_per_item[cid]
                # get word positions
                left_pos, right_pos = candidate["position"]
                # get predictions
                pred_type = get_pred_per_mention(pos_start=left_pos, pos_end=right_pos, preds=preds_per_idx, text=text,
                                                 label=label, label2id=data_args.type2id, paradigm='s2s')
                # record results
                results.append(pred_type)

    if "events" in item:
        micro_f1 = compute_unified_micro_f1(label_names=label_names, results=results)
        logger.info("{} test performance after converting: {}".format(data_args.dataset_name, micro_f1))

    return results

get_argument_extraction_s2s

Obtains the event argument extraction prediction results of the ACE2005 dataset based on the Seq2Seq paradigm, predicting the labels of entities and negative triggers and calculating the micro F1 score based on the predictions and labels.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • labels: A list of strings indicating the actual labels of the instances.

  • data_file: A string indicating the path of the testing data file.

  • data_args: The pre-defined arguments for data processing.

Returns:

  • results: A list of strings indicating the prediction results of event arguments.

def get_argument_extraction_s2s(preds, labels, data_file, data_args, is_overflow):
    """Obtains the event argument extraction results of the ACE2005 dataset based on the Seq2Seq paradigm.
    Obtains the event argument extraction prediction results of the ACE2005 dataset based on the Seq2Seq paradigm,
    predicting the labels of entities and negative triggers and calculating the micro F1 score based on the
    predictions and labels.
    Args:
        preds (`np.array`):
            A list of strings indicating the predicted types of the instances.
        labels (`np.array`):
            A list of strings indicating the actual labels of the instances.
        data_file (`str`):
            A string indicating the path of the testing data file.
        data_args:
            The pre-defined arguments for data processing.
        is_overflow:
    Returns:
        results (`List[str]`):
            A list of strings indicating the prediction results of event arguments.
    """

    # evaluation mode
    eval_mode = data_args.eae_eval_mode
    golden_trigger = data_args.golden_trigger

    # pred events
    event_preds = get_event_preds(pred_file=data_args.test_pred_file)

    # get per-word predictions
    results = []
    all_labels = []
    with open(data_args.test_file, "r", encoding="utf-8") as f:
        trigger_idx = 0
        eae_instance_idx = 0
        lines = f.readlines()
        for line in lines:
            item = json.loads(line.strip())
            text = item["text"]

            for event in item["events"]:
                for trigger in event["triggers"]:
                    true_type = event["type"]
                    pred_type = true_type if golden_trigger or event_preds is None else event_preds[trigger_idx]
                    trigger_idx += 1

                    if eval_mode in ['default', 'loose']:
                        if pred_type == "NA":
                            continue

                    # preds per index
                    preds_per_idx = preds[eae_instance_idx]
                    # get candidates
                    candidates, labels_per_idx = get_eae_candidates(item, trigger)
                    for i, label in enumerate(labels_per_idx):
                        labels_per_idx[i] = get_plain_label(label)
                    all_labels.extend(labels_per_idx)

                    # loop for converting
                    for cid, candidate in enumerate(candidates):
                        label = labels_per_idx[cid]
                        if pred_type == true_type:
                            # get word positions
                            left_pos, right_pos = candidate["position"]
                            # get predictions
                            pred_role = get_pred_per_mention(pos_start=left_pos, pos_end=right_pos, preds=preds_per_idx,
                                                             text=text, label=label, label2id=data_args.role2id,
                                                             paradigm='s2s')
                        else:
                            pred_role = "NA"
                        # record results
                        results.append(pred_role)
                    eae_instance_idx += 1

            # negative triggers
            for trigger in item["negative_triggers"]:
                true_type = "NA"
                pred_type = true_type if golden_trigger or event_preds is None else event_preds[trigger_idx]
                trigger_idx += 1

                if eval_mode in ['default', 'strict']:  # loose mode has no neg
                    if pred_type != "NA":
                        # preds per index
                        preds_per_idx = preds[eae_instance_idx]

                        # get candidates
                        candidates, labels_per_idx = get_eae_candidates(item, trigger)
                        for i, label in enumerate(labels_per_idx):
                            labels_per_idx[i] = get_plain_label(label)
                        all_labels.extend(labels_per_idx)

                        # loop for converting
                        for cid, candidate in enumerate(candidates):
                            label = labels_per_idx[cid]
                            # get word positions
                            left_pos, right_pos = candidate["position"]
                            # get predictions
                            pred_role = get_pred_per_mention(pos_start=left_pos, pos_end=right_pos, preds=preds_per_idx,
                                                             text=text, label=label, label2id=data_args.role2id,
                                                             paradigm='s2s')
                            # record results
                            results.append(pred_role)

                        eae_instance_idx += 1

        assert len(preds) == eae_instance_idx

    pos_labels = list(data_args.role2id.keys())
    pos_labels.remove("NA")
    micro_f1 = f1_score(all_labels, results, labels=pos_labels, average="micro") * 100.0

    logger.info("Number of Instances: {}".format(eae_instance_idx))
    logger.info("{} test performance after converting: {}".format(data_args.dataset_name, micro_f1))
    return results

Dump Results

import jsonlines
import json
import numpy as np
from tqdm import tqdm
from collections import defaultdict
from typing import List, Dict, Union, Tuple
from .convert_format import get_pred_per_mention
from .metric import select_start_position
from ..input_engineering.input_utils import check_pred_len, get_left_and_right_pos

get_sentence_arguments

Gets the predicted arguments from a sentence in the Sequence Labeling paradigm.

Args:

  • input_sentence: A list of dictionaries each of which contains the word and the corresponding bio-role.

Returns:

  • arguments: A list of dictionaries each of which contains the word and the corresponding role.

def get_sentence_arguments(input_sentence: List[Dict[str, str]]) -> List[Dict[str, str]]:
    """Get the predicted arguments from a sentence in the Sequence Labeling paradigm.
    Args:
        input_sentence (`List[Dict[str, str]]`):
            A list of dictionaries each of which contains the word and the corresponding bio-role.
    Returns:
        arguments (`List[Dict[str, str]]`):
            A list of dictionaries each of which contains the word and the corresponding role.
    """
    input_sentence.append({"role": "NA", "word": "<EOS>"})
    arguments = []

    previous_role = None
    previous_arg = ""
    for item in input_sentence:
        if item["role"] != "NA" and previous_role is None:
            previous_role = item["role"]
            previous_arg = item["word"]

        elif item["role"] == previous_role:
            previous_arg += item["word"]

        elif item["role"] != "NA":
            arguments.append({"role": previous_role, "argument": previous_arg})
            previous_role = item["role"]
            previous_arg = item["word"]

        elif previous_role is not None:
            arguments.append({"role": previous_role, "argument": previous_arg})
            previous_role = None
            previous_arg = ""

    return arguments

get_maven_submission

Converts the predictions to the submission format of the MAVEN dataset and dumps the predictions into a json file.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • instance_ids: A list of strings containing the id of each instance to be predicted.

  • result_file: A string indicating the path to place the written json file.

def get_maven_submission(preds: Union[np.array, List[str]],
                         instance_ids: List[str],
                         result_file: str) -> None:
    """Converts the predictions to the submission format of the MAVEN dataset.
    Converts the predictions to the submission format of the MAVEN dataset and dumps the predictions into a json file.
    Args:
        preds (`List[str]`):
            A list of strings indicating the predicted types of the instances.
        instance_ids (`List[str]`):
            A list of strings containing the id of each instance to be predicted.
        result_file (`str`):
            A string indicating the path to place the written json file.
    """
    all_results = defaultdict(list)
    for i, pred in enumerate(preds):
        example_id, candidate_id = instance_ids[i].split("-")
        all_results[example_id].append({
            "id": candidate_id,
            "type_id": int(pred)
        })
    with open(result_file, "w") as f:
        for data_id in all_results.keys():
            format_result = dict(id=data_id, predictions=[])
            for candidate in all_results[data_id]:
                format_result["predictions"].append(candidate)
            f.write(json.dumps(format_result) + "\n")

get_maven_submission_sl

Obtains the instances’ predictions in the test file of the MAVEN dataset based on the sequence labeling paradigm and converts the predictions to the dataset’s submission format. The converted predictions are dumped into a json file for submission.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • labels: A list of strings indicating the actual labels of the instances.

  • result_file: A string indicating the path to place the written json file.

  • type2id: A dictionary containing the correspondences between event types and ids.

  • config: The configurations of the model.

def get_maven_submission_sl(preds: Union[np.array, List[str]],
                            labels: Union[np.array, List[str]],
                            is_overflow,
                            result_file: str,
                            type2id: Dict[str, int],
                            config) -> None:
    """Converts the predictions to the submission format of the MAVEN dataset based on the sequence labeling paradigm.
    Obtains the instances' predictions in the test file of the MAVEN dataset based on the sequence labeling paradigm and
    converts the predictions to the dataset's submission format. The converted predictions are dumped into a json file
    for submission.
    Args:
        preds (`List[str]`):
            A list of strings indicating the predicted types of the instances.
        labels (`List[str]`):
            A list of strings indicating the actual labels of the instances.
        is_overflow:
        result_file (`str`):
            A string indicating the path to place the written json file.
        type2id (`Dict[str, int]`):
            A dictionary containing the correspondences between event types and ids.
        config:
            The configurations of the model.
    """
    # get per-word predictions
    preds, _ = select_start_position(preds, labels, False)
    results = defaultdict(list)
    language = config.language

    with open(config.test_file, "r") as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            item = json.loads(line.strip())
            text = item["text"]

            # check for alignment
            if not is_overflow[i]:
                check_pred_len(pred=preds[i], item=item, language=language)

            for candidate in item["candidates"]:
                # get word positions
                word_pos_start, word_pos_end = get_left_and_right_pos(text=text, trigger=candidate, language=language)
                # get predictions
                pred = get_pred_per_mention(word_pos_start, word_pos_end, preds[i], config.id2type)
                # record results
                results[item["id"]].append({
                    "id": candidate["id"].split("-")[-1],
                    "type_id": int(type2id[pred]),
                })
    # dump results
    with open(result_file, "w") as f:
        for id, preds_per_doc in results.items():
            results_per_doc = dict(id=id, predictions=preds_per_doc)
            f.write(json.dumps(results_per_doc)+"\n")

get_maven_submission_seq2seq

Obtains the instances’ predictions in the test file of the MAVEN dataset based on the Sequence-to-Sequence (Seq2Seq) paradigm and converts the predictions to the dataset’s submission format. The converted predictions are dumped into a json file for submission.

Args:

  • preds: The textual predictions of the Event Type or Argument Role. A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)

  • save_path: A string indicating the path to place the written json file.

  • data_args: The pre-defined arguments for data processing.

def get_maven_submission_seq2seq(preds: List[List[Tuple[str, str]]],
                                 save_path: str,
                                 data_args) -> None:
    """Converts the predictions to the submission format of the MAVEN dataset based on the Seq2Seq paradigm.
    Obtains the instances' predictions in the test file of the MAVEN dataset based on the Sequence-to-Sequence (Seq2Seq)
    paradigm and converts the predictions to the dataset's submission format. The converted predictions are dumped into
    a json file for submission.
    Args:
        preds (`List[List[Tuple[str, str]]]`):
            The textual predictions of the Event Type or Argument Role.
            A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)
        save_path (`str`):
            A string indicating the path to place the written json file.
        data_args:
            The pre-defined arguments for data processing.
    """
    type2id = data_args.type2id
    results = defaultdict(list)
    with open(data_args.test_file, "r") as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            item = json.loads(line.strip())
            text = item["text"]
            preds_per_idx = preds[idx]

            for candidate in item["candidates"]:
                label = "NA"
                left_pos, right_pos = candidate["position"]
                # get predictions
                pred_type = get_pred_per_mention(pos_start=left_pos, pos_end=right_pos, preds=preds_per_idx, text=text,
                                                 label=label, label2id=type2id, paradigm='s2s')

                # record results
                results[item["id"]].append({"id": candidate["id"].split("-")[-1], "type_id": int(type2id[pred_type])})
    # dump results
    with open(save_path, "w") as f:
        for id, preds_per_doc in results.items():
            results_per_doc = dict(id=id, predictions=preds_per_doc)
            f.write(json.dumps(results_per_doc) + "\n")

get_leven_submission

Converts the predictions to the submission format of the LEVEN dataset and dumps the predictions into a json file.

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • instance_ids: A list of strings containing the id of each instance to be predicted.

  • result_file: A string indicating the path to place the written json file.

Returns:

  • The parameters of the input are passed to the get_maven_submission() method for further predictions.

def get_leven_submission(preds: Union[np.array, List[str]],
                         instance_ids: List[str],
                         result_file: str) -> None:
    """Converts the predictions to the submission format of the LEVEN dataset.
    Converts the predictions to the submission format of the LEVEN dataset and dumps the predictions into a json file.
    Args:
        preds (`List[str]`):
            A list of strings indicating the predicted types of the instances.
        instance_ids (`List[str]`):
            A list of strings containing the id of each instance to be predicted.
        result_file (`str`):
            A string indicating the path to place the written json file.
    Returns:
        The parameters of the input are passed to the `get_maven_submission()` method for further predictions.
    """
    return get_maven_submission(preds, instance_ids, result_file)

get_leven_submission_sl

Obtains the instances’ predictions in the test file of the LEVEN dataset based on the sequence labeling paradigm and converts the predictions to the dataset’s submission format. The converted predictions are dumped into a json file for submission.

Args:

  • preds: A list of strings indicating the predicted type of the instances.

  • labels: A list of strings indicating the actual label of the instances.

  • result_file: A string indicating the path to place the written json file.

  • type2id: A dictionary containing the correspondences between event types and ids.

  • config: The configurations of the model.

Returns:

  • The parameters of the input are passed to the get_maven_submission_sl() method for further predictions.

def get_leven_submission_sl(preds: Union[np.array, List[str]],
                            labels: Union[np.array, List[str]],
                            is_overflow,
                            result_file: str,
                            type2id: Dict[str, int],
                            config):
    """Converts the predictions to the submission format of the LEVEN dataset based on the sequence labeling paradigm.
    Obtains the instances' predictions in the test file of the LEVEN dataset based on the sequence labeling paradigm and
    converts the predictions to the dataset's submission format. The converted predictions are dumped into a json file
    for submission.
    Args:
        preds (`List[str]`):
            A list of strings indicating the predicted type of the instances.
        labels (`List[str]`):
            A list of strings indicating the actual label of the instances.
        is_overflow:
        result_file (`str`):
            A string indicating the path to place the written json file.
        type2id (`Dict[str, int]`):
            A dictionary containing the correspondences between event types and ids.
        config:
            The configurations of the model.
    Returns:
        The parameters of the input are passed to the `get_maven_submission_sl()` method for further predictions.
    """
    return get_maven_submission_sl(preds, labels, is_overflow, result_file, type2id, config)

get_leven_submission_seq2seq

Obtains the instances’ predictions in the test file of the LEVEN dataset based on the Sequence-to-Sequence (Seq2Seq) paradigm and converts the predictions to the dataset’s submission format. The converted predictions are dumped into a json file for submission.

Args:

  • preds: The textual predictions of the Event Type or Argument Role. A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)

  • save_path: A string indicating the path to place the written json file.

  • data_args: The pre-defined arguments for data processing.

Returns:

  • The parameters of the input are passed to the get_maven_submission_seq2seq() method for further predictions. The formats of LEVEN and MAVEN are the same.

def get_leven_submission_seq2seq(preds: List[List[Tuple[str, str]]],
                                 save_path: str,
                                 data_args) -> None:
    """Converts the predictions to the submission format of the LEVEN dataset based on the Seq2Seq paradigm.
    Obtains the instances' predictions in the test file of the LEVEN dataset based on the Sequence-to-Sequence (Seq2Seq)
    paradigm and converts the predictions to the dataset's submission format. The converted predictions are dumped into
    a json file for submission.
    Args:
        preds (`List[List[Tuple[str, str]]]`):
            The textual predictions of the Event Type or Argument Role.
            A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)
        save_path (`str`):
            A string indicating the path to place the written json file.
        data_args:
            The pre-defined arguments for data processing.
    Returns:
        The parameters of the input are passed to the `get_maven_submission_seq2seq()` method for further predictions.
        The formats of LEVEN and MAVEN are the same.
    """
    return get_maven_submission_seq2seq(preds, save_path, data_args)

get_duee_submission_sl

Args:

  • preds: A list of strings indicating the predicted types of the instances.

  • labels: A list of strings indicating the actual labels of the instances.

  • result_file: A string indicating the path to place the written json file.

  • config: The configurations of the model.

Returns:

all_results: A list of dictionaries containing the predictions of events.

def get_duee_submission_sl(preds: Union[np.array, List[str]],
                           labels: Union[np.array, List[str]],
                           is_overflow,
                           result_file: str,
                           config) -> List[Dict[str, Union[str, Dict]]]:
    """Converts the predictions to the submission format of the DuEE dataset based on the sequence labeling paradigm.
    Obtains the instances' predictions in the test file of the DuEE dataset based on the sequence labeling paradigm and
    converts the predictions to the dataset's submission format. The converted predictions are dumped into a json file
    for submission.
    Args:
        preds (`List[str]`):
            A list of strings indicating the predicted types of the instances.
        labels (`List[str]`):
            A list of strings indicating the actual labels of the instances.
        is_overflow:
        result_file (`str`):
            A string indicating the path to place the written json file.
        config:
            The configurations of the model.
    Returns:
        all_results (`List[Dict[str, Union[str, Dict]]]`):
            A list of dictionaries containing the predictions of events.
    """
    # trigger predictions
    ed_preds = json.load(open(config.test_pred_file))

    # get per-word predictions
    preds, labels = select_start_position(preds, labels, False)
    all_results = []

    with open(config.test_file, "r", encoding='utf-8') as f:
        trigger_idx = 0
        example_idx = 0
        lines = f.readlines()
        for line in tqdm(lines, desc='Generating DuEE1.0 Submission Files'):
            item = json.loads(line.strip())

            item_id = item["id"]
            event_list = []

            for tid, trigger in enumerate(item["candidates"]):
                pred_event_type = ed_preds[trigger_idx]
                if pred_event_type != "NA":
                    if not is_overflow[example_idx]:
                        if config.language == "English":
                            assert len(preds[example_idx]) == len(item["text"].split())
                        elif config.language == "Chinese":
                            assert len(preds[example_idx]) == len("".join(item["text"].split()))  # remove space token
                        else:
                            raise NotImplementedError

                    pred_event = dict(event_type=pred_event_type, arguments=[])
                    sentence_result = []
                    for cid, candidate in enumerate(item["candidates"]):
                        if cid == tid:
                            continue
                        char_pos = candidate["position"]
                        if config.language == "English":
                            word_pos_start = len(item["text"][:char_pos[0]].split())
                            word_pos_end = word_pos_start + len(item["text"][char_pos[0]:char_pos[1]].split())
                        elif config.language == "Chinese":
                            word_pos_start = len([w for w in item["text"][:char_pos[0]] if w.strip('\n\xa0� ')])
                            word_pos_end = len([w for w in item["text"][:char_pos[1]] if w.strip('\n\xa0� ')])
                        else:
                            raise NotImplementedError
                        # get predictions
                        pred = get_pred_per_mention(word_pos_start, word_pos_end, preds[example_idx], config.id2role)
                        sentence_result.append({"role": pred, "word": candidate["trigger_word"]})

                    pred_event["arguments"] = get_sentence_arguments(sentence_result)
                    if pred_event["arguments"]:
                        event_list.append(pred_event)

                    example_idx += 1

                trigger_idx += 1

            all_results.append({"id": item_id, "event_list": event_list})

    # dump results
    with jsonlines.open(result_file, "w") as f:
        for r in all_results:
            jsonlines.Writer.write(f, r)

    return all_results

Evaluation Utils

import os
import json
import shutil
import logging
import jsonlines
import numpy as np

from tqdm import tqdm
from pathlib import Path
from typing import List, Dict, Union, Tuple
from transformers import PreTrainedTokenizer

from ..trainer import Trainer
from ..trainer_seq2seq import Seq2SeqTrainer
from ..arguments import DataArguments, ModelArguments, TrainingArguments
from ..input_engineering.seq2seq_processor import extract_argument
from ..input_engineering.base_processor import EDDataProcessor, EAEDataProcessor
from ..input_engineering.mrc_converter import make_predictions, find_best_thresh

from .convert_format import get_trigger_detection_sl, get_trigger_detection_s2s

logger = logging.getLogger(__name__)

dump_preds

Save the Event Detection predictions for further use in the Event Argument Extraction task.

Args:

  • trainer: The trainer for event detection.

  • tokenizer: The tokenizer proposed for the tokenization process.

  • data_class: The processor of the input data.

  • output_dir: The file path to dump the event detection predictions.

  • model_args: The pre-defined arguments for model configuration.

  • data_args: The pre-defined arguments for data processing.

  • training_args: The pre-defined arguments for training event detection model.

  • mode: The mode of the prediction, can be ‘train’, ‘valid’ or ‘test’.

def dump_preds(trainer: Union[Trainer, Seq2SeqTrainer],
               tokenizer: PreTrainedTokenizer,
               data_class: type,
               output_dir: Union[str,Path],
               model_args: ModelArguments,
               data_args: DataArguments,
               training_args: TrainingArguments,
               mode: str = "train",
               ) -> None:
    """Dump the Event Detection predictions for each token in the dataset.
    Save the Event Detection predictions for further use in the Event Argument Extraction task.
    Args:
        trainer:
            The trainer for event detection.
        tokenizer (`PreTrainedTokenizer`):
            A string indicating the tokenizer proposed for the tokenization process.
        data_class:
            The processor of the input data.
        output_dir (`str`):
            The file path to dump the event detection predictions.
        model_args (`ModelArguments`):
            The pre-defined arguments for model configuration.
        data_args (`DataArguments`):
            The pre-defined arguments for data processing.
        training_args (`TrainingArguments`):
            The pre-defined arguments for training event detection model.
        mode (`str`):
            The mode of the prediction, can be 'train', 'valid' or 'test'.
    Returns:
        None
    """
    if mode == "train":
        data_file = data_args.train_file
    elif mode == "valid":
        data_file = data_args.validation_file
    elif mode == "test":
        data_file = data_args.test_file
    else:
        raise NotImplementedError

    logits, labels, metrics, dataset = predict(trainer=trainer, tokenizer=tokenizer, data_class=data_class,
                                               data_args=data_args, data_file=data_file,
                                               training_args=training_args)
    logger.info("\n")
    logger.info("{}-Dump Preds-{}{}".format("-" * 25, mode, "-" * 25))
    logger.info("Test file: {}, Metrics: {}, Split_Infer: {}".format(data_file, metrics, data_args.split_infer))

    preds = get_pred_s2s(logits, tokenizer) if model_args.paradigm == "seq2seq" else np.argmax(logits, axis=-1)

    if model_args.paradigm == "token_classification":
        pred_labels = [data_args.id2type[pred] for pred in preds]
    elif model_args.paradigm == "sequence_labeling":
        pred_labels = get_trigger_detection_sl(preds, labels, data_file, data_args, dataset.is_overflow)
    elif model_args.paradigm == "seq2seq":
        pred_labels = get_trigger_detection_s2s(preds, labels, data_file, data_args, None)
    else:
        raise NotImplementedError

    save_path = os.path.join(output_dir, "{}_preds.json".format(mode))

    json.dump(pred_labels, open(save_path, "w", encoding='utf-8'), ensure_ascii=False)
    logger.info("ED {} preds dumped to {}\n ED finished!".format(mode, save_path))

get_pred_s2s

Converts Seq2Seq output logits to textual Event Type Prediction in Event Detection task, or to textual Argument Role Prediction in Event Argument Extraction task.

Args:

  • logits: The decoded logits of the Seq2Seq model.

  • tokenizer: A string indicating the tokenizer proposed for the tokenization process.

  • pred_types: The event detection predictions, only used in Event Argument Extraction task.

Returns:

  • preds: The textual predictions of the Event Type or Argument Role. A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)

def get_pred_s2s(logits: np.array,
                 tokenizer: PreTrainedTokenizer,
                 pred_types: List[str] = None,
                 ) -> List[List[Tuple[str, str]]]:
    """Convert Seq2Seq output logits to textual Event Type Prediction or Argument Role Prediction.
    Convert Seq2Seq output logits to textual Event Type Prediction in Event Detection task,
        or to textual Argument Role Prediction in Event Argument Extraction task.
    Args:
        logits (`np.array`):
            The decoded logits of the Seq2Seq model.
        tokenizer (`PreTrainedTokenizer`):
            A string indicating the tokenizer proposed for the tokenization process.
        pred_types (`List[str]`):
            The event detection predictions, only used in Event Argument Extraction task.
    Returns:
        preds (`List[List[Tuple[str, str]]]`):
            The textual predictions of the Event Type or Argument Role.
            A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)
    """

    decoded_preds = tokenizer.batch_decode(logits, skip_special_tokens=False)

    def clean_str(x_str):
        for to_remove_token in [tokenizer.eos_token, tokenizer.pad_token]:
            x_str = x_str.replace(to_remove_token, '')
        return x_str.strip()

    preds = list()
    for i, pred in enumerate(decoded_preds):
        pred = clean_str(pred)
        pred_type = pred_types[i] if pred_types else "NA"
        arguments = extract_argument(pred, i, pred_type)
        tmp = list()
        for arg in arguments:
            tmp.append((arg[-1], arg[-2]))
        preds.append(tmp)

    return preds

get_pred_mrc

Converts MRC output logits to textual Event Type Prediction in Event Detection task, or to textual Argument Role Prediction in Event Argument Extraction task.

Args:

  • logits: The logits output of the MRC model.

  • training_args: The event detection predictions, only used in Event Argument Extraction task.

Returns:

  • preds: The textual predictions of the Event Type or Argument Role. A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)

def get_pred_mrc(logits: np.array,
                 training_args: TrainingArguments,
                 ) -> List[List[Tuple[str, str]]]:
    """Convert MRC output logits to textual Event Type Prediction or Argument Role Prediction.
    Convert MRC output logits to textual Event Type Prediction in Event Detection task,
        or to textual Argument Role Prediction in Event Argument Extraction task.
    Args:
        logits (`np.array`):
            The logits output of the MRC model.
        training_args (`TrainingArguments`):
            The event detection predictions, only used in Event Argument Extraction task.
    Returns:
        preds (`List[List[Tuple[str, str]]]`):
            The textual predictions of the Event Type or Argument Role.
            A list of tuple lists, in which each tuple is (argument, role) or (trigger, event_type)
    """

    start_logits, end_logits = np.split(logits, 2, axis=-1)
    all_preds, all_labels = make_predictions(start_logits, end_logits, training_args)

    all_preds = sorted(all_preds, key=lambda x: x[-2])
    best_na_thresh = find_best_thresh(all_preds, all_labels)
    logger.info("Best thresh founded. %.6f" % best_na_thresh)

    final_preds = []
    for argument in all_preds:
        if argument[-2] < best_na_thresh:
            final_preds.append(argument[:-2] + argument[-1:])  # no na_prob

    return final_preds

predict

Predicts the test set of the event detection task. The prediction of logits and labels, evaluation metrics’ results, and the dataset would be returned.

Args:

  • trainer: The trainer for event detection.

  • tokenizer: The tokenizer proposed for the tokenization process.

  • data_class: The processor of the input data.

  • data_args: The pre-defined arguments for data processing.

  • data_file: A string representing the file path of the dataset.

  • training_args: The pre-defined arguments for training.

Returns:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

  • metrics: The evaluation metrics result based on the predictions and annotations.

  • dataset: An instance of the testing dataset.

def predict(trainer: Union[Trainer, Seq2SeqTrainer],
            tokenizer: PreTrainedTokenizer,
            data_class: type,
            data_args: DataArguments,
            data_file: str,
            training_args: TrainingArguments,
            ) -> Tuple[np.array, np.array, Dict, Union[EDDataProcessor, EAEDataProcessor]]:
    """Predicts the test set of the Event Detection task or Event Argument Extraction task.
    Predicts the test set of the event detection task. The prediction of logits and labels, evaluation metrics' results,
    and the dataset would be returned.
    Args:
        trainer:
            The trainer for event detection.
        tokenizer (`PreTrainedTokenizer`):
            A string indicating the tokenizer proposed for the tokenization process.
        data_class:
            The processor of the input data.
        data_args:
            The pre-defined arguments for data processing.
        data_file (`str`):
            A string representing the file path of the dataset.
        training_args (`TrainingArguments`):
            The pre-defined arguments for training.
    Returns:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels: (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
        metrics:
            The evaluation metrics result based on the predictions and annotations.
        dataset:
            An instance of the testing dataset.
    """

    if training_args.task_name == "ED":
        pred_func = predict_sub_ed if data_args.split_infer else predict_ed
        return pred_func(trainer, tokenizer, data_class, data_args, data_file)

    elif training_args.task_name == 'EAE':
        pred_func = predict_sub_eae if data_args.split_infer else predict_eae
        return pred_func(trainer, tokenizer, data_class, data_args, training_args)

    else:
        raise NotImplementedError

get_sub_files

Splits a large data file into several small data files for evaluation. Sometimes, the test data file can be too large to make prediction due to GPU memory constrain. Therefore, we split the large file into several smaller ones and make predictions on each.

Args:

  • input_test_file: The path to the large data file that needs to split.

  • input_test_pred_file: The path to the Event Detection Predictions of the input_test_file. Only used in Event Argument Extraction task.

  • sub_size: The number of items contained each split file.

Returns:

  • if input_test_pred_file is not None: (Event Argument Extraction task)
    • output_test_files, output_pred_files: The lists of paths to the split files.

  • else:
    • output_test_files: The list of paths to the split files.

def get_sub_files(input_test_file: str,
                  input_test_pred_file: str = None,
                  sub_size: int = 5000,
                  ) -> Union[List[str], Tuple[List[str], List[str]]]:
    """Split a large data file into several small data files for evaluation.
    Sometimes, the test data file can be too large to make prediction due to GPU memory constrain.
    Therefore, we split the large file into several smaller ones and make predictions on each.
    Args:
        input_test_file (`str`):
            The path to the large data file that needs to split.
        input_test_pred_file (`str`):
            The path to the Event Detection Predictions of the input_test_file.
            Only used in Event Argument Extraction task.
        sub_size (`int`):
            The number of items contained each split file.
    Returns:
        if input_test_pred_file is not None: (Event Argument Extraction task)
            output_test_files, output_pred_files:
                The lists of paths to the split files.
        else:
            output_test_files:
                The list of paths to the split files.
    """
    test_data = list(jsonlines.open(input_test_file))
    sub_data_folder = '/'.join(input_test_file.split('/')[:-1]) + '/test_cache/'

    # clear the cache dir before split evaluate
    if os.path.isdir(sub_data_folder):
        shutil.rmtree(sub_data_folder)
        logger.info("Cleared Existing Cache Dir")

    os.makedirs(sub_data_folder, exist_ok=False)
    output_test_files = []

    pred_data, sub_pred_folder = None, None
    output_pred_files = []
    if input_test_pred_file:
        pred_data = json.load(open(input_test_pred_file, encoding='utf-8'))
        sub_pred_folder = '/'.join(input_test_pred_file.split('/')[:-1]) + '/test_cache/'
        os.makedirs(sub_pred_folder, exist_ok=True)

    pred_start = 0
    for sub_id, i in enumerate(range(0, len(test_data), sub_size)):
        test_data_sub = test_data[i: i + sub_size]
        test_file_sub = sub_data_folder + 'sub-{}.json'.format(sub_id)

        with jsonlines.open(test_file_sub, 'w') as f:
            for data in test_data_sub:
                jsonlines.Writer.write(f, data)

        output_test_files.append(test_file_sub)

        if input_test_pred_file:
            pred_end = pred_start + sum([len(d['candidates']) for d in test_data_sub])
            test_pred_sub = pred_data[pred_start: pred_end]
            pred_start = pred_end

            test_pred_file_sub = sub_pred_folder + 'sub-{}.json'.format(sub_id)

            with open(test_pred_file_sub, 'w', encoding='utf-8') as f:
                json.dump(test_pred_sub, f, ensure_ascii=False)

            output_pred_files.append(test_pred_file_sub)

    if input_test_pred_file:
        return output_test_files, output_pred_files

    return output_test_files

predict_ed

Predicts the test set of the event detection task. The prediction of logits and labels, evaluation metrics’ results, and the dataset would be returned.

Args:

  • trainer: The trainer for event detection.

  • tokenizer: The tokenizer proposed for the tokenization process.

  • data_class: The processor of the input data.

  • data_args: The pre-defined arguments for data processing.

  • data_file: A string representing the file path of the dataset.

Returns:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

  • metrics: The evaluation metrics result based on the predictions and annotations.

  • dataset: An instance of the testing dataset.

def predict_ed(trainer: Union[Trainer, Seq2SeqTrainer],
               tokenizer: PreTrainedTokenizer,
               data_class: type,
               data_args,
               data_file: str,
               ) -> Tuple[np.array, np.array, Dict, EDDataProcessor]:
    """Predicts the test set of the event detection task.
    Predicts the test set of the event detection task. The prediction of logits and labels, evaluation metrics' results,
    and the dataset would be returned.
    Args:
        trainer:
            The trainer for event detection.
        tokenizer (`PreTrainedTokenizer`):
            A string indicating the tokenizer proposed for the tokenization process.
        data_class:
            The processor of the input data.
        data_args:
            The pre-defined arguments for data processing.
        data_file (`str`):
            A string representing the file path of the dataset.
    Returns:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels: (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
        metrics:
            The evaluation metrics result based on the predictions and annotations.
        dataset:
            An instance of the testing dataset.
    """
    dataset = data_class(data_args, tokenizer, data_file)
    logits, labels, metrics = trainer.predict(
        test_dataset=dataset,
        ignore_keys=["loss"]
    )
    return logits, labels, metrics, dataset

predict_sub_ed

Predicts the test set of the event detection task of a list of datasets. The prediction of logits and labels are conducted separately on each file, and the evaluation metrics’ results are calculated after concatenating the predictions together. Finally, the prediction of logits and labels, evaluation metrics’ results, and the dataset would be returned.

Args:

  • trainer: The trainer for event detection.

  • tokenizer: The tokenizer proposed for the tokenization process.

  • data_class: The processor of the input data.

  • data_args: The pre-defined arguments for data processing.

  • data_file: A string representing the file path of the dataset.

Returns:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

  • metrics: The evaluation metrics result based on the predictions and annotations.

  • dataset: An instance of the testing dataset.

def predict_sub_ed(trainer: Union[Trainer, Seq2SeqTrainer],
                   tokenizer: PreTrainedTokenizer,
                   data_class: type,
                   data_args: DataArguments,
                   data_file: str,
                   ) -> Tuple[np.array, np.array, Dict, EDDataProcessor]:
    """Predicts the test set of the event detection task of subfile datasets.
    Predicts the test set of the event detection task of a list of datasets. The prediction of logits and labels are
    conducted separately on each file, and the evaluation metrics' results are calculated after concatenating the
    predictions together. Finally, the prediction of logits and labels, evaluation metrics' results, and the dataset
    would be returned.
    Args:
        trainer:
            The trainer for event detection.
        tokenizer (`PreTrainedTokenizer`):
            A string indicating the tokenizer proposed for the tokenization process.
        data_class:
            The processor of the input data.
        data_args:
            The pre-defined arguments for data processing.
        data_file (`str`):
            A string representing the file path of the dataset.
    Returns:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels: (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
        metrics:
            The evaluation metrics result based on the predictions and annotations.
        dataset:
            An instance of the testing dataset.
    """
    data_file_full = data_file
    data_file_list = get_sub_files(input_test_file=data_file_full,
                                   sub_size=data_args.split_infer_size)

    logits_list, labels_list = [], []
    for data_file in tqdm(data_file_list, desc='Split Evaluate'):
        data_args.truncate_in_batch = False
        logits, labels, metrics, _ = predict_ed(trainer, tokenizer, data_class, data_args, data_file)
        logits_list.append(logits)
        labels_list.append(labels)

    logits = np.concatenate(logits_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)

    metrics = trainer.compute_metrics(logits=logits, labels=labels,
                                      **{"tokenizer": tokenizer, "training_args": trainer.args})

    dataset = data_class(data_args, tokenizer, data_file_full)
    return logits, labels, metrics, dataset

predict_eae

Predicts the test set of the event argument extraction task. The prediction of logits and labels, evaluation metrics’ results, and the dataset would be returned.

Args:

  • trainer: The trainer for event detection.

  • tokenizer: A string indicating the tokenizer proposed for the tokenization process.

  • data_class: The processor of the input data.

  • data_args: The pre-defined arguments for data processing.

  • training_args: The pre-defined arguments for the training process.

Returns:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

  • metrics: The evaluation metrics result based on the predictions and annotations.

  • test_dataset: An instance of the testing dataset.

def predict_eae(trainer: Union[Trainer, Seq2SeqTrainer],
                tokenizer: PreTrainedTokenizer,
                data_class: type,
                data_args: DataArguments,
                training_args: TrainingArguments,
                ) -> Tuple[np.array, np.array, Dict, EAEDataProcessor]:
    """Predicts the test set of the event argument extraction task.
    Predicts the test set of the event argument extraction task. The prediction of logits and labels, evaluation
    metrics' results, and the dataset would be returned.
    Args:
        trainer:
            The trainer for event detection.
        tokenizer (`PreTrainedTokenizer`):
            A string indicating the tokenizer proposed for the tokenization process.
        data_class:
            The processor of the input data.
        data_args:
            The pre-defined arguments for data processing.
        training_args:
            The pre-defined arguments for the training process.
    Returns:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels: (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
        metrics:
            The evaluation metrics result based on the predictions and annotations.
        test_dataset:
            An instance of the testing dataset.
    """
    test_dataset = data_class(data_args, tokenizer, data_args.test_file, data_args.test_pred_file)
    training_args.data_for_evaluation = test_dataset.get_data_for_evaluation()
    logits, labels, metrics = trainer.predict(test_dataset=test_dataset, ignore_keys=["loss"])

    return logits, labels, metrics, test_dataset

predict_sub_eae

Predicts the test set of the event detection task of a list of datasets. The prediction of logits and labels are conducted separately on each file, and the evaluation metrics’ results are calculated after concatenating the predictions together. Finally, the prediction of logits and labels, evaluation metrics’ results, and the dataset would be returned.

Args:

  • trainer: The trainer for event detection.

  • tokenizer: The tokenizer proposed for the tokenization process.

  • data_class: The processor of the input data.

  • data_args: The pre-defined arguments for data processing.

  • training_args: The pre-defined arguments for the training process.

Returns:

  • logits: An numpy array of integers containing the predictions from the model to be decoded.

  • labels: An numpy array of integers containing the actual labels obtained from the annotated dataset.

  • metrics: The evaluation metrics result based on the predictions and annotations.

  • test_dataset: An instance of the testing dataset.

def predict_sub_eae(trainer: Union[Trainer, Seq2SeqTrainer],
                    tokenizer: PreTrainedTokenizer,
                    data_class: type,
                    data_args: DataArguments,
                    training_args: TrainingArguments,
                    ) -> Tuple[np.array, np.array, Dict, EDDataProcessor]:
    """Predicts the test set of the event detection task of subfile datasets.
    Predicts the test set of the event detection task of a list of datasets. The prediction of logits and labels are
    conducted separately on each file, and the evaluation metrics' results are calculated after concatenating the
    predictions together. Finally, the prediction of logits and labels, evaluation metrics' results, and the dataset
    would be returned.
    Args:
        trainer:
            The trainer for event detection.
        tokenizer (`PreTrainedTokenizer`):
            A string indicating the tokenizer proposed for the tokenization process.
        data_class:
            The processor of the input data.
        data_args:
            The pre-defined arguments for data processing.
        training_args:
            The pre-defined arguments for the training process.
    Returns:
        logits (`np.ndarray`):
            An numpy array of integers containing the predictions from the model to be decoded.
        labels: (`np.ndarray`):
            An numpy array of integers containing the actual labels obtained from the annotated dataset.
        metrics:
            The evaluation metrics result based on the predictions and annotations.
        test_dataset:
            An instance of the testing dataset.
    """
    test_file_full, test_pred_file_full = data_args.test_file, data_args.test_pred_file
    test_file_list, test_pred_file_list = get_sub_files(input_test_file=test_file_full,
                                                        input_test_pred_file=test_pred_file_full,
                                                        sub_size=data_args.split_infer_size)

    logits_list, labels_list = [], []
    for test_file, test_pred_file in tqdm(list(zip(test_file_list, test_pred_file_list)), desc='Split Evaluate'):
        data_args.test_file = test_file
        data_args.test_pred_file = test_pred_file

        logits, labels, metrics, _ = predict_eae(trainer, tokenizer, data_class, data_args, training_args)
        logits_list.append(logits)
        labels_list.append(labels)

    # TODO: concat operation is slow
    logits = np.concatenate(logits_list, axis=0)
    labels = np.concatenate(labels_list, axis=0)

    test_dataset_full = data_class(data_args, tokenizer, test_file_full, test_pred_file_full)
    training_args.data_for_evaluation = test_dataset_full.get_data_for_evaluation()

    metrics = trainer.compute_metrics(logits=logits, labels=labels,
                                      **{"tokenizer": tokenizer, "training_args": training_args})

    data_args.test_file = test_file_full
    data_args.test_pred_file = test_pred_file_full

    test_dataset = data_class(data_args, tokenizer, data_args.test_file, data_args.test_pred_file)
    return logits, labels, metrics, test_dataset