Constraint Decoding

Note

Copyright Text2Event from https://github.com/luyaojie/Text2Event.

Licensed under the MIT License.

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Copyright Text2Event from https://github.com/luyaojie/Text2Event.
# Licensed under the MIT License.

from typing import List, Dict
import os
import re
import sys
sys.path.append("../")
import pdb

# debug = True if 'DEBUG' in os.environ else False
# debug_step = True if 'DEBUG_STEP' in os.environ else False
debug = False
debug_step = False

from ..input_engineering.seq2seq_processor import type_start, type_end


def get_label_name_tree(label_name_list, tokenizer, end_symbol='<end>'):
    sub_token_tree = dict()

    label_tree = dict()
    for typename in label_name_list:
        after_tokenized = tokenizer.encode(typename, add_special_tokens=False)
        label_tree[typename] = after_tokenized

    for _, sub_label_seq in label_tree.items():
        parent = sub_token_tree
        for value in sub_label_seq:
            if value not in parent:
                parent[value] = dict()
            parent = parent[value]

        parent[end_symbol] = None

    return sub_token_tree


def match_sublist(the_list, to_match):
    """
    :param the_list: [1, 2, 3, 4, 5, 6, 1, 2, 4, 5]
    :param to_match: [1, 2]
    :return:
        [(0, 1), (6, 7)]
    """
    len_to_match = len(to_match)
    matched_list = list()
    for index in range(len(the_list) - len_to_match + 1):
        if to_match == the_list[index:index + len_to_match]:
            matched_list += [(index, index + len_to_match - 1)]
    return matched_list


def find_bracket_position(generated_text, _type_start, _type_end):
    bracket_position = {_type_start: list(), _type_end: list()}
    for index, char in enumerate(generated_text):
        if char in bracket_position:
            bracket_position[char] += [index]
    return bracket_position


def generated_search_src_sequence(generated, src_sequence, end_sequence_search_tokens=None):
    # print(generated, src_sequence) if debug else None

    if len(generated) == 0:
        # It has not been generated yet. All SRC are valid.
        return src_sequence

    matched_tuples = match_sublist(the_list=src_sequence, to_match=generated)

    valid_token = list()
    for _, end in matched_tuples:
        next_index = end + 1
        if next_index < len(src_sequence):
            valid_token += [src_sequence[next_index]]

    if end_sequence_search_tokens:
        valid_token += end_sequence_search_tokens

    return valid_token


def get_constraint_decoder(tokenizer, type_schema, source_prefix=None):
    return StruConstraintDecoder(tokenizer=tokenizer, type_schema=type_schema, source_prefix=source_prefix)


class ConstraintDecoder:
    def __init__(self, tokenizer, source_prefix):
        self.tokenizer = tokenizer
        self.source_prefix = source_prefix
        self.source_prefix_tokenized = tokenizer.encode(source_prefix,
                                                        add_special_tokens=False) if source_prefix else []

    def get_state_valid_tokens(self, src_sentence: List[str], tgt_generated: List[str]) -> List[str]:
        pass

    def constraint_decoding(self, batch_id, src_sentence, tgt_generated):
        if self.source_prefix_tokenized:
            # Remove Source Prefix for Generation
            src_sentence = src_sentence[len(self.source_prefix_tokenized):]

        if debug:
            # if batch_id == 4:
            print("Src:", self.tokenizer.convert_ids_to_tokens(src_sentence))
            print("Tgt:", self.tokenizer.convert_ids_to_tokens(tgt_generated))
            print(batch_id, len(tgt_generated), tgt_generated)

        valid_token_ids = self.get_state_valid_tokens(
            src_sentence.tolist(),
            tgt_generated.tolist()
        )
        # pdb.set_trace()

        # if debug:
        #     print('========================================')
        #     print('valid tokens:', self.tokenizer.convert_ids_to_tokens(
        #         valid_token_ids), valid_token_ids)
        #     if debug_step:
        #         input()

        # return self.tokenizer.convert_tokens_to_ids(valid_tokens)
        return valid_token_ids


class StruConstraintDecoder(ConstraintDecoder):
    def __init__(self, tokenizer, type_schema, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
        self.tree_end = '<tree-end>'
        self.type_tree = get_label_name_tree(type_schema["role_list"],
                                             tokenizer=self.tokenizer,
                                             end_symbol=self.tree_end)
        self.type_start = self.tokenizer.convert_tokens_to_ids([type_start])[0]
        self.type_end = self.tokenizer.convert_tokens_to_ids([type_end])[0]

    def check_state(self, tgt_generated):
        if tgt_generated[-1] == self.tokenizer.pad_token_id:
            return 'start', -1

        special_token_set = {self.type_start, self.type_end}
        special_index_token = list(
            filter(lambda x: x[1] in special_token_set, list(enumerate(tgt_generated))))

        last_special_index, last_special_token = special_index_token[-1]

        if len(special_index_token) == 1:
            if last_special_token != self.type_start:
                return 'error', 0

        bracket_position = find_bracket_position(
            tgt_generated, _type_start=self.type_start, _type_end=self.type_end)
        start_number, end_number = len(bracket_position[self.type_start]), len(
            bracket_position[self.type_end])

        if start_number == end_number:
            return 'end_generate', -1
        if start_number == end_number + 1:
            state = 'start_first_generation'
        elif start_number == end_number + 2:
            state = 'generate_span'
        else:
            state = 'error'
        return state, last_special_index

    def search_prefix_tree_and_sequence(self, generated: List[str], prefix_tree: Dict, src_sentence: List[str],
                                        end_sequence_search_tokens: List[str] = None):
        """
        Generate Type Name + Text Span
        :param generated:
        :param prefix_tree:
        :param src_sentence:
        :param end_sequence_search_tokens:
        :return:
        """
        tree = prefix_tree
        for index, token in enumerate(generated):
            tree = tree[token]
            is_tree_end = len(tree) == 1 and self.tree_end in tree

            if is_tree_end:
                valid_token = generated_search_src_sequence(
                    generated=generated[index + 1:],
                    src_sequence=src_sentence,
                    end_sequence_search_tokens=end_sequence_search_tokens,
                )
                return valid_token

            if self.tree_end in tree:
                try:
                    valid_token = generated_search_src_sequence(
                        generated=generated[index + 1:],
                        src_sequence=src_sentence,
                        end_sequence_search_tokens=end_sequence_search_tokens,
                    )
                    return valid_token
                except IndexError:
                    # Still search tree
                    continue

        valid_token = list(tree.keys())
        return valid_token

    def get_state_valid_tokens(self, src_sentence, tgt_generated):
        """
        :param src_sentence:
        :param tgt_generated:
        :return:
            List[str], valid token list
        """
        if self.tokenizer.eos_token_id in src_sentence:
            src_sentence = src_sentence[:src_sentence.index(
                self.tokenizer.eos_token_id)]

        state, index = self.check_state(tgt_generated)

        print("State: %s" % state) if debug else None

        if state == 'error':
            print("Error:")
            print("Src:", src_sentence)
            print("Tgt:", tgt_generated)
            valid_tokens = [self.tokenizer.eos_token_id]

        elif state == 'start':
            valid_tokens = [self.type_start]

        elif state == 'start_first_generation':
            valid_tokens = [self.type_start, self.type_end]

        elif state == 'generate_span':

            if tgt_generated[-1] == self.type_start:
                # Start Event Label
                return list(self.type_tree.keys())

            elif tgt_generated[-1] == self.type_end:
                raise RuntimeError('Invalid %s in %s' %
                                   (self.type_end, tgt_generated))

            else:
                try:
                    valid_tokens = self.search_prefix_tree_and_sequence(
                        generated=tgt_generated[index + 1:],
                        prefix_tree=self.type_tree,
                        src_sentence=src_sentence,
                        end_sequence_search_tokens=[self.type_end]
                    )
                except:
                    print("Warning! An unexpected token is generated due to len(valid_tokens) < num_beams.")
                    valid_tokens = [self.tokenizer.eos_token_id]

        elif state == 'end_generate':
            valid_tokens = [self.tokenizer.eos_token_id]

        else:
            raise NotImplementedError(
                'State `%s` for %s is not implemented.' % (state, self.__class__))

        print("Valid: %s" % valid_tokens) if debug else None
        return valid_tokens


class SpanConstraintDecoder(ConstraintDecoder):
    def __init__(self, tokenizer, type_schema, *args, **kwargs):
        super().__init__(tokenizer, *args, **kwargs)
        self.tree_end = '<tree-end>'
        self.type_tree = get_label_name_tree(type_schema["role_list"],
                                             tokenizer=self.tokenizer,
                                             end_symbol=self.tree_end)

    def check_state(self, tgt_generated, special_tokens_in_tgt):
        if tgt_generated[-1] == self.tokenizer.pad_token_id:
            return 'start', -1
        else:
            index = len(tgt_generated)
            for i, token in enumerate(tgt_generated):
                if token == special_tokens_in_tgt[-1]:
                    index = i+1
                    break
            return "generate", index

    def get_special_tokens(self, sentence):
        special_template = re.compile("<extra_id_\d+>")
        tokens = self.tokenizer.convert_ids_to_tokens(sentence)
        special_tokens = []
        for token in tokens:
            if special_template.match(token) is not None:
                special_tokens.append(token)
        return self.tokenizer.convert_tokens_to_ids(special_tokens)

    def truncate_src(self, src_sentence):
        special_template = re.compile("<extra_id_\d+>")
        index = len(src_sentence)
        tokens = self.tokenizer.convert_ids_to_tokens(src_sentence)
        for i, token in enumerate(tokens):
            if special_template.match(token) is not None:
                index = i
                break
        return src_sentence[:index]

    def get_state_valid_tokens(self, src_sentence, tgt_generated):
        """
        :param src_sentence:
        :param tgt_generated:
        :return:
            List[str], valid token list
        """
        if self.tokenizer.eos_token_id in src_sentence:
            src_sentence = src_sentence[:src_sentence.index(
                self.tokenizer.eos_token_id)]

        special_tokens_in_src = self.get_special_tokens(src_sentence)
        special_tokens_in_gen = self.get_special_tokens(tgt_generated)

        # truncate
        src_sentence = self.truncate_src(src_sentence)

        state, index = self.check_state(tgt_generated, special_tokens_in_gen)


        if state == 'start':
            valid_tokens = [special_tokens_in_src[0]]

        elif state == 'generate':
            # valid_tokens = [self.type_start, self.type_end]
            valid_special_tokens = [self.tokenizer.convert_tokens_to_ids("[SEP]")]
            for token in special_tokens_in_src:
                if token not in special_tokens_in_gen:
                    valid_special_tokens.append(token)
            valid_tokens = generated_search_src_sequence(
                        generated=tgt_generated[index:],
                        src_sequence=src_sentence,
                        end_sequence_search_tokens=[self.tokenizer.eos_token_id],
                    )
            valid_tokens = valid_special_tokens + valid_tokens
        else:
            raise NotImplementedError(
                'State `%s` for %s is not implemented.' % (state, self.__class__))

        print("Valid: %s" % valid_tokens) if debug else None
        return valid_tokens