From 8a8a807547fb12b9f7bc8055596455982fea2e5c Mon Sep 17 00:00:00 2001 From: Johannes Heuel Date: Fri, 16 Feb 2024 14:23:20 +0100 Subject: [PATCH] split main.rs into files --- Cargo.lock | 1 + Cargo.toml | 1 + src/commands/join.rs | 31 ++++ src/commands/leave.rs | 17 +++ src/commands/mod.rs | 20 +++ src/commands/pause.rs | 29 ++++ src/commands/play.rs | 101 +++++++++++++ src/commands/queue.rs | 58 +++++++ src/commands/resume.rs | 29 ++++ src/commands/stop.rs | 29 ++++ src/handler.rs | 76 ++++++++++ src/main.rs | 335 ++++++++--------------------------------- src/metadata.rs | 11 ++ src/state.rs | 15 ++ 14 files changed, 477 insertions(+), 276 deletions(-) create mode 100644 src/commands/join.rs create mode 100644 src/commands/leave.rs create mode 100644 src/commands/mod.rs create mode 100644 src/commands/pause.rs create mode 100644 src/commands/play.rs create mode 100644 src/commands/queue.rs create mode 100644 src/commands/resume.rs create mode 100644 src/commands/stop.rs create mode 100644 src/handler.rs create mode 100644 src/metadata.rs create mode 100644 src/state.rs diff --git a/Cargo.lock b/Cargo.lock index d3eb832..c1366de 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -416,6 +416,7 @@ dependencies = [ "twilight-model", "twilight-standby", "twilight-util", + "url", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index ec4c009..b557579 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,3 +22,4 @@ twilight-cache-inmemory = "0.15" twilight-util = { version = "0.15", features=["builder"] } dotenv = "0.15.0" serde_json = "1.0" +url = "2.5.0" diff --git a/src/commands/join.rs b/src/commands/join.rs new file mode 100644 index 0000000..60df9b8 --- /dev/null +++ b/src/commands/join.rs @@ -0,0 +1,31 @@ +use std::{error::Error, num::NonZeroU64}; +use twilight_model::channel::Message; + +use crate::state::State; + +pub(crate) async fn join( + msg: Message, + state: State, +) -> Result<(), Box> { + let user_id = msg.author.id; + let guild_id = msg.guild_id.ok_or("No guild id attached to the message.")?; + let channel_id = state + .cache + .voice_state(user_id, guild_id) + .ok_or("Cannot get voice state for user")? + .channel_id(); + let channel_id = + NonZeroU64::new(channel_id.into()).ok_or("Joined voice channel must have nonzero ID.")?; + state + .songbird + .join(guild_id, channel_id) + .await + .map_err(|e| format!("Could not join voice channel: {:?}", e))?; + + // signal that we are not listening + if let Some(call_lock) = state.songbird.get(guild_id) { + let mut call = call_lock.lock().await; + call.deafen(true).await?; + } + Ok(()) +} diff --git a/src/commands/leave.rs b/src/commands/leave.rs new file mode 100644 index 0000000..f0f6138 --- /dev/null +++ b/src/commands/leave.rs @@ -0,0 +1,17 @@ +use crate::state::State; +use std::error::Error; +use twilight_model::channel::Message; + +pub(crate) async fn leave( + msg: Message, + state: State, +) -> Result<(), Box> { + tracing::debug!( + "leave command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + let guild_id = msg.guild_id.unwrap(); + state.songbird.leave(guild_id).await?; + Ok(()) +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs new file mode 100644 index 0000000..8ba472b --- /dev/null +++ b/src/commands/mod.rs @@ -0,0 +1,20 @@ +mod join; +pub(crate) use join::join; + +mod leave; +pub(crate) use leave::leave; + +mod pause; +pub(crate) use pause::pause; + +mod play; +pub(crate) use play::play; + +mod queue; +pub(crate) use queue::queue; + +mod resume; +pub(crate) use resume::resume; + +mod stop; +pub(crate) use stop::stop; diff --git a/src/commands/pause.rs b/src/commands/pause.rs new file mode 100644 index 0000000..7a14f9f --- /dev/null +++ b/src/commands/pause.rs @@ -0,0 +1,29 @@ +use crate::state::State; +use std::error::Error; +use twilight_model::channel::Message; + +pub(crate) async fn pause( + msg: Message, + state: State, +) -> Result<(), Box> { + tracing::debug!( + "pause command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + + let guild_id = msg.guild_id.unwrap(); + + if let Some(call_lock) = state.songbird.get(guild_id) { + let call = call_lock.lock().await; + call.queue().pause()?; + } + + state + .http + .create_message(msg.channel_id) + .content("Paused the track")? + .await?; + + Ok(()) +} diff --git a/src/commands/play.rs b/src/commands/play.rs new file mode 100644 index 0000000..827a36e --- /dev/null +++ b/src/commands/play.rs @@ -0,0 +1,101 @@ +use crate::commands::join; +use crate::metadata::{Metadata, MetadataMap}; +use crate::state::State; +use serde_json::Value; +use songbird::input::{Compose, YoutubeDl}; +use std::io::{BufRead, BufReader}; +use std::{error::Error, ops::Sub, time::Duration}; +use tokio::process::Command; +use tracing::debug; +use twilight_model::channel::Message; +use url::Url; + +pub(crate) async fn play( + msg: Message, + state: State, + query: String, +) -> Result<(), Box> { + tracing::debug!( + "play command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + + join(msg.clone(), state.clone()).await?; + + let guild_id = msg.guild_id.unwrap(); + + // handle keyword queries + let query = if Url::parse(&query).is_err() { + format!("ytsearch:{query}") + } else { + query + }; + + // handle playlist links + let urls = if query.contains("list=") { + get_playlist_urls(query).await? + } else { + vec![query] + }; + + for url in urls { + let mut src = YoutubeDl::new(reqwest::Client::new(), url.to_string()); + if let Ok(metadata) = src.aux_metadata().await { + debug!("metadata: {:?}", metadata); + + if let Some(call_lock) = state.songbird.get(guild_id) { + let mut call = call_lock.lock().await; + let handle = call.enqueue_with_preload( + src.into(), + metadata.duration.map(|duration| -> Duration { + if duration.as_secs() > 5 { + duration.sub(Duration::from_secs(5)) + } else { + duration + } + }), + ); + let mut x = handle.typemap().write().await; + x.insert::(Metadata { + title: metadata.title, + duration: metadata.duration, + }); + } + } else { + state + .http + .create_message(msg.channel_id) + .content("Cannot find any results")? + .await?; + } + } + + Ok(()) +} + +async fn get_playlist_urls( + url: String, +) -> Result, Box> { + let output = Command::new("yt-dlp") + .args(vec![&url, "--flat-playlist", "-j"]) + .output() + .await?; + + let reader = BufReader::new(output.stdout.as_slice()); + let urls = reader + .lines() + .flatten() + .map(|line| { + let entry: Value = serde_json::from_str(&line).unwrap(); + entry + .get("webpage_url") + .unwrap() + .as_str() + .unwrap() + .to_string() + }) + .collect(); + + Ok(urls) +} diff --git a/src/commands/queue.rs b/src/commands/queue.rs new file mode 100644 index 0000000..2dda61c --- /dev/null +++ b/src/commands/queue.rs @@ -0,0 +1,58 @@ +use crate::{metadata::MetadataMap, state::State}; +use std::error::Error; +use twilight_model::channel::Message; + +pub(crate) async fn queue( + msg: Message, + state: State, +) -> Result<(), Box> { + tracing::debug!( + "queue command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + let guild_id = msg.guild_id.unwrap(); + + if let Some(call_lock) = state.songbird.get(guild_id) { + let call = call_lock.lock().await; + let queue = call.queue().current_queue(); + let mut message = String::new(); + if queue.is_empty() { + message.push_str("There are no tracks in the queue.\n"); + } else { + message.push_str("Currently playing:\n"); + } + for track in queue { + let map = track.typemap().read().await; + let metadata = map.get::().unwrap(); + message.push_str( + format!( + "* `{}", + metadata.title.clone().unwrap_or("Unknown".to_string()), + ) + .as_str(), + ); + if let Some(duration) = metadata.duration { + let res = duration.as_secs(); + let hours = res / (60 * 60); + let res = res - hours * 60 * 60; + let minutes = res / 60; + let res = res - minutes * 60; + let seconds = res; + message.push_str(" ("); + if hours > 0 { + message.push_str(format!("{:02}:", hours).as_str()); + } + message.push_str(format!("{:02}:{:02}", minutes, seconds).as_str()); + message.push(')'); + } + message.push_str("`\n"); + } + state + .http + .create_message(msg.channel_id) + .content(&message)? + .await?; + } + Ok(()) +} diff --git a/src/commands/resume.rs b/src/commands/resume.rs new file mode 100644 index 0000000..f907754 --- /dev/null +++ b/src/commands/resume.rs @@ -0,0 +1,29 @@ +use crate::state::State; +use std::error::Error; +use twilight_model::channel::Message; + +pub(crate) async fn resume( + msg: Message, + state: State, +) -> Result<(), Box> { + tracing::debug!( + "resume command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + + let guild_id = msg.guild_id.unwrap(); + + if let Some(call_lock) = state.songbird.get(guild_id) { + let call = call_lock.lock().await; + call.queue().resume()?; + } + + state + .http + .create_message(msg.channel_id) + .content("Resumed the track")? + .await?; + + Ok(()) +} diff --git a/src/commands/stop.rs b/src/commands/stop.rs new file mode 100644 index 0000000..f41a7d3 --- /dev/null +++ b/src/commands/stop.rs @@ -0,0 +1,29 @@ +use crate::state::State; +use std::error::Error; +use twilight_model::channel::Message; + +pub(crate) async fn stop( + msg: Message, + state: State, +) -> Result<(), Box> { + tracing::debug!( + "stop command in channel {} by {}", + msg.channel_id, + msg.author.name + ); + + let guild_id = msg.guild_id.unwrap(); + + if let Some(call_lock) = state.songbird.get(guild_id) { + let call = call_lock.lock().await; + call.queue().stop(); + } + + state + .http + .create_message(msg.channel_id) + .content("Stopped the track and cleared the queue")? + .await?; + + Ok(()) +} diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..7922b91 --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,76 @@ +use crate::commands::{join, leave, pause, play, queue, resume, stop}; +use crate::state::State; + +use futures::Future; +use std::error::Error; +use std::sync::Arc; + +use twilight_gateway::Event; +use twilight_model::channel::Message; + +enum ChatCommand { + Play(Message, String), + Stop(Message), + Pause(Message), + Resume(Message), + Leave(Message), + Join(Message), + Queue(Message), + NotImplemented, +} + +fn parse_command(event: Event) -> Option { + match event { + Event::MessageCreate(msg_create) => { + if msg_create.guild_id.is_none() || !msg_create.content.starts_with('!') { + return None; + } + let split: Vec<&str> = msg_create.content.splitn(2, ' ').collect(); + match split.as_slice() { + ["!play", query] => { + Some(ChatCommand::Play(msg_create.0.clone(), query.to_string())) + } + ["!stop"] | ["!stop", _] => Some(ChatCommand::Stop(msg_create.0)), + ["!pause"] | ["!pause", _] => Some(ChatCommand::Pause(msg_create.0)), + ["!resume"] | ["!resume", _] => Some(ChatCommand::Resume(msg_create.0)), + ["!leave"] | ["!leave", _] => Some(ChatCommand::Leave(msg_create.0)), + ["!join"] | ["!join", _] => Some(ChatCommand::Join(msg_create.0)), + ["!queue"] | ["!queue", _] => Some(ChatCommand::Queue(msg_create.0)), + _ => Some(ChatCommand::NotImplemented), + } + } + _ => None, + } +} + +fn spawn( + fut: impl Future>> + Send + 'static, +) { + tokio::spawn(async move { + if let Err(why) = fut.await { + tracing::debug!("handler error: {:?}", why); + } + }); +} + +pub(crate) struct Handler { + state: State, +} + +impl Handler { + pub(crate) fn new(state: State) -> Self { + Self { state } + } + pub(crate) async fn act(&mut self, event: Event) { + match parse_command(event) { + Some(ChatCommand::Play(msg, query)) => spawn(play(msg, Arc::clone(&self.state), query)), + Some(ChatCommand::Stop(msg)) => spawn(stop(msg, Arc::clone(&self.state))), + Some(ChatCommand::Pause(msg)) => spawn(pause(msg, Arc::clone(&self.state))), + Some(ChatCommand::Resume(msg)) => spawn(resume(msg, Arc::clone(&self.state))), + Some(ChatCommand::Leave(msg)) => spawn(leave(msg, Arc::clone(&self.state))), + Some(ChatCommand::Join(msg)) => spawn(join(msg, Arc::clone(&self.state))), + Some(ChatCommand::Queue(msg)) => spawn(queue(msg, Arc::clone(&self.state))), + _ => {} + } + } +} diff --git a/src/main.rs b/src/main.rs index 28b6ca5..ff07403 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,95 +1,30 @@ +mod handler; +use handler::Handler; +mod commands; +mod metadata; +mod state; use dotenv::dotenv; use futures::StreamExt; -use serde_json::Value; -use songbird::{ - input::{Compose, YoutubeDl}, - shards::TwilightMap, - typemap::TypeMapKey, - Songbird, +use songbird::{shards::TwilightMap, Songbird}; +use state::StateRef; +use std::{env, error::Error, sync::Arc}; +use tokio::{ + select, + signal::unix::{signal, SignalKind}, + sync::watch, }; -use std::{ - env, - error::Error, - future::Future, - io::{BufRead, BufReader}, - num::NonZeroU64, - ops::Sub, - sync::Arc, - time::Duration, -}; -use tokio::process::Command; -use tracing::debug; +use tracing::{debug, info}; use twilight_cache_inmemory::InMemoryCache; use twilight_gateway::{ stream::{self, ShardEventStream}, - Event, Intents, Shard, + Intents, Shard, }; use twilight_http::Client as HttpClient; use twilight_model::application::command::CommandType; -use twilight_model::{channel::Message, id::Id}; +use twilight_model::id::Id; use twilight_standby::Standby; use twilight_util::builder::command::{CommandBuilder, StringBuilder}; -type State = Arc; - -#[derive(Debug)] -struct StateRef { - http: HttpClient, - cache: InMemoryCache, - songbird: Songbird, - standby: Standby, -} - -struct Metadata { - title: Option, - artist: Option, -} -struct MetadataMap; -impl TypeMapKey for MetadataMap { - type Value = Metadata; -} - -enum ChatCommand { - Play(Message, String), - Stop(Message), - Leave(Message), - Join(Message), - Queue(Message), - NotImplemented, -} - -fn parse_command(event: Event) -> Option { - match event { - Event::MessageCreate(msg_create) => { - if msg_create.guild_id.is_none() || !msg_create.content.starts_with('!') { - return None; - } - let split: Vec<&str> = msg_create.content.splitn(2, ' ').collect(); - match split.as_slice() { - ["!play", query] => { - Some(ChatCommand::Play(msg_create.0.clone(), query.to_string())) - } - ["!stop"] | ["!stop", _] => Some(ChatCommand::Stop(msg_create.0)), - ["!leave"] | ["!leave", _] => Some(ChatCommand::Leave(msg_create.0)), - ["!join"] | ["!join", _] => Some(ChatCommand::Join(msg_create.0)), - ["!queue"] | ["!queue", _] => Some(ChatCommand::Queue(msg_create.0)), - _ => Some(ChatCommand::NotImplemented), - } - } - _ => None, - } -} - -fn spawn( - fut: impl Future>> + Send + 'static, -) { - tokio::spawn(async move { - if let Err(why) = fut.await { - tracing::debug!("handler error: {:?}", why); - } - }); -} - #[tokio::main] async fn main() -> Result<(), Box> { dotenv().ok(); @@ -97,6 +32,20 @@ async fn main() -> Result<(), Box> { // Initialize the tracing subscriber. tracing_subscriber::fmt::init(); + let (stop_tx, mut stop_rx) = watch::channel(()); + + tokio::spawn(async move { + let mut sigterm = signal(SignalKind::terminate()).unwrap(); + let mut sigint = signal(SignalKind::interrupt()).unwrap(); + loop { + select! { + _ = sigterm.recv() => println!("Receive SIGTERM"), + _ = sigint.recv() => println!("Receive SIGTERM"), + }; + stop_tx.send(()).unwrap(); + } + }); + let (mut shards, state) = { let token = env::var("DISCORD_TOKEN")?; let app_id = env::var("DISCORD_APP_ID")?.parse()?; @@ -148,209 +97,43 @@ async fn main() -> Result<(), Box> { ) }; + let mut handler = Handler::new(Arc::clone(&state)); let mut stream = ShardEventStream::new(shards.iter_mut()); loop { - let event = match stream.next().await { - Some((_, Ok(event))) => event, - Some((_, Err(source))) => { - tracing::warn!(?source, "error receiving event"); - - if source.is_fatal() { - break; + select! { + biased; + _ = stop_rx.changed() => { + for guild in state.cache.iter().guilds(){ + info!("Leaving guild {:?}", guild.id()); + state.songbird.leave(guild.id()).await?; } + // need to grab next event to properly leave voice channels + stream.next().await; + break; + }, + next = stream.next() => { + let event = match next { + Some((_, Ok(event))) => event, + Some((_, Err(source))) => { + tracing::warn!(?source, "error receiving event"); - continue; - } - None => break, - }; - debug!("Event: {:?}", &event); - - state.cache.update(&event); - state.standby.process(&event); - state.songbird.process(&event).await; - - match parse_command(event) { - Some(ChatCommand::Play(msg, query)) => spawn(play(msg, Arc::clone(&state), query)), - Some(ChatCommand::Stop(msg)) => spawn(stop(msg, Arc::clone(&state))), - Some(ChatCommand::Leave(msg)) => spawn(leave(msg, Arc::clone(&state))), - Some(ChatCommand::Join(msg)) => spawn(join(msg, Arc::clone(&state))), - Some(ChatCommand::Queue(msg)) => spawn(queue(msg, Arc::clone(&state))), - _ => {} - } - } - - Ok(()) -} - -async fn join(msg: Message, state: State) -> Result<(), Box> { - let user_id = msg.author.id; - let guild_id = msg.guild_id.ok_or("No guild id attached to the message.")?; - let channel_id = state - .cache - .voice_state(user_id, guild_id) - .ok_or("Cannot get voice state for user")? - .channel_id(); - let channel_id = - NonZeroU64::new(channel_id.into()).ok_or("Joined voice channel must have nonzero ID.")?; - state - .songbird - .join(guild_id, channel_id) - .await - .map_err(|e| format!("Could not join voice channel: {:?}", e))?; - - // signal that we are not listening - if let Some(call_lock) = state.songbird.get(guild_id) { - let mut call = call_lock.lock().await; - call.deafen(true).await?; - } - Ok(()) -} - -async fn leave(msg: Message, state: State) -> Result<(), Box> { - tracing::debug!( - "leave command in channel {} by {}", - msg.channel_id, - msg.author.name - ); - let guild_id = msg.guild_id.unwrap(); - state.songbird.leave(guild_id).await?; - Ok(()) -} - -async fn get_playlist_urls( - url: String, -) -> Result, Box> { - let output = Command::new("yt-dlp") - .args(vec![&url, "--flat-playlist", "-j"]) - .output() - .await?; - - let reader = BufReader::new(output.stdout.as_slice()); - let urls = reader - .lines() - .flatten() - .map(|line| { - let entry: Value = serde_json::from_str(&line).unwrap(); - entry - .get("webpage_url") - .unwrap() - .as_str() - .unwrap() - .to_string() - }) - .collect(); - - Ok(urls) -} - -async fn queue(msg: Message, state: State) -> Result<(), Box> { - tracing::debug!( - "queue command in channel {} by {}", - msg.channel_id, - msg.author.name - ); - let guild_id = msg.guild_id.unwrap(); - - if let Some(call_lock) = state.songbird.get(guild_id) { - let call = call_lock.lock().await; - let queue = call.queue().current_queue(); - let mut message = String::new(); - message.push_str("Currently playing:\n"); - for track in queue { - let map = track.typemap().read().await; - let metadata = map.get::().unwrap(); - message.push_str( - format!( - "* {}\n", - metadata.title.clone().unwrap_or("Unknown".to_string()), - ) - .as_str(), - ); - } - state - .http - .create_message(msg.channel_id) - .content(&message)? - .await?; - } - Ok(()) -} - -async fn play( - msg: Message, - state: State, - query: String, -) -> Result<(), Box> { - tracing::debug!( - "play command in channel {} by {}", - msg.channel_id, - msg.author.name - ); - - join(msg.clone(), state.clone()).await?; - - let guild_id = msg.guild_id.unwrap(); - - let urls = if query.contains("list=") { - get_playlist_urls(query).await? - } else { - vec![query] - }; - - for url in urls { - let mut src = YoutubeDl::new(reqwest::Client::new(), url.to_string()); - if let Ok(metadata) = src.aux_metadata().await { - debug!("metadata: {:?}", metadata); - - if let Some(call_lock) = state.songbird.get(guild_id) { - let mut call = call_lock.lock().await; - let handle = call.enqueue_with_preload( - src.into(), - metadata.duration.map(|duration| -> Duration { - if duration.as_secs() > 5 { - duration.sub(Duration::from_secs(5)) - } else { - duration + if source.is_fatal() { + break; } - }), - ); - let mut x = handle.typemap().write().await; - x.insert::(Metadata { - title: metadata.title, - artist: metadata.artist, - }); + + continue; + } + None => break, + }; + debug!("Event: {:?}", &event); + + state.cache.update(&event); + state.standby.process(&event); + state.songbird.process(&event).await; + + handler.act(event).await; } - } else { - state - .http - .create_message(msg.channel_id) - .content("Cannot find any results")? - .await?; } } - - Ok(()) -} - -async fn stop(msg: Message, state: State) -> Result<(), Box> { - tracing::debug!( - "stop command in channel {} by {}", - msg.channel_id, - msg.author.name - ); - - let guild_id = msg.guild_id.unwrap(); - - if let Some(call_lock) = state.songbird.get(guild_id) { - let mut call = call_lock.lock().await; - call.stop(); - } - - state - .http - .create_message(msg.channel_id) - .content("Stopped the track")? - .await?; - Ok(()) } diff --git a/src/metadata.rs b/src/metadata.rs new file mode 100644 index 0000000..90f304e --- /dev/null +++ b/src/metadata.rs @@ -0,0 +1,11 @@ +use songbird::typemap::TypeMapKey; +use std::time::Duration; + +pub(crate) struct Metadata { + pub(crate) title: Option, + pub(crate) duration: Option, +} +pub(crate) struct MetadataMap; +impl TypeMapKey for MetadataMap { + type Value = Metadata; +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..192a08c --- /dev/null +++ b/src/state.rs @@ -0,0 +1,15 @@ +use songbird::Songbird; +use std::sync::Arc; +use twilight_cache_inmemory::InMemoryCache; +use twilight_http::Client as HttpClient; +use twilight_standby::Standby; + +pub(crate) type State = Arc; + +#[derive(Debug)] +pub(crate) struct StateRef { + pub(crate) http: HttpClient, + pub(crate) cache: InMemoryCache, + pub(crate) songbird: Songbird, + pub(crate) standby: Standby, +}