Skip to content

Commit c4558e5

Browse files
yinghsienwucopybara-github
authored andcommitted
feat: Update VertexRagStore
PiperOrigin-RevId: 746177020
1 parent 8b1db9c commit c4558e5

2 files changed

Lines changed: 203 additions & 1 deletion

File tree

google/genai/tests/models/test_generate_content_tools.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def divide_floats(a: float, b: float) -> float:
164164
exception_if_mldev='retrieval',
165165
),
166166
pytest_helper.TestTableItem(
167-
name='test_rag_model',
167+
name='test_rag_model_old',
168168
parameters=types._GenerateContentParameters(
169169
model='gemini-1.5-flash',
170170
contents=t.t_contents(
@@ -191,6 +191,39 @@ def divide_floats(a: float, b: float) -> float:
191191
),
192192
exception_if_mldev='retrieval',
193193
),
194+
pytest_helper.TestTableItem(
195+
name='test_rag_model_ga',
196+
parameters=types._GenerateContentParameters(
197+
model='gemini-2.0-flash-001',
198+
contents=t.t_contents(
199+
None,
200+
'How much gain or loss did Google get in the Motorola Mobile'
201+
' deal in 2014?',
202+
),
203+
config={
204+
'tools': [
205+
types.Tool(
206+
retrieval=types.Retrieval(
207+
vertex_rag_store=types.VertexRagStore(
208+
rag_resources=[
209+
types.VertexRagStoreRagResource(
210+
rag_corpus='projects/964831358985/locations/us-central1/ragCorpora/3379951520341557248'
211+
)
212+
],
213+
rag_retrieval_config=types.RagRetrievalConfig(
214+
top_k=3,
215+
filter=types.RagRetrievalConfigFilter(
216+
vector_similarity_threshold=0.5,
217+
),
218+
),
219+
)
220+
),
221+
),
222+
]
223+
},
224+
),
225+
exception_if_mldev='retrieval',
226+
),
194227
pytest_helper.TestTableItem(
195228
name='test_function_call',
196229
parameters=types._GenerateContentParameters(

google/genai/types.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,168 @@ class VertexRagStoreRagResourceDict(TypedDict, total=False):
15181518
]
15191519

15201520

1521+
class RagRetrievalConfigFilter(_common.BaseModel):
1522+
"""Config for filters."""
1523+
1524+
metadata_filter: Optional[str] = Field(
1525+
default=None, description="""Optional. String for metadata filtering."""
1526+
)
1527+
vector_distance_threshold: Optional[float] = Field(
1528+
default=None,
1529+
description="""Optional. Only returns contexts with vector distance smaller than the threshold.""",
1530+
)
1531+
vector_similarity_threshold: Optional[float] = Field(
1532+
default=None,
1533+
description="""Optional. Only returns contexts with vector similarity larger than the threshold.""",
1534+
)
1535+
1536+
1537+
class RagRetrievalConfigFilterDict(TypedDict, total=False):
1538+
"""Config for filters."""
1539+
1540+
metadata_filter: Optional[str]
1541+
"""Optional. String for metadata filtering."""
1542+
1543+
vector_distance_threshold: Optional[float]
1544+
"""Optional. Only returns contexts with vector distance smaller than the threshold."""
1545+
1546+
vector_similarity_threshold: Optional[float]
1547+
"""Optional. Only returns contexts with vector similarity larger than the threshold."""
1548+
1549+
1550+
RagRetrievalConfigFilterOrDict = Union[
1551+
RagRetrievalConfigFilter, RagRetrievalConfigFilterDict
1552+
]
1553+
1554+
1555+
class RagRetrievalConfigHybridSearch(_common.BaseModel):
1556+
"""Config for Hybrid Search."""
1557+
1558+
alpha: Optional[float] = Field(
1559+
default=None,
1560+
description="""Optional. Alpha value controls the weight between dense and sparse vector search results. The range is [0, 1], while 0 means sparse vector search only and 1 means dense vector search only. The default value is 0.5 which balances sparse and dense vector search equally.""",
1561+
)
1562+
1563+
1564+
class RagRetrievalConfigHybridSearchDict(TypedDict, total=False):
1565+
"""Config for Hybrid Search."""
1566+
1567+
alpha: Optional[float]
1568+
"""Optional. Alpha value controls the weight between dense and sparse vector search results. The range is [0, 1], while 0 means sparse vector search only and 1 means dense vector search only. The default value is 0.5 which balances sparse and dense vector search equally."""
1569+
1570+
1571+
RagRetrievalConfigHybridSearchOrDict = Union[
1572+
RagRetrievalConfigHybridSearch, RagRetrievalConfigHybridSearchDict
1573+
]
1574+
1575+
1576+
class RagRetrievalConfigRankingLlmRanker(_common.BaseModel):
1577+
"""Config for LlmRanker."""
1578+
1579+
model_name: Optional[str] = Field(
1580+
default=None,
1581+
description="""Optional. The model name used for ranking. Format: `gemini-1.5-pro`""",
1582+
)
1583+
1584+
1585+
class RagRetrievalConfigRankingLlmRankerDict(TypedDict, total=False):
1586+
"""Config for LlmRanker."""
1587+
1588+
model_name: Optional[str]
1589+
"""Optional. The model name used for ranking. Format: `gemini-1.5-pro`"""
1590+
1591+
1592+
RagRetrievalConfigRankingLlmRankerOrDict = Union[
1593+
RagRetrievalConfigRankingLlmRanker, RagRetrievalConfigRankingLlmRankerDict
1594+
]
1595+
1596+
1597+
class RagRetrievalConfigRankingRankService(_common.BaseModel):
1598+
"""Config for Rank Service."""
1599+
1600+
model_name: Optional[str] = Field(
1601+
default=None,
1602+
description="""Optional. The model name of the rank service. Format: `semantic-ranker-512@latest`""",
1603+
)
1604+
1605+
1606+
class RagRetrievalConfigRankingRankServiceDict(TypedDict, total=False):
1607+
"""Config for Rank Service."""
1608+
1609+
model_name: Optional[str]
1610+
"""Optional. The model name of the rank service. Format: `semantic-ranker-512@latest`"""
1611+
1612+
1613+
RagRetrievalConfigRankingRankServiceOrDict = Union[
1614+
RagRetrievalConfigRankingRankService,
1615+
RagRetrievalConfigRankingRankServiceDict,
1616+
]
1617+
1618+
1619+
class RagRetrievalConfigRanking(_common.BaseModel):
1620+
"""Config for ranking and reranking."""
1621+
1622+
llm_ranker: Optional[RagRetrievalConfigRankingLlmRanker] = Field(
1623+
default=None, description="""Optional. Config for LlmRanker."""
1624+
)
1625+
rank_service: Optional[RagRetrievalConfigRankingRankService] = Field(
1626+
default=None, description="""Optional. Config for Rank Service."""
1627+
)
1628+
1629+
1630+
class RagRetrievalConfigRankingDict(TypedDict, total=False):
1631+
"""Config for ranking and reranking."""
1632+
1633+
llm_ranker: Optional[RagRetrievalConfigRankingLlmRankerDict]
1634+
"""Optional. Config for LlmRanker."""
1635+
1636+
rank_service: Optional[RagRetrievalConfigRankingRankServiceDict]
1637+
"""Optional. Config for Rank Service."""
1638+
1639+
1640+
RagRetrievalConfigRankingOrDict = Union[
1641+
RagRetrievalConfigRanking, RagRetrievalConfigRankingDict
1642+
]
1643+
1644+
1645+
class RagRetrievalConfig(_common.BaseModel):
1646+
"""Specifies the context retrieval config."""
1647+
1648+
filter: Optional[RagRetrievalConfigFilter] = Field(
1649+
default=None, description="""Optional. Config for filters."""
1650+
)
1651+
hybrid_search: Optional[RagRetrievalConfigHybridSearch] = Field(
1652+
default=None, description="""Optional. Config for Hybrid Search."""
1653+
)
1654+
ranking: Optional[RagRetrievalConfigRanking] = Field(
1655+
default=None,
1656+
description="""Optional. Config for ranking and reranking.""",
1657+
)
1658+
top_k: Optional[int] = Field(
1659+
default=None,
1660+
description="""Optional. The number of contexts to retrieve.""",
1661+
)
1662+
1663+
1664+
class RagRetrievalConfigDict(TypedDict, total=False):
1665+
"""Specifies the context retrieval config."""
1666+
1667+
filter: Optional[RagRetrievalConfigFilterDict]
1668+
"""Optional. Config for filters."""
1669+
1670+
hybrid_search: Optional[RagRetrievalConfigHybridSearchDict]
1671+
"""Optional. Config for Hybrid Search."""
1672+
1673+
ranking: Optional[RagRetrievalConfigRankingDict]
1674+
"""Optional. Config for ranking and reranking."""
1675+
1676+
top_k: Optional[int]
1677+
"""Optional. The number of contexts to retrieve."""
1678+
1679+
1680+
RagRetrievalConfigOrDict = Union[RagRetrievalConfig, RagRetrievalConfigDict]
1681+
1682+
15211683
class VertexRagStore(_common.BaseModel):
15221684
"""Retrieve from Vertex RAG Store for grounding."""
15231685

@@ -1529,6 +1691,10 @@ class VertexRagStore(_common.BaseModel):
15291691
default=None,
15301692
description="""Optional. The representation of the rag source. It can be used to specify corpus only or ragfiles. Currently only support one corpus or multiple files from one corpus. In the future we may open up multiple corpora support.""",
15311693
)
1694+
rag_retrieval_config: Optional[RagRetrievalConfig] = Field(
1695+
default=None,
1696+
description="""Optional. The retrieval config for the Rag query.""",
1697+
)
15321698
similarity_top_k: Optional[int] = Field(
15331699
default=None,
15341700
description="""Optional. Number of top k results to return from the selected corpora.""",
@@ -1548,6 +1714,9 @@ class VertexRagStoreDict(TypedDict, total=False):
15481714
rag_resources: Optional[list[VertexRagStoreRagResourceDict]]
15491715
"""Optional. The representation of the rag source. It can be used to specify corpus only or ragfiles. Currently only support one corpus or multiple files from one corpus. In the future we may open up multiple corpora support."""
15501716

1717+
rag_retrieval_config: Optional[RagRetrievalConfigDict]
1718+
"""Optional. The retrieval config for the Rag query."""
1719+
15511720
similarity_top_k: Optional[int]
15521721
"""Optional. Number of top k results to return from the selected corpora."""
15531722

0 commit comments

Comments
 (0)