Skip to content

Commit

Permalink
Restore state from text if necessary.
Browse files Browse the repository at this point in the history
  • Loading branch information
manyoso committed Oct 10, 2023
1 parent ed53852 commit 4244c64
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 24 deletions.
27 changes: 17 additions & 10 deletions gpt4all-chat/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_modelInfo.filename();
if (version > 2)
stream << m_collections;
if (!m_llmodel->serialize(stream, version))
if (!m_llmodel->serialize(stream, version, true /*serializeKV*/))
return false;
if (!m_chatModel->serialize(stream, version))
return false;
Expand All @@ -404,29 +404,36 @@ bool Chat::deserialize(QDataStream &stream, int version)
QString modelId;
stream >> modelId;
if (version > 4) {
if (!ModelList::globalInstance()->contains(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
if (ModelList::globalInstance()->contains(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
} else {
if (!ModelList::globalInstance()->containsByFilename(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
if (ModelList::globalInstance()->containsByFilename(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
}
emit modelInfoChanged();
if (!m_modelInfo.id().isEmpty())
emit modelInfoChanged();

bool deserializeKV = true; // make this a setting
bool discardKV = m_modelInfo.id().isEmpty();

// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_modelInfo.filename().contains("gpt4all-j"))
return false;
discardKV = true;

if (version > 2) {
stream >> m_collections;
emit collectionListChanged(m_collections);
}
m_llmodel->setModelInfo(m_modelInfo);
if (!m_llmodel->deserialize(stream, version))
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
return false;
if (!m_chatModel->deserialize(stream, version))
return false;

if (!deserializeKV || discardKV)
m_llmodel->setStateFromText(m_chatModel->text());

emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}
Expand Down
1 change: 0 additions & 1 deletion gpt4all-chat/chatlistmodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,6 @@ void ChatsRestoreThread::run()
chat->moveToThread(qApp->thread());
if (!chat->deserialize(in, version)) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
file.remove();
} else {
emit chatRestored(chat);
}
Expand Down
154 changes: 143 additions & 11 deletions gpt4all-chat/chatllm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
, m_restoreStateFromText(false)
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
Expand Down Expand Up @@ -726,7 +727,35 @@ bool ChatLLM::handleSystemRecalculate(bool isRecalc)
return false;
}

bool ChatLLM::serialize(QDataStream &stream, int version)
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
#endif
Q_UNUSED(token);
return !m_stopGenerating;
}

bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}

bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}

bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
stream << m_llModelType;
Expand All @@ -741,6 +770,14 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
stream << response();
stream << generatedName();
stream << m_promptResponseTokens;

if (!serializeKV) {
#if defined(DEBUG)
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif
return stream.status() == QDataStream::Ok;
}

if (version <= 3) {
int responseLogits;
stream << responseLogits;
Expand All @@ -759,7 +796,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
return stream.status() == QDataStream::Ok;
}

bool ChatLLM::deserialize(QDataStream &stream, int version)
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{
if (version > 1) {
int internalStateVersion;
Expand All @@ -773,26 +810,60 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
stream >> m_promptResponseTokens;

// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = !deserializeKV || discardKV;

if (!deserializeKV) {
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
return stream.status() == QDataStream::Ok;
}

if (version <= 3) {
int responseLogits;
stream >> responseLogits;
}
stream >> m_ctx.n_past;

int32_t n_past;
stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past;

quint64 logitsSize;
stream >> logitsSize;
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
if (!discardKV) {
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
} else {
stream.skipRawData(logitsSize * sizeof(float));
}

quint64 tokensSize;
stream >> tokensSize;
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
} else {
stream.skipRawData(tokensSize * sizeof(int));
}

if (version > 0) {
QByteArray compressed;
stream >> compressed;
m_state = qUncompress(compressed);
if (!discardKV)
m_state = qUncompress(compressed);
} else {
stream >> m_state;
if (!discardKV)
stream >> m_state;
else {
QByteArray state;
stream >> m_state;
}
}

#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
Expand Down Expand Up @@ -823,7 +894,7 @@ void ChatLLM::saveState()

void ChatLLM::restoreState()
{
if (!isModelLoaded() || m_state.isEmpty())
if (!isModelLoaded())
return;

if (m_llModelType == LLModelType::CHATGPT_) {
Expand All @@ -838,10 +909,19 @@ void ChatLLM::restoreState()
return;
}

if (m_restoreStateFromText) {
Q_ASSERT(m_state.isEmpty());
processRestoreStateFromText();
}

#if defined(DEBUG)
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_processedSystemPrompt = true;

if (m_state.isEmpty())
return;

m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear();
m_state.resize(0);
Expand All @@ -859,7 +939,10 @@ void ChatLLM::processSystemPrompt()
return;
}

// Start with a whole new context
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();

auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
std::placeholders::_2);
Expand Down Expand Up @@ -890,5 +973,54 @@ void ChatLLM::processSystemPrompt()
printf("\n");
fflush(stdout);
#endif
m_processedSystemPrompt = true;

m_processedSystemPrompt = !m_stopGenerating;
}

void ChatLLM::processRestoreStateFromText()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;

m_isRecalc = true;
emit recalcChanged();

m_stopGenerating = false;
m_ctx = LLModel::PromptContext();

auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);

const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
for (auto pair : m_stateFromText) {
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
}

if (!m_stopGenerating) {
m_restoreStateFromText = false;
m_stateFromText.clear();
}

m_isRecalc = false;
emit recalcChanged();
}
11 changes: 9 additions & 2 deletions gpt4all-chat/chatllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ class ChatLLM : public QObject

QString generatedName() const { return QString::fromStdString(m_nameResponse); }

bool serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }

public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
Expand All @@ -110,6 +111,7 @@ public Q_SLOTS:
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processSystemPrompt();
void processRestoreStateFromText();

Q_SIGNALS:
void recalcChanged();
Expand Down Expand Up @@ -144,6 +146,9 @@ public Q_SLOTS:
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleSystemRecalculate(bool isRecalc);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
void saveState();
void restoreState();

Expand All @@ -168,6 +173,8 @@ public Q_SLOTS:
bool m_forceMetal;
bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
QVector<QPair<QString, QString>> m_stateFromText;
};

#endif // CHATLLM_H
8 changes: 8 additions & 0 deletions gpt4all-chat/chatmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,14 @@ class ChatModel : public QAbstractListModel
return stream.status() == QDataStream::Ok;
}

QVector<QPair<QString, QString>> text() const
{
QVector<QPair<QString, QString>> result;
for (const auto &c : m_chatItems)
result << qMakePair(c.name, c.value);
return result;
}

Q_SIGNALS:
void countChanged();

Expand Down

0 comments on commit 4244c64

Please sign in to comment.