Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,058 Bytes
14dc68f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
#!/usr/bin/env python3
import os
import argparse
import openai
import pinecone
from dotenv import load_dotenv
load_dotenv()
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "")
assert OPENAI_API_KEY, "OPENAI_API_KEY environment variable is missing from .env"
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY", "")
assert PINECONE_API_KEY, "PINECONE_API_KEY environment variable is missing from .env"
PINECONE_ENVIRONMENT = os.getenv("PINECONE_ENVIRONMENT", "us-east1-gcp")
assert PINECONE_ENVIRONMENT, "PINECONE_ENVIRONMENT environment variable is missing from .env"
# Table config
PINECONE_TABLE_NAME = os.getenv("TABLE_NAME", "")
assert PINECONE_TABLE_NAME, "TABLE_NAME environment variable is missing from .env"
# Function to query records from the Pinecone index
def query_records(index, query, top_k=1000):
results = index.query(query, top_k=top_k, include_metadata=True)
return [f"{task.metadata['task']}:\n{task.metadata['result']}\n------------------" for task in results.matches]
# Get embedding for the text
def get_ada_embedding(text):
text = text.replace("\n", " ")
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Query Pinecone index using a string.")
parser.add_argument('objective', nargs='*', metavar='<objective>', help='''
main objective description. Doesn\'t need to be quoted.
if not specified, get objective from environment.
''', default=[os.getenv("OBJECTIVE", "")])
args = parser.parse_args()
# Configure OpenAI
openai.api_key = OPENAI_API_KEY
# Initialize Pinecone
pinecone.init(api_key=PINECONE_API_KEY)
# Connect to the objective index
index = pinecone.Index(PINECONE_TABLE_NAME)
# Query records from the index
query = get_ada_embedding(' '.join(args.objective).strip())
retrieved_tasks = query_records(index, query)
for r in retrieved_tasks:
print(r)
if __name__ == "__main__":
main()
|