Skip to content

Commit

Permalink
Merge branch 'main' into adapter_e2e_test
Browse files Browse the repository at this point in the history
  • Loading branch information
bangqipropel authored Jun 27, 2024
2 parents cc0f886 + f70dcc5 commit 923dca7
Show file tree
Hide file tree
Showing 3 changed files with 222 additions and 37 deletions.
16 changes: 12 additions & 4 deletions api/v1alpha1/workspace_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,13 @@ func (w *Workspace) Validate(ctx context.Context) (errs *apis.FieldError) {
errs = errs.Also(w.validateCreate().ViaField("spec"))
if w.Inference != nil {
// TODO: Add Adapter Spec Validation - Including DataSource Validation for Adapter
errs = errs.Also(w.Resource.validateCreate(*w.Inference).ViaField("resource"),
errs = errs.Also(w.Resource.validateCreateWithInference(w.Inference).ViaField("resource"),

Check warning on line 48 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L48

Added line #L48 was not covered by tests
w.Inference.validateCreate().ViaField("inference"))
}
if w.Tuning != nil {
// TODO: Add validate resource based on Tuning Spec
errs = errs.Also(w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))
errs = errs.Also(w.Resource.validateCreateWithTuning(w.Tuning).ViaField("resource"),
w.Tuning.validateCreate(ctx, w.Namespace).ViaField("tuning"))

Check warning on line 54 in api/v1alpha1/workspace_validation.go

View check run for this annotation

Codecov / codecov/patch

api/v1alpha1/workspace_validation.go#L53-L54

Added lines #L53 - L54 were not covered by tests
}
} else {
klog.InfoS("Validate update", "workspace", fmt.Sprintf("%s/%s", w.Namespace, w.Name))
Expand Down Expand Up @@ -299,7 +300,14 @@ func (r *DataDestination) validateUpdate(old *DataDestination) (errs *apis.Field
return errs
}

func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.FieldError) {
func (r *ResourceSpec) validateCreateWithTuning(tuning *TuningSpec) (errs *apis.FieldError) {
if *r.Count > 1 {
errs = errs.Also(apis.ErrInvalidValue("Tuning does not currently support multinode configurations. Please set the node count to 1. Future support with DeepSpeed will allow this.", "count"))
}
return errs
}

func (r *ResourceSpec) validateCreateWithInference(inference *InferenceSpec) (errs *apis.FieldError) {
var presetName string
if inference.Preset != nil {
presetName = strings.ToLower(string(inference.Preset.Name))
Expand All @@ -308,7 +316,7 @@ func (r *ResourceSpec) validateCreate(inference InferenceSpec) (errs *apis.Field

// Check if instancetype exists in our SKUs map
if skuConfig, exists := SupportedGPUConfigs[instanceType]; exists {
if inference.Preset != nil {
if presetName != "" {
model := plugin.KaitoModelRegister.MustGet(presetName) // InferenceSpec has been validated so the name is valid.
// Validate GPU count for given SKU
machineCount := *r.Count
Expand Down
111 changes: 78 additions & 33 deletions api/v1alpha1/workspace_validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset bool
errContent string // Content expect error to include, if any
expectErrs bool
validateTuning bool // To indicate if we are testing tuning validation
}{
{
name: "Valid resource",
Expand All @@ -214,6 +215,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "",
expectErrs: false,
validateTuning: false,
},
{
name: "Insufficient total GPU memory",
Expand All @@ -227,6 +229,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "Insufficient total GPU memory",
expectErrs: true,
validateTuning: false,
},

{
Expand All @@ -241,6 +244,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "Insufficient number of GPUs",
expectErrs: true,
validateTuning: false,
},
{
name: "Insufficient per GPU memory",
Expand All @@ -254,6 +258,7 @@ func TestResourceSpecValidateCreate(t *testing.T) {
preset: true,
errContent: "Insufficient per GPU memory",
expectErrs: true,
validateTuning: false,
},

{
Expand All @@ -262,27 +267,30 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_invalid_sku",
Count: pointerToInt(1),
},
errContent: "Unsupported instance",
expectErrs: true,
errContent: "Unsupported instance",
expectErrs: true,
validateTuning: false,
},
{
name: "Only Template set",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NV12s_v3",
Count: pointerToInt(1),
},
preset: false,
errContent: "",
expectErrs: false,
preset: false,
errContent: "",
expectErrs: false,
validateTuning: false,
},
{
name: "N-Prefix SKU",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_Nsku",
Count: pointerToInt(1),
},
errContent: "",
expectErrs: false,
errContent: "",
expectErrs: false,
validateTuning: false,
},

{
Expand All @@ -291,44 +299,81 @@ func TestResourceSpecValidateCreate(t *testing.T) {
InstanceType: "Standard_Dsku",
Count: pointerToInt(1),
},
errContent: "",
expectErrs: false,
errContent: "",
expectErrs: false,
validateTuning: false,
},
{
name: "Tuning validation with single node",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
Count: pointerToInt(1),
},
errContent: "",
expectErrs: false,
validateTuning: true,
},
{
name: "Tuning validation with multinode",
resourceSpec: &ResourceSpec{
InstanceType: "Standard_NC6",
Count: pointerToInt(2),
},
errContent: "Tuning does not currently support multinode configurations",
expectErrs: true,
validateTuning: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
var spec InferenceSpec
if tc.validateTuning {
tuningSpec := &TuningSpec{}
errs := tc.resourceSpec.validateCreateWithTuning(tuningSpec)
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreateWithTuning() errors = %v, expectErrs %v", errs, tc.expectErrs)
}

if tc.preset {
spec = InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
},
},
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateCreateWithTuning() error message = %v, expected to contain = %v", errMsg, tc.errContent)
}
}
} else {
spec = InferenceSpec{
Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
var spec InferenceSpec

if tc.preset {
spec = InferenceSpec{
Preset: &PresetSpec{
PresetMeta: PresetMeta{
Name: ModelName("test-validation"),
},
},
}
} else {
spec = InferenceSpec{
Template: &v1.PodTemplateSpec{}, // Assuming a non-nil TemplateSpec implies it's set
}
}
}

gpuCountRequirement = tc.modelGPUCount
totalGPUMemoryRequirement = tc.modelTotalGPUMemory
perGPUMemoryRequirement = tc.modelPerGPUMemory
gpuCountRequirement = tc.modelGPUCount
totalGPUMemoryRequirement = tc.modelTotalGPUMemory
perGPUMemoryRequirement = tc.modelPerGPUMemory

errs := tc.resourceSpec.validateCreate(spec)
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
}
errs := tc.resourceSpec.validateCreateWithInference(&spec)
hasErrs := errs != nil
if hasErrs != tc.expectErrs {
t.Errorf("validateCreate() errors = %v, expectErrs %v", errs, tc.expectErrs)
}

// If there is an error and errContent is not empty, check that the error contains the expected content.
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
// If there is an error and errContent is not empty, check that the error contains the expected content.
if hasErrs && tc.errContent != "" {
errMsg := errs.Error()
if !strings.Contains(errMsg, tc.errContent) {
t.Errorf("validateCreate() error message = %v, expected to contain = %v", errMsg, tc.errContent)
}
}
}
})
Expand Down
132 changes: 132 additions & 0 deletions docs/tuning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Kaito Tuning Workspace API

This guide provides instructions on how to use the Kaito Tuning Workspace API for parameter-efficient fine-tuning (PEFT) of models. The API supports methods like LoRA and QLoRA and allows users to specify their own datasets and configuration settings.

## Getting Started

To use the Kaito Tuning Workspace API, you need to define a Workspace custom resource (CR). Below are examples of how to define the CR and its various components.

## Example Workspace Definitions
Here are three examples of using the API to define a workspace for tuning different models:

Example 1: Tuning [`phi-3-mini`](../../examples/fine-tuning/kaito_workspace_tuning_phi_3.yaml)

Example 2: Tuning `falcon-7b`
```yaml
apiVersion: kaito.sh/v1alpha1
kind: Workspace
metadata:
name: workspace-tuning-falcon
resource:
instanceType: "Standard_NC6s_v3"
labelSelector:
matchLabels:
app: tuning-phi-3-falcon
tuning:
preset:
name: falcon-7b
method: qlora
input:
image: ACR_REPO_HERE.azurecr.io/IMAGE_NAME_HERE:0.0.1
imagePullSecrets:
- IMAGE_PULL_SECRETS_HERE
output:
image: ACR_REPO_HERE.azurecr.io/IMAGE_NAME_HERE:0.0.1 # Tuning Output
imagePushSecret: aimodelsregistrysecret

```
Generic TuningSpec Structure:
```yaml
tuning:
preset:
name: preset-model
method: lora or qlora
config: custom-configmap (optional)
input: # Image or URL
urls:
- "https://example.com/dataset.parquet?download=true"
# Future updates will include support for v1.Volume as an input type.
output: # Image
image: "youracr.azurecr.io/custom-adapter:0.0.1"
imagePushSecret: youracrsecret
```
## Default ConfigMaps
The default configuration for different tuning methods can be specified
using ConfigMaps. Below are examples for LoRA and QLoRA methods.
- [Default LoRA ConfigMap](../../charts/kaito/workspace/templates/lora-params.yaml)
- [QLoRA ConfigMap](../../charts/kaito/workspace/templates/qlora-params.yaml)
## Using Custom ConfigMaps
You can specify your own custom ConfigMap and include it in the `Config`
field of the `TuningSpec`

For more information on configurable parameters, please refer to the respective documentation links provided in the default ConfigMap examples.

## Key Parameters in Kaito ConfigMap Structure
The following sections highlight important parameters for each configuration area in the Kaito ConfigMap. For a complete list of parameters, please refer to the provided Hugging Face documentation links.

ModelConfig [Full List](https://huggingface.co/docs/transformers/v4.40.2/en/model_doc/auto#transformers.AutoModelForCausalLM.from_pretrained)
- torch_dtype: Specifies the data type for PyTorch tensors, e.g., "bfloat16".
- local_files_only: Indicates whether to only use local files.
- device_map: Configures device mapping for the model, typically "auto".

QuantizationConfig [Full List](https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/quantization#transformers.BitsAndBytesConfig)
- load_in_4bit: Enables loading the model with 4-bit precision.
- bnb_4bit_quant_type: Specifies the type of 4-bit quantization, e.g., "nf4".
- bnb_4bit_compute_dtype: Data type for computation, e.g., "bfloat16".
- bnb_4bit_use_double_quant: Enables double quantization.

LoraConfig [Full List](https://huggingface.co/docs/peft/v0.8.2/en/package_reference/lora#peft.LoraConfig)
- r: Rank of the low-rank matrices used in LoRA.
- lora_alpha: Scaling factor for LoRA.
- lora_dropout: Dropout rate for LoRA layers.

TrainingArguments [Full List](https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/trainer#transformers.TrainingArguments)
- output_dir: Directory where the training results will be saved.
- ddp_find_unused_parameters: Flag to handle unused parameters during distributed training.
- save_strategy: Strategy for saving checkpoints, e.g., "epoch".
- per_device_train_batch_size: Batch size per device during training.
- num_train_epochs: Total number of training epochs to perform, defaults to 3.0.

DataCollator [Full List](https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/data_collator#transformers.DataCollatorForLanguageModeling)
- mlm: Masked language modeling flag.

DatasetConfig [Full List](https://github.com/Azure/kaito/blob/main/presets/tuning/text-generation/cli.py#L44)
- shuffle_dataset: Whether to shuffle the dataset.
- train_test_split: Proportion of data used for training, typically set to 1 for using all data.

## Expected Dataset Format
The dataset should follow a specific format. For example:

- Conversational Format
```json
{
"messages": [
{"role": "system", "content": "Marv is a factual chatbot that is also sarcastic."},
{"role": "user", "content": "What's the capital of France?"},
{"role": "assistant", "content": "Paris, as if everyone doesn't know that already."}
]
}
```
For a practical example, refer to [HuggingFace Dolly 15k OAI-style dataset](https://huggingface.co/datasets/philschmid/dolly-15k-oai-style/tree/main).

- Instruction Format
```json
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
{"prompt": "<prompt text>", "completion": "<ideal generated text>"}
```

For a practical example, refer to [HuggingFace Instruction Dataset](https://huggingface.co/datasets/HuggingFaceH4/instruction-dataset/tree/main)


## Dataset Format Support

The SFTTrainer supports popular dataset formats, allowing direct passage of the dataset without pre-processing. Supported formats include conversational and instruction formats. If your dataset is not in one of these formats, it will be passed directly to the SFTTrainer without any preprocessing. This may result in undefined behavior if the dataset does not align with the trainer's expected input structure.

To ensure proper function, you may need to preprocess the dataset to match one of the supported formats.

For example usage and more details, refer to the [Official Hugging Face documentation and tutorials](https://huggingface.co/docs/trl/v0.9.4/sft_trainer#dataset-format-support).

0 comments on commit 923dca7

Please sign in to comment.