Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 41 additions & 5 deletions nemo/deploy/nlp/query_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,20 @@ def query_llm(
stop_words_list=None,
bad_words_list=None,
no_repeat_ngram_size=None,
max_output_len=512,
top_k=1,
top_p=0.0,
temperature=1.0,
min_output_len=None,
max_output_len=None,
top_k=None,
top_p=None,
temperature=None,
random_seed=None,
task_id=None,
lora_uids=None,
use_greedy: bool = None,
repetition_penalty: float = None,
add_BOS: bool = None,
all_probs: bool = None,
compute_logprob: bool = None,
end_strings=None,
init_timeout=60.0,
):
"""
Expand All @@ -110,6 +117,9 @@ def query_llm(
prompts = str_list2numpy(prompts)
inputs = {"prompts": prompts}

if min_output_len is not None:
inputs["min_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_)

if max_output_len is not None:
inputs["max_output_len"] = np.full(prompts.shape, max_output_len, dtype=np.int_)

Expand All @@ -127,6 +137,7 @@ def query_llm(

if stop_words_list is not None:
inputs["stop_words_list"] = str_list2numpy(stop_words_list)

if bad_words_list is not None:
inputs["bad_words_list"] = str_list2numpy(bad_words_list)

Expand All @@ -141,12 +152,37 @@ def query_llm(
lora_uids = np.char.encode(lora_uids, "utf-8")
inputs["lora_uids"] = np.full((prompts.shape[0], len(lora_uids)), lora_uids)

if use_greedy is not None:
inputs["use_greedy"] = np.full(prompts.shape, use_greedy, dtype=np.bool_)

if repetition_penalty is not None:
inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single)

if add_BOS is not None:
inputs["add_BOS"] = np.full(prompts.shape, add_BOS, dtype=np.bool_)

if all_probs is not None:
inputs["all_probs"] = np.full(prompts.shape, all_probs, dtype=np.bool_)

if compute_logprob is not None:
inputs["compute_logprob"] = np.full(prompts.shape, compute_logprob, dtype=np.bool_)

if end_strings is not None:
inputs["end_strings"] = str_list2numpy(end_strings)

with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client:
result_dict = client.infer_batch(**inputs)
output_type = client.model_config.outputs[0].dtype

if output_type == np.bytes_:
sentences = np.char.decode(result_dict["outputs"].astype("bytes"), "utf-8")
if "outputs" in result_dict.keys():
output = result_dict["outputs"]
elif "sentences" in result_dict.keys():
output = result_dict["sentences"]
else:
return "Unknown output keyword."

sentences = np.char.decode(output.astype("bytes"), "utf-8")
return sentences
else:
return result_dict["outputs"]
Expand Down
29 changes: 20 additions & 9 deletions scripts/deploy/nlp/deploy_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,6 @@ def get_args(argv):
description=f"Deploy nemo models to Triton",
)
parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file")
parser.add_argument(
"-dsn",
"--direct_serve_nemo",
default=False,
action='store_true',
help="Serve the nemo model directly instead of exporting to TRTLLM first. Will ignore other TRTLLM-specific arguments.",
)
parser.add_argument(
"-ptnc",
"--ptuning_nemo_checkpoint",
Expand Down Expand Up @@ -147,6 +140,15 @@ def get_args(argv):
action='store_true',
help='Use TensorRT LLM C++ runtime',
)
parser.add_argument(
"-b",
'--backend',
nargs='?',
const=None,
default='TensorRT-LLM',
choices=['TensorRT-LLM', 'vLLM', 'In-Framework'],
help="Different options to deploy nemo model.",
)
parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode")

args = parser.parse_args(argv)
Expand Down Expand Up @@ -261,7 +263,8 @@ def get_trtllm_deployable(args):

def get_nemo_deployable(args):
if args.nemo_checkpoint is None:
raise ValueError("Direct serve requires a .nemo checkpoint")
raise ValueError("In-Framework deployment requires a .nemo checkpoint")

return MegatronLLMDeployable(args.nemo_checkpoint, args.num_gpus)


Expand All @@ -277,7 +280,15 @@ def nemo_deploy(argv):
LOGGER.info("Logging level set to {}".format(loglevel))
LOGGER.info(args)

triton_deployable = get_nemo_deployable(args) if args.direct_serve_nemo else get_trtllm_deployable(args)
backend = args.backend.lower()
if backend == 'tensorrt-llm':
triton_deployable = get_trtllm_deployable(args)
elif backend == 'in-framework':
triton_deployable = get_nemo_deployable(args)
elif backend == 'vllm':
raise ValueError("vLLM will be supported in the next release.")
else:
raise ValueError("Backend: {0} is not supported.".format(backend))

try:
nm = DeployPyTriton(
Expand Down
Loading