|
|
@@ -0,0 +1,327 @@
|
|
|
+import {
|
|
|
+ BadRequestException,
|
|
|
+ ConflictException,
|
|
|
+ Injectable,
|
|
|
+ InternalServerErrorException,
|
|
|
+ Logger
|
|
|
+} from '@nestjs/common'
|
|
|
+import { InjectRepository } from '@nestjs/typeorm'
|
|
|
+import { KnowledgeBase } from './entities/knowledge-base.entity'
|
|
|
+import { Repository } from 'typeorm'
|
|
|
+import { Tiktoken, get_encoding } from '@dqbd/tiktoken'
|
|
|
+import { Configuration, OpenAIApi } from 'azure-openai'
|
|
|
+import { DataTypes, Sequelize } from 'sequelize'
|
|
|
+import { ConfigService } from '@nestjs/config'
|
|
|
+import { KnowledgeEmbedding } from './entities/knowledge-embedding.entity'
|
|
|
+import { VECTOR } from '../utils/pgvector'
|
|
|
+import * as queue from 'fastq'
|
|
|
+import { setTimeout } from 'timers/promises'
|
|
|
+import * as PdfParse from '@cyber2024/pdf-parse-fixed'
|
|
|
+import { createHash } from 'crypto'
|
|
|
+import { PageRequest } from '../common/dto/page-request'
|
|
|
+import { Pagination, paginate } from 'nestjs-typeorm-paginate'
|
|
|
+import { KnowledgeFile } from './entities/knowledge-file.entity'
|
|
|
+import { FileService } from 'src/file/file.service'
|
|
|
+import { FileStatus } from './enums/file-status.enum'
|
|
|
+
|
|
|
+function formatEmbedding(embedding: number[]) {
|
|
|
+ return `[${embedding.join(', ')}]`
|
|
|
+}
|
|
|
+
|
|
|
+@Injectable()
|
|
|
+export class KnowledgeBaseService {
|
|
|
+ private readonly tokenizer: Tiktoken
|
|
|
+ private readonly openai: OpenAIApi
|
|
|
+ private readonly sequelize: Sequelize
|
|
|
+ constructor(
|
|
|
+ @InjectRepository(KnowledgeBase)
|
|
|
+ private readonly knowledgeBaseRepository: Repository<KnowledgeBase>,
|
|
|
+ @InjectRepository(KnowledgeFile)
|
|
|
+ private readonly knowledgeFileRepository: Repository<KnowledgeFile>,
|
|
|
+ private readonly configService: ConfigService,
|
|
|
+ private readonly fileService: FileService
|
|
|
+ ) {
|
|
|
+ this.tokenizer = get_encoding('cl100k_base')
|
|
|
+ this.openai = new OpenAIApi(
|
|
|
+ new Configuration({
|
|
|
+ apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
|
|
|
+ // add azure info into configuration
|
|
|
+ azure: {
|
|
|
+ apiKey: 'beb32e4625a94b65ba8bc0ba1688c4d2',
|
|
|
+ endpoint: 'https://zouma.openai.azure.com'
|
|
|
+ }
|
|
|
+ })
|
|
|
+ )
|
|
|
+ this.sequelize = new Sequelize({
|
|
|
+ dialect: 'postgres',
|
|
|
+ host: configService.get<string>('PG_HOST'),
|
|
|
+ port: configService.get<number>('PG_PORT'),
|
|
|
+ username: configService.get<string>('PG_USERNAME'),
|
|
|
+ password: configService.get<string>('PG_PASSWORD'),
|
|
|
+ database: configService.get<string>('PG_DATABASE'),
|
|
|
+ // logging: (msg) => Logger.debug(msg, 'Sequelize')
|
|
|
+ logging: false
|
|
|
+ })
|
|
|
+ KnowledgeEmbedding.init(
|
|
|
+ {
|
|
|
+ id: {
|
|
|
+ primaryKey: true,
|
|
|
+ autoIncrement: true,
|
|
|
+ type: DataTypes.INTEGER
|
|
|
+ },
|
|
|
+ knowledgeId: {
|
|
|
+ type: DataTypes.INTEGER
|
|
|
+ },
|
|
|
+ fileId: {
|
|
|
+ type: DataTypes.INTEGER
|
|
|
+ },
|
|
|
+ fileHash: {
|
|
|
+ type: DataTypes.STRING
|
|
|
+ },
|
|
|
+ text: {
|
|
|
+ type: DataTypes.TEXT({
|
|
|
+ length: 'long'
|
|
|
+ })
|
|
|
+ },
|
|
|
+ embedding: {
|
|
|
+ type: new VECTOR(1536)
|
|
|
+ },
|
|
|
+ index: {
|
|
|
+ type: DataTypes.INTEGER
|
|
|
+ }
|
|
|
+ },
|
|
|
+ { sequelize: this.sequelize }
|
|
|
+ )
|
|
|
+ this.sequelize.sync()
|
|
|
+ }
|
|
|
+
|
|
|
+ async findAllKnowledgeBase(req: PageRequest<KnowledgeBase>): Promise<Pagination<KnowledgeBase>> {
|
|
|
+ return await paginate<KnowledgeBase>(this.knowledgeBaseRepository, req.page, req.search)
|
|
|
+ }
|
|
|
+
|
|
|
+ async createKnowledgeBase(knowledgeBase: Partial<KnowledgeBase>): Promise<KnowledgeBase> {
|
|
|
+ return await this.knowledgeBaseRepository.save(knowledgeBase)
|
|
|
+ }
|
|
|
+
|
|
|
+ async updateKnowledgeBase(knowledgeBase: Partial<KnowledgeBase>): Promise<KnowledgeBase> {
|
|
|
+ return await this.knowledgeBaseRepository.save(knowledgeBase)
|
|
|
+ }
|
|
|
+
|
|
|
+ async deleteKnowledgeBase(knowledgeBaseId: number): Promise<void> {
|
|
|
+ await this.knowledgeBaseRepository.delete(knowledgeBaseId)
|
|
|
+ await this.knowledgeFileRepository.delete({ knowledgeId: knowledgeBaseId })
|
|
|
+ await KnowledgeEmbedding.destroy({
|
|
|
+ where: {
|
|
|
+ knowledgeId: knowledgeBaseId
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ async getKnowledgeBaseById(knowledgeBaseId: number): Promise<KnowledgeBase> {
|
|
|
+ return await this.knowledgeBaseRepository.findOneOrFail({ where: { id: knowledgeBaseId } })
|
|
|
+ }
|
|
|
+
|
|
|
+ async fileAllKnowledgeFile(req: PageRequest<KnowledgeFile>): Promise<Pagination<KnowledgeFile>> {
|
|
|
+ return await paginate<KnowledgeFile>(this.knowledgeFileRepository, req.page, req.search)
|
|
|
+ }
|
|
|
+
|
|
|
+ async updateKnowledgeFile(knowledgeFile: Partial<KnowledgeFile>): Promise<KnowledgeFile> {
|
|
|
+ return await this.knowledgeFileRepository.save(knowledgeFile)
|
|
|
+ }
|
|
|
+
|
|
|
+ async deleteKnowledgeFile(knowledgeFileId: number): Promise<void> {
|
|
|
+ await this.knowledgeFileRepository.delete(knowledgeFileId)
|
|
|
+ await KnowledgeEmbedding.destroy({
|
|
|
+ where: {
|
|
|
+ fileId: knowledgeFileId
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ public async uploadKnowledgeFile(file: Express.Multer.File, knowledgeId: number) {
|
|
|
+ const knowledgeBase = await this.getKnowledgeBaseById(knowledgeId)
|
|
|
+ const { originalname, buffer, mimetype, size } = file
|
|
|
+ let fileHash = this.calculateMD5(buffer)
|
|
|
+ let knowledgeFile = await this.knowledgeFileRepository.findOneBy({
|
|
|
+ fileHash
|
|
|
+ })
|
|
|
+ if (knowledgeFile) {
|
|
|
+ throw new ConflictException(`File ${originalname} already exists`)
|
|
|
+ }
|
|
|
+ const { url: fileUrl } = await this.fileService.uploadBuffer(
|
|
|
+ buffer,
|
|
|
+ mimetype.split('/')[1],
|
|
|
+ originalname.split('.').slice(-1)
|
|
|
+ )
|
|
|
+ knowledgeFile = new KnowledgeFile()
|
|
|
+ knowledgeFile.orgId = knowledgeBase.orgId
|
|
|
+ knowledgeFile.knowledgeId = knowledgeId
|
|
|
+ knowledgeFile.fileHash = fileHash
|
|
|
+ knowledgeFile.fileType = mimetype
|
|
|
+ knowledgeFile.fileName = originalname
|
|
|
+ knowledgeFile.size = size
|
|
|
+ knowledgeFile.fileUrl = fileUrl
|
|
|
+ await this.knowledgeFileRepository.save(knowledgeFile)
|
|
|
+ this.processKnowledgeFile(knowledgeFile, buffer)
|
|
|
+ return knowledgeFile
|
|
|
+ }
|
|
|
+
|
|
|
+ public async processKnowledgeFile(knowledgeFile: KnowledgeFile, buffer: Buffer) {
|
|
|
+ knowledgeFile.status = FileStatus.PROCESSING
|
|
|
+ try {
|
|
|
+ await this.knowledgeFileRepository.save(knowledgeFile)
|
|
|
+ const pdf = await PdfParse(buffer)
|
|
|
+ const contents = []
|
|
|
+ let paragraph = ''
|
|
|
+ pdf.text
|
|
|
+ .trim()
|
|
|
+ .split('\n')
|
|
|
+ .forEach((line) => {
|
|
|
+ line = line.trim()
|
|
|
+ paragraph += line
|
|
|
+ if (this.isFullSentence(line)) {
|
|
|
+ contents.push(paragraph)
|
|
|
+ paragraph = ''
|
|
|
+ }
|
|
|
+ })
|
|
|
+ if (paragraph) {
|
|
|
+ contents.push(paragraph)
|
|
|
+ }
|
|
|
+
|
|
|
+ const embeddings = await this.createEmbeddings(contents)
|
|
|
+ Logger.log(
|
|
|
+ `create embeddings finished, total token usage: ${embeddings.reduce((acc, cur) => acc + cur.token, 0)}`
|
|
|
+ )
|
|
|
+ await KnowledgeEmbedding.destroy({
|
|
|
+ where: {
|
|
|
+ fileHash: knowledgeFile.fileHash
|
|
|
+ }
|
|
|
+ })
|
|
|
+ let i = 0
|
|
|
+ for (const item of embeddings) {
|
|
|
+ try {
|
|
|
+ await KnowledgeEmbedding.create({
|
|
|
+ knowledgeId: knowledgeFile.knowledgeId,
|
|
|
+ fileId: knowledgeFile.id,
|
|
|
+ fileHash: knowledgeFile.fileHash,
|
|
|
+ text: item.text,
|
|
|
+ embedding: formatEmbedding(item.embedding),
|
|
|
+ index: i++
|
|
|
+ })
|
|
|
+ } catch (error) {
|
|
|
+ Logger.error(error.message)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ knowledgeFile.status = FileStatus.DONE
|
|
|
+ await this.knowledgeFileRepository.save(knowledgeFile)
|
|
|
+ } catch (e) {
|
|
|
+ knowledgeFile.status = FileStatus.FAILED
|
|
|
+ knowledgeFile.error = e.message
|
|
|
+ await this.knowledgeFileRepository.save(knowledgeFile)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ isFullSentence(str) {
|
|
|
+ return /[.!?。!?…;;::”’)】》」』〕〉》〗〞〟»"'\])}]+$/.test(str)
|
|
|
+ }
|
|
|
+
|
|
|
+ calculateMD5(buffer) {
|
|
|
+ const hash = createHash('md5')
|
|
|
+ hash.update(buffer)
|
|
|
+ return hash.digest('hex')
|
|
|
+ }
|
|
|
+
|
|
|
+ async createEmbeddings(content: string[]) {
|
|
|
+ const self = this
|
|
|
+ const result = Array(content.length)
|
|
|
+ async function worker(arg) {
|
|
|
+ result[arg.index] = await self.getEmbedding(arg.text)
|
|
|
+ Logger.log(`create embedding for ${arg.index + 1}/${content.length}`)
|
|
|
+ }
|
|
|
+ const q = queue.promise(worker, 32)
|
|
|
+ content.forEach((text, index) => {
|
|
|
+ q.push({
|
|
|
+ text,
|
|
|
+ index
|
|
|
+ })
|
|
|
+ })
|
|
|
+ await q.drained()
|
|
|
+ return result.filter((i) => i && i.text)
|
|
|
+ }
|
|
|
+
|
|
|
+ async getEmbedding(content: string, retry = 0) {
|
|
|
+ try {
|
|
|
+ const response = await this.openai.createEmbedding({
|
|
|
+ model: 'embedding',
|
|
|
+ input: content
|
|
|
+ })
|
|
|
+ return {
|
|
|
+ text: content,
|
|
|
+ embedding: response.data.data[0].embedding,
|
|
|
+ token: response.data.usage.total_tokens
|
|
|
+ }
|
|
|
+ } catch (error) {
|
|
|
+ if (retry < 3) {
|
|
|
+ Logger.error(`fetchEmbedding error: ${error.message}, retry ${retry}`, 'fetchEmbedding')
|
|
|
+ await setTimeout(2000)
|
|
|
+ return await this.getEmbedding(content, retry + 1)
|
|
|
+ }
|
|
|
+ Logger.error(error.stack, 'fetchEmbedding')
|
|
|
+ throw new InternalServerErrorException(error.message)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ async getKeywords(text: string) {
|
|
|
+ try {
|
|
|
+ const res = await this.openai.createChatCompletion({
|
|
|
+ model: 'gpt35',
|
|
|
+ messages: [
|
|
|
+ {
|
|
|
+ role: 'user',
|
|
|
+ content: `You need to extract keywords from the statement or question and return a series of keywords separated by commas.\ncontent: ${text}\nkeywords: `
|
|
|
+ }
|
|
|
+ ]
|
|
|
+ })
|
|
|
+ return res.data.choices[0].message.content
|
|
|
+ } catch (error) {
|
|
|
+ Logger.error(error.message)
|
|
|
+ if (error.response) {
|
|
|
+ Logger.error(error.response.data)
|
|
|
+ }
|
|
|
+ throw new InternalServerErrorException(error.message)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ async searchEmbedding(name: string, embedding: number[]) {
|
|
|
+ return await KnowledgeEmbedding.findAll({
|
|
|
+ where: {
|
|
|
+ name
|
|
|
+ },
|
|
|
+ order: this.sequelize.literal(`embedding <-> '${formatEmbedding(embedding)}'`),
|
|
|
+ limit: 100
|
|
|
+ })
|
|
|
+ }
|
|
|
+
|
|
|
+ cutContext(context: string[]) {
|
|
|
+ if (!context || !context.length) return []
|
|
|
+ let max = 4096 - 1024
|
|
|
+ for (let i = 0; i < context.length; i++) {
|
|
|
+ max -= this.tokenizer.encode(context[i]).length
|
|
|
+ if (max < 0) {
|
|
|
+ return context.slice(0, i)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ return context
|
|
|
+ }
|
|
|
+
|
|
|
+ async searchKnowledge(question: string, orgId: number, knowledgeId?: number, fileId?: number) {
|
|
|
+ const keywords = await this.getKeywords(question)
|
|
|
+ const { embedding: keywordEmbedding } = await this.getEmbedding(keywords)
|
|
|
+ const context = await KnowledgeEmbedding.findAll({
|
|
|
+ where: { orgId, knowledgeId, fileId },
|
|
|
+ order: this.sequelize.literal(`embedding <-> '${formatEmbedding(keywordEmbedding)}'`),
|
|
|
+ limit: 100
|
|
|
+ })
|
|
|
+ return context
|
|
|
+ }
|
|
|
+}
|