Classification Head

from .classification import LinearHead, MRCHead
from .crf import CRF

get_head

def get_head(config):
    if config.head_type == "linear":
        return LinearHead(config)
    elif config.head_type == "mrc":
        return MRCHead(config)
    elif config.head_type == "crf":
        return CRF(config.num_labels, batch_first=True)
    elif config.head_type in ["none", "None"] or config.head_type is None:
        return None
    else:
        raise ValueError("Invalid head_type %s in config" % config.head_type)