Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/qcodes/dataset/descriptions/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,16 @@ def top_level_parameters(self) -> tuple[ParamSpecBase, ...]:
for node_id, in_degree in self._dependency_subgraph.in_degree
if in_degree == 0
}
# Parameters that are inferred from other parameters (have outgoing
# edges in the inference subgraph) should not be independent top-level
# parameters, since their data is part of the tree of the parameter
# they are inferred from.
parameters_inferred_from_others = {
self._node_to_paramspec(node_id)
for node_id, out_degree in self._inference_subgraph.out_degree
if out_degree > 0
}
dependency_top_level = dependency_top_level - parameters_inferred_from_others
standalone_top_level = {
self._node_to_paramspec(node_id)
for node_id, degree in self._graph.degree
Expand Down
65 changes: 62 additions & 3 deletions src/qcodes/dataset/exporters/export_to_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from math import prod
from typing import TYPE_CHECKING, Literal

import numpy as np
from packaging import version as p_version

from qcodes.dataset.linked_datasets.links import links_to_str
Expand Down Expand Up @@ -61,6 +62,56 @@ def _calculate_index_shape(idx: pd.Index | pd.MultiIndex) -> dict[Hashable, int]
return expanded_shape


def _add_inferred_data_vars(
dataset: DataSetProtocol,
name: str,
sub_dict: Mapping[str, npt.NDArray],
xr_dataset: xr.Dataset,
) -> xr.Dataset:
"""Add inferred parameters as data variables to an xarray dataset.

Parameters that are inferred from the top-level measurement parameter
and present in sub_dict but not yet in the dataset are added as data
variables along the existing dimensions.
"""

interdeps = dataset.description.interdeps
meas_paramspec = interdeps.graph.nodes[name]["value"]
_, deps, inferred = interdeps.all_parameters_in_tree_by_group(meas_paramspec)

dep_names = {dep.name for dep in deps}
dims = tuple(d for d in xr_dataset.dims)

for inf in inferred:
if inf.name in dep_names:
continue
if inf.name in xr_dataset:
continue
if inf.name not in sub_dict:
continue

inf_data = sub_dict[inf.name]
if inf_data.dtype == np.dtype("O"):
try:
flat = np.concatenate(inf_data)
except ValueError:
flat = inf_data.ravel()
else:
flat = inf_data.ravel()

# Only add if the data length matches the existing dataset size
expected_size = 1
for d in dims:
expected_size *= xr_dataset.sizes[d]
if flat.shape[0] == expected_size:
xr_dataset[inf.name] = (
dims,
flat.reshape(tuple(xr_dataset.sizes[d] for d in dims)),
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably warn if this is not the case

Comment on lines +78 to +110

return xr_dataset


def _load_to_xarray_dataset_dict_no_metadata(
dataset: DataSetProtocol,
datadict: Mapping[str, Mapping[str, npt.NDArray]],
Expand Down Expand Up @@ -100,25 +151,33 @@ def _load_to_xarray_dataset_dict_no_metadata(
interdeps=dataset.description.interdeps,
dependent_parameter=name,
).to_xarray()
xr_dataset_dict[name] = xr_dataset
xr_dataset_dict[name] = _add_inferred_data_vars(
dataset, name, sub_dict, xr_dataset
)
elif index_is_unique:
df = _data_to_dataframe(
sub_dict,
index,
interdeps=dataset.description.interdeps,
dependent_parameter=name,
)
xr_dataset_dict[name] = _xarray_data_set_from_pandas_multi_index(
xr_dataset = _xarray_data_set_from_pandas_multi_index(
dataset, use_multi_index, name, df, index
)
xr_dataset_dict[name] = _add_inferred_data_vars(
dataset, name, sub_dict, xr_dataset
)
else:
df = _data_to_dataframe(
sub_dict,
index,
interdeps=dataset.description.interdeps,
dependent_parameter=name,
)
xr_dataset_dict[name] = df.reset_index().to_xarray()
xr_dataset = df.reset_index().to_xarray()
xr_dataset_dict[name] = _add_inferred_data_vars(
dataset, name, sub_dict, xr_dataset
)

return xr_dataset_dict

Expand Down
72 changes: 72 additions & 0 deletions tests/dataset/test_parameter_with_setpoints_has_control.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from typing import TYPE_CHECKING

import numpy as np
import numpy.testing as npt

from qcodes.dataset import Measurement
from qcodes.parameters import ManualParameter, ParameterWithSetpoints
from qcodes.validators import Arrays

if TYPE_CHECKING:
from qcodes.dataset.experiment_container import Experiment


def test_parameter_with_setpoints_has_control(experiment: "Experiment"):
class MySp(ParameterWithSetpoints):
def unpack_self(self, value):
res = super().unpack_self(value)
res.append((p1, p1()))
return res

mp_data = np.arange(10)
p1_data = np.linspace(-1, 1, 10)

mp = ManualParameter("mp", vals=Arrays(shape=(10,)), initial_value=mp_data)
p1 = ParameterWithSetpoints(
"p1", vals=Arrays(shape=(10,)), setpoints=(mp,), set_cmd=None
)
p2 = MySp("p2", vals=Arrays(shape=(10,)), setpoints=(mp,), set_cmd=None)
p2.has_control_of.add(p1)

p1(p1_data)
p2_data = np.random.randn(10)
p2(p2_data)

meas = Measurement()
meas.register_parameter(p2)

# Only p2 should be top-level; p1 is inferred from p2
interdeps = meas._interdeps
top_level_names = [p.name for p in interdeps.top_level_parameters]
assert top_level_names == ["p2"]

with meas.run() as ds:
ds.add_result((p2, p2()))

# Verify raw parameter data has exactly one row per parameter
raw_data = ds.dataset.get_parameter_data()
assert list(raw_data.keys()) == ["p2"], "Only p2 should be a top-level result"
for name, arr in raw_data["p2"].items():
assert arr.shape == (1, 10), (
f"Expected shape (1, 10) for {name}, got {arr.shape}"
)

xds = ds.dataset.to_xarray_dataset()

# mp should be the only dimension (not a generic 'index')
assert list(xds.sizes.keys()) == ["mp"]
assert xds.sizes["mp"] == 10

# mp values used as coordinate axis
npt.assert_array_equal(xds.coords["mp"].values, mp_data)

# p2 is the primary data variable with correct values
assert "p2" in xds.data_vars
npt.assert_array_almost_equal(xds["p2"].values, p2_data)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall an assrtion be added for p1 in the xarray dataset as well?


# p1 is included as a data variable (inferred from p2) with correct values
assert "p1" in xds.data_vars
npt.assert_array_almost_equal(xds["p1"].values, p1_data)

# p1 data is also retrievable from the raw parameter data
npt.assert_array_almost_equal(raw_data["p2"]["p1"].ravel(), p1_data)
Loading