Machine Reading Comprehension (MRC) Processor

import pdb
import json
import logging

from tqdm import tqdm
from .base_processor import (
    EAEDataProcessor,
    EAEInputExample,
    EAEInputFeatures
)
from .mrc_converter import read_query_templates
from .input_utils import get_words, get_left_and_right_pos
from collections import defaultdict

logger = logging.getLogger(__name__)

EAEMRCProcessor

Data processor for Machine Reading Comprehension (MRC) for event argument extraction. The class is inherited from the EAEDataProcessor class, in which the undefined functions, including read_examples() and convert_examples_to_features() are implemented; a new function entitled remove_sub_word() is defined to remove the annotations whose word is a sub-word, the rest of the attributes and functions are multiplexed from the EAEDataProcessor class.

class EAEMRCProcessor(EAEDataProcessor):
    """Data processor for Machine Reading Comprehension (MRC) for event argument extraction.

    Data processor for Machine Reading Comprehension (MRC) for event argument extraction. The class is inherited from
    the `EAEDataProcessor` class, in which the undefined functions, including `read_examples()` and
    `convert_examples_to_features()` are implemented; a new function entitled `remove_sub_word()` is defined to remove
    the annotations whose word is a sub-word, the rest of the attributes and functions are multiplexed from the
    `EAEDataProcessor` class.
    """

    def __init__(self,
                config,
                tokenizer,
                input_file: str,
                pred_file: str,
                is_training: bool = False) -> None:
        """Constructs a `EAEMRCProcessor`."""
        super().__init__(config, tokenizer, pred_file, is_training)
        self.read_examples(input_file)
        self.convert_examples_to_features()

    def read_examples(self,
                    input_file: str) -> None:
        """Obtains a collection of `EAEInputExample`s for the dataset."""
        self.examples = []
        self.data_for_evaluation["golden_arguments"] = []
        trigger_idx = 0
        query_templates = read_query_templates(self.config.prompt_file,
                                            translate=self.config.dataset_name == "ACE2005-ZH")
        template_id = self.config.mrc_template_id
        with open(input_file, "r", encoding="utf-8") as f:
            for idx, line in enumerate(tqdm(f.readlines(), desc="Reading from %s" % input_file)):
                item = json.loads(line.strip())
                if "events" in item:
                    words = get_words(text=item["text"], language=self.config.language)
                    for event in item["events"]:
                        for trigger in event["triggers"]:
                            if self.is_training or self.config.golden_trigger or self.event_preds is None:
                                pred_event_type = event["type"]
                            else:
                                pred_event_type = self.event_preds[trigger_idx]
                            trigger_idx += 1

                            # Evaluation mode for EAE
                            # If predicted event type is NA:
                            #   in [default] and [loose] modes, we don't consider the trigger
                            #   in [strict] mode, we consider the trigger
                            if self.config.eae_eval_mode in ["default", "loose"] and pred_event_type == "NA":
                                continue

                            # golden label for the trigger
                            arguments_per_trigger = dict(id=trigger_idx-1,
                                                        arguments=[],
                                                        pred_type=pred_event_type,
                                                        true_type=event["type"])
                            for argument in trigger["arguments"]:
                                arguments_per_role = dict(role=argument["role"], mentions=[])
                                for mention in argument["mentions"]:
                                    left_pos, right_pos = get_left_and_right_pos(text=item["text"],
                                                                                trigger=mention,
                                                                                language=self.config.language)
                                    arguments_per_role["mentions"].append({
                                        "position": [left_pos, right_pos - 1]
                                    })
                                arguments_per_trigger["arguments"].append(arguments_per_role)
                            self.data_for_evaluation["golden_arguments"].append(arguments_per_trigger)

                            if pred_event_type == "NA":
                                assert self.config.eae_eval_mode == "strict"
                                # in strict mode, we add the gold args for the trigger but do not make predictions
                                continue

                            trigger_left, trigger_right = get_left_and_right_pos(text=item["text"],
                                                                                trigger=trigger,
                                                                                language=self.config.language)

                            for role in query_templates[pred_event_type].keys():
                                query = query_templates[pred_event_type][role][template_id]
                                query = query.replace("[trigger]", self.tokenizer.tokenize(trigger["trigger_word"])[0])
                                query = get_words(text=query, language=self.config.language)
                                if self.is_training:
                                    no_answer = True
                                    for argument in trigger["arguments"]:
                                        if argument["role"] not in query_templates[pred_event_type]:
                                            # raise ValueError(
                                            #     "No template for %s in %s" % (argument["role"], pred_event_type)
                                            # )
                                            logger.warning(
                                                "No template for %s in %s" % (argument["role"], pred_event_type))
                                            pass
                                        if argument["role"] != role:
                                            continue
                                        no_answer = False
                                        for mention in argument["mentions"]:
                                            left_pos, right_pos = get_left_and_right_pos(text=item["text"],
                                                                                        trigger=mention,
                                                                                        language=self.config.language)
                                            example = EAEInputExample(
                                                example_id=trigger_idx-1,
                                                text=words,
                                                pred_type=pred_event_type,
                                                true_type=event["type"],
                                                input_template=query,
                                                trigger_left=trigger_left,
                                                trigger_right=trigger_right,
                                                argument_left=left_pos,
                                                argument_right=right_pos - 1,
                                                argument_role=role
                                            )
                                            self.examples.append(example)
                                    if no_answer:
                                        example = EAEInputExample(
                                            example_id=trigger_idx-1,
                                            text=words,
                                            pred_type=pred_event_type,
                                            true_type=event["type"],
                                            input_template=query,
                                            trigger_left=trigger_left,
                                            trigger_right=trigger_right,
                                            argument_left=-1,
                                            argument_right=-1,
                                            argument_role=role
                                        )
                                        self.examples.append(example)
                                else:
                                    # one instance per query
                                    example = EAEInputExample(
                                        example_id=trigger_idx-1,
                                        text=words,
                                        pred_type=pred_event_type,
                                        true_type=event["type"],
                                        input_template=query,
                                        trigger_left=trigger_left,
                                        trigger_right=trigger_right,
                                        argument_left=-1,
                                        argument_right=-1,
                                        argument_role=role
                                    )
                                    self.examples.append(example)
                    # negative triggers
                    for neg_trigger in item["negative_triggers"]:
                        if self.is_training or self.config.golden_trigger or self.event_preds is None:
                            pred_event_type = "NA"
                        else:
                            pred_event_type = self.event_preds[trigger_idx]

                        trigger_idx += 1
                        if self.config.eae_eval_mode == "loose":
                            continue
                        elif self.config.eae_eval_mode in ["default", "strict"]:
                            if pred_event_type == "NA":
                                continue
                            trigger_left, trigger_right = get_left_and_right_pos(text=item["text"],
                                                                                trigger=neg_trigger,
                                                                                language=self.config.language)
                            for role in query_templates[pred_event_type].keys():
                                query = query_templates[pred_event_type][role][template_id]
                                query = query.replace("[trigger]",
                                                    self.tokenizer.tokenize(neg_trigger["trigger_word"])[0])
                                query = get_words(text=query, language=self.config.language)
                                # one instance per query
                                example = EAEInputExample(
                                    example_id=trigger_idx-1,
                                    text=words,
                                    pred_type=pred_event_type,
                                    true_type="NA",
                                    input_template=query,
                                    trigger_left=trigger_left,
                                    trigger_right=trigger_right,
                                    argument_left=-1,
                                    argument_right=-1,
                                    argument_role=role
                                )
                                self.examples.append(example)

                        else:
                            raise ValueError("Invalid eae_eval_mode: %s" % self.config.eae_eval_mode)
                else:
                    for candi in item["candidates"]:
                        trigger_left, trigger_right = get_left_and_right_pos(text=item["text"],
                                                                            trigger=candi,
                                                                            language=self.config.language)
                        pred_event_type = self.event_preds[trigger_idx]
                        trigger_idx += 1
                        if pred_event_type != "NA":
                            for role in query_templates[pred_event_type].keys():
                                query = query_templates[pred_event_type][role][template_id]
                                query = query.replace("[trigger]", self.tokenizer.tokenize(candi["trigger_word"])[0])
                                query = get_words(text=query, language=self.config.language)
                                # one instance per query
                                example = EAEInputExample(
                                    example_id=trigger_idx-1,
                                    text=words,
                                    pred_type=pred_event_type,
                                    true_type="NA",
                                    input_template=query,
                                    trigger_left=trigger_left,
                                    trigger_right=trigger_right,
                                    argument_left=-1,
                                    argument_right=-1,
                                    argument_role=role
                                )
                                self.examples.append(example)
            if self.event_preds is not None:
                assert trigger_idx == len(self.event_preds)

    def convert_examples_to_features(self) -> None:
        """Converts the `EAEInputExample`s into `EAEInputFeatures`s."""
        self.input_features = []
        self.data_for_evaluation["text_range"] = []
        self.data_for_evaluation["text"] = []

        for example in tqdm(self.examples, desc="Processing features for MRC"):
            # context
            input_context = self.tokenizer(example.text,
                                        truncation=True,
                                        max_length=self.config.max_seq_length,
                                        is_split_into_words=True)
            # template
            input_template = self.tokenizer(example.input_template,
                                            truncation=True,
                                            padding="max_length",
                                            max_length=self.config.max_seq_length,
                                            is_split_into_words=True)

            input_context = self.remove_sub_word(input_context)
            # concatenate
            input_ids = input_context["input_ids"] + input_template["input_ids"]
            attention_mask = input_context["attention_mask"] + input_template["attention_mask"]
            token_type_ids = [0] * len(input_context["input_ids"]) + [1] * len(input_template["input_ids"])
            # truncation
            input_ids = input_ids[:self.config.max_seq_length]
            attention_mask = attention_mask[:self.config.max_seq_length]
            token_type_ids = token_type_ids[:self.config.max_seq_length]
            # output labels
            start_position = 0 if example.argument_left == -1 else example.argument_left + 1
            end_position = 0 if example.argument_right == -1 else example.argument_right + 1
            # data for evaluation
            text_range = dict()
            text_range["start"] = 1
            text_range["end"] = text_range["start"] + sum(input_context["attention_mask"][1:])
            self.data_for_evaluation["text_range"].append(text_range)
            self.data_for_evaluation["text"].append(example.text)
            # features
            features = EAEInputFeatures(
                example_id=example.example_id,
                input_ids=input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                argument_left=start_position,
                argument_right=end_position
            )
            self.input_features.append(features)

    @staticmethod
    def remove_sub_word(inputs):
        """Removes the annotations whose word is a sub-word."""
        outputs = defaultdict(list)
        pre_word_id = -1
        for token_id, word_id in enumerate(inputs.word_ids()):
            if token_id == 0 or (word_id != pre_word_id and word_id is not None):
                for key in inputs:
                    outputs[key].append(inputs[key][token_id])
            pre_word_id = word_id
        return outputs