-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(genai): add tokenizer package (#10699)
- Loading branch information
Showing
3 changed files
with
312 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
package tokenizer_test | ||
|
||
import ( | ||
"fmt" | ||
"log" | ||
|
||
"cloud.google.com/go/vertexai/genai" | ||
"cloud.google.com/go/vertexai/genai/tokenizer" | ||
) | ||
|
||
func ExampleTokenizer_CountTokens() { | ||
tok, err := tokenizer.New("gemini-1.5-flash") | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
ntoks, err := tok.CountTokens(genai.Text("a prompt"), genai.Text("another prompt")) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
fmt.Println("total token count:", ntoks.TotalTokens) | ||
|
||
// Output: total token count: 4 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,156 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
// Package tokenizer provides local token counting for Gemini models. This | ||
// tokenizer downloads its model from the web, but otherwise doesn't require | ||
// an API call for every CountTokens invocation. | ||
package tokenizer | ||
|
||
import ( | ||
"bytes" | ||
"crypto/sha256" | ||
"encoding/hex" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"os" | ||
"path/filepath" | ||
|
||
"cloud.google.com/go/vertexai/genai" | ||
"cloud.google.com/go/vertexai/internal/sentencepiece" | ||
) | ||
|
||
var supportedModels = map[string]bool{ | ||
"gemini-1.0-pro": true, | ||
"gemini-1.5-pro": true, | ||
"gemini-1.5-flash": true, | ||
"gemini-1.0-pro-001": true, | ||
"gemini-1.0-pro-002": true, | ||
"gemini-1.5-pro-001": true, | ||
"gemini-1.5-flash-001": true, | ||
} | ||
|
||
// Tokenizer is a local tokenizer for text. | ||
type Tokenizer struct { | ||
encoder *sentencepiece.Encoder | ||
} | ||
|
||
// CountTokensResponse is the response of [Tokenizer.CountTokens]. | ||
type CountTokensResponse struct { | ||
TotalTokens int32 | ||
} | ||
|
||
// New creates a new [Tokenizer] from a model name; the model name is the same | ||
// as you would pass to a [genai.Client.GenerativeModel]. | ||
func New(modelName string) (*Tokenizer, error) { | ||
if !supportedModels[modelName] { | ||
return nil, fmt.Errorf("model %s is not supported", modelName) | ||
} | ||
|
||
data, err := loadModelData(gemmaModelURL, gemmaModelHash) | ||
if err != nil { | ||
return nil, fmt.Errorf("loading model: %w", err) | ||
} | ||
|
||
encoder, err := sentencepiece.NewEncoder(bytes.NewReader(data)) | ||
if err != nil { | ||
return nil, fmt.Errorf("creating encoder: %w", err) | ||
} | ||
|
||
return &Tokenizer{encoder: encoder}, nil | ||
} | ||
|
||
// CountTokens counts the tokens in all the given parts and returns their | ||
// sum. Only [genai.Text] parts are suppored; an error will be returned if | ||
// non-text parts are provided. | ||
func (tok *Tokenizer) CountTokens(parts ...genai.Part) (*CountTokensResponse, error) { | ||
sum := 0 | ||
|
||
for _, part := range parts { | ||
if t, ok := part.(genai.Text); ok { | ||
toks := tok.encoder.Encode(string(t)) | ||
sum += len(toks) | ||
} else { | ||
return nil, fmt.Errorf("Tokenizer.CountTokens only supports Text parts") | ||
} | ||
} | ||
|
||
return &CountTokensResponse{TotalTokens: int32(sum)}, nil | ||
} | ||
|
||
// gemmaModelURL is the URL from which we download the model file. | ||
const gemmaModelURL = "https://raw.githubusercontent.com/google/gemma_pytorch/33b652c465537c6158f9a472ea5700e5e770ad3f/tokenizer/tokenizer.model" | ||
|
||
// gemmaModelHash is the expected hash of the model file (as calculated | ||
// by [hashString]). | ||
const gemmaModelHash = "61a7b147390c64585d6c3543dd6fc636906c9af3865a5548f27f31aee1d4c8e2" | ||
|
||
// downloadModelFile downloads a file from the given URL. | ||
func downloadModelFile(url string) ([]byte, error) { | ||
resp, err := http.Get(url) | ||
if err != nil { | ||
return nil, err | ||
} | ||
defer resp.Body.Close() | ||
|
||
return io.ReadAll(resp.Body) | ||
} | ||
|
||
// hashString computes a hex string of the SHA256 hash of data. | ||
func hashString(data []byte) string { | ||
hash256 := sha256.Sum256(data) | ||
return hex.EncodeToString(hash256[:]) | ||
} | ||
|
||
// loadModelData loads model data from the given URL, using a local file-system | ||
// cache. wantHash is the hash (as returned by [hashString] expected on the | ||
// loaded data. | ||
// | ||
// Caching logic: | ||
// | ||
// Assuming $TEMP_DIR is the temporary directory used by the OS, this function | ||
// uses the file $TEMP_DIR/vertexai_tokenizer_model/$urlhash as a cache, where | ||
// $urlhash is hashString(url). | ||
// | ||
// If this cache file doesn't exist, or the data it contains doesn't match | ||
// wantHash, downloads data from the URL and writes it into the cache. If the | ||
// URL's data doesn't match the hash, an error is returned. | ||
func loadModelData(url string, wantHash string) ([]byte, error) { | ||
urlhash := hashString([]byte(url)) | ||
cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model") | ||
cachePath := filepath.Join(cacheDir, urlhash) | ||
|
||
cacheData, err := os.ReadFile(cachePath) | ||
if err != nil || hashString(cacheData) != wantHash { | ||
cacheData, err = downloadModelFile(url) | ||
if err != nil { | ||
return nil, fmt.Errorf("loading cache and downloading model: %w", err) | ||
} | ||
|
||
if hashString(cacheData) != wantHash { | ||
return nil, fmt.Errorf("downloaded model hash mismatch") | ||
} | ||
|
||
err = os.MkdirAll(cacheDir, 0770) | ||
if err != nil { | ||
return nil, fmt.Errorf("creating cache dir: %w", err) | ||
} | ||
err = os.WriteFile(cachePath, cacheData, 0660) | ||
if err != nil { | ||
return nil, fmt.Errorf("writing cache file: %w", err) | ||
} | ||
} | ||
|
||
return cacheData, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
// Copyright 2024 Google LLC | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
package tokenizer | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"path/filepath" | ||
"testing" | ||
|
||
"cloud.google.com/go/vertexai/genai" | ||
) | ||
|
||
func TestDownload(t *testing.T) { | ||
b, err := downloadModelFile(gemmaModelURL) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if hashString(b) != gemmaModelHash { | ||
t.Errorf("gemma model hash doesn't match") | ||
} | ||
} | ||
|
||
func TestLoadModelData(t *testing.T) { | ||
// Tests that loadModelData manages to load the model properly, and download | ||
// a new one as needed. | ||
checkDataAndErr := func(data []byte, err error) { | ||
t.Helper() | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
gotHash := hashString(data) | ||
if gotHash != gemmaModelHash { | ||
t.Errorf("got hash=%v, want=%v", gotHash, gemmaModelHash) | ||
} | ||
} | ||
|
||
data, err := loadModelData(gemmaModelURL, gemmaModelHash) | ||
checkDataAndErr(data, err) | ||
|
||
// The cache should exist now and have the right data, try again. | ||
data, err = loadModelData(gemmaModelURL, gemmaModelHash) | ||
checkDataAndErr(data, err) | ||
|
||
// Overwrite cache file with wrong data, and try again. | ||
cacheDir := filepath.Join(os.TempDir(), "vertexai_tokenizer_model") | ||
cachePath := filepath.Join(cacheDir, hashString([]byte(gemmaModelURL))) | ||
_ = os.MkdirAll(cacheDir, 0770) | ||
_ = os.WriteFile(cachePath, []byte{0, 1, 2, 3}, 0660) | ||
data, err = loadModelData(gemmaModelURL, gemmaModelHash) | ||
checkDataAndErr(data, err) | ||
} | ||
|
||
func TestCreateTokenizer(t *testing.T) { | ||
// Create a tokenizer successfully | ||
_, err := New("gemini-1.5-flash") | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
||
// Create a tokenizer with an unsupported model | ||
_, err = New("gemini-0.92") | ||
if err == nil { | ||
t.Errorf("got no error, want error") | ||
} | ||
} | ||
|
||
func TestCountTokens(t *testing.T) { | ||
var tests = []struct { | ||
parts []genai.Part | ||
wantCount int32 | ||
}{ | ||
{[]genai.Part{genai.Text("hello world")}, 2}, | ||
{[]genai.Part{genai.Text("<table><th></th></table>")}, 4}, | ||
{[]genai.Part{genai.Text("hello world"), genai.Text("<table><th></th></table>")}, 6}, | ||
} | ||
|
||
tok, err := New("gemini-1.5-flash") | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
||
for i, tt := range tests { | ||
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { | ||
got, err := tok.CountTokens(tt.parts...) | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
if got.TotalTokens != tt.wantCount { | ||
t.Errorf("got %v, want %v", got.TotalTokens, tt.wantCount) | ||
} | ||
}) | ||
} | ||
} | ||
|
||
func TestCountTokensNonText(t *testing.T) { | ||
tok, err := New("gemini-1.5-flash") | ||
if err != nil { | ||
t.Error(err) | ||
} | ||
|
||
_, err = tok.CountTokens(genai.Text("foo"), genai.ImageData("format", []byte{0, 1})) | ||
if err == nil { | ||
t.Error("got no error, want error") | ||
} | ||
} |