diff --git a/src/dstack/_internal/core/backends/aws/backend.py b/src/dstack/_internal/core/backends/aws/backend.py index 3dfd4f4093..1169227cc7 100644 --- a/src/dstack/_internal/core/backends/aws/backend.py +++ b/src/dstack/_internal/core/backends/aws/backend.py @@ -1,3 +1,5 @@ +from typing import Optional + import botocore.exceptions from dstack._internal.core.backends.aws.compute import AWSCompute @@ -11,9 +13,12 @@ class AWSBackend(Backend): TYPE = BackendType.AWS COMPUTE_CLASS = AWSCompute - def __init__(self, config: AWSConfig): + def __init__(self, config: AWSConfig, compute: Optional[AWSCompute] = None): self.config = config - self._compute = AWSCompute(self.config) + if compute is not None: + self._compute = compute + else: + self._compute = AWSCompute(self.config) self._check_credentials() def compute(self) -> AWSCompute: diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index 48720bb316..be3133456c 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -1,6 +1,7 @@ import threading from collections.abc import Iterable from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple import boto3 @@ -19,6 +20,8 @@ ) from dstack._internal.core.backends.base.compute import ( Compute, + ComputeCache, + ComputeTTLCache, ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, @@ -94,6 +97,11 @@ def _ec2client_cache_methodkey(self, ec2_client, *args, **kwargs): return hashkey(*args, **kwargs) +@dataclass +class AWSQuotasCache(ComputeTTLCache): + execution_lock: threading.Lock = field(default_factory=threading.Lock) + + class AWSCompute( ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, @@ -106,7 +114,12 @@ class AWSCompute( ComputeWithVolumeSupport, Compute, ): - def __init__(self, config: AWSConfig): + def __init__( + self, + config: AWSConfig, + quotas_cache: Optional[AWSQuotasCache] = None, + zones_cache: Optional[ComputeCache] = None, + ): super().__init__() self.config = config if isinstance(config.creds, AWSAccessKeyCreds): @@ -119,23 +132,18 @@ def __init__(self, config: AWSConfig): # Caches to avoid redundant API calls when provisioning many instances # get_offers is already cached but we still cache its sub-functions # with more aggressive/longer caches. - self._offers_post_filter_cache_lock = threading.Lock() - self._offers_post_filter_cache = TTLCache(maxsize=10, ttl=180) - self._get_regions_to_quotas_cache_lock = threading.Lock() - self._get_regions_to_quotas_execution_lock = threading.Lock() - self._get_regions_to_quotas_cache = TTLCache(maxsize=10, ttl=300) - self._get_regions_to_zones_cache_lock = threading.Lock() - self._get_regions_to_zones_cache = Cache(maxsize=10) - self._get_vpc_id_subnet_id_or_error_cache_lock = threading.Lock() - self._get_vpc_id_subnet_id_or_error_cache = TTLCache(maxsize=100, ttl=600) - self._get_maximum_efa_interfaces_cache_lock = threading.Lock() - self._get_maximum_efa_interfaces_cache = Cache(maxsize=100) - self._get_subnets_availability_zones_cache_lock = threading.Lock() - self._get_subnets_availability_zones_cache = Cache(maxsize=100) - self._create_security_group_cache_lock = threading.Lock() - self._create_security_group_cache = TTLCache(maxsize=100, ttl=600) - self._get_image_id_and_username_cache_lock = threading.Lock() - self._get_image_id_and_username_cache = TTLCache(maxsize=100, ttl=600) + self._offers_post_filter_cache = ComputeTTLCache(cache=TTLCache(maxsize=10, ttl=180)) + if quotas_cache is None: + quotas_cache = AWSQuotasCache(cache=TTLCache(maxsize=10, ttl=600)) + self._regions_to_quotas_cache = quotas_cache + if zones_cache is None: + zones_cache = ComputeCache(cache=Cache(maxsize=10)) + self._regions_to_zones_cache = zones_cache + self._vpc_id_subnet_id_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600)) + self._maximum_efa_interfaces_cache = ComputeCache(cache=Cache(maxsize=100)) + self._subnets_availability_zones_cache = ComputeCache(cache=Cache(maxsize=100)) + self._security_group_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600)) + self._image_id_and_username_cache = ComputeTTLCache(cache=TTLCache(maxsize=100, ttl=600)) def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: offers = get_catalog_offers( @@ -144,7 +152,7 @@ def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability extra_filter=_supported_instances, ) regions = list(set(i.region for i in offers)) - with self._get_regions_to_quotas_execution_lock: + with self._regions_to_quotas_cache.execution_lock: # Cache lock does not prevent concurrent execution. # We use a separate lock to avoid requesting quotas in parallel and hitting rate limits. regions_to_quotas = self._get_regions_to_quotas(self.session, regions) @@ -173,9 +181,9 @@ def _get_offers_cached_key(self, requirements: Requirements) -> int: return hash(requirements.json()) @cachedmethod( - cache=lambda self: self._offers_post_filter_cache, + cache=lambda self: self._offers_post_filter_cache.cache, key=_get_offers_cached_key, - lock=lambda self: self._offers_post_filter_cache_lock, + lock=lambda self: self._offers_post_filter_cache.lock, ) def get_offers_post_filter( self, requirements: Requirements @@ -789,9 +797,9 @@ def _get_regions_to_quotas_key( return hashkey(tuple(regions)) @cachedmethod( - cache=lambda self: self._get_regions_to_quotas_cache, + cache=lambda self: self._regions_to_quotas_cache.cache, key=_get_regions_to_quotas_key, - lock=lambda self: self._get_regions_to_quotas_cache_lock, + lock=lambda self: self._regions_to_quotas_cache.lock, ) def _get_regions_to_quotas( self, @@ -808,9 +816,9 @@ def _get_regions_to_zones_key( return hashkey(tuple(regions)) @cachedmethod( - cache=lambda self: self._get_regions_to_zones_cache, + cache=lambda self: self._regions_to_zones_cache.cache, key=_get_regions_to_zones_key, - lock=lambda self: self._get_regions_to_zones_cache_lock, + lock=lambda self: self._regions_to_zones_cache.lock, ) def _get_regions_to_zones( self, @@ -832,9 +840,9 @@ def _get_vpc_id_subnet_id_or_error_cache_key( ) @cachedmethod( - cache=lambda self: self._get_vpc_id_subnet_id_or_error_cache, + cache=lambda self: self._vpc_id_subnet_id_cache.cache, key=_get_vpc_id_subnet_id_or_error_cache_key, - lock=lambda self: self._get_vpc_id_subnet_id_or_error_cache_lock, + lock=lambda self: self._vpc_id_subnet_id_cache.lock, ) def _get_vpc_id_subnet_id_or_error( self, @@ -853,9 +861,9 @@ def _get_vpc_id_subnet_id_or_error( ) @cachedmethod( - cache=lambda self: self._get_maximum_efa_interfaces_cache, + cache=lambda self: self._maximum_efa_interfaces_cache.cache, key=_ec2client_cache_methodkey, - lock=lambda self: self._get_maximum_efa_interfaces_cache_lock, + lock=lambda self: self._maximum_efa_interfaces_cache.lock, ) def _get_maximum_efa_interfaces( self, @@ -877,9 +885,9 @@ def _get_subnets_availability_zones_key( return hashkey(region, tuple(subnet_ids)) @cachedmethod( - cache=lambda self: self._get_subnets_availability_zones_cache, + cache=lambda self: self._subnets_availability_zones_cache.cache, key=_get_subnets_availability_zones_key, - lock=lambda self: self._get_subnets_availability_zones_cache_lock, + lock=lambda self: self._subnets_availability_zones_cache.lock, ) def _get_subnets_availability_zones( self, @@ -893,9 +901,9 @@ def _get_subnets_availability_zones( ) @cachedmethod( - cache=lambda self: self._create_security_group_cache, + cache=lambda self: self._security_group_cache.cache, key=_ec2client_cache_methodkey, - lock=lambda self: self._create_security_group_cache_lock, + lock=lambda self: self._security_group_cache.lock, ) def _create_security_group( self, @@ -923,9 +931,9 @@ def _get_image_id_and_username_cache_key( ) @cachedmethod( - cache=lambda self: self._get_image_id_and_username_cache, + cache=lambda self: self._image_id_and_username_cache.cache, key=_get_image_id_and_username_cache_key, - lock=lambda self: self._get_image_id_and_username_cache_lock, + lock=lambda self: self._image_id_and_username_cache.lock, ) def _get_image_id_and_username( self, diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 75a68e77ff..49513e3211 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -6,6 +6,7 @@ import threading from abc import ABC, abstractmethod from collections.abc import Iterable, Iterator +from dataclasses import dataclass, field from enum import Enum from functools import lru_cache from pathlib import Path @@ -14,7 +15,7 @@ import git import requests import yaml -from cachetools import TTLCache, cachedmethod +from cachetools import Cache, TTLCache, cachedmethod from gpuhunt import CPUArchitecture from dstack._internal import settings @@ -89,6 +90,18 @@ def to_cpu_architecture(self) -> CPUArchitecture: assert False, self +@dataclass +class ComputeCache: + cache: Cache + lock: threading.Lock = field(default_factory=threading.Lock) + + +@dataclass +class ComputeTTLCache: + cache: TTLCache + lock: threading.Lock = field(default_factory=threading.Lock) + + class Compute(ABC): """ A base class for all compute implementations with minimal features. diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index c2c18e3d9f..cd5ecb829f 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -1,7 +1,6 @@ import concurrent.futures import json import re -import threading from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass @@ -19,6 +18,7 @@ from dstack import version from dstack._internal.core.backends.base.compute import ( Compute, + ComputeTTLCache, ComputeWithAllOffersCached, ComputeWithCreateInstanceSupport, ComputeWithGatewaySupport, @@ -127,11 +127,9 @@ def __init__(self, config: GCPConfig): credentials=self.credentials ) self.reservations_client = compute_v1.ReservationsClient(credentials=self.credentials) - self._usable_subnets_cache_lock = threading.Lock() - self._usable_subnets_cache = TTLCache(maxsize=1, ttl=120) - self._find_reservation_cache_lock = threading.Lock() - # smaller TTL, since we check the reservation's in_use_count, which can change often - self._find_reservation_cache = TTLCache(maxsize=8, ttl=20) + self._usable_subnets_cache = ComputeTTLCache(cache=TTLCache(maxsize=1, ttl=120)) + # Smaller TTL since we check the reservation's in_use_count, which can change often + self._reservation_cache = ComputeTTLCache(cache=TTLCache(maxsize=8, ttl=20)) def get_all_offers_with_availability(self) -> List[InstanceOfferWithAvailability]: regions = get_or_error(self.config.regions) @@ -948,8 +946,8 @@ def _get_roce_subnets( return nic_subnets @cachedmethod( - cache=lambda self: self._usable_subnets_cache, - lock=lambda self: self._usable_subnets_cache_lock, + cache=lambda self: self._usable_subnets_cache.cache, + lock=lambda self: self._usable_subnets_cache.lock, ) def _list_usable_subnets(self) -> list[compute_v1.UsableSubnetwork]: # To avoid hitting the `ListUsable requests per minute` system limit, we fetch all subnets @@ -969,8 +967,8 @@ def _get_vpc_subnet(self, region: str) -> Optional[str]: ) @cachedmethod( - cache=lambda self: self._find_reservation_cache, - lock=lambda self: self._find_reservation_cache_lock, + cache=lambda self: self._reservation_cache.cache, + lock=lambda self: self._reservation_cache.lock, ) def _find_reservation(self, configured_name: str) -> dict[str, compute_v1.Reservation]: if match := RESERVATION_PATTERN.fullmatch(configured_name): diff --git a/src/dstack/_internal/server/services/backends/__init__.py b/src/dstack/_internal/server/services/backends/__init__.py index 53284e6175..ce0f17bde5 100644 --- a/src/dstack/_internal/server/services/backends/__init__.py +++ b/src/dstack/_internal/server/services/backends/__init__.py @@ -1,5 +1,6 @@ import asyncio import heapq +import time from collections.abc import Iterable, Iterator from typing import Callable, Coroutine, Dict, List, Optional, Tuple from uuid import UUID @@ -361,7 +362,7 @@ def get_filtered_offers_with_backends( yield (backend, offer) logger.info("Requesting instance offers from backends: %s", [b.TYPE.value for b in backends]) - tasks = [run_async(backend.compute().get_offers, requirements) for backend in backends] + tasks = [run_async(get_offers_tracked, backend, requirements) for backend in backends] offers_by_backend = [] for backend, result in zip(backends, await asyncio.gather(*tasks, return_exceptions=True)): if isinstance(result, BackendError): @@ -391,3 +392,13 @@ def check_backend_type_available(backend_type: BackendType): " Ensure that backend dependencies are installed." f" Available backends: {[b.value for b in list_available_backend_types()]}." ) + + +def get_offers_tracked( + backend: Backend, requirements: Requirements +) -> Iterator[InstanceOfferWithAvailability]: + start = time.time() + res = backend.compute().get_offers(requirements) + duration = time.time() - start + logger.debug("Got offers from %s in %.6fs", backend.TYPE.value, duration) + return res