[package] authors = ["nathanielsimard "] categories = ["science", "no-std", "embedded", "wasm"] description = "Optimizer building blocks for the Burn deep learning framework" documentation = "https://docs.rs/burn-optim" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-optim" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-optim" version.workspace = true [lints] workspace = true [features] default = [ "std", "burn-core/default", ] doc = [ "std", # Doc features "burn-core/doc", ] std = [ "burn-core/std", "num-traits/std", "serde/std", "log", ] tracing = [ "burn-collective?/tracing", "burn-core/tracing", "burn-cuda?/tracing", "burn-fusion?/tracing", "burn-remote?/tracing", "burn-rocm?/tracing", "burn-router?/tracing", "burn-tch?/tracing", "burn-wgpu?/tracing", ] collective = ["burn-collective"] test-cuda = [ "burn-cuda/default", ] # To use cuda during testing, default uses ndarray. test-rocm = [ "burn-rocm/default", ] # To use hip during testing, default uses ndarray. test-tch = [ "burn-tch/default", ] # To use tch during testing, default uses ndarray. test-wgpu = [ "burn-wgpu/default", ] # To use wgpu during testing, default uses ndarray. test-vulkan = [ "test-wgpu", "burn-wgpu/vulkan", ] # To use wgpu-spirv during testing, default uses ndarray. test-metal = [ "test-wgpu", "burn-wgpu/metal", ] # To use wgpu-spirv during testing, default uses ndarray. # Memory checks are disabled by default test-memory-checks = ["burn-fusion/memory-checks"] [dependencies] # ** Please make sure all dependencies support no_std when std is disabled ** burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", default-features = false } burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true, default-features = false } num-traits = { workspace = true } derive-new = { workspace = true } log = { workspace = true, optional = true } serde = { workspace = true, features = ["derive"] } # The same implementation of HashMap in std but with no_std support (only alloc crate is needed) hashbrown = { workspace = true, features = ["serde"] } # no_std compatible # FOR TESTING burn-cuda = { path = "../burn-cuda", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-rocm = { path = "../burn-rocm", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-remote = { path = "../burn-remote", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-router = { path = "../burn-router", version = "=0.21.0-pre.2", default-features = false, optional = true } burn-tch = { path = "../burn-tch", version = "=0.21.0-pre.2", optional = true } burn-wgpu = { path = "../burn-wgpu", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-fusion = { path = "../burn-fusion", version = "=0.21.0-pre.2", optional = true } [dev-dependencies] burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2" } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" } rstest = { workspace = true } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"]