Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
parent a62795347f
commit f4c58c1790
20 changed files with 1091 additions and 950 deletions

View File

@@ -1,25 +1,27 @@
pub mod load;
use burn::{
config::Config,
config::Config,
module::{Module, Param},
tensor::{
backend::Backend,
Tensor,
},
tensor::{backend::Backend, Tensor},
};
#[derive(Config)]
pub struct GroupNormConfig {
n_group: usize,
n_channel: usize,
n_group: usize,
n_channel: usize,
#[config(default = 1e-5)]
eps: f64,
eps: f64,
}
impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group);
assert!(
self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}",
self.n_channel,
self.n_group
);
let n_per_group = self.n_channel / self.n_group;
@@ -29,22 +31,22 @@ impl GroupNormConfig {
let eps = self.eps;
GroupNorm {
n_group: self.n_group,
n_channel: self.n_channel,
gamma,
beta,
eps,
n_group: self.n_group,
n_channel: self.n_channel,
gamma,
beta,
eps,
}
}
}
#[derive(Module, Debug)]
pub struct GroupNorm<B: Backend> {
n_group: usize,
n_channel: usize,
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
eps: f64,
n_group: usize,
n_channel: usize,
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
eps: f64,
}
impl<B: Backend> GroupNorm<B> {
@@ -56,10 +58,17 @@ impl<B: Backend> GroupNorm<B> {
let mut affine_shape = [1; D];
affine_shape[1] = self.n_channel;
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps )
.reshape(shape)
.mul(self.gamma.val().reshape(affine_shape))
.add(self.beta.val().reshape(affine_shape))
layernorm(
x.reshape([
n_batch,
self.n_group,
num_elements / (n_batch * self.n_group),
]),
self.eps,
)
.reshape(shape)
.mul(self.gamma.val().reshape(affine_shape))
.add(self.beta.val().reshape(affine_shape))
}
}
@@ -68,5 +77,6 @@ pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tenso
//x.sub(mean).div(var.sqrt().add_scalar(eps))
let u = x.clone() - x.mean_dim(D - 1);
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() )
}
u.clone()
.div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
}