Skip to content

Commit

Permalink
update type annotations (#383)
Browse files Browse the repository at this point in the history
* update type annotations

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pre-commit fixes

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: jaimergp <jaimergp@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 28, 2023
1 parent 993d739 commit a244048
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions conda_libmamba_solver/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,15 @@
We maintain a map of subdir-specific URLs to `conda.model.channel.Channel`
and `libmamba.Repo` objects.
"""
from __future__ import annotations

import logging
import os
from dataclasses import dataclass
from functools import lru_cache, partial
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, Iterable, Optional, Tuple, Union
from typing import Iterable

import libmambapy as api
from conda.base.constants import REPODATA_FN
Expand Down Expand Up @@ -109,8 +111,8 @@ class LibMambaIndexHelper(IndexHelper):
def __init__(
self,
installed_records: Iterable[PackageRecord] = (),
channels: Iterable[Union[Channel, str]] = None,
subdirs: Iterable[str] = None,
channels: Iterable[Channel | str] | None = None,
subdirs: Iterable[str] | None = None,
repodata_fn: str = REPODATA_FN,
query_format=api.QueryFormat.JSON,
):
Expand Down Expand Up @@ -217,7 +219,7 @@ def _repo_from_records(
finally:
os.unlink(f.name)

def _fetch_channel(self, url: str) -> Tuple[str, os.PathLike]:
def _fetch_channel(self, url: str) -> tuple[str, os.PathLike]:
channel = Channel.from_url(url)
if not channel.subdir:
raise ValueError(f"Channel URLs must specify a subdir! Provided: {url}")
Expand All @@ -238,7 +240,7 @@ def _fetch_channel(self, url: str) -> Tuple[str, os.PathLike]:

def _json_path_to_repo_info(
self, url: str, json_path: str, try_solv: bool = False
) -> Optional[_ChannelRepoInfo]:
) -> _ChannelRepoInfo | None:
channel = Channel.from_url(url)
noauth_url = channel.urls(with_credentials=False, subdirs=(channel.subdir,))[0]
json_path = Path(json_path)
Expand Down Expand Up @@ -279,7 +281,7 @@ def _json_path_to_repo_info(
noauth_url=noauth_url,
)

def _load_channels(self) -> Dict[str, _ChannelRepoInfo]:
def _load_channels(self) -> dict[str, _ChannelRepoInfo]:
# 1. Obtain and deduplicate URLs from channels
urls = []
seen_noauth = set()
Expand Down Expand Up @@ -330,24 +332,22 @@ def _load_installed(self, records: Iterable[PackageRecord]) -> api.Repo:
return repo

def whoneeds(
self, query: Union[str, MatchSpec], records=True
) -> Union[Iterable[PackageRecord], dict, str]:
self, query: str | MatchSpec, records=True
) -> Iterable[PackageRecord] | dict | str:
result_str = self._query.whoneeds(self._prepare_query(query), self._format)
if self._format == api.QueryFormat.JSON:
return self._process_query_result(result_str, records=records)
return result_str

def depends(
self, query: Union[str, MatchSpec], records=True
) -> Union[Iterable[PackageRecord], dict, str]:
self, query: str | MatchSpec, records=True
) -> Iterable[PackageRecord] | dict | str:
result_str = self._query.depends(self._prepare_query(query), self._format)
if self._format == api.QueryFormat.JSON:
return self._process_query_result(result_str, records=records)
return result_str

def search(
self, query: Union[str, MatchSpec], records=True
) -> Union[Iterable[PackageRecord], dict, str]:
def search(self, query: str | MatchSpec, records=True) -> Iterable[PackageRecord] | dict | str:
result_str = self._query.find(self._prepare_query(query), self._format)
if self._format == api.QueryFormat.JSON:
return self._process_query_result(result_str, records=records)
Expand All @@ -364,7 +364,7 @@ def explicit_pool(self, specs: Iterable[MatchSpec]) -> Iterable[str]:
explicit_pool.add(record.name)
return tuple(explicit_pool)

def _prepare_query(self, query: Union[str, MatchSpec]) -> str:
def _prepare_query(self, query: str | MatchSpec) -> str:
if isinstance(query, str):
if "[" not in query:
return query
Expand All @@ -391,7 +391,7 @@ def _process_query_result(
self,
result_str,
records=True,
) -> Union[Iterable[PackageRecord], dict]:
) -> Iterable[PackageRecord] | dict:
result = json_load(result_str)
if result.get("result", {}).get("status") != "OK":
query_type = result.get("query", {}).get("type", "<Unknown>")
Expand Down

0 comments on commit a244048

Please sign in to comment.