diff --git a/openml/runs/run.py b/openml/runs/run.py index 945264131..f39ec1f3c 100644 --- a/openml/runs/run.py +++ b/openml/runs/run.py @@ -6,6 +6,7 @@ from collections import OrderedDict from pathlib import Path from typing import ( + IO, TYPE_CHECKING, Any, Callable, @@ -149,6 +150,7 @@ def __init__( # noqa: PLR0913 self.predictions_url = predictions_url self.description_text = description_text self.run_details = run_details + self._additional_files: dict[str, tuple[str, bytes]] = {} self._predictions = None @property @@ -614,6 +616,35 @@ def _parse_publish_response(self, xml_response: dict) -> None: """Parse the id from the xml_response and assign it to self.""" self.run_id = int(xml_response["oml:upload_run"]["oml:run_id"]) + def add_file( + self, + file_name: str, + content: str | bytes | Path | IO[bytes], + ) -> None: + """Attach additional file to this run. + + Parameters + ---------- + file_name : str + Name under which the file will be stored on OpenML, + e.g. "model.onnx" or "weights.bin". + content : {str, bytes, pathlib.Path, file-like} + The file content. If a Path or file object is passed, + it will be read into memory and stored as bytes. + """ + if isinstance(content, Path): + with content.open("rb") as f: + content_bytes = f.read() + elif hasattr(content, "read"): # file-like + content_bytes = content.read() + elif isinstance(content, str): + content_bytes = content.encode("utf-8") + else: + content_bytes = content + + # store as (file_name, content) tuple + self._additional_files[file_name] = (file_name, content_bytes) + def _get_file_elements(self) -> dict: """Get file_elements to upload to the server. @@ -644,7 +675,9 @@ def _get_file_elements(self) -> dict: self.model, ) - file_elements = {"description": ("description.xml", self._to_xml())} + file_elements: dict[str, tuple[str, str | bytes]] = { + "description": ("description.xml", self._to_xml()) + } if self.error_message is None: predictions = arff.dumps(self._generate_arff_dict()) @@ -653,6 +686,10 @@ def _get_file_elements(self) -> dict: if self.trace is not None: trace_arff = arff.dumps(self.trace.trace_to_arff()) file_elements["trace"] = ("trace.arff", trace_arff) + + for key, (stored_name, content) in self._additional_files.items(): + file_elements[key] = (stored_name, content) + return file_elements def _to_dict(self) -> dict[str, dict]: # noqa: PLR0912, C901