Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions changelog.d/19510.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Simplify Rust HTTP client response streaming and limiting.
3 changes: 3 additions & 0 deletions rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ pyo3 = { version = "0.27.2", features = [
"anyhow",
"abi3",
"abi3-py310",
# So we can pass `bytes::Bytes` directly back to Python efficiently,
# https://docs.rs/pyo3/latest/pyo3/bytes/index.html
"bytes",
] }
pyo3-log = "0.13.1"
pythonize = "0.27.0"
Expand Down
2 changes: 1 addition & 1 deletion rust/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl NotFoundError {
import_exception!(synapse.api.errors, HttpResponseException);

impl HttpResponseException {
pub fn new(status: StatusCode, bytes: Vec<u8>) -> pyo3::PyErr {
pub fn new(status: StatusCode, bytes: bytes::Bytes) -> pyo3::PyErr {
HttpResponseException::new_err((
status.as_u16(),
status.canonical_reason().unwrap_or_default(),
Expand Down
75 changes: 21 additions & 54 deletions rust/src/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
use std::{collections::HashMap, future::Future, sync::OnceLock};

use anyhow::Context;
use futures::TryStreamExt;
use headers::HeaderMapExt;
use http_body_util::BodyExt;
use once_cell::sync::OnceCell;
use pyo3::{create_exception, exceptions::PyException, prelude::*};
use reqwest::RequestBuilder;
Expand Down Expand Up @@ -236,62 +235,30 @@ impl HttpClient {

let status = response.status();

// Find the expected `Content-Length` so we can pre-allocate the buffer
// necessary to read the response. It's expected that not every request will
// have a `Content-Length` header.
//
// `response.content_length()` does exist but the "value does not directly
// represents the value of the `Content-Length` header, but rather the size
// of the response’s body"
// (https://docs.rs/reqwest/latest/reqwest/struct.Response.html#method.content_length)
// and we want to avoid reading the entire body at this point because we
// purposely stream it below until the `response_limit`.
let content_length = {
let content_length = response
.headers()
.typed_get::<headers::ContentLength>()
// We need a `usize` for the `Vec::with_capacity(...)` usage below
.and_then(|content_length| content_length.0.try_into().ok());

// Sanity check that the request isn't too large from the information
// they told us (may be inaccurate so we also check below as we actually
// read the bytes)
if let Some(content_length_bytes) = content_length {
if content_length_bytes > response_limit {
Err(anyhow::anyhow!(
"Response size (defined by `Content-Length`) too large"
))?;
}
}

content_length
};

// Stream the response to avoid allocating a giant object on the server
// above our expected `response_limit`.
let mut stream = response.bytes_stream();
// Pre-allocate the buffer based on the expected `Content-Length`
let mut buffer = Vec::with_capacity(
content_length
// Default to pre-allocating nothing when the request doesn't have a
// `Content-Length` header
.unwrap_or(0),
);
while let Some(chunk) = stream.try_next().await.context("reading body")? {
if buffer.len() + chunk.len() > response_limit {
Err(anyhow::anyhow!("Response size too large"))?;
}

buffer.extend_from_slice(&chunk);
}
// A light-weight way to read the response up until the `response_limit`. We
// want to avoid allocating a giant response object on the server above our
// expected `response_limit` to avoid out-of-memory DOS problems.
let body = reqwest::Body::from(response);
let limited_body = http_body_util::Limited::new(body, response_limit);
let collected = limited_body
.collect()
.await
.map_err(anyhow::Error::from_boxed)
.with_context(|| {
format!(
"Response body exceeded response limit ({} bytes)",
response_limit
Comment on lines +249 to +250
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future, I'd like to improve these error messages even further (include context of the request method/URL)

RuntimeError: Response exceeded response limit (1)
	
	Caused by:
	    length limit exceeded

)
})?;
let bytes: bytes::Bytes = collected.to_bytes();

if !status.is_success() {
return Err(HttpResponseException::new(status, buffer));
return Err(HttpResponseException::new(status, bytes));
}

let r = Python::attach(|py| buffer.into_pyobject(py).map(|o| o.unbind()))?;

Ok(r)
// Because of the `pyo3` `bytes` feature, we can pass this back to Python
// land efficiently
Ok(bytes)
})
}
}
Expand Down
18 changes: 18 additions & 0 deletions tests/synapse_rust/test_http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,24 @@ async def do_request() -> None:
self.get_success(self.till_deferred_has_result(do_request()))
self.assertEqual(self.server.calls, 1)

def test_request_response_limit_exceeded(self) -> None:
"""
Test to make sure we handle the response limit being exceeded
"""

async def do_request() -> None:
await self._rust_http_client.get(
url=self.server.endpoint,
# Small limit so we hit the limit
response_limit=1,
)

self.assertFailure(
self.till_deferred_has_result(do_request()),
RuntimeError,
)
self.assertEqual(self.server.calls, 1)

async def test_logging_context(self) -> None:
"""
Test to make sure the `LoggingContext` (logcontext) is handled correctly
Expand Down
Loading