pub mod load; use burn::{ config::Config, module::{Module, Param}, tensor::{ backend::Backend, Tensor, }, }; #[derive(Config)] pub struct GroupNormConfig { n_group: usize, n_channel: usize, #[config(default = 1e-5)] eps: f64, } impl GroupNormConfig { pub fn init(&self) -> GroupNorm { assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group); let n_per_group = self.n_channel / self.n_group; let gamma = Tensor::ones([self.n_channel]).into(); let beta = Tensor::zeros([self.n_channel]).into(); let eps = self.eps; GroupNorm { n_group: self.n_group, n_channel: self.n_channel, gamma, beta, eps, } } } #[derive(Module, Debug)] pub struct GroupNorm { n_group: usize, n_channel: usize, gamma: Param>, beta: Param>, eps: f64, } impl GroupNorm { pub fn forward(&self, x: Tensor) -> Tensor { let shape = x.shape(); let n_batch = shape.dims[0]; let num_elements = shape.num_elements(); let mut affine_shape = [1; D]; affine_shape[1] = self.n_channel; layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps ) .reshape(shape) .mul(self.gamma.val().reshape(affine_shape)) .add(self.beta.val().reshape(affine_shape)) } } pub fn layernorm(x: Tensor, eps: f64) -> Tensor { //let (var, mean) = x.clone().var_mean_bias(D - 1); //x.sub(mean).div(var.sqrt().add_scalar(eps)) let u = x.clone() - x.mean_dim(D - 1); u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() ) }