11// Copyright (c) Microsoft Corporation. All rights reserved.
22// Licensed under the MIT License.
33
4+ #include < cstring>
5+
46#include < core/common/safeint.h>
57#include " word_conv_embedding.h"
68
@@ -14,6 +16,7 @@ namespace contrib {
1416void WordConvEmbedding::CharEmbeddingLookup (
1517 const int * seq_ptr,
1618 const float * char_embedding_weight_p,
19+ size_t char_embedding_table_size,
1720 size_t seq_len,
1821 size_t word_len,
1922 size_t char_embedding_size,
@@ -26,7 +29,12 @@ void WordConvEmbedding::CharEmbeddingLookup(
2629 float * cur_dst_ptr = dst + word_inx * word_len * char_embedding_size;
2730 size_t char_length_to_lookup = std::max<size_t >(words_len_ptr[word_inx], filter_width);
2831 for (size_t char_inx = 0 ; char_inx < char_length_to_lookup; char_inx++) {
29- memcpy (cur_dst_ptr, char_embedding_weight_p + (*cur_seq_ptr) * char_embedding_size, sizeof (float ) * char_embedding_size);
32+ const int char_index = *cur_seq_ptr;
33+ if (char_index >= 0 && static_cast <size_t >(char_index) < char_embedding_table_size) {
34+ memcpy (cur_dst_ptr,
35+ char_embedding_weight_p + static_cast <size_t >(char_index) * char_embedding_size,
36+ sizeof (float ) * char_embedding_size);
37+ }
3038 cur_dst_ptr += char_embedding_size;
3139 cur_seq_ptr++;
3240 }
@@ -131,7 +139,23 @@ void WordConvEmbedding::CalculateLengthOfEachWordInSequence(
131139 }
132140}
133141
134- Status WordConvEmbedding::ValidateInputShape (const TensorShape& w_conv_shape, const TensorShape& w_char_embedding_shape) const {
142+ Status WordConvEmbedding::ValidateInputShape (const TensorShape& sequence_shape, const TensorShape& w_conv_shape,
143+ const TensorShape& w_char_embedding_shape) const {
144+ if (sequence_shape.NumDimensions () <= 1 ) {
145+ return ORT_MAKE_STATUS (ONNXRUNTIME , FAIL , " Sequence input must have rank greater than 1." ,
146+ " Sequence rank: " , sequence_shape.NumDimensions ());
147+ }
148+
149+ if (w_conv_shape.NumDimensions () <= 3 ) {
150+ return ORT_MAKE_STATUS (ONNXRUNTIME , FAIL , " Conv weight input must have rank greater than 3." ,
151+ " Conv weight rank: " , w_conv_shape.NumDimensions ());
152+ }
153+
154+ if (w_char_embedding_shape.NumDimensions () <= 1 ) {
155+ return ORT_MAKE_STATUS (ONNXRUNTIME , FAIL , " Char embedding input must have rank greater than 1." ,
156+ " Char embedding rank: " , w_char_embedding_shape.NumDimensions ());
157+ }
158+
135159 if (embedding_size_ != -1 && w_conv_shape[0 ] != embedding_size_) {
136160 return ORT_MAKE_STATUS (ONNXRUNTIME , FAIL , " Conv filter size does not match embedding_size attribute." ,
137161 " embedding_size attribute: " , embedding_size_,
@@ -156,6 +180,12 @@ Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, co
156180 " Conv kernal size 2 : " , w_conv_shape[3 ]);
157181 }
158182
183+ if (w_conv_shape[2 ] > sequence_shape[1 ]) {
184+ return ORT_MAKE_STATUS (ONNXRUNTIME , FAIL , " Conv kernel width must not exceed word length." ,
185+ " Conv kernel width: " , w_conv_shape[2 ],
186+ " Word length: " , sequence_shape[1 ]);
187+ }
188+
159189 return Status::OK ();
160190}
161191
@@ -170,7 +200,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
170200 const TensorShape& w_conv_shape = w_conv.Shape ();
171201 const TensorShape& w_char_embedding_shape = w_char_embedding.Shape ();
172202
173- ORT_RETURN_IF_ERROR (ValidateInputShape (w_conv_shape, w_char_embedding_shape));
203+ ORT_RETURN_IF_ERROR (ValidateInputShape (sequence_shape, w_conv_shape, w_char_embedding_shape));
174204
175205 int64_t seq_len = sequence_shape[0 ];
176206 int64_t word_len = sequence_shape[1 ];
@@ -198,6 +228,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
198228
199229 CharEmbeddingLookup (seq_ptr,
200230 w_char_embedding.Data <float >(),
231+ onnxruntime::narrow<size_t >(w_char_embedding_shape[0 ]),
201232 onnxruntime::narrow<size_t >(seq_len),
202233 onnxruntime::narrow<size_t >(word_len),
203234 onnxruntime::narrow<size_t >(char_embedding_size),
0 commit comments