Add custom backend to enable flash attention

This commit is contained in:
Gadersd
2023-09-07 12:54:27 -04:00
parent f4c58c1790
commit 01b1aea897
7 changed files with 177 additions and 34 deletions

View File

@@ -6,15 +6,8 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
default = ["torch-backend"]
torch-backend = ["burn-tch"]
wgpu-backend = ["burn-wgpu"]
[dependencies.burn-tch]
package = "burn-tch"
git = "https://github.com/burn-rs/burn.git"
optional = true
[dependencies.burn-wgpu]
package = "burn-wgpu"
git = "https://github.com/burn-rs/burn.git"
@@ -23,6 +16,9 @@ optional = true
[dependencies]
burn = { git = "https://github.com/burn-rs/burn.git" }
burn-ndarray = { package = "burn-ndarray", git = "https://github.com/burn-rs/burn.git" }
burn-tch = { package = "burn-tch", git = "https://github.com/burn-rs/burn.git" }
burn-autodiff = { package = "burn-autodiff", git = "https://github.com/burn-rs/burn.git" }
tch = "0.13.0"
serde = {version = "1.0.171", features = ["std", "derive"]}
npy = "0.4.0"
num-traits = "0.2.15"