| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223 |
- import { ForbiddenException, Injectable, InternalServerErrorException, Logger } from '@nestjs/common'
- import { Observable } from 'rxjs'
- import { ChatGPTAPI, ChatMessage } from '../chatapi'
- import type { RequestProps } from './types'
- import { chatReplyProcess } from './chatgpt'
- import { Repository, MoreThanOrEqual, LessThanOrEqual, And } from 'typeorm'
- import { ChatHistory } from './entities/chat.entity'
- import { InjectRepository } from '@nestjs/typeorm'
- import { TokenUsage } from './entities/token-usage.entity'
- import { format } from 'date-fns'
- import { MembershipService } from '../membership/membership.service'
- import { MemberType } from '../membership/entities/membership.entity'
- import { get_encoding } from '@dqbd/tiktoken'
- import { fetchSSE } from 'src/chatapi/fetch-sse'
- import { HttpService } from '@nestjs/axios'
- import * as types from '../chatapi/types'
- import { SysConfigService } from 'src/sys-config/sys-config.service'
- @Injectable()
- export class ChatService {
- tokenizer = get_encoding('cl100k_base')
- constructor(
- @InjectRepository(ChatHistory)
- private readonly chatHistoryRepository: Repository<ChatHistory>,
- @InjectRepository(TokenUsage)
- private readonly tokenUsageRepository: Repository<TokenUsage>,
- private readonly membershipService: MembershipService,
- private readonly httpService: HttpService,
- private readonly sysConfigService: SysConfigService
- ) {}
- public chat(req, res): Observable<any> {
- res.setHeader('Content-Type', 'application/octet-stream')
- return new Observable((observer) => {
- const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
- let firstChunk = true
- chatReplyProcess({
- message: prompt,
- lastContext: options,
- process: (chat: ChatMessage) => {
- // observer.next(
- // new MessageEvent('message', {
- // data: firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`
- // })
- // )
- observer.next(firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`)
- firstChunk = false
- },
- systemMessage,
- temperature,
- top_p
- })
- .then(() => {})
- .catch((error) => {
- observer.error(error)
- })
- .finally(() => {
- observer.complete()
- })
- })
- }
- public async chat1(req, res) {
- res.setHeader('Content-type', 'application/octet-stream')
- const defSysMsg = (await this.sysConfigService.findByName('system_message'))?.value
- const membership = await this.membershipService.getMembership(req.user.id)
- if (!membership) {
- throw new ForbiddenException('请先成为会员')
- }
- if (membership.memberType == MemberType.Trial && membership.tokenLeft <= 0) {
- throw new ForbiddenException('您的试用额度已用完,请升级会员')
- }
- if (membership.isExpired) {
- throw new ForbiddenException('您的会员已过期,请即时续费')
- }
- try {
- const promptTime = new Date()
- const { prompt, options = {}, systemMessage, temperature, top_p } = req.body as RequestProps
- let firstChunk = true
- const result = await chatReplyProcess({
- message: prompt,
- lastContext: options,
- process: (chat: ChatMessage) => {
- res.write(firstChunk ? JSON.stringify(chat) : `\n${JSON.stringify(chat)}`)
- firstChunk = false
- },
- systemMessage: systemMessage || defSysMsg,
- temperature,
- top_p
- })
- let chatMessage = result.data as ChatMessage
- this.chatHistoryRepository
- .save(
- new ChatHistory({
- messageId: chatMessage.parentMessageId,
- parentMessageId: options.parentMessageId,
- userId: req.user.id,
- message: prompt,
- role: 'user',
- token: chatMessage.detail.usage.prompt_tokens,
- time: promptTime
- })
- )
- .catch((e) => {
- Logger.error(e, 'SAVE CHAT HISTORY')
- })
- this.chatHistoryRepository
- .save(
- new ChatHistory({
- messageId: chatMessage.id,
- parentMessageId: chatMessage.parentMessageId,
- userId: req.user.id,
- message: chatMessage.text,
- role: 'assistant',
- token: chatMessage.detail.usage.completion_tokens,
- time: new Date()
- })
- )
- .catch((e) => {
- Logger.error(e, 'SAVE CHAT HISTORY')
- })
- this.saveUsage(req.user.id, chatMessage.detail.usage.total_tokens)
- } catch (error) {
- res.write(JSON.stringify(error))
- } finally {
- res.end()
- }
- }
- public async chatProxy(req) {
- const url = `${process.env.AZURE_OPENAI_ENDPOINT}/openai/deployments/${process.env.AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${process.env.AZURE_OPENAI_VERSION}`
- req.body.stream = false
- try {
- const { data } = await this.httpService.axiosRef.post(url, req.body, {
- headers: {
- 'Content-Type': 'application/json',
- 'api-key': `${process.env.AZURE_OPENAI_KEY}`
- }
- })
- this.saveUsage(req.user.id, data.usage.total_tokens, false)
- return data
- } catch (e) {
- throw new InternalServerErrorException(e.response.data)
- }
- }
- public streamChatProxy(req) {
- const url = `${process.env.AZURE_OPENAI_ENDPOINT}/openai/deployments/${process.env.AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version=${process.env.AZURE_OPENAI_VERSION}`
- req.body.stream = true
- return new Observable((subscriber) => {
- let text = ''
- fetchSSE(url, {
- body: JSON.stringify(req.body),
- headers: {
- 'Content-Type': 'application/json',
- 'api-key': `${process.env.AZURE_OPENAI_KEY}`
- },
- method: 'POST',
- onMessage: (msg: string) => {
- subscriber.next(msg)
- if ('[DONE]' === msg) {
- Logger.log('done', 'CHAT PROXY')
- this.tiktokenAndSave(
- req.user.id,
- req.body.messages.map((message) => `${message.role}:\n${message.content}`).join('\n\n') +
- '\n\nassistant:\n' +
- text
- )
- return subscriber.complete()
- }
- const response: types.openai.CreateChatCompletionDeltaResponse = JSON.parse(msg)
- if (response.choices?.length) {
- const delta = response.choices[0].delta
- if (delta?.content) text += delta.content
- }
- },
- onError: (err) => {
- Logger.error(err, 'CHAT PROXY')
- subscriber.error(err)
- subscriber.complete()
- }
- }).catch((e) => {
- Logger.error(e, 'CHAT PROXY')
- subscriber.error(e)
- })
- })
- }
- public async tiktokenAndSave(userId: number, text: string) {
- // TODO: use a better fix in the tokenizer
- text = text.replace(/<\|endoftext\|>/g, '')
- const tokenizer = get_encoding('cl100k_base')
- const token = tokenizer.encode(text).length
- await this.saveUsage(userId, token, false)
- }
- public async saveUsage(userId: number, usage: number, member = true) {
- const date = format(new Date(), 'yyyy-MM-dd')
- const tokenUsage = await this.tokenUsageRepository.findOneBy({
- userId,
- date
- })
- if (tokenUsage) {
- tokenUsage.usage += usage
- await this.tokenUsageRepository.save(tokenUsage)
- } else {
- await this.tokenUsageRepository.save(new TokenUsage({ userId, usage, date }))
- }
- if (member) {
- this.membershipService.saveUsage(userId, usage)
- }
- }
- public async getUsage(userId: number, start?: string, end?: string): Promise<TokenUsage[]> {
- const date = format(new Date(), 'yyyy-MM-dd')
- return await this.tokenUsageRepository.findBy({
- userId,
- date: And(MoreThanOrEqual(start || '2023-04-01'), LessThanOrEqual(end || date))
- })
- }
- }
|