Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
200 changes: 112 additions & 88 deletions lib/crewai/src/crewai/llms/providers/openai/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,14 +1696,107 @@ def _handle_completion(

return content

def _finalize_streaming_response(
self,
full_response: str,
tool_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | list[dict[str, Any]]:
"""Finalize a streaming response with usage tracking, tool call handling, and events.

Args:
full_response: The accumulated text response from the stream.
tool_calls: Accumulated tool calls from the stream, keyed by index.
usage_data: Token usage data from the stream.
params: The completion parameters containing messages.
available_functions: Available functions for tool calling.
from_task: Task that initiated the call.
from_agent: Agent that initiated the call.

Returns:
Tool calls list when tools were invoked without available_functions,
tool execution result when available_functions is provided,
or the text response string.
"""
self._track_token_usage_internal(usage_data)

if tool_calls and not available_functions:
tool_calls_list = [
{
"id": call_data["id"],
"type": "function",
"function": {
"name": call_data["name"],
"arguments": call_data["arguments"],
},
"index": call_data["index"],
}
for call_data in tool_calls.values()
]
self._emit_call_completed_event(
response=tool_calls_list,
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return tool_calls_list

if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
arguments = call_data["arguments"]

if not function_name or not arguments:
continue

if function_name not in available_functions:
logging.warning(
f"Function '{function_name}' not found in available functions"
)
continue

try:
function_args = json.loads(arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue

result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)

if result is not None:
return result

full_response = self._apply_stop_words(full_response)

self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)

return full_response

def _handle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | BaseModel:
) -> str | list[dict[str, Any]] | BaseModel:
"""Handle streaming chat completion."""
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
Expand Down Expand Up @@ -1820,54 +1913,20 @@ def _handle_streaming_completion(
response_id=response_id_stream,
)

self._track_token_usage_internal(usage_data)

if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
arguments = call_data["arguments"]

# Skip if function name is empty or arguments are empty
if not function_name or not arguments:
continue

# Check if function exists in available functions
if function_name not in available_functions:
logging.warning(
f"Function '{function_name}' not found in available functions"
)
continue

try:
function_args = json.loads(arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue

result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)

if result is not None:
return result

full_response = self._apply_stop_words(full_response)

self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
result = self._finalize_streaming_response(
full_response=full_response,
tool_calls=tool_calls,
usage_data=usage_data,
params=params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)

return self._invoke_after_llm_call_hooks(
params["messages"], full_response, from_agent
)
if isinstance(result, str):
return self._invoke_after_llm_call_hooks(
params["messages"], result, from_agent
)
return result

async def _ahandle_completion(
self,
Expand Down Expand Up @@ -2016,7 +2075,7 @@ async def _ahandle_streaming_completion(
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | BaseModel:
) -> str | list[dict[str, Any]] | BaseModel:
"""Handle async streaming chat completion."""
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
Expand Down Expand Up @@ -2142,51 +2201,16 @@ async def _ahandle_streaming_completion(
response_id=response_id_stream,
)

self._track_token_usage_internal(usage_data)

if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
arguments = call_data["arguments"]

if not function_name or not arguments:
continue

if function_name not in available_functions:
logging.warning(
f"Function '{function_name}' not found in available functions"
)
continue

try:
function_args = json.loads(arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
continue

result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)

if result is not None:
return result

full_response = self._apply_stop_words(full_response)

self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
return self._finalize_streaming_response(
full_response=full_response,
tool_calls=tool_calls,
usage_data=usage_data,
params=params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)

return full_response

def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return not self.is_o1_model
Expand Down
Loading
Loading