Skip to content

Commit

Permalink
Fix SQ Hints
Browse files Browse the repository at this point in the history
  • Loading branch information
amengus87 committed Sep 26, 2024
1 parent c1ce38a commit 082a07c
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 17 deletions.
2 changes: 1 addition & 1 deletion backend/src/main/java/ai/dragon/entity/FarmEntity.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ public FarmEntity() {
this.raagIdentifier = UUID.randomUUID().toString();
this.languageModel = LanguageModelType.OpenAiModel;
this.chatMemoryStrategy = ChatMemoryStrategy.MaxMessages;
this.queryRouter = QueryRouterType.Default;
this.queryRouter = QueryRouterType.DEFAULT;
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package ai.dragon.enumeration;

public enum QueryRouterType {
Default("Default"),
LanguageModel("LanguageModel");
DEFAULT("Default"),
LANGUAGE_MODEL("LanguageModel");

private String value;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ public abstract class AbstractRepository<T extends AbstractEntity> {
@Autowired
protected DatabaseService databaseService;

private ObjectRepository<T> repository;
private ObjectRepository<T> objectRepository;

public void executeTransaction(Consumer<AbstractRepository<T>> transactionConsumer) {
if (this instanceof TransactionalRepository) {
Expand Down Expand Up @@ -239,31 +239,31 @@ private void validate(T entity) {
}

protected ObjectRepository<T> getObjectRepository() {
if (repository != null) {
return repository;
if (objectRepository != null) {
return objectRepository;
}

Nitrite db = databaseService.getNitriteDB();
repository = db.getRepository(getGenericSuperclass());
objectRepository = db.getRepository(getGenericSuperclass());

return repository;
return objectRepository;
}

// Inner class to handle transactional operations
private static class TransactionalRepository<T extends AbstractEntity> extends AbstractRepository<T> {
private final ObjectRepository<T> repository;
private final ObjectRepository<T> objectRepository;
private final Class<T> type;

TransactionalRepository(ObjectRepository<T> repository, Class<T> type, DatabaseService databaseService) {
TransactionalRepository(ObjectRepository<T> objectRepository, Class<T> type, DatabaseService databaseService) {
super();
this.repository = repository;
this.objectRepository = objectRepository;
this.databaseService = databaseService;
this.type = type;
}

@Override
protected ObjectRepository<T> getObjectRepository() {
return this.repository;
return this.objectRepository;
}

@Override
Expand Down
2 changes: 1 addition & 1 deletion backend/src/main/java/ai/dragon/service/RaagService.java
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ private void buildRetrievalAugmentor(AiServices<AiAssistant> assistantBuilder, F
DefaultRetrievalAugmentorBuilder retrievalAugmentorBuilder = DefaultRetrievalAugmentor.builder();
Map<ContentRetriever, String> retrievers = this.buildRetrieverMap(farm, servletRequest);
if (retrievers != null && !retrievers.isEmpty()) {
if (QueryRouterType.LanguageModel.equals(farm.getQueryRouter())) {
if (QueryRouterType.LANGUAGE_MODEL.equals(farm.getQueryRouter())) {
retrievalAugmentorBuilder.queryRouter(new LanguageModelQueryRouter(chatLanguageModel, retrievers,
retrievalSettings.getLanguageQueryRouterPromptTemplate(),
retrievalSettings.getLanguageQueryRouterFallbackStrategy()));
Expand Down
2 changes: 1 addition & 1 deletion backend/src/main/java/ai/dragon/util/fluenttry/Try.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public T run(Callable<T> callable) {
}
executorService.shutdown();
if (ex != null && fallback != null) {
return (T) fallback.apply(ex);
return fallback.apply(ex);
}
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ static void beforeAll(@Autowired FarmRepository farmRepository,
.setLanguageModelSettings(List.of(apiKeySetting, omniModelNameSetting));
farmWithSunspotsAndWebSSHSilosFallbackFail
.setSilos(List.of(sunspotsSilo.getUuid(), websshSilo.getUuid()));
farmWithSunspotsAndWebSSHSilosFallbackFail.setQueryRouter(QueryRouterType.LanguageModel);
farmWithSunspotsAndWebSSHSilosFallbackFail.setQueryRouter(QueryRouterType.LANGUAGE_MODEL);
farmWithSunspotsAndWebSSHSilosFallbackFail
.setRetrievalAugmentorSettings(List.of("languageQueryRouterFallbackStrategy=FAIL"));
farmRepository.save(farmWithSunspotsAndWebSSHSilosFallbackFail);
Expand All @@ -188,7 +188,7 @@ static void beforeAll(@Autowired FarmRepository farmRepository,
.setLanguageModelSettings(List.of(apiKeySetting, omniModelNameSetting));
farmWithSunspotsAndWebSSHSilosDoNotRoute
.setSilos(List.of(sunspotsSilo.getUuid(), websshSilo.getUuid()));
farmWithSunspotsAndWebSSHSilosDoNotRoute.setQueryRouter(QueryRouterType.LanguageModel);
farmWithSunspotsAndWebSSHSilosDoNotRoute.setQueryRouter(QueryRouterType.LANGUAGE_MODEL);
farmWithSunspotsAndWebSSHSilosDoNotRoute.setRetrievalAugmentorSettings(
List.of("languageQueryRouterFallbackStrategy=DO_NOT_ROUTE"));
farmRepository.save(farmWithSunspotsAndWebSSHSilosDoNotRoute);
Expand All @@ -201,7 +201,7 @@ static void beforeAll(@Autowired FarmRepository farmRepository,
.setLanguageModelSettings(List.of(apiKeySetting, omniModelNameSetting));
farmWithSunspotsAndWebSSHSilosRouteToAll
.setSilos(List.of(sunspotsSilo.getUuid(), websshSilo.getUuid()));
farmWithSunspotsAndWebSSHSilosRouteToAll.setQueryRouter(QueryRouterType.LanguageModel);
farmWithSunspotsAndWebSSHSilosRouteToAll.setQueryRouter(QueryRouterType.LANGUAGE_MODEL);
farmWithSunspotsAndWebSSHSilosRouteToAll.setRetrievalAugmentorSettings(
List.of("languageQueryRouterFallbackStrategy=ROUTE_TO_ALL"));
farmRepository.save(farmWithSunspotsAndWebSSHSilosRouteToAll);
Expand Down

0 comments on commit 082a07c

Please sign in to comment.