Fix model initialization bugs

This commit is contained in:
Gadersd
2023-08-05 16:20:49 -04:00
parent 10ce5ac89d
commit ce91c8838f
4 changed files with 2 additions and 4 deletions

View File

@@ -369,7 +369,6 @@ pub struct UpsampleConfig {
impl UpsampleConfig {
fn init<B: Backend>(&self) -> Upsample<B> {
let conv = Conv2dConfig::new([self.n_channels, self.n_channels], [3, 3])
.with_stride([2, 2])
.with_padding(PaddingConfig2d::Explicit(1, 1))
.init();
@@ -493,7 +492,7 @@ pub struct TransformerBlockConfig {
impl TransformerBlockConfig {
fn init<B: Backend>(&self) -> TransformerBlock<B> {
let norm1 = nn::LayerNormConfig::new(self.n_state).init();
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
let attn1 = MultiHeadAttentionConfig::new(self.n_state, self.n_state, self.n_head).init();
let norm2 = nn::LayerNormConfig::new(self.n_state).init();
let attn2 = MultiHeadAttentionConfig::new(self.n_state, self.n_context_state, self.n_head).init();
let norm3 = nn::LayerNormConfig::new(self.n_state).init();