4343import uuid
4444from collections .abc import AsyncIterator , Callable
4545from functools import wraps
46- from typing import TYPE_CHECKING , Any
46+ from typing import TYPE_CHECKING , Any , cast
4747
4848if TYPE_CHECKING :
4949 from genkit .blocks .resource import ResourceFn , ResourceOptions
@@ -144,11 +144,11 @@ def wrapper(func: Callable) -> Callable:
144144 Returns:
145145 The wrapped function that executes the flow.
146146 """
147- flow_name = name if name is not None else func . __name__
147+ flow_name = name if name is not None else getattr ( func , ' __name__' , 'unnamed_flow' )
148148 flow_description = get_func_description (func , description )
149149 action = self .registry .register_action (
150150 name = flow_name ,
151- kind = ActionKind .FLOW ,
151+ kind = cast ( ActionKind , ActionKind .FLOW ) ,
152152 fn = func ,
153153 description = flow_description ,
154154 span_metadata = {'genkit:metadata:flow:name' : flow_name },
@@ -257,7 +257,7 @@ def wrapper(func: Callable) -> Callable:
257257 Returns:
258258 The wrapped function that executes the tool.
259259 """
260- tool_name = name if name is not None else func . __name__
260+ tool_name = name if name is not None else getattr ( func , ' __name__' , 'unnamed_tool' )
261261 tool_description = get_func_description (func , description )
262262
263263 input_spec = inspect .getfullargspec (func )
@@ -275,7 +275,7 @@ def tool_fn_wrapper(*args):
275275
276276 action = self .registry .register_action (
277277 name = tool_name ,
278- kind = ActionKind .TOOL ,
278+ kind = cast ( ActionKind , ActionKind .TOOL ) ,
279279 description = tool_description ,
280280 fn = tool_fn_wrapper ,
281281 metadata_fn = func ,
@@ -315,10 +315,10 @@ def define_retriever(
315315 self ,
316316 name : str ,
317317 fn : RetrieverFn ,
318- config_schema : BaseModel | dict [str , Any ] | None = None ,
318+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
319319 metadata : dict [str , Any ] | None = None ,
320320 description : str | None = None ,
321- ) -> Callable [[ Callable ], Callable ] :
321+ ) -> Action :
322322 """Define a retriever action.
323323
324324 Args:
@@ -339,7 +339,7 @@ def define_retriever(
339339 retriever_description = get_func_description (fn , description )
340340 return self .registry .register_action (
341341 name = name ,
342- kind = ActionKind .RETRIEVER ,
342+ kind = cast ( ActionKind , ActionKind .RETRIEVER ) ,
343343 fn = fn ,
344344 metadata = retriever_meta ,
345345 description = retriever_description ,
@@ -349,10 +349,10 @@ def define_indexer(
349349 self ,
350350 name : str ,
351351 fn : IndexerFn ,
352- config_schema : BaseModel | dict [str , Any ] | None = None ,
352+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
353353 metadata : dict [str , Any ] | None = None ,
354354 description : str | None = None ,
355- ) -> Callable [[ Callable ], Callable ] :
355+ ) -> Action :
356356 """Define an indexer action.
357357
358358 Args:
@@ -374,7 +374,7 @@ def define_indexer(
374374 indexer_description = get_func_description (fn , description )
375375 return self .registry .register_action (
376376 name = name ,
377- kind = ActionKind .INDEXER ,
377+ kind = cast ( ActionKind , ActionKind .INDEXER ) ,
378378 fn = fn ,
379379 metadata = indexer_meta ,
380380 description = indexer_description ,
@@ -384,7 +384,7 @@ def define_reranker(
384384 self ,
385385 name : str ,
386386 fn : RerankerFn ,
387- config_schema : BaseModel | dict [str , Any ] | None = None ,
387+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
388388 metadata : dict [str , Any ] | None = None ,
389389 description : str | None = None ,
390390 ) -> Action :
@@ -426,7 +426,7 @@ def define_reranker(
426426 name = name ,
427427 fn = fn ,
428428 options = RerankerOptions (
429- config_schema = reranker_meta ['reranker' ].get ('customOptions' ),
429+ configSchema = reranker_meta ['reranker' ].get ('customOptions' ),
430430 label = reranker_meta ['reranker' ].get ('label' ),
431431 ),
432432 description = reranker_description ,
@@ -482,7 +482,7 @@ def define_evaluator(
482482 definition : str ,
483483 fn : EvaluatorFn ,
484484 is_billed : bool = False ,
485- config_schema : BaseModel | dict [str , Any ] | None = None ,
485+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
486486 metadata : dict [str , Any ] | None = None ,
487487 description : str | None = None ,
488488 ) -> Action :
@@ -544,13 +544,13 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
544544 logger .debug (traceback .format_exc ())
545545 evaluation = Score (
546546 error = f'Evaluation of test case { datapoint .test_case_id } failed: \n { str (e )} ' ,
547- status = EvalStatusEnum .FAIL ,
547+ status = cast ( EvalStatusEnum , EvalStatusEnum .FAIL ) ,
548548 )
549549 eval_responses .append (
550550 EvalFnResponse (
551- span_id = span_id ,
552- trace_id = trace_id ,
553- test_case_id = datapoint .test_case_id ,
551+ spanId = span_id ,
552+ traceId = trace_id ,
553+ testCaseId = datapoint .test_case_id ,
554554 evaluation = evaluation ,
555555 )
556556 )
@@ -566,11 +566,11 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
566566 logger .debug (traceback .format_exc ())
567567 evaluation = Score (
568568 error = f'Evaluation of test case { datapoint .test_case_id } failed: \n { str (e )} ' ,
569- status = EvalStatusEnum .FAIL ,
569+ status = cast ( EvalStatusEnum , EvalStatusEnum .FAIL ) ,
570570 )
571571 eval_responses .append (
572572 EvalFnResponse (
573- test_case_id = datapoint .test_case_id ,
573+ testCaseId = datapoint .test_case_id ,
574574 evaluation = evaluation ,
575575 )
576576 )
@@ -581,7 +581,7 @@ async def eval_stepper_fn(req: EvalRequest) -> EvalResponse:
581581
582582 return self .registry .register_action (
583583 name = name ,
584- kind = ActionKind .EVALUATOR ,
584+ kind = cast ( ActionKind , ActionKind .EVALUATOR ) ,
585585 fn = eval_stepper_fn ,
586586 metadata = evaluator_meta ,
587587 description = evaluator_description ,
@@ -594,10 +594,10 @@ def define_batch_evaluator(
594594 definition : str ,
595595 fn : BatchEvaluatorFn ,
596596 is_billed : bool = False ,
597- config_schema : BaseModel | dict [str , Any ] | None = None ,
597+ config_schema : type [ BaseModel ] | dict [str , Any ] | None = None ,
598598 metadata : dict [str , Any ] | None = None ,
599599 description : str | None = None ,
600- ) -> Callable [[ Callable ], Callable ] :
600+ ) -> Action :
601601 """Define a batch evaluator action.
602602
603603 This action runs the callback function on the entire dataset.
@@ -627,7 +627,7 @@ def define_batch_evaluator(
627627 evaluator_description = get_func_description (fn , description )
628628 return self .registry .register_action (
629629 name = name ,
630- kind = ActionKind .EVALUATOR ,
630+ kind = cast ( ActionKind , ActionKind .EVALUATOR ) ,
631631 fn = fn ,
632632 metadata = evaluator_meta ,
633633 description = evaluator_description ,
@@ -666,7 +666,7 @@ def define_model(
666666 model_description = get_func_description (fn , description )
667667 return self .registry .register_action (
668668 name = name ,
669- kind = ActionKind .MODEL ,
669+ kind = cast ( ActionKind , ActionKind .MODEL ) ,
670670 fn = fn ,
671671 metadata = model_meta ,
672672 description = model_description ,
@@ -706,7 +706,7 @@ def define_embedder(
706706 embedder_description = get_func_description (fn , description )
707707 return self .registry .register_action (
708708 name = name ,
709- kind = ActionKind .EMBEDDER ,
709+ kind = cast ( ActionKind , ActionKind .EMBEDDER ) ,
710710 fn = fn ,
711711 metadata = embedder_meta ,
712712 description = embedder_description ,
0 commit comments