Skip to content

Commit

Permalink
fix(tgi): return the correct FinishReason on stop string
Browse files Browse the repository at this point in the history
  • Loading branch information
dacorvo committed Dec 12, 2024
1 parent ce5e91d commit 80b3d6d
Showing 1 changed file with 17 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,12 @@ def attention_mask(self) -> torch.LongTensor:
def max_token(self) -> int:
return self._generation_config.max_length

@property
def max_new_tokens(self) -> int:
# The current value of max_new_tokens: might be different of the target max_new_tokens
# if the slot has been paused and resumed.
return self._generation_config.max_new_tokens

@property
def truncate(self) -> int:
return self._truncate
Expand Down Expand Up @@ -384,7 +390,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
slot = empty_slots.pop()
slot.assign(self.batch_id, request, self.model.generation_config)
new_slots.append(slot)
logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}")
logger.debug(
f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}"
)
if self.rebuild_cache_on_prefill:
# We will clear pending slots and prefill all slots
prefill_slots = self.slots
Expand Down Expand Up @@ -453,6 +461,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]:
# Append back the next token
slot.append(next_tokens[i])
logger.debug("Model ready for decoding")
if next_batch is not None:
logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}")
return generation, next_batch

def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]:
Expand Down Expand Up @@ -525,14 +535,16 @@ def _generate_token(
if next_token == self.tokenizer.eos_token_id:
finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN
elif slot.stopped:
# For now we only support the length stopping criteria
finish_reason = FinishReason.FINISH_REASON_LENGTH
if slot.generated_tokens == slot.max_new_tokens:
finish_reason = FinishReason.FINISH_REASON_LENGTH
else:
finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE
if finish_reason is not None:
# We must include the generated text for each finished sequence in the response
generated_text = GeneratedText(
text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason
)
logger.debug(f"Finished generating tokens for request {request_id}")
logger.debug(f"Decode complete for request {request_id} with {slot.generated_tokens} tokens")
# mark the slot as available
slot.clear()
else:
Expand Down Expand Up @@ -590,7 +602,7 @@ def clear(self, batch_id: Optional[int] = None):
def _clear(self, keep_slot_ids: List):
for slot in self.slots:
if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids:
logger.info(f"Removing slot {slot.id} with request {slot.request_id}")
logger.debug(f"Removing slot {slot.id} with request {slot.request_id}")
slot.clear()

@classmethod
Expand Down

0 comments on commit 80b3d6d

Please sign in to comment.