DRIFT Search
In [1]:
Copied!
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License.
In [2]:
Copied!
import os
from pathlib import Path
import pandas as pd
import tiktoken
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
print(f"Entity df columns: {entity_df.columns}")
entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
full_content_embedding_store = LanceDBVectorStore(
collection_name="default-community-full_content",
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
import os
from pathlib import Path
import pandas as pd
import tiktoken
from graphrag.config.enums import ModelType
from graphrag.config.models.drift_search_config import DRIFTSearchConfig
from graphrag.config.models.language_model_config import LanguageModelConfig
from graphrag.language_model.manager import ModelManager
from graphrag.query.indexer_adapters import (
read_indexer_entities,
read_indexer_relationships,
read_indexer_report_embeddings,
read_indexer_reports,
read_indexer_text_units,
)
from graphrag.query.structured_search.drift_search.drift_context import (
DRIFTSearchContextBuilder,
)
from graphrag.query.structured_search.drift_search.search import DRIFTSearch
from graphrag.vector_stores.lancedb import LanceDBVectorStore
INPUT_DIR = "./inputs/operation dulce"
LANCEDB_URI = f"{INPUT_DIR}/lancedb"
COMMUNITY_REPORT_TABLE = "community_reports"
COMMUNITY_TABLE = "communities"
ENTITY_TABLE = "entities"
RELATIONSHIP_TABLE = "relationships"
COVARIATE_TABLE = "covariates"
TEXT_UNIT_TABLE = "text_units"
COMMUNITY_LEVEL = 2
# read nodes table to get community and degree data
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
community_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_TABLE}.parquet")
print(f"Entity df columns: {entity_df.columns}")
entities = read_indexer_entities(entity_df, community_df, COMMUNITY_LEVEL)
# load description embeddings to an in-memory lancedb vectorstore
# to connect to a remote db, specify url and port values.
description_embedding_store = LanceDBVectorStore(
collection_name="default-entity-description",
)
description_embedding_store.connect(db_uri=LANCEDB_URI)
full_content_embedding_store = LanceDBVectorStore(
collection_name="default-community-full_content",
)
full_content_embedding_store.connect(db_uri=LANCEDB_URI)
print(f"Entity count: {len(entity_df)}")
entity_df.head()
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
relationships = read_indexer_relationships(relationship_df)
print(f"Relationship count: {len(relationship_df)}")
relationship_df.head()
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
text_units = read_indexer_text_units(text_unit_df)
print(f"Text unit records: {len(text_unit_df)}")
text_unit_df.head()
--------------------------------------------------------------------------- ModuleNotFoundError Traceback (most recent call last) Cell In[2], line 4 1 import os 2 from pathlib import Path ----> 4 import pandas as pd 5 import tiktoken 7 from graphrag.config.enums import ModelType ModuleNotFoundError: No module named 'pandas'
In [3]:
Copied!
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]
chat_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIChat,
model=llm_model,
max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search",
model_type=ModelType.OpenAIChat,
config=chat_config,
)
token_encoder = tiktoken.encoding_for_model(llm_model)
embedding_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIEmbedding,
model=embedding_model,
max_retries=20,
)
text_embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=ModelType.OpenAIEmbedding,
config=embedding_config,
)
api_key = os.environ["GRAPHRAG_API_KEY"]
llm_model = os.environ["GRAPHRAG_LLM_MODEL"]
embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"]
chat_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIChat,
model=llm_model,
max_retries=20,
)
chat_model = ModelManager().get_or_create_chat_model(
name="local_search",
model_type=ModelType.OpenAIChat,
config=chat_config,
)
token_encoder = tiktoken.encoding_for_model(llm_model)
embedding_config = LanguageModelConfig(
api_key=api_key,
type=ModelType.OpenAIEmbedding,
model=embedding_model,
max_retries=20,
)
text_embedder = ModelManager().get_or_create_embedding_model(
name="local_search_embedding",
model_type=ModelType.OpenAIEmbedding,
config=embedding_config,
)
--------------------------------------------------------------------------- KeyError Traceback (most recent call last) Cell In[3], line 1 ----> 1 api_key = os.environ["GRAPHRAG_API_KEY"] 2 llm_model = os.environ["GRAPHRAG_LLM_MODEL"] 3 embedding_model = os.environ["GRAPHRAG_EMBEDDING_MODEL"] File E:\Python\Python-3.10.13\lib\os.py:680, in _Environ.__getitem__(self, key) 677 value = self._data[self.encodekey(key)] 678 except KeyError: 679 # raise KeyError with the original key value --> 680 raise KeyError(key) from None 681 return self.decodevalue(value) KeyError: 'GRAPHRAG_API_KEY'
In [4]:
Copied!
def read_community_reports(
input_dir: str,
community_report_table: str = COMMUNITY_REPORT_TABLE,
):
"""Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
input_path = Path(input_dir) / f"{community_report_table}.parquet"
return pd.read_parquet(input_path)
report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
report_df,
community_df,
COMMUNITY_LEVEL,
content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)
def read_community_reports(
input_dir: str,
community_report_table: str = COMMUNITY_REPORT_TABLE,
):
"""Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path."""
input_path = Path(input_dir) / f"{community_report_table}.parquet"
return pd.read_parquet(input_path)
report_df = read_community_reports(INPUT_DIR)
reports = read_indexer_reports(
report_df,
community_df,
COMMUNITY_LEVEL,
content_embedding_col="full_content_embeddings",
)
read_indexer_report_embeddings(reports, full_content_embedding_store)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[4], line 3 1 def read_community_reports( 2 input_dir: str, ----> 3 community_report_table: str = COMMUNITY_REPORT_TABLE, 4 ): 5 """Embeds the full content of the community reports and saves the DataFrame with embeddings to the output path.""" 6 input_path = Path(input_dir) / f"{community_report_table}.parquet" NameError: name 'COMMUNITY_REPORT_TABLE' is not defined
In [5]:
Copied!
drift_params = DRIFTSearchConfig(
temperature=0,
max_tokens=12_000,
primer_folds=1,
drift_k_followups=3,
n_depth=3,
n=1,
)
context_builder = DRIFTSearchContextBuilder(
model=chat_model,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
token_encoder=token_encoder,
config=drift_params,
)
search = DRIFTSearch(
model=chat_model, context_builder=context_builder, token_encoder=token_encoder
)
drift_params = DRIFTSearchConfig(
temperature=0,
max_tokens=12_000,
primer_folds=1,
drift_k_followups=3,
n_depth=3,
n=1,
)
context_builder = DRIFTSearchContextBuilder(
model=chat_model,
text_embedder=text_embedder,
entities=entities,
relationships=relationships,
reports=reports,
entity_text_embeddings=description_embedding_store,
text_units=text_units,
token_encoder=token_encoder,
config=drift_params,
)
search = DRIFTSearch(
model=chat_model, context_builder=context_builder, token_encoder=token_encoder
)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[5], line 1 ----> 1 drift_params = DRIFTSearchConfig( 2 temperature=0, 3 max_tokens=12_000, 4 primer_folds=1, 5 drift_k_followups=3, 6 n_depth=3, 7 n=1, 8 ) 10 context_builder = DRIFTSearchContextBuilder( 11 model=chat_model, 12 text_embedder=text_embedder, (...) 19 config=drift_params, 20 ) 22 search = DRIFTSearch( 23 model=chat_model, context_builder=context_builder, token_encoder=token_encoder 24 ) NameError: name 'DRIFTSearchConfig' is not defined
In [6]:
Copied!
resp = await search.search("Who is agent Mercer?")
resp = await search.search("Who is agent Mercer?")
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[6], line 1 ----> 1 resp = await search.search("Who is agent Mercer?") NameError: name 'search' is not defined
In [7]:
Copied!
resp.response
resp.response
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[7], line 1 ----> 1 resp.response NameError: name 'resp' is not defined
In [8]:
Copied!
print(resp.context_data)
print(resp.context_data)
--------------------------------------------------------------------------- NameError Traceback (most recent call last) Cell In[8], line 1 ----> 1 print(resp.context_data) NameError: name 'resp' is not defined