Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions lib/crewai/src/crewai/cli/constants.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any


DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_PROVIDER = "workos"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE = "client_01JNJQWBJ4SPFN3SWJM5T7BDG8"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID = "client_01JYT06R59SP0NXYGD994NFXXX"
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN = "login.crewai.com"

ENV_VARS = {
ENV_VARS: dict[str, list[dict[str, Any]]] = {
"openai": [
{
"prompt": "Enter your OPENAI API key (press Enter to skip)",
Expand Down Expand Up @@ -112,7 +115,7 @@
}


PROVIDERS = [
PROVIDERS: list[str] = [
"openai",
"anthropic",
"gemini",
Expand All @@ -127,7 +130,7 @@
"sambanova",
]

MODELS = {
MODELS: dict[str, list[str]] = {
"openai": [
"gpt-4",
"gpt-4.1",
Expand Down
50 changes: 46 additions & 4 deletions lib/crewai/src/crewai/cli/create_crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sys

import click
import tomli

from crewai.cli.constants import ENV_VARS, MODELS
from crewai.cli.provider import (
Expand All @@ -13,7 +14,31 @@
from crewai.cli.utils import copy_template, load_env_vars, write_env_file


def create_folder_structure(name, parent_folder=None):
def get_reserved_script_names() -> set[str]:
"""Get reserved script names from pyproject.toml template.

Returns:
Set of reserved script names that would conflict with crew folder names.
"""
package_dir = Path(__file__).parent
template_path = package_dir / "templates" / "crew" / "pyproject.toml"

with open(template_path, "r") as f:
template_content = f.read()

template_content = template_content.replace("{{folder_name}}", "_placeholder_")
template_content = template_content.replace("{{name}}", "placeholder")
template_content = template_content.replace("{{crew_name}}", "Placeholder")

template_data = tomli.loads(template_content)
script_names = set(template_data.get("project", {}).get("scripts", {}).keys())
script_names.discard("_placeholder_")
return script_names


def create_folder_structure(
name: str, parent_folder: str | None = None
) -> tuple[Path, str, str]:
import keyword
import re

Expand Down Expand Up @@ -51,6 +76,14 @@ def create_folder_structure(name, parent_folder=None):
f"Project name '{name}' would generate invalid Python module name '{folder_name}'"
)

reserved_names = get_reserved_script_names()
if folder_name in reserved_names:
raise ValueError(
f"Project name '{name}' would generate folder name '{folder_name}' which is reserved. "
f"Reserved names are: {', '.join(sorted(reserved_names))}. "
"Please choose a different name."
)

class_name = name.replace("_", " ").replace("-", " ").title().replace(" ", "")

class_name = re.sub(r"[^a-zA-Z0-9_]", "", class_name)
Expand Down Expand Up @@ -114,7 +147,9 @@ def create_folder_structure(name, parent_folder=None):
return folder_path, folder_name, class_name


def copy_template_files(folder_path, name, class_name, parent_folder):
def copy_template_files(
folder_path: Path, name: str, class_name: str, parent_folder: str | None
) -> None:
package_dir = Path(__file__).parent
templates_dir = package_dir / "templates" / "crew"

Expand Down Expand Up @@ -155,7 +190,12 @@ def copy_template_files(folder_path, name, class_name, parent_folder):
copy_template(src_file, dst_file, name, class_name, folder_path.name)


def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
def create_crew(
name: str,
provider: str | None = None,
skip_provider: bool = False,
parent_folder: str | None = None,
) -> None:
folder_path, folder_name, class_name = create_folder_structure(name, parent_folder)
env_vars = load_env_vars(folder_path)
if not skip_provider:
Expand Down Expand Up @@ -189,7 +229,9 @@ def create_crew(name, provider=None, skip_provider=False, parent_folder=None):
if selected_provider is None: # User typed 'q'
click.secho("Exiting...", fg="yellow")
sys.exit(0)
if selected_provider: # Valid selection
if selected_provider and isinstance(
selected_provider, str
): # Valid selection
break
click.secho(
"No provider selected. Please try again or press 'q' to exit.", fg="red"
Expand Down
97 changes: 49 additions & 48 deletions lib/crewai/src/crewai/cli/provider.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from collections import defaultdict
from collections.abc import Sequence
import json
import os
from pathlib import Path
import time
from typing import Any

import certifi
import click
Expand All @@ -11,16 +13,15 @@
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS


def select_choice(prompt_message, choices):
"""
Presents a list of choices to the user and prompts them to select one.
def select_choice(prompt_message: str, choices: Sequence[str]) -> str | None:
"""Presents a list of choices to the user and prompts them to select one.

Args:
- prompt_message (str): The message to display to the user before presenting the choices.
- choices (list): A list of options to present to the user.
prompt_message: The message to display to the user before presenting the choices.
choices: A list of options to present to the user.

Returns:
- str: The selected choice from the list, or None if the user chooses to quit.
The selected choice from the list, or None if the user chooses to quit.
"""

provider_models = get_provider_data()
Expand Down Expand Up @@ -52,16 +53,14 @@ def select_choice(prompt_message, choices):
)


def select_provider(provider_models):
"""
Presents a list of providers to the user and prompts them to select one.
def select_provider(provider_models: dict[str, list[str]]) -> str | None | bool:
"""Presents a list of providers to the user and prompts them to select one.

Args:
- provider_models (dict): A dictionary of provider models.
provider_models: A dictionary of provider models.

Returns:
- str: The selected provider
- None: If user explicitly quits
The selected provider, None if user explicitly quits, or False if no selection.
"""
predefined_providers = [p.lower() for p in PROVIDERS]
all_providers = sorted(set(predefined_providers + list(provider_models.keys())))
Expand All @@ -80,16 +79,15 @@ def select_provider(provider_models):
return provider.lower() if provider else False


def select_model(provider, provider_models):
"""
Presents a list of models for a given provider to the user and prompts them to select one.
def select_model(provider: str, provider_models: dict[str, list[str]]) -> str | None:
"""Presents a list of models for a given provider to the user and prompts them to select one.

Args:
- provider (str): The provider for which to select a model.
- provider_models (dict): A dictionary of provider models.
provider: The provider for which to select a model.
provider_models: A dictionary of provider models.

Returns:
- str: The selected model, or None if the operation is aborted or an invalid selection is made.
The selected model, or None if the operation is aborted or an invalid selection is made.
"""
predefined_providers = [p.lower() for p in PROVIDERS]

Expand All @@ -107,16 +105,17 @@ def select_model(provider, provider_models):
)


def load_provider_data(cache_file, cache_expiry):
"""
Loads provider data from a cache file if it exists and is not expired. If the cache is expired or corrupted, it fetches the data from the web.
def load_provider_data(cache_file: Path, cache_expiry: int) -> dict[str, Any] | None:
"""Loads provider data from a cache file if it exists and is not expired.

If the cache is expired or corrupted, it fetches the data from the web.

Args:
- cache_file (Path): The path to the cache file.
- cache_expiry (int): The cache expiry time in seconds.
cache_file: The path to the cache file.
cache_expiry: The cache expiry time in seconds.

Returns:
- dict or None: The loaded provider data or None if the operation fails.
The loaded provider data or None if the operation fails.
"""
current_time = time.time()
if (
Expand All @@ -137,32 +136,31 @@ def load_provider_data(cache_file, cache_expiry):
return fetch_provider_data(cache_file)


def read_cache_file(cache_file):
"""
Reads and returns the JSON content from a cache file. Returns None if the file contains invalid JSON.
def read_cache_file(cache_file: Path) -> dict[str, Any] | None:
"""Reads and returns the JSON content from a cache file.

Args:
- cache_file (Path): The path to the cache file.
cache_file: The path to the cache file.

Returns:
- dict or None: The JSON content of the cache file or None if the JSON is invalid.
The JSON content of the cache file or None if the JSON is invalid.
"""
try:
with open(cache_file, "r") as f:
return json.load(f)
data: dict[str, Any] = json.load(f)
return data
except json.JSONDecodeError:
return None


def fetch_provider_data(cache_file):
"""
Fetches provider data from a specified URL and caches it to a file.
def fetch_provider_data(cache_file: Path) -> dict[str, Any] | None:
"""Fetches provider data from a specified URL and caches it to a file.

Args:
- cache_file (Path): The path to the cache file.
cache_file: The path to the cache file.

Returns:
- dict or None: The fetched provider data or None if the operation fails.
The fetched provider data or None if the operation fails.
"""
ssl_config = os.environ["SSL_CERT_FILE"] = certifi.where()

Expand All @@ -180,36 +178,39 @@ def fetch_provider_data(cache_file):
return None


def download_data(response):
"""
Downloads data from a given HTTP response and returns the JSON content.
def download_data(response: requests.Response) -> dict[str, Any]:
"""Downloads data from a given HTTP response and returns the JSON content.

Args:
- response (requests.Response): The HTTP response object.
response: The HTTP response object.

Returns:
- dict: The JSON content of the response.
The JSON content of the response.
"""
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
data_chunks = []
data_chunks: list[bytes] = []
bar: Any
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
) as bar:
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)
progress_bar.update(len(chunk))
bar.update(len(chunk))
data_content = b"".join(data_chunks)
return json.loads(data_content.decode("utf-8"))
result: dict[str, Any] = json.loads(data_content.decode("utf-8"))
return result


def get_provider_data():
"""
Retrieves provider data from a cache file, filters out models based on provider criteria, and returns a dictionary of providers mapped to their models.
def get_provider_data() -> dict[str, list[str]] | None:
"""Retrieves provider data from a cache file.

Filters out models based on provider criteria, and returns a dictionary of providers
mapped to their models.

Returns:
- dict or None: A dictionary of providers mapped to their models or None if the operation fails.
A dictionary of providers mapped to their models or None if the operation fails.
"""
cache_dir = Path.home() / ".crewai"
cache_dir.mkdir(exist_ok=True)
Expand Down
17 changes: 17 additions & 0 deletions lib/crewai/tests/cli/test_create_crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,23 @@ def test_create_folder_structure_folder_name_validation():
shutil.rmtree(folder_path)


def test_create_folder_structure_rejects_reserved_names():
"""Test that reserved script names are rejected to prevent pyproject.toml conflicts."""
with tempfile.TemporaryDirectory() as temp_dir:
reserved_names = ["test", "train", "replay", "run_crew", "run_with_trigger"]

for reserved_name in reserved_names:
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(reserved_name, parent_folder=temp_dir)

with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(f"{reserved_name}/", parent_folder=temp_dir)

capitalized = reserved_name.capitalize()
with pytest.raises(ValueError, match="which is reserved"):
create_folder_structure(capitalized, parent_folder=temp_dir)


@mock.patch("crewai.cli.create_crew.create_folder_structure")
@mock.patch("crewai.cli.create_crew.copy_template")
@mock.patch("crewai.cli.create_crew.load_env_vars")
Expand Down
Loading