Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ui/sdk/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
ibis-framework
langchain_core
polars
pydantic
pyspark
pytest
ray
2 changes: 1 addition & 1 deletion ui/sdk/src/hamilton_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = (0, 4, 2)
__version__ = (0, 4, 3)
22 changes: 22 additions & 0 deletions ui/sdk/src/hamilton_sdk/tracking/pydantic_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any, Dict

import pydantic
from hamilton_sdk.tracking import stats


@stats.compute_stats.register
def compute_stats_pydantic(
result: pydantic.BaseModel, node_name: str, node_tags: dict
) -> Dict[str, Any]:
if hasattr(result, "dump_model"):
llm_result = result.dump_model()
else:
llm_result = result.dict()
return {
"observability_type": "dict",
"observability_value": {
"type": str(type(result)),
"value": llm_result,
},
"observability_schema_version": "0.0.2",
}
3 changes: 1 addition & 2 deletions ui/sdk/src/hamilton_sdk/tracking/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hamilton.data_quality import base as dq_base
from hamilton.lifecycle import base as lifecycle_base

_modules_to_import = ["numpy", "pandas", "polars", "pyspark", "ibis", "langchain"]
_modules_to_import = ["numpy", "pandas", "polars", "pyspark", "ibis", "langchain", "pydantic"]

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -44,7 +44,6 @@ def process_result(result: Any, node: h_node.Node) -> Any:
:param node: The node that produced the result
:return: The processed result - it has to be JSON serializable!
"""

try:
start = py_time.time()
statistics = stats.compute_stats(result, node.name, node.tags)
Expand Down
6 changes: 5 additions & 1 deletion ui/sdk/src/hamilton_sdk/tracking/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def make_json_safe(item: Union[dict, list, str, float, int, bool]) -> Any:
elif hasattr(item, "to_json"):
# we convert to json string and then deserialize it so that
# it's not a string in the UI.
return json.loads(item.to_json())
try:
return json.loads(item.to_json())
except Exception:
# pass
return str(item)[0:200] + "..."
elif hasattr(item, "to_dict"):
return make_json_safe(item.to_dict())
else:
Expand Down
46 changes: 46 additions & 0 deletions ui/sdk/tests/tracking/test_pydantic_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from hamilton_sdk.tracking import pydantic_stats
from pydantic import BaseModel


class TestModel(BaseModel):
name: str
value: int


class TestModel2(BaseModel):
name: str
value: int

def dump_model(self):
return {"name": self.name, "value": self.value}


class EmptyModel(BaseModel):
pass


def test_compute_stats_df_with_dump_model():
model = TestModel2(name="test", value=2)
result = pydantic_stats.compute_stats_pydantic(model, "node1", {"tag1": "value1"})
assert result["observability_type"] == "dict"
assert result["observability_value"]["type"] == str(type(model))
assert result["observability_value"]["value"] == {"name": "test", "value": 2}
assert result["observability_schema_version"] == "0.0.2"


def test_compute_stats_df_without_dump_model():
model = TestModel(name="test", value=1)
result = pydantic_stats.compute_stats_pydantic(model, "node1", {"tag1": "value1"})
assert result["observability_type"] == "dict"
assert result["observability_value"]["type"] == str(type(model))
assert result["observability_value"]["value"] == {"name": "test", "value": 1}
assert result["observability_schema_version"] == "0.0.2"


def test_compute_stats_df_with_empty_model():
model = EmptyModel()
result = pydantic_stats.compute_stats_pydantic(model, "node1", {"tag1": "value1"})
assert result["observability_type"] == "dict"
assert result["observability_value"]["type"] == str(type(model))
assert result["observability_value"]["value"] == {}
assert result["observability_schema_version"] == "0.0.2"
23 changes: 23 additions & 0 deletions ui/sdk/tests/tracking/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,29 @@ def test_make_json_safe_with_pandas_dataframe():
}


def test_make_json_safe_with_pandas_dataframe_duplicate_indexes():
"""to_json failes with duplicate indexes"""
input_dataframe = pd.DataFrame(
{
"A": 1.0,
"B": pd.Timestamp("20130102"),
"C": pd.Series(1, index=list(range(4)), dtype=np.float64),
"D": np.array([3] * 4, dtype="int32"),
"E": pd.Categorical(["test", "train", "test", "train"]),
"F": "foo",
},
index=[0, 1, 0, 1],
)
actual = utils.make_json_safe(input_dataframe)
assert actual == (
" A B C D E F\n"
"0 1.0 2013-01-02 1.0 3 test foo\n"
"1 1.0 2013-01-02 1.0 3 train foo\n"
"0 1.0 2013-01-02 1.0 3 test foo\n"
"1 1.0 2013-01-02 1.0 3 train foo..."
)


def test_make_json_safe_with_pandas_series():
index = pd.date_range("2022-01-01", periods=6, freq="w")
input_series = pd.Series([1, 10, 50, 100, 200, 400], index=index)
Expand Down