Skip to content

Commit c2bb87f

Browse files
committed
fix: accept logprobs as integer and warn on unsupported params instead of rejecting
- Add custom deserializer for chat completion logprobs field to accept both boolean and integer values (0=false, non-zero=true) - Change unsupported field validation to log warnings instead of returning errors, improving compatibility with clients that send extra parameters like add_special_tokens, prompt_cache_key, request_id, and chat_template https://claude.ai/code/session_01EoHsvxn4B3WMuJkwP5f6od
1 parent aaa8a56 commit c2bb87f

File tree

4 files changed

+102
-23
lines changed

4 files changed

+102
-23
lines changed

lib/async-openai/src/types/chat.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,39 @@ use std::{collections::HashMap, pin::Pin};
1212

1313
use derive_builder::Builder;
1414
use futures::Stream;
15-
use serde::{Deserialize, Serialize};
15+
use serde::{Deserialize, Deserializer, Serialize};
1616
use utoipa::ToSchema;
1717

1818
use url::Url;
1919
use uuid::{Uuid, uuid};
2020

21+
/// Custom deserializer that accepts both boolean and integer values for logprobs.
22+
/// Integer values are coerced to boolean (0 = false, non-zero = true).
23+
fn deserialize_bool_or_int<'de, D>(deserializer: D) -> Result<Option<bool>, D::Error>
24+
where
25+
D: Deserializer<'de>,
26+
{
27+
let value: Option<serde_json::Value> = Option::deserialize(deserializer)?;
28+
match value {
29+
None => Ok(None),
30+
Some(serde_json::Value::Bool(b)) => Ok(Some(b)),
31+
Some(serde_json::Value::Number(n)) => {
32+
if let Some(i) = n.as_i64() {
33+
Ok(Some(i != 0))
34+
} else if let Some(f) = n.as_f64() {
35+
Ok(Some(f != 0.0))
36+
} else {
37+
Err(serde::de::Error::custom(
38+
"logprobs must be a boolean or integer",
39+
))
40+
}
41+
}
42+
Some(_) => Err(serde::de::Error::custom(
43+
"logprobs must be a boolean or integer",
44+
)),
45+
}
46+
}
47+
2148
use crate::error::OpenAIError;
2249

2350
#[derive(ToSchema, Debug, Serialize, Deserialize, Clone, PartialEq)]
@@ -949,7 +976,12 @@ pub struct CreateChatCompletionRequest {
949976
pub logit_bias: Option<HashMap<String, serde_json::Value>>, // default: null
950977

951978
/// Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`.
952-
#[serde(skip_serializing_if = "Option::is_none")]
979+
/// Also accepts integer values for compatibility (0 = false, non-zero = true).
980+
#[serde(
981+
skip_serializing_if = "Option::is_none",
982+
default,
983+
deserialize_with = "deserialize_bool_or_int"
984+
)]
953985
pub logprobs: Option<bool>,
954986

955987
/// An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.

lib/llm/src/http/service/openai.rs

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2786,17 +2786,9 @@ mod tests {
27862786
assert!(request.unsupported_fields.contains_key("documents"));
27872787
assert!(request.unsupported_fields.contains_key("chat_template"));
27882788

2789+
// Unsupported fields should now produce a warning but not an error
27892790
let result = validate_chat_completion_fields_generic(&request);
2790-
assert!(result.is_err());
2791-
if let Err(error_response) = result {
2792-
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
2793-
let msg = &error_response.1.message;
2794-
assert!(msg.contains("Unsupported parameter"));
2795-
// Verify all fields appear in the error message
2796-
assert!(msg.contains("add_special_tokens"));
2797-
assert!(msg.contains("documents"));
2798-
assert!(msg.contains("chat_template"));
2799-
}
2791+
assert!(result.is_ok());
28002792
}
28012793

28022794
#[test]
@@ -2819,16 +2811,9 @@ mod tests {
28192811
);
28202812
assert!(request.unsupported_fields.contains_key("response_format"));
28212813

2814+
// Unsupported fields should now produce a warning but not an error
28222815
let result = validate_completion_fields_generic(&request);
2823-
assert!(result.is_err());
2824-
if let Err(error_response) = result {
2825-
assert_eq!(error_response.0, StatusCode::BAD_REQUEST);
2826-
let msg = &error_response.1.message;
2827-
assert!(msg.contains("Unsupported parameter"));
2828-
// Verify both fields appear in error message
2829-
assert!(msg.contains("add_special_tokens"));
2830-
assert!(msg.contains("response_format"));
2831-
}
2816+
assert!(result.is_ok());
28322817
}
28332818

28342819
#[tokio::test]

lib/llm/src/protocols/openai/chat_completions.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,4 +426,62 @@ mod tests {
426426
assert_eq!(output_options.skip_special_tokens, Some(skip_value));
427427
}
428428
}
429+
430+
#[test]
431+
fn test_logprobs_accepts_integer() {
432+
// logprobs: 1 (integer) should be treated as true
433+
let json_str = json!({
434+
"model": "test-model",
435+
"messages": [{"role": "user", "content": "Hello"}],
436+
"logprobs": 1
437+
});
438+
let request: NvCreateChatCompletionRequest =
439+
serde_json::from_value(json_str).expect("Failed to deserialize request with logprobs integer");
440+
assert_eq!(request.inner.logprobs, Some(true));
441+
442+
// logprobs: 0 should be treated as false
443+
let json_str = json!({
444+
"model": "test-model",
445+
"messages": [{"role": "user", "content": "Hello"}],
446+
"logprobs": 0
447+
});
448+
let request: NvCreateChatCompletionRequest =
449+
serde_json::from_value(json_str).expect("Failed to deserialize request with logprobs 0");
450+
assert_eq!(request.inner.logprobs, Some(false));
451+
452+
// logprobs: true should still work
453+
let json_str = json!({
454+
"model": "test-model",
455+
"messages": [{"role": "user", "content": "Hello"}],
456+
"logprobs": true
457+
});
458+
let request: NvCreateChatCompletionRequest =
459+
serde_json::from_value(json_str).expect("Failed to deserialize request with logprobs bool");
460+
assert_eq!(request.inner.logprobs, Some(true));
461+
}
462+
463+
#[test]
464+
fn test_unsupported_fields_warn_not_error() {
465+
use crate::engines::ValidateRequest;
466+
467+
let json_str = json!({
468+
"model": "test-model",
469+
"messages": [{"role": "user", "content": "Hello"}],
470+
"add_special_tokens": true,
471+
"prompt_cache_key": "key123",
472+
"request_id": "req-456",
473+
"chat_template": "custom"
474+
});
475+
let request: NvCreateChatCompletionRequest =
476+
serde_json::from_value(json_str).expect("Failed to deserialize request");
477+
478+
// These fields should be captured as unsupported
479+
assert!(request.unsupported_fields.contains_key("add_special_tokens"));
480+
assert!(request.unsupported_fields.contains_key("prompt_cache_key"));
481+
assert!(request.unsupported_fields.contains_key("request_id"));
482+
assert!(request.unsupported_fields.contains_key("chat_template"));
483+
484+
// Validation should succeed (warn, not error)
485+
assert!(ValidateRequest::validate(&request).is_ok());
486+
}
429487
}

lib/llm/src/protocols/openai/validate.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ pub const MAX_REPETITION_PENALTY: f32 = 2.0;
9797
// Shared Fields
9898
//
9999

100-
/// Validates that no unsupported fields are present in the request
100+
/// Validates unsupported fields in the request.
101+
/// Instead of rejecting requests with unsupported fields, this now logs a warning
102+
/// and allows the request to proceed. This improves compatibility with clients
103+
/// that send extra parameters (e.g. add_special_tokens, prompt_cache_key,
104+
/// request_id, chat_template).
101105
pub fn validate_no_unsupported_fields(
102106
unsupported_fields: &std::collections::HashMap<String, serde_json::Value>,
103107
) -> Result<(), anyhow::Error> {
@@ -106,7 +110,7 @@ pub fn validate_no_unsupported_fields(
106110
.keys()
107111
.map(|s| format!("`{}`", s))
108112
.collect();
109-
anyhow::bail!("Unsupported parameter(s): {}", fields.join(", "));
113+
tracing::warn!("Ignoring unsupported parameter(s): {}", fields.join(", "));
110114
}
111115
Ok(())
112116
}

0 commit comments

Comments
 (0)