Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,34 @@ void ClearDecoderProviderOptionsHardwareVendorId(Config& config, std::string_vie
}
}

struct ToolCalling_Element : JSON::Element {
explicit ToolCalling_Element(Config::ToolCalling& v) : v_{v} {}

void OnValue(std::string_view name, JSON::Value value) override {
if (name == "tool_call_start_token") {
v_.tool_call_start_token = JSON::Get<std::string_view>(value);
} else if (name == "tool_call_end_token") {
v_.tool_call_end_token = JSON::Get<std::string_view>(value);
}
}

Config::ToolCalling& v_;
};

struct Reasoning_Element : JSON::Element {
explicit Reasoning_Element(Config::Reasoning& v) : v_{v} {}

void OnValue(std::string_view name, JSON::Value value) override {
if (name == "reasoning_start_token") {
v_.reasoning_start_token = JSON::Get<std::string_view>(value);
} else if (name == "reasoning_end_token") {
v_.reasoning_end_token = JSON::Get<std::string_view>(value);
}
}

Config::Reasoning& v_;
};

struct Root_Element : JSON::Element {
explicit Root_Element(Config& config) : config_{config} {}

Expand All @@ -1556,13 +1584,17 @@ struct Root_Element : JSON::Element {
if (name == "model") return model_element_;
if (name == "search") return search_element_;
if (name == "engine") return engine_element_;
if (name == "tool_calling") return tool_calling_element_;
if (name == "reasoning") return reasoning_element_;
throw JSON::unknown_value_error{};
}

Config& config_;
Model_Element model_element_{config_.model};
Search_Element search_element_{config_.search};
Engine_Element engine_element_{config_.engine};
ToolCalling_Element tool_calling_element_{config_.tool_calling};
Reasoning_Element reasoning_element_{config_.reasoning};
};

struct RootObject_Element : JSON::Element {
Expand Down
10 changes: 10 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,16 @@ struct Config {
std::optional<StaticBatching> static_batching; // Static batching settings
} engine; // Engine settings

struct ToolCalling {
std::string tool_call_start_token; // e.g., "<tool_call>"
std::string tool_call_end_token; // e.g., "</tool_call>"
} tool_calling;

struct Reasoning {
std::string reasoning_start_token; // e.g., "<think>"
std::string reasoning_end_token; // e.g., "</think>"
} reasoning;

void AddMapping(const std::string& nominal_name, const std::string& graph_name);
// Returns graph name and true if the nominal name is found in the mapping
// otherwise returns the nominal name and false
Expand Down
52 changes: 52 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
#include <algorithm>
#include <array>
#include <climits>
#include <functional>
#include <random>
#include <set>
#include <string>
#include <string_view>
#include <thread>
#include <unordered_map>

#include "../generators.h"
#include "../search.h"
Expand Down Expand Up @@ -285,6 +287,11 @@ Tokenizer::Tokenizer(Config& config) : bos_token_id_{config.model.bos_token_id},
// Resolve tokenizer_dir (may be empty, relative, absolute, or "package:"-scheme).
const fs::path tokenizer_dir = config.ResolvePath(config.model.tokenizer_dir);
CheckResult(OrtxCreateTokenizerWithOptions(tokenizer_.Address(), tokenizer_dir.string().c_str(), keys, values, 2));

// TODO: Once ORT Extensions supports an "additional_special_tokens" option, pass the generation
// tags (tool_calling/reasoning tokens) here so that models which don't already mark them as
// special in their tokenizer_config.json will still get correct skip_special_tokens behavior.
// This is needed for the FL SDK's dual-stream special token detection in OnnxChatGenerator::Decode().
}

std::unique_ptr<TokenizerStream> Tokenizer::CreateStream() const {
Expand Down Expand Up @@ -817,6 +824,51 @@ bool Model::IsPruned() const {
return logits_shape[1] == 1;
}

namespace {

// Fallback map for models whose genai_config.json doesn't yet have tool_calling/reasoning sections.
// Keyed by model.type string from genai_config.json.
// Inner map: tag_name -> value
const std::string* GetFallbackTag(const std::string& model_type, const std::string& tag_name) {
static const std::unordered_map<std::string, std::unordered_map<std::string, std::string>> fallback_map = {
{"qwen2", {{"tool_call_start", "<tool_call>"}, {"tool_call_end", "</tool_call>"}}},
{"qwen3", {{"tool_call_start", "<tool_call>"}, {"tool_call_end", "</tool_call>"}, {"reasoning_start", "<think>"}, {"reasoning_end", "</think>"}}},
{"phi3", {{"tool_call_start", "<tool_call>"}, {"tool_call_end", "</tool_call>"}}},
{"gptoss", {{"tool_call_start", "<|start|>"}, {"tool_call_end", "<|call|>"}}},
};
auto type_it = fallback_map.find(model_type);
if (type_it == fallback_map.end()) return nullptr;
auto tag_it = type_it->second.find(tag_name);
if (tag_it == type_it->second.end()) return nullptr;
return &tag_it->second;
}

} // namespace

const std::string& Model::GetTag(const std::string& tag_name) const {
// Check config first (tool_calling and reasoning sections)
static const std::unordered_map<std::string, std::function<const std::string&(const Config&)>> config_accessors = {
{"tool_call_start", [](const Config& c) -> const std::string& { return c.tool_calling.tool_call_start_token; }},
{"tool_call_end", [](const Config& c) -> const std::string& { return c.tool_calling.tool_call_end_token; }},
{"reasoning_start", [](const Config& c) -> const std::string& { return c.reasoning.reasoning_start_token; }},
{"reasoning_end", [](const Config& c) -> const std::string& { return c.reasoning.reasoning_end_token; }},
};

auto accessor_it = config_accessors.find(tag_name);
if (accessor_it != config_accessors.end()) {
const std::string& config_val = accessor_it->second(*config_);
if (!config_val.empty())
return config_val;
}

// Fallback to model-type-based map
const auto* fallback = GetFallbackTag(config_->model.type, tag_name);
if (fallback) return *fallback;

static const std::string empty;
return empty;
}

std::shared_ptr<Model> CreateModel(OrtEnv& ort_env, const char* config_path, const RuntimeSettings* settings /*= nullptr*/) {
std::string config_overlay;
if (settings) {
Expand Down
5 changes: 5 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ struct Model : std::enable_shared_from_this<Model>, LeakChecked<Model>, External

bool IsPruned() const;

// Generic tag accessor (reads from config with model-type fallback).
// Known tag names: "tool_call_start", "tool_call_end", "reasoning_start", "reasoning_end".
// Returns the tag value, or an empty string if the model doesn't define the tag.
const std::string& GetTag(const std::string& tag_name) const;

std::unique_ptr<Config> config_;
std::unique_ptr<OrtSessionOptions> session_options_;

Expand Down
6 changes: 6 additions & 0 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ struct OgaModel : OgaAbstract {
return p;
}

OgaString GetTag(const char* tag_name) const {
const char* p;
OgaCheckResult(OgaModelGetTag(this, tag_name, &p));
return p;
}

static void operator delete(void* p) { OgaDestroyModel(reinterpret_cast<OgaModel*>(p)); }
};

Expand Down
7 changes: 7 additions & 0 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,13 @@ OgaResult* OGA_API_CALL OgaModelGetDeviceType(const OgaModel* model, const char*
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaModelGetTag(const OgaModel* model, const char* tag_name, const char** out) {
OGA_TRY
*out = AllocOgaString(model->GetTag(tag_name).c_str());
return nullptr;
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaCreateGeneratorParams(const OgaModel* model, OgaGeneratorParams** out) {
OGA_TRY
auto params = std::make_shared<Generators::GeneratorParams>(*model);
Expand Down
13 changes: 13 additions & 0 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,19 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaModelGetType(const OgaModel* model, const
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaModelGetDeviceType(const OgaModel* model, const char** out);

/**
* \brief Returns a tag value for this model by name.
*
* Known tag names: "tool_call_start", "tool_call_end", "reasoning_start", "reasoning_end".
* Returns an empty string if the model doesn't define the requested tag.
*
Comment on lines +410 to +414
* \param[in] model The model to query.
* \param[in] tag_name The name of the tag to retrieve.
* \param[out] out The tag value string. Must be destroyed with OgaDestroyString.
* \return OgaResult containing the error message if the call failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaModelGetTag(const OgaModel* model, const char* tag_name, const char** out);

/**
* \brief Destroys the given config
* \param[in] config The config to be destroyed.
Expand Down
78 changes: 78 additions & 0 deletions test/c_api_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1816,3 +1816,81 @@ TEST(CAPITests, ParakeetTdtTranscribeLong) {
auto transcription = RunParakeetTdt(PARAKEET_TDT_AUDIO_TEDLIUM);
EXPECT_FALSE(transcription.empty());
}

// Test tool_calling and reasoning config parsing and fallback map
TEST(CAPITests, Tags_Fallback) {
// tiny-random-gpt2 model has type "gpt2" which is NOT in the fallback map → empty tags
auto model = OgaModel::Create(MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32");

auto tool_start = model->GetTag("tool_call_start");
auto tool_end = model->GetTag("tool_call_end");
auto reasoning_start = model->GetTag("reasoning_start");
auto reasoning_end = model->GetTag("reasoning_end");

EXPECT_STREQ(static_cast<const char*>(tool_start), "");
EXPECT_STREQ(static_cast<const char*>(tool_end), "");
EXPECT_STREQ(static_cast<const char*>(reasoning_start), "");
EXPECT_STREQ(static_cast<const char*>(reasoning_end), "");
}

TEST(CAPITests, Tags_FromConfig) {
// Create a temporary model directory with tool_calling and reasoning sections
auto temp_dir = std::filesystem::temp_directory_path() / "oga_test_tool_tags";
std::filesystem::create_directories(temp_dir);
Comment on lines +1838 to +1839

// Copy minimal model files from tiny-random-gpt2
std::string src_dir = MODEL_PATH "hf-internal-testing/tiny-random-gpt2-fp32";
for (const auto& entry : std::filesystem::directory_iterator(src_dir)) {
if (entry.path().filename() != "genai_config.json") {
std::filesystem::copy_file(entry.path(), temp_dir / entry.path().filename(),
std::filesystem::copy_options::overwrite_existing);
}
}

// Write genai_config.json with tool_calling and reasoning sections
{
std::ofstream f((temp_dir / "genai_config.json").string());
f << R"({
"model": {
"type": "gpt2",
"pad_token_id": 98,
"bos_token_id": 98,
"eos_token_id": 98,
"vocab_size": 1000,
"context_length": 512,
"decoder": {
"session_options": { "provider_options": [] },
"filename": "past.onnx",
"num_key_value_heads": 4,
"head_size": 8,
"num_hidden_layers": 5,
"inputs": { "past_names": "past_%d" },
"outputs": { "present_names": "present_%d" }
}
},
"tool_calling": {
"tool_call_start_token": "<tool_call>",
"tool_call_end_token": "</tool_call>"
},
"reasoning": {
"reasoning_start_token": "<think>",
"reasoning_end_token": "</think>"
}
})";
}

auto model = OgaModel::Create(temp_dir.string().c_str());

auto tool_start = model->GetTag("tool_call_start");
auto tool_end = model->GetTag("tool_call_end");
auto reasoning_start = model->GetTag("reasoning_start");
auto reasoning_end = model->GetTag("reasoning_end");

EXPECT_STREQ(static_cast<const char*>(tool_start), "<tool_call>");
EXPECT_STREQ(static_cast<const char*>(tool_end), "</tool_call>");
EXPECT_STREQ(static_cast<const char*>(reasoning_start), "<think>");
EXPECT_STREQ(static_cast<const char*>(reasoning_end), "</think>");

// Cleanup
std::filesystem::remove_all(temp_dir);
}
Loading