Official implementation of GeDi: Generative Discriminator Guided Sequence Generation
Blogpost here
Colab Notebook on controlling topic using GeDi here
Sept 29, 2020: Adding support for GeDi-guided GPT-3 generation (API key needed)
GeDi is a method of using class-conditional language models (which we refer to as generative discriminators (GeDis)) to guide generation from other (potentially much larger) language models. This has several advantages over finetuning large language models directly including:
- significantly less training computation.
- maintaining the diversity of the original language model (If we finetune a large pretrained language model to a specific attribute dataset, we will likely reduce the broad generation capabilities of the model).
- teaching the language model what not to generate. This is especially useful for applications like detoxification.
GeDi is a form of discriminator guided generation. A discriminator that can classify an attribute could be used to guide language model generation towards that attribute by classifying the sequences that result from candidate next tokens. However, using a normal discriminator (such as BERT) to do this would be very computationally expensive during generation, since it would require feeding in every candidate next token one-by-one to the discriminator to be classified. However, using generative discriminators, we can very efficiently classify candidate next tokens during generation using Bayes rule (see Section 3.1 of the paper). As an added bonus, generative discriminators can be used as zero shot classifiers, and can therefore be used to guide generation towards unseen topics.
-
Python 3.7, PyTorch 1.4 (We recommend creating a container using the pytorch/pytorch:1.4-cuda10.1-cudnn7-devel official pytorch docker image.)
-
Run
scripts/setup.sh:cd scripts bash setup.shThis will install the following:
- First download the models:
cd scripts
bash get_models.sh
This downloads and saves the topic, sentiment, and detoxifier models in the folder ../pretrained_models
- To generate, use
bash run_generation.sh, which calls../generate_GeDi.pywith the appropriate arguments (set for topic generation by default).
Important arguments include:
--modecan be set totopic,sentiment, ordetoxify--gen_typecan be set togedifor GeDi guided generation,cclmfor class conditional generation, orgpt2to generate from raw GPT-2--gen_lengthmax length of generation--gedi_model_name_or_pathpath to GeDi model. If unused, will assume you ranbash get_models.shand infer model directory from--modeargument--filter_pequal to 1 - \rho in Equation 7 of the paper--target_pequal to \tau from the paper--disc_weightexponent for posterior weighting (\omega in Equation 6 of the paper)--fp16converts GPT2-XL weights to fp16 for faster generation and less GPU memory usage
Running will allow you to enter control codes and prompts for generation in a continuous loop until you exit.
- Set
--mode topicinscripts/run_generation.sh - You will be prompted to give a topic code. The model was trained on
world,sports,business, andscience, but can often generate other topics zero-shot, for instancespace,fire,climate,education - If the topic code you give is more than one BPE token, the model often struggles because the 4 training topics were all 1 BPE token. You will be warned that this might not work, but can proceed by hitting enter again (or can type a new topic code).
- After the topic code, you will be asked to give a prompt to the model to condition on for generation.
- Set
--mode sentimentinscripts/run_generation.sh - The model can controllably generate positive or negative text. When generalizing to other domains such as stories, this often translates to positive/negative mood or tone of the story (since sentiment implies an opinion).
- The model is set to positive sentiment by default. You will be prompted for the opportunity to change to negative sentiment by typing
n. Note that the negative model can be very negative, and this sometimes results in toxic or offensive samples. - You will then be asked to give a prompt to the model to condition on for generation.
- Set
--mode detoxifyinscripts/run_generation.sh - This mode can be used to avoid generating toxic or offensive text.
- You will then be asked to give a prompt to the model to condition on for generation.
- GeDi can often find a way to navigate especially aggressive prompts, but does rarely but occasionally still generate toxic text if given certain prompts. We observed this can be a problem for longer generations.
- Two of the baselines we consider are generating from GPT-2 (will give same result regardless of control codes), and generating from the GeDi model directly as a class-conditional language model (instead of using it to guide generation from GPT-2).
- Set
--gen_type gpt2to generate from GPT-2, and--gen_type cclmto generate directly from the GeDi as a class-conditional language model.--gen_type cclmcorresponds to all experiments in Section 5 of the paper, and the CC-LM baselines in Section 6.1.
- If you have your own GPT-3 API secret key, you can use GeDi to guide decoding from GPT-3.
- This is somewhat limited, since the GPT-3 API only allow access to the top 100 next token log probabilities.
- Reuses settings for controlling GPT-2 (which uses all next token log probs), retuning for GPT-3 could give better results.
- It is also slow (up to 1 second per token) because modifying GPT-3 decoding requires calling the API one token at a time.
To control sentiment from GPT-3 using your API key (should have prefix "sk"):
pip install openai
python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gpt3_api_key sk-xxxxxxxx
You can also try changing the --mode or other arguments. To generate directly from GPT-3 without GeDi using our same greedy decoding scheme:
python ../generate_GeDi.py --penalize_cond --gen_length 100 --mode sentiment --gen_type gpt2 --gpt3_api_key sk-xxxxxxx
- This repository includes code to train a topic GeDi using GeDi training.
- There are some differences in this training script and the one used to train the pretrained model. The pretrained model only used half of AG news, and there were some slight differences in preprocessing.
- This runs in about 5 hours on a 16GB V100 GPU on GCP.
- First, download and process the topic data:
cd scripts
bash get_data.sh
- Then run training using:
bash run_training.sh which calls ../train_GeDi.py with the appropriate arguments
- The directory for model to be saved is specified by
output_dirargument. - When generating from your trained GeDi, you will need to call
../generate_GeDi.py(called frombash run_generation.sh) with--gedi_model_name_or_pathset to the directory of your trained model.
@article{KrauseGeDi2020,
title={{GeDi: Generative Discriminator Guided Sequence Generation}},
author={Krause, Ben and Gotmare, Akhilesh Deepak and McCann, Bryan and Keskar, Nitish Shirish and Joty, Shafiq and Socher, Richard and Rajani, Nazneen Fatema},
journal={arXiv preprint arXiv:2009.06367},
year={2020}
}
The code is released under the BSD-3 License (see LICENSE.txt for details), but we also ask that users respect the following:
This software should not be used to promote or profit from violence, hate, and division, environmental destruction, abuse of human rights, or the destruction of people's physical and mental health.
