diff --git a/gpt4all-chat/chat.cpp b/gpt4all-chat/chat.cpp index 163b1dd37de2..e504f1657a21 100644 --- a/gpt4all-chat/chat.cpp +++ b/gpt4all-chat/chat.cpp @@ -385,7 +385,7 @@ bool Chat::serialize(QDataStream &stream, int version) const stream << m_modelInfo.filename(); if (version > 2) stream << m_collections; - if (!m_llmodel->serialize(stream, version)) + if (!m_llmodel->serialize(stream, version, true /*serializeKV*/)) return false; if (!m_chatModel->serialize(stream, version)) return false; @@ -404,29 +404,36 @@ bool Chat::deserialize(QDataStream &stream, int version) QString modelId; stream >> modelId; if (version > 4) { - if (!ModelList::globalInstance()->contains(modelId)) - return false; - m_modelInfo = ModelList::globalInstance()->modelInfo(modelId); + if (ModelList::globalInstance()->contains(modelId)) + m_modelInfo = ModelList::globalInstance()->modelInfo(modelId); } else { - if (!ModelList::globalInstance()->containsByFilename(modelId)) - return false; - m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId); + if (ModelList::globalInstance()->containsByFilename(modelId)) + m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId); } - emit modelInfoChanged(); + if (!m_modelInfo.id().isEmpty()) + emit modelInfoChanged(); + + bool deserializeKV = true; // make this a setting + bool discardKV = m_modelInfo.id().isEmpty(); // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so // unfortunately, we cannot deserialize these if (version < 2 && m_modelInfo.filename().contains("gpt4all-j")) - return false; + discardKV = true; + if (version > 2) { stream >> m_collections; emit collectionListChanged(m_collections); } m_llmodel->setModelInfo(m_modelInfo); - if (!m_llmodel->deserialize(stream, version)) + if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV)) return false; if (!m_chatModel->deserialize(stream, version)) return false; + + if (!deserializeKV || discardKV) + m_llmodel->setStateFromText(m_chatModel->text()); + emit chatModelChanged(); return stream.status() == QDataStream::Ok; } diff --git a/gpt4all-chat/chatlistmodel.cpp b/gpt4all-chat/chatlistmodel.cpp index 062ebdd936a6..26c8b564eeae 100644 --- a/gpt4all-chat/chatlistmodel.cpp +++ b/gpt4all-chat/chatlistmodel.cpp @@ -84,13 +84,16 @@ void ChatSaver::saveChats(const QVector &chats) const QString savePath = MySettings::globalInstance()->modelPath(); for (Chat *chat : chats) { QString fileName = "gpt4all-" + chat->id() + ".chat"; - QFile file(savePath + "/" + fileName); - bool success = file.open(QIODevice::WriteOnly); + QString filePath = savePath + "/" + fileName; + QFile originalFile(filePath); + QFile tempFile(filePath + ".tmp"); // Temporary file + + bool success = tempFile.open(QIODevice::WriteOnly); if (!success) { - qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName(); + qWarning() << "ERROR: Couldn't save chat to temporary file:" << tempFile.fileName(); continue; } - QDataStream out(&file); + QDataStream out(&tempFile); out << (quint32)CHAT_FORMAT_MAGIC; out << (qint32)CHAT_FORMAT_VERSION; @@ -98,11 +101,16 @@ void ChatSaver::saveChats(const QVector &chats) qDebug() << "serializing chat" << fileName; if (!chat->serialize(out, CHAT_FORMAT_VERSION)) { - qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName(); - file.remove(); + qWarning() << "ERROR: Couldn't serialize chat to file:" << tempFile.fileName(); + tempFile.remove(); + continue; } - file.close(); + + if (originalFile.exists()) + originalFile.remove(); + tempFile.rename(filePath); } + qint64 elapsedTime = timer.elapsed(); qDebug() << "serializing chats took:" << elapsedTime << "ms"; emit saveChatsFinished(); @@ -224,7 +232,6 @@ void ChatsRestoreThread::run() chat->moveToThread(qApp->thread()); if (!chat->deserialize(in, version)) { qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName(); - file.remove(); } else { emit chatRestored(chat); } diff --git a/gpt4all-chat/chatllm.cpp b/gpt4all-chat/chatllm.cpp index 6a23879d0f7d..6528f41ceca1 100644 --- a/gpt4all-chat/chatllm.cpp +++ b/gpt4all-chat/chatllm.cpp @@ -69,6 +69,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer) , m_forceMetal(MySettings::globalInstance()->forceMetal()) , m_reloadingToChangeVariant(false) , m_processedSystemPrompt(false) + , m_restoreStateFromText(false) { moveToThread(&m_llmThread); connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup); @@ -726,7 +727,35 @@ bool ChatLLM::handleSystemRecalculate(bool isRecalc) return false; } -bool ChatLLM::serialize(QDataStream &stream, int version) +bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token) +{ +#if defined(DEBUG) + qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating; +#endif + Q_UNUSED(token); + return !m_stopGenerating; +} + +bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response) +{ +#if defined(DEBUG) + qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating; +#endif + Q_UNUSED(token); + Q_UNUSED(response); + return false; +} + +bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc) +{ +#if defined(DEBUG) + qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc; +#endif + Q_UNUSED(isRecalc); + return false; +} + +bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV) { if (version > 1) { stream << m_llModelType; @@ -741,6 +770,14 @@ bool ChatLLM::serialize(QDataStream &stream, int version) stream << response(); stream << generatedName(); stream << m_promptResponseTokens; + + if (!serializeKV) { +#if defined(DEBUG) + qDebug() << "serialize" << m_llmThread.objectName() << m_state.size(); +#endif + return stream.status() == QDataStream::Ok; + } + if (version <= 3) { int responseLogits; stream << responseLogits; @@ -759,7 +796,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version) return stream.status() == QDataStream::Ok; } -bool ChatLLM::deserialize(QDataStream &stream, int version) +bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV) { if (version > 1) { int internalStateVersion; @@ -773,26 +810,60 @@ bool ChatLLM::deserialize(QDataStream &stream, int version) stream >> nameResponse; m_nameResponse = nameResponse.toStdString(); stream >> m_promptResponseTokens; + + // If we do not deserialize the KV or it is discarded, then we need to restore the state from the + // text only. This will be a costly operation, but the chat has to be restored from the text archive + // alone. + m_restoreStateFromText = !deserializeKV || discardKV; + + if (!deserializeKV) { +#if defined(DEBUG) + qDebug() << "deserialize" << m_llmThread.objectName(); +#endif + return stream.status() == QDataStream::Ok; + } + if (version <= 3) { int responseLogits; stream >> responseLogits; } - stream >> m_ctx.n_past; + + int32_t n_past; + stream >> n_past; + if (!discardKV) m_ctx.n_past = n_past; + quint64 logitsSize; stream >> logitsSize; - m_ctx.logits.resize(logitsSize); - stream.readRawData(reinterpret_cast(m_ctx.logits.data()), logitsSize * sizeof(float)); + if (!discardKV) { + m_ctx.logits.resize(logitsSize); + stream.readRawData(reinterpret_cast(m_ctx.logits.data()), logitsSize * sizeof(float)); + } else { + stream.skipRawData(logitsSize * sizeof(float)); + } + quint64 tokensSize; stream >> tokensSize; - m_ctx.tokens.resize(tokensSize); - stream.readRawData(reinterpret_cast(m_ctx.tokens.data()), tokensSize * sizeof(int)); + if (!discardKV) { + m_ctx.tokens.resize(tokensSize); + stream.readRawData(reinterpret_cast(m_ctx.tokens.data()), tokensSize * sizeof(int)); + } else { + stream.skipRawData(tokensSize * sizeof(int)); + } + if (version > 0) { QByteArray compressed; stream >> compressed; - m_state = qUncompress(compressed); + if (!discardKV) + m_state = qUncompress(compressed); } else { - stream >> m_state; + if (!discardKV) + stream >> m_state; + else { + QByteArray state; + stream >> m_state; + } } + #if defined(DEBUG) qDebug() << "deserialize" << m_llmThread.objectName(); #endif @@ -823,7 +894,7 @@ void ChatLLM::saveState() void ChatLLM::restoreState() { - if (!isModelLoaded() || m_state.isEmpty()) + if (!isModelLoaded()) return; if (m_llModelType == LLModelType::CHATGPT_) { @@ -838,10 +909,19 @@ void ChatLLM::restoreState() return; } + if (m_restoreStateFromText) { + Q_ASSERT(m_state.isEmpty()); + processRestoreStateFromText(); + } + #if defined(DEBUG) qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size(); #endif m_processedSystemPrompt = true; + + if (m_state.isEmpty()) + return; + m_llModelInfo.model->restoreState(static_cast(reinterpret_cast(m_state.data()))); m_state.clear(); m_state.resize(0); @@ -859,7 +939,10 @@ void ChatLLM::processSystemPrompt() return; } + // Start with a whole new context m_stopGenerating = false; + m_ctx = LLModel::PromptContext(); + auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1); auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1, std::placeholders::_2); @@ -890,5 +973,54 @@ void ChatLLM::processSystemPrompt() printf("\n"); fflush(stdout); #endif - m_processedSystemPrompt = true; + + m_processedSystemPrompt = !m_stopGenerating; +} + +void ChatLLM::processRestoreStateFromText() +{ + Q_ASSERT(isModelLoaded()); + if (!isModelLoaded() || !m_restoreStateFromText || m_isServer) + return; + + m_isRecalc = true; + emit recalcChanged(); + + m_stopGenerating = false; + m_ctx = LLModel::PromptContext(); + + auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1); + auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1, + std::placeholders::_2); + auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1); + + const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo); + const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo); + const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo); + const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo); + const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo); + const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo); + const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo); + const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo); + int n_threads = MySettings::globalInstance()->threadCount(); + m_ctx.n_predict = n_predict; + m_ctx.top_k = top_k; + m_ctx.top_p = top_p; + m_ctx.temp = temp; + m_ctx.n_batch = n_batch; + m_ctx.repeat_penalty = repeat_penalty; + m_ctx.repeat_last_n = repeat_penalty_tokens; + m_llModelInfo.model->setThreadCount(n_threads); + for (auto pair : m_stateFromText) { + const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second; + m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx); + } + + if (!m_stopGenerating) { + m_restoreStateFromText = false; + m_stateFromText.clear(); + } + + m_isRecalc = false; + emit recalcChanged(); } diff --git a/gpt4all-chat/chatllm.h b/gpt4all-chat/chatllm.h index 829540414e75..4e07e48f788a 100644 --- a/gpt4all-chat/chatllm.h +++ b/gpt4all-chat/chatllm.h @@ -92,8 +92,9 @@ class ChatLLM : public QObject QString generatedName() const { return QString::fromStdString(m_nameResponse); } - bool serialize(QDataStream &stream, int version); - bool deserialize(QDataStream &stream, int version); + bool serialize(QDataStream &stream, int version, bool serializeKV); + bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV); + void setStateFromText(const QVector> &stateFromText) { m_stateFromText = stateFromText; } public Q_SLOTS: bool prompt(const QList &collectionList, const QString &prompt); @@ -110,6 +111,7 @@ public Q_SLOTS: void handleForceMetalChanged(bool forceMetal); void handleDeviceChanged(); void processSystemPrompt(); + void processRestoreStateFromText(); Q_SIGNALS: void recalcChanged(); @@ -144,6 +146,9 @@ public Q_SLOTS: bool handleSystemPrompt(int32_t token); bool handleSystemResponse(int32_t token, const std::string &response); bool handleSystemRecalculate(bool isRecalc); + bool handleRestoreStateFromTextPrompt(int32_t token); + bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response); + bool handleRestoreStateFromTextRecalculate(bool isRecalc); void saveState(); void restoreState(); @@ -168,6 +173,8 @@ public Q_SLOTS: bool m_forceMetal; bool m_reloadingToChangeVariant; bool m_processedSystemPrompt; + bool m_restoreStateFromText; + QVector> m_stateFromText; }; #endif // CHATLLM_H diff --git a/gpt4all-chat/chatmodel.h b/gpt4all-chat/chatmodel.h index 2ff4f388028b..0502ca72896a 100644 --- a/gpt4all-chat/chatmodel.h +++ b/gpt4all-chat/chatmodel.h @@ -285,6 +285,14 @@ class ChatModel : public QAbstractListModel return stream.status() == QDataStream::Ok; } + QVector> text() const + { + QVector> result; + for (const auto &c : m_chatItems) + result << qMakePair(c.name, c.value); + return result; + } + Q_SIGNALS: void countChanged();