xiongzhu 3 жил өмнө
parent
commit
43b6292ff0

+ 2 - 1
src/chat/chat.module.ts

@@ -3,9 +3,10 @@ import { ChatService } from './chat.service'
 import { ChatController } from './chat.controller'
 import { TypeOrmModule } from '@nestjs/typeorm'
 import { ChatHistory } from './entities/chat.entity'
+import { TokenUsage } from './entities/token-usage.entity'
 
 @Module({
-    imports: [TypeOrmModule.forFeature([ChatHistory])],
+    imports: [TypeOrmModule.forFeature([ChatHistory, TokenUsage])],
     providers: [ChatService],
     controllers: [ChatController]
 })

+ 55 - 3
src/chat/chat.service.ts

@@ -6,10 +6,17 @@ import { chatReplyProcess } from './chatgpt'
 import { Repository } from 'typeorm'
 import { ChatHistory } from './entities/chat.entity'
 import { InjectRepository } from '@nestjs/typeorm'
+import { TokenUsage } from './entities/token-usage.entity'
+import { format } from 'date-fns'
 
 @Injectable()
 export class ChatService {
-    constructor(@InjectRepository(ChatHistory) private readonly chatHistoryRepository: Repository<ChatHistory>) {}
+    constructor(
+        @InjectRepository(ChatHistory)
+        private readonly chatHistoryRepository: Repository<ChatHistory>,
+        @InjectRepository(TokenUsage)
+        private readonly tokenUsageRepository: Repository<TokenUsage>
+    ) {}
 
     public chat(req, res): Observable<any> {
         res.setHeader('Content-Type', 'application/octet-stream')
@@ -46,7 +53,7 @@ export class ChatService {
         res.setHeader('Content-type', 'application/octet-stream')
 
         try {
-            Logger.log(JSON.stringify(req.body, null, 2), 'ASK')
+            const promptTime = new Date()
             const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
             let firstChunk = true
             const result = await chatReplyProcess({
@@ -61,11 +68,56 @@ export class ChatService {
                 top_p
             })
             let chatMessage = result.data as ChatMessage
-            Logger.log(JSON.stringify(result, null, 2), 'ANSWER')
+
+            this.chatHistoryRepository
+                .save(
+                    new ChatHistory({
+                        messageId: chatMessage.parentMessageId,
+                        parentMessageId: options.parentMessageId,
+                        userId: req.user.id,
+                        message: prompt,
+                        role: 'user',
+                        token: chatMessage.detail.usage.prompt_tokens,
+                        time: promptTime
+                    })
+                )
+                .catch((e) => {
+                    Logger.error(e, 'SAVE CHAT HISTORY')
+                })
+            this.chatHistoryRepository
+                .save(
+                    new ChatHistory({
+                        messageId: chatMessage.id,
+                        parentMessageId: chatMessage.parentMessageId,
+                        userId: req.user.id,
+                        message: chatMessage.text,
+                        role: 'assistant',
+                        token: chatMessage.detail.usage.completion_tokens,
+                        time: new Date()
+                    })
+                )
+                .catch((e) => {
+                    Logger.error(e, 'SAVE CHAT HISTORY')
+                })
+            this.saveUsage(req.user.id, chatMessage.detail.usage.total_tokens)
         } catch (error) {
             res.write(JSON.stringify(error))
         } finally {
             res.end()
         }
     }
+
+    public async saveUsage(userId: number, usage: number) {
+        const date = format(new Date(), 'yyyy-MM-dd')
+        const tokenUsage = await this.tokenUsageRepository.findOneBy({
+            userId,
+            date
+        })
+        if (tokenUsage) {
+            tokenUsage.usage += usage
+            await this.tokenUsageRepository.save(tokenUsage)
+        } else {
+            await this.tokenUsageRepository.save(new TokenUsage({ userId, usage, date }))
+        }
+    }
 }

+ 1 - 1
src/chat/chatgpt/index.ts

@@ -34,7 +34,7 @@ let api: ChatGPTAPI | ChatGPTUnofficialProxyAPI
     if (isNotEmptyString(process.env.OPENAI_API_KEY)) {
         const OPENAI_API_BASE_URL = process.env.OPENAI_API_BASE_URL
         const store = new KeyvRedis(process.env.REDIS_URI)
-        const messageStore = new Keyv({ store, namespace: 'chatgpt', ttl: 100 })
+        const messageStore = new Keyv({ store, namespace: 'chatgpt', ttl: 90 * 24 * 60 * 60 * 1000 })
         const options: ChatGPTAPIOptions = {
             apiKey: process.env.AZURE_OPENAI_KEY,
             apiEndpoint: process.env.AZURE_OPENAI_ENDPOINT,

+ 18 - 6
src/chat/entities/chat.entity.ts

@@ -1,20 +1,32 @@
-import { CreateDateColumn, Entity, PrimaryGeneratedColumn } from 'typeorm'
+import { Column, CreateDateColumn, Entity, PrimaryColumn, PrimaryGeneratedColumn } from 'typeorm'
 
 @Entity()
 export class ChatHistory {
     @PrimaryGeneratedColumn()
-    id: number
+    id: string
 
-    userId: number
+    @Column()
+    messageId: string
+
+    @Column({ nullable: true })
+    parentMessageId?: string
 
-    chatId: string
+    @Column()
+    userId: number
 
+    @Column({ type: 'text' })
     message: string
 
+    @Column()
     role: string
 
+    @Column()
     token: number
 
-    @CreateDateColumn()
-    createdAt: Date
+    @Column()
+    time: Date
+
+    constructor(data: Partial<ChatHistory> ) {
+        Object.assign(this, data)
+    }
 }

+ 21 - 0
src/chat/entities/token-usage.entity.ts

@@ -0,0 +1,21 @@
+import { Column, Entity, PrimaryGeneratedColumn, Unique } from 'typeorm'
+
+@Entity()
+@Unique('user_date', ['userId', 'date'])
+export class TokenUsage {
+    @PrimaryGeneratedColumn()
+    id: number
+
+    @Column()
+    userId: number
+
+    @Column({ type: 'date' })
+    date: string
+
+    @Column({ default: 0 })
+    usage: number
+
+    constructor(data: Partial<TokenUsage>) {
+        Object.assign(this, data)
+    }
+}

+ 5 - 1
src/membership/membership.service.ts

@@ -28,6 +28,7 @@ export class MembershipService {
         membership.userId = userId
         membership.expireAt = addDays(new Date(), 7)
         membership.memberType = MemberType.Trial
+        membership.tokenLeft = 3000
         return await this.memberShipRepository.save(membership)
     }
 
@@ -46,11 +47,14 @@ export class MembershipService {
             membership.userId = userId
             membership.planId = planId
             membership.expireAt = addDays(new Date(), plan.duration)
+            membership.tokenLeft = plan.tokenLimit
         } else {
             if (membership.expireAt < new Date() || membership.memberType == MemberType.Trial) {
                 membership.expireAt = addDays(new Date(), plan.duration)
+                membership.tokenLeft = plan.tokenLimit
             } else {
                 membership.expireAt = addDays(membership.expireAt, plan.duration)
+                membership.tokenLeft = Math.max(membership.tokenLeft, 0) + plan.tokenLimit
             }
         }
         membership.memberType = MemberType.Paid
@@ -65,7 +69,7 @@ export class MembershipService {
             userId: userId
         })
         if (!membership) {
-            throw new NotFoundException(`Membership #${userId} not found`)
+            return await this.trial(userId)
         }
         return membership
     }

+ 18 - 0
src/weixin/types.ts

@@ -70,4 +70,22 @@ declare namespace WechatPay {
         summary: string
         resource: Resource
     }
+
+    interface EncryptCertificate {
+        algorithm: string
+        nonce: string
+        associated_data: string
+        ciphertext: string
+    }
+
+    interface CertData {
+        serial_no: string
+        effective_time: Date
+        expire_time: Date
+        encrypt_certificate: EncryptCertificate
+    }
+
+    interface GetCertResponse {
+        data: CertData[]
+    }
 }

+ 7 - 6
src/weixin/weixin.service.ts

@@ -204,8 +204,8 @@ export class WeixinService {
                 this.weixinConfiguration.certSerial,
                 this.privateKey
             )
-            Logger.log(JSON.stringify(result.data, null, 2), '获取微信证书')
-            let data = JSON.stringify(result.data)
+            Logger.log('OK', '获取平台证书')
+            const data = result.data as WechatPay.GetCertResponse
             let headers = result.headers
             let serial = headers['wechatpay-serial']
             let timestamp = headers['wechatpay-timestamp']
@@ -214,7 +214,7 @@ export class WeixinService {
 
             // 根据序列号查证书  验证签名
             // let verifySignature: boolean = PayKit.verifySignature(signature, data, nonce, timestamp, wxPublicKey)
-            let verifySignature: boolean = PayKit.verifySign(headers, data, this.platformPlublicKey)
+            let verifySignature: boolean = PayKit.verifySign(headers, JSON.stringify(data), this.platformPlublicKey)
             Logger.log(verifySignature, '验证签名')
 
             let certPath = this.weixinConfiguration.certPath + 'platform_cert.pem'
@@ -223,13 +223,14 @@ export class WeixinService {
             })
             let decrypt = PayKit.aes256gcmDecrypt(
                 this.weixinConfiguration.mchKey,
-                result.data.data[0].encrypt_certificate.nonce,
-                result.data.data[0].encrypt_certificate.associated_data,
-                result.data.data[0].encrypt_certificate.ciphertext
+                data.data[0].encrypt_certificate.nonce,
+                data.data[0].encrypt_certificate.associated_data,
+                data.data[0].encrypt_certificate.ciphertext
             )
             // 保存证书
             fs.writeFileSync(certPath, decrypt)
             this.platformPlublicKey = fs.readFileSync(this.weixinConfiguration.certPath + 'platform_cert.pem')
+            Logger.log(this.weixinConfiguration.certPath + 'platform_cert.pem', '保存平台证书')
             return data
         } catch (error) {
             Logger.error(error)