Add files via upload
Add initial project files
This commit is contained in:
33
src/model/groupnorm/load.rs
Normal file
33
src/model/groupnorm/load.rs
Normal file
@@ -0,0 +1,33 @@
|
||||
use super::GroupNorm;
|
||||
use crate::model::load::*;
|
||||
|
||||
use std::error::Error;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
pub fn load_group_norm<B: Backend>(path: &str, device: &B::Device) -> Result<GroupNorm<B>, Box<dyn Error>> {
|
||||
let n_group = load_usize::<B>("n_group", path, device)?.into();
|
||||
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let gamma = load_tensor::<B, 1>("weight", path, device).ok().unwrap_or_else(|| Tensor::ones_device([n_channel], device)).into();
|
||||
let beta = load_tensor::<B, 1>("bias", path, device).ok().unwrap_or_else(|| Tensor::zeros_device([n_channel], device)).into();
|
||||
|
||||
Ok(
|
||||
GroupNorm {
|
||||
n_group,
|
||||
n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
}
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user