diff --git a/backend/dataall/base/db/exceptions.py b/backend/dataall/base/db/exceptions.py
index 95c2c2a73..3a3371cfe 100644
--- a/backend/dataall/base/db/exceptions.py
+++ b/backend/dataall/base/db/exceptions.py
@@ -181,3 +181,26 @@ def __init__(self, action, message):
def __str__(self):
return f'{self.message}'
+
+
+class ResourceThresholdExceeded(Exception):
+ def __init__(self, username, action):
+ self.username = username
+ self.action = action
+ self.message = f"""
+ An error occurred (ResourceThresholdExceeded) when calling {self.action} operation:
+ Requests exceeded max daily invocation count for User: {self.username}
+ """
+
+ def __str__(self):
+ return f'{self.message}'
+
+
+class ModelGuardrailException(Exception):
+ def __init__(self, message):
+ self.message = f"""
+ An error occurred (ModelGuardrailException) when invoking the model: {message}
+ """
+
+ def __str__(self):
+ return f'{self.message}'
diff --git a/backend/dataall/core/resource_threshold/__init__.py b/backend/dataall/core/resource_threshold/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/backend/dataall/core/resource_threshold/db/resource_threshold_models.py b/backend/dataall/core/resource_threshold/db/resource_threshold_models.py
new file mode 100644
index 000000000..ba5e9bece
--- /dev/null
+++ b/backend/dataall/core/resource_threshold/db/resource_threshold_models.py
@@ -0,0 +1,12 @@
+from dataall.base.db import Base, utils
+from sqlalchemy import String, Integer, Column, Date
+from datetime import date
+
+
+class ResourceThreshold(Base):
+ __tablename__ = 'resource_threshold'
+ actionUri = Column(String(64), primary_key=True, default=utils.uuid('resource_threshold'))
+ username = Column(String(64), nullable=False)
+ actionType = Column(String(64), nullable=False)
+ date = Column(Date, default=date.today, nullable=False)
+ count = Column(Integer, default=1, nullable=False)
diff --git a/backend/dataall/core/resource_threshold/db/resource_threshold_repositories.py b/backend/dataall/core/resource_threshold/db/resource_threshold_repositories.py
new file mode 100644
index 000000000..90bac8c35
--- /dev/null
+++ b/backend/dataall/core/resource_threshold/db/resource_threshold_repositories.py
@@ -0,0 +1,56 @@
+from dataall.core.resource_threshold.db.resource_threshold_models import ResourceThreshold
+from sqlalchemy import and_
+from datetime import date
+
+
+class ResourceThresholdRepository:
+ @staticmethod
+ def get_count_today(session, username, action_type):
+ amount = (
+ session.query(ResourceThreshold.count)
+ .filter(
+ and_(
+ ResourceThreshold.username == username,
+ ResourceThreshold.actionType == action_type,
+ ResourceThreshold.date == date.today(),
+ )
+ )
+ .scalar()
+ )
+ return amount if amount else 0
+
+ @staticmethod
+ def add_entry(session, username, action_type):
+ user_entry = ResourceThresholdRepository._get_user_entry(session, username, action_type)
+ if user_entry:
+ session.query(ResourceThreshold).filter(
+ and_(
+ ResourceThreshold.username == username,
+ ResourceThreshold.actionType == action_type,
+ )
+ ).update({ResourceThreshold.count: 1, ResourceThreshold.date: date.today()}, synchronize_session=False)
+ session.commit()
+ else:
+ action_entry = ResourceThreshold(username=username, actionType=action_type)
+ session.add(action_entry)
+ session.commit()
+
+ @staticmethod
+ def increment_count(session, username, action_type):
+ session.query(ResourceThreshold).filter(
+ and_(
+ ResourceThreshold.username == username,
+ ResourceThreshold.actionType == action_type,
+ ResourceThreshold.date == date.today(),
+ )
+ ).update({ResourceThreshold.count: ResourceThreshold.count + 1}, synchronize_session=False)
+ session.commit()
+
+ @staticmethod
+ def _get_user_entry(session, username, action_type):
+ entry = (
+ session.query(ResourceThreshold)
+ .filter(and_(ResourceThreshold.username == username, ResourceThreshold.actionType == action_type))
+ .first()
+ )
+ return entry
diff --git a/backend/dataall/core/resource_threshold/services/__init__.py b/backend/dataall/core/resource_threshold/services/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/backend/dataall/core/resource_threshold/services/resource_threshold_service.py b/backend/dataall/core/resource_threshold/services/resource_threshold_service.py
new file mode 100644
index 000000000..53574ae14
--- /dev/null
+++ b/backend/dataall/core/resource_threshold/services/resource_threshold_service.py
@@ -0,0 +1,42 @@
+from dataall.core.resource_threshold.db.resource_threshold_repositories import ResourceThresholdRepository
+from dataall.base.db import exceptions
+from functools import wraps
+from dataall.base.config import config
+from dataall.base.context import get_context
+
+import logging
+
+log = logging.getLogger(__name__)
+
+
+class ResourceThresholdService:
+ @staticmethod
+ def check_invocation_count(action_type, max_count_config_path):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ context = get_context()
+ with context.db_engine.scoped_session() as session:
+ count = ResourceThresholdRepository.get_count_today(
+ session=session, username=context.username, action_type=action_type
+ )
+ max_count = config.get_property(max_count_config_path, 10)
+ log.info(
+ f'User {context.username} has invoked {action_type} {count} times today of max {max_count}'
+ )
+ if count < max_count:
+ if count == 0:
+ ResourceThresholdRepository.add_entry(
+ session=session, username=context.username, action_type=action_type
+ )
+ else:
+ ResourceThresholdRepository.increment_count(
+ session=session, username=context.username, action_type=action_type
+ )
+ return func(*args, **kwargs)
+ else:
+ raise exceptions.ResourceThresholdExceeded(username=context.username, action=action_type)
+
+ return wrapper
+
+ return decorator
diff --git a/backend/dataall/modules/s3_datasets/api/dataset/queries.py b/backend/dataall/modules/s3_datasets/api/dataset/queries.py
index 5043b868d..aeb7b3589 100644
--- a/backend/dataall/modules/s3_datasets/api/dataset/queries.py
+++ b/backend/dataall/modules/s3_datasets/api/dataset/queries.py
@@ -2,6 +2,7 @@
from dataall.modules.s3_datasets.api.dataset.resolvers import (
get_dataset,
get_dataset_assume_role_url,
+ list_s3_object_keys,
get_file_upload_presigned_url,
list_datasets_owned_by_env_group,
)
@@ -45,3 +46,12 @@
resolver=list_datasets_owned_by_env_group,
test_scope='Dataset',
)
+
+listS3ObjectKeys = gql.QueryField(
+ name='listS3ObjectKeys',
+ type=gql.ArrayType(gql.String),
+ args=[
+ gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)),
+ ],
+ resolver=list_s3_object_keys,
+)
diff --git a/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py b/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py
index 90f6fd3d9..4f4b7807f 100644
--- a/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py
+++ b/backend/dataall/modules/s3_datasets/api/dataset/resolvers.py
@@ -156,6 +156,10 @@ def list_datasets_owned_by_env_group(
return DatasetService.list_datasets_owned_by_env_group(environmentUri, groupUri, filter)
+def list_s3_object_keys(context, source, datasetUri: str = None):
+ return DatasetService.list_s3_object_keys(uri=datasetUri)
+
+
class RequestValidator:
@staticmethod
def validate_creation_request(data):
diff --git a/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py b/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py
index 94db4d056..acab43ec9 100644
--- a/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py
+++ b/backend/dataall/modules/s3_datasets/aws/s3_dataset_client.py
@@ -73,3 +73,17 @@ def get_bucket_encryption(self) -> (str, str, str):
f'Data.all Environment Pivot Role does not have s3:GetEncryptionConfiguration Permission for {dataset.S3BucketName} bucket: {e}'
)
raise Exception(f'Cannot fetch the bucket encryption configuration for {dataset.S3BucketName}: {e}')
+
+ def list_object_keys(self, bucket_name):
+ try:
+ response = self._client.list_objects_v2(
+ Bucket=bucket_name,
+ )
+
+ def txt_or_pdf(s):
+ return s.split('.')[-1] in ['pdf', 'txt']
+
+ return [ob['Key'] for ob in response.get('Contents', []) if txt_or_pdf(ob['Key'])]
+ except ClientError as e:
+ logging.error(f'Failed to list objects in {bucket_name} : {e}')
+ raise e
diff --git a/backend/dataall/modules/s3_datasets/services/dataset_service.py b/backend/dataall/modules/s3_datasets/services/dataset_service.py
index 2e68eb951..8a2977154 100644
--- a/backend/dataall/modules/s3_datasets/services/dataset_service.py
+++ b/backend/dataall/modules/s3_datasets/services/dataset_service.py
@@ -38,6 +38,7 @@
DATASET_ALL,
DATASET_READ,
IMPORT_DATASET,
+ GET_DATASET,
)
from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository
from dataall.modules.datasets_base.db.dataset_repositories import DatasetBaseRepository
@@ -556,3 +557,11 @@ def delete_dataset_term_links(session, dataset_uri):
for table_uri in tables:
GlossaryRepository.delete_glossary_terms_links(session, table_uri, 'DatasetTable')
GlossaryRepository.delete_glossary_terms_links(session, dataset_uri, 'Dataset')
+
+ @staticmethod
+ @ResourcePolicyService.has_resource_permission(GET_DATASET)
+ def list_s3_object_keys(uri):
+ with get_context().db_engine.scoped_session() as session:
+ dataset = DatasetRepository.get_dataset_by_uri(session, uri)
+
+ return S3DatasetClient(dataset).list_object_keys(dataset.S3BucketName)
diff --git a/backend/dataall/modules/worksheets/api/queries.py b/backend/dataall/modules/worksheets/api/queries.py
index e1b92a4a2..55fc1c35f 100644
--- a/backend/dataall/modules/worksheets/api/queries.py
+++ b/backend/dataall/modules/worksheets/api/queries.py
@@ -1,5 +1,11 @@
from dataall.base.api import gql
-from dataall.modules.worksheets.api.resolvers import get_worksheet, list_worksheets, run_sql_query
+from dataall.modules.worksheets.api.resolvers import (
+ get_worksheet,
+ list_worksheets,
+ run_sql_query,
+ text_to_sql,
+ analyze_text_genai,
+)
getWorksheet = gql.QueryField(
@@ -28,3 +34,29 @@
],
resolver=run_sql_query,
)
+
+TextToSQL = gql.QueryField(
+ name='textToSQL',
+ type=gql.String,
+ args=[
+ gql.Argument(name='worksheetUri', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='prompt', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='databaseName', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='tableNames', type=gql.ArrayType(gql.String)),
+ ],
+ resolver=text_to_sql,
+)
+
+analyzeTextDocument = gql.QueryField(
+ name='analyzeTextDocument',
+ type=gql.String,
+ args=[
+ gql.Argument(name='worksheetUri', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='environmentUri', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='prompt', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='datasetUri', type=gql.NonNullableType(gql.String)),
+ gql.Argument(name='key', type=gql.NonNullableType(gql.String)),
+ ],
+ resolver=analyze_text_genai,
+)
diff --git a/backend/dataall/modules/worksheets/api/resolvers.py b/backend/dataall/modules/worksheets/api/resolvers.py
index 450667217..fa2863c6e 100644
--- a/backend/dataall/modules/worksheets/api/resolvers.py
+++ b/backend/dataall/modules/worksheets/api/resolvers.py
@@ -3,6 +3,7 @@
from dataall.modules.worksheets.db.worksheet_models import Worksheet
from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository
from dataall.modules.worksheets.services.worksheet_service import WorksheetService
+from dataall.base.feature_toggle_checker import is_feature_enabled
from dataall.base.api.context import Context
@@ -14,27 +15,19 @@ def create_worksheet(context: Context, source, input: dict = None):
if not input.get('label'):
raise exceptions.RequiredParameter('label')
- with context.engine.scoped_session() as session:
- return WorksheetService.create_worksheet(
- session=session,
- username=context.username,
- data=input,
- )
+ return WorksheetService.create_worksheet(
+ data=input,
+ )
-def update_worksheet(context: Context, source, worksheetUri: str = None, input: dict = None):
- with context.engine.scoped_session() as session:
- return WorksheetService.update_worksheet(
- session=session, username=context.username, uri=worksheetUri, data=input
- )
+def update_worksheet(context: Context, source, worksheetUri: str, input: dict = None):
+ return WorksheetService.update_worksheet(uri=worksheetUri, data=input)
-def get_worksheet(context: Context, source, worksheetUri: str = None):
- with context.engine.scoped_session() as session:
- return WorksheetService.get_worksheet(
- session=session,
- uri=worksheetUri,
- )
+def get_worksheet(context: Context, source, worksheetUri: str):
+ return WorksheetService.get_worksheet(
+ uri=worksheetUri,
+ )
def resolve_user_role(context: Context, source: Worksheet):
@@ -59,13 +52,47 @@ def list_worksheets(context, source, filter: dict = None):
)
-def run_sql_query(context: Context, source, environmentUri: str = None, worksheetUri: str = None, sqlQuery: str = None):
- with context.engine.scoped_session() as session:
- return WorksheetService.run_sql_query(
- session=session, uri=environmentUri, worksheetUri=worksheetUri, sqlQuery=sqlQuery
- )
+def run_sql_query(context: Context, source, environmentUri: str, worksheetUri: str, sqlQuery: str):
+ return WorksheetService.run_sql_query(uri=environmentUri, worksheetUri=worksheetUri, sqlQuery=sqlQuery)
-def delete_worksheet(context, source, worksheetUri: str = None):
- with context.engine.scoped_session() as session:
- return WorksheetService.delete_worksheet(session=session, uri=worksheetUri)
+def delete_worksheet(context, source, worksheetUri: str):
+ return WorksheetService.delete_worksheet(uri=worksheetUri)
+
+
+@is_feature_enabled('modules.worksheets.features.nlq.active')
+def text_to_sql(
+ context: Context,
+ source,
+ environmentUri: str,
+ worksheetUri: str,
+ prompt: str,
+ databaseName: str,
+ tableNames: list,
+):
+ return WorksheetService.run_nlq(
+ uri=environmentUri,
+ prompt=prompt,
+ worksheetUri=worksheetUri,
+ db_name=databaseName,
+ table_names=tableNames,
+ )
+
+
+@is_feature_enabled('modules.worksheets.features.nlq.active')
+def analyze_text_genai(
+ context,
+ source,
+ worksheetUri: str,
+ environmentUri: str,
+ prompt: str,
+ datasetUri: str,
+ key: str,
+):
+ return WorksheetService.analyze_text_genai(
+ uri=environmentUri,
+ worksheetUri=worksheetUri,
+ prompt=prompt,
+ datasetUri=datasetUri,
+ key=key,
+ )
diff --git a/backend/dataall/modules/worksheets/aws/bedrock_client.py b/backend/dataall/modules/worksheets/aws/bedrock_client.py
new file mode 100644
index 000000000..13a0e286f
--- /dev/null
+++ b/backend/dataall/modules/worksheets/aws/bedrock_client.py
@@ -0,0 +1,51 @@
+from dataall.base.aws.sts import SessionHelper
+from langchain_aws import ChatBedrock as BedrockChat
+from langchain_core.output_parsers import StrOutputParser
+from langchain_core.prompts import PromptTemplate
+from dataall.base.db import exceptions
+import os
+
+TEXT_TO_SQL_EXAMPLES_PATH = os.path.join(os.path.dirname(__file__), 'bedrock_prompts', 'text_to_sql_examples.txt')
+TEXT_TO_SQL_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), 'bedrock_prompts', 'test_to_sql_template.txt')
+PROCESS_TEXT_TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), 'bedrock_prompts', 'process_text_template.txt')
+
+
+class BedrockClient:
+ def __init__(self):
+ self._session = SessionHelper.get_session()
+ self._client = self._session.client('bedrock-runtime')
+ model_id = 'anthropic.claude-3-5-sonnet-20240620-v1:0'
+ model_kwargs = {
+ 'max_tokens': 2048,
+ 'temperature': 0,
+ 'top_k': 250,
+ 'top_p': 1,
+ 'stop_sequences': ['\n\nHuman'],
+ }
+ self._model = BedrockChat(
+ client=self._client,
+ model_id=model_id,
+ model_kwargs=model_kwargs,
+ )
+
+ def invoke_model_text_to_sql(self, prompt: str, metadata: str):
+ prompt_template = PromptTemplate.from_file(TEXT_TO_SQL_TEMPLATE_PATH)
+ chain = prompt_template | self._model | StrOutputParser()
+
+ with open(TEXT_TO_SQL_EXAMPLES_PATH, 'r') as f:
+ examples = f.read()
+
+ response = chain.invoke({'prompt': prompt, 'context': metadata, 'examples': examples})
+ if response.startswith('Error:'):
+ raise exceptions.ModelGuardrailException(response)
+ return response
+
+ def invoke_model_process_text(self, prompt: str, content: str):
+ prompt_template = PromptTemplate.from_file(PROCESS_TEXT_TEMPLATE_PATH)
+
+ chain = prompt_template | self._model | StrOutputParser()
+ response = chain.invoke({'prompt': prompt, 'content': content})
+
+ if response.startswith('Error:'):
+ raise exceptions.ModelGuardrailException(response)
+ return response
diff --git a/backend/dataall/modules/worksheets/aws/bedrock_prompts/process_text_template.txt b/backend/dataall/modules/worksheets/aws/bedrock_prompts/process_text_template.txt
new file mode 100644
index 000000000..1b9d56fa9
--- /dev/null
+++ b/backend/dataall/modules/worksheets/aws/bedrock_prompts/process_text_template.txt
@@ -0,0 +1,14 @@
+You are an AI assistant tasked with analyzing and processing text content. Your goal is to provide accurate and helpful responses based on the given content and user prompt.
+You must follow the steps:
+
+1. Detetermine if the document has the information to be able to answer the question. If not respond with "Error: The Document does not provide the information needed to answer your question"
+2. I want you to answer the question based on the information in the document.
+3. At the bottom I want you to provide the sources (the parts of the document where you found the results). The sources should be listed in order
+
+
+Content to analyze:
+{content}
+
+User prompt: {prompt}
+
+Please provide a response that addresses the user's prompt in the context of the given content. Be thorough, accurate, and helpful in your analysis.
diff --git a/backend/dataall/modules/worksheets/aws/bedrock_prompts/test_to_sql_template.txt b/backend/dataall/modules/worksheets/aws/bedrock_prompts/test_to_sql_template.txt
new file mode 100644
index 000000000..c24f5f9f0
--- /dev/null
+++ b/backend/dataall/modules/worksheets/aws/bedrock_prompts/test_to_sql_template.txt
@@ -0,0 +1,38 @@
+You will be given the name of an AWS Glue Database, metadata from one or more AWS Glue Table(s) and a user prompt from a user.
+
+Based on this information your job is to turn the prompt into a SQL query that will be sent to query the data within the tables in Amazon Athena.
+
+Take the following points into consideration. It is crucial that you follow them:
+
+- I only want you to return the SQL needed (NO EXPLANATION or anything else).
+
+- Tables are referenced on the following form 'database_name.table_name' (for example 'Select * FROM database_name.table_name ...' and not 'SELECT * FROM table_name ...) since we dont have access to the table name directly since its not global variable.
+
+- Take relations between tables into consideration, for example if you have a table with columns that might reference the other tables, you would need to join them in the query.
+
+- The generate SQL statement MUST be Read only (no WRITE, INSERT, ALTER or DELETE keywords)
+
+- Answer on the same form as the examples given below.
+
+Examples:
+{examples}
+
+
+I want you to follow the following steps when generating the SQL statement:
+
+Step 1: Determine if the given tables columns are suitable to answer the question.
+If not respond with "Error: The tables provided does not give enough information"
+
+Step 2: Determine if the user wants to perform any mutations, if so return "Error: Only READ queries are allowed"
+
+Step 3: Determine if joins will be needed.
+
+Step 4: Generate the SQL in order to solve the problem.
+
+
+Based on the following metadata:
+{context}
+
+
+User prompt: {prompt}
+
diff --git a/backend/dataall/modules/worksheets/aws/bedrock_prompts/text_to_sql_examples.txt b/backend/dataall/modules/worksheets/aws/bedrock_prompts/text_to_sql_examples.txt
new file mode 100644
index 000000000..8fa9bd84c
--- /dev/null
+++ b/backend/dataall/modules/worksheets/aws/bedrock_prompts/text_to_sql_examples.txt
@@ -0,0 +1,49 @@
+Example 1.
+User prompt: I want to get the average area of all listings
+
+Context: Based on on the following metadata
+Database Name : dataall_homes_11p3uu8f
+Table Name: listings
+Column Metadata: [{'Name': 'price', 'Type': 'bigint'}, {'Name': 'area', 'Type': 'bigint'}, {'Name': 'bedrooms', 'Type': 'bigint'}, {'Name': 'bathrooms', 'Type': 'bigint'}, {'Name': 'stories', 'Type': 'bigint'}, {'Name': 'mainroad', 'Type': 'string'}, {'Name': 'guestroom', 'Type': 'string'}, {'Name': 'basement', 'Type': 'string'}, {'Name': 'hotwaterheating', 'Type': 'string'}, {'Name': 'airconditioning', 'Type': 'string'}, {'Name': 'parking', 'Type': 'bigint'}, {'Name': 'prefarea', 'Type': 'string'}, {'Name': 'furnishingstatus', 'Type': 'string'}, {'Name': 'passengerid', 'Type': 'bigint'}, {'Name': 'survived', 'Type': 'bigint'}, {'Name': 'pclass', 'Type': 'bigint'}, {'Name': 'name', 'Type': 'string'}, {'Name': 'sex', 'Type': 'string'}, {'Name': 'age', 'Type': 'double'}, {'Name': 'sibsp', 'Type': 'bigint'}, {'Name': 'parch', 'Type': 'bigint'}, {'Name': 'ticket', 'Type': 'string'}, {'Name': 'fare', 'Type': 'double'}, {'Name': 'cabin', 'Type': 'string'}, {'Name': 'embarked', 'Type': 'string'}]
+Partition Metadata: []
+
+Response: SELECT AVG(CAST(area AS DOUBLE)) FROM dataall_homes_11p3uu8f.listings WHERE area IS NOT NULL;
+
+
+Example 2.
+User prompt: I want to get the average of the 3 most expensive listings with less than 3 bedrooms
+
+Context: Based on on the following metadata
+Database Name : dataall_homes_11p3uu8f
+Table Name: listings
+Column Metadata: [{'Name': 'price', 'Type': 'bigint'}, {'Name': 'area', 'Type': 'bigint'}, {'Name': 'bedrooms', 'Type': 'bigint'}, {'Name': 'bathrooms', 'Type': 'bigint'}, {'Name': 'stories', 'Type': 'bigint'}, {'Name': 'mainroad', 'Type': 'string'}, {'Name': 'guestroom', 'Type': 'string'}, {'Name': 'basement', 'Type': 'string'}, {'Name': 'hotwaterheating', 'Type': 'string'}, {'Name': 'airconditioning', 'Type': 'string'}, {'Name': 'parking', 'Type': 'bigint'}, {'Name': 'prefarea', 'Type': 'string'}, {'Name': 'furnishingstatus', 'Type': 'string'}, {'Name': 'passengerid', 'Type': 'bigint'}, {'Name': 'survived', 'Type': 'bigint'}, {'Name': 'pclass', 'Type': 'bigint'}, {'Name': 'name', 'Type': 'string'}, {'Name': 'sex', 'Type': 'string'}, {'Name': 'age', 'Type': 'double'}, {'Name': 'sibsp', 'Type': 'bigint'}, {'Name': 'parch', 'Type': 'bigint'}, {'Name': 'ticket', 'Type': 'string'}, {'Name': 'fare', 'Type': 'double'}, {'Name': 'cabin', 'Type': 'string'}, {'Name': 'embarked', 'Type': 'string'}]
+Partition Metadata: []
+
+Response: SELECT AVG(price) AS average_price FROM (SELECT price FROM dataall_homes_11p3uu8f.listings WHERE bedrooms > 3 ORDER BY price DESC LIMIT 3);
+
+
+Example 3.
+User prompt: I want to see if any letter has been sent from 900 Somerville Avenue to 2 Finnigan Street and what is the content
+
+Context: Based on the following metadata
+Database Name : dataall_packages_omf768qq
+Table name: packages
+Column Metadata: [{'Name': 'id', 'Type': 'bigint'}, {'Name': 'contents', 'Type': 'string'}, {'Name': 'from_address_id', 'Type': 'bigint'}, {'Name': 'to_address_id', 'Type': 'bigint'}]\n
+Partition Metadata: []
+
+Database Name : dataall_packages_omf768qq
+Table name: addresses
+Column Metadata: [{'Name': 'id', 'Type': 'bigint'}, {'Name': 'address', 'Type': 'string'}, {'Name': 'type', 'Type': 'string'}]
+Partition Metadata: []
+
+Database Name : dataall_packages_omf768qq
+Table name: drivers
+Column Metadata: [{'Name': 'id', 'Type': 'bigint'}, {'Name': 'name', 'Type': 'string'}]
+Partition Metadata: []
+
+Database Name : dataall_packages_omf768qq
+Table name: scans
+Column Metadata: [{'Name': 'id', 'Type': 'bigint'}, {'Name': 'driver_id', 'Type': 'bigint'}, {'Name': 'package_id', 'Type': 'bigint'}, {'Name': 'address_id', 'Type': 'bigint'}, {'Name': 'action', 'Type': 'string'}, {'Name': 'timestamp', 'Type': 'string'}]
+Partition Metadata: []
+
+Response: SELECT p.contents FROM dataall_packages_omf768qq.packages p JOIN dataall_packages_omf768qq.addresses a1 ON p.from_address_id = a1.id JOIN dataall_packages_omf768qq.addresses a2 ON p.to_address_id = a2.id WHERE a1.address = '900 Somerville Avenue' AND a2.address = '2 Finnigan Street';
diff --git a/backend/dataall/modules/worksheets/aws/glue_client.py b/backend/dataall/modules/worksheets/aws/glue_client.py
new file mode 100644
index 000000000..d123a7662
--- /dev/null
+++ b/backend/dataall/modules/worksheets/aws/glue_client.py
@@ -0,0 +1,36 @@
+import logging
+from os import name
+
+from botocore.exceptions import ClientError
+
+from dataall.base.aws.sts import SessionHelper
+
+log = logging.getLogger(__name__)
+
+
+class GlueClient:
+ def __init__(self, account_id, region, role=None):
+ pivot_role_session = SessionHelper.remote_session(accountid=account_id, region=region)
+ aws_session = (
+ SessionHelper.get_session(base_session=pivot_role_session, role_arn=role) if role else pivot_role_session
+ )
+ self._client = aws_session.client('glue', region_name=region)
+ self._account_id = account_id
+ self._region = region
+
+ def get_table_metadata(self, database, table_name):
+ try:
+ table_metadata = self._client.get_table(DatabaseName=database, Name=table_name)
+ table_name = table_metadata['Table']['Name']
+ column_metadata = table_metadata['Table']['StorageDescriptor']['Columns']
+ partition_metadata = table_metadata['Table']['PartitionKeys']
+ meta_data = f"""
+ Database Name: {database}
+ Table Name: {table_name}
+ Column Metadata: {column_metadata}
+ Partition Metadata: {partition_metadata}
+ """
+ except ClientError as e:
+ log.error(f'Error fetching metadata for {table_name=}: {e}')
+ raise e
+ return meta_data
diff --git a/backend/dataall/modules/worksheets/aws/s3_client.py b/backend/dataall/modules/worksheets/aws/s3_client.py
new file mode 100644
index 000000000..d5a7ba981
--- /dev/null
+++ b/backend/dataall/modules/worksheets/aws/s3_client.py
@@ -0,0 +1,53 @@
+import logging
+import pypdf
+from io import BytesIO
+
+from dataall.base.aws.sts import SessionHelper
+from botocore.exceptions import ClientError
+from dataall.base.db import exceptions
+
+logger = logging.getLogger(__name__)
+
+
+class S3Client:
+ file_extension_readers = {
+ 'txt': lambda content: S3Client._read_txt_content(content),
+ 'pdf': lambda content: S3Client._read_pdf_content(content),
+ }
+
+ def __init__(self, account_id, region, role=None):
+ pivot_role_session = SessionHelper.remote_session(accountid=account_id, region=region)
+ aws_session = (
+ SessionHelper.get_session(base_session=pivot_role_session, role_arn=role) if role else pivot_role_session
+ )
+ self._client = aws_session.client('s3', region_name=region)
+
+ @staticmethod
+ def _read_txt_content(content):
+ file_content = content['Body'].read().decode('utf-8')
+ return file_content
+
+ @staticmethod
+ def _read_pdf_content(content):
+ pdf_content = content['Body'].read()
+ pdf_buffer = BytesIO(pdf_content)
+ pdf_reader = pypdf.PdfReader(pdf_buffer)
+ full_text = ''
+ for page_num in range(len(pdf_reader.pages)):
+ page = pdf_reader.pages[page_num]
+ full_text += page.extract_text()
+ return full_text
+
+ def get_content(self, bucket_name, key):
+ try:
+ file_extension = key.split('.')[-1].lower()
+ if file_extension not in self.file_extension_readers.keys():
+ raise exceptions.InvalidInput('S3 Object Key', key, '.txt or .pdf file extensions only')
+
+ content = self._client.get_object(Bucket=bucket_name, Key=key)
+
+ return self.file_extension_readers[file_extension](content)
+
+ except ClientError as e:
+ logging.error(f'Failed to get content of {key} in {bucket_name} : {e}')
+ raise e
diff --git a/backend/dataall/modules/worksheets/services/worksheet_service.py b/backend/dataall/modules/worksheets/services/worksheet_service.py
index b72efa367..d5ecb92de 100644
--- a/backend/dataall/modules/worksheets/services/worksheet_service.py
+++ b/backend/dataall/modules/worksheets/services/worksheet_service.py
@@ -1,5 +1,10 @@
import logging
+from dataall.core.resource_threshold.services.resource_threshold_service import ResourceThresholdService
+from dataall.modules.worksheets.aws.glue_client import GlueClient
+from dataall.modules.worksheets.aws.s3_client import S3Client
+from dataall.modules.s3_datasets.db.dataset_repositories import DatasetRepository
+from dataall.modules.worksheets.aws.bedrock_client import BedrockClient
from dataall.core.activity.db.activity_models import Activity
from dataall.core.environment.services.environment_service import EnvironmentService
from dataall.base.db import exceptions
@@ -8,6 +13,7 @@
from dataall.modules.worksheets.aws.athena_client import AthenaClient
from dataall.modules.worksheets.db.worksheet_models import Worksheet
from dataall.modules.worksheets.db.worksheet_repositories import WorksheetRepository
+from dataall.base.context import get_context
from dataall.modules.worksheets.services.worksheet_permissions import (
MANAGE_WORKSHEETS,
UPDATE_WORKSHEET,
@@ -23,7 +29,7 @@
class WorksheetService:
@staticmethod
- def get_worksheet_by_uri(session, uri: str) -> Worksheet:
+ def _get_worksheet_by_uri(session, uri: str) -> Worksheet:
if not uri:
raise exceptions.RequiredParameter(param_name='worksheetUri')
worksheet = WorksheetRepository.find_worksheet_by_uri(session, uri)
@@ -33,85 +39,107 @@ def get_worksheet_by_uri(session, uri: str) -> Worksheet:
@staticmethod
@TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS)
- def create_worksheet(session, username, data=None) -> Worksheet:
- worksheet = Worksheet(
- owner=username,
- label=data.get('label'),
- description=data.get('description', 'No description provided'),
- tags=data.get('tags'),
- chartConfig={'dimensions': [], 'measures': [], 'chartType': 'bar'},
- SamlAdminGroupName=data['SamlAdminGroupName'],
- )
-
- session.add(worksheet)
- session.commit()
-
- activity = Activity(
- action='WORKSHEET:CREATE',
- label='WORKSHEET:CREATE',
- owner=username,
- summary=f'{username} created worksheet {worksheet.name} ',
- targetUri=worksheet.worksheetUri,
- targetType='worksheet',
- )
- session.add(activity)
-
- ResourcePolicyService.attach_resource_policy(
- session=session,
- group=data['SamlAdminGroupName'],
- permissions=WORKSHEET_ALL,
- resource_uri=worksheet.worksheetUri,
- resource_type=Worksheet.__name__,
- )
- return worksheet
+ def create_worksheet(data=None) -> Worksheet:
+ context = get_context()
+ with context.db_engine.scoped_session() as session:
+ worksheet = Worksheet(
+ owner=context.username,
+ label=data.get('label'),
+ description=data.get('description', 'No description provided'),
+ tags=data.get('tags'),
+ chartConfig={'dimensions': [], 'measures': [], 'chartType': 'bar'},
+ SamlAdminGroupName=data['SamlAdminGroupName'],
+ )
+
+ session.add(worksheet)
+ session.commit()
+
+ activity = Activity(
+ action='WORKSHEET:CREATE',
+ label='WORKSHEET:CREATE',
+ owner=context.username,
+ summary=f'{context.username} created worksheet {worksheet.name} ',
+ targetUri=worksheet.worksheetUri,
+ targetType='worksheet',
+ )
+ session.add(activity)
+
+ ResourcePolicyService.attach_resource_policy(
+ session=session,
+ group=data['SamlAdminGroupName'],
+ permissions=WORKSHEET_ALL,
+ resource_uri=worksheet.worksheetUri,
+ resource_type=Worksheet.__name__,
+ )
+ return worksheet
@staticmethod
+ @TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS)
@ResourcePolicyService.has_resource_permission(UPDATE_WORKSHEET)
- def update_worksheet(session, username, uri, data=None):
- worksheet = WorksheetService.get_worksheet_by_uri(session, uri)
- for field in data.keys():
- setattr(worksheet, field, data.get(field))
- session.commit()
-
- activity = Activity(
- action='WORKSHEET:UPDATE',
- label='WORKSHEET:UPDATE',
- owner=username,
- summary=f'{username} updated worksheet {worksheet.name} ',
- targetUri=worksheet.worksheetUri,
- targetType='worksheet',
- )
- session.add(activity)
- return worksheet
+ def update_worksheet(uri, data=None):
+ context = get_context()
+ with context.db_engine.scoped_session() as session:
+ worksheet = WorksheetService._get_worksheet_by_uri(session, uri)
+ for field in data.keys():
+ setattr(worksheet, field, data.get(field))
+ session.commit()
+
+ activity = Activity(
+ action='WORKSHEET:UPDATE',
+ label='WORKSHEET:UPDATE',
+ owner=context.username,
+ summary=f'{context.username} updated worksheet {worksheet.name} ',
+ targetUri=worksheet.worksheetUri,
+ targetType='worksheet',
+ )
+ session.add(activity)
+ return worksheet
@staticmethod
@ResourcePolicyService.has_resource_permission(GET_WORKSHEET)
- def get_worksheet(session, uri):
- worksheet = WorksheetService.get_worksheet_by_uri(session, uri)
- return worksheet
+ def get_worksheet(uri):
+ with get_context().db_engine.scoped_session() as session:
+ worksheet = WorksheetService._get_worksheet_by_uri(session, uri)
+ return worksheet
+
+ @staticmethod
+ def list_worksheets(filter):
+ context = get_context()
+ with context.db_engine.scoped_session() as session:
+ return WorksheetRepository.paginated_user_worksheets(
+ session=session,
+ username=context.username,
+ groups=context.groups,
+ uri=None,
+ data=filter,
+ check_perm=True,
+ )
@staticmethod
+ @TenantPolicyService.has_tenant_permission(MANAGE_WORKSHEETS)
@ResourcePolicyService.has_resource_permission(DELETE_WORKSHEET)
- def delete_worksheet(session, uri) -> bool:
- worksheet = WorksheetService.get_worksheet_by_uri(session, uri)
- session.delete(worksheet)
- ResourcePolicyService.delete_resource_policy(
- session=session,
- group=worksheet.SamlAdminGroupName,
- resource_uri=uri,
- resource_type=Worksheet.__name__,
- )
- return True
+ def delete_worksheet(uri) -> bool:
+ with get_context().db_engine.scoped_session() as session:
+ worksheet = WorksheetService._get_worksheet_by_uri(session, uri)
+ session.delete(worksheet)
+ ResourcePolicyService.delete_resource_policy(
+ session=session,
+ group=worksheet.SamlAdminGroupName,
+ resource_uri=uri,
+ resource_type=Worksheet.__name__,
+ )
+ return True
@staticmethod
@ResourcePolicyService.has_resource_permission(RUN_ATHENA_QUERY)
- def run_sql_query(session, uri, worksheetUri, sqlQuery):
- environment = EnvironmentService.get_environment_by_uri(session, uri)
- worksheet = WorksheetService.get_worksheet_by_uri(session, worksheetUri)
+ def run_sql_query(uri, worksheetUri, sqlQuery):
+ with get_context().db_engine.scoped_session() as session:
+ environment = EnvironmentService.get_environment_by_uri(session, uri)
+ worksheet = WorksheetService._get_worksheet_by_uri(session, worksheetUri)
- env_group = EnvironmentService.get_environment_group(
- session, worksheet.SamlAdminGroupName, environment.environmentUri
- )
+ env_group = EnvironmentService.get_environment_group(
+ session, worksheet.SamlAdminGroupName, environment.environmentUri
+ )
cursor = AthenaClient.run_athena_query(
aws_account_id=environment.AwsAccountId,
@@ -122,3 +150,48 @@ def run_sql_query(session, uri, worksheetUri, sqlQuery):
)
return AthenaClient.convert_query_output(cursor)
+
+ @staticmethod
+ @ResourcePolicyService.has_resource_permission(RUN_ATHENA_QUERY)
+ @ResourceThresholdService.check_invocation_count('nlq', 'modules.worksheets.features.nlq.max_count_per_day')
+ def run_nlq(uri, prompt, worksheetUri, db_name, table_names):
+ with get_context().db_engine.scoped_session() as session:
+ environment = EnvironmentService.get_environment_by_uri(session, uri)
+ worksheet = WorksheetService._get_worksheet_by_uri(session, worksheetUri)
+
+ env_group = EnvironmentService.get_environment_group(
+ session, worksheet.SamlAdminGroupName, environment.environmentUri
+ )
+
+ glue_client = GlueClient(
+ account_id=environment.AwsAccountId, region=environment.region, role=env_group.environmentIAMRoleArn
+ )
+
+ metadata = []
+ for table in table_names:
+ metadata.append(glue_client.get_table_metadata(database=db_name, table_name=table))
+
+ return BedrockClient().invoke_model_text_to_sql(prompt, '\n'.join(metadata))
+
+ @staticmethod
+ @ResourcePolicyService.has_resource_permission(RUN_ATHENA_QUERY)
+ @ResourceThresholdService.check_invocation_count('nlq', 'modules.worksheets.features.nlq.max_count_per_day')
+ def analyze_text_genai(uri, worksheetUri, prompt, datasetUri, key):
+ with get_context().db_engine.scoped_session() as session:
+ environment = EnvironmentService.get_environment_by_uri(session, uri)
+ worksheet = WorksheetService._get_worksheet_by_uri(session, worksheetUri)
+
+ env_group = EnvironmentService.get_environment_group(
+ session, worksheet.SamlAdminGroupName, environment.environmentUri
+ )
+
+ dataset = DatasetRepository.get_dataset_by_uri(session, datasetUri)
+
+ s3_client = S3Client(
+ account_id=environment.AwsAccountId,
+ region=environment.region,
+ role=env_group.environmentIAMRoleArn,
+ )
+
+ content = s3_client.get_content(dataset.S3BucketName, key)
+ return BedrockClient().invoke_model_process_text(prompt, content)
diff --git a/backend/migrations/env.py b/backend/migrations/env.py
index 2816677cc..84b6a1094 100644
--- a/backend/migrations/env.py
+++ b/backend/migrations/env.py
@@ -20,6 +20,7 @@
from dataall.modules.omics.db.omics_models import OmicsWorkflow, OmicsRun
from dataall.modules.metadata_forms.db.metadata_form_models import *
from dataall.modules.redshift_datasets.db.redshift_models import RedshiftDataset, RedshiftTable, RedshiftConnection
+from dataall.core.resource_threshold.db.resource_threshold_models import ResourceThreshold
# fmt: on
# enable ruff-format back
diff --git a/backend/migrations/versions/2258cd8d6e9f_resource_threshholds_added.py b/backend/migrations/versions/2258cd8d6e9f_resource_threshholds_added.py
new file mode 100644
index 000000000..1a4f76741
--- /dev/null
+++ b/backend/migrations/versions/2258cd8d6e9f_resource_threshholds_added.py
@@ -0,0 +1,37 @@
+"""resource_threshholds_added
+
+Revision ID: 2258cd8d6e9f
+Revises: 5a798acc6282
+Create Date: 2024-08-22 12:31:38.465650
+
+"""
+
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '2258cd8d6e9f'
+down_revision = '5a798acc6282'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table(
+ 'resource_threshold',
+ sa.Column('actionUri', sa.String(length=64), nullable=False),
+ sa.Column('username', sa.String(length=64), nullable=False),
+ sa.Column('actionType', sa.String(length=64), nullable=False),
+ sa.Column('date', sa.Date(), nullable=False),
+ sa.Column('count', sa.Integer(), nullable=False),
+ sa.PrimaryKeyConstraint('actionUri'),
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('resource_threshold')
+ # ### end Alembic commands ###
diff --git a/backend/requirements.txt b/backend/requirements.txt
index 05cb6619c..7048f6429 100644
--- a/backend/requirements.txt
+++ b/backend/requirements.txt
@@ -13,3 +13,6 @@ requests_aws4auth==1.1.1
sqlalchemy==1.3.24
alembic==1.13.1
retrying==1.3.4
+langchain-aws==0.2.2
+langchain-core==0.3.11
+pypdf==5.0.1
\ No newline at end of file
diff --git a/config.json b/config.json
index e3af66063..4461f8274 100644
--- a/config.json
+++ b/config.json
@@ -77,7 +77,13 @@
"active": true
},
"worksheets": {
- "active": true
+ "active": true,
+ "features": {
+ "nlq": {
+ "active": true,
+ "max_count_per_day": 25
+ }
+ }
},
"dashboards": {
"active": true
diff --git a/deploy/stacks/lambda_api.py b/deploy/stacks/lambda_api.py
index 7a32dc62f..a76d0f040 100644
--- a/deploy/stacks/lambda_api.py
+++ b/deploy/stacks/lambda_api.py
@@ -30,6 +30,7 @@
from .pyNestedStack import pyNestedClass
from .solution_bundling import SolutionBundling
from .waf_rules import get_waf_rules
+from .run_if import run_if
class LambdaApiStack(pyNestedClass):
@@ -154,7 +155,9 @@ def __init__(
retention=getattr(logs.RetentionDays, self.log_retention_duration),
),
description='dataall graphql function',
- role=self.create_function_role(envname, resource_prefix, 'graphql', pivot_role_name, vpc),
+ role=self.create_function_role(
+ envname, resource_prefix, 'graphql', pivot_role_name, vpc, self._get_bedrock_policy_statement() or []
+ ),
code=_lambda.DockerImageCode.from_ecr(
repository=ecr_repository, tag=image_tag, cmd=['api_handler.handler']
),
@@ -345,7 +348,25 @@ def create_lambda_sgs(self, envname, name, resource_prefix, vpc):
)
return lambda_sg
- def create_function_role(self, envname, resource_prefix, fn_name, pivot_role_name, vpc):
+ @run_if(['modules.worksheets.features.nlq.active'])
+ def _get_bedrock_policy_statement(self):
+ return [
+ iam.PolicyStatement(
+ actions=[
+ 'bedrock:InvokeModel',
+ 'bedrock:GetPrompt',
+ 'bedrock:CreateFoundationModelAgreement',
+ 'bedrock:InvokeFlow',
+ ],
+ resources=[
+ f'arn:aws:bedrock:{self.region}:{self.account}:flow/*',
+ f'arn:aws:bedrock:{self.region}:{self.account}:prompt/*',
+ f'arn:aws:bedrock:{self.region}::foundation-model/*',
+ ],
+ )
+ ]
+
+ def create_function_role(self, envname, resource_prefix, fn_name, pivot_role_name, vpc, extra_statements=[]):
role_name = f'{resource_prefix}-{envname}-{fn_name}-role'
role_inline_policy = iam.Policy(
@@ -473,7 +494,8 @@ def create_function_role(self, envname, resource_prefix, fn_name, pivot_role_nam
actions=['events:EnableRule', 'events:DisableRule'],
resources=[f'arn:aws:events:{self.region}:{self.account}:rule/dataall*'],
),
- ],
+ ]
+ + extra_statements,
)
role = iam.Role(
self,
diff --git a/documentation/userguide/docs/pictures/worksheets/ws_analyze_txt_doc.png b/documentation/userguide/docs/pictures/worksheets/ws_analyze_txt_doc.png
new file mode 100644
index 000000000..ba000b7bb
Binary files /dev/null and b/documentation/userguide/docs/pictures/worksheets/ws_analyze_txt_doc.png differ
diff --git a/documentation/userguide/docs/pictures/worksheets/ws_text_to_sql.png b/documentation/userguide/docs/pictures/worksheets/ws_text_to_sql.png
new file mode 100644
index 000000000..26a12c211
Binary files /dev/null and b/documentation/userguide/docs/pictures/worksheets/ws_text_to_sql.png differ
diff --git a/documentation/userguide/docs/worksheets.md b/documentation/userguide/docs/worksheets.md
index 006cdcf70..bfdd9bfdc 100644
--- a/documentation/userguide/docs/worksheets.md
+++ b/documentation/userguide/docs/worksheets.md
@@ -45,7 +45,7 @@ coming from Athena will pop-up automatically.

-If you want to save the current query for later or for other users, click on the **save** icon (between the edit and the
+If you want to save the current query for later or for other users, click on the **save** icon (next to the edit and the
delete buttons).
!!! success "More than just SELECT"
@@ -55,4 +55,56 @@ delete buttons).
for more information on AWS Athena SQL syntax.
+## :material-new-box: **Experimental Features: GenAI Powered Worksheets**
+
+As part of data.all >= v2.7 we introduced support for generative AI powered worksheet features. These features include both:
+
+1. Natural Language Querying (NLQ) of Structured Data
+2. Text Document Analysis of Unstructured Data
+
+These features are optionally enabled/disabled via feature flags specified in data.all's configuration.
+
+More details on how to use each of these features are below.
+
+### Natural Language Querying (NLQ) of Structured Data
+
+data.all offers a NLQ feature to significantly reduce the barrier to entry for non-technical business users who need to quickly and easily query data to make informed decisions.
+
+Given a prompt and a selection of tables, data.all NLQ feature will generate the corresponding SQL statement that data.all users can execute against the data they have access to in data.all's Worksheets module.
+
+To start generating SQL, data.all users can select the TextToSQL Tab in the Worksheets View:
+
+
+
+Users select the Worksheet environment, database and one or more tables where the data of interest is stored. Then they introduce a prompt describing the operation they want to perform. For example, they could type something like "Give me the top 3 clients in the last 10 months". Once they send the request to generate the query, data.all will invoke Claude 3.5 Sonnet model using Amazon Bedrock to generate a response.
+
+To enrich the context of the genAI request, data.all fetches the Glue metadata of the tables and database and passes it to the LLM. Access to Glue is limited to the tables the user has access to, in other words, we control that only accessible glue tables are fetched.
+
+In addition, there are built in guardrails to avoid mutating SQL statements (i.e. WRITE, UPSERT, DELETE, etc.).
+
+data.all Admins can additionally limit the number of invocations run against these LLMs by specifying a `max_count_per_day` feature flag in data.all's configuration (please reference data.all's [Deployment Guide](https://data-dot-all.github.io/dataall/deploy-aws/#configjson) for more information).
+
+
+### Text Document Analysis of Unstructured Data
+
+For unstructured text documents, data.all offers a feature to start analyzing your data using natural language.
+
+Given a prompt and a selected text docuemnt in a S3 Dataset, data.all's Document Analyzer feature will generate a response displayed in the data.all Worksheet Editor.
+
+!!! warning "Limitations of Document Analysis"
+ Currently data.all's Worksheet Document Analyzer is limited only to `.txt` and `.pdf` file extensions. Additionally, the feature is limited only to
+ text documents which are explicitly owned by one of the user's teams (documents that are given access via data.all shares are not yet supported).
+
+
+To start analyzing your text documents, data.all users can select the Document Analyzer Tab in the Worksheets View:
+
+
+
+Users select the Worksheet environment, S3 dataset bucket and S3 object key (.txt or .pdf file) where the data of interest is stored. Then they introduce a prompt describing the information they want from the text document. For example, they could type something like "Give me the most prevalent 3 themes across this document". Once they send the request, data.all will invoke Claude 3.5 Sonnet model using Amazon Bedrock to generate a response.
+
+data.all fetches the content of the S3 Object and passes it to the LLM along with the user prompt. Access to S3 is limited to the buckets the user owns.
+
+There are built in guardrails to reduce hallucinations by ensuring the selected S3 Object contains information pertaining to the user's prompt.
+
+data.all Admins can additionally limit the number of invocations run against these LLMs by specifying a `max_count_per_day` feature flag in data.all's configuration (please reference data.all's [Deployment Guide](https://data-dot-all.github.io/dataall/deploy-aws/#configjson) for more information).
diff --git a/frontend/src/modules/DatasetsBase/components/DatasetGovernance.js b/frontend/src/modules/DatasetsBase/components/DatasetGovernance.js
index f2e86a115..2cda63d26 100644
--- a/frontend/src/modules/DatasetsBase/components/DatasetGovernance.js
+++ b/frontend/src/modules/DatasetsBase/components/DatasetGovernance.js
@@ -9,9 +9,8 @@ import {
Typography
} from '@mui/material';
import PropTypes from 'prop-types';
-import { Label } from 'design';
+import { Label, UserModal } from 'design';
import { isFeatureEnabled } from 'utils';
-import { UserModal } from 'design';
import { useState } from 'react';
export const DatasetGovernance = (props) => {
diff --git a/frontend/src/modules/Environments/components/EnvironmentOverview.js b/frontend/src/modules/Environments/components/EnvironmentOverview.js
index ae6f012f0..2991c67e6 100644
--- a/frontend/src/modules/Environments/components/EnvironmentOverview.js
+++ b/frontend/src/modules/Environments/components/EnvironmentOverview.js
@@ -1,8 +1,7 @@
import React, { useState } from 'react';
import { Box, Grid } from '@mui/material';
import PropTypes from 'prop-types';
-import { ObjectBrief, ObjectMetadata } from 'design';
-import { UserModal } from 'design';
+import { ObjectBrief, ObjectMetadata, UserModal } from 'design';
import { EnvironmentConsoleAccess } from './EnvironmentConsoleAccess';
import { EnvironmentFeatures } from './EnvironmentFeatures';
diff --git a/frontend/src/modules/Organizations/components/OrganizationOverview.js b/frontend/src/modules/Organizations/components/OrganizationOverview.js
index 95234dd1d..4f58e6102 100644
--- a/frontend/src/modules/Organizations/components/OrganizationOverview.js
+++ b/frontend/src/modules/Organizations/components/OrganizationOverview.js
@@ -1,8 +1,7 @@
import React, { useState } from 'react';
import { Box, Grid } from '@mui/material';
import PropTypes from 'prop-types';
-import { ObjectBrief, ObjectMetadata } from 'design';
-import { UserModal } from 'design';
+import { ObjectBrief, ObjectMetadata, UserModal } from 'design';
export const OrganizationOverview = (props) => {
const { organization, ...other } = props;
diff --git a/frontend/src/modules/Worksheets/components/SQLQueryEditor.js b/frontend/src/modules/Worksheets/components/SQLQueryEditor.js
index dba02596f..2715963c0 100644
--- a/frontend/src/modules/Worksheets/components/SQLQueryEditor.js
+++ b/frontend/src/modules/Worksheets/components/SQLQueryEditor.js
@@ -3,7 +3,12 @@ import PropTypes from 'prop-types';
import { useRef } from 'react';
import { THEMES, useSettings } from 'design';
-export const SQLQueryEditor = ({ sql, setSqlBody }) => {
+export const SQLQueryEditor = ({
+ sql,
+ setSqlBody,
+ height = '19rem',
+ language = 'sql'
+}) => {
const { settings } = useSettings();
const valueGetter = useRef();
function handleEditorDidMount(_valueGetter) {
@@ -22,9 +27,9 @@ export const SQLQueryEditor = ({ sql, setSqlBody }) => {
options={{ minimap: { enabled: false } }}
theme={settings.theme === THEMES.LIGHT ? 'light' : 'vs-dark'}
inDiffEditor={false}
- height="19rem"
+ height={height}
editorDidMount={() => handleEditorDidMount()}
- language="sql"
+ language={language}
showPrintMargin
showGutter
highlightActiveLine
diff --git a/frontend/src/modules/Worksheets/components/WorksheetDocAnalyzer.js b/frontend/src/modules/Worksheets/components/WorksheetDocAnalyzer.js
new file mode 100644
index 000000000..5b7ac7fe0
--- /dev/null
+++ b/frontend/src/modules/Worksheets/components/WorksheetDocAnalyzer.js
@@ -0,0 +1,249 @@
+import {
+ Box,
+ Card,
+ CircularProgress,
+ MenuItem,
+ TextField
+} from '@mui/material';
+import { LoadingButton } from '@mui/lab';
+import React, { useCallback, useState } from 'react';
+import { Scrollbar } from 'design';
+import { SET_ERROR, useDispatch } from 'globalErrors';
+import { listS3ObjectKeys, useClient } from 'services';
+import { analyzeTextDocument } from '../services';
+import PropTypes from 'prop-types';
+
+export const WorksheetDocAnalyzer = ({
+ handleEnvironmentChange,
+ loadingEnvs,
+ currentEnv,
+ environmentOptions,
+ worksheet,
+ selectedDatabase,
+ loadingDatabases,
+ databaseOptions,
+ handleTextChange,
+ setSelectedDatabase
+}) => {
+ const dispatch = useDispatch();
+ const client = useClient();
+ const [invoking, setInvoking] = useState(false);
+ const [prompt, setPrompt] = useState('');
+ const [loadingKeys, setLoadingKeys] = useState(false);
+ const [keyOptions, setKeyOptions] = useState([]);
+ const [selectedKey, setSelectedKey] = useState('');
+ const filteredDBOptions = databaseOptions.filter((db) => 'bucketName' in db);
+
+ function handleBucketChange(event) {
+ setSelectedDatabase(event.target.value);
+ fetchKeys(currentEnv, event.target.value).catch((e) =>
+ dispatch({ type: SET_ERROR, error: e.message })
+ );
+ }
+ const fetchKeys = useCallback(
+ async (environment, dataset) => {
+ setLoadingKeys(true);
+ const response = await client.query(
+ listS3ObjectKeys({
+ datasetUri: dataset.value
+ })
+ );
+ if (!response.errors) {
+ setKeyOptions(response.data.listS3ObjectKeys);
+ } else {
+ dispatch({ type: SET_ERROR, error: response.errors[0].message });
+ }
+ setLoadingKeys(false);
+ },
+ [client, dispatch]
+ );
+
+ const handleSubmit = async () => {
+ setInvoking(true);
+ const response = await client.query(
+ analyzeTextDocument({
+ prompt: prompt,
+ key: selectedKey,
+ environmentUri: currentEnv.environmentUri,
+ worksheetUri: worksheet.worksheetUri,
+ datasetUri: selectedDatabase.value
+ })
+ );
+ if (!response.errors) {
+ handleTextChange(response.data.analyzeTextDocument);
+ } else {
+ dispatch({ type: SET_ERROR, error: response.errors[0].message });
+ }
+ setInvoking(false);
+ };
+
+ function handleKeyChange(event) {
+ setSelectedKey(event.target.value);
+ }
+
+ return (
+
+
+
+
+
+ {
+ handleEnvironmentChange(event);
+ }}
+ select
+ value={currentEnv}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingEnvs ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {environmentOptions.map((environment) => (
+
+ ))}
+
+
+
+
+
+
+ {
+ handleBucketChange(event);
+ }}
+ select
+ value={selectedDatabase}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingDatabases ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {filteredDBOptions.length > 0 ? (
+ filteredDBOptions.map((database) => (
+
+ ))
+ ) : (
+
+ )}
+
+
+
+ {
+ handleKeyChange(event);
+ }}
+ select
+ value={selectedKey}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingKeys ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {keyOptions.length > 0 ? (
+ keyOptions.map((key) => (
+
+ ))
+ ) : (
+
+ )}
+
+
+
+
+ setPrompt(e.target.value)}
+ variant="outlined"
+ />
+
+
+
+ Summarize
+
+
+
+
+
+
+ );
+};
+
+WorksheetDocAnalyzer.propTypes = {
+ handleEnvironmentChange: PropTypes.func.isRequired,
+ loadingEnvs: PropTypes.bool.isRequired,
+ currentEnv: PropTypes.object.isRequired,
+ environmentOptions: PropTypes.array.isRequired,
+ worksheet: PropTypes.object.isRequired,
+ selectedDatabase: PropTypes.object.isRequired,
+ loadingDatabases: PropTypes.bool.isRequired,
+ databaseOptions: PropTypes.array.isRequired,
+ handleTextChange: PropTypes.func.isRequired,
+ setSelectedDatabase: PropTypes.func.isRequired
+};
diff --git a/frontend/src/modules/Worksheets/components/WorksheetSQLEditor.js b/frontend/src/modules/Worksheets/components/WorksheetSQLEditor.js
new file mode 100644
index 000000000..5c89d68cd
--- /dev/null
+++ b/frontend/src/modules/Worksheets/components/WorksheetSQLEditor.js
@@ -0,0 +1,230 @@
+import {
+ Box,
+ Card,
+ CircularProgress,
+ List,
+ ListItem,
+ ListItemIcon,
+ MenuItem,
+ TextField,
+ Tooltip,
+ Typography
+} from '@mui/material';
+import PropTypes from 'prop-types';
+import React from 'react';
+import { CgHashtag } from 'react-icons/cg';
+import { VscSymbolString } from 'react-icons/vsc';
+import { Scrollbar } from 'design';
+
+export const WorksheetSQLEditor = ({
+ handleEnvironmentChange,
+ loadingEnvs,
+ currentEnv,
+ environmentOptions,
+ worksheet,
+ handleDatabaseChange,
+ selectedDatabase,
+ loadingDatabases,
+ databaseOptions,
+ handleTableChange,
+ selectedTable,
+ loadingTables,
+ tableOptions,
+ loadingColumns,
+ columns
+}) => {
+ return (
+
+
+
+
+
+ {
+ handleEnvironmentChange(event);
+ }}
+ select
+ value={currentEnv}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingEnvs ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {environmentOptions.map((environment) => (
+
+ ))}
+
+
+
+
+
+
+ {
+ handleDatabaseChange(event);
+ }}
+ select
+ value={selectedDatabase}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingDatabases ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {databaseOptions.length > 0 ? (
+ databaseOptions.map((database) => (
+
+ ))
+ ) : (
+
+ )}
+
+
+
+ {
+ handleTableChange(event);
+ }}
+ select
+ value={selectedTable}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingTables ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {tableOptions.length > 0 ? (
+ tableOptions.map((table) => (
+
+ ))
+ ) : (
+
+ )}
+
+
+ {loadingColumns ? (
+
+ ) : (
+
+ {columns && columns.length > 0 && (
+
+
+ Columns
+
+
+ {columns.map((col) => (
+
+
+ {col.typeName !== 'string' ? (
+
+
+
+ ) : (
+
+
+
+ )}
+
+
+
+ {col.name.substring(0, 22)}
+
+
+
+
+
+ ))}
+
+
+ )}
+
+ )}
+
+
+
+
+ );
+};
+
+WorksheetSQLEditor.propTypes = {
+ handleEnvironmentChange: PropTypes.func.isRequired,
+ loadingEnvs: PropTypes.bool.isRequired,
+ currentEnv: PropTypes.object.isRequired,
+ environmentOptions: PropTypes.array.isRequired,
+ worksheet: PropTypes.object.isRequired,
+ handleDatabaseChange: PropTypes.func.isRequired,
+ selectedDatabase: PropTypes.object.isRequired,
+ loadingDatabases: PropTypes.bool.isRequired,
+ databaseOptions: PropTypes.array.isRequired,
+ handleTableChange: PropTypes.func.isRequired,
+ selectedTable: PropTypes.object.isRequired,
+ loadingTables: PropTypes.bool.isRequired,
+ tableOptions: PropTypes.array.isRequired,
+ loadingColumns: PropTypes.bool.isRequired,
+ columns: PropTypes.array.isRequired
+};
diff --git a/frontend/src/modules/Worksheets/components/WorksheetTextToSQLEditor.js b/frontend/src/modules/Worksheets/components/WorksheetTextToSQLEditor.js
new file mode 100644
index 000000000..1dd1764a3
--- /dev/null
+++ b/frontend/src/modules/Worksheets/components/WorksheetTextToSQLEditor.js
@@ -0,0 +1,247 @@
+import { LoadingButton } from '@mui/lab';
+import {
+ Box,
+ Card,
+ CircularProgress,
+ MenuItem,
+ TextField,
+ Autocomplete,
+ Chip
+} from '@mui/material';
+
+import React, { useState } from 'react';
+import { Scrollbar } from 'design';
+import { SET_ERROR, useDispatch } from 'globalErrors';
+import { useClient } from 'services';
+import { textToSQL } from '../services';
+import PropTypes from 'prop-types';
+
+export const WorksheetTextToSQLEditor = ({
+ handleEnvironmentChange,
+ loadingEnvs,
+ currentEnv,
+ environmentOptions,
+ worksheet,
+ handleDatabaseChange,
+ selectedDatabase,
+ loadingDatabases,
+ databaseOptions,
+ loadingTables,
+ tableOptions,
+ handleSQLChange
+}) => {
+ const dispatch = useDispatch();
+ const client = useClient();
+ const [invoking, setInvoking] = useState(false);
+ const [selectedTables, setSelectedTables] = useState([]);
+ const [prompt, setPrompt] = useState('');
+
+ const handleSubmit = async () => {
+ setInvoking(true);
+ handleSQLChange('');
+
+ const response = await client.query(
+ textToSQL({
+ prompt: prompt,
+ environmentUri: currentEnv.environmentUri,
+ worksheetUri: worksheet.worksheetUri,
+ databaseName: selectedDatabase.label,
+ tableNames: selectedTables
+ })
+ );
+ if (!response.errors) {
+ handleSQLChange(response.data.textToSQL);
+ } else {
+ dispatch({ type: SET_ERROR, error: response.errors[0].message });
+ }
+ setInvoking(false);
+ };
+
+ function handlePromptChange(prompt) {
+ setPrompt(prompt);
+ }
+
+ function handleTablesChange(newValue) {
+ setSelectedTables(newValue);
+ setPrompt('');
+ }
+
+ function handleDatabaseChanges(event) {
+ setSelectedTables([]);
+ setPrompt('');
+ handleDatabaseChange(event);
+ }
+
+ return (
+
+
+
+
+
+ {
+ handleEnvironmentChange(event);
+ }}
+ select
+ value={currentEnv}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingEnvs ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {environmentOptions.map((environment) => (
+
+ ))}
+
+
+
+
+
+
+ {
+ handleDatabaseChanges(event);
+ }}
+ select
+ value={selectedDatabase}
+ variant="outlined"
+ InputProps={{
+ endAdornment: (
+ <>
+ {loadingDatabases ? (
+
+ ) : null}
+ >
+ )
+ }}
+ >
+ {databaseOptions.length > 0 ? (
+ databaseOptions.map((database) => (
+
+ ))
+ ) : (
+
+ )}
+
+
+
+
+ t.GlueTableName)}
+ value={selectedTables}
+ onChange={(_, newValue) => handleTablesChange(newValue)}
+ renderInput={(params) => (
+ <>
+ {loadingTables ? (
+
+ ) : (
+
+ )}
+ >
+ )}
+ renderTags={(value, getTagProps) =>
+ value.map((option, index) => (
+
+ ))
+ }
+ disabled={!selectedDatabase}
+ fullWidth
+ margin="normal"
+ />
+
+
+ handlePromptChange(e.target.value)}
+ variant="outlined"
+ />
+
+
+
+ Generate SQL
+
+
+
+
+
+
+ );
+};
+
+WorksheetTextToSQLEditor.propTypes = {
+ handleEnvironmentChange: PropTypes.func.isRequired,
+ loadingEnvs: PropTypes.bool.isRequired,
+ currentEnv: PropTypes.object.isRequired,
+ environmentOptions: PropTypes.array.isRequired,
+ worksheet: PropTypes.object.isRequired,
+ handleDatabaseChange: PropTypes.func.isRequired,
+ selectedDatabase: PropTypes.object.isRequired,
+ loadingDatabases: PropTypes.bool.isRequired,
+ databaseOptions: PropTypes.array.isRequired,
+ handleTableChange: PropTypes.func.isRequired,
+ loadingTables: PropTypes.bool.isRequired,
+ tableOptions: PropTypes.array.isRequired
+};
diff --git a/frontend/src/modules/Worksheets/components/index.js b/frontend/src/modules/Worksheets/components/index.js
index 54d2e1f8c..f438402ff 100644
--- a/frontend/src/modules/Worksheets/components/index.js
+++ b/frontend/src/modules/Worksheets/components/index.js
@@ -2,3 +2,6 @@ export * from './SQLQueryEditor';
export * from './WorksheetEditFormModal';
export * from './WorksheetListItem';
export * from './WorksheetResult';
+export * from './WorksheetSQLEditor';
+export * from './WorksheetTextToSQLEditor';
+export * from './WorksheetDocAnalyzer';
diff --git a/frontend/src/modules/Worksheets/services/analyzeTextDocument.js b/frontend/src/modules/Worksheets/services/analyzeTextDocument.js
new file mode 100644
index 000000000..32cae4369
--- /dev/null
+++ b/frontend/src/modules/Worksheets/services/analyzeTextDocument.js
@@ -0,0 +1,34 @@
+import { gql } from 'apollo-boost';
+
+export const analyzeTextDocument = ({
+ prompt,
+ environmentUri,
+ worksheetUri,
+ datasetUri,
+ key
+}) => ({
+ variables: {
+ prompt,
+ environmentUri,
+ worksheetUri,
+ datasetUri,
+ key
+ },
+ query: gql`
+ query analyzeTextDocument(
+ $prompt: String!
+ $environmentUri: String!
+ $worksheetUri: String!
+ $datasetUri: String!
+ $key: String!
+ ) {
+ analyzeTextDocument(
+ prompt: $prompt
+ environmentUri: $environmentUri
+ worksheetUri: $worksheetUri
+ datasetUri: $datasetUri
+ key: $key
+ )
+ }
+ `
+});
diff --git a/frontend/src/modules/Worksheets/services/index.js b/frontend/src/modules/Worksheets/services/index.js
index b10e7d361..53a7bf1a2 100644
--- a/frontend/src/modules/Worksheets/services/index.js
+++ b/frontend/src/modules/Worksheets/services/index.js
@@ -6,3 +6,5 @@ export * from './listWorksheets';
export * from './runAthenaSqlQuery';
export * from './updateWorksheet';
export * from './listSharedDatasetTableColumns';
+export * from './textToSQL';
+export * from './analyzeTextDocument';
diff --git a/frontend/src/modules/Worksheets/services/textToSQL.js b/frontend/src/modules/Worksheets/services/textToSQL.js
new file mode 100644
index 000000000..c6de1d098
--- /dev/null
+++ b/frontend/src/modules/Worksheets/services/textToSQL.js
@@ -0,0 +1,34 @@
+import { gql } from 'apollo-boost';
+
+export const textToSQL = ({
+ prompt,
+ environmentUri,
+ worksheetUri,
+ databaseName,
+ tableNames
+}) => ({
+ variables: {
+ prompt,
+ environmentUri,
+ worksheetUri,
+ databaseName,
+ tableNames
+ },
+ query: gql`
+ query textToSQL(
+ $prompt: String!
+ $environmentUri: String!
+ $worksheetUri: String!
+ $databaseName: String!
+ $tableNames: [String]
+ ) {
+ textToSQL(
+ prompt: $prompt
+ environmentUri: $environmentUri
+ worksheetUri: $worksheetUri
+ databaseName: $databaseName
+ tableNames: $tableNames
+ )
+ }
+ `
+});
diff --git a/frontend/src/modules/Worksheets/views/WorksheetView.js b/frontend/src/modules/Worksheets/views/WorksheetView.js
index 545c3f190..013b37d05 100644
--- a/frontend/src/modules/Worksheets/views/WorksheetView.js
+++ b/frontend/src/modules/Worksheets/views/WorksheetView.js
@@ -1,31 +1,27 @@
-import { PlayArrowOutlined, SaveOutlined } from '@mui/icons-material';
+import { PlayArrowOutlined } from '@mui/icons-material';
import { LoadingButton } from '@mui/lab';
import {
Box,
- Card,
CircularProgress,
Divider,
IconButton,
- List,
- ListItem,
- ListItemIcon,
- MenuItem,
- TextField,
- Tooltip,
- Typography
+ Typography,
+ Tabs,
+ Tab,
+ Alert,
+ Stack
} from '@mui/material';
import { useSnackbar } from 'notistack';
import React, { useCallback, useEffect, useState } from 'react';
import { Helmet } from 'react-helmet-async';
-import { CgHashtag } from 'react-icons/cg';
import { FaTrash } from 'react-icons/fa';
-import { VscSymbolString } from 'react-icons/vsc';
import { useNavigate, useParams } from 'react-router-dom';
import {
Defaults,
DeleteObjectWithFrictionModal,
PencilAltIcon,
- Scrollbar
+ SaveIcon,
+ useSettings
} from 'design';
import { SET_ERROR, useDispatch } from 'globalErrors';
import {
@@ -47,8 +43,32 @@ import {
import {
SQLQueryEditor,
WorksheetEditFormModal,
- WorksheetResult
+ WorksheetResult,
+ WorksheetTextToSQLEditor,
+ WorksheetDocAnalyzer,
+ WorksheetSQLEditor
} from '../components';
+import { isFeatureEnabled } from 'utils';
+
+const tabs = [
+ {
+ label: 'SQL Editor',
+ value: 'SQLEditor',
+ active: true
+ },
+ {
+ label: 'Text-To-SQL Editor',
+ value: 'TextToSQL',
+ active: isFeatureEnabled('worksheets', 'nlq')
+ },
+ {
+ label: 'Document Analyzer',
+ value: 'DocAnalyzer',
+ active: isFeatureEnabled('worksheets', 'nlq')
+ }
+];
+
+const activeTabs = tabs.filter((tab) => tab.active !== false);
const WorksheetView = () => {
const navigate = useNavigate();
@@ -63,6 +83,7 @@ const WorksheetView = () => {
const [sqlBody, setSqlBody] = useState(
" select 'A' as dim, 23 as nb\n union \n select 'B' as dim, 43 as nb "
);
+ const [textBody, setTextBody] = useState('');
const [currentEnv, setCurrentEnv] = useState();
const [loadingEnvs, setLoadingEnvs] = useState(false);
const [loadingDatabases, setLoadingDatabases] = useState(false);
@@ -76,6 +97,13 @@ const WorksheetView = () => {
const [runningQuery, setRunningQuery] = useState(false);
const [isEditWorksheetOpen, setIsEditWorksheetOpen] = useState(null);
const [isDeleteWorksheetOpen, setIsDeleteWorksheetOpen] = useState(null);
+ const [currentTab, setCurrentTab] = useState(activeTabs[0].value);
+ const { settings } = useSettings();
+
+ const handleTabChange = (event, newValue) => {
+ setCurrentTab(newValue);
+ };
+
const handleEditWorksheetModalOpen = () => {
setIsEditWorksheetOpen(true);
};
@@ -136,7 +164,8 @@ const WorksheetView = () => {
(d) => ({
...d,
value: d.datasetUri,
- label: d.GlueDatabaseName
+ label: d.GlueDatabaseName,
+ bucketName: d.S3BucketName
})
);
}
@@ -362,6 +391,14 @@ const WorksheetView = () => {
);
}
+ function handleSQLChange(value) {
+ setSqlBody(value);
+ }
+
+ function handleTextChange(value) {
+ setTextBody(value);
+ }
+
function handleDatabaseChange(event) {
setColumns([]);
setTableOptions([]);
@@ -392,6 +429,24 @@ const WorksheetView = () => {
Worksheet | data.all
+
+ {activeTabs.map((tab) => (
+
+ ))}
+
{
height: '100%'
}}
>
-
-
-
-
- {
- handleEnvironmentChange(event);
- }}
- select
- value={currentEnv}
- variant="outlined"
- InputProps={{
- endAdornment: (
- <>
- {loadingEnvs ? (
-
- ) : null}
- >
- )
- }}
- >
- {environmentOptions.map((environment) => (
-
- ))}
-
-
-
-
-
-
- {
- handleDatabaseChange(event);
- }}
- select
- value={selectedDatabase}
- variant="outlined"
- InputProps={{
- endAdornment: (
- <>
- {loadingDatabases ? (
-
- ) : null}
- >
- )
- }}
- >
- {databaseOptions.length > 0 ? (
- databaseOptions.map((database) => (
-
- ))
- ) : (
-
- )}
-
-
-
- {
- handleTableChange(event);
- }}
- select
- value={selectedTable}
- variant="outlined"
- InputProps={{
- endAdornment: (
- <>
- {loadingTables ? (
-
- ) : null}
- >
- )
- }}
- >
- {tableOptions.length > 0 ? (
- tableOptions.map((table) => (
-
- ))
- ) : (
-
- )}
-
-
- {loadingColumns ? (
-
- ) : (
-
- {columns && columns.length > 0 && (
-
-
- Columns
-
-
- {columns.map((col) => (
-
-
- {col.typeName !== 'string' ? (
-
-
-
- ) : (
-
-
-
- )}
-
-
-
- {col.name.substring(0, 22)}
-
-
-
-
-
- ))}
-
-
- )}
-
- )}
-
-
-
+ {currentTab === 'SQLEditor' && (
+
+ )}
+ {currentTab === 'TextToSQL' && (
+
+ )}
+ {currentTab === 'DocAnalyzer' && (
+
+ )}
{
flexGrow: 1
}}
>
-
-
-
- {worksheet.label}
-
+
+
+
+
+ {worksheet.label}
+
+
+
+ {currentTab !== 'DocAnalyzer' && (
+
+
+
+ )}
+
+
+
+
+
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- }
- sx={{ m: 1 }}
- variant="contained"
+
+ {currentTab !== 'SQLEditor' && (
+
- Run Query
-
-
-
-
-
-
+
+
+ Experimental Feature: Carefully review this AI-generated
+ response for accuracy
+
+
+
+ )}
+ {currentTab !== 'DocAnalyzer' ? (
+ <>
+
+
+
+
+
+
+ }
+ sx={{ m: 1 }}
+ variant="contained"
+ >
+ Run Query
+
+
+
+
+
+
+ >
+ ) : (
+
+
+
+ )}
{worksheet && isEditWorksheetOpen && (
diff --git a/frontend/src/services/graphql/Datasets/index.js b/frontend/src/services/graphql/Datasets/index.js
index 6c83fc1e0..7a4a11e12 100644
--- a/frontend/src/services/graphql/Datasets/index.js
+++ b/frontend/src/services/graphql/Datasets/index.js
@@ -5,3 +5,4 @@ export * from './getDatasetSharedAssumeRoleUrl';
export * from './listDatasetTables';
export * from './listS3DatasetsOwnedByEnvGroup';
export * from './removeDatasetStorageLocation';
+export * from './listS3ObjectKeys';
diff --git a/frontend/src/services/graphql/Datasets/listS3ObjectKeys.js b/frontend/src/services/graphql/Datasets/listS3ObjectKeys.js
new file mode 100644
index 000000000..b78b5f3ff
--- /dev/null
+++ b/frontend/src/services/graphql/Datasets/listS3ObjectKeys.js
@@ -0,0 +1,12 @@
+import { gql } from 'apollo-boost';
+
+export const listS3ObjectKeys = ({ datasetUri }) => ({
+ variables: {
+ datasetUri
+ },
+ query: gql`
+ query listS3ObjectKeys($datasetUri: String!) {
+ listS3ObjectKeys(datasetUri: $datasetUri)
+ }
+ `
+});
diff --git a/tests_new/integration_tests/modules/s3_datasets/global_conftest.py b/tests_new/integration_tests/modules/s3_datasets/global_conftest.py
index e05861e73..81cc89662 100644
--- a/tests_new/integration_tests/modules/s3_datasets/global_conftest.py
+++ b/tests_new/integration_tests/modules/s3_datasets/global_conftest.py
@@ -448,6 +448,17 @@ def get_or_create_persistent_s3_dataset(
if withContent:
create_tables(client, s3_dataset)
create_folders(client, s3_dataset)
+ creds = json.loads(generate_dataset_access_token(client, s3_dataset.datasetUri))
+ dataset_session = boto3.Session(
+ aws_access_key_id=creds['AccessKey'],
+ aws_secret_access_key=creds['SessionKey'],
+ aws_session_token=creds['sessionToken'],
+ )
+ file_path = os.path.join(os.path.dirname(__file__), 'sample_data/folder1/txt_sample.txt')
+ s3_client = S3Client(dataset_session, s3_dataset.region)
+ s3_client.upload_file_to_prefix(
+ local_file_path=file_path, s3_path=f'{s3_dataset.S3BucketName}/sessionFolderA'
+ )
if s3_dataset.stack.status in ['CREATE_COMPLETE', 'UPDATE_COMPLETE']:
return s3_dataset
@@ -551,3 +562,8 @@ def persistent_imported_kms_s3_dataset1(
kms_alias=resource_name,
glue_database=resource_name,
)
+
+
+@pytest.fixture(scope='session')
+def persistent_s3_dataset1_folders(client1, persistent_s3_dataset1):
+ yield create_folders(client1, persistent_s3_dataset1)
diff --git a/tests_new/integration_tests/modules/s3_datasets/queries.py b/tests_new/integration_tests/modules/s3_datasets/queries.py
index c129954da..932b4a5d0 100644
--- a/tests_new/integration_tests/modules/s3_datasets/queries.py
+++ b/tests_new/integration_tests/modules/s3_datasets/queries.py
@@ -749,3 +749,17 @@ def get_table_profiling_run(client, tableUri):
}
response = client.query(query=query)
return response.data.getDatasetTableProfilingRun
+
+
+def list_s3_object_keys(client, datasetUri):
+ query = {
+ 'operationName': 'listS3ObjectKeys',
+ 'variables': {'datasetUri': datasetUri},
+ 'query': """
+ query listS3ObjectKeys($datasetUri: String!) {
+ listS3ObjectKeys(datasetUri: $datasetUri)
+ }
+ """,
+ }
+ response = client.query(query=query)
+ return response.data.listS3ObjectKeys
diff --git a/tests_new/integration_tests/modules/s3_datasets/sample_data/folder1/txt_sample.txt b/tests_new/integration_tests/modules/s3_datasets/sample_data/folder1/txt_sample.txt
new file mode 100644
index 000000000..c57eff55e
--- /dev/null
+++ b/tests_new/integration_tests/modules/s3_datasets/sample_data/folder1/txt_sample.txt
@@ -0,0 +1 @@
+Hello World!
\ No newline at end of file
diff --git a/tests_new/integration_tests/modules/s3_datasets/test_s3_dataset.py b/tests_new/integration_tests/modules/s3_datasets/test_s3_dataset.py
index 8e08f2855..7c9c232e4 100644
--- a/tests_new/integration_tests/modules/s3_datasets/test_s3_dataset.py
+++ b/tests_new/integration_tests/modules/s3_datasets/test_s3_dataset.py
@@ -17,6 +17,7 @@
start_glue_crawler,
update_dataset,
list_s3_datasets_owned_by_env_group,
+ list_s3_object_keys,
)
from integration_tests.core.stack.queries import update_stack
from integration_tests.core.stack.utils import check_stack_ready
@@ -284,3 +285,14 @@ def test_persistent_import_kms_s3_dataset_update(client1, persistent_imported_km
client1, env_uri=env_uri, stack_uri=stack_uri, target_uri=dataset_uri, target_type=target_type
)
assert_that(stack.status).is_in('CREATE_COMPLETE', 'UPDATE_COMPLETE')
+
+
+def test_list_s3_object_keys(client1, persistent_s3_dataset1):
+ response = list_s3_object_keys(client1, persistent_s3_dataset1.datasetUri)
+ assert_that(response).contains('sessionFolderA/txt_sample.txt')
+
+
+def test_list_s3_object_keys_unauthorized(client2, persistent_s3_dataset1):
+ assert_that(list_s3_object_keys).raises(GqlError).when_called_with(
+ client2, persistent_s3_dataset1.datasetUri
+ ).contains('UnauthorizedOperation', 'GET_DATASET', persistent_s3_dataset1.datasetUri)
diff --git a/tests_new/integration_tests/modules/worksheets/queries.py b/tests_new/integration_tests/modules/worksheets/queries.py
index 71af5bf5e..84616ebcf 100644
--- a/tests_new/integration_tests/modules/worksheets/queries.py
+++ b/tests_new/integration_tests/modules/worksheets/queries.py
@@ -182,3 +182,67 @@ def update_worksheet(client, worksheet_uri, name='', description='', tags=[]):
}
response = client.query(query=query)
return response.data.updateWorksheet
+
+
+def text_to_sql(client, prompt, environment_uri, worksheet_uri, database_name, table_names=[]):
+ query = {
+ 'operationName': 'textToSQL',
+ 'variables': {
+ 'prompt': prompt,
+ 'environmentUri': environment_uri,
+ 'worksheetUri': worksheet_uri,
+ 'databaseName': database_name,
+ 'tableNames': table_names,
+ },
+ 'query': """
+ query textToSQL(
+ $prompt: String!
+ $environmentUri: String!
+ $worksheetUri: String!
+ $databaseName: String!
+ $tableNames: [String]
+ ) {
+ textToSQL(
+ prompt: $prompt
+ environmentUri: $environmentUri
+ worksheetUri: $worksheetUri
+ databaseName: $databaseName
+ tableNames: $tableNames
+ )
+ }
+ """,
+ }
+ response = client.query(query=query)
+ return response.data.textToSQL
+
+
+def analyze_text_document(client, prompt, environment_uri, worksheet_uri, dataset_uri, key):
+ query = {
+ 'operationName': 'analyzeTextDocument',
+ 'variables': {
+ 'prompt': prompt,
+ 'environmentUri': environment_uri,
+ 'worksheetUri': worksheet_uri,
+ 'datasetUri': dataset_uri,
+ 'key': key,
+ },
+ 'query': """
+ query analyzeTextDocument(
+ $prompt: String!
+ $environmentUri: String!
+ $worksheetUri: String!
+ $datasetUri: String!
+ $key: String!
+ ) {
+ analyzeTextDocument(
+ prompt: $prompt
+ environmentUri: $environmentUri
+ worksheetUri: $worksheetUri
+ datasetUri: $datasetUri
+ key: $key
+ )
+ }
+ """,
+ }
+ response = client.query(query=query)
+ return response.data.analyzeTextDocument
diff --git a/tests_new/integration_tests/modules/worksheets/test_worksheet.py b/tests_new/integration_tests/modules/worksheets/test_worksheet.py
index 3138776ea..aae6c3cc4 100644
--- a/tests_new/integration_tests/modules/worksheets/test_worksheet.py
+++ b/tests_new/integration_tests/modules/worksheets/test_worksheet.py
@@ -1,4 +1,5 @@
from assertpy import assert_that
+import pytest
from integration_tests.modules.worksheets.queries import (
create_worksheet,
@@ -7,8 +8,11 @@
list_worksheets,
run_athena_sql_query,
update_worksheet,
+ text_to_sql,
+ analyze_text_document,
)
from integration_tests.errors import GqlError
+from dataall.base.config import config
def test_create_worksheet(client1, worksheet1):
@@ -59,3 +63,97 @@ def test_update_worksheet_unauthorized(client2, worksheet1):
assert_that(update_worksheet).raises(GqlError).when_called_with(
client2, worksheet1.worksheetUri, worksheet1.label, 'updated desc', worksheet1.tags
).contains('UnauthorizedOperation', 'UPDATE_WORKSHEET')
+
+
+def test_run_athena_sql_query(client1, worksheet1, persistent_env1):
+ sql_query = 'SHOW DATABASES;'
+ rows = run_athena_sql_query(
+ client=client1,
+ query=sql_query,
+ environment_uri=persistent_env1.environmentUri,
+ worksheet_uri=worksheet1.worksheetUri,
+ ).rows
+ assert_that(rows).is_not_empty()
+ db_names = [r.cells[0].value for r in rows]
+ assert_that(db_names).contains('default')
+
+
+def test_run_athena_sql_query_unauthorized(client2, worksheet1, persistent_env1):
+ sql_query = 'SHOW DATABASES;'
+ assert_that(run_athena_sql_query).raises(GqlError).when_called_with(
+ client2, sql_query, persistent_env1.environmentUri, worksheet1.worksheetUri
+ ).contains('UnauthorizedOperation', 'RUN_ATHENA_QUERY')
+
+
+@pytest.mark.skipif(
+ not config.get_property('modules.worksheets.features.nlq.active'), reason='Feature Disabled by Config'
+)
+def test_text_to_sql(client1, worksheet1, persistent_env1):
+ prompt = 'Write me a query to list the databases I have access to in Athena'
+ response = text_to_sql(
+ client=client1,
+ prompt=prompt,
+ environment_uri=persistent_env1.environmentUri,
+ worksheet_uri=worksheet1.worksheetUri,
+ database_name='',
+ table_names=[],
+ )
+ # Results are nondeterministic - just asserting the response is not None
+ assert_that(response).is_not_none()
+
+
+@pytest.mark.skipif(
+ not config.get_property('modules.worksheets.features.nlq.active'), reason='Feature Disabled by Config'
+)
+def test_text_to_sql_unauthorized(client2, worksheet1, persistent_env1):
+ prompt = 'Write a query to access data in athena'
+ assert_that(text_to_sql).raises(GqlError).when_called_with(
+ client2, prompt, persistent_env1.environmentUri, worksheet1.worksheetUri, ''
+ ).contains('UnauthorizedOperation', 'RUN_ATHENA_QUERY')
+
+
+@pytest.mark.skipif(
+ not config.get_property('modules.worksheets.features.nlq.active'), reason='Feature Disabled by Config'
+)
+def test_analyze_text_doc(client1, worksheet1, persistent_env1, persistent_s3_dataset1):
+ prompt = 'Give me the first character of the first line of this text document'
+ response = analyze_text_document(
+ client=client1,
+ prompt=prompt,
+ environment_uri=persistent_env1.environmentUri,
+ worksheet_uri=worksheet1.worksheetUri,
+ dataset_uri=persistent_s3_dataset1.datasetUri,
+ key='sessionFolderA/txt_sample.txt',
+ )
+ # Results are nondeterministic - just asserting the response is not None
+ assert_that(response).is_not_none()
+
+
+@pytest.mark.skipif(
+ not config.get_property('modules.worksheets.features.nlq.active'), reason='Feature Disabled by Config'
+)
+def test_analyze_text_doc_invalid_object(client1, worksheet1, persistent_env1, persistent_s3_dataset1):
+ prompt = 'Give me a summary of this text document'
+ assert_that(analyze_text_document).raises(GqlError).when_called_with(
+ client1,
+ prompt,
+ persistent_env1.environmentUri,
+ worksheet1.worksheetUri,
+ persistent_s3_dataset1.datasetUri,
+ 'some_file.js',
+ ).contains('S3 Object Key', 'some_file.js', 'InvalidInput')
+
+
+@pytest.mark.skipif(
+ not config.get_property('modules.worksheets.features.nlq.active'), reason='Feature Disabled by Config'
+)
+def test_analyze_text_doc_unauthorized(client2, worksheet1, persistent_env1, persistent_s3_dataset1):
+ prompt = 'Give me a summary of this text document'
+ assert_that(analyze_text_document).raises(GqlError).when_called_with(
+ client2,
+ prompt,
+ persistent_env1.environmentUri,
+ worksheet1.worksheetUri,
+ persistent_s3_dataset1.datasetUri,
+ 'file.txt',
+ ).contains('UnauthorizedOperation', 'RUN_ATHENA_QUERY')