Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore state from text #1493

Merged
merged 2 commits into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions gpt4all-chat/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_modelInfo.filename();
if (version > 2)
stream << m_collections;
if (!m_llmodel->serialize(stream, version))
if (!m_llmodel->serialize(stream, version, true /*serializeKV*/))
return false;
if (!m_chatModel->serialize(stream, version))
return false;
Expand All @@ -404,29 +404,36 @@ bool Chat::deserialize(QDataStream &stream, int version)
QString modelId;
stream >> modelId;
if (version > 4) {
if (!ModelList::globalInstance()->contains(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
if (ModelList::globalInstance()->contains(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
} else {
if (!ModelList::globalInstance()->containsByFilename(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
if (ModelList::globalInstance()->containsByFilename(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
}
emit modelInfoChanged();
if (!m_modelInfo.id().isEmpty())
emit modelInfoChanged();

bool deserializeKV = true; // make this a setting
bool discardKV = m_modelInfo.id().isEmpty();

// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_modelInfo.filename().contains("gpt4all-j"))
return false;
discardKV = true;

if (version > 2) {
stream >> m_collections;
emit collectionListChanged(m_collections);
}
m_llmodel->setModelInfo(m_modelInfo);
if (!m_llmodel->deserialize(stream, version))
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
return false;
if (!m_chatModel->deserialize(stream, version))
return false;

if (!deserializeKV || discardKV)
m_llmodel->setStateFromText(m_chatModel->text());

emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}
Expand Down
23 changes: 15 additions & 8 deletions gpt4all-chat/chatlistmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,25 +84,33 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
const QString savePath = MySettings::globalInstance()->modelPath();
for (Chat *chat : chats) {
QString fileName = "gpt4all-" + chat->id() + ".chat";
QFile file(savePath + "/" + fileName);
bool success = file.open(QIODevice::WriteOnly);
QString filePath = savePath + "/" + fileName;
QFile originalFile(filePath);
QFile tempFile(filePath + ".tmp"); // Temporary file

bool success = tempFile.open(QIODevice::WriteOnly);
if (!success) {
qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName();
qWarning() << "ERROR: Couldn't save chat to temporary file:" << tempFile.fileName();
continue;
}
QDataStream out(&file);
QDataStream out(&tempFile);

out << (quint32)CHAT_FORMAT_MAGIC;
out << (qint32)CHAT_FORMAT_VERSION;
out.setVersion(QDataStream::Qt_6_2);

qDebug() << "serializing chat" << fileName;
if (!chat->serialize(out, CHAT_FORMAT_VERSION)) {
qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName();
file.remove();
qWarning() << "ERROR: Couldn't serialize chat to file:" << tempFile.fileName();
tempFile.remove();
continue;
}
file.close();

if (originalFile.exists())
originalFile.remove();
tempFile.rename(filePath);
Copy link
Member

@apage43 apage43 Oct 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename should be able to atomically replace a file without needing to remove it first. this should be fine either way though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://doc.qt.io/qt-6/qfile.html#rename

If a file with the name newName already exists, rename() returns false (i.e., QFile will not overwrite it).

}

qint64 elapsedTime = timer.elapsed();
qDebug() << "serializing chats took:" << elapsedTime << "ms";
emit saveChatsFinished();
Expand Down Expand Up @@ -224,7 +232,6 @@ void ChatsRestoreThread::run()
chat->moveToThread(qApp->thread());
if (!chat->deserialize(in, version)) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
file.remove();
} else {
emit chatRestored(chat);
}
Expand Down
154 changes: 143 additions & 11 deletions gpt4all-chat/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
, m_restoreStateFromText(false)
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
Expand Down Expand Up @@ -726,7 +727,35 @@ bool ChatLLM::handleSystemRecalculate(bool isRecalc)
return false;
}

bool ChatLLM::serialize(QDataStream &stream, int version)
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
#endif
Q_UNUSED(token);
return !m_stopGenerating;
}

bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}

bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}

bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
stream << m_llModelType;
Expand All @@ -741,6 +770,14 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
stream << response();
stream << generatedName();
stream << m_promptResponseTokens;

if (!serializeKV) {
#if defined(DEBUG)
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif
return stream.status() == QDataStream::Ok;
}

if (version <= 3) {
int responseLogits;
stream << responseLogits;
Expand All @@ -759,7 +796,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
return stream.status() == QDataStream::Ok;
}

bool ChatLLM::deserialize(QDataStream &stream, int version)
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{
if (version > 1) {
int internalStateVersion;
Expand All @@ -773,26 +810,60 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
stream >> m_promptResponseTokens;

// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = !deserializeKV || discardKV;

if (!deserializeKV) {
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
return stream.status() == QDataStream::Ok;
}

if (version <= 3) {
int responseLogits;
stream >> responseLogits;
}
stream >> m_ctx.n_past;

int32_t n_past;
stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past;

quint64 logitsSize;
stream >> logitsSize;
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
if (!discardKV) {
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
} else {
stream.skipRawData(logitsSize * sizeof(float));
}

quint64 tokensSize;
stream >> tokensSize;
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
} else {
stream.skipRawData(tokensSize * sizeof(int));
}

if (version > 0) {
QByteArray compressed;
stream >> compressed;
m_state = qUncompress(compressed);
if (!discardKV)
m_state = qUncompress(compressed);
} else {
stream >> m_state;
if (!discardKV)
stream >> m_state;
else {
QByteArray state;
stream >> m_state;
}
}

#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
Expand Down Expand Up @@ -823,7 +894,7 @@ void ChatLLM::saveState()

void ChatLLM::restoreState()
{
if (!isModelLoaded() || m_state.isEmpty())
if (!isModelLoaded())
return;

if (m_llModelType == LLModelType::CHATGPT_) {
Expand All @@ -838,10 +909,19 @@ void ChatLLM::restoreState()
return;
}

if (m_restoreStateFromText) {
Q_ASSERT(m_state.isEmpty());
processRestoreStateFromText();
}

#if defined(DEBUG)
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_processedSystemPrompt = true;

if (m_state.isEmpty())
return;

m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear();
m_state.resize(0);
Expand All @@ -859,7 +939,10 @@ void ChatLLM::processSystemPrompt()
return;
}

// Start with a whole new context
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();

auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
std::placeholders::_2);
Expand Down Expand Up @@ -890,5 +973,54 @@ void ChatLLM::processSystemPrompt()
printf("\n");
fflush(stdout);
#endif
m_processedSystemPrompt = true;

m_processedSystemPrompt = !m_stopGenerating;
}

void ChatLLM::processRestoreStateFromText()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;

m_isRecalc = true;
emit recalcChanged();

m_stopGenerating = false;
m_ctx = LLModel::PromptContext();

auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);

const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
for (auto pair : m_stateFromText) {
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
}

if (!m_stopGenerating) {
m_restoreStateFromText = false;
m_stateFromText.clear();
}

m_isRecalc = false;
emit recalcChanged();
}
11 changes: 9 additions & 2 deletions gpt4all-chat/chatllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ class ChatLLM : public QObject

QString generatedName() const { return QString::fromStdString(m_nameResponse); }

bool serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }

public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
Expand All @@ -110,6 +111,7 @@ public Q_SLOTS:
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processSystemPrompt();
void processRestoreStateFromText();

Q_SIGNALS:
void recalcChanged();
Expand Down Expand Up @@ -144,6 +146,9 @@ public Q_SLOTS:
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleSystemRecalculate(bool isRecalc);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
void saveState();
void restoreState();

Expand All @@ -168,6 +173,8 @@ public Q_SLOTS:
bool m_forceMetal;
bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
QVector<QPair<QString, QString>> m_stateFromText;
};

#endif // CHATLLM_H
8 changes: 8 additions & 0 deletions gpt4all-chat/chatmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ class ChatModel : public QAbstractListModel
return stream.status() == QDataStream::Ok;
}

QVector<QPair<QString, QString>> text() const
{
QVector<QPair<QString, QString>> result;
for (const auto &c : m_chatItems)
result << qMakePair(c.name, c.value);
return result;
}

Q_SIGNALS:
void countChanged();

Expand Down
Loading