mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-07-17 00:01:54 -04:00
Compare commits
3 Commits
108d950874
...
af6fe5fbb5
Author | SHA1 | Date | |
---|---|---|---|
|
af6fe5fbb5 | ||
|
ca8baa294b | ||
|
889c8d1758 |
65
gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py
Normal file
65
gpt4all-api/gpt4all_api/app/api_v1/routes/embeddings.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from typing import List, Union
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from api_v1.settings import settings
|
||||||
|
from gpt4all import Embed4All
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
### This should follow https://github.com/openai/openai-openapi/blob/master/openapi.yaml
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingRequest(BaseModel):
|
||||||
|
model: str = Field(
|
||||||
|
settings.model, description="The model to generate an embedding from."
|
||||||
|
)
|
||||||
|
input: Union[str, List[str], List[int], List[List[int]]] = Field(
|
||||||
|
..., description="Input text to embed, encoded as a string or array of tokens."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingUsage(BaseModel):
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class Embedding(BaseModel):
|
||||||
|
index: int = 0
|
||||||
|
object: str = "embedding"
|
||||||
|
embedding: List[float]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingResponse(BaseModel):
|
||||||
|
object: str = "list"
|
||||||
|
model: str
|
||||||
|
data: List[Embedding]
|
||||||
|
usage: EmbeddingUsage
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/embeddings", tags=["Embedding Endpoints"])
|
||||||
|
|
||||||
|
embedder = Embed4All()
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding(data: EmbeddingRequest) -> EmbeddingResponse:
|
||||||
|
"""
|
||||||
|
Calculates the embedding for the given input using a specified model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data (EmbeddingRequest): An EmbeddingRequest object containing the input data
|
||||||
|
and model name.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
EmbeddingResponse: An EmbeddingResponse object encapsulating the calculated embedding,
|
||||||
|
usage info, and the model name.
|
||||||
|
"""
|
||||||
|
embedding = embedder.embed(data.input)
|
||||||
|
return EmbeddingResponse(
|
||||||
|
data=[Embedding(embedding=embedding)], usage=EmbeddingUsage(), model=data.model
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=EmbeddingResponse)
|
||||||
|
def embeddings(data: EmbeddingRequest):
|
||||||
|
"""
|
||||||
|
Creates a GPT4All embedding
|
||||||
|
"""
|
||||||
|
return get_embedding(data)
|
@ -1,6 +1,8 @@
|
|||||||
"""
|
"""
|
||||||
Use the OpenAI python API to test gpt4all models.
|
Use the OpenAI python API to test gpt4all models.
|
||||||
"""
|
"""
|
||||||
|
from typing import List, get_args
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
openai.api_base = "http://localhost:4891/v1"
|
openai.api_base = "http://localhost:4891/v1"
|
||||||
@ -43,3 +45,15 @@ def test_batched_completion():
|
|||||||
)
|
)
|
||||||
assert len(response['choices'][0]['text']) > len(prompt)
|
assert len(response['choices'][0]['text']) > len(prompt)
|
||||||
assert len(response['choices']) == 3
|
assert len(response['choices']) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding():
|
||||||
|
model = "ggml-all-MiniLM-L6-v2-f16.bin"
|
||||||
|
prompt = "Who is Michael Jordan?"
|
||||||
|
response = openai.Embedding.create(model=model, input=prompt)
|
||||||
|
output = response["data"][0]["embedding"]
|
||||||
|
args = get_args(List[float])
|
||||||
|
|
||||||
|
assert response["model"] == model
|
||||||
|
assert isinstance(output, list)
|
||||||
|
assert all(isinstance(x, args) for x in output)
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
## What models are supported by the GPT4All ecosystem?
|
## What models are supported by the GPT4All ecosystem?
|
||||||
|
|
||||||
Currently, there are five different model architectures that are supported:
|
Currently, there are six different model architectures that are supported:
|
||||||
|
|
||||||
1. GPT-J - Based off of the GPT-J architecture with examples found [here](https://huggingface.co/EleutherAI/gpt-j-6b)
|
1. GPT-J - Based off of the GPT-J architecture with examples found [here](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||||
2. LLaMA - Based off of the LLaMA architecture with examples found [here](https://huggingface.co/models?sort=downloads&search=llama)
|
2. LLaMA - Based off of the LLaMA architecture with examples found [here](https://huggingface.co/models?sort=downloads&search=llama)
|
||||||
@ -13,21 +13,31 @@ Currently, there are five different model architectures that are supported:
|
|||||||
|
|
||||||
## Why so many different architectures? What differentiates them?
|
## Why so many different architectures? What differentiates them?
|
||||||
|
|
||||||
One of the major differences is license. Currently, the LLAMA based models are subject to a non-commercial license, whereas the GPTJ and MPT base models allow commercial usage. In the early advent of the recent explosion of activity in open source local models, the llama models have generally been seen as performing better, but that is changing quickly. Every week - even every day! - new models are released with some of the GPTJ and MPT models competitive in performance/quality with LLAMA. What's more, there are some very nice architectural innovations with the MPT models that could lead to new performance/quality gains.
|
One of the major differences is license. Currently, the LLaMA based models are subject to a non-commercial license, whereas the GPTJ and MPT base
|
||||||
|
models allow commercial usage. However, its successor [Llama 2 is commercially licensable](https://ai.meta.com/llama/license/), too. In the early
|
||||||
|
advent of the recent explosion of activity in open source local models, the LLaMA models have generally been seen as performing better, but that is
|
||||||
|
changing quickly. Every week - even every day! - new models are released with some of the GPTJ and MPT models competitive in performance/quality with
|
||||||
|
LLaMA. What's more, there are some very nice architectural innovations with the MPT models that could lead to new performance/quality gains.
|
||||||
|
|
||||||
## How does GPT4All make these models available for CPU inference?
|
## How does GPT4All make these models available for CPU inference?
|
||||||
|
|
||||||
By leveraging the ggml library written by Georgi Gerganov and a growing community of developers. There are currently multiple different versions of this library. The original GitHub repo can be found [here](https://github.com/ggerganov/ggml), but the developer of the library has also created a LLAMA based version [here](https://github.com/ggerganov/llama.cpp). Currently, this backend is using the latter as a submodule.
|
By leveraging the ggml library written by Georgi Gerganov and a growing community of developers. There are currently multiple different versions of
|
||||||
|
this library. The original GitHub repo can be found [here](https://github.com/ggerganov/ggml), but the developer of the library has also created a
|
||||||
|
LLaMA based version [here](https://github.com/ggerganov/llama.cpp). Currently, this backend is using the latter as a submodule.
|
||||||
|
|
||||||
## Does that mean GPT4All is compatible with all llama.cpp models and vice versa?
|
## Does that mean GPT4All is compatible with all llama.cpp models and vice versa?
|
||||||
|
|
||||||
Yes!
|
Yes!
|
||||||
|
|
||||||
The upstream [llama.cpp](https://github.com/ggerganov/llama.cpp) project has introduced several [compatibility breaking](https://github.com/ggerganov/llama.cpp/commit/b9fd7eee57df101d4a3e3eabc9fd6c2cb13c9ca1) quantization methods recently. This is a breaking change that renders all previous models (including the ones that GPT4All uses) inoperative with newer versions of llama.cpp since that change.
|
The upstream [llama.cpp](https://github.com/ggerganov/llama.cpp) project has introduced several [compatibility breaking] quantization methods recently.
|
||||||
|
This is a breaking change that renders all previous models (including the ones that GPT4All uses) inoperative with newer versions of llama.cpp since
|
||||||
|
that change.
|
||||||
|
|
||||||
Fortunately, we have engineered a submoduling system allowing us to dynamically load different versions of the underlying library so that
|
Fortunately, we have engineered a submoduling system allowing us to dynamically load different versions of the underlying library so that
|
||||||
GPT4All just works.
|
GPT4All just works.
|
||||||
|
|
||||||
|
[compatibility breaking]: https://github.com/ggerganov/llama.cpp/commit/b9fd7eee57df101d4a3e3eabc9fd6c2cb13c9ca1
|
||||||
|
|
||||||
## What are the system requirements?
|
## What are the system requirements?
|
||||||
|
|
||||||
Your CPU needs to support [AVX or AVX2 instructions](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) and you need enough RAM to load a model into memory.
|
Your CPU needs to support [AVX or AVX2 instructions](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) and you need enough RAM to load a model into memory.
|
||||||
@ -39,7 +49,7 @@ In newer versions of llama.cpp, there has been some added support for NVIDIA GPU
|
|||||||
## Ok, so bottom line... how do I make my model on Hugging Face compatible with GPT4All ecosystem right now?
|
## Ok, so bottom line... how do I make my model on Hugging Face compatible with GPT4All ecosystem right now?
|
||||||
|
|
||||||
1. Check to make sure the Hugging Face model is available in one of our three supported architectures
|
1. Check to make sure the Hugging Face model is available in one of our three supported architectures
|
||||||
2. If it is, then you can use the conversion script inside of our pinned llama.cpp submodule for GPTJ and LLAMA based models
|
2. If it is, then you can use the conversion script inside of our pinned llama.cpp submodule for GPTJ and LLaMA based models
|
||||||
3. Or if your model is an MPT model you can use the conversion script located directly in this backend directory under the scripts subdirectory
|
3. Or if your model is an MPT model you can use the conversion script located directly in this backend directory under the scripts subdirectory
|
||||||
|
|
||||||
## Language Bindings
|
## Language Bindings
|
||||||
|
@ -40,6 +40,7 @@ One click installers for macOS, Linux, and Windows at https://gpt4all.io
|
|||||||
* Syntax highlighting support for programming languages, etc.
|
* Syntax highlighting support for programming languages, etc.
|
||||||
* REST API with a built-in webserver in the chat gui itself with a headless operation mode as well
|
* REST API with a built-in webserver in the chat gui itself with a headless operation mode as well
|
||||||
* Advanced settings for changing temperature, topk, etc. (DONE)
|
* Advanced settings for changing temperature, topk, etc. (DONE)
|
||||||
|
* * Improve the accessibility of the installer for screen reader users
|
||||||
* YOUR IDEA HERE
|
* YOUR IDEA HERE
|
||||||
|
|
||||||
## Building and running
|
## Building and running
|
||||||
|
Loading…
x
Reference in New Issue
Block a user