Skip to content

Commit d17c439

Browse files
authored
Bugfix/create postgresql table if not exists for vectorstore (#5573)
create postgresql table if not exists for vectorstore
1 parent d090b71 commit d17c439

File tree

3 files changed

+52
-9
lines changed

3 files changed

+52
-9
lines changed

packages/components/nodes/vectorstores/Postgres/driver/Base.ts

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ import { VectorStore } from '@langchain/core/vectorstores'
22
import { getCredentialData, getCredentialParam, ICommonObject, INodeData } from '../../../../src'
33
import { Document } from '@langchain/core/documents'
44
import { Embeddings } from '@langchain/core/embeddings'
5-
import { getDatabase, getHost, getPort, getSSL, getTableName } from '../utils'
5+
import { getDatabase, getHost, getPort, getSchemaName, getSSL, getTableName } from '../utils'
66

77
export abstract class VectorStoreDriver {
88
constructor(protected nodeData: INodeData, protected options: ICommonObject) {}
@@ -35,6 +35,18 @@ export abstract class VectorStoreDriver {
3535
return this.sanitizeTableName(getTableName(this.nodeData))
3636
}
3737

38+
getSchemaName() {
39+
const schemaName = getSchemaName(this.nodeData)
40+
return schemaName ? this.sanitizeTableName(schemaName) : undefined
41+
}
42+
43+
getTablePath() {
44+
const schemaName = this.getSchemaName()
45+
const tableName = this.getTableName()
46+
if (!schemaName) return `"${tableName}"`
47+
return `"${schemaName}"."${tableName}"`
48+
}
49+
3850
getEmbeddings() {
3951
return this.nodeData.inputs?.embeddings as Embeddings
4052
}

packages/components/nodes/vectorstores/Postgres/driver/TypeORM.ts

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,25 @@ export class TypeORMDriver extends VectorStoreDriver {
5252
async getArgs(): Promise<TypeORMVectorStoreArgs> {
5353
return {
5454
postgresConnectionOptions: await this.getPostgresConnectionOptions(),
55-
tableName: this.getTableName()
55+
tableName: this.getTableName(),
56+
schemaName: this.getSchemaName()
5657
}
5758
}
5859

5960
async instanciate(metadataFilters?: any) {
60-
return this.adaptInstance(await TypeORMVectorStore.fromDataSource(this.getEmbeddings(), await this.getArgs()), metadataFilters)
61+
return this.adaptInstance(
62+
await TypeORMVectorStore.fromDataSource(this.getEmbeddings(), await this.getArgs()),
63+
metadataFilters,
64+
this.getTablePath()
65+
)
6166
}
6267

6368
async fromDocuments(documents: Document[]) {
64-
return this.adaptInstance(await TypeORMVectorStore.fromDocuments(documents, this.getEmbeddings(), await this.getArgs()))
69+
return this.adaptInstance(
70+
await TypeORMVectorStore.fromDocuments(documents, this.getEmbeddings(), await this.getArgs()),
71+
undefined,
72+
this.getTablePath()
73+
)
6574
}
6675

6776
sanitizeDocuments(documents: Document[]) {
@@ -73,8 +82,8 @@ export class TypeORMDriver extends VectorStoreDriver {
7382
return documents
7483
}
7584

76-
protected async adaptInstance(instance: TypeORMVectorStore, metadataFilters?: any): Promise<VectorStore> {
77-
const tableName = this.getTableName()
85+
protected async adaptInstance(instance: TypeORMVectorStore, metadataFilters?: any, tablePath?: string): Promise<VectorStore> {
86+
const effectiveTablePath = tablePath ?? this.getTablePath()
7887

7988
// Rewrite the method to use pg pool connection instead of the default connection
8089
/* Otherwise a connection error is displayed when the chain tries to execute the function
@@ -86,7 +95,7 @@ export class TypeORMDriver extends VectorStoreDriver {
8695
return await TypeORMDriver.similaritySearchVectorWithScore(
8796
query,
8897
k,
89-
tableName,
98+
effectiveTablePath,
9099
await this.getPostgresConnectionOptions(),
91100
filter ?? metadataFilters,
92101
this.computedOperatorString
@@ -141,6 +150,8 @@ export class TypeORMDriver extends VectorStoreDriver {
141150

142151
instance.addDocuments = async (documents: Document[], options?: { ids?: string[] }): Promise<void> => {
143152
const texts = documents.map(({ pageContent }) => pageContent)
153+
// Ensure table exists before adding documents (this will create the table if it does not exist)
154+
await this.ensureTableInDatabase(instance, effectiveTablePath)
144155
return (instance.addVectors as any)(await this.getEmbeddings().embedDocuments(texts), documents, options)
145156
}
146157

@@ -162,10 +173,26 @@ export class TypeORMDriver extends VectorStoreDriver {
162173
}
163174
}
164175

176+
/**
177+
* Ensures the table exists in the database with the correct schema.
178+
* Creates the pgvector extension and table if they don't exist.
179+
*/
180+
async ensureTableInDatabase(instance: TypeORMVectorStore, tablePath: string): Promise<void> {
181+
await instance.appDataSource.query('CREATE EXTENSION IF NOT EXISTS vector;')
182+
await instance.appDataSource.query(`
183+
CREATE TABLE IF NOT EXISTS ${tablePath} (
184+
"id" uuid NOT NULL DEFAULT gen_random_uuid() PRIMARY KEY,
185+
"pageContent" text,
186+
metadata jsonb,
187+
embedding vector
188+
);
189+
`)
190+
}
191+
165192
static similaritySearchVectorWithScore = async (
166193
query: number[],
167194
k: number,
168-
tableName: string,
195+
tablePath: string,
169196
postgresConnectionOptions: ICommonObject,
170197
filter?: any,
171198
distanceOperator: string = '<=>'
@@ -186,7 +213,7 @@ export class TypeORMDriver extends VectorStoreDriver {
186213

187214
const queryString = `
188215
SELECT *, embedding ${distanceOperator} $1 as "_distance"
189-
FROM ${tableName}
216+
FROM ${tablePath}
190217
WHERE ((metadata @> $2) AND NOT (metadata ? '${FLOWISE_CHATID}')) ${chatflowOr}
191218
ORDER BY "_distance" ASC
192219
LIMIT $3;`

packages/components/nodes/vectorstores/Postgres/utils.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,7 @@ export function getTableName(nodeData?: INodeData) {
2323
export function getContentColumnName(nodeData?: INodeData) {
2424
return defaultChain(nodeData?.inputs?.contentColumnName, process.env.POSTGRES_VECTORSTORE_CONTENT_COLUMN_NAME, 'pageContent')
2525
}
26+
27+
export function getSchemaName(nodeData?: INodeData) {
28+
return defaultChain(nodeData?.inputs?.schemaName, process.env.POSTGRES_VECTORSTORE_SCHEMA_NAME)
29+
}

0 commit comments

Comments
 (0)