|
@@ -8,7 +8,7 @@ import {
|
|
|
} from '@nestjs/common'
|
|
} from '@nestjs/common'
|
|
|
import { InjectRepository } from '@nestjs/typeorm'
|
|
import { InjectRepository } from '@nestjs/typeorm'
|
|
|
import { KnowledgeBase } from './entities/knowledge-base.entity'
|
|
import { KnowledgeBase } from './entities/knowledge-base.entity'
|
|
|
-import { Repository } from 'typeorm'
|
|
|
|
|
|
|
+import { In, Repository } from 'typeorm'
|
|
|
import { Tiktoken, get_encoding } from '@dqbd/tiktoken'
|
|
import { Tiktoken, get_encoding } from '@dqbd/tiktoken'
|
|
|
import { Configuration, OpenAIApi } from 'azure-openai'
|
|
import { Configuration, OpenAIApi } from 'azure-openai'
|
|
|
import { DataTypes, Sequelize } from 'sequelize'
|
|
import { DataTypes, Sequelize } from 'sequelize'
|
|
@@ -39,6 +39,7 @@ function formatEmbedding(embedding: number[]) {
|
|
|
export class KnowledgeBaseService implements OnModuleInit {
|
|
export class KnowledgeBaseService implements OnModuleInit {
|
|
|
private readonly tokenizer: Tiktoken
|
|
private readonly tokenizer: Tiktoken
|
|
|
private readonly openai: OpenAIApi
|
|
private readonly openai: OpenAIApi
|
|
|
|
|
+ private readonly embeddingApi: OpenAIApi
|
|
|
private readonly sequelize: Sequelize
|
|
private readonly sequelize: Sequelize
|
|
|
private embeddings: OpenAIEmbeddings
|
|
private embeddings: OpenAIEmbeddings
|
|
|
private vectorStore: TypeORMVectorStore
|
|
private vectorStore: TypeORMVectorStore
|
|
@@ -53,11 +54,23 @@ export class KnowledgeBaseService implements OnModuleInit {
|
|
|
this.tokenizer = get_encoding('cl100k_base')
|
|
this.tokenizer = get_encoding('cl100k_base')
|
|
|
this.openai = new OpenAIApi(
|
|
this.openai = new OpenAIApi(
|
|
|
new Configuration({
|
|
new Configuration({
|
|
|
- apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
|
|
|
|
|
|
|
+ apiKey: process.env.AZURE_OPENAI_KEY,
|
|
|
// add azure info into configuration
|
|
// add azure info into configuration
|
|
|
azure: {
|
|
azure: {
|
|
|
- apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
|
|
|
|
|
- endpoint: 'https://zouma.openai.azure.com'
|
|
|
|
|
|
|
+ apiKey: process.env.AZURE_OPENAI_KEY,
|
|
|
|
|
+ endpoint: process.env.AZURE_OPENAI_ENDPOINT,
|
|
|
|
|
+ deploymentName: process.env.AZURE_OPENAI_DEPLOYMENT
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+ )
|
|
|
|
|
+ this.embeddingApi = new OpenAIApi(
|
|
|
|
|
+ new Configuration({
|
|
|
|
|
+ apiKey: process.env.AZURE_EMBEDDING_KEY,
|
|
|
|
|
+ // add azure info into configuration
|
|
|
|
|
+ azure: {
|
|
|
|
|
+ apiKey: process.env.AZURE_EMBEDDING_KEY,
|
|
|
|
|
+ endpoint: `https://${process.env.AZURE_EMBEDDING_INSTANCE}.openai.azure.com`,
|
|
|
|
|
+ deploymentName: process.env.AZURE_EMBEDDING_DEPLOYMENT
|
|
|
}
|
|
}
|
|
|
})
|
|
})
|
|
|
)
|
|
)
|
|
@@ -125,7 +138,7 @@ export class KnowledgeBaseService implements OnModuleInit {
|
|
|
password: process.env.PG_PASSWORD,
|
|
password: process.env.PG_PASSWORD,
|
|
|
database: process.env.PG_DATABASE
|
|
database: process.env.PG_DATABASE
|
|
|
},
|
|
},
|
|
|
- verbose: true,
|
|
|
|
|
|
|
+ verbose: true
|
|
|
})
|
|
})
|
|
|
await this.vectorStore.ensureTableInDatabase()
|
|
await this.vectorStore.ensureTableInDatabase()
|
|
|
}
|
|
}
|
|
@@ -203,8 +216,8 @@ export class KnowledgeBaseService implements OnModuleInit {
|
|
|
this.processExcelKnowledgeFile(knowledgeFile, buffer)
|
|
this.processExcelKnowledgeFile(knowledgeFile, buffer)
|
|
|
break
|
|
break
|
|
|
case 'application/pdf':
|
|
case 'application/pdf':
|
|
|
- //this.processPdfKnowledgeFile(knowledgeFile, buffer)
|
|
|
|
|
- this.processPdfKnowledgeFile1(knowledgeFile, buffer)
|
|
|
|
|
|
|
+ this.processPdfKnowledgeFile(knowledgeFile, buffer)
|
|
|
|
|
+ // this.processPdfKnowledgeFile1(knowledgeFile, buffer)
|
|
|
break
|
|
break
|
|
|
}
|
|
}
|
|
|
return knowledgeFile
|
|
return knowledgeFile
|
|
@@ -384,7 +397,7 @@ export class KnowledgeBaseService implements OnModuleInit {
|
|
|
|
|
|
|
|
async getEmbedding(content: string, retry = 0) {
|
|
async getEmbedding(content: string, retry = 0) {
|
|
|
try {
|
|
try {
|
|
|
- const response = await this.openai.createEmbedding({
|
|
|
|
|
|
|
+ const response = await this.embeddingApi.createEmbedding({
|
|
|
model: 'embedding',
|
|
model: 'embedding',
|
|
|
input: content
|
|
input: content
|
|
|
})
|
|
})
|
|
@@ -407,7 +420,7 @@ export class KnowledgeBaseService implements OnModuleInit {
|
|
|
async getKeywords(text: string) {
|
|
async getKeywords(text: string) {
|
|
|
try {
|
|
try {
|
|
|
const res = await this.openai.createChatCompletion({
|
|
const res = await this.openai.createChatCompletion({
|
|
|
- model: 'gpt35',
|
|
|
|
|
|
|
+ model: 'gpt-35-turbo',
|
|
|
messages: [
|
|
messages: [
|
|
|
{
|
|
{
|
|
|
role: 'user',
|
|
role: 'user',
|
|
@@ -425,16 +438,6 @@ export class KnowledgeBaseService implements OnModuleInit {
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- async searchEmbedding(name: string, embedding: number[]) {
|
|
|
|
|
- return await KnowledgeEmbedding.findAll({
|
|
|
|
|
- where: {
|
|
|
|
|
- name
|
|
|
|
|
- },
|
|
|
|
|
- order: this.sequelize.literal(`embedding <-> '${formatEmbedding(embedding)}'`),
|
|
|
|
|
- limit: 100
|
|
|
|
|
- })
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
cutContext(context: string[]) {
|
|
cutContext(context: string[]) {
|
|
|
if (!context || !context.length) return []
|
|
if (!context || !context.length) return []
|
|
|
let max = 4096 - 1024
|
|
let max = 4096 - 1024
|
|
@@ -457,7 +460,18 @@ export class KnowledgeBaseService implements OnModuleInit {
|
|
|
order: this.sequelize.literal(`embedding <-> '${formatEmbedding(keywordEmbedding)}'`),
|
|
order: this.sequelize.literal(`embedding <-> '${formatEmbedding(keywordEmbedding)}'`),
|
|
|
limit: 100
|
|
limit: 100
|
|
|
})
|
|
})
|
|
|
- const context = await this.cutContext(relatedEmbeddings.map((item) => item.text))
|
|
|
|
|
|
|
+ const files = await this.knowledgeFileRepository.findBy({
|
|
|
|
|
+ fileHash: In([...new Set(relatedEmbeddings.map((item) => item.fileHash))])
|
|
|
|
|
+ })
|
|
|
|
|
+ const context = await this.cutContext(
|
|
|
|
|
+ relatedEmbeddings.map((item) => {
|
|
|
|
|
+ const file = files.find((i) => i.fileHash === item.fileHash)
|
|
|
|
|
+ if (file) {
|
|
|
|
|
+ return `***\n[${file.fileName}](${file.fileUrl}):\n\n${item.text}`
|
|
|
|
|
+ }
|
|
|
|
|
+ return item.text
|
|
|
|
|
+ })
|
|
|
|
|
+ )
|
|
|
return context
|
|
return context
|
|
|
}
|
|
}
|
|
|
}
|
|
}
|