Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,20 @@ Generator::Generator(const Model& model, const GeneratorParams& params) : model_
throw std::runtime_error("search max_length is 0");
if (params.search.max_length > model.config_->model.context_length)
throw std::runtime_error("max_length (" + std::to_string(params.search.max_length) + ") cannot be greater than model context_length (" + std::to_string(model.config_->model.context_length) + ")");
if (params.search.batch_size < 1)
throw std::runtime_error("batch_size must be 1 or greater, is " + std::to_string(params.search.batch_size));

constexpr int kMaxBatchSize = 32;
constexpr int kMaxNumBeams = 32;
constexpr int kMaxNumBeamsCuda = 32;

if (params.search.batch_size < 1 || params.search.batch_size > kMaxBatchSize)
throw std::runtime_error("batch_size (" + std::to_string(params.search.batch_size) + ") must be in [1, " + std::to_string(kMaxBatchSize) + "]");

const int max_num_beams = (params.search.num_beams > 1 &&
(params.p_device->GetType() == DeviceType::CUDA || params.p_device->GetType() == DeviceType::NvTensorRtRtx))
? kMaxNumBeamsCuda
: kMaxNumBeams;
if (params.search.num_beams < 1 || params.search.num_beams > max_num_beams)
throw std::runtime_error("num_beams (" + std::to_string(params.search.num_beams) + ") must be in [1, " + std::to_string(max_num_beams) + "]");
if (params.config.model.vocab_size < 1)
throw std::runtime_error("vocab_size must be 1 or greater, is " + std::to_string(params.config.model.vocab_size));

Expand Down
42 changes: 41 additions & 1 deletion test/model_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,46 @@ Print all primes between 1 and n
}
#endif

// Validation tests for search parameter bounds
#if !USE_DML
TEST(ModelTests, NumBeamsUpperBoundThrows) {
auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 20);
params->SetSearchOption("batch_size", 1);
params->SetSearchOption("num_beams", 512); // exceeds upper bound of 256

EXPECT_THROW(OgaGenerator::Create(*model, *params), std::runtime_error);
}

TEST(ModelTests, BatchSizeUpperBoundThrows) {
auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 20);
params->SetSearchOption("batch_size", 512); // exceeds upper bound of 256

EXPECT_THROW(OgaGenerator::Create(*model, *params), std::runtime_error);
}

TEST(ModelTests, NumBeamsZeroThrows) {
auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 20);
params->SetSearchOption("batch_size", 1);
params->SetSearchOption("num_beams", 0); // below lower bound of 1

EXPECT_THROW(OgaGenerator::Create(*model, *params), std::runtime_error);
}

TEST(ModelTests, BatchSizeZeroThrows) {
auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");
auto params = OgaGeneratorParams::Create(*model);
params->SetSearchOption("max_length", 20);
params->SetSearchOption("batch_size", 0); // below lower bound of 1

EXPECT_THROW(OgaGenerator::Create(*model, *params), std::runtime_error);
}
#endif
// --- Validation tests (no model files required) ---

TEST(ValidationTests, WindowIndexAcceptsValidParams) {
Expand Down Expand Up @@ -520,4 +560,4 @@ TEST(ValidationTests, WindowIndexRejectsExcessiveDimensions) {
TEST(ValidationTests, WindowIndexRejectsTotalSizeOverflow) {
// Each dim individually <= kMaxElements, but grid_t * padded_h * padded_w > kMaxElements
EXPECT_THROW(Generators::ValidateWindowIndexParams(1000, 2000000, 2000000, 2, 14, 112), std::runtime_error);
}
}
Loading