diff --git a/README.md b/README.md index 5079b86a..fa2457a0 100644 --- a/README.md +++ b/README.md @@ -572,7 +572,7 @@ Pre-trained Jumpstart models can be gotten from https://sagemaker.readthedocs.io ```bash hyp create hyp-jumpstart-endpoint \ - --version 1.0 \ + --version 1.2 \ --model-id jumpstart-model-id\ --instance-type ml.g5.8xlarge \ --endpoint-name endpoint-jumpstart @@ -587,8 +587,36 @@ hyp create hyp-jumpstart-endpoint \ | `--accept-eula` | BOOLEAN | No | Whether model terms of use have been accepted (default: false) | | `--model-version` | TEXT | No | Semantic version of the model (e.g., "1.0.0", 5-14 characters) | | `--endpoint-name` | TEXT | No | Name of SageMaker endpoint (1-63 characters, alphanumeric with hyphens) | -| `--tls-certificate-output-s3-uri` | TEXT | No | S3 URI to write the TLS certificate (optional) | +| `--tls-certificate-output-s3-uri` | TEXT | No | S3 URI to write the TLS certificate | | `--debug` | FLAG | No | Enable debug mode (default: false) | +| `--version` | TEXT | No | Schema version to use (default: "1.2") | +| `--accelerator-partition-type` | TEXT | No | MIG profile for GPU partitioning (must start with "mig-") | +| `--accelerator-partition-validation` | BOOLEAN | No | Enable MIG validation (default: true) | +| `--replicas` | INTEGER | No | Number of inference server replicas (default: 1) | +| `--max-deploy-time-in-seconds` | INTEGER | No | Maximum deployment time in seconds (default: 3600) | +| `--execution-role` | TEXT | No | IAM role ARN for deploying and managing the inference server | +| `--env` | JSON | No | Environment variables as JSON, e.g. `'{"KEY":"value"}'` | +| `--metrics-enabled` | BOOLEAN | No | Enable metrics collection | +| `--metrics-scrape-interval-seconds` | INTEGER | No | Scrape interval for metrics collection | +| `--model-metrics-path` | TEXT | No | Path where the model exposes metrics | +| `--model-metrics-port` | INTEGER | No | Port where the model exposes metrics | +| `--additional-configs` | JSON | No | Additional model configs as JSON key-value pairs | +| `--gated-model-download-role` | TEXT | No | IAM role ARN for downloading gated models | +| `--model-hub-name` | TEXT | No | Name of the model hub | +| `--intelligent-routing-enabled` | BOOLEAN | No | Enable intelligent routing | +| `--routing-strategy` | TEXT | No | Routing strategy: prefixaware, kvaware, session, or roundrobin | +| `--enable-l1-cache` | BOOLEAN | No | Enable L1 cache (CPU offloading) | +| `--enable-l2-cache` | BOOLEAN | No | Enable L2 cache | +| `--l2-cache-backend` | TEXT | No | L2 cache backend type | +| `--l2-cache-local-url` | TEXT | No | L2 cache URL to local storage | +| `--cache-config-file` | TEXT | No | KV cache configuration file path | +| `--load-balancer-health-check-path` | TEXT | No | Health check path for the ALB target group | +| `--load-balancer-routing-algorithm` | TEXT | No | Routing algorithm: least_outstanding_requests or round_robin | +| `--custom-certificate-acm-arn` | TEXT | No | ACM certificate ARN for custom TLS | +| `--custom-certificate-domain-name` | TEXT | No | Domain name for the custom TLS certificate | +| `--auto-scaling-spec` | JSON | No | Full autoScalingSpec JSON for autoscaling configuration | +| `--dns-hosted-zone-id` | TEXT | No | Route53 Hosted Zone ID for DNS automation | +| `--data-capture` | JSON | No | Data capture configuration JSON for SageMaker, LoadBalancer, and Model Pod tiers | #### Invoke a JumpstartModel Endpoint @@ -671,7 +699,7 @@ hyp create #### **Option 2**: Create custom endpoint through create command ```bash hyp create hyp-custom-endpoint \ - --version 1.0 \ + --version 1.2 \ --endpoint-name endpoint-custom \ --model-name my-pytorch-model \ --model-source-type s3 \ @@ -686,17 +714,22 @@ hyp create hyp-custom-endpoint \ | Parameter | Type | Required | Description | |-----------|------|----------|-------------| -| `--instance-type` | TEXT | Yes | EC2 instance type for inference (must start with "ml.") | | `--model-name` | TEXT | Yes | Name of model to create on SageMaker (1-63 characters, alphanumeric with hyphens) | -| `--model-source-type` | TEXT | Yes | Model source type ("s3" or "fsx") | +| `--model-source-type` | TEXT | Yes | Model source type: "s3", "fsx", "huggingface", or "kubernetesVolume" | | `--image-uri` | TEXT | Yes | Docker image URI for inference | | `--container-port` | INTEGER | Yes | Port on which model server listens (1-65535) | | `--model-volume-mount-name` | TEXT | Yes | Name of the model volume mount | | `--namespace` | TEXT | No | Kubernetes namespace | | `--metadata-name` | TEXT | No | Name of the custom endpoint object | | `--endpoint-name` | TEXT | No | Name of SageMaker endpoint (1-63 characters, alphanumeric with hyphens) | -| `--env` | OBJECT | No | Environment variables as key-value pairs | +| `--version` | TEXT | No | Schema version to use (default: "1.2") | +| `--instance-type` | TEXT | No | EC2 instance type (mutually exclusive with --instance-types) | +| `--instance-types` | TEXT | No | Comma-separated list of instance types in order of preference | +| `--env` | JSON | No | Environment variables as JSON, e.g. `'{"KEY":"value"}'` | | `--metrics-enabled` | BOOLEAN | No | Enable metrics collection (default: false) | +| `--metrics-scrape-interval-seconds` | INTEGER | No | Scrape interval for metrics collection | +| `--model-metrics-path` | TEXT | No | Path where the model exposes metrics | +| `--model-metrics-port` | INTEGER | No | Port where the model exposes metrics | | `--model-version` | TEXT | No | Version of the model (semantic version format) | | `--model-location` | TEXT | No | Specific model data location | | `--prefetch-enabled` | BOOLEAN | No | Whether to pre-fetch model data (default: false) | @@ -706,10 +739,42 @@ hyp create hyp-custom-endpoint \ | `--fsx-mount-name` | TEXT | No | FSx File System Mount Name | | `--s3-bucket-name` | TEXT | No | S3 bucket location | | `--s3-region` | TEXT | No | S3 bucket region | +| `--huggingface-model-id` | TEXT | No | HuggingFace Hub model identifier (e.g. "meta-llama/Llama-3.1-8B-Instruct") | +| `--huggingface-commit-sha` | TEXT | No | Git commit SHA for the model revision (40-char hex) | +| `--huggingface-token-secret-name` | TEXT | No | Name of the K8s Secret containing the HuggingFace API token | +| `--huggingface-token-secret-key` | TEXT | No | Key in the K8s Secret for the HuggingFace API token | | `--model-volume-mount-path` | TEXT | No | Path inside container for model volume (default: "/opt/ml/model") | -| `--resources-limits` | OBJECT | No | Resource limits for the worker | -| `--resources-requests` | OBJECT | No | Resource requests for the worker | -| `--dimensions` | OBJECT | No | CloudWatch Metric dimensions as key-value pairs | +| `--resources-limits` | JSON | No | Resource limits, e.g. `'{"nvidia.com/gpu":"1"}'` | +| `--resources-requests` | JSON | No | Resource requests, e.g. `'{"cpu":"1","memory":"2Gi"}'` | +| `--replicas` | INTEGER | No | Number of inference server replicas (default: 1) | +| `--initial-replica-count` | INTEGER | No | Number of desired pods (defaults to 1) | +| `--max-deploy-time-in-seconds` | INTEGER | No | Maximum deployment time in seconds (default: 3600) | +| `--worker-args` | TEXT | No | Comma-separated arguments to the entrypoint | +| `--worker-command` | TEXT | No | Comma-separated entrypoint command array | +| `--working-dir` | TEXT | No | Working directory of the container | +| `--invocation-endpoint` | TEXT | No | Invocation endpoint path (default: "invocations") | +| `--intelligent-routing-enabled` | BOOLEAN | No | Enable intelligent routing | +| `--routing-strategy` | TEXT | No | Routing strategy: prefixaware, kvaware, session, or roundrobin | +| `--enable-l1-cache` | BOOLEAN | No | Enable L1 cache (CPU offloading) | +| `--enable-l2-cache` | BOOLEAN | No | Enable L2 cache | +| `--l2-cache-backend` | TEXT | No | L2 cache backend type | +| `--l2-cache-local-url` | TEXT | No | L2 cache URL to local storage | +| `--cache-config-file` | TEXT | No | KV cache configuration file path | +| `--load-balancer-health-check-path` | TEXT | No | Health check path for the ALB target group | +| `--load-balancer-routing-algorithm` | TEXT | No | Routing algorithm: least_outstanding_requests or round_robin | +| `--max-concurrent-requests` | INTEGER | No | Maximum concurrent requests per pod | +| `--max-queue-size` | INTEGER | No | Maximum request queue size | +| `--overflow-status-code` | INTEGER | No | HTTP status code when request limits exceeded (default: 429) | +| `--custom-certificate-acm-arn` | TEXT | No | ACM certificate ARN for custom TLS | +| `--custom-certificate-domain-name` | TEXT | No | Domain name for the custom TLS certificate | +| `--kubernetes` | JSON | No | Kubernetes customizations (initContainers, volumes, schedulerName, serviceAccountName) | +| `--node-affinity` | JSON | No | Node affinity JSON for advanced scheduling | +| `--tags` | JSON | No | Tags as JSON key-value pairs | +| `--probes` | JSON | No | Container probes JSON (livenessProbe, readinessProbe, startupProbe) | +| `--auto-scaling-spec` | JSON | No | Full autoScalingSpec JSON (overrides individual CloudWatch fields) | +| `--dns-hosted-zone-id` | TEXT | No | Route53 Hosted Zone ID for DNS automation | +| `--data-capture` | JSON | No | Data capture configuration JSON for SageMaker, LoadBalancer, and Model Pod tiers | +| `--dimensions` | JSON | No | CloudWatch Metric dimensions as key-value pairs | | `--metric-collection-period` | INTEGER | No | Period for CloudWatch query (default: 300) | | `--metric-collection-start-time` | INTEGER | No | StartTime for CloudWatch query (default: 300) | | `--metric-name` | TEXT | No | Metric name to query for CloudWatch trigger | @@ -720,7 +785,6 @@ hyp create hyp-custom-endpoint \ | `--cloud-watch-trigger-namespace` | TEXT | No | AWS CloudWatch namespace for the metric | | `--target-value` | NUMBER | No | Target value for the CloudWatch metric | | `--use-cached-metrics` | BOOLEAN | No | Enable caching of metric values (default: true) | -| `--invocation-endpoint` | TEXT | No | Invocation endpoint path (default: "invocations") | | `--debug` | FLAG | No | Enable debug mode (default: false) | diff --git a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/registry.py b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/registry.py index 5fbb3832..c3380717 100644 --- a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/registry.py +++ b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/registry.py @@ -12,16 +12,21 @@ # language governing permissions and limitations under the License. from hyperpod_custom_inference_template.v1_0 import model as v1_0 from hyperpod_custom_inference_template.v1_1 import model as v1_1 +from hyperpod_custom_inference_template.v1_2 import model as v1_2 from hyperpod_custom_inference_template.v1_0.template import ( TEMPLATE_CONTENT as v1_0_template, ) from hyperpod_custom_inference_template.v1_1.template import ( TEMPLATE_CONTENT as v1_1_template, ) +from hyperpod_custom_inference_template.v1_2.template import ( + TEMPLATE_CONTENT as v1_2_template, +) SCHEMA_REGISTRY = { "1.0": v1_0.FlatHPEndpoint, "1.1": v1_1.FlatHPEndpoint, + "1.2": v1_2.FlatHPEndpoint, } -TEMPLATE_REGISTRY = {"1.0": v1_0_template, "1.1": v1_1_template} +TEMPLATE_REGISTRY = {"1.0": v1_0_template, "1.1": v1_1_template, "1.2": v1_2_template} diff --git a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/__init__.py b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/__init__.py new file mode 100644 index 00000000..65490521 --- /dev/null +++ b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. diff --git a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/model.py b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/model.py new file mode 100644 index 00000000..6119d18d --- /dev/null +++ b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/model.py @@ -0,0 +1,736 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from pydantic import BaseModel, Field, model_validator, ConfigDict +from typing import Optional, List, Dict, Union, Literal, Any + +from sagemaker.hyperpod.inference.config.hp_endpoint_config import ( + Metrics, + ModelMetrics, + FsxStorage, + S3Storage, + HuggingFaceModel, + TokenSecretRef, + ModelSourceConfig, + TlsConfig, + CustomCertificateConfig, + EnvironmentVariables, + ModelInvocationPort, + ModelVolumeMount, + Resources, + Worker, + Dimensions, + AutoScalingSpec, + CloudWatchTrigger, + IntelligentRoutingSpec, + KvCacheSpec, + L2CacheSpec, + LoadBalancer, + Kubernetes, + Probes, + Probe, + RequestLimits, + Tags, + NodeAffinity, + DataCapture, + DnsConfig, +) +from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint +from sagemaker.hyperpod.common.config.metadata import Metadata + + +class FlatHPEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid") + + namespace: Optional[str] = Field( + default=None, description="Kubernetes namespace", min_length=1 + ) + + metadata_name: Optional[str] = Field( + None, + alias="metadata_name", + description="Name of the custom endpoint object", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + # endpoint_name + endpoint_name: Optional[str] = Field( + None, + alias="endpoint_name", + description="Name of SageMaker endpoint; empty string means no creation", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + # Environment variables map + env: Optional[Dict[str, str]] = Field( + None, + alias="env", + description="Map of environment variable names to their values", + ) + + instance_type: Optional[str] = Field( + None, + alias="instance_type", + description="EC2 instance type for the inference server. Mutually exclusive with instance_types.", + pattern=r"^ml\..*", + ) + + instance_types: Optional[str] = Field( + None, + alias="instance_types", + description="Comma-separated list of instance types in order of preference", + ) + + # metrics.* + metrics_enabled: Optional[bool] = Field( + None, + alias="metrics_enabled", + description="Enable metrics collection", + ) + + # model_name and version + model_name: str = Field( + ..., + alias="model_name", + description="Name of model to create on SageMaker", + min_length=1, + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + model_version: Optional[str] = Field( + None, + alias="model_version", + description="Version of the model for the endpoint", + min_length=5, + max_length=14, + pattern=r"^\d{1,4}\.\d{1,4}\.\d{1,4}$", + ) + + # model_source_config.* + model_source_type: Literal["fsx", "s3", "huggingface", "kubernetesVolume"] = Field( + ..., + alias="model_source_type", + description="Source type: fsx, s3, huggingface, or kubernetesVolume", + ) + model_location: Optional[str] = Field( + None, + alias="model_location", + description="Specific model data location", + ) + prefetch_enabled: Optional[bool] = Field( + False, + alias="prefetch_enabled", + description="Whether to pre-fetch model data", + ) + + # tls_config + tls_certificate_output_s3_uri: Optional[str] = Field( + None, + alias="tls_certificate_output_s3_uri", + description="S3 URI for TLS certificate output", + pattern=r"^s3://([^/]+)/?(.*)$", + ) + + # worker.* + image_uri: str = Field( + ..., + alias="image_uri", + description="Inference server image name", + ) + container_port: int = Field( + ..., + alias="container_port", + description="Port on which the model server listens", + ge=1, + le=65535, + ) + model_volume_mount_path: Optional[str] = Field( + "/opt/ml/model", + alias="model_volume_mount_path", + description="Path inside container for model volume", + ) + model_volume_mount_name: str = Field( + ..., + alias="model_volume_mount_name", + description="Name of the model volume mount", + ) + + # FSXStorage + fsx_dns_name: Optional[str] = Field( + None, + alias="fsx_dns_name", + description="FSX File System DNS Name", + ) + fsx_file_system_id: Optional[str] = Field( + None, + alias="fsx_file_system_id", + description="FSX File System ID", + ) + fsx_mount_name: Optional[str] = Field( + None, + alias="fsx_mount_name", + description="FSX File System Mount Name", + ) + + # S3Storage + s3_bucket_name: Optional[str] = Field( + None, + alias="s3_bucket_name", + description="S3 bucket location", + ) + s3_region: Optional[str] = Field( + None, + alias="s3_region", + description="S3 bucket region", + ) + + # Resources + resources_limits: Optional[Dict[str, Union[int, str]]] = Field( + None, + alias="resources_limits", + description="Resource limits for the worker", + ) + resources_requests: Optional[Dict[str, Union[int, str]]] = Field( + None, + alias="resources_requests", + description="Resource requests for the worker", + ) + + # Dimensions + dimensions: Optional[Dict[str, str]] = Field( + None, + alias="dimensions", + description="CloudWatch Metric dimensions as key–value pairs", + ) + + # CloudWatch Trigger + metric_collection_period: Optional[int] = Field( + 300, description="Defines the Period for CloudWatch query" + ) + metric_collection_start_time: Optional[int] = Field( + 300, description="Defines the StartTime for CloudWatch query" + ) + metric_name: Optional[str] = Field( + None, description="Metric name to query for CloudWatch trigger" + ) + metric_stat: Optional[str] = Field( + "Average", + description=( + "Statistics metric to be used by Trigger. " + "Defines the Stat for the CloudWatch query. Default is Average." + ), + ) + metric_type: Optional[Literal["Value", "Average"]] = Field( + "Average", + description=( + "The type of metric to be used by HPA. " + "`Average` – Uses average value per pod; " + "`Value` – Uses absolute metric value." + ), + ) + min_value: Optional[float] = Field( + 0, + description=( + "Minimum metric value used in case of empty response " + "from CloudWatch. Default is 0." + ), + ) + cloud_watch_trigger_name: Optional[str] = Field( + None, description="Name for the CloudWatch trigger" + ) + cloud_watch_trigger_namespace: Optional[str] = Field( + None, description="AWS CloudWatch namespace for the metric" + ) + target_value: Optional[float] = Field( + None, description="Target value for the CloudWatch metric" + ) + use_cached_metrics: Optional[bool] = Field( + True, + description=( + "Enable caching of metric values during polling interval. " + "Default is true." + ), + ) + + invocation_endpoint: Optional[str] = Field( + default="invocations", + description=( + "The invocation endpoint of the model server. http://:/ would be pre-populated based on the other fields. " + "Please fill in the path after http://:/ specific to your model server." + ), + ) + + # Intelligent Routing flattened fields + intelligent_routing_enabled: Optional[bool] = Field( + None, + alias="intelligent_routing_enabled", + description="Enable intelligent routing", + ) + routing_strategy: Optional[ + Literal["prefixaware", "kvaware", "session", "roundrobin"] + ] = Field( + None, + alias="routing_strategy", + description="Routing strategy for intelligent routing", + ) + + # KV Cache flattened fields + enable_l1_cache: Optional[bool] = Field( + None, + alias="enable_l1_cache", + description="Enable L1 cache (CPU offloading)", + ) + enable_l2_cache: Optional[bool] = Field( + None, + alias="enable_l2_cache", + description="Enable L2 cache", + ) + l2_cache_backend: Optional[str] = Field( + None, + alias="l2_cache_backend", + description="L2 cache backend type", + ) + l2_cache_local_url: Optional[str] = Field( + None, + alias="l2_cache_local_url", + description="L2 cache URL to local storage", + ) + cache_config_file: Optional[str] = Field( + None, + alias="cache_config_file", + description="KV cache configuration file path", + ) + + # maxDeployTimeInSeconds + max_deploy_time_in_seconds: Optional[int] = Field( + 3600, + alias="max_deploy_time_in_seconds", + description="Maximum deployment time in seconds. Defaults to 3600.", + ) + + # customCertificateConfig + custom_certificate_acm_arn: Optional[str] = Field( + None, + alias="custom_certificate_acm_arn", + description="ACM certificate ARN for custom TLS certificate", + pattern=r"^arn:aws:acm:[a-z0-9-]+:[0-9]{12}:certificate/[a-fA-F0-9-]+$", + ) + custom_certificate_domain_name: Optional[str] = Field( + None, + alias="custom_certificate_domain_name", + description="Domain name for the custom TLS certificate", + max_length=253, + pattern=r"^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*$", + ) + + # LoadBalancer + load_balancer_health_check_path: Optional[str] = Field( + None, + alias="load_balancer_health_check_path", + description="Health check path for the ALB target group", + ) + load_balancer_routing_algorithm: Optional[str] = Field( + None, + alias="load_balancer_routing_algorithm", + description="Routing algorithm: least_outstanding_requests or round_robin", + ) + + # RequestLimits + max_concurrent_requests: Optional[int] = Field( + None, + alias="max_concurrent_requests", + description="Maximum concurrent requests per pod for nginx sidecar proxy", + ) + max_queue_size: Optional[int] = Field( + None, + alias="max_queue_size", + description="Maximum request queue size when concurrent limit is reached", + ) + + # Kubernetes customizations + kubernetes: Optional[Dict[str, Any]] = Field( + None, + alias="kubernetes", + description="Kubernetes customizations for the inference pod (initContainers, volumes, schedulerName)", + ) + + # Replicas + replicas: Optional[int] = Field( + 1, alias="replicas", description="Number of inference server replicas. Default 1." + ) + initial_replica_count: Optional[int] = Field( + None, alias="initial_replica_count", + description="Number of desired pods. Defaults to 1.", + ) + + # Worker args/command/workingDir + worker_args: Optional[str] = Field( + None, alias="worker_args", + description="Comma-separated arguments to the entrypoint", + ) + worker_command: Optional[str] = Field( + None, alias="worker_command", + description="Comma-separated entrypoint command array", + ) + working_dir: Optional[str] = Field( + None, alias="working_dir", + description="Working directory of the container", + ) + + # Overflow status code (requestLimits) + overflow_status_code: Optional[int] = Field( + None, alias="overflow_status_code", + description="HTTP status code when request limits exceeded. Default 429.", + ) + + # Metrics sub-fields + metrics_scrape_interval_seconds: Optional[int] = Field( + None, alias="metrics_scrape_interval_seconds", + description="Scrape interval in seconds for metrics collection", + ) + model_metrics_path: Optional[str] = Field( + None, alias="model_metrics_path", + description="Path where the model exposes metrics", + ) + model_metrics_port: Optional[int] = Field( + None, alias="model_metrics_port", + description="Port where the model exposes metrics", + ) + + # JSON flags for complex objects + node_affinity: Optional[Dict[str, Any]] = Field( + None, alias="node_affinity", + description="Node affinity JSON for advanced scheduling", + ) + tags: Optional[Dict[str, str]] = Field( + None, alias="tags", + description="Tags as key-value pairs to add to the SageMaker Endpoint", + ) + probes: Optional[Dict[str, Any]] = Field( + None, alias="probes", + description="Container probes JSON (livenessProbe, readinessProbe, startupProbe)", + ) + auto_scaling_spec: Optional[Dict[str, Any]] = Field( + None, alias="auto_scaling_spec", + description="Full autoScalingSpec JSON (overrides individual cloudwatch fields if provided)", + ) + + # HuggingFace model source fields + huggingface_model_id: Optional[str] = Field( + None, alias="huggingface_model_id", + description="HuggingFace Hub model identifier in org/model format", + ) + huggingface_commit_sha: Optional[str] = Field( + None, alias="huggingface_commit_sha", + description="Git commit SHA for the model revision (40-char hex)", + ) + huggingface_token_secret_name: Optional[str] = Field( + None, alias="huggingface_token_secret_name", + description="Name of the K8s Secret containing the HuggingFace API token", + ) + huggingface_token_secret_key: Optional[str] = Field( + None, alias="huggingface_token_secret_key", + description="Key in the K8s Secret for the HuggingFace API token", + ) + + # DNS config + dns_hosted_zone_id: Optional[str] = Field( + None, alias="dns_hosted_zone_id", + description="Route53 Hosted Zone ID for DNS automation", + pattern=r"^Z[A-Z0-9]+$", + ) + + # Data capture (JSON flag) + data_capture: Optional[Dict[str, Any]] = Field( + None, alias="data_capture", + description="Data capture configuration JSON for SageMaker, LoadBalancer, and Model Pod tiers", + ) + + @model_validator(mode="after") + def validate_model_source_config(self): + """Validate that required fields are provided based on model_source_type""" + if self.model_source_type == "s3": + if not self.s3_bucket_name or not self.s3_region: + raise ValueError( + "s3_bucket_name and s3_region are required when model_source_type is 's3'" + ) + elif self.model_source_type == "fsx": + if not self.fsx_file_system_id: + raise ValueError( + "fsx_file_system_id is required when model_source_type is 'fsx'" + ) + elif self.model_source_type == "huggingface": + if not self.huggingface_model_id: + raise ValueError( + "huggingface_model_id is required when model_source_type is 'huggingface'" + ) + return self + + @model_validator(mode="after") + def validate_name(self): + if not self.metadata_name and not self.endpoint_name: + raise ValueError("Either metadata_name or endpoint_name must be provided") + return self + + @model_validator(mode="after") + def validate_instance_type_fields(self): + has_instance = self.instance_type or self.instance_types + if self.instance_type and self.instance_types: + raise ValueError("instance_type and instance_types are mutually exclusive") + if self.node_affinity and has_instance: + raise ValueError("node_affinity cannot be specified with instance_type or instance_types simultaneously") + if not has_instance and not self.node_affinity: + raise ValueError("Either instance_type, instance_types, or node_affinity must be provided") + return self + + @model_validator(mode="after") + def validate_certificate_and_dns(self): + has_acm = self.custom_certificate_acm_arn is not None + has_domain = self.custom_certificate_domain_name is not None + has_dns = self.dns_hosted_zone_id is not None + if has_acm != has_domain: + raise ValueError( + "custom_certificate_acm_arn and custom_certificate_domain_name must both be provided together" + ) + if has_dns and not (has_acm and has_domain): + raise ValueError( + "dns_hosted_zone_id requires both custom_certificate_acm_arn and custom_certificate_domain_name" + ) + return self + + def to_domain(self) -> HPEndpoint: + if self.endpoint_name and not self.metadata_name: + self.metadata_name = self.endpoint_name + + metadata = Metadata(name=self.metadata_name, namespace=self.namespace) + + env_vars = None + if self.env: + env_vars = [ + EnvironmentVariables(name=k, value=v) for k, v in self.env.items() + ] + + dim_vars: list[Dimensions] = [] + if self.dimensions: + for name, value in self.dimensions.items(): + dim_vars.append(Dimensions(name=name, value=value)) + + cloud_watch_trigger = CloudWatchTrigger( + dimensions=dim_vars, + metric_collection_period=self.metric_collection_period, + metric_collection_start_time=self.metric_collection_start_time, + metric_name=self.metric_name, + metric_stat=self.metric_stat, + metric_type=self.metric_type, + min_value=self.min_value, + name=self.cloud_watch_trigger_name, + namespace=self.cloud_watch_trigger_namespace, + target_value=self.target_value, + use_cached_metrics=self.use_cached_metrics, + ) + + auto_scaling_spec = AutoScalingSpec(**self.auto_scaling_spec) if self.auto_scaling_spec else AutoScalingSpec(cloud_watch_trigger=cloud_watch_trigger) + + # nested metrics + model_metrics = None + if self.model_metrics_path or self.model_metrics_port: + model_metrics = ModelMetrics( + path=self.model_metrics_path, + port=self.model_metrics_port, + ) + metrics = None + if self.metrics_enabled is not None or self.metrics_scrape_interval_seconds is not None or model_metrics is not None: + metrics = Metrics( + enabled=self.metrics_enabled, + metrics_scrape_interval_seconds=self.metrics_scrape_interval_seconds, + model_metrics=model_metrics, + ) + + # Validate storage choice and build nested storage config + if self.model_source_type == "s3": + s3 = S3Storage( + bucket_name=self.s3_bucket_name, + region=self.s3_region, + ) + fsx = None + hf_model = None + elif self.model_source_type == "fsx": + fsx = FsxStorage( + dns_name=self.fsx_dns_name, + file_system_id=self.fsx_file_system_id, + mount_name=self.fsx_mount_name, + ) + s3 = None + hf_model = None + elif self.model_source_type == "huggingface": + s3 = None + fsx = None + token_ref = None + if self.huggingface_token_secret_name and self.huggingface_token_secret_key: + token_ref = TokenSecretRef( + name=self.huggingface_token_secret_name, + key=self.huggingface_token_secret_key, + ) + hf_model = HuggingFaceModel( + model_id=self.huggingface_model_id, + commit_sha=self.huggingface_commit_sha, + token_secret_ref=token_ref, + ) + elif self.model_source_type == "kubernetesVolume": + s3 = None + fsx = None + hf_model = None + else: + raise ValueError(f"Unsupported model_source_type: {self.model_source_type}") + + source = ModelSourceConfig( + model_location=self.model_location, + model_source_type=self.model_source_type, + prefetch_enabled=self.prefetch_enabled, + s3_storage=s3, + fsx_storage=fsx, + hugging_face_model=hf_model, + ) + + tls = TlsConfig( + tls_certificate_output_s3_uri=self.tls_certificate_output_s3_uri, + custom_certificate_config=CustomCertificateConfig( + acm_arn=self.custom_certificate_acm_arn, + domain_name=self.custom_certificate_domain_name, + ) if self.custom_certificate_acm_arn and self.custom_certificate_domain_name else None, + ) + + invocation_port = ModelInvocationPort( + container_port=self.container_port, + ) + volume_mount = ModelVolumeMount( + mount_path=self.model_volume_mount_path, + name=self.model_volume_mount_name, + ) + resources = Resources( + limits=self.resources_limits, + requests=self.resources_requests, + ) + request_limits = None + if self.max_concurrent_requests is not None or self.max_queue_size is not None or self.overflow_status_code is not None: + request_limits = RequestLimits( + max_concurrent_requests=self.max_concurrent_requests, + max_queue_size=self.max_queue_size, + overflow_status_code=self.overflow_status_code, + ) + + # Parse worker args/command from comma-separated strings + worker_args = [a.strip() for a in self.worker_args.split(",")] if self.worker_args else None + worker_command = [c.strip() for c in self.worker_command.split(",")] if self.worker_command else None + + # Build probes from JSON + worker_probes = Probes(**self.probes) if self.probes else None + + worker = Worker( + environment_variables=env_vars, + image=self.image_uri, + model_invocation_port=invocation_port, + model_volume_mount=volume_mount, + resources=resources, + request_limits=request_limits, + args=worker_args, + command=worker_command, + working_dir=self.working_dir, + probes=worker_probes, + ) + # Build intelligent routing spec from flattened fields + intelligent_routing_spec = None + if self.intelligent_routing_enabled is not None: + intelligent_routing_spec = IntelligentRoutingSpec( + enabled=self.intelligent_routing_enabled, + routing_strategy=self.routing_strategy, + ) + + # Build KV cache spec from flattened fields + kv_cache_spec = None + if any([self.enable_l1_cache, self.enable_l2_cache, self.cache_config_file]): + l2_cache_spec = None + if self.l2_cache_backend or self.l2_cache_local_url: + l2_cache_spec = L2CacheSpec( + l2_cache_backend=self.l2_cache_backend, + l2_cache_local_url=self.l2_cache_local_url, + ) + + kv_cache_spec = KvCacheSpec( + enable_l1_cache=self.enable_l1_cache, + enable_l2_cache=self.enable_l2_cache, + l2_cache_spec=l2_cache_spec, + cache_config_file=self.cache_config_file, + ) + + # Build load balancer config + load_balancer = None + if self.load_balancer_health_check_path or self.load_balancer_routing_algorithm: + load_balancer = LoadBalancer( + health_check_path=self.load_balancer_health_check_path, + routing_algorithm=self.load_balancer_routing_algorithm, + ) + + # Parse instance_types from comma-separated string + instance_types_list = None + if self.instance_types: + instance_types_list = [t.strip() for t in self.instance_types.split(",")] + + # Build kubernetes config + kubernetes = None + if self.kubernetes: + kubernetes = Kubernetes(**self.kubernetes) + + # Build tags + tags_list = None + if self.tags: + tags_list = [Tags(name=k, value=v) for k, v in self.tags.items()] + + # Build node affinity from JSON + node_affinity = NodeAffinity(**self.node_affinity) if self.node_affinity else None + + # Build DNS config + dns_config = None + if self.dns_hosted_zone_id: + dns_config = DnsConfig(hosted_zone_id=self.dns_hosted_zone_id) + + # Build data capture from JSON + data_capture = DataCapture(**self.data_capture) if self.data_capture else None + + return HPEndpoint( + metadata=metadata, + endpoint_name=self.endpoint_name, + instance_type=self.instance_type, + instance_types=instance_types_list, + metrics=metrics, + model_name=self.model_name, + model_source_config=source, + model_version=self.model_version, + tls_config=tls, + worker=worker, + invocation_endpoint=self.invocation_endpoint, + auto_scaling_spec=auto_scaling_spec, + intelligent_routing_spec=intelligent_routing_spec, + kv_cache_spec=kv_cache_spec, + max_deploy_time_in_seconds=self.max_deploy_time_in_seconds, + load_balancer=load_balancer, + kubernetes=kubernetes, + replicas=self.replicas, + initial_replica_count=self.initial_replica_count, + tags=tags_list, + node_affinity=node_affinity, + dns_config=dns_config, + data_capture=data_capture, + ) diff --git a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/schema.json b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/schema.json new file mode 100644 index 00000000..237d6363 --- /dev/null +++ b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/schema.json @@ -0,0 +1,953 @@ +{ + "additionalProperties": false, + "properties": { + "namespace": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Kubernetes namespace", + "title": "Namespace" + }, + "metadata_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the custom endpoint object", + "title": "Metadata Name" + }, + "endpoint_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of SageMaker endpoint; empty string means no creation", + "title": "Endpoint Name" + }, + "env": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Map of environment variable names to their values", + "title": "Env" + }, + "instance_type": { + "anyOf": [ + { + "pattern": "^ml\\..*", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "EC2 instance type for the inference server. Mutually exclusive with instance_types.", + "title": "Instance Type" + }, + "instance_types": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Comma-separated list of instance types in order of preference", + "title": "Instance Types" + }, + "metrics_enabled": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable metrics collection", + "title": "Metrics Enabled" + }, + "model_name": { + "description": "Name of model to create on SageMaker", + "maxLength": 63, + "minLength": 1, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "title": "Model Name", + "type": "string" + }, + "model_version": { + "anyOf": [ + { + "maxLength": 14, + "minLength": 5, + "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Version of the model for the endpoint", + "title": "Model Version" + }, + "model_source_type": { + "description": "Source type: fsx, s3, huggingface, or kubernetesVolume", + "enum": [ + "fsx", + "s3", + "huggingface", + "kubernetesVolume" + ], + "title": "Model Source Type", + "type": "string" + }, + "model_location": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Specific model data location", + "title": "Model Location" + }, + "prefetch_enabled": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": false, + "description": "Whether to pre-fetch model data", + "title": "Prefetch Enabled" + }, + "tls_certificate_output_s3_uri": { + "anyOf": [ + { + "pattern": "^s3://([^/]+)/?(.*)$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 URI for TLS certificate output", + "title": "Tls Certificate Output S3 Uri" + }, + "image_uri": { + "description": "Inference server image name", + "title": "Image Uri", + "type": "string" + }, + "container_port": { + "description": "Port on which the model server listens", + "maximum": 65535, + "minimum": 1, + "title": "Container Port", + "type": "integer" + }, + "model_volume_mount_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "/opt/ml/model", + "description": "Path inside container for model volume", + "title": "Model Volume Mount Path" + }, + "model_volume_mount_name": { + "description": "Name of the model volume mount", + "title": "Model Volume Mount Name", + "type": "string" + }, + "fsx_dns_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "FSX File System DNS Name", + "title": "Fsx Dns Name" + }, + "fsx_file_system_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "FSX File System ID", + "title": "Fsx File System Id" + }, + "fsx_mount_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "FSX File System Mount Name", + "title": "Fsx Mount Name" + }, + "s3_bucket_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 bucket location", + "title": "S3 Bucket Name" + }, + "s3_region": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 bucket region", + "title": "S3 Region" + }, + "resources_limits": { + "anyOf": [ + { + "additionalProperties": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ] + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Resource limits for the worker", + "title": "Resources Limits" + }, + "resources_requests": { + "anyOf": [ + { + "additionalProperties": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "string" + } + ] + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Resource requests for the worker", + "title": "Resources Requests" + }, + "dimensions": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "CloudWatch Metric dimensions as key\u2013value pairs", + "title": "Dimensions" + }, + "metric_collection_period": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 300, + "description": "Defines the Period for CloudWatch query", + "title": "Metric Collection Period" + }, + "metric_collection_start_time": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 300, + "description": "Defines the StartTime for CloudWatch query", + "title": "Metric Collection Start Time" + }, + "metric_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Metric name to query for CloudWatch trigger", + "title": "Metric Name" + }, + "metric_stat": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "Average", + "description": "Statistics metric to be used by Trigger. Defines the Stat for the CloudWatch query. Default is Average.", + "title": "Metric Stat" + }, + "metric_type": { + "anyOf": [ + { + "enum": [ + "Value", + "Average" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": "Average", + "description": "The type of metric to be used by HPA. `Average` \u2013 Uses average value per pod; `Value` \u2013 Uses absolute metric value.", + "title": "Metric Type" + }, + "min_value": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "default": 0, + "description": "Minimum metric value used in case of empty response from CloudWatch. Default is 0.", + "title": "Min Value" + }, + "cloud_watch_trigger_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name for the CloudWatch trigger", + "title": "Cloud Watch Trigger Name" + }, + "cloud_watch_trigger_namespace": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "AWS CloudWatch namespace for the metric", + "title": "Cloud Watch Trigger Namespace" + }, + "target_value": { + "anyOf": [ + { + "type": "number" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Target value for the CloudWatch metric", + "title": "Target Value" + }, + "use_cached_metrics": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": true, + "description": "Enable caching of metric values during polling interval. Default is true.", + "title": "Use Cached Metrics" + }, + "invocation_endpoint": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": "invocations", + "description": "The invocation endpoint of the model server. http://:/ would be pre-populated based on the other fields. Please fill in the path after http://:/ specific to your model server.", + "title": "Invocation Endpoint" + }, + "intelligent_routing_enabled": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable intelligent routing", + "title": "Intelligent Routing Enabled" + }, + "routing_strategy": { + "anyOf": [ + { + "enum": [ + "prefixaware", + "kvaware", + "session", + "roundrobin" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Routing strategy for intelligent routing", + "title": "Routing Strategy" + }, + "enable_l1_cache": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable L1 cache (CPU offloading)", + "title": "Enable L1 Cache" + }, + "enable_l2_cache": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable L2 cache", + "title": "Enable L2 Cache" + }, + "l2_cache_backend": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "L2 cache backend type", + "title": "L2 Cache Backend" + }, + "l2_cache_local_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "L2 cache URL to local storage", + "title": "L2 Cache Local Url" + }, + "cache_config_file": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "KV cache configuration file path", + "title": "Cache Config File" + }, + "max_deploy_time_in_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 3600, + "description": "Maximum deployment time in seconds. Defaults to 3600.", + "title": "Max Deploy Time In Seconds" + }, + "custom_certificate_acm_arn": { + "anyOf": [ + { + "type": "string", + "pattern": "^arn:aws:acm:[a-z0-9-]+:[0-9]{12}:certificate/[a-fA-F0-9-]+$" + }, + { + "type": "null" + } + ], + "default": null, + "description": "ACM certificate ARN for custom TLS certificate", + "title": "Custom Certificate Acm Arn" + }, + "custom_certificate_domain_name": { + "anyOf": [ + { + "type": "string", + "pattern": "^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*$", + "maxLength": 253 + }, + { + "type": "null" + } + ], + "default": null, + "description": "Domain name for the custom TLS certificate", + "title": "Custom Certificate Domain Name" + }, + "load_balancer_health_check_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Health check path for the ALB target group", + "title": "Load Balancer Health Check Path" + }, + "load_balancer_routing_algorithm": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Routing algorithm: least_outstanding_requests or round_robin", + "title": "Load Balancer Routing Algorithm" + }, + "max_concurrent_requests": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum concurrent requests per pod for nginx sidecar proxy", + "title": "Max Concurrent Requests" + }, + "max_queue_size": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Maximum request queue size when concurrent limit is reached", + "title": "Max Queue Size" + }, + "kubernetes": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Kubernetes customizations for the inference pod (initContainers, volumes, schedulerName)", + "title": "Kubernetes" + }, + "replicas": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 1, + "description": "Number of inference server replicas. Default 1.", + "title": "Replicas" + }, + "initial_replica_count": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Number of desired pods. Defaults to 1.", + "title": "Initial Replica Count" + }, + "worker_args": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Comma-separated arguments to the entrypoint", + "title": "Worker Args" + }, + "worker_command": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Comma-separated entrypoint command array", + "title": "Worker Command" + }, + "working_dir": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Working directory of the container", + "title": "Working Dir" + }, + "overflow_status_code": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "HTTP status code when request limits exceeded. Default 429.", + "title": "Overflow Status Code" + }, + "metrics_scrape_interval_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Scrape interval in seconds for metrics collection", + "title": "Metrics Scrape Interval Seconds" + }, + "model_metrics_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Path where the model exposes metrics", + "title": "Model Metrics Path" + }, + "model_metrics_port": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Port where the model exposes metrics", + "title": "Model Metrics Port" + }, + "node_affinity": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Node affinity JSON for advanced scheduling", + "title": "Node Affinity" + }, + "tags": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Tags as key-value pairs to add to the SageMaker Endpoint", + "title": "Tags" + }, + "probes": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Container probes JSON (livenessProbe, readinessProbe, startupProbe)", + "title": "Probes" + }, + "auto_scaling_spec": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Full autoScalingSpec JSON (overrides individual cloudwatch fields if provided)", + "title": "Auto Scaling Spec" + }, + "huggingface_model_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "HuggingFace Hub model identifier in org/model format", + "title": "Huggingface Model Id" + }, + "huggingface_commit_sha": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Git commit SHA for the model revision (40-char hex)", + "title": "Huggingface Commit Sha" + }, + "huggingface_token_secret_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the K8s Secret containing the HuggingFace API token", + "title": "Huggingface Token Secret Name" + }, + "huggingface_token_secret_key": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Key in the K8s Secret for the HuggingFace API token", + "title": "Huggingface Token Secret Key" + }, + "dns_hosted_zone_id": { + "anyOf": [ + { + "type": "string", + "pattern": "^Z[A-Z0-9]+$" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Route53 Hosted Zone ID for DNS automation", + "title": "Dns Hosted Zone Id" + }, + "data_capture": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Data capture configuration JSON for SageMaker, LoadBalancer, and Model Pod tiers", + "title": "Data Capture" + } + }, + "required": [ + "model_name", + "model_source_type", + "image_uri", + "container_port", + "model_volume_mount_name" + ], + "title": "FlatHPEndpoint", + "type": "object" +} diff --git a/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/template.py b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/template.py new file mode 100644 index 00000000..497d9c3f --- /dev/null +++ b/hyperpod-custom-inference-template/hyperpod_custom_inference_template/v1_2/template.py @@ -0,0 +1,232 @@ +TEMPLATE_CONTENT = """ +apiVersion: inference.sagemaker.aws.amazon.com/v1 +kind: InferenceEndpointConfig +metadata: + name: {{ metadata_name or endpoint_name }} + namespace: {{ namespace }} +spec: + endpointName: {{ endpoint_name }} +{%- if instance_type %} + instanceType: {{ instance_type }} +{%- endif %} +{%- if instance_types %} + instanceTypes: +{%- for it in instance_types.split(",") %} + - {{ it.strip() }} +{%- endfor %} +{%- endif %} + modelName: {{ model_name }} + modelVersion: {{ model_version or "" }} +{%- if replicas is not none %} + replicas: {{ replicas }} +{%- endif %} +{%- if initial_replica_count is not none and initial_replica_count != "" %} + InitialReplicaCount: {{ initial_replica_count }} +{%- endif %} +{%- if max_deploy_time_in_seconds is not none %} + maxDeployTimeInSeconds: {{ max_deploy_time_in_seconds }} +{%- endif %} +{%- if tags %} + tags: +{%- for tag_name, tag_value in tags.items() %} + - name: {{ tag_name }} + value: "{{ tag_value }}" +{%- endfor %} +{%- endif %} + + + metrics: + enabled: {{ metrics_enabled or False }} +{%- if metrics_scrape_interval_seconds %} + metricsScrapeIntervalSeconds: {{ metrics_scrape_interval_seconds }} +{%- endif %} +{%- if model_metrics_path or model_metrics_port %} + modelMetrics: +{%- if model_metrics_path %} + path: "{{ model_metrics_path }}" +{%- endif %} +{%- if model_metrics_port %} + port: {{ model_metrics_port }} +{%- endif %} +{%- endif %} + + modelSourceConfig: + modelSourceType: {{ model_source_type }} + modelLocation: {{ model_location or "" }} + prefetchEnabled: {{ prefetch_enabled or False }} +{%- if model_source_type == "s3" %} + s3Storage: + bucketName: {{ s3_bucket_name }} + region: {{ s3_region }} +{%- elif model_source_type == "fsx" %} + fsxStorage: + dnsName: {{ fsx_dns_name }} + fileSystemId: {{ fsx_file_system_id }} + mountName: {{ fsx_mount_name or "" }} +{%- elif model_source_type == "huggingface" %} + huggingFaceModel: + modelId: {{ huggingface_model_id }} +{%- if huggingface_commit_sha %} + commitSHA: {{ huggingface_commit_sha }} +{%- endif %} +{%- if huggingface_token_secret_name and huggingface_token_secret_key %} + tokenSecretRef: + name: {{ huggingface_token_secret_name }} + key: {{ huggingface_token_secret_key }} +{%- endif %} +{%- endif %} + + + tlsConfig: + tlsCertificateOutputS3Uri: {{ tls_certificate_output_s3_uri or "" }} +{%- if custom_certificate_acm_arn and custom_certificate_domain_name %} + customCertificateConfig: + acmArn: "{{ custom_certificate_acm_arn }}" + domainName: "{{ custom_certificate_domain_name }}" +{%- endif %} + +{%- if node_affinity %} + nodeAffinity: {{ node_affinity | tojson }} +{%- endif %} + +{%- if kubernetes %} + kubernetes: {{ kubernetes | tojson }} +{%- endif %} + + worker: + environmentVariables: + {%- if env %} + {%- for key, val in env.items() %} + - name: {{ key }} + value: "{{ val }}" + {%- endfor %} + {%- else %} + [] + {%- endif %} + image: {{ image_uri }} + modelInvocationPort: + containerPort: {{ container_port }} + modelVolumeMount: + name: {{ model_volume_mount_name }} + mountPath: {{ model_volume_mount_path }} + resources: +{%- if resources_limits %} + limits: +{%- for key, val in resources_limits.items() %} + {{ key }}: {{ val }} +{%- endfor %} +{%- else %} + {} +{%- endif %} +{%- if resources_requests %} + requests: +{%- for key, val in resources_requests.items() %} + {{ key }}: {{ val }} +{%- endfor %} +{%- endif %} +{%- if worker_args %} + args: +{%- for arg in worker_args.split(",") %} + - "{{ arg.strip() }}" +{%- endfor %} +{%- endif %} +{%- if worker_command %} + command: +{%- for cmd in worker_command.split(",") %} + - "{{ cmd.strip() }}" +{%- endfor %} +{%- endif %} +{%- if working_dir %} + workingDir: "{{ working_dir }}" +{%- endif %} +{%- if probes %} + probes: {{ probes | tojson }} +{%- endif %} +{%- if max_concurrent_requests or max_queue_size or overflow_status_code %} + requestLimits: +{%- if max_concurrent_requests %} + maxConcurrentRequests: {{ max_concurrent_requests }} +{%- endif %} +{%- if max_queue_size %} + maxQueueSize: {{ max_queue_size }} +{%- endif %} +{%- if overflow_status_code %} + overflowStatusCode: {{ overflow_status_code }} +{%- endif %} +{%- endif %} + +{%- if auto_scaling_spec %} + autoScalingSpec: {{ auto_scaling_spec | tojson }} +{%- else %} + autoScalingSpec: + cloudWatchTrigger: +{%- if dimensions %} + dimensions: +{%- for dim_key, dim_val in dimensions.items() %} + - name: {{ dim_key }} + value: {{ dim_val }} +{%- endfor %} +{%- endif %} + metricCollectionPeriod: {{ metric_collection_period }} + metricCollectionStartTime: {{ metric_collection_start_time }} + metricName: {{ metric_name or "" }} + metricStat: {{ metric_stat }} + metricType: {{ metric_type }} + minValue: {{ min_value }} + name: {{ cloud_watch_trigger_name or "" }} + namespace: {{ cloud_watch_trigger_namespace or "" }} + targetValue: {{ target_value or "" }} + useCachedMetrics: {{ use_cached_metrics or False }} +{%- endif %} + + invocationEndpoint: "{{ invocation_endpoint }}" + +{%- if intelligent_routing_enabled is defined and intelligent_routing_enabled is not none and intelligent_routing_enabled != "" %} + intelligentRoutingSpec: + enabled: {{ intelligent_routing_enabled }} +{%- if routing_strategy %} + routingStrategy: "{{ routing_strategy }}" +{%- endif %} +{%- endif %} + +{%- if (enable_l1_cache is defined and enable_l1_cache is not none and enable_l1_cache != "") or (enable_l2_cache is defined and enable_l2_cache is not none and enable_l2_cache != "") or cache_config_file %} + kvCacheSpec: +{%- if enable_l1_cache is defined and enable_l1_cache is not none and enable_l1_cache != "" %} + enableL1Cache: {{ enable_l1_cache }} +{%- endif %} +{%- if enable_l2_cache is defined and enable_l2_cache is not none and enable_l2_cache != "" %} + enableL2Cache: {{ enable_l2_cache }} +{%- endif %} +{%- if l2_cache_backend or l2_cache_local_url %} + l2CacheSpec: +{%- if l2_cache_backend %} + l2CacheBackend: "{{ l2_cache_backend }}" +{%- endif %} +{%- if l2_cache_local_url %} + l2CacheLocalUrl: "{{ l2_cache_local_url }}" +{%- endif %} +{%- endif %} +{%- if cache_config_file %} + cacheConfigFile: "{{ cache_config_file }}" +{%- endif %} +{%- endif %} + +{%- if load_balancer_health_check_path or load_balancer_routing_algorithm %} + loadBalancer: +{%- if load_balancer_health_check_path %} + healthCheckPath: "{{ load_balancer_health_check_path }}" +{%- endif %} +{%- if load_balancer_routing_algorithm %} + routingAlgorithm: "{{ load_balancer_routing_algorithm }}" +{%- endif %} +{%- endif %} + +{%- if data_capture %} + dataCapture: {{ data_capture | tojson }} +{%- endif %} + +{%- if dns_hosted_zone_id %} + dnsConfig: + hostedZoneId: "{{ dns_hosted_zone_id }}" +{%- endif %} +""" \ No newline at end of file diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py index 96b80a47..decc1968 100644 --- a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/registry.py @@ -12,15 +12,19 @@ # language governing permissions and limitations under the License. from hyperpod_jumpstart_inference_template.v1_0 import model as v1_0 from hyperpod_jumpstart_inference_template.v1_1 import model as v1_1 +from hyperpod_jumpstart_inference_template.v1_2 import model as v1_2 from hyperpod_jumpstart_inference_template.v1_0.template import TEMPLATE_CONTENT as v1_0_template from hyperpod_jumpstart_inference_template.v1_1.template import TEMPLATE_CONTENT as v1_1_template +from hyperpod_jumpstart_inference_template.v1_2.template import TEMPLATE_CONTENT as v1_2_template SCHEMA_REGISTRY = { "1.0": v1_0.FlatHPJumpStartEndpoint, "1.1": v1_1.FlatHPJumpStartEndpoint, + "1.2": v1_2.FlatHPJumpStartEndpoint, } TEMPLATE_REGISTRY = { "1.0": v1_0_template, "1.1": v1_1_template, + "1.2": v1_2_template, } diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/__init__.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/__init__.py new file mode 100644 index 00000000..68054b98 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/__init__.py @@ -0,0 +1,12 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. \ No newline at end of file diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/model.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/model.py new file mode 100644 index 00000000..4526f5df --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/model.py @@ -0,0 +1,371 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from pydantic import BaseModel, Field, model_validator, ConfigDict +from typing import Optional, Literal, Dict, Any, List + +# reuse the nested types +from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import ( + Model, + SageMakerEndpoint, + Server, + TlsConfig, + CustomCertificateConfig, + Validations, + IntelligentRoutingSpec, + KvCacheSpec, + L2CacheSpec, + LoadBalancer, + AutoScalingSpec, + Metrics, + ModelMetrics, + EnvironmentVariables, + DataCapture, + DnsConfig, +) +from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint +from sagemaker.hyperpod.common.config.metadata import Metadata + + +class FlatHPJumpStartEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid") + + namespace: Optional[str] = Field( + default=None, description="Kubernetes namespace", min_length=1 + ) + + accept_eula: bool = Field( + False, + alias="accept_eula", + description="Whether model terms of use have been accepted", + ) + + metadata_name: Optional[str] = Field( + None, + alias="metadata_name", + description="Name of the jumpstart endpoint object", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + model_id: str = Field( + ..., + alias="model_id", + description="Unique identifier of the model within the hub", + min_length=1, + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + + model_version: Optional[str] = Field( + None, + alias="model_version", + description="Semantic version of the model to deploy (e.g. 1.0.0)", + min_length=5, + max_length=14, + pattern=r"^\d{1,4}\.\d{1,4}\.\d{1,4}$", + ) + + instance_type: str = Field( + ..., + alias="instance_type", + description="EC2 instance type for the inference server", + pattern=r"^ml\..*", + ) + + accelerator_partition_type: Optional[str] = Field( + None, + alias="accelerator_partition_type", + description="MIG profile to use for GPU partitioning", + pattern=r"^mig-.*$", + ) + + accelerator_partition_validation: Optional[bool] = Field( + True, + alias="accelerator_partition_validation", + description="Enable MIG validation for GPU partitioning. Default is true." + ) + + endpoint_name: Optional[str] = Field( + None, + alias="endpoint_name", + description="Name of SageMaker endpoint; empty string means no creation", + max_length=63, + pattern=r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + ) + tls_certificate_output_s3_uri: Optional[str] = Field( + None, + alias="tls_certificate_output_s3_uri", + description="S3 URI to write the TLS certificate", + pattern=r"^s3://([^/]+)/?(.*)$", + ) + + # customCertificateConfig + custom_certificate_acm_arn: Optional[str] = Field( + None, + alias="custom_certificate_acm_arn", + description="ACM certificate ARN for custom TLS certificate", + pattern=r"^arn:aws:acm:[a-z0-9-]+:[0-9]{12}:certificate/[a-fA-F0-9-]+$", + ) + custom_certificate_domain_name: Optional[str] = Field( + None, + alias="custom_certificate_domain_name", + description="Domain name for the custom TLS certificate", + max_length=253, + pattern=r"^[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)*$", + ) + + # Intelligent Routing + intelligent_routing_enabled: Optional[bool] = Field( + None, + alias="intelligent_routing_enabled", + description="Enable intelligent routing", + ) + routing_strategy: Optional[ + Literal["prefixaware", "kvaware", "session", "roundrobin"] + ] = Field( + None, + alias="routing_strategy", + description="Routing strategy for intelligent routing", + ) + + # KV Cache + enable_l1_cache: Optional[bool] = Field( + None, alias="enable_l1_cache", description="Enable L1 cache (CPU offloading)" + ) + enable_l2_cache: Optional[bool] = Field( + None, alias="enable_l2_cache", description="Enable L2 cache" + ) + l2_cache_backend: Optional[str] = Field( + None, alias="l2_cache_backend", description="L2 cache backend type" + ) + l2_cache_local_url: Optional[str] = Field( + None, alias="l2_cache_local_url", description="L2 cache URL to local storage" + ) + cache_config_file: Optional[str] = Field( + None, alias="cache_config_file", description="KV cache configuration file path" + ) + + # LoadBalancer + load_balancer_health_check_path: Optional[str] = Field( + None, + alias="load_balancer_health_check_path", + description="Health check path for the ALB target group", + ) + load_balancer_routing_algorithm: Optional[str] = Field( + None, + alias="load_balancer_routing_algorithm", + description="Routing algorithm: least_outstanding_requests or round_robin", + ) + + # Replicas + replicas: Optional[int] = Field( + 1, alias="replicas", description="Number of inference server replicas. Default 1." + ) + + # Max deploy time + max_deploy_time_in_seconds: Optional[int] = Field( + 3600, alias="max_deploy_time_in_seconds", + description="Maximum deployment time in seconds. Defaults to 3600.", + ) + + # Environment variables + env: Optional[Dict[str, str]] = Field( + None, alias="env", + description="Map of environment variable names to their values", + ) + + # Metrics + metrics_enabled: Optional[bool] = Field( + None, alias="metrics_enabled", description="Enable metrics collection" + ) + metrics_scrape_interval_seconds: Optional[int] = Field( + None, alias="metrics_scrape_interval_seconds", + description="Scrape interval in seconds for metrics collection", + ) + model_metrics_path: Optional[str] = Field( + None, alias="model_metrics_path", description="Path where the model exposes metrics" + ) + model_metrics_port: Optional[int] = Field( + None, alias="model_metrics_port", description="Port where the model exposes metrics" + ) + + # Model sub-fields + additional_configs: Optional[Dict[str, str]] = Field( + None, alias="additional_configs", description="Additional model configs as key-value pairs" + ) + gated_model_download_role: Optional[str] = Field( + None, alias="gated_model_download_role", + description="IAM role ARN for downloading gated models", + ) + model_hub_name: Optional[str] = Field( + None, alias="model_hub_name", description="Name of the model hub" + ) + + # Server execution role + execution_role: Optional[str] = Field( + None, alias="execution_role", + description="IAM role ARN for deploying and managing the inference server", + ) + + # Full autoScalingSpec JSON override + auto_scaling_spec: Optional[Dict[str, Any]] = Field( + None, alias="auto_scaling_spec", + description="Full autoScalingSpec JSON for autoscaling configuration", + ) + + # DNS config + dns_hosted_zone_id: Optional[str] = Field( + None, alias="dns_hosted_zone_id", + description="Route53 Hosted Zone ID for DNS automation", + pattern=r"^Z[A-Z0-9]+$", + ) + + # Data capture (JSON flag) + data_capture: Optional[Dict[str, Any]] = Field( + None, alias="data_capture", + description="Data capture configuration JSON for SageMaker, LoadBalancer, and Model Pod tiers", + ) + + @model_validator(mode="after") + def validate_name(self): + if not self.metadata_name and not self.endpoint_name: + raise ValueError("Either metadata_name or endpoint_name must be provided") + return self + + @model_validator(mode="after") + def validate_certificate_and_dns(self): + has_acm = self.custom_certificate_acm_arn is not None + has_domain = self.custom_certificate_domain_name is not None + has_dns = self.dns_hosted_zone_id is not None + if has_acm != has_domain: + raise ValueError( + "custom_certificate_acm_arn and custom_certificate_domain_name must both be provided together" + ) + if has_dns and not (has_acm and has_domain): + raise ValueError( + "dns_hosted_zone_id requires both custom_certificate_acm_arn and custom_certificate_domain_name" + ) + return self + + def to_domain(self) -> HPJumpStartEndpoint: + if self.endpoint_name and not self.metadata_name: + self.metadata_name = self.endpoint_name + + metadata = Metadata(name=self.metadata_name, namespace=self.namespace) + + model = Model( + accept_eula=self.accept_eula, + model_id=self.model_id, + model_version=self.model_version, + additional_configs=[{"name": k, "value": v} for k, v in self.additional_configs.items()] if self.additional_configs else None, + gated_model_download_role=self.gated_model_download_role, + model_hub_name=self.model_hub_name, + ) + validations = Validations( + accelerator_partition_validation=self.accelerator_partition_validation, + ) + server = Server( + instance_type=self.instance_type, + accelerator_partition_type=self.accelerator_partition_type, + validations=validations, + execution_role=self.execution_role, + ) + sage_ep = SageMakerEndpoint(name=self.endpoint_name) + tls = TlsConfig( + tls_certificate_output_s3_uri=self.tls_certificate_output_s3_uri, + custom_certificate_config=CustomCertificateConfig( + acm_arn=self.custom_certificate_acm_arn, + domain_name=self.custom_certificate_domain_name, + ) if self.custom_certificate_acm_arn and self.custom_certificate_domain_name else None, + ) + + # Build intelligent routing spec + intelligent_routing_spec = None + if self.intelligent_routing_enabled is not None: + intelligent_routing_spec = IntelligentRoutingSpec( + enabled=self.intelligent_routing_enabled, + routing_strategy=self.routing_strategy, + ) + + # Build KV cache spec + kv_cache_spec = None + if any([self.enable_l1_cache, self.enable_l2_cache, self.cache_config_file]): + l2_cache_spec = None + if self.l2_cache_backend or self.l2_cache_local_url: + l2_cache_spec = L2CacheSpec( + l2_cache_backend=self.l2_cache_backend, + l2_cache_local_url=self.l2_cache_local_url, + ) + kv_cache_spec = KvCacheSpec( + enable_l1_cache=self.enable_l1_cache, + enable_l2_cache=self.enable_l2_cache, + l2_cache_spec=l2_cache_spec, + cache_config_file=self.cache_config_file, + ) + + # Build load balancer config + load_balancer = None + if self.load_balancer_health_check_path or self.load_balancer_routing_algorithm: + load_balancer = LoadBalancer( + health_check_path=self.load_balancer_health_check_path, + routing_algorithm=self.load_balancer_routing_algorithm, + ) + + # Build env vars + env_vars = None + if self.env: + env_vars = [EnvironmentVariables(name=k, value=v) for k, v in self.env.items()] + + # Build metrics + model_metrics = None + if self.model_metrics_path or self.model_metrics_port: + model_metrics = ModelMetrics( + path=self.model_metrics_path, + port=self.model_metrics_port, + ) + metrics = None + if self.metrics_enabled is not None or self.metrics_scrape_interval_seconds is not None or model_metrics is not None: + metrics = Metrics( + enabled=self.metrics_enabled, + metrics_scrape_interval_seconds=self.metrics_scrape_interval_seconds, + model_metrics=model_metrics, + ) + + # Build autoScalingSpec from JSON + auto_scaling_spec = AutoScalingSpec(**self.auto_scaling_spec) if self.auto_scaling_spec else None + + # Build DNS config + dns_config = None + if self.dns_hosted_zone_id: + dns_config = DnsConfig(hosted_zone_id=self.dns_hosted_zone_id) + + # Build data capture from JSON + data_capture = DataCapture(**self.data_capture) if self.data_capture else None + + return HPJumpStartEndpoint( + metadata=metadata, + model=model, + server=server, + sage_maker_endpoint=sage_ep, + tls_config=tls, + intelligent_routing_spec=intelligent_routing_spec, + kv_cache_spec=kv_cache_spec, + load_balancer=load_balancer, + replicas=self.replicas, + max_deploy_time_in_seconds=self.max_deploy_time_in_seconds, + environment_variables=env_vars, + metrics=metrics, + auto_scaling_spec=auto_scaling_spec, + dns_config=dns_config, + data_capture=data_capture, + ) diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/schema.json b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/schema.json new file mode 100644 index 00000000..d37bf8c8 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/schema.json @@ -0,0 +1,471 @@ +{ + "additionalProperties": false, + "properties": { + "namespace": { + "anyOf": [ + { + "minLength": 1, + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Kubernetes namespace", + "title": "Namespace" + }, + "accept_eula": { + "default": false, + "description": "Whether model terms of use have been accepted", + "title": "Accept Eula", + "type": "boolean" + }, + "metadata_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the jumpstart endpoint object", + "title": "Metadata Name" + }, + "model_id": { + "description": "Unique identifier of the model within the hub", + "maxLength": 63, + "minLength": 1, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "title": "Model Id", + "type": "string" + }, + "model_version": { + "anyOf": [ + { + "maxLength": 14, + "minLength": 5, + "pattern": "^\\d{1,4}\\.\\d{1,4}\\.\\d{1,4}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Semantic version of the model to deploy (e.g. 1.0.0)", + "title": "Model Version" + }, + "instance_type": { + "description": "EC2 instance type for the inference server", + "pattern": "^ml\\..*", + "title": "Instance Type", + "type": "string" + }, + "accelerator_partition_type": { + "anyOf": [ + { + "pattern": "^mig-.*$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "MIG profile to use for GPU partitioning", + "title": "Accelerator Partition Type" + }, + "accelerator_partition_validation": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": true, + "description": "Enable MIG validation for GPU partitioning. Default is true.", + "title": "Accelerator Partition Validation" + }, + "endpoint_name": { + "anyOf": [ + { + "maxLength": 63, + "pattern": "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of SageMaker endpoint; empty string means no creation", + "title": "Endpoint Name" + }, + "tls_certificate_output_s3_uri": { + "anyOf": [ + { + "pattern": "^s3://([^/]+)/?(.*)$", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "S3 URI to write the TLS certificate", + "title": "Tls Certificate Output S3 Uri" + }, + "custom_certificate_acm_arn": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "ACM certificate ARN for custom TLS certificate", + "title": "Custom Certificate Acm Arn" + }, + "custom_certificate_domain_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Domain name for the custom TLS certificate", + "title": "Custom Certificate Domain Name" + }, + "intelligent_routing_enabled": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable intelligent routing", + "title": "Intelligent Routing Enabled" + }, + "routing_strategy": { + "anyOf": [ + { + "enum": [ + "prefixaware", + "kvaware", + "session", + "roundrobin" + ], + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Routing strategy for intelligent routing", + "title": "Routing Strategy" + }, + "enable_l1_cache": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable L1 cache (CPU offloading)", + "title": "Enable L1 Cache" + }, + "enable_l2_cache": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable L2 cache", + "title": "Enable L2 Cache" + }, + "l2_cache_backend": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "L2 cache backend type", + "title": "L2 Cache Backend" + }, + "l2_cache_local_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "L2 cache URL to local storage", + "title": "L2 Cache Local Url" + }, + "cache_config_file": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "KV cache configuration file path", + "title": "Cache Config File" + }, + "load_balancer_health_check_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Health check path for the ALB target group", + "title": "Load Balancer Health Check Path" + }, + "load_balancer_routing_algorithm": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Routing algorithm: least_outstanding_requests or round_robin", + "title": "Load Balancer Routing Algorithm" + }, + "replicas": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 1, + "description": "Number of inference server replicas. Default 1.", + "title": "Replicas" + }, + "max_deploy_time_in_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": 3600, + "description": "Maximum deployment time in seconds. Defaults to 3600.", + "title": "Max Deploy Time In Seconds" + }, + "env": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Map of environment variable names to their values", + "title": "Env" + }, + "metrics_enabled": { + "anyOf": [ + { + "type": "boolean" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Enable metrics collection", + "title": "Metrics Enabled" + }, + "metrics_scrape_interval_seconds": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Scrape interval in seconds for metrics collection", + "title": "Metrics Scrape Interval Seconds" + }, + "model_metrics_path": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Path where the model exposes metrics", + "title": "Model Metrics Path" + }, + "model_metrics_port": { + "anyOf": [ + { + "type": "integer" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Port where the model exposes metrics", + "title": "Model Metrics Port" + }, + "additional_configs": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Additional model configs as key-value pairs", + "title": "Additional Configs" + }, + "gated_model_download_role": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "IAM role ARN for downloading gated models", + "title": "Gated Model Download Role" + }, + "model_hub_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Name of the model hub", + "title": "Model Hub Name" + }, + "execution_role": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "IAM role ARN for deploying and managing the inference server", + "title": "Execution Role" + }, + "auto_scaling_spec": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Full autoScalingSpec JSON for autoscaling configuration", + "title": "Auto Scaling Spec" + }, + "dns_hosted_zone_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Route53 Hosted Zone ID for DNS automation", + "title": "Dns Hosted Zone Id" + }, + "data_capture": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "default": null, + "description": "Data capture configuration JSON for SageMaker, LoadBalancer, and Model Pod tiers", + "title": "Data Capture" + } + }, + "required": [ + "model_id", + "instance_type" + ], + "title": "FlatHPJumpStartEndpoint", + "type": "object" +} diff --git a/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/template.py b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/template.py new file mode 100644 index 00000000..2a176d11 --- /dev/null +++ b/hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_2/template.py @@ -0,0 +1,30 @@ +TEMPLATE_CONTENT = """ +apiVersion: inference.sagemaker.aws.amazon.com/v1 +kind: JumpStartModel +metadata: + name: {{ metadata_name or endpoint_name }} + namespace: {{ namespace or "default" }} +spec: + model: + acceptEula: {{ accept_eula or false }} + modelHubName: "SageMakerPublicHub" + modelId: {{ model_id }} + modelVersion: {{ model_version or "" }} + sageMakerEndpoint: + name: {{ endpoint_name or "" }} + server: + instanceType: {{ instance_type }} + {% if accelerator_partition_type is not none %}acceleratorPartitionType: "{{ accelerator_partition_type }}"{% endif %} + {% if accelerator_partition_validation is not none %}validations: + {% if accelerator_partition_validation is not none %} acceleratorPartitionValidation: {{ accelerator_partition_validation }}{% endif %} + {% endif %} + tlsConfig: + tlsCertificateOutputS3Uri: {{ tls_certificate_output_s3_uri or "" }} +{% if custom_certificate_acm_arn and custom_certificate_domain_name %} customCertificateConfig: + acmArn: "{{ custom_certificate_acm_arn }}" + domainName: "{{ custom_certificate_domain_name }}" +{% endif %} +{% if dns_hosted_zone_id is not none %} dnsConfig: + hostedZoneId: "{{ dns_hosted_zone_id }}" +{% endif %} +""" \ No newline at end of file diff --git a/src/sagemaker/hyperpod/cli/commands/inference.py b/src/sagemaker/hyperpod/cli/commands/inference.py index b4b1474f..1a38dafe 100644 --- a/src/sagemaker/hyperpod/cli/commands/inference.py +++ b/src/sagemaker/hyperpod/cli/commands/inference.py @@ -20,7 +20,7 @@ # CREATE @click.command("hyp-jumpstart-endpoint") -@click.option("--version", default="1.1", help="Schema version to use") +@click.option("--version", default="1.2", help="Schema version to use") @click.option("--debug", is_flag=True, help="Enable debug mode") @generate_click_command( schema_pkg="hyperpod_jumpstart_inference_template", @@ -35,7 +35,7 @@ def js_create(version, debug, js_endpoint): @click.command("hyp-custom-endpoint") -@click.option("--version", default="1.1", help="Schema version to use") +@click.option("--version", default="1.2", help="Schema version to use") @click.option("--debug", is_flag=True, help="Enable debug mode") @generate_click_command( schema_pkg="hyperpod_custom_inference_template", diff --git a/src/sagemaker/hyperpod/cli/inference_utils.py b/src/sagemaker/hyperpod/cli/inference_utils.py index 29f3c6b3..eb2a522d 100644 --- a/src/sagemaker/hyperpod/cli/inference_utils.py +++ b/src/sagemaker/hyperpod/cli/inference_utils.py @@ -49,6 +49,13 @@ def wrapped_func(*args, **kwargs): "dimensions": ("JSON object of dimensions, e.g. " '\'{"VAR1":"foo","VAR2":"bar"}\''), "resources_limits": ('JSON object of resource limits, e.g. \'{"cpu":"2","memory":"4Gi"}\''), "resources_requests": ('JSON object of resource requests, e.g. \'{"cpu":"1","memory":"2Gi"}\''), + "kubernetes": ('JSON object for kubernetes customizations, e.g. \'{"initContainers":[...],"volumes":[...],"schedulerName":"..."}\''), + "node_affinity": ('JSON object for node affinity scheduling'), + "tags": ('JSON object of tags, e.g. \'{"team":"ml","env":"prod"}\''), + "probes": ('JSON object for container probes, e.g. \'{"livenessProbe":{...},"readinessProbe":{...}}\''), + "auto_scaling_spec": ('JSON object for full autoscaling config'), + "additional_configs": ('JSON object of additional model configs, e.g. \'{"key1":"val1"}\''), + "data_capture": ('JSON object for data capture config, e.g. \'{"sagemakerEndpoint":{"enabled":true}}\''), } for flag_name, help_text in json_flags.items(): @@ -72,6 +79,13 @@ def wrapped_func(*args, **kwargs): "dimensions", "resources_limits", "resources_requests", + "kubernetes", + "node_affinity", + "tags", + "probes", + "auto_scaling_spec", + "additional_configs", + "data_capture", ): continue diff --git a/src/sagemaker/hyperpod/inference/config/hp_endpoint_config.py b/src/sagemaker/hyperpod/inference/config/hp_endpoint_config.py index 33471286..6e4c0c8f 100644 --- a/src/sagemaker/hyperpod/inference/config/hp_endpoint_config.py +++ b/src/sagemaker/hyperpod/inference/config/hp_endpoint_config.py @@ -1,18 +1,19 @@ from pydantic import BaseModel, ConfigDict, Field -from typing import Optional, List, Dict, Union, Literal +from typing import Any, Optional, List, Dict, Union, Literal class Dimensions(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) name: str = Field(description="CloudWatch Metric dimension name") value: str = Field(description="CloudWatch Metric dimension value") class CloudWatchTrigger(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """CloudWatch metric trigger to use for autoscaling""" - model_config = ConfigDict(extra="forbid") activationTargetValue: Optional[float] = Field( default=0, @@ -71,7 +72,7 @@ class CloudWatchTrigger(BaseModel): class CloudWatchTriggerList(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) activationTargetValue: Optional[float] = Field( default=0, @@ -130,9 +131,10 @@ class CloudWatchTriggerList(BaseModel): class PrometheusTrigger(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Prometheus metric trigger to use for autoscaling""" - model_config = ConfigDict(extra="forbid") activationTargetValue: Optional[float] = Field( default=0, @@ -176,7 +178,7 @@ class PrometheusTrigger(BaseModel): class PrometheusTriggerList(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) activationTargetValue: Optional[float] = Field( default=0, @@ -220,7 +222,7 @@ class PrometheusTriggerList(BaseModel): class AutoScalingSpec(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) cloudWatchTrigger: Optional[CloudWatchTrigger] = Field( default=None, @@ -279,10 +281,127 @@ class AutoScalingSpec(BaseModel): ) +class Kubernetes(BaseModel): + """User-provided customizations for the inference pod.""" + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + initContainers: Optional[List[Dict[str, Any]]] = Field( + default=None, alias="init_containers", + description="Init containers to run before the inference server starts.", + ) + schedulerName: Optional[str] = Field( + default=None, alias="scheduler_name", + description="Name of the scheduler to use for pod scheduling.", + ) + serviceAccountName: Optional[str] = Field( + default=None, + alias="service_account_name", + description="Name of the Kubernetes ServiceAccount to use for the inference pod. If not specified, the namespace's default service account will be used. This is useful for providing AWS credentials via IRSA to init containers or the worker.", + ) + volumes: Optional[List[Dict[str, Any]]] = Field( + default=None, + description="Additional volumes to add to the pod spec.", + ) + + +class NodeSelectorRequirement(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + key: str + operator: str + values: Optional[List[str]] = None + + +class NodeSelectorTerm(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + matchExpressions: Optional[List[NodeSelectorRequirement]] = Field( + default=None, alias="match_expressions" + ) + matchFields: Optional[List[NodeSelectorRequirement]] = Field( + default=None, alias="match_fields" + ) + + +class PreferredSchedulingTerm(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + preference: NodeSelectorTerm + weight: int + + +class NodeSelector(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + nodeSelectorTerms: List[NodeSelectorTerm] = Field(alias="node_selector_terms") + + +class NodeAffinity(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + preferredDuringSchedulingIgnoredDuringExecution: Optional[ + List[PreferredSchedulingTerm] + ] = Field(default=None, alias="preferred_during_scheduling_ignored_during_execution") + requiredDuringSchedulingIgnoredDuringExecution: Optional[NodeSelector] = Field( + default=None, alias="required_during_scheduling_ignored_during_execution" + ) + + +class CustomCertificateConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + acmArn: str = Field(alias="acm_arn", description="ACM certificate ARN") + domainName: str = Field( + alias="domain_name", + description="Domain name to use from the certificate.", + ) + + +class Probe(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + exec: Optional[Dict[str, Any]] = None + failureThreshold: Optional[int] = Field(default=None, alias="failure_threshold") + grpc: Optional[Dict[str, Any]] = None + httpGet: Optional[Dict[str, Any]] = Field(default=None, alias="http_get") + initialDelaySeconds: Optional[int] = Field( + default=None, alias="initial_delay_seconds" + ) + periodSeconds: Optional[int] = Field(default=None, alias="period_seconds") + successThreshold: Optional[int] = Field(default=None, alias="success_threshold") + tcpSocket: Optional[Dict[str, Any]] = Field(default=None, alias="tcp_socket") + terminationGracePeriodSeconds: Optional[int] = Field( + default=None, alias="termination_grace_period_seconds" + ) + timeoutSeconds: Optional[int] = Field(default=None, alias="timeout_seconds") + + +class Probes(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + livenessProbe: Optional[Probe] = Field(default=None, alias="liveness_probe") + readinessProbe: Optional[Probe] = Field(default=None, alias="readiness_probe") + startupProbe: Optional[Probe] = Field(default=None, alias="startup_probe") + + +class RequestLimits(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + maxConcurrentRequests: Optional[int] = Field( + default=None, alias="max_concurrent_requests" + ) + maxQueueSize: Optional[int] = Field(default=None, alias="max_queue_size") + overflowStatusCode: Optional[int] = Field( + default=429, alias="overflow_status_code" + ) + + class IntelligentRoutingSpec(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Configuration for intelligent routing This feature is currently not supported for existing deployments. Adding this configuration to an existing deployment will be rejected.""" - model_config = ConfigDict(extra="forbid") autoScalingSpec: Optional[AutoScalingSpec] = Field( default=None, alias="auto_scaling_spec" @@ -296,9 +415,10 @@ class IntelligentRoutingSpec(BaseModel): class L2CacheSpec(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Configuration for providing L2 Cache offloading""" - model_config = ConfigDict(extra="forbid") l2CacheBackend: Optional[str] = Field( default=None, @@ -313,9 +433,10 @@ class L2CacheSpec(BaseModel): class KvCacheSpec(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Configuration for KV Cache specification By default L1CacheOffloading will be enabled""" - model_config = ConfigDict(extra="forbid") cacheConfigFile: Optional[str] = Field( default=None, @@ -334,9 +455,10 @@ class KvCacheSpec(BaseModel): class LoadBalancer(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Configuration for Application Load Balancer""" - model_config = ConfigDict(extra="forbid") healthCheckPath: Optional[str] = Field( default="/ping", @@ -355,7 +477,7 @@ class LoadBalancer(BaseModel): class ModelMetrics(BaseModel): """Configuration for model container metrics scraping""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) path: Optional[str] = Field( default="/metrics", description="Path where the model exposes metrics" @@ -367,9 +489,10 @@ class ModelMetrics(BaseModel): class Metrics(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Configuration for metrics collection and exposure""" - model_config = ConfigDict(extra="forbid") enabled: Optional[bool] = Field( default=True, description="Enable metrics collection for this model deployment" @@ -387,7 +510,7 @@ class Metrics(BaseModel): class FsxStorage(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) dnsName: Optional[str] = Field( default=None, alias="dns_name", description="FSX File System DNS Name" @@ -399,22 +522,69 @@ class FsxStorage(BaseModel): class S3Storage(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) bucketName: str = Field(alias="bucket_name", description="S3 bucket location") region: str = Field(description="S3 bucket region") +class TokenSecretRef(BaseModel): + """Reference to a Kubernetes Secret containing the HuggingFace API token.""" + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + key: str = Field( + description="The key of the secret to select from. Must be a valid secret key." + ) + name: Optional[str] = Field( + default="", + description="Name of the referent.", + ) + optional: Optional[bool] = Field( + default=None, + description="Specify whether the Secret or its key must be defined", + ) + + +class HuggingFaceModel(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """HuggingFace model configuration. Required when modelSourceType is huggingface.""" + + + commitSHA: Optional[str] = Field( + default=None, + alias="commit_sha", + description="Git commit SHA for the model revision. Must be a full 40-character lowercase hex SHA. If not provided, the operator defaults to main branch.", + ) + modelId: str = Field( + alias="model_id", + description='HuggingFace Hub model identifier in org/model format (e.g. "meta-llama/Llama-3.1-8B-Instruct").', + ) + tokenSecretRef: Optional[TokenSecretRef] = Field( + default=None, + alias="token_secret_ref", + description="Reference to a Kubernetes Secret containing the HuggingFace API token.", + ) + + class ModelSourceConfig(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) fsxStorage: Optional[FsxStorage] = Field(default=None, alias="fsx_storage") + huggingFaceModel: Optional[HuggingFaceModel] = Field( + default=None, + alias="hugging_face_model", + description='HuggingFace model configuration. Required when modelSourceType is "huggingface".', + ) modelLocation: Optional[str] = Field( default=None, alias="model_location", - description="Sepcific location where the model data exists", + description="Specific location where the model data exists", + ) + modelSourceType: Literal["fsx", "s3", "huggingface", "kubernetesVolume"] = Field( + alias="model_source_type" ) - modelSourceType: Literal["fsx", "s3"] = Field(alias="model_source_type") prefetchEnabled: Optional[bool] = Field( default=False, alias="prefetch_enabled", @@ -424,17 +594,22 @@ class ModelSourceConfig(BaseModel): class Tags(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) name: str value: str class TlsConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Configurations for TLS""" - model_config = ConfigDict(extra="forbid") + customCertificateConfig: Optional[CustomCertificateConfig] = Field( + default=None, alias="custom_certificate_config", + description="Customer-provided ACM certificate configuration", + ) tlsCertificateOutputS3Uri: Optional[str] = Field( default=None, alias="tls_certificate_output_s3_uri" ) @@ -443,7 +618,7 @@ class TlsConfig(BaseModel): class ConfigMapKeyRef(BaseModel): """Selects a key of a ConfigMap.""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) key: str = Field(description="The key to select.") name: Optional[str] = Field( @@ -457,9 +632,10 @@ class ConfigMapKeyRef(BaseModel): class FieldRef(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Selects a field of the pod: supports metadata.name, metadata.namespace, `metadata.labels['']`, `metadata.annotations['']`, spec.nodeName, spec.serviceAccountName, status.hostIP, status.podIP, status.podIPs.""" - model_config = ConfigDict(extra="forbid") apiVersion: Optional[str] = Field( default=None, @@ -473,9 +649,10 @@ class FieldRef(BaseModel): class ResourceFieldRef(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Selects a resource of the container: only resources limits and requests (limits.cpu, limits.memory, limits.ephemeral-storage, requests.cpu, requests.memory and requests.ephemeral-storage) are currently supported.""" - model_config = ConfigDict(extra="forbid") containerName: Optional[str] = Field( default=None, @@ -492,7 +669,7 @@ class ResourceFieldRef(BaseModel): class SecretKeyRef(BaseModel): """Selects a key of a secret in the pod's namespace""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) key: str = Field( description="The key of the secret to select from. Must be a valid secret key." @@ -508,9 +685,10 @@ class SecretKeyRef(BaseModel): class ValueFrom(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Source for the environment variable's value. Cannot be used if value is not empty.""" - model_config = ConfigDict(extra="forbid") configMapKeyRef: Optional[ConfigMapKeyRef] = Field( default=None, @@ -535,9 +713,10 @@ class ValueFrom(BaseModel): class EnvironmentVariables(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """EnvVar represents an environment variable present in a Container.""" - model_config = ConfigDict(extra="forbid") name: str = Field( description="Name of the environment variable. Must be a C_IDENTIFIER." @@ -554,9 +733,10 @@ class EnvironmentVariables(BaseModel): class ModelInvocationPort(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Defines the port at which the model server will listen to the invocation requests.""" - model_config = ConfigDict(extra="forbid") containerPort: int = Field( alias="container_port", @@ -569,9 +749,10 @@ class ModelInvocationPort(BaseModel): class ModelVolumeMount(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Defines the volume where model will be loaded""" - model_config = ConfigDict(extra="forbid") mountPath: Optional[str] = Field( default="/opt/ml/model", @@ -584,7 +765,7 @@ class ModelVolumeMount(BaseModel): class Claims(BaseModel): """ResourceClaim references one entry in PodSpec.ResourceClaims.""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) name: str = Field( description="Name must match the name of one entry in pod.spec.resourceClaims of the Pod where this field is used. It makes that resource available inside a container." @@ -598,7 +779,7 @@ class Claims(BaseModel): class Resources(BaseModel): """Defines the Resources in terms of CPU, GPU, Memory needed for the model to be deployed""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) claims: Optional[List[Claims]] = Field( default=None, @@ -615,9 +796,10 @@ class Resources(BaseModel): class Worker(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Details of the worker""" - model_config = ConfigDict(extra="forbid") args: Optional[List[str]] = Field( default=None, description="Defines the Arguments to the entrypoint." @@ -640,6 +822,15 @@ class Worker(BaseModel): alias="model_volume_mount", description="Defines the volume where model will be loaded", ) + probes: Optional[Probes] = Field( + default=None, + description="Configuration for container probes (liveness, readiness, startup)", + ) + requestLimits: Optional[RequestLimits] = Field( + default=None, + alias="request_limits", + description="Configuration for request limiting on the nginx sidecar proxy", + ) resources: Resources = Field( description="Defines the Resources in terms of CPU, GPU, Memory needed for the model to be deployed" ) @@ -650,10 +841,188 @@ class Worker(BaseModel): ) +class CaptureContentTypeHeader(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for how to treat different content type headers during capture""" + + + csvContentTypes: Optional[List[str]] = Field( + default=None, + alias="csv_content_types", + description="List of content type headers to treat as CSV", + ) + jsonContentTypes: Optional[List[str]] = Field( + default=None, + alias="json_content_types", + description="List of content type headers to treat as JSON", + ) + + +class CaptureOptions(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """CaptureOption defines what data to capture (input, output, or both).""" + + + captureMode: Literal["Input", "Output"] = Field( + alias="capture_mode", description="Capture mode: Input or Output" + ) + + +class BufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for buffering and flushing captured data""" + + + batchSize: Optional[int] = Field( + default=10, + alias="batch_size", + description="Number of records to batch before writing to S3", + ) + flushIntervalSeconds: Optional[int] = Field( + default=60, + alias="flush_interval_seconds", + description="Flush interval in seconds", + ) + + +class PayloadConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for payload size limits""" + + + maxPayloadSizeKB: Optional[int] = Field( + default=0, + alias="max_payload_size_kb", + description="Maximum payload size in KB to capture. 0 means no limit (capture full payload).", + ) + + +class DataCaptureModelPod(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for Model Pod level data capture (Tier 3)""" + + + bufferConfig: Optional[BufferConfig] = Field( + default=None, + alias="buffer_config", + description="Configuration for buffering and flushing captured data", + ) + captureContentTypeHeader: Optional[CaptureContentTypeHeader] = Field( + default=None, + alias="capture_content_type_header", + description="Configuration for how to treat different content type headers during capture", + ) + captureOptions: Optional[List[CaptureOptions]] = Field( + default=None, + alias="capture_options", + description="Capture options (Input, Output, or both). Defaults to [Input, Output] when enabled.", + ) + enabled: bool = Field(description="Enable or disable model pod data capture") + initialSamplingPercentage: Optional[int] = Field( + default=None, + alias="initial_sampling_percentage", + description="Percentage of requests to capture (0-100). Defaults to 100 when enabled.", + ) + kmsKeyId: Optional[str] = Field( + default=None, + alias="kms_key_id", + description="Optional KMS key ID, ARN, alias name, or alias ARN for encrypting captured data", + ) + payloadConfig: Optional[PayloadConfig] = Field( + default=None, + alias="payload_config", + description="Configuration for payload size limits", + ) + + +class DataCaptureSagemakerEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for SageMaker Endpoint level data capture (Tier 1)""" + + + captureContentTypeHeader: Optional[CaptureContentTypeHeader] = Field( + default=None, + alias="capture_content_type_header", + description="Configuration for how to treat different content type headers during capture", + ) + captureOptions: Optional[List[CaptureOptions]] = Field( + default=None, + alias="capture_options", + description="Capture options (Input, Output, or both). Defaults to [Input, Output] when enabled.", + ) + enabled: bool = Field( + description="Enable or disable SageMaker endpoint data capture" + ) + initialSamplingPercentage: Optional[int] = Field( + default=None, + alias="initial_sampling_percentage", + description="Percentage of requests to capture (0-100). Defaults to 100 when enabled.", + ) + kmsKeyId: Optional[str] = Field( + default=None, + alias="kms_key_id", + description="Optional KMS key ID, ARN, alias name, or alias ARN for encrypting captured data", + ) + + +class DataCaptureLoadBalancer(BaseModel): + """Configuration for LoadBalancer level data capture (Tier 2)""" + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + enabled: bool = Field(description="Enable or disable load balancer access logs") + + +class DataCapture(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for data capture across multiple tiers (SageMaker, LoadBalancer, Model Pod)""" + + + loadBalancer: Optional[DataCaptureLoadBalancer] = Field( + default=None, + alias="load_balancer", + description="Configuration for LoadBalancer level data capture (Tier 2)", + ) + modelPod: Optional[DataCaptureModelPod] = Field( + default=None, + alias="model_pod", + description="Configuration for Model Pod level data capture (Tier 3)", + ) + s3Uri: Optional[str] = Field( + default=None, + alias="s3_uri", + description="Common S3 URI for all data capture tiers. Each tier will write to a specific prefix within this bucket.", + ) + sagemakerEndpoint: Optional[DataCaptureSagemakerEndpoint] = Field( + default=None, + alias="sagemaker_endpoint", + description="Configuration for SageMaker Endpoint level data capture (Tier 1)", + ) + + +class DnsConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """DNS automation configuration for Route53. Requires tlsConfig.customCertificateConfig to be set.""" + + + hostedZoneId: str = Field( + alias="hosted_zone_id", + description="Route53 Hosted Zone ID where the DNS record will be created.", + ) + + class _HPEndpoint(BaseModel): """InferenceEndpointConfigSpec defines the desired state of InferenceEndpointConfig.""" - model_config = ConfigDict(extra="ignore") + model_config = ConfigDict(extra="ignore", populate_by_name=True) InitialReplicaCount: Optional[int] = Field( default=None, @@ -663,13 +1032,30 @@ class _HPEndpoint(BaseModel): autoScalingSpec: Optional[AutoScalingSpec] = Field( default=None, alias="auto_scaling_spec" ) + dataCapture: Optional[DataCapture] = Field( + default=None, + alias="data_capture", + description="Configuration for data capture across multiple tiers (SageMaker, LoadBalancer, Model Pod)", + ) + dnsConfig: Optional[DnsConfig] = Field( + default=None, + alias="dns_config", + description="DNS automation configuration for Route53. Requires tlsConfig.customCertificateConfig to be set.", + ) endpointName: Optional[str] = Field( default=None, alias="endpoint_name", description="Name used for Sagemaker Endpoint Name of sagemaker endpoint. Defaults to empty string which represents that Sagemaker endpoint will not be created.", ) - instanceType: str = Field( - alias="instance_type", description="Instance Type to deploy the model on" + instanceType: Optional[str] = Field( + default=None, + alias="instance_type", + description="Single instance type to deploy the model on. Mutually exclusive with instanceTypes.", + ) + instanceTypes: Optional[List[str]] = Field( + default=None, + alias="instance_types", + description="List of instance types to deploy the model on, in order of preference.", ) intelligentRoutingSpec: Optional[IntelligentRoutingSpec] = Field( default=None, @@ -686,11 +1072,20 @@ class _HPEndpoint(BaseModel): alias="kv_cache_spec", description="Configuration for KV Cache specification By default L1CacheOffloading will be enabled", ) + kubernetes: Optional[Kubernetes] = Field( + default=None, + description="User-provided customizations for the inference pod.", + ) loadBalancer: Optional[LoadBalancer] = Field( default=None, alias="load_balancer", description="Configuration for Application Load Balancer", ) + maxDeployTimeInSeconds: Optional[int] = Field( + default=3600, + alias="max_deploy_time_in_seconds", + description="Maximum allowed time in seconds for the deployment to complete before timing out. Defaults to 1 hour (3600 seconds)", + ) metrics: Optional[Metrics] = Field( default=None, description="Configuration for metrics collection and exposure" ) @@ -704,6 +1099,11 @@ class _HPEndpoint(BaseModel): alias="model_version", description="Version of the model used in creating sagemaker endpoint", ) + nodeAffinity: Optional[NodeAffinity] = Field( + default=None, + alias="node_affinity", + description="Custom node affinity configuration for advanced scheduling.", + ) replicas: Optional[int] = Field( default=1, description="The desired number of inference server replicas. Default 1.", @@ -719,9 +1119,10 @@ class _HPEndpoint(BaseModel): class Conditions(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """DeploymentCondition describes the state of a deployment at a certain point.""" - model_config = ConfigDict(extra="forbid") lastTransitionTime: Optional[str] = Field( default=None, @@ -752,9 +1153,10 @@ class Conditions(BaseModel): class Status(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Status of the Deployment Object""" - model_config = ConfigDict(extra="forbid") availableReplicas: Optional[int] = Field( default=None, @@ -784,6 +1186,11 @@ class Status(BaseModel): default=None, description="Total number of non-terminated pods targeted by this deployment (their labels match the selector).", ) + terminatingReplicas: Optional[int] = Field( + default=None, + alias="terminating_replicas", + description="Total number of terminating pods targeted by this deployment.", + ) unavailableReplicas: Optional[int] = Field( default=None, alias="unavailable_replicas", @@ -797,9 +1204,10 @@ class Status(BaseModel): class DeploymentStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Details of the native kubernetes deployment that hosts the model""" - model_config = ConfigDict(extra="forbid") deploymentObjectOverallState: Optional[str] = Field( default=None, @@ -822,9 +1230,10 @@ class DeploymentStatus(BaseModel): class Sagemaker(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Status of the SageMaker endpoint""" - model_config = ConfigDict(extra="forbid") configArn: Optional[str] = Field( default=None, @@ -847,7 +1256,7 @@ class Sagemaker(BaseModel): class Endpoints(BaseModel): """EndpointStatus contains the status of SageMaker endpoints""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) sagemaker: Optional[Sagemaker] = Field( default=None, description="Status of the SageMaker endpoint" @@ -857,7 +1266,7 @@ class Endpoints(BaseModel): class ModelMetricsStatus(BaseModel): """Status of model container metrics collection""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) path: Optional[str] = Field( default=None, description="The path where metrics are available" @@ -868,9 +1277,10 @@ class ModelMetricsStatus(BaseModel): class MetricsStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Status of metrics collection""" - model_config = ConfigDict(extra="forbid") enabled: bool = Field(description="Whether metrics collection is enabled") errorMessage: Optional[str] = Field( @@ -894,9 +1304,10 @@ class MetricsStatus(BaseModel): class TlsCertificate(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """CertificateStatus represents the status of TLS certificates""" - model_config = ConfigDict(extra="forbid") certificateARN: Optional[str] = Field( default=None, @@ -908,6 +1319,11 @@ class TlsCertificate(BaseModel): alias="certificate_domain_names", description="The certificate domain names that is attached to the certificate", ) + certificateHealth: Optional[Literal["Valid", "Expiring", "Expired"]] = Field( + default=None, + alias="certificate_health", + description="Certificate health status", + ) certificateName: Optional[str] = Field( default=None, alias="certificate_name", @@ -938,10 +1354,54 @@ class TlsCertificate(BaseModel): ) +class DnsStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Status of the operator-managed Route53 DNS record""" + + + dnsHealth: Optional[Literal["Active", "Pending", "Error"]] = Field( + default=None, + alias="dns_health", + description="DNS resolution status: Active, Pending, or Error.", + ) + hostedZoneId: Optional[str] = Field( + default=None, + alias="hosted_zone_id", + description="Route53 hosted zone ID.", + ) + lastTransitionTime: Optional[str] = Field( + default=None, + alias="last_transition_time", + description="When the status last transitioned, used for propagation timeout.", + ) + managedByOperator: bool = Field( + alias="managed_by_operator", + description="Whether the operator manages this DNS record.", + ) + message: Optional[str] = Field( + default=None, description="Human-readable status or error message." + ) + previousHostedZoneId: Optional[str] = Field( + default=None, + alias="previous_hosted_zone_id", + description="Previous hosted zone ID, retained during domain/zone changes until cleanup completes.", + ) + previousRecordName: Optional[str] = Field( + default=None, + alias="previous_record_name", + description="Previous record name, retained during domain/zone changes until cleanup completes.", + ) + recordName: Optional[str] = Field( + default=None, alias="record_name", description="Route53 record name." + ) + + class InferenceEndpointConfigStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """ModelDeploymentStatus defines the observed state of ModelDeployment""" - model_config = ConfigDict(extra="forbid") conditions: Optional[List[Conditions]] = Field( default=None, @@ -952,6 +1412,11 @@ class InferenceEndpointConfigStatus(BaseModel): alias="deployment_status", description="Details of the native kubernetes deployment that hosts the model", ) + dnsStatus: Optional[DnsStatus] = Field( + default=None, + alias="dns_status", + description="Status of the operator-managed Route53 DNS record", + ) endpoints: Optional[Endpoints] = Field( default=None, description="EndpointStatus contains the status of SageMaker endpoints", diff --git a/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py b/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py index 5e971868..fd9161ee 100644 --- a/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py +++ b/src/sagemaker/hyperpod/inference/config/hp_jumpstart_endpoint_config.py @@ -3,16 +3,17 @@ class Dimensions(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) name: str = Field(description="CloudWatch Metric dimension name") value: str = Field(description="CloudWatch Metric dimension value") class CloudWatchTrigger(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """CloudWatch metric trigger to use for autoscaling""" - model_config = ConfigDict(extra="forbid") activationTargetValue: Optional[float] = Field( default=0, @@ -71,9 +72,113 @@ class CloudWatchTrigger(BaseModel): class PrometheusTrigger(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Prometheus metric trigger to use for autoscaling""" - model_config = ConfigDict(extra="forbid") + + activationTargetValue: Optional[float] = Field( + default=0, + alias="activation_target_value", + description="Activation Value for Prometheus metric to scale from 0 to 1. Only applicable if minReplicaCount = 0", + ) + customHeaders: Optional[str] = Field( + default=None, + alias="custom_headers", + description="Custom headers to include while querying the prometheus endpoint.", + ) + metricType: Optional[Literal["Value", "Average"]] = Field( + default="Average", + alias="metric_type", + description="The type of metric to be used by HPA. Enum: AverageValue - Uses average value of metric per pod, Value - Uses absolute metric value", + ) + name: Optional[str] = Field( + default=None, description="Name for the Prometheus trigger" + ) + namespace: Optional[str] = Field( + default=None, description="Namespace for namespaced queries" + ) + query: Optional[str] = Field( + default=None, description="PromQLQuery for the metric." + ) + serverAddress: Optional[str] = Field( + default=None, + alias="server_address", + description="Server address for AMP workspace", + ) + targetValue: Optional[float] = Field( + default=None, + alias="target_value", + description="Target metric value for scaling", + ) + useCachedMetrics: Optional[bool] = Field( + default=True, + alias="use_cached_metrics", + description="Enable caching of metric values during polling interval. Default is true", + ) + + +class CloudWatchTriggerList(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + activationTargetValue: Optional[float] = Field( + default=0, + alias="activation_target_value", + description="Activation Value for CloudWatch metric to scale from 0 to 1. Only applicable if minReplicaCount = 0", + ) + dimensions: Optional[List[Dimensions]] = Field( + default=None, description="Dimensions for Cloudwatch metrics" + ) + metricCollectionPeriod: Optional[int] = Field( + default=300, + alias="metric_collection_period", + description="Defines the Period for CloudWatch query", + ) + metricCollectionStartTime: Optional[int] = Field( + default=300, + alias="metric_collection_start_time", + description="Defines the StartTime for CloudWatch query", + ) + metricName: Optional[str] = Field( + default=None, + alias="metric_name", + description="Metric name to query for Cloudwatch trigger", + ) + metricStat: Optional[str] = Field( + default="Average", + alias="metric_stat", + description="Statistics metric to be used by Trigger. Used to define Stat for CloudWatch query. Default is Average.", + ) + metricType: Optional[Literal["Value", "Average"]] = Field( + default="Average", + alias="metric_type", + description="The type of metric to be used by HPA. Enum: AverageValue - Uses average value of metric per pod, Value - Uses absolute metric value", + ) + minValue: Optional[float] = Field( + default=0, + alias="min_value", + description="Minimum metric value used in case of empty response from CloudWatch. Default is 0.", + ) + name: Optional[str] = Field( + default=None, description="Name for the CloudWatch trigger" + ) + namespace: Optional[str] = Field( + default=None, description="AWS CloudWatch namespace for metric" + ) + targetValue: Optional[float] = Field( + default=None, + alias="target_value", + description="TargetValue for CloudWatch metric", + ) + useCachedMetrics: Optional[bool] = Field( + default=True, + alias="use_cached_metrics", + description="Enable caching of metric values during polling interval. Default is true", + ) + + +class PrometheusTriggerList(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) activationTargetValue: Optional[float] = Field( default=0, @@ -117,13 +222,18 @@ class PrometheusTrigger(BaseModel): class AutoScalingSpec(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) cloudWatchTrigger: Optional[CloudWatchTrigger] = Field( default=None, alias="cloud_watch_trigger", description="CloudWatch metric trigger to use for autoscaling", ) + cloudWatchTriggerList: Optional[List[CloudWatchTriggerList]] = Field( + default=None, + alias="cloud_watch_trigger_list", + description="Multiple CloudWatch metric triggers to use for autoscaling. Takes priority over CloudWatchTrigger if both are provided.", + ) cooldownPeriod: Optional[int] = Field( default=300, alias="cooldown_period", @@ -154,6 +264,11 @@ class AutoScalingSpec(BaseModel): alias="prometheus_trigger", description="Prometheus metric trigger to use for autoscaling", ) + prometheusTriggerList: Optional[List[PrometheusTriggerList]] = Field( + default=None, + alias="prometheus_trigger_list", + description="Multiple Prometheus metric triggers to use for autoscaling. Takes priority over PrometheusTrigger if both are provided.", + ) scaleDownStabilizationTime: Optional[int] = Field( default=300, alias="scale_down_stabilization_time", @@ -167,7 +282,7 @@ class AutoScalingSpec(BaseModel): class EnvironmentVariables(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) name: str value: str @@ -176,7 +291,7 @@ class EnvironmentVariables(BaseModel): class ModelMetrics(BaseModel): """Configuration for model container metrics scraping""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) path: Optional[str] = Field( default="/metrics", description="Path where the model exposes metrics" @@ -188,9 +303,10 @@ class ModelMetrics(BaseModel): class Metrics(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Configuration for metrics collection and exposure""" - model_config = ConfigDict(extra="forbid") enabled: Optional[bool] = Field( default=True, description="Enable metrics collection for this model deployment" @@ -208,14 +324,14 @@ class Metrics(BaseModel): class AdditionalConfigs(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) name: str value: str class Model(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) acceptEula: bool = Field( default=False, @@ -247,7 +363,7 @@ class Model(BaseModel): class SageMakerEndpoint(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) name: Optional[str] = Field( default="", @@ -266,7 +382,7 @@ class Validations(BaseModel): class Server(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) executionRole: Optional[str] = Field( default=None, @@ -290,22 +406,240 @@ class Server(BaseModel): ) +class IntelligentRoutingSpec(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for intelligent routing""" + + + autoScalingSpec: Optional[AutoScalingSpec] = Field( + default=None, alias="auto_scaling_spec" + ) + enabled: Optional[bool] = Field( + default=False, description="Once set, the enabled field cannot be modified" + ) + routingStrategy: Optional[ + Literal["prefixaware", "kvaware", "session", "roundrobin"] + ] = Field(default="prefixaware", alias="routing_strategy") + + +class L2CacheSpec(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + l2CacheBackend: Optional[str] = Field( + default=None, alias="l2_cache_backend" + ) + l2CacheLocalUrl: Optional[str] = Field( + default=None, alias="l2_cache_local_url" + ) + + +class KvCacheSpec(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + cacheConfigFile: Optional[str] = Field( + default=None, alias="cache_config_file" + ) + enableL1Cache: Optional[bool] = Field( + default=True, alias="enable_l1_cache" + ) + enableL2Cache: Optional[bool] = Field( + default=False, alias="enable_l2_cache" + ) + l2CacheSpec: Optional[L2CacheSpec] = Field( + default=None, alias="l2_cache_spec" + ) + + +class LoadBalancer(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + healthCheckPath: Optional[str] = Field( + default="/ping", alias="health_check_path" + ) + routingAlgorithm: Optional[Literal["least_outstanding_requests", "round_robin"]] = ( + Field(default="least_outstanding_requests", alias="routing_algorithm") + ) + + +class CustomCertificateConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + acmArn: str = Field(alias="acm_arn", description="ACM certificate ARN") + domainName: str = Field(alias="domain_name") + + class TlsConfig(BaseModel): - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) + customCertificateConfig: Optional[CustomCertificateConfig] = Field( + default=None, alias="custom_certificate_config" + ) tlsCertificateOutputS3Uri: Optional[str] = Field( default=None, alias="tls_certificate_output_s3_uri" ) +class CaptureContentTypeHeader(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for how to treat different content type headers during capture""" + + + csvContentTypes: Optional[List[str]] = Field( + default=None, + alias="csv_content_types", + description="List of content type headers to treat as CSV", + ) + jsonContentTypes: Optional[List[str]] = Field( + default=None, + alias="json_content_types", + description="List of content type headers to treat as JSON", + ) + + +class CaptureOptions(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """CaptureOption defines what data to capture (input, output, or both).""" + + + captureMode: Literal["Input", "Output"] = Field( + alias="capture_mode", description="Capture mode: Input or Output" + ) + + +class BufferConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for buffering and flushing captured data""" + + + batchSize: Optional[int] = Field( + default=10, + alias="batch_size", + description="Number of records to batch before writing to S3", + ) + flushIntervalSeconds: Optional[int] = Field( + default=60, + alias="flush_interval_seconds", + description="Flush interval in seconds", + ) + + +class PayloadConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for payload size limits""" + + + maxPayloadSizeKB: Optional[int] = Field( + default=0, + alias="max_payload_size_kb", + description="Maximum payload size in KB to capture. 0 means no limit.", + ) + + +class DataCaptureModelPod(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for Model Pod level data capture (Tier 3)""" + + + bufferConfig: Optional[BufferConfig] = Field( + default=None, alias="buffer_config", + ) + captureContentTypeHeader: Optional[CaptureContentTypeHeader] = Field( + default=None, alias="capture_content_type_header", + ) + captureOptions: Optional[List[CaptureOptions]] = Field( + default=None, alias="capture_options", + ) + enabled: bool = Field(description="Enable or disable model pod data capture") + initialSamplingPercentage: Optional[int] = Field( + default=None, alias="initial_sampling_percentage", + ) + kmsKeyId: Optional[str] = Field(default=None, alias="kms_key_id") + payloadConfig: Optional[PayloadConfig] = Field( + default=None, alias="payload_config", + ) + + +class DataCaptureSagemakerEndpoint(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for SageMaker Endpoint level data capture (Tier 1)""" + + + captureContentTypeHeader: Optional[CaptureContentTypeHeader] = Field( + default=None, alias="capture_content_type_header", + ) + captureOptions: Optional[List[CaptureOptions]] = Field( + default=None, alias="capture_options", + ) + enabled: bool = Field(description="Enable or disable SageMaker endpoint data capture") + initialSamplingPercentage: Optional[int] = Field( + default=None, alias="initial_sampling_percentage", + ) + kmsKeyId: Optional[str] = Field(default=None, alias="kms_key_id") + + +class DataCaptureLoadBalancer(BaseModel): + """Configuration for LoadBalancer level data capture (Tier 2)""" + + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + enabled: bool = Field(description="Enable or disable load balancer access logs") + + +class DataCapture(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Configuration for data capture across multiple tiers (SageMaker, LoadBalancer, Model Pod)""" + + + loadBalancer: Optional[DataCaptureLoadBalancer] = Field( + default=None, alias="load_balancer", + ) + modelPod: Optional[DataCaptureModelPod] = Field( + default=None, alias="model_pod", + ) + s3Uri: Optional[str] = Field(default=None, alias="s3_uri") + sagemakerEndpoint: Optional[DataCaptureSagemakerEndpoint] = Field( + default=None, alias="sagemaker_endpoint", + ) + + +class DnsConfig(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """DNS automation configuration for Route53.""" + + + hostedZoneId: str = Field( + alias="hosted_zone_id", + description="Route53 Hosted Zone ID where the DNS record will be created.", + ) + + class _HPJumpStartEndpoint(BaseModel): """Config defines the desired state of JumpStartModel.""" - model_config = ConfigDict(extra="ignore") + model_config = ConfigDict(extra="ignore", populate_by_name=True) autoScalingSpec: Optional[AutoScalingSpec] = Field( default=None, alias="auto_scaling_spec" ) + dataCapture: Optional[DataCapture] = Field( + default=None, + alias="data_capture", + description="Configuration for data capture across multiple tiers (SageMaker, LoadBalancer, Model Pod)", + ) + dnsConfig: Optional[DnsConfig] = Field( + default=None, + alias="dns_config", + description="DNS automation configuration for Route53. Requires tlsConfig.customCertificateConfig to be set.", + ) environmentVariables: Optional[List[EnvironmentVariables]] = Field( default=None, alias="environment_variables", @@ -316,6 +650,21 @@ class _HPJumpStartEndpoint(BaseModel): alias="max_deploy_time_in_seconds", description="Maximum allowed time in seconds for the deployment to complete before timing out. Defaults to 1 hour (3600 seconds)", ) + intelligentRoutingSpec: Optional[IntelligentRoutingSpec] = Field( + default=None, + alias="intelligent_routing_spec", + description="Configuration for intelligent routing", + ) + kvCacheSpec: Optional[KvCacheSpec] = Field( + default=None, + alias="kv_cache_spec", + description="Configuration for KV Cache specification", + ) + loadBalancer: Optional[LoadBalancer] = Field( + default=None, + alias="load_balancer", + description="Configuration for Application Load Balancer", + ) metrics: Optional[Metrics] = Field( default=None, description="Configuration for metrics collection and exposure" ) @@ -332,9 +681,10 @@ class _HPJumpStartEndpoint(BaseModel): class Conditions(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """DeploymentCondition describes the state of a deployment at a certain point.""" - model_config = ConfigDict(extra="forbid") lastTransitionTime: Optional[str] = Field( default=None, @@ -365,9 +715,10 @@ class Conditions(BaseModel): class Status(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Status of the Deployment Object""" - model_config = ConfigDict(extra="forbid") availableReplicas: Optional[int] = Field( default=None, @@ -397,6 +748,11 @@ class Status(BaseModel): default=None, description="Total number of non-terminated pods targeted by this deployment (their labels match the selector).", ) + terminatingReplicas: Optional[int] = Field( + default=None, + alias="terminating_replicas", + description="Total number of terminating pods targeted by this deployment.", + ) unavailableReplicas: Optional[int] = Field( default=None, alias="unavailable_replicas", @@ -410,9 +766,10 @@ class Status(BaseModel): class DeploymentStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Details of the native kubernetes deployment that hosts the model""" - model_config = ConfigDict(extra="forbid") deploymentObjectOverallState: Optional[str] = Field( default=None, @@ -435,9 +792,10 @@ class DeploymentStatus(BaseModel): class Sagemaker(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Status of the SageMaker endpoint""" - model_config = ConfigDict(extra="forbid") configArn: Optional[str] = Field( default=None, @@ -460,17 +818,17 @@ class Sagemaker(BaseModel): class Endpoints(BaseModel): """EndpointStatus contains the status of SageMaker endpoints""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) sagemaker: Optional[Sagemaker] = Field( default=None, description="Status of the SageMaker endpoint" ) -class ModelMetrics(BaseModel): +class ModelMetricsStatus(BaseModel): """Status of model container metrics collection""" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", populate_by_name=True) path: Optional[str] = Field( default=None, description="The path where metrics are available" @@ -481,9 +839,10 @@ class ModelMetrics(BaseModel): class MetricsStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """Status of metrics collection""" - model_config = ConfigDict(extra="forbid") enabled: bool = Field(description="Whether metrics collection is enabled") errorMessage: Optional[str] = Field( @@ -496,7 +855,7 @@ class MetricsStatus(BaseModel): alias="metrics_scrape_interval_seconds", description="Scrape interval in seconds for metrics collection from sidecar and model container.", ) - modelMetrics: Optional[ModelMetrics] = Field( + modelMetrics: Optional[ModelMetricsStatus] = Field( default=None, alias="model_metrics", description="Status of model container metrics collection", @@ -507,9 +866,10 @@ class MetricsStatus(BaseModel): class TlsCertificate(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """CertificateStatus represents the status of TLS certificates""" - model_config = ConfigDict(extra="forbid") certificateARN: Optional[str] = Field( default=None, @@ -521,6 +881,11 @@ class TlsCertificate(BaseModel): alias="certificate_domain_names", description="The certificate domain names that is attached to the certificate", ) + certificateHealth: Optional[Literal["Valid", "Expiring", "Expired"]] = Field( + default=None, + alias="certificate_health", + description="Certificate health status", + ) certificateName: Optional[str] = Field( default=None, alias="certificate_name", @@ -551,20 +916,104 @@ class TlsCertificate(BaseModel): ) +class DataCaptureModelPodStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Health status of the model pod data capture tier""" + + + lastTransitionTime: Optional[str] = Field( + default=None, + alias="last_transition_time", + description="Time of the last health state transition", + ) + message: Optional[str] = Field( + default=None, + description="Human-readable message describing the health state", + ) + reason: Optional[str] = Field( + default=None, + description="Reason for unhealthy status (e.g., OOMKilled, S3UploadFailure, MultipleContainerRestarts)", + ) + status: Literal["Healthy", "Unhealthy"] = Field( + description="Current health status" + ) + + +class DataCaptureStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Health status of the data capture pipeline""" + + + modelPod: Optional[DataCaptureModelPodStatus] = Field( + default=None, + alias="model_pod", + description="Health status of the model pod data capture tier", + ) + + +class DnsStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + + """Status of the operator-managed Route53 DNS record""" + + + dnsHealth: Optional[Literal["Active", "Pending", "Error"]] = Field( + default=None, alias="dns_health", + description="DNS resolution status: Active, Pending, or Error.", + ) + hostedZoneId: Optional[str] = Field( + default=None, alias="hosted_zone_id", + description="Route53 hosted zone ID.", + ) + lastTransitionTime: Optional[str] = Field( + default=None, alias="last_transition_time", + description="When the status last transitioned, used for propagation timeout.", + ) + managedByOperator: bool = Field( + alias="managed_by_operator", + description="Whether the operator manages this DNS record.", + ) + message: Optional[str] = Field( + default=None, description="Human-readable status or error message." + ) + previousHostedZoneId: Optional[str] = Field( + default=None, alias="previous_hosted_zone_id", + ) + previousRecordName: Optional[str] = Field( + default=None, alias="previous_record_name", + ) + recordName: Optional[str] = Field( + default=None, alias="record_name", description="Route53 record name." + ) + + class JumpStartModelStatus(BaseModel): + model_config = ConfigDict(extra="forbid", populate_by_name=True) + """ModelDeploymentStatus defines the observed state of ModelDeployment""" - model_config = ConfigDict(extra="forbid") conditions: Optional[List[Conditions]] = Field( default=None, description="Detailed conditions representing the state of the deployment", ) + dataCaptureStatus: Optional[DataCaptureStatus] = Field( + default=None, + alias="data_capture_status", + description="Health status of the data capture pipeline", + ) deploymentStatus: Optional[DeploymentStatus] = Field( default=None, alias="deployment_status", description="Details of the native kubernetes deployment that hosts the model", ) + dnsStatus: Optional[DnsStatus] = Field( + default=None, + alias="dns_status", + description="Status of the operator-managed Route53 DNS record", + ) endpoints: Optional[Endpoints] = Field( default=None, description="EndpointStatus contains the status of SageMaker endpoints", diff --git a/src/sagemaker/hyperpod/inference/hp_endpoint.py b/src/sagemaker/hyperpod/inference/hp_endpoint.py index 7c108231..eb6206ef 100644 --- a/src/sagemaker/hyperpod/inference/hp_endpoint.py +++ b/src/sagemaker/hyperpod/inference/hp_endpoint.py @@ -51,7 +51,11 @@ def _create_internal(self, spec, debug=False): annotations=self.metadata.annotations if self.metadata else None, ) - self.validate_instance_type(spec.instanceType) + if spec.instanceType: + self.validate_instance_type(spec.instanceType) + elif spec.instanceTypes: + for it in spec.instanceTypes: + self.validate_instance_type(it) self.call_create_api( metadata=metadata, diff --git a/test/integration_tests/inference/cli/test_cli_custom_v12_fields.py b/test/integration_tests/inference/cli/test_cli_custom_v12_fields.py new file mode 100644 index 00000000..e68a89f1 --- /dev/null +++ b/test/integration_tests/inference/cli/test_cli_custom_v12_fields.py @@ -0,0 +1,173 @@ +"""Integration test for v1.2 custom endpoint new fields. + +Creates a deployment with all new v1.2 flags, then verifies via describe/get +that the fields are persisted correctly in the CRD, then deletes. +No need to wait for InService — we only validate spec field round-trip. +""" +import pytest +from click.testing import CliRunner +from sagemaker.hyperpod.cli.commands.inference import ( + custom_create, + custom_describe, + custom_delete, +) +from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint +from test.integration_tests.utils import get_time_str + +NAMESPACE = "integration" +VERSION = "1.2" + + +@pytest.fixture(scope="module") +def runner(): + return CliRunner() + + +@pytest.fixture(scope="module") +def endpoint_name(): + return "custom-v12-fields-" + get_time_str() + + +@pytest.mark.dependency(name="create_v12") +def test_create_with_v12_fields(runner, endpoint_name): + """Create custom endpoint with all new v1.2 fields.""" + result = runner.invoke(custom_create, [ + "--namespace", NAMESPACE, + "--version", VERSION, + "--endpoint-name", endpoint_name, + "--model-name", "v12-field-test", + "--model-source-type", "s3", + "--model-location", "test-model", + "--s3-bucket-name", "sagemaker-hyperpod-beta-integ-test-model-bucket-n", + "--s3-region", "us-east-2", + "--image-uri", "763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:2.3.0-transformers4.48.0-cpu-py311-ubuntu22.04", + "--container-port", "8080", + "--model-volume-mount-name", "model-weights", + # v1.2 new fields + "--replicas", "1", + "--initial-replica-count", "1", + "--max-deploy-time-in-seconds", "7200", + "--invocation-endpoint", "invocations", + "--worker-args", "--max-model-len,4096", + "--worker-command", "python,-m,vllm.entrypoints.openai.api_server", + "--working-dir", "/opt/ml", + "--metrics-enabled", "True", + "--metrics-scrape-interval-seconds", "30", + "--model-metrics-path", "/metrics", + "--model-metrics-port", "8080", + "--max-concurrent-requests", "10", + "--max-queue-size", "5", + "--overflow-status-code", "503", + "--load-balancer-health-check-path", "/ping", + "--load-balancer-routing-algorithm", "round_robin", + "--intelligent-routing-enabled", "False", + "--enable-l1-cache", "True", + "--kubernetes", '{"serviceAccountName":"default"}', + "--tags", '{"team":"ml","env":"integ-test"}', + "--node-affinity", '{"required_during_scheduling_ignored_during_execution":{"node_selector_terms":[{"match_expressions":[{"key":"node.kubernetes.io/instance-type","operator":"In","values":["ml.c5.2xlarge"]}]}]}}', + "--probes", '{"livenessProbe":{"httpGet":{"path":"/ping","port":8080},"periodSeconds":30}}', + "--auto-scaling-spec", '{"min_replica_count":1,"max_replica_count":3,"polling_interval":60}', + "--resources-requests", '{"cpu":"1","memory":"2Gi"}', + "--resources-limits", '{"cpu":"2","memory":"4Gi","nvidia.com/gpu":"0"}', + "--custom-certificate-acm-arn", "arn:aws:acm:us-east-2:249127818294:certificate/a1b2c3d4-e5f6-7890-abcd-ef1234567890", + "--custom-certificate-domain-name", "test.example.com", + ]) + assert result.exit_code == 0, result.output + + # Verify the CR was actually created on the cluster + from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint + ep = HPEndpoint.get(name=endpoint_name, namespace=NAMESPACE) + assert ep is not None, f"Endpoint {endpoint_name} not found on cluster after creation" + + +@pytest.mark.dependency(name="verify_v12", depends=["create_v12"]) +def test_verify_v12_fields_via_sdk(endpoint_name): + """Get the endpoint via SDK and verify all new v1.2 fields are correct.""" + ep = HPEndpoint.get(name=endpoint_name, namespace=NAMESPACE) + + # Basic fields + assert ep.replicas == 1 + assert ep.InitialReplicaCount == 1 + assert ep.maxDeployTimeInSeconds == 7200 + assert ep.invocationEndpoint == "invocations" + + # Worker fields + assert ep.worker.args == ["--max-model-len", "4096"] + assert ep.worker.command == ["python", "-m", "vllm.entrypoints.openai.api_server"] + assert ep.worker.workingDir == "/opt/ml" + + # Metrics + assert ep.metrics.enabled is True + assert ep.metrics.metricsScrapeIntervalSeconds == 30 + assert ep.metrics.modelMetrics.path == "/metrics" + assert ep.metrics.modelMetrics.port == 8080 + + # Request limits + assert ep.worker.requestLimits.maxConcurrentRequests == 10 + assert ep.worker.requestLimits.maxQueueSize == 5 + assert ep.worker.requestLimits.overflowStatusCode == 503 + + # Load balancer + assert ep.loadBalancer.healthCheckPath == "/ping" + assert ep.loadBalancer.routingAlgorithm == "round_robin" + + # Intelligent routing + assert ep.intelligentRoutingSpec.enabled is False + + # KV cache + assert ep.kvCacheSpec.enableL1Cache is True + + # Kubernetes + assert ep.kubernetes.serviceAccountName == "default" + + # Tags + tag_map = {t.name: t.value for t in ep.tags} + assert tag_map["team"] == "ml" + assert tag_map["env"] == "integ-test" + + # Resources + assert ep.worker.resources.requests["cpu"] == "1" + assert ep.worker.resources.limits["nvidia.com/gpu"] == "0" + + # Node affinity + assert ep.nodeAffinity is not None + terms = ep.nodeAffinity.requiredDuringSchedulingIgnoredDuringExecution.nodeSelectorTerms + assert terms[0].matchExpressions[0].key == "node.kubernetes.io/instance-type" + + # Probes + assert ep.worker.probes is not None + assert ep.worker.probes.livenessProbe is not None + + # Auto scaling spec + assert ep.autoScalingSpec.minReplicaCount == 1 + assert ep.autoScalingSpec.maxReplicaCount == 3 + assert ep.autoScalingSpec.pollingInterval == 60 + + # Custom certificate + assert ep.tlsConfig.customCertificateConfig.acmArn == "arn:aws:acm:us-east-2:249127818294:certificate/a1b2c3d4-e5f6-7890-abcd-ef1234567890" + assert ep.tlsConfig.customCertificateConfig.domainName == "test.example.com" + + +@pytest.mark.dependency(name="describe_v12", depends=["create_v12"]) +def test_describe_shows_v12_fields(runner, endpoint_name): + """Verify describe CLI output contains new v1.2 field values.""" + result = runner.invoke(custom_describe, [ + "--name", endpoint_name, + "--namespace", NAMESPACE, + "--full", + ]) + assert result.exit_code == 0 + output = result.output + assert endpoint_name in output + assert "7200" in output # maxDeployTimeInSeconds + assert "round_robin" in output # loadBalancer routing + + +@pytest.mark.dependency(depends=["verify_v12", "describe_v12"]) +def test_delete_v12_endpoint(runner, endpoint_name): + """Clean up the test endpoint.""" + result = runner.invoke(custom_delete, [ + "--name", endpoint_name, + "--namespace", NAMESPACE, + ]) + assert result.exit_code == 0 diff --git a/test/integration_tests/inference/cli/test_cli_jumpstart_v12_fields.py b/test/integration_tests/inference/cli/test_cli_jumpstart_v12_fields.py new file mode 100644 index 00000000..2ec0572a --- /dev/null +++ b/test/integration_tests/inference/cli/test_cli_jumpstart_v12_fields.py @@ -0,0 +1,149 @@ +"""Integration test for v1.2 JumpStart endpoint new fields. + +Creates a deployment with all new v1.2 flags, then verifies via describe/get +that the fields are persisted correctly in the CRD, then deletes. +""" +import pytest +from click.testing import CliRunner +from sagemaker.hyperpod.cli.commands.inference import ( + js_create, + js_describe, + js_delete, +) +from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint +from test.integration_tests.utils import get_time_str + +NAMESPACE = "integration" +VERSION = "1.2" + + +@pytest.fixture(scope="module") +def runner(): + return CliRunner() + + +@pytest.fixture(scope="module") +def endpoint_name(): + return "js-v12-fields-" + get_time_str() + + +@pytest.mark.dependency(name="js_create_v12") +def test_create_with_v12_fields(runner, endpoint_name): + """Create JumpStart endpoint with all new v1.2 fields.""" + result = runner.invoke(js_create, [ + "--namespace", NAMESPACE, + "--version", VERSION, + "--model-id", "deepseek-llm-r1-distill-qwen-1-5b", + "--instance-type", "ml.g5.8xlarge", + "--endpoint-name", endpoint_name, + "--accept-eula", "True", + # v1.2 new fields + "--replicas", "1", + "--max-deploy-time-in-seconds", "5400", + "--execution-role", "arn:aws:iam::249127818294:role/test-role", + "--metrics-enabled", "True", + "--metrics-scrape-interval-seconds", "30", + "--model-metrics-path", "/metrics", + "--model-metrics-port", "8080", + "--intelligent-routing-enabled", "False", + "--routing-strategy", "roundrobin", + "--enable-l1-cache", "True", + "--enable-l2-cache", "True", + "--l2-cache-backend", "redis", + "--l2-cache-local-url", "redis://localhost:6379", + "--cache-config-file", "/opt/config/kv.yaml", + "--load-balancer-health-check-path", "/ping", + "--load-balancer-routing-algorithm", "round_robin", + "--env", '{"TEST_KEY":"test_value"}', + "--additional-configs", '{"config1":"value1"}', + "--auto-scaling-spec", '{"min_replica_count":1,"max_replica_count":5,"polling_interval":60}', + "--custom-certificate-acm-arn", "arn:aws:acm:us-east-2:249127818294:certificate/a1b2c3d4-e5f6-7890-abcd-ef1234567890", + "--custom-certificate-domain-name", "test.example.com", + "--gated-model-download-role", "arn:aws:iam::249127818294:role/gated-download", + "--model-hub-name", "SageMakerPublicHub", + ]) + assert result.exit_code == 0, result.output + + # Verify the CR was actually created on the cluster + from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint + ep = HPJumpStartEndpoint.get(name=endpoint_name, namespace=NAMESPACE) + assert ep is not None, f"Endpoint {endpoint_name} not found on cluster after creation" + + +@pytest.mark.dependency(name="js_verify_v12", depends=["js_create_v12"]) +def test_verify_v12_fields_via_sdk(endpoint_name): + """Get the endpoint via SDK and verify all new v1.2 fields are correct.""" + ep = HPJumpStartEndpoint.get(name=endpoint_name, namespace=NAMESPACE) + + # Basic fields + assert ep.replicas == 1 + assert ep.maxDeployTimeInSeconds == 5400 + + # Server + assert ep.server.executionRole == "arn:aws:iam::249127818294:role/test-role" + + # Metrics + assert ep.metrics.enabled is True + assert ep.metrics.metricsScrapeIntervalSeconds == 30 + assert ep.metrics.modelMetrics.path == "/metrics" + assert ep.metrics.modelMetrics.port == 8080 + + # Intelligent routing + assert ep.intelligentRoutingSpec.enabled is False + assert ep.intelligentRoutingSpec.routingStrategy == "roundrobin" + + # KV cache + assert ep.kvCacheSpec.enableL1Cache is True + assert ep.kvCacheSpec.enableL2Cache is True + assert ep.kvCacheSpec.l2CacheSpec.l2CacheBackend == "redis" + assert ep.kvCacheSpec.l2CacheSpec.l2CacheLocalUrl == "redis://localhost:6379" + assert ep.kvCacheSpec.cacheConfigFile == "/opt/config/kv.yaml" + + # Load balancer + assert ep.loadBalancer.healthCheckPath == "/ping" + assert ep.loadBalancer.routingAlgorithm == "round_robin" + + # Environment variables + env_map = {e.name: e.value for e in ep.environmentVariables} + assert env_map["TEST_KEY"] == "test_value" + + # Additional configs + config_map = {c.name: c.value for c in ep.model.additionalConfigs} + assert config_map["config1"] == "value1" + + # Model + assert ep.model.gatedModelDownloadRole == "arn:aws:iam::249127818294:role/gated-download" + assert ep.model.modelHubName == "SageMakerPublicHub" + + # Auto scaling spec + assert ep.autoScalingSpec.minReplicaCount == 1 + assert ep.autoScalingSpec.maxReplicaCount == 5 + assert ep.autoScalingSpec.pollingInterval == 60 + + # Custom certificate + assert ep.tlsConfig.customCertificateConfig.acmArn == "arn:aws:acm:us-east-2:249127818294:certificate/a1b2c3d4-e5f6-7890-abcd-ef1234567890" + assert ep.tlsConfig.customCertificateConfig.domainName == "test.example.com" + + +@pytest.mark.dependency(name="js_describe_v12", depends=["js_create_v12"]) +def test_describe_shows_v12_fields(runner, endpoint_name): + """Verify describe CLI output contains new v1.2 field values.""" + result = runner.invoke(js_describe, [ + "--name", endpoint_name, + "--namespace", NAMESPACE, + "--full", + ]) + assert result.exit_code == 0 + output = result.output + assert endpoint_name in output + assert "5400" in output # maxDeployTimeInSeconds + + +@pytest.mark.dependency(depends=["js_verify_v12", "js_describe_v12"]) +def test_delete_v12_endpoint(runner, endpoint_name): + """Clean up the test endpoint.""" + result = runner.invoke(js_delete, [ + "--name", endpoint_name, + "--namespace", NAMESPACE, + ]) + assert result.exit_code == 0 diff --git a/test/integration_tests/inference/sdk/test_sdk_custom_v12_fields.py b/test/integration_tests/inference/sdk/test_sdk_custom_v12_fields.py new file mode 100644 index 00000000..58ba8d27 --- /dev/null +++ b/test/integration_tests/inference/sdk/test_sdk_custom_v12_fields.py @@ -0,0 +1,127 @@ +"""Integration test for v1.2 custom endpoint new fields via SDK. + +Creates a deployment with all new v1.2 SDK fields, verifies via get, +then deletes. No need to wait for InService. +""" +import pytest +from sagemaker.hyperpod.inference.hp_endpoint import HPEndpoint +from sagemaker.hyperpod.inference.config.hp_endpoint_config import ( + ModelSourceConfig, S3Storage, TlsConfig, Worker, ModelVolumeMount, + ModelInvocationPort, Resources, EnvironmentVariables, Metrics, ModelMetrics, + LoadBalancer, IntelligentRoutingSpec, KvCacheSpec, Kubernetes, + RequestLimits, Probes, Tags, +) +from sagemaker.hyperpod.common.config.metadata import Metadata +from test.integration_tests.utils import get_time_str + +NAMESPACE = "integration" +ENDPOINT_NAME = "custom-sdk-v12-" + get_time_str() + + +@pytest.fixture(scope="module") +def custom_endpoint(): + metadata = Metadata(name=ENDPOINT_NAME, namespace=NAMESPACE) + + model_src = ModelSourceConfig( + model_source_type="s3", + model_location="test-model", + s3_storage=S3Storage( + bucket_name="sagemaker-hyperpod-beta-integ-test-model-bucket-n", + region="us-east-2", + ), + ) + + worker = Worker( + image="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pytorch-inference:2.3.0-transformers4.48.0-cpu-py311-ubuntu22.04", + model_volume_mount=ModelVolumeMount(name="model-weights"), + model_invocation_port=ModelInvocationPort(container_port=8080), + resources=Resources( + requests={"cpu": "1", "memory": "2Gi"}, + limits={"cpu": "2", "memory": "4Gi", "nvidia.com/gpu": "0"}, + ), + environment_variables=[ + EnvironmentVariables(name="SAGEMAKER_PROGRAM", value="inference.py"), + ], + args=["--max-model-len", "4096"], + working_dir="/opt/ml", + request_limits=RequestLimits( + max_concurrent_requests=10, + max_queue_size=5, + overflow_status_code=503, + ), + ) + + return HPEndpoint( + metadata=metadata, + endpoint_name=ENDPOINT_NAME, + instance_type="ml.c5.2xlarge", + model_name="v12-sdk-field-test", + model_source_config=model_src, + worker=worker, + replicas=1, + max_deploy_time_in_seconds=7200, + metrics=Metrics( + enabled=True, + metrics_scrape_interval_seconds=30, + model_metrics=ModelMetrics(path="/metrics", port=8080), + ), + load_balancer=LoadBalancer( + health_check_path="/ping", + routing_algorithm="round_robin", + ), + intelligent_routing_spec=IntelligentRoutingSpec(enabled=False), + kv_cache_spec=KvCacheSpec(enable_l1_cache=True), + kubernetes=Kubernetes(service_account_name="default"), + tags=[Tags(name="team", value="ml"), Tags(name="env", value="integ-test")], + ) + + +@pytest.mark.dependency(name="sdk_create_v12") +def test_create_endpoint(custom_endpoint): + custom_endpoint.create() + assert custom_endpoint.metadata.name == ENDPOINT_NAME + + +@pytest.mark.dependency(name="sdk_verify_v12", depends=["sdk_create_v12"]) +def test_verify_v12_fields(): + """Get endpoint and verify all new v1.2 fields.""" + ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + + assert ep.replicas == 1 + assert ep.maxDeployTimeInSeconds == 7200 + + # Worker + assert ep.worker.args == ["--max-model-len", "4096"] + assert ep.worker.workingDir == "/opt/ml" + assert ep.worker.requestLimits.maxConcurrentRequests == 10 + assert ep.worker.requestLimits.maxQueueSize == 5 + assert ep.worker.requestLimits.overflowStatusCode == 503 + + # Metrics + assert ep.metrics.enabled is True + assert ep.metrics.metricsScrapeIntervalSeconds == 30 + assert ep.metrics.modelMetrics.path == "/metrics" + + # Load balancer + assert ep.loadBalancer.healthCheckPath == "/ping" + assert ep.loadBalancer.routingAlgorithm == "round_robin" + + # Intelligent routing + assert ep.intelligentRoutingSpec.enabled is False + + # KV cache + assert ep.kvCacheSpec.enableL1Cache is True + + # Kubernetes + assert ep.kubernetes.serviceAccountName == "default" + + # Tags + tag_map = {t.name: t.value for t in ep.tags} + assert tag_map["team"] == "ml" + assert tag_map["env"] == "integ-test" + + +@pytest.mark.dependency(depends=["sdk_verify_v12"]) +def test_delete_endpoint(): + ep = HPEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + ep.delete() diff --git a/test/integration_tests/inference/sdk/test_sdk_jumpstart_v12_fields.py b/test/integration_tests/inference/sdk/test_sdk_jumpstart_v12_fields.py new file mode 100644 index 00000000..5f77f9d4 --- /dev/null +++ b/test/integration_tests/inference/sdk/test_sdk_jumpstart_v12_fields.py @@ -0,0 +1,85 @@ +"""Integration test for v1.2 JumpStart endpoint new fields via SDK. + +Creates a deployment with all new v1.2 SDK fields, verifies via get, +then deletes. No need to wait for InService. +""" +import pytest +from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint +from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import ( + Model, Server, SageMakerEndpoint, Metrics, ModelMetrics, + EnvironmentVariables, IntelligentRoutingSpec, KvCacheSpec, LoadBalancer, +) +from sagemaker.hyperpod.common.config.metadata import Metadata +from test.integration_tests.utils import get_time_str + +NAMESPACE = "integration" +ENDPOINT_NAME = "js-sdk-v12-" + get_time_str() + + +@pytest.fixture(scope="module") +def endpoint_obj(): + metadata = Metadata(name=ENDPOINT_NAME, namespace=NAMESPACE) + + return HPJumpStartEndpoint( + metadata=metadata, + model=Model(model_id="deepseek-llm-r1-distill-qwen-1-5b", accept_eula=True), + server=Server(instance_type="ml.g5.8xlarge"), + sage_maker_endpoint=SageMakerEndpoint(name=ENDPOINT_NAME), + replicas=1, + max_deploy_time_in_seconds=5400, + metrics=Metrics( + enabled=True, + metrics_scrape_interval_seconds=30, + model_metrics=ModelMetrics(path="/metrics", port=8080), + ), + intelligent_routing_spec=IntelligentRoutingSpec(enabled=False), + kv_cache_spec=KvCacheSpec(enable_l1_cache=True), + load_balancer=LoadBalancer( + health_check_path="/ping", + routing_algorithm="round_robin", + ), + environment_variables=[ + EnvironmentVariables(name="TEST_KEY", value="test_value"), + ], + ) + + +@pytest.mark.dependency(name="js_sdk_create_v12") +def test_create_endpoint(endpoint_obj): + endpoint_obj.create() + assert endpoint_obj.metadata.name == ENDPOINT_NAME + + +@pytest.mark.dependency(name="js_sdk_verify_v12", depends=["js_sdk_create_v12"]) +def test_verify_v12_fields(): + """Get endpoint and verify all new v1.2 fields.""" + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + + assert ep.replicas == 1 + assert ep.maxDeployTimeInSeconds == 5400 + + # Metrics + assert ep.metrics.enabled is True + assert ep.metrics.metricsScrapeIntervalSeconds == 30 + assert ep.metrics.modelMetrics.path == "/metrics" + assert ep.metrics.modelMetrics.port == 8080 + + # Intelligent routing + assert ep.intelligentRoutingSpec.enabled is False + + # KV cache + assert ep.kvCacheSpec.enableL1Cache is True + + # Load balancer + assert ep.loadBalancer.healthCheckPath == "/ping" + assert ep.loadBalancer.routingAlgorithm == "round_robin" + + # Environment variables + env_map = {e.name: e.value for e in ep.environmentVariables} + assert env_map["TEST_KEY"] == "test_value" + + +@pytest.mark.dependency(depends=["js_sdk_verify_v12"]) +def test_delete_endpoint(): + ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE) + ep.delete() diff --git a/test/unit_tests/cli/test_inference.py b/test/unit_tests/cli/test_inference.py index 177dedd1..e032efe1 100644 --- a/test/unit_tests/cli/test_inference.py +++ b/test/unit_tests/cli/test_inference.py @@ -107,10 +107,16 @@ def test_js_create_missing_required_args(): assert "Missing option" in result.output +@patch("sys.argv", ["pytest", "--version", "1.1"]) def test_js_create_with_mig_profile(): """ Test js_create with MIG profile (accelerator partition) options using v1.1 schema. """ + if "sagemaker.hyperpod.cli.commands.inference" in sys.modules: + importlib.reload(sys.modules["sagemaker.hyperpod.cli.commands.inference"]) + + from sagemaker.hyperpod.cli.commands.inference import js_create + with patch( "sagemaker.hyperpod.cli.inference_utils.load_schema_for_version" ) as mock_load_schema, patch( @@ -183,10 +189,16 @@ def test_js_create_missing_required_args(): assert "Missing option" in result.output +@patch("sys.argv", ["pytest", "--version", "1.1"]) def test_js_create_mig_validation_error_handling(): """ Test js_create properly handles MIG profile validation errors using v1.1 schema. """ + if "sagemaker.hyperpod.cli.commands.inference" in sys.modules: + importlib.reload(sys.modules["sagemaker.hyperpod.cli.commands.inference"]) + + from sagemaker.hyperpod.cli.commands.inference import js_create + with patch( "sagemaker.hyperpod.cli.commands.inference.HPJumpStartEndpoint" ) as mock_endpoint_class, patch( @@ -636,4 +648,402 @@ def test_custom_create_with_intelligent_routing_and_kv_cache(): ) assert result.exit_code == 0, result.output - domain_obj.create.assert_called_once_with(debug=False) \ No newline at end of file + domain_obj.create.assert_called_once_with(debug=False) + + +# ── v1.2 new field to_domain tests ────────────────────────────────────────── + + +def test_custom_to_domain_huggingface(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + + flat = FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="huggingface", + huggingface_model_id="meta-llama/Llama-3.1-8B-Instruct", + huggingface_token_secret_name="hf-secret", huggingface_token_secret_key="token", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + ) + domain = flat.to_domain() + assert domain.modelSourceConfig.modelSourceType == "huggingface" + assert domain.modelSourceConfig.huggingFaceModel.modelId == "meta-llama/Llama-3.1-8B-Instruct" + assert domain.modelSourceConfig.huggingFaceModel.tokenSecretRef.name == "hf-secret" + + +def test_custom_to_domain_kubernetes_volume(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + + flat = FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="kubernetesVolume", + model_location="/mnt/models/my-model", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + ) + domain = flat.to_domain() + assert domain.modelSourceConfig.modelSourceType == "kubernetesVolume" + assert domain.modelSourceConfig.modelLocation == "/mnt/models/my-model" + + +def test_custom_to_domain_dns_config(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + + flat = FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="s3", + s3_bucket_name="bucket", s3_region="us-east-2", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", dns_hosted_zone_id="Z1234567890", + custom_certificate_acm_arn="arn:aws:acm:us-east-2:123456789012:certificate/abcd1234-abcd-1234-abcd-1234abcd1234", + custom_certificate_domain_name="test.example.com", + ) + domain = flat.to_domain() + assert domain.dnsConfig.hostedZoneId == "Z1234567890" + + +def test_custom_to_domain_data_capture_json(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + + flat = FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="s3", + s3_bucket_name="bucket", s3_region="us-east-2", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + data_capture={"sagemaker_endpoint": {"enabled": True}}, + ) + domain = flat.to_domain() + assert domain.dataCapture.sagemakerEndpoint.enabled is True + + +def test_custom_to_domain_service_account_via_kubernetes(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + + flat = FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="s3", + s3_bucket_name="bucket", s3_region="us-east-2", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + kubernetes={"service_account_name": "my-sa"}, + ) + domain = flat.to_domain() + assert domain.kubernetes.serviceAccountName == "my-sa" + + +def test_custom_huggingface_requires_model_id(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="huggingface", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + ) + + +def test_js_to_domain_dns_config(): + from hyperpod_jumpstart_inference_template.v1_2.model import FlatHPJumpStartEndpoint + + flat = FlatHPJumpStartEndpoint( + metadata_name="test", model_id="test-model", instance_type="ml.g5.8xlarge", + dns_hosted_zone_id="Z999", + custom_certificate_acm_arn="arn:aws:acm:us-east-2:123456789012:certificate/abcd1234-abcd-1234-abcd-1234abcd1234", + custom_certificate_domain_name="test.example.com", + ) + domain = flat.to_domain() + assert domain.dnsConfig.hostedZoneId == "Z999" + + +def test_js_to_domain_data_capture_json(): + from hyperpod_jumpstart_inference_template.v1_2.model import FlatHPJumpStartEndpoint + + flat = FlatHPJumpStartEndpoint( + metadata_name="test", model_id="test-model", instance_type="ml.g5.8xlarge", + data_capture={"load_balancer": {"enabled": True}}, + ) + domain = flat.to_domain() + assert domain.dataCapture.loadBalancer.enabled is True + + +def test_custom_node_affinity_without_instance_type(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + + flat = FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="kubernetesVolume", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + node_affinity={"required_during_scheduling_ignored_during_execution": { + "node_selector_terms": [{"match_expressions": [{"key": "kubernetes.io/hostname", "operator": "In", "values": ["node-1"]}]}] + }}, + ) + domain = flat.to_domain() + assert domain.nodeAffinity is not None + assert domain.instanceType is None + + +def test_custom_node_affinity_with_instance_type_fails(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="node_affinity cannot be specified with instance_type"): + FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="kubernetesVolume", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + node_affinity={"required_during_scheduling_ignored_during_execution": { + "node_selector_terms": [{"match_expressions": [{"key": "k", "operator": "In", "values": ["v"]}]}] + }}, + ) + + +def test_custom_node_affinity_with_instance_types_fails(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="node_affinity cannot be specified with instance_type"): + FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="kubernetesVolume", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + instance_types="ml.g5.8xlarge,ml.g5.4xlarge", + node_affinity={"required_during_scheduling_ignored_during_execution": { + "node_selector_terms": [{"match_expressions": [{"key": "k", "operator": "In", "values": ["v"]}]}] + }}, + ) + + +def test_custom_no_instance_type_no_node_affinity_fails(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="Either instance_type, instance_types, or node_affinity must be provided"): + FlatHPEndpoint( + metadata_name="test", model_name="test-model", model_source_type="kubernetesVolume", + image_uri="test:latest", container_port=8000, model_volume_mount_name="mw", + ) + + +# ── v1.2 validation fix tests (cert/DNS cross-validation, pattern validation, data_capture skip list) ── + +VALID_ACM_ARN = "arn:aws:acm:us-east-2:123456789012:certificate/abcd1234-abcd-1234-abcd-1234abcd1234" + +# Bug 2: cert mutual dependency — custom endpoint + +def test_custom_acm_arn_without_domain_raises(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="must both be provided together"): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_acm_arn=VALID_ACM_ARN, + ) + + +def test_custom_domain_without_acm_arn_raises(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="must both be provided together"): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_domain_name="test.example.com", + ) + + +def test_custom_both_cert_fields_valid(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + + ep = FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_acm_arn=VALID_ACM_ARN, + custom_certificate_domain_name="test.example.com", + ) + assert ep.custom_certificate_acm_arn is not None + assert ep.custom_certificate_domain_name is not None + + +# Bug 1: DNS requires both cert fields — custom endpoint + +def test_custom_dns_without_cert_raises(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="dns_hosted_zone_id requires both"): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + dns_hosted_zone_id="Z1234567890ABC", + ) + + +def test_custom_dns_with_only_acm_arn_raises(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="must both be provided together"): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + dns_hosted_zone_id="Z1234567890ABC", + custom_certificate_acm_arn=VALID_ACM_ARN, + ) + + +# Bug 3: pattern validation — custom endpoint + +def test_custom_invalid_acm_arn_pattern(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_acm_arn="not-a-valid-arn", + custom_certificate_domain_name="test.example.com", + ) + + +def test_custom_invalid_domain_name_pattern(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_acm_arn=VALID_ACM_ARN, + custom_certificate_domain_name="INVALID_DOMAIN!!", + ) + + +def test_custom_domain_name_too_long(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_acm_arn=VALID_ACM_ARN, + custom_certificate_domain_name="a" * 254, + ) + + +def test_custom_invalid_hosted_zone_id_pattern(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_acm_arn=VALID_ACM_ARN, + custom_certificate_domain_name="test.example.com", + dns_hosted_zone_id="invalid-zone", + ) + + +def test_custom_hosted_zone_must_start_with_z(): + from hyperpod_custom_inference_template.v1_2.model import FlatHPEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPEndpoint( + metadata_name="test", model_name="m", model_source_type="s3", + s3_bucket_name="b", s3_region="us-east-2", + image_uri="i:l", container_port=8080, model_volume_mount_name="mw", + instance_type="ml.g5.8xlarge", + custom_certificate_acm_arn=VALID_ACM_ARN, + custom_certificate_domain_name="test.example.com", + dns_hosted_zone_id="A1234567890", + ) + + +# Jumpstart cert/DNS validation + +def test_js_acm_arn_without_domain_raises(): + from hyperpod_jumpstart_inference_template.v1_2.model import FlatHPJumpStartEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="must both be provided together"): + FlatHPJumpStartEndpoint( + model_id="test-model", instance_type="ml.g5.8xlarge", endpoint_name="ep", + custom_certificate_acm_arn=VALID_ACM_ARN, + ) + + +def test_js_dns_without_cert_raises(): + from hyperpod_jumpstart_inference_template.v1_2.model import FlatHPJumpStartEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError, match="dns_hosted_zone_id requires both"): + FlatHPJumpStartEndpoint( + model_id="test-model", instance_type="ml.g5.8xlarge", endpoint_name="ep", + dns_hosted_zone_id="Z1234567890ABC", + ) + + +def test_js_invalid_acm_arn_pattern(): + from hyperpod_jumpstart_inference_template.v1_2.model import FlatHPJumpStartEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPJumpStartEndpoint( + model_id="test-model", instance_type="ml.g5.8xlarge", endpoint_name="ep", + custom_certificate_acm_arn="bad-arn", + custom_certificate_domain_name="test.example.com", + ) + + +def test_js_invalid_hosted_zone_pattern(): + from hyperpod_jumpstart_inference_template.v1_2.model import FlatHPJumpStartEndpoint + from pydantic import ValidationError + + with pytest.raises(ValidationError): + FlatHPJumpStartEndpoint( + model_id="test-model", instance_type="ml.g5.8xlarge", endpoint_name="ep", + custom_certificate_acm_arn=VALID_ACM_ARN, + custom_certificate_domain_name="test.example.com", + dns_hosted_zone_id="invalid", + ) + + +def test_js_cert_dns_all_valid(): + from hyperpod_jumpstart_inference_template.v1_2.model import FlatHPJumpStartEndpoint + + ep = FlatHPJumpStartEndpoint( + model_id="test-model", instance_type="ml.g5.8xlarge", endpoint_name="ep", + custom_certificate_acm_arn=VALID_ACM_ARN, + custom_certificate_domain_name="test.example.com", + dns_hosted_zone_id="Z1234567890ABC", + ) + assert ep.dns_hosted_zone_id is not None + + +# Bug 4: data_capture skip list in inference_utils + +def test_data_capture_in_inference_utils_skip_list(): + import inspect + from sagemaker.hyperpod.cli.inference_utils import generate_click_command + + source = inspect.getsource(generate_click_command) + assert '"data_capture"' in source diff --git a/test/unit_tests/inference/test_hp_endpoint.py b/test/unit_tests/inference/test_hp_endpoint.py index 10a69a72..16c78e67 100644 --- a/test/unit_tests/inference/test_hp_endpoint.py +++ b/test/unit_tests/inference/test_hp_endpoint.py @@ -21,6 +21,32 @@ ModelVolumeMount, Resources, Worker, + Kubernetes, + CustomCertificateConfig, + NodeAffinity, + NodeSelectorTerm, + NodeSelectorRequirement, + NodeSelector, + PreferredSchedulingTerm, + Probe, + Probes, + RequestLimits, + Tags, + HuggingFaceModel, + TokenSecretRef, + DataCapture, + DataCaptureLoadBalancer, + DataCaptureModelPod, + DataCaptureSagemakerEndpoint, + CaptureOptions, + CaptureContentTypeHeader, + BufferConfig, + PayloadConfig, + DnsConfig, + DnsStatus, + InferenceEndpointConfigStatus, + TlsCertificate, + Status, ) from sagemaker.hyperpod.inference.config.constants import * from sagemaker.hyperpod.common.config import Metadata @@ -330,3 +356,99 @@ def test_list_pods_with_endpoint_name(self, mock_verify_config, mock_core_api): mock_core_api.return_value.list_namespaced_pod.assert_called_once_with( namespace="default" ) + + +class TestServiceAccountName(unittest.TestCase): + def test_kubernetes_with_service_account_name(self): + k8s = Kubernetes(service_account_name="my-inference-sa", scheduler_name="default-scheduler") + self.assertEqual(k8s.serviceAccountName, "my-inference-sa") + + def test_kubernetes_service_account_name_camel_case(self): + k8s = Kubernetes(serviceAccountName="my-sa") + self.assertEqual(k8s.serviceAccountName, "my-sa") + + def test_kubernetes_service_account_name_none_by_default(self): + k8s = Kubernetes() + self.assertIsNone(k8s.serviceAccountName) + + +class TestHuggingFaceModelConfig(unittest.TestCase): + def test_huggingface_model_basic(self): + hf = HuggingFaceModel(model_id="meta-llama/Llama-3.1-8B-Instruct") + self.assertEqual(hf.modelId, "meta-llama/Llama-3.1-8B-Instruct") + self.assertIsNone(hf.commitSHA) + self.assertIsNone(hf.tokenSecretRef) + + def test_huggingface_model_with_token(self): + hf = HuggingFaceModel( + model_id="meta-llama/Llama-3.1-8B-Instruct", + commit_sha="a" * 40, + token_secret_ref=TokenSecretRef(name="hf-secret", key="token"), + ) + self.assertEqual(hf.commitSHA, "a" * 40) + self.assertEqual(hf.tokenSecretRef.name, "hf-secret") + + def test_model_source_config_huggingface(self): + src = ModelSourceConfig( + model_source_type="huggingface", + hugging_face_model=HuggingFaceModel(model_id="meta-llama/Llama-3.1-8B-Instruct"), + ) + self.assertEqual(src.modelSourceType, "huggingface") + self.assertIsNotNone(src.huggingFaceModel) + + def test_model_source_config_kubernetes_volume(self): + src = ModelSourceConfig(model_source_type="kubernetesVolume", model_location="/mnt/models/my-model") + self.assertEqual(src.modelSourceType, "kubernetesVolume") + + +class TestDataCaptureConfig(unittest.TestCase): + def test_data_capture_sagemaker_endpoint(self): + dc = DataCapture(sagemaker_endpoint=DataCaptureSagemakerEndpoint(enabled=True)) + self.assertTrue(dc.sagemakerEndpoint.enabled) + + def test_data_capture_model_pod(self): + dc = DataCapture( + model_pod=DataCaptureModelPod( + enabled=True, initial_sampling_percentage=50, + buffer_config=BufferConfig(batch_size=20, flush_interval_seconds=120), + capture_options=[CaptureOptions(capture_mode="Input"), CaptureOptions(capture_mode="Output")], + ), + s3_uri="s3://my-bucket/captures", + ) + self.assertTrue(dc.modelPod.enabled) + self.assertEqual(dc.modelPod.bufferConfig.batchSize, 20) + self.assertEqual(len(dc.modelPod.captureOptions), 2) + + def test_data_capture_full(self): + dc = DataCapture( + sagemaker_endpoint=DataCaptureSagemakerEndpoint( + enabled=True, + capture_content_type_header=CaptureContentTypeHeader( + csv_content_types=["text/csv"], json_content_types=["application/json"], + ), + ), + load_balancer=DataCaptureLoadBalancer(enabled=False), + model_pod=DataCaptureModelPod( + enabled=True, payload_config=PayloadConfig(max_payload_size_kb=1024), + kms_key_id="arn:aws:kms:us-east-2:123:key/abc", + ), + s3_uri="s3://bucket/prefix", + ) + self.assertEqual(dc.sagemakerEndpoint.captureContentTypeHeader.csvContentTypes, ["text/csv"]) + self.assertFalse(dc.loadBalancer.enabled) + self.assertEqual(dc.modelPod.payloadConfig.maxPayloadSizeKB, 1024) + + +class TestDnsConfigAndStatus(unittest.TestCase): + def test_dns_config(self): + dns = DnsConfig(hosted_zone_id="Z1234567890") + self.assertEqual(dns.hostedZoneId, "Z1234567890") + + def test_dns_status(self): + ds = DnsStatus(dns_health="Active", hosted_zone_id="Z123", managed_by_operator=True, record_name="test.example.com") + self.assertEqual(ds.dnsHealth, "Active") + self.assertTrue(ds.managedByOperator) + + def test_inference_status_with_dns(self): + status = InferenceEndpointConfigStatus(dns_status=DnsStatus(managed_by_operator=True, dns_health="Pending")) + self.assertEqual(status.dnsStatus.dnsHealth, "Pending") diff --git a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py index e6079eb9..07858573 100644 --- a/test/unit_tests/inference/test_hp_jumpstart_endpoint.py +++ b/test/unit_tests/inference/test_hp_jumpstart_endpoint.py @@ -8,6 +8,14 @@ SageMakerEndpoint, TlsConfig, Validations, + _HPJumpStartEndpoint, + DataCapture as JSDataCapture, + DataCaptureSagemakerEndpoint as JSDataCaptureSagemakerEndpoint, + DnsConfig as JSDnsConfig, + DnsStatus as JSDnsStatus, + DataCaptureStatus as JSDataCaptureStatus, + DataCaptureModelPodStatus as JSDataCaptureModelPodStatus, + JumpStartModelStatus, ) from sagemaker.hyperpod.common.config import Metadata @@ -492,4 +500,32 @@ def test_create_from_dict_missing_name_and_endpoint_name(self): self.assertIn( 'Input "name" is required if endpoint name is not provided', str(context.exception), - ) \ No newline at end of file + ) + + +class TestJumpStartDataCaptureDns(unittest.TestCase): + def test_jumpstart_with_data_capture(self): + ep = _HPJumpStartEndpoint( + model={"accept_eula": True, "model_id": "test"}, + server={"instance_type": "ml.g5.8xlarge"}, + data_capture=JSDataCapture(sagemaker_endpoint=JSDataCaptureSagemakerEndpoint(enabled=True)), + ) + self.assertIsNotNone(ep.dataCapture) + + def test_jumpstart_with_dns_config(self): + ep = _HPJumpStartEndpoint( + model={"accept_eula": True, "model_id": "test"}, + server={"instance_type": "ml.g5.8xlarge"}, + dns_config=JSDnsConfig(hosted_zone_id="Z999"), + ) + self.assertEqual(ep.dnsConfig.hostedZoneId, "Z999") + + def test_jumpstart_status_with_dns_and_data_capture(self): + status = JumpStartModelStatus( + dns_status=JSDnsStatus(managed_by_operator=True), + data_capture_status=JSDataCaptureStatus( + model_pod=JSDataCaptureModelPodStatus(status="Healthy"), + ), + ) + self.assertTrue(status.dnsStatus.managedByOperator) + self.assertEqual(status.dataCaptureStatus.modelPod.status, "Healthy")