Add custom backend to enable flash attention
This commit is contained in:
10
Cargo.toml
10
Cargo.toml
@@ -6,15 +6,8 @@ edition = "2021"
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
default = ["torch-backend"]
|
||||
torch-backend = ["burn-tch"]
|
||||
wgpu-backend = ["burn-wgpu"]
|
||||
|
||||
[dependencies.burn-tch]
|
||||
package = "burn-tch"
|
||||
git = "https://github.com/burn-rs/burn.git"
|
||||
optional = true
|
||||
|
||||
[dependencies.burn-wgpu]
|
||||
package = "burn-wgpu"
|
||||
git = "https://github.com/burn-rs/burn.git"
|
||||
@@ -23,6 +16,9 @@ optional = true
|
||||
[dependencies]
|
||||
burn = { git = "https://github.com/burn-rs/burn.git" }
|
||||
burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" }
|
||||
burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" }
|
||||
burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" }
|
||||
tch = "0.13.0"
|
||||
serde = {version = "1.0.171", features = ["std", "derive"]}
|
||||
npy = "0.4.0"
|
||||
num-traits = "0.2.15"
|
||||
|
||||
Reference in New Issue
Block a user