'How to output the list of probabilities on each token via model.generate?

Right now I have:

model = GPTNeoForCausalLM.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.cuda()
gen_tokens = model.generate(input_ids, do_sample=specifiedDoSample, output_scores=True, temperature=specifiedTemperature, max_new_tokens=specifiedNumTokens, repetition_penalty=specifiedRepetitionPenalty, top_p=specifiedTopP)
gen_text = tokenizer.batch_decode(gen_tokens)[0]
print(gen_text)

This will print the generated text. However, I want it to list the top N tokens in each step as well as their probability (N being a number specified by me), similar to OpenAI's beta playground where you can select "Show probabilities: Full spectrum". For example, if the prompt is "You are now a", the next token should say something like {"vampire": 51%, "corpse": 32% ... etc.}

What is the easiest way to do this via Huggingface Transformers?



Solution 1:[1]

You need to add ", output_scores=True, return_dict_in_generate=True" in the call to the generate method, this will give you a scores table per character of generated phrase, which contains a tensor with the scores (need to softmax to get the probas) of each token for each possible sequence in the beam search.

Look at generation_utils.py in the transformers source tree, starting at "def generate"

Solution 2:[2]

A potential workaround is in the thread https://github.com/huggingface/transformers/issues/10012.

Use beam search as described in the thread, using n beams where n is the number of probs you want to display, but only looking 1 token into the future. Then, according to comment by mshuffett:

I just moved this line below the return_dict_in_generate block.

next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)

I tried it and it worked perfectly. The next single token's probabilities now displayed correctly.

Alternatively you can try the solutions described in https://github.com/huggingface/transformers/issues/16010. I haven't gotten around to it because it looks slightly more involved than the easy workaround.

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1
Solution 2 pete