use super::GroupNorm; use crate::model::load::*; use std::error::Error; use burn::{ config::Config, module::{Module, Param}, nn, tensor::{backend::Backend, Tensor}, }; pub fn load_group_norm( path: &str, device: &B::Device, ) -> Result, Box> { let n_group = load_usize::("n_group", path, device)?.into(); let n_channel = load_usize::("n_channel", path, device)?.into(); let eps = load_f32::("eps", path, device)?.into(); let gamma = Param::from_tensor(load_tensor::("weight", path, device) .ok() .unwrap_or_else(|| Tensor::ones([n_channel], device)) ); let beta = Param::from_tensor(load_tensor::("bias", path, device) .ok() .unwrap_or_else(|| Tensor::zeros([n_channel], device)) ); Ok(GroupNorm { n_group, n_channel, gamma, beta, eps, }) }