Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ build/
*.egg-info/
.venv/

*.sh

# MkDocs build output and caches
site/
.cache/
Expand Down
10 changes: 8 additions & 2 deletions docs/examples/covariance_field/covariance_field_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,14 @@
)

# --8<-- [start:imports]
from psyphy.model import WPPM, GaussianNoise, OddityTask, Prior
from psyphy.model.covariance_field import WPPMCovarianceField
from psyphy.model import (
WPPM,
GaussianNoise,
OddityTask,
OddityTaskConfig,
Prior,
WPPMCovarianceField, # (fast (\Sigma) evaluation)
)

# --8<-- [end:imports]

Expand Down
29 changes: 7 additions & 22 deletions docs/examples/wppm/full_wppm_fit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
import os
import sys

# --8<-- [start:jax_device_setup]
# Must be set BEFORE importing JAX, as JAX locks in its backend on first import.
# Unset any forced CPU override so JAX can auto-detect GPU/TPU if available.
os.environ.pop("JAX_PLATFORM_NAME", None)
# --8<-- [end:jax_device_setup]

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
Expand All @@ -34,40 +40,19 @@
from psyphy.data import TrialData # (batched trial container)
from psyphy.inference import MAPOptimizer # fitter
from psyphy.model import (
WPPM,
GaussianNoise,
OddityTask,
OddityTaskConfig,
Prior,
WPPMCovarianceField, # (fast (\Sigma) evaluation)
WPPM,
)

# --8<-- [end:imports]
PLOTS_DIR = os.path.join(os.path.dirname(__file__), "plots")


# --8<-- [start:jax_device_setup]
# Prefer GPU/TPU if available; otherwise fall back to CPU.
try:
has_accel = any(
getattr(d, "platform", "").lower() in ("gpu", "cuda", "tpu")
for d in jax.devices()
)
except Exception:
has_accel = False

if not has_accel:
# Force CPU backend if no accelerator detected (or JAX not yet initialized).
os.environ.setdefault("JAX_PLATFORM_NAME", "cpu")
else:
# Remove any forced setting so JAX can use the accelerator.
os.environ.pop("JAX_PLATFORM_NAME", None)


# print device used
print("DEVICE USED:", jax.devices()[0])
# Helper: invert criterion to d* for Oddity task
# --8<-- [end:jax_device_setup]


# # Robust ellipse plotting utilities
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
133 changes: 133 additions & 0 deletions docs/examples/wppm/quick_start.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Quick start — fit your first covariance ellipse

> **Goal:** run the full `psyphy` workflow — simulate data, fit a model, inspect the result — in one short script with no GPU required.
>
> The complete runnable script is [`quick_start.py`](https://github.com/flatironinstitute/psyphy/blob/main/docs/examples/wppm/quick_start.py).
> For a spatially-varying field over a 2-D stimulus grid, see the [full example](full_wppm_fit_example.md).

---

### Runtime

| Hardware | Approximate time |
|---|---|
| GPU (any modern CUDA device) | < 5 s |
| CPU (laptop / M-series Mac) | < 2 min |

The three knobs that control runtime:

```python title="Compute settings (quick start defaults)"
--8<-- "docs/examples/wppm/quick_start.py:compute_settings"
```

---

## Step 0 — Imports

```python title="Imports"
--8<-- "docs/examples/wppm/quick_start.py:imports"
```

---

## Step 1 — Define a ground-truth model and sample parameters

We create a WPPM with known parameters to act as the synthetic observer.
Data will be generated from it so we have a ground truth to compare against.

```python title="Ground-truth model"
--8<-- "docs/examples/wppm/quick_start.py:truth_model"
```

---

## Step 2 — Simulate trials at a single reference point

We generate `NUM_TRIALS` oddity-task responses at a single reference stimulus
`ref = [0, 0]`. Probe displacements are scaled by the local covariance
(constant Mahalanobis radius), so trial difficulty stays roughly uniform.

```python title="Simulate data"
--8<-- "docs/examples/wppm/quick_start.py:simulate_data"
```

The `TrialData` container is the canonical input for fitting:

```python title="Data container"
--8<-- "docs/examples/wppm/quick_start.py:data"
```

---

## Step 3 — Build the model to fit

We build a fresh WPPM with the same hyperparameters but independent random
weights, then take one draw from the prior as the starting point for
optimization.

```python title="Model definition"
--8<-- "docs/examples/wppm/quick_start.py:build_model"
```

```python title="Prior sample (initialization)"
--8<-- "docs/examples/wppm/quick_start.py:prior"
```

---

## Step 4 — Fit with MAP optimization

```python title="Fit with MAPOptimizer"
--8<-- "docs/examples/wppm/quick_start.py:fit_map"
```

`MAPOptimizer` runs SGD + momentum and returns a `MAPPosterior` — a point
estimate at $W_\text{MAP}$.

---

## Step 5 — Inspect the fitted covariance ellipse

`WPPMCovarianceField` binds a `(model, params)` pair into a single callable
that returns $\Sigma(x)$ for any stimulus `x`:

```python title="Evaluate covariance fields"
--8<-- "docs/examples/wppm/quick_start.py:cov_fields"
```

The ellipse plot below overlays the ground truth (black), the prior
initialization (blue), and the MAP fit (red) at the single reference point:

```python title="Plot ellipses"
--8<-- "docs/examples/wppm/quick_start.py:plot_ellipses"
```

<div align="center">
<img src="../../examples/wppm/plots/quick_start_ellipses.png"
alt="Covariance ellipses: ground truth (black), prior (blue), MAP fit (red)"
width="480"/>
<p><em>Ground truth (black), prior sample (blue), and MAP-fitted (red) covariance ellipses at the single reference point.</em></p>
</div>

---

## Step 6 — Learning curve

```python title="Access learning curve"
--8<-- "docs/examples/wppm/quick_start.py:plot_learning_curve"
```

<div align="center">
<img src="../../examples/wppm/plots/quick_start_learning_curve.png"
alt="Learning curve"
width="480"/>
<p><em>Negative log-likelihood over optimizer steps.</em></p>
</div>

---

## Next steps

- **Spatially-varying field:** scale up to a full 2-D grid → [full example](full_wppm_fit_example.md).
- **Your own data:** replace the simulated `TrialData` with your own `refs`, `comparisons`, and `responses` arrays.
- **API reference:** see [`MAPOptimizer`](../../reference/inference.md), [`WPPM`](../../reference/model.md), and [`WPPMCovarianceField`](../../reference/model.md).
Loading
Loading