mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-16 00:04:55 -04:00
Compare commits
4 Commits
2d02c65177
...
b4dbbd1485
Author | SHA1 | Date | |
---|---|---|---|
|
b4dbbd1485 | ||
|
5f0aaf8bdb | ||
|
4974ae917c | ||
|
63849d9afc |
@ -29,7 +29,7 @@ Run on an M1 macOS Device (not sped up!)
|
|||||||
</p>
|
</p>
|
||||||
|
|
||||||
## GPT4All: An ecosystem of open-source on-edge large language models.
|
## GPT4All: An ecosystem of open-source on-edge large language models.
|
||||||
GPT4All is an ecosystem to train and deploy **powerful** and **customized** large language models that run locally on consumer grade CPUs.
|
GPT4All is an ecosystem to train and deploy **powerful** and **customized** large language models that run locally on consumer grade CPUs. Note that your CPU needs to support [AVX or AVX2 instructions](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions).
|
||||||
|
|
||||||
Learn more in the [documentation](https://docs.gpt4all.io).
|
Learn more in the [documentation](https://docs.gpt4all.io).
|
||||||
|
|
||||||
|
@ -91,22 +91,4 @@ To interact with GPT4All responses as the model generates, use the `streaming =
|
|||||||
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
|
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Streaming and Chat Sessions
|
|
||||||
When streaming tokens in a chat session, you must manually handle collection and updating of the chat history.
|
|
||||||
|
|
||||||
```python
|
|
||||||
from gpt4all import GPT4All
|
|
||||||
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
|
||||||
|
|
||||||
with model.chat_session():
|
|
||||||
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
|
|
||||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
|
||||||
|
|
||||||
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
|
|
||||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
|
||||||
|
|
||||||
print(model.current_chat_session)
|
|
||||||
```
|
|
||||||
|
|
||||||
### API documentation
|
|
||||||
::: gpt4all.gpt4all.GPT4All
|
::: gpt4all.gpt4all.GPT4All
|
||||||
|
@ -5,7 +5,7 @@ import os
|
|||||||
import time
|
import time
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, Union, Optional
|
from typing import Any, Dict, Iterable, List, Union, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -13,7 +13,17 @@ from tqdm import tqdm
|
|||||||
from . import pyllmodel
|
from . import pyllmodel
|
||||||
|
|
||||||
# TODO: move to config
|
# TODO: move to config
|
||||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace(
|
||||||
|
"\\", "\\\\"
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_MODEL_CONFIG = {
|
||||||
|
"systemPrompt": "",
|
||||||
|
"promptTemplate": "### Human: \n{0}\n### Assistant:\n",
|
||||||
|
}
|
||||||
|
|
||||||
|
ConfigType = Dict[str,str]
|
||||||
|
MessageType = Dict[str, str]
|
||||||
|
|
||||||
class Embed4All:
|
class Embed4All:
|
||||||
"""
|
"""
|
||||||
@ -34,7 +44,7 @@ class Embed4All:
|
|||||||
def embed(
|
def embed(
|
||||||
self,
|
self,
|
||||||
text: str
|
text: str
|
||||||
) -> list[float]:
|
) -> List[float]:
|
||||||
"""
|
"""
|
||||||
Generate an embedding.
|
Generate an embedding.
|
||||||
|
|
||||||
@ -74,17 +84,20 @@ class GPT4All:
|
|||||||
self.model_type = model_type
|
self.model_type = model_type
|
||||||
self.model = pyllmodel.LLModel()
|
self.model = pyllmodel.LLModel()
|
||||||
# Retrieve model and download if allowed
|
# Retrieve model and download if allowed
|
||||||
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
self.config: ConfigType = self.retrieve_model(
|
||||||
self.model.load_model(model_dest)
|
model_name, model_path=model_path, allow_download=allow_download
|
||||||
|
)
|
||||||
|
self.model.load_model(self.config["path"])
|
||||||
# Set n_threads
|
# Set n_threads
|
||||||
if n_threads is not None:
|
if n_threads is not None:
|
||||||
self.model.set_thread_count(n_threads)
|
self.model.set_thread_count(n_threads)
|
||||||
|
|
||||||
self._is_chat_session_activated = False
|
self._is_chat_session_activated: bool = False
|
||||||
self.current_chat_session = []
|
self.current_chat_session: List[MessageType] = empty_chat_session()
|
||||||
|
self._current_prompt_template: str = "{0}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def list_models() -> Dict:
|
def list_models() -> List[ConfigType]:
|
||||||
"""
|
"""
|
||||||
Fetch model list from https://gpt4all.io/models/models.json.
|
Fetch model list from https://gpt4all.io/models/models.json.
|
||||||
|
|
||||||
@ -95,8 +108,11 @@ class GPT4All:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def retrieve_model(
|
def retrieve_model(
|
||||||
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
|
model_name: str,
|
||||||
) -> str:
|
model_path: Optional[str] = None,
|
||||||
|
allow_download: bool = True,
|
||||||
|
verbose: bool = True,
|
||||||
|
) -> ConfigType:
|
||||||
"""
|
"""
|
||||||
Find model file, and if it doesn't exist, download the model.
|
Find model file, and if it doesn't exist, download the model.
|
||||||
|
|
||||||
@ -108,11 +124,25 @@ class GPT4All:
|
|||||||
verbose: If True (default), print debug messages.
|
verbose: If True (default), print debug messages.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Model file destination.
|
Model config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_filename = append_bin_suffix_if_missing(model_name)
|
model_filename = append_bin_suffix_if_missing(model_name)
|
||||||
|
|
||||||
|
# get the config for the model
|
||||||
|
config: ConfigType = DEFAULT_MODEL_CONFIG
|
||||||
|
if allow_download:
|
||||||
|
available_models = GPT4All.list_models()
|
||||||
|
|
||||||
|
for m in available_models:
|
||||||
|
if model_filename == m["filename"]:
|
||||||
|
config.update(m)
|
||||||
|
config["systemPrompt"] = config["systemPrompt"].strip()
|
||||||
|
config["promptTemplate"] = config["promptTemplate"].replace(
|
||||||
|
"%1", "{0}", 1
|
||||||
|
) # change to Python-style formatting
|
||||||
|
break
|
||||||
|
|
||||||
# Validate download directory
|
# Validate download directory
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
try:
|
try:
|
||||||
@ -131,31 +161,34 @@ class GPT4All:
|
|||||||
|
|
||||||
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||||
if os.path.exists(model_dest):
|
if os.path.exists(model_dest):
|
||||||
|
config.pop("url", None)
|
||||||
|
config["path"] = model_dest
|
||||||
if verbose:
|
if verbose:
|
||||||
print("Found model file at ", model_dest)
|
print("Found model file at ", model_dest)
|
||||||
return model_dest
|
|
||||||
|
|
||||||
# If model file does not exist, download
|
# If model file does not exist, download
|
||||||
elif allow_download:
|
elif allow_download:
|
||||||
# Make sure valid model filename before attempting download
|
# Make sure valid model filename before attempting download
|
||||||
available_models = GPT4All.list_models()
|
|
||||||
|
|
||||||
selected_model = None
|
if "url" not in config:
|
||||||
for m in available_models:
|
|
||||||
if model_filename == m['filename']:
|
|
||||||
selected_model = m
|
|
||||||
break
|
|
||||||
|
|
||||||
if selected_model is None:
|
|
||||||
raise ValueError(f"Model filename not in model list: {model_filename}")
|
raise ValueError(f"Model filename not in model list: {model_filename}")
|
||||||
url = selected_model.pop('url', None)
|
url = config.pop("url", None)
|
||||||
|
|
||||||
return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
|
config["path"] = GPT4All.download_model(
|
||||||
|
model_filename, model_path, verbose=verbose, url=url
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Failed to retrieve model")
|
raise ValueError("Failed to retrieve model")
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
|
def download_model(
|
||||||
|
model_filename: str,
|
||||||
|
model_path: str,
|
||||||
|
verbose: bool = True,
|
||||||
|
url: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Download model from https://gpt4all.io.
|
Download model from https://gpt4all.io.
|
||||||
|
|
||||||
@ -191,7 +224,7 @@ class GPT4All:
|
|||||||
except Exception:
|
except Exception:
|
||||||
if os.path.exists(download_path):
|
if os.path.exists(download_path):
|
||||||
if verbose:
|
if verbose:
|
||||||
print('Cleaning up the interrupted download...')
|
print("Cleaning up the interrupted download...")
|
||||||
os.remove(download_path)
|
os.remove(download_path)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@ -212,13 +245,14 @@ class GPT4All:
|
|||||||
max_tokens: int = 200,
|
max_tokens: int = 200,
|
||||||
temp: float = 0.7,
|
temp: float = 0.7,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = 0.1,
|
top_p: float = 0.4,
|
||||||
repeat_penalty: float = 1.18,
|
repeat_penalty: float = 1.18,
|
||||||
repeat_last_n: int = 64,
|
repeat_last_n: int = 64,
|
||||||
n_batch: int = 8,
|
n_batch: int = 8,
|
||||||
n_predict: Optional[int] = None,
|
n_predict: Optional[int] = None,
|
||||||
streaming: bool = False,
|
streaming: bool = False,
|
||||||
) -> Union[str, Iterable]:
|
callback: pyllmodel.ResponseCallbackType = pyllmodel.empty_response_callback,
|
||||||
|
) -> Union[str, Iterable[str]]:
|
||||||
"""
|
"""
|
||||||
Generate outputs from any GPT4All model.
|
Generate outputs from any GPT4All model.
|
||||||
|
|
||||||
@ -233,12 +267,14 @@ class GPT4All:
|
|||||||
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
|
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
|
||||||
n_predict: Equivalent to max_tokens, exists for backwards compatibility.
|
n_predict: Equivalent to max_tokens, exists for backwards compatibility.
|
||||||
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
|
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
|
||||||
|
callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either the entire completion or a generator that yields the completion token by token.
|
Either the entire completion or a generator that yields the completion token by token.
|
||||||
"""
|
"""
|
||||||
generate_kwargs = dict(
|
|
||||||
prompt=prompt,
|
# Preparing the model request
|
||||||
|
generate_kwargs: Dict[str, Any] = dict(
|
||||||
temp=temp,
|
temp=temp,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
@ -249,42 +285,87 @@ class GPT4All:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self._is_chat_session_activated:
|
if self._is_chat_session_activated:
|
||||||
|
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt
|
||||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||||
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session[-1:])
|
|
||||||
generate_kwargs['reset_context'] = len(self.current_chat_session) == 1
|
prompt = self._format_chat_prompt_template(
|
||||||
|
messages = self.current_chat_session[-1:],
|
||||||
|
default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
generate_kwargs['reset_context'] = True
|
generate_kwargs["reset_context"] = True
|
||||||
|
|
||||||
if streaming:
|
# Prepare the callback, process the model response
|
||||||
return self.model.prompt_model_streaming(**generate_kwargs)
|
output_collector: List[MessageType]
|
||||||
|
output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated
|
||||||
output = self.model.prompt_model(**generate_kwargs)
|
|
||||||
|
|
||||||
if self._is_chat_session_activated:
|
if self._is_chat_session_activated:
|
||||||
self.current_chat_session.append({"role": "assistant", "content": output})
|
self.current_chat_session.append({"role": "assistant", "content": ""})
|
||||||
|
output_collector = self.current_chat_session
|
||||||
|
|
||||||
return output
|
def _callback_wrapper(
|
||||||
|
callback: pyllmodel.ResponseCallbackType,
|
||||||
|
output_collector: List[MessageType],
|
||||||
|
) -> pyllmodel.ResponseCallbackType:
|
||||||
|
|
||||||
|
def _callback(token_id: int, response: str) -> bool:
|
||||||
|
nonlocal callback, output_collector
|
||||||
|
|
||||||
|
output_collector[-1]["content"] += response
|
||||||
|
|
||||||
|
return callback(token_id, response)
|
||||||
|
|
||||||
|
return _callback
|
||||||
|
|
||||||
|
# Send the request to the model
|
||||||
|
if streaming:
|
||||||
|
return self.model.prompt_model_streaming(
|
||||||
|
prompt=prompt,
|
||||||
|
callback=_callback_wrapper(callback, output_collector),
|
||||||
|
**generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.model.prompt_model(
|
||||||
|
prompt=prompt,
|
||||||
|
callback=_callback_wrapper(callback, output_collector),
|
||||||
|
**generate_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return output_collector[-1]["content"]
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def chat_session(self):
|
def chat_session(
|
||||||
'''
|
self,
|
||||||
|
system_prompt: str = "",
|
||||||
|
prompt_template: str = "",
|
||||||
|
):
|
||||||
|
"""
|
||||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||||
'''
|
|
||||||
|
Args:
|
||||||
|
system_prompt: An initial instruction for the model.
|
||||||
|
prompt_template: Template for the prompts with {0} being replaced by the user message.
|
||||||
|
"""
|
||||||
# Code to acquire resource, e.g.:
|
# Code to acquire resource, e.g.:
|
||||||
self._is_chat_session_activated = True
|
self._is_chat_session_activated = True
|
||||||
self.current_chat_session = []
|
self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"])
|
||||||
|
self._current_prompt_template = prompt_template or self.config["promptTemplate"]
|
||||||
try:
|
try:
|
||||||
yield self
|
yield self
|
||||||
finally:
|
finally:
|
||||||
# Code to release resource, e.g.:
|
# Code to release resource, e.g.:
|
||||||
self._is_chat_session_activated = False
|
self._is_chat_session_activated = False
|
||||||
self.current_chat_session = []
|
self.current_chat_session = empty_chat_session()
|
||||||
|
self._current_prompt_template = "{0}"
|
||||||
|
|
||||||
def _format_chat_prompt_template(
|
def _format_chat_prompt_template(
|
||||||
self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True
|
self,
|
||||||
|
messages: List[MessageType],
|
||||||
|
default_prompt_header: str = "",
|
||||||
|
default_prompt_footer: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Helper method for building a prompt using template from list of messages.
|
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of dictionaries. Each dictionary should have a "role" key
|
messages: List of dictionaries. Each dictionary should have a "role" key
|
||||||
@ -296,19 +377,44 @@ class GPT4All:
|
|||||||
Returns:
|
Returns:
|
||||||
Formatted prompt.
|
Formatted prompt.
|
||||||
"""
|
"""
|
||||||
full_prompt = ""
|
|
||||||
|
if isinstance(default_prompt_header, bool):
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
default_prompt_header = ""
|
||||||
|
|
||||||
|
if isinstance(default_prompt_footer, bool):
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn(
|
||||||
|
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
default_prompt_footer = ""
|
||||||
|
|
||||||
|
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if message["role"] == "user":
|
if message["role"] == "user":
|
||||||
user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n"
|
user_message = self._current_prompt_template.format(message["content"])
|
||||||
full_prompt += user_message
|
full_prompt += user_message
|
||||||
if message["role"] == "assistant":
|
if message["role"] == "assistant":
|
||||||
assistant_message = message["content"] + '\n'
|
assistant_message = message["content"] + "\n"
|
||||||
full_prompt += assistant_message
|
full_prompt += assistant_message
|
||||||
|
|
||||||
|
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
|
||||||
|
|
||||||
return full_prompt
|
return full_prompt
|
||||||
|
|
||||||
|
|
||||||
|
def empty_chat_session(system_prompt: str = "") -> List[MessageType]:
|
||||||
|
return [{"role": "system", "content": system_prompt}]
|
||||||
|
|
||||||
|
|
||||||
def append_bin_suffix_if_missing(model_name):
|
def append_bin_suffix_if_missing(model_name):
|
||||||
if not model_name.endswith(".bin"):
|
if not model_name.endswith(".bin"):
|
||||||
model_name += ".bin"
|
model_name += ".bin"
|
||||||
|
@ -6,26 +6,19 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from typing import Iterable
|
import logging
|
||||||
|
from typing import Iterable, Callable, List
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
|
logger: logging.Logger = logging.getLogger(__name__)
|
||||||
class DualStreamProcessor:
|
|
||||||
def __init__(self, stream=None):
|
|
||||||
self.stream = stream
|
|
||||||
self.output = ""
|
|
||||||
|
|
||||||
def write(self, text):
|
|
||||||
if self.stream is not None:
|
|
||||||
self.stream.write(text)
|
|
||||||
self.stream.flush()
|
|
||||||
self.output += text
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: provide a config file to make this more robust
|
# TODO: provide a config file to make this more robust
|
||||||
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
||||||
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
|
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace(
|
||||||
|
"\\", "\\\\"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_llmodel_library():
|
def load_llmodel_library():
|
||||||
@ -43,9 +36,9 @@ def load_llmodel_library():
|
|||||||
|
|
||||||
c_lib_ext = get_c_shared_lib_extension()
|
c_lib_ext = get_c_shared_lib_extension()
|
||||||
|
|
||||||
llmodel_file = "libllmodel" + '.' + c_lib_ext
|
llmodel_file = "libllmodel" + "." + c_lib_ext
|
||||||
|
|
||||||
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace(
|
llmodel_dir = str(pkg_resources.resource_filename("gpt4all", os.path.join(LLMODEL_PATH, llmodel_file))).replace(
|
||||||
"\\", "\\\\"
|
"\\", "\\\\"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -134,7 +127,15 @@ llmodel.llmodel_set_implementation_search_path.restype = None
|
|||||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||||
|
|
||||||
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode('utf-8'))
|
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode("utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
ResponseCallbackType = Callable[[int, str], bool]
|
||||||
|
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||||
|
|
||||||
|
|
||||||
|
def empty_response_callback(token_id: int, response: str) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
class LLModel:
|
class LLModel:
|
||||||
@ -250,9 +251,10 @@ class LLModel:
|
|||||||
def generate_embedding(
|
def generate_embedding(
|
||||||
self,
|
self,
|
||||||
text: str
|
text: str
|
||||||
) -> list[float]:
|
) -> List[float]:
|
||||||
if not text:
|
if not text:
|
||||||
raise ValueError("Text must not be None or empty")
|
raise ValueError("Text must not be None or empty")
|
||||||
|
|
||||||
embedding_size = ctypes.c_size_t()
|
embedding_size = ctypes.c_size_t()
|
||||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||||
@ -263,6 +265,7 @@ class LLModel:
|
|||||||
def prompt_model(
|
def prompt_model(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
callback: ResponseCallbackType,
|
||||||
n_predict: int = 4096,
|
n_predict: int = 4096,
|
||||||
top_k: int = 40,
|
top_k: int = 40,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
@ -272,8 +275,7 @@ class LLModel:
|
|||||||
repeat_last_n: int = 10,
|
repeat_last_n: int = 10,
|
||||||
context_erase: float = 0.75,
|
context_erase: float = 0.75,
|
||||||
reset_context: bool = False,
|
reset_context: bool = False,
|
||||||
streaming=False,
|
):
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Generate response from model from a prompt.
|
Generate response from model from a prompt.
|
||||||
|
|
||||||
@ -281,26 +283,24 @@ class LLModel:
|
|||||||
----------
|
----------
|
||||||
prompt: str
|
prompt: str
|
||||||
Question, task, or conversation for model to respond to
|
Question, task, or conversation for model to respond to
|
||||||
streaming: bool
|
callback(token_id:int, response:str): bool
|
||||||
Stream response to stdout
|
The model sends response tokens to callback
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Model response str
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
prompt_bytes = prompt.encode('utf-8')
|
logger.info(
|
||||||
|
"LLModel.prompt_model -- prompt:\n"
|
||||||
|
+ "%s\n"
|
||||||
|
+ "===/LLModel.prompt_model -- prompt/===",
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_bytes = prompt.encode("utf-8")
|
||||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
|
|
||||||
stream_processor = DualStreamProcessor()
|
|
||||||
|
|
||||||
if streaming:
|
|
||||||
stream_processor.stream = sys.stdout
|
|
||||||
|
|
||||||
sys.stdout = stream_processor
|
|
||||||
|
|
||||||
self._set_context(
|
self._set_context(
|
||||||
n_predict=n_predict,
|
n_predict=n_predict,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
@ -317,56 +317,37 @@ class LLModel:
|
|||||||
self.model,
|
self.model,
|
||||||
prompt_ptr,
|
prompt_ptr,
|
||||||
PromptCallback(self._prompt_callback),
|
PromptCallback(self._prompt_callback),
|
||||||
ResponseCallback(self._response_callback),
|
ResponseCallback(self._callback_decoder(callback)),
|
||||||
RecalculateCallback(self._recalculate_callback),
|
RecalculateCallback(self._recalculate_callback),
|
||||||
self.context,
|
self.context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Revert to old stdout
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
# Force new line
|
|
||||||
return stream_processor.output
|
|
||||||
|
|
||||||
def prompt_model_streaming(
|
def prompt_model_streaming(
|
||||||
self,
|
self,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
n_predict: int = 4096,
|
callback: ResponseCallbackType = empty_response_callback,
|
||||||
top_k: int = 40,
|
**kwargs
|
||||||
top_p: float = 0.9,
|
) -> Iterable[str]:
|
||||||
temp: float = 0.1,
|
|
||||||
n_batch: int = 8,
|
|
||||||
repeat_penalty: float = 1.2,
|
|
||||||
repeat_last_n: int = 10,
|
|
||||||
context_erase: float = 0.75,
|
|
||||||
reset_context: bool = False,
|
|
||||||
) -> Iterable:
|
|
||||||
# Symbol to terminate from generator
|
# Symbol to terminate from generator
|
||||||
TERMINATING_SYMBOL = object()
|
TERMINATING_SYMBOL = object()
|
||||||
|
|
||||||
output_queue = queue.Queue()
|
output_queue = queue.Queue()
|
||||||
|
|
||||||
prompt_bytes = prompt.encode('utf-8')
|
|
||||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
|
||||||
|
|
||||||
self._set_context(
|
|
||||||
n_predict=n_predict,
|
|
||||||
top_k=top_k,
|
|
||||||
top_p=top_p,
|
|
||||||
temp=temp,
|
|
||||||
n_batch=n_batch,
|
|
||||||
repeat_penalty=repeat_penalty,
|
|
||||||
repeat_last_n=repeat_last_n,
|
|
||||||
context_erase=context_erase,
|
|
||||||
reset_context=reset_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Put response tokens into an output queue
|
# Put response tokens into an output queue
|
||||||
def _generator_response_callback(token_id, response):
|
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||||
output_queue.put(response.decode('utf-8', 'replace'))
|
def _generator_callback(token_id: int, response: str):
|
||||||
|
nonlocal callback
|
||||||
|
|
||||||
|
if callback(token_id, response):
|
||||||
|
output_queue.put(response)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context):
|
return False
|
||||||
llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context)
|
|
||||||
|
return _generator_callback
|
||||||
|
|
||||||
|
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||||
|
self.prompt_model(prompt, callback, **kwargs)
|
||||||
output_queue.put(TERMINATING_SYMBOL)
|
output_queue.put(TERMINATING_SYMBOL)
|
||||||
|
|
||||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||||
@ -374,13 +355,10 @@ class LLModel:
|
|||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=run_llmodel_prompt,
|
target=run_llmodel_prompt,
|
||||||
args=(
|
args=(
|
||||||
self.model,
|
prompt,
|
||||||
prompt_ptr,
|
_generator_callback_wrapper(callback)
|
||||||
PromptCallback(self._prompt_callback),
|
|
||||||
ResponseCallback(_generator_response_callback),
|
|
||||||
RecalculateCallback(self._recalculate_callback),
|
|
||||||
self.context,
|
|
||||||
),
|
),
|
||||||
|
kwargs=kwargs,
|
||||||
)
|
)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
@ -391,18 +369,19 @@ class LLModel:
|
|||||||
break
|
break
|
||||||
yield response
|
yield response
|
||||||
|
|
||||||
|
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
|
||||||
|
def _raw_callback(token_id: int, response: bytes) -> bool:
|
||||||
|
nonlocal callback
|
||||||
|
return callback(token_id, response.decode("utf-8", "replace"))
|
||||||
|
|
||||||
|
return _raw_callback
|
||||||
|
|
||||||
# Empty prompt callback
|
# Empty prompt callback
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _prompt_callback(token_id):
|
def _prompt_callback(token_id: int) -> bool:
|
||||||
return True
|
|
||||||
|
|
||||||
# Empty response callback method that just prints response to be collected
|
|
||||||
@staticmethod
|
|
||||||
def _response_callback(token_id, response):
|
|
||||||
sys.stdout.write(response.decode('utf-8', 'replace'))
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Empty recalculate callback
|
# Empty recalculate callback
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _recalculate_callback(is_recalculating):
|
def _recalculate_callback(is_recalculating: bool) -> bool:
|
||||||
return is_recalculating
|
return is_recalculating
|
||||||
|
@ -108,7 +108,7 @@ private:
|
|||||||
QString m_name;
|
QString m_name;
|
||||||
QString m_filename;
|
QString m_filename;
|
||||||
double m_temperature = 0.7;
|
double m_temperature = 0.7;
|
||||||
double m_topP = 0.1;
|
double m_topP = 0.4;
|
||||||
int m_topK = 40;
|
int m_topK = 40;
|
||||||
int m_maxLength = 4096;
|
int m_maxLength = 4096;
|
||||||
int m_promptBatchSize = 128;
|
int m_promptBatchSize = 128;
|
||||||
|
Loading…
x
Reference in New Issue
Block a user