Files
reason-flow/server/services/graphRagService.js
T
2025-11-06 11:08:59 +01:00

145 lines
4.9 KiB
JavaScript

const { Document } = require('../models');
const embeddingService = require('./embeddingService');
const logger = require('../utils/logger');
class GraphRagService {
constructor() {
this.similarityThreshold = parseFloat(process.env.GRAPH_RAG_SIM_THRESHOLD || '0.2');
this.maxNeighbors = parseInt(process.env.GRAPH_RAG_MAX_NEIGHBORS || '10');
this.maxResults = parseInt(process.env.GRAPH_RAG_MAX_RESULTS || '10');
}
scoreSimilarity(a, b) {
return embeddingService.cosineSimilarity(a, b);
}
tagOverlap(tagsA = [], tagsB = []) {
const setA = new Set((tagsA || []).map((t) => (t || '').toLowerCase()));
const setB = new Set((tagsB || []).map((t) => (t || '').toLowerCase()));
let overlap = 0;
setA.forEach((t) => {
if (setB.has(t)) overlap += 1;
});
return overlap;
}
buildGraph(nodes) {
// nodes: [{ id, embedding, tags }]
const edges = new Map(); // id -> [{ id, score, reason }]
for (let i = 0; i < nodes.length; i++) {
for (let j = i + 1; j < nodes.length; j++) {
const ni = nodes[i];
const nj = nodes[j];
const sim = this.scoreSimilarity(ni.embedding, nj.embedding);
const tagScore = this.tagOverlap(ni.tags, nj.tags);
const hybrid = sim + Math.min(tagScore, 3) * 0.05; // light tag bonus
if (hybrid >= this.similarityThreshold) {
if (!edges.has(ni.id)) edges.set(ni.id, []);
if (!edges.has(nj.id)) edges.set(nj.id, []);
edges.get(ni.id).push({ id: nj.id, score: hybrid, reason: { sim, tagScore } });
edges.get(nj.id).push({ id: ni.id, score: hybrid, reason: { sim, tagScore } });
}
}
}
// Trim neighbors
edges.forEach((arr, k) => {
arr.sort((a, b) => b.score - a.score);
edges.set(k, arr.slice(0, this.maxNeighbors));
});
return edges;
}
async graphSearch({ query, category }) {
const queryEmbedding = await embeddingService.embedText(query);
// Load candidate docs
const where = { is_indexed: true };
if (category) where.category = category;
const docs = await Document.findAll({
where,
attributes: ['id', 'original_filename', 'extracted_text', 'embeddings', 'tags', 'category', 'created_at']
});
const nodes = docs
.filter((d) => Array.isArray(d.embeddings) && d.embeddings.length > 0)
.map((d) => ({ id: d.id, embedding: d.embeddings, tags: d.tags || [], ref: d }));
if (nodes.length === 0) {
return { results: [] };
}
// Seed scores by query similarity
const seedScores = nodes.map((n) => ({
id: n.id,
score: this.scoreSimilarity(queryEmbedding, n.embedding)
}));
// Log similarity scores for debugging
logger.info('Similarity scores:', seedScores.map(s => ({ id: s.id, score: s.score.toFixed(4) })));
seedScores.sort((a, b) => b.score - a.score);
const seeds = seedScores.slice(0, Math.min(5, seedScores.length)).map((s) => s.id);
const graph = this.buildGraph(nodes);
// Expand neighborhoods from seeds
const visited = new Set();
const scored = new Map();
const pushScore = (id, add, meta) => {
const prev = scored.get(id) || { score: 0, hops: Infinity, reasons: [] };
const combined = {
score: Math.max(prev.score, add),
hops: Math.min(prev.hops, meta.hops),
reasons: prev.reasons.length < 3 ? [...prev.reasons, meta] : prev.reasons
};
scored.set(id, combined);
};
const queue = [];
seeds.forEach((id) => queue.push({ id, hops: 0, via: null }));
while (queue.length > 0 && scored.size < 200) {
const { id, hops, via } = queue.shift();
if (visited.has(id) || hops > 2) continue;
visited.add(id);
// Base score: similarity to query
const node = nodes.find((n) => n.id === id);
const base = this.scoreSimilarity(queryEmbedding, node.embedding);
pushScore(id, base, { type: 'seed', hops });
const neighbors = graph.get(id) || [];
neighbors.forEach((nbr) => {
const pathScore = (base + nbr.score) / 2;
pushScore(nbr.id, pathScore, { type: 'edge', hops: hops + 1, via: id, edgeScore: nbr.score });
if (!visited.has(nbr.id)) {
queue.push({ id: nbr.id, hops: hops + 1, via: id });
}
});
}
// Format results
const ranked = Array.from(scored.entries())
.map(([id, info]) => {
const ref = nodes.find((n) => n.id === id)?.ref;
return {
id,
original_filename: ref?.original_filename,
snippet: (ref?.extracted_text || '').slice(0, 400),
category: ref?.category,
created_at: ref?.created_at,
score: Number(info.score.toFixed(4)),
hops: info.hops,
reasons: info.reasons
};
})
.sort((a, b) => b.score - a.score)
.slice(0, this.maxResults);
return { results: ranked };
}
}
module.exports = new GraphRagService();