@@ -117,9 +117,12 @@ def __init__(self, config: KubernetesConfig):
117117 def get_offers_by_requirements (
118118 self , requirements : Requirements
119119 ) -> list [InstanceOfferWithAvailability ]:
120+ gpu_request = 0
121+ if (gpu_spec := requirements .resources .gpu ) is not None :
122+ gpu_request = _get_gpu_request_from_gpu_spec (gpu_spec )
120123 instance_offers : list [InstanceOfferWithAvailability ] = []
121124 for node in self .api .list_node ().items :
122- if (instance_offer := _get_instance_offer_from_node (node )) is not None :
125+ if (instance_offer := _get_instance_offer_from_node (node , gpu_request )) is not None :
123126 instance_offers .extend (
124127 filter_offers_by_requirements ([instance_offer ], requirements )
125128 )
@@ -188,15 +191,15 @@ def run_job(
188191 if (cpu_max := resources_spec .cpu .count .max ) is not None :
189192 resources_limits ["cpu" ] = str (cpu_max )
190193 if (gpu_spec := resources_spec .gpu ) is not None :
191- gpu_min = gpu_spec .count .min
192- if gpu_min is not None and gpu_min > 0 :
194+ if (gpu_request := _get_gpu_request_from_gpu_spec (gpu_spec )) > 0 :
193195 gpu_resource , node_affinity , node_taint = _get_pod_spec_parameters_for_gpu (
194196 self .api , gpu_spec
195197 )
196- logger .debug ("Requesting GPU resource: %s=%d" , gpu_resource , gpu_min )
198+ logger .debug ("Requesting GPU resource: %s=%d" , gpu_resource , gpu_request )
197199 # Limit must be set (GPU resources cannot be overcommitted)
198200 # and must be equal to request.
199- resources_requests [gpu_resource ] = resources_limits [gpu_resource ] = str (gpu_min )
201+ resources_requests [gpu_resource ] = str (gpu_request )
202+ resources_limits [gpu_resource ] = str (gpu_request )
200203 # It should be NoSchedule, but we also add NoExecute toleration just in case.
201204 for effect in [TaintEffect .NO_SCHEDULE , TaintEffect .NO_EXECUTE ]:
202205 tolerations .append (
@@ -335,7 +338,10 @@ def update_provisioning_data(
335338 provisioning_data .hostname = get_or_error (service_spec .cluster_ip )
336339 pod_spec = get_or_error (pod .spec )
337340 node = self .api .read_node (name = get_or_error (pod_spec .node_name ))
338- if (instance_offer := _get_instance_offer_from_node (node )) is not None :
341+ # The original offer has a list of GPUs already sliced according to pod spec's GPU resource
342+ # request, which is inferred from dstack's GPUSpec, see _get_gpu_request_from_gpu_spec
343+ gpu_request = len (provisioning_data .instance_type .resources .gpus )
344+ if (instance_offer := _get_instance_offer_from_node (node , gpu_request )) is not None :
339345 provisioning_data .instance_type = instance_offer .instance
340346 provisioning_data .region = instance_offer .region
341347 provisioning_data .price = instance_offer .price
@@ -475,7 +481,13 @@ def terminate_gateway(
475481 )
476482
477483
478- def _get_instance_offer_from_node (node : client .V1Node ) -> Optional [InstanceOfferWithAvailability ]:
484+ def _get_gpu_request_from_gpu_spec (gpu_spec : GPUSpec ) -> int :
485+ return gpu_spec .count .min or 0
486+
487+
488+ def _get_instance_offer_from_node (
489+ node : client .V1Node , gpu_request : int
490+ ) -> Optional [InstanceOfferWithAvailability ]:
479491 try :
480492 node_name = get_or_error (get_or_error (node .metadata ).name )
481493 node_status = get_or_error (node .status )
@@ -499,7 +511,7 @@ def _get_instance_offer_from_node(node: client.V1Node) -> Optional[InstanceOffer
499511 cpus = cpus ,
500512 cpu_arch = cpu_arch ,
501513 memory_mib = memory_mib ,
502- gpus = gpus ,
514+ gpus = gpus [: gpu_request ] ,
503515 spot = False ,
504516 disk = Disk (size_mib = disk_size_mib ),
505517 ),
0 commit comments