diff --git a/zoidberg_client/src/main.rs b/zoidberg_client/src/main.rs index f814be8..1fe650a 100644 --- a/zoidberg_client/src/main.rs +++ b/zoidberg_client/src/main.rs @@ -1,4 +1,4 @@ -use clap::{App, Arg}; +use clap::{arg, value_parser, App, Arg}; use env_logger::Env; use futures::future::{AbortHandle, Abortable}; use log; @@ -37,10 +37,11 @@ struct Worker { id: String, secret: String, server: String, + threads: i32, } impl Worker { - async fn new(server: &str, secret: &str) -> Result> { + async fn new(server: &str, secret: &str, threads: i32) -> Result> { let res = build_client(secret) .get(format!("{}/register", server)) .send() @@ -53,6 +54,7 @@ impl Worker { id: r.id, secret: secret.to_string(), server: server.to_string(), + threads, }) } @@ -83,6 +85,7 @@ impl Worker { .post(format!("{}/fetch", self.server)) .json(&FetchRequest { worker_id: self.id.clone(), + threads: self.threads, }) .send() .await?; @@ -136,15 +139,26 @@ async fn main() -> Result<(), Box> { .required(true) .help("Set Zoidberg server address"), ) + .arg( + arg!(-j --threads "Sets number of threads") + .required(false) + .value_parser(value_parser!(i32)), + ) .get_matches(); let server = matches.value_of("server").unwrap(); + let threads: i32 = if let Some(t) = matches.get_one::("threads") { + *t + } else { + 1 + }; + let secret = std::env::var("ZOIDBERG_SECRET").unwrap_or_else(|_| { eprintln!("Please set the $ZOIDBERG_SECRET environment variable"); std::process::exit(1); }); let client = Arc::new( - Worker::new(server, &secret) + Worker::new(server, &secret, threads) .await .expect("Could not create client"), ); diff --git a/zoidberg_lib/src/types.rs b/zoidberg_lib/src/types.rs index 8fcb0db..7fc0620 100644 --- a/zoidberg_lib/src/types.rs +++ b/zoidberg_lib/src/types.rs @@ -46,6 +46,8 @@ pub struct Job { pub cmd: String, #[serde(default = "Status::default")] pub status: Status, + #[serde(default)] + pub threads: i32, } #[derive(Serialize, Deserialize)] @@ -61,6 +63,8 @@ pub struct RegisterResponse { #[derive(Serialize, Deserialize)] pub struct FetchRequest { pub worker_id: String, + #[serde(default)] + pub threads: i32, } #[derive(Serialize, Deserialize)] diff --git a/zoidberg_server/src/main.rs b/zoidberg_server/src/main.rs index 1c83231..89d53dd 100644 --- a/zoidberg_server/src/main.rs +++ b/zoidberg_server/src/main.rs @@ -64,7 +64,9 @@ async fn fetch( f: web::Json, _: Authorization, ) -> Result { - let requesting_worker = f.into_inner().worker_id; + let f = f.into_inner(); + let requesting_worker = f.worker_id; + let threads = f.threads; { let workers = data.workers.lock().unwrap(); if workers.iter().filter(|w| w.id == requesting_worker).count() != 1 { @@ -74,15 +76,24 @@ async fn fetch( } } let mut new_jobs = data.new_jobs.lock().unwrap(); - if let Some(j) = new_jobs.pop() { + new_jobs.sort_by(|a, b| b.threads.cmp(&a.threads)); + + if let Some(j) = new_jobs + .iter() + .filter(|x| x.threads <= threads) + .cloned() + .collect::>() + .first() + { let mut jobs = data.jobs.lock().unwrap(); for cj in jobs.iter_mut() { if cj.id == j.id { cj.status = Status::Running(requesting_worker.clone()) } } - return Ok(web::Json(FetchResponse::Jobs(vec![j]))); - } + new_jobs.retain(|x| x.id != j.id); + return Ok(web::Json(FetchResponse::Jobs(vec![j.clone()]))); + }; Ok(web::Json(FetchResponse::Nop)) } @@ -277,11 +288,26 @@ mod tests { id: "some_worker".to_string(), last_heartbeat: None, }]), - new_jobs: Mutex::new(vec![Job { - id: jobid, - cmd: cmd.clone(), - status: Status::Submitted, - }]), + new_jobs: Mutex::new(vec![ + Job { + id: jobid, + cmd: cmd.clone(), + status: Status::Submitted, + threads: 1, + }, + Job { + id: jobid + 1, + cmd: cmd.clone(), + status: Status::Submitted, + threads: 2, + }, + Job { + id: jobid + 2, + cmd: cmd.clone(), + status: Status::Submitted, + threads: 3, + }, + ]), jobs: Mutex::new(Vec::new()), })) .service(fetch), @@ -291,6 +317,7 @@ mod tests { .append_header(("cookie", "secret")) .set_json(FetchRequest { worker_id: "some_worker".to_string(), + threads: 1, }) .uri("/fetch") .to_request(); @@ -324,6 +351,7 @@ mod tests { id: jobid, cmd: cmd.clone(), status: Status::Submitted, + threads: 1, }]), })) .service(status), @@ -375,6 +403,7 @@ mod tests { id: 0, cmd: String::from("hi"), status: Status::Submitted, + threads: 1, }]) .uri("/submit") .to_request();