Machine Reading Comprehension (MRC) Converter
import collections
import logging
from typing import Dict, List, Optional
logger = logging.getLogger(__name__)
read_query_templates
Loads query templates from a prompt file. If a translation is required, the query templates would be translated from English to Chinese based on four types of regulations.
Args:
prompt_file
: A string indicating the path of the prompt file.translate
: A boolean variable indicating whether or not to translate the query templates into Chinese.
Returns:
query_templates
: A dictionary containing the query templates applicable for every event type and argument role.
def read_query_templates(prompt_file: str,
translate: Optional[bool] = False) -> Dict[str, Dict[str, List[str]]]:
"""Loads query templates from a prompt file.
Loads query templates from a prompt file. If a translation is required, the query templates would be translated from
English to Chinese based on four types of regulations.
Args:
prompt_file (`str`):
A string indicating the path of the prompt file.
translate (`bool`, `optional`, defaults to `False`):
A boolean variable indicating whether or not to translate the query templates into Chinese.
Returns:
query_templates (`Dict[str, Dict[str, List[str]]]`)
A dictionary containing the query templates applicable for every event type and argument role.
"""
et_translation = dict()
ar_translation = dict()
if translate:
# the event types and argument roles in ACE2005-zh are expressed in English, we translate them to Chinese
et_file = "/".join(prompt_file.split('/')[:-1]) + "/chinese_event_types.txt"
title = None
for line in open(et_file, encoding='utf-8').readlines():
num = line.split()[0]
chinese = line.split()[1][:line.split()[1].index("(")]
english = line[line.index("(") + 1:line.index(')')]
if '.' not in num:
title = chinese, english
if title:
et_translation['{}.{}'.format(title[1], english)] = "{}.{}".format(title[0], chinese)
ar_file = "/".join(prompt_file.split('/')[:-1]) + "/chinese_arg_roles.txt"
for line in open(ar_file, encoding='utf-8').readlines():
english, chinese = line.strip().split()
ar_translation[english] = chinese
query_templates = dict()
with open(prompt_file, "r", encoding='utf-8') as f:
for line in f:
event_arg, query = line.strip().split(",")
event_type, arg_name = event_arg.split("_")
if event_type not in query_templates:
query_templates[event_type] = dict()
if arg_name not in query_templates[event_type]:
query_templates[event_type][arg_name] = list()
if translate:
# 0 template arg_name
query_templates[event_type][arg_name].append(ar_translation[arg_name])
# 1 template arg_name + in trigger (replace [trigger] when forming the instance)
query_templates[event_type][arg_name].append(ar_translation[arg_name] + "在[trigger]中")
# 2 template arg_query
query_templates[event_type][arg_name].append(query)
# 3 arg_query + trigger (replace [trigger] when forming the instance)
query_templates[event_type][arg_name].append(query[:-1] + "在[trigger]中?")
else:
# 0 template arg_name
query_templates[event_type][arg_name].append(arg_name)
# 1 template arg_name + in trigger (replace [trigger] when forming the instance)
query_templates[event_type][arg_name].append(arg_name + " in [trigger]")
# 2 template arg_query
query_templates[event_type][arg_name].append(query)
# 3 arg_query + trigger (replace [trigger] when forming the instance)
query_templates[event_type][arg_name].append(query[:-1] + " in [trigger]?")
for event_type in query_templates:
for arg_name in query_templates[event_type]:
assert len(query_templates[event_type][arg_name]) == 4
return query_templates
_get_best_indexes
Gets the n-best logits from a list. The methods returns a list containing the indexes of the n-best logits that satisfies both the logits are n-best and greater than the logit of the “cls” token.
Args:
def _get_best_indexes(logits: List[int],
n_best_size: Optional[int] = 1,
larger_than_cls: Optional[bool] = False,
cls_logit: Optional[int] = None) -> List[int]:
"""Gets the n-best logits from a list.
Gets the n-best logits from a list. The methods returns a list containing the indexes of the n-best logits that
satisfies both the logits are n-best and greater than the logit of the "cls" token.
"""
index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)
best_indexes = []
for i in range(len(index_and_score)):
if i >= n_best_size:
break
if larger_than_cls:
if index_and_score[i][1] < cls_logit:
break
best_indexes.append(index_and_score[i][0])
return best_indexes
char_pos_to_word_pos
Returns the word-level position of a mention by counting the number of words before the start position of the mention.
Args:
text
: A string representing the source text that the mention is within.position
: An integer indicating the character-level position of the mention.
Returns:
An integer indicating the word-level position of the mention.
def char_pos_to_word_pos(text: str,
position: int) -> int:
"""Returns the word-level position of a mention.
Returns the word-level position of a mention by counting the number of words before the start position of the
mention.
Args:
text (`str`):
A string representing the source text that the mention is within.
position (`int`)
An integer indicating the character-level position of the mention.
Returns:
An integer indicating the word-level position of the mention.
"""
return len(text[:position].split())
make_predictions
Obtains the prediction from the Machine Reading Comprehension (MRC) model.
def make_predictions(all_start_logits, all_end_logits, training_args):
"""Obtains the prediction from the Machine Reading Comprehension (MRC) model."""
data_for_evaluation = training_args.data_for_evaluation
assert len(all_start_logits) == len(data_for_evaluation["ids"])
# all golden labels
final_all_labels = []
for arguments in data_for_evaluation["golden_arguments"]:
arguments_per_trigger = []
for argument in arguments["arguments"]:
event_argument_type = arguments["true_type"] + "_" + argument["role"]
for mention in argument["mentions"]:
arguments_per_trigger.append(
(event_argument_type, (mention["position"][0], mention["position"][1]), arguments["id"]))
final_all_labels.extend(arguments_per_trigger)
# predictions
_PrelimPrediction = collections.namedtuple("PrelimPrediction",
["start_index", "end_index", "start_logit", "end_logit"])
final_all_predictions = []
for example_id, (start_logits, end_logits) in enumerate(zip(all_start_logits, all_end_logits)):
event_argument_type = data_for_evaluation["pred_types"][example_id] + "_" + \
data_for_evaluation["roles"][example_id]
start_indexes = _get_best_indexes(start_logits, 20, True, start_logits[0])
end_indexes = _get_best_indexes(end_logits, 20, True, end_logits[0])
# add span preds
prelim_predictions = []
for start_index in start_indexes:
for end_index in end_indexes:
if start_index < data_for_evaluation["text_range"][example_id]["start"] or \
end_index < data_for_evaluation["text_range"][example_id]["start"]:
continue
if start_index >= data_for_evaluation["text_range"][example_id]["end"] or \
end_index >= data_for_evaluation["text_range"][example_id]["end"]:
continue
if end_index < start_index:
continue
word_start_index = start_index - 1
word_end_index = end_index - 1
length = word_end_index - word_start_index + 1
if length > 5:
continue
prelim_predictions.append(
_PrelimPrediction(start_index=word_start_index, end_index=word_end_index,
start_logit=start_logits[start_index], end_logit=end_logits[end_index]))
# sort
prelim_predictions = sorted(prelim_predictions, key=lambda x: (x.start_logit + x.end_logit), reverse=True)
# get final pred in format: [event_type_offset_argument_type, [start_offset, end_offset]]
max_num_pred_per_arg = 1
predictions_per_query = []
for _, pred in enumerate(prelim_predictions[:max_num_pred_per_arg]):
na_prob = (start_logits[0] + end_logits[0]) - (pred.start_logit + pred.end_logit)
predictions_per_query.append((event_argument_type, (pred.start_index, pred.end_index), na_prob,
data_for_evaluation["ids"][example_id]))
final_all_predictions.extend(predictions_per_query)
logger.info("\nAll predictions and labels generated. %d %d\n" % (len(final_all_predictions), len(final_all_labels)))
return final_all_predictions, final_all_labels
find_best_thresh
def find_best_thresh(new_preds, new_all_gold):
best_score = 0
best_na_thresh = 0
gold_arg_n, pred_arg_n = len(new_all_gold), 0
candidate_preds = []
for argument in new_preds:
candidate_preds.append(argument[:-2] + argument[-1:])
pred_arg_n += 1
pred_in_gold_n, gold_in_pred_n = 0, 0
# pred_in_gold_n
for argu in candidate_preds:
if argu in new_all_gold:
pred_in_gold_n += 1
# gold_in_pred_n
for argu in new_all_gold:
if argu in candidate_preds:
gold_in_pred_n += 1
prec_c, recall_c, f1_c = 0, 0, 0
if pred_arg_n != 0:
prec_c = 100.0 * pred_in_gold_n / pred_arg_n
else:
prec_c = 0
if gold_arg_n != 0:
recall_c = 100.0 * gold_in_pred_n / gold_arg_n
else:
recall_c = 0
if prec_c or recall_c:
f1_c = 2 * prec_c * recall_c / (prec_c + recall_c)
else:
f1_c = 0
if f1_c > best_score:
best_score = f1_c
best_na_thresh = argument[-2]
return best_na_thresh + 1e-10
compute_mrc_F1_cls
def compute_mrc_F1_cls(all_predictions, all_labels):
all_predictions = sorted(all_predictions, key=lambda x: x[-2])
# best_na_thresh = 0
best_na_thresh = find_best_thresh(all_predictions, all_labels)
print("Best thresh founded. %.6f" % best_na_thresh)
final_new_preds = []
for argument in all_predictions:
if argument[-2] < best_na_thresh:
final_new_preds.append(argument[:-2] + argument[-1:]) # no na_prob
# get results (classification)
gold_arg_n, pred_arg_n, pred_in_gold_n, gold_in_pred_n = 0, 0, 0, 0
# pred_arg_n
for argument in final_new_preds:
pred_arg_n += 1
# gold_arg_n
for argument in all_labels:
gold_arg_n += 1
# pred_in_gold_n
for argument in final_new_preds:
if argument in all_labels:
pred_in_gold_n += 1
# gold_in_pred_n
for argument in all_labels:
if argument in final_new_preds:
gold_in_pred_n += 1
prec_c, recall_c, f1_c = 0, 0, 0
if pred_arg_n != 0:
prec_c = 100.0 * pred_in_gold_n / pred_arg_n
else:
prec_c = 0
if gold_arg_n != 0:
recall_c = 100.0 * gold_in_pred_n / gold_arg_n
else:
recall_c = 0
if prec_c or recall_c:
f1_c = 2 * prec_c * recall_c / (prec_c + recall_c)
else:
f1_c = 0
logger.info("Precision: %.2f, recall: %.2f" % (prec_c, recall_c))
return f1_c