diff --git a/presets/inference/text-generation/inference_api.py b/presets/inference/text-generation/inference_api.py index ddcf2780f..bf739844d 100644 --- a/presets/inference/text-generation/inference_api.py +++ b/presets/inference/text-generation/inference_api.py @@ -123,30 +123,18 @@ def health_check(): return {"status": "Healthy"} class GenerateKwargs(BaseModel): - max_length: int = 200 + max_length: int = 200 # Length of input prompt+max_new_tokens min_length: int = 0 - do_sample: bool = True + do_sample: bool = False early_stopping: bool = False num_beams: int = 1 - num_beam_groups: int = 1 - diversity_penalty: float = 0.0 temperature: float = 1.0 - top_k: int = 10 + top_k: int = 50 top_p: float = 1 typical_p: float = 1 repetition_penalty: float = 1 - length_penalty: float = 1 - no_repeat_ngram_size: int = 0 - encoder_no_repeat_ngram_size: int = 0 - bad_words_ids: Optional[List[int]] = None - num_return_sequences: int = 1 - output_scores: bool = False - return_dict_in_generate: bool = False pad_token_id: Optional[int] = tokenizer.pad_token_id eos_token_id: Optional[int] = tokenizer.eos_token_id - forced_bos_token_id: Optional[int] = None - forced_eos_token_id: Optional[int] = None - remove_invalid_values: Optional[bool] = None class Config: extra = Extra.allow # Allows for additional fields not explicitly defined