diff --git a/modeling_florence2.py b/modeling_florence2.py index e5ee651..ccca154 100644 --- a/modeling_florence2.py +++ b/modeling_florence2.py @@ -29,6 +29,7 @@ from einops import rearrange from timm.models.layers import DropPath, trunc_normal_ from transformers.modeling_utils import PreTrainedModel +from transformers.generation.utils import GenerationMixin from transformers.utils import ( ModelOutput, add_start_docstrings, @@ -2059,7 +2060,7 @@ class Florence2LanguageModel(Florence2LanguagePreTrainedModel): ) -class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel): +class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel, GenerationMixin): base_model_prefix = "model" _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] _keys_to_ignore_on_load_missing = ["final_logits_bias"]