diff --git a/chirp/projects/zoo/models.py b/chirp/projects/zoo/models.py index 6708813d..4e4fa117 100644 --- a/chirp/projects/zoo/models.py +++ b/chirp/projects/zoo/models.py @@ -47,6 +47,7 @@ def model_class_map() -> dict[str, Any]: 'separate_embed_model': SeparateEmbedModel, 'tfhub_model': TFHubModel, 'google_whale': GoogleWhaleModel, + 'handcrafted_features': HandcraftedFeaturesModel, } @@ -124,6 +125,24 @@ def get_preset_model_config(preset_name): model_config.model_url = 'https://tfhub.dev/google/vggish/1' model_config.embedding_index = -1 model_config.logits_index = -1 + elif preset_name == 'beans_baseline': + model_key = 'handcrafted_features' + embedding_dim = 80 + model_config.sample_rate = 32_000 + model_config.melspec_config = config_dict.ConfigDict({ + 'sample_rate': model_config.sample_rate, + 'features': 160, + 'stride': 320, + 'kernel_size': 640, + 'freq_range': (60.0, model_config.sample_rate / 2.0), + 'scaling_config': frontend.LogScalingConfig(), + }) + model_config.features_config = config_dict.ConfigDict({ + 'compute_mfccs': True, + 'aggregation': 'beans', + }) + model_config.window_size_s = 1.0 + model_config.hop_size_s = 1.0 else: raise ValueError('Unsupported model preset: %s' % preset_name) return model_key, embedding_dim, model_config