604 lines
21 KiB
Python
604 lines
21 KiB
Python
from sentence_transformers import SentenceTransformer
|
|
from sklearn.metrics.pairwise import cosine_similarity
|
|
from gpt4all import GPT4All
|
|
import os
|
|
import re
|
|
import numpy as np
|
|
import json
|
|
from pathlib import Path
|
|
|
|
# Retrieval — find the most relevant chunks from your documents using embeddings and cosine similarity
|
|
# Augmented — add that retrieved context to the prompt
|
|
# Generation — use the language model to generate an answer based on that context
|
|
|
|
# --------------------------
|
|
# GIT Configuration
|
|
# ---------------------------
|
|
# git config --global credential.helper wincred
|
|
# git config credential.helper store
|
|
# git config --global user.name "Sean"
|
|
# git config --global user.email "skessler1964@gmail.com"
|
|
|
|
|
|
|
|
# IMPORTANT SETUP STEPS FOR RE-CREATING THIS ENVIORNMENT
|
|
# 1) Install python
|
|
# 3.10.11
|
|
# 2) Create venv
|
|
# python -m venv .venv
|
|
# .venv/Scripts/activate
|
|
# 3) Install Dependencies
|
|
# pip install -r requirements.txt
|
|
# 4) Meta-Llama-3-8B-Instruct.Q4_0.gguf
|
|
# \Users\skess\.cache\gpt4all\Meta-Llama-3-8B-Instruct.Q4_0.gguf
|
|
# The model will auto-download on the first run and then switch to allow_download=False (see below)
|
|
# The model is about 4.5G. The download is quick.
|
|
# lm_model = GPT4All("Meta-Llama-3-8B-Instruct.Q4_0.gguf",model_path=r"C:\Users\skess\.cache\gpt4all",device="gpu",allow_download=False)
|
|
# 5) huggingface This is for the sentence transformer (sentence-transformers/all-MiniLM-L6-v2)
|
|
# \Users\skess\.cache\huggingface There is a fodler structure under here.
|
|
# embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") This will automatically load the model if it is not already loaded
|
|
# so an internet connection would be required if running this from scratch
|
|
|
|
# IMPORTANT PYTHON NOTES - KEEP
|
|
# Python
|
|
# .venv/Scripts/Activate
|
|
# pip freeze > requirements.txt
|
|
# pip install -r requirements.txt
|
|
|
|
|
|
# Still on the to-do list:
|
|
# Fix the enrichment length cap
|
|
# Semantic chunking
|
|
# Better table handling
|
|
|
|
# -------------------
|
|
# Embedding Cleaning
|
|
# -------------------
|
|
# del embeddings_cache.npz
|
|
# del embeddings_cache_meta.json
|
|
|
|
|
|
# -------------------------
|
|
# Knowledge base selection
|
|
# -------------------------
|
|
BOOK_DIR = 'Books/History' # just a string
|
|
book_files = []
|
|
|
|
for f in Path(BOOK_DIR).rglob('*'):
|
|
if not f.is_file():
|
|
continue
|
|
try:
|
|
with open(f, 'r', encoding='utf-8'):
|
|
pass
|
|
book_files.append(str(f)) # store as string, not Path
|
|
except (UnicodeDecodeError, PermissionError):
|
|
continue
|
|
print(f"Found {len(book_files)} files")
|
|
|
|
|
|
# Overlap should be 10-20% of chunk size
|
|
CHUNK_SIZE = 700
|
|
CHUNK_OVERLAP = 100
|
|
DEBUG = False
|
|
CACHE_FILE = "embeddings_cache.npz"
|
|
CACHE_META = "embeddings_cache_meta.json"
|
|
MAX_HISTORY = 5
|
|
CURRENT_LEVEL = 10
|
|
SEARCH_FILTER = None # None = search all books
|
|
|
|
# -------------------------
|
|
# CONVERSATIONAL HISTORY
|
|
# -------------------------
|
|
conversation_history = []
|
|
|
|
# -------------------------
|
|
# LEVEL CONFIG
|
|
# -------------------------
|
|
LEVELS = {
|
|
1: {"expand": False, "top_k": 1, "max_tokens": 75, "context_len": 500},
|
|
2: {"expand": False, "top_k": 1, "max_tokens": 75, "context_len": 600},
|
|
3: {"expand": False, "top_k": 2, "max_tokens": 100, "context_len": 700},
|
|
4: {"expand": False, "top_k": 2, "max_tokens": 100, "context_len": 800},
|
|
5: {"expand": False, "top_k": 3, "max_tokens": 125, "context_len": 1000},
|
|
6: {"expand": False, "top_k": 3, "max_tokens": 150, "context_len": 1200},
|
|
7: {"expand": True, "top_k": 3, "max_tokens": 150, "context_len": 1400},
|
|
8: {"expand": True, "top_k": 4, "max_tokens": 175, "context_len": 1600},
|
|
9: {"expand": True, "top_k": 5, "max_tokens": 175, "context_len": 1800},
|
|
10: {"expand": True, "top_k": 5, "max_tokens": 200, "context_len": 2000},
|
|
}
|
|
|
|
# -------------------------
|
|
# Load models
|
|
# -------------------------
|
|
# -----------------------------------
|
|
# Load the sentence tranformer model
|
|
# -----------------------------------
|
|
print("Loading embedding model...")
|
|
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
|
|
|
# -----------------------------------
|
|
# Load the language model - If it does not exist in the download area then download it otherwise us it.
|
|
# -----------------------------------
|
|
print("Loading language model...")
|
|
model_file = "Meta-Llama-3-8B-Instruct.Q4_0.gguf"
|
|
model_path = r"C:\Users\skess\.cache\gpt4all"
|
|
|
|
full_path = os.path.join(model_path, model_file)
|
|
|
|
if not os.path.exists(full_path):
|
|
print("Model not found locally. Downloading...")
|
|
allow_download = True
|
|
else:
|
|
allow_download = False
|
|
|
|
lm_model = GPT4All(
|
|
model_file,
|
|
model_path=model_path,
|
|
device="gpu",
|
|
allow_download=allow_download
|
|
)
|
|
|
|
# -------------------------
|
|
# Clean text
|
|
# -------------------------
|
|
def clean_text(text):
|
|
text = re.sub(r'(\w+)-\n(\w+)', r'\1\2', text)
|
|
text = re.sub(r'\n+', ' ', text)
|
|
text = re.sub(r'(?<=[a-z])(\d{1,3})(?=\s[A-Z])', '', text)
|
|
text = re.sub(r'\s\d{1,4}\s', ' ', text)
|
|
text = re.sub(r'[■•◆▪→]', '', text)
|
|
text = re.sub(r' +', ' ', text)
|
|
return text.strip()
|
|
|
|
# -------------------------
|
|
# Chunk text with overlap
|
|
# -------------------------
|
|
def chunk_text(text, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP):
|
|
# Step 1 — Split into paragraphs first
|
|
paragraphs = [p.strip() for p in re.split(r'\n\s*\n', text) if p.strip()]
|
|
|
|
# Step 2 — Split any overly long paragraphs into sentences
|
|
split_units = []
|
|
for para in paragraphs:
|
|
if len(para) <= chunk_size * 2:
|
|
split_units.append(para)
|
|
else:
|
|
# Break long paragraph into sentences
|
|
sentences = re.split(r'(?<=[.!?])\s+', para)
|
|
current = ""
|
|
for sentence in sentences:
|
|
if len(current) + len(sentence) <= chunk_size:
|
|
current += " " + sentence
|
|
else:
|
|
if current:
|
|
split_units.append(current.strip())
|
|
current = sentence
|
|
if current:
|
|
split_units.append(current.strip())
|
|
|
|
# Step 3 — Combine units into chunks up to chunk_size
|
|
# with overlap by re-including the previous unit
|
|
chunks = []
|
|
current_chunk = ""
|
|
prev_unit = ""
|
|
|
|
for unit in split_units:
|
|
if len(current_chunk) + len(unit) + 1 <= chunk_size:
|
|
current_chunk += " " + unit
|
|
else:
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
# Overlap — start new chunk with previous unit for context
|
|
if prev_unit and len(prev_unit) + len(unit) + 1 <= chunk_size:
|
|
current_chunk = prev_unit + " " + unit
|
|
else:
|
|
current_chunk = unit
|
|
prev_unit = unit
|
|
|
|
if current_chunk:
|
|
chunks.append(current_chunk.strip())
|
|
|
|
return chunks
|
|
|
|
|
|
# -------------------------
|
|
# Check if cache is valid
|
|
# -------------------------
|
|
def cache_is_valid():
|
|
if not os.path.exists(CACHE_FILE) or not os.path.exists(CACHE_META):
|
|
return False
|
|
with open(CACHE_META, "r") as f:
|
|
meta = json.load(f)
|
|
if meta.get("book_files") != book_files:
|
|
return False
|
|
for book_name in book_files:
|
|
if not os.path.exists(book_name):
|
|
continue
|
|
stored_size = meta.get("file_sizes", {}).get(book_name)
|
|
actual_size = os.path.getsize(book_name)
|
|
if stored_size != actual_size:
|
|
return False
|
|
return True
|
|
|
|
# -------------------------
|
|
# Load or build embeddings
|
|
# -------------------------
|
|
all_chunks = []
|
|
all_sources = []
|
|
|
|
if cache_is_valid():
|
|
print("Loading embeddings from cache...")
|
|
data = np.load(CACHE_FILE, allow_pickle=True)
|
|
chunk_embeddings = data["embeddings"]
|
|
all_chunks = list(data["chunks"])
|
|
all_sources = list(data["sources"])
|
|
print(f"Total chunks loaded from cache: {len(all_chunks)}")
|
|
else:
|
|
print("Building embeddings from scratch...")
|
|
for book_name in book_files:
|
|
if not os.path.exists(book_name):
|
|
print(f"Warning: {book_name} not found, skipping...")
|
|
continue
|
|
print(f"Loading {book_name}...")
|
|
with open(book_name, "r", encoding="utf-8") as f:
|
|
book_text = clean_text(f.read())
|
|
book_chunks = chunk_text(book_text)
|
|
all_chunks.extend(book_chunks)
|
|
all_sources.extend([book_name] * len(book_chunks))
|
|
print(f" -> {len(book_chunks)} chunks")
|
|
|
|
print(f"Total chunks: {len(all_chunks)}")
|
|
print("Embedding chunks (this may take a minute)...")
|
|
chunk_embeddings = embed_model.encode(all_chunks, convert_to_tensor=False)
|
|
|
|
print("Saving embeddings cache...")
|
|
np.savez(
|
|
CACHE_FILE,
|
|
embeddings=chunk_embeddings,
|
|
chunks=np.array(all_chunks, dtype=object),
|
|
sources=np.array(all_sources, dtype=object)
|
|
)
|
|
file_sizes = {b: os.path.getsize(b) for b in book_files if os.path.exists(b)}
|
|
with open(CACHE_META, "w") as f:
|
|
json.dump({"book_files": book_files, "file_sizes": file_sizes}, f)
|
|
print("Cache saved.")
|
|
|
|
# -------------------------
|
|
# Book filter helper
|
|
# -------------------------
|
|
def get_filtered_indices(filter_term):
|
|
"""Return indices of chunks whose source filename contains filter_term."""
|
|
if not filter_term:
|
|
return list(range(len(all_chunks)))
|
|
filter_lower = filter_term.lower()
|
|
return [i for i, src in enumerate(all_sources)
|
|
if filter_lower in os.path.basename(src).lower()]
|
|
|
|
def show_available_books():
|
|
"""Print a short list of available books with keywords."""
|
|
print("\n--- Available books ---")
|
|
for f in book_files:
|
|
base = os.path.basename(f).replace('.txt', '')
|
|
print(f" {base}")
|
|
print("--- Use 'search <keyword>: your question' to filter ---\n")
|
|
|
|
# -------------------------
|
|
# Query expansion
|
|
# -------------------------
|
|
def expand_query(question):
|
|
book_titles = ', '.join([os.path.basename(b).replace('.txt', '') for b in book_files])
|
|
|
|
prompt = (
|
|
f"You are helping search a library containing these documents:\n"
|
|
f"{book_titles}\n\n"
|
|
f"Generate 3 alternative ways to ask the following question using "
|
|
f"vocabulary, concepts, and terminology that would likely appear in "
|
|
f"these specific documents. Do not reference authors or books not in this list. "
|
|
f"The alternative questions must ask about the SAME specific fact as the original. "
|
|
f"Do not broaden or change the subject of the question. "
|
|
f"Return ONLY the 3 questions, one per line, no numbering, no explanation.\n\n"
|
|
f"Question: {question}"
|
|
)
|
|
with lm_model.chat_session():
|
|
response = lm_model.generate(prompt, max_tokens=150)
|
|
|
|
lines = [line.strip() for line in response.strip().split('\n') if line.strip()]
|
|
alternatives = [
|
|
l for l in lines
|
|
if len(l) > 15
|
|
and len(l) < 200
|
|
and '?' in l
|
|
and l != question
|
|
and ':' not in l[:20]
|
|
][:3]
|
|
|
|
all_queries = [question] + alternatives
|
|
print(f" [Expanded queries: {len(all_queries)}]")
|
|
for q in all_queries:
|
|
print(f" - {q}")
|
|
return all_queries
|
|
|
|
# ----------------------
|
|
# Topic Detection
|
|
# ----------------------
|
|
# Stopwords for topic detection
|
|
# -------------------------
|
|
STOPWORDS = {
|
|
"the","is","a","an","and","or","of","to","in","on","for","with",
|
|
"what","which","who","how","when","where","can","i","you","it",
|
|
"did","do","does","was","were","he","she","they","his","her",
|
|
"him","them","his","its","be","been","have","has","had","will",
|
|
"would","could","should","may","might","me","my","we","our"
|
|
}
|
|
|
|
|
|
def topics_are_related(question, history, lookback=3):
|
|
"""
|
|
Returns True if the question shares meaningful words
|
|
with recent conversation history.
|
|
Also returns True for very short pronoun-heavy questions
|
|
since they are almost certainly follow-ups.
|
|
"""
|
|
if not history:
|
|
return False
|
|
|
|
# Very short questions with pronouns are almost certainly follow-ups
|
|
q_lower = question.lower()
|
|
# Very short questions with pronouns are almost certainly follow-ups
|
|
pronoun_followups = {
|
|
"he","she","they","him","her","them","his","it",
|
|
"this","that","these","those","who","what","where","when"
|
|
}
|
|
q_words_all = set(q_lower.replace('?','').replace('.','').split())
|
|
if len(q_words_all) <= 5 and q_words_all & pronoun_followups:
|
|
print(f" [Pronoun follow-up detected — enriching]")
|
|
return True
|
|
|
|
# Get meaningful words from current question
|
|
q_words = set(q_lower.split()) - STOPWORDS
|
|
|
|
if not q_words:
|
|
return False
|
|
|
|
# Get words from recent history questions
|
|
recent = history[-lookback:]
|
|
history_words = set()
|
|
for exchange in recent:
|
|
history_words.update(exchange["question"].lower().split())
|
|
history_words -= STOPWORDS
|
|
|
|
# Check overlap
|
|
overlap = len(q_words & history_words)
|
|
print(f" [Topic overlap: {overlap} word(s)]")
|
|
return overlap > 0
|
|
|
|
def enrich_query_with_history(question):
|
|
if not conversation_history:
|
|
return question
|
|
if len(question.split()) >= 6:
|
|
return question
|
|
if not topics_are_related(question, conversation_history):
|
|
print(f" [Topic shift detected — no enrichment]")
|
|
return question
|
|
|
|
recent = conversation_history[-3:]
|
|
context_words = " ".join([ex["question"] for ex in recent])
|
|
enriched = f"{context_words} {question}"
|
|
|
|
# Don't enrich if result is too long — it will overwhelm the question
|
|
if len(enriched.split()) > 30:
|
|
print(f" [Enriched query too long — using original]")
|
|
return question
|
|
|
|
print(f" [Enriched query: {enriched}]")
|
|
return enriched
|
|
|
|
# -------------------------
|
|
# Retrieve top relevant chunks
|
|
# -------------------------
|
|
def get_top_chunks(question, filter_term=None):
|
|
level_cfg = LEVELS[CURRENT_LEVEL]
|
|
|
|
# Enrich short follow-up questions with history context
|
|
retrieval_question = enrich_query_with_history(question)
|
|
|
|
if level_cfg["expand"]:
|
|
queries = expand_query(retrieval_question)
|
|
else:
|
|
queries = [retrieval_question]
|
|
|
|
# Get filtered indices
|
|
search_indices = get_filtered_indices(filter_term)
|
|
|
|
if not search_indices:
|
|
print(f" [Warning: no books matched filter '{filter_term}' — searching all]")
|
|
search_indices = list(range(len(all_chunks)))
|
|
|
|
# Subset embeddings and metadata
|
|
sub_embeddings = chunk_embeddings[search_indices]
|
|
sub_chunks = [all_chunks[i] for i in search_indices]
|
|
sub_sources = [all_sources[i] for i in search_indices]
|
|
|
|
if filter_term:
|
|
matched_books = set(os.path.basename(s) for s in sub_sources)
|
|
print(f" [Filter '{filter_term}' matched: {', '.join(matched_books)}]")
|
|
|
|
# Score within filtered subset
|
|
sub_scores = np.zeros(len(sub_chunks))
|
|
for q in queries:
|
|
query_emb = embed_model.encode([q])
|
|
scores = cosine_similarity(query_emb, sub_embeddings)[0]
|
|
sub_scores += scores
|
|
|
|
sub_scores /= len(queries)
|
|
|
|
top_k = level_cfg["top_k"]
|
|
top_indices = sub_scores.argsort()[-top_k:][::-1]
|
|
|
|
return [sub_chunks[i] for i in top_indices], [sub_sources[i] for i in top_indices]
|
|
|
|
# -------------------------
|
|
# Parse search filter from input
|
|
# -------------------------
|
|
def parse_input(user_input):
|
|
"""
|
|
Detects 'search keyword: question' syntax.
|
|
Returns (question, filter_term) tuple.
|
|
"""
|
|
pattern = re.match(r'^search\s+(.+?):\s*(.+)$', user_input, re.IGNORECASE)
|
|
if pattern:
|
|
filter_term = pattern.group(1).strip()
|
|
question = pattern.group(2).strip()
|
|
return question, filter_term
|
|
return user_input, SEARCH_FILTER
|
|
|
|
# -------------------------
|
|
# Ask question
|
|
# -------------------------
|
|
def ask_question(question, show_sources=False, filter_term=None):
|
|
global conversation_history
|
|
|
|
level_cfg = LEVELS[CURRENT_LEVEL]
|
|
top_chunks, sources = get_top_chunks(question, filter_term=filter_term)
|
|
|
|
if DEBUG:
|
|
print("\n--- Retrieved chunks ---")
|
|
for i, chunk in enumerate(top_chunks):
|
|
print(f"\nChunk {i+1}:")
|
|
print(chunk[:300])
|
|
print("--- End chunks ---\n")
|
|
|
|
context = " ".join(top_chunks)[:level_cfg["context_len"]]
|
|
|
|
# Build conversation history string
|
|
history_text = ""
|
|
if conversation_history:
|
|
history_text = "Previous conversation:\n"
|
|
for exchange in conversation_history[-MAX_HISTORY:]:
|
|
history_text += f"Q: {exchange['question']}\n"
|
|
history_text += f"A: {exchange['answer']}\n"
|
|
history_text += "\n"
|
|
|
|
prompt = (
|
|
f"You are a helpful research assistant. "
|
|
f"Answer the question using ONLY the provided context. "
|
|
f"Be direct and concise. "
|
|
f"Only say 'I don't know' if the context contains absolutely nothing relevant. "
|
|
f"Do not reference outside sources. "
|
|
f"Do not repeat or echo the conversation history in your answer. "
|
|
f"Do not include 'Context:' or 'Q:' or 'A:' labels in your answer.\n\n"
|
|
f"Do not include separator lines or notes about your sources in your answer. "
|
|
)
|
|
|
|
if history_text:
|
|
prompt += (
|
|
f"--- BACKGROUND ONLY - DO NOT REPEAT ---\n"
|
|
f"{history_text}"
|
|
f"--- END BACKGROUND ---\n\n"
|
|
)
|
|
|
|
prompt += (
|
|
f"--- REFERENCE CONTEXT ---\n"
|
|
f"{context}\n"
|
|
f"--- END CONTEXT ---\n\n"
|
|
f"Question: {question}\n\n"
|
|
f"Answer:"
|
|
)
|
|
|
|
with lm_model.chat_session():
|
|
response = lm_model.generate(prompt, max_tokens=level_cfg["max_tokens"])
|
|
|
|
answer = response.strip()
|
|
|
|
conversation_history.append({
|
|
"question": question,
|
|
"answer": answer
|
|
})
|
|
|
|
if len(conversation_history) > MAX_HISTORY:
|
|
conversation_history = conversation_history[-MAX_HISTORY:]
|
|
|
|
if show_sources:
|
|
unique_sources = list(set(sources))
|
|
short_sources = [os.path.basename(s) for s in unique_sources]
|
|
print(f" [Sources: {', '.join(short_sources)}]")
|
|
print(f" [Level: {CURRENT_LEVEL} | "
|
|
f"expand={'on' if level_cfg['expand'] else 'off'} | "
|
|
f"top_k={level_cfg['top_k']} | "
|
|
f"max_tokens={level_cfg['max_tokens']}]")
|
|
print(f" [Memory: {len(conversation_history)} exchanges]")
|
|
if filter_term:
|
|
print(f" [Filter: '{filter_term}']")
|
|
|
|
return answer
|
|
|
|
# -------------------------
|
|
# Interactive loop
|
|
# -------------------------
|
|
print("\nReady! Ask questions about your books")
|
|
print("Commands: 'exit', 'sources on/off', 'level 1-10',")
|
|
print(" 'memory clear', 'memory show', 'debug on/off'")
|
|
print(" 'books' — list available books")
|
|
print(" 'search <keyword>: question' — filter by book\n")
|
|
show_sources = False
|
|
|
|
while True:
|
|
user_input = input(f"[L{CURRENT_LEVEL}] You: ")
|
|
|
|
if user_input.lower() in ["exit", "quit"]:
|
|
break
|
|
elif user_input.lower() == "memory clear":
|
|
conversation_history.clear()
|
|
print("Conversation memory cleared.")
|
|
continue
|
|
elif user_input.lower() == "memory show":
|
|
if not conversation_history:
|
|
print("No conversation history.")
|
|
else:
|
|
print(f"\n--- Last {len(conversation_history)} exchanges ---")
|
|
for i, exchange in enumerate(conversation_history):
|
|
print(f"\nQ{i+1}: {exchange['question']}")
|
|
print(f"A{i+1}: {exchange['answer'][:100]}...")
|
|
print("---\n")
|
|
continue
|
|
elif user_input.lower() == "debug on":
|
|
DEBUG = True
|
|
print("Debug mode enabled.")
|
|
continue
|
|
elif user_input.lower() == "debug off":
|
|
DEBUG = False
|
|
print("Debug mode disabled.")
|
|
continue
|
|
elif user_input.lower() == "sources on":
|
|
show_sources = True
|
|
print("Source display enabled.")
|
|
continue
|
|
elif user_input.lower() == "sources off":
|
|
show_sources = False
|
|
print("Source display disabled.")
|
|
continue
|
|
elif user_input.lower() == "books":
|
|
show_available_books()
|
|
continue
|
|
elif user_input.lower().startswith("level "):
|
|
try:
|
|
lvl = int(user_input.split()[1])
|
|
if 1 <= lvl <= 10:
|
|
CURRENT_LEVEL = lvl
|
|
cfg = LEVELS[CURRENT_LEVEL]
|
|
print(f"Level set to {CURRENT_LEVEL} — "
|
|
f"expand={'on' if cfg['expand'] else 'off'}, "
|
|
f"top_k={cfg['top_k']}, "
|
|
f"max_tokens={cfg['max_tokens']}")
|
|
else:
|
|
print("Level must be between 1 and 10.")
|
|
except:
|
|
print("Usage: level 1 through level 10")
|
|
continue
|
|
|
|
# Parse for search filter
|
|
question, filter_term = parse_input(user_input)
|
|
|
|
response = ask_question(question, show_sources=show_sources, filter_term=filter_term)
|
|
print("Bot:", response)
|