diff --git a/migrations/20240620151940_init.sql b/migrations/20240620151940_init.sql index 900212a..17cd4cd 100644 --- a/migrations/20240620151940_init.sql +++ b/migrations/20240620151940_init.sql @@ -1,4 +1,3 @@ --- Add migration script here CREATE TABLE IF NOT EXISTS tracks ( id INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/src/commands/play.rs b/src/commands/play.rs index 633de8c..a8a22b4 100644 --- a/src/commands/play.rs +++ b/src/commands/play.rs @@ -3,25 +3,36 @@ use crate::metadata::{Metadata, MetadataMap}; use crate::state::State; use crate::{colors, db}; +use anyhow::Context; use serde::{Deserialize, Serialize}; -use songbird::input::cached::Memory; use songbird::input::{Compose, YoutubeDl}; use songbird::tracks::Track; -use std::io::{BufRead, BufReader}; -use std::ops::Sub; +use std::sync::Arc; use std::{error::Error, time::Duration}; +use std::{ + io::{BufRead, BufReader}, + ops::Sub, +}; use tokio::process::Command; use tracing::debug; use twilight_model::channel::message::embed::{ EmbedAuthor, EmbedField, EmbedFooter, EmbedThumbnail, }; -use twilight_model::channel::message::MessageFlags; +use twilight_model::channel::message::{Embed, MessageFlags}; use twilight_model::gateway::payload::incoming::InteractionCreate; use twilight_model::http::interaction::{InteractionResponse, InteractionResponseType}; use twilight_util::builder::embed::EmbedBuilder; use twilight_util::builder::InteractionResponseDataBuilder; use url::Url; +struct TrackType { + url: String, + title: Option, + duration_string: String, + channel: String, + thumbnail: Option, +} + #[derive(Debug, Serialize, Deserialize)] struct YouTubeTrack { url: Option, @@ -31,6 +42,7 @@ struct YouTubeTrack { playlist: Option, playlist_id: Option, duration_string: String, + thumbnail: Option, } fn build_playlist_url(playlist_id: &str) -> String { @@ -67,6 +79,161 @@ async fn get_tracks( Ok(tracks) } +async fn persistence( + interaction: &InteractionCreate, + track: &YouTubeTrack, + state: State, +) -> Result<(), Box> { + let Some(guild_id) = interaction.guild_id else { + return Ok(()); + }; + let Some(user_id) = interaction.author_id() else { + return Ok(()); + }; + let url = track + .original_url + .clone() + .or(track.url.clone()) + .ok_or("Could not find url")?; + let (author_name, author_global_name) = if let Some(author) = interaction.author() { + (author.name.clone(), author.global_name.clone()) + } else { + ("".to_string(), None) + }; + + db::track::insert_guild(&state.pool, db::track::Guild::new(guild_id.to_string())) + .await + .expect("failed to insert guild: {e}"); + + db::track::insert_user( + &state.pool, + db::track::User::new(user_id.to_string(), author_name, author_global_name), + ) + .await + .expect("failed to insert user: {e}"); + + let track_id = db::track::insert_track( + &state.pool, + db::track::Track::new( + url.clone(), + track.title.clone(), + track.channel.clone(), + track.duration_string.clone(), + track.thumbnail.clone().unwrap_or_default(), + ), + ) + .await + .context("failed to insert track: {e}")?; + db::track::insert_query( + &state.pool, + db::track::Query::new(user_id.to_string(), guild_id.to_string(), track_id), + ) + .await + .context("failed to insert track: {e}")?; + Ok(()) +} + +fn build_single_track_added_embeds(tracks_added: &[TrackType]) -> Vec { + let track = tracks_added.first().unwrap(); + + let host = if let Ok(host) = Url::parse(&track.url) { + Some( + host.host_str() + .unwrap_or_default() + .trim_start_matches("www.") + .to_string(), + ) + } else { + None + }; + + let footer = match host { + Some(host) => EmbedFooter { + text: format!("Streaming from {}", host), + icon_url: Some(format!( + "https://www.google.com/s2/favicons?domain={}", + host + )), + proxy_icon_url: None, + }, + None => EmbedFooter { + text: String::new(), + icon_url: None, + proxy_icon_url: None, + }, + }; + + let mut embed = EmbedBuilder::new() + .author(EmbedAuthor { + name: "🔊 Added to queue".to_string(), + icon_url: None, + proxy_icon_url: None, + url: None, + }) + .title(track.title.clone().unwrap_or("Unknown".to_string())) + .url(track.url.clone()) + .color(colors::BLURPLE) + .footer(footer) + .field(EmbedField { + inline: true, + name: "Duration".to_string(), + value: track.duration_string.clone(), + }) + .field(EmbedField { + inline: true, + name: "Channel".to_string(), + value: track.channel.clone(), + }) + .build(); + + if let Some(thumbnail) = &track.thumbnail { + embed.thumbnail = Some(EmbedThumbnail { + height: None, + proxy_url: None, + url: thumbnail.to_string(), + width: None, + }); + } + + vec![embed] +} + +fn build_playlist_added_embeds(tracks: &[YouTubeTrack], num_tracks_added: usize) -> Vec { + let mut content = String::new(); + let first_track = tracks.first().unwrap(); + content.push_str(&format!( + "Adding playlist: [{}]({})\n", + &first_track + .playlist + .clone() + .unwrap_or("Unknown".to_string()), + build_playlist_url( + &first_track + .playlist_id + .clone() + .unwrap_or("Unknown".to_string()) + ) + )); + content.push_str(&format!( + "Added {} tracks to the queue.\n", + num_tracks_added + )); + let embed = EmbedBuilder::new() + .description(content) + .color(colors::BLURPLE) + .build(); + vec![embed] +} + +fn build_embeds(tracks: &[YouTubeTrack], tracks_added: &[TrackType]) -> Vec { + let num_tracks_added = tracks_added.len(); + match num_tracks_added { + 0 => vec![], + 1 => build_single_track_added_embeds(tracks_added), + _ => build_playlist_added_embeds(tracks, num_tracks_added), + } +} + pub(crate) async fn play( interaction: Box, state: State, @@ -165,14 +332,6 @@ pub(crate) async fn play( call.queue().resume()?; } - struct TrackType { - url: String, - title: Option, - duration_string: String, - channel: String, - thumbnail: Option, - } - let mut tracks_added = vec![]; for yttrack in &tracks { tracing::debug!("track: {:?}", yttrack); @@ -183,50 +342,12 @@ pub(crate) async fn play( .ok_or("Could not find url")?; let mut src = YoutubeDl::new(reqwest::Client::new(), url.clone()); - let src_copy = src.clone(); - let src_copy2 = src.clone(); - let _m = Memory::new(src_copy2.into()); - let track: Track = src_copy.into(); + let track: Track = src.clone().into(); if let Ok(metadata) = src.aux_metadata().await { debug!("metadata: {:?}", metadata); - let (author_name, author_global_name) = if let Some(author) = interaction.author() { - (author.name.clone(), author.global_name.clone()) - } else { - ("".to_string(), None) - }; - - db::track::insert_guild(&state.pool, db::track::Guild::new(guild_id.to_string())) - .await - .expect("failed to insert guild: {e}"); - - db::track::insert_user( - &state.pool, - db::track::User::new(user_id.to_string(), author_name, author_global_name), - ) - .await - .expect("failed to insert user: {e}"); - - let track_id = db::track::insert_track( - &state.pool, - db::track::Track::new( - url.clone(), - yttrack.title.clone(), - yttrack.channel.clone(), - yttrack.duration_string.clone(), - metadata.thumbnail.clone().unwrap_or_default(), - ), - ) - .await - .expect("failed to insert track: {e}"); - - db::track::insert_query( - &state.pool, - db::track::Query::new(user_id.to_string(), guild_id.to_string(), track_id), - ) - .await - .expect("failed to insert track: {e}"); + persistence(&interaction, yttrack, Arc::clone(&state)).await?; tracks_added.push(TrackType { url: url.clone(), @@ -258,87 +379,7 @@ pub(crate) async fn play( } } } - let mut content = String::new(); - let num_tracks_added = tracks_added.len(); - let embeds = match num_tracks_added { - 0 => { - vec![] - } - 1 => { - let track = tracks_added.first().unwrap(); - - let host = Url::parse(&track.url)?; - let host = host - .host_str() - .unwrap_or_default() - .trim_start_matches("www."); - let mut embed = EmbedBuilder::new() - .author(EmbedAuthor { - name: "🔊 Added to queue".to_string(), - icon_url: None, - proxy_icon_url: None, - url: None, - }) - .title(track.title.clone().unwrap_or("Unknown".to_string())) - .url(track.url.clone()) - .color(colors::BLURPLE) - .footer(EmbedFooter { - text: format!("Streaming from {}", host), - icon_url: Some(format!( - "https://www.google.com/s2/favicons?domain={}", - host - )), - proxy_icon_url: None, - }) - .field(EmbedField { - inline: true, - name: "Duration".to_string(), - value: track.duration_string.clone(), - }) - .field(EmbedField { - inline: true, - name: "Channel".to_string(), - value: track.channel.clone(), - }) - .build(); - - if let Some(thumbnail) = &track.thumbnail { - embed.thumbnail = Some(EmbedThumbnail { - height: None, - proxy_url: None, - url: thumbnail.to_string(), - width: None, - }); - } - - vec![embed] - } - _ => { - let first_track = tracks.first().unwrap(); - content.push_str(&format!( - "Adding playlist: [{}]({})\n", - &first_track - .playlist - .clone() - .unwrap_or("Unknown".to_string()), - build_playlist_url( - &first_track - .playlist_id - .clone() - .unwrap_or("Unknown".to_string()) - ) - )); - content.push_str(&format!( - "Added {} tracks to the queue:\n", - num_tracks_added - )); - let embed = EmbedBuilder::new() - .description(content) - .color(colors::BLURPLE) - .build(); - vec![embed] - } - }; + let embeds = build_embeds(&tracks, &tracks_added); state .http diff --git a/src/db/track.rs b/src/db/track.rs index b13bd22..c2e6e56 100644 --- a/src/db/track.rs +++ b/src/db/track.rs @@ -36,10 +36,19 @@ pub(crate) async fn insert_track( pool: &sqlx::SqlitePool, track: Track, ) -> Result { - let query = - "INSERT OR REPLACE INTO tracks (url, title, channel, duration, thumbnail, updated) VALUES ($1, $2, $3, $4, $5, $6)"; + let query = r#" + INSERT INTO tracks (url, title, channel, duration, thumbnail, updated) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (url) DO UPDATE SET + title = excluded.title, + channel = excluded.channel, + duration = excluded.duration, + thumbnail = excluded.thumbnail, + updated = excluded.updated + "#; + let res = sqlx::query(query) - .bind(track.url) + .bind(&track.url) .bind(track.title) .bind(track.channel) .bind(track.duration) @@ -47,7 +56,23 @@ pub(crate) async fn insert_track( .bind(track.updated) .execute(pool) .await?; - Ok(res.last_insert_rowid()) + + let id = res.last_insert_rowid(); + + // the track id is 0 if we only updated a row + if id == 0 { + let track: Track = sqlx::query_as( + r#" + SELECT id FROM tracks + WHERE url = $1 + "#, + ) + .bind(track.url) + .fetch_one(pool) + .await?; + return Ok(track.id); + } + Ok(id) } #[derive(Debug, FromRow)] @@ -70,8 +95,14 @@ impl User { } pub(crate) async fn insert_user(pool: &sqlx::SqlitePool, user: User) -> Result<(), sqlx::Error> { - let query = - "INSERT OR REPLACE INTO users (id, name, global_name, updated) VALUES ($1, $2, $3, $4)"; + let query = r#" + INSERT INTO users (id, name, global_name, updated) + VALUES ($1, $2, $3, $4) + ON CONFLICT (id) DO UPDATE SET + name = excluded.name, + global_name = excluded.global_name, + updated = excluded.updated + "#; sqlx::query(query) .bind(user.id) .bind(user.name) @@ -105,8 +136,10 @@ impl Query { } pub(crate) async fn insert_query(pool: &sqlx::SqlitePool, q: Query) -> Result { - let query = - "INSERT OR REPLACE INTO queries (user_id, guild_id, track_id, updated) VALUES ($1, $2, $3, $4)"; + let query = r#" + INSERT INTO queries (user_id, guild_id, track_id, updated) + VALUES ($1, $2, $3, $4) + "#; let res = sqlx::query(query) .bind(q.user_id) .bind(q.guild_id) @@ -132,15 +165,17 @@ impl Guild { } } -pub(crate) async fn insert_guild( - pool: &sqlx::SqlitePool, - guild: Guild, -) -> Result { - let query = "INSERT OR REPLACE INTO guilds (id, updated) VALUES ($1, $2)"; - let res = sqlx::query(query) +pub(crate) async fn insert_guild(pool: &sqlx::SqlitePool, guild: Guild) -> Result<(), sqlx::Error> { + let query = r#" + INSERT INTO guilds (id, updated) + VALUES ($1, $2) + ON CONFLICT (id) DO UPDATE SET + updated = excluded.updated + "#; + sqlx::query(query) .bind(guild.id) .bind(guild.updated) .execute(pool) .await?; - Ok(res.last_insert_rowid()) + Ok(()) } diff --git a/src/main.rs b/src/main.rs index 797c333..705a135 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use futures::StreamExt; use signal::signal_handler; use songbird::{shards::TwilightMap, Songbird}; use state::StateRef; -use std::{env, error::Error, sync::Arc}; +use std::{env, error::Error, str::FromStr, sync::Arc}; use tokio::select; use tracing::{debug, info}; use twilight_cache_inmemory::InMemoryCache; @@ -34,16 +34,15 @@ async fn main() -> Result<(), Box> { println!("Starting up..."); - // Initialize the tracing subscriber. tracing_subscriber::fmt::init(); info!("Starting up..."); let (mut shards, state) = { let db = env::var("DATABASE_URL").map_err(|_| "DATABASE_URL is not set")?; - let options = SqliteConnectOptions::new() - .create_if_missing(true) - .filename(&db); + let options = SqliteConnectOptions::from_str(&db) + .expect("could not create options") + .create_if_missing(true); let pool = SqlitePoolOptions::new() .max_connections(5) .connect_with(options)