diff --git a/src/cli.rs b/src/cli.rs new file mode 100644 index 0000000..cc84133 --- /dev/null +++ b/src/cli.rs @@ -0,0 +1,30 @@ +use clap::Parser; + +#[derive(Parser)] +#[command(name = "websocket-debug")] +#[command(about = "A WebSocket debugging tool that logs and saves messages")] +pub struct Args { + /// WebSocket URLs to connect to (e.g., ws://localhost:8080 or wss://example.com/ws) + #[arg(required = true)] + pub urls: Vec, + + /// Bearer token for Authorization header + #[arg(long)] + pub bearer_token: Option, + + /// Enable debug logging (shows request/response headers) + #[arg(long)] + pub debug: bool, + + /// Query string parameters to add to all URLs (pre-encoded, e.g., "name=First%20Last&key=value") + #[arg(short = 'q', long = "query-string-all")] + pub query_string_all: Option, + + /// jq expression(s) to evaluate on JSON text messages for logging (can be specified multiple times) + #[arg(short = 'j', long = "jaq")] + pub jaq: Vec, + + /// Normalize JSON messages: save as .json with pretty-printing and sorted keys (for easier diffing) + #[arg(short = 'n', long = "json-normalize")] + pub json_normalize: bool, +} diff --git a/src/connection.rs b/src/connection.rs new file mode 100644 index 0000000..775e7ff --- /dev/null +++ b/src/connection.rs @@ -0,0 +1,165 @@ +use tokio_tungstenite::{ + connect_async_with_config, MaybeTlsStream, WebSocketStream, + tungstenite::{ + client::IntoClientRequest, + http::header::{HeaderValue, AUTHORIZATION, USER_AGENT}, + }, +}; +use tracing::{debug, info}; +use url::{form_urlencoded, Url}; + +/// Process URLs by adding extra query parameters if specified. +pub fn process_urls( + urls: &[String], + query_string_all: Option<&String>, +) -> Result, url::ParseError> { + let extra_params: Vec<(String, String)> = query_string_all + .map(|qs| { + form_urlencoded::parse(qs.as_bytes()) + .map(|(k, v)| (k.into_owned(), v.into_owned())) + .collect() + }) + .unwrap_or_default(); + + let mut processed_urls = Vec::new(); + for url_str in urls { + let mut url = Url::parse(url_str)?; + + if !extra_params.is_empty() { + let existing: Vec<(String, String)> = url + .query_pairs() + .map(|(k, v)| (k.into_owned(), v.into_owned())) + .collect(); + + let mut query_pairs = url.query_pairs_mut(); + query_pairs.clear(); + for (k, v) in &existing { + query_pairs.append_pair(k, v); + } + for (k, v) in &extra_params { + query_pairs.append_pair(k, v); + } + } + + processed_urls.push(url.to_string()); + } + + Ok(processed_urls) +} + +/// Connect to a WebSocket URL with optional bearer token and debug logging. +pub async fn connect( + url: &str, + idx: usize, + bearer_token: Option<&str>, + debug_enabled: bool, +) -> Result<(char, String, WebSocketStream>), String> { + let letter = (b'A' + idx as u8) as char; + + let mut request = url + .into_client_request() + .map_err(|e| format!("[{}] Invalid URL {}: {}", letter, url, e))?; + + if let Some(token) = bearer_token { + let auth_value = HeaderValue::from_str(&format!("Bearer {}", token)) + .map_err(|e| format!("Invalid bearer token: {}", e))?; + request.headers_mut().insert(AUTHORIZATION, auth_value); + } + + let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")); + request + .headers_mut() + .insert(USER_AGENT, HeaderValue::from_str(&user_agent).unwrap()); + + info!("[{}] Connecting to {}", letter, url); + + if debug_enabled { + log_request(&request, letter); + } + + let (ws_stream, response) = match connect_async_with_config(request, None, false).await { + Ok(result) => result, + Err(e) => { + if debug_enabled { + if let tokio_tungstenite::tungstenite::Error::Http(ref response) = e { + log_error_response(response, letter); + } + } + return Err(format!("[{}] Failed to connect to {}: {}", letter, url, e)); + } + }; + + info!( + "[{}] Connected successfully (status: {})", + letter, + response.status() + ); + + if debug_enabled { + log_response(&response, letter); + } + + Ok((letter, url.to_string(), ws_stream)) +} + +fn log_request(request: &tokio_tungstenite::tungstenite::http::Request<()>, letter: char) { + debug!( + "[{}] Request: {} {}", + letter, + request.method(), + request.uri() + ); + for (name, value) in request.headers() { + debug!( + "[{}] {}: {}", + letter, + name, + value.to_str().unwrap_or("") + ); + } +} + +fn log_response( + response: &tokio_tungstenite::tungstenite::http::Response>>, + letter: char, +) { + debug!( + "[{}] Response: {} {}", + letter, + response.status().as_u16(), + response.status().canonical_reason().unwrap_or("") + ); + for (name, value) in response.headers() { + debug!( + "[{}] {}: {}", + letter, + name, + value.to_str().unwrap_or("") + ); + } +} + +fn log_error_response( + response: &tokio_tungstenite::tungstenite::http::Response>>, + letter: char, +) { + debug!( + "[{}] Error response: {} {}", + letter, + response.status().as_u16(), + response.status().canonical_reason().unwrap_or("") + ); + for (name, value) in response.headers() { + debug!( + "[{}] {}: {}", + letter, + name, + value.to_str().unwrap_or("") + ); + } + if let Some(body) = response.body() { + if let Ok(body_str) = std::str::from_utf8(body) { + debug!("[{}] Response body: {}", letter, body_str); + } + } +} diff --git a/src/filter.rs b/src/filter.rs new file mode 100644 index 0000000..3fbaf59 --- /dev/null +++ b/src/filter.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use jaq_interpret::{Ctx, Filter, FilterT, ParseCtx, RcIter, Val}; +use tracing::{info, warn}; + +pub type JaqFilters = Arc>; + +/// Compile a list of jaq expressions into filters. +pub fn compile_filters(exprs: &[String]) -> Result { + let mut filters = Vec::new(); + + for expr in exprs { + let mut defs = ParseCtx::new(Vec::new()); + defs.insert_natives(jaq_core::core()); + defs.insert_defs(jaq_std::std()); + + let (parsed, errs) = jaq_parse::parse(expr, jaq_parse::main()); + if !errs.is_empty() { + let err_msgs: Vec = errs.iter().map(|e| format!("{:?}", e)).collect(); + return Err(format!( + "Failed to parse jaq expression '{}': {}", + expr, + err_msgs.join(", ") + )); + } + + let parsed = parsed.ok_or_else(|| format!("Failed to parse jaq expression '{}'", expr))?; + let filter = defs.compile(parsed); + + if !defs.errs.is_empty() { + return Err(format!( + "Failed to compile jaq expression '{}' ({} error(s))", + expr, + defs.errs.len() + )); + } + + info!("Using jaq filter: {}", expr); + filters.push(filter); + } + + Ok(Arc::new(filters)) +} + +/// Evaluate all filters against a JSON value, returning the outputs for each filter. +pub fn evaluate_filters( + filters: &JaqFilters, + json_val: &serde_json::Value, + letter: char, + seq_num: u64, +) -> Vec { + let mut all_outputs = Vec::new(); + + for filter in filters.iter() { + let inputs = RcIter::new(core::iter::empty()); + let ctx = Ctx::new([], &inputs); + let out = filter.run((ctx, Val::from(json_val.clone()))); + + let mut filter_outputs = Vec::new(); + for result in out { + match result { + Ok(val) => { + filter_outputs.push(val.to_string()); + } + Err(e) => { + warn!("[{}:{}] jaq error: {}", letter, seq_num, e); + } + } + } + + if filter_outputs.is_empty() { + all_outputs.push("(no output)".to_string()); + } else { + all_outputs.push(filter_outputs.join(", ")); + } + } + + all_outputs +} diff --git a/src/handler.rs b/src/handler.rs new file mode 100644 index 0000000..3f88e1e --- /dev/null +++ b/src/handler.rs @@ -0,0 +1,125 @@ +use std::fs; +use std::path::PathBuf; + +use futures_util::StreamExt; +use tokio::net::TcpStream; +use tokio_tungstenite::{tungstenite::Message, MaybeTlsStream, WebSocketStream}; +use tracing::{error, info, warn}; + +use crate::filter::{evaluate_filters, JaqFilters}; + +/// Configuration for the connection handler. +#[derive(Clone)] +pub struct HandlerConfig { + pub session_dir: PathBuf, + pub jaq_filters: JaqFilters, + pub json_normalize: bool, +} + +/// Handle a WebSocket connection, processing messages and writing them to files. +pub async fn handle_connection( + letter: char, + url: String, + ws_stream: WebSocketStream>, + config: HandlerConfig, +) -> (char, String) { + let (_, mut read) = ws_stream.split(); + let mut seq_num: u64 = 0; + + while let Some(message_result) = read.next().await { + match message_result { + Ok(message) => { + match message { + Message::Text(text) => { + handle_text_message(&text, letter, seq_num, &config); + } + Message::Binary(data) => { + handle_binary_message(&data, letter, seq_num, &config); + } + Message::Ping(data) => { + info!("[{}] Ping: {} bytes", letter, data.len()); + continue; + } + Message::Pong(data) => { + info!("[{}] Pong: {} bytes", letter, data.len()); + continue; + } + Message::Close(frame) => { + if let Some(cf) = frame { + info!("[{}] Connection closed: {} - {}", letter, cf.code, cf.reason); + } else { + info!("[{}] Connection closed", letter); + } + break; + } + Message::Frame(_) => { + continue; + } + } + seq_num += 1; + } + Err(e) => { + warn!("[{}] Error receiving message: {}", letter, e); + } + } + } + + info!("[{}] Session ended. Received {} messages.", letter, seq_num); + (letter, url) +} + +fn handle_text_message(text: &str, letter: char, seq_num: u64, config: &HandlerConfig) { + // Log message based on jaq filters + if config.jaq_filters.is_empty() { + let preview: String = text.chars().take(50).collect(); + let truncated = if text.len() > 50 { "..." } else { "" }; + info!("[{}:{}] Text: {}{}", letter, seq_num, preview, truncated); + } else { + match serde_json::from_str::(text) { + Ok(json_val) => { + let outputs = evaluate_filters(&config.jaq_filters, &json_val, letter, seq_num); + info!("[{}:{}] {}", letter, seq_num, outputs.join(" | ")); + } + Err(e) => { + warn!("[{}:{}] JSON parse error: {}", letter, seq_num, e); + let preview: String = text.chars().take(50).collect(); + let truncated = if text.len() > 50 { "..." } else { "" }; + info!("[{}:{}] Text: {}{}", letter, seq_num, preview, truncated); + } + } + } + + // Write message to file + let (filename, content) = if config.json_normalize { + if let Ok(json_val) = serde_json::from_str::(text) { + let pretty = serde_json::to_string_pretty(&json_val).unwrap_or_else(|_| text.to_string()); + ( + config.session_dir.join(format!("{}{}.json", letter, seq_num)), + pretty, + ) + } else { + ( + config.session_dir.join(format!("{}{}.txt", letter, seq_num)), + text.to_string(), + ) + } + } else { + ( + config.session_dir.join(format!("{}{}.txt", letter, seq_num)), + text.to_string(), + ) + }; + + if let Err(e) = fs::write(&filename, content) { + error!("[{}] Failed to write {:?}: {}", letter, filename, e); + } +} + +fn handle_binary_message(data: &[u8], letter: char, seq_num: u64, config: &HandlerConfig) { + info!("[{}:{}] Binary: {} bytes", letter, seq_num, data.len()); + + let filename = config.session_dir.join(format!("{}{}.bin", letter, seq_num)); + if let Err(e) = fs::write(&filename, data) { + error!("[{}] Failed to write {:?}: {}", letter, filename, e); + } +} diff --git a/src/main.rs b/src/main.rs index 9c25a15..2a7393c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,61 +1,27 @@ +mod cli; +mod connection; +mod filter; +mod handler; + use std::fs; use std::path::PathBuf; -use std::sync::Arc; use chrono::Local; use clap::Parser; -use futures_util::StreamExt; -use jaq_interpret::{Ctx, Filter, FilterT, ParseCtx, RcIter, Val}; use tokio::task::JoinSet; -use tokio_tungstenite::{ - connect_async_with_config, - tungstenite::{ - client::IntoClientRequest, - http::header::{HeaderValue, AUTHORIZATION, USER_AGENT}, - Message, - }, -}; -use tracing::{debug, error, info, warn, Level}; -use url::{form_urlencoded, Url}; +use tracing::{error, info, Level}; -#[derive(Parser)] -#[command(name = "websocket-debug")] -#[command(about = "A WebSocket debugging tool that logs and saves messages")] -struct Args { - /// WebSocket URLs to connect to (e.g., ws://localhost:8080 or wss://example.com/ws) - #[arg(required = true)] - urls: Vec, - - /// Bearer token for Authorization header - #[arg(long)] - bearer_token: Option, - - /// Enable debug logging (shows request/response headers) - #[arg(long)] - debug: bool, - - /// Query string parameters to add to all URLs (pre-encoded, e.g., "name=First%20Last&key=value") - #[arg(short = 'q', long = "query-string-all")] - query_string_all: Option, - - /// jq expression(s) to evaluate on JSON text messages for logging (can be specified multiple times) - #[arg(short = 'j', long = "jaq")] - jaq: Vec, - - /// Normalize JSON messages: save as .json with pretty-printing and sorted keys (for easier diffing) - #[arg(short = 'n', long = "json-normalize")] - json_normalize: bool, -} +use cli::Args; +use connection::{connect, process_urls}; +use filter::compile_filters; +use handler::{handle_connection, HandlerConfig}; #[tokio::main] async fn main() -> Result<(), Box> { let args = Args::parse(); - let log_level = if args.debug { - Level::DEBUG - } else { - Level::INFO - }; + // Initialize logging + let log_level = if args.debug { Level::DEBUG } else { Level::INFO }; tracing_subscriber::fmt() .with_target(false) .with_thread_ids(false) @@ -68,180 +34,21 @@ async fn main() -> Result<(), Box> { fs::create_dir_all(&session_dir)?; info!("Created session directory: {}", session_dir.display()); - // Compile jaq filters if provided - let jaq_filters: Arc> = { - let mut filters = Vec::new(); + // Compile jaq filters + let jaq_filters = compile_filters(&args.jaq)?; - for expr in &args.jaq { - let mut defs = ParseCtx::new(Vec::new()); - defs.insert_natives(jaq_core::core()); - defs.insert_defs(jaq_std::std()); - - let (parsed, errs) = jaq_parse::parse(expr, jaq_parse::main()); - if !errs.is_empty() { - let err_msgs: Vec = errs.iter().map(|e| format!("{:?}", e)).collect(); - return Err(format!( - "Failed to parse jaq expression '{}': {}", - expr, - err_msgs.join(", ") - ) - .into()); - } - - let parsed = - parsed.ok_or_else(|| format!("Failed to parse jaq expression '{}'", expr))?; - let filter = defs.compile(parsed); - - if !defs.errs.is_empty() { - return Err(format!( - "Failed to compile jaq expression '{}' ({} error(s))", - expr, - defs.errs.len() - ) - .into()); - } - - info!("Using jaq filter: {}", expr); - filters.push(filter); - } - - Arc::new(filters) - }; - - // Parse extra query params once if specified - let extra_params: Vec<(String, String)> = args - .query_string_all - .as_ref() - .map(|qs| { - form_urlencoded::parse(qs.as_bytes()) - .map(|(k, v)| (k.into_owned(), v.into_owned())) - .collect() - }) - .unwrap_or_default(); - - // Process URLs and add query string parameters if specified - let mut processed_urls = Vec::new(); - for url_str in &args.urls { - let mut url = Url::parse(url_str)?; - - if !extra_params.is_empty() { - let existing: Vec<(String, String)> = url - .query_pairs() - .map(|(k, v)| (k.into_owned(), v.into_owned())) - .collect(); - - let mut query_pairs = url.query_pairs_mut(); - query_pairs.clear(); - for (k, v) in &existing { - query_pairs.append_pair(k, v); - } - for (k, v) in &extra_params { - query_pairs.append_pair(k, v); - } - } - - processed_urls.push(url.to_string()); - } + // Process URLs with extra query parameters + let processed_urls = process_urls(&args.urls, args.query_string_all.as_ref())?; // Connect to all URLs simultaneously let mut connect_futures = Vec::new(); for (idx, url) in processed_urls.iter().enumerate() { - let letter = (b'A' + idx as u8) as char; - let bearer_token = args.bearer_token.clone(); let url = url.clone(); + let bearer_token = args.bearer_token.clone(); let debug_enabled = args.debug; connect_futures.push(async move { - let mut request = url - .as_str() - .into_client_request() - .map_err(|e| format!("[{}] Invalid URL {}: {}", letter, url, e))?; - - if let Some(ref token) = bearer_token { - let auth_value = HeaderValue::from_str(&format!("Bearer {}", token)) - .map_err(|e| format!("Invalid bearer token: {}", e))?; - request.headers_mut().insert(AUTHORIZATION, auth_value); - } - - let user_agent = format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")); - request - .headers_mut() - .insert(USER_AGENT, HeaderValue::from_str(&user_agent).unwrap()); - - info!("[{}] Connecting to {}", letter, url); - - if debug_enabled { - debug!( - "[{}] Request: {} {}", - letter, - request.method(), - request.uri() - ); - for (name, value) in request.headers() { - debug!( - "[{}] {}: {}", - letter, - name, - value.to_str().unwrap_or("") - ); - } - } - - let (ws_stream, response) = match connect_async_with_config(request, None, false).await - { - Ok(result) => result, - Err(e) => { - if debug_enabled { - if let tokio_tungstenite::tungstenite::Error::Http(ref response) = e { - debug!( - "[{}] Error response: {} {}", - letter, - response.status().as_u16(), - response.status().canonical_reason().unwrap_or("") - ); - for (name, value) in response.headers() { - debug!( - "[{}] {}: {}", - letter, - name, - value.to_str().unwrap_or("") - ); - } - if let Some(body) = response.body() { - if let Ok(body_str) = std::str::from_utf8(body) { - debug!("[{}] Response body: {}", letter, body_str); - } - } - } - } - return Err(format!("[{}] Failed to connect to {}: {}", letter, url, e)); - } - }; - - info!( - "[{}] Connected successfully (status: {})", - letter, - response.status() - ); - - if debug_enabled { - debug!( - "[{}] Response: {} {}", - letter, - response.status().as_u16(), - response.status().canonical_reason().unwrap_or("") - ); - for (name, value) in response.headers() { - debug!( - "[{}] {}: {}", - letter, - name, - value.to_str().unwrap_or("") - ); - } - } - - Ok::<_, String>((letter, url, ws_stream)) + connect(&url, idx, bearer_token.as_deref(), debug_enabled).await }); } @@ -261,156 +68,17 @@ async fn main() -> Result<(), Box> { info!("All {} connections established", connections.len()); - // Spawn tasks for each connection + // Spawn handler tasks for each connection let mut join_set: JoinSet<(char, String)> = JoinSet::new(); for (letter, url, ws_stream) in connections { - let session_dir = session_dir.clone(); - let jaq_filters = jaq_filters.clone(); - let json_normalize = args.json_normalize; + let config = HandlerConfig { + session_dir: session_dir.clone(), + jaq_filters: jaq_filters.clone(), + json_normalize: args.json_normalize, + }; - join_set.spawn(async move { - let (_, mut read) = ws_stream.split(); - let mut seq_num: u64 = 0; - - while let Some(message_result) = read.next().await { - match message_result { - Ok(message) => { - match message { - Message::Text(text) => { - // Determine what to log based on jaq filters - if jaq_filters.is_empty() { - let preview: String = text.chars().take(50).collect(); - let truncated = if text.len() > 50 { "..." } else { "" }; - info!( - "[{}:{}] Text: {}{}", - letter, seq_num, preview, truncated - ); - } else { - match serde_json::from_str::(&text) { - Ok(json_val) => { - let mut all_outputs = Vec::new(); - - for filter in jaq_filters.iter() { - let inputs = RcIter::new(core::iter::empty()); - let ctx = Ctx::new([], &inputs); - let out = - filter.run((ctx, Val::from(json_val.clone()))); - - let mut filter_outputs = Vec::new(); - for result in out { - match result { - Ok(val) => { - filter_outputs.push(val.to_string()); - } - Err(e) => { - warn!( - "[{}:{}] jaq error: {}", - letter, seq_num, e - ); - } - } - } - if filter_outputs.is_empty() { - all_outputs.push("(no output)".to_string()); - } else { - all_outputs.push(filter_outputs.join(", ")); - } - } - - info!( - "[{}:{}] {}", - letter, - seq_num, - all_outputs.join(" | ") - ); - } - Err(e) => { - warn!( - "[{}:{}] JSON parse error: {}", - letter, seq_num, e - ); - let preview: String = text.chars().take(50).collect(); - let truncated = - if text.len() > 50 { "..." } else { "" }; - info!( - "[{}:{}] Text: {}{}", - letter, seq_num, preview, truncated - ); - } - } - } - - // Write message to file - let (filename, content) = if json_normalize { - if let Ok(json_val) = - serde_json::from_str::(&text) - { - let pretty = - serde_json::to_string_pretty(&json_val).unwrap_or(text.clone()); - ( - session_dir.join(format!("{}{}.json", letter, seq_num)), - pretty, - ) - } else { - ( - session_dir.join(format!("{}{}.txt", letter, seq_num)), - text.clone(), - ) - } - } else { - ( - session_dir.join(format!("{}{}.txt", letter, seq_num)), - text.clone(), - ) - }; - if let Err(e) = fs::write(&filename, content) { - error!("[{}] Failed to write {:?}: {}", letter, filename, e); - } - } - Message::Binary(data) => { - info!("[{}:{}] Binary: {} bytes", letter, seq_num, data.len()); - - let filename = - session_dir.join(format!("{}{}.bin", letter, seq_num)); - if let Err(e) = fs::write(&filename, &data) { - error!("[{}] Failed to write {:?}: {}", letter, filename, e); - } - } - Message::Ping(data) => { - info!("[{}] Ping: {} bytes", letter, data.len()); - continue; - } - Message::Pong(data) => { - info!("[{}] Pong: {} bytes", letter, data.len()); - continue; - } - Message::Close(frame) => { - if let Some(cf) = frame { - info!( - "[{}] Connection closed: {} - {}", - letter, cf.code, cf.reason - ); - } else { - info!("[{}] Connection closed", letter); - } - break; - } - Message::Frame(_) => { - continue; - } - } - seq_num += 1; - } - Err(e) => { - warn!("[{}] Error receiving message: {}", letter, e); - } - } - } - - info!("[{}] Session ended. Received {} messages.", letter, seq_num); - (letter, url) - }); + join_set.spawn(handle_connection(letter, url, ws_stream, config)); } // Wait for all connections to finish