Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ with one *slight* but **important** difference:
- Simplified {func}`deepmimo.export<differt.plugins.deepmimo.export>` to reduce redundant code (by <gh-user:jeertmans>, in <gh-pr:356>).
- Changed type checker from `pyright` to `ty` (by <gh-user:jeertmans>, in <gh-pr:292>).
- Slightly improved code coverage (by <gh-user:jeertmans>, in <gh-pr:362>).
- Bumped minimum required JAX version to [`0.8.1`](https://docs.jax.dev/en/latest/changelog.html#jax-0-8-1-november-18-2025) to use new {func}`jax.jit` syntax as the use of {func}`functools.partial` now raises errors from `ty`, see <ext-gh-issue:jax-ml/jax#34697> (by <gh-user:jeertmans>, in <gh-pr:370>).

### Fixed

Expand Down
2 changes: 1 addition & 1 deletion differt/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"equinox>=0.13.1",
"filelock>=3.15.4",
"fpt-jax==0.1.0",
"jax>=0.7.2",
"jax>=0.8.1",
"jaxtyping>=0.3.2",
"numpy>=1.26.1",
"requests>=2.32.0",
Expand Down
2 changes: 1 addition & 1 deletion differt/src/differt/em/_material.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def from_itu_properties(

aliases = ("itu_" + name.lower().replace(" ", "_"),)

@partial(jax.jit, inline=True, static_argnums=(1, 2, 3, 4))
@jax.jit(inline=True, static_argnums=(1, 2, 3, 4))
def callback(
f: Float[ArrayLike, " *batch"],
a: Float[ArrayLike, ""],
Expand Down
9 changes: 4 additions & 5 deletions differt/src/differt/em/_utd.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# ruff: noqa: N802, N806
# type: ignore # noqa: PGH003
from functools import partial
from typing import Any, Literal, overload

import equinox as eqx
Expand All @@ -12,18 +11,18 @@
# TODO: use ArrayLike instead of Array as inputs


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def _cot(x: Float[Array, " *batch"]) -> Float[Array, " *batch"]:
return 1 / jnp.tan(x)


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def _sign(x: Float[Array, " *batch"]) -> Float[Array, " *batch"]:
ones = jnp.ones_like(x)
return jnp.where(x >= 0, ones, -ones)


@partial(jax.jit, inline=True, static_argnames=("mode"))
@jax.jit(inline=True, static_argnames=("mode"))
def _N(
beta: Float[Array, " *#batch"], n: Float[Array, " *#batch"], mode: Literal["+", "-"]
) -> Float[Array, " *batch"]:
Expand All @@ -32,7 +31,7 @@ def _N(
return jnp.round((beta + jnp.pi) / (2 * n * jnp.pi))


@partial(jax.jit, inline=True, static_argnames=("mode"))
@jax.jit(inline=True, static_argnames=("mode"))
def _a(
beta: Float[Array, " *#batch"], n: Float[Array, " *#batch"], mode: Literal["+", "-"]
) -> Float[Array, " *batch"]:
Expand Down
3 changes: 1 addition & 2 deletions differt/src/differt/em/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import Any

import jax
Expand Down Expand Up @@ -343,7 +342,7 @@ def transition_matrices(
raise NotImplementedError


@partial(jax.jit, static_argnames=("dB",))
@jax.jit(static_argnames=("dB",))
def fspl(
d: Float[ArrayLike, " *#batch"],
f: Float[ArrayLike, " *#batch"],
Expand Down
9 changes: 4 additions & 5 deletions differt/src/differt/geometry/_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import Literal, no_type_check, overload

import equinox as eqx
Expand Down Expand Up @@ -33,7 +32,7 @@ def normalize(
): ...


@partial(jax.jit, static_argnames=("keepdims",), inline=True)
@jax.jit(static_argnames=("keepdims",), inline=True)
def normalize(
vectors: Float[ArrayLike, "*batch 3"],
keepdims: bool = False,
Expand Down Expand Up @@ -79,7 +78,7 @@ def normalize(
)


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def perpendicular_vectors(u: Float[ArrayLike, "*batch 3"]) -> Float[Array, "*batch 3"]:
"""
Generate a vector perpendicular to the input vectors.
Expand Down Expand Up @@ -115,7 +114,7 @@ def perpendicular_vectors(u: Float[ArrayLike, "*batch 3"]) -> Float[Array, "*bat
return normalize(w)[0]


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def orthogonal_basis(
u: Float[ArrayLike, "*batch 3"],
) -> tuple[Float[Array, "*batch 3"], Float[Array, "*batch 3"]]:
Expand Down Expand Up @@ -153,7 +152,7 @@ def orthogonal_basis(
return v, w


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def path_lengths(
paths: Float[ArrayLike, "*batch path_length 3"],
) -> Float[Array, " *batch"]:
Expand Down
17 changes: 9 additions & 8 deletions differt/src/differt/rt/_fermat.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,17 @@ def fermat_path_on_linear_objects(
return jnp.empty((*batch, 0, 3), dtype=dtype)
return jnp.broadcast_to(object_origins, (*batch, num_objects, 3)).astype(dtype)

return fpt_jax.trace_rays(
from_vertices,
to_vertices,
# Needed until https://github.com/jax-ml/jax/issues/34697 is resolved
return fpt_jax.trace_rays( # type: ignore[invalid-return-type]
from_vertices, # type: ignore[invalid-argument-type]
to_vertices, # type: ignore[too-many-positional-arguments]
object_origins,
object_vectors,
num_iters=steps,
unroll=unroll,
num_iters_linesearch=linesearch_steps,
unroll_linesearch=unroll_linesearch,
implicit_diff=implicit_diff,
num_iters=steps, # type: ignore[unknown-argument]
unroll=unroll, # type: ignore[unknown-argument]
num_iters_linesearch=linesearch_steps, # type: ignore[unknown-argument]
unroll_linesearch=unroll_linesearch, # type: ignore[unknown-argument]
implicit_diff=implicit_diff, # type: ignore[unknown-argument]
)


Expand Down
5 changes: 2 additions & 3 deletions differt/src/differt/rt/_image_method.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import overload

import chex
Expand All @@ -9,7 +8,7 @@
from differt.utils import smoothing_function


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def image_of_vertices_with_respect_to_mirrors(
vertices: Float[ArrayLike, "*#batch 3"],
mirror_vertices: Float[ArrayLike, "*#batch 3"],
Expand Down Expand Up @@ -80,7 +79,7 @@ def image_of_vertices_with_respect_to_mirrors(
)


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def intersection_of_rays_with_planes(
ray_origins: Float[ArrayLike, "*#batch 3"],
ray_directions: Float[ArrayLike, "*#batch 3"],
Expand Down
6 changes: 2 additions & 4 deletions differt/src/differt/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""General purpose utilities."""

from functools import partial

import jax
import jax.numpy as jnp
from jaxtyping import Array, ArrayLike, Float, Num, PRNGKeyArray


@partial(jax.jit, static_argnames=("shape",))
@jax.jit(static_argnames=("shape",))
def sample_points_in_bounding_box(
bounding_box: Float[ArrayLike, "2 3"],
shape: tuple[int, ...] = (),
Expand Down Expand Up @@ -69,7 +67,7 @@ def safe_divide(
return jnp.where(zero_div, jnp.zeros(shape, dtype=dtype), num / den)


@partial(jax.jit, inline=True)
@jax.jit(inline=True)
def smoothing_function(
x: Float[ArrayLike, " *#batch"],
/,
Expand Down
3 changes: 1 addition & 2 deletions differt/tests/em/test_fresnel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import chex
import jax
import jax.experimental
import jax.numpy as jnp
import pytest
from jaxtyping import PRNGKeyArray
Expand All @@ -21,7 +20,7 @@
("Glass", 2.511971),
],
)
@jax.experimental.enable_x64()
@jax.enable_x64()
def test_refractive_indices(mat_name: str, expected: float) -> None:
frequency = 1e9 # Hz
mat = materials[mat_name]
Expand Down
4 changes: 2 additions & 2 deletions differt/tests/em/test_interaction_type.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import chex
import jax.experimental
import jax
import jax.numpy as jnp
import pytest
from jaxtyping import DTypeLike
Expand All @@ -10,7 +10,7 @@
class TestInteractionType:
@pytest.mark.parametrize("dtype", [jnp.int32, jnp.int64])
def test_array(self, dtype: DTypeLike) -> None:
with jax.experimental.enable_x64(dtype == jnp.int64):
with jax.enable_x64(dtype == jnp.int64):
arr = jnp.array(list(InteractionType), dtype=dtype)
assert arr.dtype == dtype

Expand Down
2 changes: 1 addition & 1 deletion differt/tests/geometry/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_fibonacci_lattice(
expected_dtype: jnp.dtype,
expectation: AbstractContextManager[Exception],
) -> None:
with jax.experimental.enable_x64(expected_dtype == jnp.float64), expectation: # type: ignore[reportAttributeAccessIssue]
with jax.enable_x64(expected_dtype == jnp.float64), expectation:
got = fibonacci_lattice(n, dtype=dtype)

normalized, lengths = normalize(got)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ bump = [
]
codspeed = ["pytest-codspeed>=2.2.0"]
cuda = [
"jax[cuda]>=0.7.2",
"jax[cuda]>=0.8.1",
]
dev = [
{include-group = "codspeed"},
Expand All @@ -21,6 +21,7 @@ docs = [
{include-group = "beartype"},
{include-group = "sionna"},
"differt[all]",
"jax==0.8.1", # TODO: 0.8.2 introduced metaclass for jax.Array, breaking some references in the docs
"myst-nb>=0.17.2",
"pillow>=10.1.0",
"plotly<6",
Expand Down
Loading
Loading