Skip to content

Commit c86a05a

Browse files
authored
ICM fixes (4/n) (microsoft#27957)
### Description Fixes ICM issue https://portal.microsofticm.com/imp/v5/incidents/details/31000000562663/summary ### Motivation and Context Fix ICMs
1 parent e227e8a commit c86a05a

3 files changed

Lines changed: 77 additions & 3 deletions

File tree

onnxruntime/contrib_ops/cpu/word_conv_embedding.cc

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include <cstring>
5+
46
#include <core/common/safeint.h>
57
#include "word_conv_embedding.h"
68

@@ -14,6 +16,7 @@ namespace contrib {
1416
void WordConvEmbedding::CharEmbeddingLookup(
1517
const int* seq_ptr,
1618
const float* char_embedding_weight_p,
19+
size_t char_embedding_table_size,
1720
size_t seq_len,
1821
size_t word_len,
1922
size_t char_embedding_size,
@@ -26,7 +29,12 @@ void WordConvEmbedding::CharEmbeddingLookup(
2629
float* cur_dst_ptr = dst + word_inx * word_len * char_embedding_size;
2730
size_t char_length_to_lookup = std::max<size_t>(words_len_ptr[word_inx], filter_width);
2831
for (size_t char_inx = 0; char_inx < char_length_to_lookup; char_inx++) {
29-
memcpy(cur_dst_ptr, char_embedding_weight_p + (*cur_seq_ptr) * char_embedding_size, sizeof(float) * char_embedding_size);
32+
const int char_index = *cur_seq_ptr;
33+
if (char_index >= 0 && static_cast<size_t>(char_index) < char_embedding_table_size) {
34+
memcpy(cur_dst_ptr,
35+
char_embedding_weight_p + static_cast<size_t>(char_index) * char_embedding_size,
36+
sizeof(float) * char_embedding_size);
37+
}
3038
cur_dst_ptr += char_embedding_size;
3139
cur_seq_ptr++;
3240
}
@@ -131,7 +139,23 @@ void WordConvEmbedding::CalculateLengthOfEachWordInSequence(
131139
}
132140
}
133141

134-
Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, const TensorShape& w_char_embedding_shape) const {
142+
Status WordConvEmbedding::ValidateInputShape(const TensorShape& sequence_shape, const TensorShape& w_conv_shape,
143+
const TensorShape& w_char_embedding_shape) const {
144+
if (sequence_shape.NumDimensions() <= 1) {
145+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Sequence input must have rank greater than 1.",
146+
" Sequence rank: ", sequence_shape.NumDimensions());
147+
}
148+
149+
if (w_conv_shape.NumDimensions() <= 3) {
150+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Conv weight input must have rank greater than 3.",
151+
" Conv weight rank: ", w_conv_shape.NumDimensions());
152+
}
153+
154+
if (w_char_embedding_shape.NumDimensions() <= 1) {
155+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Char embedding input must have rank greater than 1.",
156+
" Char embedding rank: ", w_char_embedding_shape.NumDimensions());
157+
}
158+
135159
if (embedding_size_ != -1 && w_conv_shape[0] != embedding_size_) {
136160
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Conv filter size does not match embedding_size attribute.",
137161
" embedding_size attribute: ", embedding_size_,
@@ -156,6 +180,12 @@ Status WordConvEmbedding::ValidateInputShape(const TensorShape& w_conv_shape, co
156180
" Conv kernal size 2 : ", w_conv_shape[3]);
157181
}
158182

183+
if (w_conv_shape[2] > sequence_shape[1]) {
184+
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Conv kernel width must not exceed word length.",
185+
" Conv kernel width: ", w_conv_shape[2],
186+
" Word length: ", sequence_shape[1]);
187+
}
188+
159189
return Status::OK();
160190
}
161191

@@ -170,7 +200,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
170200
const TensorShape& w_conv_shape = w_conv.Shape();
171201
const TensorShape& w_char_embedding_shape = w_char_embedding.Shape();
172202

173-
ORT_RETURN_IF_ERROR(ValidateInputShape(w_conv_shape, w_char_embedding_shape));
203+
ORT_RETURN_IF_ERROR(ValidateInputShape(sequence_shape, w_conv_shape, w_char_embedding_shape));
174204

175205
int64_t seq_len = sequence_shape[0];
176206
int64_t word_len = sequence_shape[1];
@@ -198,6 +228,7 @@ Status WordConvEmbedding::Compute(OpKernelContext* ctx) const {
198228

199229
CharEmbeddingLookup(seq_ptr,
200230
w_char_embedding.Data<float>(),
231+
onnxruntime::narrow<size_t>(w_char_embedding_shape[0]),
201232
onnxruntime::narrow<size_t>(seq_len),
202233
onnxruntime::narrow<size_t>(word_len),
203234
onnxruntime::narrow<size_t>(char_embedding_size),

onnxruntime/contrib_ops/cpu/word_conv_embedding.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class WordConvEmbedding final : public OpKernel {
2626
void CharEmbeddingLookup(
2727
const int* seq_ptr,
2828
const float* char_embedding_weight_p,
29+
size_t char_embedding_table_size,
2930
size_t seq_len,
3031
size_t word_len,
3132
size_t char_embedding_size,
@@ -51,6 +52,7 @@ class WordConvEmbedding final : public OpKernel {
5152
size_t word_len) const;
5253

5354
Status ValidateInputShape(
55+
const TensorShape& sequence_shape,
5456
const TensorShape& w_conv_shape,
5557
const TensorShape& w_char_embedding_shape) const;
5658

onnxruntime/test/contrib_ops/word_conv_embedding_test.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft Corporation. All rights reserved.
22
// Licensed under the MIT License.
33

4+
#include <cmath>
45
#include <vector>
56
#include "gtest/gtest.h"
67
#include "test/providers/provider_test_utils.h"
@@ -126,5 +127,45 @@ TEST(ContribOpTest, WordConvEmbedding_char_embedding_shape_conv_shape_not_match)
126127
test.Run(OpTester::ExpectResult::kExpectFailure);
127128
}
128129

130+
TEST(ContribOpTest, WordConvEmbedding_out_of_range_char_index_treated_as_padding) {
131+
OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain);
132+
133+
test.AddAttribute<int64_t>("embedding_size", 1LL);
134+
test.AddAttribute<int64_t>("conv_window_size", 2LL);
135+
test.AddAttribute<int64_t>("char_embedding_size", 1LL);
136+
137+
test.AddInput<int>("Sequence", {1, 2}, {1, 99});
138+
test.AddInput<float>("W", {1, 1, 2, 1}, {1.0f, 1.0f});
139+
test.AddInput<float>("B", {1}, {0.0f});
140+
test.AddInput<float>("C", {2, 1}, {123.0f, 2.0f});
141+
test.AddOutput<float>("Y", {1, 1}, {std::tanh(2.0f)});
142+
143+
test.Run();
144+
}
145+
146+
TEST(ContribOpTest, WordConvEmbedding_rejects_filter_width_larger_than_word_length) {
147+
OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain);
148+
149+
test.AddInput<int>("Sequence", {1, 2}, {1, 2});
150+
test.AddInput<float>("W", {1, 1, 3, 1}, {1.0f, 1.0f, 1.0f});
151+
test.AddInput<float>("B", {1}, {0.0f});
152+
test.AddInput<float>("C", {3, 1}, {0.0f, 1.0f, 2.0f});
153+
test.AddOutput<float>("Y", {1, 1}, {0.0f});
154+
155+
test.Run(OpTester::ExpectResult::kExpectFailure, "Conv kernel width must not exceed word length");
156+
}
157+
158+
TEST(ContribOpTest, WordConvEmbedding_rejects_sequence_rank_one) {
159+
OpTester test("WordConvEmbedding", 1, onnxruntime::kMSDomain);
160+
161+
test.AddInput<int>("Sequence", {2}, {1, 2});
162+
test.AddInput<float>("W", {1, 1, 2, 1}, {1.0f, 1.0f});
163+
test.AddInput<float>("B", {1}, {0.0f});
164+
test.AddInput<float>("C", {3, 1}, {0.0f, 1.0f, 2.0f});
165+
test.AddOutput<float>("Y", {1, 1}, {0.0f});
166+
167+
test.Run(OpTester::ExpectResult::kExpectFailure, "Sequence input must have rank greater than 1");
168+
}
169+
129170
} // namespace test
130171
} // namespace onnxruntime

0 commit comments

Comments
 (0)