Replace helper functions with native burn functions

This commit is contained in:
Gadersd
2023-09-07 12:23:18 -04:00
parent a62795347f
commit f4c58c1790
20 changed files with 1091 additions and 950 deletions

View File

@@ -4,30 +4,34 @@ use crate::model::load::*;
use std::error::Error;
use burn::{
config::Config,
config::Config,
module::{Module, Param},
nn,
tensor::{
backend::Backend,
Tensor,
},
tensor::{backend::Backend, Tensor},
};
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> {
pub fn load_group_norm<B: Backend>(
path: &str,
device: &B::Device,
) -> Result<GroupNorm<B>, Box<dyn Error>> {
let n_group = load_usize::<B>("n_group", path, device)?.into();
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into();
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into();
let gamma = load_tensor::<B, 1>("weight", path, device)
.ok()
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
.into();
let beta = load_tensor::<B, 1>("bias", path, device)
.ok()
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
.into();
Ok(
GroupNorm {
n_group,
n_channel,
gamma,
beta,
eps,
}
)
}
Ok(GroupNorm {
n_group,
n_channel,
gamma,
beta,
eps,
})
}