@@ -4,7 +4,6 @@ use dirs;
44use reqwest:: blocking:: multipart;
55use reqwest:: { self , Method } ;
66use serde:: { Deserialize , Serialize } ;
7- use std:: collections:: HashMap ;
87use std:: env;
98use std:: fs;
109use std:: fs:: File ;
@@ -517,31 +516,6 @@ fn deploy_create(model_name: &str, task: Option<&ModelTask>, dev: bool) -> Resul
517516 Ok ( ( ) )
518517}
519518
520- #[ derive( Debug , PartialEq ) ]
521- enum InferInputType {
522- INTEGER ,
523- NUMBER ,
524- TEXT ,
525- BINARY
526- }
527-
528- impl InferInputType {
529- fn from_str ( inp : & str ) -> Option < InferInputType > {
530- match inp {
531- "integer" => Some ( Self :: INTEGER ) ,
532- "number" => Some ( Self :: NUMBER ) ,
533- "string" => Some ( Self :: TEXT ) ,
534- _ => None ,
535- }
536- }
537- }
538-
539- #[ derive( Debug , PartialEq ) ]
540- enum InferInputLocation {
541- PARAMS ,
542- MULTIPART ,
543- }
544-
545519fn read_binary_file ( name : & str ) -> Result < Vec < u8 > > {
546520 let mut file = File :: open ( name) ?;
547521 let mut buf: Vec < u8 > = Vec :: new ( ) ;
@@ -553,14 +527,8 @@ fn parse_base64(b64: &str) -> Result<Vec<u8>> {
553527 Ok ( base64:: decode ( b64) ?)
554528}
555529
556- fn encode_base64 ( b64_bytes : & Vec < u8 > ) -> String {
557- base64:: encode ( b64_bytes)
558- }
559-
560- type InputMapping = HashMap < String , ( InferInputType , InferInputLocation ) > ;
561- fn infer_body ( mapping : InputMapping , args : & Vec < ( String , String ) > ) -> Result < multipart:: Form > {
530+ fn infer_body ( args : & Vec < ( String , String ) > ) -> Result < multipart:: Form > {
562531 let mut form = multipart:: Form :: new ( ) ;
563- let mut params = serde_json:: Map :: new ( ) ;
564532 for ( key, inp_value) in args {
565533 let raw_value = if inp_value. starts_with ( "@" ) {
566534 read_binary_file ( & inp_value[ 1 ..] ) ?
@@ -570,71 +538,13 @@ fn infer_body(mapping: InputMapping, args: &Vec<(String, String)>) -> Result<mul
570538 inp_value. as_bytes ( ) . to_vec ( )
571539 } ;
572540
573- let ( typ, location) = mapping. get ( key)
574- . ok_or ( DeepCtlError :: BadInput ( format ! ( "unexpected input argument `{}`" , key) ) ) ?;
575- match location {
576- InferInputLocation :: PARAMS => {
577- let parsed_value: serde_json:: Value = match typ {
578- InferInputType :: INTEGER | InferInputType :: NUMBER => serde_json:: from_str ( String :: from_utf8 ( raw_value) ?. trim ( ) ) ?,
579- InferInputType :: TEXT => serde_json:: Value :: String ( String :: from_utf8 ( raw_value) ?) ,
580- InferInputType :: BINARY => serde_json:: Value :: String ( encode_base64 ( & raw_value) ) ,
581- } ;
582- params. insert ( key. into ( ) , parsed_value) ;
583- } ,
584- InferInputLocation :: MULTIPART => {
585- let mut part = multipart:: Part :: bytes ( raw_value) ;
586- // If there is no filename, fastapi returns 422
587- part = part. file_name ( "filename.ext" ) ;
588- form = form. part ( key. to_owned ( ) , part) ;
589- }
590- } ;
591- }
592- if !params. is_empty ( ) {
593- form = form. text ( "input" , serde_json:: to_string ( & params) ?) ;
594- }
595- Ok ( form)
596- }
597-
598- fn get_model_in_schema ( dev : bool , model_name : & str ) -> Result < serde_json:: Value > {
599- let schema_cache = get_di_dir ( ) ?
600- . join ( "schemas" )
601- . join ( format ! ( "{}.in.schema.json" , model_name. replace( "/" , ":" ) ) ) ;
602- if ! schema_cache. exists ( ) {
603- let model_info = get_parsed_response ( & format ! ( "/models/{}" , model_name) , Method :: GET , dev, false ) ?;
604- let in_schema = model_info. get ( "in_schema" )
605- . ok_or ( DeepCtlError :: ApiMismatch ( format ! ( "/models/{} should contain in_schema" , model_name) ) ) ?;
606- // schema_cache is guaranteed to have a parent
607- fs:: create_dir_all ( & schema_cache. parent ( ) . unwrap ( ) ) ?;
608- serde_json:: to_writer_pretty ( File :: create ( & schema_cache) ?, in_schema) ?;
541+ let mut part = multipart:: Part :: bytes ( raw_value) ;
542+ // the filename forces the backend to treat this data as binary
543+ part = part. file_name ( "filename.ext" ) ;
544+ form = form. part ( key. to_owned ( ) , part) ;
609545 }
610546
611- let schema: serde_json:: Value = serde_json:: from_reader ( File :: open ( & schema_cache) ?) ?;
612-
613- Ok ( schema. to_owned ( ) )
614- }
615-
616- fn schema_to_mapping ( schema : serde_json:: Value ) -> Result < InputMapping > {
617- let properties = schema. get ( "properties" )
618- . ok_or ( DeepCtlError :: ApiMismatch ( "in_schema should have properties" . into ( ) ) ) ?
619- . as_object ( )
620- . ok_or ( DeepCtlError :: ApiMismatch ( "in_schema properties should be hash" . into ( ) ) ) ?;
621- let mut res = InputMapping :: new ( ) ;
622- for ( name, props) in properties {
623- let format = props. get ( "format" ) . and_then ( serde_json:: Value :: as_str) ;
624- if format == Some ( "binary" ) {
625- res. insert ( name. to_owned ( ) , ( InferInputType :: BINARY , InferInputLocation :: MULTIPART ) ) ;
626- } else {
627- let raw_inp_type = props. get ( "type" )
628- . ok_or ( DeepCtlError :: ApiMismatch ( "schema property should have type" . into ( ) ) ) ?
629- . as_str ( )
630- . ok_or ( DeepCtlError :: ApiMismatch ( "schema property type should be string" . into ( ) ) ) ?;
631- let inp_type = InferInputType :: from_str ( raw_inp_type)
632- . ok_or ( DeepCtlError :: ApiMismatch ( format ! ( "unhandled schema type {}" , raw_inp_type) ) )
633- . context ( format ! ( "type {}" , raw_inp_type) ) ?;
634- res. insert ( name. to_owned ( ) , ( inp_type, InferInputLocation :: PARAMS ) ) ;
635- }
636- }
637- Ok ( res)
547+ Ok ( form)
638548}
639549
640550fn infer_out_part ( value : & serde_json:: Value , location : & str ) -> Result < ( ) > {
@@ -715,8 +625,7 @@ fn infer_out_part(value: &serde_json::Value, location: &str) -> Result<()> {
715625}
716626
717627fn infer ( model_name : & str , args : & Vec < ( String , String ) > , outs : & Vec < ( String , String ) > , dev : bool ) -> Result < ( ) > {
718- let schema: serde_json:: Value = get_model_in_schema ( dev, model_name) ?;
719- let form = infer_body ( schema_to_mapping ( schema) ?, args) ?;
628+ let form = infer_body ( args) ?;
720629
721630 let json = get_parsed_response_extra (
722631 & format ! ( "/v1/inference/{}" , model_name) ,
0 commit comments