feature: OpenAICompatible classes does one generation at a time by default#1524
Conversation
OpenAICompatible classes does one generation at a time by default
jmartin-tech
left a comment
There was a problem hiding this comment.
Looks good, this seems like a good default expectation to provide, should we expand this to also put n in the suppressed_params base set and expect extending classes to remove it when they support multiple generations?
Or should we handle that more dynamically by having OpenAICompatible add it during the suppression or params code if self.supports_multiple_generations is False?
|
Since the OpenAI API is used so widely, I prefer not making assumptions in On a related note - we're not checking if the number of generations requested in garak config matches the number of outputs returned overall by the generator. This may be happening with some endpoints when |
| ENV_VAR = "OpenAICompatible_API_KEY".upper() # Placeholder override when extending | ||
|
|
||
| active = True | ||
| supports_multiple_generations = True | ||
| supports_multiple_generations = False | ||
| generator_family_name = "OpenAICompatible" # Placeholder override when extending | ||
|
|
||
| # template defaults optionally override when extending | ||
| DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||
| "temperature": 0.7, | ||
| "top_p": 1.0, | ||
| "uri": "http://localhost:8000/v1/", | ||
| "frequency_penalty": 0.0, | ||
| "presence_penalty": 0.0, | ||
| "seed": None, | ||
| "stop": ["#", ";"], | ||
| "suppressed_params": set(), | ||
| "retry_json": True, | ||
| "extra_params": {}, | ||
| } | ||
|
|
||
| # avoid attempt to pickle the client attribute | ||
| def __getstate__(self) -> object: | ||
| self._clear_client() | ||
| return dict(self.__dict__) | ||
|
|
||
| # restore the client attribute | ||
| def __setstate__(self, d) -> object: | ||
| self.__dict__.update(d) | ||
| self._load_client() | ||
|
|
||
| def _load_client(self): | ||
| # When extending `OpenAICompatible` this method is a likely location for target application specific | ||
| # customization and must populate self.generator with an openai api compliant object | ||
| self.client = openai.OpenAI(base_url=self.uri, api_key=self.api_key) | ||
| if self.name in ("", None): | ||
| raise ValueError( | ||
| f"{self.generator_family_name} requires model name to be set, e.g. --target_name org/private-model-name" | ||
| ) | ||
| self.generator = self.client.chat.completions | ||
|
|
||
| def _clear_client(self): | ||
| self.generator = None | ||
| self.client = None | ||
|
|
||
| def _validate_config(self): | ||
| pass | ||
|
|
||
| def __init__(self, name="", config_root=_config): | ||
| self.name = name | ||
| self._load_config(config_root) | ||
| self.fullname = f"{self.generator_family_name} {self.name}" | ||
| self.key_env_var = self.ENV_VAR | ||
|
|
||
| self._load_client() | ||
|
|
||
| if self.generator not in ( | ||
| self.client.chat.completions, | ||
| self.client.completions, | ||
| ): | ||
| raise ValueError( | ||
| "Unsupported model at generation time in generators/openai.py - please add a clause!" | ||
| ) | ||
|
|
||
| self._validate_config() | ||
|
|
||
| super().__init__(self.name, config_root=config_root) | ||
|
|
||
| # clear client config to enable object to `pickle` | ||
| self._clear_client() | ||
|
|
||
| # noinspection PyArgumentList | ||
| @backoff.on_exception( | ||
| backoff.fibo, | ||
| ( | ||
| openai.RateLimitError, | ||
| openai.InternalServerError, | ||
| openai.APITimeoutError, | ||
| openai.APIConnectionError, | ||
| garak.exception.GarakBackoffTrigger, | ||
| ), | ||
| max_value=70, | ||
| ) | ||
| def _call_model( | ||
| self, prompt: Union[Conversation, List[dict]], generations_this_call: int = 1 | ||
| ) -> List[Union[Message, None]]: | ||
| if self.client is None: | ||
| # reload client once when consuming the generator | ||
| self._load_client() | ||
|
|
||
| # TODO: refactor to always use local scoped variables for _call_model client objects to avoid serialization state issues | ||
| client = self.client | ||
| generator = self.generator | ||
| is_completion = generator == client.completions | ||
|
|
||
| create_args = {} | ||
| if "n" not in self.suppressed_params: | ||
| create_args["n"] = generations_this_call | ||
| for arg in inspect.signature(generator.create).parameters: | ||
| if arg == "model": | ||
| create_args[arg] = self.name | ||
| continue | ||
| if arg == "extra_params": | ||
| continue | ||
| if hasattr(self, arg) and arg not in self.suppressed_params: | ||
| if getattr(self, arg) is not None: | ||
| create_args[arg] = getattr(self, arg) | ||
|
|
||
| if hasattr(self, "extra_params"): | ||
| for k, v in self.extra_params.items(): | ||
| create_args[k] = v | ||
|
|
||
| if is_completion: | ||
| if not isinstance(prompt, Conversation) or len(prompt.turns) > 1: | ||
| msg = ( | ||
| f"Expected a Conversation with one Turn for {self.generator_family_name} completions model {self.name}, but got {type(prompt)}. " | ||
| f"Returning nothing!" | ||
| ) | ||
| logging.error(msg) | ||
| return list() | ||
|
|
||
| create_args["prompt"] = prompt.last_message().text | ||
|
|
||
| else: # is chat | ||
| if isinstance(prompt, Conversation): | ||
| messages = self._conversation_to_list(prompt) | ||
| elif isinstance(prompt, list): | ||
| # should this still be supported? | ||
| messages = prompt | ||
| else: | ||
| msg = ( | ||
| f"Expected a Conversation or list of dicts for {self.generator_family_name} Chat model {self.name}, but got {type(prompt)} instead. " | ||
| f"Returning nothing!" | ||
| ) | ||
| logging.error(msg) | ||
| return list() | ||
|
|
||
| create_args["messages"] = messages | ||
|
|
||
| try: | ||
| response = generator.create(**create_args) | ||
| except openai.BadRequestError as e: | ||
| msg = "Bad request: " + str(repr(prompt)) | ||
| logging.exception(e) | ||
| logging.error(msg) | ||
| return [None] | ||
| except json.decoder.JSONDecodeError as e: | ||
| logging.exception(e) | ||
| if self.retry_json: | ||
| raise garak.exception.GarakBackoffTrigger from e | ||
| else: | ||
| raise e | ||
|
|
||
| if not hasattr(response, "choices"): | ||
| logging.debug( | ||
| "Did not get a well-formed response, retrying. Expected object with .choices member, got: '%s'" | ||
| % repr(response) | ||
| ) | ||
| msg = "no .choices member in generator response" | ||
| if self.retry_json: | ||
| raise garak.exception.GarakBackoffTrigger(msg) | ||
| else: | ||
| return [None] | ||
|
|
||
| if is_completion: | ||
| return [Message(c.text) for c in response.choices] | ||
| else: | ||
| return [Message(c.message.content) for c in response.choices] | ||
|
|
||
|
|
||
| class OpenAIGenerator(OpenAICompatible): | ||
| """Generator wrapper for OpenAI text2text models. Expects API key in the OPENAI_API_KEY environment variable""" | ||
|
|
||
| ENV_VAR = "OPENAI_API_KEY" | ||
| active = True | ||
| generator_family_name = "OpenAI" | ||
| supports_multiple_generations = True | ||
|
|
||
| # remove uri as it is not overridable in this class. | ||
| DEFAULT_PARAMS = { | ||
| k: val for k, val in OpenAICompatible.DEFAULT_PARAMS.items() if k != "uri" | ||
| } | ||
|
|
||
| def _load_client(self): | ||
| self.client = openai.OpenAI(api_key=self.api_key) | ||
|
|
||
| if self.name == "": | ||
| openai_model_list = sorted([m.id for m in self.client.models.list().data]) | ||
| raise ValueError( | ||
| f"Model name is required for {self.generator_family_name}, use --target_name\n" | ||
| + " API returns following available models: ▶️ " | ||
| + " ".join(openai_model_list) | ||
| + "\n" | ||
| + " ⚠️ Not all these are text generation models" | ||
| ) | ||
|
|
||
| if self.name in completion_models: | ||
| self.generator = self.client.completions | ||
| elif self.name in chat_models: | ||
| self.generator = self.client.chat.completions | ||
| elif "-".join(self.name.split("-")[:-1]) in chat_models and re.match( | ||
| r"^.+-[01][0-9][0-3][0-9]$", self.name | ||
| ): # handle model names -MMDDish suffix | ||
| self.generator = self.client.completions | ||
|
|
||
| else: | ||
| raise ValueError( | ||
| f"No {self.generator_family_name} API defined for '{self.name}' in generators/openai.py - please add one!" | ||
| ) | ||
|
|
||
| if self.__class__.__name__ == "OpenAIGenerator" and self.name.startswith("o"): | ||
| msg = "'o'-class models should use openai.OpenAIReasoningGenerator. Try e.g. `-m openai.OpenAIReasoningGenerator` instead of `-m openai`" | ||
| logging.error(msg) | ||
| raise garak.exception.BadGeneratorException("🛑 " + msg) | ||
|
|
||
| def __init__(self, name="", config_root=_config): | ||
| self.name = name | ||
| self._load_config(config_root) | ||
| if self.name in context_lengths: | ||
| self.context_len = context_lengths[self.name] | ||
|
|
||
| super().__init__(self.name, config_root=config_root) | ||
|
|
||
|
|
||
| class OpenAIReasoningGenerator(OpenAIGenerator): | ||
| """Generator wrapper for OpenAI reasoning models, e.g. `o1` family.""" | ||
|
|
||
| supports_multiple_generations = False | ||
|
|
||
| DEFAULT_PARAMS = Generator.DEFAULT_PARAMS | { | ||
| "top_p": 1.0, | ||
| "frequency_penalty": 0.0, | ||
| "presence_penalty": 0.0, | ||
| "seed": None, | ||
| "stop": ["#", ";"], | ||
| "suppressed_params": set(["n", "temperature", "max_tokens", "stop"]), | ||
| "retry_json": True, | ||
| "max_completion_tokens": 1500, | ||
| } | ||
|
|
||
|
|
||
| DEFAULT_CLASS = "OpenAIGenerator" |
There was a problem hiding this comment.
Since n is already handed special via generations_this_call we can just base suppression on self.supports_multiple_generations:
| if self.supports_multiple_generations: |
|
Since if self.supports_multiple_generations:
create_args["n"] = generations_this_call
elif "n" not in self.suppressed_params:
create_args["n"] = 1In practice, we really should not allow |
Signed-off-by: Jeffrey Martin <jemartin@nvidia.com>
Full support for OpenAI API is patchy.
Maybe targets that support OpenAI API access in general can only generate one result at a time. This can happen silently when multiple results are requested.
This PR changes the default of
OpenAICompatibleto be to not support multiple concurrent generations, while updatingOpenAIGeneratorto still support multiple generations.