Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions projects/export/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ dependencies = [
"aframe",
"utils",
"jsonargparse>=4.27.1,<5",
"nvidia-cudnn-cu11",
"tensorrt",
"nvidia-cudnn-cu12",
"tensorrt>=10.0",
"urllib3>=2",
]

Expand Down
43 changes: 6 additions & 37 deletions projects/export/uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion projects/online/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ dev = [

[[tool.uv.index]]
name = "torch"
url = "https://download.pytorch.org/whl/cu118"
url = "https://download.pytorch.org/whl/cu124"
explicit = true

[tool.uv.sources]
Expand Down
4 changes: 3 additions & 1 deletion projects/train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@ dependencies = [
"filelock>=3.13.1,<4",
"s3fs>=2024,<2025",
"lightray>=0.2.3",
"kr8s>=0.20.0",
]

[project.scripts]
train = "train.cli:main"
train-remote = "train.remote:main"

[dependency-groups]
dev = [
Expand All @@ -46,7 +48,7 @@ dev = [

[[tool.uv.index]]
name = "torch"
url = "https://download.pytorch.org/whl/cu118"
url = "https://download.pytorch.org/whl/cu124"
explicit = true

[tool.uv.sources]
Expand Down
142 changes: 142 additions & 0 deletions projects/train/train/helm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
"""
Tools for deploying helm charts
"""

import logging
import subprocess
import time

# httpx-ws 0.9.0 references anyio.AsyncContextManagerMixin
# which was removed in anyio 4.0.
# Remove once httpx-ws releases a fix.
import anyio

if not hasattr(anyio, "AsyncContextManagerMixin"):
anyio.AsyncContextManagerMixin = object

import kr8s

CHART_REPO = "https://github.com/EthanMarx/lightray/releases/download/"


def authenticate():
result = subprocess.run(
["kubectl", "cluster-info"],
stdout=subprocess.DEVNULL,
stderr=subprocess.PIPE,
)
if result.returncode != 0:
logging.warning(
"kubectl cluster-info failed (returncode %d): %s",
result.returncode,
result.stderr.decode().strip(),
)


# used to monkey patch kr8s api refresh method;
# see https://github.com/kr8s-org/kr8s/issues/214
async def auth(self):
"""
Replacement reauthenticate method that
uses kubectl to refresh the OIDC token
"""
authenticate()
await self._load_kubeconfig()


def setup_kr8s_auth():
"""
Return a kr8s API instance with its reauthenticate method patched
to use kubectl for OIDC token refresh.

Call this before any kr8s operations that require cluster access.
"""
api = kr8s.api()
api.auth.reauthenticate = auth.__get__(api.auth, type(api.auth))
return api


class HelmChart:
def __init__(self, chart_url: str, release: str):
self.chart_url = chart_url
self.release = release
base_cmd = ["helm", "install", self.release, self.chart_url]

self.base_cmd = base_cmd
self.api = setup_kr8s_auth()

def install(self):
logging.info(f"Installing chart from {self.chart_url}")
try:
subprocess.run(self.base_cmd, check=True)
except subprocess.CalledProcessError as e:
logging.error(f"Error installing helm chart: {e}")
raise

def build_command(self, values: dict[str, str]) -> list[str]:
for k, v in values.items():
self.base_cmd += ["--set", f"{k}={v}"]

def uninstall(self):
subprocess.run(["helm", "uninstall", self.release], check=True)

def wait(self):
raise NotImplementedError

def __enter__(self):
self.install()
self.wait()
return self

def __exit__(self, exc_type, exc_value, traceback):
self.uninstall()
super().__exit__(exc_type, exc_value, traceback)


class RayCluster(HelmChart):
def __init__(
self,
release: str,
chart_path: str | None = None,
chart_version: str = "0.1.3",
):
# if no chart path is provided, use the chart
# in the github repo
if chart_path is None:
chart_path = f"{CHART_REPO}/ray-cluster-{chart_version}"
chart_path += f"/ray-cluster-{chart_version}.tgz"
super().__init__(chart_path, release)

def get_ip(self):
services = kr8s.get("service")
services = [s for s in services if s.name.startswith(self.release)]
service = [s for s in services if s.spec["type"] == "LoadBalancer"][0]
service.refresh()
return service.status.loadBalancer.ingress[0].ip

def get_pods(self):
pods = kr8s.get("pod")
# filter for pods related to this release
# and that aren't terminating from a previous run
pods = [
p
for p in pods
if p.name.startswith(self.release)
and p.status.phase in ["Pending", "Running"]
]

head = [p for p in pods if "head" in p.name][0]
workers = [p for p in pods if "worker" in p.name]
return head, workers

def wait(self):
# get pods related to this release
head, workers = self.get_pods()

# wait for pods to be ready;
# can subclass to define "readiness"
ready = False
while not ready:
ready = any(p.ready() for p in workers)
ready = ready and head.ready()
time.sleep(2)
Loading
Loading