[package] authors = ["nathanielsimard "] categories = ["science"] description = "Training crate for the Burn framework" edition.workspace = true keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"] license.workspace = true name = "burn-train" readme.workspace = true repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-train" documentation = "https://docs.rs/burn-train" version.workspace = true [lints] workspace = true [features] default = ["sys-metrics", "tui", "rl"] doc = ["default"] vision = ["burn-nn", "burn-store/pytorch", "burn-std/network", "dirs"] tracing = [ "burn-core/tracing", "burn-optim/tracing", "burn-collective?/tracing", ] sys-metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui"] rl = ["burn-rl"] # Distributed Data Parallel ddp = ["burn-collective", "burn-optim/collective"] [dependencies] burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", features = [ "dataset", "std", ], default-features = false } burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", features = [ "std", ], default-features = false } burn-rl = { path = "../burn-rl", version = "=0.21.0-pre.2", optional = true, default-features = false } burn-collective = { path = "../burn-collective", version = "=0.21.0-pre.2", optional = true } burn-nn = { path = "../burn-nn", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] } burn-store = { path = "../burn-store", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] } burn-std = { path = "../burn-std", version = "=0.21.0-pre.2", optional = true, default-features = false, features = ["std"] } dirs = { workspace = true, optional = true } log = { workspace = true } tracing-subscriber = { workspace = true } tracing-appender = { workspace = true } tracing-core = { workspace = true } # System Metrics nvml-wrapper = { workspace = true, optional = true } sysinfo = { workspace = true, optional = true } systemstat = { workspace = true, optional = true } # Text UI ratatui = { workspace = true, optional = true, features = [ "all-widgets", "crossterm", ] } # Utilities derive-new = { workspace = true } serde = { workspace = true, features = ["std", "derive"] } async-channel = { workspace = true } burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } rstest.workspace = true thiserror.workspace = true rand.workspace = true [dev-dependencies] burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" } burn-autodiff = { path = "../burn-autodiff", version = "=0.21.0-pre.2" } [package.metadata.docs.rs] features = ["doc"] rustdoc-args = ["--cfg", "docsrs"]