Skip to content

Commit

Permalink
Brave search tool calling.
Browse files Browse the repository at this point in the history
Signed-off-by: Adam Treat <treat.adam@gmail.com>
  • Loading branch information
manyoso committed Jul 25, 2024
1 parent f9cd2e3 commit 1bafbaa
Show file tree
Hide file tree
Showing 19 changed files with 650 additions and 138 deletions.
3 changes: 3 additions & 0 deletions gpt4all-chat/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ endif()

qt_add_executable(chat
main.cpp
bravesearch.h bravesearch.cpp
chat.h chat.cpp
chatllm.h chatllm.cpp
chatmodel.h chatlistmodel.h chatlistmodel.cpp
Expand All @@ -122,6 +123,7 @@ qt_add_executable(chat
modellist.h modellist.cpp
mysettings.h mysettings.cpp
network.h network.cpp
sourceexcerpt.h
server.h server.cpp
logger.h logger.cpp
${APP_ICON_RESOURCE}
Expand Down Expand Up @@ -155,6 +157,7 @@ qt_add_qml_module(chat
qml/ThumbsDownDialog.qml
qml/Toast.qml
qml/ToastManager.qml
qml/ToolSettings.qml
qml/MyBusyIndicator.qml
qml/MyButton.qml
qml/MyCheckBox.qml
Expand Down
221 changes: 221 additions & 0 deletions gpt4all-chat/bravesearch.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
#include "bravesearch.h"

#include <QCoreApplication>
#include <QDebug>
#include <QGuiApplication>
#include <QJsonArray>
#include <QJsonDocument>
#include <QJsonObject>
#include <QJsonValue>
#include <QNetworkAccessManager>
#include <QNetworkReply>
#include <QNetworkRequest>
#include <QThread>
#include <QUrl>
#include <QUrlQuery>

using namespace Qt::Literals::StringLiterals;

QPair<QString, QList<SourceExcerpt>> BraveSearch::search(const QString &apiKey, const QString &query, int topK, unsigned long timeout)
{
QThread workerThread;
BraveAPIWorker worker;
worker.moveToThread(&workerThread);
connect(&worker, &BraveAPIWorker::finished, &workerThread, &QThread::quit, Qt::DirectConnection);
connect(this, &BraveSearch::request, &worker, &BraveAPIWorker::request, Qt::QueuedConnection);
workerThread.start();
emit request(apiKey, query, topK);
workerThread.wait(timeout);
workerThread.quit();
workerThread.wait();
return worker.response();
}

void BraveAPIWorker::request(const QString &apiKey, const QString &query, int topK)
{
m_topK = topK;
QUrl jsonUrl("https://api.search.brave.com/res/v1/web/search");
QUrlQuery urlQuery;
urlQuery.addQueryItem("q", query);
jsonUrl.setQuery(urlQuery);
QNetworkRequest request(jsonUrl);
QSslConfiguration conf = request.sslConfiguration();
conf.setPeerVerifyMode(QSslSocket::VerifyNone);
request.setSslConfiguration(conf);

request.setRawHeader("X-Subscription-Token", apiKey.toUtf8());
// request.setRawHeader("Accept-Encoding", "gzip");
request.setRawHeader("Accept", "application/json");

m_networkManager = new QNetworkAccessManager(this);
QNetworkReply *reply = m_networkManager->get(request);
connect(qGuiApp, &QCoreApplication::aboutToQuit, reply, &QNetworkReply::abort);
connect(reply, &QNetworkReply::finished, this, &BraveAPIWorker::handleFinished);
connect(reply, &QNetworkReply::errorOccurred, this, &BraveAPIWorker::handleErrorOccurred);
}

static QPair<QString, QList<SourceExcerpt>> cleanBraveResponse(const QByteArray& jsonResponse, qsizetype topK = 1)
{
QJsonParseError err;
QJsonDocument document = QJsonDocument::fromJson(jsonResponse, &err);
if (err.error != QJsonParseError::NoError) {
qWarning() << "ERROR: Couldn't parse: " << jsonResponse << err.errorString();
return QPair<QString, QList<SourceExcerpt>>();
}

QJsonObject searchResponse = document.object();
QJsonObject cleanResponse;
QString query;
QJsonArray cleanArray;

QList<SourceExcerpt> infos;

if (searchResponse.contains("query")) {
QJsonObject queryObj = searchResponse["query"].toObject();
if (queryObj.contains("original")) {
query = queryObj["original"].toString();
}
}

if (searchResponse.contains("mixed")) {
QJsonObject mixedResults = searchResponse["mixed"].toObject();
QJsonArray mainResults = mixedResults["main"].toArray();

for (int i = 0; i < std::min(mainResults.size(), topK); ++i) {
QJsonObject m = mainResults[i].toObject();
QString r_type = m["type"].toString();
int idx = m["index"].toInt();
QJsonObject resultsObject = searchResponse[r_type].toObject();
QJsonArray resultsArray = resultsObject["results"].toArray();

QJsonValue cleaned;
SourceExcerpt info;
if (r_type == "web") {
// For web data - add a single output from the search
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description", "date", "extra_snippets"};
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (resultObj.contains(key)) {
cleanedObj.insert(key, resultObj[key]);
}
}

info.date = resultObj["date"].toString();
info.text = resultObj["description"].toString(); // fixme
info.url = resultObj["url"].toString();
QJsonObject meta_url = resultObj["meta_url"].toObject();
info.favicon = meta_url["favicon"].toString();
info.title = resultObj["title"].toString();

cleaned = cleanedObj;
} else if (r_type == "faq") {
// For faq data - take a list of all the questions & answers
QStringList selectedKeys = {"type", "question", "answer", "title", "url"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "infobox") {
QJsonObject resultObj = resultsArray[idx].toObject();
QStringList selectedKeys = {"type", "title", "url", "description", "long_desc"};
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (resultObj.contains(key)) {
cleanedObj.insert(key, resultObj[key]);
}
}
cleaned = cleanedObj;
} else if (r_type == "videos") {
QStringList selectedKeys = {"type", "url", "title", "description", "date"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "locations") {
QStringList selectedKeys = {"type", "title", "url", "description", "coordinates", "postal_address", "contact", "rating", "distance", "zoom_level"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else if (r_type == "news") {
QStringList selectedKeys = {"type", "title", "url", "description"};
QJsonArray cleanedArray;
for (const auto& q : resultsArray) {
QJsonObject qObj = q.toObject();
QJsonObject cleanedObj;
for (const auto& key : selectedKeys) {
if (qObj.contains(key)) {
cleanedObj.insert(key, qObj[key]);
}
}
cleanedArray.append(cleanedObj);
}
cleaned = cleanedArray;
} else {
cleaned = QJsonValue();
}

infos.append(info);
cleanArray.append(cleaned);
}
}

cleanResponse.insert("query", query);
cleanResponse.insert("top_k", cleanArray);
QJsonDocument cleanedDoc(cleanResponse);

// qDebug().noquote() << document.toJson(QJsonDocument::Indented);
// qDebug().noquote() << cleanedDoc.toJson(QJsonDocument::Indented);

return qMakePair(cleanedDoc.toJson(QJsonDocument::Indented), infos);
}

void BraveAPIWorker::handleFinished()
{
QNetworkReply *jsonReply = qobject_cast<QNetworkReply *>(sender());
Q_ASSERT(jsonReply);

if (jsonReply->error() == QNetworkReply::NoError && jsonReply->isFinished()) {
QByteArray jsonData = jsonReply->readAll();
jsonReply->deleteLater();
m_response = cleanBraveResponse(jsonData, m_topK);
} else {
QByteArray jsonData = jsonReply->readAll();
qWarning() << "ERROR: Could not search brave" << jsonReply->error() << jsonReply->errorString() << jsonData;
jsonReply->deleteLater();
}
}

void BraveAPIWorker::handleErrorOccurred(QNetworkReply::NetworkError code)
{
QNetworkReply *reply = qobject_cast<QNetworkReply *>(sender());
Q_ASSERT(reply);
qWarning().noquote() << "ERROR: BraveAPIWorker::handleErrorOccurred got HTTP Error" << code << "response:"
<< reply->errorString();
emit finished();
}
51 changes: 51 additions & 0 deletions gpt4all-chat/bravesearch.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef BRAVESEARCH_H
#define BRAVESEARCH_H

#include "sourceexcerpt.h"

#include <QObject>
#include <QString>
#include <QNetworkAccessManager>
#include <QNetworkReply>

class BraveAPIWorker : public QObject {
Q_OBJECT
public:
BraveAPIWorker()
: QObject(nullptr)
, m_networkManager(nullptr)
, m_topK(1) {}
virtual ~BraveAPIWorker() {}

QPair<QString, QList<SourceExcerpt>> response() const { return m_response; }

public Q_SLOTS:
void request(const QString &apiKey, const QString &query, int topK);

Q_SIGNALS:
void finished();

private Q_SLOTS:
void handleFinished();
void handleErrorOccurred(QNetworkReply::NetworkError code);

private:
QNetworkAccessManager *m_networkManager;
QPair<QString, QList<SourceExcerpt>> m_response;
int m_topK;
};

class BraveSearch : public QObject {
Q_OBJECT
public:
BraveSearch()
: QObject(nullptr) {}
virtual ~BraveSearch() {}

QPair<QString, QList<SourceExcerpt>> search(const QString &apiKey, const QString &query, int topK, unsigned long timeout = 2000);

Q_SIGNALS:
void request(const QString &apiKey, const QString &query, int topK);
};

#endif // BRAVESEARCH_H
29 changes: 22 additions & 7 deletions gpt4all-chat/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::promptProcessing, this, &Chat::promptProcessing, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatingQuestions, this, &Chat::generatingQuestions, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::toolCalled, this, &Chat::toolCalled, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingError, this, &Chat::handleModelLoadingError, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelLoadingWarning, this, &Chat::modelLoadingWarning, Qt::QueuedConnection);
Expand All @@ -67,7 +68,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::generatedQuestionFinished, this, &Chat::generatedQuestionFinished, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::loadedModelInfoChanged, this, &Chat::loadedModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::sourceExcerptsChanged, this, &Chat::handleSourceExcerptsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::trySwitchContextOfLoadedModelCompleted, this, &Chat::handleTrySwitchContextOfLoadedModelCompleted, Qt::QueuedConnection);

Expand Down Expand Up @@ -121,6 +122,7 @@ void Chat::resetResponseState()
emit tokenSpeedChanged();
m_responseInProgress = true;
m_responseState = m_collections.empty() ? Chat::PromptProcessing : Chat::LocalDocsRetrieval;
m_toolDescription = QString();
emit responseInProgressChanged();
emit responseStateChanged();
}
Expand All @@ -134,7 +136,7 @@ void Chat::prompt(const QString &prompt)
void Chat::regenerateResponse()
{
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, QList<ResultInfo>());
m_chatModel->updateSources(index, QList<SourceExcerpt>());
emit regenerateResponseRequested();
}

Expand Down Expand Up @@ -189,8 +191,13 @@ void Chat::handleModelLoadingPercentageChanged(float loadingPercentage)

void Chat::promptProcessing()
{
m_responseState = !databaseResults().isEmpty() ? Chat::LocalDocsProcessing : Chat::PromptProcessing;
emit responseStateChanged();
if (sourceExcerpts().isEmpty())
m_responseState = Chat::PromptProcessing;
else if (m_responseState == Chat::ToolCalled)
m_responseState = Chat::ToolProcessing;
else
m_responseState = Chat::LocalDocsProcessing;
emit responseStateChanged();
}

void Chat::generatingQuestions()
Expand All @@ -199,6 +206,14 @@ void Chat::generatingQuestions()
emit responseStateChanged();
}

void Chat::toolCalled(const QString &description)
{
m_responseState = Chat::ToolCalled;
m_toolDescription = description;
emit toolDescriptionChanged();
emit responseStateChanged();
}

void Chat::responseStopped(qint64 promptResponseMs)
{
m_tokenSpeed = QString();
Expand Down Expand Up @@ -357,11 +372,11 @@ QString Chat::fallbackReason() const
return m_llmodel->fallbackReason();
}

void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
void Chat::handleSourceExcerptsChanged(const QList<SourceExcerpt> &sourceExcerpts)
{
m_databaseResults = results;
m_sourceExcerpts = sourceExcerpts;
const int index = m_chatModel->count() - 1;
m_chatModel->updateSources(index, m_databaseResults);
m_chatModel->updateSources(index, m_sourceExcerpts);
}

void Chat::handleModelInfoChanged(const ModelInfo &modelInfo)
Expand Down
Loading

0 comments on commit 1bafbaa

Please sign in to comment.