Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 149 additions & 0 deletions crates/s3/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2816,6 +2816,18 @@ mod tests {
use super::*;
use aws_smithy_http_client::test_util::{CaptureRequestReceiver, capture_request};
use std::collections::HashMap;
use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::mpsc;
use std::thread;
use std::time::{Duration, Instant};

#[derive(Debug)]
struct CapturedXmlRequest {
method: String,
target: String,
headers: Vec<(String, String)>,
}

fn test_s3_client(
response: Option<http::Response<SdkBody>>,
Expand Down Expand Up @@ -2873,6 +2885,100 @@ mod tests {
(client, request_receiver)
}

fn read_xml_request(stream: &mut TcpStream) -> CapturedXmlRequest {
let mut buffer = Vec::new();
let mut chunk = [0_u8; 1024];
let header_end = loop {
let read = stream.read(&mut chunk).expect("read HTTP request");
assert!(read > 0, "client closed connection before headers");
buffer.extend_from_slice(&chunk[..read]);

if let Some(position) = buffer.windows(4).position(|window| window == b"\r\n\r\n") {
break position + 4;
}
};

let headers_text = String::from_utf8_lossy(&buffer[..header_end]).into_owned();
let content_length = headers_text
.lines()
.find_map(|line| {
let (name, value) = line.split_once(':')?;
name.eq_ignore_ascii_case("content-length")
.then(|| value.trim().parse::<usize>().expect("valid content length"))
})
.unwrap_or(0);

while buffer.len() - header_end < content_length {
let read = stream.read(&mut chunk).expect("read HTTP request body");
assert!(read > 0, "client closed connection before body");
buffer.extend_from_slice(&chunk[..read]);
}

let mut lines = headers_text.lines();
let request_line = lines.next().expect("request line");
let mut parts = request_line.split_whitespace();
let method = parts.next().expect("request method").to_string();
let target = parts.next().expect("request target").to_string();
let headers = lines
.filter_map(|line| {
let (name, value) = line.split_once(':')?;
Some((name.to_ascii_lowercase(), value.trim().to_string()))
})
.collect();

Comment thread
overtrue marked this conversation as resolved.
CapturedXmlRequest {
method,
target,
headers,
}
}

fn start_xml_test_server() -> (
String,
mpsc::Receiver<CapturedXmlRequest>,
thread::JoinHandle<()>,
) {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
listener
.set_nonblocking(true)
.expect("configure nonblocking listener");
let endpoint = format!("http://{}", listener.local_addr().expect("local addr"));
let (sender, receiver) = mpsc::channel();

let handle = thread::spawn(move || {
let deadline = Instant::now() + Duration::from_secs(5);
let mut stream = loop {
match listener.accept() {
Ok((stream, _)) => break stream,
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => {
assert!(Instant::now() < deadline, "timed out waiting for request");
thread::sleep(Duration::from_millis(10));
}
Err(error) => panic!("accept request: {error}"),
}
};
stream
.set_read_timeout(Some(Duration::from_secs(5)))
.expect("set stream read timeout");
let request = read_xml_request(&mut stream);
sender.send(request).expect("send captured request");

let response = "HTTP/1.1 200 OK\r\ncontent-length: 2\r\nconnection: close\r\n\r\nok";
stream
.write_all(response.as_bytes())
.expect("write HTTP response");
});

(endpoint, receiver, handle)
}

fn header_value<'a>(headers: &'a [(String, String)], name: &str) -> Option<&'a str> {
headers
.iter()
.find(|(header_name, _)| header_name.eq_ignore_ascii_case(name))
.map(|(_, value)| value.as_str())
}

#[test]
fn test_object_info_creation() {
let info = ObjectInfo::file("test.txt", 1024);
Expand Down Expand Up @@ -3595,6 +3701,49 @@ mod tests {
assert!(!url.contains("x-amz-bucket-encrypt-enabled"));
}

#[tokio::test]
async fn custom_headers_are_added_to_xml_requests_before_signing() {
let (endpoint, request_receiver, server_handle) = start_xml_test_server();
let (client, _sdk_request_receiver) = test_s3_client_with_endpoint_and_headers(
&endpoint,
None,
vec![RequestHeader {
name: "x-amz-bucket-encrypt-enabled".to_string(),
value: "1".to_string(),
}],
);
let url = client
.replication_url("bucket")
.expect("replication URL should build");

let response = client
.xml_request(
Method::PUT,
url,
Some("application/xml"),
Some(b"<xml/>".to_vec()),
)
.await
.expect("xml request should succeed");

assert_eq!(response, "ok");
let request = request_receiver
.recv_timeout(Duration::from_secs(5))
.expect("server should capture XML request");
assert_eq!(request.method, "PUT");
assert_eq!(request.target, "/bucket?replication=");
assert_eq!(
header_value(&request.headers, "x-amz-bucket-encrypt-enabled"),
Some("1")
);
assert!(
header_value(&request.headers, "authorization")
.expect("authorization header")
.contains("x-amz-bucket-encrypt-enabled")
);
server_handle.join().expect("server thread should finish");
}

#[tokio::test]
async fn delete_object_without_force_delete_omits_rustfs_header() {
let (client, request_receiver) = test_s3_client(None);
Expand Down
Loading