Skip to content

Commit 730ee72

Browse files
authored
Savvas/develop/fix valid time (#1932)
* Fix valid_time bug for observations. Add new titles * Linting * Minor error corrected * Fix minor error and font size * Linting
1 parent 05af6fc commit 730ee72

File tree

3 files changed

+21
-7
lines changed

3 files changed

+21
-7
lines changed

packages/evaluate/src/weathergen/evaluate/plotting/plotter.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -401,10 +401,7 @@ def create_maps_per_sample(
401401
region,
402402
tag=tag,
403403
map_kwargs=dict(map_kwargs.get(var, {})) | map_kwargs_global,
404-
title=(
405-
f"{self.stream}, {var} : fstep = {self.fstep:03} "
406-
f"({format_datetime(valid_time)})"
407-
),
404+
title=self.get_map_title(var, valid_time, da_t),
408405
)
409406
plot_names.append(name)
410407

@@ -506,7 +503,7 @@ def scatter_plot(
506503
)
507504

508505
plt.colorbar(scatter_plt, ax=ax, orientation="horizontal", label=f"Variable: {varname}")
509-
plt.title(title)
506+
plt.title(title, fontsize=9.5)
510507
if regionname == "global":
511508
ax.set_global()
512509
else:
@@ -628,6 +625,22 @@ def animation(self, samples, fsteps, variables, select, tag) -> list[str]:
628625
def get_map_output_dir(self, tag):
629626
return self.out_plot_basedir / self.stream / "maps" / tag
630627

628+
def get_map_title(self, var, valid_time, data):
629+
title = f"{self.stream}, {var} : fstep = {self.fstep:03}"
630+
if valid_time is not None:
631+
title += f" ({format_datetime(valid_time)})"
632+
elif "valid_time" in data.coords:
633+
valid_time_start = data["valid_time"].values.min()
634+
valid_time_end = data["valid_time"].values.max()
635+
if valid_time_start != valid_time_end:
636+
title += (
637+
f" ({format_datetime(valid_time_start)} - {format_datetime(valid_time_end)})"
638+
)
639+
else:
640+
title += f" ({format_datetime(valid_time_start)})"
641+
642+
return title
643+
631644

632645
class LinePlots:
633646
def __init__(self, plotter_cfg: dict, output_basedir: str | Path):

packages/evaluate/src/weathergen/evaluate/utils/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,8 @@ def plot_data(reader: Reader, stream: str, global_plotting_opts: dict) -> None:
430430
"fig_size": global_plotting_opts.get("fig_size", (8, 10)),
431431
"fps": global_plotting_opts.get("fps", 2),
432432
"regions": global_plotting_opts.get("regions", ["global"]),
433-
"plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False),
433+
"plot_subtimesteps": reader.get_inference_stream_attr(stream, "tokenize_spacetime", False)
434+
| plot_settings.get("plot_subtimesteps", False),
434435
}
435436
plotter = Plotter(plotter_cfg, reader.runplot_dir)
436437

src/weathergen/model/embeddings.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def forward_channels(self, x_in):
142142
x = peh(self.embed(x_in.transpose(-2, -1)))
143143

144144
for layer in self.layers:
145-
x = checkpoint(layer,x, use_reentrant=False)
145+
x = checkpoint(layer, x, use_reentrant=False)
146146

147147
# read out
148148
if self.unembed_mode == "full":

0 commit comments

Comments
 (0)