-
Notifications
You must be signed in to change notification settings - Fork 811
Feature: AWS bedrock generator #1462
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
jmartin-tech
merged 13 commits into
NVIDIA:main
from
fabriziorocco:feature/bedrock_generator
Nov 26, 2025
Merged
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
fe20c7f
bedrock
fabriziorocco 2d0f456
Add files via upload
fabriziorocco 8dd4329
added bedrock generator command
fabriziorocco 2f70bbb
Update BedrockGenerator tests for new model names
fabriziorocco 09c5f35
Imported boto3 for Bedrock Generator
fabriziorocco 1e24c76
Add bedrock doc
fabriziorocco 8b81ebf
Tests for Bedrock generator fix
fabriziorocco 5392793
add doc
fabriziorocco e040348
fix static code review
fabriziorocco bfe6965
fix value error tests
fabriziorocco 3b3c8aa
Update garak/generators/bedrock.py
fabriziorocco f9fc4f6
Update garak/generators/bedrock.py
fabriziorocco ff9aab0
fixed
fabriziorocco File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| garak.generators.bedrock | ||
| ========================= | ||
|
|
||
| .. automodule:: garak.generators.bedrock | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,265 @@ | ||
| """AWS Bedrock generator | ||
|
|
||
| Supports foundation models available through AWS Bedrock using standard AWS authentication. | ||
|
|
||
| To get started with this generator: | ||
|
|
||
| #. Visit https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html | ||
| to see available models | ||
| #. Set up AWS credentials: https://docs.aws.amazon.com/cli/latest/userguide/cli-configure-files.html or a Bedrock API Key | ||
| #. Run garak with --target_type bedrock and --target_name <model-id> | ||
|
|
||
| """ | ||
|
|
||
| import logging | ||
| import os | ||
| import re | ||
| from typing import List, Union | ||
|
|
||
| import backoff | ||
|
|
||
| from garak import _config | ||
| from garak.attempt import Message, Conversation | ||
| import garak.exception | ||
| from garak.generators.base import Generator | ||
|
|
||
| MODEL_ALIASES = { | ||
| "claude-4-5-haiku": "global.anthropic.claude-haiku-4-5-20251001-v1:0", | ||
| "claude-4-5-sonnet": "global.anthropic.claude-sonnet-4-5-20250929-v1:0", | ||
| "claude-4-1-opus": "us.anthropic.claude-opus-4-1-20250805-v1:0", # US Inference Endpoint | ||
| "claude-4-opus": "us.anthropic.claude-opus-4-20250514-v1:0", # US Inference Endpoint | ||
| "claude-4-sonnet": "global.anthropic.claude-sonnet-4-20250514-v1:0", | ||
| "nova-premier": "us.amazon.nova-premier-v1:0", # US Inference Endpoint | ||
| "nova-pro": "us.amazon.nova-pro-v1:0", # US Inference Endpoint | ||
| "nova-lite": "us.amazon.nova-lite-v1:0", # US Inference Endpoint | ||
| "nova-micro": "us.amazon.nova-micro-v1:0", # US Inference Endpoint | ||
| } | ||
|
|
||
|
|
||
| class BedrockGenerator(Generator): | ||
| """Interface for AWS Bedrock foundation models using Converse API""" | ||
|
|
||
| active = True | ||
| generator_family_name = "Bedrock" | ||
| supports_multiple_generations = False | ||
|
|
||
| DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||
| "temperature": 0.7, | ||
| "top_p": 1.0, | ||
| "stop": [], | ||
| "region": "us-east-1", | ||
| } | ||
|
|
||
| def __init__(self, name="", config_root=_config): | ||
| """Initialize the Bedrock generator. | ||
|
|
||
| Args: | ||
| name: Model name or alias (e.g., "claude-3-sonnet" or full model ID) | ||
| config_root: Configuration root object | ||
| """ | ||
| self.name = name | ||
| self._load_config(config_root) | ||
|
|
||
| if self.name in MODEL_ALIASES: | ||
| resolved_name = MODEL_ALIASES[self.name] | ||
| logging.info(f"Resolved model alias '{self.name}' to: {resolved_name}") | ||
| self.name = resolved_name | ||
|
|
||
| # Validate model ID format | ||
| if self.name: | ||
| # Check if model is in our known aliases (already resolved) | ||
| if self.name in MODEL_ALIASES.values(): | ||
| # Valid known model | ||
| pass | ||
| # Check if it's an ARN format | ||
| elif self.name.startswith("arn:aws:bedrock:"): | ||
| arn_pattern = r"^arn:aws:bedrock:[a-z0-9-]+:[0-9]+:inference-profile/[a-z0-9.:-]+$" | ||
| if not re.match(arn_pattern, self.name): | ||
| raise ValueError( | ||
| f"Model ID '{self.name}' appears to be an ARN but is not in the correct format. " | ||
| f"Expected format: 'arn:aws:bedrock:region:account:inference-profile/model-id'" | ||
| ) | ||
| # Check if it matches standard Bedrock model ID format (must have provider.model structure) | ||
| elif "." in self.name: | ||
| # Standard model IDs should have format: provider.model-name or region.provider.model-name | ||
| bedrock_id_pattern = r"^([a-z0-9-]+\.)+[a-z0-9.:-]+$" | ||
| if not re.match(bedrock_id_pattern, self.name): | ||
| raise ValueError( | ||
| f"Model ID '{self.name}' does not appear to be a valid Bedrock model ID format. " | ||
| f"Expected format examples:\n" | ||
| f" - Model ID: 'anthropic.claude-v2' or 'us.amazon.nova-pro-v1:0'\n" | ||
| f" - Inference profile: 'us.anthropic.claude-4-1-sonnet-v2:0'\n" | ||
| f" - ARN: 'arn:aws:bedrock:region:account:inference-profile/model-id'" | ||
| ) | ||
| else: | ||
| # No dots and not an ARN | ||
| supported_aliases = ", ".join(sorted(MODEL_ALIASES.keys())) | ||
| raise ValueError( | ||
| f"Model ID '{self.name}' is not in the list of supported Bedrock models. " | ||
| f"Please use one of the known aliases: {supported_aliases}\n" | ||
| f"Or provide a full Bedrock model ID (e.g., 'anthropic.claude-v2' or 'us.amazon.nova-pro-v1:0')" | ||
| ) | ||
|
|
||
| super().__init__(self.name, config_root=config_root) | ||
| self._validate_env_var() | ||
| self._load_client() | ||
|
|
||
| def _validate_env_var(self): | ||
| """Validate and set region from environment variables if not configured. | ||
|
|
||
| Checks AWS_REGION and AWS_DEFAULT_REGION environment variables only if | ||
| the region parameter is still at its default value. | ||
| """ | ||
| if self.region == "us-east-1": | ||
| env_region = os.getenv("AWS_REGION") or os.getenv("AWS_DEFAULT_REGION") | ||
| if env_region: | ||
| logging.info(f"Using AWS region from environment: {env_region}") | ||
| self.region = env_region | ||
|
|
||
| return super()._validate_env_var() | ||
|
|
||
| def _load_client(self): | ||
| """Load and configure the boto3 bedrock-runtime client. | ||
|
|
||
| Uses boto3's standard credential chain for authentication. | ||
| """ | ||
| try: | ||
| import boto3 | ||
| except ImportError: | ||
| raise ImportError( | ||
| "boto3 is required for the Bedrock generator. " | ||
| "Install it with: pip install boto3" | ||
| ) | ||
|
|
||
| self.client = boto3.client( | ||
| service_name="bedrock-runtime", | ||
| region_name=self.region, | ||
| ) | ||
|
|
||
| logging.info(f"Loaded boto3 bedrock-runtime client for region {self.region}") | ||
|
|
||
| def _clear_client(self): | ||
| """Clear the boto3 client to enable object pickling.""" | ||
| if hasattr(self, "client"): | ||
| self.client = None | ||
|
|
||
| def __getstate__(self): | ||
| """Prepare object for pickling by clearing the boto3 client.""" | ||
| self._clear_client() | ||
| return self.__dict__ | ||
|
|
||
| def __setstate__(self, state): | ||
| """Restore object from pickle and reload the boto3 client.""" | ||
| self.__dict__.update(state) | ||
| self._load_client() | ||
|
|
||
| @staticmethod | ||
| def _conversation_to_list(conversation: Conversation) -> list[dict]: | ||
| """Convert Conversation object to Bedrock Converse API message format. | ||
|
|
||
| AWS Bedrock expects messages in the format: | ||
| {"role": "user", "content": [{"text": "message text"}]} | ||
|
|
||
| Args: | ||
| conversation: Conversation object to convert | ||
|
|
||
| Returns: | ||
| List of message dictionaries in Bedrock format | ||
| """ | ||
| turn_list = [ | ||
| {"role": turn.role, "content": [{"text": turn.content.text}]} | ||
| for turn in conversation.turns | ||
| ] | ||
| return turn_list | ||
|
|
||
| @backoff.on_exception( | ||
| backoff.fibo, | ||
| garak.exception.GarakBackoffTrigger, | ||
| max_value=70, | ||
| ) | ||
fabriziorocco marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def _call_model( | ||
| self, prompt: Conversation, generations_this_call: int = 1 | ||
| ) -> List[Union[Message, None]]: | ||
| """Call the Bedrock model using the Converse API. | ||
|
|
||
| Args: | ||
| prompt: Conversation object containing the prompt turns | ||
| generations_this_call: Number of generations to request (currently only 1 is supported) | ||
|
|
||
| Returns: | ||
| List of Message objects containing the generated text, or [None] on error | ||
| """ | ||
| if self.client is None: | ||
| self._load_client() | ||
|
|
||
| messages = self._conversation_to_list(prompt) | ||
|
|
||
| if not messages: | ||
| logging.error("No valid messages to send to Bedrock") | ||
| return [None] | ||
|
|
||
| inference_config = {} | ||
| if self.temperature is not None: | ||
| inference_config["temperature"] = float(self.temperature) | ||
| if hasattr(self, "max_tokens") and self.max_tokens is not None: | ||
| inference_config["maxTokens"] = int(self.max_tokens) | ||
| if self.top_p is not None: | ||
| inference_config["topP"] = float(self.top_p) | ||
| if self.stop: | ||
| inference_config["stopSequences"] = self.stop | ||
|
|
||
| call_args = { | ||
| "modelId": self.name, | ||
| "messages": messages, | ||
| } | ||
| if inference_config: | ||
| call_args["inferenceConfig"] = inference_config | ||
|
|
||
| try: | ||
| response = self.client.converse(**call_args) | ||
|
|
||
| if not response or "output" not in response: | ||
| logging.error("Malformed response from Bedrock: missing 'output' field") | ||
| return [None] | ||
|
|
||
| if "message" not in response["output"]: | ||
| logging.error( | ||
| "Malformed response from Bedrock: missing 'message' in output" | ||
| ) | ||
| return [None] | ||
|
|
||
| message = response["output"]["message"] | ||
| if "content" not in message or not message["content"]: | ||
| logging.error( | ||
| "Malformed response from Bedrock: missing or empty 'content' in message" | ||
| ) | ||
| return [None] | ||
|
|
||
| content_block = message["content"][0] | ||
| if "text" not in content_block: | ||
| logging.error( | ||
| "Malformed response from Bedrock: missing 'text' in content block" | ||
| ) | ||
| return [None] | ||
|
|
||
| text = content_block["text"] | ||
| return [Message(text=text)] | ||
|
|
||
| except Exception as e: | ||
| from botocore.exceptions import ClientError | ||
|
|
||
| if isinstance(e, ClientError): | ||
| error_code = e.response.get("Error", {}).get("Code", "") | ||
| error_message = e.response.get("Error", {}).get("Message", "") | ||
|
|
||
| logging.error(f"Bedrock API error [{error_code}]: {error_message}") | ||
|
|
||
| if error_code in ["ThrottlingException", "ServiceUnavailableException"]: | ||
| raise garak.exception.GarakBackoffTrigger from e | ||
|
|
||
| return [None] | ||
|
|
||
| logging.exception("Error calling Bedrock model") | ||
| return [None] | ||
|
|
||
| DEFAULT_CLASS = "BedrockGenerator" | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the misdirection, I see this is a slightly different format than the what is used in the base class.