@@ -16,13 +16,15 @@ use version_compare::Version;
1616use chrono:: { DateTime , Utc , Duration } ;
1717use serde:: { Serialize , Deserialize } ;
1818use chrono:: serde:: ts_seconds;
19+ use std:: os:: unix:: fs:: PermissionsExt ;
1920
2021const VERSION : & str = env ! ( "CARGO_PKG_VERSION" ) ;
2122
2223const DEEPINFRA_HOST_PROD : & str = "https://api.deepinfra.com" ;
2324const DEEPINFRA_HOST_DEV : & str = "https://localhost:7001" ;
2425const LOGIN_PATH : & str = "/github/login" ;
2526const VERSION_CHECK_SEC : i64 = 60 * 60 * 24 * 7 ; // 1 week
27+ const GITHUB_RELEASE_LATEST : & str = "https://github.com/deepinfra/deepctl/releases/latest/download" ;
2628
2729
2830#[ derive( Serialize , Deserialize , PartialEq , Debug ) ]
@@ -45,7 +47,6 @@ struct Cli {
4547 dev : bool ,
4648}
4749
48-
4950#[ derive( Subcommand ) ]
5051enum Commands {
5152 /// Authentication commands for Deep Infra
@@ -73,16 +74,10 @@ enum Commands {
7374 #[ arg( short( 'i' ) , value_parser = infer_args_parser) ]
7475 args : Vec < ( String , String ) > ,
7576 } ,
76-
77- /// test command with subcommands
78- Test {
79- #[ command( subcommand) ]
80- command : Option < TestSubcommands > ,
81- } ,
8277 /// version command
8378 Version {
8479 #[ command( subcommand) ]
85- command : Option < VersionSubcommands > ,
80+ command : VersionSubcommands ,
8681 }
8782}
8883
@@ -240,23 +235,17 @@ fn get_config_path(dev: bool) -> std::io::Result<std::path::PathBuf> {
240235 Ok ( config_path)
241236}
242237
243- fn get_version_path ( ) -> std:: io:: Result < std:: path:: PathBuf > {
244- let home = dirs:: home_dir ( ) . unwrap ( ) ;
245- let path = home. join ( ".deepinfra/" ) ;
246- let config_path = path. join ( "version.yaml" ) ;
247- Ok ( config_path)
238+ fn get_version_path ( ) -> std:: path:: PathBuf {
239+ dirs:: home_dir ( ) . unwrap ( ) . join ( ".deepinfra" ) . join ( "version.yaml" )
248240}
249241
250242fn read_version_data ( ) -> std:: io:: Result < VersionCheck > {
251- let config_path = get_version_path ( ) ? ;
252- let mut file = File :: open ( config_path) ?;
243+ let config_path = get_version_path ( ) ;
244+ let mut file = File :: open ( & config_path) ?;
253245 let mut contents = String :: new ( ) ;
254246 file. read_to_string ( & mut contents) ?;
255247
256- match serde_yaml:: from_str ( & contents) {
257- Ok ( v) => Ok ( v) ,
258- Err ( e) => Err ( io:: Error :: new ( io:: ErrorKind :: Other , e) ) ,
259- }
248+ serde_yaml:: from_str ( & contents) . map_err ( |e| io:: Error :: new ( io:: ErrorKind :: Other , e) )
260249}
261250
262251fn auth_logout ( _dev : bool ) -> std:: io:: Result < ( ) > {
@@ -553,42 +542,44 @@ fn check_version_with_server(dev: bool) -> std::io::Result<VersionCheck> {
553542 Ok ( version_data)
554543}
555544
556- fn version_check ( dev : bool ) -> std:: io:: Result < ( ) > {
557- let version_data = check_version_with_server ( dev) ?;
558- do_version_check ( & version_data, false ) ?;
559- Ok ( ( ) )
560- }
561-
562- fn main_version_check ( dev : bool ) -> std:: io:: Result < ( ) > {
563- let mut version_data: VersionCheck = read_version_data ( ) ?;
564- // println!("min_version: {}", version_data.min);
565- // println!("update_version: {}", version_data.update);
566- // println!("latest_version: {}", version_data.latest);
567- // println!("last_check: {}", version_data.last_check);
568- version_data = if version_data. last_check < Utc :: now ( ) - Duration :: seconds ( VERSION_CHECK_SEC ) {
545+ fn main_version_check ( dev : bool , force : bool ) -> std:: io:: Result < ( ) > {
546+ let crnt_version_data: Option < VersionCheck > = read_version_data ( ) . ok ( ) ;
547+ let version_data = if
548+ crnt_version_data. is_none ( ) ||
549+ force ||
550+ crnt_version_data. as_ref ( ) . unwrap ( ) . last_check < Utc :: now ( ) - Duration :: seconds ( VERSION_CHECK_SEC ) {
569551 println ! ( "checking version with server..." ) ;
570552 check_version_with_server ( dev) ?
571553 } else {
572- version_data
554+ crnt_version_data . unwrap ( )
573555 } ;
574556 do_version_check ( & version_data, true ) ?;
575557 Ok ( ( ) )
576558}
577559
560+ fn prompt_update ( reason : & str , latest : & str ) {
561+ println ! ( "Your version {} is {}. Please update to the latest version {}." ,
562+ VERSION , reason, latest) ;
563+
564+ let mut sudo_str = "sudo " ;
565+ if let Ok ( exe) = std:: env:: current_exe ( ) {
566+ if exe. as_path ( ) . starts_with ( dirs:: home_dir ( ) . unwrap ( ) . as_path ( ) ) {
567+ sudo_str = "" ;
568+ }
569+ }
570+ println ! ( "Update to the latest version using `{}deepctl version update`" , sudo_str) ;
571+ }
572+
578573fn do_version_check ( version_data : & VersionCheck , silent : bool ) -> std:: io:: Result < ( ) > {
579574 let this_version = Version :: from ( VERSION ) . unwrap ( ) ;
580575 let min_version = Version :: from ( & version_data. min ) . unwrap ( ) ;
581576 let update_version = Version :: from ( & version_data. update ) . unwrap ( ) ;
582577
583578 if this_version < min_version {
584- println ! ( "Your version {} is too old. Please update to the latest version {}." ,
585- VERSION , version_data. latest) ;
586- println ! ( "Update to the latest version using `deepctl version update`" ) ;
579+ prompt_update ( "too old" , & version_data. latest ) ;
587580 exit ( 1 ) ;
588581 } else if this_version < update_version {
589- println ! ( "Your version ({}) is outdated. Please update to the latest version {}." ,
590- VERSION , version_data. latest) ;
591- println ! ( "Update to the latest version using `deepctl version update`" ) ;
582+ prompt_update ( "outdated" , & version_data. latest )
592583 } else {
593584 if !silent {
594585 println ! ( "Your version ({}) is up to date." , VERSION ) ;
@@ -598,22 +589,62 @@ fn do_version_check(version_data: &VersionCheck, silent: bool) -> std::io::Resul
598589}
599590
600591fn write_version_data ( version_data : & VersionCheck ) -> std:: io:: Result < ( ) > {
601- let version_path = get_version_path ( ) ? ;
592+ let version_path = get_version_path ( ) ;
602593 fs:: create_dir_all ( & version_path. parent ( ) . unwrap ( ) ) ?;
603594 let mut version_file = File :: create ( & version_path) ?;
604595 let yaml = serde_yaml:: to_string ( & version_data) . unwrap ( ) ;
605596 version_file. write_all ( yaml. as_bytes ( ) ) ?;
606597 Ok ( ( ) )
607598}
608599
600+ fn perform_update ( dev : bool ) -> std:: io:: Result < ( ) > {
601+ let suffix = if cfg ! ( target_os = "macos" ) {
602+ "-macos"
603+ } else {
604+ "-linux"
605+ } ;
606+
607+ let client = get_http_client ( dev) ;
608+ let uri = format ! ( "{}/deepctl{}" , GITHUB_RELEASE_LATEST , suffix) ;
609+ let mut res = client. get ( & uri)
610+ . timeout ( std:: time:: Duration :: from_secs ( 300 ) )
611+ . send ( ) . unwrap_or_else ( |why| panic ! ( "Failed to fetch {}: {}" , uri, why) ) ;
612+ if !res. status ( ) . is_success ( ) {
613+ panic ! ( "Failed to fetch {}" , uri) ;
614+ }
615+
616+ let current_exe = std:: env:: current_exe ( ) . unwrap_or_else ( |why| {
617+ panic ! ( "Can't figure out where deepctl is installed: {}" , why) ;
618+ } ) ;
619+ let mut tmp_exe = current_exe. clone ( ) ;
620+ tmp_exe. set_file_name ( current_exe. file_name ( ) . unwrap ( ) . to_str ( ) . unwrap ( ) . to_owned ( ) + ".tmp" ) ;
621+ {
622+ let mut tmp_exe_f = File :: create ( & tmp_exe)
623+ . unwrap_or_else ( |why| panic ! ( "couldn't open {:?}: {}" , tmp_exe, why) ) ;
624+ res. copy_to ( & mut tmp_exe_f)
625+ . unwrap_or_else ( |why| panic ! ( "Failed to save new version to {:?}: {}" , tmp_exe, why) ) ;
626+ }
627+ fs:: set_permissions ( & tmp_exe, fs:: Permissions :: from_mode ( 0o755 ) ) . unwrap ( ) ;
628+ std:: fs:: rename ( & tmp_exe, & current_exe)
629+ . unwrap_or_else ( |why| panic ! ( "Failed to rename {:?} to {:?}: {}" , tmp_exe, current_exe, why) ) ;
630+ Ok ( ( ) )
631+ }
632+
609633fn main ( ) {
610634 let opts = Cli :: parse ( ) ;
611635
612636 if !matches ! ( opts. command, Commands :: Version { ..} ) {
613- main_version_check ( opts. dev ) . ok ( ) ;
637+ // User didn't ask for a version check|update, we check anyway.
638+ main_version_check ( opts. dev , false ) . unwrap ( ) ;
614639 }
615640
616641 match opts. command {
642+ Commands :: Version { command } => {
643+ match command {
644+ VersionSubcommands :: Check => main_version_check ( opts. dev , true ) . unwrap ( ) ,
645+ VersionSubcommands :: Update => perform_update ( opts. dev ) . unwrap ( ) ,
646+ }
647+ }
617648 Commands :: Auth { command } => {
618649 match command {
619650 AuthCommands :: Login => auth_login ( opts. dev ) . unwrap ( ) ,
@@ -640,21 +671,6 @@ fn main() {
640671 ModelCommands :: Info { model } => model_info ( & model, opts. dev ) . unwrap ( ) ,
641672 }
642673 }
643- Commands :: Test { command } => {
644- match command {
645- Some ( TestSubcommands :: Command1 ) => println ! ( "test command1" ) ,
646- Some ( TestSubcommands :: Command2 ) => println ! ( "test command2" ) ,
647- Some ( TestSubcommands :: Version ) => println ! ( "test version" ) ,
648- None => println ! ( "test" ) ,
649- }
650- }
651- Commands :: Version { command } => {
652- match command {
653- Some ( VersionSubcommands :: Check ) => version_check ( opts. dev ) . unwrap ( ) ,
654- Some ( VersionSubcommands :: Update ) => println ! ( "update" ) ,
655- None => println ! ( "{}" , VERSION ) ,
656- }
657- }
658674 }
659675}
660676
0 commit comments