Replace helper functions with native burn functions
This commit is contained in:
@@ -4,30 +4,34 @@ use crate::model::load::*;
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
||||
pub fn load_group_norm<B: Backend>(
|
||||
path: &str,
|
||||
device: &B::Device,
|
||||
) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
||||
let n_group = load_usize::<B>("n_group", path, device)?.into();
|
||||
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into();
|
||||
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into();
|
||||
let gamma = load_tensor::<B, 1>("weight", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
|
||||
.into();
|
||||
let beta = load_tensor::<B, 1>("bias", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
|
||||
.into();
|
||||
|
||||
Ok(
|
||||
GroupNorm {
|
||||
n_group,
|
||||
n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
}
|
||||
)
|
||||
}
|
||||
Ok(GroupNorm {
|
||||
n_group,
|
||||
n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user