Skip to content

Commit

Permalink
refactor(project): Save point of project and project structure refact…
Browse files Browse the repository at this point in the history
…oring
  • Loading branch information
breadrock1 committed Oct 26, 2024
1 parent cfa933a commit 97dc8c4
Show file tree
Hide file tree
Showing 11 changed files with 47 additions and 60 deletions.
3 changes: 2 additions & 1 deletion config/development.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ allowed = "*"
max_age = 3600

[elastic]
address = "localhost:9200"
address = "158.160.44.99:9200"
enabled_tls = "true"
username = "elastic"
password = "elastic"
Expand All @@ -27,3 +27,4 @@ expired = 3600
address = "localhost:8085"
is_truncate = "false"
is_normalize = "false"
enabled_tls = "false"
1 change: 1 addition & 0 deletions config/production.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ expired = 3600
address = "embeddings:8085"
is_truncate = "false"
is_normalize = "false"
enabled_tls = "false"
27 changes: 13 additions & 14 deletions src/elastic.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
use crate::errors::Successful;
use crate::searcher::errors::SearcherResult;
use crate::Connectable;

use crate::searcher::errors::SearcherResult;
use elasticsearch::auth::Credentials;
use elasticsearch::cert::CertificateValidation;
use elasticsearch::http::headers::HeaderMap;
use elasticsearch::http::response::Response;
use elasticsearch::http::transport::{BuildError, SingleNodeConnectionPool, TransportBuilder};
use elasticsearch::http::{Method, Url};
use elasticsearch::{Elasticsearch, SearchParts};
use elasticsearch::{Elasticsearch, Error, SearchParts};
use getset::{CopyGetters, Getters};
use serde_derive::Deserialize;
use serde_json::Value;
Expand All @@ -23,28 +23,27 @@ pub struct ElasticClient {
}

#[derive(Clone, Deserialize, CopyGetters, Getters)]
#[getset(get = "pub")]
pub struct ElasticConfig {
#[getset(get = "pub")]
address: String,
#[getset(get_copy = "pub")]
enabled_tls: bool,
#[getset(get = "pub")]
username: String,
#[getset(get = "pub")]
password: String,
#[getset(skip)]
#[getset(get_copy = "pub")]
enabled_tls: bool,
}

impl ElasticClient {
pub fn es_client(&self) -> Arc<RwLock<Elasticsearch>> {
self.es_client.clone()
}

pub async fn send_request(
pub async fn send_native_request(
&self,
method: Method,
body: Option<&[u8]>,
target_url: &str,
) -> Result<Response, elasticsearch::Error> {
) -> Result<Response, Error> {
let es_client = self.es_client();
let elastic = es_client.write().await;
elastic
Expand All @@ -62,28 +61,28 @@ impl ElasticClient {
pub async fn search_request(
es: EsCxt,
query: &Value,
scroll: Option<&str>,
indexes: &[&str],
result: (i64, i64),
) -> SearcherResult<Response> {
let (size, offset) = result;
let elastic = es.read().await;
let response = elastic
.search(SearchParts::Index(indexes))
.allow_no_indices(true)
.pretty(true)
.from(offset)
.size(size)
.body(query)
.pretty(true)
.allow_no_indices(true)
.scroll(scroll.unwrap_or("1m"))
.send()
.await?;

let response = response.error_for_status_code()?;
Ok(response)
}

pub async fn extract_response_msg(
response: Response,
) -> Result<Successful, elasticsearch::Error> {
pub async fn extract_response_msg(response: Response) -> Result<Successful, Error> {
let _ = response.error_for_status_code()?;
Ok(Successful::new(200, "Done"))
}
Expand Down
8 changes: 2 additions & 6 deletions src/embeddings/native/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,9 @@ impl EmbeddingsService for EmbeddingsClient {
"normalize": self.is_normalize(),
}))
.send()
.await
.map_err(EmbeddingsError::from)?;
.await?;

let embed_data = response
.json::<Vec<Vec<f64>>>()
.await
.map_err(EmbeddingsError::from)?;
let embed_data = response.json::<Vec<Vec<f64>>>().await?;

let Some(tokens) = embed_data.first() else {
let msg = "loaded empty tokens array";
Expand Down
10 changes: 3 additions & 7 deletions src/searcher/elastic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use crate::searcher::errors::{PaginatedResult, SearcherError, SearcherResult};
use crate::searcher::forms::DocumentType;
use crate::searcher::forms::{DeletePaginatesForm, ScrollNextForm};
use crate::searcher::forms::{FulltextParams, SemanticParams};
use crate::searcher::models::Paginated;
use crate::searcher::{PaginatorService, SearcherService};
use crate::storage::models::{Document, DocumentVectors};

Expand All @@ -30,10 +29,7 @@ impl SearcherService for ElasticClient {
Ok(converter::to_unified_paginated(founded, return_as))
}

async fn search_semantic(
&self,
params: &SemanticParams,
) -> PaginatedResult<Value> {
async fn search_semantic(&self, params: &SemanticParams) -> PaginatedResult<Value> {
let es = self.es_client();
let query = DocumentVectors::build_search_query(params).await;
let founded = DocumentVectors::search(es, &query, params).await?;
Expand All @@ -55,6 +51,7 @@ impl PaginatorService for ElasticClient {
.iter()
.map(String::as_str)
.collect::<Vec<&str>>();

let es_client = self.es_client();
let elastic = es_client.read().await;
let response = elastic
Expand Down Expand Up @@ -85,8 +82,7 @@ impl PaginatorService for ElasticClient {
.send()
.await?;

let paginated = response.json::<Vec<Document>>().await?;
let paginated = Paginated::new(paginated);
let paginated = search::extract_searcher_result::<Document>(response).await?;
Ok(converter::to_unified_paginated(paginated, doc_type))
}
}
10 changes: 7 additions & 3 deletions src/searcher/elastic/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ impl Searcher<Document> for Document {
query: &Value,
params: &Self::Params,
) -> PaginatedResult<Document> {
let scroll = params.scroll_lifetime();
let results = params.result_size();
let indexes = params.folder_ids().split(',').collect::<Vec<&str>>();
let response = ElasticClient::search_request(es_cxt, query, &indexes, results).await?;
let response =
ElasticClient::search_request(es_cxt, query, Some(scroll), &indexes, results).await?;
let documents = extract_searcher_result::<Document>(response).await?;
Ok(documents)
}
Expand All @@ -41,15 +43,17 @@ impl Searcher<DocumentVectors> for DocumentVectors {
query: &Value,
params: &Self::Params,
) -> PaginatedResult<DocumentVectors> {
let scroll = params.scroll_lifetime();
let results = params.result_size();
let indexes = params.folder_ids().split(',').collect::<Vec<&str>>();
let response = ElasticClient::search_request(es_cxt, query, &indexes, results).await?;
let response =
ElasticClient::search_request(es_cxt, query, Some(scroll), &indexes, results).await?;
let documents = extract_searcher_result::<DocumentVectors>(response).await?;
Ok(documents)
}
}

async fn extract_searcher_result<T>(response: Response) -> PaginatedResult<T>
pub(super) async fn extract_searcher_result<T>(response: Response) -> PaginatedResult<T>
where
T: SearchQueryBuilder<T> + DocumentsTrait + serde::Serialize,
{
Expand Down
15 changes: 6 additions & 9 deletions src/searcher/endpoints.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
use crate::cacher::CacherService;

use crate::embeddings::EmbeddingsService;
use crate::errors::{ErrorResponse, JsonResponse, PaginateResponse, Successful, WebError};
use crate::searcher::forms::DocumentType;
use crate::errors::{ErrorResponse, JsonResponse, PaginateResponse, Successful};
use crate::searcher::forms::{DeletePaginatesForm, DocumentTypeQuery, ScrollNextForm};
use crate::searcher::forms::{FulltextParams, SemanticParams};
use crate::searcher::models::Paginated;
Expand Down Expand Up @@ -146,10 +145,7 @@ async fn search_semantic(
let client = cxt.get_ref();

let mut search_form = form.0;
let query_tokens = embed
.load_from_text(search_form.query())
.await
.map_err(WebError::from)?;
let query_tokens = embed.load_from_text(search_form.query()).await?;

search_form.set_tokens(query_tokens);

Expand Down Expand Up @@ -219,7 +215,7 @@ async fn delete_paginate_sessions(
),
),
request_body(
content = PaginateNextForm,
content = ScrollNextForm,
example = json!(ScrollNextForm::test_example(None))
),
responses(
Expand Down Expand Up @@ -248,7 +244,7 @@ async fn paginate_next(
cxt: PaginateContext,
#[cfg(feature = "enable-cacher")] cacher: CacherPaginateContext,
form: Json<ScrollNextForm>,
document_type: Query<DocumentType>,
document_type: Query<DocumentTypeQuery>,
) -> PaginateResponse<Vec<Value>> {
let client = cxt.get_ref();
let pag_form = form.0;
Expand All @@ -259,7 +255,8 @@ async fn paginate_next(
return Ok(Json(docs));
}

let documents = client.paginate(&pag_form, &document_type).await?;
let doc_type = document_type.0.get_type();
let documents = client.paginate(&pag_form, &doc_type).await?;

#[cfg(feature = "enable-cacher")]
cacher.insert(&pag_form, &documents).await;
Expand Down
6 changes: 1 addition & 5 deletions src/searcher/forms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,12 @@ impl DocumentType {
#[derive(Default, Deserialize, IntoParams, ToSchema)]
pub struct DocumentTypeQuery {
document_type: Option<DocumentType>,
is_grouped: bool,
}

impl DocumentTypeQuery {
pub fn get_type(&self) -> DocumentType {
self.document_type.clone().unwrap_or(DocumentType::Document)
}

pub fn is_grouped(&self) -> bool {
self.is_grouped
}
}

#[derive(Builder, Debug, Deserialize, Serialize, Getters, CopyGetters, IntoParams, ToSchema)]
Expand Down Expand Up @@ -110,6 +105,7 @@ pub struct SemanticParams {
query: String,

#[getset(skip)]
#[serde(skip_serializing_if = "Option::is_none")]
query_tokens: Option<Vec<f64>>,

#[schema(example = "test-folder")]
Expand Down
5 changes: 1 addition & 4 deletions src/searcher/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,7 @@ pub trait SearcherService {
return_as: &DocumentType,
) -> PaginatedResult<Value>;

async fn search_semantic(
&self,
params: &SemanticParams,
) -> PaginatedResult<Value>;
async fn search_semantic(&self, params: &SemanticParams) -> PaginatedResult<Value>;
}

#[async_trait::async_trait]
Expand Down
2 changes: 1 addition & 1 deletion src/storage/elastic/helper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ pub async fn filter_folders(
let indexes = &[INFO_FOLDER_ID];
let results = (params.result_size(), params.result_offset());
let query = DocumentVectors::build_retrieve_query(&params).await;
let response = ElasticClient::search_request(es_cxt, &query, indexes, results).await?;
let response = ElasticClient::search_request(es_cxt, &query, None, indexes, results).await?;

let value = response.json::<Value>().await?;
let Some(founded) = &value[&"hits"][&"hits"].as_array() else {
Expand Down
20 changes: 10 additions & 10 deletions src/storage/elastic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use crate::errors::Successful;
use crate::storage::elastic::retrieve::Retrieve;
use crate::storage::elastic::store::StoreTrait;
use crate::storage::elastic::update::UpdateTrait;
use crate::storage::errors::{StorageError, StorageResult};
use crate::storage::errors::StorageResult;
use crate::storage::forms::{CreateFolderForm, RetrieveParams};
use crate::storage::models::INFO_FOLDER_ID;
use crate::storage::models::{Document, DocumentVectors, Folder, FolderType, InfoFolder};
Expand All @@ -32,7 +32,7 @@ const CAT_INDICES_URL: &str = "/_cat/indices?format=json";
impl FolderService for ElasticClient {
async fn get_folders(&self, show_all: bool) -> StorageResult<Vec<Folder>> {
let response = self
.send_request(Method::Get, None, CAT_INDICES_URL)
.send_native_request(Method::Get, None, CAT_INDICES_URL)
.await?;

let folders = response.json::<Vec<Folder>>().await?;
Expand All @@ -42,10 +42,9 @@ impl FolderService for ElasticClient {
async fn get_folder(&self, folder_id: &str) -> StorageResult<Folder> {
let target_url = format!("/{folder_id}/_stats");
let response = self
.send_request(Method::Get, None, &target_url)
.send_native_request(Method::Get, None, &target_url)
.await?
.error_for_status_code()
.map_err(StorageError::from)?;
.error_for_status_code()?;

let value = response.json::<Value>().await?;
let mut folder = Folder::from_value(value).await?;
Expand Down Expand Up @@ -113,13 +112,15 @@ impl DocumentService for ElasticClient {
match folder_type {
FolderType::Vectors => {
let query = DocumentVectors::build_retrieve_query(params).await;
let response = ElasticClient::search_request(es, &query, &folders, results).await?;
let response =
ElasticClient::search_request(es, &query, None, &folders, results).await?;
let value = helper::extract_from_response::<DocumentVectors>(response).await?;
Ok(value)
}
_ => {
let query = Document::build_retrieve_query(params).await;
let response = ElasticClient::search_request(es, &query, &folders, results).await?;
let response =
ElasticClient::search_request(es, &query, None, &folders, results).await?;
let value = helper::extract_from_response::<Document>(response).await?;
Ok(value)
}
Expand All @@ -134,10 +135,9 @@ impl DocumentService for ElasticClient {
) -> StorageResult<Value> {
let s_doc_path = format!("/{}/_doc/{}", folder_id, doc_id);
let response = self
.send_request(Method::Get, None, &s_doc_path)
.send_native_request(Method::Get, None, &s_doc_path)
.await?
.error_for_status_code()
.map_err(StorageError::from)?;
.error_for_status_code()?;

match folder_type {
FolderType::Vectors => {
Expand Down

0 comments on commit 97dc8c4

Please sign in to comment.