Skip to content

Commit

Permalink
Register all timm models with ImageNet-1K and update documentation (#24)
Browse files Browse the repository at this point in the history
* register all timm models with imagenet 1k

* remove eva

* update readme
  • Loading branch information
dnth authored Oct 22, 2024
1 parent f3fc5bf commit 495d287
Show file tree
Hide file tree
Showing 6 changed files with 494 additions and 183 deletions.
15 changes: 1 addition & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -194,20 +194,7 @@ model = xinfer.create_model(model)


TIMM:
<table>
<thead>
<tr>
<th>Model</th>
<th>Usage</th>
</tr>
</thead>
<tbody>
<tr>
<td><a href="https://github.com/baaivision/EVA/tree/master/EVA-02">EVA02 Series</a></td>
<td><code>xinfer.create_model("eva02_small_patch14_336.mim_in22k_ft_in1k")</code></td>
</tr>
</tbody>
</table>
All models from [TIMM](https://github.com/huggingface/pytorch-image-models) fine-tuned for ImageNet 1k are supported.

> [!NOTE]
> Wish to load an unsupported model?
Expand Down
594 changes: 474 additions & 120 deletions nbs/timm.ipynb

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ einops
pillow
gradio
itables
pandas
pandas
numpy<2.0.0
2 changes: 1 addition & 1 deletion xinfer/timm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .eva02 import EVA02
from .load_timm_models import load_timm_models
from .timm_model import TimmModel
47 changes: 0 additions & 47 deletions xinfer/timm/eva02.py

This file was deleted.

16 changes: 16 additions & 0 deletions xinfer/timm/load_timm_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import timm

from ..model_registry import ModelInputOutput, register_model
from .timm_model import TimmModel


def load_timm_models():
model_list = timm.list_models("*1k*", pretrained=True)

for model_id in model_list:
register_model(model_id, "timm", ModelInputOutput.IMAGE_TO_CATEGORIES)(
TimmModel
)


load_timm_models()

0 comments on commit 495d287

Please sign in to comment.