Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ with one *slight* but **important** difference:

- Refactored {func}`image_method<differt.rt.image_method>` and {func}`fermat_path_on_linear_objects<differt.rt.fermat_path_on_linear_objects>` to use {func}`jnp.vectorize<jax.numpy.vectorize>` instead of a custom but complex chain of calls to {func}`jax.vmap`, reducing the code complexity while not affecting performance (by <gh-user:jeertmans>, in <gh-pr:298>).
- Ignored lints PLR091* globally, instead of per-case (by <gh-user:jeertmans>, in <gh-pr:298>).
- Improved code coverage for ray-triangle intersection tests (by <gh-user:jeertmans>, in <gh-pr:301>).
- Refactored benchmarks to reduce the number of benchmarks and avoid depending on JIT compilation (by <gh-user:jeertmans>, in <gh-pr:301>).

<!-- start changelog -->

Expand Down
44 changes: 44 additions & 0 deletions differt/src/differt/rt/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def triangles_visible_from_vertices(
visible from a given transmitter, coloring them in dark gray.

.. plotly::
:context: reset

>>> import equinox as eqx
>>> from differt.rt import (
Expand Down Expand Up @@ -576,6 +577,49 @@ def triangles_visible_from_vertices(
... )
>>> fig = scene.plot(backend="plotly")
>>> fig # doctest: +SKIP

In this example, a receiver is placed at the opposite side of the street canyon,
and its visible triangles are colored in blue. Triangles that are visible from both
the transmitter and the receiver are colored in yellow.

.. plotly::
:context:

>>> scene = eqx.tree_at(
... lambda s: s.receivers, scene, jnp.array([33, 0, 1.5])
... )
>>> visible_triangles = triangles_visible_from_vertices(
... jnp.stack((scene.transmitters, scene.receivers)),
... scene.mesh.triangle_vertices,
... )
>>> triangles_visible_from_tx = visible_triangles[0, :]
>>> triangles_visible_from_rx = visible_triangles[1, :]
>>> visible_by_tx_color = jnp.array([0.2, 0.2, 0.2])
>>> visible_by_rx_color = jnp.array([0.2, 0.8, 0.2])
>>> visible_by_both_color = jnp.array([0.8, 0.8, 0.2])
>>> scene = eqx.tree_at(
... lambda s: s.mesh.face_colors,
... scene,
... scene.mesh.face_colors.at[triangles_visible_from_tx, :].set(
... visible_by_tx_color
... ),
... )
>>> scene = eqx.tree_at(
... lambda s: s.mesh.face_colors,
... scene,
... scene.mesh.face_colors.at[triangles_visible_from_rx, :].set(
... visible_by_rx_color
... ),
... )
>>> scene = eqx.tree_at(
... lambda s: s.mesh.face_colors,
... scene,
... scene.mesh.face_colors.at[
... triangles_visible_from_tx & triangles_visible_from_rx, :
... ].set(visible_by_both_color),
... )
>>> fig = scene.plot(backend="plotly")
>>> fig # doctest: +SKIP
"""
vertices = jnp.asarray(vertices)
triangle_vertices = jnp.asarray(triangle_vertices)
Expand Down
95 changes: 32 additions & 63 deletions differt/tests/benchmarks/test_rt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from typing import Literal

import jax
import jax.numpy as jnp
import pytest
from jaxtyping import Array
from pytest_codspeed import BenchmarkFixture

from differt.geometry import Paths
from differt.rt import (
fermat_path_on_planar_mirrors,
generate_all_path_candidates,
image_method,
triangles_visible_from_vertices,
)
Expand All @@ -23,14 +20,18 @@ def test_image_method(
benchmark: BenchmarkFixture,
) -> None:
setup = large_random_planar_mirrors_setup
_ = benchmark(
lambda: image_method(

def bench_fun() -> None:
image_method(
setup.from_vertices,
setup.to_vertices,
setup.mirror_vertices,
setup.mirror_normals,
).block_until_ready()
)

bench_fun()

_ = benchmark(bench_fun)


@pytest.mark.benchmark(group="fermat_method")
Expand All @@ -39,14 +40,18 @@ def test_fermat(
benchmark: BenchmarkFixture,
) -> None:
setup = large_random_planar_mirrors_setup
_ = benchmark(
lambda: fermat_path_on_planar_mirrors(

def bench_fun() -> None:
fermat_path_on_planar_mirrors(
setup.from_vertices,
setup.to_vertices,
setup.mirror_vertices,
setup.mirror_normals,
).block_until_ready()
)

bench_fun()

_ = benchmark(bench_fun)


@pytest.mark.benchmark(group="triangles_visible_from_vertices")
Expand All @@ -55,74 +60,38 @@ def test_transmitter_visibility_in_simple_street_canyon_scene(
benchmark: BenchmarkFixture,
) -> None:
scene = simple_street_canyon_scene
_ = benchmark(
lambda: triangles_visible_from_vertices(

def bench_fun() -> None:
triangles_visible_from_vertices(
scene.transmitters,
scene.mesh.triangle_vertices,
).block_until_ready()
)

bench_fun()

_ = benchmark(bench_fun)


@pytest.mark.benchmark(group="compute_paths")
@pytest.mark.parametrize("order", [0, 1, 2])
@pytest.mark.parametrize("chunk_size", [None, 20_000])
@pytest.mark.parametrize("assume_quads", [False, True])
@pytest.mark.parametrize("method", ["exhaustive", "hybrid"])
def test_compute_paths_in_simple_street_canyon_scene(
order: int,
chunk_size: int | None,
assume_quads: bool,
method: Literal["exhaustive", "hybrid"],
simple_street_canyon_scene: TriangleScene,
benchmark: BenchmarkFixture,
) -> None:
scene = simple_street_canyon_scene.set_assume_quads(assume_quads)
if chunk_size:
scene = simple_street_canyon_scene.set_assume_quads()

def bench_fun() -> None:
for path in scene.compute_paths(
order,
def bench_fun() -> None:
num_valid_paths = jnp.array(0, dtype=jnp.int32)
for order in range(3):
for paths in scene.compute_paths(
order=order,
method=method,
chunk_size=chunk_size,
chunk_size=10_000,
):
path.vertices.block_until_ready()

else:
num_valid_paths += paths.num_valid_paths
num_valid_paths.block_until_ready()

def bench_fun() -> None:
scene.compute_paths(
order,
method=method,
chunk_size=None,
).vertices.block_until_ready()

_ = benchmark(bench_fun)


@pytest.mark.benchmark(group="jitted_compute_paths")
@pytest.mark.parametrize("provide_pc", [False, True])
def test_jitted_compute_paths(
provide_pc: bool,
simple_street_canyon_scene: TriangleScene,
benchmark: BenchmarkFixture,
) -> None:
scene = simple_street_canyon_scene

path_candidates = generate_all_path_candidates(scene.mesh.num_triangles, 2)

if provide_pc:

@jax.jit
def fun(path_candidates: Array) -> Paths:
return scene.compute_paths(path_candidates=path_candidates)

else:

@jax.jit
def fun(path_candidates: Array) -> Paths:
return scene.compute_paths(order=path_candidates.shape[1])

def bench_fun() -> None:
fun(path_candidates).vertices.block_until_ready()
bench_fun()

_ = benchmark(bench_fun)
16 changes: 6 additions & 10 deletions differt/tests/benchmarks/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,6 @@ def loss(model: LOSModel, scene: TriangleScene) -> Float[Array, " "]:
return jnp.mean((pred - paths.mask.astype(pred.dtype)) ** 2)


def test_dataloader(
simple_street_canyon_scene: TriangleScene,
benchmark: BenchmarkFixture,
key: PRNGKeyArray,
) -> None:
dataloader = train_dataloader(simple_street_canyon_scene, key=key)
_ = benchmark(lambda: next(dataloader))


def test_train_step(
simple_street_canyon_scene: TriangleScene,
benchmark: BenchmarkFixture,
Expand Down Expand Up @@ -147,4 +138,9 @@ def make_step(
initial=(model, opt_state, jnp.array(0.0)),
)

_ = benchmark(lambda: next(iterator)[-1].block_until_ready()) # noqa: PLW0108
def bench_fun() -> None:
next(iterator)[-1].block_until_ready()

bench_fun()

_ = benchmark(bench_fun)
Loading