Skip to content

Commit 3b1d556

Browse files
committed
fix(py/genkit): ty check fixes for genkit.ai
1 parent 87bc74d commit 3b1d556

File tree

3 files changed

+54
-45
lines changed

3 files changed

+54
-45
lines changed

py/packages/genkit/src/genkit/ai/_aio.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
class while customizing it with any plugins.
2121
"""
2222

23+
import asyncio
2324
import uuid
24-
from asyncio import Future
2525
from collections.abc import AsyncIterator
2626
from pathlib import Path
27-
from typing import Any
27+
from typing import Any, cast
2828

2929
from genkit.aio import Channel
3030
from genkit.blocks.document import Document
@@ -253,7 +253,7 @@ def generate_stream(
253253
messages: list[Message] | None = None,
254254
tools: list[str] | None = None,
255255
return_tool_requests: bool | None = None,
256-
tool_choice: ToolChoice = None,
256+
tool_choice: ToolChoice | None = None,
257257
config: GenerationCommonConfig | dict[str, Any] | None = None,
258258
max_turns: int | None = None,
259259
context: dict[str, Any] | None = None,
@@ -268,7 +268,7 @@ def generate_stream(
268268
timeout: float | None = None,
269269
) -> tuple[
270270
AsyncIterator[GenerateResponseChunkWrapper],
271-
Future[GenerateResponseWrapper],
271+
asyncio.Future[GenerateResponseWrapper],
272272
]:
273273
"""Streams generated text or structured data using a language model.
274274
@@ -351,7 +351,7 @@ def generate_stream(
351351
use=use,
352352
on_chunk=lambda c: stream.send(c),
353353
)
354-
stream.set_close_future(resp)
354+
stream.set_close_future(asyncio.create_task(resp))
355355

356356
return stream, stream.closed
357357

@@ -389,19 +389,20 @@ async def retrieve(
389389

390390
request_options = {**(retriever_config or {}), **(options or {})}
391391

392-
retrieve_action = await self.registry.resolve_action(ActionKind.RETRIEVER, retriever_name)
392+
retrieve_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.RETRIEVER), retriever_name)
393393
if retrieve_action is None:
394394
raise ValueError(f'Retriever "{retriever_name}" not found')
395395

396396
return (
397397
await retrieve_action.arun(
398398
RetrieverRequest(
399-
query=query,
399+
query=query, # type: ignore[arg-type]
400400
options=request_options if request_options else None,
401401
)
402402
)
403403
).response
404404

405+
405406
async def index(
406407
self,
407408
indexer: str | IndexerRef | None = None,
@@ -430,13 +431,16 @@ async def index(
430431

431432
req_options = {**(indexer_config or {}), **(options or {})}
432433

433-
index_action = await self.registry.resolve_action(ActionKind.INDEXER, indexer_name)
434+
index_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.INDEXER), indexer_name)
434435
if index_action is None:
435436
raise ValueError(f'Indexer "{indexer_name}" not found')
436437

438+
if documents is None:
439+
raise ValueError('Documents must be specified for indexing.')
440+
437441
await index_action.arun(
438442
IndexerRequest(
439-
documents=documents,
443+
documents=documents, # type: ignore[arg-type]
440444
options=req_options if req_options else None,
441445
)
442446
)
@@ -464,11 +468,14 @@ async def embed(
464468
# Merge options passed to embed() with config from EmbedderRef
465469
final_options = {**(embedder_config or {}), **(options or {})}
466470

467-
embed_action = await self.registry.resolve_action(ActionKind.EMBEDDER, embedder_name)
471+
embed_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.EMBEDDER), embedder_name)
468472
if embed_action is None:
469473
raise ValueError(f'Embedder "{embedder_name}" not found')
470474

471-
return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response
475+
if documents is None:
476+
raise ValueError('Documents must be specified for embedding.')
477+
478+
return (await embed_action.arun(EmbedRequest(input=documents, options=final_options))).response # type: ignore[arg-type]
472479

473480
async def evaluate(
474481
self,
@@ -501,19 +508,22 @@ async def evaluate(
501508

502509
final_options = {**(evaluator_config or {}), **(options or {})}
503510

504-
eval_action = await self.registry.resolve_action(ActionKind.EVALUATOR, evaluator_name)
511+
eval_action = await self.registry.resolve_action(cast(ActionKind, ActionKind.EVALUATOR), evaluator_name)
505512
if eval_action is None:
506513
raise ValueError(f'Evaluator "{evaluator_name}" not found')
507514

508515
if not eval_run_id:
509516
eval_run_id = str(uuid.uuid4())
510517

518+
if dataset is None:
519+
raise ValueError('Dataset must be specified for evaluation.')
520+
511521
return (
512522
await eval_action.arun(
513523
EvalRequest(
514524
dataset=dataset,
515525
options=final_options,
516-
eval_run_id=eval_run_id,
526+
evalRunId=eval_run_id,
517527
)
518528
)
519529
).response

py/packages/genkit/src/genkit/ai/_registry.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
import uuid
4444
from collections.abc import AsyncIterator, Callable
4545
from functools import wraps
46-
from typing import TYPE_CHECKING, Any
46+
from typing import TYPE_CHECKING, Any, cast
4747

4848
if TYPE_CHECKING:
4949
from genkit.blocks.resource import ResourceFn, ResourceOptions
@@ -144,11 +144,11 @@ def wrapper(func: Callable) -> Callable:
144144
Returns:
145145
The wrapped function that executes the flow.
146146
"""
147-
flow_name = name if name is not None else func.__name__
147+
flow_name = name if name is not None else getattr(func, '__name__', 'unnamed_flow')
148148
flow_description = get_func_description(func, description)
149149
action = self.registry.register_action(
150150
name=flow_name,
151-
kind=ActionKind.FLOW,
151+
kind=cast(ActionKind, ActionKind.FLOW),
152152
fn=func,
153153
description=flow_description,
154154
span_metadata={'genkit:metadata:flow:name': flow_name},
@@ -257,7 +257,7 @@ def wrapper(func: Callable) -> Callable:
257257
Returns:
258258
The wrapped function that executes the tool.
259259
"""
260-
tool_name = name if name is not None else func.__name__
260+
tool_name = name if name is not None else getattr(func, '__name__', 'unnamed_tool')
261261
tool_description = get_func_description(func, description)
262262

263263
input_spec = inspect.getfullargspec(func)
@@ -275,7 +275,7 @@ def tool_fn_wrapper(*args):
275275

276276
action = self.registry.register_action(
277277
name=tool_name,
278-
kind=ActionKind.TOOL,
278+
kind=cast(ActionKind, ActionKind.TOOL),
279279
description=tool_description,
280280
fn=tool_fn_wrapper,
281281
metadata_fn=func,
@@ -315,10 +315,10 @@ def define_retriever(
315315
self,
316316
name: str,
317317
fn: RetrieverFn,
318-
config_schema: BaseModel | dict[str, Any] | None = None,
318+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
319319
metadata: dict[str, Any] | None = None,
320320
description: str | None = None,
321-
) -> Callable[[Callable], Callable]:
321+
) -> Action:
322322
"""Define a retriever action.
323323
324324
Args:
@@ -339,7 +339,7 @@ def define_retriever(
339339
retriever_description = get_func_description(fn, description)
340340
return self.registry.register_action(
341341
name=name,
342-
kind=ActionKind.RETRIEVER,
342+
kind=cast(ActionKind, ActionKind.RETRIEVER),
343343
fn=fn,
344344
metadata=retriever_meta,
345345
description=retriever_description,
@@ -349,10 +349,10 @@ def define_indexer(
349349
self,
350350
name: str,
351351
fn: IndexerFn,
352-
config_schema: BaseModel | dict[str, Any] | None = None,
352+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
353353
metadata: dict[str, Any] | None = None,
354354
description: str | None = None,
355-
) -> Callable[[Callable], Callable]:
355+
) -> Action:
356356
"""Define an indexer action.
357357
358358
Args:
@@ -374,7 +374,7 @@ def define_indexer(
374374
indexer_description = get_func_description(fn, description)
375375
return self.registry.register_action(
376376
name=name,
377-
kind=ActionKind.INDEXER,
377+
kind=cast(ActionKind, ActionKind.INDEXER),
378378
fn=fn,
379379
metadata=indexer_meta,
380380
description=indexer_description,
@@ -384,7 +384,7 @@ def define_reranker(
384384
self,
385385
name: str,
386386
fn: RerankerFn,
387-
config_schema: BaseModel | dict[str, Any] | None = None,
387+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
388388
metadata: dict[str, Any] | None = None,
389389
description: str | None = None,
390390
) -> Action:
@@ -426,7 +426,7 @@ def define_reranker(
426426
name=name,
427427
fn=fn,
428428
options=RerankerOptions(
429-
config_schema=reranker_meta['reranker'].get('customOptions'),
429+
configSchema=reranker_meta['reranker'].get('customOptions'),
430430
label=reranker_meta['reranker'].get('label'),
431431
),
432432
description=reranker_description,
@@ -482,7 +482,7 @@ def define_evaluator(
482482
definition: str,
483483
fn: EvaluatorFn,
484484
is_billed: bool = False,
485-
config_schema: BaseModel | dict[str, Any] | None = None,
485+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
486486
metadata: dict[str, Any] | None = None,
487487
description: str | None = None,
488488
) -> Action:
@@ -544,13 +544,13 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
544544
logger.debug(traceback.format_exc())
545545
evaluation = Score(
546546
error=f'Evaluation of test case {datapoint.test_case_id} failed: \n{str(e)}',
547-
status=EvalStatusEnum.FAIL,
547+
status=cast(EvalStatusEnum, EvalStatusEnum.FAIL),
548548
)
549549
eval_responses.append(
550550
EvalFnResponse(
551-
span_id=span_id,
552-
trace_id=trace_id,
553-
test_case_id=datapoint.test_case_id,
551+
spanId=span_id,
552+
traceId=trace_id,
553+
testCaseId=datapoint.test_case_id,
554554
evaluation=evaluation,
555555
)
556556
)
@@ -566,11 +566,11 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
566566
logger.debug(traceback.format_exc())
567567
evaluation = Score(
568568
error=f'Evaluation of test case {datapoint.test_case_id} failed: \n{str(e)}',
569-
status=EvalStatusEnum.FAIL,
569+
status=cast(EvalStatusEnum, EvalStatusEnum.FAIL),
570570
)
571571
eval_responses.append(
572572
EvalFnResponse(
573-
test_case_id=datapoint.test_case_id,
573+
testCaseId=datapoint.test_case_id,
574574
evaluation=evaluation,
575575
)
576576
)
@@ -581,7 +581,7 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
581581

582582
return self.registry.register_action(
583583
name=name,
584-
kind=ActionKind.EVALUATOR,
584+
kind=cast(ActionKind, ActionKind.EVALUATOR),
585585
fn=eval_stepper_fn,
586586
metadata=evaluator_meta,
587587
description=evaluator_description,
@@ -594,10 +594,10 @@ def define_batch_evaluator(
594594
definition: str,
595595
fn: BatchEvaluatorFn,
596596
is_billed: bool = False,
597-
config_schema: BaseModel | dict[str, Any] | None = None,
597+
config_schema: type[BaseModel] | dict[str, Any] | None = None,
598598
metadata: dict[str, Any] | None = None,
599599
description: str | None = None,
600-
) -> Callable[[Callable], Callable]:
600+
) -> Action:
601601
"""Define a batch evaluator action.
602602
603603
This action runs the callback function on the entire dataset.
@@ -627,7 +627,7 @@ def define_batch_evaluator(
627627
evaluator_description = get_func_description(fn, description)
628628
return self.registry.register_action(
629629
name=name,
630-
kind=ActionKind.EVALUATOR,
630+
kind=cast(ActionKind, ActionKind.EVALUATOR),
631631
fn=fn,
632632
metadata=evaluator_meta,
633633
description=evaluator_description,
@@ -666,7 +666,7 @@ def define_model(
666666
model_description = get_func_description(fn, description)
667667
return self.registry.register_action(
668668
name=name,
669-
kind=ActionKind.MODEL,
669+
kind=cast(ActionKind, ActionKind.MODEL),
670670
fn=fn,
671671
metadata=model_meta,
672672
description=model_description,
@@ -706,7 +706,7 @@ def define_embedder(
706706
embedder_description = get_func_description(fn, description)
707707
return self.registry.register_action(
708708
name=name,
709-
kind=ActionKind.EMBEDDER,
709+
kind=cast(ActionKind, ActionKind.EMBEDDER),
710710
fn=fn,
711711
metadata=embedder_meta,
712712
description=embedder_description,

py/packages/genkit/src/genkit/blocks/evaluator.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
model.
2222
"""
2323

24-
from collections.abc import Callable
24+
from collections.abc import Awaitable, Callable, Coroutine
2525
from typing import Any, TypeVar
2626

2727
from pydantic import BaseModel, ConfigDict, Field
@@ -35,12 +35,11 @@
3535
T = TypeVar('T')
3636

3737
# User-provided evaluator function that evaluates a single datapoint.
38-
# type EvaluatorFn[T] = Callable[[BaseDataPoint, T], EvalFnResponse]
39-
EvaluatorFn = Callable[[BaseDataPoint, T], EvalFnResponse]
38+
# Must be async (coroutine function).
39+
EvaluatorFn = Callable[[BaseDataPoint, T], Coroutine[Any, Any, EvalFnResponse]]
4040

4141
# User-provided batch evaluator function that evaluates an EvaluationRequest
42-
# type BatchEvaluatorFn[T] = Callable[[EvalRequest, T], list[EvalFnResponse]]
43-
BatchEvaluatorFn = Callable[[EvalRequest, T], list[EvalFnResponse]]
42+
BatchEvaluatorFn = Callable[[EvalRequest, T], Coroutine[Any, Any, list[EvalFnResponse]]]
4443

4544

4645
class EvaluatorRef(BaseModel):
@@ -62,4 +61,4 @@ def evaluator_ref(name: str, config_schema: Any | None = None) -> EvaluatorRef:
6261
Returns:
6362
An EvaluatorRef instance.
6463
"""
65-
return EvaluatorRef(name=name, config_schema=config_schema)
64+
return EvaluatorRef(name=name, configSchema=config_schema)

0 commit comments

Comments
 (0)