-
Notifications
You must be signed in to change notification settings - Fork 850
fix: add asyncio.Lock to prevent concurrent file write conflicts #1520
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,6 @@ | ||||||||
| from __future__ import annotations | ||||||||
|
|
||||||||
| import asyncio | ||||||||
| import json | ||||||||
| from collections.abc import Sequence | ||||||||
| from pathlib import Path | ||||||||
|
|
@@ -21,6 +22,7 @@ def __init__(self, file_backend: Path): | |||||||
| self._next_checkpoint_id: int = 0 | ||||||||
| """The ID of the next checkpoint, starting from 0, incremented after each checkpoint.""" | ||||||||
| self._system_prompt: str | None = None | ||||||||
| self._write_lock = asyncio.Lock() | ||||||||
|
|
||||||||
| async def restore(self) -> bool: | ||||||||
| logger.debug("Restoring context from file: {file_backend}", file_backend=self._file_backend) | ||||||||
|
|
@@ -83,20 +85,21 @@ async def write_system_prompt(self, prompt: str) -> None: | |||||||
| """ | ||||||||
| prompt_line = json.dumps({"role": "_system_prompt", "content": prompt}) + "\n" | ||||||||
|
|
||||||||
| if not self._file_backend.exists() or self._file_backend.stat().st_size == 0: | ||||||||
| async with aiofiles.open(self._file_backend, "w", encoding="utf-8") as f: | ||||||||
| await f.write(prompt_line) | ||||||||
| else: | ||||||||
| tmp_path = self._file_backend.with_suffix(".tmp") | ||||||||
| async with aiofiles.open(tmp_path, "w", encoding="utf-8") as tmp_f: | ||||||||
| await tmp_f.write(prompt_line) | ||||||||
| async with aiofiles.open(self._file_backend, encoding="utf-8") as src_f: | ||||||||
| while True: | ||||||||
| chunk = await src_f.read(64 * 1024) | ||||||||
| if not chunk: | ||||||||
| break | ||||||||
| await tmp_f.write(chunk) | ||||||||
| await aiofiles.os.replace(tmp_path, self._file_backend) | ||||||||
| async with self._write_lock: | ||||||||
| if not self._file_backend.exists() or self._file_backend.stat().st_size == 0: | ||||||||
| async with aiofiles.open(self._file_backend, "w", encoding="utf-8") as f: | ||||||||
| await f.write(prompt_line) | ||||||||
| else: | ||||||||
| tmp_path = self._file_backend.with_suffix(".tmp") | ||||||||
| async with aiofiles.open(tmp_path, "w", encoding="utf-8") as tmp_f: | ||||||||
| await tmp_f.write(prompt_line) | ||||||||
| async with aiofiles.open(self._file_backend, encoding="utf-8") as src_f: | ||||||||
| while True: | ||||||||
| chunk = await src_f.read(64 * 1024) | ||||||||
| if not chunk: | ||||||||
| break | ||||||||
| await tmp_f.write(chunk) | ||||||||
| await aiofiles.os.replace(tmp_path, self._file_backend) | ||||||||
|
|
||||||||
| self._system_prompt = prompt | ||||||||
|
|
||||||||
|
|
@@ -105,8 +108,9 @@ async def checkpoint(self, add_user_message: bool): | |||||||
| self._next_checkpoint_id += 1 | ||||||||
| logger.debug("Checkpointing, ID: {id}", id=checkpoint_id) | ||||||||
|
|
||||||||
| async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: | ||||||||
| await f.write(json.dumps({"role": "_checkpoint", "id": checkpoint_id}) + "\n") | ||||||||
| async with self._write_lock: | ||||||||
| async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: | ||||||||
| await f.write(json.dumps({"role": "_checkpoint", "id": checkpoint_id}) + "\n") | ||||||||
| if add_user_message: | ||||||||
| await self.append_message( | ||||||||
| Message(role="user", content=[system(f"CHECKPOINT {checkpoint_id}")]) | ||||||||
|
|
@@ -131,43 +135,44 @@ async def revert_to(self, checkpoint_id: int): | |||||||
| logger.error("Checkpoint {checkpoint_id} does not exist", checkpoint_id=checkpoint_id) | ||||||||
| raise ValueError(f"Checkpoint {checkpoint_id} does not exist") | ||||||||
|
|
||||||||
| # rotate the context file | ||||||||
| rotated_file_path = await next_available_rotation(self._file_backend) | ||||||||
| if rotated_file_path is None: | ||||||||
| logger.error("No available rotation path found") | ||||||||
| raise RuntimeError("No available rotation path found") | ||||||||
| await aiofiles.os.replace(self._file_backend, rotated_file_path) | ||||||||
| logger.debug( | ||||||||
| "Rotated context file: {rotated_file_path}", rotated_file_path=rotated_file_path | ||||||||
| ) | ||||||||
|
|
||||||||
| # restore the context until the specified checkpoint | ||||||||
| self._history.clear() | ||||||||
| self._token_count = 0 | ||||||||
| self._next_checkpoint_id = 0 | ||||||||
| self._system_prompt = None | ||||||||
| async with ( | ||||||||
| aiofiles.open(rotated_file_path, encoding="utf-8") as old_file, | ||||||||
| aiofiles.open(self._file_backend, "w", encoding="utf-8") as new_file, | ||||||||
| ): | ||||||||
| async for line in old_file: | ||||||||
| if not line.strip(): | ||||||||
| continue | ||||||||
|
|
||||||||
| line_json = json.loads(line) | ||||||||
| if line_json["role"] == "_checkpoint" and line_json["id"] == checkpoint_id: | ||||||||
| break | ||||||||
| async with self._write_lock: | ||||||||
| # rotate the context file | ||||||||
| rotated_file_path = await next_available_rotation(self._file_backend) | ||||||||
| if rotated_file_path is None: | ||||||||
| logger.error("No available rotation path found") | ||||||||
| raise RuntimeError("No available rotation path found") | ||||||||
| await aiofiles.os.replace(self._file_backend, rotated_file_path) | ||||||||
| logger.debug( | ||||||||
| "Rotated context file: {rotated_file_path}", rotated_file_path=rotated_file_path | ||||||||
| ) | ||||||||
|
|
||||||||
| await new_file.write(line) | ||||||||
| if line_json["role"] == "_system_prompt": | ||||||||
| self._system_prompt = line_json["content"] | ||||||||
| elif line_json["role"] == "_usage": | ||||||||
| self._token_count = line_json["token_count"] | ||||||||
| elif line_json["role"] == "_checkpoint": | ||||||||
| self._next_checkpoint_id = line_json["id"] + 1 | ||||||||
| else: | ||||||||
| message = Message.model_validate(line_json) | ||||||||
| self._history.append(message) | ||||||||
| # restore the context until the specified checkpoint | ||||||||
| self._history.clear() | ||||||||
| self._token_count = 0 | ||||||||
| self._next_checkpoint_id = 0 | ||||||||
| self._system_prompt = None | ||||||||
| async with ( | ||||||||
| aiofiles.open(rotated_file_path, encoding="utf-8") as old_file, | ||||||||
| aiofiles.open(self._file_backend, "w", encoding="utf-8") as new_file, | ||||||||
| ): | ||||||||
| async for line in old_file: | ||||||||
| if not line.strip(): | ||||||||
| continue | ||||||||
|
|
||||||||
| line_json = json.loads(line) | ||||||||
| if line_json["role"] == "_checkpoint" and line_json["id"] == checkpoint_id: | ||||||||
| break | ||||||||
|
|
||||||||
| await new_file.write(line) | ||||||||
| if line_json["role"] == "_system_prompt": | ||||||||
| self._system_prompt = line_json["content"] | ||||||||
| elif line_json["role"] == "_usage": | ||||||||
| self._token_count = line_json["token_count"] | ||||||||
| elif line_json["role"] == "_checkpoint": | ||||||||
| self._next_checkpoint_id = line_json["id"] + 1 | ||||||||
| else: | ||||||||
| message = Message.model_validate(line_json) | ||||||||
| self._history.append(message) | ||||||||
|
|
||||||||
| async def clear(self): | ||||||||
| """ | ||||||||
|
|
@@ -182,16 +187,17 @@ async def clear(self): | |||||||
|
|
||||||||
| logger.debug("Clearing context") | ||||||||
|
|
||||||||
| # rotate the context file | ||||||||
| rotated_file_path = await next_available_rotation(self._file_backend) | ||||||||
| if rotated_file_path is None: | ||||||||
| logger.error("No available rotation path found") | ||||||||
| raise RuntimeError("No available rotation path found") | ||||||||
| await aiofiles.os.replace(self._file_backend, rotated_file_path) | ||||||||
| self._file_backend.touch() | ||||||||
| logger.debug( | ||||||||
| "Rotated context file: {rotated_file_path}", rotated_file_path=rotated_file_path | ||||||||
| ) | ||||||||
| async with self._write_lock: | ||||||||
| # rotate the context file | ||||||||
| rotated_file_path = await next_available_rotation(self._file_backend) | ||||||||
| if rotated_file_path is None: | ||||||||
| logger.error("No available rotation path found") | ||||||||
| raise RuntimeError("No available rotation path found") | ||||||||
| await aiofiles.os.replace(self._file_backend, rotated_file_path) | ||||||||
| self._file_backend.touch() | ||||||||
| logger.debug( | ||||||||
| "Rotated context file: {rotated_file_path}", rotated_file_path=rotated_file_path | ||||||||
| ) | ||||||||
|
|
||||||||
| self._history.clear() | ||||||||
| self._token_count = 0 | ||||||||
|
Comment on lines
202
to
203
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 In (Refers to lines 202-205) Prompt for agentsWas this helpful? React with 👍 or 👎 to provide feedback. |
||||||||
|
|
@@ -203,13 +209,15 @@ async def append_message(self, message: Message | Sequence[Message]): | |||||||
| messages = [message] if isinstance(message, Message) else message | ||||||||
| self._history.extend(messages) | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 In-memory state updated before acquiring
Concrete scenario via asyncio.shield
The same pattern exists in
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||
|
|
||||||||
| async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: | ||||||||
| for message in messages: | ||||||||
| await f.write(message.model_dump_json(exclude_none=True) + "\n") | ||||||||
| async with self._write_lock: | ||||||||
| async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: | ||||||||
| for message in messages: | ||||||||
| await f.write(message.model_dump_json(exclude_none=True) + "\n") | ||||||||
|
|
||||||||
| async def update_token_count(self, token_count: int): | ||||||||
| logger.debug("Updating token count in context: {token_count}", token_count=token_count) | ||||||||
| self._token_count = token_count | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🔴 In-memory state updated before acquiring
Suggested change
Was this helpful? React with 👍 or 👎 to provide feedback. |
||||||||
|
|
||||||||
| async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: | ||||||||
| await f.write(json.dumps({"role": "_usage", "token_count": token_count}) + "\n") | ||||||||
| async with self._write_lock: | ||||||||
| async with aiofiles.open(self._file_backend, "a", encoding="utf-8") as f: | ||||||||
| await f.write(json.dumps({"role": "_usage", "token_count": token_count}) + "\n") | ||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_write_lockbefore advancing in-memory stateBecause
run_soul()cancelssoul_taskon user interrupt, anycheckpoint()/append_message()/update_token_count()call that is queued on the new_write_lockcan now be cancelled after_next_checkpoint_id,_history, or_token_counthas already been mutated. That leaves the liveContextahead ofcontext.jsonl: later turns in the same CLI can see phantom checkpoints/messages/token counts, while a resumed or exported session after restart does not. Moving the in-memory updates inside the locked section avoids this divergence under the exact concurrent-write condition this patch introduces.Useful? React with 👍 / 👎.