From 74931f235e47a5a8691c907018e40226f2674155 Mon Sep 17 00:00:00 2001 From: bolade Date: Sat, 30 Aug 2025 13:56:19 +0100 Subject: [PATCH] Refactor imports and enhance query functionality with LangGraph integration; update requirements for new dependencies --- app/__pycache__/main.cpython-312.pyc | Bin 2311 -> 2312 bytes app/db/__pycache__/tables.cpython-312.pyc | Bin 1295 -> 1299 bytes app/db/tables.py | 2 +- app/main.py | 17 +- app/pydantic_schemas.py | 2 +- .../langgraph_agent.cpython-312.pyc | Bin 11285 -> 7279 bytes app/services/langgraph_agent.py | 162 ++++++++++++++++++ app/services/openrouter.py | 3 +- requirements.txt | 4 + 9 files changed, 179 insertions(+), 11 deletions(-) diff --git a/app/__pycache__/main.cpython-312.pyc b/app/__pycache__/main.cpython-312.pyc index d9c2a4c1689218ae235fa426aae78fd34f019050..9cbd458af81d89c7020e6ba300163b1926a0d2e5 100644 GIT binary patch delta 889 zcmYjPO=uHA6rS1LG>OTk$tG!1+tNa*=@wgQQ4d=DX|Z1FKT;ZmWyy|dYO+~pHZ1|6 zK?L<=<6iXcq0kD7ced^+{KF1KQ{o;f#Eiew`g8OUfVHkYE=C`JMz{zV=LCmRFJx1(W1V$(hKN2)Il zA8P&m=v=L1WKe z3a@cj032?}qU9W6Ov|wuS27AYTlLcfcxcjDklNK&9)v0L!`sz3RN=OurBguA`S2#}uT?t``TDn7I&h*)`0 Y+CuO=DgNTmZxIQHJ-hQIkdb(?|H$R#`2YX_ delta 906 zcmZ8fO=uHA6rS1L{IvOPew*4t8nJ262CWCdR>42Si(0D`a#=!l+NLJE@omx)2sWso z$A&$q;6ZPNDkvVkcvH|rLdDjVP(crREB0bN`DW8V@xi?RdGF2l=0V_o!1KZ5=_N=@ ztylAp8=e@QZ6vRxDN%@GFL;Z-jE@q5Fvl8CLdy8jvUM#=n1A75V>M+9pptBSKDQj= zN+)+eSG;SqV=FEuDL&?!l@$Lysaq)%W)Y?D<|rX>0DIFyG!w;6(S@V?tL#+>C_yEp zgr7&&q>hJ@Dy`B~^cOLn(DRe-sS1PYHFm4Q^m2-tk?E?MyQD6nrGXVYGsjdO>Fav0 zf|g^tu7Pnxn4msmUbrzBz~VkUx(8XEd~*j^W}2>{rZfcp*I$Hpl=d3Wq?kR3VF(#- zrNLA`swM^9fim79=;OYgQWkV3=k!H5gm%cV**?Zi2Xto43FzyRh{|0&J2gYJ&nzoK zOzmlHsB)CP=TV$+7#@Zy+CKGLP%$Rnv_&N=43nEQEDsB@LFAKl9*b#{kviwHd InUr{=e-ZQDm;e9( diff --git a/app/db/__pycache__/tables.cpython-312.pyc b/app/db/__pycache__/tables.cpython-312.pyc index 2086f0dc2dd37b122cfe39748de2665e952bcbe3..8e6b31ce548b07a62420e563ee9e31229f4d306c 100644 GIT binary patch delta 51 zcmeC@n#{#}nwOW00SMZkZ_K#4k#{8vD`#Rsf!^fpEN((lU)cB=*(SJrWdKq|ia=EW Dg+2|v delta 47 zcmbQt)z8IynwOW00SNBz+n8~6BkxKUM%Kx@S=@xgzOeB#vQ2RL$^fK_6oD!LNp=iz diff --git a/app/db/tables.py b/app/db/tables.py index c66d162..3153b82 100644 --- a/app/db/tables.py +++ b/app/db/tables.py @@ -2,7 +2,7 @@ import datetime from sqlalchemy import Column, DateTime, Integer, String -from db.db import Base +from app.db.db import Base class InvestorTable(Base): diff --git a/app/main.py b/app/main.py index a09042e..400730f 100644 --- a/app/main.py +++ b/app/main.py @@ -1,11 +1,12 @@ import io import pandas as pd -from db.db import db_dependency, init_database +from app.db.db import db_dependency, init_database from fastapi import FastAPI, File, UploadFile -from services.openrouter import InvestorProcessor +from app.services.openrouter import InvestorProcessor -from app.services.querying import QueryProcessor +from app.pydantic_schemas import QueryRequest, QueryResponseList +from app.services.langgraph_agent import LangGraphQueryAgent app = FastAPI() @@ -31,11 +32,11 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)): return {"results": [r.dict() for r in results]} -@app.post("/query") -async def query_investors(db: db_dependency, question: str): - processor = QueryProcessor(sql_session=db) - results = processor.process_query(question) - return {"results": [r.dict() for r in results]} +@app.post("/query", response_model=QueryResponseList) +async def query_investors(db: db_dependency, request: QueryRequest): + agent = LangGraphQueryAgent(sql_session=db) + result = agent.run(request.question) + return result if __name__ == "__main__": diff --git a/app/pydantic_schemas.py b/app/pydantic_schemas.py index 08588b5..af54a14 100644 --- a/app/pydantic_schemas.py +++ b/app/pydantic_schemas.py @@ -35,4 +35,4 @@ class QueryRequest(BaseModel): class QueryResponseList(BaseModel): - responses: List[QueryResponse] \ No newline at end of file + responses: List[QueryResponse] diff --git a/app/services/__pycache__/langgraph_agent.cpython-312.pyc b/app/services/__pycache__/langgraph_agent.cpython-312.pyc index fe524c0d6bcdaf52266d0021e07a6eb987a590c4..06c5a6dcea45774bcf19e7b18556d8b25035b0e2 100644 GIT binary patch literal 7279 zcmd5>U2GdycE0n=8IqzzNi=0imMBUJP0JBCZtScVj#Jz6PyXb@KMlgv8H_n2i86nb zJ0siDQZS8mU<(^S3c7&gHicCzuuU4byZe^M0{hm!h^7Ee23|zKwrKZ_YpK}QPd)dN zLt5%&oqbrKL+RXm&$;K$nfvpdbMO7vmKHC9C-$Geocfn`gnmN`&Lh+^%+}8!vw}3F zF&UI)k_vSL!Cd@kcjdO*%+q^z8jv)-gP>r48mOvw1NEyX^qxDh1q%s$}31mNk&6S_A+#{&S~s1;Ic^P zb>SiRB~w#LiW(0x!Bys=GLgvUJaNhPo|x2gCjAnW0b=><0sFJ_E1AdmxL$K)yg*K+)5SiK%>6PvkSRH&jE-Wz&hVd}=le z-(e))%j0W_x8EE~s5vvAQ*NZOK7rxw33X;BVd(gJI;9(ljGCLAB!j1@o<}l1_ zL`eWLMhwWu$nph#wD0So_*X+4L$4Kk)&m1&=C!Y52fwk=%W6;xmfNTC8k3hYhB*ux;Rp)t)jgpG{GJY=1@3EGXi z1vD)-Dm60GawFTwXq?7tLYjFfI?EW^sRIu}W{XqNBx6S(=Y0!YftwCG+-9nI;O%6{ zcRfz~4Eh#53G_O5DM9hMdNSydiM;FCp|^LJ5Q=8E!_z!}L|y55bv2JBKevwb9qPzt zuVdkk_WB)eGySrjg|TaMxN|7(9iR7&CQ@`hD%w>7oinOAWrxyr;8Zl-NZ~X^V#D@5 zQ<{It0oTYgT!YT)nl=FiSOvO~HX;)M5(WS;PcX-pvbw2is;Syuf?tNN;?$JgQird4 zQJ`xh9;{)Up&JGOqb=2p)vnrY@8}qsHsLL&GHIZ8TfG78s*)ND_$-==B5ZjZ=XJ^tG z-R1!Gz_0O!o|(WSL_SL70>wUqtknoSm<$$y>2xk_D$0Cs14fzoq4fcq`3$G467`Z>P!M*_RKibdSio)Ii>!! z*amyRDTO~9++h1QTiTbdEMEDsQXKktKgIe~J(PexK!_j?#YYx;a|KPJ1r}C1740NN zUSkSunz@MF<l*XM z#b;@a@u%Lm$C)F|)R`?~=!%Na0$*UK_cf(X&TROQ0`E3=fg5q^HdEk6ne`(R|O_s+pS7H5z(N zklle1xt>-dZ@vEJ=`Upb0;$ERvsrKlQ+n!}Vx;GF+i%puXzH`T`E<$RP*4qs=i8|Vumuti&S3&vwm6Z_m^#KJg0h87{s!c!K@fRg(=q8B ze3Fn;gpAc{rL$=hEY4(X5!{QO(=b`0_|FMB0VFQq??Yk>V6~~+9%pO7d}MPtf5RYw z8*w`|zk?4!alYTx<#KK?xIfev2S;zAz&_Xm*Ry?y##^hd8( zM-azXM^T`4>GI;`AHVkmwQxc)WpUoS$8R5h%m*!R%hLYE{g&)s3NMDOwvLs+rva-y zRB4Zv+M`ymbEW^&e#_sw^v2>FR(P+~-D~;d+Z|p({4dVW3snxZ2a2o9;?5@ANvpGk&pG3suaf{u%uY{C+3T3Cj-Y94}ox0<}OxEnQv z0*|K)0-kC*je?*F1pS=40A}wY^p{}8g1F!*cwAc+n_>Kd1hD7gN}Jo6f}{zWM1J4W znyfVfeW(NQMC`$hTtOt|90M(QANpK?>wN$&e+{?>QFHS&gfN`ayx)-BKsUMf(G4cv zGLBz`79ocAQ4mo~{U+@B<(#wKw`I-QiQBRhU>lf>Obu5&mvh&0`5QTW99Xu0N>?>F z1;Bto#(d9TP|S-M05BO9r_H%YI#=JQ;@3zMB=(F;_mqyi7-8a71vrfS$s+GevFwbLLVPcf>AQ-21m z`KT*S?TE~$pq}wRK#Qmz10lFh@ zCLV{z&Hn_GV+Ey2+|0Pe&A_cMY|v;t^N<@u%ghysF6e@HfrSEloxyDduE4@Fcaf7b zuur1jCM8Q}fj1lQ08xQU8^9$yu3ZnnhM%4wxX3`mB}#8-xX3`mC30_QxX3_*xc6O8 zL&HS|8ZME3L&HS|8ZHr4L&HS|8ZOaRL&HS|8ZMDoL&HV(ka1w{+s2iQ`ZDuLACF8j z#Pk6m{sU2}1MriuxA}~{h~nHhg`+szJerILC^*#^B!b!%6YYwk@M%)z3?XL;A!jJI zcwNoF*=s#~qEKavbd#^yeA?8rwxr*L!-jNja(4{%GU(5DyJ8b<;A4;*eKdU99F7%V z*bH=30?|?+x^}rTe53?_fg@We$caVG61{iRx6{klOX5DOtG9UO-pflz7mq&TLz}Ik zWo;$>X}Z)JDV}@WHEem?mtI_a(UOBpy^Fn%jl1xW5v9sT6+c31YhXEy-Vz`btut)g7&Lzf$Uc1=u}N${r|51G}}i!~VUZ zV}F?sRekV*+Xqlbc=`Iu!lw&=b*tPSvqItD37nU&B97z#Z@fY_<6b zMF{`%5Vn)mN@rr6uX#Nh#}n6Q`zF$Go1xIieAoF8hi7y1VYfL1f}eu2kpNOekG-9% zQybnG9M1Gc9-O;>?!o!{=RZGP?mbi-c`SFYeNY)bRvJFGA-}rLzDj#E&W?}AMY{WV z+UDG-e2~974gUTLJ(0rKh;gt{F=uhk=Ew6n9c-u0ntome#Pb1iG4laj2po*SJd2a~-gHIyEN;Q>-4 zgfwrA0l9;|W-h?qVlZaKMyj5EcHd@ee^r2#oNNtK4&8ce$-vthV_syVTOTq#?C{nP zdRTt@7{?y3y?`_u?;;l=6SMHGa1nyZ6thKhZU)vhCN2h(X@f;Vi1)z>T`NLjl+AqSgk26mO%x!Mk|1@>#B+SO5$TtykREnp|% zIm(JbYO6*Kz34om{ZG4OXE!~Sj@9mRj^pd_Mh5;35*$`JhGBkiKsR`v#rZ7CGh=v%Gg3k@WZe3oLW|2_lqs(*FVn(acr= literal 11285 zcmb_iYit`=cD}>m@GX&gQ18cBvL(@yNIUXVeyuFYvSVA0B`1!x-EwKpNE(?OGBZQV zVyGnDcoz!bcAfsHW$bRFEU<_cb^&jH6hVJw6E6^8_XiniOLm%cT_gpXe~jf|x9Otj zxpz2xiQY6t2jHFiIQPuuz4x5^opbK5tEy}Wp1=C%U(8SJLg-Um&>tq3+4@Jw+(8`T zs1TaLZ$pR*8)gh)dWH@gXN+NHh6$TyOq51ro-qRr9kPV2GgeY(4B5i=8GG0<;~-@w zR28nCsU~?-$QgFcxJcd{stMQ5)RMdpBz(d(@H?$C!|1h9teyIe*+Vmr+HZh29T|}7~`d-DPE37 zM489jq0%DriXw<}a!?zimKt+~tq^4HARf(79GWpak9e9la@0|%*K zL@@#r2va5d{Sw|;5}pMKucu_fNfjhNM_Z7hucm$TkudLzgkp;UIUt4w-&iCV3&X%= z-{pvO&3FFnm@gnI5z)UWNc^k>!|(;7QJ>69*M%T2`$7S6ZVoS&KY;(QH@YOb@KG+l%Q#%O>)trCX>?0%I@ii8%t3Qug*l3+XAPX;CI}5j^9QGTcB*?th^oQ9YAk`a@CUE<4|kQh}U@;){bmMvdr4j%Gw?c=HvuM zgbU6-zt>+8xd$}x7=0~@L zJod?W+q|MgkH0Hk#NAzUB&7n8c^m{w~$*fKgY{K zNg%37R_#Tl_+EWc*`ivR4Tv18^~OTMD)YQ7$m}c#i3|$Mh=otJgn1>v1(bltA+a!8 z$$}F+y=v`+=f)?eEUxfsfTsfNM?(Unsw<}% zxU2r4R?{@cD}K-wL5T<5=nqChAy6J5*p{?nGem9fysC- zSypPtLj$D()2sOX@mBq>7xuGPZ)^lo`DI96Lz&i|72~F*>EpWgkL%mg^*#{MjC&|k z+mmT(eqv?nZ7arGmM1Q>!;@wYr`W^qE^o5uahLW?OGmoJmum5)TTZ81PG@TR9=Y0+ zol{xb&`_OaQ0p0L3o)&jfivqu4J~PRPs-i1;qF^8{-uqGzG{+qSYa`k+srT|gs>10 z;d%vX?;s`LrsPFX-~7u|f?CiITGCldI1tE6BByV53PiSIKLhl`3tjenZ9f|N5hzgF zv7{eX(Pa-u>Hma`$XV75MXV%BxKyG6HS)DIw1Sm4c}i&O^xDE(7xZZm(RkZQy#26@ zu;<`ZAV)a62&c$wD6**9@Zl%(0Vz1IR^@hhuILO#!dC^6*Uv(aO~Nxo&JfiS;5fes zlCHY^1r}OOoxg;UR_O-$C}F2nI{_j1LA<0Te>5V1m=VT2H^)m@o2h2#6%|6f#~|&; zWs5%YVyTgyfu?w8C4niDy#XMSe+tQK=(D=^A1`dwc~&f06SCH%EzK!Q^XkC5G2M48 z)pu;uay(;k`V3^-LxFbR{)>;)0>vPkDQHZXJ^XUx#8TAwCvDii;*$rDml0C=6M}{Tl+wi z&LRrL%L<~vQ5fhf8w#QZn=U7I%QQzP3=8=)4G43OzD<4vX!R%bvJpfee}3sS2rcAK zr$SwV=8OrWK!XzTTS;SIG7g_53w+9o|;ctdSVp`_=a7K5I$hhsR?kL~)LQ|Q+BjVNJ)l`&)J znV_aoH&QzEPCJpbf7t>mY=>S`cypFHq|Ym7Rl4<-ghl86BlHu)92JKK;BUeF=WvdE z=?d5h%a&Jw_jE3|NG+lp^cA#7{T;nXdF*lTG!b{A5vDJmWAPb@Pj5IsC061?{B?kq zx#Ja}u7uG?;BDAHHn!}U9zQoedWqdukIC>EJW#P^k-sVn%2OJ(QymV~qVv8oo5K0; z*Uo=GIsYpY&+`EeV7UaqMy^Z|h=u)ua3m%wmD^`yA}7PyJR7Op9+YBYaGszbu#KcZ z@F{#kf_RPR1koRrBD0{_Dsz(~F)7IV6)C37>y+Cj^VHOZ^M!TVz6_U7k57$nug0r$ z?!x8qsXn|$9`@Xsv*YZrru6+VL1Fg7)V6OGI#pVs|05F&-bxnl8}lh%*Kgc;`;vWQ z7Uiq?jTuy0=n71>t;e}D=g(YX_wFfa_=?&ECJ&KrvXCkbN3+K+c|d|yb58Em`bsOP zHaDeL!gfo%7V9PLkFdnx9%d&! zwGy^kR0d!R)Uh~k=5O#puvSzPSiii;k$)rw$s|rL;N&@+T!cim zj^7CK#Pm>U!1EG5NIg~pHl#63h1H*G#TbXU3#18L2H1=J1!Tx7Edh|Cco`pgSWd`8 z;JkDUs^Z;ewppg$a?pec`=h zO{Rt>GkyK(zC)?LLzy1WJL21-{_^fkd&g4VvCM!kJ@Bp6z_$uryL!K9bhsP7q^oOc zvow4z+l6-eR>p4+rW|b6fn4qt#~+?pk*oPJGQi*L=zdGME38`{bZvC(PaZk{n_zNk z`q#4`zL-4xLh{h%RL5oL+xNzWO!qF#ZcSq1>nHM&oT|r$OwRm!8=R(uAbver2L|iw|+`n?&aVcKq}aVe+F=G znE{AptOUmdfEquepuZXO%yG7tmJ{UrA^--cCw{Ubx}tKP&cxX``;Y0L6-iKn)htvWKfG0yWi2-jYg~_5WD{YO3;}hI|P?jdR)j3h)L{Lvm%H z#uYzS1Zp(vqZHT3d>BxT5DbJuOROw}g-}2$oSj6-KgZUMhgPk+<3LueobGpnEO-L~ zA~@H8#7{e1?-SNw$jSLuAbjFa&@s*}< zA8gHOI?(+B@kDM@gJx+jrqR?AFM?YToGiTbJg$5JCzo+T)R?7QjhS<9PVyO*5~SUj z_S=xico}o5OLIj?$Dt783>a#acai^##^ZW9yyBn~`)wf2yHB)84f@ zk!fjDn@(iSsBvHm85^rtCO|QCuq%@pXT$B6KXP`jn;+OyeMjHD_)*`9kGlufFF*Kp zYH;Me9jU>wjqdTs$Whbq=GogLw{K+XyVCWZRJ|u_ur{}3I(o7+yy2T{@7l|$&Y=e* zA9Wtc(#W$X$1v4*;yvo4zLSr3bU*6s`8t}eOpo_*4RHLDw!0g$G<-MPhI;p}j3;Zm zQjTu$i)i|+is&;V{O=NdmZ|N%Pp^;MGp*f7y7qu)<#6>E$k17>sWcd7uZOVut)n}q z)vADpuGL1HshcNjPd1_V8w`+tzsWkNEin`q za#{tJEB|m*K{l5SAitGpK(FaW$ubQZfOyuIjRI(;gaLFEY$VW5riy)y9HaNGq&Ial z6wF4mOx@}Mvx!L%W9HTX=o=cigO2}y<8Dqc*D3f;f6;gL8c~rRW-94eeU!Ss0`tux z*C=^2cMcV~bc@VB(6>e~u&saEA)K&aZQBCjocmS!TtTnupD@c-IDsmd!j+uDN?sS| zM`=sg7Akm%5>}n!bubfOhO_E#=sWZe7zQod6L#I0#QOFix@8B`HE&9TmOruVc*-tu z7Iq2eQm~hjJPC*XwGu7w#9IDO&mb{p9k0Omh&chaUKS-6vZcXZXkp!X@-LvCNTd2vcAQB0`jS@=f!JcWKrz;+E2<_ zsE@b}hS~E0L4?B=NB?9ztxLJj!>eqlfCJ`u&D$`{YUd*vBdO2zB~00f2&47Rc?Z~O zp675X3B1gX46wWs^p;P7#Lq$SV3?iYBP2>71htTcuh;@(Ly;f}B$f?ITeKpJnyjUu zgB&uz=)4e$$dTy$l23D)ROC_Wd>Lj>!kA}^<1cJvmh3xypOkTdXo!se#^w1XAcMG! zPN!{I`2X;6O>9bfmK3yUg9yVQu+=E<>zp&A&4cJ1lWM062DQ>y`(UuNmN7x z=!ZY8j|UF~1Suf#1u;<~h0%H0DXdPQ3#gGLohHbaikgPj>NG7xRz^2jA@D@7*ffk( zHHHA^0Rjw!V!TYEy{d5m=1FC^SU4(6*gLH@EDGT7ftYbDh~s9QA3|}_m;zxo7<1$r zuK_%Q?W&5f9Wk8qs!Vm|Jnfv0OQkEoT=N3h0RZdOstHsz1jcx-E9q#1cWcjML|5&*Z31idPW;XITF-jdM&n?*aVXU| zw9&Z#wo$EX%nS}C>vyhQ`lx<))(q5uE?YhSi+-zP$-YA$wH+$&{{;GXCOduU&I75= z0~?)(Hd+p^oc*`Dj*oXAdGG4x?lZTUWJ6EN)thN-yY0xDsFt0Xou0LWS>)`mdeVWa zYkxTN#>}dHZFFra+0^%-X0vK2>DZS=^!}>sj=Bcd%Bk!C>gamw#k()2+mEK&k8U2D z+-Sc5Q_$X@ZXZgu4?Q@w(S9V|cx2`DqaJS_epVCsnXN-jNB{GQg#WhAE}4J!&+*r!}zeyit}v-Xn5E`;XG>{@1q}j_KkPbziKx@ z{ja(isQFbN4S8b8lwemJa3B_W7BZSeUI=U?0Ihq!5sKOxmT8A>1W~6RfUOCMrQ7j;JV)QPVQS*sQ{VX#p<7kg$|$*KWp;f{f!v zSVd?->?+VJ_z!3vvv}@|EWVlHEur{V=q)coLY!ZpHFafb8#7HUnb!79YiFj;{lso+ zwyjiwt<}Ect!M8(n{M-_+Pv$_8*RtZ4aZjOkL(RUJn_bf)kB;1ogX`DlJ&dRu07~W z^&Cm=I-0CKw&^(j$l3D4SKoLQBzSZ0$q%PC_dd7byqL6H%-T`o@vmeYRQjH8C_Xlu!#{Hlri{SO1Z41A$-!eFvDQdkAbP&9gdrv+g8EqtZ5JC9mAH0J2 z+3JJj=~s3*BX5FhHw;|6F_Q~b4bw|vaQh)3vB0K8ssSBqw*q#O4-`wZE_rey`Yv^Nl@49`%y)S=%L6e??&{#57 ztXJ@Ow;n+PSO37T7XmEE08m-V^UJ8@;DCoD)gQ1aVupVYhWz z*Nmes?ErsJ*V?{+a`a^xAoxngTJs045rlq;c8?6wzZ{?-h1q+$TEWN3(-p^G;8vV` z0Y7;cBv^4!%zmsm;FUGmnA6nSP!{1g+%$MXGMd2Sex622f|qJiz`zpcwA+VxZ)rCY z$vs4T_)3GA%7rfx2Im2K`n3ySz6TUKtsdiW|pQ+iGn15wp4R|WjU%{vpg>SH4Kn~pZqK&S(>7# ze?j)&p@Y9g2R}v5Pf+jgQ2!^W^%K None: + self.sql_session = sql_session + + # Setup Chroma collection + self.vector_db_client = vector_db_client or chromadb.PersistentClient( + path="./chroma_db" + ) + self.collection = self.vector_db_client.get_or_create_collection( + name="investor_descriptions", + metadata={ + "description": "Investor descriptions and investment thesis focus", + }, + ) + + # Build graph + graph = StateGraph(AgentState) + graph.add_node("sql_search", self._sql_search) + graph.add_node("vector_search", self._vector_search) + graph.add_node("merge", self._merge) + + # Parallel fan-out: START -> sql_search & vector_search -> merge -> END + graph.add_edge(START, "sql_search") + graph.add_edge(START, "vector_search") + graph.add_edge("sql_search", "merge") + graph.add_edge("vector_search", "merge") + graph.add_edge("merge", END) + + self.app = graph.compile() + + # Nodes + def _sql_search(self, state: AgentState) -> Dict[str, Any]: + results: List[QueryResponse] = [] + if self.sql_session is None: + return {"sql_results": results} + + # Simple LIKE-based search across a few fields + # Note: SQLite uses case-insensitive LIKE by default for ASCII. + q = ( + self.sql_session.query(InvestorTable) + .filter( + (func.lower(InvestorTable.name).like(f"%{state.question.lower()}%")) + | ( + func.lower(InvestorTable.sector_focus).like( + f"%{state.question.lower()}%" + ) + ) + | ( + func.lower(InvestorTable.stage_focus).like( + f"%{state.question.lower()}%" + ) + ) + | (func.lower(InvestorTable.region).like(f"%{state.question.lower()}%")) + ) + .limit(10) + ) + + for row in q.all(): + results.append( + QueryResponse( + name=row.name, + aum=row.aum, + check_size=row.check_size, + sector_focus=row.sector_focus, + stage_focus=row.stage_focus, + region=row.region, + investment_thesis="", + investor_description="", + reason="Matched SQL fields via LIKE", + ) + ) + + return {"sql_results": results} + + def _vector_search(self, state: AgentState) -> Dict[str, Any]: + results: List[QueryResponse] = [] + try: + q = self.collection.query(query_texts=[state.question], n_results=10) + # q has keys: ids, distances, documents, metadatas + docs = q.get("documents") or [] + metas = q.get("metadatas") or [] + if docs and metas: + for i, md in enumerate(metas[0]): + name = md.get("name", "Unknown") + results.append( + QueryResponse( + name=name, + aum=0, + check_size="", + sector_focus="", + stage_focus="", + region=md.get("headquarters", ""), + investment_thesis="", + investor_description=(docs[0][i] if docs[0] else ""), + reason="Vector similarity in Chroma", + ) + ) + except Exception: + # Best-effort; leave vector results empty on failure + pass + + return {"vector_results": results} + + def _merge(self, state: AgentState) -> Dict[str, Any]: + # Deduplicate by name, prefer SQL fields where available, keep first reason + merged: Dict[str, QueryResponse] = {} + + for item in state.vector_results + state.sql_results: + if item.name not in merged: + merged[item.name] = item + else: + existing = merged[item.name] + merged[item.name] = QueryResponse( + name=existing.name, + aum=existing.aum or item.aum, + check_size=existing.check_size or item.check_size, + sector_focus=existing.sector_focus or item.sector_focus, + stage_focus=existing.stage_focus or item.stage_focus, + region=existing.region or item.region, + investment_thesis=existing.investment_thesis + or item.investment_thesis, + investor_description=existing.investor_description + or item.investor_description, + reason=existing.reason or item.reason, + ) + + # Store back into sql_results to pass through the END with full state + return { + "sql_results": list(merged.values()), + "vector_results": [], + } + + # Public API + def run(self, question: str) -> QueryResponseList: + state = AgentState(question=question) + final_state: AgentState = self.app.invoke(state) + return QueryResponseList(responses=final_state.sql_results) diff --git a/app/services/openrouter.py b/app/services/openrouter.py index 6fa8a01..6a36a61 100644 --- a/app/services/openrouter.py +++ b/app/services/openrouter.py @@ -3,12 +3,13 @@ from typing import List, Optional import chromadb import pandas as pd -from db.tables import InvestorTable from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI from pydantic_schemas import Investor, InvestorList from settings import settings +from app.db.tables import InvestorTable + # Add these imports for your databases # from sqlalchemy.ext.asyncio import AsyncSession # from your_vector_db import VectorDBClient diff --git a/requirements.txt b/requirements.txt index 10ba213..6a7dbd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,9 +8,13 @@ chromadb>=0.4.0 # LLM integration openai>=1.0.0 +langchain>=0.2.0 +langchain-openai>=0.1.0 +langgraph>=0.2.0 # Environment management python-dotenv>=1.0.0 +pydantic-settings>=2.0.0 # Additional dependencies for data processing typing-extensions>=4.0.0