From abd3144f468abc8895fd6b5bb3e5063b5c4d19e4 Mon Sep 17 00:00:00 2001 From: Jacob Date: Sat, 2 Dec 2023 17:48:58 -0800 Subject: [PATCH] adding Pulze API --- app/lib/api_key_dialog.dart | 12 ++- app/lib/pulze_ai_api.dart | 172 ++++++++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 app/lib/pulze_ai_api.dart diff --git a/app/lib/api_key_dialog.dart b/app/lib/api_key_dialog.dart index 0aee40c..a1b294a 100644 --- a/app/lib/api_key_dialog.dart +++ b/app/lib/api_key_dialog.dart @@ -3,6 +3,7 @@ import 'package:glowby/openai_api.dart'; import 'package:glowby/utils.dart'; import 'hugging_face_api.dart'; +import 'pulze_ai_api.dart'; class ApiKeyDialog extends StatefulWidget { @override @@ -12,8 +13,10 @@ class ApiKeyDialog extends StatefulWidget { class _ApiKeyDialogState extends State { final _apiKeyController = TextEditingController(); final _huggingFaceTokenController = TextEditingController(); + final _pulzeAiController = TextEditingController(); String _apiKey = ''; String _huggingFaceToken = ''; + String _pulzeAiToken = ''; @override void initState() { @@ -25,6 +28,8 @@ class _ApiKeyDialogState extends State { _apiKeyController.text = _apiKey; _huggingFaceToken = HuggingFace_API.oat(); _huggingFaceTokenController.text = _huggingFaceToken; + _pulzeAiToken = PulzeAI_API.oat(); + _pulzeAiController.text = _pulzeAiToken; }); }); } @@ -32,6 +37,7 @@ class _ApiKeyDialogState extends State { void _saveApiKey(BuildContext context) { OpenAI_API.setOat(_apiKey); HuggingFace_API.setOat(_huggingFaceToken); + PulzeAI_API.setOat(_pulzeAiToken); Navigator.pop(context); // Hide the dialog ScaffoldMessenger.of(context).showSnackBar( @@ -116,13 +122,13 @@ class _ApiKeyDialogState extends State { SizedBox(height: 10), Text('Enter your Pulze.ai Token:'), TextField( - controller: _huggingFaceTokenController, + controller: _pulzeAiController, obscureText: true, decoration: InputDecoration( labelText: 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'), onChanged: (value) { setState(() { - _huggingFaceToken = value; + _pulzeAiToken = value; }); }, ), @@ -188,6 +194,8 @@ class _ApiKeyDialogState extends State { _apiKey = ''; _huggingFaceTokenController.clear(); _huggingFaceToken = ''; + _pulzeAiController.clear(); + _pulzeAiToken = ''; }); }, ), diff --git a/app/lib/pulze_ai_api.dart b/app/lib/pulze_ai_api.dart new file mode 100644 index 0000000..f4cea70 --- /dev/null +++ b/app/lib/pulze_ai_api.dart @@ -0,0 +1,172 @@ +import 'dart:convert'; +import 'package:flutter/foundation.dart'; +import 'package:http/http.dart' as http; +import 'package:flutter_secure_storage/flutter_secure_storage.dart'; + +class PulzeAI_API { + static String apiKey = ''; + static String _template = '''[ + { + "generated_text": "***" + } +] +'''; + static String _model = 'google/flan-t5-large'; + static String _systemMessage = ''; + static bool _sendMessages = false; + static const String _apiKeyKey = 'huggingface_api_key'; + static const String _templateKey = 'huggingface_template'; + static const String _modelKey = 'huggingface_model'; + static const String _systemMessageKey = 'huggingface_system_message'; + static const String _sendMessagesKey = 'huggingface_send_messages'; + static final FlutterSecureStorage _secureStorage = FlutterSecureStorage(); + + static String oat() { + if (apiKey == '') { + loadOat(); + } + + return apiKey; + } + + static void setOat(String value) async { + apiKey = value; + await _secureStorage.write(key: _apiKeyKey, value: apiKey); + } + + static Future loadOat() async { + try { + apiKey = await _secureStorage.read(key: _apiKeyKey) ?? ''; + _template = await _secureStorage.read(key: _templateKey) ?? + '''[ + { + "generated_text": "***" + } +] +'''; + _model = + await _secureStorage.read(key: _modelKey) ?? 'google/flan-t5-large'; + _systemMessage = await _secureStorage.read(key: _systemMessageKey) ?? ''; + _sendMessages = + await _secureStorage.read(key: _sendMessagesKey) == 'true'; + } catch (e) { + if (kDebugMode) { + print('Error loading OAT: $e'); + } + } + } + + static bool sendMessages() { + return _sendMessages; + } + + static void setSendMessages(bool sendMessages) { + _sendMessages = sendMessages; + _secureStorage.write( + key: _sendMessagesKey, value: _sendMessages.toString()); + } + + static String systemMessage() { + return _systemMessage; + } + + static void setSystemMessage(systemMessage) { + _systemMessage = systemMessage; + _secureStorage.write(key: _systemMessageKey, value: _systemMessage); + } + + static String model() { + return _model; + } + + static void setModel(model) { + _model = model; + _secureStorage.write(key: _modelKey, value: _model); + } + + static String template() { + return _template; + } + + static void setTemplate(template) { + _template = template; + _secureStorage.write(key: _templateKey, value: _template); + } + + static String? _findValueByTemplate(dynamic value, dynamic template) { + if (value is List && template is List && value.length == template.length) { + for (int i = 0; i < value.length; i++) { + final result = _findValueByTemplate(value[i], template[i]); + if (result != null) { + return result; + } + } + } else if (value is Map && template is Map) { + for (var key in template.keys) { + final result = _findValueByTemplate(value[key], template[key]); + if (result != null) { + return result; + } + } + } else if (template == "***") { + return value; + } + + return null; + } + + static Future generate(String text) async { + return await _generate(_model, text, _template); + } + + // Examples: + // generate('facebook/bart-large-cnn', 'What\'s the best way to play a guitar?', '[{"summary_text": "***"}]'); + // generate('google/flan-t5-large', 'What\'s the best way to play a guitar?', '[{"generated_text": "***"}]'); + static Future _generate( + String modelId, String text, String template) async { + if (apiKey == '') { + return 'Please enter your Pulze AI Access Token in the settings.'; + } + + final queryUrl = 'https://api.pulze.ai/v1/completions/'; + final headers = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer $apiKey', + 'Pulze-Labels': '{"hello": "world"}', // Added Pulze-Labels header + }; + + final body = jsonEncode({ + 'model': modelId, // Specify the model + 'prompt': _systemMessage == '' ? text : text + ' [System message]: ' + _systemMessage, + 'max_tokens': 7, // Added max_tokens + 'temperature': 0, // Added temperature + }); + + if (kDebugMode) { + print('Request URL: $queryUrl'); + } + + final response = + await http.post(Uri.parse(queryUrl), headers: headers, body: body); + + if (kDebugMode) { + print('Response Status Code: ${response.statusCode}'); + print('Response Body: ${response.body}'); + } + + if (response.statusCode == 200) { + final jsonResponse = jsonDecode(response.body); + final templateJson = jsonDecode(template); + final generatedText = _findValueByTemplate(jsonResponse, templateJson); + + if (kDebugMode) { + print('Generated Text: $generatedText'); + } + + return generatedText; + } else { + return 'Sorry, there was an error processing your request. Please try again later.'; + } +} + +}