mirror of
https://github.com/father-bot/chatgpt_telegram_bot.git
synced 2026-06-13 03:54:57 +03:00
+189
-7
@@ -31,6 +31,7 @@ import config
|
|||||||
import database
|
import database
|
||||||
import openai_utils
|
import openai_utils
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
# setup
|
# setup
|
||||||
db = database.Database()
|
db = database.Database()
|
||||||
@@ -177,6 +178,168 @@ async def retry_handle(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
await message_handle(update, context, message=last_dialog_message["user"], use_new_dialog_timeout=False)
|
await message_handle(update, context, message=last_dialog_message["user"], use_new_dialog_timeout=False)
|
||||||
|
|
||||||
|
async def _vision_message_handle_fn(
|
||||||
|
update: Update, context: CallbackContext, use_new_dialog_timeout: bool = True
|
||||||
|
):
|
||||||
|
logger.info('_vision_message_handle_fn')
|
||||||
|
user_id = update.message.from_user.id
|
||||||
|
current_model = db.get_user_attribute(user_id, "current_model")
|
||||||
|
|
||||||
|
if current_model != "gpt-4-vision-preview":
|
||||||
|
await update.message.reply_text(
|
||||||
|
"🥲 Images processing is only available for <b>gpt-4-vision-preview</b> model. Please change your settings in /settings",
|
||||||
|
parse_mode=ParseMode.HTML,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
chat_mode = db.get_user_attribute(user_id, "current_chat_mode")
|
||||||
|
|
||||||
|
# new dialog timeout
|
||||||
|
if use_new_dialog_timeout:
|
||||||
|
if (datetime.now() - db.get_user_attribute(user_id, "last_interaction")).seconds > config.new_dialog_timeout and len(db.get_dialog_messages(user_id)) > 0:
|
||||||
|
db.start_new_dialog(user_id)
|
||||||
|
await update.message.reply_text(f"Starting new dialog due to timeout (<b>{config.chat_modes[chat_mode]['name']}</b> mode) ✅", parse_mode=ParseMode.HTML)
|
||||||
|
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||||
|
|
||||||
|
buf = None
|
||||||
|
if update.message.effective_attachment:
|
||||||
|
photo = update.message.effective_attachment[-1]
|
||||||
|
photo_file = await context.bot.get_file(photo.file_id)
|
||||||
|
|
||||||
|
# store file in memory, not on disk
|
||||||
|
buf = io.BytesIO()
|
||||||
|
await photo_file.download_to_memory(buf)
|
||||||
|
buf.name = "image.jpg" # file extension is required
|
||||||
|
buf.seek(0) # move cursor to the beginning of the buffer
|
||||||
|
|
||||||
|
# in case of CancelledError
|
||||||
|
n_input_tokens, n_output_tokens = 0, 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
# send placeholder message to user
|
||||||
|
placeholder_message = await update.message.reply_text("...")
|
||||||
|
message = update.message.caption or update.message.text
|
||||||
|
|
||||||
|
# send typing action
|
||||||
|
await update.message.chat.send_action(action="typing")
|
||||||
|
|
||||||
|
if message is None or len(message) == 0:
|
||||||
|
await update.message.reply_text(
|
||||||
|
"🥲 You sent <b>empty message</b>. Please, try again!",
|
||||||
|
parse_mode=ParseMode.HTML,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
dialog_messages = db.get_dialog_messages(user_id, dialog_id=None)
|
||||||
|
parse_mode = {"html": ParseMode.HTML, "markdown": ParseMode.MARKDOWN}[
|
||||||
|
config.chat_modes[chat_mode]["parse_mode"]
|
||||||
|
]
|
||||||
|
|
||||||
|
chatgpt_instance = openai_utils.ChatGPT(model=current_model)
|
||||||
|
if config.enable_message_streaming:
|
||||||
|
gen = chatgpt_instance.send_vision_message_stream(
|
||||||
|
message,
|
||||||
|
dialog_messages=dialog_messages,
|
||||||
|
image_buffer=buf,
|
||||||
|
chat_mode=chat_mode,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
answer,
|
||||||
|
(n_input_tokens, n_output_tokens),
|
||||||
|
n_first_dialog_messages_removed,
|
||||||
|
) = await chatgpt_instance.send_vision_message(
|
||||||
|
message,
|
||||||
|
dialog_messages=dialog_messages,
|
||||||
|
image_buffer=buf,
|
||||||
|
chat_mode=chat_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fake_gen():
|
||||||
|
yield "finished", answer, (
|
||||||
|
n_input_tokens,
|
||||||
|
n_output_tokens,
|
||||||
|
), n_first_dialog_messages_removed
|
||||||
|
|
||||||
|
gen = fake_gen()
|
||||||
|
|
||||||
|
prev_answer = ""
|
||||||
|
async for gen_item in gen:
|
||||||
|
(
|
||||||
|
status,
|
||||||
|
answer,
|
||||||
|
(n_input_tokens, n_output_tokens),
|
||||||
|
n_first_dialog_messages_removed,
|
||||||
|
) = gen_item
|
||||||
|
answer = current_model + " " + answer
|
||||||
|
answer = answer[:4096] # telegram message limit
|
||||||
|
|
||||||
|
# update only when 100 new symbols are ready
|
||||||
|
if abs(len(answer) - len(prev_answer)) < 100 and status != "finished":
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await context.bot.edit_message_text(
|
||||||
|
answer,
|
||||||
|
chat_id=placeholder_message.chat_id,
|
||||||
|
message_id=placeholder_message.message_id,
|
||||||
|
parse_mode=parse_mode,
|
||||||
|
)
|
||||||
|
except telegram.error.BadRequest as e:
|
||||||
|
if str(e).startswith("Message is not modified"):
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
await context.bot.edit_message_text(
|
||||||
|
answer,
|
||||||
|
chat_id=placeholder_message.chat_id,
|
||||||
|
message_id=placeholder_message.message_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(0.01) # wait a bit to avoid flooding
|
||||||
|
|
||||||
|
prev_answer = answer
|
||||||
|
|
||||||
|
# update user data
|
||||||
|
if buf is not None:
|
||||||
|
base_image = base64.b64encode(buf.getvalue()).decode("utf-8")
|
||||||
|
new_dialog_message = {"user": [
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": message,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": base_image,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
, "bot": answer, "date": datetime.now()}
|
||||||
|
else:
|
||||||
|
new_dialog_message = {"user": [{"type": "text", "text": message}], "bot": answer, "date": datetime.now()}
|
||||||
|
|
||||||
|
db.set_dialog_messages(
|
||||||
|
user_id,
|
||||||
|
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
|
||||||
|
dialog_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# note: intermediate token updates only work when enable_message_streaming=True (config.yml)
|
||||||
|
db.update_n_used_tokens(user_id, current_model, n_input_tokens, n_output_tokens)
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_text = f"Something went wrong during completion. Reason: {e}"
|
||||||
|
logger.error(error_text)
|
||||||
|
await update.message.reply_text(error_text)
|
||||||
|
return
|
||||||
|
|
||||||
|
async def unsupport_message_handle(update: Update, context: CallbackContext, message=None):
|
||||||
|
error_text = f"I don't know how to read files or videos. Send the picture in normal mode (Quick Mode)."
|
||||||
|
logger.error(error_text)
|
||||||
|
await update.message.reply_text(error_text)
|
||||||
|
return
|
||||||
|
|
||||||
async def message_handle(update: Update, context: CallbackContext, message=None, use_new_dialog_timeout=True):
|
async def message_handle(update: Update, context: CallbackContext, message=None, use_new_dialog_timeout=True):
|
||||||
# check if bot was mentioned (for group chats)
|
# check if bot was mentioned (for group chats)
|
||||||
@@ -204,6 +367,8 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
|
|||||||
await generate_image_handle(update, context, message=message)
|
await generate_image_handle(update, context, message=message)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
current_model = db.get_user_attribute(user_id, "current_model")
|
||||||
|
|
||||||
async def message_handle_fn():
|
async def message_handle_fn():
|
||||||
# new dialog timeout
|
# new dialog timeout
|
||||||
if use_new_dialog_timeout:
|
if use_new_dialog_timeout:
|
||||||
@@ -214,7 +379,6 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
|
|||||||
|
|
||||||
# in case of CancelledError
|
# in case of CancelledError
|
||||||
n_input_tokens, n_output_tokens = 0, 0
|
n_input_tokens, n_output_tokens = 0, 0
|
||||||
current_model = db.get_user_attribute(user_id, "current_model")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# send placeholder message to user
|
# send placeholder message to user
|
||||||
@@ -249,11 +413,12 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
|
|||||||
gen = fake_gen()
|
gen = fake_gen()
|
||||||
|
|
||||||
prev_answer = ""
|
prev_answer = ""
|
||||||
|
|
||||||
async for gen_item in gen:
|
async for gen_item in gen:
|
||||||
status, answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed = gen_item
|
status, answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed = gen_item
|
||||||
|
answer = current_model + " " + answer
|
||||||
answer = answer[:4096] # telegram message limit
|
answer = answer[:4096] # telegram message limit
|
||||||
|
|
||||||
# update only when 100 new symbols are ready
|
# update only when 100 new symbols are ready
|
||||||
if abs(len(answer) - len(prev_answer)) < 100 and status != "finished":
|
if abs(len(answer) - len(prev_answer)) < 100 and status != "finished":
|
||||||
continue
|
continue
|
||||||
@@ -267,11 +432,12 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
|
|||||||
await context.bot.edit_message_text(answer, chat_id=placeholder_message.chat_id, message_id=placeholder_message.message_id)
|
await context.bot.edit_message_text(answer, chat_id=placeholder_message.chat_id, message_id=placeholder_message.message_id)
|
||||||
|
|
||||||
await asyncio.sleep(0.01) # wait a bit to avoid flooding
|
await asyncio.sleep(0.01) # wait a bit to avoid flooding
|
||||||
|
|
||||||
prev_answer = answer
|
prev_answer = answer
|
||||||
|
|
||||||
# update user data
|
# update user data
|
||||||
new_dialog_message = {"user": _message, "bot": answer, "date": datetime.now()}
|
new_dialog_message = {"user": _message, "bot": answer, "date": datetime.now()}
|
||||||
|
|
||||||
db.set_dialog_messages(
|
db.set_dialog_messages(
|
||||||
user_id,
|
user_id,
|
||||||
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
|
db.get_dialog_messages(user_id, dialog_id=None) + [new_dialog_message],
|
||||||
@@ -300,7 +466,19 @@ async def message_handle(update: Update, context: CallbackContext, message=None,
|
|||||||
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
|
await update.message.reply_text(text, parse_mode=ParseMode.HTML)
|
||||||
|
|
||||||
async with user_semaphores[user_id]:
|
async with user_semaphores[user_id]:
|
||||||
task = asyncio.create_task(message_handle_fn())
|
if current_model == "gpt-4-vision-preview" or update.message.photo is not None and len(update.message.photo) > 0:
|
||||||
|
logger.error('gpt-4-vision-preview')
|
||||||
|
if current_model != "gpt-4-vision-preview":
|
||||||
|
current_model = "gpt-4-vision-preview"
|
||||||
|
db.set_user_attribute(user_id, "current_model", "gpt-4-vision-preview")
|
||||||
|
task = asyncio.create_task(
|
||||||
|
_vision_message_handle_fn(update, context, use_new_dialog_timeout=use_new_dialog_timeout)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
task = asyncio.create_task(
|
||||||
|
message_handle_fn()
|
||||||
|
)
|
||||||
|
|
||||||
user_tasks[user_id] = task
|
user_tasks[user_id] = task
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -392,6 +570,7 @@ async def new_dialog_handle(update: Update, context: CallbackContext):
|
|||||||
|
|
||||||
user_id = update.message.from_user.id
|
user_id = update.message.from_user.id
|
||||||
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
db.set_user_attribute(user_id, "last_interaction", datetime.now())
|
||||||
|
db.set_user_attribute(user_id, "current_model", "gpt-3.5-turbo")
|
||||||
|
|
||||||
db.start_new_dialog(user_id)
|
db.start_new_dialog(user_id)
|
||||||
await update.message.reply_text("Starting new dialog ✅")
|
await update.message.reply_text("Starting new dialog ✅")
|
||||||
@@ -672,6 +851,9 @@ def run_bot() -> None:
|
|||||||
application.add_handler(CommandHandler("help_group_chat", help_group_chat_handle, filters=user_filter))
|
application.add_handler(CommandHandler("help_group_chat", help_group_chat_handle, filters=user_filter))
|
||||||
|
|
||||||
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND & user_filter, message_handle))
|
application.add_handler(MessageHandler(filters.TEXT & ~filters.COMMAND & user_filter, message_handle))
|
||||||
|
application.add_handler(MessageHandler(filters.PHOTO & ~filters.COMMAND & user_filter, message_handle))
|
||||||
|
application.add_handler(MessageHandler(filters.VIDEO & ~filters.COMMAND & user_filter, unsupport_message_handle))
|
||||||
|
application.add_handler(MessageHandler(filters.Document.ALL & ~filters.COMMAND & user_filter, unsupport_message_handle))
|
||||||
application.add_handler(CommandHandler("retry", retry_handle, filters=user_filter))
|
application.add_handler(CommandHandler("retry", retry_handle, filters=user_filter))
|
||||||
application.add_handler(CommandHandler("new", new_dialog_handle, filters=user_filter))
|
application.add_handler(CommandHandler("new", new_dialog_handle, filters=user_filter))
|
||||||
application.add_handler(CommandHandler("cancel", cancel_handle, filters=user_filter))
|
application.add_handler(CommandHandler("cancel", cancel_handle, filters=user_filter))
|
||||||
@@ -694,4 +876,4 @@ def run_bot() -> None:
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_bot()
|
run_bot()
|
||||||
+154
-12
@@ -1,4 +1,7 @@
|
|||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
import config
|
import config
|
||||||
|
import logging
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import openai
|
import openai
|
||||||
@@ -8,6 +11,7 @@ import openai
|
|||||||
openai.api_key = config.openai_api_key
|
openai.api_key = config.openai_api_key
|
||||||
if config.openai_api_base is not None:
|
if config.openai_api_base is not None:
|
||||||
openai.api_base = config.openai_api_base
|
openai.api_base = config.openai_api_base
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
OPENAI_COMPLETION_OPTIONS = {
|
OPENAI_COMPLETION_OPTIONS = {
|
||||||
@@ -22,7 +26,7 @@ OPENAI_COMPLETION_OPTIONS = {
|
|||||||
|
|
||||||
class ChatGPT:
|
class ChatGPT:
|
||||||
def __init__(self, model="gpt-3.5-turbo"):
|
def __init__(self, model="gpt-3.5-turbo"):
|
||||||
assert model in {"text-davinci-003", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"}, f"Unknown model: {model}"
|
assert model in {"text-davinci-003", "gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-vision-preview"}, f"Unknown model: {model}"
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|
||||||
async def send_message(self, message, dialog_messages=[], chat_mode="assistant"):
|
async def send_message(self, message, dialog_messages=[], chat_mode="assistant"):
|
||||||
@@ -33,8 +37,9 @@ class ChatGPT:
|
|||||||
answer = None
|
answer = None
|
||||||
while answer is None:
|
while answer is None:
|
||||||
try:
|
try:
|
||||||
if self.model in {"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"}:
|
if self.model in {"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview", "gpt-4-vision-preview"}:
|
||||||
messages = self._generate_prompt_messages(message, dialog_messages, chat_mode)
|
messages = self._generate_prompt_messages(message, dialog_messages, chat_mode)
|
||||||
|
|
||||||
r = await openai.ChatCompletion.acreate(
|
r = await openai.ChatCompletion.acreate(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -75,6 +80,7 @@ class ChatGPT:
|
|||||||
try:
|
try:
|
||||||
if self.model in {"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"}:
|
if self.model in {"gpt-3.5-turbo-16k", "gpt-3.5-turbo", "gpt-4", "gpt-4-1106-preview"}:
|
||||||
messages = self._generate_prompt_messages(message, dialog_messages, chat_mode)
|
messages = self._generate_prompt_messages(message, dialog_messages, chat_mode)
|
||||||
|
|
||||||
r_gen = await openai.ChatCompletion.acreate(
|
r_gen = await openai.ChatCompletion.acreate(
|
||||||
model=self.model,
|
model=self.model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
@@ -85,11 +91,15 @@ class ChatGPT:
|
|||||||
answer = ""
|
answer = ""
|
||||||
async for r_item in r_gen:
|
async for r_item in r_gen:
|
||||||
delta = r_item.choices[0].delta
|
delta = r_item.choices[0].delta
|
||||||
|
|
||||||
if "content" in delta:
|
if "content" in delta:
|
||||||
answer += delta.content
|
answer += delta.content
|
||||||
n_input_tokens, n_output_tokens = self._count_tokens_from_messages(messages, answer, model=self.model)
|
n_input_tokens, n_output_tokens = self._count_tokens_from_messages(messages, answer, model=self.model)
|
||||||
n_first_dialog_messages_removed = n_dialog_messages_before - len(dialog_messages)
|
n_first_dialog_messages_removed = 0
|
||||||
|
|
||||||
yield "not_finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed
|
yield "not_finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed
|
||||||
|
|
||||||
|
|
||||||
elif self.model == "text-davinci-003":
|
elif self.model == "text-davinci-003":
|
||||||
prompt = self._generate_prompt(message, dialog_messages, chat_mode)
|
prompt = self._generate_prompt(message, dialog_messages, chat_mode)
|
||||||
r_gen = await openai.Completion.acreate(
|
r_gen = await openai.Completion.acreate(
|
||||||
@@ -117,6 +127,109 @@ class ChatGPT:
|
|||||||
|
|
||||||
yield "finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed # sending final answer
|
yield "finished", answer, (n_input_tokens, n_output_tokens), n_first_dialog_messages_removed # sending final answer
|
||||||
|
|
||||||
|
async def send_vision_message(
|
||||||
|
self,
|
||||||
|
message,
|
||||||
|
dialog_messages=[],
|
||||||
|
chat_mode="assistant",
|
||||||
|
image_buffer: BytesIO = None,
|
||||||
|
):
|
||||||
|
n_dialog_messages_before = len(dialog_messages)
|
||||||
|
answer = None
|
||||||
|
while answer is None:
|
||||||
|
try:
|
||||||
|
if self.model == "gpt-4-vision-preview":
|
||||||
|
messages = self._generate_prompt_messages(
|
||||||
|
message, dialog_messages, chat_mode, image_buffer
|
||||||
|
)
|
||||||
|
r = await openai.ChatCompletion.acreate(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
**OPENAI_COMPLETION_OPTIONS
|
||||||
|
)
|
||||||
|
answer = r.choices[0].message.content
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported model: {self.model}")
|
||||||
|
|
||||||
|
answer = self._postprocess_answer(answer)
|
||||||
|
n_input_tokens, n_output_tokens = (
|
||||||
|
r.usage.prompt_tokens,
|
||||||
|
r.usage.completion_tokens,
|
||||||
|
)
|
||||||
|
except openai.error.InvalidRequestError as e: # too many tokens
|
||||||
|
if len(dialog_messages) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Dialog messages is reduced to zero, but still has too many tokens to make completion"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
# forget first message in dialog_messages
|
||||||
|
dialog_messages = dialog_messages[1:]
|
||||||
|
|
||||||
|
n_first_dialog_messages_removed = n_dialog_messages_before - len(
|
||||||
|
dialog_messages
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
answer,
|
||||||
|
(n_input_tokens, n_output_tokens),
|
||||||
|
n_first_dialog_messages_removed,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def send_vision_message_stream(
|
||||||
|
self,
|
||||||
|
message,
|
||||||
|
dialog_messages=[],
|
||||||
|
chat_mode="assistant",
|
||||||
|
image_buffer: BytesIO = None,
|
||||||
|
):
|
||||||
|
n_dialog_messages_before = len(dialog_messages)
|
||||||
|
answer = None
|
||||||
|
while answer is None:
|
||||||
|
try:
|
||||||
|
if self.model == "gpt-4-vision-preview":
|
||||||
|
messages = self._generate_prompt_messages(
|
||||||
|
message, dialog_messages, chat_mode, image_buffer
|
||||||
|
)
|
||||||
|
|
||||||
|
r_gen = await openai.ChatCompletion.acreate(
|
||||||
|
model=self.model,
|
||||||
|
messages=messages,
|
||||||
|
stream=True,
|
||||||
|
**OPENAI_COMPLETION_OPTIONS,
|
||||||
|
)
|
||||||
|
|
||||||
|
answer = ""
|
||||||
|
async for r_item in r_gen:
|
||||||
|
delta = r_item.choices[0].delta
|
||||||
|
if "content" in delta:
|
||||||
|
answer += delta.content
|
||||||
|
(
|
||||||
|
n_input_tokens,
|
||||||
|
n_output_tokens,
|
||||||
|
) = self._count_tokens_from_messages(
|
||||||
|
messages, answer, model=self.model
|
||||||
|
)
|
||||||
|
n_first_dialog_messages_removed = (
|
||||||
|
n_dialog_messages_before - len(dialog_messages)
|
||||||
|
)
|
||||||
|
yield "not_finished", answer, (
|
||||||
|
n_input_tokens,
|
||||||
|
n_output_tokens,
|
||||||
|
), n_first_dialog_messages_removed
|
||||||
|
|
||||||
|
answer = self._postprocess_answer(answer)
|
||||||
|
|
||||||
|
except openai.error.InvalidRequestError as e: # too many tokens
|
||||||
|
if len(dialog_messages) == 0:
|
||||||
|
raise e
|
||||||
|
# forget first message in dialog_messages
|
||||||
|
dialog_messages = dialog_messages[1:]
|
||||||
|
|
||||||
|
yield "finished", answer, (
|
||||||
|
n_input_tokens,
|
||||||
|
n_output_tokens,
|
||||||
|
), n_first_dialog_messages_removed
|
||||||
|
|
||||||
def _generate_prompt(self, message, dialog_messages, chat_mode):
|
def _generate_prompt(self, message, dialog_messages, chat_mode):
|
||||||
prompt = config.chat_modes[chat_mode]["prompt_start"]
|
prompt = config.chat_modes[chat_mode]["prompt_start"]
|
||||||
prompt += "\n\n"
|
prompt += "\n\n"
|
||||||
@@ -134,16 +247,32 @@ class ChatGPT:
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def _generate_prompt_messages(self, message, dialog_messages, chat_mode):
|
def _encode_image(self, image_buffer: BytesIO) -> bytes:
|
||||||
|
return base64.b64encode(image_buffer.read()).decode("utf-8")
|
||||||
|
|
||||||
|
def _generate_prompt_messages(self, message, dialog_messages, chat_mode, image_buffer: BytesIO = None):
|
||||||
prompt = config.chat_modes[chat_mode]["prompt_start"]
|
prompt = config.chat_modes[chat_mode]["prompt_start"]
|
||||||
|
|
||||||
messages = [{"role": "system", "content": prompt}]
|
messages = [{"role": "system", "content": prompt}]
|
||||||
|
user_messages = {"role": "user", "content": []}
|
||||||
|
|
||||||
for dialog_message in dialog_messages:
|
for dialog_message in dialog_messages:
|
||||||
messages.append({"role": "user", "content": dialog_message["user"]})
|
user_messages["content"].extend(dialog_message["user"])
|
||||||
messages.append({"role": "assistant", "content": dialog_message["bot"]})
|
messages.append({"role": "assistant", "content": dialog_message["bot"]})
|
||||||
messages.append({"role": "user", "content": message})
|
|
||||||
|
|
||||||
return messages
|
user_messages["content"].append({"type": "text", "text": message})
|
||||||
|
|
||||||
|
if image_buffer is not None:
|
||||||
|
user_messages["content"].append(
|
||||||
|
{
|
||||||
|
"type": "image",
|
||||||
|
"image": self._encode_image(image_buffer),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
response = messages + ([user_messages] if len(user_messages["content"]) > 0 else [])
|
||||||
|
return response
|
||||||
|
|
||||||
def _postprocess_answer(self, answer):
|
def _postprocess_answer(self, answer):
|
||||||
answer = answer.strip()
|
answer = answer.strip()
|
||||||
@@ -164,6 +293,9 @@ class ChatGPT:
|
|||||||
elif model == "gpt-4-1106-preview":
|
elif model == "gpt-4-1106-preview":
|
||||||
tokens_per_message = 3
|
tokens_per_message = 3
|
||||||
tokens_per_name = 1
|
tokens_per_name = 1
|
||||||
|
elif model == "gpt-4-vision-preview":
|
||||||
|
tokens_per_message = 3
|
||||||
|
tokens_per_name = 1
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model: {model}")
|
raise ValueError(f"Unknown model: {model}")
|
||||||
|
|
||||||
@@ -171,10 +303,20 @@ class ChatGPT:
|
|||||||
n_input_tokens = 0
|
n_input_tokens = 0
|
||||||
for message in messages:
|
for message in messages:
|
||||||
n_input_tokens += tokens_per_message
|
n_input_tokens += tokens_per_message
|
||||||
for key, value in message.items():
|
if isinstance(message["content"], list):
|
||||||
n_input_tokens += len(encoding.encode(value))
|
for sub_message in message["content"]:
|
||||||
if key == "name":
|
if "type" in sub_message:
|
||||||
n_input_tokens += tokens_per_name
|
if sub_message["type"] == "text":
|
||||||
|
n_input_tokens += len(encoding.encode(sub_message["text"]))
|
||||||
|
elif sub_message["type"] == "image_url":
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
if "type" in message:
|
||||||
|
if message["type"] == "text":
|
||||||
|
n_input_tokens += len(encoding.encode(message["text"]))
|
||||||
|
elif message["type"] == "image_url":
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
n_input_tokens += 2
|
n_input_tokens += 2
|
||||||
|
|
||||||
@@ -205,4 +347,4 @@ async def generate_images(prompt, n_images=4, size="512x512"):
|
|||||||
|
|
||||||
async def is_content_acceptable(prompt):
|
async def is_content_acceptable(prompt):
|
||||||
r = await openai.Moderation.acreate(input=prompt)
|
r = await openai.Moderation.acreate(input=prompt)
|
||||||
return not all(r.results[0].categories.values())
|
return not all(r.results[0].categories.values())
|
||||||
+14
-1
@@ -1,4 +1,4 @@
|
|||||||
available_text_models: ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4-1106-preview", "gpt-4", "text-davinci-003"]
|
available_text_models: ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "text-davinci-003"]
|
||||||
|
|
||||||
info:
|
info:
|
||||||
gpt-3.5-turbo:
|
gpt-3.5-turbo:
|
||||||
@@ -53,6 +53,19 @@ info:
|
|||||||
fast: 4
|
fast: 4
|
||||||
cheap: 3
|
cheap: 3
|
||||||
|
|
||||||
|
gpt-4-vision-preview:
|
||||||
|
type: chat_completion
|
||||||
|
name: GPT-4 Vision
|
||||||
|
description: Ability to <b>understand images</b>, in addition to all other GPT-4 Turbo capabilties.
|
||||||
|
|
||||||
|
price_per_1000_input_tokens: 0.01
|
||||||
|
price_per_1000_output_tokens: 0.03
|
||||||
|
|
||||||
|
scores:
|
||||||
|
smart: 5
|
||||||
|
fast: 4
|
||||||
|
cheap: 3
|
||||||
|
|
||||||
text-davinci-003:
|
text-davinci-003:
|
||||||
type: completion
|
type: completion
|
||||||
name: GPT-3.5
|
name: GPT-3.5
|
||||||
|
|||||||
Reference in New Issue
Block a user