Prompt Learning-basic & [EACL 2021]PET

https://arxiv.org/pdf/2001.07676.pdf

As shown in the figure above, it is a text sentiment classification task: to determine whether "Best pizza ever!" is a positive emotion or a negative emotion?

The innovative ideas proposed in the text are:

1. You can complete the construction of the text fill-in-the-blank task, such as becoming "Best pizza ever! It was __". This scheme can perform zero-shot learning, and use a masked language model model (such as BERT) pre-trained by MLM (Masked Language Model) to predict the probability that the missing token is predicted to be "great" or "bad". That is, the pre-trained model can learn the semantics of the label "great" or "bad" on the pre-trained data, and the semantic rationality of its filling in the text.

2. In fact, text classification can be performed in the first step, but the author believes that the classifier can be further trained. We use the ensemble learning method to obtain integrated soft-label s from multiple pre-trained models

3. Use unsupervised text data and soft-label to train a text classification model

Summary thinking

The idea of ​​PET can also be used for supervised few-sample learning: for labeled data, when constructing samples, we first add a Pattern to each sentence, in addition to the Mask position that comes with the Pattern, and then randomize other parts of the Mask , to enhance the regularization of the model.

Practical

Next, we use the OpenPrompt architecture for practical operations. The following is the overall architecture diagram of OpenPrompt:

Note that not all modules must use Prompt Learning. For example, in the generative task, verbalizers were not learned during the learning process. PromptTrainer is a controller that controls the flow of data and the training process, with some unique properties, the user can also implement the training process in a conventional way.

Sentiment classification

Step 1. Define the task

The first step is to identify the current NLP task, think about what your data looks like and what you want to get out of it! That is, the essence of this step is to determine the classses and InputExample of the task. For simplicity, let's take sentiment analysis as an example.

from openprompt.data_utils import InputExample

classes = [  # Sentiment analysis is divided into two categories, one is negative and one is positive
    "negative",
    "front"
]
dataset = [  # For simplicity, just two examples
    # text_a is the input text of the data, some other datasets may have multiple input sentences in one sample.
    InputExample(
        guid=0,
        text_a="The sun is shining today.",
        label=1
    ),
    InputExample(
        guid=1,
        text_a="It has been raining every day recently.",
        label=0
    ), ]
copy

Step 2. Get PLM

from openprompt.plms import load_plm

plm, tokenizer, model_config, WrapperClass = load_plm("bert", "uer/chinese_roberta_L-4_H-256")
copy

Step 3. Define Template

Template is a modifier of the original input text and one of the most important modules in prompt learning.

from openprompt.prompts import ManualTemplate

promptTemplate = ManualTemplate(
    text='weather{"mask"}: {"placeholder":"text_a"}',
    tokenizer=tokenizer,
)
copy

where <text_a> will be replaced with text_ain InputExample and will be used to predict label words.

Step 4. Define the Verbalizer

Verbalizer is another important (but not necessary, e.g. in generation) in cue learning, which projects raw labels to a set of label words.

from openprompt.prompts import ManualVerbalizer

promptVerbalizer = ManualVerbalizer(
    classes=classes,
    label_words={
        "negative": ["not good", "bad"],
        "front": ["good", "very nice"],
    },
    tokenizer=tokenizer,
)
copy

Step 5. Build the PromptModel

Given a task, now we have a PLM, a Template, and a Verbalizer, and we combine them into a PromptModel.

from openprompt import PromptForClassification

prompt_model = PromptForClassification(
    template=promptTemplate,
    plm=plm,
    verbalizer=promptVerbalizer,
)
copy

Note that although this example simply combines three modules, some complex interactions can actually be defined between them.

Step 6. Define the DataLoader

PromptDataLoader is basically the prompt version of pytorch Dataloader, which also includes a Tokenizer and a Template:

from openprompt import PromptDataLoader

data_loader = PromptDataLoader(
    dataset=dataset,
    tokenizer=tokenizer,
    template=promptTemplate,
    tokenizer_wrapper_class=WrapperClass,
)
copy

Step 7. Training and Inference

import torch
from transformers import AdamW

no_decay = ['bias', 'LayerNorm.weight']
# It is always a good practice to have no decay for Bias and LayerNorm parameter settings
optimizer_grouped_parameters = [
    {'params': [p for n, p in prompt_model.named_parameters() if not any(nd in n for nd in no_decay)],
     'weight_decay': 0.01},
    {'params': [p for n, p in prompt_model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-4)
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(10):
    tot_loss = 0
    for step, inputs in enumerate(data_loader):
        logits = prompt_model(inputs)
        labels = inputs['label']
        loss = loss_func(logits, labels)
        loss.backward()
        tot_loss += loss.item()
        optimizer.step()
        optimizer.zero_grad()
        if step % 100 == 1:
            print("Epoch {}, average loss: {}".format(epoch, tot_loss / (step + 1)), flush=True)

# verify
validation_dataloader = data_loader

allpreds = []
alllabels = []
for step, inputs in enumerate(validation_dataloader):
    logits = prompt_model(inputs)
    print(logits)
    labels = inputs['label']
    alllabels.extend(labels.cpu().tolist())
    allpreds.extend(torch.argmax(logits, dim=-1).cpu().tolist())

acc = sum([int(i == j) for i, j in zip(allpreds, alllabels)]) / len(allpreds)
print(acc)
copy

The experiment found that the above samples can accurately predict the results without training and direct verification of 0 samples.

Posted by scheinarts on Fri, 30 Sep 2022 10:12:16 +0530