Compare commits

...

13 Commits

Author SHA1 Message Date
Andriy Mulyar
cfd70b69fc
Update gpt4all_python_embedding.md
Signed-off-by: Andriy Mulyar <andriy.mulyar@gmail.com>
2023-07-14 14:54:56 -04:00
Andriy Mulyar
306105e62f
Update gpt4all_python_embedding.md
Signed-off-by: Andriy Mulyar <andriy.mulyar@gmail.com>
2023-07-14 14:54:36 -04:00
Andriy Mulyar
89e277bb3c
Update gpt4all_python_embedding.md
Signed-off-by: Andriy Mulyar <andriy.mulyar@gmail.com>
2023-07-14 14:30:14 -04:00
Adam Treat
f543affa9a Add better docs and threading support to bert. 2023-07-14 14:14:22 -04:00
Lakshay Kansal
6c8669cad3 highlighting rules for html and php and latex 2023-07-14 11:36:01 -04:00
Adam Treat
0c0a4f2c22 Add the docs. 2023-07-14 10:48:18 -04:00
Adam Treat
6656f0f41e Fix the test to work and not do timings. 2023-07-14 09:48:57 -04:00
Adam Treat
bb2b82e1b9 Add docs and bump version since we changed python api again. 2023-07-14 09:48:57 -04:00
Aaron Miller
c77ab849c0 LLModel objects should hold a reference to the library
prevents llmodel lib from being gc'd before live model objects
2023-07-14 09:48:57 -04:00
Aaron Miller
1c4a244291 bump mem allocation a bit 2023-07-14 09:48:57 -04:00
Aaron Miller
936dcd2bfc use default n_threads 2023-07-14 09:48:57 -04:00
Aaron Miller
15f1fe5445 rename embedder 2023-07-14 09:48:57 -04:00
Adam Treat
ee4186d579 Fixup bert python bindings. 2023-07-14 09:48:57 -04:00
11 changed files with 256 additions and 34 deletions

View File

@ -14,6 +14,7 @@
#include <regex>
#include <thread>
#include <algorithm>
#include <numeric>
//#define DEBUG_BERT
@ -462,11 +463,6 @@ void bert_eval(
ggml_set_f32(sum, 1.0f / N);
inpL = ggml_mul_mat(ctx0, inpL, sum);
// normalizer
ggml_tensor *length = ggml_sqrt(ctx0,
ggml_sum(ctx0, ggml_sqr(ctx0, inpL)));
inpL = ggml_scale(ctx0, inpL, ggml_div(ctx0, ggml_new_f32(ctx0, 1.0f), length));
ggml_tensor *output = inpL;
// run the computation
ggml_build_forward_expand(&gf, output);
@ -875,7 +871,7 @@ struct bert_ctx * bert_load_from_file(const char *fname)
// TODO: Max tokens should be a param?
int32_t N = new_bert->model.hparams.n_max_tokens;
new_bert->mem_per_input = 1.9 * (new_bert->mem_per_token * N); // add 10% to account for ggml object overhead
new_bert->mem_per_input = 2.2 * (new_bert->mem_per_token * N); // add 10% to account for ggml object overhead
}
#if defined(DEBUG_BERT)
@ -987,6 +983,9 @@ std::vector<float> Bert::embedding(const std::string &text)
}
std::transform(embeddingsSum.begin(), embeddingsSum.end(), embeddingsSum.begin(), [embeddingsSumTotal](float num){ return num / embeddingsSumTotal; });
double magnitude = std::sqrt(std::inner_product(embeddingsSum.begin(), embeddingsSum.end(), embeddingsSum.begin(), 0.0));
for (auto &value : embeddingsSum)
value /= magnitude;
std::vector<float> finalEmbeddings(embeddingsSum.begin(), embeddingsSum.end());
return finalEmbeddings;
}

View File

@ -1,8 +1,7 @@
# GPT4All Python API
# GPT4All Python Generation API
The `GPT4All` python package provides bindings to our C/C++ model backend libraries.
The source code and local build instructions can be found [here](https://github.com/nomic-ai/gpt4all/tree/main/gpt4all-bindings/python).
## Quickstart
```bash
@ -109,5 +108,5 @@ with model.chat_session():
print(model.current_chat_session)
```
### API documentation
::: gpt4all.gpt4all.GPT4All

View File

@ -0,0 +1,35 @@
# Embeddings
GPT4All supports generating high quality embeddings of arbitrary length documents of text using a CPU optimized contrastively trained [Sentence Transformer](https://www.sbert.net/). These embeddings are comparable in quality for many tasks with OpenAI.
## Quickstart
```bash
pip install gpt4all
```
### Generating embeddings
The embedding model will automatically be downloaded if not installed.
=== "Embed4All Example"
``` py
from gpt4all import GPT4All, Embed4All
text = 'The quick brown fox jumps over the lazy dog'
embedder = Embed4All()
output = embedder.embed(text)
print(output)
```
=== "Output"
```
[0.034696947783231735, -0.07192722707986832, 0.06923297047615051, ...]
```
### Speed of embedding generation
The following table lists the generation speed for text document captured on an Intel i913900HX CPU with DDR5 5600 running with 8 threads under stable load.
| Tokens | 128 | 512 | 2048 | 8129 | 16,384 |
| --------------- | ---- | ---- | ---- | ---- | ---- |
| Wall time (s) | .02 | .08 | .24 | .96 | 1.9 |
| Tokens / Second | 6508 | 6431 | 8622 | 8509 | 8369 |
### API documentation
::: gpt4all.gpt4all.Embed4All

View File

@ -1,2 +1,2 @@
from .gpt4all import GPT4All, embed # noqa
from .gpt4all import GPT4All, Embed4All # noqa
from .pyllmodel import LLModel # noqa

View File

@ -15,20 +15,36 @@ from . import pyllmodel
# TODO: move to config
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
def embed(
text: str
) -> list[float]:
class Embed4All:
"""
Generate an embedding for all GPT4All.
Args:
text: The text document to generate an embedding for.
Returns:
An embedding of your document of text.
Python class that handles embeddings for GPT4All.
"""
model = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin')
return model.model.generate_embedding(text)
def __init__(
self,
n_threads: Optional[int] = None,
):
"""
Constructor
Args:
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
"""
self.gpt4all = GPT4All(model_name='ggml-all-MiniLM-L6-v2-f16.bin', n_threads=n_threads)
def embed(
self,
text: str
) -> list[float]:
"""
Generate an embedding.
Args:
text: The text document to generate an embedding for.
Returns:
An embedding of your document of text.
"""
return self.gpt4all.model.generate_embedding(text)
class GPT4All:
"""
@ -53,7 +69,7 @@ class GPT4All:
model_type: Model architecture. This argument currently does not have any functionality and is just used as
descriptive identifier for user. Default is None.
allow_download: Allow API to download models from gpt4all.io. Default is True.
n_threads: number of CPU threads used by GPT4All. Default is None, than the number of threads are determined automatically.
n_threads: number of CPU threads used by GPT4All. Default is None, then the number of threads are determined automatically.
"""
self.model_type = model_type
self.model = pyllmodel.LLModel()

View File

@ -154,10 +154,11 @@ class LLModel:
self.model = None
self.model_name = None
self.context = None
self.llmodel_lib = llmodel
def __del__(self):
if self.model is not None:
llmodel.llmodel_model_destroy(self.model)
self.llmodel_lib.llmodel_model_destroy(self.model)
def memory_needed(self, model_path: str) -> int:
model_path_enc = model_path.encode("utf-8")
@ -253,7 +254,7 @@ class LLModel:
embedding_size = ctypes.c_size_t()
c_text = ctypes.c_char_p(text.encode('utf-8'))
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
embedding_array = ctypes.cast(embedding_ptr, ctypes.POINTER(ctypes.c_float * embedding_size.value)).contents
embedding_array = [embedding_ptr[i] for i in range(embedding_size.value)]
llmodel.llmodel_free_embedding(embedding_ptr)
return list(embedding_array)

View File

@ -0,0 +1,18 @@
import sys
from io import StringIO
from gpt4all import GPT4All, Embed4All
import time
def time_embedding(i, embedder):
text = 'foo bar ' * i
start_time = time.time()
output = embedder.embed(text)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"Time report: {2 * i / elapsed_time} tokens/second with {2 * i} tokens taking {elapsed_time} seconds")
if __name__ == "__main__":
embedder = Embed4All(n_threads=8)
for i in [2**n for n in range(6, 14)]:
time_embedding(i, embedder)

File diff suppressed because one or more lines are too long

View File

@ -10,7 +10,9 @@ use_directory_urls: false
nav:
- 'index.md'
- 'Bindings':
- 'GPT4All in Python': 'gpt4all_python.md'
- 'GPT4All in Python':
- 'Generation': 'gpt4all_python.md'
- 'Embedding': 'gpt4all_python_embedding.md'
- 'GPT4All Chat Client': 'gpt4all_chat.md'
- 'gpt4all_cli.md'
# - 'Tutorials':
@ -68,4 +70,4 @@ plugins:
#- mkdocs-jupyter:
# ignore_h1_titles: True
# show_input: True
# show_input: True

View File

@ -61,7 +61,7 @@ copy_prebuilt_C_lib(SRC_CLIB_DIRECtORY,
setup(
name=package_name,
version="1.0.4",
version="1.0.6",
description="Python bindings for GPT4All",
author="Richard Guo",
author_email="richard@nomic.ai",

View File

@ -18,6 +18,9 @@ enum Language {
Go,
Json,
Csharp,
Latex,
Html,
Php
};
static QColor keywordColor = "#2e95d3"; // blue
@ -33,6 +36,11 @@ static QColor commandColor = functionCallColor;
static QColor variableColor = numberColor;
static QColor keyColor = functionColor;
static QColor valueColor = stringColor;
static QColor parameterColor = stringColor;
static QColor attributeNameColor = numberColor;
static QColor attributeValueColor = stringColor;
static QColor specialCharacterColor = functionColor;
static QColor doctypeColor = commentColor;
static Language stringToLanguage(const QString &language)
{
@ -62,6 +70,12 @@ static Language stringToLanguage(const QString &language)
return Go;
if (language == "json")
return Json;
if (language == "latex")
return Latex;
if (language == "html")
return Html;
if (language == "php")
return Php;
return None;
}
@ -561,6 +575,135 @@ static QVector<HighlightingRule> bashHighlightingRules()
return highlightingRules;
}
static QVector<HighlightingRule> latexHighlightingRules()
{
static QVector<HighlightingRule> highlightingRules;
if (highlightingRules.isEmpty()) {
HighlightingRule rule;
QTextCharFormat commandFormat;
commandFormat.setForeground(commandColor); // commandColor needs to be set to your liking
rule.pattern = QRegularExpression("\\\\[A-Za-z]+"); // Pattern for LaTeX commands
rule.format = commandFormat;
highlightingRules.append(rule);
QTextCharFormat commentFormat;
commentFormat.setForeground(commentColor); // commentColor needs to be set to your liking
rule.pattern = QRegularExpression("%[^\n]*"); // Pattern for LaTeX comments
rule.format = commentFormat;
highlightingRules.append(rule);
}
return highlightingRules;
}
static QVector<HighlightingRule> htmlHighlightingRules()
{
static QVector<HighlightingRule> highlightingRules;
if (highlightingRules.isEmpty()) {
HighlightingRule rule;
QTextCharFormat attributeNameFormat;
attributeNameFormat.setForeground(attributeNameColor);
rule.pattern = QRegularExpression("\\b(\\w+)\\s*=");
rule.format = attributeNameFormat;
highlightingRules.append(rule);
QTextCharFormat attributeValueFormat;
attributeValueFormat.setForeground(attributeValueColor);
rule.pattern = QRegularExpression("\".*?\"|'.*?'");
rule.format = attributeValueFormat;
highlightingRules.append(rule);
QTextCharFormat commentFormat;
commentFormat.setForeground(commentColor);
rule.pattern = QRegularExpression("<!--.*?-->");
rule.format = commentFormat;
highlightingRules.append(rule);
QTextCharFormat specialCharacterFormat;
specialCharacterFormat.setForeground(specialCharacterColor);
rule.pattern = QRegularExpression("&[a-zA-Z0-9#]*;");
rule.format = specialCharacterFormat;
highlightingRules.append(rule);
QTextCharFormat doctypeFormat;
doctypeFormat.setForeground(doctypeColor);
rule.pattern = QRegularExpression("<!DOCTYPE.*?>");
rule.format = doctypeFormat;
highlightingRules.append(rule);
}
return highlightingRules;
}
static QVector<HighlightingRule> phpHighlightingRules()
{
static QVector<HighlightingRule> highlightingRules;
if (highlightingRules.isEmpty()) {
HighlightingRule rule;
QTextCharFormat functionCallFormat;
functionCallFormat.setForeground(functionCallColor);
rule.pattern = QRegularExpression("\\b(\\w+)\\s*(?=\\()");
rule.format = functionCallFormat;
highlightingRules.append(rule);
QTextCharFormat functionFormat;
functionFormat.setForeground(functionColor);
rule.pattern = QRegularExpression("\\bfunction\\s+(\\w+)\\b");
rule.format = functionFormat;
highlightingRules.append(rule);
QTextCharFormat numberFormat;
numberFormat.setForeground(numberColor);
rule.pattern = QRegularExpression("\\b[0-9]*\\.?[0-9]+\\b");
rule.format = numberFormat;
highlightingRules.append(rule);
QTextCharFormat keywordFormat;
keywordFormat.setForeground(keywordColor);
QStringList keywordPatterns = {
"\\bif\\b", "\\belse\\b", "\\belseif\\b", "\\bwhile\\b", "\\bfor\\b",
"\\bforeach\\b", "\\breturn\\b", "\\bprint\\b", "\\binclude\\b", "\\brequire\\b",
"\\binclude_once\\b", "\\brequire_once\\b", "\\btry\\b", "\\bcatch\\b",
"\\bfinally\\b", "\\bcontinue\\b", "\\bbreak\\b", "\\bclass\\b", "\\bfunction\\b",
"\\bnew\\b", "\\bthrow\\b", "\\barray\\b", "\\bpublic\\b", "\\bprivate\\b",
"\\bprotected\\b", "\\bstatic\\b", "\\bglobal\\b", "\\bisset\\b", "\\bunset\\b",
"\\bnull\\b", "\\btrue\\b", "\\bfalse\\b"
};
for (const QString &pattern : keywordPatterns) {
rule.pattern = QRegularExpression(pattern);
rule.format = keywordFormat;
highlightingRules.append(rule);
}
QTextCharFormat stringFormat;
stringFormat.setForeground(stringColor);
rule.pattern = QRegularExpression("\".*?\"");
rule.format = stringFormat;
highlightingRules.append(rule);
rule.pattern = QRegularExpression("\'.*?\'");
rule.format = stringFormat;
highlightingRules.append(rule);
QTextCharFormat commentFormat;
commentFormat.setForeground(commentColor);
rule.pattern = QRegularExpression("//[^\n]*");
rule.format = commentFormat;
highlightingRules.append(rule);
rule.pattern = QRegularExpression("/\\*.*?\\*/");
rule.format = commentFormat;
highlightingRules.append(rule);
}
return highlightingRules;
}
static QVector<HighlightingRule> jsonHighlightingRules()
{
static QVector<HighlightingRule> highlightingRules;
@ -616,6 +759,12 @@ void SyntaxHighlighter::highlightBlock(const QString &text)
rules = javaHighlightingRules();
else if (block.userState() == Json)
rules = jsonHighlightingRules();
else if (block.userState() == Latex)
rules = latexHighlightingRules();
else if (block.userState() == Html)
rules = htmlHighlightingRules();
else if (block.userState() == Php)
rules = phpHighlightingRules();
for (const HighlightingRule &rule : qAsConst(rules)) {
QRegularExpressionMatchIterator matchIterator = rule.pattern.globalMatch(text);
@ -821,7 +970,10 @@ void ResponseText::handleCodeBlocks()
|| firstWord == "java"
|| firstWord == "go"
|| firstWord == "golang"
|| firstWord == "json") {
|| firstWord == "json"
|| firstWord == "latex"
|| firstWord == "html"
|| firstWord == "php") {
codeLanguage = firstWord;
capturedText.remove(0, match.captured(0).length());
}