Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions TESTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Testing Guide

This repository relies on `pytest` and the [`uv`](https://github.com/astral-sh/uv) toolchain for managing virtual environments. The sections below outline how to install dependencies, run the full suite, execute focused subsets, and add new tests that match the project's standards.

## 1. Environment Setup

1. Ensure you have Python 3.13+ available on your path.
2. Install project dependencies with uv:
```bash
uv sync
```
This command reads `pyproject.toml` and `uv.lock`, creating the `.venv/` virtual environment with all runtime and development packages (including `pytest`, `pytest-mock`, snapshot tooling, and linting utilities).

If you add or upgrade dependencies, regenerate the lockfile with:
```bash
uv lock
```

## 2. Running Tests

### Full Suite

Run every test, including snapshot comparisons and Textual UI checks:
```bash
uv run pytest
```

### Targeted Runs

- Specific file: `uv run pytest tests/managers/test_splits.py`
- Single test node: `uv run pytest tests/components/modules/test_people.py::TestPeopleModule::test_action_edit_person_updates_successfully`
- With coverage: `uv run pytest --cov=bagels --cov-report=term-missing`

Passing `-k "<expression>"` lets you filter by substring (e.g., `-k "split and not delete"`).

## 3. Writing Tests

When extending the suite:

- **Structure**: Group related tests inside classes and keep helper functions private to the file. Co-locate fixtures near the tests that use them.
- **Parametrization**: Prefer `@pytest.mark.parametrize` for exercising multiple scenarios in a single test body. This keeps assertions DRY and ensures edge cases stay visible.
- **Mocking**: Use the `mocker` fixture from `pytest-mock` for patching collaborators. Spy on existing functions when you want call verification and rely on `MagicMock` for simulating external dependencies.
- **Edge Cases**: Include failure paths (exceptions, empty responses, missing records) alongside the happy path.
- **Snapshot & UI Tests**: For Textual components, rely on the existing snapshot helpers or slim `SimpleNamespace` fakes to avoid heavy UI bootstrapping.

Follow the existing naming convention (`test_<unit>.py`) and keep files under `tests/`.

## 4. Continuous Integration Expectations

- New tests must pass with `uv run pytest` before a pull request is opened.
- Update `pyproject.toml` and `uv.lock` when introducing additional dependencies.
- Keep commit messages descriptive so reviewers understand why changes were made (e.g., “Add CRUD coverage tests for split manager”).

## 5. Troubleshooting

- **Virtual environment issues**: Remove `.venv/` and re-run `uv sync`.
- **Snapshot diffs**: Accept new baselines only after verifying the textual output manually.
- **Textual warnings**: These often stem from missing `CONFIG`. Import `bagels.config` and use `Config.get_default()` in tests to seed defaults without reading user config files.

By following this workflow, your tests will integrate cleanly with the project’s automation and provide clear coverage of success, error, and edge scenarios.
15 changes: 15 additions & 0 deletions feature/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[project]
name = "bagels-feature-tests"
version = "0.1.0"
description = "Testing dependencies for the feature branch."
requires-python = ">=3.13"

dependencies = [
"pytest>=8.3.1",
"pytest-cov>=5.0.0",
"pytest-xdist>=3.6.1",
"syrupy>=4.6.1",
"time-machine>=2.16.0",
"textual-dev==1.6.1",
"pytest-mock>=3.14.0",
]
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ dev-dependencies = [
"time-machine==2.16.0",
"ruff>=0.9.1",
"pre-commit>=4.0.1",
"freezegun>=1.5.1",
"pytest-mock>=3.14.0",
Comment on lines +80 to +81
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The freezegun library is being added, but time-machine is already included as a dependency on line 77. Both libraries serve the same purpose of mocking time in tests. Including both adds unnecessary bloat to the project's dependencies and can cause confusion for developers. It's best practice to choose one and use it consistently.

]

[tool.hatch.metadata]
Expand Down
2 changes: 1 addition & 1 deletion src/bagels/managers/categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_categories_count():
"""Count all categories excluding deleted ones."""
session = Session()
try:
stmt = select(Category)
stmt = select(Category).filter(Category.deletedAt.is_(None))
return len(session.scalars(stmt).all())
finally:
session.close()
Expand Down
110 changes: 98 additions & 12 deletions src/bagels/managers/records.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import datetime, timedelta
from typing import Any

from sqlalchemy import func
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload, sessionmaker

from bagels.managers.splits import create_split, get_splits_by_record_id, update_split
Expand All @@ -14,8 +16,38 @@
Session = sessionmaker(bind=db_engine)


# Custom exceptions for records manager
class RecordNotFoundException(Exception):
"""Raised when a record with the requested id cannot be found.

Attributes:
record_id -- id of the record that was not found
"""


class InvalidRecordDataException(Exception):
"""Raised when provided record data is invalid or missing required fields."""


class DatabaseOperationException(Exception):
"""Raised when an underlying database operation fails.

The original exception is chained on raise to preserve context.
"""



# region Create
def create_record(record_data: dict):
"""Create a Record from a mapping of values.

Raises:
InvalidRecordDataException: If record_data is not a dict or missing required keys.
DatabaseOperationException: If an underlying DB error occurs.
"""
if not isinstance(record_data, dict):
raise InvalidRecordDataException("record_data must be a dict")

session = Session()
try:
record = Record(**record_data)
Expand All @@ -24,18 +56,34 @@ def create_record(record_data: dict):
session.refresh(record)
session.expunge(record)
return record
except SQLAlchemyError as e:
# Wrap DB errors in a project-specific exception for callers/tests
raise DatabaseOperationException("Failed to create record") from e
finally:
session.close()


def create_record_and_splits(record_data: dict, splits_data: list[dict]):
"""Create a record and associated splits.

Raises:
InvalidRecordDataException: If inputs are of incorrect types.
DatabaseOperationException: If any DB operation fails.
"""
if not isinstance(splits_data, list):
raise InvalidRecordDataException("splits_data must be a list of mappings")

session = Session()
try:
record = create_record(record_data)
for split in splits_data:
if not isinstance(split, dict):
raise InvalidRecordDataException("each split must be a dict")
split["recordId"] = record.id
create_split(split)
return record
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to create record and splits") from e
finally:
session.close()
Comment on lines 66 to 88
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This function and update_record_and_splits are not transactionally atomic. create_record and create_split each manage their own database sessions and transactions. If create_split fails for a subsequent split, the record and any previous splits will have already been committed, leaving the database in an inconsistent state. The entire operation should succeed or fail together.

To fix this, the manager functions (create_record, create_split, etc.) should be refactored to accept an optional session argument. This function would then create a single session, pass it to the underlying calls, and commit the transaction only after all operations have succeeded.


Expand All @@ -56,6 +104,9 @@ def get_record_by_id(record_id: int, populate_splits: bool = False):
)

record = query.get(record_id)
if not record:
# Documented exception: raised when the requested record does not exist
raise RecordNotFoundException(f"No record found for id={record_id}")
return record
finally:
session.close()
Expand All @@ -65,7 +116,9 @@ def get_record_total_split_amount(record_id: int):
session = Session()
try:
splits = get_splits_by_record_id(record_id)
return sum(split.amount for split in splits)
if splits is None:
raise RecordNotFoundException(f"No splits found for record id={record_id}")
return sum((getattr(split, "amount", 0) or 0) for split in splits)
Comment on lines 118 to +121
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The check if splits is None: is incorrect because get_splits_by_record_id always returns a list, which may be empty but will never be None. This makes the check dead code. Furthermore, raising a RecordNotFoundException is semantically incorrect here, as a record can validly exist without any splits. The original implementation correctly handled this by summing an empty list to 0.

The getattr and or 0 are also overly defensive, as split.amount is a non-nullable field on the Split model.

        splits = get_splits_by_record_id(record_id)
        return sum(split.amount for split in splits)

finally:
session.close()

Expand Down Expand Up @@ -102,7 +155,12 @@ def get_records(
Category.name.in_(category_names)
)
if operator_amount not in [None, ""]:
operator, amount = get_operator_amount(operator_amount)
try:
operator, amount = get_operator_amount(operator_amount)
except Exception as e:
raise InvalidRecordDataException(
f"Invalid operator_amount format: {operator_amount}"
) from e
Comment on lines +158 to +163
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Catching the generic Exception is too broad and can mask unexpected errors. The get_operator_amount function is most likely to raise a ValueError if the amount part of the string cannot be converted to a float. It's better to catch a more specific exception to make error handling more precise and avoid accidentally catching unrelated exceptions.

            except ValueError as e:
                raise InvalidRecordDataException(
                    f"Invalid operator_amount format: {operator_amount}"
                ) from e

if operator and amount:
query = query.filter(Record.amount.op(operator)(amount))
if label not in [None, ""]:
Expand All @@ -114,6 +172,8 @@ def get_records(

records = query.all()
return records
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to query records") from e
finally:
session.close()

Expand Down Expand Up @@ -169,6 +229,8 @@ def get_spending(start_date, end_date) -> list[float]:
return _calculate_daily_spending(
records, start_date, end_date, cumulative=False
)
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to calculate spending") from e
finally:
session.close()

Expand All @@ -179,6 +241,8 @@ def get_spending_trend(start_date, end_date) -> list[float]:
try:
records = _get_spending_records(session, start_date, end_date)
return _calculate_daily_spending(records, start_date, end_date, cumulative=True)
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to calculate spending trend") from e
finally:
session.close()

Expand All @@ -187,7 +251,9 @@ def is_record_all_splits_paid(record_id: int):
session = Session()
try:
splits = get_splits_by_record_id(record_id)
return all(split.isPaid for split in splits)
if splits is None:
raise RecordNotFoundException(f"No splits found for record id={record_id}")
return all(getattr(split, "isPaid", False) for split in splits)
Comment on lines 253 to +256
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function has the same logical flaw as get_record_total_split_amount. The check if splits is None: is incorrect because get_splits_by_record_id always returns a list. Raising RecordNotFoundException is also semantically wrong, as a record can exist without splits. For a record with no splits, all([]) correctly evaluates to True, which is the desired behavior. The original implementation was correct.

        splits = get_splits_by_record_id(record_id)
        return all(split.isPaid for split in splits)

finally:
session.close()

Expand Down Expand Up @@ -260,22 +326,37 @@ def adjust_balance(r):
results.append(total_balance)
current += timedelta(days=1)
return results
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to calculate daily balance") from e
finally:
session.close()


# region Update
def update_record(record_id: int, updated_data: dict):
"""Update the record with the provided mapping.

Raises:
RecordNotFoundException: If the record doesn't exist.
InvalidRecordDataException: If updated_data is not a dict.
DatabaseOperationException: For underlying DB errors.
"""
if not isinstance(updated_data, dict):
raise InvalidRecordDataException("updated_data must be a dict")

session = Session()
try:
record = session.query(Record).get(record_id)
if record:
for key, value in updated_data.items():
setattr(record, key, value)
session.commit()
session.refresh(record)
session.expunge(record)
if not record:
raise RecordNotFoundException(f"No record found for id={record_id}")
for key, value in updated_data.items():
setattr(record, key, value)
session.commit()
session.refresh(record)
session.expunge(record)
return record
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to update record") from e
finally:
session.close()

Expand All @@ -290,6 +371,8 @@ def update_record_and_splits(
for index, split in enumerate(record_splits):
update_split(split.id, splits_data[index])
return record
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to update record and splits") from e
finally:
session.close()

Expand All @@ -299,9 +382,12 @@ def delete_record(record_id: int):
session = Session()
try:
record = session.query(Record).get(record_id)
if record:
session.delete(record)
session.commit()
if not record:
raise RecordNotFoundException(f"No record found for id={record_id}")
session.delete(record)
session.commit()
return record
except SQLAlchemyError as e:
raise DatabaseOperationException("Failed to delete record") from e
finally:
session.close()
Loading