Compare commits

..

6 Commits

Author SHA1 Message Date
Adam Treat
655372dbfa Release notes for v2.4.17 and bump the version. 2023-09-14 17:11:04 -04:00
Adam Treat
aa33419c6e Fallback to CPU more robustly. 2023-09-14 16:53:11 -04:00
Adam Treat
79843c269e Release notes for v2.4.16 and bump the version. 2023-09-14 11:24:25 -04:00
Adam Treat
9013a089bd Bump to new llama with new bugfix. 2023-09-14 10:02:11 -04:00
Adam Treat
3076e0bf26 Only show GPU when we're actually using it. 2023-09-14 09:59:19 -04:00
Adam Treat
1fa67a585c Report the actual device we're using. 2023-09-14 08:25:37 -04:00
13 changed files with 91 additions and 9 deletions

@ -1 +1 @@
Subproject commit 8616ce08e5d48d2e17c06ae6af932b50d1d8a6e9 Subproject commit 703ef9c1252aff4f6c4e1fdc60fffe6ab9def377

View File

@ -168,6 +168,10 @@ bool LLamaModel::loadModel(const std::string &modelPath)
d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params); d_ptr->ctx = llama_init_from_file(modelPath.c_str(), d_ptr->params);
if (!d_ptr->ctx) { if (!d_ptr->ctx) {
#ifdef GGML_USE_KOMPUTE
// Explicitly free the device so next load it doesn't use it
ggml_vk_free_device();
#endif
std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl; std::cerr << "LLAMA ERROR: failed to load model from " << modelPath << std::endl;
return false; return false;
} }
@ -194,7 +198,7 @@ int32_t LLamaModel::threadCount() const {
LLamaModel::~LLamaModel() LLamaModel::~LLamaModel()
{ {
if(d_ptr->ctx) { if (d_ptr->ctx) {
llama_free(d_ptr->ctx); llama_free(d_ptr->ctx);
} }
} }
@ -337,6 +341,16 @@ bool LLamaModel::hasGPUDevice()
#endif #endif
} }
bool LLamaModel::usingGPUDevice()
{
#if defined(GGML_USE_KOMPUTE)
return ggml_vk_using_vulkan();
#elif defined(GGML_USE_METAL)
return true;
#endif
return false;
}
#if defined(_WIN32) #if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport) #define DLL_EXPORT __declspec(dllexport)
#else #else

View File

@ -30,6 +30,7 @@ public:
bool initializeGPUDevice(const GPUDevice &device) override; bool initializeGPUDevice(const GPUDevice &device) override;
bool initializeGPUDevice(int device) override; bool initializeGPUDevice(int device) override;
bool hasGPUDevice() override; bool hasGPUDevice() override;
bool usingGPUDevice() override;
private: private:
LLamaPrivate *d_ptr; LLamaPrivate *d_ptr;

View File

@ -100,6 +100,7 @@ public:
virtual bool initializeGPUDevice(const GPUDevice &/*device*/) { return false; } virtual bool initializeGPUDevice(const GPUDevice &/*device*/) { return false; }
virtual bool initializeGPUDevice(int /*device*/) { return false; } virtual bool initializeGPUDevice(int /*device*/) { return false; }
virtual bool hasGPUDevice() { return false; } virtual bool hasGPUDevice() { return false; }
virtual bool usingGPUDevice() { return false; }
protected: protected:
// These are pure virtual because subclasses need to implement as the default implementation of // These are pure virtual because subclasses need to implement as the default implementation of

View File

@ -163,7 +163,7 @@ struct mpt_hparams {
int32_t n_embd = 0; //max_seq_len int32_t n_embd = 0; //max_seq_len
int32_t n_head = 0; // n_heads int32_t n_head = 0; // n_heads
int32_t n_layer = 0; //n_layers int32_t n_layer = 0; //n_layers
int32_t ftype = 0; int32_t ftype = 0;
}; };
struct replit_layer { struct replit_layer {
@ -220,7 +220,7 @@ static bool kv_cache_init(
params.mem_size = cache.buf.size; params.mem_size = cache.buf.size;
params.mem_buffer = cache.buf.addr; params.mem_buffer = cache.buf.addr;
params.no_alloc = false; params.no_alloc = false;
cache.ctx = ggml_init(params); cache.ctx = ggml_init(params);
if (!cache.ctx) { if (!cache.ctx) {
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__); fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
@ -503,7 +503,7 @@ bool replit_model_load(const std::string & fname, std::istream &fin, replit_mode
} }
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "data", data_ptr, data_size, max_size)); GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "data", data_ptr, data_size, max_size));
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "kv", ggml_get_mem_buffer(model.kv_self.ctx), GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "kv", ggml_get_mem_buffer(model.kv_self.ctx),
ggml_get_mem_size(model.kv_self.ctx), 0)); ggml_get_mem_size(model.kv_self.ctx), 0));
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "eval", model.eval_buf.addr, model.eval_buf.size, 0)); GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "eval", model.eval_buf.addr, model.eval_buf.size, 0));
GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "scr0", model.scr0_buf.addr, model.scr0_buf.size, 0)); GGML_CHECK_BUF(ggml_metal_add_buffer(model.ctx_metal, "scr0", model.scr0_buf.addr, model.scr0_buf.size, 0));
@ -975,6 +975,14 @@ const std::vector<LLModel::Token> &Replit::endTokens() const
return fres; return fres;
} }
bool Replit::usingGPUDevice()
{
#if defined(GGML_USE_METAL)
return true;
#endif
return false;
}
#if defined(_WIN32) #if defined(_WIN32)
#define DLL_EXPORT __declspec(dllexport) #define DLL_EXPORT __declspec(dllexport)
#else #else

View File

@ -27,6 +27,7 @@ public:
size_t restoreState(const uint8_t *src) override; size_t restoreState(const uint8_t *src) override;
void setThreadCount(int32_t n_threads) override; void setThreadCount(int32_t n_threads) override;
int32_t threadCount() const override; int32_t threadCount() const override;
bool usingGPUDevice() override;
private: private:
ReplitPrivate *d_ptr; ReplitPrivate *d_ptr;

View File

@ -18,7 +18,7 @@ endif()
set(APP_VERSION_MAJOR 2) set(APP_VERSION_MAJOR 2)
set(APP_VERSION_MINOR 4) set(APP_VERSION_MINOR 4)
set(APP_VERSION_PATCH 16) set(APP_VERSION_PATCH 17)
set(APP_VERSION "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}") set(APP_VERSION "${APP_VERSION_MAJOR}.${APP_VERSION_MINOR}.${APP_VERSION_PATCH}")
# Include the binary directory for the generated header file # Include the binary directory for the generated header file

View File

@ -56,6 +56,7 @@ void Chat::connectLLM()
connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::recalcChanged, this, &Chat::handleRecalculating, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::generatedNameChanged, this, &Chat::generatedNameChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::reportSpeed, this, &Chat::handleTokenSpeedChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::reportDevice, this, &Chat::handleDeviceChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::databaseResultsChanged, this, &Chat::handleDatabaseResultsChanged, Qt::QueuedConnection);
connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection); connect(m_llmodel, &ChatLLM::modelInfoChanged, this, &Chat::handleModelInfoChanged, Qt::QueuedConnection);
@ -345,6 +346,12 @@ void Chat::handleTokenSpeedChanged(const QString &tokenSpeed)
emit tokenSpeedChanged(); emit tokenSpeedChanged();
} }
void Chat::handleDeviceChanged(const QString &device)
{
m_device = device;
emit deviceChanged();
}
void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results) void Chat::handleDatabaseResultsChanged(const QList<ResultInfo> &results)
{ {
m_databaseResults = results; m_databaseResults = results;

View File

@ -25,6 +25,7 @@ class Chat : public QObject
Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged) Q_PROPERTY(QList<QString> collectionList READ collectionList NOTIFY collectionListChanged)
Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged) Q_PROPERTY(QString modelLoadingError READ modelLoadingError NOTIFY modelLoadingErrorChanged)
Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged); Q_PROPERTY(QString tokenSpeed READ tokenSpeed NOTIFY tokenSpeedChanged);
Q_PROPERTY(QString device READ device NOTIFY deviceChanged);
QML_ELEMENT QML_ELEMENT
QML_UNCREATABLE("Only creatable from c++!") QML_UNCREATABLE("Only creatable from c++!")
@ -88,6 +89,7 @@ public:
QString modelLoadingError() const { return m_modelLoadingError; } QString modelLoadingError() const { return m_modelLoadingError; }
QString tokenSpeed() const { return m_tokenSpeed; } QString tokenSpeed() const { return m_tokenSpeed; }
QString device() const { return m_device; }
public Q_SLOTS: public Q_SLOTS:
void serverNewPromptResponsePair(const QString &prompt); void serverNewPromptResponsePair(const QString &prompt);
@ -115,6 +117,7 @@ Q_SIGNALS:
void isServerChanged(); void isServerChanged();
void collectionListChanged(const QList<QString> &collectionList); void collectionListChanged(const QList<QString> &collectionList);
void tokenSpeedChanged(); void tokenSpeedChanged();
void deviceChanged();
private Q_SLOTS: private Q_SLOTS:
void handleResponseChanged(const QString &response); void handleResponseChanged(const QString &response);
@ -125,6 +128,7 @@ private Q_SLOTS:
void handleRecalculating(); void handleRecalculating();
void handleModelLoadingError(const QString &error); void handleModelLoadingError(const QString &error);
void handleTokenSpeedChanged(const QString &tokenSpeed); void handleTokenSpeedChanged(const QString &tokenSpeed);
void handleDeviceChanged(const QString &device);
void handleDatabaseResultsChanged(const QList<ResultInfo> &results); void handleDatabaseResultsChanged(const QList<ResultInfo> &results);
void handleModelInfoChanged(const ModelInfo &modelInfo); void handleModelInfoChanged(const ModelInfo &modelInfo);
void handleModelInstalled(); void handleModelInstalled();
@ -137,6 +141,7 @@ private:
ModelInfo m_modelInfo; ModelInfo m_modelInfo;
QString m_modelLoadingError; QString m_modelLoadingError;
QString m_tokenSpeed; QString m_tokenSpeed;
QString m_device;
QString m_response; QString m_response;
QList<QString> m_collections; QList<QString> m_collections;
ChatModel *m_chatModel; ChatModel *m_chatModel;

View File

@ -271,31 +271,48 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
MySettings::globalInstance()->setDeviceList(deviceList); MySettings::globalInstance()->setDeviceList(deviceList);
// Pick the best match for the device // Pick the best match for the device
QString actualDevice = m_llModelInfo.model->implementation().buildVariant() == "metal" ? "Metal" : "CPU";
const QString requestedDevice = MySettings::globalInstance()->device(); const QString requestedDevice = MySettings::globalInstance()->device();
if (requestedDevice != "CPU") { if (requestedDevice != "CPU") {
const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString()); const size_t requiredMemory = m_llModelInfo.model->requiredMem(filePath.toStdString());
std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory); std::vector<LLModel::GPUDevice> availableDevices = m_llModelInfo.model->availableGPUDevices(requiredMemory);
if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) { if (!availableDevices.empty() && requestedDevice == "Auto" && availableDevices.front().type == 2 /*a discrete gpu*/) {
m_llModelInfo.model->initializeGPUDevice(availableDevices.front()); m_llModelInfo.model->initializeGPUDevice(availableDevices.front());
actualDevice = QString::fromStdString(availableDevices.front().name);
} else { } else {
for (LLModel::GPUDevice &d : availableDevices) { for (LLModel::GPUDevice &d : availableDevices) {
if (QString::fromStdString(d.name) == requestedDevice) { if (QString::fromStdString(d.name) == requestedDevice) {
m_llModelInfo.model->initializeGPUDevice(d); m_llModelInfo.model->initializeGPUDevice(d);
actualDevice = QString::fromStdString(d.name);
break; break;
} }
} }
} }
} }
// Report which device we're actually using
emit reportDevice(actualDevice);
bool success = m_llModelInfo.model->loadModel(filePath.toStdString()); bool success = m_llModelInfo.model->loadModel(filePath.toStdString());
if (!success && actualDevice != "CPU") {
emit reportDevice("CPU");
success = m_llModelInfo.model->loadModel(filePath.toStdString());
}
MySettings::globalInstance()->setAttemptModelLoad(QString()); MySettings::globalInstance()->setAttemptModelLoad(QString());
if (!success) { if (!success) {
delete std::exchange(m_llModelInfo.model, nullptr); delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
if (!m_isServer) if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();
emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename())); emit modelLoadingError(QString("Could not load model due to invalid model file for %1").arg(modelInfo.filename()));
} else { } else {
// We might have had to fallback to CPU after load if the model is not possible to accelerate
// for instance if the quantization method is not supported on Vulkan yet
if (actualDevice != "CPU" && !m_llModelInfo.model->usingGPUDevice())
emit reportDevice("CPU");
switch (m_llModelInfo.model->implementation().modelType()[0]) { switch (m_llModelInfo.model->implementation().modelType()[0]) {
case 'L': m_llModelType = LLModelType::LLAMA_; break; case 'L': m_llModelType = LLModelType::LLAMA_; break;
case 'G': m_llModelType = LLModelType::GPTJ_; break; case 'G': m_llModelType = LLModelType::GPTJ_; break;
@ -306,7 +323,8 @@ bool ChatLLM::loadModel(const ModelInfo &modelInfo)
case 'S': m_llModelType = LLModelType::STARCODER_; break; case 'S': m_llModelType = LLModelType::STARCODER_; break;
default: default:
{ {
delete std::exchange(m_llModelInfo.model, nullptr); delete m_llModelInfo.model;
m_llModelInfo.model = nullptr;
if (!m_isServer) if (!m_isServer)
LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store LLModelStore::globalInstance()->releaseModel(m_llModelInfo); // release back into the store
m_llModelInfo = LLModelInfo(); m_llModelInfo = LLModelInfo();

View File

@ -129,6 +129,7 @@ Q_SIGNALS:
void shouldBeLoadedChanged(); void shouldBeLoadedChanged();
void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results); void requestRetrieveFromDB(const QList<QString> &collections, const QString &text, int retrievalSize, QList<ResultInfo> *results);
void reportSpeed(const QString &speed); void reportSpeed(const QString &speed);
void reportDevice(const QString &device);
void databaseResultsChanged(const QList<ResultInfo>&); void databaseResultsChanged(const QList<ResultInfo>&);
void modelInfoChanged(const ModelInfo &modelInfo); void modelInfoChanged(const ModelInfo &modelInfo);

View File

@ -1013,7 +1013,7 @@ Window {
anchors.rightMargin: 30 anchors.rightMargin: 30
color: theme.mutedTextColor color: theme.mutedTextColor
visible: currentChat.tokenSpeed !== "" visible: currentChat.tokenSpeed !== ""
text: qsTr("Speed: ") + currentChat.tokenSpeed + "<br>" + qsTr("Device: ") + MySettings.device text: qsTr("Speed: ") + currentChat.tokenSpeed + "<br>" + qsTr("Device: ") + currentChat.device
font.pixelSize: theme.fontSizeLarge font.pixelSize: theme.fontSizeLarge
} }

View File

@ -480,6 +480,32 @@
* Aaron Miller (Nomic AI) * Aaron Miller (Nomic AI)
* Nils Sauer (Nomic AI) * Nils Sauer (Nomic AI)
* Lakshay Kansal (Nomic AI) * Lakshay Kansal (Nomic AI)
"
},
{
"version": "2.4.16",
"notes":
"
* Bugfix for properly falling back to CPU when GPU can't be used
* Report the actual device we're using
* Fix context bugs for GPU accelerated models
",
"contributors":
"
* Adam Treat (Nomic AI)
* Aaron Miller (Nomic AI)
"
},
{
"version": "2.4.17",
"notes":
"
* Bugfix for properly falling back to CPU when GPU is out of memory
",
"contributors":
"
* Adam Treat (Nomic AI)
* Aaron Miller (Nomic AI)
" "
} }
] ]