// The LIBTORCH environment variable can be used to specify the directory // where libtorch has been installed. // When not specified this script downloads the cpu version for libtorch // and extracts it in OUT_DIR. // // On Linux, the TORCH_CUDA_VERSION environment variable can be used, // like 9.0, 90, or cu90 to specify the version of CUDA to use for libtorch. use std::path::{Path, PathBuf}; use std::{env, fs}; const PYTHON_PRINT_PYTORCH_DETAILS: &str = r" import torch from torch.utils import cpp_extension print('LIBTORCH_VERSION:', torch.__version__.split('+')[0]) print('LIBTORCH_CXX11:', torch._C._GLIBCXX_USE_CXX11_ABI) for include_path in cpp_extension.include_paths(): print('LIBTORCH_INCLUDE:', include_path) for library_path in cpp_extension.library_paths(): print('LIBTORCH_LIB:', library_path) "; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum Os { Linux, Macos, Windows, } #[allow(dead_code)] #[derive(Debug, Clone)] struct SystemInfo { os: Os, cxx11_abi: String, libtorch_include_dirs: Vec, libtorch_lib_dir: PathBuf, } fn env_var_rerun(name: &str) -> Result { println!("cargo:rerun-if-env-changed={name}"); env::var(name) } impl SystemInfo { fn new() -> Option { let os = match env::var("CARGO_CFG_TARGET_OS") .expect("Unable to get TARGET_OS") .as_str() { "linux" => Os::Linux, "windows" => Os::Windows, "macos" => Os::Macos, os => panic!("unsupported TARGET_OS '{os}'"), }; // Locate the currently active Python binary, similar to: // https://github.com/PyO3/maturin/blob/243b8ec91d07113f97a6fe74d9b2dcb88086e0eb/src/target.rs#L547 let python_interpreter = match os { Os::Windows => PathBuf::from("python.exe"), Os::Linux | Os::Macos => { if env::var_os("VIRTUAL_ENV").is_some() { PathBuf::from("python") } else { PathBuf::from("python3") } } }; let mut libtorch_include_dirs = vec![]; let mut libtorch_lib_dir = None; let cxx11_abi = if env_var_rerun("LIBTORCH_USE_PYTORCH").is_ok() { let output = std::process::Command::new(&python_interpreter) .arg("-c") .arg(PYTHON_PRINT_PYTORCH_DETAILS) .output() .expect("error running python interpreter"); let mut cxx11_abi = None; for line in String::from_utf8_lossy(&output.stdout).lines() { match line.strip_prefix("LIBTORCH_CXX11: ") { Some("True") => cxx11_abi = Some("1".to_owned()), Some("False") => cxx11_abi = Some("0".to_owned()), _ => {} } if let Some(path) = line.strip_prefix("LIBTORCH_INCLUDE: ") { libtorch_include_dirs.push(PathBuf::from(path)) } if let Some(path) = line.strip_prefix("LIBTORCH_LIB: ") { libtorch_lib_dir = Some(PathBuf::from(path)) } } match cxx11_abi { Some(cxx11_abi) => cxx11_abi, None => panic!("no cxx11 abi returned by python {output:?}"), } } else { let libtorch = Self::prepare_libtorch_dir(os)?; let includes = env_var_rerun("LIBTORCH_INCLUDE") .map(PathBuf::from) .unwrap_or_else(|_| libtorch.clone()); let lib = env_var_rerun("LIBTORCH_LIB") .map(PathBuf::from) .unwrap_or_else(|_| libtorch.clone()); libtorch_include_dirs.push(includes.join("include")); libtorch_include_dirs.push(includes.join("include/torch/csrc/api/include")); if lib.ends_with("lib") { // DEP_TCH_LIBTORCH_LIB might already point to /lib libtorch_lib_dir = Some(lib); } else { libtorch_lib_dir = Some(lib.join("lib")); } env_var_rerun("LIBTORCH_CXX11_ABI").unwrap_or_else(|_| "1".to_owned()) }; let libtorch_lib_dir = libtorch_lib_dir?; Some(Self { os, cxx11_abi, libtorch_include_dirs, libtorch_lib_dir, }) } fn check_system_location(os: Os) -> Option { match os { Os::Linux => Path::new("/usr/lib/libtorch.so") .exists() .then(|| PathBuf::from("/usr")), _ => None, } } fn prepare_libtorch_dir(os: Os) -> Option { if let Ok(libtorch) = env_var_rerun("DEP_TCH_LIBTORCH_LIB") { Some(PathBuf::from(libtorch)) } else if let Ok(libtorch) = env_var_rerun("LIBTORCH") { Some(PathBuf::from(libtorch)) } else if let Some(pathbuf) = Self::check_system_location(os) { Some(pathbuf) } else { check_out_dir() } } fn make(&self, use_cuda: bool, use_hip: bool) { let cuda_dependency = if use_cuda || use_hip { "src/cuda_hack/dummy_cuda_dependency.cpp" } else { "src/cuda_hack/fake_cuda_dependency.cpp" }; println!("cargo:rerun-if-changed={cuda_dependency}"); match self.os { Os::Linux | Os::Macos => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) .flag(format!("-Wl,-rpath={}", self.libtorch_lib_dir.display())) .flag("-std=c++17") .flag(format!("-D_GLIBCXX_USE_CXX11_ABI={}", self.cxx11_abi)) .files(&[cuda_dependency]) .compile("burn-tch"); } Os::Windows => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .includes(&self.libtorch_include_dirs) .flag("/std:c++17") .files(&[cuda_dependency]) .compile("burn-tch"); } }; } fn make_cpu() { let cuda_dependency = "src/cuda_hack/fake_cuda_dependency.cpp"; println!("cargo:rerun-if-changed={cuda_dependency}"); let os = env::var("CARGO_CFG_TARGET_OS").expect("Unable to get TARGET_OS"); match os.as_str() { "windows" => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .flag("/std:c++17") .files(&[cuda_dependency]) .compile("burn-tch"); } _ => { cc::Build::new() .cpp(true) .pic(true) .warnings(false) .flag("-std=c++17") .files(&[cuda_dependency]) .compile("tch"); } }; } } fn check_out_dir() -> Option { let out_dir = env_var_rerun("OUT_DIR").ok()?; let libtorch_dir = PathBuf::from(out_dir).join("libtorch"); libtorch_dir.exists().then_some(libtorch_dir) } fn main() { let system_info = SystemInfo::new(); let out_dir = env_var_rerun("OUT_DIR").expect("Failed to get out dir"); let mut gpu_found = false; let found_dir = system_info.is_some(); if let Some(system_info) = &system_info { let si_lib = &system_info.libtorch_lib_dir; let use_cuda = si_lib.join("libtorch_cuda.so").exists() || si_lib.join("torch_cuda.dll").exists(); let use_hip = si_lib.join("libtorch_hip.so").exists() || si_lib.join("torch_hip.dll").exists(); system_info.make(use_cuda, use_hip); gpu_found = use_cuda || use_hip; } else { SystemInfo::make_cpu(); } let check_file = PathBuf::from(out_dir).join("tch_gpu_check.rs"); if gpu_found { fs::write(check_file, "#[allow(clippy::no_effect)]\n()").unwrap(); } else { let message = if !found_dir { r#"Could not find libtorch dir. If you are trying to use the automatically downloaded version, the path is not directly available on Windows. Instead, try setting the `LIBTORCH` environment variable for the manual download instructions. If the library has already been downloaded in the torch-sys OUT_DIR, you can point the variable to this path (or move the downloaded lib and point to it)."# } else { "No libtorch_cuda or libtorch_hip found. Download the GPU version of libtorch to use a GPU device" }; fs::write(check_file, format!("panic!(\"{message}\")")).unwrap(); } }