diff --git a/.gitea/workflows/ci.yml b/.gitea/workflows/ci.yml index 17fe1c2..1061651 100644 --- a/.gitea/workflows/ci.yml +++ b/.gitea/workflows/ci.yml @@ -2,7 +2,9 @@ name: tests on: push: - pull_request: + +paths-ignore: +- 'README.md' env: CARGO_TERM_COLOR: always @@ -36,6 +38,11 @@ jobs: - uses: actions/checkout@v4 - run: sudo apt-get update - run: sudo apt-get install -y cmake + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - name: Install yt-dlp + run: pip install yt-dlp - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run tests run: cargo test --verbose --all-features diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 19a7ef8..524a185 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,6 +37,11 @@ jobs: - uses: actions/checkout@v4 - run: sudo apt-get update - run: sudo apt-get install -y cmake + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + - name: Install yt-dlp + run: pip install yt-dlp - uses: actions-rust-lang/setup-rust-toolchain@v1 - name: Run tests run: cargo test --verbose --all-features diff --git a/Cargo.lock b/Cargo.lock index 2265db8..0bc944d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1064,9 +1064,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "mime" @@ -1209,6 +1209,8 @@ name = "ohrwurm" version = "0.1.0" dependencies = [ "anyhow", + "async-trait", + "dashmap", "dotenv", "futures", "regex", @@ -1227,6 +1229,7 @@ dependencies = [ "twilight-standby", "twilight-util", "url", + "uuid", ] [[package]] @@ -1488,9 +1491,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" dependencies = [ "bitflags 2.5.0", ] diff --git a/Cargo.toml b/Cargo.toml index b31ce0a..475c80f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,3 +25,6 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" url = "2.5.0" anyhow = "1.0.86" +dashmap = "5.5.3" +async-trait = "0.1.80" +uuid = "1.8.0" diff --git a/src/commands/loop_queue.rs b/src/commands/loop_queue.rs new file mode 100644 index 0000000..50f8090 --- /dev/null +++ b/src/commands/loop_queue.rs @@ -0,0 +1,120 @@ +use crate::state::{State, StateRef}; +use async_trait::async_trait; +use songbird::input::Compose; +use songbird::{Event, EventContext, EventHandler, TrackEvent}; +use std::ops::Sub; +use std::time::Duration; +use std::{error::Error, sync::Arc}; +use twilight_model::{ + gateway::payload::incoming::InteractionCreate, + http::interaction::{InteractionResponse, InteractionResponseType}, + id::{marker::GuildMarker, Id}, +}; +use twilight_util::builder::InteractionResponseDataBuilder; + +pub(crate) async fn loop_queue( + interaction: Box, + state: State, +) -> Result<(), Box> { + tracing::debug!( + "loop command in guild {:?} in channel {:?} by {:?}", + interaction.guild_id, + interaction.channel, + interaction.author(), + ); + + let guild_id: Id = if let Some(guild_id) = interaction.guild_id { + guild_id + } else { + return Ok(()); + }; + + let looping = if let Some(mut settings) = state.guild_settings.get_mut(&guild_id) { + settings.loop_queue = !settings.loop_queue; + settings.loop_queue + } else { + false + }; + + if let Some(call_lock) = state.songbird.get(guild_id) { + let mut call = call_lock.lock().await; + call.add_global_event( + Event::Track(TrackEvent::End), + TrackEndNotifier { + guild_id, + state: Arc::clone(&state), + }, + ); + } + + let mut message = "I'm not looping anymore!".to_string(); + if looping { + message = "I'm now looping the current queue!".to_string(); + } + + let interaction_response_data = InteractionResponseDataBuilder::new() + .content(message) + .build(); + + let response = InteractionResponse { + kind: InteractionResponseType::ChannelMessageWithSource, + data: Some(interaction_response_data), + }; + + state + .http + .interaction(interaction.application_id) + .create_response(interaction.id, &interaction.token, &response) + .await?; + + Ok(()) +} + +struct TrackEndNotifier { + guild_id: Id, + state: Arc, +} + +#[async_trait] +impl EventHandler for TrackEndNotifier { + async fn act(&self, ctx: &EventContext<'_>) -> Option { + if !self + .state + .guild_settings + .get(&self.guild_id) + .unwrap() + .loop_queue + { + return None; + } + let EventContext::Track(track_list) = ctx else { + return None; + }; + let (_, track_handle) = track_list.first()?; + if let Some(yt) = self + .state + .tracks + .get(&self.guild_id) + .unwrap() + .get(&track_handle.uuid()) + { + let mut src = yt.clone(); + if let Ok(metadata) = src.aux_metadata().await { + if let Some(call_lock) = self.state.songbird.get(self.guild_id) { + let mut call = call_lock.lock().await; + call.enqueue_with_preload( + src.into(), + metadata.duration.map(|duration| -> Duration { + if duration.as_secs() > 5 { + duration.sub(Duration::from_secs(5)) + } else { + duration + } + }), + ); + } + } + } + None + } +} diff --git a/src/commands/mod.rs b/src/commands/mod.rs index 7c4d557..21c6fcb 100644 --- a/src/commands/mod.rs +++ b/src/commands/mod.rs @@ -7,6 +7,12 @@ pub(crate) use leave::leave; mod pause; pub(crate) use pause::pause; +mod skip; +pub(crate) use skip::skip; + +mod loop_queue; +pub(crate) use loop_queue::loop_queue; + mod play; pub(crate) use play::play; diff --git a/src/commands/play.rs b/src/commands/play.rs index 7f4f4ae..2b986ec 100644 --- a/src/commands/play.rs +++ b/src/commands/play.rs @@ -5,8 +5,10 @@ use crate::state::State; use serde::{Deserialize, Serialize}; use songbird::input::{Compose, YoutubeDl}; +use songbird::tracks::Track; use std::io::{BufRead, BufReader}; -use std::{error::Error, ops::Sub, time::Duration}; +use std::ops::Sub; +use std::{error::Error, time::Duration}; use tokio::process::Command; use tracing::debug; use twilight_model::channel::message::MessageFlags; @@ -18,11 +20,12 @@ use url::Url; #[derive(Debug, Serialize, Deserialize)] struct YouTubeTrack { - url: String, + url: Option, + original_url: Option, title: String, channel: String, - playlist: String, - playlist_id: String, + playlist: Option, + playlist_id: Option, duration_string: String, } @@ -102,8 +105,16 @@ pub(crate) async fn play( let first_track = tracks.first().unwrap(); let content = format!( "Adding playlist [{}]({})", - first_track.playlist, - build_playlist_url(&first_track.playlist_id) + first_track + .playlist + .clone() + .unwrap_or("Unknown".to_string()), + build_playlist_url( + &first_track + .playlist_id + .clone() + .unwrap_or("Unknown".to_string()) + ) ); let embeds = vec![EmbedBuilder::new() .description(content) @@ -125,8 +136,16 @@ pub(crate) async fn play( let mut tracks_added = vec![]; for track in &tracks { tracing::debug!("track: {:?}", track); - let url = track.url.clone(); - let mut src = YoutubeDl::new(reqwest::Client::new(), url.to_string()); + let url = track.url.clone().or(track.original_url.clone()).ok_or("")?; + let mut src = YoutubeDl::new(reqwest::Client::new(), url.clone()); + let s = src.clone(); + let track: Track = src.clone().into(); + state + .tracks + .entry(guild_id) + .or_default() + .insert(track.uuid, s); + if let Ok(metadata) = src.aux_metadata().await { debug!("metadata: {:?}", metadata); tracks_added.push((url.clone(), metadata.title.clone())); @@ -134,7 +153,7 @@ pub(crate) async fn play( if let Some(call_lock) = state.songbird.get(guild_id) { let mut call = call_lock.lock().await; let handle = call.enqueue_with_preload( - src.into(), + track, metadata.duration.map(|duration| -> Duration { if duration.as_secs() > 5 { duration.sub(Duration::from_secs(5)) @@ -169,8 +188,16 @@ pub(crate) async fn play( let first_track = tracks.first().unwrap(); content.push_str(&format!( "Adding playlist: [{}]({})\n", - &first_track.playlist, - build_playlist_url(&first_track.playlist_id) + &first_track + .playlist + .clone() + .unwrap_or("Unknown".to_string()), + build_playlist_url( + &first_track + .playlist_id + .clone() + .unwrap_or("Unknown".to_string()) + ) )); content.push_str(&format!( "Added {} tracks to the queue:\n", @@ -192,3 +219,23 @@ pub(crate) async fn play( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_get_tracks() { + let urls = [ + "https://www.youtube.com/playlist?list=PLFxxhcEeloYa1OlnWD6UgxlVQKJH5i_0p", + "https://music.youtube.com/watch?v=RO75ZzqUOJw", + "https://www.youtube.com/watch?v=qVHyl0P_P-M", + "https://www.youtube.com/watch?v=34CZjsEI1yU", + ]; + for url in urls.iter() { + println!("url: {:?}", url); + let tracks = get_tracks(url.to_string()).await.unwrap(); + assert!(!tracks.is_empty()); + } + } +} diff --git a/src/commands/skip.rs b/src/commands/skip.rs new file mode 100644 index 0000000..dac7c57 --- /dev/null +++ b/src/commands/skip.rs @@ -0,0 +1,45 @@ +use crate::state::State; +use std::error::Error; +use twilight_model::{ + gateway::payload::incoming::InteractionCreate, + http::interaction::{InteractionResponse, InteractionResponseType}, +}; +use twilight_util::builder::InteractionResponseDataBuilder; + +pub(crate) async fn skip( + interaction: Box, + state: State, +) -> Result<(), Box> { + tracing::debug!( + "skip command in guild {:?} in channel {:?} by {:?}", + interaction.guild_id, + interaction.channel, + interaction.author(), + ); + + let Some(guild_id) = interaction.guild_id else { + return Ok(()); + }; + + if let Some(call_lock) = state.songbird.get(guild_id) { + let call = call_lock.lock().await; + call.queue().skip()?; + } + + let interaction_response_data = InteractionResponseDataBuilder::new() + .content("Skipped the next track") + .build(); + + let response = InteractionResponse { + kind: InteractionResponseType::ChannelMessageWithSource, + data: Some(interaction_response_data), + }; + + state + .http + .interaction(interaction.application_id) + .create_response(interaction.id, &interaction.token, &response) + .await?; + + Ok(()) +} diff --git a/src/handler.rs b/src/handler.rs index 662acfc..09a9085 100644 --- a/src/handler.rs +++ b/src/handler.rs @@ -1,5 +1,5 @@ use crate::commands::queue::{build_action_row, build_queue_embeds, TRACKS_PER_PAGE}; -use crate::commands::{delete, join, leave, pause, play, queue, resume, stop}; +use crate::commands::{delete, join, leave, loop_queue, pause, play, queue, resume, skip, stop}; use crate::state::State; use futures::Future; use std::error::Error; @@ -19,6 +19,8 @@ enum InteractionCommand { Play(String), Stop, Pause, + Skip, + Loop, Resume, Leave, Join, @@ -108,6 +110,12 @@ impl Handler { InteractionCommand::Pause => { spawn(pause(interaction, Arc::clone(&self.state))) } + InteractionCommand::Skip => { + spawn(skip(interaction, Arc::clone(&self.state))) + } + InteractionCommand::Loop => { + spawn(loop_queue(interaction, Arc::clone(&self.state))) + } InteractionCommand::Resume => { spawn(resume(interaction, Arc::clone(&self.state))) } @@ -191,6 +199,8 @@ fn parse_interaction_command(command: &CommandData) -> InteractionCommand { } "stop" => InteractionCommand::Stop, "pause" => InteractionCommand::Pause, + "skip" => InteractionCommand::Skip, + "loop" => InteractionCommand::Loop, "resume" => InteractionCommand::Resume, "leave" => InteractionCommand::Leave, "join" => InteractionCommand::Join, diff --git a/src/main.rs b/src/main.rs index 7bf0cd7..cb2aeb0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -77,6 +77,8 @@ async fn main() -> Result<(), Box> { cache, songbird, standby: Standby::new(), + guild_settings: Default::default(), + tracks: Default::default(), }), ) }; diff --git a/src/state.rs b/src/state.rs index 192a08c..7c059dc 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,15 +1,25 @@ -use songbird::Songbird; +use dashmap::DashMap; +use songbird::{input::YoutubeDl, Songbird}; use std::sync::Arc; use twilight_cache_inmemory::InMemoryCache; use twilight_http::Client as HttpClient; +use twilight_model::id::{marker::GuildMarker, Id}; use twilight_standby::Standby; +use uuid::Uuid; pub(crate) type State = Arc; +#[derive(Debug)] +pub(crate) struct Settings { + pub(crate) loop_queue: bool, +} + #[derive(Debug)] pub(crate) struct StateRef { pub(crate) http: HttpClient, pub(crate) cache: InMemoryCache, pub(crate) songbird: Songbird, pub(crate) standby: Standby, + pub(crate) guild_settings: DashMap, Settings>, + pub(crate) tracks: DashMap, DashMap>, }