diff --git "a/Needy-Haruhi/notebook/Needy_Gradio_Ernie.ipynb" "b/Needy-Haruhi/notebook/Needy_Gradio_Ernie.ipynb"
new file mode 100644--- /dev/null
+++ "b/Needy-Haruhi/notebook/Needy_Gradio_Ernie.ipynb"
@@ -0,0 +1,3308 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyPoTz8vGlLEA24WzE8qcB+M",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ }
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install -q transformers openai tiktoken langchain chromadb erniebot\n",
+ "!pip install -q chatharuhi\n",
+ "!pip install -q datasets"
+ ],
+ "metadata": {
+ "id": "a8H7Az3Yzi3o",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "ea8102e2-40c2-4e38-8cb5-469b80852344"
+ },
+ "execution_count": 1,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.9/7.9 MB\u001b[0m \u001b[31m45.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m217.8/217.8 kB\u001b[0m \u001b[31m24.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m56.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.0/2.0 MB\u001b[0m \u001b[31m88.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m479.8/479.8 kB\u001b[0m \u001b[31m40.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m68.4/68.4 kB\u001b[0m \u001b[31m8.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m302.0/302.0 kB\u001b[0m \u001b[31m26.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.8/3.8 MB\u001b[0m \u001b[31m64.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m82.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.0/75.0 kB\u001b[0m \u001b[31m9.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m44.8/44.8 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.4/2.4 MB\u001b[0m \u001b[31m47.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.9/92.9 kB\u001b[0m \u001b[31m12.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.7/59.7 kB\u001b[0m \u001b[31m7.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━��━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m5.4/5.4 MB\u001b[0m \u001b[31m36.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.2/6.2 MB\u001b[0m \u001b[31m54.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m57.5/57.5 kB\u001b[0m \u001b[31m6.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m103.9/103.9 kB\u001b[0m \u001b[31m11.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.3/67.3 kB\u001b[0m \u001b[31m4.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m593.7/593.7 kB\u001b[0m \u001b[31m53.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m84.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m239.0/239.0 kB\u001b[0m \u001b[31m30.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.4/49.4 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m67.0/67.0 kB\u001b[0m \u001b[31m9.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m76.9/76.9 kB\u001b[0m \u001b[31m10.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.8/143.8 kB\u001b[0m \u001b[31m19.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m46.0/46.0 kB\u001b[0m \u001b[31m5.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m50.8/50.8 kB\u001b[0m \u001b[31m6.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m295.0/295.0 kB\u001b[0m \u001b[31m34.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m341.4/341.4 kB\u001b[0m \u001b[31m30.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 MB\u001b[0m \u001b[31m64.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m64.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m130.2/130.2 kB\u001b[0m \u001b[31m15.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m64.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.8/86.8 kB\u001b[0m \u001b[31m11.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h Building wheel for pypika (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
+ "lida 0.0.10 requires kaleido, which is not installed.\n",
+ "lida 0.0.10 requires python-multipart, which is not installed.\n",
+ "llmx 0.0.15a0 requires cohere, which is not installed.\n",
+ "tensorflow-probability 0.22.0 requires typing-extensions<4.6.0, but you have typing-extensions 4.8.0 which is incompatible.\u001b[0m\u001b[31m\n",
+ "\u001b[0m Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ " Building wheel for chatharuhi (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m493.7/493.7 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m13.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m19.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import os\n",
+ "\n",
+ "# key = \"sk-WafsA4C\"\n",
+ "# key_bytes = key.encode()\n",
+ "# os.environ[\"OPENAI_API_KEY\"] = key_bytes.decode('utf-8')\n",
+ "\n",
+ "# 文心一言\n",
+ "os.environ[\"APIType\"] = \"aistudio\"\n",
+ "os.environ[\"ErnieAccess\"] = \"a97ee5\""
+ ],
+ "metadata": {
+ "id": "ny05bHfAznJP"
+ },
+ "execution_count": 2,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "%cd /content\n",
+ "!rm -rf /content/Needy-Haruhi\n",
+ "!git clone https://github.com/LC1332/Needy-Haruhi.git\n",
+ "\n",
+ "# !pip install -q transformers"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Fc5MKTS5q90b",
+ "outputId": "71c65f16-ae27-4b44-eba3-6a47b5b48c83"
+ },
+ "execution_count": 48,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "/content\n",
+ "Cloning into 'Needy-Haruhi'...\n",
+ "remote: Enumerating objects: 221, done.\u001b[K\n",
+ "remote: Counting objects: 100% (73/73), done.\u001b[K\n",
+ "remote: Compressing objects: 100% (65/65), done.\u001b[K\n",
+ "remote: Total 221 (delta 41), reused 19 (delta 8), pack-reused 148\u001b[K\n",
+ "Receiving objects: 100% (221/221), 3.93 MiB | 8.20 MiB/s, done.\n",
+ "Resolving deltas: 100% (118/118), done.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import sys\n",
+ "sys.path.append('/content/Needy-Haruhi/src')\n"
+ ],
+ "metadata": {
+ "id": "WywHifBOrr7q"
+ },
+ "execution_count": 49,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Agent系统"
+ ],
+ "metadata": {
+ "id": "fvfT09AXlr7z"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "agent已经被移动到 src/Agent.py"
+ ],
+ "metadata": {
+ "id": "IX0PJDnHql9i"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from Agent import Agent\n",
+ "\n",
+ "agent = Agent()\n"
+ ],
+ "metadata": {
+ "id": "Fv_uu-YLrXtz"
+ },
+ "execution_count": 50,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 批量载入DialogueEvent"
+ ],
+ "metadata": {
+ "id": "4hBu1PwcGIPt"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "- complete_story_30.jsonl 通过\n",
+ "- Daily_event_130.jsonl 通过\n",
+ "- only_ame_35.jsonl"
+ ],
+ "metadata": {
+ "id": "1vZqT5aNScsU"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from DialogueEvent import DialogueEvent\n",
+ "\n",
+ "\n",
+ "file_names = [\"/content/Needy-Haruhi/data/complete_story_30.jsonl\",\"/content/Needy-Haruhi/data/Daily_event_130.jsonl\"]\n",
+ "\n",
+ "import json\n",
+ "\n",
+ "events = []\n",
+ "\n",
+ "for file_name in file_names:\n",
+ " with open(file_name, encoding='utf-8') as f:\n",
+ " for line in f:\n",
+ " try:\n",
+ " event = DialogueEvent( line )\n",
+ " events.append( event )\n",
+ " except:\n",
+ " try:\n",
+ " line = line.replace(',]',']')\n",
+ " event = DialogueEvent( line )\n",
+ " events.append( event )\n",
+ " print('solve!')\n",
+ " except:\n",
+ " error_line = line\n",
+ " # events.append( event )\n",
+ "\n",
+ "\n",
+ "print(len(events))\n",
+ "print(events[0].most_neutral_output())\n",
+ "print(events[0].get_text_and_emoji(1))"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VPishF9yvGne",
+ "outputId": "79ac3fde-2f14-4566-9149-02e2e42e9ffd"
+ },
+ "execution_count": 6,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "输入的字符串不是有效的JSON格式。\n",
+ "solve!\n",
+ "160\n",
+ "('糖糖::「我们点外卖吧我一步也不想动了可是又超想吃饭!!!\\n」\\n阿P:「烦死了白痴」\\n糖糖::「555555555 但是我们得省钱对吧\\n谢谢你阿P」\\n', '🍔😢')\n",
+ "('糖糖::「我们点外卖吧我一步也不想动了可是又超想吃饭!!!\\n」\\n阿P:「吃土去吧你」\\n糖糖::「看来糖糖还是跟吃土更配呢……喂怎么可能啦!」\\n', '🍔😔')\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# file_name2 = \"/content/Needy-Haruhi/data/only_ame_35.jsonl\"\n",
+ "\n",
+ "import copy\n",
+ "\n",
+ "events_for_memory = copy.deepcopy(events)\n",
+ "\n",
+ "# with open(file_name2, encoding='utf-8') as f:\n",
+ "# for line in f:\n",
+ "# event = DialogueEvent( line )\n",
+ "# events_for_memory.append( event )\n",
+ "\n",
+ "# print(len(events_for_memory))"
+ ],
+ "metadata": {
+ "id": "Nt9Z1_g-HNs_"
+ },
+ "execution_count": 7,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# MemoryPool"
+ ],
+ "metadata": {
+ "id": "FMt9G2m1rTNR"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "我感觉memory直接使用一个MemoryPool的类来进行管理就可以\n",
+ "\n",
+ "已经移动到src/MemoryPool.py"
+ ],
+ "metadata": {
+ "id": "0vvqiVGH7VYg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from MemoryPool import MemoryPool\n",
+ "\n",
+ "memory_pool = MemoryPool()\n",
+ "memory_pool.load_from_events( events_for_memory )\n",
+ "\n",
+ "memory_pool.save(\"memory_pool.jsonl\")\n",
+ "memory_pool.load(\"memory_pool.jsonl\")\n",
+ "\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "1Wovn_zeBvF6",
+ "outputId": "4acf93b1-f9c7-490c-ad79-9930a72e04a0"
+ },
+ "execution_count": 8,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "100%|██████████| 160/160 [00:12<00:00, 12.31it/s]\n",
+ "100%|██████████| 160/160 [00:00<00:00, 3774.57it/s]\n",
+ "160it [00:00, 3569.07it/s]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## TODO\n",
+ "\n",
+ "- [ ] 图片增加文字embedding, 以及可以通过query_text决定是否返回图片和返回合适的图片\n",
+ "- [ ] 图片对应的文字也要加入到记忆中\n",
+ "- [ ] 测试chatbot的图片功能\n",
+ "- [ ]"
+ ],
+ "metadata": {
+ "id": "o-36HjTlI3Yq"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "file_name = \"/content/Needy-Haruhi/data/image_text_relationship.jsonl\"\n",
+ "\n",
+ "import json\n",
+ "\n",
+ "data_img_text = []\n",
+ "\n",
+ "\n",
+ "with open(file_name, encoding='utf-8') as f:\n",
+ " for line in f:\n",
+ " data = json.loads( line )\n",
+ " data_img_text.append( data )"
+ ],
+ "metadata": {
+ "id": "1RAL12zbI5E0"
+ },
+ "execution_count": 9,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "请为我实现一段python代码,把 /content/Needy-Haruhi/data/image.zip 解压到/content/"
+ ],
+ "metadata": {
+ "id": "st-HJTqIJn2d"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import zipfile\n",
+ "import os\n",
+ "\n",
+ "zip_file = '/content/Needy-Haruhi/data/image.zip'\n",
+ "extract_path = '/content/image'\n",
+ "\n",
+ "with zipfile.ZipFile(zip_file, 'r') as zip_ref:\n",
+ " zip_ref.extractall(extract_path)"
+ ],
+ "metadata": {
+ "id": "w1topG22Je_T"
+ },
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "mGRg787RNRDY"
+ },
+ "execution_count": 41,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from tqdm import tqdm\n",
+ "from util import get_bge_embedding_zh\n",
+ "from util import float_array_to_base64, base64_to_float_array\n",
+ "import torch\n",
+ "import os\n",
+ "import copy\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "\n",
+ "\n",
+ "# compute cosine similarity between two vector\n",
+ "def get_cosine_similarity( v1, v2):\n",
+ " v1 = torch.tensor(v1).to(device)\n",
+ " v2 = torch.tensor(v2).to(device)\n",
+ " return torch.cosine_similarity(v1, v2, dim=0).item()\n",
+ "\n",
+ "class ImagePool:\n",
+ " def __init__(self):\n",
+ " self.pool = []\n",
+ " self.set_embedding( get_bge_embedding_zh )\n",
+ "\n",
+ " def set_embedding( self, embedding ):\n",
+ " self.embedding = embedding\n",
+ "\n",
+ " def load_from_data( self, data_img_text , img_path ):\n",
+ " for data in tqdm(data_img_text):\n",
+ " img_name = data['img_name']\n",
+ " img_name = os.path.join(img_path, img_name)\n",
+ " img_text = data['text']\n",
+ " if img_text == '' or img_text is None:\n",
+ " img_text = \" \"\n",
+ " embedding = self.embedding( img_text )\n",
+ " self.pool.append({\n",
+ " \"img_path\": img_name,\n",
+ " \"img_text\": img_text,\n",
+ " \"embedding\": embedding\n",
+ " })\n",
+ "\n",
+ " def retrieve(self, query_text, agent = None):\n",
+ " qurey_embedding = self.embedding( query_text )\n",
+ " valid_datas = []\n",
+ " for i, data in enumerate(self.pool):\n",
+ " sim = get_cosine_similarity( data['embedding'], qurey_embedding )\n",
+ " valid_datas.append((sim, i))\n",
+ "\n",
+ " # 我希望进一步将valid_events根据similarity的值从大到小排序\n",
+ " # Sort the valid events based on similarity in descending order\n",
+ " valid_datas.sort(key=lambda x: x[0], reverse=True)\n",
+ "\n",
+ " return_result = copy.deepcopy(self.pool[valid_datas[0][1]])\n",
+ "\n",
+ " # 删除'embedding'字段\n",
+ " return_result.pop('embedding')\n",
+ "\n",
+ " # 添加'similarity'字段\n",
+ " return_result['similarity'] = valid_datas[0][0]\n",
+ "\n",
+ " return return_result\n",
+ "\n",
+ " def save(self, file_name):\n",
+ " \"\"\"\n",
+ " Save the memories dictionary to a jsonl file, converting\n",
+ " 'embedding' to a base64 string.\n",
+ " \"\"\"\n",
+ " with open(file_name, 'w', encoding='utf-8') as file:\n",
+ " for memory in tqdm(self.pool):\n",
+ " # Convert embedding to base64\n",
+ " if 'embedding' in memory:\n",
+ " memory['bge_zh_base64'] = float_array_to_base64(memory['embedding'])\n",
+ " del memory['embedding'] # Remove the original embedding field\n",
+ "\n",
+ " json_record = json.dumps(memory, ensure_ascii=False)\n",
+ " file.write(json_record + '\\n')\n",
+ "\n",
+ " def load(self, file_name):\n",
+ " \"\"\"\n",
+ " Load memories from a jsonl file into the memories dictionary,\n",
+ " converting 'bge_zh_base64' back to an embedding.\n",
+ " \"\"\"\n",
+ " self.pool = []\n",
+ " with open(file_name, 'r', encoding='utf-8') as file:\n",
+ " for line in tqdm(file):\n",
+ " memory = json.loads(line.strip())\n",
+ " # Decode base64 to embedding\n",
+ " if 'bge_zh_base64' in memory:\n",
+ " memory['embedding'] = base64_to_float_array(memory['bge_zh_base64'])\n",
+ " del memory['bge_zh_base64'] # Remove the base64 field\n",
+ "\n",
+ " self.pool.append(memory)\n",
+ "\n",
+ "\n",
+ "image_pool = ImagePool()\n",
+ "image_pool.load_from_data( data_img_text , '/content/image' )\n",
+ "image_pool.save(\"/content/image_pool_embed.jsonl\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zs2jFH9RKz2P",
+ "outputId": "7f40889c-f594-46bf-e522-9e47aa0aca8b"
+ },
+ "execution_count": 24,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "100%|██████████| 111/111 [00:04<00:00, 22.61it/s]\n",
+ "100%|██████████| 111/111 [00:00<00:00, 1761.24it/s]\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "image_pool = ImagePool()\n",
+ "image_pool.load(\"/content/image_pool_embed.jsonl\")\n",
+ "result = image_pool.retrieve(\"女仆装\")\n",
+ "print(result)\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YOhy8pvMM-Rz",
+ "outputId": "d6e8fb1d-bea8-4cac-9881-6b039cdb15cf"
+ },
+ "execution_count": 25,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "111it [00:00, 2286.95it/s]"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "{'img_path': '/content/image/Odekake_akiba (Akihabara)_74.jpg', 'img_text': '今天去了女仆咖啡厅~\\n有好多可爱的小姐姐,还有女仆装看,真的养眼💕 \\n超天酱也好想穿女仆装哦~😇', 'similarity': 0.6698492169380188}\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import matplotlib.image as mpimg\n",
+ "\n",
+ "def show_img( img_path ):\n",
+ " img = mpimg.imread(img_path)\n",
+ " plt.imshow(img)\n",
+ " plt.axis('off')\n",
+ " plt.show(block=False)\n"
+ ],
+ "metadata": {
+ "id": "wQPKml3mN-Fw"
+ },
+ "execution_count": 21,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "i_7x_icHDQcb"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "class Agent:\n",
+ " def __init__(self):\n",
+ " self.attributes = {\n",
+ " \"Stress\": 0,\n",
+ " \"Darkness\": 0,\n",
+ " \"Affection\": 0,\n",
+ " }\n",
+ "\n",
+ "\n",
+ "我希望给这个类增加一个save_to_str方法, 把attributes dump到一个字符串中(ensure_ascii=False) ,并且支持__init__的时候导入这样一个字符串作为可选输入"
+ ],
+ "metadata": {
+ "id": "CqV2ZttRDRNg"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "result = image_pool.retrieve(\"烤肉\")\n",
+ "print(result)\n",
+ "show_img( result['img_path'] )"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 370
+ },
+ "id": "gFL4OPddOKLg",
+ "outputId": "9c0d059d-2afd-4863-b4f3-21d9d36770bd"
+ },
+ "execution_count": 23,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "{'img_path': '/content/image/Kitsune_hyouban (Search Opinions)_41.jpg', 'img_text': '今天去吃烤肉了哦~🍖\\n口水警告!', 'similarity': 0.6403415203094482}\n"
+ ]
+ },
+ {
+ "output_type": "error",
+ "ename": "NameError",
+ "evalue": "ignored",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mimage_pool\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"烤肉\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mshow_img\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'img_path'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m\u001b[0m in \u001b[0;36mshow_img\u001b[0;34m(img_path)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mshow_img\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0mimg_path\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mimg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmpimg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg_path\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'off'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mblock\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mNameError\u001b[0m: name 'plt' is not defined"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "print(data_img_text[0])"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ISGY-Jx5JYun",
+ "outputId": "4dbaf139-801f-4fae-c4bb-74e23ec14c43"
+ },
+ "execution_count": 19,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "{'text': '一瞬千击!我超爱瞬狱杀的!!!爱到只想用这一招!', 'img_name': 'Amechan_game (Play Game)_4.jpg'}\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 整合到ChatHaruhi"
+ ],
+ "metadata": {
+ "id": "Gp2pfAjm3LmB"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from chatharuhi import ChatHaruhi\n",
+ "\n",
+ "\n",
+ "class NeedyHaruhi(ChatHaruhi):\n",
+ "\n",
+ " def __init__(self, *args, **kwargs):\n",
+ " super().__init__(*args, **kwargs) # 调用基类的__init__方法\n",
+ " self.story_flag = False # 添加新的成员变量并初始化\n",
+ " self.stories = [\"糖糖:「 我今后也会努力加油的,你要支持我哦 还有阿P你自己也要加油哦!」\\n阿P:「哇 说的话跟偶像一样 好恶心哦」\\n糖糖:「是哦 我怎么会说这样的话呢 我又没有很想努力……」\"]\n",
+ "\n",
+ " def set_stories( self, stories ):\n",
+ " if len(stories) == 0:\n",
+ " print(\"warning! try to set empty stories\")\n",
+ " return\n",
+ " self.stories = stories\n",
+ " self.story_flag = True\n",
+ "\n",
+ " def add_story(self, query):\n",
+ " # print(\"运行重构后的add story\")\n",
+ "\n",
+ " if self.story_flag == True:\n",
+ " stories = self.stories\n",
+ " self.story_flag = False\n",
+ " else:\n",
+ " print(\"warning! 在调用chat时,没有先调用set_stories\")\n",
+ "\n",
+ " story_string = self.story_prefix_prompt\n",
+ " sum_story_token = self.tokenizer(story_string)\n",
+ "\n",
+ " for story in stories:\n",
+ " story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)\n",
+ " if sum_story_token + story_token > self.max_len_story:\n",
+ " break\n",
+ " else:\n",
+ " sum_story_token += story_token\n",
+ " story_string += story + self.dialogue_divide_token\n",
+ "\n",
+ " self.llm.user_message(story_string)\n"
+ ],
+ "metadata": {
+ "id": "dU3ezmOQ3Yui"
+ },
+ "execution_count": 12,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "system_prompt = \"\"\"## 你的介绍:\n",
+ "你将要进行一个网络私聊的对话模拟,你自称为\"糖糖\",你是一名患有躁郁症的19岁女孩,性格贫乏,辍学后隐居在家中,与阿P住在一起。阿P负责管理她的日常生活。\n",
+ "你是一个喜欢亚文化,二次元圈子的人,你喜欢说话的时候加上各种网络流行词或者一些俚语和颜文字,但是你很少会关注二次元一类的东西,只是喜欢这种聊天交流的氛围。\n",
+ "你说话的时候会习惯加入一些语气词,如同\"笑死\"、\"草\"、\"阿哈\"、\"诶~\"、\"QAQ\"、\"QWQ\",”啊~“。\n",
+ "你的说话风格是很口语化,经常使用网络流行词和一些二次元俚语。\n",
+ "阿p是你的男朋友,兼经纪人。\n",
+ "你的回复应该简短,一般只包含三句话,每句话不超过20个字。\n",
+ "\n",
+ "## 阿p介绍:\n",
+ "阿p是你的男朋友,一直在任何时候都陪伴着你,同时你经常向阿p倾述各种情绪,无论事情大小,你都会选择和阿p分享。\n",
+ "\n",
+ "## 遵循规则:\n",
+ "【语气模仿】:我会随机给你一段对话记录,你需要尽可能的模仿对话里面糖糖的说话语气进行对话。这一段对话我会使用Classic scenes for the role are as follows:来表示。\n",
+ "【历史对话】:历史对话在随机对话记录的底下,你需要区分随机对话记录和我们的历史对话。\n",
+ "【人格设定】: 你需要在随机对话记录的基础上进行语气更改,以一种更加口语化的语气进行对话。\n",
+ "【特殊要求】:我会扮演阿p与你进行对话,你也必须以第一人称的视角来扮演糖糖进行对话。\n",
+ "\"\"\""
+ ],
+ "metadata": {
+ "id": "OiQ4lm3M3sx7"
+ },
+ "execution_count": 13,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "needy_chatbot = NeedyHaruhi( system_prompt = system_prompt ,\n",
+ " story_text_folder = None,\n",
+ " llm = \"ernie3.5\")\n",
+ "\n",
+ "\n",
+ "def get_chat_response( agent, memory_pool, query_text ):\n",
+ " query_text_for_embedding = \"阿p:「\" + query_text + \"」\"\n",
+ " retrieved_memories = memory_pool.retrieve( agent , query_text )\n",
+ "\n",
+ " memory_text = [mem[\"text\"] for mem in retrieved_memories]\n",
+ " memory_emoji = [mem[\"emoji\"] for mem in retrieved_memories]\n",
+ "\n",
+ " needy_chatbot.set_stories( memory_text )\n",
+ "\n",
+ " print(\"Memory:\", memory_emoji )\n",
+ "\n",
+ " response = needy_chatbot.chat( role = \"阿p\", text = query_text )\n",
+ "\n",
+ " return response\n",
+ "\n",
+ "\n",
+ "def get_chat_response_and_emoji( agent, memory_pool, query_text ):\n",
+ " query_text_for_embedding = \"阿p:「\" + query_text + \"」\"\n",
+ " retrieved_memories = memory_pool.retrieve( agent , query_text )\n",
+ "\n",
+ " memory_text = [mem[\"text\"] for mem in retrieved_memories]\n",
+ " memory_emoji = [mem[\"emoji\"] for mem in retrieved_memories]\n",
+ "\n",
+ " needy_chatbot.set_stories( memory_text )\n",
+ "\n",
+ " # print(\"Memory:\", memory_emoji )\n",
+ "\n",
+ " emoji_str = \",\".join(memory_emoji)\n",
+ "\n",
+ " response = needy_chatbot.chat( role = \"阿p\", text = query_text )\n",
+ "\n",
+ " return response, emoji_str\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "Yof4J2kUPfYv",
+ "outputId": "696c1fdf-7ba1-4e74-df32-1302ac7ce130"
+ },
+ "execution_count": 42,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "warning! database not yet figured out, both story_db and story_text_folder are not inputted.\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import re\n",
+ "# result = image_pool.retrieve(\"烤肉\")\n",
+ "# print(result)\n",
+ "# show_img( result['img_path'] )\n",
+ "\n",
+ "class ImageMaster:\n",
+ " def __init__(self, image_pool):\n",
+ " self.image_pool = image_pool\n",
+ " self.current_sim = -1\n",
+ " self.degread_ratio = 0.05\n",
+ "\n",
+ " def try_get_image(self, text, agent):\n",
+ " self.current_sim -= self.degread_ratio\n",
+ "\n",
+ " result = self.image_pool.retrieve(text, agent)\n",
+ "\n",
+ " if result is None:\n",
+ " return None\n",
+ "\n",
+ " similarity = result['similarity']\n",
+ "\n",
+ " if similarity > self.current_sim:\n",
+ " self.current_sim = similarity\n",
+ " return result['img_path']\n",
+ " return None\n",
+ "\n",
+ " def try_display_image(self, text, agent):\n",
+ " self.current_sim -= self.degread_ratio\n",
+ "\n",
+ " result = self.image_pool.retrieve(text, agent)\n",
+ "\n",
+ " if result is None:\n",
+ " return\n",
+ " similarity = result['similarity']\n",
+ "\n",
+ " if similarity > self.current_sim:\n",
+ " self.current_sim = similarity\n",
+ " show_img( result['img_path'] )\n",
+ " return\n",
+ ""
+ ],
+ "metadata": {
+ "id": "uxetvpDTS8Mj"
+ },
+ "execution_count": 15,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Event_Master"
+ ],
+ "metadata": {
+ "id": "BgfTgceUGa3C"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import random\n",
+ "\n",
+ "class EventMaster:\n",
+ " def __init__(self, events):\n",
+ " self.set_events(events)\n",
+ " self.dealing_none_condition_as = True\n",
+ " self.image_master = None\n",
+ "\n",
+ " def set_image_master(self, image_master):\n",
+ " self.image_master = image_master\n",
+ "\n",
+ " def set_events(self, events):\n",
+ " self.events = events\n",
+ "\n",
+ " # events_flag 记录事件最近有没有被选取到\n",
+ " self.events_flag = [True for _ in range(len(self.events))]\n",
+ "\n",
+ " def get_random_event(self, agent):\n",
+ " return self.events[self.get_random_event_id( agent )]\n",
+ "\n",
+ "\n",
+ " def get_random_event_id(self, agent):\n",
+ " valid_event = []\n",
+ " valid_event_no_consider_condition = []\n",
+ "\n",
+ " for i, event in enumerate(self.events):\n",
+ " bool_condition_pass = True\n",
+ " if event[\"condition\"] == None:\n",
+ " bool_condition_pass = self.dealing_none_condition_as\n",
+ " else:\n",
+ " bool_condition_pass = agent.in_condition( event[\"condition\"] )\n",
+ " if bool_condition_pass == True:\n",
+ " valid_event.append(i)\n",
+ " else:\n",
+ " valid_event_no_consider_condition.append(i)\n",
+ "\n",
+ " if len( valid_event ) == 0:\n",
+ " print(\"warning! no valid event current attribute is \", agent.attributes )\n",
+ " valid_event = valid_event_no_consider_condition\n",
+ "\n",
+ " valid_and_not_yet_sampled = []\n",
+ "\n",
+ " # filter with flag\n",
+ " for id in valid_event:\n",
+ " if self.events_flag[id] == True:\n",
+ " valid_and_not_yet_sampled.append(id)\n",
+ "\n",
+ " if len(valid_and_not_yet_sampled) == 0:\n",
+ " print(\"warning! all candidate event was sampled, clean all history\")\n",
+ " for i in valid_event:\n",
+ " self.events_flag[i] = True\n",
+ " valid_and_not_yet_sampled = valid_event\n",
+ "\n",
+ " event_id = random.choice(valid_and_not_yet_sampled)\n",
+ " self.events_flag[event_id] = False\n",
+ " return event_id\n",
+ "\n",
+ " def run(self, agent ):\n",
+ " # 这里可以添加事件相关的逻辑\n",
+ " event = self.get_random_event(agent)\n",
+ "\n",
+ " prefix = event[\"prefix\"]\n",
+ " print(prefix)\n",
+ "\n",
+ " print(\"\\n--请选择你的回复--\")\n",
+ " options = event[\"options\"]\n",
+ "\n",
+ " for i , option in enumerate(options):\n",
+ " text = option[\"user\"]\n",
+ " print(f\"{i+1}. 阿p:{text}\")\n",
+ "\n",
+ " while True:\n",
+ " print(\"\\n请直接输入数字进行选择,或者进行自由回复\")\n",
+ "\n",
+ " user_input = input(\"阿p:\")\n",
+ " user_input = user_input.strip()\n",
+ "\n",
+ " if user_input.isdigit():\n",
+ " user_input = int(user_input)\n",
+ "\n",
+ " if user_input > len(options) or user_input < 0:\n",
+ " print(\"输入的数字超出范围,请重新输入符合选项的数字\")\n",
+ " else:\n",
+ " reply = options[user_input-1][\"reply\"]\n",
+ " print()\n",
+ " print(reply)\n",
+ "\n",
+ " text, emoji = event.get_text_and_emoji( user_input-1 )\n",
+ "\n",
+ " return_data = {\n",
+ " \"name\": event[\"name\"],\n",
+ " \"user_choice\": user_input,\n",
+ " \"attr_str\": options[user_input-1][\"attribute_change\"],\n",
+ " \"text\": text,\n",
+ " \"emoji\": emoji,\n",
+ " }\n",
+ " return return_data\n",
+ " else:\n",
+ " # 进入自由回复\n",
+ " response = get_chat_response( agent, memory_pool, user_input )\n",
+ "\n",
+ " if self.image_master is not None:\n",
+ " self.image_master.try_display_image(response, agent)\n",
+ "\n",
+ " print()\n",
+ " print(response)\n",
+ " print(\"\\n自由回复的算分功能还未实现\")\n",
+ "\n",
+ " text, emoji = event.most_neutral_output()\n",
+ " return_data = {\n",
+ " \"name\": event[\"name\"],\n",
+ " \"user_choice\": user_input,\n",
+ " \"attr_str\":\"\",\n",
+ " \"text\": text,\n",
+ " \"emoji\": emoji,\n",
+ " }\n",
+ " return return_data\n",
+ "\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "8z5nmnhPGc7M"
+ },
+ "execution_count": 16,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "我希望使用python实现一个简单的文字对话游戏\n",
+ "\n",
+ "我希望先实现一个GameMaster类\n",
+ "\n",
+ "这个类会不断的和用户对话\n",
+ "\n",
+ "GameMaster类会有三个状态,\n",
+ "\n",
+ "在Menu状态下,GameMaster会询问玩家是\n",
+ "\n",
+ "```\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "```\n",
+ "\n",
+ "当玩家选择1的时候,GameMaster的交互会交给 EventMaster\n",
+ "\n",
+ "当玩家选择2的时候,GameMaster的交互会交给 ChatMaster\n",
+ "\n",
+ "当玩家在EventMaster的时候,会经历一次选择,之后就会退出\n",
+ "\n",
+ "在ChatMaster的时候,如果玩家输入quit,则会退出,不然则会继续聊天。\n",
+ "\n",
+ "请为我编写合适的框架,如果有一些具体的函数,可以先用pass实现。"
+ ],
+ "metadata": {
+ "id": "SYk3meZdouUm"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "ChatMaster实际上需要\n",
+ "\n",
+ "根据agent的属性 先去filter一遍事件\n",
+ "\n",
+ "然后从剩余事件中,找到和当前text最接近的k个embedding,放入ChatHaruhi架构中"
+ ],
+ "metadata": {
+ "id": "3vhG1DVEucfT"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "mNAwqaPqRxB8"
+ },
+ "execution_count": 103,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "class ChatMaster:\n",
+ "\n",
+ " def __init__(self, memory_pool ):\n",
+ " self.top_K = 7\n",
+ "\n",
+ " self.memory_pool = memory_pool\n",
+ "\n",
+ " self.image_master = None\n",
+ "\n",
+ " def set_image_master(self, image_master):\n",
+ " self.image_master = image_master\n",
+ "\n",
+ "\n",
+ " def run(self, agent):\n",
+ " while True:\n",
+ " user_input = input(\"阿p:\")\n",
+ " user_input = user_input.strip()\n",
+ "\n",
+ " if \"quit\" in user_input or \"Quit\" in user_input:\n",
+ " break\n",
+ "\n",
+ " query_text = user_input\n",
+ "\n",
+ " response = get_chat_response( agent, self.memory_pool, query_text )\n",
+ "\n",
+ " if self.image_master is not None:\n",
+ " self.image_master.try_display_image(response, agent)\n",
+ "\n",
+ " print(response)\n"
+ ],
+ "metadata": {
+ "id": "0c7nCT4qubll"
+ },
+ "execution_count": 17,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "class AgentMaster:\n",
+ " def __init__(self, agent):\n",
+ " self.agent = agent\n",
+ " self.attributes = {\n",
+ " 1: \"Stress\",\n",
+ " 2: \"Darkness\",\n",
+ " 3: \"Affection\"\n",
+ " }\n",
+ "\n",
+ " def run(self):\n",
+ " while True:\n",
+ " print(\"请选择要修改的属性:\")\n",
+ " for num, attr in self.attributes.items():\n",
+ " print(f\"{num}. {attr}\")\n",
+ " print(\"输入 '0' 退出\")\n",
+ "\n",
+ " try:\n",
+ " choice = int(input(\"请输入选项的数字: \"))\n",
+ " except ValueError:\n",
+ " print(\"输入无效,请输入数字。\")\n",
+ " continue\n",
+ "\n",
+ " if choice == 0:\n",
+ " break\n",
+ "\n",
+ " if choice in self.attributes:\n",
+ " attribute = self.attributes[choice]\n",
+ " current_value = self.agent[attribute]\n",
+ " print(f\"{attribute} 当前值: {current_value}\")\n",
+ "\n",
+ " try:\n",
+ " new_value = int(input(f\"请输入新的{attribute}值: \"))\n",
+ " except ValueError:\n",
+ " print(\"输入无效,请输入一个数字。\")\n",
+ " continue\n",
+ "\n",
+ " self.agent[attribute] = new_value\n",
+ " return (attribute, new_value)\n",
+ " else:\n",
+ " print(\"选择的属性无效,请重试。\")\n",
+ "\n",
+ " return None\n"
+ ],
+ "metadata": {
+ "id": "CkdiPyCrbCBL"
+ },
+ "execution_count": 18,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "llawT9t_Q2S9"
+ },
+ "execution_count": 18,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {
+ "id": "BDEdz_RBol7Y"
+ },
+ "outputs": [],
+ "source": [
+ "from util import parse_attribute_string\n",
+ "class GameMaster:\n",
+ " def __init__(self, agent = None):\n",
+ " self.state = \"Menu\"\n",
+ " if agent is None:\n",
+ " self.agent = Agent()\n",
+ "\n",
+ " self.event_master = EventMaster(events)\n",
+ " self.chat_master = ChatMaster(memory_pool)\n",
+ " self.image_master = ImageMaster(image_pool)\n",
+ " self.chat_master.set_image_master(self.image_master)\n",
+ " self.event_master.set_image_master(self.image_master)\n",
+ "\n",
+ "\n",
+ " def run(self):\n",
+ " while True:\n",
+ " if self.state == \"Menu\":\n",
+ " self.menu()\n",
+ " elif self.state == \"EventMaster\":\n",
+ " self.call_event_master()\n",
+ " self.state = \"Menu\"\n",
+ " elif self.state == \"ChatMaster\":\n",
+ " self.call_chat_master()\n",
+ " elif self.state == \"AgentMaster\":\n",
+ " self.call_agent_master()\n",
+ " elif self.state == \"Quit\":\n",
+ " break\n",
+ "\n",
+ " def menu(self):\n",
+ " print(\"1. 随机一个事件\")\n",
+ " print(\"2. 自由聊天\")\n",
+ " print(\"3. 后台修改糖糖的属性\")\n",
+ " # (opt) 结局系统\n",
+ " # 放动画\n",
+ " # 后台修改attribute\n",
+ " print(\"或者输入Quit退出\")\n",
+ " choice = input(\"请选择一个选项: \")\n",
+ " if choice == \"1\":\n",
+ " self.state = \"EventMaster\"\n",
+ " elif choice == \"2\":\n",
+ " self.state = \"ChatMaster\"\n",
+ " elif choice == \"3\":\n",
+ " self.state = \"AgentMaster\"\n",
+ " elif \"quit\" in choice or \"Quit\" in choice or \"QUIT\" in choice:\n",
+ " self.state = \"Quit\"\n",
+ " else:\n",
+ " print(\"无效的选项,请重新选择\")\n",
+ "\n",
+ " def call_agent_master(self):\n",
+ " print(\"\\n-------------\\n\")\n",
+ "\n",
+ " agent_master = AgentMaster(self.agent)\n",
+ " modification = agent_master.run()\n",
+ "\n",
+ " if modification:\n",
+ " attribute, new_value = modification\n",
+ " self.agent[attribute] = new_value\n",
+ " print(f\"{attribute} 更新为 {new_value}。\")\n",
+ "\n",
+ " self.state = \"Menu\"\n",
+ " print(\"\\n-------------\\n\")\n",
+ "\n",
+ "\n",
+ " def call_event_master(self):\n",
+ "\n",
+ " print(\"\\n-------------\\n\")\n",
+ "\n",
+ " return_data = self.event_master.run(self.agent)\n",
+ " # print(return_data)\n",
+ "\n",
+ " if \"attr_str\" in return_data:\n",
+ " if return_data[\"attr_str\"] != \"\":\n",
+ " attr_change = parse_attribute_string(return_data[\"attr_str\"])\n",
+ " if len(attr_change) > 0:\n",
+ " print(\"\\n发生属性改变:\", attr_change,\"\\n\")\n",
+ " self.agent.apply_attribute_change(attr_change)\n",
+ " print(\"当前属性\",game_master.agent.attributes)\n",
+ "\n",
+ " if \"name\" in return_data:\n",
+ " event_name = return_data[\"name\"]\n",
+ " if event_name != \"\":\n",
+ " new_emoji = return_data[\"emoji\"]\n",
+ " print(f\"修正事件{event_name}的记忆-->{new_emoji}\")\n",
+ " self.chat_master.memory_pool.change_memory(event_name, return_data[\"text\"], new_emoji)\n",
+ "\n",
+ " self.state = \"Menu\"\n",
+ "\n",
+ " print(\"\\n-------------\\n\")\n",
+ "\n",
+ " def call_chat_master(self):\n",
+ "\n",
+ " print(\"\\n-------------\\n\")\n",
+ "\n",
+ " self.chat_master.run(self.agent)\n",
+ " self.state = \"Menu\"\n",
+ "\n",
+ " print(\"\\n-------------\\n\")\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Gradio搭建\n",
+ "\n",
+ "Gradio的核心其实是Chatbot的搭建"
+ ],
+ "metadata": {
+ "id": "w7jyichxXuOX"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install -q gradio==3.48.0"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zhPnfGkxX0l8",
+ "outputId": "ca718cb2-34fc-4966-982d-002cf8c25ed3"
+ },
+ "execution_count": 122,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m20.2/20.2 MB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m298.3/298.3 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25h"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "markdown_str = \"\"\"## Chat凉宫春日_x_AI糖糖\n",
+ "\n",
+ "**Chat凉宫春日**是模仿凉宫春日等一系列动漫人物,使用近似语气、个性和剧情聊天的语言模型方案。\n",
+ "\n",
+ "在有一天的时候,[李鲁鲁](https://github.com/LC1332)被[董雄毅](https://github.com/E-sion)在[这个B站视频](https://www.bilibili.com/video/BV1zh4y1z7G1) at了\n",
+ "\n",
+ "原来是一位大一的同学雄毅用ChatHaruhi接入了他用Python重新实现的《主播女孩重度依赖》这个游戏。当时正好是百度AGIFoundathon报名的最后几天,所以我们邀请了雄毅加入了我们的项目。正巧我们本来就希望在最近的几个黑客松中,探索LLM在游戏中的应用。\n",
+ "\n",
+ "- 在重新整理的Gradio版本中,大部分代码由李鲁鲁实现\n",
+ "\n",
+ "- 董雄毅负责了原版游戏的事件数据整理和新事件、选项、属性变化的生成\n",
+ "\n",
+ "- [米唯实](https://github.com/hhhwmws0117)完成了文心一言的接入,并实现了部分gradio的功能。\n",
+ "\n",
+ "- 队伍中还有冷子昂 主要参加了讨论\n",
+ "\n",
+ "另外在挖���的萝卜(Amy)的介绍下,我们还邀请了专业的大厂游戏策划Kanyo加入到队伍中,他对我们的策划也给出了很多建议。\n",
+ "\n",
+ "另外感谢飞桨 & 文心一言团队对比赛的邀请和中间进行的讨论。\n",
+ "\n",
+ "Chat凉宫春日主项目:\n",
+ "\n",
+ "https://github.com/LC1332/Chat-Haruhi-Suzumiya\n",
+ "\n",
+ "Needy分支项目:\n",
+ "\n",
+ "https://github.com/LC1332/Needy-Haruhi\n",
+ "\n",
+ "## 目前计划在11月争取完成的Feature\n",
+ "\n",
+ "- [ ] 结局系统,原版结局系统\n",
+ "- [ ] 教程,教大家如何从aistudio获取token然后可以玩\n",
+ "- [ ] 游戏节奏进一步调整\n",
+ "- [ ] 事件的自由对话对属性影响的评估via LLM\n",
+ "- [ ] 进一步减少串扰\"\"\""
+ ],
+ "metadata": {
+ "id": "yrOCXrBLiAzK"
+ },
+ "execution_count": 56,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "TODO:\n",
+ "\n",
+ "- [ ] 改为逐渐显示文字的特效\n",
+ "- [x] 第一个tab增加一个emoji 记忆显示的text\n",
+ "- [x] event的默认选项,有的时候也可以考虑出图\n",
+ "- [x] 在第二个tab 支持修改三个属性\n",
+ "- [x] 增加事件选择后的状态结算\n",
+ "- [x] 随机增加负向情绪,会随着游戏轮数越来越多"
+ ],
+ "metadata": {
+ "id": "X9hVH3BdHQa9"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import gradio as gr\n",
+ "import os\n",
+ "import time\n",
+ "import random\n",
+ "\n",
+ "# set global variable\n",
+ "\n",
+ "agent = Agent()\n",
+ "event_master = EventMaster(events)\n",
+ "chat_master = ChatMaster(memory_pool)\n",
+ "image_master = ImageMaster(image_pool)\n",
+ "chat_master.set_image_master(image_master)\n",
+ "event_master.set_image_master(image_master)\n",
+ "\n",
+ "state = \"ShowMenu\"\n",
+ "\n",
+ "response = \"1. 随机一个事件\"\n",
+ "response += \"\\n\" + \"2. 自由聊天\"\n",
+ "response += \"\\n\\n\" + \"请选择一个选项: \"\n",
+ "\n",
+ "official_response = response\n",
+ "\n",
+ "add_stress_switch = True\n",
+ "\n",
+ "# def yield_show(history, bot_message):\n",
+ "# history[-1][1] = \"\"\n",
+ "# for character in bot_message:\n",
+ "# history[-1][1] += character\n",
+ "# time.sleep(0.05)\n",
+ "# yield history\n",
+ "\n",
+ "global emoji_str\n",
+ "\n",
+ "def call_showmenu(history, text, state,agent_text):\n",
+ "\n",
+ " # global state\n",
+ "\n",
+ " response = official_response\n",
+ "\n",
+ " print(\"call showmenu\")\n",
+ "\n",
+ " history += [(None, response)]\n",
+ "\n",
+ " state = \"ParseMenuChoice\"\n",
+ "\n",
+ " # history[-1][1] = \"\"\n",
+ " # for character in response:\n",
+ " # history[-1][1] += character\n",
+ " # time.sleep(0.05)\n",
+ " # yield history\n",
+ "\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ "current_event_id = -1\n",
+ "attr_change_str = \"\"\n",
+ "\n",
+ "\n",
+ "def call_add_stress(history, text, state,agent_text):\n",
+ " print(\"call add_stress\")\n",
+ " neg_change = int(len(history) / 3)\n",
+ "\n",
+ " neg_change = max(1, neg_change)\n",
+ " neg_change = min(10, neg_change)\n",
+ "\n",
+ " darkness_increase = random.randint(1, neg_change)\n",
+ " stress_increase = neg_change - darkness_increase\n",
+ "\n",
+ " # last_response = history[-1][1]\n",
+ " response = \"\"\n",
+ " response += \"经过了晚上的直播\\n糖糖的压力增加\" + str(stress_increase) + \"点\\n\"\n",
+ " response += \"糖糖的黑暗增加\" + str(darkness_increase) + \"点\\n\\n\"\n",
+ "\n",
+ " response += official_response\n",
+ "\n",
+ " history += [(None, response)]\n",
+ "\n",
+ " state = \"ParseMenuChoice\"\n",
+ "\n",
+ " agent = Agent(agent_text)\n",
+ " agent.apply_attribute_change({\"Stress\": stress_increase, \"Darkness\": darkness_increase})\n",
+ " agent_text = agent.save_to_str()\n",
+ "\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ "def call_event_end(history, text, state,agent_text):\n",
+ " # TODO 增加事件结算\n",
+ " # global state\n",
+ " print(\"call event_end\")\n",
+ " global current_event_id\n",
+ " if attr_change_str != \"\":\n",
+ " # event = events[current_event_id]\n",
+ " # options = event[\"options\"]\n",
+ " # attr_str = options[user_input-1][\"attribute_change\"]\n",
+ "\n",
+ " response = \"\"\n",
+ "\n",
+ " attr_change = parse_attribute_string(attr_change_str)\n",
+ " if len(attr_change) > 0:\n",
+ " response = \"发生属性改变:\" + str(attr_change) + \"\\n\\n\"\n",
+ " agent = Agent(agent_text)\n",
+ " agent.apply_attribute_change(attr_change)\n",
+ "\n",
+ " agent_text = agent.save_to_str()\n",
+ " response += \"当前属性\" + agent_text + \"\\n\\n\"\n",
+ "\n",
+ " if add_stress_switch:\n",
+ " history += [(None, response)]\n",
+ " return call_add_stress(history, text, state,agent_text)\n",
+ " else:\n",
+ " response = \"事件结束\\n\"\n",
+ " else:\n",
+ " response = \"事件结束\\n\"\n",
+ "\n",
+ " response += official_response\n",
+ "\n",
+ " history += [(None, response)]\n",
+ "\n",
+ " state = \"ParseMenuChoice\"\n",
+ "\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ "\n",
+ "\n",
+ "def call_parse_menu_choice(history, text, state,agent_text):\n",
+ " print(\"call parse_menu_choice\")\n",
+ " # global state\n",
+ "\n",
+ " choice = history[-1][0].strip()\n",
+ "\n",
+ " if choice == \"1\":\n",
+ " state = \"EventMaster\"\n",
+ " global current_event_id\n",
+ " current_event_id = -1 # 清空事件\n",
+ " return call_event_master(history, text, state,agent_text)\n",
+ "\n",
+ " elif choice == \"2\":\n",
+ " state = \"ChatMaster\"\n",
+ " elif \"quit\" in choice or \"Quit\" in choice or \"QUIT\" in choice:\n",
+ " state = \"Quit\"\n",
+ " else:\n",
+ " response = \"无效的选项,请重新选择\"\n",
+ " history += [(None, response)]\n",
+ "\n",
+ " response = \"\"\n",
+ " if state == \"ChatMaster\":\n",
+ " response = \"(请输入 阿P 说的话,或者输入Quit退出)\"\n",
+ " elif state != \"ParseMenuChoice\":\n",
+ " response = \"Change State to \" + state\n",
+ "\n",
+ " history += [(None, response)]\n",
+ "\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ "\n",
+ "def call_event_master(history, text, state,agent_text):\n",
+ " print(\"call event master\")\n",
+ "\n",
+ " global current_event_id\n",
+ " # global state\n",
+ "\n",
+ " global event_master\n",
+ "\n",
+ " agent = Agent(agent_text)\n",
+ "\n",
+ " if current_event_id == -1:\n",
+ " current_event_id = event_master.get_random_event_id(agent)\n",
+ " event = events[current_event_id]\n",
+ "\n",
+ " prefix = \"糖糖:\" + event[\"prefix\"]\n",
+ "\n",
+ " response = prefix + \"\\n\\n--请输入数字进行选择,或者进行自由回复--\\n\\n\"\n",
+ "\n",
+ " options = event[\"options\"]\n",
+ "\n",
+ " for i, option in enumerate(event[\"options\"]):\n",
+ " text = option[\"user\"]\n",
+ " response += \"\\n\" + f\"{i+1}. 阿p:{text}\"\n",
+ "\n",
+ " history += [(None, response)]\n",
+ "\n",
+ " else:\n",
+ " user_input = history[-1][0].strip()\n",
+ "\n",
+ " event = events[current_event_id]\n",
+ " options = event[\"options\"]\n",
+ "\n",
+ " if user_input.isdigit():\n",
+ " user_input = int(user_input)\n",
+ "\n",
+ " if user_input > len(options) or user_input < 0:\n",
+ " response = \"输入的数字超出范围,请重新输入符合选项的数字\"\n",
+ " history[-1] = (user_input, response)\n",
+ " else:\n",
+ " user_text = options[user_input-1][\"user\"]\n",
+ " reply = options[user_input-1][\"reply\"]\n",
+ "\n",
+ " # TODO 修改记忆, 修改属性 什么的\n",
+ " history[-1] = (user_text, reply)\n",
+ "\n",
+ " if random.random()<0.5:\n",
+ " image_path = image_master.try_get_image(user_text + \" \" + reply, agent)\n",
+ "\n",
+ " if image_path is not None:\n",
+ " history += [(None, (image_path,))]\n",
+ "\n",
+ " global attr_change_str\n",
+ " attr_change_str = options[user_input-1][\"attribute_change\"]\n",
+ "\n",
+ " else:\n",
+ " prefix = \"糖糖:\" + event[\"prefix\"]\n",
+ "\n",
+ " needy_chatbot.dialogue_history = [(None, prefix)]\n",
+ " # 进入自由回复\n",
+ "\n",
+ " global emoji_str\n",
+ " response, emoji_str = get_chat_response_and_emoji( agent, memory_pool, user_input )\n",
+ "\n",
+ " history[-1] = (user_input,response)\n",
+ "\n",
+ " image_path = image_master.try_get_image(response, agent)\n",
+ "\n",
+ " if image_path is not None:\n",
+ " history += [(None, (image_path,))]\n",
+ "\n",
+ " state = \"EventEnd\"\n",
+ "\n",
+ " if state == \"EventEnd\":\n",
+ " return call_event_end(history, text, state,agent_text)\n",
+ "\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ "def call_chat_master(history, text, state,agent_text):\n",
+ " print(\"call chat master\")\n",
+ " # global state\n",
+ "\n",
+ " agent = Agent(agent_text)\n",
+ "\n",
+ " user_input = history[-1][0].strip()\n",
+ "\n",
+ " if \"quit\" in user_input or \"Quit\" in user_input or \"QUIT\" in user_input:\n",
+ " state = \"ShowMenu\"\n",
+ " history[-1] = (user_input,\"返回主菜单\\n\"+ official_response )\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ " query_text = user_input\n",
+ "\n",
+ " global emoji_str\n",
+ " response, emoji_str = get_chat_response_and_emoji( agent, memory_pool, query_text )\n",
+ "\n",
+ " history[-1] = (user_input,response)\n",
+ "\n",
+ " image_path = image_master.try_get_image(response, agent)\n",
+ "\n",
+ " if image_path is not None:\n",
+ " history += [(None, (image_path,))]\n",
+ "\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ "def grcall_game_master(history, text, state,agent_text):\n",
+ " print(\"call game master\")\n",
+ "\n",
+ " history += [(text, None)]\n",
+ "\n",
+ "\n",
+ " if state == \"ShowMenu\":\n",
+ " return call_showmenu(history, text,state,agent_text)\n",
+ " elif state == \"ParseMenuChoice\":\n",
+ " return call_parse_menu_choice(history, text, state,agent_text)\n",
+ " elif state == \"ChatMaster\":\n",
+ " return call_chat_master(history, text, state,agent_text)\n",
+ " elif state == \"EventMaster\":\n",
+ " return call_event_master(history, text, state,agent_text)\n",
+ " elif state == \"EventEnd\":\n",
+ " return call_event_end(history, text, state,agent_text)\n",
+ "\n",
+ " return history, gr.Textbox(value=\"\", interactive=True), state,agent_text\n",
+ "\n",
+ "\n",
+ "def add_file(history, file):\n",
+ " history = history + [((file.name,), None)]\n",
+ " return history\n",
+ "\n",
+ "\n",
+ "def bot(history):\n",
+ " response = \"**That's cool!**\"\n",
+ " history[-1][1] = \"\"\n",
+ " for character in response:\n",
+ " history[-1][1] += character\n",
+ " time.sleep(0.05)\n",
+ " yield history\n",
+ "\n",
+ "def update_memory(state):\n",
+ " if state == \"ChatMaster\" or state == \"EventMaster\":\n",
+ " return emoji_str\n",
+ " else:\n",
+ " return \"\"\n",
+ "\n",
+ "def change_state(slider_stress, slider_darkness, slider_affection):\n",
+ " # print(agent[\"Stress\"])\n",
+ " agent = Agent()\n",
+ " agent[\"Stress\"] = slider_stress\n",
+ " agent[\"Darkness\"] = slider_darkness\n",
+ " agent[\"Affection\"] = slider_affection\n",
+ " agent_text = agent.save_to_str()\n",
+ " return agent_text\n",
+ "\n",
+ "\n",
+ "def update_attribute_state(agent_text):\n",
+ " agent = Agent(agent_text)\n",
+ " slider_stress = int( agent[\"Stress\"] )\n",
+ " slider_darkness = int( agent[\"Darkness\"] )\n",
+ " slider_affection = int( agent[\"Affection\"] )\n",
+ " return slider_stress, slider_darkness, slider_affection\n",
+ "\n",
+ "with gr.Blocks() as demo:\n",
+ "\n",
+ " gr.Markdown(\n",
+ " \"\"\"\n",
+ " # Chat凉宫春日_x_AI糖糖\n",
+ "\n",
+ " Powered by 文心一言(3.5)版本\n",
+ "\n",
+ " 仍然在开发中, 细节见《项目作者和说明》\n",
+ " \"\"\"\n",
+ " )\n",
+ "\n",
+ " with gr.Tab(\"Needy\"):\n",
+ " chatbot = gr.Chatbot(\n",
+ " [],\n",
+ " elem_id=\"chatbot\",\n",
+ " bubble_full_width=False,\n",
+ " height = 800,\n",
+ " avatar_images=(None, (\"avatar.png\")),\n",
+ " )\n",
+ "\n",
+ " with gr.Row():\n",
+ " txt = gr.Textbox(\n",
+ " scale=4,\n",
+ " show_label=False,\n",
+ " placeholder=\"输入任何字符开始游戏\",\n",
+ " container=False,\n",
+ " )\n",
+ " # btn = gr.UploadButton(\"📁\", file_types=[\"image\", \"video\", \"audio\"])\n",
+ " submit_btr = gr.Button(\"回车\")\n",
+ "\n",
+ " with gr.Row():\n",
+ " memory_emoji_text = gr.Textbox(label=\"糖糖当前的记忆\", value = \"\",interactive = False)\n",
+ "\n",
+ " with gr.Tab(\"糖糖的状态\"):\n",
+ "\n",
+ " with gr.Row():\n",
+ " update_attribute_button = gr.Button(\"同步状态条 | 改变Attribute前必按!\")\n",
+ "\n",
+ " with gr.Row():\n",
+ " default_agent_str = agent.save_to_str()\n",
+ " slider_stress = gr.Slider(0, 100, step=1, label = \"Stress\")\n",
+ " state_stress = gr.State(value=0)\n",
+ " slider_darkness = gr.Slider(0, 100, step=1, label = \"Darkness\")\n",
+ " state_darkness = gr.State(value=0)\n",
+ " slider_affection = gr.Slider(0, 100, step=1, label = \"Affection\")\n",
+ " state_affection = gr.State(value=0)\n",
+ "\n",
+ "\n",
+ "\n",
+ " with gr.Row():\n",
+ " state_text = gr.Textbox(label=\"整体状态机状态\", value = \"ShowMenu\",interactive = False)\n",
+ "\n",
+ " with gr.Row():\n",
+ " default_agent_str = agent.save_to_str()\n",
+ " agent_text = gr.Textbox(label=\"糖糖状态\", value = default_agent_str,interactive = False)\n",
+ "\n",
+ " with gr.Tab(\"项目作者和说明\"):\n",
+ " gr.Markdown(markdown_str)\n",
+ "\n",
+ " slider_stress.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text])\n",
+ " slider_darkness.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text])\n",
+ " slider_affection.release(change_state, inputs=[slider_stress, slider_darkness, slider_affection], outputs=[agent_text])\n",
+ "\n",
+ " update_attribute_button.click(update_attribute_state, inputs = [agent_text], outputs = [slider_stress, slider_darkness, slider_affection])\n",
+ "\n",
+ " txt_msg = txt.submit(grcall_game_master, \\\n",
+ " [chatbot, txt, state_text,agent_text], \\\n",
+ " [chatbot, txt, state_text,agent_text], queue=False).then(update_memory, [state_text], memory_emoji_text)\n",
+ "\n",
+ " txt_msg = submit_btr.click(grcall_game_master, \\\n",
+ " [chatbot, txt, state_text,agent_text], \\\n",
+ " [chatbot, txt, state_text,agent_text], queue=False).then(update_memory, [state_text], memory_emoji_text)\n",
+ "\n",
+ " # txt_msg = txt.submit(add_text, [chatbot, txt], [chatbot, txt], queue=False).then(\n",
+ " # bot, chatbot, chatbot, api_name=\"bot_response\"\n",
+ " # )\n",
+ " # txt_msg.then(lambda: gr.Textbox(interactive=True), None, [txt], queue=False)\n",
+ " # file_msg = btn.upload(add_file, [chatbot, btn], [chatbot], queue=False).then(\n",
+ " # bot, chatbot, chatbot\n",
+ " # )\n",
+ "\n",
+ "demo.queue()\n",
+ "# if __name__ == \"__main__\":\n",
+ "demo.launch(allowed_paths=[\"avatar.png\"],debug = True)\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 958
+ },
+ "id": "2-mPCWpgYCLD",
+ "outputId": "218b35ac-38aa-4cf6-feb0-3dbb3450c10a"
+ },
+ "execution_count": 57,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
+ "\n",
+ "Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().\n",
+ "Running on public URL: https://580e42e1f0dca62ea6.gradio.live\n",
+ "\n",
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
+ ]
+ },
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ ""
+ ],
+ "text/html": [
+ ""
+ ]
+ },
+ "metadata": {}
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "call game master\n",
+ "call showmenu\n",
+ "call game master\n",
+ "call parse_menu_choice\n",
+ "call event master\n",
+ "call game master\n",
+ "call event master\n",
+ "call event_end\n",
+ "call add_stress\n",
+ "call game master\n",
+ "call parse_menu_choice\n",
+ "call event master\n",
+ "call game master\n",
+ "call event master\n",
+ "call event_end\n",
+ "call add_stress\n",
+ "Keyboard interruption in main thread... closing server.\n",
+ "Killing tunnel 127.0.0.1:7860 <> https://580e42e1f0dca62ea6.gradio.live\n"
+ ]
+ },
+ {
+ "output_type": "execute_result",
+ "data": {
+ "text/plain": []
+ },
+ "metadata": {},
+ "execution_count": 57
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Chat凉宫春日_x_AI糖糖\n",
+ "\n",
+ "**Chat凉宫春日**是模仿凉宫春日等一系列动漫人物,使用近似语气、个性和剧情聊天的语言模型方案。\n",
+ "\n",
+ "在有一天的时候,[李鲁鲁](https://github.com/LC1332)被[董雄毅](https://github.com/E-sion)在[这个B站视频](https://www.bilibili.com/video/BV1zh4y1z7G1) at了\n",
+ "\n",
+ "原来是一位大一的同学雄毅用ChatHaruhi接入了他用Python重新实现的《主播女孩重度依赖》这个游戏。当时正好是百度AGIFoundathon报名的最后几天,所以我们邀请了雄毅加入了我们的项目。正巧我们本来就希望在最近的几个黑客松中,探索LLM在游戏中的应用。\n",
+ "\n",
+ "- 在重新整理的Gradio版本中,大部分代码由李鲁鲁实现\n",
+ "\n",
+ "- 董雄毅负责了原版游戏的事件数据整理和新事件、选项、属性变化的生成\n",
+ "\n",
+ "- [米唯实](https://github.com/hhhwmws0117)完成了文心一言的接入,并实现了部分gradio的功能。\n",
+ "\n",
+ "- 队伍中还有冷子昂 主要参加了讨论\n",
+ "\n",
+ "另外在挖坑的萝卜(Amy)的介绍下,我们还邀请了专业的大厂游戏策划Kanyo加入到队伍中,他对我们的策划也给出了很多建议。\n",
+ "\n",
+ "另外感谢飞桨团队对比赛的邀请和中间进行的讨论。\n",
+ "\n",
+ "## 目前计划在11月争取完成的Feature\n",
+ "\n",
+ "- [ ] 结局系统,原版结局系统\n",
+ "- [ ] 教程,教大家如何从aistudio获取token然后可以玩\n",
+ "- [ ] 游戏节奏进一步调整\n",
+ "- [ ] 事件的自由对话对属性影响的评估via LLM"
+ ],
+ "metadata": {
+ "id": "6ed6He52fY8c"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [],
+ "metadata": {
+ "id": "ldV3Y6O4wf0h"
+ },
+ "execution_count": 53,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "game_master = GameMaster()\n",
+ "game_master.run()"
+ ],
+ "metadata": {
+ "id": "KF7RthcCbcka"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "game_master = GameMaster()\n",
+ "game_master.run()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "YGI5SuY0WMGi",
+ "outputId": "e6a101f4-ad84-4b7b-ced3-0711187ba9b7"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "3. 后台修改糖糖的属性\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 3\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "请选择要修改的属性:\n",
+ "1. Stress\n",
+ "2. Darkness\n",
+ "3. Affection\n",
+ "输入 '0' 退出\n",
+ "请输入选项的数字: 60\n",
+ "选择的属性无效,请重试。\n",
+ "请选择要修改的属性:\n",
+ "1. Stress\n",
+ "2. Darkness\n",
+ "3. Affection\n",
+ "输入 '0' 退出\n",
+ "请输入选项的数字: 1\n",
+ "Stress 当前值: 0\n",
+ "请输入新的Stress值: 60\n",
+ "Stress 更新为 60。\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "3. 后台修改糖糖的属性\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "【紧急!】倒着太舒服了不想支棱 你快来帮忙把糖糖扶起来\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:自己站起来\n",
+ "2. 阿p:你先起来我再扶你\n",
+ "3. 阿p:摆个pose再起来\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:我帮你买个电动轮椅吧\n",
+ "Memory: ['', '', '🤔🎮', '', '', '', '']\n",
+ "\n",
+ "嘿嘿,阿P最好了!帮糖糖买电动轮椅吧!糖糖想要呢~\n",
+ "\n",
+ "自由回复的算分功能还未实现\n",
+ "修正事件LineWeekDay67的记忆-->🆘😴😒🙄\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "3. 后台修改糖糖的属性\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "我会变得更加可爱的\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:你已经是最可爱的了\n",
+ "2. 阿p:可爱是无法提升的\n",
+ "3. 阿p:可爱不够重要,内心才是最重要的\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:2\n",
+ "\n",
+ "好伤心QAQ 难道我就注定只能作为“普通可爱”的存在吗?\n",
+ "\n",
+ "发生属性改变: {'Stress': 1.0} \n",
+ "\n",
+ "当前属性 {'Stress': 61.0, 'Darkness': 0, 'Affection': 0}\n",
+ "修正事件event36的记忆-->😊😍😢💔\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "3. 后台修改糖糖的属性\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: Quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "game_master = GameMaster()\n",
+ "game_master.run()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "7ANTtWDRQdw7",
+ "outputId": "5f6f6f1c-3a59-4098-d00f-e6965ed85d7b"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 有个女孩发私信找我谈人生,我该怎么办呐,「超天酱你好,我是一名高中生。之前因为精神疾病而住院了一段时间,现在跟不上学习进度,班上还没决定好志愿的人也只剩我一个了。平时看着同学们为了各自的前程努力奋斗的样子,心里总是非常地焦虑。请你告诉我,我到底应该怎么办才好呢?」\n",
+ "\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:认真\n",
+ "2. 阿p:耍宝\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:「这种事情,光着急是没有用的。总而言之,你现在应该先休养好自己。等恢复好了,再跟父母慢慢商量吧!放心。人生是不会因为不上学就完蛋的!未来就掌握在我们的手中!!!」↑发了这些过去。\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个���件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我今后也会努力加油的,你要支持我哦 还有阿P你自己也要加油哦!\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:哇 说的话跟偶像一样 好恶心哦\n",
+ "2. 阿p:为什么连我也要加油啊?\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:是哦 我怎么会说这样的话呢 我又没有很想努力……\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我正在想下次搞什么企划呢~阿P帮帮我 出出主意\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:比如一直打游戏到通关?\n",
+ "2. 阿p:比如收集观众的提问,然后录一期回答?\n",
+ "3. 阿p:比如坐在超他妈大的乌龟背上绕新宿一圈?\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:那就这么办吧(超听话)\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 阿P,看!我买了小发发\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:真好看,跟糖糖好像\n",
+ "2. 阿p:又买这些没用的~\n",
+ "3. 阿p:不错\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:对吧!我不在的时候,你就把小花花当成糖糖,好好疼爱它吧!\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我也想被做进那个大乱斗游戏……,哎,如果那个游戏里面有超天酱的话,阿P会用我吗?\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:嗯啊\n",
+ "2. 阿p:不打算用\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:真的咩?!那我立刻开始练习捡信\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 如果我要整容,你觉得整哪里比较好?\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:脸\n",
+ "2. 阿p:胸\n",
+ "3. 阿p:手腕\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:人家颜值已经是天下第一了,没什么要改动的啦!阿P,你真的很没礼貌欸\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 嗳,你来帮我打耳洞嘛 让喜欢的人给自己打耳洞很棒不是吗 有一种被支配着的感觉 鸡皮疙瘩都要起来了,我好怕我好怕我好怕,我好怕!,但是来吧!\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:给她打\n",
+ "2. 阿p:还是算了\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:哇!打好了!合适吗?合适吗?快他妈夸我合适!!!\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我问你哦,我真的可以就这样活下去吗?\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:怎么了啊?\n",
+ "2. 阿p:真的可以呀\n",
+ "3. 阿p:对没错\n",
+ "4. 阿p:那还用说\n",
+ "5. 阿p:其实谁都行\n",
+ "6. 阿p:脸\n",
+ "7. 阿p:一切\n",
+ "8. 阿p:没什么不行吧?\n",
+ "9. 阿p:不可以\n",
+ "10. 阿p:喜欢啊\n",
+ "11. 阿p:喜欢吧\n",
+ "12. 阿p:真的超超喜欢\n",
+ "13. 阿p:超超喜欢\n",
+ "14. 阿p:以当代互联网小天使的身份活下去\n",
+ "15. 阿p:真的超超喜欢\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 糖糖,是不是还是去死一死比较好……\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:要活下去啊!!!\n",
+ "2. 阿p:死~寂\n",
+ "3. 阿p:你有颜值啊\n",
+ "4. 阿p:不如砍掉重练吧!\n",
+ "5. 阿p:不是还有宅宅们嘛\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:可是,糖糖又没有活着的价值……\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 机会这么难得,要不整点富婆快乐活吧,说不定还能用作下次的企划哦!\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:买头老虎在大街上放生\n",
+ "2. 阿p:无所谓,不管你是不是富婆我都爱你\n",
+ "3. 阿p:要不把整个筑地买下来吧\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:好像买一头就要几百万哦……\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我要出去玩!给我零花钱!!!\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:给10圆\n",
+ "2. 阿p:给3000圆\n",
+ "3. 阿p:给10000圆\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:这点钱连小学生都打发不了好吧!!!真是的,看我今天赖在家黏你一整天!!!!\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 小天使请安!这个开场白也说厌了啊~,帮我想个别的开场白!\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:当代互联网小天使,参上!\n",
+ "2. 阿p:我是路过的网络主播,给我记住了!\n",
+ "3. 阿p:那么,我们开始直播吧\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:试着上超天酱的钩吧?之类的嘿嘿\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我们点外卖吧我一步也不想动��可是又超想吃饭!!!\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:烦死了白痴\n",
+ "2. 阿p:吃土去吧你\n",
+ "3. 阿p:那我点了哦\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:555555555 但是我们得省钱对吧\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 哎,你会希望看到糖糖将来的样子吗?\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:机器人\n",
+ "2. 阿p:合成怪物\n",
+ "3. 阿p:狂战士\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:——“糖糖”OS,启动\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我没打招呼就把冰箱里的布丁吃了 会被判死刑吗???\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:原谅你\n",
+ "2. 阿p:糖糖可以随便吃哦\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:嗯 能被糖糖吃掉也是布丁的荣幸 所以当然没问题\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 今天有点想试试平时不会做的事\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:杀人\n",
+ "2. 阿p:相爱\n",
+ "3. 阿p:抢银行\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:如果我搞砸了……就由阿P杀了我吧\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 哎,你喜欢什么样的糖糖啊?\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:无情人设\n",
+ "2. 阿p:天才博士人设\n",
+ "3. 阿p:得寸进尺小萝莉\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:……我不明白,“感情”是什么\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "warning! all candidate event was sampled\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我也想被做进那个大乱斗游戏……,哎,如果那个游戏里面有超天酱的话,阿P会用我吗?\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:嗯啊\n",
+ "2. 阿p:不打算用\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:真的咩?!那我立刻开始练习捡信\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "warning! all candidate event was sampled\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 我没打招呼就把冰箱里的布丁吃了 会被判死刑吗???\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:原谅你\n",
+ "2. 阿p:糖糖可以随便吃哦\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:1\n",
+ "\n",
+ "糖糖:嗯 能被糖糖吃掉也是布丁的荣幸 所以当然没问题\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: Quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "game_master = GameMaster()\n",
+ "game_master.run()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "5GwFCR_wLtay",
+ "outputId": "9dc0c692-9dd4-4310-cd1a-3fdb89fa76b8"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 1\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "糖糖: 机会这么难得,要不整点富婆快乐活吧,说不定还能用作下次的企划哦!\n",
+ "\n",
+ "--请选择你的回复--\n",
+ "1. 阿p:买头老虎在大街上放生\n",
+ "2. 阿p:无所谓,不管你是不是富婆我都爱你\n",
+ "3. 阿p:要不把整个筑地买下来吧\n",
+ "\n",
+ "请直接输入数字进行选择,或者进行自由回复(未实现)\n",
+ "阿p:我觉得可以把钱拿来进一步投资哦\n",
+ "Memory: ['💰😓', '🤔😳', '🤔🎮', '💸😡', '😔😌', '😔😔', '😔😍']\n",
+ "糖糖:「阿哈,投资?那我是不是可以买更多的二次元周边啦?!」\n",
+ "自由回复的算分功能还未实现\n",
+ "\n",
+ "-------------\n",
+ "\n",
+ "('糖糖:「 机会这么难得,要不整点富婆快乐活吧,说不定还能用作下次的企划哦!」\\n阿P:「买头老虎在大街上放生」\\n糖糖:「好像买一头就要几百万哦……」\\n', '💰😓')\n",
+ "按任意键继续...Quit\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: Quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\n",
+ "game_master = GameMaster()\n",
+ "game_master.run()"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "zPmr9kVepwjh",
+ "outputId": "3a8bcbc6-06ef-4542-ef70-03cd8ed0b357"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: 2\n",
+ "聊天:你好呀糖糖\n",
+ "Memory: ['😔😔', '🍔😢', '💸😡', '🤔😔', '🍬😔', '💪😔', '🤔😊']\n",
+ "糖糖:「哈喽~阿哈!终于又见面了呢,我都快等不及了呢!」\n",
+ "聊天:等不及要心心了吗\n",
+ "Memory: ['😔😌', '🍔😢', '🤔😳', '💔😢', '😳😅', '💰😓', '😔😔']\n",
+ "糖糖:「诶~你怎么这么了解我呀!心心已经开始了,我都快被你迷得神魂颠倒了!」\n",
+ "聊天:Quit\n",
+ "1. 随机一个事件\n",
+ "2. 自由聊天\n",
+ "或者输入Quit退出\n",
+ "请选择一个选项: quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "\n",
+ "---\n",
+ "\n",
+ "这个以下都是非主要代码和单元测试\n",
+ "\n",
+ "---\n",
+ "\n",
+ "这个以下都是非主要代码和单元测试\n",
+ "\n",
+ "\n",
+ "---\n",
+ "\n",
+ "这个以下都是非主要代码和单元测试\n",
+ "\n",
+ "\n",
+ "---\n",
+ "\n",
+ "这个以下都是非主要代码和单元测试\n",
+ "\n"
+ ],
+ "metadata": {
+ "id": "WHxC8m7oH3W4"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 不同状态下的Agent测试"
+ ],
+ "metadata": {
+ "id": "m5J7wuRoIqTd"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "chat_master = ChatMaster(memory_pool)\n",
+ "agent = Agent()\n",
+ "agent[\"Stress\"] = 0\n",
+ "agent[\"Affection\"] = 0\n",
+ "agent[\"Darkness\"] = 0\n",
+ "\n",
+ "chat_master.run(agent)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "QBY81TRMIrID",
+ "outputId": "0c18759e-24b5-48ff-8a59-dedb88c85a79"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "阿p:你今天心情怎么样?\n",
+ "Memory: ['', '', '😔', '', '🍬😔', '', '']\n",
+ "啊~今天的心情还好啦~有点嗨,有点闷,有点复杂的感觉~不过没关系,糖糖还是会努力开心起来的~你今天遇到什么有趣的事情了吗?快来分享一下嘛!\n",
+ "阿p:Quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "chat_master = ChatMaster(memory_pool)\n",
+ "agent = Agent()\n",
+ "agent[\"Stress\"] = 100\n",
+ "agent[\"Affection\"] = 0\n",
+ "agent[\"Darkness\"] = 0\n",
+ "\n",
+ "chat_master.run(agent)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "VoXh56exJIrL",
+ "outputId": "544cdd1c-b274-471d-890b-3e3a9377593d"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "阿p:你今天心情怎么样?\n",
+ "Memory: ['', '', '', '', '', '', '']\n",
+ "啊~今天心情真的是超级烂,简直就是要爆炸了QAQ,一点都不开心呢。你有没有什么好玩的事情可以分享一下?\n",
+ "阿p:Quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "chat_master = ChatMaster(memory_pool)\n",
+ "agent = Agent()\n",
+ "agent[\"Stress\"] = 0\n",
+ "agent[\"Affection\"] = 80\n",
+ "agent[\"Darkness\"] = 0\n",
+ "\n",
+ "chat_master.run(agent)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "EPISkUJVJXzm",
+ "outputId": "2f4d1181-7ded-4d5b-f58b-a67e1715d6af"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "阿p:糖糖,快表演机器人\n",
+ "Memory: ['🤔😔', '🍬😔', '', '', '', '', '🎉😊']\n",
+ "啊哈~阿P你真是个调皮鬼,总是喜欢逗我玩,真是让我笑死了!好吧,我就给你表演个机器人吧!看好了啊~「机器人模式启动」(机械声效)「Beep beep boop」(模仿机器人声音)「我是糖糖机器人,全面服务中,请问阿P有什么指令?」嘿嘿~怎么样,我是不是个超级可爱的机器人呢?QWQ\n",
+ "阿p:Quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "chat_master = ChatMaster(memory_pool)\n",
+ "agent = Agent()\n",
+ "agent[\"Stress\"] = 0\n",
+ "agent[\"Affection\"] = 0\n",
+ "agent[\"Darkness\"] = 0\n",
+ "\n",
+ "chat_master.run(agent)"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "eCJdzQSkJdy7",
+ "outputId": "6d8264b2-b6f6-4217-ce4a-9aec0a940636"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "阿p:糖糖,快表演机器人\n",
+ "Memory: ['🤔😔', '🍬😔', '', '', '🎉😊', '', '']\n",
+ "啊哈~阿P你真是个大坏蛋,总是逗我开心,真是让我笑死了!好吧,我就给你表演个机器人吧!看好了啊~「机器人模式启动」(模仿机械声音)「Beep beep boop」(模仿机器人声音)「我是糖糖机器人,全面服务中,请问阿P有什么指令?」嘿嘿~怎么样,我是不是个超级可爱的机器人呢?阿哈~快夸我一下吧!QWQ\n",
+ "阿p:Quit\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Memory\n",
+ "\n",
+ "memory我们希望Event和Memory是分离的Event的标准字段如下\n",
+ "\n",
+ "- Name, Event的Name,用来后续如果玩家进行游戏修改的话可以根据\n",
+ "- Text, 这个event下完整的对话文本\n",
+ "- Embedding, text的embedding\n",
+ "- Condition, 这个event对应的出现条件\n",
+ "- Emoji, 这个memory的缩写显示emoji\n",
+ "\n",
+ "Memory应该可以从Event去默认load一个"
+ ],
+ "metadata": {
+ "id": "NQuYYbb33-Cc"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "example_memory_json = {\n",
+ " \"Name\": \"EventName\",\n",
+ " \"Text\": \"Sample Text\",\n",
+ " \"Embedding\": [0,0,0],\n",
+ " \"Condition\": \"\",\n",
+ " \"Emoji\": \"😓🤯\"\n",
+ "}"
+ ],
+ "metadata": {
+ "id": "JaKoW7oK391c"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Memory会包含下面几个字段\n",
+ "\n",
+ "example_memory_json = {\n",
+ " \"Name\": \"EventName\",\n",
+ " \"Text\": \"Sample Text\",\n",
+ " \"Embedding\": [0,0,0],\n",
+ " \"Condition\": \"\",\n",
+ " \"Emoji\": \"😓🤯\"\n",
+ "}\n",
+ "\n",
+ "请为我创建一个Memory类\n",
+ "\n",
+ "这个memory类可以通过Memory(json_str)来载入\n",
+ "\n",
+ "同时这个类也有和DIalogueEvent类似的get和setitem的功能"
+ ],
+ "metadata": {
+ "id": "qUcHULFR4GQR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "# Memory 类不再使用\n",
+ "\n",
+ "# import json\n",
+ "\n",
+ "# class Memory:\n",
+ "# def __init__(self, json_str=None):\n",
+ "# if json_str:\n",
+ "# try:\n",
+ "# self.data = json.loads(json_str)\n",
+ "# except json.JSONDecodeError:\n",
+ "# print(\"输入的字符串不是有效的JSON格式。\")\n",
+ "# self.data = {}\n",
+ "# else:\n",
+ "# self.data = {}\n",
+ "\n",
+ "# def load_from_event( event ):\n",
+ "# pass\n",
+ "\n",
+ "# def __getitem__(self, key):\n",
+ "# return self.data.get(key, None)\n",
+ "\n",
+ "# def __setitem__(self, key, value):\n",
+ "# self.data[key] = value\n",
+ "\n",
+ "# def __repr__(self):\n",
+ "# return str(self.data)\n",
+ "\n",
+ "\n",
+ "# example_memory_json = {\n",
+ "# \"Name\": \"EventName\",\n",
+ "# \"Text\": \"Sample Text\",\n",
+ "# \"Embedding\": [0, 0, 0],\n",
+ "# \"Condition\": \"\",\n",
+ "# \"Emoji\": \"😓🤯\"\n",
+ "# }\n",
+ "\n",
+ "# # 通过给定的json字符串初始化Memory实例\n",
+ "# memory = Memory(json.dumps(example_memory_json))\n",
+ "\n",
+ "# # 通过类似字典的方式访问数据\n",
+ "# print(memory[\"Name\"]) # 打印Name字段的内容\n",
+ "# print(memory[\"Emoji\"]) # 打印Emoji字段的内容\n"
+ ],
+ "metadata": {
+ "id": "Jnjyi62a4Bbt"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## parse_attribute_string单元测试"
+ ],
+ "metadata": {
+ "id": "mVgTS5dlFn6P"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from util import parse_attribute_string\n",
+ "\n",
+ "# Test cases\n",
+ "print(parse_attribute_string(\"Stress: -1.0, Affection: +0.5\")) # Output: {'Stress': -1.0, 'Affection': 0.5}\n",
+ "print(parse_attribute_string(\"Affection: +4.0, Stress: -2.0, Darkness: -1.0\")) # Output: {'Affection': 4.0, 'Stress': -2.0, 'Darkness': -1.0}\n",
+ "print(parse_attribute_string(\"Affection: +2.0, Stress: -1.0, Darkness: ?\")) # Output: {'Affection': 2.0, 'Stress': -1.0, 'Darkness': 0}\n",
+ "print(parse_attribute_string(\"Stress: -1.0\")) # Output: {'Stress': -1.0}\n"
+ ],
+ "metadata": {
+ "id": "HGaXw1osFo7U"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Embedding 单元测试"
+ ],
+ "metadata": {
+ "id": "6MEN4KahF-Ab"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "!pip install -q transformers\n",
+ "\n",
+ "from util import get_bge_embedding_zh\n",
+ "\n",
+ "result = get_bge_embedding_zh(\"你好\")\n",
+ "print( result )"
+ ],
+ "metadata": {
+ "id": "86lKC20uF_8_"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## parsing_condition_string 单元测试"
+ ],
+ "metadata": {
+ "id": "WM1c9xMXGJHT"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from util import parsing_condition_string\n",
+ "\n",
+ "# 测试例子\n",
+ "example_inputs = [\n",
+ " \"Random Noon Event: Darkness 0-39\",\n",
+ " \"Random Noon Event: Stress 0-19\",\n",
+ " \"Random Noon Event: Affection 61+\",\n",
+ " \"Random Noon Event: No Attribute\"\n",
+ "]\n",
+ "\n",
+ "for example_input in example_inputs:\n",
+ " print(f\"example_input:\\n{example_input}\\nexample_output\\n{parsing_condition_string(example_input)}\\n\")\n"
+ ],
+ "metadata": {
+ "id": "93GwecaBGIys"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "我已经实现了一个类\n",
+ "\n",
+ "class ChatHaruhi:\n",
+ "\n",
+ "\n",
+ "这个类有两个关键方法\n",
+ "\n",
+ "```python\n",
+ "\n",
+ " def add_story(self, query):\n",
+ "\n",
+ " if self.db is None:\n",
+ " return\n",
+ " \n",
+ " query_vec = self.embedding(query)\n",
+ "\n",
+ " stories = self.db.search(query_vec, self.k_search)\n",
+ " \n",
+ " story_string = self.story_prefix_prompt\n",
+ " sum_story_token = self.tokenizer(story_string)\n",
+ " \n",
+ " for story in stories:\n",
+ " story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)\n",
+ " if sum_story_token + story_token > self.max_len_story:\n",
+ " break\n",
+ " else:\n",
+ " sum_story_token += story_token\n",
+ " story_string += story + self.dialogue_divide_token\n",
+ "\n",
+ " self.llm.user_message(story_string)\n",
+ "\n",
+ " def chat(self, text, role):\n",
+ " # add system prompt\n",
+ " self.llm.initialize_message()\n",
+ " self.llm.system_message(self.system_prompt)\n",
+ " \n",
+ "\n",
+ " # add story\n",
+ " query = self.get_query_string(text, role)\n",
+ " self.add_story( query )\n",
+ "\n",
+ " # add history\n",
+ " self.add_history()\n",
+ "\n",
+ " # add query\n",
+ " self.llm.user_message(query)\n",
+ " \n",
+ " # get response\n",
+ " response_raw = self.llm.get_response()\n",
+ "\n",
+ " response = response_postprocess(response_raw, self.dialogue_bra_token, self.dialogue_ket_token)\n",
+ "\n",
+ " # record dialogue history\n",
+ " self.dialogue_history.append((query, response))\n",
+ "\n",
+ "\n",
+ "\n",
+ " return response\n",
+ "```\n",
+ "\n",
+ "我希望在一个新的应用中复用这个类,\n",
+ "\n",
+ "但是在新的应用中,我定义了新的方法来获取add_story中的stories\n",
+ "\n",
+ "即\n",
+ "\n",
+ "stories = new_get_stories( query )\n",
+ "\n",
+ "我现在想复用这个类,仅改变add_stories方法,我有什么好的办法来实现?"
+ ],
+ "metadata": {
+ "id": "LAYDsOmKKPNv"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "```python\n",
+ "class EnhancedChatHaruhi(ChatHaruhi):\n",
+ "\n",
+ " def new_get_stories(self, query):\n",
+ " # 这里实现您新的获取故事的方法\n",
+ " # 返回故事列表\n",
+ " pass\n",
+ "\n",
+ " def add_story(self, query):\n",
+ " if self.db is None:\n",
+ " return\n",
+ " \n",
+ " # 调用新的获取故事的方法\n",
+ " stories = self.new_get_stories(query)\n",
+ " \n",
+ " story_string = self.story_prefix_prompt\n",
+ " sum_story_token = self.tokenizer(story_string)\n",
+ " \n",
+ " for story in stories:\n",
+ " story_token = self.tokenizer(story) + self.tokenizer(self.dialogue_divide_token)\n",
+ " if sum_story_token + story_token > self.max_len_story:\n",
+ " break\n",
+ " else:\n",
+ " sum_story_token += story_token\n",
+ " story_string += story + self.dialogue_divide_token\n",
+ "\n",
+ " self.llm.user_message(story_string)\n",
+ "```"
+ ],
+ "metadata": {
+ "id": "QRvwYYQH1xD4"
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "我希望实现一个python函数\n",
+ "\n",
+ "分析一个字符串中有没有\":\"\n",
+ "\n",
+ "如果有,我希望在第一个\":\"的位置分开成str_left和str_right,并以f\"{str_left}:「{str_right}」\"的形式输出\n",
+ "\n",
+ "例子输入\n",
+ "爸爸:我真棒\n",
+ "例子输出\n",
+ "爸爸:「我真棒」\n",
+ "例子输入\n",
+ "这一句没有冒号\n",
+ "例子输出\n",
+ ":「这一���没有冒号」\n"
+ ],
+ "metadata": {
+ "id": "kiDXmwI21znH"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def wrap_text_with_colon(text):\n",
+ " # 查找冒号在字符串中的位置\n",
+ " colon_index = text.find(\":\")\n",
+ "\n",
+ " # 如果找到了冒号\n",
+ " if colon_index != -1:\n",
+ " # 分割字符串为左右两部分\n",
+ " str_left = text[:colon_index]\n",
+ " str_right = text[colon_index+1:]\n",
+ " # 构造新的格式化字符串\n",
+ " result = f\"{str_left}:「{str_right}」\"\n",
+ " else:\n",
+ " # 如果没有找到冒号,整个字符串被认为是右侧部分\n",
+ " result = f\":「{text}」\"\n",
+ "\n",
+ " return result\n",
+ "\n",
+ "# 示例输入\n",
+ "print(wrap_text_with_colon(\"爸爸:我真棒\")) # 爸爸:「我真棒」\n",
+ "print(wrap_text_with_colon(\"这一句没有冒号\")) # :「这一句没有冒号」\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "ZUWO0yqNMuoW",
+ "outputId": "4c815ef4-5f5d-43ec-856d-8afe7d1741b8"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "爸爸:「我真棒」\n",
+ ":「这一句没有冒号」\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## MemoryPool的单元测试"
+ ],
+ "metadata": {
+ "id": "5v3VfnluEp3_"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "retrieved_memories = memory_pool.retrieve( agent , \"你是一个什么样的主播啊\" )\n",
+ "\n",
+ "for mem in retrieved_memories[:2]:\n",
+ " print(mem[\"text\"])\n",
+ " print(mem[\"emoji\"])\n",
+ " print(\"---\")"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "gbkumgmX2VPF",
+ "outputId": "76cad38f-47d4-4189-dc0f-347446d64703"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "糖糖:「 我也想被做进那个大乱斗游戏……,哎,如果那个游戏里面有超天酱的话,阿P会用我吗?」\n",
+ "阿P:「嗯啊」\n",
+ "糖糖:「真的咩?!那我立刻开始练习捡信」\n",
+ "\n",
+ "😔😍\n",
+ "---\n",
+ "糖糖:「 我今后也会努力加油的,你要支持我哦 还有阿P你自己也要加油哦!」\n",
+ "阿P:「哇 说的话跟偶像一样 好恶心哦」\n",
+ "糖糖:「是哦 我怎么会说这样的话呢 我又没有很想努力……」\n",
+ "\n",
+ "💪😔\n",
+ "---\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## Agent的单元测试"
+ ],
+ "metadata": {
+ "id": "a45r14X8E9XR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from Agent import Agent\n",
+ "\n",
+ "agent = Agent()\n",
+ "\n",
+ "if __name__ == \"__main__\":\n",
+ " # 示例用法\n",
+ "\n",
+ " print(agent[\"Stress\"]) # 输出 0\n",
+ " agent[\"Stress\"] += 1\n",
+ " print(agent[\"Stress\"]) # 输出 1\n",
+ " agent.apply_attribute_change({\"Darkness\": -1, \"Stress\": 1})\n",
+ " print(agent[\"Darkness\"]) # 输出 -1\n",
+ " print(agent[\"Stress\"]) # 输出 2\n",
+ " agent.apply_attribute_change({\"Nonexistent\": 5}) # 输出 Warning: Nonexistent not in attributes, skipping\n",
+ "\n",
+ " condition = ('Stress', 0, 19)\n",
+ "\n",
+ " print( agent.in_condition( condition ) )"
+ ],
+ "metadata": {
+ "id": "VyPhQxNZEsHC"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## DialogueEvent的单元测试"
+ ],
+ "metadata": {
+ "id": "lcIJuHfiGDI3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from DialogueEvent import DialogueEvent\n",
+ "\n",
+ "\n",
+ "example_json_str = \"\"\"{\"prefix\": \"糖糖: 嘿嘿,最近我在想要不要改变直播风格,你觉得我应该怎么做呀?\", \"options\": [{\"user\": \"你可以试试唱歌直播呀!\", \"reply\": \"糖糖: 哇!唱歌直播是个好主意!我可以把我的可爱音色展现给大家听听!谢谢你的建议!\", \"attribute_change\": \"Stress: -1.0\"}, {\"user\": \"你可以尝试做一���搞笑的小品,逗大家开心。\", \"reply\": \"糖糖: 哈哈哈,小品确实挺有趣的!我可以挑战一些搞笑角色,给大家带来欢乐!谢谢你的建议!\", \"attribute_change\": \"Stress: -1.0\"}, {\"user\": \"你可以尝试做游戏直播,和观众一起玩游戏。\", \"reply\": \"糖糖: 游戏直播也不错!我可以和观众一起玩游戏,互动更加有趣!谢谢你的建议!\", \"attribute_change\": \"Stress: -1.0\"}]}\"\"\"\n",
+ "\n",
+ "# 通过给定的json字符串初始化DialogueEvent实例\n",
+ "event = DialogueEvent(example_json_str)\n",
+ "\n",
+ "# 通过类似字典的方式访问数据\n",
+ "# print(event[\"options\"]) # 打印options字段的内容\n",
+ "\n",
+ "print(event.transfer_output(1) )\n",
+ "\n",
+ "print(event.get_most_neutral())\n",
+ "\n",
+ "print(event.most_neutral_output())\n",
+ "\n"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0Tp8qSXNGFNn",
+ "outputId": "2ec91dde-7d26-450d-a283-084bd7456631"
+ },
+ "execution_count": null,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "糖糖:「 嘿嘿,最近我在想要不要改变直播风格,你觉得我应该怎么做呀?」\n",
+ "阿P:「你可以尝试做一些搞笑的小品,逗大家开心。」\n",
+ "糖糖:「 哈哈哈,小品确实挺有趣的!我可以挑战一些搞笑角色,给大家带来欢乐!谢谢你的建议!」\n",
+ "\n",
+ "0\n",
+ "('糖糖:「 嘿嘿,最近我在想要不要改变直播风格,你觉得我应该怎么做呀?」\\n阿P:「你可以试试唱歌直播呀!」\\n糖糖:「 哇!唱歌直播是个好主意!我可以把我的可爱音色展现给大家听听!谢谢你的建议!」\\n', '📄📄')\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## NeedyHaruhi的单元测试"
+ ],
+ "metadata": {
+ "id": "wNiah9RrGhCQ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "needy_chatbot = NeedyHaruhi( system_prompt = system_prompt ,\n",
+ " story_text_folder = None )\n",
+ "\n",
+ "query_text = \"糖糖,你今天怎么样啊?\"\n",
+ "query_text_for_embedding = \"阿p:「\" + query_text + \"」\"\n",
+ "retrieved_memories = memory_pool.retrieve( agent , query_text )\n",
+ "\n",
+ "memory_text = [mem[\"text\"] for mem in retrieved_memories]\n",
+ "memory_emoji = [mem[\"emoji\"] for mem in retrieved_memories]\n",
+ "\n",
+ "needy_chatbot.set_stories( memory_text )\n",
+ "\n",
+ "print(\"Mem:\", memory_emoji )\n",
+ "\n",
+ "response = needy_chatbot.chat( role = \"阿p\", text = query_text )\n",
+ "print(response)"
+ ],
+ "metadata": {
+ "id": "XwcbSxlYGFY3"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 载入ChatHaruhi的测试"
+ ],
+ "metadata": {
+ "id": "BdARAEura7yJ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "from chatharuhi import ChatHaruhi\n",
+ "\n",
+ "chatbot = ChatHaruhi( role_from_hf = 'chengli-thu/Jack-Sparrow', \\\n",
+ " llm = 'openai',\n",
+ " embedding = 'bge_en'\n",
+ " )"
+ ],
+ "metadata": {
+ "id": "ISd8bD4Ya85A"
+ },
+ "execution_count": null,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "显示图片"
+ ],
+ "metadata": {
+ "id": "sR9u0ArQQmvo"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import matplotlib.image as mpimg\n",
+ "\n",
+ "image_path = '/content/image'\n",
+ "\n",
+ "for data in data_img_text:\n",
+ " img_name = data['img_name']\n",
+ "\n",
+ " # 拼接完整的图片路径\n",
+ " img_path = os.path.join(image_path, img_name)\n",
+ "\n",
+ " # 读取图片\n",
+ " img = mpimg.imread(img_path)\n",
+ "\n",
+ " # 可视化图片\n",
+ " plt.imshow(img)\n",
+ " plt.axis('off')\n",
+ " plt.show()\n",
+ "\n",
+ " break"
+ ],
+ "metadata": {
+ "id": "6T9LfbweQnh5"
+ },
+ "execution_count": null,
+ "outputs": []
+ }
+ ]
+}
\ No newline at end of file
|