EasyDel implementation of Clip JAX model#357
Conversation
| """ | ||
|
|
||
| from transformers import FlaxCLIPModel | ||
| from easydel import CLIPModel |
There was a problem hiding this comment.
Using CLIPModel is appropriate when we already have EasyDeL formatted weights. In our case, we are loading weights directly from HuggingFace and converting them into JAX NumPy arrays.
Easydel provides AutoModel family of classes which is exactly for this purposes (model params taking from HF and transfer to jax arrays; after that calling model implementation e.g. ClipModel).
For this use case we should use AutoEasyDeLModelForZeroShotImageClassification (from easydel import AutoEasyDeLModelForZeroShotImageClassification).
| framework=Framework.JAX, | ||
| ) | ||
|
|
||
| def _load_processor(self, dtype_override=None): |
There was a problem hiding this comment.
Why did we deleted this _load_processor? It was better to use AutoImageProcessor instead of CLIPProcessor?
| @@ -132,7 +110,7 @@ def load_model(self, dtype_override=None): | |||
| from_pt = pretrained_model_name == "openai/clip-vit-large-patch14-336" | |||
There was a problem hiding this comment.
This flag from_pt is unnecessary in the easydel
Problem description
Modify loader of CLIP Jax model to use EasyDel Instead of Transformers library
What's changed
load_processorand used AutoImageProcessorget_input_activations_partition_specandload_parameters_partition_specChecklist