Skip to content

Commit 293fcfd

Browse files
committed
Use multipart inference API and avoid schema parsing
1 parent da73e05 commit 293fcfd

File tree

1 file changed

+7
-98
lines changed

1 file changed

+7
-98
lines changed

src/main.rs

Lines changed: 7 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use dirs;
44
use reqwest::blocking::multipart;
55
use reqwest::{self, Method};
66
use serde::{Deserialize, Serialize};
7-
use std::collections::HashMap;
87
use std::env;
98
use std::fs;
109
use std::fs::File;
@@ -517,31 +516,6 @@ fn deploy_create(model_name: &str, task: Option<&ModelTask>, dev: bool) -> Resul
517516
Ok(())
518517
}
519518

520-
#[derive(Debug, PartialEq)]
521-
enum InferInputType {
522-
INTEGER,
523-
NUMBER,
524-
TEXT,
525-
BINARY
526-
}
527-
528-
impl InferInputType {
529-
fn from_str(inp: &str) -> Option<InferInputType> {
530-
match inp {
531-
"integer" => Some(Self::INTEGER),
532-
"number" => Some(Self::NUMBER),
533-
"string" => Some(Self::TEXT),
534-
_ => None,
535-
}
536-
}
537-
}
538-
539-
#[derive(Debug, PartialEq)]
540-
enum InferInputLocation {
541-
PARAMS,
542-
MULTIPART,
543-
}
544-
545519
fn read_binary_file(name: &str) -> Result<Vec<u8>> {
546520
let mut file = File::open(name)?;
547521
let mut buf: Vec<u8> = Vec::new();
@@ -553,14 +527,8 @@ fn parse_base64(b64: &str) -> Result<Vec<u8>> {
553527
Ok(base64::decode(b64)?)
554528
}
555529

556-
fn encode_base64(b64_bytes: &Vec<u8>) -> String {
557-
base64::encode(b64_bytes)
558-
}
559-
560-
type InputMapping = HashMap<String, (InferInputType, InferInputLocation)>;
561-
fn infer_body(mapping: InputMapping, args: &Vec<(String, String)>) -> Result<multipart::Form> {
530+
fn infer_body(args: &Vec<(String, String)>) -> Result<multipart::Form> {
562531
let mut form = multipart::Form::new();
563-
let mut params = serde_json::Map::new();
564532
for (key, inp_value) in args {
565533
let raw_value = if inp_value.starts_with("@") {
566534
read_binary_file(&inp_value[1..])?
@@ -570,71 +538,13 @@ fn infer_body(mapping: InputMapping, args: &Vec<(String, String)>) -> Result<mul
570538
inp_value.as_bytes().to_vec()
571539
};
572540

573-
let (typ, location) = mapping.get(key)
574-
.ok_or(DeepCtlError::BadInput(format!("unexpected input argument `{}`", key)))?;
575-
match location {
576-
InferInputLocation::PARAMS => {
577-
let parsed_value: serde_json::Value = match typ {
578-
InferInputType::INTEGER | InferInputType::NUMBER => serde_json::from_str(String::from_utf8(raw_value)?.trim())?,
579-
InferInputType::TEXT => serde_json::Value::String(String::from_utf8(raw_value)?),
580-
InferInputType::BINARY => serde_json::Value::String(encode_base64(&raw_value)),
581-
};
582-
params.insert(key.into(), parsed_value);
583-
},
584-
InferInputLocation::MULTIPART => {
585-
let mut part = multipart::Part::bytes(raw_value);
586-
// If there is no filename, fastapi returns 422
587-
part = part.file_name("filename.ext");
588-
form = form.part(key.to_owned(), part);
589-
}
590-
};
591-
}
592-
if !params.is_empty() {
593-
form = form.text("input", serde_json::to_string(&params)?);
594-
}
595-
Ok(form)
596-
}
597-
598-
fn get_model_in_schema(dev: bool, model_name: &str) -> Result<serde_json::Value> {
599-
let schema_cache = get_di_dir()?
600-
.join("schemas")
601-
.join(format!("{}.in.schema.json", model_name.replace("/", ":")));
602-
if ! schema_cache.exists() {
603-
let model_info = get_parsed_response(&format!("/models/{}", model_name), Method::GET, dev, false)?;
604-
let in_schema = model_info.get("in_schema")
605-
.ok_or(DeepCtlError::ApiMismatch(format!("/models/{} should contain in_schema", model_name)))?;
606-
// schema_cache is guaranteed to have a parent
607-
fs::create_dir_all(&schema_cache.parent().unwrap())?;
608-
serde_json::to_writer_pretty(File::create(&schema_cache)?, in_schema)?;
541+
let mut part = multipart::Part::bytes(raw_value);
542+
// the filename forces the backend to treat this data as binary
543+
part = part.file_name("filename.ext");
544+
form = form.part(key.to_owned(), part);
609545
}
610546

611-
let schema: serde_json::Value = serde_json::from_reader(File::open(&schema_cache)?)?;
612-
613-
Ok(schema.to_owned())
614-
}
615-
616-
fn schema_to_mapping(schema: serde_json::Value) -> Result<InputMapping> {
617-
let properties = schema.get("properties")
618-
.ok_or(DeepCtlError::ApiMismatch("in_schema should have properties".into()))?
619-
.as_object()
620-
.ok_or(DeepCtlError::ApiMismatch("in_schema properties should be hash".into()))?;
621-
let mut res = InputMapping::new();
622-
for (name, props) in properties {
623-
let format = props.get("format").and_then(serde_json::Value::as_str);
624-
if format == Some("binary") {
625-
res.insert(name.to_owned(), (InferInputType::BINARY, InferInputLocation::MULTIPART));
626-
} else {
627-
let raw_inp_type = props.get("type")
628-
.ok_or(DeepCtlError::ApiMismatch("schema property should have type".into()))?
629-
.as_str()
630-
.ok_or(DeepCtlError::ApiMismatch("schema property type should be string".into()))?;
631-
let inp_type = InferInputType::from_str(raw_inp_type)
632-
.ok_or(DeepCtlError::ApiMismatch(format!("unhandled schema type {}", raw_inp_type)))
633-
.context(format!("type {}", raw_inp_type))?;
634-
res.insert(name.to_owned(), (inp_type, InferInputLocation::PARAMS));
635-
}
636-
}
637-
Ok(res)
547+
Ok(form)
638548
}
639549

640550
fn infer_out_part(value: &serde_json::Value, location: &str) -> Result<()> {
@@ -715,8 +625,7 @@ fn infer_out_part(value: &serde_json::Value, location: &str) -> Result<()> {
715625
}
716626

717627
fn infer(model_name: &str, args: &Vec<(String, String)>, outs: &Vec<(String, String)>, dev: bool) -> Result<()> {
718-
let schema: serde_json::Value = get_model_in_schema(dev, model_name)?;
719-
let form = infer_body(schema_to_mapping(schema)?, args)?;
628+
let form = infer_body(args)?;
720629

721630
let json = get_parsed_response_extra(
722631
&format!("/v1/inference/{}", model_name),

0 commit comments

Comments
 (0)