mirror of
https://github.com/nomic-ai/gpt4all.git
synced 2025-08-13 00:05:57 -04:00
Compare commits
4 Commits
2d02c65177
...
b4dbbd1485
Author | SHA1 | Date | |
---|---|---|---|
|
b4dbbd1485 | ||
|
5f0aaf8bdb | ||
|
4974ae917c | ||
|
63849d9afc |
@ -29,7 +29,7 @@ Run on an M1 macOS Device (not sped up!)
|
||||
</p>
|
||||
|
||||
## GPT4All: An ecosystem of open-source on-edge large language models.
|
||||
GPT4All is an ecosystem to train and deploy **powerful** and **customized** large language models that run locally on consumer grade CPUs.
|
||||
GPT4All is an ecosystem to train and deploy **powerful** and **customized** large language models that run locally on consumer grade CPUs. Note that your CPU needs to support [AVX or AVX2 instructions](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions).
|
||||
|
||||
Learn more in the [documentation](https://docs.gpt4all.io).
|
||||
|
||||
|
@ -91,22 +91,4 @@ To interact with GPT4All responses as the model generates, use the `streaming =
|
||||
[' Paris', ' is', ' a', ' city', ' that', ' has', ' been', ' a', ' major', ' cultural', ' and', ' economic', ' center', ' for', ' over', ' ', '2', ',', '0', '0']
|
||||
```
|
||||
|
||||
#### Streaming and Chat Sessions
|
||||
When streaming tokens in a chat session, you must manually handle collection and updating of the chat history.
|
||||
|
||||
```python
|
||||
from gpt4all import GPT4All
|
||||
model = GPT4All("orca-mini-3b.ggmlv3.q4_0.bin")
|
||||
|
||||
with model.chat_session():
|
||||
tokens = list(model.generate(prompt='hello', top_k=1, streaming=True))
|
||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
||||
|
||||
tokens = list(model.generate(prompt='write me a poem about dogs', top_k=1, streaming=True))
|
||||
model.current_chat_session.append({'role': 'assistant', 'content': ''.join(tokens)})
|
||||
|
||||
print(model.current_chat_session)
|
||||
```
|
||||
|
||||
### API documentation
|
||||
::: gpt4all.gpt4all.GPT4All
|
||||
|
@ -5,7 +5,7 @@ import os
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Union, Optional
|
||||
from typing import Any, Dict, Iterable, List, Union, Optional
|
||||
|
||||
import requests
|
||||
from tqdm import tqdm
|
||||
@ -13,7 +13,17 @@ from tqdm import tqdm
|
||||
from . import pyllmodel
|
||||
|
||||
# TODO: move to config
|
||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace("\\", "\\\\")
|
||||
DEFAULT_MODEL_DIRECTORY = os.path.join(str(Path.home()), ".cache", "gpt4all").replace(
|
||||
"\\", "\\\\"
|
||||
)
|
||||
|
||||
DEFAULT_MODEL_CONFIG = {
|
||||
"systemPrompt": "",
|
||||
"promptTemplate": "### Human: \n{0}\n### Assistant:\n",
|
||||
}
|
||||
|
||||
ConfigType = Dict[str,str]
|
||||
MessageType = Dict[str, str]
|
||||
|
||||
class Embed4All:
|
||||
"""
|
||||
@ -34,7 +44,7 @@ class Embed4All:
|
||||
def embed(
|
||||
self,
|
||||
text: str
|
||||
) -> list[float]:
|
||||
) -> List[float]:
|
||||
"""
|
||||
Generate an embedding.
|
||||
|
||||
@ -74,17 +84,20 @@ class GPT4All:
|
||||
self.model_type = model_type
|
||||
self.model = pyllmodel.LLModel()
|
||||
# Retrieve model and download if allowed
|
||||
model_dest = self.retrieve_model(model_name, model_path=model_path, allow_download=allow_download)
|
||||
self.model.load_model(model_dest)
|
||||
self.config: ConfigType = self.retrieve_model(
|
||||
model_name, model_path=model_path, allow_download=allow_download
|
||||
)
|
||||
self.model.load_model(self.config["path"])
|
||||
# Set n_threads
|
||||
if n_threads is not None:
|
||||
self.model.set_thread_count(n_threads)
|
||||
|
||||
self._is_chat_session_activated = False
|
||||
self.current_chat_session = []
|
||||
self._is_chat_session_activated: bool = False
|
||||
self.current_chat_session: List[MessageType] = empty_chat_session()
|
||||
self._current_prompt_template: str = "{0}"
|
||||
|
||||
@staticmethod
|
||||
def list_models() -> Dict:
|
||||
def list_models() -> List[ConfigType]:
|
||||
"""
|
||||
Fetch model list from https://gpt4all.io/models/models.json.
|
||||
|
||||
@ -95,8 +108,11 @@ class GPT4All:
|
||||
|
||||
@staticmethod
|
||||
def retrieve_model(
|
||||
model_name: str, model_path: Optional[str] = None, allow_download: bool = True, verbose: bool = True
|
||||
) -> str:
|
||||
model_name: str,
|
||||
model_path: Optional[str] = None,
|
||||
allow_download: bool = True,
|
||||
verbose: bool = True,
|
||||
) -> ConfigType:
|
||||
"""
|
||||
Find model file, and if it doesn't exist, download the model.
|
||||
|
||||
@ -108,11 +124,25 @@ class GPT4All:
|
||||
verbose: If True (default), print debug messages.
|
||||
|
||||
Returns:
|
||||
Model file destination.
|
||||
Model config.
|
||||
"""
|
||||
|
||||
model_filename = append_bin_suffix_if_missing(model_name)
|
||||
|
||||
# get the config for the model
|
||||
config: ConfigType = DEFAULT_MODEL_CONFIG
|
||||
if allow_download:
|
||||
available_models = GPT4All.list_models()
|
||||
|
||||
for m in available_models:
|
||||
if model_filename == m["filename"]:
|
||||
config.update(m)
|
||||
config["systemPrompt"] = config["systemPrompt"].strip()
|
||||
config["promptTemplate"] = config["promptTemplate"].replace(
|
||||
"%1", "{0}", 1
|
||||
) # change to Python-style formatting
|
||||
break
|
||||
|
||||
# Validate download directory
|
||||
if model_path is None:
|
||||
try:
|
||||
@ -131,31 +161,34 @@ class GPT4All:
|
||||
|
||||
model_dest = os.path.join(model_path, model_filename).replace("\\", "\\\\")
|
||||
if os.path.exists(model_dest):
|
||||
config.pop("url", None)
|
||||
config["path"] = model_dest
|
||||
if verbose:
|
||||
print("Found model file at ", model_dest)
|
||||
return model_dest
|
||||
|
||||
# If model file does not exist, download
|
||||
elif allow_download:
|
||||
# Make sure valid model filename before attempting download
|
||||
available_models = GPT4All.list_models()
|
||||
|
||||
selected_model = None
|
||||
for m in available_models:
|
||||
if model_filename == m['filename']:
|
||||
selected_model = m
|
||||
break
|
||||
|
||||
if selected_model is None:
|
||||
if "url" not in config:
|
||||
raise ValueError(f"Model filename not in model list: {model_filename}")
|
||||
url = selected_model.pop('url', None)
|
||||
url = config.pop("url", None)
|
||||
|
||||
return GPT4All.download_model(model_filename, model_path, verbose=verbose, url=url)
|
||||
config["path"] = GPT4All.download_model(
|
||||
model_filename, model_path, verbose=verbose, url=url
|
||||
)
|
||||
else:
|
||||
raise ValueError("Failed to retrieve model")
|
||||
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def download_model(model_filename: str, model_path: str, verbose: bool = True, url: Optional[str] = None) -> str:
|
||||
def download_model(
|
||||
model_filename: str,
|
||||
model_path: str,
|
||||
verbose: bool = True,
|
||||
url: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Download model from https://gpt4all.io.
|
||||
|
||||
@ -191,7 +224,7 @@ class GPT4All:
|
||||
except Exception:
|
||||
if os.path.exists(download_path):
|
||||
if verbose:
|
||||
print('Cleaning up the interrupted download...')
|
||||
print("Cleaning up the interrupted download...")
|
||||
os.remove(download_path)
|
||||
raise
|
||||
|
||||
@ -212,13 +245,14 @@ class GPT4All:
|
||||
max_tokens: int = 200,
|
||||
temp: float = 0.7,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.1,
|
||||
top_p: float = 0.4,
|
||||
repeat_penalty: float = 1.18,
|
||||
repeat_last_n: int = 64,
|
||||
n_batch: int = 8,
|
||||
n_predict: Optional[int] = None,
|
||||
streaming: bool = False,
|
||||
) -> Union[str, Iterable]:
|
||||
callback: pyllmodel.ResponseCallbackType = pyllmodel.empty_response_callback,
|
||||
) -> Union[str, Iterable[str]]:
|
||||
"""
|
||||
Generate outputs from any GPT4All model.
|
||||
|
||||
@ -233,12 +267,14 @@ class GPT4All:
|
||||
n_batch: Number of prompt tokens processed in parallel. Larger values decrease latency but increase resource requirements.
|
||||
n_predict: Equivalent to max_tokens, exists for backwards compatibility.
|
||||
streaming: If True, this method will instead return a generator that yields tokens as the model generates them.
|
||||
callback: A function with arguments token_id:int and response:str, which receives the tokens from the model as they are generated and stops the generation by returning False.
|
||||
|
||||
Returns:
|
||||
Either the entire completion or a generator that yields the completion token by token.
|
||||
"""
|
||||
generate_kwargs = dict(
|
||||
prompt=prompt,
|
||||
|
||||
# Preparing the model request
|
||||
generate_kwargs: Dict[str, Any] = dict(
|
||||
temp=temp,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
@ -249,42 +285,87 @@ class GPT4All:
|
||||
)
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
generate_kwargs["reset_context"] = len(self.current_chat_session) == 1 # check if there is only one message, i.e. system prompt
|
||||
self.current_chat_session.append({"role": "user", "content": prompt})
|
||||
generate_kwargs['prompt'] = self._format_chat_prompt_template(messages=self.current_chat_session[-1:])
|
||||
generate_kwargs['reset_context'] = len(self.current_chat_session) == 1
|
||||
|
||||
prompt = self._format_chat_prompt_template(
|
||||
messages = self.current_chat_session[-1:],
|
||||
default_prompt_header = self.current_chat_session[0]["content"] if generate_kwargs["reset_context"] else "",
|
||||
)
|
||||
else:
|
||||
generate_kwargs['reset_context'] = True
|
||||
generate_kwargs["reset_context"] = True
|
||||
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(**generate_kwargs)
|
||||
|
||||
output = self.model.prompt_model(**generate_kwargs)
|
||||
# Prepare the callback, process the model response
|
||||
output_collector: List[MessageType]
|
||||
output_collector = [{"content": ""}] # placeholder for the self.current_chat_session if chat session is not activated
|
||||
|
||||
if self._is_chat_session_activated:
|
||||
self.current_chat_session.append({"role": "assistant", "content": output})
|
||||
self.current_chat_session.append({"role": "assistant", "content": ""})
|
||||
output_collector = self.current_chat_session
|
||||
|
||||
return output
|
||||
def _callback_wrapper(
|
||||
callback: pyllmodel.ResponseCallbackType,
|
||||
output_collector: List[MessageType],
|
||||
) -> pyllmodel.ResponseCallbackType:
|
||||
|
||||
def _callback(token_id: int, response: str) -> bool:
|
||||
nonlocal callback, output_collector
|
||||
|
||||
output_collector[-1]["content"] += response
|
||||
|
||||
return callback(token_id, response)
|
||||
|
||||
return _callback
|
||||
|
||||
# Send the request to the model
|
||||
if streaming:
|
||||
return self.model.prompt_model_streaming(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
self.model.prompt_model(
|
||||
prompt=prompt,
|
||||
callback=_callback_wrapper(callback, output_collector),
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return output_collector[-1]["content"]
|
||||
|
||||
@contextmanager
|
||||
def chat_session(self):
|
||||
'''
|
||||
def chat_session(
|
||||
self,
|
||||
system_prompt: str = "",
|
||||
prompt_template: str = "",
|
||||
):
|
||||
"""
|
||||
Context manager to hold an inference optimized chat session with a GPT4All model.
|
||||
'''
|
||||
|
||||
Args:
|
||||
system_prompt: An initial instruction for the model.
|
||||
prompt_template: Template for the prompts with {0} being replaced by the user message.
|
||||
"""
|
||||
# Code to acquire resource, e.g.:
|
||||
self._is_chat_session_activated = True
|
||||
self.current_chat_session = []
|
||||
self.current_chat_session = empty_chat_session(system_prompt or self.config["systemPrompt"])
|
||||
self._current_prompt_template = prompt_template or self.config["promptTemplate"]
|
||||
try:
|
||||
yield self
|
||||
finally:
|
||||
# Code to release resource, e.g.:
|
||||
self._is_chat_session_activated = False
|
||||
self.current_chat_session = []
|
||||
self.current_chat_session = empty_chat_session()
|
||||
self._current_prompt_template = "{0}"
|
||||
|
||||
def _format_chat_prompt_template(
|
||||
self, messages: List[Dict], default_prompt_header=True, default_prompt_footer=True
|
||||
self,
|
||||
messages: List[MessageType],
|
||||
default_prompt_header: str = "",
|
||||
default_prompt_footer: str = "",
|
||||
) -> str:
|
||||
"""
|
||||
Helper method for building a prompt using template from list of messages.
|
||||
Helper method for building a prompt from list of messages using the self._current_prompt_template as a template for each message.
|
||||
|
||||
Args:
|
||||
messages: List of dictionaries. Each dictionary should have a "role" key
|
||||
@ -296,19 +377,44 @@ class GPT4All:
|
||||
Returns:
|
||||
Formatted prompt.
|
||||
"""
|
||||
full_prompt = ""
|
||||
|
||||
if isinstance(default_prompt_header, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_header' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_header = ""
|
||||
|
||||
if isinstance(default_prompt_footer, bool):
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"Using True/False for the 'default_prompt_footer' is deprecated. Use a string instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
default_prompt_footer = ""
|
||||
|
||||
full_prompt = default_prompt_header + "\n\n" if default_prompt_header != "" else ""
|
||||
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
user_message = "### Human: \n" + message["content"] + "\n### Assistant:\n"
|
||||
user_message = self._current_prompt_template.format(message["content"])
|
||||
full_prompt += user_message
|
||||
if message["role"] == "assistant":
|
||||
assistant_message = message["content"] + '\n'
|
||||
assistant_message = message["content"] + "\n"
|
||||
full_prompt += assistant_message
|
||||
|
||||
full_prompt += "\n\n" + default_prompt_footer if default_prompt_footer != "" else ""
|
||||
|
||||
return full_prompt
|
||||
|
||||
|
||||
def empty_chat_session(system_prompt: str = "") -> List[MessageType]:
|
||||
return [{"role": "system", "content": system_prompt}]
|
||||
|
||||
|
||||
def append_bin_suffix_if_missing(model_name):
|
||||
if not model_name.endswith(".bin"):
|
||||
model_name += ".bin"
|
||||
|
@ -6,26 +6,19 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
from typing import Iterable
|
||||
import logging
|
||||
from typing import Iterable, Callable, List
|
||||
|
||||
import pkg_resources
|
||||
|
||||
|
||||
class DualStreamProcessor:
|
||||
def __init__(self, stream=None):
|
||||
self.stream = stream
|
||||
self.output = ""
|
||||
|
||||
def write(self, text):
|
||||
if self.stream is not None:
|
||||
self.stream.write(text)
|
||||
self.stream.flush()
|
||||
self.output += text
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: provide a config file to make this more robust
|
||||
LLMODEL_PATH = os.path.join("llmodel_DO_NOT_MODIFY", "build").replace("\\", "\\\\")
|
||||
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace("\\", "\\\\")
|
||||
MODEL_LIB_PATH = str(pkg_resources.resource_filename("gpt4all", LLMODEL_PATH)).replace(
|
||||
"\\", "\\\\"
|
||||
)
|
||||
|
||||
|
||||
def load_llmodel_library():
|
||||
@ -43,9 +36,9 @@ def load_llmodel_library():
|
||||
|
||||
c_lib_ext = get_c_shared_lib_extension()
|
||||
|
||||
llmodel_file = "libllmodel" + '.' + c_lib_ext
|
||||
llmodel_file = "libllmodel" + "." + c_lib_ext
|
||||
|
||||
llmodel_dir = str(pkg_resources.resource_filename('gpt4all', os.path.join(LLMODEL_PATH, llmodel_file))).replace(
|
||||
llmodel_dir = str(pkg_resources.resource_filename("gpt4all", os.path.join(LLMODEL_PATH, llmodel_file))).replace(
|
||||
"\\", "\\\\"
|
||||
)
|
||||
|
||||
@ -134,7 +127,15 @@ llmodel.llmodel_set_implementation_search_path.restype = None
|
||||
llmodel.llmodel_threadCount.argtypes = [ctypes.c_void_p]
|
||||
llmodel.llmodel_threadCount.restype = ctypes.c_int32
|
||||
|
||||
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode('utf-8'))
|
||||
llmodel.llmodel_set_implementation_search_path(MODEL_LIB_PATH.encode("utf-8"))
|
||||
|
||||
|
||||
ResponseCallbackType = Callable[[int, str], bool]
|
||||
RawResponseCallbackType = Callable[[int, bytes], bool]
|
||||
|
||||
|
||||
def empty_response_callback(token_id: int, response: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class LLModel:
|
||||
@ -250,9 +251,10 @@ class LLModel:
|
||||
def generate_embedding(
|
||||
self,
|
||||
text: str
|
||||
) -> list[float]:
|
||||
) -> List[float]:
|
||||
if not text:
|
||||
raise ValueError("Text must not be None or empty")
|
||||
|
||||
embedding_size = ctypes.c_size_t()
|
||||
c_text = ctypes.c_char_p(text.encode('utf-8'))
|
||||
embedding_ptr = llmodel.llmodel_embedding(self.model, c_text, ctypes.byref(embedding_size))
|
||||
@ -263,6 +265,7 @@ class LLModel:
|
||||
def prompt_model(
|
||||
self,
|
||||
prompt: str,
|
||||
callback: ResponseCallbackType,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.9,
|
||||
@ -272,8 +275,7 @@ class LLModel:
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
streaming=False,
|
||||
) -> str:
|
||||
):
|
||||
"""
|
||||
Generate response from model from a prompt.
|
||||
|
||||
@ -281,26 +283,24 @@ class LLModel:
|
||||
----------
|
||||
prompt: str
|
||||
Question, task, or conversation for model to respond to
|
||||
streaming: bool
|
||||
Stream response to stdout
|
||||
callback(token_id:int, response:str): bool
|
||||
The model sends response tokens to callback
|
||||
|
||||
Returns
|
||||
-------
|
||||
Model response str
|
||||
None
|
||||
"""
|
||||
|
||||
prompt_bytes = prompt.encode('utf-8')
|
||||
logger.info(
|
||||
"LLModel.prompt_model -- prompt:\n"
|
||||
+ "%s\n"
|
||||
+ "===/LLModel.prompt_model -- prompt/===",
|
||||
prompt,
|
||||
)
|
||||
|
||||
prompt_bytes = prompt.encode("utf-8")
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
old_stdout = sys.stdout
|
||||
|
||||
stream_processor = DualStreamProcessor()
|
||||
|
||||
if streaming:
|
||||
stream_processor.stream = sys.stdout
|
||||
|
||||
sys.stdout = stream_processor
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
@ -317,56 +317,37 @@ class LLModel:
|
||||
self.model,
|
||||
prompt_ptr,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(self._response_callback),
|
||||
ResponseCallback(self._callback_decoder(callback)),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context,
|
||||
)
|
||||
|
||||
# Revert to old stdout
|
||||
sys.stdout = old_stdout
|
||||
# Force new line
|
||||
return stream_processor.output
|
||||
|
||||
def prompt_model_streaming(
|
||||
self,
|
||||
prompt: str,
|
||||
n_predict: int = 4096,
|
||||
top_k: int = 40,
|
||||
top_p: float = 0.9,
|
||||
temp: float = 0.1,
|
||||
n_batch: int = 8,
|
||||
repeat_penalty: float = 1.2,
|
||||
repeat_last_n: int = 10,
|
||||
context_erase: float = 0.75,
|
||||
reset_context: bool = False,
|
||||
) -> Iterable:
|
||||
callback: ResponseCallbackType = empty_response_callback,
|
||||
**kwargs
|
||||
) -> Iterable[str]:
|
||||
# Symbol to terminate from generator
|
||||
TERMINATING_SYMBOL = object()
|
||||
|
||||
output_queue = queue.Queue()
|
||||
|
||||
prompt_bytes = prompt.encode('utf-8')
|
||||
prompt_ptr = ctypes.c_char_p(prompt_bytes)
|
||||
|
||||
self._set_context(
|
||||
n_predict=n_predict,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temp=temp,
|
||||
n_batch=n_batch,
|
||||
repeat_penalty=repeat_penalty,
|
||||
repeat_last_n=repeat_last_n,
|
||||
context_erase=context_erase,
|
||||
reset_context=reset_context,
|
||||
)
|
||||
|
||||
# Put response tokens into an output queue
|
||||
def _generator_response_callback(token_id, response):
|
||||
output_queue.put(response.decode('utf-8', 'replace'))
|
||||
return True
|
||||
def _generator_callback_wrapper(callback: ResponseCallbackType) -> ResponseCallbackType:
|
||||
def _generator_callback(token_id: int, response: str):
|
||||
nonlocal callback
|
||||
|
||||
def run_llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context):
|
||||
llmodel.llmodel_prompt(model, prompt, prompt_callback, response_callback, recalculate_callback, context)
|
||||
if callback(token_id, response):
|
||||
output_queue.put(response)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return _generator_callback
|
||||
|
||||
def run_llmodel_prompt(prompt: str, callback: ResponseCallbackType, **kwargs):
|
||||
self.prompt_model(prompt, callback, **kwargs)
|
||||
output_queue.put(TERMINATING_SYMBOL)
|
||||
|
||||
# Kick off llmodel_prompt in separate thread so we can return generator
|
||||
@ -374,13 +355,10 @@ class LLModel:
|
||||
thread = threading.Thread(
|
||||
target=run_llmodel_prompt,
|
||||
args=(
|
||||
self.model,
|
||||
prompt_ptr,
|
||||
PromptCallback(self._prompt_callback),
|
||||
ResponseCallback(_generator_response_callback),
|
||||
RecalculateCallback(self._recalculate_callback),
|
||||
self.context,
|
||||
prompt,
|
||||
_generator_callback_wrapper(callback)
|
||||
),
|
||||
kwargs=kwargs,
|
||||
)
|
||||
thread.start()
|
||||
|
||||
@ -391,18 +369,19 @@ class LLModel:
|
||||
break
|
||||
yield response
|
||||
|
||||
def _callback_decoder(self, callback: ResponseCallbackType) -> RawResponseCallbackType:
|
||||
def _raw_callback(token_id: int, response: bytes) -> bool:
|
||||
nonlocal callback
|
||||
return callback(token_id, response.decode("utf-8", "replace"))
|
||||
|
||||
return _raw_callback
|
||||
|
||||
# Empty prompt callback
|
||||
@staticmethod
|
||||
def _prompt_callback(token_id):
|
||||
return True
|
||||
|
||||
# Empty response callback method that just prints response to be collected
|
||||
@staticmethod
|
||||
def _response_callback(token_id, response):
|
||||
sys.stdout.write(response.decode('utf-8', 'replace'))
|
||||
def _prompt_callback(token_id: int) -> bool:
|
||||
return True
|
||||
|
||||
# Empty recalculate callback
|
||||
@staticmethod
|
||||
def _recalculate_callback(is_recalculating):
|
||||
def _recalculate_callback(is_recalculating: bool) -> bool:
|
||||
return is_recalculating
|
||||
|
@ -108,7 +108,7 @@ private:
|
||||
QString m_name;
|
||||
QString m_filename;
|
||||
double m_temperature = 0.7;
|
||||
double m_topP = 0.1;
|
||||
double m_topP = 0.4;
|
||||
int m_topK = 40;
|
||||
int m_maxLength = 4096;
|
||||
int m_promptBatchSize = 128;
|
||||
|
Loading…
x
Reference in New Issue
Block a user