Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Fixed

- Pipeline no longer hangs indefinitely when inference server stalls mid-response ([#433](https://github.com/allenai/olmocr/issues/433))
- Per-page max backoff now falls back gracefully instead of killing the entire job

### Added

- `--request_timeout_s` CLI flag to control per-request timeout (default 120s)

## [v0.4.25](https://github.com/allenai/olmocr/releases/tag/v0.4.25) - 2026-01-25

## [v0.4.24](https://github.com/allenai/olmocr/releases/tag/v0.4.24) - 2026-01-23
Expand Down
132 changes: 68 additions & 64 deletions olmocr/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,10 @@ async def try_single_page(
r"---\nprimary_language: (?:[a-z]{2}|null)\nis_rotation_valid: (?:True|False|true|false)\nrotation_correction: (?:0|90|180|270)\nis_table: (?:True|False|true|false)\nis_diagram: (?:True|False|true|false)\n(?:---|---\n[\s\S]+)"
)

timeout_s = getattr(args, "request_timeout_s", None)

async with max_concurrent_requests_limit:
status_code, response_body = await apost(COMPLETION_URL, json_data=query, api_key=api_key)
status_code, response_body = await apost(COMPLETION_URL, json_data=query, api_key=api_key, timeout_s=timeout_s)

if status_code != 200:
logger.warning(
Expand Down Expand Up @@ -273,8 +275,8 @@ async def try_single_page_with_backoff(
)
await asyncio.sleep(sleep_delay)

logger.error(f"Max backoff attempts reached for {pdf_orig_path}-{page_num}, terminating job")
sys.exit(1)
logger.error(f"Max backoff attempts reached for {pdf_orig_path}-{page_num}, giving up on page")
return None


async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path: str, page_num: int) -> PageResult:
Expand Down Expand Up @@ -378,7 +380,7 @@ async def process_page(args, worker_id: int, pdf_orig_path: str, pdf_local_path:
# It feels strange perhaps, but httpx and aiohttp are very complex beasts
# Ex. the sessionpool in httpcore has 4 different locks in it, and I've noticed
# that at the scale of 100M+ requests, that they deadlock in different strange ways
async def apost(url, json_data, api_key=None):
async def apost(url, json_data, api_key=None, timeout_s=None):
parsed_url = urlparse(url)
host = parsed_url.hostname
# Default to 443 for HTTPS, 80 for HTTP
Expand All @@ -392,76 +394,77 @@ async def apost(url, json_data, api_key=None):

writer = None
try:
if use_ssl:
ssl_context = ssl.create_default_context()
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context)
else:
reader, writer = await asyncio.open_connection(host, port)
async with asyncio.timeout(timeout_s):
if use_ssl:
ssl_context = ssl.create_default_context()
reader, writer = await asyncio.open_connection(host, port, ssl=ssl_context)
else:
reader, writer = await asyncio.open_connection(host, port)

json_payload = json.dumps(json_data)
json_payload = json.dumps(json_data)

headers = [
f"POST {path} HTTP/1.1",
f"Host: {host}",
f"Content-Type: application/json",
f"Content-Length: {len(json_payload)}",
]
headers = [
f"POST {path} HTTP/1.1",
f"Host: {host}",
f"Content-Type: application/json",
f"Content-Length: {len(json_payload)}",
]

if api_key:
headers.append(f"Authorization: Bearer {api_key}")
if api_key:
headers.append(f"Authorization: Bearer {api_key}")

headers.append("Connection: close")
headers.append("Connection: close")

request = "\r\n".join(headers) + "\r\n\r\n" + json_payload
writer.write(request.encode())
await writer.drain()
request = "\r\n".join(headers) + "\r\n\r\n" + json_payload
writer.write(request.encode())
await writer.drain()

status_line = await reader.readline()
if not status_line:
raise ConnectionError("No response from server")
status_parts = status_line.decode().strip().split(" ", 2)
if len(status_parts) < 2:
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
status_code = int(status_parts[1])
status_line = await reader.readline()
if not status_line:
raise ConnectionError("No response from server")
status_parts = status_line.decode().strip().split(" ", 2)
if len(status_parts) < 2:
raise ValueError(f"Malformed status line: {status_line.decode().strip()}")
status_code = int(status_parts[1])

# Read headers
headers = {}
while True:
line = await reader.readline()
if line in (b"\r\n", b"\n", b""):
break
key, _, value = line.decode().partition(":")
headers[key.strip().lower()] = value.strip()

# Read response body
if "content-length" in headers:
body_length = int(headers["content-length"])
response_body = await reader.readexactly(body_length)
elif headers.get("transfer-encoding", "") == "chunked":
chunks = []
# Read headers
headers = {}
while True:
# Read chunk size line
size_line = await reader.readline()
chunk_size = int(size_line.strip(), 16) # Hex format

if chunk_size == 0:
await reader.readline() # Read final CRLF
line = await reader.readline()
if line in (b"\r\n", b"\n", b""):
break
key, _, value = line.decode().partition(":")
headers[key.strip().lower()] = value.strip()

# Read response body
if "content-length" in headers:
body_length = int(headers["content-length"])
response_body = await reader.readexactly(body_length)
elif headers.get("transfer-encoding", "") == "chunked":
chunks = []
while True:
# Read chunk size line
size_line = await reader.readline()
chunk_size = int(size_line.strip(), 16) # Hex format

if chunk_size == 0:
await reader.readline() # Read final CRLF
break

chunk_data = await reader.readexactly(chunk_size)
chunks.append(chunk_data)

# Read trailing CRLF after chunk data
await reader.readline()

response_body = b"".join(chunks)
elif headers.get("connection", "") == "close":
# Read until connection closes
response_body = await reader.read()
else:
raise ConnectionError("Cannot determine response body length")

chunk_data = await reader.readexactly(chunk_size)
chunks.append(chunk_data)

# Read trailing CRLF after chunk data
await reader.readline()

response_body = b"".join(chunks)
elif headers.get("connection", "") == "close":
# Read until connection closes
response_body = await reader.read()
else:
raise ConnectionError("Cannot determine response body length")

return status_code, response_body
return status_code, response_body
except Exception as e:
# Pass through errors
raise e
Expand Down Expand Up @@ -1228,6 +1231,7 @@ async def main():
parser.add_argument("--target_longest_image_dim", type=int, help="Dimension on longest side to use for rendering the pdf pages", default=1288)
parser.add_argument("--target_anchor_text_len", type=int, help="Maximum amount of anchor text to use (characters), not used for new models", default=-1)
parser.add_argument("--guided_decoding", action="store_true", help="Enable guided decoding for model YAML type outputs")
parser.add_argument("--request_timeout_s", type=float, default=120, help="Timeout (seconds) for a single HTTP request to the inference server, per attempt")
parser.add_argument(
"--disk_logging",
type=str,
Expand Down
43 changes: 40 additions & 3 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import base64
import json
import os
Expand All @@ -10,9 +11,11 @@

from olmocr.pipeline import (
PageResult,
apost,
build_page_query,
get_markdown_path,
process_page,
try_single_page_with_backoff,
)


Expand Down Expand Up @@ -215,7 +218,7 @@ async def test_process_page_with_rotation_correction(self):
# Counter to track number of API calls
call_count = 0

async def mock_apost(url, json_data, api_key=None):
async def mock_apost(url, json_data, api_key=None, timeout_s=None):
nonlocal call_count
call_count += 1

Expand Down Expand Up @@ -317,7 +320,7 @@ async def test_process_page_with_cumulative_rotation(self):
# Counter to track number of API calls
call_count = 0

async def mock_apost(url, json_data, api_key=None):
async def mock_apost(url, json_data, api_key=None, timeout_s=None):
nonlocal call_count
call_count += 1

Expand Down Expand Up @@ -426,7 +429,7 @@ async def test_process_page_rotation_wraps_around(self):
# Counter to track number of API calls
call_count = 0

async def mock_apost(url, json_data, api_key=None):
async def mock_apost(url, json_data, api_key=None, timeout_s=None):
nonlocal call_count
call_count += 1

Expand Down Expand Up @@ -568,3 +571,37 @@ def test_path_traversal_with_dotdot_stays_in_workspace(self):
assert resolved_path.startswith(resolved_workspace), (
f"BUG: Path traversal attack! Markdown path '{resolved_path}' escapes " f"workspace '{resolved_workspace}'. Paths with ../ should be sanitized."
)


class TestApostTimeout:
@pytest.mark.asyncio
async def test_apost_times_out_on_stalled_server(self):
async def stalled_server(reader, writer):
while (await reader.readline()) not in (b"\r\n", b""):
pass
writer.write(b"HTTP/1.1 200 OK\r\nContent-Length: 999999\r\n\r\n")
await writer.drain()
await asyncio.sleep(60)

server = await asyncio.start_server(stalled_server, "127.0.0.1", 0)
port = server.sockets[0].getsockname()[1]
try:
with pytest.raises(TimeoutError):
await apost(f"http://127.0.0.1:{port}/v1", json_data={}, timeout_s=0.5)
finally:
server.close()
await server.wait_closed()


class TestBackoffNoExit:
@pytest.mark.asyncio
async def test_max_backoff_returns_none_not_exit(self):
args = MockArgs()
mock_page = AsyncMock(side_effect=ConnectionError("down"))

with patch("olmocr.pipeline.try_single_page", mock_page):
with patch("asyncio.sleep", new_callable=AsyncMock):
result = await try_single_page_with_backoff(args, "t.pdf", "t.pdf", 1, 0, 0)

assert result is None
assert mock_page.call_count == 10