-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patha3_sampling.py
More file actions
254 lines (217 loc) · 12.2 KB
/
Copy patha3_sampling.py
File metadata and controls
254 lines (217 loc) · 12.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import torch
from torch.nn import functional as F
from typing import Any, Dict
from a3_utils import *
from transformers import (
T5Tokenizer,
T5ForConditionalGeneration
)
class TopKSamplerForT5(GeneratorForT5):
###########################################################################
# NOTE: Caution - do not modify the args to the class + the args of
# the sample function.
#
# However, feel free to add as many helper functions in this class as you want.
###########################################################################
def __init__(self, model: T5ForConditionalGeneration, tokenizer: T5Tokenizer):
super().__init__(model, tokenizer)
def get_top_k(self, logits, temperature, top_k):
proba = F.softmax(logits/temperature, dim = 0)
proba, indices = torch.sort(proba, dim =0, descending=True)
proba, indices = proba.tolist(), indices.tolist()
return proba[:top_k], indices[:top_k]
def sample(
self,
inputs: dict,
top_k: int,
temperature: float,
max_new_tokens: int,
) -> torch.LongTensor:
"""Generates sequences of token ids for T5ForConditionalGeneration
(which has a language modeling head) using top-k sampling.
This means that we sample the next token from the top-k scoring tokens
by using their probability values.
This function always does early stopping and does not handle the case
where we don't do early stopping.
It also only handles inputs of batch size = 1.
It also only handles top_k => 1.
The temperature variable that helps modulate the probability by scaling the logits.
distribution we sample from by scaling the logits before softmax.
Inherits variables and helper functions from GeneratorForT5().
Args:
inputs (dict): the tokenized input dictionary returned by the T5 tokenizer
top_k (int): the number of highest probability vocabulary tokens to keep for top-k filtering/sampling
temperature (float): the value used to modulate the next token probabilities, scales logits before softmax
max_new_tokens (int): a limit for the amount of decoder outputs
we desire to generate
Returns:
torch.LongTensor: top-k sampled sequence made of token ids of size (1,generated_seq_len)
This should include the starting pad token!
"""
########################################################################
# NOTE: Don't change this part, it's to help you debug!
constraint_return = self.input_constraints(inputs, max_new_tokens, top_k=top_k)
if constraint_return is None:
return None
else:
max_new_tokens = constraint_return
########################################################################
########################################################################
# TODO: Implement me! Read the docstring above and this comment carefully.
#
# For top-k sampling, keep in mind of the following:
# - do not handle input batch size != 1.
# - return the sampled sequence as it is (not in a dictionary).
# You should not return a score you get for the sequence.
# - always do early stopping: this means that if the next token is an EOS
# (end-of-sentence) token, you should stop decoding.
# - don't forget to implement the temperature functionality!
# - you might want to use the self.prepare_next_inputs function inherited
# by this class as shown here:
#
# First token use:
# model_inputs = self.prepare_next_inputs(model_inputs=inputs)
# Future use:
# model_inputs = self.prepare_next_inputs(
# model_inputs = model_inputs,
# new_token_id = new_token_id,
# )
########################################################################
model_inputs = inputs.copy()
# Set the first input we want to give to the decoder
model_inputs['decoder_input_ids'] = torch.tensor([self.tokenizer.pad_token_id] ).unsqueeze(0)
model_inputs['decoder_attention_mask'] = torch.tensor([1]).unsqueeze(0)
# Get the output hidden state of the encoder
init_output = self.model(**model_inputs)
encoder_outputs = (init_output.encoder_last_hidden_state,)
#Get the log probs of the output
logits = init_output.logits[0][-1]
top_k_proba, top_k_tokens = self.get_top_k(logits, temperature, top_k)
new_token_id = np.random.choice(top_k_tokens, 1, top_k_proba)
output_sentence = [self.tokenizer.pad_token_id, new_token_id.item()]
for i in range(max_new_tokens-1):
if new_token_id == self.tokenizer.eos_token_id :
break
# Get the distibution of the token using the hidden state of the encoded sequence and the already generated tokens
logits = self.model(None, decoder_input_ids = torch.tensor([output_sentence]) ,encoder_outputs = encoder_outputs)['logits']
#We take the top k Tokens
top_k_proba, top_k_tokens = self.get_top_k(logits[0][-1], temperature, top_k)
#We transform the logits to a distribution
new_token_id = np.random.choice(top_k_tokens, 1, top_k_proba)
output_sentence.append(new_token_id.item())
return torch.tensor([output_sentence])
class TopPSamplerForT5(GeneratorForT5):
###########################################################################
# NOTE: Caution - do not modify the args to the class + the args of
# the sample function.
#
# However, feel free to add as many helper functions in this class as you want.
###########################################################################
def __init__(self, model: T5ForConditionalGeneration, tokenizer: T5Tokenizer):
super().__init__(model, tokenizer)
def get_top_p(self, logits, temperature, top_p):
proba = F.softmax(logits/temperature, dim = 0)
proba, indices = torch.sort(proba, dim =0, descending=True)
proba, indices = proba.tolist(), indices.tolist()
sum_p, i = 0, 0
while sum_p < top_p and i < len(proba):
sum_p += proba[i]+1
i +=1
return proba[:i], indices[:i]
def sample(
self,
inputs: dict,
top_p: float,
temperature: float,
max_new_tokens: int
) -> torch.LongTensor:
"""Generates sequences of token ids for T5ForConditionalGeneration
(which has a language modeling head) using top-p sampling.
This means that we sample the next token from the smallest set of most
probable tokens with probabilities that cumulatively add up to top_p or higher.
This function always does early stopping and does not handle the case
where we don't do early stopping.
It also only handles inputs of batch size = 1.
If there are no tokens falling in the top_p cumulative probability mass
(e.g. because the top scoring tokens probability is larger than top_p) then sample the top scoring token.
The temperature variable that helps modulate the probability by scaling the logits.
distribution we sample from by scaling the logits before softmax.
Inherits variables and helper functions from GeneratorForT5().
Args:
inputs (dict): the tokenized input dictionary returned by the T5 tokenizer
top_p (float): the cumulative probability mass to select the smallest
set of most probable tokens with probabilities that
cumulatively add up to top_p or higher.
temperature (float): the value used to modulate the next token probabilities, scales logits before softmax
max_new_tokens (int): a limit for the amount of decoder outputs
we desire to generate
Returns:
torch.LongTensor: top-p sampled sequence made of token ids of size (1,generated_seq_len)
This should include the starting pad token!
"""
########################################################################
# NOTE: Don't change this part, it's to help you debug!
constraint_return = self.input_constraints(inputs, max_new_tokens)
if constraint_return is None:
return None
else:
max_new_tokens = constraint_return
########################################################################
########################################################################
# TODO: Implement me! Read the docstring above and this comment carefully.
#
# For top-p sampling, keep in mind of the following:
# - do not handle input batch size != 1.
# - return the sampled sequence as it is (not in a dictionary).
# You should not return a score you get for the sequence.
# - always do early stopping: this means that if the next token is an EOS
# (end-of-sentence) token, you should stop decoding.
# - don't forget to handle the edge case when top scoring tokens probability > top_p,
# sample that token only.
# - don't forget to implement the temperature functionality!
# - you might want to use the self.prepare_next_inputs function inherited
# by this class as shown here:
#
# First token use:
# model_inputs = self.prepare_next_inputs(model_inputs=inputs)
# Future use:
# model_inputs = self.prepare_next_inputs(
# model_inputs = model_inputs,
# new_token_id = new_token_id,
# )
########################################################################
model_inputs = inputs.copy()
# Set the first input we want to give to the decoder
model_inputs['decoder_input_ids'] = torch.tensor([self.tokenizer.pad_token_id] ).unsqueeze(0)
model_inputs['decoder_attention_mask'] = torch.tensor([1]).unsqueeze(0)
# Get the output hidden state of the encoder
init_output = self.model(**model_inputs)
encoder_outputs = (init_output.encoder_last_hidden_state,)
#Get the log probs of the output
logits = init_output.logits[0][-1]
top_p_proba, top_p_tokens = self.get_top_p(logits, temperature, top_p)
new_token_id = np.random.choice(top_p_tokens, 1, top_p_proba)
output_sentence = [self.tokenizer.pad_token_id, new_token_id.item()]
for i in range(max_new_tokens-1):
if new_token_id == self.tokenizer.eos_token_id :
break
# Get the distibution of the token using the hidden state of the encoded sequence and the already generated tokens
logits = self.model(None, decoder_input_ids = torch.tensor([output_sentence]) ,encoder_outputs = encoder_outputs)['logits']
#We take the top p Tokens
top_p_proba, top_p_tokens = self.get_top_p(logits[0][-1], temperature, top_p)
#We transform the logits to a distribution
new_token_id = np.random.choice(top_p_tokens, 1, top_p_proba)
output_sentence.append(new_token_id.item())
return torch.tensor([output_sentence])
def main():
############################################################################
# NOTE: You can use this space for testing but you are not required to do so!
############################################################################
seed = 421
torch.manual_seed(seed)
torch.set_printoptions(precision=16)
model_name = "t5-small"
model = T5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = T5Tokenizer.from_pretrained(model_name)
if __name__ == '__main__':
main()