-
-
Notifications
You must be signed in to change notification settings - Fork 816
Description
Feature request
Windows package built with support for ROCm
Motivation
Now than ROCm is officially supported on Windows, it would help a lot of developers with training models more efficiently with the bitsandbytes package, especially hobbyists who may not want to commit to Linux.
Your contribution
I managed to get a working build (with a lot of help from Claude Opus 4.5) using the TheRock nightlies on my gfx1151. Here are my repro steps:
git clone https://github.com/bitsandbytes-foundation/bitsandbytes bnb
cd bnb
py -3.12 -m venv venv
venv\scripts\activate
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ torch torchaudio torchvision
pip install --index-url https://rocm.nightlies.amd.com/v2/gfx1151/ "rocm[libraries,devel]"
rocm-sdk init
pip install triton-windows transformers
(create rocmvariables.bat)
for /f "delims=" %%i in ('rocm-sdk path --root') do set "ROCM_ROOT=%%i"
for /f "delims=" %%i in ('rocm-sdk path --bin') do set "ROCM_BIN=%%i"
:: Set environment variables
set "ROCM_HOME=%ROCM_ROOT%"
set "ROCM_PATH=%ROCM_ROOT%"
set "HIP_PATH=%ROCM_ROOT%"
set "PATH=%ROCM_ROOT%\lib\llvm\bin;%ROCM_BIN%;%PATH%"
.\rocmvariables.bat
had Claude fix the build for me, git diff:
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 922b04b..10b5b05 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -263,6 +263,8 @@ endif()
if(WIN32)
# Export all symbols
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
+ # Prevent Windows SDK min/max macros from conflicting with std::min/std::max
+ add_compile_definitions(NOMINMAX)
endif()
if(MSVC)
@@ -330,14 +332,22 @@ if(BUILD_HIP)
find_package_and_print_version(hipsparse REQUIRED)
## hacky way of excluding hip::amdhip64 (with it linked many tests unexpectedly fail e.g. adam8bit because of inaccuracies)
- set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
- set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
- set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
+ ## On Windows, we need to link amdhip64 explicitly
+ if(NOT WIN32)
+ set_target_properties(hip::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
+ set_target_properties(hip-lang::host PROPERTIES INTERFACE_LINK_LIBRARIES "")
+ set(CMAKE_HIP_IMPLICIT_LINK_LIBRARIES "")
+ endif()
target_include_directories(bitsandbytes PRIVATE ${CMAKE_SOURCE_DIR} ${CMAKE_SOURCE_DIR}/include ${ROCM_PATH}/include /include)
target_link_directories(bitsandbytes PRIVATE ${ROCM_PATH}/lib /lib)
target_link_libraries(bitsandbytes PUBLIC roc::hipblas hip::hiprand roc::hipsparse)
+ # On Windows, link the HIP runtime and rocblas directly
+ if(WIN32)
+ target_link_libraries(bitsandbytes PUBLIC amdhip64 rocblas)
+ endif()
+
target_compile_definitions(bitsandbytes PUBLIC BNB_USE_HIP)
set_source_files_properties(${HIP_FILES} PROPERTIES LANGUAGE HIP)
set_target_properties(bitsandbytes PROPERTIES LINKER_LANGUAGE CXX)
diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py
index 71e7568..02261a7 100644
--- a/bitsandbytes/cuda_specs.py
+++ b/bitsandbytes/cuda_specs.py
@@ -1,6 +1,7 @@
import dataclasses
from functools import lru_cache
import logging
+import os
import re
import subprocess
from typing import Optional
@@ -83,10 +84,21 @@ def get_rocm_gpu_arch() -> str:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
- result = subprocess.run(["rocminfo"], capture_output=True, text=True)
- match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout)
+ # On Windows, use hipinfo.exe; on Linux, use rocminfo
+ if os.name == "nt":
+ cmd = ["hipinfo.exe"]
+ arch_pattern = r"gcnArchName:\s+(gfx[a-zA-Z\d]+)"
+ else:
+ cmd = ["rocminfo"]
+ arch_pattern = r"Name:\s+gfx([a-zA-Z\d]+)"
+
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ match = re.search(arch_pattern, result.stdout)
if match:
- return "gfx" + match.group(1)
+ if os.name == "nt":
+ return match.group(1)
+ else:
+ return "gfx" + match.group(1)
else:
return "unknown"
else:
@@ -107,8 +119,17 @@ def get_rocm_warpsize() -> int:
logger = logging.getLogger(__name__)
try:
if torch.version.hip:
- result = subprocess.run(["rocminfo"], capture_output=True, text=True)
- match = re.search(r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)", result.stdout)
+ # On Windows, use hipinfo.exe; on Linux, use rocminfo
+ if os.name == "nt":
+ cmd = ["hipinfo.exe"]
+ # hipinfo.exe output format: "warpSize: 32" or "warpSize: 64"
+ warp_pattern = r"warpSize:\s+(\d+)"
+ else:
+ cmd = ["rocminfo"]
+ warp_pattern = r"Wavefront Size:\s+([0-9]{2})\(0x[0-9]{2}\)"
+
+ result = subprocess.run(cmd, capture_output=True, text=True)
+ match = re.search(warp_pattern, result.stdout)
if match:
return int(match.group(1))
else:
diff --git a/csrc/ops.cuh b/csrc/ops.cuh
index 709432d..2debab1 100644
--- a/csrc/ops.cuh
+++ b/csrc/ops.cuh
@@ -10,6 +10,13 @@
#include <cstdint>
#include <iostream>
#include <stdio.h>
+#ifdef _WIN32
+#include <windows.h>
+#include <io.h>
+#include <process.h>
+#else
+#include <unistd.h>
+#endif
#include <common.h>
#include <cublasLt.h>
diff --git a/csrc/ops_hip.cuh b/csrc/ops_hip.cuh
index 4eb4462..5960148 100644
--- a/csrc/ops_hip.cuh
+++ b/csrc/ops_hip.cuh
@@ -11,7 +11,17 @@
#include <cstdint>
#include <iostream>
#include <stdio.h>
+
+#ifdef _WIN32
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#include <windows.h>
+#include <io.h>
+#include <process.h>
+#else
#include <unistd.h>
+#endif
#include <common.h>
#include <functional>
cmake -B build -G Ninja -DCOMPUTE_BACKEND=hip -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_COMPILER="C:/Program Files/Microsoft Visual Studio/2022/Community/VC/Tools/Llvm/x64/bin/clang++.exe" -DCMAKE_HIP_COMPILER="C:/projects/bnb/venv/Lib/site-packages/_rocm_sdk_devel/lib/llvm/bin/clang++.exe" -DCMAKE_PREFIX_PATH="C:/projects/bnb/venv/Lib/site-packages/_rocm_sdk_devel" -DBNB_ROCM_ARCH="gfx1151"
cmake --build build --config Release
install
pip install .
build wheel
pip wheel . -w dist
then tested by loading Qwen3-30B-A3B-Instruct-2507 in 4-bit and running inference