Skip to content

Commit

Permalink
WIP TextGenerationPipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Aug 31, 2024
1 parent 7bb0e91 commit ecbdcee
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use anyhow::Result;
use candle_core::Device;
use candle_holder_examples::get_device_from_args;
use candle_holder_pipelines::TextClassificationPipeline;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Text Generation Pipeline

## Running the example

```bash
cargo run --example text_generation_pipeline
```
32 changes: 32 additions & 0 deletions candle-holder-examples/examples/text_generation_pipeline/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use anyhow::Result;
use candle_holder_examples::get_device_from_args;
use candle_holder_models::{GenerationConfig, GenerationParams};
use candle_holder_pipelines::TextGenerationPipeline;
use candle_holder_tokenizers::Message;

fn main() -> Result<()> {
let device = get_device_from_args()?;
println!("Device: {:?}", device);

let pipeline =
TextGenerationPipeline::new("meta-llama/Meta-Llama-3.1-8B-Instruct", &device, None, None)?;

let generations = pipeline.run(
vec![Message::user("How much is 2 + 2?")],
Some(GenerationParams {
generation_config: Some(GenerationConfig {
do_sample: true,
max_new_tokens: Some(256),
top_p: Some(0.9),
top_k: None,
temperature: 0.6,
..GenerationConfig::default()
}),
..Default::default()
}),
)?;

println!("`pipeline.run` results: {:?}", generations);

Ok(())
}
144 changes: 130 additions & 14 deletions candle-holder-pipelines/src/text_generation.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,68 @@
use candle_core::{DType, Device};
use candle_holder::{FromPretrainedParameters, Result};
use candle_holder_models::{AutoModelForCausalLM, PreTrainedModel};
use candle_holder_tokenizers::{
AutoTokenizer, BatchEncoding, Padding, PaddingOptions, PaddingSide, Tokenizer,
use candle_holder_models::{
generation::generate::GenerateOutput, AutoModelForCausalLM, GenerationParams, PreTrainedModel,
};
use candle_holder_tokenizers::{AutoTokenizer, BatchEncoding, Message, PaddingSide, Tokenizer};

enum RunTextGenerationInput {
Input(String),
Messages(Vec<Message>),
}

impl From<&str> for RunTextGenerationInput {
fn from(input: &str) -> Self {
RunTextGenerationInput::Input(input.to_owned())
}
}

impl From<String> for RunTextGenerationInput {
fn from(input: String) -> Self {
RunTextGenerationInput::Input(input)
}
}

impl From<Vec<Message>> for RunTextGenerationInput {
fn from(messages: Vec<Message>) -> Self {
RunTextGenerationInput::Messages(messages)
}
}

enum RunBatchTextGenerationInput {
Inputs(Vec<String>),
Messages(Vec<Vec<Message>>),
}

impl From<Vec<String>> for RunBatchTextGenerationInput {
fn from(inputs: Vec<String>) -> Self {
RunBatchTextGenerationInput::Inputs(inputs)
}
}

impl From<Vec<Vec<Message>>> for RunBatchTextGenerationInput {
fn from(messages: Vec<Vec<Message>>) -> Self {
RunBatchTextGenerationInput::Messages(messages)
}
}

impl From<RunTextGenerationInput> for RunBatchTextGenerationInput {
fn from(input: RunTextGenerationInput) -> Self {
match input {
RunTextGenerationInput::Input(input) => {
RunBatchTextGenerationInput::Inputs(vec![input])
}
RunTextGenerationInput::Messages(messages) => {
RunBatchTextGenerationInput::Messages(vec![messages])
}
}
}
}

#[derive(Debug, Clone)]
pub struct TextGenerationPipelineOutput {
text: Option<String>,
messages: Option<Vec<Message>>,
}

/// A pipeline for generating text with a causal language model.
pub struct TextGenerationPipeline {
Expand Down Expand Up @@ -41,20 +100,77 @@ impl TextGenerationPipeline {
})
}

fn preprocess(&self, inputs: Vec<String>) -> Result<BatchEncoding> {
let mut encodings = self.tokenizer.encode(
inputs,
true,
Some(PaddingOptions::new(Padding::Longest, None)),
)?;
fn preprocess<I: Into<RunBatchTextGenerationInput>>(
&self,
inputs: I,
) -> Result<(BatchEncoding, Vec<String>, bool)> {
let (mut encodings, text_inputs, are_messages) = match inputs.into() {
RunBatchTextGenerationInput::Inputs(inputs) => (
self.tokenizer.encode(inputs.clone(), false, None)?,
inputs,
false,
),
RunBatchTextGenerationInput::Messages(messages) => {
let messages: Result<Vec<String>> = messages
.into_iter()
.map(|messages| self.tokenizer.apply_chat_template(messages))
.collect();
let inputs = messages?;
(
self.tokenizer.encode(inputs.clone(), false, None)?,
inputs,
true,
)
}
};
encodings.to_device(&self.device)?;
Ok(encodings)
Ok((encodings, text_inputs, are_messages))
}

fn postprocess(
&self,
outputs: Vec<GenerateOutput>,
text_inputs: Vec<String>,
are_messages: bool,
) -> Result<Vec<Vec<TextGenerationPipelineOutput>>> {
let mut results = vec![];
for (output, text_input) in outputs.iter().zip(text_inputs) {
let mut outputs = vec![];
let text_input_len = text_input.len();
for sequence in output.get_sequences() {
let text = self.tokenizer.decode(&sequence[..], true)?;
println!("text input: {}", text_input);
println!("text: {}", text);
let generated: String = text.chars().skip(text_input_len).collect();
println!("generated: {}", generated);
outputs.push(TextGenerationPipelineOutput {
text: Some(generated),
messages: None,
});
}
results.push(outputs);
}
Ok(results)
}

pub fn run<I: Into<String>>(&self, input: I) -> Result<()> {
let _encodings = self.preprocess(vec![input.into()])?;
Ok(())
pub fn run<I: Into<RunTextGenerationInput>>(
&self,
input: I,
params: Option<GenerationParams>,
) -> Result<Vec<TextGenerationPipelineOutput>> {
let (encodings, text_inputs, are_messages) = self.preprocess(input.into())?;
let outputs = self
.model
.generate(encodings.get_input_ids(), params.unwrap_or_default())?;
Ok(self.postprocess(outputs, text_inputs, are_messages)?[0].clone())
}

pub fn run_batch<I: Into<String>>(&mut self, inputs: Vec<I>) {}
// pub fn run_batch<I: Into<RunBatchTextGenerationInput>>(
// &mut self,
// inputs: I,
// params: Option<GenerationParams>,
// ) -> Result<TextGenerationPipelineOutput> {
// let (encodings, are_messages) = self.preprocess(inputs)?;
// Ok(())
// }
}
2 changes: 1 addition & 1 deletion candle-holder-tokenizers/src/chat_template.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use minijinja_contrib::pycompat;
use serde::{Deserialize, Serialize};

/// Represents a message in a conversation between a user and an assistant.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
/// The role of the message. Can be "system", "user", or "assistant".
role: String,
Expand Down
4 changes: 3 additions & 1 deletion candle-holder-tokenizers/src/from_pretrained.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,9 @@ impl TokenizerInfo {
pub fn get_tokenizer_class(&self) -> &str {
if let Some(config) = &self.config {
if let Some(tokenizer_class) = &config.tokenizer_class {
return tokenizer_class;
if tokenizer_class != "PreTrainedTokenizerFast" {
return tokenizer_class;
}
}
}

Expand Down

0 comments on commit ecbdcee

Please sign in to comment.