Compare commits

...

11 Commits

Author SHA1 Message Date
cebtenzzre
aed2068342
python: always check status code of HTTP responses (#1502) 2023-10-11 18:11:28 -04:00
Aaron Miller
afaa291eab python bindings should be quiet by default
* disable llama.cpp logging unless GPT4ALL_VERBOSE_LLAMACPP envvar is
  nonempty
* make verbose flag for retrieve_model default false (but also be
  overridable via gpt4all constructor)

should be able to run a basic test:

```python
import gpt4all
model = gpt4all.GPT4All('/Users/aaron/Downloads/rift-coder-v0-7b-q4_0.gguf')
print(model.generate('def fib(n):'))
```

and see no non-model output when successful
2023-10-11 14:14:36 -07:00
cebtenzzre
7b611b49f2
llmodel: print an error if the CPU does not support AVX (#1499) 2023-10-11 15:09:40 -04:00
cebtenzzre
f81b4b45bf
python: support Path in GPT4All.__init__ (#1462) 2023-10-11 14:12:40 -04:00
Aaron Miller
043617168e do not process prompts on gpu yet 2023-10-11 13:15:50 -04:00
Aaron Miller
64001a480a mat*mat for q4_0, q8_0 2023-10-11 13:15:50 -04:00
cebtenzzre
04499d1c7d
chatllm: do not write uninitialized data to stream (#1486) 2023-10-11 11:31:34 -04:00
cebtenzzre
7a19047329
llmodel: do not call magic_match unless build variant is correct (#1488) 2023-10-11 11:30:48 -04:00
Adam Treat
df8528df73 Another codespell attempted fix. 2023-10-11 09:17:38 -04:00
Adam Treat
f0742c22f4 Restore state from text if necessary. 2023-10-11 09:16:02 -04:00
Adam Treat
35f9cdb70a Do not delete saved chats if we fail to serialize properly. 2023-10-11 09:16:02 -04:00
11 changed files with 240 additions and 45 deletions

View File

@ -1,3 +1,3 @@
[codespell]
ignore-words-list = blong, belong, afterall, som
ignore-words-list = blong, belong, afterall, som, assistent
skip = .git,*.pdf,*.svg,*.lock

@ -1 +1 @@
Subproject commit 7b8f00f5ccf4fc3cc67fe1ced792b3aec1ae6c1c
Subproject commit 3742085b0429cbe0ede49bcb9f891e4a5e25a724

View File

@ -238,6 +238,10 @@ if (LLAMA_KOMPUTE)
kompute/op_norm.comp
kompute/op_rmsnorm.comp
kompute/op_diagmask.comp
kompute/op_mul_mat_mat_f32.comp
kompute/op_mul_mat_mat_f16.comp
kompute/op_mul_mat_mat_q8_0.comp
kompute/op_mul_mat_mat_q4_0.comp
kompute/op_mul_mat_f16.comp
kompute/op_mul_mat_q8_0.comp
kompute/op_mul_mat_q4_0.comp
@ -268,6 +272,10 @@ if (LLAMA_KOMPUTE)
shaderop_norm.h
shaderop_rmsnorm.h
shaderop_diagmask.h
shaderop_mul_mat_mat_f32.h
shaderop_mul_mat_mat_f16.h
shaderop_mul_mat_mat_q8_0.h
shaderop_mul_mat_mat_q4_0.h
shaderop_mul_mat_f16.h
shaderop_mul_mat_q8_0.h
shaderop_mul_mat_q4_0.h

View File

@ -36,6 +36,17 @@ namespace {
const char *modelType_ = "LLaMA";
}
static void null_log_callback(enum ggml_log_level level, const char* text, void* userdata) {
(void)level;
(void)text;
(void)userdata;
}
static bool llama_verbose() {
const char* var = getenv("GPT4ALL_VERBOSE_LLAMACPP");
return var && *var;
}
struct gpt_params {
int32_t seed = -1; // RNG seed
int32_t n_keep = 0; // number of tokens to keep from initial prompt
@ -144,7 +155,9 @@ bool LLamaModel::loadModel(const std::string &modelPath)
d_ptr->params.use_mlock = params.use_mlock;
#endif
#ifdef GGML_USE_METAL
std::cerr << "llama.cpp: using Metal" << std::endl;
if (llama_verbose()) {
std::cerr << "llama.cpp: using Metal" << std::endl;
}
// metal always runs the whole model if n_gpu_layers is not 0, at least
// currently
d_ptr->params.n_gpu_layers = 1;
@ -390,6 +403,9 @@ DLL_EXPORT bool magic_match(const char * fname) {
}
DLL_EXPORT LLModel *construct() {
if (!llama_verbose()) {
llama_log_set(null_log_callback, nullptr);
}
return new LLamaModel;
}
}

View File

@ -113,17 +113,18 @@ const std::vector<LLModel::Implementation> &LLModel::Implementation::implementat
const LLModel::Implementation* LLModel::Implementation::implementation(const char *fname, const std::string& buildVariant) {
for (const auto& i : implementationList()) {
if (!i.m_magicMatch(fname)) continue;
if (buildVariant != i.m_buildVariant) continue;
if (!i.m_magicMatch(fname)) continue;
return &i;
}
return nullptr;
}
LLModel *LLModel::Implementation::construct(const std::string &modelPath, std::string buildVariant) {
if (!has_at_least_minimal_hardware())
if (!has_at_least_minimal_hardware()) {
std::cerr << "LLModel ERROR: CPU does not support AVX\n";
return nullptr;
}
// Get correct implementation
const Implementation* impl = nullptr;

View File

@ -1,6 +1,8 @@
"""
Python only API for running all GPT4All models.
"""
from __future__ import annotations
import os
import sys
import time
@ -60,11 +62,12 @@ class GPT4All:
def __init__(
self,
model_name: str,
model_path: Optional[str] = None,
model_path: Optional[Union[str, os.PathLike[str]]] = None,
model_type: Optional[str] = None,
allow_download: bool = True,
n_threads: Optional[int] = None,
device: Optional[str] = "cpu",
verbose: bool = False,
):
"""
Constructor
@ -89,7 +92,7 @@ class GPT4All:
self.model_type = model_type
self.model = pyllmodel.LLModel()
# Retrieve model and download if allowed
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
self.config: ConfigType = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download, verbose=verbose)
if device is not None:
if device != "cpu":
self.model.init_gpu(model_path=self.config["path"], device=device)
@ -110,14 +113,17 @@ class GPT4All:
Returns:
Model list in JSON format.
"""
return requests.get("https://gpt4all.io/models/models2.json").json()
resp = requests.get("https://gpt4all.io/models/models2.json")
if resp.status_code != 200:
raise ValueError(f'Request failed: HTTP {resp.status_code} {resp.reason}')
return resp.json()
@staticmethod
def retrieve_model(
model_name: str,
model_path: Optional[str] = None,
model_path: Optional[Union[str, os.PathLike[str]]] = None,
allow_download: bool = True,
verbose: bool = True,
verbose: bool = False,
) -> ConfigType:
"""
Find model file, and if it doesn't exist, download the model.
@ -160,7 +166,7 @@ class GPT4All:
)
model_path = DEFAULT_MODEL_DIRECTORY
else:
model_path = model_path.replace("\\", "\\\\")
model_path = str(model_path).replace("\\", "\\\\")
if not os.path.exists(model_path):
raise ValueError(f"Invalid model directory: {model_path}")
@ -185,7 +191,7 @@ class GPT4All:
@staticmethod
def download_model(
model_filename: str,
model_path: str,
model_path: Union[str, os.PathLike[str]],
verbose: bool = True,
url: Optional[str] = None,
) -> str:
@ -212,6 +218,9 @@ class GPT4All:
download_url = get_download_url(model_filename)
response = requests.get(download_url, stream=True)
if response.status_code != 200:
raise ValueError(f'Request failed: HTTP {response.status_code} {response.reason}')
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 2**20 # 1 MB

View File

@ -385,7 +385,7 @@ bool Chat::serialize(QDataStream &stream, int version) const
stream << m_modelInfo.filename();
if (version > 2)
stream << m_collections;
if (!m_llmodel->serialize(stream, version))
if (!m_llmodel->serialize(stream, version, true /*serializeKV*/))
return false;
if (!m_chatModel->serialize(stream, version))
return false;
@ -404,29 +404,36 @@ bool Chat::deserialize(QDataStream &stream, int version)
QString modelId;
stream >> modelId;
if (version > 4) {
if (!ModelList::globalInstance()->contains(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
if (ModelList::globalInstance()->contains(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfo(modelId);
} else {
if (!ModelList::globalInstance()->containsByFilename(modelId))
return false;
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
if (ModelList::globalInstance()->containsByFilename(modelId))
m_modelInfo = ModelList::globalInstance()->modelInfoByFilename(modelId);
}
emit modelInfoChanged();
if (!m_modelInfo.id().isEmpty())
emit modelInfoChanged();
bool deserializeKV = true; // make this a setting
bool discardKV = m_modelInfo.id().isEmpty();
// Prior to version 2 gptj models had a bug that fixed the kv_cache to F32 instead of F16 so
// unfortunately, we cannot deserialize these
if (version < 2 && m_modelInfo.filename().contains("gpt4all-j"))
return false;
discardKV = true;
if (version > 2) {
stream >> m_collections;
emit collectionListChanged(m_collections);
}
m_llmodel->setModelInfo(m_modelInfo);
if (!m_llmodel->deserialize(stream, version))
if (!m_llmodel->deserialize(stream, version, deserializeKV, discardKV))
return false;
if (!m_chatModel->deserialize(stream, version))
return false;
if (!deserializeKV || discardKV)
m_llmodel->setStateFromText(m_chatModel->text());
emit chatModelChanged();
return stream.status() == QDataStream::Ok;
}

View File

@ -84,13 +84,16 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
const QString savePath = MySettings::globalInstance()->modelPath();
for (Chat *chat : chats) {
QString fileName = "gpt4all-" + chat->id() + ".chat";
QFile file(savePath + "/" + fileName);
bool success = file.open(QIODevice::WriteOnly);
QString filePath = savePath + "/" + fileName;
QFile originalFile(filePath);
QFile tempFile(filePath + ".tmp"); // Temporary file
bool success = tempFile.open(QIODevice::WriteOnly);
if (!success) {
qWarning() << "ERROR: Couldn't save chat to file:" << file.fileName();
qWarning() << "ERROR: Couldn't save chat to temporary file:" << tempFile.fileName();
continue;
}
QDataStream out(&file);
QDataStream out(&tempFile);
out << (quint32)CHAT_FORMAT_MAGIC;
out << (qint32)CHAT_FORMAT_VERSION;
@ -98,11 +101,16 @@ void ChatSaver::saveChats(const QVector<Chat *> &chats)
qDebug() << "serializing chat" << fileName;
if (!chat->serialize(out, CHAT_FORMAT_VERSION)) {
qWarning() << "ERROR: Couldn't serialize chat to file:" << file.fileName();
file.remove();
qWarning() << "ERROR: Couldn't serialize chat to file:" << tempFile.fileName();
tempFile.remove();
continue;
}
file.close();
if (originalFile.exists())
originalFile.remove();
tempFile.rename(filePath);
}
qint64 elapsedTime = timer.elapsed();
qDebug() << "serializing chats took:" << elapsedTime << "ms";
emit saveChatsFinished();
@ -224,7 +232,6 @@ void ChatsRestoreThread::run()
chat->moveToThread(qApp->thread());
if (!chat->deserialize(in, version)) {
qWarning() << "ERROR: Couldn't deserialize chat from file:" << file.fileName();
file.remove();
} else {
emit chatRestored(chat);
}

View File

@ -69,6 +69,7 @@ ChatLLM::ChatLLM(Chat *parent, bool isServer)
, m_forceMetal(MySettings::globalInstance()->forceMetal())
, m_reloadingToChangeVariant(false)
, m_processedSystemPrompt(false)
, m_restoreStateFromText(false)
{
moveToThread(&m_llmThread);
connect(this, &ChatLLM::sendStartup, Network::globalInstance(), &Network::sendStartup);
@ -726,7 +727,35 @@ bool ChatLLM::handleSystemRecalculate(bool isRecalc)
return false;
}
bool ChatLLM::serialize(QDataStream &stream, int version)
bool ChatLLM::handleRestoreStateFromTextPrompt(int32_t token)
{
#if defined(DEBUG)
qDebug() << "restore state from text prompt" << m_llmThread.objectName() << token << m_stopGenerating;
#endif
Q_UNUSED(token);
return !m_stopGenerating;
}
bool ChatLLM::handleRestoreStateFromTextResponse(int32_t token, const std::string &response)
{
#if defined(DEBUG)
qDebug() << "restore state from text response" << m_llmThread.objectName() << token << response << m_stopGenerating;
#endif
Q_UNUSED(token);
Q_UNUSED(response);
return false;
}
bool ChatLLM::handleRestoreStateFromTextRecalculate(bool isRecalc)
{
#if defined(DEBUG)
qDebug() << "restore state from text recalc" << m_llmThread.objectName() << isRecalc;
#endif
Q_UNUSED(isRecalc);
return false;
}
bool ChatLLM::serialize(QDataStream &stream, int version, bool serializeKV)
{
if (version > 1) {
stream << m_llModelType;
@ -741,8 +770,16 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
stream << response();
stream << generatedName();
stream << m_promptResponseTokens;
if (!serializeKV) {
#if defined(DEBUG)
qDebug() << "serialize" << m_llmThread.objectName() << m_state.size();
#endif
return stream.status() == QDataStream::Ok;
}
if (version <= 3) {
int responseLogits;
int responseLogits = 0;
stream << responseLogits;
}
stream << m_ctx.n_past;
@ -759,7 +796,7 @@ bool ChatLLM::serialize(QDataStream &stream, int version)
return stream.status() == QDataStream::Ok;
}
bool ChatLLM::deserialize(QDataStream &stream, int version)
bool ChatLLM::deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV)
{
if (version > 1) {
int internalStateVersion;
@ -773,26 +810,60 @@ bool ChatLLM::deserialize(QDataStream &stream, int version)
stream >> nameResponse;
m_nameResponse = nameResponse.toStdString();
stream >> m_promptResponseTokens;
// If we do not deserialize the KV or it is discarded, then we need to restore the state from the
// text only. This will be a costly operation, but the chat has to be restored from the text archive
// alone.
m_restoreStateFromText = !deserializeKV || discardKV;
if (!deserializeKV) {
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
return stream.status() == QDataStream::Ok;
}
if (version <= 3) {
int responseLogits;
stream >> responseLogits;
}
stream >> m_ctx.n_past;
int32_t n_past;
stream >> n_past;
if (!discardKV) m_ctx.n_past = n_past;
quint64 logitsSize;
stream >> logitsSize;
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
if (!discardKV) {
m_ctx.logits.resize(logitsSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.logits.data()), logitsSize * sizeof(float));
} else {
stream.skipRawData(logitsSize * sizeof(float));
}
quint64 tokensSize;
stream >> tokensSize;
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
if (!discardKV) {
m_ctx.tokens.resize(tokensSize);
stream.readRawData(reinterpret_cast<char*>(m_ctx.tokens.data()), tokensSize * sizeof(int));
} else {
stream.skipRawData(tokensSize * sizeof(int));
}
if (version > 0) {
QByteArray compressed;
stream >> compressed;
m_state = qUncompress(compressed);
if (!discardKV)
m_state = qUncompress(compressed);
} else {
stream >> m_state;
if (!discardKV)
stream >> m_state;
else {
QByteArray state;
stream >> m_state;
}
}
#if defined(DEBUG)
qDebug() << "deserialize" << m_llmThread.objectName();
#endif
@ -823,7 +894,7 @@ void ChatLLM::saveState()
void ChatLLM::restoreState()
{
if (!isModelLoaded() || m_state.isEmpty())
if (!isModelLoaded())
return;
if (m_llModelType == LLModelType::CHATGPT_) {
@ -838,10 +909,19 @@ void ChatLLM::restoreState()
return;
}
if (m_restoreStateFromText) {
Q_ASSERT(m_state.isEmpty());
processRestoreStateFromText();
}
#if defined(DEBUG)
qDebug() << "restoreState" << m_llmThread.objectName() << "size:" << m_state.size();
#endif
m_processedSystemPrompt = true;
if (m_state.isEmpty())
return;
m_llModelInfo.model->restoreState(static_cast<const uint8_t*>(reinterpret_cast<void*>(m_state.data())));
m_state.clear();
m_state.resize(0);
@ -859,7 +939,10 @@ void ChatLLM::processSystemPrompt()
return;
}
// Start with a whole new context
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleSystemPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleSystemResponse, this, std::placeholders::_1,
std::placeholders::_2);
@ -890,5 +973,54 @@ void ChatLLM::processSystemPrompt()
printf("\n");
fflush(stdout);
#endif
m_processedSystemPrompt = true;
m_processedSystemPrompt = !m_stopGenerating;
}
void ChatLLM::processRestoreStateFromText()
{
Q_ASSERT(isModelLoaded());
if (!isModelLoaded() || !m_restoreStateFromText || m_isServer)
return;
m_isRecalc = true;
emit recalcChanged();
m_stopGenerating = false;
m_ctx = LLModel::PromptContext();
auto promptFunc = std::bind(&ChatLLM::handleRestoreStateFromTextPrompt, this, std::placeholders::_1);
auto responseFunc = std::bind(&ChatLLM::handleRestoreStateFromTextResponse, this, std::placeholders::_1,
std::placeholders::_2);
auto recalcFunc = std::bind(&ChatLLM::handleRestoreStateFromTextRecalculate, this, std::placeholders::_1);
const QString promptTemplate = MySettings::globalInstance()->modelPromptTemplate(m_modelInfo);
const int32_t n_predict = MySettings::globalInstance()->modelMaxLength(m_modelInfo);
const int32_t top_k = MySettings::globalInstance()->modelTopK(m_modelInfo);
const float top_p = MySettings::globalInstance()->modelTopP(m_modelInfo);
const float temp = MySettings::globalInstance()->modelTemperature(m_modelInfo);
const int32_t n_batch = MySettings::globalInstance()->modelPromptBatchSize(m_modelInfo);
const float repeat_penalty = MySettings::globalInstance()->modelRepeatPenalty(m_modelInfo);
const int32_t repeat_penalty_tokens = MySettings::globalInstance()->modelRepeatPenaltyTokens(m_modelInfo);
int n_threads = MySettings::globalInstance()->threadCount();
m_ctx.n_predict = n_predict;
m_ctx.top_k = top_k;
m_ctx.top_p = top_p;
m_ctx.temp = temp;
m_ctx.n_batch = n_batch;
m_ctx.repeat_penalty = repeat_penalty;
m_ctx.repeat_last_n = repeat_penalty_tokens;
m_llModelInfo.model->setThreadCount(n_threads);
for (auto pair : m_stateFromText) {
const QString str = pair.first == "Prompt: " ? promptTemplate.arg(pair.second) : pair.second;
m_llModelInfo.model->prompt(str.toStdString(), promptFunc, responseFunc, recalcFunc, m_ctx);
}
if (!m_stopGenerating) {
m_restoreStateFromText = false;
m_stateFromText.clear();
}
m_isRecalc = false;
emit recalcChanged();
}

View File

@ -92,8 +92,9 @@ public:
QString generatedName() const { return QString::fromStdString(m_nameResponse); }
bool serialize(QDataStream &stream, int version);
bool deserialize(QDataStream &stream, int version);
bool serialize(QDataStream &stream, int version, bool serializeKV);
bool deserialize(QDataStream &stream, int version, bool deserializeKV, bool discardKV);
void setStateFromText(const QVector<QPair<QString, QString>> &stateFromText) { m_stateFromText = stateFromText; }
public Q_SLOTS:
bool prompt(const QList<QString> &collectionList, const QString &prompt);
@ -110,6 +111,7 @@ public Q_SLOTS:
void handleForceMetalChanged(bool forceMetal);
void handleDeviceChanged();
void processSystemPrompt();
void processRestoreStateFromText();
Q_SIGNALS:
void recalcChanged();
@ -144,6 +146,9 @@ protected:
bool handleSystemPrompt(int32_t token);
bool handleSystemResponse(int32_t token, const std::string &response);
bool handleSystemRecalculate(bool isRecalc);
bool handleRestoreStateFromTextPrompt(int32_t token);
bool handleRestoreStateFromTextResponse(int32_t token, const std::string &response);
bool handleRestoreStateFromTextRecalculate(bool isRecalc);
void saveState();
void restoreState();
@ -168,6 +173,8 @@ private:
bool m_forceMetal;
bool m_reloadingToChangeVariant;
bool m_processedSystemPrompt;
bool m_restoreStateFromText;
QVector<QPair<QString, QString>> m_stateFromText;
};
#endif // CHATLLM_H

View File

@ -285,6 +285,14 @@ public:
return stream.status() == QDataStream::Ok;
}
QVector<QPair<QString, QString>> text() const
{
QVector<QPair<QString, QString>> result;
for (const auto &c : m_chatItems)
result << qMakePair(c.name, c.value);
return result;
}
Q_SIGNALS:
void countChanged();