mirror of
				https://github.com/nomic-ai/gpt4all.git
				synced 2025-11-04 00:03:33 -05:00 
			
		
		
		
	Much better memory mgmt for multi-threaded model loading/unloading.
This commit is contained in:
		
							parent
							
								
									e67b021948
								
							
						
					
					
						commit
						4448600e60
					
				@ -12,6 +12,7 @@ Chat::Chat(QObject *parent)
 | 
			
		||||
    , m_creationDate(QDateTime::currentSecsSinceEpoch())
 | 
			
		||||
    , m_llmodel(new ChatLLM(this))
 | 
			
		||||
    , m_isServer(false)
 | 
			
		||||
    , m_shouldDeleteLater(false)
 | 
			
		||||
{
 | 
			
		||||
    connectLLM();
 | 
			
		||||
}
 | 
			
		||||
@ -25,6 +26,7 @@ Chat::Chat(bool isServer, QObject *parent)
 | 
			
		||||
    , m_creationDate(QDateTime::currentSecsSinceEpoch())
 | 
			
		||||
    , m_llmodel(new Server(this))
 | 
			
		||||
    , m_isServer(true)
 | 
			
		||||
    , m_shouldDeleteLater(false)
 | 
			
		||||
{
 | 
			
		||||
    connectLLM();
 | 
			
		||||
}
 | 
			
		||||
@ -43,6 +45,7 @@ void Chat::connectLLM()
 | 
			
		||||
 | 
			
		||||
    // Should be in different threads
 | 
			
		||||
    connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::isModelLoadedChanged, Qt::QueuedConnection);
 | 
			
		||||
    connect(m_llmodel, &ChatLLM::isModelLoadedChanged, this, &Chat::handleModelLoadedChanged, Qt::QueuedConnection);
 | 
			
		||||
    connect(m_llmodel, &ChatLLM::responseChanged, this, &Chat::handleResponseChanged, Qt::QueuedConnection);
 | 
			
		||||
    connect(m_llmodel, &ChatLLM::responseStarted, this, &Chat::responseStarted, Qt::QueuedConnection);
 | 
			
		||||
    connect(m_llmodel, &ChatLLM::responseStopped, this, &Chat::responseStopped, Qt::QueuedConnection);
 | 
			
		||||
@ -55,8 +58,6 @@ void Chat::connectLLM()
 | 
			
		||||
    connect(this, &Chat::modelNameChangeRequested, m_llmodel, &ChatLLM::modelNameChangeRequested, Qt::QueuedConnection);
 | 
			
		||||
    connect(this, &Chat::loadDefaultModelRequested, m_llmodel, &ChatLLM::loadDefaultModel, Qt::QueuedConnection);
 | 
			
		||||
    connect(this, &Chat::loadModelRequested, m_llmodel, &ChatLLM::loadModel, Qt::QueuedConnection);
 | 
			
		||||
    connect(this, &Chat::unloadModelRequested, m_llmodel, &ChatLLM::unloadModel, Qt::QueuedConnection);
 | 
			
		||||
    connect(this, &Chat::reloadModelRequested, m_llmodel, &ChatLLM::reloadModel, Qt::QueuedConnection);
 | 
			
		||||
    connect(this, &Chat::generateNameRequested, m_llmodel, &ChatLLM::generateName, Qt::QueuedConnection);
 | 
			
		||||
 | 
			
		||||
    // The following are blocking operations and will block the gui thread, therefore must be fast
 | 
			
		||||
@ -122,6 +123,12 @@ void Chat::handleResponseChanged()
 | 
			
		||||
    emit responseChanged();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Chat::handleModelLoadedChanged()
 | 
			
		||||
{
 | 
			
		||||
    if (m_shouldDeleteLater)
 | 
			
		||||
        deleteLater();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Chat::responseStarted()
 | 
			
		||||
{
 | 
			
		||||
    m_responseInProgress = true;
 | 
			
		||||
@ -180,15 +187,26 @@ void Chat::loadModel(const QString &modelName)
 | 
			
		||||
    emit loadModelRequested(modelName);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Chat::unloadAndDeleteLater()
 | 
			
		||||
{
 | 
			
		||||
    if (!isModelLoaded()) {
 | 
			
		||||
        deleteLater();
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    m_shouldDeleteLater = true;
 | 
			
		||||
    unloadModel();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Chat::unloadModel()
 | 
			
		||||
{
 | 
			
		||||
    stopGenerating();
 | 
			
		||||
    emit unloadModelRequested();
 | 
			
		||||
    m_llmodel->setShouldBeLoaded(false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Chat::reloadModel()
 | 
			
		||||
{
 | 
			
		||||
    emit reloadModelRequested(m_savedModelName);
 | 
			
		||||
    m_llmodel->setShouldBeLoaded(true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void Chat::generatedNameChanged()
 | 
			
		||||
@ -236,12 +254,10 @@ bool Chat::deserialize(QDataStream &stream, int version)
 | 
			
		||||
    stream >> m_userName;
 | 
			
		||||
    emit nameChanged();
 | 
			
		||||
    stream >> m_savedModelName;
 | 
			
		||||
 | 
			
		||||
    // Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
 | 
			
		||||
    // unfortunately, we cannot deserialize these
 | 
			
		||||
    if (version < 2 && m_savedModelName.contains("gpt4all-j"))
 | 
			
		||||
        return false;
 | 
			
		||||
 | 
			
		||||
    if (!m_llmodel->deserialize(stream, version))
 | 
			
		||||
        return false;
 | 
			
		||||
    if (!m_chatModel->deserialize(stream, version))
 | 
			
		||||
 | 
			
		||||
@ -58,6 +58,7 @@ public:
 | 
			
		||||
    void loadModel(const QString &modelName);
 | 
			
		||||
    void unloadModel();
 | 
			
		||||
    void reloadModel();
 | 
			
		||||
    void unloadAndDeleteLater();
 | 
			
		||||
 | 
			
		||||
    qint64 creationDate() const { return m_creationDate; }
 | 
			
		||||
    bool serialize(QDataStream &stream, int version) const;
 | 
			
		||||
@ -87,8 +88,6 @@ Q_SIGNALS:
 | 
			
		||||
    void recalcChanged();
 | 
			
		||||
    void loadDefaultModelRequested();
 | 
			
		||||
    void loadModelRequested(const QString &modelName);
 | 
			
		||||
    void unloadModelRequested();
 | 
			
		||||
    void reloadModelRequested(const QString &modelName);
 | 
			
		||||
    void generateNameRequested();
 | 
			
		||||
    void modelListChanged();
 | 
			
		||||
    void modelLoadingError(const QString &error);
 | 
			
		||||
@ -96,6 +95,7 @@ Q_SIGNALS:
 | 
			
		||||
 | 
			
		||||
private Q_SLOTS:
 | 
			
		||||
    void handleResponseChanged();
 | 
			
		||||
    void handleModelLoadedChanged();
 | 
			
		||||
    void responseStarted();
 | 
			
		||||
    void responseStopped();
 | 
			
		||||
    void generatedNameChanged();
 | 
			
		||||
@ -112,6 +112,7 @@ private:
 | 
			
		||||
    qint64 m_creationDate;
 | 
			
		||||
    ChatLLM *m_llmodel;
 | 
			
		||||
    bool m_isServer;
 | 
			
		||||
    bool m_shouldDeleteLater;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#endif // CHAT_H
 | 
			
		||||
 | 
			
		||||
@ -40,6 +40,7 @@ void ChatListModel::setShouldSaveChats(bool b)
 | 
			
		||||
 | 
			
		||||
void ChatListModel::removeChatFile(Chat *chat) const
 | 
			
		||||
{
 | 
			
		||||
    Q_ASSERT(chat != m_serverChat);
 | 
			
		||||
    const QString savePath = Download::globalInstance()->downloadLocalModelsPath();
 | 
			
		||||
    QFile file(savePath + "/gpt4all-" + chat->id() + ".chat");
 | 
			
		||||
    if (!file.exists())
 | 
			
		||||
@ -58,6 +59,8 @@ void ChatListModel::saveChats() const
 | 
			
		||||
    timer.start();
 | 
			
		||||
    const QString savePath = Download::globalInstance()->downloadLocalModelsPath();
 | 
			
		||||
    for (Chat *chat : m_chats) {
 | 
			
		||||
        if (chat == m_serverChat)
 | 
			
		||||
            continue;
 | 
			
		||||
        QString fileName = "gpt4all-" + chat->id() + ".chat";
 | 
			
		||||
        QFile file(savePath + "/" + fileName);
 | 
			
		||||
        bool success = file.open(QIODevice::WriteOnly);
 | 
			
		||||
 | 
			
		||||
@ -125,6 +125,7 @@ public:
 | 
			
		||||
 | 
			
		||||
    Q_INVOKABLE void removeChat(Chat* chat)
 | 
			
		||||
    {
 | 
			
		||||
        Q_ASSERT(chat != m_serverChat);
 | 
			
		||||
        if (!m_chats.contains(chat)) {
 | 
			
		||||
            qWarning() << "WARNING: Removing chat failed with id" << chat->id();
 | 
			
		||||
            return;
 | 
			
		||||
@ -138,11 +139,11 @@ public:
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const int index = m_chats.indexOf(chat);
 | 
			
		||||
        if (m_chats.count() < 2) {
 | 
			
		||||
        if (m_chats.count() < 3 /*m_serverChat included*/) {
 | 
			
		||||
            addChat();
 | 
			
		||||
        } else {
 | 
			
		||||
            int nextIndex;
 | 
			
		||||
            if (index == m_chats.count() - 1)
 | 
			
		||||
            if (index == m_chats.count() - 2 /*m_serverChat is last*/)
 | 
			
		||||
                nextIndex = index - 1;
 | 
			
		||||
            else
 | 
			
		||||
                nextIndex = index + 1;
 | 
			
		||||
@ -155,7 +156,7 @@ public:
 | 
			
		||||
        beginRemoveRows(QModelIndex(), newIndex, newIndex);
 | 
			
		||||
        m_chats.removeAll(chat);
 | 
			
		||||
        endRemoveRows();
 | 
			
		||||
        delete chat;
 | 
			
		||||
        chat->unloadAndDeleteLater();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    Chat *currentChat() const
 | 
			
		||||
@ -170,7 +171,7 @@ public:
 | 
			
		||||
            return;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (m_currentChat && m_currentChat->isModelLoaded())
 | 
			
		||||
        if (m_currentChat)
 | 
			
		||||
            m_currentChat->unloadModel();
 | 
			
		||||
 | 
			
		||||
        m_currentChat = chat;
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,7 @@
 | 
			
		||||
#include <fstream>
 | 
			
		||||
 | 
			
		||||
//#define DEBUG
 | 
			
		||||
//#define DEBUG_MODEL_LOADING
 | 
			
		||||
 | 
			
		||||
#define MPT_INTERNAL_STATE_VERSION 0
 | 
			
		||||
#define GPTJ_INTERNAL_STATE_VERSION 0
 | 
			
		||||
@ -37,9 +38,51 @@ static QString modelFilePath(const QString &modelName)
 | 
			
		||||
    return QString();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
class LLModelStore {
 | 
			
		||||
public:
 | 
			
		||||
    static LLModelStore *globalInstance();
 | 
			
		||||
 | 
			
		||||
    LLModelInfo acquireModel(); // will block until llmodel is ready
 | 
			
		||||
    void releaseModel(const LLModelInfo &info); // must be called when you are done
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    LLModelStore()
 | 
			
		||||
    {
 | 
			
		||||
        // seed with empty model
 | 
			
		||||
        m_availableModels.append(LLModelInfo());
 | 
			
		||||
    }
 | 
			
		||||
    ~LLModelStore() {}
 | 
			
		||||
    QVector<LLModelInfo> m_availableModels;
 | 
			
		||||
    QMutex m_mutex;
 | 
			
		||||
    QWaitCondition m_condition;
 | 
			
		||||
    friend class MyLLModelStore;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class MyLLModelStore : public LLModelStore { };
 | 
			
		||||
Q_GLOBAL_STATIC(MyLLModelStore, storeInstance)
 | 
			
		||||
LLModelStore *LLModelStore::globalInstance()
 | 
			
		||||
{
 | 
			
		||||
    return storeInstance();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
LLModelInfo LLModelStore::acquireModel()
 | 
			
		||||
{
 | 
			
		||||
    QMutexLocker locker(&m_mutex);
 | 
			
		||||
    while (m_availableModels.isEmpty())
 | 
			
		||||
        m_condition.wait(locker.mutex());
 | 
			
		||||
    return m_availableModels.takeFirst();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void LLModelStore::releaseModel(const LLModelInfo &info)
 | 
			
		||||
{
 | 
			
		||||
    QMutexLocker locker(&m_mutex);
 | 
			
		||||
    m_availableModels.append(info);
 | 
			
		||||
    Q_ASSERT(m_availableModels.count() < 2);
 | 
			
		||||
    m_condition.wakeAll();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ChatLLM::ChatLLM(Chat *parent)
 | 
			
		||||
    : QObject{nullptr}
 | 
			
		||||
    , m_llmodel(nullptr)
 | 
			
		||||
    , m_promptResponseTokens(0)
 | 
			
		||||
    , m_promptTokens(0)
 | 
			
		||||
    , m_responseLogits(0)
 | 
			
		||||
@ -49,6 +92,7 @@ ChatLLM::ChatLLM(Chat *parent)
 | 
			
		||||
    moveToThread(&m_llmThread);
 | 
			
		||||
    connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
 | 
			
		||||
    connect(this, &ChatLLM::sendModelLoaded, Network::globalInstance(), &Network::sendModelLoaded);
 | 
			
		||||
    connect(this, &ChatLLM::shouldBeLoadedChanged, this, &ChatLLM::handleShouldBeLoadedChanged, Qt::QueuedConnection);
 | 
			
		||||
    connect(m_chat, &Chat::idChanged, this, &ChatLLM::handleChatIdChanged);
 | 
			
		||||
    connect(&m_llmThread, &QThread::started, this, &ChatLLM::threadStarted);
 | 
			
		||||
    m_llmThread.setObjectName(m_chat->id());
 | 
			
		||||
@ -59,7 +103,13 @@ ChatLLM::~ChatLLM()
 | 
			
		||||
{
 | 
			
		||||
    m_llmThread.quit();
 | 
			
		||||
    m_llmThread.wait();
 | 
			
		||||
    delete m_llmodel;
 | 
			
		||||
 | 
			
		||||
    // The only time we should have a model loaded here is on shutdown
 | 
			
		||||
    // as we explicitly unload the model in all other circumstances
 | 
			
		||||
    if (isModelLoaded()) {
 | 
			
		||||
        delete m_modelInfo.model;
 | 
			
		||||
        m_modelInfo.model = nullptr;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ChatLLM::loadDefaultModel()
 | 
			
		||||
@ -76,50 +126,103 @@ bool ChatLLM::loadDefaultModel()
 | 
			
		||||
 | 
			
		||||
bool ChatLLM::loadModel(const QString &modelName)
 | 
			
		||||
{
 | 
			
		||||
    // This is a complicated method because N different possible threads are interested in the outcome
 | 
			
		||||
    // of this method. Why? Because we have a main/gui thread trying to monitor the state of N different
 | 
			
		||||
    // possible chat threads all vying for a single resource - the currently loaded model - as the user
 | 
			
		||||
    // switches back and forth between chats. It is important for our main/gui thread to never block
 | 
			
		||||
    // but simultaneously always have up2date information with regards to which chat has the model loaded
 | 
			
		||||
    // and what the type and name of that model is. I've tried to comment extensively in this method
 | 
			
		||||
    // to provide an overview of what we're doing here.
 | 
			
		||||
 | 
			
		||||
    // We're already loaded with this model
 | 
			
		||||
    if (isModelLoaded() && m_modelName == modelName)
 | 
			
		||||
        return true;
 | 
			
		||||
 | 
			
		||||
    if (isModelLoaded()) {
 | 
			
		||||
    QString filePath = modelFilePath(modelName);
 | 
			
		||||
    QFileInfo fileInfo(filePath);
 | 
			
		||||
 | 
			
		||||
    // We have a live model, but it isn't the one we want
 | 
			
		||||
    bool alreadyAcquired = isModelLoaded();
 | 
			
		||||
    if (alreadyAcquired) {
 | 
			
		||||
        resetContextProtected();
 | 
			
		||||
        delete m_llmodel;
 | 
			
		||||
        m_llmodel = nullptr;
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
        qDebug() << "already acquired model deleted" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
        delete m_modelInfo.model;
 | 
			
		||||
        m_modelInfo.model = nullptr;
 | 
			
		||||
        emit isModelLoadedChanged();
 | 
			
		||||
    } else {
 | 
			
		||||
        // This is a blocking call that tries to retrieve the model we need from the model store.
 | 
			
		||||
        // If it succeeds, then we just have to restore state. If the store has never had a model
 | 
			
		||||
        // returned to it, then the modelInfo.model pointer should be null which will happen on startup
 | 
			
		||||
        m_modelInfo = LLModelStore::globalInstance()->acquireModel();
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
        qDebug() << "acquired model from store" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
        // At this point it is possible that while we were blocked waiting to acquire the model from the
 | 
			
		||||
        // store, that our state was changed to not be loaded. If this is the case, release the model
 | 
			
		||||
        // back into the store and quit loading
 | 
			
		||||
        if (!m_shouldBeLoaded) {
 | 
			
		||||
            qDebug() << "no longer need model" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
            LLModelStore::globalInstance()->releaseModel(m_modelInfo);
 | 
			
		||||
            m_modelInfo = LLModelInfo();
 | 
			
		||||
            emit isModelLoadedChanged();
 | 
			
		||||
            return false;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Check if the store just gave us exactly the model we were looking for
 | 
			
		||||
        if (m_modelInfo.model && m_modelInfo.fileInfo == fileInfo) {
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
            qDebug() << "store had our model" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
            restoreState();
 | 
			
		||||
            emit isModelLoadedChanged();
 | 
			
		||||
            return true;
 | 
			
		||||
        } else {
 | 
			
		||||
            // Release the memory since we have to switch to a different model.
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
            qDebug() << "deleting model" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
            delete m_modelInfo.model;
 | 
			
		||||
            m_modelInfo.model = nullptr;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool isGPTJ = false;
 | 
			
		||||
    bool isMPT = false;
 | 
			
		||||
    QString filePath = modelFilePath(modelName);
 | 
			
		||||
    QFileInfo info(filePath);
 | 
			
		||||
    if (info.exists()) {
 | 
			
		||||
    // Guarantee we've released the previous models memory
 | 
			
		||||
    Q_ASSERT(!m_modelInfo.model);
 | 
			
		||||
 | 
			
		||||
    // Store the file info in the modelInfo in case we have an error loading
 | 
			
		||||
    m_modelInfo.fileInfo = fileInfo;
 | 
			
		||||
 | 
			
		||||
    if (fileInfo.exists()) {
 | 
			
		||||
        auto fin = std::ifstream(filePath.toStdString(), std::ios::binary);
 | 
			
		||||
        uint32_t magic;
 | 
			
		||||
        fin.read((char *) &magic, sizeof(magic));
 | 
			
		||||
        fin.seekg(0);
 | 
			
		||||
        fin.close();
 | 
			
		||||
        isGPTJ = magic == 0x67676d6c;
 | 
			
		||||
        isMPT = magic == 0x67676d6d;
 | 
			
		||||
        const bool isGPTJ = magic == 0x67676d6c;
 | 
			
		||||
        const bool isMPT = magic == 0x67676d6d;
 | 
			
		||||
        if (isGPTJ) {
 | 
			
		||||
            m_modelType = ModelType::GPTJ_;
 | 
			
		||||
            m_llmodel = new GPTJ;
 | 
			
		||||
            m_llmodel->loadModel(filePath.toStdString());
 | 
			
		||||
            m_modelType = LLModelType::GPTJ_;
 | 
			
		||||
            m_modelInfo.model = new GPTJ;
 | 
			
		||||
            m_modelInfo.model->loadModel(filePath.toStdString());
 | 
			
		||||
        } else if (isMPT) {
 | 
			
		||||
            m_modelType = ModelType::MPT_;
 | 
			
		||||
            m_llmodel = new MPT;
 | 
			
		||||
            m_llmodel->loadModel(filePath.toStdString());
 | 
			
		||||
            m_modelType = LLModelType::MPT_;
 | 
			
		||||
            m_modelInfo.model = new MPT;
 | 
			
		||||
            m_modelInfo.model->loadModel(filePath.toStdString());
 | 
			
		||||
        } else {
 | 
			
		||||
            m_modelType = ModelType::LLAMA_;
 | 
			
		||||
            m_llmodel = new LLamaModel;
 | 
			
		||||
            m_llmodel->loadModel(filePath.toStdString());
 | 
			
		||||
            m_modelType = LLModelType::LLAMA_;
 | 
			
		||||
            m_modelInfo.model = new LLamaModel;
 | 
			
		||||
            m_modelInfo.model->loadModel(filePath.toStdString());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        restoreState();
 | 
			
		||||
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm modelLoadedChanged" << m_chat->id();
 | 
			
		||||
    fflush(stdout);
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
        qDebug() << "new model" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
        restoreState();
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
        qDebug() << "modelLoadedChanged" << m_chat->id();
 | 
			
		||||
        fflush(stdout);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
        emit isModelLoadedChanged();
 | 
			
		||||
 | 
			
		||||
        static bool isFirstLoad = true;
 | 
			
		||||
@ -129,19 +232,20 @@ bool ChatLLM::loadModel(const QString &modelName)
 | 
			
		||||
        } else
 | 
			
		||||
            emit sendModelLoaded();
 | 
			
		||||
    } else {
 | 
			
		||||
        LLModelStore::globalInstance()->releaseModel(m_modelInfo); // release back into the store
 | 
			
		||||
        const QString error = QString("Could not find model %1").arg(modelName);
 | 
			
		||||
        emit modelLoadingError(error);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (m_llmodel)
 | 
			
		||||
        setModelName(info.completeBaseName().remove(0, 5)); // remove the ggml- prefix
 | 
			
		||||
    if (m_modelInfo.model)
 | 
			
		||||
        setModelName(fileInfo.completeBaseName().remove(0, 5)); // remove the ggml- prefix
 | 
			
		||||
 | 
			
		||||
    return m_llmodel;
 | 
			
		||||
    return m_modelInfo.model;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool ChatLLM::isModelLoaded() const
 | 
			
		||||
{
 | 
			
		||||
    return m_llmodel && m_llmodel->isModelLoaded();
 | 
			
		||||
    return m_modelInfo.model && m_modelInfo.model->isModelLoaded();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::regenerateResponse()
 | 
			
		||||
@ -226,7 +330,7 @@ bool ChatLLM::handlePrompt(int32_t token)
 | 
			
		||||
    // m_promptResponseTokens and m_responseLogits are related to last prompt/response not
 | 
			
		||||
    // the entire context window which we can reset on regenerate prompt
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm prompt process" << m_chat->id() << token;
 | 
			
		||||
    qDebug() << "prompt process" << m_chat->id() << token;
 | 
			
		||||
#endif
 | 
			
		||||
    ++m_promptTokens;
 | 
			
		||||
    ++m_promptResponseTokens;
 | 
			
		||||
@ -287,12 +391,12 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
 | 
			
		||||
    m_ctx.n_batch = n_batch;
 | 
			
		||||
    m_ctx.repeat_penalty = repeat_penalty;
 | 
			
		||||
    m_ctx.repeat_last_n = repeat_penalty_tokens;
 | 
			
		||||
    m_llmodel->setThreadCount(n_threads);
 | 
			
		||||
    m_modelInfo.model->setThreadCount(n_threads);
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    printf("%s", qPrintable(instructPrompt));
 | 
			
		||||
    fflush(stdout);
 | 
			
		||||
#endif
 | 
			
		||||
    m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
 | 
			
		||||
    m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    printf("\n");
 | 
			
		||||
    fflush(stdout);
 | 
			
		||||
@ -307,26 +411,55 @@ bool ChatLLM::prompt(const QString &prompt, const QString &prompt_template, int3
 | 
			
		||||
    return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::setShouldBeLoaded(bool b)
 | 
			
		||||
{
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
    qDebug() << "setShouldBeLoaded" << m_chat->id() << b << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
    m_shouldBeLoaded = b; // atomic
 | 
			
		||||
    emit shouldBeLoadedChanged();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::handleShouldBeLoadedChanged()
 | 
			
		||||
{
 | 
			
		||||
    if (m_shouldBeLoaded)
 | 
			
		||||
        reloadModel();
 | 
			
		||||
    else
 | 
			
		||||
        unloadModel();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::forceUnloadModel()
 | 
			
		||||
{
 | 
			
		||||
    m_shouldBeLoaded = false; // atomic
 | 
			
		||||
    unloadModel();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::unloadModel()
 | 
			
		||||
{
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm unloadModel" << m_chat->id();
 | 
			
		||||
#endif
 | 
			
		||||
    if (!isModelLoaded())
 | 
			
		||||
        return;
 | 
			
		||||
 | 
			
		||||
    saveState();
 | 
			
		||||
    delete m_llmodel;
 | 
			
		||||
    m_llmodel = nullptr;
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
    qDebug() << "unloadModel" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
    LLModelStore::globalInstance()->releaseModel(m_modelInfo);
 | 
			
		||||
    m_modelInfo = LLModelInfo();
 | 
			
		||||
    emit isModelLoadedChanged();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::reloadModel(const QString &modelName)
 | 
			
		||||
void ChatLLM::reloadModel()
 | 
			
		||||
{
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm reloadModel" << m_chat->id();
 | 
			
		||||
    if (isModelLoaded())
 | 
			
		||||
        return;
 | 
			
		||||
 | 
			
		||||
#if defined(DEBUG_MODEL_LOADING)
 | 
			
		||||
    qDebug() << "reloadModel" << m_chat->id() << m_modelInfo.model;
 | 
			
		||||
#endif
 | 
			
		||||
    if (modelName.isEmpty()) {
 | 
			
		||||
    if (m_modelName.isEmpty()) {
 | 
			
		||||
        loadDefaultModel();
 | 
			
		||||
    } else {
 | 
			
		||||
        loadModel(modelName);
 | 
			
		||||
        loadModel(m_modelName);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -348,7 +481,7 @@ void ChatLLM::generateName()
 | 
			
		||||
    printf("%s", qPrintable(instructPrompt));
 | 
			
		||||
    fflush(stdout);
 | 
			
		||||
#endif
 | 
			
		||||
    m_llmodel->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
 | 
			
		||||
    m_modelInfo.model->prompt(instructPrompt.toStdString(), promptFunc, responseFunc, recalcFunc, ctx);
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    printf("\n");
 | 
			
		||||
    fflush(stdout);
 | 
			
		||||
@ -415,7 +548,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
 | 
			
		||||
    QByteArray compressed = qCompress(m_state);
 | 
			
		||||
    stream << compressed;
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm serialize" << m_chat->id() << m_state.size();
 | 
			
		||||
    qDebug() << "serialize" << m_chat->id() << m_state.size();
 | 
			
		||||
#endif
 | 
			
		||||
    return stream.status() == QDataStream::Ok;
 | 
			
		||||
}
 | 
			
		||||
@ -452,7 +585,7 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
 | 
			
		||||
        stream >> m_state;
 | 
			
		||||
    }
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm deserialize" << m_chat->id();
 | 
			
		||||
    qDebug() << "deserialize" << m_chat->id();
 | 
			
		||||
#endif
 | 
			
		||||
    return stream.status() == QDataStream::Ok;
 | 
			
		||||
}
 | 
			
		||||
@ -462,12 +595,12 @@ void ChatLLM::saveState()
 | 
			
		||||
    if (!isModelLoaded())
 | 
			
		||||
        return;
 | 
			
		||||
 | 
			
		||||
    const size_t stateSize = m_llmodel->stateSize();
 | 
			
		||||
    const size_t stateSize = m_modelInfo.model->stateSize();
 | 
			
		||||
    m_state.resize(stateSize);
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm saveState" << m_chat->id() << "size:" << m_state.size();
 | 
			
		||||
    qDebug() << "saveState" << m_chat->id() << "size:" << m_state.size();
 | 
			
		||||
#endif
 | 
			
		||||
    m_llmodel->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
 | 
			
		||||
    m_modelInfo.model->saveState(static_cast<uint8_t*>(reinterpret_cast<void*>(m_state.data())));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void ChatLLM::restoreState()
 | 
			
		||||
@ -476,9 +609,9 @@ void ChatLLM::restoreState()
 | 
			
		||||
        return;
 | 
			
		||||
 | 
			
		||||
#if defined(DEBUG)
 | 
			
		||||
    qDebug() << "chatllm restoreState" << m_chat->id() << "size:" << m_state.size();
 | 
			
		||||
    qDebug() << "restoreState" << m_chat->id() << "size:" << m_state.size();
 | 
			
		||||
#endif
 | 
			
		||||
    m_llmodel->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
 | 
			
		||||
    m_modelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
 | 
			
		||||
    m_state.clear();
 | 
			
		||||
    m_state.resize(0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -3,9 +3,23 @@
 | 
			
		||||
 | 
			
		||||
#include <QObject>
 | 
			
		||||
#include <QThread>
 | 
			
		||||
#include <QFileInfo>
 | 
			
		||||
 | 
			
		||||
#include "../gpt4all-backend/llmodel.h"
 | 
			
		||||
 | 
			
		||||
enum LLModelType {
 | 
			
		||||
    MPT_,
 | 
			
		||||
    GPTJ_,
 | 
			
		||||
    LLAMA_
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct LLModelInfo {
 | 
			
		||||
    LLModel *model = nullptr;
 | 
			
		||||
    QFileInfo fileInfo;
 | 
			
		||||
    // NOTE: This does not store the model type or name on purpose as this is left for ChatLLM which
 | 
			
		||||
    // must be able to serialize the information even if it is in the unloaded state
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class Chat;
 | 
			
		||||
class ChatLLM : public QObject
 | 
			
		||||
{
 | 
			
		||||
@ -17,12 +31,6 @@ class ChatLLM : public QObject
 | 
			
		||||
    Q_PROPERTY(QString generatedName READ generatedName NOTIFY generatedNameChanged)
 | 
			
		||||
 | 
			
		||||
public:
 | 
			
		||||
    enum ModelType {
 | 
			
		||||
        MPT_,
 | 
			
		||||
        GPTJ_,
 | 
			
		||||
        LLAMA_
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    ChatLLM(Chat *parent);
 | 
			
		||||
    virtual ~ChatLLM();
 | 
			
		||||
 | 
			
		||||
@ -33,6 +41,9 @@ public:
 | 
			
		||||
 | 
			
		||||
    void stopGenerating() { m_stopGenerating = true; }
 | 
			
		||||
 | 
			
		||||
    bool shouldBeLoaded() const { return m_shouldBeLoaded; }
 | 
			
		||||
    void setShouldBeLoaded(bool b);
 | 
			
		||||
 | 
			
		||||
    QString response() const;
 | 
			
		||||
    QString modelName() const;
 | 
			
		||||
 | 
			
		||||
@ -52,10 +63,12 @@ public Q_SLOTS:
 | 
			
		||||
    bool loadDefaultModel();
 | 
			
		||||
    bool loadModel(const QString &modelName);
 | 
			
		||||
    void modelNameChangeRequested(const QString &modelName);
 | 
			
		||||
    void forceUnloadModel();
 | 
			
		||||
    void unloadModel();
 | 
			
		||||
    void reloadModel(const QString &modelName);
 | 
			
		||||
    void reloadModel();
 | 
			
		||||
    void generateName();
 | 
			
		||||
    void handleChatIdChanged();
 | 
			
		||||
    void handleShouldBeLoadedChanged();
 | 
			
		||||
 | 
			
		||||
Q_SIGNALS:
 | 
			
		||||
    void isModelLoadedChanged();
 | 
			
		||||
@ -71,6 +84,7 @@ Q_SIGNALS:
 | 
			
		||||
    void generatedNameChanged();
 | 
			
		||||
    void stateChanged();
 | 
			
		||||
    void threadStarted();
 | 
			
		||||
    void shouldBeLoadedChanged();
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
    LLModel::PromptContext m_ctx;
 | 
			
		||||
@ -89,16 +103,17 @@ private:
 | 
			
		||||
    void restoreState();
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    LLModel *m_llmodel;
 | 
			
		||||
    LLModelInfo m_modelInfo;
 | 
			
		||||
    LLModelType m_modelType;
 | 
			
		||||
    std::string m_response;
 | 
			
		||||
    std::string m_nameResponse;
 | 
			
		||||
    quint32 m_responseLogits;
 | 
			
		||||
    QString m_modelName;
 | 
			
		||||
    ModelType m_modelType;
 | 
			
		||||
    Chat *m_chat;
 | 
			
		||||
    QByteArray m_state;
 | 
			
		||||
    QThread m_llmThread;
 | 
			
		||||
    std::atomic<bool> m_stopGenerating;
 | 
			
		||||
    std::atomic<bool> m_shouldBeLoaded;
 | 
			
		||||
    bool m_isRecalc;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user