Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions .github/workflows/hamilton-sdk.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: SDK Test Workflow

on:
push:
branches:
- main # or any specific branches you want to include
paths:
- 'ui/sdk/**'

pull_request:
paths:
- 'ui/sdk/**'


jobs:
sdk-unit-test:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.9', '3.10', '3.11']
defaults:
run:
working-directory: ui/sdk
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements-test.txt
pip install -e .
- name: Run unit tests
run: |
pytest tests/
3 changes: 3 additions & 0 deletions ui/sdk/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
ibis-framework
langchain_core
polars
pyspark
pytest
ray
16 changes: 15 additions & 1 deletion ui/sdk/src/hamilton_sdk/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,22 @@ def _hash_module(
module,
)

def safe_getmembers(module):
"""Need this because some modules are lazily loaded and we can't get the members.
e.g. ibis.
"""
try:
return inspect.getmembers(module)
except Exception as e:
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
f"Skipping hash for module {module.__name__} because we could not get the members. "
f"Error: {e}"
)
return []

# Loop through the module's attributes
for name, value in inspect.getmembers(module):
for name, value in safe_getmembers(module):
# Check if the attribute is a module
if inspect.ismodule(value):
if value.__package__ is None:
Expand Down
89 changes: 89 additions & 0 deletions ui/sdk/src/hamilton_sdk/tracking/ibis_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from typing import Any, Dict

from hamilton_sdk.tracking import stats
from ibis.expr.datatypes import core

# import ibis.expr.types as ir
from ibis.expr.types import relations

"""Module that houses functions to introspect an Ibis Table. We don't have expression support yet.
"""

base_data_type_mapping_dict = {
"timestamp": "datetime",
"date": "datetime",
"string": "str",
"integer": "numeric",
"double": "numeric",
"float": "numeric",
"boolean": "boolean",
"long": "numeric",
"short": "numeric",
}


def base_data_type_mapping(data_type: core.DataType) -> str:
"""Returns the base data type of the column.
This uses the internal is_* type methods to determine the base data type.
"""
return "unhandled" # TODO: implement this


base_schema = {
# we can't get all of these about an ibis dataframe
"base_data_type": None,
# 'count': 0,
"data_type": None,
# 'histogram': {},
# 'max': 0,
# 'mean': 0,
# 'min': 0,
# 'missing': 0,
"name": None,
"pos": None,
# 'quantiles': {},
# 'std': 0,
# 'zeros': 0
}


def _introspect(table: relations.Table) -> Dict[str, Any]:
"""Introspect a PySpark dataframe and return a dictionary of statistics.

:param df: PySpark dataframe to introspect.
:return: Dictionary of column to metadata about it.
"""
# table.
fields = table.schema().items()
column_to_metadata = []
for idx, (field_name, field_type) in enumerate(fields):
values = base_schema.copy()
values.update(
{
"name": field_name,
"pos": idx,
"data_type": str(field_type),
"base_data_type": base_data_type_mapping(field_type),
"nullable": field_type.nullable,
}
)
column_to_metadata.append(values)
return {
"columns": column_to_metadata,
}


@stats.compute_stats.register
def compute_stats_ibis_table(
result: relations.Table, node_name: str, node_tags: dict
) -> Dict[str, Any]:
# TODO: create custom type instead of dict for UI
o_value = _introspect(result)
return {
"observability_type": "dict",
"observability_value": {
"type": str(type(result)),
"value": o_value,
},
"observability_schema_version": "0.0.2",
}
49 changes: 49 additions & 0 deletions ui/sdk/src/hamilton_sdk/tracking/langchain_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""
Module to pull a few things from langchain objects.
"""

from typing import Any, Dict

from hamilton_sdk.tracking import stats
from langchain_core import documents as lc_documents
from langchain_core import messages as lc_messages


@stats.compute_stats.register(lc_messages.BaseMessage)
def compute_stats_lc_messages(
result: lc_messages.BaseMessage, node_name: str, node_tags: dict
) -> Dict[str, Any]:
result = {"value": result.content, "type": result.type}

return {
"observability_type": "dict",
"observability_value": result,
"observability_schema_version": "0.0.2",
}


@stats.compute_stats.register(lc_documents.Document)
def compute_stats_lc_docs(
result: lc_documents.Document, node_name: str, node_tags: dict
) -> Dict[str, Any]:
if hasattr(result, "to_document"):
return stats.compute_stats(result.to_document(), node_name, node_tags)
else:
# d.page_content # hack because not all documents are serializable
result = {"content": result.page_content, "metadata": result.metadata}
return {
"observability_type": "dict",
"observability_value": result,
"observability_schema_version": "0.0.2",
}


if __name__ == "__main__":
# Example usage
from langchain_core import messages

msg = messages.BaseMessage(content="Hello, World!", type="greeting")
print(stats.compute_stats(msg, "greeting", {}))

doc = lc_documents.Document(page_content="Hello, World!", metadata={"source": "local_dir"})
print(stats.compute_stats(doc, "document", {}))
110 changes: 110 additions & 0 deletions ui/sdk/src/hamilton_sdk/tracking/pyspark_stats.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import Any, Dict

import pyspark.sql as ps
from hamilton_sdk.tracking import stats

"""Module that houses functions to introspect a PySpark dataframe.
"""
# this is a mapping used in the Backend/UI.
# we should probably move this to a shared location.
base_data_type_mapping = {
"timestamp": "datetime",
"date": "datetime",
"string": "str",
"integer": "numeric",
"double": "numeric",
"float": "numeric",
"boolean": "boolean",
"long": "numeric",
"short": "numeric",
}

base_schema = {
# we can't get all of these about a pyspark dataframe
"base_data_type": None,
# 'count': 0,
"data_type": None,
# 'histogram': {},
# 'max': 0,
# 'mean': 0,
# 'min': 0,
# 'missing': 0,
"name": None,
"pos": None,
# 'quantiles': {},
# 'std': 0,
# 'zeros': 0
}


def _introspect(df: ps.DataFrame) -> Dict[str, Any]:
"""Introspect a PySpark dataframe and return a dictionary of statistics.

:param df: PySpark dataframe to introspect.
:return: Dictionary of column to metadata about it.
"""
fields = df.schema.jsonValue()["fields"]
column_to_metadata = []
for idx, field in enumerate(fields):
values = base_schema.copy()
values.update(
{
"name": field["name"],
"pos": idx,
"data_type": field["type"],
"base_data_type": base_data_type_mapping.get(field["type"], "unhandled"),
"nullable": field["nullable"],
}
)
column_to_metadata.append(values)
cost_explain = df._sc._jvm.PythonSQLUtils.explainString(df._jdf.queryExecution(), "cost")
extended_explain = df._sc._jvm.PythonSQLUtils.explainString(
df._jdf.queryExecution(), "extended"
)
return {
"columns": column_to_metadata,
"cost_explain": cost_explain,
"extended_explain": extended_explain,
}


@stats.compute_stats.register
def compute_stats_psdf(result: ps.DataFrame, node_name: str, node_tags: dict) -> Dict[str, Any]:
# TODO: create custom type instead of dict for UI
o_value = _introspect(result)
return {
"observability_type": "dict",
"observability_value": {
"type": str(type(result)),
"value": o_value,
},
"observability_schema_version": "0.0.2",
}


if __name__ == "__main__":
import numpy as np
import pandas as pd

df = pd.DataFrame(
{
"a": [1, 2, 3, 4, 5],
"b": ["a", "b", "c", "d", "e"],
"c": [True, False, True, False, True],
"d": [1.0, 2.0, 3.0, 4.0, 5.0],
"e": pd.Categorical(["a", "b", "c", "d", "e"]),
"f": pd.Series(["a", "b", "c", "d", "e"], dtype="string"),
"g": pd.Series(["a", "b", "c", "d", "e"], dtype="object"),
"h": pd.Series(
["20221231", None, "20221231", "20221231", "20221231"], dtype="datetime64[ns]"
),
"i": pd.Series([None, None, None, None, None], name="a", dtype=np.float64),
"j": pd.Series(name="a", data=pd.date_range("20230101", "20230105")),
}
)
spark = ps.SparkSession.builder.master("local[1]").getOrCreate()
psdf = spark.createDataFrame(df)
import pprint

res = compute_stats_psdf(psdf, "df", {})
pprint.pprint(res)
21 changes: 9 additions & 12 deletions ui/sdk/src/hamilton_sdk/tracking/runs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import importlib
import logging
import sys
import time as py_time
Expand All @@ -14,21 +15,17 @@
from hamilton.data_quality import base as dq_base
from hamilton.lifecycle import base as lifecycle_base

try:
from hamilton_sdk.tracking import numpy_stats # noqa: F401
from hamilton_sdk.tracking import pandas_stats # noqa: F401

except ImportError:
pass

try:
from hamilton_sdk.tracking import polars_stats # noqa: F401

except ImportError:
pass
_modules_to_import = ["numpy", "pandas", "polars", "pyspark", "ibis", "langchain"]

logger = logging.getLogger(__name__)

for module in _modules_to_import:
try:
importlib.import_module(f"hamilton_sdk.tracking.{module}_stats")
except ImportError:
logger.debug(f"Failed to import hamilton_sdk.tracking.{module}_stats")
pass


def process_result(result: Any, node: h_node.Node) -> Any:
"""Processes result -- this is purely a by-type mapping.
Expand Down
7 changes: 7 additions & 0 deletions ui/sdk/src/hamilton_sdk/tracking/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,11 @@ def compute_stats_tuple(result: tuple, node_name: str, node_tags: dict) -> Dict[
},
"observability_schema_version": "0.0.2",
}
# namedtuple -- this how we guide people to not have something tracked easily.
# so we skip it if it has a `secret_key`. This is hacky -- better choice would
# be to have an internal object or way to decorate a parameter to not track it.
if hasattr(result, "_asdict") and not hasattr(result, "secret_key"):
return compute_stats_dict(result._asdict(), node_name, node_tags)
return {
"observability_type": "unsupported",
"observability_value": {
Expand Down Expand Up @@ -154,6 +159,8 @@ def compute_stats_list(result: list, node_name: str, node_tags: dict) -> Dict[st
# else just string it -- max 200 chars.
if len(v) > 200:
v = v[:200] + "..."
elif observed_type == "dict":
v = v_result["observability_value"]
result_values.append(v)
return {
# yes dict type -- that's so that we can display in the UI. It's a hack.
Expand Down
4 changes: 2 additions & 2 deletions ui/sdk/tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_hash_module_with_subpackage():
seen_modules = set()
result = _hash_module(submodule1, hash_object, seen_modules)

assert result.hexdigest() == "9f0b697b6b071d0ca6df18532031a8e553a8327531e249ff457d772b8bd392c7"
assert result.hexdigest() == "4466d1f61b2c57c2b5bfe8a9fec09acd53befcfdf2f5720075aef83e3d6c6bf8"
assert len(seen_modules) == 2
assert {m.__name__ for m in seen_modules} == {
"tests.test_package_to_hash.subpackage",
Expand All @@ -62,7 +62,7 @@ def test_hash_module_complex():
seen_modules = set()
result = _hash_module(test_package_to_hash, hash_object, seen_modules)

assert result.hexdigest() == "fc568608a2f766eac3cbae4021fb367247c6aa36ac4ae72ea98104c1ba2a5e1c"
assert result.hexdigest() == "c22023a4fdc8564de1cda70d05a19d5e8c0ddaaa9dcccf644a2b789b80f19896"
assert len(seen_modules) == 4
assert {m.__name__ for m in seen_modules} == {
"tests.test_package_to_hash",
Expand Down
Loading