Skip to content

Commit

Permalink
Sketch a pixtral example.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 30, 2024
1 parent ba5834c commit 067e3b1
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 9 deletions.
2 changes: 1 addition & 1 deletion candle-transformers/src/models/pixtral/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
mod vision_model;
pub mod vision_model;
13 changes: 5 additions & 8 deletions candle-transformers/src/models/pixtral/vision_model.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#![allow(unused)]
use candle::{DType, IndexOp, Module, Result, Tensor, D};
use candle::{DType, Module, Result, Tensor, D};
use candle_nn::{linear_b, rms_norm, Linear, RmsNorm, VarBuilder};

fn default_act() -> candle_nn::Activation {
Expand All @@ -21,7 +20,7 @@ pub struct Config {
}

impl Config {
fn pixtral_12b_2409() -> Self {
pub fn pixtral_12b_2409() -> Self {
Self {
hidden_size: 1024,
num_channels: 3,
Expand Down Expand Up @@ -186,6 +185,7 @@ impl Transformer {
let vb = vb.pp("layers");
for layer_idx in 0..cfg.num_hidden_layers {
let layer = AttentionLayer::new(cfg, vb.pp(layer_idx))?;
layers.push(layer)
}
Ok(Self { layers })
}
Expand Down Expand Up @@ -244,7 +244,7 @@ impl RotaryEmbedding {
}

fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
let cos = &self.cos;
let sin = &self.sin;
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
Expand Down Expand Up @@ -291,10 +291,7 @@ impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let patch_embeds = xs.apply(&self.patch_conv)?;
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
// TODO: emb
// TODO: generate_block_attention_mask
let attn_mask = None;
self.transformer
.forward(&patch_embeds, &self.patch_positional_embedding, attn_mask)
.forward(&patch_embeds, &self.patch_positional_embedding, None)
}
}

0 comments on commit 067e3b1

Please sign in to comment.