add thread requirement for jobs
All checks were successful
continuous-integration/drone/push Build is passing

This commit is contained in:
Johannes Heuel
2022-10-20 11:38:24 +02:00
parent 71098b44e6
commit ea30f57ea9
3 changed files with 59 additions and 12 deletions

View File

@@ -1,4 +1,4 @@
use clap::{App, Arg};
use clap::{arg, value_parser, App, Arg};
use env_logger::Env;
use futures::future::{AbortHandle, Abortable};
use log;
@@ -37,10 +37,11 @@ struct Worker {
id: String,
secret: String,
server: String,
threads: i32,
}
impl Worker {
async fn new(server: &str, secret: &str) -> Result<Worker, Box<dyn Error>> {
async fn new(server: &str, secret: &str, threads: i32) -> Result<Worker, Box<dyn Error>> {
let res = build_client(secret)
.get(format!("{}/register", server))
.send()
@@ -53,6 +54,7 @@ impl Worker {
id: r.id,
secret: secret.to_string(),
server: server.to_string(),
threads,
})
}
@@ -83,6 +85,7 @@ impl Worker {
.post(format!("{}/fetch", self.server))
.json(&FetchRequest {
worker_id: self.id.clone(),
threads: self.threads,
})
.send()
.await?;
@@ -136,15 +139,26 @@ async fn main() -> Result<(), Box<dyn Error>> {
.required(true)
.help("Set Zoidberg server address"),
)
.arg(
arg!(-j --threads <VALUE> "Sets number of threads")
.required(false)
.value_parser(value_parser!(i32)),
)
.get_matches();
let server = matches.value_of("server").unwrap();
let threads: i32 = if let Some(t) = matches.get_one::<i32>("threads") {
*t
} else {
1
};
let secret = std::env::var("ZOIDBERG_SECRET").unwrap_or_else(|_| {
eprintln!("Please set the $ZOIDBERG_SECRET environment variable");
std::process::exit(1);
});
let client = Arc::new(
Worker::new(server, &secret)
Worker::new(server, &secret, threads)
.await
.expect("Could not create client"),
);

View File

@@ -46,6 +46,8 @@ pub struct Job {
pub cmd: String,
#[serde(default = "Status::default")]
pub status: Status,
#[serde(default)]
pub threads: i32,
}
#[derive(Serialize, Deserialize)]
@@ -61,6 +63,8 @@ pub struct RegisterResponse {
#[derive(Serialize, Deserialize)]
pub struct FetchRequest {
pub worker_id: String,
#[serde(default)]
pub threads: i32,
}
#[derive(Serialize, Deserialize)]

View File

@@ -64,7 +64,9 @@ async fn fetch(
f: web::Json<FetchRequest>,
_: Authorization,
) -> Result<impl Responder> {
let requesting_worker = f.into_inner().worker_id;
let f = f.into_inner();
let requesting_worker = f.worker_id;
let threads = f.threads;
{
let workers = data.workers.lock().unwrap();
if workers.iter().filter(|w| w.id == requesting_worker).count() != 1 {
@@ -74,15 +76,24 @@ async fn fetch(
}
}
let mut new_jobs = data.new_jobs.lock().unwrap();
if let Some(j) = new_jobs.pop() {
new_jobs.sort_by(|a, b| b.threads.cmp(&a.threads));
if let Some(j) = new_jobs
.iter()
.filter(|x| x.threads <= threads)
.cloned()
.collect::<Vec<Job>>()
.first()
{
let mut jobs = data.jobs.lock().unwrap();
for cj in jobs.iter_mut() {
if cj.id == j.id {
cj.status = Status::Running(requesting_worker.clone())
}
}
return Ok(web::Json(FetchResponse::Jobs(vec![j])));
}
new_jobs.retain(|x| x.id != j.id);
return Ok(web::Json(FetchResponse::Jobs(vec![j.clone()])));
};
Ok(web::Json(FetchResponse::Nop))
}
@@ -277,11 +288,26 @@ mod tests {
id: "some_worker".to_string(),
last_heartbeat: None,
}]),
new_jobs: Mutex::new(vec![Job {
id: jobid,
cmd: cmd.clone(),
status: Status::Submitted,
}]),
new_jobs: Mutex::new(vec![
Job {
id: jobid,
cmd: cmd.clone(),
status: Status::Submitted,
threads: 1,
},
Job {
id: jobid + 1,
cmd: cmd.clone(),
status: Status::Submitted,
threads: 2,
},
Job {
id: jobid + 2,
cmd: cmd.clone(),
status: Status::Submitted,
threads: 3,
},
]),
jobs: Mutex::new(Vec::new()),
}))
.service(fetch),
@@ -291,6 +317,7 @@ mod tests {
.append_header(("cookie", "secret"))
.set_json(FetchRequest {
worker_id: "some_worker".to_string(),
threads: 1,
})
.uri("/fetch")
.to_request();
@@ -324,6 +351,7 @@ mod tests {
id: jobid,
cmd: cmd.clone(),
status: Status::Submitted,
threads: 1,
}]),
}))
.service(status),
@@ -375,6 +403,7 @@ mod tests {
id: 0,
cmd: String::from("hi"),
status: Status::Submitted,
threads: 1,
}])
.uri("/submit")
.to_request();