diff --git a/cmd/pod/create.go b/cmd/pod/create.go index b0ce127..6300df7 100644 --- a/cmd/pod/create.go +++ b/cmd/pod/create.go @@ -41,6 +41,7 @@ var ( createImageName string createTemplateID string createComputeType string + createMinCudaVersion string createGpuTypeID string createGpuCount int createVolumeInGb int @@ -61,6 +62,7 @@ func init() { createCmd.Flags().StringVar(&createTemplateID, "template-id", "", "template id (use 'runpodctl template search' to find templates)") createCmd.Flags().StringVar(&createImageName, "image", "", "docker image name (required if no template)") createCmd.Flags().StringVar(&createComputeType, "compute-type", "GPU", "compute type (GPU or CPU)") + createCmd.Flags().StringVar(&createMinCudaVersion, "min-cuda-version", "", "minimum cuda version required for gpu pod placement") createCmd.Flags().StringVar(&createGpuTypeID, "gpu-id", "", "gpu id (from 'runpodctl gpu list')") createCmd.Flags().IntVar(&createGpuCount, "gpu-count", 1, "number of gpus") createCmd.Flags().IntVar(&createVolumeInGb, "volume-in-gb", 0, "volume size in gb") @@ -100,6 +102,9 @@ func runCreate(cmd *cobra.Command, args []string) error { if computeType == "CPU" && gpuTypeID != "" { return fmt.Errorf("--gpu-id is not supported for compute type CPU") } + if computeType == "CPU" && strings.TrimSpace(createMinCudaVersion) != "" { + return fmt.Errorf("--min-cuda-version is only supported for compute type GPU") + } cloudType := strings.ToUpper(strings.TrimSpace(createCloudType)) if cloudType == "" { @@ -157,12 +162,22 @@ func createPodGraphQL(gpuTypeID, cloudType string, supportPublicIP bool) (map[st return nil, err } + req, err := buildCreatePodGQLRequest(gpuTypeID, cloudType, supportPublicIP) + if err != nil { + return nil, err + } + + return gqlClient.CreatePod(req) +} + +func buildCreatePodGQLRequest(gpuTypeID, cloudType string, supportPublicIP bool) (*api.CreatePodGQLInput, error) { req := &api.CreatePodGQLInput{ CloudType: cloudType, ContainerDiskInGb: createContainerDiskInGb, GpuCount: createGpuCount, GpuTypeId: gpuTypeID, ImageName: createImageName, + MinCudaVersion: strings.TrimSpace(createMinCudaVersion), Name: createName, StartSsh: createSSH, SupportPublicIp: supportPublicIP, @@ -198,7 +213,7 @@ func createPodGraphQL(gpuTypeID, cloudType string, supportPublicIP bool) (map[st } } - return gqlClient.CreatePod(req) + return req, nil } func createPodREST(computeType, gpuTypeID, cloudType string, supportPublicIP bool) (*api.Pod, error) { diff --git a/cmd/pod/create_test.go b/cmd/pod/create_test.go new file mode 100644 index 0000000..b0a32e8 --- /dev/null +++ b/cmd/pod/create_test.go @@ -0,0 +1,103 @@ +package pod + +import "testing" + +func TestCreateCmd_HasMinCudaVersionFlag(t *testing.T) { + flag := createCmd.Flags().Lookup("min-cuda-version") + if flag == nil { + t.Fatal("expected min-cuda-version flag") + } +} + +func TestBuildCreatePodGQLRequest_IncludesMinCudaVersion(t *testing.T) { + origName := createName + origImage := createImageName + origTemplateID := createTemplateID + origMinCudaVersion := createMinCudaVersion + origGpuCount := createGpuCount + origVolumeInGb := createVolumeInGb + origContainerDiskInGb := createContainerDiskInGb + origVolumeMountPath := createVolumeMountPath + origSSH := createSSH + origPorts := createPorts + origEnv := createEnv + origDataCenterIDs := createDataCenterIDs + origNetworkVolumeID := createNetworkVolumeID + + t.Cleanup(func() { + createName = origName + createImageName = origImage + createTemplateID = origTemplateID + createMinCudaVersion = origMinCudaVersion + createGpuCount = origGpuCount + createVolumeInGb = origVolumeInGb + createContainerDiskInGb = origContainerDiskInGb + createVolumeMountPath = origVolumeMountPath + createSSH = origSSH + createPorts = origPorts + createEnv = origEnv + createDataCenterIDs = origDataCenterIDs + createNetworkVolumeID = origNetworkVolumeID + }) + + createName = "cuda-pod" + createImageName = "runpod/test" + createTemplateID = "" + createMinCudaVersion = "12.6" + createGpuCount = 1 + createVolumeInGb = 50 + createContainerDiskInGb = 25 + createVolumeMountPath = "/workspace" + createSSH = true + createPorts = "22/tcp" + createEnv = `{"A":"1"}` + createDataCenterIDs = "DC-1,DC-2" + createNetworkVolumeID = "nv-123" + + req, err := buildCreatePodGQLRequest("NVIDIA GeForce RTX 4090", "SECURE", false) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if req.MinCudaVersion != "12.6" { + t.Fatalf("expected min cuda version, got %q", req.MinCudaVersion) + } + if req.DataCenterId != "DC-1" { + t.Fatalf("expected first data center id, got %q", req.DataCenterId) + } + if req.NetworkVolumeId != "nv-123" { + t.Fatalf("expected network volume id, got %q", req.NetworkVolumeId) + } + if len(req.Env) != 1 || req.Env[0].Key != "A" || req.Env[0].Value != "1" { + t.Fatalf("unexpected env payload: %#v", req.Env) + } +} + +func TestRunCreate_RejectsMinCudaVersionForCPU(t *testing.T) { + origTemplateID := createTemplateID + origImage := createImageName + origComputeType := createComputeType + origGpuTypeID := createGpuTypeID + origMinCudaVersion := createMinCudaVersion + + t.Cleanup(func() { + createTemplateID = origTemplateID + createImageName = origImage + createComputeType = origComputeType + createGpuTypeID = origGpuTypeID + createMinCudaVersion = origMinCudaVersion + }) + + createTemplateID = "" + createImageName = "ubuntu:22.04" + createComputeType = "CPU" + createGpuTypeID = "" + createMinCudaVersion = "12.6" + + err := runCreate(createCmd, nil) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "--min-cuda-version is only supported for compute type GPU" { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/docs/runpodctl_pod_create.md b/docs/runpodctl_pod_create.md index 21a248c..6d0435f 100644 --- a/docs/runpodctl_pod_create.md +++ b/docs/runpodctl_pod_create.md @@ -39,6 +39,7 @@ runpodctl pod create [flags] --gpu-id string gpu id (from 'runpodctl gpu list') -h, --help help for create --image string docker image name (required if no template) + --min-cuda-version string minimum cuda version required for gpu pod placement --name string pod name --network-volume-id string network volume id to attach --ports string comma-separated list of ports (e.g., '8888/http,22/tcp') @@ -59,4 +60,4 @@ runpodctl pod create [flags] * [runpodctl pod](runpodctl_pod.md) - manage gpu pods -###### Auto generated by spf13/cobra on 23-Mar-2026 +###### Auto generated by spf13/cobra on 9-Apr-2026 diff --git a/internal/api/graphql.go b/internal/api/graphql.go index 93b254b..9f22db6 100644 --- a/internal/api/graphql.go +++ b/internal/api/graphql.go @@ -227,6 +227,7 @@ type CreatePodGQLInput struct { GpuCount int `json:"gpuCount"` GpuTypeId string `json:"gpuTypeId,omitempty"` ImageName string `json:"imageName,omitempty"` + MinCudaVersion string `json:"minCudaVersion,omitempty"` Name string `json:"name,omitempty"` Ports string `json:"ports,omitempty"` StartSsh bool `json:"startSsh"` diff --git a/internal/api/graphql_pod_test.go b/internal/api/graphql_pod_test.go new file mode 100644 index 0000000..0ab796f --- /dev/null +++ b/internal/api/graphql_pod_test.go @@ -0,0 +1,54 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestCreatePod_IncludesMinCudaVersion(t *testing.T) { + var gotMinCudaVersion string + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + var input GraphQLInput + if err := json.NewDecoder(r.Body).Decode(&input); err != nil { + t.Fatalf("decode request: %v", err) + } + + payload, _ := input.Variables["input"].(map[string]interface{}) + if value, ok := payload["minCudaVersion"].(string); ok { + gotMinCudaVersion = value + } + + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": map[string]interface{}{ + "podFindAndDeployOnDemand": map[string]interface{}{ + "id": "pod-1", + }, + }, + }) + })) + defer server.Close() + + client := &GraphQLClient{ + url: server.URL, + apiKey: "test-key", + httpClient: server.Client(), + userAgent: "test", + } + + _, err := client.CreatePod(&CreatePodGQLInput{ + GpuCount: 1, + GpuTypeId: "NVIDIA GeForce RTX 4090", + ImageName: "runpod/test", + MinCudaVersion: "12.6", + StartSsh: true, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if gotMinCudaVersion != "12.6" { + t.Fatalf("expected minCudaVersion to be sent, got %q", gotMinCudaVersion) + } +}