Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions modelstore/storage/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ def client(self):
def validate(self) -> bool:
logger.debug("Querying for buckets with prefix=%s...", self.bucket_name)
try:
resource = boto3.resource("s3")
resource.meta.client.head_bucket(Bucket=self.bucket_name)
self.client.head_bucket(Bucket=self.bucket_name)
return True
except ClientError:
logger.error("Unable to access bucket: %s", self.bucket_name)
Expand Down Expand Up @@ -137,20 +136,21 @@ def _get_storage_location(self, meta_data: metadata.Storage) -> str:
def _read_json_objects(self, prefix: str) -> list:
logger.debug("Listing files in: %s/%s", self.bucket_name, prefix)
results = []
objects = self.client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix)
for version in objects.get("Contents", []):
object_path = version["Key"]
if not object_path.endswith(".json"):
logger.debug("Skipping non-json file: %s", object_path)
continue
if os.path.split(object_path)[0] != prefix:
# We don't want to read files in a sub-prefix
logger.debug("Skipping file in sub-prefix: %s", object_path)
continue

obj = self._read_json_object(object_path)
if obj is not None:
results.append(obj)
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for version in page.get("Contents", []):
object_path = version["Key"]
if not object_path.endswith(".json"):
logger.debug("Skipping non-json file: %s", object_path)
continue
if os.path.split(object_path)[0] != prefix:
# We don't want to read files in a sub-prefix
logger.debug("Skipping file in sub-prefix: %s", object_path)
continue

obj = self._read_json_object(object_path)
if obj is not None:
results.append(obj)
return sorted_by_created(results)

def _read_json_object(self, prefix: str) -> dict:
Expand Down
44 changes: 17 additions & 27 deletions modelstore/storage/backblaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def __init__(
endpoint: Optional[str] = None,
region: Optional[str] = None,
root_prefix: Optional[str] = None,
_client=None,
):
super().__init__(["boto3"], root_prefix, "MODEL_STORE_B2_ROOT_PREFIX")
self.bucket_name = environment.get_value(
Expand All @@ -86,7 +87,7 @@ def __init__(
)
if self.endpoint is None:
self.endpoint = f"https://s3.{self.region}.backblazeb2.com"
self.__client = None
self.__client = _client

@staticmethod
def _boto_config():
Expand Down Expand Up @@ -115,22 +116,10 @@ def client(self):
logger.error("Unable to create B2 s3 client!")
raise

def _get_resource(self):
"""Returns a boto s3 resource configured for B2"""
return boto3.resource(
"s3",
region_name=self.region,
endpoint_url=self.endpoint,
aws_access_key_id=self.key_id,
aws_secret_access_key=self.application_key,
config=self._boto_config(),
)

def validate(self) -> bool:
logger.debug("Querying for buckets with prefix=%s...", self.bucket_name)
try:
resource = self._get_resource()
resource.meta.client.head_bucket(Bucket=self.bucket_name)
self.client.head_bucket(Bucket=self.bucket_name)
return True
except ClientError:
logger.error("Unable to access bucket: %s", self.bucket_name)
Expand Down Expand Up @@ -184,19 +173,20 @@ def _get_storage_location(self, meta_data: metadata.Storage) -> str:
def _read_json_objects(self, prefix: str) -> list:
logger.debug("Listing files in: %s/%s", self.bucket_name, prefix)
results = []
objects = self.client.list_objects_v2(Bucket=self.bucket_name, Prefix=prefix)
for version in objects.get("Contents", []):
object_path = version["Key"]
if not object_path.endswith(".json"):
logger.debug("Skipping non-json file: %s", object_path)
continue
if os.path.split(object_path)[0] != prefix:
logger.debug("Skipping file in sub-prefix: %s", object_path)
continue

obj = self._read_json_object(object_path)
if obj is not None:
results.append(obj)
paginator = self.client.get_paginator("list_objects_v2")
for page in paginator.paginate(Bucket=self.bucket_name, Prefix=prefix):
for version in page.get("Contents", []):
object_path = version["Key"]
if not object_path.endswith(".json"):
logger.debug("Skipping non-json file: %s", object_path)
continue
if os.path.split(object_path)[0] != prefix:
logger.debug("Skipping file in sub-prefix: %s", object_path)
continue

obj = self._read_json_object(object_path)
if obj is not None:
results.append(obj)
return sorted_by_created(results)

def _read_json_object(self, prefix: str) -> dict:
Expand Down
3 changes: 1 addition & 2 deletions tests/storage/test_aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,9 @@ def get_file_contents(moto_boto, prefix):
def test_create_from_environment_variables(monkeypatch):
# Does not fail when environment variables exist
monkeypatch.setenv("MODEL_STORE_AWS_BUCKET", _MOCK_BUCKET_NAME)
# pylint: disable=bare-except
try:
_ = AWSStorage()
except:
except Exception:
pytest.fail("Failed to initialise storage from env variables")


Expand Down
14 changes: 6 additions & 8 deletions tests/storage/test_backblaze.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,26 +60,23 @@ def get_file_contents(moto_boto, prefix):


def _create_storage():
storage = BackblazeStorage(
client = boto3.client("s3", region_name="us-east-1")
return BackblazeStorage(
bucket_name=_MOCK_BUCKET_NAME,
key_id="testing",
application_key="testing",
region="us-east-1",
_client=client,
)
# Override endpoint so boto3 uses the default AWS endpoint
# that moto intercepts
storage.endpoint = None
return storage


def test_create_from_environment_variables(monkeypatch):
monkeypatch.setenv("MODEL_STORE_B2_BUCKET", _MOCK_BUCKET_NAME)
monkeypatch.setenv("B2_APPLICATION_KEY_ID", "testing")
monkeypatch.setenv("B2_APPLICATION_KEY", "testing")
# pylint: disable=bare-except
try:
_ = BackblazeStorage()
except:
except Exception:
pytest.fail("Failed to initialise storage from env variables")


Expand All @@ -104,13 +101,14 @@ def test_create_fails_with_missing_environment_variables(monkeypatch):
],
)
def test_validate(bucket_name, validate_should_pass):
client = boto3.client("s3", region_name="us-east-1")
storage = BackblazeStorage(
bucket_name=bucket_name,
key_id="testing",
application_key="testing",
region="us-east-1",
_client=client,
)
storage.endpoint = None
assert storage.validate() == validate_should_pass


Expand Down
Loading