Skip to content

Commit 4c028b0

Browse files
committed
Implement auto-update
TODO: - Test actual executable (built in github) - test on macos - cleanup error handling
1 parent 96132e1 commit 4c028b0

File tree

1 file changed

+72
-56
lines changed

1 file changed

+72
-56
lines changed

src/main.rs

Lines changed: 72 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@ use version_compare::Version;
1616
use chrono::{DateTime, Utc, Duration};
1717
use serde::{Serialize, Deserialize};
1818
use chrono::serde::ts_seconds;
19+
use std::os::unix::fs::PermissionsExt;
1920

2021
const VERSION: &str = env!("CARGO_PKG_VERSION");
2122

2223
const DEEPINFRA_HOST_PROD: &str = "https://api.deepinfra.com";
2324
const DEEPINFRA_HOST_DEV: &str = "https://localhost:7001";
2425
const LOGIN_PATH: &str = "/github/login";
2526
const VERSION_CHECK_SEC: i64 = 60 * 60 * 24 * 7; // 1 week
27+
const GITHUB_RELEASE_LATEST: &str = "https://github.com/deepinfra/deepctl/releases/latest/download";
2628

2729

2830
#[derive(Serialize, Deserialize, PartialEq, Debug)]
@@ -45,7 +47,6 @@ struct Cli {
4547
dev: bool,
4648
}
4749

48-
4950
#[derive(Subcommand)]
5051
enum Commands {
5152
/// Authentication commands for Deep Infra
@@ -73,16 +74,10 @@ enum Commands {
7374
#[arg(short('i'), value_parser = infer_args_parser)]
7475
args: Vec<(String, String)>,
7576
},
76-
77-
/// test command with subcommands
78-
Test {
79-
#[command(subcommand)]
80-
command: Option<TestSubcommands>,
81-
},
8277
/// version command
8378
Version {
8479
#[command(subcommand)]
85-
command: Option<VersionSubcommands>,
80+
command: VersionSubcommands,
8681
}
8782
}
8883

@@ -240,23 +235,17 @@ fn get_config_path(dev: bool) -> std::io::Result<std::path::PathBuf> {
240235
Ok(config_path)
241236
}
242237

243-
fn get_version_path() -> std::io::Result<std::path::PathBuf> {
244-
let home = dirs::home_dir().unwrap();
245-
let path = home.join(".deepinfra/");
246-
let config_path = path.join("version.yaml");
247-
Ok(config_path)
238+
fn get_version_path() -> std::path::PathBuf {
239+
dirs::home_dir().unwrap().join(".deepinfra").join("version.yaml")
248240
}
249241

250242
fn read_version_data() -> std::io::Result<VersionCheck> {
251-
let config_path = get_version_path()?;
252-
let mut file = File::open(config_path)?;
243+
let config_path = get_version_path();
244+
let mut file = File::open(&config_path)?;
253245
let mut contents = String::new();
254246
file.read_to_string(&mut contents)?;
255247

256-
match serde_yaml::from_str(&contents) {
257-
Ok(v) => Ok(v),
258-
Err(e) => Err(io::Error::new(io::ErrorKind::Other, e)),
259-
}
248+
serde_yaml::from_str(&contents).map_err(|e| io::Error::new(io::ErrorKind::Other, e))
260249
}
261250

262251
fn auth_logout(_dev: bool) -> std::io::Result<()> {
@@ -553,42 +542,44 @@ fn check_version_with_server(dev: bool) -> std::io::Result<VersionCheck> {
553542
Ok(version_data)
554543
}
555544

556-
fn version_check(dev: bool) -> std::io::Result<()> {
557-
let version_data = check_version_with_server(dev)?;
558-
do_version_check(&version_data, false)?;
559-
Ok(())
560-
}
561-
562-
fn main_version_check(dev: bool) -> std::io::Result<()> {
563-
let mut version_data: VersionCheck = read_version_data()?;
564-
// println!("min_version: {}", version_data.min);
565-
// println!("update_version: {}", version_data.update);
566-
// println!("latest_version: {}", version_data.latest);
567-
// println!("last_check: {}", version_data.last_check);
568-
version_data = if version_data.last_check < Utc::now() - Duration::seconds(VERSION_CHECK_SEC) {
545+
fn main_version_check(dev: bool, force: bool) -> std::io::Result<()> {
546+
let crnt_version_data: Option<VersionCheck> = read_version_data().ok();
547+
let version_data = if
548+
crnt_version_data.is_none() ||
549+
force ||
550+
crnt_version_data.as_ref().unwrap().last_check < Utc::now() - Duration::seconds(VERSION_CHECK_SEC) {
569551
println!("checking version with server...");
570552
check_version_with_server(dev)?
571553
} else {
572-
version_data
554+
crnt_version_data.unwrap()
573555
};
574556
do_version_check(&version_data, true)?;
575557
Ok(())
576558
}
577559

560+
fn prompt_update(reason: &str, latest: &str) {
561+
println!("Your version {} is {}. Please update to the latest version {}.",
562+
VERSION, reason, latest);
563+
564+
let mut sudo_str = "sudo ";
565+
if let Ok(exe) = std::env::current_exe() {
566+
if exe.as_path().starts_with(dirs::home_dir().unwrap().as_path()) {
567+
sudo_str = "";
568+
}
569+
}
570+
println!("Update to the latest version using `{}deepctl version update`", sudo_str);
571+
}
572+
578573
fn do_version_check(version_data: &VersionCheck, silent: bool) -> std::io::Result<()> {
579574
let this_version = Version::from(VERSION).unwrap();
580575
let min_version = Version::from(&version_data.min).unwrap();
581576
let update_version = Version::from(&version_data.update).unwrap();
582577

583578
if this_version < min_version {
584-
println!("Your version {} is too old. Please update to the latest version {}.",
585-
VERSION, version_data.latest);
586-
println!("Update to the latest version using `deepctl version update`");
579+
prompt_update("too old", &version_data.latest);
587580
exit(1);
588581
} else if this_version < update_version {
589-
println!("Your version ({}) is outdated. Please update to the latest version {}.",
590-
VERSION, version_data.latest);
591-
println!("Update to the latest version using `deepctl version update`");
582+
prompt_update("outdated", &version_data.latest)
592583
} else {
593584
if !silent {
594585
println!("Your version ({}) is up to date.", VERSION);
@@ -598,22 +589,62 @@ fn do_version_check(version_data: &VersionCheck, silent: bool) -> std::io::Resul
598589
}
599590

600591
fn write_version_data(version_data: &VersionCheck) -> std::io::Result<()> {
601-
let version_path = get_version_path()?;
592+
let version_path = get_version_path();
602593
fs::create_dir_all(&version_path.parent().unwrap())?;
603594
let mut version_file = File::create(&version_path)?;
604595
let yaml = serde_yaml::to_string(&version_data).unwrap();
605596
version_file.write_all(yaml.as_bytes())?;
606597
Ok(())
607598
}
608599

600+
fn perform_update(dev: bool) -> std::io::Result<()> {
601+
let suffix = if cfg!(target_os = "macos") {
602+
"-macos"
603+
} else {
604+
"-linux"
605+
};
606+
607+
let client = get_http_client(dev);
608+
let uri = format!("{}/deepctl{}", GITHUB_RELEASE_LATEST, suffix);
609+
let mut res = client.get(&uri)
610+
.timeout(std::time::Duration::from_secs(300))
611+
.send().unwrap_or_else(|why| panic!("Failed to fetch {}: {}", uri, why));
612+
if !res.status().is_success() {
613+
panic!("Failed to fetch {}", uri);
614+
}
615+
616+
let current_exe = std::env::current_exe().unwrap_or_else(|why| {
617+
panic!("Can't figure out where deepctl is installed: {}", why);
618+
});
619+
let mut tmp_exe = current_exe.clone();
620+
tmp_exe.set_file_name(current_exe.file_name().unwrap().to_str().unwrap().to_owned() + ".tmp");
621+
{
622+
let mut tmp_exe_f = File::create(&tmp_exe)
623+
.unwrap_or_else(|why| panic!("couldn't open {:?}: {}", tmp_exe, why));
624+
res.copy_to(&mut tmp_exe_f)
625+
.unwrap_or_else(|why| panic!("Failed to save new version to {:?}: {}", tmp_exe, why));
626+
}
627+
fs::set_permissions(&tmp_exe, fs::Permissions::from_mode(0o755)).unwrap();
628+
std::fs::rename(&tmp_exe, &current_exe)
629+
.unwrap_or_else(|why| panic!("Failed to rename {:?} to {:?}: {}", tmp_exe, current_exe, why));
630+
Ok(())
631+
}
632+
609633
fn main() {
610634
let opts = Cli::parse();
611635

612636
if !matches!(opts.command, Commands::Version{..}) {
613-
main_version_check(opts.dev).ok();
637+
// User didn't ask for a version check|update, we check anyway.
638+
main_version_check(opts.dev, false).unwrap();
614639
}
615640

616641
match opts.command {
642+
Commands::Version { command } => {
643+
match command {
644+
VersionSubcommands::Check => main_version_check(opts.dev, true).unwrap(),
645+
VersionSubcommands::Update => perform_update(opts.dev).unwrap(),
646+
}
647+
}
617648
Commands::Auth { command } => {
618649
match command {
619650
AuthCommands::Login => auth_login(opts.dev).unwrap(),
@@ -640,21 +671,6 @@ fn main() {
640671
ModelCommands::Info { model } => model_info(&model, opts.dev).unwrap(),
641672
}
642673
}
643-
Commands::Test { command } => {
644-
match command {
645-
Some(TestSubcommands::Command1) => println!("test command1"),
646-
Some(TestSubcommands::Command2) => println!("test command2"),
647-
Some(TestSubcommands::Version) => println!("test version"),
648-
None => println!("test"),
649-
}
650-
}
651-
Commands::Version { command } => {
652-
match command {
653-
Some(VersionSubcommands::Check) => version_check(opts.dev).unwrap(),
654-
Some(VersionSubcommands::Update) => println!("update"),
655-
None => println!("{}", VERSION),
656-
}
657-
}
658674
}
659675
}
660676

0 commit comments

Comments
 (0)