Skip to content

Commit 101a66b

Browse files
authored
fix(py/genkit): additional ty check fixes in genkit.blocks (#4246)
1 parent e54e164 commit 101a66b

File tree

5 files changed

+66
-40
lines changed

5 files changed

+66
-40
lines changed

py/packages/genkit/src/genkit/blocks/prompt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
generation and management across different parts of the application.
2323
"""
2424

25+
import asyncio
2526
import os
2627
import weakref
27-
from asyncio import Future
2828
from collections.abc import AsyncIterator, Callable
2929
from pathlib import Path
3030
from typing import Any
@@ -240,7 +240,7 @@ def stream(
240240
timeout: float | None = None,
241241
) -> tuple[
242242
AsyncIterator[GenerateResponseChunkWrapper],
243-
Future[GenerateResponseWrapper],
243+
asyncio.Future[GenerateResponseWrapper],
244244
]:
245245
"""Streams the prompt with the given input and configuration.
246246
@@ -262,7 +262,7 @@ def stream(
262262
context=context,
263263
on_chunk=lambda c: stream.send(c),
264264
)
265-
stream.set_close_future(resp)
265+
stream.set_close_future(asyncio.create_task(resp))
266266

267267
return (stream, stream.closed)
268268

py/packages/genkit/src/genkit/blocks/resource.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import inspect
2626
import re
2727
from collections.abc import Awaitable, Callable
28-
from typing import Any, Protocol, TypedDict
28+
from typing import Any, Protocol, TypedDict, cast
2929

3030
from pydantic import BaseModel
3131

@@ -88,6 +88,12 @@ def __call__(self, input: ResourceInput, ctx: ActionRunContext) -> Awaitable[Res
8888
...
8989

9090

91+
class MatchableAction(Protocol):
92+
"""Protocol for actions that have a matches method."""
93+
94+
matches: Callable[[ResourceInput], bool]
95+
96+
9197
ResourceArgument = Action | str
9298

9399

@@ -136,9 +142,9 @@ async def lookup_resource_by_name(registry: Registry, name: str) -> Action:
136142
ValueError: If the resource cannot be found.
137143
"""
138144
resource = (
139-
await registry.resolve_action(ActionKind.RESOURCE, name)
140-
or await registry.resolve_action(ActionKind.RESOURCE, f'/resource/{name}')
141-
or await registry.resolve_action(ActionKind.RESOURCE, f'/dynamic-action-provider/{name}')
145+
await registry.resolve_action(cast(ActionKind, ActionKind.RESOURCE), name)
146+
or await registry.resolve_action(cast(ActionKind, ActionKind.RESOURCE), f'/resource/{name}')
147+
or await registry.resolve_action(cast(ActionKind, ActionKind.RESOURCE), f'/dynamic-action-provider/{name}')
142148
)
143149
if not resource:
144150
raise ValueError(f'Resource {name} not found')
@@ -161,7 +167,7 @@ def define_resource(registry: Registry, opts: ResourceOptions, fn: ResourceFn) -
161167
"""
162168
action = dynamic_resource(opts, fn)
163169

164-
action.matches = create_matcher(opts.get('uri'), opts.get('template'))
170+
cast(MatchableAction, action).matches = create_matcher(opts.get('uri'), opts.get('template'))
165171

166172
# Mark as not dynamic since it's being registered
167173
action.metadata['dynamic'] = False
@@ -279,7 +285,7 @@ async def wrapped_fn(input_data: ResourceInput, ctx: ActionRunContext) -> Resour
279285

280286
act = Action(
281287
name=name,
282-
kind=ActionKind.RESOURCE,
288+
kind=cast(ActionKind, ActionKind.RESOURCE),
283289
fn=wrapped_fn,
284290
metadata={
285291
'resource': {
@@ -291,7 +297,7 @@ async def wrapped_fn(input_data: ResourceInput, ctx: ActionRunContext) -> Resour
291297
description=opts.get('description'),
292298
span_metadata={'genkit:metadata:resource:uri': uri},
293299
)
294-
act.matches = matcher
300+
cast(MatchableAction, act).matches = matcher
295301
return act
296302

297303

@@ -385,23 +391,27 @@ async def find_matching_resource(
385391
"""
386392
if dynamic_resources:
387393
for action in dynamic_resources:
388-
if hasattr(action, 'matches') and action.matches(input_data):
394+
if hasattr(action, 'matches') and cast(MatchableAction, action).matches(input_data):
389395
return action
390396

391397
# Try exact match in registry
392-
resource = await registry.resolve_action(ActionKind.RESOURCE, input_data.uri)
398+
resource = await registry.resolve_action(cast(ActionKind, ActionKind.RESOURCE), input_data.uri)
393399
if resource:
394400
return resource
395401

396402
# Iterate all resources to check for matches (e.g. templates)
397403
# This is less efficient but necessary for template matching if not optimized
398-
resources = registry.get_actions_by_kind(ActionKind.RESOURCE) if hasattr(registry, 'get_actions_by_kind') else {}
404+
resources = (
405+
registry.get_actions_by_kind(cast(ActionKind, ActionKind.RESOURCE))
406+
if hasattr(registry, 'get_actions_by_kind')
407+
else {}
408+
)
399409
if not resources and hasattr(registry, '_entries'):
400410
# Fallback for compatibility if registry instance is old (unlikely in this context)
401-
resources = registry._entries.get(ActionKind.RESOURCE, {})
411+
resources = registry._entries.get(cast(ActionKind, ActionKind.RESOURCE), {})
402412

403413
for action in resources.values():
404-
if hasattr(action, 'matches') and action.matches(input_data):
414+
if hasattr(action, 'matches') and cast(MatchableAction, action).matches(input_data):
405415
return action
406416

407417
return None

py/packages/genkit/src/genkit/blocks/retriever.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@
2222
to accomplish a task.
2323
"""
2424

25+
import inspect
2526
from collections.abc import Callable
26-
from typing import Any, Generic, TypeVar
27+
from typing import Any, Awaitable, Generic, TypeVar, cast
2728

2829
from pydantic import BaseModel, ConfigDict, Field
2930

@@ -34,8 +35,8 @@
3435
from genkit.core.typing import DocumentData, RetrieverResponse
3536

3637
T = TypeVar('T')
37-
# type RetrieverFn[T] = Callable[[Document, T], RetrieverResponse]
38-
RetrieverFn = Callable[[Document, T], RetrieverResponse]
38+
# type RetrieverFn[T] = Callable[[Document, T], RetrieverResponse | Awaitable[RetrieverResponse]]
39+
RetrieverFn = Callable[[Document, T], RetrieverResponse | Awaitable[RetrieverResponse]]
3940

4041

4142
class Retriever(Generic[T]):
@@ -115,9 +116,8 @@ def retriever_action_metadata(
115116
retriever_metadata_dict['retriever']['supports'] = options.supports.model_dump(exclude_none=True, by_alias=True)
116117

117118
retriever_metadata_dict['retriever']['customOptions'] = options.config_schema if options.config_schema else None
118-
119119
return ActionMetadata(
120-
kind=ActionKind.RETRIEVER,
120+
kind=cast(ActionKind, ActionKind.RETRIEVER),
121121
name=name,
122122
input_json_schema=to_json_schema(RetrieverRequest),
123123
output_json_schema=to_json_schema(RetrieverResponse),
@@ -191,7 +191,7 @@ def indexer_action_metadata(
191191
indexer_metadata_dict['indexer']['customOptions'] = options.config_schema if options.config_schema else None
192192

193193
return ActionMetadata(
194-
kind=ActionKind.INDEXER,
194+
kind=cast(ActionKind, ActionKind.INDEXER),
195195
name=name,
196196
input_json_schema=to_json_schema(IndexerRequest),
197197
output_json_schema=to_json_schema(None),
@@ -222,18 +222,20 @@ async def wrapper(
222222
request: RetrieverRequest,
223223
ctx: Any,
224224
) -> RetrieverResponse:
225-
return await fn(request.query, request.options)
225+
query = Document.from_document_data(request.query)
226+
res = fn(query, request.options)
227+
return await res if inspect.isawaitable(res) else res
226228

227229
registry.register_action(
228-
kind=ActionKind.RETRIEVER,
230+
kind=cast(ActionKind, ActionKind.RETRIEVER),
229231
name=name,
230232
fn=wrapper,
231233
metadata=metadata.metadata,
232234
span_metadata=metadata.metadata,
233235
)
234236

235237

236-
IndexerFn = Callable[[list[Document], T], None]
238+
IndexerFn = Callable[[list[Document], T], None | Awaitable[None]]
237239

238240

239241
def define_indexer(
@@ -249,11 +251,13 @@ async def wrapper(
249251
request: IndexerRequest,
250252
ctx: Any,
251253
) -> None:
252-
docs = [Document.from_data(d) for d in request.documents]
253-
await fn(docs, request.options)
254+
docs = [Document.from_document_data(d) for d in request.documents]
255+
res = fn(docs, request.options)
256+
if inspect.isawaitable(res):
257+
await res
254258

255259
registry.register_action(
256-
kind=ActionKind.INDEXER,
260+
kind=cast(ActionKind, ActionKind.INDEXER),
257261
name=name,
258262
fn=wrapper,
259263
metadata=metadata.metadata,

py/packages/genkit/src/genkit/blocks/tools.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,17 @@
2121
allowing for controlled interruptions and specific response formatting.
2222
"""
2323

24-
from typing import Any
24+
from typing import Any, Protocol, cast
2525

2626
from genkit.core.action import ActionRunContext
27-
from genkit.core.typing import Metadata, Part, ToolRequestPart, ToolResponse
27+
from genkit.core.typing import Metadata, Part, ToolRequestPart, ToolResponse, ToolResponsePart
28+
29+
30+
class ToolRequestLike(Protocol):
31+
"""Protocol for objects that look like a ToolRequest."""
32+
33+
name: str
34+
ref: str | None
2835

2936

3037
class ToolRunContext(ActionRunContext):
@@ -71,13 +78,13 @@ class ToolInterruptError(Exception):
7178
causing a hard failure.
7279
"""
7380

74-
def __init__(self, metadata: dict[str, Any]):
81+
def __init__(self, metadata: dict[str, Any] | None = None):
7582
"""Initializes the ToolInterruptError.
7683
7784
Args:
7885
metadata: Metadata associated with the interruption.
7986
"""
80-
self.metadata = metadata
87+
self.metadata = metadata or {}
8188

8289

8390
def tool_response(
@@ -110,13 +117,18 @@ def tool_response(
110117
elif metadata:
111118
interrupt_metadata = metadata
112119

120+
tr = cast(ToolRequestLike, tool_request)
113121
return Part(
114-
tool_response=ToolResponse(
115-
name=tool_request.name,
116-
ref=tool_request.ref,
117-
output=response_data,
118-
),
119-
metadata={
120-
'interruptResponse': interrupt_metadata,
121-
},
122+
root=ToolResponsePart(
123+
toolResponse=ToolResponse(
124+
name=tr.name,
125+
ref=tr.ref,
126+
output=response_data,
127+
),
128+
metadata=Metadata(
129+
root={
130+
'interruptResponse': interrupt_metadata,
131+
}
132+
),
133+
)
122134
)

py/packages/genkit/src/genkit/core/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pydantic import TypeAdapter
2222

2323

24-
def to_json_schema(schema: type | dict[str, Any]) -> dict[str, Any]:
24+
def to_json_schema(schema: type | dict[str, Any] | None) -> dict[str, Any]:
2525
"""Converts a Python type to a JSON schema.
2626
2727
If the input `schema` is already a dictionary (assumed json schema), it is

0 commit comments

Comments
 (0)