Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion openml/runs/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from collections import OrderedDict
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Expand Down Expand Up @@ -149,6 +150,7 @@ def __init__( # noqa: PLR0913
self.predictions_url = predictions_url
self.description_text = description_text
self.run_details = run_details
self._additional_files: dict[str, tuple[str, bytes]] = {}
self._predictions = None

@property
Expand Down Expand Up @@ -614,6 +616,35 @@ def _parse_publish_response(self, xml_response: dict) -> None:
"""Parse the id from the xml_response and assign it to self."""
self.run_id = int(xml_response["oml:upload_run"]["oml:run_id"])

def add_file(
self,
file_name: str,
content: str | bytes | Path | IO[bytes],
) -> None:
"""Attach additional file to this run.

Parameters
----------
file_name : str
Name under which the file will be stored on OpenML,
e.g. "model.onnx" or "weights.bin".
content : {str, bytes, pathlib.Path, file-like}
The file content. If a Path or file object is passed,
it will be read into memory and stored as bytes.
"""
if isinstance(content, Path):
with content.open("rb") as f:
content_bytes = f.read()
elif hasattr(content, "read"): # file-like
content_bytes = content.read()
elif isinstance(content, str):
content_bytes = content.encode("utf-8")
else:
content_bytes = content

# store as (file_name, content) tuple
self._additional_files[file_name] = (file_name, content_bytes)

def _get_file_elements(self) -> dict:
"""Get file_elements to upload to the server.

Expand Down Expand Up @@ -644,7 +675,9 @@ def _get_file_elements(self) -> dict:
self.model,
)

file_elements = {"description": ("description.xml", self._to_xml())}
file_elements: dict[str, tuple[str, str | bytes]] = {
"description": ("description.xml", self._to_xml())
}

if self.error_message is None:
predictions = arff.dumps(self._generate_arff_dict())
Expand All @@ -653,6 +686,10 @@ def _get_file_elements(self) -> dict:
if self.trace is not None:
trace_arff = arff.dumps(self.trace.trace_to_arff())
file_elements["trace"] = ("trace.arff", trace_arff)

for key, (stored_name, content) in self._additional_files.items():
file_elements[key] = (stored_name, content)

return file_elements

def _to_dict(self) -> dict[str, dict]: # noqa: PLR0912, C901
Expand Down