Skip to content

Commit

Permalink
Add support for Llama inference through NeuronModelForCausalLM (#223)
Browse files Browse the repository at this point in the history
* chore: use latest AWS SDK

* chore(tgi): use latest AWS SDK

* feat(generate): add support for llama

* fix(tgi): return a string for info.dtype

* fix(tgi): slot.select should return a scalar

* fix(tgi): insert leading space in next token text when needed

* fix(NeuronGenerationMixin): remove Marian hack
  • Loading branch information
dacorvo authored Sep 12, 2023
1 parent fd29acd commit 974f343
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 16 deletions.
6 changes: 6 additions & 0 deletions optimum/exporters/neuron/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,9 @@ def check_model_inputs_order(
class GPT2NeuronConfig(TextNeuronDecoderConfig):
NEURONX_ARGS = ["n_positions"]
NEURONX_CLASS = "gpt2.model.GPT2ForSampling"


@register_in_tasks_manager("llama", "text-generation")
class LLamaNeuronConfig(TextNeuronDecoderConfig):
NEURONX_ARGS = ["n_positions"]
NEURONX_CLASS = "llama.model.LlamaForSampling"
4 changes: 0 additions & 4 deletions optimum/neuron/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,10 +504,6 @@ def beam_search(
else:
next_token_logits = outputs.logits[:, -1, :]

# hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
# cannot be generated both before and after the `nn.functional.log_softmax` operation.
next_token_logits = self.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)

# Manually compute log softmax
# log_softmax(vi) = vi - max(vi) - log(sum(exp(vi - max(vi))))
logit_max, _ = torch.max(next_token_logits, dim=-1, keepdim=True)
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@
],
"neuronx": [
"wheel",
"neuronx-cc==2.*",
"torch-neuronx",
"transformers-neuronx",
"neuronx-cc>=2.9",
"torch-neuronx>=1.13.1.1.10.1",
"transformers-neuronx>=0.6.106",
"torch==1.13.1.*",
"torchvision==0.14.*",
"neuronx_distributed >= 0.2.0",
"neuronx_distributed >= 0.3.0",
],
"diffusers": ["diffusers"],
}
Expand Down
13 changes: 7 additions & 6 deletions text-generation-inference/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,19 @@ RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEU
# Install neuronx 2.12.2 packages
RUN apt-get update -y \
&& apt-get install -y --no-install-recommends \
aws-neuronx-dkms=2.11.9.0 \
aws-neuronx-collectives=2.15.16.0-db4e2d9a9 \
aws-neuronx-runtime-lib=2.15.14.0-279f319f2 \
aws-neuronx-tools=2.12.2.0 \
aws-neuronx-dkms=2.12.18.0 \
aws-neuronx-collectives=2.16.16.0-e59c7bb3e \
aws-neuronx-runtime-lib=2.16.14.0-61fdc395f \
aws-neuronx-tools=2.13.4.0 \
&& rm -rf /var/lib/apt/lists/* \
&& apt-get clean

ENV PATH="/opt/bin/:/opt/aws/neuron/bin:${PATH}"

RUN pip3 install \
torch-neuronx==1.13.1.1.9.1 \
transformers-neuronx==0.5.58 \
neuronx-cc==2.9.0.40 \
torch-neuronx==1.13.1.1.10.1 \
transformers-neuronx==0.6.106 \
--extra-index-url=https://pip.repos.neuron.amazonaws.com

# Install HuggingFace packages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def select(self, input_ids: torch.LongTensor, logits: torch.Tensor) -> torch.Lon
Return:
`torch.LongTensor`: A scalar torch.LongTensor` containing the selected token.
"""
return self._selector.select(input_ids, logits)
return self._selector.select(input_ids, logits)[0]

@property
def stopped(self) -> bool:
Expand Down Expand Up @@ -248,7 +248,7 @@ def info(self) -> InfoResponse:
dtype = getattr(self.model.config, "torch_dtype", "float32")
return InfoResponse(
requires_padding=True,
dtype=dtype,
dtype=str(dtype),
device_type="xla",
)

Expand Down Expand Up @@ -370,6 +370,11 @@ def _generate_token(
slot_input_ids = input_ids[i : i + 1, :]
next_token = slot.select(slot_input_ids, next_token_logits)
next_token_text = self.tokenizer.decode(next_token)
if not slot.generated_text.endswith(" ") and not next_token_text.startswith(" "):
# Some tokenizers do not prepend spaces automatically when decoding a single token
contextual_text = self.tokenizer.decode([slot.next_token, next_token])
if contextual_text[: -len(next_token_text)].endswith(" "):
next_token_text = " " + next_token_text
slot.append(next_token, next_token_text)
generated_text = None
finish_reason = None
Expand Down Expand Up @@ -447,6 +452,7 @@ def from_pretrained(
Args:
model_id (`str`):
The *model_id* of a model on the HuggingFace hub or the path to a local model.
In either case, the hub or local path must also contain a Tokenizer.
revision (`str`):
The revision of the model on the HuggingFace hub.
Expand Down

0 comments on commit 974f343

Please sign in to comment.