From 5ded43881c3117c8a8e8c5822436935d5be542c8 Mon Sep 17 00:00:00 2001 From: overtrue Date: Sun, 17 May 2026 05:05:04 +0800 Subject: [PATCH] test(s3): cover custom XML request headers --- crates/s3/src/client.rs | 149 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/crates/s3/src/client.rs b/crates/s3/src/client.rs index dbc7c40..f52eec2 100644 --- a/crates/s3/src/client.rs +++ b/crates/s3/src/client.rs @@ -2816,6 +2816,18 @@ mod tests { use super::*; use aws_smithy_http_client::test_util::{CaptureRequestReceiver, capture_request}; use std::collections::HashMap; + use std::io::{Read, Write}; + use std::net::{TcpListener, TcpStream}; + use std::sync::mpsc; + use std::thread; + use std::time::{Duration, Instant}; + + #[derive(Debug)] + struct CapturedXmlRequest { + method: String, + target: String, + headers: Vec<(String, String)>, + } fn test_s3_client( response: Option>, @@ -2873,6 +2885,100 @@ mod tests { (client, request_receiver) } + fn read_xml_request(stream: &mut TcpStream) -> CapturedXmlRequest { + let mut buffer = Vec::new(); + let mut chunk = [0_u8; 1024]; + let header_end = loop { + let read = stream.read(&mut chunk).expect("read HTTP request"); + assert!(read > 0, "client closed connection before headers"); + buffer.extend_from_slice(&chunk[..read]); + + if let Some(position) = buffer.windows(4).position(|window| window == b"\r\n\r\n") { + break position + 4; + } + }; + + let headers_text = String::from_utf8_lossy(&buffer[..header_end]).into_owned(); + let content_length = headers_text + .lines() + .find_map(|line| { + let (name, value) = line.split_once(':')?; + name.eq_ignore_ascii_case("content-length") + .then(|| value.trim().parse::().expect("valid content length")) + }) + .unwrap_or(0); + + while buffer.len() - header_end < content_length { + let read = stream.read(&mut chunk).expect("read HTTP request body"); + assert!(read > 0, "client closed connection before body"); + buffer.extend_from_slice(&chunk[..read]); + } + + let mut lines = headers_text.lines(); + let request_line = lines.next().expect("request line"); + let mut parts = request_line.split_whitespace(); + let method = parts.next().expect("request method").to_string(); + let target = parts.next().expect("request target").to_string(); + let headers = lines + .filter_map(|line| { + let (name, value) = line.split_once(':')?; + Some((name.to_ascii_lowercase(), value.trim().to_string())) + }) + .collect(); + + CapturedXmlRequest { + method, + target, + headers, + } + } + + fn start_xml_test_server() -> ( + String, + mpsc::Receiver, + thread::JoinHandle<()>, + ) { + let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server"); + listener + .set_nonblocking(true) + .expect("configure nonblocking listener"); + let endpoint = format!("http://{}", listener.local_addr().expect("local addr")); + let (sender, receiver) = mpsc::channel(); + + let handle = thread::spawn(move || { + let deadline = Instant::now() + Duration::from_secs(5); + let mut stream = loop { + match listener.accept() { + Ok((stream, _)) => break stream, + Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => { + assert!(Instant::now() < deadline, "timed out waiting for request"); + thread::sleep(Duration::from_millis(10)); + } + Err(error) => panic!("accept request: {error}"), + } + }; + stream + .set_read_timeout(Some(Duration::from_secs(5))) + .expect("set stream read timeout"); + let request = read_xml_request(&mut stream); + sender.send(request).expect("send captured request"); + + let response = "HTTP/1.1 200 OK\r\ncontent-length: 2\r\nconnection: close\r\n\r\nok"; + stream + .write_all(response.as_bytes()) + .expect("write HTTP response"); + }); + + (endpoint, receiver, handle) + } + + fn header_value<'a>(headers: &'a [(String, String)], name: &str) -> Option<&'a str> { + headers + .iter() + .find(|(header_name, _)| header_name.eq_ignore_ascii_case(name)) + .map(|(_, value)| value.as_str()) + } + #[test] fn test_object_info_creation() { let info = ObjectInfo::file("test.txt", 1024); @@ -3595,6 +3701,49 @@ mod tests { assert!(!url.contains("x-amz-bucket-encrypt-enabled")); } + #[tokio::test] + async fn custom_headers_are_added_to_xml_requests_before_signing() { + let (endpoint, request_receiver, server_handle) = start_xml_test_server(); + let (client, _sdk_request_receiver) = test_s3_client_with_endpoint_and_headers( + &endpoint, + None, + vec![RequestHeader { + name: "x-amz-bucket-encrypt-enabled".to_string(), + value: "1".to_string(), + }], + ); + let url = client + .replication_url("bucket") + .expect("replication URL should build"); + + let response = client + .xml_request( + Method::PUT, + url, + Some("application/xml"), + Some(b"".to_vec()), + ) + .await + .expect("xml request should succeed"); + + assert_eq!(response, "ok"); + let request = request_receiver + .recv_timeout(Duration::from_secs(5)) + .expect("server should capture XML request"); + assert_eq!(request.method, "PUT"); + assert_eq!(request.target, "/bucket?replication="); + assert_eq!( + header_value(&request.headers, "x-amz-bucket-encrypt-enabled"), + Some("1") + ); + assert!( + header_value(&request.headers, "authorization") + .expect("authorization header") + .contains("x-amz-bucket-encrypt-enabled") + ); + server_handle.join().expect("server thread should finish"); + } + #[tokio::test] async fn delete_object_without_force_delete_omits_rustfs_header() { let (client, request_receiver) = test_s3_client(None);