diff --git a/Cargo.lock b/Cargo.lock index dfe0b73..c49864a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -233,9 +233,11 @@ dependencies = [ "anyhow", "bit_rev", "console-subscriber", + "crossterm", "flume", "indicatif", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] @@ -332,6 +334,31 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "crossterm" +version = "0.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f476fe445d41c9e991fd07515a6f463074b782242ccf4a5b7b1d1012e70824df" +dependencies = [ + "bitflags 2.9.3", + "crossterm_winapi", + "libc", + "mio 0.8.11", + "parking_lot", + "signal-hook", + "signal-hook-mio", + "winapi", +] + +[[package]] +name = "crossterm_winapi" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "acdd7c62a3665c7f6830a51635d9ac9b23ed385797f70a83bb8bafe9c572ab2b" +dependencies = [ + "winapi", +] + [[package]] name = "dashmap" version = "5.5.3" @@ -1062,6 +1089,18 @@ dependencies = [ "adler2", ] +[[package]] +name = "mio" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4a650543ca06a924e8b371db273b2756685faae30f8487da1b56505a8f78b0c" +dependencies = [ + "libc", + "log", + "wasi 0.11.1+wasi-snapshot-preview1", + "windows-sys 0.48.0", +] + [[package]] name = "mio" version = "1.0.4" @@ -1599,6 +1638,27 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "signal-hook" +version = "0.3.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d881a16cf4426aa584979d30bd82cb33429027e42122b169753d6ef1085ed6e2" +dependencies = [ + "libc", + "signal-hook-registry", +] + +[[package]] +name = "signal-hook-mio" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34db1a06d485c9142248b7a054f034b349b212551f3dfd19c94d45a754a217cd" +dependencies = [ + "libc", + "mio 0.8.11", + "signal-hook", +] + [[package]] name = "signal-hook-registry" version = "1.4.6" @@ -1792,7 +1852,7 @@ dependencies = [ "bytes", "io-uring", "libc", - "mio", + "mio 1.0.4", "parking_lot", "pin-project-lite", "signal-hook-registry", diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index cba89ce..bd04161 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -20,3 +20,5 @@ console-subscriber = { workspace = true, optional = true } tracing.workspace = true tracing-subscriber.workspace = true flume.workspace = true +crossterm = "0.27" +tokio-util = "0.7" diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 73da317..340c2f4 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -1,17 +1,33 @@ +use crossterm::{ + event::{self, Event, KeyCode, KeyModifiers}, + terminal::{self}, +}; use indicatif::{ProgressBar, ProgressState, ProgressStyle}; use std::{ collections::HashMap, fmt::Write, io::SeekFrom, sync::{atomic::AtomicU64, Arc}, + time::Duration, }; use tokio::{ fs::{create_dir_all, File}, io::{AsyncSeekExt, AsyncWriteExt}, + signal, }; +use tokio_util::sync::CancellationToken; use tracing::trace; -use bit_rev::{session::Session, utils}; +use bit_rev::{ + session::{DownloadState, PieceResult, Session}, + utils, +}; + +fn graceful_shutdown() { + let _ = terminal::disable_raw_mode(); + println!("\n\nShutting down gracefully..."); + std::process::exit(0); +} #[tokio::main] async fn main() { @@ -25,12 +41,14 @@ async fn main() { let output = std::env::args().nth(2); if let Err(err) = download_file(&filename, output).await { + let _ = terminal::disable_raw_mode(); eprintln!("Error: {:?}", err); } } pub async fn download_file(filename: &str, out_file: Option) -> anyhow::Result<()> { - let session = Session::new(); + let session = Arc::new(Session::new()); + let shutdown_token = CancellationToken::new(); let add_torrent_result = session.add_torrent(filename.into()).await?; let torrent = add_torrent_result.torrent.clone(); @@ -85,50 +103,148 @@ pub async fn download_file(filename: &str, out_file: Option) -> anyhow:: let total_downloaded = Arc::new(AtomicU64::new(0)); let total_downloaded_clone = total_downloaded.clone(); + let session_clone = session.clone(); + // Spawn progress update task tokio::spawn(async move { loop { let new = total_downloaded_clone.load(std::sync::atomic::Ordering::Relaxed); pb.set_position(new); - pb.set_message("Downloading"); + let status = match session_clone.get_download_state() { + DownloadState::Init => "Initializing", + DownloadState::Downloading => "Downloading", + DownloadState::Paused => "Paused", + }; + pb.set_message(status); tokio::time::sleep(std::time::Duration::from_millis(100)).await; } }); - let mut hashset = std::collections::HashSet::new(); - while hashset.len() < torrent.piece_hashes.len() { - let pr = add_torrent_result.pr_rx.recv_async().await?; + // Enable raw mode for single keypress detection + terminal::enable_raw_mode().expect("Failed to enable raw mode"); - hashset.insert(pr.index); + // Set up Ctrl+C signal handler + let _shutdown_token_signal = shutdown_token.clone(); + tokio::spawn(async move { + let mut sigint = signal::unix::signal(signal::unix::SignalKind::interrupt()) + .expect("Failed to install SIGINT handler"); - // Map piece to files and write data accordingly - let file_mappings = utils::map_piece_to_files(&torrent, pr.index as usize); - let mut piece_offset = 0; + sigint.recv().await; + graceful_shutdown(); + }); - for mapping in file_mappings { - let file = file_handles.get_mut(&mapping.file_index).ok_or_else(|| { - anyhow::anyhow!("File handle not found for index {}", mapping.file_index) - })?; + // Spawn keyboard input handler + let session_input = session.clone(); + let shutdown_token_input = shutdown_token.clone(); + tokio::spawn(async move { + loop { + // Check for cancellation + if shutdown_token_input.is_cancelled() { + break; + } + + if event::poll(Duration::from_millis(100)).unwrap_or(false) { + if let Ok(Event::Key(key_event)) = event::read() { + match key_event.code { + KeyCode::Char('p') | KeyCode::Char('P') => { + match session_input.get_download_state() { + DownloadState::Paused => { + session_input.resume(); + } + DownloadState::Downloading => { + session_input.pause(); + } + DownloadState::Init => { + println!("\r\nCannot pause during initialization"); + } + } + } + KeyCode::Char('q') | KeyCode::Char('Q') => { + graceful_shutdown(); + } + KeyCode::Char('c') + if key_event.modifiers.contains(KeyModifiers::CONTROL) => + { + graceful_shutdown(); + } + _ => {} + } + } + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + }); - // Seek to correct position in file - file.seek(SeekFrom::Start(mapping.file_offset as u64)) + let mut hashset = std::collections::HashSet::new(); + let mut pending_pieces: Vec<_> = Vec::new(); // Queue for pieces received while paused + + while hashset.len() < torrent.piece_hashes.len() { + // Check for shutdown signal + if shutdown_token.is_cancelled() { + break; + } + // Process any pending pieces first if we're now downloading + if session.get_download_state() == DownloadState::Downloading && !pending_pieces.is_empty() + { + let pieces_to_process = std::mem::take(&mut pending_pieces); + for pr in pieces_to_process { + process_piece( + &pr, + &torrent, + &mut file_handles, + &mut hashset, + &total_downloaded, + ) .await?; + } + } - // Write the portion of the piece that belongs to this file - let piece_data = &pr.buf[piece_offset..piece_offset + mapping.length]; - file.write_all(piece_data).await?; + // Use a timeout to periodically check if we should process pending pieces + let pr_result = tokio::time::timeout( + Duration::from_millis(100), + add_torrent_result.pr_rx.recv_async(), + ) + .await; - piece_offset += mapping.length; + match pr_result { + Ok(Ok(pr)) => { + // If paused, queue the piece but don't process it yet + if session.get_download_state() != DownloadState::Downloading { + pending_pieces.push(pr); + continue; + } - trace!( - "Wrote {} bytes to file {} at offset {}", - mapping.length, - mapping.file_index, - mapping.file_offset - ); + // Process piece immediately if downloading + process_piece( + &pr, + &torrent, + &mut file_handles, + &mut hashset, + &total_downloaded, + ) + .await?; + } + Ok(Err(_)) => { + // Channel closed + break; + } + Err(_) => { + // Timeout - continue loop to check pending pieces + continue; + } } + } - total_downloaded.fetch_add(pr.length as u64, std::sync::atomic::Ordering::Relaxed); + // Process any remaining pending pieces at the end + for pr in pending_pieces { + process_piece( + &pr, + &torrent, + &mut file_handles, + &mut hashset, + &total_downloaded, + ) + .await?; } // Sync all files @@ -136,5 +252,49 @@ pub async fn download_file(filename: &str, out_file: Option) -> anyhow:: file.sync_all().await?; } + // Restore terminal on completion + let _ = terminal::disable_raw_mode(); + println!("\nDownload completed!"); + Ok(()) } + +async fn process_piece( + pr: &PieceResult, + torrent: &bit_rev::torrent::Torrent, + file_handles: &mut HashMap, + hashset: &mut std::collections::HashSet, + total_downloaded: &Arc, +) -> anyhow::Result { + hashset.insert(pr.index); + + // Map piece to files and write data accordingly + let file_mappings = utils::map_piece_to_files(torrent, pr.index as usize); + let mut piece_offset = 0; + + for mapping in file_mappings { + let file = file_handles.get_mut(&mapping.file_index).ok_or_else(|| { + anyhow::anyhow!("File handle not found for index {}", mapping.file_index) + })?; + + // Seek to correct position in file + file.seek(SeekFrom::Start(mapping.file_offset as u64)) + .await?; + + // Write the portion of the piece that belongs to this file + let piece_data = &pr.buf[piece_offset..piece_offset + mapping.length]; + file.write_all(piece_data).await?; + + piece_offset += mapping.length; + + trace!( + "Wrote {} bytes to file {} at offset {}", + mapping.length, + mapping.file_index, + mapping.file_offset + ); + } + + total_downloaded.fetch_add(pr.length as u64, std::sync::atomic::Ordering::Relaxed); + Ok(true) +}