Skip to content

Commit da1638f

Browse files
authored
Merge pull request #3 from FullFact/make-genai-utils-async-ai-1684
Make `run_prompt` asynchronous
2 parents 5a4dfb0 + afc993a commit da1638f

File tree

3 files changed

+136
-24
lines changed

3 files changed

+136
-24
lines changed

src/genai_utils/gemini.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import logging
23
import os
34
import re
@@ -303,6 +304,98 @@ def run_prompt(
303304
use_grounding: bool = False,
304305
inline_citations: bool = False,
305306
labels: dict[str, str] = {},
307+
) -> str:
308+
"""
309+
A synchronous version of `run_prompt_async`.
310+
311+
Parameters
312+
----------
313+
prompt: str
314+
The prompt given to the model
315+
video_uri: str | None
316+
A Google Cloud URI for a video that you want to prompt.
317+
output_schema: types.SchemaUnion | None
318+
A valid schema for the model output.
319+
Generally, we'd recommend this being a pydantic BaseModel inheriting class,
320+
which defines the desired schema of the model output.
321+
```python
322+
from pydantic import BaseModel, Field
323+
324+
class Movie(BaseModel):
325+
title: str = Field(description="The title of the movie")
326+
year: int = Field(description="The year the film was released in the UK")
327+
328+
schema = Movie
329+
# or
330+
schema = list[Movie]
331+
```
332+
Use this if you want structured JSON output.
333+
system_instruction: str | None
334+
An instruction to the model which essentially goes before the prompt.
335+
For example:
336+
```
337+
You are a fact checker and you must base all your answers on evidence
338+
```
339+
generation_config: dict[str, Any]
340+
The parameters for the generation. See the docs (`generation config`_).
341+
safety_settings: dict[generative_models.HarmCategory, generative_models.HarmBlockThreshold]
342+
The safety settings for generation. Determines what will be blocked.
343+
See the docs (`safety settings`_)
344+
model_config: ModelConfig | None
345+
The config for the Gemini model.
346+
Specifies project, location, and model name.
347+
If None, will attempt to use environment variables:
348+
`GEMINI_PROJECT`, `GEMINI_LOCATION`, and `GEMINI_MODEL`.
349+
use_grounding: bool
350+
Whether Gemini should perform a Google search to ground results.
351+
This will allow it to pull from up-to-date information,
352+
and makes the output more likely to be factual.
353+
Does not work with structured output.
354+
See the docs (`grounding`_).
355+
inline_citations: bool
356+
Whether output should include citations inline with the text.
357+
These citations will be links to be used as evidence.
358+
This is only possible if grounding is set to true.
359+
labels: dict[str, str]
360+
Optional labels to attach to the API call for tracking and monitoring purposes.
361+
Labels are key-value pairs that can be used to organize and filter requests
362+
in Google Cloud logs and metrics.
363+
364+
Returns
365+
-------
366+
The text output of the Gemini model.
367+
368+
.. _generation config: https://cloud.google.com/vertex-ai/docs/reference/rest/v1/GenerationConfig
369+
.. _safety settings: https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/configure-safety-filters
370+
.. _grounding: https://ai.google.dev/gemini-api/docs/google-search
371+
"""
372+
return asyncio.run(
373+
run_prompt_async(
374+
prompt=prompt,
375+
video_uri=video_uri,
376+
output_schema=output_schema,
377+
system_instruction=system_instruction,
378+
generation_config=generation_config,
379+
safety_settings=safety_settings,
380+
model_config=model_config,
381+
use_grounding=use_grounding,
382+
inline_citations=inline_citations,
383+
labels=labels,
384+
)
385+
)
386+
387+
388+
async def run_prompt_async(
389+
prompt: str,
390+
video_uri: str | None = None,
391+
output_schema: types.SchemaUnion | None = None,
392+
system_instruction: str | None = None,
393+
generation_config: dict[str, Any] = DEFAULT_PARAMETERS,
394+
safety_settings: list[types.SafetySetting] = DEFAULT_SAFETY_SETTINGS,
395+
model_config: ModelConfig | None = None,
396+
use_grounding: bool = False,
397+
inline_citations: bool = False,
398+
labels: dict[str, str] = {},
306399
) -> str:
307400
"""
308401
Runs a prompt through the model.
@@ -405,7 +498,7 @@ class Movie(BaseModel):
405498
merged_labels = DEFAULT_LABELS | labels
406499
validate_labels(merged_labels)
407500

408-
response = client.models.generate_content(
501+
response = await client.aio.models.generate_content(
409502
model=model_config.model_name,
410503
contents=types.Content(role="user", parts=parts),
411504
config=types.GenerateContentConfig(

tests/genai_utils/test_gemini.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from unittest.mock import Mock, patch
33

44
from google.genai import Client
5+
from google.genai.client import AsyncClient
56
from google.genai.models import Models
67
from pydantic import BaseModel, Field
78

@@ -10,7 +11,7 @@
1011
GeminiError,
1112
ModelConfig,
1213
generate_model_config,
13-
run_prompt,
14+
run_prompt_async,
1415
)
1516

1617

@@ -24,7 +25,7 @@ class DummySchema(BaseModel):
2425
colour: str = Field(description="Colour of dog")
2526

2627

27-
def get_dummy():
28+
async def get_dummy():
2829
return DummyResponse()
2930

3031

@@ -57,25 +58,27 @@ def test_generate_model_config_no_env_vars():
5758

5859

5960
@patch("genai_utils.gemini.genai.Client")
60-
def test_dont_overwrite_generation_config(mock_client):
61+
async def test_dont_overwrite_generation_config(mock_client):
6162
copy_of_params = {**DEFAULT_PARAMETERS}
6263
client = Mock(Client)
6364
models = Mock(Models)
65+
async_client = Mock(AsyncClient)
6466

6567
models.generate_content.return_value = get_dummy()
66-
client.models = models
68+
client.aio = async_client
69+
async_client.models = models
6770
mock_client.return_value = client
6871

6972
assert DEFAULT_PARAMETERS == copy_of_params
70-
run_prompt(
73+
await run_prompt_async(
7174
"do something",
7275
output_schema=DummySchema,
7376
model_config=ModelConfig(
7477
project="project", location="location", model_name="model"
7578
),
7679
)
7780
models.generate_content.return_value = get_dummy()
78-
run_prompt(
81+
await run_prompt_async(
7982
"do something",
8083
model_config=ModelConfig(
8184
project="project", location="location", model_name="model"
@@ -89,16 +92,18 @@ def test_dont_overwrite_generation_config(mock_client):
8992

9093

9194
@patch("genai_utils.gemini.genai.Client")
92-
def test_error_if_grounding_with_schema(mock_client):
95+
async def test_error_if_grounding_with_schema(mock_client):
9396
client = Mock(Client)
9497
models = Mock(Models)
98+
async_client = Mock(AsyncClient)
9599

96100
models.generate_content.return_value = get_dummy()
97-
client.models = models
101+
client.aio = async_client
102+
async_client.models = models
98103
mock_client.return_value = client
99104

100105
try:
101-
run_prompt(
106+
await run_prompt_async(
102107
"do something",
103108
output_schema=DummySchema,
104109
use_grounding=True,
@@ -114,16 +119,18 @@ def test_error_if_grounding_with_schema(mock_client):
114119

115120

116121
@patch("genai_utils.gemini.genai.Client")
117-
def test_error_if_citations_and_no_grounding(mock_client):
122+
async def test_error_if_citations_and_no_grounding(mock_client):
118123
client = Mock(Client)
119124
models = Mock(Models)
125+
async_client = Mock(AsyncClient)
120126

121127
models.generate_content.return_value = get_dummy()
122-
client.models = models
128+
client.aio = async_client
129+
async_client.models = models
123130
mock_client.return_value = client
124131

125132
try:
126-
run_prompt(
133+
await run_prompt_async(
127134
"do something",
128135
use_grounding=False,
129136
inline_citations=True,

tests/genai_utils/test_labels.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,23 @@
33

44
import pytest
55
from google.genai import Client
6+
from google.genai.client import AsyncClient
67
from google.genai.models import Models
78

8-
from genai_utils.gemini import GeminiError, ModelConfig, run_prompt, validate_labels
9+
from genai_utils.gemini import (
10+
GeminiError,
11+
ModelConfig,
12+
run_prompt_async,
13+
validate_labels,
14+
)
915

1016

1117
class DummyResponse:
1218
candidates = "yes!"
1319
text = "response!"
1420

1521

16-
def get_dummy():
22+
async def get_dummy():
1723
return DummyResponse()
1824

1925

@@ -101,18 +107,20 @@ def test_validate_labels_valid_special_chars():
101107

102108

103109
@patch("genai_utils.gemini.genai.Client")
104-
def test_run_prompt_with_valid_labels(mock_client):
110+
async def test_run_prompt_with_valid_labels(mock_client):
105111
"""Test that run_prompt accepts and uses valid labels"""
106112
client = Mock(Client)
107113
models = Mock(Models)
114+
async_client = Mock(AsyncClient)
108115

109116
models.generate_content.return_value = get_dummy()
110-
client.models = models
117+
client.aio = async_client
118+
async_client.models = models
111119
mock_client.return_value = client
112120

113121
labels = {"team": "ai", "project": "test"}
114122

115-
run_prompt(
123+
await run_prompt_async(
116124
"test prompt",
117125
labels=labels,
118126
model_config=ModelConfig(
@@ -128,19 +136,21 @@ def test_run_prompt_with_valid_labels(mock_client):
128136

129137

130138
@patch("genai_utils.gemini.genai.Client")
131-
def test_run_prompt_with_invalid_labels(mock_client):
139+
async def test_run_prompt_with_invalid_labels(mock_client):
132140
"""Test that run_prompt rejects invalid labels"""
133141
client = Mock(Client)
134142
models = Mock(Models)
143+
async_client = Mock(AsyncClient)
135144

136145
models.generate_content.return_value = get_dummy()
137-
client.models = models
146+
client.aio = async_client
147+
async_client.models = models
138148
mock_client.return_value = client
139149

140150
invalid_labels = {"Invalid": "value"} # uppercase key
141151

142152
with pytest.raises(GeminiError, match="must start with a lowercase letter"):
143-
run_prompt(
153+
await run_prompt_async(
144154
"test prompt",
145155
labels=invalid_labels,
146156
model_config=ModelConfig(
@@ -151,7 +161,7 @@ def test_run_prompt_with_invalid_labels(mock_client):
151161

152162
@patch("genai_utils.gemini.genai.Client")
153163
@patch.dict(os.environ, {"GENAI_LABEL_TEAM": "ai", "GENAI_LABEL_ENV": "test"})
154-
def test_run_prompt_merges_env_labels(mock_client):
164+
async def test_run_prompt_merges_env_labels(mock_client):
155165
"""Test that run_prompt merges environment labels with request labels"""
156166
# Need to reload the module to pick up the new environment variables
157167
import importlib
@@ -162,14 +172,16 @@ def test_run_prompt_merges_env_labels(mock_client):
162172

163173
client = Mock(Client)
164174
models = Mock(Models)
175+
async_client = Mock(AsyncClient)
165176

166177
models.generate_content.return_value = get_dummy()
167-
client.models = models
178+
client.aio = async_client
179+
async_client.models = models
168180
mock_client.return_value = client
169181

170182
request_labels = {"project": "test"}
171183

172-
genai_utils.gemini.run_prompt(
184+
await genai_utils.gemini.run_prompt_async(
173185
"test prompt",
174186
labels=request_labels,
175187
model_config=ModelConfig(

0 commit comments

Comments
 (0)