diff --git a/text-generation-inference/server/text_generation_server/generator.py b/text-generation-inference/server/text_generation_server/generator.py index e394726bb..3ddee690c 100644 --- a/text-generation-inference/server/text_generation_server/generator.py +++ b/text-generation-inference/server/text_generation_server/generator.py @@ -303,6 +303,12 @@ def attention_mask(self) -> torch.LongTensor: def max_token(self) -> int: return self._generation_config.max_length + @property + def max_new_tokens(self) -> int: + # The current value of max_new_tokens: might be different of the target max_new_tokens + # if the slot has been paused and resumed. + return self._generation_config.max_new_tokens + @property def truncate(self) -> int: return self._truncate @@ -384,7 +390,9 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: slot = empty_slots.pop() slot.assign(self.batch_id, request, self.model.generation_config) new_slots.append(slot) - logger.debug(f"Request {slot.request_id} assigned to slot {slot.id}") + logger.debug( + f"Request {slot.request_id} assigned to slot {slot.id} with and max_new_tokens {slot.max_new_tokens}" + ) if self.rebuild_cache_on_prefill: # We will clear pending slots and prefill all slots prefill_slots = self.slots @@ -453,6 +461,8 @@ def prefill(self, batch: Batch) -> Tuple[List[Generation], CachedBatch]: # Append back the next token slot.append(next_tokens[i]) logger.debug("Model ready for decoding") + if next_batch is not None: + logger.debug(f"Next batch is {next_batch.id} with requests: {next_batch.request_ids}") return generation, next_batch def decode(self, batches: List[CachedBatch]) -> Tuple[List[Generation], CachedBatch]: @@ -525,14 +535,16 @@ def _generate_token( if next_token == self.tokenizer.eos_token_id: finish_reason = FinishReason.FINISH_REASON_EOS_TOKEN elif slot.stopped: - # For now we only support the length stopping criteria - finish_reason = FinishReason.FINISH_REASON_LENGTH + if slot.generated_tokens == slot.max_new_tokens: + finish_reason = FinishReason.FINISH_REASON_LENGTH + else: + finish_reason = FinishReason.FINISH_REASON_STOP_SEQUENCE if finish_reason is not None: # We must include the generated text for each finished sequence in the response generated_text = GeneratedText( text=slot.generated_text, generated_tokens=slot.generated_tokens, finish_reason=finish_reason ) - logger.debug(f"Finished generating tokens for request {request_id}") + logger.debug(f"Decode complete for request {request_id} with {slot.generated_tokens} tokens") # mark the slot as available slot.clear() else: @@ -590,7 +602,7 @@ def clear(self, batch_id: Optional[int] = None): def _clear(self, keep_slot_ids: List): for slot in self.slots: if slot.state != Slot.State.EMPTY and slot.id not in keep_slot_ids: - logger.info(f"Removing slot {slot.id} with request {slot.request_id}") + logger.debug(f"Removing slot {slot.id} with request {slot.request_id}") slot.clear() @classmethod